V2Dial/datasets/utils.py
2025-06-24 08:38:09 +02:00

83 lines
No EOL
2.6 KiB
Python

import os
import re
import json
from tqdm import trange
from utils.dist import is_main_process
from torch.utils.data import Dataset, ConcatDataset
from PIL import Image
import numpy as np
def open_img(img_pth):
try:
img = Image.open(img_pth).convert('RGB')
return img
except:
img = np.random.randint(0, high=256, size=(224,224, 3))
img = Image.fromarray(img, 'RGB')
return img
def pre_text(text, max_l=None):
text = re.sub(r"(['!?\"()*#:;~])", '', text.lower())
text = text.replace('-', ' ').replace('/', ' ').replace('<person>', 'person')
text = re.sub(r"\s{2,}", ' ', text)
text = text.rstrip('\n').strip(' ')
if max_l: # truncate
words = text.split(' ')
if len(words) > max_l:
text = ' '.join(words[:max_l])
return text
def get_datasets_media(dataloaders):
media = {}
for dataloader in dataloaders:
if isinstance(dataloader.dataset, ConcatDataset):
media[dataloader.dataset.datasets[0].medium] = dataloader
else:
media[dataloader.dataset.medium] = dataloader
# media = [dataloader.dataset.medium for dataloader in dataloaders]
return media
def type_transform_helper(x):
return x.float().div(255.0)
def load_anno(ann_file_list):
"""[summary]
Args:
ann_file_list (List[List[str, str]] or List[str, str]):
the latter will be automatically converted to the former.
Each sublist contains [anno_path, image_root], (or [anno_path, video_root, 'video'])
which specifies the data type, video or image
Returns:
List(dict): each dict is {
image: str or List[str], # image_path,
caption: str or List[str] # caption text string
}
"""
if isinstance(ann_file_list[0], str):
ann_file_list = [ann_file_list]
ann = []
for d in ann_file_list:
data_root = d[1]
fp = d[0]
is_video = len(d) == 3 and d[2] == "video"
cur_ann = json.load(open(fp, "r"))
iterator = trange(len(cur_ann), desc=f"Loading {fp}") \
if is_main_process() else range(len(cur_ann))
for idx in iterator:
key = "video" if is_video else "image"
video_id = cur_ann[idx][key][5:].split('.')[0]
# unified to have the same key for data path
# if isinstance(cur_ann[idx][key], str):
cur_ann[idx]["vis"] = os.path.join(data_root, video_id)
# else: # list
# cur_ann[idx]["vis"] = [os.path.join(data_root, e) for e in cur_ann[idx][key]]
ann += cur_ann
return ann