在使用asyncio和pandas.to_sql()时内存不足
我正在尝试同时从一个API获取数据,并把这些数据保存到SQL服务器的表中。但是我遇到了内存问题。我觉得问题出在我先等待获取任务完成,然后再等待保存任务。这是我的代码:
import asyncio
import polars as pl
import pandas as pd
import itertools
async def save_to_land(transactions, table_name, conn_object, semaphore):
def process_and_save(transactions):
df = pl.DataFrame(transactions, infer_schema_length=2000)
# some processing with polars
# ...
df.to_pandas().to_sql(table_name, conn_object, if_exists='append', index=False)
if transactions:
async with semaphore:
# run sync code in separate thread to not block async event loop
await asyncio.to_thread(process_and_save, transactions)
async def get_transactions(api, practice_id, year, incremental, with_deleted, date_modified, semaphore):
async with semaphore:
transactions = await api.get_transactions(practice_id, year, incremental=incremental, with_deleted=with_deleted, date_modified=date_modified)
await asyncio.sleep(0.5)
return transactions
async def get_transactions_and_save_to_land(api, practice_ids, table_name, conn_object, start_year, end_year, incremental, with_deleted, date_modified, batch_size=100000):
# for concurrent api requests
sem1 = asyncio.Semaphore(4)
# for concurrent db writes
sem2 = asyncio.Semaphore(3)
if not incremental:
get_tasks = []
for practice_id, year in itertools.product(practice_ids, range(start_year, end_year + 1)):
task = asyncio.create_task(get_transactions(api, practice_id, year, incremental=incremental, with_deleted=with_deleted, date_modified=date_modified, semaphore=sem1))
get_tasks.append(task)
results = await asyncio.gather(*get_tasks)
save_tasks = []
batch_transactions = []
for transactions in results:
batch_transactions.extend(transactions)
if len(batch_transactions) >= batch_size:
task = asyncio.create_task(save_to_land(batch_transactions, table_name, conn_object, semaphore=sem2))
save_tasks.append(task)
batch_transactions = []
# for any remaining transactions
if batch_transactions:
save_tasks.append(asyncio.create_task(save_to_land(batch_transactions, table_name, conn_object, semaphore=sem2)))
await asyncio.gather(*save_tasks)
我该怎么解决这个问题呢?有没有更有经验的人能给我一些关于asyncio的建议?另外,请注意,我使用了 asyncio.to_thread()
来运行 pandas.to_sql()
这部分,因为它是阻塞代码。
1 个回答
1
好吧,你在写任何东西之前都在读取所有的结果。如果你的内存不够大,无法存下所有的结果,那你就会遇到溢出的问题。
你可以考虑使用 asyncio.as_completed。你代码的最后部分可以这样写:
# delete results = asyncio.gather(....)
save_tasks = []
batch_transactions = []
for transactions in asyncio.as_completed(get_tasks):
batch_transactions.extend(transactions)
... rest of code as written ...
还有一种替代方案,它也解决了另一个问题。
与其让你的读取器 return transactions
,不如直接把它们放到一个队列里:q.put(transactions)
。你可以决定 q
是全局变量,还是作为参数传递。
然后你可以继续:
task_count = len(get_tasks)
for _ in range(task_count):
transactions = q.get()
q.task_done()
... rest of processing the same ...
... wait for all tasks to finish ...
这样有两个好处。
首先,你可以把 q
定义为 queue.Queue(maxsize=5)
。即使你对读取和写入的速度进行了限制,读取的速度还是可能会远远超过写入的速度。通过强制他们把结果放入一个大小有限的阻塞队列,读取的速度就不会太快,跟不上写入的速度。
其次,你也不再需要单独跟踪读取任务和写入任务。你只需要把它们都放在一个列表里,最后用 gather
一起处理。稍微改写一下,你甚至可以直接使用 TaskGroup
。