以下是 LoRA_ViT_timm
和 LoraConfig
+ get_peft_model
两种 LoRA 微调 ViT 方式的详细对比、使用方法以及注意事项的整理:
1. 两种方式的对比
特性 | LoRA_ViT_timm | LoraConfig + get_peft_model |
---|---|---|
实现方式 | 自定义实现 | 使用 PEFT 库标准化实现 |
LoRA 添加位置 | 手动指定(如 query 和 value 投影层) | 通过 target_modules 自动指定 |
分类器处理 | 需要手动修改和设置可训练状态 | 通过 modules_to_save 自动处理 |
代码复杂度 | 较高 | 较低 |
依赖库 | timm | peft 和 transformers |
灵活性 | 完全控制 LoRA 的添加位置和实现细节 | 依赖于 PEFT 库的功能和配置 |
适用场景 | 需要定制化实现的场景 | 快速实验和标准化应用 |
2. 使用方法
LoRA_ViT_timm
的使用方法
完整的 LoRA_ViT_timm
实现分析
1. LoRA 适配器的定义
_LoRA_qkv_timm
类用于实现 LoRA 适配器。以下是它的定义:
class _LoRA_qkv_timm(nn.Module):
def __init__(self, w_qkv, w_a_q, w_b_q, w_a_v, w_b_v):
super().__init__()
self.w_qkv = w_qkv # 原始的 qkv 线性层
self.w_a_q = w_a_q # LoRA 适配器的 A 矩阵(query)
self.w_b_q = w_b_q # LoRA 适配器的 B 矩阵(query)
self.w_a_v = w_a_v # LoRA 适配器的 A 矩阵(value)
self.w_b_v = w_b_v # LoRA 适配器的 B 矩阵(value)
def forward(self, x):
# 原始 qkv 投影
qkv = self.w_qkv(x)
new_q = self.linear_b_q(self.linear_a_q(x))
new_v = self.linear_b_v(self.linear_a_v(x))
# 缩放因子 (self.alpha // self.r) 的作用是控制 LoRA 适配器输出 对原始模型输出的影响程度
qkv[:, :, : self.dim] += (self.alpha // self.r) * new_q
qkv[:, :, -self.dim :] += (self.alpha // self.r) * new_v
# qkv[:, :, : self.dim] += new_q
# qkv[:, :, -self.dim :] += new_v
return qkv
2. 完整的 LoRA_ViT_timm
实现
以下是完整的 LoRA_ViT_timm
实现:
import torch
import torch.nn as nn
from timm.models.vision_transformer import VisionTransformer
class LoRA_ViT_timm(nn.Module):
def __init__(self, vit_model: VisionTransformer, r: int, lora_layer=None):
super(LoRA_ViT_timm, self).__init__()
if r == 0:
# 如果 r=0,直接冻结整个模型
for param in vit_model.parameters():
param.requires_grad = False
self.lora_vit = vit_model
else:
# 如果 r > 0,插入 LoRA 适配器
if lora_layer:
self.lora_layer = lora_layer # 指定需要插入 LoRA 的层
else:
self.lora_layer = list(range(len(vit_model.blocks))) # 默认在所有层插入 LoRA
# 存储 LoRA 适配器的参数
self.w_As = [] # A 矩阵
self.w_Bs = [] # B 矩阵
# 冻结原始模型的参数
for param in vit_model.parameters():
param.requires_grad = False
# 在指定层插入 LoRA 适配器
for t_layer_i, blk in enumerate(vit_model.blocks):
if t_layer_i not in self.lora_layer:
continue
# 获取 qkv 线性层的输入维度
w_qkv_linear = blk.attn.qkv
self.dim = w_qkv_linear.in_features
# 定义 LoRA 适配器的 A 和 B 矩阵
w_a_linear_q = nn.Linear(self.dim, r, bias=False)
w_b_linear_q = nn.Linear(r, self.dim, bias=False)
w_a_linear_v = nn.Linear(self.dim, r, bias=False)
w_b_linear_v = nn.Linear(r, self.dim, bias=False)
# 存储 A 和 B 矩阵
self.w_As.append(w_a_linear_q)
self.w_Bs.append(w_b_linear_q)
self.w_As.append(w_a_linear_v)
self.w_Bs.append(w_b_linear_v)
# 替换原始的 qkv 线性层为 LoRA 适配器
blk.attn.qkv = _LoRA_qkv_timm(
w_qkv_linear,
w_a_linear_q,
w_b_linear_q,
w_a_linear_v,
w_b_linear_v
)
# 初始化 LoRA 参数
self.reset_parameters()
self.lora_vit = vit_model
def reset_parameters(self):
# 初始化 LoRA 参数
for w_a in self.w_As:
nn.init.kaiming_uniform_(w_a.weight, a=math.sqrt(5))
for w_b in self.w_Bs:
nn.init.zeros_(w_b.weight)
def forward(self, x):
return self.lora_vit(x)
3. 使用示例
以下是使用 LoRA_ViT_timm
的完整代码示例:
# 加载预训练的 ViT 模型
num_classes = 100
model = timm.create_model("vit_base_patch16_224", pretrained=True, num_classes=num_classes)
# 封装为 LoRA 模型
lora_r = 8
model = LoRA_ViT_timm(vit_model=model, r=lora_r)
# 验证 LoRA 参数是否被训练
print("验证 LoRA 参数是否被训练:")
for name, param in model.named_parameters():
if "w_As" in name or "w_Bs" in name:
print(f"{name}: requires_grad = {param.requires_grad}")
LoraConfig
+ get_peft_model
的使用方法
代码示例
from transformers import ViTForImageClassification
from peft import LoraConfig, get_peft_model
# 加载预训练模型
model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224")
# 修改分类器的输出维度
model.classifier = torch.nn.Linear(model.classifier.in_features, 100)
# 配置 LoRA
config = LoraConfig(
r=8, # LoRA 的秩
lora_alpha=16, # LoRA 的缩放因子
target_modules=["query", "value"], # 目标模块
lora_dropout=0.1, # Dropout 概率
bias="none", # 是否更新偏置
modules_to_save=["classifier"], # 指定分类器需要被微调
)
# 封装为 LoRA 模型
lora_model = get_peft_model(model, config)
# 验证分类器是否被微调
print("验证分类器参数是否被训练:")
for name, param in lora_model.named_parameters():
if "classifier" in name:
print(f"{name}: requires_grad = {param.requires_grad}")
注意事项
- 依赖库:
- 需要安装
peft
和transformers
库。
- 需要安装
- 分类器处理:
- 如果分类器的输出维度不匹配,需要手动修改分类器。
- 通过
modules_to_save=["classifier"]
可以自动将分类器设置为可训练状态。
- LoRA 添加位置:
- 通过
target_modules
指定 LoRA 的添加位置(如["query", "value"]
)。
- 通过
- 以下是 LoraConfig 的所有参数及其含义:
参数名 | 类型 | 默认值 | 含义 |
---|---|---|---|
r | int | 8 | LoRA 的秩(低秩分解的维度)。 |
lora_alpha | int | 16 | LoRA 的缩放因子(用于控制 LoRA 输出的权重)。 |
target_modules | List[str] | None | 需要插入 LoRA 的模块名称列表(如 ["query", "value"] )。 |
lora_dropout | float | 0.1 | LoRA 适配器的 Dropout 概率。 |
bias | str | "none" | 是否更新偏置参数("none" 、"all" 、"lora_only" )。 |
task_type | str | None | 任务类型(如 "CAUSAL_LM" 、"SEQ_CLS" )。 |
fan_in_fan_out | bool | False | 是否启用 fan-in/fan-out 权重初始化。 |
modules_to_save | List[str] | None | 需要额外更新的模块名称列表(如分类器modules_to_save=[“classifier”])。 |
init_lora_weights | bool | True | 是否初始化 LoRA 权重。 |
inference_mode | bool | False | 是否在推理模式下冻结 LoRA 适配器。 |
- task_type=“CAUSAL_LM” 的含义
- 任务类型:
CAUSAL_LM
表示任务是因果语言建模,即生成式语言模型(如 GPT 系列)。- 这种任务的特点是模型只能根据前面的单词预测下一个单词,不能看到未来的单词。
- 其他任务类型:
SEQ_CLS
: 序列分类任务(如文本分类)。TOKEN_CLS
: 标记分类任务(如命名实体识别)。QUESTION_ANS
: 问答任务。SEQ_2_SEQ_LM
: 序列到序列语言建模任务(如翻译、摘要)。
- 任务类型:
3. 选择建议
-
选择
LoRA_ViT_timm
:- 如果你需要高度定制化的 LoRA 实现,或者对 LoRA 的原理有深入了解。
- 如果你希望完全控制 LoRA 的添加位置和实现细节。
-
选择
LoraConfig
+get_peft_model
:- 如果你希望快速实现 LoRA 微调,或者对 LoRA 的实现细节不感兴趣。
- 如果你需要标准化和易用的解决方案。