时间序列模型(2):TimesNet

385 阅读4分钟

TimesNet

1.模型背景

TimesNet是一种基于深度学习的时序预测模型,其主要特点是能够捕捉时间序列中复杂的模式,并对未来进行准确的预测。传统方法在处理时间序列的复杂变化时面临挑战,因为它们通常试图直接从一维时间序列中捕捉这些变化,而忽略了时间序列的多周期性。TimesNet通过将一维时间序列重塑为二维张量来捕捉时间序列的多周期性,并使用卷积神经网络(CNN)来学习时间序列的模式。

1742523098804.png

2.模型结构

TimesNet的模型结构如下图所示:

1742523342268.png

TimesBlock

TimesBlock是TimesNet的核心模块,负责发现时间序列的多周期性,并从转换后的二维张量中提取复杂的时间变化。TimesBlock由以下模块组成:

(1)快速傅里叶变换

FFT是一种快速傅里叶变换算法,可以将一维时间序列转换为二维张量。将输入序列进行FFT变换,可以得到频域信号,即不同周期长度在时域上的频率分布。选取top k个周期长度,就可以将整体时间序列进行分割。

(2)reshape

经过FFT操作后,我们可以得到二维张量,如下图所示:

1742524526787.png

这种转换使得二维张量的列和行分别反映了周期内变化和周期间变化,其中每一列包含一个周期内的所有时间点,每一行包含不同周期中相同相位的时间点。

(3)Inception块

卷积过程采用参数高效的Inception块来处理二维张量,通过多尺度二维卷积核同时捕捉周期内变化和周期间变化。

(4)reshape back

经过Inception块处理后,最终将二维张量重塑回一维序列,得到最终的预测结果。

(5)Adaptive Aggregation

在处理完二维张量后,TimesBlock将结果重新转换回一维张量,并根据周期的幅度进行自适应聚合。即将不同周期的预测结果根据周期对应频率的振幅进行加权平均,以更好地捕 捉时间序列的多周期性。

3. 模型特点

TimesNet具有以下特点:

  1. 提出了周期内与周期间的理念。(intraperiod-variation and interperiod-variation)
  2. 将时序变化基于多周期转换为二维张量。(we extend  the analysis of temporal variations into the 2D space by transforming the 1D time  series into a set of 2D tensors based on multiple periods.)

4. 源码解析(traffic)

预测任务:

def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
    # Normalization from Non-stationary Transformer
    means = x_enc.mean(1, keepdim=True).detach()
    x_enc = x_enc - means
    stdev = torch.sqrt(
        torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
    x_enc /= stdev
    print(x_enc.shape) # [32, 96, 862]
    # embedding
    enc_out = self.enc_embedding(x_enc, x_mark_enc)  # [B,T,C]
    print(enc_out.shape) # [32, 96, 512]
    enc_out = self.predict_linear(enc_out.permute(0, 2, 1)).permute(
        0, 2, 1)  # align temporal dimension
    print(enc_out.shape) # [32, 192, 512]
    # TimesNet
    for i in range(self.layer):
        enc_out = self.layer_norm(self.model[i](enc_out))
    # porject back
    dec_out = self.projection(enc_out)


    # De-Normalization from Non-stationary Transformer
    dec_out = dec_out * \
                (stdev[:, 0, :].unsqueeze(1).repeat(
                    1, self.pred_len + self.seq_len, 1))
    dec_out = dec_out + \
                (means[:, 0, :].unsqueeze(1).repeat(
                    1, self.pred_len + self.seq_len, 1))
    return dec_out

模型结构:

def FFT_for_Period(x, k=2):
    # [B, T, C]
    xf = torch.fft.rfft(x, dim=1)
    # find period by amplitudes
    frequency_list = abs(xf).mean(0).mean(-1)
    # 屏蔽直流fenquency
    frequency_list[0] = 0
    # 获取top k个周期长度
    _, top_list = torch.topk(frequency_list, k)
    top_list = top_list.detach().cpu().numpy()
    # 计算一个时间序列包含的周期数,T // top_list
    period = x.shape[1] // top_list
    # 返回周期数和周期权重
    return period, abs(xf).mean(-1)[:, top_list]



class TimesBlock(nn.Module):
    def __init__(self, configs):
        super(TimesBlock, self).__init__()
        self.seq_len = configs.seq_len
        self.pred_len = configs.pred_len
        self.k = configs.top_k
        # parameter-efficient design
        self.conv = nn.Sequential(
            Inception_Block_V1(configs.d_model, configs.d_ff,
                               num_kernels=configs.num_kernels),
            nn.GELU(),
            Inception_Block_V1(configs.d_ff, configs.d_model,
                               num_kernels=configs.num_kernels)
        )


    def forward(self, x):
        B, T, N = x.size()
        period_list, period_weight = FFT_for_Period(x, self.k)


        res = []
        for i in range(self.k):
            # 获取周期数
            period = period_list[i]
            # padding
            if (self.seq_len + self.pred_len) % period != 0:
                length = (
                                 ((self.seq_len + self.pred_len) // period) + 1) * period
                padding = torch.zeros([x.shape[0], (length - (self.seq_len + self.pred_len)), x.shape[2]]).to(x.device)
                out = torch.cat([x, padding], dim=1)
            else:
                length = (self.seq_len + self.pred_len)
                out = x
            # reshape, [32, 512, 39, 5]
            out = out.reshape(B, length // period, period,
                              N).permute(0, 3, 1, 2).contiguous()
            # 2D conv: from 1d Variation to 2d Variation
            out = self.conv(out)
            # reshape back
            out = out.permute(0, 2, 3, 1).reshape(B, -1, N) # [32, 192, 512]
            res.append(out[:, :(self.seq_len + self.pred_len), :])
        res = torch.stack(res, dim=-1)
        # adaptive aggregation,将不同周期的输出加权求和
        period_weight = F.softmax(period_weight, dim=1)
        period_weight = period_weight.unsqueeze(
            1).unsqueeze(1).repeat(1, T, N, 1)
        res = torch.sum(res * period_weight, -1)
        # residual connection
        res = res + x
        return res