花卉分类项目:白嫖艺术之迁移学习实战
项目概述:零元购的快乐
这是一个典型的"白嫖"预训练模型的实战项目!咱们穷学生哪来的顶级GPU从头训练深度网络?PyTorch早就看穿了我们的心思,贴心地提供了各种预训练模型任君采撷。本教程教你如何用ResNet-152这个"看起来很贵"的模型,在花卉数据集上耍出91%的准确率!
关键步骤解析:从白嫖到微调
1. 数据准备:伸手党的自我修养
虽然数据集上传不了(毕竟300多M呢),但Ka载速度杠杠的:
- 数据集放在
data/flower
目录下(别的位置记得改路径啊) - 数据增强什么的代码里都有注释,自己看咯
2. 模型选择:只选贵的,不选对的
在ResNet家族里,咱们一眼就相中了resnet152:
model_ft = models.resnet152(pretrained=use_pretrained) # 理直气壮地白嫖
为什么选它?数字大看着牛逼啊!152层不比18层厉害多了?(手动狗头)
3. 调教策略:分阶段白嫖
第一阶段:躺平式学习
- 把前面的层都冻住,美其名曰"保留预训练特征"
for param in model.parameters():
param.requires_grad = False # 冻结!这样显卡就不会冒烟了
- 只训练最后全连接层,学习率可以大点(反正就一层)
第二阶段:重新做人
- 解冻所有层:“这次我要拿回属于我的一切!”
feature_extract = False # 解冻!让显卡燃烧吧
- 学习率调小(步子太大容易扯着蛋)
optimizer_ft = optim.Adam(params_to_update, lr=1e-4) # 小碎步前进
4. 模型改造:偷梁换柱
人家ResNet本来是分1000类的,咱们花卉只有102类,得动个小手术:
model_ft.fc = nn.Sequential(
nn.Linear(num_ftrs, 102), # 强行改成我们的分类数
nn.LogSoftmax(dim=1)) # 输出层整容
5. 训练成果:真香!
- 第一阶段(20轮):32% → 73%(显卡表示毫无压力)
- 第二阶段(15轮):直接飙到91%(显卡开始冒烟)
- 总耗时约50分钟(取决于你的祖传显卡型号)
进阶白嫖指南
-
多试试其他模型:VGG、EfficientNet都嫖一遍,哪个好用哪个
-
学习率耍花样:试试余弦退火之类的骚操作
-
混合精度训练:让显卡少冒点烟,还能跑快点
-
模型可视化:看看网络到底关注花的哪个部位(是不是总盯着花心看?)
-
部署上线:用Flask做个网页应用,假装是自己从头训练的(划掉)
记住咱们读书人的事,能叫白嫖吗?这叫"迁移学习"!PyTorch官方给的模型,不用白不用,用了还想用。祝各位在调参的路上越走越远,早日炼出炼丹师的真火!