initial commit
This commit is contained in:
commit
a82bbc593e
129 changed files with 33981 additions and 0 deletions
0
utils/__init__.py
Normal file
0
utils/__init__.py
Normal file
309
utils/basic.py
Normal file
309
utils/basic.py
Normal file
|
@ -0,0 +1,309 @@
|
|||
import numpy as np
|
||||
import io
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
from collections import defaultdict, deque
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from .dist import is_dist_avail_and_initialized
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def average_dicts(dicts):
|
||||
# media = list(dicts.keys())
|
||||
# keys = [list(d.keys()) for d in dicts.values]
|
||||
# keys = list(itertools.chain.from_iterable(keys))
|
||||
# keys = list(set(keys))
|
||||
res = {}
|
||||
counter = {}
|
||||
for medium, medium_dict in dicts.items():
|
||||
for loss_key, loss_value in medium_dict.items():
|
||||
if loss_key not in res:
|
||||
res[loss_key] = loss_value
|
||||
counter[loss_key] = 1
|
||||
else:
|
||||
res[loss_key] += loss_value
|
||||
counter[loss_key] += 1
|
||||
for k in res:
|
||||
res[k] = res[k] / counter[k]
|
||||
return res
|
||||
|
||||
|
||||
|
||||
|
||||
class SmoothedValue(object):
|
||||
"""Track a series of values and provide access to smoothed values over a
|
||||
window or the global series average.
|
||||
"""
|
||||
|
||||
def __init__(self, window=20, fmt=None):
|
||||
if fmt is None:
|
||||
fmt = "{median:.4f} ({global_avg:.4f})"
|
||||
self.deque = deque(maxlen=window)
|
||||
self.total = 0.0
|
||||
self.count = 0
|
||||
self.fmt = fmt
|
||||
|
||||
def update(self, value, n=1):
|
||||
self.deque.append(value)
|
||||
self.count += n
|
||||
self.total += value * n
|
||||
|
||||
def synchronize_between_processes(self):
|
||||
"""
|
||||
Warning: does not synchronize the deque!
|
||||
"""
|
||||
if not is_dist_avail_and_initialized():
|
||||
return
|
||||
t = torch.tensor([self.count, self.total],
|
||||
dtype=torch.float64, device='cuda')
|
||||
dist.barrier()
|
||||
dist.all_reduce(t)
|
||||
t = t.tolist()
|
||||
self.count = int(t[0])
|
||||
self.total = t[1]
|
||||
|
||||
@property
|
||||
def median(self):
|
||||
d = torch.tensor(list(self.deque))
|
||||
return d.median().item()
|
||||
|
||||
@property
|
||||
def avg(self):
|
||||
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
||||
return d.mean().item()
|
||||
|
||||
@property
|
||||
def global_avg(self):
|
||||
return self.total / self.count
|
||||
|
||||
@property
|
||||
def max(self):
|
||||
return max(self.deque)
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self.deque[-1]
|
||||
|
||||
def __str__(self):
|
||||
return self.fmt.format(
|
||||
median=self.median,
|
||||
avg=self.avg,
|
||||
global_avg=self.global_avg,
|
||||
max=self.max,
|
||||
value=self.value)
|
||||
|
||||
|
||||
class MetricLogger(object):
|
||||
def __init__(self, delimiter="\t"):
|
||||
self.meters = defaultdict(SmoothedValue)
|
||||
self.delimiter = delimiter
|
||||
|
||||
def update(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
v = v.item()
|
||||
assert isinstance(v, (float, int))
|
||||
self.meters[k].update(v)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
if attr in self.meters:
|
||||
return self.meters[attr]
|
||||
if attr in self.__dict__:
|
||||
return self.__dict__[attr]
|
||||
raise AttributeError("'{}' object has no attribute '{}'".format(
|
||||
type(self).__name__, attr))
|
||||
|
||||
def __str__(self):
|
||||
loss_str = []
|
||||
for name, meter in self.meters.items():
|
||||
if meter.count == 0: # skip empty meter
|
||||
loss_str.append(
|
||||
"{}: {}".format(name, "No data")
|
||||
)
|
||||
else:
|
||||
loss_str.append(
|
||||
"{}: {}".format(name, str(meter))
|
||||
)
|
||||
return self.delimiter.join(loss_str)
|
||||
|
||||
def global_avg(self):
|
||||
loss_str = []
|
||||
for name, meter in self.meters.items():
|
||||
if meter.count == 0:
|
||||
loss_str.append(
|
||||
"{}: {}".format(name, "No data")
|
||||
)
|
||||
else:
|
||||
loss_str.append(
|
||||
"{}: {:.4f}".format(name, meter.global_avg)
|
||||
)
|
||||
return self.delimiter.join(loss_str)
|
||||
|
||||
def get_global_avg_dict(self, prefix=""):
|
||||
"""include a separator (e.g., `/`, or "_") at the end of `prefix`"""
|
||||
d = {f"{prefix}{k}": m.global_avg if m.count > 0 else 0. for k, m in self.meters.items()}
|
||||
return d
|
||||
|
||||
def synchronize_between_processes(self):
|
||||
for meter in self.meters.values():
|
||||
meter.synchronize_between_processes()
|
||||
|
||||
def add_meter(self, name, meter):
|
||||
self.meters[name] = meter
|
||||
|
||||
def log_every(self, iterable, log_freq, header=None):
|
||||
i = 0
|
||||
if not header:
|
||||
header = ''
|
||||
start_time = time.time()
|
||||
end = time.time()
|
||||
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
||||
data_time = SmoothedValue(fmt='{avg:.4f}')
|
||||
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
||||
log_msg = [
|
||||
header,
|
||||
'[{0' + space_fmt + '}/{1}]',
|
||||
'eta: {eta}\n',
|
||||
'{meters}\n',
|
||||
'time: {time}',
|
||||
'data: {data}'
|
||||
]
|
||||
if torch.cuda.is_available():
|
||||
log_msg.append('max mem: {memory:.0f} res mem: {res_mem:.0f}')
|
||||
log_msg = self.delimiter.join(log_msg)
|
||||
MB = 1024.0 * 1024.0
|
||||
for obj in iterable:
|
||||
data_time.update(time.time() - end)
|
||||
yield obj
|
||||
iter_time.update(time.time() - end)
|
||||
if i % log_freq == 0 or i == len(iterable) - 1:
|
||||
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
||||
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||
if torch.cuda.is_available():
|
||||
logger.info(log_msg.format(
|
||||
i, len(iterable), eta=eta_string,
|
||||
meters=str(self),
|
||||
time=str(iter_time), data=str(data_time),
|
||||
memory=torch.cuda.max_memory_allocated() / MB,
|
||||
res_mem=torch.cuda.max_memory_reserved() / MB,
|
||||
))
|
||||
else:
|
||||
logger.info(log_msg.format(
|
||||
i, len(iterable), eta=eta_string,
|
||||
meters=str(self),
|
||||
time=str(iter_time), data=str(data_time)))
|
||||
i += 1
|
||||
end = time.time()
|
||||
total_time = time.time() - start_time
|
||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||
logger.info('{} Total time: {} ({:.4f} s / it)'.format(
|
||||
header, total_time_str, total_time / len(iterable)))
|
||||
|
||||
|
||||
class AttrDict(dict):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(AttrDict, self).__init__(*args, **kwargs)
|
||||
self.__dict__ = self
|
||||
|
||||
|
||||
def compute_acc(logits, label, reduction='mean'):
|
||||
ret = (torch.argmax(logits, dim=1) == label).float()
|
||||
if reduction == 'none':
|
||||
return ret.detach()
|
||||
elif reduction == 'mean':
|
||||
return ret.mean().item()
|
||||
|
||||
|
||||
def compute_n_params(model, return_str=True):
|
||||
tot = 0
|
||||
for p in model.parameters():
|
||||
w = 1
|
||||
for x in p.shape:
|
||||
w *= x
|
||||
tot += w
|
||||
if return_str:
|
||||
if tot >= 1e6:
|
||||
return '{:.1f}M'.format(tot / 1e6)
|
||||
else:
|
||||
return '{:.1f}K'.format(tot / 1e3)
|
||||
else:
|
||||
return tot
|
||||
|
||||
|
||||
def setup_seed(seed):
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
|
||||
def remove_files_if_exist(file_paths):
|
||||
for fp in file_paths:
|
||||
if os.path.isfile(fp):
|
||||
os.remove(fp)
|
||||
|
||||
|
||||
def save_json(data, filename, save_pretty=False, sort_keys=False):
|
||||
with open(filename, "w") as f:
|
||||
if save_pretty:
|
||||
f.write(json.dumps(data, indent=4, sort_keys=sort_keys))
|
||||
else:
|
||||
json.dump(data, f)
|
||||
|
||||
|
||||
def load_json(filename):
|
||||
with open(filename, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def flat_list_of_lists(l):
|
||||
"""flatten a list of lists [[1,2], [3,4]] to [1,2,3,4]"""
|
||||
return [item for sublist in l for item in sublist]
|
||||
|
||||
|
||||
def find_files_by_suffix_recursively(root: str, suffix: Union[str, List[str]]):
|
||||
"""
|
||||
Args:
|
||||
root: path to the directory to start search files
|
||||
suffix: any str as suffix, or can match multiple such strings
|
||||
when input is List[str].
|
||||
Example 1, e.g., suffix: `.jpg` or [`.jpg`, `.png`]
|
||||
Example 2, e.g., use a `*` in the `suffix`: `START*.jpg.`.
|
||||
"""
|
||||
if isinstance(suffix, str):
|
||||
suffix = [suffix, ]
|
||||
filepaths = flat_list_of_lists(
|
||||
[list(Path(root).rglob(f"*{e}")) for e in suffix])
|
||||
return filepaths
|
||||
|
||||
|
||||
def match_key_and_shape(state_dict1, state_dict2):
|
||||
keys1 = set(state_dict1.keys())
|
||||
keys2 = set(state_dict2.keys())
|
||||
print(f"keys1 - keys2: {keys1 - keys2}")
|
||||
print(f"keys2 - keys1: {keys2 - keys1}")
|
||||
|
||||
mismatch = 0
|
||||
for k in list(keys1):
|
||||
if state_dict1[k].shape != state_dict2[k].shape:
|
||||
print(
|
||||
f"k={k}, state_dict1[k].shape={state_dict1[k].shape}, state_dict2[k].shape={state_dict2[k].shape}")
|
||||
mismatch += 1
|
||||
print(f"mismatch {mismatch}")
|
||||
|
||||
|
||||
def merge_dicts(list_dicts):
|
||||
merged_dict = list_dicts[0].copy()
|
||||
for i in range(1, len(list_dicts)):
|
||||
merged_dict.update(list_dicts[i])
|
||||
return merged_dict
|
25
utils/dist.py
Normal file
25
utils/dist.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
import torch.distributed as dist
|
||||
|
||||
|
||||
def is_dist_avail_and_initialized():
|
||||
if not dist.is_available():
|
||||
return False
|
||||
if not dist.is_initialized():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_world_size():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 1
|
||||
return dist.get_world_size()
|
||||
|
||||
|
||||
def get_rank():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 0
|
||||
return dist.get_rank()
|
||||
|
||||
|
||||
def is_main_process():
|
||||
return get_rank() == 0
|
149
utils/easydict.py
Normal file
149
utils/easydict.py
Normal file
|
@ -0,0 +1,149 @@
|
|||
class EasyDict(dict):
|
||||
"""
|
||||
Get attributes
|
||||
|
||||
>>> d = EasyDict({'foo':3})
|
||||
>>> d['foo']
|
||||
3
|
||||
>>> d.foo
|
||||
3
|
||||
>>> d.bar
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
AttributeError: 'EasyDict' object has no attribute 'bar'
|
||||
|
||||
Works recursively
|
||||
|
||||
>>> d = EasyDict({'foo':3, 'bar':{'x':1, 'y':2}})
|
||||
>>> isinstance(d.bar, dict)
|
||||
True
|
||||
>>> d.bar.x
|
||||
1
|
||||
|
||||
Bullet-proof
|
||||
|
||||
>>> EasyDict({})
|
||||
{}
|
||||
>>> EasyDict(d={})
|
||||
{}
|
||||
>>> EasyDict(None)
|
||||
{}
|
||||
>>> d = {'a': 1}
|
||||
>>> EasyDict(**d)
|
||||
{'a': 1}
|
||||
|
||||
Set attributes
|
||||
|
||||
>>> d = EasyDict()
|
||||
>>> d.foo = 3
|
||||
>>> d.foo
|
||||
3
|
||||
>>> d.bar = {'prop': 'value'}
|
||||
>>> d.bar.prop
|
||||
'value'
|
||||
>>> d
|
||||
{'foo': 3, 'bar': {'prop': 'value'}}
|
||||
>>> d.bar.prop = 'newer'
|
||||
>>> d.bar.prop
|
||||
'newer'
|
||||
|
||||
|
||||
Values extraction
|
||||
|
||||
>>> d = EasyDict({'foo':0, 'bar':[{'x':1, 'y':2}, {'x':3, 'y':4}]})
|
||||
>>> isinstance(d.bar, list)
|
||||
True
|
||||
>>> from operator import attrgetter
|
||||
>>> map(attrgetter('x'), d.bar)
|
||||
[1, 3]
|
||||
>>> map(attrgetter('y'), d.bar)
|
||||
[2, 4]
|
||||
>>> d = EasyDict()
|
||||
>>> d.keys()
|
||||
[]
|
||||
>>> d = EasyDict(foo=3, bar=dict(x=1, y=2))
|
||||
>>> d.foo
|
||||
3
|
||||
>>> d.bar.x
|
||||
1
|
||||
|
||||
Still like a dict though
|
||||
|
||||
>>> o = EasyDict({'clean':True})
|
||||
>>> o.items()
|
||||
[('clean', True)]
|
||||
|
||||
And like a class
|
||||
|
||||
>>> class Flower(EasyDict):
|
||||
... power = 1
|
||||
...
|
||||
>>> f = Flower()
|
||||
>>> f.power
|
||||
1
|
||||
>>> f = Flower({'height': 12})
|
||||
>>> f.height
|
||||
12
|
||||
>>> f['power']
|
||||
1
|
||||
>>> sorted(f.keys())
|
||||
['height', 'power']
|
||||
|
||||
update and pop items
|
||||
>>> d = EasyDict(a=1, b='2')
|
||||
>>> e = EasyDict(c=3.0, a=9.0)
|
||||
>>> d.update(e)
|
||||
>>> d.c
|
||||
3.0
|
||||
>>> d['c']
|
||||
3.0
|
||||
>>> d.get('c')
|
||||
3.0
|
||||
>>> d.update(a=4, b=4)
|
||||
>>> d.b
|
||||
4
|
||||
>>> d.pop('a')
|
||||
4
|
||||
>>> d.a
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
AttributeError: 'EasyDict' object has no attribute 'a'
|
||||
"""
|
||||
|
||||
def __init__(self, d=None, **kwargs):
|
||||
if d is None:
|
||||
d = {}
|
||||
if kwargs:
|
||||
d.update(**kwargs)
|
||||
for k, v in d.items():
|
||||
setattr(self, k, v)
|
||||
# Class attributes
|
||||
for k in self.__class__.__dict__.keys():
|
||||
if not (k.startswith("__") and k.endswith("__")) and not k in ("update", "pop"):
|
||||
setattr(self, k, getattr(self, k))
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if isinstance(value, (list, tuple)):
|
||||
value = [self.__class__(x) if isinstance(x, dict) else x for x in value]
|
||||
elif isinstance(value, dict) and not isinstance(value, self.__class__):
|
||||
value = self.__class__(value)
|
||||
super(EasyDict, self).__setattr__(name, value)
|
||||
super(EasyDict, self).__setitem__(name, value)
|
||||
|
||||
__setitem__ = __setattr__
|
||||
|
||||
def update(self, e=None, **f):
|
||||
d = e or dict()
|
||||
d.update(f)
|
||||
for k in d:
|
||||
setattr(self, k, d[k])
|
||||
|
||||
def pop(self, k, d=None):
|
||||
if hasattr(self, k):
|
||||
delattr(self, k)
|
||||
return super(EasyDict, self).pop(k, d)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import doctest
|
||||
|
154
utils/init.py
Normal file
154
utils/init.py
Normal file
|
@ -0,0 +1,154 @@
|
|||
import os
|
||||
import torch
|
||||
import random
|
||||
import pyhocon
|
||||
import datetime
|
||||
import json
|
||||
import subprocess
|
||||
import itertools
|
||||
import glob
|
||||
import glog as log
|
||||
import sys
|
||||
import re
|
||||
from os import path as osp
|
||||
import numpy as np
|
||||
|
||||
|
||||
# def load_runner(config, tokenizer, vocab_size):
|
||||
# if config['task'] == 'avsd':
|
||||
# return AVSDRunner(config, tokenizer, vocab_size)
|
||||
# if config['task'] == 'simmc':
|
||||
# return SIMMCRunner(config, tokenizer, vocab_size)
|
||||
# elif config['task'] == 'nextqa':
|
||||
# return NEXTQARunner(config, tokenizer, vocab_size)
|
||||
# else:
|
||||
# raise ValueError
|
||||
|
||||
|
||||
|
||||
|
||||
def set_random_seed(random_seed):
|
||||
torch.manual_seed(random_seed)
|
||||
torch.cuda.manual_seed(random_seed)
|
||||
random.seed(random_seed)
|
||||
np.random.seed(random_seed)
|
||||
|
||||
|
||||
def copy_file_to_log(log_dir):
|
||||
dirs_to_cp = ['.', 'config', 'datasets', 'runners', 'models']
|
||||
files_to_cp = ['*.py', '*.json', '*.sh', '*.conf']
|
||||
for dir_name in dirs_to_cp:
|
||||
dir_name = osp.join(log_dir, 'code', dir_name)
|
||||
if not osp.exists(dir_name):
|
||||
os.makedirs(dir_name)
|
||||
for dir_name, file_name in itertools.product(dirs_to_cp, files_to_cp):
|
||||
filename = osp.join(dir_name, file_name)
|
||||
if len(glob.glob(filename)) > 0:
|
||||
os.system(f'cp {filename} {osp.join(log_dir, "code", dir_name)}')
|
||||
log.info(f'Files copied to {osp.join(log_dir, "code")}')
|
||||
|
||||
|
||||
def set_log_file(fname, file_only=False):
|
||||
# if fname already exists, find all log file under log dir,
|
||||
# and name the current log file with a new number
|
||||
if osp.exists(fname):
|
||||
prefix, suffix = osp.splitext(fname)
|
||||
log_files = glob.glob(prefix + '*' + suffix)
|
||||
count = 0
|
||||
for log_file in log_files:
|
||||
num = re.search(r'(\d+)', log_file)
|
||||
if num is not None:
|
||||
num = int(num.group(0))
|
||||
count = max(num, count)
|
||||
fname = fname.replace(suffix, str(count + 1) + suffix)
|
||||
# set log file
|
||||
# simple tricks for duplicating logging destination in the logging module such as:
|
||||
# logging.getLogger().addHandler(logging.FileHandler(filename))
|
||||
# does NOT work well here, because python Traceback message (not via logging module) is not sent to the file,
|
||||
# the following solution (copied from : https://stackoverflow.com/questions/616645) is a little bit
|
||||
# complicated but simulates exactly the "tee" command in linux shell, and it redirects everything
|
||||
if file_only:
|
||||
# we only output messages to file, and stdout/stderr receives nothing.
|
||||
# this feature is designed for executing the script via ssh:
|
||||
# since ssh has a windowing kind of flow control, i.e., if the controller does not read data from a
|
||||
# ssh channel and its buffer fills up, the execution machine will not be able to write anything into the
|
||||
# channel and the process will be set to sleeping (S) status until someone reads all data from the channel.
|
||||
# this is not desired since we do not want to read stdout/stderr from the controller machine.
|
||||
# so, here we use a simple solution: disable output to stdout/stderr and only output messages to log file.
|
||||
log.logger.handlers[0].stream = log.handler.stream = sys.stdout = sys.stderr = f = open(fname, 'w', buffering=1)
|
||||
else:
|
||||
# we output messages to both file and stdout/stderr
|
||||
tee = subprocess.Popen(['tee', fname], stdin=subprocess.PIPE)
|
||||
os.dup2(tee.stdin.fileno(), sys.stdout.fileno())
|
||||
os.dup2(tee.stdin.fileno(), sys.stderr.fileno())
|
||||
|
||||
|
||||
def set_training_steps(config, num_samples):
|
||||
if config['parallel'] and config['dp_type'] == 'dp':
|
||||
config['num_iter_per_epoch'] = int(np.ceil(num_samples / config['batch_size']))
|
||||
else:
|
||||
config['num_iter_per_epoch'] = int(np.ceil(num_samples / (config['batch_size'] * config['num_gpus'])))
|
||||
if 'train_steps' not in config:
|
||||
config['train_steps'] = config['num_iter_per_epoch'] * config['num_epochs']
|
||||
if 'warmup_steps' not in config:
|
||||
config['warmup_steps'] = int(config['train_steps'] * config['warmup_ratio'])
|
||||
return config
|
||||
|
||||
|
||||
def initialize_from_env(model, mode, stage, eval_dir, tag=''):
|
||||
|
||||
if mode in ['train', 'debug']:
|
||||
path_config = f"config/{model}_{stage}.conf"
|
||||
config = pyhocon.ConfigFactory.parse_file(path_config)[stage]
|
||||
else:
|
||||
path_config = os.path.join(eval_dir, f'{model}_{stage}.conf')
|
||||
config = pyhocon.ConfigFactory.parse_file(path_config)[stage]
|
||||
config['log_dir'] = eval_dir
|
||||
|
||||
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
config['num_gpus'] = len(os.environ["CUDA_VISIBLE_DEVICES"].split(','))
|
||||
# multi-gpu setting
|
||||
if config['num_gpus'] > 1:
|
||||
os.environ['MASTER_ADDR'] = 'localhost'
|
||||
os.environ["MASTER_PORT"] = str(config['master_port'])
|
||||
else:
|
||||
config['num_gpus'] = 1
|
||||
|
||||
model += '-' + config.llm_name.replace('/', '_')
|
||||
|
||||
if mode == 'debug':
|
||||
model += '_debug'
|
||||
|
||||
if tag:
|
||||
model += '-' + tag
|
||||
if mode != 'generate':
|
||||
config["log_dir"] = os.path.join(config["log_dir"], model)
|
||||
if not os.path.exists(config["log_dir"]):
|
||||
os.makedirs(config["log_dir"])
|
||||
# copy the config file
|
||||
os.system(f'cp {path_config} {config["log_dir"]}')
|
||||
|
||||
config['timestamp'] = datetime.datetime.now().strftime('%m%d-%H%M%S')
|
||||
|
||||
config['expert_config'] = config['bert_config_{}'.format(config['expert_size'])]
|
||||
config['expert_config_json'] = json.load(open(config['expert_config'], 'r'))
|
||||
|
||||
config['beit_config_json'] = json.load(open(config['beit_config'], 'r'))
|
||||
|
||||
|
||||
config['model'] = model
|
||||
config['stage'] = stage
|
||||
config['loss_dict'] = {k:v for k,v in zip(config['loss_names'], config['loss_weights'])}
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def set_training_steps(config, num_samples, batch_sizes):
|
||||
config['num_iter_per_epoch'] = sum([int(np.ceil(num_sample / (bs * config['accum_grad_every'] * config['num_gpus']))) for num_sample, bs in zip(num_samples, batch_sizes)])
|
||||
if 'num_training_steps' not in config:
|
||||
config['num_training_steps'] = config['num_iter_per_epoch'] * config['epochs']
|
||||
if 'num_warmup_steps' not in config:
|
||||
config['num_warmup_steps'] = int(config['num_iter_per_epoch'] * config.get('warmup_epochs', 1.0))
|
||||
|
||||
# config['num_warmup_steps'] = int(config['num_training_steps'] * config['warmup_ratio'])
|
||||
return config
|
286
utils/logger.py
Normal file
286
utils/logger.py
Normal file
|
@ -0,0 +1,286 @@
|
|||
# from MMF: https://github.com/facebookresearch/mmf/blob/master/mmf/utils/logger.py
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import wandb
|
||||
|
||||
from .dist import get_rank, is_main_process
|
||||
from termcolor import colored
|
||||
|
||||
|
||||
def log_dict_to_wandb(log_dict, step, prefix=""):
|
||||
"""include a separator `/` at the end of `prefix`"""
|
||||
if not is_main_process():
|
||||
return
|
||||
|
||||
log_dict = {f"{prefix}{k}": v for k, v in log_dict.items()}
|
||||
wandb.log(log_dict, step)
|
||||
|
||||
|
||||
def setup_wandb(config):
|
||||
if not (config.wandb_enabled and is_main_process()):
|
||||
return
|
||||
|
||||
run = wandb.init(
|
||||
config=config,
|
||||
project=config.wandb_project,
|
||||
# entity=config.wandb.entity,
|
||||
mode=config.wandb_mode,
|
||||
# name=os.path.basename(config.output_dir),
|
||||
reinit=True
|
||||
)
|
||||
wandb.define_metric('train/webvid/step')
|
||||
wandb.define_metric('train/webvid/*', 'train/webvid/step')
|
||||
|
||||
wandb.define_metric('train/cc3m/step')
|
||||
wandb.define_metric('train/cc3m/*', 'train/cc3m/step')
|
||||
|
||||
wandb.define_metric('train/other/step')
|
||||
wandb.define_metric('train/other/*', 'train/other/step')
|
||||
|
||||
wandb.define_metric('val/msrvtt/step')
|
||||
wandb.define_metric('val/msrvtt/*', 'val/msrvtt/step')
|
||||
|
||||
wandb.define_metric('train/champagne/step')
|
||||
wandb.define_metric('train/champagne/*', 'train/champagne/step')
|
||||
|
||||
wandb.define_metric('train/visdial/step')
|
||||
wandb.define_metric('train/visdial/*', 'train/visdial/step')
|
||||
|
||||
wandb.define_metric('train/avsd/step')
|
||||
wandb.define_metric('train/avsd/*', 'train/avsd/step')
|
||||
|
||||
wandb.define_metric('train/nextqa/step')
|
||||
wandb.define_metric('train/nextqa/*', 'train/nextqa/step')
|
||||
|
||||
return run
|
||||
|
||||
|
||||
def setup_output_folder(save_dir: str, folder_only: bool = False):
|
||||
"""Sets up and returns the output file where the logs will be placed
|
||||
based on the configuration passed. Usually "save_dir/logs/log_<timestamp>.txt".
|
||||
If env.log_dir is passed, logs will be directly saved in this folder.
|
||||
Args:
|
||||
folder_only (bool, optional): If folder should be returned and not the file.
|
||||
Defaults to False.
|
||||
Returns:
|
||||
str: folder or file path depending on folder_only flag
|
||||
"""
|
||||
log_filename = "train_"
|
||||
log_filename += time.strftime("%Y_%m_%dT%H_%M_%S")
|
||||
log_filename += ".log"
|
||||
|
||||
log_folder = os.path.join(save_dir, "logs")
|
||||
|
||||
if not os.path.exists(log_folder):
|
||||
os.path.mkdirs(log_folder)
|
||||
|
||||
if folder_only:
|
||||
return log_folder
|
||||
|
||||
log_filename = os.path.join(log_folder, log_filename)
|
||||
|
||||
return log_filename
|
||||
|
||||
|
||||
def setup_logger(
|
||||
output: str = None,
|
||||
color: bool = True,
|
||||
name: str = "mmf",
|
||||
disable: bool = False,
|
||||
clear_handlers=True,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize the MMF logger and set its verbosity level to "INFO".
|
||||
Outside libraries shouldn't call this in case they have set there
|
||||
own logging handlers and setup. If they do, and don't want to
|
||||
clear handlers, pass clear_handlers options.
|
||||
The initial version of this function was taken from D2 and adapted
|
||||
for MMF.
|
||||
Args:
|
||||
output (str): a file name or a directory to save log.
|
||||
If ends with ".txt" or ".log", assumed to be a file name.
|
||||
Default: Saved to file <save_dir/logs/log_[timestamp].txt>
|
||||
color (bool): If false, won't log colored logs. Default: true
|
||||
name (str): the root module name of this logger. Defaults to "mmf".
|
||||
disable: do not use
|
||||
clear_handlers (bool): If false, won't clear existing handlers.
|
||||
Returns:
|
||||
logging.Logger: a logger
|
||||
"""
|
||||
if disable:
|
||||
return None
|
||||
logger = logging.getLogger(name)
|
||||
logger.propagate = False
|
||||
|
||||
logging.captureWarnings(True)
|
||||
warnings_logger = logging.getLogger("py.warnings")
|
||||
|
||||
plain_formatter = logging.Formatter(
|
||||
"%(asctime)s | %(levelname)s | %(name)s : %(message)s",
|
||||
datefmt="%Y-%m-%dT%H:%M:%S",
|
||||
)
|
||||
|
||||
distributed_rank = get_rank()
|
||||
handlers = []
|
||||
|
||||
logging_level = logging.INFO
|
||||
# logging_level = logging.DEBUG
|
||||
|
||||
if distributed_rank == 0:
|
||||
logger.setLevel(logging_level)
|
||||
ch = logging.StreamHandler(stream=sys.stdout)
|
||||
ch.setLevel(logging_level)
|
||||
if color:
|
||||
formatter = ColorfulFormatter(
|
||||
colored("%(asctime)s | %(name)s: ", "green") + "%(message)s",
|
||||
datefmt="%Y-%m-%dT%H:%M:%S",
|
||||
)
|
||||
else:
|
||||
formatter = plain_formatter
|
||||
ch.setFormatter(formatter)
|
||||
logger.addHandler(ch)
|
||||
warnings_logger.addHandler(ch)
|
||||
handlers.append(ch)
|
||||
|
||||
# file logging: all workers
|
||||
if output is None:
|
||||
output = setup_output_folder()
|
||||
|
||||
if output is not None:
|
||||
if output.endswith(".txt") or output.endswith(".log"):
|
||||
filename = output
|
||||
else:
|
||||
filename = os.path.join(output, "train.log")
|
||||
if distributed_rank > 0:
|
||||
filename = filename + f".rank{distributed_rank}"
|
||||
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
||||
|
||||
fh = logging.StreamHandler(_cached_log_stream(filename))
|
||||
fh.setLevel(logging_level)
|
||||
fh.setFormatter(plain_formatter)
|
||||
logger.addHandler(fh)
|
||||
warnings_logger.addHandler(fh)
|
||||
handlers.append(fh)
|
||||
|
||||
# Slurm/FB output, only log the main process
|
||||
# save_dir = get_mmf_env(key="save_dir")
|
||||
if "train.log" not in filename and distributed_rank == 0:
|
||||
filename = os.path.join(output, "train.log")
|
||||
sh = logging.StreamHandler(_cached_log_stream(filename))
|
||||
sh.setLevel(logging_level)
|
||||
sh.setFormatter(plain_formatter)
|
||||
logger.addHandler(sh)
|
||||
warnings_logger.addHandler(sh)
|
||||
handlers.append(sh)
|
||||
|
||||
logger.info(f"Logging to: {filename}")
|
||||
|
||||
# Remove existing handlers to add MMF specific handlers
|
||||
if clear_handlers:
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
# Now, add our handlers.
|
||||
logging.basicConfig(level=logging_level, handlers=handlers)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
def setup_very_basic_config(color=True):
|
||||
plain_formatter = logging.Formatter(
|
||||
"%(asctime)s | %(levelname)s | %(name)s : %(message)s",
|
||||
datefmt="%Y-%m-%dT%H:%M:%S",
|
||||
)
|
||||
ch = logging.StreamHandler(stream=sys.stdout)
|
||||
ch.setLevel(logging.INFO)
|
||||
if color:
|
||||
formatter = ColorfulFormatter(
|
||||
colored("%(asctime)s | %(name)s: ", "green") + "%(message)s",
|
||||
datefmt="%Y-%m-%dT%H:%M:%S",
|
||||
)
|
||||
else:
|
||||
formatter = plain_formatter
|
||||
ch.setFormatter(formatter)
|
||||
# Setup a minimal configuration for logging in case something tries to
|
||||
# log a message even before logging is setup by MMF.
|
||||
logging.basicConfig(level=logging.INFO, handlers=[ch])
|
||||
|
||||
|
||||
# cache the opened file object, so that different calls to `setup_logger`
|
||||
# with the same file name can safely write to the same file.
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _cached_log_stream(filename):
|
||||
return open(filename, "a")
|
||||
|
||||
|
||||
# ColorfulFormatter is adopted from Detectron2 and adapted for MMF
|
||||
class ColorfulFormatter(logging.Formatter):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def formatMessage(self, record):
|
||||
log = super().formatMessage(record)
|
||||
if record.levelno == logging.WARNING:
|
||||
prefix = colored("WARNING", "red", attrs=["blink"])
|
||||
elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
|
||||
prefix = colored("ERROR", "red", attrs=["blink", "underline"])
|
||||
else:
|
||||
return log
|
||||
return prefix + " " + log
|
||||
|
||||
|
||||
class TensorboardLogger:
|
||||
def __init__(self, log_folder="./logs", iteration=0):
|
||||
# This would handle warning of missing tensorboard
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
self.summary_writer = None
|
||||
self._is_master = is_main_process()
|
||||
# self.timer = Timer()
|
||||
self.log_folder = log_folder
|
||||
|
||||
if self._is_master:
|
||||
# current_time = self.timer.get_time_hhmmss(None, format=self.time_format)
|
||||
current_time = time.strftime("%Y-%m-%dT%H:%M:%S")
|
||||
# self.timer.get_time_hhmmss(None, format=self.time_format)
|
||||
tensorboard_folder = os.path.join(
|
||||
self.log_folder, f"tensorboard_{current_time}"
|
||||
)
|
||||
self.summary_writer = SummaryWriter(tensorboard_folder)
|
||||
|
||||
def __del__(self):
|
||||
if getattr(self, "summary_writer", None) is not None:
|
||||
self.summary_writer.close()
|
||||
|
||||
def _should_log_tensorboard(self):
|
||||
if self.summary_writer is None or not self._is_master:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def add_scalar(self, key, value, iteration):
|
||||
if not self._should_log_tensorboard():
|
||||
return
|
||||
|
||||
self.summary_writer.add_scalar(key, value, iteration)
|
||||
|
||||
def add_scalars(self, scalar_dict, iteration):
|
||||
if not self._should_log_tensorboard():
|
||||
return
|
||||
|
||||
for key, val in scalar_dict.items():
|
||||
self.summary_writer.add_scalar(key, val, iteration)
|
||||
|
||||
def add_histogram_for_model(self, model, iteration):
|
||||
if not self._should_log_tensorboard():
|
||||
return
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
np_param = param.clone().cpu().data.numpy()
|
||||
self.summary_writer.add_histogram(name, np_param, iteration)
|
174
utils/metrcis.py
Normal file
174
utils/metrcis.py
Normal file
|
@ -0,0 +1,174 @@
|
|||
"""
|
||||
A Metric observes output of certain model, for example, in form of logits or
|
||||
scores, and accumulates a particular metric with reference to some provided
|
||||
targets. In context of VisDial, we use Recall (@ 1, 5, 10), Mean Rank, Mean
|
||||
Reciprocal Rank (MRR) and Normalized Discounted Cumulative Gain (NDCG).
|
||||
|
||||
Each ``Metric`` must atleast implement three methods:
|
||||
- ``observe``, update accumulated metric with currently observed outputs
|
||||
and targets.
|
||||
- ``retrieve`` to return the accumulated metric., an optionally reset
|
||||
internally accumulated metric (this is commonly done between two epochs
|
||||
after validation).
|
||||
- ``reset`` to explicitly reset the internally accumulated metric.
|
||||
|
||||
Caveat, if you wish to implement your own class of Metric, make sure you call
|
||||
``detach`` on output tensors (like logits), else it will cause memory leaks.
|
||||
"""
|
||||
import torch
|
||||
|
||||
|
||||
def scores_to_ranks(scores: torch.Tensor):
|
||||
"""Convert model output scores into ranks."""
|
||||
batch_size, num_rounds, num_options = scores.size()
|
||||
scores = scores.view(-1, num_options)
|
||||
|
||||
# sort in descending order - largest score gets highest rank
|
||||
sorted_ranks, ranked_idx = scores.sort(1, descending=True)
|
||||
|
||||
# i-th position in ranked_idx specifies which score shall take this
|
||||
# position but we want i-th position to have rank of score at that
|
||||
# position, do this conversion
|
||||
ranks = ranked_idx.clone().fill_(0)
|
||||
for i in range(ranked_idx.size(0)):
|
||||
for j in range(num_options):
|
||||
ranks[i][ranked_idx[i][j]] = j
|
||||
# convert from 0-99 ranks to 1-100 ranks
|
||||
ranks += 1
|
||||
ranks = ranks.view(batch_size, num_rounds, num_options)
|
||||
return ranks
|
||||
|
||||
|
||||
class SparseGTMetrics(object):
|
||||
"""
|
||||
A class to accumulate all metrics with sparse ground truth annotations.
|
||||
These include Recall (@ 1, 5, 10), Mean Rank and Mean Reciprocal Rank.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._rank_list = []
|
||||
|
||||
def observe(
|
||||
self, predicted_scores: torch.Tensor, target_ranks: torch.Tensor
|
||||
):
|
||||
predicted_scores = predicted_scores.detach()
|
||||
|
||||
# shape: (batch_size, num_rounds, num_options)
|
||||
predicted_ranks = scores_to_ranks(predicted_scores)
|
||||
batch_size, num_rounds, num_options = predicted_ranks.size()
|
||||
|
||||
# collapse batch dimension
|
||||
predicted_ranks = predicted_ranks.view(
|
||||
batch_size * num_rounds, num_options
|
||||
)
|
||||
|
||||
# shape: (batch_size * num_rounds, )
|
||||
target_ranks = target_ranks.view(batch_size * num_rounds).long()
|
||||
|
||||
# shape: (batch_size * num_rounds, )
|
||||
predicted_gt_ranks = predicted_ranks[
|
||||
torch.arange(batch_size * num_rounds), target_ranks
|
||||
]
|
||||
self._rank_list.extend(list(predicted_gt_ranks.cpu().numpy()))
|
||||
|
||||
def retrieve(self, reset: bool = True):
|
||||
num_examples = len(self._rank_list)
|
||||
if num_examples > 0:
|
||||
# convert to numpy array for easy calculation.
|
||||
__rank_list = torch.tensor(self._rank_list).float()
|
||||
metrics = {
|
||||
"r@1": torch.mean((__rank_list <= 1).float()).item(),
|
||||
"r@5": torch.mean((__rank_list <= 5).float()).item(),
|
||||
"r@10": torch.mean((__rank_list <= 10).float()).item(),
|
||||
"mean": torch.mean(__rank_list).item(),
|
||||
"mrr": torch.mean(__rank_list.reciprocal()).item(),
|
||||
}
|
||||
else:
|
||||
metrics = {}
|
||||
|
||||
if reset:
|
||||
self.reset()
|
||||
return metrics
|
||||
|
||||
def reset(self):
|
||||
self._rank_list = []
|
||||
|
||||
|
||||
class NDCG(object):
|
||||
def __init__(self):
|
||||
self._ndcg_numerator = 0.0
|
||||
self._ndcg_denominator = 0.0
|
||||
|
||||
def observe(
|
||||
self, predicted_scores: torch.Tensor, target_relevance: torch.Tensor
|
||||
):
|
||||
"""
|
||||
Observe model output scores and target ground truth relevance and
|
||||
accumulate NDCG metric.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
predicted_scores: torch.Tensor
|
||||
A tensor of shape (batch_size, num_options), because dense
|
||||
annotations are available for 1 randomly picked round out of 10.
|
||||
target_relevance: torch.Tensor
|
||||
A tensor of shape same as predicted scores, indicating ground truth
|
||||
relevance of each answer option for a particular round.
|
||||
"""
|
||||
predicted_scores = predicted_scores.detach()
|
||||
|
||||
# shape: (batch_size, 1, num_options)
|
||||
predicted_scores = predicted_scores.unsqueeze(1)
|
||||
predicted_ranks = scores_to_ranks(predicted_scores)
|
||||
|
||||
# shape: (batch_size, num_options)
|
||||
predicted_ranks = predicted_ranks.squeeze(1)
|
||||
batch_size, num_options = predicted_ranks.size()
|
||||
|
||||
k = torch.sum(target_relevance != 0, dim=-1)
|
||||
|
||||
# shape: (batch_size, num_options)
|
||||
_, rankings = torch.sort(predicted_ranks, dim=-1)
|
||||
# Sort relevance in descending order so highest relevance gets top rnk.
|
||||
_, best_rankings = torch.sort(
|
||||
target_relevance, dim=-1, descending=True
|
||||
)
|
||||
|
||||
# shape: (batch_size, )
|
||||
batch_ndcg = []
|
||||
for batch_index in range(batch_size):
|
||||
|
||||
num_relevant = k[batch_index]
|
||||
dcg = self._dcg(
|
||||
rankings[batch_index][:num_relevant],
|
||||
target_relevance[batch_index],
|
||||
)
|
||||
best_dcg = self._dcg(
|
||||
best_rankings[batch_index][:num_relevant],
|
||||
target_relevance[batch_index],
|
||||
)
|
||||
batch_ndcg.append(dcg / best_dcg)
|
||||
|
||||
self._ndcg_denominator += batch_size
|
||||
self._ndcg_numerator += sum(batch_ndcg)
|
||||
|
||||
def _dcg(self, rankings: torch.Tensor, relevance: torch.Tensor):
|
||||
sorted_relevance = relevance[rankings].cpu().float()
|
||||
discounts = torch.log2(torch.arange(len(rankings)).float() + 2)
|
||||
return torch.sum(sorted_relevance / discounts, dim=-1)
|
||||
|
||||
def retrieve(self, reset: bool = True):
|
||||
if self._ndcg_denominator > 0:
|
||||
metrics = {
|
||||
"ndcg": float(self._ndcg_numerator / self._ndcg_denominator)
|
||||
}
|
||||
else:
|
||||
metrics = {}
|
||||
|
||||
if reset:
|
||||
self.reset()
|
||||
return metrics
|
||||
|
||||
def reset(self):
|
||||
self._ndcg_numerator = 0.0
|
||||
self._ndcg_denominator = 0.0
|
35
utils/optimizer.py
Normal file
35
utils/optimizer.py
Normal file
|
@ -0,0 +1,35 @@
|
|||
""" Optimizer Factory w/ Custom Weight Decay
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import re
|
||||
import torch
|
||||
from torch import optim as optim
|
||||
from utils.dist import is_main_process
|
||||
import glog as logger
|
||||
# from transformers import create_optimizer
|
||||
# from transformers import AdamW
|
||||
# import math
|
||||
|
||||
|
||||
def create_optimizer(config, model):
|
||||
lr_scale = config.get('lr_layer_decay', 1)
|
||||
weight_decay = config.get('weight_decay', 0.01)
|
||||
|
||||
optim_params = model.get_optimizer_params(weight_decay, lr_scale)
|
||||
|
||||
num_parameters = 0
|
||||
for p_group in optim_params:
|
||||
for p in p_group['params']:
|
||||
num_parameters += p.data.nelement()
|
||||
logger.info('number of trainable parameters: {}'.format(num_parameters))
|
||||
|
||||
lr = config.get('lr', 1e-4)
|
||||
betas = config.get('opt_betas', [0.9, 0.999])
|
||||
|
||||
optimizer = torch.optim.AdamW(
|
||||
optim_params,
|
||||
lr=float(lr),
|
||||
betas=betas
|
||||
)
|
||||
|
||||
return optimizer
|
240
utils/scheduler.py
Normal file
240
utils/scheduler.py
Normal file
|
@ -0,0 +1,240 @@
|
|||
""" Scheduler Factory
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
from torch.optim import Optimizer
|
||||
import math
|
||||
from torch.optim.lr_scheduler import LambdaLR, _LRScheduler
|
||||
import math
|
||||
|
||||
|
||||
# class LinearWarmupStepLRScheduler:
|
||||
# def __init__(
|
||||
# self,
|
||||
# optimizer,
|
||||
# max_epoch,
|
||||
# min_lr,
|
||||
# init_lr,
|
||||
# decay_rate=1,
|
||||
# warmup_start_lr=-1,
|
||||
# warmup_steps=0,
|
||||
# **kwargs
|
||||
# ):
|
||||
# self.optimizer = optimizer
|
||||
|
||||
# self.max_epoch = max_epoch
|
||||
# self.min_lr = min_lr
|
||||
|
||||
# self.decay_rate = decay_rate
|
||||
|
||||
# self.init_lr = init_lr
|
||||
# self.warmup_steps = warmup_steps
|
||||
# self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
|
||||
|
||||
# def step(self, cur_epoch, cur_step):
|
||||
# if cur_epoch == 0:
|
||||
# warmup_lr_schedule(
|
||||
# step=cur_step,
|
||||
# optimizer=self.optimizer,
|
||||
# max_step=self.warmup_steps,
|
||||
# init_lr=self.warmup_start_lr,
|
||||
# max_lr=self.init_lr,
|
||||
# )
|
||||
# else:
|
||||
# step_lr_schedule(
|
||||
# epoch=cur_epoch,
|
||||
# optimizer=self.optimizer,
|
||||
# init_lr=self.init_lr,
|
||||
# min_lr=self.min_lr,
|
||||
# decay_rate=self.decay_rate,
|
||||
# )
|
||||
|
||||
|
||||
# class LinearWarmupCosineLRScheduler:
|
||||
# def __init__(
|
||||
# self,
|
||||
# optimizer,
|
||||
# max_epoch,
|
||||
# min_lr,
|
||||
# init_lr,
|
||||
# warmup_steps=0,
|
||||
# warmup_start_lr=-1,
|
||||
# **kwargs
|
||||
# ):
|
||||
# self.optimizer = optimizer
|
||||
|
||||
# self.max_epoch = max_epoch
|
||||
# self.min_lr = min_lr
|
||||
|
||||
# self.init_lr = init_lr
|
||||
# self.warmup_steps = warmup_steps
|
||||
# self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
|
||||
|
||||
# def step(self, cur_epoch, cur_step):
|
||||
# # assuming the warmup iters less than one epoch
|
||||
# if cur_epoch == 0:
|
||||
# warmup_lr_schedule(
|
||||
# step=cur_step,
|
||||
# optimizer=self.optimizer,
|
||||
# max_step=self.warmup_steps,
|
||||
# init_lr=self.warmup_start_lr,
|
||||
# max_lr=self.init_lr,
|
||||
# )
|
||||
# else:
|
||||
# cosine_lr_schedule(
|
||||
# epoch=cur_epoch,
|
||||
# optimizer=self.optimizer,
|
||||
# max_epoch=self.max_epoch,
|
||||
# init_lr=self.init_lr,
|
||||
# min_lr=self.min_lr,
|
||||
# )
|
||||
|
||||
|
||||
# class ConstantLRScheduler:
|
||||
# def __init__(self, optimizer, init_lr, warmup_start_lr=-1, warmup_steps=0, **kwargs):
|
||||
# self.optimizer = optimizer
|
||||
# self.lr = init_lr
|
||||
# self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
|
||||
# self.warmup_steps = warmup_steps
|
||||
|
||||
# def step(self, cur_epoch, cur_step):
|
||||
# if cur_epoch == 0:
|
||||
# warmup_lr_schedule(
|
||||
# step=cur_step,
|
||||
# optimizer=self.optimizer,
|
||||
# max_step=self.warmup_steps,
|
||||
# init_lr=self.warmup_start_lr,
|
||||
# max_lr=self.lr,
|
||||
# )
|
||||
# else:
|
||||
# for param_group in self.optimizer.param_groups:
|
||||
# param_group["lr"] = self.lr
|
||||
|
||||
|
||||
# schedulers = {
|
||||
# 'constant_lr': ConstantLRScheduler,
|
||||
# 'linear_warmup_cosine_lr': LinearWarmupCosineLRScheduler,
|
||||
# 'linear_warmup_step_lr': LinearWarmupStepLRScheduler
|
||||
# }
|
||||
|
||||
|
||||
# def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
|
||||
# """Decay the learning rate"""
|
||||
# lr = (init_lr - min_lr) * 0.5 * (
|
||||
# 1.0 + math.cos(math.pi * epoch / max_epoch)
|
||||
# ) + min_lr
|
||||
# for param_group in optimizer.param_groups:
|
||||
# param_group["lr"] = lr
|
||||
|
||||
|
||||
# def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
|
||||
# """Warmup the learning rate"""
|
||||
# lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))
|
||||
# for param_group in optimizer.param_groups:
|
||||
# param_group["lr"] = lr
|
||||
|
||||
|
||||
# def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
|
||||
# """Decay the learning rate"""
|
||||
# lr = max(min_lr, init_lr * (decay_rate**epoch))
|
||||
# for param_group in optimizer.param_groups:
|
||||
# param_group["lr"] = lr
|
||||
|
||||
|
||||
# def create_scheduler(config, optimizer):
|
||||
# scheduler_cls = schedulers[config.get('scheduler', 'constant_lr')]
|
||||
# max_epoch = config.epochs
|
||||
# min_lr = config.min_lr
|
||||
# init_lr = config.lr
|
||||
# warmup_start_lr = config.get('warmup_lr', -1)
|
||||
# warmup_steps = config.get('warmup_steps', 0)
|
||||
|
||||
# scheduler = scheduler_cls(
|
||||
# optimizer=optimizer,
|
||||
# max_epoch=max_epoch,
|
||||
# min_lr=min_lr,
|
||||
# init_lr=init_lr,
|
||||
# decay_rate=None,
|
||||
# warmup_start_lr=warmup_start_lr,
|
||||
# warmup_steps=warmup_steps
|
||||
# )
|
||||
|
||||
# return scheduler
|
||||
|
||||
|
||||
|
||||
class WarmupLinearScheduleNonZero(_LRScheduler):
|
||||
""" Linear warmup and then linear decay.
|
||||
Linearly increases learning rate from 0 to max_lr over `warmup_steps` training steps.
|
||||
Linearly decreases learning rate linearly to min_lr over remaining `t_total - warmup_steps` steps.
|
||||
"""
|
||||
def __init__(self, optimizer, warmup_steps, t_total, min_lr=1e-5, last_epoch=-1):
|
||||
self.warmup_steps = warmup_steps
|
||||
self.t_total = t_total
|
||||
self.min_lr = min_lr
|
||||
super(WarmupLinearScheduleNonZero, self).__init__(optimizer, last_epoch=last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
step = self.last_epoch
|
||||
if step < self.warmup_steps:
|
||||
lr_factor = float(step) / float(max(1, self.warmup_steps))
|
||||
else:
|
||||
lr_factor = max(0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps)))
|
||||
|
||||
return [base_lr * lr_factor if (base_lr * lr_factor) > self.min_lr else self.min_lr for base_lr in self.base_lrs]
|
||||
|
||||
|
||||
def create_scheduler(config, optimizer):
|
||||
lr_scheduler = None
|
||||
if config['scheduler'] == 'cosine':
|
||||
lr_scheduler = get_cosine_schedule_with_warmup(
|
||||
optimizer,
|
||||
num_warmup_steps=config['num_warmup_steps'],
|
||||
num_training_steps=config['num_training_steps'],
|
||||
num_cycles=0.5,
|
||||
min_lr_multi=config['min_lr_multi']
|
||||
)
|
||||
elif config['scheduler'] == 'linear':
|
||||
lr_scheduler = WarmupLinearScheduleNonZero(
|
||||
optimizer,
|
||||
config['num_warmup_steps'],
|
||||
config['num_training_steps'],
|
||||
min_lr = config['min_lr']
|
||||
)
|
||||
return lr_scheduler
|
||||
|
||||
|
||||
def get_cosine_schedule_with_warmup(
|
||||
optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int,
|
||||
num_cycles: float = 0.5, min_lr_multi: float = 0., last_epoch: int = -1
|
||||
):
|
||||
"""
|
||||
Modified from https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/optimization.py
|
||||
|
||||
Create a schedule with a learning rate that decreases following the values of the cosine function between the
|
||||
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
|
||||
initial lr set in the optimizer.
|
||||
Args:
|
||||
optimizer ([`~torch.optim.Optimizer`]):
|
||||
The optimizer for which to schedule the learning rate.
|
||||
num_warmup_steps (`int`):
|
||||
The number of steps for the warmup phase.
|
||||
num_training_steps (`int`):
|
||||
The total number of training steps.
|
||||
num_cycles (`float`, *optional*, defaults to 0.5):
|
||||
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
|
||||
following a half-cosine).
|
||||
min_lr_multi (`float`, *optional*, defaults to 0):
|
||||
The minimum learning rate multiplier. Thus the minimum learning rate is base_lr * min_lr_multi.
|
||||
last_epoch (`int`, *optional*, defaults to -1):
|
||||
The index of the last epoch when resuming training.
|
||||
Return:
|
||||
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
||||
"""
|
||||
|
||||
def lr_lambda(current_step):
|
||||
if current_step < num_warmup_steps:
|
||||
return max(min_lr_multi, float(current_step) / float(max(1, num_warmup_steps)))
|
||||
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
|
||||
return max(min_lr_multi, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
Loading…
Add table
Add a link
Reference in a new issue