由于需要做的是时间预测,所以先选定了LSTM来做。由于不给调包,自己理解不够深入,所以自己手写了一个LSTM,但是这个效果很不好,最后没有用上这个模型,但是也算是自己手撕了这么多行代码,发上来和大家分享吧。这个代码主要的公式都是参考的https://zybuluo.com/hanbingtao/note/581764。但是这个作者的有些公式似乎有点纰漏,如公式(57)-(62)的向量转置似乎有误。但是总的来说,还是写得非常好的文章,介绍地非常详细,也郑重感谢作者~话不多说,直接上代码吧。
由于最后使用的并不是这段代码,也就是说相当于白写了,但是这个过程还是很能得到锻炼的,算是初步了解了神经网络,而且极大提高了代码能力···这些代码几乎都是一个人在两天之内写出来的。本来就不是科班出身,学的是光学,平时代码量也不大,能做到这样,也算不错了吧感觉,哈哈。
1、类RunLSTM负责的是总的前向传播和后向误差传导过程,就是整个训练的过程
package LSTMUnit;
import java.util.ArrayList;
import java.util.HashMap;
import matrixUnits.MatrixUnits;
import dataUtil.*;
import LSTMUnit.SupportFunction;
public class RunLSTM {//最终在这里执行
public static void main(String[] args) {
long startTime=System.currentTimeMillis();
String filePath="C:\\Users\\lulu\\FangCloudV2\\个人文件\\华为比赛相关\\20180404lstm\\TrainData.txt";
String[] ecsContent=FileUtil.read(filePath, null);
double[][] allData=DataUtilLstmNew.loadDataFromStringArraynoCPUMEMNoWeek(ecsContent);
allData=MatrixUnits.matrixT(allData);
allData=InputDataProcess.addWeekDay(allData, 3);
// MatrixUnits.printMatrix(allData);
// System.out.println("**********************************");
// allData=InputDataProcess.log(allData,1, Math.E);
// MatrixUnits.printMatrix(allData);
// System.exit(0);
int dx=MatrixUnits.getRow(allData);
int days=7;//一次输入多少天数据。下标都从0开始
int dhc=15;//输出的向量的长度
double learningRate=0.01;
//下面都是计算过程中需要用到的变量,一共要维护8个表
ArrayList<HashMap<String, double[][]>> deltaList=new ArrayList<HashMap<String, double[][]>>();//存放各天的误差项
//days个数据每天有deltat,deltaot,deltaft,deltait,deltagt,共5类
ArrayList<HashMap<String, double[][]>> gradList=new ArrayList<HashMap<String, double[][]>>();//存放各天的权重梯度和偏置项梯度
//days个数据每天有Wfhgradt,Wihgradt,Wghgradt,Wohgradt,Wfxgradt,Wixgradt,Wgxgradt,Woxgradt,
//bfgradt,bigradt,bggradt,bogradt,共12类
ArrayList<double[][]> ht_1andLasthtList=new ArrayList<double[][]>();//这个在前向过程中存数据,一共要存days+1个数据
ArrayList<double[][]> ct_landLastctList=new ArrayList<double[][]>();//这个在前向过程中存数据,一共要存days+1个数据
ArrayList<HashMap<String, double[][]>> gateList=new ArrayList<HashMap<String, double[][]>>();//存各种门,days个数据
//days个数据每天有ft,it,gt,ot共四类
HashMap<String, double[][]> finalGrad=new HashMap<String, double[][]>();
//最终的权重梯度,Wfhgrad,Wihgrad,Wghgrad,Wohgrad,Wfxgrad,Wixgrad,Wgxgrad,Woxgrad,
//bfgrad,bigrad,bggrad,bograd,共12类
ArrayList<HashMap<String, double[][]>> netList=new ArrayList<HashMap<String, double[][]>>();
//days个数据,存储每个门的加权输入,分别有netft,netit,netgt,netot
ArrayList<double[][]> xtList=new ArrayList<double[][]>();//输入表
SingleCell cell=new SingleCell(dx, dhc);
// System.out.println("firstwoh");
// MatrixUnits.printMatrix(cell.Woh);
initTempList(xtList,netList, gateList, ht_1andLasthtList, ct_landLastctList, deltaList, gradList, finalGrad, days);
int predictDays=7;
double[][] predictResult=new double[dhc][0];
train(xtList, netList, gateList, ht_1andLasthtList, ct_landLastctList, deltaList, gradList,finalGrad,
allData, days, cell, learningRate,dhc,dx,predictDays,predictResult);
// System.out.println("lasttwoh");
// MatrixUnits.printMatrix(cell.Woh);
long endTime=System.currentTimeMillis();//记录结束时间
double excTime=(double)(endTime-startTime)/1000;
System.out.println("Running "+excTime+"s");
}
public static void initTempList(
ArrayList<double[][]> xtList,
ArrayList<HashMap<String, double[][]>> netList,
ArrayList<HashMap<String, double[][]>> gateList,
ArrayList<double[][]> ht_1andLasthtList,
ArrayList<double[][]> ct_landLastctList,
ArrayList<HashMap<String, double[][]>> deltaList,
ArrayList<HashMap<String, double[][]>> gradList,
HashMap<String, double[][]> finalGrad,int days) {//对7个存储中间变量的列表进行初始化
//System.out.println("init six TempLists start:");
finalGrad.put("Wfhgrad", null);
finalGrad.put("Wihgrad", null);
finalGrad.put("Wghgrad", null);
finalGrad.put("Wohgrad", null);
finalGrad.put("Wfxgrad", null);
finalGrad.put("Wixgrad", null);
finalGrad.put("Wgxgrad", null);
finalGrad.put("Woxgrad", null);
finalGrad.put("bfgrad", null);
finalGrad.put("bigrad", null);
finalGrad.put("bggrad", null);
finalGrad.put("bograd", null);
ht_1andLasthtList.add(null);
ct_landLastctList.add(null);
for(int i=0;i<days;i++) {
xtList.add(null);
netList.add(new HashMap<String, double[][]>());
netList.get(netList.size()-1).put("netft", null);
netList.get(netList.size()-1).put("netit", null);
netList.get(netList.size()-1).put("netgt", null);
netList.get(netList.size()-1).put("netot", null);
gateList.add(new HashMap<String, double[][]>());
gateList.get(gateList.size()-1).put("ft", null);
gateList.get(gateList.size()-1).put("it", null);
gateList.get(gateList.size()-1).put("gt", null);
gateList.get(gateList.size()-1).put("ot", null);
ht_1andLasthtList.add(null);
ct_landLastctList.add(null);
deltaList.add(new HashMap<String, double[][]>());
gradList.add(new HashMap<String, double[][]>());
deltaList.get(i).put("deltat", null);
deltaList.get(i).put("deltaot", null);
deltaList.get(i).put("deltaft", null);
deltaList.get(i).put("deltait", null);
deltaList.get(i).put("deltagt", null);
gradList.get(i).put("Wfhgradt", null);
gradList.get(i).put("Wihgradt", null);
gradList.get(i).put("Wghgradt", null);
gradList.get(i).put("Wohgradt", null);
gradList.get(i).put("Wfxgradt", null);
gradList.get(i).put("Wixgradt", null);
gradList.get(i).put("Wgxgradt", null);
gradList.get(i).put("Woxgradt", null);
gradList.get(i).put("bfgradt", null);
gradList.get(i).put("bigradt", null);
gradList.get(i).put("bggradt", null);
gradList.get(i).put("bogradt", null);
}
//System.out.println("init six TempLists end:");
}
public static void clearTempList(
ArrayList<HashMap<String, double[][]>> netList,
ArrayList<HashMap<String, double[][]>> gateList,
ArrayList<double[][]> ht_1andLasthtList,
ArrayList<double[][]> ct_landLastctList,
ArrayList<HashMap<String, double[][]>> deltaList,
ArrayList<HashMap<String, double[][]>> gradList,
HashMap<String, double[][]> finalGrad) {//对7个存储中间变量的列表进行清除
System.out.println("clean six TempLists start:");
if(netList!=null)
netList.clear();
if(gateList!=null)
gateList.clear();
if(ht_1andLasthtList!=null)
ht_1andLasthtList.clear();
if(ct_landLastctList!=null)
ct_landLastctList.clear();
if(deltaList!=null)
deltaList.clear();
if(gradList!=null)
gradList.clear();
if(finalGrad!=null)
finalGrad.clear();
System.out.println("clean six TempLists end:");
}
public static void forwardProcess(
ArrayList<double[][]> xtList,
ArrayList<HashMap<String, double[][]>> netList,
ArrayList<HashMap<String, double[][]>> gateList,
ArrayList<double[][]> ct_landLastctList,
ArrayList<double[][]> ht_1andLasthtList,
double[][] trainCurrentData,SingleCell cell,int dhc,int dx) {//其实所谓的前向和后向过程,就是利用输入的一个时间窗口,
//对一次时间窗口 移动更新各种临时变量,最后更新权值矩阵和偏置量
//对于forwardProcess,就是通过各时刻输入表,更新各时刻ct表、net表、ht表、gate表
//System.out.println("forwardProcess start:");
int window=MatrixUnits.getCol(trainCurrentData);
//1、更新输入表
xtList.clear();//xtList是可以clear掉的,没啥问题
for(int i=0;i<window;i++) {
double[][] temp=MatrixUnits.getPartOfAMatrix(trainCurrentData, 0, dx-1, i, i);
xtList.add((double[][])temp.clone());
}
//2、更新ct和ht表的第一项,第一次循环的时候要随机化
if(ct_landLastctList.get(0)==null||ht_1andLasthtList.get(0)==null) {
double[][] temp=MatrixUnits.getARandomMatrix(dhc, 1, 0, 1);
ct_landLastctList.set(0, temp);
ht_1andLasthtList.set(0, (double[][])temp.clone());
}
else {//以后的循环直接往左推移
double[][] tempct_1=(double[][])ct_landLastctList.get(1).clone();
double[][] tempct_2=(double[][])ht_1andLasthtList.get(1).clone();
//clearTempList(netList, gateList, ht_1andLasthtList, ct_landLastctList, null, null, null);
ct_landLastctList.set(0, tempct_1);
ht_1andLasthtList.set(0, tempct_2);
}
//3、通过前向计算依次更新ct表,ht表,nett表和gete表,
for(int i=1;i<window+1;i++) {
double[][] xt=xtList.get(i-1);
double[][] ct_1=ct_landLastctList.get(i-1);
double[][] ht_1=ht_1andLasthtList.get(i-1);
double[][] netft,netit,netgt,netot,ft,it,gt,ct,ot,ht;
netft=MatrixUnits.matrixAdd(MatrixUnits.matrixAdd(MatrixUnits.matrixNormalMul(cell.Wfh, ht_1),
MatrixUnits.matrixNormalMul(cell.Wfx, xt)),cell.bf);
ft=SupportFunction.sigmoid(netft);
//输入门的计算
netit=MatrixUnits.matrixAdd(MatrixUnits.matrixAdd(MatrixUnits.matrixNormalMul(cell.Wih, ht_1),
MatrixUnits.matrixNormalMul(cell.Wix, xt)),cell.bi);
it=SupportFunction.sigmoid(netit);
//描述当前输入的单元状态
netgt=MatrixUnits.matrixAdd(MatrixUnits.matrixAdd(MatrixUnits.matrixNormalMul(cell.Wgh, ht_1),
MatrixUnits.matrixNormalMul(cell.Wgx, xt)),cell.bg);
gt=SupportFunction.tanh(netgt);
//输出单元状态
ct=MatrixUnits.matrixAdd(MatrixUnits.matrixHadamardMul(ft, ct_1), MatrixUnits.matrixHadamardMul(it, gt));
//输出门
netot=MatrixUnits.matrixAdd(MatrixUnits.matrixAdd(MatrixUnits.matrixNormalMul(cell.Woh, ht_1),
MatrixUnits.matrixNormalMul(cell.Wox, xt)),cell.bo);
ot=SupportFunction.sigmoid(netot);
//最终输出
ht=MatrixUnits.matrixHadamardMul(ot, SupportFunction.tanh(ct));
ct_landLastctList.set(i, ct);
ht_1andLasthtList.set(i, ht);
gateList.get(i-1).put("ft", ft);
gateList.get(i-1).put("it", it);
gateList.get(i-1).put("gt", gt);
gateList.get(i-1).put("ot", ot);
netList.get(i-1).put("netft", netft);
netList.get(i-1).put("netit", netit);
netList.get(i-1).put("netgt", netgt);
netList.get(i-1).put("netot", netot);
}
//System.out.println("forwardProcess end:");
}
public static void backwardProcess(
ArrayList<double[][]> xtList,
ArrayList<HashMap<String, double[][]>> deltaList,
ArrayList<double[][]> ct_landLastctList,
ArrayList<double[][]> ht_1andLasthtList,
ArrayList<HashMap<String, double[][]>> gateList,
ArrayList<HashMap<String, double[][]>> gradList,
HashMap<String, double[][]> finalGrad,
double[][] trainCurrentData,
SingleCell cell,int dhc,int dx,double[][] nextTime) {
//System.out.println("backwardProcess start:");
int window=MatrixUnits.getCol(trainCurrentData);
//1、首先更新deltaList
double[][] ct=ct_landLastctList.get(window);
double[][] ot=gateList.get(window-1).get("ot");
double[][] ft=gateList.get(window-1).get("ft");
double[][] it=gateList.get(window-1).get("it");
double[][] gt=gateList.get(window-1).get("gt");
double[][] ct_1=ct_landLastctList.get(window-1);
double[][] matrix1=MatrixUnits.getAZeromatrix(dhc, 1);//全1矩阵,这个一直到后面都能用的
double[][] tanhct21_=MatrixUnits.matrixSub(matrix1,//1-tanh^2ct
MatrixUnits.matrixHadamardMul(SupportFunction.tanh(ct), SupportFunction.tanh(ct)));
double[][] deltat=new double[dhc][1];
// //**************这里取的是误差函数1/2(t^2-y^2)的导数的相反数,而且是各个输出都要计算然后相加
// for(int i=0;i<window;i++){
// double[][] yData=ht_1andLasthtList.get(i+1);
// double[][] tempdelta=new double[dhc][1];
// double[][] target=new double[dhc][1];
// if(i<window-1)
// target=xtList.get(i+1);
// else
// target=nextTime.clone();
// for(int j=0;j<dhc;j++)
// tempdelta[j][0]=yData[i][0]*(1-yData[i][0])*(target[i][0]-yData[i][0]);
// deltat=MatrixUnits.matrixAdd(deltat, tempdelta);
// }
// //*************************
// **********这里是只算最后一个输出的delta
double[][] yData=ht_1andLasthtList.get(window);
double[][] target=nextTime.clone();
for(int j=0;j<dhc;j++)
deltat[j][0]=(target[j][0]-yData[j][0])*yData[j][0]*(1-yData[j][0]);
// *************************
deltaList.get(window-1).put("deltat", deltat);
double[][] deltaot=MatrixUnits.matrixHadamardMul(MatrixUnits.matrixHadamardMul(MatrixUnits.
matrixHadamardMul(deltat, SupportFunction.tanh(ct)),ot),MatrixUnits.matrixSub(matrix1,ot));
double[][] deltaft=MatrixUnits.matrixHadamardMul(deltat,MatrixUnits.matrixHadamardMul(ot,
MatrixUnits.matrixHadamardMul(tanhct21_, MatrixUnits.matrixHadamardMul(ct_1, MatrixUnits.
matrixHadamardMul(ft, MatrixUnits.matrixSub(matrix1, ft))))));
double[][] deltait=MatrixUnits.matrixHadamardMul(deltat,MatrixUnits.matrixHadamardMul(ot,
MatrixUnits.matrixHadamardMul(tanhct21_, MatrixUnits.matrixHadamardMul(gt, MatrixUnits.
matrixHadamardMul(it, MatrixUnits.matrixSub(matrix1, it))))));
double[][] deltagt=MatrixUnits.matrixHadamardMul(deltat,MatrixUnits.matrixHadamardMul(ot,
MatrixUnits.matrixHadamardMul(tanhct21_, MatrixUnits.matrixHadamardMul(it,
MatrixUnits.matrixSub(matrix1, MatrixUnits.matrixHadamardMul(gt, gt))))));
deltaList.get(window-1).put("deltaot", deltaot);
deltaList.get(window-1).put("deltaft", deltaft);
deltaList.get(window-1).put("deltait", deltait);
deltaList.get(window-1).put("deltagt", deltagt);
//截止到现在为止,更新完了最新一天的deltat,deltaot,deltaft,deltait,deltagt,然后后面的天就一直依靠他们来不断进行更新
for(int i=window-2;i>=0;i--){
double[][] tempct=ct_landLastctList.get(i+1);
double[][] tempct_1=ct_landLastctList.get(i);
double[][] tempot=gateList.get(i).get("ot");
double[][] tempft=gateList.get(i).get("ft");
double[][] tempit=gateList.get(i).get("it");
double[][]