initial commit
This commit is contained in:
commit
a82bbc593e
129 changed files with 33981 additions and 0 deletions
0
datasets/__init__.py
Normal file
0
datasets/__init__.py
Normal file
205
datasets/avsd_dataset.py
Normal file
205
datasets/avsd_dataset.py
Normal file
|
@ -0,0 +1,205 @@
|
|||
# coding: utf-8
|
||||
# author: noctli
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import logging
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
import torch.utils.data
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
from itertools import chain
|
||||
from torchvision import transforms
|
||||
from .utils import type_transform_helper
|
||||
from itertools import chain
|
||||
from .video_utils import read_frames_decord
|
||||
|
||||
|
||||
def tokenize(text, tokenizer, return_tensor=False):
|
||||
tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
|
||||
if return_tensor:
|
||||
return torch.tensor(tokenized_text).long()
|
||||
return tokenized_text
|
||||
|
||||
|
||||
def get_dataset(config, split):
|
||||
if split != 'test':
|
||||
dialog_pth = config[f'anno_avsd_{split}']
|
||||
else:
|
||||
dialog_pth = config['anno_avsd_test_dstc_{}'.format(config['dstc'])]
|
||||
n_history = config['num_hist_turns_avsd']
|
||||
undisclosed_only = split == 'test'
|
||||
dialog_data = json.load(open(dialog_pth, 'r'))
|
||||
dialog_list = []
|
||||
vid_set = set()
|
||||
pbar = tqdm(dialog_data['dialogs'])
|
||||
pbar.set_description('[INFO] Loading AVSD - {}'.format(split))
|
||||
for dialog in pbar:
|
||||
# if config['dstc'] != 10:
|
||||
caption = dialog['caption']
|
||||
summary = dialog['summary']
|
||||
# else:
|
||||
# caption = 'no'
|
||||
# summary = 'no'
|
||||
|
||||
questions = [d['question'] for d in dialog['dialog']]
|
||||
answers = [d['answer'] for d in dialog['dialog']]
|
||||
vid = dialog["image_id"]
|
||||
vid_set.add(vid)
|
||||
if undisclosed_only:
|
||||
it = range(len(questions) - 1, len(questions))
|
||||
else:
|
||||
it = range(len(questions))
|
||||
qalist=[]
|
||||
history = []
|
||||
if undisclosed_only:
|
||||
for n in range(len(questions)-1):
|
||||
qalist.append(questions[n])
|
||||
qalist.append(answers[n])
|
||||
history=qalist[max(-len(qalist),-n_history*2):]
|
||||
for n in it:
|
||||
if undisclosed_only:
|
||||
assert dialog['dialog'][n]['answer'] == '__UNDISCLOSED__'
|
||||
question = questions[n]
|
||||
answer = answers[n]
|
||||
history.append(question)
|
||||
if n_history == 0:
|
||||
item = {'vid': vid, 'history': [question], 'answer': answer, 'caption': caption, 'summary': summary}
|
||||
else:
|
||||
item = {'vid': vid, 'history': history, 'answer': answer, 'caption': caption, 'summary': summary}
|
||||
|
||||
dialog_list.append(item)
|
||||
qalist.append(question)
|
||||
qalist.append(answer)
|
||||
history=qalist[max(-len(qalist),-n_history*2):]
|
||||
|
||||
return dialog_list
|
||||
|
||||
|
||||
def build_input_from_segments(caption, history, reply, tokenizer, drop_caption=False):
|
||||
""" Build a sequence of input from 3 segments: caption(caption+summary) history and last reply """
|
||||
|
||||
bos, eos = tokenizer.convert_tokens_to_ids(['<s>', '</s>'])
|
||||
sep = eos
|
||||
|
||||
instance = {}
|
||||
instance["lm_labels"] = reply + [eos]
|
||||
caption = list(chain(*caption))
|
||||
|
||||
if not drop_caption:
|
||||
# sequence = [[bos] + list(chain(*caption))] + history + [reply + ([eos] if with_eos else [])]
|
||||
|
||||
# NOTE It is important not to include the reply in the input of the encoder -- > the decoder will just
|
||||
# learn to copy it --> low train/val loss but no learning is happening
|
||||
sequence = [[bos] + caption + [eos]] + [[sep] + s for s in history] + [[eos]]
|
||||
else:
|
||||
sequence = [[bos]] + [[sep] + s for s in history] + [[eos]]
|
||||
|
||||
instance["input_ids"] = list(chain(*sequence))
|
||||
return instance
|
||||
|
||||
|
||||
class AVSDDataSet(Dataset):
|
||||
def __init__(self, config, medium, vis_processor, text_processor, split
|
||||
# tokenizer, features=None, drop_rate=0.0, train=True
|
||||
):
|
||||
self.config = config
|
||||
self.medium = medium
|
||||
self.vis_processor = vis_processor
|
||||
self.text_processor = text_processor
|
||||
self.split = split
|
||||
self.batch_size = config['batch_size_test_{}'.format(medium)] if split == 'test' else config['batch_size_{}'.format(medium)]
|
||||
self.root_vis = config['root_raw_vis_{}_{}'.format(medium, split)]
|
||||
self.dialogs = get_dataset(config, split)
|
||||
|
||||
if split == 'test':
|
||||
self.dialogs = self.dialogs[config['start_idx_gen']: config['end_idx_gen']]
|
||||
|
||||
num_samples = config['num_samples_{}'.format(self.medium)]
|
||||
if num_samples > 0:
|
||||
self.dialogs = self.dialogs[:num_samples]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dialogs)
|
||||
|
||||
def load_vid(self, vid_id):
|
||||
vid_dir_path = os.path.join(self.root_vis, vid_id + '.mp4')
|
||||
|
||||
frames, _, _ = read_frames_decord(vid_dir_path, self.config.num_frames)
|
||||
frames = [self.vis_processor(f).unsqueeze(0) for f in frames]
|
||||
|
||||
vis = torch.cat(frames, dim=0)
|
||||
return vis
|
||||
|
||||
def load_vid_old(self, vid_id):
|
||||
# if vid_id == 'QQM8M':
|
||||
# print('bla')
|
||||
vid_dir_path = os.path.join(self.root_vis, vid_id)
|
||||
frame_paths = [os.path.join(vid_dir_path, f) for f in os.listdir(vid_dir_path)]
|
||||
frame_paths.sort()
|
||||
num_avail_frames = len(frame_paths)
|
||||
delta = int(num_avail_frames / (self.config['num_frames'] - 1))
|
||||
ran = list(range(0, num_avail_frames, delta))
|
||||
if len(ran) < self.config['num_frames']:
|
||||
ran.extend([num_avail_frames - 1 for _ in range(self.config['num_frames'] - len(ran))])
|
||||
if len(ran) > self.config['num_frames']:
|
||||
ran = ran[:self.config['num_frames']]
|
||||
assert len(ran) == self.config['num_frames'], f"vid {vid_id} - loaded {len(ran)}/{len(frame_paths)} frames"
|
||||
frame_paths = [frame_paths[i] for i in ran]
|
||||
vis = [Image.open(p).convert('RGB') for p in frame_paths]
|
||||
vis = [transforms.PILToTensor()(v).unsqueeze(0) for v in vis]
|
||||
vis = torch.cat(vis, dim=0)
|
||||
vis = self.trans(vis)
|
||||
return vis
|
||||
|
||||
def __getitem__(self, index):
|
||||
dialog = self.dialogs[index]
|
||||
vid_id = dialog['vid']
|
||||
|
||||
caption = dialog['caption']
|
||||
summary = dialog['summary']
|
||||
history = dialog['history']
|
||||
answer = dialog['answer']
|
||||
|
||||
caption = self.text_processor(caption)
|
||||
summary = self.text_processor(summary)
|
||||
if self.config.dstc != 10:
|
||||
caption = caption + ' ' + summary
|
||||
|
||||
history = [self.text_processor(h) for h in history]
|
||||
answer = self.text_processor(answer, remove_period=True)
|
||||
|
||||
if self.config.embed_from_llm:
|
||||
if self.config.llm_family in ['llama', 'mistral']:
|
||||
cls_tok = ''
|
||||
sep_tok = ' '
|
||||
bos_tok = '<s>'
|
||||
eos_tok = '</s>'
|
||||
else:
|
||||
cls_tok = '<s>'
|
||||
sep_tok = '</s>'
|
||||
bos_tok = '<pad>'
|
||||
eos_tok = '</s>'
|
||||
else:
|
||||
cls_tok = '[CLS]'
|
||||
sep_tok = '[SEP]'
|
||||
bos_tok = '[SEP]'
|
||||
eos_tok = '[SEP]'
|
||||
|
||||
caption = cls_tok + caption + sep_tok
|
||||
history = sep_tok.join(history)
|
||||
history = history + sep_tok
|
||||
|
||||
# load the video frames
|
||||
vis = self.load_vid(vid_id)
|
||||
|
||||
return vis, caption, history, answer, vid_id
|
||||
|
||||
|
||||
def load_avsd_dataset(config, vis_processor, text_processor, split):
|
||||
# data_file = config['anno_avsd_{}'.format(split)]
|
||||
# dataset_list = get_dataset(config, split, tokenizer_enc_dec)
|
||||
dataset = AVSDDataSet(config, 'avsd', vis_processor, text_processor, split)
|
||||
return dataset
|
279
datasets/champagne_dataset.py
Normal file
279
datasets/champagne_dataset.py
Normal file
|
@ -0,0 +1,279 @@
|
|||
# coding: utf-8
|
||||
# author: noctli
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import logging
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.data
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
from itertools import chain
|
||||
from torchvision import transforms
|
||||
from .utils import type_transform_helper
|
||||
|
||||
|
||||
def tokenize(text, tokenizer, return_tensor=False):
|
||||
tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
|
||||
if return_tensor:
|
||||
return torch.tensor(tokenized_text).long()
|
||||
return tokenized_text
|
||||
|
||||
|
||||
def get_dataset(config, split):
|
||||
|
||||
dialog_pth = config['anno_visdial_{}'.format(split)]
|
||||
dialog_data = json.load(open(dialog_pth, 'r'))['data']
|
||||
all_answers = dialog_data['answers']
|
||||
all_questions = dialog_data['questions']
|
||||
dialog_list = []
|
||||
n_history = config['num_hist_turns']
|
||||
vid_set = set()
|
||||
|
||||
pbar = tqdm(dialog_data['dialogs'])
|
||||
pbar.set_description('[INFO] Loading VisDial - {}'.format(split))
|
||||
for dialog in pbar:
|
||||
caption = dialog['caption']
|
||||
questions = [all_questions[d['question']] for d in dialog['dialog']]
|
||||
answers = [all_answers[d['answer']] for d in dialog['dialog']]
|
||||
|
||||
vid = dialog["image_id"]
|
||||
vid_set.add(vid)
|
||||
# if undisclosed_only:
|
||||
# it = range(len(questions) - 1, len(questions))
|
||||
# else:
|
||||
it = range(len(questions))
|
||||
qalist=[]
|
||||
history = []
|
||||
# if undisclosed_only:
|
||||
# for n in range(len(questions)-1):
|
||||
# qalist.append(questions[n])
|
||||
# qalist.append(answers[n])
|
||||
# history=qalist[max(-len(qalist),-n_history*2):]
|
||||
|
||||
for n in it:
|
||||
# if undisclosed_only:
|
||||
# assert dialog['dialog'][n]['answer'] == '__UNDISCLOSED__'
|
||||
question = questions[n]
|
||||
answer = answers[n]
|
||||
history.append(question)
|
||||
if n_history == 0:
|
||||
item = {'vid': vid, 'history': [question], 'answer': answer, 'caption': caption}
|
||||
else:
|
||||
item = {'vid': vid, 'history': history, 'answer': answer, 'caption': caption}
|
||||
dialog_list.append(item)
|
||||
qalist.append(question)
|
||||
qalist.append(answer)
|
||||
history=qalist[max(-len(qalist),-n_history*2):]
|
||||
return dialog_list
|
||||
|
||||
|
||||
class Champagne(Dataset):
|
||||
def __init__(self, config, medium, vis_processor, text_processor, split):
|
||||
|
||||
self.config = config
|
||||
self.medium = medium
|
||||
self.vis_processor = vis_processor
|
||||
self.text_processor = text_processor
|
||||
self.split = split
|
||||
self.batch_size = config['batch_size_{}'.format(medium)]
|
||||
|
||||
self.root_vis = config['root_raw_vis_{}_{}'.format(medium, split)]
|
||||
|
||||
# get the mapping between caption and image/video
|
||||
mapping_path = config.get('mapping_path_{}_{}'.format(medium, split), None)
|
||||
with open(mapping_path, 'rb') as f:
|
||||
self.mapping = pickle.load(f)
|
||||
|
||||
ids = list(self.mapping.keys())
|
||||
ids.sort()
|
||||
|
||||
# reserve some samples for validation
|
||||
if split == 'train':
|
||||
self.ids = ids[config.num_val_samples:]
|
||||
elif split == 'val':
|
||||
self.ids = ids[:config.num_val_samples]
|
||||
|
||||
num_samples = config['num_samples_{}'.format(self.medium)]
|
||||
if num_samples > 0:
|
||||
self.ids = self.ids[:num_samples]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.ids)
|
||||
|
||||
|
||||
def padding(self, seq, pad_token, max_len=None):
|
||||
if max_len is None:
|
||||
max_len = max([i.size(0) for i in seq])
|
||||
if len(seq[0].size()) == 1:
|
||||
result = torch.ones((len(seq), max_len)).long() * pad_token
|
||||
else:
|
||||
result = torch.ones((len(seq), max_len, seq[0].size(-1))).float()
|
||||
for i in range(len(seq)):
|
||||
result[i, :seq[i].size(0)] = seq[i]
|
||||
orig_len = [s.size(0) for s in seq]
|
||||
return result, orig_len
|
||||
|
||||
def __getitem__(self, index):
|
||||
item = self.mapping[self.ids[index]]
|
||||
# load the videos
|
||||
pth = os.path.join(self.root_vis, item['path'])
|
||||
f_names = os.listdir(pth)
|
||||
if len(f_names) == 0:
|
||||
with open('/pfss/mlde/workspaces/mlde_wsp_Rohrbach/users/ma35vahy/V2Dial_new/emergency/item.pkl', 'rb') as f:
|
||||
item = pickle.load(f)
|
||||
|
||||
# load the videos
|
||||
pth = os.path.join(self.root_vis, item['path'])
|
||||
f_names = os.listdir(pth)
|
||||
f_names.sort()
|
||||
|
||||
if len(f_names) < self.config['num_frames']:
|
||||
f_names += [f_names[-1]] * (self.config['num_frames'] - len(f_names))
|
||||
elif len(f_names) > self.config['num_frames']:
|
||||
f_names = f_names[:self.config['num_frames']]
|
||||
|
||||
pth = [os.path.join(pth, f_name) for f_name in f_names]
|
||||
try:
|
||||
vis = [Image.open(p).convert('RGB') for p in pth]
|
||||
except:
|
||||
with open('/pfss/mlde/workspaces/mlde_wsp_Rohrbach/users/ma35vahy/V2Dial_new/emergency/item.pkl', 'rb') as f:
|
||||
item = pickle.load(f)
|
||||
|
||||
# load the videos
|
||||
pth = os.path.join(self.root_vis, item['path'])
|
||||
f_names = os.listdir(pth)
|
||||
f_names.sort()
|
||||
|
||||
pth = [os.path.join(pth, f_name) for f_name in f_names]
|
||||
vis = [Image.open(p).convert('RGB') for p in pth]
|
||||
|
||||
vis = [self.vis_processor(v).unsqueeze(0) for v in vis]
|
||||
vis = torch.cat(vis, dim=0)
|
||||
|
||||
dialog = item['dialog']
|
||||
|
||||
caption = dialog['caption']
|
||||
history = dialog['history']
|
||||
answer = dialog['answer']
|
||||
|
||||
caption = self.text_processor(caption)
|
||||
history = [self.text_processor(h) for h in history]
|
||||
answer = self.text_processor(answer, remove_period=True)
|
||||
|
||||
if self.config.embed_from_llm:
|
||||
if self.config.llm_family in ['llama', 'mistral']:
|
||||
cls_tok = ''
|
||||
sep_tok = ''
|
||||
bos_tok = '<s>'
|
||||
eos_tok = '</s>'
|
||||
else:
|
||||
cls_tok = '<s>'
|
||||
sep_tok = '</s>'
|
||||
bos_tok = '<pad>'
|
||||
eos_tok = '</s>'
|
||||
else:
|
||||
cls_tok = '[CLS]'
|
||||
sep_tok = '[SEP]'
|
||||
bos_tok = '[SEP]'
|
||||
eos_tok = '[SEP]'
|
||||
|
||||
# preprocess the textual data
|
||||
caption = cls_tok + caption + sep_tok
|
||||
history = sep_tok.join(history)
|
||||
history = history + sep_tok
|
||||
# if self.config.llm_family == 'flan_t5':
|
||||
# answer = '<s> ' + self.text_processor(answer) + ' </s>'
|
||||
# else:
|
||||
# answer = self.text_processor(answer) + eos_tok
|
||||
|
||||
return vis, caption, history, answer
|
||||
|
||||
|
||||
# def collate_fn(self, batch):
|
||||
|
||||
# BOS, EOS, SEP = self.tokenizer_enc_dec.convert_tokens_to_ids(['<s>', '</s>', '</s>'])
|
||||
|
||||
# vis_list, cap_list, hist_list, ques_list, ans_list, index_list, vid_id_list = [], [], [], [], [], [], []
|
||||
# batch_size = len(batch)
|
||||
# for b in batch:
|
||||
# vis_list.append(b[0])
|
||||
# cap = [BOS] + tokenize(b[1], self.tokenizer_enc_dec) + [EOS]
|
||||
# cap_list.append(torch.tensor(cap))
|
||||
# if len(b[2])!=0:
|
||||
# hist = [[SEP] + tokenize(s, self.tokenizer_enc_dec) for s in b[2]] + [[EOS]]
|
||||
# hist_list.append(torch.tensor(list(chain(*hist))))
|
||||
# else:
|
||||
# hist = [SEP] + tokenize(b[3], self.tokenizer_enc_dec) + [EOS]
|
||||
# hist_list.append(torch.tensor(hist))
|
||||
|
||||
# ques = tokenize(b[3], self.tokenizer_enc_dec) + [EOS]
|
||||
# ques_list.append(torch.tensor(ques))
|
||||
# ans = tokenize(b[4], self.tokenizer_enc_dec) + [EOS]
|
||||
# ans_list.append(torch.tensor(ans))
|
||||
# index_list.append(b[5])
|
||||
# vid_id_list.append(b[6])
|
||||
|
||||
# # pad and keep track of the original lengths
|
||||
# cap_input_ids, cap_orig_lens = self.padding(cap_list, self.tokenizer_experts.pad_token_id)
|
||||
# hist_input_ids, hist_orig_lens = self.padding(hist_list, self.tokenizer_experts.pad_token_id)
|
||||
# ques_input_ids, ques_orig_lens = self.padding(ques_list, self.tokenizer_experts.pad_token_id)
|
||||
# ans_input_ids, _ = self.padding(ans_list, -100)
|
||||
|
||||
# cap_attention_mask = cap_input_ids != self.tokenizer_experts.pad_token_id
|
||||
# hist_attention_mask = hist_input_ids != self.tokenizer_experts.pad_token_id
|
||||
# ques_attention_mask = ques_input_ids != self.tokenizer_experts.pad_token_id
|
||||
|
||||
# total_orig_lens = [sum(l) for l in zip(cap_orig_lens, hist_orig_lens, ques_orig_lens)]
|
||||
# max_len = max(total_orig_lens)
|
||||
|
||||
# dummy_input_ids_enc_dec = torch.full((batch_size, max_len), self.tokenizer_experts.pad_token_id)
|
||||
# enc_dec_attention_mask = torch.zeros_like(dummy_input_ids_enc_dec, dtype=torch.bool)
|
||||
# for i, l in enumerate(total_orig_lens):
|
||||
# enc_dec_attention_mask[i][:l] = True
|
||||
# # add the masking of the visual input
|
||||
# num_query_tok = self.config['num_temporal_query_tokens_{}'.format(self.config['bert_size'])]
|
||||
# if self.medium in ['avsd', 'msrvtt', 'webvid', 'champagne']:
|
||||
# vis_attention_mask = torch.ones((batch_size, 2 * num_query_tok), dtype=torch.bool) # *2 for spatial and temporal queries
|
||||
# else:
|
||||
# vis_attention_mask = torch.ones((batch_size, num_query_tok), dtype=torch.bool) # only spatial queries
|
||||
|
||||
# enc_dec_attention_mask = torch.concat((vis_attention_mask, enc_dec_attention_mask), dim=1)
|
||||
# # Now prepare the data
|
||||
# vis = torch.stack(vis_list, dim=0)
|
||||
# cap = {
|
||||
# 'input_ids': cap_input_ids,
|
||||
# 'attention_mask': cap_attention_mask,
|
||||
# 'orig_lens': cap_orig_lens
|
||||
# }
|
||||
|
||||
# hist = {
|
||||
# 'input_ids': hist_input_ids,
|
||||
# 'attention_mask': hist_attention_mask,
|
||||
# 'orig_lens': hist_orig_lens
|
||||
# }
|
||||
|
||||
# ques = {
|
||||
# 'input_ids': ques_input_ids,
|
||||
# 'attention_mask': ques_attention_mask,
|
||||
# 'orig_lens': ques_orig_lens
|
||||
# }
|
||||
|
||||
# ans = {
|
||||
# 'input_ids': ans_input_ids,
|
||||
# }
|
||||
|
||||
# enc_dec_input = {
|
||||
# 'input_ids': dummy_input_ids_enc_dec,
|
||||
# 'attention_mask': enc_dec_attention_mask,
|
||||
# }
|
||||
|
||||
# index = torch.tensor(index_list)
|
||||
# return vis, cap, hist, ques, ans, enc_dec_input, index, vid_id_list
|
||||
|
||||
|
||||
def load_champagne_dataset(config, vis_processor, text_processor, split):
|
||||
dataset = Champagne(config, 'champagne', vis_processor, text_processor, split)
|
||||
return dataset
|
137
datasets/dataloader.py
Normal file
137
datasets/dataloader.py
Normal file
|
@ -0,0 +1,137 @@
|
|||
"""
|
||||
From https://github.com/klauscc/VindLU/blob/main/dataset/dataloader.py
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, Dataset, ConcatDataset
|
||||
import torch.distributed as dist
|
||||
from utils.dist import *
|
||||
import random
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MetaLoader(object):
|
||||
""" wraps multiple data loader """
|
||||
def __init__(self, name2loader):
|
||||
"""Iterates over multiple dataloaders, it ensures all processes
|
||||
work on data from the same dataloader. This loader will end when
|
||||
the shorter dataloader raises StopIteration exception.
|
||||
|
||||
loaders: Dict, {name: dataloader}
|
||||
"""
|
||||
self.name2loader = name2loader
|
||||
self.name2iter = {name: iter(l) for name, l in name2loader.items()}
|
||||
name2index = {name: idx for idx, (name, l) in enumerate(name2loader.items())}
|
||||
index2name = {v: k for k, v in name2index.items()}
|
||||
|
||||
iter_order = []
|
||||
for n, l in name2loader.items():
|
||||
iter_order.extend([name2index[n]]*len(l))
|
||||
|
||||
random.shuffle(iter_order)
|
||||
iter_order = torch.Tensor(iter_order).to(torch.device("cuda")).to(torch.uint8)
|
||||
|
||||
# sync
|
||||
if is_dist_avail_and_initialized():
|
||||
# make sure all processes have the same order so that
|
||||
# each step they will have data from the same loader
|
||||
dist.broadcast(iter_order, src=0)
|
||||
self.iter_order = [index2name[int(e.item())] for e in iter_order.cpu()]
|
||||
|
||||
logger.info(str(self))
|
||||
|
||||
def __str__(self):
|
||||
output = [f"MetaLoader has {len(self.name2loader)} dataloaders, {len(self)} batches in total"]
|
||||
for idx, (name, loader) in enumerate(self.name2loader.items()):
|
||||
output.append(
|
||||
f"dataloader index={idx} name={name}, batch-size={loader.batch_size} length(#batches)={len(loader)} "
|
||||
)
|
||||
return "\n".join(output)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.iter_order)
|
||||
|
||||
def __iter__(self):
|
||||
""" this iterator will run indefinitely """
|
||||
for name in self.iter_order:
|
||||
_iter = self.name2iter[name]
|
||||
batch = next(_iter)
|
||||
yield name, batch
|
||||
|
||||
|
||||
def load_dataloaders(config, datasets, split, output_dict=False):
|
||||
if isinstance(datasets, dict):
|
||||
datasets = list(datasets.values())
|
||||
shuffles = [True] * len(datasets) if split == 'train' else [False] * len(datasets)
|
||||
if config['distributed'] and split != 'test':
|
||||
num_tasks = get_world_size()
|
||||
global_rank = get_rank()
|
||||
samplers = create_samplers(
|
||||
datasets, shuffles, num_tasks, global_rank
|
||||
)
|
||||
else:
|
||||
samplers = [None] * len(datasets)
|
||||
|
||||
batch_size = [dataset.datasets[0].batch_size if isinstance(dataset, ConcatDataset) else dataset.batch_size for dataset in datasets]
|
||||
collate_fns = []
|
||||
for dataset in datasets:
|
||||
if isinstance(dataset, ConcatDataset):
|
||||
collate_fns.append(getattr(dataset.datasets[0], 'collate_fn', None))
|
||||
else:
|
||||
collate_fns.append(getattr(dataset, 'collate_fn', None))
|
||||
|
||||
loaders = create_loader(
|
||||
datasets,
|
||||
samplers,
|
||||
batch_size=batch_size,
|
||||
num_workers=[config.num_workers] * len(datasets),
|
||||
is_trains=shuffles,
|
||||
collate_fns=collate_fns,
|
||||
) # [0]
|
||||
loaders_dict = {}
|
||||
if output_dict:
|
||||
for l in loaders:
|
||||
if isinstance(l.dataset, ConcatDataset):
|
||||
loaders_dict[l.dataset.datasets[0].medium] = l
|
||||
else:
|
||||
loaders_dict[l.dataset.medium] = l
|
||||
return loaders_dict
|
||||
return loaders
|
||||
|
||||
|
||||
def create_samplers(datasets, shuffles, num_tasks, global_rank):
|
||||
samplers = []
|
||||
for dataset, shuffle in zip(datasets, shuffles):
|
||||
sampler = torch.utils.data.DistributedSampler(
|
||||
dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle
|
||||
)
|
||||
samplers.append(sampler)
|
||||
return samplers
|
||||
|
||||
|
||||
def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
|
||||
loaders = []
|
||||
for dataset, sampler, bs, n_worker, is_train, collate_fn in zip(
|
||||
datasets, samplers, batch_size, num_workers, is_trains, collate_fns
|
||||
):
|
||||
if is_train:
|
||||
shuffle = sampler is None
|
||||
drop_last = True
|
||||
else:
|
||||
shuffle = False
|
||||
drop_last = True
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=bs,
|
||||
num_workers=n_worker,
|
||||
pin_memory=False,
|
||||
sampler=sampler,
|
||||
shuffle=shuffle,
|
||||
collate_fn=collate_fn,
|
||||
drop_last=drop_last,
|
||||
persistent_workers=True if n_worker > 0 else False,
|
||||
)
|
||||
loaders.append(loader)
|
||||
return loaders
|
86
datasets/nextqa_dataset.py
Normal file
86
datasets/nextqa_dataset.py
Normal file
|
@ -0,0 +1,86 @@
|
|||
import os
|
||||
import pandas as pd
|
||||
# import h5py
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from .video_utils import read_frames_decord
|
||||
|
||||
|
||||
def load_file(file_name):
|
||||
annos = None
|
||||
if os.path.splitext(file_name)[-1] == '.csv':
|
||||
return pd.read_csv(file_name)
|
||||
with open(file_name, 'r') as fp:
|
||||
if os.path.splitext(file_name)[1]== '.txt':
|
||||
annos = fp.readlines()
|
||||
annos = [line.rstrip() for line in annos]
|
||||
if os.path.splitext(file_name)[1] == '.json':
|
||||
annos = json.load(fp)
|
||||
return annos
|
||||
|
||||
|
||||
class NextQADataset(Dataset):
|
||||
def __init__(self, config, medium, vis_processor, text_processor, split):
|
||||
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.medium = medium
|
||||
self.vis_processor = vis_processor
|
||||
self.text_processor = text_processor
|
||||
self.split = split
|
||||
|
||||
self.batch_size = config['batch_size_test_{}'.format(medium)] if split == 'test' else config['batch_size_{}'.format(medium)]
|
||||
self.root_vis = config['root_raw_vis_{}_{}'.format(medium, split)]
|
||||
with open(config['vid_mapping_nextqa'], 'r') as f:
|
||||
self.video_mapping = json.load(f)
|
||||
|
||||
self.sample_list = load_file(self.config['anno_nextqa_{}'.format(split)])
|
||||
|
||||
if split == 'test':
|
||||
self.sample_list = self.sample_list[config['start_idx_gen']: config['end_idx_gen']]
|
||||
self.captions = load_file(self.config['next_qa_captions_{}'.format(split)])
|
||||
else:
|
||||
self.captions = None
|
||||
|
||||
num_samples = config['num_samples_{}'.format(self.medium)]
|
||||
if num_samples > 0:
|
||||
self.sample_list = self.sample_list[:num_samples]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sample_list)
|
||||
|
||||
|
||||
def load_vid(self, vid_id):
|
||||
vid_dir_path = os.path.join(self.root_vis, self.video_mapping[vid_id] + '.mp4')
|
||||
|
||||
frames, _, _ = read_frames_decord(vid_dir_path, self.config.num_frames)
|
||||
frames = [self.vis_processor(f).unsqueeze(0) for f in frames]
|
||||
|
||||
vis = torch.cat(frames, dim=0)
|
||||
return vis
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if self.split == 'test':
|
||||
idx += self.config['start_idx_gen']
|
||||
|
||||
cur_sample = self.sample_list.loc[idx]
|
||||
video_id, ques, ans, qid = str(cur_sample['video']), str(cur_sample['question']),\
|
||||
str(cur_sample['answer']), str(cur_sample['qid'])
|
||||
|
||||
history = self.text_processor(ques)
|
||||
answer = self.text_processor(ans)
|
||||
if self.split == 'test':
|
||||
caption = self.text_processor(self.captions[video_id])
|
||||
else:
|
||||
caption = self.text_processor('please answer the following question based on the video')
|
||||
vis = self.load_vid(video_id)
|
||||
|
||||
return vis, caption, history, answer, video_id, qid
|
||||
|
||||
def load_nextqa_dataset(config, vis_processor, text_processor, split):
|
||||
# data_file = config['anno_avsd_{}'.format(split)]
|
||||
# dataset_list = get_dataset(config, split, tokenizer_enc_dec)
|
||||
dataset = NextQADataset(config, 'nextqa', vis_processor, text_processor, split)
|
||||
return dataset
|
156
datasets/pretraining.py
Normal file
156
datasets/pretraining.py
Normal file
|
@ -0,0 +1,156 @@
|
|||
from torch.utils.data import Dataset
|
||||
import pickle
|
||||
import os
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from torchvision import transforms
|
||||
import random
|
||||
|
||||
from .utils import pre_text, type_transform_helper, load_anno, open_img
|
||||
|
||||
class CapDataset(Dataset):
|
||||
def __init__(self, config, medium, vis_processor, text_processor, split):
|
||||
super(CapDataset, self).__init__()
|
||||
self.config = config
|
||||
self.batch_size = config['batch_size_{}'.format(medium)]
|
||||
self.medium = medium # "webvid / cc3m / msrvtt"
|
||||
self.vis_processor = vis_processor
|
||||
self.text_processor = text_processor
|
||||
self.split = split # train / val / test
|
||||
|
||||
self.root_vis = config['root_raw_vis_{}_{}'.format(medium, split)]
|
||||
|
||||
# get the mapping between caption and image/video
|
||||
mapping_path = config.get('mapping_path_{}_{}'.format(medium, split), None)
|
||||
with open(mapping_path, 'rb') as f:
|
||||
self.mapping = pickle.load(f)
|
||||
|
||||
# These are the main ids of the dataset (typically one pro image/vid)
|
||||
self.ids = list(self.mapping.keys())
|
||||
num_samples = config['num_samples_{}'.format(self.medium)]
|
||||
if num_samples > 0:
|
||||
self.ids = self.ids[:num_samples]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.ids)
|
||||
|
||||
def __getitem__(self, index):
|
||||
item = self.mapping[self.ids[index]]
|
||||
# _id = self.ids[index]
|
||||
############################# Textal features #############################
|
||||
caption = item['caption']
|
||||
# caption_ = pre_text(caption)
|
||||
caption = self.text_processor(caption)
|
||||
# add [CLS] token
|
||||
caption = '[CLS] ' + caption
|
||||
|
||||
if self.medium == 'cc3m':
|
||||
pth = os.path.join(self.root_vis, item['file'])
|
||||
vis = open_img(pth)
|
||||
vis = self.vis_processor(vis).unsqueeze(0)
|
||||
else:
|
||||
pth = os.path.join(self.root_vis, item['file'])
|
||||
f_names = os.listdir(pth)
|
||||
f_names.sort(key=lambda f_n: int(f_n.split('.')[0]))
|
||||
pth = [os.path.join(pth, f_name) for f_name in f_names]
|
||||
vis = [Image.open(p).convert('RGB') for p in pth]
|
||||
vis = [self.vis_processor(v).unsqueeze(0) for v in vis]
|
||||
vis = torch.cat(vis, dim=0)
|
||||
|
||||
# Get negative vis
|
||||
neg_index = random.randint(0, len(self) - 1)
|
||||
while neg_index == index:
|
||||
neg_index = random.randint(0, len(self) - 1)
|
||||
|
||||
neg_item = self.mapping[self.ids[neg_index]]
|
||||
|
||||
if self.medium == 'cc3m':
|
||||
neg_pth = os.path.join(self.root_vis, neg_item['file'])
|
||||
neg_vis = open_img(neg_pth)
|
||||
neg_vis = self.vis_processor(neg_vis).unsqueeze(0)
|
||||
else:
|
||||
neg_pth = os.path.join(self.root_vis, neg_item['file'])
|
||||
neg_f_names = os.listdir(neg_pth)
|
||||
neg_f_names.sort(key=lambda f_n: int(f_n.split('.')[0]))
|
||||
neg_pth = [os.path.join(neg_pth, neg_f_name) for neg_f_name in neg_f_names]
|
||||
neg_vis = [Image.open(p).convert('RGB') for p in neg_pth]
|
||||
neg_vis = [self.vis_processor(v).unsqueeze(0) for v in neg_vis]
|
||||
neg_vis = torch.cat(neg_vis, dim=0)
|
||||
|
||||
# return caption, vis
|
||||
return vis, caption, neg_vis
|
||||
|
||||
|
||||
class VideoTextRetDataset(Dataset):
|
||||
def __init__(self, config, vis_processor, text_processor, medium, split):
|
||||
super(VideoTextRetDataset, self).__init__()
|
||||
|
||||
self.config = config
|
||||
self.batch_size = config['batch_size_{}'.format(medium)]
|
||||
self.medium = medium # "webvid / cc3m / msrvtt"
|
||||
self.vis_processor = vis_processor
|
||||
self.text_processor = text_processor
|
||||
self.split = split # train / val / test
|
||||
|
||||
|
||||
self.root_vis = config['root_raw_vis_{}_{}'.format(medium, split)]
|
||||
|
||||
anno_path = config['annotation_{}_{}'.format(medium, split)]
|
||||
self.raw_anno_list = load_anno(anno_path)
|
||||
self.text = []
|
||||
self.vis = []
|
||||
self.txt2vis = {}
|
||||
self.vis2txt = {}
|
||||
self.build_data()
|
||||
self.anno_list = [dict(vis=v) for v in self.vis]
|
||||
# print('bla')
|
||||
|
||||
def __len__(self):
|
||||
return len(self.anno_list)
|
||||
|
||||
def __getitem__(self, index):
|
||||
pth = self.anno_list[index]['vis']
|
||||
f_names = os.listdir(pth)
|
||||
f_names.sort(key=lambda f_n: int(f_n.split('.')[0]))
|
||||
pth = [os.path.join(pth, f_name) for f_name in f_names]
|
||||
vis = [Image.open(p).convert('RGB') for p in pth]
|
||||
vis = [self.vis_processor(v) for v in vis]
|
||||
# vis = [transforms.PILToTensor()(v).unsqueeze(0) for v in vis]
|
||||
vis = torch.cat(vis, dim=0)
|
||||
# vis = self.trans(vis)
|
||||
|
||||
return vis, index
|
||||
|
||||
def build_data(self):
|
||||
"""each image may have multiple ground_truth text, e.g., COCO and Flickr30K"""
|
||||
txt_id = 0
|
||||
for vis_id, ann in enumerate(self.raw_anno_list):
|
||||
self.vis.append(ann["vis"])
|
||||
self.vis2txt[vis_id] = []
|
||||
_captions = ann["caption"] \
|
||||
if isinstance(ann["caption"], list) else [ann["caption"], ]
|
||||
for i, caption in enumerate(_captions):
|
||||
# self.text.append(pre_text(caption))
|
||||
self.text.append(self.text_processor(caption))
|
||||
self.vis2txt[vis_id].append(txt_id)
|
||||
self.txt2vis[txt_id] = vis_id
|
||||
txt_id += 1
|
||||
|
||||
|
||||
def load_datasets(config, vis_processor, text_processor, split):
|
||||
if config['stage'] == 'stage_1':
|
||||
if split != 'test':
|
||||
cc3m_dataset = CapDataset(config, 'cc3m', vis_processor, text_processor, split)
|
||||
webvid_dataset = CapDataset(config, 'webvid', vis_processor, text_processor, split)
|
||||
datasets = {
|
||||
'cc3m': cc3m_dataset,
|
||||
'webvid': webvid_dataset
|
||||
}
|
||||
else: # Test with msrvtt_1k --> video retieval
|
||||
msrvtt_dataset = VideoTextRetDataset(config, vis_processor, text_processor, 'msrvtt', split)
|
||||
datasets = {
|
||||
'msrvtt': msrvtt_dataset
|
||||
}
|
||||
return datasets
|
83
datasets/utils.py
Normal file
83
datasets/utils.py
Normal file
|
@ -0,0 +1,83 @@
|
|||
import os
|
||||
import re
|
||||
import json
|
||||
from tqdm import trange
|
||||
from utils.dist import is_main_process
|
||||
from torch.utils.data import Dataset, ConcatDataset
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
def open_img(img_pth):
|
||||
try:
|
||||
img = Image.open(img_pth).convert('RGB')
|
||||
return img
|
||||
except:
|
||||
img = np.random.randint(0, high=256, size=(224,224, 3))
|
||||
img = Image.fromarray(img, 'RGB')
|
||||
return img
|
||||
|
||||
|
||||
def pre_text(text, max_l=None):
|
||||
text = re.sub(r"(['!?\"()*#:;~])", '', text.lower())
|
||||
text = text.replace('-', ' ').replace('/', ' ').replace('<person>', 'person')
|
||||
|
||||
text = re.sub(r"\s{2,}", ' ', text)
|
||||
text = text.rstrip('\n').strip(' ')
|
||||
|
||||
if max_l: # truncate
|
||||
words = text.split(' ')
|
||||
if len(words) > max_l:
|
||||
text = ' '.join(words[:max_l])
|
||||
return text
|
||||
|
||||
|
||||
def get_datasets_media(dataloaders):
|
||||
media = {}
|
||||
for dataloader in dataloaders:
|
||||
if isinstance(dataloader.dataset, ConcatDataset):
|
||||
media[dataloader.dataset.datasets[0].medium] = dataloader
|
||||
else:
|
||||
media[dataloader.dataset.medium] = dataloader
|
||||
|
||||
# media = [dataloader.dataset.medium for dataloader in dataloaders]
|
||||
return media
|
||||
|
||||
def type_transform_helper(x):
|
||||
return x.float().div(255.0)
|
||||
|
||||
def load_anno(ann_file_list):
|
||||
"""[summary]
|
||||
|
||||
Args:
|
||||
ann_file_list (List[List[str, str]] or List[str, str]):
|
||||
the latter will be automatically converted to the former.
|
||||
Each sublist contains [anno_path, image_root], (or [anno_path, video_root, 'video'])
|
||||
which specifies the data type, video or image
|
||||
|
||||
Returns:
|
||||
List(dict): each dict is {
|
||||
image: str or List[str], # image_path,
|
||||
caption: str or List[str] # caption text string
|
||||
}
|
||||
"""
|
||||
if isinstance(ann_file_list[0], str):
|
||||
ann_file_list = [ann_file_list]
|
||||
|
||||
ann = []
|
||||
for d in ann_file_list:
|
||||
data_root = d[1]
|
||||
fp = d[0]
|
||||
is_video = len(d) == 3 and d[2] == "video"
|
||||
cur_ann = json.load(open(fp, "r"))
|
||||
iterator = trange(len(cur_ann), desc=f"Loading {fp}") \
|
||||
if is_main_process() else range(len(cur_ann))
|
||||
for idx in iterator:
|
||||
key = "video" if is_video else "image"
|
||||
video_id = cur_ann[idx][key][5:].split('.')[0]
|
||||
# unified to have the same key for data path
|
||||
# if isinstance(cur_ann[idx][key], str):
|
||||
cur_ann[idx]["vis"] = os.path.join(data_root, video_id)
|
||||
# else: # list
|
||||
# cur_ann[idx]["vis"] = [os.path.join(data_root, e) for e in cur_ann[idx][key]]
|
||||
ann += cur_ann
|
||||
return ann
|
97
datasets/video_utils.py
Normal file
97
datasets/video_utils.py
Normal file
|
@ -0,0 +1,97 @@
|
|||
"""
|
||||
Modified from https://github.com/m-bain/frozen-in-time/blob/22a91d78405ec6032fdf521ae1ff5573358e632f/base/base_dataset.py
|
||||
"""
|
||||
import random
|
||||
import decord
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import math
|
||||
decord.bridge.set_bridge("torch")
|
||||
|
||||
|
||||
def pts_to_secs(pts: int, time_base: float, start_pts: int) -> float:
|
||||
"""
|
||||
Converts a present time with the given time base and start_pts offset to seconds.
|
||||
|
||||
Returns:
|
||||
time_in_seconds (float): The corresponding time in seconds.
|
||||
|
||||
https://github.com/facebookresearch/pytorchvideo/blob/main/pytorchvideo/data/utils.py#L54-L64
|
||||
"""
|
||||
if pts == math.inf:
|
||||
return math.inf
|
||||
|
||||
return int(pts - start_pts) * time_base
|
||||
|
||||
|
||||
def get_pyav_video_duration(video_reader):
|
||||
video_stream = video_reader.streams.video[0]
|
||||
video_duration = pts_to_secs(
|
||||
video_stream.duration,
|
||||
video_stream.time_base,
|
||||
video_stream.start_time
|
||||
)
|
||||
return float(video_duration)
|
||||
|
||||
|
||||
def get_frame_indices_by_fps():
|
||||
pass
|
||||
|
||||
|
||||
def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1):
|
||||
if sample in ["rand", "middle"]:
|
||||
acc_samples = min(num_frames, vlen)
|
||||
# split the video into `acc_samples` intervals, and sample from each interval.
|
||||
intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int)
|
||||
ranges = []
|
||||
for idx, interv in enumerate(intervals[:-1]):
|
||||
ranges.append((interv, intervals[idx + 1] - 1))
|
||||
if sample == 'rand':
|
||||
try:
|
||||
frame_indices = [random.choice(range(x[0], x[1])) for x in ranges]
|
||||
except:
|
||||
frame_indices = np.random.permutation(vlen)[:acc_samples]
|
||||
frame_indices.sort()
|
||||
frame_indices = list(frame_indices)
|
||||
elif fix_start is not None:
|
||||
frame_indices = [x[0] + fix_start for x in ranges]
|
||||
elif sample == 'middle':
|
||||
frame_indices = [(x[0] + x[1]) // 2 for x in ranges]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if len(frame_indices) < num_frames: # padded with last frame
|
||||
padded_frame_indices = [frame_indices[-1]] * num_frames
|
||||
padded_frame_indices[:len(frame_indices)] = frame_indices
|
||||
frame_indices = padded_frame_indices
|
||||
elif "fps" in sample: # fps0.5, sequentially sample frames at 0.5 fps
|
||||
output_fps = float(sample[3:])
|
||||
duration = float(vlen) / input_fps
|
||||
delta = 1 / output_fps # gap between frames, this is also the clip length each frame represents
|
||||
frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta)
|
||||
frame_indices = np.around(frame_seconds * input_fps).astype(int)
|
||||
frame_indices = [e for e in frame_indices if e < vlen]
|
||||
if max_num_frames > 0 and len(frame_indices) > max_num_frames:
|
||||
frame_indices = frame_indices[:max_num_frames]
|
||||
# frame_indices = np.linspace(0 + delta / 2, duration + delta / 2, endpoint=False, num=max_num_frames)
|
||||
else:
|
||||
raise ValueError
|
||||
return frame_indices
|
||||
|
||||
|
||||
def read_frames_decord(video_path, num_frames, sample='rand', fix_start=None, max_num_frames=-1):
|
||||
video_reader = decord.VideoReader(video_path, num_threads=1)
|
||||
vlen = len(video_reader)
|
||||
fps = video_reader.get_avg_fps()
|
||||
duration = vlen / float(fps)
|
||||
frame_indices = get_frame_indices(
|
||||
num_frames, vlen, sample=sample, fix_start=fix_start,
|
||||
input_fps=fps, max_num_frames=max_num_frames
|
||||
)
|
||||
frames = video_reader.get_batch(frame_indices) # (T, H, W, C), torch.uint8
|
||||
frames = frames.permute(0, 3, 1, 2) # (T, C, H, W), torch.uint8
|
||||
frames = frames.split(1, dim=0)
|
||||
|
||||
frames = [Image.fromarray(f.squeeze().numpy(), mode='RGB') for f in frames]
|
||||
# frames = frames.numpy() # convert to numpy
|
||||
return frames, frame_indices, duration
|
183
datasets/visdial_dataset.py
Normal file
183
datasets/visdial_dataset.py
Normal file
|
@ -0,0 +1,183 @@
|
|||
# coding: utf-8
|
||||
# author: noctli
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import logging
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.data
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
from itertools import chain
|
||||
from torchvision import transforms
|
||||
from .utils import type_transform_helper
|
||||
from .utils import open_img
|
||||
|
||||
def tokenize(text, tokenizer, return_tensor=False):
|
||||
tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
|
||||
if return_tensor:
|
||||
return torch.tensor(tokenized_text).long()
|
||||
return tokenized_text
|
||||
|
||||
|
||||
def get_dataset(config, split):
|
||||
|
||||
dialog_pth = config['anno_visdial_{}'.format(split)]
|
||||
dialog_data = json.load(open(dialog_pth, 'r'))['data']
|
||||
|
||||
all_answers = dialog_data['answers']
|
||||
all_questions = dialog_data['questions']
|
||||
dialog_list = []
|
||||
n_history = config['num_hist_turns_visdial']
|
||||
vid_set = set()
|
||||
undisclosed_only = False
|
||||
|
||||
pbar = tqdm(dialog_data['dialogs'])
|
||||
pbar.set_description('[INFO] Loading VisDial - {}'.format(split))
|
||||
for dialog in pbar:
|
||||
caption = dialog['caption'] + ' .'
|
||||
questions = [all_questions[d['question']] + ' ?' for d in dialog['dialog']]
|
||||
answers = [all_answers[d['answer']] + ' .' for d in dialog['dialog']]
|
||||
|
||||
# answer_opts = [[all_answers[key] for key in d['answer_options']] for d in dialog['dialog']]
|
||||
# if 'test' in config['anno_visdial_{}'.format(split)]:
|
||||
# gt_indices = [-1 for _ in range(len(questions))]
|
||||
# else:
|
||||
# gt_indices = [d['gt_index'] for d in dialog['dialog']]
|
||||
|
||||
vid = dialog['image_id']
|
||||
vid_set.add(vid)
|
||||
if undisclosed_only:
|
||||
it = range(len(questions) - 1, len(questions))
|
||||
else:
|
||||
it = range(len(questions))
|
||||
|
||||
qalist=[]
|
||||
history = []
|
||||
if undisclosed_only:
|
||||
for n in range(len(questions)-1):
|
||||
qalist.append(questions[n])
|
||||
qalist.append(answers[n])
|
||||
history=qalist[max(-len(qalist),-n_history*2):]
|
||||
|
||||
for n in it:
|
||||
if undisclosed_only:
|
||||
assert dialog['dialog'][n]['answer'] == '__UNDISCLOSED__'
|
||||
question = questions[n]
|
||||
answer = answers[n]
|
||||
# answer_opt = answer_opts[n]
|
||||
# gt_index = gt_indices[n]
|
||||
history.append(question)
|
||||
# if n_history == 0:
|
||||
# item = {'vid': vid, 'history': [question], 'answer': answer, 'caption': caption, 'round': n+1, 'answer_opts': answer_opt, 'gt_index': gt_index}
|
||||
# else:
|
||||
# item = {'vid': vid, 'history': history, 'answer': answer, 'caption': caption, 'round': n+1, 'answer_opts': answer_opt, 'gt_index': gt_index}
|
||||
|
||||
if n_history == 0:
|
||||
item = {'vid': vid, 'history': [question], 'answer': answer, 'caption': caption, 'round': n+1}
|
||||
else:
|
||||
item = {'vid': vid, 'history': history, 'answer': answer, 'caption': caption, 'round': n+1}
|
||||
|
||||
|
||||
|
||||
dialog_list.append(item)
|
||||
qalist.append(question)
|
||||
qalist.append(answer)
|
||||
history=qalist[max(-len(qalist),-n_history*2):]
|
||||
|
||||
return dialog_list
|
||||
|
||||
|
||||
class VisDial(Dataset):
|
||||
def __init__(self, config, medium, vis_processor, text_processor, split
|
||||
# tokenizer, features=None, drop_rate=0.0, train=True
|
||||
):
|
||||
self.config = config
|
||||
self.medium = medium
|
||||
self.split = split
|
||||
self.vis_processor = vis_processor
|
||||
self.text_processor = text_processor
|
||||
self.batch_size = config['batch_size_test_{}'.format(medium)] if split == 'test' else config['batch_size_{}'.format(medium)]
|
||||
self.root_vis = config['root_raw_vis_{}_{}'.format(medium, split)]
|
||||
|
||||
self.dialogs = get_dataset(config, split)
|
||||
|
||||
if split == 'test':
|
||||
self.dialogs = self.dialogs[config['start_idx_gen']: config['end_idx_gen']]
|
||||
|
||||
num_samples = config['num_samples_{}'.format(self.medium)]
|
||||
if num_samples > 0:
|
||||
self.dialogs = self.dialogs[:num_samples]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dialogs)
|
||||
|
||||
def load_img(self, vid_id):
|
||||
file_pth = os.path.join(self.root_vis, f'{vid_id}.jpg')
|
||||
vis = open_img(file_pth)
|
||||
vis = self.vis_processor(vis).unsqueeze(0)
|
||||
return vis
|
||||
|
||||
def __getitem__(self, index):
|
||||
dialog = self.dialogs[index]
|
||||
|
||||
vid_id = dialog['vid']
|
||||
caption = dialog['caption']
|
||||
history = dialog['history']
|
||||
answer = dialog['answer']
|
||||
d_round = dialog['round']
|
||||
|
||||
caption = self.text_processor(caption)
|
||||
history = [self.text_processor(h) for h in history]
|
||||
answer = self.text_processor(answer, remove_period=True)
|
||||
|
||||
|
||||
# if self.split == 'test':
|
||||
# answer_opts = dialog['answer_opts']
|
||||
# answer_opts = [self.text_processor(a) for a in answer_opts]
|
||||
|
||||
# gt_index = dialog['gt_index']
|
||||
# dialog_round = dialog['round']
|
||||
|
||||
# dense_key = str(vid_id) + '_' + str(dialog_round)
|
||||
# gt_relevance = self.dense_annos.get(dense_key, -1)
|
||||
# # eval_data = (answer_opts, gt_index, gt_relevance)
|
||||
|
||||
|
||||
if self.config.embed_from_llm:
|
||||
if self.config.llm_family in ['llama', 'mistral']:
|
||||
cls_tok = ''
|
||||
sep_tok = ' '
|
||||
bos_tok = '<s>'
|
||||
eos_tok = '</s>'
|
||||
else:
|
||||
cls_tok = '<s>'
|
||||
sep_tok = '</s>'
|
||||
bos_tok = '<pad>'
|
||||
eos_tok = '</s>'
|
||||
else:
|
||||
cls_tok = '[CLS]'
|
||||
sep_tok = '[SEP]'
|
||||
bos_tok = '[SEP]'
|
||||
eos_tok = '[SEP]'
|
||||
|
||||
caption = cls_tok + caption + sep_tok
|
||||
history = sep_tok.join(history)
|
||||
history = history + sep_tok
|
||||
|
||||
# load the video frames
|
||||
vis = self.load_img(vid_id)
|
||||
|
||||
# if self.split == 'test':
|
||||
# return vis, caption, history, answer, vid_id, answer_opts, gt_relevance, gt_index
|
||||
|
||||
# else:
|
||||
return vis, caption, history, answer, vid_id, d_round
|
||||
|
||||
|
||||
|
||||
def load_visdial_dataset(config, vis_processor, text_processor, split):
|
||||
dataset = VisDial(config, 'visdial', vis_processor, text_processor, split)
|
||||
return dataset
|
Loading…
Add table
Add a link
Reference in a new issue