PyTorch使用技巧1:F.dropout加self.training、F.log_softmax

博客主要介绍了Pytorch相关内容。一是self.training,涉及dropout方法在训练和验证/测试时的使用情况,以及Module的training属性、F.dropout和nn.Dropout的操作实现,还有training属性受train()和eval()方法影响的机制;二是F.log_softmax,介绍了其作用和函数语言格式。
部署运行你感兴趣的模型镜像

在这里插入图片描述

1、self.training

dropout方法是将输入Tensor的元素按伯努利分布随机置0,具体原理此处不赘,以后待补。总之就是训练的时候要用dropout,验证/测试的时候要关dropout。

以下介绍Module的training属性,F(torch.nn.functional).dropout 和 nn(torch.nn).Dropout 中相应操作的实现方式,以及Module的training属性受train()和eval()方法影响而改变的机制。

方法来自论文:https://www.jmlr.org/papers/volume15/srivastava14a/srivastava14a.pdf

参考:https://blog.youkuaiyun.com/PolarisRisingWar/article/details/117754981

2、F.log_softmax

F.softmax作用:
按照行或者列来做归一化的
F.softmax函数语言格式:

# 0是对列做归一化,1是对行做归一化

F.softmax作用:按照行或者列来做归一化的
F.softmax函数语言格式:

# 0是对列做归一化,1是对行做归一化


F.softmax(x,dim=1) 或者 F.softmax(x,dim=0)

F.log_softmax作用:

在softmax的结果上再做多一次log运算

F.log_softmax函数语言格式:

F.log_softmax(x,dim=1) 或者 F.log_softmax(x,dim=0)

原文链接:https://blog.youkuaiyun.com/m0_51004308/article/details/118001835
参考:https://blog.youkuaiyun.com/m0_51004308/article/details/118001835

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

请解读以下代码: import torch.optim as optim import torch import torch.nn as nn import torch.nn.functional as F class GCN(nn.Module): def __init__(self, nfeat, nhid, nclass, dropout): super(GCN, self).__init__() self.gc1 = GraphConvolution(nfeat, nhid) self.gc2 = GraphConvolution(nhid, nclass) self.dropout = dropout def forward(self, x, adj): x = F.relu(self.gc1(x, adj)) x = F.dropout(x, self.dropout, training=self.training) x = self.gc2(x, adj) return F.log_softmax(x, dim=1) class GraphConvolution(nn.Module): def __init__(self, in_features, out_features, bias=True): super(GraphConvolution, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features)) if bias: self.bias = nn.Parameter(torch.FloatTensor(out_features)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): nn.init.kaiming_uniform_(self.weight) if self.bias is not None: nn.init.zeros_(self.bias) def forward(self, input, adj): support = torch.mm(input, self.weight) output = torch.spmm(adj, support) if self.bias is not None: return output + self.bias else: return output # 载数据 adj, features, labels, idx_train, idx_val, idx_test = load_data() # 初始化模型和优化器 model = GCN(nfeat=features.shape[1], nhid=16, nclass=labels.max().item() + 1,dropout=0.5) optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) # 训练模型 model.train() for epoch in range(200): optimizer.zero_grad() output = model(features, adj) loss_train = F.nll_loss(output[idx_train], labels[idx_train]) acc_train = accuracy(output[idx_train], labels[idx_train]) loss_train.backward() optimizer.step() # 验证模型 model.eval() output = model(features, adj) loss_val = F.nll_loss(output[idx_val], labels[idx_val]) acc_val = accuracy(output[idx_val], labels[idx_val]) print('Epoch: {:04d}'.format(epoch+1), 'loss_train: {:.4f}'.format(loss_train.item()), 'acc_train: {:.4f}'.format(acc_train.item()), 'loss_val: {:.4f}'.format(loss_val.item()), 'acc_val: {:.4f}'.format(acc_val.item()))
最新发布
12-06
请解读以下代码: import torch.optim as optim import torch import torch.nn as nn import torch.nn.functional as F class GCN(nn.Module): def __init__(self, nfeat, nhid, nclass, dropout): super(GCN, self).__init__() self.gc1 = GraphConvolution(nfeat, nhid) self.gc2 = GraphConvolution(nhid, nclass) self.dropout = dropout def forward(self, x, adj): x = F.relu(self.gc1(x, adj)) x = F.dropout(x, self.dropout, training=self.training) x = self.gc2(x, adj) return F.log_softmax(x, dim=1) class GraphConvolution(nn.Module): def __init__(self, in_features, out_features, bias=True): super(GraphConvolution, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features)) if bias: self.bias = nn.Parameter(torch.FloatTensor(out_features)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): nn.init.kaiming_uniform_(self.weight) if self.bias is not None: nn.init.zeros_(self.bias) def forward(self, input, adj): support = torch.mm(input, self.weight) output = torch.spmm(adj, support) if self.bias is not None: return output + self.bias else: return output # 载数据 adj, features, labels, idx_train, idx_val, idx_test = load_data() # 初始化模型和优化器 model = GCN(nfeat=features.shape[1], nhid=16, nclass=labels.max().item() + 1,dropout=0.5) optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) # 训练模型 model.train() for epoch in range(200): optimizer.zero_grad() output = model(features, adj) loss_train = F.nll_loss(output[idx_train], labels[idx_train]) acc_train = accuracy(output[idx_train], labels[idx_train]) loss_train.backward() optimizer.step() # 验证模型 model.eval() output = model(features, adj) loss_val = F.nll_loss(output[idx_val], labels[idx_val]) acc_val = accuracy(output[idx_val], labels[idx_val]) print('Epoch: {:04d}'.format(epoch+1), 'loss_train: {:.4f}'.format(loss_train.item()), 'acc_train: {:.4f}'.format(acc_train.item()), 'loss_val: {:.4f}'.format(loss_val.item()), 'acc_val: {:.4f}'.format(acc_val.item()))
12-06
import os import datetime import numpy as np import matplotlib.pyplot as plt import shutil from glob import glob from tqdm import tqdm from PIL import Image import csv import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from torchvision import transforms from torch.utils.tensorboard import SummaryWriter from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score, accuracy_score class AttentionUNet: def __init__(self, input_shape, classes, epochs, result_path, database_path, learning_rate=1e-4, batch_size=8, early_stopping_patience=10, lr_reduction_patience=5, lr_reduction_factor=0.5, dropout_rate=0.3, filters_base=32, kernel_size=3, activation='relu', output_activation='softmax', optimizer='adam', loss='cross_entropy', # 默认使用PyTorch支持的损失函数 metrics=['accuracy'], attention_mechanism='additive'): self.input_shape = input_shape self.num_classes = classes['class_number'] # 解析类别信息 self.background_gray = classes['bg'][0] self.background_name = classes['bg'][1] self.foreground_labels = classes['fg'] # 创建灰度值到类别索引的映射 self.gray_to_index = {self.background_gray: 0} # 背景映射到索引0 # 前景映射到索引1,2,... for idx, (gray_val, name) in enumerate(self.foreground_labels.items(), start=1): self.gray_to_index[gray_val] = idx # 验证映射 if len(self.gray_to_index) != self.num_classes: raise ValueError(f"类别数量不匹配: 配置的类别数={self.num_classes}, 实际映射的类别数={len(self.gray_to_index)}") self.epochs = epochs self.result_path = result_path self.database_path = database_path self.learning_rate = learning_rate self.batch_size = batch_size self.early_stopping_patience = early_stopping_patience self.lr_reduction_patience = lr_reduction_patience self.lr_reduction_factor = lr_reduction_factor self.dropout_rate = dropout_rate self.filters_base = filters_base self.kernel_size = kernel_size self.activation = activation self.output_activation = output_activation self.optimizer_type = optimizer self.loss_type = loss self.metrics = metrics self.attention_mechanism = attention_mechanism self.model = None self.history = None self.best_model_path = os.path.join(self.result_path, 'models', 'best_model.pth') self.temp_model_dir = os.path.join(self.result_path, 'models', 'temp_epochs') self.train_log_path = os.path.join(self.result_path, 'diagram', '训练日志.csv') self.best_model_performance_path = os.path.join(self.result_path, 'diagram', '最佳模型性能.txt') self.performance_plots_dir = os.path.join(self.result_path, 'diagram', 'performance_plots') self.tensorboard_dir = os.path.join(self.result_path, 'tensorboard') # 创建结果目录结构 self._create_result_directories() self._print_model_config() # 设置设备 self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"使用设备: {self.device}") def _create_result_directories(self): """创建结果文件目录结构并保存配置。""" dirs_to_create = [ self.result_path, os.path.join(self.result_path, 'models'), os.path.join(self.result_path, 'diagram'), self.temp_model_dir, # 确保包含这个目录 self.performance_plots_dir, self.tensorboard_dir ] for dir_path in dirs_to_create: os.makedirs(dir_path, exist_ok=True) print(f"目录已创建: {dir_path}") print(f"结果文件目录结构已创建:{self.result_path}") # 保存模型配置 config_path = os.path.join(self.result_path, 'models', 'configuration.txt') try: with open(config_path, 'w', encoding='utf-8') as f: f.write("Attention U-Net 模型配置 (PyTorch 实现)\n") f.write(f"保存时间: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") f.write("="*50 + "\n\n") f.write("核心参数:\n") f.write(f"输入尺寸: {self.input_shape}\n") f.write(f"类别数: {self.num_classes}\n") f.write(f"背景类别: {self.background_name} (灰度值: {self.background_gray})\n") f.write("前景类别:\n") for gray_val, name in self.foreground_labels.items(): f.write(f" {name}: 灰度值 {gray_val}\n") f.write("\n训练参数:\n") f.write(f"训练轮数: {self.epochs}\n") f.write(f"批次大小: {self.batch_size}\n") f.write(f"学习率: {self.learning_rate}\n") f.write(f"早停耐心值: {self.early_stopping_patience}\n") f.write(f"学习率衰减耐心值: {self.lr_reduction_patience}\n") f.write(f"学习率衰减因子: {self.lr_reduction_factor}\n") f.write("\n模型架构参数:\n") f.write(f"基础卷积核数量: {self.filters_base}\n") f.write(f"卷积核大小: {self.kernel_size}\n") f.write(f"激活函数: {self.activation}\n") f.write(f"输出激活函数: {self.output_activation}\n") f.write(f"Dropout比率: {self.dropout_rate}\n") f.write(f"注意力机制类型: {self.attention_mechanism}\n") f.write("\n优化参数:\n") f.write(f"优化器: {self.optimizer_type}\n") f.write(f"损失函数: {self.loss_type}\n") f.write(f"监控指标: {', '.join(self.metrics)}\n") print(f"模型配置已保存到: {config_path}") except Exception as e: print(f"保存配置时出错: {e}") def _print_model_config(self): """在控制台打印模型配置。""" print("\n" + "="*30) print("模型配置:") print(f"输入尺寸: {self.input_shape}") print(f"类别数: {self.num_classes}") print(f"背景类别: {self.background_name} (灰度值: {self.background_gray} -> 索引: {self.gray_to_index[self.background_gray]})") print("前景类别:") for gray_val, name in self.foreground_labels.items(): print(f" {name} (灰度值: {gray_val} -> 索引: {self.gray_to_index[gray_val]})") print(f"训练轮数: {self.epochs}") print(f"结果文件路径: {self.result_path}") print(f"数据集文件路径: {self.database_path}") print(f"学习率: {self.learning_rate}") print(f"批次大小: {self.batch_size}") print(f"早停耐心值: {self.early_stopping_patience}") print(f"学习率衰减耐心值: {self.lr_reduction_patience}") print(f"学习率衰减因子: {self.lr_reduction_factor}") print(f"Dropout比率: {self.dropout_rate}") print(f"基础卷积核数量: {self.filters_base}") print(f"卷积核大小: {self.kernel_size}") print(f"激活函数: {self.activation}") print(f"输出激活函数: {self.output_activation}") print(f"优化器: {self.optimizer_type}") print(f"损失函数: {self.loss_type}") print(f"监控指标: {self.metrics}") print(f"注意力机制类型: {self.attention_mechanism}") print("="*30 + "\n") class AttentionUNetModel(nn.Module): """PyTorch实现的Attention U-Net模型""" def __init__(self, input_channels, num_classes, filters_base=32, kernel_size=3, dropout_rate=0.3, attention_mechanism='additive', output_activation='softmax'): super().__init__() self.input_channels = input_channels self.num_classes = num_classes self.filters_base = filters_base self.kernel_size = kernel_size self.dropout_rate = dropout_rate self.attention_mechanism = attention_mechanism self.output_activation = output_activation # 添输出激活函数属性 # 编码器 (Encoder) self.enc1 = self._conv_block(input_channels, filters_base) self.enc2 = self._conv_block(filters_base, filters_base * 2) self.enc3 = self._conv_block(filters_base * 2, filters_base * 4) self.enc4 = self._conv_block(filters_base * 4, filters_base * 8) self.pool = nn.MaxPool2d(2) self.dropout = nn.Dropout2d(dropout_rate) # Bottleneck self.bottleneck = self._conv_block(filters_base * 8, filters_base * 16) # 解码器 (Decoder) self.up1 = self._up_block(filters_base * 16, filters_base * 8) self.att1 = self._attention_gate(filters_base * 8, filters_base * 8) self.dec1 = self._conv_block(filters_base * 16, filters_base * 8) self.up2 = self._up_block(filters_base * 8, filters_base * 4) self.att2 = self._attention_gate(filters_base * 4, filters_base * 4) self.dec2 = self._conv_block(filters_base * 8, filters_base * 4) self.up3 = self._up_block(filters_base * 4, filters_base * 2) self.att3 = self._attention_gate(filters_base * 2, filters_base * 2) self.dec3 = self._conv_block(filters_base * 4, filters_base * 2) self.up4 = self._up_block(filters_base * 2, filters_base) self.att4 = self._attention_gate(filters_base, filters_base) self.dec4 = self._conv_block(filters_base * 2, filters_base) # 输出层 self.out_conv = nn.Conv2d(filters_base, num_classes, kernel_size=1) def _conv_block(self, in_channels, out_channels): """标准卷积块,包含Conv2D, BatchNorm, Activation。""" return nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=self.kernel_size, padding=self.kernel_size//2), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=self.kernel_size, padding=self.kernel_size//2), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def _up_block(self, in_channels, out_channels): """上采样块,包含ConvTranspose2d和Dropout。""" return nn.Sequential( nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2), nn.Dropout2d(self.dropout_rate) ) def _attention_gate(self, g_channels, x_channels): """注意力门 (Attention Gate)。""" if self.attention_mechanism == 'additive': # 性注意力机制 return nn.Sequential( nn.Conv2d(g_channels, x_channels, kernel_size=1), nn.BatchNorm2d(x_channels), nn.ReLU(inplace=True), nn.Conv2d(x_channels, x_channels, kernel_size=1), nn.BatchNorm2d(x_channels), nn.Sigmoid() ) elif self.attention_mechanism == 'multiplicative': # 乘性注意力机制 return nn.Sequential( nn.Conv2d(g_channels + x_channels, x_channels, kernel_size=1), nn.BatchNorm2d(x_channels), nn.ReLU(inplace=True), nn.Conv2d(x_channels, 1, kernel_size=1), nn.BatchNorm2d(1), nn.Sigmoid() ) else: raise ValueError(f"不支持的注意力机制类型: {self.attention_mechanism}") def forward(self, x): # 编码器路径 enc1 = self.enc1(x) enc1_pool = self.pool(enc1) enc1_pool = self.dropout(enc1_pool) enc2 = self.enc2(enc1_pool) enc2_pool = self.pool(enc2) enc2_pool = self.dropout(enc2_pool) enc3 = self.enc3(enc2_pool) enc3_pool = self.pool(enc3) enc3_pool = self.dropout(enc3_pool) enc4 = self.enc4(enc3_pool) enc4_pool = self.pool(enc4) enc4_pool = self.dropout(enc4_pool) # 瓶颈层 bottleneck = self.bottleneck(enc4_pool) # 解码器路径 # 上采样块1 up1 = self.up1(bottleneck) if self.attention_mechanism == 'additive': att1 = self.att1(up1) att1 = att1 * enc4 else: att1 = self.att1(torch.cat([up1, enc4], dim=1)) att1 = att1 * enc4 merge1 = torch.cat([up1, att1], dim=1) dec1 = self.dec1(merge1) # 上采样块2 up2 = self.up2(dec1) if self.attention_mechanism == 'additive': att2 = self.att2(up2) att2 = att2 * enc3 else: att2 = self.att2(torch.cat([up2, enc3], dim=1)) att2 = att2 * enc3 merge2 = torch.cat([up2, att2], dim=1) dec2 = self.dec2(merge2) # 上采样块3 up3 = self.up3(dec2) if self.attention_mechanism == 'additive': att3 = self.att3(up3) att3 = att3 * enc2 else: att3 = self.att3(torch.cat([up3, enc2], dim=1)) att3 = att3 * enc2 merge3 = torch.cat([up3, att3], dim=1) dec3 = self.dec3(merge3) # 上采样块4 up4 = self.up4(dec3) if self.attention_mechanism == 'additive': att4 = self.att4(up4) att4 = att4 * enc1 else: att4 = self.att4(torch.cat([up4, enc1], dim=1)) att4 = att4 * enc1 merge4 = torch.cat([up4, att4], dim=1) dec4 = self.dec4(merge4) # 输出层 out = self.out_conv(dec4) # 应用输出激活函数 if self.output_activation == 'softmax': out = F.softmax(out, dim=1) elif self.output_activation == 'sigmoid': out = torch.sigmoid(out) return out def _dice_loss(self, y_pred, y_true, smooth=1e-6): """Dice损失函数实现""" # 展平预测和真实标签 y_pred_flat = y_pred.contiguous().view(-1) y_true_flat = y_true.contiguous().view(-1) # 计算交集和并集 intersection = (y_pred_flat * y_true_flat).sum() union = y_pred_flat.sum() + y_true_flat.sum() # 计算Dice系数 dice = (2. * intersection + smooth) / (union + smooth) # 返回Dice损失 return 1 - dice def _focal_loss(self, y_pred, y_true, alpha=0.25, gamma=2.0): """Focal损失函数实现""" # 计算交叉熵 ce_loss = F.binary_cross_entropy(y_pred, y_true, reduction='none') # 计算概率 p_t = y_true * y_pred + (1 - y_true) * (1 - y_pred) # 计算Focal损失 focal_loss = torch.pow(1 - p_t, gamma) * ce_loss # 应用alpha权重 if alpha is not None: alpha_t = y_true * alpha + (1 - y_true) * (1 - alpha) focal_loss = alpha_t * focal_loss return focal_loss.mean() def _jaccard_loss(self, y_pred, y_true, smooth=1e-6): """Jaccard损失函数实现(IoU损失)""" # 展平预测和真实标签 y_pred_flat = y_pred.contiguous().view(-1) y_true_flat = y_true.contiguous().view(-1) # 计算交集和并集 intersection = (y_pred_flat * y_true_flat).sum() total = y_pred_flat.sum() + y_true_flat.sum() union = total - intersection # 计算Jaccard指数(IoU) iou = (intersection + smooth) / (union + smooth) return 1 - iou def build_model(self): """构建Attention U-Net模型。""" input_channels = self.input_shape[2] if len(self.input_shape) == 3 else 3 self.model = self.AttentionUNetModel( input_channels=input_channels, num_classes=self.num_classes, filters_base=self.filters_base, kernel_size=self.kernel_size, dropout_rate=self.dropout_rate, attention_mechanism=self.attention_mechanism, output_activation=self.output_activation # 添输出激活函数参数 ).to(self.device) print("Attention U-Net 模型已成功构建。") print(f"模型参数数量: {sum(p.numel() for p in self.model.parameters() if p.requires_grad):,}") # 选择优化器 if self.optimizer_type.lower() == 'adam': optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate) elif self.optimizer_type.lower() == 'sgd': optimizer = optim.SGD(self.model.parameters(), lr=self.learning_rate, momentum=0.9) elif self.optimizer_type.lower() == 'rmsprop': optimizer = optim.RMSprop(self.model.parameters(), lr=self.learning_rate) else: raise ValueError(f"不支持的优化器类型: {self.optimizer_type}") # 选择损失函数 if self.loss_type.lower() == 'dice_loss': criterion = self._dice_loss elif self.loss_type.lower() == 'focal_loss': criterion = self._focal_loss elif self.loss_type.lower() == 'jaccard_loss': criterion = self._jaccard_loss elif self.loss_type.lower() in ['cross_entropy', 'categorical_crossentropy']: # 支持两种名称 criterion = nn.CrossEntropyLoss() else: raise ValueError(f"不支持的损失函数类型: {self.loss_type}") return self.model, optimizer, criterion class SegmentationDataset(Dataset): """图像分割数据集类""" def __init__(self, image_dir, mask_dir, input_shape, gray_to_index, num_classes): self.image_dir = image_dir self.mask_dir = mask_dir self.input_shape = input_shape self.gray_to_index = gray_to_index self.num_classes = num_classes self.image_files = sorted(glob(os.path.join(image_dir, '*'))) self.mask_files = sorted(glob(os.path.join(mask_dir, '*'))) if not self.image_files: raise FileNotFoundError(f"在 {image_dir} 中未找到任何图像文件") if not self.mask_files: raise FileNotFoundError(f"在 {mask_dir} 中未找到任何掩码文件") if len(self.image_files) != len(self.mask_files): print(f"警告: 数据集中图像文件数量 ({len(self.image_files)}) 与掩码文件数量 ({len(self.mask_files)}) 不匹配") self.transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def __len__(self): return min(len(self.image_files), len(self.mask_files)) def __getitem__(self, idx): # 载图像 img_path = self.image_files[idx] img = Image.open(img_path).convert('RGB') img = img.resize((self.input_shape[1], self.input_shape[0])) img = self.transform(img) # 载掩码 mask_path = self.mask_files[idx] mask = Image.open(mask_path).convert('L') mask = mask.resize((self.input_shape[1], self.input_shape[0]), Image.NEAREST) mask = np.array(mask) # 创建类别索引掩码 (整数类型) processed_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int64) for gray_val, class_idx in self.gray_to_index.items(): processed_mask[mask == gray_val] = class_idx return img, processed_mask def _load_datasets(self): """载训练、验证和测试数据集""" # 定义数据集路径 train_image_dir = os.path.join(self.database_path, 'train', 'images') train_mask_dir = os.path.join(self.database_path, 'train', 'masks') val_image_dir = os.path.join(self.database_path, 'val', 'images') val_mask_dir = os.path.join(self.database_path, 'val', 'masks') test_image_dir = os.path.join(self.database_path, 'test', 'images') test_mask_dir = os.path.join(self.database_path, 'test', 'masks') # 创建数据集实例 train_dataset = self.SegmentationDataset( train_image_dir, train_mask_dir, self.input_shape, self.gray_to_index, self.num_classes ) val_dataset = self.SegmentationDataset( val_image_dir, val_mask_dir, self.input_shape, self.gray_to_index, self.num_classes ) test_dataset = self.SegmentationDataset( test_image_dir, test_mask_dir, self.input_shape, self.gray_to_index, self.num_classes ) # 创建数据载器 train_loader = DataLoader( train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4 ) val_loader = DataLoader( val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=4 ) test_loader = DataLoader( test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=4 ) # 打印数据集信息 print(f"\n数据集载完成:") print(f" 训练集: {len(train_dataset)} 个样本") print(f" 验证集: {len(val_dataset)} 个样本") print(f" 测试集: {len(test_dataset)} 个样本") return train_loader, val_loader, test_loader def _calculate_metrics(self, y_true, y_pred, smooth=1e-6, verbose=True): """ 计算分割性能指标。 Args: y_true (np.array): 真实标签 (类别索引)。 y_pred (np.array): 预测结果 (概率)。 smooth (float): 防止除以零的小常数。 verbose (bool): 是否打印详细指标。 Returns: dict: 包含所有计算的指标。 """ if y_true.size == 0 or y_pred.size == 0: print("警告: 真实标签或预测结果为空,无法计算指标。") return {} # 将预测结果转换为类别索引 y_pred_argmax = np.argmax(y_pred, axis=1) # 展平图像以进行指标计算 y_pred_flat = y_pred_argmax.flatten() y_true_flat = y_true.flatten() metrics = {} epsilon = 1e-7 # 防止除以零 # 1. 全局Dice系数 (针对所有前景和背景像素点) # 背景索引 bg_index = self.gray_to_index[self.background_gray] # 将所有前景类别合并为一个"前景"类别 y_true_fg_flat = (y_true_flat != bg_index) y_pred_fg_flat = (y_pred_flat != bg_index) # True Positives (Global Foreground): 真实为前景,预测为前景 TP_global = np.sum(y_true_fg_flat & y_pred_fg_flat) # False Positives (Global Foreground): 真实为背景,预测为前景 FP_global = np.sum(~y_true_fg_flat & y_pred_fg_flat) # False Negatives (Global Foreground): 真实为前景,预测为背景 FN_global = np.sum(y_true_fg_flat & ~y_pred_fg_flat) # True Negatives (Global Background): 真实为背景,预测为背景 TN_global = np.sum(~y_true_fg_flat & ~y_pred_fg_flat) dice_global = (2. * TP_global) / (2 * TP_global + FP_global + FN_global + epsilon) metrics['Dice_Global'] = dice_global accuracy_global = (TP_global + TN_global) / (TP_global + TN_global + FP_global + FN_global + epsilon) metrics['Accuracy_Global'] = accuracy_global if verbose: print(f" 全局 Dice 系数: {metrics['Dice_Global']:.4f}") print(f" 全局准确率: {metrics['Accuracy_Global']:.4f}") # 2. 针对每个类别的像素点计算指标 for gray_val, class_idx in self.gray_to_index.items(): class_name = self.background_name if gray_val == self.background_gray else self.foreground_labels[gray_val] # 提取当前类别的TP, FP, FN, TN TP_class = np.sum((y_true_flat == class_idx) & (y_pred_flat == class_idx)) FP_class = np.sum((y_true_flat != class_idx) & (y_pred_flat == class_idx)) FN_class = np.sum((y_true_flat == class_idx) & (y_pred_flat != class_idx)) TN_class = np.sum((y_true_flat != class_idx) & (y_pred_flat != class_idx)) dice_class = (2. * TP_class) / (2 * TP_class + FP_class + FN_class + epsilon) metrics[f'{class_name}_Dice'] = dice_class accuracy_class = (TP_class + TN_class) / (TP_class + TN_class + FP_class + FN_class + epsilon) metrics[f'{class_name}_Accuracy'] = accuracy_class iou_class = TP_class / (TP_class + FP_class + FN_class + epsilon) metrics[f'{class_name}_IoU'] = iou_class precision_class = TP_class / (TP_class + FP_class + epsilon) metrics[f'{class_name}_Precision'] = precision_class recall_class = TP_class / (TP_class + FN_class + epsilon) metrics[f'{class_name}_Recall'] = recall_class specificity_class = TN_class / (TN_class + FP_class + epsilon) metrics[f'{class_name}_Specificity'] = specificity_class if verbose: print(f"\n --- 类别: {class_name} (灰度值: {gray_val}, 索引: {class_idx}) ---") print(f" Dice: {metrics[f'{class_name}_Dice']:.4f}") print(f" Accuracy: {metrics[f'{class_name}_Accuracy']:.4f}") print(f" IoU: {metrics[f'{class_name}_IoU']:.4f}") print(f" Precision: {metrics[f'{class_name}_Precision']:.4f}") print(f" Recall: {metrics[f'{class_name}_Recall']:.4f}") print(f" Specificity: {metrics[f'{class_name}_Specificity']:.4f}") return metrics def _select_best_model(self, all_metrics): """ 选择最佳模型的筛选规则(可自定义) Args: all_metrics (list): 包含所有epoch模型性能指标的列表 Returns: int: 最佳模型对应的epoch编号 """ # 默认规则:使用全局Dice系数作为主要指标,选择最高值的模型 best_epoch = -1 best_metric_value = -1 # 遍历所有模型的评估结果 for metrics in all_metrics: # 使用全局Dice系数作为选择标准 current_value = metrics.get('Dice_Global', -1) # 如果当前模型性能更好,更新最佳模型 if current_value > best_metric_value: best_metric_value = current_value best_epoch = metrics['Epoch'] print(f"\n最佳模型筛选结果: Epoch {best_epoch} (全局Dice系数: {best_metric_value:.4f})") return best_epoch def train(self): """ 训练Attention U-Net模型。 """ # 确保所有结果目录都存在 self._create_result_directories() # 确保所有目录已创建 # 构建模型 model, optimizer, criterion = self.build_model() # 载数据集 train_loader, val_loader, test_loader = self._load_datasets() # 初始化TensorBoard writer = SummaryWriter(self.tensorboard_dir) # 初始化变量 best_val_loss = float('inf') epochs_without_improvement = 0 all_metrics = [] # 确保日志目录存在 os.makedirs(os.path.dirname(self.train_log_path), exist_ok=True) # 准备训练日志CSV with open(self.train_log_path, 'w', newline='', encoding='utf-8') as csvfile: fieldnames = ['Epoch', 'Train_Loss', 'Val_Loss', 'Dice_Global', 'Accuracy_Global'] # 添每个类别的指标字段 for gray_val in self.gray_to_index: class_name = self.background_name if gray_val == self.background_gray else self.foreground_labels[gray_val] fieldnames.extend([ f'{class_name}_Dice', f'{class_name}_Accuracy', f'{class_name}_IoU', f'{class_name}_Precision', f'{class_name}_Recall', f'{class_name}_Specificity' ]) writer_csv = csv.DictWriter(csvfile, fieldnames=fieldnames) writer_csv.writeheader() print("\n开始训练模型...") for epoch in range(1, self.epochs + 1): # 训练阶段 model.train() train_loss = 0.0 for images, masks in tqdm(train_loader, desc=f"Epoch {epoch}/{self.epochs} [训练]"): images = images.to(self.device) masks = masks.to(self.device) # 前向传播 outputs = model(images) # 计算损失 if isinstance(criterion, nn.CrossEntropyLoss): # 对于CrossEntropyLoss,直接使用类别索引 loss = criterion(outputs, masks.long()) else: # 对于其他损失函数,需要将掩码转换为one-hot编码 masks_one_hot = F.one_hot(masks.long(), num_classes=self.num_classes).permute(0, 3, 1, 2).float() loss = criterion(outputs, masks_one_hot) # 反向传播和优化 optimizer.zero_grad() loss.backward() optimizer.step() train_loss += loss.item() * images.size(0) # 计算平均训练损失 train_loss = train_loss / len(train_loader.dataset) # 验证阶段 model.eval() val_loss = 0.0 all_val_preds = [] all_val_masks = [] with torch.no_grad(): for images, masks in tqdm(val_loader, desc=f"Epoch {epoch}/{self.epochs} [验证]"): images = images.to(self.device) masks = masks.to(self.device) # 前向传播 outputs = model(images) # 计算损失 if isinstance(criterion, nn.CrossEntropyLoss): loss = criterion(outputs, masks.long()) else: masks_one_hot = F.one_hot(masks.long(), num_classes=self.num_classes).permute(0, 3, 1, 2).float() loss = criterion(outputs, masks_one_hot) val_loss += loss.item() * images.size(0) # 收集预测结果和真实标签 all_val_preds.append(outputs.cpu().numpy()) all_val_masks.append(masks.cpu().numpy()) # 计算平均验证损失 val_loss = val_loss / len(val_loader.dataset) # 合并所有验证集的预测结果和真实标签 val_preds = np.concatenate(all_val_preds, axis=0) val_masks = np.concatenate(all_val_masks, axis=0) # 计算验证指标 val_metrics = self._calculate_metrics(val_masks, val_preds, verbose=False) val_metrics['Epoch'] = epoch val_metrics['Train_Loss'] = train_loss val_metrics['Val_Loss'] = val_loss # 记录到TensorBoard writer.add_scalar('Loss/Train', train_loss, epoch) writer.add_scalar('Loss/Validation', val_loss, epoch) writer.add_scalar('Metrics/Dice_Global', val_metrics['Dice_Global'], epoch) writer.add_scalar('Metrics/Accuracy_Global', val_metrics['Accuracy_Global'], epoch) # 保存指标到列表 all_metrics.append(val_metrics) # 确保临时目录存在再保存模型 os.makedirs(self.temp_model_dir, exist_ok=True) # 关键修复:确保目录存在 # 保存当前epoch的模型 epoch_model_path = os.path.join(self.temp_model_dir, f'epoch_{epoch:03d}.pth') torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'train_loss': train_loss, 'val_loss': val_loss, 'val_metrics': val_metrics }, epoch_model_path) print(f"Epoch {epoch:03d}: 模型已保存到 {epoch_model_path}") # 写入CSV日志 - 确保目录存在 os.makedirs(os.path.dirname(self.train_log_path), exist_ok=True) with open(self.train_log_path, 'a', newline='', encoding='utf-8') as csvfile: writer_csv = csv.DictWriter(csvfile, fieldnames=fieldnames) writer_csv.writerow(val_metrics) # 打印训练进度 print(f"Epoch {epoch}/{self.epochs} - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, " f"Val Dice Global: {val_metrics['Dice_Global']:.4f}, " f"Val Accuracy Global: {val_metrics['Accuracy_Global']:.4f}") # 学习率调度 if epoch > 1 and val_loss < best_val_loss: best_val_loss = val_loss epochs_without_improvement = 0 else: epochs_without_improvement += 1 if epochs_without_improvement >= self.lr_reduction_patience: # 降低学习率 for param_group in optimizer.param_groups: param_group['lr'] *= self.lr_reduction_factor print(f"Epoch {epoch}: 降低学习率至 {optimizer.param_groups[0]['lr']:.2e}") epochs_without_improvement = 0 # 早停检查 if epochs_without_improvement >= self.early_stopping_patience: print(f"Epoch {epoch}: 早停触发,停止训练") break # 关闭TensorBoard写入器 writer.close() # 训练完成后选择最佳模型 self._select_and_save_best_model(all_metrics) self._plot_performance_curves() self._test_best_model(test_loader) def _select_and_save_best_model(self, all_metrics): """选择并保存最佳模型""" best_epoch = self._select_best_model(all_metrics) if best_epoch > 0: # 载最佳模型 best_model_path = os.path.join(self.temp_model_dir, f'epoch_{best_epoch:03d}.pth') if os.path.exists(best_model_path): # 确保目标目录存在 os.makedirs(os.path.dirname(self.best_model_path), exist_ok=True) # 关键修复 # 复制到最终位置 shutil.copyfile(best_model_path, self.best_model_path) print(f"已将最佳模型 (Epoch {best_epoch}) 保存到: {self.best_model_path}") else: print(f"警告: 找不到最佳模型对应的文件 (Epoch {best_epoch})") else: print("警告: 没有找到有效的最佳模型") # 保存最佳模型性能 if best_epoch > 0: best_metrics = next(m for m in all_metrics if m['Epoch'] == best_epoch) with open(self.best_model_performance_path, 'w', encoding='utf-8') as f: f.write("Attention U-Net 最佳模型验证性能\n") f.write(f"保存时间: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") f.write(f"最佳模型路径: {self.best_model_path}\n") f.write(f"对应Epoch: {best_epoch}\n") f.write("-" * 50 + "\n") for key, value in best_metrics.items(): if isinstance(value, float): f.write(f"{key}: {value:.4f}\n") else: f.write(f"{key}: {value}\n") print(f"最佳模型的验证性能参数已记录在 '{self.best_model_performance_path}'") # 删除临时模型目录 if os.path.exists(self.temp_model_dir): print(f"删除临时模型文件目录:{self.temp_model_dir}") shutil.rmtree(self.temp_model_dir) def _test_best_model(self, test_loader): if not os.path.exists(self.best_model_path): print(f"错误:未找到最佳模型文件: {self.best_model_path}。请确保模型已成功训练并保存。") return print("\n正在载最佳模型进行测试...") try: # 修复模型载问题 - 添安全上下文管理器 import torch.serialization from numpy import scalar # 导入需要的NumPy类型 # 创建安全上下文载模型 with torch.serialization.safe_globals([scalar]): checkpoint = torch.load( self.best_model_path, map_location=self.device ) # 重建模型结构 input_channels = self.input_shape[2] if len(self.input_shape) == 3 else 3 model = self.AttentionUNetModel( input_channels=input_channels, num_classes=self.num_classes, filters_base=self.filters_base, kernel_size=self.kernel_size, dropout_rate=self.dropout_rate, attention_mechanism=self.attention_mechanism, output_activation=self.output_activation ).to(self.device) # 载状态字典 model.load_state_dict(checkpoint['model_state_dict']) model.eval() except Exception as e: print(f"错误: 无法载最佳模型 {self.best_model_path}. 错误: {e}") return # 进行测试 all_test_preds = [] all_test_masks = [] with torch.no_grad(): for images, masks in tqdm(test_loader, desc="测试最佳模型"): images = images.to(self.device) masks = masks.to(self.device) outputs = model(images) all_test_preds.append(outputs.cpu().numpy()) all_test_masks.append(masks.cpu().numpy()) # 合并结果 test_preds = np.concatenate(all_test_preds, axis=0) test_masks = np.concatenate(all_test_masks, axis=0) # 计算指标 test_metrics = self._calculate_metrics(test_masks, test_preds, verbose=True) # 保存测试性能 with open(self.best_model_performance_path, 'a', encoding='utf-8') as f: f.write("\n" + "=" * 50 + "\n") f.write("测试集性能:\n") for key, value in test_metrics.items(): if isinstance(value, float): f.write(f"{key}: {value:.4f}\n") else: f.write(f"{key}: {value}\n") print(f"最佳模型的测试性能参数已追到 '{self.best_model_performance_path}'") def _plot_performance_curves(self): """ 根据训练日志数据绘制参数-迭代次数曲线 """ print("\n正在绘制性能曲线图...") if not os.path.exists(self.train_log_path): print(f"错误:未找到训练日志文件: {self.train_log_path},无法绘制曲线图。") return # 确保目录存在 os.makedirs(self.performance_plots_dir, exist_ok=True) # 从CSV文件读取数据 epochs = [] data = {} # 读取CSV文件 with open(self.train_log_path, 'r', encoding='utf-8') as csvfile: reader = csv.DictReader(csvfile) for row in reader: epochs.append(int(row['Epoch'])) for key, value in row.items(): if key == 'Epoch': continue if key not in data: data[key] = [] try: data[key].append(float(value)) except: data[key].append(0.0) if not epochs: print("训练日志中没有足够的有效数据来绘制曲线。") return # 获取所有类别名称(包括背景和前景) all_classes = set() for key in data.keys(): # 只处理包含下划线且不以'Val'或'Train'开头的键 if '_' in key and not key.startswith(('Val_', 'Train_')): # 提取类别名称(例如:'background_Dice' -> 'background') class_name = key.split('_')[0] # 确保类别名称在已知类别中 if class_name in [self.background_name] + list(self.foreground_labels.values()): all_classes.add(class_name) # 1. 损失曲线 plt.figure(figsize=(10, 6)) plt.plot(epochs, data['Train_Loss'], label='训练损失', color='blue', linewidth=2) plt.plot(epochs, data['Val_Loss'], label='验证损失', color='red', linewidth=2) plt.title('训练和验证损失 vs. 迭代次数', fontsize=14) plt.xlabel('迭代次数', fontsize=12) plt.ylabel('损失值', fontsize=12) plt.xticks(range(min(epochs), max(epochs)+1, max(1, len(epochs)//10))) plt.legend(fontsize=12) plt.grid(True) plt.tight_layout() plt.savefig(os.path.join(self.performance_plots_dir, '损失迭代曲线.png')) plt.close() # 2. 全局Dice和Accuracy曲线 plt.figure(figsize=(10, 6)) plt.plot(epochs, data['Dice_Global'], label='全局Dice系数', color='blue', linewidth=2) plt.plot(epochs, data['Accuracy_Global'], label='全局准确率', color='red', linewidth=2) plt.title('全局Dice系数和准确率 vs. 迭代次数', fontsize=14) plt.xlabel('迭代次数', fontsize=12) plt.ylabel('值', fontsize=12) plt.xticks(range(min(epochs), max(epochs)+1, max(1, len(epochs)//10))) plt.legend(fontsize=12) plt.grid(True) plt.tight_layout() plt.savefig(os.path.join(self.performance_plots_dir, '全局Dice和准确率迭代曲线.png')) plt.close() # 3. 各类别IoU曲线 plt.figure(figsize=(10, 6)) for class_name in all_classes: plt.plot(epochs, data[f'{class_name}_IoU'], label=f'{class_name} IoU', linewidth=2) plt.title('各类IoU vs. 迭代次数', fontsize=14) plt.xlabel('迭代次数', fontsize=12) plt.ylabel('IoU值', fontsize=12) plt.xticks(range(min(epochs), max(epochs)+1, max(1, len(epochs)//10))) plt.legend(fontsize=12) plt.grid(True) plt.tight_layout() plt.savefig(os.path.join(self.performance_plots_dir, '各类IoU迭代曲线.png')) plt.close() # 4. 各类别Precision曲线 plt.figure(figsize=(10, 6)) for class_name in all_classes: plt.plot(epochs, data[f'{class_name}_Precision'], label=f'{class_name} 精确率', linewidth=2) plt.title('各类精确率 vs. 迭代次数', fontsize=14) plt.xlabel('迭代次数', fontsize=12) plt.ylabel('精确率值', fontsize=12) plt.xticks(range(min(epochs), max(epochs)+1, max(1, len(epochs)//10))) plt.legend(fontsize=12) plt.grid(True) plt.tight_layout() plt.savefig(os.path.join(self.performance_plots_dir, '各类精确率迭代曲线.png')) plt.close() # 5. 各类别Recall曲线 plt.figure(figsize=(10, 6)) for class_name in all_classes: plt.plot(epochs, data[f'{class_name}_Recall'], label=f'{class_name} 召回率', linewidth=2) plt.title('各类召回率 vs. 迭代次数', fontsize=14) plt.xlabel('迭代次数', fontsize=12) plt.ylabel('召回率值', fontsize=12) plt.xticks(range(min(epochs), max(epochs)+1, max(1, len(epochs)//10))) plt.legend(fontsize=12) plt.grid(True) plt.tight_layout() plt.savefig(os.path.join(self.performance_plots_dir, '各类召回率迭代曲线.png')) plt.close() # 6. 各类别Specificity曲线 plt.figure(figsize=(10, 6)) for class_name in all_classes: plt.plot(epochs, data[f'{class_name}_Specificity'], label=f'{class_name} 特异性', linewidth=2) plt.title('各类特异性 vs. 迭代次数', fontsize=14) plt.xlabel('迭代次数', fontsize=12) plt.ylabel('特异性值', fontsize=12) plt.xticks(range(min(epochs), max(epochs)+1, max(1, len(epochs)//10))) plt.legend(fontsize=12) plt.grid(True) plt.tight_layout() plt.savefig(os.path.join(self.performance_plots_dir, '各类特异性迭代曲线.png')) plt.close() # 7. 各类别Accuracy曲线 plt.figure(figsize=(10, 6)) for class_name in all_classes: plt.plot(epochs, data[f'{class_name}_Accuracy'], label=f'{class_name} 准确率', linewidth=2) plt.title('各类准确率 vs. 迭代次数', fontsize=14) plt.xlabel('迭代次数', fontsize=12) plt.ylabel('准确率值', fontsize=12) plt.xticks(range(min(epochs), max(epochs)+1, max(1, len(epochs)//10))) plt.legend(fontsize=12) plt.grid(True) plt.tight_layout() plt.savefig(os.path.join(self.performance_plots_dir, '各类准确率迭代曲线.png')) plt.close() # 8. 各类别Dice曲线 plt.figure(figsize=(10, 6)) for class_name in all_classes: plt.plot(epochs, data[f'{class_name}_Dice'], label=f'{class_name} Dice', linewidth=2) plt.title('各类Dice系数 vs. 迭代次数', fontsize=14) plt.xlabel('迭代次数', fontsize=12) plt.ylabel('Dice值', fontsize=12) plt.xticks(range(min(epochs), max(epochs)+1, max(1, len(epochs)//10))) plt.legend(fontsize=12) plt.grid(True) plt.tight_layout() plt.savefig(os.path.join(self.performance_plots_dir, '各类Dice迭代曲线.png')) plt.close() # 9. 每个类别的Dice和IoU曲线(n张图) for class_name in all_classes: plt.figure(figsize=(10, 6)) plt.plot(epochs, data[f'{class_name}_Dice'], label=f'{class_name} Dice', color='blue', linewidth=2) plt.plot(epochs, data[f'{class_name}_IoU'], label=f'{class_name} IoU', color='red', linewidth=2) plt.title(f'{class_name}: Dice和IoU vs. 迭代次数', fontsize=14) plt.xlabel('迭代次数', fontsize=12) plt.ylabel('值', fontsize=12) plt.xticks(range(min(epochs), max(epochs)+1, max(1, len(epochs)//10))) plt.legend(fontsize=12) plt.grid(True) plt.tight_layout() plt.savefig(os.path.join(self.performance_plots_dir, f'{class_name}_Dice+IoU迭代曲线.png')) plt.close() # 10. 每个类别的Precision和Recall曲线(n张图) for class_name in all_classes: plt.figure(figsize=(10, 6)) plt.plot(epochs, data[f'{class_name}_Precision'], label=f'{class_name} 精确率', color='blue', linewidth=2) plt.plot(epochs, data[f'{class_name}_Recall'], label=f'{class_name} 召回率', color='red', linewidth=2) plt.title(f'{class_name}: 精确率和召回率 vs. 迭代次数', fontsize=14) plt.xlabel('迭代次数', fontsize=12) plt.ylabel('值', fontsize=12) plt.xticks(range(min(epochs), max(epochs)+1, max(1, len(epochs)//10))) plt.legend(fontsize=12) plt.grid(True) plt.tight_layout() plt.savefig(os.path.join(self.performance_plots_dir, f'{class_name}_精确率+召回率迭代曲线.png')) plt.close() print(f"性能曲线图已保存到 '{self.performance_plots_dir}'") print(f"总共生成 {8 + 2*len(all_classes)} 张曲线图") print("\n所有任务已完成!") #------------------------------------- 示例用法 -------------------------------------------- # 定义模型基础参数 input_shape = (512, 512, 3) classes = { 'class_number': 3, 'bg': (2, 'background'), 'fg': {0: 'cup', 1: 'disc'} } epochs = 5 result_base_path = '/content/drive/MyDrive/results' database_base_path = '/content/drive/MyDrive/database' # 实例化模型类 unet_model = AttentionUNet( input_shape=input_shape, classes=classes, epochs=epochs, result_path=result_base_path, database_path=database_base_path, learning_rate=1e-4, batch_size=5, early_stopping_patience=3, lr_reduction_patience=2, dropout_rate=0.5, filters_base=16, attention_mechanism='additive', loss='cross_entropy' # 修改为PyTorch支持的损失函数 ) # 构建并训练模型 unet_model.train() #------------------------------------------------------------------------------------------ 上述代码运行时报错: 正在绘制性能曲线图... --------------------------------------------------------------------------- KeyError Traceback (most recent call last) <ipython-input-8-2018641610> in <cell line: 0>() 1143 1144 # 构建并训练模型 -> 1145 unet_model.train() 1146 #------------------------------------------------------------------------------------------ 1 frames <ipython-input-8-2018641610> in _plot_performance_curves(self) 941 reader = csv.DictReader(csvfile) 942 for row in reader: --> 943 epochs.append(int(row['Epoch'])) 944 for key, value in row.items(): 945 if key == 'Epoch': KeyError: 'Epoch' 请修改
06-17
import sys import os import cv2 import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, transforms from PyQt5.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, QLabel, QPushButton, QFileDialog, QGroupBox, QScrollArea, QProgressBar, QMessageBox) from PyQt5.QtGui import QPixmap, QImage, QFont, QIcon from PyQt5.QtCore import Qt, QThread, pyqtSignal class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 32, 3, 1) # 第一卷积层: 1输入通道, 32输出通道, 3x3卷积核, 步长1 self.conv2 = nn.Conv2d(32, 64, 3, 1) # 第二卷积层: 32输入通道, 64输出通道, 3x3卷积核, 步长1 self.dropout1 = nn.Dropout2d(0.25) # 第一Dropout: 25%的神经元随机失活 self.dropout2 = nn.Dropout2d(0.5) # 第二Dropout: 50%的神经元随机失活 self.fc1 = nn.Linear(9216, 128) # 第一全连接层: 9216输入, 128输出 self.fc2 = nn.Linear(128, 10) # 第二全连接层: 128输入, 10输出(对应0-9数字) def forward(self, x): x = self.conv1(x) # 第一卷积层 x = F.relu(x) # ReLU激活函数 x = self.conv2(x) # 第二卷积层 x = F.relu(x) # ReLU激活函数 x = F.max_pool2d(x, 2) # 最大池化层, 池化窗口2x2 x = self.dropout1(x) # 第一Dropout层 x = torch.flatten(x, 1) # 展平张量, 从第1维度开始 x = self.fc1(x) # 第一全连接层 x = F.relu(x) # ReLU激活函数 x = self.dropout2(x) # 第二Dropout层 x = self.fc2(x) # 第二全连接层 output = F.log_softmax(x, dim=1) # 对数softmax激活, 沿第1维度 return output # 创建CNN模型 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = Net().to(device) net.eval() # 设置为评估模式 # 创建模型实例 class TrainThread(QThread): progress = pyqtSignal(int) # 进度信号 message = pyqtSignal(str) # 消息信号 finished = pyqtSignal() # 完成信号 def run(self): try: # 设置训练参数 batch_size = 64 # 批量大小 epochs = 10 # 训练轮数 # 数据预处理 transform = transforms.Compose([ transforms.ToTensor(), # 转换为张量 transforms.Normalize((0.1307,), (0.3081,)) # 标准化(MNIST数据集的均值和标准差) ]) # 载MNIST数据集 self.message.emit("正在下载MNIST数据集...") train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform) test_dataset = datasets.MNIST('./data', train=False, transform=transform) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False) # 定义优化器 optimizer = optim.Adam(net.parameters()) # 使用Adam优化器 # 训练模型 self.message.emit("开始训练模型...") for epoch in range(1, epochs + 1): net.train() # 设置为训练模式 for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) # 移动数据到设备 optimizer.zero_grad() # 梯度清零 output = net(data) # 前向传播 loss = F.nll_loss(output, target) # 计算负对数似然损失 loss.backward() # 反向传播 optimizer.step() # 更新参数 # 更新进度 progress = int(100. * (batch_idx + 1) / len(train_loader)) self.progress.emit(progress) # 测试模型 net.eval() # 设置为评估模式 test_loss = 0 correct = 0 with torch.no_grad(): # 不计算梯度 for data, target in test_loader: data, target = data.to(device), target.to(device) output = net(data) test_loss += F.nll_loss(output, target, reduction='sum').item() pred = output.argmax(dim=1, keepdim=True) # 获取预测结果 correct += pred.eq(target.view_as(pred)).sum().item() # 计算正确数 test_loss /= len(test_loader.dataset) accuracy = 100. * correct / len(test_loader.dataset) self.message.emit(f"Epoch {epoch}/{epochs} - 测试准确率: {accuracy:.2f}%") # 保存模型 torch.save(net.state_dict(), 'mnist_cnn.pt') self.message.emit("模型训练完成并已保存!") self.finished.emit() except Exception as e: self.message.emit(f"训练出错: {str(e)}") self.finished.emit() class HandwritingRecognitionApp(QMainWindow): def __init__(self): super().__init__() self.setWindowTitle("手写数字识别系统") # 设置窗口标题 self.setGeometry(100, 100, 1000, 700) # 设置窗口位置和大小(x,y,width,height) # 设置主控件和布局 main_widget = QWidget() main_layout = QVBoxLayout() # 垂直布局 main_widget.setLayout(main_layout) self.setCentralWidget(main_widget) # 设置为主窗口的中心控件 # 创建顶部控制区域 control_group = QGroupBox("控制面板") # 分组框 control_layout = QHBoxLayout() # 水平布局 control_group.setLayout(control_layout) # 创建按钮 self.select_folder_btn = QPushButton("选择文件夹并预处理") self.select_image_btn = QPushButton("选择图像并识别") self.train_model_btn = QPushButton("训练模型") control_layout.addWidget(self.select_folder_btn) control_layout.addWidget(self.select_image_btn) control_layout.addWidget(self.train_model_btn) # 训练进度条 self.progress_bar = QProgressBar() self.progress_bar.setVisible(False) # 初始不可见 # 创建图像和结果区域 content_layout = QHBoxLayout() # 水平布局 # 左侧图像区域 image_group = QGroupBox("图像预览") image_layout = QVBoxLayout() # 垂直布局 image_group.setLayout(image_layout) self.original_label = QLabel("原始图像") self.original_label.setAlignment(Qt.AlignCenter) self.original_label.setMinimumSize(300, 300) self.original_label.setStyleSheet("background-color: #f0f0f0; border: 1px solid #ccc;") self.processed_label = QLabel("处理后图像") self.processed_label.setAlignment(Qt.AlignCenter) self.processed_label.setMinimumSize(300, 300) self.processed_label.setStyleSheet("background-color: #f0f0f0; border: 1px solid #ccc;") image_layout.addWidget(QLabel("原始图像:")) image_layout.addWidget(self.original_label) image_layout.addWidget(QLabel("处理后图像 (28x28):")) image_layout.addWidget(self.processed_label) # 右侧结果区域 result_group = QGroupBox("识别结果") result_layout = QVBoxLayout() # 垂直布局 result_group.setLayout(result_layout) self.prediction_label = QLabel("预测结果: -") self.prediction_label.setFont(QFont("Arial", 24, QFont.Bold)) self.prediction_label.setAlignment(Qt.AlignCenter) self.prediction_label.setStyleSheet("color: #FF5722;") # 创建滚动区域显示概率 scroll_area = QScrollArea() scroll_widget = QWidget() self.prob_layout = QVBoxLayout(scroll_widget) scroll_area.setWidget(scroll_widget) scroll_area.setWidgetResizable(True) scroll_area.setStyleSheet("background-color: white; border: 1px solid #ddd;") # 初始化概率标签 self.prob_labels = [] for i in range(10): label = QLabel(f"{i}: 0.00%") label.setFont(QFont("Arial", 10)) self.prob_labels.append(label) self.prob_layout.addWidget(label) result_layout.addWidget(self.prediction_label) result_layout.addWidget(QLabel("概率分布:")) result_layout.addWidget(scroll_area) content_layout.addWidget(image_group, 40) # 权重40% content_layout.addWidget(result_group, 60) # 权重60% # 状态栏 self.status_label = QLabel("就绪") self.status_label.setStyleSheet("color: #666; padding: 5px; border-top: 1px solid #ddd;") # 添到主布局 main_layout.addWidget(control_group) main_layout.addWidget(self.progress_bar) main_layout.addLayout(content_layout, 70) # 权重70% main_layout.addWidget(self.status_label) # 连接信号和槽 self.select_folder_btn.clicked.connect(self.process_folder) self.select_image_btn.clicked.connect(self.select_and_recognize) self.train_model_btn.clicked.connect(self.start_training) # 初始化变量 self.current_image_path = None self.train_thread = None # 尝试载预训练模型 self.load_pretrained_model() def load_pretrained_model(self): """尝试载预训练模型""" model_path = 'mnist_cnn.pt' if os.path.exists(model_path): try: net.load_state_dict(torch.load(model_path, map_location=device)) net.eval() self.status_label.setText("预训练模型载成功!") except Exception as e: self.status_label.setText(f"模型载失败: {str(e)}") else: self.status_label.setText("警告: 未找到预训练模型,请先训练模型") def start_training(self): """开始训练模型""" if self.train_thread and self.train_thread.isRunning(): QMessageBox.warning(self, "警告", "模型训练已在运行中!") return self.train_thread = TrainThread() self.train_thread.progress.connect(self.update_progress) self.train_thread.message.connect(self.status_label.setText) self.train_thread.finished.connect(self.training_finished) self.progress_bar.setVisible(True) self.progress_bar.setValue(0) self.train_model_btn.setEnabled(False) self.train_model_btn.setText("训练中...") self.train_thread.start() def update_progress(self, value): """更新训练进度""" self.progress_bar.setValue(value) def training_finished(self): """训练完成""" self.progress_bar.setVisible(False) self.train_model_btn.setEnabled(True) self.train_model_btn.setText("训练模型") QMessageBox.information(self, "完成", "模型训练完成!") self.load_pretrained_model() def process_folder(self): """处理整个文件夹中的图像""" folder_path = QFileDialog.getExistingDirectory(self, "选择文件夹", os.getcwd()) if not folder_path: return self.status_label.setText(f"正在处理文件夹: {folder_path}...") QApplication.processEvents() # 强制刷新UI try: file_list = os.listdir(folder_path) processed_count = 0 for file_name in file_list: if not file_name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp')): continue path = os.path.join(folder_path, file_name) img = cv2.imread(path, 0) # 以灰度模式读取图像 if img is None: continue dst = 255 - img # 图像反色(白底黑字转为黑底白字) if img.shape[:2] != (28, 28): # 如果不是28x28大小 dst = cv2.resize(dst, (28, 28)) # 调整大小为28x28 cv2.imwrite(path, dst) # 保存处理后的图像 processed_count += 1 self.status_label.setText(f"处理完成! 共处理 {processed_count} 张图像") except Exception as e: self.status_label.setText(f"处理出错: {str(e)}") def select_and_recognize(self): """选择单张图像并进行识别""" file_path, _ = QFileDialog.getOpenFileName( self, "选择图像", os.getcwd(), "图像文件 (*.png *.jpg *.jpeg *.bmp)" ) if not file_path: return self.current_image_path = file_path self.status_label.setText(f"正在处理图像: {os.path.basename(file_path)}") QApplication.processEvents() # 强制刷新UI try: # 显示原始图像 self.display_image(file_path, self.original_label) # 读取并处理图像 img = cv2.imread(file_path, 0) # 以灰度模式读取图像 if img is None: self.status_label.setText(f"无法读取图像: {file_path}") return dst = 255 - img # 图像反色 if img.shape[:2] != (28, 28): # 如果不是28x28大小 dst = cv2.resize(dst, (28, 28)) # 调整大小为28x28 # 显示处理后的图像 self.display_processed_image(dst, self.processed_label) # 模型推理 img_tensor = self.prepare_image(dst) output = net(img_tensor) prob = F.softmax(output, dim=1) # 计算softmax概率 prob = prob.cpu().detach().numpy() # 移动到CPU并转换为numpy数组 # 获取预测结果 pred = np.argmax(prob) # 获取概率最大的类别 # 显示结果 self.prediction_label.setText(f"预测结果: {pred}") # 显示概率分布 prob_values = prob[0] * 100 # 转换为百分比 for i, label in enumerate(self.prob_labels): prob_text = f"{i}: {prob_values[i]:.2f}%" label.setText(prob_text) # 高亮显示预测结果对应的概率 label.setStyleSheet( "font-weight: bold; color: #FF5722; background-color: #FFF3E0;" if i == pred else "") self.status_label.setText(f"识别完成: {os.path.basename(file_path)} -> 数字 {pred}") except Exception as e: self.status_label.setText(f"识别出错: {str(e)}") def display_image(self, path, label): """在标签上显示图像""" pixmap = QPixmap(path) if pixmap.isNull(): label.setText("无法载图像") return scaled_pixmap = pixmap.scaled( label.width(), label.height(), Qt.KeepAspectRatio, Qt.SmoothTransformation ) label.setPixmap(scaled_pixmap) def display_processed_image(self, img_array, label): """在标签上显示处理后的图像数组""" height, width = img_array.shape bytes_per_line = width # 创建QImage对象 q_img = QImage(img_array.data, width, height, bytes_per_line, QImage.Format_Grayscale8) pixmap = QPixmap.fromImage(q_img) scaled_pixmap = pixmap.scaled( label.width(), label.height(), Qt.KeepAspectRatio, Qt.SmoothTransformation ) label.setPixmap(scaled_pixmap) def prepare_image(self, img_array): """准备图像用于模型输入""" img_normalized = img_array.astype(np.float32) / 255.0 # 归一化到[0,1] # 标准化(使用MNIST数据集的均值和标准差) img_normalized = (img_normalized - 0.1307) / 0.3081 # 添批次和通道维度 img = np.expand_dims(np.expand_dims(img_normalized, 0), 0) img_tensor = torch.from_numpy(img) # 转换为PyTorch张量 img_tensor = img_tensor.to(device) # 移动到指定设备 return img_tensor if __name__ == "__main__": app = QApplication(sys.argv) app.setStyle("Fusion") # 设置应用程序样式 window = HandwritingRecognitionApp() window.show() sys.exit(app.exec_())
06-29
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

源代码杀手

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值