xPatch
1. 模型背景
论文中提到,近年来注意力机制被广泛用于时间序列预测模型当中,类似于Crossformer (Zhang and Yan 2022) and PatchTST (Nie et al. 2023)等模型均为transformer-based模型。然而,文中提到,注意力机制在充分利用时间序列数据中的时间关系上遇到挑战。同时,xPatch假设了注意力机制的置换不变性(Permutation Invariance)会影响预测效果。因此,xPatch提出了一种**基于指数分解的双流架构(dual-stream architecture that utilizes exponential decomposition)**来改善这一问题。
xPatch架构利用**EMA分解(exponential moving average decomposition)**对时间序列数据进行建模,其思想是将时间序列数据分解为两个部分,季节性模块和趋势性模块。季节性模块捕捉季节性变化,趋势性模块捕捉趋势性变化。
文中将季节性模块和趋势性模块的概念进行了阐述,认为季节性模块的统计属性(例如平均值和方差)是相对稳定的,因此将其视为稳定的分量(stationary)。反之,趋势性分量反映了长期预测,均值会随着时间的推移而变化,因此趋势性分量是不稳定的(non-stationary)。将两种流进行区分处理,季节流进行非线性处理(non-linear),趋势流进行线性处理(linear)。
2. 模型结构
xPatch的模型结构如下图所示:
xPatch首先将输入的时间序列进行EMA分解,将季节性分量和趋势性分量分开。季节性分量进行patching操作后输入非线性处理层,非线性层包含了深度可分离卷积(depthwise separable convolution)和残差连接。趋势性分量直接输入到线性层,通过线性操作直接进行输出。最后将分流输出进行连接,获得最终的预测结果。
2.1. EMA分解
EMA(Exponential Moving Average)分解对传统SMA(Simple Moving Average)分解进行了改进。 传统SMA分解以k为窗口大小,对窗口内的时间序列作平均,获取移动平均值。
而EMA采用了指数移动平均,其权重系数随着时间的推移而衰减,权重的衰减率由超参数α决定。这一形式让EMA更加着重于捕捉短期趋势,关注短期的信号变化。
2.2. patching操作
文中提到,patching操作的灵感来自于vision transformer (ViT)的patching操作。ViT将输入图像划分为多个patch,xPatch的patching操作与ViT的patching操作类似,将时间序列数据划分为多个patch。目的在于提取重复的季节性特征,更好地关注重复的模式,捕捉模式间的依赖关系。
2.3. 深度可分离卷积
采用了深度可分离卷积(depthwise separable convolution)结构,将卷积操作分解为深度卷积(depthwise convolution)和逐点(pointwise)卷积。深度卷积提取空间特征,逐点卷积提取时间特征。在深度卷积中,每个输入通道独立地应用一个卷积核,每个卷积核仅在对应的一个通道上进行卷积操作。这种方式不会组合来自不同输入通道的信息。逐点卷积是使用 1×1 的卷积核对深度卷积的输出进行卷积,其目的是组合来自不同通道的特征。这一步骤可以看作是一个特征融合和通道数调整的过程。
3.模型特点
xPatch的模型结构具有以下特点:
- 基于指数分解的双流架构:xPatch将时间序列数据使用EMA分解为季节性分量和趋势性分量,分别进行非线性处理和线性处理。
- 季节性分量的patching操作:xPatch的patching操作与ViT的patching操作类似,将季节性分量进行patching操作,提取重复的季节性特征。
- 非线性层:xPatch的非线性层采用深度可分离卷积(depthwise separable convolution)结构,提取季节性特征。
4.源码解析
EMA分解的实现:
import torch
from torch import nn
class EMA(nn.Module):
"""
Exponential Moving Average (EMA) block to highlight the trend of time series
"""
def __init__(self, alpha):
super(EMA, self).__init__()
# self.alpha = nn.Parameter(alpha) # Learnable alpha
self.alpha = alpha
# Optimized implementation with O(1) time complexity
def forward(self, x):
# x: [Batch, Input, Channel]
# self.alpha.data.clamp_(0, 1) # Clamp learnable alpha to [0, 1]
_, t, _ = x.shape
# [95, 94, 93,..., 1, 0]
powers = torch.flip(torch.arange(t, dtype=torch.double), dims=(0,))
# [exp(1-alpha, 95), exp(1-alpha, 94), exp(1-alpha, 93),..., 1 - alpha, 1]
weights = torch.pow((1 - self.alpha), powers).to(x.device)
divisor = weights.clone()
# [exp(1-alpha, 94) * alpha, exp(1-alpha, 93) * alpha, ..., (1 - alpha) * alpha, alpha]
weights[1:] = weights[1:] * self.alpha
weights = weights.reshape(1, t, 1)
divisor = divisor.reshape(1, t, 1)
# 将时间序列数据乘以权重,然后累加,除以除数,通过这一步操作可以对应原文的公式
x = torch.cumsum(x * weights, dim=1)
x = torch.div(x, divisor)
return x.to(torch.float32)
# 原始实现,时间复杂度为O(n)
# # Naive implementation with O(n) time complexity
# def forward(self, x):
# # self.alpha.data.clamp_(0, 1) # Clamp learnable alpha to [0, 1]
# s = x[:, 0, :]
# res = [s.unsqueeze(1)]
# for t in range(1, x.shape[1]):
# xt = x[:, t, :]
# s = self.alpha * xt + (1 - self.alpha) * s
# res.append(s.unsqueeze(1))
# return torch.cat(res, dim=1)
网络架构的实现:
import torch
from torch import nn
class Network(nn.Module):
def __init__(self, seq_len, pred_len, patch_len, stride, padding_patch):
super(Network, self).__init__()
# Parameters
self.pred_len = pred_len
# Non-linear Stream
# Patching
self.patch_len = patch_len
self.stride = stride
self.padding_patch = padding_patch
self.dim = patch_len * patch_len
self.patch_num = (seq_len - patch_len)//stride + 1
if padding_patch == 'end': # can be modified to general case
self.padding_patch_layer = nn.ReplicationPad1d((0, stride))
self.patch_num += 1
# Patch Embedding
# 将patch进行embedding到指定维度
self.fc1 = nn.Linear(patch_len, self.dim)
self.gelu1 = nn.GELU()
self.bn1 = nn.BatchNorm1d(self.patch_num)
# CNN Depthwise
self.conv1 = nn.Conv1d(self.patch_num, self.patch_num,
patch_len, patch_len, groups=self.patch_num)
self.gelu2 = nn.GELU()
self.bn2 = nn.BatchNorm1d(self.patch_num)
# Residual Stream
self.fc2 = nn.Linear(self.dim, patch_len)
# CNN Pointwise
self.conv2 = nn.Conv1d(self.patch_num, self.patch_num, 1, 1)
self.gelu3 = nn.GELU()
self.bn3 = nn.BatchNorm1d(self.patch_num)
# Flatten Head
self.flatten1 = nn.Flatten(start_dim=-2)
self.fc3 = nn.Linear(self.patch_num * patch_len, pred_len * 2)
self.gelu4 = nn.GELU()
self.fc4 = nn.Linear(pred_len * 2, pred_len)
# Linear Stream
# MLP
self.fc5 = nn.Linear(seq_len, pred_len * 4)
self.avgpool1 = nn.AvgPool1d(kernel_size=2)
self.ln1 = nn.LayerNorm(pred_len * 2)
self.fc6 = nn.Linear(pred_len * 2, pred_len)
self.avgpool2 = nn.AvgPool1d(kernel_size=2)
self.ln2 = nn.LayerNorm(pred_len // 2)
self.fc7 = nn.Linear(pred_len // 2, pred_len)
# Streams Concatination
self.fc8 = nn.Linear(pred_len * 2, pred_len)
def forward(self, s, t):
# x: [Batch, Input, Channel]
# s - seasonality
# t - trend
# 将分流数据进行reshape,使得时间序列数据在最后一个维度上,便于进行patching操作
# s_shape, t_shape: [32, 862, 96]
s = s.permute(0,2,1) # to [Batch, Channel, Input]
t = t.permute(0,2,1) # to [Batch, Channel, Input]
# Channel split for channel independence
B = s.shape[0] # Batch size
C = s.shape[1] # Channel size
I = s.shape[2] # Input size
s = torch.reshape(s, (B*C, I)) # [Batch and Channel, Input]
t = torch.reshape(t, (B*C, I)) # [Batch and Channel, Input]
# Non-linear Stream
# Patching
if self.padding_patch == 'end':
s = self.padding_patch_layer(s)
s = s.unfold(dimension=-1, size=self.patch_len, step=self.stride)
# s_after_patching: [Batch and Channel, Patch_num, Patch_len] = [27584, 12, 16]
# Patch Embedding
s = self.fc1(s)
s = self.gelu1(s)
s = self.bn1(s)
# s_after_embedding: [Batch and Channel, Patch_num, Dim] = [27584, 12, 256]
# 预留结果,后续用于残差连接
res = s
# CNN Depthwise
s = self.conv1(s)
s = self.gelu2(s)
s = self.bn2(s)
# Residual Stream
res = self.fc2(res)
s = s + res
# CNN Pointwise
s = self.conv2(s)
s = self.gelu3(s)
s = self.bn3(s)
# Flatten Head
s = self.flatten1(s)
# s_after_flatten: [Batch and Channel, Dim] = [27584, 256]
s = self.fc3(s)
# s_after_head: [Batch and Channel, Pred_len * 2] = [27584, 192]
s = self.gelu4(s)
s = self.fc4(s)
# s_after_output: [Batch and Channel, Pred_len] = [27584, 96]
# Linear Stream
# MLP
t = self.fc5(t)
t = self.avgpool1(t)
t = self.ln1(t)
t = self.fc6(t)
t = self.avgpool2(t)
t = self.ln2(t)
t = self.fc7(t)
# Streams Concatination
x = torch.cat((s, t), dim=1)
x = self.fc8(x)
# Channel concatination
x = torch.reshape(x, (B, C, self.pred_len)) # [Batch, Channel, Output]
# 还原形状
x = x.permute(0,2,1) # to [Batch, Output, Channel]
return x