0 参考文章
本文是在《pointnet C++推理部署–libtorch框架》的基础上针对pointnet++做的修改。
下面我以pointnet++的partseg任务为例对修改进行说明,环境配置与上文一致。
1 .pth转.pt脚本
import torch
import pointnet2_part_seg_msg
point_num = 2048
class_num = 16
part_num = 50
normal_channel = False
def to_categorical(y, class_num):
""" 1-hot encodes a tensor """
new_y = torch.eye(class_num)[y.cpu().data.numpy(),]
if (y.is_cuda):
return new_y.cuda()
return new_y
model = pointnet2_part_seg_msg.get_model(part_num, normal_channel)
model = model.cuda() #cpu版本需注释此句
model.eval()
checkpoint = torch.load('best_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
x = (torch.rand(1, 6, point_num) if normal_channel else torch.rand(1, 3, point_num))
x = x.cuda() #cpu版本需注释此句
#label = torch.randint(0, 1, (1, 1))
#label = label.cuda() #cpu版本需注释此句
#traced_script_module = torch.jit.trace(model, (x, to_categorical(label, class_num)))
traced_script_module = torch.jit.script(model)
traced_script_module.save('best_model_script.pt')
使用script函数的原因是它执行时报错会提示错误信息,而trace函数报错时只显示大量代码而无错误提示。
2 pointnet2_utils.py修改
为了traced_script_module = torch.jit.script(model)
函数正常执行,需要对pointnet2_utils.py文件进行修改。
修改的问题分为3类:
- 函数参数类型未指定;
- ModuleList使用变量下标访问元素;
- 函数返回值类型不固定。
修改后代码:
import torch
import torch.nn as nn
import torch.nn.functional as F
from time import time
import numpy as np
def timeit(tag, t):
print("{}: {}s".format(tag, time() - t))
return time()
def pc_normalize(pc):
l = pc.shape[0]
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
pc = pc / m
return pc
def square_distance(src, dst):
"""
Calculate Euclid distance between each two points.
src^T * dst = xn * xm + yn * ym + zn * zm;
sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
= sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
Input:
src: source points, [B, N, C]
dst: target points, [B, M, C]
Output:
dist: per-point square distance, [B, N, M]
"""
B, N, _ = src.shape
_, M, _ = dst.shape
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
dist += torch.sum(src ** 2, -1).view(B, N, 1)
dist += torch.sum(dst ** 2, -1).view(B, 1, M)
return dist
def index_points(points, idx):
"""
Input:
points: input points data, [B, N, C]
idx: sample index data, [B, S]
Return:
new_points:, indexed points data, [B, S, C]
"""
device = points.device
B = points.shape[0]
view_shape = list(idx.shape)
# view_shape[1:] = [1] * (len(view_shape) - 1)
new_view_shape = [view_shape[0]] + [1] * (len(view_shape) - 1)
view_shape = new_view_shape
repeat_shape = list(idx.shape)
repeat_shape