NCDE
本文涉及些许NODE相关知识,可以参见另一篇博客:神经常微分方程(NODE)介绍
与NODE的比较
NCDE想解决的问题是NODE在方程求解过程中只利用到了初始状态的值,而忽略了中间时刻系统状态的影响,因此NCDE拟在每一预测时刻都利用到先前所有时刻的系统状态值,给出具有累积意义的预测结果。NODE一般的计算过程如下公式所述,旨在找到从输入数据
x
x
x到输出数据
y
y
y的映射,可以看出其中的积分过程即考虑到了累积时刻的影响,有别于NODE仅仅包含初值项的设置。在目标函数设置合理的情况下,我们可以将
l
θ
1
,
l
θ
2
l_{\theta}^1,l_{\theta}^2
lθ1,lθ2分别视为解码器和编码器,从而实现任意时间点的数据生成。
y
≈
l
θ
1
(
z
T
)
,
w
h
e
r
e
z
t
=
z
0
+
∫
0
t
f
θ
(
z
s
)
d
s
a
n
d
z
0
=
l
θ
2
(
x
)
y\approx l_{\theta}^1(z_T),\ where\ z_t=z_0+\int_0^tf_{\theta}(z_s)ds\ and\ z_0=l^2_\theta(x)
y≈lθ1(zT), where zt=z0+∫0tfθ(zs)ds and z0=lθ2(x)
当网络参数
θ
\theta
θ训练完成后,数据的生成过程只由初值
z
0
z_0
z0所唯一决定,而无法利用到后续时间点的数据。因此在实际应用层面,NODE更多的是作为ResNet的平替,并不能很好的对序列数据进行处理。以NODE的原始代码为例,针对阿基米德螺旋线数据生成过程,在训练过程是利用到了所有时间的数据对网络参数进行训练,但无论是在训练还是测试过程中,不同时间点的隐变量生成都是只使用了初始时刻的隐状态
z
0
z_0
z0,送入ODE求解得到,而其余时刻的状态
z
t
z_t
zt则置之不理。相对的,NCDE则会利用到所有时间点的数据,是对RNN的一种平替。
此外,NODE在数据生成的mini batch方面存在问题,虽然输入的数据确实可以是随机采样的,但同一批次的数据需要是同种采样类型的,否则无法一起训练(从代码来看是这样的)。而NCDE则没有这方面的困扰,同一批次的数据无需缺失方式相同。
NCDE设计
NCDE的数据生成公式如下所示:
z
t
=
z
t
0
+
∫
t
0
t
f
θ
(
z
s
)
d
X
s
f
o
r
t
∈
(
t
0
,
t
n
]
z_t=z_{t_0}+\int^t_{t_0}f_{\theta}(z_s)dX_s\quad for\ t\in(t_0,t_n]
zt=zt0+∫t0tfθ(zs)dXsfor t∈(t0,tn]
可以看出和NODE相比,不同点在于积分变量由时间
t
t
t变成了时序值
X
s
X_s
Xs,原先ODE的解是由时间变量
t
t
t和初值
z
0
z_0
z0所决定,而CDE的解则是受时序值
X
s
X_s
Xs所控制。这一时序值
X
s
X_s
Xs实际上是利用原始输入序列数据对系统进行的一次初步拟合,文章采用了简单的三次样条插值方法作为初次的模拟。也就是对于非等间隔采样的输入时序数据
x
=
(
(
t
0
,
x
0
)
,
(
t
1
,
x
1
)
,
.
.
.
,
(
t
n
,
x
n
)
)
x=((t_0,x_0),(t_1,x_1),...,(t_n,x_n))
x=((t0,x0),(t1,x1),...,(tn,xn)),自然三次样条插值函数
X
s
X_s
Xs是对
x
x
x隐含生成过程的模拟,使节点处有
X
t
i
=
(
x
i
,
t
i
)
X_{t_i}=(x_i,t_i)
Xti=(xi,ti),而其余任意时间的值也可由
x
(
t
)
=
X
(
t
)
x(t)=X(t)
x(t)=X(t)得到。
利用插值函数
X
s
X_s
Xs,我们获得了关于输入时序数据的连续值,从而可以求得任意时刻的
d
X
s
dX_s
dXs,如同将
t
t
t视为积分变量一样。同时由于
X
s
X_s
Xs是可导的,我们可以定义:
g
θ
,
X
(
z
,
s
)
=
f
θ
(
z
)
d
X
d
s
(
s
)
g_{\theta,X}(z,s)=f_{\theta}(z)\frac{dX}{ds}(s)
gθ,X(z,s)=fθ(z)dsdX(s)
从而可以将生成公式转化为如下形式:
z
t
=
z
t
0
+
∫
t
0
t
f
θ
(
z
s
)
d
X
s
=
z
t
0
+
∫
t
0
t
f
θ
(
z
)
d
X
d
s
(
s
)
d
s
=
z
t
0
+
∫
t
0
t
g
θ
,
X
(
z
,
s
)
d
s
z_t=z_{t_0}+\int^t_{t_0}f_{\theta}(z_s)dX_s=z_{t_0}+\int^t_{t_0}f_{\theta}(z)\frac{dX}{ds}(s)d_s=z_{t_0}+\int^t_{t_0}g_{\theta,X}(z,s)ds
zt=zt0+∫t0tfθ(zs)dXs=zt0+∫t0tfθ(z)dsdX(s)ds=zt0+∫t0tgθ,X(z,s)ds
此时保证了计算过程受所有时刻输入数据控制的同时,和NODE的公式具有了相同的形式,只不过被积分函数由
f
θ
(
z
s
)
f_{\theta}(z_s)
fθ(zs)变成了
f
θ
(
z
)
d
X
d
s
(
s
)
f_{\theta}(z)\frac{dX}{ds}(s)
fθ(z)dsdX(s),从而可以直接用NODE的计算工具来计算,包括了adjoint method。
实际运用
由于利用了
X
s
X_s
Xs作为数据的初始拟合,我们可以获得任意时刻的
x
t
x_t
xt,因此即使输入的数据具有各异的非等间隔采样模式,也可以很直接的batch在一起。
利用NCDE生成数据的过程和NODE类似,只是在网络训练前需要去计算自然三次样条插值,同时要求解的微分方程为:
d
y
d
t
=
f
θ
(
z
)
d
X
d
t
(
t
)
\frac{dy}{dt}=f_{\theta}(z)\frac{dX}{dt}(t)
dtdy=fθ(z)dtdX(t)
从而除了对隐状态进行处理外,还要额外乘上当前时间点插值函数的导数。