sklearn 逻辑回归 - 重要特征

16 投票
3 回答
44569 浏览
提问于 2025-04-18 10:01

我很确定这个问题之前有人问过,但我找不到答案。

我在Python中使用sklearn运行逻辑回归,能够通过Transform方法将我的数据集转化为最重要的特征。

classf = linear_model.LogisticRegression()
func  = classf.fit(Xtrain, ytrain)
reduced_train = func.transform(Xtrain)

我怎么才能知道哪些特征被选为最重要的呢?更一般来说,我怎么能计算出数据集中每个特征的p值呢?

3 个回答

4

LogisticRegression.transform 这个方法需要一个 threshold 值,用来决定哪些特征(数据的不同方面)要保留。下面是文档中的说明:

阈值:字符串、浮点数或 None,选填(默认是 None) 这个阈值用来选择特征。重要性大于或等于这个值的特征会被保留,而其他的则会被丢弃。如果设置为“median”(中位数)或“mean”(平均数),那么阈值就会是特征重要性的中位数或平均数。你也可以使用一个缩放因子,比如“1.25*mean”。如果设置为 None,并且有可用的对象属性 threshold,那么就会使用这个属性。否则,默认会使用“mean”。

在 LR 估计器中没有 threshold 这个属性,所以默认情况下,只会保留那些绝对值大于平均值的特征(在对所有类别求和之后)。

4

你可以查看拟合模型中的 coef_ 属性,来了解哪些特征是最重要的。(对于LogisticRegression来说,transform 只是用来查看哪些系数的绝对值最大。)

大多数scikit-learn模型并没有提供计算p值的方法。一般来说,这些模型的设计目的是用来实际预测结果,而不是用来分析预测的过程。如果你对p值感兴趣,可以看看 statsmodels,不过它的成熟度比sklearn稍低一些。

15

正如上面评论中提到的,你可以(而且应该)在进行模型训练之前对数据进行缩放,这样可以让模型的系数变得可比。下面是一段小代码,展示了这个过程是如何工作的。我遵循了这个格式来进行比较。

import numpy as np    
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
import pandas as pd
import matplotlib.pyplot as plt

x1 = np.random.randn(100)
x2 = np.random.randn(100)
x3 = np.random.randn(100)

#Make difference in feature dependance
y = (3 + x1 + 2*x2 + 5*x3 + 0.2*np.random.randn()) > 0

X = pd.DataFrame({'x1':x1,'x2':x2,'x3':x3})

#Scale your data
scaler = StandardScaler()
scaler.fit(X) 
X_scaled = pd.DataFrame(scaler.transform(X),columns = X.columns)

clf = LogisticRegression(random_state = 0)
clf.fit(X_scaled, y)

feature_importance = abs(clf.coef_[0])
feature_importance = 100.0 * (feature_importance / feature_importance.max())
sorted_idx = np.argsort(feature_importance)
pos = np.arange(sorted_idx.shape[0]) + .5

featfig = plt.figure()
featax = featfig.add_subplot(1, 1, 1)
featax.barh(pos, feature_importance[sorted_idx], align='center')
featax.set_yticks(pos)
featax.set_yticklabels(np.array(X.columns)[sorted_idx], fontsize=8)
featax.set_xlabel('Relative Feature Importance')

plt.tight_layout()   
plt.show()

撰写回答