这篇主要是看下如何用神经网络做回归并画图,贴代码
public class CSVPlotter { public static void main( String[] args ) throws IOException, InterruptedException { String filename = new ClassPathResource("/DataExamples/CSVPlotData.csv").getFile().getPath();//获取路径 DataSet ds = readCSVDataset(filename);//csv数据的第一列是输入数据,第二列是输出数据 ArrayList<DataSet> DataSetList = new ArrayList<>();//构造一个arraylist,元素是dataset DataSetList.add(ds);//其实只有一个元素 plotDataset(DataSetList); //Plot the data, make sure we have the right data.//调用plotDataset函数画图 MultiLayerNetwork net =fitStraightline(ds);//调用fitStraightline函数,用数据构建一个网络 // Get the min and max x values, using Nd4j NormalizerMinMaxScaler preProcessor = new NormalizerMinMaxScaler();//用最大值最小值规范化 preProcessor.fit(ds); int nSamples = 50;//样本数50 INDArray x = Nd4j.linspace(preProcessor.getMin().getInt(0),preProcessor.getMax().getInt(0),nSamples).reshape(nSamples, 1);//linspace会生成向量,第一个参数是最小值,第二个参数是最大值,第三个参数是步长,reshape是重新搞成一个几成几的矩阵,这里还是50行1列,x就是最终的x轴数据 INDArray y = net.output(