V2Dial/datasets/pretraining.py
2025-06-24 08:38:09 +02:00

156 lines
No EOL
6 KiB
Python

from torch.utils.data import Dataset
import pickle
import os
import torch
from PIL import Image
import numpy as np
from torchvision import transforms
import random
from .utils import pre_text, type_transform_helper, load_anno, open_img
class CapDataset(Dataset):
def __init__(self, config, medium, vis_processor, text_processor, split):
super(CapDataset, self).__init__()
self.config = config
self.batch_size = config['batch_size_{}'.format(medium)]
self.medium = medium # "webvid / cc3m / msrvtt"
self.vis_processor = vis_processor
self.text_processor = text_processor
self.split = split # train / val / test
self.root_vis = config['root_raw_vis_{}_{}'.format(medium, split)]
# get the mapping between caption and image/video
mapping_path = config.get('mapping_path_{}_{}'.format(medium, split), None)
with open(mapping_path, 'rb') as f:
self.mapping = pickle.load(f)
# These are the main ids of the dataset (typically one pro image/vid)
self.ids = list(self.mapping.keys())
num_samples = config['num_samples_{}'.format(self.medium)]
if num_samples > 0:
self.ids = self.ids[:num_samples]
def __len__(self):
return len(self.ids)
def __getitem__(self, index):
item = self.mapping[self.ids[index]]
# _id = self.ids[index]
############################# Textal features #############################
caption = item['caption']
# caption_ = pre_text(caption)
caption = self.text_processor(caption)
# add [CLS] token
caption = '[CLS] ' + caption
if self.medium == 'cc3m':
pth = os.path.join(self.root_vis, item['file'])
vis = open_img(pth)
vis = self.vis_processor(vis).unsqueeze(0)
else:
pth = os.path.join(self.root_vis, item['file'])
f_names = os.listdir(pth)
f_names.sort(key=lambda f_n: int(f_n.split('.')[0]))
pth = [os.path.join(pth, f_name) for f_name in f_names]
vis = [Image.open(p).convert('RGB') for p in pth]
vis = [self.vis_processor(v).unsqueeze(0) for v in vis]
vis = torch.cat(vis, dim=0)
# Get negative vis
neg_index = random.randint(0, len(self) - 1)
while neg_index == index:
neg_index = random.randint(0, len(self) - 1)
neg_item = self.mapping[self.ids[neg_index]]
if self.medium == 'cc3m':
neg_pth = os.path.join(self.root_vis, neg_item['file'])
neg_vis = open_img(neg_pth)
neg_vis = self.vis_processor(neg_vis).unsqueeze(0)
else:
neg_pth = os.path.join(self.root_vis, neg_item['file'])
neg_f_names = os.listdir(neg_pth)
neg_f_names.sort(key=lambda f_n: int(f_n.split('.')[0]))
neg_pth = [os.path.join(neg_pth, neg_f_name) for neg_f_name in neg_f_names]
neg_vis = [Image.open(p).convert('RGB') for p in neg_pth]
neg_vis = [self.vis_processor(v).unsqueeze(0) for v in neg_vis]
neg_vis = torch.cat(neg_vis, dim=0)
# return caption, vis
return vis, caption, neg_vis
class VideoTextRetDataset(Dataset):
def __init__(self, config, vis_processor, text_processor, medium, split):
super(VideoTextRetDataset, self).__init__()
self.config = config
self.batch_size = config['batch_size_{}'.format(medium)]
self.medium = medium # "webvid / cc3m / msrvtt"
self.vis_processor = vis_processor
self.text_processor = text_processor
self.split = split # train / val / test
self.root_vis = config['root_raw_vis_{}_{}'.format(medium, split)]
anno_path = config['annotation_{}_{}'.format(medium, split)]
self.raw_anno_list = load_anno(anno_path)
self.text = []
self.vis = []
self.txt2vis = {}
self.vis2txt = {}
self.build_data()
self.anno_list = [dict(vis=v) for v in self.vis]
# print('bla')
def __len__(self):
return len(self.anno_list)
def __getitem__(self, index):
pth = self.anno_list[index]['vis']
f_names = os.listdir(pth)
f_names.sort(key=lambda f_n: int(f_n.split('.')[0]))
pth = [os.path.join(pth, f_name) for f_name in f_names]
vis = [Image.open(p).convert('RGB') for p in pth]
vis = [self.vis_processor(v) for v in vis]
# vis = [transforms.PILToTensor()(v).unsqueeze(0) for v in vis]
vis = torch.cat(vis, dim=0)
# vis = self.trans(vis)
return vis, index
def build_data(self):
"""each image may have multiple ground_truth text, e.g., COCO and Flickr30K"""
txt_id = 0
for vis_id, ann in enumerate(self.raw_anno_list):
self.vis.append(ann["vis"])
self.vis2txt[vis_id] = []
_captions = ann["caption"] \
if isinstance(ann["caption"], list) else [ann["caption"], ]
for i, caption in enumerate(_captions):
# self.text.append(pre_text(caption))
self.text.append(self.text_processor(caption))
self.vis2txt[vis_id].append(txt_id)
self.txt2vis[txt_id] = vis_id
txt_id += 1
def load_datasets(config, vis_processor, text_processor, split):
if config['stage'] == 'stage_1':
if split != 'test':
cc3m_dataset = CapDataset(config, 'cc3m', vis_processor, text_processor, split)
webvid_dataset = CapDataset(config, 'webvid', vis_processor, text_processor, split)
datasets = {
'cc3m': cc3m_dataset,
'webvid': webvid_dataset
}
else: # Test with msrvtt_1k --> video retieval
msrvtt_dataset = VideoTextRetDataset(config, vis_processor, text_processor, 'msrvtt', split)
datasets = {
'msrvtt': msrvtt_dataset
}
return datasets