torch.einsum用法详解

本文详细介绍了PyTorch库中的torch.einsum函数,包括其基本概念、equation字符串的写法以及各种张量操作示例,如点积、矩阵乘法、张量缩并和求和。它强调了如何通过灵活的索引约定进行高级张量运算。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

torch.einsum用法详解

概述

torch.einsum是PyTorch中的一个函数,用于执行爱因斯坦求和约定(Einstein summation)运算。它提供了一种灵活而强大的方式来执行多维张量的操作和变换。

torch.einsum的基本语法如下:

torch.einsum(equation, *operands)

其中,equation是一个字符串,用于指定爱因斯坦求和约定的运算方式,operands是一个或多个输入张量。

equation中,你可以使用大写字母表示张量的维度标识符,使用小写字母表示对应维度的长度。通过指定输入张量和输出张量之间的维度关系,你可以定义所需的运算操作。

下面是一个简单的例子,展示了如何使用torch.einsum来执行一些基本操作:

import torch

# 两个向量的点积
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
result = torch.einsum('i,i->', a, b)
print(result)  # 输出: 32

# 矩阵乘法
A = torch.tensor([[1, 2], [3, 4]])
B = torch.tensor([[5, 6], [7, 8]])
result = torch.einsum('ij,jk->ik', A, B)
print(result)  # 输出: tensor([[19, 22], [43, 50]])

# 张量缩并
C = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
result = torch.einsum('ijk->jk', C)
print(result)  # 输出: tensor([[6, 8], [10, 12]])

# 张量迹
D = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
result = torch.einsum('ii->', D)
print(result)  # 输出: 15

在这些例子中,我们使用torch.einsum函数执行了不同的运算,包括点积、矩阵乘法、张量缩并和张量迹。你可以通过调整equation字符串来定义所需的运算操作。

请注意,torch.einsum函数的性能可能会受到一些限制。在处理大型张量时,它可能不如其他专门优化的函数(如torch.matmul)高效。因此,在选择使用torch.einsum时,你需要权衡性能和灵活性之间的权衡。

希望这个例子能够帮助你理解如何使用torch.einsum函数进行爱因斯坦求和约定运算。详细了解和熟悉equation字符串的语法和用法,可以让你在处理多维张量时更具表达力和灵活性。

equation的写法

torch.einsum中,equation字符串用于指定张量操作的具体方式。它由两部分组成:输入标记和输出标记,用箭头 “->” 分隔。表示维度的字符只能是26个英文字母 ‘a’ - ‘z’

输入标记描述了输入张量的维度和形状,输出标记描述了输出张量的维度和形状。每个标记都由一个或多个字母组成,用逗号分隔。

equation 中的字符也可以理解为索引,就是输出张量的某个位置的值,是怎么从输入张量中得到的,比如上面矩阵乘法的输出 c 的某个点 c[i, j] 的值是通过 a[i, k] 和 b[k, j] 沿着 k 这个维度做内积得到的。

1. 基本概念

接着介绍两个基本概念,自由索引(Free indices)和求和索引(Summation indices):

自由索引,出现在箭头右边的索引,比如上面的例子就是 i 和 j;
求和索引,只出现在箭头左边的索引,表示中间计算结果需要这个维度上求和之后才能得到输出,比如上面的例子就是 k;

2. 求和准则

  • 规则一,equation 箭头左边,在不同输入之间重复出现的索引表示,把输入张量沿着该维度做乘法操作,比如还是以上面矩阵乘法为例, “ik,kj->ij”,k 在输入中重复出现,所以就是把 a 和 b 沿着 k 这个维度作相乘操作;
  • 规则二,只出现在 equation 箭头左边的索引,表示中间计算结果需要在这个维度上求和,也就是上面提到的求和索引;
  • 规则三,equation 箭头右边的索引顺序可以是任意的,比如上面的 “ik,kj->ij” 如果写成 “ik,kj->ji”,那么就是返回输出结果的转置,用户只需要定义好索引的顺序,转置操作会在 einsum 内部完成。

特殊规则有两条:

  • equation 可以不写包括箭头在内的右边部分,那么在这种情况下,输出张量的维度会根据默认规则推导。就是把输入中只出现一次的索引取出来,然后按字母表顺序排列,比如上面的矩阵乘法 “ik,kj->ij” 也可以简化为 “ik,kj”,根据默认规则,输出就是 “ij” 与原来一样;
  • equation 中支持 “…” 省略号,用于表示用户并不关心的索引,比如只对一个高维张量的最后两维做转置可以这么写:
a = torch.randn(2,3,5,7,9)
# i = 7, j = 9
b = torch.einsum('...ij->...ji', [a])

3. 示例

下面是一些示例,说明了equation字符串的常见用法:

  1. 矩阵乘法:

对于两个矩阵相乘的操作,可以使用equation字符串 'ij,jk->ik'。其中 'ij' 表示第一个输入张量的形状为 (i, j)'jk' 表示第二个输入张量的形状为 (j, k)'ik' 表示输出张量的形状为 (i, k)

  1. 张量缩减和求和:

要对某些维度进行缩减和求和,可以使用equation字符串 'ijk->''ijk->i'。在这里,'ijk' 表示输入张量的形状为 (i, j, k)'' 表示输出张量为空(即缩减所有维度),而 'i' 表示输出张量的形状为 (i,)(即只在第一维度上进行求和)。

  1. 元素级乘法和广播:

要对两个张量进行元素级乘法并进行广播,可以使用equation字符串 'ij,j->ij'。在这里,'ij' 表示第一个输入张量的形状为 (i, j)'j' 表示第二个输入张量的形状为 (j,)'ij' 表示输出张量的形状与第一个输入张量相同。

equation字符串中,字母通常用于标识不同的维度。例如,'i' 表示第一维度,'j' 表示第二维度,以此类推。你可以根据需要使用任何合法的字母来标识维度。

重要的是要确保equation字符串正确地表示所需的张量操作。可以通过检查输入和输出张量的形状来验证结果是否符合预期。

希望这个解释能够帮助你更好地理解equation字符串的写法和字母的含义。

进阶

当涉及到更复杂的张量操作时,torch.einsum可以提供更大的灵活性。下面是一些更多的示例,展示了torch.einsum的一些高级用法:

  1. 广播和元素级乘法:
import torch

A = torch.tensor([[1, 2], [3, 4]])
B = torch.tensor([5, 6])
result = torch.einsum('ij,j->ij', A, B)
print(result)  # 输出: tensor([[ 5, 12], [15, 24]])

在这个例子中,我们使用torch.einsum进行广播和元素级乘法。通过指定equation'ij,j->ij',我们将张量A的每一行与向量B进行乘法操作,并生成一个相同形状的输出张量。

  1. 张量的转置和重排:
import torch

A = torch.tensor([[1, 2], [3, 4]])
result = torch.einsum('ij->ji', A)
print(result)  # 输出: tensor([[1, 3], [2, 4]])

B = torch.tensor([[1, 2, 3], [4, 5, 6]])
result = torch.einsum('ij->jik', B)
print(result)  # 输出: tensor([[[1, 2, 3]], [[4, 5, 6]]])

在这个例子中,我们使用torch.einsum来执行张量的转置和重排操作。通过调整equation字符串,我们可以指定所需的维度顺序和形状。

  1. 批量矩阵乘法:
import torch

A = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
B = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
result = torch.einsum('bij,bjk->bik', A, B)
print(result)  # 输出: tensor([[[ 7, 10], [15, 22]], [[47, 58], [67, 82]]])

在这个例子中,我们使用torch.einsum执行批量矩阵乘法操作。通过在equation中引入一个额外的维度标识符b,我们可以指定对每个批次中的矩阵执行矩阵乘法。

这些示例只是展示了torch.einsum函数的一些用法。实际上,你可以根据需要创建更复杂的equation字符串,以执行各种张量操作,例如卷积、张量拼接、张量分解等。

关于torch.einsum的完整语法和用法,请参阅PyTorch官方文档的相关部分:torch.einsum

希望这些例子能帮助你更深入地理解和应用torch.einsum函数。它是一个非常强大的工具,可以帮助你在PyTorch中进行高级的张量操作。

当涉及更复杂的张量操作时,torch.einsum还有一些其他的高级用法。以下是更多示例,展示了torch.einsum的一些功能:

  1. 维度缩减和求和:
import torch

A = torch.tensor([[1, 2], [3, 4]])
result = torch.einsum('ij->', A)
print(result)  # 输出: 10

B = torch.tensor([[1, 2], [3, 4]])
result = torch.einsum('ij->i', B)
print(result)  # 输出: tensor([3, 7])

在这个例子中,我们使用torch.einsum来执行维度缩减和求和操作。通过指定equation字符串中的维度标识符,我们可以选择性地缩减维度并求和。

  1. 张量拼接和分割:
import torch

A = torch.tensor([[1, 2], [3, 4]])
B = torch.tensor([[5, 6]])
result = torch.einsum('ij,ik->ijk', A, B)
print(result)  # 输出: tensor([[[ 5,  6], [10, 12]], [[15, 18], [20, 24]]])

C = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
result = torch.einsum('ijk->ik', C)
print(result)  # 输出: tensor([[ 4,  6], [12, 14]])

在这个例子中,我们使用torch.einsum来执行张量的拼接和分割操作。通过在equation字符串中调整维度标识符,我们可以指定所需的拼接形状或分割方式。

  1. 张量对齐和约简:
import torch

A = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
B = torch.tensor([[1, 2], [3, 4]])
result = torch.einsum('ijk,kl->ijl', A, B)
print(result)  # 输出: tensor([[[ 7, 10], [15, 22]], [[23, 34], [31, 46]]])

在这个例子中,我们使用torch.einsum来执行张量的对齐和约简操作。通过在equation字符串中指定相同的维度标识符,我们可以对张量进行逐元素相乘并相加以实现约简。

这些示例展示了torch.einsum函数的更多用法。你可以根据需要创建更复杂的equation字符串,以执行各种张量操作,如张量合并、张量分解、张量乘法和其他高级操作。

请注意,torch.einsum的灵活性和强大性意味着你需要仔细设计和验证equation字符串,以确保所得到的结果是你期望的。

希望这些例子能够进一步帮助你理解和应用torch.einsum函数。它是一个非常有用的工具,可以在PyTorch中执行高级的张量操作。

更多阅读

https://pytorch.org/docs/stable/generated/torch.einsum.html
https://zhuanlan.zhihu.com/p/361209187

评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值