这个实验拿来入门最好了,很轻松,执行作者写好的脚本就可以;
官网链接:http://caffe.berkeleyvision.org/gathered/examples/mnist.html
跟着这个实验,可以观察caffe使用的流程,深入一点,可以看看作者写的shell脚本做了哪些事情;
下面我们一步步来执行,并作简单的分析:
1 首先准备数据集
执行下面的命令。$CAFFE_ROOT是指caffe安装的目录,比如说我的是/home/sloanqin/caffe-master/
cd $CAFFE_ROOT
./data/mnist/get_mnist.sh //执行下载mnist数据集脚本
./examples/mnist/create_mnist.sh //将下载的mnist数据集转换成lmdb格式的数据
我们来看看这两个脚本做了什么:
1.1 执行脚本get_mnist.sh:
get_mnist.sh脚本文件如下:
<pre name="code" class="plain"># This scripts downloads the mnist data and unzips it.
DIR="$( cd "$(dirname "$0")" ; pwd -P )"
cd $DIR
echo "Downloading..."
wget --no-check-certificate http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
wget --no-check-certificate http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
wget --no-check-certificate http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
wget --no-check-certificate http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
echo "Unzipping..."
gunzip train-images-idx3-ubyte.gz
gunzip train-labels-idx1-ubyte.gz
gunzip t10k-images-idx3-ubyte.gz
gunzip t10k-labels-idx1-ubyte.gz
# Creation is split out because leveldb sometimes causes segfault
# and needs to be re-created.
echo "Done."
作用:下载下面四个文件,并解压;
train-images-idx3-ubyte.gz训练集
train-labels-idx1-ubyte.gz训练集的标签
t10k-images-idx3-ubyte.gz测试集
t10k-labels-idx1-ubyte.gz测试集的标签
2 定义caffe的网络结构:.prototxt 文件
caffe的网络结构定义在后缀名为.prototxt的文件中,我们根据自己的需要定义自己的网络结构;
在这个实验中,我们使用作者已经为我们定义好的lenet网络结构,大家可以在下面的目录中找到该文件:
$CAFFE_ROOT/examples/mnist/lenet_train_test.prototxt
在我的电脑上,目录是/home/sloanqin/examples/mnist/lenet_train_test.prototxt
在后续的工作中,定义好自己的网络结构是最关键的,直接决定了性能,这里我们就不多说了;
3 定义caffe运算的时候的一些规则:solver.prototxt 文件
该文件在下面的目录中:
$CAFFE_ROOT/examples/mnist/lenet_solver.prototxt
文件内容如下:作者给出了英文注释,我再给出中文的注释
# The train/test net protocol buffer definition
net: "examples/mnist/lenet_train_test.prototxt"
# test_iter specifies how many forward passes the test should carry out.
# In the case of MNIST, we have test batch size 100 and 100 test iterations,
# covering the full 10,000 testing images.
test_iter: 100
# Carry out testing every 500 training iterations.
test_interval: 500
# The base learning rate, momentum and the weight decay of the network.
base_lr: 0.01
momentum: 0.9
weight_decay: 0.0005
# The learning rate policy
lr_policy: "inv"
gamma: 0.0001
power: 0.75
# Display every 100 iterations
display: 100
# The maximum number of iterations
max_iter: 10000
# snapshot intermediate results
snapshot: 5000
snapshot_prefix: "examples/mnist/lenet"
# solver mode: CPU or GPU
solver_mode: GPU
中文注释版本:
<pre name="code" class="plain"># The train/test net protocol buffer definition
net: "examples/mnist/lenet_train_test.prototxt"
# test_iter specifies how many forward passes the test should carry out.
# In the case of MNIST, we have test batch size 100 and 100 test iterations,
# covering the full 10,000 testing images.
test_iter: 100 // 这个参数指定测试的时候送入多少个
// 这里说明一个知识:GPU在计算的时候,每次迭代是多张图片,我们称为一个batch
// 作者提到:test batch size 100,就是说每个包有100张图片
// 这里设置 st_iter=100,就是测试的时候一共输入100*100=10000张图片
//所以,test_iter 的英文翻译就是:测试时迭代次数
# Carry out testing every 500 training iterations.
test_interval: 500 //定义每500次迭代,做一次测试
# The base learning rate, momentum and the weight decay of the network.
base_lr: 0.01 // 定义了刚开始的学习速率是0.01
momentum: 0.9
weight_decay: 0.0005
# The learning rate policy
lr_policy: "inv"
gamma: 0.0001
power: 0.75
# Display every 100 iterations
display: 100 //每迭代100次,显示一次计算结果
# The maximum number of iterations
max_iter: 10000 //设置最大的迭代次数
# snapshot intermediate results
snapshot: 5000 // 保存中间运行时得到的参数结果,这里设置成每5000次迭代保存一次,这样运行中间断掉了,我们可以从断掉的地方继续开始
snapshot_prefix: "examples/mnist/lenet"
# solver mode: CPU or GPU
solver_mode: GPU //使用CPU还是GPU计算
4 执行命令进行训练
最后一步就是执行脚本开始训练:
cd $CAFFE_ROOT
./examples/mnist/train_lenet.sh
我们打开这个脚本,可以看到特别简单,就一行:
./build/tools/caffe train --solver=examples/mnist/lenet_solver.prototxt
这行代码的意思:调用./build/tools/caffe目录下面的train函数,train函数的输入参数是solver.prototxt文件的路径:--solver=examples/mnist/lenet_solver.prototxt
5 结果
运行的过程中,可以卡到test 的准确率在不断上升;运行结束后,会生成模型文件:lenet_iter_10000.caffemodel
还有一个文件是snapshot 保存的:lenet_iter_10000.solverstate
原文链接:http://write.blog.youkuaiyun.com/postedit/49147935