训练过程中定义了损失函数,激活函数,优化器,并进行了测试,计算了测试准确率,并用tensorboard进行可视化,数据集采用torchvision的CIFAR10,并运用GPU训练
🚀代码如下:
train.py
from torch import nn
import torch
import torchvision
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, ReLU
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from module import *
# 准备数据集
train_data = torchvision.datasets.CIFAR10(root='./dataset', train=True, transform=torchvision.transforms.ToTensor(),
download=True)
val_data = torchvision.datasets.CIFAR10(root='./dataset', train=False, transform=torchvision.transforms.ToTensor(),
download=True)
train_data_size = len(train_data)
val_data_size = len(val_data)
# 加载数据集
train_loader = DataLoader(train