Code release
This commit is contained in:
commit
09fb25e339
29 changed files with 7162 additions and 0 deletions
0
models/__init__.py
Normal file
0
models/__init__.py
Normal file
830
models/runner.py
Normal file
830
models/runner.py
Normal file
|
@ -0,0 +1,830 @@
|
|||
import os
|
||||
import os.path as osp
|
||||
import json
|
||||
from collections import deque
|
||||
import time
|
||||
import re
|
||||
import shutil
|
||||
import glob
|
||||
import pickle
|
||||
import gc
|
||||
import numpy as np
|
||||
import glog as log
|
||||
try:
|
||||
from apex import amp
|
||||
except ModuleNotFoundError:
|
||||
print('apex not found')
|
||||
|
||||
import torch
|
||||
import torch.utils.data as tud
|
||||
import torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
|
||||
from utils.data_utils import load_pickle_lines
|
||||
from utils.visdial_metrics import SparseGTMetrics, NDCG, scores_to_ranks
|
||||
import wandb
|
||||
|
||||
|
||||
class Runner:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
if 'rank' in config:
|
||||
self.gpu_rank = config['rank']
|
||||
else:
|
||||
self.gpu_rank = 0
|
||||
|
||||
self.epoch_idx = 0
|
||||
self.max_metric = 0.
|
||||
self.max_metric_epoch_idx = 0
|
||||
self.na_str = 'N/A'
|
||||
|
||||
if self.config["max_ckpt_to_keep"] > 0:
|
||||
self.checkpoint_queue = deque(
|
||||
[], maxlen=config["max_ckpt_to_keep"])
|
||||
self.metrics_queue = deque([], maxlen=config["max_ckpt_to_keep"])
|
||||
|
||||
self.setup_wandb()
|
||||
|
||||
def setup_wandb(self):
|
||||
if self.gpu_rank == 0:
|
||||
print("[INFO] Set wandb logging on rank {}".format(0))
|
||||
run = wandb.init(
|
||||
project=self.config['wandb_project'], config=self.config, mode=self.config['wandb_mode'])
|
||||
else:
|
||||
run = None
|
||||
self.run = run
|
||||
|
||||
def forward(self, batch, eval_visdial=False):
|
||||
return NotImplementedError
|
||||
|
||||
def train(self, dataset, dataset_eval=None):
|
||||
# wandb.login()
|
||||
if os.path.exists(self.config['log_dir']) or self.config['loads_ckpt'] or self.config['loads_best_ckpt']:
|
||||
self.load_ckpt()
|
||||
|
||||
if self.config['use_trainval']:
|
||||
dataset.split = 'trainval'
|
||||
else:
|
||||
dataset.split = 'train'
|
||||
batch_size = self.config['batch_size']
|
||||
if self.config['parallel'] and self.config['dp_type'] != 'dp':
|
||||
sampler_tr = tud.distributed.DistributedSampler(
|
||||
dataset,
|
||||
num_replicas=self.config['num_gpus'],
|
||||
rank=self.gpu_rank
|
||||
)
|
||||
else:
|
||||
sampler_tr = None
|
||||
|
||||
data_loader_tr = tud.DataLoader(
|
||||
dataset=dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=self.config['training'] and not self.config['parallel'],
|
||||
collate_fn=dataset.collate_fn,
|
||||
num_workers=self.config['num_workers'],
|
||||
sampler=sampler_tr
|
||||
)
|
||||
|
||||
|
||||
start_epoch_idx = self.epoch_idx
|
||||
num_iter_epoch = self.config['num_iter_per_epoch']
|
||||
if self.config['display']:
|
||||
log.info(f'{num_iter_epoch} iter per epoch.')
|
||||
|
||||
# eval before training
|
||||
eval_dense_at_first = self.config['train_on_dense'] and self.config['skip_mrr_eval'] and start_epoch_idx == 0
|
||||
# eval before training under 2 circumstances:
|
||||
# for dense finetuning, eval ndcg before the first epoch
|
||||
# for mrr training, continue training and the last epoch is not evaluated
|
||||
|
||||
if (eval_dense_at_first or (self.config['eval_at_start'] and len(self.metrics_queue) == 0 and start_epoch_idx > 0)):
|
||||
if eval_dense_at_first:
|
||||
iter_now = 0
|
||||
else:
|
||||
iter_now = max(num_iter_epoch * start_epoch_idx, 0)
|
||||
|
||||
if dataset_eval is None:
|
||||
dataset.split = 'val'
|
||||
dataset_to_eval = dataset
|
||||
else:
|
||||
dataset_to_eval = dataset_eval
|
||||
|
||||
metrics_results = {}
|
||||
metrics_to_maximize, metrics_results['val'] = self.evaluate(
|
||||
dataset_to_eval, iter_now)
|
||||
if eval_dense_at_first:
|
||||
self.max_metric = metrics_to_maximize
|
||||
self.max_metric_epoch_idx = -1
|
||||
else:
|
||||
if self.config['display']:
|
||||
self.save_eval_results(
|
||||
'val', start_epoch_idx - 1, metrics_results)
|
||||
if metrics_to_maximize > self.max_metric:
|
||||
self.max_metric = metrics_to_maximize
|
||||
self.max_metric_epoch_idx = start_epoch_idx - 1
|
||||
self.copy_best_results('val', start_epoch_idx - 1)
|
||||
self.copy_best_predictions('val')
|
||||
if dataset_eval is None:
|
||||
if self.config['use_trainval']:
|
||||
dataset.split = 'trainval'
|
||||
else:
|
||||
dataset.split = 'train'
|
||||
|
||||
num_epochs = self.config['num_epochs']
|
||||
|
||||
for epoch_idx in range(start_epoch_idx, num_epochs):
|
||||
if self.config['parallel'] and self.config['dp_type'] != 'dp':
|
||||
sampler_tr.set_epoch(epoch_idx)
|
||||
|
||||
self.epoch_idx = epoch_idx
|
||||
|
||||
if self.config['display']:
|
||||
log.info(f'starting epoch {epoch_idx}')
|
||||
log.info('training')
|
||||
|
||||
self.model.train()
|
||||
|
||||
num_batch = 0
|
||||
next_logging_pct = .1
|
||||
next_evaluating_pct = self.config["next_evaluating_pct"] + .1
|
||||
start_time = time.time()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
for batch in data_loader_tr:
|
||||
if self.config['eval_before_training']:
|
||||
log.info('Skipping stright to evaluation...')
|
||||
break
|
||||
num_batch += 1
|
||||
pct = num_batch / num_iter_epoch * 100
|
||||
iter_now = num_iter_epoch * epoch_idx + num_batch
|
||||
|
||||
output = self.forward(batch)
|
||||
losses = output['losses']
|
||||
|
||||
# optimizer step
|
||||
losses['tot_loss'] /= self.config['batch_multiply']
|
||||
# debug
|
||||
if self.config['debugging']:
|
||||
log.info('try backward')
|
||||
if self.config['dp_type'] == 'apex':
|
||||
with amp.scale_loss(losses['tot_loss'], self.optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
losses['tot_loss'].backward()
|
||||
if self.config['debugging']:
|
||||
log.info('backward done')
|
||||
|
||||
if iter_now % self.config['batch_multiply'] == 0:
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
self.scheduler.step()
|
||||
|
||||
# display and eval
|
||||
if pct >= next_logging_pct:
|
||||
if self.config['display']:
|
||||
loss_to_print = ''
|
||||
for key in losses:
|
||||
if losses[key] is not None and isinstance(losses[key], torch.Tensor):
|
||||
loss_to_print += f'[{key}: {losses[key].item():.4f}]'
|
||||
print(
|
||||
f'[{int(pct)}%][Epoch: {epoch_idx + 1}/{num_epochs}][Iter : {num_batch}/{len(data_loader_tr)}] [time: {time.time() - start_time:.2f}] {loss_to_print}'
|
||||
)
|
||||
|
||||
|
||||
next_logging_pct += self.config["next_logging_pct"]
|
||||
|
||||
if self.config['debugging']:
|
||||
break
|
||||
|
||||
if pct >= next_evaluating_pct:
|
||||
next_evaluating_pct += self.config["next_evaluating_pct"]
|
||||
|
||||
if self.run:
|
||||
if self.config['train_on_dense']:
|
||||
self.run.log(
|
||||
{
|
||||
"Train/dense_loss": losses['dense_loss'],
|
||||
"Train/total_loss": losses['tot_loss'],
|
||||
},
|
||||
step=iter_now
|
||||
)
|
||||
|
||||
else:
|
||||
self.run.log(
|
||||
{
|
||||
"Train/lm_loss": losses['lm_loss'],
|
||||
"Train/img_loss": losses['img_loss'],
|
||||
"Train/nsp_loss": losses['nsp_loss'],
|
||||
"Train/total_loss": losses['tot_loss'],
|
||||
},
|
||||
step=iter_now
|
||||
)
|
||||
|
||||
lr_gnn, lr_bert = self.scheduler.get_lr()[0], self.scheduler.get_lr()[1]
|
||||
self.run.log(
|
||||
{
|
||||
"Train/lr_gnn": lr_gnn,
|
||||
"Train/lr_bert": lr_bert,
|
||||
},
|
||||
step=iter_now
|
||||
)
|
||||
del losses
|
||||
# debug
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if self.config['display']:
|
||||
log.info(
|
||||
f'100%,\ttime:\t{time.time() - start_time:.2f}'
|
||||
)
|
||||
ckpt_path = self.save_ckpt()
|
||||
|
||||
if not self.config['skip_visdial_eval'] and self.epoch_idx % self.config['eval_visdial_every'] == 0:
|
||||
|
||||
iter_now = num_iter_epoch * (epoch_idx + 1)
|
||||
|
||||
if dataset_eval is None:
|
||||
dataset.split = 'val'
|
||||
dataset_to_eval = dataset
|
||||
else:
|
||||
dataset_to_eval = dataset_eval
|
||||
metrics_results = {}
|
||||
metrics_to_maximize, metrics_results['val'] = self.evaluate(
|
||||
dataset_to_eval, iter_now)
|
||||
if dataset_eval is None:
|
||||
if self.config['use_trainval']:
|
||||
dataset.split = 'trainval'
|
||||
else:
|
||||
dataset.split = 'train'
|
||||
if self.config['display']:
|
||||
self.save_eval_results('val', epoch_idx, metrics_results)
|
||||
|
||||
if self.config['display']:
|
||||
|
||||
if metrics_to_maximize > self.max_metric:
|
||||
self.max_metric = metrics_to_maximize
|
||||
self.max_metric_epoch_idx = epoch_idx
|
||||
self.copy_best_results('val', epoch_idx)
|
||||
self.copy_best_predictions('val')
|
||||
|
||||
elif not self.config['parallel'] and epoch_idx - self.max_metric_epoch_idx > self.config["early_stop_epoch"]:
|
||||
log.info('Early stop.')
|
||||
break
|
||||
|
||||
if self.run:
|
||||
self.run.log(
|
||||
{"Val/metric_best": self.max_metric}, step=iter_now)
|
||||
|
||||
if self.config['parallel']:
|
||||
if self.config['dp_type'] == 'dp':
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
dist.barrier()
|
||||
log.info('Rank {} passed barrier...'.format(self.gpu_rank))
|
||||
|
||||
if self.config['stop_epochs'] >= 0 and epoch_idx + 1 >= self.config['stop_epochs']:
|
||||
if self.config['display']:
|
||||
log.info('Stop for reaching stop_epochs.')
|
||||
break
|
||||
|
||||
def evaluate(self, dataset, training_iter=None, eval_visdial=True):
|
||||
# create files to save output
|
||||
if self.config['predicting']:
|
||||
visdial_file_name = None
|
||||
if self.config['save_score']:
|
||||
visdial_file_name = osp.join(
|
||||
self.config['log_dir'], f'visdial_prediction.pkl')
|
||||
if osp.exists(visdial_file_name):
|
||||
dialogs_predicted = load_pickle_lines(
|
||||
visdial_file_name)
|
||||
dialogs_predicted = [d['image_id']
|
||||
for d in dialogs_predicted]
|
||||
else:
|
||||
dialogs_predicted = []
|
||||
f_visdial = open(visdial_file_name, 'ab')
|
||||
|
||||
else:
|
||||
visdial_file_name = osp.join(
|
||||
self.config['log_dir'], f'visdial_prediction.jsonlines')
|
||||
if self.config['parallel'] and self.config['dp_type'] != 'dp':
|
||||
visdial_file_name = visdial_file_name.replace(
|
||||
'.jsonlines', f'_{self.config["rank"]}of{self.config["num_gpus"]}.jsonlines')
|
||||
if osp.exists(visdial_file_name):
|
||||
dialogs_predicted_visdial = [json.loads(
|
||||
line)['image_id'] for line in open(visdial_file_name)]
|
||||
f_visdial = open(visdial_file_name, 'a')
|
||||
else:
|
||||
dialogs_predicted_visdial = []
|
||||
f_visdial = open(visdial_file_name, 'w')
|
||||
|
||||
dialogs_predicted = dialogs_predicted_visdial
|
||||
|
||||
if len(dialogs_predicted) > 0:
|
||||
log.info(f'Found {len(dialogs_predicted)} predicted results.')
|
||||
|
||||
if self.config['display']:
|
||||
if visdial_file_name is not None:
|
||||
log.info(
|
||||
f'VisDial predictions saved to {visdial_file_name}')
|
||||
|
||||
elif self.config['display']:
|
||||
if self.config['continue_evaluation']:
|
||||
predicted_files = os.listdir(
|
||||
osp.join(self.config['visdial_output_dir'], dataset.split))
|
||||
dialogs_predicted = [
|
||||
int(re.match(r'(\d+).npz', p).group(1)) for p in predicted_files]
|
||||
else:
|
||||
if osp.exists(osp.join(self.config['visdial_output_dir'], dataset.split)):
|
||||
shutil.rmtree(
|
||||
osp.join(self.config['visdial_output_dir'], dataset.split))
|
||||
os.makedirs(
|
||||
osp.join(self.config['visdial_output_dir'], dataset.split))
|
||||
|
||||
dialogs_predicted = []
|
||||
log.info(f'Found {len(dialogs_predicted)} predicted results.')
|
||||
|
||||
if self.config['parallel'] and self.config['dp_type'] != 'dp':
|
||||
sampler_val = tud.distributed.DistributedSampler(
|
||||
dataset,
|
||||
num_replicas=self.config['num_gpus'],
|
||||
rank=self.gpu_rank
|
||||
)
|
||||
|
||||
sampler_val.set_epoch(self.epoch_idx)
|
||||
else:
|
||||
sampler_val = None
|
||||
|
||||
data_loader_val = tud.DataLoader(
|
||||
dataset=dataset,
|
||||
batch_size=self.config['eval_batch_size'],
|
||||
shuffle=False,
|
||||
collate_fn=dataset.collate_fn,
|
||||
num_workers=self.config['num_workers'],
|
||||
sampler=sampler_val
|
||||
)
|
||||
self.model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
if self.config['display']:
|
||||
log.info(f'Evaluating {len(dataset)} samples')
|
||||
|
||||
next_logging_pct = self.config["next_logging_pct"] + .1
|
||||
if self.config['parallel'] and self.config['dp_type'] == 'dp':
|
||||
num_batch_tot = int(
|
||||
np.ceil(len(dataset) / self.config['eval_batch_size']))
|
||||
else:
|
||||
num_batch_tot = int(np.ceil(
|
||||
len(dataset) / (self.config['eval_batch_size'] * self.config['num_gpus'])))
|
||||
num_batch = 0
|
||||
if dataset.split == 'val':
|
||||
num_options = self.config["num_options"]
|
||||
if self.config['skip_mrr_eval']:
|
||||
num_rounds = 1
|
||||
else:
|
||||
num_rounds = 10
|
||||
elif dataset.split == 'test':
|
||||
num_options = 100
|
||||
num_rounds = 1
|
||||
if self.gpu_rank == 0:
|
||||
start_time = time.time()
|
||||
|
||||
for batch in data_loader_val:
|
||||
num_batch += 1
|
||||
# skip dialogs that have been predicted
|
||||
if self.config['predicting']:
|
||||
image_ids = batch['image_id'].tolist()
|
||||
skip_batch = True
|
||||
for image_id in image_ids:
|
||||
if image_id not in dialogs_predicted:
|
||||
skip_batch = False
|
||||
if skip_batch:
|
||||
continue
|
||||
output = self.forward(
|
||||
batch, eval_visdial=eval_visdial)
|
||||
|
||||
# visdial evaluation
|
||||
if eval_visdial:
|
||||
img_ids = batch['image_id'].tolist()
|
||||
batch_size = len(img_ids)
|
||||
if not self.config['skip_ndcg_eval']:
|
||||
gt_relevance_round_id = batch['round_id'].tolist()
|
||||
|
||||
# [batch_size * num_rounds * num_options, 2]
|
||||
nsp_scores = output['nsp_scores']
|
||||
nsp_probs = F.softmax(nsp_scores, dim=1)
|
||||
assert nsp_probs.shape[-1] == 2
|
||||
# num_dim=2, 0 for postive, 1 for negative
|
||||
nsp_probs = nsp_probs[:, 0]
|
||||
nsp_probs = nsp_probs.view(
|
||||
batch_size, num_rounds, num_options)
|
||||
|
||||
# could be predicting or evaluating
|
||||
if dataset.split == 'val':
|
||||
if self.config['skip_ndcg_eval']:
|
||||
gt_option_inds = batch['gt_option_inds']
|
||||
|
||||
for b in range(batch_size):
|
||||
filename = osp.join(
|
||||
self.config['visdial_output_dir'], dataset.split, f'{img_ids[b]}.npz')
|
||||
if not osp.exists(filename):
|
||||
np.savez(
|
||||
filename,
|
||||
nsp_probs=nsp_probs[b].cpu().numpy(),
|
||||
gt_option_inds=gt_option_inds[b].cpu().numpy()
|
||||
)
|
||||
else:
|
||||
# [batch_size, num_rounds]
|
||||
gt_option_inds = batch['gt_option_inds']
|
||||
# [batch_size, num_options]
|
||||
gt_relevance = batch['gt_relevance']
|
||||
|
||||
for b in range(batch_size):
|
||||
filename = osp.join(
|
||||
self.config['visdial_output_dir'], dataset.split, f'{img_ids[b]}.npz')
|
||||
if not osp.exists(filename):
|
||||
np.savez(filename,
|
||||
nsp_probs=nsp_probs[b].cpu().numpy(),
|
||||
gt_option_inds=gt_option_inds[b].cpu(
|
||||
).numpy(),
|
||||
gt_relevance=gt_relevance[b].cpu(
|
||||
).numpy(),
|
||||
gt_relevance_round_id=gt_relevance_round_id[b])
|
||||
|
||||
# must be predicting
|
||||
if dataset.split == 'test':
|
||||
if self.config['save_score']:
|
||||
for b in range(batch_size):
|
||||
prediction = {
|
||||
"image_id": img_ids[b],
|
||||
"nsp_probs": nsp_probs[b].cpu().numpy(),
|
||||
"gt_relevance_round_id": gt_relevance_round_id[b]
|
||||
}
|
||||
pickle.dump(prediction, f_visdial)
|
||||
else:
|
||||
# [eval_batch_size, num_rounds, num_options]
|
||||
ranks = scores_to_ranks(nsp_probs)
|
||||
ranks = ranks.squeeze(1)
|
||||
for b in range(batch_size):
|
||||
prediction = {
|
||||
"image_id": img_ids[b],
|
||||
"round_id": gt_relevance_round_id[b],
|
||||
"ranks": ranks[b].tolist()
|
||||
}
|
||||
f_visdial.write(json.dumps(prediction) + '\n')
|
||||
|
||||
# debug
|
||||
if self.config['debugging']:
|
||||
break
|
||||
|
||||
pct = num_batch / num_batch_tot * 100
|
||||
if pct >= next_logging_pct:
|
||||
if self.config['display'] and self.gpu_rank == 0:
|
||||
log.info(
|
||||
f'{int(pct)}%,\ttime:\t{time.time() - start_time:.2f}'
|
||||
)
|
||||
next_logging_pct += self.config["next_logging_pct"]
|
||||
# debug
|
||||
if self.config['debugging']:
|
||||
break
|
||||
|
||||
if self.config['display'] and self.gpu_rank == 0:
|
||||
pct = num_batch / num_batch_tot * 100
|
||||
log.info(
|
||||
f'{int(pct)}%,\ttime:\t{time.time() - start_time:.2f}'
|
||||
)
|
||||
|
||||
if not self.config['validating']:
|
||||
self.model.train()
|
||||
|
||||
if self.config['parallel'] and self.config['dp_type'] != 'dp':
|
||||
dist.barrier()
|
||||
|
||||
print(f'{self.gpu_rank} passed barrier')
|
||||
|
||||
if self.config['predicting']:
|
||||
f_visdial.close()
|
||||
if not self.config['save_score']:
|
||||
all_visdial_predictions = [json.loads(
|
||||
line) for line in open(visdial_file_name)]
|
||||
if self.config['predict_split'] == 'test' and len(all_visdial_predictions) == self.config['num_test_dialogs']:
|
||||
visdial_file_name = visdial_file_name.replace(
|
||||
'jsonlines', 'json')
|
||||
with open(visdial_file_name, 'w') as f_visdial:
|
||||
json.dump(all_visdial_predictions, f_visdial)
|
||||
log.info(
|
||||
f'Prediction for submisson save to {visdial_file_name}.')
|
||||
return None, None
|
||||
|
||||
if self.config['display']:
|
||||
if dataset.split == 'val' and eval_visdial:
|
||||
if not self.config['skip_mrr_eval']:
|
||||
sparse_metrics = SparseGTMetrics()
|
||||
if not self.config['skip_ndcg_eval']:
|
||||
ndcg = NDCG()
|
||||
|
||||
if dataset.split == 'val' and eval_visdial:
|
||||
visdial_output_filenames = glob.glob(
|
||||
osp.join(self.config['visdial_output_dir'], dataset.split, '*.npz'))
|
||||
log.info(
|
||||
f'Calculating visdial metrics for {len(visdial_output_filenames)} dialogs')
|
||||
for visdial_output_filename in visdial_output_filenames:
|
||||
output = np.load(visdial_output_filename)
|
||||
nsp_probs = torch.from_numpy(
|
||||
output['nsp_probs']).unsqueeze(0)
|
||||
if not self.config['skip_ndcg_eval']:
|
||||
gt_relevance = torch.from_numpy(output['gt_relevance']).unsqueeze(0)
|
||||
if not self.config['skip_mrr_eval']:
|
||||
gt_option_inds = torch.from_numpy(
|
||||
output['gt_option_inds']).unsqueeze(0)
|
||||
sparse_metrics.observe(nsp_probs, gt_option_inds)
|
||||
if not self.config['skip_ndcg_eval']:
|
||||
gt_relevance_round_id = output['gt_relevance_round_id']
|
||||
nsp_probs_dense = nsp_probs[0, gt_relevance_round_id - 1, :].unsqueeze(0)
|
||||
else:
|
||||
nsp_probs_dense = nsp_probs.squeeze(0) # [1, 100]
|
||||
if not self.config['skip_ndcg_eval']:
|
||||
ndcg.observe(nsp_probs_dense, gt_relevance)
|
||||
|
||||
# visdial eval output
|
||||
visdial_metrics = {}
|
||||
if dataset.split == 'val' and eval_visdial:
|
||||
if not self.config['skip_mrr_eval']:
|
||||
visdial_metrics.update(sparse_metrics.retrieve(reset=True))
|
||||
if not self.config['skip_ndcg_eval']:
|
||||
visdial_metrics.update(ndcg.retrieve(reset=True))
|
||||
|
||||
if self.config['display']:
|
||||
to_print = ''
|
||||
for metric_name, metric_value in visdial_metrics.items():
|
||||
if 'round' not in metric_name:
|
||||
to_print += f"\n{metric_name}: {metric_value}"
|
||||
if training_iter is not None:
|
||||
if self.run:
|
||||
self.run.log(
|
||||
{'Val/' + metric_name: metric_value}, step=training_iter)
|
||||
log.info(to_print)
|
||||
|
||||
if self.config['metrics_to_maximize'] in visdial_metrics:
|
||||
metrics_to_maximize = visdial_metrics[self.config['metrics_to_maximize']]
|
||||
else:
|
||||
metrics_to_maximize = None
|
||||
torch.cuda.empty_cache()
|
||||
return metrics_to_maximize, visdial_metrics
|
||||
else:
|
||||
torch.cuda.empty_cache()
|
||||
return None, None
|
||||
|
||||
def save_eval_results(self, split, epoch_idx, metrics_results):
|
||||
|
||||
metrics_filename = osp.join(
|
||||
self.config['log_dir'], f'metrics_epoch_{epoch_idx}.json')
|
||||
with open(metrics_filename, 'w') as f:
|
||||
json.dump(metrics_results, f)
|
||||
log.info(f'Results of metrics saved to {metrics_filename}')
|
||||
|
||||
if self.config["max_ckpt_to_keep"] > 0:
|
||||
if len(self.metrics_queue) == self.metrics_queue.maxlen:
|
||||
todel = self.metrics_queue.popleft()
|
||||
os.remove(todel)
|
||||
self.metrics_queue.append(metrics_filename)
|
||||
|
||||
if epoch_idx == 'best':
|
||||
self.copy_best_predictions(split)
|
||||
|
||||
def copy_best_results(self, split, epoch_idx):
|
||||
to_print = 'Copy '
|
||||
|
||||
if not self.config['skip_saving_ckpt']:
|
||||
ckpt_path = osp.join(
|
||||
self.config['log_dir'], f'epoch_{epoch_idx}.ckpt')
|
||||
best_ckpt_path = ckpt_path.replace(
|
||||
f'{epoch_idx}.ckpt', 'best.ckpt')
|
||||
shutil.copyfile(ckpt_path, best_ckpt_path)
|
||||
to_print += best_ckpt_path + ' '
|
||||
|
||||
metrics_filename = osp.join(
|
||||
self.config['log_dir'], f'metrics_epoch_{epoch_idx}.json')
|
||||
best_metric_filename = metrics_filename.replace(
|
||||
f'{epoch_idx}.json', 'best.json')
|
||||
shutil.copyfile(metrics_filename, best_metric_filename)
|
||||
to_print += best_metric_filename + ' '
|
||||
|
||||
log.info(to_print)
|
||||
|
||||
def copy_best_predictions(self, split):
|
||||
to_print = 'Copy '
|
||||
|
||||
visdial_output_dir = osp.join(self.config['visdial_output_dir'], split)
|
||||
if osp.exists(visdial_output_dir):
|
||||
dir_best = visdial_output_dir.replace('output', 'output_best')
|
||||
if osp.exists(dir_best):
|
||||
shutil.rmtree(dir_best)
|
||||
shutil.copytree(visdial_output_dir, dir_best)
|
||||
to_print += dir_best + ' '
|
||||
|
||||
log.info(to_print)
|
||||
|
||||
def get_ckpt(self):
|
||||
ckpt = {
|
||||
'epoch_idx': self.epoch_idx,
|
||||
'max_metric': self.max_metric,
|
||||
'seed': self.config['random_seed'],
|
||||
'optimizer': self.optimizer.state_dict(),
|
||||
'scheduler': self.scheduler.state_dict()
|
||||
}
|
||||
if self.config['parallel']:
|
||||
ckpt['model_state_dict'] = self.model.module.state_dict()
|
||||
else:
|
||||
ckpt['model_state_dict'] = self.model.state_dict()
|
||||
if self.config['dp_type'] == 'apex':
|
||||
ckpt['amp'] = amp.state_dict()
|
||||
return ckpt
|
||||
|
||||
def set_ckpt(self, ckpt_dict):
|
||||
if not self.config['restarts']:
|
||||
self.epoch_idx = ckpt_dict.get('epoch_idx', -1) + 1
|
||||
|
||||
if not self.config['resets_max_metric']:
|
||||
self.max_metric = ckpt_dict.get('max_metric', -1)
|
||||
|
||||
if self.config['parallel']:
|
||||
model = self.model.module
|
||||
else:
|
||||
model = self.model
|
||||
|
||||
model_state_dict = model.state_dict()
|
||||
former_dict = {
|
||||
k: v for k, v in ckpt_dict['model_state_dict'].items() if k in model_state_dict}
|
||||
|
||||
if self.config['display']:
|
||||
log.info("number of keys transferred: %d" % len(former_dict))
|
||||
assert len(former_dict.keys()) > 0
|
||||
|
||||
model_state_dict.update(former_dict)
|
||||
|
||||
model.load_state_dict(model_state_dict)
|
||||
if self.config['display']:
|
||||
log.info('loaded model')
|
||||
del model_state_dict, former_dict
|
||||
|
||||
if not self.config['validating'] and not (self.config['uses_new_optimizer'] or self.config['sets_new_lr']):
|
||||
if 'optimizer' in ckpt_dict:
|
||||
self.optimizer.load_state_dict(ckpt_dict['optimizer'])
|
||||
if self.config['display']:
|
||||
log.info('loaded optimizer')
|
||||
if 'scheduler' in ckpt_dict:
|
||||
self.scheduler.last_epcoh = ckpt_dict['epoch_idx'] * \
|
||||
self.config['num_iter_per_epoch']
|
||||
self.scheduler.load_state_dict(ckpt_dict['scheduler'])
|
||||
|
||||
if 'amp' in ckpt_dict and self.config['dp_type'] == 'apex':
|
||||
amp.load_state_dict(ckpt_dict['amp'])
|
||||
|
||||
del ckpt_dict
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def save_ckpt(self):
|
||||
ckpt_path = f'{self.config["log_dir"]}/epoch_{self.epoch_idx}.ckpt'
|
||||
log.info(f'saving checkpoint {ckpt_path}')
|
||||
ckpt = self.get_ckpt()
|
||||
if self.config['skip_saving_ckpt']:
|
||||
return ckpt_path
|
||||
torch_version = float(torch.__version__[:3])
|
||||
if torch_version - 1.4 > 1e-3:
|
||||
torch.save(ckpt, f=ckpt_path, _use_new_zipfile_serialization=False)
|
||||
else:
|
||||
torch.save(ckpt, f=ckpt_path)
|
||||
del ckpt
|
||||
|
||||
if not (self.config['parallel'] and self.config['dp_type'] in ['ddp', 'apex']):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if self.config["max_ckpt_to_keep"] > 0:
|
||||
if len(self.checkpoint_queue) == self.checkpoint_queue.maxlen:
|
||||
todel = self.checkpoint_queue.popleft()
|
||||
os.remove(todel)
|
||||
self.checkpoint_queue.append(ckpt_path)
|
||||
|
||||
return ckpt_path
|
||||
|
||||
def save_ckpt_best(self):
|
||||
ckpt_path = f'{self.config["log_dir"]}/epoch_best.ckpt'
|
||||
log.info(f'saving checkpoint {ckpt_path}')
|
||||
ckpt = self.get_ckpt()
|
||||
torch.save(ckpt, f=ckpt_path)
|
||||
del ckpt
|
||||
return ckpt_path
|
||||
|
||||
def load_ckpt_best(self):
|
||||
ckpt_path = f'{osp.dirname(self.config["log_dir"])}/epoch_best.ckpt'
|
||||
if not osp.exists(ckpt_path):
|
||||
ckpt_paths = [path for path in os.listdir(
|
||||
f'{self.config["log_dir"]}/') if path.endswith('.ckpt') and 'best' not in path]
|
||||
if len(ckpt_paths) == 0:
|
||||
if self.config['display']:
|
||||
log.info(f'No .ckpt found in {self.config["log_dir"]}')
|
||||
return
|
||||
|
||||
def sort_func(x): return int(re.search(r"(\d+)", x).groups()[0])
|
||||
ckpt_path = f'{self.config["log_dir"]}/{sorted(ckpt_paths, key=sort_func, reverse=True)[0]}'
|
||||
if self.config['display']:
|
||||
log.info(f'loading checkpoint {ckpt_path}')
|
||||
map_location = {'cuda:0': f'cuda:{self.gpu_rank}'}
|
||||
self.set_ckpt(torch.load(ckpt_path, map_location=map_location))
|
||||
|
||||
def load_ckpt(self, ckpt_path=None):
|
||||
if not ckpt_path:
|
||||
if self.config['validating'] or self.config['loads_best_ckpt']:
|
||||
ckpt_path = f'{self.config["log_dir"]}/epoch_best.ckpt'
|
||||
else:
|
||||
ckpt_paths = [path for path in os.listdir(
|
||||
f'{self.config["log_dir"]}/') if path.endswith('.ckpt') and 'best' not in path]
|
||||
if len(ckpt_paths) == 0:
|
||||
if self.config['display']:
|
||||
log.info(f'No .ckpt found in {self.config["log_dir"]}')
|
||||
return
|
||||
|
||||
def sort_func(x): return int(
|
||||
re.search(r"(\d+)", x).groups()[0])
|
||||
ckpt_path = f'{self.config["log_dir"]}/{sorted(ckpt_paths, key=sort_func, reverse=True)[0]}'
|
||||
|
||||
if self.config['display']:
|
||||
log.info(f'loading checkpoint {ckpt_path}')
|
||||
epoch_name = osp.split(ckpt_path)[1].split('.')[0]
|
||||
if re.search(r"(\d+)", epoch_name):
|
||||
self.checkpoint_queue.append(ckpt_path)
|
||||
metrics_filename = osp.join(
|
||||
self.config['log_dir'], f'metrics_{epoch_name}.json')
|
||||
if osp.exists(metrics_filename):
|
||||
self.metrics_queue.append(metrics_filename)
|
||||
|
||||
map_location = {'cuda:0': f'cuda:{self.gpu_rank}'}
|
||||
self.set_ckpt(torch.load(ckpt_path, map_location=map_location))
|
||||
|
||||
def match_model_key(self, pretrained_dict, model_dict):
|
||||
matched_dict = dict()
|
||||
for key in pretrained_dict:
|
||||
if key in model_dict:
|
||||
matched_key = key
|
||||
elif key.startswith('encoder.') and key[8:] in model_dict:
|
||||
matched_key = key[8:]
|
||||
elif key.startswith('module.') and key[7:] in model_dict:
|
||||
matched_key = key[7:]
|
||||
elif 'encoder.' + key in model_dict:
|
||||
matched_key = 'encoder.' + key
|
||||
elif 'module.' + key in model_dict:
|
||||
matched_key = 'module.' + key
|
||||
else:
|
||||
# not_found.append(key)
|
||||
continue
|
||||
matched_dict[matched_key] = pretrained_dict[key]
|
||||
|
||||
not_found = ""
|
||||
for k in model_dict:
|
||||
if k not in matched_dict:
|
||||
not_found += k + '\n'
|
||||
|
||||
log.info("Keys from model_dict that were not found in pretrained_dict:")
|
||||
log.info(not_found)
|
||||
return matched_dict
|
||||
|
||||
def load_pretrained_vilbert(self, start_from=None):
|
||||
if start_from is not None:
|
||||
self.config["start_path"] = start_from
|
||||
if self.config['training'] or self.config['debugging']:
|
||||
ckpt_paths = [path for path in os.listdir(
|
||||
f'{self.config["log_dir"]}/') if path.endswith('.ckpt') and 'best' not in path]
|
||||
if len(ckpt_paths) > 0:
|
||||
if self.config['display']:
|
||||
log.info('Continue training')
|
||||
return
|
||||
|
||||
if self.config['display']:
|
||||
log.info(
|
||||
f'Loading pretrained VilBERT from {self.config["start_path"]}')
|
||||
map_location = {'cuda:0': f'cuda:{self.gpu_rank}'}
|
||||
pretrained_dict = torch.load(
|
||||
self.config['start_path'], map_location=map_location)
|
||||
if 'model_state_dict' in pretrained_dict:
|
||||
pretrained_dict = pretrained_dict['model_state_dict']
|
||||
if self.config['parallel']:
|
||||
model = self.model.module
|
||||
else:
|
||||
model = self.model
|
||||
model_dict = model.state_dict()
|
||||
|
||||
matched_dict = self.match_model_key(pretrained_dict, model_dict)
|
||||
|
||||
if self.config['display']:
|
||||
log.info("number of keys transferred: %d" % len(matched_dict))
|
||||
assert len(matched_dict.keys()) > 0
|
||||
model_dict.update(matched_dict)
|
||||
model.load_state_dict(model_dict)
|
||||
|
||||
del pretrained_dict, model_dict, matched_dict
|
||||
if not self.config['parallel'] or self.config['dp_type'] == 'dp':
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if self.config['display']:
|
||||
log.info(f'Pretrained VilBERT loaded')
|
379
models/vdgr.py
Normal file
379
models/vdgr.py
Normal file
|
@ -0,0 +1,379 @@
|
|||
import sys
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
sys.path.append('../')
|
||||
from utils.model_utils import listMLE, approxNDCGLoss, listNet, neuralNDCG, neuralNDCG_transposed
|
||||
|
||||
from utils.data_utils import sequence_mask
|
||||
from utils.optim_utils import init_optim
|
||||
from models.runner import Runner
|
||||
|
||||
from models.vilbert_dialog import BertForMultiModalPreTraining, BertConfig
|
||||
|
||||
|
||||
class VDGR(nn.Module):
|
||||
|
||||
def __init__(self, config_path, device, use_apex=False, cache_dir=None):
|
||||
super(VDGR, self).__init__()
|
||||
config = BertConfig.from_json_file(config_path)
|
||||
|
||||
self.bert_pretrained = BertForMultiModalPreTraining.from_pretrained('bert-base-uncased', config, device, use_apex=use_apex, cache_dir=cache_dir)
|
||||
self.bert_pretrained.train()
|
||||
|
||||
def forward(self, input_ids, image_feat, image_loc, image_edge_indices, image_edge_attributes,
|
||||
question_edge_indices, question_edge_attributes, question_limits,
|
||||
history_edge_indices, history_sep_indices,
|
||||
sep_indices=None, sep_len=None, token_type_ids=None,
|
||||
attention_mask=None, masked_lm_labels=None, next_sentence_label=None,
|
||||
image_attention_mask=None, image_label=None, image_target=None):
|
||||
|
||||
masked_lm_loss = None
|
||||
masked_img_loss = None
|
||||
nsp_loss = None
|
||||
seq_relationship_score = None
|
||||
|
||||
if next_sentence_label is not None and masked_lm_labels \
|
||||
is not None and image_target is not None:
|
||||
# train mode, output losses
|
||||
masked_lm_loss, masked_img_loss, nsp_loss, _, _, seq_relationship_score, _ = \
|
||||
self.bert_pretrained(
|
||||
input_ids, image_feat, image_loc, image_edge_indices, image_edge_attributes,
|
||||
question_edge_indices, question_edge_attributes, question_limits,
|
||||
history_edge_indices, history_sep_indices, sep_indices=sep_indices, sep_len=sep_len, \
|
||||
token_type_ids=token_type_ids, attention_mask=attention_mask, masked_lm_labels=masked_lm_labels, \
|
||||
next_sentence_label=next_sentence_label, image_attention_mask=image_attention_mask,\
|
||||
image_label=image_label, image_target=image_target)
|
||||
else:
|
||||
#inference, output scores
|
||||
_, _, seq_relationship_score, _, _, _ = \
|
||||
self.bert_pretrained(
|
||||
input_ids, image_feat, image_loc, image_edge_indices, image_edge_attributes,
|
||||
question_edge_indices, question_edge_attributes, question_limits,
|
||||
history_edge_indices, history_sep_indices,
|
||||
sep_indices=sep_indices, sep_len=sep_len, \
|
||||
token_type_ids=token_type_ids, attention_mask=attention_mask, masked_lm_labels=masked_lm_labels, \
|
||||
next_sentence_label=next_sentence_label, image_attention_mask=image_attention_mask,\
|
||||
image_label=image_label, image_target=image_target)
|
||||
|
||||
out = (masked_lm_loss, masked_img_loss, nsp_loss, seq_relationship_score)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class SparseRunner(Runner):
|
||||
def __init__(self, config):
|
||||
super(SparseRunner, self).__init__(config)
|
||||
self.model = VDGR(
|
||||
self.config['model_config'], self.config['device'],
|
||||
use_apex=self.config['dp_type'] == 'apex',
|
||||
cache_dir=self.config['bert_cache_dir'])
|
||||
|
||||
self.model.to(self.config['device'])
|
||||
|
||||
if not self.config['validating'] or self.config['dp_type'] == 'apex':
|
||||
self.optimizer, self.scheduler = init_optim(self.model, self.config)
|
||||
|
||||
def forward(self, batch, eval_visdial=False):
|
||||
# load data
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
batch[key] = batch[key].to(self.config['device'])
|
||||
elif isinstance(batch[key], list):
|
||||
if key != 'dialog_info': # Do not send the dialog_info item to the gpu
|
||||
batch[key] = [x.to(self.config['device']) for x in batch[key]]
|
||||
|
||||
tokens = batch['tokens']
|
||||
segments = batch['segments']
|
||||
sep_indices = batch['sep_indices']
|
||||
mask = batch['mask']
|
||||
hist_len = batch['hist_len']
|
||||
image_feat = batch['image_feat']
|
||||
image_loc = batch['image_loc']
|
||||
image_mask = batch['image_mask']
|
||||
next_sentence_labels = batch.get('next_sentence_labels', None)
|
||||
image_target = batch.get('image_target', None)
|
||||
image_label = batch.get('image_label', None)
|
||||
# load the graph data
|
||||
image_edge_indices = batch['image_edge_indices']
|
||||
image_edge_attributes = batch['image_edge_attributes']
|
||||
question_edge_indices = batch['question_edge_indices']
|
||||
question_edge_attributes = batch['question_edge_attributes']
|
||||
question_limits = batch['question_limits']
|
||||
history_edge_indices = batch['history_edge_indices']
|
||||
history_sep_indices = batch['history_sep_indices']
|
||||
|
||||
sequence_lengths = torch.gather(sep_indices, 1, hist_len.view(-1, 1)) + 1
|
||||
sequence_lengths = sequence_lengths.squeeze(1)
|
||||
attention_mask_lm_nsp = sequence_mask(sequence_lengths, max_len=tokens.shape[1])
|
||||
sep_len = hist_len + 1
|
||||
|
||||
losses = OrderedDict()
|
||||
|
||||
if eval_visdial:
|
||||
num_lines = tokens.size(0)
|
||||
line_batch_size = self.config['eval_line_batch_size']
|
||||
num_line_batches = num_lines // line_batch_size
|
||||
if num_lines % line_batch_size > 0:
|
||||
num_line_batches += 1
|
||||
nsp_scores = []
|
||||
for j in range(num_line_batches):
|
||||
# create chunks of the original batch
|
||||
chunk_range = range(j*line_batch_size, min((j+1)*line_batch_size, num_lines))
|
||||
tokens_chunk = tokens[chunk_range]
|
||||
segments_chunk = segments[chunk_range]
|
||||
sep_indices_chunk = sep_indices[chunk_range]
|
||||
mask_chunk = mask[chunk_range]
|
||||
sep_len_chunk = sep_len[chunk_range]
|
||||
attention_mask_lm_nsp_chunk = attention_mask_lm_nsp[chunk_range]
|
||||
image_feat_chunk = image_feat[chunk_range]
|
||||
image_loc_chunk = image_loc[chunk_range]
|
||||
image_mask_chunk = image_mask[chunk_range]
|
||||
image_edge_indices_chunk = image_edge_indices[chunk_range[0]: chunk_range[-1]+1]
|
||||
image_edge_attributes_chunk = image_edge_attributes[chunk_range[0]: chunk_range[-1]+1]
|
||||
question_edge_indices_chunk = question_edge_indices[chunk_range[0]: chunk_range[-1]+1]
|
||||
question_edge_attributes_chunk = question_edge_attributes[chunk_range[0]: chunk_range[-1]+1]
|
||||
question_limits_chunk = question_limits[chunk_range[0]: chunk_range[-1]+1]
|
||||
history_edge_indices_chunk = history_edge_indices[chunk_range[0]: chunk_range[-1]+1]
|
||||
history_sep_indices_chunk = history_sep_indices[chunk_range[0]: chunk_range[-1]+1]
|
||||
|
||||
_ , _ , _, nsp_scores_chunk = \
|
||||
self.model(
|
||||
tokens_chunk,
|
||||
image_feat_chunk,
|
||||
image_loc_chunk,
|
||||
image_edge_indices_chunk,
|
||||
image_edge_attributes_chunk,
|
||||
question_edge_indices_chunk,
|
||||
question_edge_attributes_chunk,
|
||||
question_limits_chunk,
|
||||
history_edge_indices_chunk,
|
||||
history_sep_indices_chunk,
|
||||
sep_indices=sep_indices_chunk,
|
||||
sep_len=sep_len_chunk,
|
||||
token_type_ids=segments_chunk,
|
||||
masked_lm_labels=mask_chunk,
|
||||
attention_mask=attention_mask_lm_nsp_chunk,
|
||||
image_attention_mask=image_mask_chunk
|
||||
)
|
||||
nsp_scores.append(nsp_scores_chunk)
|
||||
nsp_scores = torch.cat(nsp_scores, 0)
|
||||
|
||||
else:
|
||||
losses['lm_loss'], losses['img_loss'], losses['nsp_loss'], nsp_scores = \
|
||||
self.model(
|
||||
tokens,
|
||||
image_feat,
|
||||
image_loc,
|
||||
image_edge_indices,
|
||||
image_edge_attributes,
|
||||
question_edge_indices,
|
||||
question_edge_attributes,
|
||||
question_limits,
|
||||
history_edge_indices,
|
||||
history_sep_indices,
|
||||
next_sentence_label=next_sentence_labels,
|
||||
image_target=image_target,
|
||||
image_label=image_label,
|
||||
sep_indices=sep_indices,
|
||||
sep_len=sep_len,
|
||||
token_type_ids=segments,
|
||||
masked_lm_labels=mask,
|
||||
attention_mask=attention_mask_lm_nsp,
|
||||
image_attention_mask=image_mask
|
||||
)
|
||||
|
||||
losses['tot_loss'] = 0
|
||||
for key in ['lm_loss', 'img_loss', 'nsp_loss']:
|
||||
if key in losses and losses[key] is not None:
|
||||
losses[key] = losses[key].mean()
|
||||
losses['tot_loss'] += self.config[f'{key}_coeff'] * losses[key]
|
||||
|
||||
output = {
|
||||
'losses': losses,
|
||||
'nsp_scores': nsp_scores
|
||||
}
|
||||
return output
|
||||
|
||||
|
||||
class DenseRunner(Runner):
|
||||
def __init__(self, config):
|
||||
super(DenseRunner, self).__init__(config)
|
||||
self.model = VDGR(
|
||||
self.config['model_config'], self.config['device'],
|
||||
use_apex=self.config['dp_type'] == 'apex',
|
||||
cache_dir=self.config['bert_cache_dir'])
|
||||
|
||||
if not(self.config['parallel'] and self.config['dp_type'] == 'dp'):
|
||||
self.model.to(self.config['device'])
|
||||
|
||||
if self.config['dense_loss'] == 'ce':
|
||||
self.dense_loss = nn.KLDivLoss(reduction='batchmean')
|
||||
elif self.config['dense_loss'] == 'listmle':
|
||||
self.dense_loss = listMLE
|
||||
elif self.config['dense_loss'] == 'listnet':
|
||||
self.dense_loss = listNet
|
||||
elif self.config['dense_loss'] == 'approxndcg':
|
||||
self.dense_loss = approxNDCGLoss
|
||||
elif self.config['dense_loss'] == 'neural_ndcg':
|
||||
self.dense_loss = neuralNDCG
|
||||
elif self.config['dense_loss'] == 'neural_ndcg_transposed':
|
||||
self.dense_loss = neuralNDCG_transposed
|
||||
else:
|
||||
raise ValueError('dense_loss must be one of ce, listmle, listnet, approxndcg, neural_ndcg, neural_ndcg_transposed')
|
||||
|
||||
if not self.config['validating'] or self.config['dp_type'] == 'apex':
|
||||
self.optimizer, self.scheduler = init_optim(self.model, self.config)
|
||||
|
||||
def forward(self, batch, eval_visdial=False):
|
||||
# load data
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
batch[key] = batch[key].to(self.config['device'])
|
||||
elif isinstance(batch[key], list):
|
||||
if key != 'dialog_info': # Do not send the dialog_info item to the gpu
|
||||
batch[key] = [x.to(self.config['device']) for x in batch[key]]
|
||||
|
||||
# get embedding and forward visdial
|
||||
tokens = batch['tokens']
|
||||
segments = batch['segments']
|
||||
sep_indices = batch['sep_indices']
|
||||
mask = batch['mask']
|
||||
hist_len = batch['hist_len']
|
||||
image_feat = batch['image_feat']
|
||||
image_loc = batch['image_loc']
|
||||
image_mask = batch['image_mask']
|
||||
next_sentence_labels = batch.get('next_sentence_labels', None)
|
||||
image_target = batch.get('image_target', None)
|
||||
image_label = batch.get('image_label', None)
|
||||
|
||||
# load the graph data
|
||||
image_edge_indices = batch['image_edge_indices']
|
||||
image_edge_attributes = batch['image_edge_attributes']
|
||||
question_edge_indices = batch['question_edge_indices']
|
||||
question_edge_attributes = batch['question_edge_attributes']
|
||||
question_limits = batch['question_limits']
|
||||
history_edge_indices = batch['history_edge_indices']
|
||||
assert history_edge_indices[0].size(0) == 2
|
||||
history_sep_indices = batch['history_sep_indices']
|
||||
|
||||
sequence_lengths = torch.gather(sep_indices, 1, hist_len.view(-1, 1)) + 1
|
||||
sequence_lengths = sequence_lengths.squeeze(1)
|
||||
attention_mask_lm_nsp = sequence_mask(sequence_lengths, max_len=tokens.shape[1])
|
||||
sep_len = hist_len + 1
|
||||
|
||||
losses = OrderedDict()
|
||||
|
||||
if eval_visdial:
|
||||
num_lines = tokens.size(0)
|
||||
line_batch_size = self.config['eval_line_batch_size']
|
||||
num_line_batches = num_lines // line_batch_size
|
||||
if num_lines % line_batch_size > 0:
|
||||
num_line_batches += 1
|
||||
nsp_scores = []
|
||||
for j in range(num_line_batches):
|
||||
# create chunks of the original batch
|
||||
chunk_range = range(j*line_batch_size, min((j+1)*line_batch_size, num_lines))
|
||||
tokens_chunk = tokens[chunk_range]
|
||||
segments_chunk = segments[chunk_range]
|
||||
sep_indices_chunk = sep_indices[chunk_range]
|
||||
mask_chunk = mask[chunk_range]
|
||||
sep_len_chunk = sep_len[chunk_range]
|
||||
attention_mask_lm_nsp_chunk = attention_mask_lm_nsp[chunk_range]
|
||||
image_feat_chunk = image_feat[chunk_range]
|
||||
image_loc_chunk = image_loc[chunk_range]
|
||||
image_mask_chunk = image_mask[chunk_range]
|
||||
image_edge_indices_chunk = image_edge_indices[chunk_range[0]: chunk_range[-1]+1]
|
||||
image_edge_attributes_chunk = image_edge_attributes[chunk_range[0]: chunk_range[-1]+1]
|
||||
question_edge_indices_chunk = question_edge_indices[chunk_range[0]: chunk_range[-1]+1]
|
||||
question_edge_attributes_chunk = question_edge_attributes[chunk_range[0]: chunk_range[-1]+1]
|
||||
question_limits_chunk = question_limits[chunk_range[0]: chunk_range[-1]+1]
|
||||
history_edge_indices_chunk = history_edge_indices[chunk_range[0]: chunk_range[-1]+1]
|
||||
history_sep_indices_chunk = history_sep_indices[chunk_range[0]: chunk_range[-1]+1]
|
||||
|
||||
_, _, _, nsp_scores_chunk = \
|
||||
self.model(
|
||||
tokens_chunk,
|
||||
image_feat_chunk,
|
||||
image_loc_chunk,
|
||||
image_edge_indices_chunk,
|
||||
image_edge_attributes_chunk,
|
||||
question_edge_indices_chunk,
|
||||
question_edge_attributes_chunk,
|
||||
question_limits_chunk,
|
||||
history_edge_indices_chunk,
|
||||
history_sep_indices_chunk,
|
||||
sep_indices=sep_indices_chunk,
|
||||
sep_len=sep_len_chunk,
|
||||
token_type_ids=segments_chunk,
|
||||
masked_lm_labels=mask_chunk,
|
||||
attention_mask=attention_mask_lm_nsp_chunk,
|
||||
image_attention_mask=image_mask_chunk
|
||||
)
|
||||
nsp_scores.append(nsp_scores_chunk)
|
||||
nsp_scores = torch.cat(nsp_scores, 0)
|
||||
|
||||
else:
|
||||
_, _, _, nsp_scores = \
|
||||
self.model(
|
||||
tokens,
|
||||
image_feat,
|
||||
image_loc,
|
||||
image_edge_indices,
|
||||
image_edge_attributes,
|
||||
question_edge_indices,
|
||||
question_edge_attributes,
|
||||
question_limits,
|
||||
history_edge_indices,
|
||||
history_sep_indices,
|
||||
next_sentence_label=next_sentence_labels,
|
||||
image_target=image_target,
|
||||
image_label=image_label,
|
||||
sep_indices=sep_indices,
|
||||
sep_len=sep_len,
|
||||
token_type_ids=segments,
|
||||
masked_lm_labels=mask,
|
||||
attention_mask=attention_mask_lm_nsp,
|
||||
image_attention_mask=image_mask
|
||||
)
|
||||
|
||||
|
||||
if nsp_scores is not None:
|
||||
nsp_scores_output = nsp_scores.detach().clone()
|
||||
if not eval_visdial:
|
||||
nsp_scores = nsp_scores.view(-1, self.config['num_options_dense'], 2)
|
||||
if 'next_sentence_labels' in batch and self.config['nsp_loss_coeff'] > 0:
|
||||
next_sentence_labels = batch['next_sentence_labels'].to(self.config['device'])
|
||||
losses['nsp_loss'] = F.cross_entropy(nsp_scores.view(-1,2), next_sentence_labels.view(-1))
|
||||
else:
|
||||
losses['nsp_loss'] = None
|
||||
|
||||
if not eval_visdial:
|
||||
gt_relevance = batch['gt_relevance'].to(self.config['device'])
|
||||
nsp_scores = nsp_scores[:, :, 0]
|
||||
if self.config['dense_loss'] == 'ce':
|
||||
losses['dense_loss'] = self.dense_loss(F.log_softmax(nsp_scores, dim=1), F.softmax(gt_relevance, dim=1))
|
||||
else:
|
||||
losses['dense_loss'] = self.dense_loss(nsp_scores, gt_relevance)
|
||||
else:
|
||||
losses['dense_loss'] = None
|
||||
else:
|
||||
nsp_scores_output = None
|
||||
losses['nsp_loss'] = None
|
||||
losses['dense_loss'] = None
|
||||
|
||||
losses['tot_loss'] = 0
|
||||
for key in ['nsp_loss', 'dense_loss']:
|
||||
if key in losses and losses[key] is not None:
|
||||
losses[key] = losses[key].mean()
|
||||
losses['tot_loss'] += self.config[f'{key}_coeff'] * losses[key]
|
||||
|
||||
output = {
|
||||
'losses': losses,
|
||||
'nsp_scores': nsp_scores_output
|
||||
}
|
||||
|
||||
return output
|
2021
models/vilbert_dialog.py
Normal file
2021
models/vilbert_dialog.py
Normal file
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue