torch.nn.Unfold(kernel_size, dilation=1, padding=0, stride=1)
参数
- kernel_size,滑动窗口的大小
- dilation,和空洞卷积中的dilation一样,控制滑动窗口取值的间隔
- padding,控制输入张量四周padding的大小
- stride,控制滑动窗口滑动的步长
输入输出维度
- 输入维度为 (N, C, *),其中N代表batch_size,C为通道数,*代表任意的空间维度。
- 输出维度为 \((N,C\times \prod (kernel\_size),L)\),其中 \(\prod (kernel\_size)\) 表示滑动窗口的大小,\(L\) 表示滑动窗口在输入张量中可以滑动的所有位置的个数。\(L\) 的大小计算如下
这里 \(L\) 的计算方式和卷积中输出大小的计算是一模一样的
这个函数的作用就是通过一个滑动窗口遍历输入张量,在这个过程中,取出输入张量中滑动窗口对应位置的值,组成输出向量。
这个过程有点类似于卷积,卷积过程中滑动窗口就是卷积核,卷积核滑动到某个位置,取出输入张量中对应位置的值与该卷积核对应位置的值分别相乘最后相加得到单个值的输出。而unfold是滑动到某一位置,直接取出输入张量中对应位置的值,然后flatten成一个一维向量,放到输出张量的对应位置处。
比如下图中的一个输入张量,batch_size=1,shape=(1, C, W, H),我们想沿WxH维度取出一个个局部特征,就可以通过unfold来实现。
示例
如下代码所示,输入x是一个维度为(2, 5, 3, 4)的张量,其中2是batch_size,5是通道数,3,4代表spatial维度的大小,对应上面输入维度中的 *
kernel大小为(2, 3),unfold做的就是在spatial维度滑动该kernel对应的窗口,取出对应的值,注意这里滑动窗口时取窗口内所有通道的值。上面的 \(\prod (kernel\_size)=2\times 3=6\),窗口内共有2x3x5=30个值,对应输出维度中的 \(C\times \prod (kernel\_size)\)。然后把窗口内对应的block展平成一个一维向量,放到输出中。接下来因为默认stride=1,dilation=1,padding=0,因此(2, 3)大小的kernel在维度(3, 4)中两个方向都只能滑动一次,因此滑动窗口只有4个位置,对应上面的 \(L=4\)。
x = torch.randn(2, 5, 3, 4)
output = F.unfold(x, kernel_size=(2, 3))
# each patch contains 30 values (2x3=6 vectors, each of 5 channels)
# 4 blocks (2x3 kernels) in total in the 3x4 input
print(output.size())
# torch.Size([2, 30, 4])
参考
Unfold — PyTorch 1.12 documentation
PyTorch中torch.nn.functional.unfold函数使用详解_咆哮的阿杰的博客-优快云博客_pytorch unfold