# 运用CNN分析MNIST手写数字分类
import torch
import numpy as np
from torch.utils.data import DataLoader
from torchvision.datasets import mnist
from torch import nn
from torch.autograd import Variable
from torch import optim
from torchvision import transforms
# 定义CNN
class CNN(nn.Module):
def __init__(self):
super(CNN,self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(1,16,kernel_size=3), # 16, 26 ,26
nn.BatchNorm2d(16),
nn.ReLU(inplace=True))
self.layer2 = nn.Sequential(
nn.Conv2d(16,32,kernel_size=3),# 32, 24, 24
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2,stride=2)) # 32, 12,12 (24-2) /2 +1
self.layer3 = nn.Sequential(
nn.Conv2d(32,64,kernel_size=3), # 64,10,10
nn.BatchNorm2d(64),
nn.ReLU(
深度学习之PyTorch —— CNN实现MNIST手写数字分类
最新推荐文章于 2025-06-13 21:14:36 发布