V2Dial/datasets/dataloader.py
2025-06-24 08:38:09 +02:00

137 lines
No EOL
4.6 KiB
Python

"""
From https://github.com/klauscc/VindLU/blob/main/dataset/dataloader.py
"""
import torch
from torch.utils.data import DataLoader, Dataset, ConcatDataset
import torch.distributed as dist
from utils.dist import *
import random
import logging
logger = logging.getLogger(__name__)
class MetaLoader(object):
""" wraps multiple data loader """
def __init__(self, name2loader):
"""Iterates over multiple dataloaders, it ensures all processes
work on data from the same dataloader. This loader will end when
the shorter dataloader raises StopIteration exception.
loaders: Dict, {name: dataloader}
"""
self.name2loader = name2loader
self.name2iter = {name: iter(l) for name, l in name2loader.items()}
name2index = {name: idx for idx, (name, l) in enumerate(name2loader.items())}
index2name = {v: k for k, v in name2index.items()}
iter_order = []
for n, l in name2loader.items():
iter_order.extend([name2index[n]]*len(l))
random.shuffle(iter_order)
iter_order = torch.Tensor(iter_order).to(torch.device("cuda")).to(torch.uint8)
# sync
if is_dist_avail_and_initialized():
# make sure all processes have the same order so that
# each step they will have data from the same loader
dist.broadcast(iter_order, src=0)
self.iter_order = [index2name[int(e.item())] for e in iter_order.cpu()]
logger.info(str(self))
def __str__(self):
output = [f"MetaLoader has {len(self.name2loader)} dataloaders, {len(self)} batches in total"]
for idx, (name, loader) in enumerate(self.name2loader.items()):
output.append(
f"dataloader index={idx} name={name}, batch-size={loader.batch_size} length(#batches)={len(loader)} "
)
return "\n".join(output)
def __len__(self):
return len(self.iter_order)
def __iter__(self):
""" this iterator will run indefinitely """
for name in self.iter_order:
_iter = self.name2iter[name]
batch = next(_iter)
yield name, batch
def load_dataloaders(config, datasets, split, output_dict=False):
if isinstance(datasets, dict):
datasets = list(datasets.values())
shuffles = [True] * len(datasets) if split == 'train' else [False] * len(datasets)
if config['distributed'] and split != 'test':
num_tasks = get_world_size()
global_rank = get_rank()
samplers = create_samplers(
datasets, shuffles, num_tasks, global_rank
)
else:
samplers = [None] * len(datasets)
batch_size = [dataset.datasets[0].batch_size if isinstance(dataset, ConcatDataset) else dataset.batch_size for dataset in datasets]
collate_fns = []
for dataset in datasets:
if isinstance(dataset, ConcatDataset):
collate_fns.append(getattr(dataset.datasets[0], 'collate_fn', None))
else:
collate_fns.append(getattr(dataset, 'collate_fn', None))
loaders = create_loader(
datasets,
samplers,
batch_size=batch_size,
num_workers=[config.num_workers] * len(datasets),
is_trains=shuffles,
collate_fns=collate_fns,
) # [0]
loaders_dict = {}
if output_dict:
for l in loaders:
if isinstance(l.dataset, ConcatDataset):
loaders_dict[l.dataset.datasets[0].medium] = l
else:
loaders_dict[l.dataset.medium] = l
return loaders_dict
return loaders
def create_samplers(datasets, shuffles, num_tasks, global_rank):
samplers = []
for dataset, shuffle in zip(datasets, shuffles):
sampler = torch.utils.data.DistributedSampler(
dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle
)
samplers.append(sampler)
return samplers
def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
loaders = []
for dataset, sampler, bs, n_worker, is_train, collate_fn in zip(
datasets, samplers, batch_size, num_workers, is_trains, collate_fns
):
if is_train:
shuffle = sampler is None
drop_last = True
else:
shuffle = False
drop_last = True
loader = DataLoader(
dataset,
batch_size=bs,
num_workers=n_worker,
pin_memory=False,
sampler=sampler,
shuffle=shuffle,
collate_fn=collate_fn,
drop_last=drop_last,
persistent_workers=True if n_worker > 0 else False,
)
loaders.append(loader)
return loaders