使用numb优化Jaccard距离性能

2024-05-08 22:42:21 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试使用Numba在python中实现jaccard distance的最快版本

@nb.jit()
def nbjaccard(seq1, seq2):
    set1, set2 = set(seq1), set(seq2)
    return 1 - len(set1 & set2) / float(len(set1 | set2))

def jaccard(seq1, seq2):
    set1, set2 = set(seq1), set(seq2)
    return 1 - len(set1 & set2) / float(len(set1 | set2))


%%timeit
nbjaccard("compare this string","compare a different string")

——12.4毫秒

^{pr2}$

-3.87毫秒

为什么麻木版本需要更长的时间?有没有办法加速?在


Tags: 版本stringlenreturndeffloatcompareset
2条回答

当我计时这两个函数时,nbjaccard需要~4.7微秒(在预热jit之后),而普通python函数使用numba0.32.0需要~3.2微秒。也就是说,在这种情况下,我不希望numba给您任何加速,因为目前在nopython模式下基本上没有字符串支持。这意味着您将遍历python对象层,这通常与不使用jit运行没有什么不同,除非numba可以执行一些智能循环提升(即使用纯内部函数而不是python函数编译子块)。除了numba案例中输入的类型检查之外,您可能只需要支付一些小的开销。在

我认为底线是,您正在尝试将numba用于当前未涵盖的用例。Numba真正擅长的地方是处理numpy数组和对数值标量值的操作或者可以推送到GPU的问题。在

在我看来,允许对象模式numba函数(或者如果numba实现了整个函数使用python对象,则没有警告),因为这些函数通常比纯python函数慢一些。在

Numba非常强大(与C扩展或Cython相比,无需类型声明就可以编写python代码的类型分派非常棒),但只有在它支持操作的情况下:

这意味着在“nopython”模式下不支持未列出的任何操作。如果numba不得不回到"object mode"那就小心了:

object mode

A Numba compilation mode that generates code that handles all values as Python objects and uses the Python C API to perform all operations on those objects. Code compiled in object mode will often run no faster than Python interpreted code, unless the Numba compiler can take advantage of loop-jitting.

你的情况就是这样的:你完全是在对象模式下操作:

>>> nbjaccard.inspect_types()

[...]
#  - LINE 3  - 
#   seq1 = arg(0, name=seq1)  :: pyobject
#   seq2 = arg(1, name=seq2)  :: pyobject
#   $0.1 = global(set: <class 'set'>)  :: pyobject
#   $0.3 = call $0.1(seq1)  :: pyobject
#   $0.4 = global(set: <class 'set'>)  :: pyobject
#   $0.6 = call $0.4(seq2)  :: pyobject
#   set1 = $0.3  :: pyobject
#   set2 = $0.6  :: pyobject

set1, set2 = set(seq1), set(seq2)

#  - LINE 4  - 
#   $const0.7 = const(int, 1)  :: pyobject
#   $0.8 = global(len: <built-in function len>)  :: pyobject
#   $0.11 = set1 & set2  :: pyobject
#   $0.12 = call $0.8($0.11)  :: pyobject
#   $0.13 = global(float: <class 'float'>)  :: pyobject
#   $0.14 = global(len: <built-in function len>)  :: pyobject
#   $0.17 = set1 | set2  :: pyobject
#   $0.18 = call $0.14($0.17)  :: pyobject
#   $0.19 = call $0.13($0.18)  :: pyobject
#   $0.20 = $0.12 / $0.19  :: pyobject
#   $0.21 = $const0.7 - $0.20  :: pyobject
#   $0.22 = cast(value=$0.21)  :: pyobject
#   return $0.22

return 1 - len(set1 & set2) / float(len(set1 | set2))

如您所见,每一个操作都在Python对象上操作(如每行末尾的:: pyobject所示)。这是因为numba不支持strs和sets,所以这里绝对没有比这更快的了。但是你知道如何使用numpy数组或齐次列表(数值类型)来解决这个问题。在

在我的电脑上,时间差要大得多(使用numba 0.32.0),但单个计时要快得多-微秒秒(10**-6秒),而不是毫秒秒(10**-3秒):

^{pr2}$

注意,默认情况下jitlazy,因此第一个调用应该在执行计时之前完成,因为它包括编译代码的时间。在


不过,有一个优化你可以做:如果你知道两个集合的交集,你就可以计算联合的长度(正如@Paul Hankin在他的现在删除了答案中提到的那样):

len(union) = len(set1) + len(set2) - len(intersection)

这将导致以下(纯python)代码:

def jaccard2(seq1, seq2):
    set1, set2 = set(seq1), set(seq2)
    num_intersection = len(set1 & set2)
    return 1 - num_intersection / float(len(set1) + len(set2) - num_intersection)

%timeit jaccard2("compare this string","compare a different string")
100000 loops, best of 3: 13.7 µs per loop

不是更快,而是更好。在


如果使用,还有一些改进空间:

%load_ext cython

%%cython
def cyjaccard(seq1, seq2):
    cdef set set1 = set(seq1)
    cdef set set2 = set()

    cdef Py_ssize_t length_intersect = 0

    for char in seq2:
        if char not in set2:
            if char in set1:
                length_intersect += 1
            set2.add(char)

    return 1 - (length_intersect / float(len(set1) + len(set2) - length_intersect))

%timeit cyjaccard("compare this string","compare a different string")
100000 loops, best of 3: 7.97 µs per loop

这里的主要优点是,只需一次迭代,就可以创建set2并计算交集中元素的数量(根本不需要创建交集)!在

相关问题 更多 >

    热门问题