VDGR/models/vdgr.py

380 lines
17 KiB
Python
Raw Permalink Normal View History

2023-10-25 15:38:09 +02:00
import sys
from collections import OrderedDict
import torch
from torch import nn
import torch.nn.functional as F
sys.path.append('../')
from utils.model_utils import listMLE, approxNDCGLoss, listNet, neuralNDCG, neuralNDCG_transposed
from utils.data_utils import sequence_mask
from utils.optim_utils import init_optim
from models.runner import Runner
from models.vilbert_dialog import BertForMultiModalPreTraining, BertConfig
class VDGR(nn.Module):
def __init__(self, config_path, device, use_apex=False, cache_dir=None):
super(VDGR, self).__init__()
config = BertConfig.from_json_file(config_path)
self.bert_pretrained = BertForMultiModalPreTraining.from_pretrained('bert-base-uncased', config, device, use_apex=use_apex, cache_dir=cache_dir)
self.bert_pretrained.train()
def forward(self, input_ids, image_feat, image_loc, image_edge_indices, image_edge_attributes,
question_edge_indices, question_edge_attributes, question_limits,
history_edge_indices, history_sep_indices,
sep_indices=None, sep_len=None, token_type_ids=None,
attention_mask=None, masked_lm_labels=None, next_sentence_label=None,
image_attention_mask=None, image_label=None, image_target=None):
masked_lm_loss = None
masked_img_loss = None
nsp_loss = None
seq_relationship_score = None
if next_sentence_label is not None and masked_lm_labels \
is not None and image_target is not None:
# train mode, output losses
masked_lm_loss, masked_img_loss, nsp_loss, _, _, seq_relationship_score, _ = \
self.bert_pretrained(
input_ids, image_feat, image_loc, image_edge_indices, image_edge_attributes,
question_edge_indices, question_edge_attributes, question_limits,
history_edge_indices, history_sep_indices, sep_indices=sep_indices, sep_len=sep_len, \
token_type_ids=token_type_ids, attention_mask=attention_mask, masked_lm_labels=masked_lm_labels, \
next_sentence_label=next_sentence_label, image_attention_mask=image_attention_mask,\
image_label=image_label, image_target=image_target)
else:
#inference, output scores
_, _, seq_relationship_score, _, _, _ = \
self.bert_pretrained(
input_ids, image_feat, image_loc, image_edge_indices, image_edge_attributes,
question_edge_indices, question_edge_attributes, question_limits,
history_edge_indices, history_sep_indices,
sep_indices=sep_indices, sep_len=sep_len, \
token_type_ids=token_type_ids, attention_mask=attention_mask, masked_lm_labels=masked_lm_labels, \
next_sentence_label=next_sentence_label, image_attention_mask=image_attention_mask,\
image_label=image_label, image_target=image_target)
out = (masked_lm_loss, masked_img_loss, nsp_loss, seq_relationship_score)
return out
class SparseRunner(Runner):
def __init__(self, config):
super(SparseRunner, self).__init__(config)
self.model = VDGR(
self.config['model_config'], self.config['device'],
use_apex=self.config['dp_type'] == 'apex',
cache_dir=self.config['bert_cache_dir'])
self.model.to(self.config['device'])
if not self.config['validating'] or self.config['dp_type'] == 'apex':
self.optimizer, self.scheduler = init_optim(self.model, self.config)
def forward(self, batch, eval_visdial=False):
# load data
for key in batch:
if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].to(self.config['device'])
elif isinstance(batch[key], list):
if key != 'dialog_info': # Do not send the dialog_info item to the gpu
batch[key] = [x.to(self.config['device']) for x in batch[key]]
tokens = batch['tokens']
segments = batch['segments']
sep_indices = batch['sep_indices']
mask = batch['mask']
hist_len = batch['hist_len']
image_feat = batch['image_feat']
image_loc = batch['image_loc']
image_mask = batch['image_mask']
next_sentence_labels = batch.get('next_sentence_labels', None)
image_target = batch.get('image_target', None)
image_label = batch.get('image_label', None)
# load the graph data
image_edge_indices = batch['image_edge_indices']
image_edge_attributes = batch['image_edge_attributes']
question_edge_indices = batch['question_edge_indices']
question_edge_attributes = batch['question_edge_attributes']
question_limits = batch['question_limits']
history_edge_indices = batch['history_edge_indices']
history_sep_indices = batch['history_sep_indices']
sequence_lengths = torch.gather(sep_indices, 1, hist_len.view(-1, 1)) + 1
sequence_lengths = sequence_lengths.squeeze(1)
attention_mask_lm_nsp = sequence_mask(sequence_lengths, max_len=tokens.shape[1])
sep_len = hist_len + 1
losses = OrderedDict()
if eval_visdial:
num_lines = tokens.size(0)
line_batch_size = self.config['eval_line_batch_size']
num_line_batches = num_lines // line_batch_size
if num_lines % line_batch_size > 0:
num_line_batches += 1
nsp_scores = []
for j in range(num_line_batches):
# create chunks of the original batch
chunk_range = range(j*line_batch_size, min((j+1)*line_batch_size, num_lines))
tokens_chunk = tokens[chunk_range]
segments_chunk = segments[chunk_range]
sep_indices_chunk = sep_indices[chunk_range]
mask_chunk = mask[chunk_range]
sep_len_chunk = sep_len[chunk_range]
attention_mask_lm_nsp_chunk = attention_mask_lm_nsp[chunk_range]
image_feat_chunk = image_feat[chunk_range]
image_loc_chunk = image_loc[chunk_range]
image_mask_chunk = image_mask[chunk_range]
image_edge_indices_chunk = image_edge_indices[chunk_range[0]: chunk_range[-1]+1]
image_edge_attributes_chunk = image_edge_attributes[chunk_range[0]: chunk_range[-1]+1]
question_edge_indices_chunk = question_edge_indices[chunk_range[0]: chunk_range[-1]+1]
question_edge_attributes_chunk = question_edge_attributes[chunk_range[0]: chunk_range[-1]+1]
question_limits_chunk = question_limits[chunk_range[0]: chunk_range[-1]+1]
history_edge_indices_chunk = history_edge_indices[chunk_range[0]: chunk_range[-1]+1]
history_sep_indices_chunk = history_sep_indices[chunk_range[0]: chunk_range[-1]+1]
_ , _ , _, nsp_scores_chunk = \
self.model(
tokens_chunk,
image_feat_chunk,
image_loc_chunk,
image_edge_indices_chunk,
image_edge_attributes_chunk,
question_edge_indices_chunk,
question_edge_attributes_chunk,
question_limits_chunk,
history_edge_indices_chunk,
history_sep_indices_chunk,
sep_indices=sep_indices_chunk,
sep_len=sep_len_chunk,
token_type_ids=segments_chunk,
masked_lm_labels=mask_chunk,
attention_mask=attention_mask_lm_nsp_chunk,
image_attention_mask=image_mask_chunk
)
nsp_scores.append(nsp_scores_chunk)
nsp_scores = torch.cat(nsp_scores, 0)
else:
losses['lm_loss'], losses['img_loss'], losses['nsp_loss'], nsp_scores = \
self.model(
tokens,
image_feat,
image_loc,
image_edge_indices,
image_edge_attributes,
question_edge_indices,
question_edge_attributes,
question_limits,
history_edge_indices,
history_sep_indices,
next_sentence_label=next_sentence_labels,
image_target=image_target,
image_label=image_label,
sep_indices=sep_indices,
sep_len=sep_len,
token_type_ids=segments,
masked_lm_labels=mask,
attention_mask=attention_mask_lm_nsp,
image_attention_mask=image_mask
)
losses['tot_loss'] = 0
for key in ['lm_loss', 'img_loss', 'nsp_loss']:
if key in losses and losses[key] is not None:
losses[key] = losses[key].mean()
losses['tot_loss'] += self.config[f'{key}_coeff'] * losses[key]
output = {
'losses': losses,
'nsp_scores': nsp_scores
}
return output
class DenseRunner(Runner):
def __init__(self, config):
super(DenseRunner, self).__init__(config)
self.model = VDGR(
self.config['model_config'], self.config['device'],
use_apex=self.config['dp_type'] == 'apex',
cache_dir=self.config['bert_cache_dir'])
if not(self.config['parallel'] and self.config['dp_type'] == 'dp'):
self.model.to(self.config['device'])
if self.config['dense_loss'] == 'ce':
self.dense_loss = nn.KLDivLoss(reduction='batchmean')
elif self.config['dense_loss'] == 'listmle':
self.dense_loss = listMLE
elif self.config['dense_loss'] == 'listnet':
self.dense_loss = listNet
elif self.config['dense_loss'] == 'approxndcg':
self.dense_loss = approxNDCGLoss
elif self.config['dense_loss'] == 'neural_ndcg':
self.dense_loss = neuralNDCG
elif self.config['dense_loss'] == 'neural_ndcg_transposed':
self.dense_loss = neuralNDCG_transposed
else:
raise ValueError('dense_loss must be one of ce, listmle, listnet, approxndcg, neural_ndcg, neural_ndcg_transposed')
if not self.config['validating'] or self.config['dp_type'] == 'apex':
self.optimizer, self.scheduler = init_optim(self.model, self.config)
def forward(self, batch, eval_visdial=False):
# load data
for key in batch:
if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].to(self.config['device'])
elif isinstance(batch[key], list):
if key != 'dialog_info': # Do not send the dialog_info item to the gpu
batch[key] = [x.to(self.config['device']) for x in batch[key]]
# get embedding and forward visdial
tokens = batch['tokens']
segments = batch['segments']
sep_indices = batch['sep_indices']
mask = batch['mask']
hist_len = batch['hist_len']
image_feat = batch['image_feat']
image_loc = batch['image_loc']
image_mask = batch['image_mask']
next_sentence_labels = batch.get('next_sentence_labels', None)
image_target = batch.get('image_target', None)
image_label = batch.get('image_label', None)
# load the graph data
image_edge_indices = batch['image_edge_indices']
image_edge_attributes = batch['image_edge_attributes']
question_edge_indices = batch['question_edge_indices']
question_edge_attributes = batch['question_edge_attributes']
question_limits = batch['question_limits']
history_edge_indices = batch['history_edge_indices']
assert history_edge_indices[0].size(0) == 2
history_sep_indices = batch['history_sep_indices']
sequence_lengths = torch.gather(sep_indices, 1, hist_len.view(-1, 1)) + 1
sequence_lengths = sequence_lengths.squeeze(1)
attention_mask_lm_nsp = sequence_mask(sequence_lengths, max_len=tokens.shape[1])
sep_len = hist_len + 1
losses = OrderedDict()
if eval_visdial:
num_lines = tokens.size(0)
line_batch_size = self.config['eval_line_batch_size']
num_line_batches = num_lines // line_batch_size
if num_lines % line_batch_size > 0:
num_line_batches += 1
nsp_scores = []
for j in range(num_line_batches):
# create chunks of the original batch
chunk_range = range(j*line_batch_size, min((j+1)*line_batch_size, num_lines))
tokens_chunk = tokens[chunk_range]
segments_chunk = segments[chunk_range]
sep_indices_chunk = sep_indices[chunk_range]
mask_chunk = mask[chunk_range]
sep_len_chunk = sep_len[chunk_range]
attention_mask_lm_nsp_chunk = attention_mask_lm_nsp[chunk_range]
image_feat_chunk = image_feat[chunk_range]
image_loc_chunk = image_loc[chunk_range]
image_mask_chunk = image_mask[chunk_range]
image_edge_indices_chunk = image_edge_indices[chunk_range[0]: chunk_range[-1]+1]
image_edge_attributes_chunk = image_edge_attributes[chunk_range[0]: chunk_range[-1]+1]
question_edge_indices_chunk = question_edge_indices[chunk_range[0]: chunk_range[-1]+1]
question_edge_attributes_chunk = question_edge_attributes[chunk_range[0]: chunk_range[-1]+1]
question_limits_chunk = question_limits[chunk_range[0]: chunk_range[-1]+1]
history_edge_indices_chunk = history_edge_indices[chunk_range[0]: chunk_range[-1]+1]
history_sep_indices_chunk = history_sep_indices[chunk_range[0]: chunk_range[-1]+1]
_, _, _, nsp_scores_chunk = \
self.model(
tokens_chunk,
image_feat_chunk,
image_loc_chunk,
image_edge_indices_chunk,
image_edge_attributes_chunk,
question_edge_indices_chunk,
question_edge_attributes_chunk,
question_limits_chunk,
history_edge_indices_chunk,
history_sep_indices_chunk,
sep_indices=sep_indices_chunk,
sep_len=sep_len_chunk,
token_type_ids=segments_chunk,
masked_lm_labels=mask_chunk,
attention_mask=attention_mask_lm_nsp_chunk,
image_attention_mask=image_mask_chunk
)
nsp_scores.append(nsp_scores_chunk)
nsp_scores = torch.cat(nsp_scores, 0)
else:
_, _, _, nsp_scores = \
self.model(
tokens,
image_feat,
image_loc,
image_edge_indices,
image_edge_attributes,
question_edge_indices,
question_edge_attributes,
question_limits,
history_edge_indices,
history_sep_indices,
next_sentence_label=next_sentence_labels,
image_target=image_target,
image_label=image_label,
sep_indices=sep_indices,
sep_len=sep_len,
token_type_ids=segments,
masked_lm_labels=mask,
attention_mask=attention_mask_lm_nsp,
image_attention_mask=image_mask
)
if nsp_scores is not None:
nsp_scores_output = nsp_scores.detach().clone()
if not eval_visdial:
nsp_scores = nsp_scores.view(-1, self.config['num_options_dense'], 2)
if 'next_sentence_labels' in batch and self.config['nsp_loss_coeff'] > 0:
next_sentence_labels = batch['next_sentence_labels'].to(self.config['device'])
losses['nsp_loss'] = F.cross_entropy(nsp_scores.view(-1,2), next_sentence_labels.view(-1))
else:
losses['nsp_loss'] = None
if not eval_visdial:
gt_relevance = batch['gt_relevance'].to(self.config['device'])
nsp_scores = nsp_scores[:, :, 0]
if self.config['dense_loss'] == 'ce':
losses['dense_loss'] = self.dense_loss(F.log_softmax(nsp_scores, dim=1), F.softmax(gt_relevance, dim=1))
else:
losses['dense_loss'] = self.dense_loss(nsp_scores, gt_relevance)
else:
losses['dense_loss'] = None
else:
nsp_scores_output = None
losses['nsp_loss'] = None
losses['dense_loss'] = None
losses['tot_loss'] = 0
for key in ['nsp_loss', 'dense_loss']:
if key in losses and losses[key] is not None:
losses[key] = losses[key].mean()
losses['tot_loss'] += self.config[f'{key}_coeff'] * losses[key]
output = {
'losses': losses,
'nsp_scores': nsp_scores_output
}
return output