本篇博客主要介绍PyTorch中使用CNN网络进行MNIST数据分类。
示例代码:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.utils.data as Data
import torchvision
import numpy as np
import matplotlib.pyplot as plt
# 超参数
EPOCH = 1
BATCH_SIZE = 50
LR = 0.001
DOWNLOAD_MNIST = False # 已下载,设置为False,未下载,则设置为True
# 下载MNIST数据
# 训练数据
train_data = torchvision.datasets.MNIST(
root='./mnist/', # 数据保存地址
train=True, # 训练数据,False即为测试数据
transform=torchvision.transforms.ToTensor(), # 将下载的源数据变成Tensor数据,(0,1)
download=DOWNLOAD_MNIST,
)
# 显示一张样本图片
# print(train_data.train_data.size())
# print(train_data.train_labels.size())
# plt.imshow(train_data.train_data[0].numpy(), cmap='gray')
# plt.title