如果我们对数组索引执行非整数计算(然后转换为int()),那么在jit编译的jax代码中似乎仍然无法将结果用作有效索引。我们如何解决这个问题
下面是一个简单的例子。具体问题:命令jnp.diag_index(d)是否可以在不向fun()传递额外参数的情况下工作
在Jupiter单元格中运行此操作:
import jax.numpy as jnp
from jax import jit
@jit
def fun(t):
d = jnp.sqrt(t.size**2)
d = jnp.array(d,int)
jnp.diag_indices(t.size) # this line works
jnp.diag_indices(d) # this line breaks. Comment it out to see that d and t.size have the same dtype=int32
return t.size, d
fun(jnp.array([1,2]))
问题不是
d
的类型,而是d
是jax操作的结果,因此在JIT上下文中被跟踪。在JAX中,数组的形状和大小不能依赖于跟踪的数量,这就是代码导致错误的原因为了解决这个问题,一个有用的模式是使用
np
操作而不是jnp
操作,以确保d
是静态的且不被跟踪:有关跟踪、静态值和类似主题的简要背景信息,How To Think In JAX文档页面可能会有所帮助
相关问题 更多 >
编程相关推荐