距离度量的组合优化
我有一组轨迹,这些轨迹是由沿着轨迹的点组成的,每个点都有对应的坐标。我把这些数据存储在一个三维数组里,结构是(轨迹,点,参数)。我想找到一组r条轨迹,这些轨迹之间的所有可能组合的总距离是最大的。我的第一次尝试,看起来是这样的:
max_dist = 0
for h in itertools.combinations ( xrange(num_traj), r):
for (m,l) in itertools.combinations (h, 2):
accum = 0.
for ( i, j ) in itertools.izip ( range(k), range(k) ):
A = [ (my_mat[m, i, z] - my_mat[l, j, z])**2 \
for z in xrange(k) ]
A = numpy.array( numpy.sqrt (A) ).sum()
accum += A
if max_dist < accum:
selected_trajectories = h
这个方法运行得非常慢,因为轨迹的数量(num_traj)可能在500到1000之间,而r的值可能在5到20之间。k的值是随意的,但通常可以达到50。
为了让代码看起来更聪明,我把所有的内容放进了两个嵌套的列表推导式里,并且大量使用了itertools库:
chunk = [[ numpy.sqrt((my_mat[m, i, :] - my_mat[l, j, :])**2).sum() \
for ((m,l),i,j) in \
itertools.product ( itertools.combinations(h,2), range(k), range(k)) ]\
for h in itertools.combinations(range(num_traj), r) ]
除了代码看起来非常难懂(!!!)之外,它的运行速度也很慢。有没有人能给我一些建议,帮我改进这个方法呢?
5 个回答
这里有一些额外的建议和有趣的观点,除了大家提到的内容。(顺便说一下,mathmike提到的生成所有对之间距离的查找列表的建议,你应该立刻实施。这可以让你的算法复杂度减少到O(r^2))
首先,下面这几行
for ( i, j ) in itertools.izip ( range(k), range(k) ):
A = [ (my_mat[m, i, z] - my_mat[l, j, z])**2 \
for z in xrange(k) ]
可以用下面的代码替换
for i in xrange(k):
A = [ (my_mat[m, i, z] - my_mat[l, i, z])**2 \
for z in xrange(k) ]
因为在每次循环中,i和j都是相同的。这里根本不需要用到izip。
其次,关于这一行
A = numpy.array( numpy.sqrt (A) ).sum()
你确定这是你想要的计算方式吗?可能是这样,但我觉得有点奇怪,因为如果这是在计算向量之间的欧几里得距离,那么这一行应该是:
A = numpy.sqrt (numpy.array( A ).sum())
或者直接用
A = numpy.sqrt(sum(A))
因为我觉得把A转换成numpy数组来使用numpy的sum函数可能会比直接用Python内置的sum函数慢,但我可能错了。此外,如果你确实想要的是欧几里得距离,那这样做会减少开平方的次数。
第三,你知道你可能要遍历多少种组合吗?在最坏的情况下,如果num_traj = 1000和r = 20,那大约有6.79E42种组合。这对于你现在的方法来说是相当难以处理的。即使在最好的情况下,num_traj = 500和r = 5,也有1.28E12种组合,这已经很多了,但还不是不可能。这才是你面临的真正问题,因为如果你采纳mathmike的建议,我提到的前两点就不是那么重要了。
那么你该怎么办呢?你需要更聪明一些。目前我还不清楚什么方法最合适。我猜你可能需要以某种方式让算法变得更灵活。我想到的一个方法是尝试用动态规划的方式结合一些启发式的方法。对于每个轨迹,你可以计算它与其他轨迹配对的距离的总和或平均值,并用这个作为适应度指标。在继续处理三元组之前,可以先把适应度最低的轨迹剔除。然后你可以对三元组做同样的事情:计算所有三元组(在剩下的可能轨迹中)中每个轨迹参与的累计距离的总和或平均值,并用这个作为适应度指标来决定哪些轨迹在继续处理四元组之前被剔除。这并不能保证得到最优解,但应该会相当不错,并且我相信这会大大降低解决方案的时间复杂度。
在计算距离的时候,你可以不必去算平方根……因为最大的和也会有最大的平方和,虽然这样做只是能让速度提升一点点。
与其每次都去重新计算每对轨迹之间的距离,不如先把所有轨迹之间的距离都计算出来。然后把这些距离存储在一个字典里,等需要的时候直接查找就可以了。
这样,你的内层循环 for (i,j) ...
就可以用快速查找来替代,速度会更快。