在GPU上安装Jax

-3 投票
0 回答
31 浏览
提问于 2025-04-12 00:59

我正在尝试复现s5模型的结果,链接在这里:https://github.com/lindermanlab/S5。但是在安装jax的时候遇到了一个大问题。作者提供了一个requirements_gpu.txt文件:

flax
torch
torchtext
tensorflow-datasets==4.5.2
pydub==0.25.1
datasets
tqdm
--find-links https://storage.googleapis.com/jax-releases/jax_releases.html
jax[cuda]>=version

这会导致

from jax import linear_util as lu
ImportError: cannot import name 'linear_util' from 'jax'

我也尝试过按照官方指南页面直接安装jax,链接在这里:https://jax.readthedocs.io/en/latest/installation.html

但是还是出现了很多错误。我尝试过降级和升级我的Python版本,还用过不同的显卡(4090搭配cuda12.3,3060搭配cuda11.4,以及2060搭配cuda11.1),但都没有成功 :-(

这是作者在代码中导入jax的地方:

from jax import random
import jax.numpy as np
from jax.scipy.linalg import block_diag
import jax
from jax.nn import one_hot
from flax.training import train_state
from flax import linen as nn
from jax.nn.initializers import lecun_normal, normal
from jax import random
from jax.numpy.linalg import eigh
from flax import linen as nn

我当前的环境有:

absl-py==1.4.0
aiohttp==3.9.3
aiosignal==1.3.1
async-timeout==4.0.3
attrs==23.2.0
certifi==2024.2.2
charset-normalizer==3.3.2
chex==0.1.6
colorama==0.4.6
commonmark==0.9.1
contourpy==1.2.0
cycler==0.12.1
datasets==2.18.0
dill==0.3.8
dm-tree==0.1.8
etils==1.7.0
filelock==3.13.1
flax==0.4.0
fonttools==4.50.0
frozenlist==1.4.1
fsspec==2024.2.0
googleapis-common-protos==1.63.0
grpcio==1.62.1
huggingface-hub==0.21.4
idna==3.6
importlib_resources==6.4.0
jax==0.2.22
jaxlib==0.4.13
Jinja2==3.1.3
joblib==1.3.2
kiwisolver==1.4.5
Markdown==3.6
MarkupSafe==2.1.5
matplotlib==3.8.3
ml-dtypes==0.3.2
mpmath==1.3.0
msgpack==1.0.8
multidict==6.0.5
multiprocess==0.70.16
nest-asyncio==1.6.0
networkx==3.2.1
numpy==1.26.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.19.3
nvidia-nvjitlink-cu12==12.4.99
nvidia-nvtx-cu12==12.1.105
opt-einsum==3.3.0
optax==0.1.7
orbax-checkpoint==0.5.7
packaging==24.0
pandas==2.2.1
pillow==10.2.0
promise==2.3
protobuf==3.20.3
pyarrow==15.0.2
pyarrow-hotfix==0.6
pydub==0.25.1
Pygments==2.17.2
pyparsing==3.1.2
python-dateutil==2.9.0.post0
pytz==2024.1
PyYAML==6.0.1
requests==2.31.0
rich==11.2.0
scikit-learn==1.4.1.post1
scipy==1.12.0
six==1.16.0
sympy==1.12
tensorboard==2.16.2
tensorboard-data-server==0.7.2
tensorflow-datasets==4.5.2
tensorflow-metadata==1.14.0
tensorstore==0.1.56
termcolor==2.4.0
threadpoolctl==3.4.0
toolz==0.12.1
torch==2.2.1
torchdata==0.7.1
torchtext==0.17.1
torchvision==0.17.1
tqdm==4.66.2
triton==2.2.0
typing_extensions==4.10.0
tzdata==2024.1
urllib3==2.2.1
Werkzeug==3.0.1
xxhash==3.4.1
yarl==1.9.4
zipp==3.18.1

正如我所说,我们的集群里有三种类型的显卡:4090搭配cuda12.3,3060搭配cuda 11.4,以及2060搭配cuda 11.1。我想知道有没有人能告诉我,jax、jaxlib和flax的哪个版本可以在这三台机器中的至少一台上运行。非常感谢你的帮助。如果你需要更多信息,请告诉我。

0 个回答

暂无回答

撰写回答