在ResNet50和BERT的交叉注意力计算中,维度变换的核心在于对齐两个模型的特征空间,使其能够通过注意力机制进行交互。以下是具体步骤和原理分析:
1. 特征维度对齐
• ResNet50图像特征处理
ResNet50的典型输出为经过全局平均池化后的2048维特征向量(对应全连接层前的特征)。若需保留空间信息(如中间层特征),其形状可能为 (batch, channels, height, width)
(例如 (B, 2048, H, W)
)。此时需将空间维度展平为序列形式,例如 (B, H*W, 2048)
,并通过线性层(如1×1卷积或全连接层)将通道数从2048投影到768维,与BERT的隐藏层维度一致。
变换后维度:(B, H*W, 768)
。
• BERT文本特征处理
BERT的输出通常为 (B, seq_len, 768)
,其中 seq_len
是文本序列长度。若需与图像特征交互,可直接使用该维度,无需额外调整。
2. 交叉注意力机制实现
交叉注意力的核心是将图像特征作为查询(Query),文本特征作为键(Key)和值(Value)(或反之),通过线性变换实现维度匹配:
• 线性投影:
对图像特征和文本特征分别应用线性变换生成Q、K、V。例如:
• 图像Q:Linear(768 → d_model)
,文本K/V:Linear(768 → d_model)
其中 d_model
是注意力头维度,如BERT的64(总维度768/12头)。
变换后维度:Q为 (B, H*W, d_model)
,K/V为 (B, seq_len, d_model)
。
• 注意力计算:
通过矩阵乘法计算注意力得分 (Q·K^T)/√d_model
,得到形状为 (B, H*W, seq_len)
的权重矩阵。加权求和后,输出特征维度为 (B, H*W, d_model)
或 (B, seq_len, d_model)
(取决于Q的来源)。
3. 输出特征融合
• 多模态融合:
交叉注意力输出可与原始特征相加或拼接。例如,图像特征经注意力增强后,可再通过全连接层恢复为原任务所需维度(如分类任务的类别数)。
• 维度一致性:
若需将结果输入后续网络(如ResNet的下一层或BERT的另一个Transformer层),需确保输出维度与目标模块的输入匹配。例如,图像特征可能需要从768维再次投影回ResNet的通道数(如2048)。
以下是基于ResNet50和BERT进行交叉注意力计算的PyTorch代码实现,结合两者的特征维度对齐与注意力机制设计:
4. 特征维度对齐模块
ResNet50特征处理(空间序列化 + 投影)
import torch
import torch.nn as nn
from torchvision.models import resnet50
class ResNetFeatureExtractor(nn.Module):
def __init__(self, hidden_dim=768):
super().__init__()
# 加载预训练的ResNet50(去除全连接层)
self.resnet = nn.Sequential(*list(resnet50(pretrained=True).children())[:-2])
# 空间展平与维度投影
self.projection = nn.Sequential(
nn.Conv2d(2048, hidden_dim, kernel_size=1), # 1x1卷积将通道数从2048→768
nn