使用Python Decimal库加速算术运算
我正在尝试运行一个类似于谷歌PageRank算法的函数(当然是出于非商业目的)。这里有一段Python代码;请注意,a[0]
是唯一重要的部分,a[0]
包含一个n x n
的矩阵,比如[[0,1,1],[1,0,1],[1,1,0]]
。你也可以在维基百科上找到我获取这段代码的来源:
def GetNodeRanks(a): # graph, names, size
numIterations = 10
adjacencyMatrix = copy.deepcopy(a[0])
b = [1]*len(adjacencyMatrix)
tmp = [0]*len(adjacencyMatrix)
for i in range(numIterations):
for j in range(len(adjacencyMatrix)):
tmp[j] = 0
for k in range(len(adjacencyMatrix)):
tmp[j] = tmp[j] + adjacencyMatrix[j][k] * b[k]
norm_sq = 0
for j in range(len(adjacencyMatrix)):
norm_sq = norm_sq + tmp[j]*tmp[j]
norm = math.sqrt(norm_sq)
for j in range(len(b)):
b[j] = tmp[j] / norm
print b
return b
当我运行这个实现时(使用的矩阵比3 x 3
的矩阵要大得多),结果的精度不够,无法有效地计算出排名,导致我无法进行有用的比较。所以我尝试了这个:
from decimal import *
getcontext().prec = 5
def GetNodeRanks(a): # graph, names, size
numIterations = 10
adjacencyMatrix = copy.deepcopy(a[0])
b = [Decimal(1)]*len(adjacencyMatrix)
tmp = [Decimal(0)]*len(adjacencyMatrix)
for i in range(numIterations):
for j in range(len(adjacencyMatrix)):
tmp[j] = Decimal(0)
for k in range(len(adjacencyMatrix)):
tmp[j] = Decimal(tmp[j] + adjacencyMatrix[j][k] * b[k])
norm_sq = Decimal(0)
for j in range(len(adjacencyMatrix)):
norm_sq = Decimal(norm_sq + tmp[j]*tmp[j])
norm = Decimal(norm_sq).sqrt
for j in range(len(b)):
b[j] = Decimal(tmp[j] / norm)
print b
return b
即使在这种不太理想的低精度下,代码运行得非常慢,我等了很久也没有完成。之前的代码运行很快,但精度不够。
有没有什么简单合理的方法可以让代码同时快速又精确地运行呢?
1 个回答
0
以下是一些加速代码的小建议:
- 优化循环中的代码
- 如果可以的话,把所有东西都移到内循环外面
- 不要重复计算已经知道的内容,使用变量来存储结果
- 跳过那些不必要的操作,省略它们
- 考虑使用列表推导式,这通常会快一些
- 一旦速度达到可以接受的水平,就停止优化
逐行检查你的代码:
from decimal import *
getcontext().prec = 5
def GetNodeRanks(a): # graph, names, size
# opt: pass in directly a[0], you do not use the rest
numIterations = 10
adjacencyMatrix = copy.deepcopy(a[0])
#opt: why copy.deepcopy? You do not modify adjacencyMatric
b = [Decimal(1)]*len(adjacencyMatrix)
# opt: You often call Decimal(1) and Decimal(0), it takes some time
# do it only once like
# dec_zero = Decimal(0)
# dec_one = Decimal(1)
# prepare also other, repeatedly used data structures
# len_adjacencyMatrix = len(adjacencyMatrix)
# adjacencyMatrix_range = range(len_ajdacencyMatrix)
# Replace code with pre-calculated variables yourself
tmp = [Decimal(0)]*len(adjacencyMatrix)
for i in range(numIterations):
for j in range(len(adjacencyMatrix)):
tmp[j] = Decimal(0)
for k in range(len(adjacencyMatrix)):
tmp[j] = Decimal(tmp[j] + adjacencyMatrix[j][k] * b[k])
norm_sq = Decimal(0)
for j in range(len(adjacencyMatrix)):
norm_sq = Decimal(norm_sq + tmp[j]*tmp[j])
norm = Decimal(norm_sq).sqrt #is this correct? I woudl expect .sqrt()
for j in range(len(b)):
b[j] = Decimal(tmp[j] / norm)
print b
return b
接下来是一些如何在Python中优化列表处理的示例。
使用 sum
,把:
norm_sq = Decimal(0)
for j in range(len(adjacencyMatrix)):
norm_sq = Decimal(norm_sq + tmp[j]*tmp[j])
改成:
norm_sq = sum(val*val for val in tmp)
稍微使用一下列表推导式:
把:
for j in range(len(b)):
b[j] = Decimal(tmp[j] / norm)
改成:
b = [Decimal(tmp_itm / norm) for tmp_itm in tmp]
如果你掌握了这种编码风格,你就能优化最初的循环,并且可能会发现一些预先计算的变量变得没用了。