conv源码阅读

class SparseConvolution(SparseModule):
    def __init__(
        self,
        ndim,    # 卷积的维度,如2D、3D等
        in_channels,
        out_channels,
        kernel_size=3,
        stride=1,
        padding=0,
        dilation=1,   # 扩张大小,默认为1
        groups=1,   # 卷积组数,默认为1
        bias=True,
        subm=False,   # 是否使用子采样卷积
        output_padding=0,   # 输出填充,默认为0
        transposed=False,
        inverse=False,
        indice_key=None,   # 索引键,用于标识卷积索引的字典
        fused_bn=False,   # 是否融合批归一化
    ):
        super(SparseConvolution, self).__init__()
        assert groups == 1   # 当前仅支持groups为1的情况
        # 如果 kernel_size、stride、padding、dilation、output_padding 为整数,则将其转换为列表形式
        if not isinstance(kernel_size, (list, tuple)):
            kernel_size = [kernel_size] * ndim
        if not isinstance(stride, (list, tuple)):
            stride = [stride] * ndim
        if not isinstance(padding, (list, tuple)):
            padding = [padding] * ndim
        if not isinstance(dilation, (list, tuple)):
            dilation = [dilation] * ndim
        if not isinstance(output_padding, (list, tuple)):
            output_padding = [output_padding] * ndim
		# 检查 dilation 和 stride 的兼容性
        for d, s in zip(dilation, stride):
            assert any([s == 1, d == 1]), "don't support this."

        self.ndim = ndim
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.conv1x1 = np.prod(kernel_size) == 1
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.transposed = transposed
        self.inverse = inverse
        self.output_padding = output_padding
        self.groups = groups
        self.subm = subm
        self.indice_key = indice_key
        self.fused_bn = fused_bn

        self.weight = Parameter(torch.Tensor(*kernel_size, in_channels, out_channels))
        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter("bias", None)
        self.reset_parameters()   # 调用参数初始化函数

    def reset_parameters(self):
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))     # 使用 kaiming_uniform_ 初始化权重参数
        if self.bias is not None:
            fan_in, _ = _calculate_fan_in_and_fan_out_hwio(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

    def forward(self, input):
        assert isinstance(input, SparseConvTensor)
        features = input.features
        device = features.device
        indices = input.indices
        spatial_shape = input.spatial_shape
        batch_size = input.batch_size
        if not self.subm:
            if self.transposed:
                out_spatial_shape = ops.get_deconv_output_size(
                    spatial_shape,
                    self.kernel_size,
                    self.stride,
                    self.padding,
                    self.dilation,
                    self.output_padding,
                )
            else:
                out_spatial_shape = ops.get_conv_output_size(
                    spatial_shape, self.kernel_size, self.stride, self.padding, self.dilation
                )

        else:
            out_spatial_shape = spatial_shape
        # input.update_grid(out_spatial_shape)
        # t = time.time()
        if self.conv1x1:
            features = torch.mm(
                input.features, self.weight.view(self.in_channels, self.out_channels)
            )
            if self.bias is not None:
                features += self.bias
            out_tensor = SparseConvTensor(
                features, input.indices, input.spatial_shape, input.batch_size
            )
            out_tensor.indice_dict = input.indice_dict
            out_tensor.grid = input.grid
            return out_tensor
        datas = input.find_indice_pair(self.indice_key)
        if self.inverse:
            assert datas is not None and self.indice_key is not None
            _, outids, indice_pairs, indice_pair_num, out_spatial_shape = datas
            assert indice_pairs.shape[0] == np.prod(
                self.kernel_size
            ), "inverse conv must have same kernel size as its couple conv"
        else:
            if self.indice_key is not None and datas is not None:
                outids, _, indice_pairs, indice_pair_num, _ = datas
            else:
                outids, indice_pairs, indice_pair_num = ops.get_indice_pairs(
                    indices,
                    batch_size,
                    spatial_shape,
                    self.kernel_size,
                    self.stride,
                    self.padding,
                    self.dilation,
                    self.output_padding,
                    self.subm,
                    self.transposed,
                    grid=input.grid,
                )
                input.indice_dict[self.indice_key] = (
                    outids,
                    indices,
                    indice_pairs,
                    indice_pair_num,
                    spatial_shape,
                )
        if self.fused_bn:
            assert self.bias is not None
            out_features = ops.fused_indice_conv(
                features,
                self.weight,
                self.bias,
                indice_pairs.to(device),
                indice_pair_num,
                outids.shape[0],
                self.inverse,
                self.subm,
            )
        else:
            if self.subm:
                out_features = Fsp.indice_subm_conv(
                    features, self.weight, indice_pairs.to(device), indice_pair_num, outids.shape[0]
                )
            else:
                if self.inverse:
                    out_features = Fsp.indice_inverse_conv(
                        features,
                        self.weight,
                        indice_pairs.to(device),
                        indice_pair_num,
                        outids.shape[0],
                    )
                else:
                    out_features = Fsp.indice_conv(
                        features,
                        self.weight,
                        indice_pairs.to(device),
                        indice_pair_num,
                        outids.shape[0],
                    )

            if self.bias is not None:
                out_features += self.bias
        out_tensor = SparseConvTensor(out_features, outids, out_spatial_shape, batch_size)
        out_tensor.indice_dict = input.indice_dict
        out_tensor.grid = input.grid
        return out_tensor
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值