VDGR/models/runner.py
2023-10-25 15:38:09 +02:00

830 lines
34 KiB
Python

import os
import os.path as osp
import json
from collections import deque
import time
import re
import shutil
import glob
import pickle
import gc
import numpy as np
import glog as log
try:
from apex import amp
except ModuleNotFoundError:
print('apex not found')
import torch
import torch.utils.data as tud
import torch.nn.functional as F
import torch.distributed as dist
from utils.data_utils import load_pickle_lines
from utils.visdial_metrics import SparseGTMetrics, NDCG, scores_to_ranks
import wandb
class Runner:
def __init__(self, config):
self.config = config
if 'rank' in config:
self.gpu_rank = config['rank']
else:
self.gpu_rank = 0
self.epoch_idx = 0
self.max_metric = 0.
self.max_metric_epoch_idx = 0
self.na_str = 'N/A'
if self.config["max_ckpt_to_keep"] > 0:
self.checkpoint_queue = deque(
[], maxlen=config["max_ckpt_to_keep"])
self.metrics_queue = deque([], maxlen=config["max_ckpt_to_keep"])
self.setup_wandb()
def setup_wandb(self):
if self.gpu_rank == 0:
print("[INFO] Set wandb logging on rank {}".format(0))
run = wandb.init(
project=self.config['wandb_project'], config=self.config, mode=self.config['wandb_mode'])
else:
run = None
self.run = run
def forward(self, batch, eval_visdial=False):
return NotImplementedError
def train(self, dataset, dataset_eval=None):
# wandb.login()
if os.path.exists(self.config['log_dir']) or self.config['loads_ckpt'] or self.config['loads_best_ckpt']:
self.load_ckpt()
if self.config['use_trainval']:
dataset.split = 'trainval'
else:
dataset.split = 'train'
batch_size = self.config['batch_size']
if self.config['parallel'] and self.config['dp_type'] != 'dp':
sampler_tr = tud.distributed.DistributedSampler(
dataset,
num_replicas=self.config['num_gpus'],
rank=self.gpu_rank
)
else:
sampler_tr = None
data_loader_tr = tud.DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=self.config['training'] and not self.config['parallel'],
collate_fn=dataset.collate_fn,
num_workers=self.config['num_workers'],
sampler=sampler_tr
)
start_epoch_idx = self.epoch_idx
num_iter_epoch = self.config['num_iter_per_epoch']
if self.config['display']:
log.info(f'{num_iter_epoch} iter per epoch.')
# eval before training
eval_dense_at_first = self.config['train_on_dense'] and self.config['skip_mrr_eval'] and start_epoch_idx == 0
# eval before training under 2 circumstances:
# for dense finetuning, eval ndcg before the first epoch
# for mrr training, continue training and the last epoch is not evaluated
if (eval_dense_at_first or (self.config['eval_at_start'] and len(self.metrics_queue) == 0 and start_epoch_idx > 0)):
if eval_dense_at_first:
iter_now = 0
else:
iter_now = max(num_iter_epoch * start_epoch_idx, 0)
if dataset_eval is None:
dataset.split = 'val'
dataset_to_eval = dataset
else:
dataset_to_eval = dataset_eval
metrics_results = {}
metrics_to_maximize, metrics_results['val'] = self.evaluate(
dataset_to_eval, iter_now)
if eval_dense_at_first:
self.max_metric = metrics_to_maximize
self.max_metric_epoch_idx = -1
else:
if self.config['display']:
self.save_eval_results(
'val', start_epoch_idx - 1, metrics_results)
if metrics_to_maximize > self.max_metric:
self.max_metric = metrics_to_maximize
self.max_metric_epoch_idx = start_epoch_idx - 1
self.copy_best_results('val', start_epoch_idx - 1)
self.copy_best_predictions('val')
if dataset_eval is None:
if self.config['use_trainval']:
dataset.split = 'trainval'
else:
dataset.split = 'train'
num_epochs = self.config['num_epochs']
for epoch_idx in range(start_epoch_idx, num_epochs):
if self.config['parallel'] and self.config['dp_type'] != 'dp':
sampler_tr.set_epoch(epoch_idx)
self.epoch_idx = epoch_idx
if self.config['display']:
log.info(f'starting epoch {epoch_idx}')
log.info('training')
self.model.train()
num_batch = 0
next_logging_pct = .1
next_evaluating_pct = self.config["next_evaluating_pct"] + .1
start_time = time.time()
self.optimizer.zero_grad()
for batch in data_loader_tr:
if self.config['eval_before_training']:
log.info('Skipping stright to evaluation...')
break
num_batch += 1
pct = num_batch / num_iter_epoch * 100
iter_now = num_iter_epoch * epoch_idx + num_batch
output = self.forward(batch)
losses = output['losses']
# optimizer step
losses['tot_loss'] /= self.config['batch_multiply']
# debug
if self.config['debugging']:
log.info('try backward')
if self.config['dp_type'] == 'apex':
with amp.scale_loss(losses['tot_loss'], self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
losses['tot_loss'].backward()
if self.config['debugging']:
log.info('backward done')
if iter_now % self.config['batch_multiply'] == 0:
self.optimizer.step()
self.optimizer.zero_grad()
self.scheduler.step()
# display and eval
if pct >= next_logging_pct:
if self.config['display']:
loss_to_print = ''
for key in losses:
if losses[key] is not None and isinstance(losses[key], torch.Tensor):
loss_to_print += f'[{key}: {losses[key].item():.4f}]'
print(
f'[{int(pct)}%][Epoch: {epoch_idx + 1}/{num_epochs}][Iter : {num_batch}/{len(data_loader_tr)}] [time: {time.time() - start_time:.2f}] {loss_to_print}'
)
next_logging_pct += self.config["next_logging_pct"]
if self.config['debugging']:
break
if pct >= next_evaluating_pct:
next_evaluating_pct += self.config["next_evaluating_pct"]
if self.run:
if self.config['train_on_dense']:
self.run.log(
{
"Train/dense_loss": losses['dense_loss'],
"Train/total_loss": losses['tot_loss'],
},
step=iter_now
)
else:
self.run.log(
{
"Train/lm_loss": losses['lm_loss'],
"Train/img_loss": losses['img_loss'],
"Train/nsp_loss": losses['nsp_loss'],
"Train/total_loss": losses['tot_loss'],
},
step=iter_now
)
lr_gnn, lr_bert = self.scheduler.get_lr()[0], self.scheduler.get_lr()[1]
self.run.log(
{
"Train/lr_gnn": lr_gnn,
"Train/lr_bert": lr_bert,
},
step=iter_now
)
del losses
# debug
torch.cuda.empty_cache()
if self.config['display']:
log.info(
f'100%,\ttime:\t{time.time() - start_time:.2f}'
)
ckpt_path = self.save_ckpt()
if not self.config['skip_visdial_eval'] and self.epoch_idx % self.config['eval_visdial_every'] == 0:
iter_now = num_iter_epoch * (epoch_idx + 1)
if dataset_eval is None:
dataset.split = 'val'
dataset_to_eval = dataset
else:
dataset_to_eval = dataset_eval
metrics_results = {}
metrics_to_maximize, metrics_results['val'] = self.evaluate(
dataset_to_eval, iter_now)
if dataset_eval is None:
if self.config['use_trainval']:
dataset.split = 'trainval'
else:
dataset.split = 'train'
if self.config['display']:
self.save_eval_results('val', epoch_idx, metrics_results)
if self.config['display']:
if metrics_to_maximize > self.max_metric:
self.max_metric = metrics_to_maximize
self.max_metric_epoch_idx = epoch_idx
self.copy_best_results('val', epoch_idx)
self.copy_best_predictions('val')
elif not self.config['parallel'] and epoch_idx - self.max_metric_epoch_idx > self.config["early_stop_epoch"]:
log.info('Early stop.')
break
if self.run:
self.run.log(
{"Val/metric_best": self.max_metric}, step=iter_now)
if self.config['parallel']:
if self.config['dp_type'] == 'dp':
gc.collect()
torch.cuda.empty_cache()
else:
dist.barrier()
log.info('Rank {} passed barrier...'.format(self.gpu_rank))
if self.config['stop_epochs'] >= 0 and epoch_idx + 1 >= self.config['stop_epochs']:
if self.config['display']:
log.info('Stop for reaching stop_epochs.')
break
def evaluate(self, dataset, training_iter=None, eval_visdial=True):
# create files to save output
if self.config['predicting']:
visdial_file_name = None
if self.config['save_score']:
visdial_file_name = osp.join(
self.config['log_dir'], f'visdial_prediction.pkl')
if osp.exists(visdial_file_name):
dialogs_predicted = load_pickle_lines(
visdial_file_name)
dialogs_predicted = [d['image_id']
for d in dialogs_predicted]
else:
dialogs_predicted = []
f_visdial = open(visdial_file_name, 'ab')
else:
visdial_file_name = osp.join(
self.config['log_dir'], f'visdial_prediction.jsonlines')
if self.config['parallel'] and self.config['dp_type'] != 'dp':
visdial_file_name = visdial_file_name.replace(
'.jsonlines', f'_{self.config["rank"]}of{self.config["num_gpus"]}.jsonlines')
if osp.exists(visdial_file_name):
dialogs_predicted_visdial = [json.loads(
line)['image_id'] for line in open(visdial_file_name)]
f_visdial = open(visdial_file_name, 'a')
else:
dialogs_predicted_visdial = []
f_visdial = open(visdial_file_name, 'w')
dialogs_predicted = dialogs_predicted_visdial
if len(dialogs_predicted) > 0:
log.info(f'Found {len(dialogs_predicted)} predicted results.')
if self.config['display']:
if visdial_file_name is not None:
log.info(
f'VisDial predictions saved to {visdial_file_name}')
elif self.config['display']:
if self.config['continue_evaluation']:
predicted_files = os.listdir(
osp.join(self.config['visdial_output_dir'], dataset.split))
dialogs_predicted = [
int(re.match(r'(\d+).npz', p).group(1)) for p in predicted_files]
else:
if osp.exists(osp.join(self.config['visdial_output_dir'], dataset.split)):
shutil.rmtree(
osp.join(self.config['visdial_output_dir'], dataset.split))
os.makedirs(
osp.join(self.config['visdial_output_dir'], dataset.split))
dialogs_predicted = []
log.info(f'Found {len(dialogs_predicted)} predicted results.')
if self.config['parallel'] and self.config['dp_type'] != 'dp':
sampler_val = tud.distributed.DistributedSampler(
dataset,
num_replicas=self.config['num_gpus'],
rank=self.gpu_rank
)
sampler_val.set_epoch(self.epoch_idx)
else:
sampler_val = None
data_loader_val = tud.DataLoader(
dataset=dataset,
batch_size=self.config['eval_batch_size'],
shuffle=False,
collate_fn=dataset.collate_fn,
num_workers=self.config['num_workers'],
sampler=sampler_val
)
self.model.eval()
with torch.no_grad():
if self.config['display']:
log.info(f'Evaluating {len(dataset)} samples')
next_logging_pct = self.config["next_logging_pct"] + .1
if self.config['parallel'] and self.config['dp_type'] == 'dp':
num_batch_tot = int(
np.ceil(len(dataset) / self.config['eval_batch_size']))
else:
num_batch_tot = int(np.ceil(
len(dataset) / (self.config['eval_batch_size'] * self.config['num_gpus'])))
num_batch = 0
if dataset.split == 'val':
num_options = self.config["num_options"]
if self.config['skip_mrr_eval']:
num_rounds = 1
else:
num_rounds = 10
elif dataset.split == 'test':
num_options = 100
num_rounds = 1
if self.gpu_rank == 0:
start_time = time.time()
for batch in data_loader_val:
num_batch += 1
# skip dialogs that have been predicted
if self.config['predicting']:
image_ids = batch['image_id'].tolist()
skip_batch = True
for image_id in image_ids:
if image_id not in dialogs_predicted:
skip_batch = False
if skip_batch:
continue
output = self.forward(
batch, eval_visdial=eval_visdial)
# visdial evaluation
if eval_visdial:
img_ids = batch['image_id'].tolist()
batch_size = len(img_ids)
if not self.config['skip_ndcg_eval']:
gt_relevance_round_id = batch['round_id'].tolist()
# [batch_size * num_rounds * num_options, 2]
nsp_scores = output['nsp_scores']
nsp_probs = F.softmax(nsp_scores, dim=1)
assert nsp_probs.shape[-1] == 2
# num_dim=2, 0 for postive, 1 for negative
nsp_probs = nsp_probs[:, 0]
nsp_probs = nsp_probs.view(
batch_size, num_rounds, num_options)
# could be predicting or evaluating
if dataset.split == 'val':
if self.config['skip_ndcg_eval']:
gt_option_inds = batch['gt_option_inds']
for b in range(batch_size):
filename = osp.join(
self.config['visdial_output_dir'], dataset.split, f'{img_ids[b]}.npz')
if not osp.exists(filename):
np.savez(
filename,
nsp_probs=nsp_probs[b].cpu().numpy(),
gt_option_inds=gt_option_inds[b].cpu().numpy()
)
else:
# [batch_size, num_rounds]
gt_option_inds = batch['gt_option_inds']
# [batch_size, num_options]
gt_relevance = batch['gt_relevance']
for b in range(batch_size):
filename = osp.join(
self.config['visdial_output_dir'], dataset.split, f'{img_ids[b]}.npz')
if not osp.exists(filename):
np.savez(filename,
nsp_probs=nsp_probs[b].cpu().numpy(),
gt_option_inds=gt_option_inds[b].cpu(
).numpy(),
gt_relevance=gt_relevance[b].cpu(
).numpy(),
gt_relevance_round_id=gt_relevance_round_id[b])
# must be predicting
if dataset.split == 'test':
if self.config['save_score']:
for b in range(batch_size):
prediction = {
"image_id": img_ids[b],
"nsp_probs": nsp_probs[b].cpu().numpy(),
"gt_relevance_round_id": gt_relevance_round_id[b]
}
pickle.dump(prediction, f_visdial)
else:
# [eval_batch_size, num_rounds, num_options]
ranks = scores_to_ranks(nsp_probs)
ranks = ranks.squeeze(1)
for b in range(batch_size):
prediction = {
"image_id": img_ids[b],
"round_id": gt_relevance_round_id[b],
"ranks": ranks[b].tolist()
}
f_visdial.write(json.dumps(prediction) + '\n')
# debug
if self.config['debugging']:
break
pct = num_batch / num_batch_tot * 100
if pct >= next_logging_pct:
if self.config['display'] and self.gpu_rank == 0:
log.info(
f'{int(pct)}%,\ttime:\t{time.time() - start_time:.2f}'
)
next_logging_pct += self.config["next_logging_pct"]
# debug
if self.config['debugging']:
break
if self.config['display'] and self.gpu_rank == 0:
pct = num_batch / num_batch_tot * 100
log.info(
f'{int(pct)}%,\ttime:\t{time.time() - start_time:.2f}'
)
if not self.config['validating']:
self.model.train()
if self.config['parallel'] and self.config['dp_type'] != 'dp':
dist.barrier()
print(f'{self.gpu_rank} passed barrier')
if self.config['predicting']:
f_visdial.close()
if not self.config['save_score']:
all_visdial_predictions = [json.loads(
line) for line in open(visdial_file_name)]
if self.config['predict_split'] == 'test' and len(all_visdial_predictions) == self.config['num_test_dialogs']:
visdial_file_name = visdial_file_name.replace(
'jsonlines', 'json')
with open(visdial_file_name, 'w') as f_visdial:
json.dump(all_visdial_predictions, f_visdial)
log.info(
f'Prediction for submisson save to {visdial_file_name}.')
return None, None
if self.config['display']:
if dataset.split == 'val' and eval_visdial:
if not self.config['skip_mrr_eval']:
sparse_metrics = SparseGTMetrics()
if not self.config['skip_ndcg_eval']:
ndcg = NDCG()
if dataset.split == 'val' and eval_visdial:
visdial_output_filenames = glob.glob(
osp.join(self.config['visdial_output_dir'], dataset.split, '*.npz'))
log.info(
f'Calculating visdial metrics for {len(visdial_output_filenames)} dialogs')
for visdial_output_filename in visdial_output_filenames:
output = np.load(visdial_output_filename)
nsp_probs = torch.from_numpy(
output['nsp_probs']).unsqueeze(0)
if not self.config['skip_ndcg_eval']:
gt_relevance = torch.from_numpy(output['gt_relevance']).unsqueeze(0)
if not self.config['skip_mrr_eval']:
gt_option_inds = torch.from_numpy(
output['gt_option_inds']).unsqueeze(0)
sparse_metrics.observe(nsp_probs, gt_option_inds)
if not self.config['skip_ndcg_eval']:
gt_relevance_round_id = output['gt_relevance_round_id']
nsp_probs_dense = nsp_probs[0, gt_relevance_round_id - 1, :].unsqueeze(0)
else:
nsp_probs_dense = nsp_probs.squeeze(0) # [1, 100]
if not self.config['skip_ndcg_eval']:
ndcg.observe(nsp_probs_dense, gt_relevance)
# visdial eval output
visdial_metrics = {}
if dataset.split == 'val' and eval_visdial:
if not self.config['skip_mrr_eval']:
visdial_metrics.update(sparse_metrics.retrieve(reset=True))
if not self.config['skip_ndcg_eval']:
visdial_metrics.update(ndcg.retrieve(reset=True))
if self.config['display']:
to_print = ''
for metric_name, metric_value in visdial_metrics.items():
if 'round' not in metric_name:
to_print += f"\n{metric_name}: {metric_value}"
if training_iter is not None:
if self.run:
self.run.log(
{'Val/' + metric_name: metric_value}, step=training_iter)
log.info(to_print)
if self.config['metrics_to_maximize'] in visdial_metrics:
metrics_to_maximize = visdial_metrics[self.config['metrics_to_maximize']]
else:
metrics_to_maximize = None
torch.cuda.empty_cache()
return metrics_to_maximize, visdial_metrics
else:
torch.cuda.empty_cache()
return None, None
def save_eval_results(self, split, epoch_idx, metrics_results):
metrics_filename = osp.join(
self.config['log_dir'], f'metrics_epoch_{epoch_idx}.json')
with open(metrics_filename, 'w') as f:
json.dump(metrics_results, f)
log.info(f'Results of metrics saved to {metrics_filename}')
if self.config["max_ckpt_to_keep"] > 0:
if len(self.metrics_queue) == self.metrics_queue.maxlen:
todel = self.metrics_queue.popleft()
os.remove(todel)
self.metrics_queue.append(metrics_filename)
if epoch_idx == 'best':
self.copy_best_predictions(split)
def copy_best_results(self, split, epoch_idx):
to_print = 'Copy '
if not self.config['skip_saving_ckpt']:
ckpt_path = osp.join(
self.config['log_dir'], f'epoch_{epoch_idx}.ckpt')
best_ckpt_path = ckpt_path.replace(
f'{epoch_idx}.ckpt', 'best.ckpt')
shutil.copyfile(ckpt_path, best_ckpt_path)
to_print += best_ckpt_path + ' '
metrics_filename = osp.join(
self.config['log_dir'], f'metrics_epoch_{epoch_idx}.json')
best_metric_filename = metrics_filename.replace(
f'{epoch_idx}.json', 'best.json')
shutil.copyfile(metrics_filename, best_metric_filename)
to_print += best_metric_filename + ' '
log.info(to_print)
def copy_best_predictions(self, split):
to_print = 'Copy '
visdial_output_dir = osp.join(self.config['visdial_output_dir'], split)
if osp.exists(visdial_output_dir):
dir_best = visdial_output_dir.replace('output', 'output_best')
if osp.exists(dir_best):
shutil.rmtree(dir_best)
shutil.copytree(visdial_output_dir, dir_best)
to_print += dir_best + ' '
log.info(to_print)
def get_ckpt(self):
ckpt = {
'epoch_idx': self.epoch_idx,
'max_metric': self.max_metric,
'seed': self.config['random_seed'],
'optimizer': self.optimizer.state_dict(),
'scheduler': self.scheduler.state_dict()
}
if self.config['parallel']:
ckpt['model_state_dict'] = self.model.module.state_dict()
else:
ckpt['model_state_dict'] = self.model.state_dict()
if self.config['dp_type'] == 'apex':
ckpt['amp'] = amp.state_dict()
return ckpt
def set_ckpt(self, ckpt_dict):
if not self.config['restarts']:
self.epoch_idx = ckpt_dict.get('epoch_idx', -1) + 1
if not self.config['resets_max_metric']:
self.max_metric = ckpt_dict.get('max_metric', -1)
if self.config['parallel']:
model = self.model.module
else:
model = self.model
model_state_dict = model.state_dict()
former_dict = {
k: v for k, v in ckpt_dict['model_state_dict'].items() if k in model_state_dict}
if self.config['display']:
log.info("number of keys transferred: %d" % len(former_dict))
assert len(former_dict.keys()) > 0
model_state_dict.update(former_dict)
model.load_state_dict(model_state_dict)
if self.config['display']:
log.info('loaded model')
del model_state_dict, former_dict
if not self.config['validating'] and not (self.config['uses_new_optimizer'] or self.config['sets_new_lr']):
if 'optimizer' in ckpt_dict:
self.optimizer.load_state_dict(ckpt_dict['optimizer'])
if self.config['display']:
log.info('loaded optimizer')
if 'scheduler' in ckpt_dict:
self.scheduler.last_epcoh = ckpt_dict['epoch_idx'] * \
self.config['num_iter_per_epoch']
self.scheduler.load_state_dict(ckpt_dict['scheduler'])
if 'amp' in ckpt_dict and self.config['dp_type'] == 'apex':
amp.load_state_dict(ckpt_dict['amp'])
del ckpt_dict
torch.cuda.empty_cache()
def save_ckpt(self):
ckpt_path = f'{self.config["log_dir"]}/epoch_{self.epoch_idx}.ckpt'
log.info(f'saving checkpoint {ckpt_path}')
ckpt = self.get_ckpt()
if self.config['skip_saving_ckpt']:
return ckpt_path
torch_version = float(torch.__version__[:3])
if torch_version - 1.4 > 1e-3:
torch.save(ckpt, f=ckpt_path, _use_new_zipfile_serialization=False)
else:
torch.save(ckpt, f=ckpt_path)
del ckpt
if not (self.config['parallel'] and self.config['dp_type'] in ['ddp', 'apex']):
torch.cuda.empty_cache()
if self.config["max_ckpt_to_keep"] > 0:
if len(self.checkpoint_queue) == self.checkpoint_queue.maxlen:
todel = self.checkpoint_queue.popleft()
os.remove(todel)
self.checkpoint_queue.append(ckpt_path)
return ckpt_path
def save_ckpt_best(self):
ckpt_path = f'{self.config["log_dir"]}/epoch_best.ckpt'
log.info(f'saving checkpoint {ckpt_path}')
ckpt = self.get_ckpt()
torch.save(ckpt, f=ckpt_path)
del ckpt
return ckpt_path
def load_ckpt_best(self):
ckpt_path = f'{osp.dirname(self.config["log_dir"])}/epoch_best.ckpt'
if not osp.exists(ckpt_path):
ckpt_paths = [path for path in os.listdir(
f'{self.config["log_dir"]}/') if path.endswith('.ckpt') and 'best' not in path]
if len(ckpt_paths) == 0:
if self.config['display']:
log.info(f'No .ckpt found in {self.config["log_dir"]}')
return
def sort_func(x): return int(re.search(r"(\d+)", x).groups()[0])
ckpt_path = f'{self.config["log_dir"]}/{sorted(ckpt_paths, key=sort_func, reverse=True)[0]}'
if self.config['display']:
log.info(f'loading checkpoint {ckpt_path}')
map_location = {'cuda:0': f'cuda:{self.gpu_rank}'}
self.set_ckpt(torch.load(ckpt_path, map_location=map_location))
def load_ckpt(self, ckpt_path=None):
if not ckpt_path:
if self.config['validating'] or self.config['loads_best_ckpt']:
ckpt_path = f'{self.config["log_dir"]}/epoch_best.ckpt'
else:
ckpt_paths = [path for path in os.listdir(
f'{self.config["log_dir"]}/') if path.endswith('.ckpt') and 'best' not in path]
if len(ckpt_paths) == 0:
if self.config['display']:
log.info(f'No .ckpt found in {self.config["log_dir"]}')
return
def sort_func(x): return int(
re.search(r"(\d+)", x).groups()[0])
ckpt_path = f'{self.config["log_dir"]}/{sorted(ckpt_paths, key=sort_func, reverse=True)[0]}'
if self.config['display']:
log.info(f'loading checkpoint {ckpt_path}')
epoch_name = osp.split(ckpt_path)[1].split('.')[0]
if re.search(r"(\d+)", epoch_name):
self.checkpoint_queue.append(ckpt_path)
metrics_filename = osp.join(
self.config['log_dir'], f'metrics_{epoch_name}.json')
if osp.exists(metrics_filename):
self.metrics_queue.append(metrics_filename)
map_location = {'cuda:0': f'cuda:{self.gpu_rank}'}
self.set_ckpt(torch.load(ckpt_path, map_location=map_location))
def match_model_key(self, pretrained_dict, model_dict):
matched_dict = dict()
for key in pretrained_dict:
if key in model_dict:
matched_key = key
elif key.startswith('encoder.') and key[8:] in model_dict:
matched_key = key[8:]
elif key.startswith('module.') and key[7:] in model_dict:
matched_key = key[7:]
elif 'encoder.' + key in model_dict:
matched_key = 'encoder.' + key
elif 'module.' + key in model_dict:
matched_key = 'module.' + key
else:
# not_found.append(key)
continue
matched_dict[matched_key] = pretrained_dict[key]
not_found = ""
for k in model_dict:
if k not in matched_dict:
not_found += k + '\n'
log.info("Keys from model_dict that were not found in pretrained_dict:")
log.info(not_found)
return matched_dict
def load_pretrained_vilbert(self, start_from=None):
if start_from is not None:
self.config["start_path"] = start_from
if self.config['training'] or self.config['debugging']:
ckpt_paths = [path for path in os.listdir(
f'{self.config["log_dir"]}/') if path.endswith('.ckpt') and 'best' not in path]
if len(ckpt_paths) > 0:
if self.config['display']:
log.info('Continue training')
return
if self.config['display']:
log.info(
f'Loading pretrained VilBERT from {self.config["start_path"]}')
map_location = {'cuda:0': f'cuda:{self.gpu_rank}'}
pretrained_dict = torch.load(
self.config['start_path'], map_location=map_location)
if 'model_state_dict' in pretrained_dict:
pretrained_dict = pretrained_dict['model_state_dict']
if self.config['parallel']:
model = self.model.module
else:
model = self.model
model_dict = model.state_dict()
matched_dict = self.match_model_key(pretrained_dict, model_dict)
if self.config['display']:
log.info("number of keys transferred: %d" % len(matched_dict))
assert len(matched_dict.keys()) > 0
model_dict.update(matched_dict)
model.load_state_dict(model_dict)
del pretrained_dict, model_dict, matched_dict
if not self.config['parallel'] or self.config['dp_type'] == 'dp':
torch.cuda.empty_cache()
if self.config['display']:
log.info(f'Pretrained VilBERT loaded')