pytorch搭建简单的CNN网络进行训练
适合小白入门
1. 导包
import numpy as np
import torch, torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
import torchvision.transforms as transforms
- torchvison:计算机视觉工具包,包含:
torchvison.transforms(常用的图像预处理方法);
torchvision.datasets(常用数据集的dataset实现,MNIST,CIFAR-10,ImageNet等);
torchvison.model(常用的模型预训练,AlexNet,VGG,ResNet,GoogleNet等)。
2. 准备数据集
下载一个公开数据集CIFAR10,简单的普通三通道图片的数据。
ROOT = '.data' # 下载保存路径
train_data = datasets.CIFAR10(ROOT, train=True, download = True)
数据预处理(增强)
计算数据的平均偏差和标准差,以便将其归一化。
数据集是由彩色图像组成的,但是有三个颜色通道(红色、绿色和蓝色)。
独立计算每个颜色通道的均值和标准差。
means = train_data.data.mean(axis = (0,1,2))/ 255
stds = train_data.data.std(axis = (0,1,2))/ 255
数据增强
定义数据预处理(增强)的方式
train_transforms = transforms.Compose([
transforms.RandomRotation(5),
transforms.RandomHorizontalFlip(0.5),
transforms.ToTensor(), # 将一个PIL Image或numpy.ndarray图像转换为一个Tensor数据类型
transforms.Normalize(mean = means, std <

最低0.47元/天 解锁文章





