PyTorch生成式人工智能(20)——像素卷积神经网络
0. 前言
像素卷积神经网络 (Pixel Convolutional Neural Network
, PixelCNN
) 是于 2016
年提出的一种图像生成模型,其根据前面的像素预测下一个像素的概率来逐像素地生成图像,模型可以通过自回归的方式进行训练以生成图像。在本节中,将使用 PyTorch
实现 PixelCNN
模型并将其应用于图像数据生成中。
1. PixelCNN 工作原理
为了理解 PixelCNN
,我们需要介绍两个关键技术:掩码卷积层 (Masked Convolutional Layer
) 和残差块 (Residual Block
)。
1.1 掩码卷积层
卷积层可以通过应用一系列卷积核从图像中提取特征。在特定像素点处,卷积层的输出是卷积核权重与以该像素为中心的区域上的值的加权和。通过应用一系列卷积层可以检测到图像中的边缘、纹理以及在更深层的形状和高级特征。
虽然卷积层在特征检测中十分有效,但无法直接用于自回归模型,因为像素之间没有明确的顺序关系。在图像中所有像素均会被平等对待,没有像素会被视为图像的起始或结束点,这与文本数据不同,文本数据中的符号具有明确的顺序性,因此可以方便地应用循环模型,如长短期记忆网络 (Long Short-Term Memory Network
, LSTM
)。
为了能够以自回归的方式下将卷积层应用于图像生成,我们首先必须将像素进行排序,并确保卷积核只能看到前面的像素。然后,通过将由 1
和 0
组成的掩码与卷积核权重矩阵相乘,使得在每个像素处,层的输出仅受到前面像素值的影响,从而逐像素地生成图像,通过将卷积卷积核应用于当前图像来预测下一个像素的值。
首先,需要选择像素的排序方式,一种可行的方法是从左上到右下对像素进行排序,首先沿行移动,然后沿列移动。
然后,我们对卷积核进行掩码处理,以使得每个像素处的层的输出仅受到前面的像素值的影响。为此,我们将由 1
和 0
组成的掩码与卷积核权重矩阵相乘,将目标像素后面的其余像素的值置零。
在 PixelCNN
中实际上有两种不同类型的掩码:
A
型,中心像素的值为掩码像素B
型,中心像素的值不为掩码像素
初始的掩码卷积层(即直接应用于输入图像的层)不能使用中心像素,因为这恰是我们希望网络预测的像素,而后续的层可以使用中心像素,因为它已经由初始输入图像之前的像素信息计算出来。
使用 PyToch
构建掩码卷积层 (MaskedConvLayer
):
class MaskedConv2d(nn.Conv2d):
def __init__(self, mask_type, *args, **kwargs):
super().__init__(*args, **kwargs)
self.register_buffer('mask', self.weight.data.clone())
_, _, h, w = self.mask.shape
self.mask.fill_(0)
self.mask[:, :, :h//2, :] = 1 # 上半部可见
self.mask[:, :, h//2, :w//2 + (mask_type=='B')] = 1 # 中心行左侧
def forward(self, x):
self.weight.data *= self.mask # 应用掩码
return super().forward(x)
需要注意的是,我们假设使用灰度图像(即只有一个通道)。如果我们使用彩色图像,则可以对三个颜色通道进行排序,例如红色通道在蓝色通道之前,蓝色通道在绿色通道之前。
1.2 残差块
我们已经学习了如何对卷积层进行掩码处理,接下来开始构建 PixelCNN
,我们将使用残差块 (Residual Block
) 作为核心构建块。
残差块是一组网络层,包含两个主要部分:
- 主路径 (
Main Path
):由一系列卷积层和激活函数构成,用于学习特征表示 - 跳跃连接 (
Skip Connection
):直接将输入信息绕过一部分主路径,与输出相加。这样可以确保输入信息更容易传播到后续层,并且有助于避免梯度消失问题
也就是说,在残差块中,输入有一条捷径连接到输出,而无需经过中间层。跳跃连接的理论基础可以描述为,如果最优的变换是保持输入不变,那么通过简单地将中间层的权重置零就可以实现;如果没有跳跃连接,网络就必须通过中间层找到一个恒等映射,这显然更加困难。使用 PyTorch
构建残差块 (ResidualBlock
):
class ResidualBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv = nn.Sequential(
MaskedConv2d('B', in_channels, in_channels//2, 1),
nn.ReLU(),
MaskedConv2d('B', in_channels//2, in_channels//2, 3, padding=1),
nn.ReLU(),
MaskedConv2d('B', in_channels//2, in_channels, 1)
)
def forward(self, x):
return x + self.conv(x)
2. PixelCNN 分析
在 PixelCNN
中,输出层是一个具有 256
个卷积核的 Conv2D
层,使用 softmax
激活函数。换句话说,网络通过预测正确的像素值来尝试重新创建其输入,类似于编码器。不同之处在于,网络采用了 MaskedConv2D
层,像素预测使用的像素信息并不相同。
使用这种方法,PixelCNN
必须独立学习每个像素的输出值,但像素值 220
与 221
的差异并不明显,这意味着即使对于最简单的数据集,训练速度也可能非常慢。因此,为了加快新图像的生成速度,通常将图像尺寸缩放进行缩放,同时将 PixelCNN
的输出减少到较小的像素级别。
为了生成新图像,我们需要让模型根据前面的所有像素逐个像素地预测下一个像素。与变分自编码器等模型相比,生成过程非常缓慢。对于一个尺寸 32×32
的灰度图像,我们需要使用模型进行 1024
次连续预测,而对于变分自编码器 (Variational Autoencoder
, VAE
),我们只需要进行一次预测。这是自回归模型(如 PixelCNN
)的一个主要缺点,由于采样过程的顺序性,进行采样的速度较慢。
3. 使用混合分布改进 PixelCNN
3.1 模型构建
通常,为了避免训练速度过慢的问题,令网络不必学习 256
个独立像素值上的分布,我们将 PixelCNN
的输出减少到较小的像素级别。但是,这种方法并非最佳解决方案,对于彩色图像,我们不希望仅使用少数几种颜色。
为了解决这一问题,我们可以将网络的输出设为混合分布 (Mixture Distribution
),而不是对 256
个离散像素值使用 softmax
,混合分布简单来说就是两个或多个其他概率分布的混合。例如,我们可以使用一个由五个逻辑分布组成的混合分布,每个分布都有不同的参数。混合分布还需要离散分类分布,用于指示选择混合中包含的每个分布的概率。
要从混合分布中进行采样,我们首先从分类分布中进行采样,选择一个特定的子分布,然后按照正常的方式从该子分布中进行采样。这样,我们可以用相对较少的参数创建复杂的分布:
# PixelCNN模型(混合高斯分布输出)
class PixelCNN(nn.Module):
def __init__(self, n_mix=5):
super().__init__()
self.n_mix = n_mix
# 初始卷积层(Type A掩码)
self.input_conv = MaskedConv2d('A', 3, 64, 7)
# 残差块堆叠
self.res_blocks = nn.Sequential(
*[ResidualBlock(64) for _ in range(N_RESBLOCK)]
)
# 输出层(混合高斯参数)
self.out_conv = nn.Sequential(
nn.ReLU(),
MaskedConv2d('B', 64, 256, 1)
)
# 混合分布参数生成
self.mixture_conv = nn.Conv2d(256, 3*3*n_mix, 1) # 每个像素产生3*n_mix参数(RGB)
def forward(self, x):
# 输入x范围[-1,1],转换为[0,1]
x = (x + 1) / 2
# 前向传播
x = self.input_conv(x)
x = self.res_blocks(x)
x = self.out_conv(x)
# 生成混合分布参数
params = self.mixture_conv(x) # [B, 3*3*n_mix, H, W]
params = params.view(-1, 3, 3*self.n_mix, 32, 32) # [B, 3, 3*n_mix, H, W]
# 拆分参数
logit_probs = params[:, :, :self.n_mix, :, :] # 混合系数
means = params[:, :, self.n_mix:2*self.n_mix, :, :] # 均值
log_scales = params[:, :, 2*self.n_mix:, :, :] # 对数标准差
# 应用约束
log_scales = torch.clamp(log_scales, min=-7.0) # 防止数值不稳定
return logit_probs, means, log_scales
定义损失函数:
# 混合高斯分布损失函数
def mixture_loss(x, logit_probs, means, log_scales):
# 将输入数据转换为[0,1]范围
x = (x + 1) / 2 # 从[-1,1]转换到[0,1]
# 计算概率密度
scales = torch.exp(log_scales)
x = x.unsqueeze(2) # 添加混合维度
# 计算对数概率
log_probs = -0.5 * ((x - means) / scales)**2 - log_scales - 0.5 * math.log(2*math.pi)
# 混合系数处理
log_probs = log_probs + F.log_softmax(logit_probs, dim=2)
# 对数似然
log_likelihood = torch.logsumexp(log_probs, dim=2)
return -log_likelihood.mean()
3.2 模型训练
接下来,使用 CIFAR-10
数据集训练 PixelCNN
模型。
(1) 使用 torchvision
库加载 CIFAR-10
数据集,并进行预处理:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import math
# 数据预处理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 转换为[-1,1]范围
])
# 加载CIFAR-10数据集
train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
(2) 定义超参数,并实例化 PixelCNN
模型和优化器:
# 超参数配置
BATCH_SIZE = 64
EPOCHS = 100
LR = 3e-4
N_RESBLOCK = 5
N_MIX = 5 # 混合高斯分布数量
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 初始化模型和优化器
model = PixelCNN(n_mix=N_MIX).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
(3) 执行模型训练 100
个 epoch
:
# 训练循环
for epoch in range(EPOCHS):
model.train()
total_loss = 0
for batch, (images, _) in enumerate(train_loader):
images = images.to(DEVICE)
# 前向传播
logit_probs, means, log_scales = model(images)
# 计算损失
loss = mixture_loss(images, logit_probs, means, log_scales)
# 反向传播
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
total_loss += loss.item()
if batch % 100 == 0:
print(f"Epoch [{epoch+1}/{EPOCHS}] Batch [{batch}/{len(train_loader)}] "
f"Loss: {loss.item():.4f}")
avg_loss = total_loss / len(train_loader)
print(f"Epoch [{epoch+1}/{EPOCHS}] Average Loss: {avg_loss:.4f}")
(4) 为了显示生成结果,定义辅助函数 generate()
:
# 生成函数
def generate(model, n_samples=32):
model.eval()
with torch.no_grad():
samples = torch.zeros(n_samples, 3, 32, 32).to(DEVICE) - 1
for i in range(32):
for j in range(32):
for c in range(3):
# 前向传播获取参数
logit_probs, means, log_scales = model(samples)
# 处理混合系数
channel_probs = F.softmax(logit_probs[:, c, :, i, j], dim=-1)
component = torch.multinomial(channel_probs, 1).squeeze(1)
# 索引参数并保持维度
batch_idx = torch.arange(samples.size(0)).to(DEVICE)
mean = means[batch_idx, c, component, i, j]
scale = torch.exp(log_scales[batch_idx, c, component, i, j])
# 采样并限制范围
pixel = torch.normal(mean, scale)
pixel = torch.clamp(pixel, 0.0, 1.0) * 2 - 1
# 确保正确赋值
samples[:, c, i, j] = pixel
return samples
(5) 使用训练完成的模型生成图像:
generated = generate(model)
torchvision.utils.save_image(generated, "generated_cifar10.png", nrow=16, normalize=True)
生成的图像看起来不自然,虽然生成了有趣的图像,但并没有学习到训练数据集的自然图像的结构。这是由于模型的搜索效率低下的原因,为了便于模型学习,可以使用量化技术,将 CIFAR-10
图像从每个像素的原始 256
个强度值量化为每个像素 8
个强度值。
小结
在本节中,介绍了如何使用 PixelCNN
以自回归的方式生成图像,使用 PyToch
构建 PixelCNN
模型,实现掩码卷积层和残差块,以便信息可以在网络中传递,只有前面的像素可以用于生成当前的像素。
系列链接
PyTorch生成式人工智能实战:从零打造创意引擎
PyTorch生成式人工智能(1)——神经网络与模型训练过程详解
PyTorch生成式人工智能(2)——PyTorch基础
PyTorch生成式人工智能(3)——使用PyTorch构建神经网络
PyTorch生成式人工智能(4)——卷积神经网络详解
PyTorch生成式人工智能(5)——分类任务详解
PyTorch生成式人工智能(6)——生成模型(Generative Model)详解
PyTorch生成式人工智能(7)——生成对抗网络实践详解
PyTorch生成式人工智能(8)——深度卷积生成对抗网络
PyTorch生成式人工智能(9)——Pix2Pix详解与实现
PyTorch生成式人工智能(10)——CyclelGAN详解与实现
PyTorch生成式人工智能(11)——神经风格迁移
PyTorch生成式人工智能(12)——StyleGAN详解与实现
PyTorch生成式人工智能(13)——WGAN详解与实现
PyTorch生成式人工智能(14)——条件生成对抗网络(conditional GAN,cGAN)
PyTorch生成式人工智能(15)——自注意力生成对抗网络(Self-Attention GAN, SAGAN)
PyTorch生成式人工智能(16)——自编码器(AutoEncoder)详解
PyTorch生成式人工智能(17)——变分自编码器详解与实现
PyTorch生成式人工智能(18)——循环神经网络详解与实现
PyTorch生成式人工智能(19)——自回归模型详解与实现