摘要
本文求解 softmax + cross-entropy 在反向传播中的梯度.
相关
配套代码, 请参考文章 :
Python和PyTorch对比实现多标签softmax + cross-entropy交叉熵损失及反向传播
有关 softmax 的详细介绍, 请参考 :
softmax函数详解及反向传播中的梯度求导
有关 cross-entropy 的详细介绍, 请参考 :
通过案例详解cross-entropy交叉熵损失函数
系列文章索引 :
https://blog.youkuaiyun.com/oBrightLamp/article/details/85067981
正文
在大多数教程中, softmax 和 cross-entropy 总是一起出现, 求梯度的时候也是一起考虑.
softmax 和 cross-entropy 的梯度, 已经在上面的两篇文章中分别给出.
1. 题目
考虑一个输入向量 x, 经 softmax 函数归一化处理后得到向量 s 作为预测的概率分布, 已知向量 y 为真实的概率分布, 由 cross-entropy 函数计算得出误差值 error (标量 e ), 求 e 关于 x 的梯度.
x = ( x 1 , x 2 , x 3 , ⋯   , x k ) s = s o f t m a x ( x ) s i = e x i ∑ t = 1 k e x t e = c r o s s E n t r o p y ( s , y ) = − ∑ i = 1 k y i l o g ( s i ) \quad\\ x = (x_1, x_2, x_3, \cdots, x_k)\\ \quad\\ s = softmax(x)\\ \quad\\ s_{i} = \frac{e^{x_{i}}}{ \sum_{t = 1}^{k}e^{x_{t}}} \\ \quad\\ e = crossEntropy(s, y) = -\sum_{i = 1}^{k}y_{i}log(s_{i})\\ x=(x1,x2,x3,⋯,xk)s=softmax(x)si=∑t=1kextexie=crossEntropy(s,y)=−i=1∑kyilog(si)
已知 :
∇ e ( s ) = ∂ e ∂ s = ( ∂ e ∂ s 1 , ∂ e ∂ s 2 , ⋯   , ∂ e ∂ s k ) = ( − y 1 s 1 , − y 2 s 2 , ⋯   , − y k s k )    ∇ s ( x ) = ∂ s ∂ x = ( ∂ s 1 / ∂ x 1 ∂ s 1 / ∂ x 2 ⋯ ∂ s 1 / ∂ x k ∂ s 2 / ∂ x 1 ∂ s 2 / ∂ x 2 ⋯ ∂ s 2 / ∂ x k ⋮ ⋮ ⋱ ⋮ ∂ s k / ∂ x 1 ∂ s k / ∂ x 2 ⋯ ∂ s k / ∂ x k ) = ( − s 1 s 1 + s 1 − s 1 s 2 ⋯ − s 1 s k − s 2 s 1 − s 2 s 2 + s 2 ⋯ − s 2 s k ⋮ ⋮ ⋱ ⋮ − s k s 1 − s k s 2 ⋯ − s k s k + s k ) \nabla e_{(s)}=\frac{\partial e}{\partial s} =(\frac{\partial e}{\partial s_{1}},\frac{\partial e}{\partial s_{2}}, \cdots, \frac{\partial e}{\partial s_{k}}) =( -\frac{y_1}{s_1}, -\frac{y_2}{s_2},\cdots,-\frac{y_k}{s_k}) \\ \;\\ % ---------- \nabla s_{(x)}= \frac{\partial s}{\partial x}= \begin{pmatrix} \partial s_{1}/\partial x_{1}&\partial s_{1}/\partial x_{2}& \cdots&\partial s_{1}/\partial x_{k}\\ \partial s_{2}/\partial x_{1}&\partial s_{2}/\partial x_{2}& \cdots&\partial s_{2}/\partial x_{k}\\ \vdots & \vdots & \ddots & \vdots \\ \partial s_{k}/\partial x_{1}&\partial s_{k}/\partial x_{2}& \cdots&\partial s_{k}/\partial x_{k}\\ \end{pmatrix}= \begin{pmatrix} -s_{1}s_{1} + s_{1} & -s_{1}s_{2} & \cdots & -s_{1}s_{k} \\ -s_{2}s_{1} & -s_{2}s_{2} + s_{2} & \cdots & -s_{2}s_{k} \\ \vdots & \vdots & \ddots & \vdots \\ -s_{k}s_{1} & -s_{k}s_{2} & \cdots & -s_{k}s_{k} + s_{k} \end{pmatrix} \\ \quad\\ ∇e(s)=∂s∂e=(∂s1∂e,∂s2∂e,⋯,∂sk∂e)=