ProGAN代码实现细节(二)

这里承接上一篇关于ProGAN的论文精读笔记,记录一下其代码实现细节。

一、初始化

#导入需要的库和包
import argparse
import os
import numpy as np
import time
import datetime
import sys
import torch
import torchvision as tv

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torch.autograd import Variable
def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ("yes", "true", "t", "y", "1"):
        return True
    elif v.lower() in ("no", "false", "f", "n", "0"):
        return False
    else:
        raise argparse.ArgumentTypeError("Boolean value expected.")

        
parser = argparse.ArgumentParser("Train Progressively grown GAN",
                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter, )
parser.add_argument("--rec_dir", action="store", type=str2bool, default=True,
                    help="whether images stored under one folder or has a recursive dir structure",
                    required=False, )
parser.add_argument("--flip_horizontal", action="store", type=str2bool, default=True,
                    help="whether to apply mirror augmentation", required=False, )
parser.add_argument("--depth", action="store", type=int, default=5,
                    help="对应generator和discriminator的深度,depth=5适用于渐进训练到32*32网络也就是cifar-10,高分辨率的以此类推是2的n次方n=depth的分辨率", required=False, )
parser.add_argument("--num_channels", action="store", type=int, default=3,
                    help="输入图像的通道数,默认是RGB三通道", required=False, )
parser.add_argument("--latent_size", action="store", type=int, default=128,
                    help="generator与discriminator中间的隐层神经元节点个数,如果是高分辨率图像可以改为512", required=False, )
parser.add_argument("--use_eql", action="store", type=str2bool, default=True,
                    help="控制是否使用均衡学习率策略", required=False, )
parser.add_argument("--use_ema", action="store", type=str2bool, default=True,
                    help="是否使用指数滑动平均策略", required=False, )
parser.add_argument("--ema_beta", action="store", type=float, default=0.999, help="value of the ema beta",
                    required=False, )
parser.add_argument("--epochs", action="store", type=int, required=False, nargs="+",
                    default=[172 for _ in range(9)], help="Mapper network configuration", )
parser.add_argument("--batch_sizes", action="store", type=int, required=False, nargs="+",
                    default=[512, 256, 128, 128, ], help="Mapper network configuration", )
parser.add_argument("--fade_in_percentages", action="store", type=int, required=False, nargs="+",
                    default=[50 for _ in range(9)], help="百分之多少的epoch采用平滑过渡也就是渐进式训练的那个阿尔法进行平滑过渡", )
args = parser.parse_known_args()[0]
#该组参数适用于cifar-10数据集
num_epochs = [4, 8, 16, 20]  #分别对应渐进式训练的1,2,3,4层epoch数,分辨率越高训练epoch越多
fade_ins = [50, 50, 50, 50]  #平滑过度在所有epoch中占比,这里都是前一半进行过度,后一半epoch阿尔法完全变成1
gen_learning_rate = 0.003
dis_learning_rate = 0.003

二、数据加载:

#加载的数据为 256*256 的建筑正面照片和分割图
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

data_path = "../DATA/"
transforms = tv.transforms.Compose(
    [tv.transforms.ToTensor(), tv.transforms.Normalize([0.5], [0.5])]
)

trainset = tv.datasets.CIFAR10(root=data_path,
                               transform=transforms,
                               download=True)
from torch.utils.data import DataLoader, Dataset

def get_data_loader(dataset: Dataset, batch_size: int, num_workers: int = 3) -> DataLoader:
    """
    generate the data_loader from the given dataset
    Args:
        dataset: Torch dataset object
        batch_size: batch size for training
        num_workers: num of parallel readers for reading the data
    Returns: dataloader for the dataset
    """
    return DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True,)

查看一下训练集中的图像

from torch.autograd import Variable
import matplotlib.pyplot as plt

def show_img(img, trans=True):
    if trans:
        img = np.transpose(img.detach().cpu().numpy(), (1, 2, 0))  # 把channel维度放到最后
        plt.imshow(img / 2 + 0.5)
    else:
        plt.imshow(img)
    plt.show()

print(trainset[0])
for i in range(2):
    sample = trainset[i]
    print(classes[sample[1]])
    show_img(sample[0], trans=True)

三、模型

3.1自定义层

3.1.1学习率均衡

使用动态权重缩放

from torch.nn import Conv2d, ConvTranspose2d, Linear
from torch import Tensor


class EqualizedConv2d(Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, 
                 bias=True, padding_mode="zeros", ) -> None:
        super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias,
                         padding_mode, )
        # make sure that the self.weight and self.bias are initialized according to random normal distribution
        torch.nn.init.normal_(self.weight)
        if bias:
            torch.nn.init.zeros_(self.bias)

        # define the scale for the weights:
        fan_in 
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值