OLViT/src/utils/dvd_codebase/data/data_handler.py

265 lines
11 KiB
Python

"""
Copyright (c) Facebook, Inc. and its affiliates.
All rights reserved.
This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""
import copy, logging, sys, time, os, pdb, random, glob, json
import pickle as pkl
import numpy as np
from tqdm import tqdm
from collections import Counter
from functools import partial
import nltk
import torch
import torch.utils.data as Data
from src.utils.dvd_codebase.data.dataset import *
from src.utils.dvd_codebase.data.analysis_utils import *
from src.utils.dvd_codebase.data.data_utils import *
from src.utils.dvd_codebase.data.analysis_utils import get_question_subtype, get_question_complexity
from transformers import AutoTokenizer
def load_dials(args, split):
files = []
for video_split in ['all_actions', 'max2action']:
files += glob.glob(args.data_dir + '{}_{}_*/*.json'.format(video_split, split))
files = sorted(files) # [:50]
if args.debug:
files = files[:100]
all_dials = []
vid_set = {}
for file in tqdm(files, total=len(files)):
dials = json.load(open(file))
all_dials.extend(dials)
video_split = dials[0][0]['split']
vid = dials[0][0]['image'].replace('CLEVR', 'CATER')
vid_key = '{}-{}'.format(video_split, vid)
if vid_key not in vid_set:
vid_set[vid_key] = '{}/{}/{}.pkl'.format(args.fea_dir, video_split, vid)
return all_dials, vid_set
def load_videos(args, vid_set):
vid_fts = {}
ft_dims = None
size, stride = -1, -1
segment_map = {}
for vid_key, fea_file in tqdm(vid_set.items(), total=len(vid_set)):
#fea_file = '{}/{}.pkl'.format(args.fea_dir, vid)
fea = pkl.load(open(fea_file, 'rb'))
output = []
for clip_idx, clip in enumerate(fea['clips']):
fea = clip['features']
if len(fea.shape)==3:
fea = fea.transpose(1, 2, 0)
output.append(fea)
start, end = clip['segment']
if clip_idx not in segment_map:
segment_map[clip_idx] = (start, end)
if size == -1:
size = end - start + 1
if clip_idx>0 and stride == -1:
stride = start - prior_start
prior_start, prior_end = start, end
vft = np.asarray(output)
vid_fts[vid_key] = vft
if ft_dims is None:
ft_dims = vft.shape
return vid_fts, ft_dims, size, stride, segment_map
def load_video_features(args, vid_set):
vid_fts = {}
for vid_key, fea_file in tqdm(vid_set.items(), total=len(vid_set)):
#fea_file = '{}/{}.pkl'.format(args.fea_dir, vid)
fea = pkl.load(open(fea_file, 'rb'))
vid_fts[vid_key] = fea
return vid_fts
def get_vocabulary(dials, args, vocab=None):
#answer_options = set()
word_freq = {}
for dialog in tqdm(dials, total=len(dials)):
for turn in dialog:
for word in nltk.word_tokenize(turn['question']):
if word not in word_freq: word_freq[word] = 0
word_freq[word] += 1
answer = str(turn['answer'])
#answer_options.add(answer)
for word in nltk.word_tokenize(answer):
if word not in word_freq: word_freq[word] = 0
word_freq[word] += 1
program = turn['final_all_program']
for n in program:
if n['type'] == 'identity': continue
if n['type'] not in word_freq: word_freq[n['type']] = 0
word_freq[n['type']] += 1
if 'side_inputs' in n:
for side_input in n['side_inputs']:
for word in nltk.word_tokenize(side_input):
if word not in word_freq: word_freq[word] = 0
word_freq[word] += 1
if vocab is not None:
unk_words = set()
for word, freq in word_freq.items():
if word not in vocab:
unk_words.add(word)
return unk_words
vocab = {'<unk>':0, '<blank>':1, '<sos>':2, '<eos>':3, '<eoo>': 4}
for word, freq in word_freq.items():
vocab[word] = len(vocab)
answer_options = ['0', '1', '10', '2', '3', '4', '5', '6', '7', '8', '9', 'False', 'True', 'blue', 'brown', 'cone', 'cube', 'cyan', 'cylinder', 'flying', 'flying,rotating', 'flying,rotating,sliding', 'flying,sliding', 'gold', 'gray', 'green', 'large', 'medium', 'metal', 'no action', 'purple', 'red', 'rotating', 'rotating,sliding', 'rubber', 'sliding', 'small', 'sphere', 'spl', 'yellow']
return vocab, answer_options
def answer_by_question_type(dials):
qa_dist = {}
for dialog in dials:
for turn_idx, turn in enumerate(dialog):
answer = turn['answer']
template = turn['template']
if turn_idx > 0:
prior_template = dialog[turn_idx-1]['template']
else:
prior_template = None
qtype = get_question_subtype(template, prior_template)
if qtype not in qa_dist:
qa_dist[qtype] = {}
if answer not in qa_dist[qtype]:
qa_dist[qtype][answer] = 0
qa_dist[qtype][answer] += 1
return qa_dist
# Load text data
def create_dials(dials, vocab, answer_list, vft_data, args, tokenizer=None):
dialog_list = []
qa_id = 0
for dialog in tqdm(dials, total=len(dials)):
if tokenizer is None:
questions = [words2ids(t['question'], vocab) for t in dialog]
answers = [words2ids(str(t['answer']), vocab) for t in dialog]
else:
questions = [words2ids_pretrained_lm(t['question'], vocab, tokenizer) for t in dialog]
answers = [words2ids_pretrained_lm(str(t['answer']), vocab, tokenizer) for t in dialog]
answer_output = [[answer_list.index(str(t['answer']))] for t in dialog]
qa_pair = [np.concatenate((q,a)).astype(np.int32) for q,a in zip(questions, answers)]
attribute_dependencies = []
object_dependencies = []
temporal_dependencies = []
spatial_dependencies = []
q_types = []
q_complexities = []
for i, t in enumerate(dialog):
# determine the type of turn relation
attribute_dependencies.append(t['turn_dependencies']['attribute'])
object_dependencies.append(t['turn_dependencies']['object'])
temporal_dependencies.append(t['turn_dependencies']['temporal'])
spatial_dependencies.append(t['turn_dependencies']['spatial'])
# determine the question type based on the template for analysis reasons
if i == 0:
q_types.append(get_question_type(t['template'], None))
else:
q_types.append(get_question_type(t['template'], dialog[i-1]['template']))
# get question complexity
q_complexities.append(get_question_complexity(t, t['template_filename'] ))
# get image name
video_name = t['image']
vid_cutoffs = [t['template']['cutoff'] for t in dialog]
gt_vid_periods = [t['template']['used_periods'][-1] for t in dialog]
programs = [program2ids(t['final_all_program'], vocab) for t in dialog]
states = [state2ids(t['template']['used_objects'], vocab) for t in dialog]
vid = dialog[0]['image'].replace('CLEVR', 'CATER')
vid_split = dialog[0]['split']
vid_key = '{}-{}'.format(vid_split, vid)
whole_vft_fea = vft_data[vid_key]
turn_based_vft_fea = []
# cutoff the unused vft data based on the vid_cutoffs
for t_idx, t_cutoff in enumerate(vid_cutoffs):
if t_cutoff is not None:
t_vft_fea = whole_vft_fea[:t_cutoff[3], :, :]
else:
t_vft_fea = whole_vft_fea
turn_based_vft_fea.append(t_vft_fea)
for n in range(len(questions)):
start_turn_idx = 0
history = np.asarray([])
turns = []
q_turns = []
a_turns = []
for m in range(start_turn_idx, n):
history = np.append(history, qa_pair[m])
turns.append(qa_pair[m])
q_turns.append(questions[m])
a_turns.append(np.array(answer_output[m]))
question = questions[n]
answer = answer_output[n]
program = programs[n]
state = states[n]
gt_period = gt_vid_periods[n]
q_type = q_types[n]
attribute_dependency = attribute_dependencies[n]
object_dependency = object_dependencies[n]
temporal_dependency = temporal_dependencies[n]
spatial_dependency = spatial_dependencies[n]
q_complexity = q_complexities[n]
vft_feat = turn_based_vft_fea[n]
item = [vid_split, vid, qa_id, history, question, answer, turns,
q_turns, a_turns, vft_feat, gt_period,
program, state, q_type, attribute_dependency, object_dependency,
temporal_dependency, spatial_dependency, video_name, q_complexity]
dialog_list.append(item)
qa_id += 1
data = {'dialogs': dialog_list, 'vocab': vocab, 'answer': answer_list, 'features': []}
return data
def create_dataset(data, vocab, split, args):
out = {}
keys = ['vid_split', 'vid', 'qa_id', 'history', 'question', 'answer', 'turns',
'q_turns', 'a_turns', 'vft', 'gt_period',
'program', 'state', 'q_type', 'attribute_dependency', 'object_dependency',
'temporal_dependency', 'spatial_dependency', 'video_name', 'q_complexity']
for key in keys:
out[key] = []
for dialog in data['dialogs']:
out['vid_split'].append(dialog[0])
out['vid'].append(dialog[1])
out['qa_id'].append(dialog[2])
out['history'].append(dialog[3])
out['question'].append(dialog[4])
out['answer'].append(dialog[5])
out['turns'].append(dialog[6])
out['q_turns'].append(dialog[7])
out['a_turns'].append(dialog[8])
out['vft'].append(dialog[9])
out['gt_period'].append(dialog[10])
out['program'].append(dialog[11])
out['state'].append(dialog[12])
out['q_type'].append(dialog[13])
out['attribute_dependency'].append(dialog[14])
out['object_dependency'].append(dialog[15])
out['temporal_dependency'].append(dialog[16])
out['spatial_dependency'].append(dialog[17])
out['video_name'].append(dialog[18])
out['q_complexity'].append(dialog[19])
dataset = Dataset(out)
data_loader = torch.utils.data.DataLoader(dataset=dataset,
batch_size=args.batch_size,
shuffle=(split=='train'),
collate_fn=partial(collate_fn, vocab=vocab),
num_workers=args.num_workers,
pin_memory=True)
return data_loader, len(out['vid'])