gplt-计算图

本文深入探讨了深度学习中计算图的基本概念与实现原理,详细解释了如何利用计算图进行函数值及偏导数的计算,包括加法、减法、乘法、指数、对数和正弦函数等算子的处理。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

原题

计算图”(computational graph)是现代深度学习系统的基础执行引擎,提供了一种表示任意数学表达式的方法,例如用有向无环图表示的神经网络。 图中的节点表示基本操作或输入变量,边表示节点之间的中间值的依赖性。 例如,下图就是一个函数f(x1,x2)=lnx1+x1x2−sinx2f(x_1,x_2)=lnx_1+x_1x_2-sinx_2f(x1,x2)=lnx1+x1x2sinx2的计算图。在这里插入图片描述
现在给定一个计算图,请你根据所有输入变量计算函数值及其偏导数(即梯度)。 例如,给定输入x1=1,x2=5x_1=1,x_2=5x1=1,x2=5,上述计算图获得函数值 f(2,5)=ln(2)+2×5−sin(5)=11.652f(2,5)=ln(2)+2×5−sin(5)=11.652f(2,5)=ln(2)+2×5sin(5)=11.652;并且根据微分链式法则,上图得到的梯度 ▽f=[∂f∂x1,∂f∂x2]=[1x1+x2,x1−cosx2]=[,5.5001.716]\triangledown f=[\frac{\partial f}{\partial x_1},\frac{\partial f}{\partial x_2}]=[\frac1x_1+x_2,x_1-cosx_2]=[,5.5001.716]f=[x1f,x2f]=[x11+x2,x1cosx2]=[,5.5001.716]
知道你已经把微积分忘了,所以这里只要求你处理几个简单的算子:加法、减法、乘法、指数(即编程语言中的 exp(x) 函数)、对数(lnx,即编程语言中的 log(x) 函数)和正弦函数(sinx,即编程语言中的 sin(x) 函数)。
如果你注意观察,可以发现在计算图中,计算函数值是一个从左向右进行的计算,而计算偏导数则正好相反。
输入格式
输入在第一行给出正整数 N≤5×104N\le5\times10^4N5×104,为计算图中的顶点数。
以下 N 行,第 i 行给出第 i 个顶点的信息,其中 i=0,1,⋯,N−1。第一个值是顶点的类型编号,分别为:

  • 0 代表输入变量
  • 1 代表加法,对应x1+x2x_1+x_2x1+x2
  • 2 代表减法,对应x1−x2x_1-x_2x1x2
  • 3 代表乘法,对应x1×x2x_1\times x_2x1×x2
  • 4 代表指数,对应exe^xex
  • 5 代表对数,对应lnxlnxlnx
  • 6 代表正弦,对应sinxsinxsinx
    对于输入变量,后面会跟它的双精度浮点数值;对于单目算子,后面会跟它对应的单个变量的顶点编号(编号从 0 开始);对于双目算子,后面会跟它对应两个变量的顶点编号。
    题目保证只有一个输出顶点(即没有出边的顶点,例如上图最右边的 -),且计算过程不会超过双精度浮点数的计算精度范围。
    输出格式
    首先在第一行输出给定计算图的函数值。在第二行顺序输出函数对于每个变量的偏导数的值,其间以一个空格分隔,行首尾不得有多余空格。偏导数的输出顺序与输入变量的出现顺序相同。输出小数点后 3 位。
    输入样例
    7
    0 2.0
    0 5.0
    5 0
    3 0 1
    6 1
    1 2 3
    2 5 4
    输出样例
    11.652
    5.500 1.716

梯度计算

  首先考虑计算图上的梯度计算。
  注意到顶点储存的是变量和算子,自然而然地,可以考虑用边代表运算过程;具体来说,对于一条从uuu射出,射入vvv的边⟨u,v⟩\langle u,v\rangleu,v,它记录的是v(u)v(u)v(u)(即将uuu作为变量进行vvv运算)求偏导的结果,也就是∂v∂u\frac{\partial v}{\partial u}uv
  同时,此题中包含的算子是单目或双目的,因此射入一个顶点的边不超过两条。另外,由于f:Rn↦Rf:\mathbb R^n\mapsto\mathbb Rf:RnR,因此计算图中一定存在且仅存在一个出度为0的定点(不妨称为输出节点),而入度为0的顶点一定表示变量(不妨称为输入节点)。总上,可以用类似树或k分图的方法处理计算图,换言之,将计算图看作这样一棵特殊的树:以输出节点为根,同一个独立集的深度相同(这是比较粗糙的说法),输入节点是树的叶子节点(同时也是树的全部叶子节点)在这里插入图片描述
  一个朴素的想法如下:从输入节点出发,沿着边的方向DFS,每访问一个顶点,便计算一次偏导数值,根据链式法则,在到达输出节点前,计算的偏导数值需要累乘,而同一个变量以此法计算的偏导数值需要累加。
  以题目的示意输入f(x1,x2)=lnx1+x1x2−sinx2f(x_1,x_2)=lnx_1+x_1x_2-sinx_2f(x1,x2)=lnx1+x1x2sinx2为例,可得∂f∂x1=1x1∗1∗1+x2∗1∗1=1x1+x2∂f∂x2=x1∗1∗1+cosx2∗(−1)=x1−cosx2\frac{\partial f}{\partial x_1}=\frac1{x_1}*1*1+x_2*1*1=\frac1{x_1}+x_2\\\frac{\partial f}{\partial x_2}=x_1*1*1+cosx_2*(-1)=x_1-cosx_2x1f=x1111+x211=x11+x2x2f=x111+cosx2(1)=x1cosx2
  这种想法是可行的,然而我们可以想起从下至上地遍历树,往往是由于题目中没有给出树的全部信息,我们需要建立一棵树(例如最优编码问题),对于已给出树的结构的题目,从上至下遍历往往是更优的方法。
  直觉上我们会发现重复访问难以避免:考虑一个储存双目算子的非根节点,假设射入它的边e1e_1e1被访问n1n_1n1次,e2e_2e2被访问n2n_2n2次,那么射出它的边e3e_3e3将被访问n1+n2n_1+n_2n1+n2次,哪怕在访问双目算子前没有出现重复,在访问双目算子时也会发生重复访问。如果所有算子都是双目算子,那么将产生指数级的复杂度。

  更好的方法:由上而下地进行DFS,如果递归和计算偏导数的时间复杂度都是O(1)O(1)O(1)的,那么将不会产生重复访问导致的额外时间复杂度。将每次BFS最后访问的叶子节点(也就是输入变量)作为偏导数计算结果的标签,当整个计算图访问结束后将相同标签的偏导数值相加,便得到结果。
  事实上,由于f:Rn↦Rf:\mathbb R^n\mapsto\mathbb Rf:RnR,使用向后传播方法(Backpropagation)计算梯度能使时间复杂度降低到算子目数之和。
  另外地,若f:R↦Rnf:\mathbb R\mapsto\mathbb R^nf:RRn,由于f−1:Rn↦Rf^{-1}:\mathbb R^n\mapsto\mathbb Rf1:RnR,因此可以对应地使用向前传播计算。

代码

  首先考虑储存计算图的数据结构,由于算子都是一目或二目的,因此考虑二叉树的储存方法,用四个vector分别储存类型键值左子节点右子节点。主程序如下:


vector<int> type(50000,0),L(50000,-1),R(50000,-1);
vector<double> key(50000,0);
vector<double> val(50000,0),grad(50000,0);	//每个节点的函数值和梯度
vector<int> isv(50000,0),isg(50000,0);		//标记是否计算了某节点的函数值和梯度
vector<int> var;							//变量位置	

int main(void){
	input();	//读取输入数据,并返回根节点;
	getval(root);		//计算函数值
	getgrad(root);		//计算梯度
	output();			//按格式输出;
	return 0;
}

  其中,输入函数为:


int root;   //根节点
void input(){
	int n;
	cin>>n;
	vector<int> isroot(n,1);	//记录是否为根节点
	for(int i=0;i<n;i++){
		int r,l,t;
		double k=0;
		cin>>t;
		if(!t){
			cin>>k;		//读入变量数据
            var.push_back(i);
		}
		else{			//如果是算子
			cin>>l;
			isroot[l]=0;
			L[i]=l;
			if(t<4){	//如果是双目算子
				cin>>r;
				isroot[r]=0;
				R[i]=r;
			}
		}
		type[i]=t;
		key[i]=k;
	}
	root=find(isroot.begin(),isroot.end(),1)-isroot.begin();	//记录根节点
}

  之后的getval()函数需要计算每个节点的函数值

double getval(int index){
	/*
	1: +	4: exp()
	2: -	5: log()
	3: *	6: sin()
	*/
		switch(type[index]){
		case 0:
			val[index]=key[index];
			break;
		case 1:
			val[index]=getval(L[index])+getval(R[index]);
			break;
		case 2:
			val[index]=getval(L[index])-getval(R[index]);
			break;
		case 3:
			val[index]=getval(L[index])*getval(R[index]);
			break;
		case 4:
			val[index]=exp(getval(L[index]));
			break;
		case 5:
			val[index]=log(getval(L[index]));
			break;
		case 6:
			val[index]=sin(getval(L[index]));
			break;
	}
	isv[index]=1;
	return val[index];
}

getgrad()函数用来计算梯度值,需要注意的是用tempgrad储存临时梯度值。


double tempgrad=1;
void getgrad(int index){
	/*
	1: +	4: exp()
	2: -	5: log()
	3: *	6: sin()
	*/
	
	switch(type[index]){
		case 0:
			grad[index]+=tempgrad;
			break;
		case 1:
			tempgrad*=1;
			getgrad(L[index]);
			tempgrad/=1;
			tempgrad*=1;
			getgrad(R[index]);
			tempgrad/=1;
			break;
		case 2:
			tempgrad*=1;
			getgrad(L[index]);
			tempgrad/=1;
			tempgrad*=-1;
			getgrad(R[index]);
			tempgrad/=-1;
			break;
		case 3:
			tempgrad*=val[R[index]];
			getgrad(L[index]);
			tempgrad/=val[R[index]];
			tempgrad*=val[L[index]];
			getgrad(R[index]);
			tempgrad/=val[L[index]];
			break;
		case 4:
			tempgrad*=exp(val[L[index]]);
			getgrad(L[index]);
			tempgrad/=exp(val[L[index]]);
			break;
		case 5:
			tempgrad*=1/val[L[index]];
			getgrad(L[index]);
			tempgrad/=1/val[L[index]];
			break;
		case 6:
			tempgrad*=cos(val[L[index]]);
			getgrad(L[index]);
			tempgrad/=cos(val[L[index]]);
			break;
	}
}

output()函数没什么好说的,不过最后一个测试点是N=0N=0N=0的输入,不太清楚要求输出什么格式。

void output(){
	if(var.size())printf("%.3lf",val[root]);
	for(int i=0;i<var.size();i++)
		printf("%s%.3lf",i==0?"\n":" ",grad[var[i]]);
}

运行结果:在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值