Code release

This commit is contained in:
Adnen Abdessaied 2023-10-25 15:38:09 +02:00
commit 09fb25e339
29 changed files with 7162 additions and 0 deletions

0
models/__init__.py Normal file
View file

830
models/runner.py Normal file
View file

@ -0,0 +1,830 @@
import os
import os.path as osp
import json
from collections import deque
import time
import re
import shutil
import glob
import pickle
import gc
import numpy as np
import glog as log
try:
from apex import amp
except ModuleNotFoundError:
print('apex not found')
import torch
import torch.utils.data as tud
import torch.nn.functional as F
import torch.distributed as dist
from utils.data_utils import load_pickle_lines
from utils.visdial_metrics import SparseGTMetrics, NDCG, scores_to_ranks
import wandb
class Runner:
def __init__(self, config):
self.config = config
if 'rank' in config:
self.gpu_rank = config['rank']
else:
self.gpu_rank = 0
self.epoch_idx = 0
self.max_metric = 0.
self.max_metric_epoch_idx = 0
self.na_str = 'N/A'
if self.config["max_ckpt_to_keep"] > 0:
self.checkpoint_queue = deque(
[], maxlen=config["max_ckpt_to_keep"])
self.metrics_queue = deque([], maxlen=config["max_ckpt_to_keep"])
self.setup_wandb()
def setup_wandb(self):
if self.gpu_rank == 0:
print("[INFO] Set wandb logging on rank {}".format(0))
run = wandb.init(
project=self.config['wandb_project'], config=self.config, mode=self.config['wandb_mode'])
else:
run = None
self.run = run
def forward(self, batch, eval_visdial=False):
return NotImplementedError
def train(self, dataset, dataset_eval=None):
# wandb.login()
if os.path.exists(self.config['log_dir']) or self.config['loads_ckpt'] or self.config['loads_best_ckpt']:
self.load_ckpt()
if self.config['use_trainval']:
dataset.split = 'trainval'
else:
dataset.split = 'train'
batch_size = self.config['batch_size']
if self.config['parallel'] and self.config['dp_type'] != 'dp':
sampler_tr = tud.distributed.DistributedSampler(
dataset,
num_replicas=self.config['num_gpus'],
rank=self.gpu_rank
)
else:
sampler_tr = None
data_loader_tr = tud.DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=self.config['training'] and not self.config['parallel'],
collate_fn=dataset.collate_fn,
num_workers=self.config['num_workers'],
sampler=sampler_tr
)
start_epoch_idx = self.epoch_idx
num_iter_epoch = self.config['num_iter_per_epoch']
if self.config['display']:
log.info(f'{num_iter_epoch} iter per epoch.')
# eval before training
eval_dense_at_first = self.config['train_on_dense'] and self.config['skip_mrr_eval'] and start_epoch_idx == 0
# eval before training under 2 circumstances:
# for dense finetuning, eval ndcg before the first epoch
# for mrr training, continue training and the last epoch is not evaluated
if (eval_dense_at_first or (self.config['eval_at_start'] and len(self.metrics_queue) == 0 and start_epoch_idx > 0)):
if eval_dense_at_first:
iter_now = 0
else:
iter_now = max(num_iter_epoch * start_epoch_idx, 0)
if dataset_eval is None:
dataset.split = 'val'
dataset_to_eval = dataset
else:
dataset_to_eval = dataset_eval
metrics_results = {}
metrics_to_maximize, metrics_results['val'] = self.evaluate(
dataset_to_eval, iter_now)
if eval_dense_at_first:
self.max_metric = metrics_to_maximize
self.max_metric_epoch_idx = -1
else:
if self.config['display']:
self.save_eval_results(
'val', start_epoch_idx - 1, metrics_results)
if metrics_to_maximize > self.max_metric:
self.max_metric = metrics_to_maximize
self.max_metric_epoch_idx = start_epoch_idx - 1
self.copy_best_results('val', start_epoch_idx - 1)
self.copy_best_predictions('val')
if dataset_eval is None:
if self.config['use_trainval']:
dataset.split = 'trainval'
else:
dataset.split = 'train'
num_epochs = self.config['num_epochs']
for epoch_idx in range(start_epoch_idx, num_epochs):
if self.config['parallel'] and self.config['dp_type'] != 'dp':
sampler_tr.set_epoch(epoch_idx)
self.epoch_idx = epoch_idx
if self.config['display']:
log.info(f'starting epoch {epoch_idx}')
log.info('training')
self.model.train()
num_batch = 0
next_logging_pct = .1
next_evaluating_pct = self.config["next_evaluating_pct"] + .1
start_time = time.time()
self.optimizer.zero_grad()
for batch in data_loader_tr:
if self.config['eval_before_training']:
log.info('Skipping stright to evaluation...')
break
num_batch += 1
pct = num_batch / num_iter_epoch * 100
iter_now = num_iter_epoch * epoch_idx + num_batch
output = self.forward(batch)
losses = output['losses']
# optimizer step
losses['tot_loss'] /= self.config['batch_multiply']
# debug
if self.config['debugging']:
log.info('try backward')
if self.config['dp_type'] == 'apex':
with amp.scale_loss(losses['tot_loss'], self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
losses['tot_loss'].backward()
if self.config['debugging']:
log.info('backward done')
if iter_now % self.config['batch_multiply'] == 0:
self.optimizer.step()
self.optimizer.zero_grad()
self.scheduler.step()
# display and eval
if pct >= next_logging_pct:
if self.config['display']:
loss_to_print = ''
for key in losses:
if losses[key] is not None and isinstance(losses[key], torch.Tensor):
loss_to_print += f'[{key}: {losses[key].item():.4f}]'
print(
f'[{int(pct)}%][Epoch: {epoch_idx + 1}/{num_epochs}][Iter : {num_batch}/{len(data_loader_tr)}] [time: {time.time() - start_time:.2f}] {loss_to_print}'
)
next_logging_pct += self.config["next_logging_pct"]
if self.config['debugging']:
break
if pct >= next_evaluating_pct:
next_evaluating_pct += self.config["next_evaluating_pct"]
if self.run:
if self.config['train_on_dense']:
self.run.log(
{
"Train/dense_loss": losses['dense_loss'],
"Train/total_loss": losses['tot_loss'],
},
step=iter_now
)
else:
self.run.log(
{
"Train/lm_loss": losses['lm_loss'],
"Train/img_loss": losses['img_loss'],
"Train/nsp_loss": losses['nsp_loss'],
"Train/total_loss": losses['tot_loss'],
},
step=iter_now
)
lr_gnn, lr_bert = self.scheduler.get_lr()[0], self.scheduler.get_lr()[1]
self.run.log(
{
"Train/lr_gnn": lr_gnn,
"Train/lr_bert": lr_bert,
},
step=iter_now
)
del losses
# debug
torch.cuda.empty_cache()
if self.config['display']:
log.info(
f'100%,\ttime:\t{time.time() - start_time:.2f}'
)
ckpt_path = self.save_ckpt()
if not self.config['skip_visdial_eval'] and self.epoch_idx % self.config['eval_visdial_every'] == 0:
iter_now = num_iter_epoch * (epoch_idx + 1)
if dataset_eval is None:
dataset.split = 'val'
dataset_to_eval = dataset
else:
dataset_to_eval = dataset_eval
metrics_results = {}
metrics_to_maximize, metrics_results['val'] = self.evaluate(
dataset_to_eval, iter_now)
if dataset_eval is None:
if self.config['use_trainval']:
dataset.split = 'trainval'
else:
dataset.split = 'train'
if self.config['display']:
self.save_eval_results('val', epoch_idx, metrics_results)
if self.config['display']:
if metrics_to_maximize > self.max_metric:
self.max_metric = metrics_to_maximize
self.max_metric_epoch_idx = epoch_idx
self.copy_best_results('val', epoch_idx)
self.copy_best_predictions('val')
elif not self.config['parallel'] and epoch_idx - self.max_metric_epoch_idx > self.config["early_stop_epoch"]:
log.info('Early stop.')
break
if self.run:
self.run.log(
{"Val/metric_best": self.max_metric}, step=iter_now)
if self.config['parallel']:
if self.config['dp_type'] == 'dp':
gc.collect()
torch.cuda.empty_cache()
else:
dist.barrier()
log.info('Rank {} passed barrier...'.format(self.gpu_rank))
if self.config['stop_epochs'] >= 0 and epoch_idx + 1 >= self.config['stop_epochs']:
if self.config['display']:
log.info('Stop for reaching stop_epochs.')
break
def evaluate(self, dataset, training_iter=None, eval_visdial=True):
# create files to save output
if self.config['predicting']:
visdial_file_name = None
if self.config['save_score']:
visdial_file_name = osp.join(
self.config['log_dir'], f'visdial_prediction.pkl')
if osp.exists(visdial_file_name):
dialogs_predicted = load_pickle_lines(
visdial_file_name)
dialogs_predicted = [d['image_id']
for d in dialogs_predicted]
else:
dialogs_predicted = []
f_visdial = open(visdial_file_name, 'ab')
else:
visdial_file_name = osp.join(
self.config['log_dir'], f'visdial_prediction.jsonlines')
if self.config['parallel'] and self.config['dp_type'] != 'dp':
visdial_file_name = visdial_file_name.replace(
'.jsonlines', f'_{self.config["rank"]}of{self.config["num_gpus"]}.jsonlines')
if osp.exists(visdial_file_name):
dialogs_predicted_visdial = [json.loads(
line)['image_id'] for line in open(visdial_file_name)]
f_visdial = open(visdial_file_name, 'a')
else:
dialogs_predicted_visdial = []
f_visdial = open(visdial_file_name, 'w')
dialogs_predicted = dialogs_predicted_visdial
if len(dialogs_predicted) > 0:
log.info(f'Found {len(dialogs_predicted)} predicted results.')
if self.config['display']:
if visdial_file_name is not None:
log.info(
f'VisDial predictions saved to {visdial_file_name}')
elif self.config['display']:
if self.config['continue_evaluation']:
predicted_files = os.listdir(
osp.join(self.config['visdial_output_dir'], dataset.split))
dialogs_predicted = [
int(re.match(r'(\d+).npz', p).group(1)) for p in predicted_files]
else:
if osp.exists(osp.join(self.config['visdial_output_dir'], dataset.split)):
shutil.rmtree(
osp.join(self.config['visdial_output_dir'], dataset.split))
os.makedirs(
osp.join(self.config['visdial_output_dir'], dataset.split))
dialogs_predicted = []
log.info(f'Found {len(dialogs_predicted)} predicted results.')
if self.config['parallel'] and self.config['dp_type'] != 'dp':
sampler_val = tud.distributed.DistributedSampler(
dataset,
num_replicas=self.config['num_gpus'],
rank=self.gpu_rank
)
sampler_val.set_epoch(self.epoch_idx)
else:
sampler_val = None
data_loader_val = tud.DataLoader(
dataset=dataset,
batch_size=self.config['eval_batch_size'],
shuffle=False,
collate_fn=dataset.collate_fn,
num_workers=self.config['num_workers'],
sampler=sampler_val
)
self.model.eval()
with torch.no_grad():
if self.config['display']:
log.info(f'Evaluating {len(dataset)} samples')
next_logging_pct = self.config["next_logging_pct"] + .1
if self.config['parallel'] and self.config['dp_type'] == 'dp':
num_batch_tot = int(
np.ceil(len(dataset) / self.config['eval_batch_size']))
else:
num_batch_tot = int(np.ceil(
len(dataset) / (self.config['eval_batch_size'] * self.config['num_gpus'])))
num_batch = 0
if dataset.split == 'val':
num_options = self.config["num_options"]
if self.config['skip_mrr_eval']:
num_rounds = 1
else:
num_rounds = 10
elif dataset.split == 'test':
num_options = 100
num_rounds = 1
if self.gpu_rank == 0:
start_time = time.time()
for batch in data_loader_val:
num_batch += 1
# skip dialogs that have been predicted
if self.config['predicting']:
image_ids = batch['image_id'].tolist()
skip_batch = True
for image_id in image_ids:
if image_id not in dialogs_predicted:
skip_batch = False
if skip_batch:
continue
output = self.forward(
batch, eval_visdial=eval_visdial)
# visdial evaluation
if eval_visdial:
img_ids = batch['image_id'].tolist()
batch_size = len(img_ids)
if not self.config['skip_ndcg_eval']:
gt_relevance_round_id = batch['round_id'].tolist()
# [batch_size * num_rounds * num_options, 2]
nsp_scores = output['nsp_scores']
nsp_probs = F.softmax(nsp_scores, dim=1)
assert nsp_probs.shape[-1] == 2
# num_dim=2, 0 for postive, 1 for negative
nsp_probs = nsp_probs[:, 0]
nsp_probs = nsp_probs.view(
batch_size, num_rounds, num_options)
# could be predicting or evaluating
if dataset.split == 'val':
if self.config['skip_ndcg_eval']:
gt_option_inds = batch['gt_option_inds']
for b in range(batch_size):
filename = osp.join(
self.config['visdial_output_dir'], dataset.split, f'{img_ids[b]}.npz')
if not osp.exists(filename):
np.savez(
filename,
nsp_probs=nsp_probs[b].cpu().numpy(),
gt_option_inds=gt_option_inds[b].cpu().numpy()
)
else:
# [batch_size, num_rounds]
gt_option_inds = batch['gt_option_inds']
# [batch_size, num_options]
gt_relevance = batch['gt_relevance']
for b in range(batch_size):
filename = osp.join(
self.config['visdial_output_dir'], dataset.split, f'{img_ids[b]}.npz')
if not osp.exists(filename):
np.savez(filename,
nsp_probs=nsp_probs[b].cpu().numpy(),
gt_option_inds=gt_option_inds[b].cpu(
).numpy(),
gt_relevance=gt_relevance[b].cpu(
).numpy(),
gt_relevance_round_id=gt_relevance_round_id[b])
# must be predicting
if dataset.split == 'test':
if self.config['save_score']:
for b in range(batch_size):
prediction = {
"image_id": img_ids[b],
"nsp_probs": nsp_probs[b].cpu().numpy(),
"gt_relevance_round_id": gt_relevance_round_id[b]
}
pickle.dump(prediction, f_visdial)
else:
# [eval_batch_size, num_rounds, num_options]
ranks = scores_to_ranks(nsp_probs)
ranks = ranks.squeeze(1)
for b in range(batch_size):
prediction = {
"image_id": img_ids[b],
"round_id": gt_relevance_round_id[b],
"ranks": ranks[b].tolist()
}
f_visdial.write(json.dumps(prediction) + '\n')
# debug
if self.config['debugging']:
break
pct = num_batch / num_batch_tot * 100
if pct >= next_logging_pct:
if self.config['display'] and self.gpu_rank == 0:
log.info(
f'{int(pct)}%,\ttime:\t{time.time() - start_time:.2f}'
)
next_logging_pct += self.config["next_logging_pct"]
# debug
if self.config['debugging']:
break
if self.config['display'] and self.gpu_rank == 0:
pct = num_batch / num_batch_tot * 100
log.info(
f'{int(pct)}%,\ttime:\t{time.time() - start_time:.2f}'
)
if not self.config['validating']:
self.model.train()
if self.config['parallel'] and self.config['dp_type'] != 'dp':
dist.barrier()
print(f'{self.gpu_rank} passed barrier')
if self.config['predicting']:
f_visdial.close()
if not self.config['save_score']:
all_visdial_predictions = [json.loads(
line) for line in open(visdial_file_name)]
if self.config['predict_split'] == 'test' and len(all_visdial_predictions) == self.config['num_test_dialogs']:
visdial_file_name = visdial_file_name.replace(
'jsonlines', 'json')
with open(visdial_file_name, 'w') as f_visdial:
json.dump(all_visdial_predictions, f_visdial)
log.info(
f'Prediction for submisson save to {visdial_file_name}.')
return None, None
if self.config['display']:
if dataset.split == 'val' and eval_visdial:
if not self.config['skip_mrr_eval']:
sparse_metrics = SparseGTMetrics()
if not self.config['skip_ndcg_eval']:
ndcg = NDCG()
if dataset.split == 'val' and eval_visdial:
visdial_output_filenames = glob.glob(
osp.join(self.config['visdial_output_dir'], dataset.split, '*.npz'))
log.info(
f'Calculating visdial metrics for {len(visdial_output_filenames)} dialogs')
for visdial_output_filename in visdial_output_filenames:
output = np.load(visdial_output_filename)
nsp_probs = torch.from_numpy(
output['nsp_probs']).unsqueeze(0)
if not self.config['skip_ndcg_eval']:
gt_relevance = torch.from_numpy(output['gt_relevance']).unsqueeze(0)
if not self.config['skip_mrr_eval']:
gt_option_inds = torch.from_numpy(
output['gt_option_inds']).unsqueeze(0)
sparse_metrics.observe(nsp_probs, gt_option_inds)
if not self.config['skip_ndcg_eval']:
gt_relevance_round_id = output['gt_relevance_round_id']
nsp_probs_dense = nsp_probs[0, gt_relevance_round_id - 1, :].unsqueeze(0)
else:
nsp_probs_dense = nsp_probs.squeeze(0) # [1, 100]
if not self.config['skip_ndcg_eval']:
ndcg.observe(nsp_probs_dense, gt_relevance)
# visdial eval output
visdial_metrics = {}
if dataset.split == 'val' and eval_visdial:
if not self.config['skip_mrr_eval']:
visdial_metrics.update(sparse_metrics.retrieve(reset=True))
if not self.config['skip_ndcg_eval']:
visdial_metrics.update(ndcg.retrieve(reset=True))
if self.config['display']:
to_print = ''
for metric_name, metric_value in visdial_metrics.items():
if 'round' not in metric_name:
to_print += f"\n{metric_name}: {metric_value}"
if training_iter is not None:
if self.run:
self.run.log(
{'Val/' + metric_name: metric_value}, step=training_iter)
log.info(to_print)
if self.config['metrics_to_maximize'] in visdial_metrics:
metrics_to_maximize = visdial_metrics[self.config['metrics_to_maximize']]
else:
metrics_to_maximize = None
torch.cuda.empty_cache()
return metrics_to_maximize, visdial_metrics
else:
torch.cuda.empty_cache()
return None, None
def save_eval_results(self, split, epoch_idx, metrics_results):
metrics_filename = osp.join(
self.config['log_dir'], f'metrics_epoch_{epoch_idx}.json')
with open(metrics_filename, 'w') as f:
json.dump(metrics_results, f)
log.info(f'Results of metrics saved to {metrics_filename}')
if self.config["max_ckpt_to_keep"] > 0:
if len(self.metrics_queue) == self.metrics_queue.maxlen:
todel = self.metrics_queue.popleft()
os.remove(todel)
self.metrics_queue.append(metrics_filename)
if epoch_idx == 'best':
self.copy_best_predictions(split)
def copy_best_results(self, split, epoch_idx):
to_print = 'Copy '
if not self.config['skip_saving_ckpt']:
ckpt_path = osp.join(
self.config['log_dir'], f'epoch_{epoch_idx}.ckpt')
best_ckpt_path = ckpt_path.replace(
f'{epoch_idx}.ckpt', 'best.ckpt')
shutil.copyfile(ckpt_path, best_ckpt_path)
to_print += best_ckpt_path + ' '
metrics_filename = osp.join(
self.config['log_dir'], f'metrics_epoch_{epoch_idx}.json')
best_metric_filename = metrics_filename.replace(
f'{epoch_idx}.json', 'best.json')
shutil.copyfile(metrics_filename, best_metric_filename)
to_print += best_metric_filename + ' '
log.info(to_print)
def copy_best_predictions(self, split):
to_print = 'Copy '
visdial_output_dir = osp.join(self.config['visdial_output_dir'], split)
if osp.exists(visdial_output_dir):
dir_best = visdial_output_dir.replace('output', 'output_best')
if osp.exists(dir_best):
shutil.rmtree(dir_best)
shutil.copytree(visdial_output_dir, dir_best)
to_print += dir_best + ' '
log.info(to_print)
def get_ckpt(self):
ckpt = {
'epoch_idx': self.epoch_idx,
'max_metric': self.max_metric,
'seed': self.config['random_seed'],
'optimizer': self.optimizer.state_dict(),
'scheduler': self.scheduler.state_dict()
}
if self.config['parallel']:
ckpt['model_state_dict'] = self.model.module.state_dict()
else:
ckpt['model_state_dict'] = self.model.state_dict()
if self.config['dp_type'] == 'apex':
ckpt['amp'] = amp.state_dict()
return ckpt
def set_ckpt(self, ckpt_dict):
if not self.config['restarts']:
self.epoch_idx = ckpt_dict.get('epoch_idx', -1) + 1
if not self.config['resets_max_metric']:
self.max_metric = ckpt_dict.get('max_metric', -1)
if self.config['parallel']:
model = self.model.module
else:
model = self.model
model_state_dict = model.state_dict()
former_dict = {
k: v for k, v in ckpt_dict['model_state_dict'].items() if k in model_state_dict}
if self.config['display']:
log.info("number of keys transferred: %d" % len(former_dict))
assert len(former_dict.keys()) > 0
model_state_dict.update(former_dict)
model.load_state_dict(model_state_dict)
if self.config['display']:
log.info('loaded model')
del model_state_dict, former_dict
if not self.config['validating'] and not (self.config['uses_new_optimizer'] or self.config['sets_new_lr']):
if 'optimizer' in ckpt_dict:
self.optimizer.load_state_dict(ckpt_dict['optimizer'])
if self.config['display']:
log.info('loaded optimizer')
if 'scheduler' in ckpt_dict:
self.scheduler.last_epcoh = ckpt_dict['epoch_idx'] * \
self.config['num_iter_per_epoch']
self.scheduler.load_state_dict(ckpt_dict['scheduler'])
if 'amp' in ckpt_dict and self.config['dp_type'] == 'apex':
amp.load_state_dict(ckpt_dict['amp'])
del ckpt_dict
torch.cuda.empty_cache()
def save_ckpt(self):
ckpt_path = f'{self.config["log_dir"]}/epoch_{self.epoch_idx}.ckpt'
log.info(f'saving checkpoint {ckpt_path}')
ckpt = self.get_ckpt()
if self.config['skip_saving_ckpt']:
return ckpt_path
torch_version = float(torch.__version__[:3])
if torch_version - 1.4 > 1e-3:
torch.save(ckpt, f=ckpt_path, _use_new_zipfile_serialization=False)
else:
torch.save(ckpt, f=ckpt_path)
del ckpt
if not (self.config['parallel'] and self.config['dp_type'] in ['ddp', 'apex']):
torch.cuda.empty_cache()
if self.config["max_ckpt_to_keep"] > 0:
if len(self.checkpoint_queue) == self.checkpoint_queue.maxlen:
todel = self.checkpoint_queue.popleft()
os.remove(todel)
self.checkpoint_queue.append(ckpt_path)
return ckpt_path
def save_ckpt_best(self):
ckpt_path = f'{self.config["log_dir"]}/epoch_best.ckpt'
log.info(f'saving checkpoint {ckpt_path}')
ckpt = self.get_ckpt()
torch.save(ckpt, f=ckpt_path)
del ckpt
return ckpt_path
def load_ckpt_best(self):
ckpt_path = f'{osp.dirname(self.config["log_dir"])}/epoch_best.ckpt'
if not osp.exists(ckpt_path):
ckpt_paths = [path for path in os.listdir(
f'{self.config["log_dir"]}/') if path.endswith('.ckpt') and 'best' not in path]
if len(ckpt_paths) == 0:
if self.config['display']:
log.info(f'No .ckpt found in {self.config["log_dir"]}')
return
def sort_func(x): return int(re.search(r"(\d+)", x).groups()[0])
ckpt_path = f'{self.config["log_dir"]}/{sorted(ckpt_paths, key=sort_func, reverse=True)[0]}'
if self.config['display']:
log.info(f'loading checkpoint {ckpt_path}')
map_location = {'cuda:0': f'cuda:{self.gpu_rank}'}
self.set_ckpt(torch.load(ckpt_path, map_location=map_location))
def load_ckpt(self, ckpt_path=None):
if not ckpt_path:
if self.config['validating'] or self.config['loads_best_ckpt']:
ckpt_path = f'{self.config["log_dir"]}/epoch_best.ckpt'
else:
ckpt_paths = [path for path in os.listdir(
f'{self.config["log_dir"]}/') if path.endswith('.ckpt') and 'best' not in path]
if len(ckpt_paths) == 0:
if self.config['display']:
log.info(f'No .ckpt found in {self.config["log_dir"]}')
return
def sort_func(x): return int(
re.search(r"(\d+)", x).groups()[0])
ckpt_path = f'{self.config["log_dir"]}/{sorted(ckpt_paths, key=sort_func, reverse=True)[0]}'
if self.config['display']:
log.info(f'loading checkpoint {ckpt_path}')
epoch_name = osp.split(ckpt_path)[1].split('.')[0]
if re.search(r"(\d+)", epoch_name):
self.checkpoint_queue.append(ckpt_path)
metrics_filename = osp.join(
self.config['log_dir'], f'metrics_{epoch_name}.json')
if osp.exists(metrics_filename):
self.metrics_queue.append(metrics_filename)
map_location = {'cuda:0': f'cuda:{self.gpu_rank}'}
self.set_ckpt(torch.load(ckpt_path, map_location=map_location))
def match_model_key(self, pretrained_dict, model_dict):
matched_dict = dict()
for key in pretrained_dict:
if key in model_dict:
matched_key = key
elif key.startswith('encoder.') and key[8:] in model_dict:
matched_key = key[8:]
elif key.startswith('module.') and key[7:] in model_dict:
matched_key = key[7:]
elif 'encoder.' + key in model_dict:
matched_key = 'encoder.' + key
elif 'module.' + key in model_dict:
matched_key = 'module.' + key
else:
# not_found.append(key)
continue
matched_dict[matched_key] = pretrained_dict[key]
not_found = ""
for k in model_dict:
if k not in matched_dict:
not_found += k + '\n'
log.info("Keys from model_dict that were not found in pretrained_dict:")
log.info(not_found)
return matched_dict
def load_pretrained_vilbert(self, start_from=None):
if start_from is not None:
self.config["start_path"] = start_from
if self.config['training'] or self.config['debugging']:
ckpt_paths = [path for path in os.listdir(
f'{self.config["log_dir"]}/') if path.endswith('.ckpt') and 'best' not in path]
if len(ckpt_paths) > 0:
if self.config['display']:
log.info('Continue training')
return
if self.config['display']:
log.info(
f'Loading pretrained VilBERT from {self.config["start_path"]}')
map_location = {'cuda:0': f'cuda:{self.gpu_rank}'}
pretrained_dict = torch.load(
self.config['start_path'], map_location=map_location)
if 'model_state_dict' in pretrained_dict:
pretrained_dict = pretrained_dict['model_state_dict']
if self.config['parallel']:
model = self.model.module
else:
model = self.model
model_dict = model.state_dict()
matched_dict = self.match_model_key(pretrained_dict, model_dict)
if self.config['display']:
log.info("number of keys transferred: %d" % len(matched_dict))
assert len(matched_dict.keys()) > 0
model_dict.update(matched_dict)
model.load_state_dict(model_dict)
del pretrained_dict, model_dict, matched_dict
if not self.config['parallel'] or self.config['dp_type'] == 'dp':
torch.cuda.empty_cache()
if self.config['display']:
log.info(f'Pretrained VilBERT loaded')

379
models/vdgr.py Normal file
View file

@ -0,0 +1,379 @@
import sys
from collections import OrderedDict
import torch
from torch import nn
import torch.nn.functional as F
sys.path.append('../')
from utils.model_utils import listMLE, approxNDCGLoss, listNet, neuralNDCG, neuralNDCG_transposed
from utils.data_utils import sequence_mask
from utils.optim_utils import init_optim
from models.runner import Runner
from models.vilbert_dialog import BertForMultiModalPreTraining, BertConfig
class VDGR(nn.Module):
def __init__(self, config_path, device, use_apex=False, cache_dir=None):
super(VDGR, self).__init__()
config = BertConfig.from_json_file(config_path)
self.bert_pretrained = BertForMultiModalPreTraining.from_pretrained('bert-base-uncased', config, device, use_apex=use_apex, cache_dir=cache_dir)
self.bert_pretrained.train()
def forward(self, input_ids, image_feat, image_loc, image_edge_indices, image_edge_attributes,
question_edge_indices, question_edge_attributes, question_limits,
history_edge_indices, history_sep_indices,
sep_indices=None, sep_len=None, token_type_ids=None,
attention_mask=None, masked_lm_labels=None, next_sentence_label=None,
image_attention_mask=None, image_label=None, image_target=None):
masked_lm_loss = None
masked_img_loss = None
nsp_loss = None
seq_relationship_score = None
if next_sentence_label is not None and masked_lm_labels \
is not None and image_target is not None:
# train mode, output losses
masked_lm_loss, masked_img_loss, nsp_loss, _, _, seq_relationship_score, _ = \
self.bert_pretrained(
input_ids, image_feat, image_loc, image_edge_indices, image_edge_attributes,
question_edge_indices, question_edge_attributes, question_limits,
history_edge_indices, history_sep_indices, sep_indices=sep_indices, sep_len=sep_len, \
token_type_ids=token_type_ids, attention_mask=attention_mask, masked_lm_labels=masked_lm_labels, \
next_sentence_label=next_sentence_label, image_attention_mask=image_attention_mask,\
image_label=image_label, image_target=image_target)
else:
#inference, output scores
_, _, seq_relationship_score, _, _, _ = \
self.bert_pretrained(
input_ids, image_feat, image_loc, image_edge_indices, image_edge_attributes,
question_edge_indices, question_edge_attributes, question_limits,
history_edge_indices, history_sep_indices,
sep_indices=sep_indices, sep_len=sep_len, \
token_type_ids=token_type_ids, attention_mask=attention_mask, masked_lm_labels=masked_lm_labels, \
next_sentence_label=next_sentence_label, image_attention_mask=image_attention_mask,\
image_label=image_label, image_target=image_target)
out = (masked_lm_loss, masked_img_loss, nsp_loss, seq_relationship_score)
return out
class SparseRunner(Runner):
def __init__(self, config):
super(SparseRunner, self).__init__(config)
self.model = VDGR(
self.config['model_config'], self.config['device'],
use_apex=self.config['dp_type'] == 'apex',
cache_dir=self.config['bert_cache_dir'])
self.model.to(self.config['device'])
if not self.config['validating'] or self.config['dp_type'] == 'apex':
self.optimizer, self.scheduler = init_optim(self.model, self.config)
def forward(self, batch, eval_visdial=False):
# load data
for key in batch:
if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].to(self.config['device'])
elif isinstance(batch[key], list):
if key != 'dialog_info': # Do not send the dialog_info item to the gpu
batch[key] = [x.to(self.config['device']) for x in batch[key]]
tokens = batch['tokens']
segments = batch['segments']
sep_indices = batch['sep_indices']
mask = batch['mask']
hist_len = batch['hist_len']
image_feat = batch['image_feat']
image_loc = batch['image_loc']
image_mask = batch['image_mask']
next_sentence_labels = batch.get('next_sentence_labels', None)
image_target = batch.get('image_target', None)
image_label = batch.get('image_label', None)
# load the graph data
image_edge_indices = batch['image_edge_indices']
image_edge_attributes = batch['image_edge_attributes']
question_edge_indices = batch['question_edge_indices']
question_edge_attributes = batch['question_edge_attributes']
question_limits = batch['question_limits']
history_edge_indices = batch['history_edge_indices']
history_sep_indices = batch['history_sep_indices']
sequence_lengths = torch.gather(sep_indices, 1, hist_len.view(-1, 1)) + 1
sequence_lengths = sequence_lengths.squeeze(1)
attention_mask_lm_nsp = sequence_mask(sequence_lengths, max_len=tokens.shape[1])
sep_len = hist_len + 1
losses = OrderedDict()
if eval_visdial:
num_lines = tokens.size(0)
line_batch_size = self.config['eval_line_batch_size']
num_line_batches = num_lines // line_batch_size
if num_lines % line_batch_size > 0:
num_line_batches += 1
nsp_scores = []
for j in range(num_line_batches):
# create chunks of the original batch
chunk_range = range(j*line_batch_size, min((j+1)*line_batch_size, num_lines))
tokens_chunk = tokens[chunk_range]
segments_chunk = segments[chunk_range]
sep_indices_chunk = sep_indices[chunk_range]
mask_chunk = mask[chunk_range]
sep_len_chunk = sep_len[chunk_range]
attention_mask_lm_nsp_chunk = attention_mask_lm_nsp[chunk_range]
image_feat_chunk = image_feat[chunk_range]
image_loc_chunk = image_loc[chunk_range]
image_mask_chunk = image_mask[chunk_range]
image_edge_indices_chunk = image_edge_indices[chunk_range[0]: chunk_range[-1]+1]
image_edge_attributes_chunk = image_edge_attributes[chunk_range[0]: chunk_range[-1]+1]
question_edge_indices_chunk = question_edge_indices[chunk_range[0]: chunk_range[-1]+1]
question_edge_attributes_chunk = question_edge_attributes[chunk_range[0]: chunk_range[-1]+1]
question_limits_chunk = question_limits[chunk_range[0]: chunk_range[-1]+1]
history_edge_indices_chunk = history_edge_indices[chunk_range[0]: chunk_range[-1]+1]
history_sep_indices_chunk = history_sep_indices[chunk_range[0]: chunk_range[-1]+1]
_ , _ , _, nsp_scores_chunk = \
self.model(
tokens_chunk,
image_feat_chunk,
image_loc_chunk,
image_edge_indices_chunk,
image_edge_attributes_chunk,
question_edge_indices_chunk,
question_edge_attributes_chunk,
question_limits_chunk,
history_edge_indices_chunk,
history_sep_indices_chunk,
sep_indices=sep_indices_chunk,
sep_len=sep_len_chunk,
token_type_ids=segments_chunk,
masked_lm_labels=mask_chunk,
attention_mask=attention_mask_lm_nsp_chunk,
image_attention_mask=image_mask_chunk
)
nsp_scores.append(nsp_scores_chunk)
nsp_scores = torch.cat(nsp_scores, 0)
else:
losses['lm_loss'], losses['img_loss'], losses['nsp_loss'], nsp_scores = \
self.model(
tokens,
image_feat,
image_loc,
image_edge_indices,
image_edge_attributes,
question_edge_indices,
question_edge_attributes,
question_limits,
history_edge_indices,
history_sep_indices,
next_sentence_label=next_sentence_labels,
image_target=image_target,
image_label=image_label,
sep_indices=sep_indices,
sep_len=sep_len,
token_type_ids=segments,
masked_lm_labels=mask,
attention_mask=attention_mask_lm_nsp,
image_attention_mask=image_mask
)
losses['tot_loss'] = 0
for key in ['lm_loss', 'img_loss', 'nsp_loss']:
if key in losses and losses[key] is not None:
losses[key] = losses[key].mean()
losses['tot_loss'] += self.config[f'{key}_coeff'] * losses[key]
output = {
'losses': losses,
'nsp_scores': nsp_scores
}
return output
class DenseRunner(Runner):
def __init__(self, config):
super(DenseRunner, self).__init__(config)
self.model = VDGR(
self.config['model_config'], self.config['device'],
use_apex=self.config['dp_type'] == 'apex',
cache_dir=self.config['bert_cache_dir'])
if not(self.config['parallel'] and self.config['dp_type'] == 'dp'):
self.model.to(self.config['device'])
if self.config['dense_loss'] == 'ce':
self.dense_loss = nn.KLDivLoss(reduction='batchmean')
elif self.config['dense_loss'] == 'listmle':
self.dense_loss = listMLE
elif self.config['dense_loss'] == 'listnet':
self.dense_loss = listNet
elif self.config['dense_loss'] == 'approxndcg':
self.dense_loss = approxNDCGLoss
elif self.config['dense_loss'] == 'neural_ndcg':
self.dense_loss = neuralNDCG
elif self.config['dense_loss'] == 'neural_ndcg_transposed':
self.dense_loss = neuralNDCG_transposed
else:
raise ValueError('dense_loss must be one of ce, listmle, listnet, approxndcg, neural_ndcg, neural_ndcg_transposed')
if not self.config['validating'] or self.config['dp_type'] == 'apex':
self.optimizer, self.scheduler = init_optim(self.model, self.config)
def forward(self, batch, eval_visdial=False):
# load data
for key in batch:
if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].to(self.config['device'])
elif isinstance(batch[key], list):
if key != 'dialog_info': # Do not send the dialog_info item to the gpu
batch[key] = [x.to(self.config['device']) for x in batch[key]]
# get embedding and forward visdial
tokens = batch['tokens']
segments = batch['segments']
sep_indices = batch['sep_indices']
mask = batch['mask']
hist_len = batch['hist_len']
image_feat = batch['image_feat']
image_loc = batch['image_loc']
image_mask = batch['image_mask']
next_sentence_labels = batch.get('next_sentence_labels', None)
image_target = batch.get('image_target', None)
image_label = batch.get('image_label', None)
# load the graph data
image_edge_indices = batch['image_edge_indices']
image_edge_attributes = batch['image_edge_attributes']
question_edge_indices = batch['question_edge_indices']
question_edge_attributes = batch['question_edge_attributes']
question_limits = batch['question_limits']
history_edge_indices = batch['history_edge_indices']
assert history_edge_indices[0].size(0) == 2
history_sep_indices = batch['history_sep_indices']
sequence_lengths = torch.gather(sep_indices, 1, hist_len.view(-1, 1)) + 1
sequence_lengths = sequence_lengths.squeeze(1)
attention_mask_lm_nsp = sequence_mask(sequence_lengths, max_len=tokens.shape[1])
sep_len = hist_len + 1
losses = OrderedDict()
if eval_visdial:
num_lines = tokens.size(0)
line_batch_size = self.config['eval_line_batch_size']
num_line_batches = num_lines // line_batch_size
if num_lines % line_batch_size > 0:
num_line_batches += 1
nsp_scores = []
for j in range(num_line_batches):
# create chunks of the original batch
chunk_range = range(j*line_batch_size, min((j+1)*line_batch_size, num_lines))
tokens_chunk = tokens[chunk_range]
segments_chunk = segments[chunk_range]
sep_indices_chunk = sep_indices[chunk_range]
mask_chunk = mask[chunk_range]
sep_len_chunk = sep_len[chunk_range]
attention_mask_lm_nsp_chunk = attention_mask_lm_nsp[chunk_range]
image_feat_chunk = image_feat[chunk_range]
image_loc_chunk = image_loc[chunk_range]
image_mask_chunk = image_mask[chunk_range]
image_edge_indices_chunk = image_edge_indices[chunk_range[0]: chunk_range[-1]+1]
image_edge_attributes_chunk = image_edge_attributes[chunk_range[0]: chunk_range[-1]+1]
question_edge_indices_chunk = question_edge_indices[chunk_range[0]: chunk_range[-1]+1]
question_edge_attributes_chunk = question_edge_attributes[chunk_range[0]: chunk_range[-1]+1]
question_limits_chunk = question_limits[chunk_range[0]: chunk_range[-1]+1]
history_edge_indices_chunk = history_edge_indices[chunk_range[0]: chunk_range[-1]+1]
history_sep_indices_chunk = history_sep_indices[chunk_range[0]: chunk_range[-1]+1]
_, _, _, nsp_scores_chunk = \
self.model(
tokens_chunk,
image_feat_chunk,
image_loc_chunk,
image_edge_indices_chunk,
image_edge_attributes_chunk,
question_edge_indices_chunk,
question_edge_attributes_chunk,
question_limits_chunk,
history_edge_indices_chunk,
history_sep_indices_chunk,
sep_indices=sep_indices_chunk,
sep_len=sep_len_chunk,
token_type_ids=segments_chunk,
masked_lm_labels=mask_chunk,
attention_mask=attention_mask_lm_nsp_chunk,
image_attention_mask=image_mask_chunk
)
nsp_scores.append(nsp_scores_chunk)
nsp_scores = torch.cat(nsp_scores, 0)
else:
_, _, _, nsp_scores = \
self.model(
tokens,
image_feat,
image_loc,
image_edge_indices,
image_edge_attributes,
question_edge_indices,
question_edge_attributes,
question_limits,
history_edge_indices,
history_sep_indices,
next_sentence_label=next_sentence_labels,
image_target=image_target,
image_label=image_label,
sep_indices=sep_indices,
sep_len=sep_len,
token_type_ids=segments,
masked_lm_labels=mask,
attention_mask=attention_mask_lm_nsp,
image_attention_mask=image_mask
)
if nsp_scores is not None:
nsp_scores_output = nsp_scores.detach().clone()
if not eval_visdial:
nsp_scores = nsp_scores.view(-1, self.config['num_options_dense'], 2)
if 'next_sentence_labels' in batch and self.config['nsp_loss_coeff'] > 0:
next_sentence_labels = batch['next_sentence_labels'].to(self.config['device'])
losses['nsp_loss'] = F.cross_entropy(nsp_scores.view(-1,2), next_sentence_labels.view(-1))
else:
losses['nsp_loss'] = None
if not eval_visdial:
gt_relevance = batch['gt_relevance'].to(self.config['device'])
nsp_scores = nsp_scores[:, :, 0]
if self.config['dense_loss'] == 'ce':
losses['dense_loss'] = self.dense_loss(F.log_softmax(nsp_scores, dim=1), F.softmax(gt_relevance, dim=1))
else:
losses['dense_loss'] = self.dense_loss(nsp_scores, gt_relevance)
else:
losses['dense_loss'] = None
else:
nsp_scores_output = None
losses['nsp_loss'] = None
losses['dense_loss'] = None
losses['tot_loss'] = 0
for key in ['nsp_loss', 'dense_loss']:
if key in losses and losses[key] is not None:
losses[key] = losses[key].mean()
losses['tot_loss'] += self.config[f'{key}_coeff'] * losses[key]
output = {
'losses': losses,
'nsp_scores': nsp_scores_output
}
return output

2021
models/vilbert_dialog.py Normal file

File diff suppressed because it is too large Load diff