在使用asyncio和pandas.to_sql()时内存不足

0 投票
1 回答
26 浏览
提问于 2025-04-13 17:22

我正在尝试同时从一个API获取数据,并把这些数据保存到SQL服务器的表中。但是我遇到了内存问题。我觉得问题出在我先等待获取任务完成,然后再等待保存任务。这是我的代码:

import asyncio
import polars as pl
import pandas as pd
import itertools


async def save_to_land(transactions, table_name, conn_object, semaphore):

    def process_and_save(transactions):
        df = pl.DataFrame(transactions, infer_schema_length=2000)
        # some processing with polars
        #  ...
        df.to_pandas().to_sql(table_name, conn_object, if_exists='append', index=False)

    if transactions:
        async with semaphore:
            # run sync code in separate thread to not block async event loop
            await asyncio.to_thread(process_and_save, transactions)

async def get_transactions(api, practice_id, year, incremental, with_deleted, date_modified, semaphore):
    async with semaphore:
        transactions = await api.get_transactions(practice_id, year, incremental=incremental, with_deleted=with_deleted, date_modified=date_modified)
        await asyncio.sleep(0.5)
        return transactions
    
async def get_transactions_and_save_to_land(api, practice_ids, table_name, conn_object, start_year, end_year, incremental, with_deleted, date_modified, batch_size=100000):

    # for concurrent api requests
    sem1 = asyncio.Semaphore(4)
    # for concurrent db writes
    sem2 = asyncio.Semaphore(3)

    if not incremental:
        get_tasks = []
        for practice_id, year in itertools.product(practice_ids, range(start_year, end_year + 1)):
            task = asyncio.create_task(get_transactions(api, practice_id, year, incremental=incremental, with_deleted=with_deleted, date_modified=date_modified, semaphore=sem1))
            get_tasks.append(task)
        
        results = await asyncio.gather(*get_tasks) 
        
        save_tasks = []
        batch_transactions = []
        for transactions in results:
            batch_transactions.extend(transactions)
            if len(batch_transactions) >= batch_size:
                task = asyncio.create_task(save_to_land(batch_transactions, table_name, conn_object, semaphore=sem2))
                save_tasks.append(task)
                batch_transactions = []

        # for any remaining transactions
        if batch_transactions:
            save_tasks.append(asyncio.create_task(save_to_land(batch_transactions, table_name, conn_object, semaphore=sem2)))

        await asyncio.gather(*save_tasks) 

我该怎么解决这个问题呢?有没有更有经验的人能给我一些关于asyncio的建议?另外,请注意,我使用了 asyncio.to_thread() 来运行 pandas.to_sql() 这部分,因为它是阻塞代码。

1 个回答

1

好吧,你在写任何东西之前都在读取所有的结果。如果你的内存不够大,无法存下所有的结果,那你就会遇到溢出的问题。

你可以考虑使用 asyncio.as_completed。你代码的最后部分可以这样写:

        # delete results = asyncio.gather(....)
        save_tasks = []
        batch_transactions = []
        for transactions in asyncio.as_completed(get_tasks):
            batch_transactions.extend(transactions)
            ... rest of code as written ...

还有一种替代方案,它也解决了另一个问题。

与其让你的读取器 return transactions,不如直接把它们放到一个队列里:q.put(transactions)。你可以决定 q 是全局变量,还是作为参数传递。

然后你可以继续:

    task_count = len(get_tasks)
    for _ in range(task_count):
        transactions = q.get()
        q.task_done()
        ... rest of processing the same ...
    ... wait for all tasks to finish ...

这样有两个好处。

首先,你可以把 q 定义为 queue.Queue(maxsize=5)。即使你对读取和写入的速度进行了限制,读取的速度还是可能会远远超过写入的速度。通过强制他们把结果放入一个大小有限的阻塞队列,读取的速度就不会太快,跟不上写入的速度。

其次,你也不再需要单独跟踪读取任务和写入任务。你只需要把它们都放在一个列表里,最后用 gather 一起处理。稍微改写一下,你甚至可以直接使用 TaskGroup

撰写回答