将优化的埃拉托斯特尼筛法从Python移植到C++

2 投票
3 回答
1193 浏览
提问于 2025-04-16 13:37

不久前,我在Python中使用了一个非常快速的质数筛选工具,叫做primesieve,具体可以在这里找到:列出所有小于N的质数的最快方法

为了更准确地说,这个实现是:

def primes2(n):
    """ Input n>=6, Returns a list of primes, 2 <= p < n """
    n, correction = n-n%6+6, 2-(n%6>1)
    sieve = [True] * (n/3)
    for i in xrange(1,int(n**0.5)/3+1):
      if sieve[i]:
        k=3*i+1|1
        sieve[      k*k/3      ::2*k] = [False] * ((n/6-k*k/6-1)/k+1)
        sieve[k*(k-2*(i&1)+4)/3::2*k] = [False] * ((n/6-k*(k-2*(i&1)+4)/6-1)/k+1)
    return [2,3] + [3*i+1|1 for i in xrange(1,n/3-correction) if sieve[i]]

现在我大致明白了通过自动跳过2、3等的倍数来优化的思路,但当我尝试把这个算法移植到C++时就遇到了困难(我对Python理解得不错,对C++的理解一般/较差,但也足够用)。

我目前自己写的代码是这样的(isqrt()只是一个简单的整数平方根函数):

template <class T>
void primesbelow(T N, std::vector<T> &primes) {
    T sievemax = (N-3 + (1-(N % 2))) / 2;
    T i;
    T sievemaxroot = isqrt(sievemax) + 1;

    boost::dynamic_bitset<> sieve(sievemax);
    sieve.set();

    primes.push_back(2);

    for (i = 0; i <= sievemaxroot; i++) {
        if (sieve[i]) {
            primes.push_back(2*i+3);
            for (T j = 3*i+3; j <= sievemax; j += 2*i+3) sieve[j] = 0; // filter multiples
        }
    }

    for (; i <= sievemax; i++) {
        if (sieve[i]) primes.push_back(2*i+3);
    }
}

这个实现还不错,自动跳过了2的倍数,但如果我能把Python的实现移植过来,我觉得速度会快很多(大约快50%-30%)。

为了比较结果(希望这个问题能得到成功的解答),在一台Q6600的Ubuntu 10.10上,使用N=100000000g++ -O3的当前执行时间是1230毫秒。

我现在希望能得到一些帮助,要么理解上面Python实现的具体做法,要么帮我移植一下(不过后者的帮助性不大)。

编辑

关于我觉得困难的地方,补充一些信息。

我对使用的技术,比如修正变量,以及整体是如何结合在一起的,感到困惑。如果有链接能解释不同的埃拉托斯特尼筛法优化(除了那些简单说“你只需跳过2、3和5的倍数”然后给你一千行C代码的网站)就太好了。

我觉得如果是100%直接和字面上的移植应该不会有问题,但毕竟这是为了学习,那样就完全没用。

编辑

在查看了原始numpy版本的代码后,实际上实现起来相当简单,经过一些思考也不难理解。这是我想到的C++版本。我把它完整地发布在这里,以帮助后来的读者,万一他们需要一个效率不错的质数筛选工具,而不是两百万行的代码。这个质数筛选工具在上面提到的同一台机器上,能在大约415毫秒内找出所有小于100000000的质数。这是3倍的速度提升,比我预期的要好!

#include <vector>
#include <boost/dynamic_bitset.hpp>

// http://vault.embedded.com/98/9802fe2.htm - integer square root
unsigned short isqrt(unsigned long a) {
    unsigned long rem = 0;
    unsigned long root = 0;

    for (short i = 0; i < 16; i++) {
        root <<= 1;
        rem = ((rem << 2) + (a >> 30));
        a <<= 2;
        root++;

        if (root <= rem) {
            rem -= root;
            root++;
        } else root--;

    }

    return static_cast<unsigned short> (root >> 1);
}

// https://stackoverflow.com/questions/2068372/fastest-way-to-list-all-primes-below-n-in-python/3035188#3035188
// https://stackoverflow.com/questions/5293238/porting-optimized-sieve-of-eratosthenes-from-python-to-c/5293492
template <class T>
void primesbelow(T N, std::vector<T> &primes) {
    T i, j, k, l, sievemax, sievemaxroot;

    sievemax = N/3;
    if ((N % 6) == 2) sievemax++;

    sievemaxroot = isqrt(N)/3;

    boost::dynamic_bitset<> sieve(sievemax);
    sieve.set();

    primes.push_back(2);
    primes.push_back(3);

    for (i = 1; i <= sievemaxroot; i++) {
        if (sieve[i]) {
            k = (3*i + 1) | 1;
            l = (4*k-2*k*(i&1)) / 3;

            for (j = k*k/3; j < sievemax; j += 2*k) {
                sieve[j] = 0;
                sieve[j+l] = 0;
            }

            primes.push_back(k);
        }
    }

    for (i = sievemaxroot + 1; i < sievemax; i++) {
        if (sieve[i]) primes.push_back((3*i+1)|1);
    }
}

3 个回答

0

顺便提一下,你可以“近似”得到质数。我们把这种近似的质数叫做P。下面是一些公式:

P = 2*k+1 // 这个数不能被2整除

P = 6*k + {1, 5} // 这个数不能被2和3整除

P = 30*k + {1, 7, 11, 13, 17, 19, 23, 29} // 这个数不能被2、3和5整除

这些公式得到的数字有一个特点,就是P可能不是质数,但所有的质数都在这个P的集合里。也就是说,如果你只在P的集合里检查质数,你不会漏掉任何质数。

你也可以把这些公式改成:

P = X*k + {-i, -j, -k, k, j, i}

如果这样更方便的话。

这里有一些代码,使用了这个技巧,公式得到的P不能被2、3、5和7整除。

这个链接可能展示了这个技巧在实际应用中的潜力。

1

在Howard Hinnant的回答基础上,Howard,你其实不需要测试所有自然数中那些不能被2、3或5整除的数字是否是质数。你只需要把数组中的每个数字(除了1,因为它自己就不算)和它后面的每个数字相乘。这样得到的结果会告诉你数组中所有的非质数,直到你决定停止这个乘法过程为止。比如,数组中的第一个非质数是7的平方,也就是49。第二个非质数是7乘以11,结果是77,依此类推。想了解更多,可以查看这个链接:http://www.primesdemystified.com

3

我尽量解释得简单明了些。sieve 数组的索引方式有点特别;它为每个符合条件的数字存储一个位(bit),这些数字是模 6 余 1 或 5 的。因此,像 6*k + 1 这样的数字会存储在位置 2*k,而 6*k + 5 会存储在位置 2*k + 13*i+1|1 这个操作是反向的:它把形如 2*n 的数字转换成 6*n + 1,把 2*n + 1 转换成 6*n + 5(这里的 +1|1 操作会把 0 变成 1,把 3 变成 5)。主循环会遍历所有符合这种特性的数字 k,从 5 开始(当 i 为 1 时);i 是对应于数字 ksieve 中的索引。第一次对 sieve 的更新会清除所有索引为 k*k/3 + 2*m*k 形式的位(m 是自然数);这些索引对应的数字从 k^2 开始,每次增加 6*k。第二次更新从索引 k*(k-2*(i&1)+4)/3 开始(对于模 6 余 1 的 k,是数字 k * (k+4),否则是 k * (k+2)),同样每次增加 6*k

再试着解释一下:让 candidates 表示所有大于等于 5 的数字集合,这些数字模 6 余 1 或 5。如果你把这个集合中的两个元素相乘,你会得到集合中的另一个元素。对于集合中某个 ksucc(k) 表示 candidates 中比 k 大的下一个元素(按数字顺序)。在这种情况下,sieve 的内部循环基本上是(使用正常的 sieve 索引):

for k in candidates:
  for (l = k; ; l += 6) sieve[k * l] = False
  for (l = succ(k); ; l += 6) sieve[k * l] = False

由于 sieve 中存储的元素有限制,这相当于:

for k in candidates:
  for l in candidates where l >= k:
    sieve[k * l] = False

这将会在某个时刻(要么是之前用当前 k 作为 l,要么是现在用作 k)从 sieve 中移除 candidates 中所有 k 的倍数(除了 k 本身)。

撰写回答