from torch import nn
import math
import torch
"""
Args:
out_side (tuple): Length of side in the pooling results of each pyramid layer.
Inputs:
- `input`: the input Tensor to invert ([batch, channel, width, height])
"""
class SpatialPyramidPool2D(nn.Module):
def __init__(self,out_side):
super(SpatialPyramidPool2D,self).__init__()
self.out_side=out_side
def forward(self,x):
out=None
for n in self.out_side:
w_r,h_r=map(lambda s:math.ceil(s/n),x.size()[2:])
s_w,s_h=map(lambda s:math.floor(s/n),x.size()[2:])
max_pool=nn.MaxPool2d(kernel_size=(w_r,h_r),stride=(s_w,s_h))
y=max_pool(x)
if out is None:
out=y.view(y.size()[0],-1)
else:
out=torch.cat((out,y.view(y.size()[0],-1)),1)
return out
seq=nn.Sequential(
nn.Conv2d(
in_channels=1,
out_channels=16,
kernel_size=5,
stride=1,
padding=2,
),
nn.ReLU(),
SpatialPyramidPool2D(out_side=(1,2,4))
)
x=torch.randn(1,1,10,20)
y=seq(x)
print(y.shape)