import numpy
data = numpy.random.randint(0, 10, (6,8))
test = set(numpy.random.randint(0, 10, 5))
我需要一个值为布尔数组的表达式,其形状为data
(或者至少可以被重塑为相同的形状),它告诉我data
中的对应项是否在set
中。在
例如,如果我想知道data
的哪些元素严格小于6
,我可以使用一个向量化的表达式
它计算6x8
布尔nArray。相反,当我尝试一个显然等价的布尔表达式时
b = data in test
我得到的是一个例外:
TypeError: unhashable type: 'numpy.ndarray'
编辑:由于hpaulj,下面的可能性4给出了错误的结果 还有Divakar让我走上正轨。
这里我比较了四种不同的可能性
np.in1d(data, np.hstack(test))
。在np.in1d(data, np.array(list(test)))
。在np.in1d(data, test)
。这里是Ipython会话,为了避免出现空白行,稍微进行了编辑
In [1]: import numpy as np
In [2]: nr, nc = 100, 100
In [3]: top = 3000
In [4]: data = np.random.randint(0, top, (nr, nc))
In [5]: test = set(np.random.randint(0, top, top//3))
In [6]: %timeit np.in1d(data, np.hstack(test))
100 loops, best of 3: 5.65 ms per loop
In [7]: %timeit np.in1d(data, np.array(list(test)))
1000 loops, best of 3: 1.4 ms per loop
In [8]: %timeit np.in1d(data, np.fromiter(test, int))
1000 loops, best of 3: 1.33 ms per loop
In [9]: %timeit np.in1d(data, test)
1000 loops, best of 3: 687 µs per loop
In [10]: nr, nc = 1000, 1000
In [11]: top = 300000
In [12]: data = np.random.randint(0, top, (nr, nc))
In [13]: test = set(np.random.randint(0, top, top//3))
In [14]: %timeit np.in1d(data, np.hstack(test))
1 loop, best of 3: 706 ms per loop
In [15]: %timeit np.in1d(data, np.array(list(test)))
1 loop, best of 3: 269 ms per loop
In [16]: %timeit np.in1d(data, np.fromiter(test, int))
1 loop, best of 3: 274 ms per loop
In [17]: %timeit np.in1d(data, test)
10 loops, best of 3: 67.9 ms per loop
In [18]:
匿名发帖人的回答给了我们更美好的时光。
原来,匿名发帖者有很好的理由删除他们的答案,结果是错误的!在
正如hpaulj所评论的那样,在in1d
的文档中有一条警告,不要使用set
作为第二个参数,但是如果计算结果可能出错,我更希望显式失败。在
也就是说,使用numpy.fromiter()
的解决方案有最好的数字。。。在
我假设您正在寻找一个布尔数组来检测} 从} 来检测
set
元素在data
数组中的存在。为此,可以使用^{set
提取元素,然后使用^{set
中每个位置的set
中是否存在任何元素,给我们一个与data
大小相同的布尔数组。因为,np.in1d
在处理之前会使输入变平,因此作为最后一步,我们需要将输出从np.in1d
改回其原始的2D
形状。因此,最终实施将是-样本运行-
^{pr2}$表达式
a = data < 6
返回一个新数组,因为<
是一个值比较运算符。在请注意,
in
运算符不在此列表中。可能是因为它的工作方向与大多数操作符相反。在当
a + b
与a.__add__(b)
相同时,a in b
从右到左b.__contains__(a)
。在本例中,python尝试调用set.__contains__()
,它只接受散列/不可变类型。数组是可变的,所以它们不能是集合的成员。在解决方法是直接使用
numpy.vectorize
而不是in
,并对数组中的每个元素调用任何python函数。在它是numpy数组的一种
map()
。在基准
当n较大时,这种方法很快,因为
^{pr2}$set.__contains__()
是一个恒定时间操作。(“大”表示top
>;13000左右)然而,当n很小时,其他解决方案要快得多。在
相关问题 更多 >
编程相关推荐