训练模型可以利用自定义的模型进行也可以采用fine-tune的方法。这里先介绍如何自定义模型进行训练。
系统: ubuntu14.04
Mxnet: 0.904
1.数据准备
mxnet数据集的生成可以参考上一篇文章
这里从代码来分析:
#数据路径
train_iter = "/mxnet/tools/train-cat.rec"
val_iter = "/mxnet/tools/train-cat_test.rec"
#放到ImageRecordIter,此处省略了均值文件
train_dataiter = mx.io.ImageRecordIter(
path_imgrec=train_iter,#上面的路径
#mean_img=datadir+"/mean.bin",
rand_crop=True,
rand_mirror=True,
data_shape=(3,128,128),
batch_size=batch_size,
preprocess_threads=1)
test_dataiter = mx.io.ImageRecordIter(
path_imgrec=val_iter,#上面的路径
#mean_img=datadir+"/mean.bin",
rand_crop=False,
rand_mirror=False,
data_shape=(3,128,128),
batch_size=batch_size,
preprocess_threads=1)
2.准备训练的网络
mxnet的网络通过Symbol来定义。
def get_symbol(num_classes, **kwargs):
input_data = mx.symbol.Variable(name="data")
# stage 1
conv1 = mx.symbol.Convolution(name='conv1',
data=input_data, kernel=(11, 11), stride=(4, 4), num_filter=96)
relu1 = mx.symbol.Activation(data=conv1, act_type="relu")
lrn1 = mx.symbol.LRN(data=relu1, alpha=0.0001, beta=0.75, knorm=2, nsize=5)
pool1 = mx.symbol.Pooling(
data=lrn1, pool_type="max", kernel=(3, 3), stride=(2,2))
# stage 2
conv2 = mx.symbol.Convolution(name='conv2',
data=pool1, kernel=(5, 5), pad=(2, 2), num_filter=256)
relu2 = mx.symbol.Activation(data=conv2, act_type="relu")
lrn2 = mx.symbol.LRN(data=relu2, alpha=0.0001, beta=0.75, knorm=2, nsize=5)
pool2 = mx.symbol.Pooling(data=lrn2, kernel=(3, 3), stride=(2, 2), pool_type="max")
# stage 3
conv3 = mx.symbol.Convolution(name='conv3',
data=pool2, kernel=(3, 3), pad=(1, 1), num_filter=384)
relu3 = mx.symbol.Activation(data=conv3, act_type="relu")
conv4 = mx.symbol.Convolution(name='conv4',
data=relu3, kernel=(3, 3), pad=(1, 1), num_filter=384)
relu4 = mx.symbol.Activation(data=conv4, act_type="relu")
conv5 = mx.symbol.Convolution(name='conv5',
data=relu4, kernel=(3, 3), pad=(1, 1), num_filter=256)
relu5 = mx.symbol.Activation(data=conv5, act_type="relu")
pool3 = mx.symbol.Pooling(data=relu5, kernel=(3, 3), stride=(2, 2), pool_type="max")
# stage 4
flatten = mx.symbol.Flatten(data=pool3)
fc1 = mx.symbol.FullyConnected(name='fc1', data=flatten, num_hidden=4096)
relu6 = mx.symbol.Activation(data=fc1, act_type="relu")
dropout1 = mx.symbol.Dropout(data=relu6, p=0.5)
# stage 5
fc2 = mx.symbol.FullyConnected(name='fc2', data=dropout1, num_hidden=4096)
relu7 = mx.symbol.Activation(data=fc2, act_type="relu")
dropout2 = mx.symbol.Dropout(data=relu7, p=0.5)
# stage 6
fc3 = mx.symbol.FullyConnected(name='fc3', data=dropout2, num_hidden=num_classes)
softmax = mx.symbol.SoftmaxOutput(data=fc3, name='softmax')
return softmax
3.模型训练
import logging
net = get_symbol(2) #类别为2 即cat 和dog
mod = mx.mod.Module(symbol=net,
context=mx.gpu(),
data_names=['data'],
label_names=['softmax_label'])
logging.basicConfig(level=logging.DEBUG)
mod.fit(train_dataiter,
eval_data=test_dataiter,
optimizer='sgd',
optimizer_params={'learning_rate':0.1},
eval_metric='acc',
num_epoch = 8)
4.保存模型
mod.save_checkpoint('./model',num_epoch)
参考文献
[1]http://mxnet.io/api/python/module.html#mxnet.module.BaseModule.forward
[2]http://mxnet.io/tutorials/basic/module.html