如何部署深度学习模型
使用gradio包来进行简单的网页搭建。
import gradio as gr
第一将调用模型的方法,与网页上点击run图标进行绑定。
在app方法中部署加载模型的html页面。
在api方法中对模型进行初始化。
前置工作需要观察模型的权重文件的类型不同类型使用不同的加载方式。下面的方法加载了后缀为.pt的模型参数。
第二步在调用模型的方法中首先加载模型的参数,然后使用模型来解决任务。
model_fair_7 = torchvision.models.resnet34(pretrained=False)
model_fair_7.fc = nn.Linear(model_fair_7.fc.in_features, 18)
model_fair_7.load_state_dict(
torch.load(path_ckpt))
如果需要使用预训练模型可以在huggingface中考虑下载。
关于模型中图片为什么显示黑色的问题:
imgs[20,1,584,565]通过下述两种方式取到的结果意义是不一样的,主要是和imgs的构造有关,它是由
imgs = np.expand_dims(img,0) if imgs is None else np.concatenate((imgs,np.expand_dims(img,0)))在第一个维度一张张图拼接起来的。
正确的做法如下:
//for循环取出pred_imgs中的元素就是[1,584,565]
for i in range(self.test_imgs.shape[0]):
total_img = concat_result(self.test_imgs[i],self.pred_imgs[i])
concat_result:
//[1,584,565]---->[584,565,1]
pred_res = data = np.transpose(pred_res,(1,2,0))
if ori_img.shape[0]==3:
pred_res = np.repeat((pred_res*255).astype(np.uint8),repeats=3,axis=2)
pred_res = pred_res.astype(np.uint8)
img_pil = Image.fromarray(pred_res)
img_pil.show()
错误的做法:
pred_imgs = pred_imgs[0, :, :, :].astype(np.uint8)
问题的排查方式:
直接排查原始代码关键位置图片是否能够正常显示:
模型预测输出部分、各项后处理操作处理完成之后。