F.binary_cross_entropy与sklearn.metric.log_loss的比较

F.binary_cross_entropysklearn.metrics.log_loss 都可以用来计算二分类或多分类任务中的交叉熵损失,但它们的使用场景、实现方式和功能有所不同。以下是两者的主要区别:


1. 所属库

  • F.binary_cross_entropy

    • 来自 PyTorch 的 torch.nn.functional 模块。
    • 主要用于深度学习模型训练时计算损失值,并支持自动求导(autograd)以便反向传播优化模型参数。
  • log_loss

    • 来自 Scikit-learn 的 sklearn.metrics 模块。
    • 主要用于评估模型性能,通常在模型训练完成后使用,不支持自动求导。

2. 输入格式

  • F.binary_cross_entropy

    • 输入是张量(tensor),适用于 PyTorch 模型。
    • 要求输入的预测值为概率值(通常是通过激活函数如 Sigmoid 输出的值,范围 [0, 1])。
    • 标签也必须是张量,形状与预测值一致。
    import torch
    import torch.nn.functional as F
    
    preds = torch.tensor([0.9, 0.1, 0.8])  # 模型预测的概率
    targets = torch.tensor([1, 0, 1])      # 真实标签
    loss = F.binary_cross_entropy(preds, targets)
    
  • log_loss

    • 输入是 NumPy 数组或 Python 列表。
    • 预测值也可以是概率值(范围 [0, 1]),但标签通常是以整数形式表示的类别(例如 0 或 1)。
    from sklearn.metrics import log_loss
    
    preds = [[0.9], [0.1], [0.8]]  # 模型预测的概率
    targets = [1, 0, 1]            # 真实标签
    loss = log_loss(targets, preds)
    

3. 计算方式

  • F.binary_cross_entropy

    • 直接计算二分类交叉熵损失。
    • 公式如下:
      Loss=−1N∑i=1N[yi⋅log⁡(pi)+(1−yi)⋅log⁡(1−pi)]\text{Loss} = -\frac{1}{N} \sum_{i=1}^{N} \left[ y_i \cdot \log(p_i) + (1 - y_i) \cdot \log(1 - p_i) \right]Loss=N1i=1N[yilog(pi)+(1yi)log(1pi)]
      • yiy_iyi 是真实标签(0 或 1)。
      • pip_ipi 是预测概率(范围 [0, 1])。
    • 支持逐元素计算,返回的是一个标量(平均损失)。
  • log_loss

    • 默认计算多分类交叉熵损失,但可以通过设置 labels 参数处理二分类问题。
    • 对于二分类问题,公式相同,但输入格式可能略有不同(如需要二维数组)。
    • 自动对多个样本取平均。

4. 功能与用途

  • F.binary_cross_entropy

    • 用于模型训练期间计算损失值。
    • 支持自动求导,便于反向传播更新模型参数。
    • 可用于动态调整模型。
  • log_loss

    • 用于模型评估阶段,衡量模型预测的质量。
    • 无法直接用于模型训练,因为没有自动求导功能。

5. 是否支持加权

  • F.binary_cross_entropy

    • 支持通过 weight 参数为每个样本或类别设置权重。
      loss = F.binary_cross_entropy(preds, targets, weight=torch.tensor([0.5, 1.0]))
      
  • log_loss

    • 不支持样本权重,但可以通过预处理数据来模拟加权效果。

6. 多分类支持

  • F.binary_cross_entropy

    • 仅支持二分类问题。
    • 如果需要处理多分类问题,可以使用 F.cross_entropy
  • log_loss

    • 原生支持多分类问题,只需提供多维概率分布即可。

7. 性能与效率

  • F.binary_cross_entropy

    • 使用 GPU 加速时性能更高,适合大规模深度学习任务。
    • 需要将数据转换为张量格式。
  • log_loss

    • 通常运行在 CPU 上,适合小规模数据集或模型评估。
    • 更方便直接使用 NumPy 数据。

总结对比表

特性F.binary_cross_entropylog_loss
所属库PyTorchScikit-learn
适用场景模型训练模型评估
输入格式张量NumPy 数组或列表
是否支持自动求导
是否支持加权
多分类支持不支持(需用 F.cross_entropy支持
性能高效(支持 GPU)一般(CPU 为主)

选择建议

  • 如果正在使用 PyTorch 进行深度学习模型训练,推荐使用 F.binary_cross_entropy
  • 如果已经完成模型训练并希望评估模型性能,推荐使用 log_loss
  • 如果需要处理多分类问题,可以使用 F.cross_entropy(PyTorch)或 log_loss(Scikit-learn)。
import os import datetime import numpy as np import matplotlib.pyplot as plt import shutil from glob import glob from tqdm import tqdm from PIL import Image import csv import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from torchvision import transforms from torch.utils.tensorboard import SummaryWriter from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score, accuracy_score class AttentionUNet: def __init__(self, input_shape, classes, epochs, result_path, database_path, learning_rate=1e-4, batch_size=8, early_stopping_patience=10, lr_reduction_patience=5, lr_reduction_factor=0.5, dropout_rate=0.3, filters_base=32, kernel_size=3, activation='relu', output_activation='softmax', optimizer='adam', loss='cross_entropy', # 默认使用PyTorch支持的损失函数 metrics=['accuracy'], attention_mechanism='additive'): self.input_shape = input_shape self.num_classes = classes['class_number'] # 解析类别信息 self.background_gray = classes['bg'][0] self.background_name = classes['bg'][1] self.foreground_labels = classes['fg'] # 创建灰度值到类别索引的映射 self.gray_to_index = {self.background_gray: 0} # 背景映射到索引0 # 前景映射到索引1,2,... for idx, (gray_val, name) in enumerate(self.foreground_labels.items(), start=1): self.gray_to_index[gray_val] = idx # 验证映射 if len(self.gray_to_index) != self.num_classes: raise ValueError(f"类别数量不匹配: 配置的类别数={self.num_classes}, 实际映射的类别数={len(self.gray_to_index)}") self.epochs = epochs self.result_path = result_path self.database_path = database_path self.learning_rate = learning_rate self.batch_size = batch_size self.early_stopping_patience = early_stopping_patience self.lr_reduction_patience = lr_reduction_patience self.lr_reduction_factor = lr_reduction_factor self.dropout_rate = dropout_rate self.filters_base = filters_base self.kernel_size = kernel_size self.activation = activation self.output_activation = output_activation self.optimizer_type = optimizer self.loss_type = loss self.metrics = metrics self.attention_mechanism = attention_mechanism self.model = None self.history = None self.best_model_path = os.path.join(self.result_path, 'models', 'best_model.pth') self.temp_model_dir = os.path.join(self.result_path, 'models', 'temp_epochs') self.train_log_path = os.path.join(self.result_path, 'diagram', '训练日志.csv') self.best_model_performance_path = os.path.join(self.result_path, 'diagram', '最佳模型性能.txt') self.performance_plots_dir = os.path.join(self.result_path, 'diagram', 'performance_plots') self.tensorboard_dir = os.path.join(self.result_path, 'tensorboard') # 创建结果目录结构 self._create_result_directories() self._print_model_config() # 设置设备 self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"使用设备: {self.device}") def _create_result_directories(self): """创建结果文件目录结构并保存配置。""" dirs_to_create = [ self.result_path, os.path.join(self.result_path, 'models'), os.path.join(self.result_path, 'diagram'), self.temp_model_dir, # 确保包含这个目录 self.performance_plots_dir, self.tensorboard_dir ] for dir_path in dirs_to_create: os.makedirs(dir_path, exist_ok=True) print(f"目录已创建: {dir_path}") print(f"结果文件目录结构已创建:{self.result_path}") # 保存模型配置 config_path = os.path.join(self.result_path, 'models', 'configuration.txt') try: with open(config_path, 'w', encoding='utf-8') as f: f.write("Attention U-Net 模型配置 (PyTorch 实现)\n") f.write(f"保存时间: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") f.write("="*50 + "\n\n") f.write("核心参数:\n") f.write(f"输入尺寸: {self.input_shape}\n") f.write(f"类别数: {self.num_classes}\n") f.write(f"背景类别: {self.background_name} (灰度值: {self.background_gray})\n") f.write("前景类别:\n") for gray_val, name in self.foreground_labels.items(): f.write(f" {name}: 灰度值 {gray_val}\n") f.write("\n训练参数:\n") f.write(f"训练轮数: {self.epochs}\n") f.write(f"批次大小: {self.batch_size}\n") f.write(f"学习率: {self.learning_rate}\n") f.write(f"早停耐心值: {self.early_stopping_patience}\n") f.write(f"学习率衰减耐心值: {self.lr_reduction_patience}\n") f.write(f"学习率衰减因子: {self.lr_reduction_factor}\n") f.write("\n模型架构参数:\n") f.write(f"基础卷积核数量: {self.filters_base}\n") f.write(f"卷积核大小: {self.kernel_size}\n") f.write(f"激活函数: {self.activation}\n") f.write(f"输出激活函数: {self.output_activation}\n") f.write(f"Dropout比率: {self.dropout_rate}\n") f.write(f"注意力机制类型: {self.attention_mechanism}\n") f.write("\n优化参数:\n") f.write(f"优化器: {self.optimizer_type}\n") f.write(f"损失函数: {self.loss_type}\n") f.write(f"监控指标: {', '.join(self.metrics)}\n") print(f"模型配置已保存到: {config_path}") except Exception as e: print(f"保存配置时出错: {e}") def _print_model_config(self): """在控制台打印模型配置。""" print("\n" + "="*30) print("模型配置:") print(f"输入尺寸: {self.input_shape}") print(f"类别数: {self.num_classes}") print(f"背景类别: {self.background_name} (灰度值: {self.background_gray} -> 索引: {self.gray_to_index[self.background_gray]})") print("前景类别:") for gray_val, name in self.foreground_labels.items(): print(f" {name} (灰度值: {gray_val} -> 索引: {self.gray_to_index[gray_val]})") print(f"训练轮数: {self.epochs}") print(f"结果文件路径: {self.result_path}") print(f"数据集文件路径: {self.database_path}") print(f"学习率: {self.learning_rate}") print(f"批次大小: {self.batch_size}") print(f"早停耐心值: {self.early_stopping_patience}") print(f"学习率衰减耐心值: {self.lr_reduction_patience}") print(f"学习率衰减因子: {self.lr_reduction_factor}") print(f"Dropout比率: {self.dropout_rate}") print(f"基础卷积核数量: {self.filters_base}") print(f"卷积核大小: {self.kernel_size}") print(f"激活函数: {self.activation}") print(f"输出激活函数: {self.output_activation}") print(f"优化器: {self.optimizer_type}") print(f"损失函数: {self.loss_type}") print(f"监控指标: {self.metrics}") print(f"注意力机制类型: {self.attention_mechanism}") print("="*30 + "\n") class AttentionUNetModel(nn.Module): """PyTorch实现的Attention U-Net模型""" def __init__(self, input_channels, num_classes, filters_base=32, kernel_size=3, dropout_rate=0.3, attention_mechanism='additive', output_activation='softmax'): super().__init__() self.input_channels = input_channels self.num_classes = num_classes self.filters_base = filters_base self.kernel_size = kernel_size self.dropout_rate = dropout_rate self.attention_mechanism = attention_mechanism self.output_activation = output_activation # 添加输出激活函数属性 # 编码器 (Encoder) self.enc1 = self._conv_block(input_channels, filters_base) self.enc2 = self._conv_block(filters_base, filters_base * 2) self.enc3 = self._conv_block(filters_base * 2, filters_base * 4) self.enc4 = self._conv_block(filters_base * 4, filters_base * 8) self.pool = nn.MaxPool2d(2) self.dropout = nn.Dropout2d(dropout_rate) # Bottleneck self.bottleneck = self._conv_block(filters_base * 8, filters_base * 16) # 解码器 (Decoder) self.up1 = self._up_block(filters_base * 16, filters_base * 8) self.att1 = self._attention_gate(filters_base * 8, filters_base * 8) self.dec1 = self._conv_block(filters_base * 16, filters_base * 8) self.up2 = self._up_block(filters_base * 8, filters_base * 4) self.att2 = self._attention_gate(filters_base * 4, filters_base * 4) self.dec2 = self._conv_block(filters_base * 8, filters_base * 4) self.up3 = self._up_block(filters_base * 4, filters_base * 2) self.att3 = self._attention_gate(filters_base * 2, filters_base * 2) self.dec3 = self._conv_block(filters_base * 4, filters_base * 2) self.up4 = self._up_block(filters_base * 2, filters_base) self.att4 = self._attention_gate(filters_base, filters_base) self.dec4 = self._conv_block(filters_base * 2, filters_base) # 输出层 self.out_conv = nn.Conv2d(filters_base, num_classes, kernel_size=1) def _conv_block(self, in_channels, out_channels): """标准卷积块,包含Conv2D, BatchNorm, Activation。""" return nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=self.kernel_size, padding=self.kernel_size//2), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=self.kernel_size, padding=self.kernel_size//2), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def _up_block(self, in_channels, out_channels): """上采样块,包含ConvTranspose2d和Dropout。""" return nn.Sequential( nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2), nn.Dropout2d(self.dropout_rate) ) def _attention_gate(self, g_channels, x_channels): """注意力门 (Attention Gate)。""" if self.attention_mechanism == 'additive': # 加性注意力机制 return nn.Sequential( nn.Conv2d(g_channels, x_channels, kernel_size=1), nn.BatchNorm2d(x_channels), nn.ReLU(inplace=True), nn.Conv2d(x_channels, x_channels, kernel_size=1), nn.BatchNorm2d(x_channels), nn.Sigmoid() ) elif self.attention_mechanism == 'multiplicative': # 乘性注意力机制 return nn.Sequential( nn.Conv2d(g_channels + x_channels, x_channels, kernel_size=1), nn.BatchNorm2d(x_channels), nn.ReLU(inplace=True), nn.Conv2d(x_channels, 1, kernel_size=1), nn.BatchNorm2d(1), nn.Sigmoid() ) else: raise ValueError(f"不支持的注意力机制类型: {self.attention_mechanism}") def forward(self, x): # 编码器路径 enc1 = self.enc1(x) enc1_pool = self.pool(enc1) enc1_pool = self.dropout(enc1_pool) enc2 = self.enc2(enc1_pool) enc2_pool = self.pool(enc2) enc2_pool = self.dropout(enc2_pool) enc3 = self.enc3(enc2_pool) enc3_pool = self.pool(enc3) enc3_pool = self.dropout(enc3_pool) enc4 = self.enc4(enc3_pool) enc4_pool = self.pool(enc4) enc4_pool = self.dropout(enc4_pool) # 瓶颈层 bottleneck = self.bottleneck(enc4_pool) # 解码器路径 # 上采样块1 up1 = self.up1(bottleneck) if self.attention_mechanism == 'additive': att1 = self.att1(up1) att1 = att1 * enc4 else: att1 = self.att1(torch.cat([up1, enc4], dim=1)) att1 = att1 * enc4 merge1 = torch.cat([up1, att1], dim=1) dec1 = self.dec1(merge1) # 上采样块2 up2 = self.up2(dec1) if self.attention_mechanism == 'additive': att2 = self.att2(up2) att2 = att2 * enc3 else: att2 = self.att2(torch.cat([up2, enc3], dim=1)) att2 = att2 * enc3 merge2 = torch.cat([up2, att2], dim=1) dec2 = self.dec2(merge2) # 上采样块3 up3 = self.up3(dec2) if self.attention_mechanism == 'additive': att3 = self.att3(up3) att3 = att3 * enc2 else: att3 = self.att3(torch.cat([up3, enc2], dim=1)) att3 = att3 * enc2 merge3 = torch.cat([up3, att3], dim=1) dec3 = self.dec3(merge3) # 上采样块4 up4 = self.up4(dec3) if self.attention_mechanism == 'additive': att4 = self.att4(up4) att4 = att4 * enc1 else: att4 = self.att4(torch.cat([up4, enc1], dim=1)) att4 = att4 * enc1 merge4 = torch.cat([up4, att4], dim=1) dec4 = self.dec4(merge4) # 输出层 out = self.out_conv(dec4) # 应用输出激活函数 if self.output_activation == 'softmax': out = F.softmax(out, dim=1) elif self.output_activation == 'sigmoid': out = torch.sigmoid(out) return out def _dice_loss(self, y_pred, y_true, smooth=1e-6): """Dice损失函数实现""" # 展平预测和真实标签 y_pred_flat = y_pred.contiguous().view(-1) y_true_flat = y_true.contiguous().view(-1) # 计算交集和并集 intersection = (y_pred_flat * y_true_flat).sum() union = y_pred_flat.sum() + y_true_flat.sum() # 计算Dice系数 dice = (2. * intersection + smooth) / (union + smooth) # 返回Dice损失 return 1 - dice def _focal_loss(self, y_pred, y_true, alpha=0.25, gamma=2.0): """Focal损失函数实现""" # 计算交叉熵 ce_loss = F.binary_cross_entropy(y_pred, y_true, reduction='none') # 计算概率 p_t = y_true * y_pred + (1 - y_true) * (1 - y_pred) # 计算Focal损失 focal_loss = torch.pow(1 - p_t, gamma) * ce_loss # 应用alpha权重 if alpha is not None: alpha_t = y_true * alpha + (1 - y_true) * (1 - alpha) focal_loss = alpha_t * focal_loss return focal_loss.mean() def _jaccard_loss(self, y_pred, y_true, smooth=1e-6): """Jaccard损失函数实现(IoU损失)""" # 展平预测和真实标签 y_pred_flat = y_pred.contiguous().view(-1) y_true_flat = y_true.contiguous().view(-1) # 计算交集和并集 intersection = (y_pred_flat * y_true_flat).sum() total = y_pred_flat.sum() + y_true_flat.sum() union = total - intersection # 计算Jaccard指数(IoU) iou = (intersection + smooth) / (union + smooth) return 1 - iou def build_model(self): """构建Attention U-Net模型。""" input_channels = self.input_shape[2] if len(self.input_shape) == 3 else 3 self.model = self.AttentionUNetModel( input_channels=input_channels, num_classes=self.num_classes, filters_base=self.filters_base, kernel_size=self.kernel_size, dropout_rate=self.dropout_rate, attention_mechanism=self.attention_mechanism, output_activation=self.output_activation # 添加输出激活函数参数 ).to(self.device) print("Attention U-Net 模型已成功构建。") print(f"模型参数数量: {sum(p.numel() for p in self.model.parameters() if p.requires_grad):,}") # 选择优化器 if self.optimizer_type.lower() == 'adam': optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate) elif self.optimizer_type.lower() == 'sgd': optimizer = optim.SGD(self.model.parameters(), lr=self.learning_rate, momentum=0.9) elif self.optimizer_type.lower() == 'rmsprop': optimizer = optim.RMSprop(self.model.parameters(), lr=self.learning_rate) else: raise ValueError(f"不支持的优化器类型: {self.optimizer_type}") # 选择损失函数 if self.loss_type.lower() == 'dice_loss': criterion = self._dice_loss elif self.loss_type.lower() == 'focal_loss': criterion = self._focal_loss elif self.loss_type.lower() == 'jaccard_loss': criterion = self._jaccard_loss elif self.loss_type.lower() in ['cross_entropy', 'categorical_crossentropy']: # 支持两种名称 criterion = nn.CrossEntropyLoss() else: raise ValueError(f"不支持的损失函数类型: {self.loss_type}") return self.model, optimizer, criterion class SegmentationDataset(Dataset): """图像分割数据集类""" def __init__(self, image_dir, mask_dir, input_shape, gray_to_index, num_classes): self.image_dir = image_dir self.mask_dir = mask_dir self.input_shape = input_shape self.gray_to_index = gray_to_index self.num_classes = num_classes self.image_files = sorted(glob(os.path.join(image_dir, '*'))) self.mask_files = sorted(glob(os.path.join(mask_dir, '*'))) if not self.image_files: raise FileNotFoundError(f"在 {image_dir} 中未找到任何图像文件") if not self.mask_files: raise FileNotFoundError(f"在 {mask_dir} 中未找到任何掩码文件") if len(self.image_files) != len(self.mask_files): print(f"警告: 数据集中图像文件数量 ({len(self.image_files)}) 掩码文件数量 ({len(self.mask_files)}) 不匹配") self.transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def __len__(self): return min(len(self.image_files), len(self.mask_files)) def __getitem__(self, idx): # 加载图像 img_path = self.image_files[idx] img = Image.open(img_path).convert('RGB') img = img.resize((self.input_shape[1], self.input_shape[0])) img = self.transform(img) # 加载掩码 mask_path = self.mask_files[idx] mask = Image.open(mask_path).convert('L') mask = mask.resize((self.input_shape[1], self.input_shape[0]), Image.NEAREST) mask = np.array(mask) # 创建类别索引掩码 (整数类型) processed_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int64) for gray_val, class_idx in self.gray_to_index.items(): processed_mask[mask == gray_val] = class_idx return img, processed_mask def _load_datasets(self): """加载训练、验证和测试数据集""" # 定义数据集路径 train_image_dir = os.path.join(self.database_path, 'train', 'images') train_mask_dir = os.path.join(self.database_path, 'train', 'masks') val_image_dir = os.path.join(self.database_path, 'val', 'images') val_mask_dir = os.path.join(self.database_path, 'val', 'masks') test_image_dir = os.path.join(self.database_path, 'test', 'images') test_mask_dir = os.path.join(self.database_path, 'test', 'masks') # 创建数据集实例 train_dataset = self.SegmentationDataset( train_image_dir, train_mask_dir, self.input_shape, self.gray_to_index, self.num_classes ) val_dataset = self.SegmentationDataset( val_image_dir, val_mask_dir, self.input_shape, self.gray_to_index, self.num_classes ) test_dataset = self.SegmentationDataset( test_image_dir, test_mask_dir, self.input_shape, self.gray_to_index, self.num_classes ) # 创建数据加载器 train_loader = DataLoader( train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4 ) val_loader = DataLoader( val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=4 ) test_loader = DataLoader( test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=4 ) # 打印数据集信息 print(f"\n数据集加载完成:") print(f" 训练集: {len(train_dataset)} 个样本") print(f" 验证集: {len(val_dataset)} 个样本") print(f" 测试集: {len(test_dataset)} 个样本") return train_loader, val_loader, test_loader def _calculate_metrics(self, y_true, y_pred, smooth=1e-6, verbose=True): """ 计算分割性能指标。 Args: y_true (np.array): 真实标签 (类别索引)。 y_pred (np.array): 预测结果 (概率)。 smooth (float): 防止除以零的小常数。 verbose (bool): 是否打印详细指标。 Returns: dict: 包含所有计算的指标。 """ if y_true.size == 0 or y_pred.size == 0: print("警告: 真实标签或预测结果为空,无法计算指标。") return {} # 将预测结果转换为类别索引 y_pred_argmax = np.argmax(y_pred, axis=1) # 展平图像以进行指标计算 y_pred_flat = y_pred_argmax.flatten() y_true_flat = y_true.flatten() metrics = {} epsilon = 1e-7 # 防止除以零 # 1. 全局Dice系数 (针对所有前景和背景像素点) # 背景索引 bg_index = self.gray_to_index[self.background_gray] # 将所有前景类别合并为一个"前景"类别 y_true_fg_flat = (y_true_flat != bg_index) y_pred_fg_flat = (y_pred_flat != bg_index) # True Positives (Global Foreground): 真实为前景,预测为前景 TP_global = np.sum(y_true_fg_flat & y_pred_fg_flat) # False Positives (Global Foreground): 真实为背景,预测为前景 FP_global = np.sum(~y_true_fg_flat & y_pred_fg_flat) # False Negatives (Global Foreground): 真实为前景,预测为背景 FN_global = np.sum(y_true_fg_flat & ~y_pred_fg_flat) # True Negatives (Global Background): 真实为背景,预测为背景 TN_global = np.sum(~y_true_fg_flat & ~y_pred_fg_flat) dice_global = (2. * TP_global) / (2 * TP_global + FP_global + FN_global + epsilon) metrics['Dice_Global'] = dice_global accuracy_global = (TP_global + TN_global) / (TP_global + TN_global + FP_global + FN_global + epsilon) metrics['Accuracy_Global'] = accuracy_global if verbose: print(f" 全局 Dice 系数: {metrics['Dice_Global']:.4f}") print(f" 全局准确率: {metrics['Accuracy_Global']:.4f}") # 2. 针对每个类别的像素点计算指标 for gray_val, class_idx in self.gray_to_index.items(): class_name = self.background_name if gray_val == self.background_gray else self.foreground_labels[gray_val] # 提取当前类别的TP, FP, FN, TN TP_class = np.sum((y_true_flat == class_idx) & (y_pred_flat == class_idx)) FP_class = np.sum((y_true_flat != class_idx) & (y_pred_flat == class_idx)) FN_class = np.sum((y_true_flat == class_idx) & (y_pred_flat != class_idx)) TN_class = np.sum((y_true_flat != class_idx) & (y_pred_flat != class_idx)) dice_class = (2. * TP_class) / (2 * TP_class + FP_class + FN_class + epsilon) metrics[f'{class_name}_Dice'] = dice_class accuracy_class = (TP_class + TN_class) / (TP_class + TN_class + FP_class + FN_class + epsilon) metrics[f'{class_name}_Accuracy'] = accuracy_class iou_class = TP_class / (TP_class + FP_class + FN_class + epsilon) metrics[f'{class_name}_IoU'] = iou_class precision_class = TP_class / (TP_class + FP_class + epsilon) metrics[f'{class_name}_Precision'] = precision_class recall_class = TP_class / (TP_class + FN_class + epsilon) metrics[f'{class_name}_Recall'] = recall_class specificity_class = TN_class / (TN_class + FP_class + epsilon) metrics[f'{class_name}_Specificity'] = specificity_class if verbose: print(f"\n --- 类别: {class_name} (灰度值: {gray_val}, 索引: {class_idx}) ---") print(f" Dice: {metrics[f'{class_name}_Dice']:.4f}") print(f" Accuracy: {metrics[f'{class_name}_Accuracy']:.4f}") print(f" IoU: {metrics[f'{class_name}_IoU']:.4f}") print(f" Precision: {metrics[f'{class_name}_Precision']:.4f}") print(f" Recall: {metrics[f'{class_name}_Recall']:.4f}") print(f" Specificity: {metrics[f'{class_name}_Specificity']:.4f}") return metrics def _select_best_model(self, all_metrics): """ 选择最佳模型的筛选规则(可自定义) Args: all_metrics (list): 包含所有epoch模型性能指标的列表 Returns: int: 最佳模型对应的epoch编号 """ # 默认规则:使用全局Dice系数作为主要指标,选择最高值的模型 best_epoch = -1 best_metric_value = -1 # 遍历所有模型的评估结果 for metrics in all_metrics: # 使用全局Dice系数作为选择标准 current_value = metrics.get('Dice_Global', -1) # 如果当前模型性能更好,更新最佳模型 if current_value > best_metric_value: best_metric_value = current_value best_epoch = metrics['Epoch'] print(f"\n最佳模型筛选结果: Epoch {best_epoch} (全局Dice系数: {best_metric_value:.4f})") return best_epoch def train(self): """ 训练Attention U-Net模型。 """ # 确保所有结果目录都存在 self._create_result_directories() # 确保所有目录已创建 # 构建模型 model, optimizer, criterion = self.build_model() # 加载数据集 train_loader, val_loader, test_loader = self._load_datasets() # 初始化TensorBoard writer = SummaryWriter(self.tensorboard_dir) # 初始化变量 best_val_loss = float('inf') epochs_without_improvement = 0 all_metrics = [] # 确保日志目录存在 os.makedirs(os.path.dirname(self.train_log_path), exist_ok=True) # 准备训练日志CSV with open(self.train_log_path, 'w', newline='', encoding='utf-8') as csvfile: fieldnames = ['Epoch', 'Train_Loss', 'Val_Loss', 'Dice_Global', 'Accuracy_Global'] # 添加每个类别的指标字段 for gray_val in self.gray_to_index: class_name = self.background_name if gray_val == self.background_gray else self.foreground_labels[gray_val] fieldnames.extend([ f'{class_name}_Dice', f'{class_name}_Accuracy', f'{class_name}_IoU', f'{class_name}_Precision', f'{class_name}_Recall', f'{class_name}_Specificity' ]) writer_csv = csv.DictWriter(csvfile, fieldnames=fieldnames) writer_csv.writeheader() print("\n开始训练模型...") for epoch in range(1, self.epochs + 1): # 训练阶段 model.train() train_loss = 0.0 for images, masks in tqdm(train_loader, desc=f"Epoch {epoch}/{self.epochs} [训练]"): images = images.to(self.device) masks = masks.to(self.device) # 前向传播 outputs = model(images) # 计算损失 if isinstance(criterion, nn.CrossEntropyLoss): # 对于CrossEntropyLoss,直接使用类别索引 loss = criterion(outputs, masks.long()) else: # 对于其他损失函数,需要将掩码转换为one-hot编码 masks_one_hot = F.one_hot(masks.long(), num_classes=self.num_classes).permute(0, 3, 1, 2).float() loss = criterion(outputs, masks_one_hot) # 反向传播和优化 optimizer.zero_grad() loss.backward() optimizer.step() train_loss += loss.item() * images.size(0) # 计算平均训练损失 train_loss = train_loss / len(train_loader.dataset) # 验证阶段 model.eval() val_loss = 0.0 all_val_preds = [] all_val_masks = [] with torch.no_grad(): for images, masks in tqdm(val_loader, desc=f"Epoch {epoch}/{self.epochs} [验证]"): images = images.to(self.device) masks = masks.to(self.device) # 前向传播 outputs = model(images) # 计算损失 if isinstance(criterion, nn.CrossEntropyLoss): loss = criterion(outputs, masks.long()) else: masks_one_hot = F.one_hot(masks.long(), num_classes=self.num_classes).permute(0, 3, 1, 2).float() loss = criterion(outputs, masks_one_hot) val_loss += loss.item() * images.size(0) # 收集预测结果和真实标签 all_val_preds.append(outputs.cpu().numpy()) all_val_masks.append(masks.cpu().numpy()) # 计算平均验证损失 val_loss = val_loss / len(val_loader.dataset) # 合并所有验证集的预测结果和真实标签 val_preds = np.concatenate(all_val_preds, axis=0) val_masks = np.concatenate(all_val_masks, axis=0) # 计算验证指标 val_metrics = self._calculate_metrics(val_masks, val_preds, verbose=False) val_metrics['Epoch'] = epoch val_metrics['Train_Loss'] = train_loss val_metrics['Val_Loss'] = val_loss # 记录到TensorBoard writer.add_scalar('Loss/Train', train_loss, epoch) writer.add_scalar('Loss/Validation', val_loss, epoch) writer.add_scalar('Metrics/Dice_Global', val_metrics['Dice_Global'], epoch) writer.add_scalar('Metrics/Accuracy_Global', val_metrics['Accuracy_Global'], epoch) # 保存指标到列表 all_metrics.append(val_metrics) # 确保临时目录存在再保存模型 os.makedirs(self.temp_model_dir, exist_ok=True) # 关键修复:确保目录存在 # 保存当前epoch的模型 epoch_model_path = os.path.join(self.temp_model_dir, f'epoch_{epoch:03d}.pth') torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'train_loss': train_loss, 'val_loss': val_loss, 'val_metrics': val_metrics }, epoch_model_path) print(f"Epoch {epoch:03d}: 模型已保存到 {epoch_model_path}") # 写入CSV日志 - 确保目录存在 os.makedirs(os.path.dirname(self.train_log_path), exist_ok=True) with open(self.train_log_path, 'a', newline='', encoding='utf-8') as csvfile: writer_csv = csv.DictWriter(csvfile, fieldnames=fieldnames) writer_csv.writerow(val_metrics) # 打印训练进度 print(f"Epoch {epoch}/{self.epochs} - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, " f"Val Dice Global: {val_metrics['Dice_Global']:.4f}, " f"Val Accuracy Global: {val_metrics['Accuracy_Global']:.4f}") # 学习率调度 if epoch > 1 and val_loss < best_val_loss: best_val_loss = val_loss epochs_without_improvement = 0 else: epochs_without_improvement += 1 if epochs_without_improvement >= self.lr_reduction_patience: # 降低学习率 for param_group in optimizer.param_groups: param_group['lr'] *= self.lr_reduction_factor print(f"Epoch {epoch}: 降低学习率至 {optimizer.param_groups[0]['lr']:.2e}") epochs_without_improvement = 0 # 早停检查 if epochs_without_improvement >= self.early_stopping_patience: print(f"Epoch {epoch}: 早停触发,停止训练") break # 关闭TensorBoard写入器 writer.close() # 训练完成后选择最佳模型 self._select_and_save_best_model(all_metrics) self._plot_performance_curves() self._test_best_model(test_loader) def _select_and_save_best_model(self, all_metrics): """选择并保存最佳模型""" best_epoch = self._select_best_model(all_metrics) if best_epoch > 0: # 加载最佳模型 best_model_path = os.path.join(self.temp_model_dir, f'epoch_{best_epoch:03d}.pth') if os.path.exists(best_model_path): # 确保目标目录存在 os.makedirs(os.path.dirname(self.best_model_path), exist_ok=True) # 关键修复 # 复制到最终位置 shutil.copyfile(best_model_path, self.best_model_path) print(f"已将最佳模型 (Epoch {best_epoch}) 保存到: {self.best_model_path}") else: print(f"警告: 找不到最佳模型对应的文件 (Epoch {best_epoch})") else: print("警告: 没有找到有效的最佳模型") # 保存最佳模型性能 if best_epoch > 0: best_metrics = next(m for m in all_metrics if m['Epoch'] == best_epoch) with open(self.best_model_performance_path, 'w', encoding='utf-8') as f: f.write("Attention U-Net 最佳模型验证性能\n") f.write(f"保存时间: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") f.write(f"最佳模型路径: {self.best_model_path}\n") f.write(f"对应Epoch: {best_epoch}\n") f.write("-" * 50 + "\n") for key, value in best_metrics.items(): if isinstance(value, float): f.write(f"{key}: {value:.4f}\n") else: f.write(f"{key}: {value}\n") print(f"最佳模型的验证性能参数已记录在 '{self.best_model_performance_path}'") # 删除临时模型目录 if os.path.exists(self.temp_model_dir): print(f"删除临时模型文件目录:{self.temp_model_dir}") shutil.rmtree(self.temp_model_dir) def _test_best_model(self, test_loader): if not os.path.exists(self.best_model_path): print(f"错误:未找到最佳模型文件: {self.best_model_path}。请确保模型已成功训练并保存。") return print("\n正在加载最佳模型进行测试...") try: # 修复模型加载问题 - 添加安全上下文管理器 import torch.serialization from numpy import scalar # 导入需要的NumPy类型 # 创建安全上下文加载模型 with torch.serialization.safe_globals([scalar]): checkpoint = torch.load( self.best_model_path, map_location=self.device ) # 重建模型结构 input_channels = self.input_shape[2] if len(self.input_shape) == 3 else 3 model = self.AttentionUNetModel( input_channels=input_channels, num_classes=self.num_classes, filters_base=self.filters_base, kernel_size=self.kernel_size, dropout_rate=self.dropout_rate, attention_mechanism=self.attention_mechanism, output_activation=self.output_activation ).to(self.device) # 加载状态字典 model.load_state_dict(checkpoint['model_state_dict']) model.eval() except Exception as e: print(f"错误: 无法加载最佳模型 {self.best_model_path}. 错误: {e}") return # 进行测试 all_test_preds = [] all_test_masks = [] with torch.no_grad(): for images, masks in tqdm(test_loader, desc="测试最佳模型"): images = images.to(self.device) masks = masks.to(self.device) outputs = model(images) all_test_preds.append(outputs.cpu().numpy()) all_test_masks.append(masks.cpu().numpy()) # 合并结果 test_preds = np.concatenate(all_test_preds, axis=0) test_masks = np.concatenate(all_test_masks, axis=0) # 计算指标 test_metrics = self._calculate_metrics(test_masks, test_preds, verbose=True) # 保存测试性能 with open(self.best_model_performance_path, 'a', encoding='utf-8') as f: f.write("\n" + "=" * 50 + "\n") f.write("测试集性能:\n") for key, value in test_metrics.items(): if isinstance(value, float): f.write(f"{key}: {value:.4f}\n") else: f.write(f"{key}: {value}\n") print(f"最佳模型的测试性能参数已追加到 '{self.best_model_performance_path}'") def _plot_performance_curves(self): """ 根据训练日志数据绘制参数-迭代次数曲线 """ print("\n正在绘制性能曲线图...") if not os.path.exists(self.train_log_path): print(f"错误:未找到训练日志文件: {self.train_log_path},无法绘制曲线图。") return # 确保目录存在 os.makedirs(self.performance_plots_dir, exist_ok=True) # 从CSV文件读取数据 epochs = [] data = {} # 读取CSV文件 with open(self.train_log_path, 'r', encoding='utf-8') as csvfile: reader = csv.DictReader(csvfile) for row in reader: epochs.append(int(row['Epoch'])) for key, value in row.items(): if key == 'Epoch': continue if key not in data: data[key] = [] try: data[key].append(float(value)) except: data[key].append(0.0) if not epochs: print("训练日志中没有足够的有效数据来绘制曲线。") return # 获取所有类别名称(包括背景和前景) all_classes = set() for key in data.keys(): # 只处理包含下划线且不以'Val'或'Train'开头的键 if '_' in key and not key.startswith(('Val_', 'Train_')): # 提取类别名称(例如:'background_Dice' -> 'background') class_name = key.split('_')[0] # 确保类别名称在已知类别中 if class_name in [self.background_name] + list(self.foreground_labels.values()): all_classes.add(class_name) # 1. 损失曲线 plt.figure(figsize=(10, 6)) plt.plot(epochs, data['Train_Loss'], label='训练损失', color='blue', linewidth=2) plt.plot(epochs, data['Val_Loss'], label='验证损失', color='red', linewidth=2) plt.title('训练和验证损失 vs. 迭代次数', fontsize=14) plt.xlabel('迭代次数', fontsize=12) plt.ylabel('损失值', fontsize=12) plt.xticks(range(min(epochs), max(epochs)+1, max(1, len(epochs)//10))) plt.legend(fontsize=12) plt.grid(True) plt.tight_layout() plt.savefig(os.path.join(self.performance_plots_dir, '损失迭代曲线.png')) plt.close() # 2. 全局Dice和Accuracy曲线 plt.figure(figsize=(10, 6)) plt.plot(epochs, data['Dice_Global'], label='全局Dice系数', color='blue', linewidth=2) plt.plot(epochs, data['Accuracy_Global'], label='全局准确率', color='red', linewidth=2) plt.title('全局Dice系数和准确率 vs. 迭代次数', fontsize=14) plt.xlabel('迭代次数', fontsize=12) plt.ylabel('值', fontsize=12) plt.xticks(range(min(epochs), max(epochs)+1, max(1, len(epochs)//10))) plt.legend(fontsize=12) plt.grid(True) plt.tight_layout() plt.savefig(os.path.join(self.performance_plots_dir, '全局Dice和准确率迭代曲线.png')) plt.close() # 3. 各类别IoU曲线 plt.figure(figsize=(10, 6)) for class_name in all_classes: plt.plot(epochs, data[f'{class_name}_IoU'], label=f'{class_name} IoU', linewidth=2) plt.title('各类IoU vs. 迭代次数', fontsize=14) plt.xlabel('迭代次数', fontsize=12) plt.ylabel('IoU值', fontsize=12) plt.xticks(range(min(epochs), max(epochs)+1, max(1, len(epochs)//10))) plt.legend(fontsize=12) plt.grid(True) plt.tight_layout() plt.savefig(os.path.join(self.performance_plots_dir, '各类IoU迭代曲线.png')) plt.close() # 4. 各类别Precision曲线 plt.figure(figsize=(10, 6)) for class_name in all_classes: plt.plot(epochs, data[f'{class_name}_Precision'], label=f'{class_name} 精确率', linewidth=2) plt.title('各类精确率 vs. 迭代次数', fontsize=14) plt.xlabel('迭代次数', fontsize=12) plt.ylabel('精确率值', fontsize=12) plt.xticks(range(min(epochs), max(epochs)+1, max(1, len(epochs)//10))) plt.legend(fontsize=12) plt.grid(True) plt.tight_layout() plt.savefig(os.path.join(self.performance_plots_dir, '各类精确率迭代曲线.png')) plt.close() # 5. 各类别Recall曲线 plt.figure(figsize=(10, 6)) for class_name in all_classes: plt.plot(epochs, data[f'{class_name}_Recall'], label=f'{class_name} 召回率', linewidth=2) plt.title('各类召回率 vs. 迭代次数', fontsize=14) plt.xlabel('迭代次数', fontsize=12) plt.ylabel('召回率值', fontsize=12) plt.xticks(range(min(epochs), max(epochs)+1, max(1, len(epochs)//10))) plt.legend(fontsize=12) plt.grid(True) plt.tight_layout() plt.savefig(os.path.join(self.performance_plots_dir, '各类召回率迭代曲线.png')) plt.close() # 6. 各类别Specificity曲线 plt.figure(figsize=(10, 6)) for class_name in all_classes: plt.plot(epochs, data[f'{class_name}_Specificity'], label=f'{class_name} 特异性', linewidth=2) plt.title('各类特异性 vs. 迭代次数', fontsize=14) plt.xlabel('迭代次数', fontsize=12) plt.ylabel('特异性值', fontsize=12) plt.xticks(range(min(epochs), max(epochs)+1, max(1, len(epochs)//10))) plt.legend(fontsize=12) plt.grid(True) plt.tight_layout() plt.savefig(os.path.join(self.performance_plots_dir, '各类特异性迭代曲线.png')) plt.close() # 7. 各类别Accuracy曲线 plt.figure(figsize=(10, 6)) for class_name in all_classes: plt.plot(epochs, data[f'{class_name}_Accuracy'], label=f'{class_name} 准确率', linewidth=2) plt.title('各类准确率 vs. 迭代次数', fontsize=14) plt.xlabel('迭代次数', fontsize=12) plt.ylabel('准确率值', fontsize=12) plt.xticks(range(min(epochs), max(epochs)+1, max(1, len(epochs)//10))) plt.legend(fontsize=12) plt.grid(True) plt.tight_layout() plt.savefig(os.path.join(self.performance_plots_dir, '各类准确率迭代曲线.png')) plt.close() # 8. 各类别Dice曲线 plt.figure(figsize=(10, 6)) for class_name in all_classes: plt.plot(epochs, data[f'{class_name}_Dice'], label=f'{class_name} Dice', linewidth=2) plt.title('各类Dice系数 vs. 迭代次数', fontsize=14) plt.xlabel('迭代次数', fontsize=12) plt.ylabel('Dice值', fontsize=12) plt.xticks(range(min(epochs), max(epochs)+1, max(1, len(epochs)//10))) plt.legend(fontsize=12) plt.grid(True) plt.tight_layout() plt.savefig(os.path.join(self.performance_plots_dir, '各类Dice迭代曲线.png')) plt.close() # 9. 每个类别的Dice和IoU曲线(n张图) for class_name in all_classes: plt.figure(figsize=(10, 6)) plt.plot(epochs, data[f'{class_name}_Dice'], label=f'{class_name} Dice', color='blue', linewidth=2) plt.plot(epochs, data[f'{class_name}_IoU'], label=f'{class_name} IoU', color='red', linewidth=2) plt.title(f'{class_name}: Dice和IoU vs. 迭代次数', fontsize=14) plt.xlabel('迭代次数', fontsize=12) plt.ylabel('值', fontsize=12) plt.xticks(range(min(epochs), max(epochs)+1, max(1, len(epochs)//10))) plt.legend(fontsize=12) plt.grid(True) plt.tight_layout() plt.savefig(os.path.join(self.performance_plots_dir, f'{class_name}_Dice+IoU迭代曲线.png')) plt.close() # 10. 每个类别的Precision和Recall曲线(n张图) for class_name in all_classes: plt.figure(figsize=(10, 6)) plt.plot(epochs, data[f'{class_name}_Precision'], label=f'{class_name} 精确率', color='blue', linewidth=2) plt.plot(epochs, data[f'{class_name}_Recall'], label=f'{class_name} 召回率', color='red', linewidth=2) plt.title(f'{class_name}: 精确率和召回率 vs. 迭代次数', fontsize=14) plt.xlabel('迭代次数', fontsize=12) plt.ylabel('值', fontsize=12) plt.xticks(range(min(epochs), max(epochs)+1, max(1, len(epochs)//10))) plt.legend(fontsize=12) plt.grid(True) plt.tight_layout() plt.savefig(os.path.join(self.performance_plots_dir, f'{class_name}_精确率+召回率迭代曲线.png')) plt.close() print(f"性能曲线图已保存到 '{self.performance_plots_dir}'") print(f"总共生成 {8 + 2*len(all_classes)} 张曲线图") print("\n所有任务已完成!") #------------------------------------- 示例用法 -------------------------------------------- # 定义模型基础参数 input_shape = (512, 512, 3) classes = { 'class_number': 3, 'bg': (2, 'background'), 'fg': {0: 'cup', 1: 'disc'} } epochs = 5 result_base_path = '/content/drive/MyDrive/results' database_base_path = '/content/drive/MyDrive/database' # 实例化模型类 unet_model = AttentionUNet( input_shape=input_shape, classes=classes, epochs=epochs, result_path=result_base_path, database_path=database_base_path, learning_rate=1e-4, batch_size=5, early_stopping_patience=3, lr_reduction_patience=2, dropout_rate=0.5, filters_base=16, attention_mechanism='additive', loss='cross_entropy' # 修改为PyTorch支持的损失函数 ) # 构建并训练模型 unet_model.train() #------------------------------------------------------------------------------------------ 上述代码运行时报错: 正在绘制性能曲线图... --------------------------------------------------------------------------- KeyError Traceback (most recent call last) <ipython-input-8-2018641610> in <cell line: 0>() 1143 1144 # 构建并训练模型 -> 1145 unet_model.train() 1146 #------------------------------------------------------------------------------------------ 1 frames <ipython-input-8-2018641610> in _plot_performance_curves(self) 941 reader = csv.DictReader(csvfile) 942 for row in reader: --> 943 epochs.append(int(row['Epoch'])) 944 for key, value in row.items(): 945 if key == 'Epoch': KeyError: 'Epoch' 请修改
06-17
# 这是一个示例 Python 脚本。 # 按 Shift+F10 执行或将其替换为您的代码。 # 按 双击 Shift 在所有地方搜索类、文件、工具窗口、操作和设置。 import argparse import math import pickle import torch import torch.nn as nn import torch.nn.functional as F from tqdm import tqdm from omegaconf import OmegaConf from sklearn.metrics import f1_score from torch.utils.data import Dataset, DataLoader from torch.nn import TransformerEncoderLayer, TransformerEncoder restypes = [ 'A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V' ] unsure_restype = 'X' unknown_restype = 'U' def make_dataset(data_config, train_rate=0.7, valid_rate=0.2): data_path = data_config.data_path with open(data_path, 'rb') as f: data = pickle.load(f) total_number = len(data) train_sep = int(total_number * train_rate) valid_sep = int(total_number * (train_rate + valid_rate)) train_data_dicts = data[:train_sep] valid_data_dicts = data[train_sep:valid_sep] test_data_dicts = data[valid_sep:] train_dataset = DisProtDataset(train_data_dicts) valid_dataset = DisProtDataset(valid_data_dicts) test_dataset = DisProtDataset(test_data_dicts) return train_dataset, valid_dataset, test_dataset class DisProtDataset(Dataset): def __init__(self, dict_data): sequences = [d['sequence'] for d in dict_data] labels = [d['label'] for d in dict_data] assert len(sequences) == len(labels) self.sequences = sequences self.labels = labels self.residue_mapping = {'X':20} self.residue_mapping.update(dict(zip(restypes, range(len(restypes))))) def __len__(self): return len(self.sequences) def __getitem__(self, idx): sequence = torch.zeros(len(self.sequences[idx]), len(self.residue_mapping)) for i, c in enumerate(self.sequences[idx]): if c not in restypes: c = 'X' sequence[i][self.residue_mapping[c]] = 1 label = torch.tensor([int(c) for c in self.labels[idx]], dtype=torch.long) return sequence, label class PositionalEncoding(nn.Module): def __init__(self, d_model, dropout=0.0, max_len=40): super().__init__() position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp( torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model) ) pe = torch.zeros(1, max_len, d_model) pe[0, :, 0::2] = torch.sin(position * div_term) pe[0, :, 1::2] = torch.cos(position * div_term) self.register_buffer("pe", pe) self.dropout = nn.Dropout(p=dropout) def forward(self, x): if len(x.shape) == 3: x = x + self.pe[:, : x.size(1)] elif len(x.shape) == 4: x = x + self.pe[:, :x.size(1), None, :] return self.dropout(x) class DisProtModel(nn.Module): def __init__(self, model_config): super().__init__() self.d_model = model_config.d_model self.n_head = model_config.n_head self.n_layer = model_config.n_layer self.input_layer = nn.Linear(model_config.i_dim, self.d_model) self.position_embed = PositionalEncoding(self.d_model, max_len=20000) self.input_norm = nn.LayerNorm(self.d_model) self.dropout_in = nn.Dropout(p=0.1) encoder_layer = TransformerEncoderLayer( d_model=self.d_model, nhead=self.n_head, activation='gelu', batch_first=True) self.transformer = TransformerEncoder(encoder_layer, num_layers=self.n_layer) self.output_layer = nn.Sequential( nn.Linear(self.d_model, self.d_model), nn.GELU(), nn.Dropout(p=0.1), nn.Linear(self.d_model, model_config.o_dim) ) def forward(self, x): x = self.input_layer(x) x = self.position_embed(x) x = self.input_norm(x) x = self.dropout_in(x) x = self.transformer(x) x = self.output_layer(x) return x def metric_fn(pred, gt): pred = pred.detach().cpu() gt = gt.detach().cpu() pred_labels = torch.argmax(pred, dim=-1).view(-1) gt_labels = gt.view(-1) score = f1_score(y_true=gt_labels, y_pred=pred_labels, average='micro') return score if __name__ == '__main__': device = 'cuda' if torch.cuda.is_available() else 'cpu' parser = argparse.ArgumentParser('IDRs prediction') parser.add_argument('--config_path', default='./config.yaml') args = parser.parse_args() config = OmegaConf.load(args.config_path) train_dataset, valid_dataset, test_dataset = make_dataset(config.data) train_dataloader = DataLoader(dataset=train_dataset, **config.train.dataloader) valid_dataloader = DataLoader(dataset=valid_dataset, batch_size=1, shuffle=False) model = DisProtModel(config.model) model = model.to(device) optimizer = torch.optim.AdamW(model.parameters(), lr=config.train.optimizer.lr, weight_decay=config.train.optimizer.weight_decay) loss_fn = nn.CrossEntropyLoss() model.eval() metric = 0. with torch.no_grad(): for sequence, label in valid_dataloader: sequence = sequence.to(device) label = label.to(device) pred = model(sequence) metric += metric_fn(pred, label) print("init f1_score:", metric / len(valid_dataloader)) for epoch in range(config.train.epochs): # train loop progress_bar = tqdm( train_dataloader, initial=0, desc=f"epoch:{epoch:03d}", ) model.train() total_loss = 0. for sequence, label in progress_bar: sequence = sequence.to(device) label = label.to(device) pred = model(sequence) loss = loss_fn(pred.permute(0, 2, 1), label) progress_bar.set_postfix(loss=loss.item()) total_loss += loss.item() optimizer.zero_grad() loss.backward() optimizer.step() avg_loss = total_loss / len(train_dataloader) # valid loop model.eval() metric = 0. with torch.no_grad(): for sequence, label in valid_dataloader: sequence = sequence.to(device) label = label.to(device) pred = model(sequence) metric += metric_fn(pred, label) print(f"avg_training_loss: {avg_loss}, f1_score: {metric / len(valid_dataloader)}") # 保存当前 epoch 的模型 save_path = f"model.pkl" torch.save(model.state_dict(), save_path) print(f"Model saved to {save_path}") 帮我分析一下这个代码是干什么的
07-13
import os import json import torch import numpy as np import matplotlib.pyplot as plt from monai.transforms import ( Compose, LoadImaged, EnsureChannelFirstd, Orientationd, Spacingd, Resized, Activations, AsDiscrete, ) from monai.data import CacheDataset, DataLoader, decollate_batch from monai.networks.nets import UNet from monai.losses import DiceCELoss from monai.metrics import DiceMetric from monai.inferers import sliding_window_inference from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm from sklearn.model_selection import train_test_split import torch.nn as nn import torch.optim as optim from monai.data.utils import pad_list_data_collate # 设置根目录 root_dir = "Work_dir/Gn/Task03_Liver" data_dir = root_dir # 数据预处理 train_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Orientationd(keys=["image", "label"], axcodes="RAS"), Spacingd( keys=["image", "label"], pixdim=(0.8, 0.8, 1.5), mode=("bilinear", "nearest") ), Resized(keys=["image", "label"], spatial_size=(128, 128, 64)), ]) val_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Orientationd(keys=["image", "label"], axcodes="RAS"), Spacingd( keys=["image", "label"], pixdim=(0.8, 0.8, 1.5), mode=("bilinear", "nearest") ), Resized(keys=["image", "label"], spatial_size=(128, 128, 64)), ]) # 加载 dataset.json 文件 json_path = os.path.join(data_dir, "dataset.json") with open(json_path, "r") as f: data_json = json.load(f) files = data_json["training"] for file in files: for key in file.keys(): rel_path = file[key].replace("./", "").replace(".\\", "").replace("\\", "/") file[key] = os.path.join(data_dir, rel_path) # 划分训练集和验证集 train_list, val_list = train_test_split(files[:50], test_size=0.2, random_state=42) # 创建数据集 train_ds = CacheDataset(data=train_list, transform=train_transforms, cache_rate=1.0, num_workers=4) val_ds = CacheDataset(data=val_list, transform=val_transforms, cache_rate=1.0, num_workers=4) # 数据加载器 num_workers = 4 bs = 2 train_loader = DataLoader( train_ds, batch_size=bs, shuffle=True, num_workers=num_workers, pin_memory=True, collate_fn=pad_list_data_collate ) val_loader = DataLoader( val_ds, batch_size=1, shuffle=False, num_workers=num_workers, pin_memory=True, collate_fn=pad_list_data_collate ) # 构建模型 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = UNet( spatial_dims=3, in_channels=1, out_channels=2, channels=[16, 32, 64], strides=[2, 2] ).to(device) if torch.cuda.device_count() > 1: print(f"Using {torch.cuda.device_count()} GPUs!") model = torch.nn.DataParallel(model) # 损失函数 & 优化器 loss_function = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001) # TensorBoard 日志设置 log_dir = os.path.join(root_dir, "logs") os.makedirs(log_dir, exist_ok=True) writer = SummaryWriter(log_dir=log_dir) # 验证参数 loader = DataLoader(train_ds, batch_size=1, num_workers=0) shapes = [] for data in loader: image_shape = data["image"].shape spatial_shape = image_shape[2:] # 去掉 batch 和 channel 维度 shapes.append(spatial_shape) max_shape = np.max(shapes, axis=0) volume_size = tuple(max_shape) print("Sliding window ROI size:", volume_size) dice_metric = DiceMetric(include_background=False, reduction="mean") post_pred = Compose([Activations(softmax=True), AsDiscrete(threshold=0.5)]) max_epochs = 300 eval_interval = 2 class Trainer: def __init__(self): self.best_dice = 0.0 self.epoch_loss_values = [] self.metric_values = [] self.train_loss = 0 def train_epoch(self, epoch): model.train() self.train_loss = 0 step = 0 with tqdm(train_loader, unit="batch") as tepoch: tepoch.set_description(f"Epoch {epoch + 1}/{max_epochs}") for batch in tepoch: step += 1 x, y = batch["image"].to(device), batch["label"].to(device) y = (y > 0).long() # 转换为整数类别 y = y.squeeze(1) optimizer.zero_grad() logit_map = model(x) loss = loss_function(logit_map, y) loss.backward() optimizer.step() self.train_loss += loss.item() tepoch.set_postfix(loss=loss.item()) epoch_loss = self.train_loss / len(train_loader) self.epoch_loss_values.append(epoch_loss) writer.add_scalar("train/loss", epoch_loss, epoch) def validate(self, epoch): model.eval() with torch.no_grad(): with tqdm(val_loader, unit="batch") as vepoch: vepoch.set_description("Validation") for batch in vepoch: val_inputs = batch["image"].to(device) val_labels = batch["label"].to(device) val_labels = (val_labels > 0).long() val_outputs = sliding_window_inference( inputs=val_inputs, roi_size=volume_size, sw_batch_size=bs, predictor=model, overlap=0.3 ) val_labels_list = decollate_batch(val_labels) val_labels_convert = [AsDiscrete()(label) for label in val_labels_list] val_outputs_list = decollate_batch(val_outputs) val_output_convert = [post_pred(pred) for pred in val_outputs_list] dice_metric(y_pred=val_output_convert, y=val_labels_convert) mean_dice = dice_metric.aggregate().item() vepoch.set_description(f"current mean dice: {mean_dice:.4f}") dice_metric.reset() self.metric_values.append(mean_dice) writer.add_scalar("val/dice", mean_dice, epoch) if mean_dice > self.best_dice: self.best_dice = mean_dice save_path = os.path.join(log_dir, f"best_model_epoch_{epoch:03d}_dice_{mean_dice:.4f}.pth") torch.save(model.state_dict(), save_path) print(f"New best model saved with Dice: {mean_dice:.4f}") def plot_metrics(self): plt.figure(figsize=(12, 6)) plt.subplot(1, 2, 1) plt.title("Training Loss") plt.plot(range(1, len(self.epoch_loss_values) + 1), self.epoch_loss_values) plt.xlabel("Epoch") plt.ylabel("Loss") plt.subplot(1, 2, 2) plt.title("Validation Dice Score") plt.plot(range(1, len(self.metric_values) + 1), self.metric_values) plt.xlabel("Epoch") plt.ylabel("Dice") plt.tight_layout() plt.savefig(os.path.join(log_dir, "training_metrics.png")) plt.close() def save_logs(self): log_data = { "best_dice": self.best_dice, "epoch_loss": self.epoch_loss_values, "val_dice": self.metric_values, "last_epoch": len(self.epoch_loss_values) } with open(os.path.join(log_dir, "training_log.json"), "w") as f: json.dump(log_data, f, indent=2) trainer = Trainer() for epoch in range(max_epochs): trainer.train_epoch(epoch) if (epoch + 1) % eval_interval == 0: trainer.validate(epoch) if (epoch + 1) % 10 == 0: trainer.plot_metrics() trainer.save_logs() trainer.plot_metrics() trainer.save_logs() writer.close() print(f"Training completed. Best Dice: {trainer.best_dice:.4f}") print(f"Logs and models saved in: {log_dir}") 上述代码可从改进预处理方法的角度来进一步优化图像分割效果,具体说说可以从哪些预处理方面来进行改进,如何改进,改进原理?
07-11
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值