计算numpy数组的和,同时排除某些值

4 投票
4 回答
17646 浏览
提问于 2025-04-18 15:30

我想要计算一个二维的 numpy 数组的总和。不过,有些特定值的元素我想在计算时排除掉。请问有什么高效的方法可以做到这一点呢?

举个例子,我先创建了一个全是1的二维 numpy 数组,然后把其中几个元素换成了2:

import numpy

data_set = numpy.ones((10, 10))

data_set[4][4] = 2
data_set[5][5] = 2
data_set[6][6] = 2

我该如何在这个二维数组中求和,同时排除所有的2呢?注意,对于一个10乘10的数组,正确的答案应该是97,因为我把三个元素的值改成了2。

我知道可以用嵌套的for循环来实现这个功能。例如:

elements = []
for idx_x in range(data_set.shape[0]):
  for idx_y in range(data_set.shape[1]):
    if data_set[idx_x][idx_y] != 2:
      elements.append(data_set[idx_x][idx_y])

data_set_sum = numpy.sum(elements)

但是在我的实际数据中(数据量非常大),这样做太慢了。那有什么正确的方法可以做到这一点呢?

4 个回答

0

这样做怎么样?我们可以利用numpy的布尔功能。

我们只需要把所有符合条件的值先设为零,然后再进行求和。这样做的好处是,不会改变数组的形状,跟直接从数组中筛选出来的方式不同。

另一个好处是,这样我们在应用过滤后,还可以沿着某个轴进行求和。

import numpy

data_set = numpy.ones((10, 10))

data_set[4][4] = 2
data_set[5][5] = 2
data_set[6][6] = 2

print "Sum", data_set.sum()

another_set = numpy.array(data_set) # Take a copy, we'll need that later

data_set[data_set == 2] = 0  # Set all the values that are 2 to zero
print "Filtered sum", data_set.sum()
print "Along axis", data_set.sum(0), data_set.sum(1)

同样,我们也可以使用其他布尔值来设置我们想要排除在求和之外的数据。

another_set[(another_set > 1) & (another_set < 3)] = 0
print "Another filtered sum", another_set.sum()
1

使用 np.sumwhere= 参数,我们可以避免因为使用高级数组索引而导致的数组复制。

>>> import numpy as np
>>> data_set = np.ones((10,10))
>>> data_set[(4,5,6),(4,5,6)] = 2
>>> np.sum(data_set, where=data_set != 2)
97.0
>>> data_set.sum(where=data_set != 2)
97.0

https://numpy.org/doc/stable/reference/generated/numpy.sum.html

高级索引总是会返回数据的一个副本,而基本切片则返回的是一个视图。

https://numpy.org/doc/stable/user/basics.indexing.html#advanced-indexing

5

如果不使用numpy,这个解决方案其实也没那么复杂:

x = [1,2,3,4,5,6,7]
sum(y for y in x if y != 7)
# 21

对于一个排除值的列表,这个方法也能用:

# set is faster for resolving `in`
exl = set([1,2,3])
sum(y for y in x if y not in exl)
# 22
12

使用numpy的功能,可以通过布尔数组进行索引。在下面的例子中,data_set!=2会生成一个布尔数组,当元素不是2时,这个数组的值为True(并且形状是正确的)。所以,data_set[data_set!=2]是一种快速又方便的方法,可以得到一个不包含特定值的数组。当然,布尔表达式可以更复杂。

In [1]: import numpy as np
In [2]: data_set = np.ones((10, 10))
In [4]: data_set[4,4] = 2
In [5]: data_set[5,5] = 2
In [6]: data_set[6,6] = 2
In [7]: data_set[data_set != 2].sum()
Out[7]: 97.0
In [8]: data_set != 2
Out[8]: 
array([[ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       ...
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True]], dtype=bool)

撰写回答