目录
torch.einsum
是 PyTorch 中用于实现 爱因斯坦求和约定 的接口。它提供了一种简洁的方式来表达张量之间的各种操作(如矩阵乘法、点积、外积、转置等)。其核心思想是通过下标字符串描述张量的操作规则。
常用操作示例:
# 1. 矩阵乘法
A = torch.randn(3, 4)
B = torch.randn(4, 5)
C = torch.einsum("ij,jk->ik", A, B) # 等价于 A @ B
# 2. 点积(内积)
a = torch.randn(5)
b = torch.randn(5)
result = torch.einsum("i,i->", a, b) # 等价于 torch.dot(a, b)
# 3. 外积
a = torch.randn(3)
b = torch.randn(4)
result = torch.einsum("i,j->ij", a, b) # 等价于 torch.outer(a,b)
# 4. 逐元素乘法(Hadamard积)
A = torch.randn(3, 4)
B = torch.randn(3, 4)
result = torch.einsum("ij,ij->ij", A, B) # 等价于 A * B
# 5. 转置
A = torch.randn(3, 4)
result = torch.einsum("ij->ji", A) # 等价于 A.T
# 6. 求和