0x1. 前言
在ResNet中(https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py),关于BatchNorm的调用一共有两种模式,第一种是ReLU接在BN之后:
out = self.bn1(out)
out = self.relu(out)
另外一种模式是残差结构引入的 BNAddReLU 的模式:
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
我们知道在 CUDA 优化中常见的一个技巧是将一些ElementWise的算子融合到之前的计算密集型算子如卷积,矩阵乘等。在OneFlow中针对上述两种情况并且cudnn无法fuse时分别进行了fuse和优化,本篇文章就来解析一下这里的代码实现,体会其中的CUDA优化技巧。这里的源码开源在OneFlow的github仓库:https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/user/kernels/normalization_kernel.cu 。如果本文对你产生了启发,不妨为OneFlow投个star。
0x2. 代码解析
0x2.1 CUDNN BatchNorm算子的实现和局限
我们先来看一下OneFlow中是如何使用CUDNN库实现BatchNorm算子的。代码见:https://github.com/Oneflow-Inc/oneflow/blob/master/oneflow/user/kernels/normalization_kernel.cu#L31-L244 。这段代码中首先实现了一个getCudnnBatchNormMode工具函数:
cudnnBatchNormMode_t getCudnnBatchNormMode(const int64_t dim) {
if (dim == 2) {
return CUDNN_BATCHNORM_PER_ACTIVATION;
} else if (ParseBooleanFromEnv("ONEFLOW_ENABLE_NHWC", false)) {
return CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
} else {
// NOTE(Liang Depeng): The new CUDNN_BATCHNORM_SPATIAL_PERSISTENT mode was
// introduced in CuDNN 7 for performance optimization, but it results in
// accuracy losses in convolution models such as ResNeXt-101 and
// video R(2+1)D. We will fall back to the normal CUDNN_BATCHNORM_SPATIAL
return CUDNN_BATCHNORM_SPATIAL;
}
}
这里的dim表示输入Tensor的维度,比如形状为(1,3,224,224)(1, 3, 224, 224)(1,3,224,224)的输入Tensor,这里的维度就是4。然后这里涉及到三种不同的cudnnBatchNormMode_t,我们看一下CUDNN的文档(https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnBatchNormMode_t):

可以看到 CUDNN_BATCHNORM_PER_ACTIVATION 被用于非卷积层,在OneFlow中只有当输入Tensor的维度为2时才选取这种模式。而CUDNN_BATCHNORM_SPATIAL_PERSISTENT这种模式只有当输入Tensor的数据排布为NHWC方式时才会启用。而对于其它的模式,在OneFlow中一律选取CUDNN_BATCHNORM_SPATIAL模式。
接下来阅读一下 InferDimSizeAndDataFormat 函数:
void InferDimSizeAndDataFormat(const ShapeView& x_shape, const int32_t axis, int32_t* n, int32_t* c,
int32_t* h, int32_t* w, cudnnTensorFormat_t* format) {
if (x_shape.Count(axis + 1) == 1) {
if (axis == 0) {
*n = 1;
*h = 1;
} else {
*n

最低0.47元/天 解锁文章
448

被折叠的 条评论
为什么被折叠?



