深度学习之tiny-dnn开源代码学习(2)-手写数字识别

本文介绍如何使用 Tiny-DNN 框架搭建并训练 LeNet-5 卷积神经网络,用于 MNIST 手写数字识别任务。文章详细展示了网络构造过程、数据集加载与预处理步骤,并提供了训练及评估网络性能的方法。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

前一篇博客我们大致认识到了tiny-dnn 这个库的整体结构和各个类的继承关系,了解到整体架构后我们这一节通过做一个实际的例子来进一步了解DeepLearning的知识。这篇文章讲怎样用tiny-dnn进行手写数字识别,任何一个深度学习的库的第一个实验应该就是手写数字识别啦。那么我们今天就来进行这个实验,这个实验在tiny-dnn库的example模块中有,我就是解读了那里的源代码才理解了进行手写数字识别的过程的。

数据集-MNIST

这个实验的数据使用的是MNIST手写数字数据集, 详情请见,我们也可以在这个链接直接下载数据集进行使用,这个数据集有60000个训练数据样本,和10000个测试的样本,它是NIST数据集的一部分。

网络结构-LeNet-5

这个网络结构参考于Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner. 的 “Gradient-based learning applied to document recognition” 这篇论文应该是非常出名了,这篇论文中详细介绍了怎样建构LeNet-5网络。

Constructing Net

Construcnting net的过程就需要用到上一节提到的nework类和不同层的类。
首先建立一个函数用来建设网络代码如下


static void construct_net(network<sequential> & net)
{
#define  o true
#define X false
static const bool table[] = {
    //0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
    o, X, X, X, o, o, o, X, X, o, o, o, o, X, o, o,
    o, o, X, X, X, o, o, o, X, X, o, o, o, o, X, o,
    o, o, o, X, X, X, o, o, o, X, X, o, X, o, o, o,
    X, o, o, o, X, X, o, o, o, o, X, X, o, X, o, o,
    X, X, o, o, o, X, X, o, o, o, o, X, o, o, X, o,
    X, X, X, o, o, o, X, X, o, o, o, o, X, o, o, o
};
#undef o
#undef X
    core::backend_t backend_type = core::default_engine();

    net << tiny_dnn::convolutional_layer(32, 32, 5, 1, 6, padding::valid, true, 1, 1, backend_type)
        << tanh_layer(28, 28, 6)
        << average_pooling_layer(28, 28, 6, 2)
        << tanh_layer(14, 14, 6)
        << convolutional_layer(14, 14, 5, 6, 16, core::connection_table(table, 6, 16), padding::valid, true, 1, 1, backend_type)
        << tanh_layer(10, 10, 16)
        << average_pooling_layer(10, 10, 16, 2)
        << tanh_layer(5, 5, 16)
        << convolutional_layer(5, 5, 5, 16, 120, padding::valid, true, 1, 1, backend_type)
        << tanh_layer(1, 1, 120)
        << fully_connected_layer(120, 10, true, backend_type)
        << tanh_layer(10);

}

变量table存储的是一种连接关系,整个网络的架构过程如下图所示
网络建构图
这个图就是论文Gradient-based learning applied to document recognition中LeNet-5的网络结构。

training

网络的训练同样写一个函数来实现。

static void train_LeNet(const std::string &data_dir_path)
{
    network<sequential> net;
    adagrad optimizer;
    construct_net(net);
    std::cout << "load models .. " << std::endl;
    // load MNIST dataset
    std::vector<label_t> train_labels, test_label;
    std::vector<vec_t> train_images, test_images;
    parse_mnist_labels(data_dir_path + "/train-labels.idx1-ubyte", &train_labels);
    parse_mnist_images(data_dir_path + "/train-images.idx3-ubyte", &train_images, -1.0, 1.0, 2, 2);
    parse_mnist_labels(data_dir_path + "/t10k-labels.idx1-ubyte", &test_label);
    parse_mnist_images(data_dir_path + "/t10k-images.idx3-ubyte",  &test_images, -1.0, 1.0, 2, 2);
    std::cout << "start training" << std::endl;

    progress_display disp(static_cast<unsigned long>(train_images.size()));
    timer t;
    int minibatch_size = 10;
    int num_epochs = 5;
    optimizer.alpha *= static_cast<tiny_dnn::float_t>(std::sqrt(minibatch_size));

    auto on_enumberate_epoch = [&]()
    {
        std::cout << t.elapsed() << "s elpased." << std::endl;
        tiny_dnn::result res = net.test(test_images, test_label);
        std::cout << res.num_success << "/" << res.num_total << std::endl;
        disp.restart(static_cast<unsigned long>(train_images.size()));
        t.restart();
    };

    auto on_enumberate_minibatch = [&]()
    {
        disp += minibatch_size;
    };

    //training 
    net.train<mse>(optimizer, train_images, train_labels, minibatch_size, num_epochs,
                    on_enumberate_minibatch, on_enumberate_epoch);
    std::cout << "end training" << std::endl;
    net.test(test_images, test_label).print_detail(std::cout);
    net.save("./LeNet-model");
}

main函数

主函数需要输入数据的目录地址

int main(int argc, char **argv)
{
    if (argc != 2)
    {
        std::cerr << "Usage :" << argv[0];
        std::cerr << "path_to_data" << std::endl;
        return -1;
    }
    train_LeNet(argv[1]);
    return 0;
}

工程编译

最后我们用cmake生成makefile文件,CMakeLists.txt 内容设置如下

CMAKE_MINIMUM_REQUIRED(VERSION 3.5.1)
PROJECT(LeNet CXX)

# 隐含定义了两个变量
# LeNet_BINARY_DIR 和 LeNet_SOURCE_DIR 内部工程,这两个变量都指向
#当前文件夹,同时也自动定义了 两个变量
# PROJECT_BINARY_DIR 和 PROJECT_SOURCE_DIR  这两个变量和上面两个变量一样

#SET 指令
#用来显示定义变量 
# 支持C++14标准
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14")
SET(CMAKE_BUILD_TYPE "Debug")  
SET(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} -O0 -Wall -g2 -ggdb")  
SET(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O3 -Wall") 
SET(EXECUTABLE_OUTPUT_PATH ${PROJECT_BINARY_DIR}/bin)
SET(LIBRARY_OUTPUT_PATH ${PROJECT_BINARY_DIR}/lib)
SET(CMAKE_MODULE_PATH /usr/local/OpenCV32/share/OpenCV)
FIND_PACKAGE(OpenCV REQUIRED)

MESSAGE(STATUS "CMAKE PREFIX PATH ${CMAKE_PREFIX_PATH}")
MESSAGE(STATUS "OpenCV library status:")
MESSAGE(STATUS " version : ${opencv_VERSION}")
MESSAGE(STATUS " include path: ${OpenCV_INCLUDE_DIRS}")
MESSAGE(STATUS " library: ${OpenCV_LIBS}")

# ADD opencv headers include paths
INCLUDE_DIRECTORIES(${OpenCV_INCLUDE_DIRS})
INCLUDE_DIRECTORIES(./)

#add opencv libs
AUX_SOURCE_DIRECTORY(. DIR_SRCS)

MESSAGE(STATUS " DIR_SRCS  ${DIR_SRCS}")

# add executable file of LeNet
ADD_EXECUTABLE(LeNet ${DIR_SRCS})

TARGET_LINK_LIBRARIES(LeNet ${OpenCV_LIBS} pthread)

接下来就可以训练网络了,漫长的等待。。。
训练中

网络训练好就可以存储下来进行手写数字识别了。

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值