文章目录
使用Pytorch的 C++ (CUDA)扩展实现卷积算子
卷积算子
在实现卷积算子时,将卷积算子等价转换为两个阶段:收集(collecting)和 线性变换(linear transform)。在收集阶段,“image to column” 和 “column to image” 过程分别在前向传播中将输入的 X ∈ R B × C × H × W X\in \mathbb{R}^{B\times C\times H\times W} X∈RB×C×H×W 转换为 H ∈ R B × ( C ⋅ K h ⋅ K w ) × H o u t × W o u t H \in \mathbb{R}^{B\times (C\cdot K_h\cdot K_w)\times H_{out}\times W_{out}} H∈RB×(C⋅Kh⋅Kw)×Hout×Wout 和 在反向传播中将传递得到的 H g r a d ∈ R B × ( C ⋅ K h ⋅ K w ) × H o u t × W o u t H_{grad} \in \mathbb{R}^{B\times (C\cdot K_h\cdot K_w)\times H_{out}\times W_{out}} Hgrad∈RB×(C⋅Kh⋅Kw)×Hout×Wout 逆向转换为 X g r a d ∈ R B × C × H × W X_{grad}\in \mathbb{R}^{B\times C\times H\times W} Xgrad∈RB×C×H×W。在线性变换阶段,卷积核权重对前向传播中得到的 H H H 进行加权,或根据 H g r a d H_{grad} Hgrad 得到对应的权重梯度 W g r a d W_{grad} Wgrad。
代码
这里根据设备不同提供两种实现方式,一种是在CPU设备上串行实现,另一种是在GPU设备上并行实现。
CPU 串行版本(Python)
from torch.autograd import Function
# 用于找到对应的数据位置坐标
def data_index_init(offset, clist, *args):
if not args:
return offset
X, *rest = args
offset = data_index_init(offset, clist, *rest)
tmp = offset % X
clist.append(tmp)
# print(offset, x)
return offset // X
# image to column
def im2col_cpu(img, # 图像张量
kernel_h, kernel_w, # 卷积核大小
pad_h, pad_w, # 填充
stride_h, stride_w, # 跨步
dilation_h, dilation_w ): # 卷积核膨胀系数
'''
img: (N, C, H, W)
'''
b,c,h,w = img.shape
out_h = (h - (dilation_h*(kernel_h - 1) + 1) + 2*pad_h)//stride_h + 1
out_w = (w - (dilation_w*(kernel_w - 1) + 1) + 2*pad_w)//stride_w + 1
out_c = c*kernel_h*kernel_w
col = torch.zeros(b, out_c, out_h, out_w).to(img.device)
for col_c in range(out_c):
clist = []
data_index_init(col_c, clist, c, kernel_h, kernel_w)
c_im, offset_h, offset_w = clist[::-1]
for col_h in range(out_h):
h_im = col_h*stride_h - pad_h + offset_h*dilation_h
for col_w in range(out_w):
w_im = col_w*stride_w - pad_w + offset_w*dilation_w
if h_im < h and h_im >=0 and w_im < w and w_im >= 0:
col[:, col_c, col_h, col_w] = img[:, c_im, h_im, w_im]
return col
# column to image
def col2im_cpu(col_gd,
c, h, w,
kernel_h, kernel_w,
pad_h, pad_w,
stride_h, stride_w,
dilation_h, dilation_w):
'''
col: (N, C', H', W')
'''
b,out_c,out_h,out_w = col_gd.shape
img_gd = torch.zeros(b, c, h, w)
for col_c in range(out_c):
clist = []
data_index_init(col_c, clist, c, kernel_h, kernel_w)
c_im, offset_h, offset_w = clist[