import torch
import torch.nn as nn
import torch.nn.functional as F
import cv2
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
kernel = [[0, 0 , 0 ],
[0 , 1/255 , 0],
[0 , 0 , 0 ]]
kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0)
self.weight = nn.Parameter(data=kernel, requires_grad=False)
def forward(self,x):
x1 = x[:, 0]
x2 = x[:, 1]
x3 = x[:, 2]
print(x1.shape)
x1 = F.conv2d(x1.unsqueeze(1), self.weight, padding=1)
x2 = F.conv2d(x2.unsqueeze(1), self.weight, padding=1)
x3 = F.conv2d(x3.unsqueeze(1), self.weight, padding=1)
x = torch.cat([x1, x2, x3], dim=1)
print(x.shape)
x=x.squeeze(0).permute(1,2,0)
return x
def main():
src=cv2.imread("./space_shuttle_224x224.jpg")
print(src.shape)
print(src)
net=Net()
src = torch.Tensor(src)
src = src.permute(2, 0, 1).unsqueeze(0)
y=net(src)
print(y)
net.eval()
dummy_input = src
torch.onnx.export(net,
dummy_input,
"model.onnx",
export_params=True,
opset_version=10,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
dynamic_axes={'input': {0: 'batch_size'},
'output': {0: 'batch_size'}})
if __name__ == '__main__':
main()