initial commit

This commit is contained in:
Andreas Bulling 2025-06-24 08:38:09 +02:00
commit a82bbc593e
129 changed files with 33981 additions and 0 deletions

0
utils/__init__.py Normal file
View file

309
utils/basic.py Normal file
View file

@ -0,0 +1,309 @@
import numpy as np
import io
import os
import json
import logging
import random
import time
from collections import defaultdict, deque
import datetime
from pathlib import Path
from typing import List, Union
import itertools
import torch
import torch.distributed as dist
from .dist import is_dist_avail_and_initialized
logger = logging.getLogger(__name__)
def average_dicts(dicts):
# media = list(dicts.keys())
# keys = [list(d.keys()) for d in dicts.values]
# keys = list(itertools.chain.from_iterable(keys))
# keys = list(set(keys))
res = {}
counter = {}
for medium, medium_dict in dicts.items():
for loss_key, loss_value in medium_dict.items():
if loss_key not in res:
res[loss_key] = loss_value
counter[loss_key] = 1
else:
res[loss_key] += loss_value
counter[loss_key] += 1
for k in res:
res[k] = res[k] / counter[k]
return res
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window=20, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
if not is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total],
dtype=torch.float64, device='cuda')
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value)
class MetricLogger(object):
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, attr))
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
if meter.count == 0: # skip empty meter
loss_str.append(
"{}: {}".format(name, "No data")
)
else:
loss_str.append(
"{}: {}".format(name, str(meter))
)
return self.delimiter.join(loss_str)
def global_avg(self):
loss_str = []
for name, meter in self.meters.items():
if meter.count == 0:
loss_str.append(
"{}: {}".format(name, "No data")
)
else:
loss_str.append(
"{}: {:.4f}".format(name, meter.global_avg)
)
return self.delimiter.join(loss_str)
def get_global_avg_dict(self, prefix=""):
"""include a separator (e.g., `/`, or "_") at the end of `prefix`"""
d = {f"{prefix}{k}": m.global_avg if m.count > 0 else 0. for k, m in self.meters.items()}
return d
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, meter):
self.meters[name] = meter
def log_every(self, iterable, log_freq, header=None):
i = 0
if not header:
header = ''
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt='{avg:.4f}')
data_time = SmoothedValue(fmt='{avg:.4f}')
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
log_msg = [
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}\n',
'{meters}\n',
'time: {time}',
'data: {data}'
]
if torch.cuda.is_available():
log_msg.append('max mem: {memory:.0f} res mem: {res_mem:.0f}')
log_msg = self.delimiter.join(log_msg)
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if i % log_freq == 0 or i == len(iterable) - 1:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
logger.info(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB,
res_mem=torch.cuda.max_memory_reserved() / MB,
))
else:
logger.info(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time)))
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
logger.info('{} Total time: {} ({:.4f} s / it)'.format(
header, total_time_str, total_time / len(iterable)))
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
def compute_acc(logits, label, reduction='mean'):
ret = (torch.argmax(logits, dim=1) == label).float()
if reduction == 'none':
return ret.detach()
elif reduction == 'mean':
return ret.mean().item()
def compute_n_params(model, return_str=True):
tot = 0
for p in model.parameters():
w = 1
for x in p.shape:
w *= x
tot += w
if return_str:
if tot >= 1e6:
return '{:.1f}M'.format(tot / 1e6)
else:
return '{:.1f}K'.format(tot / 1e3)
else:
return tot
def setup_seed(seed):
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
def remove_files_if_exist(file_paths):
for fp in file_paths:
if os.path.isfile(fp):
os.remove(fp)
def save_json(data, filename, save_pretty=False, sort_keys=False):
with open(filename, "w") as f:
if save_pretty:
f.write(json.dumps(data, indent=4, sort_keys=sort_keys))
else:
json.dump(data, f)
def load_json(filename):
with open(filename, "r") as f:
return json.load(f)
def flat_list_of_lists(l):
"""flatten a list of lists [[1,2], [3,4]] to [1,2,3,4]"""
return [item for sublist in l for item in sublist]
def find_files_by_suffix_recursively(root: str, suffix: Union[str, List[str]]):
"""
Args:
root: path to the directory to start search files
suffix: any str as suffix, or can match multiple such strings
when input is List[str].
Example 1, e.g., suffix: `.jpg` or [`.jpg`, `.png`]
Example 2, e.g., use a `*` in the `suffix`: `START*.jpg.`.
"""
if isinstance(suffix, str):
suffix = [suffix, ]
filepaths = flat_list_of_lists(
[list(Path(root).rglob(f"*{e}")) for e in suffix])
return filepaths
def match_key_and_shape(state_dict1, state_dict2):
keys1 = set(state_dict1.keys())
keys2 = set(state_dict2.keys())
print(f"keys1 - keys2: {keys1 - keys2}")
print(f"keys2 - keys1: {keys2 - keys1}")
mismatch = 0
for k in list(keys1):
if state_dict1[k].shape != state_dict2[k].shape:
print(
f"k={k}, state_dict1[k].shape={state_dict1[k].shape}, state_dict2[k].shape={state_dict2[k].shape}")
mismatch += 1
print(f"mismatch {mismatch}")
def merge_dicts(list_dicts):
merged_dict = list_dicts[0].copy()
for i in range(1, len(list_dicts)):
merged_dict.update(list_dicts[i])
return merged_dict

25
utils/dist.py Normal file
View file

@ -0,0 +1,25 @@
import torch.distributed as dist
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0

149
utils/easydict.py Normal file
View file

@ -0,0 +1,149 @@
class EasyDict(dict):
"""
Get attributes
>>> d = EasyDict({'foo':3})
>>> d['foo']
3
>>> d.foo
3
>>> d.bar
Traceback (most recent call last):
...
AttributeError: 'EasyDict' object has no attribute 'bar'
Works recursively
>>> d = EasyDict({'foo':3, 'bar':{'x':1, 'y':2}})
>>> isinstance(d.bar, dict)
True
>>> d.bar.x
1
Bullet-proof
>>> EasyDict({})
{}
>>> EasyDict(d={})
{}
>>> EasyDict(None)
{}
>>> d = {'a': 1}
>>> EasyDict(**d)
{'a': 1}
Set attributes
>>> d = EasyDict()
>>> d.foo = 3
>>> d.foo
3
>>> d.bar = {'prop': 'value'}
>>> d.bar.prop
'value'
>>> d
{'foo': 3, 'bar': {'prop': 'value'}}
>>> d.bar.prop = 'newer'
>>> d.bar.prop
'newer'
Values extraction
>>> d = EasyDict({'foo':0, 'bar':[{'x':1, 'y':2}, {'x':3, 'y':4}]})
>>> isinstance(d.bar, list)
True
>>> from operator import attrgetter
>>> map(attrgetter('x'), d.bar)
[1, 3]
>>> map(attrgetter('y'), d.bar)
[2, 4]
>>> d = EasyDict()
>>> d.keys()
[]
>>> d = EasyDict(foo=3, bar=dict(x=1, y=2))
>>> d.foo
3
>>> d.bar.x
1
Still like a dict though
>>> o = EasyDict({'clean':True})
>>> o.items()
[('clean', True)]
And like a class
>>> class Flower(EasyDict):
... power = 1
...
>>> f = Flower()
>>> f.power
1
>>> f = Flower({'height': 12})
>>> f.height
12
>>> f['power']
1
>>> sorted(f.keys())
['height', 'power']
update and pop items
>>> d = EasyDict(a=1, b='2')
>>> e = EasyDict(c=3.0, a=9.0)
>>> d.update(e)
>>> d.c
3.0
>>> d['c']
3.0
>>> d.get('c')
3.0
>>> d.update(a=4, b=4)
>>> d.b
4
>>> d.pop('a')
4
>>> d.a
Traceback (most recent call last):
...
AttributeError: 'EasyDict' object has no attribute 'a'
"""
def __init__(self, d=None, **kwargs):
if d is None:
d = {}
if kwargs:
d.update(**kwargs)
for k, v in d.items():
setattr(self, k, v)
# Class attributes
for k in self.__class__.__dict__.keys():
if not (k.startswith("__") and k.endswith("__")) and not k in ("update", "pop"):
setattr(self, k, getattr(self, k))
def __setattr__(self, name, value):
if isinstance(value, (list, tuple)):
value = [self.__class__(x) if isinstance(x, dict) else x for x in value]
elif isinstance(value, dict) and not isinstance(value, self.__class__):
value = self.__class__(value)
super(EasyDict, self).__setattr__(name, value)
super(EasyDict, self).__setitem__(name, value)
__setitem__ = __setattr__
def update(self, e=None, **f):
d = e or dict()
d.update(f)
for k in d:
setattr(self, k, d[k])
def pop(self, k, d=None):
if hasattr(self, k):
delattr(self, k)
return super(EasyDict, self).pop(k, d)
if __name__ == "__main__":
import doctest

154
utils/init.py Normal file
View file

@ -0,0 +1,154 @@
import os
import torch
import random
import pyhocon
import datetime
import json
import subprocess
import itertools
import glob
import glog as log
import sys
import re
from os import path as osp
import numpy as np
# def load_runner(config, tokenizer, vocab_size):
# if config['task'] == 'avsd':
# return AVSDRunner(config, tokenizer, vocab_size)
# if config['task'] == 'simmc':
# return SIMMCRunner(config, tokenizer, vocab_size)
# elif config['task'] == 'nextqa':
# return NEXTQARunner(config, tokenizer, vocab_size)
# else:
# raise ValueError
def set_random_seed(random_seed):
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
random.seed(random_seed)
np.random.seed(random_seed)
def copy_file_to_log(log_dir):
dirs_to_cp = ['.', 'config', 'datasets', 'runners', 'models']
files_to_cp = ['*.py', '*.json', '*.sh', '*.conf']
for dir_name in dirs_to_cp:
dir_name = osp.join(log_dir, 'code', dir_name)
if not osp.exists(dir_name):
os.makedirs(dir_name)
for dir_name, file_name in itertools.product(dirs_to_cp, files_to_cp):
filename = osp.join(dir_name, file_name)
if len(glob.glob(filename)) > 0:
os.system(f'cp {filename} {osp.join(log_dir, "code", dir_name)}')
log.info(f'Files copied to {osp.join(log_dir, "code")}')
def set_log_file(fname, file_only=False):
# if fname already exists, find all log file under log dir,
# and name the current log file with a new number
if osp.exists(fname):
prefix, suffix = osp.splitext(fname)
log_files = glob.glob(prefix + '*' + suffix)
count = 0
for log_file in log_files:
num = re.search(r'(\d+)', log_file)
if num is not None:
num = int(num.group(0))
count = max(num, count)
fname = fname.replace(suffix, str(count + 1) + suffix)
# set log file
# simple tricks for duplicating logging destination in the logging module such as:
# logging.getLogger().addHandler(logging.FileHandler(filename))
# does NOT work well here, because python Traceback message (not via logging module) is not sent to the file,
# the following solution (copied from : https://stackoverflow.com/questions/616645) is a little bit
# complicated but simulates exactly the "tee" command in linux shell, and it redirects everything
if file_only:
# we only output messages to file, and stdout/stderr receives nothing.
# this feature is designed for executing the script via ssh:
# since ssh has a windowing kind of flow control, i.e., if the controller does not read data from a
# ssh channel and its buffer fills up, the execution machine will not be able to write anything into the
# channel and the process will be set to sleeping (S) status until someone reads all data from the channel.
# this is not desired since we do not want to read stdout/stderr from the controller machine.
# so, here we use a simple solution: disable output to stdout/stderr and only output messages to log file.
log.logger.handlers[0].stream = log.handler.stream = sys.stdout = sys.stderr = f = open(fname, 'w', buffering=1)
else:
# we output messages to both file and stdout/stderr
tee = subprocess.Popen(['tee', fname], stdin=subprocess.PIPE)
os.dup2(tee.stdin.fileno(), sys.stdout.fileno())
os.dup2(tee.stdin.fileno(), sys.stderr.fileno())
def set_training_steps(config, num_samples):
if config['parallel'] and config['dp_type'] == 'dp':
config['num_iter_per_epoch'] = int(np.ceil(num_samples / config['batch_size']))
else:
config['num_iter_per_epoch'] = int(np.ceil(num_samples / (config['batch_size'] * config['num_gpus'])))
if 'train_steps' not in config:
config['train_steps'] = config['num_iter_per_epoch'] * config['num_epochs']
if 'warmup_steps' not in config:
config['warmup_steps'] = int(config['train_steps'] * config['warmup_ratio'])
return config
def initialize_from_env(model, mode, stage, eval_dir, tag=''):
if mode in ['train', 'debug']:
path_config = f"config/{model}_{stage}.conf"
config = pyhocon.ConfigFactory.parse_file(path_config)[stage]
else:
path_config = os.path.join(eval_dir, f'{model}_{stage}.conf')
config = pyhocon.ConfigFactory.parse_file(path_config)[stage]
config['log_dir'] = eval_dir
if "CUDA_VISIBLE_DEVICES" in os.environ:
config['num_gpus'] = len(os.environ["CUDA_VISIBLE_DEVICES"].split(','))
# multi-gpu setting
if config['num_gpus'] > 1:
os.environ['MASTER_ADDR'] = 'localhost'
os.environ["MASTER_PORT"] = str(config['master_port'])
else:
config['num_gpus'] = 1
model += '-' + config.llm_name.replace('/', '_')
if mode == 'debug':
model += '_debug'
if tag:
model += '-' + tag
if mode != 'generate':
config["log_dir"] = os.path.join(config["log_dir"], model)
if not os.path.exists(config["log_dir"]):
os.makedirs(config["log_dir"])
# copy the config file
os.system(f'cp {path_config} {config["log_dir"]}')
config['timestamp'] = datetime.datetime.now().strftime('%m%d-%H%M%S')
config['expert_config'] = config['bert_config_{}'.format(config['expert_size'])]
config['expert_config_json'] = json.load(open(config['expert_config'], 'r'))
config['beit_config_json'] = json.load(open(config['beit_config'], 'r'))
config['model'] = model
config['stage'] = stage
config['loss_dict'] = {k:v for k,v in zip(config['loss_names'], config['loss_weights'])}
return config
def set_training_steps(config, num_samples, batch_sizes):
config['num_iter_per_epoch'] = sum([int(np.ceil(num_sample / (bs * config['accum_grad_every'] * config['num_gpus']))) for num_sample, bs in zip(num_samples, batch_sizes)])
if 'num_training_steps' not in config:
config['num_training_steps'] = config['num_iter_per_epoch'] * config['epochs']
if 'num_warmup_steps' not in config:
config['num_warmup_steps'] = int(config['num_iter_per_epoch'] * config.get('warmup_epochs', 1.0))
# config['num_warmup_steps'] = int(config['num_training_steps'] * config['warmup_ratio'])
return config

286
utils/logger.py Normal file
View file

@ -0,0 +1,286 @@
# from MMF: https://github.com/facebookresearch/mmf/blob/master/mmf/utils/logger.py
# Copyright (c) Facebook, Inc. and its affiliates.
import functools
import logging
import os
import sys
import time
import wandb
from .dist import get_rank, is_main_process
from termcolor import colored
def log_dict_to_wandb(log_dict, step, prefix=""):
"""include a separator `/` at the end of `prefix`"""
if not is_main_process():
return
log_dict = {f"{prefix}{k}": v for k, v in log_dict.items()}
wandb.log(log_dict, step)
def setup_wandb(config):
if not (config.wandb_enabled and is_main_process()):
return
run = wandb.init(
config=config,
project=config.wandb_project,
# entity=config.wandb.entity,
mode=config.wandb_mode,
# name=os.path.basename(config.output_dir),
reinit=True
)
wandb.define_metric('train/webvid/step')
wandb.define_metric('train/webvid/*', 'train/webvid/step')
wandb.define_metric('train/cc3m/step')
wandb.define_metric('train/cc3m/*', 'train/cc3m/step')
wandb.define_metric('train/other/step')
wandb.define_metric('train/other/*', 'train/other/step')
wandb.define_metric('val/msrvtt/step')
wandb.define_metric('val/msrvtt/*', 'val/msrvtt/step')
wandb.define_metric('train/champagne/step')
wandb.define_metric('train/champagne/*', 'train/champagne/step')
wandb.define_metric('train/visdial/step')
wandb.define_metric('train/visdial/*', 'train/visdial/step')
wandb.define_metric('train/avsd/step')
wandb.define_metric('train/avsd/*', 'train/avsd/step')
wandb.define_metric('train/nextqa/step')
wandb.define_metric('train/nextqa/*', 'train/nextqa/step')
return run
def setup_output_folder(save_dir: str, folder_only: bool = False):
"""Sets up and returns the output file where the logs will be placed
based on the configuration passed. Usually "save_dir/logs/log_<timestamp>.txt".
If env.log_dir is passed, logs will be directly saved in this folder.
Args:
folder_only (bool, optional): If folder should be returned and not the file.
Defaults to False.
Returns:
str: folder or file path depending on folder_only flag
"""
log_filename = "train_"
log_filename += time.strftime("%Y_%m_%dT%H_%M_%S")
log_filename += ".log"
log_folder = os.path.join(save_dir, "logs")
if not os.path.exists(log_folder):
os.path.mkdirs(log_folder)
if folder_only:
return log_folder
log_filename = os.path.join(log_folder, log_filename)
return log_filename
def setup_logger(
output: str = None,
color: bool = True,
name: str = "mmf",
disable: bool = False,
clear_handlers=True,
*args,
**kwargs,
):
"""
Initialize the MMF logger and set its verbosity level to "INFO".
Outside libraries shouldn't call this in case they have set there
own logging handlers and setup. If they do, and don't want to
clear handlers, pass clear_handlers options.
The initial version of this function was taken from D2 and adapted
for MMF.
Args:
output (str): a file name or a directory to save log.
If ends with ".txt" or ".log", assumed to be a file name.
Default: Saved to file <save_dir/logs/log_[timestamp].txt>
color (bool): If false, won't log colored logs. Default: true
name (str): the root module name of this logger. Defaults to "mmf".
disable: do not use
clear_handlers (bool): If false, won't clear existing handlers.
Returns:
logging.Logger: a logger
"""
if disable:
return None
logger = logging.getLogger(name)
logger.propagate = False
logging.captureWarnings(True)
warnings_logger = logging.getLogger("py.warnings")
plain_formatter = logging.Formatter(
"%(asctime)s | %(levelname)s | %(name)s : %(message)s",
datefmt="%Y-%m-%dT%H:%M:%S",
)
distributed_rank = get_rank()
handlers = []
logging_level = logging.INFO
# logging_level = logging.DEBUG
if distributed_rank == 0:
logger.setLevel(logging_level)
ch = logging.StreamHandler(stream=sys.stdout)
ch.setLevel(logging_level)
if color:
formatter = ColorfulFormatter(
colored("%(asctime)s | %(name)s: ", "green") + "%(message)s",
datefmt="%Y-%m-%dT%H:%M:%S",
)
else:
formatter = plain_formatter
ch.setFormatter(formatter)
logger.addHandler(ch)
warnings_logger.addHandler(ch)
handlers.append(ch)
# file logging: all workers
if output is None:
output = setup_output_folder()
if output is not None:
if output.endswith(".txt") or output.endswith(".log"):
filename = output
else:
filename = os.path.join(output, "train.log")
if distributed_rank > 0:
filename = filename + f".rank{distributed_rank}"
os.makedirs(os.path.dirname(filename), exist_ok=True)
fh = logging.StreamHandler(_cached_log_stream(filename))
fh.setLevel(logging_level)
fh.setFormatter(plain_formatter)
logger.addHandler(fh)
warnings_logger.addHandler(fh)
handlers.append(fh)
# Slurm/FB output, only log the main process
# save_dir = get_mmf_env(key="save_dir")
if "train.log" not in filename and distributed_rank == 0:
filename = os.path.join(output, "train.log")
sh = logging.StreamHandler(_cached_log_stream(filename))
sh.setLevel(logging_level)
sh.setFormatter(plain_formatter)
logger.addHandler(sh)
warnings_logger.addHandler(sh)
handlers.append(sh)
logger.info(f"Logging to: {filename}")
# Remove existing handlers to add MMF specific handlers
if clear_handlers:
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
# Now, add our handlers.
logging.basicConfig(level=logging_level, handlers=handlers)
return logger
def setup_very_basic_config(color=True):
plain_formatter = logging.Formatter(
"%(asctime)s | %(levelname)s | %(name)s : %(message)s",
datefmt="%Y-%m-%dT%H:%M:%S",
)
ch = logging.StreamHandler(stream=sys.stdout)
ch.setLevel(logging.INFO)
if color:
formatter = ColorfulFormatter(
colored("%(asctime)s | %(name)s: ", "green") + "%(message)s",
datefmt="%Y-%m-%dT%H:%M:%S",
)
else:
formatter = plain_formatter
ch.setFormatter(formatter)
# Setup a minimal configuration for logging in case something tries to
# log a message even before logging is setup by MMF.
logging.basicConfig(level=logging.INFO, handlers=[ch])
# cache the opened file object, so that different calls to `setup_logger`
# with the same file name can safely write to the same file.
@functools.lru_cache(maxsize=None)
def _cached_log_stream(filename):
return open(filename, "a")
# ColorfulFormatter is adopted from Detectron2 and adapted for MMF
class ColorfulFormatter(logging.Formatter):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def formatMessage(self, record):
log = super().formatMessage(record)
if record.levelno == logging.WARNING:
prefix = colored("WARNING", "red", attrs=["blink"])
elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
prefix = colored("ERROR", "red", attrs=["blink", "underline"])
else:
return log
return prefix + " " + log
class TensorboardLogger:
def __init__(self, log_folder="./logs", iteration=0):
# This would handle warning of missing tensorboard
from torch.utils.tensorboard import SummaryWriter
self.summary_writer = None
self._is_master = is_main_process()
# self.timer = Timer()
self.log_folder = log_folder
if self._is_master:
# current_time = self.timer.get_time_hhmmss(None, format=self.time_format)
current_time = time.strftime("%Y-%m-%dT%H:%M:%S")
# self.timer.get_time_hhmmss(None, format=self.time_format)
tensorboard_folder = os.path.join(
self.log_folder, f"tensorboard_{current_time}"
)
self.summary_writer = SummaryWriter(tensorboard_folder)
def __del__(self):
if getattr(self, "summary_writer", None) is not None:
self.summary_writer.close()
def _should_log_tensorboard(self):
if self.summary_writer is None or not self._is_master:
return False
else:
return True
def add_scalar(self, key, value, iteration):
if not self._should_log_tensorboard():
return
self.summary_writer.add_scalar(key, value, iteration)
def add_scalars(self, scalar_dict, iteration):
if not self._should_log_tensorboard():
return
for key, val in scalar_dict.items():
self.summary_writer.add_scalar(key, val, iteration)
def add_histogram_for_model(self, model, iteration):
if not self._should_log_tensorboard():
return
for name, param in model.named_parameters():
np_param = param.clone().cpu().data.numpy()
self.summary_writer.add_histogram(name, np_param, iteration)

174
utils/metrcis.py Normal file
View file

@ -0,0 +1,174 @@
"""
A Metric observes output of certain model, for example, in form of logits or
scores, and accumulates a particular metric with reference to some provided
targets. In context of VisDial, we use Recall (@ 1, 5, 10), Mean Rank, Mean
Reciprocal Rank (MRR) and Normalized Discounted Cumulative Gain (NDCG).
Each ``Metric`` must atleast implement three methods:
- ``observe``, update accumulated metric with currently observed outputs
and targets.
- ``retrieve`` to return the accumulated metric., an optionally reset
internally accumulated metric (this is commonly done between two epochs
after validation).
- ``reset`` to explicitly reset the internally accumulated metric.
Caveat, if you wish to implement your own class of Metric, make sure you call
``detach`` on output tensors (like logits), else it will cause memory leaks.
"""
import torch
def scores_to_ranks(scores: torch.Tensor):
"""Convert model output scores into ranks."""
batch_size, num_rounds, num_options = scores.size()
scores = scores.view(-1, num_options)
# sort in descending order - largest score gets highest rank
sorted_ranks, ranked_idx = scores.sort(1, descending=True)
# i-th position in ranked_idx specifies which score shall take this
# position but we want i-th position to have rank of score at that
# position, do this conversion
ranks = ranked_idx.clone().fill_(0)
for i in range(ranked_idx.size(0)):
for j in range(num_options):
ranks[i][ranked_idx[i][j]] = j
# convert from 0-99 ranks to 1-100 ranks
ranks += 1
ranks = ranks.view(batch_size, num_rounds, num_options)
return ranks
class SparseGTMetrics(object):
"""
A class to accumulate all metrics with sparse ground truth annotations.
These include Recall (@ 1, 5, 10), Mean Rank and Mean Reciprocal Rank.
"""
def __init__(self):
self._rank_list = []
def observe(
self, predicted_scores: torch.Tensor, target_ranks: torch.Tensor
):
predicted_scores = predicted_scores.detach()
# shape: (batch_size, num_rounds, num_options)
predicted_ranks = scores_to_ranks(predicted_scores)
batch_size, num_rounds, num_options = predicted_ranks.size()
# collapse batch dimension
predicted_ranks = predicted_ranks.view(
batch_size * num_rounds, num_options
)
# shape: (batch_size * num_rounds, )
target_ranks = target_ranks.view(batch_size * num_rounds).long()
# shape: (batch_size * num_rounds, )
predicted_gt_ranks = predicted_ranks[
torch.arange(batch_size * num_rounds), target_ranks
]
self._rank_list.extend(list(predicted_gt_ranks.cpu().numpy()))
def retrieve(self, reset: bool = True):
num_examples = len(self._rank_list)
if num_examples > 0:
# convert to numpy array for easy calculation.
__rank_list = torch.tensor(self._rank_list).float()
metrics = {
"r@1": torch.mean((__rank_list <= 1).float()).item(),
"r@5": torch.mean((__rank_list <= 5).float()).item(),
"r@10": torch.mean((__rank_list <= 10).float()).item(),
"mean": torch.mean(__rank_list).item(),
"mrr": torch.mean(__rank_list.reciprocal()).item(),
}
else:
metrics = {}
if reset:
self.reset()
return metrics
def reset(self):
self._rank_list = []
class NDCG(object):
def __init__(self):
self._ndcg_numerator = 0.0
self._ndcg_denominator = 0.0
def observe(
self, predicted_scores: torch.Tensor, target_relevance: torch.Tensor
):
"""
Observe model output scores and target ground truth relevance and
accumulate NDCG metric.
Parameters
----------
predicted_scores: torch.Tensor
A tensor of shape (batch_size, num_options), because dense
annotations are available for 1 randomly picked round out of 10.
target_relevance: torch.Tensor
A tensor of shape same as predicted scores, indicating ground truth
relevance of each answer option for a particular round.
"""
predicted_scores = predicted_scores.detach()
# shape: (batch_size, 1, num_options)
predicted_scores = predicted_scores.unsqueeze(1)
predicted_ranks = scores_to_ranks(predicted_scores)
# shape: (batch_size, num_options)
predicted_ranks = predicted_ranks.squeeze(1)
batch_size, num_options = predicted_ranks.size()
k = torch.sum(target_relevance != 0, dim=-1)
# shape: (batch_size, num_options)
_, rankings = torch.sort(predicted_ranks, dim=-1)
# Sort relevance in descending order so highest relevance gets top rnk.
_, best_rankings = torch.sort(
target_relevance, dim=-1, descending=True
)
# shape: (batch_size, )
batch_ndcg = []
for batch_index in range(batch_size):
num_relevant = k[batch_index]
dcg = self._dcg(
rankings[batch_index][:num_relevant],
target_relevance[batch_index],
)
best_dcg = self._dcg(
best_rankings[batch_index][:num_relevant],
target_relevance[batch_index],
)
batch_ndcg.append(dcg / best_dcg)
self._ndcg_denominator += batch_size
self._ndcg_numerator += sum(batch_ndcg)
def _dcg(self, rankings: torch.Tensor, relevance: torch.Tensor):
sorted_relevance = relevance[rankings].cpu().float()
discounts = torch.log2(torch.arange(len(rankings)).float() + 2)
return torch.sum(sorted_relevance / discounts, dim=-1)
def retrieve(self, reset: bool = True):
if self._ndcg_denominator > 0:
metrics = {
"ndcg": float(self._ndcg_numerator / self._ndcg_denominator)
}
else:
metrics = {}
if reset:
self.reset()
return metrics
def reset(self):
self._ndcg_numerator = 0.0
self._ndcg_denominator = 0.0

35
utils/optimizer.py Normal file
View file

@ -0,0 +1,35 @@
""" Optimizer Factory w/ Custom Weight Decay
Hacked together by / Copyright 2020 Ross Wightman
"""
import re
import torch
from torch import optim as optim
from utils.dist import is_main_process
import glog as logger
# from transformers import create_optimizer
# from transformers import AdamW
# import math
def create_optimizer(config, model):
lr_scale = config.get('lr_layer_decay', 1)
weight_decay = config.get('weight_decay', 0.01)
optim_params = model.get_optimizer_params(weight_decay, lr_scale)
num_parameters = 0
for p_group in optim_params:
for p in p_group['params']:
num_parameters += p.data.nelement()
logger.info('number of trainable parameters: {}'.format(num_parameters))
lr = config.get('lr', 1e-4)
betas = config.get('opt_betas', [0.9, 0.999])
optimizer = torch.optim.AdamW(
optim_params,
lr=float(lr),
betas=betas
)
return optimizer

240
utils/scheduler.py Normal file
View file

@ -0,0 +1,240 @@
""" Scheduler Factory
Hacked together by / Copyright 2020 Ross Wightman
"""
from torch.optim import Optimizer
import math
from torch.optim.lr_scheduler import LambdaLR, _LRScheduler
import math
# class LinearWarmupStepLRScheduler:
# def __init__(
# self,
# optimizer,
# max_epoch,
# min_lr,
# init_lr,
# decay_rate=1,
# warmup_start_lr=-1,
# warmup_steps=0,
# **kwargs
# ):
# self.optimizer = optimizer
# self.max_epoch = max_epoch
# self.min_lr = min_lr
# self.decay_rate = decay_rate
# self.init_lr = init_lr
# self.warmup_steps = warmup_steps
# self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
# def step(self, cur_epoch, cur_step):
# if cur_epoch == 0:
# warmup_lr_schedule(
# step=cur_step,
# optimizer=self.optimizer,
# max_step=self.warmup_steps,
# init_lr=self.warmup_start_lr,
# max_lr=self.init_lr,
# )
# else:
# step_lr_schedule(
# epoch=cur_epoch,
# optimizer=self.optimizer,
# init_lr=self.init_lr,
# min_lr=self.min_lr,
# decay_rate=self.decay_rate,
# )
# class LinearWarmupCosineLRScheduler:
# def __init__(
# self,
# optimizer,
# max_epoch,
# min_lr,
# init_lr,
# warmup_steps=0,
# warmup_start_lr=-1,
# **kwargs
# ):
# self.optimizer = optimizer
# self.max_epoch = max_epoch
# self.min_lr = min_lr
# self.init_lr = init_lr
# self.warmup_steps = warmup_steps
# self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
# def step(self, cur_epoch, cur_step):
# # assuming the warmup iters less than one epoch
# if cur_epoch == 0:
# warmup_lr_schedule(
# step=cur_step,
# optimizer=self.optimizer,
# max_step=self.warmup_steps,
# init_lr=self.warmup_start_lr,
# max_lr=self.init_lr,
# )
# else:
# cosine_lr_schedule(
# epoch=cur_epoch,
# optimizer=self.optimizer,
# max_epoch=self.max_epoch,
# init_lr=self.init_lr,
# min_lr=self.min_lr,
# )
# class ConstantLRScheduler:
# def __init__(self, optimizer, init_lr, warmup_start_lr=-1, warmup_steps=0, **kwargs):
# self.optimizer = optimizer
# self.lr = init_lr
# self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
# self.warmup_steps = warmup_steps
# def step(self, cur_epoch, cur_step):
# if cur_epoch == 0:
# warmup_lr_schedule(
# step=cur_step,
# optimizer=self.optimizer,
# max_step=self.warmup_steps,
# init_lr=self.warmup_start_lr,
# max_lr=self.lr,
# )
# else:
# for param_group in self.optimizer.param_groups:
# param_group["lr"] = self.lr
# schedulers = {
# 'constant_lr': ConstantLRScheduler,
# 'linear_warmup_cosine_lr': LinearWarmupCosineLRScheduler,
# 'linear_warmup_step_lr': LinearWarmupStepLRScheduler
# }
# def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
# """Decay the learning rate"""
# lr = (init_lr - min_lr) * 0.5 * (
# 1.0 + math.cos(math.pi * epoch / max_epoch)
# ) + min_lr
# for param_group in optimizer.param_groups:
# param_group["lr"] = lr
# def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
# """Warmup the learning rate"""
# lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))
# for param_group in optimizer.param_groups:
# param_group["lr"] = lr
# def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
# """Decay the learning rate"""
# lr = max(min_lr, init_lr * (decay_rate**epoch))
# for param_group in optimizer.param_groups:
# param_group["lr"] = lr
# def create_scheduler(config, optimizer):
# scheduler_cls = schedulers[config.get('scheduler', 'constant_lr')]
# max_epoch = config.epochs
# min_lr = config.min_lr
# init_lr = config.lr
# warmup_start_lr = config.get('warmup_lr', -1)
# warmup_steps = config.get('warmup_steps', 0)
# scheduler = scheduler_cls(
# optimizer=optimizer,
# max_epoch=max_epoch,
# min_lr=min_lr,
# init_lr=init_lr,
# decay_rate=None,
# warmup_start_lr=warmup_start_lr,
# warmup_steps=warmup_steps
# )
# return scheduler
class WarmupLinearScheduleNonZero(_LRScheduler):
""" Linear warmup and then linear decay.
Linearly increases learning rate from 0 to max_lr over `warmup_steps` training steps.
Linearly decreases learning rate linearly to min_lr over remaining `t_total - warmup_steps` steps.
"""
def __init__(self, optimizer, warmup_steps, t_total, min_lr=1e-5, last_epoch=-1):
self.warmup_steps = warmup_steps
self.t_total = t_total
self.min_lr = min_lr
super(WarmupLinearScheduleNonZero, self).__init__(optimizer, last_epoch=last_epoch)
def get_lr(self):
step = self.last_epoch
if step < self.warmup_steps:
lr_factor = float(step) / float(max(1, self.warmup_steps))
else:
lr_factor = max(0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps)))
return [base_lr * lr_factor if (base_lr * lr_factor) > self.min_lr else self.min_lr for base_lr in self.base_lrs]
def create_scheduler(config, optimizer):
lr_scheduler = None
if config['scheduler'] == 'cosine':
lr_scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=config['num_warmup_steps'],
num_training_steps=config['num_training_steps'],
num_cycles=0.5,
min_lr_multi=config['min_lr_multi']
)
elif config['scheduler'] == 'linear':
lr_scheduler = WarmupLinearScheduleNonZero(
optimizer,
config['num_warmup_steps'],
config['num_training_steps'],
min_lr = config['min_lr']
)
return lr_scheduler
def get_cosine_schedule_with_warmup(
optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int,
num_cycles: float = 0.5, min_lr_multi: float = 0., last_epoch: int = -1
):
"""
Modified from https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/optimization.py
Create a schedule with a learning rate that decreases following the values of the cosine function between the
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
initial lr set in the optimizer.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (`int`):
The number of steps for the warmup phase.
num_training_steps (`int`):
The total number of training steps.
num_cycles (`float`, *optional*, defaults to 0.5):
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
following a half-cosine).
min_lr_multi (`float`, *optional*, defaults to 0):
The minimum learning rate multiplier. Thus the minimum learning rate is base_lr * min_lr_multi.
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return max(min_lr_multi, float(current_step) / float(max(1, num_warmup_steps)))
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
return max(min_lr_multi, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
return LambdaLR(optimizer, lr_lambda, last_epoch)