1. 官网介绍
torch.distributed.all_gather() 官网链接
all_gather(tensor_list,tensor,group=None,async_op=False):
tensor_list每个元素代表每个rank的数据,tensor代表每个进程中的tensor数据,其中tensor_list每个分量的维度要与对应的tensor参数中每个rank的维度相同。


官网源代码:
def all_gather(tensor_list,
tensor,
group=None,
async_op=False):
"""
Gathers tensors from the whole group in a list.
Complex tensors are supported.
Args:
tensor_list (list[Tensor]): Output list. It should contain
correctly-sized tensors to be used for output of the collective.
tensor (Tensor): Tensor to be broadcast from current process.
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
async_op (bool, optional): Whether this op should be an async op
Returns:
Async work handle, if async_op is set to True.
None, if not async_op or if not part of the group
Examples:
>>> # All tensors below are of torch.int64 dtype.
>>> # We have 2 process groups, 2 ranks.
>>> tensor_list = [torch.zero(2, dtype=torch.int64) for _ in range(2)]
>>> tensor_list
[tensor([0, 0]), tensor([0, 0])] # Rank 0 and 1
>>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank
>>> tensor
tensor([1, 2]) # Rank 0
tensor([3, 4]) # Rank 1
>>> dist.all_gather(tensor_list, tensor)
>>> tensor_list
[tensor([1, 2]), tensor([3, 4])] # Rank 0
[tensor([1, 2]), tensor([3, 4])] # Rank 1
>>> # All tensors below are of torch.cfloat dtype.
>>> # We have 2 process groups, 2 ranks.
>>> tensor_list = [torch.zero(2, dtype=torch.cfloat) for

文章介绍了torch.distributed.all_gather函数在PyTorch中的用法,它用于收集分布式环境中各进程的数据。在模型测试或评估阶段,all_gather不会进行梯度传播。而在训练状态下,如果需要梯度传播,可以通过特殊方法实现。文章提供了相关代码示例和不同场景下的应用策略。
最低0.47元/天 解锁文章
1353





