torch.einsum用法详解
概述
torch.einsum
是PyTorch中的一个函数,用于执行爱因斯坦求和约定(Einstein summation)运算。它提供了一种灵活而强大的方式来执行多维张量的操作和变换。
torch.einsum
的基本语法如下:
torch.einsum(equation, *operands)
其中,equation
是一个字符串,用于指定爱因斯坦求和约定的运算方式,operands
是一个或多个输入张量。
在equation
中,你可以使用大写字母表示张量的维度标识符,使用小写字母表示对应维度的长度。通过指定输入张量和输出张量之间的维度关系,你可以定义所需的运算操作。
下面是一个简单的例子,展示了如何使用torch.einsum
来执行一些基本操作:
import torch
# 两个向量的点积
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
result = torch.einsum('i,i->', a, b)
print(result) # 输出: 32
# 矩阵乘法
A = torch.tensor([[1, 2], [3, 4]])
B = torch.tensor([[5, 6], [7, 8]])
result = torch.einsum('ij,jk->ik', A, B)
print(result) # 输出: tensor([[19, 22], [43, 50]])
# 张量缩并
C = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
result = torch.einsum('ijk->jk', C)
print(result) # 输出: tensor([[6, 8], [10, 12]])
# 张量迹
D = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
result = torch.einsum('ii->', D)
print(result) # 输出: 15
在这些例子中,我们使用torch.einsum
函数执行了不同的运算,包括点积、矩阵乘法、张量缩并和张量迹。你可以通过调整equation
字符串来定义所需的运算操作。
请注意,torch.einsum
函数的性能可能会受到一些限制。在处理大型张量时,它可能不如其他专门优化的函数(如torch.matmul
)高效。因此,在选择使用torch.einsum
时,你需要权衡性能和灵活性之间的权衡。
希望这个例子能够帮助你理解如何使用torch.einsum
函数进行爱因斯坦求和约定运算。详细了解和熟悉equation
字符串的语法和用法,可以让你在处理多维张量时更具表达力和灵活性。
equation的写法
在torch.einsum
中,equation
字符串用于指定张量操作的具体方式。它由两部分组成:输入标记和输出标记,用箭头 “->” 分隔。表示维度的字符只能是26个英文字母 ‘a’ - ‘z’
输入标记描述了输入张量的维度和形状,输出标记描述了输出张量的维度和形状。每个标记都由一个或多个字母组成,用逗号分隔。
equation 中的字符也可以理解为索引,就是输出张量的某个位置的值,是怎么从输入张量中得到的,比如上面矩阵乘法的输出 c 的某个点 c[i, j] 的值是通过 a[i, k] 和 b[k, j] 沿着 k 这个维度做内积得到的。
1. 基本概念
接着介绍两个基本概念,自由索引(Free indices)和求和索引(Summation indices):
自由索引,出现在箭头右边的索引,比如上面的例子就是 i 和 j;
求和索引,只出现在箭头左边的索引,表示中间计算结果需要这个维度上求和之后才能得到输出,比如上面的例子就是 k;
2. 求和准则
- 规则一,equation 箭头左边,在不同输入之间重复出现的索引表示,把输入张量沿着该维度做乘法操作,比如还是以上面矩阵乘法为例, “ik,kj->ij”,k 在输入中重复出现,所以就是把 a 和 b 沿着 k 这个维度作相乘操作;
- 规则二,只出现在 equation 箭头左边的索引,表示中间计算结果需要在这个维度上求和,也就是上面提到的求和索引;
- 规则三,equation 箭头右边的索引顺序可以是任意的,比如上面的 “ik,kj->ij” 如果写成 “ik,kj->ji”,那么就是返回输出结果的转置,用户只需要定义好索引的顺序,转置操作会在 einsum 内部完成。
特殊规则有两条:
- equation 可以不写包括箭头在内的右边部分,那么在这种情况下,输出张量的维度会根据默认规则推导。就是把输入中只出现一次的索引取出来,然后按字母表顺序排列,比如上面的矩阵乘法 “ik,kj->ij” 也可以简化为 “ik,kj”,根据默认规则,输出就是 “ij” 与原来一样;
- equation 中支持 “…” 省略号,用于表示用户并不关心的索引,比如只对一个高维张量的最后两维做转置可以这么写:
a = torch.randn(2,3,5,7,9)
# i = 7, j = 9
b = torch.einsum('...ij->...ji', [a])
3. 示例
下面是一些示例,说明了equation
字符串的常见用法:
- 矩阵乘法:
对于两个矩阵相乘的操作,可以使用equation
字符串 'ij,jk->ik'
。其中 'ij'
表示第一个输入张量的形状为 (i, j)
,'jk'
表示第二个输入张量的形状为 (j, k)
,'ik'
表示输出张量的形状为 (i, k)
。
- 张量缩减和求和:
要对某些维度进行缩减和求和,可以使用equation
字符串 'ijk->'
或 'ijk->i'
。在这里,'ijk'
表示输入张量的形状为 (i, j, k)
,''
表示输出张量为空(即缩减所有维度),而 'i'
表示输出张量的形状为 (i,)
(即只在第一维度上进行求和)。
- 元素级乘法和广播:
要对两个张量进行元素级乘法并进行广播,可以使用equation
字符串 'ij,j->ij'
。在这里,'ij'
表示第一个输入张量的形状为 (i, j)
,'j'
表示第二个输入张量的形状为 (j,)
,'ij'
表示输出张量的形状与第一个输入张量相同。
在equation
字符串中,字母通常用于标识不同的维度。例如,'i'
表示第一维度,'j'
表示第二维度,以此类推。你可以根据需要使用任何合法的字母来标识维度。
重要的是要确保equation
字符串正确地表示所需的张量操作。可以通过检查输入和输出张量的形状来验证结果是否符合预期。
希望这个解释能够帮助你更好地理解equation
字符串的写法和字母的含义。
进阶
当涉及到更复杂的张量操作时,torch.einsum
可以提供更大的灵活性。下面是一些更多的示例,展示了torch.einsum
的一些高级用法:
- 广播和元素级乘法:
import torch
A = torch.tensor([[1, 2], [3, 4]])
B = torch.tensor([5, 6])
result = torch.einsum('ij,j->ij', A, B)
print(result) # 输出: tensor([[ 5, 12], [15, 24]])
在这个例子中,我们使用torch.einsum
进行广播和元素级乘法。通过指定equation
为'ij,j->ij'
,我们将张量A
的每一行与向量B
进行乘法操作,并生成一个相同形状的输出张量。
- 张量的转置和重排:
import torch
A = torch.tensor([[1, 2], [3, 4]])
result = torch.einsum('ij->ji', A)
print(result) # 输出: tensor([[1, 3], [2, 4]])
B = torch.tensor([[1, 2, 3], [4, 5, 6]])
result = torch.einsum('ij->jik', B)
print(result) # 输出: tensor([[[1, 2, 3]], [[4, 5, 6]]])
在这个例子中,我们使用torch.einsum
来执行张量的转置和重排操作。通过调整equation
字符串,我们可以指定所需的维度顺序和形状。
- 批量矩阵乘法:
import torch
A = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
B = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
result = torch.einsum('bij,bjk->bik', A, B)
print(result) # 输出: tensor([[[ 7, 10], [15, 22]], [[47, 58], [67, 82]]])
在这个例子中,我们使用torch.einsum
执行批量矩阵乘法操作。通过在equation
中引入一个额外的维度标识符b
,我们可以指定对每个批次中的矩阵执行矩阵乘法。
这些示例只是展示了torch.einsum
函数的一些用法。实际上,你可以根据需要创建更复杂的equation
字符串,以执行各种张量操作,例如卷积、张量拼接、张量分解等。
关于torch.einsum
的完整语法和用法,请参阅PyTorch官方文档的相关部分:torch.einsum。
希望这些例子能帮助你更深入地理解和应用torch.einsum
函数。它是一个非常强大的工具,可以帮助你在PyTorch中进行高级的张量操作。
当涉及更复杂的张量操作时,torch.einsum
还有一些其他的高级用法。以下是更多示例,展示了torch.einsum
的一些功能:
- 维度缩减和求和:
import torch
A = torch.tensor([[1, 2], [3, 4]])
result = torch.einsum('ij->', A)
print(result) # 输出: 10
B = torch.tensor([[1, 2], [3, 4]])
result = torch.einsum('ij->i', B)
print(result) # 输出: tensor([3, 7])
在这个例子中,我们使用torch.einsum
来执行维度缩减和求和操作。通过指定equation
字符串中的维度标识符,我们可以选择性地缩减维度并求和。
- 张量拼接和分割:
import torch
A = torch.tensor([[1, 2], [3, 4]])
B = torch.tensor([[5, 6]])
result = torch.einsum('ij,ik->ijk', A, B)
print(result) # 输出: tensor([[[ 5, 6], [10, 12]], [[15, 18], [20, 24]]])
C = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
result = torch.einsum('ijk->ik', C)
print(result) # 输出: tensor([[ 4, 6], [12, 14]])
在这个例子中,我们使用torch.einsum
来执行张量的拼接和分割操作。通过在equation
字符串中调整维度标识符,我们可以指定所需的拼接形状或分割方式。
- 张量对齐和约简:
import torch
A = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
B = torch.tensor([[1, 2], [3, 4]])
result = torch.einsum('ijk,kl->ijl', A, B)
print(result) # 输出: tensor([[[ 7, 10], [15, 22]], [[23, 34], [31, 46]]])
在这个例子中,我们使用torch.einsum
来执行张量的对齐和约简操作。通过在equation
字符串中指定相同的维度标识符,我们可以对张量进行逐元素相乘并相加以实现约简。
这些示例展示了torch.einsum
函数的更多用法。你可以根据需要创建更复杂的equation
字符串,以执行各种张量操作,如张量合并、张量分解、张量乘法和其他高级操作。
请注意,torch.einsum
的灵活性和强大性意味着你需要仔细设计和验证equation
字符串,以确保所得到的结果是你期望的。
希望这些例子能够进一步帮助你理解和应用torch.einsum
函数。它是一个非常有用的工具,可以在PyTorch中执行高级的张量操作。
更多阅读
https://pytorch.org/docs/stable/generated/torch.einsum.html
https://zhuanlan.zhihu.com/p/361209187