更新:
- 4.26 第一次整理博文,只简要指出代码整体结构功能,对代码细节未完全理解。
- 5.2 发现交叉注意力部分理解完全有失偏颇,重新整理关于unet和crossattention理论和代码
整体代码
代码地址在
daclip-uir-main/universal-image-restoration/config/daclip-sde/models/modules/DenoisingUNet_arch.py
注意力机制相关代码在同级文件夹下attention.py。
该部分属于使用IR-SDE的复原处理中估计图像噪声得分的步骤。
【DA-CLIP】复原过程代码解读-优快云博客https://blog.youkuaiyun.com/m0_60350022/article/details/137699012?spm=1001.2014.3001.5501这部分在DA-CLIP的论文中描述如下:
We use IR-SDE ( Luo et al. , 2023a ) as the base framework for image restoration. It adapts a U-Net architecture similar to DDPM ( Ho et al. , 2020 ) but removes all self-attention layers. To inject clean content embeddings into the diffusion process, we introduce a cross-attention ( Rombach et al. , 2022 ) mechanism to learn semantic guidance from pre-trained VLMs. Considering the varying input sizes in image restoration tasks and the increasing cost of applying attention to high-resolution features, we only use cross-attention in the bottom blocks of the U-Net for sample efficiency.
我们使用IR-SDE (Luo等人。作为图像恢复的基础框架。它采用了一种类似于DDPM的U-Net架构(Ho等人。但却删除了所有的自我注意层。为了在扩散过程中注入干净的内容嵌入,我们引入了交叉注意(Rombach等。从预先训练过的vlm中学习语义指导的机制。考虑到图像恢复任务中输入大小的变化,以及对高分辨率特征的注意应用成本的增加,我们只在U-Net的底部块中使用交叉注意来提高采样效率。
在下面的代码中可以得到证实。
ConditionalUNet的forward
接收输入张量 xt
、条件张量 cond
、时间参数 time
,以及可选的文本和图像上下文。
其中xt、cond、text_context, image_context分别对应
noisy_tensor, LQ_tensor, degra_context, image_context
根据配置文件初始化
def forward(self, xt, cond, time, text_context=None, image_context=None):
# 检查输入的时间参数是否为整数或浮点数,如果是,则将其转换为一个单元素张量,并移动到xt所在的设备
if isinstance(time, int) or isinstance(time, float):
time = torch.tensor([time]).to(xt.device)
# X=noisy_tensor-LQ_tensor就是文章第一步添加的随机噪声,与LQ_tensor拼接,增加通道维度
x = xt - cond
x = torch.cat([x, cond], dim=1)
# 获取输入张量的空间维度H和W
H, W = x.shape[2:]
# 检查并调整输入张量x的空间尺寸以匹配原始图像的尺寸
x = self.check_image_size(x, H, W)
# 应用初始卷积层
x = self.init_conv(x)
# 克隆x,用于后续操作
x_ = x.clone()
# 通过时间MLP处理时间参数
t = self.time_mlp(time)
# 如果上下文维度大于0,并且使用degra上下文,且文本上下文不为空
if self.context_dim > 0:
if self.use_degra_context and text_context is not None:
# 计算文本上下文的嵌入,将其与提示向量结合,并进行处理
prompt_embedding = torch.softmax(self.text_mlp(text_context), dim=1) * self.prompt
prompt_embedding = self.prompt_mlp(prompt_embedding)
# 将处理后的文本上下文嵌入加到时间参数t上
t = t + prompt_embedding
# 如果使用图像上下文,且图像上下文不为空
if self.use_image_context and image_context is not None:
# 为图像上下文增加一个通道维度
image_context = image_context.unsqueeze(1)
# 存储下采样过程中的特征图
h = []
# 遍历下采样模块列表
for b1, b2, attn, downsample in self.downs:
# 应用第一个残差块和时间参数t
x = b1(x, t)
# 存储特征图
h.append(x)
# 应用第二个残差块和时间参数t
x = b2(x, t)
# 应用注意力机制,如果提供了图像上下文,则使用它
x = attn(x, context=image_context)
# 存储特征图
h.append(x)
# 应用下采样操作
x = downsample(x)
# 应用中间块1和时间参数t
x = self.mid_block1(x, t)
# 如果使用图像上下文,则应用注意力机制
x = self.mid_attn(x, context=image_context) if self.use_image_context else x
# 应用中间块2和时间参数t
x = self.mid_block2(x, t)
# 遍历上采样模块列表
for b1, b2, attn, upsample in self.ups:
# 从历史特征图中弹出并拼接特征,与当前特征图拼接
x = torch.cat([x, h.pop()], dim=1)
# 应用第一个残差块和时间参数t
x = b1(x, t)
# 再次从历史特征图中弹出并拼接特征,与当前特征图拼接
x = torch.cat([x, h.pop()], dim=1)
# 应用第二个残差块和时间参数t
x = b2(x, t)
# 应用注意力机制,如果提供了图像上下文,则使用它
x = attn(x, context=image_context)
# 应用上采样操作
x = upsample(x)
# 将原始输入xt与当前特征图x拼接,增加通道维度
x = torch.cat([x, x_], dim=1)
# 应用最终的残差块和时间参数t
x = self.final_res_block(x, t)
# 应用最终的卷积层
x = self.final_conv(x)
# 裁剪输出张量x,使其空间尺寸与原始输入图像的尺寸相匹配
x = x[..., :H, :W].contiguous()
# 返回处理后的输出张量x
return x
1.预处理
if isinstance(time, int) or isinstance(time, float):
time = torch.tensor([time]).to(xt.device)
# X=noisy_tensor-LQ_tensor就是app.py里添加的随机噪声,与LQ_tensor拼接,增加通道维度
x = xt - cond
x = torch.cat([x, cond], dim=1)
# 获取输入张量的空间维度H和W
H, W = x.shape[2:]
# 检查并调整输入张量x的空间尺寸以匹配原始图像的尺寸
x = self.check_image_size(x, H, W)
# 应用初始卷积层
x = self.init_conv(x)
# 克隆x,用于后续操作
x_ = x.clone()
# 通过时间MLP处理时间参数
t = self.time_mlp(time)
self.init_conv = default_conv(in_nc * 2, nf, 7)
self.time_mlp = nn.Sequential(
# 第一层是一个位置编码层,可能使用正弦和余弦函数来嵌入位置信息
sinu_pos_emb,
# 第二层是一个全连接层,将输入特征从 fourier_dim 维度映射到 time_dim 维度
nn.Linear(fourier_dim, time_dim),
# 第三层是 GELU 激活函数,它将非线性引入到模型中,有助于学习复杂的模式
nn.GELU(),
# 第四层是另一个全连接层,它将特征再次映射回 time_dim 维度
# 这通常用于进一步提取特征和增强模型的非线性能力
nn.Linear(time_dim, time_dim)
)
2.prompt代码
处理过程
if self.context_dim > 0:
if self.use_degra_context and text_context is not None:
# 计算文本上下文的嵌入,将其与提示向量结合,并进行处理
prompt_embedding = torch.softmax(self.text_mlp(text_context), dim=1) * self.prompt
prompt_embedding = self.prompt_mlp(prompt_embedding)
# 将处理后的文本上下文嵌入加到时间参数t上
t = t + prompt_embedding
text_mlp
# 定义一个名为 text_mlp 的属性,它是一个由多个层组成的顺序模型,用于处理文本数据
self.text_mlp = nn.Sequential(
# 第一层是一个全连接层,将输入特征从 context_dim 维度映射到 time_dim 维度
nn.Linear(context_dim, time_dim),
# 第二层是一个非线性激活函数,这里用 NonLinearity() 表示,它可能是 nn.ReLU、nn.GELU 或其他激活函数
# 这个非线性激活函数有助于模型捕捉和学习数据中的复杂关系
NonLinearity(), # 这里 NonLinearity() 应替换为具体的激活函数,例如 nn.ReLU 或 nn.GELU
# 第三层是另一个全连接层,它将特征再次映射回 time_dim 维度
# 这通常用于进一步提取特征和增强模型的非线性能力
nn.Linear(time_dim, time_dim)
)
使用了self.prompt
self.prompt = nn.Parameter(
# 调用 torch.rand 来生成一个随机初始化的张量
# torch.rand(1, time_dim) 生成一个形状为 (1, time_dim) 的张量,其中 time_dim 是之前定义的维度
# 张量中的每个元素都是从 [0, 1) 区间的均匀分布中随机抽取的
torch.rand(1, time_dim)
)
self.prompt_mlp
self.prompt_mlp = nn.Linear(time_dim, time_dim)
# 这个 MLP 只有一个全连接层,它将时间维度的特征映射回相同的时间维度
# 这可能用于进一步处理或规范化提示数据的特征表示
3.Unet部分
下采样
# 存储下采样过程中的特征图
h = []
# 遍历下采样模块列表
for b1, b2, attn, downsample in self.downs:
# 应用第一个残差块和时间参数t
x = b1(x, t)
# 存储特征图
h.append(x)
# 应用第二个残差块和时间参数t
x = b2(x, t)
# 应用注意力机制,如果提供了图像上下文,则使用它
x = attn(x, context=image_context)
# 存储特征图
h.append(x)
# 应用下采样操作
x = downsample(x)
这里需要在定义中查找self.downs的结构,这里结合self.ups一起看结构创建定义,这里self.depth深度为4.
网络结构
# 遍历网络深度
for i in range(self.depth):
# 输入和输出维度
dim_in = nf * ch_mult[i]
dim_out = nf * ch_mult[i + 1]
# 输入和输出头部数量
num_heads_in = dim_in // num_head_channels
num_heads_out = dim_out // num_head_channels
# 每个头部的输入维度
dim_head_in = dim_in // num_heads_in
# 如果使用图像上下文且上下文维度大于0
if use_image_context and context_dim > 0:
# 这得看 i是否小于3才使用
att_down = LinearAttention(dim_in) if i < 3 else SpatialTransformer(dim_in, num_heads_in, dim_head,
depth=1,
context_dim=context_dim)
att_up = LinearAttention(dim_out) if i < 3 else SpatialTransformer(dim_out, num_heads_out, dim_head,
depth=1, context_dim=context_dim)
else:
# 使用线性注意力机制
att_down = LinearAttention(dim_in) # if i < 2 else Attention(dim_in)
att_up = LinearAttention(dim_out) # if i < 2 else Attention(dim_out)
# 下采样模块列表
# self.downs 是一个用于存储处理块序列的列表,在类的初始化方法中定义
self.downs.append(
nn.ModuleList([
# 创建第一个模块,一个自定义的块类(block_class),其输入和输出维度都是 dim_in,时间嵌入维度是 time_dim
block_class(dim_in=dim_in, dim_out=dim_in, time_emb_dim=time_dim),
# 创建第二个模块,与第一个模块相同
block_class(dim_in=dim_in, dim_out=dim_in, time_emb_dim=time_dim),
# 创建第三个模块,一个残差连接(Residual),它包含一个预归一化层(PreNorm)和注意力层(att_down)
Residual(PreNorm(dim_in, att_down)),
# 根据当前的索引 i 是否等于 (self.depth - 1),创建不同的模块
# 如果 i 不等于 (self.depth - 1),则创建一个 Downsample 模块,用于在 dim_in 和 dim_out 之间进行下采样
# 如果 i 等于 (self.depth - 1),则使用 default_conv 创建一个默认的二维卷积层,将 dim_in 维度的特征映射到 dim_out 维度
Downsample(dim_in, dim_out) if i != (self.depth - 1) else default_conv(dim_in, dim_out)
])
)
# 上采样模块列表
self.ups.insert(0, nn.ModuleList([
block_class(dim_in=dim_out + dim_in, dim_out=dim_out, time_emb_dim=time_dim),
block_class(dim_in=dim_out + dim_in, dim_out=dim_out, time_emb_dim=time_dim),
Residual(PreNorm(dim_out, att_up)),
Upsample(dim_out, dim_in) if i != 0 else default_conv(dim_out, dim_in)
]))
这里注意att_down在unet的最bottom处(i==3时)使用了交叉注意力。
SpatialTransformer定义在attention.py文件
class SpatialTransformer(nn.Module):
"""
Transformer block for image-like data.
First, project the input (aka embedding)
and reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
"""
def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None):
super().__init__()
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels)
self.proj_in = nn.Conv2d(in_channels,
inner_dim,
kernel_size=1,
stride=1,
padding=0)
self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
for d in range(depth)]
)
self.proj_out = zero_module(nn.Conv2d(inner_dim,
in_channels,
kernel_size=1,
stride=1,
padding=0))
def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
x = self.proj_in(x)
x = rearrange(x, 'b c h w -> b (h w) c')
for block in self.transformer_blocks:
x = block(x, context=context)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
x = self.proj_out(x)
return x + x_in
其中使用BasicTransformerBlock类作为self.transformer_blocks
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
super().__init__()
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
def forward(self, x, context=None):
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
def _forward(self, x, context=None):
x = self.attn1(self.norm1(x)) + x
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
return x
该类的attn1 和attn12都使用了 CrossAttention类
CrossAttention类定义如下:
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head ** -0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout)
)
def forward(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)
中间块
# 应用中间块1和时间参数t
x = self.mid_block1(x, t)
# 如果使用图像上下文,则应用注意力机制
x = self.mid_attn(x, context=image_context) if self.use_image_context else x
# 应用中间块2和时间参数t
x = self.mid_block2(x, t)
# 中间维度
mid_dim = nf * ch_mult[-1]
# 中间头部数量
num_heads_mid = mid_dim // num_head_channels
# 中间块1
self.mid_block1 = block_class(dim_in=mid_dim, dim_out=mid_dim, time_emb_dim=time_dim)
# 如果使用图像上下文且上下文维度大于0
if use_image_context and context_dim > 0:
# 使用空间变换器
self.mid_attn = Residual(PreNorm(mid_dim, SpatialTransformer(mid_dim, num_heads_mid, dim_head, depth=1,
context_dim=context_dim)))
else:
# 使用线性注意力机制
self.mid_attn = Residual(PreNorm(mid_dim, LinearAttention(mid_dim)))
# 中间块2
self.mid_block2 = block_class(dim_in=mid_dim, dim_out=mid_dim, time_emb_dim=time_dim)
上采样
# 遍历上采样模块列表
for b1, b2, attn, upsample in self.ups:
# 从历史特征图中弹出并拼接特征,与当前特征图拼接
x = torch.cat([x, h.pop()], dim=1)
# 应用第一个残差块和时间参数t
x = b1(x, t)
# 再次从历史特征图中弹出并拼接特征,与当前特征图拼接
x = torch.cat([x, h.pop()], dim=1)
# 应用第二个残差块和时间参数t
x = b2(x, t)
# 应用注意力机制,如果提供了图像上下文,则使用它
x = attn(x, context=image_context)
# 应用上采样操作
x = upsample(x)
4.后处理
# 将原始输入xt与当前特征图x拼接,增加通道维度
x = torch.cat([x, x_], dim=1)
# 应用最终的残差块和时间参数t
x = self.final_res_block(x, t)
# 应用最终的卷积层
x = self.final_conv(x)
# 裁剪输出张量x,使其空间尺寸与原始输入图像的尺寸相匹配
x = x[..., :H, :W].contiguous()
# 返回处理后的输出张量x
return x
# 最终残差块
self.final_res_block = block_class(dim_in=nf * 2, dim_out=nf, time_emb_dim=time_dim)
# 最终卷积层
self.final_conv = nn.Conv2d(nf, out_nc, 3, 1, 1)