changemodel

far.py

from functools import partial

import numpy as np
from tqdm import tqdm
import scipy.stats as stats
import math
import torch
import torch.nn as nn

import matplotlib.pyplot as plt
from timm.models.vision_transformer import Block

from models.diffloss import DiffLoss

import torch.nn.functional as F
import torchvision.transforms as T
import random


from torchvision.utils import make_grid
from typing import Optional
from PIL import Image
def save_image(images: torch.Tensor, nrow: int = 8, show: bool = True, path: Optional[str] = None, format: Optional[str] = None, to_grayscale: bool = False, **kwargs):
    images = images * 0.5 + 0.5
    grid = make_grid(images, nrow=nrow, **kwargs)  # (channels, height, width)
    #  (height, width, channels)
    grid = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
    im = Image.fromarray(grid)
    if to_grayscale:
        im = im.convert(mode="L")
    if path is not None:
        im.save(path, format=format)
    if show:
        im.show()
    return grid


def mask_by_order(mask_len, order, bsz, seq_len):
    masking = torch.zeros(bsz, seq_len).cuda()
    masking = torch.scatter(masking, dim=-1, index=order[:, :mask_len.long()], src=torch.ones(bsz, seq_len).cuda()).bool()
    return masking


class FAR(nn.Module):
    """ Masked Autoencoder with VisionTransformer backbone
    """
    def __init__(self, img_size=256, vae_stride=16, patch_size=1,
                 encoder_embed_dim=1024, encoder_depth=16, encoder_num_heads=16,
                 decoder_embed_dim=1024, decoder_depth=16, decoder_num_heads=16,
                 mlp_ratio=4., norm_layer=nn.LayerNorm,
                 vae_embed_dim=16,
                 mask=True,
                 mask_ratio_min=0.7,
                 label_drop_prob=0.1,
                 class_num=1000,
                 attn_dropout=0.1,
                 proj_dropout=0.1,
                 buffer_size=64,
                 diffloss_d=3,
                 diffloss_w=1024,
                 num_sampling_steps='100',
                 diffusion_batch_mul=4
                 ):
        super().__init__()

        # --------------------------------------------------------------------------
        # VAE and patchify specifics
        self.vae_embed_dim = vae_embed_dim

        self.img_size = img_size
        self.vae_stride = vae_stride
        self.patch_size = patch_size
        self.seq_h = self.seq_w = img_size // vae_stride // patch_size
        self.seq_len = self.seq_h * self.seq_w
        self.token_embed_dim = vae_embed_dim * patch_size**2

        # --------------------------------------------------------------------------
        # Class Embedding
        self.num_classes = class_num
        self.class_emb = nn.Embedding(1000, encoder_embed_dim)
        self.label_drop_prob = label_drop_prob
        # Fake class embedding for CFG's unconditional generation
        self.fake_latent = nn.Parameter(torch.zeros(1, encoder_embed_dim))
        self.loss_weight = [1 + np.sin(math.pi / 2. * (bands + 1) / self.seq_h) for bands in range(self.seq_h)]

        # --------------------------------------------------------------------------
        self.mask = mask
        # FAR variant masking ratio, a left-half truncated Gaussian centered at 100% masking ratio with std 0.25
        self.mask_ratio_generator = stats.truncnorm((mask_ratio_min - 1.0) / 0.25, 0, loc=1.0, scale=0.25)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
        
        # --------------------------------------------------------------------------
        # FAR encoder specifics
        self.z_proj = nn.Linear(self.token_embed_dim, encoder_embed_dim, bias=True)
        self.z_proj_ln = nn.LayerNorm(encoder_embed_dim, eps=1e-6)
        self.buffer_size = buffer_size
        self.encoder_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len + self.buffer_size, encoder_embed_dim))

        self.encoder_blocks = nn.ModuleList([
            Block(encoder_embed_dim, encoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer,
                  proj_drop=proj_dropout, attn_drop=attn_dropout) for _ in range(encoder_depth)])
        self.encoder_norm = norm_layer(encoder_embed_dim)

        # --------------------------------------------------------------------------
        # FAR decoder specifics
        self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True)
        self.decoder_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len + self.buffer_size, decoder_embed_dim))

        self.decoder_blocks = nn.ModuleList([
            Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer,
                  proj_drop=proj_dropout, attn_drop=attn_dropout) for _ in range(decoder_depth)])

        self.decoder_norm = norm_layer(decoder_embed_dim)
        self.diffusion_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len, decoder_embed_dim))

        self.initialize_weights()

        # --------------------------------------------------------------------------
        # Diffusion Loss
        self.diffloss = DiffLoss(
            target_channels=self.token_embed_dim,
            z_channels=decoder_embed_dim,
            width=diffloss_w,
            depth=diffloss_d,
            num_sampling_steps=num_sampling_steps,
        )
        self.diffusion_batch_mul = diffusion_batch_mul

    def initialize_weights(self):
        # parameters
        torch.nn.init.normal_(self.class_emb.weight, std=.02)
        torch.nn.init.normal_(self.fake_latent, std=.02)
        if self.mask:
            torch.nn.init.normal_(self.mask_token, std=.02)
        torch.nn.init.normal_(self.encoder_pos_embed_learned, std=.02)
        torch.nn.init.normal_(self.decoder_pos_embed_learned, std=.02)
        torch.nn.init.normal_(self.diffusion_pos_embed_learned, std=.02)

        # initialize nn.Linear and nn.LayerNorm
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            # we use xavier_uniform following official JAX ViT:
            torch.nn.init.xavier_uniform_(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
            if m.weight is not None:
                nn.init.constant_(m.weight, 1.0)

    def patchify(self, x):
        bsz, c, h, w = x.shape
        p = self.patch_size
        h_, w_ = h // p, w // p

        x = x.reshape(bsz, c, h_, p, w_, p)
        x = torch.einsum('nchpwq->nhwcpq', x)
        x = x.reshape(bsz, h_ * w_, c * p ** 2)
        return x  # [n, l, d]

    def unpatchify(self, x):
        bsz = x.shape[0]
        p = self.patch_size
        c = self.vae_embed_dim
        h_, w_ = self.seq_h, self.seq_w

        x = x.reshape(bsz, h_, w_, c, p, p)
        x = torch.einsum('nhwcpq->nchpwq', x)
        x = x.reshape(bsz, c, h_ * p, w_ * p)
        return x  # [n, c, h, w]

    def sample_orders(self, bsz):
        # generate a batch of random generation orders
        orders = []
        for _ in range(bsz):
            order = np.array(list(range(self.seq_len)))
            np.random.shuffle(order)
            orders.append(order)
        orders = torch.Tensor(np.array(orders)).cuda().long()
        return orders

    def random_masking(self, x, orders, x_index):
        # generate token mask
        bsz, seq_len, embed_dim = x.shape
        stage = x_index[0]
        mask_ratio_min = 0.7 * stage / 16
        mask_rate = stats.truncnorm((mask_ratio_min - 1.0) / 0.25, 0, loc=1.0, scale=0.25).rvs(1)[0]
        num_masked_tokens = int(np.ceil(seq_len * mask_rate))
        mask = torch.zeros(bsz, seq_len, device=x.device)
        mask = torch.scatter(mask, dim=-1, index=orders[:, :num_masked_tokens],
                             src=torch.ones(bsz, seq_len, device=x.device))
        return mask


    def processingpregt_latent(self, imgs):
        B, C, H, W = imgs.shape
        out = torch.zeros_like(imgs)
        latent_core = list(range(H))
        core_index = []

        random_number = torch.randint(0, len(latent_core), (1,))
        for i in range(B):
            chosen_core = latent_core[random_number]
            core_index.append(chosen_core)
            if random_number == 0:
                out[i] = torch.zeros(C, H, W).cuda()    # torch.Size([256, 256, 16])
            else:
                imgs_resize = F.interpolate(imgs[i].unsqueeze(0), size=(chosen_core, chosen_core), mode='area')
                out[i] = F.interpolate(imgs_resize, size=(H, W), mode='bicubic').squeeze(0)
        core_index = torch.tensor(core_index).to(out.device).half()
        return out, core_index

    
    def forward_mae_encoder(self, x, class_embedding, mask=None):
        x = self.z_proj(x)
        bsz, seq_len, embed_dim = x.shape

        # concat buffer
        x = torch.cat([torch.zeros(bsz, self.buffer_size, embed_dim, device=x.device), x], dim=1)
        
        if self.mask:
            mask_with_buffer = torch.cat([torch.zeros(x.size(0), self.buffer_size, device=x.device), mask], dim=1)

        # random drop class embedding during training
        if self.training:  
            drop_latent_mask = torch.rand(bsz) < self.label_drop_prob
            drop_latent_mask = drop_latent_mask.unsqueeze(-1).cuda().to(x.dtype)
            class_embedding = drop_latent_mask * self.fake_latent + (1 - drop_latent_mask) * class_embedding

        x[:, :self.buffer_size] = class_embedding.unsqueeze(1)

        # encoder position embedding
        x = x + self.encoder_pos_embed_learned
        x = self.z_proj_ln(x)

        # dropping
        if self.mask:
            x = x[(1-mask_with_buffer).nonzero(as_tuple=True)].reshape(bsz, -1, embed_dim)
        
        # apply Transformer blocks
        for blk in self.encoder_blocks:
            x = blk(x)
        x = self.encoder_norm(x)

        return x

    def forward_mae_decoder(self, x, mask=None):
        x = self.decoder_embed(x)
        if self.mask:
            mask_with_buffer = torch.cat([torch.zeros(x.size(0), self.buffer_size, device=x.device), mask], dim=1)
            # pad mask tokens
            mask_tokens = self.mask_token.repeat(mask_with_buffer.shape[0], mask_with_buffer.shape[1], 1).to(x.dtype)
            x_after_pad = mask_tokens.clone()
            x_after_pad[(1 - mask_with_buffer).nonzero(as_tuple=True)] = x.reshape(x.shape[0] * x.shape[1], x.shape[2])
            x = x_after_pad + self.decoder_pos_embed_learned
        else:
            x = x + self.decoder_pos_embed_learned

        # apply Transformer blocks
        for blk in self.decoder_blocks:
            x = blk(x)
        x = self.decoder_norm(x)

        x = x[:, self.buffer_size:]
        x = x + self.diffusion_pos_embed_learned
        return x

    def forward_loss(self, z, target, mask, index, loss_weight=False):
            
        bsz, seq_len, _ = target.shape
        
        current_diffusion_batch_mul = self.diffusion_batch_mul  # 开始时使用1
        
        target = target.reshape(bsz * seq_len, -1).repeat(current_diffusion_batch_mul, 1)
        index = index.unsqueeze(1).unsqueeze(-1).repeat(1, seq_len, 1).reshape(bsz * seq_len, -1).repeat(current_diffusion_batch_mul, 1)
        z = z.reshape(bsz*seq_len, -1).repeat(current_diffusion_batch_mul, 1)
        
        if loss_weight:
            loss_weight = loss_weight.unsqueeze(1).repeat(1, seq_len).reshape(bsz * seq_len).repeat(current_diffusion_batch_mul)

        loss = self.diffloss(z=z, target=target, index=index, loss_weight=loss_weight)
        return loss



    def forward(self, lq_imgs, gt_imgs, loss_weight=False):
        freq_level = random.randint(14,15)
        freq_embedding = self.fake_latent.repeat(gt_imgs.shape[0], 1)

        gt_latents = self.patchify(gt_imgs)
        lq_latents = self.patchify(lq_imgs)

        if loss_weight:
            loss_weight = self.loss_weight

        x_index = torch.tensor([freq_level]*gt_imgs.size(0)).to(gt_imgs.device).half()
        # 保存当前使用的频率级别,以便在训练中监控
        self.last_used_freq_level = freq_level

        mask=None

        x = self.forward_mae_encoder(lq_latents, freq_embedding, mask)
        z = self.forward_mae_decoder(x, mask)

        # 5. 计算损失,使用原始gt_latents
        loss = self.forward_loss(z=z, target=gt_latents, mask=None, index=x_index, loss_weight=loss_weight)
        
        return loss




    def sample_tokens_mask(self, bsz, num_iter=64, cfg=1.0, cfg_schedule="linear", labels=None, temperature=1.0, progress=False):
        latent_core = [0,12,12,4,5,6,7,8,9,10,11]

        num_iter = len(latent_core)
        mask = torch.ones(bsz, self.seq_len).cuda()      
        tokens = torch.zeros(bsz, self.seq_len, self.token_embed_dim).cuda()
        orders = self.sample_orders(bsz)

        for step in list(range(num_iter)):
            cur_tokens = torch.zeros(bsz, self.seq_len, self.token_embed_dim).cuda()
            if labels is not None:
                class_embedding = self.class_emb(labels)      
            else:
                class_embedding = self.fake_latent.repeat(bsz, 1)
    
            tokens = torch.cat([tokens, tokens], dim=0)
            class_embedding = torch.cat([class_embedding, self.fake_latent.repeat(bsz, 1)], dim=0)
            mask = torch.cat([mask, mask], dim=0)

            x = self.forward_mae_encoder(tokens, class_embedding, mask)
            z = self.forward_mae_decoder(x, mask)    # torch.Size([512, 256, 768])
            B, L, C = z.shape
            z = z.reshape(B * L, -1)

            # mask ratio for the next round, following MaskGIT and MAGE.
            mask_ratio = np.cos(math.pi / 2. * (step + 1) / num_iter)
            mask_len = torch.Tensor([np.floor(self.seq_len * mask_ratio)]).cuda()     # 251, 236, 212, 181, 142, 97, 49, 0 
            cfg_iter = 1 + (cfg - 1) * step / num_iter
            temperature_iter = 0.85 + (temperature - 0.85) * step / num_iter

            index = torch.tensor([latent_core[step]]).unsqueeze(1).unsqueeze(-1).repeat(B, L, 1).reshape(B * L, -1).to(z.device).half()
            z = self.diffloss.sample(z, temperature_iter, cfg_iter, index)     # torch.Size([512, 16])
            z, _ = z.chunk(2, dim=0)  # Remove null class samples.  torch.Size([256, 16])
            
            if step < num_iter-1:
                z = z.reshape(bsz, L, -1).transpose_(1, 2).reshape(bsz, -1, 16, 16)
                
                if step > 0:
                    imgs_resize = F.interpolate(z, size=(latent_core[step+1], latent_core[step+1]), mode='area')
                    z = F.interpolate(imgs_resize, size=(16, 16), mode='bicubic')
                z = z.reshape(bsz, -1, L).transpose_(1, 2).reshape(bsz*L, -1)
                
            
            sampled_token = z.reshape(bsz, L, -1)
            mask_next = mask_by_order(mask_len[0], orders, bsz, self.seq_len)
            mask_to_pred = torch.logical_not(mask_next)
            mask = mask_next
            sampled_token_latent = sampled_token[mask_to_pred.nonzero(as_tuple=True)]
            
            cur_tokens[mask_to_pred.nonzero(as_tuple=True)] = sampled_token_latent
            tokens = cur_tokens.clone()


        tokens = tokens.transpose_(1, 2).reshape(bsz, -1, 16, 16)
        return tokens
    

    def sample_tokens_nomask(self, bsz, num_iter=64, cfg=1.0, cfg_schedule="linear", labels=None, temperature=1.0, progress=False):
        latent_core = [0,2,3,4,5,6,7,8,9,10]
        # latent_core = [0,1,2,3,4,5,6,7,8,9]
        num_iter = len(latent_core)

        # init and sample generation orders
        tokens = torch.zeros(bsz, self.seq_len, self.token_embed_dim).cuda().half()
        indices = list(range(num_iter))
        if progress:
            indices = tqdm(indices)
        
        for step in indices:
            if labels is not None:
                class_embedding = self.class_emb(labels)        # torch.Size([256, 768])
            else:
                class_embedding = self.fake_latent.repeat(bsz, 1)
            
            tokens = torch.cat([tokens, tokens], dim=0)
            class_embedding = torch.cat([class_embedding, self.fake_latent.repeat(bsz, 1)], dim=0)

            # mae encoder
            x = self.forward_mae_encoder(tokens, class_embedding)
            z = self.forward_mae_decoder(x)    # torch.Size([512, 256, 768])    var输出的condition的维度很高(768),var一次生成所有token后,只随机取部分(nge)送到diffusion中作为条件,生成部分token。

            B, L, C = z.shape
            z = z.reshape(B * L, -1)


            cfg_iter = 1 + (cfg - 1) * step / num_iter
            temperature_iter = 0.8 + (1 - np.cos(math.pi / 2. * (step + 1) / num_iter)) * (1-0.8)

            index = torch.tensor([latent_core[step]]).unsqueeze(1).unsqueeze(-1).repeat(B, L, 1).reshape(B * L, -1).to(z.device).half()
            sampled_token_latent = self.diffloss.sample(z, temperature_iter, cfg_iter, index=index)     # torch.Size([512, 16])
            sampled_token_latent, _ = sampled_token_latent.chunk(2, dim=0)  # Remove null class samples.  torch.Size([256, 16])

            sampled_token_latent = sampled_token_latent.reshape(bsz, L, -1).transpose_(1, 2).reshape(bsz, -1, 16, 16)
            if step < num_iter-1:
                if step > -1:
                    sampled_token_latent = F.interpolate(sampled_token_latent, size=(latent_core[step+1], latent_core[step+1]), mode='area')
                    sampled_token_latent = F.interpolate(sampled_token_latent, size=(16, 16), mode='bicubic')
                sampled_token_latent = sampled_token_latent.view(bsz, 16, -1).transpose(1, 2)

            tokens = sampled_token_latent.clone()

        return tokens
        

    def denoise(self, lq_img):
        """用于对单张低质量图像进行去噪处理的推理方法"""
        bsz = lq_img.shape[0]
        lq_latents = self.patchify(lq_img)  # [bsz, seq_len, token_embed_dim]
        
        # 1. 使用合适的固定频率索引
        x_index = torch.ones(bsz, device=lq_img.device).mul(15).half()
        
        # 2. 将LQ潜在表示投影到decoder_embed_dim维度
        # 第一步投影: token_embed_dim -> encoder_embed_dim
        lq_proj = self.z_proj(lq_latents)
        # 使用层归一化以确保数值稳定性
        lq_proj = self.z_proj_ln(lq_proj)
        
        # 第二步投影: encoder_embed_dim -> decoder_embed_dim
        z = self.decoder_embed(lq_proj)
        
        # 3. 添加位置编码信息,这对于扩散模型很重要
        z = z + self.diffusion_pos_embed_learned
        
        # 使用正则化确保z维持在合理范围
        z = self.decoder_norm(z)
        
        # 应用额外的缩放以确保数值稳定
        scale_factor = 0.1
        z = z * scale_factor
        
        # 使用diffloss的sample方法从z采样生成增强图像的潜在表示
        # 设置较低的temperature提高输出的确定性
        temperature = 0.8
        cfg = 1.0  # 不使用分类器引导
        
        # 获取序列长度
        seq_len = z.shape[1]
        
        # 批量处理所有位置
        tokens_flat = z.reshape(bsz * seq_len, -1)
        indices = x_index.unsqueeze(1).unsqueeze(-1).repeat(1, seq_len, 1).reshape(bsz * seq_len, -1)
        
        # 使用diffloss的sample方法生成去噪结果
        enhanced_tokens = self.diffloss.sample(tokens_flat, temperature, cfg, indices)
        
        # 重要: 调试打印实际的张量形状
        print(f"diffloss.sample输出形状: {enhanced_tokens.shape}")
        
        # 参考sample_tokens_nomask方法处理输出格式
        # 将结果重塑为与sample_tokens_nomask相同的格式
        h = w = int(math.sqrt(seq_len))  # 通常是16x16
        
        # 与sample_tokens_nomask保持一致的形状转换
        enhanced_tokens = enhanced_tokens.reshape(bsz, seq_len, -1)
        
        # 重要: 调试打印重塑后的张量形状
        print(f"重塑后的形状: {enhanced_tokens.shape}")
        
        # 转置并重塑为VAE解码器需要的空间形式 [B, C, H, W]
        # 根据enhanced_tokens的实际形状调整转换方式
        C = enhanced_tokens.shape[2]  # 获取通道数
        enhanced_tokens = enhanced_tokens.permute(0, 2, 1)  # [bsz, C, seq_len]
        enhanced_tokens = enhanced_tokens.reshape(bsz, C, h, w)  # [bsz, C, h, w]
        
        # 重要: 调试打印最终形状
        print(f"最终输出形状: {enhanced_tokens.shape}")
        
        return enhanced_tokens


def far_base(**kwargs):
    model = FAR(
        encoder_embed_dim=768, encoder_depth=12, encoder_num_heads=12,
        decoder_embed_dim=768, decoder_depth=12, decoder_num_heads=12,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def far_large(**kwargs):
    model = FAR(
        encoder_embed_dim=1024, encoder_depth=16, encoder_num_heads=16,
        decoder_embed_dim=1024, decoder_depth=16, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


def far_huge(**kwargs):
    model = FAR(
        encoder_embed_dim=1280, encoder_depth=20, encoder_num_heads=16,
        decoder_embed_dim=1280, decoder_depth=20, decoder_num_heads=16,
        mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    return model


engine_far.py

import math
import sys
from typing import Iterable
import random

import torch

import util.misc as misc
import util.lr_sched as lr_sched
from models.vae import DiagonalGaussianDistribution
import torch_fidelity
import shutil
import cv2
import numpy as np
import matplotlib.pyplot as plt
import os
import copy
import time


from torchvision.utils import make_grid
from typing import Optional
from PIL import Image
def save_image(images: torch.Tensor, nrow: int = 8, show: bool = True, path: Optional[str] = None, format: Optional[str] = None, to_grayscale: bool = False, **kwargs):
    images = images * 0.5 + 0.5
    grid = make_grid(images, nrow=nrow, **kwargs)  # (channels, height, width)
    #  (height, width, channels)
    grid = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
    im = Image.fromarray(grid)
    if to_grayscale:
        im = im.convert(mode="L")
    if path is not None:
        im.save(path, format=format)
    if show:
        im.show()
    return grid


def update_ema(target_params, source_params, rate=0.99):
    """
    Update target parameters to be closer to those of source parameters using
    an exponential moving average.

    :param target_params: the target parameter sequence.
    :param source_params: the source parameter sequence.
    :param rate: the EMA rate (closer to 1 means slower).
    """
    for targ, src in zip(target_params, source_params):
        targ.detach().mul_(rate).add_(src, alpha=1 - rate)


def train_one_epoch(model, vae,
                    model_params, ema_params,
                    data_loader: Iterable, optimizer: torch.optim.Optimizer,
                    device: torch.device, epoch: int, loss_scaler,
                    log_writer=None,
                    args=None):
    model.train(True)
    metric_logger = misc.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)
    print_freq = 10

    optimizer.zero_grad()

    if log_writer is not None:
        print('log_dir: {}'.format(log_writer.log_dir))
    
    # 添加更多统计信息的跟踪
    all_losses = []
    freq_levels_used = []
    param_norms_before = [p.norm().item() for p in model.parameters() if p.requires_grad]
    
    # 用于记录数据处理的统计信息
    num_samples_processed = 0
    
    for data_iter_step, data in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
        # we use a per iteration (instead of per epoch) lr scheduler
        lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)

        # 处理数据:根据是否为微调任务选择不同的处理方式
        if args.finetune:
            # 微调任务:处理LQ和GT图像对
            lq_samples, gt_samples = data
            lq_samples = lq_samples.to(device, non_blocking=True)
            gt_samples = gt_samples.to(device, non_blocking=True)
            
            with torch.no_grad():
                # 将LQ和GT图像都编码到潜在空间
                lq_posterior = vae.encode(lq_samples)
                lq_x = lq_posterior.sample().mul_(0.2325)
                
                gt_posterior = vae.encode(gt_samples)
                gt_x = gt_posterior.sample().mul_(0.2325)
                
                # 创建假标签(全零)- 因为我们不使用类别条件
                fake_labels = torch.zeros(lq_samples.size(0), dtype=torch.long, device=device)
            
            # 前向传播 - 将LQ和GT的潜在表示传入模型
            with torch.cuda.amp.autocast():
                # 注意:这里假设model的forward已经修改为接受lq_latents和gt_latents
                # 用户说明他会自己修改forward函数,所以这里我们假设它已经被修改
                loss = model(lq_x, gt_x, loss_weight=args.loss_weight)
                
                # 记录当前使用的频率级别(从模型内部获取)
                # 注意:这里假设forward方法中使用的频率级别可以通过检查来获取
                if hasattr(model, 'module'):
                    freq_level = model.module.last_used_freq_level if hasattr(model.module, 'last_used_freq_level') else random.randint(14, 15)
                else:
                    freq_level = model.last_used_freq_level if hasattr(model, 'last_used_freq_level') else random.randint(14, 15)
                freq_levels_used.append(freq_level)

        else:
            # 原始任务:处理单个图像和类别标签
            samples, labels = data
            samples = samples.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            with torch.no_grad():
                if args.use_cached:
                    moments = samples
                    posterior = DiagonalGaussianDistribution(moments)
                else:
                    posterior = vae.encode(samples)

                # normalize the std of latent to be 1. Change it if you use a different tokenizer
                x = posterior.sample().mul_(0.2325)

            # forward
            with torch.cuda.amp.autocast():
                loss = model(x, labels, loss_weight=args.loss_weight)
                
        loss_value = loss.item()
        all_losses.append(loss_value)
        num_samples_processed += lq_samples.size(0) if args.finetune else samples.size(0)

        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            sys.exit(1)

        loss_scaler(loss, optimizer, clip_grad=args.grad_clip, parameters=model.parameters(), update_grad=True)
        optimizer.zero_grad()

        torch.cuda.synchronize()

        update_ema(ema_params, model_params, rate=args.ema_rate)

        metric_logger.update(loss=loss_value)

        lr = optimizer.param_groups[0]["lr"]
        metric_logger.update(lr=lr)

        loss_value_reduce = misc.all_reduce_mean(loss_value)
        if log_writer is not None:
            """ We use epoch_1000x as the x-axis in tensorboard.
            This calibrates different curves when batch size changes.
            """
            epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
            log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x)
            log_writer.add_scalar('lr', lr, epoch_1000x)
    
    # 计算参数变化的统计信息
    param_norms_after = [p.norm().item() for p in model.parameters() if p.requires_grad]
    param_changes = [abs(after - before) / (before + 1e-10) for before, after in zip(param_norms_before, param_norms_after)]
    
    # 在epoch结束时输出详细统计信息
    if all_losses:
        loss_mean = sum(all_losses) / len(all_losses)
        loss_min = min(all_losses)
        loss_max = max(all_losses)
        loss_std = np.std(all_losses) if len(all_losses) > 1 else 0
        
        # 计算参数变化的统计信息
        param_change_mean = sum(param_changes) / len(param_changes) if param_changes else 0
        param_change_max = max(param_changes) if param_changes else 0
        
        # 计算频率级别的统计信息
        freq_level_counts = {}
        for freq in freq_levels_used:
            freq_level_counts[freq] = freq_level_counts.get(freq, 0) + 1
            
        # 打印详细信息
        print("\n" + "="*80)
        print(f"微调统计信息 - Epoch {epoch}")
        print(f"处理样本数: {num_samples_processed}")
        print(f"损失统计: 平均={loss_mean:.6f}, 最小={loss_min:.6f}, 最大={loss_max:.6f}, 标准差={loss_std:.6f}")
        
        if freq_levels_used:
            print(f"频率级别使用情况: {freq_level_counts}")
            
        print(f"参数变化: 平均相对变化={param_change_mean:.6f}, 最大相对变化={param_change_max:.6f}")
        
        # 如果损失有异常,输出警告
        if loss_std > loss_mean * 0.5:  # 如果标准差大于平均值的一半,可能有问题
            print("警告: 损失波动较大,可能需要调整学习率或检查模型稳定性")
        
        if param_change_mean < 1e-6:
            print("警告: 参数变化很小,可能学习率过低或梯度消失")
        elif param_change_max > 0.5:
            print("警告: 部分参数变化较大,可能学习率过高或梯度爆炸")
            
        print("="*80 + "\n")
        
        # 如果有tensorboard,记录这些统计信息
        if log_writer is not None:
            log_writer.add_scalar('epoch_loss_mean', loss_mean, epoch)
            log_writer.add_scalar('epoch_loss_min', loss_min, epoch)
            log_writer.add_scalar('epoch_loss_max', loss_max, epoch)
            log_writer.add_scalar('epoch_loss_std', loss_std, epoch)
            log_writer.add_scalar('epoch_param_change_mean', param_change_mean, epoch)
            log_writer.add_scalar('epoch_param_change_max', param_change_max, epoch)
        
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}


def evaluate(model_without_ddp, vae, ema_params, args, epoch, batch_size=16, log_writer=None, cfg=1.0,
             use_ema=True):
    model_without_ddp.eval()
    num_steps = args.num_images // (batch_size * misc.get_world_size()) + 1
    save_folder = os.path.join(args.output_dir, "ariter{}-diffsteps{}-temp{}-{}cfg{}-image{}".format(args.num_iter,
                                                                                                     args.num_sampling_steps,
                                                                                                     args.temperature,
                                                                                                     args.cfg_schedule,
                                                                                                     cfg,
                                                                                                     args.num_images))
    if use_ema:
        save_folder = save_folder + "_ema"
    if args.evaluate:
        save_folder = save_folder + "_evaluate"
    print("Save to:", save_folder)
    if misc.get_rank() == 0:
        if not os.path.exists(save_folder):
            os.makedirs(save_folder)

    # switch to ema params
    if use_ema:
        model_state_dict = copy.deepcopy(model_without_ddp.state_dict())
        ema_state_dict = copy.deepcopy(model_without_ddp.state_dict())
        for i, (name, _value) in enumerate(model_without_ddp.named_parameters()):
            assert name in ema_state_dict
            ema_state_dict[name] = ema_params[i]
        print("Switch to ema")
        model_without_ddp.load_state_dict(ema_state_dict)

    class_num = args.class_num
    assert args.num_images % class_num == 0  # number of images per class must be the same
    class_label_gen_world = np.arange(0, class_num).repeat(args.num_images // class_num)
    class_label_gen_world = np.hstack([class_label_gen_world, np.zeros(50000)])
    world_size = misc.get_world_size()
    local_rank = misc.get_rank()
    used_time = 0
    gen_img_cnt = 0


    for i in range(num_steps):
        print("Generation step {}/{}".format(i, num_steps))

        labels_gen = class_label_gen_world[world_size * batch_size * i + local_rank * batch_size:
                                                world_size * batch_size * i + (local_rank + 1) * batch_size]
        labels_gen = torch.Tensor(labels_gen).long().cuda()

        torch.cuda.synchronize()
        start_time = time.time()

        # generation
        with torch.no_grad():
            with torch.cuda.amp.autocast():
                if args.mask:
                    sampled_tokens = model_without_ddp.sample_tokens_mask(bsz=batch_size, num_iter=args.num_iter, cfg=cfg,
                                                                    cfg_schedule=args.cfg_schedule, labels=labels_gen,
                                                                    temperature=args.temperature)
                else:
                    sampled_tokens = model_without_ddp.sample_tokens_nomask(bsz=batch_size, num_iter=args.num_iter, cfg=cfg,
                                                                    cfg_schedule=args.cfg_schedule, labels=labels_gen,
                                                                    temperature=args.temperature)
                sampled_images = vae.decode(sampled_tokens / 0.2325)

        # measure speed after the first generation batch
        if i >= 1:
            torch.cuda.synchronize()
            used_time += time.time() - start_time
            gen_img_cnt += batch_size
            print("Generating {} images takes {:.5f} seconds, {:.5f} sec per image".format(gen_img_cnt, used_time, used_time / gen_img_cnt))

        torch.distributed.barrier()
        sampled_images = sampled_images.detach().cpu()

        save_image(sampled_images, nrow=8, show=False, path=os.path.join(args.output_dir, f"epoch{epoch}.png"), to_grayscale=False)
        return

        sampled_images = (sampled_images + 1) / 2

        # distributed save
        for b_id in range(sampled_images.size(0)):
            img_id = i * sampled_images.size(0) * world_size + local_rank * sampled_images.size(0) + b_id
            if img_id >= args.num_images:
                break
            gen_img = np.round(np.clip(sampled_images[b_id].numpy().transpose([1, 2, 0]) * 255, 0, 255))
            gen_img = gen_img.astype(np.uint8)[:, :, ::-1]
            gen_img = gen_img[:, :, ::-1]
            plt.imsave(os.path.join(save_folder, '{}.png'.format(str(img_id).zfill(5))), gen_img)


    torch.distributed.barrier()
    time.sleep(10)

    # back to no ema
    if use_ema:
        print("Switch back from ema")
        model_without_ddp.load_state_dict(model_state_dict)

    # compute FID and IS
    if log_writer is not None:
        if args.img_size == 256:
            input2 = None
            fid_statistics_file = '/mnt/workspace/workgroup/yuhu/code/FAR/fid_stats/adm_in256_stats.npz'
        else:
            raise NotImplementedError
        metrics_dict = torch_fidelity.calculate_metrics(
            input1=save_folder,
            input2=input2,
            fid_statistics_file=fid_statistics_file,
            cuda=True,
            isc=True,
            fid=True,
            kid=False,
            prc=False,
            verbose=False,
        )
        fid = metrics_dict['frechet_inception_distance']
        inception_score = metrics_dict['inception_score_mean']
        print(fid)
        print(inception_score)
        postfix = ""
        if use_ema:
           postfix = postfix + "_ema"
        if not cfg == 1.0:
           postfix = postfix + "_cfg{}".format(cfg)
        log_writer.add_scalar('fid{}'.format(postfix), fid, epoch)
        log_writer.add_scalar('is{}'.format(postfix), inception_score, epoch)
        print("FID: {:.4f}, Inception Score: {:.4f}".format(fid, inception_score))
        # remove temporal saving folder
        shutil.rmtree(save_folder)

    torch.distributed.barrier()
    time.sleep(10)


def cache_latents(vae,
                  data_loader: Iterable,
                  device: torch.device,
                  args=None):
    metric_logger = misc.MetricLogger(delimiter="  ")
    header = 'Caching: '
    print_freq = 20

    for data_iter_step, (samples, _, paths) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):

        samples = samples.to(device, non_blocking=True)

        with torch.no_grad():
            posterior = vae.encode(samples)
            moments = posterior.parameters
            posterior_flip = vae.encode(samples.flip(dims=[3]))
            moments_flip = posterior_flip.parameters

        for i, path in enumerate(paths):
            save_path = os.path.join(args.cached_path, path + '.npz')
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
            np.savez(save_path, moments=moments[i].cpu().numpy(), moments_flip=moments_flip[i].cpu().numpy())

        if misc.is_dist_avail_and_initialized():
            torch.cuda.synchronize()

    return


def evaluate_denoising(model, vae, args, device):
    """评估去噪模型在测试集上的性能
    
    Args:
        model: 微调后的FAR模型
        vae: VAE模型用于图像处理
        args: 参数
        device: 设备
    
    Returns:
        包含评估指标的字典
    """
    print("开始评估去噪模型性能...")
    
    # 加载测试数据列表
    with open(args.lq_list_path, 'r') as f:
        lq_paths = [line.strip() for line in f.readlines()]
    
    with open(args.gt_list_path, 'r') as f:
        gt_paths = [line.strip() for line in f.readlines()]
    
    assert len(lq_paths) == len(gt_paths), "LQ和GT图像数量不匹配"
    print(f"共加载 {len(lq_paths)} 对测试图像")
    
    # 图像变换
    transform = transforms.Compose([
        transforms.Resize((args.img_size, args.img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    
    # 设置模型为评估模式
    model.eval()
    
    # 用于计算PSNR和SSIM的函数
    from skimage.metrics import peak_signal_noise_ratio as psnr
    from skimage.metrics import structural_similarity as ssim
    
    # 记录指标
    psnr_values = []
    ssim_values = []
    
    # 创建输出目录用于保存样本结果
    if args.save_samples:
        os.makedirs(os.path.join(args.output_dir, "test_samples"), exist_ok=True)
    
    # 处理每一对图像
    for idx, (lq_path, gt_path) in enumerate(zip(lq_paths, gt_paths)):
        if idx % 100 == 0:
            print(f"处理图像 {idx}/{len(lq_paths)}...")
        
        try:
            # 加载LQ和GT图像
            lq_img = Image.open(lq_path).convert('RGB')
            gt_img = Image.open(gt_path).convert('RGB')
            
            # 应用变换
            lq_tensor = transform(lq_img).unsqueeze(0).to(device)
            gt_tensor = transform(gt_img).unsqueeze(0).to(device)
            
            with torch.no_grad():
                # 使用VAE编码获取潜在表示
                lq_posterior = vae.encode(lq_tensor)
                lq_latents = lq_posterior.sample().mul_(0.2325)
                
                # 使用模型的denoise方法进行去噪
                denoised_latent = model.denoise(lq_latents)
                
                # 使用VAE解码获取去噪后的图像
                denoised_img = vae.decode(denoised_latent / 0.2325)
                
                # 将图像转换为适合计算指标的格式
                # 从[-1,1]转换为[0,1]
                lq_np = (lq_tensor + 1) / 2.0
                gt_np = (gt_tensor + 1) / 2.0
                denoised_np = (denoised_img + 1) / 2.0
                
                # 移动到CPU并转换为numpy数组
                lq_np = lq_np.squeeze().permute(1, 2, 0).cpu().numpy().clip(0, 1)
                gt_np = gt_np.squeeze().permute(1, 2, 0).cpu().numpy().clip(0, 1)
                denoised_np = denoised_np.squeeze().permute(1, 2, 0).cpu().numpy().clip(0, 1)
                
                # 计算PSNR
                current_psnr = psnr(gt_np, denoised_np, data_range=1.0)
                psnr_values.append(current_psnr)
                
                # 计算SSIM (多通道)
                current_ssim = ssim(gt_np, denoised_np, data_range=1.0, channel_axis=2, multichannel=True)
                ssim_values.append(current_ssim)
                
                # 保存一些样本图像以便可视化检查
                if args.save_samples and idx < 10:  # 只保存前10个样本
                    save_path = os.path.join(args.output_dir, "test_samples", f"sample_{idx}.png")
                    # 创建对比图 [LQ, Denoised, GT]
                    comparison = torch.cat([
                        lq_tensor, 
                        denoised_img, 
                        gt_tensor
                    ], dim=0)
                    
                    save_image(comparison, nrow=3, show=False, path=save_path)
                    print(f"样本 {idx} 已保存至: {save_path}")
        
        except Exception as e:
            print(f"处理图像 {lq_path} 时出错: {e}")
            continue
    
    # 计算平均指标
    avg_psnr = np.mean(psnr_values)
    avg_ssim = np.mean(ssim_values)
    
    # 打印结果
    print("="*50)
    print("去噪模型评估结果:")
    print(f"平均 PSNR: {avg_psnr:.4f} dB")
    print(f"平均 SSIM: {avg_ssim:.4f}")
    print("="*50)
    
    # 保存结果到文件
    result_path = os.path.join(args.output_dir, "denoising_results.txt")
    with open(result_path, 'w') as f:
        f.write(f"测试图像数量: {len(psnr_values)}\n")
        f.write(f"平均 PSNR: {avg_psnr:.4f} dB\n")
        f.write(f"平均 SSIM: {avg_ssim:.4f}\n")
    
    print(f"详细结果已保存至: {result_path}")
    
    return {
        "PSNR": avg_psnr,
        "SSIM": avg_ssim,
        "样本数": len(psnr_values)
    }

main_far.py

import argparse
import datetime
import numpy as np
import os
import time
from pathlib import Path
import random

import torch
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import Dataset
from PIL import Image

from util.crop import center_crop_arr
import util.misc as misc
from util.misc import NativeScalerWithGradNormCount as NativeScaler
from util.loader import CachedFolder

from models.vae import AutoencoderKL
from models import far
from engine_far import train_one_epoch, evaluate, evaluate_denoising
import copy


# 添加推理单张图像的函数
def inference_single_image(model, vae, input_path, output_path, device, args):
    """对单张图像进行去噪处理
    
    Args:
        model: 微调后的FAR模型
        vae: VAE模型用于图像处理
        input_path: 输入图像路径
        output_path: 输出图像路径
        device: 设备
        args: 参数
    """
    print(f"正在处理图像: {input_path}")
    
    # 确保输出目录存在
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    # 设置模型为评估模式
    model.eval()
    
    # 加载并预处理图像,与微调时使用的预处理流程一致
    transform = transforms.Compose([
        transforms.Resize((args.img_size, args.img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    
    try:
        img = Image.open(input_path).convert('RGB')
        img_tensor = transform(img).unsqueeze(0).to(device)
        
        print(f"输入图像形状: {img_tensor.shape}")
        
        with torch.no_grad():
            # 1. 使用VAE编码获取潜在表示,与微调流程一致
            posterior = vae.encode(img_tensor)
            latent = posterior.sample().mul_(0.2325)  # 使用与微调相同的缩放因子
            
            print(f"VAE编码后形状: {latent.shape}")
            
            # 2. 使用模型的denoise方法,它现在使用与微调相同的流程
            print("应用去噪处理...")
            denoised_latent = model.denoise(latent)
            
            print(f"Denoise后形状: {denoised_latent.shape}")
            
            # 3. 使用VAE解码获取去噪后的图像
            print("VAE解码中...")
            denoised_img = vae.decode(denoised_latent / 0.2325)
            
            print(f"VAE解码后形状: {denoised_img.shape}")
            
            # 4. 保存输出图像
            print("保存结果图像...")
            output_img = (denoised_img + 1) / 2.0
            output_img = output_img.clamp(0, 1)
            
            from torchvision.utils import save_image
            save_image(output_img, output_path)
            print(f"去噪后的图像已保存至: {output_path}")
            
            # 同时保存输入图像以便比较
            input_save_path = os.path.join(os.path.dirname(output_path), "input.png")
            input_img = (img_tensor + 1) / 2.0
            save_image(input_img, input_save_path)
            print(f"输入图像已保存至: {input_save_path}")
            
            # 额外保存一个并排比较图像
            comparison = torch.cat([input_img, output_img], dim=0)
            comparison_path = os.path.join(os.path.dirname(output_path), "comparison.png")
            save_image(comparison, comparison_path, nrow=2)
            print(f"比较图像已保存至: {comparison_path}")
            
    except Exception as e:
        print(f"处理图像时出错: {e}")
        import traceback
        traceback.print_exc()

# 添加自定义数据集用于加载LQ和GT图像对
class DenoiseDataset(Dataset):
    def __init__(self, lq_list_path, gt_list_path, transform=None, img_size=256):
        self.transform = transform
        self.img_size = img_size
        
        # 读取LQ和GT图像路径
        with open(lq_list_path, 'r') as f:
            self.lq_paths = [line.strip() for line in f.readlines()]
        
        with open(gt_list_path, 'r') as f:
            self.gt_paths = [line.strip() for line in f.readlines()]
        
        assert len(self.lq_paths) == len(self.gt_paths), "LQ和GT图像数量不匹配"
        
    def __len__(self):
        return len(self.lq_paths)
    
    def __getitem__(self, idx):
        # 加载LQ和GT图像
        lq_img = Image.open(self.lq_paths[idx]).convert('RGB')
        gt_img = Image.open(self.gt_paths[idx]).convert('RGB')
        
        # 应用变换
        if self.transform:
            # 创建一个新的transform,但移除RandomHorizontalFlip
            # 这样可以手动控制翻转,确保LQ和GT图像保持一致
            transform_without_flip = transforms.Compose([
                transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, self.img_size)),
                transforms.ToTensor(),
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
            ])
            
            # 手动决定是否翻转,确保LQ和GT图像使用相同的决策
            if random.random() < 0.5:
                lq_img = transforms.functional.hflip(lq_img)
                gt_img = transforms.functional.hflip(gt_img)
            
            # 应用其他变换
            lq_img = transform_without_flip(lq_img)
            gt_img = transform_without_flip(gt_img)
        else:
            # 如果没有transform,直接转换为tensor
            lq_img = transforms.ToTensor()(lq_img)
            gt_img = transforms.ToTensor()(gt_img)
            
        return lq_img, gt_img
    
def get_args_parser():
    parser = argparse.ArgumentParser('FAR training with Diffusion Loss', add_help=False)
    parser.add_argument('--batch_size', default=16, type=int,
                        help='Batch size per GPU (effective batch size is batch_size * # gpus')
    parser.add_argument('--epochs', default=400, type=int)

    # Model parameters
    parser.add_argument('--model', default='far_large', type=str, metavar='MODEL',
                        help='Name of model to train')

    # VAE parameters
    parser.add_argument('--img_size', default=256, type=int,
                        help='images input size')
    parser.add_argument('--vae_path', default="pretrained/vae_mar/kl16.ckpt", type=str,
                        help='images input size')
    parser.add_argument('--vae_embed_dim', default=16, type=int,
                        help='vae output embedding dimension')
    parser.add_argument('--vae_stride', default=16, type=int,
                        help='tokenizer stride, default use KL16')
    parser.add_argument('--patch_size', default=1, type=int,
                        help='number of tokens to group as a patch.')

    # Generation parameters
    parser.add_argument('--num_iter', default=64, type=int,
                        help='number of autoregressive iterations to generate an image')
    parser.add_argument('--num_images', default=50000, type=int,
                        help='number of images to generate')
    parser.add_argument('--cfg', default=1.0, type=float, help="classifier-free guidance")
    parser.add_argument('--cfg_schedule', default="linear", type=str)
    parser.add_argument('--label_drop_prob', default=0.1, type=float)
    parser.add_argument('--eval_freq', type=int, default=40, help='evaluation frequency')
    parser.add_argument('--save_last_freq', type=int, default=10, help='save last frequency')
    parser.add_argument('--online_eval', action='store_true')
    parser.add_argument('--evaluate', action='store_true')
    parser.add_argument('--mask', action='store_true')
    parser.add_argument('--eval_bsz', type=int, default=64, help='generation batch size')

    # Optimizer parameters
    parser.add_argument('--weight_decay', type=float, default=0.02,
                        help='weight decay (default: 0.02)')

    parser.add_argument('--lr', type=float, default=None, metavar='LR',
                        help='learning rate (absolute lr)')
    parser.add_argument('--blr', type=float, default=1e-4, metavar='LR',
                        help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
    parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
                        help='lower lr bound for cyclic schedulers that hit 0')
    parser.add_argument('--lr_schedule', type=str, default='constant',
                        help='learning rate schedule')
    parser.add_argument('--warmup_epochs', type=int, default=100, metavar='N',
                        help='epochs to warmup LR')
    parser.add_argument('--ema_rate', default=0.9999, type=float)

    # FAR params
    parser.add_argument('--mask_ratio_min', type=float, default=0.7,
                        help='Minimum mask ratio')
    parser.add_argument('--grad_clip', type=float, default=3.0,
                        help='Gradient clip')
    parser.add_argument('--attn_dropout', type=float, default=0.1,
                        help='attention dropout')
    parser.add_argument('--proj_dropout', type=float, default=0.1,
                        help='projection dropout')
    parser.add_argument('--buffer_size', type=int, default=64)
    parser.add_argument('--loss_weight', action='store_true', help='adopt uneven loss weight.')

    # Diffusion Loss params
    parser.add_argument('--diffloss_d', type=int, default=12)
    parser.add_argument('--diffloss_w', type=int, default=1536)
    parser.add_argument('--num_sampling_steps', type=str, default="100")
    parser.add_argument('--diffusion_batch_mul', type=int, default=1)
    parser.add_argument('--temperature', default=1.0, type=float, help='diffusion loss sampling temperature')

    # Dataset parameters
    parser.add_argument('--data_path', default='./data/imagenet', type=str,
                        help='dataset path')
    parser.add_argument('--class_num', default=1000, type=int)

    # 添加图像去噪微调所需的参数
    parser.add_argument('--lq_list_path', default='', type=str,
                        help='Path to the list of low quality images')
    parser.add_argument('--gt_list_path', default='', type=str,
                        help='Path to the list of ground truth images')
    parser.add_argument('--finetune', action='store_true', 
                        help='Enable finetuning for image denoising')
    parser.add_argument('--pretrained_path', default='./pretrained_models/far/far_large/checkpoint-last.pth', 
                        type=str, help='Path to pretrained model for finetuning')
                        
    # 添加单图像推理所需的参数
    parser.add_argument('--inference', action='store_true',
                        help='Run inference on a single image')
    parser.add_argument('--input_image', default='./data/test.png', type=str,
                        help='Path to the input image for inference')
    parser.add_argument('--output_image', default='./output_dir_denoise/result.png', type=str,
                        help='Path to save the output image')
    parser.add_argument('--use_ema', action='store_true',
                        help='Use EMA parameters for inference')

    parser.add_argument('--output_dir', default='./output_dir',
                        help='path where to save, empty for no saving')
    parser.add_argument('--log_dir', default='./output_dir',
                        help='path where to tensorboard log')
    parser.add_argument('--device', default='cuda',
                        help='device to use for training / testing')
    parser.add_argument('--seed', default=1, type=int)
    parser.add_argument('--resume', default='',
                        help='resume from checkpoint')

    parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
                        help='start epoch')
    parser.add_argument('--num_workers', default=8, type=int)
    parser.add_argument('--pin_mem', action='store_true',
                        help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
    parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
    parser.set_defaults(pin_mem=True)

    # distributed training parameters
    parser.add_argument('--world_size', default=1, type=int,
                        help='number of distributed processes')
    parser.add_argument('--local_rank', default=-1, type=int)
    parser.add_argument('--dist_on_itp', action='store_true')
    parser.add_argument('--dist_url', default='env://',
                        help='url used to set up distributed training')

    # caching latents
    parser.add_argument('--use_cached', action='store_true', dest='use_cached',
                        help='Use cached latents')
    parser.set_defaults(use_cached=False)
    parser.add_argument('--cached_path', default='', help='path to cached latents')

    # 添加评估去噪模型的参数
    parser.add_argument('--evaluate_denoising', action='store_true',
                        help='Evaluate denoising performance on test set')
    parser.add_argument('--save_samples', action='store_true',
                        help='Save sample images during evaluation')

    return parser

# 在main函数中修改数据加载部分
def main(args):
    misc.init_distributed_mode(args)

    print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
    print("{}".format(args).replace(', ', ',\n'))

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + misc.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)

    cudnn.benchmark = True

    num_tasks = misc.get_world_size()
    global_rank = misc.get_rank()


    if global_rank == 0 and args.log_dir is not None and not args.inference:
        os.makedirs(args.log_dir, exist_ok=True)
        log_writer = SummaryWriter(log_dir=args.log_dir)
    else:
        log_writer = None


    ### ImageNet VAE
    vae = AutoencoderKL(embed_dim=args.vae_embed_dim, ch_mult=(1, 1, 2, 2, 4), ckpt_path=args.vae_path).cuda().eval()
    for param in vae.parameters():
        param.requires_grad = False
        
    model = far.__dict__[args.model](
        img_size=args.img_size,
        vae_stride=args.vae_stride,
        patch_size=args.patch_size,
        vae_embed_dim=args.vae_embed_dim,
        mask=args.mask,
        mask_ratio_min=args.mask_ratio_min,
        label_drop_prob=args.label_drop_prob,
        class_num=args.class_num,
        attn_dropout=args.attn_dropout,
        proj_dropout=args.proj_dropout,
        buffer_size=args.buffer_size,
        diffloss_d=args.diffloss_d,
        diffloss_w=args.diffloss_w,
        num_sampling_steps=args.num_sampling_steps,
        diffusion_batch_mul=args.diffusion_batch_mul,
    )

    
    print("Model = %s" % str(model))
    # following timm: set wd as 0 for bias and norm layers
    n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print("Number of trainable parameters: {}M".format(n_params / 1e6))

    model.to(device)
    model_without_ddp = model

    eff_batch_size = args.batch_size * misc.get_world_size()

    if args.lr is None:  # only base_lr is specified
        args.lr = args.blr * eff_batch_size / 256

    print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
    print("actual lr: %.2e" % args.lr)
    print("effective batch size: %d" % eff_batch_size)

    if args.distributed and not args.inference:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
        model_without_ddp = model.module

    # no weight decay on bias, norm layers, and diffloss MLP
    param_groups = misc.add_weight_decay(model_without_ddp, args.weight_decay)
    optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))
    print(optimizer)
    loss_scaler = NativeScaler()

    
    ### resume training
    if args.finetune or args.inference:
        # 从预训练模型或微调模型加载参数
        checkpoint_path = args.pretrained_path
        print(f"Loading model from {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        model_params = list(model_without_ddp.parameters())
        ema_state_dict = checkpoint['model_ema']
        ema_params = [ema_state_dict[name].cuda() for name, _ in model_without_ddp.named_parameters()]
        print("Loaded model for inference/finetuning")
        
        # 如果是推理模式且使用EMA参数
        if args.inference and args.use_ema:
            print("Using EMA parameters for inference")
            ema_state_dict = {name: ema_params[i] for i, (name, _) in enumerate(model_without_ddp.named_parameters())}
            model_without_ddp.load_state_dict(ema_state_dict)
            
    elif args.resume and os.path.exists(os.path.join(args.resume, "checkpoint-last.pth")):
        checkpoint = torch.load(os.path.join(args.resume, "checkpoint-last.pth"), map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        model_params = list(model_without_ddp.parameters())
        ema_state_dict = checkpoint['model_ema']
        ema_params = [ema_state_dict[name].cuda() for name, _ in model_without_ddp.named_parameters()]
        print("Resume checkpoint %s" % args.resume)

        if 'optimizer' in checkpoint and 'epoch' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer'])
            args.start_epoch = checkpoint['epoch'] + 1
            if 'scaler' in checkpoint:
                loss_scaler.load_state_dict(checkpoint['scaler'])
            print("With optim & sched!")
        del checkpoint
    else:
        model_params = list(model_without_ddp.parameters())
        ema_params = copy.deepcopy(model_params)
        print("Training from scratch")
    
    # 处理单图推理
    if args.inference:
        inference_single_image(model_without_ddp, vae, args.input_image, args.output_image, device, args)
        return
        
    # evaluate FID and IS
    if args.evaluate:
        torch.cuda.empty_cache()
        evaluate(model_without_ddp, vae, ema_params, args, 0, batch_size=args.eval_bsz, log_writer=log_writer, cfg=args.cfg, use_ema=True)
        return

    # 评估去噪性能
    if args.evaluate_denoising:
        print("开始评估去噪模型在测试集上的性能...")
        # 如果使用EMA参数进行评估
        if args.use_ema:
            print("使用EMA参数进行评估")
            ema_state_dict = {name: ema_params[i] for i, (name, _) in enumerate(model_without_ddp.named_parameters())}
            model_without_ddp.load_state_dict(ema_state_dict)
        
        # 调用评估函数
        evaluate_denoising(model_without_ddp, vae, args, device)
        return

    # augmentation following DiT and ADM
    transform_train = transforms.Compose([
        transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.img_size)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ])

    # 根据是否进行图像去噪微调来选择不同的数据集
    if args.finetune:
        # 使用自定义的去噪数据集
        dataset_train = DenoiseDataset(args.lq_list_path, args.gt_list_path, transform=transform_train, img_size=args.img_size)
        print(f"Loaded denoising dataset with {len(dataset_train)} image pairs")
    elif args.use_cached:
        dataset_train = CachedFolder(args.cached_path)
    else:
        dataset_train = datasets.ImageFolder(os.path.join(args.data_path, 'train'), transform=transform_train)
    print(dataset_train)

    sampler_train = torch.utils.data.DistributedSampler(
        dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
    )
    print("Sampler_train = %s" % str(sampler_train))

    data_loader_train = torch.utils.data.DataLoader(
        dataset_train, sampler=sampler_train,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
    )
    
    # 用于可视化的固定样本,只获取一次
    fixed_vis_samples = None
    
    # training
    print(f"Start training for {args.epochs} epochs")
    start_time = time.time()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            data_loader_train.sampler.set_epoch(epoch)

        train_stats = train_one_epoch(
            model, vae,
            model_params, ema_params,
            data_loader_train,
            optimizer, device, epoch, loss_scaler,
            log_writer=log_writer,
            args=args
        )

        # 每个epoch结束后生成示例图像以可视化微调效果
        if args.finetune and misc.is_main_process() and (epoch % 5 == 0 or epoch + 1 == args.epochs):
            # 创建保存可视化结果的目录
            vis_dir = os.path.join(args.output_dir, 'visualizations')
            os.makedirs(vis_dir, exist_ok=True)
            
            print(f"正在生成epoch {epoch}的示例图像...")
            
            # 从训练数据中抽取几对图像进行可视化
            model_eval = model_without_ddp if args.distributed else model
            model_eval.eval()
            
            # 获取固定的示例图像
            with torch.no_grad():
                # 如果是第一次,获取并保存固定样本
                if fixed_vis_samples is None:
                    try:
                        # 获取一批数据用于可视化
                        lq_samples, gt_samples = next(iter(data_loader_train))
                        lq_samples = lq_samples[:4].to(device)  # 只取前4个样本
                        gt_samples = gt_samples[:4].to(device)
                        # 保存下来供后续使用
                        fixed_vis_samples = (lq_samples.clone(), gt_samples.clone())
                        print("已保存固定样本用于可视化")
                        
                        # 打印四张图像的路径
                        print("\n===== 用于可视化的四张图像路径 =====")
                        for i in range(4):
                            # 获取当前批次中的图像路径
                            if hasattr(data_loader_train.dataset, 'lq_paths') and hasattr(data_loader_train.dataset, 'gt_paths'):
                                lq_path = data_loader_train.dataset.lq_paths[i]
                                gt_path = data_loader_train.dataset.gt_paths[i]
                                print(f"图像 {i+1}:")
                                print(f"  LQ路径: {lq_path}")
                                print(f"  GT路径: {gt_path}")
                            else:
                                print(f"图像 {i+1}: 无法获取路径信息")
                        print("=====================================\n")
                    except Exception as e:
                        print(f"获取固定样本时出错: {e}")
                        # 如果获取失败,跳过本次可视化
                        model_eval.train()
                        continue
                
                # 使用保存的固定样本
                if fixed_vis_samples is not None:
                    try:
                        lq_samples, gt_samples = fixed_vis_samples
                        
                        # 编码到潜在空间
                        lq_posterior = vae.encode(lq_samples)
                        lq_latents = lq_posterior.sample().mul_(0.2325)
                        
                        # 使用微调后的模型生成增强结果
                        enhanced_latents = model_eval.denoise(lq_latents)
                        enhanced_images = vae.decode(enhanced_latents / 0.2325)
                        
                        # 创建可视化网格:原始LQ图像、增强后图像、GT图像
                        from torchvision.utils import make_grid, save_image
                        
                        # 将图像范围从[-1,1]调整到[0,1]用于保存
                        lq_vis = (lq_samples + 1) / 2
                        enhanced_vis = (enhanced_images + 1) / 2
                        gt_vis = (gt_samples + 1) / 2
                        
                        # 沿batch维度拼接
                        comparison = torch.cat([lq_vis, enhanced_vis, gt_vis], dim=0)
                        grid = make_grid(comparison, nrow=4, padding=2, normalize=False)
                        
                        # 保存网格图像
                        save_path = os.path.join(vis_dir, f'comparison_epoch_{epoch}.png')
                        save_image(grid, save_path)
                        print(f"示例图像已保存至: {save_path}")
                    except Exception as e:
                        print(f"生成示例图像时出错: {e}")
            
            model_eval.train()

        # save checkpoint
        if epoch % args.save_last_freq == 0 or epoch + 1 == args.epochs:
            misc.save_model(args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
                            loss_scaler=loss_scaler, epoch=epoch, ema_params=ema_params)
        # online evaluation
        if args.online_eval and (epoch % args.eval_freq == 0 or epoch + 1 == args.epochs):
            torch.cuda.empty_cache()
            evaluate(model_without_ddp, vae, ema_params, args, epoch, batch_size=args.eval_bsz, log_writer=log_writer, cfg=1.0, use_ema=True)
            if not (args.cfg == 1.0 or args.cfg == 0.0):
                evaluate(model_without_ddp, vae, ema_params, args, epoch, batch_size=args.eval_bsz // 2, log_writer=log_writer, cfg=args.cfg, use_ema=True)
            torch.cuda.empty_cache()

        if misc.is_main_process():
            if log_writer is not None:
                log_writer.flush()

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))


if __name__ == '__main__':
    args = get_args_parser()
    args = args.parse_args()

    Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    args.log_dir = args.output_dir
    main(args)

inference.sh

### ---------------------------- FAR_Base ----------------------------  ###
# torchrun --nnodes=1 --nproc_per_node=1  main_far.py \
# --img_size 256 --vae_path pretrained/vae_mar/kl16.ckpt --vae_embed_dim 16 --vae_stride 16 --patch_size 1 \
# --model far_base --diffloss_d 6 --diffloss_w 1024 \
# --eval_bsz 32 --num_images 1000 \
# --num_iter 10 --num_sampling_steps 100 --cfg 3.0 --cfg_schedule linear --temperature 1.0 \
# --output_dir pretrained_models/far/far_base \
# --resume pretrained_models/far/far_base \
# --data_path ${IMAGENET_PATH} --evaluate

### ---------------------------- FAR_Large ----------------------------  ###
torchrun --nnodes=1 --nproc_per_node=1  main_far.py \
--img_size 256 --vae_path pretrained/vae_mar/kl16.ckpt --vae_embed_dim 16 --vae_stride 16 --patch_size 1 \
--model far_large --diffloss_d 3 --diffloss_w 1024 \
--eval_bsz 32 --num_images 1000 \
--num_iter 10 --num_sampling_steps 100 --cfg 3.0 --cfg_schedule linear --temperature 1.0 \
--output_dir pretrained_models/far/far_large \
--resume pretrained_models/far/far_large \
--data_path ${IMAGENET_PATH} --evaluate

### ---------------------------- FAR_Huge ----------------------------  ###
# torchrun --nnodes=1 --nproc_per_node=1  main_far.py \
# --img_size 256 --vae_path pretrained/vae_mar/kl16.ckpt --vae_embed_dim 16 --vae_stride 16 --patch_size 1 \
# --model far_huge --diffloss_d 3 --diffloss_w 1024 \
# --eval_bsz 32 --num_images 1000 \
# --num_iter 10 --num_sampling_steps 100 --cfg 3.0 --cfg_schedule linear --temperature 1.0 \
# --output_dir pretrained_models/far/far_huge \
# --resume pretrained_models/far/far_huge \
# --data_path ${IMAGENET_PATH} --evaluate

### ---------------------------- FAR_Large Denoising ----------------------------  ###
# torchrun --nnodes=1 --nproc_per_node=1  main_far.py \
# --img_size 256 --vae_path pretrained/vae_mar/kl16.ckpt --vae_embed_dim 16 --vae_stride 16 --patch_size 1 \
# --model far_large --diffloss_d 3 --diffloss_w 1024 \
# --eval_bsz 32 --num_images 10 \
# --num_iter 10 --num_sampling_steps 100 --cfg 1.0 --cfg_schedule linear --temperature 0.8 \
# --output_dir ./output_dir_denoise \
# --lq_list_path /data/vjuicefs_ai_camera_jgroup_research/public_data/11164225/datasets/Magic_Dictionary/test_list/test.LQ.list \
# --gt_list_path /data/vjuicefs_ai_camera_jgroup_research/public_data/11164225/datasets/Magic_Dictionary/test_list/test.GT.list \
# --finetune --pretrained_path ./output_dir_denoise/checkpoint-last.pth \
# --evaluate

### ---------------------------- FAR_Large Denoising Evaluation ----------------------------  ###
# 评估微调后的去噪模型效果
torchrun --nnodes=1 --nproc_per_node=1 main_far.py \
--img_size 256 --vae_path pretrained_models/vae/kl16.ckpt --vae_embed_dim 16 --vae_stride 16 --patch_size 1 \
--model far_large --diffloss_d 3 --diffloss_w 1024 \
--lq_list_path ./data/test/testLQ.list \
--gt_list_path ./data/test/testGT.list \
--output_dir ./output_dir_denoise \
--finetune --pretrained_path ./output_dir_denoise/checkpoint-last.pth \
--evaluate_denoising --save_samples --use_ema

train.py

### ---------------------------- FAR_Base ----------------------------  ###
# torchrun --nproc_per_node=8 --nnodes=4 --node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} \
# main_far.py \
# --img_size 256 --vae_path pretrained_models/vae/kl16.ckpt --vae_embed_dim 16 --vae_stride 16 --patch_size 1 \
# --model far_base --diffloss_d 6 --diffloss_w 1024 \
# --epochs 400 --warmup_epochs 100 --batch_size 64 --blr 1.0e-4 --diffusion_batch_mul 4 \
# --output_dir ${OUTPUT_DIR} --resume ${OUTPUT_DIR} \
# --data_path ${IMAGENET_PATH}

### ---------------------------- FAR_Large ----------------------------  ###
# torchrun --nproc_per_node=8 --nnodes=4 --node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} \
# main_far.py \
# --img_size 256 --vae_path pretrained_models/vae/kl16.ckpt --vae_embed_dim 16 --vae_stride 16 --patch_size 1 \
# --model far_large --diffloss_d 3 --diffloss_w 1024 \
# --epochs 400 --warmup_epochs 100 --batch_size 64 --blr 1.0e-4 --diffusion_batch_mul 4 \
# --output_dir ${OUTPUT_DIR} --resume ${OUTPUT_DIR} \
# --data_path ${IMAGENET_PATH}

### ---------------------------- FAR_Huge ----------------------------  ###
# torchrun --nproc_per_node=8 --nnodes=4 --node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} \
# main_far.py \
# --img_size 256 --vae_path pretrained_models/vae/kl16.ckpt --vae_embed_dim 16 --vae_stride 16 --patch_size 1 \
# --model far_huge --diffloss_d 3 --diffloss_w 1024 \
# --epochs 400 --warmup_epochs 100 --batch_size 64 --blr 1.0e-4 --diffusion_batch_mul 4 \
# --output_dir ${OUTPUT_DIR} --resume ${OUTPUT_DIR} \
# --data_path ${IMAGENET_PATH}

### ---------------------------- FAR_Large Finetune for Denoising ----------------------------  ###
# 设置环境变量
export NODE_RANK=0
export MASTER_ADDR="127.0.0.1"
export MASTER_PORT=29500

# 运行微调命令
torchrun --nproc_per_node=4 --nnodes=1 --node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} \
main_far.py \
--img_size 256 --vae_path pretrained_models/vae/kl16.ckpt --vae_embed_dim 16 --vae_stride 16 --patch_size 1 \
--model far_large --diffloss_d 3 --diffloss_w 1024 \
--finetune --pretrained_path ./pretrained_models/far/far_large/checkpoint-last.pth \
--lq_list_path /data/vjuicefs_ai_camera_jgroup_research/public_data/11164225/datasets/Magic_Dictionary/training_list/20250105.6789degq_ShMfwowCaGq3_13_Mfww_r6789_G35l0105.6_6_1.LQ.list \
--gt_list_path /data/vjuicefs_ai_camera_jgroup_research/public_data/11164225/datasets/Magic_Dictionary/training_list/20250105.6789degq_ShMfwowCaGq3_13_Mfww_r6789_G35l0105.6_6_1.GT.list \
--epochs 100 --warmup_epochs 10 --batch_size 32 --blr 1.0e-5 --diffusion_batch_mul 4 \
--output_dir ./output_dir_denoise

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值