GAN
对抗生成网络
生成器:生成假的数据,骗过判别器
判别器:将假数据判别为假,真实数据判别为真。
流程:
1.导入数据
2.网络架构:
输入层:待生成图像(噪音)和真实数据
生成网络:将噪音图像进行生成
判别网络:
(1)判断真实图像输出结果 1
(2)判断生成图像输出结果 0
3.目标函数:
(1):对于生成网络,要使得 生成结果 通过判别网络为真
(2):对于判别网络,要使得输入为真实图像时判别为真,输入为生成图像时判别为假
判别模型:log(D1(x)) + log(1-D2(G(z)))
生成网络:log(D2(G(z)))
代码
BCELoss:
import torch #交叉熵和BCEloss差不多
from torch import autograd
#随机生成一个数据作为预测值
input = autograd.Variable(torch.tensor([[ 1.9072, 1.1079, 1.4906],
[-0.6584, -0.0512, 0.7608],
[-0.0614, 0.6583, 0.1095]]), requires_grad=True)
print(input)
print('-'*100)
#因为得到的数据有正有负,用sigmoid函数映射一下,变成规范一点的0-1之间的预测结果
from torch import nn
m = nn.Sigmoid()
print(m(input))
print('-'*100)
#构建标签 想办法让判别器判断真时预测1,判别假时预测0
target = torch.FloatTensor([[0, 1, 1], [1, 1, 1], [0, 0, 0]])
print(target)
print('-'*100)
import math
#损失函数 左边当作正确预测类别,右边为错误预测类别
r11 = 0 * math.log(0.8707) + (1-0) * math.log((1 - 0.8707))
r12 = 1 * math.log(0.7517) + (1-1) * math.log((1 - 0.7517))
r13 = 1 * math.log(0.8162) + (1-1) * math.log((1 - 0.8162))
r21 = 1 * math.log(0.3411) + (1-1) * math.log((1 - 0.3411))
r22 = 1 * math.log(0.4872) + (1-1) * math.log((1 - 0.4872))
r23 = 1 * math.log(0.6815) + (1-1) * math.log((1 - 0.6815))
r31 = 0 * math.log(0.4847) + (1-0) * m