pytorch利用 CFFI 进行 C 语言扩展。包括两个基本的步骤(docs):
- 编写 C 代码;
- python 调用 C 代码,实现相应的 Function 或 Module。
C语言拓展需要上文提到的torch.util.ffi
模块,比如官网给出的定义一个加法运算(一个栗子),这里以定义ReLU函数为例:
1. C 代码
pytorch C 的基本数据结构是 THTensor(THFloatTensor、THByteTensor等)。我们以简单的 ReLU 函数为例,示例编写 C 。
y=ReLU(x)=max(x,0)
Function 需要定义前向和后向两个方向的操作,因此,C 代码要实现相应的功能。
1.1 头文件 ext_lib.h
/* ext_lib.h */
int relu_forward(THFloatTensor *input, THFloatTensor *output);
int relu_backward(THFloatTensor *grad_output, THFloatTensor *input, THFloatTensor *grad_input);
1.1 函数实现 ext_lib.c
/* ext_lib.c */
#include <TH/TH.h>
int relu_forward(THFloatTensor *input, THFloatTensor *output)
{
THFloatTensor_resizeAs(output, input);
THFloatTensor_clamp(output, input, 0, INFINITY);
return 1;
}
int relu_backward(THFloatTensor *grad_output, THFloatTensor *input, THFloatTensor *grad_input)
{
THFloatTensor_resizeAs(grad_input, grad_output);
THFloatTensor_zero(grad_input);
THLongStorage* size = THFloatTensor_newSizeOf(grad_output);
THLongStorage *stride = THFloatTensor_newStrideOf(grad_output);
THByteTensor *mask = THByteTensor_newWithSize(size, stride);
THFloatTensor_geValue(mask, input, 0);
THFloatTensor_maskedCopy(grad_input, mask, grad_output);
return 1;
}