Pointnet++在Pytorch下模型(.pth)转Libtorch下模型(.pt)

0 参考文章

本文是在《pointnet C++推理部署–libtorch框架》的基础上针对pointnet++做的修改。

下面我以pointnet++的partseg任务为例对修改进行说明,环境配置与上文一致。

1 .pth转.pt脚本

import torch
import pointnet2_part_seg_msg

point_num = 2048
class_num = 16
part_num = 50
normal_channel = False

def to_categorical(y, class_num):
    """ 1-hot encodes a tensor """
    new_y = torch.eye(class_num)[y.cpu().data.numpy(),]
    if (y.is_cuda):
        return new_y.cuda()
    return new_y

model = pointnet2_part_seg_msg.get_model(part_num, normal_channel)
model = model.cuda() #cpu版本需注释此句
model.eval()
checkpoint = torch.load('best_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])

x = (torch.rand(1, 6, point_num) if normal_channel else torch.rand(1, 3, point_num))
x = x.cuda() #cpu版本需注释此句
#label = torch.randint(0, 1, (1, 1))
#label = label.cuda() #cpu版本需注释此句

#traced_script_module = torch.jit.trace(model, (x, to_categorical(label, class_num)))
traced_script_module = torch.jit.script(model)
traced_script_module.save('best_model_script.pt')

使用script函数的原因是它执行时报错会提示错误信息,而trace函数报错时只显示大量代码而无错误提示。

2 pointnet2_utils.py修改

为了traced_script_module = torch.jit.script(model)函数正常执行,需要对pointnet2_utils.py文件进行修改。

修改的问题分为3类:

  1. 函数参数类型未指定;
  2. ModuleList使用变量下标访问元素;
  3. 函数返回值类型不固定。

修改后代码:

import torch
import torch.nn as nn
import torch.nn.functional as F
from time import time
import numpy as np


def timeit(tag, t):
    print("{}: {}s".format(tag, time() - t))
    return time()


def pc_normalize(pc):
    l = pc.shape[0]
    centroid = np.mean(pc, axis=0)
    pc = pc - centroid
    m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
    pc = pc / m
    return pc


def square_distance(src, dst):
    """
    Calculate Euclid distance between each two points.

    src^T * dst = xn * xm + yn * ym + zn * zm;
    sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
    sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
    dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
         = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst

    Input:
        src: source points, [B, N, C]
        dst: target points, [B, M, C]
    Output:
        dist: per-point square distance, [B, N, M]
    """
    B, N, _ = src.shape
    _, M, _ = dst.shape
    dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
    dist += torch.sum(src ** 2, -1).view(B, N, 1)
    dist += torch.sum(dst ** 2, -1).view(B, 1, M)
    return dist


def index_points(points, idx):
    """

    Input:
        points: input points data, [B, N, C]
        idx: sample index data, [B, S]
    Return:
        new_points:, indexed points data, [B, S, C]
    """
    device = points.device
    B = points.shape[0]
    view_shape = list(idx.shape)
    # view_shape[1:] = [1] * (len(view_shape) - 1)
    new_view_shape = [view_shape[0]] + [1] * (len(view_shape) - 1)
    view_shape = new_view_shape

    repeat_shape = list(idx.shape)
    repeat_shape
评论 18
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值