FashionMNIST_CNN with pytorch (accuracy: 97.58%/90.1%)
接上篇
手工简单复现经典CNN网络测试(LeNet5、AlexNet、VGG16、GoogLeNet、ResNe)
想进行各种CCN经典网络的效果对比,因为个人电脑配置原因失败,不得以科学上网,注册了kaggle进行训练,整体训练分两部分:
1 创建CNN网络训练(网络简单,大概后台运行了5h);
2 使用后台运行产生的模型,进行测试集预测
结果如下:
#1 ----------- first/第一步:kaggle后台训练模型 -------------------------
import torch
import torch.nn as nn
from torch.utils.data import Dataset,DataLoader,TensorDataset
from torch.nn import CrossEntropyLoss
from torch.optim import SGD,Adam
# from torchsummary import summary
from torchvision.datasets import ImageFolder,FashionMNIST
import torchvision.transforms as transforms
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import joblib
# 读入数据
train_data = pd.read_csv('/kaggle/input/fashionmnist/fashion-mnist_train.csv')
train_y = torch.tensor(train_data['label'])
train_x = torch.tensor(train_data.iloc[:,1:].values.reshape(60000,1,28,28))/255
# 创建批量数据加载器
dataset = TensorDataset(train_x,train_y)
batch_size= 5000 # 批次大小
dataloader = DataLoader(dataset=dataset,batch_size=batch_size,shuffle=True)
# 创建模型
class Net(nn.Module):
def __init__(self)

本文通过PyTorch实现了一个简单的CNN网络,用于FashionMNIST数据集的分类任务,并详细记录了从模型训练到测试的过程及结果。最终模型在本地测试集上的准确率达到97.58%,而在Kaggle平台上提供的数据集上则为90.1%。
最低0.47元/天 解锁文章
981

被折叠的 条评论
为什么被折叠?



