本文为原创文章转载必须注明本文出处以及附上 本文地址超链接 以及 博主博客地址:http://blog.youkuaiyun.com/qq_20259459 和 作者邮箱( jinweizhi93@gmai.com )。
(如果喜欢本文,欢迎大家关注我的博客或者动手点个赞,有需要可以邮件联系我)
接上一篇文章(阅读上一篇文章:http://blog.youkuaiyun.com/qq_20259459/article/details/54600368 )
(四)cnn_train.m
%调用cnn_train:
% [ net, info ] = cnn_train(net, imdb, @getBatch, opts.train, 'val', find(imdb.images.set == 3)) ;
function [net, stats] = cnn_train(net, imdb, getBatch, varargin)
%% --------------------------------------------------------------
% 函数名:cnn_train
% 功能: 1.用于训练过程
% 2.使用随机梯度下降法(SGD)
% ------------------------------------------------------------------------
%CNN_TRAIN An example implementation of SGD for training CNNs
% CNN_TRAIN() is an example learner implementing stochastic
% gradient descent with momentum to train a CNN. It can be used
% with different datasets and tasks by providing a suitable
% getBatch function.
%
% The function automatically restarts after each training epoch by
% checkpointing.
%
% The function supports training on CPU or on one or more GPUs
% (specify the list of GPU IDs in the `gpus` option).
% Copyright (C) 2014-16 Andrea Vedaldi.
% All rights reserved.
%
% This file is part of the VLFeat library and is made available under
% the terms of the BSD license (see the COPYING file).
% ------------------------------------------------------------------------
%翻译:
%cnn_train是一个学习器的示例,基于SGD算法对CNN进行训练。
%通过适当的getBatch函数,cnn_train可以被用在训练不同的数据集,以实现不同目的的训练。
%cnn_train提供了自动检查上次训练状态并且继续接着训练的能力。
%cnn_train支持使用GPU并且同时支持多个GPU的并行运算
% ------------------------------------------------------------------------
opts.subsetSize = 1e4;
opts.expDir = fullfile('data','exp') ; %选择保存路径
opts.continue = true ; %选择每次重启都是接着上次训练状态开始
opts.batchSize = 256 ; %选择初始化批的大小为256
opts.numSubBatches = 1 ; %选择子批的个数为1(不划分子批)
opts.train = [] ; %初始化训练集索引为空
opts.val = [] ; %初始化验证集索引为空
opts.gpus = [] ; %选择GPU
opts.prefetch = false ; %选择是否预读取下一批次的样本(初始化为否)
opts.numEpochs = 300 ; %选择epoch为300
opts.learningRate = 0.001 ; %选择学习率为0.001
opts.weightDecay = 0.0005 ; %选择权重延迟为0.0005
opts.momentum = 0.9 ; %选择动量为0.9
opts.saveMomentum = true ; %选择存储动量
opts.nesterovUpdate = false ; %选择nesterovUpdate为假
opts.randomSeed = 0 ; %选择随机种子为0
opts.memoryMapFile = fullfile(tempdir, 'matconvnet.bin') ; %选择内存映射文件
opts.profile = false ; %选择profile为假
opts.parameterServer.method = 'mmap' ; %选择参数server的途径为mmap
opts.parameterServer.prefix = 'mcn' ; %选择参数server的词头为m