far.py
from functools import partial
import numpy as np
from tqdm import tqdm
import scipy.stats as stats
import math
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from timm.models.vision_transformer import Block
from models.diffloss import DiffLoss
import torch.nn.functional as F
import torchvision.transforms as T
import random
from torchvision.utils import make_grid
from typing import Optional
from PIL import Image
def save_image(images: torch.Tensor, nrow: int = 8, show: bool = True, path: Optional[str] = None, format: Optional[str] = None, to_grayscale: bool = False, **kwargs):
images = images * 0.5 + 0.5
grid = make_grid(images, nrow=nrow, **kwargs) # (channels, height, width)
# (height, width, channels)
grid = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
im = Image.fromarray(grid)
if to_grayscale:
im = im.convert(mode="L")
if path is not None:
im.save(path, format=format)
if show:
im.show()
return grid
def mask_by_order(mask_len, order, bsz, seq_len):
masking = torch.zeros(bsz, seq_len).cuda()
masking = torch.scatter(masking, dim=-1, index=order[:, :mask_len.long()], src=torch.ones(bsz, seq_len).cuda()).bool()
return masking
class FAR(nn.Module):
""" Masked Autoencoder with VisionTransformer backbone
"""
def __init__(self, img_size=256, vae_stride=16, patch_size=1,
encoder_embed_dim=1024, encoder_depth=16, encoder_num_heads=16,
decoder_embed_dim=1024, decoder_depth=16, decoder_num_heads=16,
mlp_ratio=4., norm_layer=nn.LayerNorm,
vae_embed_dim=16,
mask=True,
mask_ratio_min=0.7,
label_drop_prob=0.1,
class_num=1000,
attn_dropout=0.1,
proj_dropout=0.1,
buffer_size=64,
diffloss_d=3,
diffloss_w=1024,
num_sampling_steps='100',
diffusion_batch_mul=4
):
super().__init__()
# --------------------------------------------------------------------------
# VAE and patchify specifics
self.vae_embed_dim = vae_embed_dim
self.img_size = img_size
self.vae_stride = vae_stride
self.patch_size = patch_size
self.seq_h = self.seq_w = img_size // vae_stride // patch_size
self.seq_len = self.seq_h * self.seq_w
self.token_embed_dim = vae_embed_dim * patch_size**2
# --------------------------------------------------------------------------
# Class Embedding
self.num_classes = class_num
self.class_emb = nn.Embedding(1000, encoder_embed_dim)
self.label_drop_prob = label_drop_prob
# Fake class embedding for CFG's unconditional generation
self.fake_latent = nn.Parameter(torch.zeros(1, encoder_embed_dim))
self.loss_weight = [1 + np.sin(math.pi / 2. * (bands + 1) / self.seq_h) for bands in range(self.seq_h)]
# --------------------------------------------------------------------------
self.mask = mask
# FAR variant masking ratio, a left-half truncated Gaussian centered at 100% masking ratio with std 0.25
self.mask_ratio_generator = stats.truncnorm((mask_ratio_min - 1.0) / 0.25, 0, loc=1.0, scale=0.25)
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
# --------------------------------------------------------------------------
# FAR encoder specifics
self.z_proj = nn.Linear(self.token_embed_dim, encoder_embed_dim, bias=True)
self.z_proj_ln = nn.LayerNorm(encoder_embed_dim, eps=1e-6)
self.buffer_size = buffer_size
self.encoder_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len + self.buffer_size, encoder_embed_dim))
self.encoder_blocks = nn.ModuleList([
Block(encoder_embed_dim, encoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer,
proj_drop=proj_dropout, attn_drop=attn_dropout) for _ in range(encoder_depth)])
self.encoder_norm = norm_layer(encoder_embed_dim)
# --------------------------------------------------------------------------
# FAR decoder specifics
self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True)
self.decoder_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len + self.buffer_size, decoder_embed_dim))
self.decoder_blocks = nn.ModuleList([
Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer,
proj_drop=proj_dropout, attn_drop=attn_dropout) for _ in range(decoder_depth)])
self.decoder_norm = norm_layer(decoder_embed_dim)
self.diffusion_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len, decoder_embed_dim))
self.initialize_weights()
# --------------------------------------------------------------------------
# Diffusion Loss
self.diffloss = DiffLoss(
target_channels=self.token_embed_dim,
z_channels=decoder_embed_dim,
width=diffloss_w,
depth=diffloss_d,
num_sampling_steps=num_sampling_steps,
)
self.diffusion_batch_mul = diffusion_batch_mul
def initialize_weights(self):
# parameters
torch.nn.init.normal_(self.class_emb.weight, std=.02)
torch.nn.init.normal_(self.fake_latent, std=.02)
if self.mask:
torch.nn.init.normal_(self.mask_token, std=.02)
torch.nn.init.normal_(self.encoder_pos_embed_learned, std=.02)
torch.nn.init.normal_(self.decoder_pos_embed_learned, std=.02)
torch.nn.init.normal_(self.diffusion_pos_embed_learned, std=.02)
# initialize nn.Linear and nn.LayerNorm
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
# we use xavier_uniform following official JAX ViT:
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
if m.bias is not None:
nn.init.constant_(m.bias, 0)
if m.weight is not None:
nn.init.constant_(m.weight, 1.0)
def patchify(self, x):
bsz, c, h, w = x.shape
p = self.patch_size
h_, w_ = h // p, w // p
x = x.reshape(bsz, c, h_, p, w_, p)
x = torch.einsum('nchpwq->nhwcpq', x)
x = x.reshape(bsz, h_ * w_, c * p ** 2)
return x # [n, l, d]
def unpatchify(self, x):
bsz = x.shape[0]
p = self.patch_size
c = self.vae_embed_dim
h_, w_ = self.seq_h, self.seq_w
x = x.reshape(bsz, h_, w_, c, p, p)
x = torch.einsum('nhwcpq->nchpwq', x)
x = x.reshape(bsz, c, h_ * p, w_ * p)
return x # [n, c, h, w]
def sample_orders(self, bsz):
# generate a batch of random generation orders
orders = []
for _ in range(bsz):
order = np.array(list(range(self.seq_len)))
np.random.shuffle(order)
orders.append(order)
orders = torch.Tensor(np.array(orders)).cuda().long()
return orders
def random_masking(self, x, orders, x_index):
# generate token mask
bsz, seq_len, embed_dim = x.shape
stage = x_index[0]
mask_ratio_min = 0.7 * stage / 16
mask_rate = stats.truncnorm((mask_ratio_min - 1.0) / 0.25, 0, loc=1.0, scale=0.25).rvs(1)[0]
num_masked_tokens = int(np.ceil(seq_len * mask_rate))
mask = torch.zeros(bsz, seq_len, device=x.device)
mask = torch.scatter(mask, dim=-1, index=orders[:, :num_masked_tokens],
src=torch.ones(bsz, seq_len, device=x.device))
return mask
def processingpregt_latent(self, imgs):
B, C, H, W = imgs.shape
out = torch.zeros_like(imgs)
latent_core = list(range(H))
core_index = []
random_number = torch.randint(0, len(latent_core), (1,))
for i in range(B):
chosen_core = latent_core[random_number]
core_index.append(chosen_core)
if random_number == 0:
out[i] = torch.zeros(C, H, W).cuda() # torch.Size([256, 256, 16])
else:
imgs_resize = F.interpolate(imgs[i].unsqueeze(0), size=(chosen_core, chosen_core), mode='area')
out[i] = F.interpolate(imgs_resize, size=(H, W), mode='bicubic').squeeze(0)
core_index = torch.tensor(core_index).to(out.device).half()
return out, core_index
def forward_mae_encoder(self, x, class_embedding, mask=None):
x = self.z_proj(x)
bsz, seq_len, embed_dim = x.shape
# concat buffer
x = torch.cat([torch.zeros(bsz, self.buffer_size, embed_dim, device=x.device), x], dim=1)
if self.mask:
mask_with_buffer = torch.cat([torch.zeros(x.size(0), self.buffer_size, device=x.device), mask], dim=1)
# random drop class embedding during training
if self.training:
drop_latent_mask = torch.rand(bsz) < self.label_drop_prob
drop_latent_mask = drop_latent_mask.unsqueeze(-1).cuda().to(x.dtype)
class_embedding = drop_latent_mask * self.fake_latent + (1 - drop_latent_mask) * class_embedding
x[:, :self.buffer_size] = class_embedding.unsqueeze(1)
# encoder position embedding
x = x + self.encoder_pos_embed_learned
x = self.z_proj_ln(x)
# dropping
if self.mask:
x = x[(1-mask_with_buffer).nonzero(as_tuple=True)].reshape(bsz, -1, embed_dim)
# apply Transformer blocks
for blk in self.encoder_blocks:
x = blk(x)
x = self.encoder_norm(x)
return x
def forward_mae_decoder(self, x, mask=None):
x = self.decoder_embed(x)
if self.mask:
mask_with_buffer = torch.cat([torch.zeros(x.size(0), self.buffer_size, device=x.device), mask], dim=1)
# pad mask tokens
mask_tokens = self.mask_token.repeat(mask_with_buffer.shape[0], mask_with_buffer.shape[1], 1).to(x.dtype)
x_after_pad = mask_tokens.clone()
x_after_pad[(1 - mask_with_buffer).nonzero(as_tuple=True)] = x.reshape(x.shape[0] * x.shape[1], x.shape[2])
x = x_after_pad + self.decoder_pos_embed_learned
else:
x = x + self.decoder_pos_embed_learned
# apply Transformer blocks
for blk in self.decoder_blocks:
x = blk(x)
x = self.decoder_norm(x)
x = x[:, self.buffer_size:]
x = x + self.diffusion_pos_embed_learned
return x
def forward_loss(self, z, target, mask, index, loss_weight=False):
bsz, seq_len, _ = target.shape
current_diffusion_batch_mul = self.diffusion_batch_mul # 开始时使用1
target = target.reshape(bsz * seq_len, -1).repeat(current_diffusion_batch_mul, 1)
index = index.unsqueeze(1).unsqueeze(-1).repeat(1, seq_len, 1).reshape(bsz * seq_len, -1).repeat(current_diffusion_batch_mul, 1)
z = z.reshape(bsz*seq_len, -1).repeat(current_diffusion_batch_mul, 1)
if loss_weight:
loss_weight = loss_weight.unsqueeze(1).repeat(1, seq_len).reshape(bsz * seq_len).repeat(current_diffusion_batch_mul)
loss = self.diffloss(z=z, target=target, index=index, loss_weight=loss_weight)
return loss
def forward(self, lq_imgs, gt_imgs, loss_weight=False):
freq_level = random.randint(14,15)
freq_embedding = self.fake_latent.repeat(gt_imgs.shape[0], 1)
gt_latents = self.patchify(gt_imgs)
lq_latents = self.patchify(lq_imgs)
if loss_weight:
loss_weight = self.loss_weight
x_index = torch.tensor([freq_level]*gt_imgs.size(0)).to(gt_imgs.device).half()
# 保存当前使用的频率级别,以便在训练中监控
self.last_used_freq_level = freq_level
mask=None
x = self.forward_mae_encoder(lq_latents, freq_embedding, mask)
z = self.forward_mae_decoder(x, mask)
# 5. 计算损失,使用原始gt_latents
loss = self.forward_loss(z=z, target=gt_latents, mask=None, index=x_index, loss_weight=loss_weight)
return loss
def sample_tokens_mask(self, bsz, num_iter=64, cfg=1.0, cfg_schedule="linear", labels=None, temperature=1.0, progress=False):
latent_core = [0,12,12,4,5,6,7,8,9,10,11]
num_iter = len(latent_core)
mask = torch.ones(bsz, self.seq_len).cuda()
tokens = torch.zeros(bsz, self.seq_len, self.token_embed_dim).cuda()
orders = self.sample_orders(bsz)
for step in list(range(num_iter)):
cur_tokens = torch.zeros(bsz, self.seq_len, self.token_embed_dim).cuda()
if labels is not None:
class_embedding = self.class_emb(labels)
else:
class_embedding = self.fake_latent.repeat(bsz, 1)
tokens = torch.cat([tokens, tokens], dim=0)
class_embedding = torch.cat([class_embedding, self.fake_latent.repeat(bsz, 1)], dim=0)
mask = torch.cat([mask, mask], dim=0)
x = self.forward_mae_encoder(tokens, class_embedding, mask)
z = self.forward_mae_decoder(x, mask) # torch.Size([512, 256, 768])
B, L, C = z.shape
z = z.reshape(B * L, -1)
# mask ratio for the next round, following MaskGIT and MAGE.
mask_ratio = np.cos(math.pi / 2. * (step + 1) / num_iter)
mask_len = torch.Tensor([np.floor(self.seq_len * mask_ratio)]).cuda() # 251, 236, 212, 181, 142, 97, 49, 0
cfg_iter = 1 + (cfg - 1) * step / num_iter
temperature_iter = 0.85 + (temperature - 0.85) * step / num_iter
index = torch.tensor([latent_core[step]]).unsqueeze(1).unsqueeze(-1).repeat(B, L, 1).reshape(B * L, -1).to(z.device).half()
z = self.diffloss.sample(z, temperature_iter, cfg_iter, index) # torch.Size([512, 16])
z, _ = z.chunk(2, dim=0) # Remove null class samples. torch.Size([256, 16])
if step < num_iter-1:
z = z.reshape(bsz, L, -1).transpose_(1, 2).reshape(bsz, -1, 16, 16)
if step > 0:
imgs_resize = F.interpolate(z, size=(latent_core[step+1], latent_core[step+1]), mode='area')
z = F.interpolate(imgs_resize, size=(16, 16), mode='bicubic')
z = z.reshape(bsz, -1, L).transpose_(1, 2).reshape(bsz*L, -1)
sampled_token = z.reshape(bsz, L, -1)
mask_next = mask_by_order(mask_len[0], orders, bsz, self.seq_len)
mask_to_pred = torch.logical_not(mask_next)
mask = mask_next
sampled_token_latent = sampled_token[mask_to_pred.nonzero(as_tuple=True)]
cur_tokens[mask_to_pred.nonzero(as_tuple=True)] = sampled_token_latent
tokens = cur_tokens.clone()
tokens = tokens.transpose_(1, 2).reshape(bsz, -1, 16, 16)
return tokens
def sample_tokens_nomask(self, bsz, num_iter=64, cfg=1.0, cfg_schedule="linear", labels=None, temperature=1.0, progress=False):
latent_core = [0,2,3,4,5,6,7,8,9,10]
# latent_core = [0,1,2,3,4,5,6,7,8,9]
num_iter = len(latent_core)
# init and sample generation orders
tokens = torch.zeros(bsz, self.seq_len, self.token_embed_dim).cuda().half()
indices = list(range(num_iter))
if progress:
indices = tqdm(indices)
for step in indices:
if labels is not None:
class_embedding = self.class_emb(labels) # torch.Size([256, 768])
else:
class_embedding = self.fake_latent.repeat(bsz, 1)
tokens = torch.cat([tokens, tokens], dim=0)
class_embedding = torch.cat([class_embedding, self.fake_latent.repeat(bsz, 1)], dim=0)
# mae encoder
x = self.forward_mae_encoder(tokens, class_embedding)
z = self.forward_mae_decoder(x) # torch.Size([512, 256, 768]) var输出的condition的维度很高(768),var一次生成所有token后,只随机取部分(nge)送到diffusion中作为条件,生成部分token。
B, L, C = z.shape
z = z.reshape(B * L, -1)
cfg_iter = 1 + (cfg - 1) * step / num_iter
temperature_iter = 0.8 + (1 - np.cos(math.pi / 2. * (step + 1) / num_iter)) * (1-0.8)
index = torch.tensor([latent_core[step]]).unsqueeze(1).unsqueeze(-1).repeat(B, L, 1).reshape(B * L, -1).to(z.device).half()
sampled_token_latent = self.diffloss.sample(z, temperature_iter, cfg_iter, index=index) # torch.Size([512, 16])
sampled_token_latent, _ = sampled_token_latent.chunk(2, dim=0) # Remove null class samples. torch.Size([256, 16])
sampled_token_latent = sampled_token_latent.reshape(bsz, L, -1).transpose_(1, 2).reshape(bsz, -1, 16, 16)
if step < num_iter-1:
if step > -1:
sampled_token_latent = F.interpolate(sampled_token_latent, size=(latent_core[step+1], latent_core[step+1]), mode='area')
sampled_token_latent = F.interpolate(sampled_token_latent, size=(16, 16), mode='bicubic')
sampled_token_latent = sampled_token_latent.view(bsz, 16, -1).transpose(1, 2)
tokens = sampled_token_latent.clone()
return tokens
def denoise(self, lq_img):
"""用于对单张低质量图像进行去噪处理的推理方法"""
bsz = lq_img.shape[0]
lq_latents = self.patchify(lq_img) # [bsz, seq_len, token_embed_dim]
# 1. 使用合适的固定频率索引
x_index = torch.ones(bsz, device=lq_img.device).mul(15).half()
# 2. 将LQ潜在表示投影到decoder_embed_dim维度
# 第一步投影: token_embed_dim -> encoder_embed_dim
lq_proj = self.z_proj(lq_latents)
# 使用层归一化以确保数值稳定性
lq_proj = self.z_proj_ln(lq_proj)
# 第二步投影: encoder_embed_dim -> decoder_embed_dim
z = self.decoder_embed(lq_proj)
# 3. 添加位置编码信息,这对于扩散模型很重要
z = z + self.diffusion_pos_embed_learned
# 使用正则化确保z维持在合理范围
z = self.decoder_norm(z)
# 应用额外的缩放以确保数值稳定
scale_factor = 0.1
z = z * scale_factor
# 使用diffloss的sample方法从z采样生成增强图像的潜在表示
# 设置较低的temperature提高输出的确定性
temperature = 0.8
cfg = 1.0 # 不使用分类器引导
# 获取序列长度
seq_len = z.shape[1]
# 批量处理所有位置
tokens_flat = z.reshape(bsz * seq_len, -1)
indices = x_index.unsqueeze(1).unsqueeze(-1).repeat(1, seq_len, 1).reshape(bsz * seq_len, -1)
# 使用diffloss的sample方法生成去噪结果
enhanced_tokens = self.diffloss.sample(tokens_flat, temperature, cfg, indices)
# 重要: 调试打印实际的张量形状
print(f"diffloss.sample输出形状: {enhanced_tokens.shape}")
# 参考sample_tokens_nomask方法处理输出格式
# 将结果重塑为与sample_tokens_nomask相同的格式
h = w = int(math.sqrt(seq_len)) # 通常是16x16
# 与sample_tokens_nomask保持一致的形状转换
enhanced_tokens = enhanced_tokens.reshape(bsz, seq_len, -1)
# 重要: 调试打印重塑后的张量形状
print(f"重塑后的形状: {enhanced_tokens.shape}")
# 转置并重塑为VAE解码器需要的空间形式 [B, C, H, W]
# 根据enhanced_tokens的实际形状调整转换方式
C = enhanced_tokens.shape[2] # 获取通道数
enhanced_tokens = enhanced_tokens.permute(0, 2, 1) # [bsz, C, seq_len]
enhanced_tokens = enhanced_tokens.reshape(bsz, C, h, w) # [bsz, C, h, w]
# 重要: 调试打印最终形状
print(f"最终输出形状: {enhanced_tokens.shape}")
return enhanced_tokens
def far_base(**kwargs):
model = FAR(
encoder_embed_dim=768, encoder_depth=12, encoder_num_heads=12,
decoder_embed_dim=768, decoder_depth=12, decoder_num_heads=12,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
def far_large(**kwargs):
model = FAR(
encoder_embed_dim=1024, encoder_depth=16, encoder_num_heads=16,
decoder_embed_dim=1024, decoder_depth=16, decoder_num_heads=16,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
def far_huge(**kwargs):
model = FAR(
encoder_embed_dim=1280, encoder_depth=20, encoder_num_heads=16,
decoder_embed_dim=1280, decoder_depth=20, decoder_num_heads=16,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
engine_far.py
import math
import sys
from typing import Iterable
import random
import torch
import util.misc as misc
import util.lr_sched as lr_sched
from models.vae import DiagonalGaussianDistribution
import torch_fidelity
import shutil
import cv2
import numpy as np
import matplotlib.pyplot as plt
import os
import copy
import time
from torchvision.utils import make_grid
from typing import Optional
from PIL import Image
def save_image(images: torch.Tensor, nrow: int = 8, show: bool = True, path: Optional[str] = None, format: Optional[str] = None, to_grayscale: bool = False, **kwargs):
images = images * 0.5 + 0.5
grid = make_grid(images, nrow=nrow, **kwargs) # (channels, height, width)
# (height, width, channels)
grid = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
im = Image.fromarray(grid)
if to_grayscale:
im = im.convert(mode="L")
if path is not None:
im.save(path, format=format)
if show:
im.show()
return grid
def update_ema(target_params, source_params, rate=0.99):
"""
Update target parameters to be closer to those of source parameters using
an exponential moving average.
:param target_params: the target parameter sequence.
:param source_params: the source parameter sequence.
:param rate: the EMA rate (closer to 1 means slower).
"""
for targ, src in zip(target_params, source_params):
targ.detach().mul_(rate).add_(src, alpha=1 - rate)
def train_one_epoch(model, vae,
model_params, ema_params,
data_loader: Iterable, optimizer: torch.optim.Optimizer,
device: torch.device, epoch: int, loss_scaler,
log_writer=None,
args=None):
model.train(True)
metric_logger = misc.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
header = 'Epoch: [{}]'.format(epoch)
print_freq = 10
optimizer.zero_grad()
if log_writer is not None:
print('log_dir: {}'.format(log_writer.log_dir))
# 添加更多统计信息的跟踪
all_losses = []
freq_levels_used = []
param_norms_before = [p.norm().item() for p in model.parameters() if p.requires_grad]
# 用于记录数据处理的统计信息
num_samples_processed = 0
for data_iter_step, data in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
# we use a per iteration (instead of per epoch) lr scheduler
lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
# 处理数据:根据是否为微调任务选择不同的处理方式
if args.finetune:
# 微调任务:处理LQ和GT图像对
lq_samples, gt_samples = data
lq_samples = lq_samples.to(device, non_blocking=True)
gt_samples = gt_samples.to(device, non_blocking=True)
with torch.no_grad():
# 将LQ和GT图像都编码到潜在空间
lq_posterior = vae.encode(lq_samples)
lq_x = lq_posterior.sample().mul_(0.2325)
gt_posterior = vae.encode(gt_samples)
gt_x = gt_posterior.sample().mul_(0.2325)
# 创建假标签(全零)- 因为我们不使用类别条件
fake_labels = torch.zeros(lq_samples.size(0), dtype=torch.long, device=device)
# 前向传播 - 将LQ和GT的潜在表示传入模型
with torch.cuda.amp.autocast():
# 注意:这里假设model的forward已经修改为接受lq_latents和gt_latents
# 用户说明他会自己修改forward函数,所以这里我们假设它已经被修改
loss = model(lq_x, gt_x, loss_weight=args.loss_weight)
# 记录当前使用的频率级别(从模型内部获取)
# 注意:这里假设forward方法中使用的频率级别可以通过检查来获取
if hasattr(model, 'module'):
freq_level = model.module.last_used_freq_level if hasattr(model.module, 'last_used_freq_level') else random.randint(14, 15)
else:
freq_level = model.last_used_freq_level if hasattr(model, 'last_used_freq_level') else random.randint(14, 15)
freq_levels_used.append(freq_level)
else:
# 原始任务:处理单个图像和类别标签
samples, labels = data
samples = samples.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True)
with torch.no_grad():
if args.use_cached:
moments = samples
posterior = DiagonalGaussianDistribution(moments)
else:
posterior = vae.encode(samples)
# normalize the std of latent to be 1. Change it if you use a different tokenizer
x = posterior.sample().mul_(0.2325)
# forward
with torch.cuda.amp.autocast():
loss = model(x, labels, loss_weight=args.loss_weight)
loss_value = loss.item()
all_losses.append(loss_value)
num_samples_processed += lq_samples.size(0) if args.finetune else samples.size(0)
if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value))
sys.exit(1)
loss_scaler(loss, optimizer, clip_grad=args.grad_clip, parameters=model.parameters(), update_grad=True)
optimizer.zero_grad()
torch.cuda.synchronize()
update_ema(ema_params, model_params, rate=args.ema_rate)
metric_logger.update(loss=loss_value)
lr = optimizer.param_groups[0]["lr"]
metric_logger.update(lr=lr)
loss_value_reduce = misc.all_reduce_mean(loss_value)
if log_writer is not None:
""" We use epoch_1000x as the x-axis in tensorboard.
This calibrates different curves when batch size changes.
"""
epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x)
log_writer.add_scalar('lr', lr, epoch_1000x)
# 计算参数变化的统计信息
param_norms_after = [p.norm().item() for p in model.parameters() if p.requires_grad]
param_changes = [abs(after - before) / (before + 1e-10) for before, after in zip(param_norms_before, param_norms_after)]
# 在epoch结束时输出详细统计信息
if all_losses:
loss_mean = sum(all_losses) / len(all_losses)
loss_min = min(all_losses)
loss_max = max(all_losses)
loss_std = np.std(all_losses) if len(all_losses) > 1 else 0
# 计算参数变化的统计信息
param_change_mean = sum(param_changes) / len(param_changes) if param_changes else 0
param_change_max = max(param_changes) if param_changes else 0
# 计算频率级别的统计信息
freq_level_counts = {}
for freq in freq_levels_used:
freq_level_counts[freq] = freq_level_counts.get(freq, 0) + 1
# 打印详细信息
print("\n" + "="*80)
print(f"微调统计信息 - Epoch {epoch}")
print(f"处理样本数: {num_samples_processed}")
print(f"损失统计: 平均={loss_mean:.6f}, 最小={loss_min:.6f}, 最大={loss_max:.6f}, 标准差={loss_std:.6f}")
if freq_levels_used:
print(f"频率级别使用情况: {freq_level_counts}")
print(f"参数变化: 平均相对变化={param_change_mean:.6f}, 最大相对变化={param_change_max:.6f}")
# 如果损失有异常,输出警告
if loss_std > loss_mean * 0.5: # 如果标准差大于平均值的一半,可能有问题
print("警告: 损失波动较大,可能需要调整学习率或检查模型稳定性")
if param_change_mean < 1e-6:
print("警告: 参数变化很小,可能学习率过低或梯度消失")
elif param_change_max > 0.5:
print("警告: 部分参数变化较大,可能学习率过高或梯度爆炸")
print("="*80 + "\n")
# 如果有tensorboard,记录这些统计信息
if log_writer is not None:
log_writer.add_scalar('epoch_loss_mean', loss_mean, epoch)
log_writer.add_scalar('epoch_loss_min', loss_min, epoch)
log_writer.add_scalar('epoch_loss_max', loss_max, epoch)
log_writer.add_scalar('epoch_loss_std', loss_std, epoch)
log_writer.add_scalar('epoch_param_change_mean', param_change_mean, epoch)
log_writer.add_scalar('epoch_param_change_max', param_change_max, epoch)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
def evaluate(model_without_ddp, vae, ema_params, args, epoch, batch_size=16, log_writer=None, cfg=1.0,
use_ema=True):
model_without_ddp.eval()
num_steps = args.num_images // (batch_size * misc.get_world_size()) + 1
save_folder = os.path.join(args.output_dir, "ariter{}-diffsteps{}-temp{}-{}cfg{}-image{}".format(args.num_iter,
args.num_sampling_steps,
args.temperature,
args.cfg_schedule,
cfg,
args.num_images))
if use_ema:
save_folder = save_folder + "_ema"
if args.evaluate:
save_folder = save_folder + "_evaluate"
print("Save to:", save_folder)
if misc.get_rank() == 0:
if not os.path.exists(save_folder):
os.makedirs(save_folder)
# switch to ema params
if use_ema:
model_state_dict = copy.deepcopy(model_without_ddp.state_dict())
ema_state_dict = copy.deepcopy(model_without_ddp.state_dict())
for i, (name, _value) in enumerate(model_without_ddp.named_parameters()):
assert name in ema_state_dict
ema_state_dict[name] = ema_params[i]
print("Switch to ema")
model_without_ddp.load_state_dict(ema_state_dict)
class_num = args.class_num
assert args.num_images % class_num == 0 # number of images per class must be the same
class_label_gen_world = np.arange(0, class_num).repeat(args.num_images // class_num)
class_label_gen_world = np.hstack([class_label_gen_world, np.zeros(50000)])
world_size = misc.get_world_size()
local_rank = misc.get_rank()
used_time = 0
gen_img_cnt = 0
for i in range(num_steps):
print("Generation step {}/{}".format(i, num_steps))
labels_gen = class_label_gen_world[world_size * batch_size * i + local_rank * batch_size:
world_size * batch_size * i + (local_rank + 1) * batch_size]
labels_gen = torch.Tensor(labels_gen).long().cuda()
torch.cuda.synchronize()
start_time = time.time()
# generation
with torch.no_grad():
with torch.cuda.amp.autocast():
if args.mask:
sampled_tokens = model_without_ddp.sample_tokens_mask(bsz=batch_size, num_iter=args.num_iter, cfg=cfg,
cfg_schedule=args.cfg_schedule, labels=labels_gen,
temperature=args.temperature)
else:
sampled_tokens = model_without_ddp.sample_tokens_nomask(bsz=batch_size, num_iter=args.num_iter, cfg=cfg,
cfg_schedule=args.cfg_schedule, labels=labels_gen,
temperature=args.temperature)
sampled_images = vae.decode(sampled_tokens / 0.2325)
# measure speed after the first generation batch
if i >= 1:
torch.cuda.synchronize()
used_time += time.time() - start_time
gen_img_cnt += batch_size
print("Generating {} images takes {:.5f} seconds, {:.5f} sec per image".format(gen_img_cnt, used_time, used_time / gen_img_cnt))
torch.distributed.barrier()
sampled_images = sampled_images.detach().cpu()
save_image(sampled_images, nrow=8, show=False, path=os.path.join(args.output_dir, f"epoch{epoch}.png"), to_grayscale=False)
return
sampled_images = (sampled_images + 1) / 2
# distributed save
for b_id in range(sampled_images.size(0)):
img_id = i * sampled_images.size(0) * world_size + local_rank * sampled_images.size(0) + b_id
if img_id >= args.num_images:
break
gen_img = np.round(np.clip(sampled_images[b_id].numpy().transpose([1, 2, 0]) * 255, 0, 255))
gen_img = gen_img.astype(np.uint8)[:, :, ::-1]
gen_img = gen_img[:, :, ::-1]
plt.imsave(os.path.join(save_folder, '{}.png'.format(str(img_id).zfill(5))), gen_img)
torch.distributed.barrier()
time.sleep(10)
# back to no ema
if use_ema:
print("Switch back from ema")
model_without_ddp.load_state_dict(model_state_dict)
# compute FID and IS
if log_writer is not None:
if args.img_size == 256:
input2 = None
fid_statistics_file = '/mnt/workspace/workgroup/yuhu/code/FAR/fid_stats/adm_in256_stats.npz'
else:
raise NotImplementedError
metrics_dict = torch_fidelity.calculate_metrics(
input1=save_folder,
input2=input2,
fid_statistics_file=fid_statistics_file,
cuda=True,
isc=True,
fid=True,
kid=False,
prc=False,
verbose=False,
)
fid = metrics_dict['frechet_inception_distance']
inception_score = metrics_dict['inception_score_mean']
print(fid)
print(inception_score)
postfix = ""
if use_ema:
postfix = postfix + "_ema"
if not cfg == 1.0:
postfix = postfix + "_cfg{}".format(cfg)
log_writer.add_scalar('fid{}'.format(postfix), fid, epoch)
log_writer.add_scalar('is{}'.format(postfix), inception_score, epoch)
print("FID: {:.4f}, Inception Score: {:.4f}".format(fid, inception_score))
# remove temporal saving folder
shutil.rmtree(save_folder)
torch.distributed.barrier()
time.sleep(10)
def cache_latents(vae,
data_loader: Iterable,
device: torch.device,
args=None):
metric_logger = misc.MetricLogger(delimiter=" ")
header = 'Caching: '
print_freq = 20
for data_iter_step, (samples, _, paths) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
samples = samples.to(device, non_blocking=True)
with torch.no_grad():
posterior = vae.encode(samples)
moments = posterior.parameters
posterior_flip = vae.encode(samples.flip(dims=[3]))
moments_flip = posterior_flip.parameters
for i, path in enumerate(paths):
save_path = os.path.join(args.cached_path, path + '.npz')
os.makedirs(os.path.dirname(save_path), exist_ok=True)
np.savez(save_path, moments=moments[i].cpu().numpy(), moments_flip=moments_flip[i].cpu().numpy())
if misc.is_dist_avail_and_initialized():
torch.cuda.synchronize()
return
def evaluate_denoising(model, vae, args, device):
"""评估去噪模型在测试集上的性能
Args:
model: 微调后的FAR模型
vae: VAE模型用于图像处理
args: 参数
device: 设备
Returns:
包含评估指标的字典
"""
print("开始评估去噪模型性能...")
# 加载测试数据列表
with open(args.lq_list_path, 'r') as f:
lq_paths = [line.strip() for line in f.readlines()]
with open(args.gt_list_path, 'r') as f:
gt_paths = [line.strip() for line in f.readlines()]
assert len(lq_paths) == len(gt_paths), "LQ和GT图像数量不匹配"
print(f"共加载 {len(lq_paths)} 对测试图像")
# 图像变换
transform = transforms.Compose([
transforms.Resize((args.img_size, args.img_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# 设置模型为评估模式
model.eval()
# 用于计算PSNR和SSIM的函数
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
# 记录指标
psnr_values = []
ssim_values = []
# 创建输出目录用于保存样本结果
if args.save_samples:
os.makedirs(os.path.join(args.output_dir, "test_samples"), exist_ok=True)
# 处理每一对图像
for idx, (lq_path, gt_path) in enumerate(zip(lq_paths, gt_paths)):
if idx % 100 == 0:
print(f"处理图像 {idx}/{len(lq_paths)}...")
try:
# 加载LQ和GT图像
lq_img = Image.open(lq_path).convert('RGB')
gt_img = Image.open(gt_path).convert('RGB')
# 应用变换
lq_tensor = transform(lq_img).unsqueeze(0).to(device)
gt_tensor = transform(gt_img).unsqueeze(0).to(device)
with torch.no_grad():
# 使用VAE编码获取潜在表示
lq_posterior = vae.encode(lq_tensor)
lq_latents = lq_posterior.sample().mul_(0.2325)
# 使用模型的denoise方法进行去噪
denoised_latent = model.denoise(lq_latents)
# 使用VAE解码获取去噪后的图像
denoised_img = vae.decode(denoised_latent / 0.2325)
# 将图像转换为适合计算指标的格式
# 从[-1,1]转换为[0,1]
lq_np = (lq_tensor + 1) / 2.0
gt_np = (gt_tensor + 1) / 2.0
denoised_np = (denoised_img + 1) / 2.0
# 移动到CPU并转换为numpy数组
lq_np = lq_np.squeeze().permute(1, 2, 0).cpu().numpy().clip(0, 1)
gt_np = gt_np.squeeze().permute(1, 2, 0).cpu().numpy().clip(0, 1)
denoised_np = denoised_np.squeeze().permute(1, 2, 0).cpu().numpy().clip(0, 1)
# 计算PSNR
current_psnr = psnr(gt_np, denoised_np, data_range=1.0)
psnr_values.append(current_psnr)
# 计算SSIM (多通道)
current_ssim = ssim(gt_np, denoised_np, data_range=1.0, channel_axis=2, multichannel=True)
ssim_values.append(current_ssim)
# 保存一些样本图像以便可视化检查
if args.save_samples and idx < 10: # 只保存前10个样本
save_path = os.path.join(args.output_dir, "test_samples", f"sample_{idx}.png")
# 创建对比图 [LQ, Denoised, GT]
comparison = torch.cat([
lq_tensor,
denoised_img,
gt_tensor
], dim=0)
save_image(comparison, nrow=3, show=False, path=save_path)
print(f"样本 {idx} 已保存至: {save_path}")
except Exception as e:
print(f"处理图像 {lq_path} 时出错: {e}")
continue
# 计算平均指标
avg_psnr = np.mean(psnr_values)
avg_ssim = np.mean(ssim_values)
# 打印结果
print("="*50)
print("去噪模型评估结果:")
print(f"平均 PSNR: {avg_psnr:.4f} dB")
print(f"平均 SSIM: {avg_ssim:.4f}")
print("="*50)
# 保存结果到文件
result_path = os.path.join(args.output_dir, "denoising_results.txt")
with open(result_path, 'w') as f:
f.write(f"测试图像数量: {len(psnr_values)}\n")
f.write(f"平均 PSNR: {avg_psnr:.4f} dB\n")
f.write(f"平均 SSIM: {avg_ssim:.4f}\n")
print(f"详细结果已保存至: {result_path}")
return {
"PSNR": avg_psnr,
"SSIM": avg_ssim,
"样本数": len(psnr_values)
}
main_far.py
import argparse
import datetime
import numpy as np
import os
import time
from pathlib import Path
import random
import torch
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import Dataset
from PIL import Image
from util.crop import center_crop_arr
import util.misc as misc
from util.misc import NativeScalerWithGradNormCount as NativeScaler
from util.loader import CachedFolder
from models.vae import AutoencoderKL
from models import far
from engine_far import train_one_epoch, evaluate, evaluate_denoising
import copy
# 添加推理单张图像的函数
def inference_single_image(model, vae, input_path, output_path, device, args):
"""对单张图像进行去噪处理
Args:
model: 微调后的FAR模型
vae: VAE模型用于图像处理
input_path: 输入图像路径
output_path: 输出图像路径
device: 设备
args: 参数
"""
print(f"正在处理图像: {input_path}")
# 确保输出目录存在
os.makedirs(os.path.dirname(output_path), exist_ok=True)
# 设置模型为评估模式
model.eval()
# 加载并预处理图像,与微调时使用的预处理流程一致
transform = transforms.Compose([
transforms.Resize((args.img_size, args.img_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
try:
img = Image.open(input_path).convert('RGB')
img_tensor = transform(img).unsqueeze(0).to(device)
print(f"输入图像形状: {img_tensor.shape}")
with torch.no_grad():
# 1. 使用VAE编码获取潜在表示,与微调流程一致
posterior = vae.encode(img_tensor)
latent = posterior.sample().mul_(0.2325) # 使用与微调相同的缩放因子
print(f"VAE编码后形状: {latent.shape}")
# 2. 使用模型的denoise方法,它现在使用与微调相同的流程
print("应用去噪处理...")
denoised_latent = model.denoise(latent)
print(f"Denoise后形状: {denoised_latent.shape}")
# 3. 使用VAE解码获取去噪后的图像
print("VAE解码中...")
denoised_img = vae.decode(denoised_latent / 0.2325)
print(f"VAE解码后形状: {denoised_img.shape}")
# 4. 保存输出图像
print("保存结果图像...")
output_img = (denoised_img + 1) / 2.0
output_img = output_img.clamp(0, 1)
from torchvision.utils import save_image
save_image(output_img, output_path)
print(f"去噪后的图像已保存至: {output_path}")
# 同时保存输入图像以便比较
input_save_path = os.path.join(os.path.dirname(output_path), "input.png")
input_img = (img_tensor + 1) / 2.0
save_image(input_img, input_save_path)
print(f"输入图像已保存至: {input_save_path}")
# 额外保存一个并排比较图像
comparison = torch.cat([input_img, output_img], dim=0)
comparison_path = os.path.join(os.path.dirname(output_path), "comparison.png")
save_image(comparison, comparison_path, nrow=2)
print(f"比较图像已保存至: {comparison_path}")
except Exception as e:
print(f"处理图像时出错: {e}")
import traceback
traceback.print_exc()
# 添加自定义数据集用于加载LQ和GT图像对
class DenoiseDataset(Dataset):
def __init__(self, lq_list_path, gt_list_path, transform=None, img_size=256):
self.transform = transform
self.img_size = img_size
# 读取LQ和GT图像路径
with open(lq_list_path, 'r') as f:
self.lq_paths = [line.strip() for line in f.readlines()]
with open(gt_list_path, 'r') as f:
self.gt_paths = [line.strip() for line in f.readlines()]
assert len(self.lq_paths) == len(self.gt_paths), "LQ和GT图像数量不匹配"
def __len__(self):
return len(self.lq_paths)
def __getitem__(self, idx):
# 加载LQ和GT图像
lq_img = Image.open(self.lq_paths[idx]).convert('RGB')
gt_img = Image.open(self.gt_paths[idx]).convert('RGB')
# 应用变换
if self.transform:
# 创建一个新的transform,但移除RandomHorizontalFlip
# 这样可以手动控制翻转,确保LQ和GT图像保持一致
transform_without_flip = transforms.Compose([
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, self.img_size)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
# 手动决定是否翻转,确保LQ和GT图像使用相同的决策
if random.random() < 0.5:
lq_img = transforms.functional.hflip(lq_img)
gt_img = transforms.functional.hflip(gt_img)
# 应用其他变换
lq_img = transform_without_flip(lq_img)
gt_img = transform_without_flip(gt_img)
else:
# 如果没有transform,直接转换为tensor
lq_img = transforms.ToTensor()(lq_img)
gt_img = transforms.ToTensor()(gt_img)
return lq_img, gt_img
def get_args_parser():
parser = argparse.ArgumentParser('FAR training with Diffusion Loss', add_help=False)
parser.add_argument('--batch_size', default=16, type=int,
help='Batch size per GPU (effective batch size is batch_size * # gpus')
parser.add_argument('--epochs', default=400, type=int)
# Model parameters
parser.add_argument('--model', default='far_large', type=str, metavar='MODEL',
help='Name of model to train')
# VAE parameters
parser.add_argument('--img_size', default=256, type=int,
help='images input size')
parser.add_argument('--vae_path', default="pretrained/vae_mar/kl16.ckpt", type=str,
help='images input size')
parser.add_argument('--vae_embed_dim', default=16, type=int,
help='vae output embedding dimension')
parser.add_argument('--vae_stride', default=16, type=int,
help='tokenizer stride, default use KL16')
parser.add_argument('--patch_size', default=1, type=int,
help='number of tokens to group as a patch.')
# Generation parameters
parser.add_argument('--num_iter', default=64, type=int,
help='number of autoregressive iterations to generate an image')
parser.add_argument('--num_images', default=50000, type=int,
help='number of images to generate')
parser.add_argument('--cfg', default=1.0, type=float, help="classifier-free guidance")
parser.add_argument('--cfg_schedule', default="linear", type=str)
parser.add_argument('--label_drop_prob', default=0.1, type=float)
parser.add_argument('--eval_freq', type=int, default=40, help='evaluation frequency')
parser.add_argument('--save_last_freq', type=int, default=10, help='save last frequency')
parser.add_argument('--online_eval', action='store_true')
parser.add_argument('--evaluate', action='store_true')
parser.add_argument('--mask', action='store_true')
parser.add_argument('--eval_bsz', type=int, default=64, help='generation batch size')
# Optimizer parameters
parser.add_argument('--weight_decay', type=float, default=0.02,
help='weight decay (default: 0.02)')
parser.add_argument('--lr', type=float, default=None, metavar='LR',
help='learning rate (absolute lr)')
parser.add_argument('--blr', type=float, default=1e-4, metavar='LR',
help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0')
parser.add_argument('--lr_schedule', type=str, default='constant',
help='learning rate schedule')
parser.add_argument('--warmup_epochs', type=int, default=100, metavar='N',
help='epochs to warmup LR')
parser.add_argument('--ema_rate', default=0.9999, type=float)
# FAR params
parser.add_argument('--mask_ratio_min', type=float, default=0.7,
help='Minimum mask ratio')
parser.add_argument('--grad_clip', type=float, default=3.0,
help='Gradient clip')
parser.add_argument('--attn_dropout', type=float, default=0.1,
help='attention dropout')
parser.add_argument('--proj_dropout', type=float, default=0.1,
help='projection dropout')
parser.add_argument('--buffer_size', type=int, default=64)
parser.add_argument('--loss_weight', action='store_true', help='adopt uneven loss weight.')
# Diffusion Loss params
parser.add_argument('--diffloss_d', type=int, default=12)
parser.add_argument('--diffloss_w', type=int, default=1536)
parser.add_argument('--num_sampling_steps', type=str, default="100")
parser.add_argument('--diffusion_batch_mul', type=int, default=1)
parser.add_argument('--temperature', default=1.0, type=float, help='diffusion loss sampling temperature')
# Dataset parameters
parser.add_argument('--data_path', default='./data/imagenet', type=str,
help='dataset path')
parser.add_argument('--class_num', default=1000, type=int)
# 添加图像去噪微调所需的参数
parser.add_argument('--lq_list_path', default='', type=str,
help='Path to the list of low quality images')
parser.add_argument('--gt_list_path', default='', type=str,
help='Path to the list of ground truth images')
parser.add_argument('--finetune', action='store_true',
help='Enable finetuning for image denoising')
parser.add_argument('--pretrained_path', default='./pretrained_models/far/far_large/checkpoint-last.pth',
type=str, help='Path to pretrained model for finetuning')
# 添加单图像推理所需的参数
parser.add_argument('--inference', action='store_true',
help='Run inference on a single image')
parser.add_argument('--input_image', default='./data/test.png', type=str,
help='Path to the input image for inference')
parser.add_argument('--output_image', default='./output_dir_denoise/result.png', type=str,
help='Path to save the output image')
parser.add_argument('--use_ema', action='store_true',
help='Use EMA parameters for inference')
parser.add_argument('--output_dir', default='./output_dir',
help='path where to save, empty for no saving')
parser.add_argument('--log_dir', default='./output_dir',
help='path where to tensorboard log')
parser.add_argument('--device', default='cuda',
help='device to use for training / testing')
parser.add_argument('--seed', default=1, type=int)
parser.add_argument('--resume', default='',
help='resume from checkpoint')
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('--num_workers', default=8, type=int)
parser.add_argument('--pin_mem', action='store_true',
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
parser.set_defaults(pin_mem=True)
# distributed training parameters
parser.add_argument('--world_size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--dist_on_itp', action='store_true')
parser.add_argument('--dist_url', default='env://',
help='url used to set up distributed training')
# caching latents
parser.add_argument('--use_cached', action='store_true', dest='use_cached',
help='Use cached latents')
parser.set_defaults(use_cached=False)
parser.add_argument('--cached_path', default='', help='path to cached latents')
# 添加评估去噪模型的参数
parser.add_argument('--evaluate_denoising', action='store_true',
help='Evaluate denoising performance on test set')
parser.add_argument('--save_samples', action='store_true',
help='Save sample images during evaluation')
return parser
# 在main函数中修改数据加载部分
def main(args):
misc.init_distributed_mode(args)
print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
print("{}".format(args).replace(', ', ',\n'))
device = torch.device(args.device)
# fix the seed for reproducibility
seed = args.seed + misc.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
cudnn.benchmark = True
num_tasks = misc.get_world_size()
global_rank = misc.get_rank()
if global_rank == 0 and args.log_dir is not None and not args.inference:
os.makedirs(args.log_dir, exist_ok=True)
log_writer = SummaryWriter(log_dir=args.log_dir)
else:
log_writer = None
### ImageNet VAE
vae = AutoencoderKL(embed_dim=args.vae_embed_dim, ch_mult=(1, 1, 2, 2, 4), ckpt_path=args.vae_path).cuda().eval()
for param in vae.parameters():
param.requires_grad = False
model = far.__dict__[args.model](
img_size=args.img_size,
vae_stride=args.vae_stride,
patch_size=args.patch_size,
vae_embed_dim=args.vae_embed_dim,
mask=args.mask,
mask_ratio_min=args.mask_ratio_min,
label_drop_prob=args.label_drop_prob,
class_num=args.class_num,
attn_dropout=args.attn_dropout,
proj_dropout=args.proj_dropout,
buffer_size=args.buffer_size,
diffloss_d=args.diffloss_d,
diffloss_w=args.diffloss_w,
num_sampling_steps=args.num_sampling_steps,
diffusion_batch_mul=args.diffusion_batch_mul,
)
print("Model = %s" % str(model))
# following timm: set wd as 0 for bias and norm layers
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Number of trainable parameters: {}M".format(n_params / 1e6))
model.to(device)
model_without_ddp = model
eff_batch_size = args.batch_size * misc.get_world_size()
if args.lr is None: # only base_lr is specified
args.lr = args.blr * eff_batch_size / 256
print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
print("actual lr: %.2e" % args.lr)
print("effective batch size: %d" % eff_batch_size)
if args.distributed and not args.inference:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
model_without_ddp = model.module
# no weight decay on bias, norm layers, and diffloss MLP
param_groups = misc.add_weight_decay(model_without_ddp, args.weight_decay)
optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))
print(optimizer)
loss_scaler = NativeScaler()
### resume training
if args.finetune or args.inference:
# 从预训练模型或微调模型加载参数
checkpoint_path = args.pretrained_path
print(f"Loading model from {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
model_params = list(model_without_ddp.parameters())
ema_state_dict = checkpoint['model_ema']
ema_params = [ema_state_dict[name].cuda() for name, _ in model_without_ddp.named_parameters()]
print("Loaded model for inference/finetuning")
# 如果是推理模式且使用EMA参数
if args.inference and args.use_ema:
print("Using EMA parameters for inference")
ema_state_dict = {name: ema_params[i] for i, (name, _) in enumerate(model_without_ddp.named_parameters())}
model_without_ddp.load_state_dict(ema_state_dict)
elif args.resume and os.path.exists(os.path.join(args.resume, "checkpoint-last.pth")):
checkpoint = torch.load(os.path.join(args.resume, "checkpoint-last.pth"), map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
model_params = list(model_without_ddp.parameters())
ema_state_dict = checkpoint['model_ema']
ema_params = [ema_state_dict[name].cuda() for name, _ in model_without_ddp.named_parameters()]
print("Resume checkpoint %s" % args.resume)
if 'optimizer' in checkpoint and 'epoch' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer'])
args.start_epoch = checkpoint['epoch'] + 1
if 'scaler' in checkpoint:
loss_scaler.load_state_dict(checkpoint['scaler'])
print("With optim & sched!")
del checkpoint
else:
model_params = list(model_without_ddp.parameters())
ema_params = copy.deepcopy(model_params)
print("Training from scratch")
# 处理单图推理
if args.inference:
inference_single_image(model_without_ddp, vae, args.input_image, args.output_image, device, args)
return
# evaluate FID and IS
if args.evaluate:
torch.cuda.empty_cache()
evaluate(model_without_ddp, vae, ema_params, args, 0, batch_size=args.eval_bsz, log_writer=log_writer, cfg=args.cfg, use_ema=True)
return
# 评估去噪性能
if args.evaluate_denoising:
print("开始评估去噪模型在测试集上的性能...")
# 如果使用EMA参数进行评估
if args.use_ema:
print("使用EMA参数进行评估")
ema_state_dict = {name: ema_params[i] for i, (name, _) in enumerate(model_without_ddp.named_parameters())}
model_without_ddp.load_state_dict(ema_state_dict)
# 调用评估函数
evaluate_denoising(model_without_ddp, vae, args, device)
return
# augmentation following DiT and ADM
transform_train = transforms.Compose([
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.img_size)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
# 根据是否进行图像去噪微调来选择不同的数据集
if args.finetune:
# 使用自定义的去噪数据集
dataset_train = DenoiseDataset(args.lq_list_path, args.gt_list_path, transform=transform_train, img_size=args.img_size)
print(f"Loaded denoising dataset with {len(dataset_train)} image pairs")
elif args.use_cached:
dataset_train = CachedFolder(args.cached_path)
else:
dataset_train = datasets.ImageFolder(os.path.join(args.data_path, 'train'), transform=transform_train)
print(dataset_train)
sampler_train = torch.utils.data.DistributedSampler(
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
)
print("Sampler_train = %s" % str(sampler_train))
data_loader_train = torch.utils.data.DataLoader(
dataset_train, sampler=sampler_train,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=True,
)
# 用于可视化的固定样本,只获取一次
fixed_vis_samples = None
# training
print(f"Start training for {args.epochs} epochs")
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
data_loader_train.sampler.set_epoch(epoch)
train_stats = train_one_epoch(
model, vae,
model_params, ema_params,
data_loader_train,
optimizer, device, epoch, loss_scaler,
log_writer=log_writer,
args=args
)
# 每个epoch结束后生成示例图像以可视化微调效果
if args.finetune and misc.is_main_process() and (epoch % 5 == 0 or epoch + 1 == args.epochs):
# 创建保存可视化结果的目录
vis_dir = os.path.join(args.output_dir, 'visualizations')
os.makedirs(vis_dir, exist_ok=True)
print(f"正在生成epoch {epoch}的示例图像...")
# 从训练数据中抽取几对图像进行可视化
model_eval = model_without_ddp if args.distributed else model
model_eval.eval()
# 获取固定的示例图像
with torch.no_grad():
# 如果是第一次,获取并保存固定样本
if fixed_vis_samples is None:
try:
# 获取一批数据用于可视化
lq_samples, gt_samples = next(iter(data_loader_train))
lq_samples = lq_samples[:4].to(device) # 只取前4个样本
gt_samples = gt_samples[:4].to(device)
# 保存下来供后续使用
fixed_vis_samples = (lq_samples.clone(), gt_samples.clone())
print("已保存固定样本用于可视化")
# 打印四张图像的路径
print("\n===== 用于可视化的四张图像路径 =====")
for i in range(4):
# 获取当前批次中的图像路径
if hasattr(data_loader_train.dataset, 'lq_paths') and hasattr(data_loader_train.dataset, 'gt_paths'):
lq_path = data_loader_train.dataset.lq_paths[i]
gt_path = data_loader_train.dataset.gt_paths[i]
print(f"图像 {i+1}:")
print(f" LQ路径: {lq_path}")
print(f" GT路径: {gt_path}")
else:
print(f"图像 {i+1}: 无法获取路径信息")
print("=====================================\n")
except Exception as e:
print(f"获取固定样本时出错: {e}")
# 如果获取失败,跳过本次可视化
model_eval.train()
continue
# 使用保存的固定样本
if fixed_vis_samples is not None:
try:
lq_samples, gt_samples = fixed_vis_samples
# 编码到潜在空间
lq_posterior = vae.encode(lq_samples)
lq_latents = lq_posterior.sample().mul_(0.2325)
# 使用微调后的模型生成增强结果
enhanced_latents = model_eval.denoise(lq_latents)
enhanced_images = vae.decode(enhanced_latents / 0.2325)
# 创建可视化网格:原始LQ图像、增强后图像、GT图像
from torchvision.utils import make_grid, save_image
# 将图像范围从[-1,1]调整到[0,1]用于保存
lq_vis = (lq_samples + 1) / 2
enhanced_vis = (enhanced_images + 1) / 2
gt_vis = (gt_samples + 1) / 2
# 沿batch维度拼接
comparison = torch.cat([lq_vis, enhanced_vis, gt_vis], dim=0)
grid = make_grid(comparison, nrow=4, padding=2, normalize=False)
# 保存网格图像
save_path = os.path.join(vis_dir, f'comparison_epoch_{epoch}.png')
save_image(grid, save_path)
print(f"示例图像已保存至: {save_path}")
except Exception as e:
print(f"生成示例图像时出错: {e}")
model_eval.train()
# save checkpoint
if epoch % args.save_last_freq == 0 or epoch + 1 == args.epochs:
misc.save_model(args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
loss_scaler=loss_scaler, epoch=epoch, ema_params=ema_params)
# online evaluation
if args.online_eval and (epoch % args.eval_freq == 0 or epoch + 1 == args.epochs):
torch.cuda.empty_cache()
evaluate(model_without_ddp, vae, ema_params, args, epoch, batch_size=args.eval_bsz, log_writer=log_writer, cfg=1.0, use_ema=True)
if not (args.cfg == 1.0 or args.cfg == 0.0):
evaluate(model_without_ddp, vae, ema_params, args, epoch, batch_size=args.eval_bsz // 2, log_writer=log_writer, cfg=args.cfg, use_ema=True)
torch.cuda.empty_cache()
if misc.is_main_process():
if log_writer is not None:
log_writer.flush()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
if __name__ == '__main__':
args = get_args_parser()
args = args.parse_args()
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
args.log_dir = args.output_dir
main(args)
inference.sh
### ---------------------------- FAR_Base ---------------------------- ###
# torchrun --nnodes=1 --nproc_per_node=1 main_far.py \
# --img_size 256 --vae_path pretrained/vae_mar/kl16.ckpt --vae_embed_dim 16 --vae_stride 16 --patch_size 1 \
# --model far_base --diffloss_d 6 --diffloss_w 1024 \
# --eval_bsz 32 --num_images 1000 \
# --num_iter 10 --num_sampling_steps 100 --cfg 3.0 --cfg_schedule linear --temperature 1.0 \
# --output_dir pretrained_models/far/far_base \
# --resume pretrained_models/far/far_base \
# --data_path ${IMAGENET_PATH} --evaluate
### ---------------------------- FAR_Large ---------------------------- ###
torchrun --nnodes=1 --nproc_per_node=1 main_far.py \
--img_size 256 --vae_path pretrained/vae_mar/kl16.ckpt --vae_embed_dim 16 --vae_stride 16 --patch_size 1 \
--model far_large --diffloss_d 3 --diffloss_w 1024 \
--eval_bsz 32 --num_images 1000 \
--num_iter 10 --num_sampling_steps 100 --cfg 3.0 --cfg_schedule linear --temperature 1.0 \
--output_dir pretrained_models/far/far_large \
--resume pretrained_models/far/far_large \
--data_path ${IMAGENET_PATH} --evaluate
### ---------------------------- FAR_Huge ---------------------------- ###
# torchrun --nnodes=1 --nproc_per_node=1 main_far.py \
# --img_size 256 --vae_path pretrained/vae_mar/kl16.ckpt --vae_embed_dim 16 --vae_stride 16 --patch_size 1 \
# --model far_huge --diffloss_d 3 --diffloss_w 1024 \
# --eval_bsz 32 --num_images 1000 \
# --num_iter 10 --num_sampling_steps 100 --cfg 3.0 --cfg_schedule linear --temperature 1.0 \
# --output_dir pretrained_models/far/far_huge \
# --resume pretrained_models/far/far_huge \
# --data_path ${IMAGENET_PATH} --evaluate
### ---------------------------- FAR_Large Denoising ---------------------------- ###
# torchrun --nnodes=1 --nproc_per_node=1 main_far.py \
# --img_size 256 --vae_path pretrained/vae_mar/kl16.ckpt --vae_embed_dim 16 --vae_stride 16 --patch_size 1 \
# --model far_large --diffloss_d 3 --diffloss_w 1024 \
# --eval_bsz 32 --num_images 10 \
# --num_iter 10 --num_sampling_steps 100 --cfg 1.0 --cfg_schedule linear --temperature 0.8 \
# --output_dir ./output_dir_denoise \
# --lq_list_path /data/vjuicefs_ai_camera_jgroup_research/public_data/11164225/datasets/Magic_Dictionary/test_list/test.LQ.list \
# --gt_list_path /data/vjuicefs_ai_camera_jgroup_research/public_data/11164225/datasets/Magic_Dictionary/test_list/test.GT.list \
# --finetune --pretrained_path ./output_dir_denoise/checkpoint-last.pth \
# --evaluate
### ---------------------------- FAR_Large Denoising Evaluation ---------------------------- ###
# 评估微调后的去噪模型效果
torchrun --nnodes=1 --nproc_per_node=1 main_far.py \
--img_size 256 --vae_path pretrained_models/vae/kl16.ckpt --vae_embed_dim 16 --vae_stride 16 --patch_size 1 \
--model far_large --diffloss_d 3 --diffloss_w 1024 \
--lq_list_path ./data/test/testLQ.list \
--gt_list_path ./data/test/testGT.list \
--output_dir ./output_dir_denoise \
--finetune --pretrained_path ./output_dir_denoise/checkpoint-last.pth \
--evaluate_denoising --save_samples --use_ema
train.py
### ---------------------------- FAR_Base ---------------------------- ###
# torchrun --nproc_per_node=8 --nnodes=4 --node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} \
# main_far.py \
# --img_size 256 --vae_path pretrained_models/vae/kl16.ckpt --vae_embed_dim 16 --vae_stride 16 --patch_size 1 \
# --model far_base --diffloss_d 6 --diffloss_w 1024 \
# --epochs 400 --warmup_epochs 100 --batch_size 64 --blr 1.0e-4 --diffusion_batch_mul 4 \
# --output_dir ${OUTPUT_DIR} --resume ${OUTPUT_DIR} \
# --data_path ${IMAGENET_PATH}
### ---------------------------- FAR_Large ---------------------------- ###
# torchrun --nproc_per_node=8 --nnodes=4 --node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} \
# main_far.py \
# --img_size 256 --vae_path pretrained_models/vae/kl16.ckpt --vae_embed_dim 16 --vae_stride 16 --patch_size 1 \
# --model far_large --diffloss_d 3 --diffloss_w 1024 \
# --epochs 400 --warmup_epochs 100 --batch_size 64 --blr 1.0e-4 --diffusion_batch_mul 4 \
# --output_dir ${OUTPUT_DIR} --resume ${OUTPUT_DIR} \
# --data_path ${IMAGENET_PATH}
### ---------------------------- FAR_Huge ---------------------------- ###
# torchrun --nproc_per_node=8 --nnodes=4 --node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} \
# main_far.py \
# --img_size 256 --vae_path pretrained_models/vae/kl16.ckpt --vae_embed_dim 16 --vae_stride 16 --patch_size 1 \
# --model far_huge --diffloss_d 3 --diffloss_w 1024 \
# --epochs 400 --warmup_epochs 100 --batch_size 64 --blr 1.0e-4 --diffusion_batch_mul 4 \
# --output_dir ${OUTPUT_DIR} --resume ${OUTPUT_DIR} \
# --data_path ${IMAGENET_PATH}
### ---------------------------- FAR_Large Finetune for Denoising ---------------------------- ###
# 设置环境变量
export NODE_RANK=0
export MASTER_ADDR="127.0.0.1"
export MASTER_PORT=29500
# 运行微调命令
torchrun --nproc_per_node=4 --nnodes=1 --node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} \
main_far.py \
--img_size 256 --vae_path pretrained_models/vae/kl16.ckpt --vae_embed_dim 16 --vae_stride 16 --patch_size 1 \
--model far_large --diffloss_d 3 --diffloss_w 1024 \
--finetune --pretrained_path ./pretrained_models/far/far_large/checkpoint-last.pth \
--lq_list_path /data/vjuicefs_ai_camera_jgroup_research/public_data/11164225/datasets/Magic_Dictionary/training_list/20250105.6789degq_ShMfwowCaGq3_13_Mfww_r6789_G35l0105.6_6_1.LQ.list \
--gt_list_path /data/vjuicefs_ai_camera_jgroup_research/public_data/11164225/datasets/Magic_Dictionary/training_list/20250105.6789degq_ShMfwowCaGq3_13_Mfww_r6789_G35l0105.6_6_1.GT.list \
--epochs 100 --warmup_epochs 10 --batch_size 32 --blr 1.0e-5 --diffusion_batch_mul 4 \
--output_dir ./output_dir_denoise