GAN基础代码

GAN

对抗生成网络

 

生成器:生成假的数据,骗过判别器

判别器:将假数据判别为假,真实数据判别为真。

流程:

1.导入数据

2.网络架构:

输入层:待生成图像(噪音)和真实数据

生成网络:将噪音图像进行生成

判别网络

(1)判断真实图像输出结果 1

(2)判断生成图像输出结果 0

3.目标函数:

(1):对于生成网络,要使得 生成结果 通过判别网络为真

(2):对于判别网络,要使得输入为真实图像时判别为真,输入为生成图像时判别为假


判别模型:log(D1(x)) + log(1-D2(G(z)))

生成网络:log(D2(G(z)))
 

代码

BCELoss:

import torch    #交叉熵和BCEloss差不多
from torch import autograd
#随机生成一个数据作为预测值
input = autograd.Variable(torch.tensor([[ 1.9072,  1.1079,  1.4906],
        [-0.6584, -0.0512,  0.7608],
        [-0.0614,  0.6583,  0.1095]]), requires_grad=True)
print(input)
print('-'*100)
#因为得到的数据有正有负,用sigmoid函数映射一下,变成规范一点的0-1之间的预测结果
from torch import nn
m = nn.Sigmoid()
print(m(input))
print('-'*100)

#构建标签  想办法让判别器判断真时预测1,判别假时预测0
target = torch.FloatTensor([[0, 1, 1], [1, 1, 1], [0, 0, 0]])
print(target)
print('-'*100)

import math
#损失函数   左边当作正确预测类别,右边为错误预测类别
r11 = 0 * math.log(0.8707) + (1-0) * math.log((1 - 0.8707))
r12 = 1 * math.log(0.7517) + (1-1) * math.log((1 - 0.7517))
r13 = 1 * math.log(0.8162) + (1-1) * math.log((1 - 0.8162))

r21 = 1 * math.log(0.3411) + (1-1) * math.log((1 - 0.3411))
r22 = 1 * math.log(0.4872) + (1-1) * math.log((1 - 0.4872))
r23 = 1 * math.log(0.6815) + (1-1) * math.log((1 - 0.6815))

r31 = 0 * math.log(0.4847) + (1-0) * m
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值