Numpy 向量化小于/大于比较
我有一段代码,用来把角度和它们所在的圆的象限匹配起来。现在这段代码能给我想要的结果,但我想去掉循环,充分利用numpy的速度。
import numpy as np
angle = np.array([350, 10, 80, 100, 170, 190, 260, 280])
# Center of each quadrant
spawn_angles = np.array([0, 90, 180, 270])
segment_degrees = np.diff(spawn_angles)[0]
lower_bounds = spawn_angles - (segment_degrees / 2)
upper_bounds = spawn_angles + (segment_degrees / 2)
max_upper = upper_bounds.max()
# Wrap angles larger than the upper bound of the last segment
# back to a negative angle
angle[angle > max_upper] -= 360
quadrant = np.zeros_like(angle, dtype=np.float64)
# Want to make sure that quadrants that don't get calculated
# properly get assigned an invalid number, i.e. -1
quadrant.fill(-1)
for segment_num in range(len(spawn_angles)):
in_segment = ((angle > lower_bounds[segment_num]) &
(angle < upper_bounds[segment_num]))
quadrant[in_segment] = segment_num
# Expected/current output
quadrant
Out[16]: array([ 0., 0., 1., 1., 2., 2., 3., 3.])
基本上,我搞不懂怎么在numpy里进行>
/<
的比较。如果一个角度在lower_bounds[0]
和upper_bounds[0]
之间,那么对应的quadrant
就被赋值为0,其他象限也是这样处理的。有没有办法可以同时把角度数组和所有的lower_bound
和/或upper_bound
进行比较呢?
(如果这段代码看起来有点复杂,那是因为spawn_angles
和象限中心不总是[0, 90, 180, 270]
,它们也可以是比如[45, 135, 225, 315]
这样的值。)
2 个回答
1
感谢abarnert提供的关键见解。以下是我重新整理过的向量化代码:
import numpy as np
angle = np.array([350, 10, 80, 100, 170, 190, 260, 280])
# Center of each quadrant
spawn_angles = np.array([0, 90, 180, 270])
segment_degrees = np.diff(spawn_angles)[0]
lower_bounds = spawn_angles - (segment_degrees / 2)
upper_bounds = spawn_angles + (segment_degrees / 2)
max_upper = upper_bounds.max()
# Wrap angles larger than the upper bound of the last segment
# back to a negative angle
angle[angle > max_upper] -= 360
angle_2d = angle.reshape((len(angle), 1))
cmp_array = ((angle_2d > lower_bounds) &
(angle_2d < upper_bounds))
quadrant = np.argwhere(cmp_array)[:, 1]
quadrant
Out[29]: array([0, 0, 1, 1, 2, 2, 3, 3], dtype=int64)
1
你需要把所有的东西提升一个维度。你想要一个二维数组,每个角度作为一行,每个段号作为一列。(或者你可能想要转置一下,但如果是这样的话,你应该能从这里自己搞明白。)
如果你只是用 a > b
,其中 a
和 b
都是一维数组,那你就是在进行逐个元素的比较。
但是如果 a
是一个二维数组,那你就是在进行笛卡尔积的比较。
换句话说:
>>> array.reshape((8,1)) > lower_bounds
array([[ True, True, True, True],
[ True, False, False, False],
[ True, True, False, False],
[ True, True, False, False],
[ True, True, True, False],
[ True, True, True, False],
[ True, True, True, True],
[ True, True, True, True]], dtype=bool)
你应该能从这里继续理解下去。