def _listwise_loss(self, logits_in, labels_in, masks):
"""
Copied from Hanshu.
:param logits_in:
:param labels_in:
:param masks: B,B
:return:
"""
_EPSILON = 1e-10
logits_bb = tf.subtract(tf.expand_dims(logits_in, 1), tf.zeros_like(tf.expand_dims(logits_in, 0)))
logits_bb = tf.squeeze(logits_bb, [2])
logits_final = tf.where(masks, logits_bb, tf.log(_EPSILON) * tf.ones_like(logits_bb))
labels_bb = tf.subtract(tf.expand_dims(labels_in, 1), tf.zeros_like(tf.expand_dims(labels_in, 0)))
labels_bb = tf.squeeze(labels_bb, [2])
labels_final = tf.where(masks, labels_bb, tf.zeros_like(labels_bb))
label_sum = tf.reduce_sum(input_tensor=labels_final, axis=0, keep_dims=True)
nonzero_mask = tf.greater(tf.reshape(label_sum, [-1]), 0.0)
padded_labels = tf.where(nonzero_mask, labels_final, _EPSILON * tf.ones_like(labels_final)) * tf.cast(
masks, dtype=tf.float32)
padded_label_sum = tf.reduce_sum(input_tensor=padded_labels, axis=0, keep_dims=True)
normalized_labels = padded_labels / padded_label_sum
exps = tf.exp(0.5 * logits_final) * tf.cast(masks, dtype=tf.float32)
softmax = tf.divide(exps, tf.reduce_sum(exps, axis=0))
losses = -tf.reduce_sum(
normalized_labels * tf.log(softmax + _EPSILON) * tf.cast(masks, dtype=tf.float32), axis=0)
per_row_weights = tf.reduce_sum(tf.cast(masks, dtype=tf.float32), axis=1)
list_cnt = tf.reduce_sum(1.0 / per_row_weights)
listwise_loss = tf.reduce_sum(losses / per_row_weights) / list_cnt
# summaries.
tf.summary.scalar("list_cnt", list_cnt)
self.debug_dict = {"list_cnt": list_cnt}
return listwise_loss
def rank_loss(self):
logits = self.logits
labels = self.label
pairwise_logits = logits - tf.transpose(logits)
logging.info("[rank_loss] pairwise logits: {}".format(pairwise_logits))
pairwise_mask = tf.greater(labels - tf.transpose(labels), 0)
logging.info("[rank_loss] mask: {}".format(pairwise_mask))
pairwise_logits = tf.boolean_mask(pairwise_logits, pairwise_mask)
logging.info("[rank_loss]: after masking: {}".format(pairwise_logits))
pairwise_psudo_labels = tf.ones_like(pairwise_logits)
rank_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
logits=pairwise_logits,
labels=pairwise_psudo_labels
))
# set rank loss to zero if a batch has no positive sample.
rank_loss = tf.where(tf.is_nan(rank_loss), tf.zeros_like(rank_loss), rank_loss)
return rank_loss