sample.sh
#!/bin/bash
# 使用训练好的模型处理单张测试图像
torchrun --nproc_per_node=1 main_far_denoise.py \
--img_size 256 --vae_path pretrained_models/vae/kl16.ckpt --vae_embed_dim 16 --vae_stride 16 --patch_size 1 \
--diffloss_d 3 --diffloss_w 1024 \
--lq_freq_min 14 --lq_freq_max 15 --hq_freq 16 \
--num_steps 3 --num_sampling_steps 100 --temperature 1.0 --cfg 2.0 \
--test_image ./data/test.png \
--output_dir ./output_denoise --resume ./output_denoise \
--evaluate
train.sh
#!/bin/bash
# 训练用于图像去噪/超分辨率的FAR模型
torchrun --nproc_per_node=8 main_far_denoise.py \
--img_size 256 --vae_path pretrained_models/vae/kl16.ckpt --vae_embed_dim 16 --vae_stride 16 --patch_size 1 \
--diffloss_d 3 --diffloss_w 1024 \
--lq_freq_min 14 --lq_freq_max 15 --hq_freq 16 \
--epochs 400 --warmup_epochs 100 --batch_size 32 --blr 1.0e-4 --diffusion_batch_mul 4 \
--recon_loss_weight 1.0 \
--lq_list_path /data/LQ.list --gt_list_path /data/GT.list \
--output_dir ./output_denoise --resume ./output_denoise
engine_far_denoise.py
import math
import sys
from typing import Iterable
import torch
import torch.nn.functional as F
import util.misc as misc
import util.lr_sched as lr_sched
from models.vae import DiagonalGaussianDistribution
import os
import copy
import time
import numpy as np
from torchvision.utils import make_grid
from typing import Optional
from PIL import Image
import torchvision.transforms as transforms
def save_image(images: torch.Tensor, nrow: int = 8, show: bool = True, path: Optional[str] = None, format: Optional[str] = None, to_grayscale: bool = False, **kwargs):
images = images * 0.5 + 0.5
grid = make_grid(images, nrow=nrow, **kwargs) # (channels, height, width)
grid = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
im = Image.fromarray(grid)
if to_grayscale:
im = im.convert(mode="L")
if path is not None:
im.save(path, format=format)
if show:
im.show()
return grid
def update_ema(target_params, source_params, rate=0.99):
"""
Update target parameters to be closer to those of source parameters using
an exponential moving average.
"""
for targ, src in zip(target_params, source_params):
targ.detach().mul_(rate).add_(src, alpha=1 - rate)
def train_one_epoch(model, vae,
model_params, ema_params,
data_loader: Iterable, optimizer: torch.optim.Optimizer,
device: torch.device, epoch: int, loss_scaler,
log_writer=None,
args=None):
model.train(True)
metric_logger = misc.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
header = 'Epoch: [{}]'.format(epoch)
print_freq = 10
optimizer.zero_grad()
if log_writer is not None:
print('log_dir: {}'.format(log_writer.log_dir))
for data_iter_step, (lq_images, gt_images) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
# 调整学习率
lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
lq_images = lq_images.to(device, non_blocking=True)
gt_images = gt_images.to(device, non_blocking=True)
with torch.no_grad():
# VAE编码
lq_latents = vae.encode(lq_images).sample().mul_(0.2325)
gt_latents = vae.encode(gt_images).sample().mul_(0.2325)
# 前向传播
with torch.cuda.amp.autocast():
loss = model(lq_latents, gt_latents, loss_weight=args.loss_weight)
loss_value = loss.item()
if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value))
sys.exit(1)
# 反向传播
loss_scaler(loss, optimizer, clip_grad=args.grad_clip, parameters=model.parameters(), update_grad=True)
optimizer.zero_grad()
torch.cuda.synchronize()
# 更新EMA
update_ema(ema_params, model_params, rate=args.ema_rate)
# 记录损失
metric_logger.update(loss=loss_value)
# 记录学习率
lr = optimizer.param_groups[0]["lr"]
metric_logger.update(lr=lr)
loss_value_reduce = misc.all_reduce_mean(loss_value)
if log_writer is not None:
epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x)
log_writer.add_scalar('lr', lr, epoch_1000x)
# 收集统计信息
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
def evaluate(model_without_ddp, vae, ema_params, args, epoch, log_writer=None, use_ema=True):
"""评估模型性能"""
model_without_ddp.eval()
save_dir = os.path.join(args.output_dir, f"eval_epoch{epoch}")
os.makedirs(save_dir, exist_ok=True)
# 使用EMA参数
if use_ema:
model_state_dict = copy.deepcopy(model_without_ddp.state_dict())
ema_state_dict = copy.deepcopy(model_without_ddp.state_dict())
for i, (name, _value) in enumerate(model_without_ddp.named_parameters()):
assert name in ema_state_dict
ema_state_dict[name] = ema_params[i]
model_without_ddp.load_state_dict(ema_state_dict)
# 创建测试数据集
from datasets.denoise_pair_dataset import DenoisePairDataset
test_dataset = DenoisePairDataset(
lq_list_path=args.lq_list_path,
gt_list_path=args.gt_list_path,
img_size=args.img_size
)
# 随机选择一些样本进行评估
indices = np.random.choice(len(test_dataset), min(10, len(test_dataset)), replace=False)
total_psnr = 0.0
total_samples = 0
for idx in indices:
lq_img, gt_img = test_dataset[idx]
lq_img = lq_img.unsqueeze(0).to(next(model_without_ddp.parameters()).device)
gt_img = gt_img.unsqueeze(0).to(next(model_without_ddp.parameters()).device)
with torch.no_grad():
# VAE编码
lq_latent = vae.encode(lq_img).sample().mul_(0.2325)
# 去噪
enhanced_latent = model_without_ddp.denoise_image(
lq_latent,
num_steps=args.num_steps,
temperature=args.temperature,
cfg=args.cfg
)
# VAE解码
enhanced_img = vae.decode(enhanced_latent / 0.2325)
# 计算PSNR (峰值信噪比)
mse = F.mse_loss(enhanced_img * 0.5 + 0.5, gt_img * 0.5 + 0.5).item()
psnr = -10 * math.log10(mse)
total_psnr += psnr
total_samples += 1
# 保存结果
if idx == indices[0]: # 只保存第一个样本的结果
torchvision.utils.save_image(
torch.cat([lq_img, enhanced_img, gt_img], dim=0),
os.path.join(save_dir, f"sample.png"),
nrow=3,
normalize=True,
range=(-1, 1),
)
# 恢复原始模型状态
if use_ema:
model_without_ddp.load_state_dict(model_state_dict)
# 计算平均PSNR
avg_psnr = total_psnr / total_samples if total_samples > 0 else 0
if log_writer is not None:
log_writer.add_scalar('eval_psnr', avg_psnr, epoch)
print(f"Evaluation at epoch {epoch}: Average PSNR = {avg_psnr:.2f}")
return avg_psnr
main_far_denoise.py
import argparse
import datetime
import numpy as np
import os
import time
from pathlib import Path
import torch
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
from util.misc import NativeScalerWithGradNormCount as NativeScaler
import util.misc as misc
import util.lr_sched as lr_sched
from models.vae import AutoencoderKL
from models.far_denoise import create_far_denoise
from engine_far_denoise import train_one_epoch, evaluate
from datasets.denoise_pair_dataset import DenoisePairDataset
import copy
import torchvision.transforms as transforms
from PIL import Image
import torchvision
def get_args_parser():
parser = argparse.ArgumentParser('FAR for Image Denoising', add_help=False)
parser.add_argument('--batch_size', default=16, type=int,
help='Batch size per GPU (effective batch size is batch_size * # gpus')
parser.add_argument('--epochs', default=400, type=int)
# VAE parameters
parser.add_argument('--img_size', default=256, type=int,
help='images input size')
parser.add_argument('--vae_path', default="pretrained/vae_mar/kl16.ckpt", type=str,
help='VAE model path')
parser.add_argument('--vae_embed_dim', default=16, type=int,
help='vae output embedding dimension')
parser.add_argument('--vae_stride', default=16, type=int,
help='tokenizer stride, default use KL16')
parser.add_argument('--patch_size', default=1, type=int,
help='number of tokens to group as a patch.')
# Denoising parameters
parser.add_argument('--lq_freq_min', default=14, type=int,
help='Minimum frequency level for low-quality images')
parser.add_argument('--lq_freq_max', default=15, type=int,
help='Maximum frequency level for low-quality images')
parser.add_argument('--hq_freq', default=16, type=int,
help='Frequency level for high-quality images')
parser.add_argument('--recon_loss_weight', default=1.0, type=float,
help='Weight for reconstruction loss')
# Generation parameters
parser.add_argument('--num_steps', default=3, type=int,
help='number of frequency steps for denoising')
parser.add_argument('--test_image', default='', type=str,
help='Path to a single test image')
parser.add_argument('--num_sampling_steps', default="100", type=str,
help='Number of diffusion sampling steps')
parser.add_argument('--cfg', default=1.0, type=float, help="classifier-free guidance")
parser.add_argument('--cfg_schedule', default="linear", type=str)
parser.add_argument('--label_drop_prob', default=0.1, type=float)
parser.add_argument('--eval_freq', type=int, default=40, help='evaluation frequency')
parser.add_argument('--save_last_freq', type=int, default=10, help='save last frequency')
parser.add_argument('--online_eval', action='store_true')
parser.add_argument('--evaluate', action='store_true')
parser.add_argument('--mask', action='store_true')
parser.add_argument('--eval_bsz', type=int, default=4, help='evaluation batch size')
# Optimizer parameters
parser.add_argument('--weight_decay', type=float, default=0.02,
help='weight decay (default: 0.02)')
parser.add_argument('--lr', type=float, default=None, metavar='LR',
help='learning rate (absolute lr)')
parser.add_argument('--blr', type=float, default=1e-4, metavar='LR',
help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0')
parser.add_argument('--lr_schedule', type=str, default='constant',
help='learning rate schedule')
parser.add_argument('--warmup_epochs', type=int, default=100, metavar='N',
help='epochs to warmup LR')
parser.add_argument('--ema_rate', default=0.9999, type=float)
# FAR params
parser.add_argument('--mask_ratio_min', type=float, default=0.7,
help='Minimum mask ratio')
parser.add_argument('--grad_clip', type=float, default=3.0,
help='Gradient clip')
parser.add_argument('--attn_dropout', type=float, default=0.1,
help='attention dropout')
parser.add_argument('--proj_dropout', type=float, default=0.1,
help='projection dropout')
parser.add_argument('--buffer_size', type=int, default=64)
parser.add_argument('--loss_weight', action='store_true', help='adopt uneven loss weight.')
# Diffusion Loss params
parser.add_argument('--diffloss_d', type=int, default=3)
parser.add_argument('--diffloss_w', type=int, default=1024)
parser.add_argument('--diffusion_batch_mul', type=int, default=4)
parser.add_argument('--temperature', default=1.0, type=float, help='diffusion loss sampling temperature')
# Dataset parameters
parser.add_argument('--lq_list_path', default='/data/LQ.list', type=str,
help='path to low-quality images list')
parser.add_argument('--gt_list_path', default='/data/GT.list', type=str,
help='path to high-quality images list')
parser.add_argument('--output_dir', default='./output_denoise',
help='path where to save, empty for no saving')
parser.add_argument('--log_dir', default='./output_denoise',
help='path where to tensorboard log')
parser.add_argument('--device', default='cuda',
help='device to use for training / testing')
parser.add_argument('--seed', default=1, type=int)
parser.add_argument('--resume', default='',
help='resume from checkpoint')
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('--num_workers', default=8, type=int)
parser.add_argument('--pin_mem', action='store_true',
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
parser.set_defaults(pin_mem=True)
# distributed training parameters
parser.add_argument('--world_size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--dist_on_itp', action='store_true')
parser.add_argument('--dist_url', default='env://',
help='url used to set up distributed training')
return parser
def main(args):
misc.init_distributed_mode(args)
print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
print("{}".format(args).replace(', ', ',\n'))
device = torch.device(args.device)
# fix the seed for reproducibility
seed = args.seed + misc.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
cudnn.benchmark = True
num_tasks = misc.get_world_size()
global_rank = misc.get_rank()
if global_rank == 0 and args.log_dir is not None:
os.makedirs(args.log_dir, exist_ok=True)
log_writer = SummaryWriter(log_dir=args.log_dir)
else:
log_writer = None
# 加载预训练的VAE
vae = AutoencoderKL(embed_dim=args.vae_embed_dim, ch_mult=(1, 1, 2, 2, 4), ckpt_path=args.vae_path).cuda().eval()
for param in vae.parameters():
param.requires_grad = False
# 创建FAR去噪模型 (使用单一配置)
model = create_far_denoise(
img_size=args.img_size,
vae_stride=args.vae_stride,
patch_size=args.patch_size,
vae_embed_dim=args.vae_embed_dim,
mask=args.mask,
mask_ratio_min=args.mask_ratio_min,
label_drop_prob=args.label_drop_prob,
class_num=1, # 不使用类别信息
attn_dropout=args.attn_dropout,
proj_dropout=args.proj_dropout,
buffer_size=args.buffer_size,
diffloss_d=args.diffloss_d,
diffloss_w=args.diffloss_w,
num_sampling_steps=args.num_sampling_steps,
diffusion_batch_mul=args.diffusion_batch_mul,
lq_freq_min=args.lq_freq_min,
lq_freq_max=args.lq_freq_max,
hq_freq=args.hq_freq,
recon_loss_weight=args.recon_loss_weight,
)
print("Model = %s" % str(model))
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Number of trainable parameters: {}M".format(n_params / 1e6))
model.to(device)
model_without_ddp = model
eff_batch_size = args.batch_size * misc.get_world_size()
if args.lr is None: # only base_lr is specified
args.lr = args.blr * eff_batch_size / 256
print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
print("actual lr: %.2e" % args.lr)
print("effective batch size: %d" % eff_batch_size)
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
# 优化器设置
param_groups = misc.add_weight_decay(model_without_ddp, args.weight_decay)
optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))
print(optimizer)
loss_scaler = NativeScaler()
# 恢复训练
if args.resume and os.path.exists(os.path.join(args.resume, "checkpoint-last.pth")):
checkpoint = torch.load(os.path.join(args.resume, "checkpoint-last.pth"), map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
model_params = list(model_without_ddp.parameters())
ema_state_dict = checkpoint['model_ema']
ema_params = [ema_state_dict[name].cuda() for name, _ in model_without_ddp.named_parameters()]
print("Resume checkpoint %s" % args.resume)
if 'optimizer' in checkpoint and 'epoch' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer'])
args.start_epoch = checkpoint['epoch'] + 1
if 'scaler' in checkpoint:
loss_scaler.load_state_dict(checkpoint['scaler'])
print("With optim & sched!")
del checkpoint
else:
model_params = list(model_without_ddp.parameters())
ema_params = copy.deepcopy(model_params)
print("Training from scratch")
# 单张图像测试
if args.evaluate and args.test_image:
test_image = Image.open(args.test_image).convert('RGB')
transform = transforms.Compose([
transforms.Resize((args.img_size, args.img_size)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
test_img = transform(test_image).unsqueeze(0).to(device)
# 使用EMA模型
if hasattr(locals(), 'ema_params'):
model_state_dict = copy.deepcopy(model_without_ddp.state_dict())
ema_state_dict = copy.deepcopy(model_without_ddp.state_dict())
for i, (name, _value) in enumerate(model_without_ddp.named_parameters()):
assert name in ema_state_dict
ema_state_dict[name] = ema_params[i]
model_without_ddp.load_state_dict(ema_state_dict)
with torch.no_grad():
# VAE编码
lq_latent = vae.encode(test_img).sample().mul_(0.2325)
# 去噪
enhanced_latent = model_without_ddp.denoise_image(
lq_latent,
num_steps=args.num_steps,
temperature=args.temperature,
cfg=args.cfg
)
# VAE解码
enhanced_img = vae.decode(enhanced_latent / 0.2325)
# 保存结果
output_dir = os.path.join(args.output_dir, "results")
os.makedirs(output_dir, exist_ok=True)
# 还原EMA状态
if hasattr(locals(), 'model_state_dict'):
model_without_ddp.load_state_dict(model_state_dict)
# 保存原始和增强图像
original_img = vae.decode(lq_latent / 0.2325)
torchvision.utils.save_image(
torch.cat([test_img, original_img, enhanced_img], dim=0),
os.path.join(output_dir, "comparison.png"),
nrow=3,
normalize=True,
range=(-1, 1),
)
print(f"Testing completed. Results saved to {output_dir}")
return
# 创建训练数据集和数据加载器
dataset_train = DenoisePairDataset(
lq_list_path=args.lq_list_path,
gt_list_path=args.gt_list_path,
img_size=args.img_size
)
print(dataset_train)
if args.distributed:
sampler_train = torch.utils.data.DistributedSampler(
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
)
else:
sampler_train = torch.utils.data.RandomSampler(dataset_train)
print("Sampler_train = %s" % str(sampler_train))
data_loader_train = torch.utils.data.DataLoader(
dataset_train, sampler=sampler_train,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=True,
)
# 开始训练
print(f"Start training for {args.epochs} epochs")
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
data_loader_train.sampler.set_epoch(epoch)
train_one_epoch(
model, vae,
model_params, ema_params,
data_loader_train,
optimizer, device, epoch, loss_scaler,
log_writer=log_writer,
args=args
)
# 保存检查点
if epoch % args.save_last_freq == 0 or epoch + 1 == args.epochs:
misc.save_model(args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
loss_scaler=loss_scaler, epoch=epoch, ema_params=ema_params)
# 在线评估
if args.online_eval and (epoch % args.eval_freq == 0 or epoch + 1 == args.epochs):
torch.cuda.empty_cache()
evaluate(model_without_ddp, vae, ema_params, args, epoch, log_writer=log_writer, use_ema=True)
torch.cuda.empty_cache()
if misc.is_main_process():
if log_writer is not None:
log_writer.flush()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
if __name__ == '__main__':
args = get_args_parser()
args = args.parse_args()
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
args.log_dir = args.output_dir
main(args)
far_deoise.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import numpy as np
import math
from functools import partial
from models.far import FAR
class FAR_Denoise(FAR):
"""基于FAR的图像去噪与质量提升模型"""
def __init__(self, lq_freq_min=14, lq_freq_max=15, hq_freq=16, recon_loss_weight=1.0, **kwargs):
super().__init__(**kwargs)
# 频率参数
self.lq_freq_min = lq_freq_min
self.lq_freq_max = lq_freq_max
self.hq_freq = hq_freq
self.recon_loss_weight = recon_loss_weight
# 替换类别嵌入为低质量图像编码器
self.lq_encoder = nn.Sequential(
nn.Linear(self.token_embed_dim, self.encoder_embed_dim),
nn.LayerNorm(self.encoder_embed_dim),
nn.SiLU(),
nn.Linear(self.encoder_embed_dim, self.encoder_embed_dim)
)
# 初始化权重
self._init_lq_encoder()
def _init_lq_encoder(self):
for m in self.lq_encoder.modules():
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
if m.bias is not None:
nn.init.constant_(m.bias, 0)
if m.weight is not None:
nn.init.constant_(m.weight, 1.0)
def processingpregt_latent(self, imgs, frequency_level=None):
"""修改频率处理函数,支持指定频率级别"""
B, C, H, W = imgs.shape
out = torch.zeros_like(imgs)
# 如果未指定频率级别,随机选择一个
if frequency_level is None:
frequency_level = random.randint(self.lq_freq_min, self.lq_freq_max)
# 创建频率索引数组
core_index = torch.full((B,), frequency_level, dtype=torch.half, device=imgs.device)
# 应用频率处理
for i in range(B):
if frequency_level == 0:
out[i] = torch.zeros(C, H, W).to(imgs.device)
else:
imgs_resize = F.interpolate(imgs[i].unsqueeze(0), size=(frequency_level, frequency_level), mode='area')
out[i] = F.interpolate(imgs_resize, size=(H, W), mode='bicubic').squeeze(0)
return out, core_index
def forward(self, lq_imgs, gt_imgs, loss_weight=False):
"""前向传播,接受低质量和高质量图像对"""
# 处理低质量图像
lq_latents, lq_index = self.processingpregt_latent(lq_imgs,
frequency_level=random.randint(self.lq_freq_min, self.lq_freq_max))
# 分块处理
lq_patches = self.patchify(lq_latents)
gt_patches = self.patchify(gt_imgs)
# 获取低质量图像特征作为条件
lq_embedding = self.lq_encoder(lq_patches.mean(dim=1))
# 编码器-解码器处理
if self.mask:
orders = self.sample_orders(bsz=lq_patches.size(0))
mask = self.random_masking(lq_patches, orders, lq_index)
x = self.forward_mae_encoder(lq_patches, lq_embedding, mask)
z = self.forward_mae_decoder(x, mask)
else:
x = self.forward_mae_encoder(lq_patches, lq_embedding)
z = self.forward_mae_decoder(x)
# 扩散损失
diff_loss = self.forward_loss(z, gt_patches, mask if self.mask else None,
torch.full_like(lq_index, self.hq_freq), loss_weight)
# 重建损失 (L2损失)
recon_loss = F.mse_loss(self.unpatchify(z), gt_imgs)
# 总损失
total_loss = diff_loss + self.recon_loss_weight * recon_loss
return total_loss
def denoise_image(self, lq_img, num_steps=3, temperature=1.0, cfg=1.0):
"""推理函数,从低质量图像生成高质量图像"""
with torch.no_grad():
# 获取批次大小
bsz = lq_img.size(0)
# 预处理低质量图像
lq_latent, lq_index = self.processingpregt_latent(lq_img, frequency_level=self.lq_freq_min)
lq_patches = self.patchify(lq_latent)
# 获取低质量图像特征
lq_embedding = self.lq_encoder(lq_patches.mean(dim=1))
# 初始化生成过程
tokens = lq_patches.clone()
# 逐步提升频率
freq_steps = list(range(self.lq_freq_min, self.hq_freq + 1))
if len(freq_steps) > num_steps:
# 如果步数少于频率范围,选择均匀分布的几个频率点
freq_steps = [self.lq_freq_min] + [int(f) for f in np.linspace(self.lq_freq_min+1, self.hq_freq, num_steps-1)]
for freq_idx, freq in enumerate(freq_steps):
# 编码-解码
x = self.forward_mae_encoder(tokens, lq_embedding)
z = self.forward_mae_decoder(x)
# 应用扩散采样
B, L, C = z.shape
z = z.reshape(B * L, -1)
# CFG调度
cfg_iter = 1 + (cfg - 1) * freq_idx / (len(freq_steps) - 1)
temperature_iter = 0.9 + (temperature - 0.9) * freq_idx / (len(freq_steps) - 1)
# 为扩散模型提供频率索引
index = torch.tensor([freq]).unsqueeze(1).unsqueeze(-1).repeat(B, L, 1).reshape(B * L, -1).to(z.device).half()
# 采样新的令牌
tokens_flat = self.diffloss.sample(
torch.cat([z, z], dim=0), # 复制输入以满足CFG要求
temperature=temperature_iter,
cfg=cfg_iter,
index=index
)
tokens_flat, _ = tokens_flat.chunk(2, dim=0) # 移除CFG副本
# 重塑为图像令牌
tokens = tokens_flat.reshape(bsz, self.seq_len, -1)
# 如果不是最终频率,进行频率调整
if freq_idx < len(freq_steps) - 1:
next_freq = freq_steps[freq_idx + 1]
tokens_img = self.unpatchify(tokens)
tokens_img = F.interpolate(tokens_img, size=(next_freq, next_freq), mode='area')
tokens_img = F.interpolate(tokens_img, size=(16, 16), mode='bicubic')
tokens = self.patchify(tokens_img)
# 生成最终图像
output_latents = self.unpatchify(tokens)
return output_latents
# 仅使用一个模型配置 (基于FAR-Large)
def create_far_denoise(**kwargs):
model = FAR_Denoise(
encoder_embed_dim=1024, encoder_depth=16, encoder_num_heads=16,
decoder_embed_dim=1024, decoder_depth=16, decoder_num_heads=16,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
denoise_pair_datasets.py
import torch
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms
from util.crop import center_crop_arr
class DenoisePairDataset(Dataset):
"""用于加载低质量/高质量图像对的数据集"""
def __init__(self, lq_list_path, gt_list_path, img_size=256, transform=None):
# 读取LQ和GT图像路径
with open(lq_list_path, 'r') as f:
self.lq_paths = [line.strip() for line in f.readlines()]
with open(gt_list_path, 'r') as f:
self.gt_paths = [line.strip() for line in f.readlines()]
assert len(self.lq_paths) == len(self.gt_paths), "低质量和高质量图像数量不匹配"
if transform is None:
self.transform = transforms.Compose([
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, img_size)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
else:
self.transform = transform
def __len__(self):
return len(self.lq_paths)
def __getitem__(self, idx):
# 加载一对LQ和GT图像
lq_img = Image.open(self.lq_paths[idx]).convert('RGB')
gt_img = Image.open(self.gt_paths[idx]).convert('RGB')
if self.transform:
lq_img = self.transform(lq_img)
gt_img = self.transform(gt_img)
return lq_img, gt_img
855

被折叠的 条评论
为什么被折叠?



