# train.py import os import argparse import pandas as pd import numpy as np from tqdm import tqdm from sklearn.model_selection import train_test_split from sklearn.metrics import mean_absolute_error import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader from torchvision import models, transforms from PIL import Image import pickle # 定义设备 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # 自定义数据集 class BoneAgeDataset(Dataset): def __init__(self, df, image_folder, transform, feature_extractor): self.df = df self.image_folder = image_folder self.transform = transform self.feature_extractor = feature_extractor self.ids = df['id'].unique() self.features = [] self.labels = [] self.prepare_features() def prepare_features(self): print("提取图像特征...") for id_ in tqdm(self.ids): boneage = self.df[self.df['id'] == id_]['boneage'].values[0] image_files = [os.path.join(self.image_folder, f) for f in os.listdir(self.image_folder) if f.startswith(str(id_) + '_') and f.endswith('.png')] feats = [] for img_path in image_files: try: img = Image.open(img_path).convert('RGB') img = self.transform(img).unsqueeze(0).to(device) with torch.no_grad(): feat = self.feature_extractor(img).squeeze().cpu().numpy() feats.append(feat) except Exception as e: print(f"无法处理图像 {img_path}: {e}") if len(feats) == 0: continue aggregated_feat = np.mean(feats, axis=0) self.features.append(aggregated_feat) self.labels.append(boneage) def __len__(self): return len(self.features) def __getitem__(self, idx): return torch.tensor(self.features[idx], dtype=torch.float32), torch.tensor(self.labels[idx], dtype=torch.float32) # 定义增强的回归模型 class EnhancedRegressionModel(nn.Module): def __init__(self, input_dim): super(EnhancedRegressionModel, self).__init__() self.regressor = nn.Sequential( nn.Linear(input_dim, 1024), nn.BatchNorm1d(1024), nn.ReLU(), nn.Dropout(0.5), nn.Linear(1024, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(0.5), nn.Linear(512, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(0.3), nn.Linear(128, 1) ) def forward(self, x): return self.regressor(x).squeeze() def main(): parser = argparse.ArgumentParser(description="Train Enhanced Bone Age Regression Model") parser.add_argument('--csv_path', type=str, required=True, help='Path to boneage CSV file') parser.add_argument('--image_folder', type=str, required=True, help='Path to image folder') parser.add_argument('--model_path', type=str, default='boneage_model.pth', help='Path to save the trained model') parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs') # 增加默认训练轮数为100 parser.add_argument('--batch_size', type=int, default=32, help='Batch size') parser.add_argument('--learning_rate', type=float, default=1e-4, help='Initial learning rate') parser.add_argument('--weight_decay', type=float, default=1e-5, help='Weight decay for optimizer') args = parser.parse_args() # 读取CSV文件 df = pd.read_csv(args.csv_path) df = df[['id', 'boneage']] # 定义图像预处理,并增强数据增强 transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.RandomHorizontalFlip(), transforms.RandomRotation(15), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # 加载预训练的ResNet50 resnet = models.resnet50(pretrained=True) # 1. 解冻所有 ResNet50 的层 for param in resnet.parameters(): param.requires_grad = True feature_extractor = nn.Sequential(*list(resnet.children())[:-1]) # 移除最后一层 feature_extractor.to(device) feature_extractor.train() # 设为训练模式以进行微调 # 创建数据集和数据加载器 dataset = BoneAgeDataset(df, args.image_folder, transform, feature_extractor) X = np.array(dataset.features) y = np.array(dataset.labels) print(f"总样本数: {len(dataset)}") # 分割训练和验证集 X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42) print(f"训练集样本数: {len(X_train)}, 验证集样本数: {len(X_val)}") # 转换为Tensor X_train = torch.tensor(X_train, dtype=torch.float32) y_train = torch.tensor(y_train, dtype=torch.float32) X_val = torch.tensor(X_val, dtype=torch.float32) y_val = torch.tensor(y_val, dtype=torch.float32) train_dataset = torch.utils.data.TensorDataset(X_train, y_train) val_dataset = torch.utils.data.TensorDataset(X_val, y_val) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False) # 初始化模型、损失函数和优化器 input_dim = X_train.shape[1] model = EnhancedRegressionModel(input_dim).to(device) criterion = nn.MSELoss() # 2. 仅优化回归部分和解冻的 ResNet 层 optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) # 使用 ReduceLROnPlateau;也可以结合 StepLR scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True) # 训练循环 best_val_loss = float('inf') patience = 15 # 增加早停耐心值为15 trigger_times = 0 for epoch in range(args.epochs): model.train() feature_extractor.train() running_loss = 0.0 for inputs, targets in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{args.epochs} - Training"): inputs = inputs.to(device) targets = targets.to(device) optimizer.zero_grad() # 如果要微调 ResNet,可以在这里前向传播 outputs = model(inputs) loss = criterion(outputs, targets) loss.backward() optimizer.step() running_loss += loss.item() * inputs.size(0) epoch_loss = running_loss / len(train_loader.dataset) # 验证 model.eval() feature_extractor.eval() val_running_loss = 0.0 all_preds = [] all_targets = [] with torch.no_grad(): for inputs, targets in tqdm(val_loader, desc=f"Epoch {epoch + 1}/{args.epochs} - Validation"): inputs = inputs.to(device) targets = targets.to(device) outputs = model(inputs) loss = criterion(outputs, targets) val_running_loss += loss.item() * inputs.size(0) all_preds.extend(outputs.cpu().numpy()) all_targets.extend(targets.cpu().numpy()) val_loss = val_running_loss / len(val_loader.dataset) val_mae = mean_absolute_error(all_targets, all_preds) print( f"Epoch {epoch + 1}/{args.epochs} - Train Loss: {epoch_loss:.4f} - Val Loss: {val_loss:.4f} - Val MAE: {val_mae:.4f}") # 调整学习率 scheduler.step(val_loss) # 早停判断 if val_loss < best_val_loss: best_val_loss = val_loss trigger_times = 0 torch.save(model.state_dict(), args.model_path) print(f"最佳模型已保存 (Val Loss: {best_val_loss:.4f})") else: trigger_times += 1 print(f"Val Loss 未改善 ({trigger_times}/{patience})") if trigger_times >= patience: print("满足早停条件,停止训练。") break print(f"训练完成,最佳验证损失: {best_val_loss:.4f}") print(f"模型已保存到 {args.model_path}") # 可选:保存特征和标签以备将来使用 with open('features_labels_enhanced.pkl', 'wb') as f: pickle.dump({'X_train': X_train, 'y_train': y_train, 'X_val': X_val, 'y_val': y_val}, f) if __name__ == "__main__": main()
glguling
最新推荐文章于 2025-07-31 22:32:57 发布