import os
import datetime
import numpy as np
import matplotlib.pyplot as plt
import shutil
from glob import glob
from tqdm import tqdm
from PIL import Image
import csv
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score, accuracy_score
class AttentionUNet:
def __init__(self,
input_shape,
classes,
epochs,
result_path,
database_path,
learning_rate=1e-4,
batch_size=8,
early_stopping_patience=10,
lr_reduction_patience=5,
lr_reduction_factor=0.5,
dropout_rate=0.3,
filters_base=32,
kernel_size=3,
activation='relu',
output_activation='softmax',
optimizer='adam',
loss='cross_entropy', # 默认使用PyTorch支持的损失函数
metrics=['accuracy'],
attention_mechanism='additive'):
self.input_shape = input_shape
self.num_classes = classes['class_number']
# 解析类别信息
self.background_gray = classes['bg'][0]
self.background_name = classes['bg'][1]
self.foreground_labels = classes['fg']
# 创建灰度值到类别索引的映射
self.gray_to_index = {self.background_gray: 0} # 背景映射到索引0
# 前景映射到索引1,2,...
for idx, (gray_val, name) in enumerate(self.foreground_labels.items(), start=1):
self.gray_to_index[gray_val] = idx
# 验证映射
if len(self.gray_to_index) != self.num_classes:
raise ValueError(f"类别数量不匹配: 配置的类别数={self.num_classes}, 实际映射的类别数={len(self.gray_to_index)}")
self.epochs = epochs
self.result_path = result_path
self.database_path = database_path
self.learning_rate = learning_rate
self.batch_size = batch_size
self.early_stopping_patience = early_stopping_patience
self.lr_reduction_patience = lr_reduction_patience
self.lr_reduction_factor = lr_reduction_factor
self.dropout_rate = dropout_rate
self.filters_base = filters_base
self.kernel_size = kernel_size
self.activation = activation
self.output_activation = output_activation
self.optimizer_type = optimizer
self.loss_type = loss
self.metrics = metrics
self.attention_mechanism = attention_mechanism
self.model = None
self.history = None
self.best_model_path = os.path.join(self.result_path, 'models', 'best_model.pth')
self.temp_model_dir = os.path.join(self.result_path, 'models', 'temp_epochs')
self.train_log_path = os.path.join(self.result_path, 'diagram', '训练日志.csv')
self.best_model_performance_path = os.path.join(self.result_path, 'diagram', '最佳模型性能.txt')
self.performance_plots_dir = os.path.join(self.result_path, 'diagram', 'performance_plots')
self.tensorboard_dir = os.path.join(self.result_path, 'tensorboard')
# 创建结果目录结构
self._create_result_directories()
self._print_model_config()
# 设置设备
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {self.device}")
def _create_result_directories(self):
"""创建结果文件目录结构并保存配置。"""
dirs_to_create = [
self.result_path,
os.path.join(self.result_path, 'models'),
os.path.join(self.result_path, 'diagram'),
self.temp_model_dir, # 确保包含这个目录
self.performance_plots_dir,
self.tensorboard_dir
]
for dir_path in dirs_to_create:
os.makedirs(dir_path, exist_ok=True)
print(f"目录已创建: {dir_path}")
print(f"结果文件目录结构已创建:{self.result_path}")
# 保存模型配置
config_path = os.path.join(self.result_path, 'models', 'configuration.txt')
try:
with open(config_path, 'w', encoding='utf-8') as f:
f.write("Attention U-Net 模型配置 (PyTorch 实现)\n")
f.write(f"保存时间: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
f.write("="*50 + "\n\n")
f.write("核心参数:\n")
f.write(f"输入尺寸: {self.input_shape}\n")
f.write(f"类别数: {self.num_classes}\n")
f.write(f"背景类别: {self.background_name} (灰度值: {self.background_gray})\n")
f.write("前景类别:\n")
for gray_val, name in self.foreground_labels.items():
f.write(f" {name}: 灰度值 {gray_val}\n")
f.write("\n训练参数:\n")
f.write(f"训练轮数: {self.epochs}\n")
f.write(f"批次大小: {self.batch_size}\n")
f.write(f"学习率: {self.learning_rate}\n")
f.write(f"早停耐心值: {self.early_stopping_patience}\n")
f.write(f"学习率衰减耐心值: {self.lr_reduction_patience}\n")
f.write(f"学习率衰减因子: {self.lr_reduction_factor}\n")
f.write("\n模型架构参数:\n")
f.write(f"基础卷积核数量: {self.filters_base}\n")
f.write(f"卷积核大小: {self.kernel_size}\n")
f.write(f"激活函数: {self.activation}\n")
f.write(f"输出激活函数: {self.output_activation}\n")
f.write(f"Dropout比率: {self.dropout_rate}\n")
f.write(f"注意力机制类型: {self.attention_mechanism}\n")
f.write("\n优化参数:\n")
f.write(f"优化器: {self.optimizer_type}\n")
f.write(f"损失函数: {self.loss_type}\n")
f.write(f"监控指标: {', '.join(self.metrics)}\n")
print(f"模型配置已保存到: {config_path}")
except Exception as e:
print(f"保存配置时出错: {e}")
def _print_model_config(self):
"""在控制台打印模型配置。"""
print("\n" + "="*30)
print("模型配置:")
print(f"输入尺寸: {self.input_shape}")
print(f"类别数: {self.num_classes}")
print(f"背景类别: {self.background_name} (灰度值: {self.background_gray} -> 索引: {self.gray_to_index[self.background_gray]})")
print("前景类别:")
for gray_val, name in self.foreground_labels.items():
print(f" {name} (灰度值: {gray_val} -> 索引: {self.gray_to_index[gray_val]})")
print(f"训练轮数: {self.epochs}")
print(f"结果文件路径: {self.result_path}")
print(f"数据集文件路径: {self.database_path}")
print(f"学习率: {self.learning_rate}")
print(f"批次大小: {self.batch_size}")
print(f"早停耐心值: {self.early_stopping_patience}")
print(f"学习率衰减耐心值: {self.lr_reduction_patience}")
print(f"学习率衰减因子: {self.lr_reduction_factor}")
print(f"Dropout比率: {self.dropout_rate}")
print(f"基础卷积核数量: {self.filters_base}")
print(f"卷积核大小: {self.kernel_size}")
print(f"激活函数: {self.activation}")
print(f"输出激活函数: {self.output_activation}")
print(f"优化器: {self.optimizer_type}")
print(f"损失函数: {self.loss_type}")
print(f"监控指标: {self.metrics}")
print(f"注意力机制类型: {self.attention_mechanism}")
print("="*30 + "\n")
class AttentionUNetModel(nn.Module):
"""PyTorch实现的Attention U-Net模型"""
def __init__(self, input_channels, num_classes, filters_base=32, kernel_size=3,
dropout_rate=0.3, attention_mechanism='additive', output_activation='softmax'):
super().__init__()
self.input_channels = input_channels
self.num_classes = num_classes
self.filters_base = filters_base
self.kernel_size = kernel_size
self.dropout_rate = dropout_rate
self.attention_mechanism = attention_mechanism
self.output_activation = output_activation # 添加输出激活函数属性
# 编码器 (Encoder)
self.enc1 = self._conv_block(input_channels, filters_base)
self.enc2 = self._conv_block(filters_base, filters_base * 2)
self.enc3 = self._conv_block(filters_base * 2, filters_base * 4)
self.enc4 = self._conv_block(filters_base * 4, filters_base * 8)
self.pool = nn.MaxPool2d(2)
self.dropout = nn.Dropout2d(dropout_rate)
# Bottleneck
self.bottleneck = self._conv_block(filters_base * 8, filters_base * 16)
# 解码器 (Decoder)
self.up1 = self._up_block(filters_base * 16, filters_base * 8)
self.att1 = self._attention_gate(filters_base * 8, filters_base * 8)
self.dec1 = self._conv_block(filters_base * 16, filters_base * 8)
self.up2 = self._up_block(filters_base * 8, filters_base * 4)
self.att2 = self._attention_gate(filters_base * 4, filters_base * 4)
self.dec2 = self._conv_block(filters_base * 8, filters_base * 4)
self.up3 = self._up_block(filters_base * 4, filters_base * 2)
self.att3 = self._attention_gate(filters_base * 2, filters_base * 2)
self.dec3 = self._conv_block(filters_base * 4, filters_base * 2)
self.up4 = self._up_block(filters_base * 2, filters_base)
self.att4 = self._attention_gate(filters_base, filters_base)
self.dec4 = self._conv_block(filters_base * 2, filters_base)
# 输出层
self.out_conv = nn.Conv2d(filters_base, num_classes, kernel_size=1)
def _conv_block(self, in_channels, out_channels):
"""标准卷积块,包含Conv2D, BatchNorm, Activation。"""
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=self.kernel_size, padding=self.kernel_size//2),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=self.kernel_size, padding=self.kernel_size//2),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def _up_block(self, in_channels, out_channels):
"""上采样块,包含ConvTranspose2d和Dropout。"""
return nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
nn.Dropout2d(self.dropout_rate)
)
def _attention_gate(self, g_channels, x_channels):
"""注意力门 (Attention Gate)。"""
if self.attention_mechanism == 'additive':
# 加性注意力机制
return nn.Sequential(
nn.Conv2d(g_channels, x_channels, kernel_size=1),
nn.BatchNorm2d(x_channels),
nn.ReLU(inplace=True),
nn.Conv2d(x_channels, x_channels, kernel_size=1),
nn.BatchNorm2d(x_channels),
nn.Sigmoid()
)
elif self.attention_mechanism == 'multiplicative':
# 乘性注意力机制
return nn.Sequential(
nn.Conv2d(g_channels + x_channels, x_channels, kernel_size=1),
nn.BatchNorm2d(x_channels),
nn.ReLU(inplace=True),
nn.Conv2d(x_channels, 1, kernel_size=1),
nn.BatchNorm2d(1),
nn.Sigmoid()
)
else:
raise ValueError(f"不支持的注意力机制类型: {self.attention_mechanism}")
def forward(self, x):
# 编码器路径
enc1 = self.enc1(x)
enc1_pool = self.pool(enc1)
enc1_pool = self.dropout(enc1_pool)
enc2 = self.enc2(enc1_pool)
enc2_pool = self.pool(enc2)
enc2_pool = self.dropout(enc2_pool)
enc3 = self.enc3(enc2_pool)
enc3_pool = self.pool(enc3)
enc3_pool = self.dropout(enc3_pool)
enc4 = self.enc4(enc3_pool)
enc4_pool = self.pool(enc4)
enc4_pool = self.dropout(enc4_pool)
# 瓶颈层
bottleneck = self.bottleneck(enc4_pool)
# 解码器路径
# 上采样块1
up1 = self.up1(bottleneck)
if self.attention_mechanism == 'additive':
att1 = self.att1(up1)
att1 = att1 * enc4
else:
att1 = self.att1(torch.cat([up1, enc4], dim=1))
att1 = att1 * enc4
merge1 = torch.cat([up1, att1], dim=1)
dec1 = self.dec1(merge1)
# 上采样块2
up2 = self.up2(dec1)
if self.attention_mechanism == 'additive':
att2 = self.att2(up2)
att2 = att2 * enc3
else:
att2 = self.att2(torch.cat([up2, enc3], dim=1))
att2 = att2 * enc3
merge2 = torch.cat([up2, att2], dim=1)
dec2 = self.dec2(merge2)
# 上采样块3
up3 = self.up3(dec2)
if self.attention_mechanism == 'additive':
att3 = self.att3(up3)
att3 = att3 * enc2
else:
att3 = self.att3(torch.cat([up3, enc2], dim=1))
att3 = att3 * enc2
merge3 = torch.cat([up3, att3], dim=1)
dec3 = self.dec3(merge3)
# 上采样块4
up4 = self.up4(dec3)
if self.attention_mechanism == 'additive':
att4 = self.att4(up4)
att4 = att4 * enc1
else:
att4 = self.att4(torch.cat([up4, enc1], dim=1))
att4 = att4 * enc1
merge4 = torch.cat([up4, att4], dim=1)
dec4 = self.dec4(merge4)
# 输出层
out = self.out_conv(dec4)
# 应用输出激活函数
if self.output_activation == 'softmax':
out = F.softmax(out, dim=1)
elif self.output_activation == 'sigmoid':
out = torch.sigmoid(out)
return out
def _dice_loss(self, y_pred, y_true, smooth=1e-6):
"""Dice损失函数实现"""
# 展平预测和真实标签
y_pred_flat = y_pred.contiguous().view(-1)
y_true_flat = y_true.contiguous().view(-1)
# 计算交集和并集
intersection = (y_pred_flat * y_true_flat).sum()
union = y_pred_flat.sum() + y_true_flat.sum()
# 计算Dice系数
dice = (2. * intersection + smooth) / (union + smooth)
# 返回Dice损失
return 1 - dice
def _focal_loss(self, y_pred, y_true, alpha=0.25, gamma=2.0):
"""Focal损失函数实现"""
# 计算交叉熵
ce_loss = F.binary_cross_entropy(y_pred, y_true, reduction='none')
# 计算概率
p_t = y_true * y_pred + (1 - y_true) * (1 - y_pred)
# 计算Focal损失
focal_loss = torch.pow(1 - p_t, gamma) * ce_loss
# 应用alpha权重
if alpha is not None:
alpha_t = y_true * alpha + (1 - y_true) * (1 - alpha)
focal_loss = alpha_t * focal_loss
return focal_loss.mean()
def _jaccard_loss(self, y_pred, y_true, smooth=1e-6):
"""Jaccard损失函数实现(IoU损失)"""
# 展平预测和真实标签
y_pred_flat = y_pred.contiguous().view(-1)
y_true_flat = y_true.contiguous().view(-1)
# 计算交集和并集
intersection = (y_pred_flat * y_true_flat).sum()
total = y_pred_flat.sum() + y_true_flat.sum()
union = total - intersection
# 计算Jaccard指数(IoU)
iou = (intersection + smooth) / (union + smooth)
return 1 - iou
def build_model(self):
"""构建Attention U-Net模型。"""
input_channels = self.input_shape[2] if len(self.input_shape) == 3 else 3
self.model = self.AttentionUNetModel(
input_channels=input_channels,
num_classes=self.num_classes,
filters_base=self.filters_base,
kernel_size=self.kernel_size,
dropout_rate=self.dropout_rate,
attention_mechanism=self.attention_mechanism,
output_activation=self.output_activation # 添加输出激活函数参数
).to(self.device)
print("Attention U-Net 模型已成功构建。")
print(f"模型参数数量: {sum(p.numel() for p in self.model.parameters() if p.requires_grad):,}")
# 选择优化器
if self.optimizer_type.lower() == 'adam':
optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)
elif self.optimizer_type.lower() == 'sgd':
optimizer = optim.SGD(self.model.parameters(), lr=self.learning_rate, momentum=0.9)
elif self.optimizer_type.lower() == 'rmsprop':
optimizer = optim.RMSprop(self.model.parameters(), lr=self.learning_rate)
else:
raise ValueError(f"不支持的优化器类型: {self.optimizer_type}")
# 选择损失函数
if self.loss_type.lower() == 'dice_loss':
criterion = self._dice_loss
elif self.loss_type.lower() == 'focal_loss':
criterion = self._focal_loss
elif self.loss_type.lower() == 'jaccard_loss':
criterion = self._jaccard_loss
elif self.loss_type.lower() in ['cross_entropy', 'categorical_crossentropy']: # 支持两种名称
criterion = nn.CrossEntropyLoss()
else:
raise ValueError(f"不支持的损失函数类型: {self.loss_type}")
return self.model, optimizer, criterion
class SegmentationDataset(Dataset):
"""图像分割数据集类"""
def __init__(self, image_dir, mask_dir, input_shape, gray_to_index, num_classes):
self.image_dir = image_dir
self.mask_dir = mask_dir
self.input_shape = input_shape
self.gray_to_index = gray_to_index
self.num_classes = num_classes
self.image_files = sorted(glob(os.path.join(image_dir, '*')))
self.mask_files = sorted(glob(os.path.join(mask_dir, '*')))
if not self.image_files:
raise FileNotFoundError(f"在 {image_dir} 中未找到任何图像文件")
if not self.mask_files:
raise FileNotFoundError(f"在 {mask_dir} 中未找到任何掩码文件")
if len(self.image_files) != len(self.mask_files):
print(f"警告: 数据集中图像文件数量 ({len(self.image_files)}) 与掩码文件数量 ({len(self.mask_files)}) 不匹配")
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def __len__(self):
return min(len(self.image_files), len(self.mask_files))
def __getitem__(self, idx):
# 加载图像
img_path = self.image_files[idx]
img = Image.open(img_path).convert('RGB')
img = img.resize((self.input_shape[1], self.input_shape[0]))
img = self.transform(img)
# 加载掩码
mask_path = self.mask_files[idx]
mask = Image.open(mask_path).convert('L')
mask = mask.resize((self.input_shape[1], self.input_shape[0]), Image.NEAREST)
mask = np.array(mask)
# 创建类别索引掩码 (整数类型)
processed_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int64)
for gray_val, class_idx in self.gray_to_index.items():
processed_mask[mask == gray_val] = class_idx
return img, processed_mask
def _load_datasets(self):
"""加载训练、验证和测试数据集"""
# 定义数据集路径
train_image_dir = os.path.join(self.database_path, 'train', 'images')
train_mask_dir = os.path.join(self.database_path, 'train', 'masks')
val_image_dir = os.path.join(self.database_path, 'val', 'images')
val_mask_dir = os.path.join(self.database_path, 'val', 'masks')
test_image_dir = os.path.join(self.database_path, 'test', 'images')
test_mask_dir = os.path.join(self.database_path, 'test', 'masks')
# 创建数据集实例
train_dataset = self.SegmentationDataset(
train_image_dir, train_mask_dir, self.input_shape, self.gray_to_index, self.num_classes
)
val_dataset = self.SegmentationDataset(
val_image_dir, val_mask_dir, self.input_shape, self.gray_to_index, self.num_classes
)
test_dataset = self.SegmentationDataset(
test_image_dir, test_mask_dir, self.input_shape, self.gray_to_index, self.num_classes
)
# 创建数据加载器
train_loader = DataLoader(
train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4
)
val_loader = DataLoader(
val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=4
)
test_loader = DataLoader(
test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=4
)
# 打印数据集信息
print(f"\n数据集加载完成:")
print(f" 训练集: {len(train_dataset)} 个样本")
print(f" 验证集: {len(val_dataset)} 个样本")
print(f" 测试集: {len(test_dataset)} 个样本")
return train_loader, val_loader, test_loader
def _calculate_metrics(self, y_true, y_pred, smooth=1e-6, verbose=True):
"""
计算分割性能指标。
Args:
y_true (np.array): 真实标签 (类别索引)。
y_pred (np.array): 预测结果 (概率)。
smooth (float): 防止除以零的小常数。
verbose (bool): 是否打印详细指标。
Returns:
dict: 包含所有计算的指标。
"""
if y_true.size == 0 or y_pred.size == 0:
print("警告: 真实标签或预测结果为空,无法计算指标。")
return {}
# 将预测结果转换为类别索引
y_pred_argmax = np.argmax(y_pred, axis=1)
# 展平图像以进行指标计算
y_pred_flat = y_pred_argmax.flatten()
y_true_flat = y_true.flatten()
metrics = {}
epsilon = 1e-7 # 防止除以零
# 1. 全局Dice系数 (针对所有前景和背景像素点)
# 背景索引
bg_index = self.gray_to_index[self.background_gray]
# 将所有前景类别合并为一个"前景"类别
y_true_fg_flat = (y_true_flat != bg_index)
y_pred_fg_flat = (y_pred_flat != bg_index)
# True Positives (Global Foreground): 真实为前景,预测为前景
TP_global = np.sum(y_true_fg_flat & y_pred_fg_flat)
# False Positives (Global Foreground): 真实为背景,预测为前景
FP_global = np.sum(~y_true_fg_flat & y_pred_fg_flat)
# False Negatives (Global Foreground): 真实为前景,预测为背景
FN_global = np.sum(y_true_fg_flat & ~y_pred_fg_flat)
# True Negatives (Global Background): 真实为背景,预测为背景
TN_global = np.sum(~y_true_fg_flat & ~y_pred_fg_flat)
dice_global = (2. * TP_global) / (2 * TP_global + FP_global + FN_global + epsilon)
metrics['Dice_Global'] = dice_global
accuracy_global = (TP_global + TN_global) / (TP_global + TN_global + FP_global + FN_global + epsilon)
metrics['Accuracy_Global'] = accuracy_global
if verbose:
print(f" 全局 Dice 系数: {metrics['Dice_Global']:.4f}")
print(f" 全局准确率: {metrics['Accuracy_Global']:.4f}")
# 2. 针对每个类别的像素点计算指标
for gray_val, class_idx in self.gray_to_index.items():
class_name = self.background_name if gray_val == self.background_gray else self.foreground_labels[gray_val]
# 提取当前类别的TP, FP, FN, TN
TP_class = np.sum((y_true_flat == class_idx) & (y_pred_flat == class_idx))
FP_class = np.sum((y_true_flat != class_idx) & (y_pred_flat == class_idx))
FN_class = np.sum((y_true_flat == class_idx) & (y_pred_flat != class_idx))
TN_class = np.sum((y_true_flat != class_idx) & (y_pred_flat != class_idx))
dice_class = (2. * TP_class) / (2 * TP_class + FP_class + FN_class + epsilon)
metrics[f'{class_name}_Dice'] = dice_class
accuracy_class = (TP_class + TN_class) / (TP_class + TN_class + FP_class + FN_class + epsilon)
metrics[f'{class_name}_Accuracy'] = accuracy_class
iou_class = TP_class / (TP_class + FP_class + FN_class + epsilon)
metrics[f'{class_name}_IoU'] = iou_class
precision_class = TP_class / (TP_class + FP_class + epsilon)
metrics[f'{class_name}_Precision'] = precision_class
recall_class = TP_class / (TP_class + FN_class + epsilon)
metrics[f'{class_name}_Recall'] = recall_class
specificity_class = TN_class / (TN_class + FP_class + epsilon)
metrics[f'{class_name}_Specificity'] = specificity_class
if verbose:
print(f"\n --- 类别: {class_name} (灰度值: {gray_val}, 索引: {class_idx}) ---")
print(f" Dice: {metrics[f'{class_name}_Dice']:.4f}")
print(f" Accuracy: {metrics[f'{class_name}_Accuracy']:.4f}")
print(f" IoU: {metrics[f'{class_name}_IoU']:.4f}")
print(f" Precision: {metrics[f'{class_name}_Precision']:.4f}")
print(f" Recall: {metrics[f'{class_name}_Recall']:.4f}")
print(f" Specificity: {metrics[f'{class_name}_Specificity']:.4f}")
return metrics
def _select_best_model(self, all_metrics):
"""
选择最佳模型的筛选规则(可自定义)
Args:
all_metrics (list): 包含所有epoch模型性能指标的列表
Returns:
int: 最佳模型对应的epoch编号
"""
# 默认规则:使用全局Dice系数作为主要指标,选择最高值的模型
best_epoch = -1
best_metric_value = -1
# 遍历所有模型的评估结果
for metrics in all_metrics:
# 使用全局Dice系数作为选择标准
current_value = metrics.get('Dice_Global', -1)
# 如果当前模型性能更好,更新最佳模型
if current_value > best_metric_value:
best_metric_value = current_value
best_epoch = metrics['Epoch']
print(f"\n最佳模型筛选结果: Epoch {best_epoch} (全局Dice系数: {best_metric_value:.4f})")
return best_epoch
def train(self):
"""
训练Attention U-Net模型。
"""
# 确保所有结果目录都存在
self._create_result_directories() # 确保所有目录已创建
# 构建模型
model, optimizer, criterion = self.build_model()
# 加载数据集
train_loader, val_loader, test_loader = self._load_datasets()
# 初始化TensorBoard
writer = SummaryWriter(self.tensorboard_dir)
# 初始化变量
best_val_loss = float('inf')
epochs_without_improvement = 0
all_metrics = []
# 确保日志目录存在
os.makedirs(os.path.dirname(self.train_log_path), exist_ok=True)
# 准备训练日志CSV
with open(self.train_log_path, 'w', newline='', encoding='utf-8') as csvfile:
fieldnames = ['Epoch', 'Train_Loss', 'Val_Loss', 'Dice_Global', 'Accuracy_Global']
# 添加每个类别的指标字段
for gray_val in self.gray_to_index:
class_name = self.background_name if gray_val == self.background_gray else self.foreground_labels[gray_val]
fieldnames.extend([
f'{class_name}_Dice',
f'{class_name}_Accuracy',
f'{class_name}_IoU',
f'{class_name}_Precision',
f'{class_name}_Recall',
f'{class_name}_Specificity'
])
writer_csv = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer_csv.writeheader()
print("\n开始训练模型...")
for epoch in range(1, self.epochs + 1):
# 训练阶段
model.train()
train_loss = 0.0
for images, masks in tqdm(train_loader, desc=f"Epoch {epoch}/{self.epochs} [训练]"):
images = images.to(self.device)
masks = masks.to(self.device)
# 前向传播
outputs = model(images)
# 计算损失
if isinstance(criterion, nn.CrossEntropyLoss):
# 对于CrossEntropyLoss,直接使用类别索引
loss = criterion(outputs, masks.long())
else:
# 对于其他损失函数,需要将掩码转换为one-hot编码
masks_one_hot = F.one_hot(masks.long(), num_classes=self.num_classes).permute(0, 3, 1, 2).float()
loss = criterion(outputs, masks_one_hot)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item() * images.size(0)
# 计算平均训练损失
train_loss = train_loss / len(train_loader.dataset)
# 验证阶段
model.eval()
val_loss = 0.0
all_val_preds = []
all_val_masks = []
with torch.no_grad():
for images, masks in tqdm(val_loader, desc=f"Epoch {epoch}/{self.epochs} [验证]"):
images = images.to(self.device)
masks = masks.to(self.device)
# 前向传播
outputs = model(images)
# 计算损失
if isinstance(criterion, nn.CrossEntropyLoss):
loss = criterion(outputs, masks.long())
else:
masks_one_hot = F.one_hot(masks.long(), num_classes=self.num_classes).permute(0, 3, 1, 2).float()
loss = criterion(outputs, masks_one_hot)
val_loss += loss.item() * images.size(0)
# 收集预测结果和真实标签
all_val_preds.append(outputs.cpu().numpy())
all_val_masks.append(masks.cpu().numpy())
# 计算平均验证损失
val_loss = val_loss / len(val_loader.dataset)
# 合并所有验证集的预测结果和真实标签
val_preds = np.concatenate(all_val_preds, axis=0)
val_masks = np.concatenate(all_val_masks, axis=0)
# 计算验证指标
val_metrics = self._calculate_metrics(val_masks, val_preds, verbose=False)
val_metrics['Epoch'] = epoch
val_metrics['Train_Loss'] = train_loss
val_metrics['Val_Loss'] = val_loss
# 记录到TensorBoard
writer.add_scalar('Loss/Train', train_loss, epoch)
writer.add_scalar('Loss/Validation', val_loss, epoch)
writer.add_scalar('Metrics/Dice_Global', val_metrics['Dice_Global'], epoch)
writer.add_scalar('Metrics/Accuracy_Global', val_metrics['Accuracy_Global'], epoch)
# 保存指标到列表
all_metrics.append(val_metrics)
# 确保临时目录存在再保存模型
os.makedirs(self.temp_model_dir, exist_ok=True) # 关键修复:确保目录存在
# 保存当前epoch的模型
epoch_model_path = os.path.join(self.temp_model_dir, f'epoch_{epoch:03d}.pth')
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'train_loss': train_loss,
'val_loss': val_loss,
'val_metrics': val_metrics
}, epoch_model_path)
print(f"Epoch {epoch:03d}: 模型已保存到 {epoch_model_path}")
# 写入CSV日志 - 确保目录存在
os.makedirs(os.path.dirname(self.train_log_path), exist_ok=True)
with open(self.train_log_path, 'a', newline='', encoding='utf-8') as csvfile:
writer_csv = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer_csv.writerow(val_metrics)
# 打印训练进度
print(f"Epoch {epoch}/{self.epochs} - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, "
f"Val Dice Global: {val_metrics['Dice_Global']:.4f}, "
f"Val Accuracy Global: {val_metrics['Accuracy_Global']:.4f}")
# 学习率调度
if epoch > 1 and val_loss < best_val_loss:
best_val_loss = val_loss
epochs_without_improvement = 0
else:
epochs_without_improvement += 1
if epochs_without_improvement >= self.lr_reduction_patience:
# 降低学习率
for param_group in optimizer.param_groups:
param_group['lr'] *= self.lr_reduction_factor
print(f"Epoch {epoch}: 降低学习率至 {optimizer.param_groups[0]['lr']:.2e}")
epochs_without_improvement = 0
# 早停检查
if epochs_without_improvement >= self.early_stopping_patience:
print(f"Epoch {epoch}: 早停触发,停止训练")
break
# 关闭TensorBoard写入器
writer.close()
# 训练完成后选择最佳模型
self._select_and_save_best_model(all_metrics)
self._plot_performance_curves()
self._test_best_model(test_loader)
def _select_and_save_best_model(self, all_metrics):
"""选择并保存最佳模型"""
best_epoch = self._select_best_model(all_metrics)
if best_epoch > 0:
# 加载最佳模型
best_model_path = os.path.join(self.temp_model_dir, f'epoch_{best_epoch:03d}.pth')
if os.path.exists(best_model_path):
# 确保目标目录存在
os.makedirs(os.path.dirname(self.best_model_path), exist_ok=True) # 关键修复
# 复制到最终位置
shutil.copyfile(best_model_path, self.best_model_path)
print(f"已将最佳模型 (Epoch {best_epoch}) 保存到: {self.best_model_path}")
else:
print(f"警告: 找不到最佳模型对应的文件 (Epoch {best_epoch})")
else:
print("警告: 没有找到有效的最佳模型")
# 保存最佳模型性能
if best_epoch > 0:
best_metrics = next(m for m in all_metrics if m['Epoch'] == best_epoch)
with open(self.best_model_performance_path, 'w', encoding='utf-8') as f:
f.write("Attention U-Net 最佳模型验证性能\n")
f.write(f"保存时间: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
f.write(f"最佳模型路径: {self.best_model_path}\n")
f.write(f"对应Epoch: {best_epoch}\n")
f.write("-" * 50 + "\n")
for key, value in best_metrics.items():
if isinstance(value, float):
f.write(f"{key}: {value:.4f}\n")
else:
f.write(f"{key}: {value}\n")
print(f"最佳模型的验证性能参数已记录在 '{self.best_model_performance_path}'")
# 删除临时模型目录
if os.path.exists(self.temp_model_dir):
print(f"删除临时模型文件目录:{self.temp_model_dir}")
shutil.rmtree(self.temp_model_dir)
def _test_best_model(self, test_loader):
if not os.path.exists(self.best_model_path):
print(f"错误:未找到最佳模型文件: {self.best_model_path}。请确保模型已成功训练并保存。")
return
print("\n正在加载最佳模型进行测试...")
try:
# 修复模型加载问题 - 添加安全上下文管理器
import torch.serialization
from numpy import scalar # 导入需要的NumPy类型
# 创建安全上下文加载模型
with torch.serialization.safe_globals([scalar]):
checkpoint = torch.load(
self.best_model_path,
map_location=self.device
)
# 重建模型结构
input_channels = self.input_shape[2] if len(self.input_shape) == 3 else 3
model = self.AttentionUNetModel(
input_channels=input_channels,
num_classes=self.num_classes,
filters_base=self.filters_base,
kernel_size=self.kernel_size,
dropout_rate=self.dropout_rate,
attention_mechanism=self.attention_mechanism,
output_activation=self.output_activation
).to(self.device)
# 加载状态字典
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
except Exception as e:
print(f"错误: 无法加载最佳模型 {self.best_model_path}. 错误: {e}")
return
# 进行测试
all_test_preds = []
all_test_masks = []
with torch.no_grad():
for images, masks in tqdm(test_loader, desc="测试最佳模型"):
images = images.to(self.device)
masks = masks.to(self.device)
outputs = model(images)
all_test_preds.append(outputs.cpu().numpy())
all_test_masks.append(masks.cpu().numpy())
# 合并结果
test_preds = np.concatenate(all_test_preds, axis=0)
test_masks = np.concatenate(all_test_masks, axis=0)
# 计算指标
test_metrics = self._calculate_metrics(test_masks, test_preds, verbose=True)
# 保存测试性能
with open(self.best_model_performance_path, 'a', encoding='utf-8') as f:
f.write("\n" + "=" * 50 + "\n")
f.write("测试集性能:\n")
for key, value in test_metrics.items():
if isinstance(value, float):
f.write(f"{key}: {value:.4f}\n")
else:
f.write(f"{key}: {value}\n")
print(f"最佳模型的测试性能参数已追加到 '{self.best_model_performance_path}'")
def _plot_performance_curves(self):
"""
根据训练日志数据绘制参数-迭代次数曲线
"""
print("\n正在绘制性能曲线图...")
if not os.path.exists(self.train_log_path):
print(f"错误:未找到训练日志文件: {self.train_log_path},无法绘制曲线图。")
return
# 确保目录存在
os.makedirs(self.performance_plots_dir, exist_ok=True)
# 从CSV文件读取数据
epochs = []
data = {}
# 读取CSV文件
with open(self.train_log_path, 'r', encoding='utf-8') as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
epochs.append(int(row['Epoch']))
for key, value in row.items():
if key == 'Epoch':
continue
if key not in data:
data[key] = []
try:
data[key].append(float(value))
except:
data[key].append(0.0)
if not epochs:
print("训练日志中没有足够的有效数据来绘制曲线。")
return
# 获取所有类别名称(包括背景和前景)
all_classes = set()
for key in data.keys():
# 只处理包含下划线且不以'Val'或'Train'开头的键
if '_' in key and not key.startswith(('Val_', 'Train_')):
# 提取类别名称(例如:'background_Dice' -> 'background')
class_name = key.split('_')[0]
# 确保类别名称在已知类别中
if class_name in [self.background_name] + list(self.foreground_labels.values()):
all_classes.add(class_name)
# 1. 损失曲线
plt.figure(figsize=(10, 6))
plt.plot(epochs, data['Train_Loss'], label='训练损失', color='blue', linewidth=2)
plt.plot(epochs, data['Val_Loss'], label='验证损失', color='red', linewidth=2)
plt.title('训练和验证损失 vs. 迭代次数', fontsize=14)
plt.xlabel('迭代次数', fontsize=12)
plt.ylabel('损失值', fontsize=12)
plt.xticks(range(min(epochs), max(epochs)+1, max(1, len(epochs)//10)))
plt.legend(fontsize=12)
plt.grid(True)
plt.tight_layout()
plt.savefig(os.path.join(self.performance_plots_dir, '损失迭代曲线.png'))
plt.close()
# 2. 全局Dice和Accuracy曲线
plt.figure(figsize=(10, 6))
plt.plot(epochs, data['Dice_Global'], label='全局Dice系数', color='blue', linewidth=2)
plt.plot(epochs, data['Accuracy_Global'], label='全局准确率', color='red', linewidth=2)
plt.title('全局Dice系数和准确率 vs. 迭代次数', fontsize=14)
plt.xlabel('迭代次数', fontsize=12)
plt.ylabel('值', fontsize=12)
plt.xticks(range(min(epochs), max(epochs)+1, max(1, len(epochs)//10)))
plt.legend(fontsize=12)
plt.grid(True)
plt.tight_layout()
plt.savefig(os.path.join(self.performance_plots_dir, '全局Dice和准确率迭代曲线.png'))
plt.close()
# 3. 各类别IoU曲线
plt.figure(figsize=(10, 6))
for class_name in all_classes:
plt.plot(epochs, data[f'{class_name}_IoU'], label=f'{class_name} IoU', linewidth=2)
plt.title('各类IoU vs. 迭代次数', fontsize=14)
plt.xlabel('迭代次数', fontsize=12)
plt.ylabel('IoU值', fontsize=12)
plt.xticks(range(min(epochs), max(epochs)+1, max(1, len(epochs)//10)))
plt.legend(fontsize=12)
plt.grid(True)
plt.tight_layout()
plt.savefig(os.path.join(self.performance_plots_dir, '各类IoU迭代曲线.png'))
plt.close()
# 4. 各类别Precision曲线
plt.figure(figsize=(10, 6))
for class_name in all_classes:
plt.plot(epochs, data[f'{class_name}_Precision'], label=f'{class_name} 精确率', linewidth=2)
plt.title('各类精确率 vs. 迭代次数', fontsize=14)
plt.xlabel('迭代次数', fontsize=12)
plt.ylabel('精确率值', fontsize=12)
plt.xticks(range(min(epochs), max(epochs)+1, max(1, len(epochs)//10)))
plt.legend(fontsize=12)
plt.grid(True)
plt.tight_layout()
plt.savefig(os.path.join(self.performance_plots_dir, '各类精确率迭代曲线.png'))
plt.close()
# 5. 各类别Recall曲线
plt.figure(figsize=(10, 6))
for class_name in all_classes:
plt.plot(epochs, data[f'{class_name}_Recall'], label=f'{class_name} 召回率', linewidth=2)
plt.title('各类召回率 vs. 迭代次数', fontsize=14)
plt.xlabel('迭代次数', fontsize=12)
plt.ylabel('召回率值', fontsize=12)
plt.xticks(range(min(epochs), max(epochs)+1, max(1, len(epochs)//10)))
plt.legend(fontsize=12)
plt.grid(True)
plt.tight_layout()
plt.savefig(os.path.join(self.performance_plots_dir, '各类召回率迭代曲线.png'))
plt.close()
# 6. 各类别Specificity曲线
plt.figure(figsize=(10, 6))
for class_name in all_classes:
plt.plot(epochs, data[f'{class_name}_Specificity'], label=f'{class_name} 特异性', linewidth=2)
plt.title('各类特异性 vs. 迭代次数', fontsize=14)
plt.xlabel('迭代次数', fontsize=12)
plt.ylabel('特异性值', fontsize=12)
plt.xticks(range(min(epochs), max(epochs)+1, max(1, len(epochs)//10)))
plt.legend(fontsize=12)
plt.grid(True)
plt.tight_layout()
plt.savefig(os.path.join(self.performance_plots_dir, '各类特异性迭代曲线.png'))
plt.close()
# 7. 各类别Accuracy曲线
plt.figure(figsize=(10, 6))
for class_name in all_classes:
plt.plot(epochs, data[f'{class_name}_Accuracy'], label=f'{class_name} 准确率', linewidth=2)
plt.title('各类准确率 vs. 迭代次数', fontsize=14)
plt.xlabel('迭代次数', fontsize=12)
plt.ylabel('准确率值', fontsize=12)
plt.xticks(range(min(epochs), max(epochs)+1, max(1, len(epochs)//10)))
plt.legend(fontsize=12)
plt.grid(True)
plt.tight_layout()
plt.savefig(os.path.join(self.performance_plots_dir, '各类准确率迭代曲线.png'))
plt.close()
# 8. 各类别Dice曲线
plt.figure(figsize=(10, 6))
for class_name in all_classes:
plt.plot(epochs, data[f'{class_name}_Dice'], label=f'{class_name} Dice', linewidth=2)
plt.title('各类Dice系数 vs. 迭代次数', fontsize=14)
plt.xlabel('迭代次数', fontsize=12)
plt.ylabel('Dice值', fontsize=12)
plt.xticks(range(min(epochs), max(epochs)+1, max(1, len(epochs)//10)))
plt.legend(fontsize=12)
plt.grid(True)
plt.tight_layout()
plt.savefig(os.path.join(self.performance_plots_dir, '各类Dice迭代曲线.png'))
plt.close()
# 9. 每个类别的Dice和IoU曲线(n张图)
for class_name in all_classes:
plt.figure(figsize=(10, 6))
plt.plot(epochs, data[f'{class_name}_Dice'], label=f'{class_name} Dice', color='blue', linewidth=2)
plt.plot(epochs, data[f'{class_name}_IoU'], label=f'{class_name} IoU', color='red', linewidth=2)
plt.title(f'{class_name}: Dice和IoU vs. 迭代次数', fontsize=14)
plt.xlabel('迭代次数', fontsize=12)
plt.ylabel('值', fontsize=12)
plt.xticks(range(min(epochs), max(epochs)+1, max(1, len(epochs)//10)))
plt.legend(fontsize=12)
plt.grid(True)
plt.tight_layout()
plt.savefig(os.path.join(self.performance_plots_dir, f'{class_name}_Dice+IoU迭代曲线.png'))
plt.close()
# 10. 每个类别的Precision和Recall曲线(n张图)
for class_name in all_classes:
plt.figure(figsize=(10, 6))
plt.plot(epochs, data[f'{class_name}_Precision'], label=f'{class_name} 精确率', color='blue', linewidth=2)
plt.plot(epochs, data[f'{class_name}_Recall'], label=f'{class_name} 召回率', color='red', linewidth=2)
plt.title(f'{class_name}: 精确率和召回率 vs. 迭代次数', fontsize=14)
plt.xlabel('迭代次数', fontsize=12)
plt.ylabel('值', fontsize=12)
plt.xticks(range(min(epochs), max(epochs)+1, max(1, len(epochs)//10)))
plt.legend(fontsize=12)
plt.grid(True)
plt.tight_layout()
plt.savefig(os.path.join(self.performance_plots_dir, f'{class_name}_精确率+召回率迭代曲线.png'))
plt.close()
print(f"性能曲线图已保存到 '{self.performance_plots_dir}'")
print(f"总共生成 {8 + 2*len(all_classes)} 张曲线图")
print("\n所有任务已完成!")
#------------------------------------- 示例用法 --------------------------------------------
# 定义模型基础参数
input_shape = (512, 512, 3)
classes = {
'class_number': 3,
'bg': (2, 'background'),
'fg': {0: 'cup', 1: 'disc'}
}
epochs = 5
result_base_path = '/content/drive/MyDrive/results'
database_base_path = '/content/drive/MyDrive/database'
# 实例化模型类
unet_model = AttentionUNet(
input_shape=input_shape,
classes=classes,
epochs=epochs,
result_path=result_base_path,
database_path=database_base_path,
learning_rate=1e-4,
batch_size=5,
early_stopping_patience=3,
lr_reduction_patience=2,
dropout_rate=0.5,
filters_base=16,
attention_mechanism='additive',
loss='cross_entropy' # 修改为PyTorch支持的损失函数
)
# 构建并训练模型
unet_model.train()
#------------------------------------------------------------------------------------------
上述代码运行时报错:
正在绘制性能曲线图...
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
<ipython-input-8-2018641610> in <cell line: 0>()
1143
1144 # 构建并训练模型
-> 1145 unet_model.train()
1146 #------------------------------------------------------------------------------------------
1 frames
<ipython-input-8-2018641610> in _plot_performance_curves(self)
941 reader = csv.DictReader(csvfile)
942 for row in reader:
--> 943 epochs.append(int(row['Epoch']))
944 for key, value in row.items():
945 if key == 'Epoch':
KeyError: 'Epoch'
请修改