NumPy使用索引列表按行选择特定列索引
我在选择NumPy矩阵中每一行的特定列时遇到了困难。
假设我有一个矩阵,我称之为X
:
[1, 2, 3]
[4, 5, 6]
[7, 8, 9]
我还有一个每行对应的列索引列表,我称之为Y
:
[1, 0, 2]
我需要获取这些值:
[2]
[4]
[9]
除了用Y
这个列表来表示索引,我还可以生成一个和X
形状相同的矩阵,其中每一列都是一个bool
(布尔值)或int
(整数),值在0到1之间,表示这是否是所需的列。
[0, 1, 0]
[1, 0, 0]
[0, 0, 1]
我知道可以通过遍历数组来选择我需要的列值。不过,这个操作会在大数据数组上频繁执行,所以我希望它能尽可能快地运行。
因此,我想知道是否有更好的解决方案?
7 个回答
3
你可以通过使用迭代器来实现。就像这样:
np.fromiter((row[index] for row, index in zip(X, Y)), dtype=int)
时间:
N = 1000
X = np.zeros(shape=(N, N))
Y = np.arange(N)
#@Aशwini चhaudhary
%timeit X[np.arange(len(X)), Y]
10000 loops, best of 3: 30.7 us per loop
#mine
%timeit np.fromiter((row[index] for row, index in zip(X, Y)), dtype=int)
1000 loops, best of 3: 1.15 ms per loop
#mine
%timeit np.diag(X.T[Y])
10 loops, best of 3: 20.8 ms per loop
7
一个简单的方法可能是这样的:
In [1]: a = np.array([[1, 2, 3],
...: [4, 5, 6],
...: [7, 8, 9]])
In [2]: y = [1, 0, 2] #list of indices we want to select from matrix 'a'
range(a.shape[0])
这个代码会返回 array([0, 1, 2])
In [3]: a[range(a.shape[0]), y] #we're selecting y indices from every row
Out[3]: array([2, 4, 9])
41
最近的 numpy
版本增加了一个叫 take_along_axis
(还有 put_along_axis
)的功能,可以更简单地进行这种索引操作。
In [101]: a = np.arange(1,10).reshape(3,3)
In [102]: b = np.array([1,0,2])
In [103]: np.take_along_axis(a, b[:,None], axis=1)
Out[103]:
array([[2],
[4],
[9]])
它的工作方式和下面这个是一样的:
In [104]: a[np.arange(3), b]
Out[104]: array([2, 4, 9])
但它处理轴的方式不同。这个功能特别适合用来应用 argsort
和 argmax
的结果。
56
你可以这样做:
In [7]: a = np.array([[1, 2, 3],
...: [4, 5, 6],
...: [7, 8, 9]])
In [8]: lst = [1, 0, 2]
In [9]: a[np.arange(len(a)), lst]
Out[9]: array([2, 4, 9])
关于多维数组的索引,你可以了解更多信息:http://docs.scipy.org/doc/numpy/user/basics.indexing.html#indexing-multi-dimensional-arrays
158
如果你有一个布尔数组,你可以直接根据这个数组进行选择,方法如下:
>>> a = np.array([True, True, True, False, False])
>>> b = np.array([1,2,3,4,5])
>>> b[a]
array([1, 2, 3])
结合你最开始的例子,你可以这样做:
>>> a = np.array([[1,2,3], [4,5,6], [7,8,9]])
>>> b = np.array([[False,True,False],[True,False,False],[False,False,True]])
>>> a[b]
array([2, 4, 9])
你还可以加入一个 arange
,然后直接在这个基础上进行选择。不过,具体效果可能会根据你生成布尔数组的方式和你的代码结构有所不同。
>>> a = np.array([[1,2,3], [4,5,6], [7,8,9]])
>>> a[np.arange(len(a)), [1,0,2]]
array([2, 4, 9])