综述:
目标:从零实现一个神经网络,并用mnist数据集进行训练,最后实现拿到一张图片能够识别它是0-9之间的数字。
工具:java1.8+
工具包:nd4j
一、数据分析
1.mnist数据集的介绍
MNIST 数据集可在 http://yann.lecun.com/exdb/mnist/ 获取, 它包含了四个部分:
- Training set images: train-images-idx3-ubyte.gz (9.9 MB, 解压后 47 MB, 包含 60,000 个样本)
- Training set labels: train-labels-idx1-ubyte.gz (29 KB, 解压后 60 KB, 包含 60,000 个标签)
- Test set images: t10k-images-idx3-ubyte.gz (1.6 MB, 解压后 7.8 MB, 包含 10,000 个样本)
- Test set labels: t10k-labels-idx1-ubyte.gz (5KB, 解压后 10 KB, 包含 10,000 个标签)
MNIST 数据集来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST). 训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员. 测试集(test set) 也是同样比例的手写数字数据.
2.mnist工具类的实现
package com.yan.dl4j.Utils; import javax.imageio.ImageIO; import java.awt.*; import java.awt.image.BufferedImage; import java.io.*; public class MnistReadUtil { public static final String TRAIN_IMAGES_FILE = "src\\main\\resources\\data\\train-images.idx3-ubyte"; public static final String TRAIN_LABELS_FILE = "src\\main\\resources\\data\\train-labels.idx1-ubyte"; public static final String TEST_IMAGES_FILE = "src\\main\\resources\\data\\t10k-images.idx3-ubyte"; public static final String TEST_LABELS_FILE = "src\\main\\resources\\data\\t10k-labels.idx1-ubyte"; /** * change bytes into a hex string. * * @param bytes bytes * @return the returned hex string */ public static String bytesToHex(byte[] bytes) { StringBuffer sb = new StringBuffer(); for (int i = 0; i < bytes.length; i++) { String hex = Integer.toHexString(bytes[i] & 0xFF); if (hex.length() < 2) { sb.append(0); } sb.append(hex); } return sb.toString(); } /** * get images of 'train' or 'test' * * @param fileName the file of 'train' or 'test' about image * @return one row show a `picture` */ public static double[][] getImages(String fileName) { try{ return getImages(new FileInputStream(fileName)); }catch (FileNotFoundException e){ e.printStackTrace(); } return null; } public static double[][] getImages(InputStream inputStream) { double[][] x; try (BufferedInputStream bin = new BufferedInputStream(inputStream)) { byte[] bytes = new byte[4]; bin.read(bytes, 0, 4); if (!"00000803".equals(bytesToHex(bytes))) { // 读取魔数 throw new RuntimeException("Please select the correct file!"); } else { bin.read(bytes, 0, 4); int number = Integer.parseInt(bytesToHex(bytes), 16); // 读取样本总数 bin.read(bytes, 0, 4); int xPixel = Integer.parseInt(bytesToHex(bytes), 16); // 读取每行所含像素点数 bin.read(bytes, 0, 4); int yPixel = Integer.parseInt(bytesToHex(bytes), 16); // 读取每列所含像素点数 x = new double[number][xPixel * yPixel]; for (int i = 0; i < number; i++) { double[] element = new double[xPixel * yPixel]; for (int j = 0; j < xPixel * yPixel; j++) { element[j] = bin.read(); // 逐一读取像素值 // normalization // element[j] = bin.read() / 255.0; } x[i] = element; } } } catch (IOException e) { throw new RuntimeException(e); } return x; } /** * get labels of `train` or `test` * * @param fileName the file of 'train' or 'test' about label * @return lables */ public static double[] getLabels(String fileName) { try{ return getLabels(new FileInputStream(fileName)); }catch (FileNotFoundException e){ e.printStackTrace(); } return null; } public static double[] getLabels(InputStream inputStream) { double[] y; try (BufferedInputStream bin = new BufferedInputStream(inputStream)) { byte[] bytes = new byte[4]; bin.read(bytes, 0, 4); if (!"00000801".equals(bytesToHex(bytes))) { throw new RuntimeException("Please select the correct file!"); } else { bin.read(bytes, 0, 4); int number = Integer.parseInt(bytesToHex(bytes), 16); y = new double[number]; for (int i = 0; i < number; i++) { y[i] = bin.read(); } } } catch (IOException e) { throw new RuntimeException(e); } return y; } public static void drawGrayPicture(double[] pixelValues, String fileName) throws IOException { //double转int int[] res = new int[pixelValues.length]; for(int i=0;i<pixelValues.length;i++){ res[i] = new Double(pixelValues[i]).intValue(); } //由数据集可以得知图片为28行28列的数据; int width = 28; int high = 28; BufferedImage bufferedImage = new BufferedImage(width, high, BufferedImage.TYPE_INT_RGB); for (int i = 0; i < width; i++) { for (int j = 0; j < high; j++) { int pixel = 255 - res[i * high + j]; int value = pixel + (pixel << 8) + (pixel << 16); // r = g = b 时,正好为灰度 bufferedImage.setRGB(i, j, value); } } ImageIO.write(bufferedImage, "JPEG", new File(fileName)); } public static double[] getSizeBlackWhiteImg(File file,int width, int height) throws IOException { double[] result = null; if (!file.exists()) { System.out.println("图片不存在"); return null; } BufferedImage bufImg = ImageIO.read(file); Image _img = bufImg.getScaledInstance(width, height, Image.SCALE_DEFAULT); BufferedImage image = new BufferedImage(width, height, BufferedImage.TYPE_INT_RGB); Graphics2D graphics = image.createGraphics(); graphics.drawImage(_img, 0, 0, null); graphics.dispose(); int[] rgb = new int[3]; result = new double[width*height]; for (int i = 0; i < width; i++) { for (int j = 0; j < height; j++) { int r = image.getRGB(i, j) & 0xFFFFFF; rgb[0] = (r & 0xff0000) >> 16; rgb[1] = (r & 0xff00) >> 8; rgb[2] = (r & 0xff); int color = 255 - (int)(rgb[0]* 0.3 + rgb[1] * 0.59 + rgb[2] * 0.11); result[i*height+j] = color; } } return result; } }
此工具类getImages方法实现mnist数据集转化为double[][]的类型,以train-images.idx3-ubyte为例,转化后的double[][]为double[60000][784],每个double[],代表着一张图,一共60000张图,每张图为28*28=784个特征值。getLabels方法实现mnist数据集标签转化为double[]类型。drawGrayPicture方法把double[]转化为一张人可观察的图。getSizeBlackWhiteImg方法实现把得到的一张图灰度化,并压缩成28*28的图,然后转化为double[]。
二、神经网络搭建
1.神经网络属性
private List<Layer> layers = new ArrayList<>(); private LastLayer lastLayer; private INDArray[] Network_W; private INDArray[] Network_B; private int nin; private int seed=123; private double learningrate=0.01; private int iteration = 10;
nin为数据输入维度 Layer为神经网络层的抽象 LastLayer 为最后一层,之所以分开是因为最后一层包含损失函数。Network_W
为每一层的W值矩阵,Network_B为每一层的B值,learningrate为学习率,iteration为迭代次数。seed为初始化随机数种子。
2.思路步骤:
(1)获取搭建神经元必要参数:输入数据维度、神经元层数、每层神经元个数。
public DeepNeuralNetWork(int nin){ this.nin = nin; } public DeepNeuralNetWork addLayer(Layer layer){ layers.add(layer); return this; } public DeepNeuralNetWork addLastLayer(LastLayer lastLayer){ this.lastLayer = lastLayer; Init(); return this; }
通过构造函数获取输入数据维度,通过加入layer层获取每层的信息。Layer是一个接口,从中可以获取到每层的激活函数、每层神经元个数等等。LastLayer 为最后一层类。与其他不同的是多了一个损失函数。得到最后一层之后就可以进行参数初始化。
(2)通过必要参数初始化神经网络。
public void Init(){ Network_W = new INDArray[layers.size()+1]; Network_B = new INDArray[layers.size()+1]; for(int i=0;i<layers.size()+1;i++){ if(i==0){ //第一个 if(layers.size()>0){ Network_W[i] = layers.get(i).getWinit().Init(seed,layers.get(i).getNeuralNumber(), nin); Network_B[i] = Nd4j.zeros(layers.get(i).getNeuralNumber(), 1); }else{ Network_W[i] = lastLayer.getWinit().Init(seed,lastLayer.getNeuralNumber(), nin); Network_B[i] = Nd4j.zeros(lastLayer.getNeuralNumber(), 1); } }else if(i==layers.size()){ //最后一个 Network_W[i] = lastLayer.getWinit().Init(seed,lastLayer.getNeuralNumber(), layers.get(i-1).getNeuralNumber()); Network_B[i] = Nd4j.zeros(lastLayer.getNeuralNumber(), 1); }else{ Network_W[i] = layers.get(i).getWinit().Init(seed,layers.get(i).getNeuralNumber(), layers.get(i-1).getNeuralNumber()); Network_B[i] = Nd4j.zeros(layers.get(i).getNeuralNumber(), 1); } } }
(3)向前传播
private INDArray linear_forward(INDArray A, INDArray W, INDArray b){ return W.mmul(A).addColumnVector(b); } private INDArray linear_activate_forward(INDArray A_p, INDArray W, INDArray b, ActivateMethod activate){ if(activate!=null){ return activate.activate_forward(linear_forward(A_p,W,b)); } return linear_forward(A_p,W,b); } private INDArray[] forward(INDArray X){ INDArray[] res = new INDArray[layers.size()+1]; INDArray P_A = X; for(int i=0;i<layers.size();i++){ INDArray A = linear_activate_forward(P_A,Network_W[i],Network_B[i],layers.get(i).getActivateMethod()); P_A = A; res[i] = A; } //最后一层 INDArray A = linear_activate_forward(P_A,Network_W[layers.size()],Network_B[layers.size()],lastLayer.getActivateMethod()); res[layers.size()] = A; return res; }
(4)反向传播
private INDArray LossBackward(INDArray A,INDArray Y){ return lastLayer.LastBackward(A,Y); } private double LossForward(INDArray A,INDArray Y){ return lastLayer.getLossMethod().LossForward(A,Y); } private List<INDArray[]> backward(INDArray[] A_array,INDArray x,INDArray Y){ INDArray DZ = LossBackward(A_array[A_array.length-1],Y); INDArray[] DW = new INDArray[A_array.length]; INDArray[] DB = new INDArray[A_array.length]; List<INDArray[]> res = new ArrayList<>(); for(int i=A_array.length-1;i>=0;i--){ if(i==0){ //最后一次 INDArray dW = DZ.mmul(x.transpose()); INDArray dB = DZ.mmul(Nd4j.ones(x.shape()[1],1)); DW[i] = dW.div(x.shape()[1]); DB[i] = dB.div(x.shape()[1]); }else{ INDArray dW = DZ.mmul(A_array[i-1].transpose()); INDArray dB = DZ.mmul(Nd4j.ones(x.shape()[1],1)); DW[i] = dW.div(x.shape()[1]); DB[i] = dB.div(x.shape()[1]); DZ = activate_backward(Network_W[i].transpose().mmul(DZ),A_array[i-1],layers.get(i-1).getActivateMethod()); } } res.add(DW); res.add(DB); return res; }
(5)梯度下降
private void update_parameters(List<INDArray[]> DW_DB){ INDArray[] DW = DW_DB.get(0); INDArray[] DB = DW_DB.get(1); for(int i=0;i<Network_W.length;i++){ Network_W[i] =Network_W[i].sub(DW[i].mul(getLearningrate())); } for(int i=0;i<Network_B.length;i++){ Network_B[i] =Network_B[i].sub(DB[i].mul(getLearningrate())); } }
(6)loss值计算
private double LossForward(INDArray A,INDArray Y){ return lastLayer.getLossMethod().LossForward(A,Y); }
double loss = LossForward(A[A.length - 1],batch_list.get(j).getY());
(7)组合起来并训练
@Override public void train(TrainData data) { for(int i=0;i<getIteration();i++) { List<INData> batch_list = data.getBatchList(); for(int j=0;j<batch_list.size();j++){ INDArray[] A = forward(batch_list.get(j).getX()); //向前传播 List<INDArray[]> DW_DB = backward(A, batch_list.get(j).getX(), batch_list.get(j).getY()); //反向传播 update_parameters(DW_DB); //梯度下降 double loss = LossForward(A[A.length - 1],batch_list.get(j).getY()); //打印情况 System.out.println("i=" + (i*batch_list.size()+j)); System.out.println("loss=" + loss); } } }
(8)对输入数据进行预测
@Override public INDArray predict(INDArray x) { INDArray[] A = forward(x); return A[A.length-1]; }
(9)其他辅助类
//数学工具类 public class MyMathUtil { public static double Epow(double x){ return Math.pow(Math.E, x);//e^x } public static INDArray Epow(INDArray value){ return FUN_IND(value,v->Epow(v)); } public static double Normalization(double value,double Max){ return value/Max; } public static double MaxValue(INDArray value){ if(value.shape()[0]>1){ double[][] s = value.toDoubleMatrix(); double my_Max=s[0][0]; for(double[] si:s){ for(double sj:si){ my_Max = Math.max(my_Max,sj); } } return my_Max; }else{ double[] s = value.toDoubleVector(); double my_Max=s[0]; for(double si:s){ my_Max = Math.max(my_Max,si); } return my_Max; } } public static INDArray Normalization(INDArray value){ double my_Max = MaxValue(value); return FUN_IND(value,s->s/my_Max); } public static INDArray indArraysubMax(INDArray value){ if(value!=null){ if(value.shape()[0]>1&&value.shape()[1]>1){ double[][] s = value.transpose().toDoubleMatrix(); for(int i=0;i<s.length;i++){ double Max = s[i][0]; for(int j =0;j<s[i].length;j++){ Max = Math.max(Max,s[i][j]); } for(int j =0;j<s[i].length;j++){ s[i][j] = new BigDecimal(s[i][j]).subtract(new BigDecimal(Max)).doubleValue(); } } return Nd4j.create(s).transpose(); }else{ double[] s = value.toDoubleVector(); double Max = s[0]; for(int i=0;i<s.length;i++){ Max = Math.max(Max,s[i]); } for(int i=0;i<s.length;i++){ s[i] = new BigDecimal(s[i]).subtract(new BigDecimal(Max)).doubleValue(); } return Nd4j.create(s).reshape(value.shape()); } } return null; } public static INDArray ONEHOT(INDArray value){ if(value.isColumnVector()){ int[] s = value.toIntVector(); int Max = 0; for(int si:s){ Max = Max>si?Max:si; } double[][] one_hot_res= new double[s.length][Max+1]; for(int i=0;i<s.length;i++){ int val = s[i]; for(int j=0;j<Max+1;j++){ if(val==j){ one_hot_res[i][j] = 1; }else{ one_hot_res[i][j] = 0; } } } return Nd4j.create(one_hot_res); } return null; } public static INDArray FUN_IND(INDArray value, DoubleFunction<Double> doubleFunction){ if(value!=null){ if(value.shape()[0]>1&&value.shape()[1]>1){ double[][] s = value.toDoubleMatrix(); for(int i=0;i<s.length;i++){ for(int j =0;j<s[i].length;j++){ s[i][j] = doubleFunction.apply(s[i][j]); } } return Nd4j.create(s); }else{ double[] s = value.toDoubleVector(); for(int i=0;i<s.length;i++){ s[i] = doubleFunction.apply(s[i]); } return Nd4j.create(s).reshape(value.shape()); } } return null; } public static double MysigMoid(double value) { //Math.E=e;Math.Pow(a,b)=a^b double ey = Math.pow(Math.E, -value); return 1 / (1 + ey); } public static INDArray MysigMoid(INDArray value) { return FUN_IND(value,v->MysigMoid(v)); } public static double Mytanh(double value) { double ex = Math.pow(Math.E, value);// e^x double ey = Math.pow(Math.E, -value);//e^(-x) double sinhx = ex-ey; double coshx = ex+ey; return sinhx/coshx; } public static INDArray Mytanh(INDArray value) { return FUN_IND(value,v->Mytanh(v)); } public static double relu(double value) { return Math.max(0,value); } public static INDArray relu(INDArray value) { return FUN_IND(value,v->relu(v)); } public static double relu_back(double value) { if(value>0){ return value; }else{ return 0; } } public static INDArray relu_back(INDArray value) { return FUN_IND(value,v->relu_back(v)); } public static double Log(double value) { return Math.log(value); } public static INDArray Log(INDArray value) { return FUN_IND(value,v->Log(v)); } public static INDArray sotfmax(INDArray A){ if(A!=null){ A = MyMathUtil.Epow(A); //A: 10,128 INDArray sum_A = Nd4j.ones(1,A.shape()[0]).mmul(A); //1,128 if(A.shape()[0]>1&&A.shape()[1]>1){ double[][] A_s = A.transpose().toDoubleMatrix(); //128 10 double[] SUM_A_s = sum_A.toDoubleVector(); for(int i=0;i<A_s.length;i++){ for(int j =0;j<A_s[i].length;j++){ A_s[i][j] = A_s[i][j]/SUM_A_s[i]; } } return Nd4j.create(A_s).transpose(); }else{ A = MyMathUtil.Epow(A); double[] A_s= A.toDoubleVector(); double SUM_A_s = A.sumNumber().doubleValue(); for(int j =0;j<A_s.length;j++){ A_s[j] = A_s[j]/SUM_A_s; } return Nd4j.create(A_s).reshape(A.shape()); } } return null; } public static INDArray sotfmax_back(INDArray DA,INDArray A){ if(A!=null){ double[][] da = DA.transpose().toDoubleMatrix(); double[][] a = A.transpose().toDoubleMatrix(); double[][] res = new double[da.length][da[0].length]; for(int i=0;i<da.length;i++){ int i_order = 0; for(int j=0;j<da[i].length;j++){ if(da[i][j]!=0){i_order = j;} } for(int j=0;j<da[i].length;j++){ if(j==i_order){ res[i][j] = a[i][j]*(1-a[i][j]); }else{ res[i][j] = -a[i][j]*(a[i][i_order]); } } } return Nd4j.create(res).transpose(); } return null; } }
//激活方法类 public interface ActivateMethod { INDArray activate_forward(INDArray A); INDArray activate_backward(INDArray DA, INDArray A); }
public class Relu implements ActivateMethod { @Override public INDArray activate_forward(INDArray A) { return MyMathUtil.relu(A); } @Override public INDArray activate_backward(INDArray DA, INDArray A) { return MyMathUtil.relu_back(DA); } }
public class SoftMax implements ActivateMethod { @Override public INDArray activate_forward(INDArray A) { return MyMathUtil.sotfmax(MyMathUtil.indArraysubMax(A)); } @Override public INDArray activate_backward(INDArray DA, INDArray A) { return MyMathUtil.sotfmax_back(DA,A); } }
public class Tanh implements ActivateMethod { @Override public INDArray activate_forward(INDArray A) { return MyMathUtil.Mytanh(A); } @Override public INDArray activate_backward(INDArray DA, INDArray A) { return DA.mul(Nd4j.ones(A.shape()).sub(A.mul(A))); } }
//层抽象与实现
public interface Layer { int getNeuralNumber(); ActivateMethod getActivateMethod(); Layer setActivateMethod(ActivateMethod activate); Layer setWInit(Winit wInit); Winit getWinit(); }
public interface LastLayer extends Layer { LossMethod getLossMethod(); LastLayer setLossMethod(LossMethod lossMethod); default INDArray LastBackward(INDArray A,INDArray Y){ return getActivateMethod().activate_backward(getLossMethod().LossBackward(A,Y),A); } }
public class MyLayer implements Layer { private int number; private ActivateMethod activateMethod; private Winit winit; public MyLayer(int number,ActivateMethod method){ this.number = number; this.activateMethod = method; this.winit = new XAVIER(); } public MyLayer(int number,ActivateMethod method,Winit winit){ this.number = number; this.activateMethod = method; this.winit = winit; } @Override public int getNeuralNumber() { return number; } @Override public ActivateMethod getActivateMethod() { return activateMethod; } @Override public Layer setActivateMethod(ActivateMethod activate) { this.activateMethod = activate; return this; } @Override public Layer setWInit(Winit winit) { this.winit = winit; return this; } @Override public Winit getWinit() { return winit; } }
public class MyLastLayer extends MyLayer implements LastLayer { public MyLastLayer(int number,ActivateMethod method,LossMethod lossMethod){ super(number,method); this.lossMethod = lossMethod; } private LossMethod lossMethod; @Override public LossMethod getLossMethod() { return lossMethod; } @Override public LastLayer setLossMethod(LossMethod lossMethod) { this.lossMethod = lossMethod; return this; } }
//损失函数抽象与实现 public interface LossMethod { INDArray LossBackward(INDArray A,INDArray Y); double LossForward(INDArray A,INDArray Y); }
public class CrossEntropy implements LossMethod { @Override public INDArray LossBackward(INDArray A, INDArray Y) { return Nd4j.zeros(Y.div(A).shape()).sub(Y.div(A)); } @Override public double LossForward(INDArray A, INDArray Y) { INDArray los = Y.mul(MyMathUtil.Log(A)); return (0-los.sumNumber().doubleValue())/Y.shape()[1]; } }
public class MSE implements LossMethod { @Override public INDArray LossBackward(INDArray A, INDArray Y) { return A.sub(Y); } @Override public double LossForward(INDArray A, INDArray Y) { return (Y.sub(A).mmul(Y.sub(A).transpose()).sumNumber().doubleValue())/Y.shape()[1]; } }
//w值初始化
public interface Winit { INDArray Init(int seed,int out,int in); }
public class XAVIER implements Winit{ @Override public INDArray Init(int seed,int out, int in) { return Nd4j.randn(out,in,seed).muli(FastMath.sqrt(2.0 / (in+out))); } }
public class RELUWInit implements Winit { @Override public INDArray Init(int seed,int out, int in) { return Nd4j.randn(out,in,seed).muli(FastMath.sqrt(2.0 / in)); } }
public class RandWInit implements Winit { @Override public INDArray Init(int seed, int out, int in) { return Nd4j.rand(out,in,seed); } }
//训练数据类
public interface INData { INDArray getX(); INDArray getY(); int getSize(); }
public interface TrainData extends INData{ List<INData> getBatchList(); }
public class BatchData implements INData{ private INDArray x; private INDArray y; public BatchData(INDArray x,INDArray y){ this.x = x; this.y = y; } @Override public INDArray getX() { return x; } @Override public INDArray getY() { return y; } @Override public int getSize() { return getX().columns(); } public void setX(INDArray x) { this.x = x; } public void setY(INDArray y) { this.y = y; } }
public class MyTrainData implements TrainData { private INDArray x; private INDArray y; private int batch_size; public MyTrainData(INDArray x,INDArray y,int batch_size){ this.x = x.transpose(); this.y = y.transpose(); this.batch_size = batch_size; } public MyTrainData(INDArray x,INDArray y){ this.x = x.transpose(); this.y = y.transpose(); this.batch_size = -1; } @Override public INDArray getX() { return x; } @Override public INDArray getY() { return y; } @Override public List<INData> getBatchList() { List<INData> res = new ArrayList<>(); shufflecard(); if(batch_size!=-1){ int lastColumnOrder = 0; for(int i=batch_size;i<getSize();i=i+batch_size){ INDArray BatchColumn_x = x.get(NDArrayIndex.all(), NDArrayIndex.interval(lastColumnOrder,i)); INDArray BatchColumn_y = y.get(NDArrayIndex.all(), NDArrayIndex.interval(lastColumnOrder,i)); INData data = new BatchData(BatchColumn_x,BatchColumn_y); res.add(data); lastColumnOrder = i; } if(lastColumnOrder!=getSize()){ INDArray BatchColumn_x = x.get(NDArrayIndex.all(), NDArrayIndex.interval(lastColumnOrder,getSize())); INDArray BatchColumn_y = y.get(NDArrayIndex.all(), NDArrayIndex.interval(lastColumnOrder,getSize())); INData data = new BatchData(BatchColumn_x,BatchColumn_y); res.add(data); } }else{ INData data = new BatchData(getX(),getY()); res.add(data); } return res; } @Override public int getSize() { return getX().columns(); } public void setX(INDArray x) { this.x = x; } public void setY(INDArray y) { this.y = y; } public void shufflecard(){ Random rd = new Random(); INDArray temp_x ; INDArray temp_y ; for(int i=0;i<getSize();i++){ int j = rd.nextInt(getSize()); temp_x = x.getColumn(i).add(0); temp_y = y.getColumn(i).add(0); x.putColumn(i,x.getColumn(j).add(0)); x.putColumn(j,temp_x); y.putColumn(i,y.getColumn(j)); y.putColumn(j,temp_y); } } }
3.创建神经网络并数据输入训练:
public static final ClassPathResource TRAIN_IMAGES_FILE = new ClassPathResource("data/train-images.idx3-ubyte"); public static final ClassPathResource TRAIN_LABELS_FILE = new ClassPathResource("data/train-labels.idx1-ubyte"); public static final ClassPathResource TEST_IMAGES_FILE = new ClassPathResource("data/t10k-images.idx3-ubyte"); public static final ClassPathResource TEST_LABELS_FILE = new ClassPathResource("data/t10k-labels.idx1-ubyte");
private model pointmodel = new DeepNeuralNetWork(28*28) .addLayer(new MyLayer(1000,new Tanh())) .addLayer(new MyLayer(500,new Tanh())) .addLayer(new MyLayer(100,new Tanh())) .addLastLayer(new SotfMaxCrossEntropyLastLayer(10)).setIteration(10).setLearningrate(0.06);
double[][] images = MnistReadUtil.getImages(TRAIN_IMAGES_FILE.getInputStream()); double[] labels = MnistReadUtil.getLabels(TRAIN_LABELS_FILE.getInputStream()); INDArray X = Nd4j.create(images); //60000,784 INDArray Y = Nd4j.create(labels).transpose(); //60000,1 INDArray X_I = MyMathUtil.Normalization(X); INDArray Y_I = MyMathUtil.ONEHOT(Y);//60000,10 TrainData data = new MyTrainData(X_I,Y_I,128); pointmodel.train(data);
4.测试模型
double[][] t_images = MnistReadUtil.getImages(TEST_IMAGES_FILE.getInputStream()); double[] t_labels = MnistReadUtil.getLabels(TEST_LABELS_FILE.getInputStream()); INDArray X_t = MyMathUtil.Normalization(Nd4j.create(t_images)); INDArray Y_t = MyMathUtil.ONEHOT(Nd4j.create(t_labels).transpose()); TrainData data_t = new MyTrainData(X_t,Y_t); INDArray X_P = pointmodel.predict(data_t.getX()); System.out.println("正确率:"+scord(X_P,data_t.getY())+"%");
//找出概率最大的值并与标签比较。
private float scord(INDArray value,INDArray Y) { int res = 0; int sum = 0; double[][] s = value.transpose().toDoubleMatrix(); double[][] Ys = Y.transpose().toDoubleMatrix(); for(int i=0;i<s.length;i++){ double Max = -1; int order = -1; for(int j =0;j<s[i].length;j++){ order = Max>s[i][j]?order:j; Max = Max>s[i][j]?Max:s[i][j]; } if(order>0&&new Double(Ys[i][order]).intValue()==1){ res++; } sum++; } if(sum>0){ return ((float)res/sum)*100; }else{ return 0; } }
5.输入图片并预测
public String predict(@RequestParam(value = "file") MultipartFile file, ModelMap map){ if (file.isEmpty()) { System.out.println("文件为空空"); } try{ File my_file = File.createTempFile("tmp", null); file.transferTo(my_file); double[] m = MnistReadUtil.getSizeBlackWhiteImg(my_file,28,28); INDArray X_t = MyMathUtil.Normalization(Nd4j.create(m)); INDArray X_P = pointmodel.predict(X_t.transpose()); int number = getnumber(X_P); map.addAttribute ("number",number); return "freemarker/mnist/predict"; }catch (Exception e){ e.printStackTrace(); } return "freemarker/fail"; }
private int getnumber(INDArray X){ double[] s = X.toDoubleVector(); int res = 0; double Max = s[0]; for(int i=0;i<s.length;i++){ if(Max<s[i]){ Max = s[i]; res = i; } } return res; }
完整代码地址:git@github.com:woshiyigebing/my_dl4j.git