目录
本节将讨论如何使用Deeplearning4j库实现一些神经网络结构。
创建工程
接着前面的项目,首先导入maven依赖
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-core</artifactId>
<version>0.4-rc3.8</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-nlp</artifactId>
<version>0.4-rc3.8</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-x86</artifactId>
<version>0.4-rc3.8</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>canova-nd4j-image</artifactId>
<version>0.0.0.14</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>canova-nd4j-codec</artifactId>
<version>0.0.0.14</version>
</dependency>
Deeplearning4j
Deeplearning4j库,它是Java与Scala环境下的开源分布式深度学习项目。 Deeplearning4j依赖Spark与Hadoop使用MapReduce框架,并行训练模型,且反复对中心模型中产 生参数进行平均。
MNIST 数据集
MNIST数据集是最著名的数据集之一,由手写数字组成,如图所示。该数据集包含60 000 个训练与10 000个测试图像。
这个数据集通常用在图像识别问题中,以测试算法性能。最差记录的错误率是12%,测试时 使用单层神经网络中的SVM算法,并且没有做预处理。截止到2016年,最低的错误率只有0.21%, 使用的是DropConnect神经网络;紧随其后的是深度卷积网络,错误率为0.23%;然后是深度前 馈网络,错误率是0.35%。
接下来,让我们看看如何加载数据集。
加载数据
Deeplearning4j 提供 了“开箱 即用”的 MNIST 数 据集 加载器。加 载器被初始 化为 DataSetIterator。先导入DataSetIterator类与所有支持的数据集,这些数据集是impl包 的一部分,包含的数据集有Iris、MNIST及其他。
接着定义一些常量,比如28×28个像素组成的图像,有10个目标类与60 000个样本。新初始 化一个MnistDataSetIterator类,用于下载数据集及其标签。参数分别是迭代批大小、总样 本数,以及是否将数据集二值化:
// 定义常量
final int numRows = 28; // 输入图像的高度
final int numColumns = 28; // 输入图像的宽度
int outputNum = 10; // 输出的类别数量(0-9)
int numSamples = 60000; // 训练集的样本数量
int batchSize = 100; // 每批数据的大小
int iterations = 10; // 训练时的迭代次数
int seed = 123; // 随机种子
int listenerFreq = batchSize / 5; // 监听器的频率,每5批次打印一次分数
// 输出加载数据的信息
System.out.println("加载数据···");
// 创建MNIST数据集迭代器,用于读取训练集
DataSetIterator iterator = new MnistDataSetIterator(batchSize, numSamples, true);
但是这里有个问题,这里是自动从国外拉数据,开了VPN也会失败,所以我们可以写个自定义的数据加载器加载本地数据,不过容易OOM,所以我这里没跑出来。
// 定义常量
final int numRows = 28; // 输入图像的高度
final int numColumns = 28; // 输入图像的宽度
int outputNum = 10; // 输出的类别数量(0-9)
int numSamples = 60000; // 训练集的样本数量
int batchSize = 100; // 每批数据的大小
int iterations = 10; // 训练时的迭代次数
int seed = 123; // 随机种子
int listenerFreq = batchSize / 5; // 监听器的频率,每5批次打印一次分数
// 输出加载数据的信息
System.out.println("加载数据···");
// 创建MNIST数据集迭代器,用于读取训练集
DataSetIterator iterator = getMnistDataSetIterator(batchSize, true);
private static DataSetIterator getMnistDataSetIterator(int batchSize, boolean train) throws IOException {
String imagesFile = train ? "train-images-idx3-ubyte.gz" : "t10k-images-idx3-ubyte.gz";
String labelsFile = train ? "train-labels-idx1-ubyte.gz" : "t10k-labels-idx1-ubyte.gz";
INDArray images = loadImages(new File(BASE_PATH, imagesFile));
INDArray labels = loadLabels(new File(BASE_PATH, labelsFile));
DataSet dataSet = new DataSet(images, labels);
// 手动标准化数据
INDArray mean = images.mean(0);
INDArray std = images.std(0);
images.subiRowVector(mean);
images.diviRowVector(std);
List<DataSet> list = new ArrayList<>();
for (int i = 0; i < dataSet.numExamples(); i++) {
list.add(dataSet.get(i));
}
return new ListDataSetIterator(list, batchSize);
}
private static INDArray loadImages(File file) throws IOException {
try (DataInputStream dis = new DataInputStream(new FileInputStream(file))) {
int magicNumber = dis.readInt();
//if (magicNumber != 2051) {
// throw new IOException("Invalid magic number in image file!");
//}
int numImages = dis.readInt();
int numRows = dis.readInt();
int numColumns = dis.readInt();
INDArray images = Nd4j.create(numImages, 1, numRows, numColumns);
for (int i = 0; i < numImages; i++) {
for (int r = 0; r < numRows; r++) {
for (int c = 0; c < numColumns; c++) {
int pixel = dis.readUnsignedByte();
images.putScalar(new int[]{i, 0, r, c}, pixel);
}
}
}
return images;
}
}
private static INDArray loadLabels(File file) throws IOException {
try (DataInputStream dis = new DataInputStream(new FileInputStream(file))) {
int magicNumber = dis.readInt();
//if (magicNumber != 2049) {
// throw new IOException("Invalid magic number in label file!");
//}
int numLabels = dis.readInt();
INDArray labels = Nd4j.zeros(numLabels, 10);
for (int i = 0; i < numLabels; i++) {
int label = dis.readUnsignedByte();
labels.putScalar(new int[]{i, label}, 1.0);
}
return labels;
}
}
创建模型
本节将讨论如何实际创建一个神经网络模型。先创建一个基本的