如何根据行ID/行号过滤数据框

2 投票
1 回答
54 浏览
提问于 2025-04-13 00:42

我想从一个数据表中根据行号(row_id/row_number)获取一部分行,类似于pyarrow.Table.take这个功能。比如,给定下面这个数据表:

from datetime import datetime

df = pl.DataFrame(
    {
        "integer": [1, 2, 3, 4, 5],
        "date": [
            datetime(2022, 1, 1),
            datetime(2022, 1, 2),
            datetime(2022, 1, 3),
            datetime(2022, 1, 4),
            datetime(2022, 1, 5),
        ],
        "float": [4.0, 5.0, 6.0, 7.0, 8.0],
    }
)

print(df)

shape: (5, 3)
┌─────────┬─────────────────────┬───────┐
│ integer ┆ date                ┆ float │
│ ---     ┆ ---                 ┆ ---   │
│ i64     ┆ datetime[μs]        ┆ f64   │
╞═════════╪═════════════════════╪═══════╡
│ 1       ┆ 2022-01-01 00:00:00 ┆ 4.0   │
│ 2       ┆ 2022-01-02 00:00:00 ┆ 5.0   │
│ 3       ┆ 2022-01-03 00:00:00 ┆ 6.0   │
│ 4       ┆ 2022-01-04 00:00:00 ┆ 7.0   │
│ 5       ┆ 2022-01-05 00:00:00 ┆ 8.0   │
└─────────┴─────────────────────┴───────┘

我想要一个像 df.take([0, 4]) 这样的功能,它能返回下面这个数据表。

shape: (2, 3)
┌─────────┬─────────────────────┬───────┐
│ integer ┆ date                ┆ float │
│ ---     ┆ ---                 ┆ ---   │
│ i64     ┆ datetime[μs]        ┆ f64   │
╞═════════╪═════════════════════╪═══════╡
│ 1       ┆ 2022-01-01 00:00:00 ┆ 4.0   │
│ 5       ┆ 2022-01-05 00:00:00 ┆ 8.0   │
└─────────┴─────────────────────┴───────┘

这些行号是通过其他过程得来的,然后交给我的。我尝试使用 df.select(pl.all().take([take_indices]),但发现它比直接使用过滤器慢,也就是 df.filter(filter_expr)。请注意,我是在处理非常大的数据集(超过1亿行)。

补充:感谢你的回答。使用 df[[take_indices]] 确实有效。不过我还是很好奇,为什么过滤器的速度比 select.gather 和方括号的方法都快。在我有5000万行的数据集上,时间如下:

select.gather: 0.5秒
方括号: 0.32秒 [与mozway的时间一致]
过滤器: 0.18秒

1 个回答

1

df[[0,4]] 这个写法可以用来选择第0和第4个索引的数据。

因为 take 这个方法已经不再推荐使用,所以你可以用 gather 来替代你提到的代码:

df.select(pl.all().gather([0, 4]))

输出结果:

shape: (2, 3)
┌─────────┬─────────────────────┬───────┐
│ integer ┆ date                ┆ float │
│ ---     ┆ ---                 ┆ ---   │
│ i64     ┆ datetime[μs]        ┆ f64   │
╞═════════╪═════════════════════╪═══════╡
│ 1       ┆ 2022-01-01 00:00:00 ┆ 4.0   │
│ 5       ┆ 2022-01-05 00:00:00 ┆ 8.0   │
└─────────┴─────────────────────┴───────┘

处理50万行数据的时间:

# df.select(pl.all().gather([0, 4]))
145 µs ± 9.43 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

# df[[0,4]]
122 µs ± 14.3 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

处理500万行数据的时间:

# df.select(pl.all().gather([0, 4]))
150 µs ± 13.3 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

# df[[0,4]]
117 µs ± 17.7 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

撰写回答