import os
import torch
import numpy as np
from torch.utils.data import DataLoader, Dataset
import cv2
import segmentation_models_pytorch as smp
from sklearn.metrics import jaccard_score, f1_score, precision_score, recall_score, average_precision_score
import csv # 新增:用于写入CSV
# 配置参数(保持不变)
DATA_PATH = "dataset100_alb"
BATCH_SIZE = 20
EPOCHS = 300
LR = 0.0001
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# 自定义数据集(保持不变)
class SegmentationDataset(Dataset):
# __init__、__len__、__getitem__方法与原代码一致,略
def __init__(self, image_dir, mask_dir, transform=None):
self.image_dir = image_dir
self.mask_dir = mask_dir
self.images = sorted(os.listdir(image_dir))
self.masks = sorted(os.listdir(mask_dir))
self.transform = transform
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = os.path.join(self.image_dir, self.images[idx])
mask_path = os.path.join(self.mask_dir, self.masks[idx])
# # 打印当前读取的图片路径(关键调试步骤)
# print(f"正在读取第 {idx} 张图片: {img_path}")
# 读取图片并检查是否成功
image = cv2.imread(img_path)
if image is None:
# 抛出异常并附带路径信息,终止程序并定位问题
raise RuntimeError(f"无法读取图片: {img_path}。请检查文件是否存在、格式是否正确或是否损坏。")
image = image.astype(np.float32) / 255.0
# 同理检查掩码读取
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
if mask is None:
raise RuntimeError(f"无法读取掩码: {mask_path}")
mask = mask.astype(np.float32) / 1.0
image = torch.from_numpy(image).permute(2, 0, 1)
mask = torch.from_numpy(mask).unsqueeze(0)
# 后续转换代码...
return image, mask
# 初始化模型(保持不变)
model = smp.Unet(
encoder_name="resnet34",
encoder_weights=None,
in_channels=3,
classes=1,
activation="sigmoid"
).to(DEVICE)
state_dict = torch.load("resnet34-333f7ec4.pth")
model.encoder.load_state_dict(state_dict)
# 数据准备(保持不变)
# 数据准备
train_dataset = SegmentationDataset(
os.path.join(DATA_PATH, "train/images"),
os.path.join(DATA_PATH, "train/masks")
)
val_dataset = SegmentationDataset(
os.path.join(DATA_PATH, "val/images"),
os.path.join(DATA_PATH, "val/masks")
)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
# 损失函数和优化器(保持不变)
criterion = smp.losses.DiceLoss(mode="binary")
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=4, verbose=True)
# 修改评估函数:新增P、R、mAP50计算
def evaluate_model():
model.eval()
total_loss = 0.0
all_preds = [] # 保存二值化预测结果(0/1)
all_masks = [] # 保存真实标签(0/1)
all_probs = [] # 保存模型输出的概率值(0-1)
with torch.no_grad():
for images, masks in val_loader:
images, masks = images.to(DEVICE), masks.to(DEVICE)
outputs = model(images)
loss = criterion(outputs, masks)
total_loss += loss.item()
# 收集概率值(用于mAP计算)
probs = outputs.cpu().numpy().flatten() # 展平为一维数组
all_probs.extend(probs)
# 二值化预测(阈值0.5)
preds = (outputs > 0.5).float().cpu().numpy().flatten()
masks_np = masks.cpu().numpy().flatten() # 真实标签展平
all_preds.extend(preds)
all_masks.extend(masks_np)
# 计算指标
iou = jaccard_score(all_masks, all_preds, average="macro")
f1 = f1_score(all_masks, all_preds, average="macro")
precision = precision_score(all_masks, all_preds, average="binary") # 二分类精确率
recall = recall_score(all_masks, all_preds, average="binary") # 二分类召回率
map50 = average_precision_score(all_masks, all_probs) # 近似mAP50
return total_loss/len(val_loader), iou, f1, precision, recall, map50
# 修改训练函数:新增CSV记录
def train_model():
best_iou = 0.0
# 创建CSV文件并写入表头
csv_path = "training_metrics.csv"
with open(csv_path, 'w', newline='') as f:
writer = csv.writer(f)
writer.writerow(["epoch", "train_loss", "val_loss", "iou", "f1", "precision", "recall", "map50", "lr"])
for epoch in range(EPOCHS):
model.train()
running_loss = 0.0
for images, masks in train_loader:
images, masks = images.to(DEVICE), masks.to(DEVICE)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, masks)
loss.backward()
optimizer.step()
running_loss += loss.item()
# 验证并获取新增指标
val_loss, iou, f1, precision, recall, map50 = evaluate_model()
scheduler.step(iou)
current_lr = optimizer.param_groups[0]['lr']
# 打印日志(新增指标)
print(f"Epoch {epoch+1}/{EPOCHS} | "
f"Train Loss: {running_loss/len(train_loader):.4f} | "
f"Val Loss: {val_loss:.4f} | "
f"IoU: {iou:.4f} | F1: {f1:.4f} | "
f"Precision: {precision:.4f} | Recall: {recall:.4f} | "
f"mAP50: {map50:.4f} | LR: {current_lr:.6f}")
# 写入CSV文件
with open(csv_path, 'a', newline='') as f:
writer = csv.writer(f)
writer.writerow([epoch+1,
running_loss/len(train_loader),
val_loss,
iou,
f1,
precision,
recall,
map50,
current_lr])
# 保存最佳模型(保持不变)
if iou > best_iou:
torch.save(model.state_dict(), "best_model.pth")
best_iou = iou
# 预测函数(保持不变)
def predict(image_path):
# 与原代码一致,略
model.load_state_dict(torch.load("best_model.pth"))
model.eval()
image = cv2.imread(image_path).astype(np.float32) / 255.0
image_tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).to(DEVICE)
with torch.no_grad():
output = model(image_tensor)
prediction = (output > 0.5).float().cpu().numpy()[0, 0]
return prediction
# 执行训练(保持不变)
if __name__ == "__main__":
train_model()
这段代码增加一个计时功能
最新发布