加载已编译的so文件,定义C函数的参数和返回类型,这里面传入的是torch.tensor数组,需要将数据转成ctypes.c_float格式才能被正确识别,有两种方式可以转换:
第一种:
1:将tensor数组转成list,计算数组长度
tensor = tensor.reshape(-1).tolist()
tensor_len = len(tensor)
2:确定要映射的数据类型,创建c数组
array_type = ctypes.c_float * tensor_len
c_array = array_type(*tensor)# 将 python列表转换为 C 指针
第二种:(速度快)
1: 将tensor数组转换成numpy数组
numpy_array = tensor.float().cpu().numpy()
2:利用numpy对ctypes的支持直接转换指针类型,速度很快
c_array = numpy_array.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
完整代码
def sh_fp16_to_fp8_v2(tensor, ebias):
import ctypes
from ctypes import cdll
# # 加载库文件
fp8_v2_func = cdll.LoadLibrary('/home/liuzhiwen/Desktop/fp8_v2.so')
# 定义C函数的参数和返回类型