Windows系统下安装JAX的GPU版本

1,602 阅读1分钟

第一步:确认自己的cuda版本

打开NAIDIA控制面板,点击系统信息,点击组件 image.png

image.png

第二步:安装jaxlib

中括号里面的版本号按照你的cuda版本来写,例如我的版本为11.2,这里就是cuda112 在命令行中执行命令

pip install jaxlib[cuda112] -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver

第三步:安装jax

在命令行中执行命令

pip install jax[cuda112] -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver

[参考链接](GitHub - cloudhan/jax-windows-builder: A community supported Windows build for jax.)

[参考链接](GitHub - google/jax: Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more)