Keras 3.0正式发布:可用于 TensorFlow、JAX 和 PyTorch

2 阅读3分钟

经过 5 个月的更新迭代,Keras 3.0 终于来了。

「大新闻:我们刚刚发布了 Keras 3.0 版本!」Keras 之父 François Chollet 在 X 上激动的表示。「现在你可以在 JAX、TensorFlow 以及 PyTorch 框架上运行 Keras……」

图片

对于这一更新,Keras 官方表示,这一版本足足花了他们 5 个月的时间进行公测才完成。Keras 3.0 是对 Keras 的完全重写,你可以在 JAX、TensorFlow 或 PyTorch 之上运行 Keras 工作流,新版本还具有全新的大模型训练和部署功能。你可以选择最适合自己的框架,也可以根据当前的目标从一种框架切换到另一种框架都没有问题。

图片

Keras 地址:keras.io/keras\\_3/

被 250 多万开发者使用的 Keras,迎来 3.0 版本

Keras API 可用于 JAX、TensorFlow 和 PyTorch。现有的仅使用内置层的 tf.keras 模型可以在 JAX 和 PyTorch 中运行!

图片

Keras 3 可与任何 JAX、TensorFlow 和 PyTorch 工作流无缝协作。Keras 3 不仅适用于以 Keras 为中心的工作流,比如定义 Keras 模型、优化器、损失和度量,它还旨在与 JAX、TensorFlow 和 PyTorch 低级后端本地工作流无缝集成,在训练 Keras 模型时,你可以选择使用 JAX 训练、TensorFlow 训练、PyTorch 训练,也可以将其作为 JAX 或 PyTorch 模型的一部分,上述操作都没有问题。Keras 3 在 JAX 和 PyTorch 中提供了与 tf.keras 在 TensorFlow 中相同程度的低级实现灵活性。

图片

预训练模型。你现在可以在 Keras 3 中使用各种预训练模型。现在已经有 40 个 Keras 应用模型可在后端中使用,此外,KerasCV 和 KerasNLP 中存在的大量预训练模型(例如 BERT、T5、YOLOv8、Whisper 、SAM 等)也适用于所有后端。

Keras 3 高度向后兼容 Keras 2:Keras 3 现在实现了 Keras 2 的公共 API 接口。大多数用户无需更改任何代码即可在 Keras 3 上运行 Keras 脚本。如果你还不习惯使用 Keras 3,可以选择忽略新版本的更新,继续将 Keras 2 与 TensorFlow 结合使用。

Keras 3 支持所有后端的跨框架数据 pipeline。多框架机器学习也意味着多框架数据加载和预处理。Keras 3 模型可以使用各种数据 pipeline 进行训练,无论你使用的是 JAX、PyTorch 还是 TensorFlow 后端:

  • tf.data.Dataset pipelines。

  • torch.utils.data.DataLoader 对象。

  • NumPy 数组和 Pandas 数据帧。

  • Keras 的 keras.utils.PyDataset 对象。

一个新的分布式 API,可用于大规模数据并行和模型并行。目前这一更新仅适用于 JAX 后端,TensorFlow 和 PyTorch 支持即将推出。

至于为何要推出这一更改,Keras 团队表示,近年来,随着模型规模变得越来越大,他们希望为多设备模型分片(sharding)问题提供 Keras 解决方案。该团队设计的 API 使模型定义、训练逻辑和分片配置完全独立,这意味着模型可以像在单个设备上运行一样, 然后,你可以在训练模型时将分片配置添加到任意模型中。

数据并行(在多个设备上相同地复制小模型)只需两行即可处理:

图片

接下来是模型并行。该 API 允许你通过正则表达式配置每个变量和每个输出张量的布局。这使得为整个变量类别快速指定相同的布局变得容易。

图片

最后,Keras 团队还收集了很多大家关心的问题,并予以解答,感兴趣的读者可以前去 Keras 官方网站,了解更多内容。