initial commit

This commit is contained in:
Andreas Bulling 2025-06-24 08:38:09 +02:00
commit a82bbc593e
129 changed files with 33981 additions and 0 deletions

0
datasets/__init__.py Normal file
View file

205
datasets/avsd_dataset.py Normal file
View file

@ -0,0 +1,205 @@
# coding: utf-8
# author: noctli
import json
import os
import pickle
import logging
import numpy as np
from tqdm import tqdm
import torch
import torch.utils.data
from PIL import Image
from torch.utils.data import Dataset
from itertools import chain
from torchvision import transforms
from .utils import type_transform_helper
from itertools import chain
from .video_utils import read_frames_decord
def tokenize(text, tokenizer, return_tensor=False):
tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
if return_tensor:
return torch.tensor(tokenized_text).long()
return tokenized_text
def get_dataset(config, split):
if split != 'test':
dialog_pth = config[f'anno_avsd_{split}']
else:
dialog_pth = config['anno_avsd_test_dstc_{}'.format(config['dstc'])]
n_history = config['num_hist_turns_avsd']
undisclosed_only = split == 'test'
dialog_data = json.load(open(dialog_pth, 'r'))
dialog_list = []
vid_set = set()
pbar = tqdm(dialog_data['dialogs'])
pbar.set_description('[INFO] Loading AVSD - {}'.format(split))
for dialog in pbar:
# if config['dstc'] != 10:
caption = dialog['caption']
summary = dialog['summary']
# else:
# caption = 'no'
# summary = 'no'
questions = [d['question'] for d in dialog['dialog']]
answers = [d['answer'] for d in dialog['dialog']]
vid = dialog["image_id"]
vid_set.add(vid)
if undisclosed_only:
it = range(len(questions) - 1, len(questions))
else:
it = range(len(questions))
qalist=[]
history = []
if undisclosed_only:
for n in range(len(questions)-1):
qalist.append(questions[n])
qalist.append(answers[n])
history=qalist[max(-len(qalist),-n_history*2):]
for n in it:
if undisclosed_only:
assert dialog['dialog'][n]['answer'] == '__UNDISCLOSED__'
question = questions[n]
answer = answers[n]
history.append(question)
if n_history == 0:
item = {'vid': vid, 'history': [question], 'answer': answer, 'caption': caption, 'summary': summary}
else:
item = {'vid': vid, 'history': history, 'answer': answer, 'caption': caption, 'summary': summary}
dialog_list.append(item)
qalist.append(question)
qalist.append(answer)
history=qalist[max(-len(qalist),-n_history*2):]
return dialog_list
def build_input_from_segments(caption, history, reply, tokenizer, drop_caption=False):
""" Build a sequence of input from 3 segments: caption(caption+summary) history and last reply """
bos, eos = tokenizer.convert_tokens_to_ids(['<s>', '</s>'])
sep = eos
instance = {}
instance["lm_labels"] = reply + [eos]
caption = list(chain(*caption))
if not drop_caption:
# sequence = [[bos] + list(chain(*caption))] + history + [reply + ([eos] if with_eos else [])]
# NOTE It is important not to include the reply in the input of the encoder -- > the decoder will just
# learn to copy it --> low train/val loss but no learning is happening
sequence = [[bos] + caption + [eos]] + [[sep] + s for s in history] + [[eos]]
else:
sequence = [[bos]] + [[sep] + s for s in history] + [[eos]]
instance["input_ids"] = list(chain(*sequence))
return instance
class AVSDDataSet(Dataset):
def __init__(self, config, medium, vis_processor, text_processor, split
# tokenizer, features=None, drop_rate=0.0, train=True
):
self.config = config
self.medium = medium
self.vis_processor = vis_processor
self.text_processor = text_processor
self.split = split
self.batch_size = config['batch_size_test_{}'.format(medium)] if split == 'test' else config['batch_size_{}'.format(medium)]
self.root_vis = config['root_raw_vis_{}_{}'.format(medium, split)]
self.dialogs = get_dataset(config, split)
if split == 'test':
self.dialogs = self.dialogs[config['start_idx_gen']: config['end_idx_gen']]
num_samples = config['num_samples_{}'.format(self.medium)]
if num_samples > 0:
self.dialogs = self.dialogs[:num_samples]
def __len__(self):
return len(self.dialogs)
def load_vid(self, vid_id):
vid_dir_path = os.path.join(self.root_vis, vid_id + '.mp4')
frames, _, _ = read_frames_decord(vid_dir_path, self.config.num_frames)
frames = [self.vis_processor(f).unsqueeze(0) for f in frames]
vis = torch.cat(frames, dim=0)
return vis
def load_vid_old(self, vid_id):
# if vid_id == 'QQM8M':
# print('bla')
vid_dir_path = os.path.join(self.root_vis, vid_id)
frame_paths = [os.path.join(vid_dir_path, f) for f in os.listdir(vid_dir_path)]
frame_paths.sort()
num_avail_frames = len(frame_paths)
delta = int(num_avail_frames / (self.config['num_frames'] - 1))
ran = list(range(0, num_avail_frames, delta))
if len(ran) < self.config['num_frames']:
ran.extend([num_avail_frames - 1 for _ in range(self.config['num_frames'] - len(ran))])
if len(ran) > self.config['num_frames']:
ran = ran[:self.config['num_frames']]
assert len(ran) == self.config['num_frames'], f"vid {vid_id} - loaded {len(ran)}/{len(frame_paths)} frames"
frame_paths = [frame_paths[i] for i in ran]
vis = [Image.open(p).convert('RGB') for p in frame_paths]
vis = [transforms.PILToTensor()(v).unsqueeze(0) for v in vis]
vis = torch.cat(vis, dim=0)
vis = self.trans(vis)
return vis
def __getitem__(self, index):
dialog = self.dialogs[index]
vid_id = dialog['vid']
caption = dialog['caption']
summary = dialog['summary']
history = dialog['history']
answer = dialog['answer']
caption = self.text_processor(caption)
summary = self.text_processor(summary)
if self.config.dstc != 10:
caption = caption + ' ' + summary
history = [self.text_processor(h) for h in history]
answer = self.text_processor(answer, remove_period=True)
if self.config.embed_from_llm:
if self.config.llm_family in ['llama', 'mistral']:
cls_tok = ''
sep_tok = ' '
bos_tok = '<s>'
eos_tok = '</s>'
else:
cls_tok = '<s>'
sep_tok = '</s>'
bos_tok = '<pad>'
eos_tok = '</s>'
else:
cls_tok = '[CLS]'
sep_tok = '[SEP]'
bos_tok = '[SEP]'
eos_tok = '[SEP]'
caption = cls_tok + caption + sep_tok
history = sep_tok.join(history)
history = history + sep_tok
# load the video frames
vis = self.load_vid(vid_id)
return vis, caption, history, answer, vid_id
def load_avsd_dataset(config, vis_processor, text_processor, split):
# data_file = config['anno_avsd_{}'.format(split)]
# dataset_list = get_dataset(config, split, tokenizer_enc_dec)
dataset = AVSDDataSet(config, 'avsd', vis_processor, text_processor, split)
return dataset

View file

@ -0,0 +1,279 @@
# coding: utf-8
# author: noctli
import json
import os
import pickle
import logging
from tqdm import tqdm
import numpy as np
import torch
import torch.utils.data
from PIL import Image
from torch.utils.data import Dataset
from itertools import chain
from torchvision import transforms
from .utils import type_transform_helper
def tokenize(text, tokenizer, return_tensor=False):
tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
if return_tensor:
return torch.tensor(tokenized_text).long()
return tokenized_text
def get_dataset(config, split):
dialog_pth = config['anno_visdial_{}'.format(split)]
dialog_data = json.load(open(dialog_pth, 'r'))['data']
all_answers = dialog_data['answers']
all_questions = dialog_data['questions']
dialog_list = []
n_history = config['num_hist_turns']
vid_set = set()
pbar = tqdm(dialog_data['dialogs'])
pbar.set_description('[INFO] Loading VisDial - {}'.format(split))
for dialog in pbar:
caption = dialog['caption']
questions = [all_questions[d['question']] for d in dialog['dialog']]
answers = [all_answers[d['answer']] for d in dialog['dialog']]
vid = dialog["image_id"]
vid_set.add(vid)
# if undisclosed_only:
# it = range(len(questions) - 1, len(questions))
# else:
it = range(len(questions))
qalist=[]
history = []
# if undisclosed_only:
# for n in range(len(questions)-1):
# qalist.append(questions[n])
# qalist.append(answers[n])
# history=qalist[max(-len(qalist),-n_history*2):]
for n in it:
# if undisclosed_only:
# assert dialog['dialog'][n]['answer'] == '__UNDISCLOSED__'
question = questions[n]
answer = answers[n]
history.append(question)
if n_history == 0:
item = {'vid': vid, 'history': [question], 'answer': answer, 'caption': caption}
else:
item = {'vid': vid, 'history': history, 'answer': answer, 'caption': caption}
dialog_list.append(item)
qalist.append(question)
qalist.append(answer)
history=qalist[max(-len(qalist),-n_history*2):]
return dialog_list
class Champagne(Dataset):
def __init__(self, config, medium, vis_processor, text_processor, split):
self.config = config
self.medium = medium
self.vis_processor = vis_processor
self.text_processor = text_processor
self.split = split
self.batch_size = config['batch_size_{}'.format(medium)]
self.root_vis = config['root_raw_vis_{}_{}'.format(medium, split)]
# get the mapping between caption and image/video
mapping_path = config.get('mapping_path_{}_{}'.format(medium, split), None)
with open(mapping_path, 'rb') as f:
self.mapping = pickle.load(f)
ids = list(self.mapping.keys())
ids.sort()
# reserve some samples for validation
if split == 'train':
self.ids = ids[config.num_val_samples:]
elif split == 'val':
self.ids = ids[:config.num_val_samples]
num_samples = config['num_samples_{}'.format(self.medium)]
if num_samples > 0:
self.ids = self.ids[:num_samples]
def __len__(self):
return len(self.ids)
def padding(self, seq, pad_token, max_len=None):
if max_len is None:
max_len = max([i.size(0) for i in seq])
if len(seq[0].size()) == 1:
result = torch.ones((len(seq), max_len)).long() * pad_token
else:
result = torch.ones((len(seq), max_len, seq[0].size(-1))).float()
for i in range(len(seq)):
result[i, :seq[i].size(0)] = seq[i]
orig_len = [s.size(0) for s in seq]
return result, orig_len
def __getitem__(self, index):
item = self.mapping[self.ids[index]]
# load the videos
pth = os.path.join(self.root_vis, item['path'])
f_names = os.listdir(pth)
if len(f_names) == 0:
with open('/pfss/mlde/workspaces/mlde_wsp_Rohrbach/users/ma35vahy/V2Dial_new/emergency/item.pkl', 'rb') as f:
item = pickle.load(f)
# load the videos
pth = os.path.join(self.root_vis, item['path'])
f_names = os.listdir(pth)
f_names.sort()
if len(f_names) < self.config['num_frames']:
f_names += [f_names[-1]] * (self.config['num_frames'] - len(f_names))
elif len(f_names) > self.config['num_frames']:
f_names = f_names[:self.config['num_frames']]
pth = [os.path.join(pth, f_name) for f_name in f_names]
try:
vis = [Image.open(p).convert('RGB') for p in pth]
except:
with open('/pfss/mlde/workspaces/mlde_wsp_Rohrbach/users/ma35vahy/V2Dial_new/emergency/item.pkl', 'rb') as f:
item = pickle.load(f)
# load the videos
pth = os.path.join(self.root_vis, item['path'])
f_names = os.listdir(pth)
f_names.sort()
pth = [os.path.join(pth, f_name) for f_name in f_names]
vis = [Image.open(p).convert('RGB') for p in pth]
vis = [self.vis_processor(v).unsqueeze(0) for v in vis]
vis = torch.cat(vis, dim=0)
dialog = item['dialog']
caption = dialog['caption']
history = dialog['history']
answer = dialog['answer']
caption = self.text_processor(caption)
history = [self.text_processor(h) for h in history]
answer = self.text_processor(answer, remove_period=True)
if self.config.embed_from_llm:
if self.config.llm_family in ['llama', 'mistral']:
cls_tok = ''
sep_tok = ''
bos_tok = '<s>'
eos_tok = '</s>'
else:
cls_tok = '<s>'
sep_tok = '</s>'
bos_tok = '<pad>'
eos_tok = '</s>'
else:
cls_tok = '[CLS]'
sep_tok = '[SEP]'
bos_tok = '[SEP]'
eos_tok = '[SEP]'
# preprocess the textual data
caption = cls_tok + caption + sep_tok
history = sep_tok.join(history)
history = history + sep_tok
# if self.config.llm_family == 'flan_t5':
# answer = '<s> ' + self.text_processor(answer) + ' </s>'
# else:
# answer = self.text_processor(answer) + eos_tok
return vis, caption, history, answer
# def collate_fn(self, batch):
# BOS, EOS, SEP = self.tokenizer_enc_dec.convert_tokens_to_ids(['<s>', '</s>', '</s>'])
# vis_list, cap_list, hist_list, ques_list, ans_list, index_list, vid_id_list = [], [], [], [], [], [], []
# batch_size = len(batch)
# for b in batch:
# vis_list.append(b[0])
# cap = [BOS] + tokenize(b[1], self.tokenizer_enc_dec) + [EOS]
# cap_list.append(torch.tensor(cap))
# if len(b[2])!=0:
# hist = [[SEP] + tokenize(s, self.tokenizer_enc_dec) for s in b[2]] + [[EOS]]
# hist_list.append(torch.tensor(list(chain(*hist))))
# else:
# hist = [SEP] + tokenize(b[3], self.tokenizer_enc_dec) + [EOS]
# hist_list.append(torch.tensor(hist))
# ques = tokenize(b[3], self.tokenizer_enc_dec) + [EOS]
# ques_list.append(torch.tensor(ques))
# ans = tokenize(b[4], self.tokenizer_enc_dec) + [EOS]
# ans_list.append(torch.tensor(ans))
# index_list.append(b[5])
# vid_id_list.append(b[6])
# # pad and keep track of the original lengths
# cap_input_ids, cap_orig_lens = self.padding(cap_list, self.tokenizer_experts.pad_token_id)
# hist_input_ids, hist_orig_lens = self.padding(hist_list, self.tokenizer_experts.pad_token_id)
# ques_input_ids, ques_orig_lens = self.padding(ques_list, self.tokenizer_experts.pad_token_id)
# ans_input_ids, _ = self.padding(ans_list, -100)
# cap_attention_mask = cap_input_ids != self.tokenizer_experts.pad_token_id
# hist_attention_mask = hist_input_ids != self.tokenizer_experts.pad_token_id
# ques_attention_mask = ques_input_ids != self.tokenizer_experts.pad_token_id
# total_orig_lens = [sum(l) for l in zip(cap_orig_lens, hist_orig_lens, ques_orig_lens)]
# max_len = max(total_orig_lens)
# dummy_input_ids_enc_dec = torch.full((batch_size, max_len), self.tokenizer_experts.pad_token_id)
# enc_dec_attention_mask = torch.zeros_like(dummy_input_ids_enc_dec, dtype=torch.bool)
# for i, l in enumerate(total_orig_lens):
# enc_dec_attention_mask[i][:l] = True
# # add the masking of the visual input
# num_query_tok = self.config['num_temporal_query_tokens_{}'.format(self.config['bert_size'])]
# if self.medium in ['avsd', 'msrvtt', 'webvid', 'champagne']:
# vis_attention_mask = torch.ones((batch_size, 2 * num_query_tok), dtype=torch.bool) # *2 for spatial and temporal queries
# else:
# vis_attention_mask = torch.ones((batch_size, num_query_tok), dtype=torch.bool) # only spatial queries
# enc_dec_attention_mask = torch.concat((vis_attention_mask, enc_dec_attention_mask), dim=1)
# # Now prepare the data
# vis = torch.stack(vis_list, dim=0)
# cap = {
# 'input_ids': cap_input_ids,
# 'attention_mask': cap_attention_mask,
# 'orig_lens': cap_orig_lens
# }
# hist = {
# 'input_ids': hist_input_ids,
# 'attention_mask': hist_attention_mask,
# 'orig_lens': hist_orig_lens
# }
# ques = {
# 'input_ids': ques_input_ids,
# 'attention_mask': ques_attention_mask,
# 'orig_lens': ques_orig_lens
# }
# ans = {
# 'input_ids': ans_input_ids,
# }
# enc_dec_input = {
# 'input_ids': dummy_input_ids_enc_dec,
# 'attention_mask': enc_dec_attention_mask,
# }
# index = torch.tensor(index_list)
# return vis, cap, hist, ques, ans, enc_dec_input, index, vid_id_list
def load_champagne_dataset(config, vis_processor, text_processor, split):
dataset = Champagne(config, 'champagne', vis_processor, text_processor, split)
return dataset

137
datasets/dataloader.py Normal file
View file

@ -0,0 +1,137 @@
"""
From https://github.com/klauscc/VindLU/blob/main/dataset/dataloader.py
"""
import torch
from torch.utils.data import DataLoader, Dataset, ConcatDataset
import torch.distributed as dist
from utils.dist import *
import random
import logging
logger = logging.getLogger(__name__)
class MetaLoader(object):
""" wraps multiple data loader """
def __init__(self, name2loader):
"""Iterates over multiple dataloaders, it ensures all processes
work on data from the same dataloader. This loader will end when
the shorter dataloader raises StopIteration exception.
loaders: Dict, {name: dataloader}
"""
self.name2loader = name2loader
self.name2iter = {name: iter(l) for name, l in name2loader.items()}
name2index = {name: idx for idx, (name, l) in enumerate(name2loader.items())}
index2name = {v: k for k, v in name2index.items()}
iter_order = []
for n, l in name2loader.items():
iter_order.extend([name2index[n]]*len(l))
random.shuffle(iter_order)
iter_order = torch.Tensor(iter_order).to(torch.device("cuda")).to(torch.uint8)
# sync
if is_dist_avail_and_initialized():
# make sure all processes have the same order so that
# each step they will have data from the same loader
dist.broadcast(iter_order, src=0)
self.iter_order = [index2name[int(e.item())] for e in iter_order.cpu()]
logger.info(str(self))
def __str__(self):
output = [f"MetaLoader has {len(self.name2loader)} dataloaders, {len(self)} batches in total"]
for idx, (name, loader) in enumerate(self.name2loader.items()):
output.append(
f"dataloader index={idx} name={name}, batch-size={loader.batch_size} length(#batches)={len(loader)} "
)
return "\n".join(output)
def __len__(self):
return len(self.iter_order)
def __iter__(self):
""" this iterator will run indefinitely """
for name in self.iter_order:
_iter = self.name2iter[name]
batch = next(_iter)
yield name, batch
def load_dataloaders(config, datasets, split, output_dict=False):
if isinstance(datasets, dict):
datasets = list(datasets.values())
shuffles = [True] * len(datasets) if split == 'train' else [False] * len(datasets)
if config['distributed'] and split != 'test':
num_tasks = get_world_size()
global_rank = get_rank()
samplers = create_samplers(
datasets, shuffles, num_tasks, global_rank
)
else:
samplers = [None] * len(datasets)
batch_size = [dataset.datasets[0].batch_size if isinstance(dataset, ConcatDataset) else dataset.batch_size for dataset in datasets]
collate_fns = []
for dataset in datasets:
if isinstance(dataset, ConcatDataset):
collate_fns.append(getattr(dataset.datasets[0], 'collate_fn', None))
else:
collate_fns.append(getattr(dataset, 'collate_fn', None))
loaders = create_loader(
datasets,
samplers,
batch_size=batch_size,
num_workers=[config.num_workers] * len(datasets),
is_trains=shuffles,
collate_fns=collate_fns,
) # [0]
loaders_dict = {}
if output_dict:
for l in loaders:
if isinstance(l.dataset, ConcatDataset):
loaders_dict[l.dataset.datasets[0].medium] = l
else:
loaders_dict[l.dataset.medium] = l
return loaders_dict
return loaders
def create_samplers(datasets, shuffles, num_tasks, global_rank):
samplers = []
for dataset, shuffle in zip(datasets, shuffles):
sampler = torch.utils.data.DistributedSampler(
dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle
)
samplers.append(sampler)
return samplers
def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
loaders = []
for dataset, sampler, bs, n_worker, is_train, collate_fn in zip(
datasets, samplers, batch_size, num_workers, is_trains, collate_fns
):
if is_train:
shuffle = sampler is None
drop_last = True
else:
shuffle = False
drop_last = True
loader = DataLoader(
dataset,
batch_size=bs,
num_workers=n_worker,
pin_memory=False,
sampler=sampler,
shuffle=shuffle,
collate_fn=collate_fn,
drop_last=drop_last,
persistent_workers=True if n_worker > 0 else False,
)
loaders.append(loader)
return loaders

View file

@ -0,0 +1,86 @@
import os
import pandas as pd
# import h5py
import json
import numpy as np
import torch
from torch.utils.data import Dataset
from .video_utils import read_frames_decord
def load_file(file_name):
annos = None
if os.path.splitext(file_name)[-1] == '.csv':
return pd.read_csv(file_name)
with open(file_name, 'r') as fp:
if os.path.splitext(file_name)[1]== '.txt':
annos = fp.readlines()
annos = [line.rstrip() for line in annos]
if os.path.splitext(file_name)[1] == '.json':
annos = json.load(fp)
return annos
class NextQADataset(Dataset):
def __init__(self, config, medium, vis_processor, text_processor, split):
super().__init__()
self.config = config
self.medium = medium
self.vis_processor = vis_processor
self.text_processor = text_processor
self.split = split
self.batch_size = config['batch_size_test_{}'.format(medium)] if split == 'test' else config['batch_size_{}'.format(medium)]
self.root_vis = config['root_raw_vis_{}_{}'.format(medium, split)]
with open(config['vid_mapping_nextqa'], 'r') as f:
self.video_mapping = json.load(f)
self.sample_list = load_file(self.config['anno_nextqa_{}'.format(split)])
if split == 'test':
self.sample_list = self.sample_list[config['start_idx_gen']: config['end_idx_gen']]
self.captions = load_file(self.config['next_qa_captions_{}'.format(split)])
else:
self.captions = None
num_samples = config['num_samples_{}'.format(self.medium)]
if num_samples > 0:
self.sample_list = self.sample_list[:num_samples]
def __len__(self):
return len(self.sample_list)
def load_vid(self, vid_id):
vid_dir_path = os.path.join(self.root_vis, self.video_mapping[vid_id] + '.mp4')
frames, _, _ = read_frames_decord(vid_dir_path, self.config.num_frames)
frames = [self.vis_processor(f).unsqueeze(0) for f in frames]
vis = torch.cat(frames, dim=0)
return vis
def __getitem__(self, idx):
if self.split == 'test':
idx += self.config['start_idx_gen']
cur_sample = self.sample_list.loc[idx]
video_id, ques, ans, qid = str(cur_sample['video']), str(cur_sample['question']),\
str(cur_sample['answer']), str(cur_sample['qid'])
history = self.text_processor(ques)
answer = self.text_processor(ans)
if self.split == 'test':
caption = self.text_processor(self.captions[video_id])
else:
caption = self.text_processor('please answer the following question based on the video')
vis = self.load_vid(video_id)
return vis, caption, history, answer, video_id, qid
def load_nextqa_dataset(config, vis_processor, text_processor, split):
# data_file = config['anno_avsd_{}'.format(split)]
# dataset_list = get_dataset(config, split, tokenizer_enc_dec)
dataset = NextQADataset(config, 'nextqa', vis_processor, text_processor, split)
return dataset

156
datasets/pretraining.py Normal file
View file

@ -0,0 +1,156 @@
from torch.utils.data import Dataset
import pickle
import os
import torch
from PIL import Image
import numpy as np
from torchvision import transforms
import random
from .utils import pre_text, type_transform_helper, load_anno, open_img
class CapDataset(Dataset):
def __init__(self, config, medium, vis_processor, text_processor, split):
super(CapDataset, self).__init__()
self.config = config
self.batch_size = config['batch_size_{}'.format(medium)]
self.medium = medium # "webvid / cc3m / msrvtt"
self.vis_processor = vis_processor
self.text_processor = text_processor
self.split = split # train / val / test
self.root_vis = config['root_raw_vis_{}_{}'.format(medium, split)]
# get the mapping between caption and image/video
mapping_path = config.get('mapping_path_{}_{}'.format(medium, split), None)
with open(mapping_path, 'rb') as f:
self.mapping = pickle.load(f)
# These are the main ids of the dataset (typically one pro image/vid)
self.ids = list(self.mapping.keys())
num_samples = config['num_samples_{}'.format(self.medium)]
if num_samples > 0:
self.ids = self.ids[:num_samples]
def __len__(self):
return len(self.ids)
def __getitem__(self, index):
item = self.mapping[self.ids[index]]
# _id = self.ids[index]
############################# Textal features #############################
caption = item['caption']
# caption_ = pre_text(caption)
caption = self.text_processor(caption)
# add [CLS] token
caption = '[CLS] ' + caption
if self.medium == 'cc3m':
pth = os.path.join(self.root_vis, item['file'])
vis = open_img(pth)
vis = self.vis_processor(vis).unsqueeze(0)
else:
pth = os.path.join(self.root_vis, item['file'])
f_names = os.listdir(pth)
f_names.sort(key=lambda f_n: int(f_n.split('.')[0]))
pth = [os.path.join(pth, f_name) for f_name in f_names]
vis = [Image.open(p).convert('RGB') for p in pth]
vis = [self.vis_processor(v).unsqueeze(0) for v in vis]
vis = torch.cat(vis, dim=0)
# Get negative vis
neg_index = random.randint(0, len(self) - 1)
while neg_index == index:
neg_index = random.randint(0, len(self) - 1)
neg_item = self.mapping[self.ids[neg_index]]
if self.medium == 'cc3m':
neg_pth = os.path.join(self.root_vis, neg_item['file'])
neg_vis = open_img(neg_pth)
neg_vis = self.vis_processor(neg_vis).unsqueeze(0)
else:
neg_pth = os.path.join(self.root_vis, neg_item['file'])
neg_f_names = os.listdir(neg_pth)
neg_f_names.sort(key=lambda f_n: int(f_n.split('.')[0]))
neg_pth = [os.path.join(neg_pth, neg_f_name) for neg_f_name in neg_f_names]
neg_vis = [Image.open(p).convert('RGB') for p in neg_pth]
neg_vis = [self.vis_processor(v).unsqueeze(0) for v in neg_vis]
neg_vis = torch.cat(neg_vis, dim=0)
# return caption, vis
return vis, caption, neg_vis
class VideoTextRetDataset(Dataset):
def __init__(self, config, vis_processor, text_processor, medium, split):
super(VideoTextRetDataset, self).__init__()
self.config = config
self.batch_size = config['batch_size_{}'.format(medium)]
self.medium = medium # "webvid / cc3m / msrvtt"
self.vis_processor = vis_processor
self.text_processor = text_processor
self.split = split # train / val / test
self.root_vis = config['root_raw_vis_{}_{}'.format(medium, split)]
anno_path = config['annotation_{}_{}'.format(medium, split)]
self.raw_anno_list = load_anno(anno_path)
self.text = []
self.vis = []
self.txt2vis = {}
self.vis2txt = {}
self.build_data()
self.anno_list = [dict(vis=v) for v in self.vis]
# print('bla')
def __len__(self):
return len(self.anno_list)
def __getitem__(self, index):
pth = self.anno_list[index]['vis']
f_names = os.listdir(pth)
f_names.sort(key=lambda f_n: int(f_n.split('.')[0]))
pth = [os.path.join(pth, f_name) for f_name in f_names]
vis = [Image.open(p).convert('RGB') for p in pth]
vis = [self.vis_processor(v) for v in vis]
# vis = [transforms.PILToTensor()(v).unsqueeze(0) for v in vis]
vis = torch.cat(vis, dim=0)
# vis = self.trans(vis)
return vis, index
def build_data(self):
"""each image may have multiple ground_truth text, e.g., COCO and Flickr30K"""
txt_id = 0
for vis_id, ann in enumerate(self.raw_anno_list):
self.vis.append(ann["vis"])
self.vis2txt[vis_id] = []
_captions = ann["caption"] \
if isinstance(ann["caption"], list) else [ann["caption"], ]
for i, caption in enumerate(_captions):
# self.text.append(pre_text(caption))
self.text.append(self.text_processor(caption))
self.vis2txt[vis_id].append(txt_id)
self.txt2vis[txt_id] = vis_id
txt_id += 1
def load_datasets(config, vis_processor, text_processor, split):
if config['stage'] == 'stage_1':
if split != 'test':
cc3m_dataset = CapDataset(config, 'cc3m', vis_processor, text_processor, split)
webvid_dataset = CapDataset(config, 'webvid', vis_processor, text_processor, split)
datasets = {
'cc3m': cc3m_dataset,
'webvid': webvid_dataset
}
else: # Test with msrvtt_1k --> video retieval
msrvtt_dataset = VideoTextRetDataset(config, vis_processor, text_processor, 'msrvtt', split)
datasets = {
'msrvtt': msrvtt_dataset
}
return datasets

83
datasets/utils.py Normal file
View file

@ -0,0 +1,83 @@
import os
import re
import json
from tqdm import trange
from utils.dist import is_main_process
from torch.utils.data import Dataset, ConcatDataset
from PIL import Image
import numpy as np
def open_img(img_pth):
try:
img = Image.open(img_pth).convert('RGB')
return img
except:
img = np.random.randint(0, high=256, size=(224,224, 3))
img = Image.fromarray(img, 'RGB')
return img
def pre_text(text, max_l=None):
text = re.sub(r"(['!?\"()*#:;~])", '', text.lower())
text = text.replace('-', ' ').replace('/', ' ').replace('<person>', 'person')
text = re.sub(r"\s{2,}", ' ', text)
text = text.rstrip('\n').strip(' ')
if max_l: # truncate
words = text.split(' ')
if len(words) > max_l:
text = ' '.join(words[:max_l])
return text
def get_datasets_media(dataloaders):
media = {}
for dataloader in dataloaders:
if isinstance(dataloader.dataset, ConcatDataset):
media[dataloader.dataset.datasets[0].medium] = dataloader
else:
media[dataloader.dataset.medium] = dataloader
# media = [dataloader.dataset.medium for dataloader in dataloaders]
return media
def type_transform_helper(x):
return x.float().div(255.0)
def load_anno(ann_file_list):
"""[summary]
Args:
ann_file_list (List[List[str, str]] or List[str, str]):
the latter will be automatically converted to the former.
Each sublist contains [anno_path, image_root], (or [anno_path, video_root, 'video'])
which specifies the data type, video or image
Returns:
List(dict): each dict is {
image: str or List[str], # image_path,
caption: str or List[str] # caption text string
}
"""
if isinstance(ann_file_list[0], str):
ann_file_list = [ann_file_list]
ann = []
for d in ann_file_list:
data_root = d[1]
fp = d[0]
is_video = len(d) == 3 and d[2] == "video"
cur_ann = json.load(open(fp, "r"))
iterator = trange(len(cur_ann), desc=f"Loading {fp}") \
if is_main_process() else range(len(cur_ann))
for idx in iterator:
key = "video" if is_video else "image"
video_id = cur_ann[idx][key][5:].split('.')[0]
# unified to have the same key for data path
# if isinstance(cur_ann[idx][key], str):
cur_ann[idx]["vis"] = os.path.join(data_root, video_id)
# else: # list
# cur_ann[idx]["vis"] = [os.path.join(data_root, e) for e in cur_ann[idx][key]]
ann += cur_ann
return ann

97
datasets/video_utils.py Normal file
View file

@ -0,0 +1,97 @@
"""
Modified from https://github.com/m-bain/frozen-in-time/blob/22a91d78405ec6032fdf521ae1ff5573358e632f/base/base_dataset.py
"""
import random
import decord
from PIL import Image
import numpy as np
import math
decord.bridge.set_bridge("torch")
def pts_to_secs(pts: int, time_base: float, start_pts: int) -> float:
"""
Converts a present time with the given time base and start_pts offset to seconds.
Returns:
time_in_seconds (float): The corresponding time in seconds.
https://github.com/facebookresearch/pytorchvideo/blob/main/pytorchvideo/data/utils.py#L54-L64
"""
if pts == math.inf:
return math.inf
return int(pts - start_pts) * time_base
def get_pyav_video_duration(video_reader):
video_stream = video_reader.streams.video[0]
video_duration = pts_to_secs(
video_stream.duration,
video_stream.time_base,
video_stream.start_time
)
return float(video_duration)
def get_frame_indices_by_fps():
pass
def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1):
if sample in ["rand", "middle"]:
acc_samples = min(num_frames, vlen)
# split the video into `acc_samples` intervals, and sample from each interval.
intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int)
ranges = []
for idx, interv in enumerate(intervals[:-1]):
ranges.append((interv, intervals[idx + 1] - 1))
if sample == 'rand':
try:
frame_indices = [random.choice(range(x[0], x[1])) for x in ranges]
except:
frame_indices = np.random.permutation(vlen)[:acc_samples]
frame_indices.sort()
frame_indices = list(frame_indices)
elif fix_start is not None:
frame_indices = [x[0] + fix_start for x in ranges]
elif sample == 'middle':
frame_indices = [(x[0] + x[1]) // 2 for x in ranges]
else:
raise NotImplementedError
if len(frame_indices) < num_frames: # padded with last frame
padded_frame_indices = [frame_indices[-1]] * num_frames
padded_frame_indices[:len(frame_indices)] = frame_indices
frame_indices = padded_frame_indices
elif "fps" in sample: # fps0.5, sequentially sample frames at 0.5 fps
output_fps = float(sample[3:])
duration = float(vlen) / input_fps
delta = 1 / output_fps # gap between frames, this is also the clip length each frame represents
frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta)
frame_indices = np.around(frame_seconds * input_fps).astype(int)
frame_indices = [e for e in frame_indices if e < vlen]
if max_num_frames > 0 and len(frame_indices) > max_num_frames:
frame_indices = frame_indices[:max_num_frames]
# frame_indices = np.linspace(0 + delta / 2, duration + delta / 2, endpoint=False, num=max_num_frames)
else:
raise ValueError
return frame_indices
def read_frames_decord(video_path, num_frames, sample='rand', fix_start=None, max_num_frames=-1):
video_reader = decord.VideoReader(video_path, num_threads=1)
vlen = len(video_reader)
fps = video_reader.get_avg_fps()
duration = vlen / float(fps)
frame_indices = get_frame_indices(
num_frames, vlen, sample=sample, fix_start=fix_start,
input_fps=fps, max_num_frames=max_num_frames
)
frames = video_reader.get_batch(frame_indices) # (T, H, W, C), torch.uint8
frames = frames.permute(0, 3, 1, 2) # (T, C, H, W), torch.uint8
frames = frames.split(1, dim=0)
frames = [Image.fromarray(f.squeeze().numpy(), mode='RGB') for f in frames]
# frames = frames.numpy() # convert to numpy
return frames, frame_indices, duration

183
datasets/visdial_dataset.py Normal file
View file

@ -0,0 +1,183 @@
# coding: utf-8
# author: noctli
import json
import os
import pickle
import logging
from tqdm import tqdm
import numpy as np
import torch
import torch.utils.data
from PIL import Image
from torch.utils.data import Dataset
from itertools import chain
from torchvision import transforms
from .utils import type_transform_helper
from .utils import open_img
def tokenize(text, tokenizer, return_tensor=False):
tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
if return_tensor:
return torch.tensor(tokenized_text).long()
return tokenized_text
def get_dataset(config, split):
dialog_pth = config['anno_visdial_{}'.format(split)]
dialog_data = json.load(open(dialog_pth, 'r'))['data']
all_answers = dialog_data['answers']
all_questions = dialog_data['questions']
dialog_list = []
n_history = config['num_hist_turns_visdial']
vid_set = set()
undisclosed_only = False
pbar = tqdm(dialog_data['dialogs'])
pbar.set_description('[INFO] Loading VisDial - {}'.format(split))
for dialog in pbar:
caption = dialog['caption'] + ' .'
questions = [all_questions[d['question']] + ' ?' for d in dialog['dialog']]
answers = [all_answers[d['answer']] + ' .' for d in dialog['dialog']]
# answer_opts = [[all_answers[key] for key in d['answer_options']] for d in dialog['dialog']]
# if 'test' in config['anno_visdial_{}'.format(split)]:
# gt_indices = [-1 for _ in range(len(questions))]
# else:
# gt_indices = [d['gt_index'] for d in dialog['dialog']]
vid = dialog['image_id']
vid_set.add(vid)
if undisclosed_only:
it = range(len(questions) - 1, len(questions))
else:
it = range(len(questions))
qalist=[]
history = []
if undisclosed_only:
for n in range(len(questions)-1):
qalist.append(questions[n])
qalist.append(answers[n])
history=qalist[max(-len(qalist),-n_history*2):]
for n in it:
if undisclosed_only:
assert dialog['dialog'][n]['answer'] == '__UNDISCLOSED__'
question = questions[n]
answer = answers[n]
# answer_opt = answer_opts[n]
# gt_index = gt_indices[n]
history.append(question)
# if n_history == 0:
# item = {'vid': vid, 'history': [question], 'answer': answer, 'caption': caption, 'round': n+1, 'answer_opts': answer_opt, 'gt_index': gt_index}
# else:
# item = {'vid': vid, 'history': history, 'answer': answer, 'caption': caption, 'round': n+1, 'answer_opts': answer_opt, 'gt_index': gt_index}
if n_history == 0:
item = {'vid': vid, 'history': [question], 'answer': answer, 'caption': caption, 'round': n+1}
else:
item = {'vid': vid, 'history': history, 'answer': answer, 'caption': caption, 'round': n+1}
dialog_list.append(item)
qalist.append(question)
qalist.append(answer)
history=qalist[max(-len(qalist),-n_history*2):]
return dialog_list
class VisDial(Dataset):
def __init__(self, config, medium, vis_processor, text_processor, split
# tokenizer, features=None, drop_rate=0.0, train=True
):
self.config = config
self.medium = medium
self.split = split
self.vis_processor = vis_processor
self.text_processor = text_processor
self.batch_size = config['batch_size_test_{}'.format(medium)] if split == 'test' else config['batch_size_{}'.format(medium)]
self.root_vis = config['root_raw_vis_{}_{}'.format(medium, split)]
self.dialogs = get_dataset(config, split)
if split == 'test':
self.dialogs = self.dialogs[config['start_idx_gen']: config['end_idx_gen']]
num_samples = config['num_samples_{}'.format(self.medium)]
if num_samples > 0:
self.dialogs = self.dialogs[:num_samples]
def __len__(self):
return len(self.dialogs)
def load_img(self, vid_id):
file_pth = os.path.join(self.root_vis, f'{vid_id}.jpg')
vis = open_img(file_pth)
vis = self.vis_processor(vis).unsqueeze(0)
return vis
def __getitem__(self, index):
dialog = self.dialogs[index]
vid_id = dialog['vid']
caption = dialog['caption']
history = dialog['history']
answer = dialog['answer']
d_round = dialog['round']
caption = self.text_processor(caption)
history = [self.text_processor(h) for h in history]
answer = self.text_processor(answer, remove_period=True)
# if self.split == 'test':
# answer_opts = dialog['answer_opts']
# answer_opts = [self.text_processor(a) for a in answer_opts]
# gt_index = dialog['gt_index']
# dialog_round = dialog['round']
# dense_key = str(vid_id) + '_' + str(dialog_round)
# gt_relevance = self.dense_annos.get(dense_key, -1)
# # eval_data = (answer_opts, gt_index, gt_relevance)
if self.config.embed_from_llm:
if self.config.llm_family in ['llama', 'mistral']:
cls_tok = ''
sep_tok = ' '
bos_tok = '<s>'
eos_tok = '</s>'
else:
cls_tok = '<s>'
sep_tok = '</s>'
bos_tok = '<pad>'
eos_tok = '</s>'
else:
cls_tok = '[CLS]'
sep_tok = '[SEP]'
bos_tok = '[SEP]'
eos_tok = '[SEP]'
caption = cls_tok + caption + sep_tok
history = sep_tok.join(history)
history = history + sep_tok
# load the video frames
vis = self.load_img(vid_id)
# if self.split == 'test':
# return vis, caption, history, answer, vid_id, answer_opts, gt_relevance, gt_index
# else:
return vis, caption, history, answer, vid_id, d_round
def load_visdial_dataset(config, vis_processor, text_processor, split):
dataset = VisDial(config, 'visdial', vis_processor, text_processor, split)
return dataset