Polars 分组描述扩展

3 投票
1 回答
51 浏览
提问于 2025-04-13 03:01

df 是一个示例的 Polars 数据框:

df = pl.DataFrame(
    {
        "groups": ["A", "A", "A", "B", "B", "B"],
        "values": [1, 2, 3, 4, 5, 6],
        }
)

现在的 group_by.agg() 方法在创建描述性统计时有点不方便:

print(
    df.group_by("groups").agg(
    pl.len().alias("count"),
    pl.col("values").mean().alias("mean"),
    pl.col("values").std().alias("std"),
    pl.col("values").min().alias("min"),
    pl.col("values").quantile(0.25).alias("25%"),
    pl.col("values").quantile(0.5).alias("50%"),
    pl.col("values").quantile(0.75).alias("75%"),
    pl.col("values").max().alias("max"),
    pl.col("values").skew().alias("skew"),
    pl.col("values").kurtosis().alias("kurtosis"),
)
)

out:
shape: (2, 11)
┌────────┬───────┬──────┬─────┬───┬─────┬─────┬──────┬──────────┐
│ groups ┆ count ┆ mean ┆ std ┆ … ┆ 75% ┆ max ┆ skew ┆ kurtosis │
│ ---    ┆ ---   ┆ ---  ┆ --- ┆   ┆ --- ┆ --- ┆ ---  ┆ ---      │
│ str    ┆ u32   ┆ f64  ┆ f64 ┆   ┆ f64 ┆ i64 ┆ f64  ┆ f64      │
╞════════╪═══════╪══════╪═════╪═══╪═════╪═════╪══════╪══════════╡
│ B      ┆ 3     ┆ 5.0  ┆ 1.0 ┆ … ┆ 6.0 ┆ 6   ┆ 0.0  ┆ -1.5     │
│ A      ┆ 3     ┆ 2.0  ┆ 1.0 ┆ … ┆ 3.0 ┆ 3   ┆ 0.0  ┆ -1.5     │
└────────┴───────┴──────┴─────┴───┴─────┴─────┴──────┴──────────┘

我想写一个自定义的 group_by 扩展模块,这样我就可以通过调用来实现相同的结果:

df.describe(by="groups", percentiles=[xxx], skew=True, kurt=True)

或者

df.group_by("groups").describe(percentiles=....)

1 个回答

5

调用这个会输出和你在问题中提到的一样。



class DescribeAccessor:
    def __init__(self, df: pl.DataFrame):
        self._df = df

    def __call__(
            self,
            by: str,
            percentiles: list = [0.25, 0.5, 0.75],
            skew: bool = True,
            kurt: bool = True,
    ) -> pl.DataFrame:
        percentile_exprs = [
            pl.col("values").quantile(p).alias(f"{int(p * 100)}%")
            for p in percentiles
        ]

        aggs = [
            pl.len().alias("count"),
            pl.col("values").mean().alias("mean"),
            pl.col("values").std().alias("std"),
            pl.col("values").min().alias("min"),
            *percentile_exprs,
            pl.col("values").max().alias("max"),
        ]

        if skew:
            aggs.append(pl.col("values").skew().alias("skew"))

        if kurt:
            aggs.append(pl.col("values").kurtosis().alias("kurtosis"))

        return self._df.groupby(by).agg(aggs)


pl.DataFrame.describe = property(lambda self: DescribeAccessor(self))

df = pl.DataFrame(
    {
        "groups": ["A", "A", "A", "B", "B", "B"],
        "values": [1, 2, 3, 4, 5, 6],
    }
)

print(df.describe(by="groups", percentiles=[0.25, 0.5, 0.75], skew=True, kurt=True))

撰写回答