VQModel Arch
- VQGAN:正常的
VQGAN CNN Encoder
和VQGAN CNN Decoder
组成。
### VQGAN encoder decoder definitions ###
self.encoder = Encoder(ch_mult=config.encoder_ch_mult, z_channels=config.z_channels, dropout=config.dropout_p)
if not config.enhanced_decoder:
print("Using normal pixel decoder")
self.decoder = Decoder(ch_mult=config.decoder_ch_mult, z_channels=config.z_channels, dropout=config.dropout_p)
else:
print("Using 2 times enhanced pixel decoder")
self.decoder = Decoder(
ch_mult=config.decoder_ch_mult,
z_channels=config.z_channels,
dropout=config.dropout_p,
ch=256,
num_res_blocks=4,
)
### VQGAN encoder decoder definitions ###
self.quant_conv = nn.Conv2d(config.z_channels, config.codebook_embed_dim, 1)
self.post_quant_conv = nn.Conv2d(config.codebook_embed_dim, config.z_channels, 1)
- VQKD:
Encoder
从["clipb_224", "vitamin_xlarge_256", "siglip_384"]
三选一,冻结hdea和第26层权重。Decoder
是一个VQKD的语义Decoder的VisionTransformer
。
### VQKD encoder decoder definition ###
### define semantic encoder, initialize it with pretrained clip ###
if self.teacher == 'clipb_224':
self.encoder_vqkd = clip.load_clip_vision("ViT-B/16", download_root='./clip_model')
elif self.teacher == 'vitamin_xlarge_256':
print('Current teacher is vitamin_xlarge_256')
self.encoder_vqkd = create_model(self.teacher, pretrained=True)
del self.encoder_vqkd.head
del self.encoder_vqkd.fc_norm
elif self.teacher == 'siglip_384':
print('Current teacher is google/siglip-so400m-patch14-384')
self.encoder_vqkd = SiglipVisionModel.from_pretrained("google/siglip-so400m-patch14-384")
# Freeze the head's parameters and 26th layer's parameters
for name, param in self.encoder_vqkd.named_parameters():
if 'head' in name:
param.requires_grad = False
if 'encoder.layers.26' in name:
param.requires_grad = False
if 'post_layernorm' in name:
param.requires_grad = False
else:
raise ValueError(f'teacher {self.teacher} not supported')
print('Final decoder config', decoder_config)
self.decoder_vqkd = VisionTransformer(**decoder_config)
- Task Layer:分别是
encode_task_layer
将embed_dim对齐到semantic_code_dim,还有decode_task_layer
将embed_dim对齐到decoder_out_dim。
### task layer ###
self.encode_task_layer = nn.Sequential(
nn.Linear(encoder_config['embed_dim'], encoder_config['embed_dim']),
nn.Tanh(),
nn.Linear(encoder_config['embed_dim'], semantic_code_dim) # for quantize
)
self.decode_task_layer = nn.Sequential(
nn.Linear(decoder_config['embed_dim'], decoder_config['embed_dim']),
nn.Tanh(),
nn.Linear(decoder_config['embed_dim'], self.decoder_out_dim),
)
self.encode_task_layer.apply(self._init_weights)
self.decode_task_layer.apply(self._init_weights)
### task layer ###
- Scaling Layer:归一化
# scaling layers
if self.teacher == 'clipb_224':
self.scaling_layer = None # can be changed to ScalingLayerForClip()
elif self.teacher == 'vitamin_xlarge_256':
self.scaling_layer = ScalingLayerForClip()
elif self.teacher == 'siglip_384':
self.scaling_layer = ScalingLayerForSigLip()
else:
raise ValueError(f'teacher {self.teacher} not supported')
class ScalingLayerForSigLip(nn.Module):
def __init__(self):
super(ScalingLayerForSigLip, self).__init__()
self.register_buffer('shift', torch.Tensor([0.5, 0.5, 0.5])[None, :, None, None])
self.register_buffer('scale', torch.Tensor([0.5, 0.5, 0.5])[None, :, None, None])
def forward(self, inp):
inp = ((inp + 1.) * 127.5).clamp(0, 255.) / 255. # rescale to [0, 1.]
return (inp - self.shift) / self.scale
- VQ Module:
# quantizer
self.use_kmeans = config.kmeans
self.quantize = VectorQuantizer(config.codebook_size, self.code_dim,
config.commit_loss_beta, config.entropy_loss_ratio,
config.codebook_l2_norm, config.codebook_show_usage, split=[self.semantic_code_dim, self.vqgan_code_dim], kmeans=config.kmeans)
VQModel Forward
- encode:
# vqkd
elif self.teacher == 'siglip_384':
vqkd_x = self.scaling_layer(x)
vqkd_feature = self.encoder_vqkd(vqkd_x, output_hidden_states=True).hidden_states[-2]
vqkd_feature = self.encode_task_layer(vqkd_feature.type_as(self.encode_task_layer[-1].weight))
N = vqkd_feature.shape[1]
h, w = int(math.sqrt(N)), int(math.sqrt(N))
vqkd_feature = rearrange(vqkd_feature, 'b (h w) c -> b c h w', h=h, w=w) # reshape for quantizer
# vqgan
elif self.teacher == 'siglip_384':
upscale_reso = 2**4 * vqkd_feature.shape[-1]
vqgan_x = F.interpolate(x, size=(upscale_reso, upscale_reso), mode='bicubic')
vqgan_features = self.encoder(vqgan_x)
vqgan_features = self.quant_conv(vqgan_features)
# concat
h = torch.cat([vqkd_feature, vqgan_features], dim=1)
# shared vq, if not use var(without residual_quant)
quant, emb_loss, info = self.quantize(h)
- decode:
# split
vqkd_quant, vqgan_quant = torch.split(quant, [self.semantic_code_dim, self.vqgan_code_dim], dim=1)
# vqkd reconstruct semantic feature
vqkd_recon = self.decoder_vqkd(vqkd_quant, return_patch_tokens=True)
vqkd_recon = self.decode_task_layer(vqkd_recon)
# vqgan reconstruct pixel values
vqgan_recon = self.post_quant_conv(vqgan_quant)
vqgan_recon = self.decoder(vqgan_recon)
if self.teacher == 'siglip_384':
vqgan_recon = F.interpolate(vqgan_recon, size=(384, 384), mode='bicubic')
dec = (vqkd_recon, vqgan_recon)
Shared Quantizer
VectorQuantizer包含vqgan
和vqkd
两个embedding,查找时使用shared index mapping
,将两个codebook的embedding两两绑定。
- init初始化:初始化两个codebook的embedding。
self.embedding_vqkd = nn.Embedding(self.n_e, self.split[0])
self.embedding_vqkd.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
if self.l2_norm:
self.embedding_vqkd.weight.data = F.normalize(self.embedding_vqkd.weight.data, p=2, dim=-1)
self.embedding_vqgan = nn.Embedding(self.n_e, self.split[1])
self.embedding_vqgan.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
if self.l2_norm:
self.embedding_vqgan.weight.data = F.normalize(self.embedding_vqgan.weight.data, p=2, dim=-1)
- forward共享索引查表:分别计算
d_vqkd
和d_vqgan
,总的距离d = d_vqkd + d_vqgan
,argmin(d)
得到两个codebook共享的索引min_encoding_indices
。- 注意:Toenflow虽然有多个codebook,但
codebook的embedding loss只有1个
,是对多个codebook得到的zq进行concat之后计算的!
- 注意:Toenflow虽然有多个codebook,但
# reshape and split concated_z for lookup codebook
z = torch.einsum('b c h w -> b h w c', z).contiguous()
z_flattened = z.view(-1, self.e_dim)
z_vqkd, z_vqgan = torch.split(z, split_size_or_sections=self.split, dim=-1)
z_flattened_vqkd, z_flattened_vqgan = torch.split(z_flattened, split_size_or_sections=self.split, dim=-1)
# compute distance for each codebook
d_vqkd = torch.sum(z_flattened_vqkd ** 2, dim=1, keepdim=True) + \
torch.sum(embedding_vqkd**2, dim=1) - 2 * \
torch.einsum('bd,dn->bn', z_flattened_vqkd, torch.einsum('n d -> d n', embedding_vqkd))
d_vqgan = torch.sum(z_flattened_vqgan ** 2, dim=1, keepdim=True) + \
torch.sum(embedding_vqgan**2, dim=1) - 2 * \
torch.einsum('bd,dn->bn', z_flattened_vqgan, torch.einsum('n d -> d n', embedding_vqgan))
# shared mapping
d = d_vqkd + 1.0 * d_vqgan
min_encoding_indices = torch.argmin(d, dim=1)
# replace embedding of z with zq by shared_min_encoding_indices
all_embedding = torch.cat([embedding_vqkd, embedding_vqgan], dim=-1)
z_q = all_embedding[min_encoding_indices].view(z.shape)
z_q_vqkd = embedding_vqkd[min_encoding_indices].view(z_vqkd.shape)
z_q_vqgan = embedding_vqgan[min_encoding_indices].view(z_vqgan.shape)
# embedding loss
vq_loss = torch.mean((z_q - z.detach()) ** 2)
commit_loss = self.beta * torch.mean((z_q.detach() - z) ** 2)
entropy_loss = self.entropy_loss_ratio * compute_entropy_loss(-d)
# preserve gradients
z_q = z + (z_q - z).detach()
# reshape back to match original input shape
z_q = torch.einsum('b h w c -> b c h w', z_q)
return z_q, (vq_loss, commit_loss, entropy_loss, codebook_usage, vqkd_d_norm, vqgan_d_norm), (perplexity, min_encodings, min_encoding_indices)
VQLoss
这里初始化一个teacher
用于蒸馏,teacher的类型和VQModel的semantic encoder
的类型一样(从["clipb_224", "vitamin_xlarge_256", "siglip_384"]
三选一),不过区别是semantic encoder
的参数只有最后的head是冻结的,这里teacher
的参数是全部冻结的。
# vqkd_recon是vqkd重建的CLIP feature,reconstructions是vqgan重建的pixel values
vqkd_recon, reconstructions = reconstructions
- vqkd reconstruction loss:计算 semantic encoder decoder的vq重建 和 teacher输出 的
余弦相似度 loss
。
def calculate_clip_rec_loss(self, rec, target):
target = target / target.norm(dim=-1, keepdim=True)
rec = rec / rec.norm(dim=-1, keepdim=True)
rec_loss = (1 - (target * rec).sum(-1)).mean()
return rec_loss
if vqkd_recon is not None:
# get teacher feature
clip_target = self.get_regress_target(inputs)
vqkd_loss = self.calculate_clip_rec_loss(vqkd_recon, clip_target)
- vqgan reconstruction loss:
# reconstruction loss
rec_loss = self.rec_loss(inputs.contiguous(), reconstructions.contiguous())
- perceptual loss
# perceptual loss
p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
p_loss = torch.mean(p_loss)
- discriminator gen_adv_loss
# discriminator loss
logits_fake = self.discriminator(reconstructions.contiguous())
generator_adv_loss = self.gen_adv_loss(logits_fake)