VGG-16
1. 导入工具包
import torch.optim as optim
import torch
import torch.nn as nn
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torch.optim.lr_scheduler as lr_scheduler
import os
注:
torch: 这是PyTorch框架的基础库,提供了自动求导机制和丰富的张量运算支持,是构建和训练神经网络的基础。
torch.nn: PyTorch的神经网络库,包含多种构建神经网络所需的层结构(如卷积层、全连接层)和激活函数等
torch.utils.data: 提供了数据加载和处理的工具,是加载数据集并进行批处理的重要模块
torchvision.transforms: PyTorch的视觉库中的一个模块,提供了一系列图像处理的变换操作,用于数据增强和预处理
torchvision.datasets: 提供了常见的数据集和相关的数据加载方法,如MNIST、CIFAR-10、ImageNet等
DataLoader: torch.utils.data中的一个类,用于构建可迭代的数据加载器,可以方便地在训练循环中按批次加载数据
torch.optim.lr_scheduler: 提供了学习率调整策略,如学习率衰减,有助于训练过程中改善模型性能和减少过拟合
os: Python的标准库之一,提供了与操作系统交互的功能,如文件创建、路径操作等
2. 判断环境是CPU运行还是GPU
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
注:通过这行代码,程序可以根据当前环境自动选择最佳的计算设备,以提高计算效率和性能。在GPU上运行可以显著加速深度学习模型的训练和推理过程。
3. 数据预处理
# 定义数据预处理
transform = {
'train': transforms.Compose([
transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),
transforms.RandomRotation(degrees=15),
transforms.RandomHorizontalFlip(),
transforms.CenterCrop(size=224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(size=256),
transforms.CenterCrop(size=224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
}
注:
训练数据预处理 (transform[‘train’])
transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)): 随机裁剪图像,裁剪后的图像大小为256x256像素。裁剪区域的大小是原始图像尺寸的0.8到1.0倍之间随机选择
transforms.RandomRotation(degrees=15): 随机旋转图像,旋转角度在-15度到15度之间随机选择
transforms.RandomHorizontalFlip(): 随机水平翻转图像,即有一半的概率会翻转,一半的概率不翻
transforms.CenterCrop(size=224): 从图像中心裁剪出224x224像素的区域
transforms.ToTensor(): 将图像转换为PyTorch张量,并且将像素值从[0, 255]范围缩放到[0, 1]范围
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]): 对图像进行标准化处理。这里使用的是ImageNet数据集的均值和标准差,这些值分别用于每个颜色通道的减均值和除以标准差操作
验证数据预处理 (transform[‘val’])
transforms.Resize(size=256): 将图像大小调整为256x256像素,这里没有使用随机裁剪,而是直接调整大小
transforms.CenterCrop(size=224): 从图像中心裁剪出224x224像素的区域
transforms.ToTensor(): 将图像转换为PyTorch张量,并且将像素值从[0, 255]范围缩放到[0, 1]范围
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]): 对图像进行标准化处理,使用的是ImageNet数据集的均值和标准差
4. 读取数据
dataset = './dataset'
train_directory = os.path.join(dataset, 'train')
valid_directory = os.path.join(dataset, 'val')
注:
valid_directory = os.path.join(dataset, ‘val’):同样地,这行代码将dataset目录和字符串’val’连接起来,创建验证数据集的路径。valid_directory将包含数据集根目录下的验证数据子目录的完整路径。这两行代码通常用于准备