项目需要,先train好模型,存至本地。在项目中需要预测用户上传的图像是什么。此过程就应用到了加载本地模型,然后预测用户上传的图像,反馈会结果。
一、将模型存至本地:
//save model
File locationToSave = new File("/home/greg/LungCNNModel.zip");
boolean saveUpdater = true;
ModelSerializer.writeModel(model, locationToSave, saveUpdater);
第一个参数为要保存的模型,第二个参数为模型保存的地方,第三个参数为模型的一些参数,不是权重与偏置,而是模型更新的动量,RMSProp , Adagrad 。这样就可以将模型保存至本地,格式为zip格式。
二、加载本地模型:
MultiLayerNetwork restored = ModelSerializer.restoreMultiLayerNetwork(locationToSave);
System.out.println("Saved and loaded parameters are equal: " + model.params().equals(restored.params()));
System.out.println("Saved and loaded configurations are equal: " +
model.getLayerWiseConfigurations().equals(restored.getLayerWiseConfigurations()));
从本地把模型加载进来。
比较原始模型与加载进来的模型参数是否一致。
比较原始模型的网络配置是否与加载进来的模型网络配置是否一致。
三、预测位置图像的类别
1、加载图像
File trainData = new File("/home/greg/data/image/");
FileSplit train = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS,randNumGen);
2、制作label生成器
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
3、读入图像病初始化 recordReader
ImageRecordReader recordReader = new ImageRecordReader(height,width,channels,labelMaker);
recordReader.initialize(train);
4、制作图像迭代器,包括图像读取器,批处理,label 索引,输出几个类别
DataSetIterator dataIter = new RecordReaderDataSetIterator(recordReader,batchSize,1,outputNum);
5、对图像像素进行归一化
DataNormalization scaler = new ImagePreProcessingScaler(0,1);
scaler.fit(dataIter);
dataIter.setPreProcessor(scaler);
6、将图像迭代器格式转为DataSet格式
DataSet img = dataIter.next();
7、预测图像的label System.out.println(restored.output(img.getFeatureMatrix() , false));
false代表测试集,true为训练集
getFeatureMatrix为获得图像的数据,通过restored.output预测得到输出。
结果输出即可。
另外一个api ,restored.predicted(dataIter),我本以为这个是预测输出呢,但是这个是输出我标定的label值,也就是我存储图像的文件夹的名字。容易误导人。
注意,不要用predict应该用output。