AutoEncoder实战
import os
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow. keras import Sequential, layers
from PIL import Image
from matplotlib import pyplot as plt
tf. random. set_seed( 22 )
np. random. seed( 22 )
os. environ[ 'TF_CPP_MIN_LOG_LEVEL' ] = '2'
assert tf. __version__. startswith( '2.' )
def save_images ( imgs, name) :
new_im = Image. new( 'L' , ( 280 , 280 ) )
index = 0
for i in range ( 0 , 280 , 28 ) :
for j in range ( 0 , 280 , 28 ) :
im = imgs[ index]
im = Image. fromarray( im, mode= 'L' )
new_im. paste( im, ( i, j) )
index += 1
new_im. save( name)
h_dim = 20
batchsz = 512
lr = 1e - 3
( x_train, y_train) , ( x_test, y_test) = keras. datasets. fashion_mnist. load_data( )
x_train, x_test = x_train. astype( np. float32) / 255 . , x_test. astype( np. float32) / 255 .
train_db = tf. data. Dataset. from_tensor_slices( x_train)
train_db = train_db. shuffle( batchsz * 5 ) . batch( batchsz)
test_db = tf. data. Dataset. from_tensor_slices( x_test)
test_db = test_db. batch( batchsz)
class AE ( keras. Model) :
def __init__ ( self) :
super ( AE, self) . __init__( )
self. encoder = Sequential( [
layers. Dense( 256 , activation= tf. nn. relu) ,
layers. Dense( 128 , activation= tf. nn. relu) ,
layers. Dense( h_dim)
] )
self. decoder = Sequential( [
layers. Dense( 128 , activation= tf. nn. relu) ,
layers. Dense( 256 , activation= tf. nn. relu) ,
layers. Dense( 784 )
] )
def call ( self, inputs, training= None ) :
h = self. encoder( inputs)
x_hat = self. decoder( h)
return x_hat
model = AE( )
model. build( input_shape= ( None , 784 ) )
model. summary( )
optimizer = tf. optimizers. Adam( lr= lr)
for epoch in range ( 100 ) :
for step, x in enumerate ( train_db) :
x = tf. reshape( x, [ - 1 , 784 ] )
with tf. GradientTape( ) as tape:
x_rec_logits = model( x)
rec_loss = tf. losses. binary_crossentropy( x, x_rec_logits, from_logits= True )
rec_loss = tf. reduce_mean( rec_loss)
grads = tape. gradient( rec_loss, model. trainable_variables)
optimizer. apply_gradients( zip ( grads, model. trainable_variables) )
if step % 100 == 0 :
print ( epoch