V2Dial/datasets/champagne_dataset.py
2025-06-24 08:38:09 +02:00

279 lines
No EOL
10 KiB
Python

# coding: utf-8
# author: noctli
import json
import os
import pickle
import logging
from tqdm import tqdm
import numpy as np
import torch
import torch.utils.data
from PIL import Image
from torch.utils.data import Dataset
from itertools import chain
from torchvision import transforms
from .utils import type_transform_helper
def tokenize(text, tokenizer, return_tensor=False):
tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
if return_tensor:
return torch.tensor(tokenized_text).long()
return tokenized_text
def get_dataset(config, split):
dialog_pth = config['anno_visdial_{}'.format(split)]
dialog_data = json.load(open(dialog_pth, 'r'))['data']
all_answers = dialog_data['answers']
all_questions = dialog_data['questions']
dialog_list = []
n_history = config['num_hist_turns']
vid_set = set()
pbar = tqdm(dialog_data['dialogs'])
pbar.set_description('[INFO] Loading VisDial - {}'.format(split))
for dialog in pbar:
caption = dialog['caption']
questions = [all_questions[d['question']] for d in dialog['dialog']]
answers = [all_answers[d['answer']] for d in dialog['dialog']]
vid = dialog["image_id"]
vid_set.add(vid)
# if undisclosed_only:
# it = range(len(questions) - 1, len(questions))
# else:
it = range(len(questions))
qalist=[]
history = []
# if undisclosed_only:
# for n in range(len(questions)-1):
# qalist.append(questions[n])
# qalist.append(answers[n])
# history=qalist[max(-len(qalist),-n_history*2):]
for n in it:
# if undisclosed_only:
# assert dialog['dialog'][n]['answer'] == '__UNDISCLOSED__'
question = questions[n]
answer = answers[n]
history.append(question)
if n_history == 0:
item = {'vid': vid, 'history': [question], 'answer': answer, 'caption': caption}
else:
item = {'vid': vid, 'history': history, 'answer': answer, 'caption': caption}
dialog_list.append(item)
qalist.append(question)
qalist.append(answer)
history=qalist[max(-len(qalist),-n_history*2):]
return dialog_list
class Champagne(Dataset):
def __init__(self, config, medium, vis_processor, text_processor, split):
self.config = config
self.medium = medium
self.vis_processor = vis_processor
self.text_processor = text_processor
self.split = split
self.batch_size = config['batch_size_{}'.format(medium)]
self.root_vis = config['root_raw_vis_{}_{}'.format(medium, split)]
# get the mapping between caption and image/video
mapping_path = config.get('mapping_path_{}_{}'.format(medium, split), None)
with open(mapping_path, 'rb') as f:
self.mapping = pickle.load(f)
ids = list(self.mapping.keys())
ids.sort()
# reserve some samples for validation
if split == 'train':
self.ids = ids[config.num_val_samples:]
elif split == 'val':
self.ids = ids[:config.num_val_samples]
num_samples = config['num_samples_{}'.format(self.medium)]
if num_samples > 0:
self.ids = self.ids[:num_samples]
def __len__(self):
return len(self.ids)
def padding(self, seq, pad_token, max_len=None):
if max_len is None:
max_len = max([i.size(0) for i in seq])
if len(seq[0].size()) == 1:
result = torch.ones((len(seq), max_len)).long() * pad_token
else:
result = torch.ones((len(seq), max_len, seq[0].size(-1))).float()
for i in range(len(seq)):
result[i, :seq[i].size(0)] = seq[i]
orig_len = [s.size(0) for s in seq]
return result, orig_len
def __getitem__(self, index):
item = self.mapping[self.ids[index]]
# load the videos
pth = os.path.join(self.root_vis, item['path'])
f_names = os.listdir(pth)
if len(f_names) == 0:
with open('/pfss/mlde/workspaces/mlde_wsp_Rohrbach/users/ma35vahy/V2Dial_new/emergency/item.pkl', 'rb') as f:
item = pickle.load(f)
# load the videos
pth = os.path.join(self.root_vis, item['path'])
f_names = os.listdir(pth)
f_names.sort()
if len(f_names) < self.config['num_frames']:
f_names += [f_names[-1]] * (self.config['num_frames'] - len(f_names))
elif len(f_names) > self.config['num_frames']:
f_names = f_names[:self.config['num_frames']]
pth = [os.path.join(pth, f_name) for f_name in f_names]
try:
vis = [Image.open(p).convert('RGB') for p in pth]
except:
with open('/pfss/mlde/workspaces/mlde_wsp_Rohrbach/users/ma35vahy/V2Dial_new/emergency/item.pkl', 'rb') as f:
item = pickle.load(f)
# load the videos
pth = os.path.join(self.root_vis, item['path'])
f_names = os.listdir(pth)
f_names.sort()
pth = [os.path.join(pth, f_name) for f_name in f_names]
vis = [Image.open(p).convert('RGB') for p in pth]
vis = [self.vis_processor(v).unsqueeze(0) for v in vis]
vis = torch.cat(vis, dim=0)
dialog = item['dialog']
caption = dialog['caption']
history = dialog['history']
answer = dialog['answer']
caption = self.text_processor(caption)
history = [self.text_processor(h) for h in history]
answer = self.text_processor(answer, remove_period=True)
if self.config.embed_from_llm:
if self.config.llm_family in ['llama', 'mistral']:
cls_tok = ''
sep_tok = ''
bos_tok = '<s>'
eos_tok = '</s>'
else:
cls_tok = '<s>'
sep_tok = '</s>'
bos_tok = '<pad>'
eos_tok = '</s>'
else:
cls_tok = '[CLS]'
sep_tok = '[SEP]'
bos_tok = '[SEP]'
eos_tok = '[SEP]'
# preprocess the textual data
caption = cls_tok + caption + sep_tok
history = sep_tok.join(history)
history = history + sep_tok
# if self.config.llm_family == 'flan_t5':
# answer = '<s> ' + self.text_processor(answer) + ' </s>'
# else:
# answer = self.text_processor(answer) + eos_tok
return vis, caption, history, answer
# def collate_fn(self, batch):
# BOS, EOS, SEP = self.tokenizer_enc_dec.convert_tokens_to_ids(['<s>', '</s>', '</s>'])
# vis_list, cap_list, hist_list, ques_list, ans_list, index_list, vid_id_list = [], [], [], [], [], [], []
# batch_size = len(batch)
# for b in batch:
# vis_list.append(b[0])
# cap = [BOS] + tokenize(b[1], self.tokenizer_enc_dec) + [EOS]
# cap_list.append(torch.tensor(cap))
# if len(b[2])!=0:
# hist = [[SEP] + tokenize(s, self.tokenizer_enc_dec) for s in b[2]] + [[EOS]]
# hist_list.append(torch.tensor(list(chain(*hist))))
# else:
# hist = [SEP] + tokenize(b[3], self.tokenizer_enc_dec) + [EOS]
# hist_list.append(torch.tensor(hist))
# ques = tokenize(b[3], self.tokenizer_enc_dec) + [EOS]
# ques_list.append(torch.tensor(ques))
# ans = tokenize(b[4], self.tokenizer_enc_dec) + [EOS]
# ans_list.append(torch.tensor(ans))
# index_list.append(b[5])
# vid_id_list.append(b[6])
# # pad and keep track of the original lengths
# cap_input_ids, cap_orig_lens = self.padding(cap_list, self.tokenizer_experts.pad_token_id)
# hist_input_ids, hist_orig_lens = self.padding(hist_list, self.tokenizer_experts.pad_token_id)
# ques_input_ids, ques_orig_lens = self.padding(ques_list, self.tokenizer_experts.pad_token_id)
# ans_input_ids, _ = self.padding(ans_list, -100)
# cap_attention_mask = cap_input_ids != self.tokenizer_experts.pad_token_id
# hist_attention_mask = hist_input_ids != self.tokenizer_experts.pad_token_id
# ques_attention_mask = ques_input_ids != self.tokenizer_experts.pad_token_id
# total_orig_lens = [sum(l) for l in zip(cap_orig_lens, hist_orig_lens, ques_orig_lens)]
# max_len = max(total_orig_lens)
# dummy_input_ids_enc_dec = torch.full((batch_size, max_len), self.tokenizer_experts.pad_token_id)
# enc_dec_attention_mask = torch.zeros_like(dummy_input_ids_enc_dec, dtype=torch.bool)
# for i, l in enumerate(total_orig_lens):
# enc_dec_attention_mask[i][:l] = True
# # add the masking of the visual input
# num_query_tok = self.config['num_temporal_query_tokens_{}'.format(self.config['bert_size'])]
# if self.medium in ['avsd', 'msrvtt', 'webvid', 'champagne']:
# vis_attention_mask = torch.ones((batch_size, 2 * num_query_tok), dtype=torch.bool) # *2 for spatial and temporal queries
# else:
# vis_attention_mask = torch.ones((batch_size, num_query_tok), dtype=torch.bool) # only spatial queries
# enc_dec_attention_mask = torch.concat((vis_attention_mask, enc_dec_attention_mask), dim=1)
# # Now prepare the data
# vis = torch.stack(vis_list, dim=0)
# cap = {
# 'input_ids': cap_input_ids,
# 'attention_mask': cap_attention_mask,
# 'orig_lens': cap_orig_lens
# }
# hist = {
# 'input_ids': hist_input_ids,
# 'attention_mask': hist_attention_mask,
# 'orig_lens': hist_orig_lens
# }
# ques = {
# 'input_ids': ques_input_ids,
# 'attention_mask': ques_attention_mask,
# 'orig_lens': ques_orig_lens
# }
# ans = {
# 'input_ids': ans_input_ids,
# }
# enc_dec_input = {
# 'input_ids': dummy_input_ids_enc_dec,
# 'attention_mask': enc_dec_attention_mask,
# }
# index = torch.tensor(index_list)
# return vis, cap, hist, ques, ans, enc_dec_input, index, vid_id_list
def load_champagne_dataset(config, vis_processor, text_processor, split):
dataset = Champagne(config, 'champagne', vis_processor, text_processor, split)
return dataset