Python 2.6 优化嵌套循环

2 投票
4 回答
1212 浏览
提问于 2025-04-17 09:24

我有一个函数,它的输入是一个字典和一个数字n。字典里的每一项都是一个包含一个或多个值的集合。这个函数需要对字典的键进行排序,并提取并返回n个值。因为这个函数会被频繁执行,所以我想对它进行优化。有什么建议吗?

def select_items(temp_dict, n):
  """Select n items from the dictionary"""
  res = []
  sort_keys = sorted(temp_dict.keys())
  count = 0

  for key in sort_keys:
    for pair in temp_dict[key]:
      if count < n:
        res.append(pair)
        count += 1
      else:
        return res

  return res

在这段代码中,我有一个计数和一个“如果语句”来控制选择的值的数量。有没有办法通过使用itertools中的某个函数或者其他方法来优化这段代码呢?

4 个回答

2

我觉得用列表推导式和返回生成器的方式更简洁、更易读。使用数组切片可以避免使用if条件。

def select_items(dic, n):
  return (dic[key] for key in sorted(dic.keys())[:n])

关于速度:我认为实际的sort调用可能是这里最大的瓶颈,不过你可能不需要担心这个,直到字典的大小变得很大。在那种情况下,你可能应该考虑一开始就保持字典的有序性——插入时会稍微复杂一些,但查找和选择会很快。一个例子是sorteddict。基于树的数据结构可能是另一个选择。

接下来是基准测试。初始设置,摘自David Wolever的精彩回答:

test_dict = dict((x, "a") for x in range(1000))
test_n = 300

你的版本:

%timeit select_items(test_dict, test_n)
1000 loops, best of 3: 334 us per loop

这个版本:

%timeit select_items(test_dict, test_n)
10000 loops, best of 3: 49.1 us per loop
6

这是我第一次尝试的结果(见 select_items_faster),速度几乎翻倍:

In [12]: print _11
import itertools

def select_items_original(temp_dict, n):
  """Select n items from the dictionary"""
  res = []
  sort_keys = sorted(temp_dict.keys())
  count = 0

  for key in sort_keys:
    for pair in temp_dict[key]:
      if count < n:
        res.append(pair)
        count += 1
      else:
        return res

  return res

def select_items_faster(temp_dict, n):
    """Select n items from the dictionary"""
    items = temp_dict.items()
    items.sort()

    return list(itertools.chain.from_iterable(val for (_, val) in itertools.islice(items, n)))

test_dict = dict((x, ["a"] * int(x / 500)) for x in range(1000))
test_n = 300


In [13]: %timeit select_items_original(test_dict, test_n)
1000 loops, best of 3: 293 us per loop


In [14]: %timeit select_items_faster(test_dict, test_n)
1000 loops, best of 3: 203 us per loop

itertools.islice 换成 [:n] 并没有太大帮助:

def select_items_faster_slice(temp_dict, n):
    """Select n items from the dictionary"""
    items = temp_dict.items()
    items.sort()

    return list(itertools.chain.from_iterable(val for (_, val) in items[:n]))

In [16]: %timeit select_items_faster_slice(test_dict, test_n)
1000 loops, best of 3: 210 us per loop

使用 sorted 也没什么效果:

In [18]: %timeit select_items_faster_sorted(test_dict, test_n)
1000 loops, best of 3: 213 us per loop


In [19]: print _17
def select_items_faster_sorted(temp_dict, n):
    """Select n items from the dictionary"""
    return list(itertools.chain.from_iterable(val for (_, val) in itertools.islice(sorted(temp_dict.items()), n)))

但是把 map__getitem__ 结合起来就快多了:

In [22]: %timeit select_items_faster_map_getitem(test_dict, test_n)
10000 loops, best of 3: 90.7 us per loop

In [23]: print _20
def select_items_faster_map_getitem(temp_dict, n):
    """Select n items from the dictionary"""
    keys = temp_dict.keys()
    keys.sort()
    return list(itertools.chain.from_iterable(map(temp_dict.__getitem__, keys[:n])))

list(itertools.chain.from_iterable) 换成一些神奇的方法,速度提升了不少:

In [28]: %timeit select_items_faster_map_getitem_list_extend(test_dict, test_n)
10000 loops, best of 3: 74.9 us per loop

In 29: print _27
def select_items_faster_map_getitem_list_extend(temp_dict, n):
    """Select n items from the dictionary"""
    keys = temp_dict.keys()
    keys.sort()
    result = []
    filter(result.extend, map(temp_dict.__getitem__, keys[:n]))
    return result

而用 itertools 的函数替换 map 和切片又能再快一点:

In [31]: %timeit select_items_faster_map_getitem_list_extend_iterables(test_dict, test_n)
10000 loops, best of 3: 72.8 us per loop

In [32]: print _30
def select_items_faster_map_getitem_list_extend_iterables(temp_dict, n):
    """Select n items from the dictionary"""
    keys = temp_dict.keys()
    keys.sort()
    result = []
    filter(result.extend, itertools.imap(temp_dict.__getitem__, itertools.islice(keys, n)))
    return result

这大概是我认为能达到的最快速度了,因为在 CPython 中,Python 函数调用比较慢,而这个方法尽量减少了内层循环中 Python 函数的调用次数。

注意

  • 由于提问者没有提供输入数据的任何线索,所以我只能猜测。我可能猜错了,这可能会大大改变“快”的定义。
  • 我所有的实现返回的是 n - 1 个项目,而不是 n。

编辑:使用相同的方法来分析 J.F. Sebastian 的代码:

In [2]: %timeit select_items_heapq(test_dict, test_n)
1000 loops, best of 3: 572 us per loop

In [3]: print _1
from itertools import *
import heapq

def select_items_heapq(temp_dict, n):
    return list(islice(chain.from_iterable(imap(temp_dict.get, heapq.nsmallest(n, temp_dict))),n))

还有 TokenMacGuy 的代码:

In [5]: %timeit select_items_tokenmacguy_first(test_dict, test_n)
1000 loops, best of 3: 201 us per loop

In [6]: %timeit select_items_tokenmacguy_second(test_dict, test_n)
1000 loops, best of 3: 730 us per loop

In [7]: print _4
def select_items_tokenmacguy_first(m, n):
    k, v, r = m.keys(), m.values(), range(len(m))
    r.sort(key=k.__getitem__)
    return [v[i] for i in r[:n]]

import heapq
def select_items_tokenmacguy_second(m, n):
    k, v, r = m.keys(), m.values(), range(len(m))
    smallest = heapq.nsmallest(n, r, k.__getitem__)
    for i, ind in enumerate(smallest):
        smallest[i] = v[ind]
    return smallest
1

到目前为止,给出的答案并没有符合用户的要求。

这里的数据是一个包含序列的字典,而我们想要的结果是按照字典的键排序后,取出字典值的前 n 个元素,形成一个列表。

所以如果数据是:

{1: [1, 2, 3], 2: [4, 5, 6]}

那么,如果 n = 5,结果应该是:

[1, 2, 3, 4, 5]

基于这个要求,这里有一个脚本,用来比较原来的函数和一个(稍微)优化过的新版本:

from timeit import timeit

def select_items_old(temp_dict, n):
  res = []
  sort_keys = sorted(temp_dict.keys())
  count = 0
  for key in sort_keys:
    for pair in temp_dict[key]:
      if count < n:
        res.append(pair)
        count += 1
      else:
        return res
  return res

def select_items_new(data, limit):
    count = 0
    result = []
    extend = result.extend
    for key in sorted(data.keys()):
        value = data[key]
        extend(value)
        count += len(value)
        if count >= limit:
            break
    return result[:limit]

data = {x:range(10) for x in range(1000)}

def compare(*args):
    number = 1000
    for func in args:
        name = func.__name__
        print ('test: %s(data, 12): %r' % (name, func(data, 12)))
        code = '%s(data, %d)' % (name, 300)
        duration = timeit(
            code, 'from __main__ import %s, data' % name, number=number)
        print ('time: %s: %.2f usec/pass\n' % (code, 1000000 * duration/number))

compare(select_items_old, select_items_new)

输出结果:

test: select_items_old(data, 12): [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1]
time: select_items_old(data, 300): 163.81 usec/pass

test: select_items_new(data, 12): [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1]
time: select_items_new(data, 300): 67.74 usec/pass

撰写回答