caffe例程——LeNet训练Mnist图片数据

本文详细介绍了如何使用Caffe训练LeNet模型来识别MNIST手写数字。从数据预处理,包括将MNIST数据转化为图片,制作LMDB数据,到编写网络配置文件,生成solver,训练模型,并通过可视化工具分析训练结果。最后,使用训练好的模型进行手写数字识别。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >


参考博文


http://www.cnblogs.com/denny402/p/5684431.html

感谢博主徐其华随笔分类 - caffe系列教程,收获颇丰。在这里记录下自己学习训练和测试模型的过程,可以说相关知识是完全来自参考博文,只是记录一下自己实践中的问题和感悟。

 

一、数据准备

官网提供的mnist数据并不是图片,但我们以后做的实际项目多数可能是图片。因此有些人并不知道该怎么办。在此我将mnist数据进行了转化,变成了一张张的图片,我们练习就从图片开始。

linu下直接用wget下载  wget http://deeplearing.net/data/mnist/mnist.pkl.gz

mnist.pkl.gz其实是数据的训练集、验证集和测试集用pickle导出的文件被压缩为gzip格式,所以用python的gzip模块当成文件就可以读取。其中每个数据集是一个元组,第一个元素存储的是手写数字图片,表示每张图片是长度为28*28=784的一维浮点型numpy数组,这个数组就是单通道灰度图片按行展开得到,最大值为1(白),最小值为0(黑)。元组中的第二个元素是图片对应的标签,是一个一维的整型numpy数组,按照下标位置对应图片中的数字。然后运行 convert_mnist.py 将pickle格式的数据转换为图片。


 
 import os
 import pickle, gzip
 from matplotlib import pyplot
  
 # Load the dataset
 print('Loading data from mnist.pkl.gz ...')
 with gzip.open('mnist.pkl.gz', 'rb') as f:
 train_set, valid_set, test_set = pickle.load(f)
  
 imgs_dir = 'mnist'
 os.system('mkdir -p {}'.format(imgs_dir))
 datasets = {'train': train_set, 'val': valid_set, 'test': test_set}
 for dataname, dataset in datasets.items():
 print('Converting {} dataset ...'.format(dataname))
 data_dir = os.sep.join([imgs_dir, dataname])
 os.system('mkdir -p {}'.format(data_dir))
 for i, (img, label) in enumerate(zip(*dataset)):
 filename = '{:0>6d}_{}.jpg'.format(i, label)
 filepath = os.sep.join([data_dir, filename])
 img = img.reshape((28, 28))
 pyplot.imsave(filepath, img, cmap='gray')
 if (i+1) % 10000 == 0:
 print('{} images converted!'.format(i+1))

该脚本首先创建一个叫mnist的文件夹,然后在mnist下创建3个子文件夹train、val和test,分别用来保存对应3个数据集转换后产生的图片。每个文件的命名规则为第一个字段是序号,第二个字段是数字的值。

二、制作LMDB数据

LMDB差不多是CAFFE用来训练图片最常用的数据格式。caffe提供了专门为图像分类任务将图片转换为LMDB的官方工具,路径为caffe/build/tools/convert_imageset。要使用这个工具,第一步是生成一个图片文件路径的列表,每一行是文件路径和对应标签我,用space键或者制表符(tab)分开。运行gen_caffe_imglist.py,代码如下:

 import os
 import sys
  
 input_path = sys.argv[1].rstrip(os.sep)
 output_path = sys.argv[2]
  
 filenames = os.listdir(input_path)
  
 with open(output_path, 'w') as f:
    for filename in filenames:
      filepath = os.sep.join([input_path, filename])
      label = filename[:filename.rfind('.')].split('_')[1]
      line = '{} {}\n'.format(filepath, label)
      f.write(line)
控制台下执行   python gen_caffe_imglist.py mnist/train train.txt(注意路径)
                       python gen_caffe_imglist.py mnist/val    val.txt
                       python gen_caffe_imglist.py mnist/test   test.txt

这样就生成了3个数据集的文件列表和对应标签。然后直接调用convert_imageset就可以制作lmdb了。

>>/build/tools/convert_imageset ./train.txt train_lmdb --gray --shuffle

>>/build/tools/convert_imageset ./val.txt val_lmdb --gray --shuffle

>>/build/tools/convert_imageset ./test.txt test_lmdb --gray --shuffle

--gray是单通道读取灰度图的选项;--shuffle是个常用的选项,作用是打乱文件列表顺序。

二、训练LeNet-5

网络结构和caffe官方例子的版本没有区别,只是输入的数据层变成了我们自己制作的LMDB。配置文件实际上就是一些txt文档,只是后缀名是prototxt,我们可以直接到编辑器里编写,也可以用代码生成。此处,我用python来生成。

 name: "LeNet"
 layer {
 name: "mnist"
 type: "Data"
 top: "data"
 top: "label"
 include {
 phase: TRAIN
 }
 

transform_param {


#mnist图片,将0~255之间的值缩放到-0.5~0.5,帮助收敛
 mean_value: 128
 scale: 0.00390625
 }
 data_param {
 source: "../data/train_lmdb"
 batch_size: 50
 backend: LMDB
 }
 }
 layer {
 name: "mnist"
 type: "Data"
 top: "data"
 top: "label"
 include {
 phase: TEST
 }
 transform_param {
 mean_value: 128
 scale: 0.00390625
 }
 data_param {
 source: "../data/val_lmdb"
 batch_size: 100
 backend: LMDB
 }
 }
 layer {
 name: "conv1"
 type: "Convolution"
 bottom: "data"
 top: "conv1"
 param {
 lr_mult: 1
 }
 param {
 lr_mult: 2
 }
 convolution_param {
 num_output: 20
 kernel_size: 5
 stride: 1
 weight_filler {
 type: "xavier"
 }
 bias_filler {
 type: "constant"
 }
 }
 }
 layer {
 name: "pool1"
 type: "Pooling"
 bottom: "conv1"
 top: "pool1"
 pooling_param {
 pool: MAX
 kernel_size: 2
 stride: 2
 }
 }
 layer {
 name: "conv2"
 type: "Convolution"
 bottom: "pool1"
 top: "conv2"
 param {
 lr_mult: 1
 }
 param {
 lr_mult: 2
 }
 convolution_param {
 num_output: 50
 kernel_size: 5
 stride: 1
 weight_filler {
 type: "xavier"
 }
 bias_filler {
 type: "constant"
 }
 }
 }
 layer {
 name: "pool2"
 type: "Pooling"
 bottom: "conv2"
 top: "pool2"
 pooling_param {
 pool: MAX
 kernel_size: 2
 stride: 2
 }
 }
 layer {
 name: "ip1"
 type: "InnerProduct"
 bottom: "pool2"
 top: "ip1"
 param {
 lr_mult: 1
 }
 param {
 lr_mult: 2
 }
 inner_product_param {
 num_output: 500
 weight_filler {
 type: "xavier"
 }
 bias_filler {
 type: "constant"
 }
 }
 }
 layer {
 name: "relu1"
 type: "ReLU"
 bottom: "ip1"
 top: "ip1"
 }
 layer {
 name: "ip2"
 type: "InnerProduct"
 bottom: "ip1"
 top: "ip2"
 param {
 lr_mult: 1
 }
 param {
 lr_mult: 2
 }
 inner_product_param {
 num_output: 10
 weight_filler {
 type: "xavier"
 }
 bias_filler {
 type: "constant"
 }
 }
 }
 layer {
 name: "accuracy"
 type: "Accuracy"
 bottom: "ip2"
 bottom: "label"
 top: "accuracy"
 include {
 phase: TEST
 }
 }
 layer {
 name: "loss"
 type: "SoftmaxWithLoss"
 bottom: "ip2"
 bottom: "label"
 top: "loss"
 }

三、生成参数文件solver

 # The train/validate net protocol buffer definition
 net: "lenet_train_val.prototxt"
 # test_iter specifies how many forward passes the test should carry out.
 # In the case of MNIST, we have test batch size 100 and 100 test iterations,
 # covering the full 10,000 testing images.
 test_iter: 100
 # Carry out testing every 500 training iterations.
 test_interval: 500
 # The base learning rate, momentum and the weight decay of the network.
 base_lr: 0.01
 momentum: 0.9
 weight_decay: 0.0005
 # The learning rate policy
 lr_policy: "inv"
 gamma: 0.0001
 power: 0.75
 # Display every 100 iterations
 display: 100
 # The maximum number of iterations
 max_iter: 36000
 # snapshot intermediate results
 snapshot: 5000
 snapshot_prefix: "mnist_lenet"
 # solver mode: CPU or GPU
 solver_mode: CPU


四、开始训练模型

最后,调用下面命令就可以进行训练了。

/build/tools/caffe train -solver lenet_solver.prototxt --log_dir=./
log_dir参数指定log文件的路径。训练完毕就会生成几个以caffemodel和solverstate结尾的文件,也就是模型参数和solver状态在指定迭代次数以及训练结束时的存档;同时生成的log文件,命名是:
caffe.主机名.域名.用户名.log.INFO.年月日-时分秒.微秒

也可以写成以下形式:

/build/tools/caffe train \

      --solver=lenet_solver.prototxt \
      --2>&1 | tee mnist_train.log
文件描述符:0 stdin, 1 stdout, 2 stderr。2>表示将标准出错重定向到某个特定的地方;2>&1意思就是无论标准出错在哪里,都将标准出错重定向到标准输出中。如果没有2>&1,只会有标准输出,没有错误。tee的作用是同时输出到控制台和文件。2>log.txt则表示只将错误写到文件,其他的还是在标准输出。

caffe提供了对输出log文件的解析工具parse_log.py

&CAFFE_ROOT/tools/extra/parse_log.py mnist_train.log   ./

输出两个解析文件:train.log.train   train.log.test
其内容格式是:迭代数-时间-学习率-损失; 迭代数-时间-学习率-top1准确率-top5准确率-损失

根据解析结果,即可绘制train-loss,test-loss和accuracy的变化曲线,参考博文:https://blog.youkuaiyun.com/zziahgf/article/details/79215862

五、可视化

caffe官方提供可视化log文件的工具,在caffe/tools/extra下有个plot_training_log.py.example,把这个文件复制一份命名为plot_training_log.py,就可以用来画图了。另外这个脚本要求log文件必须以.log结尾。用mv命令把log文件名改成mnist_train.log

>>python plot_training_log.py 0 test_acc_vs_iters.png mnist_train.log

>>python plot_training_log.py 2 test_loss_vs_iters.png mnist_train.log

六、识别手写数字

有了训练好的模型,就可以识别手写数字了。测试用的是test数据集的图片和之前生成的列表,recognize_digit.py 

需要安装opencv2。安装opencv后出现ImportError:No module named cv2的错误,找不到cv2的包。这时安装扩展包即可:pip install opencv-python

 import sys
 sys.path.append('/path/to/caffe/python')
 import numpy as np
 import cv2
 import caffe
  
 MEAN = 128
 SCALE = 0.00390625
  
 imglist = sys.argv[1]
  
 caffe.set_mode_gpu()
 caffe.set_device(0)
 net = caffe.Net('lenet.prototxt', 'mnist_lenet_iter_36000.caffemodel', caffe.TEST)
 net.blobs['data'].reshape(1, 1, 28, 28)
  
 with open(imglist, 'r') as f:
 line = f.readline()
 while line:
 imgpath, label = line.split()
 line = f.readline()
 

image = cv2.imread(imgpath, cv2.IMREAD_GRAYSCALE).astype(np.float) - MEAN


#cv2.imread():读入图片,共两个参数,第一个参数为要读入的图片文件名,第二个参数为如何读取图片,包括:cv2.IMAGE_COLOR:读入一副彩色图片;cv2.IMAGE_GRAYSCALE:以灰度模式读入图片;cv2.IMAGE_UNCHANGED:读入一幅图片,并包括alpha通道。
 image *= SCALE
 net.blobs['data'].data[...] = image
 output = net.forward()
 pred_label = np.argmax(output['prob'][0])
 print('Predicted digit for {} is {}'.format(imgpath, pred_label))


 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值