Python中计算余弦距离的优化方法
我写了一个方法,用来计算两个数组之间的余弦距离:
def cosine_distance(a, b):
if len(a) != len(b):
return False
numerator = 0
denoma = 0
denomb = 0
for i in range(len(a)):
numerator += a[i]*b[i]
denoma += abs(a[i])**2
denomb += abs(b[i])**2
result = 1 - numerator / (sqrt(denoma)*sqrt(denomb))
return result
在处理大数组时,这个方法运行起来可能会很慢。有没有更快的优化版本呢?
更新:我尝试了到目前为止所有的建议,包括使用scipy。下面是一个改进后的版本,结合了Mike和Steve的建议:
def cosine_distance(a, b):
if len(a) != len(b):
raise ValueError, "a and b must be same length" #Steve
numerator = 0
denoma = 0
denomb = 0
for i in range(len(a)): #Mike's optimizations:
ai = a[i] #only calculate once
bi = b[i]
numerator += ai*bi #faster than exponent (barely)
denoma += ai*ai #strip abs() since it's squaring
denomb += bi*bi
result = 1 - numerator / (sqrt(denoma)*sqrt(denomb))
return result
8 个回答
如果你要对 a[i]
和 b[i]
进行平方运算,就不需要先取它们的绝对值。
可以把 a[i]
和 b[i]
存放在临时变量里,这样就不用重复索引了。虽然编译器可能会优化这个过程,但也不一定。
看看 **2
这个运算符。它是把平方简化成乘法,还是在用一个通用的幂函数(先取对数,再乘以2,最后再取反对数)?
不要重复计算平方根(虽然这样做的成本很小)。可以直接计算 sqrt(denoma * denomb)
。
(我最开始认为)如果不使用C语言(像numpy或scipy那样)或者不改变计算方式,是很难大幅提高速度的。不过,不管怎样,我会尝试这样做:
from itertools import imap
from math import sqrt
from operator import mul
def cosine_distance(a, b):
assert len(a) == len(b)
return 1 - (sum(imap(mul, a, b))
/ sqrt(sum(imap(mul, a, a))
* sum(imap(mul, b, b))))
在Python 2.6中,处理50万个元素的数组时,速度大约快了两倍。(这是在把map改成imap之后,参考了Jarret Hardie的建议。)
这是对原作者修改后代码的一个调整版本:
from itertools import izip
def cosine_distance(a, b):
assert len(a) == len(b)
ab_sum, a_sum, b_sum = 0, 0, 0
for ai, bi in izip(a, b):
ab_sum += ai * bi
a_sum += ai * ai
b_sum += bi * bi
return 1 - ab_sum / sqrt(a_sum * b_sum)
虽然看起来不太好,但确实快了不少……
补充:还可以试试Psyco!它能让最终版本的速度再提高4倍。我怎么会忘记呢?
如果你可以使用SciPy库,那么可以用spatial.distance
里的cosine
函数:
http://docs.scipy.org/doc/scipy/reference/spatial.distance.html
如果你不能使用SciPy,试着改写你的Python代码,可能会稍微快一点(编辑:但结果并没有我想的那么好,见下文)。
from itertools import izip
from math import sqrt
def cosine_distance(a, b):
if len(a) != len(b):
raise ValueError, "a and b must be same length"
numerator = sum(tup[0] * tup[1] for tup in izip(a,b))
denoma = sum(avalue ** 2 for avalue in a)
denomb = sum(bvalue ** 2 for bvalue in b)
result = 1 - numerator / (sqrt(denoma)*sqrt(denomb))
return result
当a和b的长度不匹配时,最好抛出一个异常。
通过在sum()
函数中使用生成器表达式,你可以让大部分计算工作由Python内部的C代码来完成,这样应该比用for
循环快。
我没有测量过这个,所以不知道具体快多少。但SciPy的代码几乎肯定是用C或C++写的,速度应该是最快的。
如果你在Python中做生物信息学,真的应该使用SciPy。
编辑:Darius Bacon测量了我的代码,发现它更慢。所以我也测了一下,确实是慢。大家要记住:想要加速的时候,不要猜,得测量。
我很困惑,为什么我试图把更多工作交给Python的C内部反而变慢。我试了长度为1000的列表,结果还是慢。
我不能再花时间去聪明地修改Python了。如果你需要更快的速度,建议你试试SciPy。
编辑:我刚手动测试了一下,没有用timeit。我发现对于短的a和b,旧代码更快;对于长的a和b,新代码更快;不过两者的差别都不大。(我现在在想我能不能信任Windows上的timeit;我想在Linux上再试一次。)我不会为了追求速度而去改动已经能用的代码。再一次,我强烈建议你试试SciPy。:-)