Polars Dataframe 是线程安全的吗?

1 投票
1 回答
94 浏览
提问于 2025-04-13 19:36

Polars 是一个很有前景的 Python 库,它有很多不错的功能,比如可以并行处理和快速计算。这些内容在文档中都有提到。

不过,有一点不是很明显,那就是 Polars 数据框的线程安全性,特别是在修改数值的时候。

我听说 Pandas 数据框不是线程安全的,这意味着你不应该从多个线程去修改一个 Pandas 数据框。

假设有这样的代码:

import threading
import polars as pl

df = pl.DataFrame({'a': range(5), 'b': range(5)})


def work_parallelly():
    for _ in range(100000):
        df[2, 'b'] += 1


thread = threading.Thread(target=work_parallelly)
thread.start()
for _ in range(100000):
    df[2, 'a'] += 1

thread.join()

print(df)

这样做算不算安全呢?结果看起来好像是正常工作的。

shape: (5, 2)
┌────────┬────────┐
│ a      ┆ b      │
│ ---    ┆ ---    │
│ i64    ┆ i64    │
╞════════╪════════╡
│ 0      ┆ 0      │
│ 1      ┆ 1      │
│ 100002 ┆ 100002 │
│ 3      ┆ 3      │
│ 4      ┆ 4      │
└────────┴────────┘

另外,.collect() 这个方法算不算线程安全呢?

1 个回答

1

如果你的代码使用多个Python线程来修改一个数据框(dataframe)中的值,那么你可能会遇到数据损坏的问题。

不过,对于大多数情况来说,这个问题是比较容易解决的。你可以使用简单的线程锁(threading.Lock类),这样可以确保同一时间只有一个Python线程在修改这个数据框。

如果你的Python代码是多线程的,并且需要处理对同一个数据框的多个修改,那么还有一个简单的解决办法。你可以不急于在收集到需要执行的操作信息时就立即应用修改,而是先把这些操作的描述放到一个线程队列(queue.Queue)里。然后,保持一个单独的工作线程来读取这个队列,并按顺序执行这些修改的Python代码。(这样,Polars会自动并行处理每个操作,但如果这些操作来自同一个Python线程,它们会一个接一个地执行。)

不过,这段代码在你的例子中会慢很多,因为每次“+= 1”的操作都需要经过队列中的参数序列化和反序列化,以及执行动作。(另一方面,你在例子中提到的Python循环中的“+=”操作是线程安全的,因为Python会确保同一时间只有一个纯Python的“+=”操作在运行,这要归功于全局解释器锁(GIL)。但是,对于广播和数据框中的其他内部转换,这种安全性就不适用了):

import threading
import polars as pl
from queue import Queue
from operator import iadd

df = pl.DataFrame({'a': range(5), 'b': range(5)})
q = Queue()
_SENTINEL = None

def worker():
    while True:
        op, args = q.get()

        if op is _SENTINEL:
            break
        elif callable(op):
            op(*args)
        elif isinstance(op, str):
            # inplace augmented operators for elements
            # of an array can't be passed as functions: -
            # so we can create a mini protocol, where the
            # operation name comes as a string:
            match op:
                case "iadd":
                    args[0][args[1], args[2]] += args[3]
                case _:
                    pass



def work_parallelly(col):
    for _ in range(100_000):
        q.put(("iadd", (df, 2, col, 1)))


worker_thread = threading.Thread(target=worker)
thread1 = threading.Thread(target=work_parallelly, args=("a",))
thread2 = threading.Thread(target=work_parallelly, args=("b",))
worker_thread.start()
thread1.start()
thread2.start()
thread1.join()
thread2.join()
q.put((None, None))
worker_thread.join()


print(df)

撰写回答