BasicTS 项目入门指南:从零开始掌握时间序列分析

57 阅读39分钟

时间序列分析是机器学习领域中最具实用价值的方向之一——从股票价格预测、天气预报,到交通流量监控、能源负荷管理,时间序列无处不在。然而,对于初学者而言,这个领域往往面临一个尴尬的困境:论文中的模型五花八门,代码实现各不相同,想要公平地对比不同方法几乎是一场噩梦。

BasicTS 正是为解决这一痛点而生。

作为一个公平、可扩展的时间序列分析基准库,BasicTS 集成了 30+ 种主流深度学习模型(从经典的 Informer、Autoformer 到最新的 iTransformer、TimeMixer)和 20+ 个标准数据集,覆盖长期预测(LTSF)、时空预测(STF)、时序分类、缺失值插补等核心任务。更重要的是,它提供了统一的训练和评估流程,让你可以用三行代码快速启动实验,专注于模型创新而非工程细节。

无论你是刚入门时序分析的新手,还是需要快速复现论文结果的研究者,这篇指南都将带你从安装配置到模型训练,全面掌握 BasicTS 的使用方法。

让我们开始吧。

介绍

BasicTS (Basic Time Series) 是一个面向时间序列分析的基准库和工具箱,现已支持时空预测、长序列预测、分类、插补等多种任务与数据集,涵盖统计模型、机器学习模型、深度学习模型等多类算法,为开发和评估时间序列预测模型提供了理想的工具。

链接:GestaltCogTeam/BasicTS: A Fair and Scalable Time Series Forecasting Benchmark and Toolkit.

📁 项目结构

BasicTS/
├── basicts/
│   ├── configs/          # 配置类
│   │   ├── tsf_config.py     # 预测任务配置
│   │   ├── tsc_config.py     # 分类任务配置
│   │   └── tsi_config.py     # 插补任务配置
│   ├── data/             # 数据集类
│   ├── metrics/          # 评估指标
│   ├── models/           # 模型实现
│   │   ├── DLinear/
│   │   ├── iTransformer/
│   │   └── ...
│   ├── modules/          # 可复用组件
│   │   ├── transformer/      # Transformer 组件
│   │   ├── norm/             # 归一化层
│   │   └── embed/            # 嵌入层
│   ├── runners/          # 执行器
│   │   ├── callback/         # 回调函数
│   │   └── taskflow/         # 任务流
│   └── scaler/           # 数据缩放器
├── datasets/             # 数据集存放目录
├── checkpoints/          # 模型保存目录
└── examples/             # 示例代码

最新动态

🎉 更新(2025年10月):BasicTS 内置支持选择学习(NeurIPS'25),一种有效缓解过拟合,增加模型性能和泛化性的训练策略。用户可以从回调模块中导入并直接使用。使用说明

🎉 更新(2025年10月):BasicTS 1.0版本发布了!新特性:

  • 🚀 三行代码,快速上手:pip install 安装,极简 API 设计,快速实现模型训练与评估。
  • 📦 模块化组件,开箱即用:提供 Transformer、MLP 等即插即用的组件,像搭积木一样构建自己的模型。
  • 🔄 多任务支持:支持时序预测、分类、插补等多个时序分析核心任务。
  • 🔧 高可扩展架构:基于 Taskflow 与 Callback 机制,无需修改 Runner 即可轻松定制。

🎉 更新(2025年5月): BasicTS 现已支持使用 BLAST (KDD'25) 语料库训练通用预测模型(例如 TimeMoE 和 ChronosBolt)。BLAST 能够实现 更快的收敛速度、显著降低计算成本,并且即使在资源有限的情况下也能获得卓越性能。

✨ 主要功能亮点

BasicTS 一方面通过 统一且标准化的流程,为热门的深度学习模型提供了 公平且全面 的复现与对比平台。另一方面,BasicTS 提供了用户 友好且易于扩展 的接口,帮助快速设计和评估新模型。用户只需定义模型结构,便可轻松完成基本操作。

公平的性能评估:通过统一且全面的流程,用户能够公平且充分地对比不同模型在任意数据集上的性能表现。

使用 BasicTS 进行开发你可以:

最简代码实现

  • 用户只需实现关键部分如模型架构、数据预处理和后处理,即可构建自己的深度学习项目。

基于配置文件控制一切

  • 用户可以通过配置文件掌控流程中的所有细节,包括数据加载器的超参数、优化策略以及其他技巧(如课程学习)。

支持所有设备

  • BasicTS 支持 CPU、GPU 以及分布式 GPU 训练(单节点多 GPU 和多节点),依托 EasyTorch 作为后端。用户只需通过设置参数即可使用这些功能,无需修改代码。

保存训练日志

  • BasicTS 提供 logging 日志系统和 Tensorboard 支持,并统一封装接口,用户可以通过简便的接口调用来保存自定义的训练日志。

📦 支持的模型

BasicTS 实现了丰富的基线模型,包括经典模型、时空预测模型、长序列预测模型、通用预测模型等。

这些模型的代码实现可在 baselines 目录中找到。

下表中的代码链接(💻Code) 指向了相关论文的官方实现,感谢各位作者对代码的开源贡献!

通用预测模型

UFM = Universal Forecasting Models(通用预测模型)

这是一类预训练的大规模时间序列基础模型,类似于 NLP 领域的 GPT、BERT。它们的特点是:

特性说明
预训练在海量时间序列数据上预训练
零样本/少样本无需或只需少量微调即可用于新数据集
跨领域可以处理不同领域的时间序列(天气、交通、金融等)
大规模参数量通常在数十亿级别

UFM vs 传统模型的区别

传统模型 (如 DLinear, iTransformer):

数据集 A \rightarrow 训练模型 \rightarrow 预测 A,每个数据集单独训练。

UFM 通用模型:

海量数据预训练 \rightarrow 预训练模型 \rightarrow 直接用于任意数据集,零样本或微调。


📊Baseline📝Title📄Paper💻Code🏛Venue🎯Task
TimeMoETime-MoE: Billion-Scale Time Series Foundation Models with Mixture of ExpertsLinkLinkICLR'25UFM
ChronosBoltChronos: Learning the Language of Time SeriesLinkLinkTMLR'24UFM
MOIRAI (inference)Unified Training of Universal Time Series Forecasting TransformersLinkLinkICML'24UFM

时空预测

什么是 STF?

STF = Spatial-Temporal Forecasting(时空预测)

这是一类需要同时考虑空间关系和时间依赖的预测任务。

核心概念:

普通时间序列预测 (LTSF):
    时间 →
    ┌───┬───┬───┬───┬───┐
    │ t1│ t2│ t3│ t4│ ? │  ← 单个序列,只考虑时间
    └───┴───┴───┴───┴───┘

时空预测 (STF):
    时间 →
    ┌───┬───┬───┬───┬───┐
节点A │ t1│ t2│ t3│ t4│ ? │
    ├───┼───┼───┼───┼───┤
节点B │ t1│ t2│ t3│ t4│ ? │  ← 多个节点 + 节点间有空间关系
    ├───┼───┼───┼───┼───┤
节点C │ t1│ t2│ t3│ t4│ ? │
    └───┴───┴───┴───┴───┘
         ↑
    节点之间通过图(Graph)连接

📊Baseline📝Title📄Paper💻Code🏛Venue🎯Task
STDNSpatiotemporal-aware Trend-Seasonality Decomposition Network for Traffic Flow ForecastingLinkLinkAAAI'25STF
HimNetHeterogeneity-Informed Meta-Parameter Learning for Spatiotemporal Time Series ForecastingLinkLinkSIGKDD'24STF
DFDGCNDynamic Frequency Domain Graph Convolutional Network for Traffic ForecastingLinkLinkICASSP'24STF
STPGNNSpatio-Temporal Pivotal Graph Neural Networks for Traffic Flow ForecastingLinkLinkAAAI'24STF
BigSTLinear Complexity Spatio-Temporal Graph Neural Network for Traffic Forecasting on Large-Scale Road NetworksLinkLinkVLDB'24STF
STDMAESpatio-Temporal-Decoupled Masked Pre-training for Traffic ForecastingLinkLinkIJCAI'24STF
STWaveWhen Spatio-Temporal Meet Wavelets: Disentangled Traffic Forecasting via Efficient Spectral Graph Attention NetworksLinkLinkICDE'23STF
STAEformerSpatio-Temporal Adaptive Embedding Makes Vanilla Transformer SOTA for Traffic ForecastingLinkLinkCIKM'23STF
MegaCRNSpatio-Temporal Meta-Graph Learning for Traffic ForecastingLinkLinkAAAI'23STF
DGCRNDynamic Graph Convolutional Recurrent Network for Traffic Prediction: Benchmark and SolutionLinkLinkACM TKDD'23STF
STIDSpatial-Temporal Identity: A Simple yet Effective Baseline for Multivariate Time Series ForecastingLinkLinkCIKM'22STF
STEPPretraining Enhanced Spatial-temporal Graph Neural Network for Multivariate Time Series ForecastingLinkLinkSIGKDD'22STF
D2STGNNDecoupled Dynamic Spatial-Temporal Graph Neural Network for Traffic ForecastingLinkLinkVLDB'22STF
STNormSpatial and Temporal Normalization for Multi-variate Time Series ForecastingLinkLinkSIGKDD'21STF
STGODESpatial-Temporal Graph ODE Networks for Traffic Flow ForecastingLinkLinkSIGKDD'21STF
GTSDiscrete Graph Structure Learning for Forecasting Multiple Time SeriesLinkLinkICLR'21STF
StemGNNSpectral Temporal Graph Neural Network for Multivariate Time-series ForecastingLinkLinkNeurIPS'20STF
MTGNNConnecting the Dots: Multivariate Time Series Forecasting with Graph Neural NetworksLinkLinkSIGKDD'20STF
AGCRNAdaptive Graph Convolutional Recurrent Network for Traffic ForecastingLinkLinkNeurIPS'20STF
GWNetGraph WaveNet for Deep Spatial-Temporal Graph ModelingLinkLinkIJCAI'19STF
STGCNSpatio-Temporal Graph Convolutional Networks: A Deep Learning Framework for Traffic ForecastingLinkLinkIJCAI'18STF
DCRNNDiffusion Convolutional Recurrent Neural Network: Data-Driven Traffic ForecastingLinkLink1, Link2ICLR'18STF

Long-Term Time Series Forecasting

什么是 LTSF?

LTSF = Long-Term Time Series Forecasting(长期时间序列预测)

这是时间序列预测中最基础、最常见的任务:根据历史数据预测未来较长时间的数值。

LTSF 的特点

特点说明
长期预测长度通常为 96、192、336、720 个时间步
多变量通常同时预测多个变量(如温度、湿度、风速等)
无图结构不考虑变量之间的空间拓扑关系
挑战长期依赖建模、误差累积、非平稳性

📊Baseline📝Title📄Paper💻Code🏛Venue🎯Task
S-D-MambaIs Mamba Effective for Time Series Forecasting?LinkLinkNeuroComputing'24LTSF
Bi-MambaBi-Mamba+: Bidirectional Mamba for Time Series ForecastingLinkLinkarXiv'24LTSF
ModernTCNModernTCN: A Modern Pure Convolution Structure for General Time Series AnalysisLinkLinkICLR'24LTSF
TimeXerTimeXer: Empowering Transformers for Time Series Forecasting with Exogenous VariablesLinkLinkNeurIPS'24LTSF
CARDCARD: Channel Aligned Robust Blend Transformer for Time Series ForecastingLinkLinkICLR'24LTSF
SOFTSSOFTS: Efficient Multivariate Time Series Forecasting with Series-Core FusionLinkLinkNeurIPS'24LTSF
CATSAre Self-Attentions Effective for Time Series Forecasting?LinkLinkNeurIPS'24LTSF
SumbaStructured Matrix Basis for Multivariate Time Series Forecasting with Interpretable DynamicsLinkLinkNeurIPS'24LTSF
GLAFFRethinking the Power of Timestamps for Robust Time Series Forecasting: A Global-Local Fusion PerspectiveLinkLinkNeurIPS'24LTSF
CycleNetCycleNet: Enhancing Time Series Forecasting through Modeling Periodic Patterns ForecastingLinkLinkNeurIPS'24LTSF
FredformerFredformer: Frequency Debiased Transformer for Time Series ForecastingLinkLinkKDD'24LTSF
UMixerAn Unet-Mixer Architecture with Stationarity Correction for Time Series ForecastingLinkLinkAAAI'24LTSF
TimeMixerDecomposable Multiscale Mixing for Time Series ForecastingLinkLinkICLR'24LTSF
Time-LLMTime-LLM: Time Series Forecasting by Reprogramming Large Language ModelsLinkLinkICLR'24LTSF
SparseTSFModeling LTSF with 1k ParametersLinkLinkICML'24LTSF
iTrainsformerInverted Transformers Are Effective for Time Series ForecastingLinkLinkICLR'24LTSF
KoopaLearning Non-stationary Time Series Dynamics with Koopman PredictorsLinkLinkNeurIPS'24LTSF
CrossGNNCrossGNN: Confronting Noisy Multivariate Time Series Via Cross Interaction RefinementLinkLinkNeurIPS'23LTSF
NLinearAre Transformers Effective for Time Series Forecasting?LinkLinkAAAI'23LTSF
CrossformerTransformer Utilizing Cross-Dimension Dependency for Multivariate Time Series ForecastingLinkLinkICLR'23LTSF
DLinearAre Transformers Effective for Time Series Forecasting?LinkLinkAAAI'23LTSF
DSformerA Double Sampling Transformer for Multivariate Time Series Long-term PredictionLinkLinkCIKM'23LTSF
SegRNNSegment Recurrent Neural Network for Long-Term Time Series ForecastingLinkLinkarXivLTSF
MTS-MixersMultivariate Time Series Forecasting via Factorized Temporal and Channel MixingLinkLinkarXivLTSF
LightTSFast Multivariate Time Series Forecasting with Light Sampling-oriented MLPLinkLinkarXivLTSF
ETSformerExponential Smoothing Transformers for Time-series ForecastingLinkLinkarXivLTSF
NHiTSNeural Hierarchical Interpolation for Time Series ForecastingLinkLinkAAAI'23LTSF
PatchTSTA Time Series is Worth 64 Words: Long-term Forecasting with TransformersLinkLinkICLR'23LTSF
TiDELong-term Forecasting with TiDE: Time-series Dense EncoderLinkLinkTMLR'23LTSF
S4Efficiently Modeling Long Sequences with Structured State SpacesLinkLinkICLR'22LTSF
TimesNetTemporal 2D-Variation Modeling for General Time Series AnalysisLinkLinkICLR'23LTSF
TriformerTriangular, Variable-Specific Attentions for Long Sequence Multivariate Time Series ForecastingLinkLinkIJCAI'22LTSF
NSformerExploring the Stationarity in Time Series ForecastingLinkLinkNeurIPS'22LTSF
FiLMFrequency improved Legendre Memory Model for LTSFLinkLinkNeurIPS'22LTSF
FEDformerFrequency Enhanced Decomposed Transformer for Long-term Series ForecastingLinkLinkICML'22LTSF
PyraformerLow complexity pyramidal Attention For Long-range Time Series Modeling and ForecastingLinkLinkICLR'22LTSF
HIHistorical Inertia: A Powerful Baseline for Long Sequence Time-series ForecastingLinkNoneCIKM'21LTSF
AutoformerDecomposition Transformers with Auto-Correlation for Long-Term Series ForecastingLinkLinkNeurIPS'21LTSF
InformerBeyond Efficient Transformer for Long Sequence Time-Series ForecastingLinkLinkAAAI'21LTSF

其他方法

📊Baseline📝Title📄Paper💻Code🏛Venue🎯Task
CatBoostCatboost: unbiased boosting with categorical featuresLinkLinkNeurIPS'18Machine Learning
LightGBMLightGBM: A Highly Efficient Gradient Boosting Decision TreeLinkLinkNeurIPS'17Machine Learning
NBeatsNeural basis expansion analysis for interpretable time series forecastingLinkLink1, Link2ICLR'19Deep Time Series Forecasting
DeepARProbabilistic Forecasting with Autoregressive Recurrent NetworksLinkLink1, Link2, Link3Int. J. Forecast'20Probabilistic Time Series Forecasting
WaveNetWaveNet: A Generative Model for Raw Audio.LinkLink 1, Link 2arXivAudio
ARVII. On a method of investigating periodicities disturbed series, with special reference to Wolfer's sunspot numbersLinkLink1927Local Forecasting
MAOn periodicity in series of related termsLinkLink1931Local Forecasting
ARMASome recent advances in forecasting and controlLinkLinkApplied Statistics'1968Local Forecasting
ARIMAForecasting with exponential smoothing: the state space approachLinkLink2008Local Forecasting
SARIMAForecasting with exponential smoothing: the state space approachLinkLink2008Local Forecasting
ARCHConditional heteroscedasticity in time series of stock returns: Evidence and forecastsLinkLinkJournal of business'1989Local Forecasting
GARCHConditional heteroscedasticity in time series of stock returns: Evidence and forecastsLinkLinkJournal of business'1989Local Forecasting
ETSThe holt-winters forecasting procedureLinkLinkApplied Statistics'1978Local Forecasting
SESThe holt-winters forecasting procedureLinkLinkApplied Statistics'1978Local Forecasting
SVRSupport vector regression machinesLinkLinkNeurIPS'1996Machine Learning
PolySVRA training algorithm for optimal margin classifiersLinkLinkCOLT'1992Machine Learning

📦 支持的数据集

BasicTS 支持多种类型的数据集,涵盖时空预测、长序列预测及大规模数据集。

数据集表格字段说明

字段含义示例
Length时间序列的总长度(时间步数)14400 表示有 14400 个时间点
Time Series Count变量/节点数量7 表示有 7 个变量同时被记录
Graph是否包含图结构(节点间的空间关系)True 表示有邻接矩阵
Freq. (m)采样频率(分钟)60 表示每小时采样一次

例如,ETTh1 数据集:

  • 14400 个时间步 × 60 分钟/步 = 600 天的数据
  • 7 个变量:包括油温、负载等电力变压器相关指标

时空预测

🏷️Name🌐Domain📏Length📊Time Series Count🔄Graph⏱️Freq. (m)🎯Task
METR-LATraffic Speed34272207True5STF
PEMS-BAYTraffic Speed52116325True5STF
PEMS03Traffic Flow26208358True5STF
PEMS04Traffic Flow16992307True5STF
PEMS07Traffic Flow28224883True5STF
PEMS08Traffic Flow17856170True5STF

长序列预测

🏷️Name🌐Domain📏Length📊Time Series Count🔄Graph⏱️Freq. (m)🎯Task
BeijingAirQualityBeijing Air Quality360007False60LTSF
ETTh1Electricity Transformer Temperature144007False60LTSF
ETTh2Electricity Transformer Temperature144007False60LTSF
ETTm1Electricity Transformer Temperature576007False15LTSF
ETTm2Electricity Transformer Temperature576007False15LTSF
ElectricityElectricity Consumption26304321False60LTSF
ExchangeRateExchange Rate75888False1440LTSF
IllnessIlness Data9667False10080LTSF
TrafficRoad Occupancy Rates17544862False60LTSF
WeatherWeather5269621False10LTSF

大规模数据集

🏷️Name🌐Domain📏Length📊Time Series Count🔄Graph⏱️Freq. (m)🎯Task
CATraffic Flow350408600True15Large Scale
GBATraffic Flow350402352True15Large Scale
GLATraffic Flow350403834True15Large Scale
SDTraffic Flow35040716True15Large Scale

Pre-training Corpus

🏷️Name🌐Domain📏Length📊Time Series Count🔄Graph⏱️Freq.🎯Task
BLASTMultiple409620000000FalseMultipleUFM

🔗 EasyTorch

BasicTS 是基于 EasyTorch 开发的,这是一个易于使用且功能强大的开源神经网络训练框架。

EasyTorch 的主要特点:

特点说明
简化训练流程封装了 PyTorch 训练的常见模式,减少样板代码
分布式训练支持内置 DDP(分布式数据并行)支持,一行代码启动多 GPU 训练
设备管理统一的设备管理(CPU/GPU/MLU),自动处理数据迁移
检查点管理提供 save_ckptload_ckptbackup_last_ckpt 等便捷函数
日志系统统一的 Logger 接口,支持 TensorBoard
环境配置随机种子设置、确定性训练、TF32 模式等

EasyTorch 提供的核心功能

设备管理
from easytorch.device import set_device_type, to_device
set_device_type("gpu")  # 设置使用 GPU
data = to_device(data)  # 自动将数据移到正确设备
分布式训练
from easytorch.launcher.dist_wrap import dist_wrap
# 自动包装为分布式训练
train_dist = dist_wrap(
    training_func,
    node_num=1,        # 节点数
    device_num=4,      # GPU 数量
    dist_backend="nccl"
)
train_dist(cfg)
检查点管理
from easytorch.core.checkpoint import save_ckpt, load_ckpt
save_ckpt(model, optimizer, epoch, path)  # 保存检查点
load_ckpt(model, path)                     # 加载检查点
工具函数
from easytorch.utils import get_logger, is_master, master_only
logger = get_logger("MyProject")  # 获取日志器
@master_only  # 只在主进程执行
def save_results():
    ...
if is_master():  # 判断是否为主进程
    print("Main process")
环境设置
from easytorch.utils.env import setup_determinacy, set_tf32_mode

setup_determinacy(seed=42)  # 设置随机种子,确保可复现
set_tf32_mode(True)         # 启用 TF32 加速
总结

EasyTorch 的定位:

PyTorch (底层) → EasyTorch (训练框架) → BasicTS (时间序列工具包)

EasyTorch 负责:设备管理、分布式训练、检查点、日志 BasicTS 负责:时间序列数据、模型、评估指标

快速上手

📦 安装 BasicTS

建议在 Linux 系统(如 Ubuntu 或 CentOS)上,在 Python 3.8 或更高版本上安装 BasicTS:

pip install basicts

推荐使用 Miniconda 或 Anaconda 来创建虚拟 Python 环境。

🔧 安装依赖项

PyTorch

BasicTS 对 PyTorch 版本非常灵活。您可以根据 Python 版本安装 PyTorch。我们建议使用 pip 进行安装。

示例设置

示例 1:Python 3.11 + PyTorch 2.5.1 + CUDA 12.4 (推荐)

# 安装 Python
conda create -n BasicTS python=3.11
conda activate BasicTS
# 安装 PyTorch
pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124

示例 2:Python 3.9 + PyTorch 1.10.0 + CUDA 11.1

# 安装 Python
conda create -n BasicTS python=3.9
conda activate BasicTS
# 安装 PyTorch
pip install torch==1.10.0+cu111 torchvision==0.11.0+cu111 torchaudio==0.10.0 -f https://download.pytorch.org/whl/torch_stable.html
🔍 下载数据集

请先从 Google Drive or 百度网盘 下载 datasets.zip 文件。将文件解压到 datasets/ 目录:

cd /path/to/YourProject # not BasicTS/basicts
unzip /path/to/datasets.zip -d datasets/

这些数据集已预处理完毕,可以直接使用。

data.dat 文件是以 numpy.memmap 格式存储的数组,包含原始时间序列数据,形状为 [L, N, C],其中 L 是时间步数,N 是时间序列数,C 是特征数。

desc.json 文件是一个字典,存储了数据集的元数据,包括数据集名称、领域、频率、特征描述、常规设置和缺失值。

其他文件是可选的,可能包含附加信息,如表示时间序列间预定义图结构的 adj_mx.pkl

如果您对预处理步骤感兴趣,可以参考预处理脚本blob/master/scripts/data_preparation raw_data.zip

🎯 快速教程:三句代码训练并评估您的模型

# train.py

from basicts.models.DLinear import DLinear, DLinearConfig
from basicts.configs import BasicTSForecastingConfig
from basicts import BasicTSLauncher

def main():

    # 1. 配置模型
    model_config = DLinearConfig(input_len=336, output_len=336)

    # 2. 配置任务
    cfg = BasicTSForecastingConfig(
        model=DLinear,
        model_config=model_config,
        dataset_name="ETTh1",
        gpus="0",
        ...
    )

    # 3.启动训练
    BasicTSLauncher.launch_training(cfg)

第一步:配置您的模型 BasicTS在basicts.models中提供了大量常用模型,您可以直接使用。BasicTS使用配置类来配置模型,每个配置类包含了对构造模型需要的每个参数的详细描述。例如,DLinear模型的配置类为 DLinearConfig。您可以在 dinear_config.py 中找到DLinearConfig类。

如果您想要使用自己的模型,则需要遵循 BasicTS 的规范,详情请见下面:🧠模型设计

第二步:配置您的任务

BasicTS 支持多种时间序列任务,包括预测、插补、分类等。任务配置类是 BasicTS的核心,关于 BasicTS 任务的一切信息都囊括在了任务配置类中,几乎全部配置项都有常用的默认值,您只需配置关键的参数(1️⃣模型,2️⃣数据集),并修改少量配置(如batch size,学习率等),就能运行代码。

您可以在basicts/configs 中找到每个 BasicTS 任务的配置类 (例如,预测任务的配置类为BasicTSForecastingConfig),及其每个参数的含义与配置方法。

进一步,在 BasicTS 的配置类中,您还可以指定回调(callbacks)和任务流(taskflow)用于在训练过程中执行额外的操作(如课程学习)和自定义数据处理流程。关于BasicTS配置类的进阶用法,请见:📜 配置设计

第三步:启动训练

BasicTSLauncher.launch_training 是训练的入口点,调用该方法并传入任务配置即可启动训练。

需要注意的是,在DDP模式下,BasicTSLauncher.launch_training 需要被包裹在if __name__ == '__main__':中,以确保每个进程都能正确初始化模型和数据集。

🥳 运行它!

完整可运行示例

# 完整可运行示例
from basicts.models.DLinear import DLinear, DLinearConfig
from basicts.configs import BasicTSForecastingConfig
from basicts import BasicTSLauncher

def main():
    # 1. 配置模型
    model_config = DLinearConfig(input_len=96, output_len=96)

    # 2. 配置任务
    cfg = BasicTSForecastingConfig(
        model=DLinear,
        model_config=model_config,
        dataset_name="ETTh1",
        gpus="0",
        batch_size=32,
        num_epochs=100,
        learning_rate=0.001,
    )

    # 3. 启动训练
    BasicTSLauncher.launch_training(cfg)

if __name__ == "__main__":
    main()

在您项目的目录下,运行以下命令即可启动训练:

python train.py

在训练中,BasicTS会默认将训练好的模型保存到 checkpoints/ 目录下,并在训练完成后执行评估(可以通过配置更改),您也可以选择将评估指标和结果保存到 checkpoints/ 目录下。

如何评估您的模型

当然,您也可以在训练结束后手动评估模型:BasicTSLauncher.launch_evaluation 是评估的入口点,您可以通过执行下面的Python代码来评估您的模型。

BasicTSLauncher.launch_evaluation(cfg, "checkpoints/your_checkpoint.pt")

函数签名

BasicTSLauncher.launch_evaluation(
    cfg,                    # 配置对象
    ckpt_path,              # 检查点文件路径
    gpus=None,              # 可选:使用的 GPU
    batch_size=None         # 可选:评估时的批次大小
)

参数解释

参数类型说明
cfgBasicTSConfig任务配置(与训练时相同)
ckpt_pathstr训练保存的模型检查点文件路径
gpusstr使用哪些 GPU,如 "0""0,1"
batch_sizeint评估时的批次大小(可覆盖配置中的值

基本用法

from basicts.models.DLinear import DLinear, DLinearConfig
from basicts.configs import BasicTSForecastingConfig
from basicts import BasicTSLauncher

# 1. 配置(与训练时相同)
model_config = DLinearConfig(input_len=96, output_len=96)
cfg = BasicTSForecastingConfig(
    model=DLinear,
    model_config=model_config,
    dataset_name="ETTh1",
)

# 2. 评估已训练的模型
BasicTSLauncher.launch_evaluation(
    cfg, 
    ckpt_path="checkpoints/DLinear_ETTh1/best_model.pt"
)

🧑‍💻 进一步探索

本教程为您提供了 BasicTS 的基础知识,但还有更多内容等待您探索。在深入其他主题之前,我们先更详细地了解 BasicTS 的结构:

DesignConvention.jpeg

BasicTS 的核心组件包括 DatasetScalerModelMetricsRunnerConfig

🎯 模型选择指南

场景推荐模型原因
快速基线DLinear, NLinear简单高效,训练快
长序列预测PatchTST, iTransformer专为长序列设计
交通预测STID, D2STGNN支持图结构
资源受限SparseTSF仅 1K 参数
零样本预测TimeMoE, ChronosBolt预训练通用模型
多变量关系Crossformer, TimesNet建模变量间依赖

💡 BasicTS 的整体设计

在时间序列分析流程中,通常包含以下几个关键部分:

  • 数据集 (Dataset):定义读取数据集和生成样本的方式。(位于 basicts.data模块)
  • 数据缩放 (Scaler):处理数据的归一化与反归一化操作,如 Z-score 和 Min-Max 归一化等方法。(位于 basicts.scaler模块)
  • 评估指标 (Metrics):定义模型评估的指标和损失函数,包括 MAE、MSE、MAPE、RMSE 和 WAPE 等。(位于 basicts.metrics模块)
  • 执行器 (Runner):作为 BasicTS 的核心模块,负责协调整个训练过程。执行器集成了数据集、数据缩放、模型架构和评估指标等组件,提供了多种功能支持。(位于 basicts.runner模块)
  • 模型结构 (Model):定义模型架构及其前向传播过程(位于basicts.models模块)

BasicTS 提供了完整且可扩展的组件,用户只需提供模型结构,即可完成大部分任务。

为了简化训练策略配置并方便对比,BasicTS 遵循 一切基于配置 的设计理念:所有选项都集成在配置文件中,用户只需修改配置文件即可轻松设置模型、数据集、归一化方法、评估指标、优化器、学习率、甚至添加扩展功能——一切就像填写表单一样简单。

📦 数据集设计

🎸 新特性

从1.0版本开始,BasicTS采用数据解耦的设计方案,用户可以使用任意数据结构的数据集,只需继承BasicTSDataset基类实现自定义的数据集读取逻辑即可。

从1.0版本开始,BasicTS不再将数据和时间戳整合在一个四维张量中存储([batch_size, seq_len, num_features, num_timestamps + 1]),改为使用两个三维张量,显著降低了显存占用。

  • 时序数据: [batch_size, seq_len, num_features];
  • 时间戳数据:[batch_size, num_features, num_timestamps]。

📊 内置数据集

⏬ 数据下载

要开始使用内置数据集,请先从 Google Drive or 百度网盘 下载 datasets.zip 文件。下载后,将文件解压至 datasets/ 目录:

cd /path/to/project
unzip /path/to/all_data.zip -d datasets/

这是BasicTS默认的数据集保存路径,当然,您也可以将数据集放在任意其他路径下,并在dataset_params中的data_file_path字段显式的提供根路径。

这些数据集已经过预处理,可以直接使用。

未来将会支持在线下载内置数据集,该功能目前正在开发中。

🔬 使用内置数据集

内置数据集通常配合BasicTS内置支持的数据集类使用,内置数据类也是配置中的默认选项。 内置数据集类:

  • 预测任务BasicTSForecastingDataset
  • 分类任务UEADataset
  • 插补任务BasicTSImputationDataset

这些内置数据集类包括以下参数:

  • dataset_name (str): 数据集的名称。
  • input_len (int): 输入序列的长度,即历史数据点的数量。
  • output_len (int): (仅预测任务)输出序列的长度,即需要预测的未来数据点的数量。
  • mode (BasicTSMode | str): 数据集的模式,"TRAIN", "VAL"或"TEST",指示其用于训练、验证还是测试。由runner统一指定,无需手动赋值。
  • use_timestamps (bool): 是否使用时间戳的标志,默认为False。
  • local (bool): 数据集是否在本地。(开发中)
  • data_file_path (str | None): 包含时间序列数据文件的路径。默认为 "datasets/{name}"。
  • memmap (bool): 是否使用内存映射加载数据集的标志。开启时节省内存但会降低训练速度,因此建议仅在数据集极大时使用。默认为False。

通常来说,默认设置下使用内置数据集,只需在配置类中指定dataset_nameinput_len以及output_len(预测任务)即可。

💿 数据格式

在BasicTS中,数据集提供的数据需要遵循标准的格式。__get_item__方法应返回一个包含以下项目的字典:

  • inputs:输入数据,形状为[batch_size, input_len, num_features]的torch.Tensor
  • targets:目标数据(可选)。一个torch.Tensor。对于预测和插补任务,形状为[batch_size, output_len, num_features],对于分类任务,形状为[batch_size, num_classes],对于自监督任务,不需要该键;
  • inputs_timestamps:输入数据的时间戳(可选),形状为[batch_size, input_len, num_timestamps]的torch.Tensor
  • targets_timestamps:输入数据的时间戳(可选),形状为[batch_size, output_len, num_timestamps]的torch.Tensor

🧑‍🍳 如何添加或自定义数据集

您可以通过以下三步使用您自定义的数据集:

  1. 编写数据集类,继承BasicTSDataset基类,基类包含三个字段:dataset_namemodememmap
  2. 自定义实现您的数据读取和预处理逻辑,实现__get_item____len__方法。请注意,虽然数据实际的存储结构可以是任意的,但__get_item__方法返回的数据项应该遵循上文提到的规范。
  3. 如果需要使用缩放器对数据做归一化,还需重写data方法(property)。该方法用于向缩放器提供一个待归一化数据的视图(np.ndarray),使缩放器学习整个训练集的分布。
  4. 在配置类中修改dataset_type字段为您自己的数据集类,并设置相应的dataset_params

🛠️ 数据缩放器设计 (Scaler)

🧐 什么是数据缩放器,为什么需要它?

数据缩放器(简称缩放器)是一个用于处理数据归一化和反归一化的类。在时间序列分析中,原始数据通常具有显著的尺度差异。因此,模型(尤其是深度学习模型)通常不会直接在原始数据上进行操作。相反,缩放器会将数据归一化到一个特定范围内,使其更适合建模。在计算损失函数或评估指标时,也可能将数据反归一化回原始尺度,以确保比较的准确性。

这使得缩放器成为时间序列分析工作流程中的重要组件。

👾 缩放器如何初始化及何时起作用?

缩放器与其他组件一起在执行器中初始化。

例如,Z-Score 缩放器会读取原始数据,并基于训练数据计算均值和标准差。

缩放器在从数据集中提取数据后起作用。数据首先由缩放器归一化,然后传递给模型进行训练。模型处理完数据后,缩放器会将输出反归一化,然后再传递给执行器进行损失计算和指标评估。

在许多时间序列分析研究中,归一化通常在数据预处理中进行,这也是早期 BasicTS 版本的做法。然而,这种方式的可扩展性较差。诸如更改输入/输出长度、应用不同的归一化方法(例如对每个时间序列单独归一化),或更改训练/验证/测试集的比例等调整,都会要求重新预处理数据。为了解决这个问题,BasicTS 采用了“即时归一化”的方式,每次提取数据时都会进行归一化处理。

# 在 runner 中
for data in dataloader:
    data = scaler.transform(data)
    forward_return = forward(data)
    forward_return = scaler.inverse_transform(forward_return)

🧑‍🔧 如何选择或自定义缩放器

BasicTS 提供了几种常见的缩放器,例如 Z-Score 和 Min-Max 缩放器。您可以通过在配置文件中设置 scaler 来轻松切换缩放器。

如果您需要自定义缩放器,可以扩展 BasicTSScaler 类,并实现 transforminverse_transform 方法。或者,您也可以选择不继承该类,但仍然需要实现这两个方法。

常用缩放器对比

缩放器公式特点
Z-Scorex=xμσx' = \frac{x - \mu}{\sigma}标准化到均值0、标准差1,适合正态分布数据
Min-Maxx=xxminxmaxxminx' = \frac{x - x_{min}}{x_{max} - x_{min}}缩放到 [0, 1] 区间,保留原始分布形状
RevIN实例级归一化 + 反归一化处理非平稳时间序列,在模型内部使用

选择建议:

  • 大多数情况下使用 Z-Score(默认)
  • 数据有明确边界时使用 Min-Max
  • 非平稳数据考虑在模型中启用 RevIN

🧠 模型设计

您的模型的 forward 函数应遵循 BasicTS 设定的规范。

🏗️ 构造模型

BasicTS使用配置类/字典构造模型,该配置类/字典应该包含构造模型所需的全部参数。 BasicTS模型配置类的基类为BasicTSModelConfig,其本身是字典的子类。当使用配置类构造模型时,您可以继承这一基类定义您的模型的配置。例如:

@dataclass
class YourModelConfig(BasicTSModelConfig):
	input_len: int
	output_len: int
	num_features: int
	hidden_size: int = 256
	hidden_act: int = "relu"

class YourModel(nn.Module):
	def __init__(config: YourModelConfig):
		...

[!important]

⚠️注意:强烈建议在配置中只使用可以JSON序列化的字段(数值、字符串、布尔、列表、元组、字典等),避免将自定义类作为字段,否则配置文件可能无法被正常保存。

🪴 输入接口

BasicTS 自1.0起,forward函数不再强制要求传入固定的参数(尽管未使用),而是可以按需指定传入的参数。然而,传入参数需要遵守以下规范。

  • 标准模型参数:BasicTS 1.0 标准的forward参数命名如下。模型的主输入为inputs,输出为targets;若使用时间戳,则时间戳数据为inputs_timestampstargets_timestamps;若需要使用mask信息(如计算损失),则掩码数据为inputs_masktargets_mask。此外,还可以传入当前训练的轮(epoch)数和步(step)数。注意,train参数即将被淘汰,可以访问nn.Moduletraining字段实现。

    def forward(
        self,
        inputs: torch.Tensor,
        targets: Optional[torch.Tensor] = None,
        inputs_timestamps: Optional[torch.Tensor] = None,
        targets_timestamps: Optional[torch.Tensor] = None,
        inputs_mask: Optional[torch.Tensor] = None,
        targets_mask: Optional[torch.Tensor] = None,
        epoch: Optional[int] = None,
        step: Optional[int] = None,
        train: Optional[bool] = None
        ,**kwargs 
    ):
    

    假设模型只需要用到输入序列及其时间戳,则:

      class MyModel(nn.Module):
          def forward(self, inputs: torch.Tensor, inputs_timestamps: torch.Tensor):
      		...
    
  • 自定义模型参数:您可以在forward函数中加入任何自定义模型参数,但需要保证数据字典中包含该键。例如:

     # 如果包含extra_flag这个额外的参数,则需要保证传入的数据字典中包含该键:
     # {"inputs": inputs, "extra_flag": extra_flag, ...}
     def forward(self, inputs: torch.Tensor, extra_flag: bool):
     	...
    

    您可以在数据流上游添加或修改数据字典:Dataset或taskflow。 数据流向: Dataset.__get_item__ -> taskflow.preprocess -> model.forward

    • Dataset.__get_item__中添加(推荐):在数据集的__get_item__函数中返回包含该键的字典。 例如:

      class MyDataset(torch.utils.data.Dataset):
      	def __get_item__(self, idx: int):
      		return {
      			"inputs": self.inputs[idx],
      			"targets": self.targets[idx],
      			"extra_flag": self.flag[idx] # <-- add extra_flag
      		}
      
    • taskflow.preprocess改变数据字典:在自定义Taskflow类的preprocess可以修改数据字典。由于涉及对任务逻辑的修改,建议新用户谨慎使用该方法。 例如:

      class MyTaskflow(BasicTSTaskflow):
      	def preprocess(self, data: dict):
      		...
      		data["extra_flag"] = self.extra_flag # <-- add extra_flag
      		return data
      

🌷 输出接口

forward 函数的返回值应该是一个字典或一个torch.Tensor

  • 字典中必须包含键prediction,代表模型的预测结果。
  • 若返回值为一个torch.Tensor,则后续pipeline会自动将其包装成字典{"prediction":...},从而计算损失。
  • 字典中可以添加任意您自定义的键,用于实现自定义逻辑或计算评估指标等。
  • 想要返回在模型内部计算的损失时,必须返回包含键loss的字典(若直接传一个损失的torch.Tensor则会被视作预测结果)。当字典中包含loss时,后续pipeline不会再计算损失,而是直接取用。
  • 想要返回在内部计算的额外损失,并与主损失相加时,须在配置类中使用AddAuxiliaryLoss的callback,并指定额外损失的键名。例如,传递名为freq_losslb_loss的额外损失,使最终损失为MSE + freq_loss + lb_loss:
 # in your_train_script.py
 config=BasicTSConfig(
     loss=masked_mse,
     callback=[AddAuxiliaryLoss([`freq_loss`, `lb_loss`])],
     ...
 )
 
 # in your_model.py
 def forward(...):
     return {
   	  "prediction": prediction,
   	  "freq_loss": freq_loss,
   	  "lb_loss": lb_loss
   	  }

🥳 支持的基线模型

BasicTS 提供了多种内置模型。您可以在models 模块中找到它们,并只需导入对应的模型类和模型配置类即可使用模型。以使用STID为例:

from basicts.models.STID import STID, STIDConfig

task_config = BasicTSForecastingConfig(
	model=STID,
	model_config=STIDConfig,
	...
)

特别地,对于内置的多任务模型,通常包含一个公用的骨干网络(XXXBackbone,XXX为模型名),以及若干个任务特定的模型(XXXForYYY,YYY为任务名)。以TimesNet为例,可以导入TimesNetForForecasting进行预测任务,TimesNetForClassification进行分类任务,TimesNetForReconstruction进行插补任务。这些下游任务公用相同的骨干网络和相同的配置类。

from basicts.models.TimesNet import TimesNetBackbone, TimesNetForForecasting, TimesNetForClassifiction, TimesNetForReconstruction, TimesNetConfig

📉 评估指标设计

接口规范

评估指标是评估模型性能的重要组成部分。在 BasicTS 中,评估指标是接受模型预测值、真实值及其他参数作为输入并返回标量值以评估模型性能的函数。

一个定义良好的评估指标函数应包含以下参数:

  • prediction: 模型的预测值
  • targets: 实际的真实值
  • targets_mask: 可选参数,用于指定在哪些点上计算损失(一般用于掩码缺失值)。

prediction 和 target 是必需参数,而 targets_mask 是可选参数,但强烈建议采纳,以处理时间序列数据中常见的缺失值。

评估指标函数还可以接受其他额外参数,这些参数会从模型的返回值中提取并传递给指标函数。

BasicTS 内置评估指标

BasicTS 提供了多种常用的评估指标,例如 MAE、MSE、RMSE、MAPE 和 WAPE。您可以在 basicts.metrics 模块中找到这些指标的实现。

常用评估指标详解

指标公式适用场景
MAE$\frac{1}{n}\sumy_i - \hat{y}_i$对异常值不敏感,适合一般预测任务
MSE1n(yiy^i)2\frac{1}{n}\sum(y_i - \hat{y}_i)^2对大误差惩罚更重,常用作损失函数
RMSEMSE\sqrt{MSE}与原数据同量纲,便于解释
MAPE$\frac{100%}{n}\sum\frac{y_i - \hat{y}_i}{y_i}$百分比误差,适合比较不同量级数据
WAPE$\frac{\sumy_i - \hat{y}_i}{\sumy_i}$加权百分比误差,避免 MAPE 除零问题

注意:当真实值接近 0 时,MAPE 会趋向无穷大,此时建议使用 WAPE 或 MAE。

如何实现自定义评估指标

根据接口规范中的指南,您可以轻松实现自定义的评估指标。以下是一个示例:

class MyModel:
    def __init__(self):
        # 初始化模型
        ...
    
    def forward(...):
        # 前向计算
        ...
        return {
                'prediction': prediction,
                'targets': target,
                'other_key1': other_value1,
                'other_key2': other_value2,
                'other_key3': other_value3,
                ...
        }

def my_metric_1(prediction, targets, targets_mask=None, other_key1=None, other_key2=None, ...):
    # 计算指标
    ...

def my_metric_2(prediction, targets, targets_mask=None, other_key3=None, ...):
    # 计算指标
    ...

遵循这些规范,您可以灵活地在 BasicTS 中自定义和扩展评估指标,以满足特定需求。

🧮 仪表盘

该节仅涉及细节内容,绝大部分情况下不会影响使用,可以跳过。

在BasicTS中,我们使用仪表盘(Meter类)在训练中维护指标值。BasicTS会默认使用平均仪表盘(AvgMeter类),逐步更新并维护对应指标的均值,这适用于绝大部分指标。

然而,也有一些指标不应该维护均值,例如RMSE,是先求平均再开平方,此时如果逐步累积最后再求平均则会产生错误(虽然一般不影响模型的训练结果)。此时,应该使用特殊的仪表盘,实现正确的增量计算。

🏃‍♂️ 执行器与流程

💿 概述

执行器是 BasicTS 的核心组件,负责管理整个训练和评估过程。它将数据集、数据缩放器、模型、评估指标和配置文件等各个子组件集成在一起,构建一个公平且可扩展的训练和评估流程。

自BasicTS 1.0起,BasicTS只需要一个执行器类BasicTSRunner,并对其进行了全面重构和解耦。您无需再修改任何执行器代码,就能实现任何自定义的扩展功能。

BasicTS训练与评估流程的三层架构:重构后的BasicTS的训练与评估流程可以被分为三个层次。

  • 执行器与通用流程层(BasicTSRunner):集结了一切基础流程中通用的、和具体任务无关的训练流程。用户不应该直接修改该层次的代码。
  • 任务流层(BasicTSTaskflow):定义了基础流程中和任务相关的步骤。当不修改任务流程时,用户应该尽量少地自定义该层的对象。
  • 回调层(BasicTSCallback):定义了基础流程之外的扩展功能,例如早停、梯度裁剪、课程学习等。当想要扩展功能时,用户应该尽可能地通过回调来实现。

⚡️ 通用流程

以训练为例(评估类似),执行器实现的通用流程如下列伪代码所示。 与标准深度学习框架相符,通用流程包括:模型前传、计算损失、损失反传、优化器更新。

def train_loop(self):
	for epoch in range(num_epochs):
		
		# Event 1: on_epoch_start events
		callback_handler.trigger("on_epoch_start")
		
		for data in train_data_loder:
		
			# Event 2: on_step_start events
			callback_handler.trigger("on_step_start")
			
			# Task-specific 1: preprocess data 
			data = taskflow.preprocess(self, data)
			
			# General pipeline 1: model forward
			forward_return = forward()
			
			# Event 3: on_compute_loss events
			callback_handler.trigger("on_compute_loss")
			
			# General pipeline 2: compute loss
			loss = metric_forward(loss_function, forward_return)
			
			# Task-specific 2: get loss weight
			loss_weight = taskflow.get_weight(forward_return)
			
			# Event 4: on_backward events
			callback_handler.trigger("on_backward") # on_backward events

			# General pipeline 3: loss backward
			loss.backward()

			# Event 5: on_optimizer_step events
			callback_handler.trigger("on_optimizer_step")

			# General pipeline 4: optimizer step
			optimizer_step()

			# Task-specific 3: postprocess forward return
			forward_return = taskflow.postprocess(self, forward_return)

			# General pipeline 5: compute metrics
			metric_value = metric_forward(metric_fn, forward_return)

			# Event 6: on_step_end events
			callback_handler.trigger("on_step_end")

		# Event 7: on_epoch_end events
		callback_handler.trigger("on_epoch_end")

💫 任务流

任务流模块位于basicts.runners.taskflow,其基类定义如下:

class BasicTSTaskflow():
	def preprocess(self, runner, data):
		pass
	
	def postprocess(self, runner, forward_return):
		pass
	
	def get_weight(self, forward_return):
		pass
  • preprocess:定义数据在模型前传前的预处理逻辑,包括归一化、生成缺失值掩码等。
  • postprocess:定义数据在计算指标前的后处理逻辑,包括反归一化(预测任务),计算argmax(分类任务)等。
  • get_weight:定义当前批次在全部训练数据中的损失权重,保证数据集的整体损失能被正确计算。例如,分类任务的权重应该是该批次的样本数,预测任务应该是该批次全部有效点的数量。

🪝 回调层

回调模块位于basicts.runners.callback。一个回调类应该包含若干个回调函数,执行器的CallbackHandler对象会在对应的阶段调用这些函数,以实现功能的扩展。

回调基类BasicTSCallback定义了全部可用的回调函数:

class BasicTSCallback:
	# 训练开始时
	def on_train_start(self, runner, *args, **kwargs):
		pass
	# 训练结束时
	def on_train_end(self, runner, *args, **kwargs):
		pass
	# epoch开始时
	def on_epoch_start(self, runner, *args, **kwargs):
		pass
	# epoch结束时
	def on_epoch_end(self, runner, *args, **kwargs):
		pass
	# step开始时
	def on_step_start(self, runner, *args, **kwargs):
		pass
	# step结束时
	def on_step_end(self, runner, *args, **kwargs):
		pass
	# 验证开始时
	def on_validate_start(self, runner, *args, **kwargs):
		pass
	# 验证结束时
	def on_validate_end(self, runner, *args, **kwargs):
		pass
	# 测试开始时
	def on_test_start(self, runner, *args, **kwargs):
		pass
	# 测试结束时
	def on_test_end(self, runner, *args, **kwargs):
		pass
	# 计算损失前
	def on_compute_loss(self, runner, *args, **kwargs):
		pass
	# 反向传播前
	def on_backward(self, runner, *args, **kwargs):
		pass
	# 优化器更新前
	def on_optimizer_step(self, runner, *args, **kwargs):
		pass

常用内置回调示例

1. 早停 (Early Stopping)
from basicts.runners.callback import EarlyStopping

cfg = BasicTSForecastingConfig(
    callbacks=[
        EarlyStopping(patience=10, monitor="val_loss")
    ],
    ...
)
2. 梯度裁剪 (Gradient Clipping)
from basicts.runners.callback import GradientClipping

cfg = BasicTSForecastingConfig(
    callbacks=[
        GradientClipping(max_norm=1.0)
    ],
    ...
)
3. 课程学习 (Curriculum Learning)
from basicts.runners.callback import CurriculumLearning

cfg = BasicTSForecastingConfig(
    callbacks=[
        CurriculumLearning(start_len=12, end_len=96, warmup_epochs=10)
    ],
    ...
)

📜 配置设计

BasicTS 的设计理念是“配置即一切“。BasicTS 的目标是让用户专注于模型和数据,而不用被繁琐的流程构建所困扰。

🎸 新特性

从1.0版本开始,BasicTS不再使用py文件配合命令行指定配置路径的方式,升级为使用配置类进行配置。 配置类的基类为BasicTSConfig,此外,每个具体任务对应一个配置类,包括BasicTSForecastingConfigBasicTSClassificationConfigBasicTSImputationConfigBasicTSFoundationModelConfig等。基类BasicTSConfig定义了公用的字段以及保存/加载/打印配置类的方法,任务特定的配置类则包括执行该任务需要的一切配置参数。您可以灵活地向其中导入模型,并设置所有必要的选项。

配置类通常包含以下部分:

  • 常规选项: 描述一般设置,如配置说明、gpusseed 等。
  • 环境选项: 包括设置如 tf32cudnndeterministic 等。
  • 数据集选项: 指定 dataset_name(数据集名,必须显式指定)、dataset_type(数据集类)、dataset_params(数据集参数)等。
  • 数据缩放器选项: 指定 scaler(缩放器类)、norm_each_channel(通道独立归一化)、rescale(是否反归一化)等。
  • 模型选项: 指定 model(模型类,必须显式指定)、model_config(模型参数,必须显式指定)等。
  • 评估指标选项: 包括 metrics(评估指标函数)、target_metric(目标评估指标)等。
  • 训练选项:
    • 常规: 指定设置如 num_epochs/num_stepsloss 等。
    • 优化器: 指定 optimizer(优化器类)、optimizer_params(优化器参数)等。
    • 调度器: 指定 lr_scheduler(调度器类)、lr_scheduler_params(调度器参数)等。
    • 数据: 指定设置如 batch_sizenum_workerspin_memory等。
  • 验证选项:
    • 常规: 包括验证频率 val_interval
    • 数据: 指定设置如 batch_sizenum_workerspin_memory 等。
  • 测试选项:
    • 常规: 包括测试频率 test_interval
    • 数据: 指定设置如 batch_sizenum_workerspin_memory 等。

Config类字段的metadata提供了每个字段的默认值、详细含义及使用方法。

🏗️ 构造配置类

👥 既是类,也是字典

BasicTS的配置类继承自EasyDict,既可以作为类使用,也可以作为字典使用,方便扩展并且灵活易用。下面介绍两种使用方式:

  1. 像类一样使用:一切参数都是类的字段,可以像访问类的字段一样访问、修改以及添加新的参数。例如:

    model = config.model
    config.gpus = "0"
    config.new_field = new_value
    
  2. 像字典一样使用:一切参数都是字典的键值,可以像访问字典一样访问、修改、添加、删除参数。例如:

    model = config["model"]
    optimizer = config.get("optimizer", Adam)
    config["new_key"] = new_value
    config.pop("not_used")
    
🔨 配置类中的对象

也许你已经发现了,配置中存在许多需要进一步构造的对象,例如模型、数据集、优化器、调度器等。BasicTS在配置中传入对应的类和创建对象所需的参数,而将这些对象的创建延迟到实际执行任务时。

对于这些对象,BasicTS支持两种方式来灵活地配置:

  1. 将构造对象所需的全部参数以字典的形式传入。例如,将参数字典传给dataset_params以供后续创建数据集的实例。

    config = BasicTSForecastingConfig(
    	dataset_params={
    		'input_len': 336,
    		'output_len': 336,
    		...},
    	...)
    
  2. 直接将构造对象需要的参数作为字段传入。例如,可以直接配置Config类的input_lenoutput_len,并将自定义的参数传入(不能直接在构造方法中添加未定义的字段)。

    config = BasicTSForecastingConfig(
    	input_len=336,
    	output_len=336,
    	...)
    config.your_dataset_param_1 = param_1
    config.your_dataset_param_2 = param_2
    

❓ 常见问题 (FAQ)

Q1: 训练时出现 CUDA out of memory 怎么办?

A: 尝试以下方法:

  • 减小 batch_size
  • 减小 input_lenoutput_len
  • 使用 memmap=True 加载大数据集
  • 使用混合精度训练

Q2: 如何使用多 GPU 训练?

A: 修改配置中的 gpus 参数:

cfg = BasicTSForecastingConfig(
    gpus="0,1,2,3",  # 使用 4 张 GPU
    ...
)

Q3: 如何恢复中断的训练?

A: 使用检查点恢复:

cfg.resume_from = "checkpoints/last_checkpoint.pt"

Q4: 如何只进行推理而不训练?

A: 使用 launch_evaluation 并加载预训练权重。

Q5: 自定义数据集需要什么格式?

A: 数据需要是 [L, N, C] 形状的 numpy 数组,其中:

L = 时间步数 N = 变量数 C = 特征数(通常为 1)

声明

本文内容整理自 BasicTS 官方 GitHub 仓库 及其官方文档,结合个人学习理解进行归纳总结,仅供学习交流使用。

文中引用的论文链接、代码链接均指向原作者的官方发布地址,版权归原作者所有。如有任何内容侵犯了您的权益,请通过评论区或私信联系我,我将在核实后第一时间删除相关内容。

感谢 BasicTS 开发团队的开源贡献!

欢迎关注我的公众号:「木子吉星