参考文献:https://github.com/Coldog2333/pytoflow
(https://github.com/anchen1011/toflow)
import math
import torch
# import torch.utils.serialization # it was removed in torch v1.0.0 or higher version.
arguments_strModel = 'sintel-final'
SpyNet_model_dir = './models' # The directory of SpyNet's weights
def normalize(tensorInput):
tensorRed = (tensorInput[:, 0:1, :, :] - 0.485) / 0.229
tensorGreen = (tensorInput[:, 1:2, :, :] - 0.456) / 0.224
tensorBlue = (tensorInput[:, 2:3, :, :] - 0.406) / 0.225
return torch.cat([tensorRed, tensorGreen, tensorBlue], 1)
def denormalize(tensorInput):
tensorRed = (tensorInput[:, 0:1, :, :] * 0.229) + 0.485
tensorGreen = (tensorInput[:, 1:2, :, :] * 0.224) + 0.456
tensorBlue = (tensorInput[:, 2:3, :, :] * 0.225) + 0.406
return torch.cat([tensorRed, tensorGreen, tensorBlue], 1)
Backward_tensorGrid = {}
def Backward(tensorInput, tensorFlow, cuda_flag):
if str(tensorFlow.size()) not in Backward_tensorGrid:
tensorHorizontal = torch.linspace(-1.0, 1.0, tensorFlow.size(3)).view(1, 1, 1, tensorFlow.size(3)).expand(tensorFlow.size(0), -1, tensorFlow.size(2), -1)
tensorVertical = torch.linspace(-1.0, 1.0, tensorFlow.size(2)).view(1, 1, tensorFlow.size(2), 1).expand(tensorFlow.size(0), -1, -1, tensorFlow.size(3))
if cuda_flag:
Backward_tensorGrid[str(tensorFlow.size())] = torch.cat([ tensorHorizontal, tensorVertical ], 1).cuda()
else:
Backward_tensorGrid[str(tensorFlow.size())] = torch.cat([tensorHorizontal, tensorVertical], 1)
# end
tensorFlow = torch.cat([ tensorFlow[:, 0:1, :, :] / ((tensorInput.size(3) - 1.0) / 2.0), tensorFlow[:, 1:2, :, :] / ((tensorInput.size(2) - 1.0) / 2.0) ], 1)
return torch.nn.functional.grid_sample(input=tensorInput, grid=(Backward_tensorGrid[str(tensorFlow.size())] + tensorFlow).permute(0, 2, 3, 1), mode='bilinear', padding_mode='border')
# end
SpyNet
class SpyNet(torch.nn.Module):
def __init__(self, cuda_flag):
super(SpyNet, self).__init__()
self.cuda_flag = cuda_flag
class Basic(torch.nn.Module):
def __init__(self, intLevel):
super(Basic, self).__init__()
self.moduleBasic = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3),
torch.nn.ReLU(inplace=False),
torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3),
torch.nn.ReLU(inplace=False),
torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3),
torch.nn.ReLU(inplace=False),
torch.nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3),
torch.nn.ReLU(inplace=False),
torch.nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3)
)
# end
def forward(self, tensorInput):
return self.moduleBasic(tensorInput)
self.moduleBasic = torch.nn.ModuleList([Basic(intLevel) for intLevel in range(4)])
self.load_state_dict(torch.load(SpyNet_model_dir + '/network-' + arguments_strModel + '.pytorch'), strict=False)
def forward(self, tensorFirst, tensorSecond):
tensorFirst = [tensorFirst]
tensorSecond = [tensorSecond]
for intLevel in range(3):
if tensorFirst[0].size(2) > 32 or tensorFirst[0].size(3) > 32:
tensorFirst.insert(0, torch.nn.functional.avg_pool2d(input=tensorFirst[0], kernel_size=2, stride=2))
tensorSecond.insert(0, torch.nn.functional.avg_pool2d(input=tensorSecond[0], kernel_size=2, stride=2))
tensorFlow = tensorFirst[0].new_zeros(tensorFirst[0].size(0), 2,
int(math.floor(tensorFirst[0].size(2) / 2.0)),
int(math.floor(tensorFirst[0].size(3) / 2.0)))
for intLevel in range(len(tensorFirst)):
tensorUpsampled = torch.nn.functional.interpolate(input=tensorFlow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0
# if the sizes of upsampling and downsampling are not the same, apply zero-padding.
if tensorUpsampled.size(2) != tensorFirst[intLevel].size(2):
tensorUpsampled = torch.nn.functional.pad(input=tensorUpsampled, pad=[0, 0, 0, 1], mode='replicate')
if tensorUpsampled.size(3) != tensorFirst[intLevel].size(3):
tensorUpsampled = torch.nn.functional.pad(input=tensorUpsampled, pad=[0, 1, 0, 0], mode='replicate')
# input :[first picture of corresponding level,
# the output of w with input second picture of corresponding level and upsampling flow,
# upsampling flow]
# then we obtain the final flow. 最终再加起来得到intLevel的flow
tensorFlow = self.moduleBasic[intLevel](torch.cat([tensorFirst[intLevel],
Backward(tensorInput=tensorSecond[intLevel],
tensorFlow=tensorUpsampled,
cuda_flag=self.cuda_flag),
tensorUpsampled], 1)) + tensorUpsampled
return tensorFlow
warp
class warp(torch.nn.Module):
def __init__(self, h, w, cuda_flag):
super(warp, self).__init__()
self.height = h
self.width = w
if cuda_flag:
self.addterm = self.init_addterm().cuda()
else:
self.addterm = self.init_addterm()
def init_addterm(self):
n = torch.FloatTensor(list(range(self.width)))
horizontal_term = n.expand((1, 1, self.height, self.width)) # 第一个1是batch size
n = torch.FloatTensor(list(range(self.height)))
vertical_term = n.expand((1, 1, self.width, self.height)).permute(0, 1, 3, 2)
addterm = torch.cat((horizontal_term, vertical_term), dim=1)
return addterm
def forward(self, frame, flow):
"""
:param frame: frame.shape (batch_size=1, n_channels=3, width=256, height=448)
:param flow: flow.shape (batch_size=1, n_channels=2, width=256, height=448)
:return: reference_frame: warped frame
"""
if True:
flow = flow + self.addterm
else:
self.addterm = self.init_addterm()
flow = flow + self.addterm
horizontal_flow = flow[0, 0, :, :].expand(1, 1, self.height, self.width) # 第一个0是batch size
vertical_flow = flow[0, 1, :, :].expand(1, 1, self.height, self.width)
horizontal_flow = horizontal_flow * 2 / (self.width - 1) - 1
vertical_flow = vertical_flow * 2 / (self.height - 1) - 1
flow = torch.cat((horizontal_flow, vertical_flow), dim=1)
flow = flow.permute(0, 2, 3, 1)
reference_frame = torch.nn.functional.grid_sample(frame, flow)
return reference_frame
class ResNet(torch.nn.Module):
"""
Three-layers ResNet/ResBlock
reference: https://blog.youkuaiyun.com/chenyuping333/article/details/82344334
"""
def __init__(self, task):
super(ResNet, self).__init__()
self.task = task
self.conv_3x2_64_9x9 = torch.nn.Conv2d(in_channels=3 * 2, out_channels=64, kernel_size=9, padding=8 // 2)
self.conv_3x7_64_9x9 = torch.nn.Conv2d(in_channels=3 * 7, out_channels=64, kernel_size=9, padding=8 // 2)
self.conv_64_64_9x9 = torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=9, padding=8 // 2)
self.conv_64_64_1x1 = torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=1)
self.conv_64_3_1x1 = torch.nn.Conv2d(in_channels=64, out_channels=3, kernel_size=1)
def ResBlock(self, x, aver):
if self.task == 'interp':
x = torch.nn.functional.relu(self.conv_3x2_64_9x9(x))
x = torch.nn.functional.relu(self.conv_64_64_1x1(x))
elif self.task in ['denoise', 'denoising']:
x = torch.nn.functional.relu(self.conv_3x7_64_9x9(x))
x = torch.nn.functional.relu(self.conv_64_64_1x1(x))
elif self.task in ['sr', 'super-resolution']:
x = torch.nn.functional.relu(self.conv_3x7_64_9x9(x))
x = torch.nn.functional.relu(self.conv_64_64_9x9(x))
x = torch.nn.functional.relu(self.conv_64_64_1x1(x))
else:
raise NameError('Only support: [interp, denoise/denoising, sr/super-resolution]')
x = self.conv_64_3_1x1(x) + aver
return x
def forward(self, frames):
aver = frames.mean(dim=1)
x = frames[:, 0, :, :, :]
for i in range(1, frames.size(1)):
x = torch.cat((x, frames[:, i, :, :, :]), dim=1)
result = self.ResBlock(x, aver)
return result
class TOFlow(torch.nn.Module):
def __init__(self, h, w, task, cuda_flag):
super(TOFlow, self).__init__()
self.height = h
self.width = w
self.task = task
self.cuda_flag = cuda_flag
self.SpyNet = SpyNet(cuda_flag=self.cuda_flag) # SpyNet层
# for param in self.SpyNet.parameters(): # fix
# param.requires_grad = False
self.warp = warp(self.height, self.width, cuda_flag=self.cuda_flag)
self.ResNet = ResNet(task=self.task)
# frames should be TensorFloat
def forward(self, frames):
"""
:param frames: [batch_size=1, img_num, n_channels=3, h, w]
:return:
"""
for i in range(frames.size(1)):
frames[:, i, :, :, :] = normalize(frames[:, i, :, :, :])
if self.cuda_flag:
opticalflows = torch.zeros(frames.size(0), frames.size(1), 2, frames.size(3), frames.size(4)).cuda()
warpframes = torch.empty(frames.size(0), frames.size(1), 3, frames.size(3), frames.size(4)).cuda()
else:
opticalflows = torch.zeros(frames.size(0), frames.size(1), 2, frames.size(3), frames.size(4))
warpframes = torch.empty(frames.size(0), frames.size(1), 3, frames.size(3), frames.size(4))
if self.task == 'interp':
process_index = [0, 1]
opticalflows[:, 1, :, :, :] = self.SpyNet(frames[:, 0, :, :, :], frames[:, 1, :, :, :]) / 2
opticalflows[:, 0, :, :, :] = self.SpyNet(frames[:, 1, :, :, :], frames[:, 0, :, :, :]) / 2
elif self.task in ['denoise', 'denoising', 'sr', 'super-resolution']:
process_index = [0, 1, 2, 4, 5, 6]
for i in process_index:
opticalflows[:, i, :, :, :] = self.SpyNet(frames[:, 3, :, :, :], frames[:, i, :, :, :])
warpframes[:, 3, :, :, :] = frames[:, 3, :, :, :]
else:
raise NameError('Only support: [interp, denoise/denoising, sr/super-resolution]')
for i in process_index:
warpframes[:, i, :, :, :] = self.warp(frames[:, i, :, :, :], opticalflows[:, i, :, :, :])
# warpframes: [batch_size=1, img_num=7, n_channels=3, height=256, width=448]
Img = self.ResNet(warpframes)
# Img: [batch_size=1, n_channels=3, h, w]
Img = denormalize(Img)
return Img