tensorflow44《TensorFlow实战》笔记-09-03 分布式并行 code

本文介绍了一个使用TensorFlow进行分布式训练的实例,通过设置不同的任务角色(参数服务器和工作节点),在多个机器上实现模型的并行训练。文章详细展示了如何配置集群、定义模型以及执行训练过程。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

未测试

# 《TensorFlow实战》09 TensorBoard、多GPU并行及分布式并行
# win10 Tensorflow1.0.1 python3.5.3
# CUDA v8.0 cudnn-8.0-windows10-x64-v5.1
# filename:sz09.03.py # 分布式并行

# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/dist_test/python/mnist_replica.py
# tensorflow\tensorflow\tools\dist_test\python\mnist_replica.py

'''
# 未测试
sz09.03.py==distributed.py
测试方法:
on 192.168.233.201 192.168.233.202 192.168.233.203 异步模式
python distributed.py --job_name=ps --task_index=0
python distributed.py --job_name=worker --task_index=0
python distributed.py --job_name=worker --task_index=1

on 192.168.233.201 192.168.233.202 192.168.233.203 同步模式
python distributed.py --job_name=ps --task_index=0 --sync_replicas=True
python distributed.py --job_name=worker --task_index=0 --sync_replicas=True
python distributed.py --job_name=worker --task_index=1 --sync_replicas=True
'''
import math
import tempfile
import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

flags = tf.app.flags
flags.DEFINE_string("data_dir", "MNIST_data", "Directory for storing mnist data")
flags.DEFINE_integer("hidden_units", 100, "Number of units in the hiddent layers of the NN")
flags.DEFINE_integer("train_steps", 1000000, "Number of (global) training steps to perform")
flags.DEFINE_integer("batch_size", 100, "Training batch size")
flags.DEFINE_float("learning_rate", 0.01, "Learning rate")
flags.DEFINE_boolean("sync_replicas", False,
                     "Use the sync_replicas (synchronized replicas) mode, "
                     " wherein the parameter updates from workers are "
                     "aggregated before applied to avoid stale gradients")
flags.DEFINE_integer("replicats_to_aggregate", None,
                     "Number of replicats to aggregate before parameter "
                     "update is applied (For sync_replicats mode only; "
                     "default: num_workers)")
flags.DEFINE_string("ps_hosts", "192.168.233.202:2222",
                    "Comma-separated list of hostname:port pairs")
flags.DEFINE_string("worker_hosts",
                    "192.168.233.202:2222,192.168.233.203:2222",
                    "Comma-separated list of hostname:port pairs")
flags.DEFINE_string("job_name", None, "job name: worker or ps")
flags.DEFINE_integer("task_index", None, "Worker task index, should be >= 0. task_index=0 is "
                     "the master worker task the performs the variable "
                     "initialization ")
FLAGS = flags.FLAGS
IMAGE_PIXELS = 28

def main(unused_argv):
    mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
    if FLAGS.job_name is None or FLAGS.job_name == "":
        raise ValueError("Must specify an explicit `job_name`")
    if FLAGS.task_index is None or FLAGS.task_index == "":
        raise ValueError("Must specify an explicit `task_index`")

    print("job name = %s" % FLAGS.job_name)
    print("task index = %d" % FLAGS.task_index)
    ps_spec = FLAGS.ps_hosts.split(",")
    worker_spec = FLAGS.worker_hosts.split(",")
    num_workers = len(worker_spec)
    cluster = tf.train.ClusterSpec({"ps": ps_spec, "worker": worker_spec})
    server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
    if FLAGS.job_name == "ps":
        server.join()
    is_chief = (FLAGS.task_index == 0)
    worker_device = "/job:worker/task:%d/gpu:0" % FLAGS.task_index
    with tf.device(tf.train.replica_device_setter(worker_device=worker_device, ps_device="/job:ps/cpu:0", cluster=cluster)):
        global_step = tf.Variable(0, name="global_step", trainable=False)

    hid_w = tf.Variable(tf.truncated_normal([IMAGE_PIXELS*IMAGE_PIXELS, FLAGS.hidden_units], stddev=1.0 / IMAGE_PIXELS), name="hid_w")
    hid_b = tf.Variable(tf.zeros([FLAGS.hidden_units]), name="hid_b")
    sm_w = tf.Variable(tf.truncated_normal([FLAGS.hidden_units, 10], stddev=1.0 / math.sqrt(FLAGS.hidden_units)), name="sm_w")
    sm_b = tf.Variable(tf.zeros([10]), name="sm_b")

    x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS*IMAGE_PIXELS])
    y_ = tf.placeholder(tf.float32, [None, 10])

    hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b)
    hid = tf.nn.relu(hid_lin)

    y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b))
    cross_entropy = -tf.reduce_sum(y_*tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
    opt = tf.train.AdamOptimizer(FLAGS.learning_rate)
    if FLAGS.sync_replicas:
        if FLAGS.replicas_to_aggregate is None:
            replicas_to_aggregate = num_workers
        else:
            replicas_to_aggregate = FLAGS.replicasz_to_aggregate

        opt = tf.train.SyncReplicasOptimizer(opt,
                                             replicas_to_aggregate=replicas_to_aggregate,
                                             total_num_replicas=num_workers,
                                             replica_id=FLAGS.task_index,
                                             name="mnist_sync_replicas")
    train_step = opt.minimize(cross_entropy, global_step=global_step)
    if FLAGS.sync_replicas and is_chief:
        chief_queue_runner = opt.get_chief_queue_runner()
        init_tokens_op = opt.get_init_tokens_op()
    init_op = tf.global_variables_initializer()
    train_dir=tempfile.mkdtemp()
    sv = tf.train.Supervisor(is_chief=is_chief,
                             logdir=train_dir,
                             init_op=init_op,
                             recovery_wait_secs=1,
                             global_step=global_step)

    sess_config = tf.ConfigProto(
        allow_soft_placement=True,
        log_device_placement=False,
        device_filters=["/job:ps", "/job:worker/task:%d" % FLAGS.task_index])
    if is_chief:
        print("Worker %d: Initializing session..." % FLAGS.task_index)
    else:
        print("Worker %d: Waiting for session to be initialized..." % FLAGS.task_index)
    sess = sv.prepare_or_wait_for_session(server.target, config=sess_config)
    print("Worker %d: Session initialization complete." % FLAGS.task_index)
    if FLAGS.sync_replicas and is_chief:
        print("Starting chief queue runner and running init_tokens_op")
        sv.start_queue_runners(sess, [chief_queue_runner])
        sess.run(init_tokens_op)

    time_begin = time.time()
    print("Training begins @%f" % time_begin)

    local_step = 0
    while True:
        batch_xs, batch_ys = mnist.train.next_batch(FLAGS.batch_size)
        train_feed = {x: batch_xs, y: batch_ys}
        _, step = sess.run([train_step, global_step], feed_dict=train_feed)
        local_step += 1

        now = time.time()
        print("%f: Worker %d: training step %d done (global step: %d)" % (now, FLAGS.task_index, local_step, step))

        if step >= FLAGS.train_steps:
            break

    time_end = time.time()
    print("Training ends @ %f" % time_end)
    training_time = time_end - time_begin
    print("Training ends @ %f s" % training_time)

    val_feed = {x: mnist.validation.images, y_: mnist.validation.labels}
    val_xent = sess.run(cross_entropy, feed_dict=val_feed)
    print("After %d training step(s), validation cross entropy = %g" % (FLAGS.train_steps, val_xent))

if __name__ == "__main__":
    tf.app.run()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值