第一步:确认自己的cuda版本
打开NAIDIA控制面板,点击系统信息,点击组件
第二步:安装jaxlib
中括号里面的版本号按照你的cuda版本来写,例如我的版本为11.2,这里就是cuda112 在命令行中执行命令
pip install jaxlib[cuda112] -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
第三步:安装jax
在命令行中执行命令
pip install jax[cuda112] -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
[参考链接](GitHub - cloudhan/jax-windows-builder: A community supported Windows build for jax.)