《DeepSeek大模型高性能核心技术与多模态融合开发(人工智能技术丛书)》(王晓华)【摘要 书评 试读】- 京东图书
对于VQ-VAE的应用,其核心功能在于能够将原本连续的图像数据转化为离散的token表示。这一过程不仅实现了图像的高效压缩与编码,还为后续的图像处理和分析提供了新的视角。通过这些离散的token,我们可以更灵活地处理和操作图像数据,例如进行图像的检索、分类、编辑等任务。
具体来说,VQ-VAE通过学习一个离散的潜在空间来表示图像,这个空间由一系列预定义的token(或称为码字)构成。在训练过程中,VQ-VAE会将输入图像编码到这个离散空间中,选择最接近的token来表示图像的局部特征。这样,原本由像素值构成的连续图像就被转换成了一系列离散的token。
在生成图像时,VQ-VAE则根据这些token来解码并重构图像。由于token的离散性,生成的图像虽然在细节上可能与原图有所差异,但整体上能够保留原图的主要特征和结构。这种离散化的表示方法不仅降低了数据的复杂性,还提高了生成模型的效率和可控性。
此外,VQ-VAE生成的离散token序列还可以作为其他模型(如Transformer等)的输入,从而进一步拓展其在图像处理领域的应用。例如,可以通过对这些token进行序列建模,生成具有特定风格或内容的图像,或者实现图像的补全和修复等功能。
本章我们将完成基于VQ-VAE的图像生成。
13.2.1 图像的准备与超参数设置
本章我们将要完成的图像的准备,即使用MNIST数据集完成编码器VQ-VAE的手写体的生成。首先是图像文本的获取,我们可以直接使用MNIST数据集而对图像内容进行提取,代码如下所示。
from torch.utils.data import Dataset
from torchvision.transforms.v2 import PILToTensor, Compose
import torchvision
from tqdm import tqdm
import torch
import random
# 手写数字
class MNIST(Dataset):
def __init__(self, is_train=True):
super().__init__()
self.ds = torchvision.datasets.MNIST('../../dataset/mnist/', train=is_train, download=True)
self.img_convert = Compose([
PILToTensor(),
])
def __len__(self):
return len(self.ds)
def __getitem__(self, index):
img, label = self.ds[index]
#text = f"现在的数字是:{label}#"
text = random.sample(sample_texts,1)[0][-10:]
text = text + str(label) + "#"
full_tok = tokenizer_emo.encode(text)[-12:]
full_tok = full_tok + [1] * (12 - len(full_tok))
inp_tok = full_tok[:-1]
tgt_tok = full_tok[1:]
inp_tok = torch.tensor(inp_tok)
tgt_tok = torch.tensor(tgt_tok)
"""
torch.Size([1, 28, 28])
"""
return self.img_convert(img) / 255.0, inp_tok,tgt_tok
上面代码复用了文本生成部分的MNIST数据集,读者在具体使用时,可以根据需要自行对其进行调整。
接下来,我们需要完成图像的超参数设计。在求解VQ-VAE的过程中,我们设计了如下的参数内容:
class Config:
in_channels = 1
d_model = 384
image_size = 28
patch_size = 4
num_heads = 6
num_layers = 3
token_dim = token_size = 256
latent_token_vocab_size = num_latent_tokens = 32
codebook_size = 4096
token_size是指在VQ-VAE模型中量化后每个token(或“码字”)在潜在空间的维度。编码器会将输入数据转化为潜在表示,进而量化至一个离散的潜在空间。此量化步骤是通过将潜在表示里的每个向量用最近的“码字”替代来实现的,这些“码字”源于一个预先设定的码本。token_size参数即确定了这些“码字”的维度大小。例如,若token_size为12,意味着每个“码字”是一个12维的向量。
num_latent_tokens则代表在VQ-VAE的潜在空间中使用的不同“码字”的数量,它决定了码本的大小,即码本中包含离散向量的数目。在模型中,码本是一个可学习的参数集合,存储了表示输入数据特征的“码字”。num_latent_tokens参数控制着码本的大小,并影响模型捕捉输入数据细节的能力。举例来说,若num_latent_tokens设为64,则码本包含64个不同的“码字”,在量化潜在表示时,模型会从这些“码字”中选择一个来替换每个向量。
13.2.2 VQ-VAE的编码器与解码器
本小节讲解VQ-VAE的编码器和解码器。首先是编码器的作用,编码器将连续的图像特征转化为具有特定数目的token表示,而解码器则是将token复原成图像。
1. 编码器
编码器(Encoder)是深度学习模型中常见的一个组件,特别是在处理图像、文本等类型的数据时。它的主要作用是将输入数据(如图像)转换成一个更紧凑、更易于处理的表示形式,通常称为“编码”或“潜在表示”(Latent Representation)。这种表示可以捕捉输入数据的关键特征,并用于后续的任务,如分类、生成等。
其中Latent Representation,即潜在表示,是深度学习中一个核心概念,它指的是模型内部学习到的一种数据表示。这种表示通常不是直接可观察的,而是捕捉了输入数据的内在结构和特征。在上面的示例代码中,潜在表示是通过Encoder模块中的一系列变换得到的,它将原始图像数据转换成一种更高级、更抽象的形式。这种潜在表示有助于模型更好地理解和处理输入数据,进而提升在各种任务上的性能。
完整的编码器代码如下所示。
import einops
import torch
from einops.layers.torch import Rearrange
from torch import nn
class ExtractLatentTokens(torch.nn.Module):
"""
提取潜在表示(Latent Tokens)的模块。
这个模块用于从输入张量中提取一部分作为潜在表示,通常用于生成模型中,
例如变分自编码器(VAE)或者生成对抗网络(GAN)中的潜在空间操作。
属性:
grid_size (int): 网格的大小,用于确定从输入张量中提取潜在表示的起始位置。
"""
def __init__(self, grid_size):
"""
初始化 ExtractLatentTokens 模块。
参数:
grid_size (int): 网格的大小,这个值的平方将用于确定输入张量中
从哪个位置开始提取潜在表示。
"""
super(ExtractLatentTokens, self).__init__()
self.grid_size = grid_size
def forward(self, x):
"""
前向传播方法,从输入张量中提取潜在表示。
参数:
x (torch.Tensor): 输入张量,通常包含了数据的完整表示。
返回:
torch.Tensor: 提取的潜在表示张量,包含了从输入张量中指定位置开始的
所有元素。
"""
# 计算提取起始位置的索引,即 grid_size 的平方
start_index = self.grid_size ** 2
# 返回从 start_index 开始到 x 结束的切片,即提取的潜在表示
return x[:, start_index:]
import config
from blocks import ResidualAttention
class Encoder(nn.Module):
def __init__(self,config = config.Config,positional_embedding = None,latent_token_positional_embedding = None):
super(Encoder, self).__init__()
in_channels = config.in_channels
self.image_size = config.image_size
d_model = self.d_model = self.width = config.d_model
self.num_heads = config.num_heads
self.num_layers = config.num_layers
self.patch_size = config.patch_size
scale = self.width ** -0.5 # scale by 1/sqrt(d)
self.token_size = config.token_size#
self.patch_embed = torch.nn.Conv2d(in_channels=in_channels,out_channels=self.width,kernel_size= self.patch_size,stride= self.patch_size,bias = True)
# 图像补丁的位置嵌入 #torch.Size([49, 384])
self.grid_size = self.image_size // self.patch_size
#这个position_embedding 是加在图形上,以及加在补丁上
self.positional_embedding = positional_embedding # [ 7*7, 384 ]
self.latent_token_positional_embedding = latent_token_positional_embedding
self.transformer = nn.ModuleList()
for _ in range(self.num_layers):
self.transformer.append(ResidualAttention(d_model=self.d_model, attention_head_num=6))
self.norm = torch.nn.RMSNorm(d_model)
self.model = nn.Sequential(*self.transformer,ExtractLatentTokens(grid_size=self.grid_size),self.norm)
self.encoder_out = nn.Linear(self.width, self.token_size)
def forward(self, pixel_values, latent_tokens):
B, _, _, _ = pixel_values.shape
x = pixel_values
x = self.patch_embed(x)
x = einops.rearrange(x, "B C H W -> B (H W) C ")
x = x + self.positional_embedding.to(x.dtype)
latent_tokens = einops.repeat(latent_tokens, "T C -> B T C", B=B)
x = torch.cat([x, latent_tokens], dim=1)
x = self.model(x)
x = x + self.latent_token_positional_embedding.to(x.dtype)
x = self.encoder_out(x)
return x #[-1,32,256]
在上面示例代码的Encoder类中,潜在表示是通过多个Transformer层的自注意力机制和非线性变换逐步构建起来的。这些层能够捕捉输入图像中不同位置之间的依赖关系,并将这些信息编码到一个固定大小的向量空间中。这个过程使得模型能够提取出图像的关键特征,并以一种紧凑且高效的方式表达出来,即形成潜在表示。
最后,这种潜在表示在模型的后续部分中发挥着重要作用。它可以被用作分类、生成或其他相关任务的输入,帮助模型更好地完成这些任务。在上面的代码中,潜在表示最终通过线性输出层被转换成了目标大小的编码向量,这个向量可以进一步用于下游任务的处理和分析。
2. 解码器
接下来完成模型的解码器[yx1] [晓王2] 。解码器作为VQ-VAE的另一核心组件,其任务是与编码器相反,即将编码器生成的离散token表示重新映射回原始的图像空间。解码器能够逐步还原出图像的低层次细节,最终生成与输入图像在视觉上相似的重建图像。而且,由于token表示的离散性,解码器在生成图像时具有一定的创造性和多样性,使得VQ-VAE在生成新颖、多样化的图像方面表现出色。解码器代码如下:
import einops
import torch
from einops.layers.torch import Rearrange
from torch import nn
import torch
from torch.nn.functional import embedding
import config
class RemoveLatentTokens(nn.Module):
def __init__(self, grid_size):
super().__init__()
self.grid_size = grid_size
def forward(self, x):
return x[:, 0 : self.grid_size**2]
from blocks import ResidualAttention
class Decoder(nn.Module):
def __init__(self,config = config.Config,positional_embedding = None,latent_token_positional_embedding = None):
super().__init__()
self.image_size = config.image_size
self.patch_size = config.patch_size
self.grid_size = self.image_size // self.patch_size
self.num_latent_tokens = config.num_latent_tokens
d_model = self.d_model = self.width = config.d_model
self.token_size = config.token_size
self.num_heads = config.num_heads
num_layers = self.num_layers = config.num_layers
# project token dim to model dim
self.grid_size = self.image_size // self.patch_size
self.decoder_embed = nn.Linear(self.token_size, self.width, bias=True)
scale = config.d_model ** -0.5
#这个position_embedding 是加在图形上,以及加在补丁上
self.positional_embedding = positional_embedding
self.latent_token_positional_embedding = latent_token_positional_embedding
self.remve_laten_token_layer = RemoveLatentTokens(self.grid_size)
self.transformer = nn.ModuleList() # attention layers
for _ in range(self.num_layers):
self.transformer.append(ResidualAttention(d_model = self.d_model,attention_head_num=6))
self.norm = torch.nn.RMSNorm(d_model)
# FFN to convert mask tokens to image patches
self.ffn = nn.Sequential(
nn.Conv2d(self.width, config.in_channels * self.patch_size ** 2, 1, padding=0, bias=True),
Rearrange(
"B (P1 P2 C) H W -> B C (H P1) (W P2)",
P1=self.patch_size,
P2=self.patch_size,
),
)
# conv layer on pixel output
self.conv_out = nn.Conv2d(config.in_channels, config.in_channels, 3, padding=1, bias=True)
self.model = nn.Sequential(*self.transformer,RemoveLatentTokens(grid_size=self.grid_size),self.norm)
def forward(self, z_q, latent_tokens):
x = z_q
B,T,C = x.shape
x = self.decoder_embed(x)
mask_tokens = torch.unsqueeze(latent_tokens,dim=0).repeat(B, 1, 1).to(x.dtype)
mask_tokens = mask_tokens + self.latent_token_positional_embedding.to(mask_tokens.dtype)
x = torch.cat([mask_tokens, x], dim=1)
x = self.model(x) # decode latent tokens
x = x + self.positional_embedding
x = einops.rearrange(x,"B (H W) C -> B C H W", H=self.grid_size, W=self.grid_size)
x = self.ffn(x)
x = self.conv_out(x)
return x
上面示例代码通过对输入的文本进行解码,从而完成了图像的重建任务。
13.2.3 VQ-VAE的模型设计
在完成了编码器与解码器的程序实现后,后面需要完成VQ-VAE的主程序设计。上一节中,我们已经详细阐述了VQ-VAE实现的核心要点。在具体程序设计方面,我们既可以选择手工编写这部分代码,以充分掌握每一个实现细节,也可以利用预设的vector_quantize_pytorch库来简化模型的设计过程。首先,我们需要安装vector_quantize_pytorch库:
pip install vector_quantize_pytorch
之后完成VQ-VAE模型,代码如下所示。
import torch,einops
import encoder
import decoder
import config
from vector_quantize_pytorch import VectorQuantize
class Tokenizer(torch.nn.Module):
"""
1D Image Tokenizer
"""
def __init__(self, config = config.Config()):
super(Tokenizer, self).__init__()
self.config = config
self.image_size = config.image_size
self.patch_size = config.patch_size
self.grid_size = self.image_size // self.patch_size
scale = config.d_model ** -0.5
self.latent_tokens = torch.nn.Parameter(scale * torch.randn(self.grid_size ** 2, config.d_model))
self.positional_embedding = torch.nn.Parameter(scale * torch.randn(self.grid_size ** 2, config.d_model)) # [ 7*7, 384 ]
self.latent_token_positional_embedding = torch.nn.Parameter(scale * torch.randn(self.grid_size ** 2, config.d_model))
self.encoder = encoder.Encoder(config,self.positional_embedding, self.latent_token_positional_embedding)
self.decoder = decoder.Decoder(config,self.positional_embedding, self.latent_token_positional_embedding)
self.vq = VectorQuantize(dim = config.token_dim,codebook_size = config.codebook_size,decay = 0.8,commitment_weight = 1.)
#模型训练用
def forward(self, x):
z_q,indices, commit_loss = self.encode(x)
decoded_imaged = self.decoder(z_q,self.latent_tokens)
return decoded_imaged,commit_loss
def encode(self, x):
embedding = self.encoder(x,self.latent_tokens)
quantized, indices, commit_loss = self.vq(embedding)
return quantized,indices, commit_loss
def decode_tokens(self, tokens):
z_q = self.vq.get_codes_from_indices(tokens)
return self.decoder(z_q)
在上面代码中,我们定义了解码与编码模块,vq模块从编码器输出中提取对应的向量表示并通过计算得到离散token-indices。commit_loss是贡献的损失函数,其作用是在计算时对内容进行损失计算并反馈结果。
13.2.4 VQ-VAE的训练与预测
本小节将完成VQ-VAE模型的训练,读者可以使用如下代码完成模型的训练:
import tokenizer
import math
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import config
device = "cuda"
model = tokenizer.Tokenizer(config=config.Config())
model.to(device)
save_path = "./saver/ViLT_generator.pth"
model.load_state_dict(torch.load(save_path),strict=False)
BATCH_SIZE = 128
seq_len = 49
import get_data_emotion
#import get_data_emotion_2 as get_data_emotion
train_dataset = get_data_emotion.MNIST()
train_loader = (DataLoader(train_dataset, batch_size=BATCH_SIZE,shuffle=True))
optimizer = torch.optim.AdamW(model.parameters(), lr = 2e-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max = 12000,eta_min=2e-7,last_epoch=-1)
for epoch in range(6):
pbar = tqdm(train_loader,total=len(train_loader))
for inputs,token_inp,token_tgt in pbar:
inputs = inputs.to(device)
token_inp = token_inp.to(device)
imaged,result_dict = model(inputs)
#reconstruction_loss = torch.nn.functional.mse_loss(inputs, imaged, reduction="mean")
reconstruction_loss = torch.nn.functional.smooth_l1_loss(inputs, imaged, reduction="sum")
quantizer_loss = result_dict#["quantizer_loss"]
autoencoder_loss = reconstruction_loss + quantizer_loss
optimizer.zero_grad()
autoencoder_loss.backward()
optimizer.step()
lr_scheduler.step() # 执行优化器
pbar.set_description(f"epoch:{epoch +1}, train_loss:{autoencoder_loss.item():.5f}, lr:{lr_scheduler.get_last_lr()[0]*1000:.5f}")
if (epoch+1) % 2 == 0:
torch.save(model.state_dict(), save_path)
torch.save(model.state_dict(), save_path)
在上面代码中,我们需要着重注意的是损失函数,这里主要定义了两种损失函数,分别是对码表的损失函数计算以及对图像损失函数的计算;在具体使用上,我们采用直接相加的方式来完成函数的计算。读者可以自行完成训练。
接下来就是模型的预测部分。我们可以使用训练好的模型对输入的内容进行重建,代码如下:
import tokenizer
import math
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import config
import matplotlib.pyplot as plt
device = "cpu"
model = tokenizer.Tokenizer(config=config.Config())
save_path = "./saver/ViLT_generator.pth"
model.load_state_dict(torch.load(save_path),strict=False)
import get_data_emotion
train_dataset = get_data_emotion.MNIST(is_train=False)
img, inp_tok,tgt_tok = train_dataset[929]
plt.imshow(img.permute(1, 2, 0))
plt.show()
image = torch.unsqueeze(img,0)
decoded, result_dict = model(image)
image_decoded = decoded[0].permute(1, 2, 0).detach().numpy()
plt.imshow(image_decoded)
plt.show()
print(image_decoded[0,:5])
img, inp_tok,tgt_tok = train_dataset[2929]
plt.imshow(img.permute(1, 2, 0))
plt.show()
print()
image = torch.unsqueeze(img,0)
decoded, result_dict = model(image)
image_decoded = decoded[0].permute(1, 2, 0).detach().numpy()
plt.imshow(image_decoded)
plt.show()
print(image_decoded[0,:5])
重建结果如图13-4所示,读者可以自行验证。