Glow-pytorch复现github项目

本文详细介绍了一个基于PyTorch实现的Glow模型,并对该模型的主要组件进行了深入解析,包括ActNorm层、可逆卷积层、仿射耦合层等。

在完成Glow论文学习后,对github上一个复现的仓库进行学习,帮助理解算法实现过程中的一些细节;所选择的复现仓库是基于pytorch实现,链接为https://github.com/rosinality/glow-pytorch。Glow是基于Flow的模型的,其结构很直接,数学原理确定,在定义模块时,要保证模型可以逆向运算。

本仓库中的代码主要在model.py和train.py中,结构、逻辑很清晰,具体实现与论文一致,阅读论文时可直接与代码比对学习,感兴趣的读者可阅读Normalized Glow论文阅读笔记,主要搞清楚其中的表1即可对理解整个代码构造。先主要对上述两个py文件进行注释解析,帮助学习和理解。

model.py

清晰的对论文图2中展示的Glow的各个模块进行定义

import torch
from torch import nn
from torch.nn import functional as F
from math import log, pi, exp
import numpy as np
from scipy import linalg as la

device = "gpu:0" if torch.cuda.is_available() else "cpu"

logabs = lambda x: torch.log(torch.abs(x))  # 自定义求log(|x|)的lambda函数


class ActNorm(nn.Module):
    def __init__(self, in_channel, logdet=True):
        super().__init__()
        # 只对channel维度进行原酸,即本质上loc、scale只有num_channels个数值
        self.loc = nn.Parameter(torch.zeros(1, in_channel, 1, 1))  # 相当于论文表1中的b,平移量
        self.scale = nn.Parameter(torch.ones(1, in_channel, 1, 1))  # 相当于论文表1中的s,伸缩量

        # self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8))
        # 用于判断是否是第一个batch,如果是,就调用initialize函数计算该batch内数据的均值和方差,再初始化loc和scale
        self.initialized = nn.Parameter(torch.tensor(0, dtype=torch.uint8), requires_grad=False)

        self.logdet = logdet

    def initialize(self, input):
        with torch.no_grad():
            # flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
            # mean = (
            #     flatten.mean(1)
            #         .unsqueeze(1)
            #         .unsqueeze(2)
            #         .unsqueeze(3)
            #         .permute(1, 0, 2, 3)
            # )
            # std = (
            #     flatten.std(1)
            #         .unsqueeze(1)
            #         .unsqueeze(2)
            #         .unsqueeze(3)
            #         .permute(1, 0, 2, 3)
            # )

            # 此处的mean,std可以通过torch.mean(input, dim=(0, 2, 3), keepdim=True)来实现
            mean = torch.mean(input, dim=(0, 2, 3), keepdim=True)
            std = torch.std(input, dim=(0, 2, 3), keepdim=True)

            # 论文中提到的数据以来的初始化 ,目的是第一个batch的数据经过actnorm后变成标准分布,以稳定训练收敛
            self.loc.data.copy_(-mean)
            self.scale.data.copy_(1 / (std + 1e-6))

    def forward(self, input):  # 前向过程
        _, _, height, width = input.shape  # 整体的shape是[batch_size, num_channel, height, width]

        if self.initialized.item() == 0:  # 最开始为初始化的0,表示第一个batch,要进行数据依赖的初始化;完成后,用1填充,后续就不再执行
            self.initialize(input)
            self.initialized.fill_(1)

        log_abs = logabs(self.scale)  # 对scake中的每个元素值求log

        logdet = height * width * torch.sum(log_abs)  # 对数似然的变化量,是一个标量

        if self.logdet:  # 是否返回对数似然的增量
            return self.scale * (input + self.loc), logdet

        else:
            return self.scale * (input + self.loc)

    def reverse(self, output):  # 逆向过程,推理
        return output / self.scale - self.loc


# 可逆1X1二维卷积
class InvConv2d(nn.Module):
    def __init__(self, in_channel):
        super().__init__()

        weight = torch.randn(in_channel, in_channel)  # 初始化1X1卷积的权重参数,尺寸是[c, c],就相当于在通道上的一个MLP
        q, _ = torch.qr(weight)  # 对weight进行qr分解,q是一个正交的矩阵,即q的行列式不为0
        # 尺寸扩为[c, c, 1, 1],因为在conv2d中要求的weight的尺寸为[input_size, output_size, kernel_size, kernel_size]
        weight = q.unsqueeze(2).unsqueeze(3)
        self.weight = nn.Parameter(weight)

    def forward(self, input):  # 前向
        _, _, height, width = input.shape

        out = F.conv2d(input, self.weight)
        logdet = height * width * torch.slogdet(self.weight.squeeze().double())[1].float()

        return out, logdet

    def reverse(self, output):  # 逆向
        return F.conv2d(output, self.weight.squeeze().inverse().unsqueeze(2).unsqueeze(3))


# 使用LU分解的快速版可逆二维卷积
class InvConv2dLU(nn.Module):
    def __init__(self, in_channel):
        super().__init__()

        weight = np.random.randn(in_channel, in_channel)
        q, _ = la.qr(weight)  # 任何矩阵都可进行qr分解 ,分解后q是正交矩阵

        # 行列式不为0,LU分解一定存在;q为正交矩阵,行列式值为+1或-1,故一定可以进行LU分解
        # A = P L U,P为置换矩阵,L为下三角、U为上三角;L的对角线元素一定为1
        w_p, w_l, w_u = la.lu(q.astype(np.float32))
        w_s = np.diag(w_u)  # 获取对角线元素构成数组,一维向量
        w_u = np.triu(w_u, 1)  # 只保留第一条对角线的上三角矩阵,其实就是对角线元素表内0
        u_mask = np.triu(np.ones_like(w_u), 1)  # 只保留第一条对角线的上三角矩阵,对应元素全部为1,其他元素变为了0
        l_mask = u_mask.T  # 只保留第-1条对角线的下三角矩阵,对应元素为1,其他元素全部为0
        # 将上述定义的向量转为张量
        w_p = torch.from_numpy(w_p).to(device)
        w_l = torch.from_numpy(w_l).to(device)
        w_s = torch.from_numpy(w_s).to(device)
        w_u = torch.from_numpy(w_u).to(device)

        # 将不需要更新的量,将其注册为buffer
        self.
评论 2
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值