Python 相当于 Bit Twiddling Hacks 的 C 代码?

3 投票
2 回答
1191 浏览
提问于 2025-04-15 19:16

我有一个计算位数的方法,想尽可能让它运行得更快。我想尝试下面这个来自位操作技巧的算法,但我不太懂C语言。这里的'Type T'指的是什么?(T)~(T)0/3在Python中怎么写呢?

这是一个针对位宽最多为128的整数的最佳位计数方法的推广(由类型T参数化):

v = v - ((v >> 1) & (T)~(T)0/3);      // temp 
v = (v & (T)~(T)0/15*3) + ((v >> 2) & (T)~(T)0/15*3);      // temp
v = (v + (v >> 4)) & (T)~(T)0/255*15;                      // temp
c = (T)(v * ((T)~(T)0/255)) >> (sizeof(v) - 1) * CHAR_BIT; // count

2 个回答

2

你复制的内容是一个生成代码的模板。把这个模板直接翻译成另一种语言并指望它能快速运行,这样做并不好。我们来详细看看这个模板。

(T)~(T)0 的意思是“在类型 T 中能容纳的 1 位的数量”。这个算法需要 4 个掩码,我们将为可能感兴趣的不同 T 大小来计算这些掩码。

>>> for N in (8, 16, 32, 64, 128):
...     all_ones = (1 << N) - 1
...     constants = ' '.join([hex(x) for x in [
...         all_ones // 3,
...         all_ones // 15 * 3,
...         all_ones // 255 * 15,
...         all_ones // 255,
...         ]])
...     print N, constants
...
8 0x55 0x33 0xf 0x1
16 0x5555 0x3333 0xf0f 0x101
32 0x55555555L 0x33333333L 0xf0f0f0fL 0x1010101L
64 0x5555555555555555L 0x3333333333333333L 0xf0f0f0f0f0f0f0fL 0x101010101010101L
128 0x55555555555555555555555555555555L 0x33333333333333333333333333333333L 0xf0f0f0f0f0f0f0f0f0f0f0f0f0f0f0fL 0x1010101010101010101010101010101L
>>>

你会注意到,为 32 位情况生成的掩码和硬编码的 32 位 C 代码中的掩码是一样的。实现细节:在 32 位掩码中去掉 L 后缀(Python 2.x 中),在 Python 3.x 中去掉所有 L 后缀。

如你所见,整个模板和 (T)~(T)0 的内容其实只是让人困惑的花言巧语。简单来说,对于一个 k 字节的类型,你需要 4 个掩码:

k bytes each 0x55
k bytes each 0x33
k bytes each 0x0f
k bytes each 0x01

最后的位移只是 N-8(也就是 8*(k-1))位。顺便说一下,我怀疑这个模板代码在 CHAR_BIT 不是 8 的机器上是否真的能工作,但现在这样的机器已经不多了。

更新:还有一个影响从 C 转到 Python 的算法正确性和速度的点。C 的算法通常假设使用无符号整数。在 C 中,对无符号整数的操作是默默地按 2**N 取模。换句话说,只有最低有效的 N 位会被保留。没有溢出异常。许多位操作算法依赖于这一点。然而 (a) Python 的 intlong 是有符号的 (b) 旧版 Python 2.X 会抛出异常,最近的 Python 2.X 会默默地把 int 提升为 long,而 Python 3.x 的 int 相当于 Python 2.x 的 long

正确性问题通常需要在 Python 代码中至少执行一次 register &= all_ones。通常需要仔细分析以确定最小的正确掩码。

使用 long 而不是 int 对效率没有太大帮助。你会发现,32 位的算法即使输入为 0 也会返回一个 long 类型的结果,因为 32 位的 all_ones 是 long 类型。

7

T 是一种整数类型,我猜它是无符号的。因为这是 C 语言,所以它的宽度是固定的,可能是 8、16、32、64 或 128 位中的一种(但不一定)。在代码示例中反复出现的 (T)~(T)0 这个片段,其实就是给出了一个值,计算方式是 2 的 N 次方减去 1,其中 N 是类型 T 的宽度。我怀疑这段代码可能要求 N 是 8 的倍数才能正常工作。

下面是将给定代码直接翻译成 Python 的版本,参数化为 N,即 T 的位宽。

def count_set_bits(v, N=128):
    mask = (1 << N) - 1

    v = v - ((v >> 1) & mask//3)
    v = (v & mask//15*3) + ((v >> 2) & mask//15*3)
    v = (v + (v >> 4)) & mask//255*15
    return (mask & v * (mask//255)) >> (N//8 - 1) * 8

注意事项:

(1) 上面的代码只适用于最大值为 2 的 128 次方的数字。不过,你可能可以将其推广到更大的数字。

(2) 代码中有明显的低效之处:比如 'mask//15' 这个计算被执行了两次。对于 C 语言来说,这没什么问题,因为编译器几乎肯定会在编译时而不是运行时进行除法运算,但 Python 的小优化器可能没有那么聪明。

(3) 在 C 语言中最快的方法不一定能转化为 Python 中最快的方法。为了提高 Python 的速度,你可能需要寻找一种能尽量减少位运算次数的算法。正如亚历山大·盖斯勒所说的:要进行性能分析!

撰写回答