如何提高这个小距离Python函数的性能

2024-03-28 19:58:56 发布

您现在位置:Python中文网/ 问答频道 /正文

我在使用来自sklearn的聚类算法的自定义距离度量函数时遇到了性能瓶颈。在

Run Snake Run显示的结果如下:

enter image description here

显然,问题在于dbscan_metric函数。这个函数看起来非常简单,我不太清楚加速它的最佳方法是:

def dbscan_metric(a,b):
  if a.shape[0] != NUM_FEATURES:
    return np.linalg.norm(a-b)
  else:
    return np.linalg.norm(np.multiply(FTR_WEIGHTS, (a-b)))

任何关于是什么导致它如此缓慢的想法都将是非常感激的。在


Tags: 函数run算法距离normreturn度量np
1条回答
网友
1楼 · 发布于 2024-03-28 19:58:56

我不熟悉这个函数的作用,但是是否有重复计算的可能?如果是这样,您可以记住函数:

cache = {}
def dbscan_metric(a,b):

  diff = a - b

  if a.shape[0] != NUM_FEATURES:
    to_calc = diff
  else:
    to_calc = np.multiply(FTR_WEIGHTS, diff)

  if not cache.get(to_calc): cache[to_calc] = np.linalg.norm(to_calc)

  return cache[to_calc]

相关问题 更多 >