TensorFlow2.0神经网络基础
0. 前向传播算法简介
计算神经网络的前向传播结果需要三部分信息:
- 神经网络的输入,这个输入就是从实体中提取的特征向量
- 神经网络的连接结构,神经网络是由神经元构成的,神经网络的结构给出不同神经元之间输入输出的连接关系。
- 每个神经元中的参数。
1. 数据加载
TensorFlow
依托Keras
,预先封装了很多小型的数据集,可以通过接口下载并使用该数据集。包含boston_housing
(波士顿房价回归预测数据集)、cifar10
(Cifar图像分类数据集共10大类)、cifar100
(Cifar图像分类数据集共100小类)、mnist
(手写数字识别数据集)、fashion_mnist
(常用品分类数据集)、imdb
(电影情感评论文本分类数据集)。
调用load_data
方法TensorFlow
会自动从Google
下载暂存的数据集到本地(第一次运行,后面不再下载,需要科学上网),然后以numpy
格式返回训练集和测试集的x和y。
示例
(1)tf.data.Dataset
- 在
TensorFlow
中,data
模块中的Dataset
是一个很重要的对象,作为数据加载的接口。 Dataset
对象的form_tensor_slices
方法可以很方便地从numpy
矩阵等加载数据并进行预处理
(注意Dataset对象的使用必须先取得对应的迭代器)。
示例
Dataset
对象的shuffle
方法可以很方便地打散数据(特征和标签同步打散),一般只需要一个buffer_size
参数,该数值越大,混乱程度越大。
Dataset
对象的map
方法可以很方便地进行数据预处理
或者数据增广
的操作,其功能类似于Python的map
方法,传入一个函数对象,对每个数据调用该函数进行处理。
Dataset
对象的batch
方法可以直接设定每次取出数据的batch_size
(默认为1),这是最核心的功能。
Dataset
对象的repeat
方法可以指定迭代器
迭代的次数(在Python中对可迭代对象一旦取完就会停止取数据,但是训练往往需要很多轮次),默认不指定repeat的参数则会一直迭代下去。
2. 全连接神经网络
经典的神经网络结构是由多个隐藏层的神经元级联形成的全连接神经网络,后来各类针对不同任务的神经网络结构的设计都是基于全连接神经网络的思路,如计算机视觉的卷积神经网络
、自然语言处理的循环神经网络
等。对于全连接神经元的神经网络,由于结构固定,已经做了一定程度的封装。
(1)tf.keras.layers.Dense(units)
可以创建包含参数
w
w
w和偏置
b
b
b的一层全连接层网络,units
表示神经元的数目。注意在创建Dense
对象后,weights
参数是默认没有创建的,需要通过build
方法创建,使用net
实例进行运算时会自动检查参数,若没有创建,则依据参与运算的数据自动创建weights
。
(2)tf.keras.Sequential
Sequential是用于线性堆叠多层网络结构的一个基础容器,其会自动将输入的张量流过多层得到输出。
3. 误差计算
(1)MSE均方差
均方误差的计算公式为:
l
o
s
s
=
1
N
∑
(
y
−
o
u
t
)
2
loss = \frac{1}{N}\sum {(y-out)^2}
loss=N1∑(y−out)2
在TensorFlow中对这类简单的误差函数(损失函数)进行了简单封装。MSE这类损失函数一般都要除以一个样本量N,以保证求得的梯度值不会太大。
(2)CrossEntropy交叉熵
计算公式为:
H
(
p
,
q
)
=
−
∑
p
(
x
)
log
q
(
x
)
H(p,q) = -\sum p(x) \log{q(x)}
H(p,q)=−∑p(x)logq(x)
预测的
q
q
q分布尽量逼近于真实的
p
p
p分布,此时交叉熵函数值最小,故优化该损失函数合理。在TensorFlow中,对交叉熵函数计算进行了封装。
4. 补充内容
- 在TensorFlow中,变量(tf.Variable)的作用就是保存和更新神经网络中的参数。
- 通过满足正态分布的随机数来初始化神经网络中的参数是一个非常常用的方法。
- 在神经网络中,偏置项(bias)通常会使用常数来设置初始值。
- 一个变量的值在被使用之前,这个变量的初始化过程需要被明确地调用。
- 在TensorFlow中,变量的声明函数tf.Variable是一个运算,运算的输出结果就是一个张量,所以变量只是一种特殊的张量。
- 变量的类型是不可改变的,一个变量在构建之后,它的类型就不能再改变了。
- 维度和类型是变量最重要的两个属性,其中维度在程序运行中是有可能改变的,但是需要通过设置参数validate_shape=False。