卷积模块重参数化

卷积模块的重参数化实现。
class SeqConv3x3(nn.Module):
def init(self, seq_type, inp_planes, out_planes, depth_multiplier):
super(SeqConv3x3, self).init()

    self.type = seq_type
    self.inp_planes = inp_planes
    self.out_planes = out_planes

    if self.type == 'conv1x1-conv3x3':
        self.mid_planes = int(out_planes * depth_multiplier)
        conv0 = torch.nn.Conv2d(self.inp_planes, self.mid_planes, kernel_size=1, padding=0)
        self.k0 = conv0.weight
        self.b0 = conv0.bias

        conv1 = torch.nn.Conv2d(self.mid_planes, self.out_planes, kernel_size=3)
        self.k1 = conv1.weight
        self.b1 = conv1.bias

    elif self.type == 'conv1x1-sobelx':
        conv0 = torch.nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=1, padding=0)
        self.k0 = conv0.weight
        self.b0 = conv0.bias

        # init scale & bias
        scale = torch.randn(size=(self.out_planes, 1, 1, 1)) * 1e-3
        self.scale = nn.Parameter(scale)
        # bias = 0.0
        # bias = [bias for c in range(self.out_planes)]
        # bias = torch.FloatTensor(bias)
        bias = torch.randn(self.out_planes) * 1e-3
        bias = torch.reshape(bias, (self.out_planes,))
        self.bias = nn.Parameter(bias)
        # init mask
        self.mask = torch.zeros((self.out_planes, 1, 3, 3), dtype=torch.float32)
        for i in range(self.out_planes):
            self.mask[i, 0, 0, 0] = 1.0
            self.mask[i, 0, 1, 0] = 2.0
            self.mask[i, 0, 2, 0] = 1.0
            self.mask[i, 0, 0, 2] = -1.0
            self.mask[i, 0, 1, 2] = -2.0
            self.mask[i, 0, 2, 2] = -1.0
        self.mask = nn.Parameter(data=self.mask, requires_grad=False)

    elif self.type == 'conv1x1-sobely':
        conv0 = torch.nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=1, padding=0)
        sel
评论 1
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值