<think>我们正在讨论PyTorch中的torch.einsum函数。这是一个强大的函数,用于执行爱因斯坦求和约定(Einstein summation convention)。它允许我们以简洁的方式表示各种张量运算(如矩阵乘法、点积、外积、转置等)。用户要求了解其使用方法和示例。
首先,让我们回顾爱因斯坦求和约定:它通过省略求和符号(Σ)来简化表达式。例如,矩阵乘法$C_{ik} = \sum_j A_{ij} B_{jk}$可以简写为$A_{ij} B_{jk}$,并指定输出下标为$ik$。
在torch.einsum中,我们通过一个字符串表达式来指定操作。格式为:`'下标字符串, 张量1, 张量2, ... -> 输出下标字符串'`。其中,下标字符串由字母组成,表示每个维度的标签。如果省略输出下标字符串,则函数会输出求和后的标量(即所有元素的和)。
接下来,我将通过几个常见示例来说明torch.einsum的用法。
### 示例1:矩阵乘法
矩阵乘法$C = AB$,其中$A$是$m \times n$矩阵,$B$是$n \times p$矩阵,则$C_{ik} = \sum_{j} A_{ij} B_{jk}$。
在代码中,我们可以这样写:
```python
import torch
A = torch.randn(2, 3)
B = torch.randn(3, 4)
# 使用einsum进行矩阵乘法
C = torch.einsum('ij,jk->ik', A, B)
# 验证:与torch.matmul比较
assert torch.allclose(C, torch.matmul(A, B))
```
### 示例2:点积
两个向量的点积:$c = \sum_i a_i b_i$。
```python
a = torch.randn(5)
b = torch.randn(5)
c = torch.einsum('i,i->', a, b)
# 验证:与torch.dot比较
assert torch.allclose(c, torch.dot(a, b))
```
### 示例3:外积
两个向量的外积:$C_{ij} = a_i b_j$。
```python
a = torch.randn(5)
b = torch.randn(4)
C = torch.einsum('i,j->ij', a, b)
# 验证:外积的另一种实现
outer = a.unsqueeze(1) * b.unsqueeze(0)
assert torch.allclose(C, outer)
```
### 示例4:批量矩阵乘法
假设我们有一批矩阵,比如A的形状为(b, m, n),B的形状为(b, n, p),那么批量矩阵乘法可以表示为:$C_{bik} = \sum_j A_{bij} B_{bjk}$。
```python
batch_A = torch.randn(10, 2, 3)
batch_B = torch.randn(10, 3, 4)
batch_C = torch.einsum('bij,bjk->bik', batch_A, batch_B)
# 验证:使用torch.bmm
assert torch.allclose(batch_C, torch.bmm(batch_A, batch_B))
```
### 示例5:张量缩并(Tensor Contraction)
假设有一个四维张量A(形状为2x3x4x5)和一个三维张量B(形状为4x5x6),我们想对A的最后两个维度和B的前两个维度进行缩并,得到一个新的张量C(形状为2x3x6):
$C_{abk} = \sum_{ij} A_{abij} B_{ijk}$
```python
A = torch.randn(2, 3, 4, 5)
B = torch.randn(4, 5, 6)
C = torch.einsum('abij,ijk->abk', A, B)
# 另一种实现:将A重塑为(2*3, 4*5),B重塑为(4*5, 6),然后矩阵乘法,再重塑回(2,3,6)
A_reshaped = A.reshape(2*3, -1) # (6, 20)
B_reshaped = B.reshape(-1, 6) # (20, 6)
C_alt = torch.matmul(A_reshaped, B_reshaped).reshape(2,3,6)
assert torch.allclose(C, C_alt)
```
### 示例6:求和
我们可以使用einsum来求和。例如,对矩阵A的所有元素求和:$s = \sum_{ij} A_{ij}$
```python
A = torch.randn(3,4)
s = torch.einsum('ij->', A) # 等价于A.sum()
assert torch.allclose(s, A.sum())
```
### 示例7:按轴求和
对矩阵A的每一行求和(即沿列方向求和):$v_i = \sum_j A_{ij}$
```python
v = torch.einsum('ij->i', A) # 等价于A.sum(dim=1)
assert torch.allclose(v, A.sum(dim=1))
```
### 示例8:元素对应乘法(Hadamard积)再求和
两个相同形状的张量A和B,先逐元素相乘再求和(即内积):
$s = \sum_{ij} A_{ij} B_{ij}$
```python
A = torch.randn(3,4)
B = torch.randn(3,4)
s = torch.einsum('ij,ij->', A, B) # 等价于(A*B).sum()
assert torch.allclose(s, (A*B).sum())
```
### 注意事项:
1. 下标字符串中的字母可以是任意小写字母(或大写字母,但通常用小写),但每个维度都要用一个字母表示。
2. 在输入和输出中重复出现的下标表示该维度会被求和(即缩并),而只在输出中出现的下标则保留。
3. 我们可以使用省略号(...)来处理批处理维度或未知维度。例如,对于两个具有相同批处理维度的矩阵进行批量乘法,但不知道批处理维度的具体数量,可以这样:
```python
# A: (..., m, n), B: (..., n, p)
C = torch.einsum('...ij,...jk->...ik', A, B)
```
### 性能说明:
虽然`torch.einsum`非常灵活,但在某些情况下可能不如专用函数(如`torch.matmul`、`torch.dot`)高效。因此,在性能关键的代码中,如果存在专用函数,建议使用专用函数。但是,对于复杂的张量运算,`einsum`提供了一种清晰且简洁的表达方式。
### 总结
`torch.einsum`是一个强大的工具,通过爱因斯坦求和约定,我们可以用简洁的字符串表达式来描述复杂的张量运算。掌握它可以帮助我们更高效地编写张量操作代码。
§§相关问题§§
1. 如何用`torch.einsum`实现张量转置?
2. 在`torch.einsum`中,如何处理具有多个批处理维度的张量运算?
3. `torch.einsum`和`numpy.einsum`有什么区别?
4. 使用`torch.einsum`时,如何避免常见的错误(如下标不匹配)?</think>### PyTorch `torch.einsum()` 函数详解
`torch.einsum()` 是基于爱因斯坦求和约定(Einstein summation convention)的张量运算函数,它通过简洁的下标标记法实现复杂的张量操作。其核心语法为:
```
torch.einsum('下标表达式', 张量1, 张量2, ...)
```
其中:
- **下标表达式**:定义输入/输出维度的字母序列(如 `ij,jk->ik`)
- **箭头 `->` 左侧**:输入张量的维度标记
- **箭头右侧**:输出张量的维度标记
- **重复下标**:自动执行求和操作(缩并)
---
### 一、基础用法示例
#### 1. 矩阵乘法
```python
import torch
A = torch.tensor([[1, 2], [3, 4]]) # 2×2
B = torch.tensor([[5, 6], [7, 8]]) # 2×2
# 等价于 A @ B
C = torch.einsum('ij,jk->ik', A, B)
# 输出: tensor([[19, 22], [43, 50]])
```
- `ij,jk->ik` 表示:$C_{ik} = \sum_{j} A_{ij} \times B_{jk}$
#### 2. 向量点积
```python
x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 6])
# 等价于 torch.dot(x, y)
dot = torch.einsum('i,i->', x, y)
# 输出: tensor(32)
```
- `i,i->` 表示:$\sum_{i} x_i \times y_i$
#### 3. 矩阵迹(对角线求和)
```python
M = torch.tensor([[1, 2], [3, 4]])
# 等价于 torch.trace(M)
trace = torch.einsum('ii->', M)
# 输出: tensor(5)
```
---
### 二、进阶用法示例
#### 1. 批量矩阵乘法
```python
batch_A = torch.randn(3, 2, 4) # 3个2×4矩阵
batch_B = torch.randn(3, 4, 5) # 3个4×5矩阵
# 等价于 torch.bmm(batch_A, batch_B)
result = torch.einsum('bij,bjk->bik', batch_A, batch_B) # 输出形状: (3,2,5)
```
- `b` 表示批处理维度,$C_{bik} = \sum_{j} A_{bij} \times B_{bjk}$
#### 2. 张量缩并
```python
T = torch.randn(2, 3, 4) # 三维张量
v = torch.randn(4) # 向量
# 对第三维缩并: T[i,j,k] * v[k]
output = torch.einsum('ijk,k->ij', T, v) # 输出形状: (2,3)
```
#### 3. 双线性变换
```python
A = torch.randn(5, 3)
B = torch.randn(5, 4)
x = torch.randn(3)
y = torch.randn(4)
# 计算 x^T A^T B y
result = torch.einsum('ik,jk,i,j->', A, B, x, y)
```
---
### 三、特殊操作
#### 1. 转置
```python
M = torch.tensor([[1, 2], [3, 4]])
# 等价于 M.T
MT = torch.einsum('ij->ji', M)
# 输出: tensor([[1,3], [2,4]])
```
#### 2. 求和特定维度
```python
X = torch.randn(2, 3, 4)
# 沿第二维求和
sum_dim1 = torch.einsum('ijk->ik', X) # 等价于 X.sum(dim=1)
```
#### 3. 外积
```python
a = torch.tensor([1, 2])
b = torch.tensor([3, 4, 5])
# 等价于 torch.outer(a, b)
outer = torch.einsum('i,j->ij', a, b)
# 输出: tensor([[3,4,5], [6,8,10]])
```
---
### 四、使用技巧
1. **省略号语法**:处理可变批次维度
```python
# 适用于任意批次大小的矩阵乘法
result = torch.einsum('...ij,...jk->...ik', batch_A, batch_B)
```
2. **性能提示**:
- 对于简单操作(如矩阵乘法),专用函数 `torch.matmul()` 更快
- 复杂张量操作时 `einsum` 更高效且代码更简洁
3. **错误避免**:
- 下标数量必须与张量维度匹配
- 输出下标需在输入中出现(除非求和)
- 相同下标字母的维度长度必须一致
> 引用说明:爱因斯坦求和约定可简化张量运算表达式[^1],PyTorch 的实现与 NumPy 的 `einsum` 高度兼容。