无法从'jax'导入名称'linear_util

1 投票
1 回答
156 浏览
提问于 2025-04-12 23:52

我正在尝试复现S5模型的实验,https://github.com/lindermanlab/S5,但在解决环境问题时遇到了一些麻烦。当我运行这个脚本./run_lra_cifar.sh时,出现了以下错误

Traceback (most recent call last):
  File "/Path/S5/run_train.py", line 3, in <module>
    from s5.train import train
  File "/Path/S5/s5/train.py", line 7, in <module>
    from .train_helpers import create_train_state, reduce_lr_on_plateau,\
  File "/Path/train_helpers.py", line 6, in <module>
    from flax.training import train_state
  File "/Path/miniconda3/lib/python3.12/site-packages/flax/__init__.py", line 19, in <module>
    from . import core
  File "/Path/miniconda3/lib/python3.12/site-packages/flax/core/__init__.py", line 15, in <module>
    from .axes_scan import broadcast
  File "/Path/miniconda3/lib/python3.12/site-packages/flax/core/axes_scan.py", line 22, in <module>
    from jax import linear_util as lu
ImportError: cannot import name 'linear_util' from 'jax' (/Path/miniconda3/lib/python3.12/site-packages/jax/__init__.py)

我是在一台RTX4090的电脑上运行,CUDA版本是11.8。我的jax版本是0.4.25,jaxlib版本是0.4.25+cuda11.cudnn86

我最开始尝试按照作者的方式安装依赖

pip install -r requirements_gpu.txt

但是在我这儿似乎不太管用,因为我连import jax都不能执行。所以我按照https://jax.readthedocs.io/en/latest/installation.html上的说明安装了jax,输入了

pip install --upgrade pip
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

到目前为止,我尝试过:

  1. 使用旧一点的显卡(3060和2070)
  2. 把python降级到3.9

有没有人知道可能出什么问题了?任何帮助都非常感谢

1 个回答

0

jax.linear_util 在 JAX 版本 0.4.16 中被标记为不再推荐使用,之后在版本 0.4.24 中被彻底移除了。

看起来 flax 是引入 linear_util 的来源,这意味着你正在使用一个旧版的 flax,而它和你当前使用的新版本 jax 不兼容。

要解决这个问题,你需要么安装一个旧版本的 JAX,这个版本里还有 jax.linear_util,要么更新到一个与新版本 JAX 兼容的新版 flax

撰写回答