86 lines
3.1 KiB
Python
86 lines
3.1 KiB
Python
import os
|
|
import pandas as pd
|
|
# import h5py
|
|
import json
|
|
import numpy as np
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
from .video_utils import read_frames_decord
|
|
|
|
|
|
def load_file(file_name):
|
|
annos = None
|
|
if os.path.splitext(file_name)[-1] == '.csv':
|
|
return pd.read_csv(file_name)
|
|
with open(file_name, 'r') as fp:
|
|
if os.path.splitext(file_name)[1]== '.txt':
|
|
annos = fp.readlines()
|
|
annos = [line.rstrip() for line in annos]
|
|
if os.path.splitext(file_name)[1] == '.json':
|
|
annos = json.load(fp)
|
|
return annos
|
|
|
|
|
|
class NextQADataset(Dataset):
|
|
def __init__(self, config, medium, vis_processor, text_processor, split):
|
|
|
|
super().__init__()
|
|
self.config = config
|
|
self.medium = medium
|
|
self.vis_processor = vis_processor
|
|
self.text_processor = text_processor
|
|
self.split = split
|
|
|
|
self.batch_size = config['batch_size_test_{}'.format(medium)] if split == 'test' else config['batch_size_{}'.format(medium)]
|
|
self.root_vis = config['root_raw_vis_{}_{}'.format(medium, split)]
|
|
with open(config['vid_mapping_nextqa'], 'r') as f:
|
|
self.video_mapping = json.load(f)
|
|
|
|
self.sample_list = load_file(self.config['anno_nextqa_{}'.format(split)])
|
|
|
|
if split == 'test':
|
|
self.sample_list = self.sample_list[config['start_idx_gen']: config['end_idx_gen']]
|
|
self.captions = load_file(self.config['next_qa_captions_{}'.format(split)])
|
|
else:
|
|
self.captions = None
|
|
|
|
num_samples = config['num_samples_{}'.format(self.medium)]
|
|
if num_samples > 0:
|
|
self.sample_list = self.sample_list[:num_samples]
|
|
|
|
def __len__(self):
|
|
return len(self.sample_list)
|
|
|
|
|
|
def load_vid(self, vid_id):
|
|
vid_dir_path = os.path.join(self.root_vis, self.video_mapping[vid_id] + '.mp4')
|
|
|
|
frames, _, _ = read_frames_decord(vid_dir_path, self.config.num_frames)
|
|
frames = [self.vis_processor(f).unsqueeze(0) for f in frames]
|
|
|
|
vis = torch.cat(frames, dim=0)
|
|
return vis
|
|
|
|
def __getitem__(self, idx):
|
|
if self.split == 'test':
|
|
idx += self.config['start_idx_gen']
|
|
|
|
cur_sample = self.sample_list.loc[idx]
|
|
video_id, ques, ans, qid = str(cur_sample['video']), str(cur_sample['question']),\
|
|
str(cur_sample['answer']), str(cur_sample['qid'])
|
|
|
|
history = self.text_processor(ques)
|
|
answer = self.text_processor(ans)
|
|
if self.split == 'test':
|
|
caption = self.text_processor(self.captions[video_id])
|
|
else:
|
|
caption = self.text_processor('please answer the following question based on the video')
|
|
vis = self.load_vid(video_id)
|
|
|
|
return vis, caption, history, answer, video_id, qid
|
|
|
|
def load_nextqa_dataset(config, vis_processor, text_processor, split):
|
|
# data_file = config['anno_avsd_{}'.format(split)]
|
|
# dataset_list = get_dataset(config, split, tokenizer_enc_dec)
|
|
dataset = NextQADataset(config, 'nextqa', vis_processor, text_processor, split)
|
|
return dataset
|