Python位掩码(可变长度)

4 投票
3 回答
6643 浏览
提问于 2025-04-16 11:21

为了研究一个问题,我们需要在Python中组织位掩码的搜索。我们有一份原始数据(可以看作是一串比特),大小大约是1.5GB。我们的目标是找出特定位掩码出现的次数。让我给你举个例子来说明这个情况。

input:    sequence of bits, a bitmask to search(mask length: 12bits)

第一个想法(虽然效率不高)是使用异或(XOR)操作,像这样:

1step: from input we take 12 first bits(position 0 to 11) and make XOR with mask 
2step: from input we take bits from 1 to 12 position and XOR with mask ...

让我们先进行前两步:

input sequence 100100011110101010110110011010100101010110101010
mask to search: 100100011110
step 1: take first 12 bits from input: 100100011110 and XOR it with mask.
step 2: teke bits from 1 to 12position: 001000111101 and XOR it with mask.
...

问题是:我们该如何从输入中提取比特?我们可以提取前12个比特,但如何提取从第1位到第12位的比特,以便进行下一次迭代呢?

之前我们使用过Python的BitString库,但搜索所有掩码所花费的时间太长了。还有一点,掩码的大小可以从12比特到256比特不等。

有没有什么建议?这个任务需要用Python来实现。

3 个回答

1

在字节数据中寻找比特模式比一般的搜索要复杂一些。普通的搜索算法并不总是有效,因为从字节数据中提取每一个比特都有成本,而且只有两个字符可以选择,所以仅凭运气,50%的比较结果可能会匹配(这让很多算法的效率大打折扣)。

你提到尝试过bitstring模块(这是我写的),但觉得速度太慢。这让我并不感到意外,所以如果有人有好的主意来解决这个问题,我会很关注!不过,bitstring的处理方式给你提供了一个可能的加速方法:

为了进行匹配,bitstring会把字节数据的一部分转换成普通的'0'和'1'字符串,然后使用Python的find方法进行快速搜索。大部分时间都花在了数据转换成字符串上,但因为你在同一数据上进行多次搜索,所以这样做可以节省很多时间。

masks = ['0000101010100101', '010100011110110101101', '01010101101']
byte_data_chunk = bytearray('blahblahblah')
# convert to a string with one character per bit
# uses lookup table not given here!
s = ''.join(BYTE_TO_BITS[x] for x in byte_data_chunk)
for mask in masks:
    p = s.find(mask)
    # etc.

关键是,一旦你把数据转换成普通字符串,就可以使用内置的find方法来搜索每一个掩码,而这个方法经过了很好的优化。当你使用bitstring时,它每次都要为每个掩码进行转换,这样会严重影响性能。

3

当你的掩码是8位的倍数时,你的搜索变得相对简单,只需要进行字节比较,任何一种 子串搜索算法 都可以用(我建议把它转换成字符串然后使用内置的搜索,因为你可能会遇到字符验证失败的问题)。

sequence = <list of 8-bit integers>
mask = [0b10010001, 0b01101101]
matches = my_substring_search(sequence, mask)

如果掩码大于8位但不是8的倍数,我建议把掩码截断到8的倍数,然后使用上面提到的相同子串搜索。对于找到的任何匹配项,你可以测试剩余部分。

sequence = <list of 8-bit integers>
mask_a = [0b10010001]
mask_b = 0b01100000
mask_b_pattern = 0b11110000   # relevant bits of mask_b
matches = my_substring_search(sequence, mask_a)

for match in matches:
    if (sequence[match+len(mask_a)] & mask_b_pattern) == mask_b:
        valid_match = True  # or something more useful...

如果 sequence 是一个包含4096个字节的列表,你可能需要考虑不同部分之间的重叠。这可以通过将 sequence 设为一个包含 4096+ceil(mask_bits/8.0) 个字节的列表来轻松实现,但每次读取下一个块时仍然只向前移动4096。


下面是生成和使用这些掩码的演示:

class Mask(object):
    def __init__(self, source, source_mask):
        self._masks = list(self._generate_masks(source, source_mask))

    def match(self, buffer, i, j):
        return any(m.match(buffer, i, j) for m in self._masks)

    class MaskBits(object):
        def __init__(self, pre, pre_mask, match_bytes, post, post_mask):
            self.match_bytes = match_bytes
            self.pre, self.pre_mask = pre, pre_mask
            self.post, self.post_mask = post, post_mask

        def __repr__(self):
            return '(%02x %02x) (%s) (%02x %02x)' % (self.pre, self.pre_mask,
                ', '.join('%02x' % m for m in self.match_bytes),
                self.post, self.post_mask)

        def match(self, buffer, i, j):
            return (buffer[i:j] == self.match_bytes and
                    buffer[i-1] & self.pre_mask == self.pre and
                    buffer[j] & self.post_mask == self.post)

    def _generate_masks(self, src, src_mask):
        pre_mask = 0
        pre = 0
        post_mask = 0
        post = 0
        while pre_mask != 0xFF:
            src_bytes = []
            for i in (24, 16, 8, 0):
                if (src_mask >> i) & 0xFF == 0xFF:
                    src_bytes.append((src >> i) & 0xFF)
                else:
                    post_mask = (src_mask >> i) & 0xFF
                    post = (src >> i) & 0xFF
                    break
            yield self.MaskBits(pre, pre_mask, src_bytes, post, post_mask)
            pre += pre
            pre_mask += pre_mask
            if src & 0x80000000: pre |= 0x00000001
            pre_mask |= 0x00000001
            src = (src & 0x7FFFFFFF) * 2
            src_mask = (src_mask & 0x7FFFFFFF) * 2

这段代码并不是一个完整的搜索算法,它只是验证匹配的一部分。Mask对象是用源值和源掩码构建的,两个都是左对齐的,并且(在这个实现中)都是32位长:

src = 0b11101011011011010101001010100000
src_mask = 0b11111111111111111111111111100000

缓冲区是一个字节值的列表:

buffer_1 = [0x7e, 0xb6, 0xd5, 0x2b, 0x88]

一个Mask对象会生成一个内部的移位掩码列表:

>> m = Mask(src, src_mask)
>> m._masks
[(00 00) (eb, 6d, 52) (a0 e0),
 (01 01) (d6, da, a5) (40 c0),
 (03 03) (ad, b5, 4a) (80 80),
 (07 07) (5b, 6a, 95) (00 00),
 (0e 0f) (b6, d5) (2a fe),
 (1d 1f) (6d, aa) (54 fc),
 (3a 3f) (db, 54) (a8 f8),
 (75 7f) (b6, a9) (50 f0)]

中间的元素是完全匹配的子串(从这个对象中没有简单的方法可以提取出来,但它是 m._masks[i].match_bytes)。一旦你使用高效的算法找到了这个子序列,你可以使用 m.match(buffer, i, j) 来验证周围的位,其中 i 是第一个匹配字节的索引,j 是最后一个匹配字节之后的字节索引(这样 buffer[i:j] == match_bytes)。

在上面的 buffer 中,位序列可以从位5开始找到,这意味着 _masks[4].match_bytes 可以在 buffer[1:3] 中找到。因此:

>> m.match(buffer, 1, 3)
True

(随意使用、改编、修改、出售或以任何可能的方式折磨这段代码。我很享受把它组合在一起的过程——这是一个有趣的问题——不过我不对任何错误负责,所以一定要彻底测试!)

4

你现在用的算法是查找数据中“字符串”的一种简单方法,但幸运的是,还有更好的算法可以使用。一个例子就是KMP算法,当然还有其他算法可能更适合你的需求。

使用更好的算法后,你的计算复杂度可以从O(n*m)降低到O(n+m),这样效率会更高。

撰写回答