a = torch.zeros([2,3,4,5,6])
a = torch.einsum('...ij -> ...ji', a)print(f"the dealed a shape is {a.shape}")# the dealed a shape is torch.Size([2, 3, 4, 6, 5])
实例
# 取出对角线元素
b = torch.arange(16).reshape(4,4)print(b)
b = torch.einsum('ii->i', b)print(b)# tensor([[ 0, 1, 2, 3],# [ 4, 5, 6, 7],# [ 8, 9, 10, 11],# [12, 13, 14, 15]])# tensor([ 0, 5, 10, 15])# get sum
a = torch.arange(6).reshape(2,3)print(a)
a = torch.einsum('ij ->', a)print(a)# tensor([[0, 1, 2],# [3, 4, 5]])# tensor(15)# get sum by row、clo
a = torch.arange(6).reshape(2,3)print(f"sum example : \na = {a}")print(torch.einsum('ij -> i', a))print(torch.einsum('ij -> j', a))# sum example : # a = tensor([[0, 1, 2],# [3, 4, 5]])# tensor([ 3, 12])# tensor([3, 5, 7])
用法表格
rearange
import torch
from einops import rearrange
images = torch.randn((32,30,40,3))# (32, 30, 40, 3)print(rearrange(images,'b h w c -> b h w c').shape)# (960, 40, 3)print(rearrange(images,'b h w c -> (b h) w c').shape)# (30, 1280, 3)print(rearrange(images,'b h w c -> h (b w) c').shape)# (32, 3, 30, 40)print(rearrange(images,'b h w c -> b c h w').shape)# (32, 3600)print(rearrange(images,'b h w c -> b (c h w)').shape)# ---------------------------------------------# 这里(h h1) (w w1)就相当于h与w变为原来的1/h1,1/w1倍# (128, 15, 20, 3)print(rearrange(images,'b (h h1) (w w1) c -> (b h1 w1) h w c', h1=2, w1=2).shape)# (32, 15, 20, 12)print(rearrange(images,'b (h h1) (w w1) c -> b h w (c h1 w1)', h1=2, w1=2).shape)
repeat
import torch
from einops import repeat
image = torch.randn((30,40))# 整体复制 (30, 40, 3)print(repeat(image,'h w -> h w c', c=3).shape)# 按行复制 (60, 40)print(repeat(image,'h w -> (repeat h) w', repeat=2).shape)# 按列复制 (30, 120) 注意:(repeat w)与(w repeat)结果是不同的print(repeat(image,'h w -> h (repeat w)', repeat=3).shape)# (60, 80)print(repeat(image,'h w -> (h h2) (w w2)', h2=2, w2=2).shape)
reduce
import torch
from einops importreduce
x = torch.randn(3,5,5)# (5, 5)print(reduce(x,'c h w -> h w','max').shape)
x = torch.randn(1,3,6,6)# (1, 3, 3, 3) 注意:如果不是整除会报错
y1 =reduce(x,'b c (h h1) (w w1) -> b c h w','max', h1=2, w1=2)print(y1.shape)# Adaptive max-pooling:(1, 3, 3, 2)print(reduce(x,'b c (h h1) (w w1) -> b c h1 w1','max', h1=3, w1=2).shape)# Global average pooling:(1, 3)print(reduce(x,'b c h w -> b c','mean').shape)