swift GRPOTrainer类 源码解析

@contextmanager
def unwrap_model_for_generation(
    model,
    accelerator,
    gather_deepspeed3_params=True,
    gather_parameters: List = None,
):
    unwrapped_model = accelerator.unwrap_model(model)
    if accelerator.state.deepspeed_plugin is not None and accelerator.state.deepspeed_plugin.zero_stage == 3:
        if not gather_deepspeed3_params:
            yield accelerator.unwrap_model(model)
        else:
            import deepspeed
            parameters = [
                parameter for name, parameter in model.named_parameters()
                if not gather_parameters or name in gather_parameters
            ]
            with deepspeed.zero.GatheredParameters(parameters):
                from trl.models.utils import remove_hooks
                remove_hooks(model)
                yield accelerator.unwrap_model(model)
                from trl.models.utils import add_hooks
                add_hooks(model)
    else:
        yield unwrapped_model


class GRPOCallback(TrainerCallback):

    def __init__(self, trainer):
        self.trainer = trainer

    # offload original_modules to cpu, to save memory
    def on_train_begin(self, args, state, control, **kwargs):
        self.trainer.queue = self.trainer.train_queue
        train_dataloader = getattr(state, 'train_dataloader', None) or kwargs.get('train_dataloader')
        self.trainer._prefetch(train_dataloader)


@dataclass
class DataCache:
    inputs: List[Dict] = field(default_factory=list)
    outputs: List[Dict] = field(default_factory=list)
    distributed_idx: List[List] = field(default_factory=list)

est

Poe

这段代码是用于大语言模型训练的,使用了 AccelerateDeepSpeedTrainerCallback 以及自定义的数据缓存结构。它涉及模型解包(unwrap)和 DeepSpeed Zero Stage 3 的特殊处理。下面是对代码中每一部分逐行的详细解释:


一、unwrap_model_for_generation 上下文管理器

python

运行复制

@contextmanager
def unwrap_model_for_generation(
    model,
    accelerator,
    gather_deepspeed3_params=True,
    gather_parameters: List = None,
):
  • @contextmanager:表明该函数是一个上下文管理器(可用于 with 语句)。
  • model:传入的模型对象。
  • acceleratorAccelerate 框架的对象,用于处理多设备训练、混合精度等。
  • gather_deepspeed3_params:是否需要收集 DeepSpeed Stage 3 分散的参数。
  • gather_parameters:指定收集哪些参数的名字列表(可选)。

python

运行复制

    unwrapped_model = accelerator.unwrap_model(model)
  • 使用 accelerator.unwrap_model 解包模型,去除封装(如 DDP、FP16、Deepspeed 包装等),得到底层原始模型。

python

运行复制

    if accelerator.state.deepspeed_plugin is not None and accelerator.state.deepspeed_plugin.zero_stage == 3:
  • 判断是否使用了 DeepSpeed,并且是 Zero Stage 3(该阶段会对参数进行分片,不能直接访问完整参数)。

python

运行复制

        if not gather_deepspeed3_params:
            yield accelerator.unwrap_model(model)
  • 如果不需要收集分片参数,直接 yield 解包后的模型。

python

运行复制

        else:
            import deepspeed
  • 否则导入 deepspeed,准备使用其工具收集参数。

python

运行复制

            parameters = [
                parameter for name, parameter in model.named_parameters()
                if not gather_parameters or name in gather_parameters
            ]
  • 根据 gather_parameters 筛选出要收集的参数。如果没有指定,则收集所有参数。

python

运行复制

            with deepspeed.zero.GatheredParameters(parameters):
  • 使用 Deepspeed 的上下文管理器,将 Stage 3 中被分散的参数 收集到当前进程,以便可以正常使用。

python

运行复制

                from trl.models.utils import remove_hooks
                remove_hooks(model)
  • 从模型中暂时移除 hook(比如 forward hook),防止收集参数时引发副作用。

python

运行复制

                yield accelerator.unwrap_model(model)
  • yield 解包后的模型,此时参数是完整的。

python

运行复制

                from trl.models.utils import add_hooks
                add_hooks(model)
  • 在 yield 后恢复之前移除的 hook,保持模型状态的一致性。

python

运行复制

    else:
        yield unwrapped_model
  • 如果不是 DeepSpeed Stage 3,直接 yield 解包后的模型。

二、GRPOCallback 回调类

python

运行复制

class GRPOCallback(TrainerCallback):
  • 自定义的训练回调类,继承自 TrainerCallback(用于训练过程中的事件响应)。

python

运行复制

    def __init__(self, trainer):
        self.trainer = trainer
  • 初始化时将 trainer 对象保存为成员变量,以便后续使用。

python

运行复制

    def on_train_begin(self, args, state, control, **kwargs):
  • 训练开始时回调的函数。

python

运行复制

        self.trainer.queue = self.trainer.train_queue
  • train_queue 赋值给 queue,可能用于数据加载或任务队列处理。

python

运行复制

        train_dataloader = getattr(state, 'train_dataloader', None) or kwargs.get('train_dataloader')
  • 尝试从 statekwargs 中获取训练数据加载器。

python

运行复制

        self.trainer._prefetch(train_dataloader)
  • 调用 _prefetch 方法预加载数据,提高训练效率。

三、DataCache 数据缓存结构

python

运行复制

@dataclass
class DataCache:
  • 使用 dataclass 声明一个简单的数据结构类 DataCache

python

运行复制

    inputs: List[Dict] = field(default_factory=list)
  • 存储输入数据,类型为字典的列表。

python

运行复制

    outputs: List[Dict] = field(default_factory=list)
  • 存储模型输出,类型为字典的列表。

python

运行复制

    distributed_idx: List[List] = field(default_factory=list)
  • 存储分布式训练场景下的索引映射(可能用于还原或聚合)。

总结

这段代码展示了:

  • 如何处理 DeepSpeed Stage 3 参数收集。
  • 如何自定义训练回调以优化内存和数据加载。
  • 如何定义训练过程中的输入输出缓存结构。

它是构建一个高效、可扩展的大模型训练框架的重要组件。

class GRPOTrainer(RLHFTrainerMixin, SwiftMixin, HFGRPOTrainer):
    executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)

    def __init__(self,
                 model: Optional[Union[PreTrainedModel, nn.Module]] = None,
                 ref_model: Optional[Union[PreTrainedModel, nn.Module]] = None,
                 reward_model: Optional[Union[PreTrainedModel, nn.Module]] = None,
                 reward_funcs: Optional[List[Union[str, Callable]]] = None,
                 *_args,
                 **kwargs):
        from swift.trainers.rlhf_arguments import GRPOConfig
        args: GRPOConfig = kwargs['args']
        self.args = args
        self.train_queue = Queue()
        self.eval_queue = Queue()
        self.processing_class = kwargs.get('template').tokenizer
        self.offload_modules = {}
        self.offload_states = {}
        _, _, _, local_world_size = get_dist_setting()

        if not isinstance(reward_funcs, list):
            rewa
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值