对于高性能的机器学习研究,Just After eXceution(JAX)是CPU、GPU和TPU上的NumPy,具有出色的自动区分功能。它是一个用于高性能数值计算的Python库,特别是机器学习研究。它的数值API是基于NumPy,一个用于科学计算的函数库。Python和NumPy都是著名的和常用的编程语言,使JAX变得直接、通用和简单的实现。本文将重点介绍JAX的特点和实现,以建立一个深度学习模型。以下是将要涉及的主题。
内容列表
- 使用JAX的原因
- 什么是XLA?
- JAX的生态系统里有什么?
- 用JAX构建ML模型
JAX不是google的官方产品,但它的受欢迎程度却在不断提高,让我们来了解一下受欢迎背后的原因。
使用JAX的原因
尽管JAX为开发加速数字代码提供了一个直接而强大的API,但使用JAX有效地工作,偶尔也需要额外的思考。JAX本质上是一个即时编译器(JIT),专注于生成高效的代码,同时利用纯Python的简单性。 除了NumPy API之外,JAX还包含一套可扩展的、有助于机器学习研究的可组合的函数转换,例如。
-
差异化:基于梯度的优化是机器学习的关键。JAX在正向和反向模式下使用函数转换(如Gradients、Hessian和Jacobian(jacfwd和jacrev))实现了任意数值函数的自动微分。
-
矢量化:在机器学习研究中,一个单一的函数经常被应用于大量的数据,如计算整个批次的损失或评估每个例子的梯度,用于差异化私有学习。JAX中的vmap转换可以实现自动矢量化,从而简化了这种类型的编程。例如,在开发新算法时,研究人员不需要考虑批处理。JAX还可以通过相关的pmap转换实现大规模的数据并行化,它可以优雅地分配对于单个加速器的内存来说过于庞大的数据。
-
及时编译(JIT):XLA被用来在GPU和云TPU加速器上进行JIT编译和运行JAX应用程序。JIT编译与JAX的NumPy一致的API相结合,使之前没有高性能计算经验的研究人员可以随时扩展到一个或多个加速器。
你是否在寻找一个完整的数据科学中使用的Python库。 在此查看.
什么是XLA?
XLA(加速线性代数)是一个特定领域的线性代数编译器,可以加速TensorFlow模型,只需修改少量源代码。
当一个TensorFlow程序被执行时,TensorFlow执行器独立地执行每个操作。执行器为每个TensorFlow操作调度到一个预编译的GPU内核实现。XLA提供了一种额外的模型执行方式,将TensorFlow图编译成一连串特别为指定模型建立的计算内核。因为这些内核是特定于模型的,它们可以使用特定于模型的信息来进行优化。
XLA的结构
XLA的输入语言被称为高级运算(HLO)。将HLO视为编译器的中间表示法是最方便的。因此,HLO代表 "在 "源语言和目标语言之间的程序。
XLA将HLO中描述的图形翻译成多个平台的机器指令。XLA是模块化的,即可以很容易地插入一个替代的后端,以针对一些创新的硬件架构。XLA在与目标无关的阶段之后将HLO计算转移到后端。后端可以进行额外的HLO级优化,这次要考虑到目标特定的数据和要求。
下面的步骤是生成目标特定的代码。XLA捆绑的CPU和GPU后端使用LLVM来进行低级别的中间表示优化和代码创建。这些后端产生有效描述XLA HLO计算所需的LLVM IR,然后使用LLVM从这个LLVM中间表示中发射出本地代码。
使用XLA的理由
使用XLA有四个主要原因:
- 因为根据定义,翻译似乎需要分析和综合。字对字的翻译是无效的。
- 将翻译的复杂挑战分为两个更简单、更容易管理的部分。
- 可以为现有的前端构建一个新的后端,以提供可重定向的编译器,反之亦然。
- 进行独立于机器的优化。
JAX的生态系统里有什么?
该生态系统由五个不同的库组成。
Haiku
处理有状态的对象,如具有可训练参数的神经网络,在JAX的可组合函数转换的编程范式下可能会很困难。Haiku是一个神经网络库,它使用户能够使用传统的面向对象的编程范式,同时利用JAX的纯函数式范式的力量和简单性。
一些外部项目,包括Coax、DeepChem和NumPyro,积极使用Haiku。它扩展了Sonnet的API,我们在TensorFlow中基于模块的神经网络编程模型。
Optax
基于梯度的优化对机器学习很重要。Optax包括一个梯度转换库以及组合运算符(如链),允许在一行代码中开发许多常见的优化器(如RMSProp或Adam)。Optax的组合结构使其很容易在定制的优化器中重新组合相同的基本元素。它还包括用于随机梯度估计和二阶优化的实用程序。
RLax
RLax是一个库,为强化学习(RL)的发展提供了重要的构建模块,也被称为深度强化学习。RLax的组件包括TD-学习、策略梯度、行为准则、MAP、近似策略优化、非线性价值转换、通用价值函数和众多探索方法。
RLax并不意味着是一个开发和部署成熟的RL代理系统的框架。Acme是建立在RLax组件上的全功能代理架构的一个例子。
测试
测试对于软件的可靠性至关重要,研究代码也不例外。从研究试验中得出科学结论,就必须相信你的代码的准确性。Chex是一个测试工具的集合,库的编写者用它来确保通用构件的正确性和适应性,最终用户也用它来验证他们的实验方案。
Chex包括一些工具,如JAX感知单元测试、JAX数据类型属性的断言、模拟和伪造,以及多设备测试环境。
Jraph
Jraph是一个用于在JAX中处理图形神经网络GNN的小库。Jraph为图提供了一个标准化的数据结构,一套用于处理图的工具,以及一套可随时分叉和扩展的图神经网络模型。其他主要功能包括利用硬件加速器的GraphTuple批处理,通过填充和屏蔽对可变形状图的JIT编译支持,以及跨输入分区的损失指定。Jraph,像Optax和我们的其他库一样,对用户选择的神经网络库没有任何限制。
用JAX构建ML模型
在这篇文章中,在TensorFlow平台上构建一个生成对抗网模型,在Jax的Haiku中对MNIST数据集进行训练。
让我们首先安装Haiku和Optax
!pip install dm-haiku
! pip install optax
导入必要的库
import functools
from typing import Any, NamedTuple
import haiku as hk
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
读取数据集
mnist_dataset = tfds.load("mnist")
def make_dataset(batch_size, seed=1):
def _preprocess(sample):
image = tf.image.convert_image_dtype(sample["image"], tf.float32)
return 2.0 * image - 1.0
ds = mnist["train"]
ds = ds.map(map_func=_preprocess,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds = ds.cache()
ds = ds.shuffle(10 * batch_size, seed=seed).repeat().batch(batch_size)
return iter(tfds.as_numpy(ds))
创建生成器和判别器
该模型被用作生成器,以产生来自问题领域的新的合理实例,而该模型被用作判别器,以确定一个实例是真实的(来自领域)还是生成的。
class Generator(hk.Module):
def __init__(self, output_channels=(32, 1), name=None):
super().__init__(name=name)
self.output_channels = output_channels
def __call__(self, x):
x = hk.Linear(7 * 7 * 64)(x)
x = jnp.reshape(x, x.shape[:1] + (7, 7, 64))
for output_channels in self.output_channels:
x = jax.nn.relu(x)
x = hk.Conv2DTranspose(output_channels=output_channels,
kernel_shape=[5, 5],
stride=2,
padding="SAME")(x)
return jnp.tanh(x)
class Discriminator(hk.Module):
def __init__(self,
output_channels=(8, 16, 32, 64, 128),
strides=(2, 1, 2, 1, 2),
name=None):
super().__init__(name=name)
self.output_channels = output_channels
self.strides = strides
def __call__(self, x):
for output_channels, stride in zip(self.output_channels, self.strides):
x = hk.Conv2D(output_channels=output_channels,
kernel_shape=[5, 5],
stride=stride,
padding="SAME")(x)
x = jax.nn.leaky_relu(x, negative_slope=0.2)
x = hk.Flatten()(x)
logits = hk.Linear(2)(x)
return logits
创建GAN算法
import optax
class GAN_algo_basic:
def __init__(self, num_latents):
self.num_latents = num_latents
self.gen_transform = hk.without_apply_rng(
hk.transform(lambda *args: Generator()(*args)))
self.disc_transform = hk.without_apply_rng(
hk.transform(lambda *args: Discriminator()(*args)))
self.optimizers = GANTuple(gen=optax.adam(1e-4, b1=0.5, b2=0.9),
disc=optax.adam(1e-4, b1=0.5, b2=0.9))
@functools.partial(jax.jit, static_argnums=0)
def initial_state(self, rng, batch):
dummy_latents = jnp.zeros((batch.shape[0], self.num_latents))
rng_gen, rng_disc = jax.random.split(rng)
params = GANTuple(gen=self.gen_transform.init(rng_gen, dummy_latents),
disc=self.disc_transform.init(rng_disc, batch))
print("Generator: \n\n{}\n".format(tree_shape(params.gen)))
print("Discriminator: \n\n{}\n".format(tree_shape(params.disc)))
opt_state = GANTuple(gen=self.optimizers.gen.init(params.gen),
disc=self.optimizers.disc.init(params.disc))
return GANState(params=params, opt_state=opt_state)
def sample(self, rng, gen_params, num_samples):
"""Generates images from noise latents."""
latents = jax.random.normal(rng, shape=(num_samples, self.num_latents))
return self.gen_transform.apply(gen_params, latents)
def gen_loss(self, gen_params, rng, disc_params, batch):
fake_batch = self.sample(rng, gen_params, num_samples=batch.shape[0])
fake_logits = self.disc_transform.apply(disc_params, fake_batch)
fake_probs = jax.nn.softmax(fake_logits)[:, 1]
loss = -jnp.log(fake_probs)
return jnp.mean(loss)
def disc_loss(self, disc_params, rng, gen_params, batch):
fake_batch = self.sample(rng, gen_params, num_samples=batch.shape[0])
real_and_fake_batch = jnp.concatenate([batch, fake_batch], axis=0)
real_and_fake_logits = self.disc_transform.apply(disc_params,
real_and_fake_batch)
real_logits, fake_logits = jnp.split(real_and_fake_logits, 2, axis=0)
real_labels = jnp.ones((batch.shape[0],), dtype=jnp.int32)
real_loss = sparse_softmax_cross_entropy(real_logits, real_labels)
fake_labels = jnp.zeros((batch.shape[0],), dtype=jnp.int32)
fake_loss = sparse_softmax_cross_entropy(fake_logits, fake_labels)
return jnp.mean(real_loss + fake_loss)
@functools.partial(jax.jit, static_argnums=0)
def update(self, rng, gan_state, batch):
rng, rng_gen, rng_disc = jax.random.split(rng, 3)
disc_loss, disc_grads = jax.value_and_grad(self.disc_loss)(
gan_state.params.disc,
rng_disc,
gan_state.params.gen,
batch)
disc_update, disc_opt_state = self.optimizers.disc.update(
disc_grads, gan_state.opt_state.disc)
disc_params = optax.apply_updates(gan_state.params.disc, disc_update)
gen_loss, gen_grads = jax.value_and_grad(self.gen_loss)(
gan_state.params.gen,
rng_gen,
gan_state.params.disc,
batch)
gen_update, gen_opt_state = self.optimizers.gen.update(
gen_grads, gan_state.opt_state.gen)
gen_params = optax.apply_updates(gan_state.params.gen, gen_update)
params = GANTuple(gen=gen_params, disc=disc_params)
opt_state = GANTuple(gen=gen_opt_state, disc=disc_opt_state)
gan_state = GANState(params=params, opt_state=opt_state)
log = {
"gen_loss": gen_loss,
"disc_loss": disc_loss,
}
return rng, gan_state, log
训练模型
for step in range(num_steps):
rng, gan_state, log = model.update(rng, gan_state, next(dataset))
if step % log_every == 0:
log = jax.device_get(log)
gen_loss = log["gen_loss"]
disc_loss = log["disc_loss"]
print(f"Step {step}: "
f"gen_loss = {gen_loss:.3f}, disc_loss = {disc_loss:.3f}")
steps.append(step)
gen_losses.append(gen_loss)
disc_losses.append(disc_loss)
由于时间限制,该模型将被训练5000步。这取决于用户对步骤数量的选择。对于5000步,大约需要60分钟。

印度分析杂志
分析发生器和判别器的损失
fig, axes = plt.subplots(1, 2, figsize=(20, 6))
# Plot the discriminator loss.
axes[0].plot(steps, disc_losses, "-")
axes[0].set_title("Discriminator loss", fontsize=20)
# Plot the generator loss.
axes[1].plot(steps, gen_losses, '-')
axes[1].set_title("Generator loss", fontsize=20);

印度分析》杂志
我们可以观察到,在最初的2000步中,发生器的损失相当高,而在3000步之后,鉴别器和发生器的损失平均来说是不变的。
结论
Just After eXceution(JAX)是一种高性能的数值计算,特别是在机器学习研究中。它的数值API是基于NumPy,一个用于科学计算的函数库。通过这篇文章,我们已经了解了JAX的生态系统以及作为该生态系统一部分的Optax和Haiku的实现。