Python 2.6 优化嵌套循环
我有一个函数,它的输入是一个字典和一个数字n。字典里的每一项都是一个包含一个或多个值的集合。这个函数需要对字典的键进行排序,并提取并返回n个值。因为这个函数会被频繁执行,所以我想对它进行优化。有什么建议吗?
def select_items(temp_dict, n):
"""Select n items from the dictionary"""
res = []
sort_keys = sorted(temp_dict.keys())
count = 0
for key in sort_keys:
for pair in temp_dict[key]:
if count < n:
res.append(pair)
count += 1
else:
return res
return res
在这段代码中,我有一个计数和一个“如果语句”来控制选择的值的数量。有没有办法通过使用itertools中的某个函数或者其他方法来优化这段代码呢?
4 个回答
我觉得用列表推导式和返回生成器的方式更简洁、更易读。使用数组切片可以避免使用if
条件。
def select_items(dic, n):
return (dic[key] for key in sorted(dic.keys())[:n])
关于速度:我认为实际的sort
调用可能是这里最大的瓶颈,不过你可能不需要担心这个,直到字典的大小变得很大。在那种情况下,你可能应该考虑一开始就保持字典的有序性——插入时会稍微复杂一些,但查找和选择会很快。一个例子是sorteddict。基于树的数据结构可能是另一个选择。
接下来是基准测试。初始设置,摘自David Wolever的精彩回答:
test_dict = dict((x, "a") for x in range(1000))
test_n = 300
你的版本:
%timeit select_items(test_dict, test_n)
1000 loops, best of 3: 334 us per loop
这个版本:
%timeit select_items(test_dict, test_n)
10000 loops, best of 3: 49.1 us per loop
这是我第一次尝试的结果(见 select_items_faster
),速度几乎翻倍:
In [12]: print _11
import itertools
def select_items_original(temp_dict, n):
"""Select n items from the dictionary"""
res = []
sort_keys = sorted(temp_dict.keys())
count = 0
for key in sort_keys:
for pair in temp_dict[key]:
if count < n:
res.append(pair)
count += 1
else:
return res
return res
def select_items_faster(temp_dict, n):
"""Select n items from the dictionary"""
items = temp_dict.items()
items.sort()
return list(itertools.chain.from_iterable(val for (_, val) in itertools.islice(items, n)))
test_dict = dict((x, ["a"] * int(x / 500)) for x in range(1000))
test_n = 300
In [13]: %timeit select_items_original(test_dict, test_n)
1000 loops, best of 3: 293 us per loop
In [14]: %timeit select_items_faster(test_dict, test_n)
1000 loops, best of 3: 203 us per loop
把 itertools.islice
换成 [:n]
并没有太大帮助:
def select_items_faster_slice(temp_dict, n):
"""Select n items from the dictionary"""
items = temp_dict.items()
items.sort()
return list(itertools.chain.from_iterable(val for (_, val) in items[:n]))
In [16]: %timeit select_items_faster_slice(test_dict, test_n)
1000 loops, best of 3: 210 us per loop
使用 sorted
也没什么效果:
In [18]: %timeit select_items_faster_sorted(test_dict, test_n)
1000 loops, best of 3: 213 us per loop
In [19]: print _17
def select_items_faster_sorted(temp_dict, n):
"""Select n items from the dictionary"""
return list(itertools.chain.from_iterable(val for (_, val) in itertools.islice(sorted(temp_dict.items()), n)))
但是把 map
和 __getitem__
结合起来就快多了:
In [22]: %timeit select_items_faster_map_getitem(test_dict, test_n)
10000 loops, best of 3: 90.7 us per loop
In [23]: print _20
def select_items_faster_map_getitem(temp_dict, n):
"""Select n items from the dictionary"""
keys = temp_dict.keys()
keys.sort()
return list(itertools.chain.from_iterable(map(temp_dict.__getitem__, keys[:n])))
把 list(itertools.chain.from_iterable)
换成一些神奇的方法,速度提升了不少:
In [28]: %timeit select_items_faster_map_getitem_list_extend(test_dict, test_n)
10000 loops, best of 3: 74.9 us per loop
In 29: print _27
def select_items_faster_map_getitem_list_extend(temp_dict, n):
"""Select n items from the dictionary"""
keys = temp_dict.keys()
keys.sort()
result = []
filter(result.extend, map(temp_dict.__getitem__, keys[:n]))
return result
而用 itertools 的函数替换 map 和切片又能再快一点:
In [31]: %timeit select_items_faster_map_getitem_list_extend_iterables(test_dict, test_n)
10000 loops, best of 3: 72.8 us per loop
In [32]: print _30
def select_items_faster_map_getitem_list_extend_iterables(temp_dict, n):
"""Select n items from the dictionary"""
keys = temp_dict.keys()
keys.sort()
result = []
filter(result.extend, itertools.imap(temp_dict.__getitem__, itertools.islice(keys, n)))
return result
这大概是我认为能达到的最快速度了,因为在 CPython 中,Python 函数调用比较慢,而这个方法尽量减少了内层循环中 Python 函数的调用次数。
注意:
- 由于提问者没有提供输入数据的任何线索,所以我只能猜测。我可能猜错了,这可能会大大改变“快”的定义。
- 我所有的实现返回的是 n - 1 个项目,而不是 n。
编辑:使用相同的方法来分析 J.F. Sebastian 的代码:
In [2]: %timeit select_items_heapq(test_dict, test_n)
1000 loops, best of 3: 572 us per loop
In [3]: print _1
from itertools import *
import heapq
def select_items_heapq(temp_dict, n):
return list(islice(chain.from_iterable(imap(temp_dict.get, heapq.nsmallest(n, temp_dict))),n))
还有 TokenMacGuy 的代码:
In [5]: %timeit select_items_tokenmacguy_first(test_dict, test_n)
1000 loops, best of 3: 201 us per loop
In [6]: %timeit select_items_tokenmacguy_second(test_dict, test_n)
1000 loops, best of 3: 730 us per loop
In [7]: print _4
def select_items_tokenmacguy_first(m, n):
k, v, r = m.keys(), m.values(), range(len(m))
r.sort(key=k.__getitem__)
return [v[i] for i in r[:n]]
import heapq
def select_items_tokenmacguy_second(m, n):
k, v, r = m.keys(), m.values(), range(len(m))
smallest = heapq.nsmallest(n, r, k.__getitem__)
for i, ind in enumerate(smallest):
smallest[i] = v[ind]
return smallest
到目前为止,给出的答案并没有符合用户的要求。
这里的数据是一个包含序列的字典,而我们想要的结果是按照字典的键排序后,取出字典值的前 n 个元素,形成一个列表。
所以如果数据是:
{1: [1, 2, 3], 2: [4, 5, 6]}
那么,如果 n = 5,结果应该是:
[1, 2, 3, 4, 5]
基于这个要求,这里有一个脚本,用来比较原来的函数和一个(稍微)优化过的新版本:
from timeit import timeit
def select_items_old(temp_dict, n):
res = []
sort_keys = sorted(temp_dict.keys())
count = 0
for key in sort_keys:
for pair in temp_dict[key]:
if count < n:
res.append(pair)
count += 1
else:
return res
return res
def select_items_new(data, limit):
count = 0
result = []
extend = result.extend
for key in sorted(data.keys()):
value = data[key]
extend(value)
count += len(value)
if count >= limit:
break
return result[:limit]
data = {x:range(10) for x in range(1000)}
def compare(*args):
number = 1000
for func in args:
name = func.__name__
print ('test: %s(data, 12): %r' % (name, func(data, 12)))
code = '%s(data, %d)' % (name, 300)
duration = timeit(
code, 'from __main__ import %s, data' % name, number=number)
print ('time: %s: %.2f usec/pass\n' % (code, 1000000 * duration/number))
compare(select_items_old, select_items_new)
输出结果:
test: select_items_old(data, 12): [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1]
time: select_items_old(data, 300): 163.81 usec/pass
test: select_items_new(data, 12): [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1]
time: select_items_new(data, 300): 67.74 usec/pass