使用Python Decimal库加速算术运算

0 投票
1 回答
541 浏览
提问于 2025-04-18 03:03

我正在尝试运行一个类似于谷歌PageRank算法的函数(当然是出于非商业目的)。这里有一段Python代码;请注意,a[0]是唯一重要的部分,a[0]包含一个n x n的矩阵,比如[[0,1,1],[1,0,1],[1,1,0]]。你也可以在维基百科上找到我获取这段代码的来源:

def GetNodeRanks(a):        # graph, names, size
    numIterations = 10
    adjacencyMatrix = copy.deepcopy(a[0])
    b = [1]*len(adjacencyMatrix)
    tmp = [0]*len(adjacencyMatrix)
    for i in range(numIterations):
        for j in range(len(adjacencyMatrix)):
            tmp[j] = 0
            for k in range(len(adjacencyMatrix)):
                tmp[j] = tmp[j] + adjacencyMatrix[j][k] * b[k]
        norm_sq = 0
        for j in range(len(adjacencyMatrix)):
            norm_sq = norm_sq + tmp[j]*tmp[j]
        norm = math.sqrt(norm_sq)
        for j in range(len(b)):
            b[j] = tmp[j] / norm
    print b
    return b 

当我运行这个实现时(使用的矩阵比3 x 3的矩阵要大得多),结果的精度不够,无法有效地计算出排名,导致我无法进行有用的比较。所以我尝试了这个:

from decimal import *

getcontext().prec = 5

def GetNodeRanks(a):        # graph, names, size
    numIterations = 10
    adjacencyMatrix = copy.deepcopy(a[0])
    b = [Decimal(1)]*len(adjacencyMatrix)
    tmp = [Decimal(0)]*len(adjacencyMatrix)
    for i in range(numIterations):
        for j in range(len(adjacencyMatrix)):
            tmp[j] = Decimal(0)
            for k in range(len(adjacencyMatrix)):
                tmp[j] = Decimal(tmp[j] + adjacencyMatrix[j][k] * b[k])
        norm_sq = Decimal(0)
        for j in range(len(adjacencyMatrix)):
            norm_sq = Decimal(norm_sq + tmp[j]*tmp[j])
        norm = Decimal(norm_sq).sqrt
        for j in range(len(b)):
            b[j] = Decimal(tmp[j] / norm)
    print b
    return b 

即使在这种不太理想的低精度下,代码运行得非常慢,我等了很久也没有完成。之前的代码运行很快,但精度不够。

有没有什么简单合理的方法可以让代码同时快速又精确地运行呢?

1 个回答

0

以下是一些加速代码的小建议:

  • 优化循环中的代码
  • 如果可以的话,把所有东西都移到内循环外面
  • 不要重复计算已经知道的内容,使用变量来存储结果
  • 跳过那些不必要的操作,省略它们
  • 考虑使用列表推导式,这通常会快一些
  • 一旦速度达到可以接受的水平,就停止优化

逐行检查你的代码:

from decimal import *

getcontext().prec = 5

def GetNodeRanks(a):        # graph, names, size
    # opt: pass in directly a[0], you do not use the rest
    numIterations = 10
    adjacencyMatrix = copy.deepcopy(a[0])
    #opt: why copy.deepcopy? You do not modify adjacencyMatric
    b = [Decimal(1)]*len(adjacencyMatrix)
    # opt: You often call Decimal(1) and Decimal(0), it takes some time
    # do it only once like
    # dec_zero = Decimal(0)
    # dec_one = Decimal(1)
    # prepare also other, repeatedly used data structures
    # len_adjacencyMatrix = len(adjacencyMatrix)
    # adjacencyMatrix_range = range(len_ajdacencyMatrix)
    # Replace code with pre-calculated variables yourself

    tmp = [Decimal(0)]*len(adjacencyMatrix)
    for i in range(numIterations):
        for j in range(len(adjacencyMatrix)):
            tmp[j] = Decimal(0)
            for k in range(len(adjacencyMatrix)):
                tmp[j] = Decimal(tmp[j] + adjacencyMatrix[j][k] * b[k])
        norm_sq = Decimal(0)
        for j in range(len(adjacencyMatrix)):
            norm_sq = Decimal(norm_sq + tmp[j]*tmp[j])
        norm = Decimal(norm_sq).sqrt #is this correct? I woudl expect .sqrt()
        for j in range(len(b)):
            b[j] = Decimal(tmp[j] / norm)
    print b
    return b 

接下来是一些如何在Python中优化列表处理的示例。

使用 sum,把:

        norm_sq = Decimal(0)
        for j in range(len(adjacencyMatrix)):
            norm_sq = Decimal(norm_sq + tmp[j]*tmp[j])

改成:

        norm_sq = sum(val*val for val in tmp)

稍微使用一下列表推导式:

把:

        for j in range(len(b)):
            b[j] = Decimal(tmp[j] / norm)

改成:

    b = [Decimal(tmp_itm / norm) for tmp_itm in tmp]

如果你掌握了这种编码风格,你就能优化最初的循环,并且可能会发现一些预先计算的变量变得没用了。

撰写回答