使用scikit-learn获取错误分类的文档
我想知道在scikitlearn这个Python模块里,是否有内置的函数可以用来找出被错误分类的文档。
其实这很简单,我通常是自己写代码,通过比较预测结果和测试向量,从测试文档数组中找出那些文档。但我在想,是否有现成的功能可以用,而不是每次写Python代码时都要重复这个功能。
2 个回答
0
你可以用列表推导式来获取那些被错误分类的东西。除此之外,我不知道在sklearn中还有其他方法可以做到这一点。
from sklearn.cross_validation import train_test_split
from sklearn import datasets
from sklearn import svm
iris = datasets.load_iris()
X_iris, y_iris = iris.data, iris.target
X, y = X_iris[:, :2], y_iris
X_train, X_test, y_train, y_test = train_test_split(X, y)
clf = svm.LinearSVC()
clf.fit(X_train, y_train)
mis_cls = [train
for test, truth, train in
zip(X_test, y_test, X_train)
if clf.predict(test) != truth]
16
如果你有一组真实标签 y_test
,比如说 ["ham", "spam", "spam", "ham"]
,然后你把它转换成一个NumPy数组,那么你可以用一行代码来和预测结果进行比较:
import numpy as np
y_test = np.asarray(y_test)
misclassified = np.where(y_test != clf.predict(X_test))
现在 misclassified
就是一个包含 X_test
中索引的数组。
@eickenberg 说得对,这种功能在scikit-learn里没有实现,因为用户应该对NumPy有足够的了解,能够用几行代码自己完成这个操作。