SGD / Ranger21训练se-resnet18 [ cifar-10 ]

本文介绍了一种集成多种最新深度学习组件的优化器Ranger21,并通过Imagenette和CIFAR-10数据集上的实验对比了其与SGD的表现。尽管Ranger21在准确率上有优势,但在训练后期可能会出现验证集损失增大的问题。

这两天看到了一个叫Ranger21(github / arxiv)的训练器,写的是将最新的深度学习组件集成到单个优化器中,以AdamW优化器作为其核心(或可选的MadGrad)、自适应梯度剪裁、梯度中心化、正负动量、稳定权值衰减、线性学习率warm-up、Lookahead、Softplus变换、梯度归一化等,有些技术我也没接触过,反正听着很厉害。

于是在Imagenette(github),Imagenette是Imagenet中10个易于分类的类的子集,训练集每类大概900多张,验证集每类大概400张左右,用Xception试了一下,如下图所示:
在这里插入图片描述
acc方面ranger21可以超过90%,而sgd只有81%(没仔细调参),似乎用起来比sgd更简单一点,不仅快而且泛化性还强(注:二者用一样的学习率,ranger21自带的学习率策略是warmup – stable – warmdown,sgd用的余弦退火),但是几次实验下来发现ranger21总是在训练末期,验证集上的损失会上升,百度了一下可能原因是这个,意思是模型过于极端,在个别预测错误的样本上损失太大,因此拉大了整体损失,但不怎么影响准确度。

之后还是在cifar10上进行了一下实验,模型采用的是pre-activation的resnet18(但其实记得论文说pre-act对浅层用处不大),并加上了squeeze-excitation模块,即se-preact-resnet18代码如下所示:

import torch
from torch import nn


class SEBlock(nn.Module):
    def __init__(self, in_planes, planes, stride=1):
        
        super(SEBlock, self).__init__()
        
        self.residual = nn.Sequential(
            nn.BatchNorm2d(in_planes),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(planes),
            nn.ReLU(inplace=True),
            nn.Conv2d(planes, planes, kernel_size=3
CIFAR-10 数据集上使用 ResNet-18 模型是一个常见的深度学习任务,ResNet-18 是一种轻量级的残差网络,适用于图像分类任务。以下是一个典型的实现方式,基于 PyTorch 框架。 ### 数据预处理 CIFAR-10 数据集的图像尺寸为 32x32,通道数为 3。通常使用 CIFAR-10 的均值和标准差进行归一化处理,以加速模型收敛: ```python import torch import torchvision import torchvision.transforms as transforms transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2) testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2) ``` ### ResNet-18 模型定义 使用 PyTorch 提供的 `torchvision.models.resnet18` 并调整最后的全连接层以适配 CIFAR-1010 类输出: ```python import torch.nn as nn import torchvision.models as models class ResNet18CIFAR10(nn.Module): def __init__(self): super(ResNet18CIFAR10, self).__init__() self.model = models.resnet18(weights=None) self.model.fc = nn.Linear(self.model.fc.in_features, 10) def forward(self, x): return self.model(x) ``` ### 训练设置 使用交叉熵损失函数和 SGD 优化器进行训练: ```python import torch.optim as optim import torch.nn as nn device = 'cuda' if torch.cuda.is_available() else 'cpu' net = ResNet18CIFAR10().to(device) criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1) ``` ### 训练循环 ```python for epoch in range(200): # 假设训练200个epoch net.train() for inputs, targets in trainloader: inputs, targets = inputs.to(device), targets.to(device) optimizer.zero_grad() outputs = net(inputs) loss = criterion(outputs, targets) loss.backward() optimizer.step() scheduler.step() # 每个epoch结束后评估测试集准确率 net.eval() correct = 0 total = 0 with torch.no_grad(): for inputs, targets in testloader: inputs, targets = inputs.to(device), targets.to(device) outputs = net(inputs) _, predicted = torch.max(outputs.data, 1) total += targets.size(0) correct += (predicted == targets).sum().item() print(f'Epoch {epoch+1}, Accuracy: {100 * correct / total:.2f}%') ``` ### GitHub 项目推荐 以下是一些包含 CIFAR-10ResNet-18 实现的开源项目: 1. **[kuangliu/pytorch-cifar](https://github.com/kuangliu/pytorch-cifar)** 该项目提供了多种经典网络(包括 ResNet)在 CIFAR-10/100 上的实现,结构清晰,适合学习和复现。 2. **[huyvnphan/PyTorch_CIFAR10](https://github.com/huyvnphan/PyTorch_CIFAR10)** 提供了训练和评估脚本,并支持多种模型结构,包括 ResNet、DenseNet 等。 3. **[chenyaofo/pytorch-cifar-models](https://github.com/chenyaofo/pytorch-cifar-models)** 该项目提供预训练ResNet、VGG、MobileNet 等模型,适用于快速部署和推理。 这些项目均具备良好的文档和可读性,适合用于教学、研究或部署[^1]。
评论 2
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值