fairseq框架下模型训练-train

本文详细介绍了fairseq训练框架的核心组成部分,包括task、model、criterion的定义与注册,以及训练流程中的关键步骤如构建模型、设置损失函数、迭代训练与验证。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

fairseq给出的训练框架中,包含几个部分,main()函数,train(), get_traning_stats(),validate(),get_valid_stats()。这个框架是不改变的,我们通过在fairseq/tasks中注册自己的task,fairseq/models中注册自己的model,fairseq/critirion中注册自己的critirion来完成基于fairseq框架的训练。我们来看一下在fairseq框架中如何调用我们自己定义的部分。

(1)首先建立task:

task = tasks.setup_task(args)

而调用的这个task正是我们自己定义的task。
setup_task的主要作用是读入src_dict和tgt_dict

@register_task('guided_translation')
class GuidedTranslationTask(FairseqTask):
    def __init__(self, args, src_dict, tgt_dict):
        super().__init__(args)
        self.src_dict = src_dict
        self.tgt_dict = tgt_dict
    def setup_task(cls, args, **kwargs):
            src_dict = cls.load_dictionary(os.path.join(paths[0], 'dict.{}.txt'.format(args.source_lang)))
            tgt_dict = cls.load_dictionary(os.path.join(paths[0], 'dict.{}.txt'.format(args.target_lang)))
    return cls(args, src_dict, tgt_dict)#返回task的实体

(2)建立model

model = task.build_model(args)

而在task.build_model()中有:

return models.build_model(args, self)

build_model在上一篇叙述model的文章中提及过,

@register_model("guided_transformer")
class GuidedTransformerModel(FairseqEncoderDecoderModel):
    def __init__(self, args, encoder, decoder):
        super().__init__(encoder, decoder)
        self.args = args
        self.supports_align_args = True
    def build_model(cls, args, task):
        encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens)
        decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens)
        return cls(args, encoder, decoder)

(3)建立criterion

criterion = task.build_criterion(args)

在guided transformer中,采用的是已经注册的label_smoothed_cross_entropy

(4)建立trainer
在fairseq中有trainer.py 可以根据需求提取其中的功能
trainer.py中class Trainer定义了 get_train_iterator,save_check_poinnt, load_check_point, train_step, valid_step等以及一些参数的接口。

from fairseq.trainer import Trainer
trainer = Trainer(args, task, model, criterion)

(5)读取、保存断点

    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)

在load_checkpoint中,做了几项工作
首先,通过class trainer中的load_checkpoint读取checkpoint_last.pt

    extra_state = trainer.load_checkpoint(
        checkpoint_path,
        args.reset_optimizer,
        args.reset_lr_scheduler,
        eval(args.optimizer_overrides),
        reset_meters=args.reset_meters,
    )

其次,通过get_train_iterator获取下一epoch的训练数据

    if extra_state is not None and not args.reset_dataloader:
        # restore iterator from checkpoint
        itr_state = extra_state["train_iterator"]
        epoch_itr = trainer.get_train_iterator(
            epoch=itr_state["epoch"], load_dataset=True, **passthrough_args
        )
        epoch_itr.load_state_dict(itr_state)
    else:
        epoch_itr = trainer.get_train_iterator(
            epoch=0, load_dataset=True, **passthrough_args
        )

(6)使用fairseq.logging 中的meters来对训练进行计时

    train_meter = meters.StopwatchMeter() #"""Computes the sum/avg duration of some event in seconds"""
    train_meter.start()
    train过程
    train_meter.stop()
    logger.info('done training in {:.1f} seconds'.format(train_meter.sum))

(7)开始训练
在以下条件成立时循环进行训练

    while (
        lr > args.min_lr
        and (
            epoch_itr.epoch < max_epoch
            # allow resuming training from the final checkpoint
            or epoch_itr._next_epoch_itr is not None
        )
        and trainer.get_num_updates() < max_update
    ):
train(args, trainer, task, epoch_itr)

在train函数中做如下几步
调用trainer.train_step

    for samples in progress:
        log_output = trainer.train_step(samples)
        num_updates = trainer.get_num_updates()
        if log_output is None:
            continue

在fairseq.trainer.py的train_step中,重点做了以下几步

        self._set_seed()
        self.model.train()
        self.criterion.train()
        self.zero_grad()

调用fairseq task. train_step计算结果以及损失函数

        loss, sample_size_i, logging_output = self.task.train_step(
               sample=sample,
               model=self.model,
               criterion=self.criterion,
               optimizer=self.optimizer,
               update_num=self.get_num_updates(),
               ignore_grad=is_dummy_batch,
                    )

在验证集上评估模型并返回损失值

valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)

调用与train_step相似的valid_step完成valid_loss计算

调用fairseq中的trainer class中的lr_step来更新learning rate
only usue first validation loss to update the learing rate.

lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

保存checkpoint

if epoch_itr.epoch % args.save_interval == 0:
      checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0])

获取下一个train epoch

        epoch_itr = trainer.get_train_iterator(
            epoch_itr.epoch,
            # sharded data: get train iterator for next epoch
            load_dataset=(os.pathsep in getattr(args, 'data', '')),
        )
### Fairseq 使用教程及项目地址 Fairseq 是由 Facebook AI 研究团队开发的一个高性能序列到序列 (Seq2Seq) 学习框架,广泛应用于机器翻译、文本生成以及其他自然语言处理任务的研究与开发[^2]。 #### 1. 安装方式 Fairseq 提供两种主要的安装方法: - **pip 安装**: 用户可以通过 Python 的包管理工具 `pip` 来快速安装 Fairseq。这种方式适合希望快速上手而不关心底层实现细节的用户[^3]。 ```bash pip install fairseq ``` - **源码安装**: 如果需要自定义功能或者参与项目的开发工作,则推荐通过克隆仓库并编译的方式进行安装。此方法允许开发者修改代码以满足特定需求。 ```bash git clone https://github.com/pytorch/fairseq.git cd fairseq pip install --editable . ``` #### 2. 主要命令工具 Fairseq 配备了一系列强大的命令行工具来支持数据预处理、模型训练以及推理等功能: - **fairseq-preprocess**: 数据预处理工具,用于准备输入数据集以便后续训练过程使用。 - **fairseq-train**: 训练神经网络模型的核心脚本。 - **fairseq-generate**: 利用已训练好的模型生成目标语句或其他形式的结果。 - **fairseq-interactive**: 实现交互式的预测模式,方便调试和测试新样本。 - **fairseq-score**: 对齐两个文件中的句子并对它们评分。 - **fairseq-eval-lm**: 评估语言模型性能。 #### 3. GitHub 项目地址 以下是 Fairseq 和其更新版本 Fairseq2 的官方 GitHub 地址: - Fairseq: [https://gitcode.com/gh_mirrors/fa/fairseq](https://gitcode.com/gh_mirrors/fa/fairseq) - Fairseq2: [https://gitcode.com/gh_mirrors/fa/fairseq2](https://gitcode.com/gh_mirrors/fa/fairseq2)[^1] 这些链接提供了完整的文档说明和技术支持资源,帮助使用者更好地理解和应用该框架。 ```python import torch from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils # 加载预训练模型 model_name_or_path = 'path/to/model' models, cfg, task = checkpoint_utils.load_model_ensemble_and_task([model_name_or_path]) # 设置设备为 GPU 或 CPU device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') for model in models: model.to(device) # 构建字典映射表 generator = task.build_generator(models, cfg.generation) ``` 以上是一个简单的例子展示如何加载 Fairseq 中保存下来的模型,并将其部署至指定硬件平台上运行推断操作。
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值