2018华为软件精英挑战赛系列3——自己手撕的LSTM

本文分享了作者在华为软件精英挑战赛中尝试使用自编LSTM模型进行时间序列预测的经历。虽然最终未采用此模型,但过程中深入理解了神经网络,并提升了代码能力。文中提到参考了一篇详细教程,同时也指出其部分公式可能存在错误。

    由于需要做的是时间预测,所以先选定了LSTM来做。由于不给调包,自己理解不够深入,所以自己手写了一个LSTM,但是这个效果很不好,最后没有用上这个模型,但是也算是自己手撕了这么多行代码,发上来和大家分享吧。这个代码主要的公式都是参考的https://zybuluo.com/hanbingtao/note/581764。但是这个作者的有些公式似乎有点纰漏,如公式(57)-(62)的向量转置似乎有误。但是总的来说,还是写得非常好的文章,介绍地非常详细,也郑重感谢作者~话不多说,直接上代码吧。

    由于最后使用的并不是这段代码,也就是说相当于白写了,但是这个过程还是很能得到锻炼的,算是初步了解了神经网络,而且极大提高了代码能力···这些代码几乎都是一个人在两天之内写出来的。本来就不是科班出身,学的是光学,平时代码量也不大,能做到这样,也算不错了吧感觉,哈哈。

1、类RunLSTM负责的是总的前向传播和后向误差传导过程,就是整个训练的过程

package LSTMUnit;
import java.util.ArrayList;
import java.util.HashMap;
import matrixUnits.MatrixUnits;
import dataUtil.*;
import LSTMUnit.SupportFunction;


public class RunLSTM {//最终在这里执行
	public static void main(String[] args) {
		long startTime=System.currentTimeMillis();
		
		String filePath="C:\\Users\\lulu\\FangCloudV2\\个人文件\\华为比赛相关\\20180404lstm\\TrainData.txt";
		String[] ecsContent=FileUtil.read(filePath, null);
		double[][] allData=DataUtilLstmNew.loadDataFromStringArraynoCPUMEMNoWeek(ecsContent);
		
		allData=MatrixUnits.matrixT(allData);
		allData=InputDataProcess.addWeekDay(allData, 3);
//		MatrixUnits.printMatrix(allData);
//		System.out.println("**********************************");
//		allData=InputDataProcess.log(allData,1, Math.E);
//		MatrixUnits.printMatrix(allData);
//		System.exit(0);
		
		int dx=MatrixUnits.getRow(allData);
		int days=7;//一次输入多少天数据。下标都从0开始
		int dhc=15;//输出的向量的长度
		
		double learningRate=0.01;
		//下面都是计算过程中需要用到的变量,一共要维护8个表
		ArrayList<HashMap<String, double[][]>> deltaList=new ArrayList<HashMap<String, double[][]>>();//存放各天的误差项
		//days个数据每天有deltat,deltaot,deltaft,deltait,deltagt,共5类
		ArrayList<HashMap<String, double[][]>> gradList=new ArrayList<HashMap<String, double[][]>>();//存放各天的权重梯度和偏置项梯度
		//days个数据每天有Wfhgradt,Wihgradt,Wghgradt,Wohgradt,Wfxgradt,Wixgradt,Wgxgradt,Woxgradt,
		//bfgradt,bigradt,bggradt,bogradt,共12类
		ArrayList<double[][]> ht_1andLasthtList=new ArrayList<double[][]>();//这个在前向过程中存数据,一共要存days+1个数据
		ArrayList<double[][]> ct_landLastctList=new ArrayList<double[][]>();//这个在前向过程中存数据,一共要存days+1个数据
		ArrayList<HashMap<String, double[][]>> gateList=new ArrayList<HashMap<String, double[][]>>();//存各种门,days个数据
		//days个数据每天有ft,it,gt,ot共四类
		HashMap<String, double[][]> finalGrad=new HashMap<String, double[][]>();
		//最终的权重梯度,Wfhgrad,Wihgrad,Wghgrad,Wohgrad,Wfxgrad,Wixgrad,Wgxgrad,Woxgrad,
		//bfgrad,bigrad,bggrad,bograd,共12类
		ArrayList<HashMap<String, double[][]>> netList=new ArrayList<HashMap<String, double[][]>>();
		//days个数据,存储每个门的加权输入,分别有netft,netit,netgt,netot
		ArrayList<double[][]> xtList=new ArrayList<double[][]>();//输入表
		
		SingleCell cell=new SingleCell(dx, dhc);
//		System.out.println("firstwoh");
//		MatrixUnits.printMatrix(cell.Woh);
		initTempList(xtList,netList, gateList, ht_1andLasthtList, ct_landLastctList, deltaList, gradList, finalGrad, days);
		int predictDays=7;
		double[][] predictResult=new double[dhc][0];
		train(xtList, netList, gateList, ht_1andLasthtList, ct_landLastctList, deltaList, gradList,finalGrad, 
		allData, days, cell, learningRate,dhc,dx,predictDays,predictResult);
		
//		System.out.println("lasttwoh");
//		MatrixUnits.printMatrix(cell.Woh);
		
		long endTime=System.currentTimeMillis();//记录结束时间  
		double excTime=(double)(endTime-startTime)/1000;  
		System.out.println("Running  "+excTime+"s");
		
		
	}
	
	public static void initTempList(
	ArrayList<double[][]> xtList,
	ArrayList<HashMap<String, double[][]>> netList,
	ArrayList<HashMap<String, double[][]>> gateList,
	ArrayList<double[][]> ht_1andLasthtList,
	ArrayList<double[][]> ct_landLastctList,
	ArrayList<HashMap<String, double[][]>> deltaList,
	ArrayList<HashMap<String, double[][]>> gradList,
	HashMap<String, double[][]> finalGrad,int days) {//对7个存储中间变量的列表进行初始化
		//System.out.println("init six TempLists start:");
		
		finalGrad.put("Wfhgrad", null);
		finalGrad.put("Wihgrad", null);
		finalGrad.put("Wghgrad", null);
		finalGrad.put("Wohgrad", null);
		finalGrad.put("Wfxgrad", null);
		finalGrad.put("Wixgrad", null);
		finalGrad.put("Wgxgrad", null);
		finalGrad.put("Woxgrad", null);
		finalGrad.put("bfgrad", null);
		finalGrad.put("bigrad", null);
		finalGrad.put("bggrad", null);
		finalGrad.put("bograd", null);
		
		ht_1andLasthtList.add(null);
		ct_landLastctList.add(null);
		for(int i=0;i<days;i++) {
			xtList.add(null);
			
			netList.add(new HashMap<String, double[][]>());
			netList.get(netList.size()-1).put("netft", null);
			netList.get(netList.size()-1).put("netit", null);
			netList.get(netList.size()-1).put("netgt", null);
			netList.get(netList.size()-1).put("netot", null);
			
			gateList.add(new HashMap<String, double[][]>());
			gateList.get(gateList.size()-1).put("ft", null);
			gateList.get(gateList.size()-1).put("it", null);
			gateList.get(gateList.size()-1).put("gt", null);
			gateList.get(gateList.size()-1).put("ot", null);
			
			
			ht_1andLasthtList.add(null);
			ct_landLastctList.add(null);
			
			deltaList.add(new HashMap<String, double[][]>());
			gradList.add(new HashMap<String, double[][]>());
			
			deltaList.get(i).put("deltat", null);
			deltaList.get(i).put("deltaot", null);
			deltaList.get(i).put("deltaft", null);
			deltaList.get(i).put("deltait", null);
			deltaList.get(i).put("deltagt", null);
			
			gradList.get(i).put("Wfhgradt", null);
			gradList.get(i).put("Wihgradt", null);
			gradList.get(i).put("Wghgradt", null);
			gradList.get(i).put("Wohgradt", null);
			gradList.get(i).put("Wfxgradt", null);
			gradList.get(i).put("Wixgradt", null);
			gradList.get(i).put("Wgxgradt", null);
			gradList.get(i).put("Woxgradt", null);
			gradList.get(i).put("bfgradt", null);
			gradList.get(i).put("bigradt", null);
			gradList.get(i).put("bggradt", null);
			gradList.get(i).put("bogradt", null);
		}
		//System.out.println("init six TempLists end:");
	}
	public static void clearTempList(
	ArrayList<HashMap<String, double[][]>> netList,
	ArrayList<HashMap<String, double[][]>> gateList,
	ArrayList<double[][]> ht_1andLasthtList,
	ArrayList<double[][]> ct_landLastctList,
	ArrayList<HashMap<String, double[][]>> deltaList,
	ArrayList<HashMap<String, double[][]>> gradList,
	HashMap<String, double[][]> finalGrad) {//对7个存储中间变量的列表进行清除
		System.out.println("clean six TempLists start:");
		if(netList!=null)
			netList.clear();
		if(gateList!=null)
			gateList.clear();
		if(ht_1andLasthtList!=null)
			ht_1andLasthtList.clear();
		if(ct_landLastctList!=null)
			ct_landLastctList.clear();
		if(deltaList!=null)
			deltaList.clear();
		if(gradList!=null)
			gradList.clear();
		if(finalGrad!=null)
			finalGrad.clear();
		System.out.println("clean six TempLists end:");
	}
	public static void forwardProcess(
	ArrayList<double[][]> xtList,
	ArrayList<HashMap<String, double[][]>> netList,
	ArrayList<HashMap<String, double[][]>> gateList,
	ArrayList<double[][]> ct_landLastctList,
	ArrayList<double[][]> ht_1andLasthtList,
	double[][] trainCurrentData,SingleCell cell,int dhc,int dx) {//其实所谓的前向和后向过程,就是利用输入的一个时间窗口,
	//对一次时间窗口 移动更新各种临时变量,最后更新权值矩阵和偏置量
	//对于forwardProcess,就是通过各时刻输入表,更新各时刻ct表、net表、ht表、gate表
		//System.out.println("forwardProcess start:");
		int window=MatrixUnits.getCol(trainCurrentData);
		//1、更新输入表
		xtList.clear();//xtList是可以clear掉的,没啥问题
		for(int i=0;i<window;i++) {
			double[][] temp=MatrixUnits.getPartOfAMatrix(trainCurrentData, 0, dx-1, i, i);
			xtList.add((double[][])temp.clone());
		}
		//2、更新ct和ht表的第一项,第一次循环的时候要随机化
		if(ct_landLastctList.get(0)==null||ht_1andLasthtList.get(0)==null) {
			double[][] temp=MatrixUnits.getARandomMatrix(dhc, 1, 0, 1);
			ct_landLastctList.set(0, temp);
			ht_1andLasthtList.set(0, (double[][])temp.clone());
		}
		else {//以后的循环直接往左推移
			double[][] tempct_1=(double[][])ct_landLastctList.get(1).clone();
			double[][] tempct_2=(double[][])ht_1andLasthtList.get(1).clone();
			//clearTempList(netList, gateList, ht_1andLasthtList, ct_landLastctList, null, null, null);
			ct_landLastctList.set(0, tempct_1);
			ht_1andLasthtList.set(0, tempct_2);
		}
		//3、通过前向计算依次更新ct表,ht表,nett表和gete表,
		for(int i=1;i<window+1;i++) {
			double[][] xt=xtList.get(i-1);
			double[][] ct_1=ct_landLastctList.get(i-1);
			double[][] ht_1=ht_1andLasthtList.get(i-1);
			double[][] netft,netit,netgt,netot,ft,it,gt,ct,ot,ht;
			
			netft=MatrixUnits.matrixAdd(MatrixUnits.matrixAdd(MatrixUnits.matrixNormalMul(cell.Wfh, ht_1),
			MatrixUnits.matrixNormalMul(cell.Wfx, xt)),cell.bf);
			ft=SupportFunction.sigmoid(netft);
			//输入门的计算
			netit=MatrixUnits.matrixAdd(MatrixUnits.matrixAdd(MatrixUnits.matrixNormalMul(cell.Wih, ht_1),
			MatrixUnits.matrixNormalMul(cell.Wix, xt)),cell.bi);
			it=SupportFunction.sigmoid(netit);
			//描述当前输入的单元状态
			netgt=MatrixUnits.matrixAdd(MatrixUnits.matrixAdd(MatrixUnits.matrixNormalMul(cell.Wgh, ht_1),
			MatrixUnits.matrixNormalMul(cell.Wgx, xt)),cell.bg);
			gt=SupportFunction.tanh(netgt);
			//输出单元状态
			ct=MatrixUnits.matrixAdd(MatrixUnits.matrixHadamardMul(ft, ct_1), MatrixUnits.matrixHadamardMul(it, gt));
			//输出门
			netot=MatrixUnits.matrixAdd(MatrixUnits.matrixAdd(MatrixUnits.matrixNormalMul(cell.Woh, ht_1),
			MatrixUnits.matrixNormalMul(cell.Wox, xt)),cell.bo);
			ot=SupportFunction.sigmoid(netot);
			//最终输出
			ht=MatrixUnits.matrixHadamardMul(ot, SupportFunction.tanh(ct));
			
			ct_landLastctList.set(i, ct);
			ht_1andLasthtList.set(i, ht);
			gateList.get(i-1).put("ft", ft);
			gateList.get(i-1).put("it", it);
			gateList.get(i-1).put("gt", gt);
			gateList.get(i-1).put("ot", ot);
			netList.get(i-1).put("netft", netft);
			netList.get(i-1).put("netit", netit);
			netList.get(i-1).put("netgt", netgt);
			netList.get(i-1).put("netot", netot);
		}
		//System.out.println("forwardProcess end:");
	}
	public static void backwardProcess(
	ArrayList<double[][]> xtList,
	ArrayList<HashMap<String, double[][]>> deltaList,
	ArrayList<double[][]> ct_landLastctList,
	ArrayList<double[][]> ht_1andLasthtList,
	ArrayList<HashMap<String, double[][]>> gateList,
	ArrayList<HashMap<String, double[][]>> gradList,
	HashMap<String, double[][]> finalGrad,
	double[][] trainCurrentData,
	SingleCell cell,int dhc,int dx,double[][] nextTime) {
		//System.out.println("backwardProcess start:");
		int window=MatrixUnits.getCol(trainCurrentData);
		//1、首先更新deltaList
		double[][] ct=ct_landLastctList.get(window);
		double[][] ot=gateList.get(window-1).get("ot");
		double[][] ft=gateList.get(window-1).get("ft");
		double[][] it=gateList.get(window-1).get("it");
		double[][] gt=gateList.get(window-1).get("gt");
		double[][] ct_1=ct_landLastctList.get(window-1);
		
		double[][] matrix1=MatrixUnits.getAZeromatrix(dhc, 1);//全1矩阵,这个一直到后面都能用的
		
		double[][] tanhct21_=MatrixUnits.matrixSub(matrix1,//1-tanh^2ct
		MatrixUnits.matrixHadamardMul(SupportFunction.tanh(ct), SupportFunction.tanh(ct)));
		
		double[][] deltat=new double[dhc][1];
//		//**************这里取的是误差函数1/2(t^2-y^2)的导数的相反数,而且是各个输出都要计算然后相加
//		for(int i=0;i<window;i++){
//			double[][] yData=ht_1andLasthtList.get(i+1);
//			double[][] tempdelta=new double[dhc][1];
//			double[][] target=new double[dhc][1];
//			if(i<window-1)
//				target=xtList.get(i+1);
//			else
//				target=nextTime.clone();
//			for(int j=0;j<dhc;j++)
//				tempdelta[j][0]=yData[i][0]*(1-yData[i][0])*(target[i][0]-yData[i][0]);
//			deltat=MatrixUnits.matrixAdd(deltat, tempdelta);
//		}
//		//*************************
//		**********这里是只算最后一个输出的delta
		double[][] yData=ht_1andLasthtList.get(window);
		double[][] target=nextTime.clone();
		for(int j=0;j<dhc;j++)
			deltat[j][0]=(target[j][0]-yData[j][0])*yData[j][0]*(1-yData[j][0]);
//		*************************		
		deltaList.get(window-1).put("deltat", deltat);
		
		
		double[][] deltaot=MatrixUnits.matrixHadamardMul(MatrixUnits.matrixHadamardMul(MatrixUnits.
		matrixHadamardMul(deltat, SupportFunction.tanh(ct)),ot),MatrixUnits.matrixSub(matrix1,ot));
				
		double[][] deltaft=MatrixUnits.matrixHadamardMul(deltat,MatrixUnits.matrixHadamardMul(ot,
		MatrixUnits.matrixHadamardMul(tanhct21_, MatrixUnits.matrixHadamardMul(ct_1, MatrixUnits.
		matrixHadamardMul(ft, MatrixUnits.matrixSub(matrix1, ft))))));
			
		double[][] deltait=MatrixUnits.matrixHadamardMul(deltat,MatrixUnits.matrixHadamardMul(ot,
		MatrixUnits.matrixHadamardMul(tanhct21_, MatrixUnits.matrixHadamardMul(gt, MatrixUnits.
		matrixHadamardMul(it, MatrixUnits.matrixSub(matrix1, it))))));
				
		double[][] deltagt=MatrixUnits.matrixHadamardMul(deltat,MatrixUnits.matrixHadamardMul(ot,
		MatrixUnits.matrixHadamardMul(tanhct21_, MatrixUnits.matrixHadamardMul(it,
		MatrixUnits.matrixSub(matrix1, MatrixUnits.matrixHadamardMul(gt, gt))))));
		
		deltaList.get(window-1).put("deltaot", deltaot);
		deltaList.get(window-1).put("deltaft", deltaft);
		deltaList.get(window-1).put("deltait", deltait);
		deltaList.get(window-1).put("deltagt", deltagt);
		//截止到现在为止,更新完了最新一天的deltat,deltaot,deltaft,deltait,deltagt,然后后面的天就一直依靠他们来不断进行更新
		for(int i=window-2;i>=0;i--){
			double[][] tempct=ct_landLastctList.get(i+1);
			double[][] tempct_1=ct_landLastctList.get(i);
			double[][] tempot=gateList.get(i).get("ot");
			double[][] tempft=gateList.get(i).get("ft");
			double[][] tempit=gateList.get(i).get("it");
			double[][]
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值