[CLIP-VIT-L + Qwen] 多模态大模型源码阅读 - trainer篇


参考repo:WatchTower-Liu/VLM-learning; url: VLLM-BASE

前情提要

有关多模态大模型架构中的语言模型部分(MQwen.py)的代码请看(多模态大模型源码阅读 - 1多模态大模型源码阅读 - 2多模态大模型源码阅读 - 3多模态大模型源码阅读 - 4),多模态大模型架构中的视觉模型(visual/CLIP-VIT.py)部分请看多模态大模型源码阅读 - 5
本节主要讲的是项目中的多模态Trainer部分,即项目文件trainer.py,该文件中的代码重构了部分transfomers.trainer的成员方法,以适配多模态场景下的模型训练,包括自定义的损失计算,参数保存,优化器配置,支持分布式训练(多卡场景)。

源码阅读

导包

import torch
from transformers import Trainer
from transformers.trainer import (
    is_sagemaker_mp_enabled,
    get_parameter_names,
    has_length,
    ALL_LAYERNORM_LAYERS,
    logger,
)
import os
from peft import get_peft_model_state_dict

逐行解读

import torch
from transformers import Trainer

torch不必赘述,深度学习的核心出装,构建和训练神经网络的必备库,调包调参侠(我)的福音。
Trainer类主要用于NLP和多模态任务,简化模型训练过程,在后续的代码中作为父类使用。

from transformers.trainer import (
    is_sagemaker_mp_enabled,
    get_parameter_names,
    has_length,
    ALL_LAYERNORM_LAYERS,
    logger,
)

is_sagemaker_mp_enabled检验是否在Amazon SageMaker的模型并行环境中运行。模型并行性允许将模型的不同组件分布到多个GPU设备上,用以加速大规模模型的训练。如果是单卡童鞋就不必在意这个设置~
get_parameter_names用以获取模型中的参数名,在设置优化器参数时,可以区分需要权重衰减的参数和不需要的参数。
has_length检测对象是否有长度信息,用于确定训练过程的迭代次数。在项目代码中没有用到。
ALL_LAYERNORM_LAYERS:包含所有LAYERNORM类型的层,用于在优化器配置中排除这些层的权重衰减。
logger:日志记录,输出训练过程中信息和调试信息。

import os
from peft import get_peft_model_state_dict

os:经常使用的库,主要用来创建文件、文件夹,开关文件。
peft(Parameter-Efficient Fine-Tuning),用于高效微调模型,在微调过程中会冻结预训练模型的大部分参数,仅保留少量的可训练参数,以在尽可能少的资源占用和时间下微调模型适配下游任务, 大名鼎鼎的LoRA、Prefix Tuning、Prompt Tuning 等都在这个库中。get_peft_model_state_dict用于获取微调后的adapter状态字典。例如使用LoRA对模型微调后,可以使用这一方法获取微调后的LoRA adapter状态字典。

compute_loss方法(重构)

class MultiModalTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        return model(
            image=inputs["images"],
            input_ids=inputs["input_ids"],
            labels=inputs["labels"],
        ).loss

整体含义

为多模态场景自定义的损失计算重构方法,以适配多模态形式的输入,如image

逐行解读

class MultiModalTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):

自定义MultiModelTrainer类,继承自transfomers.Trainer,拥有其成员变量和方法。
model:可以同时处理图片和文本类型输入
inputs:包含图片输入,文本索引输入和有监督训练需要的标签数据。
return_outputs:指示是否返回模型输出,考虑到这个项目是科研级代码,所以这个参数没啥用(QWQ)。

        return model(
            image=inputs["images"],
            input_ids=inputs["input_ids"],
            labels=inputs["labels"],
        ).loss

将inputs字典中的对应键下的值传递给model,获取其返回值中的损失值,用于后续的模型优化。

save_model函数(重构)

    def save_model(self, output_dir=None, _internal_call=False):
        from transformers.trainer import TRAINING_ARGS_NAME
        
        # Ensure output_dir is not None
        if output_dir is None:
            output_dir = self.args.output_dir
        
        # Create the output directory if it doesn't exist
        os.makedirs(output_dir, exist_ok=True)
        
        # Save training arguments
        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
        
        # Access the original model
        model = self.model.module if hasattr(self.model, 'module') 
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值