基于VQ-VAE的手写体生成

《DeepSeek大模型高性能核心技术与多模态融合开发(人工智能技术丛书)》(王晓华)【摘要 书评 试读】- 京东图书

对于VQ-VAE的应用,其核心功能在于能够将原本连续的图像数据转化为离散的token表示。这一过程不仅实现了图像的高效压缩与编码,还为后续的图像处理和分析提供了新的视角。通过这些离散的token,我们可以更灵活地处理和操作图像数据,例如进行图像的检索、分类、编辑等任务。

具体来说,VQ-VAE通过学习一个离散的潜在空间来表示图像,这个空间由一系列预定义的token(或称为码字)构成。在训练过程中,VQ-VAE会将输入图像编码到这个离散空间中,选择最接近的token来表示图像的局部特征。这样,原本由像素值构成的连续图像就被转换成了一系列离散的token。

在生成图像时,VQ-VAE则根据这些token来解码并重构图像。由于token的离散性,生成的图像虽然在细节上可能与原图有所差异,但整体上能够保留原图的主要特征和结构。这种离散化的表示方法不仅降低了数据的复杂性,还提高了生成模型的效率和可控性。

此外,VQ-VAE生成的离散token序列还可以作为其他模型(如Transformer等)的输入,从而进一步拓展其在图像处理领域的应用。例如,可以通过对这些token进行序列建模,生成具有特定风格或内容的图像,或者实现图像的补全和修复等功能。

本章我们将完成基于VQ-VAE的图像生成。

13.2.1  图像的准备与超参数设置

本章我们将要完成的图像的准备,即使用MNIST数据集完成编码器VQ-VAE的手写体的生成。首先是图像文本的获取,我们可以直接使用MNIST数据集而对图像内容进行提取,代码如下所示。

from torch.utils.data import Dataset
from torchvision.transforms.v2 import PILToTensor, Compose
import torchvision
from tqdm import tqdm
import torch
import random

# 手写数字
class MNIST(Dataset):
    def __init__(self, is_train=True):
        super().__init__()
        self.ds = torchvision.datasets.MNIST('../../dataset/mnist/', train=is_train, download=True)
        self.img_convert = Compose([
            PILToTensor(),
        ])

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, index):
        img, label = self.ds[index]

        #text = f"现在的数字是:{label}#"
        text = random.sample(sample_texts,1)[0][-10:]
        text = text + str(label) + "#"

        full_tok = tokenizer_emo.encode(text)[-12:]
        full_tok = full_tok + [1] * (12 - len(full_tok))

        inp_tok = full_tok[:-1]
        tgt_tok = full_tok[1:]

        inp_tok = torch.tensor(inp_tok)
        tgt_tok = torch.tensor(tgt_tok)

        """
        torch.Size([1, 28, 28])
        """
        return self.img_convert(img) / 255.0, inp_tok,tgt_tok

上面代码复用了文本生成部分的MNIST数据集,读者在具体使用时,可以根据需要自行对其进行调整。

接下来,我们需要完成图像的超参数设计。在求解VQ-VAE的过程中,我们设计了如下的参数内容:

class Config:
    in_channels = 1
    d_model = 384
    image_size = 28
    patch_size = 4
    num_heads = 6
    num_layers = 3

    token_dim = token_size = 256
    latent_token_vocab_size = num_latent_tokens = 32

    codebook_size = 4096

token_size是指在VQ-VAE模型中量化后每个token(或“码字”)在潜在空间的维度。编码器会将输入数据转化为潜在表示,进而量化至一个离散的潜在空间。此量化步骤是通过将潜在表示里的每个向量用最近的“码字”替代来实现的,这些“码字”源于一个预先设定的码本。token_size参数即确定了这些“码字”的维度大小。例如,若token_size为12,意味着每个“码字”是一个12维的向量。

num_latent_tokens则代表在VQ-VAE的潜在空间中使用的不同“码字”的数量,它决定了码本的大小,即码本中包含离散向量的数目。在模型中,码本是一个可学习的参数集合,存储了表示输入数据特征的“码字”。num_latent_tokens参数控制着码本的大小,并影响模型捕捉输入数据细节的能力。举例来说,若num_latent_tokens设为64,则码本包含64个不同的“码字”,在量化潜在表

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值