在Python中对数组数组排序
我有以下的数据结构:
[[[ 512 520 1 130523]]
[[ 520 614 573 7448]]
[[ 614 616 615 210]]
[[ 616 622 619 269]]
[[ 622 624 623 162]]
[[ 625 770 706 8822]]
[[ 770 776 773 241]]]
我想返回一个形状相同的对象,但只保留第4列中最大的3个值对应的行(如果这样说能让你明白的话)。在这个例子中,就是第1、2和6行。
有没有什么优雅的方法可以做到这一点?
3 个回答
0
我简化了你列表中的列表结构,这样可以更专注于主要问题。你可以使用 sorted()
函数,并配合一个自定义的 compare()
函数来进行排序:
my_list = [[512, 520, 1, 130523],
[520, 614 , 573, 7448],
[614, 616, 615, 210],
[616, 622, 619, 269],
[622, 624, 623, 162],
[625, 770, 706, 8822],
[770, 776, 773, 241]]
def sort_by(a):
return a[3]
sorted(my_list, key=sort_by)
print my_list[0:3] # prints [[512, 520, 1, 130523], [520, 614, 573, 7448], [614, 616, 615, 210]]
5
你可以使用 sorted()
这个函数,并告诉它你想根据第4列来排序:
l = [[[512, 520 , 1, 130523]],
[[ 520 , 614 , 573, 7448]],
[[ 614 , 616 , 615, 210]],
[[ 616 , 622 , 619, 269]],
[[ 622 , 624 , 623, 162]],
[[ 625 , 770 , 706, 8822]],
[[ 770 , 776 , 773, 241]]]
top3 = sorted(l, key=lambda x: x[0][3], reverse=True)[:3]
print top3
这样就会得到:
[[[512, 520, 1, 130523]], [[625, 770, 706, 8822]], [[520, 614, 573, 7448]]]
3
你可以对数组进行排序,但从NumPy 1.8开始,有一种更快的方法可以找到最大的N个值,特别是当data
很大的时候:
import numpy as np
data = np.array([[[ 512, 520, 1, 130523]],
[[ 520, 614, 573, 7448]],
[[ 614, 616, 615, 210]],
[[ 616, 622, 619, 269]],
[[ 622, 624, 623, 162]],
[[ 625, 770, 706, 8822]],
[[ 770, 776, 773, 241]]])
idx = np.argpartition(-data[...,-1].flatten(), 3)
print(data[idx[:3]])
会得到
[[[ 520 614 573 7448]]
[[ 512 520 1 130523]]
[[ 625 770 706 8822]]]
np.argpartition
执行的是一种部分排序。它返回的是数组的索引,按照部分排序的顺序排列,这样每个kth
项就处于它最终的排序位置。实际上,每组k
项是相对于其他组进行排序的,但每组内部并没有排序(这样可以节省一些时间)。
注意,返回的3个最高的行并不是按照它们在data
中出现的顺序返回的。
为了比较,这里是如何使用np.argsort
(它会进行完全排序)来找到3个最高的行:
idx = np.argsort(data[..., -1].flatten())
print(data[idx[-3:]])
会得到
[[[ 520 614 573 7448]]
[[ 625 770 706 8822]]
[[ 512 520 1 130523]]]
注意:对于小数组,np.argsort
会更快:
In [63]: %timeit idx = np.argsort(data[..., -1].flatten())
100000 loops, best of 3: 2.6 µs per loop
In [64]: %timeit idx = np.argpartition(-data[...,-1].flatten(), 3)
100000 loops, best of 3: 5.61 µs per loop
但对于大数组,np.argpartition
会更快:
In [92]: data2 = np.tile(data, (10**3,1,1))
In [93]: data2.shape
Out[93]: (7000, 1, 4)
In [94]: %timeit idx = np.argsort(data2[..., -1].flatten())
10000 loops, best of 3: 164 µs per loop
In [95]: %timeit idx = np.argpartition(-data2[...,-1].flatten(), 3)
10000 loops, best of 3: 49.5 µs per loop