T5X项目中的辅助任务功能详解
t5x 项目地址: https://gitcode.com/gh_mirrors/t5/t5x
引言
在机器学习模型训练过程中,我们经常会遇到一些需要额外处理任务的场景。T5X框架提供了一套强大的辅助任务(Auxiliary Job)功能,能够帮助开发者高效地处理这些需求。本文将深入解析T5X中辅助任务的工作原理、典型应用场景以及具体实现方法。
辅助任务概述
辅助任务是T5X框架中一个非常实用的功能,它允许在主训练任务之外,自动触发并执行额外的任务。每当一个新的检查点(checkpoint)被保存时,系统就会自动启动一个辅助任务。
典型应用场景
-
耗时评估任务分离:当评估任务(如infer_eval或train_eval)因数据集过大、解码速度慢或多任务评估而耗时过长时,可以将评估任务分离到辅助任务中执行,避免影响主训练任务的速度。
-
持续微调:在训练过程中,希望对每个检查点进行下游任务的微调。
-
自定义评估:当需要运行自定义评估代码,而这些代码不适合放在seqio.Evaluator框架中时。
工作原理
辅助任务启动时,控制器会自动替换四个关键的gin宏:
MODEL_DIR
:模型目录MIXTURE_OR_TASK_NAME
:任务或混合名称(由用户通过标志控制)INITIAL_CHECKPOINT_PATH
:最新检查点路径TRAIN_STEPS
:训练步数
此外,辅助任务可以拥有与主训练任务完全不同的资源配置、优先级甚至运行环境。
实战案例一:分离评估任务
实施步骤
-
选择模型架构:本例使用T5-1.1-Base模型。
-
选择训练和评估任务:使用WMT14英法翻译任务('wmt_enfr14_v003')。
-
编写Gin配置文件:
- 需要两个gin文件:一个用于训练任务,一个用于辅助任务。
- 辅助任务gin文件可以基于现有配置或完全独立。
-
启动实验:使用特定标志控制辅助任务行为。
关键标志说明
auxiliary_job_mixtures
:逗号分隔的任务列表,每个任务都会触发一个辅助任务auxiliary_job_gin_file
:辅助任务使用的gin配置文件replace_gin_file
:设为True时,辅助任务不使用训练任务的任何gin配置auxiliary_job_cell
:辅助任务运行环境auxiliary_job_platform
:辅助任务硬件平台auxiliary_job_build_target
:辅助任务使用的二进制文件final_auxiliary_job_steps
:设为0表示进行持续评估
示例脚本
declare -a ARGS=(
--cell=iz
--platform=jd=2x2
--final_auxiliary_job_steps=0
--replace_gin_file=True
--auxiliary_job_mixtures=wmt14_enfr_v003
--auxiliary_job_gin_file=base_wmt14enfr_eval.gin
--auxiliary_job_cell=iz
--auxiliary_job_platform=jd=2x2
--auxiliary_job_build_target_path=//t5x:eval
--gin_file=base_wmt14enfr_train.gin
)
gxm t5x/google/xm_launch.py "${ARGS[@]}"
实战案例二:持续微调任务
本案例展示如何在C4数据集上预训练模型,并在WMT'14英法翻译任务上持续微调。
实施步骤
-
模型架构:同样使用T5-1.1-Base模型。
-
训练任务:使用C4数据集上的span corruption任务('c4_v220_span_corruption')。
-
Gin配置:
base_c4_pretrain.gin
:训练任务配置base_wmtenfr14_finetune.gin
:辅助任务配置
-
启动实验:关键区别是将
final_auxiliary_job_steps
设为非零值(如200),并使用train.py
二进制文件。
示例脚本
declare -a ARGS=(
--cell=iz
--platform=jd=2x2
--final_auxiliary_job_steps=200
--replace_gin_file=True
--auxiliary_job_mixtures=wmt14_enfr_v003
--auxiliary_job_gin_file=base_wmt14enfr_finetune.gin
--auxiliary_job_cell=iz
--auxiliary_job_platform=jd=2x2
--auxiliary_job_build_target_path=//t5x:train
--gin_file=base_c4_pretrain.gin
)
gxm t5x/google/xm_launch.py "${ARGS[@]}"
常见问题与解决方案
-
未设置auxiliary_mixtures标志:即使gin文件中已指定任务,也必须设置此标志。
-
使用不同二进制文件时未设置replace_gin_file=True:会导致找不到train函数的错误。
-
未记录指标:确保SeqIO评估器正确配置以记录到TensorBoard。
-
train_eval速度慢:考虑将train_eval指标添加到SeqIO任务的metrics_fn参数中,在辅助任务中计算。
-
宏名称混淆:辅助任务使用
INITIAL_CHECKPOINT_PATH
而非CHECKPOINT_PATH
。 -
gin宏被忽略:所有gin宏必须在gin脚本中定义,不能通过标志传递。
-
持续评估设置错误:进行持续评估时必须设置
final_auxiliary_job_steps=0
。
最佳实践建议
-
资源分配:评估任务通常需要比训练任务更少的资源,合理配置可提高资源利用率。
-
任务优先级:根据业务需求设置适当的任务优先级,确保关键任务优先执行。
-
日志管理:为辅助任务配置独立的日志目录,便于问题排查和结果分析。
-
监控机制:建立完善的监控机制,及时发现和处理辅助任务失败的情况。
-
参数调优:根据实际运行情况,不断优化辅助任务的执行频率和资源配置。
通过合理使用T5X的辅助任务功能,开发者可以构建更加灵活、高效的训练流程,显著提升模型开发效率。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考