jax:如何解决错误:pmap请求沿轴0映射其参数,但其秩应至少为1,但仅为0?
我正在尝试运行这个关于基于分数的生成模型的简单介绍。这个代码使用了flax.optim
,但似乎已经转移到了optax
(你可以查看这个链接了解更多:https://flax.readthedocs.io/en/latest/guides/converting_and_upgrading/optax_update_guide.html)。
我做了一个代码的副本,并进行了我认为需要的修改(我不太确定如何替换optimizer = flax.jax_utils.replicate(optimizer)
这行)。
现在,在训练部分,我遇到了一个错误:
pmap被请求沿着轴0进行映射,这意味着它的秩(rank)应该至少为1,但实际上只有0(它的形状是())
这个错误出现在loss, params, opt_state = train_step_fn(step_rng, x, params, opt_state)
这一行。显然,这个问题来自于“定义损失函数”部分的return jax.pmap(step_fn, axis_name='device')
。
我该如何修复这个错误呢?我在网上搜索过,但不知道问题出在哪里。
1 个回答
1
这个问题发生是因为你把一个单一的值(标量)传给了一个需要处理多个值的函数,也就是pmapped函数。举个例子:
import jax
func = lambda x: x ** 2
pfunc = jax.pmap(func)
pfunc(1.0)
# ValueError: pmap was requested to map its argument along axis 0, which implies
# that its rank should be at least 1, but is only 0 (its shape is ())
如果你想对一个单一的值进行操作,应该直接使用这个函数,而不是用pmap
包裹起来:
func(1.0)
# 1.0
另外,如果你想使用pmap
,那么你需要传入一个数组,这个数组的第一个维度要和设备的数量相匹配:
num_devices = len(jax.devices())
x = jax.numpy.arange(num_devices)
pfunc(x)
# Array([ 0, 1, 4, 9, 16, 25, 36, 49], dtype=int32)