torch.nn.functional.grid_sample()
Given an input and a flow-field grid, computes the output using input values and pixel locations from grid.
简单来说就是,提供一个input的Tensor以及一个对应的flow-field网格(比如光流,体素流等),然后根据grid中每个位置提供的坐标信息(这里指input中pixel的坐标),将input中对应位置的像素值填充到grid指定的位置,得到最终的输出。
一、介绍:
torch.nn.functional.grid_sample
是 PyTorch 中的一个函数,用于对输入图像或张量进行采样和插值。该函数通过指定的网格坐标 (grid
) 来重定位图像中的像素,实现图像的变换(如仿射变换、透视变换等)。
二、用法:
部分转载自链接
关于input、grid以及output的尺寸如下所示:(input也可以是5D的Tensor,这里我们只考虑4D的情况)
input : ( N , C , H i n , W i n ) \text{input}:(N,C,H_{in},W_{in}) input:(N,C,Hin,Win)
grid : ( N , H o u t , W o u t , 2 ) \text{grid}:(N,H_{out},W_{out},2) grid:(N,Hout,Wout,2)
output : ( N , C , H o u t , W o u t ) \text{output}:(N,C,H_{out},W_{out}) output:(N,C,Hout,Wout)
这里的 input \text{input} input 和 output \text{output} output 就是输入的图片,或者是网络中的feature map。关键的处理过程在于grid,grid的最后一维的大小为2,即表示 input \text{input} input 中pixel的位置信息 (x,y) ,这里一般会将x和y的取值范围归一化到 [−1,1] 之间, (−1,−1) 表示 input \text{input} input 左上角的像素的坐标,(1,1) 表示 input \text{input} input 右下角的像素的坐标,对于超出这个范围的坐标(x,y),函数将会根据参数 padding_mode \text{padding\_mode} padding_mode 的设定进行不同的处理。
- padding_mode=’zeros’ \text{padding\_mode='zeros'} padding_mode=’zeros’:对于越界的位置在网格中采用pixel value=0进行填充。
- padding_mode=’border’ \text{padding\_mode='border'} padding_mode=’border’对于越界的位置在网格中采用边界的pixel value进行填充。
- padding_mode=’reflection’ \text{padding\_mode='reflection'} padding_mode=’reflection’:对于越界的位置在网格中采用关于边界的对称值进行填充。
对于mode='bilinear’参数,则定义了在
input
\text{input}
input 中指定位置的pixel value中进行插值的方法,为什么需要插值呢?因为前面我们说了,grid中表示的位置信息x和y的取值范围在 [−1,1] 之间,这就意味着我们要根据一个浮点型的坐标值在
input
\text{input}
input 中对pixel value进行采样,mode有nearest和bilinear
两种模式。 nearest就是直接采用与 (x,y) 距离最近处的像素值来填充grid,而bilinear则是采用双线性插值的方法来进行填充,总之其与nearest的区别就是nearest只考虑最近点的pixel value,而bilinear则采用(x,y)周围的四个pixel value进行加权平均值来填充grid。