卡拉楚巴算法递归过多

6 投票
5 回答
4926 浏览
提问于 2025-04-16 23:35

我正在尝试用C++实现Karatsuba乘法算法,但现在我只是想先在Python中让它工作。

这是我的代码:

def mult(x, y, b, m):
    if max(x, y) < b:
        return x * y

    bm = pow(b, m)
    x0 = x / bm
    x1 = x % bm
    y0 = y / bm
    y1 = y % bm

    z2 = mult(x1, y1, b, m)
    z0 = mult(x0, y0, b, m)
    z1 = mult(x1 + x0, y1 + y0, b, m) - z2 - z0

    return mult(z2, bm ** 2, b, m) + mult(z1, bm, b, m) + z0

我不明白的是:z2z1z0应该怎么创建?使用mult函数进行递归调用是正确的吗?如果是这样的话,我可能在某个地方搞错了,因为递归没有停止。

有人能指出错误在哪里吗?

5 个回答

4

Karatsuba 乘法的目的是改进传统的分治乘法算法,通过减少递归调用的次数,从四次减少到三次。因此,在你的代码中,只有计算 z0z1z2 的那几行应该包含递归调用。其他地方如果使用递归,复杂度会变得更糟。而且,在你还没有定义乘法的情况下,也不能用 pow 来计算 bm

这个算法关键在于它使用了位置表示法。如果你有一个数字在基数 b 下的表示 x,那么 x*bm 就是把这个表示的数字的每一位向左移动 m 次。这个移动操作在任何位置表示法中都是“免费”的。这也意味着,如果你想实现这个算法,你需要复现这种位置表示法和“免费”的移动。你可以选择在基数 b=2 下计算,并使用 Python 的位运算符(或者在你的测试平台上使用的其他进制的位运算符),或者你决定为了学习目的实现一个适用于任意 b 的版本,并用字符串、数组或列表来复现这种位置运算

你已经有一个基于列表的解决方案了。我喜欢在 Python 中使用字符串,因为 int(s, base) 可以将字符串 s 作为基数 base 的数字表示转换为整数,这样测试起来很方便。我在这里发布了一个注释非常详细的基于字符串的实现作为 gist,其中包括字符串与数字之间的转换函数。

你可以通过提供带填充的字符串和它们(相等的)长度作为参数来测试 mult

In [169]: mult("987654321","987654321",10,9)

Out[169]: '966551847789971041'

如果你不想自己处理填充或计算字符串长度,可以使用一个填充函数来帮你完成:

In [170]: padding("987654321","2")

Out[170]: ('987654321', '000000002', 9)

当然,它也适用于 b>10

In [171]: mult('987654321', '000000002', 16, 9)

Out[171]: '130eca8642'

(可以通过wolfram alpha进行验证)

4

通常情况下,大数字会被存储为整数数组。每个整数代表一个数字。这种方法允许我们通过简单地向左移动数组来将任何数字乘以基数的幂。

下面是我基于列表的实现(可能有错误):

def normalize(l,b):
    over = 0
    for i,x in enumerate(l):
        over,l[i] = divmod(x+over,b)
    if over: l.append(over)
    return l
def sum_lists(x,y,b):
    l = min(len(x),len(y))
    res = map(operator.add,x[:l],y[:l])
    if len(x) > l: res.extend(x[l:])
    else: res.extend(y[l:])
    return normalize(res,b)
def sub_lists(x,y,b):
    res = map(operator.sub,x[:len(y)],y)
    res.extend(x[len(y):])
    return normalize(res,b)
def lshift(x,n):
    if len(x) > 1 or len(x) == 1 and x[0] != 0:
        return [0 for i in range(n)] + x
    else: return x
def mult_lists(x,y,b):
    if min(len(x),len(y)) == 0: return [0]
    m = max(len(x),len(y))
    if (m == 1): return normalize([x[0]*y[0]],b)
    else: m >>= 1
    x0,x1 = x[:m],x[m:]
    y0,y1 = y[:m],y[m:]
    z0 = mult_lists(x0,y0,b)
    z1 = mult_lists(x1,y1,b)
    z2 = mult_lists(sum_lists(x0,x1,b),sum_lists(y0,y1,b),b)
    t1 = lshift(sub_lists(z2,sum_lists(z1,z0,b),b),m)
    t2 = lshift(z1,m*2)
    return sum_lists(sum_lists(z0,t1,b),t2,b)

sum_listssub_lists 返回的结果没有经过标准化——也就是说,单个数字可能会大于基数的值。normalize 函数解决了这个问题。

所有的函数都期望接收反向顺序的数字列表。例如,数字12在十进制中应该写成 [2,1]。我们来计算9987654321的平方。

» a = [1,2,3,4,5,6,7,8,9]
» res = mult_lists(a,a,10)
» res.reverse()
» res
[9, 7, 5, 4, 6, 1, 0, 5, 7, 7, 8, 9, 9, 7, 1, 0, 4, 1]
5

注意:下面的回复直接回答了提问者关于过度递归的问题,但并没有提供一个正确的Karatsuba算法。其他的回复在这方面提供了更多的信息。

试试这个版本:

def mult(x, y, b, m):
    bm = pow(b, m)

    if min(x, y) <= bm:
        return x * y

    # NOTE the following 4 lines
    x0 = x % bm
    x1 = x / bm
    y0 = y % bm
    y1 = y / bm

    z0 = mult(x0, y0, b, m)
    z2 = mult(x1, y1, b, m)
    z1 = mult(x1 + x0, y1 + y0, b, m) - z2 - z0

    retval = mult(mult(z2, bm, b, m) + z1, bm, b, m) + z0
    assert retval == x * y, "%d * %d == %d != %d" % (x, y, x * y, retval)
    return retval

你版本中最严重的问题是,你计算x0和x1,以及y0和y1时搞错了顺序。此外,当x1y1为0时,这个算法的推导就不成立,因为在这种情况下,一个分解步骤就变得无效。因此,你必须避免这种情况,确保x和y都大于b的m次方。

编辑:修正了代码中的一个错别字;添加了一些说明。

编辑2:

为了更清楚地说明,直接评论一下你原来的版本:

def mult(x, y, b, m):
    # The termination condition will never be true when the recursive 
    # call is either
    #    mult(z2, bm ** 2, b, m)
    # or mult(z1, bm, b, m)
    #
    # Since every recursive call leads to one of the above, you have an
    # infinite recursion condition.
    if max(x, y) < b:
        return x * y

    bm = pow(b, m)

    # Even without the recursion problem, the next four lines are wrong
    x0 = x / bm  # RHS should be x % bm
    x1 = x % bm  # RHS should be x / bm
    y0 = y / bm  # RHS should be y % bm
    y1 = y % bm  # RHS should be y / bm

    z2 = mult(x1, y1, b, m)
    z0 = mult(x0, y0, b, m)
    z1 = mult(x1 + x0, y1 + y0, b, m) - z2 - z0

    return mult(z2, bm ** 2, b, m) + mult(z1, bm, b, m) + z0

撰写回答