MST-MIXER/custom_datasets/avsd.py
2024-07-08 11:41:28 +02:00

401 lines
16 KiB
Python

import os
import pickle
import pyhocon
from copy import deepcopy
import json
from tqdm import tqdm
import numpy as np
import torch
from argparse import ArgumentParser
from torch.utils.data import Dataset, DataLoader
from transformers import BartTokenizer
from itertools import chain
ADDITIONAL_SPECIAL_TOKENS = [
'<place_holder>', '<s0>', '<s1>', '<s2>', '<s3>', '<s4>', '<s5>']
SPECIAL_TOKENS_DICT = {
'bos_token': '<s>',
'eos_token': '</s>',
'pad_token': '<pad>',
'additional_special_tokens': ADDITIONAL_SPECIAL_TOKENS
}
S0_TOK = '<s0>' # I3D_flow
S1_TOK = '<s1>' # I3D_rgb
S2_TOK = '<s2>' # sam obj
S3_TOK = '<s3>' # audio
S4_TOK = '<s4>' # history
S5_TOK = '<s5>' # question
def tokenize(obj, tokenizer):
if isinstance(obj, str):
return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj))
if isinstance(obj, dict):
return dict((n, tokenize(o)) for n, o in obj.items())
return list(tokenize(o) for o in obj)
class AVSDDataset(Dataset):
def __init__(self, config, split):
super().__init__()
self.config = config
self.split = split
self.bart_max_input_len = config['bart_max_input_len']
self.bart_size = config['bart_size']
self.cap_sum = config['cap_sum']
self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-{}'.format(self.bart_size))
self.vocab_size = self.tokenizer.vocab_size
self.tokenizer.add_special_tokens({'additional_special_tokens': ADDITIONAL_SPECIAL_TOKENS})
self.vocab_size += len(ADDITIONAL_SPECIAL_TOKENS)
self.tokenizer.save_pretrained(os.path.join(self.config['log_dir'], 'bart_tokenizer'))
self.processed_dir = os.path.join(self.config['avsd_processed'], 'hist_with_{}_rounds'.format(self.config['n_history']), split)
self.paths = list(map(lambda p: os.path.join(self.processed_dir, p), os.listdir(self.processed_dir)))
if self.config['overfit'] > 0:
self.paths = self.paths[:self.config['overfit_size']]
def __len__(self):
return len(self.paths)
def __getitem__(self, index):
pth = self.paths[index]
with open(pth, 'rb') as f:
item = pickle.load(f)
question_sep = self.tokenizer.convert_tokens_to_ids('<s5>')
input_ids = item['input_ids']
history_end = (input_ids == question_sep).nonzero(as_tuple=True)[0]
history_interval = [0, history_end.item()] # The last token is the question state token (not part of the history)
question_interval = [history_end.item(), input_ids.size(0)]
lm_labels = item['lm_labels']
i3d_rgb = item['i3d_rgb']
i3d_flow = item['i3d_flow']
sam = item['sam']
vgg = item['vgg']
vid = item['vid']
return input_ids, lm_labels, history_interval, question_interval, i3d_rgb, i3d_flow, sam, vgg, vid
def padding(self, seq, pad_token, max_len=None):
if max_len is None:
max_len = max([i.size(0) for i in seq])
if len(seq[0].size()) == 1:
result = torch.ones((len(seq), max_len)).long() * pad_token
else:
result = torch.ones((len(seq), max_len, seq[0].size(-1))).float()
for i in range(len(seq)):
result[i, :seq[i].size(0)] = seq[i]
return result
def collate_fn(self, batch):
input_ids_list, lm_labels_list, history_interval_list, question_interval_list, i3d_rgb_list, i3d_flow_list, sam_list, vggish_list, vid_ids_list = [], [], [], [], [], [], [], [], []
for i in batch:
input_ids_list.append(i[0])
lm_labels_list.append(i[1])
history_interval_list.append(i[2])
question_interval_list.append(i[3])
i3d_rgb_list.append(i[4])
i3d_flow_list.append(i[5])
sam_list.append(i[6])
vggish_list.append(i[7])
vid_ids_list.append(i[8])
history_intervals = np.array(history_interval_list)
question_intervals = np.array(question_interval_list)
min_len_i3d_flow = min([feat.shape[0] for feat in i3d_flow_list])
min_len_i3d_rgb = min([feat.shape[0] for feat in i3d_rgb_list])
min_len_sam = min([feat.shape[0] for feat in sam_list])
min_len_vggish = min([feat.shape[0] for feat in vggish_list])
min_length = min([self.config['vis_feat_length'], min_len_i3d_flow, min_len_i3d_rgb, min_len_sam, min_len_vggish])
# Sample equally-distant features from the visual features for each sample within the batch
for i in range(len(i3d_rgb_list)):
sample_idx_i3d_rgb = np.round(np.linspace(0, i3d_rgb_list[i].shape[0] - 1, min_length)).astype(int)
i3d_rgb_list[i] = i3d_rgb_list[i][sample_idx_i3d_rgb, :]
i3d_rgb = torch.from_numpy(np.array(i3d_rgb_list)).float()
for i in range(len(i3d_flow_list)):
sample_idx_i3d_flow = np.round(np.linspace(0, i3d_flow_list[i].shape[0] - 1, min_length)).astype(int)
i3d_flow_list[i] = i3d_flow_list[i][sample_idx_i3d_flow, :]
i3d_flow = torch.from_numpy(np.array(i3d_flow_list)).float()
for i in range(len(sam_list)):
sample_idx_sam = np.round(np.linspace(0, sam_list[i].shape[0] - 1, min_length)).astype(int)
sam_list[i] = sam_list[i][sample_idx_sam, :]
sam = torch.from_numpy(np.array(sam_list)).float()
for i in range(len(vggish_list)):
sample_idx_vggish = np.round(np.linspace(0, vggish_list[i].shape[0] - 1, min_length)).astype(int)
vggish_list[i] = vggish_list[i][sample_idx_vggish, :]
vggish = torch.from_numpy(np.array(vggish_list)).float()
pad_token, i3d_flow_sep, i3d_rgb_sep, sam_sep, audio_sep, ph_token = self.tokenizer.convert_tokens_to_ids(
['<pad>', '<s0>', '<s1>', '<s2>', '<s3>', '<place_holder>'])
# All the visual features will not be masked because we do not perform any padding on them
video_mask = torch.ones((len(batch), min_length*4 + 4)) == 1 # NOTE *4: 4 modalities | +4: the state tokens
# Now we create a dummy input for the video tokens (sole purpose is to reserve the spot of the seperators)
dummy = torch.ones((len(batch), min_length)) * ph_token
video_place_holder_ids = torch.cat(
[torch.ones((len(batch), 1)) * i3d_rgb_sep, dummy,
torch.ones((len(batch), 1)) * i3d_flow_sep, dummy,
torch.ones((len(batch), 1)) * sam_sep, dummy,
torch.ones((len(batch), 1)) * audio_sep, dummy,
], dim=-1).long()
input_ids = self.padding(input_ids_list, pad_token)
lm_labels = self.padding(lm_labels_list, -100)
text_mask = input_ids != pad_token
input_mask = torch.cat([video_mask, text_mask], dim=1)
# Now we get the intervals of the visual input tokens
# Here the interval do not change across the batch dimension
i3d_rgb_interval = [0, min_length + 1] # the last token is not part of this modality
i3d_flow_interval = [min_length + 1, 2 * min_length + 2]
sam_interval = [2 * min_length + 2, 3 * min_length + 3]
audio_interval = [3 * min_length + 3, 4 * min_length + 4]
vis_state_vector_idx = [i3d_rgb_interval[0], i3d_flow_interval[0], sam_interval[0], audio_interval[0]]
# adapt the question and history interval -- shifted to the right by the visual input length
history_intervals += 4 * min_length + 4
question_intervals += 4 * min_length + 4
history_intervals = history_intervals.tolist()
question_intervals = question_intervals.tolist()
history_state_vector_idx = [x[0] + 1 for x in history_intervals] # +1 because the history starts with <s><s4> .....
question_state_vector_idx = [x[0] for x in question_intervals] # +1 because the history starts with <s><s4> .....
batch = {
'input_ids': input_ids,
'video_place_holder_ids': video_place_holder_ids,
'i3d_rgb': i3d_rgb,
'i3d_flow': i3d_flow,
'sam': sam,
'vggish': vggish,
'lm_labels': lm_labels,
'input_mask': input_mask,
'i3d_rgb_interval': i3d_rgb_interval,
'i3d_flow_interval': i3d_flow_interval,
'sam_interval': sam_interval,
'audio_interval': audio_interval,
'history_intervals': history_intervals,
'question_intervals': question_intervals,
'vis_state_vector_idx': vis_state_vector_idx,
'history_state_vector_idx': history_state_vector_idx,
'question_state_vector_idx': question_state_vector_idx
}
return batch
def get_dataset(config, split, tokenizer):
if split != 'test':
dialog_pth = config[f'avsd_{split}']
else:
dialog_pth = config['avsd_test_dstc{}'.format(config['dstc'])]
n_history = config['n_history']
dialog_data = json.load(open(dialog_pth, 'r'))
dialog_list = []
vid_set = set()
undisclosed_only = split == 'test'
pbar = tqdm(dialog_data['dialogs'])
pbar.set_description('[INFO] Generating {} items | DSTC {}'.format(split, config['dstc']))
for dialog in pbar:
if config['dstc'] != 10:
caption = [tokenize(dialog['caption'], tokenizer)] + [tokenize(dialog['summary'], tokenizer)]
else:
caption = [tokenize('no', tokenizer)]
questions = [tokenize(d['question'], tokenizer) for d in dialog['dialog']]
answers = [tokenize(d['answer'], tokenizer) for d in dialog['dialog']]
vid = dialog["image_id"]
vid_set.add(vid)
if undisclosed_only:
it = range(len(questions) - 1, len(questions))
else:
it = range(len(questions))
qalist=[]
history = []
if undisclosed_only:
for n in range(len(questions)-1):
qalist.append(questions[n])
qalist.append(answers[n])
history=qalist[max(-len(qalist),-n_history*2):]
for n in it:
if undisclosed_only:
assert dialog['dialog'][n]['answer'] == '__UNDISCLOSED__'
question = questions[n]
answer = answers[n]
history.append(question)
if n_history == 0:
item = {'vid': vid, 'history': [question], 'answer': answer, 'caption': caption}
else:
item = {'vid': vid, 'history': history, 'answer': answer, 'caption': caption}
dialog_list.append(item)
qalist.append(question)
qalist.append(answer)
history=qalist[max(-len(qalist),-n_history*2):]
all_features = {}
fea_types = ['vggish', 'i3d_flow', 'i3d_rgb', 'sam']
dataname = '<FeaType>/<ImageID>.npy'
for ftype in fea_types:
if undisclosed_only:
basename = dataname.replace('<FeaType>', ftype+'_testset')
else:
basename = dataname.replace('<FeaType>', ftype)
features = {}
for vid in vid_set:
filename = basename.replace('<ImageID>', vid)
filepath = config['avsd_feature_path'] + filename
features[vid] = filepath
all_features[ftype] = features
return dialog_list, all_features
def build_input_from_segments(caption, history_orig, reply, tokenizer, add_state_tokens=True, drop_caption=False):
""" Build a sequence of input from 3 segments: caption(caption+summary) history and last reply """
bos, eos, hist_state, ques_state = tokenizer.convert_tokens_to_ids(['<s>', '</s>', '<s4>', '<s5>'])
sep = eos
instance = {}
instance["lm_labels"] = reply + [eos]
caption = list(chain(*caption))
# Add state tokens if applicable
if add_state_tokens:
caption.insert(0, hist_state)
history = deepcopy(history_orig)
history[-1].insert(0, ques_state)
else:
history = history_orig
if not drop_caption:
# sequence = [[bos] + list(chain(*caption))] + history + [reply + ([eos] if with_eos else [])]
# NOTE It is important not to include the reply in the input of the encoder -- > the decoder will just
# learn to copy it --> low train/val loss but no learning is happening
sequence = [[bos] + caption + [eos]] + [[sep] + s for s in history] + [[eos]]
else:
sequence = [[bos]] + [[hist_state]] + [[sep] + s for s in history] + [[eos]]
instance["input_ids"] = list(chain(*sequence))
return instance
def parse_args():
parser = ArgumentParser(description='debug dataloader')
parser.add_argument(
'--split',
type=str,
default='train',
help='train or val')
parser.add_argument(
'--model',
type=str,
default='mixer',
help='model name to train or test')
parser.add_argument(
'--log_dataset',
action='store_true',
default=False,
help='Whether or not to log the processed data')
parser.add_argument(
'--add_state_tokens',
action='store_true',
default=True,
help='Whether or not to add state tokens')
parser.add_argument(
'--log_dir',
type=str,
default='processed/avsd',
help='Output directory')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
split = args.split
config = pyhocon.ConfigFactory.parse_file(
'config/mst_mixer.conf')[args.model]
config['expand_rnd'] = False
config['debugging'] = False
config['overfit'] = False
args.log_dir = os.path.join(args.log_dir, 'hist_with_{}_rounds'.format(config['n_history']) )
if args.log_dataset:
log_dir = os.path.join(args.log_dir, split)
if not os.path.isdir(log_dir):
os.makedirs(log_dir)
tokenizer = BartTokenizer.from_pretrained('facebook/bart-{}'.format(config['bart_size']))
tokenizer.add_special_tokens({'additional_special_tokens': ADDITIONAL_SPECIAL_TOKENS})
dialogs, features = get_dataset(config, split, tokenizer)
pbar = tqdm(dialogs)
pbar.set_description('[{}] Logging processed data'.format(split))
counter = 0
for dialog in pbar:
vid = dialog['vid']
his = dialog['history']
cap = dialog['caption']
ans = dialog['answer']
if np.random.rand() < config['caption_drop_rate']:
instance = build_input_from_segments(
cap, his, ans, tokenizer, add_state_tokens=args.add_state_tokens, drop_caption=True)
else:
instance = build_input_from_segments(
cap, his, ans, tokenizer, add_state_tokens=args.add_state_tokens, drop_caption=False)
input_ids = torch.Tensor(instance["input_ids"]).long()
lm_labels = torch.Tensor(instance["lm_labels"]).long()
vgg = np.load(features["vggish"][vid])
i3d_flow = np.load(features["i3d_flow"][vid])
i3d_rgb = np.load(features["i3d_rgb"][vid])
sam = np.load(features["sam"][vid])
item = {
'input_ids': input_ids,
'lm_labels': lm_labels,
'i3d_rgb': i3d_rgb,
'i3d_flow': i3d_flow,
'sam': sam,
'vgg': vgg,
'vid': vid
}
counter += 1
pth = os.path.join(log_dir, str(counter) + '.pkl')
with open(pth, 'wb') as f:
pickle.dump(item, f, protocol=pickle.HIGHEST_PROTOCOL)
else:
avsd_dataset = AVSDDataset(config, 'val')
avsd_dataloader = DataLoader(avsd_dataset, batch_size=4, shuffle=False, collate_fn=avsd_dataset.collate_fn)
for i, data in enumerate(avsd_dataloader):
print('{}/{}'.format(i, len(avsd_dataloader)))
print(avsd_dataset.max_len)
print('[INFO] Done...')