冻结层设置不生效?深度剖析nn.Module.requires_grad陷阱

部署运行你感兴趣的模型镜像

第一章:PyTorch中参数冻结的核心机制

在深度学习模型训练过程中,参数冻结是一种常见且关键的技术手段,尤其在迁移学习和微调(fine-tuning)场景中广泛应用。通过冻结部分网络层的参数,可以防止其在反向传播过程中更新,从而保留预训练模型中的特征提取能力,同时降低计算开销。

参数冻结的基本原理

PyTorch 中每个 nn.Parameter 对象都包含一个 requires_grad 属性,该属性控制是否对该参数进行梯度计算。当设置为 False 时,该参数不会被追踪梯度,也不会参与优化器的更新。
  • 默认情况下,所有模型参数的 requires_grad=True
  • 冻结参数即遍历目标层并将其参数的 requires_grad 设为 False
  • 仅需训练的层保持 requires_grad=True

实现参数冻结的代码示例

以下代码展示如何冻结预训练模型的前几层:
# 加载预训练模型
import torch
import torchvision.models as models

model = models.resnet18(pretrained=True)

# 冻结所有参数
for param in model.parameters():
    param.requires_grad = False

# 解冻最后的全连接层,使其可训练
model.fc = torch.nn.Linear(model.fc.in_features, 10)
model.fc.requires_grad = True
上述代码首先冻结整个 ResNet-18 模型的参数,随后替换最后的分类层,并允许其参与梯度更新。这种策略常用于图像分类任务的迁移学习。

冻结状态的验证方法

可通过以下方式检查模型中可训练参数的数量:
def count_trainable_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"可训练参数数量: {count_trainable_params(model)}")
操作requires_grad 状态是否参与梯度更新
冻结参数False
解冻参数True

第二章:nn.Module参数冻结的理论基础与常见误区

2.1 requires_grad属性的本质与张量追踪原理

在PyTorch中,`requires_grad`是自动微分机制的核心开关。当设置为`True`时,张量的操作将被动态追踪并构建计算图,用于后续梯度回传。
张量追踪机制
只有`requires_grad=True`的张量参与的操作才会被记录。系统通过`grad_fn`属性保存生成该张量的函数,形成反向传播链。
import torch
x = torch.tensor(3.0, requires_grad=True)
y = x ** 2
print(y.grad_fn)  # <PowBackward0 object>
上述代码中,`y`由幂运算生成,其`grad_fn`指向`PowBackward0`,表示反向传播时将使用幂函数的导数规则。
计算图的动态构建
PyTorch采用动态计算图(Dynamic Computation Graph),每次前向传播都会重新构建图结构,灵活性高。
张量requires_grad是否参与追踪
xTrue
y = x**2自动继承
z = y.detach()False

2.2 模型层冻结的正确语义与反模式分析

模型层冻结是深度学习训练中控制参数更新的重要手段,其核心语义在于通过设置 `requires_grad=False` 来阻止特定网络层的梯度计算。
正确使用方式
for param in model.backbone.parameters():
    param.requires_grad = False
上述代码将主干网络参数冻结,仅允许后续层参与梯度更新。该操作应在优化器创建前完成,否则优化器会保留对已冻结参数的引用,导致内存浪费。
常见反模式
  • 在优化器创建后冻结参数,导致冻结失效
  • 误用 `.eval()` 替代冻结,仅停用Dropout/BatchNorm但不阻止梯度传播
  • 部分冻结时未验证参数状态,引发意外更新
冻结策略对比
策略适用场景风险
全层冻结特征提取器复用灵活性差
逐层选择性冻结渐进式微调配置复杂

2.3 参数注册与autograd图构建的关系剖析

在深度学习框架中,参数注册是模型构建的基础步骤。当用户通过 nn.Parameter 将张量标记为可训练参数时,该张量会被自动加入到模型的参数列表中,并触发后续的梯度追踪机制。
参数注册触发梯度追踪
一旦参数被注册,框架会在前向传播过程中将其参与的所有运算记录到 autograd 图中。每个操作都会生成一个对应的 Function 节点,用于反向传播时计算梯度。
import torch
import torch.nn as nn

class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(3, 3))  # 参数注册
        self.bias = nn.Parameter(torch.zeros(3))

    def forward(self, x):
        return x @ self.weight + self.bias  # 构建计算图
上述代码中,nn.Parameter 的使用使得 weightbias 被自动追踪。在调用 forward 时,矩阵乘法和加法操作将被记录至 autograd 图,形成从输入到输出的完整导数路径。
autograd图的动态构建过程
  • 每次前向传播时,所有涉及 Parameter 的操作都会动态构建计算图;
  • 图中节点保存了操作类型及输入输出关系,用于链式求导;
  • 反向传播时,autograd 引擎依据此图逐层计算并累积梯度。

2.4 子模块继承性问题:为什么freeze不传递?

在 Git 子模块管理中,`git submodule freeze` 并非原生命令,常被误认为可继承状态。实际上,子模块的“冻结”状态(即固定到特定提交)不会自动传递给父仓库的克隆者。
子模块的独立性机制
每个子模块是独立的 Git 仓库,其 HEAD 指向一个具体提交。父仓库仅记录该提交的哈希值。

# 父仓库记录子模块 commit hash
git add modules/utils
git commit -m "Freeze utils to v1.2"
上述操作仅保存子模块的当前提交指针,不锁定远程更新行为。
为何 freeze 不具传递性
  • 子模块克隆后默认处于“分离头指针”状态
  • 执行 git submodule update 会同步最新记录的提交,而非保留原始冻结策略
  • 无元数据字段标记“冻结意图”,故无法传递语义状态
因此,需通过文档或 CI 脚本显式维护子模块一致性。

2.5 冻结策略对优化器参数组的影响机制

在深度学习训练过程中,冻结策略通过控制模型中某些参数是否参与梯度更新,直接影响优化器的参数组管理。
参数组的动态划分
采用冻结策略时,模型参数被划分为可训练与不可训练两部分。优化器仅接收requires_grad=True的参数,从而实现对特定层的隔离更新。
for name, param in model.named_parameters():
    if "encoder" in name:
        param.requires_grad = False  # 冻结编码器
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)
上述代码中,filter函数筛选出需更新的参数,确保优化器不包含冻结层参数,减少内存开销并提升计算效率。
对优化器状态的影响
冻结参数不会初始化优化器中的动量、方差等状态变量,显著降低显存占用。同时,该机制支持分阶段解冻策略,实现渐进式微调。

第三章:典型场景下的冻结实践与调试技巧

3.1 预训练模型特征提取层的冻结实现

在迁移学习中,冻结预训练模型的特征提取层可有效保留其已学习到的通用图像特征,同时减少训练开销。
冻结卷积基的实现方法
以TensorFlow/Keras为例,可通过设置`trainable=False`来冻结模型前缀层:

base_model = tf.keras.applications.ResNet50(weights='imagenet', include_top=False)
base_model.trainable = False  # 冻结所有卷积层

model = tf.keras.Sequential([
    base_model,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(10, activation='softmax')
])
上述代码中,`base_model.trainable = False`会锁定ResNet50的全部权重,仅允许后续全连接层参与梯度更新。该策略适用于小数据集微调,避免破坏预训练特征。
参数影响对比
配置可训练参数量适用场景
全模型训练约25M大数据集,域差异大
仅顶层训练约10K小数据集,域相近

3.2 条件化冻结:动态控制训练中的可微分状态

在复杂模型训练中,条件化冻结技术允许根据运行时指标动态调整网络层的可微分状态,从而优化资源分配与收敛效率。
动态冻结策略
通过监控每层梯度幅值,可设定阈值触发冻结机制。例如,当某层连续三步梯度L2范数低于1e-4时,自动将其参数设置为不可更新。

for param in model.layers[depth].parameters():
    if grad_norm_history[depth][-3:] < [1e-4] * 3:
        param.requires_grad = False  # 冻结该层
    else:
        param.requires_grad = True   # 恢复可微分
上述代码通过维护梯度历史记录实现动态控制。requires_grad标志直接影响反向传播是否计算该层梯度,实现细粒度训练调控。
应用场景对比
场景冻结策略收益
迁移学习浅层冻结保留通用特征
大模型微调条件化解冻节省显存30%

3.3 多任务学习中混合梯度流的隔离方案

在多任务学习中,不同任务的梯度更新可能相互干扰,导致模型收敛不稳定。为解决这一问题,梯度隔离机制应运而生。

梯度掩码策略

通过引入任务专属的梯度掩码,控制反向传播时的参数更新路径:

# 示例:基于掩码的梯度隔离
mask = task_masks[task_id]  # 形状 [D,], D为参数维度
grad = grad * mask.detach() # 隔离共享层中的任务特定梯度
上述代码中,mask 在前向传播中记录任务活跃状态,detach() 确保掩码不参与梯度计算,避免影响优化方向。

参数空间划分方法

  • 硬共享架构中,底层共享权重易受冲突梯度影响;
  • 采用子空间投影,将各任务梯度限制在正交子空间内;
  • 提升模型对任务间语义差异的鲁棒性。

第四章:高级冻结技术与性能优化策略

4.1 使用torch.no_grad与requires_grad的协同控制

在PyTorch中,`torch.no_grad()` 与 `requires_grad` 是控制梯度计算的核心机制。通过合理组合二者,可实现训练与推理阶段的高效切换。
功能分工
  • requires_grad=True:标识张量参与梯度计算,常用于模型参数
  • torch.no_grad():上下文管理器,临时禁用梯度追踪,加速推理并节省内存
典型应用场景
x = torch.tensor([1.0, 2.0], requires_grad=True)
with torch.no_grad():
    y = x * 2  # y.requires_grad 为 False,不记录计算图
该代码中,尽管输入张量 `x` 需要梯度,但在 no_grad 上下文中,操作不会被追踪,确保推理过程无冗余计算。
协同控制策略
场景requires_gradtorch.no_grad效果
训练正向传播True关闭记录计算图
模型评估True开启禁用梯度,提升性能

4.2 自定义参数分组与优化器级别的冻结管理

在复杂模型训练中,对不同参数组应用差异化的优化策略至关重要。通过自定义参数分组,可为网络的不同部分设置独立的学习率或权重衰减。
参数分组配置示例
optimizer = torch.optim.Adam([
    {'params': model.features.parameters(), 'lr': 1e-5},  # 冻结特征提取层
    {'params': model.classifier.parameters(), 'lr': 1e-3}  # 解冻分类头,使用较大学习率
])
上述代码将模型划分为两个参数组:底层特征提取网络使用极低学习率以保留预训练权重,而分类头则允许快速收敛。
动态冻结策略
  • 初期冻结主干网络,仅训练头部分类器
  • 中期逐步解冻靠后的特征块,配合分组学习率微调
  • 后期全量微调,整体优化所有参数
该策略有效避免深层网络在早期训练阶段因梯度剧烈变化而破坏已有特征表达。

4.3 冻结状态持久化:保存与加载时的陷阱规避

在实现对象冻结状态持久化过程中,开发者常因忽略序列化边界条件而引入数据不一致风险。正确处理保存与恢复流程至关重要。
序列化前的状态校验
应确保对象在冻结后不可变性未被破坏。可通过反射机制验证字段完整性:

func (o *Object) Serialize() ([]byte, error) {
    if !o.frozen {
        return nil, errors.New("object not frozen")
    }
    data, err := json.Marshal(o.state)
    if err != nil {
        return nil, fmt.Errorf("marshal failed: %w", err)
    }
    return data, nil
}
该代码确保仅当 frozen 标志为真时才允许序列化,防止中途状态误存。
反序列化中的版本兼容性
使用结构化标签管理 schema 演进:
  • 添加字段时设置默认值
  • 移除字段需保留占位以避免解析失败
  • 关键变更应伴随版本号递增

4.4 混合精度训练中冻结参数的行为一致性保障

在混合精度训练中,冻结参数的梯度计算与更新需严格隔离,以避免因自动微分机制误触发低精度类型操作导致行为不一致。
参数冻结与类型保持同步
通过将冻结参数显式排除在优化器更新范围外,并固定其数据类型为高精度(如FP32),可确保前向与反向传播过程中数值稳定性。

for param in model.parameters():
    if not param.requires_grad:
        param.data = param.data.float()  # 强制保持FP32
上述代码确保冻结参数始终运行在高精度模式下,防止自动转换至FP16引发舍入误差。
优化器参数组隔离
  • 将可训练参数与冻结参数分组管理
  • 仅对requires_grad=True的参数注册优化器更新
  • 使用torch.no_grad()上下文阻止冻结层梯度累积

第五章:总结与最佳实践建议

构建高可用微服务架构的通信策略
在分布式系统中,服务间通信的稳定性至关重要。使用 gRPC 作为内部通信协议时,应启用双向流与超时控制,避免级联故障。

// 设置客户端超时,防止长时间阻塞
conn, err := grpc.Dial(
    "service-address:50051",
    grpc.WithInsecure(),
    grpc.WithTimeout(5*time.Second), // 关键:设置连接超时
    grpc.WithBlock(),
)
if err != nil {
    log.Fatalf("did not connect: %v", err)
}
defer conn.Close()
配置管理的最佳实践
避免将敏感信息硬编码在代码中。推荐使用 HashiCorp Vault 或 Kubernetes Secrets 结合环境变量注入。
  • 开发、测试、生产环境使用独立的配置命名空间
  • 定期轮换密钥并审计访问日志
  • 使用 ConfigMap 管理非敏感配置,提升部署可移植性
性能监控与告警机制
集成 Prometheus 和 Grafana 实现指标可视化。关键指标包括请求延迟 P99、错误率和服务健康状态。
指标名称采集方式告警阈值
HTTP 5xx 错误率Prometheus + Exporter>5% 持续5分钟
服务响应延迟(P99)OpenTelemetry 追踪>1.5s
自动化部署流水线设计
采用 GitOps 模式,通过 ArgoCD 实现 Kubernetes 集群的持续交付。每次提交至 main 分支触发 CI/CD 流水线:
  1. 代码扫描(SonarQube)
  2. 单元测试与覆盖率检查
  3. 镜像构建并推送到私有 Registry
  4. ArgoCD 自动同步部署到预发环境

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

基于PyTorch的自主编写VGG16猫狗分类实验代码如下: # 1. 导入必要库 import torch import torch.nn as nn import torch.optim as optim from torch.optim.lr_scheduler import StepLR from torchvision import transforms from torch.utils.data import DataLoader, ConcatDataset import os from PIL import Image import matplotlib.pyplot as plt from sklearn.metrics import accuracy_score, precision_score, recall_score, confusion_matrix import numpy as np # 2. 超参数与设备配置 batch_size = 16 # GPU内存不足时改为8/4,CPU改为4 epochs = 10 lr = 1e-4 gamma = 0.7 num_classes = 2 # 猫/狗二分类 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") root = r"d:\Users\86192\Desktop\catdog\catdog" # 数据集根路径 # 3. 数据预处理(训练集增强,测试/验证集基础处理) train_transform = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(p=0.5), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), transforms.GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 2.0)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) val_test_transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # 自定义数据集:从文件名提取猫/狗标签(文件名含"cat"→猫,含"dog"→狗) class CustomDataset(torch.utils.data.Dataset): def __init__(self, img_dir, transform=None): self.img_dir = img_dir self.transform = transform # 筛选图片文件(支持.jpg/.jpeg/.png) self.img_paths = [ os.path.join(img_dir, f) for f in os.listdir(img_dir) if f.endswith(('.jpg', '.jpeg', '.png')) ] # 从文件名提取标签(0=猫,1=狗) self.labels = [] for path in self.img_paths: filename = os.path.basename(path).lower() if 'cat' in filename: self.labels.append(0) elif 'dog' in filename: self.labels.append(1) else: raise ValueError(f"文件名 {path} 不含'cat'/'dog',无法自动标注!请检查文件名或手动标注。") def __len__(self): return len(self.img_paths) def __getitem__(self, idx): img = Image.open(self.img_paths[idx]).convert('RGB') # 统一转RGB(避免灰度图报错) label = self.labels[idx] if self.transform: img = self.transform(img) return img, label # 4. 加载数据集(使用CustomDataset替代ImageFolder) # 训练集:合并training_set + training_set2 train_dirs = [os.path.join(root, "training_set"), os.path.join(root, "training_set2")] train_datasets = [CustomDataset(dir, transform=train_transform) for dir in train_dirs] train_dataset = ConcatDataset(train_datasets) train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=2 ) # 验证集:training_set3 val_dataset = CustomDataset(os.path.join(root, "training_set3"), transform=val_test_transform) val_loader = DataLoader( val_dataset, batch_size=batch_size, shuffle=False, num_workers=2 ) # 测试集:test_set test_dataset = CustomDataset(os.path.join(root, "test_set"), transform=val_test_transform) test_loader = DataLoader( test_dataset, batch_size=batch_size, shuffle=False, num_workers=2 ) # 小测试集:test_set3 test_small_dataset = CustomDataset(os.path.join(root, "test_set3"), transform=val_test_transform) test_small_loader = DataLoader( test_small_dataset, batch_size=batch_size, shuffle=False, num_workers=2 ) # 打印数据集统计(验证标签提取是否正确) print("=== 数据集统计 ===") for ds, name in zip( [train_datasets[0], train_datasets[1], val_dataset, test_dataset, test_small_dataset], ["training_set", "training_set2", "training_set3", "test_set", "test_set3"] ): cat_num = sum(1 for l in ds.labels if l == 0) dog_num = len(ds.labels) - cat_num print(f"{name}: 总图片数={len(ds)}, 猫={cat_num}, 狗={dog_num}") # 5. 自主编写VGG16模型(完全复现实验报告结构) class VGG16Custom(nn.Module): def __init__(self, num_classes=2): super(VGG16Custom, self).__init__() # 特征提取层(13个卷积 + 5个池化,严格匹配VGG16结构) self.features = nn.Sequential( # 第1组:2×Conv3-64 + MaxPool nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True), nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2), # 第2组:2×Conv3-128 + MaxPool nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True), nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2), # 第3组:3×Conv3-256 + MaxPool nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2), # 第4组:3×Conv3-512 + MaxPool nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True), nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True), nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2), # 第5组:3×Conv3-512 + MaxPool nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True), nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True), nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2) ) # 平均池化层(适配全连接层输入) self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) # 分类器(3个全连接层 + Dropout正则化) self.classifier = nn.Sequential( nn.Linear(512 * 7 * 7, 4096), nn.ReLU(inplace=True), nn.Dropout(p=0.5), nn.Linear(4096, 4096), nn.ReLU(inplace=True), nn.Dropout(p=0.5), nn.Linear(4096, num_classes) ) # 实验要求:冻结所有参数,仅解冻features[0]层(第一个卷积层) for param in self.parameters(): param.requires_grad = False for param in self.features[0].parameters(): param.requires_grad = True def forward(self, x): x = self.features(x) # 特征提取 x = self.avgpool(x) # 池化 x = torch.flatten(x, 1) # 展平(batch维度外的维度合并) x = self.classifier(x) # 分类 return x # 实例化模型并转移到设备(GPU/CPU) model = VGG16Custom(num_classes=num_classes).to(device) print(f"模型已加载至{device},仅解冻features[0]层(第一个卷积层)") # 6. 损失函数、优化器、学习率调度器 criterion = nn.CrossEntropyLoss() # 交叉熵损失(适配二分类) # 仅优化解冻的参数(features[0]层 + 分类器层) optimizer = optim.Adam( filter(lambda p: p.requires_grad, model.parameters()), lr=lr ) scheduler = StepLR(optimizer, step_size=3, gamma=gamma) # 每3轮学习率衰减为70% # 7. 训练函数 def train(model, device, train_loader, criterion, optimizer, epoch): model.train() # 启用训练模式(Dropout生效) train_loss = 0.0 correct = 0 total = 0 for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) # 前向传播 optimizer.zero_grad() # 清空梯度 output = model(data) loss = criterion(output, target) # 反向传播 + 参数更新 loss.backward() optimizer.step() # 统计损失与准确率 train_loss += loss.item() * data.size(0) _, predicted = output.max(1) total += target.size(0) correct += predicted.eq(target).sum().item() # 打印批次进度 if batch_idx % 10 == 0: progress = 100. * batch_idx / len(train_loader) print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ' f'({progress:.0f}%)]\tLoss: {loss.item():.6f}') # 计算平均损失与准确率 avg_loss = train_loss / len(train_loader.dataset) acc = 100. * correct / total print(f'Epoch {epoch} | 训练损失: {avg_loss:.6f} | 训练准确率: {acc:.2f}%\n') return avg_loss, acc # 8. 验证函数 def validate(model, device, val_loader, criterion): model.eval() # 启用评估模式(Dropout关闭) val_loss = 0.0 correct = 0 total = 0 with torch.no_grad(): # 禁用梯度计算(节省内存) for data, target in val_loader: data, target = data.to(device), target.to(device) output = model(data) loss = criterion(output, target) val_loss += loss.item() * data.size(0) _, predicted = output.max(1) total += target.size(0) correct += predicted.eq(target).sum().item() # 计算验证损失与准确率 avg_loss = val_loss / len(val_loader.dataset) acc = 100. * correct / total print(f'验证损失: {avg_loss:.6f} | 验证准确率: {acc:.2f}%\n') return avg_loss, acc # 9. 测试函数(输出Accuracy、Precision、Recall) def test(model, device, test_loader, test_name): model.eval() all_preds = [] # 存储所有预测结果 all_targets = [] # 存储所有真实标签 with torch.no_grad(): for data, target in test_loader: data, target = data.to(device), target.to(device) output = model(data) _, predicted = output.max(1) # 转移到CPU并转为numpy(适配sklearn指标计算) all_preds.extend(predicted.cpu().numpy()) all_targets.extend(target.cpu().numpy()) # 计算三大指标(正类=狗(1),负类=猫(0)) accuracy = accuracy_score(all_targets, all_preds) precision = precision_score( all_targets, all_preds, average='binary', zero_division=0 # zero_division处理无正类预测的情况 ) recall = recall_score( all_targets, all_preds, average='binary', zero_division=0 ) # 计算混淆矩阵(辅助分析错误类型) cm = confusion_matrix(all_targets, all_preds) # 打印结果 print(f"=== {test_name} 评价指标 ===") print(f"Accuracy(准确率): {accuracy:.4f} ({accuracy*100:.2f}%)") print(f"Precision(精确率): {precision:.4f}") print(f"Recall(召回率): {recall:.4f}") print(f"混淆矩阵:\n{cm}") print(f"TP(狗预测为狗): {cm[1,1]}, TN(猫预测为猫): {cm[0,0]}") print(f"FP(猫预测为狗): {cm[0,1]}, FN(狗预测为猫): {cm[1,0]}\n") return accuracy, precision, recall, cm # 10. 执行训练与测试 if __name__ == "__main__": # 记录训练过程(用于绘图) train_loss_history = [] train_acc_history = [] val_loss_history = [] val_acc_history = [] print("===== 开始训练 =====") for epoch in range(1, epochs + 1): train_loss, train_acc = train(model, device, train_loader, criterion, optimizer, epoch) val_loss, val_acc = validate(model, device, val_loader, criterion) scheduler.step() # 更新学习率 # 保存训练历史 train_loss_history.append(train_loss) train_acc_history.append(train_acc) val_loss_history.append(val_loss) val_acc_history.append(val_acc) print("===== 开始测试 =====") # 测试主测试集(test_set) test(model, device, test_loader, "主测试集(test_set)") # 测试小测试集(test_set3) test(model, device, test_small_loader, "小测试集(test_set3)") # 绘制训练/验证曲线(保存到数据集根目录) plt.figure(figsize=(12, 5)) # 子图1:损失曲线 plt.subplot(1, 2, 1) plt.plot(range(1, epochs+1), train_loss_history, label="训练损失", color="blue") plt.plot(range(1, epochs+1), val_loss_history, label="验证损失", color="red") plt.xlabel("轮次(Epoch)") plt.ylabel("损失(Loss)") plt.title("VGG16训练与验证损失曲线") plt.legend() plt.grid(alpha=0.3) # 子图2:准确率曲线 plt.subplot(1, 2, 2) plt.plot(range(1, epochs+1), train_acc_history, label="训练准确率", color="blue") plt.plot(range(1, epochs+1), val_acc_history, label="验证准确率", color="red") plt.xlabel("轮次(Epoch)") plt.ylabel("准确率(%)") plt.title("VGG16训练与验证准确率曲线") plt.legend() plt.grid(alpha=0.3) # 保存曲线 curve_path = os.path.join(root, "training_validation_curves.jpg") plt.tight_layout() plt.savefig(curve_path, dpi=300, bbox_inches="tight") plt.show() print(f"训练/验证曲线已保存至: {curve_path}") # 保存模型参数 model_path = os.path.join(root, "vgg16_custom_catdog_model.pth") torch.save(model.state_dict(), model_path) print(f"自主编写VGG16模型参数已保存至: {model_path}") 代码是否存在错误,是否有更优化的代码或者是更符合VGG16模型的代码,请给出我完整的代码
09-12
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值