对2d numpy数组进行子集划分并保持行的一致性

2024-04-20 13:08:16 发布

您现在位置:Python中文网/ 问答频道 /正文

我想知道做以下事情最简单的方法是什么:

假设我们有以下2d数组:

>>> a = np.array([['z', 'z', 'z', 'f', 'z','f', 'f'], ['z', 'z', 'z', 'f', 'z','f', 'f']])

array([['z', 'z', 'z', 'f', 'z', 'f', 'f'],
   ['z', 'z', 'z', 'f', 'z', 'f', 'f']],
  dtype='<U1')



>>> b = np.array(range(0,14)).reshape(2, -1)


array([[ 0,  1,  2,  3,  4,  5,  6],
       [ 7,  8,  9, 10, 11, 12, 13]])


>>> idxs = list(zip(*np.where(a == 'f')))

[(0, 3), (0, 5), (0, 6), (1, 3), (1, 5), (1, 6)]


>>> [b[x] for x in idxs]

[3, 5, 6, 10, 12, 13]

但是,我想保留之前关于第一个索引或行的结构,即:

[[3, 5, 6], [7, 11]]

有没有办法保持这种结构?你知道吗


Tags: 方法fornprange数组zipwhere事情
3条回答

使用for循环:

[b[i][a[i] == 'f'] for i in range(len(a))]
# [array([3, 5, 6]), array([10, 12, 13])]

a = np.array([['z', 'z', 'z', 'f', 'z','f', 'f'], ['z', 'z', 'z', 'f', 'z','f', 'f']]) b = np.array(range(0,14)).reshape(2, -1) idxs = list(zip(*np.where(a == 'f'))) c=[[],[]] for x in idxs: c[x[0]].append(b[x]) print c

这是一个更复杂但纯粹的解决方案:

  1. 获取索引(在a的扁平版本中),其中它是一个'f'。你知道吗
  2. 获取新行开始的索引
  3. 从1中找出数组中属于一行的索引
  4. 在这些索引处拆分数组。你知道吗

代码如下所示:

>>> indices = np.flatnonzero(a.ravel() == 'f')
>>> rows = np.arange(1, a.shape[0])*a.shape[1]
>>> np.split(b.ravel()[indices], np.searchsorted(indices, rows))
[array([3, 5, 6], dtype=int64), array([10, 12, 13], dtype=int64)]

比其他解决方案长一点,我不确定它是否会更快1。你知道吗

不过,就我个人而言,我会用一个列表和一个zip

[b_row[a_row] for a_row, b_row in zip(a == 'f', b)]

它短得多,而且根据我的时间安排,相当出色。你知道吗


时间安排:

import numpy as np
a = np.array([['z', 'z', 'z', 'f', 'z','f', 'f']]*10000)
b = np.arange(a.size).reshape(-1, a.shape[1])

%%timeit

indices = np.flatnonzero(a.ravel() == 'f')
rows = np.arange(1, a.shape[0])*a.shape[1]
np.split(b.ravel()[indices], np.searchsorted(indices, rows))

123 ms ± 8.25 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

%timeit [b[i][a[i] == 'f'] for i in range(len(a))]

162 ms ± 14 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

但与my suggestion at Psidoms answer相比要慢得多:

%timeit [b_row[a_row] for a_row, b_row in zip(a == 'f', b)]

44.9 ms ± 1.93 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

相关问题 更多 >