initial commit
This commit is contained in:
commit
a82bbc593e
129 changed files with 33981 additions and 0 deletions
156
datasets/pretraining.py
Normal file
156
datasets/pretraining.py
Normal file
|
@ -0,0 +1,156 @@
|
|||
from torch.utils.data import Dataset
|
||||
import pickle
|
||||
import os
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from torchvision import transforms
|
||||
import random
|
||||
|
||||
from .utils import pre_text, type_transform_helper, load_anno, open_img
|
||||
|
||||
class CapDataset(Dataset):
|
||||
def __init__(self, config, medium, vis_processor, text_processor, split):
|
||||
super(CapDataset, self).__init__()
|
||||
self.config = config
|
||||
self.batch_size = config['batch_size_{}'.format(medium)]
|
||||
self.medium = medium # "webvid / cc3m / msrvtt"
|
||||
self.vis_processor = vis_processor
|
||||
self.text_processor = text_processor
|
||||
self.split = split # train / val / test
|
||||
|
||||
self.root_vis = config['root_raw_vis_{}_{}'.format(medium, split)]
|
||||
|
||||
# get the mapping between caption and image/video
|
||||
mapping_path = config.get('mapping_path_{}_{}'.format(medium, split), None)
|
||||
with open(mapping_path, 'rb') as f:
|
||||
self.mapping = pickle.load(f)
|
||||
|
||||
# These are the main ids of the dataset (typically one pro image/vid)
|
||||
self.ids = list(self.mapping.keys())
|
||||
num_samples = config['num_samples_{}'.format(self.medium)]
|
||||
if num_samples > 0:
|
||||
self.ids = self.ids[:num_samples]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.ids)
|
||||
|
||||
def __getitem__(self, index):
|
||||
item = self.mapping[self.ids[index]]
|
||||
# _id = self.ids[index]
|
||||
############################# Textal features #############################
|
||||
caption = item['caption']
|
||||
# caption_ = pre_text(caption)
|
||||
caption = self.text_processor(caption)
|
||||
# add [CLS] token
|
||||
caption = '[CLS] ' + caption
|
||||
|
||||
if self.medium == 'cc3m':
|
||||
pth = os.path.join(self.root_vis, item['file'])
|
||||
vis = open_img(pth)
|
||||
vis = self.vis_processor(vis).unsqueeze(0)
|
||||
else:
|
||||
pth = os.path.join(self.root_vis, item['file'])
|
||||
f_names = os.listdir(pth)
|
||||
f_names.sort(key=lambda f_n: int(f_n.split('.')[0]))
|
||||
pth = [os.path.join(pth, f_name) for f_name in f_names]
|
||||
vis = [Image.open(p).convert('RGB') for p in pth]
|
||||
vis = [self.vis_processor(v).unsqueeze(0) for v in vis]
|
||||
vis = torch.cat(vis, dim=0)
|
||||
|
||||
# Get negative vis
|
||||
neg_index = random.randint(0, len(self) - 1)
|
||||
while neg_index == index:
|
||||
neg_index = random.randint(0, len(self) - 1)
|
||||
|
||||
neg_item = self.mapping[self.ids[neg_index]]
|
||||
|
||||
if self.medium == 'cc3m':
|
||||
neg_pth = os.path.join(self.root_vis, neg_item['file'])
|
||||
neg_vis = open_img(neg_pth)
|
||||
neg_vis = self.vis_processor(neg_vis).unsqueeze(0)
|
||||
else:
|
||||
neg_pth = os.path.join(self.root_vis, neg_item['file'])
|
||||
neg_f_names = os.listdir(neg_pth)
|
||||
neg_f_names.sort(key=lambda f_n: int(f_n.split('.')[0]))
|
||||
neg_pth = [os.path.join(neg_pth, neg_f_name) for neg_f_name in neg_f_names]
|
||||
neg_vis = [Image.open(p).convert('RGB') for p in neg_pth]
|
||||
neg_vis = [self.vis_processor(v).unsqueeze(0) for v in neg_vis]
|
||||
neg_vis = torch.cat(neg_vis, dim=0)
|
||||
|
||||
# return caption, vis
|
||||
return vis, caption, neg_vis
|
||||
|
||||
|
||||
class VideoTextRetDataset(Dataset):
|
||||
def __init__(self, config, vis_processor, text_processor, medium, split):
|
||||
super(VideoTextRetDataset, self).__init__()
|
||||
|
||||
self.config = config
|
||||
self.batch_size = config['batch_size_{}'.format(medium)]
|
||||
self.medium = medium # "webvid / cc3m / msrvtt"
|
||||
self.vis_processor = vis_processor
|
||||
self.text_processor = text_processor
|
||||
self.split = split # train / val / test
|
||||
|
||||
|
||||
self.root_vis = config['root_raw_vis_{}_{}'.format(medium, split)]
|
||||
|
||||
anno_path = config['annotation_{}_{}'.format(medium, split)]
|
||||
self.raw_anno_list = load_anno(anno_path)
|
||||
self.text = []
|
||||
self.vis = []
|
||||
self.txt2vis = {}
|
||||
self.vis2txt = {}
|
||||
self.build_data()
|
||||
self.anno_list = [dict(vis=v) for v in self.vis]
|
||||
# print('bla')
|
||||
|
||||
def __len__(self):
|
||||
return len(self.anno_list)
|
||||
|
||||
def __getitem__(self, index):
|
||||
pth = self.anno_list[index]['vis']
|
||||
f_names = os.listdir(pth)
|
||||
f_names.sort(key=lambda f_n: int(f_n.split('.')[0]))
|
||||
pth = [os.path.join(pth, f_name) for f_name in f_names]
|
||||
vis = [Image.open(p).convert('RGB') for p in pth]
|
||||
vis = [self.vis_processor(v) for v in vis]
|
||||
# vis = [transforms.PILToTensor()(v).unsqueeze(0) for v in vis]
|
||||
vis = torch.cat(vis, dim=0)
|
||||
# vis = self.trans(vis)
|
||||
|
||||
return vis, index
|
||||
|
||||
def build_data(self):
|
||||
"""each image may have multiple ground_truth text, e.g., COCO and Flickr30K"""
|
||||
txt_id = 0
|
||||
for vis_id, ann in enumerate(self.raw_anno_list):
|
||||
self.vis.append(ann["vis"])
|
||||
self.vis2txt[vis_id] = []
|
||||
_captions = ann["caption"] \
|
||||
if isinstance(ann["caption"], list) else [ann["caption"], ]
|
||||
for i, caption in enumerate(_captions):
|
||||
# self.text.append(pre_text(caption))
|
||||
self.text.append(self.text_processor(caption))
|
||||
self.vis2txt[vis_id].append(txt_id)
|
||||
self.txt2vis[txt_id] = vis_id
|
||||
txt_id += 1
|
||||
|
||||
|
||||
def load_datasets(config, vis_processor, text_processor, split):
|
||||
if config['stage'] == 'stage_1':
|
||||
if split != 'test':
|
||||
cc3m_dataset = CapDataset(config, 'cc3m', vis_processor, text_processor, split)
|
||||
webvid_dataset = CapDataset(config, 'webvid', vis_processor, text_processor, split)
|
||||
datasets = {
|
||||
'cc3m': cc3m_dataset,
|
||||
'webvid': webvid_dataset
|
||||
}
|
||||
else: # Test with msrvtt_1k --> video retieval
|
||||
msrvtt_dataset = VideoTextRetDataset(config, vis_processor, text_processor, 'msrvtt', split)
|
||||
datasets = {
|
||||
'msrvtt': msrvtt_dataset
|
||||
}
|
||||
return datasets
|
Loading…
Add table
Add a link
Reference in a new issue