Forward Process
Defined as Markov Chain:q(x1:T|x0)=∏t=1Tq(xt|xt−1,xt−2,⋯ ,x0)=∏t=1Tq(xt|xt−1) q\left({\bold x}_{1:T}\middle\vert{\bold x}_0\right)=\prod_{t=1}^T{q\left({\bold x}_t\middle\vert{\bold x}_{t-1},{\bold x}_{t-2},\cdots,{\bold x}_0\right)}=\prod_{t=1}^T{q\left({\bold x}_t\middle\vert{\bold x}_{t-1}\right)} q(x1:T∣x0)=t=1∏Tq(xt∣xt−1,xt−2,⋯,x0)=t=1∏Tq(xt∣xt−1)whereq(xt|xt−1)=N(xt;1−βt⋅xt−1,βtI) q(\left.{\bold x}_t \middle\vert{\bold x}_{t-1}\right.)={\cal N}\left({\bold x}_t;\sqrt{1-\beta_t}\cdot{\bold x}_{t-1},\beta_t{\bold I}\right) q(xt∣xt−1)=N(xt;1−βt⋅xt−1,βtI)
Reparameterization Trick
xt=1−βt⋅xt−1+βt⋅ϵt {\bold x}_t=\sqrt{1-\beta_t}\cdot{\bold x}_{t-1}+\sqrt{\beta_t}\cdot{\boldsymbol\epsilon}_t xt=1−βt⋅xt−1+βt⋅ϵtwhereϵt∼N(0,I) {\boldsymbol\epsilon}_t\sim{\cal N}\left({\bold 0},{\bold I}\right) ϵt∼N(0,I)
Why μ2+σ2=1\mu^2+\sigma^2=1μ2+σ2=1
xt=1−βt(1−βt−1⋅xt−2+βt−1⋅ϵt−1)+βt⋅ϵt=(1−βt)(1−βt−1)⋅xt−2+1−(1−βt)(1−βt−1)⋅ϵ′=⋯=∏i=1t(1−βi)⋅x0+1−∏i=1t(1−βi)⋅ϵ′′ \begin{aligned} {\bold x}_t &=\sqrt{1-\beta_t}\left(\sqrt{1-\beta_{t-1}}\cdot{\bold x}_{t-2}+\sqrt{\beta_{t-1}}\cdot{\boldsymbol\epsilon}_{t-1}\right)+\sqrt{\beta_t}\cdot{\boldsymbol\epsilon}_t \\ &=\sqrt{(1-\beta_t)(1-\beta_{t-1})}\cdot{\bold x}_{t-2}+\sqrt{1-(1-\beta_t)(1-\beta_{t-1})}\cdot{\boldsymbol\epsilon}' \\ &=\cdots \\ &=\sqrt{\prod_{i=1}^{t}\left(1-\beta_i\right)}\cdot{\bold x}_0+\sqrt{1-\prod_{i=1}^{t}\left(1-\beta_i\right)}\cdot{\boldsymbol\epsilon}'' \end{aligned} xt=1−βt(1−βt−1⋅xt−2+βt−1⋅ϵt−1)+βt⋅ϵt=(1−βt)(1−βt−1)⋅xt−2+1−(1−βt)(1−βt−1)⋅ϵ′=⋯=i=1∏t(1−βi)⋅x0+1−i=1∏t(1−βi)⋅ϵ′′whereϵ′,ϵ′′∼N(0,I) {\boldsymbol\epsilon}',{\boldsymbol\epsilon}''\sim{\cal N}\left({\bold 0},{\bold I}\right) ϵ′,ϵ′′∼N(0,I)letαt=1−βt \alpha_t=1-\beta_t αt=1−βtandαˉt=∏s=1tαs \bar\alpha_t=\prod_{s=1}^{t}\alpha_s αˉt=s=1∏tαswe haveq(xt|x0)=N(xt;αˉt⋅x0,(1−αˉt)I) q(\left.{\bold x}_t \middle\vert{\bold x}_0\right.)={\cal N}\left({\bold x}_t;\sqrt{\bar\alpha_t}\cdot{\bold x}_0,(1-\bar\alpha_t){\bold I}\right) q(xt∣x0)=N(xt;αˉt⋅x0,(1−αˉt)I)
Reverse Process
Defined as Markov Chain as well:pθ(x0:T)=pθ(xT)∏t=1Tpθ(xt−1|xt) p_\theta({\bold x}_{0:T})=p_\theta({\bold x}_T)\prod_{t=1}^T{p_\theta\left({\bold x}_{t-1}\middle\vert{\bold x}_{t}\right)} pθ(x0:T)=pθ(xT)t=1∏Tpθ(xt−1∣xt)wherepθ(xt−1|xt)=N(xt−1;μθ(xt,t),Σθ(xt,t)) p_\theta\left({\bold x}_{t-1}\middle\vert{\bold x}_t\right)={\cal N}\left({\bold x}_{t-1};{\boldsymbol\mu}_\theta\left({\bold x}_t,t\right),{\boldsymbol\Sigma}_\theta\left({\bold x}_t,t\right)\right) pθ(xt−1∣xt)=N(xt−1;μθ(xt,t),Σθ(xt,t))
From Forward Process
q(xt−1|xt)=q(xt−1|xt,x0)=q(xt|xt−1,x0)⋅q(xt−1|x0)q(xt|x0) q\left({\bold x}_{t-1}\middle\vert{\bold x}_t\right)=q\left({\bold x}_{t-1}\middle\vert{\bold x}_t,{\bold x}_0\right)=q(\left.{\bold x}_t \middle\vert{\bold x}_{t-1},{\bold x}_0\right.)\cdot\frac{q(\left.{\bold x}_{t-1} \middle\vert{\bold x}_0\right.)}{q(\left.{\bold x}_t \middle\vert{\bold x}_0\right.)} q(xt−1∣xt)=q(xt−1∣xt,x0)=q(xt∣xt−1,x0)⋅q(xt∣x0)q(xt−1∣x0)with Gaussian kernellogq(xt−1|xt,x0)=−12[(xt−αt⋅xt−1)2βt+(xt−1−αˉt−1⋅x0)21−αˉt−1−(xt−αˉt⋅x0)21−αˉt]=−12[(αtβt+11−αˉt−1)xt−12−(2αtβt⋅xt+2αˉt−11−αˉt−1⋅x0)xt−1+C] \begin{aligned} \log{q\left({\bold x}_{t-1}\middle\vert{\bold x}_t,{\bold x}_0\right)} &=-\frac12\left[\frac{\left({\bold x}_t-\sqrt{\alpha_t}\cdot{\bold x}_{t-1}\right)^2}{\beta_t}+\frac{\left({\bold x}_{t-1}-\sqrt{\bar\alpha_{t-1}}\cdot{\bold x}_0\right)^2}{1-\bar\alpha_{t-1}}-\frac{\left({\bold x}_t-\sqrt{\bar\alpha_t}\cdot{\bold x}_0\right)^2}{1-\bar\alpha_t}\right] \\ &=-\frac12\left[\left(\frac{\alpha_t}{\beta_t}+\frac{1}{1-\bar\alpha_{t-1}}\right){\bold x}_{t-1}^2-\left(\frac{2\sqrt{\alpha_t}}{\beta_t}\cdot{\bold x}_t+\frac{2\sqrt{\bar\alpha_{t-1}}}{1-\bar\alpha_{t-1}}\cdot{\bold x}_0\right){\bold x}_{t-1}+C\right] \end{aligned} logq(xt−1∣xt,x0)=−21[βt(xt−αt⋅xt−1)2+1−αˉt−1(xt−1−αˉt−1⋅x0)2−1−αˉt(xt−αˉt⋅x0)2]=−21[(βtαt+1−αˉt−11)xt−12−(βt2αt⋅xt+1−αˉt−12αˉt−1⋅x0)xt−1+C]therefore1σ2=αt−αˉt+βtβt(1−αˉt−1)=1−αˉt1−αˉt−1⋅1βt⟹σ2=1−αˉt−11−αˉt⋅βt=Δβ~t \frac{1}{\sigma^2}=\frac{\alpha_t-\bar\alpha_t+\beta_t}{\beta_t\left(1-\bar\alpha_{t-1}\right)}=\frac{1-\bar\alpha_t}{1-\bar\alpha_{t-1}}\cdot\frac{1}{\beta_t} \Longrightarrow \sigma^2=\frac{1-\bar\alpha_{t-1}}{1-\bar\alpha_t}\cdot\beta_t\xlongequal[]{\Delta}\tilde\beta_t σ21=βt(1−αˉt−1)αt−αˉt+βt=1−αˉt−11−αˉt⋅βt1⟹σ2=1−αˉt1−αˉt−1⋅βtΔβ~tμ=σ22(2αtβt⋅xt+2αˉt−11−αˉt−1⋅x0)=αt(1−αˉt−1)1−αˉt⋅xt+βtαˉt−11−αˉt⋅x0=Δμ~t(xt,x0) \mu=\frac{\sigma^2}{2}\left(\frac{2\sqrt{\alpha_t}}{\beta_t}\cdot{\bold x}_t+\frac{2\sqrt{\bar\alpha_{t-1}}}{1-\bar\alpha_{t-1}}\cdot{\bold x}_0\right)=\frac{\sqrt{\alpha_t}\left(1-\bar\alpha_{t-1}\right)}{1-\bar\alpha_t}\cdot{\bold x}_t+\frac{\beta_t\sqrt{\bar\alpha_{t-1}}}{1-\bar\alpha_t}\cdot{\bold x}_0\xlongequal[]{\Delta}\tilde{\boldsymbol\mu}_t({\bold x}_t,{\bold x}_0) μ=2σ2(βt2αt⋅xt+1−αˉt−12αˉt−1⋅x0)=1−αˉtαt(1−αˉt−1)⋅xt+1−αˉtβtαˉt−1⋅x0Δμ~t(xt,x0)finallyq(xt−1|xt,x0)=N(xt−1;μ~t(xt,x0),β~tI) q\left({\bold x}_{t-1}\middle\vert{\bold x}_t,{\bold x}_0\right)={\cal N}\left({\bold x}_{t-1};\tilde{\boldsymbol\mu}_t({\bold x}_t,{\bold x}_0),\tilde\beta_t{\bold I}\right) q(xt−1∣xt,x0)=N(xt−1;μ~t(xt,x0),β~tI)
Noise Prediction
For given noiseϵ∼N(0,I) {\boldsymbol\epsilon}\sim{\cal N}({\bold 0},{\bold I}) ϵ∼N(0,I)we havext(x0,ϵ)=αˉt⋅x0+1−αˉt⋅ϵ⟹x0=xt(x0,ϵ)−1−αˉt⋅ϵαˉt {\bold x}_t({\bold x}_0,{\boldsymbol\epsilon})=\sqrt{\bar\alpha_t}\cdot{\bold x}_0+\sqrt{1-\bar\alpha_t}\cdot{\boldsymbol\epsilon} \Longrightarrow {\bold x}_0=\frac{{\bold x}_t({\bold x}_0,{\boldsymbol\epsilon})-\sqrt{1-\bar\alpha_t}\cdot{\boldsymbol\epsilon}}{\sqrt{\bar\alpha_t}} xt(x0,ϵ)=αˉt⋅x0+1−αˉt⋅ϵ⟹x0=αˉtxt(x0,ϵ)−1−αˉt⋅ϵthus, w.r.t. xt{\bold x}_txt and ϵ\boldsymbol\epsilonϵμ~t(xt,xt−1−αˉt⋅ϵαˉt)=αt(1−αˉt−1)1−αˉt⋅xt+βtαˉt−11−αˉt⋅xt−1−αˉt⋅ϵαˉt=αt−αˉt+βt(1−αˉt)αt⋅xt−βt1−αˉtαt⋅ϵ=1αt(xt−βt1−αˉt⋅ϵ) \begin{aligned} \tilde{\boldsymbol\mu}_t\left({\bold x}_t,\frac{{\bold x}_t-\sqrt{1-\bar\alpha_t}\cdot{\boldsymbol\epsilon}}{\sqrt{\bar\alpha_t}}\right) &=\frac{\sqrt{\alpha_t}\left(1-\bar\alpha_{t-1}\right)}{1-\bar\alpha_t}\cdot{\bold x}_t+\frac{\beta_t\sqrt{\bar\alpha_{t-1}}}{1-\bar\alpha_t}\cdot\frac{{\bold x}_t-\sqrt{1-\bar\alpha_t}\cdot{\boldsymbol\epsilon}}{\sqrt{\bar\alpha_t}} \\ &=\frac{\alpha_t-\bar\alpha_t+\beta_t}{\left(1-\bar\alpha_t\right)\sqrt{\alpha_t}}\cdot{\bold x}_t-\frac{\beta_t}{\sqrt{1-\bar\alpha_t}\sqrt{\alpha_t}}\cdot{\boldsymbol\epsilon} \\ &=\frac{1}{\sqrt{\alpha_t}}\left({\bold x}_t-\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\cdot{\boldsymbol\epsilon}\right) \end{aligned} μ~t(xt,αˉtxt−1−αˉt⋅ϵ)=1−αˉtαt(1−αˉt−1)⋅xt+1−αˉtβtαˉt−1⋅αˉtxt−1−αˉt⋅ϵ=(1−αˉt)αtαt−αˉt+βt⋅xt−1−αˉtαtβt⋅ϵ=αt1(xt−1−αˉtβt⋅ϵ)parameterize as neural networkϵ=ϵθ(xt,t) {\boldsymbol\epsilon}={\boldsymbol\epsilon}_\theta({\bold x}_t,t) ϵ=ϵθ(xt,t)finallyμθ(xt,t)=1αt(xt−βt1−αˉt⋅ϵθ(xt,t)) {\boldsymbol\mu}_\theta\left({\bold x}_t,t\right)=\frac{1}{\sqrt{\alpha_t}}\left({\bold x}_t-\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\cdot{\boldsymbol\epsilon}_\theta({\bold x}_t,t)\right) μθ(xt,t)=αt1(xt−1−αˉtβt⋅ϵθ(xt,t))
Loss Function
Recappθ(xt−1|xt)=N(xt−1;μθ(xt,t),σt2I) p_\theta\left({\bold x}_{t-1}\middle\vert{\bold x}_t\right)={\cal N}\left({\bold x}_{t-1};{\boldsymbol\mu}_\theta\left({\bold x}_t,t\right),\sigma_t^2{\bold I}\right) pθ(xt−1∣xt)=N(xt−1;μθ(xt,t),σt2I)whereσt2=βt or β~t \sigma_t^2=\beta_t \ {\rm or} \ \tilde\beta_t σt2=βt or β~tusing KL divergenceLt−1=KL(q(xt−1|xt,x0)∥pθ(xt−1|xt))=Eq[12σt2∥μ~t(xt,x0)−μθ(xt,t)∥2]=Ex0,ϵ[12σt2∥1αt(xt−βt1−αˉt⋅ϵ)−1αt(xt−βt1−αˉt⋅ϵθ(xt,t))∥2]=Ex0,ϵ[βt22σt2αt(1−αˉt)∥ϵ−ϵθ(αˉt⋅x0+1−αˉt⋅ϵ,t)∥2] \begin{aligned} {\cal L}_{t-1} &=\mathop{\rm KL}\left(q\left({\bold x}_{t-1}\middle\vert{\bold x}_t,{\bold x}_0\right)\middle\Vert p_\theta\left({\bold x}_{t-1}\middle\vert{\bold x}_t\right)\right) \\ &={\bf E}_q\left[\left.\frac{1}{2\sigma_t^2}\middle\Vert\tilde{\boldsymbol\mu}_t({\bold x}_t,{\bold x}_0)-{\boldsymbol\mu}_\theta\left({\bold x}_t,t\right)\right\Vert^2\right] \\ &={\bf E}_{{\bold x}_0,{\boldsymbol\epsilon}}\left[\left.\frac{1}{2\sigma_t^2}\middle\Vert\frac{1}{\sqrt{\alpha_t}}\left({\bold x}_t-\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\cdot{\boldsymbol\epsilon}\right)-\frac{1}{\sqrt{\alpha_t}}\left({\bold x}_t-\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\cdot{\boldsymbol\epsilon}_\theta({\bold x}_t,t)\right)\right\Vert^2\right] \\ &={\bf E}_{{\bold x}_0,{\boldsymbol\epsilon}}\left[\left.\frac{\beta_t^2}{2\sigma_t^2\alpha_t(1-\bar\alpha_t)}\middle\Vert{\boldsymbol\epsilon}-{\boldsymbol\epsilon}_\theta\left(\sqrt{\bar\alpha_t}\cdot{\bold x}_0+\sqrt{1-\bar\alpha_t}\cdot{\boldsymbol\epsilon},t\right)\right\Vert^2\right] \end{aligned} Lt−1=KL(q(xt−1∣xt,x0)∥pθ(xt−1∣xt))=Eq[2σt21μ~t(xt,x0)−μθ(xt,t)2]=Ex0,ϵ[2σt21αt1(xt−1−αˉtβt⋅ϵ)−αt1(xt−1−αˉtβt⋅ϵθ(xt,t))2]=Ex0,ϵ[2σt2αt(1−αˉt)βt2ϵ−ϵθ(αˉt⋅x0+1−αˉt⋅ϵ,t)2]a simplified version (w/ no coefficient)Lsimp=Ex0,ϵ[∥ϵ−ϵθ(αˉt⋅x0+1−αˉt⋅ϵ,t)∥2] {\cal L}_{\rm simp}={\bf E}_{{\bold x}_0,{\boldsymbol\epsilon}}\left[\left\Vert{\boldsymbol\epsilon}-{\boldsymbol\epsilon}_\theta\left(\sqrt{\bar\alpha_t}\cdot{\bold x}_0+\sqrt{1-\bar\alpha_t}\cdot{\boldsymbol\epsilon},t\right)\right\Vert^2\right] Lsimp=Ex0,ϵ[ϵ−ϵθ(αˉt⋅x0+1−αˉt⋅ϵ,t)2]