830 lines
34 KiB
Python
830 lines
34 KiB
Python
import os
|
|
import os.path as osp
|
|
import json
|
|
from collections import deque
|
|
import time
|
|
import re
|
|
import shutil
|
|
import glob
|
|
import pickle
|
|
import gc
|
|
import numpy as np
|
|
import glog as log
|
|
try:
|
|
from apex import amp
|
|
except ModuleNotFoundError:
|
|
print('apex not found')
|
|
|
|
import torch
|
|
import torch.utils.data as tud
|
|
import torch.nn.functional as F
|
|
import torch.distributed as dist
|
|
|
|
from utils.data_utils import load_pickle_lines
|
|
from utils.visdial_metrics import SparseGTMetrics, NDCG, scores_to_ranks
|
|
import wandb
|
|
|
|
|
|
class Runner:
|
|
def __init__(self, config):
|
|
self.config = config
|
|
if 'rank' in config:
|
|
self.gpu_rank = config['rank']
|
|
else:
|
|
self.gpu_rank = 0
|
|
|
|
self.epoch_idx = 0
|
|
self.max_metric = 0.
|
|
self.max_metric_epoch_idx = 0
|
|
self.na_str = 'N/A'
|
|
|
|
if self.config["max_ckpt_to_keep"] > 0:
|
|
self.checkpoint_queue = deque(
|
|
[], maxlen=config["max_ckpt_to_keep"])
|
|
self.metrics_queue = deque([], maxlen=config["max_ckpt_to_keep"])
|
|
|
|
self.setup_wandb()
|
|
|
|
def setup_wandb(self):
|
|
if self.gpu_rank == 0:
|
|
print("[INFO] Set wandb logging on rank {}".format(0))
|
|
run = wandb.init(
|
|
project=self.config['wandb_project'], config=self.config, mode=self.config['wandb_mode'])
|
|
else:
|
|
run = None
|
|
self.run = run
|
|
|
|
def forward(self, batch, eval_visdial=False):
|
|
return NotImplementedError
|
|
|
|
def train(self, dataset, dataset_eval=None):
|
|
# wandb.login()
|
|
if os.path.exists(self.config['log_dir']) or self.config['loads_ckpt'] or self.config['loads_best_ckpt']:
|
|
self.load_ckpt()
|
|
|
|
if self.config['use_trainval']:
|
|
dataset.split = 'trainval'
|
|
else:
|
|
dataset.split = 'train'
|
|
batch_size = self.config['batch_size']
|
|
if self.config['parallel'] and self.config['dp_type'] != 'dp':
|
|
sampler_tr = tud.distributed.DistributedSampler(
|
|
dataset,
|
|
num_replicas=self.config['num_gpus'],
|
|
rank=self.gpu_rank
|
|
)
|
|
else:
|
|
sampler_tr = None
|
|
|
|
data_loader_tr = tud.DataLoader(
|
|
dataset=dataset,
|
|
batch_size=batch_size,
|
|
shuffle=self.config['training'] and not self.config['parallel'],
|
|
collate_fn=dataset.collate_fn,
|
|
num_workers=self.config['num_workers'],
|
|
sampler=sampler_tr
|
|
)
|
|
|
|
|
|
start_epoch_idx = self.epoch_idx
|
|
num_iter_epoch = self.config['num_iter_per_epoch']
|
|
if self.config['display']:
|
|
log.info(f'{num_iter_epoch} iter per epoch.')
|
|
|
|
# eval before training
|
|
eval_dense_at_first = self.config['train_on_dense'] and self.config['skip_mrr_eval'] and start_epoch_idx == 0
|
|
# eval before training under 2 circumstances:
|
|
# for dense finetuning, eval ndcg before the first epoch
|
|
# for mrr training, continue training and the last epoch is not evaluated
|
|
|
|
if (eval_dense_at_first or (self.config['eval_at_start'] and len(self.metrics_queue) == 0 and start_epoch_idx > 0)):
|
|
if eval_dense_at_first:
|
|
iter_now = 0
|
|
else:
|
|
iter_now = max(num_iter_epoch * start_epoch_idx, 0)
|
|
|
|
if dataset_eval is None:
|
|
dataset.split = 'val'
|
|
dataset_to_eval = dataset
|
|
else:
|
|
dataset_to_eval = dataset_eval
|
|
|
|
metrics_results = {}
|
|
metrics_to_maximize, metrics_results['val'] = self.evaluate(
|
|
dataset_to_eval, iter_now)
|
|
if eval_dense_at_first:
|
|
self.max_metric = metrics_to_maximize
|
|
self.max_metric_epoch_idx = -1
|
|
else:
|
|
if self.config['display']:
|
|
self.save_eval_results(
|
|
'val', start_epoch_idx - 1, metrics_results)
|
|
if metrics_to_maximize > self.max_metric:
|
|
self.max_metric = metrics_to_maximize
|
|
self.max_metric_epoch_idx = start_epoch_idx - 1
|
|
self.copy_best_results('val', start_epoch_idx - 1)
|
|
self.copy_best_predictions('val')
|
|
if dataset_eval is None:
|
|
if self.config['use_trainval']:
|
|
dataset.split = 'trainval'
|
|
else:
|
|
dataset.split = 'train'
|
|
|
|
num_epochs = self.config['num_epochs']
|
|
|
|
for epoch_idx in range(start_epoch_idx, num_epochs):
|
|
if self.config['parallel'] and self.config['dp_type'] != 'dp':
|
|
sampler_tr.set_epoch(epoch_idx)
|
|
|
|
self.epoch_idx = epoch_idx
|
|
|
|
if self.config['display']:
|
|
log.info(f'starting epoch {epoch_idx}')
|
|
log.info('training')
|
|
|
|
self.model.train()
|
|
|
|
num_batch = 0
|
|
next_logging_pct = .1
|
|
next_evaluating_pct = self.config["next_evaluating_pct"] + .1
|
|
start_time = time.time()
|
|
self.optimizer.zero_grad()
|
|
|
|
for batch in data_loader_tr:
|
|
if self.config['eval_before_training']:
|
|
log.info('Skipping stright to evaluation...')
|
|
break
|
|
num_batch += 1
|
|
pct = num_batch / num_iter_epoch * 100
|
|
iter_now = num_iter_epoch * epoch_idx + num_batch
|
|
|
|
output = self.forward(batch)
|
|
losses = output['losses']
|
|
|
|
# optimizer step
|
|
losses['tot_loss'] /= self.config['batch_multiply']
|
|
# debug
|
|
if self.config['debugging']:
|
|
log.info('try backward')
|
|
if self.config['dp_type'] == 'apex':
|
|
with amp.scale_loss(losses['tot_loss'], self.optimizer) as scaled_loss:
|
|
scaled_loss.backward()
|
|
else:
|
|
losses['tot_loss'].backward()
|
|
if self.config['debugging']:
|
|
log.info('backward done')
|
|
|
|
if iter_now % self.config['batch_multiply'] == 0:
|
|
self.optimizer.step()
|
|
self.optimizer.zero_grad()
|
|
self.scheduler.step()
|
|
|
|
# display and eval
|
|
if pct >= next_logging_pct:
|
|
if self.config['display']:
|
|
loss_to_print = ''
|
|
for key in losses:
|
|
if losses[key] is not None and isinstance(losses[key], torch.Tensor):
|
|
loss_to_print += f'[{key}: {losses[key].item():.4f}]'
|
|
print(
|
|
f'[{int(pct)}%][Epoch: {epoch_idx + 1}/{num_epochs}][Iter : {num_batch}/{len(data_loader_tr)}] [time: {time.time() - start_time:.2f}] {loss_to_print}'
|
|
)
|
|
|
|
|
|
next_logging_pct += self.config["next_logging_pct"]
|
|
|
|
if self.config['debugging']:
|
|
break
|
|
|
|
if pct >= next_evaluating_pct:
|
|
next_evaluating_pct += self.config["next_evaluating_pct"]
|
|
|
|
if self.run:
|
|
if self.config['train_on_dense']:
|
|
self.run.log(
|
|
{
|
|
"Train/dense_loss": losses['dense_loss'],
|
|
"Train/total_loss": losses['tot_loss'],
|
|
},
|
|
step=iter_now
|
|
)
|
|
|
|
else:
|
|
self.run.log(
|
|
{
|
|
"Train/lm_loss": losses['lm_loss'],
|
|
"Train/img_loss": losses['img_loss'],
|
|
"Train/nsp_loss": losses['nsp_loss'],
|
|
"Train/total_loss": losses['tot_loss'],
|
|
},
|
|
step=iter_now
|
|
)
|
|
|
|
lr_gnn, lr_bert = self.scheduler.get_lr()[0], self.scheduler.get_lr()[1]
|
|
self.run.log(
|
|
{
|
|
"Train/lr_gnn": lr_gnn,
|
|
"Train/lr_bert": lr_bert,
|
|
},
|
|
step=iter_now
|
|
)
|
|
del losses
|
|
# debug
|
|
torch.cuda.empty_cache()
|
|
|
|
if self.config['display']:
|
|
log.info(
|
|
f'100%,\ttime:\t{time.time() - start_time:.2f}'
|
|
)
|
|
ckpt_path = self.save_ckpt()
|
|
|
|
if not self.config['skip_visdial_eval'] and self.epoch_idx % self.config['eval_visdial_every'] == 0:
|
|
|
|
iter_now = num_iter_epoch * (epoch_idx + 1)
|
|
|
|
if dataset_eval is None:
|
|
dataset.split = 'val'
|
|
dataset_to_eval = dataset
|
|
else:
|
|
dataset_to_eval = dataset_eval
|
|
metrics_results = {}
|
|
metrics_to_maximize, metrics_results['val'] = self.evaluate(
|
|
dataset_to_eval, iter_now)
|
|
if dataset_eval is None:
|
|
if self.config['use_trainval']:
|
|
dataset.split = 'trainval'
|
|
else:
|
|
dataset.split = 'train'
|
|
if self.config['display']:
|
|
self.save_eval_results('val', epoch_idx, metrics_results)
|
|
|
|
if self.config['display']:
|
|
|
|
if metrics_to_maximize > self.max_metric:
|
|
self.max_metric = metrics_to_maximize
|
|
self.max_metric_epoch_idx = epoch_idx
|
|
self.copy_best_results('val', epoch_idx)
|
|
self.copy_best_predictions('val')
|
|
|
|
elif not self.config['parallel'] and epoch_idx - self.max_metric_epoch_idx > self.config["early_stop_epoch"]:
|
|
log.info('Early stop.')
|
|
break
|
|
|
|
if self.run:
|
|
self.run.log(
|
|
{"Val/metric_best": self.max_metric}, step=iter_now)
|
|
|
|
if self.config['parallel']:
|
|
if self.config['dp_type'] == 'dp':
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
else:
|
|
dist.barrier()
|
|
log.info('Rank {} passed barrier...'.format(self.gpu_rank))
|
|
|
|
if self.config['stop_epochs'] >= 0 and epoch_idx + 1 >= self.config['stop_epochs']:
|
|
if self.config['display']:
|
|
log.info('Stop for reaching stop_epochs.')
|
|
break
|
|
|
|
def evaluate(self, dataset, training_iter=None, eval_visdial=True):
|
|
# create files to save output
|
|
if self.config['predicting']:
|
|
visdial_file_name = None
|
|
if self.config['save_score']:
|
|
visdial_file_name = osp.join(
|
|
self.config['log_dir'], f'visdial_prediction.pkl')
|
|
if osp.exists(visdial_file_name):
|
|
dialogs_predicted = load_pickle_lines(
|
|
visdial_file_name)
|
|
dialogs_predicted = [d['image_id']
|
|
for d in dialogs_predicted]
|
|
else:
|
|
dialogs_predicted = []
|
|
f_visdial = open(visdial_file_name, 'ab')
|
|
|
|
else:
|
|
visdial_file_name = osp.join(
|
|
self.config['log_dir'], f'visdial_prediction.jsonlines')
|
|
if self.config['parallel'] and self.config['dp_type'] != 'dp':
|
|
visdial_file_name = visdial_file_name.replace(
|
|
'.jsonlines', f'_{self.config["rank"]}of{self.config["num_gpus"]}.jsonlines')
|
|
if osp.exists(visdial_file_name):
|
|
dialogs_predicted_visdial = [json.loads(
|
|
line)['image_id'] for line in open(visdial_file_name)]
|
|
f_visdial = open(visdial_file_name, 'a')
|
|
else:
|
|
dialogs_predicted_visdial = []
|
|
f_visdial = open(visdial_file_name, 'w')
|
|
|
|
dialogs_predicted = dialogs_predicted_visdial
|
|
|
|
if len(dialogs_predicted) > 0:
|
|
log.info(f'Found {len(dialogs_predicted)} predicted results.')
|
|
|
|
if self.config['display']:
|
|
if visdial_file_name is not None:
|
|
log.info(
|
|
f'VisDial predictions saved to {visdial_file_name}')
|
|
|
|
elif self.config['display']:
|
|
if self.config['continue_evaluation']:
|
|
predicted_files = os.listdir(
|
|
osp.join(self.config['visdial_output_dir'], dataset.split))
|
|
dialogs_predicted = [
|
|
int(re.match(r'(\d+).npz', p).group(1)) for p in predicted_files]
|
|
else:
|
|
if osp.exists(osp.join(self.config['visdial_output_dir'], dataset.split)):
|
|
shutil.rmtree(
|
|
osp.join(self.config['visdial_output_dir'], dataset.split))
|
|
os.makedirs(
|
|
osp.join(self.config['visdial_output_dir'], dataset.split))
|
|
|
|
dialogs_predicted = []
|
|
log.info(f'Found {len(dialogs_predicted)} predicted results.')
|
|
|
|
if self.config['parallel'] and self.config['dp_type'] != 'dp':
|
|
sampler_val = tud.distributed.DistributedSampler(
|
|
dataset,
|
|
num_replicas=self.config['num_gpus'],
|
|
rank=self.gpu_rank
|
|
)
|
|
|
|
sampler_val.set_epoch(self.epoch_idx)
|
|
else:
|
|
sampler_val = None
|
|
|
|
data_loader_val = tud.DataLoader(
|
|
dataset=dataset,
|
|
batch_size=self.config['eval_batch_size'],
|
|
shuffle=False,
|
|
collate_fn=dataset.collate_fn,
|
|
num_workers=self.config['num_workers'],
|
|
sampler=sampler_val
|
|
)
|
|
self.model.eval()
|
|
|
|
with torch.no_grad():
|
|
if self.config['display']:
|
|
log.info(f'Evaluating {len(dataset)} samples')
|
|
|
|
next_logging_pct = self.config["next_logging_pct"] + .1
|
|
if self.config['parallel'] and self.config['dp_type'] == 'dp':
|
|
num_batch_tot = int(
|
|
np.ceil(len(dataset) / self.config['eval_batch_size']))
|
|
else:
|
|
num_batch_tot = int(np.ceil(
|
|
len(dataset) / (self.config['eval_batch_size'] * self.config['num_gpus'])))
|
|
num_batch = 0
|
|
if dataset.split == 'val':
|
|
num_options = self.config["num_options"]
|
|
if self.config['skip_mrr_eval']:
|
|
num_rounds = 1
|
|
else:
|
|
num_rounds = 10
|
|
elif dataset.split == 'test':
|
|
num_options = 100
|
|
num_rounds = 1
|
|
if self.gpu_rank == 0:
|
|
start_time = time.time()
|
|
|
|
for batch in data_loader_val:
|
|
num_batch += 1
|
|
# skip dialogs that have been predicted
|
|
if self.config['predicting']:
|
|
image_ids = batch['image_id'].tolist()
|
|
skip_batch = True
|
|
for image_id in image_ids:
|
|
if image_id not in dialogs_predicted:
|
|
skip_batch = False
|
|
if skip_batch:
|
|
continue
|
|
output = self.forward(
|
|
batch, eval_visdial=eval_visdial)
|
|
|
|
# visdial evaluation
|
|
if eval_visdial:
|
|
img_ids = batch['image_id'].tolist()
|
|
batch_size = len(img_ids)
|
|
if not self.config['skip_ndcg_eval']:
|
|
gt_relevance_round_id = batch['round_id'].tolist()
|
|
|
|
# [batch_size * num_rounds * num_options, 2]
|
|
nsp_scores = output['nsp_scores']
|
|
nsp_probs = F.softmax(nsp_scores, dim=1)
|
|
assert nsp_probs.shape[-1] == 2
|
|
# num_dim=2, 0 for postive, 1 for negative
|
|
nsp_probs = nsp_probs[:, 0]
|
|
nsp_probs = nsp_probs.view(
|
|
batch_size, num_rounds, num_options)
|
|
|
|
# could be predicting or evaluating
|
|
if dataset.split == 'val':
|
|
if self.config['skip_ndcg_eval']:
|
|
gt_option_inds = batch['gt_option_inds']
|
|
|
|
for b in range(batch_size):
|
|
filename = osp.join(
|
|
self.config['visdial_output_dir'], dataset.split, f'{img_ids[b]}.npz')
|
|
if not osp.exists(filename):
|
|
np.savez(
|
|
filename,
|
|
nsp_probs=nsp_probs[b].cpu().numpy(),
|
|
gt_option_inds=gt_option_inds[b].cpu().numpy()
|
|
)
|
|
else:
|
|
# [batch_size, num_rounds]
|
|
gt_option_inds = batch['gt_option_inds']
|
|
# [batch_size, num_options]
|
|
gt_relevance = batch['gt_relevance']
|
|
|
|
for b in range(batch_size):
|
|
filename = osp.join(
|
|
self.config['visdial_output_dir'], dataset.split, f'{img_ids[b]}.npz')
|
|
if not osp.exists(filename):
|
|
np.savez(filename,
|
|
nsp_probs=nsp_probs[b].cpu().numpy(),
|
|
gt_option_inds=gt_option_inds[b].cpu(
|
|
).numpy(),
|
|
gt_relevance=gt_relevance[b].cpu(
|
|
).numpy(),
|
|
gt_relevance_round_id=gt_relevance_round_id[b])
|
|
|
|
# must be predicting
|
|
if dataset.split == 'test':
|
|
if self.config['save_score']:
|
|
for b in range(batch_size):
|
|
prediction = {
|
|
"image_id": img_ids[b],
|
|
"nsp_probs": nsp_probs[b].cpu().numpy(),
|
|
"gt_relevance_round_id": gt_relevance_round_id[b]
|
|
}
|
|
pickle.dump(prediction, f_visdial)
|
|
else:
|
|
# [eval_batch_size, num_rounds, num_options]
|
|
ranks = scores_to_ranks(nsp_probs)
|
|
ranks = ranks.squeeze(1)
|
|
for b in range(batch_size):
|
|
prediction = {
|
|
"image_id": img_ids[b],
|
|
"round_id": gt_relevance_round_id[b],
|
|
"ranks": ranks[b].tolist()
|
|
}
|
|
f_visdial.write(json.dumps(prediction) + '\n')
|
|
|
|
# debug
|
|
if self.config['debugging']:
|
|
break
|
|
|
|
pct = num_batch / num_batch_tot * 100
|
|
if pct >= next_logging_pct:
|
|
if self.config['display'] and self.gpu_rank == 0:
|
|
log.info(
|
|
f'{int(pct)}%,\ttime:\t{time.time() - start_time:.2f}'
|
|
)
|
|
next_logging_pct += self.config["next_logging_pct"]
|
|
# debug
|
|
if self.config['debugging']:
|
|
break
|
|
|
|
if self.config['display'] and self.gpu_rank == 0:
|
|
pct = num_batch / num_batch_tot * 100
|
|
log.info(
|
|
f'{int(pct)}%,\ttime:\t{time.time() - start_time:.2f}'
|
|
)
|
|
|
|
if not self.config['validating']:
|
|
self.model.train()
|
|
|
|
if self.config['parallel'] and self.config['dp_type'] != 'dp':
|
|
dist.barrier()
|
|
|
|
print(f'{self.gpu_rank} passed barrier')
|
|
|
|
if self.config['predicting']:
|
|
f_visdial.close()
|
|
if not self.config['save_score']:
|
|
all_visdial_predictions = [json.loads(
|
|
line) for line in open(visdial_file_name)]
|
|
if self.config['predict_split'] == 'test' and len(all_visdial_predictions) == self.config['num_test_dialogs']:
|
|
visdial_file_name = visdial_file_name.replace(
|
|
'jsonlines', 'json')
|
|
with open(visdial_file_name, 'w') as f_visdial:
|
|
json.dump(all_visdial_predictions, f_visdial)
|
|
log.info(
|
|
f'Prediction for submisson save to {visdial_file_name}.')
|
|
return None, None
|
|
|
|
if self.config['display']:
|
|
if dataset.split == 'val' and eval_visdial:
|
|
if not self.config['skip_mrr_eval']:
|
|
sparse_metrics = SparseGTMetrics()
|
|
if not self.config['skip_ndcg_eval']:
|
|
ndcg = NDCG()
|
|
|
|
if dataset.split == 'val' and eval_visdial:
|
|
visdial_output_filenames = glob.glob(
|
|
osp.join(self.config['visdial_output_dir'], dataset.split, '*.npz'))
|
|
log.info(
|
|
f'Calculating visdial metrics for {len(visdial_output_filenames)} dialogs')
|
|
for visdial_output_filename in visdial_output_filenames:
|
|
output = np.load(visdial_output_filename)
|
|
nsp_probs = torch.from_numpy(
|
|
output['nsp_probs']).unsqueeze(0)
|
|
if not self.config['skip_ndcg_eval']:
|
|
gt_relevance = torch.from_numpy(output['gt_relevance']).unsqueeze(0)
|
|
if not self.config['skip_mrr_eval']:
|
|
gt_option_inds = torch.from_numpy(
|
|
output['gt_option_inds']).unsqueeze(0)
|
|
sparse_metrics.observe(nsp_probs, gt_option_inds)
|
|
if not self.config['skip_ndcg_eval']:
|
|
gt_relevance_round_id = output['gt_relevance_round_id']
|
|
nsp_probs_dense = nsp_probs[0, gt_relevance_round_id - 1, :].unsqueeze(0)
|
|
else:
|
|
nsp_probs_dense = nsp_probs.squeeze(0) # [1, 100]
|
|
if not self.config['skip_ndcg_eval']:
|
|
ndcg.observe(nsp_probs_dense, gt_relevance)
|
|
|
|
# visdial eval output
|
|
visdial_metrics = {}
|
|
if dataset.split == 'val' and eval_visdial:
|
|
if not self.config['skip_mrr_eval']:
|
|
visdial_metrics.update(sparse_metrics.retrieve(reset=True))
|
|
if not self.config['skip_ndcg_eval']:
|
|
visdial_metrics.update(ndcg.retrieve(reset=True))
|
|
|
|
if self.config['display']:
|
|
to_print = ''
|
|
for metric_name, metric_value in visdial_metrics.items():
|
|
if 'round' not in metric_name:
|
|
to_print += f"\n{metric_name}: {metric_value}"
|
|
if training_iter is not None:
|
|
if self.run:
|
|
self.run.log(
|
|
{'Val/' + metric_name: metric_value}, step=training_iter)
|
|
log.info(to_print)
|
|
|
|
if self.config['metrics_to_maximize'] in visdial_metrics:
|
|
metrics_to_maximize = visdial_metrics[self.config['metrics_to_maximize']]
|
|
else:
|
|
metrics_to_maximize = None
|
|
torch.cuda.empty_cache()
|
|
return metrics_to_maximize, visdial_metrics
|
|
else:
|
|
torch.cuda.empty_cache()
|
|
return None, None
|
|
|
|
def save_eval_results(self, split, epoch_idx, metrics_results):
|
|
|
|
metrics_filename = osp.join(
|
|
self.config['log_dir'], f'metrics_epoch_{epoch_idx}.json')
|
|
with open(metrics_filename, 'w') as f:
|
|
json.dump(metrics_results, f)
|
|
log.info(f'Results of metrics saved to {metrics_filename}')
|
|
|
|
if self.config["max_ckpt_to_keep"] > 0:
|
|
if len(self.metrics_queue) == self.metrics_queue.maxlen:
|
|
todel = self.metrics_queue.popleft()
|
|
os.remove(todel)
|
|
self.metrics_queue.append(metrics_filename)
|
|
|
|
if epoch_idx == 'best':
|
|
self.copy_best_predictions(split)
|
|
|
|
def copy_best_results(self, split, epoch_idx):
|
|
to_print = 'Copy '
|
|
|
|
if not self.config['skip_saving_ckpt']:
|
|
ckpt_path = osp.join(
|
|
self.config['log_dir'], f'epoch_{epoch_idx}.ckpt')
|
|
best_ckpt_path = ckpt_path.replace(
|
|
f'{epoch_idx}.ckpt', 'best.ckpt')
|
|
shutil.copyfile(ckpt_path, best_ckpt_path)
|
|
to_print += best_ckpt_path + ' '
|
|
|
|
metrics_filename = osp.join(
|
|
self.config['log_dir'], f'metrics_epoch_{epoch_idx}.json')
|
|
best_metric_filename = metrics_filename.replace(
|
|
f'{epoch_idx}.json', 'best.json')
|
|
shutil.copyfile(metrics_filename, best_metric_filename)
|
|
to_print += best_metric_filename + ' '
|
|
|
|
log.info(to_print)
|
|
|
|
def copy_best_predictions(self, split):
|
|
to_print = 'Copy '
|
|
|
|
visdial_output_dir = osp.join(self.config['visdial_output_dir'], split)
|
|
if osp.exists(visdial_output_dir):
|
|
dir_best = visdial_output_dir.replace('output', 'output_best')
|
|
if osp.exists(dir_best):
|
|
shutil.rmtree(dir_best)
|
|
shutil.copytree(visdial_output_dir, dir_best)
|
|
to_print += dir_best + ' '
|
|
|
|
log.info(to_print)
|
|
|
|
def get_ckpt(self):
|
|
ckpt = {
|
|
'epoch_idx': self.epoch_idx,
|
|
'max_metric': self.max_metric,
|
|
'seed': self.config['random_seed'],
|
|
'optimizer': self.optimizer.state_dict(),
|
|
'scheduler': self.scheduler.state_dict()
|
|
}
|
|
if self.config['parallel']:
|
|
ckpt['model_state_dict'] = self.model.module.state_dict()
|
|
else:
|
|
ckpt['model_state_dict'] = self.model.state_dict()
|
|
if self.config['dp_type'] == 'apex':
|
|
ckpt['amp'] = amp.state_dict()
|
|
return ckpt
|
|
|
|
def set_ckpt(self, ckpt_dict):
|
|
if not self.config['restarts']:
|
|
self.epoch_idx = ckpt_dict.get('epoch_idx', -1) + 1
|
|
|
|
if not self.config['resets_max_metric']:
|
|
self.max_metric = ckpt_dict.get('max_metric', -1)
|
|
|
|
if self.config['parallel']:
|
|
model = self.model.module
|
|
else:
|
|
model = self.model
|
|
|
|
model_state_dict = model.state_dict()
|
|
former_dict = {
|
|
k: v for k, v in ckpt_dict['model_state_dict'].items() if k in model_state_dict}
|
|
|
|
if self.config['display']:
|
|
log.info("number of keys transferred: %d" % len(former_dict))
|
|
assert len(former_dict.keys()) > 0
|
|
|
|
model_state_dict.update(former_dict)
|
|
|
|
model.load_state_dict(model_state_dict)
|
|
if self.config['display']:
|
|
log.info('loaded model')
|
|
del model_state_dict, former_dict
|
|
|
|
if not self.config['validating'] and not (self.config['uses_new_optimizer'] or self.config['sets_new_lr']):
|
|
if 'optimizer' in ckpt_dict:
|
|
self.optimizer.load_state_dict(ckpt_dict['optimizer'])
|
|
if self.config['display']:
|
|
log.info('loaded optimizer')
|
|
if 'scheduler' in ckpt_dict:
|
|
self.scheduler.last_epcoh = ckpt_dict['epoch_idx'] * \
|
|
self.config['num_iter_per_epoch']
|
|
self.scheduler.load_state_dict(ckpt_dict['scheduler'])
|
|
|
|
if 'amp' in ckpt_dict and self.config['dp_type'] == 'apex':
|
|
amp.load_state_dict(ckpt_dict['amp'])
|
|
|
|
del ckpt_dict
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
def save_ckpt(self):
|
|
ckpt_path = f'{self.config["log_dir"]}/epoch_{self.epoch_idx}.ckpt'
|
|
log.info(f'saving checkpoint {ckpt_path}')
|
|
ckpt = self.get_ckpt()
|
|
if self.config['skip_saving_ckpt']:
|
|
return ckpt_path
|
|
torch_version = float(torch.__version__[:3])
|
|
if torch_version - 1.4 > 1e-3:
|
|
torch.save(ckpt, f=ckpt_path, _use_new_zipfile_serialization=False)
|
|
else:
|
|
torch.save(ckpt, f=ckpt_path)
|
|
del ckpt
|
|
|
|
if not (self.config['parallel'] and self.config['dp_type'] in ['ddp', 'apex']):
|
|
torch.cuda.empty_cache()
|
|
|
|
if self.config["max_ckpt_to_keep"] > 0:
|
|
if len(self.checkpoint_queue) == self.checkpoint_queue.maxlen:
|
|
todel = self.checkpoint_queue.popleft()
|
|
os.remove(todel)
|
|
self.checkpoint_queue.append(ckpt_path)
|
|
|
|
return ckpt_path
|
|
|
|
def save_ckpt_best(self):
|
|
ckpt_path = f'{self.config["log_dir"]}/epoch_best.ckpt'
|
|
log.info(f'saving checkpoint {ckpt_path}')
|
|
ckpt = self.get_ckpt()
|
|
torch.save(ckpt, f=ckpt_path)
|
|
del ckpt
|
|
return ckpt_path
|
|
|
|
def load_ckpt_best(self):
|
|
ckpt_path = f'{osp.dirname(self.config["log_dir"])}/epoch_best.ckpt'
|
|
if not osp.exists(ckpt_path):
|
|
ckpt_paths = [path for path in os.listdir(
|
|
f'{self.config["log_dir"]}/') if path.endswith('.ckpt') and 'best' not in path]
|
|
if len(ckpt_paths) == 0:
|
|
if self.config['display']:
|
|
log.info(f'No .ckpt found in {self.config["log_dir"]}')
|
|
return
|
|
|
|
def sort_func(x): return int(re.search(r"(\d+)", x).groups()[0])
|
|
ckpt_path = f'{self.config["log_dir"]}/{sorted(ckpt_paths, key=sort_func, reverse=True)[0]}'
|
|
if self.config['display']:
|
|
log.info(f'loading checkpoint {ckpt_path}')
|
|
map_location = {'cuda:0': f'cuda:{self.gpu_rank}'}
|
|
self.set_ckpt(torch.load(ckpt_path, map_location=map_location))
|
|
|
|
def load_ckpt(self, ckpt_path=None):
|
|
if not ckpt_path:
|
|
if self.config['validating'] or self.config['loads_best_ckpt']:
|
|
ckpt_path = f'{self.config["log_dir"]}/epoch_best.ckpt'
|
|
else:
|
|
ckpt_paths = [path for path in os.listdir(
|
|
f'{self.config["log_dir"]}/') if path.endswith('.ckpt') and 'best' not in path]
|
|
if len(ckpt_paths) == 0:
|
|
if self.config['display']:
|
|
log.info(f'No .ckpt found in {self.config["log_dir"]}')
|
|
return
|
|
|
|
def sort_func(x): return int(
|
|
re.search(r"(\d+)", x).groups()[0])
|
|
ckpt_path = f'{self.config["log_dir"]}/{sorted(ckpt_paths, key=sort_func, reverse=True)[0]}'
|
|
|
|
if self.config['display']:
|
|
log.info(f'loading checkpoint {ckpt_path}')
|
|
epoch_name = osp.split(ckpt_path)[1].split('.')[0]
|
|
if re.search(r"(\d+)", epoch_name):
|
|
self.checkpoint_queue.append(ckpt_path)
|
|
metrics_filename = osp.join(
|
|
self.config['log_dir'], f'metrics_{epoch_name}.json')
|
|
if osp.exists(metrics_filename):
|
|
self.metrics_queue.append(metrics_filename)
|
|
|
|
map_location = {'cuda:0': f'cuda:{self.gpu_rank}'}
|
|
self.set_ckpt(torch.load(ckpt_path, map_location=map_location))
|
|
|
|
def match_model_key(self, pretrained_dict, model_dict):
|
|
matched_dict = dict()
|
|
for key in pretrained_dict:
|
|
if key in model_dict:
|
|
matched_key = key
|
|
elif key.startswith('encoder.') and key[8:] in model_dict:
|
|
matched_key = key[8:]
|
|
elif key.startswith('module.') and key[7:] in model_dict:
|
|
matched_key = key[7:]
|
|
elif 'encoder.' + key in model_dict:
|
|
matched_key = 'encoder.' + key
|
|
elif 'module.' + key in model_dict:
|
|
matched_key = 'module.' + key
|
|
else:
|
|
# not_found.append(key)
|
|
continue
|
|
matched_dict[matched_key] = pretrained_dict[key]
|
|
|
|
not_found = ""
|
|
for k in model_dict:
|
|
if k not in matched_dict:
|
|
not_found += k + '\n'
|
|
|
|
log.info("Keys from model_dict that were not found in pretrained_dict:")
|
|
log.info(not_found)
|
|
return matched_dict
|
|
|
|
def load_pretrained_vilbert(self, start_from=None):
|
|
if start_from is not None:
|
|
self.config["start_path"] = start_from
|
|
if self.config['training'] or self.config['debugging']:
|
|
ckpt_paths = [path for path in os.listdir(
|
|
f'{self.config["log_dir"]}/') if path.endswith('.ckpt') and 'best' not in path]
|
|
if len(ckpt_paths) > 0:
|
|
if self.config['display']:
|
|
log.info('Continue training')
|
|
return
|
|
|
|
if self.config['display']:
|
|
log.info(
|
|
f'Loading pretrained VilBERT from {self.config["start_path"]}')
|
|
map_location = {'cuda:0': f'cuda:{self.gpu_rank}'}
|
|
pretrained_dict = torch.load(
|
|
self.config['start_path'], map_location=map_location)
|
|
if 'model_state_dict' in pretrained_dict:
|
|
pretrained_dict = pretrained_dict['model_state_dict']
|
|
if self.config['parallel']:
|
|
model = self.model.module
|
|
else:
|
|
model = self.model
|
|
model_dict = model.state_dict()
|
|
|
|
matched_dict = self.match_model_key(pretrained_dict, model_dict)
|
|
|
|
if self.config['display']:
|
|
log.info("number of keys transferred: %d" % len(matched_dict))
|
|
assert len(matched_dict.keys()) > 0
|
|
model_dict.update(matched_dict)
|
|
model.load_state_dict(model_dict)
|
|
|
|
del pretrained_dict, model_dict, matched_dict
|
|
if not self.config['parallel'] or self.config['dp_type'] == 'dp':
|
|
torch.cuda.empty_cache()
|
|
|
|
if self.config['display']:
|
|
log.info(f'Pretrained VilBERT loaded')
|