SimCLR论文中损失函数求导梯度

题记

关于《A Simple Framework for Contrastive Learning of Visual Representations》这篇文章已经有很多大神整理过了,具体可见
【深度学习】详解 SimCLR
AI的未来:自监督,谷歌SimCLR-A Simple Framework for Contrastive Learning of Visual Representations。我在这里就不详细阐述,在看论文时对文中loss函数求导有疑问。在网上没有查阅到相关资料,在请教师兄后才得以解决,特作此笔记,希望能帮助到有同样疑惑的小伙伴。

NT-Xent求导梯度

在算法中损失函数是这样的,
ℓ ( i , j ) = − log ⁡ exp ⁡ ( s i , j / τ ) ∑ k = 1 2 N 1 [ k ≠ i ] exp ⁡ ( s i , k / τ ) \ell(i, j)=-\log \frac{\exp \left(s_{i, j} / \tau\right)}{\sum_{k=1}^{2 N} \mathbb{1}_{[k \neq i]} \exp \left(s_{i, k} / \tau\right)} (i,j)=logk=12N1[k=i]exp(si,k/τ)exp(si,j/τ)
其中 s i , j s_{i, j} si,j为相似性度量,是 sim ⁡ ( u , v ) = u ⊤ v / ∥ u ∥ ∥ v ∥ \operatorname{sim}(u,v) =u^\top v/\|u\|\|v\| sim(u,v)=uv/∥u∥∥v的缩写, 1 \mathbb{1} 1是判别函数,保证 k ≠ i k \neq i k=i.
在表格中是这样的,
u T v + / τ − log ⁡ ∑ v ∈ { v + , v − } exp ⁡ ( u T v / τ ) \boldsymbol{u}^{T} \boldsymbol{v}^{+} / \tau-\log \sum_{\boldsymbol{v} \in\left\{\boldsymbol{v}^{+}, \boldsymbol{v}^{-}\right\}} \exp \left(\boldsymbol{u}^{T} \boldsymbol{v} / \tau\right) uTv+/τlogv{v+,v}exp(uTv/τ)
这里实际上没加负号(因为可以是梯度下降法也可以是梯度上升法优化),并且将 l o g log log乘进去了,我们默认其是以 e e e为底
接下来将其看作 u \boldsymbol{u} u的函数,对其进行求导
f ( u ) ′ = v + / τ − ∑ v ∈ { v + , v − } exp ⁡ ( u T v / τ ) v / τ ∑ v ∈ { v + , v − } exp ⁡ ( u T v / τ ) f(\boldsymbol{u})^{'}=\boldsymbol{v}^{+} / \tau-\frac{{\sum_{\boldsymbol{v} \in\left\{\boldsymbol{v}^{+}, \boldsymbol{v}^{-}\right\}} \exp \left(\boldsymbol{u}^{T} \boldsymbol{v} / \tau\right)} \boldsymbol{v}/ \tau}{ {\sum_{\boldsymbol{v} \in\left\{\boldsymbol{v}^{+}, \boldsymbol{v}^{-}\right\}} \exp \left(\boldsymbol{u}^{T} \boldsymbol{v} / \tau\right)}} f(u)=v+/τv{v+,v}exp(uTv/τ)v{v+,v}exp(uTv/τ)v/τ
最后一项涉及三层复合函数求导,公式为

f [ g ( h ( x ) ) ] ′ = f ′ [ g ( h ( x ) ) ] g ′ [ h ( x ) ] h ′ ( x ) f[g(h(x))]^{\prime}=f^{\prime}[g(h(x))]g^{\prime}[h(x)]h^{\prime}(x) f[g(h(x))]=f[g(h(x))]g[h(x)]h(x)
这里要做的是把第一项给通分
f ( u ) ′ = ∑ v ∈ { v + , v − } exp ⁡ ( u T v / τ ) v + / τ − ∑ v ∈ { v + , v − } exp ⁡ ( u T v / τ ) v / τ ∑ v ∈ { v + , v − } exp ⁡ ( u T v / τ ) f(\boldsymbol{u})^{'}=\frac{ {\sum_{\boldsymbol{v} \in\left\{\boldsymbol{v}^{+}, \boldsymbol{v}^{-}\right\}} \exp \left(\boldsymbol{u}^{T} \boldsymbol{v} / \tau\right)}{\boldsymbol{v}^{+} / \tau-\sum_{\boldsymbol{v} \in\left\{\boldsymbol{v}^{+}, \boldsymbol{v}^{-}\right\}} \exp \left(\boldsymbol{u}^{T} \boldsymbol{v} / \tau\right)} \boldsymbol{v}/ \tau}{ {\sum_{\boldsymbol{v} \in\left\{\boldsymbol{v}^{+}, \boldsymbol{v}^{-}\right\}} \exp \left(\boldsymbol{u}^{T} \boldsymbol{v} / \tau\right)}} f(u)=v{v+,v}exp(uTv/τ)v{v+,v}exp(uTv/τ)v+/τv{v+,v}exp(uTv/τ)v/τ
下一步是把分子求和公式展开,可得

f ( u ) ′ = exp ⁡ ( u T v + / τ ) v + / τ + ∑ v − exp ⁡ ( u T v − / τ ) v + / τ − exp ⁡ ( u T v + / τ ) v + / τ − ∑ v − exp ⁡ ( u T v − / τ ) v − / τ ∑ v ∈ { v + , v − } exp ⁡ ( u T v / τ ) f(\boldsymbol{u})^{'}=\frac{ {\exp \left(\boldsymbol{u}^{T} \boldsymbol{v} ^{+}/ \tau\right)} \boldsymbol{v} ^{+}/ \tau+{\sum_{\boldsymbol{v} ^{-}} \exp \left(\boldsymbol{u}^{T} \boldsymbol{v} ^{-}/ \tau\right)}{\boldsymbol{v}^{+} / \tau- \exp \left(\boldsymbol{u}^{T} \boldsymbol{v}^{+} / \tau\right) \boldsymbol{v} ^{+}/ \tau-\sum_{\boldsymbol{v} ^{-}} \exp \left(\boldsymbol{u}^{T} \boldsymbol{v}^{-} / \tau\right)} \boldsymbol{v}^{-}/ \tau}{ {\sum_{\boldsymbol{v} \in\left\{\boldsymbol{v}^{+}, \boldsymbol{v}^{-}\right\}} \exp \left(\boldsymbol{u}^{T} \boldsymbol{v} / \tau\right)}} f(u)=v{v+,v}exp(uTv/τ)exp(uTv+/τ)v+/τ+vexp(uTv/τ)v+/τexp(uTv+/τ)v+/τvexp(uTv/τ)v/τ
可以看到分子的第一项和第三项一样,将其消掉可得
f ( u ) ′ = ∑ v − exp ⁡ ( u T v − / τ ) v + / τ − ∑ v − exp ⁡ ( u T v − / τ ) v − / τ ∑ v ∈ { v + , v − } exp ⁡ ( u T v / τ ) = ∑ v − exp ⁡ ( u T v − / τ ) v + / τ ∑ v ∈ { v + , v − } exp ⁡ ( u T v / τ ) − ∑ v − exp ⁡ ( u T v − / τ ) v − / τ ∑ v ∈ { v + , v − } exp ⁡ ( u T v / τ ) f(\boldsymbol{u})^{'}=\frac{{\sum_{\boldsymbol{v} ^{-}} \exp \left(\boldsymbol{u}^{T} \boldsymbol{v} ^{-}/ \tau\right)}{\boldsymbol{v}^{+} / \tau-\sum_{\boldsymbol{v} ^{-}} \exp \left(\boldsymbol{u}^{T} \boldsymbol{v}^{-} / \tau\right)} \boldsymbol{v}^{-}/ \tau}{ {\sum_{\boldsymbol{v} \in\left\{\boldsymbol{v}^{+}, \boldsymbol{v}^{-}\right\}} \exp \left(\boldsymbol{u}^{T} \boldsymbol{v} / \tau\right)}}=\frac{{\sum_{\boldsymbol{v} ^{-}} \exp \left(\boldsymbol{u}^{T} \boldsymbol{v} ^{-}/ \tau\right)}{\boldsymbol{v}^{+} / \tau } }{ {\sum_{\boldsymbol{v} \in\left\{\boldsymbol{v}^{+}, \boldsymbol{v}^{-}\right\}} \exp \left(\boldsymbol{u}^{T} \boldsymbol{v} / \tau\right)}} -\frac{{\sum_{\boldsymbol{v} ^{-}} \exp \left(\boldsymbol{u}^{T} \boldsymbol{v}^{-} / \tau\right)} \boldsymbol{v}^{-}/ \tau}{ {\sum_{\boldsymbol{v} \in\left\{\boldsymbol{v}^{+}, \boldsymbol{v}^{-}\right\}} \exp \left(\boldsymbol{u}^{T} \boldsymbol{v} / \tau\right)}} f(u)=v{v+,v}exp(uTv/τ)vexp(uTv/τ)v+/τvexp(uTv/τ)v/τ=v{v+,v}exp(uTv/τ)vexp(uTv/τ)v+/τv{v+,v}exp(uTv/τ)vexp(uTv/τ)v/τ
到这一步已经和论文给出的非常相似了,将第一项分子再加上正样本减去正样本可得
f ( u ) ′ = ∑ v − exp ⁡ ( u T v − / τ ) v + / τ + exp ⁡ ( u T v + / τ ) v + / τ − exp ⁡ ( u T v + / τ ) v + / τ ∑ v ∈ { v + , v − } exp ⁡ ( u T v / τ ) − ∑ v − exp ⁡ ( u T v − / τ ) v − / τ ∑ v ∈ { v + , v − } exp ⁡ ( u T v / τ ) f(\boldsymbol{u})^{'}=\frac{{\sum_{\boldsymbol{v} ^{-}} \exp \left(\boldsymbol{u}^{T} \boldsymbol{v} ^{-}/ \tau\right)}{\boldsymbol{v}^{+} / \tau }+ \exp \left(\boldsymbol{u}^{T} \boldsymbol{v} ^{+}/ \tau\right)\boldsymbol{v}^{+} / \tau-\exp \left(\boldsymbol{u}^{T} \boldsymbol{v} ^{+}/ \tau\right)\boldsymbol{v}^{+} / \tau}{ {\sum_{\boldsymbol{v} \in\left\{\boldsymbol{v}^{+}, \boldsymbol{v}^{-}\right\}} \exp \left(\boldsymbol{u}^{T} \boldsymbol{v} / \tau\right)}} -\frac{{\sum_{\boldsymbol{v} ^{-}} \exp \left(\boldsymbol{u}^{T} \boldsymbol{v}^{-} / \tau\right)} \boldsymbol{v}^{-}/ \tau}{ {\sum_{\boldsymbol{v} \in\left\{\boldsymbol{v}^{+}, \boldsymbol{v}^{-}\right\}} \exp \left(\boldsymbol{u}^{T} \boldsymbol{v} / \tau\right)}} f(u)=v{v+,v}exp(uTv/τ)vexp(uTv/τ)v+/τ+exp(uTv+/τ)v+/τexp(uTv+/τ)v+/τv{v+,v}exp(uTv/τ)vexp(uTv/τ)v/τ
将第一项的分子中正负样本合并可得
f ( u ) ′ = ∑ v ∈ { v + , v − } exp ⁡ ( u T v / τ ) v + / τ − exp ⁡ ( u T v + / τ ) v + / τ ∑ v ∈ { v + , v − } exp ⁡ ( u T v / τ ) − ∑ v − exp ⁡ ( u T v − / τ ) v − / τ ∑ v ∈ { v + , v − } exp ⁡ ( u T v / τ ) f(\boldsymbol{u})^{'}=\frac{{\sum_{\boldsymbol{v} \in\left\{\boldsymbol{v}^{+}, \boldsymbol{v}^{-}\right\}} \exp \left(\boldsymbol{u}^{T} \boldsymbol{v} / \tau\right)}\boldsymbol{v}^{+} / \tau-\exp \left(\boldsymbol{u}^{T} \boldsymbol{v} ^{+}/ \tau\right)\boldsymbol{v}^{+} / \tau}{ {\sum_{\boldsymbol{v} \in\left\{\boldsymbol{v}^{+}, \boldsymbol{v}^{-}\right\}} \exp \left(\boldsymbol{u}^{T} \boldsymbol{v} / \tau\right)}} -\frac{{\sum_{\boldsymbol{v} ^{-}} \exp \left(\boldsymbol{u}^{T} \boldsymbol{v}^{-} / \tau\right)} \boldsymbol{v}^{-}/ \tau}{ {\sum_{\boldsymbol{v} \in\left\{\boldsymbol{v}^{+}, \boldsymbol{v}^{-}\right\}} \exp \left(\boldsymbol{u}^{T} \boldsymbol{v} / \tau\right)}} f(u)=v{v+,v}exp(uTv/τ)v{v+,v}exp(uTv/τ)v+/τexp(uTv+/τ)v+/τv{v+,v}exp(uTv/τ)vexp(uTv/τ)v/τ
提取公因式后化简得
f ( u ) ′ = ( v + τ ) ( 1 − exp ⁡ ( u T v + / τ ) ∑ v ∈ { v + , v − } exp ⁡ ( u T v / τ ) ) − ( ∑ v − exp ⁡ ( u T v − / τ ) ∑ v ∈ { v + , v − } exp ⁡ ( u T v / τ ) ) ( v − τ ) f(\boldsymbol{u})^{'}=(\frac{\boldsymbol{v}^{+}}{\tau} )(1- \frac{\exp \left(\boldsymbol{u}^{T} \boldsymbol{v} ^{+}/ \tau\right)}{ {\sum_{\boldsymbol{v} \in\left\{\boldsymbol{v}^{+}, \boldsymbol{v}^{-}\right\}} \exp \left(\boldsymbol{u}^{T} \boldsymbol{v} / \tau\right)}}) -(\frac{{\sum_{\boldsymbol{v} ^{-}} \exp \left(\boldsymbol{u}^{T} \boldsymbol{v}^{-} / \tau\right)} }{ {\sum_{\boldsymbol{v} \in\left\{\boldsymbol{v}^{+}, \boldsymbol{v}^{-}\right\}} \exp \left(\boldsymbol{u}^{T} \boldsymbol{v} / \tau\right)}})(\frac{\boldsymbol{v}^{-}}{\tau} ) f(u)=(τv+)(1v{v+,v}exp(uTv/τ)exp(uTv+/τ))(v{v+,v}exp(uTv/τ)vexp(uTv/τ))(τv)
Z ( u ) = ∑ v ∈ { v + , v − } exp ⁡ ( u T v / τ ) {Z}\left(\mathcal{u}\right)={\sum_{\boldsymbol{v} \in\left\{\boldsymbol{v}^{+}, \boldsymbol{v}^{-}\right\}} \exp \left(\boldsymbol{u}^{T} \boldsymbol{v} / \tau\right)} Z(u)=v{v+,v}exp(uTv/τ),可得
f ( u ) ′ = ( v + τ ) ( 1 − exp ⁡ ( u T v + / τ ) Z ( u ) ) − ( ∑ v − exp ⁡ ( u T v − / τ ) Z ( u ) ) ( v − τ ) f(\boldsymbol{u})^{'}=(\frac{\boldsymbol{v}^{+}}{\tau} )(1- \frac{\exp \left(\boldsymbol{u}^{T} \boldsymbol{v} ^{+}/ \tau\right)}{ {Z}\left(\boldsymbol{u}\right)}) -(\frac{{\sum_{\boldsymbol{v} ^{-}} \exp \left(\boldsymbol{u}^{T} \boldsymbol{v}^{-} / \tau\right)} }{ {Z}\left(\boldsymbol{u}\right)})(\frac{\boldsymbol{v}^{-}}{\tau} ) f(u)=(τv+)(1Z(u)exp(uTv+/τ))(Z(u)vexp(uTv/τ))(τv)
再稍微一整理就得到论文中最后的结果

f ( u ) ′ = ( 1 − exp ⁡ ( u T v + / τ ) Z ( u ) ) / τ v + − ∑ v − exp ⁡ ( u T v − / τ ) Z ( u ) / τ v − f(\boldsymbol{u})^{'}=(1-\frac{\exp(\boldsymbol{u}^{T}\boldsymbol{v}^{+}/\tau)}{Z(\boldsymbol{u})})/\tau \boldsymbol{v}^{+}-\sum_{\boldsymbol{v}^{-}}\frac{\exp(\boldsymbol{u}^{T}\boldsymbol{v}^{-}/\tau)}{Z(\boldsymbol{u})}/\tau \boldsymbol{v}^{-} f(u)=(1Z(u)exp(uTv+/τ))/τv+vZ(u)exp(uTv/τ)/τv

需要注意的是,论文最后结果中的 v − \boldsymbol{v}^{-} v v + \boldsymbol{v}^{+} v+都应该是看做后面乘上去的(如果作为分母的话,是要加括号的,推出来的公式也确实如此);
此外, Z ( u ) {Z}\left(\boldsymbol{u}\right) Z(u)的定义论文也没有给出,不过通过推导应该是上文所描述的。

NT-Logistic求导梯度

表格中的损失函数是这样的
l o g σ ( u T v + / τ ) + l o g σ ( − u T v − / τ ) log\sigma (\boldsymbol{u}^{T}\boldsymbol{v}^{+}/\tau )+log\sigma (-\boldsymbol{u}^{T}\boldsymbol{v}^{-}/\tau ) logσ(uTv+/τ)+logσ(uTv/τ)
这里实际上也没加负号(因为可以是梯度下降法也可以是梯度上升法优化)。接下来将其看作 u \boldsymbol{u} u的函数,对其进行求导,我们默认其是以 e e e为底,激活函数 σ ( ⋅ ) \sigma(\cdot) σ()为 sigmoid 函数
f ( u ) ′ = 1 σ ( u T v + / τ ) σ ( u T v + / τ ) ( 1 − σ ( u T v + / τ ) ) ( v + τ ) + 1 σ ( − u T v + / τ ) σ ( − u T v + / τ ) ( 1 − σ ( − u T v + / τ ) ) ( − v − τ ) f(\boldsymbol{u})^{'}=\frac{1}{\sigma (\boldsymbol{u}^{T}\boldsymbol{v}^{+}/\tau) }\sigma(\boldsymbol{u}^{T}\boldsymbol{v}^{+}/\tau)(1-\sigma(\boldsymbol{u}^{T}\boldsymbol{v}^{+}/\tau))(\frac{\boldsymbol{v}^{+}}{\tau})+\frac{1}{\sigma (-\boldsymbol{u}^{T}\boldsymbol{v}^{+}/\tau) }\sigma(-\boldsymbol{u}^{T}\boldsymbol{v}^{+}/\tau)(1-\sigma(-\boldsymbol{u}^{T}\boldsymbol{v}^{+}/\tau))(-\frac{\boldsymbol{v}^{-}}{\tau}) f(u)=σ(uTv+/τ)1σ(uTv+/τ)(1σ(uTv+/τ))(τv+)+σ(uTv+/τ)1σ(uTv+/τ)(1σ(uTv+/τ))(τv)
化简得
f ( u ) ′ = ( 1 − σ ( u T v + / τ ) ) ( v + τ ) − ( 1 − σ ( − u T v + / τ ) ) ( v − τ ) f(\boldsymbol{u})^{'}=(1-\sigma(\boldsymbol{u}^{T}\boldsymbol{v}^{+}/\tau))(\frac{\boldsymbol{v}^{+}}{\tau})-(1-\sigma(-\boldsymbol{u}^{T}\boldsymbol{v}^{+}/\tau))(\frac{\boldsymbol{v}^{-}}{\tau}) f(u)=(1σ(uTv+/τ))(τv+)(1σ(uTv+/τ))(τv)
将sigmoid 函数带入并通分
f ( u ) ′ = 1 + e − u T v + / τ − 1 1 + e − u T v + / τ ( v + τ ) − 1 + e u T v + / τ − 1 1 + e u T v + / τ ( v − τ ) f(\boldsymbol{u})^{'}=\frac{1+e^{-\boldsymbol{u}^{T}\boldsymbol{v}^{+}/\tau}-1}{1+e^{-\boldsymbol{u}^{T}\boldsymbol{v}^{+}/\tau}} (\frac{\boldsymbol{v}^{+}}{\tau})-\frac{1+e^{\boldsymbol{u}^{T}\boldsymbol{v}^{+}/\tau}-1}{1+e^{\boldsymbol{u}^{T}\boldsymbol{v}^{+}/\tau}}(\frac{\boldsymbol{v}^{-}}{\tau}) f(u)=1+euTv+/τ1+euTv+/τ1(τv+)1+euTv+/τ1+euTv+/τ1(τv)
整理得
f ( u ) ′ = e − u T v + / τ 1 + e − u T v + / τ ( v + τ ) − e u T v + / τ 1 + e u T v + / τ ( v − τ ) f(\boldsymbol{u})^{'}=\frac{e^{-\boldsymbol{u}^{T}\boldsymbol{v}^{+}/\tau}}{1+e^{-\boldsymbol{u}^{T}\boldsymbol{v}^{+}/\tau}} (\frac{\boldsymbol{v}^{+}}{\tau})-\frac{e^{\boldsymbol{u}^{T}\boldsymbol{v}^{+}/\tau}}{1+e^{\boldsymbol{u}^{T}\boldsymbol{v}^{+}/\tau}}(\frac{\boldsymbol{v}^{-}}{\tau}) f(u)=1+euTv+/τeuTv+/τ(τv+)1+euTv+/τeuTv+/τ(τv)
第一项分子分母同乘 e u T v + / τ e^{\boldsymbol{u}^{T}\boldsymbol{v}^{+}/\tau} euTv+/τ;第二项分子分母同乘 e − u T v + / τ e^{-\boldsymbol{u}^{T}\boldsymbol{v}^{+}/\tau} euTv+/τ
f ( u ) ′ = 1 e u T v + / τ + 1 ( v + τ ) − 1 e − u T v + / τ + 1 ( v − τ ) f(\boldsymbol{u})^{'}=\frac{1}{e^{\boldsymbol{u}^{T}\boldsymbol{v}^{+}/\tau}+1} (\frac{\boldsymbol{v}^{+}}{\tau})-\frac{1}{e^{-\boldsymbol{u}^{T}\boldsymbol{v}^{+}/\tau}+1}(\frac{\boldsymbol{v}^{-}}{\tau}) f(u)=euTv+/τ+11(τv+)euTv+/τ+11(τv)
最后可得
f ( u ) ′ = σ ( − u T v + / τ ) ( v + τ ) − σ ( u T v + / τ ) ( v − τ ) f(\boldsymbol{u})^{'}=\sigma(-\boldsymbol{u}^{T}\boldsymbol{v}^{+}/\tau) (\frac{\boldsymbol{v}^{+}}{\tau})-\sigma(\boldsymbol{u}^{T}\boldsymbol{v}^{+}/\tau)(\frac{\boldsymbol{v}^{-}}{\tau}) f(u)=σ(uTv+/τ)(τv+)σ(uTv+/τ)(τv)
可以发现和论文给出结果一致,论文给出结果为
σ ( − u T v + / τ ) / τ v + − σ ( u T v + / τ ) / τ v − \sigma(-\boldsymbol{u}^{T}\boldsymbol{v}^{+}/\tau)/\tau \boldsymbol{v}^{+} -\sigma(\boldsymbol{u}^{T}\boldsymbol{v}^{+}/\tau)/\tau \boldsymbol{v}^{-} σ(uTv+/τ)/τv+σ(uTv+/τ)/τv

参考博文及感谢

部分内容参考以下链接,这里表示感谢 Thanks♪(・ω・)ノ
参考博文1 【深度学习】详解 SimCLR
https://blog.youkuaiyun.com/qq_39478403/article/details/128358529
参考博文2 AI的未来:自监督,谷歌SimCLR-A Simple Framework for Contrastive Learning of Visual Representations
https://zhuanlan.zhihu.com/p/372073905
参考博文3 神经网络中的常用激活函数和导数
https://blog.youkuaiyun.com/lw_power/article/details/90291928
参考博文4 矩阵求导公式的数学推导(矩阵求导——基础篇)
https://zhuanlan.zhihu.com/p/273729929
参考博文5 向量对向量求导
https://zhuanlan.zhihu.com/p/449988999

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值