携手创作,共同成长!这是我参与「掘金日新计划 · 8 月更文挑战」的第13天,点击查看活动详情
前言
为什么要学习 JAX,既然有了 PyTorch、Tensorflow 这样深度学习框架,我们为什么还要学习 JAX 框架呢。
- 首先 JAX 的新鲜事物,个人喜欢尝鲜,显得很 cool
- 想要了解更底层东西,JAX 可以帮助我们进一步了解底层技术
- 自己一来有一个梦想写一个简单深度学习框架,基于 JAX 可以让自己省一些力气
首先聊一聊 JAX,在学习机器学习,numpy 我想一定是大家不可缺少的库,Numpy 的 API 设计让我们体验到用 python 来操作向量和矩阵的轻松,这些良好体验让 python 在大数据或数字科学上站稳了脚跟。不过因为 Numpy 没有对 GPU 以及自动求导方面良好的支持,所以在深度学习领域应用也就是打杂的位置,其实主流框架已经兼容取代了 Numpy 的功能,只不过大家还是对 Numpy 比较熟悉,对于处理一些问题,第一个想到还是 Numpy,所以其还留有一席之地。
估计 JAX API 设计者,想要熟悉 Numpy 的开发者无缝迁移到 JAX 使用,下面例子很好解释了这一点
import jax.numpy as jnp
import numpy as np
在 JAX 提供一个叫 numpy 模块,其模块下 API 基本与 Numpy 一致,也提供 Scipy 不过自己没有用过 Scipy 模块的 API
x_np = np.linspace(0,10,1000)
y_np = 2 * np.sin(x_np) * np.cos(x_np)
plt.plot(x_np,y_np)
x_jnp = jnp.linspace(0,10,1000)
y_jnp = 2 * jnp.sin(x_np) * jnp.cos(x_np)
plt.plot(x_jnp,y_jnp)
函数式编程
在 JAX 支持函数式编程,例如其变量是 immutable 请看下面例子
x = np.arange(size)
print(x)
x[index] = value
print(x)
[0 1 2 3 4 5 6 7 8 9]
[12 1 2 3 4 5 6 7 8 9]
x = jnp.arange(size)
print(x)
x[index] = value
print(x)
当试图直接修改定义好数组 x 会抛出下面错误,提示不支持对数组再次赋值,因为 JAX 是 immutable(不可变数据类型)
object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[]
解决方案
y = x.at[index].set(value)
print(y)
其实并没有更新原有数组 x 而是重新创建一份在此基础上进行修改得到一个新的数组
随机数生成方式
from jax import random, device_put
jax.random包提供了一些用于确定性地生成伪随机数一般流程。
seed = 0
key = random.PRNGKey(seed)
x = random.normal(key,(10,))
x
与 NumPy 和 SciPy 是有状态的伪随机数生成器(PRNG)不同,JAX 随机函数都需要将一个显示地传入一个参数 的PRNG 状态作为生成随机数的第一个参数。随机状态由两个无符号的 32 位整数描述,称之为密钥,可以使用 jax.random.PRNGKey()函数生成密钥
from jax._src.api import block_until_ready
size = 3000
x_jnp = random.normal(key,(size,size),dtype=jnp.float32)
x_np = np.random.normal((size,size)).astype(np.float32)
%timeit jnp.dot(x_jnp,x_jnp.T).block_until_ready()
%timeit np.dot(x_np,x_np.T)
%timeit jnp.dot(x_np,x_np.T).block_until_ready()
x_np_device = device_put(x_np)
%timeit jnp.dot(x_np_device,x_np_device)
在 JAX 默认就是在 GPU 对于 AI 模型进行加速的,而对于 Numpy 是不支持在 GPU 上进行矩阵运算的。
在矩阵运算后加上 block_until_ready 好处是因为 JAX 在 GPU 上运算是异步的,所以如果没有添加 block_util_ready 就不会等待上面运算结束再向下进行,而添加了 block_until_ready 就会等待计算结束再向下进行
def visualize_fn(fn,l=-10,r=10,n=1000):
x = np.linspace(l,r,num=n)
y = fn(x)
plt.plot(x,y)
plt.show()
selu_jit = jit(selu)
visualize_fn(selu)
data = random.normal(key,(1000000,))
print("non-jit version:")
%timeit selu(data).block_until_ready()
print("jit version")
%timeit selu_jit(data).block_until_ready()
def fn_1(x):
return x**2
grad(fn_1)(2.0)
grad(grad(fn_1))(2.0)
#DeviceArray(2., dtype=float32, weak_type=True)
grad(grad(grad(fn_1)))(2.0)
DeviceArray(4., dtype=float32, weak_type=True)
def sum_logistic(x):
return jnp.sum(1.0/(1.0 + jnp.exp(-x)))
x = jnp.arange(3.)
x
loss = sum_logistic
grad_loss = grad(loss)
print(grad_loss(x))