写在前面的话
不同于Tensorflow官方教程简略的DEMO,我们自己动手实现以下目标
- 从本地文件系统中加载图片、标签
- 对图片和标签预处理
- 创建batch对象以提供随机批次训练
- 构建网络结构
- 训练神经网络
- 在验证集合上评估准确率
- 保存及加载网络参数模型
CNN卷积神经网络的基础知识及简介,我推荐这篇文章。
http://brohrer.github.io/how_convolutional_neural_networks_work.html
训练数据
请下载 https://pan.baidu.com/s/1cdBnbC
训练集合
train.txt 图片文件名-标签
train.rar 图片库,请解压为train文件夹
验证集合
val.txt 图片名-标签
val.rar 图片库,请解压为val文件夹
本程序用到的第三方库
- tensorflow (1.0.1) 谷歌基于DistBelief进行研发的第二代人工智能学习系统
- Pillow (3.2.0) 基本的图像处理功能
- numpy (1.12.1) 开源的数值计算扩展
- matplotlib (1.5.1) Python 的 2D绘图库
干货,代码及讲解
定义输入数据路径
TRAIN_LABEL_PATH = "/path/to/train.txt"
TRAIN_IMAGE_PATH = "/path/to/train/"
VAL_LABEL_PATH = "/path/to/val.txt"
VAL_IMAGE_PATH = "/path/to/val/"
定义函数辅助构建神经网络
def weigth_variable(shape, name):
# 这里使用截断的正态分布,标准差为0.1
initial = tf.truncated_normal(shape, stddev=0.1)
return tf.Variable(initial, name = name)
def bias_variable(shape):
# bias初始化为0.1避免死亡节点
initial = tf.constant(0.1, shape=shape)
return initial
def conv2d(x, W):
# 参数中x是输入,W是卷积的参数,比如[5,5,1,32]:前面两个数字代表卷积核的尺寸;第三个数字代表有多少个channel。因为我们只有灰度单色,所以是1,如果是RGB彩色图片,这里应该是3。
# 最后一个数字代表卷积核的数量,也就是这个卷积层会提取多少类的特征。
# Strides代表卷积模板移动的步长,都是1代表会不遗漏地划过图片的每一个点。
# Padding代表边界的处理方式,这里的SAME代表给边界加上Padding让卷积的输出和输入保持同样(SAME)是尺寸。
return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding="SAME")
def