在不预处理的情况下将生成器分块

81 投票
12 回答
33556 浏览
提问于 2025-04-18 11:50

这个问题和这个以及这个有关,但那些都是在提前遍历生成器,这正是我想避免的。

我想把生成器分成几个小块。具体要求是:

  • 不要填充小块:如果剩下的元素少于小块的大小,最后一块可以小一点。
  • 不要提前遍历生成器:计算元素的过程很耗费资源,应该只由使用这些元素的函数来完成,而不是由分块的函数来做。
  • 这当然意味着:不要在内存中累积(不要使用列表)。

我尝试了以下代码:

def head(iterable, max=10):
    for cnt, el in enumerate(iterable):
        yield el
        if cnt >= max:
            break

def chunks(iterable, size=10):
    i = iter(iterable)
    while True:
        yield head(i, size)

# Sample generator: the real data is much more complex, and expensive to compute
els = xrange(7)

for n, chunk in enumerate(chunks(els, 3)):
    for el in chunk:
        print 'Chunk %3d, value %d' % (n, el)

这个方法在某种程度上是有效的:

Chunk   0, value 0
Chunk   0, value 1
Chunk   0, value 2
Chunk   1, value 3
Chunk   1, value 4
Chunk   1, value 5
Chunk   2, value 6
^CTraceback (most recent call last):
  File "xxxx.py", line 15, in <module>
    for el in chunk:
  File "xxxx.py", line 2, in head
    for cnt, el in enumerate(iterable):
KeyboardInterrupt

但是……它永远不会停止(我得按^C来中断),因为有个while True循环。我希望在生成器被消耗完时停止这个循环,但我不知道怎么检测这种情况。我尝试抛出一个异常:

class NoMoreData(Exception):
    pass

def head(iterable, max=10):
    for cnt, el in enumerate(iterable):
        yield el
        if cnt >= max:
            break
    if cnt == 0 : raise NoMoreData()

def chunks(iterable, size=10):
    i = iter(iterable)
    while True:
        try:
            yield head(i, size)
        except NoMoreData:
            break

# Sample generator: the real data is much more complex, and expensive to compute    
els = xrange(7)

for n, chunk in enumerate(chunks(els, 2)):
    for el in chunk:
        print 'Chunk %3d, value %d' % (n, el)

但这样异常只在消费者的上下文中被抛出,这不是我想要的(我希望保持消费者代码的整洁)。

Chunk   0, value 0
Chunk   0, value 1
Chunk   0, value 2
Chunk   1, value 3
Chunk   1, value 4
Chunk   1, value 5
Chunk   2, value 6
Traceback (most recent call last):
  File "xxxx.py", line 22, in <module>
    for el in chunk:
  File "xxxx.py", line 9, in head
    if cnt == 0 : raise NoMoreData
__main__.NoMoreData()

我该如何在chunks函数中检测生成器是否已经耗尽,而不提前遍历它呢?

12 个回答

4

可以试试用 itertools.islice 这个工具:

import itertools

els = iter(xrange(7))

print list(itertools.islice(els, 2))
print list(itertools.islice(els, 2))
print list(itertools.islice(els, 2))
print list(itertools.islice(els, 2))

这样可以得到:

[0, 1]
[2, 3]
[4, 5]
[6]

下面是一个分块器的实现,还有一些测试:

import itertools
from typing import Iterable


def chunker(iterable: Iterable, size: int) -> Iterable[list]:
    iterable = iter(iterable)
    while True:
        chunk = list(itertools.islice(iterable, size))
        if not chunk:
            break
        yield chunk
    

assert list(chunker(range(10), 3)) == [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
assert list(chunker([], 3)) == []
7

more-itertools 提供了两个功能,分别是 chunkedichunked,可以帮助我们实现一些目标。这些内容在 Python 3 itertools 文档页面 上有提到。

chunked 和 ichunked 示例

16

如果你不能预先处理一个块,那就得把它拆开来做,Python没有完全为你做好的工具

关于时间的比较,请看答案底部。

为了把尽可能多的工作交给C层(这通常是CPython中最快的解决方案),同时确保未使用的块被丢弃,修改过的Moses的解决方案是最佳选择(groupby负责清理之前组中未被使用的元素):

from functools import partial
from itertools import cycle, groupby
from operator import itemgetter

def chunks(iterable, size=10):
    c = cycle((False,) * size + (True,) * size)  # Make a cheap iterator that will group in groups of size elements
    # groupby will pass an element to next as its second argument, but because
    # c is an infinite iterator, the second argument will never get used
    return map(itemgetter(1), groupby(iterable, partial(next, c)))

如果你从来不需要留下未使用的组,使用修改过的tobiak_k的答案可以稍微快一点:

def chunks(iterable, size=10):
    it = iter(iterable)
    make_islice = partial(islice, it, size - 1)
    for first in it:
        yield chain((first,), make_islice())

不过,这种增益很小(见时间比较),所以我建议只有在你想要将第一未使用元素作为下一个组的第一个元素时才使用;否则,还是用基于groupby的解决方案,它能给你更好的保证。

话虽如此,你确定这些组需要懒加载吗?因为如果每个块可以作为tuple生成,还有更好更简单的解决方案,可以急切地消费每个块,但整体输入是懒加载的。以下是按Python版本的最佳选项:

未来最佳的半急切解决方案(需要Python 3.12或更高版本,预计在2023年10月发布),完全不需要手动编写代码,并且在所有情况下都是最快的如果你可以预处理一个

从Python 3.12开始,有一个内置的工具,itertools.batched。参数与下面的chunker配方相反,但行为是一样的(将数据分批成长度为ntuple,最后一批可能不完整):

from itertools import batched

for batch in batched('ABCDEFG', 3):
    print(batch)

将输出:

('A', 'B', 'C')
('D', 'E', 'F')
('G',)

答案底部的时间比较显示,如果你可以依赖3.12+并且可以将组作为tuple生成,那它绝对是最好的解决方案(剧透:它的执行时间是其他等效懒加载解决方案的15-30%,具体取决于组是否被完全消费)。它是在C层实现的,实现利用了Python层无法实现的性能优化,使其轻松超越任何在Python层实现的解决方案。特别是:

  • tupleislice解决方案不同,它总是为每个批次预先分配一个大小为ntuple,并直接填充(而tupleislice涉及逐个构建tuple,每个批次都耗费时间;它使用的是摊销增长,但每个批次仍可能涉及几次重新分配)
  • zip_longest解决方案不同(由于zip_longest的实现,它确实会预先设置tuple的大小,可能是其在大多数情况下更好性能的来源):
    1. (每个批次的浪费) 每个批次只需查找一次.__next__,而不是每个批次查找n次(这是一个小成本,考虑到C层查找,但在每个批次中都要付出),并且
    2. (与最终不完整批次大小相关的一次性浪费) 它不需要为fillvalue创建一个哨兵,也不需要在最后一批后查找它(这至少需要O(log n)的二分查找,3.10之前需要O(n)的线性查找,加上O(n)的切片工作);batched在提取元素时会计算元素数量,作为跟踪插入位置的副作用,因此当迭代器耗尽时,它会立即停止,并可以直接realloctuple以匹配能够提取的元素数量(避免了新的tuple,并且通常避免了任何复制)。

在3.12之前,对于小到中等的n和/或大量批次的最快解决方案(对于3.10-3.11,即使在极端情况下也没有明显变慢):

当块大小通常较小时,最快的解决方案是这个,改编自rhettg的答案

from itertools import takewhile, zip_longest

def chunker(n, iterable):
    '''chunker(3, 'ABCDEFG') --> ('A', 'B', 'C'), ('D', 'E', 'F'),  ('G',)'''
    fillvalue = object()  # Anonymous sentinel object that can't possibly appear in input
    args = (iter(iterable),) * n
    for x in zip_longest(*args, fillvalue=fillvalue):
        if x[-1] is fillvalue:
            # takewhile optimizes a bit for when n is large and the final
            # group is small; at the cost of a little performance, you can
            # avoid the takewhile import and simplify to:
            # yield tuple(v for v in x if v is not fillvalue)
            yield tuple(takewhile(lambda v: v is not fillvalue, x))
        else:
            yield x

如果性能至关重要,特别是当块大小较大且你需要做很多次时,并且你可以依赖3.10+(这是bisect添加了key参数支持的版本),你可以通过将O(n)takewhile替换为O(log n)的二分查找来稍微改进上述内容,添加from bisect import bisect到导入中,并更改:

yield tuple(takewhile(lambda v: v is not fillvalue, x))

为:

yield x[:bisect(x, False, key=lambda v: v is fillvalue)]  # 3.10+ only!

旧的答案(仍然非常快,因为将所有工作推到C层,但在除了极端情况外,略逊于基于zip_longest的解决方案[涉及小批量的巨大块大小],在常见情况下大约慢2倍):

通过使用纯C级内置函数(在CPython中),生成每个块不需要任何Python字节码(除非底层生成器是用Python实现的),这带来了巨大的性能优势。它确实会在返回之前遍历每个,但在返回的块之外不会进行任何预处理:

# Only needed on Py2, to get iterator-based map; Py3's is already iterator-based
from future_builtins import map

from itertools import islice, repeat, starmap, takewhile

# operator.truth is *significantly* faster than bool for the case of
# exactly one positional argument prior to 3.10; in 3.10+, you can
# just use bool (which is trivially faster than truth)
from operator import truth

def chunker(n, iterable):  # n is size of each chunk; last chunk may be smaller
    return takewhile(truth, map(tuple, starmap(islice, repeat((iter(iterable), n)))))

由于这有点复杂,下面是更易理解的版本:

def chunker(n, iterable):
    iterable = iter(iterable)
    while True:
        x = tuple(islice(iterable, n))
        if not x:
            return
        yield x

将对chunker的调用包装在enumerate中,可以让你为块编号,如果需要的话。

时间比较(在CPython 3.12.0,RHEL8上)

懒加载组并确保组的对齐,即使组没有完全消费的解决方案:
>>> from itertools import *
>>> from functools import partial
>>> from operator import itemgetter
>>> def chunks_moses(iterable, size=10):
...     c = count()
...     for _, g in groupby(iterable, lambda _: next(c)//size):
...         yield g
...
>>> def chunks_moses_optimized(iterable, size=10):
...    c = cycle((False,) * size + (True,) * size)
...    return map(itemgetter(1), groupby(iterable, partial(next, c)))
...
>>> %%timeit b = b'\0'*10000
... for grp in chunks_moses(b):
...     next(grp)  # Only consume one element, but rest of group skipped for you
...
1.35 ms ± 31.1 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

>>> %%timeit b = b'\0'*10000
... for grp in chunks_moses_optimized(b):
...     next(grp)  # Only consume one element, but rest of group skipped for you
469 μs ± 5.32 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

因此,将相同的基本工作推到C层,并避免不必要的Python级int操作(创建连续值,除法)大约减少了三分之一的工作量。

懒加载组,不消费未处理元素的解决方案

这些解决方案的行为与之前的解决方案相同,仅当所有组在下一个组开始之前完全消费时,因此我们将它们与之前的完全消费模式的解决方案进行比较:

# Relies on same imports as above
>>> def chunks_tobias(iterable, size=10):
...     iterator = iter(iterable)
...     for first in iterator:
...         yield chain([first], islice(iterator, size - 1))
... 
>>> def chunks_tobias_optimized(iterable, size=10):
...     it = iter(iterable)
...     make_islice = partial(islice, it, size - 1)
...     for first in it:
...         yield chain((first,), make_islice())
... 
>>> %%timeit b = b'\0'*10000; from collections import deque; consume = deque(maxlen=0).extend
... for grp in chunks_moses(b):
...     consume(grp)  # Consume all elements one at a time without storing them
...
1.26 ms ± 4.22 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

>>> %%timeit b = b'\0'*10000; from collections import deque; consume = deque(maxlen=0).extend
... for grp in chunks_moses_optimized(b):
...     consume(grp)
...
480 μs ± 11.2 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

>>> %%timeit b = b'\0'*10000; from collections import deque; consume = deque(maxlen=0).extend
... for grp in chunks_tobias(b):
...     consume(grp)
...
531 μs ± 15.4 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

>>> %%timeit b = b'\0'*10000; from collections import deque; consume = deque(maxlen=0).extend
... for grp in chunks_tobias_optimized(b):
...     consume(grp)
...
461 μs ± 24.3 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

因此,使用chainislice进行分块的速度稍微快于优化后的groupby,但代价是没有正确处理部分消费的组。

急切完成每个完整组后再返回的解决方案(但仍然懒加载整个迭代器)

如果你真的不能在内存中同时存储超过1-2个元素,这些解决方案就不适用了,但通常这不是一个真正的问题,所以首先考虑这些选项:

>>> def chunks_rhettg_optimized(iterable, size=10):
...     fillvalue = object()
...     args = (iter(iterable),) * size
...     for x in zip_longest(*args, fillvalue=fillvalue):
...         if x[-1] is fillvalue:
...             yield tuple(takewhile(lambda v: v is not fillvalue, x))
...         else:
...             yield x
...
>>> %%timeit b = b'\0'*10000; from collections import deque; consume = deque(maxlen=0).extend
... for grp in chunks_rhettg_optimized(b):
...     consume(grp)
...
195 μs ± 8.43 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

>>> %%timeit b = b'\0'*10000; from collections import deque; consume = deque(maxlen=0).extend
... for grp in batched(b, 10):
...     consume(grp)
...
142 μs ± 4.87 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

为了完整起见,当你只检查结果的单个元素时,这些解决方案的时间比较:

>>> %%timeit b = b'\0'*10000
... for grp in chunks_rhettg_optimized(b):
...     grp[0]
...
116 μs ± 712 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

>>> %%timeit b = b'\0'*10000
... for grp in batched(b, 10):
...     grp[0]
...
75 μs ± 186 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

时间信息总结:

  1. 如果你可以承担使用急切生成tuple的解决方案的内存/资源,而不是生成该组的懒加载迭代器,它比任何生成组迭代器的解决方案都快。生成tuple的时间仅需~30-45%,而如果你只使用每个结果的一部分,时间甚至可以低至~15%(当比较只使用每个组部分的懒加载组迭代器解决方案,但仍然跳过下一个组之前未消费的元素时,与只检查部分组的itertools.batched相比)。
  2. 如果你必须使用懒加载组迭代器解决方案,优化使用groupby+cycle+partial+itemgetter+map(这将确保在生成下一个组之前消费最后一个组,即使最后一个组没有被用户完全消费)几乎与优化使用chain+partial+islice一样快,并且保证未完全消费的组在生成下一个组之前仍会被消费(它并不保证如果下一个组没有被请求,组会被消费,但这通常是你想要的行为)。我唯一考虑不保证在下一个组之前完全消费组的简单更快版本的情况是当这是期望的行为时(如果你没有完全消费一个组,你想要下一个组以第一个未消费的元素开始)。
17

另一种创建分组或块的方法,而不是提前遍历生成器,是使用itertools.groupby,并通过一个使用itertools.count对象的键函数来实现。因为这个count对象和可迭代对象是独立的,所以可以轻松生成块,而不需要知道可迭代对象里面有什么。

每次调用groupby时,都会调用count对象的next方法,并通过将当前计数值除以块的大小来生成一个组/块的(后面跟着块中的项目)。

from itertools import groupby, count

def chunks(iterable, size=10):
    c = count()
    for _, g in groupby(iterable, lambda _: next(c)//size):
        yield g

每个由生成器函数产生的组/块g都是一个迭代器。不过,由于groupby对所有组使用了一个共享的迭代器,因此这些组的迭代器不能存储在列表或任何容器中,每个组的迭代器必须在下一个之前被消耗掉。

98

一种方法是先看看第一个元素,如果有的话,然后再创建并返回实际的生成器。

def head(iterable, max=10):
    first = next(iterable)      # raise exception when depleted
    def head_inner():
        yield first             # yield the extracted first element
        for cnt, el in enumerate(iterable):
            yield el
            if cnt + 1 >= max:  # cnt + 1 to include first
                break
    return head_inner()

在你的 chunk 生成器中使用这个,并像你处理自定义异常那样捕获 StopIteration 异常。


更新:这里有另一个版本,使用 itertools.islice 来替代大部分的 head 函数,并且使用了一个 for 循环。这个简单的 for 循环实际上做的事情和原始代码中复杂的 while-try-next-except-break 结构是 完全一样的,所以结果要 可读得多

def chunks(iterable, size=10):
    iterator = iter(iterable)
    for first in iterator:    # stops when iterator is depleted
        def chunk():          # construct generator for next chunk
            yield first       # yield element from for loop
            for more in islice(iterator, size - 1):
                yield more    # yield more elements from the iterator
        yield chunk()         # in outer generator, yield next chunk

而且我们可以做到更简洁,使用 itertools.chain 来替代内部生成器:

def chunks(iterable, size=10):
    iterator = iter(iterable)
    for first in iterator:
        yield chain([first], islice(iterator, size - 1))

撰写回答