第三章: 神经网络原理详解与Pytorch入门
第二部分:深度学习框架PyTorch入门
第六节:Pytorch进阶训练
内容:自定义loss、模型微调、数据增强
一、自定义 Loss 函数
【深度学习】关键技术-损失函数(Loss Function)_slideloss[16]损失函数是由googleai在2022年提出的一种用于深度学习目标检测中的-优快云博客
虽然 PyTorch 提供了许多标准损失函数(如 nn.CrossEntropyLoss
, nn.MSELoss
等),但在一些任务中,我们可能需要根据特定目标自定义损失函数。
自定义Loss函数步骤:
-
继承
nn.Module
-
实现
forward
方法
示例:带权重的MSE Loss
import torch
import torch.nn as nn
class WeightedMSELoss(nn.Module):
def __init__(self, weight):
super(WeightedMSELoss, self).__init__()
self.weight = weight
def forward(self, input, target):
return torch.mean(self.weight * (input - target) ** 2)
# 使用示例
loss_fn = WeightedMSELoss(weight=torch.tensor([1.0, 2.0]))
二、模型微调(Fine-tuning)
微调是利用已经训练好的模型参数(通常是预训练模型),在新任务中进行调整的过程。
常用步骤:
-
加载预训练模型(如
resnet18(pretrained=True)
) -
冻结部分层参数(避免训练)
-
替换输出层(适配新任务)
-
训练新网络结构
示例代码:
import torchvision.models as models
import torch.nn as nn
model = models.resnet18(pretrained=True)
for param in model.parameters():
param.requires_grad = False # 冻结全部参数
# 替换分类层
model.fc = nn.Linear(model.fc.in_features, 10) # 新的10类任务
可选项:仅解冻最后几层微调,速度更快,效果更稳定。
三、数据增强(Data Augmentation)
【漫话机器学习系列】031.数据增强(Dateset augmentation)_数据增强 随机插入case-优快云博客
数据增强是在训练时对图像进行随机变化,以提高模型的泛化能力。
使用 torchvision.transforms
实现常见增强方法:
import torchvision.transforms as transforms
transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ColorJitter(brightness=0.2, contrast=0.2),
transforms.ToTensor()
])
常见增强方式:
方法 | 描述 |
---|---|
RandomCrop | 随机裁剪图像 |
RandomRotation | 随机旋转角度 |
ColorJitter | 亮度、对比度、饱和度扰动 |
GaussianBlur | 添加模糊噪声 |
Cutout/Erasing | 局部遮挡模拟鲁棒性 |
注意:验证集与测试集通常不使用数据增强,只需归一化处理。
总结对比表:
项目 | 作用与场景 | 示例 |
---|---|---|
自定义Loss | 适配特定任务,如加权误差或不对称损失 | 自定义MSE |
模型微调 | 利用已有模型加速新任务训练 | ResNet迁移 |
数据增强 | 增强鲁棒性、减少过拟合 | 旋转、翻转 |