关于Sampler和Scheduler?这些内容是你需要知道的

部署运行你感兴趣的模型镜像

在WebUI里,你所能看到的“采样方法”是一个单一的参数选项,但在Comfy UI 的核心采样器 KSampler 里,它会被“拆分”成两个单独的选项:Sampler Name(采样器名称)与 Scheduler (调度器)——

在这里插入图片描述
在这里插入图片描述
回溯Stable Diffusion生成一张图片的“去噪过程” ,SD会经由随机种子生成一张随机噪声图,然后利用训练好的”噪声预测器“(U-Net),结合输入的提示词等“条件”(Conditioning),进行“条件去噪”,在这张噪声图上不断添加一些形象,使之成为一张生成的新图片。

在这里插入图片描述
而广义的采样方法(Sampler),即代表在控制这个“去噪”过程的一种算法。简单地去理解,不同的算法会给你带来不同的采样结果,而不同算法对于采样步数的要求也可能会有些许差异。

在这里插入图片描述
在这里插入图片描述
如果你想更深入地挖掘这里面的技术原理,你可以在Reddit上的这一篇回答里,找到由 @ManBearScientist 撰写的这一篇回答,非常专业地列举了所有采样方法的来源与含义,以及它们是如何一步步积累取得今天的研究成果的
在这里插入图片描述
那调度器(Scheduler)呢?你可以把它看做是采样方法的一部分,主要用于控制采样过程中的时间步长。
大部分UI都会习惯将采样方法本身与调度器选项“合并”起来作为一种采样方法呈现给用户,但根据ComfyUI作者 @ComfyAnonymous 在这一个Issue中的答复:我决定把它作为一个单独的选项,因为它对我来说更有意义。

在这里插入图片描述
根据他的建议,如果你希望各种采样器能与其他UI中保持大致一样的行为,只需要使用Karras或Normal即可。

您可能感兴趣的与本文相关的镜像

ComfyUI

ComfyUI

AI应用
ComfyUI

ComfyUI是一款易于上手的工作流设计工具,具有以下特点:基于工作流节点设计,可视化工作流搭建,快速切换工作流,对显存占用小,速度快,支持多种插件,如ADetailer、Controlnet和AnimateDIFF等

import os import torch import torchvision import torch.nn as nn from tqdm import tqdm from torch.utils.data import DataLoader from torch.distributed import all_reduce, all_gather, ReduceOp from utils.func import * from modules.loss import * from modules.scheduler import * def train(cfg, model, train_dataset, val_dataset, estimator, logger=None): if cfg.base.HPO: from nni import report_intermediate_result, report_final_result device = cfg.base.device optimizer = initialize_optimizer(cfg, model) train_sampler, val_sampler = initialize_sampler(cfg, train_dataset, val_dataset) lr_scheduler, warmup_scheduler = initialize_lr_scheduler(cfg, optimizer) loss_function, loss_weight_scheduler = initialize_loss(cfg, train_dataset) train_loader, val_loader = initialize_dataloader(cfg, train_dataset, val_dataset, train_sampler, val_sampler) # start training model.train() avg_loss = 0 max_indicator = -999 for epoch in range(1, cfg.train.epochs + 1): # resampling weight update if cfg.dist.distributed: train_sampler.set_epoch(epoch) elif train_sampler: train_sampler.step() # update loss weights if loss_weight_scheduler: weight = loss_weight_scheduler.step() loss_function.weight = weight.to(device) # warmup scheduler update if warmup_scheduler and not warmup_scheduler.is_finish(): warmup_scheduler.step() epoch_loss = 0 estimator.reset() progress = tqdm(enumerate(train_loader), total=len(train_loader)) if cfg.base.progress else enumerate(train_loader) for step, train_data in progress: X, y = train_data X = X.cuda(cfg.dist.gpu) if cfg.dist.distributed else X.to(device) y = y.cuda(cfg.dist.gpu) if cfg.dist.distributed else y.to(device) y = select_target_type(y, cfg.train.criterion) # forward y_pred = model(X) loss = loss_function(y_pred, y) # backward optimizer.zero_grad() loss.backward() optimizer.step() # metrics if cfg.dist.distributed: all_reduce(loss, ReduceOp.SUM) loss = loss / cfg.dist.world_size y_pred_list = [torch.zeros_like(y_pred) for _ in range(cfg.dist.world_size)] y_list = [torch.zeros_like(y) for _ in range(cfg.dist.world_size)] all_gather(y_pred_list, y_pred) all_gather(y_list, y) y_pred = torch.cat(y_pred_list, dim=0) y = torch.cat(y_list, dim=0) if is_main(cfg): epoch_loss += loss.item() avg_loss = epoch_loss / (step + 1) estimator.update(y_pred, y) message = 'epoch: [{} / {}], loss: {:.6f}'.format(epoch, cfg.train.epochs, avg_loss) if cfg.base.progress: progress.set_description(message) if is_main(cfg) and not cfg.base.progress: print(message) if is_main(cfg): train_scores = estimator.get_scores(4) scores_txt = ', '.join(['{}: {}'.format(metric, score) for metric, score in train_scores.items()]) print('Training metrics:', scores_txt) curr_lr = optimizer.param_groups[0]['lr'] if is_main(cfg) and logger: for metric, score in train_scores.items(): logger.add_scalar('training {}'.format(metric), score, epoch) logger.add_scalar('training loss', avg_loss, epoch) logger.add_scalar('learning rate', curr_lr, epoch) if is_main(cfg) and cfg.train.sample_view: samples = torchvision.utils.make_grid(X) samples = inverse_normalize(samples, cfg.data.mean, cfg.data.std) logger.add_image('input samples', samples, epoch, dataformats='CHW') # validation performance if epoch % cfg.train.eval_interval == 0: eval(cfg, model, val_loader, cfg.train.criterion, estimator, device) val_scores = estimator.get_scores(6) scores_txt = ['{}: {}'.format(metric, score) for metric, score in val_scores.items()] print_msg('Validation metrics:', scores_txt) if is_main(cfg) and logger: for metric, score in val_scores.items(): logger.add_scalar('validation {}'.format(metric), score, epoch) # save model indicator = val_scores[cfg.train.indicator] if is_main(cfg) and indicator > max_indicator: save_weights(model, os.path.join(cfg.base.save_path, 'best_validation_weights.pt')) max_indicator = indicator print_msg('Best {} in validation set. Model save at {}'.format(cfg.train.indicator, cfg.base.save_path)) if is_main(cfg) and cfg.base.HPO: report_intermediate_result(indicator) if is_main(cfg) and epoch % cfg.train.save_interval == 0: save_weights(model, os.path.join(cfg.base.save_path, 'epoch_{}.pt'.format(epoch))) # update learning rate if lr_scheduler and (not warmup_scheduler or warmup_scheduler.is_finish()): if cfg.solver.lr_scheduler == 'reduce_on_plateau': lr_scheduler.step(avg_loss) else: lr_scheduler.step() # save final model if is_main(cfg): save_weights(model, os.path.join(cfg.base.save_path, 'final_weights.pt')) if is_main(cfg) and cfg.base.HPO: report_final_result(max_indicator) if is_main(cfg) and logger: logger.close() def evaluate(cfg, model, test_dataset, estimator): test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset) if cfg.dist.distributed else None test_loader = DataLoader( test_dataset, shuffle=(test_sampler is None), sampler=test_sampler, batch_size=cfg.train.batch_size, num_workers=cfg.train.num_workers, pin_memory=cfg.train.pin_memory ) print('Running on Test set...') eval(cfg, model, test_loader, cfg.train.criterion, estimator, cfg.base.device) if is_main(cfg): print('================Finished================') test_scores = estimator.get_scores(6) for metric, score in test_scores.items(): print('{}: {}'.format(metric, score)) print('Confusion Matrix:') print(estimator.get_conf_mat()) print('========================================') def eval(cfg, model, dataloader, criterion, estimator, device): model.eval() torch.set_grad_enabled(False) estimator.reset() for test_data in dataloader: X, y = test_data X = X.cuda(cfg.dist.gpu) if cfg.dist.distributed else X.to(device) y = y.cuda(cfg.dist.gpu) if cfg.dist.distributed else y.to(device) y = select_target_type(y, criterion) y_pred = model(X) if cfg.dist.distributed: y_pred_list = [torch.zeros_like(y_pred) for _ in range(cfg.dist.world_size)] y_list = [torch.zeros_like(y) for _ in range(cfg.dist.world_size)] all_gather(y_pred_list, y_pred) all_gather(y_list, y) y_pred = torch.cat(y_pred_list, dim=0) y = torch.cat(y_list, dim=0) estimator.update(y_pred, y) model.train() torch.set_grad_enabled(True) # define weighted_sampler def initialize_sampler(cfg, train_dataset, val_dataset): sampling_strategy = cfg.data.sampling_strategy if cfg.dist.distributed: if sampling_strategy != 'instance_balanced': msg = 'Resampling is not allowed when distributed parallel is applied. \ Please set sampling_strategy to instance_balanced.' exit_with_error(msg) train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=cfg.dist.world_size, rank=cfg.dist.rank ) val_sampler = torch.utils.data.distributed.DistributedSampler( val_dataset, num_replicas=cfg.dist.world_size, rank=cfg.dist.rank ) else: val_sampler = None if sampling_strategy == 'class_balanced': train_sampler = ScheduledWeightedSampler(train_dataset, 1) elif sampling_strategy == 'progressively_balanced': train_sampler = ScheduledWeightedSampler(train_dataset, cfg.data.sampling_weights_decay_rate) elif sampling_strategy == 'instance_balanced': train_sampler = None else: raise NotImplementedError('Not implemented resampling strategy.') return train_sampler, val_sampler # define data loader def initialize_dataloader(cfg, train_dataset, val_dataset, train_sampler, val_sampler): batch_size = cfg.train.batch_size num_workers = cfg.train.num_workers pin_memory = cfg.train.pin_memory train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=num_workers, drop_last=True, pin_memory=pin_memory ) val_loader = DataLoader( val_dataset, batch_size=batch_size, shuffle=(val_sampler is None), sampler=val_sampler, num_workers=num_workers, drop_last=False, pin_memory=pin_memory ) return train_loader, val_loader # define loss and loss weights scheduler def initialize_loss(cfg, train_dataset): criterion = cfg.train.criterion criterion_args = cfg.criterion_args[criterion] weight = None loss_weight_scheduler = None loss_weight = cfg.train.loss_weight if criterion == 'cross_entropy': if loss_weight == 'balance': loss_weight_scheduler = LossWeightsScheduler(train_dataset, 1) elif loss_weight == 'dynamic': loss_weight_scheduler = LossWeightsScheduler(train_dataset, cfg.train.loss_weight_decay_rate) elif isinstance(loss_weight, list): assert len(loss_weight) == len(train_dataset.classes) weight = torch.as_tensor(loss_weight, dtype=torch.float32, device=cfg.base.device) loss = nn.CrossEntropyLoss(weight=weight, **criterion_args) elif criterion == 'mean_square_error': loss = nn.MSELoss(**criterion_args) elif criterion == 'mean_absolute_error': loss = nn.L1Loss(**criterion_args) elif criterion == 'smooth_L1': loss = nn.SmoothL1Loss(**criterion_args) elif criterion == 'kappa_loss': loss = KappaLoss(**criterion_args) elif criterion == 'focal_loss': loss = FocalLoss(**criterion_args) else: raise NotImplementedError('Not implemented loss function.') loss_function = WarpedLoss(loss, criterion) return loss_function, loss_weight_scheduler # define optmizer def initialize_optimizer(cfg, model): optimizer_strategy = cfg.solver.optimizer learning_rate = cfg.solver.learning_rate weight_decay = cfg.solver.weight_decay momentum = cfg.solver.momentum nesterov = cfg.solver.nesterov adamw_betas = cfg.solver.adamw_betas if optimizer_strategy == 'SGD': optimizer = torch.optim.SGD( model.parameters(), lr=learning_rate, momentum=momentum, nesterov=nesterov, weight_decay=weight_decay ) elif optimizer_strategy == 'ADAM': optimizer = torch.optim.Adam( model.parameters(), lr=learning_rate, weight_decay=weight_decay ) elif optimizer_strategy == 'ADAMW': optimizer = torch.optim.AdamW( model.parameters(), lr=learning_rate, betas=adamw_betas, weight_decay=weight_decay ) else: raise NotImplementedError('Not implemented optimizer.') return optimizer # define learning rate scheduler def initialize_lr_scheduler(cfg, optimizer): warmup_epochs = cfg.train.warmup_epochs learning_rate = cfg.solver.learning_rate scheduler_strategy = cfg.solver.lr_scheduler if not scheduler_strategy: lr_scheduler = None else: scheduler_args = cfg.scheduler_args[scheduler_strategy] if scheduler_strategy == 'cosine': lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, **scheduler_args) elif scheduler_strategy == 'multiple_steps': lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, **scheduler_args) elif scheduler_strategy == 'reduce_on_plateau': lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, **scheduler_args) elif scheduler_strategy == 'exponential': lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, **scheduler_args) elif scheduler_strategy == 'clipped_cosine': lr_scheduler = ClippedCosineAnnealingLR(optimizer, **scheduler_args) else: raise NotImplementedError('Not implemented learning rate scheduler.') if warmup_epochs > 0: warmup_scheduler = WarmupLRScheduler(optimizer, warmup_epochs, learning_rate) else: warmup_scheduler = None return lr_scheduler, warmup_scheduler 代码需要用户修改哪些地方
最新发布
10-18
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值