smoothl1函数_Python torch.nn.functional 模块,smooth_l1_loss() 实例源码 - 编程字典

这段代码展示了如何在PyTorch中使用functional模块的smooth_l1_loss函数计算损失。它涉及到定位预测(loc_preds)和目标(loc_targets)以及置信度预测(conf_preds)和目标(conf_targets)之间的SmoothL1Loss和CrossEntropyLoss。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

def forward(self, loc_preds, loc_targets, conf_preds, conf_targets):

'''Compute loss between (loc_preds, loc_targets) and (conf_preds, conf_targets).

Args:

loc_preds: (tensor) predicted locations, sized [batch_size, 8732, 4].

loc_targets: (tensor) encoded target locations, sized [batch_size, 8732, 4].

conf_preds: (tensor) predicted class confidences, sized [batch_size, 8732, num_classes].

conf_targets: (tensor) encoded target classes, sized [batch_size, 8732].

loss:

(tensor) loss = SmoothL1Loss(loc_preds, loc_targets) + CrossEntropyLoss(conf_preds, conf_targets).

'''

batch_size, num_boxes, _ = loc_preds.size()

pos = conf_targets>0 # [N,8732], pos means the box matched.

# print(pos.size())

num_matched_boxes = pos.data.long().sum()

if num_matched_boxes == 0:

return Variable(torch.Tensor([0])), Variable(torch.Tensor([0]))

################################################################

# loc_loss = SmoothL1Loss(pos_loc_preds, pos_loc_targets)

################################################################

pos_mask = pos.unsqueeze(2).expand_as(loc_preds) # [N,8732,4]

pos_loc_preds = loc_preds[pos_mask].view(-1,4) # [#pos,4]

# print(pos_loc_preds.size())

pos_loc_targets = loc_targets[pos_mask].view(-1,4) # [#pos,4]

loc_loss = F.smooth_l1_loss(pos_loc_preds, pos_loc_targets, size_average=False)

################################################################

# conf_loss = CrossEntropyLoss(pos_conf_preds, pos_conf_targets)

# + CrossEntropyLoss(neg_conf_preds, neg_conf_targets)

################################################################

# print('1',conf_preds.size()) # [N, 8732, 16]

# print('2',conf_targets.size()) # [N, 8732]

conf_loss = self.cross_entropy_loss(conf_preds.view(-1,self.num_classes), \

conf_targets.view(-1)) # [N*8732,]

# print(conf_loss.size()) # [8732, 8732]

conf_loss = conf_loss.view(-1, 8732)

neg = self.hard_negative_mining(conf_loss, pos) # [N,8732]

pos_mask = pos.unsqueeze(2).expand_as(conf_preds) # [N,8732,21]

# print(conf_preds.size()) # [N, 8732, 16]

# print(neg.size()) # [N, 8732*8732]

neg_mask = neg.unsqueeze(2).expand_as(conf_preds) # [N,8732,21]

mask = (pos_mask+neg_mask).gt(0)

pos_and_neg = (pos+neg).gt(0)

preds = conf_preds[mask].view(-1,self.num_classes) # [#pos+#neg,21]

targets = conf_targets[pos_and_neg] # [#pos+#neg,]

conf_loss = F.cross_entropy(preds, targets, size_average=False)

loc_loss /= num_matched_boxes

conf_loss /= num_matched_boxes

# print('loc_loss: %f conf_loss: %f' % (loc_loss.data[0], conf_loss.data[0]), end=' ')

return loc_loss, conf_loss

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值