External-Attention-pytorch迁移学习:预训练模型微调指南
你是否还在为深度学习模型训练耗时长、数据不足而烦恼?本文将带你使用External-Attention-pytorch库,通过迁移学习(Transfer Learning)技术,利用预训练模型快速解决自己的任务。读完本文后,你将掌握如何加载预训练模型、修改分类层、冻结与解冻参数以及完整的微调流程,即使只有少量数据也能训练出高性能模型。
为什么选择迁移学习?
深度学习模型通常需要大量数据和计算资源才能训练到理想效果。而迁移学习可以将预训练模型(在大规模数据集上训练好的模型)的知识迁移到新任务中,实现:
- 减少数据需求:只需少量标注数据即可获得良好效果
- 降低计算成本:无需从头训练,节省时间和资源
- 提高模型性能:利用预训练特征提取能力,加速收敛并提升精度
External-Attention-pytorch库提供了丰富的预训练模型和注意力机制模块,如ResNet、MobileViT等骨干网络,以及SE、CBAM等注意力机制,为迁移学习提供了强大支持。
准备工作:安装与环境配置
安装库
你可以通过两种方式安装External-Attention-pytorch库:
pip安装(推荐)
pip install fightingcv-attention
源码安装
git clone https://gitcode.com/gh_mirrors/ex/External-Attention-pytorch
cd External-Attention-pytorch
导入必要模块
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from model.backbone.resnet import ResNet50 # ResNet模型定义
from model.attention.SEAttention import SEAttention # 注意力机制模块
迁移学习实战:以ResNet50为例
1. 加载预训练模型
External-Attention-pytorch提供了多种预训练骨干网络,以ResNet50为例:
# 加载ResNet50模型,默认分类数为1000(ImageNet类别)
model = ResNet50(num_classes=1000)
# 加载预训练权重(假设已下载到本地)
pretrained_weights = torch.load("resnet50_pretrained.pth")
model.load_state_dict(pretrained_weights, strict=False) # strict=False允许部分层不匹配
ResNet50的网络结构包含:
- 卷积stem层:用于初步特征提取
- 四个残差块(layer1-layer4):逐步提取高层特征
- 分类头:全连接层输出1000类概率
2. 修改分类层
预训练模型默认用于1000类ImageNet分类任务,我们需要根据自己的任务修改分类层:
# 假设我们的任务是10类分类
num_classes = 10
# 获取原分类层输入特征数
in_features = model.classifier.in_features
# 替换分类层
model.classifier = nn.Linear(in_features, num_classes)
# 初始化新分类层权重
nn.init.normal_(model.classifier.weight, mean=0.0, std=0.01)
nn.init.zeros_(model.classifier.bias)
3. 添加注意力机制(可选)
为进一步提升模型性能,可以在预训练模型中插入注意力机制模块,如SEAttention:
# 在layer4后添加SE注意力模块
class ResNetWithSE(nn.Module):
def __init__(self, base_model, num_classes):
super().__init__()
self.base_model = base_model
self.se_attention = SEAttention(channel=2048) # ResNet50的layer4输出通道为2048
self.classifier = nn.Linear(2048, num_classes)
def forward(self, x):
x = self.base_model.conv1(x)
x = self.base_model.bn1(x)
x = self.base_model.relu(x)
x = self.base_model.maxpool(x)
x = self.base_model.layer1(x)
x = self.base_model.layer2(x)
x = self.base_model.layer3(x)
x = self.base_model.layer4(x)
x = self.se_attention(x) # 应用SE注意力
x = self.base_model.avgpool(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
# 创建带SE注意力的模型
model = ResNetWithSE(base_model=model, num_classes=num_classes)
SE注意力机制通过自适应地调整通道权重,增强有用特征,抑制无用特征,结构如下:
参数冻结与解冻策略
参数冻结
在微调初期,为保护预训练特征提取能力,通常冻结大部分参数,只训练新添加的层:
# 冻结base_model的所有参数
for param in model.base_model.parameters():
param.requires_grad = False
# 只训练SE注意力层和分类层
for param in model.se_attention.parameters():
param.requires_grad = True
for param in model.classifier.parameters():
param.requires_grad = True
参数解冻
当模型在新任务上收敛后,可以解冻部分高层参数,进行精细调整:
# 解冻layer4的参数(ResNet的最后一个残差块)
for param in model.base_model.layer4.parameters():
param.requires_grad = True
这种策略可以让模型在保持预训练特征的同时,适应新任务的数据分布。
完整微调流程
1. 数据准备
假设我们的数据集有train和val两个子集,使用PyTorch的Dataset和DataLoader加载:
# 数据预处理
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet标准化参数
])
# 假设我们有自定义的Dataset类
train_dataset = YourDataset("train_data_path", transform=transform)
val_dataset = YourDataset("val_data_path", transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
2. 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
# 只优化requires_grad=True的参数
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)
3. 训练与验证
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
num_epochs = 20
best_val_acc = 0.0
for epoch in range(num_epochs):
# 训练阶段
model.train()
train_loss = 0.0
train_correct = 0
train_total = 0
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item() * inputs.size(0)
_, predicted = torch.max(outputs.data, 1)
train_total += labels.size(0)
train_correct += (predicted == labels).sum().item()
train_acc = train_correct / train_total
train_loss = train_loss / train_total
# 验证阶段
model.eval()
val_loss = 0.0
val_correct = 0
val_total = 0
with torch.no_grad():
for inputs, labels in val_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
val_loss += loss.item() * inputs.size(0)
_, predicted = torch.max(outputs.data, 1)
val_total += labels.size(0)
val_correct += (predicted == labels).sum().item()
val_acc = val_correct / val_total
val_loss = val_loss / val_total
scheduler.step(val_loss)
print(f"Epoch {epoch+1}/{num_epochs}")
print(f"Train Loss: {train_loss:.4f} Acc: {train_acc:.4f}")
print(f"Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}\n")
# 保存最佳模型
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save(model.state_dict(), "best_model.pth")
print(f"Saved best model with val acc: {best_val_acc:.4f}")
常见问题与解决方案
过拟合问题
当训练数据较少时,模型容易过拟合,可采用以下方法:
- 增加数据增强:如旋转、缩放、裁剪等
- 使用早停(Early Stopping):监控验证集损失,不再提升时停止训练
- 添加正则化:如Dropout、L2正则化
# 添加Dropout
model = nn.Sequential(
model,
nn.Dropout(0.5)
)
# L2正则化
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
lr=1e-4, weight_decay=1e-5)
学习率调整
学习率是影响微调效果的关键超参数:
- 初始学习率不宜过大,通常为1e-4 ~ 1e-5
- 使用学习率调度器动态调整,如StepLR、CosineAnnealingLR
# CosineAnnealing学习率调度器
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-6)
总结与展望
本文详细介绍了使用External-Attention-pytorch进行迁移学习的完整流程,包括模型加载、分类层修改、注意力机制集成、参数冻结与解冻策略以及微调训练流程。通过迁移学习,我们可以充分利用预训练模型的特征提取能力,在少量数据上快速训练出高性能模型。
未来,你还可以尝试:
- 探索不同的注意力机制,如CBAM、ECA等
- 尝试不同的骨干网络,如MobileViT、ConvMixer
- 使用知识蒸馏技术进一步压缩模型,部署到边缘设备
希望本文对你的深度学习项目有所帮助!如果你有任何问题或建议,欢迎在评论区留言交流。别忘了点赞、收藏本文,关注作者获取更多深度学习实用教程!
附录:常用预训练模型加载代码
| 模型名称 | 加载代码 |
|---|---|
| ResNet50 | from model.backbone.resnet import ResNet50; model = ResNet50() |
| MobileViT | from model.backbone.MobileViT import MobileViT; model = MobileViT() |
| ConvMixer | from model.backbone.ConvMixer import ConvMixer; model = ConvMixer() |
| PVT | from model.backbone.PVT import PVT; model = PVT() |
模型详细实现可参考:model/backbone/ 目录下的对应文件。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考





