【人工智能笔记】第四十三节:TF2实现VITGAN对抗生成网络,ISN 实现

本文介绍VITGAN中的Improved Spectral Normalization(ISN)实现细节,通过增强Lipschitz连续性来稳定GAN训练过程。适用于使用Vision Transformers作为鉴别器的场景。

网络结构图
该章节介绍VITGAN对抗生成网络中,ISN 部分的代码实现。

目录(文章发布后会补上链接):

  1. 网络结构简介
  2. Mapping NetWork 实现
  3. PositionalEmbedding 实现
  4. MLP 实现
  5. MSA多头注意力 实现
  6. SLN自调制 实现
  7. CoordinatesPositionalEmbedding 实现
  8. ModulatedLinear 实现
  9. Siren 实现
  10. Generator生成器 实现
  11. PatchEmbedding 实现
  12. ISN 实现
  13. Discriminator鉴别器 实现
  14. VITGAN 实现

ISN 简介

论文原文
Improved Spectral Normalization. 为了进一步加强 Lipschitz 连续性,我们还在鉴别器训练中应用了谱归一化(SN)[35]。标准 SN 使用幂迭代来估计神经网络中每一层的投影矩阵的谱范数。然后将权重矩阵除以估计的谱范数,因此得到的投影矩阵的 Lipschitz 常数等于 1。我们发现 Transformer 块对 Lipschitz 常数的尺度很敏感,并且在使用 SN 时训练表现出非常缓慢的进展(c.f.表 3b)。同样,我们发现当使用基于 ViT 的鉴别器时,R1 梯度惩罚会削弱 GAN 训练(参见图 4)。 [14] 表明 MLP 块的小 Lipschitz 常数可能导致 Transformer 的输出崩溃为 rank-1 矩阵。为了解决这个问题,我们建议增加投影矩阵的光谱范数。

代码实现

import tensorflow as tf

class ISN(tf.keras.layers.Wrapper):
    """Performs spectral normalization on weights.
    This wrapper controls the Lipschitz constant of the layer by
    constraining its spectral norm, which can stabilize the training of GANs.
    See [Spectral Normalization for Generative Adversarial Networks](https://arxiv.org/abs/1802.05957).
    Wrap `tf.keras.layers.Conv2D`:
    >>> x = np.random.rand(1, 10, 10, 1)
    >>> conv2d = SpectralNormalization(tf.keras.layers.Conv2D(2, 2))
    >>> y = conv2d(x)
    >>> y.shape
    TensorShape([1, 9, 9, 2])
    Wrap `tf.keras.layers.Dense`:
    >>> x = np.random.rand(1, 10, 10, 1)
    >>> dense = SpectralNormalization(tf.keras.layers.Dense(10))
    >>> y = dense(x)
    >>> y.shape
    TensorShape([1, 10, 10, 10])
    Args:
        layer: A `tf.keras.layers.Layer` instance that
            has either `kernel` or `embeddings` attribute.
        power_iterations: `int`, the number of iterations during normalization.
    Raises:
        AssertionError: If not initialized with a `Layer` instance.
        ValueError: If initialized with negative `power_iterations`.
        AttributeError: If `layer` does not has `kernel` or `embeddings` attribute.
    """

    def __init__(self, layer: tf.keras.layers, power_iterations: int = 1, **kwargs):
        super().__init__(layer, **kwargs)
        if power_iterations <= 0:
            raise ValueError(
                "`power_iterations` should be greater than zero, got "
                "`power_iterations={}`".format(power_iterations)
            )
        self.power_iterations = power_iterations
        self._initialized = False

    def build(self, input_shape):
        """Build `Layer`"""
        super().build(input_shape)
        input_shape = tf.TensorShape(input_shape)
        self.input_spec = tf.keras.layers.InputSpec(shape=[None] + input_shape[1:])

        if hasattr(self.layer, "kernel"):
            self.w = self.layer.kernel
        elif hasattr(self.layer, "embeddings"):
            self.w = self.layer.embeddings
        else:
            raise AttributeError(
                "{} object has no attribute 'kernel' nor "
                "'embeddings'".format(type(self.layer).__name__)
            )

        self.w_shape = self.w.shape.as_list()

        self.u = self.add_weight(
            shape=(1, self.w_shape[-1]),
            initializer=tf.initializers.TruncatedNormal(stddev=0.02),
            trainable=False,
            name="sn_u",
            dtype=self.w.dtype,
        )

    def call(self, inputs, training=None):
        """Call `Layer`"""
        if training is None:
            training = tf.keras.backend.learning_phase()

        if training:
            self.normalize_weights()

        output = self.layer(inputs)
        return output
        
    def normalize_weights(self):
        """Generate spectral normalized weights.
        This method will update the value of `self.w` with the
        spectral normalized value, so that the layer is ready for `call()`.
        """

        w = tf.reshape(self.w, [-1, self.w_shape[-1]])
        u = self.u

        with tf.name_scope("spectral_normalize"):
            for _ in range(self.power_iterations):
                v = tf.math.l2_normalize(tf.matmul(u, w, transpose_b=True))
                u = tf.math.l2_normalize(tf.matmul(v, w))
            u = tf.stop_gradient(u)
            v = tf.stop_gradient(v)
            sigma = tf.matmul(tf.matmul(v, w), u, transpose_b=True)
            if not self._initialized:
                self.w_init_sigma = tf.constant(sigma)
                self._initialized = True
            self.u.assign(tf.cast(u, self.u.dtype))
            self.w.assign(
                tf.cast(tf.reshape(self.w_init_sigma * self.w / sigma, self.w_shape), self.w.dtype)
            )

if __name__ == "__main__":
    layer = ISN(tf.keras.layers.Dense(768, use_bias=False))
    x = tf.random.uniform([2,1,768], dtype=tf.float32)
    o1 = layer(x, training=True)
    tf.print('o1:', tf.shape(o1))
    o1 = layer(x, training=True)
    tf.print('o1:', tf.shape(o1))
    o1 = layer(x, training=False)
    tf.print('o1:', tf.shape(o1))

参考资料:

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

PPHT-H

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值