本代码参考廖星宇《深度学习入门之PyTorch》中的示例代码,手动拼接复现而来,仅供个人使用,侵删。
#ResNet实现CIFAR10分类
from datetime import datetime
import torch
import torch.nn.functional as F
from torch import nn
from torch.autograd import Variable
from torch import optim
from torch.utils.data import DataLoader
from torchvision import datasets,transforms
#定义ResNet基本模块-残差模块
def conv3x3(in_channel, out_channel, stride=1):
return nn.Conv2d(
in_channel,
out_channel,
kernel_size=3,
stride=stride,
padding=1,
bias=False)
#Residual Block
class residual_block(nn.Module):
def __init__(self, in_channel, out_channel, stride=1, downsample=None):
super(residual_block, self).__init__()
self.conv1 = conv3x3(in_channel, out_channel, stride)
self.bn1 = nn.BatchNorm2d(out_channel)
self.conv2 = conv3x3(out_channel, out_channel)
self.bn2 = nn.BatchNorm2d(out_channel)
self.downsample = downsa