【DA-CLIP】复原过程代码解读

    image_context, degra_context = clip_model.encode_image(img4clip, control=True)
     
    image_context = image_context.float()
        
    degra_context = degra_context.float()
      
    
    LQ_tensor = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0)

    ...
    noisy_tensor = sde.noise_state(LQ_tensor)
     
    model.feed_data(noisy_tensor, LQ_tensor, text_context=degra_context, image_context=image_context)
    # 这一行将带有噪声的图像和原始低质量图像以及上下文信息传递给模型,
    model.test(sde)
     
    visuals = model.get_current_visuals(need_GT=False)
     
    output = util.tensor2img(visuals["Output"].squeeze())

    return output[:, :, [2, 1, 0]]

  image_context, degra_context生成过程clip_model.encode_image下次再整理。

第一步:根据退化图像张量生成随机噪声图张量noisy_tensor = sde.noise_state(LQ_tensor)

  def noise_state(self, tensor):
        return tensor + torch.randn_like(tensor) * self.max_sigma
   

sde是IRSDE类实例,定义文件在sde_utils.py

 noise_state 函数是一个自定义的方法,它接收一个名为 tensor 的参数,并返回一个新的张量,该张量是原始输入张量加上一些噪声。

噪声是通过使用 torch.randn_like 函数生成的,该函数创建一个与输入张量形状相同的新张量,其元素是从标准正态分布(均值为0,方差为1)中随机抽取的。

max_sigma在yml中设置为50

第二步:给去噪模型model传入相关参数。model.feed_data(noisy_tensor, LQ_tensor, text_context=degra_context,image_context=image_context)

 model是DenoiseModel类实例,文件定义在denoising_model.py

    def feed_data(self, state, LQ, GT=None, text_context=None, image_context=None):
        self.state = state.to(self.device)    # noisy_state
        self.condition = LQ.to(self.device)  # LQ
        if GT is not None:
            self.state_0 = GT.to(self.device)  # GT
        self.text_context = text_context
        self.image_context = image_context

 将值传给类属性

第三步:调用封装好的test方法,使用IRSDE的逆扩散过程复原。model.test(sde)

    def test(self, sde=None, save_states=False):
        sde.set_mu(self.condition)
        self.model.eval()
        # 将模型设置为评估模式。
        with torch.no_grad():
            self.output = sde.reverse_sde(self.state, save_states=save_states, text_context=self.text_context, image_context=self.image_context)

        self.model.train()

在评估模式下 通过reverse_sde计算复原结果。

  1. self.model.train():当调用此方法时,模型会进入训练模式。在训练模式下,模型会跟踪所有梯度,以便在反向传播过程中更新模型的权重。此外,对于某些类型的层(如Dropout层和BatchNorm层),训练模式会改变它们的行为。例如,Dropout层在训练时会随机丢弃一些神经元,以防止过拟合,而在评估模式下则会保持所有神经元的激活。

  2. self.model.eval():与self.model.train()相反,调用此方法会使模型进入评估模式。在评估模式下,模型不会跟踪梯度,这样可以节省计算资源并提高推理速度。同时,像Dropout和BatchNorm这样的层也会改变它们的行为,以反映模型在实际使用时的状态。例如,Dropout层在评估模式下不会丢弃任何神经元,BatchNorm层会使用在整个训练集上学习到的均值和方差。

第一步condition是刚刚喂入的LQ_tensor。根据传入的sde实例的set_mu设置mu的值。作用未知。

    def set_mu(self, mu):
        self.mu = mu

 计算output,**kwargs 是一个常用的参数,它出现在函数定义中,用于表示一个不定数量的关键字参数(keyword arguments)。kwargs 是 "keyword arguments" 的缩写,它允许函数接收除定义的参数之外的额外参数。

主要模块reverse_sde()

    # 定义逆向SDE过程,从最终状态 xt 逆向模拟回到初始状态
    def reverse_sde(self, xt, T=-1, save_states=False, save_dir='sde_state', **kwargs):
        # 如果传入的 T 为负数,则使用类实例的 sample_T 属性作为逆向模拟的总时间
        T = self.sample_T if T < 0 else T

        # 从输入的最终状态 xt 创建一个副本,用于在逆向模拟过程中更新状态
        x = xt.clone()

        # 使用 tqdm 库创建一个进度条,显示逆向模拟进度
        for t in tqdm(reversed(range(1, T + 1))):
            # 调用 score_fn 方法计算给定状态和时间的评分函数(也称为概率密度函数的梯度)
            score = self.score_fn(x, t, self.sample_scale, **kwargs)
            # 执行逆向SDE步骤,使用评分函数更新状态 x
            x = self.reverse_sde_step(x, score, t)
            # x = self.reverse_sde_step_mean(x, score, t)  # 这行代码被注释掉了,可能表示一个备用的逆向模拟步骤

            # 如果 save_states 为 True,则保存逆向模拟过程中的状态
            if save_states:
                # 计算保存状态的间隔,这里假设只保存100个图像
                interval = self.T // 100
                # 如果当前时间步是保存间隔的整数倍,则保存状态
                if t % interval == 0:
                    # 计算当前状态的索引
                    idx = t // interval
                    # 如果保存目录不存在,则创建它
                    os.makedirs(save_dir, exist_ok=True)
                    # 将当前状态 x 沿第一个维度分成两部分
                    x_L, x_R = x.chunk(2, dim=1)
                    # 将两部分状态沿维度3拼接,并保存为图像文件
                    tvutils.save_image(torch.cat([x_L, x_R], dim=3).data, f'{save_dir}/state_{idx}.png',
                                       normalize=False)

        # 逆向模拟结束后,返回最终的状态 x
        return x
类初始化时定义T 100,sample_T -1
self.sample_T = self.T if sample_T < 0 else sample_T,
判断得sample_T为100

xt是传入的噪声tensor

reversed(range(1, T + 1)) 是一个Python内置函数 reversed 的使用,它接收一个序列并返回一个反向的迭代器。range(1, T + 1) 生成一个从 1 到 T(不包括 T+1)的整数序列。因此,reversed(range(1, T + 1)) 会生成一个从 T 递减到 1 的整数序列。 

说明该函数目的是将图像状态x逐步逆扩散

使用score_fn计算得分

    def score_fn(self, x, t, scale=1.0, **kwargs):
        # need to pre-set mu and score_model
        noise = self.model(x, self.mu, t * scale, **kwargs)
        return self.get_score_from_noise(noise, t)
self.sample_scale = self.T / self.sample_T
计算得sample_scale 1
 model属性由DenoiseModel的model属性传来。类定义

在DenoiseModel的初始化函数中

self.model = networks.define_G(opt).to(self.device)

查看networks.py 

# Generator
def define_G(opt):
    opt_net = opt["network_G"]
    which_model = opt_net["which_model_G"]
    setting = opt_net["setting"]
    netG = getattr(M, which_model)(**setting)
    return netG

yml配置如下 

network_G:
  which_model_G: ConditionalUNet
  setting:
    in_nc: 3
    out_nc: 3
    nf: 64
    ch_mult: [1, 2, 4, 8]
    context_dim: 512
    use_degra_context: true
    use_image_context: true
from models import modules as M

M代指modules包 ,getattr从该模块获取指定ConditionalUNet类。该类被定义在DenoisingUNet.py文件中。**setting - 这是一个参数字典,它使用参数解包语法 ** 来传递关键字参数。

 

 查看ConditionalUNet类初始化函数

 def __init__(self, in_nc, out_nc, nf, ch_mult=[1, 2, 4, 4], 
                 context_dim=512, use_degra_context=True, use_image_context=False, upscale=1):
        # 调用父类构造函数
        super().__init__()
        # 记录网络深度
        self.depth = len(ch_mult)
        # 上采样因子,当前未使用
        self.upscale = upscale
        # 设置上下文维度,如果未提供则设为-1
        self.context_dim = -1 if context_dim is None else context_dim
        # 是否使用图像上下文
        self.use_image_context = use_image_context
        # 是否使用degra上下文
        self.use_degra_context = use_degra_context

        # 设置每个头部的通道数
        num_head_channels = 32
        # 头部维度
        dim_head = num_head_channels

        # 创建ResBlock的快捷方式,使用默认的卷积和非线性激活函数
        block_class = functools.partial(ResBlock, conv=default_conv, act=NonLinearity())

        # 初始化卷积层
        self.init_conv = default_conv(in_nc*2, nf, 7)
        
        # 时间嵌入维度
        time_dim = nf * 4

        # 是否使用随机或学习正弦条件
        self.random_or_learned_sinusoidal_cond = False

        # 如果使用随机或学习正弦条件
        if self.random_or_learned_sinusoidal_cond:
            # 学习正弦条件的维度
            learned_sinusoidal_dim = 16
            # 创建正弦位置嵌入
            sinu_pos_emb = RandomOrLearnedSinusoidalPosEmb(learned_sinusoidal_dim, False)
            # 傅里叶维度
            fourier_dim = learned_sinusoidal_dim + 1
        else:
            # 创建正弦位置嵌入
            sinu_pos_emb = SinusoidalPosEmb(nf)
            # 傅里叶维度
            fourier_dim = nf

        # 时间MLP层
        self.time_mlp = nn.Sequential(
            sinu_pos_emb,
            nn.Linear(fourier_dim, time_dim),
            nn.GELU(),
            nn.Linear(time_dim, time_dim)
        )

        # 如果上下文维度大于0且使用degra上下文
        if self.context_dim > 0 and use_degra_context: 
            # 创建文本提示参数
            self.prompt = nn.Parameter(torch.rand(1, time_dim))
            # 创建文本MLP层
            self.text_mlp = nn.Sequential(
                nn.Linear(context_dim, time_dim), NonLinearity(),
                nn.Linear(time_dim, time_dim))
            # 创建提示MLP层
            self.prompt_mlp = nn.Linear(time_dim, time_dim)

        # 构建网络层
        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        # 通道倍增序列,添加初始值1
        ch_mult = [1] + ch_mult

        # 遍历网络深度
        for i in range(self.depth):
            # 输入和输出维度
            dim_in = nf * ch_mult[i]
            dim_out = nf * ch_mult[i+1]

            # 输入和输出头部数量
            num_heads_in = dim_in // num_head_channels
            num_heads_out = dim_out // num_head_channels
            # 每个头部的输入维度
            dim_head_in = dim_in // num_heads_in

            # 如果使用图像上下文且上下文维度大于0
            if use_image_context and context_dim > 0:
                # 使用空间变换器或线性注意力机制
                att_down = LinearAttention(dim_in) if i < 3 else SpatialTransformer(dim_in, num_heads_in, dim_head, depth=1, context_dim=context_dim)
                att_up = LinearAttention(dim_out) if i < 3 else SpatialTransformer(dim_out, num_heads_out, dim_head, depth=1, context_dim=context_dim)
            else:
                # 使用线性注意力机制
                att_down = LinearAttention(dim_in) # if i < 2 else Attention(dim_in)
                att_up = LinearAttention(dim_out) # if i < 2 else Attention(dim_out)

            # 下采样模块列表
            self.downs.append(nn.ModuleList([
                block_class(dim_in=dim_in, dim_out=dim_in, time_emb_dim=time_dim),
                block_class(dim_in=dim_in, dim_out=dim_in, time_emb_dim=time_dim),
                Residual(PreNorm(dim_in, att_down)),
                Downsample(dim_in, dim_out) if i != (self.depth-1) else default_conv(dim_in, dim_out)
            ]))

            # 上采样模块列表
            self.ups.insert(0, nn.ModuleList([
                block_class(dim_in=dim_out + dim_in, dim_out=dim_out, time_emb_dim=time_dim),
                block_class(dim_in=dim_out + dim_in, dim_out=dim_out, time_emb_dim=time_dim),
                Residual(PreNorm(dim_out, att_up)),
                Upsample(dim_out, dim_in) if i!=0 else default_conv(dim_out, dim_in)
            ]))

        # 中间维度
        mid_dim = nf * ch_mult[-1]
        # 中间头部数量
        num_heads_mid = mid_dim // num_head_channels
        # 中间块1
        self.mid_block1 = block_class(dim_in=mid_dim, dim_out=mid_dim, time_emb_dim=time_dim)
        # 如果使用图像上下文且上下文维度大于0
        if use_image_context and context_dim > 0:
            # 使用空间变换器
            self.mid_attn = Residual(PreNorm(mid_dim, SpatialTransformer(mid_dim, num_heads_mid, dim_head, depth=1, context_dim=context_dim)))
        else:
            # 使用线性注意力机制
            self.mid_attn = Residual(PreNorm(mid_dim, LinearAttention(mid_dim)))
        # 中间块2
        self.mid_block2 = block_class(dim_in=mid_dim, dim_out=mid_dim, time_emb_dim=time_dim)

        # 最终残差块
        self.final_res_block = block_class(dim_in=nf * 2, dim_out=nf, time_emb_dim=time_dim)
        # 最终卷积层
        self.final_conv = nn.Conv2d(nf, out_nc, 3, 1, 1)
 score_fn计算得分执行self.model的forward()

 noise = self.model(x, self.mu, t * scale, **kwargs)

x是noisy_tensor(self_state),mu是LQ_tensor(cnoditional_tensor),t是T(定义为100)到1值,scale是1

有很多参数别名但最终传下来应该是这样

该参数被传入ConditionalUNet类forward()方法。整理调用过程理解是:

model是一个初始化的netG,netG是包含以上初始化和forward方法的ConditionalUNet类,那么执行self.model(x, self.mu, t * scale, **kwargs)就是在执行该类的forward()

forward 方法是神经网络模型的核心,它定义了数据通过网络的前向传播过程。该方法接收输入张量 xt、条件张量 cond、时间参数 time,以及可选的文本和图像上下文。通过一系列卷积、注意力机制和上/下采样操作,模型提取和融合特征,最终生成输出张量。输出张量经过裁剪,确保其尺寸与输入图像的尺寸一致。

def forward(self, xt, cond, time, text_context=None, image_context=None):
    # 检查输入的时间参数是否为整数或浮点数,如果是,则将其转换为一个单元素张量,并移动到xt所在的设备
    if isinstance(time, int) or isinstance(time, float):
        time = torch.tensor([time]).to(xt.device)
    
    # X=noisy_tensor-LQ_tensor就是文章第一步添加的随机噪声,与LQ_tensor拼接,增加通道维度
    x = xt - cond
    x = torch.cat([x, cond], dim=1)

    # 获取输入张量的空间维度H和W
    H, W = x.shape[2:]
    # 检查并调整输入张量x的空间尺寸以匹配原始图像的尺寸
    x = self.check_image_size(x, H, W)

    # 应用初始卷积层
    x = self.init_conv(x)
    # 克隆x,用于后续操作
    x_ = x.clone()

    # 通过时间MLP处理时间参数
    t = self.time_mlp(time) 
    # 如果上下文维度大于0,并且使用degra上下文,且文本上下文不为空
    if self.context_dim > 0:
        if self.use_degra_context and text_context is not None:
            # 计算文本上下文的嵌入,将其与提示向量结合,并进行处理
            prompt_embedding = torch.softmax(self.text_mlp(text_context), dim=1) * self.prompt
            prompt_embedding = self.prompt_mlp(prompt_embedding)
            # 将处理后的文本上下文嵌入加到时间参数t上
            t = t + prompt_embedding

        # 如果使用图像上下文,且图像上下文不为空
        if self.use_image_context and image_context is not None:
            # 为图像上下文增加一个通道维度
            image_context = image_context.unsqueeze(1)

    # 存储下采样过程中的特征图
    h = []
    # 遍历下采样模块列表
    for b1, b2, attn, downsample in self.downs:
        # 应用第一个残差块和时间参数t
        x = b1(x, t)
        # 存储特征图
        h.append(x)

        # 应用第二个残差块和时间参数t
        x = b2(x, t)
        # 应用注意力机制,如果提供了图像上下文,则使用它
        x = attn(x, context=image_context)
        # 存储特征图
        h.append(x)

        # 应用下采样操作
        x = downsample(x)

    # 应用中间块1和时间参数t
    x = self.mid_block1(x, t)
    # 如果使用图像上下文,则应用注意力机制
    x = self.mid_attn(x, context=image_context) if self.use_image_context else x
    # 应用中间块2和时间参数t
    x = self.mid_block2(x, t)

    # 遍历上采样模块列表
    for b1, b2, attn, upsample in self.ups:
        # 从历史特征图中弹出并拼接特征,与当前特征图拼接
        x = torch.cat([x, h.pop()], dim=1)
        # 应用第一个残差块和时间参数t
        x = b1(x, t)
        
        # 再次从历史特征图中弹出并拼接特征,与当前特征图拼接
        x = torch.cat([x, h.pop()], dim=1)
        # 应用第二个残差块和时间参数t
        x = b2(x, t)

        # 应用注意力机制,如果提供了图像上下文,则使用它
        x = attn(x, context=image_context)
        # 应用上采样操作
        x = upsample(x)

    # 将原始输入xt与当前特征图x拼接,增加通道维度
    x = torch.cat([x, x_], dim=1)

    # 应用最终的残差块和时间参数t
    x = self.final_res_block(x, t)
    # 应用最终的卷积层
    x = self.final_conv(x)

    # 裁剪输出张量x,使其空间尺寸与原始输入图像的尺寸相匹配
    x = x[..., :H, :W].contiguous()
    
    # 返回处理后的输出张量x
    return x
 根据返回的noise计算score
    def get_score_from_noise(self, noise, t):
        return -noise / self.sigma_bar(t)
    def sigma_bar(self, t):
        return self.sigma_bars[t]
      sigma_bars = get_sigma_bars(thetas_cumsum)
thetas_cumsum = get_thetas_cumsum(thetas) - thetas[0] # for that thetas[0] is not 0
def get_sigma_bars(thetas_cumsum):
    return torch.sqrt(max_sigma**2 * (1 - torch.exp(-2 * thetas_cumsum * self.dt)))

 以上完成score计算,逐步逆扩散

使用rever_sde_step

    def reverse_sde_step(self, x, score, t):
        return x - self.sde_reverse_drift(x, score, t) - self.dispersion(x, t)
 计算drift
    def sde_reverse_drift(self, x, score, t):
        return (self.thetas[t] * (self.mu - x) - self.sigmas[t]**2 * score) * self.dt

 thetas

        if schedule == 'cosine':
            thetas = cosine_theta_schedule(T)
      def cosine_theta_schedule(timesteps, s = 0.008):
            """
            cosine schedule
            """
            # print('cosine schedule')
            timesteps = timesteps + 2 # for truncating from 1 to -1
            steps = timesteps + 1
            x = torch.linspace(0, timesteps, steps, dtype=torch.float32)
            alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
            alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
            betas = 1 - alphas_cumprod[1:-1]
            return betas

dt 

        self.dt = -1 / thetas_cumsum[-1] * math.log(eps)
def dispersion(self, x, t):
    return self.sigmas[t] * (torch.randn_like(x) * math.sqrt(self.dt)).to(self.device)

第四步整理复原图像信息

visuals = model.get_current_visuals(need_GT=False)

    def get_current_visuals(self, need_GT=True):
        out_dict = OrderedDict()
        out_dict["Input"] = self.condition.detach()[0].float().cpu()
        out_dict["Output"] = self.output.detach()[0].float().cpu()
        if need_GT:
            out_dict["GT"] = self.state_0.detach()[0].float().cpu()
        return out_dict

output是第三步生成的复原图像tensor,input是LQ_tesnsor

output = util.tensor2img(visuals["Output"].squeeze())

将模型输出的恢复图像张量转换为PIL图像格式。
visuals["Output"]获取了复原图像的张量,squeeze()方法移除了所有单维度的批次维度。

def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
    """
    Converts a torch Tensor into an image Numpy array
    Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
    Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
    """
    tensor = tensor.squeeze().float().cpu().clamp_(*min_max)  # clamp
    tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0])  # to range [0,1]
    n_dim = tensor.dim()
    if n_dim == 4:
        n_img = len(tensor)
        img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
        img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))  # HWC, BGR
    elif n_dim == 3:
        img_np = tensor.numpy()
        img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))  # HWC, BGR
    elif n_dim == 2:
        img_np = tensor.numpy()
    else:
        raise TypeError(
            "Only support 4D, 3D and 2D tensor. But received with dimension: {:d}".format(
                n_dim
            )
        )
    if out_type == np.uint8:
        img_np = (img_np * 255.0).round()
        # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
    return img_np.astype(out_type)

  return output[:, :, [2, 1, 0]]

这一行将图像的通道顺序从RGB转换为BGR,这是大多数图像处理库和显示设备使用的格式,并返回最终的恢复图像。

IRSDE参数作用、代码和公式的对照下次再整理

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值