VDGR/dataloader/dataloader_base.py

270 lines
13 KiB
Python
Raw Normal View History

2023-10-25 15:38:09 +02:00
import torch
from torch.utils import data
import json
import os
import glog as log
import pickle
import torch.utils.data as tud
from pytorch_transformers.tokenization_bert import BertTokenizer
from utils.image_features_reader import ImageFeaturesH5Reader
class DatasetBase(data.Dataset):
def __init__(self, config):
if config['display']:
log.info('Initializing dataset')
# Fetch the correct dataset for evaluation
if config['validating']:
assert config.eval_dataset in ['visdial', 'visdial_conv', 'visdial_vispro', 'visdial_v09']
if config.eval_dataset == 'visdial_conv':
config['visdial_val'] = config.visdialconv_val
config['visdial_val_dense_annotations'] = config.visdialconv_val_dense_annotations
elif config.eval_dataset == 'visdial_vispro':
config['visdial_val'] = config.visdialvispro_val
config['visdial_val_dense_annotations'] = config.visdialvispro_val_dense_annotations
elif config.eval_dataset == 'visdial_v09':
config['visdial_val_09'] = config.visdial_test_09
config['visdial_val_dense_annotations'] = None
self.config = config
self.numDataPoints = {}
if not config['dataloader_text_only']:
self._image_features_reader = ImageFeaturesH5Reader(
config['visdial_image_feats'],
config['visdial_image_adj_matrices']
)
if self.config['training'] or self.config['validating'] or self.config['predicting']:
split2data = {'train': 'train', 'val': 'val', 'test': 'test'}
elif self.config['debugging']:
split2data = {'train': 'val', 'val': 'val', 'test': 'test'}
elif self.config['visualizing']:
split2data = {'train': 'train', 'val': 'train', 'test': 'test'}
filename = f'visdial_{split2data["train"]}'
if config['train_on_dense']:
filename += '_dense'
if self.config['visdial_version'] == 0.9:
filename += '_09'
with open(config[filename]) as f:
self.visdial_data_train = json.load(f)
if self.config.num_samples > 0:
self.visdial_data_train['data']['dialogs'] = self.visdial_data_train['data']['dialogs'][:self.config.num_samples]
self.numDataPoints['train'] = len(self.visdial_data_train['data']['dialogs'])
filename = f'visdial_{split2data["val"]}'
if config['train_on_dense'] and config['training']:
filename += '_dense'
if self.config['visdial_version'] == 0.9:
filename += '_09'
with open(config[filename]) as f:
self.visdial_data_val = json.load(f)
if self.config.num_samples > 0:
self.visdial_data_val['data']['dialogs'] = self.visdial_data_val['data']['dialogs'][:self.config.num_samples]
self.numDataPoints['val'] = len(self.visdial_data_val['data']['dialogs'])
if config['train_on_dense']:
self.numDataPoints['trainval'] = self.numDataPoints['train'] + self.numDataPoints['val']
with open(config[f'visdial_{split2data["test"]}']) as f:
self.visdial_data_test = json.load(f)
self.numDataPoints['test'] = len(self.visdial_data_test['data']['dialogs'])
self.rlv_hst_train = None
self.rlv_hst_val = None
self.rlv_hst_test = None
if config['train_on_dense'] or config['predict_dense_round']:
with open(config[f'visdial_{split2data["train"]}_dense_annotations']) as f:
self.visdial_data_train_dense = json.load(f)
if config['train_on_dense']:
self.subsets = ['train', 'val', 'trainval', 'test']
else:
self.subsets = ['train','val','test']
self.num_options = config["num_options"]
self.num_options_dense = config["num_options_dense"]
if config['visdial_version'] != 0.9:
with open(config[f'visdial_{split2data["val"]}_dense_annotations']) as f:
self.visdial_data_val_dense = json.load(f)
else:
self.visdial_data_val_dense = None
self._split = 'train'
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', cache_dir=config['bert_cache_dir'])
# fetching token indicecs of [CLS] and [SEP]
tokens = ['[CLS]','[MASK]','[SEP]']
indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokens)
self.CLS = indexed_tokens[0]
self.MASK = indexed_tokens[1]
self.SEP = indexed_tokens[2]
self._max_region_num = 37
self.predict_each_round = self.config['predicting'] and self.config['predict_each_round']
self.keys_to_expand = ['image_feat', 'image_loc', 'image_mask', 'image_target', 'image_label']
self.keys_to_flatten_1d = ['hist_len', 'next_sentence_labels', 'round_id', 'image_id']
self.keys_to_flatten_2d = ['tokens', 'segments', 'sep_indices', 'mask', 'image_mask', 'image_label', 'input_mask', 'question_limits']
self.keys_to_flatten_3d = ['image_feat', 'image_loc', 'image_target', ]
self.keys_other = ['gt_relevance', 'gt_option_inds']
self.keys_lists_to_flatten = ['image_edge_indices', 'image_edge_attributes', 'question_edge_indices', 'question_edge_attributes', 'history_edge_indices', 'history_sep_indices']
if config['stack_gr_data']:
self.keys_to_flatten_3d.extend(self.keys_lists_to_flatten[:-1])
self.keys_to_flatten_2d.append(self.keys_lists_to_flatten[-1])
self.keys_to_flatten_1d.extend(['len_image_gr', 'len_question_gr', 'len_history_gr', 'len_history_sep'])
self.keys_lists_to_flatten = []
self.keys_to_list = ['tot_len']
# Load the parse vocab for question graph relationship mapping
if os.path.isfile(config['visdial_question_parse_vocab']):
with open(config['visdial_question_parse_vocab'], 'rb') as f:
self.parse_vocab = pickle.load(f)
def __len__(self):
return self.numDataPoints[self._split]
@property
def split(self):
return self._split
@split.setter
def split(self, split):
assert split in self.subsets
self._split = split
def tokens2str(self, seq):
dialog_sequence = ''
for sentence in seq:
for word in sentence:
dialog_sequence += self.tokenizer._convert_id_to_token(word) + " "
dialog_sequence += ' </end> '
dialog_sequence = dialog_sequence.encode('utf8')
return dialog_sequence
def pruneRounds(self, context, num_rounds):
start_segment = 1
len_context = len(context)
cur_rounds = (len(context) // 2) + 1
l_index = 0
if cur_rounds > num_rounds:
# caption is not part of the final input
l_index = len_context - (2 * num_rounds)
start_segment = 0
return context[l_index:], start_segment
def tokenize_utterance(self, sent, sentences, tot_len, sentence_count, sentence_map, speakers):
sentences.extend(sent + ['[SEP]'])
tokenized_sent = self.tokenizer.convert_tokens_to_ids(sent)
assert len(sent) == len(tokenized_sent), 'sub-word tokens are not allowed!'
sent_len = len(tokenized_sent)
tot_len += sent_len + 1 # the additional 1 is for the sep token
sentence_count += 1
sentence_map.extend([sentence_count * 2 - 1] * sent_len)
sentence_map.append(sentence_count * 2) # for [SEP]
speakers.extend([2] * (sent_len + 1))
return tokenized_sent, sentences, tot_len, sentence_count, sentence_map, speakers
def __getitem__(self, index):
return NotImplementedError
def collate_fn(self, batch):
tokens_size = batch[0]['tokens'].size()
num_rounds, num_samples = tokens_size[0], tokens_size[1]
merged_batch = {key: [d[key] for d in batch] for key in batch[0]}
if self.config['stack_gr_data']:
if (len(batch)) > 1:
max_question_gr_len = max([length.max().item() for length in merged_batch['len_question_gr']])
max_history_gr_len = max([length.max().item() for length in merged_batch['len_history_gr']])
max_history_sep_len = max([length.max().item() for length in merged_batch['len_history_sep']])
max_image_gr_len = max([length.max().item() for length in merged_batch['len_image_gr']])
question_edge_indices_padded = []
question_edge_attributes_padded = []
for q_e_idx, q_e_attr in zip(merged_batch['question_edge_indices'], merged_batch['question_edge_attributes']):
b_size, edge_dim, orig_len = q_e_idx.size()
q_e_idx_padded = torch.zeros((b_size, edge_dim, max_question_gr_len))
q_e_idx_padded[:, :, :orig_len] = q_e_idx
question_edge_indices_padded.append(q_e_idx_padded)
edge_attr_dim = q_e_attr.size(-1)
q_e_attr_padded = torch.zeros((b_size, max_question_gr_len, edge_attr_dim))
q_e_attr_padded[:, :orig_len, :] = q_e_attr
question_edge_attributes_padded.append(q_e_attr_padded)
merged_batch['question_edge_indices'] = question_edge_indices_padded
merged_batch['question_edge_attributes'] = question_edge_attributes_padded
history_edge_indices_padded = []
for h_e_idx in merged_batch['history_edge_indices']:
b_size, _, orig_len = h_e_idx.size()
h_edge_idx_padded = torch.zeros((b_size, 2, max_history_gr_len))
h_edge_idx_padded[:, :, :orig_len] = h_e_idx
history_edge_indices_padded.append(h_edge_idx_padded)
merged_batch['history_edge_indices'] = history_edge_indices_padded
history_sep_indices_padded = []
for hist_sep_idx in merged_batch['history_sep_indices']:
b_size, orig_len = hist_sep_idx.size()
hist_sep_idx_padded = torch.zeros((b_size, max_history_sep_len))
hist_sep_idx_padded[:, :orig_len] = hist_sep_idx
history_sep_indices_padded.append(hist_sep_idx_padded)
merged_batch['history_sep_indices'] = history_sep_indices_padded
image_edge_indices_padded = []
image_edge_attributes_padded = []
for img_e_idx, img_e_attr in zip(merged_batch['image_edge_indices'], merged_batch['image_edge_attributes']):
b_size, edge_dim, orig_len = img_e_idx.size()
img_e_idx_padded = torch.zeros((b_size, edge_dim, max_image_gr_len))
img_e_idx_padded[:, :, :orig_len] = img_e_idx
image_edge_indices_padded.append(img_e_idx_padded)
edge_attr_dim = img_e_attr.size(-1)
img_e_attr_padded = torch.zeros((b_size, max_image_gr_len, edge_attr_dim))
img_e_attr_padded[:, :orig_len, :] = img_e_attr
image_edge_attributes_padded.append(img_e_attr_padded)
merged_batch['image_edge_indices'] = image_edge_indices_padded
merged_batch['image_edge_attributes'] = image_edge_attributes_padded
out = {}
for key in merged_batch:
if key in self.keys_lists_to_flatten:
temp = []
for b in merged_batch[key]:
for x in b:
temp.append(x)
merged_batch[key] = temp
elif key in self.keys_to_list:
pass
else:
merged_batch[key] = torch.stack(merged_batch[key], 0)
if key in self.keys_to_expand:
if len(merged_batch[key].size()) == 3:
size0, size1, size2 = merged_batch[key].size()
expand_size = (size0, num_rounds, num_samples, size1, size2)
elif len(merged_batch[key].size()) == 2:
size0, size1 = merged_batch[key].size()
expand_size = (size0, num_rounds, num_samples, size1)
merged_batch[key] = merged_batch[key].unsqueeze(1).unsqueeze(1).expand(expand_size).contiguous()
if key in self.keys_to_flatten_1d:
merged_batch[key] = merged_batch[key].reshape(-1)
elif key in self.keys_to_flatten_2d:
merged_batch[key] = merged_batch[key].reshape(-1, merged_batch[key].shape[-1])
elif key in self.keys_to_flatten_3d:
merged_batch[key] = merged_batch[key].reshape(-1, merged_batch[key].shape[-2], merged_batch[key].shape[-1])
else:
assert key in self.keys_other, f'unrecognized key in collate_fn: {key}'
out[key] = merged_batch[key]
return out