牛bMM

### PyTorch 中 `bmm` 函数的用法 在 PyTorch 中,`bmm` 是 Batch Matrix Multiplication 的缩写,用于执行批量矩阵乘法操作。具体来说,如果输入张量分别为 `(b, m, n)` 和 `(b, n, p)`,那么 `bmm` 将返回形状为 `(b, m, p)` 的结果张量[^1]。 以下是关于 `bmm` 函数的一些重要细节: #### 输入要求 - 两个输入张量的第一个维度必须相同,表示批处理大小。 - 第二个张量的最后一维应等于第一个张量的倒数第二维,这是矩阵乘法规则的要求。 #### 输出特性 - 结果是一个三维张量,其形状由批处理大小以及前两者的其他维度决定。 #### 示例代码 下面展示了一个简单的例子来说明如何使用 `bmm` 进行批量矩阵运算: ```python import torch # 创建两个随机张量 (batch_size=3, dim1=2, dim2=4), (batch_size=3, dim1=4, dim2=5) tensor_a = torch.randn(3, 2, 4) tensor_b = torch.randn(3, 4, 5) # 使用 bmm 执行批量矩阵乘法 result = torch.bmm(tensor_a, tensor_b) print(result.shape) # 输出应该是 (3, 2, 5),即 batch size 保持不变 ``` 此代码片段展示了如何利用 `bmm` 对一批次中的每一对矩阵分别计算乘积,并最终得到一个新的批次矩阵集合[^2]。 #### 应用场景 `bmm` 常见于涉及多层感知机(MLP)、循环神经网络(RNNs),特别是 Seq2Seq 模型及其变体中使用的注意力机制部分。例如,在生成注意权重时,通常会先计算查询向量与键向量之间的相似度得分,再通过 `softmax` 转化成概率分布形式[^5]。 --- ### 注意力机制中的应用案例 假设我们有一个查询张量 Q 形状为 `[B, T_q, d_k]` ,还有一个键值对 K,V 张形均为 `[B,T_v,d_k]`,其中 B 表示批次大小;T_q 和 T_v 分别代表查询序列长度和值序列长度;d_k 则是每个词嵌入维度。我们可以这样实现基本点积注意力: ```python import torch.nn.functional as F def scaled_dot_product_attention(Q, K, V): """ 计算 Scaled Dot Product Attention. 参数: Q: 查询张量, shape [B, T_q, d_k]. K: 键张量, shape [B, T_v, d_k]. V: 值张量, shape [B, T_v, d_v]. 返回: context_vector: 上下文向量, shape [B, T_q, d_v]. """ # Step 1: Compute dot product between queries and keys scores = torch.bmm(Q, K.transpose(1, 2)) / (K.size(-1)**0.5) # Step 2: Apply softmax along last dimension to get attention weights attn_weights = F.softmax(scores, dim=-1) # Step 3: Multiply with values to compute final output context_vector = torch.bmm(attn_weights, V) return context_vector ``` 上述函数实现了标准的缩放点积注意力算法,其中第一步就是调用了 `torch.bmm()` 来完成查询Q与转置后的键K之间的大规模并行化的矩阵相乘操作[^4]。 ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值