def class_balanced_sigmoid_cross_entropy(logits, label, name='cross_entropy_loss'):
"""
The class-balanced cross entropy loss,
as in `Holistically-Nested Edge Detection
<http://arxiv.org/abs/1504.06375>`_.
This is more numerically stable than class_balanced_cross_entropy
:param logits: size: the logits.
:param label: size: the ground truth in {0,1}, of the same shape as logits.
:returns: a scalar. class-balanced cross entropy loss
"""
y = tf.cast(label, tf.float32)
count_neg = tf.reduce_sum(1. - y) # the number of 0 in y
count_pos = tf.reduce_sum(y) # the number of 1 in y (less than count_neg)
beta = count_neg / (count_neg + count_pos)
pos_weight = beta / (1 - beta)
cost = tf.nn.weighted_cross_entropy_with_logits(logits, y, pos_weight)
cost = tf.reduce_mean(cost * (1 - beta), name=name)
return cost
深度学习【14】代价敏感损失函数
最新推荐文章于 2023-12-31 01:10:51 发布