应用指定规则构建数组

1 投票
1 回答
37 浏览
提问于 2025-04-14 16:43

考虑一个数组 a,它的每一列都包含从1、2、3中随机选出的值:

a = np.array([[2, 3, 1, 3],
              [3, 2, 1, 3],
              [1, 1, 1, 2],
              [1, 3, 2, 3],
              [3, 3, 1, 3],
              [2, 1, 3, 2]])

现在,再考虑一个数组 b,它的前两列包含从1、2、3中选出的9种可能的值对(这里值对的顺序是很重要的)。b的第三列则为每一对值关联了一个非负整数。

b = np.array([[1, 1, 6],
              [1, 2, 0],
              [1, 3, 9],
              [2, 1, 6],
              [2, 2, 0],
              [2, 3, 4],
              [3, 1, 1],
              [3, 2, 0],
              [3, 3, 8]])

我需要帮助来写一段代码,生成一个数组 c,其中数组 a 中上下相邻的元素会被 b 第三列中对应的值替换。例如,a 的第一列的值是2、3、1、1、3、2,经过替换后,c 的第一列会变成4、1、6、9、0。这个思路同样适用于 a 的每一列。我们可以看到,值对的顺序很重要(从3到1得到的值是1,而从1到3得到的值是9)。

这个小例子的输出结果是:

c = np.array([[4, 0, 6, 8],
              [1, 6, 6, 0],
              [6, 9, 0, 4],
              [9, 8, 6, 8],
              [0, 1, 9, 0]])

因为这段代码会被执行很多次,所以我希望能有一个快速的向量化解决方案。

1 个回答

3

因为 b 包含了所有的配对,所以你可以有效地把它变成一个方形的形式,按照行和列的编号来索引,然后用 sliding_window_view 来形成索引对,并索引这个方形的中间结果:

from numpy.lib.stride_tricks import sliding_window_view as swv

s = np.full((b[:, 0].max()+1, b[:, 1].max()+1), -1)
s[b[:, 0], b[:, 1]] = b[:, 2]

v = swv(a, 2, axis=0)
out = s[v[..., 0], v[..., 1]]

使用 0 基索引的变体(会生成一个稍微紧凑一点的中间结果):

s = np.full((b[:, 0].max(), b[:, 1].max()), -1)
s[b[:, 0]-1, b[:, 1]-1] = b[:, 2]

v = swv(a, 2, axis=0)-1
out = s[v[..., 0], v[..., 1]]

输出结果:

array([[4, 0, 6, 8],
       [1, 6, 6, 0],
       [6, 9, 0, 4],
       [9, 8, 6, 8],
       [0, 1, 9, 0]])

中间结果 s

array([[-1, -1, -1, -1],
       [-1,  6,  0,  9],
       [-1,  6,  0,  4],
       [-1,  1,  0,  8]])

# variant
array([[6, 0, 9],
       [6, 0, 4],
       [1, 0, 8]])

如果你有一些任意的值,这会让生成一个密集的方形中间结果变得困难,你可以使用 merge 方法:

from numpy.lib.stride_tricks import sliding_window_view as swv
import pandas as pd

out = (pd.DataFrame(swv(a, 2, axis=0).reshape(-1, 2))
         .merge(pd.DataFrame(b), how='left')
         [2].to_numpy()
         .reshape(-1, a.shape[1])
       )

输出结果:

array([[4, 0, 6, 8],
       [1, 6, 6, 0],
       [6, 9, 0, 4],
       [9, 8, 6, 8],
       [0, 1, 9, 0]])

时间记录

对于一个密集的输入(也就是说,索引的形式是 0 到 n(或 1 到 n),没有缺失的索引),一个方形的 NxN a,索引为 1->N(N=1000):

# numpy
16.5 ms ± 941 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

# pandas
173 ms ± 17 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

对于一个稀疏的输入(索引是从一个更大的集合中选择的 N 值;这里是从 50,000 个可能值中选择 1000 个值;大约 2% 的密度;在方形形式中大约 0.04% 的密度):

# numpy
5.04 s ± 2.5 s per loop (mean ± std. dev. of 7 runs, 1 loop each)

# pandas
192 ms ± 13.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

撰写回答