一、概述
本文的推导参见西瓜书P102~P103,代码参见该网址。主要实现了利用三层神经网络进行手写数字的识别。
二、理论推导
1、参数定义
三层神经网络只有一层隐藏层。参数如下:
x | 输入层输入 |
v | 输入层与隐藏层间的权值 |
α | 隐藏层输入 |
b | 隐藏层输出 |
w |
隐藏层与输出层间的权值 |
输出层输入 | |
输出层输出 |
参数关系如下:
上述等式中fx为激活函数。西瓜书默认激活函数为sigmoid,损失函数为均方根,本文以此为前提进行推导。
2、推导
设损失函数为,则其公式如下:
对求偏导如下,这愚蠢的优快云不支持多行公式编辑,所以只好手写了:
于是我们就得到了的更新公式:
同样的,对求偏导如下:
于是我们就得到了的更新公式:
三、代码实现
优化方法选择SGD。
1、数据集初始化
MNIST数据集可以在TensorFlow中下载到:
import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist = input_data.read_data_