torch.matmul()
函数详解
torch.matmul()
是 PyTorch 中用于执行 矩阵乘法(矩阵点积) 的函数,支持 1D、2D、3D 及更高维度张量的广义矩阵乘法,是深度学习中非常常用的线性代数运算。
1. 基本语法
torch.matmul(tensor1, tensor2) → Tensor
参数 | 说明 |
---|---|
tensor1 , tensor2 | 要相乘的两个张量(维度可以是 1D、2D 或更高) |
返回值 | 两个张量的矩阵乘积,遵循广播规则 |
2. 不同维度输入的行为
2.1 两个 1D 张量(向量点积)
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
result = torch.matmul(a, b)
print(result) # 输出: tensor(32)
计算的是:
1
×
4
+
2
×
5
+
3
×
6
=
32
1 \times 4 + 2 \times 5 + 3 \times 6 = 32
1×4+2×5+3×6=32
2.2 1D × 2D 或 2D × 1D(向量和矩阵)
# 1D × 2D,向量乘矩阵
a = torch.tensor([1.0, 2.0]) # shape: (2,)
b = torch.tensor([[3.0, 4.0],
[5.0, 6.0]]) # shape: (2, 2)
print(torch.matmul(a, b)) # shape: (2,)
# 输出: tensor([13., 16.])
[ 1 , 2 ] × [ 3 4 5 6 ] = [ 13 , 16 ] [1, 2] \times \begin{bmatrix}3 & 4\\5 & 6\end{bmatrix} = [13, 16] [1,2]×[3546]=[13,16]
# 2D × 1D,矩阵乘向量
a = torch.tensor([[3.0, 4.0],
[5.0, 6.0]]) # shape: (2, 2)
b = torch.tensor([1.0, 2.0]) # shape: (2,)
print(torch.matmul(a, b)) # shape: (2,)
# 输出: tensor([11., 17.])
2.3 两个 2D 张量(标准矩阵乘法)
a = torch.tensor([[1, 2], [3, 4]]) # (2, 2)
b = torch.tensor([[5, 6], [7, 8]]) # (2, 2)
print(torch.matmul(a, b))
# 输出:
# tensor([[19, 22],
# [43, 50]])
2.4 批量矩阵乘法(3D 或更高)
a = torch.randn(10, 3, 4) # 表示 10 个 (3×4) 的矩阵
b = torch.randn(10, 4, 5) # 表示 10 个 (4×5) 的矩阵
result = torch.matmul(a, b) # 输出: (10, 3, 5)
广播规则:
对于更高维度的张量,matmul
会按照批次维度进行 广播并执行批矩阵乘法。
3. 与 @
运算符等价
a = torch.randn(2, 3)
b = torch.randn(3, 4)
# 等价写法
result1 = torch.matmul(a, b)
result2 = a @ b
4. 与 torch.mm()
、torch.bmm()
区别
函数 | 适用维度 | 描述 |
---|---|---|
torch.matmul() | 任意维度 | 广义矩阵乘法,推荐使用 |
torch.mm() | 仅 2D × 2D | 标准矩阵乘法 |
torch.bmm() | 仅 3D × 3D | 批矩阵乘法(batch matrix multiplication) |
5. 错误示例
a = torch.randn(3, 4)
b = torch.randn(2, 3)
torch.matmul(a, b) # ❌ 尺寸不匹配 (4 ≠ 2)
注意:
- 要保证
matmul(a, b)
中a
的最后一维 ==b
的倒数第二维。
6. 常见用途
- 实现线性层(如:
y = x @ W.T + b
) - 注意力机制中的 Q·Kᵀ 操作
- Transformer 中的矩阵操作
- 图神经网络中的邻接矩阵乘法
- batch 处理中的多样本矩阵运算
7. 总结
特性 | 说明 |
---|---|
作用 | 执行广义矩阵乘法 |
支持维度 | 1D、2D、3D、甚至更高维 |
支持广播 | 是 |
常用替代 | @ 运算符 |
推荐用途 | 适用于所有矩阵乘法场景 |