使用scikit-learn获取错误分类的文档

11 投票
2 回答
10459 浏览
提问于 2025-04-18 18:51

我想知道在scikitlearn这个Python模块里,是否有内置的函数可以用来找出被错误分类的文档。

其实这很简单,我通常是自己写代码,通过比较预测结果和测试向量,从测试文档数组中找出那些文档。但我在想,是否有现成的功能可以用,而不是每次写Python代码时都要重复这个功能。

2 个回答

0

你可以用列表推导式来获取那些被错误分类的东西。除此之外,我不知道在sklearn中还有其他方法可以做到这一点。

from sklearn.cross_validation import train_test_split
from sklearn import datasets
from sklearn import svm


iris = datasets.load_iris()
X_iris, y_iris = iris.data, iris.target
X, y = X_iris[:, :2], y_iris
X_train, X_test, y_train, y_test = train_test_split(X, y)

clf = svm.LinearSVC()
clf.fit(X_train, y_train)

mis_cls = [train 
           for test, truth, train in 
           zip(X_test, y_test, X_train) 
           if clf.predict(test) != truth]
16

如果你有一组真实标签 y_test,比如说 ["ham", "spam", "spam", "ham"],然后你把它转换成一个NumPy数组,那么你可以用一行代码来和预测结果进行比较:

import numpy as np

y_test = np.asarray(y_test)
misclassified = np.where(y_test != clf.predict(X_test))

现在 misclassified 就是一个包含 X_test 中索引的数组。

@eickenberg 说得对,这种功能在scikit-learn里没有实现,因为用户应该对NumPy有足够的了解,能够用几行代码自己完成这个操作。

撰写回答