目录
0. 背景
百度官方给我们提供了很好的安卓端部署demo,但现阶段很多模型都是pytorch实现的,想要自己复现可能需要花些时间,因此我们可以从pytorch转到onnx再到paddle实现安卓端的部署。
简单起见,我们采用最简单的图像分类模型做演示。
1. 随便写一个pytroch模型转为onnx模型
1.1 随便写一个pytorch图像分类模型
import numpy as np
import torch
import torch.nn as nn
import onnxruntime
class simple_model(nn.Module):
def __init__(self, in_dim, n_class):
super(simple_model, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_dim, 16, 3, stride=1, padding=0),
nn.ReLU(True),
nn.MaxPool2d(2, 2),
nn.Conv2d(16, 32, 3, stride=1, padding=0),
nn.ReLU(True),
nn.MaxPool2d(2, 2),
nn.Conv2d(32, 64, 3, stride=1, padding=0),
nn.MaxPool2d(2, 2)
)
self.fc = nn.Sequential(
nn.Linear(43264, n_class),
# nn.Dropout()
)
def forward(self, x):
out = self.conv(x)
out = out.view(out.size(0), -1)
out = self.fc(out)
return out
1.2 转为onnx模型,并检查输入(没有进行训练)
这一步生成onnx模型,并对比其预测结果与pytoch模型的预测结果
if __name__ == '__main__':
x = torch.randn(1, 3, 224, 224, requires_grad=True)
model = simple_model(3, 2)
out = model(x)
# Export the model
torch.onnx.export(model,