·  阅读 678

# JAX快速入门

``````import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

## 乘法矩阵

``````key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)

``````[-0.372111    0.2642311  -0.18252774 -0.7368198  -0.44030386 -0.15214427
-0.6713536  -0.59086424  0.73168874  0.56730247]

``````size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready()  # runs on the GPU

``````489 ms ± 3.98 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

JAX NumPy函数可在常规NumPy数组上使用。

``````import numpy as np
x = np.random.normal(size=(size, size)).astype(np.float32)

``````488 ms ± 942 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

``````from jax import device_put

x = np.random.normal(size=(size, size)).astype(np.float32)
x = device_put(x)

``````487 ms ± 9.94 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

``````x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit np.dot(x, x.T)

``````235 ms ± 546 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

JAX不仅仅是一个由GPU支持的NumPy。它还带有一些程序转换，这些转换在编写数字代码时很有用。目前，主要有三个：

## 利用`jit()`加快功能

JAX在GPU上透明运行（如果没有，则在CPU上运行，而TPU即将推出！）。但是，在上面的示例中，JAX一次将内核分配给GPU一次操作。如果我们有一系列操作，则可以使用`@jit`装饰器使用XLA一起编译多个操作。让我们尝试一下。

``````def selu(x, alpha=1.67, lmbda=1.05):
return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = random.normal(key, (1000000,))

``````4.4 ms ± 107 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

``````selu_jit = jit(selu)

``````860 µs ± 27.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

## 通过 `grad()计算梯度`

``````def sum_logistic(x):
return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
print(derivative_fn(x_small))

``````[0.25       0.19661197 0.10499357]

``````def first_finite_differences(f, x):
eps = 1e-3
return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
for v in jnp.eye(len(x))])

print(first_finite_differences(sum_logistic, x_small))

``````[0.24998187 0.1964569  0.10502338]

``````print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))

``````-0.035325594

``````from jax import jacfwd, jacrev
def hessian(fun):
return jit(jacfwd(jacrev(fun)))

## 自动向量化 `vmap()`

JAX在其API中还有另一种转换，您可能会发现它有用：`vmap()`向量化映射。它具有沿数组轴映射函数的熟悉语义（ familiar semantics），但不是将循环保留在外部，而是将循环推入函数的原始操作中以提高性能。当与组合时`jit()`，它的速度可以与手动添加批处理尺寸一样快。

``````mat = random.normal(key, (150, 100))
batched_x = random.normal(key, (10, 100))

def apply_matrix(v):
return jnp.dot(mat, v)

``````def naively_batched_apply_matrix(v_batched):
return jnp.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')

``````Naively batched

``````4.43 ms ± 9.91 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

``````@jit
def batched_apply_matrix(v_batched):
return jnp.dot(v_batched, mat.T)

print('Manually batched')

``````Manually batched

``````51.9 µs ± 1.72 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

``````@jit
def vmap_batched_apply_matrix(v_batched):
return vmap(apply_matrix)(v_batched)

print('Auto-vectorized with vmap')

``````Auto-vectorized with vmap

``````79.7 µs ± 249 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)