是否可以在函数内部创建一个键类型为UniTuple的numba dict

2024-05-23 19:32:36 发布

您现在位置:Python中文网/ 问答频道 /正文

我想在一个函数中实例化一个numbaDict,并且我希望键类型是三个浮点的元组。为此,a编写了以下代码:

import numba


@numba.njit
def foo():
    local_dict = numba.typed.Dict.empty(
        key_type=numba.types.UniTuple(numba.float64, 3),
        value_type=numba.float64,
    )
    return 1


if __name__ == '__main__':
    foo()

不幸的是,此代码无法编译(错误消息如下)。
但是,当我在模块级别用完全相同的代码实例化local_dict时,它会成功编译。
我还尝试将密钥类型更改为float64,它成功了,这表明(如错误消息)问题来自UniTuple类型。
所以我的问题是:如何声明一个dict,其中一个UniTuple作为函数内部的键

以下是完整的错误消息:

Traceback (most recent call last):
  File "/home/louis/PycharmProjects/Bac_a_sable/numba_sandbox.py", line 19, in <module>
    foo()
  File "/home/louis/.venvs/Bac_a_sable/lib/python3.9/site-packages/numba/core/dispatcher.py", line 420, in _compile_for_args
    error_rewrite(e, 'typing')
  File "/home/louis/.venvs/Bac_a_sable/lib/python3.9/site-packages/numba/core/dispatcher.py", line 361, in error_rewrite
    raise e.with_traceback(None)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Unknown attribute 'UniTuple' of type Module(<module 'numba.core.types' from '/home/louis/.venvs/Bac_a_sable/lib/python3.9/site-packages/numba/core/types/__init__.py'>)

File "numba_sandbox.py", line 8:
def foo():
    <source elided>
        # key_type=numba.float64, value_type=numba.float64,
        key_type=numba.types.UniTuple(dtype=numba.float64, count=3), value_type=numba.float64,
        ^

During: typing of get attribute at /home/louis/PycharmProjects/Bac_a_sable/numba_sandbox.py (8)

File "numba_sandbox.py", line 8:
def foo():
    <source elided>
        # key_type=numba.float64, value_type=numba.float64,
        key_type=numba.types.UniTuple(dtype=numba.float64, count=3), value_type=numba.float64,
        ^


Process finished with exit code 1

Tags: keypyhomefoovaluetypelinefile
1条回答
网友
1楼 · 发布于 2024-05-23 19:32:36

docs声明“jit函数中不支持类型表达式”

import numba

from numba.types import UniTuple

// declare types _outside_ of function definition
value_float = numba.float64
key_float = UniTuple(numba.float64, 3)

@numba.njit
def foo():
    local_dict = numba.typed.Dict.empty(
        key_type=key_float,
        value_type=value_float
        )
    return local_dict


if __name__ == '__main__':
    print(foo()) // prints: {}

相关问题 更多 >