# train.py import os import argparse import pandas as pd import numpy as np from tqdm import tqdm from sklearn.model_selection import train_test_split from sklearn.metrics import mean_absolute_error import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader from torchvision import models, transforms from PIL import Image import pickle # 定义设备 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # 自定义数据集 class BoneAgeDataset(Dataset): def __init__(self, df, image_folder, transform): self.df = df self.image_folder = image_folder self.transform = transform self.ids = df['id'].unique() self.data = [] self.prepare_dataset() def prepare_dataset(self): print("准备数据集...") for id_ in tqdm(self.ids): boneage = self.df[self.df['id'] == id_]['boneage'].values[0] image_files = [f for f in os.listdir(self.image_folder) if f.startswith(str(id_) + '_') and f.endswith('.png')] for img_file in image_files: img_path = os.path.join(self.image_folder, img_file) self.data.append((img_path, boneage)) def __len__(self): return len(self.data) def __getitem__(self, idx): img_path, boneage = self.data[idx] try: img = Image.open(img_path).convert('RGB') except Exception as e: print(f"无法处理图像 {img_path}: {e}") # 如果无法打开图像,返回一个全零的图像 img = Image.new('RGB', (224, 224), (0, 0, 0)) img = self.transform(img) return img, torch.tensor(boneage, dtype=torch.float32) # 定义端到端的回归模型 class EndToEndRegressionModel(nn.Module): def __init__(self, pretrained=True): super(EndToEndRegressionModel, self).__init__() self.backbone = models.resnet50(pretrained=pretrained) # 替换最后的全连接层 num_features = self.backbone.fc.in_features self.backbone.fc = nn.Sequential( nn.Linear(num_features, 1024), nn.BatchNorm1d(1024), nn.ReLU(), nn.Dropout(0.5), nn.Linear(1024, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(0.5), nn.Linear(512, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(0.3), nn.Linear(128, 1) ) def forward(self, x): return self.backbone(x).squeeze() def main(): parser = argparse.ArgumentParser(description="Train Enhanced Bone Age Regression Model End-to-End") parser.add_argument('--csv_path', type=str, required=True, help='Path to boneage CSV file') parser.add_argument('--image_folder', type=str, required=True, help='Path to image folder') parser.add_argument('--model_path', type=str, default='boneage_model.pth', help='Path to save the trained model') parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs') parser.add_argument('--batch_size', type=int, default=32, help='Batch size') parser.add_argument('--learning_rate', type=float, default=1e-4, help='Initial learning rate') parser.add_argument('--weight_decay', type=float, default=1e-5, help='Weight decay for optimizer') args = parser.parse_args() # 读取CSV文件 df = pd.read_csv(args.csv_path) df = df[['id', 'boneage']] # 定义图像预处理,并增强数据增强 transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.RandomHorizontalFlip(), transforms.RandomRotation(15), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # 创建数据集 dataset = BoneAgeDataset(df, args.image_folder, transform) # 使用索引创建训练和验证集 train_indices, val_indices = train_test_split( np.arange(len(dataset)), test_size=0.2, random_state=42, shuffle=True) train_subset = torch.utils.data.Subset(dataset, train_indices) val_subset = torch.utils.data.Subset(dataset, val_indices) # 创建数据加载器 train_loader = DataLoader(train_subset, batch_size=args.batch_size, shuffle=True, num_workers=4) val_loader = DataLoader(val_subset, batch_size=args.batch_size, shuffle=False, num_workers=4) print(f"总样本数: {len(dataset)}") print(f"训练集样本数: {len(train_subset)}, 验证集样本数: {len(val_subset)}") # 初始化模型、损失函数和优化器 model = EndToEndRegressionModel(pretrained=True).to(device) criterion = nn.L1Loss() # 使用 MAE 作为损失函数 # 解冻所有 ResNet50 的层 for param in model.parameters(): param.requires_grad = True optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) # 使用 ReduceLROnPlateau 基于 MAE 进行调整 scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='min', factor=0.1, patience=5, verbose=True) # 训练循环 best_val_mae = float('inf') patience = 15 # 提前停止的耐心值 trigger_times = 0 for epoch in range(args.epochs): # 训练阶段 model.train() running_loss = 0.0 for inputs, targets in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{args.epochs} - Training"): inputs = inputs.to(device) targets = targets.to(device) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, targets) loss.backward() optimizer.step() running_loss += loss.item() * inputs.size(0) epoch_loss = running_loss / len(train_loader.dataset) # 验证阶段 model.eval() val_running_mae = 0.0 all_preds = [] all_targets = [] with torch.no_grad(): for inputs, targets in tqdm(val_loader, desc=f"Epoch {epoch + 1}/{args.epochs} - Validation"): inputs = inputs.to(device) targets = targets.to(device) outputs = model(inputs) loss = criterion(outputs, targets) val_running_mae += loss.item() * inputs.size(0) all_preds.extend(outputs.cpu().numpy()) all_targets.extend(targets.cpu().numpy()) val_mae = val_running_mae / len(val_loader.dataset) # 也可以使用 sklearn 的 MAE 计算方式 # val_mae = mean_absolute_error(all_targets, all_preds) print( f"Epoch {epoch + 1}/{args.epochs} - Train MAE: {epoch_loss:.4f} - Val MAE: {val_mae:.4f}") # 调整学习率基于 MAE scheduler.step(val_mae) # 早停判断基于 MAE if val_mae < best_val_mae: best_val_mae = val_mae trigger_times = 0 torch.save(model.state_dict(), args.model_path) print(f"最佳模型已保存 (Val MAE: {best_val_mae:.4f})") else: trigger_times += 1 print(f"Val MAE 未改善 ({trigger_times}/{patience})") if trigger_times >= patience: print("满足早停条件,停止训练。") break print(f"训练完成,最佳验证 MAE: {best_val_mae:.4f}") print(f"模型已保存到 {args.model_path}") # 可选:保存训练历史或其他信息 with open('training_history.pkl', 'wb') as f: pickle.dump({ 'best_val_mae': best_val_mae, 'model_path': args.model_path }, f) if __name__ == "__main__": main()
gulingmae
最新推荐文章于 2025-07-27 21:37:11 发布