脑瘤的及时准确诊断对于治疗决策与预后判断至关重要。磁共振成像(MRI
)因其高分辨率的结构特征,在临床诊疗中广泛应用。
CNN因其局部感受野和平移不变性的特性,在结构化的医学影像数据上表现稳定而高效。而Transformer
模型最初应用于自然语言处理领域,近年来逐渐被扩展至计算机视觉任务,以其捕捉全局特征的能力引发研究热潮。本文就用Transformer
来实现脑肿瘤的识别。
- 对于Transformer架构,可以参考我之前写的Transformer架构详解文章来了解
1 准备工作
在我们实现Transformer
模型分类之前,我们需要做一些准备工作。首先我们要了解一下数据集,另外,为了更好地对比Transformer
模型,我已经在这个数据集上实现了基于CNN的RGB/Gray图像的分类模型作为一个参照物。
1.1 数据集与数据预处理
本项目所使用的数据集可通过Kaggle的脑瘤MRI数据集公开获得,该数据集共包含7023张脑瘤MRI图像,分为以下四类:glioma(神经胶质瘤)、meningioma(脑膜瘤)、pituitary(垂体瘤)和notumor(无瘤)
此外,数据集被划分为训练集和测试集,分别包含5712张图像和1311张图像。训练集中各类别图像数量分别为:glioma
(1321 张)、meningioma
(1339 张)、pituitary
(1595 张)和notumor
(1457 张),这使得类别分布相对均衡。
- 这里验证集和测试集一起使用,是为了增加训练和验证的数据量。
数据预处理
我们创建了两类数据集,分别对应不同的预处理方法:
- baseline RGB:用于基于
VGG16
、ResNet18
和AlexNet
迁移学习的CNN模型 - custom RGB/grayscale:用于
Transformer
模型Transformer
这里用灰度图像训练,是因为在医学影像任务中,颜色信息本身并不关键,空间结构和病灶的纹理模式才是模型判断的核心依据。同时还可以减少计算负担。
每种数据加载器的预处理细节如下图所示:
Model | 阶段 | 预处理步骤 |
---|---|---|
Baseline (RGB) | 训练 | • 调整尺寸为 224×224 像素 • 随机水平翻转(50% 概率) • 旋转(±15°) • 颜色抖动(亮度/对比度) • 平移(±10%) • 使用均值 [0.485, 0.456, 0.406] 和标准差 [0.229, 0.224, 0.225] 进行标准化 |
验证 | • 调整尺寸为 224×224 像素 • 使用均值 [0.485, 0.456, 0.406] 和标准差 [0.229, 0.224, 0.225] 进行标准化 | |
Custom (RGB) | 训练 | • 调整尺寸为 224×224 像素 • 随机水平翻转(50% 概率) • 旋转(±15°) • 使用均值 [0.485, 0.456, 0.406] 和标准差 [0.229, 0.224, 0.225] 进行标准化 |
验证 | • 同 Baseline 验证阶段的预处理方式 | |
Custom (Grayscale) | 训练 | • 转换为灰度图 • 调整尺寸为 224×224 像素 • 随机水平翻转(50% 概率) • 旋转(±15°) • 使用均值 [0.5] 和标准差 [0.5] 进行标准化 |
验证 | • 转换为灰度图 • 调整尺寸为 224×224 像素 • 使用均值 [0.5] 和标准差 [0.5] 进行标准化 |
-
所有图像均被调整为 224×224 像素,这是许多预训练模型常用的输入图像尺寸。
-
在验证阶段,预处理步骤与训练阶段保持一致,但不包括增强操作,仅保留图像缩放和标准化,以确保评估过程的一致性与稳定性。
-
Transformer模型未采用色彩扰动与图像平移的增强方法。这是由于模型架构的特点所致:在数据量较小的情况下,过多的数据增强可能导致Transformer模型记住增强带来的噪声或特定模式,反而不利于泛化。
1.2 模型参数和评估指标
1.2.1 模型参数
现在我们来看一下模型使用的一些参数:
Params | Baseline | Transformer (RGB) | Transformer (Gray) |
---|---|---|---|
Normalization | Mean: [0.485, 0.456, 0.406] Std: [0.229, 0.224, 0.225] | Mean: [0.5] Std: [0.5] | |
Batch Size | 128 | 256 | |
Optimizer | Adam (LR: 1e-4, Weight Decay: 1e-5) | ||
Loss Function | Cross-Entropy Loss | ||
Epochs | 100 | 160 | |
Architecture | VGG16, ResNet18, AlexNet | Hybrid Transformer |
- Adam 优化器优势:收敛快、适应性强,适合处理复杂特征和稀疏梯度。
- 过拟合控制策略:采用 L2 正则化,有效提升 Transformer 模型的泛化能力。
- 训练设置:设置 160 个训练轮次,并通过验证集 F1 分数保存最优模型。
- 性能评估指标:使用 F1 分数、精准率、召回率与混淆矩阵综合比较模型表现。
1.2.2 评估指标
以下指标在每个训练轮次中,针对训练集与验证集均有计算,以确保评估的全面性。这些指标可用于跟踪模型性能、收敛性以及泛化能力。通过对结果的分析,我们可以深入了解模型的优势与改进空间。
- F1 分数(F1-Score):F1 分数是精准率与召回率的调和平均数,能平衡评估模型对各类的识别能力。对于类别分布不均衡的情况尤为有效,因为它同时考虑了假阳性与假阴性的影响。
- 精准率(Precision):精准率表示某一类别中,模型预测为正的样本中有多少是真正的正例。它反映了模型减少误报(假阳性)的能力。高精准率意味着模型预测结果更具相关性,错误率更低。
- 召回率 / 灵敏度(Recall / Sensitivity):召回率衡量模型成功识别出真实正类的能力,即所有真实正例中被正确预测为正的比例。高召回率确保模型尽可能覆盖所有相关结果。
- 特异度(Specificity):特异度指模型在所有真实负类中,成功预测为负的比例。它反映模型避免误报的能力,在假阳性代价高或后果严重的应用场景中特别重要。高特异度可增强模型在类别不平衡情境下的判断能力。
- 混淆矩阵(Confusion Matrix):混淆矩阵详细列出了每一类别的真阳性、真阴性、假阳性与假阴性数量。它为误差分析提供基础,有助于了解模型在哪些类别上表现较弱,进而指导后续改进。
- 推理时间(Inference Time):指模型进行一次预测所需的平均时间,是评估模型效率的重要指标,特别是在对比不同架构时。推理时间越短,模型的实际部署和应用性能越高。
1.3 Baseline模型训练结果
我们的Baseline模型使用VGG16
、ResNet18
和AlexNet
三个模型做迁移学习。所有的训练结果放上来篇幅有点大了,而且这篇文章主要以Transformer为重点。所以这里以VGG16的训练结果为例,看一下训练损失、F1 Score和混淆矩阵:
可以看到VGG16模型在第71轮达成训练集F1分数1.0,验证集最佳 F1为0.9985,且混淆矩阵显示几乎无分类错误,表现出极高的准确率与稳定性。下面是三个模型的四个指标,VGG16表现最好,但Alexnet
和Resnet18
的各个指标也都超过了99%:
从训练结果中我们发现三个模型的准确率都非常高,我们的心已经凉一半了,那我们用Transformer
模型的意义在哪?大概率Transformer
的准确率不会超过VGG,因为它非常接近于完美了。这些问题我们等到实现完Transformer
模型后再来讨论。
2 Transformer实现
2.1 自定义架构
这里我们基于Transformer
自定义一个架构,我们要同时考虑局部特征提取能力与全局建模能力,并在计算效率、参数规模与泛化能力之间取得平衡。
2.1.1 自定义模型架构设计与优化策略
这里我们设计了一种融合CNN与Transformer架构优势的自定义模型。CNN计算成本低,擅长提取局部空间特征,但在处理大规模数据时存在性能上限;而Transformer则擅长捕捉全局上下文信息,但计算成本较高。因此,我们对Transformer架构进行了优化,以降低资源消耗的同时保留其建模能力,并利用CNN的归纳偏差加快模型收敛速度。具体设计包括:
- 引入CNN作为前置模块:
- 两个卷积模块用于提取图像初步特征;
- 利用CNN的空间感知归纳偏差(
inductive spatial bias
),提升对局部结构的建模能力; - 控制卷积层数,避免过度提取而损失空间信息,从而保留Transformer对非平移不变特征的学习能力。
- Patch生成与位置编码:
- 卷积提取的特征图根据设定的
patch
大小划分为嵌入块(embedding patches
); - 通过旋转位置编码(
Rotary Positional Embedding
)将相对位置关系编码到输入中,支持任意输入序列并增强模型对图像尺寸变换的适应能力
- 卷积提取的特征图根据设定的
2.1.2 Transformer模块结构与计算优化
为了提升模型在计算效率与参数量上的表现,我们自定义的Transformer部分设计了多项结构级优化:
高效多头注意力机制(Efficient Multi-Head Attention
)
- 使用单一线性投影同时生成
Query
、Key
和Value
,然后分割为三个部分; - 减少参数数量与内存开销,同时保留原始注意力机制的功能;
- 引入
einsum
运算进一步减少中间张量,加速计算流程
Transformer 模块结构(Transformer Block
)
每个Transformer模块遵循标准结构,并做了以下处理:
- 输入先通过
LayerNorm
层进行归一化,以稳定各特征通道的方差,促进训练收敛; - 进入自注意力模块,模型学习不同
patch
之间的重要性与结构关系; - 输出通过残差连接(
residual connection
)与原始输入相加,有助于构建深层模型并缓解梯度消失问题; - 残差输出进入多层感知机(
MLP
),搭配GELU
激活函数,进一步提取与整合关键特征。
输出与分类流程
- 自注意力模块输出通过全局自适应池化层(
Global Adaptive Pooling
)降维,提取整图的全局语义信息; - 最终将池化结果送入全连接分类层,输出最终多类别分类结果。
2.2 代码实现
2.2.1 数据集加载
首先我们定义一下dataloader
,包括RGB的和Gray的。这部分内容对应了我们之前预处理部分的内容。
base_data_dir = '/kaggle/input/brain-tumor-mri-dataset'
# Custom: RGB
train_data_dataloader = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(p=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
val_data_dataloader = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Custom: Gray
train_data_dataloader_gry = transforms.Compose([
transforms.Grayscale(num_output_channels=1),
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(p=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
val_data_dataloader_gry = transforms.Compose([
transforms.Grayscale(num_output_channels=1),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
# Custom: RGB
train_dataset = ImageFolder(f'{base_data_dir}/Training', transform=train_data_dataloader)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_dataset = ImageFolder(f'{base_data_dir}/Testing', transform=val_data_dataloader)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)
# Custom: Gray
train_dataset_gry = ImageFolder(f'{base_data_dir}/Training', transform=train_data_dataloader_gry)
train_loader_gry = DataLoader(train_dataset_gry, batch_size=128, shuffle=True)
val_dataset_gry = ImageFolder(f'{base_data_dir}/Testing', transform=val_data_dataloader_gry)
val_loader_gry = DataLoader(val_dataset_gry, batch_size=128, shuffle=False)
2.2.2 日志设置函数
这个函数用于初始化训练过程中的日志记录系统,包括将日志写入文件并实时输出到控制台(如Jupyter notebook 或终端)。它还能根据verbose
参数输出当前使用的设备(GPU/CPU)信息,有助于调试和系统配置检查。
import os
import logging
from datetime import datetime
import torch
def setup_logging(
log_dir: str, # 日志文件存储的目录路径
log_level: int = logging.INFO, # 日志等级,默认为INFO,可选 DEBUG、WARNING 等
verbose: bool = False # 是否在日志中加入设备信息
) -> logging.Logger:
# 若目标日志目录不存在则创建
os.makedirs(log_dir, exist_ok=True)
# 构造日志文件名,格式为 tumor_classification_年月日_时分秒.log
log_filename = os.path.join(
log_dir,
f'tumor_classification_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'
)
# 配置日志系统,定义日志格式和输出方式(写文件和控制台)
logging.basicConfig(
level=log_level, # 设置日志等级
# # 日志格式:时间 - 模块名 - 日志等级 - 内容(这些变量名固定,会由logging模块填充)
format='%(asctime)s - %(name)s - %(levelname)s: %(message)s',
# 处理日志输出方式的列表(可以输出到多个地方)
handlers=[
logging.FileHandler(log_filename), # 输出到文件
logging.StreamHandler() # 同时输出到控制台
]
)
# 创建或获取一个名为 'TumorClassification' 的日志记录器实例
logger = logging.getLogger('TumorClassification')
# 如果设置了 verbose 为 True,记录当前使用的设备信息
if verbose:
# 输出设备名(如果使用GPU则显示GPU型号)
logger.info(f"Device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
# 输出是否启用了CUDA
logger.info(f"CUDA Available: {torch.cuda.is_available()}")
# 返回该 logger 实例,供后续训练过程使用
return logger
通常在训练开始时调用一次,为整个训练过程提供统一的日志输出入口。
2.2.3 评估指标计算、保存和绘图
评估指标计算
该函数用于在多分类任务中计算关键性能指标,包括:F1分数
、Precision
、Recall
、每类的Sensitivity
(敏感性,常用于医学图像分类)和混淆矩阵(confusion matrix
)
输入:
true_labels
: 实际标签(列表或 numpy 数组)pred_labels
: 模型预测标签num_classes
: 分类任务中类别的总数
输出:
- 一个包含上述各类评估指标的字典
from sklearn.metrics import f1_score, precision_score, recall_score, confusion_matrix
import numpy as np
def calculate_metrics(true_labels, pred_labels, num_classes):
# 计算加权 F1 分数,考虑每个类别的样本数权重
f1 = f1_score(true_labels, pred_labels, average='weighted')
# 计算加权精确率(Precision)
precision = precision_score(true_labels, pred_labels, average='weighted')
# 计算加权召回率(Recall)
recall = recall_score(true_labels, pred_labels, average='weighted')
# 构造混淆矩阵(行是真实标签,列是预测标签)
cm = confusion_matrix(true_labels, pred_labels)
# 初始化 sensitivity(敏感性)列表,每个类别一个值
sensitivity = []
# 遍历每一个类别,逐类计算敏感性(即 Recall)
for i in range(num_classes):
tp = cm[i, i] # 真阳性:第 i 类预测正确的个数
fn = cm[i, :].sum() - tp # 假阴性:第 i 类实际为该类但被错分的个数
# 敏感性 = TP / (TP + FN),避免除零
sensitivity.append(tp / (tp + fn) if (tp + fn) > 0 else 0)
# 返回包含所有指标的字典,便于后续分析与记录
return {
'f1_score': f1,
'precision': precision,
'recall': recall,
'sensitivity': sensitivity, # 是一个列表,每个类别一个
'confusion_matrix': cm # 是一个二维矩阵,方便可视化
}
指标保存
这个函数将模型训练过程中每个epoch
的评估指标(训练集 + 验证集)保存到一个 CSV 文件中,便于后续分析、绘图、对比模型等操作。
参数说明:
train_metrics
: 每个 epoch 训练集的评估结果(列表,每个元素是一个字典)val_metrics
: 每个 epoch 验证集的评估结果(列表)csv_path
: CSV 文件保存路径
import os
import csv
def save_metrics_to_csv(train_metrics, val_metrics, csv_path):
# 确保输出路径的目录存在(如 logs/Transformer_metrics.csv 中的 logs/)
os.makedirs(os.path.dirname(csv_path), exist_ok=True)
# 打开目标 CSV 文件,使用写模式 ('w') 并禁用行末多余换行
with open(csv_path, mode='w', newline='') as file:
writer = csv.writer(file)
# 写入 CSV 表头(列名),包含每轮训练的各项指标
headers = ['Epoch', 'Train_F1', 'Train_Precision', 'Train_Recall',
'Train_Sensitivity', 'Train_Confusion_Matrix',
'Val_F1', 'Val_Precision', 'Val_Recall',
'Val_Sensitivity', 'Val_Confusion_Matrix']
writer.writerow(headers)
# 遍历所有 epoch 的训练和验证指标,一一写入 CSV 文件
for epoch, (train, val) in enumerate(zip(train_metrics, val_metrics), start=1):
writer.writerow([
epoch, # 当前 epoch 编号(从1开始)
train['f1_score'], # 训练集 F1 分数
train['precision'], # 训练集精确率
train['recall'], # 训练集召回率
train['sensitivity'], # 训练集敏感性(为一个列表)
str(train['confusion_matrix'].tolist()), # 转成字符串以便存入 CSV
val['f1_score'], # 验证集 F1 分数
val['precision'], # 验证集精确率
val['recall'], # 验证集召回率
val['sensitivity'], # 验证集敏感性
str(val['confusion_matrix'].tolist()) # 转成字符串存入 CSV
])
# 提示保存成功
print(f"Metrics saved to {csv_path}")
将confusion_matrix
转为字符串是因为混淆矩阵是一个numpy二维数组,直接写入CSV会报错。
绘图
这个函数用于可视化并保存训练和验证过程中的评估指标图表,如f1_score
、precision
、recall
和sensitivity
。这些指标图能帮助分析模型训练过程中的表现趋势,判断是否过拟合或欠拟合等。
参数:
train_metrics
: 每个epoch的训练指标,通常是一个 list of dicts。val_metrics
: 每个epoch的验证指标,同样是 list of dicts。save_dir
: 用于保存图表的目录路径。
import os
import numpy as np
import matplotlib.pyplot as plt
def plot_metrics(train_metrics, val_metrics, save_dir):
# 确保用于保存图像的目录存在,如不存在则创建
os.makedirs(save_dir, exist_ok=True)
# 获取 epoch 数,用于 x 轴
epochs = range(1, len(train_metrics) + 1)
# 要绘制的评估指标列表
metrics_to_plot = ['f1_score', 'precision', 'recall', 'sensitivity']
# 对每一个指定的指标进行绘图
for metric in metrics_to_plot:
plt.figure(figsize=(10, 6)) # 创建新的图像画布,大小为10x6英寸
# 取出训练集中每个 epoch 对应的指标值
# 如果是 sensitivity(可能是列表),则取其平均值,否则直接使用
train_values = [
epoch_metrics[metric] if metric != 'sensitivity' else np.mean(epoch_metrics[metric])
for epoch_metrics in train_metrics
]
# 同样地,提取验证集中的指标值
val_values = [
epoch_metrics[metric] if metric != 'sensitivity' else np.mean(epoch_metrics[metric])
for epoch_metrics in val_metrics
]
# 绘制训练和验证指标曲线,添加标记点
plt.plot(epochs, train_values, label=f'Train {metric}', marker='o')
plt.plot(epochs, val_values, label=f'Validation {metric}', marker='o')
# 设置图表标题、坐标轴标签、图例和网格
plt.title(f'{metric.capitalize()} over Epochs')
plt.xlabel('Epochs')
plt.ylabel(metric.capitalize())
plt.legend()
plt.grid(True)
# 设置 y 轴范围为 [0, 1],适用于这些评估指标
plt.ylim(0, 1)
# 构建图像保存路径,并保存图像
save_path = os.path.join(save_dir, f'{metric}.png')
plt.savefig(save_path) # 保存图像为 PNG 文件
print(f"Saved plot for {metric} to {save_path}") # 输出提示信息
plt.close() # 关闭当前图像,释放内存
2.2.4 训练函数
该函数是一个增强版的PyTorch模型训练循环,包含以下功能:
- 支持训练/验证数据分离
- 自动保存性能最好的模型(基于F1分数)
- 使用
TensorBoard
可视化训练过程 - 提供
Early Stopping
和动态学习率调整(使用ReduceLROnPlateau
)
import os
from typing import Optional, Tuple, List, Dict
import torch
import torch.nn as nn
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import ReduceLROnPlateau
def train_model_enhanced(
model: nn.Module,
train_loader: torch.utils.data.DataLoader,
val_loader: torch.utils.data.DataLoader,
criterion: nn.Module,
optimizer: torch.optim.Optimizer,
num_epochs: int = 150,
device: Optional[torch.device] = None,
num_classes: int = 4,
log_dir: str = './logs',
checkpoint_dir: str = '/kaggle/working',
model_name=None,
early_stopping_patience: int = 10,
learning_rate_patience: int = 5
) -> Tuple[nn.Module, List[Dict], List[Dict]]:
# 自动选择 GPU 或 CPU
device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# 创建模型保存目录和日志目录
checkpoint_dir = os.path.join(checkpoint_dir, model_name)
log_dir = os.path.join(log_dir, model_name)
os.makedirs(log_dir, exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)
# 设置日志和 TensorBoard
logger = setup_logging(log_dir, verbose=True)
writer = SummaryWriter(log_dir=log_dir)
# 设置学习率调度器(基于验证集 F1 分数)
scheduler = ReduceLROnPlateau(
optimizer,
mode='max',
factor=0.5, # 学习率下降因子
patience=learning_rate_patience # 如果 n 个 epoch 没提升,就调整 lr
)
# 初始化早停相关变量
best_val_f1 = 0.0
epochs_no_improve = 0
# 用于记录每个 epoch 的指标
train_metrics, val_metrics = [], []
# ======= 主训练循环 =======
for epoch in range(num_epochs):
model.train() # 切换为训练模式
train_loss = 0.0
train_preds, train_true = [], []
for batch_idx, (inputs, labels) in enumerate(train_loader):
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs) # 前向传播
loss = criterion(outputs, labels) # 计算损失
loss.backward() # 反向传播
# 可选:梯度裁剪,防止梯度爆炸
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step() # 更新参数
train_loss += loss.item()
_, predicted = torch.max(outputs, 1) # 获取预测标签
train_preds.extend(predicted.cpu().numpy())
train_true.extend(labels.cpu().numpy())
# ======= 验证阶段 =======
model.eval()
val_loss = 0.0
val_preds, val_true = [], []
with torch.no_grad(): # 验证不需要梯度
for inputs, labels in val_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
val_loss += loss.item()
_, predicted = torch.max(outputs, 1)
val_preds.extend(predicted.cpu().numpy())
val_true.extend(labels.cpu().numpy())
# ======= 计算评估指标 =======
train_metrics_epoch = calculate_metrics(
np.array(train_true),
np.array(train_preds),
num_classes
)
val_metrics_epoch = calculate_metrics(
np.array(val_true),
np.array(val_preds),
num_classes
)
# ======= TensorBoard 可视化记录 =======
# 当前epoch的平均训练损失
writer.add_scalar('Loss/Train', train_loss / len(train_loader), epoch)
writer.add_scalar('Loss/Validation', val_loss / len(val_loader), epoch)
writer.add_scalar('F1/Train', train_metrics_epoch['f1_score'], epoch)
writer.add_scalar('F1/Validation', val_metrics_epoch['f1_score'], epoch)
writer.add_scalar('Precision/Train', train_metrics_epoch['precision'], epoch)
writer.add_scalar('Precision/Validation', val_metrics_epoch['precision'], epoch)
writer.add_scalar('Recall/Train', train_metrics_epoch['recall'], epoch)
writer.add_scalar('Recall/Validation', val_metrics_epoch['recall'], epoch)
writer.add_scalar('Sensitivity/Train', np.mean(train_metrics_epoch['sensitivity']), epoch)
writer.add_scalar('Sensitivity/Validation', np.mean(val_metrics_epoch['sensitivity']), epoch)
# ======= 学习率调整 =======
scheduler.step(val_metrics_epoch['f1_score'])
print(val_metrics_epoch['f1_score']) # 输出当前 epoch 验证 F1 分数
# ======= 保存最佳模型 =======
if val_metrics_epoch['f1_score'] > best_val_f1 and epoch >= 10:
best_val_f1 = val_metrics_epoch['f1_score']
epochs_no_improve = 0
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'best_f1': best_val_f1,
'epoch': epoch
}, os.path.join(checkpoint_dir, f'{model_name}_{epoch}_f1_{best_val_f1:.4f}.pth'))
else:
epochs_no_improve += 1
# ======= 日志打印 =======
logger.info(f"Epoch {epoch+1}/{num_epochs}")
logger.info(f"Train F1: {train_metrics_epoch['f1_score']:.4f}")
logger.info(f"Val F1: {val_metrics_epoch['f1_score']:.4f}")
# ======= 早停判断 =======
if epochs_no_improve >= early_stopping_patience:
logger.info(f"Early stopping triggered after {epoch+1} epochs")
break
# 保存每轮的指标记录
train_metrics.append(train_metrics_epoch)
val_metrics.append(val_metrics_epoch)
# ======= 清理工作 =======
writer.close() # 关闭 TensorBoard 写入器
# 返回训练完成的模型及指标记录
return model, train_metrics, val_metrics
2.2.5 Transformer架构
本模型结合了卷积神经网络(CNN
)的局部感知能力与Transformer的全局建模能力,通过优化的模块设计,提高了训练效率与模型性能,特别适合用于医学影像分类任务。
2.2.5.1 多头自注意力机制模块
这是一个高效实现的多头自注意力机制(Multi-Head Attention
)模块,属于 Transformer 的核心组件之一,常用于自然语言处理、图像处理等任务中。
此版本对传统的实现做了优化,使用单次线性变换计算 Q、K、V,并使用 einsum
进行矩阵操作,从而提升计算效率。
import torch
import torch.nn as nn
import torch.nn.functional as F
class EfficientMultiHeadAttention(nn.Module):
def __init__(self, embed_dim, num_heads, dropout=0.1):
super().__init__()
# 检查嵌入维度是否能被头数整除(每个头分得一样多的维度)
assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by number of heads"
self.embed_dim = embed_dim # 输入/输出的维度
self.num_heads = num_heads # 注意力头数
self.head_dim = embed_dim // num_heads # 每个头处理的维度
self.scale = self.head_dim ** -0.5 # 缩放因子d_k,用于dot-product attention中防止梯度爆炸
# 合并Q,K,V的线性投影为一次计算:input -> [Q|K|V]
self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim)
# 输出投影(将多头拼接结果映射回原始维度)
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.dropout = nn.Dropout(dropout) # 用于防止过拟合
def forward(self, x, mask=None):
# B: batch size, L: 序列长度, C: 通道数(embed_dim)
B, L, C = x.shape
# 一次性计算Q,K,V,然后拆分为3份,维度为[B, L, 3*embed_dim]
qkv = self.qkv_proj(x).chunk(3, dim=-1)
# 将Q,K,V变形为 [B, num_heads, L, head_dim],方便多头并行处理
q, k, v = map(
lambda t: t.view(B, L, self.num_heads, self.head_dim).transpose(1, 2),
qkv
)
# 使用einsum高效计算注意力打分:[B, heads, query_len, key_len]
attn_weights = torch.einsum('bhqd,bhkd->bhqk', q, k) * self.scale
# 如果提供了mask(将未来单词的信息屏蔽掉),则屏蔽掉无效位置(设为-inf,softmax为0)
if mask is not None:
attn_weights = attn_weights.masked_fill(mask == 0, float('-inf'))
# 归一化注意力得分,形成概率分布(注意力分配)
attn_probs = F.softmax(attn_weights, dim=-1)
# 应用Dropout增强模型鲁棒性
attn_probs = self.dropout(attn_probs)
# 使用注意力分数(每个Query对所有Key的权重)加权求和V(每个Key位置的Value向量)
# 得到每个Query的加权Value输出:[B, heads, query_len, head_dim]
context = torch.einsum('bhqk,bhkd->bhqd', attn_probs, v)
# 还原回原始 shape:[B, L, embed_dim]
context = context.transpose(1, 2).contiguous().view(B, L, C)
# 通过线性映射恢复输出维度(可配合残差连接)
return self.out_proj(context)
1.q
,k
和v
的计算
Transformer的注意力模块是多头的,每个头并行处理输入的一部分信息(不同的子空间),然后合并结果。我们需要把embed_dim
拆成num_heads
个小头,每个处理head_dim
的向量。
embed_dim = num_heads × head_dim
在代码中,假设embed_dim = 128
,num_heads = 8
:
.view(B, L, num_heads, head_dim)
:把每个token的128维嵌入切成8段,每段16维.transpose(1, 2)
:把L
和num_heads
位置对调,方便每个头并行处理整个序列
Transformer中的每个头是独立计算注意力的。每个头处理一个[L, head_dim]
的序列,并对全部token之间进行attention。然后最后再把这num_heads
个结果拼起来。
2.attn_weights
的计算
attn_weights = torch.einsum('bhqd,bhkd->bhqk', q, k) * self.scale
这部分对应实现论文中的Scaled Dot-Product Attention
的打分部分:
Attention
(
Q
,
K
,
V
)
=
softmax
(
Q
K
T
d
k
)
V
\text{Attention}(Q, K, V) = \text{softmax}\left( \frac{QK^T}{\sqrt{d_k}} \right) V
Attention(Q,K,V)=softmax(dkQKT)V
我们重点看前半部分的打分:
s
c
o
r
e
s
=
Q
K
T
d
k
scores = \frac{Q K^T}{\sqrt{d_k}}
scores=dkQKT
在多头注意力中,我们有:
q.shape = [B, h, Lq, d_k]
:每个头的Query,Lq
是query的长度k.shape = [B, h, Lk, d_k]
:每个头的Key,Lk
是key的长度
我们想做的是:在每个头内,对每个query
和所有key
做点积,得到注意力打分矩阵[B, h, Lq, Lk]
einsum
的作用
einsum
是爱因斯坦求和约定(Einstein summation convention
)的简写,在PyTorch中是个很强大的函数,用于写简洁的矩阵运算。你可以把它看成是:
# 类似下面的矩阵乘法操作
attn_weights = torch.matmul(q, k.transpose(-2, -1)) * scale
但是 einsum
更灵活,更直观表达操作方式:
torch.einsum('bhqd,bhkd->bhqk', q, k)
假设:B
为batch size
;h
为number of heads
;q
为query sequence length
(L);k
为key sequence length
(L);d
为head dimension
(head_dim)
这段话可以解读为:
'bhqd'
:q的维度是[batch, heads, query_len, d_k]
'bhkd'
:k的维度是[batch, heads, key_len, d_k]
'->bhqk'
:输出是[batch, heads, query_len, key_len]
,即打分矩阵
意思是:对最后一维d_k
进行点积运算(查询和键之间的相似度),保留query
和key
的所有组合。即每个head
里,第q
个查询和第k
个键做点积,结果是一个[q × k]
的注意力分数矩阵。
代码例子
import torch
# 假设 batch_size=1, num_heads=1, seq_len=2, head_dim=3
# 形状:[B, H, Q, D] 和 [B, H, K, D]
q = torch.tensor([[[[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0]]]]) # shape: [1, 1, 2, 3]
k = torch.tensor([[[[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0]]]]) # shape: [1, 1, 2, 3]
# 计算 attention score: einsum('bhqd,bhkd->bhqk', q, k)
attn_weights = torch.einsum('bhqd,bhkd->bhqk', q, k)
print(attn_weights)
这里:
q[0, 0, 0] = [1, 2, 3]
是第一个query
向量k[0, 0, 0] = [1, 0, 0]
是第一个key
向量k[0, 0, 1] = [0, 1, 0]
是第二个key
向量
对第一个query
去点积第一个和第二个key
:
q[0] ⋅ k[0] = 1×1 + 2×0 + 3×0 = 1
q[0] ⋅ k[1] = 1×0 + 2×1 + 3×0 = 2
对第二个query
去点积第一个和第二个 key:
q[1] ⋅ k[0] = 4×1 + 5×0 + 6×0 = 4
q[1] ⋅ k[1] = 4×0 + 5×1 + 6×0 = 5
所以最终attn_weights
是:
[[[[1.0, 2.0],
[4.0, 5.0]]]]
维度是[1个batch, 1个头, 2个query位置, 每个query对应2个key的打分]
2.2.5.2 TransformerBlock模块
这是一个基本的Transformer子模块(block
),可以参考Transformer架构详解的Transformer编码器
章节,包含:
- 多头注意力机制(
EfficientMultiHeadAttention
) - 两个
LayerNorm
- 一个前馈MLP层(两层全连接)
该模块在输入张量上执行:自注意力 → 残差连接 + LayerNorm → 前馈 MLP → 残差连接 + LayerNorm
class TransformerBlock(nn.Module):
# 初始化 TransformerBlock,包含注意力层、规范化层、前馈网络等
def __init__(self, embed_dim, num_heads, mlp_dim, dropout=0.05):
super().__init__()
# 多头注意力模块,这里使用了刚刚实现的Efficient版本
self.attention = EfficientMultiHeadAttention(embed_dim, num_heads, dropout)
# 第一个LayerNorm,用于attention的输出
self.norm1 = nn.LayerNorm(embed_dim)
# 第二个LayerNorm,用于MLP输入的归一化(采用Pre-LN架构)
self.norm2 = nn.LayerNorm(embed_dim)
# 前馈神经网络(MLP),包含一层线性+GELU激活
self.mlp = nn.Sequential(
nn.Linear(embed_dim, mlp_dim), # 全连接:维度从embed_dim映射到mlp_dim
nn.GELU(), # 使用GELU非线性激活
# nn.Dropout(dropout), # Dropout被注释掉了,可以按需开启
)
# 前向传播定义
def forward(self, x):
# 对输入进行 LayerNorm 后送入注意力层,并加上残差连接
attn_out = self.attention(self.norm1(x))
x = x + attn_out
# 再次进行 LayerNorm,送入前馈网络,并加上残差连接
mlp_out = self.mlp(self.norm2(x))
x = x + mlp_out
# 返回输出结果
return x
Pre-LN架构(Pre-Layer Normalization
) 指的是:在Transformer的每个子模块(比如注意力机制、多层感知机)之前进行LayerNorm
,而不是之后。因为它是一种更稳定、更易训练、更适合深层模型的结构。
(1)训练更稳定
- 在Post-LN架构中,模型在深层时容易出现梯度消失或爆炸。
- Pre-LN把归一化提前,可以减缓梯度消失,尤其在深层网络训练中效果更明显。
(2)残差路径更清晰
- Pre-LN的残差连接路径不包含
LayerNorm
,因此在反向传播中信息传递更顺畅。 - 简单说:残差项梯度不被归一化扭曲,能更快传递回去。
(3)支持更大batch
、学习率
- 实验表明 Pre-LN 架构在大模型、大
batch
、长序列下表现更稳定,不容易训练崩掉。
总得来说,在训练深层模型(如GPT-2、T5),要求收敛快,稳定性强时用Pre-LN,如果训练较浅的网络用Post-LN也可以。
2.2.5.3 RotaryPositionalEmbedding模块
这是一个自定义的正弦-余弦位置编码类,用来为Transformer中的输入添加序列位置信息,核心思路与原始Transformer位置编码一致,但有改进。
这里使用的是RoPE(Rotary Positional Embedding
),是很多现代Transformer(如GPT-NeoX
、BLOOM
)中用的位置编码方法。目的是为Transformer中的每个位置添加旋转式位置编码,增强注意力机制的相对位置感知能力。优点如下:
- 比传统的正余弦位置编码(如
BERT
)更适合长序列; - 可与多头注意力自然结合,不需要额外的位置向量加法。
class RotaryPositionalEmbedding(nn.Module):
def __init__(self, dim, base=10000):
super().__init__()
# 创建倒数频率:base ** (2i / dim),用于生成位置嵌入
# torch.arange(0., dim, 2.) 生成 [0, 2, 4, ..., dim-2]
# 最终形状为 [dim/2]
inv_freq = 1. / (base ** (torch.arange(0., dim, 2.) / dim))
# 将 inv_freq 注册为 buffer(不是模型参数,但会被保存到模型里)
self.register_buffer('inv_freq', inv_freq)
def forward(self, x, seq_dim=1):
# x形状是[batch_size, seq_len, dim],所以x.shape[seq_dim=1]就是取seq_len
# t = [0, 1, 2, ..., seq_len-1]
t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
# 使用爱因斯坦求和公式生成位置与频率组合的正余弦输入矩阵
# sinusoid_inp[i,j] = t[i] * inv_freq[j],
# 结果shape是[seq_len, dim/2], 每一行对应一个位置,每一列对应一个频率维度
sinusoid_inp = torch.einsum("i,j->ij", t, self.inv_freq)
# 将sin和cos拼接成最终位置编码,形状为 [seq_len, dim]
emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
# 返回shape为[1, seq_len, dim],用于后续广播到多头输入
return emb[None, :, :]
1.register_buffer
在PyTorch中,register_buffer
是一个给模型保存常量的方法,这些常量:
- 不是参数(不会被训练)
- 但又需要被保存(比如保存到
.pt
模型文件) - 并且希望在
.cuda()
或.to(device)
时自动搬到显卡上
**例子:**假设你有一个位置编码向量 pos_enc
,你不希望它被模型训练更新(因为它是固定生成的),但又希望:
- 保存模型时能一块保存
- 加载模型后它还能自动被用上
- 模型转到GPU时它也自动转到GPU
class MyModel(nn.Module):
def __init__(self):
super().__init__()
# 注册一个 buffer,不会训练,但会保存
self.register_buffer('pos_enc', torch.tensor([[0, 1, 2], [3, 4, 5]]))
def forward(self, x):
# pos_enc 是只读常量,可以在前向传播中使用
return x + self.pos_enc
使用:
model = MyModel()
print(model.pos_enc) # 可以正常使用
model.to('cuda') # pos_enc 也自动转移到 cuda
torch.save(model.state_dict()) # pos_enc 也被保存了
我们可以把它当作模型内置的只读参考表,比如:位置编码、mask模板、词表掩码等。
2.位置编码中的倒数频率
现在分析一下inv_freq
:
inv_freq = 1. / (base ** (torch.arange(0., dim, 2.) / dim))
令维度为dim
,基数为base
,则倒数频率向量inv_freq
定义为:
inv_freq
i
=
1
base
2
i
dim
其中
i
=
0
,
1
,
2
,
…
,
dim
2
−
1
\text{inv\_freq}_i = \frac{1}{\text{base}^{\frac{2i}{\text{dim}}}} \quad \text{其中} \quad i = 0, 1, 2, \dots, \frac{\text{dim}}{2} - 1
inv_freqi=basedim2i1其中i=0,1,2,…,2dim−1
这些频率随维度指数衰减(频率越来越小,周期越来越长)。这种设计确保正余弦函数的周期差异横跨多个尺度,从而在位置编码中引入多频信息。
例子
我们希望:
- 不同位置的编码彼此不同,模型能区分第1个和第2个
token
- 相邻位置之间的编码要有某种平滑变化,便于模型感知顺序
- 高维的编码要能捕捉全局结构,低维的编码捕捉局部差异
假设嵌入维度dim = 4
,位置pos
取0~3
,我们构造频率如下:
import numpy as np
dim = 4
base = 10000
inv_freq = 1. / (base ** (np.arange(0, dim, 2) / dim))
# [1.0, 0.01] — 对应两个频率通道
现在我们计算前4个位置的位置编码(每个是4维):
pos = np.arange(4).reshape(-1, 1) # 位置 0, 1, 2, 3
freq = inv_freq.reshape(1, -1) # 频率 [1.0, 0.01]
angles = pos * freq # 广播相乘
# sin/cos 拼接
encoding = np.concatenate([np.sin(angles), np.cos(angles)], axis=-1)
结果如下:
Position | sin(1.0 * pos) | sin(0.01 * pos) | cos(1.0 * pos) | cos(0.01 * pos) |
---|---|---|---|---|
0 | 0.0 | 0.0 | 1.0 | 1.0 |
1 | 0.84 | 0.00999983 | 0.54 | 0.99995 |
2 | 0.90 | 0.01999867 | -0.41 | 0.99980 |
3 | 0.14 | 0.0299955 | -0.99 | 0.99955 |
我们发现:
- 第一对频率(
1.0
)是高频:变化大,可以捕捉微小的位置信息(比如token的局部位置差异) - 第二对频率(
0.01
)是低频:变化缓慢,可以反映更全局的位置信息(比如token的大致位置)
如果我们不用倒数,而是频率线性增长,比如 [1, 2, 3, 4]
:
- 高维维度变化太快,容易让不同位置之间的位置编码变得混乱
- sin和cos的周期不同步,可能导致模型看不出顺序规律
2.2.5.4 HybridFPNTransformer & HybridFPNTransformer_gry
现在将CNN特征提取
、Patch Embedding
、Rotary位置编码
与Transformer编码器
相结合,形成我们最终的Transformer
混合模型。流程如下
图像输入 → CNN特征提取(stem) → Patch划分 → Transformer编码 → 图像还原 → 分类输出
下面以HybridFPNTransformer
代码为例讲解。两个类结构唯一的区别是输入通道数不同,对于HybridFPNTransformer_gry
来说,只需要把self.stem
中第一个卷积层的输入通道从3改为1即可。
import torch
import torch.nn as nn
class HybridFPNTransformer(nn.Module):
def __init__(self, num_classes, image_size=224, patch_size=16,
embed_dim=128, num_heads=8, transformer_depth=4):
super().__init__()
# ---------------------------
# CNN Stem:初始卷积提取局部特征(含两次下采样)
# 输出尺寸约为原图的 1/4
# ---------------------------
self.stem = nn.Sequential(
# [B, 3, H, W] → [B, 2*embed_dim, H/2, W/2]
nn.Conv2d(3, embed_dim * 2, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(embed_dim * 2),
nn.ReLU(inplace=True),
# → [B, embed_dim, H/4, W/4]
nn.Conv2d(embed_dim * 2, embed_dim, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(embed_dim),
nn.ReLU(inplace=True)
)
# ---------------------------
# Patch Embedding:将特征图分块为 patch,映射到 token 向量
# 使用 conv 实现,效果等同 ViT 中的线性 patch embedding
# ---------------------------
self.patch_embed = nn.Conv2d(
embed_dim, embed_dim,
kernel_size=patch_size,
stride=patch_size
)
# ---------------------------
# 位置编码:Rotary Positional Embedding(RoPE)
# 加入位置感知能力,用于后续 attention 中增强空间信息
# ---------------------------
self.rotary_emb = RotaryPositionalEmbedding(embed_dim)
# ---------------------------
# Transformer 编码器:由多个标准 transformer block 组成
# 每个 block 包含多头注意力 + MLP + 残差结构
# ---------------------------
self.transformer_blocks = nn.ModuleList([
TransformerBlock(embed_dim, num_heads, embed_dim)
for _ in range(transformer_depth)
])
# ---------------------------
# 分类头(Head):全局池化 + 两层全连接 + Dropout + 激活
# 用于输出最终的 num_classes 分类结果
# ---------------------------
self.classifier = nn.Sequential(
nn.AdaptiveAvgPool2d(1), # 将 H×W → 1×1(全局平均池化)
nn.Flatten(), # [B, C, 1, 1] → [B, C]
nn.Linear(embed_dim, 64), # 降维处理
nn.ReLU(inplace=True),
nn.Dropout(0.1),
nn.Linear(64, num_classes) # 最终输出 num_classes 维度
)
def forward(self, x):
# Step 1: 用 CNN 进行初始特征提取
x = self.stem(x) # 输出尺寸约为 [B, embed_dim, H/4, W/4]
# Step 2: Patch Embedding,将图像分割为patch序列
x_patches = self.patch_embed(x) # 输出 [B, embed_dim, H_patch, W_patch]
# Step 3: 转换为序列格式,适配Transformer输入
# 把2D图像[B, C, H_patch, W_patch]拉平成[B, num_patches, embed_dim]的序列。
# 其中num_patches = H_patch × W_patch
B, C, H, W = x_patches.shape
x_patches = x_patches.flatten(2).transpose(1, 2) # [B, num_patches, embed_dim]
# Step 4: 添加旋转位置编码
rotary_pos_emb = self.rotary_emb(x_patches) # [1, num_patches, embed_dim]
# 与原始token相加,形成具有空间顺序感知的序列表示
# 广播相加到[B, num_patches, embed_dim],增强空间位置信息
x_patches = x_patches + rotary_pos_emb
# Step 5: 输入多个Transformer Block进行编码
# 对序列表示进行全局建模,提取跨patch的上下文语义信息
for transformer_block in self.transformer_blocks:
x_patches = transformer_block(x_patches)
# Step 6: 恢复为图像格式[B, C, H, W],适配分类头
# 将[B, num_patches, embed_dim]转换回2D图像形状[B, C, H, W]
x_transformed = x_patches.transpose(1, 2).view(B, -1, H, W)
# Step 7: 分类输出
return self.classifier(x_transformed)
(1)CNN Stem
CNN的局部感受野强,有利于提取边缘、纹理等底层特征。所以对原始图像进行卷积提取局部纹理特征,然后连续两次stride=2
的卷积实现空间下采样(H×W减小到原来的1/4),提升后续Transformer的效率。
(2)Patch Embedding
self.patch_embed = nn.Conv2d(embed_dim, embed_dim, kernel_size=patch_size, stride=patch_size)
这一部分的作用是将CNN提取到的特征图划分为若干个大小为 patch_size × patch_size
的patch,并通过卷积操作将每个patch映射为一个token,从而构成token序列,作为Transformer的输入。这种做法模仿了Vision Transformer
(ViT
)的patch embedding机制,但使用卷积代替了手动切分与线性映射,使计算过程更高效,同时保留了一定的局部感知能力。
这一层的四个参数含义如下:
in_channels=embed_dim
:输入特征图的通道数,与CNN stem输出一致;out_channels=embed_dim
:每个patch映射成的token向量维度;kernel_size=patch_size
:卷积核大小等于patch尺寸,用于一次提取一个patch;stride=patch_size
:步长等于patch尺寸,确保patch不重叠且完整覆盖特征图。
(3)Rotary Positional Embedding(RoPE
)
为patch token添加旋转位置编码,使模型具有空间位置感知能力。
self.rotary_emb = RotaryPositionalEmbedding(embed_dim)
Transformer缺乏位置感知能力,RoPE可注入相对位置信息。相比原始位置编码(sin-cos
),RoPE具备更好的平移不变性与泛化能力。
(4)Transformer Blocks
这里创建一个包含transformer_depth
个Transformer编码器层的堆叠,每一层都由多头注意力机制(用于捕捉token间的长距离依赖)、前馈网络(用于增强表达能力)以及残差连接和LayerNorm
(用于稳定训练)组成。
self.transformer_blocks = nn.ModuleList([
TransformerBlock(embed_dim, num_heads, embed_dim)
for _ in range(transformer_depth)
])
这种多层Transformer编码器的设计理由在于利用Transformer的全局建模能力来弥补卷积神经网络感受野不足的问题,并且通过多层堆叠赋予模型更深层次的特征抽象能力。
(5)分类头
这段代码定义了一个分类器,它首先通过自适应平均池化将Transformer编码后的特征图压缩为1x1
的维度,然后展平并通过两个线性层(中间使用ReLU
激活和Dropout
进行正则化)输出最终的分类结果,其设计理由在于利用自适应池化处理任意尺寸的输入,并通过两层线性层简化分类任务,同时使用Dropout
防止过拟合。
self.classifier = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(embed_dim, 64),
nn.ReLU(inplace=True),
nn.Dropout(0.1),
nn.Linear(64, num_classes)
)
(6)添加旋转位置编码
rotary_pos_emb = self.rotary_emb(x_patches)
x_patches = x_patches + rotary_pos_emb
通常rotary_pos_emb
为小数,位置编码的位置信息就是通过这些有规律的小数点(正余弦值)体现出来的。模型本身不会去理解这些小数代表位置。但它在训练过程中:
- 注意力机制通过
QKᵀ
进行点积,会用到这些带有位置信息的向量 - 多层Transformer堆叠后,模型逐渐学会利用这些规律的差异值来判断谁先谁后,谁靠近谁
- 训练目标(比如分类、翻译)会推动模型去重视某些位置间的关系
- 最终的注意力权重里就包含了位置信息的体现
3 训练模型
执行下面代码训练模型:
EPOCHS_NUM = 160
def main():
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Log directory setup
log_dir = './logs/'
os.makedirs(log_dir, exist_ok=True)
# # CSV to store training time and parameters
time_metrics = []
# # # Save training time and parameter metrics to a CSV
time_metrics_path = os.path.join(log_dir, 'training_time_metrics.csv')
# pd.DataFrame(time_metrics).to_csv(time_metrics_path, index=False)
# HybridFPNTransformer training
hybrid_model = HybridFPNTransformer(num_classes=4).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(
hybrid_model.parameters(),
lr=1e-4,
weight_decay=1e-5 # L2 regularization
)
start_time = time.time()
trained_hybrid_model, hybrid_train_metrics, hybrid_val_metrics = train_model_enhanced(
model=hybrid_model,
train_loader=train_loader,
val_loader=val_loader,
criterion=criterion,
optimizer=optimizer,
num_epochs=EPOCHS_NUM,
device=device,
model_name='Transformer',
num_classes=4
)
end_time = time.time()
hybrid_training_time = end_time - start_time
csv_path = './logs/Transformer_metrics.csv'
save_metrics_to_csv(hybrid_train_metrics, hybrid_val_metrics, csv_path)
plot_path = './logs/Transformer_plots.png'
plot_metrics(hybrid_train_metrics, hybrid_val_metrics, plot_path)
# Log HybridFPNTransformer training time and parameters
time_metrics.append({
'Model': 'Transformer',
'Training Time (seconds)': hybrid_training_time,
'Number of Parameters': sum(p.numel() for p in hybrid_model.parameters() if p.requires_grad)
})
# Clear VRAM
del hybrid_model, criterion, optimizer, trained_hybrid_model
torch.cuda.empty_cache()
pd.DataFrame(time_metrics).to_csv(time_metrics_path, index=False)
hybrid_model_gry = HybridFPNTransformer_gry(num_classes=4).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(
hybrid_model_gry.parameters(),
lr=1e-4,
weight_decay=1e-5 # L2 regularization
)
start_time = time.time()
hybrid_model_gry, hybrid_train_metrics_gry, hybrid_val_metrics_gry = train_model_enhanced(
model=hybrid_model_gry,
train_loader=train_loader_gry,
val_loader=val_loader_gry,
criterion=criterion,
optimizer=optimizer,
num_epochs=EPOCHS_NUM,
device=device,
model_name='HybridFPNTransformerGRY',
num_classes=4
)
end_time = time.time()
hybrid_gry_training_time = end_time - start_time
csv_path = './logs/HybridFPNTransformerGRY_metrics.csv'
save_metrics_to_csv(hybrid_train_metrics_gry, hybrid_val_metrics, csv_path)
plot_path = './logs/HybridFPNTransformerGRY_plots.png'
plot_metrics(hybrid_val_metrics_gry, hybrid_val_metrics, plot_path)
# Log HybridFPNTransformer training time and parameters
time_metrics.append({
'Model': 'HybridFPNTransformerGRY',
'Training Time (seconds)': hybrid_gry_training_time,
'Number of Parameters': sum(p.numel() for p in hybrid_model_gry.parameters() if p.requires_grad)
})
del hybrid_model_gry, criterion, optimizer
torch.cuda.empty_cache()
pd.DataFrame(time_metrics).to_csv(time_metrics_path, index=False)
4 结果分析
这里我们来看一下Transformer
模型的部分输出:
Transformer RGB的Loss和F1曲线
Transformer Gray的Loss和F1曲线
混淆矩阵
推理时间和参数
4.1 训练与验证损失及F1分数曲线
Baseline模型收敛更快,验证集的F1分数也更高。这是因为Baseline模型采用了预训练权重进行训练。VGG表现最好,这可能归因于其较大的网络结构、卷积块数量、卷积核数量等因素,使其能提取其他模型难以捕捉的分类特征。
相比之下,Transformer模型和灰度版本通常收敛较慢,验证F1分数也较低。这是因为它们未使用预训练权重,并且Transformer本身对数据量要求高。过拟合现象明显:训练F1分数达到100%,而验证集约为98%;而Baseline模型的训练F1分数约为99%~99.99%,验证集F1分数约为99%~99.89%。这说明自定义Transformer模型并未学到所有通用图像特征,而是对训练噪声产生了拟合。这一点从验证损失曲线的升高中也可看出。
4.2 最优验证F1分数模型的总体指标
F1分数、精确率和召回率对检测阳性类别(如癌症)非常关键,而特异性在判断阴性类别(如癌症分期)中更为重要。所有模型在特异性上都表现出色,而F1分数、精确率和召回率则体现了模型间的性能差异。混淆矩阵也做了可视化展示。
总体而言,VGG表现最佳,其次是ResNet-18、AlexNet、Transformer与Transformer_gry。然而,VGG仅比最小模型提高了约1.9%的F1分数,但它的推理速度最慢,说明其架构效率较低。AlexNet推理最快,但参数量仍高于Transformer_gry。ResNet18在参数或推理速度方面无优势。参数越大,占用的显存越多;推理速度越慢,部署延迟越高。
结果表明,参数规模与推理速度之间并无直接关联:Transformer的参数量较少,但推理速度却比参数更大的AlexNet慢。这是由于Transformer中的注意力计算为序列式,不能完全并行化(每个token都要与所有其他token交互);而CNN的卷积操作可以完全并行,因此速度更快。此外,Transformer的计算复杂度为O(N²),意味着输入图像尺寸缩小会显著加快处理速度,正如Transformer_gry所展示的那样。结果也说明,在该简单数据集上,模型变大并不会带来显著性能提升:VGG16参数量是transformer的30倍,F1分数仅提升不到2%。
综合来看,transformer_gry在内存效率、推理速度和F1分数之间达成了最佳平衡。而在对结果要求极高的场景中,AlexNet是综合表现最优的选择。AlexNet也可进一步优化为接受灰度输入,但此举将失去预训练权重,最终性能存在不确定性。若面对更大且更复杂的数据集,Transformer架构或许是效率与准确率的最佳折中选择。
4.3 研究发现
上一章节指出,我们自定义的混合Transformer模型相较于优化后的Baseline模型表现出收敛慢、验证F1分数低的特点。这可能与自定义Transformer架构面临的一些挑战有关。
主要问题包括:缺乏预训练权重,以及Transformer对大规模数据的高度依赖。由于本项目中的Transformer是从零开始训练,没有预训练权重支持,因此难以学习通用图像特征;而CNN模型可借助预训练滤波器捕捉普适模式。
此外,Transformer依赖大数据特性也导致了过拟合现象:训练F1分数达到100%,而验证集停留在98%,说明模型更依赖于数据中的噪声模式,而不是有意义的特征提取。尽管存在这些问题,Transformer在一些方面仍展现出潜力,尤其是在处理灰度图像时,在内存效率与推理速度之间找到了较好平衡,这显示了其在资源受限场景下的应用价值,只要对其架构进一步优化,仍具备可行性。
另一个显著发现是:数据集结构特性对CNN与Transformer的影响差异明显。CNN在空间局部性和位移不变性上表现优异,非常契合脑部MRI图像的结构特征;而Transformer擅长处理复杂、非结构化的关系,但需要更大的数据量。在我们的较小数据集中,由于Transformer缺乏归纳偏置,导致其容易过拟合,无法有效捕捉图像模式。这一现象揭示了两种架构的本质权衡:CNN适用于结构清晰、空间一致性强的任务,而Transformer更适用于关系复杂、模式多变的场景。
4.4 局限性
本项目中的Transformer模型整体表现不如CNN,主要原因如下:
首先是缺乏预训练权重。由于没有可用的预训练模型,我们的Transformer必须从零开始训练,这使得其必须从头学习所有特征。而CNN模型可以使用ImageNet等大型数据上训练好的权重进行微调,这让CNN在面对小样本MRI数据时具有明显优势。如果能有跨领域的预训练Transformer权重,其表现可能会大幅提升。
其次是数据集与模型架构不匹配。脑瘤图像具有明显的局部特征,并且肿瘤的分类并不依赖其在图像中的具体位置(即具有平移不变性)。这与CNN的设计理念天然一致,而Transformer则缺乏这种归纳偏置,必须依赖更多的数据去显式学习这些空间关系。在我们只有约6000张图像的四分类数据集中,Transformer很难有效泛化,因此容易拟合无效模式。
最后是数据集本身的限制。Transformer通常在面对复杂模式、非结构图像、大型数据集时才会展现优势;而我们的数据集不仅类别少(仅有4类:胶质瘤、脑膜瘤、垂体瘤和无肿瘤),且空间结构简单,难以发挥Transformer在复杂关系建模方面的长处。此外,数据集缺乏详细文档说明,也对训练质量构成了潜在影响。
5 总结
本项目围绕脑瘤MRI图像分类任务,系统比较了基于CNN与Transformer的两类深度学习模型的性能与适用性。我们首先构建了以VGG16、ResNet18、AlexNet为代表的CNN基准模型,并引入了一种自定义的HybridFPNTransformer架构,在此基础上探讨其在RGB和灰度图像处理中的表现。
在充分实验与可视化分析的支持下,我们得出以下几点核心结论:
- CNN模型在本任务中表现更优:借助预训练权重与结构先验,CNN模型在小样本医学图像中具备更强的泛化能力。VGG16模型达到了几乎完美的F1分数,远超基于Transformer从零训练的结果。
- Transformer模型存在性能劣势但具备潜力:尽管自定义Transformer模型在准确率上略逊一筹(F1约98%),但其参数更少、推理更快,特别是在使用灰度图时在计算效率方面显示出明显优势。这为在边缘设备或计算资源有限场景中的部署提供了可能。
- 过拟合与数据依赖性是Transformer的关键挑战:在小数据集上,Transformer模型由于缺乏卷积的归纳偏置,更容易记忆噪声而非抽象有效特征。我们观察到其训练集F1迅速达100%,而验证集F1徘徊在较低水平,进一步证实了其对大规模训练数据的依赖。
- 任务结构决定模型优势:本任务中的脑瘤图像具备局部显著性与空间结构一致性,天然契合CNN的归纳偏差;而Transformer更适用于语义关系复杂或长距离依赖显著的任务,如多器官检测或多模态影像处理。因此,模型选择应依据具体任务的视觉特性而定。
- Transformer仍具发展空间:未来若引入适用于医学图像的预训练权重、增加数据量或采用自监督预训练策略,Transformer的表现有望进一步提升。此外,通过改进patch分割策略、引入局部注意力机制、融合多尺度特征等方式,可使其更适应医学影像场景。
综上所述,尽管在当前设定下,CNN在精度表现上占据主导,但Transformer在模型轻量化与全局建模能力方面展现出良好潜力。未来通过预训练、结构优化与大数据支持,Transformer有望在医学影像分类任务中与CNN分庭抗礼,甚至实现超越。