TokenFlow的Shared Mapping实现

VQModel Arch

在这里插入图片描述

  • VQGAN:正常的VQGAN CNN EncoderVQGAN CNN Decoder组成。
        ### VQGAN encoder decoder definitions ###
        self.encoder = Encoder(ch_mult=config.encoder_ch_mult, z_channels=config.z_channels, dropout=config.dropout_p)
        if not config.enhanced_decoder:
            print("Using normal pixel decoder")
            self.decoder = Decoder(ch_mult=config.decoder_ch_mult, z_channels=config.z_channels, dropout=config.dropout_p)
        else:
            print("Using 2 times enhanced pixel decoder")
            self.decoder = Decoder(
                ch_mult=config.decoder_ch_mult, 
                z_channels=config.z_channels, 
                dropout=config.dropout_p,
                ch=256,
                num_res_blocks=4,
        )
        ### VQGAN encoder decoder definitions ###
        self.quant_conv = nn.Conv2d(config.z_channels, config.codebook_embed_dim, 1)
        self.post_quant_conv = nn.Conv2d(config.codebook_embed_dim, config.z_channels, 1)     
  • VQKDEncoder["clipb_224", "vitamin_xlarge_256", "siglip_384"]三选一,冻结hdea和第26层权重。Decoder是一个VQKD的语义Decoder的VisionTransformer
        ### VQKD encoder decoder definition ###
        ### define semantic encoder, initialize it with pretrained clip ###
        if self.teacher == 'clipb_224':
            self.encoder_vqkd = clip.load_clip_vision("ViT-B/16", download_root='./clip_model')
        elif self.teacher == 'vitamin_xlarge_256':
            print('Current teacher is vitamin_xlarge_256')
            self.encoder_vqkd = create_model(self.teacher, pretrained=True)
            del self.encoder_vqkd.head
            del self.encoder_vqkd.fc_norm
        elif self.teacher == 'siglip_384':
            print('Current teacher is google/siglip-so400m-patch14-384')
            self.encoder_vqkd = SiglipVisionModel.from_pretrained("google/siglip-so400m-patch14-384")
            # Freeze the head's parameters and 26th layer's parameters
            for name, param in self.encoder_vqkd.named_parameters():
                if 'head' in name:
                    param.requires_grad = False
                if 'encoder.layers.26' in name:
                    param.requires_grad = False
                if 'post_layernorm' in name:
                    param.requires_grad = False
        else:
            raise ValueError(f'teacher {self.teacher} not supported')
        
        print('Final decoder config', decoder_config)
        self.decoder_vqkd = VisionTransformer(**decoder_config)
  • Task Layer:分别是encode_task_layer将embed_dim对齐到semantic_code_dim,还有decode_task_layer将embed_dim对齐到decoder_out_dim。
### task layer ###
        self.encode_task_layer = nn.Sequential(
            nn.Linear(encoder_config['embed_dim'], encoder_config['embed_dim']),
            nn.Tanh(),
            nn.Linear(encoder_config['embed_dim'], semantic_code_dim) # for quantize
        )
        self.decode_task_layer = nn.Sequential(
            nn.Linear(decoder_config['embed_dim'], decoder_config['embed_dim']),
            nn.Tanh(),
            nn.Linear(decoder_config['embed_dim'], self.decoder_out_dim),
        )
        self.encode_task_layer.apply(self._init_weights)
        self.decode_task_layer.apply(self._init_weights)
        ### task layer ###
  • Scaling Layer:归一化
# scaling layers
        if self.teacher == 'clipb_224':
            self.scaling_layer = None # can be changed to ScalingLayerForClip()
        elif self.teacher == 'vitamin_xlarge_256':
            self.scaling_layer = ScalingLayerForClip()
        elif self.teacher == 'siglip_384':
            self.scaling_layer = ScalingLayerForSigLip()
        else:
            raise ValueError(f'teacher {self.teacher} not supported')   
class ScalingLayerForSigLip(nn.Module):
    def __init__(self):
        super(ScalingLayerForSigLip, self).__init__()
        self.register_buffer('shift', torch.Tensor([0.5, 0.5, 0.5])[None, :, None, None])
        self.register_buffer('scale', torch.Tensor([0.5, 0.5, 0.5])[None, :, None, None])

    def forward(self, inp):
        inp = ((inp + 1.) * 127.5).clamp(0, 255.) / 255. # rescale to [0, 1.]
        return (inp - self.shift) / self.scale

  • VQ Module
# quantizer
        self.use_kmeans = config.kmeans
        self.quantize = VectorQuantizer(config.codebook_size, self.code_dim, 
                                config.commit_loss_beta, config.entropy_loss_ratio,
                                config.codebook_l2_norm, config.codebook_show_usage, split=[self.semantic_code_dim, self.vqgan_code_dim], kmeans=config.kmeans)

VQModel Forward

  • encode
# vqkd
		elif self.teacher == 'siglip_384':
            vqkd_x = self.scaling_layer(x)
            vqkd_feature = self.encoder_vqkd(vqkd_x, output_hidden_states=True).hidden_states[-2]
            vqkd_feature = self.encode_task_layer(vqkd_feature.type_as(self.encode_task_layer[-1].weight))
            N = vqkd_feature.shape[1]
            h, w = int(math.sqrt(N)), int(math.sqrt(N))
            vqkd_feature = rearrange(vqkd_feature, 'b (h w) c -> b c h w', h=h, w=w) # reshape for quantizer
        
# vqgan
        elif self.teacher == 'siglip_384':
            upscale_reso = 2**4 * vqkd_feature.shape[-1]
            vqgan_x = F.interpolate(x, size=(upscale_reso, upscale_reso), mode='bicubic')
            vqgan_features = self.encoder(vqgan_x)
            vqgan_features = self.quant_conv(vqgan_features)

# concat
		h = torch.cat([vqkd_feature, vqgan_features], dim=1)

# shared vq, if not use var(without residual_quant)
		quant, emb_loss, info = self.quantize(h)
  • decode
# split
		vqkd_quant, vqgan_quant = torch.split(quant, [self.semantic_code_dim, self.vqgan_code_dim], dim=1)

# vqkd reconstruct semantic feature
        vqkd_recon = self.decoder_vqkd(vqkd_quant, return_patch_tokens=True)
        vqkd_recon = self.decode_task_layer(vqkd_recon)

# vqgan reconstruct pixel values
        vqgan_recon = self.post_quant_conv(vqgan_quant)
        vqgan_recon = self.decoder(vqgan_recon)

        if self.teacher == 'siglip_384':
            vqgan_recon = F.interpolate(vqgan_recon, size=(384, 384), mode='bicubic')

        dec = (vqkd_recon, vqgan_recon)

Shared Quantizer

VectorQuantizer包含vqganvqkd两个embedding,查找时使用shared index mapping,将两个codebook的embedding两两绑定。

  • init初始化:初始化两个codebook的embedding。
			self.embedding_vqkd = nn.Embedding(self.n_e, self.split[0])
            self.embedding_vqkd.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
            if self.l2_norm:
                self.embedding_vqkd.weight.data = F.normalize(self.embedding_vqkd.weight.data, p=2, dim=-1)

            self.embedding_vqgan = nn.Embedding(self.n_e, self.split[1])
            self.embedding_vqgan.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
            if self.l2_norm:
                self.embedding_vqgan.weight.data = F.normalize(self.embedding_vqgan.weight.data, p=2, dim=-1)
  • forward共享索引查表:分别计算d_vqkdd_vqgan,总的距离d = d_vqkd + d_vqganargmin(d)得到两个codebook共享的索引min_encoding_indices
    • 注意:Toenflow虽然有多个codebook,但codebook的embedding loss只有1个,是对多个codebook得到的zq进行concat之后计算的!
# reshape and split concated_z for lookup codebook
		z = torch.einsum('b c h w -> b h w c', z).contiguous()
        z_flattened = z.view(-1, self.e_dim)
        z_vqkd, z_vqgan = torch.split(z, split_size_or_sections=self.split, dim=-1)
        z_flattened_vqkd, z_flattened_vqgan = torch.split(z_flattened, split_size_or_sections=self.split, dim=-1)

# compute distance for each codebook
		d_vqkd = torch.sum(z_flattened_vqkd ** 2, dim=1, keepdim=True) + \
                torch.sum(embedding_vqkd**2, dim=1) - 2 * \
                torch.einsum('bd,dn->bn', z_flattened_vqkd, torch.einsum('n d -> d n', embedding_vqkd))
        d_vqgan = torch.sum(z_flattened_vqgan ** 2, dim=1, keepdim=True) + \
                torch.sum(embedding_vqgan**2, dim=1) - 2 * \
                torch.einsum('bd,dn->bn', z_flattened_vqgan, torch.einsum('n d -> d n', embedding_vqgan))
        
 # shared mapping
        d = d_vqkd + 1.0 * d_vqgan
        min_encoding_indices = torch.argmin(d, dim=1)

# replace embedding of z with zq by shared_min_encoding_indices
		all_embedding = torch.cat([embedding_vqkd, embedding_vqgan], dim=-1)
        z_q = all_embedding[min_encoding_indices].view(z.shape)
        z_q_vqkd = embedding_vqkd[min_encoding_indices].view(z_vqkd.shape)
        z_q_vqgan = embedding_vqgan[min_encoding_indices].view(z_vqgan.shape)

# embedding loss
	vq_loss = torch.mean((z_q - z.detach()) ** 2) 
	commit_loss = self.beta * torch.mean((z_q.detach() - z) ** 2) 
	entropy_loss = self.entropy_loss_ratio * compute_entropy_loss(-d)

# preserve gradients
        z_q = z + (z_q - z).detach()
        # reshape back to match original input shape
        z_q = torch.einsum('b h w c -> b c h w', z_q)

return z_q, (vq_loss, commit_loss, entropy_loss, codebook_usage, vqkd_d_norm, vqgan_d_norm), (perplexity, min_encodings, min_encoding_indices)

VQLoss

这里初始化一个teacher用于蒸馏,teacher的类型和VQModel的semantic encoder的类型一样(从["clipb_224", "vitamin_xlarge_256", "siglip_384"]三选一),不过区别是semantic encoder的参数只有最后的head是冻结的,这里teacher的参数是全部冻结的

# vqkd_recon是vqkd重建的CLIP feature,reconstructions是vqgan重建的pixel values
vqkd_recon, reconstructions = reconstructions
  • vqkd reconstruction loss:计算 semantic encoder decoder的vq重建 和 teacher输出 的余弦相似度 loss
    def calculate_clip_rec_loss(self, rec, target):  
        target = target / target.norm(dim=-1, keepdim=True)
        rec = rec / rec.norm(dim=-1, keepdim=True)
        rec_loss = (1 - (target * rec).sum(-1)).mean()
        return rec_loss
        
if vqkd_recon is not None:
	# get teacher feature
	clip_target = self.get_regress_target(inputs)

	vqkd_loss = self.calculate_clip_rec_loss(vqkd_recon, clip_target)
  • vqgan reconstruction loss:
 # reconstruction loss
rec_loss = self.rec_loss(inputs.contiguous(), reconstructions.contiguous())
  • perceptual loss
            # perceptual loss
            p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
            p_loss = torch.mean(p_loss)

  • discriminator gen_adv_loss
            # discriminator loss
            logits_fake = self.discriminator(reconstructions.contiguous())
            generator_adv_loss = self.gen_adv_loss(logits_fake)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Yuezero_

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值