TensorFlow 中计算给定 logits 的 sigmoid 交叉熵 tf.sigmoid_cross_entropy_with_logits 的基本用法及实例代码

本文详细解析了TensorFlow中sigmoid交叉熵损失函数的原理与应用,包括其计算公式、参数说明及实例演示,适用于多标签分类任务。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

一、环境

TensorFlow API r1.12

CUDA 9.2 V9.2.148

cudnn64_7.dll

Python 3.6.3

Windows 10

 

二、官方说明

计算输入张量 logits 的 sigmoid 交叉熵

https://tensorflow.google.cn/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits

tf.nn.sigmoid_cross_entropy_with_logits(
    _sentinel=None,
    labels=None,
    logits=None,
    name=None
)

计算离散型分类任务中的概率误差,其中每个类别都是独立的但不是互斥的。

可以用它来计算多标签分类任务,即一幅图片可以同时具有多个类别标签,如大象和狗

其 logistic loss 的计算方式如下,其中 x = logits, z = labels

z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
= z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
= z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
= z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
= (1 - z) * x + log(1 + exp(-x))
= x - x * z + log(1 + exp(-x))

在 x < 0 的情况下,为了避免计算 exp(-x) 移除,将会按照如下方式计算 logistic 损失

x - x * z + log(1 + exp(-x))
= log(exp(x)) - x * z + log(1 + exp(-x))
= - x * z + log(1 + exp(x))

因此,为了确保计算稳定性并避免移除,最终采用下面的等式来计算 logistic 损失

max(x, 0) - x * z + log(1 + exp(-abs(x)))

参数:

_sentinel:用于保护位置参数,内部的,不使用 

labels:和 logits 具有相同类型和形状的张量

logits:类型为 float 32 或 float64 的张量

name:可选参数,操作的名称

 

返回:

具有分量式的 logistic losses 且形状与 logistic 相同的张量

 

三、实例

>>> import tensorflow as tf

>>> logits = tf.constant(value=[[[0.1,0.2,0.3],[0.4,0.5,0.6],[0.7,0.8,0.9]],[[-0.1,-0.2,-0.3],[-0.4,-0.5,-0.6],[-0.7,-0.8,-0.9]]], dtype=tf.float32)
>>> logits
<tf.Tensor 'Const:0' shape=(2, 3, 3) dtype=float32>


>>> labels = tf.constant(value=[[[1.0,1.0,0.0],[0.0,0.0,1.0],[1.0,1.0,1.0]],[[0.0,0.0,0.0],[0.0,0.0,0.0],[1.0,0.0,1.0]]],dtype=tf.float32)
>>> labels
<tf.Tensor 'Const_2:0' shape=(2, 3, 3) dtype=float32>

>>> results = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits)
>>> results
<tf.Tensor 'logistic_loss:0' shape=(2, 3, 3) dtype=float32>

>>> sess = tf.InteractiveSession()

>>> print(sess.run(results))
[[[0.6443967  0.59813887 0.8543552 ]
  [0.91301525 0.974077   0.43748793]
  [0.40318602 0.3711007  0.34115386]]

 [[0.6443967  0.59813887 0.5543552 ]
  [0.5130153  0.474077   0.43748793]
  [1.103186   0.3711007  1.2411538 ]]]

>>> sess.close()

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

csdn-WJW

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值