Pytorch分布式训练print()使用技巧

在分布式训练场景中,有时我们可能会需要使用print函数(虽然大部分情况下大多会用logging进行信息输出)在终端打印相关信息。但由于同时运行多个进程,如果不进行限制,每个进程都会打印信息,不但影响观感,而且可能会造成阻塞。

通常的解决方法是利用if条件语句进行限制,只在主进程中进行打印,如下:

# 当前为主进程
if args.rank == 0:
    print('Train message')

但最近在学习目标检测模型DINO源码时,我发现作者采用重写内置print函数的方式实现了相同的功能,即只在主进程中启用print函数,在其他进程中禁用print函数。

函数源码如下:

def setup_for_distributed(is_master):
    """
    This function disables printing when not in master process
    """
    import builtins as __builtin__

    # 得到内置的print函数
    builtin_print = __builtin__.print

    
    # 重写print函数
    def print(*args, **kwargs):
        force = kwargs.pop('force', False)
        # 在主进程或者强制条件下才调用内置print输出
        if is_master or force:
            builtin_print(*args, **kwargs)

    # 用重写后的print函数替换内置的print函数
    __builtin__.print = print

该方法具体的调用位置是在初始化多进程组之后,示例如下:

import torch

args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.dist_backend = 'nccl'
args.dist_url = 'env://'
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                         world_size=args.world_size, rank=args.rank)
# 只在主进程启用print
setup_for_distributed(args.rank == 0)

实测好用,且思路清奇,果然学习永无止境。在此做一个学习记录,也分享给需要的人。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值