如何更改groupby列以查找符合掩码条件的首行,如果初始groupby未找到?

2 投票
2 回答
73 浏览
提问于 2025-04-12 03:06

这是我的数据表:

import pandas as pd
df = pd.DataFrame(
    {
        'main': ['x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'y', 'y', 'y', 'y', 'y', 'y', 'y'],
        'sub': ['c', 'c', 'c', 'd', 'd', 'e', 'e', 'e', 'e', 'f', 'f', 'f', 'f', 'g', 'g', 'g'],
        'num_1': [10, 9, 80, 80, 99, 101, 110, 222, 90, 1, 7, 10, 2, 10, 95, 10],
        'num_2': [99, 99, 99, 102, 102, 209, 209, 209, 209, 100, 100, 100, 100, 90, 90, 90]
    }
)

这是我期望的输出。我想添加一列 result

   main sub  num_1  num_2  result
0     x   c     10     99     101
1     x   c      9     99     101
2     x   c     80     99     101
3     x   d     80    102     110
4     x   d     99    102     110
5     x   e    101    209     222
6     x   e    110    209     222
7     x   e    222    209     222
8     x   e     90    209     222
9     y   f      1    100     NaN
10    y   f      7    100     NaN
11    y   f     10    100     NaN
12    y   f      2    100     NaN
13    y   g     10     90      95
14    y   g     95     90      95
15    y   g     10     90      95

这个条件筛选叫做“掩码”:

mask = (df.num_1 > df.num_2)

处理过程是这样的:

a) 我用 sub 这一列来分组。

b) 找到每个组中符合掩码条件的第一行。

c)num_1 的值放到 result 这一列里。

如果没有任何行符合掩码条件,那么就把分组列改成 main,去找 mask 的第一行。这个阶段有个条件:

在用 main 作为分组列时,之前的 subs 不应该被考虑。

下面是上述步骤在 d 组中的一个例子:

a) sub 是分组列。

b) 在 d 组中,没有行满足 df.num_1 > df.num_2 的条件。

所以现在要在 d 组中查找它的 main 组。不过 c 组也在这个 main 组里。因为 c 组在 d 组之前,所以在这一步中 c 组不应该算。

在这张图片中,我展示了这些值的来源:

在这里输入图片描述

这是我的尝试。它部分解决了一些组的问题,但并不是所有组都能解决:

def step_a(g):
    mask = (g.num_1 > g.num_2)

    g.loc[mask.cumsum().eq(1) & mask, 'result'] = g.num_1
    g['result'] = g.result.ffill().bfill()
    return g

a = df.groupby('sub').apply(step_a)

2 个回答

1

代码

定义一个自定义函数

def get_result(d):
    cond = d['num_1'].gt(d['num_2'])
    if cond.sum() > 0:
        target = d['num_1'].where(cond).bfill().iloc[0]        
    else:
        df1 = df[df['main'].eq(d['main'].iloc[0])]
        s1 = df1['num_2'].where(df1['sub'].eq(d['sub'].iloc[0])).ffill()
        target = df1['num_1'].where(df1['num_1'].gt(s1)).bfill().iloc[0]

    return pd.Series(target, index=d.index)

创建一个 result

df['result'] = df.groupby(['main', 'sub'], group_keys=False).apply(get_result)

数据框(df):

   main sub  num_1  num_2  result
0     x   c     10     99   101.0
1     x   c      9     99   101.0
2     x   c     80     99   101.0
3     x   d     80    102   110.0
4     x   d     99    102   110.0
5     x   e    101    209   222.0
6     x   e    110    209   222.0
7     x   e    222    209   222.0
8     x   e     90    209   222.0
9     y   f      1    100     NaN
10    y   f      7    100     NaN
11    y   f     10    100     NaN
12    y   f      2    100     NaN
13    y   g     10     90    95.0
14    y   g     95     90    95.0
15    y   g     10     90    95.0
1

如果我理解得没错,你可以利用的广播功能来为每个“主”元素创建一个掩码,然后用这个掩码来找出第一个大于num2的num1,同时只考虑接下来的几个组:

def find(g):
    # get sub as 0,1,2…
    sub = pd.factorize(g['sub'])[0]
    # convert inputs to numpy
    n1 = g['num_1'].to_numpy()
    n2 = g.loc[~g['sub'].duplicated(), 'num_2'].to_numpy()
    # form mask
    # (n1[:, None] > n2) -> num_1 > num_2
    # (sub[:, None] >= np.arange(len(n2))) -> exclude previous groups
    m = (n1[:, None] > n2) & (sub[:, None] >= np.arange(len(n2)))
    # find first True per column
    return pd.Series(np.where(m.any(0), n1[m.argmax(0)], np.nan)[sub],
                     index=g.index)


df['result'] = df.groupby('main', group_keys=False).apply(find)

注意,你可以轻松调整这些掩码,以实现其他逻辑(比如搜索接下来的n个组,排除所有之前的组,除了紧接着的那个,等等)。

输出结果:

# example 1                               # example 2 (from comments)
   main sub  num_1  num_2  result            main sub  num_1  num_2  result
0     x   c     10     99   101.0         0     x   d     10    102   110.0
1     x   c      9     99   101.0         1     x   d      9    102   110.0
2     x   c     80     99   101.0         2     x   c     80     99   101.0
3     x   d     80    102   110.0         3     x   c     80     99   101.0
4     x   d     99    102   110.0         4     x   c     99     99   101.0
5     x   e    101    209   222.0         5     x   e    101    209   222.0
6     x   e    110    209   222.0         6     x   e    110    209   222.0
7     x   e    222    209   222.0         7     x   e    222    209   222.0
8     x   e     90    209   222.0         8     x   e     90    209   222.0
9     y   f      1    100     NaN         9     y   f      1    100     NaN
10    y   f      7    100     NaN         10    y   f      7    100     NaN
11    y   f     10    100     NaN         11    y   f     10    100     NaN
12    y   f      2    100     NaN         12    y   f      2    100     NaN
13    y   g     10     90    95.0         13    y   g     10     90    95.0
14    y   g     95     90    95.0         14    y   g     95     90    95.0
15    y   g     10     90    95.0         15    y   g     10     90    95.0

中间掩码m,这里是第二个例子的情况:

# main == 'x'
#          99    102    209
array([[False, False, False],  #  10
       [False, False, False],  #   9
       [False, False, False],  #  80
       [False, False, False],  #  80
       [False, False, False],  #  99
       [False,  True, False],  # 101
       [ True,  True, False],  # 110
       [ True,  True,  True],  # 222
       [False, False, False]]) #  90
# out:    110    101    222


# main == 'y'
#         100     90
array([[False, False],  #   1
       [False, False],  #   7
       [False, False],  #  10
       [False, False],  #   1
       [False, False],  #  10
       [False,  True],  #  95
       [False, False]]) #  90
# out:    NaN     95

撰写回答