Mxnet图片分类(2)训练模型

   训练模型可以利用自定义的模型进行也可以采用fine-tune的方法。这里先介绍如何自定义模型进行训练。

系统: ubuntu14.04
Mxnet: 0.904

1.数据准备

mxnet数据集的生成可以参考上一篇文章

这里从代码来分析:

#数据路径
train_iter = "/mxnet/tools/train-cat.rec"
val_iter = "/mxnet/tools/train-cat_test.rec"
#放到ImageRecordIter,此处省略了均值文件
train_dataiter = mx.io.ImageRecordIter(
            path_imgrec=train_iter,#上面的路径
            #mean_img=datadir+"/mean.bin",
            rand_crop=True,
            rand_mirror=True,
            data_shape=(3,128,128),
            batch_size=batch_size,
            preprocess_threads=1)
test_dataiter = mx.io.ImageRecordIter(
            path_imgrec=val_iter,#上面的路径
            #mean_img=datadir+"/mean.bin",
            rand_crop=False,
            rand_mirror=False,
            data_shape=(3,128,128),
            batch_size=batch_size,
            preprocess_threads=1)

2.准备训练的网络

mxnet的网络通过Symbol来定义。

def get_symbol(num_classes, **kwargs):
    input_data = mx.symbol.Variable(name="data")
    # stage 1
    conv1 = mx.symbol.Convolution(name='conv1',
        data=input_data, kernel=(11, 11), stride=(4, 4), num_filter=96)
    relu1 = mx.symbol.Activation(data=conv1, act_type="relu")
    lrn1 = mx.symbol.LRN(data=relu1, alpha=0.0001, beta=0.75, knorm=2, nsize=5)
    pool1 = mx.symbol.Pooling(
        data=lrn1, pool_type="max", kernel=(3, 3), stride=(2,2))
    # stage 2
    conv2 = mx.symbol.Convolution(name='conv2',
        data=pool1, kernel=(5, 5), pad=(2, 2), num_filter=256)
    relu2 = mx.symbol.Activation(data=conv2, act_type="relu")
    lrn2 = mx.symbol.LRN(data=relu2, alpha=0.0001, beta=0.75, knorm=2, nsize=5)
    pool2 = mx.symbol.Pooling(data=lrn2, kernel=(3, 3), stride=(2, 2), pool_type="max")
    # stage 3
    conv3 = mx.symbol.Convolution(name='conv3',
        data=pool2, kernel=(3, 3), pad=(1, 1), num_filter=384)
    relu3 = mx.symbol.Activation(data=conv3, act_type="relu")
    conv4 = mx.symbol.Convolution(name='conv4',
        data=relu3, kernel=(3, 3), pad=(1, 1), num_filter=384)
    relu4 = mx.symbol.Activation(data=conv4, act_type="relu")
    conv5 = mx.symbol.Convolution(name='conv5',
        data=relu4, kernel=(3, 3), pad=(1, 1), num_filter=256)
    relu5 = mx.symbol.Activation(data=conv5, act_type="relu")
    pool3 = mx.symbol.Pooling(data=relu5, kernel=(3, 3), stride=(2, 2), pool_type="max")
    # stage 4
    flatten = mx.symbol.Flatten(data=pool3)
    fc1 = mx.symbol.FullyConnected(name='fc1', data=flatten, num_hidden=4096)
    relu6 = mx.symbol.Activation(data=fc1, act_type="relu")
    dropout1 = mx.symbol.Dropout(data=relu6, p=0.5)
    # stage 5
    fc2 = mx.symbol.FullyConnected(name='fc2', data=dropout1, num_hidden=4096)
    relu7 = mx.symbol.Activation(data=fc2, act_type="relu")
    dropout2 = mx.symbol.Dropout(data=relu7, p=0.5)
    # stage 6
    fc3 = mx.symbol.FullyConnected(name='fc3', data=dropout2, num_hidden=num_classes)
    softmax = mx.symbol.SoftmaxOutput(data=fc3, name='softmax')
    return softmax

3.模型训练

import logging

net = get_symbol(2) #类别为2 即cat 和dog
mod = mx.mod.Module(symbol=net,
                   context=mx.gpu(),
                   data_names=['data'],
                   label_names=['softmax_label'])
logging.basicConfig(level=logging.DEBUG)
mod.fit(train_dataiter,
       eval_data=test_dataiter,
       optimizer='sgd',
       optimizer_params={'learning_rate':0.1},
       eval_metric='acc',
       num_epoch = 8)

这里写图片描述

4.保存模型

mod.save_checkpoint('./model',num_epoch)

参考文献

[1]http://mxnet.io/api/python/module.html#mxnet.module.BaseModule.forward
[2]http://mxnet.io/tutorials/basic/module.html

环境的安装和数据集的制作可以参考

  1. Mxnet—faster-rcnn环境安装
  2. Mxnet图片分类(1)准备数据集

测试可以参考:

  1. Mxnet图片分类(4)利用训练好的模型进行测试
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值