104 lines
3.7 KiB
Python
104 lines
3.7 KiB
Python
|
import glob, os, json, pickle
|
||
|
import numpy as np
|
||
|
from collections import defaultdict
|
||
|
|
||
|
import torch
|
||
|
from torch.utils.data import Dataset
|
||
|
import torchvision.transforms as transforms
|
||
|
|
||
|
from core.data.utils import tokenize, ans_stat, proc_ques, qlen_to_key, ans_to_key
|
||
|
|
||
|
|
||
|
class VideoQA_Dataset(Dataset):
|
||
|
def __init__(self, __C):
|
||
|
super(VideoQA_Dataset, self).__init__()
|
||
|
self.__C = __C
|
||
|
self.ans_size = __C.NUM_ANS
|
||
|
# load raw data
|
||
|
with open(__C.QA_PATH[__C.RUN_MODE], 'r') as f:
|
||
|
self.raw_data = json.load(f)
|
||
|
self.data_size = len(self.raw_data)
|
||
|
|
||
|
splits = __C.SPLIT[__C.RUN_MODE].split('+')
|
||
|
|
||
|
frames_list = glob.glob(__C.FRAMES + '*.pt')
|
||
|
clips_list = glob.glob(__C.CLIPS + '*.pt')
|
||
|
if 'msvd' in self.C.DATASET_PATH.lower():
|
||
|
vid_ids = [int(s.split('/')[-1].split('.')[0][3:]) for s in frames_list]
|
||
|
else:
|
||
|
vid_ids = [int(s.split('/')[-1].split('.')[0][5:]) for s in frames_list]
|
||
|
self.frames_dict = {k: v for (k,v) in zip(vid_ids, frames_list)}
|
||
|
self.clips_dict = {k: v for (k,v) in zip(vid_ids, clips_list)}
|
||
|
del frames_list, clips_list
|
||
|
|
||
|
q_list = []
|
||
|
a_list = []
|
||
|
a_dict = defaultdict(lambda: 0)
|
||
|
for split in ['train', 'val']:
|
||
|
with open(__C.QA_PATH[split], 'r') as f:
|
||
|
qa_data = json.load(f)
|
||
|
for d in qa_data:
|
||
|
q_list.append(d['question'])
|
||
|
a_list = d['answer']
|
||
|
if d['answer'] not in a_dict:
|
||
|
a_dict[d['answer']] = 1
|
||
|
else:
|
||
|
a_dict[d['answer']] += 1
|
||
|
|
||
|
top_answers = sorted(a_dict, key=a_dict.get, reverse=True)
|
||
|
self.qlen_bins_to_idx = {
|
||
|
'1-3': 0,
|
||
|
'4-8': 1,
|
||
|
'9-15': 2,
|
||
|
}
|
||
|
self.ans_rare_to_idx = {
|
||
|
'0-99': 0,
|
||
|
'100-299': 1,
|
||
|
'300-999': 2,
|
||
|
|
||
|
}
|
||
|
self.qtypes_to_idx = {
|
||
|
'what': 0,
|
||
|
'who': 1,
|
||
|
'how': 2,
|
||
|
'when': 3,
|
||
|
'where': 4,
|
||
|
}
|
||
|
|
||
|
if __C.RUN_MODE == 'train':
|
||
|
self.ans_list = top_answers[:self.ans_size]
|
||
|
|
||
|
self.ans_to_ix, self.ix_to_ans = ans_stat(self.ans_list)
|
||
|
|
||
|
self.token_to_ix, self.pretrained_emb = tokenize(q_list, __C.USE_GLOVE)
|
||
|
self.token_size = self.token_to_ix.__len__()
|
||
|
print('== Question token vocab size:', self.token_size)
|
||
|
|
||
|
self.idx_to_qtypes = {v: k for (k, v) in self.qtypes_to_idx.items()}
|
||
|
self.idx_to_qlen_bins = {v: k for (k, v) in self.qlen_bins_to_idx.items()}
|
||
|
self.idx_to_ans_rare = {v: k for (k, v) in self.ans_rare_to_idx.items()}
|
||
|
|
||
|
def __getitem__(self, idx):
|
||
|
sample = self.raw_data[idx]
|
||
|
ques = sample['question']
|
||
|
q_type = self.qtypes_to_idx[ques.split(' ')[0]]
|
||
|
ques_idx, qlen, _ = proc_ques(ques, self.token_to_ix, self.__C.MAX_TOKEN)
|
||
|
qlen_bin = self.qlen_bins_to_idx[qlen_to_key(qlen)]
|
||
|
|
||
|
answer = sample['answer']
|
||
|
answer = self.ans_to_ix.get(answer, np.random.randint(0, high=len(self.ans_list)))
|
||
|
ans_rarity = self.ans_rare_to_idx[ans_to_key(answer)]
|
||
|
|
||
|
answer_one_hot = torch.zeros(self.ans_size)
|
||
|
answer_one_hot[answer] = 1.0
|
||
|
|
||
|
vid_id = sample['video_id']
|
||
|
frames = torch.load(open(self.frames_dict[vid_id], 'rb')).cpu()
|
||
|
clips = torch.load(open(self.clips_dict[vid_id], 'rb')).cpu()
|
||
|
|
||
|
return torch.from_numpy(ques_idx).long(), frames, clips, answer_one_hot, torch.tensor(answer).long(), \
|
||
|
torch.tensor(q_type).long(), torch.tensor(qlen_bin).long(), torch.tensor(ans_rarity).long()
|
||
|
|
||
|
def __len__(self):
|
||
|
return self.data_size
|