在完成Glow论文学习后,对github上一个复现的仓库进行学习,帮助理解算法实现过程中的一些细节;所选择的复现仓库是基于pytorch实现,链接为https://github.com/rosinality/glow-pytorch。Glow是基于Flow的模型的,其结构很直接,数学原理确定,在定义模块时,要保证模型可以逆向运算。
本仓库中的代码主要在model.py和train.py中,结构、逻辑很清晰,具体实现与论文一致,阅读论文时可直接与代码比对学习,感兴趣的读者可阅读Normalized Glow论文阅读笔记,主要搞清楚其中的表1即可对理解整个代码构造。先主要对上述两个py文件进行注释解析,帮助学习和理解。
model.py
清晰的对论文图2中展示的Glow的各个模块进行定义
import torch
from torch import nn
from torch.nn import functional as F
from math import log, pi, exp
import numpy as np
from scipy import linalg as la
device = "gpu:0" if torch.cuda.is_available() else "cpu"
logabs = lambda x: torch.log(torch.abs(x)) # 自定义求log(|x|)的lambda函数
class ActNorm(nn.Module):
def __init__(self, in_channel, logdet=True):
super().__init__()
# 只对channel维度进行原酸,即本质上loc、scale只有num_channels个数值
self.loc = nn.Parameter(torch.zeros(1, in_channel, 1, 1)) # 相当于论文表1中的b,平移量
self.scale = nn.Parameter(torch.ones(1, in_channel, 1, 1)) # 相当于论文表1中的s,伸缩量
# self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8))
# 用于判断是否是第一个batch,如果是,就调用initialize函数计算该batch内数据的均值和方差,再初始化loc和scale
self.initialized = nn.Parameter(torch.tensor(0, dtype=torch.uint8), requires_grad=False)
self.logdet = logdet
def initialize(self, input):
with torch.no_grad():
# flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
# mean = (
# flatten.mean(1)
# .unsqueeze(1)
# .unsqueeze(2)
# .unsqueeze(3)
# .permute(1, 0, 2, 3)
# )
# std = (
# flatten.std(1)
# .unsqueeze(1)
# .unsqueeze(2)
# .unsqueeze(3)
# .permute(1, 0, 2, 3)
# )
# 此处的mean,std可以通过torch.mean(input, dim=(0, 2, 3), keepdim=True)来实现
mean = torch.mean(input, dim=(0, 2, 3), keepdim=True)
std = torch.std(input, dim=(0, 2, 3), keepdim=True)
# 论文中提到的数据以来的初始化 ,目的是第一个batch的数据经过actnorm后变成标准分布,以稳定训练收敛
self.loc.data.copy_(-mean)
self.scale.data.copy_(1 / (std + 1e-6))
def forward(self, input): # 前向过程
_, _, height, width = input.shape # 整体的shape是[batch_size, num_channel, height, width]
if self.initialized.item() == 0: # 最开始为初始化的0,表示第一个batch,要进行数据依赖的初始化;完成后,用1填充,后续就不再执行
self.initialize(input)
self.initialized.fill_(1)
log_abs = logabs(self.scale) # 对scake中的每个元素值求log
logdet = height * width * torch.sum(log_abs) # 对数似然的变化量,是一个标量
if self.logdet: # 是否返回对数似然的增量
return self.scale * (input + self.loc), logdet
else:
return self.scale * (input + self.loc)
def reverse(self, output): # 逆向过程,推理
return output / self.scale - self.loc
# 可逆1X1二维卷积
class InvConv2d(nn.Module):
def __init__(self, in_channel):
super().__init__()
weight = torch.randn(in_channel, in_channel) # 初始化1X1卷积的权重参数,尺寸是[c, c],就相当于在通道上的一个MLP
q, _ = torch.qr(weight) # 对weight进行qr分解,q是一个正交的矩阵,即q的行列式不为0
# 尺寸扩为[c, c, 1, 1],因为在conv2d中要求的weight的尺寸为[input_size, output_size, kernel_size, kernel_size]
weight = q.unsqueeze(2).unsqueeze(3)
self.weight = nn.Parameter(weight)
def forward(self, input): # 前向
_, _, height, width = input.shape
out = F.conv2d(input, self.weight)
logdet = height * width * torch.slogdet(self.weight.squeeze().double())[1].float()
return out, logdet
def reverse(self, output): # 逆向
return F.conv2d(output, self.weight.squeeze().inverse().unsqueeze(2).unsqueeze(3))
# 使用LU分解的快速版可逆二维卷积
class InvConv2dLU(nn.Module):
def __init__(self, in_channel):
super().__init__()
weight = np.random.randn(in_channel, in_channel)
q, _ = la.qr(weight) # 任何矩阵都可进行qr分解 ,分解后q是正交矩阵
# 行列式不为0,LU分解一定存在;q为正交矩阵,行列式值为+1或-1,故一定可以进行LU分解
# A = P L U,P为置换矩阵,L为下三角、U为上三角;L的对角线元素一定为1
w_p, w_l, w_u = la.lu(q.astype(np.float32))
w_s = np.diag(w_u) # 获取对角线元素构成数组,一维向量
w_u = np.triu(w_u, 1) # 只保留第一条对角线的上三角矩阵,其实就是对角线元素表内0
u_mask = np.triu(np.ones_like(w_u), 1) # 只保留第一条对角线的上三角矩阵,对应元素全部为1,其他元素变为了0
l_mask = u_mask.T # 只保留第-1条对角线的下三角矩阵,对应元素为1,其他元素全部为0
# 将上述定义的向量转为张量
w_p = torch.from_numpy(w_p).to(device)
w_l = torch.from_numpy(w_l).to(device)
w_s = torch.from_numpy(w_s).to(device)
w_u = torch.from_numpy(w_u).to(device)
# 将不需要更新的量,将其注册为buffer
self.

本文详细介绍了一个基于PyTorch实现的Glow模型,并对该模型的主要组件进行了深入解析,包括ActNorm层、可逆卷积层、仿射耦合层等。
最低0.47元/天 解锁文章
2660





