Linux Pytorch ResNet-18 cifar10 实践报告

本文对比了ResNet-v1和ResNet-v2在CIFAR-10数据集上的表现,发现ResNet-v2有略微优势。同时,通过应用MixUp、CutMix和TrivialAugment三种数据增强技术,测试准确率显著提升,其中CutMix表现最佳。实验表明数据增强能有效防止过拟合,增强模型泛化能力。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

硬件资源

cpu: Intel(R) Core(TM) i5-7500 CPU @ 3.40GHz
显卡: 1080Ti
内存: 16G

环境版本

#系统信息
Distributor ID:	Ubuntu
Description:	Ubuntu 16.04.5 LTS
Release:	16.04
Codename:	xenial
#主要依赖		
torch              1.10.0
torchvision        0.11.1
#CUDA信息
~$ nvcc -V
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2019 NVIDIA Corporation
Built on Wed_Oct_23_19:24:38_PDT_2019
Cuda compilation tools, release 10.2, V10.2.89
~$ cat /usr/local/cuda/version.txt
CUDA Version 10.2.89

实验方法

基于PyTorch使用ResNet-18模型训练cifar10数据集

  1. 对比 ResNet-v1 和 ResNet-v2 的测试集准确率
  2. 对比三种当前比较先进的数据增强方法(MixUp、CutMix、TrivialAugment)的测试集准确率

基本参数设置

#训练集 测试集比例
5:1 即训练集50000张,测试集10000张
# 超参数设置
EPOCH = 100  # 遍历数据集次数
BATCH_SIZE = 512  # 批处理尺寸(batch_size)
LR = 0.1  # 学习率

# 基本的随机增强
RandomCrop
RandomHorizontalFlip

# 损失函数
CrossEntropyLoss #交叉熵
# 优化方式
optim.SGD(net.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)
# 学习率迭代策略
#在指定的epoch值,[60, 90]处对学习率进行衰减,lr = lr * gamma
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[60,90], gamma=0.1)

实验结果

实验测试集准确率
ResNet-v194.13%
ResNet-v294.23%
ResNet-v2 + RandomMixup95.10%
ResNet-v2 + RandomCutmix95.46%
ResNet-v2 + TrivialAugment95.27%

实验结果对比

最终实验结果表明:

  • ResNet-v2的网络模型优于ResNet-v1。
  • 三种当前比较先进的数据增强方法(MixUp、CutMix、TrivialAugment)都有不俗的作用,一定程度上提升了准确率。
  • 针对cifar10数据集,用ResNet-v2的ResNet-18训练时,CutMix的数据增强手段最优。

结果分析

1. ResNet-v1 VS ResNet-v2

ResNet-v1 VS ResNet-v2
由上图可以看出,ResNet-v2相比于ResNet-v1,训练集(蓝线)和测试集准确率(粉红线)都更早更快得达到一个比较好的效果,也就是训练更容易,最终的测试集准确率也超出一点点(94.23% vs 94.13%)。
在这里插入图片描述

ResNet-v2重新设计了一种残差网络基本单元(unit)就是将激活函数(先BN再ReLU)移到权值层之前,形成一种“预激活(pre-activation)”的方式,如上图(b),而不是ResNet-v1中常规的“后激活(post-activation)”方式,如上图(a),并且预激活的单元中的所有权值层的输入都是归一化的信号。这使得网络更易于训练并且泛化性能也得到提升。

2. ResNet-v2 VS ResNet-v2+TrivialAugment

请添加图片描述
从上图可以看到,

  • ResNet-v2训练过程中,训练集准确率(绿线)一直高于测试集准确率(红线),并且最后训练集准确率接近100%,而测试集准确率仅有94.23%。
  • 加上TrivialAugment的数据增强方法后,训练全程测试集准确率(粉红线)一直高于训练集准确率(蓝线),最后训练集准确率达到92%左右,而测试集准确率达到95.27%。
  • 以上结果说明加上TrivialAugment的数据增强方法后,网络过拟合程度大大减小。其原因不难看出是因为强大的数据增强扩充了数据集,增强了网络的泛化能力。

3. MixUp vs CutMix vs TrivialAugment

请添加图片描述
几种数据增强的区别:MixUp vs CutMix vs TrivialAugment

  • MixUp:将随机的两张样本按比例混合,分类的结果按比例分配
    在这里插入图片描述

  • CutMix:将一部分区域cut掉但不填充0像素而是随机填充训练集中的其他数据的区域像素值,分类结果按一定的比例分配
    在这里插入图片描述

  • TrivialAugment:每次随机选择一个图像增强操作,然后随机确定它的增强幅度,并对图像进行增强。由于没有任何超参数,所以不需要任何搜索
    在这里插入图片描述

  • TrivialAugment通过组合多种数据增强手段,随即增强强度,可能会有不稳定因素

  • mixup是将两张图按比例进行插值来混合样本,cutmix是采用cut部分区域再补丁的形式去混合图像,不会有图像混合后不自然的情形

  • cutmix通过要求模型从局部视图识别对象,对cut区域中添加其他样本的信息,能够进一步增强模型的定位能力

  • cutmix不会有图像混合后不自然的情形,能够提升模型分类的表现,最终测试准确率也更高一些

<think>好的,用户现在问的是:“fastai没有直接下载cifar10的接口么?”首先,我需要回忆一下fastai的数据集支持情况。根据之前的了解,fastai确实内置了一些常见数据集,比如MNIST、CIFAR10等,但用户可能需要确认具体的接口名称和用法。 接下来,我应该检查用户之前的对话历史。用户之前询问的是如何使用fastai结合TensorBoard训练CIFAR10模型,当时提供的代码示例中,用户需要自己组织数据目录结构,比如假设数据存储在`./cifar10/train`和`./cifar10/val`。这可能让用户误以为fastai没有直接下载CIFAR10的功能,导致现在这个问题。 现在需要明确回答用户的问题:是的,fastai有直接下载CIFAR10的接口吗?如果有,具体是哪个函数?如果没有,用户应该如何正确加载? 查阅fastai的文档,发现`fastai.data.external`模块中确实有`untar_data`函数,可以用来下载并解压数据集,包括CIFAR10。例如,使用`URLs.CIFAR10`可以获取CIFAR10的URL,然后调用`untar_data`下载。这可能是因为用户之前没有注意到这个功能,或者在示例中没有使用,导致用户认为需要自己准备数据。 接下来,需要解释如何正确使用fastai内置的CIFAR10接口,并提供代码示例。同时,可能需要指出用户之前示例中的方法(手动组织数据)和直接下载的区别,以及哪种方式更推荐。 另外,用户可能在使用时遇到路径问题,或者对数据加载流程不太熟悉,因此需要详细说明步骤,并确保代码示例的准确性。同时,提醒用户可能需要调整数据增强或模型结构以适应CIFAR10的小尺寸图像。 还需要考虑用户可能的后续问题,例如如果不想使用内置接口,如何自行下载和处理CIFAR10数据,但当前问题主要围绕是否有直接接口,因此重点应放在肯定回答并提供使用方法上。 总结回答的结构: 1. 明确回答:是的,fastai有直接下载CIFAR10的接口。 2. 提供具体的函数和代码示例,展示如何下载和加载数据。 3. 对比之前示例中的方法,解释为何用户可能有此疑问。 4. 可能补充一些注意事项,比如数据路径、预处理等。 需要确保回答准确,避免误导用户,同时语言简洁明了,帮助用户快速解决问题。</think>--- 是的,**fastai 实际上内置了直接下载 CIFAR10 的接口**,但需要明确调用对应的数据集加载方法。以下是两种实现方式,分别对应**直接下载**和**手动加载**的完整流程: --- ### **方法 1:使用 fastai 内置接口(推荐)** ```python from fastai.vision.all import * # ✅ 直接下载并加载 CIFAR10(自动解压到 ~/.fastai/data 目录) path = untar_data(URLs.CIFAR_10) # 关键函数:untar_data + URLs.CIFAR_10 # 定义 DataBlock(适配官方 CIFAR10 结构) dblock = DataBlock( blocks=(ImageBlock(cls=PILImageBW), CategoryBlock), # CIFAR10 是灰度图 get_items=get_image_files, splitter=GrandparentSplitter(valid_name='test'), # 官方 test 集作为验证集 get_y=parent_label, item_tfms=Resize(32), batch_tfms=aug_transforms() ) # 创建 DataLoaders dls = dblock.dataloaders(path, bs=128) ``` #### **关键说明**: 1. **数据集路径** - `untar_data(URLs.CIFAR_10)` 会下载数据集到 `~/.fastai/data/cifar10`(Linux/Mac)或 `C:\Users\用户名\.fastai\data\cifar10`(Windows) - 目录结构自动包含 `train`(训练集)和 `test`(测试集) 2. **图像类型适配** - **特别注意**:CIFAR10 本质是 **32x32 彩色图**,但 fastai 的 `URLs.CIFAR_10` 接口返回的是转换为 **灰度图** 的版本 - 若需使用彩色图,需手动下载(见方法2) --- ### **方法 2:手动加载 PyTorchCIFAR10** 若需要原始彩色数据或更灵活的控制,可以通过 `torchvision` 加载: ```python from torchvision.datasets import CIFAR10 from fastai.vision.all import * # ✅ 下载并转换格式 train_ds = CIFAR10(root='data', train=True, download=True) valid_ds = CIFAR10(root='data', train=False, download=True) # 将 PyTorch Dataset 转为 fastai DataLoaders dls = DataLoaders.from_dsets( train_ds, valid_ds, bs=128, after_item=[ToTensor(), Resize(32)], after_batch=[*aug_transforms(), Normalize.from_stats(*imagenet_stats)] ) ``` --- ### **为什么用户会认为没有接口?** 1. **文档可见性** fastai 的 CIFAR10 接口 (`URLs.CIFAR_10`) 在官方教程中较少使用,主流教学更倾向 MNIST 或 Imagenette。 2. **灰度图限制** fastai 内置的 CIFAR10 是灰度图版本,而许多用户期望直接使用彩色图(需通过 `torchvision` 加载)。 --- ### **完整训练代码示例(方法1 + ResNet)** ```python # 模型定义(适配灰度图输入) learn = vision_learner( dls, resnet18, pretrained=False, metrics=accuracy, n_in=1 # 关键!输入通道数设为1(灰度图) ) # 训练配置 learn.fit_one_cycle(10, lr_max=1e-3) ``` --- ### **常见问题** 1. **维度不匹配错误** - 错误信息:`RuntimeError: Given groups=1, weight of size [64, 3, ...], expected input[...] 1 channels` - 解决方案: - 使用 `n_in=1`(灰度图)或 `n_in=3`(彩色图)明确指定输入通道数 2. **数据增强不生效** - 检查 `batch_tfms` 是否包含 `aug_transforms()`,并确保输入尺寸正确 --- ### **最佳实践建议** - **优先使用 torchvision 版本** 通过方法2加载彩色图,更符合主流研究设定: ```python dblock = DataBlock( blocks=(ImageBlock, CategoryBlock), get_items=lambda path: [ (x, str(y)) for x,y in train_ds ], # 自定义加载逻辑 splitter=RandomSplitter(0.2), batch_tfms=[*aug_transforms(), Normalize.from_stats(*cifar_stats)] ) ``` 其中 `cifar_stats` 需预先计算(CIFAR10 的均值和标准差)。 --- 通过上述方法,可以灵活选择 **fastai 内置接口** 或 **torchvision 原生加载** 两种方案,适配不同场景需求。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值