上一节我们了解了自编码器的基础知识,这一节我们在pytorch上创建一个自编码器用于处理mnist数据集。
1.问题的描述
关于mnist数据集在这里就不再描述了,相信大家已经很熟悉了。今天使用自编码器有两个目的,即:
- 使用encoder显示压缩的特征数据。
- 使用decoder生成mnist数据集。
先对一些超参数进行设置:
# 定义一些超参数
EPOCH = 10
BATCH_SIZE = 64
LR = 0.005
DOWNLOAD_MNIST = True # 下过数据的话, 就可以设置成 False
N_TEST_IMG = 5 # 显示5张图片看效果
数据集准备过程和之前的也一样,不过我们这里不需要测试集,只需要训练集,而且只需要训练集的图片不需要标签。
2.AE的模型搭建
模型由两部分组成,编码器(encoder)和解码器(decoder),它们是两个相反的流程。encoder将28×28的数据压缩成3个数据,decoder将3个数据扩展为28×28个数据。
class AutoEncoder(nn.Module):
def __init__(self):
super(AutoEncoder,self).__init__()
# 编码器
self.encoder = nn.Sequential(
nn.Linear(28*28,128),
nn.Tanh(),
nn.Linear(128,64),
nn.Tanh(),
nn.Linear(64,12),
nn.Tanh(),
nn.Linear(12,3) # 压缩成3个特征, 进行 3D 图像可视化
)
# 解码器
self.decoder = nn.Sequential(
nn.Linear(3,12),
nn.Tanh(),
nn.Linear(12,64),
nn.Tanh(),
nn.Linear(64,128),
nn.Tanh(),
nn.Linear(128,28*28),
nn.Sigmoid() # 激励函数让输出值在 (0, 1)
)
def forward(self,x):
encodered = self.encoder(x)
decodered = self.decoder(encodered)
return encodered,decodered
网络模型的最后要输出两个值encodered,decodered,分别为编码器和解码器的输出值。要注意的是在decoder的最后一层激励函数是sigmoid,就是让输出在0-1区间,表示解压后像素的灰度值。
3.AE的结果展示
训练过程还是和之前有些差异的,因为AE进行的是无监督学习,因此它的X,Y数据用的是同一个:
b_x