生成式对抗网络(GAN,generative adversarial network)的简单理解就是,想想一名伪造者试图伪造一幅毕加索的画作。一开始,伪造者非常不擅长这项任务,他随便画了幅与毕加索真迹放在一起,请鉴定商进行评估,鉴定商鉴定后,将结果反馈给伪造者,并告诉他怎样可以让❀看起来更像毕加索的真迹。伪造者学习后回去重新画,然后再拿给鉴定商鉴定,多次循环后,伪造者已经十分熟练的伪造毕加索的画作了,鉴定商的鉴定能力也有了很大的提高。最后,他们手上拥有了一些优秀的毕加索赝品。
下边的例子是使用GAN模拟二次方程:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
"""超参数"""
BATCH_SIZE = 64 # 每批数据个数
LR_G = 0.0001 # 生成器的学习率(伪造者)
LR_D = 0.0001 # 判别器的学习率(鉴定商)
N_IDEAS = 5 # 随机想法个数
ART_COMPONENTS = 15 # 线段上数据点个数
PAINT_POINTS = np.vstack([np.linspace(-1, 1, ART_COMPONENTS) for _ in range(BATCH_SIZE)]) # h恩坐标范围
"""创建圣经网络"""
def artist_works():
a = np.random.uniform(1, 2, size=BATCH_SIZE)[:, np.newaxis]
paintings = a * np.power(PAINT_POINTS, 2) + (a - 1) # 定义原始一元二次方程(一开始的真品)
paintings = torch.from_numpy(paintings).float() # 转换为torch形式
return paintings
G = nn.Sequential( # 生成器(伪造者)
nn.Linear(N_IDEAS, 128), # 输入随即想法
nn.ReLU(),
nn.Linear(128, ART_COMPONENTS) # 生成15个点连线(创造一个赝品)
)
D = nn.Sequential( # 判别器(鉴定上)
nn.Linear(ART_COMPONENTS, 128), # 接受生成器生成的数据(获得赝品)
nn.ReLU(),
nn.Linear(128, 1), # 判别是否和原始数据相似(鉴定赝品是真是假)
nn.Sigmoid() # 产生百分比,表示是什么数据(表示是真品还是赝品)
)
opt_D = torch.optim.RMSprop(D.parameters(), lr=LR_D) # 优化判别器
opt_G = torch.optim.RMSprop(G.parameters(), lr=LR_G) # 优化生成器
"""训练神经网络"""
plt.ion()
for step in range(5000):
artist_paintings = artist_works() # 先获取原始标准方程(一开始的真品)
G_ideas = torch.randn(BATCH_SIZE, N_IDEAS) # 随机生成数据(想法)
G_paintings = G(G_ideas) # 生成器产生方程(创造赝品)
prob_artist0 = D(artist_paintings) # 计算式标准方程的概率(真品的概率)
prob_artist1 = D(G_paintings) # 计算式伪造方程的概率(赝品的概率)
D_loss = -torch.mean(torch.log(prob_artist0) + torch.log(1 - prob_artist1)) # 增加标准方程的概率
G_loss = torch.mean(torch.log(1 - prob_artist1)) # 增加伪造方程被认为是真方程的概率
opt_D.zero_grad()
D_loss.backward(retain_graph=True)
opt_D.step()
opt_G.zero_grad()
G_loss.backward()
opt_G.step()
"""循环打印"""
if step % 100 == 0: # 每100步打印一次
plt.cla() # 清空上一次的
plt.plot(PAINT_POINTS[0], G_paintings.data.numpy()[0], c='#4AD631', lw=3,
label='Generated painting', ) # 随机生成的方程
plt.plot(PAINT_POINTS[0], 2 * np.power(PAINT_POINTS[0], 2) + 1, c='#74BCFF', lw=3,
label='upper bound') # 标准方程上界
plt.plot(PAINT_POINTS[0], 1 * np.power(PAINT_POINTS[0], 2) + 0, c='#FF9359', lw=3,
label='lower bound') # 标准方程下界
plt.ylim((0, 3))
plt.legend(loc='upper right', fontsize=10)
plt.draw()
plt.pause(0.01)
plt.ioff()
plt.show()