MST-MIXER/custom_datasets/nextqa.py

212 lines
7.7 KiB
Python
Raw Normal View History

2024-07-08 11:41:28 +02:00
import os
import pandas as pd
import h5py
import json
import numpy as np
import torch
from torch.utils.data import Dataset
from transformers import BartTokenizer
from itertools import chain
ADDITIONAL_SPECIAL_TOKENS = [
'<place_holder>', '<s0>', '<s1>', '<s2>', '<s3>', '<s4>', '<s5>']
SPECIAL_TOKENS_DICT = {
'bos_token': '<s>',
'eos_token': '</s>',
'pad_token': '<pad>',
'additional_special_tokens': ADDITIONAL_SPECIAL_TOKENS
}
S0_TOK = '<s0>' # frame
S1_TOK = '<s1>' # mot
S2_TOK = '<s2>' # question
def load_file(file_name):
annos = None
if os.path.splitext(file_name)[-1] == '.csv':
return pd.read_csv(file_name)
with open(file_name, 'r') as fp:
if os.path.splitext(file_name)[1]== '.txt':
annos = fp.readlines()
annos = [line.rstrip() for line in annos]
if os.path.splitext(file_name)[1] == '.json':
annos = json.load(fp)
return annos
def tokenize(obj, tokenizer):
if isinstance(obj, str):
return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj))
if isinstance(obj, dict):
return dict((n, tokenize(o)) for n, o in obj.items())
return list(tokenize(o) for o in obj)
class NextQADataset(Dataset):
def __init__(self, config, split):
super().__init__()
self.config = config
self.split = split
self.bart_max_input_len = config['bart_max_input_len']
self.bart_size = config['bart_size']
self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-{}'.format(self.bart_size))
self.vocab_size = self.tokenizer.vocab_size
self.tokenizer.add_special_tokens({'additional_special_tokens': ADDITIONAL_SPECIAL_TOKENS})
self.vocab_size += len(ADDITIONAL_SPECIAL_TOKENS)
self.tokenizer.save_pretrained(os.path.join(self.config['log_dir'], 'bart_tokenizer'))
sample_list_file = os.path.join(self.config['nextqa_root'], '{}.csv'.format(split))
self.sample_list = load_file(sample_list_file)
vid_feat_file = os.path.join(self.config['nextqa_vid_feat'], 'app_mot_{}.h5'.format(split))
print('Load {}...'.format(vid_feat_file))
self.frame_feats = {}
self.mot_feats = {}
with h5py.File(vid_feat_file, 'r') as fp:
vids = fp['ids']
feats = fp['feat']
for vid, feat in zip(vids, feats):
self.frame_feats[str(vid)] = feat[:, :2048] # (16, 2048)
self.mot_feats[str(vid)] = feat[:, 2048:] # (16, 2048)
if self.config['overfit_size'] > 0:
self.sample_list = self.sample_list[:self.config['overfit_size']]
def __len__(self):
return len(self.sample_list)
def get_video_feature(self, video_name):
"""
:param video_name:
:return:
"""
app_feat = self.frame_feats[video_name]
app_feat = torch.from_numpy(app_feat).type(torch.float32)
mot_feat = self.mot_feats[video_name]
mot_feat = torch.from_numpy(mot_feat).type(torch.float32)
return app_feat, mot_feat
def __getitem__(self, idx):
cur_sample = self.sample_list.loc[idx]
video_name, ques, ans, qid = str(cur_sample['video']), str(cur_sample['question']),\
str(cur_sample['answer']), str(cur_sample['qid'])
input_ids = tokenize(ques, self.tokenizer)
lm_labels = tokenize(ans, self.tokenizer)
app_feat, mot_feat = self.get_video_feature(video_name)
bos, eos, ques_state = self.tokenizer.convert_tokens_to_ids(['<s>', '</s>', '<s2>'])
# Add state tokens
input_ids.insert(0, ques_state)
lm_labels.append(eos)
question_interval = [0, len(input_ids)]
input_ids = torch.Tensor(input_ids).long()
lm_labels = torch.Tensor(lm_labels).long()
return input_ids, lm_labels, app_feat, mot_feat, question_interval, video_name
def padding(self, seq, pad_token, max_len=None):
if max_len is None:
max_len = max([i.size(0) for i in seq])
if len(seq[0].size()) == 1:
result = torch.ones((len(seq), max_len)).long() * pad_token
else:
result = torch.ones((len(seq), max_len, seq[0].size(-1))).float()
for i in range(len(seq)):
result[i, :seq[i].size(0)] = seq[i]
return result
def collate_fn(self, batch):
input_ids_list, lm_labels_list, app_feat_list, mot_feat_list, question_interval_list, vid_ids_list = [], [], [], [], [], []
for i in batch:
input_ids_list.append(i[0])
lm_labels_list.append(i[1])
app_feat_list.append(i[2])
mot_feat_list.append(i[3])
question_interval_list.append(i[4])
vid_ids_list.append(i[5])
app_feats = torch.stack(app_feat_list, dim=0).float()
mot_feats = torch.stack(mot_feat_list, dim=0).float()
question_intervals = np.array(question_interval_list)
pad_token, app_sep, mot_sep, ph_token = self.tokenizer.convert_tokens_to_ids(
['<pad>', '<s0>', '<s1>', '<place_holder>'])
# All the visual features will not be masked because we do not perform any padding on them
video_mask = torch.ones((len(batch), 16*2 + 2)) == 1 # NOTE *2: 2 modalities | +2: the state tokens | each modality has length 16
# Now we create a dummy input for the video tokens (sole purpose is to reserve the spot of the seperators)
dummy = torch.ones((len(batch), 16)) * ph_token
video_place_holder_ids = torch.cat(
[torch.ones((len(batch), 1)) * app_sep, dummy,
torch.ones((len(batch), 1)) * mot_sep, dummy,
], dim=-1).long()
input_ids = self.padding(input_ids_list, pad_token)
lm_labels = self.padding(lm_labels_list, -100)
text_mask = input_ids != pad_token
input_mask = torch.cat([video_mask, text_mask], dim=1)
# Now we get the intervals of the visual input tokens
# Here the interval do not change across the batch dimension
app_interval = [0, 16 + 1] # the last token is not part of this modality
mot_interval = [16 + 1, 2 * 16 + 2]
vis_state_vector_idx = [app_interval[0], mot_interval[0]]
# adapt the question and history interval -- shifted to the right by the visual input length
question_intervals += 2 * 16 + 2
question_intervals = question_intervals.tolist()
question_state_vector_idx = [x[0] for x in question_intervals]
batch = {
'input_ids': input_ids,
'video_place_holder_ids': video_place_holder_ids,
'app_feats': app_feats,
'mot_feats': mot_feats,
'lm_labels': lm_labels,
'input_mask': input_mask,
'app_interval': app_interval,
'mot_interval': mot_interval,
'question_intervals': question_intervals,
'vis_state_vector_idx': vis_state_vector_idx,
'question_state_vector_idx': question_state_vector_idx
}
return batch
def get_dataset(config, split):
bart_max_input_len = config['bart_max_input_len']
bart_size = config['bart_size']
sample_list_file = os.path.join(config['nextqa_root'], '{}.csv'.format(split))
sample_list = load_file(sample_list_file)
vid_feat_file = os.path.join(config['nextqa_vid_feat'], 'app_mot_{}.h5'.format(split))
print('Load {}...'.format(vid_feat_file))
app_feats = {}
mot_feats = {}
with h5py.File(vid_feat_file, 'r') as fp:
vids = fp['ids']
feats = fp['feat']
for vid, feat in zip(vids, feats):
app_feats[str(vid)] = feat[:, :2048] # (16, 2048)
mot_feats[str(vid)] = feat[:, 2048:] # (16, 2048)
return sample_list, app_feats, mot_feats