代码如下
from tensorflow.keras import Sequential,Model
from tensorflow.keras.layers import Dense, Reshape, Input, Flatten
from tensorflow.keras.layers import LeakyReLU, BatchNormalization
from tensorflow.keras.datasets import mnist
from tensorflow.keras.optimizers import Adam
import numpy as np
import matplotlib.pyplot as plt
class GAN():
def __init__(self):
self.latent_dim = 100
self.img_rows = 28
self.img_cols = 28
self.channel = 1
self.img_shape = (self.img_rows, self.img_cols, self.channel)
self.discriminator = self.build_discriminator()
optimizer = Adam(0.0002, 0.5)
self.discriminator.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
self.generator = self.build_generator()
self.discriminator.trainable = False
z = Input(shape=(self.latent_dim,))
img = self.generator(z)
validity = self.discriminator(img)
self.combined = Model(z