lightGBM模型上的部分依赖图(PDPbox)错误:“传递的项目数错误3,位置暗示”

2024-05-12 16:12:52 发布

您现在位置:Python中文网/ 问答频道 /正文

我有一段这样的代码

import lightgbm as lgb
from pdpbox import pdp, get_dataset, info_plots
import seaborn as sns
from sklearn.model_selection import train_test_split

#load some data
df = sns.load_dataset("iris")

X_train, X_test, y_train, y_test = train_test_split(
        X, y, train_size=0.80)
lgd_train = lgb.Dataset(X_train, label=y_train)
params={ "objective": "multiclass",
            "num_class": 3,}
clf = lgb.train(params, d_train)
#plot partial dependence
pdp_dist = pdp.pdp_isolate(
            model=clf, dataset=X_train, model_features=X_train.columns, feature='petal_width'
        )
pdp.pdp_plot(pdp_dist, 'petal_width')

这可能与predict输出有关,可能是针对3IRIS类的,但我不知道如何修复它

如果我使用lgb.LGBMClassifier().fit(X\u-train,y\u-train)的另一个lightGBM语法,那么它会不断返回一个错误

[LightGBM] [Fatal] Do not support special JSON characters in feature name.

尽管我的数据绝对没有特殊字符。有办法避开它吗?谢谢1


Tags: fromtestimportmodelplotasloadtrain
2条回答

如果切换到lightgbmsklearnAPI,则错误消息将消失。这是docs中建议的正确API:

model: a fitted sklearn model

证明

import lightgbm as lgb
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from pdpbox import pdp, get_dataset, info_plots

#load some data
df = sns.load_dataset("iris")
X = df.iloc[:,:4]
y , mapping = pd.factorize(df["species"])

X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.80)

lgd_train = lgb.Dataset(X_train, label=y_train)
params={ "objective": "multiclass",
            "num_class": 3,}

clf = lgb.LGBMClassifier() # <  Choose sklearn API !!!
clf.fit(X_train, y_train)

pdp_dist = pdp.pdp_isolate(model=clf, dataset=X_train
                           , model_features=X_train.columns
                           , feature='petal_width')
pdp.pdp_plot(pdp_dist, 'petal_width');

enter image description here

就您的另一个错误而言,如果从conda-forge安装了lightgbm v. 2.3.1,我无法在我的计算机上重现它

因此,我相信你有两个行动方案:

  • 安装没有此类问题的lightgbm版本
  • 或替换列名中的所有非字母数字字符:
X_train.columns = ["".join (c if c.isalnum() else "_" for c in str(x)) for x in X_train.columns]

如建议here

请在开始工作时使用此代码:

import re
df = df.rename(columns = lambda x:re.sub('[^A-Za-z0-9_]+', '', x))

相关问题 更多 >