Code release
This commit is contained in:
commit
09fb25e339
29 changed files with 7162 additions and 0 deletions
0
dataloader/__init__.py
Normal file
0
dataloader/__init__.py
Normal file
269
dataloader/dataloader_base.py
Normal file
269
dataloader/dataloader_base.py
Normal file
|
@ -0,0 +1,269 @@
|
|||
import torch
|
||||
from torch.utils import data
|
||||
import json
|
||||
import os
|
||||
import glog as log
|
||||
import pickle
|
||||
|
||||
import torch.utils.data as tud
|
||||
from pytorch_transformers.tokenization_bert import BertTokenizer
|
||||
|
||||
from utils.image_features_reader import ImageFeaturesH5Reader
|
||||
|
||||
|
||||
class DatasetBase(data.Dataset):
|
||||
|
||||
def __init__(self, config):
|
||||
|
||||
if config['display']:
|
||||
log.info('Initializing dataset')
|
||||
|
||||
# Fetch the correct dataset for evaluation
|
||||
if config['validating']:
|
||||
assert config.eval_dataset in ['visdial', 'visdial_conv', 'visdial_vispro', 'visdial_v09']
|
||||
if config.eval_dataset == 'visdial_conv':
|
||||
config['visdial_val'] = config.visdialconv_val
|
||||
config['visdial_val_dense_annotations'] = config.visdialconv_val_dense_annotations
|
||||
elif config.eval_dataset == 'visdial_vispro':
|
||||
config['visdial_val'] = config.visdialvispro_val
|
||||
config['visdial_val_dense_annotations'] = config.visdialvispro_val_dense_annotations
|
||||
elif config.eval_dataset == 'visdial_v09':
|
||||
config['visdial_val_09'] = config.visdial_test_09
|
||||
config['visdial_val_dense_annotations'] = None
|
||||
|
||||
self.config = config
|
||||
self.numDataPoints = {}
|
||||
|
||||
if not config['dataloader_text_only']:
|
||||
self._image_features_reader = ImageFeaturesH5Reader(
|
||||
config['visdial_image_feats'],
|
||||
config['visdial_image_adj_matrices']
|
||||
)
|
||||
|
||||
if self.config['training'] or self.config['validating'] or self.config['predicting']:
|
||||
split2data = {'train': 'train', 'val': 'val', 'test': 'test'}
|
||||
elif self.config['debugging']:
|
||||
split2data = {'train': 'val', 'val': 'val', 'test': 'test'}
|
||||
elif self.config['visualizing']:
|
||||
split2data = {'train': 'train', 'val': 'train', 'test': 'test'}
|
||||
|
||||
filename = f'visdial_{split2data["train"]}'
|
||||
if config['train_on_dense']:
|
||||
filename += '_dense'
|
||||
if self.config['visdial_version'] == 0.9:
|
||||
filename += '_09'
|
||||
|
||||
with open(config[filename]) as f:
|
||||
self.visdial_data_train = json.load(f)
|
||||
if self.config.num_samples > 0:
|
||||
self.visdial_data_train['data']['dialogs'] = self.visdial_data_train['data']['dialogs'][:self.config.num_samples]
|
||||
self.numDataPoints['train'] = len(self.visdial_data_train['data']['dialogs'])
|
||||
|
||||
filename = f'visdial_{split2data["val"]}'
|
||||
if config['train_on_dense'] and config['training']:
|
||||
filename += '_dense'
|
||||
if self.config['visdial_version'] == 0.9:
|
||||
filename += '_09'
|
||||
|
||||
with open(config[filename]) as f:
|
||||
self.visdial_data_val = json.load(f)
|
||||
if self.config.num_samples > 0:
|
||||
self.visdial_data_val['data']['dialogs'] = self.visdial_data_val['data']['dialogs'][:self.config.num_samples]
|
||||
self.numDataPoints['val'] = len(self.visdial_data_val['data']['dialogs'])
|
||||
|
||||
if config['train_on_dense']:
|
||||
self.numDataPoints['trainval'] = self.numDataPoints['train'] + self.numDataPoints['val']
|
||||
with open(config[f'visdial_{split2data["test"]}']) as f:
|
||||
self.visdial_data_test = json.load(f)
|
||||
self.numDataPoints['test'] = len(self.visdial_data_test['data']['dialogs'])
|
||||
|
||||
self.rlv_hst_train = None
|
||||
self.rlv_hst_val = None
|
||||
self.rlv_hst_test = None
|
||||
|
||||
if config['train_on_dense'] or config['predict_dense_round']:
|
||||
with open(config[f'visdial_{split2data["train"]}_dense_annotations']) as f:
|
||||
self.visdial_data_train_dense = json.load(f)
|
||||
if config['train_on_dense']:
|
||||
self.subsets = ['train', 'val', 'trainval', 'test']
|
||||
else:
|
||||
self.subsets = ['train','val','test']
|
||||
self.num_options = config["num_options"]
|
||||
self.num_options_dense = config["num_options_dense"]
|
||||
if config['visdial_version'] != 0.9:
|
||||
with open(config[f'visdial_{split2data["val"]}_dense_annotations']) as f:
|
||||
self.visdial_data_val_dense = json.load(f)
|
||||
else:
|
||||
self.visdial_data_val_dense = None
|
||||
self._split = 'train'
|
||||
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', cache_dir=config['bert_cache_dir'])
|
||||
# fetching token indicecs of [CLS] and [SEP]
|
||||
tokens = ['[CLS]','[MASK]','[SEP]']
|
||||
indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokens)
|
||||
self.CLS = indexed_tokens[0]
|
||||
self.MASK = indexed_tokens[1]
|
||||
self.SEP = indexed_tokens[2]
|
||||
self._max_region_num = 37
|
||||
self.predict_each_round = self.config['predicting'] and self.config['predict_each_round']
|
||||
|
||||
self.keys_to_expand = ['image_feat', 'image_loc', 'image_mask', 'image_target', 'image_label']
|
||||
self.keys_to_flatten_1d = ['hist_len', 'next_sentence_labels', 'round_id', 'image_id']
|
||||
self.keys_to_flatten_2d = ['tokens', 'segments', 'sep_indices', 'mask', 'image_mask', 'image_label', 'input_mask', 'question_limits']
|
||||
self.keys_to_flatten_3d = ['image_feat', 'image_loc', 'image_target', ]
|
||||
self.keys_other = ['gt_relevance', 'gt_option_inds']
|
||||
self.keys_lists_to_flatten = ['image_edge_indices', 'image_edge_attributes', 'question_edge_indices', 'question_edge_attributes', 'history_edge_indices', 'history_sep_indices']
|
||||
if config['stack_gr_data']:
|
||||
self.keys_to_flatten_3d.extend(self.keys_lists_to_flatten[:-1])
|
||||
self.keys_to_flatten_2d.append(self.keys_lists_to_flatten[-1])
|
||||
self.keys_to_flatten_1d.extend(['len_image_gr', 'len_question_gr', 'len_history_gr', 'len_history_sep'])
|
||||
self.keys_lists_to_flatten = []
|
||||
|
||||
self.keys_to_list = ['tot_len']
|
||||
|
||||
# Load the parse vocab for question graph relationship mapping
|
||||
if os.path.isfile(config['visdial_question_parse_vocab']):
|
||||
with open(config['visdial_question_parse_vocab'], 'rb') as f:
|
||||
self.parse_vocab = pickle.load(f)
|
||||
|
||||
def __len__(self):
|
||||
return self.numDataPoints[self._split]
|
||||
|
||||
@property
|
||||
def split(self):
|
||||
return self._split
|
||||
|
||||
@split.setter
|
||||
def split(self, split):
|
||||
assert split in self.subsets
|
||||
self._split = split
|
||||
|
||||
def tokens2str(self, seq):
|
||||
dialog_sequence = ''
|
||||
for sentence in seq:
|
||||
for word in sentence:
|
||||
dialog_sequence += self.tokenizer._convert_id_to_token(word) + " "
|
||||
dialog_sequence += ' </end> '
|
||||
dialog_sequence = dialog_sequence.encode('utf8')
|
||||
return dialog_sequence
|
||||
|
||||
def pruneRounds(self, context, num_rounds):
|
||||
start_segment = 1
|
||||
len_context = len(context)
|
||||
cur_rounds = (len(context) // 2) + 1
|
||||
l_index = 0
|
||||
if cur_rounds > num_rounds:
|
||||
# caption is not part of the final input
|
||||
l_index = len_context - (2 * num_rounds)
|
||||
start_segment = 0
|
||||
return context[l_index:], start_segment
|
||||
|
||||
def tokenize_utterance(self, sent, sentences, tot_len, sentence_count, sentence_map, speakers):
|
||||
sentences.extend(sent + ['[SEP]'])
|
||||
tokenized_sent = self.tokenizer.convert_tokens_to_ids(sent)
|
||||
assert len(sent) == len(tokenized_sent), 'sub-word tokens are not allowed!'
|
||||
|
||||
sent_len = len(tokenized_sent)
|
||||
tot_len += sent_len + 1 # the additional 1 is for the sep token
|
||||
sentence_count += 1
|
||||
sentence_map.extend([sentence_count * 2 - 1] * sent_len)
|
||||
sentence_map.append(sentence_count * 2) # for [SEP]
|
||||
speakers.extend([2] * (sent_len + 1))
|
||||
|
||||
return tokenized_sent, sentences, tot_len, sentence_count, sentence_map, speakers
|
||||
|
||||
def __getitem__(self, index):
|
||||
return NotImplementedError
|
||||
|
||||
def collate_fn(self, batch):
|
||||
tokens_size = batch[0]['tokens'].size()
|
||||
num_rounds, num_samples = tokens_size[0], tokens_size[1]
|
||||
merged_batch = {key: [d[key] for d in batch] for key in batch[0]}
|
||||
|
||||
if self.config['stack_gr_data']:
|
||||
if (len(batch)) > 1:
|
||||
max_question_gr_len = max([length.max().item() for length in merged_batch['len_question_gr']])
|
||||
max_history_gr_len = max([length.max().item() for length in merged_batch['len_history_gr']])
|
||||
max_history_sep_len = max([length.max().item() for length in merged_batch['len_history_sep']])
|
||||
max_image_gr_len = max([length.max().item() for length in merged_batch['len_image_gr']])
|
||||
|
||||
question_edge_indices_padded = []
|
||||
question_edge_attributes_padded = []
|
||||
|
||||
for q_e_idx, q_e_attr in zip(merged_batch['question_edge_indices'], merged_batch['question_edge_attributes']):
|
||||
b_size, edge_dim, orig_len = q_e_idx.size()
|
||||
q_e_idx_padded = torch.zeros((b_size, edge_dim, max_question_gr_len))
|
||||
q_e_idx_padded[:, :, :orig_len] = q_e_idx
|
||||
question_edge_indices_padded.append(q_e_idx_padded)
|
||||
|
||||
edge_attr_dim = q_e_attr.size(-1)
|
||||
q_e_attr_padded = torch.zeros((b_size, max_question_gr_len, edge_attr_dim))
|
||||
q_e_attr_padded[:, :orig_len, :] = q_e_attr
|
||||
question_edge_attributes_padded.append(q_e_attr_padded)
|
||||
|
||||
merged_batch['question_edge_indices'] = question_edge_indices_padded
|
||||
merged_batch['question_edge_attributes'] = question_edge_attributes_padded
|
||||
|
||||
history_edge_indices_padded = []
|
||||
for h_e_idx in merged_batch['history_edge_indices']:
|
||||
b_size, _, orig_len = h_e_idx.size()
|
||||
h_edge_idx_padded = torch.zeros((b_size, 2, max_history_gr_len))
|
||||
h_edge_idx_padded[:, :, :orig_len] = h_e_idx
|
||||
history_edge_indices_padded.append(h_edge_idx_padded)
|
||||
merged_batch['history_edge_indices'] = history_edge_indices_padded
|
||||
|
||||
history_sep_indices_padded = []
|
||||
for hist_sep_idx in merged_batch['history_sep_indices']:
|
||||
b_size, orig_len = hist_sep_idx.size()
|
||||
hist_sep_idx_padded = torch.zeros((b_size, max_history_sep_len))
|
||||
hist_sep_idx_padded[:, :orig_len] = hist_sep_idx
|
||||
history_sep_indices_padded.append(hist_sep_idx_padded)
|
||||
merged_batch['history_sep_indices'] = history_sep_indices_padded
|
||||
|
||||
image_edge_indices_padded = []
|
||||
image_edge_attributes_padded = []
|
||||
for img_e_idx, img_e_attr in zip(merged_batch['image_edge_indices'], merged_batch['image_edge_attributes']):
|
||||
b_size, edge_dim, orig_len = img_e_idx.size()
|
||||
img_e_idx_padded = torch.zeros((b_size, edge_dim, max_image_gr_len))
|
||||
img_e_idx_padded[:, :, :orig_len] = img_e_idx
|
||||
image_edge_indices_padded.append(img_e_idx_padded)
|
||||
|
||||
edge_attr_dim = img_e_attr.size(-1)
|
||||
img_e_attr_padded = torch.zeros((b_size, max_image_gr_len, edge_attr_dim))
|
||||
img_e_attr_padded[:, :orig_len, :] = img_e_attr
|
||||
image_edge_attributes_padded.append(img_e_attr_padded)
|
||||
|
||||
merged_batch['image_edge_indices'] = image_edge_indices_padded
|
||||
merged_batch['image_edge_attributes'] = image_edge_attributes_padded
|
||||
|
||||
out = {}
|
||||
for key in merged_batch:
|
||||
if key in self.keys_lists_to_flatten:
|
||||
temp = []
|
||||
for b in merged_batch[key]:
|
||||
for x in b:
|
||||
temp.append(x)
|
||||
merged_batch[key] = temp
|
||||
|
||||
elif key in self.keys_to_list:
|
||||
pass
|
||||
else:
|
||||
merged_batch[key] = torch.stack(merged_batch[key], 0)
|
||||
if key in self.keys_to_expand:
|
||||
if len(merged_batch[key].size()) == 3:
|
||||
size0, size1, size2 = merged_batch[key].size()
|
||||
expand_size = (size0, num_rounds, num_samples, size1, size2)
|
||||
elif len(merged_batch[key].size()) == 2:
|
||||
size0, size1 = merged_batch[key].size()
|
||||
expand_size = (size0, num_rounds, num_samples, size1)
|
||||
merged_batch[key] = merged_batch[key].unsqueeze(1).unsqueeze(1).expand(expand_size).contiguous()
|
||||
if key in self.keys_to_flatten_1d:
|
||||
merged_batch[key] = merged_batch[key].reshape(-1)
|
||||
elif key in self.keys_to_flatten_2d:
|
||||
merged_batch[key] = merged_batch[key].reshape(-1, merged_batch[key].shape[-1])
|
||||
elif key in self.keys_to_flatten_3d:
|
||||
merged_batch[key] = merged_batch[key].reshape(-1, merged_batch[key].shape[-2], merged_batch[key].shape[-1])
|
||||
else:
|
||||
assert key in self.keys_other, f'unrecognized key in collate_fn: {key}'
|
||||
|
||||
out[key] = merged_batch[key]
|
||||
return out
|
615
dataloader/dataloader_visdial.py
Normal file
615
dataloader/dataloader_visdial.py
Normal file
|
@ -0,0 +1,615 @@
|
|||
import torch
|
||||
import os
|
||||
import numpy as np
|
||||
import random
|
||||
import pickle
|
||||
|
||||
import sys
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
from utils.data_utils import encode_input, encode_input_with_mask, encode_image_input
|
||||
from dataloader.dataloader_base import DatasetBase
|
||||
|
||||
|
||||
class VisdialDataset(DatasetBase):
|
||||
|
||||
def __init__(self, config):
|
||||
super(VisdialDataset, self).__init__(config)
|
||||
|
||||
def __getitem__(self, index):
|
||||
MAX_SEQ_LEN = self.config['max_seq_len']
|
||||
cur_data = None
|
||||
if self._split == 'train':
|
||||
cur_data = self.visdial_data_train['data']
|
||||
ques_adj_matrices_dir = os.path.join(self.config['visdial_question_adj_matrices'], 'train')
|
||||
hist_adj_matrices_dir = os.path.join(self.config['visdial_history_adj_matrices'], 'train')
|
||||
|
||||
elif self._split == 'val':
|
||||
cur_data = self.visdial_data_val['data']
|
||||
ques_adj_matrices_dir = os.path.join(self.config['visdial_question_adj_matrices'], 'val')
|
||||
hist_adj_matrices_dir = os.path.join(self.config['visdial_history_adj_matrices'], 'val')
|
||||
|
||||
else:
|
||||
cur_data = self.visdial_data_test['data']
|
||||
ques_adj_matrices_dir = os.path.join(self.config['visdial_question_adj_matrices'], 'test')
|
||||
hist_adj_matrices_dir = os.path.join(self.config['visdial_history_adj_matrices'], 'test')
|
||||
|
||||
if self.config['visdial_version'] == 0.9:
|
||||
ques_adj_matrices_dir = os.path.join(self.config['visdial_question_adj_matrices'], 'train')
|
||||
hist_adj_matrices_dir = os.path.join(self.config['visdial_history_adj_matrices'], 'train')
|
||||
|
||||
self.num_bad_samples = 0
|
||||
# number of options to score on
|
||||
num_options = self.num_options
|
||||
assert num_options > 1 and num_options <= 100
|
||||
num_dialog_rounds = 10
|
||||
|
||||
dialog = cur_data['dialogs'][index]
|
||||
cur_questions = cur_data['questions']
|
||||
cur_answers = cur_data['answers']
|
||||
img_id = dialog['image_id']
|
||||
graph_idx = dialog.get('dialog_idx', index)
|
||||
|
||||
if self._split == 'train':
|
||||
# caption
|
||||
sent = dialog['caption'].split(' ')
|
||||
sentences = ['[CLS]']
|
||||
tot_len = 1 # for the CLS token
|
||||
sentence_map = [0] # for the CLS token
|
||||
sentence_count = 0
|
||||
speakers = [0]
|
||||
|
||||
tokenized_sent, sentences, tot_len, sentence_count, sentence_map, speakers = \
|
||||
self.tokenize_utterance(sent, sentences, tot_len, sentence_count, sentence_map, speakers)
|
||||
|
||||
utterances = [[tokenized_sent]]
|
||||
utterances_random = [[tokenized_sent]]
|
||||
|
||||
for rnd, utterance in enumerate(dialog['dialog']):
|
||||
cur_rnd_utterance = utterances[-1].copy()
|
||||
cur_rnd_utterance_random = utterances[-1].copy()
|
||||
|
||||
# question
|
||||
sent = cur_questions[utterance['question']].split(' ')
|
||||
tokenized_sent, sentences, tot_len, sentence_count, sentence_map, speakers = \
|
||||
self.tokenize_utterance(sent, sentences, tot_len, sentence_count, sentence_map, speakers)
|
||||
|
||||
cur_rnd_utterance.append(tokenized_sent)
|
||||
cur_rnd_utterance_random.append(tokenized_sent)
|
||||
|
||||
# answer
|
||||
sent = cur_answers[utterance['answer']].split(' ')
|
||||
tokenized_sent, sentences, tot_len, sentence_count, sentence_map, speakers = \
|
||||
self.tokenize_utterance(sent, sentences, tot_len, sentence_count, sentence_map, speakers)
|
||||
cur_rnd_utterance.append(tokenized_sent)
|
||||
|
||||
utterances.append(cur_rnd_utterance)
|
||||
|
||||
# randomly select one random utterance in that round
|
||||
num_inds = len(utterance['answer_options'])
|
||||
gt_option_ind = utterance['gt_index']
|
||||
|
||||
negative_samples = []
|
||||
|
||||
for _ in range(self.config["num_negative_samples"]):
|
||||
|
||||
all_inds = list(range(100))
|
||||
all_inds.remove(gt_option_ind)
|
||||
all_inds = all_inds[:(num_options-1)]
|
||||
tokenized_random_utterance = None
|
||||
option_ind = None
|
||||
|
||||
while len(all_inds):
|
||||
option_ind = random.choice(all_inds)
|
||||
tokenized_random_utterance = self.tokenizer.convert_tokens_to_ids(cur_answers[utterance['answer_options'][option_ind]].split(' '))
|
||||
# the 1 here is for the sep token at the end of each utterance
|
||||
if(MAX_SEQ_LEN >= (tot_len + len(tokenized_random_utterance) + 1)):
|
||||
break
|
||||
else:
|
||||
all_inds.remove(option_ind)
|
||||
if len(all_inds) == 0:
|
||||
# all the options exceed the max len. Truncate the last utterance in this case.
|
||||
tokenized_random_utterance = tokenized_random_utterance[:len(tokenized_sent)]
|
||||
t = cur_rnd_utterance_random.copy()
|
||||
t.append(tokenized_random_utterance)
|
||||
negative_samples.append(t)
|
||||
|
||||
utterances_random.append(negative_samples)
|
||||
|
||||
# removing the caption in the beginning
|
||||
utterances = utterances[1:]
|
||||
utterances_random = utterances_random[1:]
|
||||
assert len(utterances) == len(utterances_random) == num_dialog_rounds
|
||||
assert tot_len <= MAX_SEQ_LEN, '{} {} tot_len = {} > max_seq_len'.format(
|
||||
self._split, index, tot_len
|
||||
)
|
||||
|
||||
tokens_all = []
|
||||
question_limits_all = []
|
||||
question_edge_indices_all = []
|
||||
question_edge_attributes_all = []
|
||||
history_edge_indices_all = []
|
||||
history_sep_indices_all = []
|
||||
mask_all = []
|
||||
segments_all = []
|
||||
sep_indices_all = []
|
||||
next_labels_all = []
|
||||
hist_len_all = []
|
||||
|
||||
# randomly pick several rounds to train
|
||||
pos_rounds = sorted(random.sample(range(num_dialog_rounds), self.config['sequences_per_image'] // 2), reverse=True)
|
||||
neg_rounds = sorted(random.sample(range(num_dialog_rounds), self.config['sequences_per_image'] // 2), reverse=True)
|
||||
|
||||
tokens_all_rnd = []
|
||||
question_limits_all_rnd = []
|
||||
mask_all_rnd = []
|
||||
segments_all_rnd = []
|
||||
sep_indices_all_rnd = []
|
||||
next_labels_all_rnd = []
|
||||
hist_len_all_rnd = []
|
||||
|
||||
for j in pos_rounds:
|
||||
context = utterances[j]
|
||||
context, start_segment = self.pruneRounds(context, self.config['visdial_tot_rounds'])
|
||||
if j == pos_rounds[0]: # dialog with positive label and max rounds
|
||||
tokens, segments, sep_indices, mask, input_mask, start_question, end_question = encode_input_with_mask(context, start_segment, self.CLS,
|
||||
self.SEP, self.MASK, max_seq_len=MAX_SEQ_LEN, mask_prob=self.config["mask_prob"])
|
||||
else:
|
||||
tokens, segments, sep_indices, mask, start_question, end_question = encode_input(context, start_segment, self.CLS,
|
||||
self.SEP, self.MASK, max_seq_len=MAX_SEQ_LEN, mask_prob=self.config["mask_prob"])
|
||||
tokens_all_rnd.append(tokens)
|
||||
question_limits_all_rnd.append(torch.tensor([start_question, end_question]))
|
||||
mask_all_rnd.append(mask)
|
||||
sep_indices_all_rnd.append(sep_indices)
|
||||
next_labels_all_rnd.append(torch.LongTensor([0]))
|
||||
segments_all_rnd.append(segments)
|
||||
hist_len_all_rnd.append(torch.LongTensor([len(context)-1]))
|
||||
|
||||
tokens_all.append(torch.cat(tokens_all_rnd,0).unsqueeze(0))
|
||||
mask_all.append(torch.cat(mask_all_rnd,0).unsqueeze(0))
|
||||
question_limits_all.extend(question_limits_all_rnd)
|
||||
segments_all.append(torch.cat(segments_all_rnd, 0).unsqueeze(0))
|
||||
sep_indices_all.append(torch.cat(sep_indices_all_rnd, 0).unsqueeze(0))
|
||||
next_labels_all.append(torch.cat(next_labels_all_rnd, 0).unsqueeze(0))
|
||||
hist_len_all.append(torch.cat(hist_len_all_rnd,0).unsqueeze(0))
|
||||
|
||||
assert len(pos_rounds) == 1
|
||||
question_graphs = pickle.load(
|
||||
open(os.path.join(ques_adj_matrices_dir, f'{graph_idx}.pkl'), 'rb')
|
||||
)
|
||||
|
||||
question_graph_pos = question_graphs[pos_rounds[0]]
|
||||
question_edge_index_pos = []
|
||||
question_edge_attribute_pos = []
|
||||
for edge_idx, edge_attr in question_graph_pos:
|
||||
question_edge_index_pos.append(edge_idx)
|
||||
edge_attr_one_hot = np.zeros((len(self.parse_vocab) + 1,), dtype=np.float32)
|
||||
edge_attr_one_hot[self.parse_vocab.get(edge_attr, len(self.parse_vocab))] = 1.0
|
||||
question_edge_attribute_pos.append(edge_attr_one_hot)
|
||||
|
||||
question_edge_index_pos = np.array(question_edge_index_pos, dtype=np.float64)
|
||||
question_edge_attribute_pos = np.stack(question_edge_attribute_pos, axis=0)
|
||||
|
||||
question_edge_indices_all.append(
|
||||
torch.from_numpy(question_edge_index_pos).t().long().contiguous()
|
||||
)
|
||||
|
||||
question_edge_attributes_all.append(
|
||||
torch.from_numpy(question_edge_attribute_pos)
|
||||
)
|
||||
|
||||
history_edge_indices = pickle.load(
|
||||
open(os.path.join(hist_adj_matrices_dir, f'{graph_idx}.pkl'), 'rb')
|
||||
)
|
||||
|
||||
history_edge_indices_all.append(
|
||||
torch.tensor(history_edge_indices[pos_rounds[0]]).t().long().contiguous()
|
||||
)
|
||||
# Get the [SEP] tokens that will represent the history graph node features
|
||||
hist_idx_pos = [i * 2 for i in range(pos_rounds[0] + 1)]
|
||||
sep_indices = sep_indices.squeeze(0).numpy()
|
||||
history_sep_indices_all.append(torch.from_numpy(sep_indices[hist_idx_pos]))
|
||||
|
||||
if len(neg_rounds) > 0:
|
||||
tokens_all_rnd = []
|
||||
question_limits_all_rnd = []
|
||||
mask_all_rnd = []
|
||||
segments_all_rnd = []
|
||||
sep_indices_all_rnd = []
|
||||
next_labels_all_rnd = []
|
||||
hist_len_all_rnd = []
|
||||
|
||||
for j in neg_rounds:
|
||||
|
||||
negative_samples = utterances_random[j]
|
||||
for context_random in negative_samples:
|
||||
context_random, start_segment = self.pruneRounds(context_random, self.config['visdial_tot_rounds'])
|
||||
tokens_random, segments_random, sep_indices_random, mask_random, start_question, end_question = encode_input(context_random, start_segment, self.CLS,
|
||||
self.SEP, self.MASK, max_seq_len=MAX_SEQ_LEN, mask_prob=self.config["mask_prob"])
|
||||
tokens_all_rnd.append(tokens_random)
|
||||
question_limits_all_rnd.append(torch.tensor([start_question, end_question]))
|
||||
mask_all_rnd.append(mask_random)
|
||||
sep_indices_all_rnd.append(sep_indices_random)
|
||||
next_labels_all_rnd.append(torch.LongTensor([1]))
|
||||
segments_all_rnd.append(segments_random)
|
||||
hist_len_all_rnd.append(torch.LongTensor([len(context_random)-1]))
|
||||
|
||||
tokens_all.append(torch.cat(tokens_all_rnd,0).unsqueeze(0))
|
||||
mask_all.append(torch.cat(mask_all_rnd,0).unsqueeze(0))
|
||||
question_limits_all.extend(question_limits_all_rnd)
|
||||
segments_all.append(torch.cat(segments_all_rnd, 0).unsqueeze(0))
|
||||
sep_indices_all.append(torch.cat(sep_indices_all_rnd, 0).unsqueeze(0))
|
||||
next_labels_all.append(torch.cat(next_labels_all_rnd, 0).unsqueeze(0))
|
||||
hist_len_all.append(torch.cat(hist_len_all_rnd,0).unsqueeze(0))
|
||||
|
||||
assert len(neg_rounds) == 1
|
||||
|
||||
question_graph_neg = question_graphs[neg_rounds[0]]
|
||||
question_edge_index_neg = []
|
||||
question_edge_attribute_neg = []
|
||||
for edge_idx, edge_attr in question_graph_neg:
|
||||
question_edge_index_neg.append(edge_idx)
|
||||
edge_attr_one_hot = np.zeros((len(self.parse_vocab) + 1,), dtype=np.float32)
|
||||
edge_attr_one_hot[self.parse_vocab.get(edge_attr, len(self.parse_vocab))] = 1.0
|
||||
question_edge_attribute_neg.append(edge_attr_one_hot)
|
||||
|
||||
question_edge_index_neg = np.array(question_edge_index_neg, dtype=np.float64)
|
||||
question_edge_attribute_neg = np.stack(question_edge_attribute_neg, axis=0)
|
||||
|
||||
question_edge_indices_all.append(
|
||||
torch.from_numpy(question_edge_index_neg).t().long().contiguous()
|
||||
)
|
||||
|
||||
question_edge_attributes_all.append(
|
||||
torch.from_numpy(question_edge_attribute_neg)
|
||||
)
|
||||
|
||||
history_edge_indices_all.append(
|
||||
torch.tensor(history_edge_indices[neg_rounds[0]]).t().long().contiguous()
|
||||
)
|
||||
|
||||
# Get the [SEP] tokens that will represent the history graph node features
|
||||
hist_idx_neg = [i * 2 for i in range(neg_rounds[0] + 1)]
|
||||
sep_indices_random = sep_indices_random.squeeze(0).numpy()
|
||||
history_sep_indices_all.append(torch.from_numpy(sep_indices_random[hist_idx_neg]))
|
||||
|
||||
tokens_all = torch.cat(tokens_all, 0) # [2, num_pos, max_len]
|
||||
question_limits_all = torch.stack(question_limits_all, 0) # [2, 2]
|
||||
mask_all = torch.cat(mask_all,0)
|
||||
segments_all = torch.cat(segments_all, 0)
|
||||
sep_indices_all = torch.cat(sep_indices_all, 0)
|
||||
next_labels_all = torch.cat(next_labels_all, 0)
|
||||
hist_len_all = torch.cat(hist_len_all, 0)
|
||||
input_mask_all = torch.LongTensor(input_mask) # [max_len]
|
||||
|
||||
item = {}
|
||||
|
||||
item['tokens'] = tokens_all
|
||||
item['question_limits'] = question_limits_all
|
||||
item['question_edge_indices'] = question_edge_indices_all
|
||||
item['question_edge_attributes'] = question_edge_attributes_all
|
||||
|
||||
item['history_edge_indices'] = history_edge_indices_all
|
||||
item['history_sep_indices'] = history_sep_indices_all
|
||||
item['segments'] = segments_all
|
||||
item['sep_indices'] = sep_indices_all
|
||||
item['mask'] = mask_all
|
||||
item['next_sentence_labels'] = next_labels_all
|
||||
item['hist_len'] = hist_len_all
|
||||
item['input_mask'] = input_mask_all
|
||||
|
||||
# get image features
|
||||
if not self.config['dataloader_text_only']:
|
||||
features, num_boxes, boxes, _ , image_target, image_edge_indexes, image_edge_attributes = self._image_features_reader[img_id]
|
||||
features, spatials, image_mask, image_target, image_label = encode_image_input(features, num_boxes, boxes, image_target, max_regions=self._max_region_num)
|
||||
else:
|
||||
features = spatials = image_mask = image_target = image_label = torch.tensor([0])
|
||||
|
||||
elif self._split == 'val':
|
||||
gt_relevance = None
|
||||
gt_option_inds = []
|
||||
options_all = []
|
||||
|
||||
# caption
|
||||
sent = dialog['caption'].split(' ')
|
||||
sentences = ['[CLS]']
|
||||
tot_len = 1 # for the CLS token
|
||||
sentence_map = [0] # for the CLS token
|
||||
sentence_count = 0
|
||||
speakers = [0]
|
||||
|
||||
tokenized_sent, sentences, tot_len, sentence_count, sentence_map, speakers = \
|
||||
self.tokenize_utterance(sent, sentences, tot_len, sentence_count, sentence_map, speakers)
|
||||
utterances = [[tokenized_sent]]
|
||||
|
||||
for rnd, utterance in enumerate(dialog['dialog']):
|
||||
cur_rnd_utterance = utterances[-1].copy()
|
||||
|
||||
# question
|
||||
sent = cur_questions[utterance['question']].split(' ')
|
||||
tokenized_sent, sentences, tot_len, sentence_count, sentence_map, speakers = \
|
||||
self.tokenize_utterance(sent, sentences, tot_len, sentence_count, sentence_map, speakers)
|
||||
|
||||
cur_rnd_utterance.append(tokenized_sent)
|
||||
|
||||
# current round
|
||||
gt_option_ind = utterance['gt_index']
|
||||
# first select gt option id, then choose the first num_options inds
|
||||
option_inds = []
|
||||
option_inds.append(gt_option_ind)
|
||||
all_inds = list(range(100))
|
||||
all_inds.remove(gt_option_ind)
|
||||
all_inds = all_inds[:(num_options-1)]
|
||||
option_inds.extend(all_inds)
|
||||
gt_option_inds.append(0)
|
||||
cur_rnd_options = []
|
||||
answer_options = [utterance['answer_options'][k] for k in option_inds]
|
||||
assert len(answer_options) == len(option_inds) == num_options
|
||||
assert answer_options[0] == utterance['answer']
|
||||
|
||||
# for evaluation of all options and dense relevance
|
||||
if self.visdial_data_val_dense:
|
||||
if rnd == self.visdial_data_val_dense[index]['round_id'] - 1:
|
||||
# only 1 round has gt_relevance for each example
|
||||
if 'relevance' in self.visdial_data_val_dense[index]:
|
||||
gt_relevance = torch.Tensor(self.visdial_data_val_dense[index]['relevance'])
|
||||
else:
|
||||
gt_relevance = torch.Tensor(self.visdial_data_val_dense[index]['gt_relevance'])
|
||||
# shuffle based on new indices
|
||||
gt_relevance = gt_relevance[torch.LongTensor(option_inds)]
|
||||
else:
|
||||
gt_relevance = -1
|
||||
|
||||
for answer_option in answer_options:
|
||||
cur_rnd_cur_option = cur_rnd_utterance.copy()
|
||||
cur_rnd_cur_option.append(self.tokenizer.convert_tokens_to_ids(cur_answers[answer_option].split(' ')))
|
||||
cur_rnd_options.append(cur_rnd_cur_option)
|
||||
|
||||
# answer
|
||||
sent = cur_answers[utterance['answer']].split(' ')
|
||||
tokenized_sent, sentences, tot_len, sentence_count, sentence_map, speakers = \
|
||||
self.tokenize_utterance(sent, sentences, tot_len, sentence_count, sentence_map, speakers)
|
||||
cur_rnd_utterance.append(tokenized_sent)
|
||||
|
||||
utterances.append(cur_rnd_utterance)
|
||||
options_all.append(cur_rnd_options)
|
||||
|
||||
# encode the input and create batch x 10 x 100 * max_len arrays (batch x num_rounds x num_options)
|
||||
tokens_all = []
|
||||
question_limits_all = []
|
||||
mask_all = []
|
||||
segments_all = []
|
||||
sep_indices_all = []
|
||||
hist_len_all = []
|
||||
history_sep_indices_all = []
|
||||
|
||||
for rnd, cur_rnd_options in enumerate(options_all):
|
||||
|
||||
tokens_all_rnd = []
|
||||
mask_all_rnd = []
|
||||
segments_all_rnd = []
|
||||
sep_indices_all_rnd = []
|
||||
hist_len_all_rnd = []
|
||||
|
||||
for j, cur_rnd_option in enumerate(cur_rnd_options):
|
||||
|
||||
cur_rnd_option, start_segment = self.pruneRounds(cur_rnd_option, self.config['visdial_tot_rounds'])
|
||||
if rnd == len(options_all) - 1 and j == 0: # gt dialog
|
||||
tokens, segments, sep_indices, mask, input_mask, start_question, end_question = encode_input_with_mask(cur_rnd_option, start_segment, self.CLS,
|
||||
self.SEP, self.MASK, max_seq_len=MAX_SEQ_LEN, mask_prob=0)
|
||||
else:
|
||||
tokens, segments, sep_indices, mask, start_question, end_question = encode_input(cur_rnd_option, start_segment,self.CLS,
|
||||
self.SEP, self.MASK ,max_seq_len=MAX_SEQ_LEN, mask_prob=0)
|
||||
|
||||
tokens_all_rnd.append(tokens)
|
||||
mask_all_rnd.append(mask)
|
||||
segments_all_rnd.append(segments)
|
||||
sep_indices_all_rnd.append(sep_indices)
|
||||
hist_len_all_rnd.append(torch.LongTensor([len(cur_rnd_option)-1]))
|
||||
|
||||
question_limits_all.append(torch.tensor([start_question, end_question]).unsqueeze(0).repeat(100, 1))
|
||||
tokens_all.append(torch.cat(tokens_all_rnd,0).unsqueeze(0))
|
||||
mask_all.append(torch.cat(mask_all_rnd,0).unsqueeze(0))
|
||||
segments_all.append(torch.cat(segments_all_rnd,0).unsqueeze(0))
|
||||
sep_indices_all.append(torch.cat(sep_indices_all_rnd,0).unsqueeze(0))
|
||||
hist_len_all.append(torch.cat(hist_len_all_rnd,0).unsqueeze(0))
|
||||
# Get the [SEP] tokens that will represent the history graph node features
|
||||
# It will be the same for all answer candidates as the history does not change
|
||||
# for each answer
|
||||
hist_idx = [i * 2 for i in range(rnd + 1)]
|
||||
history_sep_indices_all.extend(sep_indices.squeeze(0)[hist_idx].contiguous() for _ in range(100))
|
||||
|
||||
tokens_all = torch.cat(tokens_all, 0) # [10, 100, max_len]
|
||||
mask_all = torch.cat(mask_all, 0)
|
||||
segments_all = torch.cat(segments_all, 0)
|
||||
sep_indices_all = torch.cat(sep_indices_all, 0)
|
||||
hist_len_all = torch.cat(hist_len_all, 0)
|
||||
input_mask_all = torch.LongTensor(input_mask) # [max_len]
|
||||
|
||||
# load graph data
|
||||
question_limits_all = torch.stack(question_limits_all, 0) # [10, 100, 2]
|
||||
|
||||
question_graphs = pickle.load(
|
||||
open(os.path.join(ques_adj_matrices_dir, f'{graph_idx}.pkl'), 'rb')
|
||||
)
|
||||
question_edge_indices_all = [] # [10, N] we do not repeat it 100 times here
|
||||
question_edge_attributes_all = [] # [10, N] we do not repeat it 100 times here
|
||||
|
||||
for q_graph_round in question_graphs:
|
||||
question_edge_index = []
|
||||
question_edge_attribute = []
|
||||
for edge_index, edge_attr in q_graph_round:
|
||||
question_edge_index.append(edge_index)
|
||||
edge_attr_one_hot = np.zeros((len(self.parse_vocab) + 1,), dtype=np.float32)
|
||||
edge_attr_one_hot[self.parse_vocab.get(edge_attr, len(self.parse_vocab))] = 1.0
|
||||
question_edge_attribute.append(edge_attr_one_hot)
|
||||
question_edge_index = np.array(question_edge_index, dtype=np.float64)
|
||||
question_edge_attribute = np.stack(question_edge_attribute, axis=0)
|
||||
|
||||
question_edge_indices_all.extend(
|
||||
[torch.from_numpy(question_edge_index).t().long().contiguous() for _ in range(100)])
|
||||
question_edge_attributes_all.extend(
|
||||
[torch.from_numpy(question_edge_attribute).contiguous() for _ in range(100)])
|
||||
|
||||
_history_edge_incides_all = pickle.load(
|
||||
open(os.path.join(hist_adj_matrices_dir, f'{graph_idx}.pkl'), 'rb')
|
||||
)
|
||||
history_edge_incides_all = []
|
||||
for hist_edge_indices_rnd in _history_edge_incides_all:
|
||||
history_edge_incides_all.extend(
|
||||
[torch.tensor(hist_edge_indices_rnd).t().long().contiguous() for _ in range(100)]
|
||||
)
|
||||
|
||||
item = {}
|
||||
item['tokens'] = tokens_all
|
||||
item['segments'] = segments_all
|
||||
item['sep_indices'] = sep_indices_all
|
||||
item['mask'] = mask_all
|
||||
item['hist_len'] = hist_len_all
|
||||
item['input_mask'] = input_mask_all
|
||||
|
||||
item['gt_option_inds'] = torch.LongTensor(gt_option_inds)
|
||||
|
||||
# return dense annotation data as well
|
||||
if self.visdial_data_val_dense:
|
||||
item['round_id'] = torch.LongTensor([self.visdial_data_val_dense[index]['round_id']])
|
||||
item['gt_relevance'] = gt_relevance
|
||||
|
||||
item['question_limits'] = question_limits_all
|
||||
|
||||
item['question_edge_indices'] = question_edge_indices_all
|
||||
item['question_edge_attributes'] = question_edge_attributes_all
|
||||
|
||||
item['history_edge_indices'] = history_edge_incides_all
|
||||
item['history_sep_indices'] = history_sep_indices_all
|
||||
|
||||
# get image features
|
||||
if not self.config['dataloader_text_only']:
|
||||
features, num_boxes, boxes, _ , image_target, image_edge_indexes, image_edge_attributes = self._image_features_reader[img_id]
|
||||
features, spatials, image_mask, image_target, image_label = encode_image_input(
|
||||
features, num_boxes, boxes, image_target, max_regions=self._max_region_num, mask_prob=0)
|
||||
else:
|
||||
features = spatials = image_mask = image_target = image_label = torch.tensor([0])
|
||||
|
||||
elif self.split == 'test':
|
||||
assert num_options == 100
|
||||
cur_rnd_utterance = [self.tokenizer.convert_tokens_to_ids(dialog['caption'].split(' '))]
|
||||
options_all = []
|
||||
for rnd,utterance in enumerate(dialog['dialog']):
|
||||
cur_rnd_utterance.append(self.tokenizer.convert_tokens_to_ids(cur_questions[utterance['question']].split(' ')))
|
||||
if rnd != len(dialog['dialog'])-1:
|
||||
cur_rnd_utterance.append(self.tokenizer.convert_tokens_to_ids(cur_answers[utterance['answer']].split(' ')))
|
||||
for answer_option in dialog['dialog'][-1]['answer_options']:
|
||||
cur_option = cur_rnd_utterance.copy()
|
||||
cur_option.append(self.tokenizer.convert_tokens_to_ids(cur_answers[answer_option].split(' ')))
|
||||
options_all.append(cur_option)
|
||||
|
||||
tokens_all = []
|
||||
mask_all = []
|
||||
segments_all = []
|
||||
sep_indices_all = []
|
||||
hist_len_all = []
|
||||
|
||||
for j, option in enumerate(options_all):
|
||||
option, start_segment = self.pruneRounds(option, self.config['visdial_tot_rounds'])
|
||||
tokens, segments, sep_indices, mask = encode_input(option, start_segment ,self.CLS,
|
||||
self.SEP, self.MASK ,max_seq_len=MAX_SEQ_LEN, mask_prob=0)
|
||||
|
||||
tokens_all.append(tokens)
|
||||
mask_all.append(mask)
|
||||
segments_all.append(segments)
|
||||
sep_indices_all.append(sep_indices)
|
||||
hist_len_all.append(torch.LongTensor([len(option)-1]))
|
||||
|
||||
tokens_all = torch.cat(tokens_all,0)
|
||||
mask_all = torch.cat(mask_all,0)
|
||||
segments_all = torch.cat(segments_all, 0)
|
||||
sep_indices_all = torch.cat(sep_indices_all, 0)
|
||||
hist_len_all = torch.cat(hist_len_all,0)
|
||||
hist_idx = [i*2 for i in range(len(dialog['dialog']))]
|
||||
history_sep_indices_all = [sep_indices.squeeze(0)[hist_idx].contiguous() for _ in range(num_options)]
|
||||
|
||||
with open(os.path.join(ques_adj_matrices_dir, f'{graph_idx}.pkl'), 'rb') as f:
|
||||
question_graphs = pickle.load(f)
|
||||
q_graph_last = question_graphs[-1]
|
||||
question_edge_index = []
|
||||
question_edge_attribute = []
|
||||
for edge_index, edge_attr in q_graph_last:
|
||||
question_edge_index.append(edge_index)
|
||||
edge_attr_one_hot = np.zeros((len(self.parse_vocab) + 1,), dtype=np.float32)
|
||||
edge_attr_one_hot[self.parse_vocab.get(edge_attr, len(self.parse_vocab))] = 1.0
|
||||
question_edge_attribute.append(edge_attr_one_hot)
|
||||
question_edge_index = np.array(question_edge_index, dtype=np.float64)
|
||||
question_edge_attribute = np.stack(question_edge_attribute, axis=0)
|
||||
|
||||
question_edge_indices_all = [torch.from_numpy(question_edge_index).t().long().contiguous() for _ in range(num_options)]
|
||||
question_edge_attributes_all = [torch.from_numpy(question_edge_attribute).contiguous() for _ in range(num_options)]
|
||||
|
||||
with open(os.path.join(hist_adj_matrices_dir, f'{graph_idx}.pkl'), 'rb') as f:
|
||||
_history_edge_incides_all = pickle.load(f)
|
||||
_history_edge_incides_last = _history_edge_incides_all[-1]
|
||||
history_edge_index_all = [torch.tensor(_history_edge_incides_last).t().long().contiguous() for _ in range(num_options)]
|
||||
|
||||
if self.config['stack_gr_data']:
|
||||
question_edge_indices_all = torch.stack(question_edge_indices_all, dim=0)
|
||||
question_edge_attributes_all = torch.stack(question_edge_attributes_all, dim=0)
|
||||
history_edge_index_all = torch.stack(history_edge_index_all, dim=0)
|
||||
history_sep_indices_all = torch.stack(history_sep_indices_all, dim=0)
|
||||
len_question_gr = torch.tensor(question_edge_indices_all.size(-1)).unsqueeze(0).repeat(num_options, 1)
|
||||
len_history_gr = torch.tensor(history_edge_index_all.size(-1)).repeat(num_options, 1)
|
||||
len_history_sep = torch.tensor(history_sep_indices_all.size(-1)).repeat(num_options, 1)
|
||||
|
||||
item = {}
|
||||
item['tokens'] = tokens_all.unsqueeze(0)
|
||||
item['segments'] = segments_all.unsqueeze(0)
|
||||
item['sep_indices'] = sep_indices_all.unsqueeze(0)
|
||||
item['mask'] = mask_all.unsqueeze(0)
|
||||
item['hist_len'] = hist_len_all.unsqueeze(0)
|
||||
item['question_limits'] = question_limits_all
|
||||
item['question_edge_indices'] = question_edge_indices_all
|
||||
item['question_edge_attributes'] = question_edge_attributes_all
|
||||
|
||||
item['history_edge_indices'] = history_edge_index_all
|
||||
item['history_sep_indices'] = history_sep_indices_all
|
||||
|
||||
if self.config['stack_gr_data']:
|
||||
item['len_question_gr'] = len_question_gr
|
||||
item['len_history_gr'] = len_history_gr
|
||||
item['len_history_sep'] = len_history_sep
|
||||
|
||||
item['round_id'] = torch.LongTensor([dialog['round_id']])
|
||||
|
||||
# get image features
|
||||
if not self.config['dataloader_text_only']:
|
||||
features, num_boxes, boxes, _ , image_target, image_edge_indexes, image_edge_attributes = self._image_features_reader[img_id]
|
||||
features, spatials, image_mask, image_target, image_label = encode_image_input(features, num_boxes, boxes, image_target, max_regions=self._max_region_num, mask_prob=0)
|
||||
else:
|
||||
features = spatials = image_mask = image_target = image_label = torch.tensor([0])
|
||||
|
||||
item['image_feat'] = features
|
||||
item['image_loc'] = spatials
|
||||
item['image_mask'] = image_mask
|
||||
item['image_target'] = image_target
|
||||
item['image_label'] = image_label
|
||||
item['image_id'] = torch.LongTensor([img_id])
|
||||
if self._split == 'train':
|
||||
# cheap hack to account for the graph data for the postitive and negatice examples
|
||||
item['image_edge_indices'] = [torch.from_numpy(image_edge_indexes).long(), torch.from_numpy(image_edge_indexes).long()]
|
||||
item['image_edge_attributes'] = [torch.from_numpy(image_edge_attributes), torch.from_numpy(image_edge_attributes)]
|
||||
elif self._split == 'val':
|
||||
# cheap hack to account for the graph data for the postitive and negatice examples
|
||||
item['image_edge_indices'] = [torch.from_numpy(image_edge_indexes).contiguous().long() for _ in range(1000)]
|
||||
item['image_edge_attributes'] = [torch.from_numpy(image_edge_attributes).contiguous() for _ in range(1000)]
|
||||
|
||||
else:
|
||||
# cheap hack to account for the graph data for the postitive and negatice examples
|
||||
item['image_edge_indices'] = [torch.from_numpy(image_edge_indexes).contiguous().long() for _ in range(100)]
|
||||
item['image_edge_attributes'] = [torch.from_numpy(image_edge_attributes).contiguous() for _ in range(100)]
|
||||
|
||||
if self.config['stack_gr_data']:
|
||||
item['image_edge_indices'] = torch.stack(item['image_edge_indices'], dim=0)
|
||||
item['image_edge_attributes'] = torch.stack(item['image_edge_attributes'], dim=0)
|
||||
len_image_gr = torch.tensor(item['image_edge_indices'].size(-1)).unsqueeze(0).repeat(num_options)
|
||||
item['len_image_gr'] = len_image_gr
|
||||
|
||||
return item
|
313
dataloader/dataloader_visdial_dense.py
Normal file
313
dataloader/dataloader_visdial_dense.py
Normal file
|
@ -0,0 +1,313 @@
|
|||
import torch
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import numpy as np
|
||||
import random
|
||||
from tqdm import tqdm
|
||||
import copy
|
||||
import pyhocon
|
||||
import glog as log
|
||||
from collections import OrderedDict
|
||||
import argparse
|
||||
import pickle
|
||||
import torch.utils.data as tud
|
||||
import sys
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
from utils.data_utils import encode_input, encode_image_input
|
||||
from dataloader.dataloader_base import DatasetBase
|
||||
|
||||
|
||||
class VisdialDenseDataset(DatasetBase):
|
||||
|
||||
def __init__(self, config):
|
||||
super(VisdialDenseDataset, self).__init__(config)
|
||||
with open(config.tr_graph_idx_mapping, 'r') as f:
|
||||
self.tr_graph_idx_mapping = json.load(f)
|
||||
|
||||
with open(config.val_graph_idx_mapping, 'r') as f:
|
||||
self.val_graph_idx_mapping = json.load(f)
|
||||
|
||||
with open(config.test_graph_idx_mapping, 'r') as f:
|
||||
self.test_graph_idx_mapping = json.load(f)
|
||||
|
||||
|
||||
self.question_gr_paths = {
|
||||
'train': os.path.join(self.config['visdial_question_adj_matrices'], 'train'),
|
||||
'val': os.path.join(self.config['visdial_question_adj_matrices'], 'val'),
|
||||
'test': os.path.join(self.config['visdial_question_adj_matrices'], 'test')
|
||||
}
|
||||
|
||||
self.history_gr_paths = {
|
||||
'train': os.path.join(self.config['visdial_history_adj_matrices'], 'train'),
|
||||
'val': os.path.join(self.config['visdial_history_adj_matrices'], 'val'),
|
||||
'test': os.path.join(self.config['visdial_history_adj_matrices'], 'test')
|
||||
}
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
MAX_SEQ_LEN = self.config['max_seq_len']
|
||||
cur_data = None
|
||||
cur_dense_annotations = None
|
||||
|
||||
if self._split == 'train':
|
||||
cur_data = self.visdial_data_train['data']
|
||||
cur_dense_annotations = self.visdial_data_train_dense
|
||||
cur_question_gr_path = self.question_gr_paths['train']
|
||||
cur_history_gr_path = self.history_gr_paths['train']
|
||||
cur_gr_mapping = self.tr_graph_idx_mapping
|
||||
|
||||
if self.config['rlv_hst_only']:
|
||||
cur_rlv_hst = self.rlv_hst_train
|
||||
elif self._split == 'val':
|
||||
cur_data = self.visdial_data_val['data']
|
||||
cur_dense_annotations = self.visdial_data_val_dense
|
||||
cur_question_gr_path = self.question_gr_paths['val']
|
||||
cur_history_gr_path = self.history_gr_paths['val']
|
||||
cur_gr_mapping = self.val_graph_idx_mapping
|
||||
|
||||
if self.config['rlv_hst_only']:
|
||||
cur_rlv_hst = self.rlv_hst_val
|
||||
elif self._split == 'trainval':
|
||||
if index >= self.numDataPoints['train']:
|
||||
cur_data = self.visdial_data_val['data']
|
||||
cur_dense_annotations = self.visdial_data_val_dense
|
||||
cur_gr_mapping = self.val_graph_idx_mapping
|
||||
index -= self.numDataPoints['train']
|
||||
cur_question_gr_path = self.question_gr_paths['val']
|
||||
cur_history_gr_path = self.history_gr_paths['val']
|
||||
if self.config['rlv_hst_only']:
|
||||
cur_rlv_hst = self.rlv_hst_val
|
||||
else:
|
||||
cur_data = self.visdial_data_train['data']
|
||||
cur_dense_annotations = self.visdial_data_train_dense
|
||||
cur_question_gr_path = self.question_gr_paths['train']
|
||||
cur_gr_mapping = self.tr_graph_idx_mapping
|
||||
cur_history_gr_path = self.history_gr_paths['train']
|
||||
if self.config['rlv_hst_only']:
|
||||
cur_rlv_hst = self.rlv_hst_train
|
||||
elif self._split == 'test':
|
||||
cur_data = self.visdial_data_test['data']
|
||||
cur_question_gr_path = self.question_gr_paths['test']
|
||||
cur_history_gr_path = self.history_gr_paths['test']
|
||||
if self.config['rlv_hst_only']:
|
||||
cur_rlv_hst = self.rlv_hst_test
|
||||
|
||||
# number of options to score on
|
||||
num_options = self.num_options_dense
|
||||
if self._split == 'test' or self.config['validating'] or self.config['predicting']:
|
||||
assert num_options == 100
|
||||
else:
|
||||
assert num_options >=1 and num_options <= 100
|
||||
|
||||
dialog = cur_data['dialogs'][index]
|
||||
cur_questions = cur_data['questions']
|
||||
cur_answers = cur_data['answers']
|
||||
img_id = dialog['image_id']
|
||||
if self._split != 'test':
|
||||
graph_idx = cur_gr_mapping[str(img_id)]
|
||||
else:
|
||||
graph_idx = index
|
||||
|
||||
if self._split != 'test':
|
||||
assert img_id == cur_dense_annotations[index]['image_id']
|
||||
if self.config['rlv_hst_only']:
|
||||
rlv_hst = cur_rlv_hst[str(img_id)] # [10 for each round, 10 for cap + first 9 round ]
|
||||
|
||||
if self._split == 'test':
|
||||
cur_rounds = len(dialog['dialog']) # 1, 2, ..., 10
|
||||
else:
|
||||
cur_rounds = cur_dense_annotations[index]['round_id'] # 1, 2, ..., 10
|
||||
|
||||
# caption
|
||||
cur_rnd_utterance = []
|
||||
include_caption = True
|
||||
if self.config['rlv_hst_only']:
|
||||
if self.config['rlv_hst_dense_round']:
|
||||
if rlv_hst[0] == 0:
|
||||
include_caption = False
|
||||
elif rlv_hst[cur_rounds - 1][0] == 0:
|
||||
include_caption = False
|
||||
if include_caption:
|
||||
sent = dialog['caption'].split(' ')
|
||||
tokenized_sent = self.tokenizer.convert_tokens_to_ids(sent)
|
||||
cur_rnd_utterance.append(tokenized_sent)
|
||||
# tot_len += len(sent) + 1
|
||||
|
||||
for rnd, utterance in enumerate(dialog['dialog'][:cur_rounds]):
|
||||
if self.config['rlv_hst_only'] and rnd < cur_rounds - 1:
|
||||
if self.config['rlv_hst_dense_round']:
|
||||
if rlv_hst[rnd + 1] == 0:
|
||||
continue
|
||||
elif rlv_hst[cur_rounds - 1][rnd + 1] == 0:
|
||||
continue
|
||||
# question
|
||||
sent = cur_questions[utterance['question']].split(' ')
|
||||
tokenized_sent = self.tokenizer.convert_tokens_to_ids(sent)
|
||||
cur_rnd_utterance.append(tokenized_sent)
|
||||
|
||||
# answer
|
||||
if rnd != cur_rounds - 1:
|
||||
sent = cur_answers[utterance['answer']].split(' ')
|
||||
tokenized_sent = self.tokenizer.convert_tokens_to_ids(sent)
|
||||
cur_rnd_utterance.append(tokenized_sent)
|
||||
|
||||
if self.config['rlv_hst_only']:
|
||||
num_rlv_rnds = len(cur_rnd_utterance) - 1
|
||||
else:
|
||||
num_rlv_rnds = None
|
||||
|
||||
if self._split != 'test':
|
||||
gt_option = dialog['dialog'][cur_rounds - 1]['gt_index']
|
||||
if self.config['training'] or self.config['debugging']:
|
||||
# first select gt option id, then choose the first num_options inds
|
||||
option_inds = []
|
||||
option_inds.append(gt_option)
|
||||
all_inds = list(range(100))
|
||||
all_inds.remove(gt_option)
|
||||
# debug
|
||||
if num_options < 100:
|
||||
random.shuffle(all_inds)
|
||||
all_inds = all_inds[:(num_options-1)]
|
||||
option_inds.extend(all_inds)
|
||||
gt_option = 0
|
||||
else:
|
||||
option_inds = range(num_options)
|
||||
answer_options = [dialog['dialog'][cur_rounds - 1]['answer_options'][k] for k in option_inds]
|
||||
if 'relevance' in cur_dense_annotations[index]:
|
||||
key = 'relevance'
|
||||
else:
|
||||
key = 'gt_relevance'
|
||||
gt_relevance = torch.Tensor(cur_dense_annotations[index][key])
|
||||
gt_relevance = gt_relevance[option_inds]
|
||||
assert len(answer_options) == len(option_inds) == num_options
|
||||
else:
|
||||
answer_options = dialog['dialog'][-1]['answer_options']
|
||||
assert len(answer_options) == num_options
|
||||
|
||||
options_all = []
|
||||
for answer_option in answer_options:
|
||||
cur_option = cur_rnd_utterance.copy()
|
||||
cur_option.append(self.tokenizer.convert_tokens_to_ids(cur_answers[answer_option].split(' ')))
|
||||
options_all.append(cur_option)
|
||||
if not self.config['rlv_hst_only']:
|
||||
assert len(cur_option) == 2 * cur_rounds + 1
|
||||
|
||||
tokens_all = []
|
||||
mask_all = []
|
||||
segments_all = []
|
||||
sep_indices_all = []
|
||||
hist_len_all = []
|
||||
tot_len_debug = []
|
||||
|
||||
for opt_id, option in enumerate(options_all):
|
||||
option, start_segment = self.pruneRounds(option, self.config['visdial_tot_rounds'])
|
||||
tokens, segments, sep_indices, mask, start_question, end_question = encode_input(option, start_segment ,self.CLS,
|
||||
self.SEP, self.MASK ,max_seq_len=MAX_SEQ_LEN, mask_prob=0)
|
||||
|
||||
tokens_all.append(tokens)
|
||||
mask_all.append(mask)
|
||||
segments_all.append(segments)
|
||||
sep_indices_all.append(sep_indices)
|
||||
hist_len_all.append(torch.LongTensor([len(option)-1]))
|
||||
|
||||
len_tokens = sum(len(s) for s in option)
|
||||
tot_len_debug.append(len_tokens + len(option) + 1)
|
||||
|
||||
tokens_all = torch.cat(tokens_all,0)
|
||||
mask_all = torch.cat(mask_all,0)
|
||||
segments_all = torch.cat(segments_all, 0)
|
||||
sep_indices_all = torch.cat(sep_indices_all, 0)
|
||||
hist_len_all = torch.cat(hist_len_all,0)
|
||||
question_limits_all = torch.tensor([start_question, end_question]).unsqueeze(0).repeat(num_options, 1)
|
||||
if self.config['rlv_hst_only']:
|
||||
assert num_rlv_rnds > 0
|
||||
hist_idx = [i * 2 for i in range(num_rlv_rnds)]
|
||||
else:
|
||||
hist_idx = [i*2 for i in range(cur_rounds)]
|
||||
history_sep_indices_all = sep_indices.squeeze(0)[hist_idx].contiguous().unsqueeze(0).repeat(num_options, 1)
|
||||
|
||||
with open(os.path.join(cur_question_gr_path, f'{graph_idx}.pkl'), 'rb') as f:
|
||||
question_graphs = pickle.load(f)
|
||||
question_graph_round = question_graphs[cur_rounds - 1]
|
||||
question_edge_index = []
|
||||
question_edge_attribute = []
|
||||
for edge_index, edge_attr in question_graph_round:
|
||||
question_edge_index.append(edge_index)
|
||||
edge_attr_one_hot = np.zeros((len(self.parse_vocab) + 1,), dtype=np.float32)
|
||||
edge_attr_one_hot[self.parse_vocab.get(edge_attr, len(self.parse_vocab))] = 1.0
|
||||
question_edge_attribute.append(edge_attr_one_hot)
|
||||
question_edge_index = np.array(question_edge_index, dtype=np.float64)
|
||||
question_edge_attribute = np.stack(question_edge_attribute, axis=0)
|
||||
|
||||
question_edge_indices_all = [torch.from_numpy(question_edge_index).t().long().contiguous() for _ in range(num_options)]
|
||||
question_edge_attributes_all = [torch.from_numpy(question_edge_attribute).contiguous() for _ in range(num_options)]
|
||||
|
||||
if self.config['rlv_hst_only']:
|
||||
with open(os.path.join(cur_history_gr_path, f'{graph_idx}.pkl'), 'rb') as f:
|
||||
_history_edge_incides_round = pickle.load(f)
|
||||
else:
|
||||
with open(os.path.join(cur_history_gr_path, f'{graph_idx}.pkl'), 'rb') as f:
|
||||
_history_edge_incides_all = pickle.load(f)
|
||||
_history_edge_incides_round = _history_edge_incides_all[cur_rounds - 1]
|
||||
|
||||
history_edge_index_all = [torch.tensor(_history_edge_incides_round).t().long().contiguous() for _ in range(num_options)]
|
||||
|
||||
if self.config['stack_gr_data']:
|
||||
question_edge_indices_all = torch.stack(question_edge_indices_all, dim=0)
|
||||
question_edge_attributes_all = torch.stack(question_edge_attributes_all, dim=0)
|
||||
history_edge_index_all = torch.stack(history_edge_index_all, dim=0)
|
||||
|
||||
item = {}
|
||||
|
||||
item['tokens'] = tokens_all.unsqueeze(0) # [1, num_options, max_len]
|
||||
item['segments'] = segments_all.unsqueeze(0)
|
||||
item['sep_indices'] = sep_indices_all.unsqueeze(0)
|
||||
item['mask'] = mask_all.unsqueeze(0)
|
||||
item['hist_len'] = hist_len_all.unsqueeze(0)
|
||||
item['question_limits'] = question_limits_all
|
||||
item['question_edge_indices'] = question_edge_indices_all
|
||||
item['question_edge_attributes'] = question_edge_attributes_all
|
||||
item['history_edge_indices'] = history_edge_index_all
|
||||
item['history_sep_indices'] = history_sep_indices_all
|
||||
|
||||
# add dense annotation fields
|
||||
if self._split != 'test':
|
||||
item['gt_relevance'] = gt_relevance # [num_options]
|
||||
item['gt_option_inds'] = torch.LongTensor([gt_option])
|
||||
|
||||
# add next sentence labels for training with the nsp loss as well
|
||||
nsp_labels = torch.ones(*tokens_all.unsqueeze(0).shape[:-1]).long()
|
||||
nsp_labels[:,gt_option] = 0
|
||||
item['next_sentence_labels'] = nsp_labels
|
||||
|
||||
item['round_id'] = torch.LongTensor([cur_rounds])
|
||||
else:
|
||||
if 'round_id' in dialog:
|
||||
item['round_id'] = torch.LongTensor([dialog['round_id']])
|
||||
else:
|
||||
item['round_id'] = torch.LongTensor([cur_rounds])
|
||||
|
||||
# get image features
|
||||
if not self.config['dataloader_text_only']:
|
||||
features, num_boxes, boxes, _ , image_target, image_edge_indexes, image_edge_attributes = self._image_features_reader[img_id]
|
||||
features, spatials, image_mask, image_target, image_label = encode_image_input(features, num_boxes, boxes, image_target, max_regions=self._max_region_num, mask_prob=0)
|
||||
else:
|
||||
features = spatials = image_mask = image_target = image_label = torch.tensor([0])
|
||||
item['image_feat'] = features
|
||||
item['image_loc'] = spatials
|
||||
item['image_mask'] = image_mask
|
||||
item['image_id'] = torch.LongTensor([img_id])
|
||||
item['tot_len'] = torch.LongTensor(tot_len_debug)
|
||||
|
||||
|
||||
|
||||
item['image_edge_indices'] = [torch.from_numpy(image_edge_indexes).contiguous().long() for _ in range(num_options)]
|
||||
item['image_edge_attributes'] = [torch.from_numpy(image_edge_attributes).contiguous() for _ in range(num_options)]
|
||||
|
||||
if self.config['stack_gr_data']:
|
||||
item['image_edge_indices'] = torch.stack(item['image_edge_indices'], dim=0)
|
||||
item['image_edge_attributes'] = torch.stack(item['image_edge_attributes'], dim=0)
|
||||
|
||||
return item
|
Loading…
Add table
Add a link
Reference in a new issue