本人自用,效果未知。目测有提升但并不大。
1.创建SpectralNorm.py
import tensorflow as tf
class SpectralNorm(tf.keras.constraints.Constraint):
def __init__(self, n_iter=5): # 超参数,据说5次就够了
self.n_iter = n_iter
def call(self, input_weights):
w = tf.reshape(input_weights, (-1, input_weights.shape[-1]))
u = tf.random.normal((w.shape[0], 1))
for _ in range(self.n_iter):
v = tf.matmul(w, u, transpose_a=True)
v /= tf.norm(v)
u = tf.matmul(w, v)
u /= tf.norm(u)
spec_norm = tf.matmul(u, tf.matmul(w, v), transpose_a=True)
return input_weights / spec_norm
2.在主程序中
from SpectralNorm import SpectralNorm
在判别器搭建网络时
def build_discriminator(self):
# 定义判别器网络
model = Sequential()