如何在Python中实现高效的无限素数生成器?

72 投票
13 回答
34464 浏览
提问于 2025-04-15 18:57

这不是作业,我只是好奇。

这里的关键词是“无限”。

我想用它来写 for p in primes()。我相信这是Haskell里面的一个内置函数。

所以,答案不能简单到“就用筛法”这么肤浅。

首先,你不知道会消耗多少个连续的质数。假设你一次能搞定100个,你还会用同样的筛法和质数公式吗?

我更喜欢非并发的方法。

谢谢你阅读(和写作 ;))!

13 个回答

51

为了后人留个记录,这里是对Will Ness的漂亮算法在Python 3中的重写。需要做一些改动(迭代器不再有.next()方法了,但有了一个新的next()内置函数)。其他的改动是为了好玩(使用新的yield from <iterable>替代了原来四个yield语句)。还有一些改动是为了让代码更易读(我不太喜欢过多使用单字母的变量名)。

这个版本的速度比原来的快很多,但这并不是因为算法本身的原因。速度提升主要是因为去掉了原来的add()函数,而是直接在代码中处理。

def psieve():
    import itertools
    yield from (2, 3, 5, 7)
    D = {}
    ps = psieve()
    next(ps)
    p = next(ps)
    assert p == 3
    psq = p*p
    for i in itertools.count(9, 2):
        if i in D:      # composite
            step = D.pop(i)
        elif i < psq:   # prime
            yield i
            continue
        else:           # composite, = p*p
            assert i == psq
            step = 2*p
            p = next(ps)
            psq = p*p
        i += step
        while i in D:
            i += step
        D[i] = step
80

因为提问者想要一个高效的实现,所以这里有一个对David Eppstein和Alex Martelli在2002年发布的代码的重大改进(可以在这个链接找到)。这个改进的关键是:在看到一个质数的平方之前,不要把它的信息记录在字典里。这样做可以把空间复杂度降低到小于O(sqrt(n)),而不是O(n),这里的n是生成的质数数量(π(sqrt(n log n)) ~ 2 sqrt(n log n) / log(n log n) ~ 2 sqrt(n / log n))。因此,时间复杂度也得到了改善,也就是说,运行得更快

这个方法创建了一个“滑动筛”,作为每个基础质数当前倍数的字典(也就是在当前生成点的平方根以下),并且包含它们的步长值:

from itertools import count
                                         # ideone.com/aVndFM
def postponed_sieve():                   # postponed sieve, by Will Ness      
    yield 2; yield 3; yield 5; yield 7;  # original code David Eppstein, 
    sieve = {}                           #   Alex Martelli, ActiveState Recipe 2002
    ps = postponed_sieve()               # a separate base Primes Supply:
    p = next(ps) and next(ps)            # (3) a Prime to add to dict
    q = p*p                              # (9) its sQuare 
    for c in count(9,2):                 # the Candidate
        if c in sieve:               # c's a multiple of some base prime
            s = sieve.pop(c)         #     i.e. a composite ; or
        elif c < q:  
             yield c                 # a prime
             continue              
        else:   # (c==q):            # or the next base prime's square:
            s=count(q+2*p,2*p)       #    (9+6, by 6 : 15,21,27,33,...)
            p=next(ps)               #    (5)
            q=p*p                    #    (25)
        for m in s:                  # the next multiple 
            if m not in sieve:       # no duplicates
                break
        sieve[m] = s                 # original test entry: ideone.com/WFv4f

(这里的旧代码经过编辑,加入了Tim Peters的回答中看到的更改)。另外,关于这个话题的相关讨论可以参考这个链接

类似的基于2-3-5-7轮子的代码运行速度约快2.15倍(这接近理论上的提升3/2 * 5/4 * 7/6 = 2.1875)。

2022年更新:我最近偶然发现了这个90年代的“NESL”东西,它实际上使用了相同的sqrt递归技巧。所以在阳光下没有什么是新的。:)

这段代码可以很简单地扩展,以便从给定的值开始枚举质数。这可以在这个基于JS的条目中看到。

85

“如果我看得更远…”

来自食谱的 erat2 函数可以进一步加速(大约提高20-25%):

erat2a

import itertools as it
def erat2a( ):
    D = {  }
    yield 2
    for q in it.islice(it.count(3), 0, None, 2):
        p = D.pop(q, None)
        if p is None:
            D[q*q] = q
            yield q
        else:
            # old code here:
            # x = p + q
            # while x in D or not (x&1):
            #     x += p
            # changed into:
            x = q + 2*p
            while x in D:
                x += 2*p
            D[x] = p

这里的 not (x&1) 检查是用来确认 x 是奇数的。不过,由于 qp 都是奇数,通过加上 2*p,可以避免一半的步骤,同时也省去了检查奇偶性的步骤。

erat3

如果你不介意多一点复杂性,erat2 可以通过以下更改加速35-40%(注意:需要Python 2.7+或Python 3+,因为用到了 itertools.compress 函数):

import itertools as it
def erat3( ):
    D = { 9: 3, 25: 5 }
    yield 2
    yield 3
    yield 5
    MASK= 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0,
    MODULOS= frozenset( (1, 7, 11, 13, 17, 19, 23, 29) )

    for q in it.compress(
            it.islice(it.count(7), 0, None, 2),
            it.cycle(MASK)):
        p = D.pop(q, None)
        if p is None:
            D[q*q] = q
            yield q
        else:
            x = q + 2*p
            while x in D or (x%30) not in MODULOS:
                x += 2*p
            D[x] = p

erat3 函数利用了这样一个事实:除了2、3和5,所有的质数对30取余后只会得到八个数字,这些数字包含在 MODULOS 的不可变集合中。因此,在输出前面三个质数后,我们从7开始,只处理候选数字。
候选数字的筛选使用了 itertools.compress 函数;这里的“魔法”在于 MASK 序列;MASK 有15个元素(每30个数字中有15个奇数,由 itertools.islice 函数选择),每个可能的候选数字从7开始对应一个 1。这个循环会根据 itertools.cycle 函数的规定重复。
引入候选数字筛选需要另一个修改:or (x%30) not in MODULOS 检查。erat2 算法处理所有奇数;而现在 erat3 算法只处理30的候选数字,我们需要确保所有的 D.keys() 只能是这样的——错误——候选数字。

基准测试

结果

在一台Atom 330的Ubuntu 9.10服务器上,版本2.6.4和3.1.1+:

$ testit
up to 8192
==== python2 erat2 ====
100 loops, best of 3: 18.6 msec per loop
==== python2 erat2a ====
100 loops, best of 3: 14.5 msec per loop
==== python2 erat3 ====
Traceback (most recent call last):
…
AttributeError: 'module' object has no attribute 'compress'
==== python3 erat2 ====
100 loops, best of 3: 19.2 msec per loop
==== python3 erat2a ====
100 loops, best of 3: 14.1 msec per loop
==== python3 erat3 ====
100 loops, best of 3: 11.7 msec per loop

在一台AMD Geode LX的Gentoo家庭服务器上,Python 2.6.5和3.1.2:

$ testit
up to 8192
==== python2 erat2 ====
10 loops, best of 3: 104 msec per loop
==== python2 erat2a ====
10 loops, best of 3: 81 msec per loop
==== python2 erat3 ====
Traceback (most recent call last):
…
AttributeError: 'module' object has no attribute 'compress'
==== python3 erat2 ====
10 loops, best of 3: 116 msec per loop
==== python3 erat2a ====
10 loops, best of 3: 82 msec per loop
==== python3 erat3 ====
10 loops, best of 3: 66 msec per loop

基准测试代码

一个名为 primegen.py 的模块包含了 erat2erat2aerat3 函数。以下是测试脚本:

#!/bin/sh
max_num=${1:-8192}
echo up to $max_num
for python_version in python2 python3
do
    for function in erat2 erat2a erat3
    do
        echo "==== $python_version $function ===="
        $python_version -O -m timeit -c \
        -s  "import itertools as it, functools as ft, operator as op, primegen; cmp= ft.partial(op.ge, $max_num)" \
            "next(it.dropwhile(cmp, primegen.$function()))"
    done
done

撰写回答