来自<Tensorflow实战>一书 #coding: utf-8 import numpy as np import sklearn.preprocessing as prep import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data def xavier_init(fan_in, fan_out, constant = 1): low = -constant * np.sqrt(6.0 / (fan_in + fan_out)) high = constant * np.sqrt(6.0 / (fan_in + fan_out)) return tf.random_uniform((fan_in, fan_out),minval = low, maxval = high,dtype = tf.float32) class AdditiveGaussianNoiseAutoencoder(object): def __init__(self, n_input, n_hidden, transfer_function=tf.nn.softplus, optimizer=tf.train.AdamOptimizer(), scale=0.1): self.n_input = n_input self.n_hidden = n_hidden self.transfer = transfer_function self.scale = tf.placeholder(tf.float32) # 定义成一个placeholder self.training_scale = scale network_weights = self._initialize_weights() self.weights = network_weights # 定义模型,也就是输入层,隐含层,输出层以及之间的映射矩阵 self.x = tf.placeholder(tf.float32, [None, self.n_input]) self.hidden = self.transfer(tf.add(tf.matmul(self.x + scale * tf.random_normal((n_input,)), self.weights['w1']), self.weights['b1'])) self.reconstruction = tf.add(tf.matmul(self.hidden, self.weights['w2']), self.weights['b2']) # 定义损失函数,这里我们使用平方差,因为下面的激活函数选择的是恒等 self.cost = 0.5 * tf.reduce_sum(tf.pow(tf.subtract(self.reconstruction, self.x), 2.0)) self.optimizer = optimizer.minimize(self.cost) # 优化器为求损失极小化 init = tf.global_variables_initializer() self.sess = tf.Session() self.sess.run(init) def _initialize_weights(self): # 字典类型 all_weights = dict() # 输入层到隐含层矩阵 all_weights['w1'] = tf.Variable(xavier_init(self.n_input, self.n_hidden)) # 输入层到隐含层偏置向量 all_weights['b1'] = tf.Variable(tf.zeros([self.n_hidden], dtype=tf.float32)) # 隐含层到输出层矩阵,可以看出是w1的逆 all_weights['w2'] = tf.Variable(tf.zeros([self.n_hidden, self.n_input], dtype=tf.float32)) # 隐含层到输出层偏置向量 all_weights['b2'] = tf.Variable(tf.zeros([self.n_input], dtype=tf.float32)) return all_weights def partial_fit(self, X): # 执行损失+执行一步训练 cost, opt = self.sess.run((self.cost, self.optimizer),feed_dict = {self.x: X,self.scale: self.training_scale}) return cost def calc_total_cost(self, X): # 只执行损失 return self.sess.run(self.cost, feed_dict = {self.x: X, self.scale: self.training_scale}) def transform(self,X): # 为了看隐含层输出结果 return self.sess.run(self.hidden, feed_dict = {self.x: X, self.scale: self.training_scale}) def generate(self, hidden = None): # 将隐含层输出的高阶特征还原为原始数据 if hidden is None: hidden = np.random.normal(size = self.weights['b1']) return self.sess.run(self.reconstruction, feed_dict = {self.hidden: hidden}) def reconstruct(self, X): # transform + generator return self.sess.run(self.reconstruction, feed_dict={self.x: X, self.scale: self.training_scale}) def getWeights(self): return self.sess.run(self.weights['w1']) def getBiases(self): return self.sess.run(self.weights['b1']) mnist = input_data.read_data_sets('MNIST_data', one_hot= True) def standard_scale(X_train, X_test): # 标准化 均值0 标准差1 preprocessor = prep.StandardScaler().fit(X_train) X_train = preprocessor.transform(X_train) X_test = preprocessor.transform(X_test) return X_train, X_test def get_random_block_from_data(data, batch_size): # 随机取数 start_index = np.random.randint(0, len(data) - batch_size) return data[start_index:(start_index + batch_size)] X_train, X_test = standard_scale(mnist.train.images, mnist.test.images) n_samples = int(mnist.train.num_examples) training_epochs = 20 batch_size = 128 display_step = 1 #每轮epoch显示一次cost #创建一个AGN自编码器的实例 autoencoder = AdditiveGaussianNoiseAutoencoder(n_input = 784, n_hidden = 200, transfer_function = tf.nn.softplus, optimizer = tf.train.AdamOptimizer(learning_rate = 0.001), scale = 0.01) #训练 for epoch in range(training_epochs): avg_cost = 0. total_batch = int(n_samples / batch_size) # 循环载入训练数据 for i in range(total_batch): batch_xs = get_random_block_from_data(X_train, batch_size) # 使用batch喂入训练数据 cost = autoencoder.partial_fit(batch_xs) # Compute average loss avg_cost += cost / n_samples * batch_size # 展示损失 if epoch % display_step == 0: print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f}".format(avg_cost)) #训练完的模型性能测试 平方误差 print("Total cost: " + str(autoencoder.calc_total_cost(X_test)))