有没有方法用plotly重现Shap值图?

1 投票
2 回答
49 浏览
提问于 2025-04-12 04:49

我正在用Plotly重新制作SHAP库中的总结图。我有两个数据集:

  • 一个SHAP值数据集,里面包含了我原始数据集中每个数据点的SHAP值。
  • 原始数据集,里面有特征的一热编码值。这意味着它的值只有0或1。

我的目标是创建一个SHAP值的蜂群图,并根据一热编码变量中对应点的值是0还是1,为每个点分配不同的颜色。

这是我写的代码,它能生成描述的图,但只针对一个变量:

figures = []


for column in shap_values.columns:
    fig = px.strip(merged_df, x=merged_df[column+'_shap'], color=merged_df[column+'_train'], orientation='h', stripmode='overlay')

    fig.update_layout(
        title=f'Bee swarm plot de la valeur de Shapley pour {column}',
        xaxis_title='Valeur de Shapley (impact sur la sortie du modèle)',
        yaxis_title='Caractéristique'
    )
    
    figures.append(fig)

有没有办法把这些图合并成一个综合的图呢?

这是数据的一个示例:

shap_values = pd.DataFrame(
    {"A" : [-0.065704,-0.096510,0.062368,0.062368,0.063093], 
     'B' : [-0.168249,-0.173284,-0.168756,-0.168756,-0.169378]})

train  = pd.DataFrame(
    {"A" : [0,1,1,0,0], 
     'B' : [1,1,0,0,1]})

merged_df = shap_values.join(train, lsuffix='_shap', rsuffix='_train’)

2 个回答

0

谢谢你,Krish。我稍微修改了一下代码,这段代码现在可以正常工作了:

shap_values = pd.DataFrame({
"A": [-0.065704, -0.096510, 0.062368, 0.062368, 
0.063093],
"B": [-0.168249, -0.173284, -0.168756, -0.168756, 
-0.169378]
})

train = pd.DataFrame({
"A": [0, 1, 1, 0, 0],
"B": [1, 1, 0, 0, 1]
})

# Joining SHAP values and one-hot encoded features
merged_df = shap_values.join(train, lsuffix='_shap', rsuffix='_train')

# Melt the merged DataFrame to long format
melted_df = merged_df.melt(value_vars=[col for col in 
merged_df.columns if '_shap' in col],
                       var_name='Feature', 
                       value_name='SHAP Value')
melted_df['Feature'] = melted_df['Feature'].str.replace('_shap', '', regex=False)

# Directly assign the 'One-hot Value' using a vectorized approach
# This avoids using apply() which caused the indexing issue
for feature in train.columns:
    feature_shap = feature + '_shap'
    feature_train = feature + '_train'
    melted_df.loc[melted_df['Feature'] == feature, 'One-hot Value'] = merged_df[feature_train].values

# Generate the plot again
fig = px.strip(melted_df, x='SHAP Value', y='Feature', 
               color='One-hot Value',
               orientation='h', stripmode='overlay', 
               title='Bee Swarm Plot of SHAP Values by Feature')

fig.update_layout(xaxis_title='SHAP Value (Impact on Model Output)',
                  yaxis_title='Feature')

fig.show()
0
import pandas as pd
import plotly.express as px

# Sample data
shap_values = pd.DataFrame({
"A": [-0.065704, -0.096510, 0.062368, 0.062368, 
0.063093],
"B": [-0.168249, -0.173284, -0.168756, -0.168756, 
-0.169378]
})

train = pd.DataFrame({
"A": [0, 1, 1, 0, 0],
"B": [1, 1, 0, 0, 1]
})

# Joining SHAP values and one-hot encoded features
merged_df = shap_values.join(train, lsuffix='_shap', 
rsuffix='_train')

# Melt the merged DataFrame to long format
melted_df = merged_df.melt(value_vars=[col for col in 
merged_df.columns if '_shap' in col],
                       var_name='Feature', 
                       value_name='SHAP Value')

# Extract the original feature name and merge with the 
# one-hot encoded values
melted_df['Feature'] = 
melted_df['Feature'].str.replace('_shap', '')
melted_df['One-hot Value'] = melted_df.apply(lambda x: 
merged_df.loc[x.name, x['Feature'] + '_train'], axis=1)

fig = px.strip(melted_df, x='SHAP Value', y='Feature', 
           color='One-hot Value',
           orientation='h', stripmode='overlay', 
title='Bee Swarm Plot of SHAP Values by Feature')

fig.update_layout(
xaxis_title='SHAP Value (Impact on Model Output)',
yaxis_title='Feature')

fig.show()

当然可以!请把你想要翻译的内容发给我,我会帮你把它变得更简单易懂。

撰写回答