### 对抗神经网络 (GAN) 的代码实现
对抗神经网络(GAN)的实现通常涉及定义生成器和判别器两个部分,并设置它们之间的对抗训练机制。下面是一个基于 PyTorch 实现的手写数字图片生成的 GAN 示例[^1]。
#### 导入必要的库
```python
import torch
from torch import nn, optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
```
#### 定义生成器
生成器的任务是从随机噪声向量生成看似真实的图像。
```python
class Generator(nn.Module):
def __init__(self, input_size=100, hidden_dim=128, output_size=(1, 28, 28)):
super(Generator, self).__init__()
# 构建多层感知机结构
self.model = nn.Sequential(
nn.Linear(input_size, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim * 2),
nn.ReLU(),
nn.Linear(hidden_dim * 2, int(torch.prod(torch.tensor(output_size)))),
nn.Tanh()
)
def forward(self, z):
img_flat = self.model(z)
img = img_flat.view(img_flat.size(0), *output_size)
return img
```
#### 定义判别器
判别器负责评估传入的数据是否为真实样本或是来自生成器的伪造品。
```python
class Discriminator(nn.Module):
def __init__(self, input_size=(1, 28, 28), hidden_dim=128):
super(Discriminator, self).__init__()
# 设计一个多层感知机架构
self.model = nn.Sequential(
nn.Linear(int(torch.prod(torch.tensor(input_size))), hidden_dim*2),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(hidden_dim*2, hidden_dim),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(hidden_dim, 1),
nn.Sigmoid()
)
def forward(self, imgs):
logits = self.model(imgs.view(imgs.size(0), -1))
return logits.squeeze()
```
#### 训练循环
此段代码展示了如何交替更新生成器和判别器参数,以促进两者的竞争关系。
```python
def train_gan(generator, discriminator, dataloader, num_epochs=200, lr=0.0002):
criterion = nn.BCELoss() # 使用二元交叉熵损失函数
optimizer_G = optim.Adam(generator.parameters(), lr=lr)
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)
fixed_noise = torch.randn((64, 100))
for epoch in range(num_epochs):
for i, (imgs, _) in enumerate(dataloader):
real_labels = torch.ones(imgs.shape[0])
fake_labels = torch.zeros(imgs.shape[0])
# 更新 D 网络: max log(D(x)) + log(1 - D(G(z)))
optimizer_D.zero_grad()
outputs = discriminator(imgs).squeeze()
d_loss_real = criterion(outputs, real_labels)
real_score = outputs
noise = torch.randn(imgs.shape[0], 100)
fake_images = generator(noise)
outputs = discriminator(fake_images.detach()).squeeze()
d_loss_fake = criterion(outputs, fake_labels)
fake_score = outputs
d_loss = d_loss_real + d_loss_fake
d_loss.backward()
optimizer_D.step()
# 更新 G 网络: min log(1 - D(G(z))) <-> max log(D(G(z))
optimizer_G.zero_grad()
outputs = discriminator(fake_images).squeeze()
g_loss = criterion(outputs, real_labels)
g_loss.backward()
optimizer_G.step()
print(f'Epoch [{epoch}/{num_epochs}], '
f'd_loss: {d_loss.item():.4f}, '
f'g_loss: {g_loss.item():.4f}')
```
上述代码片段提供了完整的 GAN 模型构建流程以及基本的训练逻辑[^4]。