替换numpy数组中大于限制的值

4 投票
4 回答
8974 浏览
提问于 2025-04-16 14:47

我有一个 n x m 的数组,还有每一列的最大值。除了逐个检查每个元素,还有什么更好的方法来替换那些超过最大值的数吗?

举个例子:

def check_limits(bad_array, maxs):
    good_array = np.copy(bad_array)
    for i_line in xrange(bad_array.shape[0]):
        for i_column in xrange(bad_array.shape[1]):
            if good_array[i_line][i_column] >= maxs[i_column]:
                good_array[i_line][i_column] = maxs[i_column] - 1
    return good_array

有没有什么更快、更简洁的方法来做到这一点呢?

4 个回答

0

如果我们不对 bad_array 的结构做任何假设,那么你的代码在对手的角度来看是最优的。如果我们知道每一列都是按升序排列的,那么一旦我们遇到一个比最大值还大的数,我们就可以确定这一列后面的所有元素也都比这个限制值大。但如果没有这样的假设,我们就得一个一个地检查每个元素。

如果你决定先对每一列进行排序,这样的时间复杂度是 (n 列 * nlogn),这已经比检查每个元素所需的 n*n 时间要长了。

你也可以通过逐个检查并复制元素来创建 good_array,而不是先把 bad_array 中的所有元素都复制过来再检查。这样做大约可以把时间减少一半。

2

另一种方法是使用 clip 函数:

用eumiro的例子:

bad_array = np.array([[ 0,  1,  2,  3],
                      [ 4,  5,  6,  7],
                      [ 8,  9, 10, 11]])
maxs = np.array([7,6,5,4])

good_array = bad_array.clip(max=maxs-1)

或者

bad_array.clip(max=maxs-1, out=good_array)

你还可以通过添加参数 min= 来指定下限。

11

使用 putmask 方法:

import numpy as np

a = np.array([[ 0,  1,  2,  3],
              [ 4,  5,  6,  7],
              [ 8,  9, 10, 11]])
m = np.array([7,6,5,4])

# This is what you need:

np.putmask(a, a >= m, m - 1)

# a is now:

np.array([[0, 1, 2, 3],
          [4, 5, 4, 3],
          [6, 5, 4, 3]])

撰写回答