前面博客讲述了数据流的输入,通过网络后我们要计算网络损失,我们通过CrossEntropy来进行类别分类,通过SmoothL1来进行计算损失。由于是SSD需要考虑前景与背景的数据平衡,因此根据论文我们选择了Positive:Negitive=1:3。(查询包的位置site-package)
1、首先判断每张图Positive的个数,与整个Batch下Positive的数量。
for cp, bp, ct, bt in zip(*[cls_pred, box_pred, cls_target, box_target]):
pos_samples = (ct > 0)
num_pos.append(pos_samples.sum()
num_pos_all = sum([p.asscalar() for p in num_pos])
2、求解每张图片的loss
cls_losses = []
box_losses = []
sum_losses = []
for cp, bp, ct, bt in zip(*[cls_pred, box_pred, cls_target, box_target]):
pred = nd.log_softmax(cp, axis=-1)##进行Log_SoftMax,求解出交叉熵
pos = ct > 0 ##pos的shape为(N*M)
cls_loss = -nd.pick(pred, ct, axis=-1, keepdims=False)##计算loss
rank = (cls_loss * (pos - 1)).argsort(axis=1).argsort(axis=1)#Negtive按照Loss计算Loss最大的,以便于进行类似于OHEM来提升鲁棒性。
hard_negative = rank < nd.maximum(self._min_hard_negatives, pos.sum(axis=1)
* self._negative_mining_ratio).expand_dims(-1)
# mask out if not positive or negative
cls_loss = nd.where((pos + hard_negative) > 0, cls_loss, nd.zeros_like(cls_loss))###找到要求的cls_loss
cls_losses.append(nd.sum(cls_loss, axis=0, exclude=True) / max(1., num_pos_all))##求解class
bp = _reshape_like(nd, bp, bt)
box_loss = nd.abs(bp - bt)##计算差值
box_loss = nd.where(box_loss > self._rho, box_loss - 0.5 * self._rho,
(0.5 / self._rho) * nd.square(box_loss))
# box loss only apply to positive samples
box_loss = box_loss * pos.expand_dims(axis=-1)
box_losses.append(nd.sum(box_loss, axis=0, exclude=True) / max(1., num_pos_all))##exclude表示的是除外
sum_losses.append(cls_losses[-1] + self._lambd * box_losses[-1])