initial commit
This commit is contained in:
commit
a82bbc593e
129 changed files with 33981 additions and 0 deletions
286
utils/logger.py
Normal file
286
utils/logger.py
Normal file
|
@ -0,0 +1,286 @@
|
|||
# from MMF: https://github.com/facebookresearch/mmf/blob/master/mmf/utils/logger.py
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import wandb
|
||||
|
||||
from .dist import get_rank, is_main_process
|
||||
from termcolor import colored
|
||||
|
||||
|
||||
def log_dict_to_wandb(log_dict, step, prefix=""):
|
||||
"""include a separator `/` at the end of `prefix`"""
|
||||
if not is_main_process():
|
||||
return
|
||||
|
||||
log_dict = {f"{prefix}{k}": v for k, v in log_dict.items()}
|
||||
wandb.log(log_dict, step)
|
||||
|
||||
|
||||
def setup_wandb(config):
|
||||
if not (config.wandb_enabled and is_main_process()):
|
||||
return
|
||||
|
||||
run = wandb.init(
|
||||
config=config,
|
||||
project=config.wandb_project,
|
||||
# entity=config.wandb.entity,
|
||||
mode=config.wandb_mode,
|
||||
# name=os.path.basename(config.output_dir),
|
||||
reinit=True
|
||||
)
|
||||
wandb.define_metric('train/webvid/step')
|
||||
wandb.define_metric('train/webvid/*', 'train/webvid/step')
|
||||
|
||||
wandb.define_metric('train/cc3m/step')
|
||||
wandb.define_metric('train/cc3m/*', 'train/cc3m/step')
|
||||
|
||||
wandb.define_metric('train/other/step')
|
||||
wandb.define_metric('train/other/*', 'train/other/step')
|
||||
|
||||
wandb.define_metric('val/msrvtt/step')
|
||||
wandb.define_metric('val/msrvtt/*', 'val/msrvtt/step')
|
||||
|
||||
wandb.define_metric('train/champagne/step')
|
||||
wandb.define_metric('train/champagne/*', 'train/champagne/step')
|
||||
|
||||
wandb.define_metric('train/visdial/step')
|
||||
wandb.define_metric('train/visdial/*', 'train/visdial/step')
|
||||
|
||||
wandb.define_metric('train/avsd/step')
|
||||
wandb.define_metric('train/avsd/*', 'train/avsd/step')
|
||||
|
||||
wandb.define_metric('train/nextqa/step')
|
||||
wandb.define_metric('train/nextqa/*', 'train/nextqa/step')
|
||||
|
||||
return run
|
||||
|
||||
|
||||
def setup_output_folder(save_dir: str, folder_only: bool = False):
|
||||
"""Sets up and returns the output file where the logs will be placed
|
||||
based on the configuration passed. Usually "save_dir/logs/log_<timestamp>.txt".
|
||||
If env.log_dir is passed, logs will be directly saved in this folder.
|
||||
Args:
|
||||
folder_only (bool, optional): If folder should be returned and not the file.
|
||||
Defaults to False.
|
||||
Returns:
|
||||
str: folder or file path depending on folder_only flag
|
||||
"""
|
||||
log_filename = "train_"
|
||||
log_filename += time.strftime("%Y_%m_%dT%H_%M_%S")
|
||||
log_filename += ".log"
|
||||
|
||||
log_folder = os.path.join(save_dir, "logs")
|
||||
|
||||
if not os.path.exists(log_folder):
|
||||
os.path.mkdirs(log_folder)
|
||||
|
||||
if folder_only:
|
||||
return log_folder
|
||||
|
||||
log_filename = os.path.join(log_folder, log_filename)
|
||||
|
||||
return log_filename
|
||||
|
||||
|
||||
def setup_logger(
|
||||
output: str = None,
|
||||
color: bool = True,
|
||||
name: str = "mmf",
|
||||
disable: bool = False,
|
||||
clear_handlers=True,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize the MMF logger and set its verbosity level to "INFO".
|
||||
Outside libraries shouldn't call this in case they have set there
|
||||
own logging handlers and setup. If they do, and don't want to
|
||||
clear handlers, pass clear_handlers options.
|
||||
The initial version of this function was taken from D2 and adapted
|
||||
for MMF.
|
||||
Args:
|
||||
output (str): a file name or a directory to save log.
|
||||
If ends with ".txt" or ".log", assumed to be a file name.
|
||||
Default: Saved to file <save_dir/logs/log_[timestamp].txt>
|
||||
color (bool): If false, won't log colored logs. Default: true
|
||||
name (str): the root module name of this logger. Defaults to "mmf".
|
||||
disable: do not use
|
||||
clear_handlers (bool): If false, won't clear existing handlers.
|
||||
Returns:
|
||||
logging.Logger: a logger
|
||||
"""
|
||||
if disable:
|
||||
return None
|
||||
logger = logging.getLogger(name)
|
||||
logger.propagate = False
|
||||
|
||||
logging.captureWarnings(True)
|
||||
warnings_logger = logging.getLogger("py.warnings")
|
||||
|
||||
plain_formatter = logging.Formatter(
|
||||
"%(asctime)s | %(levelname)s | %(name)s : %(message)s",
|
||||
datefmt="%Y-%m-%dT%H:%M:%S",
|
||||
)
|
||||
|
||||
distributed_rank = get_rank()
|
||||
handlers = []
|
||||
|
||||
logging_level = logging.INFO
|
||||
# logging_level = logging.DEBUG
|
||||
|
||||
if distributed_rank == 0:
|
||||
logger.setLevel(logging_level)
|
||||
ch = logging.StreamHandler(stream=sys.stdout)
|
||||
ch.setLevel(logging_level)
|
||||
if color:
|
||||
formatter = ColorfulFormatter(
|
||||
colored("%(asctime)s | %(name)s: ", "green") + "%(message)s",
|
||||
datefmt="%Y-%m-%dT%H:%M:%S",
|
||||
)
|
||||
else:
|
||||
formatter = plain_formatter
|
||||
ch.setFormatter(formatter)
|
||||
logger.addHandler(ch)
|
||||
warnings_logger.addHandler(ch)
|
||||
handlers.append(ch)
|
||||
|
||||
# file logging: all workers
|
||||
if output is None:
|
||||
output = setup_output_folder()
|
||||
|
||||
if output is not None:
|
||||
if output.endswith(".txt") or output.endswith(".log"):
|
||||
filename = output
|
||||
else:
|
||||
filename = os.path.join(output, "train.log")
|
||||
if distributed_rank > 0:
|
||||
filename = filename + f".rank{distributed_rank}"
|
||||
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
||||
|
||||
fh = logging.StreamHandler(_cached_log_stream(filename))
|
||||
fh.setLevel(logging_level)
|
||||
fh.setFormatter(plain_formatter)
|
||||
logger.addHandler(fh)
|
||||
warnings_logger.addHandler(fh)
|
||||
handlers.append(fh)
|
||||
|
||||
# Slurm/FB output, only log the main process
|
||||
# save_dir = get_mmf_env(key="save_dir")
|
||||
if "train.log" not in filename and distributed_rank == 0:
|
||||
filename = os.path.join(output, "train.log")
|
||||
sh = logging.StreamHandler(_cached_log_stream(filename))
|
||||
sh.setLevel(logging_level)
|
||||
sh.setFormatter(plain_formatter)
|
||||
logger.addHandler(sh)
|
||||
warnings_logger.addHandler(sh)
|
||||
handlers.append(sh)
|
||||
|
||||
logger.info(f"Logging to: {filename}")
|
||||
|
||||
# Remove existing handlers to add MMF specific handlers
|
||||
if clear_handlers:
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
# Now, add our handlers.
|
||||
logging.basicConfig(level=logging_level, handlers=handlers)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
def setup_very_basic_config(color=True):
|
||||
plain_formatter = logging.Formatter(
|
||||
"%(asctime)s | %(levelname)s | %(name)s : %(message)s",
|
||||
datefmt="%Y-%m-%dT%H:%M:%S",
|
||||
)
|
||||
ch = logging.StreamHandler(stream=sys.stdout)
|
||||
ch.setLevel(logging.INFO)
|
||||
if color:
|
||||
formatter = ColorfulFormatter(
|
||||
colored("%(asctime)s | %(name)s: ", "green") + "%(message)s",
|
||||
datefmt="%Y-%m-%dT%H:%M:%S",
|
||||
)
|
||||
else:
|
||||
formatter = plain_formatter
|
||||
ch.setFormatter(formatter)
|
||||
# Setup a minimal configuration for logging in case something tries to
|
||||
# log a message even before logging is setup by MMF.
|
||||
logging.basicConfig(level=logging.INFO, handlers=[ch])
|
||||
|
||||
|
||||
# cache the opened file object, so that different calls to `setup_logger`
|
||||
# with the same file name can safely write to the same file.
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _cached_log_stream(filename):
|
||||
return open(filename, "a")
|
||||
|
||||
|
||||
# ColorfulFormatter is adopted from Detectron2 and adapted for MMF
|
||||
class ColorfulFormatter(logging.Formatter):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def formatMessage(self, record):
|
||||
log = super().formatMessage(record)
|
||||
if record.levelno == logging.WARNING:
|
||||
prefix = colored("WARNING", "red", attrs=["blink"])
|
||||
elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
|
||||
prefix = colored("ERROR", "red", attrs=["blink", "underline"])
|
||||
else:
|
||||
return log
|
||||
return prefix + " " + log
|
||||
|
||||
|
||||
class TensorboardLogger:
|
||||
def __init__(self, log_folder="./logs", iteration=0):
|
||||
# This would handle warning of missing tensorboard
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
self.summary_writer = None
|
||||
self._is_master = is_main_process()
|
||||
# self.timer = Timer()
|
||||
self.log_folder = log_folder
|
||||
|
||||
if self._is_master:
|
||||
# current_time = self.timer.get_time_hhmmss(None, format=self.time_format)
|
||||
current_time = time.strftime("%Y-%m-%dT%H:%M:%S")
|
||||
# self.timer.get_time_hhmmss(None, format=self.time_format)
|
||||
tensorboard_folder = os.path.join(
|
||||
self.log_folder, f"tensorboard_{current_time}"
|
||||
)
|
||||
self.summary_writer = SummaryWriter(tensorboard_folder)
|
||||
|
||||
def __del__(self):
|
||||
if getattr(self, "summary_writer", None) is not None:
|
||||
self.summary_writer.close()
|
||||
|
||||
def _should_log_tensorboard(self):
|
||||
if self.summary_writer is None or not self._is_master:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def add_scalar(self, key, value, iteration):
|
||||
if not self._should_log_tensorboard():
|
||||
return
|
||||
|
||||
self.summary_writer.add_scalar(key, value, iteration)
|
||||
|
||||
def add_scalars(self, scalar_dict, iteration):
|
||||
if not self._should_log_tensorboard():
|
||||
return
|
||||
|
||||
for key, val in scalar_dict.items():
|
||||
self.summary_writer.add_scalar(key, val, iteration)
|
||||
|
||||
def add_histogram_for_model(self, model, iteration):
|
||||
if not self._should_log_tensorboard():
|
||||
return
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
np_param = param.clone().cpu().data.numpy()
|
||||
self.summary_writer.add_histogram(name, np_param, iteration)
|
Loading…
Add table
Add a link
Reference in a new issue