使用PyTorch实现resnet-18分类CIFAR-10

本文介绍了使用PyTorch实现ResNet-18模型对CIFAR-10数据集进行分类的过程,包括模块引入、数据导入、定义Block类和ResNet类,以及训练过程中的精度计算和模型拟合。首次训练结果显示较低的准确率,通过增加训练轮数和优化策略,期望提升模型性能。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

昨天看完了resnet的论文,今天试着来实现一下。

先放一个resnet18的模型图:
在这里插入图片描述

模块引入

# 模块引入

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

数据导入

# 数据导入

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data',train=True,
                                        download=True,transform=transform)
trainloader = torch.utils.data.DataLoader(trainset,batch_size=4,
                                          shuffle=True,num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data',train=False,
                                       download=True,transform=transform)
testloader = torch.utils.data.DataLoader(testset,batch_size=4,
                                        shuffle=False,num_workers=2)

classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')
class_nums = 10

定义block类

正式准备开始写模型后遇到的第一个问题就是如何表示short connection,一个最直观的想法是,手动编写每一层,然后注意保存输入x,在适当的时候把它进行处理,然后加到输出里。

这个方法实现起来很简单,但它有三个问题:

  1. 如果要对网络进行修改,会比较麻烦。
  2. 18层的resnet这样写是可行的,但是论文中提到了有搭建1000层的resnet,如果这样写的话,1000层写起来就太麻烦了。
  3. 这样写不能体
评论 11
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值