200 行代码,深入分析动态计算图的原理及实现
原文地址:优快云 博客
文章目录
1. 前言
机器学习这几年可是大红大紫,各行各业的人都往这里涌入,硬是在机器学习这一领域里挤出了一片人口红海。而在机器学习领域,神经网络由于自己下限低、上限高的特点,赢得了不少人的青睐。
在神经网络中,却有一件我们经常使用,经常耳闻,但又不太熟悉的东西——BP 算法。入门“炼丹”的小萌新往往会对这个一头雾水,久经沙场的老油条对这个也可能不了解细节。
我在查阅许多文章后,发现大多数文章对 BP 算法的介绍往往是点到为止,更深入者也就在数学公式推导层面止步,涉及到代码层面的博主鲜少,更很少提及 BP 算法在神经网络中的更广泛实现——计算图机制。
于是,秉持着“科普”的原则,笔者就撰写了这篇有关于 BP 算法以及计算图原理的文章,并在其中以笔者自己的代码实现,详细地讲解计算图的工作机制,并最终与成熟的计算框架进行比较。
2. BP 算法
BP 算法,又名反向传播算法,是目前深度学习的理论基石。其原始论文于 1986 年由 D. Rumelhart 发表在 Nature 上1。在其论文中,就已经使用 MSE(Mean Square Error) 均方误差作为训练目标,并使用多层的 MLP 感知机作为模型,进行亲戚关系的分类。

当前时代的神经网络,早已比当时的网络来的更加庞大,几百个万的模型参数比比皆是,GPT-3 甚至已经上千亿的模型,而其最基本的算法,却来自于 40 年前,让人感到不可思议。
对于 BP 算法的理解其实非常简单。假设神经网络的的损失是
L
L
L,
x
\bm{x}
x 是输入向量,
W
i
j
\bm{W}_{ij}
Wij 是第
i
i
i 层的第
j
j
j 个参数,那么根据梯度下降的原理,我们需要得到
L
L
L 对
W
i
j
\bm{W}_{ij}
Wij 偏微分值:
∇
W
i
j
=
∂
L
∂
W
i
j
\nabla\bm{W}_{ij}=\frac{\partial L}{\partial \bm{W}_{ij}}
∇Wij=∂Wij∂L
设
η
\eta
η 为学习率,则最终的参数更新算法为:
W
i
j
t
+
1
=
W
i
j
t
−
η
⋅
∇
W
i
j
\bm{W}_{ij}^{t+1}=\bm{W}_{ij}^{t}-\eta\cdot\nabla\bm{W}_{ij}
Wijt+1=Wijt−η⋅∇Wij
然后问题来了:怎么计算 ∂ L ∂ W i j \frac{\partial L}{\partial \bm{W}_{ij}} ∂Wij∂L?
许多的博文都对这个问题作出众多的解释,大部分人会选择使用数学推导的形式阐述,最终结果或许可能如下:

这串花里胡哨的东西,对数学系的同学来说刚刚好,对笔者来说可不好。讲到底 BP 算法就是一个偏导数的链式法则应用,写这么复杂真的有用吗?
∂
y
∂
x
1
=
∂
y
∂
x
n
⋅
∂
y
∂
x
n
−
1
⋅
⋯
⋅
∂
x
2
∂
x
1
\frac{\partial y}{\partial x_1}=\frac{\partial y}{\partial x_n}\cdot \frac{\partial y}{\partial x_{n-1}}\cdot\dots\cdot\frac{\partial x_2}{\partial x_1}
∂x1∂y=∂xn∂y⋅∂xn−1∂y⋅⋯⋅∂x1∂x2
看吧!如果我把上面这串链式法则的公式,
y
y
y 换成
L
L
L,
x
1
x_1
x1 换为
W
i
j
\bm{W}_{ij}
Wij,剩下的
x
i
x_i
xi 换为神经网络中的一些其他变量,不就把 BP 算法拆成了许多更小的偏导数的乘积吗?
对于 BP 算法的数学机理,了解到这已经足够。下一节,笔者将以程序员的角度,带大家看 BP 算法的另一个视角——计算图机制。
3. 计算图
本章中通过一个实际的例子,给出计算图的详细说明,并引出了计算图反向传播机制的定理。
3.1 计算图定义
计算图是描述计算过程的数据结构,而且通常是 DAG 图(有向无环图)。
在计算图中,每一个节点表示一个变量(值),每一条边表示数据的流动方向,并且每一条边的值被定义为边的首尾节点的偏导数值。例如:

这幅图表示以下的三个算式:
c
=
a
+
b
d
=
b
+
1
e
=
c
×
d
\begin{aligned} c&=a+b\\ d&=b+1\\ e&=c\times d \end{aligned}
cde=a+b=b+1=c×d
在这副计算图中,每个节点都表示着一个变量值,每条边表示数据的流动。在每条边上,笔者提前算出了每条边的末尾节点对起始节点的偏导数,例如边 (b,d) 的偏导数就是
∂
d
∂
b
=
1
\frac{\partial d}{\partial b}=1
∂b∂d=1。
3.2 计算图机制
拥有计算图的定义后,下面来详细介绍一下计算图是如何对应 BP 算法的。
3.2.1 前向传播
对应于 BP 算法的前向传播(Forward Pass)过程,计算图的前向传播其实相同,就是把计算图的每个节点的值都计算出来。
例如,在上面的示例图中,若设
a
=
2
,
b
=
1
a=2,b=1
a=2,b=1,那么前向传播的过程就把其他的节点值都算出来:
c
=
a
+
b
=
3
d
=
b
+
1
=
2
e
=
c
×
d
=
5
\begin{aligned} c&=a+b=3\\ d&=b+1=2\\ e&=c\times d=5 \end{aligned}
cde=a+b=3=b+1=2=c×d=5
前向传播没有理解上的难点,大家一眼就能明白,而难点在于反向传播的过程中。
3.2.2 反向传播
在反向传播的机制中,笔者并不打算引入过于复杂的数学公式来证明,而是选择用更加浅显易懂的大白话,说明计算图在反向传播过程中的工作原理。

对于示例图中,如果想求 ∂ e ∂ b \frac{\partial e}{\partial b} ∂b∂e,该怎么办?首先,从节点 b b b 开始,可以发现 b b b 通过作用于 c c c 和 d d d,进而对节点 e e e 造成了影响。
这个连环影响的现象表达成数学的形式,即为
Δ
e
=
∂
e
∂
c
⋅
Δ
c
+
∂
e
∂
d
⋅
Δ
d
=
∂
e
∂
c
⋅
(
∂
c
∂
b
⋅
Δ
b
)
+
∂
e
∂
d
⋅
(
∂
d
∂
b
⋅
Δ
b
)
\begin{aligned} \Delta e&=\frac{\partial e}{\partial c}\cdot\Delta c + \frac{\partial e}{\partial d}\cdot\Delta d \\ &=\frac{\partial e}{\partial c}\cdot(\frac{\partial c}{\partial b}\cdot \Delta b) + \frac{\partial e}{\partial d}\cdot(\frac{\partial d}{\partial b}\cdot \Delta b) \end{aligned}
Δe=∂c∂e⋅Δc+∂d∂e⋅Δd=∂c∂e⋅(∂b∂c⋅Δb)+∂d∂e⋅(∂b∂d⋅Δb)
上式左右两侧同时除以
Δ
b
\Delta b
Δb,则可以不严谨的得到:
∂
e
∂
b
=
∂
e
∂
c
⋅
∂
c
∂
b
+
∂
e
∂
d
⋅
∂
d
∂
b
\frac{\partial e}{\partial b}=\frac{\partial e}{\partial c}\cdot\frac{\partial c}{\partial b} + \frac{\partial e}{\partial d}\cdot\frac{\partial d}{\partial b}
∂b∂e=∂c∂e⋅∂b∂c+∂d∂e⋅∂b∂d
仔细地观察这个式子,对比下图可以发现:式子的前半部分 ∂ e ∂ c ⋅ ∂ c ∂ b \frac{\partial e}{\partial c}\cdot\frac{\partial c}{\partial b} ∂c∂e⋅∂b∂c,正好是路线 A 的边上梯度值的乘积;同理,式子的后半部分 ∂ e ∂ d ⋅ ∂ d ∂ b \frac{\partial e}{\partial d}\cdot\frac{\partial d}{\partial b} ∂d∂e⋅∂b∂d,也是路线 B 的边上梯度值的乘积。

从这里例子,可以总结出计算图的最终定理。
定理(计算图反向传播机制):计算图上任意两点 x x x 和 y y y,且 y y y 在 x x x 之后,则 ∂ y ∂ x \frac{\partial y}{\partial x} ∂x∂y 的值为点 x x x 到点 y y y 上所有的不重复路径上的边值乘积的总和。
如果觉得这个定理有点难懂,那么其详细的计算过程如下:
- 找到所有从点 x x x 到 y y y 的不重复路径,记作集合 P \mathcal{P} P
- 对任意 p i ∈ P p_i \in \mathcal{P} pi∈P,计算路径 p i p_i pi 上所有边值乘积 M i M_i Mi
- 则 ∂ y ∂ x = ∑ p i ∈ P M i \frac{\partial y}{\partial x}=\sum^{p_i\in \mathcal{P}} M_i ∂x∂y=∑pi∈PMi
对应到这个例子,就是说:从路线 A,得到其路径上的乘积为
d
d
d;从路线 B,得到其路径上的乘积为
c
c
c。那么最终的结果为
∂
e
∂
b
=
d
+
c
=
a
+
2
b
+
1
\frac{\partial e}{\partial b}=d+c=a+2b+1
∂b∂e=d+c=a+2b+1
由于在前向传播的过程中,所有的变量值我们都已经确定,所以算出 ∂ e ∂ b \frac{\partial e}{\partial b} ∂b∂e 的过程也就迎刃而解了。
有兴趣的同学可以试着验证其他的变量,看它们是否符合此规律。此外,笔者更推荐对其他的计算图检查,可以加深对这条规则的理解。
4. 代码实现
下面就是代码实现的部分咯,觉得麻烦的小伙伴可以跳过不看哦,但还是希望能给我的代码点个 star 收藏一下,十分感激!ヾ(≧▽≦*)o
Github 仓库:toy_computational_graph
4.1 Operation 定义
在个人的 200 行代码的实现中,大部分代码用于实现加减乘除的操作,事实上真正涉及反向传播的代码可能不足 30 行。下面是关于 Operation 的基类定义:
class Operation(ABC):
def __init__(self):
super().__init__()
# 反向传播过程中所需要的上下文 ctx
self.ctx: Optional[Dict] = None
# 记录输入的节点
self.inputs: List[Value] = []
def __call__(self, *args) -> Scalar:
self.inputs = list(args)
self.ctx = dict()
ret = self.forward(args, ctx=self.ctx)
ret.op = self
return ret
@staticmethod
@abstractmethod
def forward(inputs: List[Scalar], ctx=None) -> Scalar:
# 进行前向传播,并将反向传播的必要信息存放于 ctx 中
pass
@staticmethod
@abstractmethod
def backward(grad_output: float, ctx=None) -> List[float]:
# 反向传播的过程,返回每条输入边的累积梯度值
# grad_output 是从更加往后的节点传播到此处的累积梯度乘积
pass
可见,每个 Operation 其实就有以下功能:
- 记录输入节点
- 记录前向传播过程中产生的上下文
- 前向传播
- 反向传播
根据这个基类,最终派生出了加减乘除操作的实现类:
class AddOperation(Operation):
@staticmethod
def forward(inputs: List[Scalar], ctx=None) -> Scalar:
x, y = inputs
return Scalar(x.value + y.value)
@staticmethod
def backward(grad_output: float, ctx=None) -> List[float]:
return [grad_output, grad_output]
class SubOperation(Operation):
@staticmethod
def forward(inputs: List[Scalar], ctx=None) -> Scalar:
x, y = inputs
return Scalar(x.value - y.value)
@staticmethod
def backward(grad_output: float, ctx=None) -> List[float]:
return [grad_output, -grad_output]
class MulOperation(Operation):
@staticmethod
def forward(inputs: List[Scalar], ctx=None) -> Scalar:
x, y = inputs
ctx["x"] = x
ctx["y"] = y
return Scalar(x.value * y.value)
@staticmethod
def backward(grad_output: float, ctx=None) -> List[float]:
x, y = ctx["x"].value, ctx["y"].value
return [grad_output * y, grad_output * x]
class DivOperation(Operation):
@staticmethod
def forward(inputs: List[Scalar], ctx=None) -> Scalar:
x, y = inputs
assert y.value != 0, "Division by zero"
ctx["x"] = x
ctx["y"] = y
return Scalar(x.value / y.value)
@staticmethod
def backward(grad_output: float, ctx=None) -> List[float]:
x, y = ctx["x"].value, ctx["y"].value
return [grad_output / y, -x * grad_output / (y ** 2)]
代码简短而且清爽,适合读者学习。
4.2 数值类型
由于这个 codebase 体量不大,因此只允许使用 float 的包装类 Scalar 作为数值类型。其中 Value 类是 Scalar 类的基类,其定义并实现了反向传播的机制,如下:
class Value:
def __init__(self, op: Optional[Operation]):
self.op = op
self.grad = 0.
def zero_grad(self):
# 梯度清零,类似于 PyTorch
self.grad = 0.
def backward(self, grad_output: Optional[float] = None):
# 反向传播的实际执行,就是从此节点,迭代地把累积梯度乘积向更前的节点传播
# 等节点根据所传入的累积梯度乘积,更新完自身的梯度值后,就继续进行此过程
# 注:在保证 DAG 的前提下,此过程相等于遍历图上的所有不同路径
grad_output = grad_output if grad_output is not None else 1.
self.grad += grad_output
if self.op is not None:
prev_grads = self.op.backward(grad_output, ctx=self.op.ctx)
for input, prev_grad in zip(self.op.inputs, prev_grads):
input.backward(prev_grad)
至于 Scalar 类,只是实现了 __add__ 之类的加减乘除的 Dunder 函数的封装类,大致如下:
class Scalar(Value):
def __init__(self, value: numbers.Number, op: Optional[Operation] = None):
super().__init__(op)
self._value = float(value)
def __add__(self, other):
from operation import AddOperation
if isinstance(other, Scalar):
op = AddOperation()
return op(self, other)
elif isinstance(other, numbers.Number):
op = AddOperation()
return op(self, Scalar(other))
else:
raise TypeError("unsupported type")
... ...
由于 Scalar 类并不包括太多实际操作,因此完整代码供有兴趣的读者自行查看。
4.3 运行结果
详细代码可以查看代码仓库中的 example.py,结果如下:
example1:
x=10.0, y=2.0, r=x+2*y=14.0
=> x.grad=1.0, y.grad=2.0
example2:
x=10.0, r=x*x=100.0
=> x.grad=20.0
example3:
x=10.0, r=x*(x+1)=110.0
=> x.grad=21.0
example4:
x=8.0, y=4.0, r=x/y=2.0
=> x.grad=0.25, y.grad=-0.5
example5:
x=3.0, r=1/(x*x+1)=0.1
=> x.grad=-0.06
example6:
x=8.0, y=3.0, r=(x*x+1)/(y*y-1)=8.125
=> x.grad=2.0, y.grad=-6.09375
以上六个例子的运算结果均正确。
5. 杂谈
事实上,我这个 demo 和 PyTorch 一样,采用的是动态计算图的形式,即计算图是在运算的过程中实时产生。相反的,Tensorflow 就是采用静态计算图,其计算图需要在一开始就进行编译并固定。
相较于我这个毫无优化的 demo,PyTorch 对于计算图的优化则是出神入化。首先在这个计算图的迭代过程中,明显可以发现,不同路径之间的乘积是可以并行计算的。
同时,从计算图机制的定理中可以发现,由于各个路径上的梯度最终是相加起来的,因此并行下最好的实现方式就是将各个变量的梯度都初始化为 0,否则梯度相加后会出错。这也是为什么 PyTorch 训练时,会需要 zero_grad() 这一步。当然,笔者的实现中也仿效了这一设计。
6. 总结
本文从程序员的角度,总结出了计算图机制下的运行定理,并给出了约 200 行的代码实现,希望能够帮助所有正在入门机器学习的人。
如果您觉得本文有价值,还希望您能给我的文章点个赞、收藏和关注的三连,我们下期再见!ヾ( ̄▽ ̄)ByeBye
最后的最后,附上本文代码的 repo 地址:toy_computational_graph,希望读者能点几个 star 支持一下!
本文从程序员角度深入剖析动态计算图,通过200行代码实现BP算法,详细讲解计算图的工作原理。文中介绍了计算图的定义、前向传播和反向传播机制,并提供了加减乘除操作的代码实现。通过实例展示了计算图如何进行反向传播,最后探讨了与PyTorch等框架的异同。
14万+

被折叠的 条评论
为什么被折叠?



