jax-vmap:强制正确的形状

2024-06-16 09:27:13 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在使用vmap对部分代码进行矢量化。下面是一个最小的例子,在矢量化之前:

dim = 2
def sum(x):
    a = np.ones((dim,))
    return np.dot(x, a)

num_samples = 100
samples = np.ones((num_samples, dim))

sum(samples[0]) # 2

使用vmap:

sum = vmap(sum)
sum(samples) # DeviceArray of shape (100,), all entries are 2

但在矢量化之后,这可能会出错:

sum(samples[0]) # DeviceArray of shape (2,2), all entries are 1

这里发生的是samples[0]的形状是(2,)。向量化的函数调用将其输入参数沿第一个轴拆分,因此被馈送2个shape(1,)数组。由于使用a广播,结果输出再次具有(2,)形状,并且被堆叠到(2,2)阵列。你知道吗

这对我来说似乎很危险。代码看起来很正常,结果输出很容易被其他一些广播规则所占用,这些规则隐藏了代码的损坏形状。你知道吗

有没有可能强制执行正确的形状?你知道吗


Tags: of代码nponesall矢量化numentries