SHAP函数在plotting方法中引发异常

2024-04-29 09:34:48 发布

您现在位置:Python中文网/ 问答频道 /正文

samples.zip 示例压缩文件夹包含:

  1. model.pkl
  2. x_test.csv

要重复这些问题,请执行以下步骤:

  1. 使用lin2 =joblib.load('model.pkl')加载线性回归模型
  2. 使用x_test_2 = pd.read_csv('x_test.csv').drop(['Unnamed: 0'],axis=1)加载x_test_2
  3. 运行下面的代码以加载解释程序
explainer_test = shap.Explainer(lin2.predict, x_test_2)
shap_values_test = explainer_test(x_test_2)
  1. 然后运行partial_dependence_plot查看错误消息:

ValueError: x and y can be no greater than 2-D, but have shapes (2,) and (2, 1, 1)

sample_ind = 3
shap.partial_dependence_plot(
    "new_personal_projection_delta", 
    lin.predict, 
    x_test, model_expected_value=True,
    feature_expected_value=True, ice=False,
    shap_values=shap_values_test[sample_ind:sample_ind+1,:]
)
  1. 运行另一个函数以绘制瀑布图以查看错误消息:

Exception: waterfall_plot requires a scalar base_values of the model output as the first parameter, but you have passed an array as the first parameter! Try shap.waterfall_plot(explainer.base_values[0], values[0], X[0]) or for multi-output models try shap.waterfall_plot(explainer.base_values[0], values[0][0], X[0]).

shap.plots.waterfall(shap_values_test[sample_ind], max_display=14)

问题:

  1. 为什么我不能运行partial_dependence_plot&shap.plots.waterfall
  2. 为了运行上述方法,我需要对输入进行哪些更改

Tags: csvthesampletestbasemodelplotpartial