大家好,我是刘明,明志科技创始人,华为昇思MindSpore布道师。
技术上主攻前端开发、鸿蒙开发和AI算法研究。
努力为大家带来持续的技术分享,如果你也喜欢我的文章,就点个关注吧
构建Diffusion模型
下面,我们逐步构建Diffusion模型。
首先,我们定义了一些帮助函数和类,这些函数和类将在实现神经网络时使用。
def rearrange(head, inputs):
b, hc, x, y = inputs.shape
c = hc // head
return inputs.reshape((b, head, c, x * y))
def rsqrt(x):
res = ops.sqrt(x)
return ops.inv(res)
def randn_like(x, dtype=None):
if dtype is None:
dtype = x.dtype
res = ops.standard_normal(x.shape).astype(dtype)
return res
def randn(shape, dtype=None):
if dtype is None:
dtype = ms.float32
res = ops.standard_normal(shape).astype(dtype)
return res
def randint(low, high, size, dtype=ms.int32):
res = ops.uniform(size, Tensor(low, dtype), Tensor(high, dtype), dtype=dtype)
return res
def exists(x):
return x is not None
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
def _check_dtype(d1, d2):
if ms.float32 in (d1, d2):
return ms.float32
if d1 == d2:
return d1
raise ValueError('dtype is not supported.')
class Residual(nn.Cell):
def __init__(self, fn):
super().__init__()
self.fn = fn
def construct(self, x, *args, **kwargs):
return self.fn(x, *args, **kwargs) + x
我们还定义了上采样和下采样操作的别名。
def Upsample(dim):
return nn.Conv2dTranspose(dim, dim, 4, 2, pad_mode="pad", padding=1)
def Downsample(dim):
return nn.Conv2d(dim, dim, 4, 2, pad_mode="pad", padding=1)
位置向量
由于神经网络的参数在时间(噪声水平)上共享,作者使用正弦位置嵌入来编码
,灵感来自Transformer(Vaswani et al., 2017)。对于批处理中的每一张图像,神经网络“知道”它在哪个特定时间步长(噪声水平)上运行。
SinusoidalPositionEmbeddings模块采用(batch_size, 1)形状的张量作为输入(即批处理中几个有噪声图像的噪声水平),并将其转换为(batch_size, dim)形状的张量,其中dim是位置嵌入的尺寸。然后,我们将其添加到每个剩余块中。
class SinusoidalPositionEmbeddings(nn.Cell):
def __init__(self, dim):
super().__init__()
self.dim = dim
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = np.exp(np.arange(half_dim) * - emb)
self.emb = Tensor(emb, ms.float32)
def construct(self, x):
emb = x[:, None] * self.emb[None, :]
emb = ops.concat((ops.sin(emb), ops.cos(emb)), axis=-1)
return emb
ResNet/ConvNeXT块
接下来,我们定义U-Net模型的核心构建块。DDPM作者使用了一个Wide ResNet块(Zagoruyko et al., 2016),但Phil Wang决定添加ConvNeXT(Liu et al., 2022)替换ResNet,因为后者在图像领域取得了巨大成功。
在最终的U-Net架构中,可以选择其中一个或另一个,本文选择ConvNeXT块构建U-Net模型。
class Block(nn.Cell):
def __init__(self, dim, dim_out, groups=1):
super().__init__()
self.proj = nn.Conv2d(dim, dim_out, 3, pad_mode="pad", padding=1)
self.proj = c(dim, dim_out, 3, padding=1, pad_mode='pad')
self.norm = nn.GroupNorm(groups, dim_out)
self.act = nn.SiLU()
def construct(self, x, scale_shift=None):
x = self.proj(x)
x = self.norm(x)
if exists(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift
x = self.act(x)
return x
class ConvNextBlock(nn.Cell):
def __init__(self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True):
super().__init__()
self.mlp = (
nn.SequentialCell(nn.GELU(), nn.Dense(time_emb_dim, dim))
if exists(time_emb_dim)
else None
)
self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, group=dim, pad_mode="pad")
self.net = nn.SequentialCell(
nn.GroupNorm(1, dim) if norm else nn.Identity(),
nn.Conv2d(dim, dim_out * mult, 3, padding=1, pad_mode="pad"),
nn.GELU(),
nn.GroupNorm(1, dim_out * mult),
nn.Conv2d(dim_out * mult, dim_out, 3, padding=1, pad_mode="pad"),
)
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def construct(self, x, time_emb=None):
h = self.ds_conv(x)
if exists(self.mlp) and exists(time_emb):
assert exists(time_emb), "time embedding must be passed in"
condition = self.mlp(time_emb)
condition = condition.expand_dims(-1).expand_dims(-1)
h = h + condition
h = self.net(h)
return h + self.res_conv(x)
Attention模块
接下来,我们定义Attention模块,DDPM作者将其添加到卷积块之间。Attention是著名的Transformer架构(Vaswani et al., 2017),在人工智能的各个领域都取得了巨大的成功,从NLP到蛋白质折叠。Phil Wang使用了两种注意力变体:一种是常规的multi-head self-attention(如Transformer中使用的),另一种是LinearAttention(Shen et al., 2018),其时间和内存要求在序列长度上线性缩放,而不是在常规注意力中缩放。 要想对Attention机制进行深入的了解,请参照Jay Allamar的精彩的博文。
class Attention(nn.Cell):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, pad_mode='valid', has_bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1, pad_mode='valid', has_bias=True)
self.map = ops.Map()
self.partial = ops.Partial()
def construct(self, x):
b, _, h, w = x.shape
qkv = self.to_qkv(x).chunk(3, 1)
q, k, v = self.map(self.partial(rearrange, self.heads), qkv)
q = q * self.scale
# 'b h d i, b h d j -> b h i j'
sim = ops.bmm(q.swapaxes(2, 3), k)
attn = ops.softmax(sim, axis=-1)
# 'b h i j, b h d j -> b h i d'
out = ops.bmm(attn, v.swapaxes(2, 3))
out = out.swapaxes(-1, -2).reshape((b, -1, h, w))
return self.to_out(out)
class LayerNorm(nn.Cell):
def __init__(self, dim):
super().__init__()
self.g = Parameter(initializer('ones', (1, dim, 1, 1)), name='g')
def construct(self, x):
eps = 1e-5
var = x.var(1, keepdims=True)
mean = x.mean(1, keep_dims=True)
return (x - mean) * rsqrt((var + eps)) * self.g
class LinearAttention(nn.Cell):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, pad_mode='valid', has_bias=False)
self.to_out = nn.SequentialCell(
nn.Conv2d(hidden_dim, dim, 1, pad_mode='valid', has_bias=True),
LayerNorm(dim)
)
self.map = ops.Map()
self.partial = ops.Partial()
def construct(self, x):
b, _, h, w = x.shape
qkv = self.to_qkv(x).chunk(3, 1)
q, k, v = self.map(self.partial(rearrange, self.heads), qkv)
q = ops.softmax(q, -2)
k = ops.softmax(k, -1)
q = q * self.scale
v = v / (h * w)
# 'b h d n, b h e n -> b h d e'
context = ops.bmm(k, v.swapaxes(2, 3))
# 'b h d e, b h d n -> b h e n'
out = ops.bmm(context.swapaxes(2, 3), q)
out = out.reshape((b, -1, h, w))
return self.to_out(out)
组归一化
DDPM作者将U-Net的卷积/注意层与群归一化(Wu et al., 2018)。下面,我们定义一个PreNorm类,将用于在注意层之前应用groupnorm。
class PreNorm(nn.Cell):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.GroupNorm(1, dim)
def construct(self, x):
x = self.norm(x)
return self.fn(x)
条件U-Net
我们已经定义了所有的构建块(位置嵌入、ResNet/ConvNeXT块、Attention和组归一化),现在需要定义整个神经网络了。请记住,网络
的工作是接收一批噪声图像+噪声水平,并输出添加到输入中的噪声。
更具体的: 网络获取了一批(batch_size, num_channels, height, width)形状的噪声图像和一批(batch_size, 1)形状的噪音水平作为输入,并返回(batch_size, num_channels, height, width)形状的张量。
网络构建过程如下:
首先,将卷积层应用于噪声图像批上,并计算噪声水平的位置
接下来,应用一系列下采样级。每个下采样阶段由2个ResNet/ConvNeXT块 + groupnorm + attention + 残差连接 + 一个下采样操作组成
在网络的中间,再次应用ResNet或ConvNeXT块,并与attention交织
接下来,应用一系列上采样级。每个上采样级由2个ResNet/ConvNeXT块+ groupnorm + attention + 残差连接 + 一个上采样操作组成
最后,应用ResNet/ConvNeXT块,然后应用卷积层
最终,神经网络将层堆叠起来,就像它们是乐高积木一样(但重要的是了解它们是如何工作的)。
class Unet(nn.Cell):
def __init__(
self,
dim,
init_dim=None,
out_dim=None,
dim_mults=(1, 2, 4, 8),
channels=3,
with_time_emb=True,
convnext_mult=2,
):
super().__init__()
self.channels = channels
init_dim = default(init_dim, dim // 3 * 2)
self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3, pad_mode="pad", has_bias=True)
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
block_klass = partial(ConvNextBlock, mult=convnext_mult)
if with_time_emb:
time_dim = dim * 4
self.time_mlp = nn.SequentialCell(
SinusoidalPositionEmbeddings(dim),
nn.Dense(dim, time_dim),
nn.GELU(),
nn.Dense(time_dim, time_dim),
)
else:
time_dim = None
self.time_mlp = None
self.downs = nn.CellList([])
self.ups = nn.CellList([])
num_resolutions = len(in_out)
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1)
self.downs.append(
nn.CellList(
[
block_klass(dim_in, dim_out, time_emb_dim=time_dim),
block_klass(dim_out, dim_out, time_emb_dim=time_dim),
Residual(PreNorm(dim_out, LinearAttention(dim_out))),
Downsample(dim_out) if not is_last else nn.Identity(),
]
)
)
mid_dim = dims[-1]
self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolutions - 1)
self.ups.append(
nn.CellList(
[
block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
block_klass(dim_in, dim_in, time_emb_dim=time_dim),
Residual(PreNorm(dim_in, LinearAttention(dim_in))),
Upsample(dim_in) if not is_last else nn.Identity(),
]
)
)
out_dim = default(out_dim, channels)
self.final_conv = nn.SequentialCell(
block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1)
)
def construct(self, x, time):
x = self.init_conv(x)
t = self.time_mlp(time) if exists(self.time_mlp) else None
h = []
for block1, block2, attn, downsample in self.downs:
x = block1(x, t)
x = block2(x, t)
x = attn(x)
h.append(x)
x = downsample(x)
x = self.mid_block1(x, t)
x = self.mid_attn(x)
x = self.mid_block2(x, t)
len_h = len(h) - 1
for block1, block2, attn, upsample in self.ups:
x = ops.concat((x, h[len_h]), 1)
len_h -= 1
x = block1(x, t)
x = block2(x, t)
x = attn(x)
x = upsample(x)
return self.final_conv(x)
正向扩散
我们已经知道正向扩散过程在多个时间步长T中,从实际分布逐渐向图像添加噪声,根据差异计划进行正向扩散。
下面,我们定义了T时间步的时间表。
def linear_beta_schedule(timesteps):
beta_start = 0.0001
beta_end = 0.02
return np.linspace(beta_start, beta_end, timesteps).astype(np.float32)