Polars 多进程的 map_batches UDF

5 投票
1 回答
106 浏览
提问于 2025-04-13 02:39

我想要使用一个叫做 numba UDF 的东西,它可以为 df 中的每个组生成相同长度的向量:

import numba

df = pl.DataFrame(
    {
        "group": ["A", "A", "A", "B", "B"],
        "index": [1, 3, 5, 1, 4],
    }
)

@numba.jit(nopython=True)
def UDF(array: np.ndarray, threshold: int) -> np.ndarray:
    result = np.zeros(array.shape[0])
    accumulator = 0
    
    for i, value in enumerate(array):
        accumulator += value
        if accumulator >= threshold:
            result[i] = 1
            accumulator = 0
            
    return result

df.with_columns(
    pl.col("index")
    .map_batches(
        lambda x: UDF(x.to_numpy(), 5)
        )
    .over("group")
    .cast(pl.UInt8)
    .alias("udf")
    )

这个想法是受到这篇文章的启发,文章中介绍了如何使用 multi-processing。不过在我这个例子里,我是通过 over 窗口函数来应用这个 UDF。有没有什么高效的方法可以并行化上面的操作呢?

期望的输出:

shape: (6, 3)
┌───────┬───────┬─────┐
│ group ┆ index ┆ udf │
│ ---   ┆ ---   ┆ --- │
│ str   ┆ i64   ┆ u8  │
╞═══════╪═══════╪═════╡
│ A     ┆ 1     ┆ 0   │
│ A     ┆ 3     ┆ 0   │
│ A     ┆ 5     ┆ 1   │
│ B     ┆ 1     ┆ 0   │
│ B     ┆ 4     ┆ 1   │
└───────┴───────┴─────┘

相关问题:

1 个回答

3

这里有一个例子,展示了如何使用 以及它的并行处理功能来实现这个目标:

from numba import njit, prange


@njit(parallel=True)
def UDF_nb_parallel(array, n, threshold):
    result = np.zeros_like(array, dtype="uint8")

    for i in prange(array.size // n):
        accumulator = 0
        for j in range(i * n, (i + 1) * n):
            value = array[j]
            accumulator += value
            if accumulator >= threshold:
                result[j] = 1
                accumulator = 0

    return result

df = df.with_columns(
    pl.Series(name="new_udf", values=UDF_nb_parallel(df["index"].to_numpy(), 3, 5))
)
print(df)

输出结果:

shape: (9, 3)
┌───────┬───────┬─────────┐
│ group ┆ index ┆ new_udf │
│ ---   ┆ ---   ┆ ---     │
│ str   ┆ i64   ┆ u8      │
╞═══════╪═══════╪═════════╡
│ A     ┆ 1     ┆ 0       │
│ A     ┆ 3     ┆ 0       │
│ A     ┆ 5     ┆ 1       │
│ B     ┆ 1     ┆ 0       │
│ B     ┆ 4     ┆ 1       │
│ B     ┆ 8     ┆ 1       │
│ C     ┆ 1     ┆ 0       │
│ C     ┆ 1     ┆ 0       │
│ C     ┆ 4     ┆ 1       │
└───────┴───────┴─────────┘

基准测试:

from timeit import timeit

import numpy as np
import polars as pl
from numba import njit, prange


def get_df(N, n):
    assert N % n == 0

    df = pl.DataFrame(
        {
            "group": [f"group_{i}" for i in range(N // n) for _ in range(n)],
            "index": np.random.randint(1, 5, size=N, dtype="uint64"),
        }
    )
    return df


@njit
def UDF(array: np.ndarray, threshold: int) -> np.ndarray:
    result = np.zeros(array.shape[0])
    accumulator = 0

    for i, value in enumerate(array):
        accumulator += value
        if accumulator >= threshold:
            result[i] = 1
            accumulator = 0

    return result


@njit(parallel=True)
def UDF_nb_parallel(array, n, threshold):
    result = np.zeros_like(array, dtype="uint8")

    for i in prange(array.size // n):
        accumulator = 0
        for j in range(i * n, (i + 1) * n):
            value = array[j]
            accumulator += value
            if accumulator >= threshold:
                result[j] = 1
                accumulator = 0

    return result


def get_udf_polars(df):
    return df.with_columns(
        pl.col("index")
        .map_batches(lambda x: UDF(x.to_numpy(), 5))
        .over("group")
        .cast(pl.UInt8)
        .alias("udf")
    )


df = get_df(3 * 33_333, 3)  # 100_000 values, length of groups 3

df = get_udf_polars(df)

df = df.with_columns(
    pl.Series(name="new_udf", values=UDF_nb_parallel(df["index"].to_numpy(), 3, 5))
)

assert np.allclose(df["udf"].to_numpy(), df["new_udf"].to_numpy())


t1 = timeit("get_udf_polars(df)", number=1, globals=globals())
t2 = timeit(
    'df.with_columns(pl.Series(name="new_udf", values=UDF_nb_parallel(df["index"].to_numpy(), 3, 5)))',
    number=1,
    globals=globals(),
)

print(t1)
print(t2)

在我的电脑上(AMD 5700x)输出:

2.7000599699968006
0.00025866299984045327

处理 100_000_000 行/组时,耗时 0.06319052699836902 (如果 parallel=False,则耗时 0.2159650030080229


编辑:处理可变长度的组:

@njit(parallel=True)
def UDF_nb_parallel_2(array, indices, amount, threshold):
    result = np.zeros_like(array, dtype="uint8")

    for i in prange(indices.size):
        accumulator = 0
        for j in range(indices[i], indices[i] + amount[i]):
            value = array[j]
            accumulator += value
            if accumulator >= threshold:
                result[j] = 1
                accumulator = 0

    return result

def get_udf_polars_nb(df):
    n = df["group"].to_numpy()
    indices = np.unique(n, return_index=True)[1]
    amount = np.diff(np.r_[indices, [n.size]])
    return df.with_columns(
        pl.Series(
            name="new_udf",
            values=UDF_nb_parallel_2(df["index"].to_numpy(), indices, amount, 5),
        )
    )

df = get_udf_polars_nb(df)

基准测试:

import random
from timeit import timeit

import numpy as np
import polars as pl
from numba import njit, prange


def get_df(N):
    groups = []
    cnt, group_no, running = 0, 1, True
    while running:
        for _ in range(random.randint(3, 10)):
            groups.append(group_no)
            cnt += 1
            if cnt >= N:
                running = False
                break
        group_no += 1

    df = pl.DataFrame(
        {
            "group": groups,
            "index": np.random.randint(1, 5, size=N, dtype="uint64"),
        }
    )
    return df


@njit
def UDF(array: np.ndarray, threshold: int) -> np.ndarray:
    result = np.zeros(array.shape[0])
    accumulator = 0

    for i, value in enumerate(array):
        accumulator += value
        if accumulator >= threshold:
            result[i] = 1
            accumulator = 0

    return result


@njit(parallel=True)
def UDF_nb_parallel_2(array, indices, amount, threshold):
    result = np.zeros_like(array, dtype="uint8")

    for i in prange(indices.size):
        accumulator = 0
        for j in range(indices[i], indices[i] + amount[i]):
            value = array[j]
            accumulator += value
            if accumulator >= threshold:
                result[j] = 1
                accumulator = 0

    return result


def get_udf_polars(df):
    return df.with_columns(
        pl.col("index")
        .map_batches(lambda x: UDF(x.to_numpy(), 5))
        .over("group")
        .cast(pl.UInt8)
        .alias("udf")
    )


def get_udf_polars_nb(df):
    n = df["group"].to_numpy()
    indices = np.unique(n, return_index=True)[1]
    amount = np.diff(np.r_[indices, [n.size]])
    return df.with_columns(
        pl.Series(
            name="new_udf",
            values=UDF_nb_parallel_2(df["index"].to_numpy(), indices, amount, 5),
        )
    )


df = get_df(100_000)  # 100_000 values, length of groups length 3-9

df = get_udf_polars(df)
df = get_udf_polars_nb(df)

assert np.allclose(df["udf"].to_numpy(), df["new_udf"].to_numpy())


t1 = timeit("get_udf_polars(df)", number=1, globals=globals())
t2 = timeit("get_udf_polars_nb(df)", number=1, globals=globals())

print(t1)
print(t2)

输出结果:

1.2675148629932664
0.0024339070077985525

撰写回答