我正在使用vmap对部分代码进行矢量化。下面是一个最小的例子,在矢量化之前:
dim = 2
def sum(x):
a = np.ones((dim,))
return np.dot(x, a)
num_samples = 100
samples = np.ones((num_samples, dim))
sum(samples[0]) # 2
使用vmap:
sum = vmap(sum)
sum(samples) # DeviceArray of shape (100,), all entries are 2
但在矢量化之后,这可能会出错:
sum(samples[0]) # DeviceArray of shape (2,2), all entries are 1
这里发生的是samples[0]
的形状是(2,)
。向量化的函数调用将其输入参数沿第一个轴拆分,因此被馈送2个shape(1,)
数组。由于使用a
广播,结果输出再次具有(2,)
形状,并且被堆叠到(2,2)
阵列。你知道吗
这对我来说似乎很危险。代码看起来很正常,结果输出很容易被其他一些广播规则所占用,这些规则隐藏了代码的损坏形状。你知道吗
有没有可能强制执行正确的形状?你知道吗
目前没有回答
相关问题 更多 >
编程相关推荐