import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
引入需要的库,最后两行是因为在使用Jupyter notebook 出现内核已挂掉的问题,导致无法运行,查找资料后添加了最后两行代码。
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.5,0.5)])
对数据做归一化处理,transforms.Compose()把多个步骤整合在一起,即totensor,Normalize两步骤整合在一起transform.ToTensor()ToTensor()将shape为(H, W, C)的nump.ndarray或img转为shape为(C, H, W)的tensor,其将每一个数值归一化到[0,1],其归一化方法比较简单,直接除以255即可。transforms.Normalize()把数值[0,1]转换成[-1,1]。
train_ds=torchvision.datasets.MNIST('data',train=True,transform=transform,download=True)
下载数据集MNIST
root (string): 表示数据集的根目录,其中根目录存在MNIST/processed/training.pt和MNIST/processed/test.pt的子目录
train (bool, optional): 如果为True,则从training.pt创建数据集,否则从test.pt创建数据集
download (bool, optional): 如果为True,则从internet下载数据集并将其放入根目录。如果数据集已下载,则不会再次下载
transform (callable, optional): 接收PIL图片并返回转换后版本图片的转换函数
dataloader=torch.utils.data.DataLoader(train_ds,batch_size=64,shuffle=True)
dataset:包含所有数据的数据集
batch_size:批量训练数据量的大小
shuffle:洗牌,是否打乱数据。
class Generator(nn.Module):
def __init__(self):
super(Generator,self).__init__()
self.main=nn.Sequential(
nn.Linear(100,256),
nn.ReLU(),
nn.Linear(256,512),
nn.ReLU(),
nn.Linear(512,784),
nn.Tanh()
)
def forward(self,x):
img = self.main(x)
img=img.view(-1,28,28)#转换成图片的形式
return img
上面是生成器的模型:
输入:100的符合正态分布的噪声。
全连接层1:100-256
激活函数
全连接层2:256-512
激活函数
全连接层3:512-784
激活函数
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator,self).__init__()
self.main = nn.Sequential(
nn.Linear(784,512),
nn.LeakyReLU(),
nn.Linear(512,256),
nn.LeakyReLU(),
nn.Linear(256,1),
nn.Sigmoid()
)
def forward(self,x):
x =x.view(-1,784) #展平
x =self.main(x)
return x
上面是判别器:
输入:1*28*28=784的图片,输出:一个概率值
全连接层1:784-512
激活函数
全连接层2:512-256
激活函数
全连接层:256-1
激活函数
device='cuda' if torch.cuda.is_available() else 'cpu'
配置设备:cuda或cpu
#初始化生成器和判别器把他们放到相应的设备上
gen = Generator().to(device)
dis = Discriminator().to(device)
#训练器的优化器
d_optim = torch.optim.Adam(dis.parameters(),lr=0.0001)
#训练生成器的优化器
g_optim = torch.optim.Adam(dis.parameters(),lr=0.0001)
#交叉熵损失函数
loss_fn = torch.nn.BCELoss()
def gen_img_plot(model,test_input):
prediction = np.squeeze(model(test_input).detach().cpu().numpy())
fig = plt.figure(figsize=(4,4))
for i in range(16):
plt.subplot(4,4,i+1)
plt.imshow((prediction[i]+1)/2)
plt.axis('off')
plt.show()
绘制图像
np.squeeze:在机器学习和深度学习中,通常算法的结果是可以表示向量的数组(即包含两对或以上的方括号形式[[]]),如果直接利用这个数组进行画图可能显示界面为空(见后面的示例)。我们可以利用squeeze()函数将表示向量的数组转换为秩为1的数组,这样利用matplotlib库函数画图时,就可以正常的显示结果了。
figure:创建新的图形对象,在屏幕上单独显示的窗口,且窗口中可以输出图形
subplot(numbRow,numbCol,plotNum):numbRow是plot图的行数;numbCol是plot图的列数;plotNum是指第几行第几列的第几幅图 .
test_input=torch.randn(16,100,device=device)
D_loss=[]
G_loss=[]
#训练循环
for epoch in range(20):
#初始化损失值
d_epoch_loss = 0
g_epoch_loss = 0
count = len(dataloader) #返回批次数
#对数据集进行迭代
for step,(img,_) in enumerate(dataloader):
img =img.to(device) #把数据放到设备上
size = img.size(0) #img的第一位是size,获取批次的大小
random_noise = torch.randn(size,100,device=device)
#判别器训练(真实图片的损失和生成图片的损失),损失的构建和优化
d_optim.zero_grad()#梯度归零
#判别器对于真实图片产生的损失
real_output = dis(img) #判别器输入真实的图片,real_output对真实图片的预测结果
d_real_loss = loss_fn(real_output,
torch.ones_like(real_output)
)
d_real_loss.backward()#计算梯度
#在生成器上去计算生成器的损失,优化目标是判别器上的参数
gen_img = gen(random_noise) #得到生成的图片
#因为优化目标是判别器,所以对生成器上的优化目标进行截断
fake_output = dis(gen_img.detach()) #判别器输入生成的图片,fake_output对生成图片的预测;detach会截断梯度,梯度就不会再传递到gen模型中了
#判别器在生成图像上产生的损失
d_fake_loss = loss_fn(fake_output,
torch.zeros_like(fake_output)
)
d_fake_loss.backward()
#判别器损失
d_loss = d_real_loss + d_fake_loss
#判别器优化
d_optim.step()
#生成器上损失的构建和优化
g_optim.zero_grad() #先将生成器上的梯度置零
fake_output = dis(gen_img)
g_loss = loss_fn(fake_output,
torch.ones_like(fake_output)
) #生成器损失
g_loss.backward()
g_optim.step()
#累计每一个批次的loss
with torch.no_grad():
d_epoch_loss +=d_loss
g_epoch_loss +=g_loss
#求平均损失
with torch.no_grad():
d_epoch_loss /=count
g_epoch_loss /=count
D_loss.append(d_epoch_loss)
G_loss.append(g_epoch_loss)
print('Epoch:',epoch)
gen_img_plot(gen,test_input)
在判别器的帮助下,生成器最后的结果图应该越来越接近一个数字,但这个简单的模型并不好,最后结果不明显。