最近学了下陈天奇大佬的DeepLearningSystem课程,HW2里面有一块是对LogSumExp(简称LSE)算子求导数。
LSE应用非常广泛(例如多分类里的Softmax可以利用LSE来解决上溢问题 )。
所以这篇文章对LSE做了一个求导(但写的有点繁琐
顺便练练LaTeX 😄
下面是一些符号的说明:
i n p u t : z ∈ R n a r g m a x ( z ) = j , max z = z j z i ^ = z i − max z = z i − z j L o g S u m E x p ( z i ) = log ( ∑ k = 1 n exp ( z i − max z ) ) + max z = log ( ∑ k = 1 n exp ( z i ^ ) ) + z j L S E = L o g S u m E x p input: z \in \mathbb{R}^n \\ argmax \left(z \right) = j, \max{z}=z_j\\ \hat{z_{i}} = z_{i} - \max{z}=z_i-z_j\\ LogSumExp(z_i) = \log(\sum_{k=1}^{n}\exp(z_{i}-\max{z}))+\max{z}=\log(\sum_{k=1}^{n}\exp(\hat{z_i}))+z_j \\ LSE=LogSumExp input:z∈Rnargmax(z)=j,maxz=zjzi^=zi−maxz=zi−zjLogSumExp(zi)=log(k=1∑nexp(zi−maxz))+maxz=log(k=1∑nexp(zi^))+zjLSE=LogSumExp
- 当 i ≠ j i\neq j i=j时
∂ L S E ∂ z i = ∂ L S E ∂ log ∑ k = 1 n exp ( z k ^ ) ⋅ ∂ log ∑ k = 1 n exp ( z k ^ ) ∂ z i + ∂ L S E ∂ max z ⋅ ∂ max z ∂ z i = 1 ⋅ ∂ log ∑ k = 1 n exp ( z k ^ ) ∂ ∑ k = 1 n exp ( z k ^ ) ⋅ ∂ ∑ k = 1 n exp ( z k ^ ) ∂ z i ^ + 1 ⋅ 0 = ∂ log ∑ k = 1 n exp ( z k ^ ) ∂ ∑ k = 1 n exp ( z k ^ ) ⋅ ∑ k = 1 n ( ∂ exp ( z k ^ ) ∂ z i ^ ) = 1 ∑ k = 1 n exp ( z k ^ ) ⋅ ∑ k = 1 n ( ∂ exp ( z k ^ ) ∂ z k ^ ⋅ ∂ z k ^ ∂ z i ) = 1 ∑ k = 1 n exp ( z k ^ ) ⋅ ∑ k = 1 n ( exp ( z k ^ ) ⋅ ∂ ( z k − max z ) ∂ z i ) = 1 ∑ k = 1 n exp ( z k ^ ) ⋅ ∑ k = 1 n ( exp ( z k ^ ) ⋅ I ( k = i ) ) = 1 ∑ exp ( z k ^ ) ⋅ exp ( z i ^ ) = exp ( z i ^ ) ∑ k = 1 n exp ( z k ^ ) \begin{align} \frac{\partial{LSE}}{\partial{z_{i}}} &= \frac{\partial{LSE}}{\partial{\log\sum_{k=1}^{n}\exp(\hat{z_{k}})}} \cdot \frac{\partial{\log\sum_{k=1}^{n}\exp(\hat{z_{k}})}}{\partial{z_{i}}} + \frac{\partial{LSE}}{\partial{\max{z}}} \cdot \frac{\partial{\max{z}}}{\partial{z_{i}}} \\ &= 1 \cdot \frac{\partial{\log\sum_{k=1}^{n}\exp(\hat{z_{k}})}}{\partial{\sum_{k=1}^{n}{\exp(\hat{z_{k}})}}} \cdot \frac{\partial{ {\sum_{k=1}^{n}\exp(\hat{z_{k}})}}}{\partial{\hat{z_{i}}}} + 1 \cdot 0 \\ &= \frac{\partial{\log\sum_{k=1}^{n}\exp(\hat{z_{k}})}}{\partial{\sum_{k=1}^{n}{\exp(\hat{z_{k}})}}} \cdot \sum_{k=1}^{n}\left(\frac{\partial{ {\exp(\hat{z_{k}})}}}{\partial{\hat{z_{i}}}}\right) \\ &= \frac{1}{ {\sum_{k=1}^{n}\exp(\hat{z_{k}})}} \cdot \sum_{k=1}^{n}\left(\frac{\partial{ {\exp(\hat{z_{k}})}}}{\partial{\hat{z_{k}}}} \cdot \frac{ {\partial{ {\hat{z_{k}}}}}}{\partial{z_{i}}} \right) \\ &= \frac{1}{ {\sum_{k=1}^{n}\exp(\hat{z_{k}})}} \cdot \sum_{k=1}^{n}\left(\exp(\hat{z_k}) \cdot \frac{ {\partial{({z_{k}-\max{z}})}}}{\partial{z_{i}}} \right) \\ &= \frac{1}{ {\sum_{k=1}^{n}\exp(\hat{z_{k}})}} \cdot \sum_{k=1}^{n}\left(\exp(\hat{z_{k}}) \cdot \mathbb{I}\left(k=i\right) \right) \\ &= \frac{1}{ {\sum\exp(\hat{z_{k}})}} \cdot \exp(\hat{z_{i}}) \\ &= \frac{\exp(\hat{z_{i}})}{\sum_{k=1}^{n} {\exp(\hat{z_{k}})}} \nonumber \end{align} ∂zi∂LSE=∂log∑k=1nexp(zk^)∂LSE⋅∂zi∂log∑