彻底理解PyTorch中的矩阵乘法:matmul vs 点积
数学本质区分
点积(Dot Product)
- 定义:仅适用于两个相同长度的向量
a = [a1, a2, …, an]
b = [b1, b2, …, bn]
d
o
t
p
r
o
d
u
c
t
=
a
⋅
b
=
Σ
a
i
∗
b
i
dot_product = a·b = Σa_i*b_i
dotproduct=a⋅b=Σai∗bi(结果是一个标量)
复制- 几何意义:投影长度的乘积
矩阵乘法(Matrix Multiplication)
- 定义:两个矩阵的线性组合运算
A
∈
R
(
m
×
n
)
,
B
∈
R
(
n
×
p
)
A ∈ R^(m×n), B ∈ R^(n×p)
A∈R(m×n),B∈R(n×p)
C
=
A
B
∈
R
(
m
×
p
)
C = AB ∈ R^(m×p)
C=AB∈R(m×p)
C
[
i
,
j
]
=
Σ
A
[
i
,
k
]
⋅
B
[
k
,
j
]
(
k
从
1
到
n
)
C[i,j] = ΣA[i,k]·B[k,j] (k从1到n)
C[i,j]=ΣA[i,k]⋅B[k,j](k从1到n)
- 核心要求:A的列数 = B的行数
PyTorch实现对比
场景1:向量操作
# 创建两个向量
v1 = torch.tensor([1., 2., 3.]) # shape [3]
v2 = torch.tensor([4., 5., 6.]) # shape [3]
# 点积(严格一维)
print(torch.dot(v1, v2)) # → 1*4 + 2*5 + 3*6 = 32.0
# matmul处理向量(自动识别为点积)
print(torch.matmul(v1, v2)) # → 32.0
print(v1 @ v2) # → 32.0(@是matmul别名)
场景2:矩阵操作
M1 = torch.tensor([[1., 2.], [3., 4.]]) # 2x2
M2 = torch.tensor([[5., 6.], [7., 8.]]) # 2x2
# 标准矩阵乘法
print(torch.matmul(M1, M2))
# [[1*5+2*7=19, 1*6+2*8=22],
# [3*5+4*7=43, 3*6+4*8=50]]
# 错误尝试点积
try:
print(torch.dot(M1, M2))
except RuntimeError as e:
print(e) # 报错:只接受1D输入
场景3:混合维度操作
M = torch.randn(3, 4) # 3x4矩阵
v = torch.randn(4) # 4元素向量
# 矩阵×向量 → 向量
result = torch.matmul(M, v) # 结果shape [3]
print(result.shape)
# 尝试点积会报错
try:
print(torch.dot(M, v))
except RuntimeError as e:
print(e) # 输入必须都是1D
广播机制下的差异
案例:三维张量运算
batch_M = torch.randn(5, 3, 4) # 5个3x4矩阵
batch_v = torch.randn(5, 4) # 5个4元素向量
# matmul自动广播
result = torch.matmul(batch_M, batch_v.unsqueeze(-1))
print(result.shape) # → [5, 3, 1]
# 等效手动维度调整
manual_result = []
for i in range(5):
manual_result.append(batch_M[i] @ batch_v[i].unsqueeze(1))
print(torch.stack(manual_result).shape) # → [5, 3, 1]