initial commit
This commit is contained in:
commit
a82bbc593e
129 changed files with 33981 additions and 0 deletions
279
datasets/champagne_dataset.py
Normal file
279
datasets/champagne_dataset.py
Normal file
|
@ -0,0 +1,279 @@
|
|||
# coding: utf-8
|
||||
# author: noctli
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import logging
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.data
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
from itertools import chain
|
||||
from torchvision import transforms
|
||||
from .utils import type_transform_helper
|
||||
|
||||
|
||||
def tokenize(text, tokenizer, return_tensor=False):
|
||||
tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
|
||||
if return_tensor:
|
||||
return torch.tensor(tokenized_text).long()
|
||||
return tokenized_text
|
||||
|
||||
|
||||
def get_dataset(config, split):
|
||||
|
||||
dialog_pth = config['anno_visdial_{}'.format(split)]
|
||||
dialog_data = json.load(open(dialog_pth, 'r'))['data']
|
||||
all_answers = dialog_data['answers']
|
||||
all_questions = dialog_data['questions']
|
||||
dialog_list = []
|
||||
n_history = config['num_hist_turns']
|
||||
vid_set = set()
|
||||
|
||||
pbar = tqdm(dialog_data['dialogs'])
|
||||
pbar.set_description('[INFO] Loading VisDial - {}'.format(split))
|
||||
for dialog in pbar:
|
||||
caption = dialog['caption']
|
||||
questions = [all_questions[d['question']] for d in dialog['dialog']]
|
||||
answers = [all_answers[d['answer']] for d in dialog['dialog']]
|
||||
|
||||
vid = dialog["image_id"]
|
||||
vid_set.add(vid)
|
||||
# if undisclosed_only:
|
||||
# it = range(len(questions) - 1, len(questions))
|
||||
# else:
|
||||
it = range(len(questions))
|
||||
qalist=[]
|
||||
history = []
|
||||
# if undisclosed_only:
|
||||
# for n in range(len(questions)-1):
|
||||
# qalist.append(questions[n])
|
||||
# qalist.append(answers[n])
|
||||
# history=qalist[max(-len(qalist),-n_history*2):]
|
||||
|
||||
for n in it:
|
||||
# if undisclosed_only:
|
||||
# assert dialog['dialog'][n]['answer'] == '__UNDISCLOSED__'
|
||||
question = questions[n]
|
||||
answer = answers[n]
|
||||
history.append(question)
|
||||
if n_history == 0:
|
||||
item = {'vid': vid, 'history': [question], 'answer': answer, 'caption': caption}
|
||||
else:
|
||||
item = {'vid': vid, 'history': history, 'answer': answer, 'caption': caption}
|
||||
dialog_list.append(item)
|
||||
qalist.append(question)
|
||||
qalist.append(answer)
|
||||
history=qalist[max(-len(qalist),-n_history*2):]
|
||||
return dialog_list
|
||||
|
||||
|
||||
class Champagne(Dataset):
|
||||
def __init__(self, config, medium, vis_processor, text_processor, split):
|
||||
|
||||
self.config = config
|
||||
self.medium = medium
|
||||
self.vis_processor = vis_processor
|
||||
self.text_processor = text_processor
|
||||
self.split = split
|
||||
self.batch_size = config['batch_size_{}'.format(medium)]
|
||||
|
||||
self.root_vis = config['root_raw_vis_{}_{}'.format(medium, split)]
|
||||
|
||||
# get the mapping between caption and image/video
|
||||
mapping_path = config.get('mapping_path_{}_{}'.format(medium, split), None)
|
||||
with open(mapping_path, 'rb') as f:
|
||||
self.mapping = pickle.load(f)
|
||||
|
||||
ids = list(self.mapping.keys())
|
||||
ids.sort()
|
||||
|
||||
# reserve some samples for validation
|
||||
if split == 'train':
|
||||
self.ids = ids[config.num_val_samples:]
|
||||
elif split == 'val':
|
||||
self.ids = ids[:config.num_val_samples]
|
||||
|
||||
num_samples = config['num_samples_{}'.format(self.medium)]
|
||||
if num_samples > 0:
|
||||
self.ids = self.ids[:num_samples]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.ids)
|
||||
|
||||
|
||||
def padding(self, seq, pad_token, max_len=None):
|
||||
if max_len is None:
|
||||
max_len = max([i.size(0) for i in seq])
|
||||
if len(seq[0].size()) == 1:
|
||||
result = torch.ones((len(seq), max_len)).long() * pad_token
|
||||
else:
|
||||
result = torch.ones((len(seq), max_len, seq[0].size(-1))).float()
|
||||
for i in range(len(seq)):
|
||||
result[i, :seq[i].size(0)] = seq[i]
|
||||
orig_len = [s.size(0) for s in seq]
|
||||
return result, orig_len
|
||||
|
||||
def __getitem__(self, index):
|
||||
item = self.mapping[self.ids[index]]
|
||||
# load the videos
|
||||
pth = os.path.join(self.root_vis, item['path'])
|
||||
f_names = os.listdir(pth)
|
||||
if len(f_names) == 0:
|
||||
with open('/pfss/mlde/workspaces/mlde_wsp_Rohrbach/users/ma35vahy/V2Dial_new/emergency/item.pkl', 'rb') as f:
|
||||
item = pickle.load(f)
|
||||
|
||||
# load the videos
|
||||
pth = os.path.join(self.root_vis, item['path'])
|
||||
f_names = os.listdir(pth)
|
||||
f_names.sort()
|
||||
|
||||
if len(f_names) < self.config['num_frames']:
|
||||
f_names += [f_names[-1]] * (self.config['num_frames'] - len(f_names))
|
||||
elif len(f_names) > self.config['num_frames']:
|
||||
f_names = f_names[:self.config['num_frames']]
|
||||
|
||||
pth = [os.path.join(pth, f_name) for f_name in f_names]
|
||||
try:
|
||||
vis = [Image.open(p).convert('RGB') for p in pth]
|
||||
except:
|
||||
with open('/pfss/mlde/workspaces/mlde_wsp_Rohrbach/users/ma35vahy/V2Dial_new/emergency/item.pkl', 'rb') as f:
|
||||
item = pickle.load(f)
|
||||
|
||||
# load the videos
|
||||
pth = os.path.join(self.root_vis, item['path'])
|
||||
f_names = os.listdir(pth)
|
||||
f_names.sort()
|
||||
|
||||
pth = [os.path.join(pth, f_name) for f_name in f_names]
|
||||
vis = [Image.open(p).convert('RGB') for p in pth]
|
||||
|
||||
vis = [self.vis_processor(v).unsqueeze(0) for v in vis]
|
||||
vis = torch.cat(vis, dim=0)
|
||||
|
||||
dialog = item['dialog']
|
||||
|
||||
caption = dialog['caption']
|
||||
history = dialog['history']
|
||||
answer = dialog['answer']
|
||||
|
||||
caption = self.text_processor(caption)
|
||||
history = [self.text_processor(h) for h in history]
|
||||
answer = self.text_processor(answer, remove_period=True)
|
||||
|
||||
if self.config.embed_from_llm:
|
||||
if self.config.llm_family in ['llama', 'mistral']:
|
||||
cls_tok = ''
|
||||
sep_tok = ''
|
||||
bos_tok = '<s>'
|
||||
eos_tok = '</s>'
|
||||
else:
|
||||
cls_tok = '<s>'
|
||||
sep_tok = '</s>'
|
||||
bos_tok = '<pad>'
|
||||
eos_tok = '</s>'
|
||||
else:
|
||||
cls_tok = '[CLS]'
|
||||
sep_tok = '[SEP]'
|
||||
bos_tok = '[SEP]'
|
||||
eos_tok = '[SEP]'
|
||||
|
||||
# preprocess the textual data
|
||||
caption = cls_tok + caption + sep_tok
|
||||
history = sep_tok.join(history)
|
||||
history = history + sep_tok
|
||||
# if self.config.llm_family == 'flan_t5':
|
||||
# answer = '<s> ' + self.text_processor(answer) + ' </s>'
|
||||
# else:
|
||||
# answer = self.text_processor(answer) + eos_tok
|
||||
|
||||
return vis, caption, history, answer
|
||||
|
||||
|
||||
# def collate_fn(self, batch):
|
||||
|
||||
# BOS, EOS, SEP = self.tokenizer_enc_dec.convert_tokens_to_ids(['<s>', '</s>', '</s>'])
|
||||
|
||||
# vis_list, cap_list, hist_list, ques_list, ans_list, index_list, vid_id_list = [], [], [], [], [], [], []
|
||||
# batch_size = len(batch)
|
||||
# for b in batch:
|
||||
# vis_list.append(b[0])
|
||||
# cap = [BOS] + tokenize(b[1], self.tokenizer_enc_dec) + [EOS]
|
||||
# cap_list.append(torch.tensor(cap))
|
||||
# if len(b[2])!=0:
|
||||
# hist = [[SEP] + tokenize(s, self.tokenizer_enc_dec) for s in b[2]] + [[EOS]]
|
||||
# hist_list.append(torch.tensor(list(chain(*hist))))
|
||||
# else:
|
||||
# hist = [SEP] + tokenize(b[3], self.tokenizer_enc_dec) + [EOS]
|
||||
# hist_list.append(torch.tensor(hist))
|
||||
|
||||
# ques = tokenize(b[3], self.tokenizer_enc_dec) + [EOS]
|
||||
# ques_list.append(torch.tensor(ques))
|
||||
# ans = tokenize(b[4], self.tokenizer_enc_dec) + [EOS]
|
||||
# ans_list.append(torch.tensor(ans))
|
||||
# index_list.append(b[5])
|
||||
# vid_id_list.append(b[6])
|
||||
|
||||
# # pad and keep track of the original lengths
|
||||
# cap_input_ids, cap_orig_lens = self.padding(cap_list, self.tokenizer_experts.pad_token_id)
|
||||
# hist_input_ids, hist_orig_lens = self.padding(hist_list, self.tokenizer_experts.pad_token_id)
|
||||
# ques_input_ids, ques_orig_lens = self.padding(ques_list, self.tokenizer_experts.pad_token_id)
|
||||
# ans_input_ids, _ = self.padding(ans_list, -100)
|
||||
|
||||
# cap_attention_mask = cap_input_ids != self.tokenizer_experts.pad_token_id
|
||||
# hist_attention_mask = hist_input_ids != self.tokenizer_experts.pad_token_id
|
||||
# ques_attention_mask = ques_input_ids != self.tokenizer_experts.pad_token_id
|
||||
|
||||
# total_orig_lens = [sum(l) for l in zip(cap_orig_lens, hist_orig_lens, ques_orig_lens)]
|
||||
# max_len = max(total_orig_lens)
|
||||
|
||||
# dummy_input_ids_enc_dec = torch.full((batch_size, max_len), self.tokenizer_experts.pad_token_id)
|
||||
# enc_dec_attention_mask = torch.zeros_like(dummy_input_ids_enc_dec, dtype=torch.bool)
|
||||
# for i, l in enumerate(total_orig_lens):
|
||||
# enc_dec_attention_mask[i][:l] = True
|
||||
# # add the masking of the visual input
|
||||
# num_query_tok = self.config['num_temporal_query_tokens_{}'.format(self.config['bert_size'])]
|
||||
# if self.medium in ['avsd', 'msrvtt', 'webvid', 'champagne']:
|
||||
# vis_attention_mask = torch.ones((batch_size, 2 * num_query_tok), dtype=torch.bool) # *2 for spatial and temporal queries
|
||||
# else:
|
||||
# vis_attention_mask = torch.ones((batch_size, num_query_tok), dtype=torch.bool) # only spatial queries
|
||||
|
||||
# enc_dec_attention_mask = torch.concat((vis_attention_mask, enc_dec_attention_mask), dim=1)
|
||||
# # Now prepare the data
|
||||
# vis = torch.stack(vis_list, dim=0)
|
||||
# cap = {
|
||||
# 'input_ids': cap_input_ids,
|
||||
# 'attention_mask': cap_attention_mask,
|
||||
# 'orig_lens': cap_orig_lens
|
||||
# }
|
||||
|
||||
# hist = {
|
||||
# 'input_ids': hist_input_ids,
|
||||
# 'attention_mask': hist_attention_mask,
|
||||
# 'orig_lens': hist_orig_lens
|
||||
# }
|
||||
|
||||
# ques = {
|
||||
# 'input_ids': ques_input_ids,
|
||||
# 'attention_mask': ques_attention_mask,
|
||||
# 'orig_lens': ques_orig_lens
|
||||
# }
|
||||
|
||||
# ans = {
|
||||
# 'input_ids': ans_input_ids,
|
||||
# }
|
||||
|
||||
# enc_dec_input = {
|
||||
# 'input_ids': dummy_input_ids_enc_dec,
|
||||
# 'attention_mask': enc_dec_attention_mask,
|
||||
# }
|
||||
|
||||
# index = torch.tensor(index_list)
|
||||
# return vis, cap, hist, ques, ans, enc_dec_input, index, vid_id_list
|
||||
|
||||
|
||||
def load_champagne_dataset(config, vis_processor, text_processor, split):
|
||||
dataset = Champagne(config, 'champagne', vis_processor, text_processor, split)
|
||||
return dataset
|
Loading…
Add table
Add a link
Reference in a new issue