基于VQ-VAE的手写体生成

《DeepSeek大模型高性能核心技术与多模态融合开发(人工智能技术丛书)》(王晓华)【摘要 书评 试读】- 京东图书

对于VQ-VAE的应用,其核心功能在于能够将原本连续的图像数据转化为离散的token表示。这一过程不仅实现了图像的高效压缩与编码,还为后续的图像处理和分析提供了新的视角。通过这些离散的token,我们可以更灵活地处理和操作图像数据,例如进行图像的检索、分类、编辑等任务。

具体来说,VQ-VAE通过学习一个离散的潜在空间来表示图像,这个空间由一系列预定义的token(或称为码字)构成。在训练过程中,VQ-VAE会将输入图像编码到这个离散空间中,选择最接近的token来表示图像的局部特征。这样,原本由像素值构成的连续图像就被转换成了一系列离散的token。

在生成图像时,VQ-VAE则根据这些token来解码并重构图像。由于token的离散性,生成的图像虽然在细节上可能与原图有所差异,但整体上能够保留原图的主要特征和结构。这种离散化的表示方法不仅降低了数据的复杂性,还提高了生成模型的效率和可控性。

此外,VQ-VAE生成的离散token序列还可以作为其他模型(如Transformer等)的输入,从而进一步拓展其在图像处理领域的应用。例如,可以通过对这些token进行序列建模,生成具有特定风格或内容的图像,或者实现图像的补全和修复等功能。

本章我们将完成基于VQ-VAE的图像生成。

13.2.1  图像的准备与超参数设置

本章我们将要完成的图像的准备,即使用MNIST数据集完成编码器VQ-VAE的手写体的生成。首先是图像文本的获取,我们可以直接使用MNIST数据集而对图像内容进行提取,代码如下所示。

from torch.utils.data import Dataset
from torchvision.transforms.v2 import PILToTensor, Compose
import torchvision
from tqdm import tqdm
import torch
import random

# 手写数字
class MNIST(Dataset):
    def __init__(self, is_train=True):
        super().__init__()
        self.ds = torchvision.datasets.MNIST('../../dataset/mnist/', train=is_train, download=True)
        self.img_convert = Compose([
            PILToTensor(),
        ])

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, index):
        img, label = self.ds[index]

        #text = f"现在的数字是:{label}#"
        text = random.sample(sample_texts,1)[0][-10:]
        text = text + str(label) + "#"

        full_tok = tokenizer_emo.encode(text)[-12:]
        full_tok = full_tok + [1] * (12 - len(full_tok))

        inp_tok = full_tok[:-1]
        tgt_tok = full_tok[1:]

        inp_tok = torch.tensor(inp_tok)
        tgt_tok = torch.tensor(tgt_tok)

        """
        torch.Size([1, 28, 28])
        """
        return self.img_convert(img) / 255.0, inp_tok,tgt_tok

上面代码复用了文本生成部分的MNIST数据集,读者在具体使用时,可以根据需要自行对其进行调整。

接下来,我们需要完成图像的超参数设计。在求解VQ-VAE的过程中,我们设计了如下的参数内容:

class Config:
    in_channels = 1
    d_model = 384
    image_size = 28
    patch_size = 4
    num_heads = 6
    num_layers = 3

    token_dim = token_size = 256
    latent_token_vocab_size = num_latent_tokens = 32

    codebook_size = 4096

token_size是指在VQ-VAE模型中量化后每个token(或“码字”)在潜在空间的维度。编码器会将输入数据转化为潜在表示,进而量化至一个离散的潜在空间。此量化步骤是通过将潜在表示里的每个向量用最近的“码字”替代来实现的,这些“码字”源于一个预先设定的码本。token_size参数即确定了这些“码字”的维度大小。例如,若token_size为12,意味着每个“码字”是一个12维的向量。

num_latent_tokens则代表在VQ-VAE的潜在空间中使用的不同“码字”的数量,它决定了码本的大小,即码本中包含离散向量的数目。在模型中,码本是一个可学习的参数集合,存储了表示输入数据特征的“码字”。num_latent_tokens参数控制着码本的大小,并影响模型捕捉输入数据细节的能力。举例来说,若num_latent_tokens设为64,则码本包含64个不同的“码字”,在量化潜在表示时,模型会从这些“码字”中选择一个来替换每个向量。

13.2.2  VQ-VAE的编码器与解码器

本小节讲解VQ-VAE的编码器和解码器。首先是编码器的作用,编码器将连续的图像特征转化为具有特定数目的token表示,而解码器则是将token复原成图像。

1. 编码器

编码器(Encoder)是深度学习模型中常见的一个组件,特别是在处理图像、文本等类型的数据时。它的主要作用是将输入数据(如图像)转换成一个更紧凑、更易于处理的表示形式,通常称为“编码”或“潜在表示”(Latent Representation)。这种表示可以捕捉输入数据的关键特征,并用于后续的任务,如分类、生成等。

其中Latent Representation,即潜在表示,是深度学习中一个核心概念,它指的是模型内部学习到的一种数据表示。这种表示通常不是直接可观察的,而是捕捉了输入数据的内在结构和特征。在上面的示例代码中,潜在表示是通过Encoder模块中的一系列变换得到的,它将原始图像数据转换成一种更高级、更抽象的形式。这种潜在表示有助于模型更好地理解和处理输入数据,进而提升在各种任务上的性能。

完整的编码器代码如下所示。

import einops
import torch
from einops.layers.torch import Rearrange
from torch import nn

class ExtractLatentTokens(torch.nn.Module):
    """
    提取潜在表示(Latent Tokens)的模块。

    这个模块用于从输入张量中提取一部分作为潜在表示,通常用于生成模型中,
    例如变分自编码器(VAE)或者生成对抗网络(GAN)中的潜在空间操作。

    属性:
    grid_size (int): 网格的大小,用于确定从输入张量中提取潜在表示的起始位置。
    """
    def __init__(self, grid_size):
        """
        初始化 ExtractLatentTokens 模块。

        参数:
        grid_size (int): 网格的大小,这个值的平方将用于确定输入张量中
                        从哪个位置开始提取潜在表示。
        """
        super(ExtractLatentTokens, self).__init__()
        self.grid_size = grid_size

    def forward(self, x):
        """
        前向传播方法,从输入张量中提取潜在表示。

        参数:
        x (torch.Tensor): 输入张量,通常包含了数据的完整表示。

        返回:
        torch.Tensor: 提取的潜在表示张量,包含了从输入张量中指定位置开始的
                     所有元素。
        """
        # 计算提取起始位置的索引,即 grid_size 的平方
        start_index = self.grid_size ** 2

        # 返回从 start_index 开始到 x 结束的切片,即提取的潜在表示
        return x[:, start_index:]

import config
from blocks import ResidualAttention

class Encoder(nn.Module):
    def __init__(self,config = config.Config,positional_embedding = None,latent_token_positional_embedding = None):
        super(Encoder, self).__init__()
        in_channels = config.in_channels
        self.image_size = config.image_size
        d_model = self.d_model = self.width = config.d_model
        self.num_heads = config.num_heads
        self.num_layers = config.num_layers
        self.patch_size = config.patch_size

        scale = self.width ** -0.5  # scale by 1/sqrt(d)

        self.token_size =  config.token_size#
        self.patch_embed = torch.nn.Conv2d(in_channels=in_channels,out_channels=self.width,kernel_size= self.patch_size,stride= self.patch_size,bias = True)

        # 图像补丁的位置嵌入     #torch.Size([49, 384])
        self.grid_size = self.image_size // self.patch_size
        #这个position_embedding 是加在图形上,以及加在补丁上
        self.positional_embedding = positional_embedding  # [ 7*7, 384 ]
        self.latent_token_positional_embedding = latent_token_positional_embedding

        self.transformer = nn.ModuleList()
        for _ in range(self.num_layers):
            self.transformer.append(ResidualAttention(d_model=self.d_model, attention_head_num=6))

        self.norm = torch.nn.RMSNorm(d_model)
        self.model = nn.Sequential(*self.transformer,ExtractLatentTokens(grid_size=self.grid_size),self.norm)
        self.encoder_out = nn.Linear(self.width, self.token_size)

    def forward(self, pixel_values, latent_tokens):

        B, _, _, _ = pixel_values.shape
        x = pixel_values
        x = self.patch_embed(x)
        x = einops.rearrange(x, "B C H W -> B (H W) C ")

        x = x + self.positional_embedding.to(x.dtype)

        latent_tokens = einops.repeat(latent_tokens, "T C -> B T C", B=B)
        x = torch.cat([x, latent_tokens], dim=1)
        x = self.model(x)
        x = x + self.latent_token_positional_embedding.to(x.dtype)

        x = self.encoder_out(x)
        return x    #[-1,32,256]

在上面示例代码的Encoder类中,潜在表示是通过多个Transformer层的自注意力机制和非线性变换逐步构建起来的。这些层能够捕捉输入图像中不同位置之间的依赖关系,并将这些信息编码到一个固定大小的向量空间中。这个过程使得模型能够提取出图像的关键特征,并以一种紧凑且高效的方式表达出来,即形成潜在表示。

最后,这种潜在表示在模型的后续部分中发挥着重要作用。它可以被用作分类、生成或其他相关任务的输入,帮助模型更好地完成这些任务。在上面的代码中,潜在表示最终通过线性输出层被转换成了目标大小的编码向量,这个向量可以进一步用于下游任务的处理和分析。

2. 解码器

接下来完成模型的解码器[yx1] [晓王2] 。解码器作为VQ-VAE的另一核心组件,其任务是与编码器相反,即将编码器生成的离散token表示重新映射回原始的图像空间。解码器能够逐步还原出图像的低层次细节,最终生成与输入图像在视觉上相似的重建图像。而且,由于token表示的离散性,解码器在生成图像时具有一定的创造性和多样性,使得VQ-VAE在生成新颖、多样化的图像方面表现出色。解码器代码如下:

import einops
import torch
from einops.layers.torch import Rearrange
from torch import nn
import torch
from torch.nn.functional import embedding

import config
class RemoveLatentTokens(nn.Module):
    def __init__(self, grid_size):
        super().__init__()
        self.grid_size = grid_size

    def forward(self, x):
        return x[:, 0 : self.grid_size**2]

from blocks import ResidualAttention

class Decoder(nn.Module):
    def __init__(self,config = config.Config,positional_embedding = None,latent_token_positional_embedding = None):
        super().__init__()
        self.image_size = config.image_size
        self.patch_size = config.patch_size
        self.grid_size = self.image_size // self.patch_size
        self.num_latent_tokens = config.num_latent_tokens
        d_model = self.d_model = self.width = config.d_model
        self.token_size = config.token_size
        self.num_heads = config.num_heads
        num_layers = self.num_layers = config.num_layers

        # project token dim to model dim
        self.grid_size = self.image_size // self.patch_size
        self.decoder_embed = nn.Linear(self.token_size, self.width, bias=True)
        scale = config.d_model ** -0.5
        #这个position_embedding 是加在图形上,以及加在补丁上
        self.positional_embedding = positional_embedding
        self.latent_token_positional_embedding = latent_token_positional_embedding

        self.remve_laten_token_layer = RemoveLatentTokens(self.grid_size)

        self.transformer = nn.ModuleList()  # attention layers
        for _ in range(self.num_layers):
            self.transformer.append(ResidualAttention(d_model = self.d_model,attention_head_num=6))

        self.norm = torch.nn.RMSNorm(d_model)
        # FFN to convert mask tokens to image patches
        self.ffn = nn.Sequential(
            nn.Conv2d(self.width, config.in_channels * self.patch_size ** 2, 1, padding=0, bias=True),
            Rearrange(
                "B (P1 P2 C) H W -> B C (H P1) (W P2)",
                P1=self.patch_size,
                P2=self.patch_size,
            ),
        )
        # conv layer on pixel output
        self.conv_out = nn.Conv2d(config.in_channels, config.in_channels, 3, padding=1, bias=True)

        self.model = nn.Sequential(*self.transformer,RemoveLatentTokens(grid_size=self.grid_size),self.norm)

    def forward(self, z_q, latent_tokens):

        x = z_q
        B,T,C = x.shape
        x = self.decoder_embed(x)

        mask_tokens = torch.unsqueeze(latent_tokens,dim=0).repeat(B, 1, 1).to(x.dtype)
        mask_tokens = mask_tokens + self.latent_token_positional_embedding.to(mask_tokens.dtype)

        x = torch.cat([mask_tokens, x], dim=1)
        x = self.model(x)  # decode latent tokens
        x = x + self.positional_embedding

        x = einops.rearrange(x,"B (H W) C -> B C H W", H=self.grid_size, W=self.grid_size)
        x = self.ffn(x)
        x = self.conv_out(x)
        return x

上面示例代码通过对输入的文本进行解码,从而完成了图像的重建任务。

13.2.3  VQ-VAE的模型设计

在完成了编码器与解码器的程序实现后,后面需要完成VQ-VAE的主程序设计。上一节中,我们已经详细阐述了VQ-VAE实现的核心要点。在具体程序设计方面,我们既可以选择手工编写这部分代码,以充分掌握每一个实现细节,也可以利用预设的vector_quantize_pytorch库来简化模型的设计过程。首先,我们需要安装vector_quantize_pytorch库:

pip install vector_quantize_pytorch

之后完成VQ-VAE模型,代码如下所示。

import torch,einops

import encoder
import decoder
import config

from vector_quantize_pytorch import VectorQuantize

class Tokenizer(torch.nn.Module):
    """
    1D Image Tokenizer
    """
    def __init__(self, config = config.Config()):
        super(Tokenizer, self).__init__()
        self.config = config
        self.image_size = config.image_size
        self.patch_size = config.patch_size
        self.grid_size = self.image_size // self.patch_size

        scale = config.d_model ** -0.5
        self.latent_tokens = torch.nn.Parameter(scale * torch.randn(self.grid_size ** 2, config.d_model))
        self.positional_embedding = torch.nn.Parameter(scale * torch.randn(self.grid_size ** 2, config.d_model))  # [ 7*7, 384 ]
        self.latent_token_positional_embedding = torch.nn.Parameter(scale * torch.randn(self.grid_size ** 2, config.d_model))

        self.encoder = encoder.Encoder(config,self.positional_embedding, self.latent_token_positional_embedding)
        self.decoder = decoder.Decoder(config,self.positional_embedding, self.latent_token_positional_embedding)

        self.vq =  VectorQuantize(dim = config.token_dim,codebook_size = config.codebook_size,decay = 0.8,commitment_weight = 1.)


    #模型训练用
    def forward(self, x):
        z_q,indices, commit_loss = self.encode(x)

        decoded_imaged = self.decoder(z_q,self.latent_tokens)
        return decoded_imaged,commit_loss

    def encode(self, x):
        embedding = self.encoder(x,self.latent_tokens)

        quantized, indices, commit_loss = self.vq(embedding)
        return quantized,indices, commit_loss

    def decode_tokens(self, tokens):
        z_q = self.vq.get_codes_from_indices(tokens)

        return self.decoder(z_q)

在上面代码中,我们定义了解码与编码模块,vq模块从编码器输出中提取对应的向量表示并通过计算得到离散token-indices。commit_loss是贡献的损失函数,其作用是在计算时对内容进行损失计算并反馈结果。

13.2.4  VQ-VAE的训练与预测

本小节将完成VQ-VAE模型的训练,读者可以使用如下代码完成模型的训练:

import tokenizer
import math
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import config

device = "cuda"
model = tokenizer.Tokenizer(config=config.Config())
model.to(device)
save_path = "./saver/ViLT_generator.pth"
model.load_state_dict(torch.load(save_path),strict=False)

BATCH_SIZE = 128
seq_len = 49
import get_data_emotion
#import get_data_emotion_2 as get_data_emotion
train_dataset = get_data_emotion.MNIST()
train_loader = (DataLoader(train_dataset, batch_size=BATCH_SIZE,shuffle=True))

optimizer = torch.optim.AdamW(model.parameters(), lr = 2e-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max = 12000,eta_min=2e-7,last_epoch=-1)

for epoch in range(6):
    pbar = tqdm(train_loader,total=len(train_loader))
    for inputs,token_inp,token_tgt in pbar:
        inputs = inputs.to(device)
        token_inp = token_inp.to(device)

        imaged,result_dict = model(inputs)

        #reconstruction_loss = torch.nn.functional.mse_loss(inputs, imaged, reduction="mean")
        reconstruction_loss = torch.nn.functional.smooth_l1_loss(inputs, imaged, reduction="sum")
        quantizer_loss = result_dict#["quantizer_loss"]
        autoencoder_loss = reconstruction_loss + quantizer_loss

        optimizer.zero_grad()
        autoencoder_loss.backward()
        optimizer.step()
        lr_scheduler.step()  # 执行优化器
        pbar.set_description(f"epoch:{epoch +1}, train_loss:{autoencoder_loss.item():.5f}, lr:{lr_scheduler.get_last_lr()[0]*1000:.5f}")
    if (epoch+1) % 2 == 0:
        torch.save(model.state_dict(), save_path)

torch.save(model.state_dict(), save_path)

在上面代码中,我们需要着重注意的是损失函数,这里主要定义了两种损失函数,分别是对码表的损失函数计算以及对图像损失函数的计算;在具体使用上,我们采用直接相加的方式来完成函数的计算。读者可以自行完成训练。

接下来就是模型的预测部分。我们可以使用训练好的模型对输入的内容进行重建,代码如下:

import tokenizer
import math
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import config
import matplotlib.pyplot as plt
device = "cpu"
model = tokenizer.Tokenizer(config=config.Config())

save_path = "./saver/ViLT_generator.pth"
model.load_state_dict(torch.load(save_path),strict=False)

import get_data_emotion
train_dataset = get_data_emotion.MNIST(is_train=False)

img, inp_tok,tgt_tok = train_dataset[929]

plt.imshow(img.permute(1, 2, 0))
plt.show()

image = torch.unsqueeze(img,0)
decoded, result_dict = model(image)
image_decoded = decoded[0].permute(1, 2, 0).detach().numpy()
plt.imshow(image_decoded)
plt.show()
print(image_decoded[0,:5])

img, inp_tok,tgt_tok = train_dataset[2929]
plt.imshow(img.permute(1, 2, 0))
plt.show()
print()
image = torch.unsqueeze(img,0)
decoded, result_dict = model(image)
image_decoded = decoded[0].permute(1, 2, 0).detach().numpy()
plt.imshow(image_decoded)
plt.show()
print(image_decoded[0,:5])

重建结果如图13-4所示,读者可以自行验证。

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值