如何用torch框架训练深度学习模型(详解)
0. 需要的包
import torch
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
1. 数据加载和导入
以MNIST数据集为例
# 1.1 需要设置数据归一化
train_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])
test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])
# 1.2 用dataset.MNIST函数下载和加载训练集与测试集
train_dataset = datasets