在我们上期文章中(文章请见《深度学习之图片分类》),我们使用MNIST数据集训练了自己的图片分类模型,并保存在build/model
目录下。接下来,我们将使用上期训练的模型进行预测图片。
加载模型
private static Classifications predict() throws IOException, ModelException, TranslateException {
Image img = ImageFactory.getInstance().fromUrl("https://www.d2lcoder.com/image/0.png");
Mlp mlp = new Mlp(
Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH,
Mnist