你是一个程序员,现在请你看看以下代码逻辑上有什么问题:import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import matplotlib.pyplot as plt
import itertools
# ================= 配置参数 =================
TRAIN_ROOT = "MRI_END/train"
TEST_ROOT = "MRI_END/test"
MODALITIES = ['PD', 'T1', 'T2']
MODALITY_STATS = {
'PD': {'mean': [0.1138], 'std': [0.1147]},
'T1': {'mean': [0.1632], 'std': [0.1887]},
'T2': {'mean': [0.1082], 'std': [0.1121]}
}
# ================= 数据预处理 =================
class MedicalTransform:
def __init__(self, phase='train'):
if phase == 'train':
self.transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(30),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.ToTensor(),
transforms.RandomErasing(p=0.3, scale=(0.02, 0.1)),# 模拟遮挡
transforms.Normalize(mean=stats['mean'], std=stats['std'])
])
else:
self.transform = transforms.Compose([
transforms.Resize(128 + 32),
transforms.CenterCrop(128),
transforms.ToTensor(),
])
def __call__(self, img, modality):
img = self.transform(img)
stats = MODALITY_STATS[modality]
return transforms.Normalize(mean=stats['mean'], std=stats['std'])(img)
# ================= 数据集类 =================
class KneeMRIDataset(Dataset):
def __init__(self, base_path, phase='train'):
self.phase = phase
self.samples = []
self.mod_combinations = []
self.transform = MedicalTransform(phase)
for age_folder in os.listdir(base_path):
age_path = os.path.join(base_path, age_folder)
if not os.path.isdir(age_path): continue
for subject_folder in os.listdir(age_path):
subject_path = os.path.join(age_path, subject_folder)
if not os.path.isdir(subject_path): continue
parts = subject_folder.split('_', 1)
gender = 0 if parts[0] == 'M' else 1
age_val = int(age_folder.split('_')[1])
mod_files = {}
for fname in os.listdir(subject_path):
mod_prefix = fname.split('_')[0].upper()
if mod_prefix in MODALITIES:
mod_files[mod_prefix] = os.path.join(subject_path, fname)
if len(mod_files) >= 2:
valid_combs = list(itertools.combinations(mod_files.keys(), 2))#生成所有模态组合
self.samples.append({
'age': age_val,
'gender': gender,#性别
'mod_files': mod_files,
'valid_combs': valid_combs#可用模态
})
self.mod_combinations.extend(valid_combs)
#排序去重转化为列表
self.mod_combinations = list(set(tuple(sorted(c)) for c in self.mod_combinations))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
sample = self.samples[idx]
if self.phase == 'train':
selected_mods = random.choice(sample['valid_combs'])#随机选取一组模态
# 构建模态字典
modality_images = {}
for mod in selected_mods:
img = Image.open(sample['mod_files'][mod]).convert('L')
img = self.transform(img, mod)
modality_images[mod] = img
return {
'modality_images': modality_images,
'available_modes': selected_mods,
'age': sample['age'],
'gender': sample['gender'],
'mod_comb': self.mod_combinations.index(tuple(sorted(selected_mods)))
}
# ================= 新特征融合模块 =================
class ModalitySelector(nn.Module):
def __init__(self):
super().__init__()
# 共享编码器结构(适用于所有模态)
self.encoder = nn.Sequential(
nn.Conv2d(1, 64, 3, padding=1), # 单通道输入(灰度医学影像)
nn.ReLU(),
nn.MaxPool2d(2), # 空间下采样
nn.Conv2d(64, 128, 3, padding=1),
nn.ReLU(),
nn.AdaptiveAvgPool2d((16, 16))#确保不同分辨率的输入输出统一尺寸
)
def forward(self, input_dict, available_modes):
features = {}#available_modes 参数允许运行时灵活选择模态组合
for mode in available_modes:
# 所有模态共享相同的编码器
features[mode] = self.encoder(input_dict[mode])# 共享编码器
return features
class DualAttentionFusion(nn.Module):
def __init__(self, in_channels = 128):
super().__init__()
# 自注意力:提取模态内特征
self.self_attn = nn.MultiheadAttention(in_channels, num_heads=4)
# 跨模态注意力:建立模态间关联
self.cross_attn = nn.MultiheadAttention(in_channels, num_heads=4)
# 特征压缩层
self.compression = nn.Sequential(
nn.Linear(in_channels, in_channels//2),
nn.ReLU()
)
def forward(self, feat1, feat2):
# 展平空间维度 (B, C, H, W) -> (B, H*W, C)
B, C, H, W = feat1.size()
feat1 = feat1.view(B, C, -1).permute(2, 0, 1) # (L, B, C)
feat2 = feat2.view(B, C, -1).permute(2, 0, 1)
# 自注意力增强
feat1_attn, _ = self.self_attn(feat1, feat1, feat1)
feat2_attn, _ = self.self_attn(feat2, feat2, feat2)
# 跨模态交互以feat1为Query,feat2为Key/Value)
fused, _ = self.cross_attn(feat1_attn, feat2_attn, feat2_attn)
# 压缩特征维度
fused = fused.permute(1, 2, 0) # (B, C, L)
fused = fused.mean(dim=-1) # (B, C)
return self.compression(fused)
class GenderFusion(nn.Module):
def __init__(self, img_feat_dim, gender_dim=32):
super().__init__()
# 性别嵌入层
self.gender_emb = nn.Embedding(2, gender_dim)
# 门控注意力机制
self.attn_gate = nn.Sequential(
nn.Linear(img_feat_dim + gender_dim, 128),
nn.ReLU(),
nn.Linear(128, img_feat_dim),
nn.Sigmoid()
)
# 特征调制
self.feature_modulator = nn.Sequential(
nn.Conv1d(img_feat_dim, img_feat_dim//2, 1),
nn.ReLU(),
nn.Conv1d(img_feat_dim//2, img_feat_dim, 1),
nn.Sigmoid()
)
def forward(self, img_feat, gender_labels):
# 性别嵌入
gender_emb = self.gender_emb(gender_labels)
# 特征调制
spatial_weights = self.feature_modulator(img_feat.unsqueeze(2)).squeeze(2)
modulated_feat = img_feat * spatial_weights
# 注意力门控
fused = torch.cat([modulated_feat, gender_emb], dim=1)
attn_weights = self.attn_gate(fused)
weighted_feat = modulated_feat * attn_weights
return weighted_feat
class EnhancedKneeAgePredictor(nn.Module):
def __init__(self, num_combinations):
super().__init__()
# 创建模型但不加载预训练权重
self.feature_extractor = timm.create_model(
'resnet18',
pretrained=True,
in_chans=1
)
num_features = feature_extractor.fc.in_features
feature_extractor.fc = nn.Linear(num_features,1)
# 模态选择器
self.mod_selector = ModalitySelector()
# 双模态注意力融合
self.dual_attn_fusion = DualAttentionFusion(128)
# 性别融合
self.gender_fusion = GenderFusion(64, 32)
# 分类头
self.classifier = nn.Sequential(
nn.Linear(64, 32),
nn.ReLU(),
nn.Linear(32, 1)
)
# 嵌入层
self.mod_emb = nn.Embedding(num_combinations, 64)
# 模态组合映射
self.mod_mapper = nn.Embedding(len(MODALITIES), 128)
def forward(self, input_dict, gender, mod_comb, available_modes):
# 特征提取
features = self.mod_selector(input_dict, available_modes)
# 获取模态特征
mod_keys = list(features.keys())
feat1 = features[mod_keys[0]]
feat2 = features[mod_keys[1]]
# 模态映射
mod1_bias = self.mod_mapper(torch.tensor(MODALITIES.index(mod_keys[0]), device=feat1.device))
mod2_bias = self.mod_mapper(torch.tensor(MODALITIES.index(mod_keys[1]), device=feat1.device))
feat1 = feat1 + mod1_bias.view(1, -1, 1, 1)
feat2 = feat2 + mod2_bias.view(1, -1, 1, 1)
# 双模态融合
fused = self.dual_attn_fusion(feat1, feat2)
# 性别融合
gender_fused = self.gender_fusion(fused, gender)
# 模态组合嵌入
mod_bias = self.mod_emb(mod_comb)
final_feat = gender_fused + mod_bias
return self.classifier(final_feat).squeeze(1)
# ================= 训练流程 - 已修改 =================
# 数据集初始化
train_val_dataset = KneeMRIDataset(TRAIN_ROOT, phase='train')
num_combinations = len(train_val_dataset.mod_combinations)
# 数据拆分
train_size = int(0.8 * len(train_val_dataset))
train_set, val_set = random_split(
train_val_dataset,
[train_size, len(train_val_dataset)-train_size],
generator=torch.Generator().manual_seed(42)
)
test_dataset = KneeMRIDataset(TEST_ROOT, phase='test')
# 数据加载器配置
train_loader = DataLoader(train_set, batch_size=64, shuffle=True,persistent_workers=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_set, batch_size=64, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=64, num_workers=2, pin_memory=True)
# 设备选择
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model =EnhancedKneeAgePredictor(num_combinations).to(device)
optimizer = optim.AdamW(model.parameters(), lr=0.0001, weight_decay=0.05)
criterion = nn.HuberLoss()
# 训练日志初始化
training_log = {
'train_loss': [], 'val_loss': [], 'test_loss': [],
'train_acc': [], 'val_acc': [], 'test_acc': [],
'best_acc': 0.0
}
# 主训练循环
for epoch in range(60):
train_set.dataset.set_epoch(epoch)
# 训练阶段
model.train()
epoch_loss, correct_preds = 0.0, 0
for batch in train_loader:
inputs = batch['modality_images'].to(device)
labels = batch['age'].float().to(device) # 保持浮点型
genders = batch['gender'].to(device)
mod_combs = batch['mod_comb'].to(device)
optimizer.zero_grad()
outputs = model(inputs, genders, mod_combs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# 回归任务准确率计算(误差±1岁内视为正确)[5](@ref)
pred_ages = outputs.squeeze()
correct = torch.abs(pred_ages - labels) < 1.0
correct_preds += correct.sum().item()
epoch_loss += loss.item() * len(labels) # 用len(labels)替代inputs.size(0)
# 验证阶段
model.eval()
val_loss, val_correct = 0.0, 0
with torch.no_grad():
torch.manual_seed(42)
for batch in val_loader:
inputs = batch['images'].to(device)
labels = batch['age'].float().to(device) # 保持浮点型
genders = batch['gender'].to(device)
mod_combs = batch['mod_comb'].to(device)
outputs = model(inputs, genders, mod_combs)
loss = criterion(outputs, labels)
val_loss += loss.item() * inputs.size(0)
pred_ages = outputs.squeeze()
correct = torch.abs(pred_ages - labels) < 1.0
val_correct += correct.sum().item()
# 测试阶段
test_loss, test_correct = 0.0, 0
with torch.no_grad():
for batch in test_loader:
inputs = batch['images'].to(device)
labels = batch['age'].float().to(device) # 保持浮点型
genders = batch['gender'].to(device)
mod_combs = batch['mod_comb'].to(device)
outputs = model(inputs, genders, mod_combs)
loss = criterion(outputs, labels)
test_loss += loss.item() * inputs.size(0)
pred_ages = outputs.squeeze()
correct = torch.abs(pred_ages - labels) < 1.0
val_correct += correct.sum().item()
# 记录指标
train_loss = epoch_loss / len(train_set)
train_acc = correct_preds / len(train_set)
val_loss = val_loss / len(val_set)
val_acc = val_correct / len(val_set)
test_loss = test_loss / len(test_dataset)
test_acc = test_correct / len(test_dataset)
training_log['train_loss'].append(train_loss)
training_log['train_acc'].append(train_acc)
training_log['val_loss'].append(val_loss)
training_log['val_acc'].append(val_acc)
training_log['test_loss'].append(test_loss)
training_log['test_acc'].append(test_acc)
# 模型保存逻辑
if val_acc > training_log['best_acc']:
training_log['best_acc'] = val_acc
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': train_loss,
'accuracy': val_acc
}, 'best_age_predictor.pth')
# 训练进度输出
print(f'Epoch [{epoch+1:02d}/{TOTAL_EPOCHS}]')
print(f'Train Loss: {train_loss:.4f} | Acc: {train_acc:.4f}')
print(f'Val Loss: {val_loss:.4f} | Acc: {val_acc:.4f}')
print(f'Test Loss: {test_loss:.4f} | Acc: {test_acc:.4f}\n')
# 可视化训练过程
plt.figure(figsize=(15, 6))
plt.subplot(1, 2, 1)
plt.plot(training_log['train_loss'], label='Train', linestyle='--')
plt.plot(training_log['val_loss'], label='Val', linestyle='-.')
plt.plot(training_log['test_loss'], label='Test', linestyle='-')
plt.title('Loss Trajectory'), plt.xlabel('Epoch'), plt.ylabel('Loss'), plt.legend()
plt.subplot(1, 2, 2)
plt.plot(training_log['train_acc'], label='Train', linestyle='--')
plt.plot(training_log['val_acc'], label='Val', linestyle='-.')
plt.plot(training_log['test_acc'], label='Test', linestyle='-')
plt.title('Accuracy Progress'), plt.xlabel('Epoch'), plt.ylabel('Accuracy'), plt.legend()
plt.tight_layout()
plt.savefig('training_metrics.png')
plt.show()
最新发布