LoRA微调ViT模型代码实现

以下是 LoRA_ViT_timmLoraConfig + get_peft_model 两种 LoRA 微调 ViT 方式的详细对比、使用方法以及注意事项的整理:


1. 两种方式的对比

特性LoRA_ViT_timmLoraConfig + get_peft_model
实现方式自定义实现使用 PEFT 库标准化实现
LoRA 添加位置手动指定(如 queryvalue 投影层)通过 target_modules 自动指定
分类器处理需要手动修改和设置可训练状态通过 modules_to_save 自动处理
代码复杂度较高较低
依赖库timmpefttransformers
灵活性完全控制 LoRA 的添加位置和实现细节依赖于 PEFT 库的功能和配置
适用场景需要定制化实现的场景快速实验和标准化应用

2. 使用方法

LoRA_ViT_timm 的使用方法


完整的 LoRA_ViT_timm 实现分析

1. LoRA 适配器的定义

_LoRA_qkv_timm 类用于实现 LoRA 适配器。以下是它的定义:

class _LoRA_qkv_timm(nn.Module):
    def __init__(self, w_qkv, w_a_q, w_b_q, w_a_v, w_b_v):
        super().__init__()
        self.w_qkv = w_qkv  # 原始的 qkv 线性层
        self.w_a_q = w_a_q  # LoRA 适配器的 A 矩阵(query)
        self.w_b_q = w_b_q  # LoRA 适配器的 B 矩阵(query)
        self.w_a_v = w_a_v  # LoRA 适配器的 A 矩阵(value)
        self.w_b_v = w_b_v  # LoRA 适配器的 B 矩阵(value)

    def forward(self, x):
        # 原始 qkv 投影
        qkv = self.w_qkv(x)

        new_q = self.linear_b_q(self.linear_a_q(x))
        new_v = self.linear_b_v(self.linear_a_v(x))
       
        # 缩放因子 (self.alpha // self.r) 的作用是控制 LoRA 适配器输出 对原始模型输出的影响程度
        qkv[:, :, : self.dim] += (self.alpha // self.r) * new_q
        qkv[:, :, -self.dim :] += (self.alpha // self.r) * new_v
        
        # qkv[:, :, : self.dim] += new_q
        # qkv[:, :, -self.dim :] += new_v
        
        return qkv

2. 完整的 LoRA_ViT_timm 实现

以下是完整的 LoRA_ViT_timm 实现:

import torch
import torch.nn as nn
from timm.models.vision_transformer import VisionTransformer

class LoRA_ViT_timm(nn.Module):
    def __init__(self, vit_model: VisionTransformer, r: int, lora_layer=None):
        super(LoRA_ViT_timm, self).__init__()

        if r == 0:
            # 如果 r=0,直接冻结整个模型
            for param in vit_model.parameters():
                param.requires_grad = False
            self.lora_vit = vit_model
        else:
            # 如果 r > 0,插入 LoRA 适配器
            if lora_layer:
                self.lora_layer = lora_layer  # 指定需要插入 LoRA 的层
            else:
                self.lora_layer = list(range(len(vit_model.blocks)))  # 默认在所有层插入 LoRA

            # 存储 LoRA 适配器的参数
            self.w_As = []  # A 矩阵
            self.w_Bs = []  # B 矩阵

            # 冻结原始模型的参数
            for param in vit_model.parameters():
                param.requires_grad = False

            # 在指定层插入 LoRA 适配器
            for t_layer_i, blk in enumerate(vit_model.blocks):
                if t_layer_i not in self.lora_layer:
                    continue

                # 获取 qkv 线性层的输入维度
                w_qkv_linear = blk.attn.qkv
                self.dim = w_qkv_linear.in_features

                # 定义 LoRA 适配器的 A 和 B 矩阵
                w_a_linear_q = nn.Linear(self.dim, r, bias=False)
                w_b_linear_q = nn.Linear(r, self.dim, bias=False)
                w_a_linear_v = nn.Linear(self.dim, r, bias=False)
                w_b_linear_v = nn.Linear(r, self.dim, bias=False)

                # 存储 A 和 B 矩阵
                self.w_As.append(w_a_linear_q)
                self.w_Bs.append(w_b_linear_q)
                self.w_As.append(w_a_linear_v)
                self.w_Bs.append(w_b_linear_v)

                # 替换原始的 qkv 线性层为 LoRA 适配器
                blk.attn.qkv = _LoRA_qkv_timm(
                    w_qkv_linear,
                    w_a_linear_q,
                    w_b_linear_q,
                    w_a_linear_v,
                    w_b_linear_v
                )

            # 初始化 LoRA 参数
            self.reset_parameters()
            self.lora_vit = vit_model

    def reset_parameters(self):
        # 初始化 LoRA 参数
        for w_a in self.w_As:
            nn.init.kaiming_uniform_(w_a.weight, a=math.sqrt(5))
        for w_b in self.w_Bs:
            nn.init.zeros_(w_b.weight)

    def forward(self, x):
        return self.lora_vit(x)

3. 使用示例

以下是使用 LoRA_ViT_timm 的完整代码示例:

# 加载预训练的 ViT 模型
num_classes = 100
model = timm.create_model("vit_base_patch16_224", pretrained=True, num_classes=num_classes)

# 封装为 LoRA 模型
lora_r = 8
model = LoRA_ViT_timm(vit_model=model, r=lora_r)

# 验证 LoRA 参数是否被训练
print("验证 LoRA 参数是否被训练:")
for name, param in model.named_parameters():
    if "w_As" in name or "w_Bs" in name:
        print(f"{name}: requires_grad = {param.requires_grad}")

LoraConfig + get_peft_model 的使用方法

代码示例
from transformers import ViTForImageClassification
from peft import LoraConfig, get_peft_model

# 加载预训练模型
model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224")

# 修改分类器的输出维度
model.classifier = torch.nn.Linear(model.classifier.in_features, 100)

# 配置 LoRA
config = LoraConfig(
    r=8,  # LoRA 的秩
    lora_alpha=16,  # LoRA 的缩放因子
    target_modules=["query", "value"],  # 目标模块
    lora_dropout=0.1,  # Dropout 概率
    bias="none",  # 是否更新偏置
    modules_to_save=["classifier"],  # 指定分类器需要被微调
)

# 封装为 LoRA 模型
lora_model = get_peft_model(model, config)

# 验证分类器是否被微调
print("验证分类器参数是否被训练:")
for name, param in lora_model.named_parameters():
    if "classifier" in name:
        print(f"{name}: requires_grad = {param.requires_grad}")
注意事项
  1. 依赖库:
    • 需要安装 pefttransformers 库。
  2. 分类器处理:
    • 如果分类器的输出维度不匹配,需要手动修改分类器。
    • 通过 modules_to_save=["classifier"] 可以自动将分类器设置为可训练状态。
  3. LoRA 添加位置:
    • 通过 target_modules 指定 LoRA 的添加位置(如 ["query", "value"])。
  4. 以下是 LoraConfig 的所有参数及其含义:
参数名类型默认值含义
rint8LoRA 的秩(低秩分解的维度)。
lora_alphaint16LoRA 的缩放因子(用于控制 LoRA 输出的权重)。
target_modulesList[str]None需要插入 LoRA 的模块名称列表(如 ["query", "value"])。
lora_dropoutfloat0.1LoRA 适配器的 Dropout 概率。
biasstr"none"是否更新偏置参数("none""all""lora_only")。
task_typestrNone任务类型(如 "CAUSAL_LM""SEQ_CLS")。
fan_in_fan_outboolFalse是否启用 fan-in/fan-out 权重初始化。
modules_to_saveList[str]None需要额外更新的模块名称列表(如分类器modules_to_save=[“classifier”])。
init_lora_weightsboolTrue是否初始化 LoRA 权重。
inference_modeboolFalse是否在推理模式下冻结 LoRA 适配器。
  1. task_type=“CAUSAL_LM” 的含义
    • 任务类型:
      • CAUSAL_LM 表示任务是因果语言建模,即生成式语言模型(如 GPT 系列)。
      • 这种任务的特点是模型只能根据前面的单词预测下一个单词,不能看到未来的单词。
    • 其他任务类型:
      • SEQ_CLS: 序列分类任务(如文本分类)。
      • TOKEN_CLS: 标记分类任务(如命名实体识别)。
      • QUESTION_ANS: 问答任务。
      • SEQ_2_SEQ_LM: 序列到序列语言建模任务(如翻译、摘要)。

3. 选择建议

  • 选择 LoRA_ViT_timm:

    • 如果你需要高度定制化的 LoRA 实现,或者对 LoRA 的原理有深入了解。
    • 如果你希望完全控制 LoRA 的添加位置和实现细节。
  • 选择 LoraConfig + get_peft_model:

    • 如果你希望快速实现 LoRA 微调,或者对 LoRA 的实现细节不感兴趣。
    • 如果你需要标准化和易用的解决方案。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值