图像分类——基于Deeplearning4j

目录

创建工程

Deeplearning4j

MNIST 数据集

加载数据

创建模型

创建单层回归模型

创建深度信念网络

创建多层卷积网络

完整代码


本节将讨论如何使用Deeplearning4j库实现一些神经网络结构。

创建工程

接着前面的项目,首先导入maven依赖

        <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>deeplearning4j-core</artifactId>
            <version>0.4-rc3.8</version>
        </dependency>
        <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>deeplearning4j-nlp</artifactId>
            <version>0.4-rc3.8</version>
        </dependency>
        <dependency>
            <groupId>org.nd4j</groupId>
            <artifactId>nd4j-x86</artifactId>
            <version>0.4-rc3.8</version>
        </dependency>
        <dependency>
            <groupId>org.nd4j</groupId>
            <artifactId>canova-nd4j-image</artifactId>
            <version>0.0.0.14</version>
        </dependency>
        <dependency>
            <groupId>org.nd4j</groupId>
            <artifactId>canova-nd4j-codec</artifactId>
            <version>0.0.0.14</version>
        </dependency>

Deeplearning4j

Deeplearning4j库,它是Java与Scala环境下的开源分布式深度学习项目。 Deeplearning4j依赖Spark与Hadoop使用MapReduce框架,并行训练模型,且反复对中心模型中产 生参数进行平均。

MNIST 数据集

MNIST数据集是最著名的数据集之一,由手写数字组成,如图所示。该数据集包含60 000 个训练与10 000个测试图像。

这个数据集通常用在图像识别问题中,以测试算法性能。最差记录的错误率是12%,测试时 使用单层神经网络中的SVM算法,并且没有做预处理。截止到2016年,最低的错误率只有0.21%, 使用的是DropConnect神经网络;紧随其后的是深度卷积网络,错误率为0.23%;然后是深度前 馈网络,错误率是0.35%。
接下来,让我们看看如何加载数据集。

加载数据

Deeplearning4j 提供 了“开箱 即用”的 MNIST 数 据集 加载器。加 载器被初始 化为 DataSetIterator。先导入DataSetIterator类与所有支持的数据集,这些数据集是impl包 的一部分,包含的数据集有Iris、MNIST及其他。

接着定义一些常量,比如28×28个像素组成的图像,有10个目标类与60 000个样本。新初始 化一个MnistDataSetIterator类,用于下载数据集及其标签。参数分别是迭代批大小、总样 本数,以及是否将数据集二值化:

        // 定义常量
        final int numRows = 28; // 输入图像的高度
        final int numColumns = 28; // 输入图像的宽度
        int outputNum = 10; // 输出的类别数量(0-9)
        int numSamples = 60000; // 训练集的样本数量
        int batchSize = 100; // 每批数据的大小
        int iterations = 10; // 训练时的迭代次数
        int seed = 123; // 随机种子
        int listenerFreq = batchSize / 5; // 监听器的频率,每5批次打印一次分数
        // 输出加载数据的信息
        System.out.println("加载数据···");
        // 创建MNIST数据集迭代器,用于读取训练集
        DataSetIterator iterator = new MnistDataSetIterator(batchSize, numSamples, true);

但是这里有个问题,这里是自动从国外拉数据,开了VPN也会失败,所以我们可以写个自定义的数据加载器加载本地数据,不过容易OOM,所以我这里没跑出来。

        // 定义常量
        final int numRows = 28; // 输入图像的高度
        final int numColumns = 28; // 输入图像的宽度
        int outputNum = 10; // 输出的类别数量(0-9)
        int numSamples = 60000; // 训练集的样本数量
        int batchSize = 100; // 每批数据的大小
        int iterations = 10; // 训练时的迭代次数
        int seed = 123; // 随机种子
        int listenerFreq = batchSize / 5; // 监听器的频率,每5批次打印一次分数
        // 输出加载数据的信息
        System.out.println("加载数据···");
        // 创建MNIST数据集迭代器,用于读取训练集
        DataSetIterator iterator = getMnistDataSetIterator(batchSize, true);

    private static DataSetIterator getMnistDataSetIterator(int batchSize, boolean train) throws IOException {
        String imagesFile = train ? "train-images-idx3-ubyte.gz" : "t10k-images-idx3-ubyte.gz";
        String labelsFile = train ? "train-labels-idx1-ubyte.gz" : "t10k-labels-idx1-ubyte.gz";
        INDArray images = loadImages(new File(BASE_PATH, imagesFile));
        INDArray labels = loadLabels(new File(BASE_PATH, labelsFile));
        DataSet dataSet = new DataSet(images, labels);
        // 手动标准化数据
        INDArray mean = images.mean(0);
        INDArray std = images.std(0);
        images.subiRowVector(mean);
        images.diviRowVector(std);
        List<DataSet> list = new ArrayList<>();
        for (int i = 0; i < dataSet.numExamples(); i++) {
            list.add(dataSet.get(i));
        }
        return new ListDataSetIterator(list, batchSize);
    }

    private static INDArray loadImages(File file) throws IOException {
        try (DataInputStream dis = new DataInputStream(new FileInputStream(file))) {
            int magicNumber = dis.readInt();
            //if (magicNumber != 2051) {
            //    throw new IOException("Invalid magic number in image file!");
            //}
            int numImages = dis.readInt();
            int numRows = dis.readInt();
            int numColumns = dis.readInt();
            INDArray images = Nd4j.create(numImages, 1, numRows, numColumns);
            for (int i = 0; i < numImages; i++) {
                for (int r = 0; r < numRows; r++) {
                    for (int c = 0; c < numColumns; c++) {
                        int pixel = dis.readUnsignedByte();
                        images.putScalar(new int[]{i, 0, r, c}, pixel);
                    }
                }
            }
            return images;
        }
    }

    private static INDArray loadLabels(File file) throws IOException {
        try (DataInputStream dis = new DataInputStream(new FileInputStream(file))) {
            int magicNumber = dis.readInt();
            //if (magicNumber != 2049) {
            //    throw new IOException("Invalid magic number in label file!");
            //}
            int numLabels = dis.readInt();
            INDArray labels = Nd4j.zeros(numLabels, 10);
            for (int i = 0; i < numLabels; i++) {
                int label = dis.readUnsignedByte();
                labels.putScalar(new int[]{i, label}, 1.0);
            }
            return labels;
        }
    }

创建模型

本节将讨论如何实际创建一个神经网络模型。先创建一个基本的

deeplearning4j是基于java的深度学习库,当然,它有许多特点,但暂时还没学那么深入,所以就不做介绍了 需要学习dl4j,无从下手,就想着先看看官网的examples,于是,下载了examples程序,结果无法运行,总是出错,如下: 查看一周的错误,也没有成功,马上就要放弃了,结果今天在论坛一大牛指导下,终于成功跑起,下面,将心酸的环境配置过程记录如下,以备自己以后查阅,同时,也希望各种高手可以指点,毕竟,本人还是菜鸟一枚 1.安装JAVA运行环境 该部分,网上有许多教程,这里不再赘述,首先,就是安装一个JDK,然后,再安装一个自己喜欢的IED,这里,以eclispe为例 好了,java的运行环境配置好了,接下来,开始配置dl4j的运行环境,它的官网上给了好复杂的设置步骤,照着做看一些后,发现根本无法进行,结果发现,不需要全部设置完成,就可以运行它的例子了,所以,本人并没有按照官网的教程全部设置,只是设置到了可以运行官网的examples为止,可能存在隐患吧,但本人能力有限,实在无从下手,还期待高手指定 2.按照Maven 按照教程安装Maven,该教程讲述非常详细 (1)下载Maven3,3,3,以win7 64位为例 下载地址:https://maven.apache.org/download.cgi (2)将Maven解压到某个文件夹中,这里以“C:\Program Files\apache-maven-3.3.3”为例 (3)配置环境变量:将maven中的bin的路径添加到system variables的PATH中 (4)测试maven是否安装成功 在命令行中输入mvn -version 如果如下下图所示结果,证明配置正确 3. 下载dl4j的examples,网址为: https://github.com/deeplearning4j/dl4j-0.4-examples 4.打开eclipse,导入刚刚下载的dl4j的examples,具体地: 打开eclipse后->File->import->Maven Existing Maven Projects,在Root Directory中选择examples的文件夹 然后,Finish 这样,examples被成功导入 当然,由于Maven会自动导入程序所需的jar文件(在配置文件pom.xml中所提及),所以,会花费一些时间自动下载这些文件 点击运行,出现如下错误: 这个问题困扰了本人一周,终于解决,是因为系统缺少dll文件所致 5. 下载dll文件,地址为https://www.dropbox.com/s/6p8yn3fcf230rxy/ND4J_Win64_OpenBLAS-v0.2.14.zip?dl=1 下载后,将该文件随意放入一个文件夹中,这里以“C:/BLAS”为例 将所有下载得到的dll文件放入该文件夹,并且,将该路径添加至环境变量Path中 6.此时,再运行刚刚的examples,发现程序终于可以正常运行了!
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

顾北辰20

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值