神经受控微分方程(NCDE)介绍

NCDE(Neural Controlled Differential Equations)旨在解决NODE(Neural Ordinary Differential Equations)在处理序列数据时只依赖初始状态的局限。NODE通常用于静态数据的映射,而NCDE利用所有时间点的数据,更适合作为RNN的替代。NCDE的计算中,积分变量由时间变为时序值,通过自然三次样条插值模拟输入序列,使得模型能够考虑所有时间点的信息。这种方法允许NCDE在处理非等间隔采样的数据时进行批处理,并且可以直接使用NODE的计算工具。NCDE的实际应用中,需要计算插值函数并解特定的微分方程,以处理隐状态和时间点插值函数的导数。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

NCDE

本文涉及些许NODE相关知识,可以参见另一篇博客:神经常微分方程(NODE)介绍

与NODE的比较

 NCDE想解决的问题是NODE在方程求解过程中只利用到了初始状态的值,而忽略了中间时刻系统状态的影响,因此NCDE拟在每一预测时刻都利用到先前所有时刻的系统状态值,给出具有累积意义的预测结果。NODE一般的计算过程如下公式所述,旨在找到从输入数据 x x x到输出数据 y y y的映射,可以看出其中的积分过程即考虑到了累积时刻的影响,有别于NODE仅仅包含初值项的设置。在目标函数设置合理的情况下,我们可以将 l θ 1 , l θ 2 l_{\theta}^1,l_{\theta}^2 lθ1,lθ2分别视为解码器和编码器,从而实现任意时间点的数据生成。
y ≈ l θ 1 ( z T ) ,   w h e r e   z t = z 0 + ∫ 0 t f θ ( z s ) d s   a n d   z 0 = l θ 2 ( x ) y\approx l_{\theta}^1(z_T),\ where\ z_t=z_0+\int_0^tf_{\theta}(z_s)ds\ and\ z_0=l^2_\theta(x) ylθ1(zT), where zt=z0+0tfθ(zs)ds and z0=lθ2(x)
 当网络参数 θ \theta θ训练完成后,数据的生成过程只由初值 z 0 z_0 z0所唯一决定,而无法利用到后续时间点的数据。因此在实际应用层面,NODE更多的是作为ResNet的平替,并不能很好的对序列数据进行处理。以NODE的原始代码为例,针对阿基米德螺旋线数据生成过程,在训练过程是利用到了所有时间的数据对网络参数进行训练,但无论是在训练还是测试过程中,不同时间点的隐变量生成都是只使用了初始时刻的隐状态 z 0 z_0 z0,送入ODE求解得到,而其余时刻的状态 z t z_t zt则置之不理。相对的,NCDE则会利用到所有时间点的数据,是对RNN的一种平替。
 此外,NODE在数据生成的mini batch方面存在问题,虽然输入的数据确实可以是随机采样的,但同一批次的数据需要是同种采样类型的,否则无法一起训练(从代码来看是这样的)。而NCDE则没有这方面的困扰,同一批次的数据无需缺失方式相同。

NCDE设计

 NCDE的数据生成公式如下所示:
z t = z t 0 + ∫ t 0 t f θ ( z s ) d X s f o r   t ∈ ( t 0 , t n ] z_t=z_{t_0}+\int^t_{t_0}f_{\theta}(z_s)dX_s\quad for\ t\in(t_0,t_n] zt=zt0+t0tfθ(zs)dXsfor t(t0,tn]
可以看出和NODE相比,不同点在于积分变量由时间 t t t变成了时序值 X s X_s Xs,原先ODE的解是由时间变量 t t t和初值 z 0 z_0 z0所决定,而CDE的解则是受时序值 X s X_s Xs所控制。这一时序值 X s X_s Xs实际上是利用原始输入序列数据对系统进行的一次初步拟合,文章采用了简单的三次样条插值方法作为初次的模拟。也就是对于非等间隔采样的输入时序数据 x = ( ( t 0 , x 0 ) , ( t 1 , x 1 ) , . . . , ( t n , x n ) ) x=((t_0,x_0),(t_1,x_1),...,(t_n,x_n)) x=((t0,x0),(t1,x1),...,(tn,xn)),自然三次样条插值函数 X s X_s Xs是对 x x x隐含生成过程的模拟,使节点处有 X t i = ( x i , t i ) X_{t_i}=(x_i,t_i) Xti=(xi,ti),而其余任意时间的值也可由 x ( t ) = X ( t ) x(t)=X(t) x(t)=X(t)得到。
 利用插值函数 X s X_s Xs,我们获得了关于输入时序数据的连续值,从而可以求得任意时刻的 d X s dX_s dXs,如同将 t t t视为积分变量一样。同时由于 X s X_s Xs是可导的,我们可以定义:
g θ , X ( z , s ) = f θ ( z ) d X d s ( s ) g_{\theta,X}(z,s)=f_{\theta}(z)\frac{dX}{ds}(s) gθ,X(z,s)=fθ(z)dsdX(s)
从而可以将生成公式转化为如下形式:
z t = z t 0 + ∫ t 0 t f θ ( z s ) d X s = z t 0 + ∫ t 0 t f θ ( z ) d X d s ( s ) d s = z t 0 + ∫ t 0 t g θ , X ( z , s ) d s z_t=z_{t_0}+\int^t_{t_0}f_{\theta}(z_s)dX_s=z_{t_0}+\int^t_{t_0}f_{\theta}(z)\frac{dX}{ds}(s)d_s=z_{t_0}+\int^t_{t_0}g_{\theta,X}(z,s)ds zt=zt0+t0tfθ(zs)dXs=zt0+t0tfθ(z)dsdX(s)ds=zt0+t0tgθ,X(z,s)ds
此时保证了计算过程受所有时刻输入数据控制的同时,和NODE的公式具有了相同的形式,只不过被积分函数由 f θ ( z s ) f_{\theta}(z_s) fθ(zs)变成了 f θ ( z ) d X d s ( s ) f_{\theta}(z)\frac{dX}{ds}(s) fθ(z)dsdX(s),从而可以直接用NODE的计算工具来计算,包括了adjoint method。

实际运用

 由于利用了 X s X_s Xs作为数据的初始拟合,我们可以获得任意时刻的 x t x_t xt,因此即使输入的数据具有各异的非等间隔采样模式,也可以很直接的batch在一起。
 利用NCDE生成数据的过程和NODE类似,只是在网络训练前需要去计算自然三次样条插值,同时要求解的微分方程为:
d y d t = f θ ( z ) d X d t ( t ) \frac{dy}{dt}=f_{\theta}(z)\frac{dX}{dt}(t) dtdy=fθ(z)dtdX(t)
从而除了对隐状态进行处理外,还要额外乘上当前时间点插值函数的导数。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值