我有一个函数compute(x)
,其中x
是一个jnp.ndarray
。现在,我想使用vmap
将它转换为一个函数,该函数需要一批数组x[i]
,然后jit
来加速它compute(x)
类似于:
def compute(x):
# ... some code
y = very_expensive_function(x)
return y
但是,每个数组x[i]
具有不同的长度。我可以很容易地解决这个问题,方法是用尾随零填充数组,使它们都具有相同的长度N
,并且vmap(compute)
可以应用于具有形状(batch_size, N)
的批
但是,这样做会导致对每个数组x[i]
的尾随零也调用very_expensive_function()
。有没有一种方法可以修改compute()
,使得very_expensive_function()
只在x
的片上调用,而不干扰vmap
和jit
使用JAX,当您希望jit函数以加快速度时,给定的批处理参数
x
必须是定义良好的ndarray(即x[i]必须具有相同的形状)。无论您是否使用vmap
,这都是正确的现在,通常的处理方法是填充这些数组。这意味着您在参数中添加了一个掩码,这样填充的值就不会影响结果。例如,如果我想计算形状},我需要“禁用”填充值的效果。以下是一个例子:
(bath_size, max_length)
的填充值softmax
{它不像填充
x
那么简单。您需要在每一步实际更改计算以禁用填充效果。对于softmax,可以通过将填充值设置为接近负无穷大来实现这一点最后,您无法事先真正知道,使用填充+掩蔽还是不使用填充+掩蔽,速度性能是否会更好。根据我的经验,这通常会导致CPU的良好改进,以及GPU的巨大改进。特别是,批次大小的选择对性能有很大的影响,因为较高的
batch_size
在统计上会导致较高的max_length
,因此在填充值上执行的“无用”计算的数量较高相关问题 更多 >
编程相关推荐