42 lines
804 B
Python
42 lines
804 B
Python
|
from typing import List
|
||
|
from torch import distributed
|
||
|
|
||
|
|
||
|
def barrier():
|
||
|
if distributed.is_initialized():
|
||
|
distributed.barrier()
|
||
|
else:
|
||
|
pass
|
||
|
|
||
|
|
||
|
def broadcast(data, src):
|
||
|
if distributed.is_initialized():
|
||
|
distributed.broadcast(data, src)
|
||
|
else:
|
||
|
pass
|
||
|
|
||
|
|
||
|
def all_gather(data: List, src):
|
||
|
if distributed.is_initialized():
|
||
|
distributed.all_gather(data, src)
|
||
|
else:
|
||
|
data[0] = src
|
||
|
|
||
|
|
||
|
def get_rank():
|
||
|
if distributed.is_initialized():
|
||
|
return distributed.get_rank()
|
||
|
else:
|
||
|
return 0
|
||
|
|
||
|
|
||
|
def get_world_size():
|
||
|
if distributed.is_initialized():
|
||
|
return distributed.get_world_size()
|
||
|
else:
|
||
|
return 1
|
||
|
|
||
|
|
||
|
def chunk_size(size, rank, world_size):
|
||
|
extra = rank < size % world_size
|
||
|
return size // world_size + extra
|