这里承接上一篇关于ProGAN的论文精读笔记,记录一下其代码实现细节。
一、初始化
#导入需要的库和包
import argparse
import os
import numpy as np
import time
import datetime
import sys
import torch
import torchvision as tv
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torch.autograd import Variable
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("Boolean value expected.")
parser = argparse.ArgumentParser("Train Progressively grown GAN",
formatter_class=argparse.ArgumentDefaultsHelpFormatter, )
parser.add_argument("--rec_dir", action="store", type=str2bool, default=True,
help="whether images stored under one folder or has a recursive dir structure",
required=False, )
parser.add_argument("--flip_horizontal", action="store", type=str2bool, default=True,
help="whether to apply mirror augmentation", required=False, )
parser.add_argument("--depth", action="store", type=int, default=5,
help="对应generator和discriminator的深度,depth=5适用于渐进训练到32*32网络也就是cifar-10,高分辨率的以此类推是2的n次方n=depth的分辨率", required=False, )
parser.add_argument("--num_channels", action="store", type=int, default=3,
help="输入图像的通道数,默认是RGB三通道", required=False, )
parser.add_argument("--latent_size", action="store", type=int, default=128,
help="generator与discriminator中间的隐层神经元节点个数,如果是高分辨率图像可以改为512", required=False, )
parser.add_argument("--use_eql", action="store", type=str2bool, default=True,
help="控制是否使用均衡学习率策略", required=False, )
parser.add_argument("--use_ema", action="store", type=str2bool, default=True,
help="是否使用指数滑动平均策略", required=False, )
parser.add_argument("--ema_beta", action="store", type=float, default=0.999, help="value of the ema beta",
required=False, )
parser.add_argument("--epochs", action="store", type=int, required=False, nargs="+",
default=[172 for _ in range(9)], help="Mapper network configuration", )
parser.add_argument("--batch_sizes", action="store", type=int, required=False, nargs="+",
default=[512, 256, 128, 128, ], help="Mapper network configuration", )
parser.add_argument("--fade_in_percentages", action="store", type=int, required=False, nargs="+",
default=[50 for _ in range(9)], help="百分之多少的epoch采用平滑过渡也就是渐进式训练的那个阿尔法进行平滑过渡", )
args = parser.parse_known_args()[0]
#该组参数适用于cifar-10数据集
num_epochs = [4, 8, 16, 20] #分别对应渐进式训练的1,2,3,4层epoch数,分辨率越高训练epoch越多
fade_ins = [50, 50, 50, 50] #平滑过度在所有epoch中占比,这里都是前一半进行过度,后一半epoch阿尔法完全变成1
gen_learning_rate = 0.003
dis_learning_rate = 0.003
二、数据加载:
#加载的数据为 256*256 的建筑正面照片和分割图
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
data_path = "../DATA/"
transforms = tv.transforms.Compose(
[tv.transforms.ToTensor(), tv.transforms.Normalize([0.5], [0.5])]
)
trainset = tv.datasets.CIFAR10(root=data_path,
transform=transforms,
download=True)
from torch.utils.data import DataLoader, Dataset
def get_data_loader(dataset: Dataset, batch_size: int, num_workers: int = 3) -> DataLoader:
"""
generate the data_loader from the given dataset
Args:
dataset: Torch dataset object
batch_size: batch size for training
num_workers: num of parallel readers for reading the data
Returns: dataloader for the dataset
"""
return DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True,)
查看一下训练集中的图像
from torch.autograd import Variable
import matplotlib.pyplot as plt
def show_img(img, trans=True):
if trans:
img = np.transpose(img.detach().cpu().numpy(), (1, 2, 0)) # 把channel维度放到最后
plt.imshow(img / 2 + 0.5)
else:
plt.imshow(img)
plt.show()
print(trainset[0])
for i in range(2):
sample = trainset[i]
print(classes[sample[1]])
show_img(sample[0], trans=True)
三、模型
3.1自定义层
3.1.1学习率均衡
使用动态权重缩放
from torch.nn import Conv2d, ConvTranspose2d, Linear
from torch import Tensor
class EqualizedConv2d(Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1,
bias=True, padding_mode="zeros", ) -> None:
super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias,
padding_mode, )
# make sure that the self.weight and self.bias are initialized according to random normal distribution
torch.nn.init.normal_(self.weight)
if bias:
torch.nn.init.zeros_(self.bias)
# define the scale for the weights:
fan_in