不同长度的JAX批处理

2024-06-16 13:08:59 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个函数compute(x),其中x是一个jnp.ndarray。现在,我想使用vmap将它转换为一个函数,该函数需要一批数组x[i],然后jit来加速它compute(x)类似于:

def compute(x):
    # ... some code
    y = very_expensive_function(x)
    return y

但是,每个数组x[i]具有不同的长度。我可以很容易地解决这个问题,方法是用尾随零填充数组,使它们都具有相同的长度N,并且vmap(compute)可以应用于具有形状(batch_size, N)的批

但是,这样做会导致对每个数组x[i]的尾随零也调用very_expensive_function()。有没有一种方法可以修改compute(),使得very_expensive_function()只在x的片上调用,而不干扰vmapjit


Tags: 方法函数returndefcodefunctionsome数组
1条回答
网友
1楼 · 发布于 2024-06-16 13:08:59

使用JAX,当您希望jit函数以加快速度时,给定的批处理参数x必须是定义良好的ndarray(即x[i]必须具有相同的形状)。无论您是否使用vmap,这都是正确的

现在,通常的处理方法是填充这些数组。这意味着您在参数中添加了一个掩码,这样填充的值就不会影响结果。例如,如果我想计算形状(bath_size, max_length)的填充值softmax{},我需要“禁用”填充值的效果。以下是一个例子:

import jax.numpy as jnp
import jax

PAD = 0
MINUS_INFINITY = -1e6

x = jnp.array([ 
       [1, 2, 3, 4],
       [1, 2, PAD, PAD],
       [1, 2, 3, PAD]
    ])

mask = jnp.array([
           [1, 1, 1, 1],
           [1, 1, 0, 0],
           [1, 1, 1, 0]
       ])
       
masked_sofmax = jax.nn.softmax(x + (1-mask)*MINUS_INFINITY)    

它不像填充x那么简单。您需要在每一步实际更改计算以禁用填充效果。对于softmax,可以通过将填充值设置为接近负无穷大来实现这一点

最后,您无法事先真正知道,使用填充+掩蔽还是不使用填充+掩蔽,速度性能是否会更好。根据我的经验,这通常会导致CPU的良好改进,以及GPU的巨大改进。特别是,批次大小的选择对性能有很大的影响,因为较高的batch_size在统计上会导致较高的max_length,因此在填充值上执行的“无用”计算的数量较高

相关问题 更多 >