“我正在参加「掘金·启航计划」”
NCDE
与NODE的比较
NODE的介绍可见主页内相关博客
NCDE想解决的问题是NODE在方程求解过程中只利用到了初始状态的值,NODE一般的计算过程如下公式所述,旨在找到从输入数数据x到输出数据y的映射,在目标函数设置合理的情况下,可以将lθ1,lθ2分别视为解码器和编码器,从而实现任意时间点的数据生成。
y≈lθ1(zT), where zt=z0+∫0tfθ(zs)ds and z0=lθ2(x)
当网络参数θ训练完成后,数据的生成过程只由初值z0所唯一决定,而无法利用到后续时间点的数据。因此在实际应用层面,NODE更多的是作为ResNet的平替,并不能很好的对序列数据进行处理。以NODE的原始代码为例,针对阿基米德螺旋线数据生成过程,在训练过程是利用到了所有时间的数据对网络参数进行训练,但无论是在训练还是测试过程中,不同时间点的隐变量生成都是只使用了初始时刻的隐状态z0,送入ODE求解得到,而其余时刻的状态zt则置之不理。相对的,NCDE则会利用到所有时间点的数据,是对RNN的一种平替。
此外,NODE在数据生成的mini batch方面存在问题,虽然输入的数据确实可以是随机采样的,但同一批次的数据需要是同种采样类型的,否则无法一起训练(从代码来看是这样的)。而NCDE则没有这方面的困扰,同一批次的数据无需缺失方式相同。
NCDE设计
NCDE的数据生成公式如下所示:
zt=zt0+∫t0tfθ(zs)dXsfor t∈(t0,tn]
可以看出和NODE相比,不同点在于积分变量由时间t变成了时序值Xs,原先ODE的解是由时间变量t和初值z0所决定,而CDE的解则是受时序值Xs所控制。这一时序值Xs实际上是利用原始输入序列数据对系统进行的一次初步拟合,文章采用了简单的三次样条插值方法作为初次的模拟。也就是对于非等间隔采样的输入时序数据x=((t0,x0),(t1,x1),...,(tn,xn)),自然三次样条插值函数Xs是对x隐含生成过程的模拟,使节点处有Xti=(xi,ti),而其余任意时间的值也可由x(t)=X(t)得到。
利用插值函数Xs,我们获得了关于输入时序数据的连续值,从而可以求得任意时刻的dXs,如同将t视为积分变量一样。同时由于Xs是可导的,我们可以定义:
gθ,X(z,s)=fθ(z)dsdX(s)
从而可以将生成公式转化为如下形式:
zt=zt0+∫t0tfθ(zs)dXs=zt0+∫t0tfθ(z)dsdX(s)ds=zt0+∫t0tgθ,X(z,s)ds
此时保证了计算过程受所有时刻输入数据控制的同时,和NODE的公式具有了相同的形式,只不过被积分函数由fθ(zs)变成了fθ(z)dsdX(s),从而可以直接用NODE的计算工具来计算,包括了adjoint method。
实际运用
由于利用了Xs作为数据的初始拟合,我们可以获得任意时刻的xt,因此即使输入的数据具有各异的非等间隔采样模式,也可以很直接的batch在一起。
利用NCDE生成数据的过程和NODE类似,只是在网络训练前需要去计算自然三次样条插值,同时要求解的微分方程为:
dtdy=fθ(z)dtdX(t)
从而除了对隐状态进行处理外,还要额外乘上当前时间点插值函数的导数。