有没有办法让这个Python kNN函数更高效?

3 投票
2 回答
2077 浏览
提问于 2025-04-27 21:46

在使用MATLAB时遇到了一些麻烦,所以我决定试试Python:

我写了一个函数,用自己的距离计算方法来计算kNN,当样本是我自己定义的类时:

def closestK(sample, otherSamples, distFunc, k):
"Returns the closest k samples to sample based on distFunc"
    n = len(otherSamples)
    d = [distFunc(sample, otherSamples[i]) for i in range(0,n)]
    idx  = sorted(range(0,len(d)), key=lambda k: d[k])
    return idx[1:(k+1)]

def kNN(samples, distFunc, k):
    return [[closestK(samples[i], samples, distFunc, k)] for i in range(len(samples))]

这是我用来计算距离的函数:

@staticmethod    
def distanceRepr(c1, c2):
    r1 = c1.repr
    r2 = c2.repr
    # because cdist needs 2D array
    if r1.ndim == 1:
        r1 = np.vstack([r1,r1])
    if r2.ndim == 1:
        r2 = np.vstack([r2,r2])

    return scipy.spatial.distance.cdist(r1, r2, 'euclidean').min()

但是和“正常”的kNN函数相比,它的运行速度还是慢得惊人,即使使用“暴力”算法。难道我做错了什么吗?

更新

我添加了这个类的构造函数。属性repr包含了一组向量(从1到任意值),距离是计算这两个repr集合之间的最小欧几里得距离。

class myCluster:
    def __init__(self, index = -1, P = np.array([])):
        if index ==-1 :
            self.repr = np.array([])
            self.IDs = np.array([])
            self.n = 0
            self.center = np.array([])
        else:
            self.repr = np.array(P)
            self.IDs = np.array(index)
            self.n = 1
            self.center = np.array(P)

还有其他相关的代码(X是一个矩阵,行是样本,列是变量):

level = [myCluster(i, X[i,:]) for i in range(0,n)]
kNN(level, myCluster.distanceRepr, 3)

更新 2

我做了一些测量,发现耗时最长的那行是

d = [distFunc(sample, otherSamples[i]) for i in range(0,n)]

所以问题出在distFunc上。当我把它改成返回

np.linalg.norm(c1.repr-c2.repr)

也就是“正常”的向量计算,没有排序时,运行时间保持不变。所以问题出在调用这个函数上。使用类会让运行时间增加60倍,这合理吗?

暂无标签

2 个回答

0

我想到了一些要点:

  • 每次你调用closestK的时候,都是在计算一个样本和其他每个样本之间的距离,这样就会重复计算每对样本的距离(比如先计算距离(a,b),再计算距离(b,a)),其实可以只计算一次,省去重复的工作。
  • 你在计算r的时候,可能需要进行比较耗时的vstack操作,这个过程会重复进行2 * (n - 1)次,其中n是样本的数量。你也可以只计算一次,然后把结果存储起来,作为myCluster的一个属性。
  • 你对整个列表进行了排序,但其实你只需要前k个元素(在k个元素之后的就没必要排序了)。
  • 为了计算你集合中点之间的最小距离,你创建了一个包含所有距离的矩阵,然后找出最小值:其实可以有更好的方法。

我的建议是实现一个top-k类,里面有一个insert方法,只有在新元素比当前第k个元素更好时才插入(并且把第k个元素移除)。同时修改myCluster以包含r。这样你的代码可能看起来像:

kNN = {i : TopK() for i in xrange(len(samples))}
for i, sample1 in enumerate(samples):
    for j, sample2 in enumerate(samples[:i]):
        dist = distanceRepr(sample1, sample2)
        kNN[i].insert(j, -dist)
        kNN[j].insert(i, -dist)
return kNN

这里是一个可能的TopK实现:

import heapq

class TopK:
    def __init__(self, k):
        self.k = k
        self.content = []

    def insert (self, key, score):
        if len(self.content) < self.k:
            heapq.heappush(self.content, (score, key))
        else:
            heapq.heappushpop(self.content, (score, key))

    def get_keys(self):
        return [elem[1] for elem in self.content]

对于distanceRepr,你可以使用类似这样的东西:

import scipy.spatial

def distanceRepr(set0 ,set1):
    if len(set0) < len(set1):
        min_set = set0
        max_set = set1
    else:
        min_set = set1
        max_set = set0
    if len(min_set) == 0:
        raise Exception("Empty set")

    min_dist = scipy.inf
    tree = scipy.spatial.cKDTree(max_set)

    for point in min_set:
        distance, _ = tree.query(point, 1, 0., 2, min_dist)
        if min_dist > distance:
            min_dist = min(min_dist, distance)

    return min_dist

对于中等和大型数据来说,这种方法会比你现在的方法快(比如样本1和样本2的大小都超过5000),而且内存使用会小得多,这样就能处理更大的样本(而cdit可能会因为内存不足而无法工作)。

2

你遇到的问题是Python运行得比较慢(其实是CPython这个解释器慢)。根据维基百科的说法:

NumPy是针对CPython这个Python的参考实现而设计的,它是一个不进行优化的字节码编译器/解释器。用这个版本的Python写的数学算法通常运行得比编译过的代码慢很多。NumPy通过提供多维数组和高效操作数组的函数与运算符来解决这个问题。因此,任何主要以数组和矩阵操作为基础的算法,运行速度几乎可以和相应的C代码一样快。

还有来自Scipy常见问题解答的内容:

Python的列表是高效的通用容器。它们支持(相对)高效的插入、删除、追加和连接,并且Python的列表推导式让它们的构造和操作变得简单。然而,它们也有一些限制:不支持“向量化”的操作,比如逐元素的加法和乘法。而且,由于列表可以包含不同类型的对象,Python必须为每个元素存储类型信息,并在操作每个元素时执行类型调度代码。这也意味着很少有列表操作可以通过高效的C循环来完成——每次迭代都需要进行类型检查和其他Python API的管理。

注意,这不仅仅是Python的问题;想了解更多背景,可以参考这个问题这个问题

由于动态类型系统和解释器带来的开销,如果没有办法利用各种编译的C和Fortran库(比如Numpy),Python在高性能数值计算方面就会显得不太有用。此外,还有一些JIT编译器,比如Numba和PyPy,试图让Python代码的执行速度接近静态类型编译代码的速度。

总的来说:你在用普通的Python做的事情太多了,相比之下你交给快速的C代码的工作量不够。我想你需要采用更“数组导向”的编码风格,而不是面向对象的风格,这样才能在使用NumPy时获得更好的性能(在这方面,MATLAB的情况也很相似)。另一方面,如果你使用更高效的算法(可以参考Ara的回答),那么Python的慢速可能就不是个大问题了。

撰写回答