BigGAN代码解读(gpt3.5帮助)——生成器部分

本文介绍了GitHub上点赞最多的BigGAN生成器代码实现,包括谱正则化(SN)和批正则化(BN)的使用,以及根据不同分辨率调整的参数。代码详细展示了如何定义生成器的架构,并根据分辨率、注意力机制和卷积核大小等设置生成不同尺寸的图像。此外,还讨论了潜在空间的处理和网络组件的选择,如使用SNConv2d和SNLinear进行谱正则化,以及在不同分辨率下添加注意力层的逻辑。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

代码来源于Github中点赞最多的BigGAN复现

作者个人学习记录

BigGAN的生成器代码内部引用了代码人员编写的谱正则化(SN)以及批正则化(BN),关于这部分的解读地址在这里:
批正则化
谱正则化

关于生成器,对于生成图片的不同分辨率,代码人员提供了不同的参数,代码如下:

# Architectures for G
# Attention is passed in in the format '32_64' to mean applying an attention
# block at both resolution 32x32 and 64x64. Just '64' will apply at 64x64.
def G_arch(ch=64, attention='64', ksize='333333', dilation='111111'):
  # 字典 不同的分辨率对应不同的结构
  arch = {}
  arch[512] = {'in_channels' :  [ch * item for item in [16, 16, 8, 8, 4, 2, 1]],
               'out_channels' : [ch * item for item in [16,  8, 8, 4, 2, 1, 1]],
               'upsample' : [True] * 7,
               'resolution' : [8, 16, 32, 64, 128, 256, 512],
               'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')])
                              for i in range(3,10)}}
  arch[256] = {'in_channels' :  [ch * item for item in [16, 16, 8, 8, 4, 2]],
               'out_channels' : [ch * item for item in [16,  8, 8, 4, 2, 1]],
               'upsample' : [True] * 6,
               'resolution' : [8, 16, 32, 64, 128, 256],
               'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')])
                              for i in range(3,9)}}
  arch[128] = {'in_channels' :  [ch * item for item in [16, 16, 8, 4, 2]],
               'out_channels' : [ch * item for item in [16, 8, 4, 2, 1]],
               'upsample' : [True] * 5,
               'resolution' : [8, 16, 32, 64, 128],
               'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')])
                              for i in range(3,8)}}
  arch[64]  = {'in_channels' :  [ch * item for item in [16, 16, 8, 4]],
               'out_channels' : [ch * item for item in [16, 8, 4, 2]],
               'upsample' : [True] * 4,
               'resolution' : [8, 16, 32, 64],
               'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')])
                              for i in range(3,7)}}
  arch[32]  = {'in_channels' :  [ch * item for item in [4, 4, 4]],
               'out_channels' : [ch * item for item in [4, 4, 4]],
               'upsample' : [True] * 3,
               'resolution' : [8, 16, 32],
               'attention' : {2**i: (2**i in [int(item) for item in attention.split('_')])
                              for i in range(3,6)}}

  return arch

这段代码定义了一个函数,该函数返回一个字典,字典中的key代表生成图像的分辨率,value则代表不同分辨率所对应的参数,其中upsample代表是否进行上采样,attention代表在哪层使用注意力机制。这个字典将输入到生成器的类中,作为参数。
接下来为生成器部分:
在参数设置阶段代码如下

class Generator(nn.Module):
  def __init__(self, G_ch=64, dim_z=128, bottom_width=4, resolution=128,
               G_kernel_size=3, G_attn='64', n_classes=1000,
               num_G_SVs=1, num_G_SV_itrs=1,
               G_shared=True, shared_dim=0, hier=False,
               cross_replica=False, mybn=False,
               G_activation=nn.ReLU(inplace=False),
               G_lr=5e-5, G_B1=0.0, G_B2=0.999, adam_eps=1e-8,
               BN_eps=1e-5, SN_eps=1e-12, G_mixed_precision=False, G_fp16=False,
               G_init='ortho', skip_init=False, no_optim=False,
               G_param='SN', norm_style='bn',
               **kwargs):
    super(Generator, self).__init__()
    # Channel width mulitplier
    self.ch = G_ch
    # Dimensionality of the latent space
    self.dim_z = dim_z
    # The initial spatial dimensions
    self.bottom_width = bottom_width
    # Resolution of the output
    self.resolution = resolution
    # Kernel size?
    self.kernel_size = G_kernel_size
    # Attention?
    self.attention = G_attn
    # number of classes, for use in categorical conditional generation
    self.n_classes = n_classes
    # Use shared embeddings?
    self.G_shared = G_shared
    # Dimensionality of the shared embedding? Unused if not using G_shared
    self.shared_dim = shared_dim if shared_dim > 0 else dim_z
    # Hierarchical latent space?
    self.hier = hier
    # Cross replica batchnorm?
    self.cross_replica = cross_replica
    # Use my batchnorm?
    self.mybn = mybn
    # nonlinearity for residual blocks
    self.activation = G_activation
    # Initialization style
    self.init = G_init
    # Parameterization style
    self.G_param = G_param
    # Normalization style
    self.norm_style = norm_style
    # Epsilon for BatchNorm?
    self.BN_eps = BN_eps
    # Epsilon for Spectral Norm?
    self.SN_eps = SN_eps
    # fp16?
    self.fp16 = G_fp16
    # Architecture dict
    self.arch = G_arch(self.ch, self.attention)[resolution]

在这里分别定义了再上一阶段引入的字典,决定生成图像的分辨率,卷积核的大小,类别数量、时候进行分层操作(有潜在空间),是否使用注意力机制,使用何种bn层、激活函数、标准化方法等等。
之后如果有潜在空间,则对噪声信号进行处理,因为在论文中提到,输入到生成器的噪声被分割然后与类别信息相组合,输入到不同的层中。代码如下:

    # If using hierarchical latents, adjust z
    if self.hier:
      # Number of places z slots into
      # 有潜在空间的情况下,这个应该就是指多层需要chunk的意思吧
      # 因为要将噪声z分割成多个chunk,这部分用于计算chunk的num与size,以及对应的噪声z维度
      self.num_slots = len(self.arch['in_channels']) + 1 # 层的数量加一
      self.z_chunk_size = (self.dim_z // self.num_slots)
      # Recalculate latent dimensionality for even splitting into chunks
      self.dim_z = self.z_chunk_size *  self.num_slots
    else:
      # 只有一层需要chunk加入,就不需要分割z了
      self.num_slots = 1
      self.z_chunk_size = 0

如果有潜在空间,可以看到这里用到了字典中的value中的’in_channels’,代表每层使用的通道数,因为每层都要有被分割的噪声块(chunk)输入,并且在网络最开始也要输入噪声,所以输入的噪声被分割成为层数+1个chunk,确定数量之后就是确定维度。
如果没有潜在空间。就是在else之后,就有一块,也没有chunk了。

下一步为选择不同的程序组件,bn层、线性层、谱正则化、embedding这些,代码如下:

    # Which convs, batchnorms, and linear layers to use
    # 选择是否使用具有谱正则化的层
    if self.G_param == 'SN':
      self.which_conv = functools.partial(layers.SNConv2d,
                          kernel_size=3, padding=1,
                          num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,
                          eps=self.SN_eps)
      self.which_linear = functools.partial(layers.SNLinear,
                          num_svs=num_G_SVs, num_itrs=num_G_SV_itrs,
                          eps=self.SN_eps)
    else:
      self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)
      self.which_linear = nn.Linear
      
    # We use a non-spectral-normed embedding here regardless;
    # For some reason applying SN to G's embedding seems to randomly cripple(削弱) G
    # 具有SN的Embedding会使G变弱,所以使用无谱正则化的Embedding
    self.which_embedding = nn.Embedding
    bn_linear = (functools.partial(self.which_linear, bias=False) if self.G_shared
                 else self.which_embedding)
    # 选择bn层
    self.which_bn = functools.partial(layers.ccbn,
                          which_linear=bn_linear,
                          cross_replica=self.cross_replica,
                          mybn=self.mybn,
                          input_size=(self.shared_dim + self.z_chunk_size if self.G_shared
                                      else self.n_classes),
                          norm_style=self.norm_style,
                          eps=self.BN_eps)


    # Prepare model
    # If not using shared embeddings, self.shared is just a passthrough
    self.shared = (self.which_embedding(n_classes, self.shared_dim) if G_shared 
                    else layers.identity())
    # First linear layer
    # 第一个线性层,第一个参数为随机噪声被分割成的大小,第二个参数为第一层通道数*最开始的特征图大小
    self.linear = self.which_linear(self.dim_z // self.num_slots,
                                    self.arch['in_channels'][0] * (self.bottom_width **2))

在这部分中,首先选择是否使用具有谱正则化(SN)的线性层、卷积层、embedding层,在这里面从性能角度出发,代码人员对前两者使用了谱正则化操作,对embedding则不使用。接着就是定义了一个线性层和embedding层。
接着就是定义网络结构,代码如下:

    # self.blocks is a doubly-nested list of modules, the outer loop intended
    # to be over blocks at a given resolution (resblocks and/or self-attention)
    # while the inner loop is over a given block
    # 这段注释解释了self.blocks这个变量的结构和用途。self.blocks是一个双层嵌套的模块列表,
    # 外层循环用于处理给定分辨率的块(例如resblocks和自注意力模块),
    # 而内层循环则用于处理给定块的模块。
    # 因此,外层循环按照分辨率逐渐减小,内层循环按照块的顺序逐个处理模块。
    self.blocks = []
    for index in range(len(self.arch['out_channels'])):
      self.blocks += [[layers.GBlock(in_channels=self.arch['in_channels'][index],
                             out_channels=self.arch['out_channels'][index],
                             which_conv=self.which_conv,
                             which_bn=self.which_bn,
                             activation=self.activation,
                             upsample=(functools.partial(F.interpolate, scale_factor=2)
                                       if self.arch['upsample'][index] else None))]]

      # If attention on this block, attach it to the end
      if self.arch['attention'][self.arch['resolution'][index]]:
        # G_arch的参数‘attention’标志了是否在该block添加注意力
        print('Adding attention layer in G at resolution %d' % self.arch['resolution'][index])
        self.blocks[-1] += [layers.Attention(self.arch['out_channels'][index], self.which_conv)]

    # Turn self.blocks into a ModuleList so that it's all properly registered.
    self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])

    # output layer: batchnorm-relu-conv.
    # Consider using a non-spectral conv here
    self.output_layer = nn.Sequential(layers.bn(self.arch['out_channels'][-1],
                                                cross_replica=self.cross_replica,
                                                mybn=self.mybn),
                                    self.activation,
                                    self.which_conv(self.arch['out_channels'][-1], 3))

这部分就是通过循环,将每一层添加到列表中,GBlock是代码人员提前定义好的网络结构,之后将列表转换为ModuleLIst,后面的部分为参数初始化以及优化器部分,不进行赘述。
在forward函数中,对噪声信号进行了处理,处理的维度在之前定义过,代码如下:

  # Note on this forward function: we pass in a y vector which has
  # already been passed through G.shared to enable easy class-wise
  # interpolation later. If we passed in the one-hot and then ran it through
  # G.shared in this forward function, it would be harder to handle.
  def forward(self, z, y):
    # If hierarchical, concatenate zs and ys
    if self.hier:
      # 如果分层的情况下,把z分成层数+1的数量,第一个保留,剩下的和类别拼接在一起。
      zs = torch.split(z, self.z_chunk_size, 1)
      z = zs[0]
      ys = [torch.cat([y, item], 1) for item in zs[1:]]
    else:
      ys = [y] * len(self.blocks)
      
    # First linear layer
    # 第一个z的chunk直接给到线性层
    h = self.linear(z)
    # Reshape
    h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width)
    
    # Loop over blocks
    # "Loop over blocks" 意味着对生成器(Generator)的每个分辨率分块进行迭代。

    # 在BigGAN的生成器中,分辨率分块是通过将特征图的空间大小逐渐放大来实现的,每个分辨率级别有一个块。
    # 因此,在代码实现中,这些块存储在 self.blocks 这个属性中,并通过 for 循环在每个块上进行迭代。
    # 每个块本身由多个生成器模块(G-Block)组成,这些模块也是通过 for 循环进行迭代。
    # 通过这种方式,可以在每个分辨率级别逐步生成越来越复杂的图像。
    for index, blocklist in enumerate(self.blocks):
      # Second inner loop in case block has multiple layers
      for block in blocklist:
        h = block(h, ys[index])
        
    # Apply batchnorm-relu-conv-tanh at output
    return torch.tanh(self.output_layer(h))

将分割好的噪声信号,除了第一块,剩余块都与类别信息的one-hout向量组合起来,然后输入到刚刚定义好的模型中,产生生成图像。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值