VDGR/dataloader/dataloader_visdial_dense.py

314 lines
14 KiB
Python
Raw Normal View History

2023-10-25 15:38:09 +02:00
import torch
import json
import os
import time
import numpy as np
import random
from tqdm import tqdm
import copy
import pyhocon
import glog as log
from collections import OrderedDict
import argparse
import pickle
import torch.utils.data as tud
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from utils.data_utils import encode_input, encode_image_input
from dataloader.dataloader_base import DatasetBase
class VisdialDenseDataset(DatasetBase):
def __init__(self, config):
super(VisdialDenseDataset, self).__init__(config)
with open(config.tr_graph_idx_mapping, 'r') as f:
self.tr_graph_idx_mapping = json.load(f)
with open(config.val_graph_idx_mapping, 'r') as f:
self.val_graph_idx_mapping = json.load(f)
with open(config.test_graph_idx_mapping, 'r') as f:
self.test_graph_idx_mapping = json.load(f)
self.question_gr_paths = {
'train': os.path.join(self.config['visdial_question_adj_matrices'], 'train'),
'val': os.path.join(self.config['visdial_question_adj_matrices'], 'val'),
'test': os.path.join(self.config['visdial_question_adj_matrices'], 'test')
}
self.history_gr_paths = {
'train': os.path.join(self.config['visdial_history_adj_matrices'], 'train'),
'val': os.path.join(self.config['visdial_history_adj_matrices'], 'val'),
'test': os.path.join(self.config['visdial_history_adj_matrices'], 'test')
}
def __getitem__(self, index):
MAX_SEQ_LEN = self.config['max_seq_len']
cur_data = None
cur_dense_annotations = None
if self._split == 'train':
cur_data = self.visdial_data_train['data']
cur_dense_annotations = self.visdial_data_train_dense
cur_question_gr_path = self.question_gr_paths['train']
cur_history_gr_path = self.history_gr_paths['train']
cur_gr_mapping = self.tr_graph_idx_mapping
if self.config['rlv_hst_only']:
cur_rlv_hst = self.rlv_hst_train
elif self._split == 'val':
cur_data = self.visdial_data_val['data']
cur_dense_annotations = self.visdial_data_val_dense
cur_question_gr_path = self.question_gr_paths['val']
cur_history_gr_path = self.history_gr_paths['val']
cur_gr_mapping = self.val_graph_idx_mapping
if self.config['rlv_hst_only']:
cur_rlv_hst = self.rlv_hst_val
elif self._split == 'trainval':
if index >= self.numDataPoints['train']:
cur_data = self.visdial_data_val['data']
cur_dense_annotations = self.visdial_data_val_dense
cur_gr_mapping = self.val_graph_idx_mapping
index -= self.numDataPoints['train']
cur_question_gr_path = self.question_gr_paths['val']
cur_history_gr_path = self.history_gr_paths['val']
if self.config['rlv_hst_only']:
cur_rlv_hst = self.rlv_hst_val
else:
cur_data = self.visdial_data_train['data']
cur_dense_annotations = self.visdial_data_train_dense
cur_question_gr_path = self.question_gr_paths['train']
cur_gr_mapping = self.tr_graph_idx_mapping
cur_history_gr_path = self.history_gr_paths['train']
if self.config['rlv_hst_only']:
cur_rlv_hst = self.rlv_hst_train
elif self._split == 'test':
cur_data = self.visdial_data_test['data']
cur_question_gr_path = self.question_gr_paths['test']
cur_history_gr_path = self.history_gr_paths['test']
if self.config['rlv_hst_only']:
cur_rlv_hst = self.rlv_hst_test
# number of options to score on
num_options = self.num_options_dense
if self._split == 'test' or self.config['validating'] or self.config['predicting']:
assert num_options == 100
else:
assert num_options >=1 and num_options <= 100
dialog = cur_data['dialogs'][index]
cur_questions = cur_data['questions']
cur_answers = cur_data['answers']
img_id = dialog['image_id']
if self._split != 'test':
graph_idx = cur_gr_mapping[str(img_id)]
else:
graph_idx = index
if self._split != 'test':
assert img_id == cur_dense_annotations[index]['image_id']
if self.config['rlv_hst_only']:
rlv_hst = cur_rlv_hst[str(img_id)] # [10 for each round, 10 for cap + first 9 round ]
if self._split == 'test':
cur_rounds = len(dialog['dialog']) # 1, 2, ..., 10
else:
cur_rounds = cur_dense_annotations[index]['round_id'] # 1, 2, ..., 10
# caption
cur_rnd_utterance = []
include_caption = True
if self.config['rlv_hst_only']:
if self.config['rlv_hst_dense_round']:
if rlv_hst[0] == 0:
include_caption = False
elif rlv_hst[cur_rounds - 1][0] == 0:
include_caption = False
if include_caption:
sent = dialog['caption'].split(' ')
tokenized_sent = self.tokenizer.convert_tokens_to_ids(sent)
cur_rnd_utterance.append(tokenized_sent)
# tot_len += len(sent) + 1
for rnd, utterance in enumerate(dialog['dialog'][:cur_rounds]):
if self.config['rlv_hst_only'] and rnd < cur_rounds - 1:
if self.config['rlv_hst_dense_round']:
if rlv_hst[rnd + 1] == 0:
continue
elif rlv_hst[cur_rounds - 1][rnd + 1] == 0:
continue
# question
sent = cur_questions[utterance['question']].split(' ')
tokenized_sent = self.tokenizer.convert_tokens_to_ids(sent)
cur_rnd_utterance.append(tokenized_sent)
# answer
if rnd != cur_rounds - 1:
sent = cur_answers[utterance['answer']].split(' ')
tokenized_sent = self.tokenizer.convert_tokens_to_ids(sent)
cur_rnd_utterance.append(tokenized_sent)
if self.config['rlv_hst_only']:
num_rlv_rnds = len(cur_rnd_utterance) - 1
else:
num_rlv_rnds = None
if self._split != 'test':
gt_option = dialog['dialog'][cur_rounds - 1]['gt_index']
if self.config['training'] or self.config['debugging']:
# first select gt option id, then choose the first num_options inds
option_inds = []
option_inds.append(gt_option)
all_inds = list(range(100))
all_inds.remove(gt_option)
# debug
if num_options < 100:
random.shuffle(all_inds)
all_inds = all_inds[:(num_options-1)]
option_inds.extend(all_inds)
gt_option = 0
else:
option_inds = range(num_options)
answer_options = [dialog['dialog'][cur_rounds - 1]['answer_options'][k] for k in option_inds]
if 'relevance' in cur_dense_annotations[index]:
key = 'relevance'
else:
key = 'gt_relevance'
gt_relevance = torch.Tensor(cur_dense_annotations[index][key])
gt_relevance = gt_relevance[option_inds]
assert len(answer_options) == len(option_inds) == num_options
else:
answer_options = dialog['dialog'][-1]['answer_options']
assert len(answer_options) == num_options
options_all = []
for answer_option in answer_options:
cur_option = cur_rnd_utterance.copy()
cur_option.append(self.tokenizer.convert_tokens_to_ids(cur_answers[answer_option].split(' ')))
options_all.append(cur_option)
if not self.config['rlv_hst_only']:
assert len(cur_option) == 2 * cur_rounds + 1
tokens_all = []
mask_all = []
segments_all = []
sep_indices_all = []
hist_len_all = []
tot_len_debug = []
for opt_id, option in enumerate(options_all):
option, start_segment = self.pruneRounds(option, self.config['visdial_tot_rounds'])
tokens, segments, sep_indices, mask, start_question, end_question = encode_input(option, start_segment ,self.CLS,
self.SEP, self.MASK ,max_seq_len=MAX_SEQ_LEN, mask_prob=0)
tokens_all.append(tokens)
mask_all.append(mask)
segments_all.append(segments)
sep_indices_all.append(sep_indices)
hist_len_all.append(torch.LongTensor([len(option)-1]))
len_tokens = sum(len(s) for s in option)
tot_len_debug.append(len_tokens + len(option) + 1)
tokens_all = torch.cat(tokens_all,0)
mask_all = torch.cat(mask_all,0)
segments_all = torch.cat(segments_all, 0)
sep_indices_all = torch.cat(sep_indices_all, 0)
hist_len_all = torch.cat(hist_len_all,0)
question_limits_all = torch.tensor([start_question, end_question]).unsqueeze(0).repeat(num_options, 1)
if self.config['rlv_hst_only']:
assert num_rlv_rnds > 0
hist_idx = [i * 2 for i in range(num_rlv_rnds)]
else:
hist_idx = [i*2 for i in range(cur_rounds)]
history_sep_indices_all = sep_indices.squeeze(0)[hist_idx].contiguous().unsqueeze(0).repeat(num_options, 1)
with open(os.path.join(cur_question_gr_path, f'{graph_idx}.pkl'), 'rb') as f:
question_graphs = pickle.load(f)
question_graph_round = question_graphs[cur_rounds - 1]
question_edge_index = []
question_edge_attribute = []
for edge_index, edge_attr in question_graph_round:
question_edge_index.append(edge_index)
edge_attr_one_hot = np.zeros((len(self.parse_vocab) + 1,), dtype=np.float32)
edge_attr_one_hot[self.parse_vocab.get(edge_attr, len(self.parse_vocab))] = 1.0
question_edge_attribute.append(edge_attr_one_hot)
question_edge_index = np.array(question_edge_index, dtype=np.float64)
question_edge_attribute = np.stack(question_edge_attribute, axis=0)
question_edge_indices_all = [torch.from_numpy(question_edge_index).t().long().contiguous() for _ in range(num_options)]
question_edge_attributes_all = [torch.from_numpy(question_edge_attribute).contiguous() for _ in range(num_options)]
if self.config['rlv_hst_only']:
with open(os.path.join(cur_history_gr_path, f'{graph_idx}.pkl'), 'rb') as f:
_history_edge_incides_round = pickle.load(f)
else:
with open(os.path.join(cur_history_gr_path, f'{graph_idx}.pkl'), 'rb') as f:
_history_edge_incides_all = pickle.load(f)
_history_edge_incides_round = _history_edge_incides_all[cur_rounds - 1]
history_edge_index_all = [torch.tensor(_history_edge_incides_round).t().long().contiguous() for _ in range(num_options)]
if self.config['stack_gr_data']:
question_edge_indices_all = torch.stack(question_edge_indices_all, dim=0)
question_edge_attributes_all = torch.stack(question_edge_attributes_all, dim=0)
history_edge_index_all = torch.stack(history_edge_index_all, dim=0)
item = {}
item['tokens'] = tokens_all.unsqueeze(0) # [1, num_options, max_len]
item['segments'] = segments_all.unsqueeze(0)
item['sep_indices'] = sep_indices_all.unsqueeze(0)
item['mask'] = mask_all.unsqueeze(0)
item['hist_len'] = hist_len_all.unsqueeze(0)
item['question_limits'] = question_limits_all
item['question_edge_indices'] = question_edge_indices_all
item['question_edge_attributes'] = question_edge_attributes_all
item['history_edge_indices'] = history_edge_index_all
item['history_sep_indices'] = history_sep_indices_all
# add dense annotation fields
if self._split != 'test':
item['gt_relevance'] = gt_relevance # [num_options]
item['gt_option_inds'] = torch.LongTensor([gt_option])
# add next sentence labels for training with the nsp loss as well
nsp_labels = torch.ones(*tokens_all.unsqueeze(0).shape[:-1]).long()
nsp_labels[:,gt_option] = 0
item['next_sentence_labels'] = nsp_labels
item['round_id'] = torch.LongTensor([cur_rounds])
else:
if 'round_id' in dialog:
item['round_id'] = torch.LongTensor([dialog['round_id']])
else:
item['round_id'] = torch.LongTensor([cur_rounds])
# get image features
if not self.config['dataloader_text_only']:
features, num_boxes, boxes, _ , image_target, image_edge_indexes, image_edge_attributes = self._image_features_reader[img_id]
features, spatials, image_mask, image_target, image_label = encode_image_input(features, num_boxes, boxes, image_target, max_regions=self._max_region_num, mask_prob=0)
else:
features = spatials = image_mask = image_target = image_label = torch.tensor([0])
item['image_feat'] = features
item['image_loc'] = spatials
item['image_mask'] = image_mask
item['image_id'] = torch.LongTensor([img_id])
item['tot_len'] = torch.LongTensor(tot_len_debug)
item['image_edge_indices'] = [torch.from_numpy(image_edge_indexes).contiguous().long() for _ in range(num_options)]
item['image_edge_attributes'] = [torch.from_numpy(image_edge_attributes).contiguous() for _ in range(num_options)]
if self.config['stack_gr_data']:
item['image_edge_indices'] = torch.stack(item['image_edge_indices'], dim=0)
item['image_edge_attributes'] = torch.stack(item['image_edge_attributes'], dim=0)
return item