@contextmanager
def unwrap_model_for_generation(
model,
accelerator,
gather_deepspeed3_params=True,
gather_parameters: List = None,
):
unwrapped_model = accelerator.unwrap_model(model)
if accelerator.state.deepspeed_plugin is not None and accelerator.state.deepspeed_plugin.zero_stage == 3:
if not gather_deepspeed3_params:
yield accelerator.unwrap_model(model)
else:
import deepspeed
parameters = [
parameter for name, parameter in model.named_parameters()
if not gather_parameters or name in gather_parameters
]
with deepspeed.zero.GatheredParameters(parameters):
from trl.models.utils import remove_hooks
remove_hooks(model)
yield accelerator.unwrap_model(model)
from trl.models.utils import add_hooks
add_hooks(model)
else:
yield unwrapped_model
class GRPOCallback(TrainerCallback):
def __init__(self, trainer):
self.trainer = trainer
# offload original_modules to cpu, to save memory
def on_train_begin(self, args, state, control, **kwargs):
self.trainer.queue = self.trainer.train_queue
train_dataloader = getattr(state, 'train_dataloader', None) or kwargs.get('train_dataloader')
self.trainer._prefetch(train_dataloader)
@dataclass
class DataCache:
inputs: List[Dict] = field(default_factory=list)
outputs: List[Dict] = field(default_factory=list)
distributed_idx: List[List] = field(default_factory=list)
Poe
这段代码是用于大语言模型训练的,使用了 Accelerate、DeepSpeed、TrainerCallback 以及自定义的数据缓存结构。它涉及模型解包(unwrap)和 DeepSpeed Zero Stage 3 的特殊处理。下面是对代码中每一部分逐行的详细解释:
一、unwrap_model_for_generation 上下文管理器
python
运行复制
@contextmanager
def unwrap_model_for_generation(
model,
accelerator,
gather_deepspeed3_params=True,
gather_parameters: List = None,
):
@contextmanager:表明该函数是一个上下文管理器(可用于with语句)。model:传入的模型对象。accelerator:Accelerate框架的对象,用于处理多设备训练、混合精度等。gather_deepspeed3_params:是否需要收集 DeepSpeed Stage 3 分散的参数。gather_parameters:指定收集哪些参数的名字列表(可选)。
python
运行复制
unwrapped_model = accelerator.unwrap_model(model)
- 使用
accelerator.unwrap_model解包模型,去除封装(如 DDP、FP16、Deepspeed 包装等),得到底层原始模型。
python
运行复制
if accelerator.state.deepspeed_plugin is not None and accelerator.state.deepspeed_plugin.zero_stage == 3:
- 判断是否使用了 DeepSpeed,并且是 Zero Stage 3(该阶段会对参数进行分片,不能直接访问完整参数)。
python
运行复制
if not gather_deepspeed3_params:
yield accelerator.unwrap_model(model)
- 如果不需要收集分片参数,直接 yield 解包后的模型。
python
运行复制
else:
import deepspeed
- 否则导入
deepspeed,准备使用其工具收集参数。
python
运行复制
parameters = [
parameter for name, parameter in model.named_parameters()
if not gather_parameters or name in gather_parameters
]
- 根据
gather_parameters筛选出要收集的参数。如果没有指定,则收集所有参数。
python
运行复制
with deepspeed.zero.GatheredParameters(parameters):
- 使用 Deepspeed 的上下文管理器,将 Stage 3 中被分散的参数 收集到当前进程,以便可以正常使用。
python
运行复制
from trl.models.utils import remove_hooks
remove_hooks(model)
- 从模型中暂时移除 hook(比如 forward hook),防止收集参数时引发副作用。
python
运行复制
yield accelerator.unwrap_model(model)
- yield 解包后的模型,此时参数是完整的。
python
运行复制
from trl.models.utils import add_hooks
add_hooks(model)
- 在 yield 后恢复之前移除的 hook,保持模型状态的一致性。
python
运行复制
else:
yield unwrapped_model
- 如果不是 DeepSpeed Stage 3,直接 yield 解包后的模型。
二、GRPOCallback 回调类
python
运行复制
class GRPOCallback(TrainerCallback):
- 自定义的训练回调类,继承自
TrainerCallback(用于训练过程中的事件响应)。
python
运行复制
def __init__(self, trainer):
self.trainer = trainer
- 初始化时将
trainer对象保存为成员变量,以便后续使用。
python
运行复制
def on_train_begin(self, args, state, control, **kwargs):
- 训练开始时回调的函数。
python
运行复制
self.trainer.queue = self.trainer.train_queue
- 将
train_queue赋值给queue,可能用于数据加载或任务队列处理。
python
运行复制
train_dataloader = getattr(state, 'train_dataloader', None) or kwargs.get('train_dataloader')
- 尝试从
state或kwargs中获取训练数据加载器。
python
运行复制
self.trainer._prefetch(train_dataloader)
- 调用
_prefetch方法预加载数据,提高训练效率。
三、DataCache 数据缓存结构
python
运行复制
@dataclass
class DataCache:
- 使用
dataclass声明一个简单的数据结构类DataCache。
python
运行复制
inputs: List[Dict] = field(default_factory=list)
- 存储输入数据,类型为字典的列表。
python
运行复制
outputs: List[Dict] = field(default_factory=list)
- 存储模型输出,类型为字典的列表。
python
运行复制
distributed_idx: List[List] = field(default_factory=list)
- 存储分布式训练场景下的索引映射(可能用于还原或聚合)。
总结
这段代码展示了:
- 如何处理 DeepSpeed Stage 3 参数收集。
- 如何自定义训练回调以优化内存和数据加载。
- 如何定义训练过程中的输入输出缓存结构。
它是构建一个高效、可扩展的大模型训练框架的重要组件。
class GRPOTrainer(RLHFTrainerMixin, SwiftMixin, HFGRPOTrainer):
executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
def __init__(self,
model: Optional[Union[PreTrainedModel, nn.Module]] = None,
ref_model: Optional[Union[PreTrainedModel, nn.Module]] = None,
reward_model: Optional[Union[PreTrainedModel, nn.Module]] = None,
reward_funcs: Optional[List[Union[str, Callable]]] = None,
*_args,
**kwargs):
from swift.trainers.rlhf_arguments import GRPOConfig
args: GRPOConfig = kwargs['args']
self.args = args
self.train_queue = Queue()
self.eval_queue = Queue()
self.processing_class = kwargs.get('template').tokenizer
self.offload_modules = {}
self.offload_states = {}
_, _, _, local_world_size = get_dist_setting()
if not isinstance(reward_funcs, list):
rewa

最低0.47元/天 解锁文章
8956

被折叠的 条评论
为什么被折叠?



