在上一篇文章使用预训练模型VGG16识别动物中,我们使用大规模数据集训练了高精度模型,并下载了预训练模型进行推理。但当找不到完全符合需求的预训练模型,且没有足够的大规模数据集从头训练模型时,可以使用 迁移学习(Transfer Learning)。
文章目录
1 项目准备
在进行我们的项目之前,我们先了解一下迁移学习的概念,然后再介绍一下我们将基于迁移学习完成的项目。
1.1 迁移学习
迁移学习的核心思想是利用预训练模型在某些任务上的学习成果,将其重新训练以适应具有一定重叠的新任务。一个类比是:一个擅长绘画的艺术家学会使用木炭作画时,他的绘画经验会非常有用。例如,使用一个擅长识别汽车种类的预训练模型,可以进一步训练其识别摩托车种类。
迁移学习特别适用于数据不足的场景。在这种情况下,直接从头训练的模型容易过拟合,而迁移学习可以在小数据集上训练出准确且鲁棒性强的模型。
迁移学习可以缓解过拟合,原因如下:
- 利用预训练特征:预训练模型(如在 ImageNet 上训练的 VGG16)已经学习到通用的低层次特征(如边缘、纹理),这些特征可以很好地适应许多视觉任务。通过迁移学习,模型无需从头学习这些特征,可以直接在小数据集上调整最后几层以适应特定任务。
- 减少参数需要训练:冻结基础层的参数只更新新层的参数,大幅降低了训练所需的数据量,减少了过拟合的风险。
- 提高泛化能力:预训练模型通过在大规模数据集上训练过,具备了良好的泛化能力,能够在小数据集上表现出更好的性能。
1.2 项目介绍
上一篇文章中,我们用预训练的 ImageNet 模型来识别所有的狗。然而,这次我们希望创建一个只允许特定狗进入的宠物门。以一只名为 Bo 的狗为例,我们有 Bo 的 30 张照片,但预训练模型并未专门训练过这只狗。
如果直接用这 30 张照片从头训练模型,会导致过拟合。通过迁移学习,我们可以利用预训练模型中识别狗的能力,在小数据集上实现对 Bo 的识别。
本篇文章的目标如下:
- 准备一个预训练模型以进行迁移学习。
- 使用自己的小型数据集对预训练模型进行迁移学习。
- 进一步微调模型以提升性能。
2 代码学习
和之前一样,在学习代码之前,我们先加载需要用的库并设置一下GPU:
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.v2 as transforms
import torchvision.io as tv_io
import glob
import json
from PIL import Image
import utils
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.is_available()
import torch._dynamo
torch._dynamo.config.suppress_errors = True
2.1 下载模型和准备工作
ImageNet 提供的 torchvision.models
是计算机视觉迁移学习的理想选择,因为这些模型已在多种图像类型上训练。我们将使用 VGG16 模型,并加载其默认权重。
2.1.1 下载预训练模型
from torchvision.models import vgg16
from torchvision.models import VGG16_Weights
weights = VGG16_Weights.DEFAULT
vgg_model = vgg16(weights=weights)
看一下VGG模型的组成:
vgg_model.to(device)
输出:
VGG(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace=True)
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(6): ReLU(inplace=True)
(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): ReLU(inplace=True)
(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace=True)
(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(13): ReLU(inplace=True)
(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): ReLU(inplace=True)
(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(18): ReLU(inplace=True)
(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(20): ReLU(inplace=True)
(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(22): ReLU(inplace=True)
(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(25): ReLU(inplace=True)
(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(27): ReLU(inplace=True)
(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(29): ReLU(inplace=True)
(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
(classifier): Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace=True)
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace=True)
(5): Dropout(p=0.5, inplace=False)
(6): Linear(in_features=4096, out_features=1000, bias=True)
)
)
ImageNet 模型的最后一层是一个包含 1000 个单元的全连接层,表示 1000 个类别。为了识别 Bo,我们需要将最后一层替换为一个新层,用于分类是否为 Bo。
2.1.2 冻结基础模型
在添加新层之前,我们需要冻结预训练模型的基础层。冻结层意味着在训练过程中,这些层的参数不会更新,只会训练新添加的层。
- 前几层学习到的特征(如边缘和纹理)通常是通用的,适用于大多数视觉任务:低层特征如边缘和角点在所有视觉任务中都是基础,几乎不依赖于具体的数据集。
- 最早的几层只关注像素之间的局部关系,与高层语义无关。
冻结层的原因:
- 保留模型在 ImageNet 数据集上的学习成果。
- 防止破坏这些预训练的特性。
通过设置 requires_grad_
为 False
,即可冻结参数:
vgg_model.requires_grad_(False)
如果需要解冻层(例如微调阶段),可以将 requires_grad_
设置为 True
。
2.1.3 添加新层
现在我们可以在预训练模型的基础上添加新的可训练层。新层将基于预训练层提取的特征,在新数据集上生成分类结果。
以下代码将 VGG16 的输出(1000 个单元)连接到一个神经元,用于二分类任务(Bo 或非 Bo):
N_CLASSES = 1
my_model = nn.Sequential(
vgg_model,
nn.Linear(1000, N_CLASSES)
)
my_model.to(device)
输出:
Sequential(
(0): VGG(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace=True)
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(6): ReLU(inplace=True)
(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): ReLU(inplace=True)
(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace=True)
(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(13): ReLU(inplace=True)
(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): ReLU(inplace=True)
(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(18): ReLU(inplace=True)
(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(20): ReLU(inplace=True)
(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(22): ReLU(inplace=True)
(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(25): ReLU(inplace=True)
(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(27): ReLU(inplace=True)
(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(29): ReLU(inplace=True)
(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
(classifier): Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace=True)
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace=True)
(5): Dropout(p=0.5, inplace=False)
(6): Linear(in_features=4096, out_features=1000, bias=True)
)
)
(1): Linear(in_features=1000, out_features=1, bias=True)
)
现在我们可以通过循环查看模型参数是否被冻结:
for idx, param in enumerate(my_model.parameters()):
print(idx, param.requires_grad)
输出:
0 False
1 False
2 False
3 False
4 False
5 False
6 False
7 False
8 False
9 False
10 False
11 False
12 False
13 False
14 False
15 False
16 False
17 False
18 False
19 False
20 False
21 False
22 False
23 False
24 False
25 False
26 False
27 False
28 False
29 False
30 False
31 False
32 True
33 True
如果需要解冻 VGG 层进行训练,可以执行以下操作,执行后再循环查看模型参数,全部都会为True。
vgg_model.requires_grad_(True)
这里就不解冻了,因为我们不需要重复训练VGG层。
2.1.4 编译模型
与之前的练习类似,我们需要为模型指定损失函数和评估指标,但本次需要做一些不同的选择。
-
损失函数:由于这是一个二分类问题(Bo 或 not Bo),我们选择二元交叉熵(
binary crossentropy
)作为损失函数,而不是之前用于多分类问题的分类交叉熵(categorical crossentropy
)。 -
评估指标:使用
binary accuracy
代替传统的accuracy
,以适应二分类任务的需求。- 在二分类任务中,输出值是一个连续的分数(通常是
logits
或概率值)。通过一个阈值(通常是 0.5),判断模型预测的类别为 0 或 1。
- 在二分类任务中,输出值是一个连续的分数(通常是
-
输出设置:通过设置
from_logits=True
告知损失函数输出值未经过归一化处理(例如没有使用 softmax)。logits
是神经网络最后一层的输出值,通常未经过归一化处理。使用归一化函数(如sigmoid
或softmax
)会将输出值转化为概率。但在很多情况下(如使用交叉熵损失函数时),直接传入logits
会更高效,因为损失函数内部会自动处理归一化。from_logits=True
告知损失函数输入的是logits
,它会在内部应用归一化。
loss_function = nn.BCEWithLogitsLoss() # 二元交叉熵损失函数
optimizer = Adam(my_model.parameters()) # Adam 优化器
my_model = my_model.to(device) # 将模型移动到设备
BCEWithLogitsLoss
无参数情况下,默认设置from_logits=True
2.2 数据处理
与上一篇文章一样,我们将创建一个自定义 Dataset
类来读取 Bo 和非 Bo 的图像数据。首先,从 VGG 权重中获取pre_trans
。
pre_trans = weights.transforms()
2.2.1 数据集
这里,就不用pandas中的DataFrame
加载图片和标签了,我们直接从图像文件中读取数据。通过文件路径推断标签。假设 data_dir="train/"
,文件夹下有文件夹结构如下:
train/
├── bo/
│ ├── bo1.jpg
│ ├── bo2.jpg
├── not_bo/
├── not_bo1.jpg
├── not_bo2.jpg
- 文件夹下有20张bo的照片(还有10张bo的照片用作validation),另外有120张为其它狗(
not_bo
)的图片
下面是自定义数据集类的代码,具体看注释:
# 定义了两个类别标签,“bo”表示目标类别,“not_bo”表示非目标类别
DATA_LABELS = ["bo", "not_bo"]
# 自定义 PyTorch 数据集类
class MyDataset(Dataset):
def __init__(self, data_dir): # 初始化方法,接收数据目录路径作为参数
self.imgs = [] # 用于存储预处理后的图像
self.labels = [] # 用于存储图像对应的标签
# 遍历每个标签类别及其索引
for l_idx, label in enumerate(DATA_LABELS):
# 使用 glob 查找每个类别文件夹下的所有 .jpg 文件路径
# data_dir 是数据根目录,例如 "train/",每个类别的路径为 train/bo/ 或 train/not_bo/
data_paths = glob.glob(data_dir + label + '/*.jpg', recursive=True)
# 遍历查找到的图像路径列表
for path in data_paths:
# 打开图像文件,将其加载为 PIL 图像对象
img = Image.open(path)
# 对图像进行预处理(如调整大小、标准化等),并将其转换为 PyTorch 张量
self.imgs.append(pre_trans(img).to(device))
# 根据类别索引 l_idx(0 表示 bo,1 表示 not_bo),生成标签张量
self.labels.append(torch.tensor(l_idx).to(device).float())
# 定义数据集中的单个样本提取逻辑
def __getitem__(self, idx):
# 获取指定索引 idx 对应的图像和标签
img = self.imgs[idx] # 图像张量
label = self.labels[idx] # 标签张量
return img, label # 返回图像和标签,供 DataLoader 使用
# 定义数据集的总样本数
def __len__(self):
return len(self.imgs) # 返回数据集中样本的总数
现在我们可以用自定义的 MyDataset
类创建数据加载器:
n = 32 # 批次大小
train_path = "presidential_doggy_door/train/"
train_data = MyDataset(train_path)
train_loader = DataLoader(train_data, batch_size=n, shuffle=True)
train_N = len(train_loader.dataset)
valid_path = "presidential_doggy_door/valid/"
valid_data = MyDataset(valid_path)
valid_loader = DataLoader(valid_data, batch_size=n)
valid_N = len(valid_loader.dataset)
2.2.2 数据增强
为了让模型更好地识别 Bo,我们应用一些数据增强方法。这次的图像是彩色的,因此可以充分利用 ColorJitter
进行颜色调整。
IMG_WIDTH, IMG_HEIGHT = (224, 224)
random_trans = transforms.Compose([
transforms.RandomRotation(25), # 随机旋转
transforms.RandomResizedCrop((IMG_WIDTH, IMG_HEIGHT), scale=(.8, 1), ratio=(1, 1)), # 随机裁剪
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.ColorJitter(brightness=.2, contrast=.2, saturation=.2, hue=.2) # 颜色抖动
])
2.3 训练和验证模型
训练循环大部分与之前的相同,但有一些细微的差异:
2.3.1 计算准确率函数
首先来看一下get_batch_accuracy
函数:由于使用二元交叉熵(Binary Cross Entropy)作为损失函数,我们不需要将模型输出通过 sigmoid
激活函数处理。通过观察可以发现:
- 当模型输出大于 0 时,
sigmoid
输出接近 1;小于 0 时,sigmoid
输出接近 0。 - 因此,只需检查模型输出是否大于 0 即可判断类别。
def get_batch_accuracy(output, y, N):
zero_tensor = torch.tensor([0]).to(device) # 创建值为 0 的张量
pred = torch.gt(output, zero_tensor) # 判断模型输出是否大于 0
correct = pred.eq(y.view_as(pred)).sum().item() # 比较预测与真实标签
return correct / N # 计算准确率
2.3.2 训练函数
下面是模型训练的函数,新增部分用于显示模型最后一层的梯度,以验证只有新添加的层在学习。
def train(model, check_grad=False):
loss = 0
accuracy = 0
model.train() # 设置模型为训练模式
for x, y in train_loader:
output = torch.squeeze(model(random_trans(x))) # 模型预测并压缩维度
optimizer.zero_grad() # 梯度清零
batch_loss = loss_function(output, y) # 计算批次损失
batch_loss.backward() # 反向传播
optimizer.step() # 更新模型参数
loss += batch_loss.item() # 累计损失
accuracy += get_batch_accuracy(output, y, train_N) # 累计准确率
# 可选:打印最后一层的梯度
if check_grad:
print('Last Gradient:')
for param in model.parameters():
print(param.grad)
print('Train - Loss: {:.4f} Accuracy: {:.4f}'.format(loss, accuracy))
简单解释一下torch.squeeze(model(random_trans(x)))
,假设 x
的形状为
(
32
,
3
,
224
,
224
)
(32, 3, 224, 224)
(32,3,224,224)(一个批次包含 32 张 RGB 图像,每张大小为
224
×
224
224 \times 224
224×224)。
- 经过
random_trans(x)
图像增强后,图像形状保持不变: ( 32 , 3 , 224 , 224 ) (32, 3, 224, 224) (32,3,224,224)。 - 经过
model(...)
,得到 logits 输出,形状为 ( 32 , 1 ) (32, 1) (32,1)。 - 使用
torch.squeeze()
后,形状变为 ( 32 ) (32) (32),便于后续计算损失或进行分类判断。
由于 VGG16 的最后一层有 1000 个神经元,与下一层的单个神经元连接,因此会输出大量梯度值:
train(my_model, check_grad=True)
输出:
Last Gradient:
None
...略
None
tensor([[-1.7544e-01, -3.2742e-01, -3.5025e-01, 2.8259e-01, 5.8005e-01,
6.0872e-01, 3.2710e-01, 5.9767e-02, 3.4024e-02, 4.9906e-01,
-7.9312e-01, -7.4066e-01, -6.4248e-01, -6.1524e-01, -1.7711e-01,
-5.3027e-01, -5.0837e-01, -6.0457e-01, 1.3494e-01, -4.8378e-01,
-5.0667e-01, 4.3823e-01, 5.3120e-01, 7.9974e-01, -8.5800e-02,
-7.0989e-01, -5.7805e-01, -7.8004e-01, -9.1848e-01, -6.9956e-01,
-6.6547e-01, -5.1254e-01, -7.2260e-01, 3.8824e-01, 3.8325e-01,
-2.9148e-01, 1.7045e-01, -1.3857e+00, -5.2024e-01, -5.3980e-02,
-6.3483e-01, -1.1288e+00, -6.5688e-01, -1.9186e-02, -1.0455e+00,
-5.9199e-01, -6.3787e-01, 1.0210e-01, -7.0408e-01, -4.0892e-01,
6.4533e-02, 8.2972e-01, -1.0984e+00, -1.1901e+00, -4.0869e-01,
-7.1822e-01, -1.1695e+00, -1.2966e+00, -3.8214e-01, -6.7613e-01,
-8.1616e-01, -1.1129e+00, -1.0367e+00, -4.4582e-01, -2.3638e-01,
1.9191e-01, -4.0016e-01, -3.6626e-01, -5.3753e-01, 2.2426e-01,
-9.8173e-01, -4.1114e-01, -6.1928e-01, -4.3241e-01, -5.0240e-01,
-1.1249e+00, -6.7539e-01, -7.8897e-01, -4.9502e-01, -7.2221e-01,
9.9532e-01, -4.8096e-01, -8.4775e-02, -1.2387e-01, -1.2294e-01,
4.9023e-02, -4.8676e-01, -5.4947e-01, -2.4885e-01, 2.4217e-02,
-1.0428e+00, 4.5021e-03, -2.6639e-01, -1.8733e-01, 5.2851e-02,
-8.1037e-01, 4.4159e-02, -1.1009e+00, -1.0536e+00, 2.5601e-01,
5.7655e-01, 8.1418e-01, -9.5628e-01, 3.1177e-01, 2.9675e-01,
-3.3930e-01, 1.4372e-01, 1.3638e-02, -5.7496e-01, 7.0985e-01,
-2.3029e-01, -3.6438e-02, 7.8828e-01, -7.4832e-03, -3.1391e-01,
-3.0802e-01, -1.4903e-01, -1.8410e-01, -5.9939e-01, -4.1235e-01,
-1.0097e+00, -4.9801e-01, -9.2671e-01, -3.5769e-01, -6.8458e-01,
-6.7180e-01, -5.0943e-02, 2.7029e-01, 4.5693e-01, -2.6082e-01,
-4.0161e-01, -4.1856e-02, 2.6322e-01, -4.7891e-01, -3.2046e-02,
1.2394e-01, -2.3324e-01, 3.1666e-02, -2.8982e-01, -9.6604e-01,
-9.0639e-01, -5.0243e-01, -5.8052e-01, -5.9692e-03, 7.8434e-04,
-4.7410e-02, 1.0430e-01, -5.0174e-01, 6.0396e-01, -5.4888e-02,
6.2739e-01, 3.6291e-01, -4.6046e-01, -2.3584e-01, -6.4077e-01,
-1.7806e-01, -7.8082e-02, -2.7355e-01, -5.1460e-01, 2.6626e-01,
3.0177e-01, 2.3867e-01, -2.1095e-01, 5.0674e-01, 2.2907e-01,
6.3274e-02, -4.5064e-01, -2.7952e-01, -9.2401e-02, -6.0889e-01,
1.6224e-01, -4.1656e-01, -3.4775e-01, -1.3679e-01, 9.6529e-02,
4.2382e-02, -2.5797e-01, 1.3562e-01, 2.4578e-01, 7.0888e-01,
-2.2414e-02, 3.2868e-01, -9.1382e-02, 1.3297e+00, 5.1516e-01,
2.1450e-01, 2.7332e-01, -3.1009e-02, 3.6133e-01, 8.1713e-01,
5.3174e-01, 8.7230e-01, 6.7090e-02, -3.6013e-01, -2.1139e-01,
5.1516e-01, 5.2809e-01, 1.4728e+00, 8.8910e-01, 1.1745e+00,
8.9725e-02, -5.2408e-01, 3.3150e-01, 8.0386e-01, -1.7271e-01,
5.3883e-01, 8.5256e-01, -1.4709e-01, 5.7014e-01, 4.0775e-01,
1.1537e-01, 4.2447e-01, 2.3372e-01, 1.4613e-01, 6.2075e-01,
-2.5229e-01, 1.1984e-01, 4.4006e-01, 6.0319e-02, 5.2251e-01,
3.7996e-01, 7.2196e-01, 3.1009e-01, 8.5933e-01, 8.3837e-01,
-3.8384e-02, 7.5061e-01, 4.8474e-01, 2.3170e-01, 9.2325e-02,
-5.4723e-01, -1.6103e-01, -1.5929e-01, 1.2951e+00, 4.6359e-01,
2.3991e-01, 2.5596e-01, 3.2791e-01, 3.1061e-01, 2.2561e-01,
5.7878e-01, 5.0774e-01, 7.9616e-02, 1.8229e-01, -1.0505e-02,
3.3951e-01, 5.7757e-01, 6.9330e-02, -2.0447e-01, 1.6851e-01,
2.4136e-01, -3.3566e-01, 3.2309e-01, 4.5403e-03, 5.9609e-01,
-2.2460e-01, 1.1878e+00, -1.9196e-01, -1.0095e-01, -5.0410e-02,
4.0296e-01, -3.3743e-01, 4.9154e-01, 4.5181e-02, 6.5029e-02,
4.5169e-01, 3.1957e-01, 1.1242e+00, -8.2551e-02, 1.0428e-01,
-2.8535e-01, -1.8505e-01, 5.6668e-02, 9.5634e-02, -2.2666e-01,
-1.4186e-01, -2.3955e-01, -1.7384e-01, -4.2970e-01, -2.5679e-01,
8.6496e-02, -8.3843e-02, -1.3523e-01, -3.4327e-01, -3.0160e-01,
1.5758e-01, 9.4906e-01, 5.4506e-02, 2.5300e-01, -3.3435e-01,
2.5600e-01, -1.0302e-01, -9.5671e-02, 1.6541e-02, 1.1849e+00,
1.6133e+00, 6.2993e-01, 5.1954e-01, 6.8446e-02, 1.4106e-02,
-8.3469e-01, -4.4179e-01, -4.3720e-01, -2.9500e-01, -1.4123e-01,
-4.0103e-01, -2.2307e-01, -4.1705e-01, -2.5477e-01, -8.8754e-02,
1.5583e-01, -6.9193e-01, -4.3465e-01, -3.4695e-01, -4.9969e-01,
-4.9844e-01, -6.3136e-01, -1.1276e+00, -5.8314e-01, -5.6808e-01,
-9.9875e-01, -5.0705e-02, -8.6277e-02, -4.4018e-01, -1.2385e-01,
-2.9588e-01, -3.8404e-01, 1.1681e-02, -9.0235e-03, 1.5771e-01,
1.1791e-01, 1.4297e-01, -6.2983e-01, -9.5536e-01, -2.5565e-01,
-1.5391e-01, 1.2998e-01, 1.9953e-01, -5.2096e-01, -1.4853e-01,
5.1128e-01, 3.8770e-01, 4.6292e-01, 7.8133e-01, -3.4823e-02,
1.1199e+00, 8.2565e-01, 1.4222e+00, 4.8996e-01, 5.9679e-01,
1.5140e+00, 3.5912e-01, 4.9130e-01, 6.8880e-01, 1.3371e+00,
8.1143e-01, 2.2543e-01, 1.6702e-01, -3.1753e-01, -3.3345e-01,
7.3151e-02, 9.9862e-02, -6.7115e-02, -1.2147e-01, 7.2379e-01,
5.6254e-01, 9.0620e-01, 5.7460e-01, 1.9504e-01, 1.1463e+00,
-4.3740e-01, 9.8323e-02, 5.0859e-01, -4.9404e-02, 3.8926e-01,
4.8849e-01, -3.2364e-02, -5.1617e-01, 1.7361e-01, 6.1864e-01,
1.3550e-01, 5.9990e-01, -3.6677e-01, -1.8845e-02, 5.2383e-01,
7.1532e-01, 9.4754e-01, -3.3795e-01, 2.0504e-01, -1.2463e-01,
-4.9789e-01, -4.6760e-01, 3.5805e-01, -4.7438e-01, 1.4768e-02,
5.6401e-02, 3.6824e-01, -3.5621e-01, -6.0544e-01, 1.1242e+00,
5.3035e-01, -1.5544e-01, 8.4991e-03, 1.2774e-01, -2.0913e-02,
6.4669e-01, 1.3271e-01, 2.2336e-01, 4.2172e-01, 8.3584e-01,
7.0911e-01, -1.6691e-01, 2.5150e-01, 1.1925e+00, 3.2783e-01,
-3.8567e-01, -4.2771e-01, 1.3689e+00, -9.3467e-01, -5.9124e-01,
8.7705e-02, -8.3544e-02, -4.1152e-01, -4.7636e-01, -8.4075e-01,
6.4817e-01, 1.3138e-01, -3.8658e-03, 3.2083e-01, -6.7445e-01,
-1.8609e-01, -3.5245e-01, 4.6682e-01, 1.1039e-01, -4.9245e-02,
-1.5886e-01, -3.4515e-01, 1.0884e+00, -6.1701e-02, 2.0826e-01,
-3.2728e-01, -3.7305e-01, 7.6063e-01, 1.2500e-01, -2.3529e-01,
2.9716e-01, -2.9628e-02, 6.9071e-01, 1.0946e+00, 5.0622e-01,
-6.7897e-01, 1.6814e-01, 1.6319e-01, -7.6587e-01, -4.9430e-01,
-6.2659e-01, 7.1690e-01, 2.4527e-01, 5.7834e-01, 6.6461e-01,
4.7335e-01, 2.6489e-01, 4.8830e-01, 4.2461e-01, -5.3125e-01,
8.2322e-01, -1.0523e+00, -9.7149e-01, -1.3763e-01, 4.4326e-01,
5.8212e-01, 6.6231e-01, 3.4307e-01, -5.8996e-01, 1.3648e-01,
2.1469e-01, -3.9137e-01, -8.4347e-01, 1.4832e-01, -2.9151e-01,
-2.1144e-01, -1.5212e-01, -5.6108e-01, 1.4632e+00, 3.7968e-01,
-9.9502e-01, 1.2784e-01, 1.5410e-01, 3.1072e-01, 2.7450e-01,
1.8879e-01, 6.7186e-01, -5.7458e-01, -5.5305e-01, 8.6939e-01,
-8.4140e-01, 1.0098e-01, 9.7234e-01, -3.5749e-01, -3.1363e-01,
1.0251e+00, 1.0649e+00, -7.4344e-01, -9.5872e-01, -2.6811e-01,
-2.5413e-02, -2.6947e-02, -8.0043e-01, -7.3723e-01, -1.3699e+00,
-1.4215e-02, -3.0558e-01, -3.5545e-01, 4.3493e-01, 1.1974e-01,
2.9585e-01, 1.9010e-02, 9.5524e-01, -2.6313e-01, -1.9986e-01,
-6.3724e-01, -6.8761e-01, -6.6525e-03, 2.1994e-01, 3.6776e-01,
5.3228e-01, -5.0982e-01, -1.7214e-01, -7.1451e-01, 4.2411e-01,
3.6325e-01, -5.8008e-03, -4.1071e-01, -1.9090e-01, -5.8616e-01,
-1.3731e+00, -2.3476e-01, 5.0639e-01, 6.2128e-01, -1.2105e-01,
4.5680e-01, -3.6777e-01, -4.8782e-01, -1.2036e-01, -3.5872e-01,
-3.1772e-01, -1.6101e-02, -6.8761e-01, -4.9851e-01, 6.4212e-01,
-1.3627e+00, -1.3873e-01, 3.2438e-01, -9.2687e-02, -1.3069e-01,
-3.5789e-01, -2.1601e-01, 8.5684e-01, 6.1687e-01, -2.6308e-01,
-5.7640e-01, 1.5337e-02, 9.8325e-01, -8.2864e-01, -4.2301e-01,
-1.0946e-01, 2.0482e-03, -4.9955e-01, 5.4178e-01, 5.4901e-01,
6.5611e-01, -8.7986e-01, -1.2359e-01, -7.4628e-01, 5.8283e-01,
6.5500e-01, 5.3964e-01, 2.5499e-01, 3.4695e-01, 2.9082e-01,
3.1107e-01, -9.9775e-01, -6.4548e-01, 3.5271e-01, -2.3684e-01,
-3.0687e-02, 9.6840e-01, -2.9441e-01, -1.7284e-01, -3.3753e-01,
-6.1289e-01, 2.8625e-01, -8.7487e-01, -6.6964e-02, 6.8912e-02,
7.0997e-01, 3.1675e-01, -2.8768e-01, -1.4001e-01, 5.8170e-01,
9.6576e-01, 3.0093e-01, 2.0030e-01, 6.5689e-01, -9.9078e-02,
2.7720e-02, -3.7819e-01, -4.4595e-01, -8.3836e-02, 7.5016e-01,
2.6113e-01, 8.4086e-01, -5.6746e-01, -2.4073e-01, -2.8014e-01,
9.8023e-02, 8.3700e-02, -2.0063e-01, -6.4018e-02, 2.4375e-01,
5.0879e-01, 3.7023e-01, -3.7670e-01, -3.0028e-01, -5.7714e-01,
-3.6847e-01, -3.3820e-01, -1.5884e-01, 1.1988e-01, -6.9787e-01,
-6.1233e-01, -3.3228e-01, 4.5991e-01, 3.0918e-01, -1.4384e-01,
1.3325e-01, 3.4373e-01, -1.6447e-01, -1.7332e-01, 3.6048e-01,
-7.8606e-01, -3.6021e-01, -1.9213e-02, 2.1053e-01, 1.0852e-01,
5.6117e-01, 1.1758e+00, -7.9305e-01, -6.5506e-01, 1.6443e+00,
2.4420e-01, -6.1111e-01, 7.0643e-01, -2.1908e-01, 6.7593e-02,
1.0184e-02, -1.0817e-01, 1.0483e+00, 3.3803e-02, -8.2672e-01,
8.7884e-02, -6.3000e-02, -2.8742e-01, 9.9268e-01, 5.9845e-01,
-3.5003e-01, 3.0594e-01, 5.4412e-01, 9.4359e-01, -4.0021e-01,
5.6490e-02, 9.3258e-02, 5.9996e-01, -6.3922e-01, -2.1804e-01,
-1.4328e-02, 8.8924e-04, 3.2243e-01, 8.6198e-02, 2.4741e-01,
-3.2110e-01, 1.2544e-01, 1.4453e+00, 8.6298e-01, -4.7820e-01,
-9.9051e-01, -9.3073e-01, -1.3451e+00, -4.6813e-02, 1.2195e-01,
5.1505e-01, 1.9792e-01, -1.8230e-01, 5.4837e-01, -2.1014e-01,
1.2490e-01, 2.5163e-01, -5.6847e-01, 6.4455e-01, 6.8344e-01,
1.9394e-01, 1.1413e+00, -4.2705e-01, 4.0593e-01, -1.2405e-01,
-3.1259e-01, 3.5057e-01, -7.4957e-01, 8.5981e-01, -5.8320e-01,
-1.0674e-01, -8.8419e-02, -6.8084e-01, -3.9997e-01, -5.8208e-01,
2.5168e-01, 1.0437e+00, -1.1771e-01, 1.5982e-01, -4.5244e-01,
-7.6478e-01, -1.3536e-01, -6.9634e-01, 1.9466e-01, 8.7299e-01,
-2.1721e-01, -1.0015e+00, 4.1764e-01, 3.7579e-01, -7.7277e-01,
8.0956e-01, 3.5867e-01, -1.4754e-01, 8.2020e-01, -7.7433e-02,
1.0298e+00, -7.1787e-01, -4.7511e-01, 1.0310e+00, -5.7484e-01,
-1.7745e-01, 3.2878e-01, -4.4476e-01, 1.0948e-01, 8.1824e-01,
-2.4893e-03, -5.3822e-01, -2.2379e-01, 8.0561e-02, 9.3057e-02,
-9.5623e-02, -8.8346e-01, -2.7623e-01, -4.8847e-01, 1.1614e-01,
9.6560e-01, 3.5281e-01, 1.5146e-01, 1.5581e-01, 4.1637e-03,
-2.6449e-01, -4.3860e-01, -2.8917e-01, -5.3853e-05, 1.2238e+00,
-5.6862e-02, -9.0667e-01, -3.8450e-01, -3.6530e-01, -1.4925e-01,
-6.2542e-01, -6.2117e-01, 6.4625e-01, -7.5835e-01, -4.4316e-01,
6.1951e-01, 5.3488e-01, -7.2258e-01, -5.6171e-01, 2.3845e-02,
2.5880e-01, -2.2888e-01, 4.7781e-01, -4.4103e-02, -6.0243e-01,
-5.5530e-01, -5.4849e-01, 1.9147e-01, -9.2770e-01, -4.2939e-01,
-4.5498e-01, -7.2349e-02, 1.0604e+00, 8.3773e-03, 1.8364e-03,
6.3580e-01, 8.1394e-01, 2.0768e-01, -1.0080e+00, -3.4177e-01,
-1.0614e+00, 8.0248e-01, 3.7003e-01, 4.6856e-01, -3.0236e-01,
1.8149e-01, 1.8159e-01, 6.4092e-01, -1.7012e-03, -8.0130e-01,
-9.6853e-01, -7.2521e-01, 8.9102e-01, -3.9454e-01, -2.5631e-01,
-9.7604e-02, -6.3021e-01, -9.3553e-01, 9.1216e-01, 2.0788e-02,
4.9450e-01, 1.6375e-01, -2.0362e-01, -6.0301e-01, 6.2432e-01,
7.6297e-01, -2.5092e-01, -5.4573e-01, -6.8181e-01, -2.5224e-01,
5.3633e-01, -4.8921e-01, 1.0590e+00, 4.9470e-01, 3.3533e-01,
5.6461e-01, 2.6504e-01, 1.7099e-01, -6.9919e-02, 8.1155e-01,
7.1162e-01, 3.9807e-01, 8.1624e-01, 1.2990e+00, 3.4553e-01,
-2.0461e-01, 1.7764e-01, 1.2342e+00, -3.8718e-01, -2.0317e-01,
1.6723e-01, 6.6631e-01, 1.0489e-01, 5.4685e-01, -7.5599e-01,
-1.4645e-02, 3.1836e-01, -1.7702e-01, 1.3726e-01, -5.0550e-01,
-8.4961e-01, -1.5202e-01, 5.3466e-01, 8.5211e-01, -1.2047e-01,
-6.7310e-01, 9.9490e-01, -5.2963e-02, -1.7486e-01, -2.8276e-01,
-1.0952e-01, 1.0109e-01, 1.2935e+00, 1.0582e+00, -5.9848e-01,
1.7264e-01, -2.5981e-01, -7.7737e-01, -9.1011e-01, 7.2406e-01,
5.8416e-01, -3.2779e-01, -1.9431e-01, 3.1548e-01, 3.1757e-01,
3.0605e-01, -6.7280e-01, 1.3216e-01, 9.0023e-01, 8.7088e-02,
-2.3721e-01, -9.0640e-01, 6.7392e-01, -2.6667e-01, -2.8760e-01,
4.9407e-01, -3.8842e-01, -8.6777e-02, -7.8174e-02, -1.6123e-01,
4.4536e-01, -5.0251e-01, 3.4807e-01, 2.3365e-02, 1.1720e-01,
-1.7411e-01, -5.8415e-01, -2.9320e-01, -4.2819e-02, -3.0250e-01,
-5.3503e-01, 7.4629e-01, 4.8132e-01, 1.2916e+00, 1.4694e-01,
1.2384e-01, 7.7418e-01, 2.5849e-01, -1.7581e-02, 6.2081e-01,
8.9407e-01, 1.1595e+00, -1.9118e-01, -6.4152e-01, -1.0393e-01,
-6.5908e-01, -1.5315e+00, -9.6489e-01, -5.5077e-01, -1.5004e-01,
-1.3946e+00, -8.9819e-01, -5.9983e-01, -9.7032e-01, -9.9386e-01,
-6.5717e-01, 4.9145e-01, 9.7614e-01, 5.0756e-01, -4.2993e-01,
-4.1004e-01, -1.9161e-01, -8.2460e-01, 3.2447e-02, -7.9599e-02,
-6.7832e-01, -3.0543e-01, -6.7015e-02, 1.5095e-01, -5.0903e-01,
-1.5201e-01, -6.0907e-02, -1.9378e-01, -5.4632e-02, -1.3574e-01,
3.0301e-01, 1.5639e-01, -6.6806e-01, 1.2669e+00, -6.8265e-01,
-4.1642e-01, -9.0397e-01, -7.3187e-01, -1.2412e+00, -1.1058e+00,
-7.9914e-01, -2.9846e-01, -9.6234e-01, -3.2566e-01, -7.7322e-01,
1.5191e+00, 8.5183e-01, 2.0483e+00, 1.3316e+00, 1.6647e-01,
1.4621e+00, 1.3676e+00, 7.7252e-01, 1.3087e+00, 1.2429e+00,
1.2008e+00, -2.9195e-01, 8.4185e-01, 1.2080e+00, 1.2942e+00,
-5.9255e-01, -5.6795e-01, 8.3677e-02, 8.2386e-01, 2.0821e-01,
7.7015e-01, 3.1907e-01, -8.4631e-01, 3.4039e-01, -7.6810e-02,
5.7064e-01, 4.7979e-01, -7.5530e-01, -8.3638e-02, -1.9206e-02]])
tensor([0.2733])
Train - Loss: 2.9310 Accuracy: 0.8417
2.3.3 验证函数
验证阶段的逻辑基本保持不变。设置模型为评估模式,关闭梯度计算以节省内存和计算资源。
def validate(model):
loss = 0
accuracy = 0
model.eval() # 设置模型为评估模式
with torch.no_grad(): # 禁用梯度计算
for x, y in valid_loader:
output = torch.squeeze(model(x)) # 模型预测并压缩维度
loss += loss_function(output, y.float()).item() # 累计损失
accuracy += get_batch_accuracy(output, y, valid_N) # 累计准确率
print('Valid - Loss: {:.4f} Accuracy: {:.4f}'.format(loss, accuracy))
2.4 测试代码
下面就来执行上面写的代码:
epochs = 10
for epoch in range(epochs):
print('Epoch: {}'.format(epoch))
train(my_model, check_grad=False)
validate(my_model)
输出:
Epoch: 0
Train - Loss: 1.9242 Accuracy: 0.8705
Valid - Loss: 0.8323 Accuracy: 0.6667
Epoch: 1
Train - Loss: 1.4564 Accuracy: 0.8489
Valid - Loss: 0.5792 Accuracy: 0.7333
Epoch: 2
Train - Loss: 1.1878 Accuracy: 0.8777
Valid - Loss: 0.4356 Accuracy: 0.7667
Epoch: 3
Train - Loss: 1.2490 Accuracy: 0.8921
Valid - Loss: 0.4100 Accuracy: 0.8000
Epoch: 4
Train - Loss: 1.2821 Accuracy: 0.8849
Valid - Loss: 0.3632 Accuracy: 0.8333
Epoch: 5
Train - Loss: 0.9136 Accuracy: 0.9209
Valid - Loss: 0.3413 Accuracy: 0.8333
Epoch: 6
Train - Loss: 1.1618 Accuracy: 0.9209
Valid - Loss: 0.2976 Accuracy: 0.8333
Epoch: 7
Train - Loss: 1.1632 Accuracy: 0.9281
Valid - Loss: 0.2149 Accuracy: 0.9000
Epoch: 8
Train - Loss: 1.0856 Accuracy: 0.9281
Valid - Loss: 0.1250 Accuracy: 0.9667
Epoch: 9
Train - Loss: 0.8895 Accuracy: 0.9353
Valid - Loss: 0.1047 Accuracy: 0.9667
损失和准确率曲线如下:
从结果可以发现,模型的训练和验证准确率都应该非常高,尽管我们仅使用了一个小型数据集,但由于从 ImageNet 模型中转移了知识,模型能够实现较高的准确率,并很好地泛化。这表明模型已经能够很好地识别 Bo 和其他宠物。
2.5 模型微调
现在模型的新层已经经过训练,我们可以应用一个最终的技巧——微调(Fine-Tuning)。
- 微调的核心:解冻整个模型的预训练层,并使用非常小的学习率重新训练。这将使预训练层进行细微调整,从而进一步提升模型性能。
- 小学习率的作用:VGG16 是一个相对较大的模型,小学习率能够防止在小数据集上训练时发生过拟合,避免对预训练特征进行过度修改。
- 大模型的高容量允许其拟合复杂的模式和细节,但也容易学习到训练数据中的噪声。
为什么要在冻结层训练完成后再进行微调?
- 新添加的线性层是随机初始化的,在训练时需要进行大量更新以正确分类图像。
- 如果直接解冻所有层,在反向传播过程中,最后几层的大更新会导致预训练层的参数也发生大幅度变化,从而破坏重要的预训练特征。
- 现在新层已经训练收敛,解冻后进行小更新(尤其是使用小学习率)不会破坏早期层的特征。
解冻模型并进行微调
以下是解冻预训练层并进行微调的代码:
# 解冻基础模型
vgg_model.requires_grad_(True)
# 使用非常小的学习率
optimizer = Adam(my_model.parameters(), lr=.000001)
# 进行微调训练
epochs = 2
for epoch in range(epochs):
print('Epoch: {}'.format(epoch))
train(my_model, check_grad=False) # 训练模型
validate(my_model) # 验证模型
输出:
Epoch: 0
Train - Loss: 1.0701 Accuracy: 0.8921
Valid - Loss: 0.1145 Accuracy: 0.9667
Epoch: 1
Train - Loss: 1.0529 Accuracy: 0.9137
Valid - Loss: 0.1246 Accuracy: 0.9667
- 在微调阶段,仅需少量 epoch(如 2),以避免过拟合。
2.6 预测测试
现在我们已经训练好了模型,可以开始创建属于 Bo 的宠物门了。首先,我们需要对图像进行与之前相同的预处理,并查看模型的预测结果。
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
# 显示图像
def show_image(image_path):
image = mpimg.imread(image_path)
plt.imshow(image)
# 进行预测
def make_prediction(file_path):
show_image(file_path) # 显示图像
image = Image.open(file_path) # 打开图像文件
image = pre_trans(image).to(device) # 对图像进行预处理
image = image.unsqueeze(0) # 添加批次维度
output = my_model(image) # 模型预测
prediction = output.item() # 提取预测值
return prediction
测试预测功能
make_prediction('presidential_doggy_door/valid/bo/bo_20.jpg')
输出:
make_prediction('presidential_doggy_door/valid/not_bo/121.jpg')
输出:
看起来一个负数的结果意味着它是Bo,而正数的预测意味着不是。所以我们可以继续完成Bo的验证函数:
# 宠物门逻辑
def bo_doggy_door(image_path):
pred = make_prediction(image_path) # 获取预测值
if pred < 0:
print("It's Bo! Let him in!") # 如果是 Bo,允许进入
else:
print("That's not Bo! Stay out!") # 如果不是 Bo,禁止进入
测试宠物门
bo_doggy_door('presidential_doggy_door/valid/not_bo/131.jpg')
输出:
bo_doggy_door('presidential_doggy_door/valid/bo/bo_29.jpg')
输出:
3 总结
通过迁移学习,我们在一个非常小的数据集上构建了一个高准确率的模型。这是一种非常强大的技术,在没有大数据集支持的情况下,能够帮助项目取得成功。