深度学习模型移植pytorch->onnx->ncnn->android
最近做项目需要在手机上跑深度学习的模型,踩了不少坑,做个记录。这里以mobilenetv3为例子。
pytorch->onnx
1.环境配置
pytorch 1.0 之后自带onnx输出,所以是个很好的选择,顺便装个caffe2验证不同后端加载的模型输出结果是否一致。如果用conda配置,就几行命令。
conda install pytorch-nightly-cpu -c pytorch
conda install -c conda-forge onnx
2.mobilenetv3 修改自 rwightman 的代码,模型参数也是这位大神训练的。修改后的代码见此。
拿一张图像先跑一下
import time
import numpy as np
import torch
from PIL import Image
import mobilenetv3
import onnx
import caffe2.python.onnx.backend as backend
src_image = 'cat.jpg'
input_size = 224
mean_vals = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32).view(3,1,1)
std_vals = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32).view(3,1,1)
imagenet_ids = []
with open("synset_words.txt", "r") as f:
for line in f.readlines():
imagenet_ids.append(line.strip())
###############################prepare model####################################
infer_device = torch.device('cpu')
res_model = mobilenetv3.MobileNetV3()
state_dict = torch.load('./mobilenetv3x1.0.pth', map_location=lambda storage, loc: storage)
res_model.load_state_dict(state_dict)
res_model.to(infer_device)
res_model.eval()
################################prepare test image#####################################
pil_im = Image.open(src_image).convert('RGB')
pil_im_resize = pil_im.resize((input_size, input_size), Image.BILINEAR)
origin_im_tensor = torch.ByteTensor(torch.ByteStorage.from_buffer(pil_im_resize.tobytes()))
origin_im_tensor = origin_im_tensor.view(input_size, input_size, 3)
origin_im_tensor = origin_im_tensor.transpose(0, 1).transpose(0, 2).contiguous()
origin_im_tensor = (origin_im_tensor.float()/255 - mean_vals)/std_vals
origin_im_tensor = origin_im_tensor.unsqueeze(0)
###########################test################################
t1=time.time()
with torch.no_grad():
pred = res_model(origin_im_tensor.to(infer_device))
predidx = torch.argmax(pred, dim=1)
t2 = time.time()
print(t2 - t1)
print("predict result: ", imagenet_ids[predidx])
不出意外,预测应该是猫
0.03892230987548828
predict result: n02123597 Siamese cat, Siamese
3.pytorch导出onnx并检查,代码接上段
##################export###############
output_onnx = 'mobilenetv3.onnx'
x = origin_im_tensor
print("==> Exporting model to ONNX format at '{}'".format(output_onnx))
input_names = ["input0"]
output_names = ["output0"]
torch_out = torch.onnx._export(res_model, x, output_onnx, export_params=True, verbose=