math.gcd的算法是什么,为什么它比欧几里得算法更快?

2 投票
1 回答
104 浏览
提问于 2025-04-11 21:52

测试显示,Python的 math.gcd 函数比简单的欧几里得算法快一个数量级:

import math
from timeit import default_timer as timer

def gcd(a,b):
        while b != 0:
                a, b = b, a % b
        return a

def main():
        a = 28871271685163
        b = 17461204521323
        start = timer()
        print(gcd(a, b))
        end = timer()
        print(end - start)

        start = timer()
        print(math.gcd(a, b))
        end = timer()
        print(end - start)

结果是

$ python3 test.py
1
4.816000000573695e-05
1
8.346003596670926e-06

e-05e-06

我猜这可能是因为有一些优化或者使用了其他算法?

1 个回答

7

math.gcd() 其实是 Python 对一个库函数的封装,这个库函数是以机器代码的形式运行的,也就是说它是从 "C" 语言代码编译而来的,而不是由 Python 解释器直接运行的。想了解更多,可以看看这个链接:math.py 和 sys.py 在哪里?

对于 CPython 来说,这就是它的实现:

math_gcd(PyObject *module, PyObject * const *args, Py_ssize_t nargs)

mathmodule.c 文件中

它会调用

_PyLong_GCD(PyObject *aarg, PyObject *barg)

longobject.c 文件中

这个函数显然使用了 Lehmer 的 GCD 算法

不过,代码中有很多额外的操作和特殊情况的处理,这让它变得复杂了不少。不过,整体上还是挺干净的。

PyObject *
_PyLong_GCD(PyObject *aarg, PyObject *barg)
{
    PyLongObject *a, *b, *c = NULL, *d = NULL, *r;
    stwodigits x, y, q, s, t, c_carry, d_carry;
    stwodigits A, B, C, D, T;
    int nbits, k;
    digit *a_digit, *b_digit, *c_digit, *d_digit, *a_end, *b_end;

    a = (PyLongObject *)aarg;
    b = (PyLongObject *)barg;
    if (_PyLong_DigitCount(a) <= 2 && _PyLong_DigitCount(b) <= 2) {
        Py_INCREF(a);
        Py_INCREF(b);
        goto simple;
    }

    /* Initial reduction: make sure that 0 <= b <= a. */
    a = (PyLongObject *)long_abs(a);
    if (a == NULL)
        return NULL;
    b = (PyLongObject *)long_abs(b);
    if (b == NULL) {
        Py_DECREF(a);
        return NULL;
    }
    if (long_compare(a, b) < 0) {
        r = a;
        a = b;
        b = r;
    }
    /* We now own references to a and b */

    Py_ssize_t size_a, size_b, alloc_a, alloc_b;
    alloc_a = _PyLong_DigitCount(a);
    alloc_b = _PyLong_DigitCount(b);
    /* reduce until a fits into 2 digits */
    while ((size_a = _PyLong_DigitCount(a)) > 2) {
        nbits = bit_length_digit(a->long_value.ob_digit[size_a-1]);
        /* extract top 2*PyLong_SHIFT bits of a into x, along with
           corresponding bits of b into y */
        size_b = _PyLong_DigitCount(b);
        assert(size_b <= size_a);
        if (size_b == 0) {
            if (size_a < alloc_a) {
                r = (PyLongObject *)_PyLong_Copy(a);
                Py_DECREF(a);
            }
            else
                r = a;
            Py_DECREF(b);
            Py_XDECREF(c);
            Py_XDECREF(d);
            return (PyObject *)r;
        }
        x = (((twodigits)a->long_value.ob_digit[size_a-1] << (2*PyLong_SHIFT-nbits)) |
             ((twodigits)a->long_value.ob_digit[size_a-2] << (PyLong_SHIFT-nbits)) |
             (a->long_value.ob_digit[size_a-3] >> nbits));

        y = ((size_b >= size_a - 2 ? b->long_value.ob_digit[size_a-3] >> nbits : 0) |
             (size_b >= size_a - 1 ? (twodigits)b->long_value.ob_digit[size_a-2] << (PyLong_SHIFT-nbits) : 0) |
             (size_b >= size_a ? (twodigits)b->long_value.ob_digit[size_a-1] << (2*PyLong_SHIFT-nbits) : 0));

        /* inner loop of Lehmer's algorithm; A, B, C, D never grow
           larger than PyLong_MASK during the algorithm. */
        A = 1; B = 0; C = 0; D = 1;
        for (k=0;; k++) {
            if (y-C == 0)
                break;
            q = (x+(A-1))/(y-C);
            s = B+q*D;
            t = x-q*y;
            if (s > t)
                break;
            x = y; y = t;
            t = A+q*C; A = D; B = C; C = s; D = t;
        }

        if (k == 0) {
            /* no progress; do a Euclidean step */
            if (l_mod(a, b, &r) < 0)
                goto error;
            Py_SETREF(a, b);
            b = r;
            alloc_a = alloc_b;
            alloc_b = _PyLong_DigitCount(b);
            continue;
        }

        /*
          a, b = A*b-B*a, D*a-C*b if k is odd
          a, b = A*a-B*b, D*b-C*a if k is even
        */
        if (k&1) {
            T = -A; A = -B; B = T;
            T = -C; C = -D; D = T;
        }
        if (c != NULL) {
            assert(size_a >= 0);
            _PyLong_SetSignAndDigitCount(c, 1, size_a);
        }
        else if (Py_REFCNT(a) == 1) {
            c = (PyLongObject*)Py_NewRef(a);
        }
        else {
            alloc_a = size_a;
            c = _PyLong_New(size_a);
            if (c == NULL)
                goto error;
        }

        if (d != NULL) {
            assert(size_a >= 0);
            _PyLong_SetSignAndDigitCount(d, 1, size_a);
        }
        else if (Py_REFCNT(b) == 1 && size_a <= alloc_b) {
            d = (PyLongObject*)Py_NewRef(b);
            assert(size_a >= 0);
            _PyLong_SetSignAndDigitCount(d, 1, size_a);
        }
        else {
            alloc_b = size_a;
            d = _PyLong_New(size_a);
            if (d == NULL)
                goto error;
        }

        a_end = a->long_value.ob_digit + size_a;
        b_end = b->long_value.ob_digit + size_b;

        /* compute new a and new b in parallel */
        a_digit = a->long_value.ob_digit;
        b_digit = b->long_value.ob_digit;
        c_digit = c->long_value.ob_digit;
        d_digit = d->long_value.ob_digit;
        c_carry = 0;
        d_carry = 0;
        while (b_digit < b_end) {
            c_carry += (A * *a_digit) - (B * *b_digit);
            d_carry += (D * *b_digit++) - (C * *a_digit++);
            *c_digit++ = (digit)(c_carry & PyLong_MASK);
            *d_digit++ = (digit)(d_carry & PyLong_MASK);
            c_carry >>= PyLong_SHIFT;
            d_carry >>= PyLong_SHIFT;
        }
        while (a_digit < a_end) {
            c_carry += A * *a_digit;
            d_carry -= C * *a_digit++;
            *c_digit++ = (digit)(c_carry & PyLong_MASK);
            *d_digit++ = (digit)(d_carry & PyLong_MASK);
            c_carry >>= PyLong_SHIFT;
            d_carry >>= PyLong_SHIFT;
        }
        assert(c_carry == 0);
        assert(d_carry == 0);

        Py_INCREF(c);
        Py_INCREF(d);
        Py_DECREF(a);
        Py_DECREF(b);
        a = long_normalize(c);
        b = long_normalize(d);
    }
    Py_XDECREF(c);
    Py_XDECREF(d);

simple:
    assert(Py_REFCNT(a) > 0);
    assert(Py_REFCNT(b) > 0);
/* Issue #24999: use two shifts instead of ">> 2*PyLong_SHIFT" to avoid
   undefined behaviour when LONG_MAX type is smaller than 60 bits */
#if LONG_MAX >> PyLong_SHIFT >> PyLong_SHIFT

    /* a fits into a long, so b must too */
    x = PyLong_AsLong((PyObject *)a);
    y = PyLong_AsLong((PyObject *)b);
#elif LLONG_MAX >> PyLong_SHIFT >> PyLong_SHIFT
    x = PyLong_AsLongLong((PyObject *)a);
    y = PyLong_AsLongLong((PyObject *)b);
#else
# error "_PyLong_GCD"
#endif
    x = Py_ABS(x);
    y = Py_ABS(y);
    Py_DECREF(a);
    Py_DECREF(b);

    /* usual Euclidean algorithm for longs */
    while (y != 0) {
        t = y;
        y = x % y;
        x = t;
    }
#if LONG_MAX >> PyLong_SHIFT >> PyLong_SHIFT
    return PyLong_FromLong(x);
#elif LLONG_MAX >> PyLong_SHIFT >> PyLong_SHIFT
    return PyLong_FromLongLong(x);
#else
# error "_PyLong_GCD"
#endif

error:
    Py_DECREF(a);
    Py_DECREF(b);
    Py_XDECREF(c);
    Py_XDECREF(d);
    return NULL;
}

撰写回答