380 lines
17 KiB
Python
380 lines
17 KiB
Python
|
import sys
|
||
|
from collections import OrderedDict
|
||
|
|
||
|
import torch
|
||
|
from torch import nn
|
||
|
import torch.nn.functional as F
|
||
|
|
||
|
sys.path.append('../')
|
||
|
from utils.model_utils import listMLE, approxNDCGLoss, listNet, neuralNDCG, neuralNDCG_transposed
|
||
|
|
||
|
from utils.data_utils import sequence_mask
|
||
|
from utils.optim_utils import init_optim
|
||
|
from models.runner import Runner
|
||
|
|
||
|
from models.vilbert_dialog import BertForMultiModalPreTraining, BertConfig
|
||
|
|
||
|
|
||
|
class VDGR(nn.Module):
|
||
|
|
||
|
def __init__(self, config_path, device, use_apex=False, cache_dir=None):
|
||
|
super(VDGR, self).__init__()
|
||
|
config = BertConfig.from_json_file(config_path)
|
||
|
|
||
|
self.bert_pretrained = BertForMultiModalPreTraining.from_pretrained('bert-base-uncased', config, device, use_apex=use_apex, cache_dir=cache_dir)
|
||
|
self.bert_pretrained.train()
|
||
|
|
||
|
def forward(self, input_ids, image_feat, image_loc, image_edge_indices, image_edge_attributes,
|
||
|
question_edge_indices, question_edge_attributes, question_limits,
|
||
|
history_edge_indices, history_sep_indices,
|
||
|
sep_indices=None, sep_len=None, token_type_ids=None,
|
||
|
attention_mask=None, masked_lm_labels=None, next_sentence_label=None,
|
||
|
image_attention_mask=None, image_label=None, image_target=None):
|
||
|
|
||
|
masked_lm_loss = None
|
||
|
masked_img_loss = None
|
||
|
nsp_loss = None
|
||
|
seq_relationship_score = None
|
||
|
|
||
|
if next_sentence_label is not None and masked_lm_labels \
|
||
|
is not None and image_target is not None:
|
||
|
# train mode, output losses
|
||
|
masked_lm_loss, masked_img_loss, nsp_loss, _, _, seq_relationship_score, _ = \
|
||
|
self.bert_pretrained(
|
||
|
input_ids, image_feat, image_loc, image_edge_indices, image_edge_attributes,
|
||
|
question_edge_indices, question_edge_attributes, question_limits,
|
||
|
history_edge_indices, history_sep_indices, sep_indices=sep_indices, sep_len=sep_len, \
|
||
|
token_type_ids=token_type_ids, attention_mask=attention_mask, masked_lm_labels=masked_lm_labels, \
|
||
|
next_sentence_label=next_sentence_label, image_attention_mask=image_attention_mask,\
|
||
|
image_label=image_label, image_target=image_target)
|
||
|
else:
|
||
|
#inference, output scores
|
||
|
_, _, seq_relationship_score, _, _, _ = \
|
||
|
self.bert_pretrained(
|
||
|
input_ids, image_feat, image_loc, image_edge_indices, image_edge_attributes,
|
||
|
question_edge_indices, question_edge_attributes, question_limits,
|
||
|
history_edge_indices, history_sep_indices,
|
||
|
sep_indices=sep_indices, sep_len=sep_len, \
|
||
|
token_type_ids=token_type_ids, attention_mask=attention_mask, masked_lm_labels=masked_lm_labels, \
|
||
|
next_sentence_label=next_sentence_label, image_attention_mask=image_attention_mask,\
|
||
|
image_label=image_label, image_target=image_target)
|
||
|
|
||
|
out = (masked_lm_loss, masked_img_loss, nsp_loss, seq_relationship_score)
|
||
|
|
||
|
return out
|
||
|
|
||
|
|
||
|
class SparseRunner(Runner):
|
||
|
def __init__(self, config):
|
||
|
super(SparseRunner, self).__init__(config)
|
||
|
self.model = VDGR(
|
||
|
self.config['model_config'], self.config['device'],
|
||
|
use_apex=self.config['dp_type'] == 'apex',
|
||
|
cache_dir=self.config['bert_cache_dir'])
|
||
|
|
||
|
self.model.to(self.config['device'])
|
||
|
|
||
|
if not self.config['validating'] or self.config['dp_type'] == 'apex':
|
||
|
self.optimizer, self.scheduler = init_optim(self.model, self.config)
|
||
|
|
||
|
def forward(self, batch, eval_visdial=False):
|
||
|
# load data
|
||
|
for key in batch:
|
||
|
if isinstance(batch[key], torch.Tensor):
|
||
|
batch[key] = batch[key].to(self.config['device'])
|
||
|
elif isinstance(batch[key], list):
|
||
|
if key != 'dialog_info': # Do not send the dialog_info item to the gpu
|
||
|
batch[key] = [x.to(self.config['device']) for x in batch[key]]
|
||
|
|
||
|
tokens = batch['tokens']
|
||
|
segments = batch['segments']
|
||
|
sep_indices = batch['sep_indices']
|
||
|
mask = batch['mask']
|
||
|
hist_len = batch['hist_len']
|
||
|
image_feat = batch['image_feat']
|
||
|
image_loc = batch['image_loc']
|
||
|
image_mask = batch['image_mask']
|
||
|
next_sentence_labels = batch.get('next_sentence_labels', None)
|
||
|
image_target = batch.get('image_target', None)
|
||
|
image_label = batch.get('image_label', None)
|
||
|
# load the graph data
|
||
|
image_edge_indices = batch['image_edge_indices']
|
||
|
image_edge_attributes = batch['image_edge_attributes']
|
||
|
question_edge_indices = batch['question_edge_indices']
|
||
|
question_edge_attributes = batch['question_edge_attributes']
|
||
|
question_limits = batch['question_limits']
|
||
|
history_edge_indices = batch['history_edge_indices']
|
||
|
history_sep_indices = batch['history_sep_indices']
|
||
|
|
||
|
sequence_lengths = torch.gather(sep_indices, 1, hist_len.view(-1, 1)) + 1
|
||
|
sequence_lengths = sequence_lengths.squeeze(1)
|
||
|
attention_mask_lm_nsp = sequence_mask(sequence_lengths, max_len=tokens.shape[1])
|
||
|
sep_len = hist_len + 1
|
||
|
|
||
|
losses = OrderedDict()
|
||
|
|
||
|
if eval_visdial:
|
||
|
num_lines = tokens.size(0)
|
||
|
line_batch_size = self.config['eval_line_batch_size']
|
||
|
num_line_batches = num_lines // line_batch_size
|
||
|
if num_lines % line_batch_size > 0:
|
||
|
num_line_batches += 1
|
||
|
nsp_scores = []
|
||
|
for j in range(num_line_batches):
|
||
|
# create chunks of the original batch
|
||
|
chunk_range = range(j*line_batch_size, min((j+1)*line_batch_size, num_lines))
|
||
|
tokens_chunk = tokens[chunk_range]
|
||
|
segments_chunk = segments[chunk_range]
|
||
|
sep_indices_chunk = sep_indices[chunk_range]
|
||
|
mask_chunk = mask[chunk_range]
|
||
|
sep_len_chunk = sep_len[chunk_range]
|
||
|
attention_mask_lm_nsp_chunk = attention_mask_lm_nsp[chunk_range]
|
||
|
image_feat_chunk = image_feat[chunk_range]
|
||
|
image_loc_chunk = image_loc[chunk_range]
|
||
|
image_mask_chunk = image_mask[chunk_range]
|
||
|
image_edge_indices_chunk = image_edge_indices[chunk_range[0]: chunk_range[-1]+1]
|
||
|
image_edge_attributes_chunk = image_edge_attributes[chunk_range[0]: chunk_range[-1]+1]
|
||
|
question_edge_indices_chunk = question_edge_indices[chunk_range[0]: chunk_range[-1]+1]
|
||
|
question_edge_attributes_chunk = question_edge_attributes[chunk_range[0]: chunk_range[-1]+1]
|
||
|
question_limits_chunk = question_limits[chunk_range[0]: chunk_range[-1]+1]
|
||
|
history_edge_indices_chunk = history_edge_indices[chunk_range[0]: chunk_range[-1]+1]
|
||
|
history_sep_indices_chunk = history_sep_indices[chunk_range[0]: chunk_range[-1]+1]
|
||
|
|
||
|
_ , _ , _, nsp_scores_chunk = \
|
||
|
self.model(
|
||
|
tokens_chunk,
|
||
|
image_feat_chunk,
|
||
|
image_loc_chunk,
|
||
|
image_edge_indices_chunk,
|
||
|
image_edge_attributes_chunk,
|
||
|
question_edge_indices_chunk,
|
||
|
question_edge_attributes_chunk,
|
||
|
question_limits_chunk,
|
||
|
history_edge_indices_chunk,
|
||
|
history_sep_indices_chunk,
|
||
|
sep_indices=sep_indices_chunk,
|
||
|
sep_len=sep_len_chunk,
|
||
|
token_type_ids=segments_chunk,
|
||
|
masked_lm_labels=mask_chunk,
|
||
|
attention_mask=attention_mask_lm_nsp_chunk,
|
||
|
image_attention_mask=image_mask_chunk
|
||
|
)
|
||
|
nsp_scores.append(nsp_scores_chunk)
|
||
|
nsp_scores = torch.cat(nsp_scores, 0)
|
||
|
|
||
|
else:
|
||
|
losses['lm_loss'], losses['img_loss'], losses['nsp_loss'], nsp_scores = \
|
||
|
self.model(
|
||
|
tokens,
|
||
|
image_feat,
|
||
|
image_loc,
|
||
|
image_edge_indices,
|
||
|
image_edge_attributes,
|
||
|
question_edge_indices,
|
||
|
question_edge_attributes,
|
||
|
question_limits,
|
||
|
history_edge_indices,
|
||
|
history_sep_indices,
|
||
|
next_sentence_label=next_sentence_labels,
|
||
|
image_target=image_target,
|
||
|
image_label=image_label,
|
||
|
sep_indices=sep_indices,
|
||
|
sep_len=sep_len,
|
||
|
token_type_ids=segments,
|
||
|
masked_lm_labels=mask,
|
||
|
attention_mask=attention_mask_lm_nsp,
|
||
|
image_attention_mask=image_mask
|
||
|
)
|
||
|
|
||
|
losses['tot_loss'] = 0
|
||
|
for key in ['lm_loss', 'img_loss', 'nsp_loss']:
|
||
|
if key in losses and losses[key] is not None:
|
||
|
losses[key] = losses[key].mean()
|
||
|
losses['tot_loss'] += self.config[f'{key}_coeff'] * losses[key]
|
||
|
|
||
|
output = {
|
||
|
'losses': losses,
|
||
|
'nsp_scores': nsp_scores
|
||
|
}
|
||
|
return output
|
||
|
|
||
|
|
||
|
class DenseRunner(Runner):
|
||
|
def __init__(self, config):
|
||
|
super(DenseRunner, self).__init__(config)
|
||
|
self.model = VDGR(
|
||
|
self.config['model_config'], self.config['device'],
|
||
|
use_apex=self.config['dp_type'] == 'apex',
|
||
|
cache_dir=self.config['bert_cache_dir'])
|
||
|
|
||
|
if not(self.config['parallel'] and self.config['dp_type'] == 'dp'):
|
||
|
self.model.to(self.config['device'])
|
||
|
|
||
|
if self.config['dense_loss'] == 'ce':
|
||
|
self.dense_loss = nn.KLDivLoss(reduction='batchmean')
|
||
|
elif self.config['dense_loss'] == 'listmle':
|
||
|
self.dense_loss = listMLE
|
||
|
elif self.config['dense_loss'] == 'listnet':
|
||
|
self.dense_loss = listNet
|
||
|
elif self.config['dense_loss'] == 'approxndcg':
|
||
|
self.dense_loss = approxNDCGLoss
|
||
|
elif self.config['dense_loss'] == 'neural_ndcg':
|
||
|
self.dense_loss = neuralNDCG
|
||
|
elif self.config['dense_loss'] == 'neural_ndcg_transposed':
|
||
|
self.dense_loss = neuralNDCG_transposed
|
||
|
else:
|
||
|
raise ValueError('dense_loss must be one of ce, listmle, listnet, approxndcg, neural_ndcg, neural_ndcg_transposed')
|
||
|
|
||
|
if not self.config['validating'] or self.config['dp_type'] == 'apex':
|
||
|
self.optimizer, self.scheduler = init_optim(self.model, self.config)
|
||
|
|
||
|
def forward(self, batch, eval_visdial=False):
|
||
|
# load data
|
||
|
for key in batch:
|
||
|
if isinstance(batch[key], torch.Tensor):
|
||
|
batch[key] = batch[key].to(self.config['device'])
|
||
|
elif isinstance(batch[key], list):
|
||
|
if key != 'dialog_info': # Do not send the dialog_info item to the gpu
|
||
|
batch[key] = [x.to(self.config['device']) for x in batch[key]]
|
||
|
|
||
|
# get embedding and forward visdial
|
||
|
tokens = batch['tokens']
|
||
|
segments = batch['segments']
|
||
|
sep_indices = batch['sep_indices']
|
||
|
mask = batch['mask']
|
||
|
hist_len = batch['hist_len']
|
||
|
image_feat = batch['image_feat']
|
||
|
image_loc = batch['image_loc']
|
||
|
image_mask = batch['image_mask']
|
||
|
next_sentence_labels = batch.get('next_sentence_labels', None)
|
||
|
image_target = batch.get('image_target', None)
|
||
|
image_label = batch.get('image_label', None)
|
||
|
|
||
|
# load the graph data
|
||
|
image_edge_indices = batch['image_edge_indices']
|
||
|
image_edge_attributes = batch['image_edge_attributes']
|
||
|
question_edge_indices = batch['question_edge_indices']
|
||
|
question_edge_attributes = batch['question_edge_attributes']
|
||
|
question_limits = batch['question_limits']
|
||
|
history_edge_indices = batch['history_edge_indices']
|
||
|
assert history_edge_indices[0].size(0) == 2
|
||
|
history_sep_indices = batch['history_sep_indices']
|
||
|
|
||
|
sequence_lengths = torch.gather(sep_indices, 1, hist_len.view(-1, 1)) + 1
|
||
|
sequence_lengths = sequence_lengths.squeeze(1)
|
||
|
attention_mask_lm_nsp = sequence_mask(sequence_lengths, max_len=tokens.shape[1])
|
||
|
sep_len = hist_len + 1
|
||
|
|
||
|
losses = OrderedDict()
|
||
|
|
||
|
if eval_visdial:
|
||
|
num_lines = tokens.size(0)
|
||
|
line_batch_size = self.config['eval_line_batch_size']
|
||
|
num_line_batches = num_lines // line_batch_size
|
||
|
if num_lines % line_batch_size > 0:
|
||
|
num_line_batches += 1
|
||
|
nsp_scores = []
|
||
|
for j in range(num_line_batches):
|
||
|
# create chunks of the original batch
|
||
|
chunk_range = range(j*line_batch_size, min((j+1)*line_batch_size, num_lines))
|
||
|
tokens_chunk = tokens[chunk_range]
|
||
|
segments_chunk = segments[chunk_range]
|
||
|
sep_indices_chunk = sep_indices[chunk_range]
|
||
|
mask_chunk = mask[chunk_range]
|
||
|
sep_len_chunk = sep_len[chunk_range]
|
||
|
attention_mask_lm_nsp_chunk = attention_mask_lm_nsp[chunk_range]
|
||
|
image_feat_chunk = image_feat[chunk_range]
|
||
|
image_loc_chunk = image_loc[chunk_range]
|
||
|
image_mask_chunk = image_mask[chunk_range]
|
||
|
image_edge_indices_chunk = image_edge_indices[chunk_range[0]: chunk_range[-1]+1]
|
||
|
image_edge_attributes_chunk = image_edge_attributes[chunk_range[0]: chunk_range[-1]+1]
|
||
|
question_edge_indices_chunk = question_edge_indices[chunk_range[0]: chunk_range[-1]+1]
|
||
|
question_edge_attributes_chunk = question_edge_attributes[chunk_range[0]: chunk_range[-1]+1]
|
||
|
question_limits_chunk = question_limits[chunk_range[0]: chunk_range[-1]+1]
|
||
|
history_edge_indices_chunk = history_edge_indices[chunk_range[0]: chunk_range[-1]+1]
|
||
|
history_sep_indices_chunk = history_sep_indices[chunk_range[0]: chunk_range[-1]+1]
|
||
|
|
||
|
_, _, _, nsp_scores_chunk = \
|
||
|
self.model(
|
||
|
tokens_chunk,
|
||
|
image_feat_chunk,
|
||
|
image_loc_chunk,
|
||
|
image_edge_indices_chunk,
|
||
|
image_edge_attributes_chunk,
|
||
|
question_edge_indices_chunk,
|
||
|
question_edge_attributes_chunk,
|
||
|
question_limits_chunk,
|
||
|
history_edge_indices_chunk,
|
||
|
history_sep_indices_chunk,
|
||
|
sep_indices=sep_indices_chunk,
|
||
|
sep_len=sep_len_chunk,
|
||
|
token_type_ids=segments_chunk,
|
||
|
masked_lm_labels=mask_chunk,
|
||
|
attention_mask=attention_mask_lm_nsp_chunk,
|
||
|
image_attention_mask=image_mask_chunk
|
||
|
)
|
||
|
nsp_scores.append(nsp_scores_chunk)
|
||
|
nsp_scores = torch.cat(nsp_scores, 0)
|
||
|
|
||
|
else:
|
||
|
_, _, _, nsp_scores = \
|
||
|
self.model(
|
||
|
tokens,
|
||
|
image_feat,
|
||
|
image_loc,
|
||
|
image_edge_indices,
|
||
|
image_edge_attributes,
|
||
|
question_edge_indices,
|
||
|
question_edge_attributes,
|
||
|
question_limits,
|
||
|
history_edge_indices,
|
||
|
history_sep_indices,
|
||
|
next_sentence_label=next_sentence_labels,
|
||
|
image_target=image_target,
|
||
|
image_label=image_label,
|
||
|
sep_indices=sep_indices,
|
||
|
sep_len=sep_len,
|
||
|
token_type_ids=segments,
|
||
|
masked_lm_labels=mask,
|
||
|
attention_mask=attention_mask_lm_nsp,
|
||
|
image_attention_mask=image_mask
|
||
|
)
|
||
|
|
||
|
|
||
|
if nsp_scores is not None:
|
||
|
nsp_scores_output = nsp_scores.detach().clone()
|
||
|
if not eval_visdial:
|
||
|
nsp_scores = nsp_scores.view(-1, self.config['num_options_dense'], 2)
|
||
|
if 'next_sentence_labels' in batch and self.config['nsp_loss_coeff'] > 0:
|
||
|
next_sentence_labels = batch['next_sentence_labels'].to(self.config['device'])
|
||
|
losses['nsp_loss'] = F.cross_entropy(nsp_scores.view(-1,2), next_sentence_labels.view(-1))
|
||
|
else:
|
||
|
losses['nsp_loss'] = None
|
||
|
|
||
|
if not eval_visdial:
|
||
|
gt_relevance = batch['gt_relevance'].to(self.config['device'])
|
||
|
nsp_scores = nsp_scores[:, :, 0]
|
||
|
if self.config['dense_loss'] == 'ce':
|
||
|
losses['dense_loss'] = self.dense_loss(F.log_softmax(nsp_scores, dim=1), F.softmax(gt_relevance, dim=1))
|
||
|
else:
|
||
|
losses['dense_loss'] = self.dense_loss(nsp_scores, gt_relevance)
|
||
|
else:
|
||
|
losses['dense_loss'] = None
|
||
|
else:
|
||
|
nsp_scores_output = None
|
||
|
losses['nsp_loss'] = None
|
||
|
losses['dense_loss'] = None
|
||
|
|
||
|
losses['tot_loss'] = 0
|
||
|
for key in ['nsp_loss', 'dense_loss']:
|
||
|
if key in losses and losses[key] is not None:
|
||
|
losses[key] = losses[key].mean()
|
||
|
losses['tot_loss'] += self.config[f'{key}_coeff'] * losses[key]
|
||
|
|
||
|
output = {
|
||
|
'losses': losses,
|
||
|
'nsp_scores': nsp_scores_output
|
||
|
}
|
||
|
|
||
|
return output
|