### PyTorch 中 `bmm` 函数的用法
在 PyTorch 中,`bmm` 是 Batch Matrix Multiplication 的缩写,用于执行批量矩阵乘法操作。具体来说,如果输入张量分别为 `(b, m, n)` 和 `(b, n, p)`,那么 `bmm` 将返回形状为 `(b, m, p)` 的结果张量[^1]。
以下是关于 `bmm` 函数的一些重要细节:
#### 输入要求
- 两个输入张量的第一个维度必须相同,表示批处理大小。
- 第二个张量的最后一维应等于第一个张量的倒数第二维,这是矩阵乘法规则的要求。
#### 输出特性
- 结果是一个三维张量,其形状由批处理大小以及前两者的其他维度决定。
#### 示例代码
下面展示了一个简单的例子来说明如何使用 `bmm` 进行批量矩阵运算:
```python
import torch
# 创建两个随机张量 (batch_size=3, dim1=2, dim2=4), (batch_size=3, dim1=4, dim2=5)
tensor_a = torch.randn(3, 2, 4)
tensor_b = torch.randn(3, 4, 5)
# 使用 bmm 执行批量矩阵乘法
result = torch.bmm(tensor_a, tensor_b)
print(result.shape) # 输出应该是 (3, 2, 5),即 batch size 保持不变
```
此代码片段展示了如何利用 `bmm` 对一批次中的每一对矩阵分别计算乘积,并最终得到一个新的批次矩阵集合[^2]。
#### 应用场景
`bmm` 常见于涉及多层感知机(MLP)、循环神经网络(RNNs),特别是 Seq2Seq 模型及其变体中使用的注意力机制部分。例如,在生成注意权重时,通常会先计算查询向量与键向量之间的相似度得分,再通过 `softmax` 转化成概率分布形式[^5]。
---
### 注意力机制中的应用案例
假设我们有一个查询张量 Q 形状为 `[B, T_q, d_k]` ,还有一个键值对 K,V 张形均为 `[B,T_v,d_k]`,其中 B 表示批次大小;T_q 和 T_v 分别代表查询序列长度和值序列长度;d_k 则是每个词嵌入维度。我们可以这样实现基本点积注意力:
```python
import torch.nn.functional as F
def scaled_dot_product_attention(Q, K, V):
"""
计算 Scaled Dot Product Attention.
参数:
Q: 查询张量, shape [B, T_q, d_k].
K: 键张量, shape [B, T_v, d_k].
V: 值张量, shape [B, T_v, d_v].
返回:
context_vector: 上下文向量, shape [B, T_q, d_v].
"""
# Step 1: Compute dot product between queries and keys
scores = torch.bmm(Q, K.transpose(1, 2)) / (K.size(-1)**0.5)
# Step 2: Apply softmax along last dimension to get attention weights
attn_weights = F.softmax(scores, dim=-1)
# Step 3: Multiply with values to compute final output
context_vector = torch.bmm(attn_weights, V)
return context_vector
```
上述函数实现了标准的缩放点积注意力算法,其中第一步就是调用了 `torch.bmm()` 来完成查询Q与转置后的键K之间的大规模并行化的矩阵相乘操作[^4]。
---