Pandas:使用条件滚动计数时内存过多

2024-05-23 19:17:41 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图对一列中出现的观察值进行滚动计数,给定另一列中指定的固定窗口长度(按组)。通过一个例子可以更好地解释这一点:

df = pd.DataFrame({'B': ['X', 'X' , 'Y', 'X', 'Y', 'Y', 'X', 'X', 'Y', 'Y', 'X', 'Y'],
                   'group': ["IT", "IT", "IT", "MV", "MV", "MV", "IT", "MV", "MV", "IT", "IT", "MV"]})

for i in df['B'].unique():
    df.loc[df['B']==i, 'count'] = df.where(df['B'].eq(i)).groupby(df['group'])['B'].transform(lambda x: x.rolling(3, min_periods=1).count().shift(fill_value=0))
print(df)

    B group  count
0   X    IT    0.0
1   X    IT    1.0
2   Y    IT    0.0
3   X    MV    0.0
4   Y    MV    0.0
5   Y    MV    1.0
6   X    IT    2.0
7   X    MV    1.0
8   Y    MV    2.0
9   Y    IT    1.0
10  X    IT    1.0
11  Y    MV    2.0

如上所述,我们按“组”分组,在窗口长度为3的情况下,对B列中的“X”和“Y”进行滚动计数。如果“X”是当前行,那么我们计算“X”在“group”组中的前3次观察中出现的次数,不包括当前行的计数(因此按时段向后移动=1)

但是,当使用大型数据集时,此代码速度较慢,并且占用了太多内存。感谢您在这方面的改进


Tags: indataframedfforcountgroupitwhere
1条回答
网友
1楼 · 发布于 2024-05-23 19:17:41

我为您的问题开发了另一个解决方案,它基于分组和使用一个热编码(get_dummy

代码如下:

df = pd.DataFrame({'B': ['X', 'X' , 'Y', 'X', 'Y', 'Y', 'X', 'X', 'Y', 'Y', 'X', 'Y'],
                   'group': ["IT", "IT", "IT", "MV", "MV", "MV", "IT", "MV", "MV", "IT", "IT", "MV"]})

# add a one-hot encoding to the dataframe. 
t = pd.concat([df, pd.get_dummies(df.B)], axis=1)

t.index.name = "inx"

# do a rolling sum of 4. It's the past 3, plus 1. 
t = t.groupby("group").rolling(4, min_periods = 1).sum()
t = t.reset_index().set_index("inx").sort_index()

# remove the extra '1' from the rolling result. 
t.loc[:, ["X", "Y"]] = t.loc[:, ["X", "Y"]] - 1

# merge back the results with the original dataframe. 
t = pd.concat([df, t[["X", "Y"]]], axis=1)

# create a 'count' column which is based on the values of 'B'. 
t["count"] = t.lookup(t.index, t.B )

输出为:

     B group    X    Y  count
inx                          
0    X    IT  0.0 -1.0    0.0
1    X    IT  1.0 -1.0    1.0
2    Y    IT  1.0  0.0    0.0
3    X    MV  0.0 -1.0    0.0
4    Y    MV  0.0  0.0    0.0
5    Y    MV  0.0  1.0    1.0
6    X    IT  2.0  0.0    2.0
7    X    MV  1.0  1.0    1.0
8    Y    MV  0.0  2.0    2.0
9    Y    IT  1.0  1.0    1.0
10   X    IT  1.0  1.0    1.0
11   Y    MV  0.0  2.0    2.0

一网打尽:

df['count'] = (pd.concat([df, df['B'].str.get_dummies()], axis=1)
                 .groupby('group')
                 .rolling(4, min_periods=1)
                 .sum()
                 .sort_index(level=1)
                 .reset_index(drop=True)
                 .lookup(df.index, df['B']) - 1)

相关问题 更多 >