Diffusion的unet中用到的AttentionBlock详解


Diffusion的unet中用到的AttentionBlock详解

class AttentionBlock(nn.Module):
    __doc__ = r"""Applies QKV self-attention with a residual connection.
    
    Input:
        x: tensor of shape (N, in_channels, H, W)
        norm (string or None): which normalization to use (instance, group, batch, or none). Default: "gn"
        num_groups (int): number of groups used in group normalization. Default: 32
    Output:
        tensor of shape (N, in_channels, H, W)
    Args:
        in_channels (int): number of input channels
    """
    def __init__(self, in_channels, norm="gn", num_groups=32):
        super().__init__()
        
        self.in_channels = in_channels
        self.norm = get_norm(norm, in_channels, num_groups)
        # 为啥这里的QKV并不是一样的???而是把通道数翻了3倍
        self.to_qkv = nn.Conv2d(in_channels, in_channels * 3, 1)
        self.to_out = nn.Conv2d(in_channels, in_channels, 1)

    def forward(self, x):
        b, c, h, w = x.shape
        q, k, v = torch.split(self.to_qkv(self.norm(x)), self.in_channels, dim=1)

        q = q.permute(0, 2, 3, 1).view(b, h * w, c)
        k = k.view(b, c, h * w)
        v = v.permute(0, 2, 3, 1).view(b, h * w, c)

        dot_products = torch.bmm(q, k) * (c ** (-0.5))
        assert dot_products.shape == (b, h * w, h * w)

        attention = torch.softmax(dot_products, dim=-1)
        out = torch.bmm(attention, v)
        assert out.shape == (b, h * w, c)
        out = out.view(b, h, w, c).permute(0, 3, 1, 2)

        return self.to_out(out) + x

x: (batch, channel, h, w)
经过to_qkv操作,变成了(batch, channel*3, h, w)

torch.split

torch.split(tensor, split_size_or_sections, dim=0)
# 作用:将tensor分成块结构
'''
split_size_or_secctions: 即多少个为一组
dim: 对哪个维度进行划分
'''

eg:
q, k, v = torch.split(self.to_qkv(self.norm(x)), self.in_channels, dim=1)
即对大小为(batch, channel*3, h, w)的张量,在dim=1上划分,每channel个为一组
所以,q, k, v 的形状均为(batch, channel, h, w)

torch.split详解

torch中的permute的用法

作用:permute可以对tensor进行转置

import torch
import torch.nn as nn

x = torch.randn(1, 2, 3, 4)
print(x.size())    # torch.Size([1, 2, 3, 4])   
print(x.permute(2, 1, 0, 3).size())# torch.Size([3, 2, 1, 4])   

torch.transpose()

因为torch.transpose 一次只能进行两个维度的转置,如果需要多个维度的转置,那么需要多次调用transpose()。比如上述的tensor[1,2,3,4]转置为tensor[3,4,1,2],使用transpose需要做如下:

x.transpose(0,2).transpose(1,3)

view()

view()函数作用的内存必须是连续的,如果操作数不是连续存储的,必须在操作之前执行contiguous(),把tensor变成在内存中连续分布的形式;view的功能有点像reshape,可以对tensor进行重新塑型

import torch
import torch.nn as nn
import numpy as np

y = np.array([[[1, 2, 3], [4, 5, 6]]]) # 1X2X3
y_tensor = torch.tensor(y)
y_tensor_trans = y_tensor.permute(2, 0, 1) # 3X1X2
print(y_tensor.size())
print(y_tensor_trans.size())

print(y_tensor)
print(y_tensor_trans)
print(y_tensor.view(1, 3, 2)) 


torch.Size([1, 2, 3])
torch.Size([3, 1, 2])
tensor([[[1, 2, 3],
         [4, 5, 6]]])
tensor([[[1, 4]],

        [[2, 5]],

        [[3, 6]]])
tensor([[[1, 2],
         [3, 4],
         [5, 6]]])

permute参考
permute详解参考

torch.bmm

作用:
计算两个tensor的矩阵乘法,torch.bmm(a,b),tensor a 的size为(b,h,w),tensor b的size为(b,w,m) 也就是说两个tensor的第一维是相等的,然后第一个数组的第三维和第二个数组的第二维度要求一样,对于剩下的则不做要求,输出维度 (b,h,m)

torch.bmm要求a,b的维度必须是3维的,不能为2D or 4D

矩阵相乘

softmax(x, dim=-1)

import torch 
a = torch.randn(2,3)
print(a)
tensor([[-8.2976e-01,  5.8105e-04,  1.2218e+00],
        [ 1.9745e-01,  1.2727e+00,  5.9587e-01]])
b = torch.softmax(a, dim=-1)
print(b)
tensor([[0.0903, 0.2072, 0.7025],
        [0.1845, 0.5407, 0.2748]])

softmax(x, dim=-1)

### Diffusion Model中的UNet架构解释 #### UNet 架构概述 UNet 是一种常用于图像处理任务的卷积神经网络结构,最初设计用于生物医学图像分割。该架构由编码器(下采样路径)和解码器(上采样路径)组成,在两者之间通过跳跃连接传递特征图[^1]。 #### 编码器部分 编码器负责逐步提取输入数据的空间信息并减少分辨率。每一层通常包含两个3×3卷积操作,后面跟着ReLU激活函数以及最大池化层来降低空间维度。随着层数加深,通道数量逐渐增加以捕捉更复杂的模式[^2]。 #### 解码器部分 解码器的任务是对低维表示进行重建,恢复原始尺寸的同时保留高层次语义信息。这一步骤利用转置卷积(反卷积)扩大特征映射大小,并且每步都会与相应层次来自编码阶段的信息相融合——即所谓的“跳跃链接”,从而增强局部细节准确性[^3]。 #### 扩散模型特定调整 当应用于扩散模型时, UNet 需要做出一些适应性的改变: - **时间条件嵌入**: 添加额外的时间步长作为条件变量t∈{0,...,T},使得网络能够根据不同噪声水平调整预测行为。 - **多尺度注意力机制**: 引入自注意模块帮助捕获全局依赖关系,尤其对于高分辨率图片生成非常重要。 ```python import torch.nn as nn class TimeEmbedding(nn.Module): def __init__(self, embedding_dim=256): super().__init__() self.fc_1 = nn.Linear(embedding_dim, embedding_dim * 4) self.act_fn = nn.SiLU() self.fc_2 = nn.Linear(embedding_dim * 4, embedding_dim) def forward(self, timesteps): x = self.fc_1(timesteps) x = self.act_fn(x) x = self.fc_2(x) return x def get_timestep_embedding(timesteps, embedding_dim): half_dim = embedding_dim // 2 emb = np.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=torch.float32).to(device=timesteps.device) * -emb) emb = timesteps[:, None].float() * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) if embedding_dim % 2 == 1: # zero pad emb = nn.ZeroPad2d((0, 1, 0, 0))(emb) assert emb.shape == (timesteps.shape[0], embedding_dim) return emb ```
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值