numba的jit和scipy的hypergeom出现断言错误

2024-06-17 12:51:25 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试使用此功能:

@jit
def pv (matrix1, matrix2, vec, n):
    for i in range (n):
        for j in range (n):
            matrix1[i,j] = 1 - sum (hypergeom.pmf(X, n, int(vec[i]), int(vec[j])) for X in range(matrix2[i,j]))

但是我得到了断言错误,很多事情我甚至还没有开始理解(从numba回溯),最后断言错误:在对象模式管道中失败(步骤:对象模式前端)。我怀疑与海尔格姆有关,但我不知道我哪里做错了

编辑:最后我没有找到实现@jit代码的方法,但我找到了函数scipy.stats.hypergeom.cdf,它可以实现以下功能:

sum (hypergeom.pmf(X, n, int(vec[i]), int(vec[j])) for X in range(matrix2[i,j]))
from scipy.stats import hypergeom

hypergeom.cdf(m2[i,j], n, v[i], v[j])

虽然这个解决方案加快了代码的速度,但是for循环仍然非常慢(n=5053需要运行半个多小时)


Tags: 对象in功能for错误range断言jit
1条回答
网友
1楼 · 发布于 2024-06-17 12:51:25

对于这种类型的东西使用Numba是个好主意,但不幸的是,正如您所怀疑的,它不支持hypergeom函数。你没有做错什么,只是不支持,所以我认为你不能在这种情况下使用Numba

支持内容的列表位于https://numba.pydata.org/numba-doc/latest/reference/numpysupported.html

在过去发生类似情况时,我采取的一种方法是尝试使用numpa确实支持的numpy子集编写我自己版本的不受支持函数,但成功率是可变的,它可能会导致一系列全新的问题(将经过调试、测试的库函数替换为您自己的某个实现可能会导致转储程序火灾)。如果不查看hypergeom.pmf的源代码,我不知道这是否是可行的方法

相关问题 更多 >