时间序列模型(3):xPatch

153 阅读6分钟

xPatch

1. 模型背景

论文中提到,近年来注意力机制被广泛用于时间序列预测模型当中,类似于Crossformer (Zhang and Yan 2022) and PatchTST (Nie et al. 2023)等模型均为transformer-based模型。然而,文中提到,注意力机制在充分利用时间序列数据中的时间关系上遇到挑战。同时,xPatch假设了注意力机制的置换不变性(Permutation Invariance)会影响预测效果。因此,xPatch提出了一种**基于指数分解的双流架构(dual-stream architecture that utilizes exponential decomposition)**来改善这一问题。

xPatch架构利用**EMA分解(exponential moving average decomposition)**对时间序列数据进行建模,其思想是将时间序列数据分解为两个部分,季节性模块和趋势性模块。季节性模块捕捉季节性变化,趋势性模块捕捉趋势性变化。

文中将季节性模块和趋势性模块的概念进行了阐述,认为季节性模块的统计属性(例如平均值和方差)是相对稳定的,因此将其视为稳定的分量(stationary)。反之,趋势性分量反映了长期预测,均值会随着时间的推移而变化,因此趋势性分量是不稳定的(non-stationary)。将两种流进行区分处理,季节流进行非线性处理(non-linear),趋势流进行线性处理(linear)。

2. 模型结构

xPatch的模型结构如下图所示:

1743128817644.png

xPatch首先将输入的时间序列进行EMA分解,将季节性分量和趋势性分量分开。季节性分量进行patching操作后输入非线性处理层,非线性层包含了深度可分离卷积(depthwise separable convolution)和残差连接。趋势性分量直接输入到线性层,通过线性操作直接进行输出。最后将分流输出进行连接,获得最终的预测结果。

2.1. EMA分解

EMA(Exponential Moving Average)分解对传统SMA(Simple Moving Average)分解进行了改进。 传统SMA分解以k为窗口大小,对窗口内的时间序列作平均,获取移动平均值。

st=xt+xt+1++xt+k1k=1ki=tt+k1xiXT=AvgPool(Padding(X))XS=XXT(1)s_t = \frac{x_t + x_{t+1} + \ldots + x_{t+k-1}}{k} = \frac{1}{k} \sum_{i=t}^{t+k-1} x_i \\ X_T = \text{AvgPool}(\text{Padding}(X))\tag{1} \\ X_S = X - X_T\\

而EMA采用了指数移动平均,其权重系数随着时间的推移而衰减,权重的衰减率由超参数α决定。这一形式让EMA更加着重于捕捉短期趋势,关注短期的信号变化。

s0=x0st=αxt+(1α)st1,,t>0XT=EMA(X)XS=XXT(2)s_0 = x_0 \\ s_t = \alpha x_t + (1 - \alpha) s_{t-1}, , t > 0 \tag{2} \\ X_T = \text{EMA}(X) \\ X_S = X - X_T \\

2.2. patching操作

文中提到,patching操作的灵感来自于vision transformer (ViT)的patching操作。ViT将输入图像划分为多个patch,xPatch的patching操作与ViT的patching操作类似,将时间序列数据划分为多个patch。目的在于提取重复的季节性特征,更好地关注重复的模式,捕捉模式间的依赖关系。

2.3. 深度可分离卷积

采用了深度可分离卷积(depthwise separable convolution)结构,将卷积操作分解为深度卷积(depthwise convolution)和逐点(pointwise)卷积。深度卷积提取空间特征,逐点卷积提取时间特征。在深度卷积中,每个输入通道独立地应用一个卷积核,每个卷积核仅在对应的一个通道上进行卷积操作。这种方式不会组合来自不同输入通道的信息。逐点卷积是使用 1×1 的卷积核对深度卷积的输出进行卷积,其目的是组合来自不同通道的特征。这一步骤可以看作是一个特征融合和通道数调整的过程。

3.模型特点

xPatch的模型结构具有以下特点:

  • 基于指数分解的双流架构:xPatch将时间序列数据使用EMA分解为季节性分量和趋势性分量,分别进行非线性处理和线性处理。
  • 季节性分量的patching操作:xPatch的patching操作与ViT的patching操作类似,将季节性分量进行patching操作,提取重复的季节性特征。
  • 非线性层:xPatch的非线性层采用深度可分离卷积(depthwise separable convolution)结构,提取季节性特征。

4.源码解析

EMA分解的实现:

import torch
from torch import nn


class EMA(nn.Module):
    """
    Exponential Moving Average (EMA) block to highlight the trend of time series
    """
    def __init__(self, alpha):
        super(EMA, self).__init__()
        # self.alpha = nn.Parameter(alpha)    # Learnable alpha
        self.alpha = alpha


    # Optimized implementation with O(1) time complexity
    def forward(self, x):
        # x: [Batch, Input, Channel]
        # self.alpha.data.clamp_(0, 1)        # Clamp learnable alpha to [0, 1]
        _, t, _ = x.shape
        # [95, 94, 93,..., 1, 0]
        powers = torch.flip(torch.arange(t, dtype=torch.double), dims=(0,))
        # [exp(1-alpha, 95), exp(1-alpha, 94), exp(1-alpha, 93),..., 1 - alpha, 1]
        weights = torch.pow((1 - self.alpha), powers).to(x.device)
        divisor = weights.clone()
        # [exp(1-alpha, 94) * alpha, exp(1-alpha, 93) * alpha, ..., (1 - alpha) * alpha, alpha]
        weights[1:] = weights[1:] * self.alpha
        weights = weights.reshape(1, t, 1)
        divisor = divisor.reshape(1, t, 1)
        # 将时间序列数据乘以权重,然后累加,除以除数,通过这一步操作可以对应原文的公式
        x = torch.cumsum(x * weights, dim=1)
        x = torch.div(x, divisor)
        return x.to(torch.float32)
    
    # 原始实现,时间复杂度为O(n)
    # # Naive implementation with O(n) time complexity
    # def forward(self, x):
    #     # self.alpha.data.clamp_(0, 1)        # Clamp learnable alpha to [0, 1]
    #     s = x[:, 0, :]
    #     res = [s.unsqueeze(1)]
    #     for t in range(1, x.shape[1]):
    #         xt = x[:, t, :]
    #         s = self.alpha * xt + (1 - self.alpha) * s
    #         res.append(s.unsqueeze(1))
    #     return torch.cat(res, dim=1)

网络架构的实现:

import torch
from torch import nn


class Network(nn.Module):
    def __init__(self, seq_len, pred_len, patch_len, stride, padding_patch):
        super(Network, self).__init__()


        # Parameters
        self.pred_len = pred_len


        # Non-linear Stream
        # Patching
        self.patch_len = patch_len
        self.stride = stride
        self.padding_patch = padding_patch
        self.dim = patch_len * patch_len
        self.patch_num = (seq_len - patch_len)//stride + 1
        if padding_patch == 'end': # can be modified to general case
            self.padding_patch_layer = nn.ReplicationPad1d((0, stride)) 
            self.patch_num += 1


        # Patch Embedding
        # 将patch进行embedding到指定维度
        self.fc1 = nn.Linear(patch_len, self.dim)
        self.gelu1 = nn.GELU()
        self.bn1 = nn.BatchNorm1d(self.patch_num)
        
        # CNN Depthwise
        self.conv1 = nn.Conv1d(self.patch_num, self.patch_num,
                               patch_len, patch_len, groups=self.patch_num)
        self.gelu2 = nn.GELU()
        self.bn2 = nn.BatchNorm1d(self.patch_num)


        # Residual Stream
        self.fc2 = nn.Linear(self.dim, patch_len)


        # CNN Pointwise
        self.conv2 = nn.Conv1d(self.patch_num, self.patch_num, 1, 1)
        self.gelu3 = nn.GELU()
        self.bn3 = nn.BatchNorm1d(self.patch_num)


        # Flatten Head
        self.flatten1 = nn.Flatten(start_dim=-2)
        self.fc3 = nn.Linear(self.patch_num * patch_len, pred_len * 2)
        self.gelu4 = nn.GELU()
        self.fc4 = nn.Linear(pred_len * 2, pred_len)


        # Linear Stream
        # MLP
        self.fc5 = nn.Linear(seq_len, pred_len * 4)
        self.avgpool1 = nn.AvgPool1d(kernel_size=2)
        self.ln1 = nn.LayerNorm(pred_len * 2)


        self.fc6 = nn.Linear(pred_len * 2, pred_len)
        self.avgpool2 = nn.AvgPool1d(kernel_size=2)
        self.ln2 = nn.LayerNorm(pred_len // 2)


        self.fc7 = nn.Linear(pred_len // 2, pred_len)


        # Streams Concatination
        self.fc8 = nn.Linear(pred_len * 2, pred_len)


    def forward(self, s, t):
        # x: [Batch, Input, Channel]
        # s - seasonality
        # t - trend
        
        # 将分流数据进行reshape,使得时间序列数据在最后一个维度上,便于进行patching操作
        # s_shape, t_shape: [32, 862, 96]
        s = s.permute(0,2,1) # to [Batch, Channel, Input] 
        t = t.permute(0,2,1) # to [Batch, Channel, Input]
        
        # Channel split for channel independence
        B = s.shape[0] # Batch size
        C = s.shape[1] # Channel size
        I = s.shape[2] # Input size
        s = torch.reshape(s, (B*C, I)) # [Batch and Channel, Input]
        t = torch.reshape(t, (B*C, I)) # [Batch and Channel, Input]


        # Non-linear Stream
        # Patching
        if self.padding_patch == 'end':
            s = self.padding_patch_layer(s)
        s = s.unfold(dimension=-1, size=self.patch_len, step=self.stride)
        # s_after_patching: [Batch and Channel, Patch_num, Patch_len] = [27584, 12, 16]
        
        # Patch Embedding
        s = self.fc1(s)
        s = self.gelu1(s)
        s = self.bn1(s)
        # s_after_embedding: [Batch and Channel, Patch_num, Dim] = [27584, 12, 256]
        
        # 预留结果,后续用于残差连接
        res = s


        # CNN Depthwise 
        s = self.conv1(s)
        s = self.gelu2(s)
        s = self.bn2(s)


        # Residual Stream
        res = self.fc2(res)
        s = s + res


        # CNN Pointwise
        s = self.conv2(s)
        s = self.gelu3(s)
        s = self.bn3(s)


        # Flatten Head
        s = self.flatten1(s)
        # s_after_flatten: [Batch and Channel, Dim] = [27584, 256]
        s = self.fc3(s)
        # s_after_head: [Batch and Channel, Pred_len * 2] = [27584, 192]
        s = self.gelu4(s)
        s = self.fc4(s)
        # s_after_output: [Batch and Channel, Pred_len] = [27584, 96]


        # Linear Stream
        # MLP
        t = self.fc5(t)
        t = self.avgpool1(t)
        t = self.ln1(t)


        t = self.fc6(t)
        t = self.avgpool2(t)
        t = self.ln2(t)


        t = self.fc7(t)


        # Streams Concatination
        x = torch.cat((s, t), dim=1)
        x = self.fc8(x)


        # Channel concatination
        x = torch.reshape(x, (B, C, self.pred_len)) # [Batch, Channel, Output]
        # 还原形状
        x = x.permute(0,2,1) # to [Batch, Output, Channel]


        return x