用于简单阵列更新的Jax-vmap

2024-06-16 10:47:47 发布

您现在位置:Python中文网/ 问答频道 /正文

我是Jax新手,我正在努力转换其他人的代码,这些代码使用了numba“fastmath”特性,并且依赖于许多嵌套for循环,没有太多性能损失。我试图使用Jax的vmap函数重新创建相同的行为。然而,我目前在一些基本问题上苦苦挣扎。下面是一个简化的示例,说明了我正在尝试使用vmap对哪些内容进行矢量化:

import jax.numpy as jnp
from jax import vmap
import jax.ops

a = jnp.arange(20).reshape((4, 5))
b = jnp.arange(5)
c = jnp.arange(4)
d = jnp.zeros(20)
e = jnp.zeros((4, 5))

for i in range(a.shape[0]):
    for j in range(a.shape[1]):
        a = jax.ops.index_add(a, jax.ops.index[i, j], b[j] + c[i])
        d = jax.ops.index_update(d, jax.ops.index[i*a.shape[1] + j], b[j] * c[i])
        e = jax.ops.index_update(e, jax.ops.index[i, j], 2*b[j])

我如何使用vmap重写这样的代码?虽然这段代码相对容易手动矢量化,但我希望更好地理解vmap的工作原理,并希望任何答案都能帮助我。医生们现在似乎对我没什么帮助。我真的很感激你能提供的任何帮助


Tags: 代码inimportforindexzerosrange矢量化
1条回答
网友
1楼 · 发布于 2024-06-16 10:47:47

下面是如何使用vmap实现大致相同的计算:

from jax import vmap, partial

@partial(vmap, in_axes=(0, None, 0))
@partial(vmap, in_axes=(0, 0, None))
def f(a, b, c):
  return a + b + c, b * c, 2 * b

a, d, e = f(a, b, c)
d = d.ravel()

相关问题 更多 >