无法从'jax'导入名称'linear_util
我正在尝试复现S5模型的实验,https://github.com/lindermanlab/S5,但在解决环境问题时遇到了一些麻烦。当我运行这个脚本./run_lra_cifar.sh
时,出现了以下错误
Traceback (most recent call last):
File "/Path/S5/run_train.py", line 3, in <module>
from s5.train import train
File "/Path/S5/s5/train.py", line 7, in <module>
from .train_helpers import create_train_state, reduce_lr_on_plateau,\
File "/Path/train_helpers.py", line 6, in <module>
from flax.training import train_state
File "/Path/miniconda3/lib/python3.12/site-packages/flax/__init__.py", line 19, in <module>
from . import core
File "/Path/miniconda3/lib/python3.12/site-packages/flax/core/__init__.py", line 15, in <module>
from .axes_scan import broadcast
File "/Path/miniconda3/lib/python3.12/site-packages/flax/core/axes_scan.py", line 22, in <module>
from jax import linear_util as lu
ImportError: cannot import name 'linear_util' from 'jax' (/Path/miniconda3/lib/python3.12/site-packages/jax/__init__.py)
我是在一台RTX4090的电脑上运行,CUDA版本是11.8。我的jax版本是0.4.25,jaxlib版本是0.4.25+cuda11.cudnn86
我最开始尝试按照作者的方式安装依赖
pip install -r requirements_gpu.txt
但是在我这儿似乎不太管用,因为我连import jax
都不能执行。所以我按照https://jax.readthedocs.io/en/latest/installation.html上的说明安装了jax,输入了
pip install --upgrade pip
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
到目前为止,我尝试过:
- 使用旧一点的显卡(3060和2070)
- 把python降级到3.9
有没有人知道可能出什么问题了?任何帮助都非常感谢
1 个回答
0
jax.linear_util
在 JAX 版本 0.4.16 中被标记为不再推荐使用,之后在版本 0.4.24 中被彻底移除了。
看起来 flax
是引入 linear_util
的来源,这意味着你正在使用一个旧版的 flax,而它和你当前使用的新版本 jax 不兼容。
要解决这个问题,你需要么安装一个旧版本的 JAX,这个版本里还有 jax.linear_util
,要么更新到一个与新版本 JAX 兼容的新版 flax
。