42 lines
No EOL
804 B
Python
Executable file
42 lines
No EOL
804 B
Python
Executable file
from typing import List
|
|
from torch import distributed
|
|
|
|
|
|
def barrier():
|
|
if distributed.is_initialized():
|
|
distributed.barrier()
|
|
else:
|
|
pass
|
|
|
|
|
|
def broadcast(data, src):
|
|
if distributed.is_initialized():
|
|
distributed.broadcast(data, src)
|
|
else:
|
|
pass
|
|
|
|
|
|
def all_gather(data: List, src):
|
|
if distributed.is_initialized():
|
|
distributed.all_gather(data, src)
|
|
else:
|
|
data[0] = src
|
|
|
|
|
|
def get_rank():
|
|
if distributed.is_initialized():
|
|
return distributed.get_rank()
|
|
else:
|
|
return 0
|
|
|
|
|
|
def get_world_size():
|
|
if distributed.is_initialized():
|
|
return distributed.get_world_size()
|
|
else:
|
|
return 1
|
|
|
|
|
|
def chunk_size(size, rank, world_size):
|
|
extra = rank < size % world_size
|
|
return size // world_size + extra |