MXNet实战入门Minist分类

本文详细介绍使用MXNet实现LeNet神经网络对手写数字进行训练和测试的过程。通过定义网络结构,设置训练参数,加载MNIST数据集进行训练,并保存模型。最后,加载训练好的模型对测试图像进行预测。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >


文件目录如下:
在这里插入图片描述

1.训练过程

  • 训练代码:
import mxnet as mx
import argparse
import numpy as np
import gzip
import struct
import logging

def get_network(num_classes):
    """
    LeNet
    """
    data = mx.sym.Variable("data")
    conv1 = mx.sym.Convolution(data=data, kernel=(5,5), num_filter=6,
                               name="conv1")
    relu1 = mx.sym.Activation(data=conv1, act_type="relu", name="relu1")
    pool1 = mx.sym.Pooling(data=relu1, kernel=(2,2), stride=(2,2),
                           pool_type="max", name="pool1")

    conv2 = mx.sym.Convolution(data=pool1, kernel=(5, 5), num_filter=16,
                               name="conv2")
    relu2 = mx.sym.Activation(data=conv2, act_type="relu", name="relu2")
    pool2 = mx.sym.Pooling(data=relu2, kernel=(2, 2), stride=(2, 2),
                           pool_type="max", name="pool2")

    fc1 = mx.sym.FullyConnected(data=pool2, num_hidden=120, name="fc1")
    relu3 = mx.sym.Activation(data=fc1, act_type="relu", name="relu3")

    fc2 = mx.sym.FullyConnected(data=relu3, num_hidden=84, name="fc2")
    relu4 = mx.sym.Activation(data=fc2, act_type="relu", name="relu4")

    fc3 = mx.sym.FullyConnected(data=relu4, num_hidden=num_classes, name="fc3")
    sym = mx.sym.SoftmaxOutput(data=fc3, name="softmax")
    return sym

def get_args():
    parser = argparse.ArgumentParser(description='score a model on a dataset')
    parser.add_argument('--num-classes', type=int, default=10)
    parser.add_argument('--gpus', type=str, default='0')
    parser.add_argument('--batch-size', type=int, default=64)
    parser.add_argument('--num-epoch', type=int, default=10)
    parser.add_argument('--lr', type=float, default=0.1, help="learning rate")
    parser.add_argument('--save-result', type=str, default='output/')
    parser.add_argument('--save-name', type=str, default='LeNet')
    args = parser.parse_args()
    return args

if __name__ == '__main__':
    args = get_args()
    if args.gpus:
        context = [mx.gpu(int(index)) for index in
                   args.gpus.strip().split(",")]
    else:
        context = mx.cpu()

    # get data
    train_data = mx.io.MNISTIter(
        image='train-images.idx3-ubyte',
        label='train-labels.idx1-ubyte',
        batch_size=args.batch_size,
        shuffle=1)
    val_data = mx.io.MNISTIter(
        image='t10k-images.idx3-ubyte',
        label='t10k-labels.idx1-ubyte',
        batch_size=args.batch_size,
        shuffle=0)

    # get network(symbol)
    sym = get_network(num_classes=args.num_classes)

    optimizer_params = {'learning_rate':args.lr}
    initializer = mx.init.Xavier(rnd_type='gaussian', factor_type='in',
                                 magnitude=2)

    mod = mx.mod.Module(symbol=sym, context=context)
    #训练日志保存
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    stream_handler = logging.StreamHandler()
    logger.addHandler(stream_handler)
    file_handler = logging.FileHandler('output/train.log')
    logger.addHandler(file_handler)
    logger.info(args)
    #默认period=1代表每训练完一次epoch保存一次结果
    checkpoint = mx.callback.do_checkpoint(prefix=args.save_result
                                           +args.save_name, period=2)

    batch_callback = mx.callback.Speedometer(args.batch_size, 1000)
    mod.fit(train_data=train_data,
            eval_data=val_data,
            eval_metric='acc',
            optimizer_params=optimizer_params,
            optimizer='sgd',
            batch_end_callback=batch_callback,
            initializer=initializer,
            num_epoch=args.num_epoch,
            epoch_end_callback=checkpoint)

  • 训练结果:
/home/yuyang/anaconda3/envs/mxnet/bin/python3.5 /home/yuyang/下载/MXNet-Deep-Learning-in-Action-master/demo_4/train_mnist.py
[15:14:28] src/io/iter_mnist.cc:113: MNISTIter: load 60000 images, shuffle=1, shape=[64,1,28,28]
[15:14:28] src/io/iter_mnist.cc:113: MNISTIter: load 10000 images, shuffle=0, shape=[64,1,28,28]
Namespace(batch_size=64, gpus='0', lr=0.1, num_classes=10, num_epoch=10, save_name='LeNet', save_result='output/')
[15:14:29] src/operator/nn/./cudnn/./cudnn_algoreg-inl.h:97: Running performance tests to find the best convolution algorithm, this can take a while... (setting env variable MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable)
Epoch[0] Train-accuracy=0.932597
Epoch[0] Time cost=0.839
Epoch[0] Validation-accuracy=0.980168
Epoch[1] Train-accuracy=0.979156
Epoch[1] Time cost=0.825
Saved checkpoint to "output/LeNet-0002.params"
Epoch[1] Validation-accuracy=0.984575
Epoch[2] Train-accuracy=0.984458
Epoch[2] Time cost=0.884
Epoch[2] Validation-accuracy=0.986178
Epoch[3] Train-accuracy=0.988744
Epoch[3] Time cost=0.894
Saved checkpoint to "output/LeNet-0004.params"
Epoch[3] Validation-accuracy=0.986879
Epoch[4] Train-accuracy=0.991229
Epoch[4] Time cost=0.821
Epoch[4] Validation-accuracy=0.986779
Epoch[5] Train-accuracy=0.993096
Epoch[5] Time cost=0.854
Saved checkpoint to "output/LeNet-0006.params"
Epoch[5] Validation-accuracy=0.986879
Epoch[6] Train-accuracy=0.994697
Epoch[6] Time cost=0.859
Epoch[6] Validation-accuracy=0.987280
Epoch[7] Train-accuracy=0.995748
Epoch[7] Time cost=0.854
Saved checkpoint to "output/LeNet-0008.params"
Epoch[7] Validation-accuracy=0.987280
Epoch[8] Train-accuracy=0.996331
Epoch[8] Time cost=0.833
Epoch[8] Validation-accuracy=0.987480
Epoch[9] Train-accuracy=0.996815
Epoch[9] Time cost=0.888
Saved checkpoint to "output/LeNet-0010.params"
Epoch[9] Validation-accuracy=0.987680

Process finished with exit code 0

2.测试过程

  • 测试代码
import mxnet as mx
import numpy as np

#模型加载
def load_model(model_prefix, index, context, data_shapes, label_shapes):
  #mx.model.load_checkpoint()接口用于导入训练好的模型
  sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, index)
  model = mx.mod.Module(symbol=sym, context=context)
  #将网络结构与输入数据绑定到一起
  model.bind(data_shapes=data_shapes, label_shapes=label_shapes,
             for_training=False)

  model.set_params(arg_params=arg_params, aux_params=aux_params,
                   allow_missing=True) #allow_missing=True表示arg_params和aux_params不必与待初始化网络结构完全一致
  return model

#数据读取
def load_data(data_path):
  data = mx.image.imread(data_path, flag=0)  #flag=1表示3通道
  cla_cast_aug = mx.image.CastAug()          #将数类型转换成float32
  cla_resize_aug = mx.image.ForceResizeAug(size=[28, 28])  #保证输入图像尺寸与网络结构定义图像尺寸一致
  cla_augmenters = [cla_cast_aug, cla_resize_aug]

  for aug in cla_augmenters:
      data = aug(data)
  data = mx.nd.transpose(data, axes=(2, 0, 1)) #(H, W, C)---->(C, H, W)
  data = mx.nd.expand_dims(data, axis=0)       #增加第0维度,将数据扩展成4维
  data = mx.io.DataBatch([data])               #将数据封装成模型能够直接处理的数据结构
  return data

#预测输出
def get_output(model, data):
  #执行前向操作
  model.forward(data)
  #得到模型输出;get_output方法的第一个[0]代表第0个分类任务的输出,本例中只有一个任务
  #model.get_output()[0]是一个2维的NDArray,第0维表示batch_size,本例中只预测一张图片
  #model.get_output()[0][0]得到的是1*N大小的NDArray向量,本例中N=10,该向量每个值代表属于每个类别的概率
  #接着调用asnumpy()方法就得到了Numpy array结构的结果
  cla_prob = model.get_outputs()[0][0].asnumpy()
  #调用Numpy的argmax()接口得到概率最大值所对应的下标index
  cla_label = np.argmax(cla_prob)
  return cla_label

if __name__ == "__main__":
  model_prefix = "output/LeNet"
  index = 10
  context = mx.gpu(0)
  data_shapes = [('data', (1,1,28,28))]
  label_shapes = [('softmax_label', (1,))]
  model = load_model(model_prefix, index, context, data_shapes, label_shapes)

  data_path = "test_image/test1.png"
  data  =load_data(data_path)
  cla_label = get_output(model, data)
  print("Predict result:{}".format(cla_label))

  • 输出结果
/home/yuyang/anaconda3/envs/mxnet/bin/python3.5 /home/yuyang/下载/MXNet-Deep-Learning-in-Action-master/demo_4/test.mnist.py
[17:18:29] src/operator/nn/./cudnn/./cudnn_algoreg-inl.h:97: Running performance tests to find the best convolution algorithm, this can take a while... (setting env variable MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable)
Predict result:5

Process finished with exit code 0
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值