使用torch普通算子组合替代torch.einsum爱因斯坦求和

1. torch.einsum(‘bnd, bmd->bnm’, x, y)

torch.einsum(‘bnd, bmd->bnm’, x, y) 表示的是对张量 x 和 y 进行特定的求和和维度变换。

具体来说,这个操作的输入是两个形状为 [b, n, d] 和 [b, m, d] 的张量 x 和 y,输出是一个形状为 [b, n, m] 的张量 z。其计算过程可以理解为:对于每个 b,z[b, n, m] 等于 x[b, n, :] 和 y[b, m, :] 之间的点积。

为了用普通的 torch 操作符来替代 einsum,我们可以通过 torch.matmul 函数实现。这个函数可以用来执行批量矩阵乘法,并且能够很好地替代这个 einsum 操作。

具体实现如下:

import torch

# 假设 x 和 y 的形状分别为 (b, n, d) 和 (b, m, d)
x = torch.randn(10, 20, 30)  # 举例
y = torch.randn(10, 15, 30)  # 举例

# einsum: z = torch.einsum('bnd, bmd->bnm', x, y)
# 可以转换为以下操作:
z = torch.matmul(x, y.transpose(-1, -2))  # z 的形状为 (b, n, m)

# 检查 z 的形状是否正确
print(z.shape)

2. torch.einsum(‘ij,jk->ik’, A, B)

可以用普通的矩阵乘法 torch.matmul 替代

具体实现如下:

import torch

A = torch.rand(3, 4)
B = torch.rand(4, 5)

# 使用 einsum
result_einsum = torch.einsum('ij,jk->ik', A, B)

# 使用 matmul
result_matmul = torch.matmul(A, B)

# 验证结果相同
print(torch.allclose(result_einsum, result_matmul))

3. torch.einsum(‘bij,bjk->bik’, A, B)

可以用 torch.bmm 来替代

具体实现如下:

import torch

A = torch.rand(10, 3, 4)
B = torch.rand(10, 4, 5)

# 使用 einsum
result_einsum = torch.einsum('bij,bjk->bik', A, B)

# 使用 bmm
result_bmm = torch.bmm(A, B)

# 验证结果相同
print(torch.allclose(result_einsum, result_bmm))

4. torch.einsum(‘i,i->’, A, B)

向量内积,可以用 torch.dot 来替代

具体实现如下:

import torch

A = torch.rand(4)
B = torch.rand(4)

# 使用 einsum
result_einsum = torch.einsum('i,i->', A, B)

# 使用 dot
result_dot = torch.dot(A, B)

# 验证结果相同
print(torch.allclose(result_einsum, result_dot))

5. torch.einsum(‘i,j->ij’, A, B)

向量外积,可以用 torch.outer 来替代

具体实现如下:

import torch

A = torch.rand(4)
B = torch.rand(5)

# 使用 einsum
result_einsum = torch.einsum('i,j->ij', A, B)

# 使用 outer
result_outer = torch.outer(A, B)

# 验证结果相同
print(torch.allclose(result_einsum, result_outer))

不同的 einsum 表达式会对应不同的替代操作,有时可能需要组合多个普通操作来达到相同的效果。如果某些 einsum 表达式太复杂,使用普通算子替代时会比较繁琐,此时建议继续使用 einsum,因为它不仅更简洁,而且通常性能优化得很好。

后续遇到其余需替换的 op 再进行更新

<think>我们正在讨论PyTorch中的torch.einsum函数。这是一个强大的函数,用于执行爱因斯坦求和约定(Einstein summation convention)。它允许我们以简洁的方式表示各种张量运算(如矩阵乘法、点积、外积、转置等)。用户要求了解其使用方法和示例。 首先,让我们回顾爱因斯坦求和约定:它通过省略求和符号(Σ)来简化表达式。例如,矩阵乘法$C_{ik} = \sum_j A_{ij} B_{jk}$可以简写为$A_{ij} B_{jk}$,并指定输出下标为$ik$。 在torch.einsum中,我们通过一个字符串表达式来指定操作。格式为:`'下标字符串, 张量1, 张量2, ... -> 输出下标字符串'`。其中,下标字符串由字母组成,表示每个维度的标签。如果省略输出下标字符串,则函数会输出求和后的标量(即所有元素的和)。 接下来,我将通过几个常见示例来说明torch.einsum的用法。 ### 示例1:矩阵乘法 矩阵乘法$C = AB$,其中$A$是$m \times n$矩阵,$B$是$n \times p$矩阵,则$C_{ik} = \sum_{j} A_{ij} B_{jk}$。 在代码中,我们可以这样写: ```python import torch A = torch.randn(2, 3) B = torch.randn(3, 4) # 使用einsum进行矩阵乘法 C = torch.einsum('ij,jk->ik', A, B) # 验证:与torch.matmul比较 assert torch.allclose(C, torch.matmul(A, B)) ``` ### 示例2:点积 两个向量的点积:$c = \sum_i a_i b_i$。 ```python a = torch.randn(5) b = torch.randn(5) c = torch.einsum('i,i->', a, b) # 验证:与torch.dot比较 assert torch.allclose(c, torch.dot(a, b)) ``` ### 示例3:外积 两个向量的外积:$C_{ij} = a_i b_j$。 ```python a = torch.randn(5) b = torch.randn(4) C = torch.einsum('i,j->ij', a, b) # 验证:外积的另一种实现 outer = a.unsqueeze(1) * b.unsqueeze(0) assert torch.allclose(C, outer) ``` ### 示例4:批量矩阵乘法 假设我们有一批矩阵,比如A的形状为(b, m, n),B的形状为(b, n, p),那么批量矩阵乘法可以表示为:$C_{bik} = \sum_j A_{bij} B_{bjk}$。 ```python batch_A = torch.randn(10, 2, 3) batch_B = torch.randn(10, 3, 4) batch_C = torch.einsum('bij,bjk->bik', batch_A, batch_B) # 验证:使用torch.bmm assert torch.allclose(batch_C, torch.bmm(batch_A, batch_B)) ``` ### 示例5:张量缩并(Tensor Contraction) 假设有一个四维张量A(形状为2x3x4x5)和一个三维张量B(形状为4x5x6),我们想对A的最后两个维度和B的前两个维度进行缩并,得到一个新的张量C(形状为2x3x6): $C_{abk} = \sum_{ij} A_{abij} B_{ijk}$ ```python A = torch.randn(2, 3, 4, 5) B = torch.randn(4, 5, 6) C = torch.einsum('abij,ijk->abk', A, B) # 另一种实现:将A重塑为(2*3, 4*5),B重塑为(4*5, 6),然后矩阵乘法,再重塑回(2,3,6) A_reshaped = A.reshape(2*3, -1) # (6, 20) B_reshaped = B.reshape(-1, 6) # (20, 6) C_alt = torch.matmul(A_reshaped, B_reshaped).reshape(2,3,6) assert torch.allclose(C, C_alt) ``` ### 示例6:求和 我们可以使用einsum求和。例如,对矩阵A的所有元素求和:$s = \sum_{ij} A_{ij}$ ```python A = torch.randn(3,4) s = torch.einsum('ij->', A) # 等价于A.sum() assert torch.allclose(s, A.sum()) ``` ### 示例7:按轴求和 对矩阵A的每一行求和(即沿列方向求和):$v_i = \sum_j A_{ij}$ ```python v = torch.einsum('ij->i', A) # 等价于A.sum(dim=1) assert torch.allclose(v, A.sum(dim=1)) ``` ### 示例8:元素对应乘法(Hadamard积)再求和 两个相同形状的张量A和B,先逐元素相乘再求和(即内积): $s = \sum_{ij} A_{ij} B_{ij}$ ```python A = torch.randn(3,4) B = torch.randn(3,4) s = torch.einsum('ij,ij->', A, B) # 等价于(A*B).sum() assert torch.allclose(s, (A*B).sum()) ``` ### 注意事项: 1. 下标字符串中的字母可以是任意小写字母(或大写字母,但通常用小写),但每个维度都要用一个字母表示。 2. 在输入和输出中重复出现的下标表示该维度会被求和(即缩并),而只在输出中出现的下标则保留。 3. 我们可以使用省略号(...)来处理批处理维度或未知维度。例如,对于两个具有相同批处理维度的矩阵进行批量乘法,但不知道批处理维度的具体数量,可以这样: ```python # A: (..., m, n), B: (..., n, p) C = torch.einsum('...ij,...jk->...ik', A, B) ``` ### 性能说明: 虽然`torch.einsum`非常灵活,但在某些情况下可能不如专用函数(如`torch.matmul`、`torch.dot`)高效。因此,在性能关键的代码中,如果存在专用函数,建议使用专用函数。但是,对于复杂的张量运算,`einsum`提供了一种清晰且简洁的表达方式。 ### 总结 `torch.einsum`是一个强大的工具,通过爱因斯坦求和约定,我们可以用简洁的字符串表达式来描述复杂的张量运算。掌握它可以帮助我们更高效地编写张量操作代码。 §§相关问题§§ 1. 如何用`torch.einsum`实现张量转置? 2. 在`torch.einsum`中,如何处理具有多个批处理维度的张量运算? 3. `torch.einsum`和`numpy.einsum`有什么区别? 4. 使用`torch.einsum`时,如何避免常见的错误(如下标不匹配)?</think>### PyTorch `torch.einsum()` 函数详解 `torch.einsum()` 是基于爱因斯坦求和约定(Einstein summation convention)的张量运算函数,它通过简洁的下标标记法实现复杂的张量操作。其核心语法为: ``` torch.einsum('下标表达式', 张量1, 张量2, ...) ``` 其中: - **下标表达式**:定义输入/输出维度的字母序列(如 `ij,jk->ik`) - **箭头 `->` 左侧**:输入张量的维度标记 - **箭头右侧**:输出张量的维度标记 - **重复下标**:自动执行求和操作(缩并) --- ### 一、基础用法示例 #### 1. 矩阵乘法 ```python import torch A = torch.tensor([[1, 2], [3, 4]]) # 2×2 B = torch.tensor([[5, 6], [7, 8]]) # 2×2 # 等价于 A @ B C = torch.einsum('ij,jk->ik', A, B) # 输出: tensor([[19, 22], [43, 50]]) ``` - `ij,jk->ik` 表示:$C_{ik} = \sum_{j} A_{ij} \times B_{jk}$ #### 2. 向量点积 ```python x = torch.tensor([1, 2, 3]) y = torch.tensor([4, 5, 6]) # 等价于 torch.dot(x, y) dot = torch.einsum('i,i->', x, y) # 输出: tensor(32) ``` - `i,i->` 表示:$\sum_{i} x_i \times y_i$ #### 3. 矩阵迹(对角线求和) ```python M = torch.tensor([[1, 2], [3, 4]]) # 等价于 torch.trace(M) trace = torch.einsum('ii->', M) # 输出: tensor(5) ``` --- ### 二、进阶用法示例 #### 1. 批量矩阵乘法 ```python batch_A = torch.randn(3, 2, 4) # 3个2×4矩阵 batch_B = torch.randn(3, 4, 5) # 3个4×5矩阵 # 等价于 torch.bmm(batch_A, batch_B) result = torch.einsum('bij,bjk->bik', batch_A, batch_B) # 输出形状: (3,2,5) ``` - `b` 表示批处理维度,$C_{bik} = \sum_{j} A_{bij} \times B_{bjk}$ #### 2. 张量缩并 ```python T = torch.randn(2, 3, 4) # 三维张量 v = torch.randn(4) # 向量 # 对第三维缩并: T[i,j,k] * v[k] output = torch.einsum('ijk,k->ij', T, v) # 输出形状: (2,3) ``` #### 3. 双线性变换 ```python A = torch.randn(5, 3) B = torch.randn(5, 4) x = torch.randn(3) y = torch.randn(4) # 计算 x^T A^T B y result = torch.einsum('ik,jk,i,j->', A, B, x, y) ``` --- ### 三、特殊操作 #### 1. 转置 ```python M = torch.tensor([[1, 2], [3, 4]]) # 等价于 M.T MT = torch.einsum('ij->ji', M) # 输出: tensor([[1,3], [2,4]]) ``` #### 2. 求和特定维度 ```python X = torch.randn(2, 3, 4) # 沿第二维求和 sum_dim1 = torch.einsum('ijk->ik', X) # 等价于 X.sum(dim=1) ``` #### 3. 外积 ```python a = torch.tensor([1, 2]) b = torch.tensor([3, 4, 5]) # 等价于 torch.outer(a, b) outer = torch.einsum('i,j->ij', a, b) # 输出: tensor([[3,4,5], [6,8,10]]) ``` --- ### 四、使用技巧 1. **省略号语法**:处理可变批次维度 ```python # 适用于任意批次大小的矩阵乘法 result = torch.einsum('...ij,...jk->...ik', batch_A, batch_B) ``` 2. **性能提示**: - 对于简单操作(如矩阵乘法),专用函数 `torch.matmul()` 更快 - 复杂张量操作时 `einsum` 更高效且代码更简洁 3. **错误避免**: - 下标数量必须与张量维度匹配 - 输出下标需在输入中出现(除非求和- 相同下标字母的维度长度必须一致 > 引用说明:爱因斯坦求和约定可简化张量运算表达式[^1],PyTorch 的实现与 NumPy 的 `einsum` 高度兼容。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

libo-coder

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值