fashion mnist数据获取
根据Fashion Mnist论文给出的网址下载数据集:
https://github.com/zalandoresearch/fashion-mnist
网络结构
包括输入层,两个卷积层,全连接层和输出层,下面是详细信息
Net(
(conv1): Sequential(
(0): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(1): ReLU()
(2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(conv2): Sequential(
(0): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(1): ReLU()
(2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(output): Linear(in_features=1568, out_features=10, bias=True)
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-rbSItlYW-1605606095709)(https://github.com/Ryanlzz/Ryanlzz.github.io/blob/main/image/fashion.png)]
代码
可视化数据和制作标签
import os
from skimage import io
import torchvision.datasets.mnist as mnist
root="fashion_mnist/"
train_set = (
mnist.read_image_file(os.path.join(root, 'train-images-idx3-ubyte')),
mnist.read_label_file(os.path.join(root, 'train-labels-idx1-ubyte'))
)
test_set = (
mnist.read_image_file(os.path.join(root, 't10k-images-idx3-ubyte')),
mnist.read_label_file(os.path.join(root, 't10k-labels-idx1-ubyte'))
)
print("training set :"