根据列表B筛选列表A中的元素,使得A中的每个元素a至少有一个B中的元素b满足a = (a&b)

1 投票
1 回答
99 浏览
提问于 2025-04-14 15:37

我有两个整数列表:A 和 B。我想知道怎么能高效地确保 A 中的每个元素在 B 中至少有一个元素,满足它们进行位与运算后,结果是 A 中的那个元素。

举个例子,假设 B = [14, 13]。

对于 A 中的元素:
a1 = 12 是有效的,因为 12 & 14 = 12。
a2 = 3 是无效的,因为 3 & 14 = 2,3 & 13 = 1。

A 和 B 的元素数量可以达到 10^10。

我正在从 CSV 文件中读取 B,并动态生成 A。然后我用以下方式检查上述条件:

df.applymap(lambda x: (x & a) == a).any().any()

但是随着 B 的大小增加,这个检查成了我的瓶颈。

1 个回答

1

如果速度是一个音乐会,你可以试试用

from numba import njit


@njit
def check(a, b):
    for val_a in a:
        for val_b in b:
            if (val_b & val_a) == val_a:
                return True
    return False


a = np.array([12, 3], dtype=np.uint8)
b = np.array([14, 13], dtype=np.uint8)

print(check(a, b))

输出结果:

True

基准测试:

from statistics import median
from timeit import repeat

np.random.seed(42)


def setup():
    a = np.random.randint(0, 255, size=10_000_000, dtype=np.uint8)
    b = np.random.randint(0, 255, size=10_000_000, dtype=np.uint8)
    return a, b


t = repeat(
    "check(a, b)", setup="a, b = setup()", repeat=1000, number=1, globals=globals()
)

print(f"t={median(t):.8f}")

在我的电脑上(AMD 5700x),这个输出:

t=0.00000496

编辑:对于 np.uint64 的值从 0MAX(np.unit64)

def setup():
    a = np.random.randint(0, np.iinfo(np.uint64).max, size=10_000_000, dtype=np.uint64)
    b = np.random.randint(0, np.iinfo(np.uint64).max, size=10_000_000, dtype=np.uint64)
    return a, b


t = repeat(
    "check(a, b)", setup="a, b = setup()", repeat=50, number=1, globals=globals()
)

print(f"t={median(t):.8f}")

输出结果:

t=0.04230742

编辑2:对 B 数组中的前1000个项目按位数进行排序:

from statistics import median
from timeit import repeat

np.random.seed(42)


# https://stackoverflow.com/a/68943135/10035985
@njit
def bit_count(arr):
    # Make the values type-agnostic (as long as it's integers)
    t = arr.dtype.type
    mask = t(-1)
    s55 = t(0x5555555555555555 & mask)  # Add more digits for 128bit support
    s33 = t(0x3333333333333333 & mask)
    s0F = t(0x0F0F0F0F0F0F0F0F & mask)
    s01 = t(0x0101010101010101 & mask)

    arr = arr - ((arr >> np.uint8(1)) & s55)
    arr = (arr & s33) + ((arr >> np.uint8(2)) & s33)
    arr = (arr + (arr >> np.uint8(4))) & s0F
    return (arr * s01) >> np.uint16((8 * (arr.itemsize - 1)))


@njit
def check(a, b):
    b[:1000] = b[np.argsort(bit_count(b[:1000]))[::-1]]

    for val_a in a:
        for val_b in b:
            if (val_b & val_a) == val_a:
                return True
    return False


def setup():
    a = np.random.randint(0, np.iinfo(np.uint64).max, size=10_000_000, dtype=np.uint64)
    b = np.random.randint(0, np.iinfo(np.uint64).max, size=10_000_000, dtype=np.uint64)
    return a, b


check(np.array([1, 2, 3], dtype=np.uint64), np.array([1, 2, 3], dtype=np.uint64))

t = repeat(
    "check(a, b)", setup="a, b = setup()", repeat=50, number=1, globals=globals()
)

print(f"t={median(t):.8f}")

输出结果:

t=0.04198640

编辑3:如果你想检查从 A 到 B 数组的每个元素:

@njit
def check(a, b):
    for val_a in a:
        for val_b in b:
            if (val_b & val_a) == val_a:
                break
        else:
            return False
    return True


def setup():
    a = np.random.randint(0, np.iinfo(np.uint64).max, size=10_000_000, dtype=np.uint64)
    b = np.random.randint(0, np.iinfo(np.uint64).max, size=10_000_000, dtype=np.uint64)
    return a, b


check(np.array([1, 2, 3], dtype=np.uint64), np.array([1, 2, 3], dtype=np.uint64))

t = repeat(
    "check(a, b)", setup="a, b = setup()", repeat=50, number=1, globals=globals()
)

print(f"t={median(t):.8f}")

输出结果:

t=0.00415215

编辑4:创建掩码的并行版本:

from statistics import median
from timeit import repeat

import numpy as np
from numba import njit, prange

@njit(parallel=True)
def check(a, b):
    out = np.ones_like(a, dtype=np.uint8)

    for i in prange(len(a)):
        val_a = a[i]

        for val_b in b:
            if (val_b & val_a) == val_a:
                break
        else:
            out[i] = 0

    return out


def setup():
    a = np.random.randint(0, np.iinfo(np.uint64).max, size=1_000_000, dtype=np.uint64)
    b = np.random.randint(0, np.iinfo(np.uint64).max, size=1_000_000, dtype=np.uint64)
    return a, b


check(np.array([1, 2, 3], dtype=np.uint64), np.array([1, 2, 3], dtype=np.uint64))

t = repeat("check(a, b)", setup="a, b = setup()", repeat=1, number=1, globals=globals())

print(f"t={median(t):.8f}")

输出结果:

t=79.79640480

撰写回答