这是一个非常有启发的例子,可以扩展到生产环境做一些模型!
public class PredictGenderTrain { public String filePath; public static void main(String args[]) { PredictGenderTrain dg = new PredictGenderTrain();//生成类实例,这种写法忘了叫什么了,故弄玄虚的感觉,谁知道告我一下,我喜欢主类里只有main函数的写法 dg.filePath = System.getProperty("user.dir") + "\\src\\main\\resources\\PredictGender\\Data";//找到数据路径 dg.train();//调用train函数 } /** * This function uses GenderRecordReader and passes it to RecordReaderDataSetIterator for further training. */ public void train() { int seed = 123456; double learningRate = 0.01; int batchSize = 100; int nEpochs = 100; int numInputs = 0; int numOutputs = 0; int numHiddenNodes = 0; try(GenderRecordReader rr = new GenderRecordReader(new ArrayList<String>() { {add("M");add("F");}}))//这个try里面有小括号我也是头一次注意,括号里一般都是输入输出流,训练数据读取器作为临时变量,过后就会被自动回收,这里调用性别读取器类,后面会有这个类的详细解释 { long st = System.currentTimeMillis();//打印当前时间 System.out.println("Preprocessing start time : " + st); rr.initialize(new FileSplit(new File(this.filePath)));//初始化读取器 long et = System.currentTimeMillis();//打印当前时间,处理时间 System.out.println("Preprocessing end time : " + et); System.out.println("time taken to process data : " + (et-st) + " ms"); numInputs = rr.maxLengthName * 5; // multiplied by 5 as for each letter we use five binary digits like 00000//每个字符用5个二进制表示,输入大小就是最长名字的5倍 numOutputs = 2;//输出大小为2 numHiddenNodes = 2 * numInputs + numOutputs;//隐含层大小 GenderRecordReader rr1 = new GenderRecordReader(new ArrayList<String>() { {add("M");add("F");}});//又搞了一个读取器 DataSetIterator trainIter = new RecordReaderDataSetIterator(rr, batchSize, numInputs, 2);//训练迭代器 System.out.println(trainIter); //System.exit(0); DataSetIterator testIter = new RecordReaderDataSetIterator(rr1, batchSize, numInputs, 2);//测试迭代器 MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()//网络还是一样,假装自己是老司机 .seed(seed) .biasInit(1) .regularization(true).l2(1e-4) .iterations(1) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .learningRate(learningRate) .updater(Updater.NESTEROVS).momentum(0.9)//采用梯度修正的参数修正方法 .list() .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes) .weightInit(WeightInit.XAVIER) .activation("relu") .build()) .layer(1, new DenseLayer.Builder().nIn(numHiddenNodes).nOut(numHiddenNodes) .weightInit(WeightInit.XAVIER) .activation("relu") .build()) .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MSE) .weightInit(WeightInit.XAVIER) .activation("softmax") .nIn(numHiddenNodes).nOut(numOutputs).build()) .pretrain(false).backprop(true).build(); MultiLayerNetwork model = new MultiLayerNetwork(conf);