Numpy 向量化小于/大于比较

-1 投票
2 回答
2462 浏览
提问于 2025-05-01 16:40

我有一段代码,用来把角度和它们所在的圆的象限匹配起来。现在这段代码能给我想要的结果,但我想去掉循环,充分利用numpy的速度。

import numpy as np

angle = np.array([350, 10, 80, 100, 170, 190, 260, 280])
# Center of each quadrant
spawn_angles = np.array([0, 90, 180, 270])

segment_degrees = np.diff(spawn_angles)[0]
lower_bounds = spawn_angles - (segment_degrees / 2)
upper_bounds = spawn_angles + (segment_degrees / 2)
max_upper = upper_bounds.max()
# Wrap angles larger than the upper bound of the last segment
# back to a negative angle
angle[angle > max_upper] -= 360
quadrant = np.zeros_like(angle, dtype=np.float64)
# Want to make sure that quadrants that don't get calculated
# properly get assigned an invalid number, i.e. -1
quadrant.fill(-1)
for segment_num in range(len(spawn_angles)):
    in_segment = ((angle > lower_bounds[segment_num]) & 
                  (angle < upper_bounds[segment_num]))
    quadrant[in_segment] = segment_num

# Expected/current output
quadrant
Out[16]: array([ 0.,  0.,  1.,  1.,  2.,  2.,  3.,  3.])

基本上,我搞不懂怎么在numpy里进行>/<的比较。如果一个角度在lower_bounds[0]upper_bounds[0]之间,那么对应的quadrant就被赋值为0,其他象限也是这样处理的。有没有办法可以同时把角度数组和所有的lower_bound和/或upper_bound进行比较呢?

(如果这段代码看起来有点复杂,那是因为spawn_angles和象限中心不总是[0, 90, 180, 270],它们也可以是比如[45, 135, 225, 315]这样的值。)

暂无标签

2 个回答

1

感谢abarnert提供的关键见解。以下是我重新整理过的向量化代码:

import numpy as np

angle = np.array([350, 10, 80, 100, 170, 190, 260, 280])
# Center of each quadrant
spawn_angles = np.array([0, 90, 180, 270])

segment_degrees = np.diff(spawn_angles)[0]
lower_bounds = spawn_angles - (segment_degrees / 2)
upper_bounds = spawn_angles + (segment_degrees / 2)
max_upper = upper_bounds.max()
# Wrap angles larger than the upper bound of the last segment
# back to a negative angle
angle[angle > max_upper] -= 360
angle_2d = angle.reshape((len(angle), 1))
cmp_array = ((angle_2d > lower_bounds) & 
             (angle_2d < upper_bounds))
quadrant = np.argwhere(cmp_array)[:, 1]
quadrant
Out[29]: array([0, 0, 1, 1, 2, 2, 3, 3], dtype=int64)
1

你需要把所有的东西提升一个维度。你想要一个二维数组,每个角度作为一行,每个段号作为一列。(或者你可能想要转置一下,但如果是这样的话,你应该能从这里自己搞明白。)

如果你只是用 a > b,其中 ab 都是一维数组,那你就是在进行逐个元素的比较。

但是如果 a 是一个二维数组,那你就是在进行笛卡尔积的比较。

换句话说:

>>> array.reshape((8,1)) > lower_bounds
array([[ True,  True,  True,  True],
       [ True, False, False, False],
       [ True,  True, False, False],
       [ True,  True, False, False],
       [ True,  True,  True, False],
       [ True,  True,  True, False],
       [ True,  True,  True,  True],
       [ True,  True,  True,  True]], dtype=bool)

你应该能从这里继续理解下去。

撰写回答