详细解释代码:
import os
import json
import torch
import numpy as np
from PIL import Image, ImageDraw
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
# ========== 路径配置 ==========
data_dir = r'C:\Users\Administrator\Desktop\vessel'
test_files = ['Gentry_771_xa_010.tif.jpg', 'Gymnostoma nobile 0895.JPG.jpg']
# ========== 自定义数据集类 ==========
class SegmentationDataset(Dataset):
def __init__(self, image_paths, img_transform=None, mask_transform=None):
self.image_paths = [p for p in image_paths if os.path.basename(p) not in test_files]
self.img_transform = img_transform
self.mask_transform = mask_transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img_path = self.image_paths[idx]
image = Image.open(img_path).convert('RGB')
json_path = os.path.splitext(img_path)[0] + '.json'
with open(json_path) as f:
annotation = json.load(f)
mask = Image.new('L', image.size, 0)
for shape in annotation['shapes']:
points = [(p[0], p[1]) for p in shape['points']]
ImageDraw.Draw(mask).polygon(points, outline=255, fill=255)
if self.img_transform:
image = self.img_transform(image)
if self.mask_transform:
mask = self.mask_transform(mask)
mask = (mask > 0).float()
return image, mask
# ========== QKV 注意力模块 ==========
class QKVAttention(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.inter_channels = in_channels // 8
self.query = nn.Conv2d(in_channels, self.inter_channels, kernel_size=1)
self.key = nn.Conv2d(in_channels, self.inter_channels, kernel_size=1)
self.value = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
batch_size, C, H, W = x.size()
# 生成Q, K, V
q = self.query(x).view(batch_size, self.inter_channels, -1).permute(0, 2, 1)
k = self.key(x).view(batch_size, self.inter_channels, -1)
v = self.value(x).view(batch_size, C, -1)
# 计算注意力分数
attn = torch.bmm(q, k) # [batch, HW, HW]
attn = torch.softmax(attn, dim=-1)
# 应用注意力
out = torch.bmm(v, attn.permute(0, 2, 1))
out = out.view(batch_size, C, H, W)
return self.gamma * out + x
# ========== 基于ResNet的FCN模型 ==========
class ResNetFCN(nn.Module):
def __init__(self, pretrained=True):
super().__init__()
# 加载预训练的ResNet18
resnet = models.resnet18(pretrained=pretrained)
# 编码器部分
self.encoder1 = nn.Sequential(
resnet.conv1,
resnet.bn1,
resnet.relu,
resnet.maxpool
)
self.encoder2 = resnet.layer1
self.encoder3 = resnet.layer2
self.encoder4 = resnet.layer3
self.encoder5 = resnet.layer4
# 注意力模块(添加到瓶颈层)
self.attention = QKVAttention(512)
# 解码器部分
self.up1 = self._up_block(512, 256)
self.up2 = self._up_block(256, 128)
self.up3 = self._up_block(128, 64)
self.up4 = self._up_block(64, 64)
# 最终输出层
self.final_conv = nn.Conv2d(64, 1, kernel_size=1)
def _up_block(self, in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1),
nn.ReLU(inplace=True),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
)
def forward(self, x):
# 编码过程
e1 = self.encoder1(x) # 64, 64, 64
e2 = self.encoder2(e1) # 64, 64, 64
e3 = self.encoder3(e2) # 128, 32, 32
e4 = self.encoder4(e3) # 256, 16, 16
e5 = self.encoder5(e4) # 512, 8, 8
# 应用QKV注意力
attn = self.attention(e5)
# 解码过程
d1 = self.up1(attn) # 256, 16, 16
d2 = self.up2(d1) # 128, 32, 32
d3 = self.up3(d2) # 64, 64, 64
d4 = self.up4(d3) # 64, 128, 128
# 最终上采样到256x256
out = nn.functional.interpolate(d4, size=256, mode='bilinear', align_corners=True)
out = self.final_conv(out)
return torch.sigmoid(out)
# ========== 训练配置 ==========
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 图像预处理(包含归一化)
img_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
mask_transform = transforms.Compose([
transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.NEAREST),
transforms.ToTensor(),
])
# 获取所有图像路径
all_images = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith('.jpg')]
# 创建数据集
dataset = SegmentationDataset(all_images, img_transform=img_transform, mask_transform=mask_transform)
train_loader = DataLoader(dataset, batch_size=4, shuffle=True)
# 初始化模型
model = ResNetFCN(pretrained=True).to(device)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4) # 添加权重衰减防止过拟合
# ========== 训练循环 ==========
train_losses = []
for epoch in range(50): # 增加epoch数量
model.train()
epoch_loss = 0.0
for images, masks in train_loader:
images = images.to(device)
masks = masks.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, masks)
loss.backward()
optimizer.step()
epoch_loss += loss.item() * images.size(0)
epoch_loss /= len(dataset)
train_losses.append(epoch_loss)
print(f'Epoch {epoch + 1}, Loss: {epoch_loss:.4f}')
# 绘制训练损失曲线
plt.figure(figsize=(10, 6))
plt.plot(range(1, len(train_losses)+1), train_losses, 'b-o')
plt.title('Training Loss Curve')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)
plt.savefig('loss_curve.png')
plt.show()
# ========== 模型预测 ==========
model.eval()
test_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 反归一化函数用于显示
def denormalize(tensor):
mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
return tensor * std + mean
for test_file in test_files:
img_path = os.path.join(data_dir, test_file)
image = Image.open(img_path).convert('RGB')
image_tensor = test_transform(image).unsqueeze(0).to(device)
with torch.no_grad():
pred_mask = model(image_tensor).cpu().squeeze().numpy()
# 反归一化原始图像用于显示
orig_img = denormalize(image_tensor.cpu()).squeeze(0).permute(1, 2, 0).numpy()
orig_img = np.clip(orig_img, 0, 1)
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(orig_img)
plt.title('Original Image')
plt.subplot(1, 2, 2)
plt.imshow(pred_mask > 0.5, cmap='gray')
plt.title('Predicted Mask')
plt.savefig(f'prediction_{test_file}.png')
plt.show()
# ========== 模型评估 ==========
print("模型评估:")
print(f"最终训练损失: {train_losses[-1]:.4f}")
print("过拟合/欠拟合分析:")
if train_losses[-1] < 0.1:
print(" - 模型可能过拟合:训练损失很低,但测试集只有2张图片,无法验证泛化能力")
print(" - 改进思路:增加数据增强,添加Dropout层,使用更小的模型")
elif train_losses[-1] > 0.3:
print(" - 模型可能欠拟合:训练损失较高")
print(" - 改进思路:增加训练epoch,提高模型复杂度,调整学习率")
else:
print(" - 训练损失在合理范围内,但由于测试集太小,难以准确评估模型性能")
最新发布