昨天看完了resnet的论文,今天试着来实现一下。
先放一个resnet18的模型图:
模块引入
# 模块引入
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
数据导入
# 数据导入
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data',train=True,
download=True,transform=transform)
trainloader = torch.utils.data.DataLoader(trainset,batch_size=4,
shuffle=True,num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data',train=False,
download=True,transform=transform)
testloader = torch.utils.data.DataLoader(testset,batch_size=4,
shuffle=False,num_workers=2)
classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')
class_nums = 10
定义block类
正式准备开始写模型后遇到的第一个问题就是如何表示short connection,一个最直观的想法是,手动编写每一层,然后注意保存输入x,在适当的时候把它进行处理,然后加到输出里。
这个方法实现起来很简单,但它有三个问题:
- 如果要对网络进行修改,会比较麻烦。
- 18层的resnet这样写是可行的,但是论文中提到了有搭建1000层的resnet,如果这样写的话,1000层写起来就太麻烦了。
- 这样写不能体