glguling

# train.py

import os
import argparse
import pandas as pd
import numpy as np
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_absolute_error
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from PIL import Image
import pickle

# 定义设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


# 自定义数据集
class BoneAgeDataset(Dataset):
    def __init__(self, df, image_folder, transform, feature_extractor):
        self.df = df
        self.image_folder = image_folder
        self.transform = transform
        self.feature_extractor = feature_extractor
        self.ids = df['id'].unique()
        self.features = []
        self.labels = []
        self.prepare_features()

    def prepare_features(self):
        print("提取图像特征...")
        for id_ in tqdm(self.ids):
            boneage = self.df[self.df['id'] == id_]['boneage'].values[0]
            image_files = [os.path.join(self.image_folder, f) for f in os.listdir(self.image_folder)
                           if f.startswith(str(id_) + '_') and f.endswith('.png')]
            feats = []
            for img_path in image_files:
                try:
                    img = Image.open(img_path).convert('RGB')
                    img = self.transform(img).unsqueeze(0).to(device)
                    with torch.no_grad():
                        feat = self.feature_extractor(img).squeeze().cpu().numpy()
                    feats.append(feat)
                except Exception as e:
                    print(f"无法处理图像 {img_path}: {e}")
            if len(feats) == 0:
                continue
            aggregated_feat = np.mean(feats, axis=0)
            self.features.append(aggregated_feat)
            self.labels.append(boneage)

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return torch.tensor(self.features[idx], dtype=torch.float32), torch.tensor(self.labels[idx],
                                                                                   dtype=torch.float32)


# 定义增强的回归模型
class EnhancedRegressionModel(nn.Module):
    def __init__(self, input_dim):
        super(EnhancedRegressionModel, self).__init__()
        self.regressor = nn.Sequential(
            nn.Linear(input_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 1)
        )

    def forward(self, x):
        return self.regressor(x).squeeze()


def main():
    parser = argparse.ArgumentParser(description="Train Enhanced Bone Age Regression Model")
    parser.add_argument('--csv_path', type=str, required=True, help='Path to boneage CSV file')
    parser.add_argument('--image_folder', type=str, required=True, help='Path to image folder')
    parser.add_argument('--model_path', type=str, default='boneage_model.pth', help='Path to save the trained model')
    parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs')  # 增加默认训练轮数为100
    parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
    parser.add_argument('--learning_rate', type=float, default=1e-4, help='Initial learning rate')
    parser.add_argument('--weight_decay', type=float, default=1e-5, help='Weight decay for optimizer')
    args = parser.parse_args()

    # 读取CSV文件
    df = pd.read_csv(args.csv_path)
    df = df[['id', 'boneage']]

    # 定义图像预处理,并增强数据增强
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

    # 加载预训练的ResNet50
    resnet = models.resnet50(pretrained=True)

    # 1. 解冻所有 ResNet50 的层
    for param in resnet.parameters():
        param.requires_grad = True

    feature_extractor = nn.Sequential(*list(resnet.children())[:-1])  # 移除最后一层
    feature_extractor.to(device)
    feature_extractor.train()  # 设为训练模式以进行微调

    # 创建数据集和数据加载器
    dataset = BoneAgeDataset(df, args.image_folder, transform, feature_extractor)
    X = np.array(dataset.features)
    y = np.array(dataset.labels)

    print(f"总样本数: {len(dataset)}")

    # 分割训练和验证集
    X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
    print(f"训练集样本数: {len(X_train)}, 验证集样本数: {len(X_val)}")

    # 转换为Tensor
    X_train = torch.tensor(X_train, dtype=torch.float32)
    y_train = torch.tensor(y_train, dtype=torch.float32)
    X_val = torch.tensor(X_val, dtype=torch.float32)
    y_val = torch.tensor(y_val, dtype=torch.float32)

    train_dataset = torch.utils.data.TensorDataset(X_train, y_train)
    val_dataset = torch.utils.data.TensorDataset(X_val, y_val)

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)

    # 初始化模型、损失函数和优化器
    input_dim = X_train.shape[1]
    model = EnhancedRegressionModel(input_dim).to(device)
    criterion = nn.MSELoss()

    # 2. 仅优化回归部分和解冻的 ResNet 层
    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)

    # 使用 ReduceLROnPlateau;也可以结合 StepLR
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)

    # 训练循环
    best_val_loss = float('inf')
    patience = 15  # 增加早停耐心值为15
    trigger_times = 0

    for epoch in range(args.epochs):
        model.train()
        feature_extractor.train()
        running_loss = 0.0
        for inputs, targets in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{args.epochs} - Training"):
            inputs = inputs.to(device)
            targets = targets.to(device)

            optimizer.zero_grad()
            # 如果要微调 ResNet,可以在这里前向传播
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)

        epoch_loss = running_loss / len(train_loader.dataset)

        # 验证
        model.eval()
        feature_extractor.eval()
        val_running_loss = 0.0
        all_preds = []
        all_targets = []
        with torch.no_grad():
            for inputs, targets in tqdm(val_loader, desc=f"Epoch {epoch + 1}/{args.epochs} - Validation"):
                inputs = inputs.to(device)
                targets = targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_running_loss += loss.item() * inputs.size(0)

                all_preds.extend(outputs.cpu().numpy())
                all_targets.extend(targets.cpu().numpy())

        val_loss = val_running_loss / len(val_loader.dataset)
        val_mae = mean_absolute_error(all_targets, all_preds)
        print(
            f"Epoch {epoch + 1}/{args.epochs} - Train Loss: {epoch_loss:.4f} - Val Loss: {val_loss:.4f} - Val MAE: {val_mae:.4f}")

        # 调整学习率
        scheduler.step(val_loss)

        # 早停判断
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            trigger_times = 0
            torch.save(model.state_dict(), args.model_path)
            print(f"最佳模型已保存 (Val Loss: {best_val_loss:.4f})")
        else:
            trigger_times += 1
            print(f"Val Loss 未改善 ({trigger_times}/{patience})")
            if trigger_times >= patience:
                print("满足早停条件,停止训练。")
                break

    print(f"训练完成,最佳验证损失: {best_val_loss:.4f}")
    print(f"模型已保存到 {args.model_path}")

    # 可选:保存特征和标签以备将来使用
    with open('features_labels_enhanced.pkl', 'wb') as f:
        pickle.dump({'X_train': X_train, 'y_train': y_train, 'X_val': X_val, 'y_val': y_val}, f)


if __name__ == "__main__":
    main()     
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值