从Numpy矩阵构造Python集合

22 投票
6 回答
46240 浏览
提问于 2025-04-15 17:14

我正在尝试执行以下代码

>> from numpy import *
>> x = array([[3,2,3],[4,4,4]])
>> y = set(x)
TypeError: unhashable type: 'numpy.ndarray'

我该如何简单有效地从Numpy数组中创建一个包含所有元素的集合呢?

6 个回答

9

上面的回答适用于你想从一个ndarray中的元素创建一个集合的情况,但如果你想创建一个包含ndarray对象的集合,或者把ndarray对象当作字典的键使用,那你就需要给它们提供一个可以哈希的包装类。下面的代码是一个简单的例子:

from hashlib import sha1

from numpy import all, array, uint8


class hashable(object):
    r'''Hashable wrapper for ndarray objects.

        Instances of ndarray are not hashable, meaning they cannot be added to
        sets, nor used as keys in dictionaries. This is by design - ndarray
        objects are mutable, and therefore cannot reliably implement the
        __hash__() method.

        The hashable class allows a way around this limitation. It implements
        the required methods for hashable objects in terms of an encapsulated
        ndarray object. This can be either a copied instance (which is safer)
        or the original object (which requires the user to be careful enough
        not to modify it).
    '''
    def __init__(self, wrapped, tight=False):
        r'''Creates a new hashable object encapsulating an ndarray.

            wrapped
                The wrapped ndarray.

            tight
                Optional. If True, a copy of the input ndaray is created.
                Defaults to False.
        '''
        self.__tight = tight
        self.__wrapped = array(wrapped) if tight else wrapped
        self.__hash = int(sha1(wrapped.view(uint8)).hexdigest(), 16)

    def __eq__(self, other):
        return all(self.__wrapped == other.__wrapped)

    def __hash__(self):
        return self.__hash

    def unwrap(self):
        r'''Returns the encapsulated ndarray.

            If the wrapper is "tight", a copy of the encapsulated ndarray is
            returned. Otherwise, the encapsulated ndarray itself is returned.
        '''
        if self.__tight:
            return array(self.__wrapped)

        return self.__wrapped

使用这个包装类其实很简单:

>>> from numpy import arange

>>> a = arange(0, 1024)
>>> d = {}
>>> d[a] = 'foo'
Traceback (most recent call last):
  File "<input>", line 1, in <module>
TypeError: unhashable type: 'numpy.ndarray'
>>> b = hashable(a)
>>> d[b] = 'bar'
>>> d[b]
'bar'
19

数组的不可变版本是元组,所以你可以尝试把数组中的数组转换成元组的数组:

>> from numpy import *
>> x = array([[3,2,3],[4,4,4]])

>> x_hashable = map(tuple, x)

>> y = set(x_hashable)
set([(3, 2, 3), (4, 4, 4)])
36

如果你想要一组元素,这里有另一种可能更快的方法:

y = set(x.flatten())

补充说明: 在对一个10x100的数组进行比较时,我测试了 x.flatx.flatten()x.ravel(),发现它们的速度差不多。对于一个3x3的数组,速度最快的是迭代器版本:

y = set(x.flat)

我推荐这个方法,因为它占用的内存更少(随着数组大小的增加,它的表现也很好)。

再补充一下: 还有一个NumPy的函数可以做类似的事情:

y = numpy.unique(x)

这个函数会生成一个NumPy数组,里面的元素和 set(x.flat) 一样,但它是以NumPy数组的形式存在。这种方法非常快(几乎快了10倍),不过如果你需要一个 set,那么用 set(numpy.unique(x)) 的速度会比其他方法稍慢一些(因为创建一个集合会消耗更多资源)。

撰写回答