Mitra:混合合成先验增强表格基础模型
生成多样化的合成先验分布,使得表格基础模型能够超越特定任务的基线模型。
表格数据支撑着医疗、金融、电商和科学等领域的關鍵决策。然而,传统用于表格数据的机器学习方法——例如随机森林和 XGBoost——通常会产生针对单个数据集的模型,跨不同分布的迁移能力有限。
受大型语言模型成功的启发,表格基础模型有望改变这一现状:无需为每个任务单独训练模型,单个预训练模型只需通过 conditioning 在中等数量的示例上就能泛化到新任务,这种技术称为上下文学习。
作为某机构自动机器学习框架 AutoGluon 最新版本的一部分,我们在此介绍 Mitra,一个在此类基于上下文学习的范式内训练的表格基础模型。就像大型语言模型在多样化的文本语料库上训练一样,Mitra 在由精心设计的混合先验分布生成的合成数据集上进行预训练。
乍一看,在预训练 Mitra 时未使用任何真实世界数据可能令人惊讶。但现实中的表格数据通常是有限且异构的,具有变化的特征类型、依赖关系和噪声水平。事实证明,模拟覆盖广泛可能数据模式的多样化合成数据集更为实用。
我们发现这些合成先验的质量对模型的泛化能力起着关键作用。有效的先验倾向于:(1) 在真实任务上产生良好性能;(2) 展现多样性,防止过拟合;(3) 提供其他先验中不存在且独特的模式。
基于这些原则,我们构建了一个混合先验,其中包括结构因果模型——它将变量间的因果依赖关系图与描述每个变量值变化对其因变量影响的(概率)方程相结合;以及流行的基于树的方法,如梯度提升、随机森林和决策树。这些先验共同使 Mitra 能够学习鲁棒的表示,并有效地泛化到各种各样的真实世界表格问题上。
Mitra 框架概览。 我们在混合合成数据先验(包括结构因果模型和基于树的模型)上预训练表格基础模型。每个数据集被分割成支撑集和查询集。Mitra 支持跨行和列的 2D 注意力以及 1D 逐行注意力。在推理时,模型以来自真实数据集的支撑示例为条件,使用上下文学习来预测查询标签,无需梯度更新。
我们在选定的先验混合上预训练 Mitra。每个合成任务包含一个支撑集和一个查询集。模型通过学习关注支撑集来预测查询集的标签,无需梯度更新。经过数百万个这样的任务,Mitra 学到了可泛化的推理和自适应模式。其架构基于跨行和特征的 2D 注意力,能够灵活处理不同大小的表格和特征交互。
我们在分类和回归任务上,跨越 TabRepo、TabZilla、AMLB 和 TabArena 等主要表格基准对 Mitra 进行了评估。与强大的表格基础模型(如 TabPFNv2 和 TabICL)以及特定数据集的模型(如 CatBoost、RealMLP 和 AutoGluon 1.3 最佳质量预设)相比,Mitra 展现了最先进的性能。
Mitra 评估结果。 每个评估指标的胜者和亚军分别以绿色和蓝色显示。缩写 +e 表示上下文学习中的集成,+f 表示微调。Elo 评分括号内显示 95% 置信区间。聚合指标列中的数值是相应指标的均值(括号内为标准差)。
Mitra 和基线模型在二维正弦棋盘数据上的决策边界。 Mitra 显示出比 TabPFNv2 更规则、更少碎片化的决策边界。
正如基础模型重塑了计算机视觉和自然语言处理领域,Mitra 为表格数据预测提供了一种更通用、更有效的方法。随着该领域的进步,我们预见到更丰富的先验空间和自适应的混合策略。Mitra 已在 AutoGluon 1.4 版本中开源,可供使用。我们邀请研究者和实践者探索这个用于表格预测的新基础。
了解更多:
- Mitra 分类器
- Mitra 回归器
致谢: Junming Yin, Nick Erickson, Abdul Fatir Ansari, Boran Han, Shuai Zhang, Leman Akoglu, Christos Faloutsos, Michael W. Mahoney, Cuixiong Hu, Huzefa Rangwala, George Karypis, Bernie Wang
研究领域: 机器学习 标签: 表格数据FINISHED