计算pandas中滚动交集的大小

3 投票
1 回答
753 浏览
提问于 2025-04-17 22:43

我有一个叫做 dataframe 的东西,它包含了组标签(比如 'B')和每个组的元素(比如 'A')。这些组标签是有顺序的,我想知道在组 i+1 中有多少个来自组 i 的元素。

举个例子:

df= pd.DataFrame({ 'A': ['a','b','c','a','c','a','d'], 'B' : [1,1,1,2,2,3,3]})

   A  B
0  a  1
1  b  1
2  c  1
3  a  2
4  c  2
5  a  3
6  d  3

我想要的结果大概是这样的:

B
1  NaN
2  2
3  1

一种解决方法是先计算组 i 和组 i+1 中所有不同元素的总数,然后再减去每个组中不同元素的数量。我试过这样做:

pd.rolling_apply(grp['A'], lambda x: len(x.unique()),2)

但是这出现了错误:

AttributeError: 'Series' object has no attribute 'type'

我该如何使用 rolling_apply 来解决这个问题,或者有没有更好的方法呢?

1 个回答

2

这里介绍了一种使用集合和移动结果的方法:

首先,我们要对数据进行分组,然后把每个组的A列转换成一个集合:

In [86]: grp = df.groupby('B')
In [87]: s = grp.apply(lambda x : set(x['A']))
In [88]: s
Out[88]: 
B
1    set([a, c, b])
2       set([a, c])
3       set([a, d])
dtype: object

接下来,我们要计算相邻集合之间的交集,所以我们需要制作一个移动版本(我把NaN替换成一个空集合,以便下一步使用):

In [89]: s2 = s.shift(1).fillna(set([]))
In [90]: s2
Out[90]: 
B
1           set([])
2    set([a, c, b])
3       set([a, c])
dtype: object

然后把这两个系列合并,计算交集的长度:

In [91]: s.combine(s2, lambda x, y: len(x.intersection(y)))
Out[91]: 
B
1    0
2    2
3    1
dtype: object

还有一种方法可以完成最后一步(对于集合来说,&表示交集):

df = pd.concat([s, s2], axis=1)
df.apply(lambda x: len(x[0] & x[1]), axis=1)

滚动应用不工作的原因有两个:1)你给它提供的是一个分组对象,而不是一个系列,2)它只适用于数字值。

撰写回答