使用位运算找出 n = 2**x 的指数 [n 的以 2 为底的对数]

15 投票
7 回答
10805 浏览
提问于 2025-04-15 19:13

有没有简单的方法只用位运算从2的幂中提取出指数?

编辑:虽然最开始的问题是关于位运算的,但如果你想知道“在Python中,给定Y = 2,找X的最快方法是什么?”这个讨论也很有意思。

我现在正在尝试优化一个例程(拉宾-米勒素性测试),这个例程需要把一个偶数 N 表示成2**s * d的形式。我可以通过以下方式得到2**s部分:

two_power_s = N & -N

但是我找不到只用位运算提取出s的方法。目前我正在测试的一些变通方法效果都不太好(都比较慢),包括:

  • 使用对数函数
  • 操作2**s的二进制表示(也就是计算尾随的零)
  • 通过不断除以2循环,直到结果为1

我使用的是Python,但我想这个问题的答案应该和编程语言无关。

7 个回答

4

有一个页面上有很多这种技巧和窍门。虽然它是为C语言写的,但很多内容在Python中也能用(当然,性能会有所不同)。你想要的部分可以在这里及之后的内容找到。

你可以试试这个,比如:

register unsigned int r = 0; // result of log2(v) will go here
for (i = 4; i >= 0; i--) // unroll for speed...
{
  if (v & b[i])
  {
    v >>= S[i];
    r |= S[i];
  } 
}

看起来这个代码可以很容易地转换成Python。

7

“与语言无关”和关注性能这两个概念基本上是相互矛盾的。

大多数现代处理器都有一个叫做CLZ的指令,意思是“计算前导零的数量”。在GCC编译器中,你可以用__builtin_clz(x)来调用这个指令(即使在没有CLZ指令的目标上,它也能生成合理的代码,虽然可能不是最快的)。需要注意的是,当输入为零时,这个CLZ的结果是未定义的,所以如果在你的应用中这个情况很重要,你需要额外加一个判断来处理。

在CELT(http://celt-codec.org)中,我们为那些没有CLZ指令的编译器使用的无分支CLZ是由Timothy B. Terriberry编写的:


int ilog(uint32 _v){
  int ret;
  int m;
  ret=!!_v;
  m=!!(_v&0xFFFF0000)<<4;
  _v>>=m;
  ret|=m;
  m=!!(_v&0xFF00)<<3;
  _v>>=m;
  ret|=m;
  m=!!(_v&0xF0)<<2;
  _v>>=m;
  ret|=m;
  m=!!(_v&0xC)<<1;
  _v>>=m;
  ret|=m;
  ret+=!!(_v&0x2);
  return ret;
}

(注释中提到,这种方法比有分支的版本和基于查找表的版本更快)

但是如果性能真的那么重要,你可能不应该用Python来实现这部分代码。

5

简短回答

关于Python来说:

  • 找到2的x次方的最快方法是通过查找一个字典,这个字典的键是2的幂(在代码中可以看到"hashlookup")
  • 最快的位运算方法叫做"unrolled_bitwise"。
  • 前面提到的两种方法都有明确的(但可以扩展的)上限。没有硬编码上限的最快方法是"log_e",它可以根据Python处理数字的能力进行扩展。

前言说明

  1. 下面所有的速度测量都是通过timeit.Timer.repeat(testn, cycles)获得的,其中testn设置为3,cycles由脚本自动调整,以获得秒级的时间(注意:这个自动调整机制之前有个bug,已在2010年2月18日修复)。
  2. 并不是所有方法都能扩展,所以我没有对所有函数进行不同2的幂的测试。
  3. 我没有成功让一些提议的方法工作(函数返回了错误的结果)。我还没有时间进行逐步调试:我把代码(已注释)放在这里,以防有人通过检查发现错误(或者想自己调试)。

结果

func(25)**

hashlookup:          0.13s     100%
lookup:              0.15s     109%
stringcount:         0.29s     220%
unrolled_bitwise:    0.36s     272%
log_e:               0.60s     450%
bitcounter:          0.64s     479%
log_2:               0.69s     515%
ilog:                0.81s     609%
bitwise:             1.10s     821%
olgn:                1.42s    1065%

func(231)**

hashlookup:          0.11s     100%
unrolled_bitwise:    0.26s     229%
log_e:               0.30s     268%
stringcount:         0.30s     270%
log_2:               0.34s     301%
ilog:                0.41s     363%
bitwise:             0.87s     778%
olgn:                1.02s     912%
bitcounter:          1.42s    1264%

func(2128)**

hashlookup:     0.01s     100%
stringcount:    0.03s     264%
log_e:          0.04s     315%
log_2:          0.04s     383%
olgn:           0.18s    1585%
bitcounter:     1.41s   12393%

func(21024)**

log_e:          0.00s     100%
log_2:          0.01s     118%
stringcount:    0.02s     354%
olgn:           0.03s     707%
bitcounter:     1.73s   37695%

代码

import math, sys

def stringcount(v):
    """mac"""    
    return len(bin(v)) - 3

def log_2(v):
    """mac"""    
    return int(round(math.log(v, 2), 0)) # 2**101 generates 100.999999999

def log_e(v):
    """bp on mac"""    
    return int(round(math.log(v)/0.69314718055994529, 0))  # 0.69 == log(2)

def bitcounter(v):
    """John Y on mac"""
    r = 0
    while v > 1 :
        v >>= 1
        r += 1
    return r

def olgn(n) :
    """outis"""
    if n < 1:
        return -1
    low = 0
    high = sys.getsizeof(n)*8 # not the best upper-bound guesstimate, but...
    while True:
        mid = (low+high)//2
        i = n >> mid
        if i == 1:
            return mid
        if i == 0:
            high = mid-1
        else:
            low = mid+1

def hashlookup(v):
    """mac on brone -- limit: v < 2**131"""
#    def prepareTable(max_log2=130) :
#        hash_table = {}
#        for p in range(1, max_log2) :
#            hash_table[2**p] = p
#        return hash_table

    global hash_table
    return hash_table[v] 

def lookup(v):
    """brone -- limit: v < 2**11"""
#    def prepareTable(max_log2=10) :
#        log2s_table=[0]*((1<<max_log2)+1)
#        for i in range(max_log2+1):
#            log2s_table[1<<i]=i
#        return tuple(log2s_table)

    global log2s_table
    return log2s_table[v]

def bitwise(v):
    """Mark Byers -- limit: v < 2**32"""
    b = (0x2, 0xC, 0xF0, 0xFF00, 0xFFFF0000)
    S = (1, 2, 4, 8, 16)
    r = 0
    for i in range(4, -1, -1) :
        if (v & b[i]) :
            v >>= S[i];
            r |= S[i];
    return r

def unrolled_bitwise(v):
    """x4u on Mark Byers -- limit:   v < 2**33"""
    r = 0;
    if v > 0xffff : 
        v >>= 16
        r = 16;
    if v > 0x00ff :
        v >>=  8
        r += 8;
    if v > 0x000f :
        v >>=  4
        r += 4;
    if v > 0x0003 : 
        v >>=  2
        r += 2;
    return r + (v >> 1)

def ilog(v):
    """Gregory Maxwell - (Original code: B. Terriberry) -- limit: v < 2**32"""
    ret = 1
    m = (not not v & 0xFFFF0000) << 4;
    v >>= m;
    ret |= m;
    m = (not not v & 0xFF00) << 3;
    v >>= m;
    ret |= m;
    m = (not not v & 0xF0) << 2;
    v >>= m;
    ret |= m;
    m = (not not v & 0xC) << 1;
    v >>= m;
    ret |= m;
    ret += (not not v & 0x2);
    return ret - 1;


# following table is equal to "return hashlookup.prepareTable()" 
hash_table = {...} # numbers have been cut out to avoid cluttering the post

# following table is equal to "return lookup.prepareTable()" - cached for speed
log2s_table = (...) # numbers have been cut out to avoid cluttering the post

撰写回答