initial commit
This commit is contained in:
commit
a82bbc593e
129 changed files with 33981 additions and 0 deletions
137
datasets/dataloader.py
Normal file
137
datasets/dataloader.py
Normal file
|
@ -0,0 +1,137 @@
|
|||
"""
|
||||
From https://github.com/klauscc/VindLU/blob/main/dataset/dataloader.py
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, Dataset, ConcatDataset
|
||||
import torch.distributed as dist
|
||||
from utils.dist import *
|
||||
import random
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MetaLoader(object):
|
||||
""" wraps multiple data loader """
|
||||
def __init__(self, name2loader):
|
||||
"""Iterates over multiple dataloaders, it ensures all processes
|
||||
work on data from the same dataloader. This loader will end when
|
||||
the shorter dataloader raises StopIteration exception.
|
||||
|
||||
loaders: Dict, {name: dataloader}
|
||||
"""
|
||||
self.name2loader = name2loader
|
||||
self.name2iter = {name: iter(l) for name, l in name2loader.items()}
|
||||
name2index = {name: idx for idx, (name, l) in enumerate(name2loader.items())}
|
||||
index2name = {v: k for k, v in name2index.items()}
|
||||
|
||||
iter_order = []
|
||||
for n, l in name2loader.items():
|
||||
iter_order.extend([name2index[n]]*len(l))
|
||||
|
||||
random.shuffle(iter_order)
|
||||
iter_order = torch.Tensor(iter_order).to(torch.device("cuda")).to(torch.uint8)
|
||||
|
||||
# sync
|
||||
if is_dist_avail_and_initialized():
|
||||
# make sure all processes have the same order so that
|
||||
# each step they will have data from the same loader
|
||||
dist.broadcast(iter_order, src=0)
|
||||
self.iter_order = [index2name[int(e.item())] for e in iter_order.cpu()]
|
||||
|
||||
logger.info(str(self))
|
||||
|
||||
def __str__(self):
|
||||
output = [f"MetaLoader has {len(self.name2loader)} dataloaders, {len(self)} batches in total"]
|
||||
for idx, (name, loader) in enumerate(self.name2loader.items()):
|
||||
output.append(
|
||||
f"dataloader index={idx} name={name}, batch-size={loader.batch_size} length(#batches)={len(loader)} "
|
||||
)
|
||||
return "\n".join(output)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.iter_order)
|
||||
|
||||
def __iter__(self):
|
||||
""" this iterator will run indefinitely """
|
||||
for name in self.iter_order:
|
||||
_iter = self.name2iter[name]
|
||||
batch = next(_iter)
|
||||
yield name, batch
|
||||
|
||||
|
||||
def load_dataloaders(config, datasets, split, output_dict=False):
|
||||
if isinstance(datasets, dict):
|
||||
datasets = list(datasets.values())
|
||||
shuffles = [True] * len(datasets) if split == 'train' else [False] * len(datasets)
|
||||
if config['distributed'] and split != 'test':
|
||||
num_tasks = get_world_size()
|
||||
global_rank = get_rank()
|
||||
samplers = create_samplers(
|
||||
datasets, shuffles, num_tasks, global_rank
|
||||
)
|
||||
else:
|
||||
samplers = [None] * len(datasets)
|
||||
|
||||
batch_size = [dataset.datasets[0].batch_size if isinstance(dataset, ConcatDataset) else dataset.batch_size for dataset in datasets]
|
||||
collate_fns = []
|
||||
for dataset in datasets:
|
||||
if isinstance(dataset, ConcatDataset):
|
||||
collate_fns.append(getattr(dataset.datasets[0], 'collate_fn', None))
|
||||
else:
|
||||
collate_fns.append(getattr(dataset, 'collate_fn', None))
|
||||
|
||||
loaders = create_loader(
|
||||
datasets,
|
||||
samplers,
|
||||
batch_size=batch_size,
|
||||
num_workers=[config.num_workers] * len(datasets),
|
||||
is_trains=shuffles,
|
||||
collate_fns=collate_fns,
|
||||
) # [0]
|
||||
loaders_dict = {}
|
||||
if output_dict:
|
||||
for l in loaders:
|
||||
if isinstance(l.dataset, ConcatDataset):
|
||||
loaders_dict[l.dataset.datasets[0].medium] = l
|
||||
else:
|
||||
loaders_dict[l.dataset.medium] = l
|
||||
return loaders_dict
|
||||
return loaders
|
||||
|
||||
|
||||
def create_samplers(datasets, shuffles, num_tasks, global_rank):
|
||||
samplers = []
|
||||
for dataset, shuffle in zip(datasets, shuffles):
|
||||
sampler = torch.utils.data.DistributedSampler(
|
||||
dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle
|
||||
)
|
||||
samplers.append(sampler)
|
||||
return samplers
|
||||
|
||||
|
||||
def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
|
||||
loaders = []
|
||||
for dataset, sampler, bs, n_worker, is_train, collate_fn in zip(
|
||||
datasets, samplers, batch_size, num_workers, is_trains, collate_fns
|
||||
):
|
||||
if is_train:
|
||||
shuffle = sampler is None
|
||||
drop_last = True
|
||||
else:
|
||||
shuffle = False
|
||||
drop_last = True
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=bs,
|
||||
num_workers=n_worker,
|
||||
pin_memory=False,
|
||||
sampler=sampler,
|
||||
shuffle=shuffle,
|
||||
collate_fn=collate_fn,
|
||||
drop_last=drop_last,
|
||||
persistent_workers=True if n_worker > 0 else False,
|
||||
)
|
||||
loaders.append(loader)
|
||||
return loaders
|
Loading…
Add table
Add a link
Reference in a new issue