#用于分布式环境下进行all-reduce的操作
def _reduce(tensor: Tensor) -> Tensor:
if gpc.get_world_size(ParallelMode.TENSOR) == 1:
return tensor
dist.all_reduce(tensor,
op=dist.ReduceOp.SUM,
group=gpc.get_group(ParallelMode.TENSOR),#指定用于通信的进程组,通过gpc.get_group(ParallelMode.TENSOR)获取张量模式下的并行模式组
async_op=False)#同步执行全局求和操作,即在该操作完成之前会阻塞程序继续执行,直到所有进程完成全局求和操作
return tensor
- 对张量进行分割
#用于在分布式环境下将张量沿指定维度分割成多份
def _split(tensor: Tensor