1. 提出问题
对于一个2D特征图
x
∈
R
M
×
N
×
d
x \in\mathbb R^{M\times N \times d}
x∈RM×N×d,空间尺寸
M
×
N
M \times N
M×N,将其划分为
s
p
×
s
p
s_p \times s_p
sp×sp 的窗格,请问用代码如何表示
(上面这个问题来自Focal Self-attention for Local-Global Interactions in Vision Transformers论文,地址)
举例:下面图示一个20x20的2D tensor,目标是将这个2D tensor划分为 4x4的窗格,所以希望输出 5x5x2x2的5D tensor
2. 解决方案
# 假设窗格尺寸为 2x2
# windows 就是
import torch
def getWindow(x, win_size):
M, N, C = x.shape
x = x.view(M, M // window_size, window_size, N // window_size, window_size, C)
windows = x.permute(0, 2, 1, 3, 4).contiguous()
return windows
3. 讨论
首先,使用torch的tensor.view方法,将输入的
x
x
x重新排列为
x
^
∈
R
M
s
p
×
s
p
×
N
s
p
×
s
p
×
d
\hat{x} \in\mathbb R^{\frac{M}{s_p}\times s_p \times \frac{N}{s_p} \times s_p \times d}
x^∈RspM×sp×spN×sp×d
然后,使用torch的tensor.permute方法,将
x
x
x重新排列为
x
^
∈
R
M
s
p
×
N
s
p
×
s
p
×
s
p
×
d
\hat{x} \in\mathbb R^{\frac{M}{s_p}\times \frac{N}{s_p} \times s_p \times s_p \times d}
x^∈RspM×spN×sp×sp×d
一开始觉得很神奇,也不理解为什么用以上这两步操作就可以得到结果。仔细思考,view的作用是将
x
x
x划分为所需的“粒度”。而permute的作用是将粒度重新排列成想要的顺序。
下面做一个测试:假设输入8x6的2D数据,win_size是2
# 准备数据
H = 8
W = 6
data = np.arange(H*W)
print(data)
"""
[ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47]
"""
data = torch.tensor(data)
data = data.reshape(H,W)
print(data)
"""
tensor([[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23],
[24, 25, 26, 27, 28, 29],
[30, 31, 32, 33, 34, 35],
[36, 37, 38, 39, 40, 41],
[42, 43, 44, 45, 46, 47]], dtype=torch.int32)
"""
data_view = data.view(int(H/2), 2, int(W/2), 2)
print(data_view)
"""
tensor([[[[ 0, 1],
[ 2, 3],
[ 4, 5]],
[[ 6, 7],
[ 8, 9],
[10, 11]]],
[[[12, 13],
[14, 15],
[16, 17]],
[[18, 19],
[20, 21],
[22, 23]]],
[[[24, 25],
[26, 27],
[28, 29]],
[[30, 31],
[32, 33],
[34, 35]]],
show more (open the raw output data in a text editor) ...
[40, 41]],
[[42, 43],
[44, 45],
[46, 47]]]], dtype=torch.int32)
"""
windows = data_view.permute(0, 2, 1, 3).contiguous()
print(windows)
"""
tensor([[[[ 0, 1],
[ 6, 7]],
[[ 2, 3],
[ 8, 9]],
[[ 4, 5],
[10, 11]]],
[[[12, 13],
[18, 19]],
[[14, 15],
[20, 21]],
[[16, 17],
[22, 23]]],
[[[24, 25],
[30, 31]],
[[26, 27],
[32, 33]],
show more (open the raw output data in a text editor) ...
[[38, 39],
[44, 45]],
[[40, 41],
[46, 47]]]], dtype=torch.int32)
"""
这里补充一个pytorch view和reshape的区别:
https://stackoverflow.com/questions/49643225/whats-the-difference-between-reshape-and-view-in-pytorch
一般情况下,建议用view
6. 参考
https://stackoverflow.com/questions/48915810/pytorch-what-does-contiguous-do
https://jdhao.github.io/2019/07/10/pytorch_view_reshape_transpose_permute/
https://clay-atlas.com/us/blog/2021/08/11/pytorch-en-view-permute-change-dimensions/
https://stackoverflow.com/questions/49643225/whats-the-difference-between-reshape-and-view-in-pytorch