
该章节介绍VITGAN对抗生成网络中,ISN 部分的代码实现。
目录(文章发布后会补上链接):
- 网络结构简介
- Mapping NetWork 实现
- PositionalEmbedding 实现
- MLP 实现
- MSA多头注意力 实现
- SLN自调制 实现
- CoordinatesPositionalEmbedding 实现
- ModulatedLinear 实现
- Siren 实现
- Generator生成器 实现
- PatchEmbedding 实现
- ISN 实现
- Discriminator鉴别器 实现
- VITGAN 实现
ISN 简介

Improved Spectral Normalization. 为了进一步加强 Lipschitz 连续性,我们还在鉴别器训练中应用了谱归一化(SN)[35]。标准 SN 使用幂迭代来估计神经网络中每一层的投影矩阵的谱范数。然后将权重矩阵除以估计的谱范数,因此得到的投影矩阵的 Lipschitz 常数等于 1。我们发现 Transformer 块对 Lipschitz 常数的尺度很敏感,并且在使用 SN 时训练表现出非常缓慢的进展(c.f.表 3b)。同样,我们发现当使用基于 ViT 的鉴别器时,R1 梯度惩罚会削弱 GAN 训练(参见图 4)。 [14] 表明 MLP 块的小 Lipschitz 常数可能导致 Transformer 的输出崩溃为 rank-1 矩阵。为了解决这个问题,我们建议增加投影矩阵的光谱范数。
代码实现
import tensorflow as tf
class ISN(tf.keras.layers.Wrapper):
"""Performs spectral normalization on weights.
This wrapper controls the Lipschitz constant of the layer by
constraining its spectral norm, which can stabilize the training of GANs.
See [Spectral Normalization for Generative Adversarial Networks](https://arxiv.org/abs/1802.05957).
Wrap `tf.keras.layers.Conv2D`:
>>> x = np.random.rand(1, 10, 10, 1)
>>> conv2d = SpectralNormalization(tf.keras.layers.Conv2D(2, 2))
>>> y = conv2d(x)
>>> y.shape
TensorShape([1, 9, 9, 2])
Wrap `tf.keras.layers.Dense`:
>>> x = np.random.rand(1, 10, 10, 1)
>>> dense = SpectralNormalization(tf.keras.layers.Dense(10))
>>> y = dense(x)
>>> y.shape
TensorShape([1, 10, 10, 10])
Args:
layer: A `tf.keras.layers.Layer` instance that
has either `kernel` or `embeddings` attribute.
power_iterations: `int`, the number of iterations during normalization.
Raises:
AssertionError: If not initialized with a `Layer` instance.
ValueError: If initialized with negative `power_iterations`.
AttributeError: If `layer` does not has `kernel` or `embeddings` attribute.
"""
def __init__(self, layer: tf.keras.layers, power_iterations: int = 1, **kwargs):
super().__init__(layer, **kwargs)
if power_iterations <= 0:
raise ValueError(
"`power_iterations` should be greater than zero, got "
"`power_iterations={}`".format(power_iterations)
)
self.power_iterations = power_iterations
self._initialized = False
def build(self, input_shape):
"""Build `Layer`"""
super().build(input_shape)
input_shape = tf.TensorShape(input_shape)
self.input_spec = tf.keras.layers.InputSpec(shape=[None] + input_shape[1:])
if hasattr(self.layer, "kernel"):
self.w = self.layer.kernel
elif hasattr(self.layer, "embeddings"):
self.w = self.layer.embeddings
else:
raise AttributeError(
"{} object has no attribute 'kernel' nor "
"'embeddings'".format(type(self.layer).__name__)
)
self.w_shape = self.w.shape.as_list()
self.u = self.add_weight(
shape=(1, self.w_shape[-1]),
initializer=tf.initializers.TruncatedNormal(stddev=0.02),
trainable=False,
name="sn_u",
dtype=self.w.dtype,
)
def call(self, inputs, training=None):
"""Call `Layer`"""
if training is None:
training = tf.keras.backend.learning_phase()
if training:
self.normalize_weights()
output = self.layer(inputs)
return output
def normalize_weights(self):
"""Generate spectral normalized weights.
This method will update the value of `self.w` with the
spectral normalized value, so that the layer is ready for `call()`.
"""
w = tf.reshape(self.w, [-1, self.w_shape[-1]])
u = self.u
with tf.name_scope("spectral_normalize"):
for _ in range(self.power_iterations):
v = tf.math.l2_normalize(tf.matmul(u, w, transpose_b=True))
u = tf.math.l2_normalize(tf.matmul(v, w))
u = tf.stop_gradient(u)
v = tf.stop_gradient(v)
sigma = tf.matmul(tf.matmul(v, w), u, transpose_b=True)
if not self._initialized:
self.w_init_sigma = tf.constant(sigma)
self._initialized = True
self.u.assign(tf.cast(u, self.u.dtype))
self.w.assign(
tf.cast(tf.reshape(self.w_init_sigma * self.w / sigma, self.w_shape), self.w.dtype)
)
if __name__ == "__main__":
layer = ISN(tf.keras.layers.Dense(768, use_bias=False))
x = tf.random.uniform([2,1,768], dtype=tf.float32)
o1 = layer(x, training=True)
tf.print('o1:', tf.shape(o1))
o1 = layer(x, training=True)
tf.print('o1:', tf.shape(o1))
o1 = layer(x, training=False)
tf.print('o1:', tf.shape(o1))
本文介绍VITGAN中的Improved Spectral Normalization(ISN)实现细节,通过增强Lipschitz连续性来稳定GAN训练过程。适用于使用Vision Transformers作为鉴别器的场景。

被折叠的 条评论
为什么被折叠?



