本文通过对TensorFlow的一个入门案例进行整理和分析,介绍了使用TensorFlow模型的一些基本概念和一般性的规律
在进入案例之前首先介绍一下使用TensorFlow编程模型会涉及到的基本概念。
一、TensorFlow编程模型
基本概念
- TensorFlow编程模型将计算表示为一个有向图(computation graph)
- 图的每一个节点(node)代表一个运算操作(operation)
- 图的边有两种含义:
- 一种表示数据流(flow)
- 一种表示依赖控制(contrial dependencies)。这类边约束节点执行的顺序,可以用来进行条件控制
- 在计算图的边中流动(flow)的数据被称为张量(tensor)——tensorflow命名的由来
- 用户与tensorflow交互的接口:Session
- 可通过Seesion的Extend方法向计算图添加新的节点和边
- 可通过Session的Run方法来执行计算图
- Variable:一种特殊运算符,将一些需要保留的tensor储存在内存或者显存中,并在计算时更新
实现原理
涉及到的几个重要概念:
- client:客户端
- session:提供client与master和worker相连的接口
- master:负责指导所有worker按流程执行计算图
- worker:与计算设备(CPU或者GPU)相连,负责管理这些计算设备。每一个worker可以管理多个设备。
这几个概念的关系为:
client通过session的接口与master及多个worker相连,其中每一个worker可以与多个硬件设备(device)相连,比如CPU或者GPU。
TensorFlow的执行模式有两种:
- 单机模式:client, master, worker全部在一台机器上的同一个进程中
- 分布式模式:允许client, master, worker在不同机器的不同进程中。同时由集群调度系统同意管理各项任务
在本文中,我们只涉及到单机模式,可以暂且不去关注分布式模式。
二、案例
该案例是通过一个单层的softmax神经元对MNIST手写数字进行识别。先上代码,该代码可以直接运行,后续我们对代码进行分析。
本文目标是总结TensorFlow的使用,对于softmax神经元的模型、MNIST手写数字以及一些机器学习的先验知识,在此不去详细介绍。
# -*- coding: utf-8 -*-
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data/', one_hot = True)
print(mnist.train.images.shape, mnist.train.labels.shape)
print(mnist.test.images.shape, mnist.test.labels.shape)
#加载MNIST数据集,并输出数据集的信息
sess = tf.InteractiveSession() #创建一个session
x = tf.placeholder(tf.float32,[None, 784]) #用于输入训练数据
W = tf.Variable(tf.zeros([784,10])) #softmax神经元的权值矩阵
b = tf.Variable(tf.zeros([10])) #softmax神经元的偏置矩阵
y = tf.nn.softmax(tf.matmul(x,W )+b) #通过softmax计算出来的label
y_ = tf.placeholder(tf.float32, [Non