文章目录
规则
- 𝑡𝑜𝑟𝑐ℎ.𝑒𝑖𝑛𝑠𝑢𝑚(𝑒𝑞𝑢𝑎𝑡𝑖𝑜𝑛,∗∗𝑜𝑝𝑒𝑟𝑎𝑛𝑑𝑠)
- 形式:
"Free indices -> Summation indices"
- einsum允许计算许多常见的多维线性代数阵列运算,方法是基于爱因斯坦求和约定以简写格式表示它们,省略了求和号
- 在箭头左边:用下标标记输入operands的每个维度
- 在箭头右边:定义哪些下标是输出的一部分
- 将operands元素与下标不属于输出的维度的乘积求和来计算输出
- 形式:
- 三条基本规则
- Free indices 中逗号前后重复的下标维度,输入张量沿着该维度做乘法操作,
以矩阵乘法为例:"ik,kj->ij",k 在输入中重复出现,所以就是把 a 和 b 沿着 k 这个维度作相乘操作
- Free indices 中出现而
Summation indices
中没有出现的下标维度,计算结果需要在这个维度上求和 - Summation indices的下标可以改变顺序,
如以矩阵乘法+转置为例,"ik,kj->ji",那么就是返回输出结果的转置
- Free indices 中逗号前后重复的下标维度,输入张量沿着该维度做乘法操作,
- 两条额外规则
- Summation indices 可以不写,einsum会按照一套规则自动推导:把输入中只出现一次的索引取出来,然后按字母表顺序排列
- 支持 “…” 省略号,用于表示用户并不关心的索引
例子
获取对角线元素
torch.einsum('ii->i', torch.randn(4, 4))
等价于torch.diagonal(torch.randn(4, 4), 0)
- PS:
torch.diagonal(A, offset, d1, d2)
表示取矩阵A的d1和d2个维度的二维矩阵的对角元素,offset表示往右上偏移多少的对角元素
- PS:
矩阵的迹
torch.einsum('ii', torch.randn(4, 4))
矩阵转置
torch.einsum('...ij->...ji', A)
求和
torch.einsum('ij->', [a])
- 列求和:
torch.einsum('ij->j', [a])
==torch.sum(a, 0)
矩阵乘法
- 以下操作等价:
c = torch.matmul(a,b)
c = torch.einsum('ik,kj->ij', a, b)
c = a.mm(b)
c = torch.mm(a, b)
c = a @ b
矩阵内积
torch.einsum('ij,ij->', [a, b])
哈达玛积(对应位置点乘)
torch.einsum('ij,ij->ij', [a, b])
矩阵和向量相乘
torch.einsum('ik,k->i', [a, b])
向量相乘
torch.einsum('i,i->', [a, b])
== torch.dot(a, b)
向量外积
torch.einsum('i,j->ij', [a, b])
batch matrix multiplication
torch.einsum('bij,bjk->bik', As, Bs)
==torch.bmm(As, Bs)
multi headed attention
KV^T
:
# q k v均为 [seq_len, batch_size, heads, d_k]
torch.einsum('ibhd,jbhd->ijbh', query, key)
- attention * V
# attn [seq_len, seq_len, batch_size, heads]
# value [seq_len, batch_size, heads, d_k]
x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
# x [seq_len, batch_size, heads, d_k]