在Python中对数组数组排序

1 投票
3 回答
6433 浏览
提问于 2025-04-18 18:06

我有以下的数据结构:

 [[[   512    520     1 130523]]

 [[   520    614    573   7448]]

 [[   614    616    615    210]]

 [[   616    622    619    269]]

 [[   622    624    623    162]]

 [[   625    770    706   8822]]

 [[   770    776    773    241]]]

我想返回一个形状相同的对象,但只保留第4列中最大的3个值对应的行(如果这样说能让你明白的话)。在这个例子中,就是第1、2和6行。

有没有什么优雅的方法可以做到这一点?

3 个回答

0

我简化了你列表中的列表结构,这样可以更专注于主要问题。你可以使用 sorted() 函数,并配合一个自定义的 compare() 函数来进行排序:

my_list =  [[512, 520, 1, 130523], 
        [520, 614 , 573, 7448],
        [614, 616, 615, 210],
        [616, 622, 619, 269], 
        [622, 624, 623, 162], 
        [625, 770, 706, 8822], 
        [770, 776, 773, 241]]

def sort_by(a):
    return a[3]

sorted(my_list, key=sort_by)
print my_list[0:3] # prints [[512, 520, 1, 130523], [520, 614, 573, 7448], [614, 616, 615, 210]]
5

你可以使用 sorted() 这个函数,并告诉它你想根据第4列来排序:

l = [[[512,    520 ,    1, 130523]],
 [[   520 ,   614  ,  573,   7448]],
 [[   614 ,   616  ,  615,    210]],
 [[   616 ,   622  ,  619,    269]],
 [[   622 ,   624  ,  623,    162]],
 [[   625 ,   770  ,  706,   8822]],
 [[   770 ,   776  ,  773,    241]]]

top3 =  sorted(l, key=lambda x: x[0][3], reverse=True)[:3]

print top3

这样就会得到:

[[[512, 520, 1, 130523]], [[625, 770, 706, 8822]], [[520, 614, 573, 7448]]]
3

你可以对数组进行排序,但从NumPy 1.8开始,有一种更快的方法可以找到最大的N个值,特别是当data很大的时候:

使用numpy.argpartition

import numpy as np
data = np.array([[[ 512,    520,     1, 130523]],
                 [[ 520,    614,    573,   7448]],
                 [[ 614,    616,    615,    210]],
                 [[ 616,    622,    619,    269]],
                 [[ 622,    624,    623,    162]],
                 [[ 625,    770,    706,   8822]],
                 [[ 770,    776,    773,    241]]])

idx = np.argpartition(-data[...,-1].flatten(), 3)
print(data[idx[:3]])

会得到

[[[   520    614    573   7448]]

 [[   512    520      1 130523]]

 [[   625    770    706   8822]]]

np.argpartition执行的是一种部分排序。它返回的是数组的索引,按照部分排序的顺序排列,这样每个kth项就处于它最终的排序位置。实际上,每组k项是相对于其他组进行排序的,但每组内部并没有排序(这样可以节省一些时间)。

注意,返回的3个最高的行并不是按照它们在data中出现的顺序返回的。


为了比较,这里是如何使用np.argsort(它会进行完全排序)来找到3个最高的行:

idx = np.argsort(data[..., -1].flatten())
print(data[idx[-3:]])

会得到

[[[   520    614    573   7448]]

 [[   625    770    706   8822]]

 [[   512    520      1 130523]]]

注意:对于小数组,np.argsort会更快:

In [63]: %timeit idx = np.argsort(data[..., -1].flatten())
100000 loops, best of 3: 2.6 µs per loop

In [64]: %timeit idx = np.argpartition(-data[...,-1].flatten(), 3)
100000 loops, best of 3: 5.61 µs per loop

但对于大数组,np.argpartition会更快:

In [92]: data2 = np.tile(data, (10**3,1,1))
In [93]: data2.shape
Out[93]: (7000, 1, 4)

In [94]: %timeit idx = np.argsort(data2[..., -1].flatten())
10000 loops, best of 3: 164 µs per loop

In [95]: %timeit idx = np.argpartition(-data2[...,-1].flatten(), 3)
10000 loops, best of 3: 49.5 µs per loop

撰写回答