简单介绍:tensorflow的分布式

介绍
  • RPC框架(远程过程调用):它是一种对底层通信协议的二次封装,通过网络从远程计算机程序上请求的服务框架。可参考:RPC框架
    • tensorflow中的是谷歌的gRPC框架
  • 数据分析设备类型:
    • 一机多卡(普通):一台服务器,多个显卡。
    • 多机多卡(分布式):多台服务器,多个显卡。
  • tensorflow实现分布式的结构:
    • 参数服务器们(tensorflow对于其的命名规范:/job:ps/task:0/cpu:0):用于存储和更新参数W(权重)
    • 工作服务器们(tensorflow对于其的命名规范:/job:worker/task:1/gpu:2):用于计算梯度下降的“变化值”
      ps:工作服务器中有一个主服务器(需要指定),用于定义和运行任务,其它服务器配合计算。
  • tensorflow更新参数的方式(配合图片)
    • 我们先假设参数更新的公式是 w’ = w - t;t = a(方向) — > t 为梯度下降求出来的“变化值”
    • 同步模式更新:
      • 过程:工作服务器计算出梯度下降的变化值:t1,t2…,并将传给参数服务器,然后参数服务器取变化值的平均值使用:w’ = w - tmean 更新权重参数,然后再将 w’ 传入各个工作服务器中。
      • 特点:在这个过程中,工作服务器的计算速度是有差别的,所以计算快的工作服务器,会阻塞等待慢的,这个过程别称之为同步。
    • 异步模式更新(tensorflow默认方式) :
      • 过程:相对同步而言,异步是不会不等待“变化值”的,它会使用 w’ = w - t 直接计算出新的 w’ ,然后使用 w’ 和之前存储的 w 取平均值更新为参数服务器存储的 w 和当前工作服务器的新参数 Wmean。
      • 特点:因为每个工作服务器计算出来的变化值都会被计算然后取平均值,所以这个过程中相对同步而言,是不受个别机器配置的影响的,但是可能会出现梯度消失的问题。
        tensorflow参考图片
案例:分布式做一个简单的乘法运算
  • 由于需要介绍的api太多,不再介绍,仅在代码做简要注释。
  • 需要注意的是:当在配置参数服务器和工作服务器的时候,按顺序,依次为0,1,2…号服务器(可以参考下图)。
    机器学习
  • code:
    #!/usr/local/bin/python3
    # -*- coding: utf-8 -*-
    # Author  : rusi_
    
    import tensorflow as tf
    
    FLAGS = tf.app.flags.FLAGS
    
    tf.app.flags.DEFINE_string("job_name", "worker", "启动服务的类型ps or  worker")
    tf.app.flags.DEFINE_integer("task_index", 0, "指定ps或者worker当中的哪一台服务器")
    
    
    def main(argv):
        # 定义全局计数的op ,给钩子列表当中的训练步数使用
        global_step = tf.contrib.framework.get_or_create_global_step()
    
        # 指定集群描述对象, ps , worker
        clr = tf.train.ClusterSpec({"ps": ["172.18.87.17:7000"], "worker": ["192.168.233.141:70001"]})
    
        # 创建不同的服务, ps, worker
        server = tf.train.Server(clr, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
    
        # 判断服务类型
        if FLAGS.job_name == "ps":
            # 参数服务器只需要等待worker传递参数
            server.join()
        elif FLAGS.job_name == "worker":
            worker_device = "/job:worker/task:0/cpu:0/"
    
            # tf.device(参数:可以指定名字或函数(固定的)指定什么设备去运行什么代码。
            # 当指定函数的时候是为了同步数据的,这种方法也是分布式特有的使用方式。
            # 当指定名字的时候,例如:'/cpu:0'代表的是在当前服务器上使用cpu0运行下面部分的代码
            with tf.device(tf.train.replica_device_setter(
                    worker_device=worker_device,  # 用什么设备
                    cluster=clr  # 和别的服务器建立关系
            )):
                # 简单做一个矩阵乘法运算
                x = tf.Variable([[1, 2, 3, 4]])
                w = tf.Variable([[1], [1], [1], [1]])
                mat = tf.matmul(x, w)
    
            # 创建分布式会话(不需要再初始化变量op了,MonitoredTrainingSession()自动帮助初始化)
            with tf.train.MonitoredTrainingSession(
                    master="grpc://192.168.233.141:70001",  # 指定主worker
                    is_chief=(FLAGS.task_index == 0),  # 判断是否是主worker
                    # config=tf.ConfigProto(log_device_placement=True),  # 打印设备信息
                    hooks=[tf.train.StopAtStepHook(last_step=200)]  # 钩子实现循环多少步,当超过200步的时候会抛出异常(需要配合全局global_step(计数器)使用)。
            ) as m_sess:
                while not m_sess.should_stop():  # should_stop() :会话是否产生了异常(异常时返回false)。
                    print(m_sess.run(mat))
    
    
    if __name__ == "__main__":
        tf.app.run()  # 默认调用main()函数,main函数要求必须有argv参数
    
    
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值