在scikit-learn中实现每个对象有3个特征的K近邻分类器

6 投票
1 回答
15253 浏览
提问于 2025-04-17 13:32

我想用scikit-learn模块来实现一个KNeighborsClassifier(K近邻分类器)(http://scikit-learn.org/dev/modules/generated/sklearn.neighbors.KNeighborsClassifier.html)。

我从我的图像中提取了物体的形状、拉伸度和Hu矩特征。那我该如何准备这些数据用于训练和验证呢?我需要为每个从图像中提取的物体创建一个包含这三个特征[Hm, e, s]的列表(因为一张图像可能有多个物体)吗?

我看了这个例子(http://scikit-learn.org/dev/modules/generated/sklearn.neighbors.KNeighborsClassifier.html):

X = [[0], [1], [2], [3]]
y = [0, 0, 1, 1]
from sklearn.neighbors import KNeighborsClassifier
neigh = KNeighborsClassifier(n_neighbors=3)
neigh.fit(X, y) 

print(neigh.predict([[1.1]]))
print(neigh.predict_proba([[0.9]]))

X和y是两个特征吗?

samples = [[0., 0., 0.], [0., .5, 0.], [1., 1., .5]]
from sklearn.neighbors import NearestNeighbors
neigh = NearestNeighbors(n_neighbors=1)
neigh.fit(samples) 

print(neigh.kneighbors([1., 1., 1.])) 

为什么在第一个例子中使用X和y,而现在用sample?

1 个回答

13

你第一段代码定义了一个用于一维数据的分类器。

X 代表特征向量。

[0] is the feature vector of the first data example
[1] is the feature vector of the second data example
....
[[0],[1],[2],[3]] is a list of all data examples, 
  each example has only 1 feature.

y 代表标签。

下面的图展示了这个概念:

这里输入图片描述

  • 绿色节点是标签为0的数据
  • 红色节点是标签为1的数据
  • 灰色节点是标签未知的数据。
    print(neigh.predict([[1.1]]))

这里是在让分类器预测 x=1.1 的标签。

    print(neigh.predict_proba([[0.9]]))

这里是在让分类器给出每个标签的归属概率估计。

因为两个灰色节点离绿色节点比较近,所以下面的输出是合理的。

    [0] # green label
    [[ 0.66666667  0.33333333]]  # green label has greater probability

第二段代码实际上对 scikit-learn 有很好的说明:

在下面的例子中,我们从一个数组构建了一个 NeighborsClassifier 类,并询问谁是离 [1,1,1] 最近的点。

>>> samples = [[0., 0., 0.], [0., .5, 0.], [1., 1., .5]]
>>> from sklearn.neighbors import NearestNeighbors
>>> neigh = NearestNeighbors(n_neighbors=1)
>>> neigh.fit(samples) 
NearestNeighbors(algorithm='auto', leaf_size=30, ...)
>>> print(neigh.kneighbors([1., 1., 1.])) 
(array([[ 0.5]]), array([[2]]...))

这里没有目标值,因为这只是一个 NearestNeighbors 类,它不是分类器,所以不需要标签。

对于你自己的问题:

因为你需要一个分类器,如果想使用 KNN 方法,应该使用 KNeighborsClassifier。你可能想像下面这样构建你的特征向量 X 和标签 y

X = [ [h1, e1, s1], 
      [h2, e2, s2],
      ...
    ]
y = [label1, label2, ..., ]

撰写回答