一、概述
本文的推导参见西瓜书P102~P103,代码参见该网址。主要实现了利用三层神经网络进行手写数字的识别。
二、理论推导
1、参数定义

三层神经网络只有一层隐藏层。参数如下:
| x | 输入层输入 |
| v | 输入层与隐藏层间的权值 |
| α | 隐藏层输入 |
| b | 隐藏层输出 |
| w |
隐藏层与输出层间的权值 |
| 输出层输入 | |
| 输出层输出 |
参数关系如下:
上述等式中fx为激活函数。西瓜书默认激活函数为sigmoid,损失函数为均方根,本文以此为前提进行推导。
2、推导
设损失函数为,则其公式如下:
对求偏导如下,这愚蠢的优快云不支持多行公式编辑,所以只好手写了:

于是我们就得到了的更新公式:
同样的,对求偏导如下:

于是我们就得到了的更新公式:
三、代码实现
优化方法选择SGD。
1、数据集初始化
MNIST数据集可以在TensorFlow中下载到:
import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist = input_data.read_data_

最低0.47元/天 解锁文章
6700

被折叠的 条评论
为什么被折叠?



