参考书籍:《Tensorflow 实战Google深度学习框架》
我觉得看一下第三章可以更清晰的了解tensorflow是怎么建立,训练一个神经网络的。
1. train.py
这份文件定义了主函数
def main(args):
初始化:
# Initalization for running
if config.save_model:
log_dir = utils.create_log_dir(config, config_file)
summary_writer = tf.summary.FileWriter(log_dir, network.graph)
if config.restore_model:
network.restore_model(config.restore_model, config.restore_scopes)
proc_func = lambda images: preprocess(images, config, True)
trainset.start_batch_queue(config.batch_size, proc_func=proc_func)
这里的config参数设置都来自于文件 WarpGAN\config\default.py
数据集读取初始化等操作来自于文件 WarpGAN\utils\dataset.py
主循环:
# Main Loop
print('\nStart Training\nname: {}\n# epochs: {}\nepoch_size: {}\nbatch_size: {}\n'.format(
config.name, config.num_epochs, config.epoch_size, config.batch_size))
global_step = 0
start_time = time.time()
for epoch in range(config.num_epochs):
if epoch == 0: test(network, config, log_dir, global_step)
# Training
for step in range(config.epoch_size):
# Prepare input
learning_rate = utils.get_updated_learning_rate(global_step, config)
batch = trainset.pop_batch_queue()
wl, sm, global_step = network.train(batch['images'], batch['labels'], batch['is_photo'], learning_rate, config.keep_prob)
wl['lr'] = learning_rate
# Display
if step % config.summary_interval == 0:
duration = time.time() - start_time
start_time = time.time()
utils.display_info(epoch, step, duration, wl)
if config.save_model:
summary_writer.add_summary(sm, global_step=global_step)
wl, sm, global_step = network.train(batch['images'], batch['labels'], batch['is_photo'], learning_rate, config.keep_prob)
这句话是重点,调用了网络的训练
2. warpgan.py
这个文件中定义了warpgan这个网络的计算图,前向传播以及损失函数。
训练神经网络的过程可以概括为下面这三个步骤:
1)定义神经网络的结构和前向传播的输出结果
2)定义损失函数(根据前向传播的输出结果计算出来的)以及反向传播优化的算法
3)生成会话(tf.Session()),并在训练数据上反复运行反向传播优化算法
def train(self, images_batch, labels_batch, switch_batch, learning_rate, keep_prob):
images_A = images_batch[~switch_batch]
images_B = images_batch[switch_batch]
labels_A = labels_batch[~switch_batch]
labels_B = labels_batch[switch_batch]
scales_A = np.ones((images_A.shape[0]))
scales_B = np.ones((images_B.shape[0]))
feed_dict = { self.images_A: images_A,
self.images_B: images_B,