对抗神经网络GAN中d_loss g_loss两种更新参数的图解释

本文探讨了在对抗神经网络GAN中,d_loss和g_loss的反向传播过程中计算图的保留问题。若先计算d_loss,需使用retain_graph=True或detach()避免影响g_loss;反之,若先计算g_loss,detach()fake_img即可不影响d_loss的计算。参考了相关知乎文章进行深入解析。
部署运行你感兴趣的模型镜像

版权归属:

更多关注:

在这里插入图片描述

  • 如果先计算d_loss,在d_loss.backward()后会默认自动释放掉【real_img -> G -> fake_img -> D】这个计算图,但是在执行g_loss.backward()时需要【real_img -> G -> fake_img】这一段的计算图,所以会报告retain graph的错误,解决办法:d_loss.backward(retain_graph=True),或者在计算d_loss时,使用fake_img.detach(),这样在d_loss.backward()时只会释放【fake_img -> D】这段计算图,不会释放【real_img -> G -> fake_img】这段计算图,进而影响g_loss.backward();
  • 如果先计算g_loss,在g_loss.backward()后会默认自动释放掉【real_img -> G -> fake_img】这个计算图,在计算d_loss时只需要将fake_img.detach()【将fake_img与计算图“脱钩”】就不会影响d_loss.backward()的计算,因为更新只需要【fake_img -> D】这一段计算图。

参考

您可能感兴趣的与本文相关的镜像

GPT-oss:20b

GPT-oss:20b

图文对话
Gpt-oss

GPT OSS 是OpenAI 推出的重量级开放模型,面向强推理、智能体任务以及多样化开发场景

### 对抗神经网络 (GAN) 的代码实现 对抗神经网络GAN)的实现通常涉及定义生成器和判别器两个部分,并设置它们之间的对抗训练机制。下面是一个基于 PyTorch 实现的手写数字图片生成的 GAN 示例[^1]。 #### 导入必要的库 ```python import torch from torch import nn, optim from torchvision import datasets, transforms from torch.utils.data import DataLoader ``` #### 定义生成器 生成器的任务是从随机噪声向量生成看似真实的图像。 ```python class Generator(nn.Module): def __init__(self, input_size=100, hidden_dim=128, output_size=(1, 28, 28)): super(Generator, self).__init__() # 构建多层感知机结构 self.model = nn.Sequential( nn.Linear(input_size, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim * 2), nn.ReLU(), nn.Linear(hidden_dim * 2, int(torch.prod(torch.tensor(output_size)))), nn.Tanh() ) def forward(self, z): img_flat = self.model(z) img = img_flat.view(img_flat.size(0), *output_size) return img ``` #### 定义判别器 判别器负责评估传入的数据是否为真实样本或是来自生成器的伪造品。 ```python class Discriminator(nn.Module): def __init__(self, input_size=(1, 28, 28), hidden_dim=128): super(Discriminator, self).__init__() # 设计一个多层感知机架构 self.model = nn.Sequential( nn.Linear(int(torch.prod(torch.tensor(input_size))), hidden_dim*2), nn.LeakyReLU(0.2), nn.Dropout(0.3), nn.Linear(hidden_dim*2, hidden_dim), nn.LeakyReLU(0.2), nn.Dropout(0.3), nn.Linear(hidden_dim, 1), nn.Sigmoid() ) def forward(self, imgs): logits = self.model(imgs.view(imgs.size(0), -1)) return logits.squeeze() ``` #### 训练循环 此段代码展示了如何交替更新生成器和判别器参数,以促进两者的竞争关系。 ```python def train_gan(generator, discriminator, dataloader, num_epochs=200, lr=0.0002): criterion = nn.BCELoss() # 使用二元交叉熵损失函数 optimizer_G = optim.Adam(generator.parameters(), lr=lr) optimizer_D = optim.Adam(discriminator.parameters(), lr=lr) fixed_noise = torch.randn((64, 100)) for epoch in range(num_epochs): for i, (imgs, _) in enumerate(dataloader): real_labels = torch.ones(imgs.shape[0]) fake_labels = torch.zeros(imgs.shape[0]) # 更新 D 网络: max log(D(x)) + log(1 - D(G(z))) optimizer_D.zero_grad() outputs = discriminator(imgs).squeeze() d_loss_real = criterion(outputs, real_labels) real_score = outputs noise = torch.randn(imgs.shape[0], 100) fake_images = generator(noise) outputs = discriminator(fake_images.detach()).squeeze() d_loss_fake = criterion(outputs, fake_labels) fake_score = outputs d_loss = d_loss_real + d_loss_fake d_loss.backward() optimizer_D.step() # 更新 G 网络: min log(1 - D(G(z))) <-> max log(D(G(z)) optimizer_G.zero_grad() outputs = discriminator(fake_images).squeeze() g_loss = criterion(outputs, real_labels) g_loss.backward() optimizer_G.step() print(f'Epoch [{epoch}/{num_epochs}], ' f'd_loss: {d_loss.item():.4f}, ' f'g_loss: {g_loss.item():.4f}') ``` 上述代码片段提供了完整的 GAN 模型构建流程以及基本的训练逻辑[^4]。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值