initial commit
This commit is contained in:
commit
a82bbc593e
129 changed files with 33981 additions and 0 deletions
86
datasets/nextqa_dataset.py
Normal file
86
datasets/nextqa_dataset.py
Normal file
|
@ -0,0 +1,86 @@
|
|||
import os
|
||||
import pandas as pd
|
||||
# import h5py
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from .video_utils import read_frames_decord
|
||||
|
||||
|
||||
def load_file(file_name):
|
||||
annos = None
|
||||
if os.path.splitext(file_name)[-1] == '.csv':
|
||||
return pd.read_csv(file_name)
|
||||
with open(file_name, 'r') as fp:
|
||||
if os.path.splitext(file_name)[1]== '.txt':
|
||||
annos = fp.readlines()
|
||||
annos = [line.rstrip() for line in annos]
|
||||
if os.path.splitext(file_name)[1] == '.json':
|
||||
annos = json.load(fp)
|
||||
return annos
|
||||
|
||||
|
||||
class NextQADataset(Dataset):
|
||||
def __init__(self, config, medium, vis_processor, text_processor, split):
|
||||
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.medium = medium
|
||||
self.vis_processor = vis_processor
|
||||
self.text_processor = text_processor
|
||||
self.split = split
|
||||
|
||||
self.batch_size = config['batch_size_test_{}'.format(medium)] if split == 'test' else config['batch_size_{}'.format(medium)]
|
||||
self.root_vis = config['root_raw_vis_{}_{}'.format(medium, split)]
|
||||
with open(config['vid_mapping_nextqa'], 'r') as f:
|
||||
self.video_mapping = json.load(f)
|
||||
|
||||
self.sample_list = load_file(self.config['anno_nextqa_{}'.format(split)])
|
||||
|
||||
if split == 'test':
|
||||
self.sample_list = self.sample_list[config['start_idx_gen']: config['end_idx_gen']]
|
||||
self.captions = load_file(self.config['next_qa_captions_{}'.format(split)])
|
||||
else:
|
||||
self.captions = None
|
||||
|
||||
num_samples = config['num_samples_{}'.format(self.medium)]
|
||||
if num_samples > 0:
|
||||
self.sample_list = self.sample_list[:num_samples]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sample_list)
|
||||
|
||||
|
||||
def load_vid(self, vid_id):
|
||||
vid_dir_path = os.path.join(self.root_vis, self.video_mapping[vid_id] + '.mp4')
|
||||
|
||||
frames, _, _ = read_frames_decord(vid_dir_path, self.config.num_frames)
|
||||
frames = [self.vis_processor(f).unsqueeze(0) for f in frames]
|
||||
|
||||
vis = torch.cat(frames, dim=0)
|
||||
return vis
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if self.split == 'test':
|
||||
idx += self.config['start_idx_gen']
|
||||
|
||||
cur_sample = self.sample_list.loc[idx]
|
||||
video_id, ques, ans, qid = str(cur_sample['video']), str(cur_sample['question']),\
|
||||
str(cur_sample['answer']), str(cur_sample['qid'])
|
||||
|
||||
history = self.text_processor(ques)
|
||||
answer = self.text_processor(ans)
|
||||
if self.split == 'test':
|
||||
caption = self.text_processor(self.captions[video_id])
|
||||
else:
|
||||
caption = self.text_processor('please answer the following question based on the video')
|
||||
vis = self.load_vid(video_id)
|
||||
|
||||
return vis, caption, history, answer, video_id, qid
|
||||
|
||||
def load_nextqa_dataset(config, vis_processor, text_processor, split):
|
||||
# data_file = config['anno_avsd_{}'.format(split)]
|
||||
# dataset_list = get_dataset(config, split, tokenizer_enc_dec)
|
||||
dataset = NextQADataset(config, 'nextqa', vis_processor, text_processor, split)
|
||||
return dataset
|
Loading…
Add table
Add a link
Reference in a new issue