神经网络实现图像的线性变换

# This is a sample Python script.

# Press Shift+F10 to execute it or replace it with your code.
# Press Double Shift to search everywhere for classes, files, tool windows, actions, and settings.
import torch
import torch.nn as nn
import torch.nn.functional as F
import cv2
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        # 卷积核的中心
        kernel = [[0, 0 , 0 ],
                  [0 , 1/255 , 0],
                  [0 , 0 , 0 ]]

        kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0)
        self.weight = nn.Parameter(data=kernel, requires_grad=False)

    def forward(self,x):
        x1 = x[:, 0]
        x2 = x[:, 1]
        x3 = x[:, 2]
        print(x1.shape)
        # print(x1)
        x1 = F.conv2d(x1.unsqueeze(1), self.weight, padding=1)
        x2 = F.conv2d(x2.unsqueeze(1), self.weight, padding=1)
        x3 = F.conv2d(x3.unsqueeze(1), self.weight, padding=1)
        x = torch.cat([x1, x2, x3], dim=1)
        print(x.shape)
        x=x.squeeze(0).permute(1,2,0)
        return x

def main():
    src=cv2.imread("./space_shuttle_224x224.jpg")
    print(src.shape)
    print(src)
    net=Net()
    # src = cv2.cvtColor(src, cv2.COLOR_BGR2RGB)
    # print(src)
    src = torch.Tensor(src)
    src = src.permute(2, 0, 1).unsqueeze(0)
    # print(src.shape)
    # print(src)
    y=net(src)
    print(y)
    net.eval()

    # 虚拟输入,这里使用src作为虚拟输入
    dummy_input = src

    # 导出模型
    torch.onnx.export(net,  # 模型
                      dummy_input,  # 模型输入的虚拟张量
                      "model.onnx",  # 输出文件的路径
                      export_params=True,  # 是否导出模型参数
                      opset_version=10,  # ONNX版本
                      do_constant_folding=True,  # 是否执行常量折叠优化
                      input_names=['input'],  # 输入张量的名字
                      output_names=['output'],  # 输出张量的名字
                      dynamic_axes={'input': {0: 'batch_size'},  # 批量大小动态化
                                    'output': {0: 'batch_size'}})



# Press the green button in the gutter to run the script.
if __name__ == '__main__':
    main()

# See PyCharm help at https://www.jetbrains.com/help/pycharm/

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值