denoisemodel

far.py

from functools import partial

import numpy as np
from tqdm import tqdm
import scipy.stats as stats
import math
import torch
import torch.nn as nn

import matplotlib.pyplot as plt
from timm.models.vision_transformer import Block

from models.diffloss import DiffLoss

import torch.nn.functional as F
import torchvision.transforms as T
import random


from torchvision.utils import make_grid
from typing import Optional
from PIL import Image
def save_image(images: torch.Tensor, nrow: int = 8, show: bool = True, path: Optional[str] = None, format: Optional[str] = None, to_grayscale: bool = False, **kwargs):
    images = images * 0.5 + 0.5
    grid = make_grid(images, nrow=nrow, **kwargs)  # (channels, height, width)
    #  (height, width, channels)
    grid = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
    im = Image.fromarray(grid)
    if to_grayscale:
        im = im.convert(mode="L")
    if path is not None:
        im.save(path, format=format)
    if show:
        im.show()
    return grid


def mask_by_order(mask_len, order, bsz, seq_len):
    masking = torch.zeros(bsz, seq_len).cuda()
    masking = torch.scatter(masking, dim=-1, index=order[:, :mask_len.long()], src=torch.ones(bsz, seq_len).cuda()).bool()
    return masking


class FAR(nn.Module):
    """ Masked Autoencoder with VisionTransformer backbone
    """
    def __init__(self, img_size=256, vae_stride=16, patch_size=1,
                 encoder_embed_dim=1024, encoder_depth=16, encoder_num_heads=16,
                 decoder_embed_dim=1024, decoder_depth=16, decoder_num_heads=16,
                 mlp_ratio=4., norm_layer=nn.LayerNorm,
                 vae_embed_dim=16,
                 mask=True,
                 mask_ratio_min=0.7,
                 label_drop_prob=0.1,
                 class_num=1000,
                 attn_dropout=0.1,
                 proj_dropout=0.1,
                 buffer_size=64,
                 diffloss_d=3,
                 diffloss_w=1024,
                 num_sampling_steps='100',
                 diffusion_batch_mul=4
                 ):
        super().__init__()

        # --------------------------------------------------------------------------
        # VAE and patchify specifics
        self.vae_embed_dim = vae_embed_dim

        self.img_size = img_size
        self.vae_stride = vae_stride
        self.patch_size = patch_size
        self.seq_h = self.seq_w = img_size // vae_stride // patch_size
        self.seq_len = self.seq_h * self.seq_w
        self.token_embed_dim = vae_embed_dim * patch_size**2

        # --------------------------------------------------------------------------
        # Class Embedding
        self.num_classes = class_num
        self.class_emb = nn.Embedding(1000, encoder_embed_dim)
        self.label_drop_prob = label_drop_prob
        # Fake class embedding for CFG's unconditional generation
        self.fake_latent = nn.Parameter(torch.zeros(1, encoder_embed_dim))
        self.loss_weight = [1 + np.sin(math.pi / 2. * (bands + 1) / self.seq_h) for bands in range(self.seq_h)]

        # --------------------------------------------------------------------------
        self.mask = mask
        # FAR variant masking ratio, a left-half truncated Gaussian centered at 100% masking ratio with std 0.25
        self.mask_ratio_generator = stats.truncnorm((mask_ratio_min - 1.0) / 0.25, 0, loc=1.0, scale=0.25)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
        
        # --------------------------------------------------------------------------
        # FAR encoder specifics
        self.z_proj = nn.Linear(self.token_embed_dim, encoder_embed_dim, bias=True)
        self.z_proj_ln = nn.LayerNorm(encoder_embed_dim, eps=1e-6)
        self.buffer_size = buffer_size
        self.encoder_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len + self.buffer_size, encoder_embed_dim))

        self.encoder_blocks = nn.ModuleList([
            Block(encoder_embed_dim, encoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer,
                  proj_drop=proj_dropout, attn_drop=attn_dropout) for _ in range(encoder_depth)])
        self.encoder_norm = norm_layer(encoder_embed_dim)

        # --------------------------------------------------------------------------
        # FAR decoder specifics
        self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True)
        self.decoder_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len + self.buffer_size, decoder_embed_dim))

        self.decoder_blocks = nn.ModuleList([
            Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer,
                  proj_drop=proj_dropout, attn_drop=attn_dropout) for _ in range(decoder_depth)])

        self.decoder_norm = norm_layer(decoder_embed_dim)
        self.diffusion_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len, decoder_embed_dim))

        self.initialize_weights()

        # --------------------------------------------------------------------------
        # Diffusion Loss
        self.diffloss = DiffLoss(
            target_channels=self.token_embed_dim,
            z_channels=decoder_embed_dim,
            width=diffloss_w,
            depth=diffloss_d,
            num_sampling_steps=num_sampling_steps,
        )
        self.diffusion_batch_mul = diffusion_batch_mul

    def initialize_weights(self):
        # parameters
        torch.nn.init.normal_(self.class_emb.weight, std=.02)
        torch.nn.init.normal_(self.fake_latent, std=.02)
        if self.mask:
            torch.nn.init.normal_(self.mask_token, std=.02)
        torch.nn.init.normal_(self.encoder_pos_embed_learned, std=.02)
        torch.nn.init.normal_(self.decoder_pos_embed_learned, std=.02)
        torch.nn.init.normal_(self.diffusion_pos_embed_learned, std=.02)

        # initialize nn.Linear and nn.LayerNorm
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            # we use xavier_uniform following official JAX ViT:
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
            if m.weight is not None:
                nn.init.constant_(m.weight, 1.0)

    def patchify(self, x):
        bsz, c, h, w = x.shape
        p = self.patch_size
        h_, w_ = h // p, w // p

        x = x.reshape(bsz, c, h_, p, w_, p)
        x = torch.einsum('nchpwq->nhwcpq', x)
        x = x.reshape(bsz, h_ * w_, c * p ** 2)
        return x  # [n, l, d]

    def unpatchify(self, x):
        bsz = x.shape[0]
        p = self.patch_size
        c = self.vae_embed_dim
        h_, w_ = self.seq_h, self.seq_w

        x = x.reshape(bsz, h_, w_, c, p, p)
        x = torch.einsum('nhwcpq->nchpwq', x)
        x = x.reshape(bsz, c, h_ * p, w_ * p)
        return x  # [n, c, h, w]

    def sample_orders(self, bsz):
        # generate a batch of random generation orders
        orders = []
        for _ in range(bsz):
            order = np.array(list(range(self.seq_len)))
            np.random.shuffle(order)
            orders.append(order)
        orders = torch.Tensor(np.array(orders)).cuda().long()
        return orders

    def random_masking(self, x, orders, x_index):
        # generate token mask
        bsz, seq_len, embed_dim = x.shape
        stage = x_index[0]
        mask_ratio_min = 0.7 * stage / 16
        mask_rate = stats.truncnorm((mask_ratio_min - 1.0) / 0.25, 0, loc=1.0, scale=0.25).rvs(1)[0]
        num_masked_tokens = int(np.ceil(seq_len * mask_rate))
        mask = torch.zeros(bsz, seq_len, device=x.device)
        mask = torch.scatter(mask, dim=-1, index=orders[:, :num_masked_tokens],
                             src=torch.ones(bsz, seq_len, device=x.device))
        return mask


    def processingpregt_latent(self, imgs):
        B, C, H, W = imgs.shape
        out = torch.zeros_like(imgs)
        latent_core = list(range(H))
        core_index = []

        random_number = torch.randint(0, len(latent_core), (1,))
        for i in range(B):
            chosen_core = latent_core[random_number]
            core_index.append(chosen_core)
            if random_number == 0:
                out[i] = torch.zeros(C, H, W).cuda()    # torch.Size([256, 256, 16])
            else:
                imgs_resize = F.interpolate(imgs[i].unsqueeze(0), size=(chosen_core, chosen_core), mode='area')
                out[i] = F.interpolate(imgs_resize, size=(H, W), mode='bicubic').squeeze(0)
        core_index = torch.tensor(core_index).to(out.device).half()
        return out, core_index

    
    def forward_mae_encoder(self, x, class_embedding, mask=None):
        x = self.z_proj(x)
        bsz, seq_len, embed_dim = x.shape

        # concat buffer
        x = torch.cat([torch.zeros(bsz, self.buffer_size, embed_dim, device=x.device), x], dim=1)
        
        if self.mask:
            mask_with_buffer = torch.cat([torch.zeros(x.size(0), self.buffer_size, device=x.device), mask], dim=1)

        # random drop class embedding during training
        if self.training:  
            drop_latent_mask = torch.rand(bsz) < self.label_drop_prob
            drop_latent_mask = drop_latent_mask.unsqueeze(-1).cuda().to(x.dtype)
            class_embedding = drop_latent_mask * self.fake_latent + (1 - drop_latent_mask) * class_embedding

        x[:, :self.buffer_size] = class_embedding.unsqueeze(1)

        # encoder position embedding
        x = x + self.encoder_pos_embed_learned
        x = self.z_proj_ln(x)

        # dropping
        if self.mask:
            x = x[(1-mask_with_buffer).nonzero(as_tuple=True)].reshape(bsz, -1, embed_dim)
        
        # apply Transformer blocks
        for blk in self.encoder_blocks:
            x = blk(x)
        x = self.encoder_norm(x)

        return x

    def forward_mae_decoder(self, x, mask=None):
        x = self.decoder_embed(x)
        if self.mask:
            mask_with_buffer = torch.cat([torch.zeros(x.size(0), self.buffer_size, device=x.device), mask], dim=1)
            # pad mask tokens
            mask_tokens = self.mask_token.repeat(mask_with_buffer.shape[0], mask_with_buffer.shape[1], 1).to(x.dtype)
            x_after_pad = mask_tokens.clone()
            x_after_pad[(1 - mask_with_buffer).nonzero(as_tuple=True)] = x.reshape(x.shape[0] * x.shape[1], x.shape[2])
            x = x_after_pad + self.decoder_pos_embed_learned
        else:
            x = x + self.decoder_pos_embed_learned

        # apply Transformer blocks
        for blk in self.decoder_blocks:
            x = blk(x)
        x = self.decoder_norm(x)

        x = x[:, self.buffer_size:]
        x = x + self.diffusion_pos_embed_learned
        return x

    def forward_loss(self, z, target, mask, index, loss_weight=False):
        bsz, seq_len, _ = target.shape
        target = target.reshape(bsz * seq_len, -1).repeat(self.diffusion_batch_mul, 1)
        index = index.unsqueeze(1).unsqueeze(-1).repeat(1, seq_len, 1).reshape(bsz * seq_len, -1).repeat(self.diffusion_batch_mul, 1)
        z = z.reshape(bsz*seq_len, -1).repeat(self.diffusion_batch_mul, 1)
        
        if loss_weight:
            loss_weight = loss_weight.unsqueeze(1).repeat(1, seq_len).reshape(bsz * seq_len).repeat(self.diffusion_batch_mul)
        loss = self.diffloss(z=z, target=target, index=index, loss_weight=loss_weight)
        return loss



    def forward(self, imgs, labels, loss_weight=False):
        class_embedding = self.class_emb(labels)
        
        process_imgs, x_index = self.processingpregt_latent(imgs)
        if loss_weight:
            loss_weight = self.loss_weight

        x = self.patchify(process_imgs)         # x.shape: torch.Size([B, 256, 16]))
        gt_latents = self.patchify(imgs)
        
        mask = None
        if self.mask:
            orders = self.sample_orders(bsz=x.size(0))
            mask = self.random_masking(x, orders, x_index)

        x = self.forward_mae_encoder(x, class_embedding, mask)
        z = self.forward_mae_decoder(x, mask)
        loss = self.forward_loss(z=z, target=gt_latents, mask=mask, index=x_index, loss_weight=loss_weight)

        return loss

    def forward_finetune(self, lq_imgs, gt_imgs, freq_level=15):
        """用于图像去噪微调的前向传播函数
        
        Args:
            lq_imgs: 低质量图像的潜在表示
            gt_imgs: 高质量图像的潜在表示
            freq_level: 频率等级,默认随机选择14或15
            
        Returns:
            loss: 扩散损失
        """
        # 获取批次大小
        bsz = lq_imgs.shape[0]
        
        # 将图像转换为补丁表示
        lq_latents = self.patchify(lq_imgs)  # [bsz, seq_len, token_embed_dim],形状为[bsz, 256, 16]
        gt_latents = self.patchify(gt_imgs)  # [bsz, seq_len, token_embed_dim],形状为[bsz, 256, 16]
        
        # 随机生成14或15的频率级别,简化index处理
        random_freq = torch.randint(14, 16, (1,)).item()  # 随机选择14或15
        core_index = [random_freq] * bsz  # 为批次中的每个样本创建相同的频率级别
        x_index = torch.tensor(core_index).to(lq_imgs.device).half()  # 与原始代码保持一致
        
        # 将lq_latents处理为与decoder输出的z完全相同的表示
        # 注意:需要模拟完整的encoder-decoder流程

        # 1. 首先投影到encoder维度
        lq_z = self.z_proj(lq_latents)
        
        # 2. 添加缓冲区,与encoder处理一致
        lq_z = torch.cat([torch.zeros(bsz, self.buffer_size, lq_z.shape[-1], device=lq_z.device), lq_z], dim=1)
        
        # 3. 添加条件嵌入(使用fake_latent作为无条件生成的嵌入)
        lq_z[:, :self.buffer_size] = self.fake_latent.unsqueeze(1)
        
        # 4. 添加encoder位置编码
        lq_z = lq_z + self.encoder_pos_embed_learned
        lq_z = self.z_proj_ln(lq_z)
        
        # 5. 投影到decoder维度
        lq_z = self.decoder_embed(lq_z)
        
        # 6. 添加decoder位置编码
        lq_z = lq_z + self.decoder_pos_embed_learned
        
        # 7. 去除buffer部分
        lq_z = lq_z[:, self.buffer_size:]
        
        # 8. 添加diffusion位置编码
        lq_z = lq_z + self.diffusion_pos_embed_learned
        
        # 直接计算扩散损失
        loss = self.forward_loss(z=lq_z, target=gt_latents, mask=None, index=x_index, loss_weight=False)
        
        return loss
    
    def denoise_image(self, lq_img, temperature=1.0, cfg=3.0, freq_levels=None):
        """对单张低质量图像进行去噪处理
        
        Args:
            lq_img: 低质量图像的潜在表示,形状为[1, C, H, W]
            temperature: 扩散采样温度
            cfg: 分类器自由引导系数
            freq_levels: 频率级别,如果为None则使用15
            
        Returns:
            denoised_img: 去噪后的潜在表示,形状为[1, C, H, W]
        """
        self.eval()  # 设置为评估模式
        
        # 使用固定频率级别,一般采用14或15
        freq_level = 15 if freq_levels is None else freq_levels[0]
        
        with torch.no_grad():
            # 获取批次大小
            bsz = lq_img.shape[0]  # 通常为1
            
            # 将图像转换为补丁表示
            lq_latents = self.patchify(lq_img)  # [bsz, seq_len, token_embed_dim]
            
            # 1. 首先投影到encoder维度
            lq_z = self.z_proj(lq_latents)
            
            # 2. 添加缓冲区,与encoder处理一致
            lq_z = torch.cat([torch.zeros(bsz, self.buffer_size, lq_z.shape[-1], device=lq_z.device), lq_z], dim=1)
            
            # 3. 创建条件和无条件嵌入(用于CFG)
            # 为两者都使用fake_latent,它是一个无标签的嵌入
            fake_embedding = self.fake_latent.repeat(bsz, 1)
            
            # 复制输入以同时处理条件和无条件部分
            lq_z_both = torch.cat([lq_z, lq_z], dim=0)
            
            # 复制嵌入以同时处理条件和无条件部分
            fake_embedding_both = torch.cat([fake_embedding, fake_embedding], dim=0)
            fake_embedding_both = fake_embedding_both.unsqueeze(1)
            lq_z_both[:, :self.buffer_size] = fake_embedding_both
            
            # 4. 添加encoder位置编码并执行归一化
            lq_z_both = lq_z_both + self.encoder_pos_embed_learned
            lq_z_both = self.z_proj_ln(lq_z_both)
            
            # 5. 投影到decoder维度
            lq_z_both = self.decoder_embed(lq_z_both)
            
            # 6. 添加decoder位置编码
            lq_z_both = lq_z_both + self.decoder_pos_embed_learned
            
            # 7. 去除buffer部分
            lq_z_both = lq_z_both[:, self.buffer_size:]
            
            # 8. 添加diffusion位置编码
            lq_z_both = lq_z_both + self.diffusion_pos_embed_learned
            
            # 9. 准备频率级别索引
            index = torch.tensor([freq_level] * (bsz * 2), device=lq_img.device).half()
            
            # 10. 处理维度以适应diffloss采样
            B, L, C = lq_z_both.shape
            lq_z_both = lq_z_both.reshape(B * L, -1)  # [B*L, C]
            
            # 扩展索引
            index_expanded = index.unsqueeze(1).unsqueeze(-1).repeat(1, L, 1).reshape(B * L, -1)
            
            # 11. 使用diffloss采样进行去噪
            denoised_tokens = self.diffloss.sample(
                lq_z_both, 
                temperature=temperature,
                cfg=cfg,
                index=index_expanded
            )
            
            # 12. 将结果重新整形,并分离条件和无条件部分
            denoised_tokens = denoised_tokens.reshape(B, L, -1)
            denoised_cond, _ = denoised_tokens.chunk(2, dim=0)  # 只保留条件部分
            
            # 13. 转换回图像形状
            denoised_img = self.unpatchify(denoised_cond)
            
            return denoised_img

    def sample_tokens_mask(self, bsz, num_iter=64, cfg=1.0, cfg_schedule="linear", labels=None, temperature=1.0, progress=False):
        latent_core = [0,12,12,4,5,6,7,8,9,10,11]

        num_iter = len(latent_core)
        mask = torch.ones(bsz, self.seq_len).cuda()      
        tokens = torch.zeros(bsz, self.seq_len, self.token_embed_dim).cuda()
        orders = self.sample_orders(bsz)

        for step in list(range(num_iter)):
            cur_tokens = torch.zeros(bsz, self.seq_len, self.token_embed_dim).cuda()
            if labels is not None:
                class_embedding = self.class_emb(labels)      
            else:
                class_embedding = self.fake_latent.repeat(bsz, 1)
    
            tokens = torch.cat([tokens, tokens], dim=0)
            class_embedding = torch.cat([class_embedding, self.fake_latent.repeat(bsz, 1)], dim=0)
            mask = torch.cat([mask, mask], dim=0)

            x = self.forward_mae_encoder(tokens, class_embedding, mask)
            z = self.forward_mae_decoder(x, mask)    # torch.Size([512, 256, 768])
            B, L, C = z.shape
            z = z.reshape(B * L, -1)

            # mask ratio for the next round, following MaskGIT and MAGE.
            mask_ratio = np.cos(math.pi / 2. * (step + 1) / num_iter)
            mask_len = torch.Tensor([np.floor(self.seq_len * mask_ratio)]).cuda()     # 251, 236, 212, 181, 142, 97, 49, 0 
            cfg_iter = 1 + (cfg - 1) * step / num_iter
            temperature_iter = 0.85 + (temperature - 0.85) * step / num_iter

            index = torch.tensor([latent_core[step]]).unsqueeze(1).unsqueeze(-1).repeat(B, L, 1).reshape(B * L, -1).to(z.device).half()
            z = self.diffloss.sample(z, temperature_iter, cfg_iter, index)     # torch.Size([512, 16])
            z, _ = z.chunk(2, dim=0)  # Remove null class samples.  torch.Size([256, 16])
            
            if step < num_iter-1:
                z = z.reshape(bsz, L, -1).transpose_(1, 2).reshape(bsz, -1, 16, 16)
                
                if step > 0:
                    imgs_resize = F.interpolate(z, size=(latent_core[step+1], latent_core[step+1]), mode='area')
                    z = F.interpolate(imgs_resize, size=(16, 16), mode='bicubic')
                z = z.reshape(bsz, -1, L).transpose_(1, 2).reshape(bsz*L, -1)
                
            
            sampled_token = z.reshape(bsz, L, -1)
            mask_next = mask_by_order(mask_len[0], orders, bsz, self.seq_len)
            mask_to_pred = torch.logical_not(mask_next)
            mask = mask_next
            sampled_token_latent = sampled_token[mask_to_pred.nonzero(as_tuple=True)]
            
            cur_tokens[mask_to_pred.nonzero(as_tuple=True)] = sampled_token_latent
            tokens = cur_tokens.clone()


        tokens = tokens.transpose_(1, 2).reshape(bsz, -1, 16, 16)
        return tokens
    

    def sample_tokens_nomask(self, bsz, num_iter=64, cfg=1.0, cfg_schedule="linear", labels=None, temperature=1.0, progress=False):
        latent_core = [0,2,3,4,5,6,7,8,9,10]
        # latent_core = [0,1,2,3,4,5,6,7,8,9]
        num_iter = len(latent_core)

        # init and sample generation orders
        tokens = torch.zeros(bsz, self.seq_len, self.token_embed_dim).cuda().half()
        indices = list(range(num_iter))
        if progress:
            indices = tqdm(indices)
        
        for step in indices:
            if labels is not None:
                class_embedding = self.class_emb(labels)        # torch.Size([256, 768])
            else:
                class_embedding = self.fake_latent.repeat(bsz, 1)
            
            tokens = torch.cat([tokens, tokens], dim=0)
            class_embedding = torch.cat([class_embedding, self.fake_latent.repeat(bsz, 1)], dim=0)

            # mae encoder
            x = self.forward_mae_encoder(tokens, class_embedding)
            z = self.forward_mae_decoder(x)    # torch.Size([512, 256, 768])    var输出的condition的维度很高(768),var一次生成所有token后,只随机取部分(nge)送到diffusion中作为条件,生成部分token。

            B, L, C = z.shape
            z = z.reshape(B * L, -1)


            cfg_iter = 1 + (cfg - 1) * step / num_iter
            temperature_iter = 0.8 + (1 - np.cos(math.pi / 2. * (step + 1) / num_iter)) * (1-0.8)

            index = torch.tensor([latent_core[step]]).unsqueeze(1).unsqueeze(-1).repeat(B, L, 1).reshape(B * L, -1).to(z.device).half()
            sampled_token_latent = self.diffloss.sample(z, temperature_iter, cfg_iter, index=index)     # torch.Size([512, 16])
            sampled_token_latent, _ = sampled_token_latent.chunk(2, dim=0)  # Remove null class samples.  torch.Size([256, 16])

            sampled_token_latent = sampled_token_latent.reshape(bsz, L, -1).transpose_(1, 2).reshape(bsz, -1, 16, 16)
            if step < num_iter-1:
                if step > -1:
                    sampled_token_latent = F.interpolate(sampled_token_latent, size=(latent_core[step+1], latent_core[step+1]), mode='area')
                    sampled_token_latent = F.interpolate(sampled_token_latent, size=(16, 16), mode='bicubic')
                sampled_token_latent = sampled_token_latent.view(bsz, 16, -1).transpose(1, 2)

            tokens = sampled_token_latent.clone()

        return tokens
        


def far_base(**kwargs):
    model = FAR(
        encoder_embed_dim=768, encoder_depth=12, encoder_num_heads=12,
        decoder_embed_dim=768, decoder_depth=12, decoder_num_heads=12,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def far_large(**kwargs):
    model = FAR(
        encoder_embed_dim=1024, encoder_depth=16, encoder_num_heads=16,
        decoder_embed_dim=1024, decoder_depth=16, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def far_huge(**kwargs):
    model = FAR(
        encoder_embed_dim=1280, encoder_depth=20, encoder_num_heads=16,
        decoder_embed_dim=1280, decoder_depth=20, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


engine_far.py

import math
import sys
from typing import Iterable

import torch

import util.misc as misc
import util.lr_sched as lr_sched
from models.vae import DiagonalGaussianDistribution
import torch_fidelity
import shutil
import cv2
import numpy as np
import matplotlib.pyplot as plt
import os
import copy
import time
from PIL import Image
import torchvision.transforms as transforms


from torchvision.utils import make_grid
from typing import Optional
from PIL import Image
def save_image(images: torch.Tensor, nrow: int = 8, show: bool = True, path: Optional[str] = None, format: Optional[str] = None, to_grayscale: bool = False, **kwargs):
    images = images * 0.5 + 0.5
    grid = make_grid(images, nrow=nrow, **kwargs)  # (channels, height, width)
    #  (height, width, channels)
    grid = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
    im = Image.fromarray(grid)
    if to_grayscale:
        im = im.convert(mode="L")
    if path is not None:
        im.save(path, format=format)
    if show:
        im.show()
    return grid


def update_ema(target_params, source_params, rate=0.99):
    """
    Update target parameters to be closer to those of source parameters using
    an exponential moving average.

    :param target_params: the target parameter sequence.
    :param source_params: the source parameter sequence.
    :param rate: the EMA rate (closer to 1 means slower).
    """
    for targ, src in zip(target_params, source_params):
        targ.detach().mul_(rate).add_(src, alpha=1 - rate)


def process_single_image(model, vae, input_path, output_path, device, args, 
                       temperature=1.0, cfg=3.0, use_ema=True):
    """
    对单张图像进行去噪处理
    
    Args:
        model: FAR模型
        vae: VAE模型
        input_path: 输入图像路径
        output_path: 输出图像路径
        device: 计算设备
        args: 参数
        temperature: 扩散采样温度
        cfg: 分类器自由引导系数
        use_ema: 是否使用EMA参数
    """
    # 确保模型处于评估模式
    model.eval()
    
    # 切换到EMA参数
    if use_ema and hasattr(args, 'resume') and os.path.exists(os.path.join(args.resume, "checkpoint-last.pth")):
        print("加载EMA参数...")
        checkpoint = torch.load(os.path.join(args.resume, "checkpoint-last.pth"), map_location='cpu')
        model_state_dict = copy.deepcopy(model.state_dict())
        ema_state_dict = copy.deepcopy(model.state_dict())
        for i, (name, _value) in enumerate(model.named_parameters()):
            if name in checkpoint['model_ema']:
                ema_state_dict[name] = checkpoint['model_ema'][name].cuda()
        model.load_state_dict(ema_state_dict)
    
    # 导入center_crop_arr函数以与训练保持一致
    from util.crop import center_crop_arr
    
    # 图像预处理 - 与训练时保持一致,但去掉随机性操作
    transform = transforms.Compose([
        transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.img_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ])
    
    # 加载并预处理图像
    img = Image.open(input_path).convert('RGB')
    img_tensor = transform(img).unsqueeze(0).to(device)  # 添加批次维度
    
    # VAE编码
    with torch.no_grad():
        posterior = vae.encode(img_tensor)
        lq_latent = posterior.sample().mul_(0.2325)  # 标准化
    
    # 指定频率级别进行去噪
    freq_level = 15  # 使用单一频率级别
    
    # 模型去噪处理
    with torch.cuda.amp.autocast():
        denoised_latent = model.denoise_image(
            lq_latent, 
            temperature=temperature, 
            cfg=cfg,
            freq_levels=[freq_level]
        )
    
    # VAE解码
    with torch.no_grad():
        denoised_img = vae.decode(denoised_latent / 0.2325)  # 反标准化
    
    # 保存结果
    save_image(denoised_img, show=False, path=output_path)
    
    # 创建对比图
    plt.figure(figsize=(12, 6))
    
    # 原始图像
    img_np = np.array(img)
    plt.subplot(1, 2, 1)
    plt.imshow(img_np)
    plt.title('输入图像')
    plt.axis('off')
    
    # 去噪后图像
    denoised_np = denoised_img[0].mul(0.5).add_(0.5).mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
    plt.subplot(1, 2, 2)
    plt.imshow(denoised_np)
    plt.title('去噪后图像')
    plt.axis('off')
    
    # 保存对比图
    compare_path = os.path.splitext(output_path)[0] + '_compare.png'
    plt.tight_layout()
    plt.savefig(compare_path)
    plt.close()
    
    print(f"处理完成!结果已保存到: {output_path}")
    print(f"对比图已保存到: {compare_path}")
    
    # 如果使用了EMA参数,恢复原始状态
    if use_ema and 'model_state_dict' in locals():
        model.load_state_dict(model_state_dict)
    
    return output_path


def train_one_epoch(model, vae,
                    model_params, ema_params,
                    data_loader: Iterable, optimizer: torch.optim.Optimizer,
                    device: torch.device, epoch: int, loss_scaler,
                    log_writer=None,
                    args=None):
    model.train(True)
    metric_logger = misc.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)
    print_freq = 10

    optimizer.zero_grad()

    if log_writer is not None:
        print('log_dir: {}'.format(log_writer.log_dir))

    for data_iter_step, (samples, labels) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
        # we use a per iteration (instead of per epoch) lr scheduler
        lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)

        samples = samples.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)


        with torch.no_grad():
            if args.use_cached:
                moments = samples
                posterior = DiagonalGaussianDistribution(moments)
            else:
                posterior = vae.encode(samples)

            # normalize the std of latent to be 1. Change it if you use a different tokenizer
            x = posterior.sample().mul_(0.2325)

        # forward
        with torch.cuda.amp.autocast():
            loss = model(x, labels, loss_weight=args.loss_weight)
        loss_value = loss.item()

        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            sys.exit(1)

        loss_scaler(loss, optimizer, clip_grad=args.grad_clip, parameters=model.parameters(), update_grad=True)
        optimizer.zero_grad()

        torch.cuda.synchronize()

        update_ema(ema_params, model_params, rate=args.ema_rate)

        metric_logger.update(loss=loss_value)


        lr = optimizer.param_groups[0]["lr"]
        metric_logger.update(lr=lr)

        loss_value_reduce = misc.all_reduce_mean(loss_value)
        if log_writer is not None:
            """ We use epoch_1000x as the x-axis in tensorboard.
            This calibrates different curves when batch size changes.
            """
            epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
            log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x)
            log_writer.add_scalar('lr', lr, epoch_1000x)
        
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}


def train_one_epoch_finetune(model, vae,
                           model_params, ema_params,
                           data_loader: Iterable, optimizer: torch.optim.Optimizer,
                           device: torch.device, epoch: int, loss_scaler,
                           log_writer=None,
                           args=None):
    """
    针对图像去噪任务的微调训练循环
    
    Args:
        model: FAR模型
        vae: VAE模型用于编码解码
        model_params: 模型参数
        ema_params: EMA参数
        data_loader: 包含LQ和GT图像对的数据加载器
        optimizer: 优化器
        device: 训练设备
        epoch: 当前训练轮次
        loss_scaler: 损失缩放器
        log_writer: 日志记录器
        args: 训练参数
    """
    model.train(True)
    metric_logger = misc.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Finetune Epoch: [{}]'.format(epoch)
    print_freq = 10

    optimizer.zero_grad()

    if log_writer is not None:
        print('log_dir: {}'.format(log_writer.log_dir))

    for data_iter_step, (lq_samples, gt_samples, labels) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
        # 调整学习率
        lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)

        # 将样本移至GPU
        lq_samples = lq_samples.to(device, non_blocking=True)
        gt_samples = gt_samples.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)  # 虽然不会使用,但保持一致性

        # VAE编码
        with torch.no_grad():
            if args.use_cached:
                lq_moments = lq_samples
                gt_moments = gt_samples
                lq_posterior = DiagonalGaussianDistribution(lq_moments)
                gt_posterior = DiagonalGaussianDistribution(gt_moments)
            else:
                lq_posterior = vae.encode(lq_samples)
                gt_posterior = vae.encode(gt_samples)
            
            # 潜在表示标准化
            lq_latent = lq_posterior.sample().mul_(0.2325)
            gt_latent = gt_posterior.sample().mul_(0.2325)
        
        # 前向传播计算损失
        with torch.cuda.amp.autocast():
            # 模型内部会随机选择14或15作为频率级别
            loss = model.forward_finetune(lq_latent, gt_latent)
        
        loss_value = loss.item()
        
        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            sys.exit(1)
        
        # 反向传播更新参数
        loss_scaler(loss, optimizer, clip_grad=args.grad_clip, parameters=model.parameters(), update_grad=True)
        optimizer.zero_grad()
        
        torch.cuda.synchronize()
        
        # 更新EMA参数
        update_ema(ema_params, model_params, rate=args.ema_rate)
        
        metric_logger.update(loss=loss_value)
        
        lr = optimizer.param_groups[0]["lr"]
        metric_logger.update(lr=lr)
        
        loss_value_reduce = misc.all_reduce_mean(loss_value)
        if log_writer is not None:
            epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
            log_writer.add_scalar('finetune_loss', loss_value_reduce, epoch_1000x)
            log_writer.add_scalar('lr', lr, epoch_1000x)
    
    # 收集所有进程的统计信息
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}


def evaluate(model_without_ddp, vae, ema_params, args, epoch, batch_size=16, log_writer=None, cfg=1.0,
             use_ema=True):
    model_without_ddp.eval()
    num_steps = args.num_images // (batch_size * misc.get_world_size()) + 1
    save_folder = os.path.join(args.output_dir, "ariter{}-diffsteps{}-temp{}-{}cfg{}-image{}".format(args.num_iter,
                                                                                                     args.num_sampling_steps,
                                                                                                     args.temperature,
                                                                                                     args.cfg_schedule,
                                                                                                     cfg,
                                                                                                     args.num_images))
    if use_ema:
        save_folder = save_folder + "_ema"
    if args.evaluate:
        save_folder = save_folder + "_evaluate"
    print("Save to:", save_folder)
    if misc.get_rank() == 0:
        if not os.path.exists(save_folder):
            os.makedirs(save_folder)

    # switch to ema params
    if use_ema:
        model_state_dict = copy.deepcopy(model_without_ddp.state_dict())
        ema_state_dict = copy.deepcopy(model_without_ddp.state_dict())
        for i, (name, _value) in enumerate(model_without_ddp.named_parameters()):
            assert name in ema_state_dict
            ema_state_dict[name] = ema_params[i]
        print("Switch to ema")
        model_without_ddp.load_state_dict(ema_state_dict)

    class_num = args.class_num
    assert args.num_images % class_num == 0  # number of images per class must be the same
    class_label_gen_world = np.arange(0, class_num).repeat(args.num_images // class_num)
    class_label_gen_world = np.hstack([class_label_gen_world, np.zeros(50000)])
    world_size = misc.get_world_size()
    local_rank = misc.get_rank()
    used_time = 0
    gen_img_cnt = 0


    for i in range(num_steps):
        print("Generation step {}/{}".format(i, num_steps))

        labels_gen = class_label_gen_world[world_size * batch_size * i + local_rank * batch_size:
                                                world_size * batch_size * i + (local_rank + 1) * batch_size]
        labels_gen = torch.Tensor(labels_gen).long().cuda()

        torch.cuda.synchronize()
        start_time = time.time()

        # generation
        with torch.no_grad():
            with torch.cuda.amp.autocast():
                if args.mask:
                    sampled_tokens = model_without_ddp.sample_tokens_mask(bsz=batch_size, num_iter=args.num_iter, cfg=cfg,
                                                                    cfg_schedule=args.cfg_schedule, labels=labels_gen,
                                                                    temperature=args.temperature)
                else:
                    sampled_tokens = model_without_ddp.sample_tokens_nomask(bsz=batch_size, num_iter=args.num_iter, cfg=cfg,
                                                                    cfg_schedule=args.cfg_schedule, labels=labels_gen,
                                                                    temperature=args.temperature)
                sampled_images = vae.decode(sampled_tokens / 0.2325)

        # measure speed after the first generation batch
        if i >= 1:
            torch.cuda.synchronize()
            used_time += time.time() - start_time
            gen_img_cnt += batch_size
            print("Generating {} images takes {:.5f} seconds, {:.5f} sec per image".format(gen_img_cnt, used_time, used_time / gen_img_cnt))

        torch.distributed.barrier()
        sampled_images = sampled_images.detach().cpu()

        save_image(sampled_images, nrow=8, show=False, path=os.path.join(args.output_dir, f"epoch{epoch}.png"), to_grayscale=False)
        return

        sampled_images = (sampled_images + 1) / 2

        # distributed save
        for b_id in range(sampled_images.size(0)):
            img_id = i * sampled_images.size(0) * world_size + local_rank * sampled_images.size(0) + b_id
            if img_id >= args.num_images:
                break
            gen_img = np.round(np.clip(sampled_images[b_id].numpy().transpose([1, 2, 0]) * 255, 0, 255))
            gen_img = gen_img.astype(np.uint8)[:, :, ::-1]
            gen_img = gen_img[:, :, ::-1]
            plt.imsave(os.path.join(save_folder, '{}.png'.format(str(img_id).zfill(5))), gen_img)


    torch.distributed.barrier()
    time.sleep(10)

    # back to no ema
    if use_ema:
        print("Switch back from ema")
        model_without_ddp.load_state_dict(model_state_dict)

    # compute FID and IS
    if log_writer is not None:
        if args.img_size == 256:
            input2 = None
            fid_statistics_file = '/mnt/workspace/workgroup/yuhu/code/FAR/fid_stats/adm_in256_stats.npz'
        else:
            raise NotImplementedError
        metrics_dict = torch_fidelity.calculate_metrics(
            input1=save_folder,
            input2=input2,
            fid_statistics_file=fid_statistics_file,
            cuda=True,
            isc=True,
            fid=True,
            kid=False,
            prc=False,
            verbose=False,
        )
        fid = metrics_dict['frechet_inception_distance']
        inception_score = metrics_dict['inception_score_mean']
        print(fid)
        print(inception_score)
        postfix = ""
        if use_ema:
           postfix = postfix + "_ema"
        if not cfg == 1.0:
           postfix = postfix + "_cfg{}".format(cfg)
        log_writer.add_scalar('fid{}'.format(postfix), fid, epoch)
        log_writer.add_scalar('is{}'.format(postfix), inception_score, epoch)
        print("FID: {:.4f}, Inception Score: {:.4f}".format(fid, inception_score))
        # remove temporal saving folder
        shutil.rmtree(save_folder)

    torch.distributed.barrier()
    time.sleep(10)


def cache_latents(vae,
                  data_loader: Iterable,
                  device: torch.device,
                  args=None):
    metric_logger = misc.MetricLogger(delimiter="  ")
    header = 'Caching: '
    print_freq = 20

    for data_iter_step, (samples, _, paths) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):

        samples = samples.to(device, non_blocking=True)

        with torch.no_grad():
            posterior = vae.encode(samples)
            moments = posterior.parameters
            posterior_flip = vae.encode(samples.flip(dims=[3]))
            moments_flip = posterior_flip.parameters

        for i, path in enumerate(paths):
            save_path = os.path.join(args.cached_path, path + '.npz')
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
            np.savez(save_path, moments=moments[i].cpu().numpy(), moments_flip=moments_flip[i].cpu().numpy())

        if misc.is_dist_avail_and_initialized():
            torch.cuda.synchronize()

    return

main_far.py

import argparse
import datetime
import numpy as np
import os
import time
from pathlib import Path

import torch
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms as transforms
import torchvision.datasets as datasets

from util.crop import center_crop_arr
import util.misc as misc
from util.misc import NativeScalerWithGradNormCount as NativeScaler
from util.loader import CachedFolder
from PIL import Image

from models.vae import AutoencoderKL
from models import far
from engine_far import train_one_epoch, train_one_epoch_finetune, evaluate, process_single_image
import copy


# 新增:图像对数据集类
class PairedImageDataset(torch.utils.data.Dataset):
    """
    用于加载配对的LQ和GT图像的数据集
    """
    def __init__(self, lq_list_path, gt_list_path, transform=None):
        self.lq_list = self.read_list(lq_list_path)
        self.gt_list = self.read_list(gt_list_path)
        self.transform = transform
        assert len(self.lq_list) == len(self.gt_list), "LQ和GT列表长度不匹配!"
        
    def read_list(self, list_path):
        with open(list_path, 'r') as f:
            return [line.strip() for line in f.readlines()]
    
    def __len__(self):
        return len(self.lq_list)
    
    def __getitem__(self, idx):
        lq_path = self.lq_list[idx]
        gt_path = self.gt_list[idx]
        
        lq_img = Image.open(lq_path).convert('RGB')
        gt_img = Image.open(gt_path).convert('RGB')
        
        if self.transform:
            lq_img = self.transform(lq_img)
            gt_img = self.transform(gt_img)
        
        # 返回一个虚拟标签0,保持与原有数据加载格式兼容,确保是长整型
        return lq_img, gt_img, torch.tensor(0, dtype=torch.long)
    
def get_args_parser():
    parser = argparse.ArgumentParser('FAR training with Diffusion Loss', add_help=False)
    parser.add_argument('--batch_size', default=16, type=int,
                        help='Batch size per GPU (effective batch size is batch_size * # gpus')
    parser.add_argument('--epochs', default=400, type=int)

    # Model parameters
    parser.add_argument('--model', default='far_large', type=str, metavar='MODEL',
                        help='Name of model to train')

    # VAE parameters
    parser.add_argument('--img_size', default=256, type=int,
                        help='images input size')
    parser.add_argument('--vae_path', default="pretrained/vae_mar/kl16.ckpt", type=str,
                        help='images input size')
    parser.add_argument('--vae_embed_dim', default=16, type=int,
                        help='vae output embedding dimension')
    parser.add_argument('--vae_stride', default=16, type=int,
                        help='tokenizer stride, default use KL16')
    parser.add_argument('--patch_size', default=1, type=int,
                        help='number of tokens to group as a patch.')

    # Generation parameters
    parser.add_argument('--num_iter', default=64, type=int,
                        help='number of autoregressive iterations to generate an image')
    parser.add_argument('--num_images', default=50000, type=int,
                        help='number of images to generate')
    parser.add_argument('--cfg', default=1.0, type=float, help="classifier-free guidance")
    parser.add_argument('--cfg_schedule', default="linear", type=str)
    parser.add_argument('--label_drop_prob', default=0.1, type=float)
    parser.add_argument('--eval_freq', type=int, default=40, help='evaluation frequency')
    parser.add_argument('--save_last_freq', type=int, default=10, help='save last frequency')
    parser.add_argument('--online_eval', action='store_true')
    parser.add_argument('--evaluate', action='store_true')
    parser.add_argument('--mask', action='store_true')
    parser.add_argument('--eval_bsz', type=int, default=64, help='generation batch size')

    # 新增:单张图像推理相关参数
    parser.add_argument('--inference', action='store_true',
                        help='启用单张图像推理模式')
    parser.add_argument('--input_image', type=str, default='',
                        help='输入图像的路径')
    parser.add_argument('--output_image', type=str, default='',
                        help='输出图像的路径')
    parser.add_argument('--inference_temperature', type=float, default=1.0,
                        help='推理时的温度参数')
    parser.add_argument('--inference_cfg', type=float, default=3.0,
                        help='推理时的分类器自由引导系数')
    parser.add_argument('--use_ema', action='store_true', default=True,
                        help='是否使用EMA参数进行推理')

    # Optimizer parameters
    parser.add_argument('--weight_decay', type=float, default=0.02,
                        help='weight decay (default: 0.02)')

    parser.add_argument('--lr', type=float, default=None, metavar='LR',
                        help='learning rate (absolute lr)')
    parser.add_argument('--blr', type=float, default=1e-4, metavar='LR',
                        help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
    parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
                        help='lower lr bound for cyclic schedulers that hit 0')
    parser.add_argument('--lr_schedule', type=str, default='constant',
                        help='learning rate schedule')
    parser.add_argument('--warmup_epochs', type=int, default=100, metavar='N',
                        help='epochs to warmup LR')
    parser.add_argument('--ema_rate', default=0.9999, type=float)

    # FAR params
    parser.add_argument('--mask_ratio_min', type=float, default=0.7,
                        help='Minimum mask ratio')
    parser.add_argument('--grad_clip', type=float, default=3.0,
                        help='Gradient clip')
    parser.add_argument('--attn_dropout', type=float, default=0.1,
                        help='attention dropout')
    parser.add_argument('--proj_dropout', type=float, default=0.1,
                        help='projection dropout')
    parser.add_argument('--buffer_size', type=int, default=64)
    parser.add_argument('--loss_weight', action='store_true', help='adopt uneven loss weight.')

    # Diffusion Loss params
    parser.add_argument('--diffloss_d', type=int, default=12)
    parser.add_argument('--diffloss_w', type=int, default=1536)
    parser.add_argument('--num_sampling_steps', type=str, default="100")
    parser.add_argument('--diffusion_batch_mul', type=int, default=1)
    parser.add_argument('--temperature', default=1.0, type=float, help='diffusion loss sampling temperature')

    # Dataset parameters
    parser.add_argument('--data_path', default='./data/imagenet', type=str,
                        help='dataset path')
    parser.add_argument('--class_num', default=1000, type=int)

    # 新增:微调相关参数
    parser.add_argument('--finetune', action='store_true', 
                        help='是否进行微调训练')
    parser.add_argument('--lq_list_path', type=str, default='',
                        help='低质量图像列表文件路径')
    parser.add_argument('--gt_list_path', type=str, default='',
                        help='高质量图像列表文件路径')
    parser.add_argument('--freq_level', type=int, default=15,
                        help='低质量图像的频率等级(此参数已不再使用,模型内部会随机选择14或15)')

    parser.add_argument('--output_dir', default='./output_dir',
                        help='path where to save, empty for no saving')
    parser.add_argument('--log_dir', default='./output_dir',
                        help='path where to tensorboard log')
    parser.add_argument('--device', default='cuda',
                        help='device to use for training / testing')
    parser.add_argument('--seed', default=1, type=int)
    parser.add_argument('--resume', default='',
                        help='resume from checkpoint')

    parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
                        help='start epoch')
    parser.add_argument('--num_workers', default=8, type=int)
    parser.add_argument('--pin_mem', action='store_true',
                        help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
    parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
    parser.set_defaults(pin_mem=True)

    # distributed training parameters
    parser.add_argument('--world_size', default=1, type=int,
                        help='number of distributed processes')
    parser.add_argument('--local_rank', default=-1, type=int)
    parser.add_argument('--dist_on_itp', action='store_true')
    parser.add_argument('--dist_url', default='env://',
                        help='url used to set up distributed training')

    # caching latents
    parser.add_argument('--use_cached', action='store_true', dest='use_cached',
                        help='Use cached latents')
    parser.set_defaults(use_cached=False)
    parser.add_argument('--cached_path', default='', help='path to cached latents')

    return parser


def main(args):
    misc.init_distributed_mode(args)

    print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
    print("{}".format(args).replace(', ', ',\n'))

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + misc.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)

    cudnn.benchmark = True

    # 对于单张图像推理,不需要分布式设置
    if args.inference:
        global_rank = 0
        num_tasks = 1
    else:
        num_tasks = misc.get_world_size()
        global_rank = misc.get_rank()

    if global_rank == 0 and args.log_dir is not None:
        os.makedirs(args.log_dir, exist_ok=True)
        log_writer = SummaryWriter(log_dir=args.log_dir)
    else:
        log_writer = None


    ### ImageNet VAE
    vae = AutoencoderKL(embed_dim=args.vae_embed_dim, ch_mult=(1, 1, 2, 2, 4), ckpt_path=args.vae_path).cuda().eval()
    for param in vae.parameters():
        param.requires_grad = False
        
    model = far.__dict__[args.model](
        img_size=args.img_size,
        vae_stride=args.vae_stride,
        patch_size=args.patch_size,
        vae_embed_dim=args.vae_embed_dim,
        mask=args.mask,
        mask_ratio_min=args.mask_ratio_min,
        label_drop_prob=args.label_drop_prob,
        class_num=args.class_num,
        attn_dropout=args.attn_dropout,
        proj_dropout=args.proj_dropout,
        buffer_size=args.buffer_size,
        diffloss_d=args.diffloss_d,
        diffloss_w=args.diffloss_w,
        num_sampling_steps=args.num_sampling_steps,
        diffusion_batch_mul=args.diffusion_batch_mul,
    )

    
    print("Model = %s" % str(model))
    # following timm: set wd as 0 for bias and norm layers
    n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print("Number of trainable parameters: {}M".format(n_params / 1e6))

    model.to(device)
    model_without_ddp = model

    # 推理模式
    if args.inference:
        assert args.input_image != '', "必须指定输入图像路径!"
        assert args.output_image != '', "必须指定输出图像路径!"
        assert args.resume != '', "必须提供训练好的模型路径!"
        
        if not os.path.exists(args.input_image):
            raise FileNotFoundError(f"找不到输入图像:{args.input_image}")
        
        # 确保输出目录存在
        output_dir = os.path.dirname(args.output_image)
        if output_dir and not os.path.exists(output_dir):
            os.makedirs(output_dir)
            
        # 加载模型权重
        if os.path.exists(os.path.join(args.resume, "checkpoint-last.pth")):
            checkpoint = torch.load(os.path.join(args.resume, "checkpoint-last.pth"), map_location='cpu')
            model_without_ddp.load_state_dict(checkpoint['model'])
            print(f"已加载模型权重:{args.resume}")
        else:
            raise FileNotFoundError(f"找不到模型权重:{args.resume}")
            
        # 处理单张图像
        process_single_image(
            model=model_without_ddp,
            vae=vae,
            input_path=args.input_image,
            output_path=args.output_image,
            device=device,
            args=args,
            temperature=args.inference_temperature,
            cfg=args.inference_cfg,
            use_ema=args.use_ema
        )
        
        return

    eff_batch_size = args.batch_size * misc.get_world_size()

    if args.lr is None:  # only base_lr is specified
        args.lr = args.blr * eff_batch_size / 256

    print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
    print("actual lr: %.2e" % args.lr)
    print("effective batch size: %d" % eff_batch_size)

    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module

    # no weight decay on bias, norm layers, and diffloss MLP
    param_groups = misc.add_weight_decay(model_without_ddp, args.weight_decay)
    optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))
    print(optimizer)
    loss_scaler = NativeScaler()

    
    ### resume training
    if args.resume and os.path.exists(os.path.join(args.resume, "checkpoint-last.pth")):
        checkpoint = torch.load(os.path.join(args.resume, "checkpoint-last.pth"), map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        model_params = list(model_without_ddp.parameters())
        ema_state_dict = checkpoint['model_ema']
        ema_params = [ema_state_dict[name].cuda() for name, _ in model_without_ddp.named_parameters()]
        print("Resume checkpoint %s" % args.resume)

        if 'optimizer' in checkpoint and 'epoch' in checkpoint and not args.finetune:
            optimizer.load_state_dict(checkpoint['optimizer'])
            args.start_epoch = checkpoint['epoch'] + 1
            if 'scaler' in checkpoint:
                loss_scaler.load_state_dict(checkpoint['scaler'])
            print("With optim & sched!")
        del checkpoint
    else:
        model_params = list(model_without_ddp.parameters())
        ema_params = copy.deepcopy(model_params)
        print("Training from scratch")
    
    # evaluate FID and IS
    if args.evaluate:
        torch.cuda.empty_cache()
        evaluate(model_without_ddp, vae, ema_params, args, 0, batch_size=args.eval_bsz, log_writer=log_writer, cfg=args.cfg, use_ema=True)
        return


    # augmentation following DiT and ADM
    transform_train = transforms.Compose([
        transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.img_size)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ])

    # 根据微调模式选择不同的数据集
    if args.finetune:
        print(f"使用微调模式,加载LQ/GT图像对")
        dataset_train = PairedImageDataset(
            lq_list_path=args.lq_list_path,
            gt_list_path=args.gt_list_path,
            transform=transform_train
        )
    else:
        if args.use_cached:
            dataset_train = CachedFolder(args.cached_path)
        else:
            dataset_train = datasets.ImageFolder(os.path.join(args.data_path, 'train'), transform=transform_train)
            
    print(dataset_train)

    sampler_train = torch.utils.data.DistributedSampler(
        dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
    )
    print("Sampler_train = %s" % str(sampler_train))

    data_loader_train = torch.utils.data.DataLoader(
        dataset_train, sampler=sampler_train,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
    )
    
    # training
    print(f"Start {'finetune' if args.finetune else 'training'} for {args.epochs} epochs")
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            data_loader_train.sampler.set_epoch(epoch)

        # 根据微调模式选择不同的训练函数
        if args.finetune:
            train_stats = train_one_epoch_finetune(
                model, vae,
                model_params, ema_params,
                data_loader_train,
                optimizer, device, epoch, loss_scaler,
                log_writer=log_writer,
                args=args
            )
        else:
            train_stats = train_one_epoch(
                model, vae,
                model_params, ema_params,
                data_loader_train,
                optimizer, device, epoch, loss_scaler,
                log_writer=log_writer,
                args=args
            )

        # save checkpoint
        if epoch % args.save_last_freq == 0 or epoch + 1 == args.epochs:
            misc.save_model(args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
                            loss_scaler=loss_scaler, epoch=epoch, ema_params=ema_params)
        # online evaluation
        if args.online_eval and (epoch % args.eval_freq == 0 or epoch + 1 == args.epochs):
            torch.cuda.empty_cache()
            evaluate(model_without_ddp, vae, ema_params, args, epoch, batch_size=args.eval_bsz, log_writer=log_writer, cfg=1.0, use_ema=True)
            if not (args.cfg == 1.0 or args.cfg == 0.0):
                evaluate(model_without_ddp, vae, ema_params, args, epoch, batch_size=args.eval_bsz // 2, log_writer=log_writer, cfg=args.cfg, use_ema=True)
            torch.cuda.empty_cache()

        if misc.is_main_process():
            if log_writer is not None:
                log_writer.flush()

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))


if __name__ == '__main__':
    args = get_args_parser()
    args = args.parse_args()

    Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    args.log_dir = args.output_dir
    main(args)

fine-tune.py

#!/bin/bash

# 微调FAR-large模型用于图像去噪任务
torchrun --nproc_per_node=4 --nnodes=1 --node_rank=0 \
main_far.py \
--img_size 256 --vae_path pretrained_models/vae/kl16.ckpt --vae_embed_dim 16 --vae_stride 16 --patch_size 1 \
--model far_large --diffloss_d 3 --diffloss_w 1024 \
--epochs 100 --warmup_epochs 10 --batch_size 32 --blr 5.0e-5 --diffusion_batch_mul 4 \
--finetune \
--lq_list_path /data/vjuicefs_ai_camera_jgroup_research/public_data/11164225/datasets/Magic_Dictionary/training_list/20250105.6789degq_ShMfwowCaGq3_13_Mfww_r6789_G35l0105.6_6_1.LQ.list \
--gt_list_path /data/vjuicefs_ai_camera_jgroup_research/public_data/11164225/datasets/Magic_Dictionary/training_list/20250105.6789degq_ShMfwowCaGq3_13_Mfww_r6789_G35l0105.6_6_1.GT.list \
--resume pretrained_models/far/far_large \
--output_dir ./output_finetune \
--save_last_freq 5 \
--grad_clip 1.0 

inference.sh

#!/bin/bash

# FAR模型单张图像去噪推理脚本

# 检查输入参数
if [ "$#" -lt 1 ]; then
  echo "用法: $0 <输入图像路径> [输出图像路径=./output.png]"
  exit 1
fi

# 参数设置
INPUT_IMAGE="$1"
OUTPUT_IMAGE="${2:-./output.png}"
MODEL_PATH="./output_finetune"  # 微调后的模型路径

# 确保输入文件存在
if [ ! -f "$INPUT_IMAGE" ]; then
  echo "错误: 输入文件不存在: $INPUT_IMAGE"
  exit 1
fi

# 创建输出目录(如果不存在)
OUTPUT_DIR=$(dirname "$OUTPUT_IMAGE")
mkdir -p "$OUTPUT_DIR"

echo "处理图像: $INPUT_IMAGE"
echo "输出路径: $OUTPUT_IMAGE"
echo "使用模型: $MODEL_PATH"

# 执行推理
python main_far.py \
--img_size 256 --vae_path pretrained_models/vae/kl16.ckpt --vae_embed_dim 16 --vae_stride 16 --patch_size 1 \
--model far_large --diffloss_d 3 --diffloss_w 1024 \
--inference \
--input_image "$INPUT_IMAGE" \
--output_image "$OUTPUT_IMAGE" \
--resume "$MODEL_PATH" \
--inference_temperature 1.0 \
--inference_cfg 3.0 \
--use_ema

# 检查执行结果
if [ $? -eq 0 ]; then
  echo "图像处理成功完成!"
  echo "结果已保存到: $OUTPUT_IMAGE"
else
  echo "处理过程中发生错误!"
fi

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值