GAN介绍

本文深入解析了生成对抗网络(GAN)的基本思想与工作原理,包括Generator生成器和Discriminator判别器的相互作用机制,以及如何通过迭代优化实现两者间的纳什均衡。同时,讨论了GAN训练中常见的JS散度缺陷及解决方案,并提供了使用PyTorch实现GAN的代码示例。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

What I cannot create, I do not understand.
我没有创造出它,说明我对它没有理解透彻

GAN的思想

GAN是由一个Generator生成器和Discriminator判别器来组成,基本思想如下:
G ∗ = arg ⁡ min ⁡ G max ⁡ D V ( G , D ) G^{*}=\arg \min _{G} \max _{D} V(G, D) G=argGminDmaxV(G,D)
对于G来说,需要最小化 G ∗ G^* G,希望差异越小越好
对于D来说,需要最大化 G ∗ G^* G,希望差异越大越好
其中,V表示:
V = E x ∼ P d a t a [ log ⁡ D ( x ) ] + E x ∼ P G [ log ⁡ ( 1 − D ( x ) ) ] V = E_{x \sim P_{d a t a}}[\log D(x)]+E_{x \sim P_{G}}[\log (1-D(x))]\\ V=ExPdata[logD(x)]+ExPG[log(1D(x))]
V就代表了 P G P_G PG P d a t a P_data Pdata的分布差异,换成概率积分的形式:
V = ∫ [ P d a t a ( x ) log ⁡ D ( x ) + P G ( x ) log ⁡ ( 1 − D ( x ) ) ] d x V=\int\left[P_{d a t a}(x) \log D(x)+P_{G}(x) \log (1-D(x))\right] d x V=[Pdata(x)logD(x)+PG(x)log(1D(x))]dx
如何确定最优的G和D呢?
首先当给定G时,最优 D ∗ D^* D是:
D ∗ ( x ) = P data ( x ) P data ( x ) + P G ( x ) D^{*}(x)=\frac{P_{\text {data}}(x)}{P_{\text {data}}(x)+P_{G}(x)} D(x)=Pdata(x)+PG(x)Pdata(x)
此时:
max ⁡ D V ( G , D ) = = − 2 log ⁡ 2 + 2 D J S ( P data ( x ) ∥ P G ( x ) ) \max _{D} V(G, D)==-2 \log 2+2 D_{JS}\left(P_{\text {data}}(x) \| P_{G}(x)\right) DmaxV(G,D)=2log2+2DJS(Pdata(x)PG(x))
即代表了两个分布的差异,且当两个概率接近0时,v的值接近于0
然后,当给定D时,确定最优的 G ∗ G^* G:
L ( G , D ∗ ) = 2 D J S ( p r ∥ p g ) − 2 log ⁡ 2 L\left(G, D^{*}\right)=2 D_{J S}\left(p_{r} \| p_{g}\right)-2 \log 2 L(G,D)=2DJS(prpg)2log2
p r = p g p_r=p_g pr=pg时, D J S D_{JS} DJS等于0,此时差异为0,G最优
最后不断更新G和D的参数,使得G和D一代代的进化,最终得到一个很好的G使得D的判断陷入纳什均衡(Nash Equilibrium)

JS散度缺陷

当初始时, p r p_r pr p g p_g pg的分布完全不重合时,此时的 D J S D_{JS} DJS就是很大且保持不变,此时求导会出现梯度离散现象,从而导致GAN在早期可能会出现的训练不稳定的现象。
在这里插入图片描述
解决方法:

  • 采用Wasserstein距离,Wasserstein距离用来度量两个概率分布之间的距离。所以提出了WGAN网络:
    在这里插入图片描述
    但是因为要求;
    ∣ f ( x 1 ) − f ( x 2 ) ∣ ≤ ∣ x 1 − x 2 ∣ \left|f\left(x_{1}\right)-f\left(x_{2}\right)\right| \leq\left|x_{1}-x_{2}\right| f(x1)f(x2)x1x2
    所以考虑增加惩罚项,提出了WAN_GP网络:
    在这里插入图片描述

GAN实际操作

from    torch import nn, optim
import  numpy as np
h_dim = 400
batchsz = 512

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(2, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, 2),
        )
    def forward(self, z):
        output = self.net(z)
        return output
        
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(2, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        output = self.net(x)
        return output.view(-1)

def data_generator():
    scale = 2.
    centers = [
        (1, 0),
        (-1, 0),
        (0, 1),
        (0, -1),
        (1. / np.sqrt(2), 1. / np.sqrt(2)),
        (1. / np.sqrt(2), -1. / np.sqrt(2)),
        (-1. / np.sqrt(2), 1. / np.sqrt(2)),
        (-1. / np.sqrt(2), -1. / np.sqrt(2))
    ]
    centers = [(scale * x, scale * y) for x, y in centers]
    while True:
        dataset = []
        for i in range(batchsz):
            point = np.random.randn(2) * .02
            center = random.choice(centers)
            point[0] += center[0]
            point[1] += center[1]
            dataset.append(point)
        dataset = np.array(dataset, dtype='float32')
        dataset /= 1.414  # stdev
        yield dataset
        
def weights_init(m):
    if isinstance(m, nn.Linear):
        # m.weight.data.normal_(0.0, 0.02)
        nn.init.kaiming_normal_(m.weight)
        m.bias.data.fill_(0)

def main():
    np.random.seed(23)
    G = Generator()
    D = Discriminator()
    G.apply(weights_init)
    D.apply(weights_init)
    optim_G = optim.Adam(G.parameters(), lr=1e-3, betas=(0.5, 0.9))
    optim_D = optim.Adam(D.parameters(), lr=1e-3, betas=(0.5, 0.9))
    data_iter = data_generator()
    for epoch in range(50000):
        # 1. train discriminator for k steps
        for _ in range(5):
            x = next(data_iter)
            xr = torch.from_numpy(x)
            predr = (D(xr))
            lossr = - (predr.mean())
            z = torch.randn(batchsz, 2)
            xf = G(z).detach()
            predf = (D(xf))
            lossf = (predf.mean())
            loss_D = lossr + lossf 
            optim_D.zero_grad()
            loss_D.backward()
            optim_D.step()
        # 2. train Generator
        z = torch.randn(batchsz, 2)
        xf = G(z)
        predf = (D(xf))
        loss_G = - (predf.mean())
        optim_G.zero_grad()
        loss_G.backward()
        optim_G.step()

if __name__ == '__main__':
    main()

AE和GAN的比较

  • GAN只能输入随机噪声来产生特定图片,但AE却可以通过指定不同的输入噪声分布来生成不同类型的图片。
    AE生成的图片只是尽可能的像真的,但实际可能不是真的内容;GAN生成的图片是具有真实内容的
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值