class MultiHeadAttention(nn.Module):
def __init__(self, heads, d_model, dropout=0.1):
super().__init__()
self.d_model = d_model
self.d_k = d_model // heads # 每个头的维度,d_model 除以 heads
self.h = heads # 多头的数量
# 定义线性变换层,用于将输入分别映射到 q, k, v 空间
self.q_linear = nn.Linear(d_model, d_model)
self.v_linear = nn.Linear(d_model, d_model)
self.k_linear = nn.Linear(d_model, d_model)
# 定义 dropout 层,用于
【大规模语言模型:从理论到实践】Transformer中MultiHeadAttention详解
最新推荐文章于 2025-04-10 12:10:09 发布