根据列表B筛选列表A中的元素,使得A中的每个元素a至少有一个B中的元素b满足a = (a&b)
我有两个整数列表:A 和 B。我想知道怎么能高效地确保 A 中的每个元素在 B 中至少有一个元素,满足它们进行位与运算后,结果是 A 中的那个元素。
举个例子,假设 B = [14, 13]。
对于 A 中的元素:
a1 = 12 是有效的,因为 12 & 14 = 12。
a2 = 3 是无效的,因为 3 & 14 = 2,3 & 13 = 1。
A 和 B 的元素数量可以达到 10^10。
我正在从 CSV 文件中读取 B,并动态生成 A。然后我用以下方式检查上述条件:
df.applymap(lambda x: (x & a) == a).any().any()
但是随着 B 的大小增加,这个检查成了我的瓶颈。
1 个回答
1
如果速度是一个音乐会,你可以试试用 numba:
from numba import njit
@njit
def check(a, b):
for val_a in a:
for val_b in b:
if (val_b & val_a) == val_a:
return True
return False
a = np.array([12, 3], dtype=np.uint8)
b = np.array([14, 13], dtype=np.uint8)
print(check(a, b))
输出结果:
True
基准测试:
from statistics import median
from timeit import repeat
np.random.seed(42)
def setup():
a = np.random.randint(0, 255, size=10_000_000, dtype=np.uint8)
b = np.random.randint(0, 255, size=10_000_000, dtype=np.uint8)
return a, b
t = repeat(
"check(a, b)", setup="a, b = setup()", repeat=1000, number=1, globals=globals()
)
print(f"t={median(t):.8f}")
在我的电脑上(AMD 5700x),这个输出:
t=0.00000496
编辑:对于 np.uint64
的值从 0
到 MAX(np.unit64)
:
def setup():
a = np.random.randint(0, np.iinfo(np.uint64).max, size=10_000_000, dtype=np.uint64)
b = np.random.randint(0, np.iinfo(np.uint64).max, size=10_000_000, dtype=np.uint64)
return a, b
t = repeat(
"check(a, b)", setup="a, b = setup()", repeat=50, number=1, globals=globals()
)
print(f"t={median(t):.8f}")
输出结果:
t=0.04230742
编辑2:对 B
数组中的前1000个项目按位数进行排序:
from statistics import median
from timeit import repeat
np.random.seed(42)
# https://stackoverflow.com/a/68943135/10035985
@njit
def bit_count(arr):
# Make the values type-agnostic (as long as it's integers)
t = arr.dtype.type
mask = t(-1)
s55 = t(0x5555555555555555 & mask) # Add more digits for 128bit support
s33 = t(0x3333333333333333 & mask)
s0F = t(0x0F0F0F0F0F0F0F0F & mask)
s01 = t(0x0101010101010101 & mask)
arr = arr - ((arr >> np.uint8(1)) & s55)
arr = (arr & s33) + ((arr >> np.uint8(2)) & s33)
arr = (arr + (arr >> np.uint8(4))) & s0F
return (arr * s01) >> np.uint16((8 * (arr.itemsize - 1)))
@njit
def check(a, b):
b[:1000] = b[np.argsort(bit_count(b[:1000]))[::-1]]
for val_a in a:
for val_b in b:
if (val_b & val_a) == val_a:
return True
return False
def setup():
a = np.random.randint(0, np.iinfo(np.uint64).max, size=10_000_000, dtype=np.uint64)
b = np.random.randint(0, np.iinfo(np.uint64).max, size=10_000_000, dtype=np.uint64)
return a, b
check(np.array([1, 2, 3], dtype=np.uint64), np.array([1, 2, 3], dtype=np.uint64))
t = repeat(
"check(a, b)", setup="a, b = setup()", repeat=50, number=1, globals=globals()
)
print(f"t={median(t):.8f}")
输出结果:
t=0.04198640
编辑3:如果你想检查从 A 到 B 数组的每个元素:
@njit
def check(a, b):
for val_a in a:
for val_b in b:
if (val_b & val_a) == val_a:
break
else:
return False
return True
def setup():
a = np.random.randint(0, np.iinfo(np.uint64).max, size=10_000_000, dtype=np.uint64)
b = np.random.randint(0, np.iinfo(np.uint64).max, size=10_000_000, dtype=np.uint64)
return a, b
check(np.array([1, 2, 3], dtype=np.uint64), np.array([1, 2, 3], dtype=np.uint64))
t = repeat(
"check(a, b)", setup="a, b = setup()", repeat=50, number=1, globals=globals()
)
print(f"t={median(t):.8f}")
输出结果:
t=0.00415215
编辑4:创建掩码的并行版本:
from statistics import median
from timeit import repeat
import numpy as np
from numba import njit, prange
@njit(parallel=True)
def check(a, b):
out = np.ones_like(a, dtype=np.uint8)
for i in prange(len(a)):
val_a = a[i]
for val_b in b:
if (val_b & val_a) == val_a:
break
else:
out[i] = 0
return out
def setup():
a = np.random.randint(0, np.iinfo(np.uint64).max, size=1_000_000, dtype=np.uint64)
b = np.random.randint(0, np.iinfo(np.uint64).max, size=1_000_000, dtype=np.uint64)
return a, b
check(np.array([1, 2, 3], dtype=np.uint64), np.array([1, 2, 3], dtype=np.uint64))
t = repeat("check(a, b)", setup="a, b = setup()", repeat=1, number=1, globals=globals())
print(f"t={median(t):.8f}")
输出结果:
t=79.79640480