一、auto-encode
import os
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import Sequential, layers
from PIL import Image
from matplotlib import pyplot as plt
tf.random.set_seed(22)
np.random.seed(22)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
assert tf.__version__.startswith('2.')
def save_images(imgs, name):
new_im = Image.new('L', (280, 280))
index = 0
for i in range(0, 280, 28):
for j in range(0, 280, 28):
im = imgs[index]
im = Image.fromarray(im, mode='L')
new_im.paste(im, (i, j))
index += 1
new_im.save(name)
h_dim = 20
batchsz = 512
lr = 1e-3
(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()
x_train, x_test = x_train.astype(np.float32) / 255., x_test.astype(np.float32) / 255.
train_db = tf.data.Dataset.from_tensor_slices(x_train)
train_db = train_db.shuffle(batchsz * 5).batch(batchsz)
test_db = tf.data.Dataset.from_tensor_slices(x_test)
test_db = test_db.batch(batchsz)
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)
class AE(keras.Model):
def __init__(self):
super(AE, self).__init__()
self.encoder = Sequential([
layers.Dense(256, activation=tf.nn.relu),
layers.Dense(128, activation=tf.nn.relu),
layers.Dense(h_dim)
])
self.decoder = Sequential([
layers.Dense(128, activation=tf.nn.relu),
layers.Dense(256, activation=tf.nn.relu),
layers.Dense(784)
])
def call(self, inputs, training=None):
h = self.encoder(inputs)
x_hat = self.decoder(h)
return x_hat
model = AE()
model.build(input_shape=(None, 784))
model.summary()
optimizer = tf.optimizers.Adam(lr=lr)
for epoch in range(100):
for step, x in enumerate(train_db):
x = tf.reshape(x, [-1, 784])
with tf.GradientTape() as tape:
x_rec_logits = model(x)
rec_loss = tf.losses.binary_crossentropy(x, x_rec_logits, from_logits=True)
rec_loss = tf.reduce_mean(rec_loss)
grads = tape.gradient(rec_loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
if step % 100 ==0:
print(epoch, step, float(rec_loss))
x = next(iter(test_db))
logits = model(tf.reshape(x, [-1, 784]))
x_hat = tf.sigmoid(logits)
x_hat = tf.reshape(x_hat, [-1, 28, 28])
x_concat = tf.concat([x, x_hat], axis=0)
x_concat = x_hat
x_concat = x_concat.numpy() * 255.
x_concat = x_concat.astype(np.uint8)
save_images(x_concat, 'ae_images/rec_epoch_%d.png'%epoch)
二、VAE
import os
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import Sequential, layers
from PIL import Image
from matplotlib import pyplot as plt
tf.random.set_seed(22)
np.random.seed(22)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
assert tf.__version__.startswith('2.')
def save_images(imgs, name):
new_im = Image.new('L', (280, 280))
index = 0
for i in range(0, 280, 28):
for j in range(0, 280, 28):
im = imgs[index]
im = Image.fromarray(im, mode='L')
new_im.paste(im, (i, j))
index += 1
new_im.save(name)
h_dim = 20
batchsz = 512
lr = 1e-3
(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()
x_train, x_test = x_train.astype(np.float32) / 255., x_test.astype(np.float32) / 255.
train_db = tf.data.Dataset.from_tensor_slices(x_train)
train_db = train_db.shuffle(batchsz * 5).batch(batchsz)
test_db = tf.data.Dataset.from_tensor_slices(x_test)
test_db = test_db.batch(batchsz)
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)
z_dim = 10
class VAE(keras.Model):
def __init__(self):
super(VAE, self).__init__()
self.fc1 = layers.Dense(128)
self.fc2 = layers.Dense(z_dim)
self.fc3 = layers.Dense(z_dim)
self.fc4 = layers.Dense(128)
self.fc5 = layers.Dense(784)
def encoder(self, x):
h = tf.nn.relu(self.fc1(x))
mu = self.fc2(h)
log_var = self.fc3(h)
return mu, log_var
def decoder(self, z):
out = tf.nn.relu(self.fc4(z))
out = self.fc5(out)
return out
def reparameterize(self, mu, log_var):
eps = tf.random.normal(log_var.shape)
std = tf.exp(log_var * 0.5)
z = mu + std * eps
return z
def call(self, inputs, training=None):
mu, log_var = self.encoder(inputs)
z = self.reparameterize(mu, log_var)
x_hat = self.decoder(z)
return x_hat, mu, log_var
model = VAE()
model.build(input_shape=(4, 784))
optimizer = tf.optimizers.Adam(lr)
for epoch in range(1000):
for step, x in enumerate(train_db):
x = tf.reshape(x, [-1, 784])
with tf.GradientTape() as tape:
x_rec_logits, mu, log_var = model(x)
rec_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=x, logits=x_rec_logits)
rec_loss = tf.reduce_sum(rec_loss) / x.shape[0]
kl_div = -0.5 * (log_var + 1 - mu ** 2 - tf.exp(log_var))
kl_div = tf.reduce_sum(kl_div) / x.shape[0]
loss = rec_loss + 1. * kl_div
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
if step % 100 == 0:
print(epoch, step, 'kl div:', float(kl_div), 'rec loss:', float(rec_loss))
z = tf.random.normal((batchsz, z_dim))
logits = model.decoder(z)
x_hat = tf.sigmoid(logits)
x_hat = tf.reshape(x_hat, [-1, 28, 28]).numpy() * 255.
x_hat = x_hat.astype(np.uint8)
save_images(x_hat, 'vae_images/sampled_epoch%d.png' % epoch)
x = next(iter(test_db))
x = tf.reshape(x, [-1, 784])
x_hat_logits, _, _ = model(x)
x_hat = tf.sigmoid(x_hat_logits)
x_hat = tf.reshape(x_hat, [-1, 28, 28]).numpy() * 255.
x_hat = x_hat.astype(np.uint8)
save_images(x_hat, 'vae_images/rec_epoch%d.png' % epoch)
三、GAN生成图片
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
class Generator(keras.Model):
def __init__(self):
super(Generator, self).__init__()
filter = 64
self.conv1 = layers.Conv2DTranspose(filter * 8, 4, 1, 'valid', use_bias=False)
self.bn1 = layers.BatchNormalization()
self.conv2 = layers.Conv2DTranspose(filter * 4, 4, 2, 'same', use_bias=False)
self.bn2 = layers.BatchNormalization()
self.conv3 = layers.Conv2DTranspose(filter * 2, 4, 2, 'same', use_bias=False)
self.bn3 = layers.BatchNormalization()
self.conv4 = layers.Conv2DTranspose(filter * 1, 4, 2, 'same', use_bias=False)
self.bn4 = layers.BatchNormalization()
self.conv5 = layers.Conv2DTranspose(3, 4, 2, 'same', use_bias=False)
def call(self, inputs, training=None):
x = inputs
x = tf.reshape(x, (x.shape[0], 1, 1, x.shape[1]))
x = tf.nn.relu(x)
x = tf.nn.relu(self.bn1(self.conv1(x), training=training))
x = tf.nn.relu(self.bn2(self.conv2(x), training=training))
x = tf.nn.relu(self.bn3(self.conv3(x), training=training))
x = tf.nn.relu(self.bn4(self.conv4(x), training=training))
x = self.conv5(x)
x = tf.tanh(x)
return x
class Discriminator(keras.Model):
def __init__(self):
super(Discriminator, self).__init__()
filter = 64
self.conv1 = layers.Conv2D(filter, 4, 2, 'valid', use_bias=False)
self.bn1 = layers.BatchNormalization()
self.conv2 = layers.Conv2D(filter * 2, 4, 2, 'valid', use_bias=False)
self.bn2 = layers.BatchNormalization()
self.conv3 = layers.Conv2D(filter * 4, 4, 2, 'valid', use_bias=False)
self.bn3 = layers.BatchNormalization()
self.conv4 = layers.Conv2D(filter * 8, 3, 1, 'valid', use_bias=False)
self.bn4 = layers.BatchNormalization()
self.conv5 = layers.Conv2D(filter * 16, 3, 1, 'valid', use_bias=False)
self.bn5 = layers.BatchNormalization()
self.pool = layers.GlobalAveragePooling2D()
self.flatten = layers.Flatten()
self.fc = layers.Dense(1)
def call(self, inputs, training=None):
x = tf.nn.leaky_relu(self.bn1(self.conv1(inputs), training=training))
x = tf.nn.leaky_relu(self.bn2(self.conv2(x), training=training))
x = tf.nn.leaky_relu(self.bn3(self.conv3(x), training=training))
x = tf.nn.leaky_relu(self.bn4(self.conv4(x), training=training))
x = tf.nn.leaky_relu(self.bn5(self.conv5(x), training=training))
x = self.pool(x)
x = self.flatten(x)
logits = self.fc(x)
return logits
def main():
d = Discriminator()
g = Generator()
x = tf.random.normal([2, 64, 64, 3])
z = tf.random.normal([2, 100])
prob = d(x)
print(prob)
x_hat = g(z)
print(x_hat.shape)
if __name__ == '__main__':
main()
import os
import numpy as np
import tensorflow as tf
from tensorflow import keras
from scipy.misc import toimage
import glob
from gan import Generator, Discriminator
from dataset import make_anime_dataset
def save_result(val_out, val_block_size, image_path, color_mode):
def preprocess(img):
img = ((img + 1.0) * 127.5).astype(np.uint8)
return img
preprocesed = preprocess(val_out)
final_image = np.array([])
single_row = np.array([])
for b in range(val_out.shape[0]):
if single_row.size == 0:
single_row = preprocesed[b, :, :, :]
else:
single_row = np.concatenate((single_row, preprocesed[b, :, :, :]), axis=1)
if (b + 1) % val_block_size == 0:
if final_image.size == 0:
final_image = single_row
else:
final_image = np.concatenate((final_image, single_row), axis=0)
single_row = np.array([])
if final_image.shape[2] == 1:
final_image = np.squeeze(final_image, axis=2)
toimage(final_image).save(image_path)
def celoss_ones(logits):
y = tf.ones_like(logits)
loss = keras.losses.binary_crossentropy(y, logits, from_logits=True)
return tf.reduce_mean(loss)
def celoss_zeros(logits):
y = tf.zeros_like(logits)
loss = keras.losses.binary_crossentropy(y, logits, from_logits=True)
return tf.reduce_mean(loss)
def d_loss_fn(generator, discriminator, batch_z, batch_x, is_training):
fake_image = generator(batch_z, is_training)
d_fake_logits = discriminator(fake_image, is_training)
d_real_logits = discriminator(batch_x, is_training)
d_loss_real = celoss_ones(d_real_logits)
d_loss_fake = celoss_zeros(d_fake_logits)
loss = d_loss_fake + d_loss_real
return loss
def g_loss_fn(generator, discriminator, batch_z, is_training):
fake_image = generator(batch_z, is_training)
d_fake_logits = discriminator(fake_image, is_training)
loss = celoss_ones(d_fake_logits)
return loss
def main():
tf.random.set_seed(3333)
np.random.seed(3333)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
assert tf.__version__.startswith('2.')
z_dim = 100
epochs = 3000000
batch_size = 64
learning_rate = 0.0002
is_training = True
img_path = glob.glob(r'C:\Users\z390\Downloads\anime-faces\*\*.jpg') + \
glob.glob(r'C:\Users\z390\Downloads\anime-faces\*\*.png')
print('images num:', len(img_path))
dataset, img_shape, _ = make_anime_dataset(img_path, batch_size, resize=64)
print(dataset, img_shape)
sample = next(iter(dataset))
print(sample.shape, tf.reduce_max(sample).numpy(),
tf.reduce_min(sample).numpy())
dataset = dataset.repeat(100)
db_iter = iter(dataset)
generator = Generator()
generator.build(input_shape=(4, z_dim))
discriminator = Discriminator()
discriminator.build(input_shape=(4, 64, 64, 3))
g_optimizer = keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
d_optimizer = keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
generator.load_weights('generator.ckpt')
discriminator.load_weights('discriminator.ckpt')
print('Loaded chpt!!')
d_losses, g_losses = [], []
for epoch in range(epochs):
for _ in range(1):
batch_z = tf.random.normal([batch_size, z_dim])
batch_x = next(db_iter)
with tf.GradientTape() as tape:
d_loss = d_loss_fn(generator, discriminator, batch_z, batch_x, is_training)
grads = tape.gradient(d_loss, discriminator.trainable_variables)
d_optimizer.apply_gradients(zip(grads, discriminator.trainable_variables))
batch_z = tf.random.normal([batch_size, z_dim])
batch_x = next(db_iter)
with tf.GradientTape() as tape:
g_loss = g_loss_fn(generator, discriminator, batch_z, is_training)
grads = tape.gradient(g_loss, generator.trainable_variables)
g_optimizer.apply_gradients(zip(grads, generator.trainable_variables))
if epoch % 100 == 0:
print(epoch, 'd-loss:', float(d_loss), 'g-loss:', float(g_loss))
z = tf.random.normal([100, z_dim])
fake_image = generator(z, training=False)
img_path = os.path.join('gan_images', 'gan-%d.png' % epoch)
save_result(fake_image.numpy(), 10, img_path, color_mode='P')
d_losses.append(float(d_loss))
g_losses.append(float(g_loss))
if epoch % 10000 == 1:
generator.save_weights('generator.ckpt')
discriminator.save_weights('discriminator.ckpt')
if __name__ == '__main__':
main()
import multiprocessing
import tensorflow as tf
def make_anime_dataset(img_paths, batch_size, resize=64, drop_remainder=True, shuffle=True, repeat=1):
def _map_fn(img):
img = tf.image.resize(img, [resize, resize])
img = tf.clip_by_value(img, 0, 255)
img = img / 127.5 - 1
return img
dataset = disk_image_batch_dataset(img_paths,
batch_size,
drop_remainder=drop_remainder,
map_fn=_map_fn,
shuffle=shuffle,
repeat=repeat)
img_shape = (resize, resize, 3)
len_dataset = len(img_paths) // batch_size
return dataset, img_shape, len_dataset
def batch_dataset(dataset,
batch_size,
drop_remainder=True,
n_prefetch_batch=1,
filter_fn=None,
map_fn=None,
n_map_threads=None,
filter_after_map=False,
shuffle=True,
shuffle_buffer_size=None,
repeat=None):
if n_map_threads is None:
n_map_threads = multiprocessing.cpu_count()
if shuffle and shuffle_buffer_size is None:
shuffle_buffer_size = max(batch_size * 128, 2048)
if shuffle:
dataset = dataset.shuffle(shuffle_buffer_size)
if not filter_after_map:
if filter_fn:
dataset = dataset.filter(filter_fn)
if map_fn:
dataset = dataset.map(map_fn, num_parallel_calls=n_map_threads)
else:
if map_fn:
dataset = dataset.map(map_fn, num_parallel_calls=n_map_threads)
if filter_fn:
dataset = dataset.filter(filter_fn)
dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
dataset = dataset.repeat(repeat).prefetch(n_prefetch_batch)
return dataset
def memory_data_batch_dataset(memory_data,
batch_size,
drop_remainder=True,
n_prefetch_batch=1,
filter_fn=None,
map_fn=None,
n_map_threads=None,
filter_after_map=False,
shuffle=True,
shuffle_buffer_size=None,
repeat=None):
"""Batch dataset of memory data.
Parameters
----------
memory_data : nested structure of tensors/ndarrays/lists
"""
dataset = tf.data.Dataset.from_tensor_slices(memory_data)
dataset = batch_dataset(dataset,
batch_size,
drop_remainder=drop_remainder,
n_prefetch_batch=n_prefetch_batch,
filter_fn=filter_fn,
map_fn=map_fn,
n_map_threads=n_map_threads,
filter_after_map=filter_after_map,
shuffle=shuffle,
shuffle_buffer_size=shuffle_buffer_size,
repeat=repeat)
return dataset
def disk_image_batch_dataset(img_paths,
batch_size,
labels=None,
drop_remainder=True,
n_prefetch_batch=1,
filter_fn=None,
map_fn=None,
n_map_threads=None,
filter_after_map=False,
shuffle=True,
shuffle_buffer_size=None,
repeat=None):
"""Batch dataset of disk image for PNG and JPEG.
Parameters
----------
img_paths : 1d-tensor/ndarray/list of str
labels : nested structure of tensors/ndarrays/lists
"""
if labels is None:
memory_data = img_paths
else:
memory_data = (img_paths, labels)
def parse_fn(path, *label):
img = tf.io.read_file(path)
img = tf.image.decode_jpeg(img, channels=3)
return (img,) + label
if map_fn:
def map_fn_(*args):
return map_fn(*parse_fn(*args))
else:
map_fn_ = parse_fn
dataset = memory_data_batch_dataset(memory_data,
batch_size,
drop_remainder=drop_remainder,
n_prefetch_batch=n_prefetch_batch,
filter_fn=filter_fn,
map_fn=map_fn_,
n_map_threads=n_map_threads,
filter_after_map=filter_after_map,
shuffle=shuffle,
shuffle_buffer_size=shuffle_buffer_size,
repeat=repeat)
return dataset
四、WGAN
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
class Generator(keras.Model):
def __init__(self):
super(Generator, self).__init__()
self.fc = layers.Dense(3 * 3 * 512)
self.conv1 = layers.Conv2DTranspose(256, 3, 3, 'valid')
self.bn1 = layers.BatchNormalization()
self.conv2 = layers.Conv2DTranspose(128, 5, 2, 'valid')
self.bn2 = layers.BatchNormalization()
self.conv3 = layers.Conv2DTranspose(3, 4, 3, 'valid')
def call(self, inputs, training=None):
x = self.fc(inputs)
x = tf.reshape(x, [-1, 3, 3, 512])
x = tf.nn.leaky_relu(x)
x = tf.nn.leaky_relu(self.bn1(self.conv1(x), training=training))
x = tf.nn.leaky_relu(self.bn2(self.conv2(x), training=training))
x = self.conv3(x)
x = tf.tanh(x)
return x
class Discriminator(keras.Model):
def __init__(self):
super(Discriminator, self).__init__()
self.conv1 = layers.Conv2D(64, 5, 3, 'valid')
self.conv2 = layers.Conv2D(128, 5, 3, 'valid')
self.bn2 = layers.BatchNormalization()
self.conv3 = layers.Conv2D(256, 5, 3, 'valid')
self.bn3 = layers.BatchNormalization()
self.flatten = layers.Flatten()
self.fc = layers.Dense(1)
def call(self, inputs, training=None):
x = tf.nn.leaky_relu(self.conv1(inputs))
x = tf.nn.leaky_relu(self.bn2(self.conv2(x), training=training))
x = tf.nn.leaky_relu(self.bn3(self.conv3(x), training=training))
x = self.flatten(x)
logits = self.fc(x)
return logits
def main():
d = Discriminator()
g = Generator()
x = tf.random.normal([2, 64, 64, 3])
z = tf.random.normal([2, 100])
prob = d(x)
print(prob)
x_hat = g(z)
print(x_hat.shape)
if __name__ == '__main__':
main()
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import numpy as np
import tensorflow as tf
from tensorflow import keras
from PIL import Image
import glob
from gan import Generator, Discriminator
from dataset import make_anime_dataset
def save_result(val_out, val_block_size, image_path, color_mode):
def preprocess(img):
img = ((img + 1.0) * 127.5).astype(np.uint8)
return img
preprocesed = preprocess(val_out)
final_image = np.array([])
single_row = np.array([])
for b in range(val_out.shape[0]):
if single_row.size == 0:
single_row = preprocesed[b, :, :, :]
else:
single_row = np.concatenate((single_row, preprocesed[b, :, :, :]), axis=1)
if (b + 1) % val_block_size == 0:
if final_image.size == 0:
final_image = single_row
else:
final_image = np.concatenate((final_image, single_row), axis=0)
single_row = np.array([])
if final_image.shape[2] == 1:
final_image = np.squeeze(final_image, axis=2)
Image.fromarray(final_image).save(image_path)
def celoss_ones(logits):
return - tf.reduce_mean(logits)
def celoss_zeros(logits):
return tf.reduce_mean(logits)
def gradient_penalty(discriminator, batch_x, fake_image):
batchsz = batch_x.shape[0]
t = tf.random.uniform([batchsz, 1, 1, 1])
t = tf.broadcast_to(t, batch_x.shape)
interplate = t * batch_x + (1 - t) * fake_image
with tf.GradientTape() as tape:
tape.watch([interplate])
d_interplote_logits = discriminator(interplate, training=True)
grads = tape.gradient(d_interplote_logits, interplate)
grads = tf.reshape(grads, [grads.shape[0], -1])
gp = tf.norm(grads, axis=1)
gp = tf.reduce_mean((gp - 1) ** 2)
return gp
def d_loss_fn(generator, discriminator, batch_z, batch_x, is_training):
fake_image = generator(batch_z, is_training)
d_fake_logits = discriminator(fake_image, is_training)
d_real_logits = discriminator(batch_x, is_training)
d_loss_real = celoss_ones(d_real_logits)
d_loss_fake = celoss_zeros(d_fake_logits)
gp = gradient_penalty(discriminator, batch_x, fake_image)
loss = d_loss_real + d_loss_fake + 10. * gp
return loss, gp
def g_loss_fn(generator, discriminator, batch_z, is_training):
fake_image = generator(batch_z, is_training)
d_fake_logits = discriminator(fake_image, is_training)
loss = celoss_ones(d_fake_logits)
return loss
def main():
tf.random.set_seed(233)
np.random.seed(233)
assert tf.__version__.startswith('2.')
z_dim = 100
epochs = 3000000
batch_size = 512
learning_rate = 0.0005
is_training = True
img_path = glob.glob(r'C:\Users\Jackie\Downloads\faces\*.jpg')
assert len(img_path) > 0
dataset, img_shape, _ = make_anime_dataset(img_path, batch_size)
print(dataset, img_shape)
sample = next(iter(dataset))
print(sample.shape, tf.reduce_max(sample).numpy(),
tf.reduce_min(sample).numpy())
dataset = dataset.repeat()
db_iter = iter(dataset)
generator = Generator()
generator.build(input_shape=(None, z_dim))
discriminator = Discriminator()
discriminator.build(input_shape=(None, 64, 64, 3))
z_sample = tf.random.normal([100, z_dim])
g_optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
d_optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
for epoch in range(epochs):
for _ in range(5):
batch_z = tf.random.normal([batch_size, z_dim])
batch_x = next(db_iter)
with tf.GradientTape() as tape:
d_loss, gp = d_loss_fn(generator, discriminator, batch_z, batch_x, is_training)
grads = tape.gradient(d_loss, discriminator.trainable_variables)
d_optimizer.apply_gradients(zip(grads, discriminator.trainable_variables))
batch_z = tf.random.normal([batch_size, z_dim])
with tf.GradientTape() as tape:
g_loss = g_loss_fn(generator, discriminator, batch_z, is_training)
grads = tape.gradient(g_loss, generator.trainable_variables)
g_optimizer.apply_gradients(zip(grads, generator.trainable_variables))
if epoch % 100 == 0:
print(epoch, 'd-loss:', float(d_loss), 'g-loss:', float(g_loss),
'gp:', float(gp))
z = tf.random.normal([100, z_dim])
fake_image = generator(z, training=False)
img_path = os.path.join('images', 'wgan-%d.png' % epoch)
save_result(fake_image.numpy(), 10, img_path, color_mode='P')
if __name__ == '__main__':
main()