import torch from torch.utils import data import json import os import glog as log import pickle import torch.utils.data as tud from pytorch_transformers.tokenization_bert import BertTokenizer from utils.image_features_reader import ImageFeaturesH5Reader class DatasetBase(data.Dataset): def __init__(self, config): if config['display']: log.info('Initializing dataset') # Fetch the correct dataset for evaluation if config['validating']: assert config.eval_dataset in ['visdial', 'visdial_conv', 'visdial_vispro', 'visdial_v09'] if config.eval_dataset == 'visdial_conv': config['visdial_val'] = config.visdialconv_val config['visdial_val_dense_annotations'] = config.visdialconv_val_dense_annotations elif config.eval_dataset == 'visdial_vispro': config['visdial_val'] = config.visdialvispro_val config['visdial_val_dense_annotations'] = config.visdialvispro_val_dense_annotations elif config.eval_dataset == 'visdial_v09': config['visdial_val_09'] = config.visdial_test_09 config['visdial_val_dense_annotations'] = None self.config = config self.numDataPoints = {} if not config['dataloader_text_only']: self._image_features_reader = ImageFeaturesH5Reader( config['visdial_image_feats'], config['visdial_image_adj_matrices'] ) if self.config['training'] or self.config['validating'] or self.config['predicting']: split2data = {'train': 'train', 'val': 'val', 'test': 'test'} elif self.config['debugging']: split2data = {'train': 'val', 'val': 'val', 'test': 'test'} elif self.config['visualizing']: split2data = {'train': 'train', 'val': 'train', 'test': 'test'} filename = f'visdial_{split2data["train"]}' if config['train_on_dense']: filename += '_dense' if self.config['visdial_version'] == 0.9: filename += '_09' with open(config[filename]) as f: self.visdial_data_train = json.load(f) if self.config.num_samples > 0: self.visdial_data_train['data']['dialogs'] = self.visdial_data_train['data']['dialogs'][:self.config.num_samples] self.numDataPoints['train'] = len(self.visdial_data_train['data']['dialogs']) filename = f'visdial_{split2data["val"]}' if config['train_on_dense'] and config['training']: filename += '_dense' if self.config['visdial_version'] == 0.9: filename += '_09' with open(config[filename]) as f: self.visdial_data_val = json.load(f) if self.config.num_samples > 0: self.visdial_data_val['data']['dialogs'] = self.visdial_data_val['data']['dialogs'][:self.config.num_samples] self.numDataPoints['val'] = len(self.visdial_data_val['data']['dialogs']) if config['train_on_dense']: self.numDataPoints['trainval'] = self.numDataPoints['train'] + self.numDataPoints['val'] with open(config[f'visdial_{split2data["test"]}']) as f: self.visdial_data_test = json.load(f) self.numDataPoints['test'] = len(self.visdial_data_test['data']['dialogs']) self.rlv_hst_train = None self.rlv_hst_val = None self.rlv_hst_test = None if config['train_on_dense'] or config['predict_dense_round']: with open(config[f'visdial_{split2data["train"]}_dense_annotations']) as f: self.visdial_data_train_dense = json.load(f) if config['train_on_dense']: self.subsets = ['train', 'val', 'trainval', 'test'] else: self.subsets = ['train','val','test'] self.num_options = config["num_options"] self.num_options_dense = config["num_options_dense"] if config['visdial_version'] != 0.9: with open(config[f'visdial_{split2data["val"]}_dense_annotations']) as f: self.visdial_data_val_dense = json.load(f) else: self.visdial_data_val_dense = None self._split = 'train' self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', cache_dir=config['bert_cache_dir']) # fetching token indicecs of [CLS] and [SEP] tokens = ['[CLS]','[MASK]','[SEP]'] indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokens) self.CLS = indexed_tokens[0] self.MASK = indexed_tokens[1] self.SEP = indexed_tokens[2] self._max_region_num = 37 self.predict_each_round = self.config['predicting'] and self.config['predict_each_round'] self.keys_to_expand = ['image_feat', 'image_loc', 'image_mask', 'image_target', 'image_label'] self.keys_to_flatten_1d = ['hist_len', 'next_sentence_labels', 'round_id', 'image_id'] self.keys_to_flatten_2d = ['tokens', 'segments', 'sep_indices', 'mask', 'image_mask', 'image_label', 'input_mask', 'question_limits'] self.keys_to_flatten_3d = ['image_feat', 'image_loc', 'image_target', ] self.keys_other = ['gt_relevance', 'gt_option_inds'] self.keys_lists_to_flatten = ['image_edge_indices', 'image_edge_attributes', 'question_edge_indices', 'question_edge_attributes', 'history_edge_indices', 'history_sep_indices'] if config['stack_gr_data']: self.keys_to_flatten_3d.extend(self.keys_lists_to_flatten[:-1]) self.keys_to_flatten_2d.append(self.keys_lists_to_flatten[-1]) self.keys_to_flatten_1d.extend(['len_image_gr', 'len_question_gr', 'len_history_gr', 'len_history_sep']) self.keys_lists_to_flatten = [] self.keys_to_list = ['tot_len'] # Load the parse vocab for question graph relationship mapping if os.path.isfile(config['visdial_question_parse_vocab']): with open(config['visdial_question_parse_vocab'], 'rb') as f: self.parse_vocab = pickle.load(f) def __len__(self): return self.numDataPoints[self._split] @property def split(self): return self._split @split.setter def split(self, split): assert split in self.subsets self._split = split def tokens2str(self, seq): dialog_sequence = '' for sentence in seq: for word in sentence: dialog_sequence += self.tokenizer._convert_id_to_token(word) + " " dialog_sequence += ' ' dialog_sequence = dialog_sequence.encode('utf8') return dialog_sequence def pruneRounds(self, context, num_rounds): start_segment = 1 len_context = len(context) cur_rounds = (len(context) // 2) + 1 l_index = 0 if cur_rounds > num_rounds: # caption is not part of the final input l_index = len_context - (2 * num_rounds) start_segment = 0 return context[l_index:], start_segment def tokenize_utterance(self, sent, sentences, tot_len, sentence_count, sentence_map, speakers): sentences.extend(sent + ['[SEP]']) tokenized_sent = self.tokenizer.convert_tokens_to_ids(sent) assert len(sent) == len(tokenized_sent), 'sub-word tokens are not allowed!' sent_len = len(tokenized_sent) tot_len += sent_len + 1 # the additional 1 is for the sep token sentence_count += 1 sentence_map.extend([sentence_count * 2 - 1] * sent_len) sentence_map.append(sentence_count * 2) # for [SEP] speakers.extend([2] * (sent_len + 1)) return tokenized_sent, sentences, tot_len, sentence_count, sentence_map, speakers def __getitem__(self, index): return NotImplementedError def collate_fn(self, batch): tokens_size = batch[0]['tokens'].size() num_rounds, num_samples = tokens_size[0], tokens_size[1] merged_batch = {key: [d[key] for d in batch] for key in batch[0]} if self.config['stack_gr_data']: if (len(batch)) > 1: max_question_gr_len = max([length.max().item() for length in merged_batch['len_question_gr']]) max_history_gr_len = max([length.max().item() for length in merged_batch['len_history_gr']]) max_history_sep_len = max([length.max().item() for length in merged_batch['len_history_sep']]) max_image_gr_len = max([length.max().item() for length in merged_batch['len_image_gr']]) question_edge_indices_padded = [] question_edge_attributes_padded = [] for q_e_idx, q_e_attr in zip(merged_batch['question_edge_indices'], merged_batch['question_edge_attributes']): b_size, edge_dim, orig_len = q_e_idx.size() q_e_idx_padded = torch.zeros((b_size, edge_dim, max_question_gr_len)) q_e_idx_padded[:, :, :orig_len] = q_e_idx question_edge_indices_padded.append(q_e_idx_padded) edge_attr_dim = q_e_attr.size(-1) q_e_attr_padded = torch.zeros((b_size, max_question_gr_len, edge_attr_dim)) q_e_attr_padded[:, :orig_len, :] = q_e_attr question_edge_attributes_padded.append(q_e_attr_padded) merged_batch['question_edge_indices'] = question_edge_indices_padded merged_batch['question_edge_attributes'] = question_edge_attributes_padded history_edge_indices_padded = [] for h_e_idx in merged_batch['history_edge_indices']: b_size, _, orig_len = h_e_idx.size() h_edge_idx_padded = torch.zeros((b_size, 2, max_history_gr_len)) h_edge_idx_padded[:, :, :orig_len] = h_e_idx history_edge_indices_padded.append(h_edge_idx_padded) merged_batch['history_edge_indices'] = history_edge_indices_padded history_sep_indices_padded = [] for hist_sep_idx in merged_batch['history_sep_indices']: b_size, orig_len = hist_sep_idx.size() hist_sep_idx_padded = torch.zeros((b_size, max_history_sep_len)) hist_sep_idx_padded[:, :orig_len] = hist_sep_idx history_sep_indices_padded.append(hist_sep_idx_padded) merged_batch['history_sep_indices'] = history_sep_indices_padded image_edge_indices_padded = [] image_edge_attributes_padded = [] for img_e_idx, img_e_attr in zip(merged_batch['image_edge_indices'], merged_batch['image_edge_attributes']): b_size, edge_dim, orig_len = img_e_idx.size() img_e_idx_padded = torch.zeros((b_size, edge_dim, max_image_gr_len)) img_e_idx_padded[:, :, :orig_len] = img_e_idx image_edge_indices_padded.append(img_e_idx_padded) edge_attr_dim = img_e_attr.size(-1) img_e_attr_padded = torch.zeros((b_size, max_image_gr_len, edge_attr_dim)) img_e_attr_padded[:, :orig_len, :] = img_e_attr image_edge_attributes_padded.append(img_e_attr_padded) merged_batch['image_edge_indices'] = image_edge_indices_padded merged_batch['image_edge_attributes'] = image_edge_attributes_padded out = {} for key in merged_batch: if key in self.keys_lists_to_flatten: temp = [] for b in merged_batch[key]: for x in b: temp.append(x) merged_batch[key] = temp elif key in self.keys_to_list: pass else: merged_batch[key] = torch.stack(merged_batch[key], 0) if key in self.keys_to_expand: if len(merged_batch[key].size()) == 3: size0, size1, size2 = merged_batch[key].size() expand_size = (size0, num_rounds, num_samples, size1, size2) elif len(merged_batch[key].size()) == 2: size0, size1 = merged_batch[key].size() expand_size = (size0, num_rounds, num_samples, size1) merged_batch[key] = merged_batch[key].unsqueeze(1).unsqueeze(1).expand(expand_size).contiguous() if key in self.keys_to_flatten_1d: merged_batch[key] = merged_batch[key].reshape(-1) elif key in self.keys_to_flatten_2d: merged_batch[key] = merged_batch[key].reshape(-1, merged_batch[key].shape[-1]) elif key in self.keys_to_flatten_3d: merged_batch[key] = merged_batch[key].reshape(-1, merged_batch[key].shape[-2], merged_batch[key].shape[-1]) else: assert key in self.keys_other, f'unrecognized key in collate_fn: {key}' out[key] = merged_batch[key] return out