如何在Python中使用超时迭代AsyncIterator流而不取消流?

0 投票
1 回答
18 浏览
提问于 2025-04-13 00:10

我正在处理一个对象,它是一个 AsyncIterator[str]。这个对象从网络获取消息,并把这些消息作为字符串输出。我想为这个消息流创建一个包装器,能够缓存这些消息,并在固定的时间间隔内输出它们。

我的代码看起来是这样的:

async def buffer_stream(stream: AsyncIterator[str], buffer_time: Optional[float]) -> AsyncIterator[str]:
    """
    Buffer messages from the stream, and yields them at regular intervals.
    """
    last_sent_at = time.perf_counter()
    buffer = ''

    stop = False
    while not stop:
        time_to_send = False

        timeout = (
            max(buffer_time - (time.perf_counter() - last_sent_at), 0)
            if buffer_time else None
        )
        try:
            buffer += await asyncio.wait_for(
                stream.__anext__(),
                timeout=timeout
            )
        except asyncio.TimeoutError:
            time_to_send = True
        except StopAsyncIteration:
            time_to_send = True
            stop = True
        else:
            if time.perf_counter() - last_sent_at >= buffer_time:
                time_to_send = True

        if not buffer_time or time_to_send:
            if buffer:
                yield buffer
                buffer = ''
            last_sent_at = time.perf_counter()

从我能看出来的逻辑上是没问题的,但一旦到达第一个超时时间,它就会中断这个流,并提前退出,没等流处理完。

我觉得这可能是因为 asyncio.wait_for 特别说明了:

当发生超时时,会取消任务并抛出 TimeoutError。为了避免任务被取消,可以用 shield() 包裹它。

我尝试用 shield 包裹它:

buffer += await asyncio.wait_for(
    shield(stream.__anext__()),
    timeout=timeout
)

但这又因为其他原因出错了:RuntimeError: anext(): asynchronous generator is already running。根据我的理解,这意味着在尝试获取下一个值时,它还在处理上一个 anext(),这就导致了错误。

有没有合适的方法来实现这个呢?

演示链接: https://www.sololearn.com/en/compiler-playground/cBCVnVAD4H7g

1 个回答

0

你可以把 stream.__anext__() 的结果变成一个任务(或者更一般地说,变成一个未来的结果),然后等待它,直到超时或者返回一个结果为止:

async def buffer_stream(stream: AsyncIterator[str], buffer_time: Optional[float]) -> AsyncIterator[str]:
    last_sent_at = time.perf_counter()
    buffer = ''

    stop = False
    await_next = None
    while not stop:
        time_to_send = False

        timeout = (
            max(buffer_time - (time.perf_counter() - last_sent_at), 0)
            if buffer_time else None
        )
        if await_next is None:
            await_next = asyncio.ensure_future(stream.__anext__())
        try:
            buffer += await asyncio.wait_for(
                asyncio.shield(await_next),
                timeout=timeout
            )
        except asyncio.TimeoutError:
            time_to_send = True
        except StopAsyncIteration:
            time_to_send = True
            stop = True
        else:
            await_next = None
            if time.perf_counter() - last_sent_at >= buffer_time:
                time_to_send = True

        if not buffer_time or time_to_send:
            if buffer:
                yield buffer
                buffer = ''
            last_sent_at = time.perf_counter()

撰写回答