计算Pandas数据帧中两行之间的LDA分布之间的距离

2024-04-18 22:34:11 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一个数据框,包含LDA主题分发输出以及其他人口统计信息,如下所示:

single_df = pd.DataFrame([{"department": 'marketing', 'LDA_1': 0.252, 'LDA_2':0.002, 'LDA_3':0.50},
                          {"department": 'engineering', 'LDA_1': 0.478, 'LDA_2':0.152, 'LDA_3':0.492},
                          {"department": 'cooperate', 'LDA_1': 0.52, 'LDA_2':0.780, 'LDA_3':0.50},
                          {"department": "marketing", 'LDA_1': 0.352, 'LDA_2':0.052, 'LDA_3':0.20}])

enter image description here

我想进入下面的最终数据帧。我如何编写一个函数来计算返回到数据帧下方的两行(列名包含“LDA_”)之间的Jenson-Shannon距离

i j same_department distance_LDA
0 1          0        0.23
0 2          0        0.43
0 3          1        0.26
1 2          0        0.24
1 3          0        0.11
2 3          0        0.29

我已经编写了如下代码来计算各个对之间的JS距离。如何将其转换为函数

array=single_df.filter(regex='LDA').to_numpy()
distance.jensenshannon(array[0],array[1])

然后,为了计算两个人是否共享该部门,我有以下代码:

def same_department(i,j):
    if i['department'] == j['department']:
        return 1
    else:
        return 0   

Tags: 数据函数代码距离df主题returnarray
1条回答
网友
1楼 · 发布于 2024-04-18 22:34:11

让我们尝试生成所有可能的行组合,合并以生成一个数据帧,其中比较可以在同一行中进行。然后根据列后缀按行应用jensenshannon函数:

from itertools import combinations
from scipy.spatial.distance import jensenshannon
import pandas as pd

single_df = pd.DataFrame([{"department": 'marketing', 'LDA_1': 0.252,
                           'LDA_2': 0.002, 'LDA_3': 0.50},
                          {"department": 'engineering', 'LDA_1': 0.478,
                           'LDA_2': 0.152, 'LDA_3': 0.492},
                          {"department": 'cooperate', 'LDA_1': 0.52,
                           'LDA_2': 0.780, 'LDA_3': 0.50},
                          {"department": "marketing", 'LDA_1': 0.352,
                           'LDA_2': 0.052, 'LDA_3': 0.20}])

# Merge the 3 LDA Columns Into A Single Column Containing a List
single_df['LDA'] = single_df.filter(regex='^LDA_.*').agg(list, axis=1)
# Get Rid Of The Original LDA_X columns
single_df = single_df.filter(regex='^(?!LDA_.*)')

# Get All Row Combinations
a, b = map(list, zip(*combinations(single_df.index, 2)))

# Merge Together
df = single_df.loc[a].reset_index().merge(
    single_df.loc[b].reset_index(),
    left_index=True,
    right_index=True,
)

# Apply jensonshannon to LDA_x and LDA_y Lists
df['distance_LDA'] = df.apply(
    lambda x: jensenshannon(x['LDA_x'], x['LDA_y']), axis=1)

# Get If In Same Department
df['same_department'] = df['department_x'].eq(df['department_y']).astype(int)

# Rename and Filter Columns
df = df \
    .rename(columns={'index_x': 'i',
                     'index_y': 'j'})[['i', 'j',
                                       'same_department',
                                       'distance_LDA']]

# For Display
print(df.to_string(index=False))

输出:

i  j  same_department  distance_LDA
0  1                0      0.235849
0  2                0      0.429508
0  3                1      0.264777
1  2                0      0.238155
1  3                0      0.112456
2  3                0      0.299704

相关问题 更多 >