TimesNet
1.模型背景
TimesNet是一种基于深度学习的时序预测模型,其主要特点是能够捕捉时间序列中复杂的模式,并对未来进行准确的预测。传统方法在处理时间序列的复杂变化时面临挑战,因为它们通常试图直接从一维时间序列中捕捉这些变化,而忽略了时间序列的多周期性。TimesNet通过将一维时间序列重塑为二维张量来捕捉时间序列的多周期性,并使用卷积神经网络(CNN)来学习时间序列的模式。
2.模型结构
TimesNet的模型结构如下图所示:
TimesBlock
TimesBlock是TimesNet的核心模块,负责发现时间序列的多周期性,并从转换后的二维张量中提取复杂的时间变化。TimesBlock由以下模块组成:
(1)快速傅里叶变换
FFT是一种快速傅里叶变换算法,可以将一维时间序列转换为二维张量。将输入序列进行FFT变换,可以得到频域信号,即不同周期长度在时域上的频率分布。选取top k个周期长度,就可以将整体时间序列进行分割。
(2)reshape
经过FFT操作后,我们可以得到二维张量,如下图所示:
这种转换使得二维张量的列和行分别反映了周期内变化和周期间变化,其中每一列包含一个周期内的所有时间点,每一行包含不同周期中相同相位的时间点。
(3)Inception块
卷积过程采用参数高效的Inception块来处理二维张量,通过多尺度二维卷积核同时捕捉周期内变化和周期间变化。
(4)reshape back
经过Inception块处理后,最终将二维张量重塑回一维序列,得到最终的预测结果。
(5)Adaptive Aggregation
在处理完二维张量后,TimesBlock将结果重新转换回一维张量,并根据周期的幅度进行自适应聚合。即将不同周期的预测结果根据周期对应频率的振幅进行加权平均,以更好地捕 捉时间序列的多周期性。
3. 模型特点
TimesNet具有以下特点:
- 提出了周期内与周期间的理念。(intraperiod-variation and interperiod-variation)
- 将时序变化基于多周期转换为二维张量。(we extend the analysis of temporal variations into the 2D space by transforming the 1D time series into a set of 2D tensors based on multiple periods.)
4. 源码解析(traffic)
预测任务:
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
# Normalization from Non-stationary Transformer
means = x_enc.mean(1, keepdim=True).detach()
x_enc = x_enc - means
stdev = torch.sqrt(
torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
x_enc /= stdev
print(x_enc.shape) # [32, 96, 862]
# embedding
enc_out = self.enc_embedding(x_enc, x_mark_enc) # [B,T,C]
print(enc_out.shape) # [32, 96, 512]
enc_out = self.predict_linear(enc_out.permute(0, 2, 1)).permute(
0, 2, 1) # align temporal dimension
print(enc_out.shape) # [32, 192, 512]
# TimesNet
for i in range(self.layer):
enc_out = self.layer_norm(self.model[i](enc_out))
# porject back
dec_out = self.projection(enc_out)
# De-Normalization from Non-stationary Transformer
dec_out = dec_out * \
(stdev[:, 0, :].unsqueeze(1).repeat(
1, self.pred_len + self.seq_len, 1))
dec_out = dec_out + \
(means[:, 0, :].unsqueeze(1).repeat(
1, self.pred_len + self.seq_len, 1))
return dec_out
模型结构:
def FFT_for_Period(x, k=2):
# [B, T, C]
xf = torch.fft.rfft(x, dim=1)
# find period by amplitudes
frequency_list = abs(xf).mean(0).mean(-1)
# 屏蔽直流fenquency
frequency_list[0] = 0
# 获取top k个周期长度
_, top_list = torch.topk(frequency_list, k)
top_list = top_list.detach().cpu().numpy()
# 计算一个时间序列包含的周期数,T // top_list
period = x.shape[1] // top_list
# 返回周期数和周期权重
return period, abs(xf).mean(-1)[:, top_list]
class TimesBlock(nn.Module):
def __init__(self, configs):
super(TimesBlock, self).__init__()
self.seq_len = configs.seq_len
self.pred_len = configs.pred_len
self.k = configs.top_k
# parameter-efficient design
self.conv = nn.Sequential(
Inception_Block_V1(configs.d_model, configs.d_ff,
num_kernels=configs.num_kernels),
nn.GELU(),
Inception_Block_V1(configs.d_ff, configs.d_model,
num_kernels=configs.num_kernels)
)
def forward(self, x):
B, T, N = x.size()
period_list, period_weight = FFT_for_Period(x, self.k)
res = []
for i in range(self.k):
# 获取周期数
period = period_list[i]
# padding
if (self.seq_len + self.pred_len) % period != 0:
length = (
((self.seq_len + self.pred_len) // period) + 1) * period
padding = torch.zeros([x.shape[0], (length - (self.seq_len + self.pred_len)), x.shape[2]]).to(x.device)
out = torch.cat([x, padding], dim=1)
else:
length = (self.seq_len + self.pred_len)
out = x
# reshape, [32, 512, 39, 5]
out = out.reshape(B, length // period, period,
N).permute(0, 3, 1, 2).contiguous()
# 2D conv: from 1d Variation to 2d Variation
out = self.conv(out)
# reshape back
out = out.permute(0, 2, 3, 1).reshape(B, -1, N) # [32, 192, 512]
res.append(out[:, :(self.seq_len + self.pred_len), :])
res = torch.stack(res, dim=-1)
# adaptive aggregation,将不同周期的输出加权求和
period_weight = F.softmax(period_weight, dim=1)
period_weight = period_weight.unsqueeze(
1).unsqueeze(1).repeat(1, T, N, 1)
res = torch.sum(res * period_weight, -1)
# residual connection
res = res + x
return res