手撕VAE(变分自编码器) -- Day3 -- train.py

手撕VAE(变分自编码器) – Day3 – train.py

Variational Auto Encoder (VAE) 模型原理

VAE 网络结构图

在这里插入图片描述

在这里插入图片描述

变分自编码器 网络结构

和AE的区别

  • Part1:隐空间(latent)不是直接利用神经网络生成的,而是利用神经网络生成均值和方差来采样生成隐空间。

    • 例如:原本是直接生成(1,10)的向量,但是现在是生成10个均值 u u u和10个标准差 σ \sigma σ(满足正态分布),然后通过采样生成10个向量值。
    • 存在的问题:采样的操作是无法进行梯度反向传播的,因此需要利用重参数化进行反向传播,来进行均值和方差的参数更新。
  • Part2:重参数化技巧,利用正态分布之间的转换关系来链接生成的均值,方差和生成的样本直接的关系。

    • 例子:对于(0,1)的标准正态分布z,和( u u u, σ \sigma σ)的正态分布x,满足z= x − u σ \frac{x-u}{\sigma} σxu,得到 x = u + σ z x=u+\sigma z x=u+σz。所以我们只需要随机生成一个标准正态分布z,乘上标准差,加上均值即可,这样就构建出了 u u u σ \sigma σ 与x之间的关系,也就可以梯度反传了(非常巧妙的方法)。
  • Part3:损失函数添加了一个KL散度,从原来的重建损失,添加了一个分布损失,用来保证生成隐空间的分布,满足标准正态分布,因此该损失主要是衡量生成分布和标准正态分布之间的差距。

    • VAE(变分自编码器)的损失函数由两部分组成:重构损失KL 散度

      1. 重构损失(Reconstruction Loss)

      重构损失衡量了生成的数据 ( \hat{x} ) 与输入数据 ( x ) 之间的差异,通常使用 均方误差(MSE)交叉熵 来衡量。对于连续数据,常用 MSE,对于离散数据,常用交叉熵。
      L reconstruction = E q ( z ∣ x ) [ ∥ x − x ^ ∥ 2 ] \mathcal{L}_{\text{reconstruction}} = \mathbb{E}_{q(z|x)}[\| x - \hat{x} \|^2] Lreconstruction=Eq(zx)[xx^2]
      这里:

      • ( x ^ \hat{x} x^ = p(x|z) ) 是解码器生成的样本。
      • ( x ) 是原始输入数据。
      • ( ∥ x − x ^ ∥ 2 \| x - \hat{x} \|^2 xx^2 ) 是输入数据与生成数据之间的均方误差。

      2. KL 散度(KL Divergence)

      KL 散度项衡量了潜在空间的分布 ( q(z|x) ) 与标准正态分布 ( p(z) ) 之间的差异。通过最小化 KL 散度,我们希望潜在空间的分布接近标准正态分布 ( \mathcal{N}(0, I) ),从而避免过拟合,并保持潜在空间的结构。

      KL 散度的公式为:
      $$

      \mathcal{L}{\text{KL}} = D{\text{KL}}[q(z|x) || p(z)]
      $$

      对于每个潜在维度 ( z_i ),KL 散度可以表示为:
      D KL [ N ( μ ( x ) , σ ( x ) 2 ) ∣ ∣ N ( 0 , 1 ) ] = 1 2 ( μ i 2 ( x ) + σ i 2 ( x ) − 1 − log ⁡ σ i 2 ( x ) ) D_{\text{KL}}[\mathcal{N}(\mu(x), \sigma(x)^2) || \mathcal{N}(0, 1)] = \frac{1}{2} \left( \mu_i^2(x) + \sigma_i^2(x) - 1 - \log \sigma_i^2(x) \right) DKL[N(μ(x),σ(x)2)∣∣N(0,1)]=21(μi2(x)+σi2(x)1logσi2(x))

      这里:

      • μ ( x ) \mu(x) μ(x) σ ( x ) \sigma(x) σ(x) 分别是编码器网络输出的潜在变量的均值和标准差。
      • N ( 0 , 1 ) \mathcal{N}(0, 1) N(0,1) 是标准正态分布。

VAE 训练代码 - 变分自编码器

VAE 训练代码

Part1 引入相关库函数

# 训练模型,该模块主要是为了实现对于模型的训练,
'''
# Part1 引入相关的库函数
'''

import torch
from torch import nn
from dataset import Mnist_dataset
from VAE import VAE
import torch.utils.data as data

Part2 初始化 VAE 的训练参数

'''
初始化一些训练参数
'''
EPOCH = 50
Mnist_dataloader = data.DataLoader(dataset=Mnist_dataset, batch_size=64, shuffle=True)

# 前向传播的模型
net = VAE(img_channel=1, img_size=28, encode_f1_size=400, latent_size=10)

# 计算损失函数,VAE和AE不同的点,还在于,需要计算正态分布之间的KL散度。

# 反向更新参数
lr = 1e-3
optim = torch.optim.Adam(params=net.parameters(), lr=lr)


# 定义VAE的损失函数,主要包含重建损失和KL散度
def vae_loss(rec_x, x, mu, sigma):
    # 首先是重建损失
    loss1 = nn.BCELoss(reduction='sum')
    rec_loss = loss1(rec_x, x)
    # 然后是KL散度的损失,也就是预测出来的均值和方差要满足标准正态分布(所以衡量的是标准正态分布和预测到的分布的差距和),这里是假设log以2为底
    KL_loss = -0.5 * torch.sum(1 + sigma - torch.pow(mu, 2) - torch.pow(sigma, 2))
    return rec_loss + KL_loss

Part3 训练

'''
# 开始训练
'''
# net.train() # 设置为训练模式

for epoch in range(EPOCH):
    n_iter = 0
    for batch_img, _ in Mnist_dataloader:
        # 先进行前向传播
        batch_img_pre, mu, sigma = net(batch_img)  #

        # 计算损失
        loss_cal = vae_loss(batch_img_pre, batch_img, mu, sigma)

        # 清除梯度
        optim.zero_grad()
        # 反向传播
        loss_cal.backward()
        # 更新参数
        optim.step()

        l = loss_cal.item()

        if n_iter % 100 == 0:
            print('此时的epoch为{},iter为{},loss为{}'.format(epoch, n_iter, l))

        n_iter += 1
    if epoch == 20:
        # 注意pt文件是保存整个模型及其参数的,pth文件只是保存参数
        torch.save(net.encode, 'VAE_encoder_eopch_{}.pt'.format(epoch))
        # 注意pt文件是保存整个模型及其参数的,pth文件只是保存参数
        torch.save(net.decode, 'VAE_decoder_eopch_{}.pt'.format(epoch))
        break

参考

视频讲解:DALL·E 2(内含扩散模型介绍)【论文精读】_哔哩哔哩_bilibili

模型原理讲解:自学资料 - Dalle2模型 - 文生图技术-优快云博客

github资料:YanxinTong/VAE_Pytorch: 利用 Pytorch 手撕 VAE 模型

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值