Online Hard Negative Mining on Pytorch

本文介绍了一种在线硬例挖掘(OHEM)的方法,该方法通过挑选出模型预测较为困难的样本进行损失计算,以提高模型训练的效率和效果。在PyTorch框架下实现了一个NLL_OHEM类,用于自定义损失函数,实现在线硬例挖掘。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

转载自: http://www.erogol.com/online-hard-example-mining-pytorch/

import torch as th                                                                 
                                                                                   
                                                                                   
class NLL_OHEM(th.nn.NLLLoss):                                                     
    """ Online hard example mining. 
    Needs input from nn.LogSotmax() """                                             
                                                                                   
    def __init__(self, ratio):      
        super(NLL_OHEM, self).__init__(None, True)                                 
        self.ratio = ratio                                                         
                                                                                   
    def forward(self, x, y, ratio=None):                                           
        if ratio is not None:                                                      
            self.ratio = ratio                                                     
        num_inst = x.size(0)                                                       
        num_hns = int(self.ratio * num_inst)                                       
        x_ = x.clone()                                                             
        inst_losses = th.autograd.Variable(th.zeros(num_inst)).cuda()              
        for idx, label in enumerate(y.data):                                       
            inst_losses[idx] = -x_.data[idx, label]                                 
        #loss_incs = -x_.sum(1)                                                    
        _, idxs = inst_losses.topk(num_hns)                                        
        x_hn = x.index_select(0, idxs)                                             
        y_hn = y.index_select(0, idxs)                                             
        return th.nn.functional.nll_loss(x_hn, y_hn)     

Q1. 这里的理解应该是挑出不那么好的数据,再计算loss,而不是通过loss的判定,挑出一组好的数据再去计算新的loss?

Q2. 之前在pytorch中将20行:inst_losses[idx] = -x_.data[idx, label] 替换为 criterion 代码直接退出,不会报任何错误

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值