如何使用scikit-learn的CountVectorizer根据索引识别决策树中的特征名?

0 投票
2 回答
38 浏览
提问于 2025-04-14 16:56

我有一些数据,用来训练一个模型,目的是判断一句话是关于:

  • 猫或狗
  • 不是关于猫或狗

包含文本列和标签列的数据截图

我运行了以下代码来训练一个 DecisionTreeClassifier() 模型,然后查看树的可视化效果:

import numpy as np
from numpy.random import seed
import random as rn
import os
import pandas as pd
seed_num = 1
os.environ['PYTHONHASHSEED'] = '0'
np.random.seed(seed_num)
rn.seed(seed_num)

from sklearn.pipeline import Pipeline
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree

dummy_train = pd.read_csv('dummy_train.csv')

tree_clf = tree.DecisionTreeClassifier()

X_train = dummy_train["text"]
y_train = dummy_train["label"]

dt_tree_pipe = Pipeline([('vect', CountVectorizer(ngram_range=(1,1),
                                                 binary=True)),
                     ('tfidf', TfidfTransformer(use_idf=False)),
                      ('clf', DecisionTreeClassifier(random_state=seed_num,
                                                 class_weight={0:1, 1:1})),
                   ])

tree_model_fold_1 = dt_tree_pipe.fit(X_train, y_train)

tree.plot_tree(dt_tree_pipe["clf"])

...得到的树如下:

决策树可视化的截图

第一节点检查 x[7] 是否小于或等于 0.177我该如何找出 x[7] 代表哪个词呢?

我尝试了以下代码,但输出的词(“describing”和“the”)看起来不太对。我本以为 'cat''dog' 是用来将数据分成正类和负类的两个词。

vect_from_pipe = dt_tree_pipe["vect"]
words = vect_from_pipe.vocabulary_.keys()
print(list(words)[7])
print(list(words)[5])

显示' descripting '和'the'的截图

2 个回答

1

这个 vocabulary_ 属性里的内容并不是按顺序排列的;实际上,这个字典里的值告诉你特征的索引位置:

vocabulary_ : 字典
这是一个术语到特征索引的映射。

因为我们已经对树中的两个特征有了比较清晰的认识,你可以直接检查 vect_from_pipe.vocabulary_['cat'], vect_from_pipe.vocabulary_['dog'],看看它们的值是不是 5 和 7。如果不是,那你就需要反向查找这个字典,找出值为 5 和 7 的键是什么。不过,更简单的方法是直接用 vect_from_pipe.get_feature_names_out(),然后查看第 5 和第 7 个索引的内容。实际上,在 plot_tree 中使用这个方法是非常常见的:

tree.plot_tree(
    dt_tree_pipe[-1],
    feature_names = df_tree_pipe[:-1].get_feature_names_out(),
)
1

scikit-learn 中,你要找的术语是 特征名称。这些特征名称就是在进行任何转换之前的输入。

在你的代码中,你正在访问 CountVectorizervocabulary_ 属性,这个属性会返回一个字典,字典的键是单词,值是它们的索引。当你把这些键转换成列表并访问第7个或第5个元素时,这并不一定对应于特征矩阵中第7个或第5个索引的单词。

如果你想获取某个特定索引对应的特征名称(单词),你应该使用 get_feature_names_out() 方法,这个方法会返回一个按特征矩阵中索引顺序排列的特征名称列表。

可以用下面的代码替换:

vect_from_pipe = dt_tree_pipe["vect"]
feature_names = vect_from_pipe.get_feature_names_out()
print(feature_names[7])
print(feature_names[5])

这样会打印出在你的特征矩阵中索引7和5对应的单词。索引7的单词是在你的决策树第一次分裂时使用的。所以在你的情况下,x[7] 在决策树中对应于 CountVectorizer 中的 feature_names[7]

撰写回答