七、整体流程梳理
1. 引入使用的包
用到什么包,临时引入就可以,不用太担心。
import time
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.models import resnet18, ResNet18_Weights
import wandb
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import *
import matplotlib.pyplot as plt
2. 数据
# 下面和以前就一样了
train_dataset = CIFAR10(
root=datapath,
train=True,
download=True,
transform=transform,
)
# 构建训练数据集
train_loader = DataLoader(
#
dataset=train_dataset,
batch_size=batzh_size,
shuffle=True,
num_workers=2,
)
3. 模型
# 再次获取resnet18原始神经网络并对齐fc层进行调整
model = resnet18(weights=None)
in_features = model.fc.in_features
# 重写FC:我们这里做的是10分类
model.fc = nn.Linear(in_features=in_features, out_features=10)
# 需要对权重信息进行处理:要加载我们训练之后最新的权重文件
weights_default = torch.load(weightpath)
weights_default.pop("fc.weight")
weights_default.pop("fc.bias")
# 把权重参数进行同步
new_state_dict = model.state_dict()
weights_default_process = {
k: v for k, v in weights_default.items() if k in new_state_dict
}
new_state_dict.update(weights_default_process)
model.load_state_dict(new_state_dict)
model.to(device)
4. 训练
4.1 数据增强
为了防止过拟合,增加模型的泛化能力,我们会数据增强
transform = transforms.Compose(
[
transforms.RandomRotation(45), # 随机旋转,-45到45度之间随机选
transforms.RandomCrop(32, padding=4), # 随机裁剪
transforms.RandomHorizontalFlip(p=0.5), # 随机水平翻转 选择一个概率概率
transforms.RandomVerticalFlip(p=0.5), # 随机垂直翻转
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616)),
]
)
transformtest = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2471,