如何将jax-vmap用于嵌套循环?

2024-06-16 10:49:18 发布

您现在位置:Python中文网/ 问答频道 /正文

我想使用vmap将此代码矢量化以提高性能

def matrix(dataA, dataB):
    return jnp.array([[func(a, b) for b in dataB] for a in dataA])
matrix(data, data)

我试过这个:

def f(x, y):
    return func(x, y)
mapped = jax.vmap(f)
mapped(data, data)

但这只给出对角线条目

基本上我有一个向量data = [1,2,3,4,5](示例),我想得到一个矩阵,使得矩阵的每个条目(i, j)都是f(data[i], data[j])。因此,得到的矩阵形状将是(len(data), len(data))


Tags: 代码infordatalenreturndef条目
1条回答
网友
1楼 · 发布于 2024-06-16 10:49:18

jax.vmap一次映射一组轴。如果要跨两组独立的轴进行映射,可以通过嵌套两个vmap变换来实现:

mapped = jax.vmap(jax.vmap(f, in_axes=(None, 0)), in_axes=(0, None))
result = mapped(data, data)

相关问题 更多 >