【分布式TensorFlow】MNIST图像分类

分布式TensorFlow——MNIST手写数字图像分类

  简介

       在本节中,我们将讨论 TensorFlow 在分布式计算中的应用。
       通俗来讲,分布式计算指的是使用超过一个计算单元的资源来执行所需的计算或实现目标。
       这背后的思想是,通过使用更多的计算能力,能够更快地训练同样的模型。
       
       我们还是以 MNIST 卷积神经网络模型的分布式训练为示例。
       我们将使用一个参数服务器和三个工作节点任务。
       为了方便实现,我们假定所有的任务都是在一台机器上本地运行的(这很容易通过 IP 地址替换 lacalhost 来适配多机器的设置)。
       

  项目设计

       处理导入值和常量:

	import tensorflow as tf
	from tensorflow.contrib import slim
	from tensorflow.examples.tutorials.mnist import input_data
	
	BATCH_SIZE = 4				# 每个小批量训练使用的示例数量
	TRANING_STEPS = 100			# 训练期间我们将使用的小批量的数量
	PRINT_EVERY = 2				# 打印信息间隔
	LOG_DIR = "./logs/parallelism"		# 日志文件目录
	DATA_DIR = "c:/tmp/data"			# MNIST数据目录

       
       定义集群:
       我们在本地玉兴所有任务。为了使用多台计算机,请使用正确的 IP 地址替换 localhost 。当然,端口 2222-2225 也是任意的(但是当使用单个机器时必然是不同的):

	parameter_servers = ["localhost:2222"]
	workers = ["localhost:2223",
	           "localhost:2224",
	           "localhost:2225"]

	cluster = tf.train.ClusterSpec({"ps": parameter_servers, "worker": workers})

       
       使用 tf.app.flags 机制来定义两个参数,当我们调用每个任务上的程序时,我们将通过命令行提供这两个参数:

	# job_name : 这对于单参数服务器来说可以是 'ps', 对于每个工作任务来说可以是 'worker'
	tf.app.flags.DEFINE_string("job_name", "", "'ps' / 'worker'")
	# task_index : 每个作业类型中的任务索引。因此,参数服务器将使用 task_index = 0 , 对于工作节点,任务索引为 0 , 1 , 2 
	tf.app.flags.DEFINE_integer("task_index", 0, "Index of task")
	FLAGS = tf.app.flags.FLAGS

       
       我们准备好在为了此任务定义的服务器的集群中使用当前任务的标识。注意,这会发生在我们运行的所有的四个任务上。每个人物都知道其身份( job_name , task_index )以及集群中其他任务(由第一个参数提供)的身份:

	server = tf.train.Server(cluster,
	                         job_name=FLAGS.job_name,
	                         task_index=FLAGS.task_index)

       
       在开始训练之前,我们定义卷积网络,并加载要使用的数据。
       这与我们前面的例子所做的相似,所以在这里不再赘述。
       为了代码的简洁,我们使用 TF-Slim :

	mnist = input_data.read_data_sets(DATA_DIR, one_hot=True)
	
	
	def net(x):
	    x_image = tf.reshape(x, [-1, 28, 28, 1])
	    net = slim.layers.conv2d(x_image, 32, [5, 5], scope='conv1')
	    net = slim.layers.max_pool2d(net, [2, 2], scope='pool1')
	    net = slim.layers.conv2d(net, 64, [5, 5], scope='conv2')
	    net = slim.layers.max_pool2d(net, [2, 2], scope='pool2')
	    net = slim.layers.flatten(net, scope='flatten')
	    net = slim.layers.fully_connected(net, 500, scope='fully_connected')
	    net = slim.layers.fully_connected(net, 10, activation_fn=None, scope='pred')
	    return net

       
       训练期间的实际过程取决于任务的类型。
       对于参数服务器,我们希望这个机制能很好地为参数提供服务。
       这需要等待请求并处理它们。

	# 服务器的 **.join()** 方法即使在所有其他任务都终止的情况下也不会终止
	# 因此当不再需要这个进程时,就必须从外部终止它。
	if FLAGS.job_name == "ps":
	    server.join()

       
       在每个工作任务中,我们定义相同的计算图:

	elif FLAGS.job_name == "worker":
	    # 我们使用 tf.train.replica_device_setter 来指定这个部分
	    # 这意味着这些 TensorFlow 变量将通过参数服务器
	    #(也就是允许我们执行分布式计算的机制)进行同步
	    with tf.device(tf.train.replica_device_setter(
	            worker_device="/job:worker/task:%d" % FLAGS.task_index,
	            cluster=cluster)):
	
	        # global_step 变量将保存跨任务训练期间的步骤总数(每个步骤索引仅发生在单个任务上)
	        # 这创建了一个时间线,以便可以从分隔开的每个任务上始终知道我们在总体情况中的位置
	        global_step = tf.get_variable('global_step', [],
	                                      initializer=tf.constant_initializer(0),
	                                      trainable=False)
	
	        # 设置占位符
	        x = tf.placeholder(tf.float32, shape=[None, 784], name="x-input")
	        y_ = tf.placeholder(tf.float32, shape=[None, 10], name="y-input")
	        y = net(x)
	
	        # 计算交叉熵
	        cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
	        # 使用 Adam 优化器
	        train_step = tf.train.AdamOptimizer(1e-4)\
	            .minimize(cross_entropy, global_step=global_step)
	
	        correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
	        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
	
	        init_op = tf.global_variables_initializer()

       
       我们设置一个 Supervisor 和一个 managed_session :
       这与我们通常使用的常规会话类似,但它能够处理分布的某些部分。
       变量的初始化只能在一个任务中完成(通过 is_chief 参数指定的 “主要主管”)
       所有其他任务将等待变量初始化完成,然后继续运行。

	sv = tf.train.Supervisor(is_chief=(FLAGS.task_index == 0),
	                             logdir=LOG_DIR,
	                             global_step=global_step,
	                             init_op=init_op)
	
	with sv.managed_session(server.target, config=config) as sess:

       随着会话启动,我们运行训练:

    with sv.managed_session(server.target, config=config) as sess:
        step = 0

        while not sv.should_stop() and step <= TRANING_STEPS:
            batch_x, batch_y = mnist.train.next_batch(BATCH_SIZE)

            _, acc, step = sess.run([train_step, accuracy, global_step],
                                    feed_dict={x: batch_x, y_: batch_y})

            # 每个打印间隔后,打印当前小批量的准确度
            if step % PRINT_EVERY == 0:
                print("Worker : {}, Step: {}, Accuracy (batch): {}"
                      .format(FLAGS.task_index, step, acc))

        # 最后,我们运行测试的准确度:
        test_acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})
        print("Test-Accuracy: {}".format(test_acc))

       

  完整代码

import tensorflow as tf
from tensorflow.contrib import slim
from tensorflow.examples.tutorials.mnist import input_data

BATCH_SIZE = 4
TRANING_STEPS = 100
PRINT_EVERY = 2
LOG_DIR = "./logs/parallelism"
DATA_DIR = "c:/tmp/data"
parameter_servers = ["localhost:2222"]
workers = ["localhost:2223",
           "localhost:2224",
           "localhost:2225"]
# unknown = ["localhost:2226"]

cluster = tf.train.ClusterSpec({"ps": parameter_servers, "worker": workers})

tf.app.flags.DEFINE_string("job_name", "", "'ps' / 'worker'")
tf.app.flags.DEFINE_integer("task_index", 0, "Index of task")
FLAGS = tf.app.flags.FLAGS

server = tf.train.Server(cluster,
                         job_name=FLAGS.job_name,
                         task_index=FLAGS.task_index)

mnist = input_data.read_data_sets(DATA_DIR, one_hot=True)


def net(x):
    x_image = tf.reshape(x, [-1, 28, 28, 1])
    net = slim.layers.conv2d(x_image, 32, [5, 5], scope='conv1')
    net = slim.layers.max_pool2d(net, [2, 2], scope='pool1')
    net = slim.layers.conv2d(net, 64, [5, 5], scope='conv2')
    net = slim.layers.max_pool2d(net, [2, 2], scope='pool2')
    net = slim.layers.flatten(net, scope='flatten')
    net = slim.layers.fully_connected(net, 500, scope='fully_connected')
    net = slim.layers.fully_connected(net, 10, activation_fn=None, scope='pred')
    return net


if FLAGS.job_name == "ps":
    server.join()

elif FLAGS.job_name == "worker":
    with tf.device(tf.train.replica_device_setter(
            worker_device="/job:worker/task:%d" % FLAGS.task_index,
            cluster=cluster)):

        global_step = tf.get_variable('global_step', [],
                                      initializer=tf.constant_initializer(0),
                                      trainable=False)

        x = tf.placeholder(tf.float32, shape=[None, 784], name="x-input")
        y_ = tf.placeholder(tf.float32, shape=[None, 10], name="y-input")
        y = net(x)

        cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
        train_step = tf.train.AdamOptimizer(1e-4)\
            .minimize(cross_entropy, global_step=global_step)

        correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

        init_op = tf.global_variables_initializer()

    sv = tf.train.Supervisor(is_chief=(FLAGS.task_index == 0),
                             logdir=LOG_DIR,
                             global_step=global_step,
                             init_op=init_op)

    config = tf.ConfigProto()
    config.gpu_options.per_process_gpu_memory_fraction = 0.2
    config.gpu_options.allow_growth = True

    with sv.managed_session(server.target, config=config) as sess:
        step = 0

        while not sv.should_stop() and step <= TRANING_STEPS:
            batch_x, batch_y = mnist.train.next_batch(BATCH_SIZE)

            _, acc, step = sess.run([train_step, accuracy, global_step],
                                    feed_dict={x: batch_x, y_: batch_y})

            if step % PRINT_EVERY == 0:
                print("Worker : {}, Step: {}, Accuracy (batch): {}"
                      .format(FLAGS.task_index, step, acc))

        test_acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})
        print("Test-Accuracy: {}".format(test_acc))

    sv.stop()

       

       本文示例参考《TensorFlow学习指南——深度学习系统构建详解》第九章第三节。

       
       欢迎各位大佬交流讨论!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值