203 lines
5.4 KiB
Python
Executable file
203 lines
5.4 KiB
Python
Executable file
"""
|
|
Copyright (c) 2022, salesforce.com, inc.
|
|
All rights reserved.
|
|
SPDX-License-Identifier: BSD-3-Clause
|
|
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
|
"""
|
|
|
|
import datetime
|
|
import functools
|
|
import os
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import timm.models.hub as timm_hub
|
|
|
|
|
|
def setup_for_distributed(is_master):
|
|
"""
|
|
This function disables printing when not in master process
|
|
"""
|
|
import builtins as __builtin__
|
|
|
|
builtin_print = __builtin__.print
|
|
|
|
def print(*args, **kwargs):
|
|
force = kwargs.pop("force", False)
|
|
if is_master or force:
|
|
builtin_print(*args, **kwargs)
|
|
|
|
__builtin__.print = print
|
|
|
|
|
|
def is_dist_avail_and_initialized():
|
|
if not dist.is_available():
|
|
return False
|
|
if not dist.is_initialized():
|
|
return False
|
|
return True
|
|
|
|
|
|
def get_world_size():
|
|
if not is_dist_avail_and_initialized():
|
|
return 1
|
|
return dist.get_world_size()
|
|
|
|
|
|
def get_rank():
|
|
if not is_dist_avail_and_initialized():
|
|
return 0
|
|
return dist.get_rank()
|
|
|
|
|
|
def is_main_process():
|
|
return get_rank() == 0
|
|
|
|
|
|
def init_distributed_mode(args):
|
|
if args.distributed is False:
|
|
print("Not using distributed mode")
|
|
args.rank = 0
|
|
return
|
|
|
|
if 'LOCAL_RANK' not in os.environ:
|
|
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
|
|
|
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
|
args.rank = int(os.environ["RANK"])
|
|
args.world_size = int(os.environ["WORLD_SIZE"])
|
|
args.gpu = int(os.environ["LOCAL_RANK"])
|
|
elif "SLURM_PROCID" in os.environ:
|
|
args.rank = int(os.environ["SLURM_PROCID"])
|
|
args.gpu = args.rank % torch.cuda.device_count()
|
|
else:
|
|
print("Not using distributed mode")
|
|
args.distributed = False
|
|
args.rank = 0
|
|
return
|
|
|
|
args.distributed = True
|
|
|
|
torch.cuda.set_device(args.gpu)
|
|
args.dist_backend = "nccl"
|
|
print(
|
|
"| distributed init (rank {}, world {}): {}".format(
|
|
args.rank, args.world_size, args.dist_url
|
|
),
|
|
flush=True,
|
|
)
|
|
torch.distributed.init_process_group(
|
|
backend=args.dist_backend,
|
|
init_method=args.dist_url,
|
|
world_size=args.world_size,
|
|
rank=args.rank,
|
|
timeout=datetime.timedelta(
|
|
days=365
|
|
), # allow auto-downloading and de-compressing
|
|
)
|
|
torch.distributed.barrier()
|
|
setup_for_distributed(args.rank == 0)
|
|
|
|
|
|
def get_dist_info():
|
|
if torch.__version__ < "1.0":
|
|
initialized = dist._initialized
|
|
else:
|
|
initialized = dist.is_initialized()
|
|
if initialized:
|
|
rank = dist.get_rank()
|
|
world_size = dist.get_world_size()
|
|
else: # non-distributed training
|
|
rank = 0
|
|
world_size = 1
|
|
return rank, world_size
|
|
|
|
|
|
def main_process(func):
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
rank, _ = get_dist_info()
|
|
if rank == 0:
|
|
return func(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
|
|
def download_cached_file(url, check_hash=True, progress=False):
|
|
"""
|
|
Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
|
|
If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
|
|
"""
|
|
|
|
def get_cached_file_path():
|
|
# a hack to sync the file path across processes
|
|
parts = torch.hub.urlparse(url)
|
|
filename = os.path.basename(parts.path)
|
|
cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
|
|
|
|
return cached_file
|
|
|
|
if is_main_process():
|
|
timm_hub.download_cached_file(url, check_hash, progress)
|
|
|
|
if is_dist_avail_and_initialized():
|
|
dist.barrier()
|
|
|
|
return get_cached_file_path()
|
|
|
|
|
|
class GatherLayer(torch.autograd.Function):
|
|
"""
|
|
Gather tensors from all workers with support for backward propagation:
|
|
This implementation does not cut the gradients as torch.distributed.all_gather does.
|
|
"""
|
|
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
output = [
|
|
torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())
|
|
]
|
|
torch.distributed.all_gather(output, x)
|
|
return tuple(output)
|
|
|
|
@staticmethod
|
|
def backward(ctx, *grads):
|
|
all_gradients = torch.stack(grads)
|
|
torch.distributed.all_reduce(all_gradients)
|
|
return all_gradients[torch.distributed.get_rank()]
|
|
|
|
|
|
def all_gather_with_grad(tensors):
|
|
"""
|
|
Performs all_gather operation on the provided tensors.
|
|
Graph remains connected for backward grad computation.
|
|
"""
|
|
# Queue the gathered tensors
|
|
world_size = torch.distributed.get_world_size()
|
|
# There is no need for reduction in the single-proc case
|
|
if world_size == 1:
|
|
return tensors
|
|
|
|
# tensor_all = GatherLayer.apply(tensors)
|
|
tensor_all = GatherLayer.apply(tensors)
|
|
|
|
return torch.cat(tensor_all, dim=0)
|
|
|
|
|
|
@torch.no_grad()
|
|
def concat_all_gather(tensor):
|
|
"""
|
|
Performs all_gather operation on the provided tensors.
|
|
*** Warning ***: torch.distributed.all_gather has no gradient.
|
|
"""
|
|
# if use distributed training
|
|
if not is_dist_avail_and_initialized():
|
|
return tensor
|
|
|
|
tensors_gather = [
|
|
torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
|
|
]
|
|
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
|
|
|
output = torch.cat(tensors_gather, dim=0)
|
|
return output
|