时间序列分析是机器学习领域中最具实用价值的方向之一——从股票价格预测、天气预报,到交通流量监控、能源负荷管理,时间序列无处不在。然而,对于初学者而言,这个领域往往面临一个尴尬的困境:论文中的模型五花八门,代码实现各不相同,想要公平地对比不同方法几乎是一场噩梦。
BasicTS 正是为解决这一痛点而生。
作为一个公平、可扩展的时间序列分析基准库,BasicTS 集成了 30+ 种主流深度学习模型(从经典的 Informer、Autoformer 到最新的 iTransformer、TimeMixer)和 20+ 个标准数据集,覆盖长期预测(LTSF)、时空预测(STF)、时序分类、缺失值插补等核心任务。更重要的是,它提供了统一的训练和评估流程,让你可以用三行代码快速启动实验,专注于模型创新而非工程细节。
无论你是刚入门时序分析的新手,还是需要快速复现论文结果的研究者,这篇指南都将带你从安装配置到模型训练,全面掌握 BasicTS 的使用方法。
让我们开始吧。
介绍
BasicTS (Basic Time Series) 是一个面向时间序列分析的基准库和工具箱,现已支持时空预测、长序列预测、分类、插补等多种任务与数据集,涵盖统计模型、机器学习模型、深度学习模型等多类算法,为开发和评估时间序列预测模型提供了理想的工具。
链接:GestaltCogTeam/BasicTS: A Fair and Scalable Time Series Forecasting Benchmark and Toolkit.
📁 项目结构
BasicTS/
├── basicts/
│ ├── configs/ # 配置类
│ │ ├── tsf_config.py # 预测任务配置
│ │ ├── tsc_config.py # 分类任务配置
│ │ └── tsi_config.py # 插补任务配置
│ ├── data/ # 数据集类
│ ├── metrics/ # 评估指标
│ ├── models/ # 模型实现
│ │ ├── DLinear/
│ │ ├── iTransformer/
│ │ └── ...
│ ├── modules/ # 可复用组件
│ │ ├── transformer/ # Transformer 组件
│ │ ├── norm/ # 归一化层
│ │ └── embed/ # 嵌入层
│ ├── runners/ # 执行器
│ │ ├── callback/ # 回调函数
│ │ └── taskflow/ # 任务流
│ └── scaler/ # 数据缩放器
├── datasets/ # 数据集存放目录
├── checkpoints/ # 模型保存目录
└── examples/ # 示例代码
最新动态
🎉 更新(2025年10月):BasicTS 内置支持选择学习(NeurIPS'25),一种有效缓解过拟合,增加模型性能和泛化性的训练策略。用户可以从回调模块中导入并直接使用。使用说明
🎉 更新(2025年10月):BasicTS 1.0版本发布了!新特性:
- 🚀 三行代码,快速上手:pip install 安装,极简 API 设计,快速实现模型训练与评估。
- 📦 模块化组件,开箱即用:提供 Transformer、MLP 等即插即用的组件,像搭积木一样构建自己的模型。
- 🔄 多任务支持:支持时序预测、分类、插补等多个时序分析核心任务。
- 🔧 高可扩展架构:基于 Taskflow 与 Callback 机制,无需修改 Runner 即可轻松定制。
🎉 更新(2025年5月): BasicTS 现已支持使用 BLAST (KDD'25) 语料库训练通用预测模型(例如 TimeMoE 和 ChronosBolt)。BLAST 能够实现 更快的收敛速度、显著降低计算成本,并且即使在资源有限的情况下也能获得卓越性能。
✨ 主要功能亮点
BasicTS 一方面通过 统一且标准化的流程,为热门的深度学习模型提供了 公平且全面 的复现与对比平台。另一方面,BasicTS 提供了用户 友好且易于扩展 的接口,帮助快速设计和评估新模型。用户只需定义模型结构,便可轻松完成基本操作。
公平的性能评估:通过统一且全面的流程,用户能够公平且充分地对比不同模型在任意数据集上的性能表现。
使用 BasicTS 进行开发你可以:
最简代码实现
- 用户只需实现关键部分如模型架构、数据预处理和后处理,即可构建自己的深度学习项目。
基于配置文件控制一切
- 用户可以通过配置文件掌控流程中的所有细节,包括数据加载器的超参数、优化策略以及其他技巧(如课程学习)。
支持所有设备
- BasicTS 支持 CPU、GPU 以及分布式 GPU 训练(单节点多 GPU 和多节点),依托 EasyTorch 作为后端。用户只需通过设置参数即可使用这些功能,无需修改代码。
保存训练日志
- BasicTS 提供
logging日志系统和Tensorboard支持,并统一封装接口,用户可以通过简便的接口调用来保存自定义的训练日志。
📦 支持的模型
BasicTS 实现了丰富的基线模型,包括经典模型、时空预测模型、长序列预测模型、通用预测模型等。
这些模型的代码实现可在 baselines 目录中找到。
下表中的代码链接(💻Code) 指向了相关论文的官方实现,感谢各位作者对代码的开源贡献!
通用预测模型
UFM = Universal Forecasting Models(通用预测模型)
这是一类预训练的大规模时间序列基础模型,类似于 NLP 领域的 GPT、BERT。它们的特点是:
| 特性 | 说明 |
|---|---|
| 预训练 | 在海量时间序列数据上预训练 |
| 零样本/少样本 | 无需或只需少量微调即可用于新数据集 |
| 跨领域 | 可以处理不同领域的时间序列(天气、交通、金融等) |
| 大规模 | 参数量通常在数十亿级别 |
UFM vs 传统模型的区别
传统模型 (如 DLinear, iTransformer):
数据集 A 训练模型 预测 A,每个数据集单独训练。
UFM 通用模型:
海量数据预训练 预训练模型 直接用于任意数据集,零样本或微调。
| 📊Baseline | 📝Title | 📄Paper | 💻Code | 🏛Venue | 🎯Task |
|---|---|---|---|---|---|
| TimeMoE | Time-MoE: Billion-Scale Time Series Foundation Models with Mixture of Experts | Link | Link | ICLR'25 | UFM |
| ChronosBolt | Chronos: Learning the Language of Time Series | Link | Link | TMLR'24 | UFM |
| MOIRAI (inference) | Unified Training of Universal Time Series Forecasting Transformers | Link | Link | ICML'24 | UFM |
时空预测
什么是 STF?
STF = Spatial-Temporal Forecasting(时空预测)
这是一类需要同时考虑空间关系和时间依赖的预测任务。
核心概念:
普通时间序列预测 (LTSF):
时间 →
┌───┬───┬───┬───┬───┐
│ t1│ t2│ t3│ t4│ ? │ ← 单个序列,只考虑时间
└───┴───┴───┴───┴───┘
时空预测 (STF):
时间 →
┌───┬───┬───┬───┬───┐
节点A │ t1│ t2│ t3│ t4│ ? │
├───┼───┼───┼───┼───┤
节点B │ t1│ t2│ t3│ t4│ ? │ ← 多个节点 + 节点间有空间关系
├───┼───┼───┼───┼───┤
节点C │ t1│ t2│ t3│ t4│ ? │
└───┴───┴───┴───┴───┘
↑
节点之间通过图(Graph)连接
| 📊Baseline | 📝Title | 📄Paper | 💻Code | 🏛Venue | 🎯Task |
|---|---|---|---|---|---|
| STDN | Spatiotemporal-aware Trend-Seasonality Decomposition Network for Traffic Flow Forecasting | Link | Link | AAAI'25 | STF |
| HimNet | Heterogeneity-Informed Meta-Parameter Learning for Spatiotemporal Time Series Forecasting | Link | Link | SIGKDD'24 | STF |
| DFDGCN | Dynamic Frequency Domain Graph Convolutional Network for Traffic Forecasting | Link | Link | ICASSP'24 | STF |
| STPGNN | Spatio-Temporal Pivotal Graph Neural Networks for Traffic Flow Forecasting | Link | Link | AAAI'24 | STF |
| BigST | Linear Complexity Spatio-Temporal Graph Neural Network for Traffic Forecasting on Large-Scale Road Networks | Link | Link | VLDB'24 | STF |
| STDMAE | Spatio-Temporal-Decoupled Masked Pre-training for Traffic Forecasting | Link | Link | IJCAI'24 | STF |
| STWave | When Spatio-Temporal Meet Wavelets: Disentangled Traffic Forecasting via Efficient Spectral Graph Attention Networks | Link | Link | ICDE'23 | STF |
| STAEformer | Spatio-Temporal Adaptive Embedding Makes Vanilla Transformer SOTA for Traffic Forecasting | Link | Link | CIKM'23 | STF |
| MegaCRN | Spatio-Temporal Meta-Graph Learning for Traffic Forecasting | Link | Link | AAAI'23 | STF |
| DGCRN | Dynamic Graph Convolutional Recurrent Network for Traffic Prediction: Benchmark and Solution | Link | Link | ACM TKDD'23 | STF |
| STID | Spatial-Temporal Identity: A Simple yet Effective Baseline for Multivariate Time Series Forecasting | Link | Link | CIKM'22 | STF |
| STEP | Pretraining Enhanced Spatial-temporal Graph Neural Network for Multivariate Time Series Forecasting | Link | Link | SIGKDD'22 | STF |
| D2STGNN | Decoupled Dynamic Spatial-Temporal Graph Neural Network for Traffic Forecasting | Link | Link | VLDB'22 | STF |
| STNorm | Spatial and Temporal Normalization for Multi-variate Time Series Forecasting | Link | Link | SIGKDD'21 | STF |
| STGODE | Spatial-Temporal Graph ODE Networks for Traffic Flow Forecasting | Link | Link | SIGKDD'21 | STF |
| GTS | Discrete Graph Structure Learning for Forecasting Multiple Time Series | Link | Link | ICLR'21 | STF |
| StemGNN | Spectral Temporal Graph Neural Network for Multivariate Time-series Forecasting | Link | Link | NeurIPS'20 | STF |
| MTGNN | Connecting the Dots: Multivariate Time Series Forecasting with Graph Neural Networks | Link | Link | SIGKDD'20 | STF |
| AGCRN | Adaptive Graph Convolutional Recurrent Network for Traffic Forecasting | Link | Link | NeurIPS'20 | STF |
| GWNet | Graph WaveNet for Deep Spatial-Temporal Graph Modeling | Link | Link | IJCAI'19 | STF |
| STGCN | Spatio-Temporal Graph Convolutional Networks: A Deep Learning Framework for Traffic Forecasting | Link | Link | IJCAI'18 | STF |
| DCRNN | Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forecasting | Link | Link1, Link2 | ICLR'18 | STF |
Long-Term Time Series Forecasting
什么是 LTSF?
LTSF = Long-Term Time Series Forecasting(长期时间序列预测)
这是时间序列预测中最基础、最常见的任务:根据历史数据预测未来较长时间的数值。
LTSF 的特点
| 特点 | 说明 |
|---|---|
| 长期 | 预测长度通常为 96、192、336、720 个时间步 |
| 多变量 | 通常同时预测多个变量(如温度、湿度、风速等) |
| 无图结构 | 不考虑变量之间的空间拓扑关系 |
| 挑战 | 长期依赖建模、误差累积、非平稳性 |
| 📊Baseline | 📝Title | 📄Paper | 💻Code | 🏛Venue | 🎯Task |
|---|---|---|---|---|---|
| S-D-Mamba | Is Mamba Effective for Time Series Forecasting? | Link | Link | NeuroComputing'24 | LTSF |
| Bi-Mamba | Bi-Mamba+: Bidirectional Mamba for Time Series Forecasting | Link | Link | arXiv'24 | LTSF |
| ModernTCN | ModernTCN: A Modern Pure Convolution Structure for General Time Series Analysis | Link | Link | ICLR'24 | LTSF |
| TimeXer | TimeXer: Empowering Transformers for Time Series Forecasting with Exogenous Variables | Link | Link | NeurIPS'24 | LTSF |
| CARD | CARD: Channel Aligned Robust Blend Transformer for Time Series Forecasting | Link | Link | ICLR'24 | LTSF |
| SOFTS | SOFTS: Efficient Multivariate Time Series Forecasting with Series-Core Fusion | Link | Link | NeurIPS'24 | LTSF |
| CATS | Are Self-Attentions Effective for Time Series Forecasting? | Link | Link | NeurIPS'24 | LTSF |
| Sumba | Structured Matrix Basis for Multivariate Time Series Forecasting with Interpretable Dynamics | Link | Link | NeurIPS'24 | LTSF |
| GLAFF | Rethinking the Power of Timestamps for Robust Time Series Forecasting: A Global-Local Fusion Perspective | Link | Link | NeurIPS'24 | LTSF |
| CycleNet | CycleNet: Enhancing Time Series Forecasting through Modeling Periodic Patterns Forecasting | Link | Link | NeurIPS'24 | LTSF |
| Fredformer | Fredformer: Frequency Debiased Transformer for Time Series Forecasting | Link | Link | KDD'24 | LTSF |
| UMixer | An Unet-Mixer Architecture with Stationarity Correction for Time Series Forecasting | Link | Link | AAAI'24 | LTSF |
| TimeMixer | Decomposable Multiscale Mixing for Time Series Forecasting | Link | Link | ICLR'24 | LTSF |
| Time-LLM | Time-LLM: Time Series Forecasting by Reprogramming Large Language Models | Link | Link | ICLR'24 | LTSF |
| SparseTSF | Modeling LTSF with 1k Parameters | Link | Link | ICML'24 | LTSF |
| iTrainsformer | Inverted Transformers Are Effective for Time Series Forecasting | Link | Link | ICLR'24 | LTSF |
| Koopa | Learning Non-stationary Time Series Dynamics with Koopman Predictors | Link | Link | NeurIPS'24 | LTSF |
| CrossGNN | CrossGNN: Confronting Noisy Multivariate Time Series Via Cross Interaction Refinement | Link | Link | NeurIPS'23 | LTSF |
| NLinear | Are Transformers Effective for Time Series Forecasting? | Link | Link | AAAI'23 | LTSF |
| Crossformer | Transformer Utilizing Cross-Dimension Dependency for Multivariate Time Series Forecasting | Link | Link | ICLR'23 | LTSF |
| DLinear | Are Transformers Effective for Time Series Forecasting? | Link | Link | AAAI'23 | LTSF |
| DSformer | A Double Sampling Transformer for Multivariate Time Series Long-term Prediction | Link | Link | CIKM'23 | LTSF |
| SegRNN | Segment Recurrent Neural Network for Long-Term Time Series Forecasting | Link | Link | arXiv | LTSF |
| MTS-Mixers | Multivariate Time Series Forecasting via Factorized Temporal and Channel Mixing | Link | Link | arXiv | LTSF |
| LightTS | Fast Multivariate Time Series Forecasting with Light Sampling-oriented MLP | Link | Link | arXiv | LTSF |
| ETSformer | Exponential Smoothing Transformers for Time-series Forecasting | Link | Link | arXiv | LTSF |
| NHiTS | Neural Hierarchical Interpolation for Time Series Forecasting | Link | Link | AAAI'23 | LTSF |
| PatchTST | A Time Series is Worth 64 Words: Long-term Forecasting with Transformers | Link | Link | ICLR'23 | LTSF |
| TiDE | Long-term Forecasting with TiDE: Time-series Dense Encoder | Link | Link | TMLR'23 | LTSF |
| S4 | Efficiently Modeling Long Sequences with Structured State Spaces | Link | Link | ICLR'22 | LTSF |
| TimesNet | Temporal 2D-Variation Modeling for General Time Series Analysis | Link | Link | ICLR'23 | LTSF |
| Triformer | Triangular, Variable-Specific Attentions for Long Sequence Multivariate Time Series Forecasting | Link | Link | IJCAI'22 | LTSF |
| NSformer | Exploring the Stationarity in Time Series Forecasting | Link | Link | NeurIPS'22 | LTSF |
| FiLM | Frequency improved Legendre Memory Model for LTSF | Link | Link | NeurIPS'22 | LTSF |
| FEDformer | Frequency Enhanced Decomposed Transformer for Long-term Series Forecasting | Link | Link | ICML'22 | LTSF |
| Pyraformer | Low complexity pyramidal Attention For Long-range Time Series Modeling and Forecasting | Link | Link | ICLR'22 | LTSF |
| HI | Historical Inertia: A Powerful Baseline for Long Sequence Time-series Forecasting | Link | None | CIKM'21 | LTSF |
| Autoformer | Decomposition Transformers with Auto-Correlation for Long-Term Series Forecasting | Link | Link | NeurIPS'21 | LTSF |
| Informer | Beyond Efficient Transformer for Long Sequence Time-Series Forecasting | Link | Link | AAAI'21 | LTSF |
其他方法
| 📊Baseline | 📝Title | 📄Paper | 💻Code | 🏛Venue | 🎯Task |
|---|---|---|---|---|---|
| CatBoost | Catboost: unbiased boosting with categorical features | Link | Link | NeurIPS'18 | Machine Learning |
| LightGBM | LightGBM: A Highly Efficient Gradient Boosting Decision Tree | Link | Link | NeurIPS'17 | Machine Learning |
| NBeats | Neural basis expansion analysis for interpretable time series forecasting | Link | Link1, Link2 | ICLR'19 | Deep Time Series Forecasting |
| DeepAR | Probabilistic Forecasting with Autoregressive Recurrent Networks | Link | Link1, Link2, Link3 | Int. J. Forecast'20 | Probabilistic Time Series Forecasting |
| WaveNet | WaveNet: A Generative Model for Raw Audio. | Link | Link 1, Link 2 | arXiv | Audio |
| AR | VII. On a method of investigating periodicities disturbed series, with special reference to Wolfer's sunspot numbers | Link | Link | 1927 | Local Forecasting |
| MA | On periodicity in series of related terms | Link | Link | 1931 | Local Forecasting |
| ARMA | Some recent advances in forecasting and control | Link | Link | Applied Statistics'1968 | Local Forecasting |
| ARIMA | Forecasting with exponential smoothing: the state space approach | Link | Link | 2008 | Local Forecasting |
| SARIMA | Forecasting with exponential smoothing: the state space approach | Link | Link | 2008 | Local Forecasting |
| ARCH | Conditional heteroscedasticity in time series of stock returns: Evidence and forecasts | Link | Link | Journal of business'1989 | Local Forecasting |
| GARCH | Conditional heteroscedasticity in time series of stock returns: Evidence and forecasts | Link | Link | Journal of business'1989 | Local Forecasting |
| ETS | The holt-winters forecasting procedure | Link | Link | Applied Statistics'1978 | Local Forecasting |
| SES | The holt-winters forecasting procedure | Link | Link | Applied Statistics'1978 | Local Forecasting |
| SVR | Support vector regression machines | Link | Link | NeurIPS'1996 | Machine Learning |
| PolySVR | A training algorithm for optimal margin classifiers | Link | Link | COLT'1992 | Machine Learning |
📦 支持的数据集
BasicTS 支持多种类型的数据集,涵盖时空预测、长序列预测及大规模数据集。
数据集表格字段说明
| 字段 | 含义 | 示例 |
|---|---|---|
| Length | 时间序列的总长度(时间步数) | 14400 表示有 14400 个时间点 |
| Time Series Count | 变量/节点数量 | 7 表示有 7 个变量同时被记录 |
| Graph | 是否包含图结构(节点间的空间关系) | True 表示有邻接矩阵 |
| Freq. (m) | 采样频率(分钟) | 60 表示每小时采样一次 |
例如,ETTh1 数据集:
- 14400 个时间步 × 60 分钟/步 = 600 天的数据
- 7 个变量:包括油温、负载等电力变压器相关指标
时空预测
| 🏷️Name | 🌐Domain | 📏Length | 📊Time Series Count | 🔄Graph | ⏱️Freq. (m) | 🎯Task |
|---|---|---|---|---|---|---|
| METR-LA | Traffic Speed | 34272 | 207 | True | 5 | STF |
| PEMS-BAY | Traffic Speed | 52116 | 325 | True | 5 | STF |
| PEMS03 | Traffic Flow | 26208 | 358 | True | 5 | STF |
| PEMS04 | Traffic Flow | 16992 | 307 | True | 5 | STF |
| PEMS07 | Traffic Flow | 28224 | 883 | True | 5 | STF |
| PEMS08 | Traffic Flow | 17856 | 170 | True | 5 | STF |
长序列预测
| 🏷️Name | 🌐Domain | 📏Length | 📊Time Series Count | 🔄Graph | ⏱️Freq. (m) | 🎯Task |
|---|---|---|---|---|---|---|
| BeijingAirQuality | Beijing Air Quality | 36000 | 7 | False | 60 | LTSF |
| ETTh1 | Electricity Transformer Temperature | 14400 | 7 | False | 60 | LTSF |
| ETTh2 | Electricity Transformer Temperature | 14400 | 7 | False | 60 | LTSF |
| ETTm1 | Electricity Transformer Temperature | 57600 | 7 | False | 15 | LTSF |
| ETTm2 | Electricity Transformer Temperature | 57600 | 7 | False | 15 | LTSF |
| Electricity | Electricity Consumption | 26304 | 321 | False | 60 | LTSF |
| ExchangeRate | Exchange Rate | 7588 | 8 | False | 1440 | LTSF |
| Illness | Ilness Data | 966 | 7 | False | 10080 | LTSF |
| Traffic | Road Occupancy Rates | 17544 | 862 | False | 60 | LTSF |
| Weather | Weather | 52696 | 21 | False | 10 | LTSF |
大规模数据集
| 🏷️Name | 🌐Domain | 📏Length | 📊Time Series Count | 🔄Graph | ⏱️Freq. (m) | 🎯Task |
|---|---|---|---|---|---|---|
| CA | Traffic Flow | 35040 | 8600 | True | 15 | Large Scale |
| GBA | Traffic Flow | 35040 | 2352 | True | 15 | Large Scale |
| GLA | Traffic Flow | 35040 | 3834 | True | 15 | Large Scale |
| SD | Traffic Flow | 35040 | 716 | True | 15 | Large Scale |
Pre-training Corpus
| 🏷️Name | 🌐Domain | 📏Length | 📊Time Series Count | 🔄Graph | ⏱️Freq. | 🎯Task |
|---|---|---|---|---|---|---|
| BLAST | Multiple | 4096 | 20000000 | False | Multiple | UFM |
🔗 EasyTorch
BasicTS 是基于 EasyTorch 开发的,这是一个易于使用且功能强大的开源神经网络训练框架。
EasyTorch 的主要特点:
| 特点 | 说明 |
|---|---|
| 简化训练流程 | 封装了 PyTorch 训练的常见模式,减少样板代码 |
| 分布式训练支持 | 内置 DDP(分布式数据并行)支持,一行代码启动多 GPU 训练 |
| 设备管理 | 统一的设备管理(CPU/GPU/MLU),自动处理数据迁移 |
| 检查点管理 | 提供 save_ckpt、load_ckpt、backup_last_ckpt 等便捷函数 |
| 日志系统 | 统一的 Logger 接口,支持 TensorBoard |
| 环境配置 | 随机种子设置、确定性训练、TF32 模式等 |
EasyTorch 提供的核心功能
设备管理
from easytorch.device import set_device_type, to_device
set_device_type("gpu") # 设置使用 GPU
data = to_device(data) # 自动将数据移到正确设备
分布式训练
from easytorch.launcher.dist_wrap import dist_wrap
# 自动包装为分布式训练
train_dist = dist_wrap(
training_func,
node_num=1, # 节点数
device_num=4, # GPU 数量
dist_backend="nccl"
)
train_dist(cfg)
检查点管理
from easytorch.core.checkpoint import save_ckpt, load_ckpt
save_ckpt(model, optimizer, epoch, path) # 保存检查点
load_ckpt(model, path) # 加载检查点
工具函数
from easytorch.utils import get_logger, is_master, master_only
logger = get_logger("MyProject") # 获取日志器
@master_only # 只在主进程执行
def save_results():
...
if is_master(): # 判断是否为主进程
print("Main process")
环境设置
from easytorch.utils.env import setup_determinacy, set_tf32_mode
setup_determinacy(seed=42) # 设置随机种子,确保可复现
set_tf32_mode(True) # 启用 TF32 加速
总结
EasyTorch 的定位:
PyTorch (底层) → EasyTorch (训练框架) → BasicTS (时间序列工具包)
EasyTorch 负责:设备管理、分布式训练、检查点、日志 BasicTS 负责:时间序列数据、模型、评估指标
快速上手
📦 安装 BasicTS
建议在 Linux 系统(如 Ubuntu 或 CentOS)上,在 Python 3.8 或更高版本上安装 BasicTS:
pip install basicts
推荐使用 Miniconda 或 Anaconda 来创建虚拟 Python 环境。
🔧 安装依赖项
PyTorch
BasicTS 对 PyTorch 版本非常灵活。您可以根据 Python 版本安装 PyTorch。我们建议使用 pip 进行安装。
示例设置
示例 1:Python 3.11 + PyTorch 2.5.1 + CUDA 12.4 (推荐)
# 安装 Python
conda create -n BasicTS python=3.11
conda activate BasicTS
# 安装 PyTorch
pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124
示例 2:Python 3.9 + PyTorch 1.10.0 + CUDA 11.1
# 安装 Python
conda create -n BasicTS python=3.9
conda activate BasicTS
# 安装 PyTorch
pip install torch==1.10.0+cu111 torchvision==0.11.0+cu111 torchaudio==0.10.0 -f https://download.pytorch.org/whl/torch_stable.html
🔍 下载数据集
请先从 Google Drive or 百度网盘 下载 datasets.zip 文件。将文件解压到 datasets/ 目录:
cd /path/to/YourProject # not BasicTS/basicts
unzip /path/to/datasets.zip -d datasets/
这些数据集已预处理完毕,可以直接使用。
data.dat文件是以numpy.memmap格式存储的数组,包含原始时间序列数据,形状为 [L, N, C],其中 L 是时间步数,N 是时间序列数,C 是特征数。
desc.json文件是一个字典,存储了数据集的元数据,包括数据集名称、领域、频率、特征描述、常规设置和缺失值。其他文件是可选的,可能包含附加信息,如表示时间序列间预定义图结构的
adj_mx.pkl。
如果您对预处理步骤感兴趣,可以参考预处理脚本
blob/master/scripts/data_preparation和raw_data.zip。
🎯 快速教程:三句代码训练并评估您的模型
# train.py
from basicts.models.DLinear import DLinear, DLinearConfig
from basicts.configs import BasicTSForecastingConfig
from basicts import BasicTSLauncher
def main():
# 1. 配置模型
model_config = DLinearConfig(input_len=336, output_len=336)
# 2. 配置任务
cfg = BasicTSForecastingConfig(
model=DLinear,
model_config=model_config,
dataset_name="ETTh1",
gpus="0",
...
)
# 3.启动训练
BasicTSLauncher.launch_training(cfg)
第一步:配置您的模型
BasicTS在basicts.models中提供了大量常用模型,您可以直接使用。BasicTS使用配置类来配置模型,每个配置类包含了对构造模型需要的每个参数的详细描述。例如,DLinear模型的配置类为 DLinearConfig。您可以在 dinear_config.py 中找到DLinearConfig类。
如果您想要使用自己的模型,则需要遵循 BasicTS 的规范,详情请见下面:🧠模型设计。
第二步:配置您的任务
BasicTS 支持多种时间序列任务,包括预测、插补、分类等。任务配置类是 BasicTS的核心,关于 BasicTS 任务的一切信息都囊括在了任务配置类中,几乎全部配置项都有常用的默认值,您只需配置关键的参数(1️⃣模型,2️⃣数据集),并修改少量配置(如batch size,学习率等),就能运行代码。
您可以在basicts/configs 中找到每个 BasicTS 任务的配置类 (例如,预测任务的配置类为BasicTSForecastingConfig),及其每个参数的含义与配置方法。
进一步,在 BasicTS 的配置类中,您还可以指定回调(callbacks)和任务流(taskflow)用于在训练过程中执行额外的操作(如课程学习)和自定义数据处理流程。关于BasicTS配置类的进阶用法,请见:📜 配置设计。
第三步:启动训练
BasicTSLauncher.launch_training 是训练的入口点,调用该方法并传入任务配置即可启动训练。
需要注意的是,在DDP模式下,BasicTSLauncher.launch_training 需要被包裹在if __name__ == '__main__':中,以确保每个进程都能正确初始化模型和数据集。
🥳 运行它!
完整可运行示例
# 完整可运行示例
from basicts.models.DLinear import DLinear, DLinearConfig
from basicts.configs import BasicTSForecastingConfig
from basicts import BasicTSLauncher
def main():
# 1. 配置模型
model_config = DLinearConfig(input_len=96, output_len=96)
# 2. 配置任务
cfg = BasicTSForecastingConfig(
model=DLinear,
model_config=model_config,
dataset_name="ETTh1",
gpus="0",
batch_size=32,
num_epochs=100,
learning_rate=0.001,
)
# 3. 启动训练
BasicTSLauncher.launch_training(cfg)
if __name__ == "__main__":
main()
在您项目的目录下,运行以下命令即可启动训练:
python train.py
在训练中,BasicTS会默认将训练好的模型保存到 checkpoints/ 目录下,并在训练完成后执行评估(可以通过配置更改),您也可以选择将评估指标和结果保存到 checkpoints/ 目录下。
如何评估您的模型
当然,您也可以在训练结束后手动评估模型:BasicTSLauncher.launch_evaluation 是评估的入口点,您可以通过执行下面的Python代码来评估您的模型。
BasicTSLauncher.launch_evaluation(cfg, "checkpoints/your_checkpoint.pt")
函数签名
BasicTSLauncher.launch_evaluation(
cfg, # 配置对象
ckpt_path, # 检查点文件路径
gpus=None, # 可选:使用的 GPU
batch_size=None # 可选:评估时的批次大小
)
参数解释
| 参数 | 类型 | 说明 |
|---|---|---|
cfg | BasicTSConfig | 任务配置(与训练时相同) |
ckpt_path | str | 训练保存的模型检查点文件路径 |
gpus | str | 使用哪些 GPU,如 "0" 或 "0,1" |
batch_size | int | 评估时的批次大小(可覆盖配置中的值 |
基本用法
from basicts.models.DLinear import DLinear, DLinearConfig
from basicts.configs import BasicTSForecastingConfig
from basicts import BasicTSLauncher
# 1. 配置(与训练时相同)
model_config = DLinearConfig(input_len=96, output_len=96)
cfg = BasicTSForecastingConfig(
model=DLinear,
model_config=model_config,
dataset_name="ETTh1",
)
# 2. 评估已训练的模型
BasicTSLauncher.launch_evaluation(
cfg,
ckpt_path="checkpoints/DLinear_ETTh1/best_model.pt"
)
🧑💻 进一步探索
本教程为您提供了 BasicTS 的基础知识,但还有更多内容等待您探索。在深入其他主题之前,我们先更详细地了解 BasicTS 的结构:

BasicTS 的核心组件包括 Dataset、Scaler、Model、Metrics、Runner 和 Config。
🎯 模型选择指南
| 场景 | 推荐模型 | 原因 |
|---|---|---|
| 快速基线 | DLinear, NLinear | 简单高效,训练快 |
| 长序列预测 | PatchTST, iTransformer | 专为长序列设计 |
| 交通预测 | STID, D2STGNN | 支持图结构 |
| 资源受限 | SparseTSF | 仅 1K 参数 |
| 零样本预测 | TimeMoE, ChronosBolt | 预训练通用模型 |
| 多变量关系 | Crossformer, TimesNet | 建模变量间依赖 |
💡 BasicTS 的整体设计
在时间序列分析流程中,通常包含以下几个关键部分:
- 数据集 (Dataset):定义读取数据集和生成样本的方式。(位于
basicts.data模块) - 数据缩放 (Scaler):处理数据的归一化与反归一化操作,如 Z-score 和 Min-Max 归一化等方法。(位于
basicts.scaler模块) - 评估指标 (Metrics):定义模型评估的指标和损失函数,包括 MAE、MSE、MAPE、RMSE 和 WAPE 等。(位于
basicts.metrics模块) - 执行器 (Runner):作为 BasicTS 的核心模块,负责协调整个训练过程。执行器集成了数据集、数据缩放、模型架构和评估指标等组件,提供了多种功能支持。(位于
basicts.runner模块) - 模型结构 (Model):定义模型架构及其前向传播过程(位于
basicts.models模块)
BasicTS 提供了完整且可扩展的组件,用户只需提供模型结构,即可完成大部分任务。
为了简化训练策略配置并方便对比,BasicTS 遵循 一切基于配置 的设计理念:所有选项都集成在配置文件中,用户只需修改配置文件即可轻松设置模型、数据集、归一化方法、评估指标、优化器、学习率、甚至添加扩展功能——一切就像填写表单一样简单。
📦 数据集设计
🎸 新特性
从1.0版本开始,BasicTS采用数据解耦的设计方案,用户可以使用任意数据结构的数据集,只需继承BasicTSDataset基类实现自定义的数据集读取逻辑即可。
从1.0版本开始,BasicTS不再将数据和时间戳整合在一个四维张量中存储([batch_size, seq_len, num_features, num_timestamps + 1]),改为使用两个三维张量,显著降低了显存占用。
- 时序数据: [batch_size, seq_len, num_features];
- 时间戳数据:[batch_size, num_features, num_timestamps]。
📊 内置数据集
⏬ 数据下载
要开始使用内置数据集,请先从 Google Drive or 百度网盘 下载 datasets.zip 文件。下载后,将文件解压至 datasets/ 目录:
cd /path/to/project
unzip /path/to/all_data.zip -d datasets/
这是BasicTS默认的数据集保存路径,当然,您也可以将数据集放在任意其他路径下,并在dataset_params中的data_file_path字段显式的提供根路径。
这些数据集已经过预处理,可以直接使用。
未来将会支持在线下载内置数据集,该功能目前正在开发中。
🔬 使用内置数据集
内置数据集通常配合BasicTS内置支持的数据集类使用,内置数据类也是配置中的默认选项。 内置数据集类:
- 预测任务:
BasicTSForecastingDataset; - 分类任务:
UEADataset; - 插补任务:
BasicTSImputationDataset。
这些内置数据集类包括以下参数:
dataset_name(str): 数据集的名称。input_len(int): 输入序列的长度,即历史数据点的数量。output_len(int): (仅预测任务)输出序列的长度,即需要预测的未来数据点的数量。mode(BasicTSMode | str): 数据集的模式,"TRAIN", "VAL"或"TEST",指示其用于训练、验证还是测试。由runner统一指定,无需手动赋值。use_timestamps (bool): 是否使用时间戳的标志,默认为False。local (bool): 数据集是否在本地。(开发中)data_file_path(str | None): 包含时间序列数据文件的路径。默认为 "datasets/{name}"。memmap(bool): 是否使用内存映射加载数据集的标志。开启时节省内存但会降低训练速度,因此建议仅在数据集极大时使用。默认为False。
通常来说,默认设置下使用内置数据集,只需在配置类中指定dataset_name、input_len以及output_len(预测任务)即可。
💿 数据格式
在BasicTS中,数据集提供的数据需要遵循标准的格式。__get_item__方法应返回一个包含以下项目的字典:
inputs:输入数据,形状为[batch_size, input_len, num_features]的torch.Tensor;targets:目标数据(可选)。一个torch.Tensor。对于预测和插补任务,形状为[batch_size, output_len, num_features],对于分类任务,形状为[batch_size, num_classes],对于自监督任务,不需要该键;inputs_timestamps:输入数据的时间戳(可选),形状为[batch_size, input_len, num_timestamps]的torch.Tensor;targets_timestamps:输入数据的时间戳(可选),形状为[batch_size, output_len, num_timestamps]的torch.Tensor。
🧑🍳 如何添加或自定义数据集
您可以通过以下三步使用您自定义的数据集:
- 编写数据集类,继承
BasicTSDataset基类,基类包含三个字段:dataset_name,mode,memmap。 - 自定义实现您的数据读取和预处理逻辑,实现
__get_item__和__len__方法。请注意,虽然数据实际的存储结构可以是任意的,但__get_item__方法返回的数据项应该遵循上文提到的规范。 - 如果需要使用缩放器对数据做归一化,还需重写data方法(property)。该方法用于向缩放器提供一个待归一化数据的视图(np.ndarray),使缩放器学习整个训练集的分布。
- 在配置类中修改
dataset_type字段为您自己的数据集类,并设置相应的dataset_params。
🛠️ 数据缩放器设计 (Scaler)
🧐 什么是数据缩放器,为什么需要它?
数据缩放器(简称缩放器)是一个用于处理数据归一化和反归一化的类。在时间序列分析中,原始数据通常具有显著的尺度差异。因此,模型(尤其是深度学习模型)通常不会直接在原始数据上进行操作。相反,缩放器会将数据归一化到一个特定范围内,使其更适合建模。在计算损失函数或评估指标时,也可能将数据反归一化回原始尺度,以确保比较的准确性。
这使得缩放器成为时间序列分析工作流程中的重要组件。
👾 缩放器如何初始化及何时起作用?
缩放器与其他组件一起在执行器中初始化。
例如,Z-Score 缩放器会读取原始数据,并基于训练数据计算均值和标准差。
缩放器在从数据集中提取数据后起作用。数据首先由缩放器归一化,然后传递给模型进行训练。模型处理完数据后,缩放器会将输出反归一化,然后再传递给执行器进行损失计算和指标评估。
在许多时间序列分析研究中,归一化通常在数据预处理中进行,这也是早期 BasicTS 版本的做法。然而,这种方式的可扩展性较差。诸如更改输入/输出长度、应用不同的归一化方法(例如对每个时间序列单独归一化),或更改训练/验证/测试集的比例等调整,都会要求重新预处理数据。为了解决这个问题,BasicTS 采用了“即时归一化”的方式,每次提取数据时都会进行归一化处理。
# 在 runner 中
for data in dataloader:
data = scaler.transform(data)
forward_return = forward(data)
forward_return = scaler.inverse_transform(forward_return)
🧑🔧 如何选择或自定义缩放器
BasicTS 提供了几种常见的缩放器,例如 Z-Score 和 Min-Max 缩放器。您可以通过在配置文件中设置 scaler 来轻松切换缩放器。
如果您需要自定义缩放器,可以扩展 BasicTSScaler 类,并实现 transform 和 inverse_transform 方法。或者,您也可以选择不继承该类,但仍然需要实现这两个方法。
常用缩放器对比
| 缩放器 | 公式 | 特点 |
|---|---|---|
| Z-Score | 标准化到均值0、标准差1,适合正态分布数据 | |
| Min-Max | 缩放到 [0, 1] 区间,保留原始分布形状 | |
| RevIN | 实例级归一化 + 反归一化 | 处理非平稳时间序列,在模型内部使用 |
选择建议:
- 大多数情况下使用 Z-Score(默认)
- 数据有明确边界时使用 Min-Max
- 非平稳数据考虑在模型中启用 RevIN
🧠 模型设计
您的模型的 forward 函数应遵循 BasicTS 设定的规范。
🏗️ 构造模型
BasicTS使用配置类/字典构造模型,该配置类/字典应该包含构造模型所需的全部参数。 BasicTS模型配置类的基类为BasicTSModelConfig,其本身是字典的子类。当使用配置类构造模型时,您可以继承这一基类定义您的模型的配置。例如:
@dataclass
class YourModelConfig(BasicTSModelConfig):
input_len: int
output_len: int
num_features: int
hidden_size: int = 256
hidden_act: int = "relu"
class YourModel(nn.Module):
def __init__(config: YourModelConfig):
...
[!important]
⚠️注意:强烈建议在配置中只使用可以JSON序列化的字段(数值、字符串、布尔、列表、元组、字典等),避免将自定义类作为字段,否则配置文件可能无法被正常保存。
🪴 输入接口
BasicTS 自1.0起,forward函数不再强制要求传入固定的参数(尽管未使用),而是可以按需指定传入的参数。然而,传入参数需要遵守以下规范。
-
标准模型参数:BasicTS 1.0 标准的
forward参数命名如下。模型的主输入为inputs,输出为targets;若使用时间戳,则时间戳数据为inputs_timestamps,targets_timestamps;若需要使用mask信息(如计算损失),则掩码数据为inputs_mask、targets_mask。此外,还可以传入当前训练的轮(epoch)数和步(step)数。注意,train参数即将被淘汰,可以访问nn.Module的training字段实现。def forward( self, inputs: torch.Tensor, targets: Optional[torch.Tensor] = None, inputs_timestamps: Optional[torch.Tensor] = None, targets_timestamps: Optional[torch.Tensor] = None, inputs_mask: Optional[torch.Tensor] = None, targets_mask: Optional[torch.Tensor] = None, epoch: Optional[int] = None, step: Optional[int] = None, train: Optional[bool] = None ,**kwargs ):假设模型只需要用到输入序列及其时间戳,则:
class MyModel(nn.Module): def forward(self, inputs: torch.Tensor, inputs_timestamps: torch.Tensor): ... -
自定义模型参数:您可以在
forward函数中加入任何自定义模型参数,但需要保证数据字典中包含该键。例如:# 如果包含extra_flag这个额外的参数,则需要保证传入的数据字典中包含该键: # {"inputs": inputs, "extra_flag": extra_flag, ...} def forward(self, inputs: torch.Tensor, extra_flag: bool): ...您可以在数据流上游添加或修改数据字典:Dataset或taskflow。 数据流向:
Dataset.__get_item__->taskflow.preprocess->model.forward-
在
Dataset.__get_item__中添加(推荐):在数据集的__get_item__函数中返回包含该键的字典。 例如:class MyDataset(torch.utils.data.Dataset): def __get_item__(self, idx: int): return { "inputs": self.inputs[idx], "targets": self.targets[idx], "extra_flag": self.flag[idx] # <-- add extra_flag } -
在
taskflow.preprocess改变数据字典:在自定义Taskflow类的preprocess可以修改数据字典。由于涉及对任务逻辑的修改,建议新用户谨慎使用该方法。 例如:class MyTaskflow(BasicTSTaskflow): def preprocess(self, data: dict): ... data["extra_flag"] = self.extra_flag # <-- add extra_flag return data
-
🌷 输出接口
forward 函数的返回值应该是一个字典或一个torch.Tensor。
- 字典中必须包含键
prediction,代表模型的预测结果。 - 若返回值为一个
torch.Tensor,则后续pipeline会自动将其包装成字典{"prediction":...},从而计算损失。 - 字典中可以添加任意您自定义的键,用于实现自定义逻辑或计算评估指标等。
- 想要返回在模型内部计算的损失时,必须返回包含键
loss的字典(若直接传一个损失的torch.Tensor则会被视作预测结果)。当字典中包含loss时,后续pipeline不会再计算损失,而是直接取用。 - 想要返回在内部计算的额外损失,并与主损失相加时,须在配置类中使用
AddAuxiliaryLoss的callback,并指定额外损失的键名。例如,传递名为freq_loss和lb_loss的额外损失,使最终损失为MSE + freq_loss + lb_loss:
# in your_train_script.py
config=BasicTSConfig(
loss=masked_mse,
callback=[AddAuxiliaryLoss([`freq_loss`, `lb_loss`])],
...
)
# in your_model.py
def forward(...):
return {
"prediction": prediction,
"freq_loss": freq_loss,
"lb_loss": lb_loss
}
🥳 支持的基线模型
BasicTS 提供了多种内置模型。您可以在models 模块中找到它们,并只需导入对应的模型类和模型配置类即可使用模型。以使用STID为例:
from basicts.models.STID import STID, STIDConfig
task_config = BasicTSForecastingConfig(
model=STID,
model_config=STIDConfig,
...
)
特别地,对于内置的多任务模型,通常包含一个公用的骨干网络(XXXBackbone,XXX为模型名),以及若干个任务特定的模型(XXXForYYY,YYY为任务名)。以TimesNet为例,可以导入TimesNetForForecasting进行预测任务,TimesNetForClassification进行分类任务,TimesNetForReconstruction进行插补任务。这些下游任务公用相同的骨干网络和相同的配置类。
from basicts.models.TimesNet import TimesNetBackbone, TimesNetForForecasting, TimesNetForClassifiction, TimesNetForReconstruction, TimesNetConfig
📉 评估指标设计
接口规范
评估指标是评估模型性能的重要组成部分。在 BasicTS 中,评估指标是接受模型预测值、真实值及其他参数作为输入并返回标量值以评估模型性能的函数。
一个定义良好的评估指标函数应包含以下参数:
- prediction: 模型的预测值
- targets: 实际的真实值
- targets_mask: 可选参数,用于指定在哪些点上计算损失(一般用于掩码缺失值)。
prediction 和 target 是必需参数,而 targets_mask 是可选参数,但强烈建议采纳,以处理时间序列数据中常见的缺失值。
评估指标函数还可以接受其他额外参数,这些参数会从模型的返回值中提取并传递给指标函数。
BasicTS 内置评估指标
BasicTS 提供了多种常用的评估指标,例如 MAE、MSE、RMSE、MAPE 和 WAPE。您可以在 basicts.metrics 模块中找到这些指标的实现。
常用评估指标详解
| 指标 | 公式 | 适用场景 | ||||
|---|---|---|---|---|---|---|
| MAE | $\frac{1}{n}\sum | y_i - \hat{y}_i | $ | 对异常值不敏感,适合一般预测任务 | ||
| MSE | 对大误差惩罚更重,常用作损失函数 | |||||
| RMSE | 与原数据同量纲,便于解释 | |||||
| MAPE | $\frac{100%}{n}\sum | \frac{y_i - \hat{y}_i}{y_i} | $ | 百分比误差,适合比较不同量级数据 | ||
| WAPE | $\frac{\sum | y_i - \hat{y}_i | }{\sum | y_i | }$ | 加权百分比误差,避免 MAPE 除零问题 |
注意:当真实值接近 0 时,MAPE 会趋向无穷大,此时建议使用 WAPE 或 MAE。
如何实现自定义评估指标
根据接口规范中的指南,您可以轻松实现自定义的评估指标。以下是一个示例:
class MyModel:
def __init__(self):
# 初始化模型
...
def forward(...):
# 前向计算
...
return {
'prediction': prediction,
'targets': target,
'other_key1': other_value1,
'other_key2': other_value2,
'other_key3': other_value3,
...
}
def my_metric_1(prediction, targets, targets_mask=None, other_key1=None, other_key2=None, ...):
# 计算指标
...
def my_metric_2(prediction, targets, targets_mask=None, other_key3=None, ...):
# 计算指标
...
遵循这些规范,您可以灵活地在 BasicTS 中自定义和扩展评估指标,以满足特定需求。
🧮 仪表盘
该节仅涉及细节内容,绝大部分情况下不会影响使用,可以跳过。
在BasicTS中,我们使用仪表盘(Meter类)在训练中维护指标值。BasicTS会默认使用平均仪表盘(AvgMeter类),逐步更新并维护对应指标的均值,这适用于绝大部分指标。
然而,也有一些指标不应该维护均值,例如RMSE,是先求平均再开平方,此时如果逐步累积最后再求平均则会产生错误(虽然一般不影响模型的训练结果)。此时,应该使用特殊的仪表盘,实现正确的增量计算。
🏃♂️ 执行器与流程
💿 概述
执行器是 BasicTS 的核心组件,负责管理整个训练和评估过程。它将数据集、数据缩放器、模型、评估指标和配置文件等各个子组件集成在一起,构建一个公平且可扩展的训练和评估流程。
自BasicTS 1.0起,BasicTS只需要一个执行器类BasicTSRunner,并对其进行了全面重构和解耦。您无需再修改任何执行器代码,就能实现任何自定义的扩展功能。
BasicTS训练与评估流程的三层架构:重构后的BasicTS的训练与评估流程可以被分为三个层次。
- 执行器与通用流程层(
BasicTSRunner):集结了一切基础流程中通用的、和具体任务无关的训练流程。用户不应该直接修改该层次的代码。 - 任务流层(
BasicTSTaskflow):定义了基础流程中和任务相关的步骤。当不修改任务流程时,用户应该尽量少地自定义该层的对象。 - 回调层(
BasicTSCallback):定义了基础流程之外的扩展功能,例如早停、梯度裁剪、课程学习等。当想要扩展功能时,用户应该尽可能地通过回调来实现。
⚡️ 通用流程
以训练为例(评估类似),执行器实现的通用流程如下列伪代码所示。 与标准深度学习框架相符,通用流程包括:模型前传、计算损失、损失反传、优化器更新。
def train_loop(self):
for epoch in range(num_epochs):
# Event 1: on_epoch_start events
callback_handler.trigger("on_epoch_start")
for data in train_data_loder:
# Event 2: on_step_start events
callback_handler.trigger("on_step_start")
# Task-specific 1: preprocess data
data = taskflow.preprocess(self, data)
# General pipeline 1: model forward
forward_return = forward()
# Event 3: on_compute_loss events
callback_handler.trigger("on_compute_loss")
# General pipeline 2: compute loss
loss = metric_forward(loss_function, forward_return)
# Task-specific 2: get loss weight
loss_weight = taskflow.get_weight(forward_return)
# Event 4: on_backward events
callback_handler.trigger("on_backward") # on_backward events
# General pipeline 3: loss backward
loss.backward()
# Event 5: on_optimizer_step events
callback_handler.trigger("on_optimizer_step")
# General pipeline 4: optimizer step
optimizer_step()
# Task-specific 3: postprocess forward return
forward_return = taskflow.postprocess(self, forward_return)
# General pipeline 5: compute metrics
metric_value = metric_forward(metric_fn, forward_return)
# Event 6: on_step_end events
callback_handler.trigger("on_step_end")
# Event 7: on_epoch_end events
callback_handler.trigger("on_epoch_end")
💫 任务流
任务流模块位于basicts.runners.taskflow,其基类定义如下:
class BasicTSTaskflow():
def preprocess(self, runner, data):
pass
def postprocess(self, runner, forward_return):
pass
def get_weight(self, forward_return):
pass
preprocess:定义数据在模型前传前的预处理逻辑,包括归一化、生成缺失值掩码等。postprocess:定义数据在计算指标前的后处理逻辑,包括反归一化(预测任务),计算argmax(分类任务)等。get_weight:定义当前批次在全部训练数据中的损失权重,保证数据集的整体损失能被正确计算。例如,分类任务的权重应该是该批次的样本数,预测任务应该是该批次全部有效点的数量。
🪝 回调层
回调模块位于basicts.runners.callback。一个回调类应该包含若干个回调函数,执行器的CallbackHandler对象会在对应的阶段调用这些函数,以实现功能的扩展。
回调基类BasicTSCallback定义了全部可用的回调函数:
class BasicTSCallback:
# 训练开始时
def on_train_start(self, runner, *args, **kwargs):
pass
# 训练结束时
def on_train_end(self, runner, *args, **kwargs):
pass
# epoch开始时
def on_epoch_start(self, runner, *args, **kwargs):
pass
# epoch结束时
def on_epoch_end(self, runner, *args, **kwargs):
pass
# step开始时
def on_step_start(self, runner, *args, **kwargs):
pass
# step结束时
def on_step_end(self, runner, *args, **kwargs):
pass
# 验证开始时
def on_validate_start(self, runner, *args, **kwargs):
pass
# 验证结束时
def on_validate_end(self, runner, *args, **kwargs):
pass
# 测试开始时
def on_test_start(self, runner, *args, **kwargs):
pass
# 测试结束时
def on_test_end(self, runner, *args, **kwargs):
pass
# 计算损失前
def on_compute_loss(self, runner, *args, **kwargs):
pass
# 反向传播前
def on_backward(self, runner, *args, **kwargs):
pass
# 优化器更新前
def on_optimizer_step(self, runner, *args, **kwargs):
pass
常用内置回调示例
1. 早停 (Early Stopping)
from basicts.runners.callback import EarlyStopping
cfg = BasicTSForecastingConfig(
callbacks=[
EarlyStopping(patience=10, monitor="val_loss")
],
...
)
2. 梯度裁剪 (Gradient Clipping)
from basicts.runners.callback import GradientClipping
cfg = BasicTSForecastingConfig(
callbacks=[
GradientClipping(max_norm=1.0)
],
...
)
3. 课程学习 (Curriculum Learning)
from basicts.runners.callback import CurriculumLearning
cfg = BasicTSForecastingConfig(
callbacks=[
CurriculumLearning(start_len=12, end_len=96, warmup_epochs=10)
],
...
)
📜 配置设计
BasicTS 的设计理念是“配置即一切“。BasicTS 的目标是让用户专注于模型和数据,而不用被繁琐的流程构建所困扰。
🎸 新特性
从1.0版本开始,BasicTS不再使用py文件配合命令行指定配置路径的方式,升级为使用配置类进行配置。 配置类的基类为BasicTSConfig,此外,每个具体任务对应一个配置类,包括BasicTSForecastingConfig,BasicTSClassificationConfig,BasicTSImputationConfig,BasicTSFoundationModelConfig等。基类BasicTSConfig定义了公用的字段以及保存/加载/打印配置类的方法,任务特定的配置类则包括执行该任务需要的一切配置参数。您可以灵活地向其中导入模型,并设置所有必要的选项。
配置类通常包含以下部分:
- 常规选项: 描述一般设置,如配置说明、
gpus、seed等。 - 环境选项: 包括设置如
tf32、cudnn、deterministic等。 - 数据集选项: 指定
dataset_name(数据集名,必须显式指定)、dataset_type(数据集类)、dataset_params(数据集参数)等。 - 数据缩放器选项: 指定
scaler(缩放器类)、norm_each_channel(通道独立归一化)、rescale(是否反归一化)等。 - 模型选项: 指定
model(模型类,必须显式指定)、model_config(模型参数,必须显式指定)等。 - 评估指标选项: 包括
metrics(评估指标函数)、target_metric(目标评估指标)等。 - 训练选项:
- 常规: 指定设置如
num_epochs/num_steps、loss等。 - 优化器: 指定
optimizer(优化器类)、optimizer_params(优化器参数)等。 - 调度器: 指定
lr_scheduler(调度器类)、lr_scheduler_params(调度器参数)等。 - 数据: 指定设置如
batch_size、num_workers、pin_memory等。
- 常规: 指定设置如
- 验证选项:
- 常规: 包括验证频率
val_interval。 - 数据: 指定设置如
batch_size、num_workers、pin_memory等。
- 常规: 包括验证频率
- 测试选项:
- 常规: 包括测试频率
test_interval。 - 数据: 指定设置如
batch_size、num_workers、pin_memory等。
- 常规: 包括测试频率
Config类字段的metadata提供了每个字段的默认值、详细含义及使用方法。
🏗️ 构造配置类
👥 既是类,也是字典
BasicTS的配置类继承自EasyDict,既可以作为类使用,也可以作为字典使用,方便扩展并且灵活易用。下面介绍两种使用方式:
-
像类一样使用:一切参数都是类的字段,可以像访问类的字段一样访问、修改以及添加新的参数。例如:
model = config.model config.gpus = "0" config.new_field = new_value -
像字典一样使用:一切参数都是字典的键值,可以像访问字典一样访问、修改、添加、删除参数。例如:
model = config["model"] optimizer = config.get("optimizer", Adam) config["new_key"] = new_value config.pop("not_used")
🔨 配置类中的对象
也许你已经发现了,配置中存在许多需要进一步构造的对象,例如模型、数据集、优化器、调度器等。BasicTS在配置中传入对应的类和创建对象所需的参数,而将这些对象的创建延迟到实际执行任务时。
对于这些对象,BasicTS支持两种方式来灵活地配置:
-
将构造对象所需的全部参数以字典的形式传入。例如,将参数字典传给dataset_params以供后续创建数据集的实例。
config = BasicTSForecastingConfig( dataset_params={ 'input_len': 336, 'output_len': 336, ...}, ...) -
直接将构造对象需要的参数作为字段传入。例如,可以直接配置Config类的
input_len,output_len,并将自定义的参数传入(不能直接在构造方法中添加未定义的字段)。config = BasicTSForecastingConfig( input_len=336, output_len=336, ...) config.your_dataset_param_1 = param_1 config.your_dataset_param_2 = param_2
❓ 常见问题 (FAQ)
Q1: 训练时出现 CUDA out of memory 怎么办?
A: 尝试以下方法:
- 减小
batch_size - 减小
input_len或output_len - 使用
memmap=True加载大数据集 - 使用混合精度训练
Q2: 如何使用多 GPU 训练?
A: 修改配置中的 gpus 参数:
cfg = BasicTSForecastingConfig(
gpus="0,1,2,3", # 使用 4 张 GPU
...
)
Q3: 如何恢复中断的训练?
A: 使用检查点恢复:
cfg.resume_from = "checkpoints/last_checkpoint.pt"
Q4: 如何只进行推理而不训练?
A: 使用 launch_evaluation 并加载预训练权重。
Q5: 自定义数据集需要什么格式?
A: 数据需要是 [L, N, C] 形状的 numpy 数组,其中:
L = 时间步数 N = 变量数 C = 特征数(通常为 1)
声明
本文内容整理自 BasicTS 官方 GitHub 仓库 及其官方文档,结合个人学习理解进行归纳总结,仅供学习交流使用。
文中引用的论文链接、代码链接均指向原作者的官方发布地址,版权归原作者所有。如有任何内容侵犯了您的权益,请通过评论区或私信联系我,我将在核实后第一时间删除相关内容。
感谢 BasicTS 开发团队的开源贡献!
欢迎关注我的公众号:「木子吉星」
