github资源地址:[Release-x86/x64]
上一篇:轻量级C++神经网络应用库CreativeLus:3、复杂函数逼近。案例:多输入混合逼近。
下一篇:轻量级C++神经网络应用库CreativeLus:5、ResNet残差网咯。案例:(cifar-100)图片分类。
案例4:CNN网络,实现(MNIST)手写数字识别
本案例的任务
经过前面几章的介绍,对CL基本功能应有所了解。CL致力于通过简单的代码,易懂的逻辑,通用化的功能达到快速应用的目的。之前的几个案例,演示功能性居多,实用性不强,但今天,我们就用CL来重现一个经典的CNN案例,用来解决实际应用问题:单色手写数字图片的识别。
- 本章将通过简单代码构建一个经典CNN网络LeNet-5,并介绍以下核心功能:
- 方法一:通过结构定义对象的API,构建卷积神经网络的方法;
- 方法二:通过脚本定义对象,构建卷积神经网络方法;
卷积网络LeNet-5介绍
关于什么是卷积神经网络?相关基础知识,在此不再论述,优快云 和百度都有很多的介绍。今天我们通过构建经典的cnn网络LeNet-5来实现Mnist数据集的训练和手写数字图片的识别。特别说明:本章代码中,对原LeNet-5逻辑有一些小修改,取消了原C3、C5层的链接标记。经测试,该修改对模型的训练及识别能力并无实际影响,若一定要使用,可采用在C3、C5层设置Dropout方案替代。
- 为了礼貌起见,放两张图,供小伙伴们回忆一下cnn网络和LeNet-5。(图片源于网络)
准备工作
1、Mnist数据集的获取和预处理
首先我们下载好本例所用手写数据训练数据集和测试数据集[Mnist:http://yann.lecun.com/exdb/mnist/],他的官网是这样介绍它的:
- MNIST手写数字数据库(可从本页获得)有60000个示例的训练集和10000个示例的测试集。它是NIST提供的更大集合的子集。数字已被规格化,并集中在一个固定大小的图像中。
- 它是一个很好的数据库,供那些想尝试在实际数据上学习技术和模式识别方法,同时在预处理和格式化方面花费最少精力的人使用。
之后使用代码做一下预处理,生成CL样本对集合对象,代码很简单,如下:
#include "CreativeLus.h"
#include "CreativeLusExTools.h"
string pathdatasrc = ("D:\\Documents\\Desktop\\nndata\\mnist\\");
// 第一步: 由Mnist数据预处理,生成模型训练集和预测集---------------------------------------------
BpnnSamSets trainSets, testSets;
if (!trainSets.readFromFile((pathTag + "cnnTrainData.txt").c_str())) {
StdDataHelper::readDataOfMnist(
(pathdatasrc + "train-images.idx3-ubyte").c_str(),
(pathdatasrc + "train-labels.idx1-ubyte").c_str(),
trainSets,
-1, //表示训练输入数据的值区间最小为-1
1, //表示训练输入数据的值区间最大为-1
-0.8, //表示训练输入数据对应的分类标签的值区间最小为-0.8
0.8, //表示训练输入数据对应的分类标签的值区间最大为0.8
2 //对样本输入数据做padding=2的数据扩充,使得28x28的原始数据变为32x32
);
trainSets.writeToFile((pathTag + "cnnTrainData.txt").c_str());
}
if (!testSets.readFromFile((pathTag + "cnnTestData.txt").c_str())) {
StdDataHelper::readDataOfMnist(
(pathdatasrc + "t10k-images.idx3-ubyte").c_str(),
(pathdatasrc + "t10k-labels.idx1-ubyte").c_str(),
testSets, -1, 1, -0.8, 0.8, 2
);
testSets.writeToFile((pathTag + "cnnTestData.txt").c_str());
}
- 其中静态
StdDataHelper::readDataOfMnist()
方法,是包含在头文件"CreativeLusExTools.h"
中,实现快速处理Mnist原始数据,并转换成BpnnSamSets数据集的函数,源代码也可根据需要自行调整。 - 关于LeNet-5网络实现,可参考【卷积神经网络(CNN)的简单实现(MNIST)】,不过要做好思想准备,这是一种纯C基于过程的实现方案,很难修改,很难扩充,代码灵活度不大。光是要构造出网络,代码就麻烦到让人吐血,算法过程代码更是不好理解,对于初学者很不友好。
2、卷积网络组装
基于以上前车之鉴,CL应该有简单明了的建模手段,因此,“8”行代码构建CNN网络LeNet-5的方式就诞生了:(关于 BpnnStructScript
对象使用,详见完整测试代码部分)
//代码片段
BpnnStructScript scp = //构造生成脚本
{
{
{
0,SCP_Conv,{
5,5,6},wi[0],bi,transfunc},}, //标准卷积层
{
{
0,SCP_Pool,{
2,2},{
}, {
},-1,WC_Average},}, //池化层,均值池化
{
{
0,SCP_ConvSep,{
1,1},wi[2],bi, transfunc},},//分割卷积层
{
{
0,SCP_Conv,{
5,5,16},wi[3],bi,transfunc},}, //标准卷积层
{
{
0,SCP_Pool,{
2,2},{
}, {
},-1,WC_Average},}, //池化层,均值池化
{
{
0,SCP_ConvSep,{
1,1},wi[5],bi,transfunc},}, //分割卷积层
{
{
0,SCP_Conv,{
5,5,120},wi[6],bi,transfunc},},//标准卷积层,本层链接了map=1x1的分割卷积层,即等价于一个全连接层
{
{
0,SCP_Fc,{
10},wi[7],bi, transfunc},}, //10个输出神经元的全连接层
};
关于如何做到的?CL是基于对象建模的,按[案例3]介绍的结构定义方式,可以构建任意的自定义的网络,加上些许封装,即可实现通过快速的脚本对象定义,生成复杂的卷积网络(后续还有更复杂的VGG,ResNet等网络,原理均同),这在后续技术附录中介绍。
3、手写数字图片准备
- 接下来准备10张32x32尺寸的手写数字图片,从0到9,均采用单色bitmap格式(windows位图格式)。
- 处理图片,采用头文件
"CreativeLusExTools.h"
中定义的CLBmpHelper::readbmp()
静态方法即可(不需要用OpenCV库来处理这么麻烦了)。关于CLBmpHelper::readbmp()
可自行查看源码,按需修改。
- 读取的数据时记得做一下必要的值域转换:
//代码片段
CLBmpHelper::readbmp(str.c_str(), bmpData);
data.resize(bmpData.size());
for (size_t j = 0, sj = bmpData.size(); j < sj; j++){
//将[0,255]的之间的byte值映射转换到[-1,1]的范围内,至于为什么,请自行思考。
data[j] = bmpData[j] / 255.0 * (1 - (-1)) + (-1);
}
完整测试代码
#include <stdio.h>
#include <string>
#include <vector>
#include <map>
#include "CreativeLus.h"
#include "CreativeLusExTools.h"
using namespace cl;
int main() {
printf("\n\n//案例4:cnn卷积神经网络,实现(Mnist)手写数字识别\n");
string pathTag = ("D:\\Documents\\Desktop\\example_04_cnn_mnist\\");
string pathdatasrc = ("D:\\Documents\\Desktop\\nndata\\mnist\\"); //原始Mnist数据目录
string pathdatabmp = ("D:\\Documents\\Desktop\\nndata\\bmp32x32_0_9\\"); //手写数字图片文件目录
// 第一步: 由Mnist数据预处理,生成模型训练集和预测集---------------------------------------------
BpnnSamSets trainSets, testSets;
if (!trainSets.readFromFile((pathTag + "cnnTrainData.txt").c_str())) {
StdDataHelper::readDataOfMnist(
(pathdatasrc + "train-images.idx3-ubyte").c_str(),
(pathdatasrc + "train-labels.idx1-ubyte").c_str(),
trainSets, -1, 1, -0.8, 0.8, 2
);
trainSets.writeToFile((pathTag + "cnnTrainData.txt").c_str());<