【第三章:神经网络原理详解与Pytorch入门】02.深度学习框架PyTorch入门-(6)Pytorch进阶训练(自定义loss、模型微调、数据增强)

第三章: 神经网络原理详解与Pytorch入门

第二部分:深度学习框架PyTorch入门

第六节:Pytorch进阶训练

内容:自定义loss、模型微调、数据增强


一、自定义 Loss 函数

【深度学习】关键技术-损失函数(Loss Function)_slideloss[16]损失函数是由googleai在2022年提出的一种用于深度学习目标检测中的-优快云博客

虽然 PyTorch 提供了许多标准损失函数(如 nn.CrossEntropyLoss, nn.MSELoss 等),但在一些任务中,我们可能需要根据特定目标自定义损失函数。

自定义Loss函数步骤:
  1. 继承 nn.Module

  2. 实现 forward 方法

示例:带权重的MSE Loss
import torch
import torch.nn as nn

class WeightedMSELoss(nn.Module):
    def __init__(self, weight):
        super(WeightedMSELoss, self).__init__()
        self.weight = weight

    def forward(self, input, target):
        return torch.mean(self.weight * (input - target) ** 2)

# 使用示例
loss_fn = WeightedMSELoss(weight=torch.tensor([1.0, 2.0]))


二、模型微调(Fine-tuning)

模型微调(Fine-tuning)详解-优快云博客

微调是利用已经训练好的模型参数(通常是预训练模型),在新任务中进行调整的过程。

常用步骤:
  1. 加载预训练模型(如 resnet18(pretrained=True)

  2. 冻结部分层参数(避免训练)

  3. 替换输出层(适配新任务)

  4. 训练新网络结构

示例代码:
import torchvision.models as models
import torch.nn as nn

model = models.resnet18(pretrained=True)
for param in model.parameters():
    param.requires_grad = False  # 冻结全部参数

# 替换分类层
model.fc = nn.Linear(model.fc.in_features, 10)  # 新的10类任务

可选项:仅解冻最后几层微调,速度更快,效果更稳定。


三、数据增强(Data Augmentation)

【漫话机器学习系列】031.数据增强(Dateset augmentation)_数据增强 随机插入case-优快云博客

数据增强是在训练时对图像进行随机变化,以提高模型的泛化能力。

使用 torchvision.transforms 实现常见增强方法:
import torchvision.transforms as transforms

transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor()
])
常见增强方式:
方法描述
RandomCrop随机裁剪图像
RandomRotation随机旋转角度
ColorJitter亮度、对比度、饱和度扰动
GaussianBlur添加模糊噪声
Cutout/Erasing局部遮挡模拟鲁棒性

注意:验证集与测试集通常不使用数据增强,只需归一化处理。


总结对比表:

项目作用与场景示例
自定义Loss适配特定任务,如加权误差或不对称损失自定义MSE
模型微调利用已有模型加速新任务训练ResNet迁移
数据增强增强鲁棒性、减少过拟合旋转、翻转

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值