测试数组是否可以广播到某个形状?

6 投票
7 回答
3751 浏览
提问于 2025-04-18 13:17

怎样才能测试一个数组是否可以调整成特定的形状呢?

用“python风格”的方法去尝试一下并不适合我的情况,因为我希望在执行操作时能够延迟计算。

我想知道如何实现下面的 is_broadcastable 函数:

>>> x = np.ones([2,2,2])
>>> y = np.ones([2,2])
>>> is_broadcastable(x,y)
True
>>> y = np.ones([2,3])
>>> is_broadcastable(x,y)
False

或者更好的是:

>>> is_broadcastable(x.shape, y.shape)

7 个回答

2

numpy.broadcast_shapes 从 numpy 1.20 版本开始可以使用了,所以你可以很简单地这样来实现:

import numpy as np

def is_broadcastable(shp1, shp2):
    try:
        np.broadcast_shapes(shp1, shp2)
        return True
    except ValueError:
        return False

在背后,它使用的是零长度的列表 numpy 数组来调用 broadcast_arrays,具体做法是:

np.empty(shp, dtype=[])

这样做可以避免占用内存。这个方法和 @ChrisB 提出的方案类似,但不依赖于 as_strided 的技巧,我觉得那个有点让人困惑。

3

如果你想检查任意数量的类似数组的对象(而不是传递形状),我们可以利用 np.nditer 来进行广播数组迭代

def is_broadcastable(*arrays):
    try:
        np.nditer(arrays)
        return True
    except ValueError:
        return False

需要注意的是,这个方法只适用于 np.ndarray 或者那些定义了 __array__ 的类(这个方法被调用)。

4

你可以使用 np.broadcast。比如说:

In [47]: x = np.ones([2,2,2])

In [48]: y = np.ones([2,3])

In [49]: try:
   ....:     b = np.broadcast(x, y)
   ....:     print "Result has shape", b.shape
   ....: except ValueError:
   ....:     print "Not compatible for broadcasting"
   ....:     
Not compatible for broadcasting

In [50]: y = np.ones([2,2])

In [51]: try:
   ....:     b = np.broadcast(x, y)
   ....:     print "Result has shape", b.shape
   ....: except ValueError:
   ....:     print "Not compatible for broadcasting"
   ....:
Result has shape (2, 2, 2)

在你实现懒惰求值的时候,np.broadcast_arrays 也可能会对你很有帮助。

6

如果你只是想避免创建一个特定形状的数组,可以使用 as_strided:

import numpy as np
from numpy.lib.stride_tricks import as_strided

def is_broadcastable(shp1, shp2):
    x = np.array([1])
    a = as_strided(x, shape=shp1, strides=[0] * len(shp1))
    b = as_strided(x, shape=shp2, strides=[0] * len(shp2))
    try:
        c = np.broadcast_arrays(a, b)
        return True
    except ValueError:
        return False

is_broadcastable((1000, 1000, 1000), (1000, 1, 1000))  # True
is_broadcastable((1000, 1000, 1000), (3,))  # False

这样做很节省内存,因为 a 和 b 都是由同一个记录支持的。

13

我觉得你们想得太复杂了,为什么不简单点呢?

def is_broadcastable(shp1, shp2):
    for a, b in zip(shp1[::-1], shp2[::-1]):
        if a == 1 or b == 1 or a == b:
            pass
        else:
            return False
    return True

撰写回答