jax:如何解决错误:pmap请求沿轴0映射其参数,但其秩应至少为1,但仅为0?

1 投票
1 回答
33 浏览
提问于 2025-04-12 05:30

我正在尝试运行这个关于基于分数的生成模型的简单介绍。这个代码使用了flax.optim,但似乎已经转移到了optax(你可以查看这个链接了解更多:https://flax.readthedocs.io/en/latest/guides/converting_and_upgrading/optax_update_guide.html)。

我做了一个代码的副本,并进行了我认为需要的修改(我不太确定如何替换optimizer = flax.jax_utils.replicate(optimizer)这行)。

现在,在训练部分,我遇到了一个错误:

pmap被请求沿着轴0进行映射,这意味着它的秩(rank)应该至少为1,但实际上只有0(它的形状是())

这个错误出现在loss, params, opt_state = train_step_fn(step_rng, x, params, opt_state)这一行。显然,这个问题来自于“定义损失函数”部分的return jax.pmap(step_fn, axis_name='device')

我该如何修复这个错误呢?我在网上搜索过,但不知道问题出在哪里。

1 个回答

1

这个问题发生是因为你把一个单一的值(标量)传给了一个需要处理多个值的函数,也就是pmapped函数。举个例子:

import jax
func = lambda x: x ** 2
pfunc = jax.pmap(func)

pfunc(1.0)
# ValueError: pmap was requested to map its argument along axis 0, which implies
# that its rank should be at least 1, but is only 0 (its shape is ())

如果你想对一个单一的值进行操作,应该直接使用这个函数,而不是用pmap包裹起来:

func(1.0)
# 1.0

另外,如果你想使用pmap,那么你需要传入一个数组,这个数组的第一个维度要和设备的数量相匹配:

num_devices = len(jax.devices())
x = jax.numpy.arange(num_devices)
pfunc(x)
# Array([ 0,  1,  4,  9, 16, 25, 36, 49], dtype=int32)

撰写回答