Polars 表达式无法访问中间列创建表达式

3 投票
2 回答
61 浏览
提问于 2025-04-14 15:56

我想用整数来编码那些非零的二进制事件。下面是一个示例表格:

import polars as pl
df = pl.DataFrame(
    {
        "event": [0, 1, 1, 0],
        "foo": [1, 2, 3, 4],
        "boo": [2, 3, 4, 5],
    }
)

我希望得到的结果是通过以下方式实现的:

df = df.with_row_index()
events = df.select(pl.col(["index", "event"])).filter(pl.col("event") == 1).with_row_index("event_id").drop("event")
df = df.join(events, on="index", how="left")

out:
shape: (4, 5)
┌───────┬───────┬─────┬─────┬──────────┐
│ index ┆ event ┆ foo ┆ boo ┆ event_id │
│ ---   ┆ ---   ┆ --- ┆ --- ┆ ---      │
│ u32   ┆ i64   ┆ i64 ┆ i64 ┆ u32      │
╞═══════╪═══════╪═════╪═════╪══════════╡
│ 0     ┆ 0     ┆ 1   ┆ 2   ┆ null     │
│ 1     ┆ 1     ┆ 2   ┆ 3   ┆ 0        │
│ 2     ┆ 1     ┆ 3   ┆ 4   ┆ 1        │
│ 3     ┆ 0     ┆ 4   ┆ 5   ┆ null     │
└───────┴───────┴─────┴─────┴──────────┘

我想通过连接表达式来获得预期的结果:

(
    df
    .with_row_index()
    .join(
        df
        .select(pl.col(["index", "event"]))
        .filter(pl.col("event") == 1)
        .with_row_index("event_id")
        .drop("event"),
        on="index",
        how="left",
        )
    )

不过,在.join()表达式中的那些表达式似乎没有添加来自df.with_row_index()操作的index列:

ColumnNotFoundError: index

Error originated just after this operation:
DF ["event", "foo", "boo"]; PROJECT */3 COLUMNS; SELECTION: "None"

2 个回答

1

一种可能的解决办法是使用 := 这个操作符:

df = (df := df.with_row_index()).join(
    df.select(pl.col(["index", "event"]))
    .filter(pl.col("event") == 1)
    .with_row_index("event_id")
    .drop("event"),
    on="index",
    how="left",
)

print(df)

输出结果是:

shape: (4, 5)
┌───────┬───────┬─────┬─────┬──────────┐
│ index ┆ event ┆ foo ┆ boo ┆ event_id │
│ ---   ┆ ---   ┆ --- ┆ --- ┆ ---      │
│ u32   ┆ i64   ┆ i64 ┆ i64 ┆ u32      │
╞═══════╪═══════╪═════╪═════╪══════════╡
│ 0     ┆ 0     ┆ 1   ┆ 2   ┆ null     │
│ 1     ┆ 1     ┆ 2   ┆ 3   ┆ 0        │
│ 2     ┆ 1     ┆ 3   ┆ 4   ┆ 1        │
│ 3     ┆ 0     ┆ 4   ┆ 5   ┆ null     │
└───────┴───────┴─────┴─────┴──────────┘
2

虽然使用海象运算符的解决方案可以工作,但用pl.when().then()结合pl.int_range()来创建event_id,可能会更符合常规写法,也更简洁。

(
    df
    .with_columns(
        pl.when(pl.col("event") == 1)
        .then(pl.int_range(pl.len()))
        .over("event")
        .alias("event_id")
    )
)
shape: (4, 4)
┌───────┬─────┬─────┬──────────┐
│ event ┆ foo ┆ boo ┆ event_id │
│ ---   ┆ --- ┆ --- ┆ ---      │
│ i64   ┆ i64 ┆ i64 ┆ i64      │
╞═══════╪═════╪═════╪══════════╡
│ 0     ┆ 1   ┆ 2   ┆ null     │
│ 1     ┆ 2   ┆ 3   ┆ 0        │
│ 1     ┆ 3   ┆ 4   ┆ 1        │
│ 0     ┆ 4   ┆ 5   ┆ null     │
└───────┴─────┴─────┴──────────┘

撰写回答