import wandb import os import os.path as osp import json from collections import deque, OrderedDict import time import re import shutil import glob import pickle import gc import numpy as np import glog as log import torch import torch.utils.data as tud import torch.nn.functional as F import torch.distributed as dist from torch.nn.utils import clip_grad_value_ class Runner: def __init__(self, config): self.config = config if 'rank' in config: self.gpu_rank = config['rank'] else: self.gpu_rank = 0 self.epoch_idx = 0 self.min_gen_val_loss = float('inf') self.best_epoch_idx = 0 if self.config["max_ckpt_to_keep"] > 0: self.checkpoint_queue = deque([], maxlen=config["max_ckpt_to_keep"]) self.metrics_queue = deque([], maxlen=config["max_ckpt_to_keep"]) self.setup_wandb() def setup_wandb(self): if self.gpu_rank == 0: print("[INFO] Set wandb logging on rank {}".format(0)) run = wandb.init( project=self.config['wandb_project'], config=self.config, mode=self.config['wandb_mode']) else: run = None self.run = run def forward(self, batch, eval=False): return NotImplementedError def train(self, dataset, dataset_eval): batch_size = self.config['batch_size'] if self.config['parallel'] and self.config['dp_type'] != 'dp': sampler = tud.distributed.DistributedSampler( dataset, num_replicas=self.config['num_gpus'], rank=self.gpu_rank ) else: sampler = None data_loader = tud.DataLoader( dataset=dataset, batch_size=batch_size, shuffle=self.config['training'] and not self.config['parallel'], collate_fn=dataset.collate_fn, num_workers=self.config['num_workers'], sampler=sampler ) start_epoch_idx = self.epoch_idx num_iter_epoch = self.config['num_iter_per_epoch'] if self.config['display']: log.info(f'{num_iter_epoch} iter per epoch.') num_epochs = self.config['num_epochs'] # Perform validation before training if self.config['eval_first']: _ = self.val(dataset_eval) for epoch_idx in range(start_epoch_idx, num_epochs): if self.config['parallel'] and self.config['dp_type'] != 'dp': sampler.set_epoch(epoch_idx) self.epoch_idx = epoch_idx if self.config['display']: log.info(f'starting epoch {epoch_idx}') log.info('training') self.model.train() num_batch = 0 next_logging_pct = .1 start_time = time.time() self.optimizer.zero_grad() for batch in data_loader: num_batch += 1 pct = num_batch / num_iter_epoch * 100 iter_now = num_iter_epoch * epoch_idx + num_batch output = self.forward(batch) losses = output['losses'] # optimizer step losses['tot_loss'] /= self.config['batch_multiply'] # debug if self.config['debugging']: log.info('try backward') losses['tot_loss'].backward() if self.config['clip_grad_value'] > 0: clip_grad_value_(self.model.parameters(), self.config['clip_grad_value']) if self.config['debugging']: log.info('backward done') if iter_now % self.config['batch_multiply'] == 0: self.optimizer.step() self.optimizer.zero_grad() self.scheduler.step() # display and eval if pct >= next_logging_pct: if self.config['display']: loss_to_print = '' for key in losses: if losses[key] is not None and isinstance(losses[key], torch.Tensor): loss_to_print += f'[{key}: {losses[key].item():.4f}]' print( f'[{int(pct)}%][Epoch: {epoch_idx + 1}/{num_epochs}][Iter : {num_batch}/{len(data_loader)}] [time: {time.time() - start_time:.2f}] {loss_to_print}' ) if self.config['print_output']: print(10 * '-' + 'responses' + 10 * '-') print(output['reponses']) print(10 * '-' + 'gt' + 10 * '-') print(output['gt']) next_logging_pct += self.config["next_logging_pct"] if self.config['debugging']: break lr_bart, lr_other = self.scheduler.get_lr()[0], self.scheduler.get_lr()[-1] elbo_global_key = 'elbo_loss_global (x{})'.format(self.config['elbo_global_coeff']) elbo_local_key = 'elbo_loss_local (x{})'.format(self.config['elbo_local_coeff']) gen_key = 'gen_loss (x{})'.format(self.config['gen_coeff']) if self.run: self.run.log( { f"Train/{gen_key}": losses[gen_key].item(), f"Train/{elbo_global_key}": losses[elbo_global_key].item(), f"Train/{elbo_local_key}": losses[elbo_local_key].item(), "Train/total_loss": losses['tot_loss'].item(), }, step=iter_now ) self.run.log( {"Train/lr_bart": lr_bart, "Train/lr_other": lr_other}, step=iter_now ) del losses del output if self.config['display']: log.info( f'100%,\ttime:\t{time.time() - start_time:.2f}' ) if not self.config['overfit'] and self.run: self.save_ckpt() if not self.config['skip_eval']: iter_now = num_iter_epoch * (epoch_idx + 1) val_losses = self.val(dataset_eval) if self.config['display']: log.info('#'*100) for k in val_losses: log.info('Average val {} (epoch {}) = {}'.format(k, self.epoch_idx, val_losses[k])) log.info('#'*100) gen_val_loss = val_losses[gen_key] if gen_val_loss < self.min_gen_val_loss: self.min_gen_val_loss = gen_val_loss self.best_epoch_idx = epoch_idx # Log the best model w.r.t. the validation data if self.run and self.config['save_ckpt']: self.save_ckpt_best() if self.run: self.run.log( { f"Val/{gen_key}": val_losses[gen_key], f"Val/{elbo_global_key}": val_losses[elbo_global_key], f"Val/{elbo_local_key}": val_losses[elbo_local_key], "Val/total_loss": val_losses['tot_loss'], "Val/min_gen_loss": self.min_gen_val_loss }, step=iter_now ) if self.config['parallel']: if self.config['dp_type'] == 'dp': gc.collect() torch.cuda.empty_cache() else: dist.barrier() torch.cuda.empty_cache() if self.config['stop_epochs'] >= 0 and epoch_idx + 1 >= self.config['stop_epochs']: if self.config['display']: log.info('Stop for reaching stop_epochs.') break if self.config['display']: log.info(f'Best validation loss was reached at epoch {self.best_epoch_idx}.') def val(self, dataset): total_loss_val = 0.0 total_gen_loss_val = 0.0 total_elbo_global_val = 0.0 total_elbo_local_val = 0.0 num_batch_val = 0 next_logging_pct_val = 0.05 elbo_global_key = 'elbo_loss_global (x{})'.format(self.config['elbo_global_coeff']) elbo_local_key = 'elbo_loss_local (x{})'.format(self.config['elbo_local_coeff']) gen_key = 'gen_loss (x{})'.format(self.config['gen_coeff']) # Prepare the dataloader if self.config['parallel'] and self.config['dp_type'] != 'dp': sampler_val = tud.distributed.DistributedSampler( dataset, num_replicas=self.config['num_gpus'], rank=self.gpu_rank ) sampler_val.set_epoch(self.epoch_idx) else: sampler_val = None data_loader_val = tud.DataLoader( dataset=dataset, batch_size=self.config['batch_size'], shuffle=False, collate_fn=dataset.collate_fn, num_workers=self.config['num_workers'], sampler=sampler_val ) if self.config['parallel'] and self.config['dp_type'] == 'dp': num_iter_per_epoch_val = int(np.ceil(len(dataset) / self.config['batch_size'])) else: num_iter_per_epoch_val = int(np.ceil(len(dataset) / (self.config['batch_size'] * self.config['num_gpus']))) self.model.eval() if self.gpu_rank == 0: start_time = time.time() for batch in data_loader_val: num_batch_val += 1 pct = num_batch_val / num_iter_per_epoch_val * 100 with torch.no_grad(): output = self.forward(batch) losses = output['losses'] losses['tot_loss'] /= self.config['batch_multiply'] losses[elbo_global_key] /= self.config['batch_multiply'] losses[elbo_local_key] /= self.config['batch_multiply'] losses[gen_key] /= self.config['batch_multiply'] total_loss_val += losses['tot_loss'].item() total_gen_loss_val += losses[gen_key].item() total_elbo_global_val += losses[elbo_global_key].item() total_elbo_local_val += losses[elbo_local_key].item() # display and eval if pct >= next_logging_pct_val: if self.config['display']: loss_to_print = '' for key in losses: if losses[key] is not None and isinstance(losses[key], torch.Tensor): loss_to_print += f'[{key}: {losses[key].item():.4f}]' print( f'[{int(pct)}%][Validating][Iter : {num_batch_val}/{num_iter_per_epoch_val}] [time: {time.time() - start_time:.2f}] {loss_to_print}' ) next_logging_pct_val += self.config["next_logging_pct"] loss_val = total_loss_val / num_batch_val gen_loss_val = total_gen_loss_val / num_batch_val elbo_global_val = total_elbo_global_val / num_batch_val elbo_local_val = total_elbo_local_val / num_batch_val losses_val = { 'tot_loss': loss_val, elbo_global_key: elbo_global_val, elbo_local_key: elbo_local_val, gen_key: gen_loss_val } self.model.train() return losses_val def save_eval_results(self, split, epoch_idx, metrics_results): metrics_filename = osp.join(self.config['log_dir'], f'metrics_epoch_{epoch_idx}.json') with open(metrics_filename, 'w') as f: json.dump(metrics_results, f) log.info(f'Results of metrics saved to {metrics_filename}') if self.config["max_ckpt_to_keep"] > 0: if len(self.metrics_queue) == self.metrics_queue.maxlen: todel = self.metrics_queue.popleft() os.remove(todel) self.metrics_queue.append(metrics_filename) if epoch_idx == 'best': self.copy_best_predictions(split) def copy_best_results(self, split, epoch_idx): to_print = 'Copy ' if not self.config['skip_saving_ckpt']: ckpt_path = osp.join(self.config['log_dir'], f'epoch_{epoch_idx}.ckpt') best_ckpt_path = ckpt_path.replace(f'{epoch_idx}.ckpt', 'best.ckpt') shutil.copyfile(ckpt_path, best_ckpt_path) to_print += best_ckpt_path + ' ' metrics_filename = osp.join(self.config['log_dir'], f'metrics_epoch_{epoch_idx}.json') best_metric_filename = metrics_filename.replace(f'{epoch_idx}.json', 'best.json') shutil.copyfile(metrics_filename, best_metric_filename) to_print += best_metric_filename + ' ' log.info(to_print) def set_ckpt(self, ckpt_dict): if self.config['parallel']: model = self.model.module else: model = self.model model_state_dict = model.state_dict() former_dict = {k: v for k, v in ckpt_dict['model_state_dict'].items() if k in model_state_dict} if self.config['display']: log.info("number of keys transferred: %d" % len(former_dict)) assert len(former_dict.keys()) > 0 model_state_dict.update(former_dict) model.load_state_dict(model_state_dict) if self.config['display']: log.info('loaded model') del model_state_dict, former_dict # if not self.config['uses_new_optimizer']: if not self.config['generating'] and not (self.config['uses_new_optimizer'] or self.config['sets_new_lr']): if not self.config['restarts']: self.epoch_idx = ckpt_dict['epoch_idx'] + 1 if not self.config['resets_min_val_loss']: self.min_gen_val_loss = ckpt_dict['min_gen_val_loss'] self.optimizer.load_state_dict(ckpt_dict['optimizer']) if self.config['display']: log.info('loaded optimizer') if 'scheduler' in ckpt_dict: self.scheduler.last_epcoh = ckpt_dict['epoch_idx'] * self.config['num_iter_per_epoch'] self.scheduler.load_state_dict(ckpt_dict['scheduler']) del ckpt_dict torch.cuda.empty_cache() def save_ckpt(self): ckpt_path = f'{self.config["log_dir"]}/epoch_{self.epoch_idx}.ckpt' log.info(f'saving checkpoint {ckpt_path}') ckpt = self.get_ckpt() if self.config['skip_saving_ckpt']: return ckpt_path torch_version = float(torch.__version__[:3]) if torch_version - 1.4 > 1e-3: torch.save(ckpt, f=ckpt_path, _use_new_zipfile_serialization=False) else: torch.save(ckpt, f=ckpt_path) del ckpt if not self.config['parallel']: torch.cuda.empty_cache() if self.config["max_ckpt_to_keep"] > 0: if len(self.checkpoint_queue) == self.checkpoint_queue.maxlen: todel = self.checkpoint_queue.popleft() os.remove(todel) self.checkpoint_queue.append(ckpt_path) def save_ckpt_best(self): ckpt_path = f'{self.config["log_dir"]}/epoch_best.ckpt' log.info(f'saving checkpoint {ckpt_path}') ckpt = self.get_ckpt() torch.save(ckpt, f=ckpt_path) del ckpt return ckpt_path def load_ckpt_best(self): ckpt_path = f'{self.config["log_dir"]}/epoch_best.ckpt' if not osp.exists(ckpt_path): ckpt_paths = [path for path in os.listdir(f'{self.config["log_dir"]}/') if path.endswith('.ckpt') and 'best' not in path] if len(ckpt_paths) == 0: if self.config['display']: log.info(f'No .ckpt found in {self.config["log_dir"]}') return sort_func = lambda x:int(re.search(r"(\d+)", x).groups()[0]) ckpt_path = f'{self.config["log_dir"]}/{sorted(ckpt_paths, key=sort_func, reverse=True)[0]}' if self.config['display']: log.info(f'loading checkpoint {ckpt_path}') map_location = {'cuda:0': f'cuda:{self.gpu_rank}'} self.set_ckpt(torch.load(ckpt_path, map_location=map_location)) def get_ckpt(self): ckpt = { 'epoch_idx': self.epoch_idx, 'min_gen_val_loss': self.min_gen_val_loss, 'seed': self.config['random_seed'], 'optimizer': self.optimizer.state_dict(), 'scheduler': self.scheduler.state_dict() } ckpt['model_state_dict'] = self.model.module.state_dict() return ckpt def load_ckpt(self, ckpt_path=None): if not ckpt_path: if self.config['generating']: # or self.config['start_ckpt_for_generating']: ckpt_path = f'{self.config["log_dir"]}/epoch_best.ckpt' else: ckpt_paths = [path for path in os.listdir(f'{self.config["log_dir"]}/') if path.endswith('.ckpt') and 'best' not in path] if len(ckpt_paths) == 0: if self.config['display']: log.info(f'No .ckpt found in {self.config["log_dir"]}') return sort_func = lambda x:int(re.search(r"(\d+)", x).groups()[0]) ckpt_path = f'{self.config["log_dir"]}/{sorted(ckpt_paths, key=sort_func, reverse=True)[0]}' if self.config['display']: log.info(f'loading checkpoint {ckpt_path}') epoch_name = osp.split(ckpt_path)[1].split('.')[0] if re.search(r"(\d+)", epoch_name): self.checkpoint_queue.append(ckpt_path) metrics_filename = osp.join(self.config['log_dir'], f'metrics_{epoch_name}.json') if osp.exists(metrics_filename): self.metrics_queue.append(metrics_filename) map_location = {'cuda:0': f'cuda:{self.gpu_rank}'} self.set_ckpt(torch.load(ckpt_path, map_location=map_location)) def match_model_key(self, pretrained_dict, model_dict): matched_dict = dict() not_found = [] for key in pretrained_dict: if key in model_dict: matched_key = key elif key.startswith('encoder.') and key[8:] in model_dict: matched_key = key[8:] elif key.startswith('module.') and key[7:] in model_dict: matched_key = key[7:] elif 'encoder.' + key in model_dict: matched_key = 'encoder.' + key elif 'module.' + key in model_dict: matched_key = 'module.' + key else: not_found.append(key) continue matched_dict[matched_key] = pretrained_dict[key] print("Keys from pretrained_dict that were not found in model_dict:\n", not_found) return matched_dict