NCDE(神经受控微分方程)介绍

492 阅读3分钟

“我正在参加「掘金·启航计划」”

NCDE

与NODE的比较

NODE的介绍可见主页内相关博客

 NCDE想解决的问题是NODE在方程求解过程中只利用到了初始状态的值,NODE一般的计算过程如下公式所述,旨在找到从输入数数据xx到输出数据yy的映射,在目标函数设置合理的情况下,可以将lθ1,lθ2l_{\theta}^1,l_{\theta}^2分别视为解码器和编码器,从而实现任意时间点的数据生成。

ylθ1(zT), where zt=z0+0tfθ(zs)ds and z0=lθ2(x)y\approx l_{\theta}^1(z_T),\ where\ z_t=z_0+\int_0^tf_{\theta}(z_s)ds\ and\ z_0=l^2_\theta(x)

当网络参数θ\theta训练完成后,数据的生成过程只由初值z0z_0所唯一决定,而无法利用到后续时间点的数据。因此在实际应用层面,NODE更多的是作为ResNet的平替,并不能很好的对序列数据进行处理。以NODE的原始代码为例,针对阿基米德螺旋线数据生成过程,在训练过程是利用到了所有时间的数据对网络参数进行训练,但无论是在训练还是测试过程中,不同时间点的隐变量生成都是只使用了初始时刻的隐状态z0z_0,送入ODE求解得到,而其余时刻的状态ztz_t则置之不理。相对的,NCDE则会利用到所有时间点的数据,是对RNN的一种平替。

 此外,NODE在数据生成的mini batch方面存在问题,虽然输入的数据确实可以是随机采样的,但同一批次的数据需要是同种采样类型的,否则无法一起训练(从代码来看是这样的)。而NCDE则没有这方面的困扰,同一批次的数据无需缺失方式相同。

NCDE设计

 NCDE的数据生成公式如下所示:

zt=zt0+t0tfθ(zs)dXsfor t(t0,tn]z_t=z_{t_0}+\int^t_{t_0}f_{\theta}(z_s)dX_s\quad for\ t\in(t_0,t_n]

可以看出和NODE相比,不同点在于积分变量由时间tt变成了时序值XsX_s,原先ODE的解是由时间变量tt和初值z0z_0所决定,而CDE的解则是受时序值XsX_s所控制。这一时序值XsX_s实际上是利用原始输入序列数据对系统进行的一次初步拟合,文章采用了简单的三次样条插值方法作为初次的模拟。也就是对于非等间隔采样的输入时序数据x=((t0,x0),(t1,x1),...,(tn,xn))x=((t_0,x_0),(t_1,x_1),...,(t_n,x_n)),自然三次样条插值函数XsX_s是对xx隐含生成过程的模拟,使节点处有Xti=(xi,ti)X_{t_i}=(x_i,t_i),而其余任意时间的值也可由x(t)=X(t)x(t)=X(t)得到。  利用插值函数XsX_s,我们获得了关于输入时序数据的连续值,从而可以求得任意时刻的dXsdX_s,如同将tt视为积分变量一样。同时由于XsX_s是可导的,我们可以定义:

gθ,X(z,s)=fθ(z)dXds(s)g_{\theta,X}(z,s)=f_{\theta}(z)\frac{dX}{ds}(s)

从而可以将生成公式转化为如下形式:

zt=zt0+t0tfθ(zs)dXs=zt0+t0tfθ(z)dXds(s)ds=zt0+t0tgθ,X(z,s)dsz_t=z_{t_0}+\int^t_{t_0}f_{\theta}(z_s)dX_s=z_{t_0}+\int^t_{t_0}f_{\theta}(z)\frac{dX}{ds}(s)d_s=z_{t_0}+\int^t_{t_0}g_{\theta,X}(z,s)ds

此时保证了计算过程受所有时刻输入数据控制的同时,和NODE的公式具有了相同的形式,只不过被积分函数由fθ(zs)f_{\theta}(z_s)变成了fθ(z)dXds(s)f_{\theta}(z)\frac{dX}{ds}(s),从而可以直接用NODE的计算工具来计算,包括了adjoint method。

实际运用

 由于利用了XsX_s作为数据的初始拟合,我们可以获得任意时刻的xtx_t,因此即使输入的数据具有各异的非等间隔采样模式,也可以很直接的batch在一起。  利用NCDE生成数据的过程和NODE类似,只是在网络训练前需要去计算自然三次样条插值,同时要求解的微分方程为:

dydt=fθ(z)dXdt(t)\frac{dy}{dt}=f_{\theta}(z)\frac{dX}{dt}(t)

从而除了对隐状态进行处理外,还要额外乘上当前时间点插值函数的导数。