137 lines
No EOL
4.6 KiB
Python
137 lines
No EOL
4.6 KiB
Python
"""
|
|
From https://github.com/klauscc/VindLU/blob/main/dataset/dataloader.py
|
|
"""
|
|
|
|
import torch
|
|
from torch.utils.data import DataLoader, Dataset, ConcatDataset
|
|
import torch.distributed as dist
|
|
from utils.dist import *
|
|
import random
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class MetaLoader(object):
|
|
""" wraps multiple data loader """
|
|
def __init__(self, name2loader):
|
|
"""Iterates over multiple dataloaders, it ensures all processes
|
|
work on data from the same dataloader. This loader will end when
|
|
the shorter dataloader raises StopIteration exception.
|
|
|
|
loaders: Dict, {name: dataloader}
|
|
"""
|
|
self.name2loader = name2loader
|
|
self.name2iter = {name: iter(l) for name, l in name2loader.items()}
|
|
name2index = {name: idx for idx, (name, l) in enumerate(name2loader.items())}
|
|
index2name = {v: k for k, v in name2index.items()}
|
|
|
|
iter_order = []
|
|
for n, l in name2loader.items():
|
|
iter_order.extend([name2index[n]]*len(l))
|
|
|
|
random.shuffle(iter_order)
|
|
iter_order = torch.Tensor(iter_order).to(torch.device("cuda")).to(torch.uint8)
|
|
|
|
# sync
|
|
if is_dist_avail_and_initialized():
|
|
# make sure all processes have the same order so that
|
|
# each step they will have data from the same loader
|
|
dist.broadcast(iter_order, src=0)
|
|
self.iter_order = [index2name[int(e.item())] for e in iter_order.cpu()]
|
|
|
|
logger.info(str(self))
|
|
|
|
def __str__(self):
|
|
output = [f"MetaLoader has {len(self.name2loader)} dataloaders, {len(self)} batches in total"]
|
|
for idx, (name, loader) in enumerate(self.name2loader.items()):
|
|
output.append(
|
|
f"dataloader index={idx} name={name}, batch-size={loader.batch_size} length(#batches)={len(loader)} "
|
|
)
|
|
return "\n".join(output)
|
|
|
|
def __len__(self):
|
|
return len(self.iter_order)
|
|
|
|
def __iter__(self):
|
|
""" this iterator will run indefinitely """
|
|
for name in self.iter_order:
|
|
_iter = self.name2iter[name]
|
|
batch = next(_iter)
|
|
yield name, batch
|
|
|
|
|
|
def load_dataloaders(config, datasets, split, output_dict=False):
|
|
if isinstance(datasets, dict):
|
|
datasets = list(datasets.values())
|
|
shuffles = [True] * len(datasets) if split == 'train' else [False] * len(datasets)
|
|
if config['distributed'] and split != 'test':
|
|
num_tasks = get_world_size()
|
|
global_rank = get_rank()
|
|
samplers = create_samplers(
|
|
datasets, shuffles, num_tasks, global_rank
|
|
)
|
|
else:
|
|
samplers = [None] * len(datasets)
|
|
|
|
batch_size = [dataset.datasets[0].batch_size if isinstance(dataset, ConcatDataset) else dataset.batch_size for dataset in datasets]
|
|
collate_fns = []
|
|
for dataset in datasets:
|
|
if isinstance(dataset, ConcatDataset):
|
|
collate_fns.append(getattr(dataset.datasets[0], 'collate_fn', None))
|
|
else:
|
|
collate_fns.append(getattr(dataset, 'collate_fn', None))
|
|
|
|
loaders = create_loader(
|
|
datasets,
|
|
samplers,
|
|
batch_size=batch_size,
|
|
num_workers=[config.num_workers] * len(datasets),
|
|
is_trains=shuffles,
|
|
collate_fns=collate_fns,
|
|
) # [0]
|
|
loaders_dict = {}
|
|
if output_dict:
|
|
for l in loaders:
|
|
if isinstance(l.dataset, ConcatDataset):
|
|
loaders_dict[l.dataset.datasets[0].medium] = l
|
|
else:
|
|
loaders_dict[l.dataset.medium] = l
|
|
return loaders_dict
|
|
return loaders
|
|
|
|
|
|
def create_samplers(datasets, shuffles, num_tasks, global_rank):
|
|
samplers = []
|
|
for dataset, shuffle in zip(datasets, shuffles):
|
|
sampler = torch.utils.data.DistributedSampler(
|
|
dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle
|
|
)
|
|
samplers.append(sampler)
|
|
return samplers
|
|
|
|
|
|
def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
|
|
loaders = []
|
|
for dataset, sampler, bs, n_worker, is_train, collate_fn in zip(
|
|
datasets, samplers, batch_size, num_workers, is_trains, collate_fns
|
|
):
|
|
if is_train:
|
|
shuffle = sampler is None
|
|
drop_last = True
|
|
else:
|
|
shuffle = False
|
|
drop_last = True
|
|
loader = DataLoader(
|
|
dataset,
|
|
batch_size=bs,
|
|
num_workers=n_worker,
|
|
pin_memory=False,
|
|
sampler=sampler,
|
|
shuffle=shuffle,
|
|
collate_fn=collate_fn,
|
|
drop_last=drop_last,
|
|
persistent_workers=True if n_worker > 0 else False,
|
|
)
|
|
loaders.append(loader)
|
|
return loaders |