Numpy 查找(映射或点)

10 投票
2 回答
1428 浏览
提问于 2025-04-16 12:02

我有一个很大的numpy数组:

array([[32, 32, 99,  9, 45],  # A
       [99, 45,  9, 45, 32],
       [45, 45, 99, 99, 32],
       [ 9,  9, 32, 45, 99]])

还有一个相对较大的独特值数组,这些值是有特定顺序的:

array([ 99, 32, 45, 9])       # B

我该如何快速地(不使用python字典,不复制A,也不使用python循环)将A中的值替换成B中值的索引呢?:

array([[1, 1, 0, 3, 2],
       [0, 2, 3, 2, 1],
       [2, 2, 0, 0, 1],
       [3, 3, 1, 2, 0]])

我觉得自己真傻,竟然想不出这个方法,也找不到相关的文档。真是简单的问题啊!

2 个回答

7
import numpy as np
A=np.array([[32, 32, 99,  9, 45],  
            [99, 45,  9, 45, 32],
            [45, 45, 99, 99, 32],
            [ 9,  9, 32, 45, 99]])

B=np.array([ 99, 32, 45, 9])

cutoffs=np.sort(B)
print(cutoffs)
# [ 9 32 45 99]

index=cutoffs.searchsorted(A)
print(index)
# [[1 1 3 0 2]
#  [3 2 0 2 1]
#  [2 2 3 3 1]
#  [0 0 1 2 3]]    

index 里存放的是与数组 A 中每个元素对应的 cutoff 的索引。注意,我们需要先对 B 进行排序,因为 np.searchsorted 这个函数要求输入的数组是排好序的。

index 几乎就是我们想要的结果,不过我们还需要进行一些映射。

1-->1
3-->0
0-->3
2-->2

np.argsort 可以帮助我们完成这个映射:

print(np.argsort(B))
# [3 1 2 0]
print(np.argsort(B)[1])
# 1
print(np.argsort(B)[3])
# 0
print(np.argsort(B)[0])
# 3
print(np.argsort(B)[2])
# 2

print(np.argsort(B)[index])
# [[1 1 0 3 2]
#  [0 2 3 2 1]
#  [2 2 0 0 1]
#  [3 3 1 2 0]]

所以,简单来说,答案就是:

np.argsort(B)[np.sort(B).searchsorted(A)]

同时调用 np.sort(B)np.argsort(B) 是不太高效的,因为这两步都是在对 B 进行排序。对于任何一维数组 B

np.sort(B) == B[np.argsort(B)]

所以我们可以用一种更快的方法来计算出想要的结果:

key=np.argsort(B)
result=key[B[key].searchsorted(A)]
8

给你看看

A = array([[32, 32, 99,  9, 45],  # A
   [99, 45,  9, 45, 32],
   [45, 45, 99, 99, 32],
   [ 9,  9, 32, 45, 99]])

B = array([ 99, 32, 45, 9])

ii = np.argsort(B)
C = np.digitize(A.reshape(-1,),np.sort(B)) - 1

最开始我建议:

D = np.choose(C,ii).reshape(A.shape)

但我意识到,当处理更大的数组时,这个方法有些局限。于是,我借鉴了@unutbu的聪明回答:

D = np.argsort(B)[C].reshape(A.shape)

或者这个一行代码

np.argsort(B)[np.digitize(A.reshape(-1,),np.sort(B)) - 1].reshape(A.shape)

我发现这个方法的速度比@unutbu的代码快或慢,具体取决于数组的大小和独特值的数量。

撰写回答