BiDet代码解析1:随机采样定位(Stochastic Sampling Localization,SSL)

代码库:https://github.com/ZiweiWangTHU/BiDet
源代码(SSD训练为例): ssd/train_bidet_ssd.py

        # sample loc data from predicted miu and sigma
        normal_dist = torch.randn(loc_data.size(0), loc_data.size(1), 4).float().cuda()
        log_sigma_2 = loc_data[:, :, :4]
        miu = loc_data[:, :, 4:]
        sigma = torch.exp(log_sigma_2 / 2.)
        sample_loc_data = normal_dist * sigma * args.sigma + miu
        loc_data = sample_loc_data

这段代码是BiDet算法中的一个重要部分,它实现了随机采样定位(Stochastic Sampling Localization,SSL)的功能。SSL的目的是在训练过程中增加定位分支的随机性,从而提高二值化网络的泛化能力和定位精度。

具体来说,这段代码的作用是从预测的均值(miu)和方差(sigma)中采样出一个定位数据(sample_loc_data),用来替代原始的定位数据(loc_data)。这样做的好处是可以避免二值化网络过拟合到一个固定的定位值,而是让它能够适应不同的定位情况。

这段代码中,normal_dist是一个标准正态分布,log_sigma_2是预测的方差的对数,miu是预测的均值,sigma是预测的方差的平方根,args.sigma是一个超参数,用来控制采样范围。sample_loc_data是根据公式 s a m p l e _ l o c _ d a t a = n o r m a l _ d i s t ∗ s i g m a ∗ a r g s . s i g m a + m i u sample\_loc\_data = normal\_dist * sigma * args.sigma + miu sample_loc_data=normal_distsigmaargs.sigma+miu计算出来的。loc_data是将sample_loc_data赋值给原始的定位数据。

sample_loc_data和loc_data的区别是,sample_loc_data是从预测的均值和方差中随机采样出来的一个定位数据,而loc_data是原始的定位数据,它包含了预测的方差和均值。在训练过程中,用sample_loc_data替代loc_data,可以增加定位分支的随机性,提高二值化网络的泛化能力和定位精度。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值