使用K近邻聚类色彩的高效方法

2 投票
3 回答
1065 浏览
提问于 2025-04-18 12:51

我正在尝试把一张图片里的颜色分成几个预定义的类别(黑色、白色、蓝色、绿色、红色)。我用的代码如下:

import numpy as np
import cv2

src = cv2.imread('objects.png')

colors = np.array([[0x00, 0x00, 0x00],
                   [0xff, 0xff, 0xff],
                   [0xff, 0x00, 0x00],
                   [0x00, 0xff, 0x00],
                   [0x00, 0x00, 0xff]], dtype=np.float32)
classes = np.array([[0], [1], [2], [3], [4]], np.float32)
dst = np.zeros(src.shape, np.float32)

knn = cv2.KNearest()
knn.train(colors, classes)

# This loop is very inefficient!
for i in range(0, src.shape[0]):
    for j in range(0, src.shape[1]):
        sample = np.reshape(src[i,j], (-1,3)).astype(np.float32)
        retval, result, neighbors, dist = knn.find_nearest(sample, 1)
        dst[i,j] = colors[result[0,0]]

cv2.imshow('src', src)
cv2.imshow('dst', dst)
cv2.waitKey()

这段代码运行得不错,结果如下。左边的图片是输入,右边的图片是输出。

src dst

不过,上面的循环效率很低,导致转换速度慢。有没有什么更高效的Numpy操作可以替代这个循环呢?

3 个回答

0

你可以建立一个查找表,这样就能知道每种颜色对应的类别是什么。这个表格不一定要是256x256x256那么大,你可以减少一些分类的数量。

2

如果你想要一个简单的平方差度量(也就是找出最接近的数字),这个方法可以用。

首先,计算差异:

diff = ((src[:,:,:,None] - colors.T)**2).sum(axis=2)

(假设 src 的形状是 y,x,3)

接下来,选择最接近的颜色索引:

index = diff.argmin(axis=2)

然后生成新图像:

out = colors[index]

如果你的颜色值确实是 0 或 0xff,你可以使用类似下面的代码:

out = np.where(src>0x88, 0xff, 0)
0

我成功地用下面的代码去掉了循环。这个代码运行得非常快,几乎和C++版本一样快。

import numpy as np
import cv2

src = cv2.imread('objects.png')
src_flatten = np.reshape(np.ravel(src, 'C'), (-1, 3))
dst = np.zeros(src.shape, np.float32)

colors = np.array([[0x00, 0x00, 0x00],
                   [0xff, 0xff, 0xff],
                   [0xff, 0x00, 0x00],
                   [0x00, 0xff, 0x00],
                   [0x00, 0x00, 0xff]], dtype=np.float32)
classes = np.array([[0], [1], [2], [3], [4]], np.float32)

knn = cv2.KNearest()
knn.train(colors, classes)
retval, result, neighbors, dist = knn.find_nearest(src_flatten.astype(np.float32), 1)

dst = colors[np.ravel(result, 'C').astype(np.uint8)]
dst = dst.reshape(src.shape).astype(np.uint8)

cv2.imshow('src', src)
cv2.imshow('dst', dst)
cv2.waitKey()

这个代码生成的结果和之前一样正确,而且执行时间更快了。

src dst

撰写回答