[CLIP-VIT-L + Qwen] 多模态大模型源码阅读 - trainer篇


参考repo:WatchTower-Liu/VLM-learning; url: VLLM-BASE

前情提要

有关多模态大模型架构中的语言模型部分(MQwen.py)的代码请看(多模态大模型源码阅读 - 1多模态大模型源码阅读 - 2多模态大模型源码阅读 - 3多模态大模型源码阅读 - 4),多模态大模型架构中的视觉模型(visual/CLIP-VIT.py)部分请看多模态大模型源码阅读 - 5
本节主要讲的是项目中的多模态Trainer部分,即项目文件trainer.py,该文件中的代码重构了部分transfomers.trainer的成员方法,以适配多模态场景下的模型训练,包括自定义的损失计算,参数保存,优化器配置,支持分布式训练(多卡场景)。

源码阅读

导包

import torch
from transformers import Trainer
from transformers.trainer import (
    is_sagemaker_mp_enabled,
    get_parameter_names,
    has_length,
    ALL_LAYERNORM_LAYERS,
    logger,
)
import os
from peft import get_peft_model_state_dict

逐行解读

import torch
from transformers import Trainer

torch不必赘述,深度学习的核心出装,构建和训练神经网络的必备库,调包调参侠(我)的福音。
Trainer类主要用于NLP和多模态任务,简化模型训练过程,在后续的代码中作为父类使用。

from transformers.trainer import (
    is_sagemaker_mp_enabled,
    get_parameter_names,
    has_length,
    ALL_LAYERNORM_LAYERS,
    logger,
)

is_sagemaker_mp_enabled检验是否在Amazon SageMaker的模型并行环境中运行。模型并行性允许将模型的不同组件分布到多个GPU设备上,用以加速大规模模型的训练。如果是单卡童鞋就不必在意这个设置~
get_parameter_names用以获取模型中的参数名,在设置优化器参数时,可以区分需要权重衰减的参数和不需要的参数。
has_length检测对象是否有长度信息,用于确定训练过程的迭代次数。在项目代码中没有用到。
ALL_LAYERNORM_LAYERS:包含所有LAYERNORM类型的层,用于在优化器配置中排除这些层的权重衰减。
logger:日志记录,输出训练过程中信息和调试信息。

import os
from peft import get_peft_model_state_dict

os:经常使用的库,主要用来创建文件、文件夹,开关文件。
peft(Parameter-Efficient Fine-Tuning),用于高效微调模型,在微调过程中会冻结预训练模型的大部分参数,仅保留少量的可训练参数,以在尽可能少的资源占用和时间下微调模型适配下游任务, 大名鼎鼎的LoRA、Prefix Tuning、Prompt Tuning 等都在这个库中。get_peft_model_state_dict用于获取微调后的adapter状态字典。例如使用LoRA对模型微调后,可以使用这一方法获取微调后的LoRA adapter状态字典。

compute_loss方法(重构)

class MultiModalTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        return model(
            image=inputs["images"],
            input_ids=inputs["input_ids"],
            labels=inputs["labels"],
        ).loss

整体含义

为多模态场景自定义的损失计算重构方法,以适配多模态形式的输入,如image

逐行解读

class MultiModalTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):

自定义MultiModelTrainer类,继承自transfomers.Trainer,拥有其成员变量和方法。
model:可以同时处理图片和文本类型输入
inputs:包含图片输入,文本索引输入和有监督训练需要的标签数据。
return_outputs:指示是否返回模型输出,考虑到这个项目是科研级代码,所以这个参数没啥用(QWQ)。

        return model(
            image=inputs["images"],
            input_ids=inputs["input_ids"],
            labels=inputs["labels"],
        ).loss

将inputs字典中的对应键下的值传递给model,获取其返回值中的损失值,用于后续的模型优化。

save_model函数(重构)

    def save_model(self, output_dir=None, _internal_call=False):
        from transformers.trainer import TRAINING_ARGS_NAME
        
        # Ensure output_dir is not None
        if output_dir is None:
            output_dir = self.args.output_dir
        
        # Create the output directory if it doesn't exist
        os.makedirs(output_dir, exist_ok=True)
        
        # Save training arguments
        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
        
        # Access the original model
        model = self.model.module if hasattr(self.model, 'module') else self.model
        
        # Save LLM parameters
        saved_params_LLM = get_peft_model_stat
当出现 `ModuleNotFoundError: No module named 'transformers.modeling_rope_utils'` 错误,表明 Python 解释器在当前环境中找不到 `transformers.modeling_rope_utils` 模块,可能是 `transformers` 库版本不兼容或者未正确安装导致。以下是一些解决办法: ### 1. 升级 `transformers` 库 要保证使用的 `transformers` 库是最新版本,因为 `modeling_rope_utils` 模块可能在较新的版本中才存在。可以使用以下命令升级: ```bash pip install --upgrade transformers ``` ### 2. 检查安装情况 确认 `transformers` 库已经正确安装。可以在 Python 环境中运行以下代码来检查: ```python import transformers print(transformers.__version__) ``` 如果输出了版本号,就说明库已经正确安装;若报错,则需要重新安装: ```bash pip uninstall transformers pip install transformers ``` ### 3. 检查虚拟环境 如果使用了虚拟环境,要确保当前环境中已经安装了 `transformers` 库。可以激活虚拟环境,然后再次执行安装命令。 ### 4. 检查导入路径 有时候,Python 解释器可能无法找到 `transformers` 库的安装路径。可以通过以下代码查看 Python 的模块搜索路径: ```python import sys print(sys.path) ``` 确保 `transformers` 库的安装路径在其中。 ### 示例代码 以下是一个简单的示例,展示如何正确导入 `transformers` 库: ```python import transformers try: from transformers.modeling_rope_utils import rope_config_validation print("Import successful!") except ModuleNotFoundError: print("Module not found. Please check your transformers installation.") ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值