应用指定规则构建数组
考虑一个数组 a
,它的每一列都包含从1、2、3中随机选出的值:
a = np.array([[2, 3, 1, 3],
[3, 2, 1, 3],
[1, 1, 1, 2],
[1, 3, 2, 3],
[3, 3, 1, 3],
[2, 1, 3, 2]])
现在,再考虑一个数组 b
,它的前两列包含从1、2、3中选出的9种可能的值对(这里值对的顺序是很重要的)。b
的第三列则为每一对值关联了一个非负整数。
b = np.array([[1, 1, 6],
[1, 2, 0],
[1, 3, 9],
[2, 1, 6],
[2, 2, 0],
[2, 3, 4],
[3, 1, 1],
[3, 2, 0],
[3, 3, 8]])
我需要帮助来写一段代码,生成一个数组 c
,其中数组 a
中上下相邻的元素会被 b
第三列中对应的值替换。例如,a
的第一列的值是2、3、1、1、3、2,经过替换后,c
的第一列会变成4、1、6、9、0。这个思路同样适用于 a
的每一列。我们可以看到,值对的顺序很重要(从3到1得到的值是1,而从1到3得到的值是9)。
这个小例子的输出结果是:
c = np.array([[4, 0, 6, 8],
[1, 6, 6, 0],
[6, 9, 0, 4],
[9, 8, 6, 8],
[0, 1, 9, 0]])
因为这段代码会被执行很多次,所以我希望能有一个快速的向量化解决方案。
1 个回答
3
因为 b
包含了所有的配对,所以你可以有效地把它变成一个方形的形式,按照行和列的编号来索引,然后用 sliding_window_view
来形成索引对,并索引这个方形的中间结果:
from numpy.lib.stride_tricks import sliding_window_view as swv
s = np.full((b[:, 0].max()+1, b[:, 1].max()+1), -1)
s[b[:, 0], b[:, 1]] = b[:, 2]
v = swv(a, 2, axis=0)
out = s[v[..., 0], v[..., 1]]
使用 0
基索引的变体(会生成一个稍微紧凑一点的中间结果):
s = np.full((b[:, 0].max(), b[:, 1].max()), -1)
s[b[:, 0]-1, b[:, 1]-1] = b[:, 2]
v = swv(a, 2, axis=0)-1
out = s[v[..., 0], v[..., 1]]
输出结果:
array([[4, 0, 6, 8],
[1, 6, 6, 0],
[6, 9, 0, 4],
[9, 8, 6, 8],
[0, 1, 9, 0]])
中间结果 s
:
array([[-1, -1, -1, -1],
[-1, 6, 0, 9],
[-1, 6, 0, 4],
[-1, 1, 0, 8]])
# variant
array([[6, 0, 9],
[6, 0, 4],
[1, 0, 8]])
如果你有一些任意的值,这会让生成一个密集的方形中间结果变得困难,你可以使用 pandas 的 merge
方法:
from numpy.lib.stride_tricks import sliding_window_view as swv
import pandas as pd
out = (pd.DataFrame(swv(a, 2, axis=0).reshape(-1, 2))
.merge(pd.DataFrame(b), how='left')
[2].to_numpy()
.reshape(-1, a.shape[1])
)
输出结果:
array([[4, 0, 6, 8],
[1, 6, 6, 0],
[6, 9, 0, 4],
[9, 8, 6, 8],
[0, 1, 9, 0]])
时间记录
对于一个密集的输入(也就是说,索引的形式是 0 到 n(或 1 到 n),没有缺失的索引),一个方形的 NxN a
,索引为 1->N
(N=1000):
# numpy
16.5 ms ± 941 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# pandas
173 ms ± 17 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
对于一个稀疏的输入(索引是从一个更大的集合中选择的 N
值;这里是从 50,000 个可能值中选择 1000 个值;大约 2% 的密度;在方形形式中大约 0.04% 的密度):
# numpy
5.04 s ± 2.5 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
# pandas
192 ms ± 13.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)