fairseq源码分析(四)—— 训练一个模型fairseq都干了啥之train.py

本文详细介绍了fairseq框架的分布式训练流程,包括单节点多GPU的训练策略。核心在于`torch.multiprocessing.spawn`方法,用于启动多进程进行分布式训练。主要步骤涉及CUDA初始化、任务设定、数据加载、模型与criterion初始化、虚拟batch创建、trainer与dataloader的初始化,以及训练过程中的关键操作。fairseq的训练轮数由lr、max_epoch和max_update参数共同决定。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

框架的执行流程主要如下:
首先被调用的为最外层的train.py,在调用train.py时,fairseq会读取你的命令行参数,并根据命令行参数内的信息,生成相应的分布式策略,其调用流程如下:
在这里插入图片描述

分布式

fairseq共支持以下三种训练方式,分别为单GPU,分布式训练,单节点多GPU。
下面就单节点多GPU的分布式训练策略做一下简单的介绍:

    if args.distributed_init_method is not None:
        # distributed training
        distributed_main(args.device_id, args)
    elif args.distributed_world_size > 1:
        # fallback for single node with multiple GPUs
        port = random.randint(10000, 20000)
        args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port)
        args.distributed_rank = None  # set based on device id
        if max(args.update_freq) > 1 and args.ddp_backend != 'no_c10d':
            print('| NOTE: you may get better performance with: --ddp-backend=no_c10d')
        torch.multiprocessing.spawn(
            fn=distributed_main,
            args=(args, ),
            nprocs=args.distributed_world_size,
        )

其中有关多卡训练最重要的一行为torch.multiprocessing.spawn(),该方法是pytorch多进程处理包中的一个方法,主要用来生成分布式策略,通过创建进程实例并调用join来等待它们完成,可以生成大量子进程来执行某些功能。其参数分别为:

  • fn:函数被称为派生进程的入口点
  • args (tuple) – 传递给 fn 的参数.
  • nprocs (int) – 派生的进程数.
    所以该句话的意思为:将使用nprocs个子进程来分布式运行fn(args)。
    继续向下看可以看到在train.py中这个fn就是指distributed_main,而distributed_main又是通过间接调用main方法实现的,所以这将我们关注的重点指向了main函数的执行。

开始训练

命运的车轮让我们来到了main函数面前,让我们一起看看main函数都干了些什么。
思来想去不知道这一部分该怎么去描述,最后还是决定用一张图来简介明了的来展示给大家,
在这里插入图片描述
其步骤大体为:

  • 初始化cuda,查看设备的可用状态。
    在这里插入图片描述

  • 根据命令行参数设置需要执行的任务,如translation
    在这里插入图片描述

  • 加载数据集并切分
    在这里插入图片描述

  • 初始化分布式训练
    在这里插入图片描述
    这部分加载了socket,其分布式训练是通过进程间通信的方式进行实现,其底层的分布式的运行是通过MPI库实现
    在这里插入图片描述

  • 加载模型与训练标准(criterion)
    在这里插入图片描述

  • 创建虚拟(dummy)的batch
    这里他的作用有两点:1、预热缓存分配器。2、在每个工作进程的batch数不均匀时作为占位符进行分布式数据并行训练。
    在这里插入图片描述

  • 初始化trainer
    在这里插入图片描述
    trainer是一个支持数据并行训练的一个class,这个在稍后的博文中会进行介绍。

  • 初始化dataloder
    在这里插入图片描述
    上面只是加载了数据集,在此处加载模型的dataloader,Dataloader的处理逻辑是先通过Dataset类里面的 getitem 函数获取单个的数据,然后组合成batch,再使用collate_fn所指定的函数对这个batch做一些操作。

  • 加载检查点
    在这里插入图片描述

  • 训练
    在这里插入图片描述
    可以看到影响fairseq训练轮数的主要有三个参数:lr(学习率),max_epoch,max_update,只有这三个参数同时满足训练的要求,模型的训练才会继续。
    在训练过程中模型主要经历了,train,valid,保存价差点等操作,这些在之后的章节中在进行分享。

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值