Windows 10 Anaconda Python 3.7 安装 MXNet GPU版

本文介绍了如何在Windows 10上使用Anaconda安装MXNet的GPU版本。由于MXNet GPU版不支持Python 3.7,Anaconda会自动降级到兼容的Python版本。安装过程不推荐使用pip,而应遵循官方推荐的方法。安装完成后,通过官方提供的MLP测试代码验证安装成功。
部署运行你感兴趣的模型镜像

MXNet CPU版本可以直接通过pip进行安装:

pip install mxnet

MXNet GPU版本不推荐通过pip进行安装,推荐使用conda进行安装:

conda install -c anaconda mxnet-gpu

运行结果如下:

(base) PS>conda install -c anaconda mxnet-gpu
Collecting package metadata: done
Solving environment: done

## Package Plan ##

  environment location: C:\Users\peter\Anaconda3

  added / updated specs:
    - mxnet-gpu


The following packages will be downloaded:

    package                    |            build
    ---------------------------|-----------------
    joblib-0.13.2              |           py36_0         360 KB  anaconda
    mxnet-1.2.1                |       h8cc8929_0           3 KB  anaconda
    mxnet-gpu-1.2.1            |       hf82a2c8_0           3 KB  anaconda
    navigator-updater-0.2.1    |           py36_0         1.3 MB  anaconda
    python-3.6.8               |       h9f7ef89_7        20.3 MB  anaconda
    python-dateutil-2.8.0      |           py36_0         282 KB  anaconda
    ------------------------------------------------------------
                                           Total:        1.02 GB

The following NEW packages will be INSTALLED:

  _mutex_mxnet       anaconda/win-64::_mutex_mxnet-0.0.20-gpu_mkl
  conda-package-han~ anaconda/win-64::conda-package-handling-1.1.5-py36_0
  joblib             anaconda/win-64::joblib-0.13.2-py36_0
  mxnet              anaconda/win-64::mxnet-1.2.1-h8cc8929_0
  mxnet-gpu          anaconda/win-64::mxnet-gpu-1.2.1-hf82a2c8_0

可以看出 MXNet GPU版不支持Python 3.7,所以 Anaconda 自动帮你做了降级,Anaconda还是非常智能的。


官方教程:Hand-written Digit Recognition — mxnet documentation

安装完成后进行测试:
下面的代码是官方的MLP测试代码,需要较长时间。

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

# pylint: skip-file
import mxnet as mx
import numpy as np
import os, sys
import pickle as pickle
import logging
from mxnet.test_utils import get_mnist_ubyte

# symbol net
batch_size = 100
data = mx.symbol.Variable('data')
fc1 = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128)
act1 = mx.symbol.Activation(fc1, name='relu1', act_type="relu")
fc2 = mx.symbol.FullyConnected(act1, name = 'fc2', num_hidden = 64)
act2 = mx.symbol.Activation(fc2, name='relu2', act_type="relu")
fc3 = mx.symbol.FullyConnected(act2, name='fc3', num_hidden=10)
softmax = mx.symbol.SoftmaxOutput(fc3, name = 'sm')

def accuracy(label, pred):
    py = np.argmax(pred, axis=1)
    return np.sum(py == label) / float(label.size)

num_epoch = 4
prefix = './mlp'

#check data
get_mnist_ubyte()

train_dataiter = mx.io.MNISTIter(
        image="data/train-images-idx3-ubyte",
        label="data/train-labels-idx1-ubyte",
        data_shape=(784,),
        label_name='sm_label',
        batch_size=batch_size, shuffle=True, flat=True, silent=False, seed=10)
val_dataiter = mx.io.MNISTIter(
        image="data/t10k-images-idx3-ubyte",
        label="data/t10k-labels-idx1-ubyte",
        data_shape=(784,),
        label_name='sm_label',
        batch_size=batch_size, shuffle=True, flat=True, silent=False)

def test_mlp():
    # print logging by default
    logging.basicConfig(level=logging.DEBUG)

    model = mx.model.FeedForward.create(
        softmax,
        X=train_dataiter,
        eval_data=val_dataiter,
        eval_metric=mx.metric.np(accuracy),
        epoch_end_callback=mx.callback.do_checkpoint(prefix),
        ctx=[mx.cpu(i) for i in range(2)],
        num_epoch=num_epoch,
        learning_rate=0.1, wd=0.0004,
        momentum=0.9)

    logging.info('Finish traning...')
    prob = model.predict(val_dataiter)
    logging.info('Finish predict...')
    val_dataiter.reset()
    y = np.concatenate([batch.label[0].asnumpy() for batch in val_dataiter]).astype('int')
    py = np.argmax(prob, axis=1)
    acc1 = float(np.sum(py == y)) / len(y)
    logging.info('final accuracy = %f', acc1)
    assert(acc1 > 0.95)

    # predict internal featuremaps
    internals = softmax.get_internals()
    fc2 = internals['fc2_output']
    mfeat = mx.model.FeedForward(symbol=fc2,
                                 arg_params=model.arg_params,
                                 aux_params=model.aux_params,
                                 allow_extra_params=True)
    feat = mfeat.predict(val_dataiter)
    assert feat.shape == (10000, 64)
    # pickle the model
    smodel = pickle.dumps(model)
    model2 = pickle.loads(smodel)
    prob2 = model2.predict(val_dataiter)
    assert np.sum(np.abs(prob - prob2)) == 0

    # load model from checkpoint
    model3 = mx.model.FeedForward.load(prefix, num_epoch)
    prob3 = model3.predict(val_dataiter)
    assert np.sum(np.abs(prob - prob3)) == 0

    # save model explicitly
    model.save(prefix, 128)
    model4 = mx.model.FeedForward.load(prefix, 128)
    prob4 = model4.predict(val_dataiter)
    assert np.sum(np.abs(prob - prob4)) == 0

    for i in range(num_epoch):
        os.remove('%s-%04d.params' % (prefix, i + 1))
    os.remove('%s-symbol.json' % prefix)
    os.remove('%s-0128.params' % prefix)


if __name__ == "__main__":
    test_mlp()

没有报错说明安装成功了。

您可能感兴趣的与本文相关的镜像

Python3.8

Python3.8

Conda
Python

Python 是一种高级、解释型、通用的编程语言,以其简洁易读的语法而闻名,适用于广泛的应用,包括Web开发、数据分析、人工智能和自动化脚本

评论 3
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值