多个布尔条件下 `jax.lax.cond` 的等价实现

1 投票
1 回答
46 浏览
提问于 2025-04-14 17:38

目前,jax.lax.cond 只能处理一个布尔条件。有没有办法让它支持多个布尔条件呢?

举个例子,下面是一个无法追踪的函数:

def func(x):
    if x < 0: return x
    elif (x >= 0) & (x < 1): return 2*x
    else: return 3*x

怎么才能用 JAX 写出一个可以追踪的函数呢?

1 个回答

1

一种简洁的写法是使用 jnp.select

import jax
import jax.numpy as jnp

@jax.jit
def func(x):
  return jnp.select([x < 0, x < 1], [x, 2 * x], default=3 * x)

x = jnp.array([-0.5, 0.5, 1.5])
print(func(x))
# [-0.5  1.   4.5]

撰写回答