如何在编译的jax代码中执行非整数索引算法?

2024-05-14 18:39:13 发布

您现在位置:Python中文网/ 问答频道 /正文

如果我们对数组索引执行非整数计算(然后转换为int()),那么在jit编译的jax代码中似乎仍然无法将结果用作有效索引。我们如何解决这个问题

下面是一个简单的例子。具体问题:命令jnp.diag_index(d)是否可以在不向fun()传递额外参数的情况下工作

在Jupiter单元格中运行此操作:

import jax.numpy as jnp
from jax import jit

@jit
def fun(t):
    d = jnp.sqrt(t.size**2)
    d = jnp.array(d,int)
    
    jnp.diag_indices(t.size)   # this line works
    jnp.diag_indices(d)        # this line breaks. Comment it out to see that d and t.size have the same dtype=int32 

    return t.size, d
    
fun(jnp.array([1,2]))    

Tags: 代码importsizeline整数数组thisarray
1条回答
网友
1楼 · 发布于 2024-05-14 18:39:13

问题不是d的类型,而是d是jax操作的结果,因此在JIT上下文中被跟踪。在JAX中,数组的形状和大小不能依赖于跟踪的数量,这就是代码导致错误的原因

为了解决这个问题,一个有用的模式是使用np操作而不是jnp操作,以确保d是静态的且不被跟踪:

import jax.numpy as jnp
from jax import jit

@jit
def fun(t):
    d = np.sqrt(t.size**2)
    d = np.array(d, int)
    
    jnp.diag_indices(t.size)
    jnp.diag_indices(d)

    return t.size, d
    
print(fun(jnp.array([1,2])))
# (DeviceArray(2, dtype=int32), DeviceArray(2, dtype=int32))

有关跟踪、静态值和类似主题的简要背景信息,How To Think In JAX文档页面可能会有所帮助

相关问题 更多 >

    热门问题