einsum is all you needed

部署运行你感兴趣的模型镜像

如果问pytorch中最强大的一个数学函数是什么?

我会说是torch.einsum:爱因斯坦求和函数。

它几乎是一个"万能函数":能实现超过一万种功能的函数。

不仅如此,和其它pytorch中的函数一样,torch.einsum是支持求导和反向传播的,并且计算效率非常高。

einsum 提供了一套既简洁又优雅的规则,可实现包括但不限于:内积,外积,矩阵乘法,转置和张量收缩(tensor contraction)等张量操作,熟练掌握 einsum 可以很方便的实现复杂的张量操作,而且不容易出错。

尤其是在一些包括batch维度的高阶张量的相关计算中,若使用普通的矩阵乘法、求和、转置等算子来实现很容易出现维度匹配等问题,但换成einsum则会特别简单。

套用一句深度学习paper标题当中非常时髦的话术,torch.einsum is all you needed 😋!

公众号后台回复关键词:einsum,获取本文源代码链接。

一,einsum规则原理

顾名思义,einsum这个函数的思想起源于家喻户晓的小爱同学:爱因斯坦~。

很久很久以前,小爱同学在捣鼓广义相对论。广义相对论表述各种物理量用的都是张量。

比如描述时空有一个四维时空度规张量,描述电磁场有一个电磁张量,描述运动的有能量动量张量。

d4e8e9094e267977ac8aac438fa13165.jpeg

在理论物理学家中,小爱同学的数学基础不算特别好,在捣鼓这些张量的时候,他遇到了一个比较头疼的问题:公式太长太复杂了。

有没有什么办法让这些张量运算公式稍微显得对人类友好一些呢,能不能减少一些那种扭曲的求和符号呢?

小爱发现,求和导致维度收缩,因此求和符号操作的指标总是只出现在公式的一边。

例如在我们熟悉的矩阵乘法中

k这个下标被求和了,求和导致了这个维度的消失,所以它只出现在右边而不出现在左边。

这种只出现在张量公式的一边的下标被称之为哑指标,反之为自由指标。

小爱同学脑瓜子滴溜一转,反正这种只出现在一边的哑指标一定是被求和求掉的,干脆把对应的求和符号省略得了。

这就是爱因斯坦求和约定:

只出现在公式一边的指标叫做哑指标,针对哑指标的求和符号可以省略

公式立刻清爽了很多。

这个公式表达的含义如下:

C这个张量的第i行第j列由这个张量的第i行第k列和这个张量的第k行第j列相乘,这样得到的是一个三维张量, 其元素为,然后对在维度k上求和得到。

公式展现形式中除了省去了求和符号,还省去了乘法符号(代数通识)。

借鉴爱因斯坦求和约定表达张量运算的清爽整洁,numpy、tensorflow和 torch等库中都引入了 einsum这个函数。

上述矩阵乘法可以被einsum这个函数表述成

C = torch.einsum("ik,kj->ij",A,B)

这个函数的规则原理非常简洁,3句话说明白。

  • 1,用元素计算公式来表达张量运算。

  • 2,只出现在元素计算公式箭头左边的指标叫做哑指标。

  • 3,省略元素计算公式中对哑指标的求和符号。

import torch 

A = torch.tensor([[1,2],[3,4.0]])
B = torch.tensor([[5,6],[7,8.0]])

C1 = A@B
print(C1)

C2 = torch.einsum("ik,kj->ij",[A,B])
print(C2)
tensor([[19., 22.],
        [43., 50.]])
tensor([[19., 22.],
        [43., 50.]])

二,einsum基础范例

einsum这个函数的精髓实际上是第一条:

用元素计算公式来表达张量运算。

而绝大部分张量运算都可以用元素计算公式很方便地来表达,这也是它为什么会那么神通广大。

例1,张量转置

#例1,张量转置
A = torch.randn(3,4,5)

#B = torch.permute(A,[0,2,1])
B = torch.einsum("ijk->ikj",A) 

print("before:",A.shape)
print("after:",B.shape)
before: torch.Size([3, 4, 5])
after: torch.Size([3, 5, 4])

例2,取对角元

#例2,取对角元
A = torch.randn(5,5)
#B = torch.diagonal(A)
B = torch.einsum("ii->i",A)
print("before:",A.shape)
print("after:",B.shape)
before: torch.Size([5, 5])
after: torch.Size([5])

例3,求和降维

#例3,求和降维
A = torch.randn(4,5)
#B = torch.sum(A,1)
B = torch.einsum("ij->i",A)
print("before:",A.shape)
print("after:",B.shape)
before: torch.Size([4, 5])
after: torch.Size([4])

例4,哈达玛积

#例4,哈达玛积
A = torch.randn(5,5)
B = torch.randn(5,5)
#C=A*B
C = torch.einsum("ij,ij->ij",A,B)
print("before:",A.shape, B.shape)
print("after:",C.shape)
before: torch.Size([5, 5]) torch.Size([5, 5])
after: torch.Size([5, 5])

例5,向量内积

#例5,向量内积
A = torch.randn(10)
B = torch.randn(10)
#C=torch.dot(A,B)
C = torch.einsum("i,i->",A,B)
print("before:",A.shape, B.shape)
print("after:",C.shape)
before: torch.Size([10]) torch.Size([10])
after: torch.Size([])

例6,向量外积

#例6,向量外积
A = torch.randn(10)
B = torch.randn(5)
#C = torch.outer(A,B)
C = torch.einsum("i,j->ij",A,B)
print("before:",A.shape, B.shape)
print("after:",C.shape)
before: torch.Size([10]) torch.Size([5])
after: torch.Size([10, 5])

例7,矩阵乘法

#例7,矩阵乘法
A = torch.randn(5,4)
B = torch.randn(4,6)
#C = torch.matmul(A,B)
C = torch.einsum("ik,kj->ij",A,B)
print("before:",A.shape, B.shape)
print("after:",C.shape)
before: torch.Size([5, 4]) torch.Size([4, 6])
after: torch.Size([5, 6])

例8,张量缩并

#例8,张量缩并
A = torch.randn(3,4,5)
B = torch.randn(4,3,6)
#C = torch.tensordot(A,B,dims=[(0,1),(1,0)])
C = torch.einsum("ijk,jih->kh",A,B)
print("before:",A.shape, B.shape)
print("after:",C.shape)
before: torch.Size([3, 4, 5]) torch.Size([4, 3, 6])
after: torch.Size([5, 6])

三,einsum高级范例

einsum可用于超过两个张量的计算。

例9,bilinear注意力机制

例如:双线性变换。这是向量内积的一种扩展,一种常用的注意力机制实现方式)

不考虑batch维度时,双线性变换的公式如下:

考虑batch维度时,无法用矩阵乘法表示,可以用元素计算公式表达如下:

#例9,bilinear注意力机制

#====不考虑batch维度====
q = torch.randn(10) #query_features
k = torch.randn(10) #key_features
W = torch.randn(5,10,10) #out_features,query_features,key_features
b = torch.randn(5) #out_features

#a = q@W@k.t()+b  
a = torch.bilinear(q,k,W,b)
print("a.shape:",a.shape)


#=====考虑batch维度====
Q = torch.randn(8,10)    #batch_size,query_features
K = torch.randn(8,10)    #batch_size,key_features
W = torch.randn(5,10,10) #out_features,query_features,key_features
b = torch.randn(5)       #out_features

#A = torch.bilinear(Q,K,W,b)
A = torch.einsum('bq,oqk,bk->bo',Q,W,K) + b
print("A.shape:",A.shape)
a.shape: torch.Size([5])
A.shape: torch.Size([8, 5])

例10,scaled-dot-product注意力机制

我们也可以用einsum来实现更常见的scaled-dot-product 形式的 Attention.

不考虑batch维度时,scaled-dot-product形式的Attention用矩阵乘法公式表示如下:

考虑batch维度时,无法用矩阵乘法表示,可以用元素计算公式表达如下:

#例10,scaled-dot-product注意力机制

#====不考虑batch维度====
q = torch.randn(10)  #query_features
k = torch.randn(6,10) #key_size, key_features

d_k = k.shape[-1]
a = torch.softmax(q@k.t()/d_k,-1) 

print("a.shape=",a.shape )

#====考虑batch维度====
Q = torch.randn(8,10)  #batch_size,query_features
K = torch.randn(8,6,10) #batch_size,key_size,key_features

d_k = K.shape[-1]
A = torch.softmax(torch.einsum("in,ijn->ij",Q,K)/d_k,-1) 

print("A.shape=",A.shape )
a.shape= torch.Size([6])
A.shape= torch.Size([8, 6])

以上。

万水千山总是情,点个在看行不行?😋

公众号后台回复关键词:einsum,获取本文源代码链接。

731e439ae994ff5a6b133393ae50d6ab.png

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

使用transfomers库实现encoder-decoder架构的,encoder和decoder都是transformerxl的,使用旋转位置编码的示例代码,旋转编码实现代码如下import torch class RotaryEmbedding(torch.nn.Module): def __init__(self, dim, base=10000): super().__init__() inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer('inv_freq', inv_freq) self.seq_len_cached = 0 self.cos_cached = None self.sin_cached = None def forward(self, x, seq_dim=1): seq_len = x.shape[seq_dim] if seq_len != self.seq_len_cached: #if seq_len > self.seq_len_cached: self.seq_len_cached = seq_len t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) freqs = torch.einsum('i,j->ij', t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1).to(x.device) self.cos_cached = emb.cos()[None,:, None, :] self.sin_cached = emb.sin()[None,:, None, :] #else: # cos_return = self.cos_cached[..., :seq_len] # sin_return = self.sin_cached[..., :seq_len] # return cos_return, sin_return return self.cos_cached, self.sin_cached # rotary pos emb helpers: def rotate_half(x): x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions @torch.jit.script def apply_rotary_pos_emb(q, k, cos, sin): return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)from torch.nn import Linear, Module from fast_transformers.attention import AttentionLayer from fast_transformers.events import EventDispatcher, QKVEvent from .rotary import RotaryEmbedding, apply_rotary_pos_emb class RotateAttentionLayer(AttentionLayer): """Rotate attention layer inherits from fast_transformer attention layer. The only thing added is an Embedding encoding, for more information on the attention layer see the fast_transformers code """ def __init__(self, attention, d_model, n_heads, d_keys=None, d_values=None, event_dispatcher=""): super(RotateAttentionLayer, self).__init__(attention,d_model, n_heads, d_keys=d_keys, d_values=d_values, event_dispatcher=event_dispatcher) self.rotaryemb = RotaryEmbedding(d_keys) print('Using Rotation Embedding') def forward(self, queries, keys, values, attn_mask, query_lengths, key_lengths): """ Using the same frame work as the fast_Transformers attention layer but injecting rotary information to the queries and the keys after the keys and queries are projected. In the argument description we make use of the following sizes - N: the batch size - L: The maximum length of the queries - S: The maximum length of the keys (the actual length per sequence is given by the length mask) - D: The input feature dimensionality passed in the constructor as 'd_model' Arguments --------- queries: (N, L, D) The tensor containing the queries keys: (N, S, D) The tensor containing the keys values: (N, S, D) The tensor containing the values attn_mask: An implementation of BaseMask that encodes where each query can attend to query_lengths: An implementation of BaseMask that encodes how many queries each sequence in the batch consists of key_lengths: An implementation of BaseMask that encodes how many queries each sequence in the batch consists of Returns ------- The new value for each query as a tensor of shape (N, L, D). """ # Extract the dimensions into local variables N, L, _ = queries.shape _, S, _ = keys.shape H = self.n_heads # Project the queries/keys/values queries = self.query_projection(queries).view(N, L, H, -1) keys = self.key_projection(keys).view(N, S, H, -1) cos, sin = self.rotaryemb(queries) queries, keys = apply_rotary_pos_emb(queries, keys, cos, sin) values = self.value_projection(values).view(N, S, H, -1) # Let the world know of the qkv self.event_dispatcher.dispatch(QKVEvent(self, queries, keys, values)) # Compute the attention new_values = self.inner_attention( queries, keys, values, attn_mask, query_lengths, key_lengths ).view(N, L, -1) # Project the output and return return self.out_projection(new_values)
最新发布
07-21
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值