在numpy中按值拆分数组

14 投票
4 回答
31839 浏览
提问于 2025-04-16 13:30

我有一个文件,里面的数据格式是这样的:

0.0 x1
0.1 x2
0.2 x3
0.0 x4
0.1 x5
0.2 x6
0.3 x7
...

这些数据包含了多个数据集,每个数据集的第一列都是0(比如x1,x2,x3是一组,x4,x5,x6,x7又是一组)。我需要把每个数据集单独画出来,所以我需要把数据分开。有什么简单的方法可以做到这一点吗?

我知道我可以逐行查看数据,每次遇到第一列是0的时候就把数据分开,但这样做效率似乎不高。

4 个回答

1

你不需要用Python的循环来判断每个分割的位置。只需对第一列的数据做个差值计算,然后找出数值下降的地方就可以了。

import numpy

# read the array
arry = numpy.fromfile(file, dtype=('float, S2'))

# determine where the data "splits" shoule be
col1 = arry['f0']
diff = col1 - numpy.roll(col1,1)
idxs = numpy.where(diff<0)[0]

# only loop thru the "splits"
strts = idxs
stops = list(idxs[1:])+[None]
groups = [data[strt:stop] for strt,stop in zip(strts,stops)]
17

一旦你把数据放进一个很长的numpy数组里,就可以这样做:

import numpy as np

A = np.array([[0.0, 1], [0.1, 2], [0.2, 3], [0.0, 4], [0.1, 5], [0.2, 6], [0.3, 7], [0.0, 8], [0.1, 9], [0.2, 10]])
B = np.split(A, np.argwhere(A[:,0] == 0.0).flatten()[1:])

这样做之后,B里会包含三个数组,分别是 B[0]B[1]B[2](在这个例子中,我加了一个第三个“部分”,是为了证明它确实在正常工作)。

28

我其实挺喜欢本杰明的回答的,下面这个是一个稍微简短一点的解决方案:

B= np.split(A, np.where(A[:, 0]== 0.)[0][1:])

撰写回答