我在这里对Numba有什么错?

2024-04-16 08:24:08 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在尝试学习如何使用Numba模块。到目前为止,由于与NumPy的接口出现了一些问题,我还没有得到任何工作。这是我正在运行的代码(来自Numba文档)和我得到的错误:

from numba import jit
import numpy as np

x = np.arange(100).reshape(10, 10)

@jit(nopython=True) # Set "nopython" mode for best performance, equivalent to @njit
def go_fast(a): # Function is compiled to machine code when called the first time
    trace = 0.0
    for i in range(a.shape[0]):   # Numba likes loops
        trace += np.tanh(a[i, i]) # Numba likes NumPy functions
    return a + trace              # Numba likes NumPy broadcasting

print(go_fast(x))
    Traceback (most recent call last):
File "C:/Users/JoHn/Documents/Current Classes/MEEN575_Optimization/HW6/Optimal_controller/angle_wrapping.py", line 84, in <module>
print(go_fast(x))
TypeError: expected dtype object, got 'numpy.dtype[float64]'

我从一些搜索中知道,这是或是最近的一个已知错误,与需要更新版本的NumPy或类似版本的NUBA的新版本有关,但据我所知,我有最新的NumPy版本,版本1.20。有关于我做错了什么的提示吗?说清楚一点,我从来没有对如何用python干净地设置环境有过很好的理解,所以很可能我在这里遗漏了一些明显的东西


Tags: toimport版本numpygofor错误np
1条回答
网友
1楼 · 发布于 2024-04-16 08:24:08

更新至0.53.1有效。对我来说,它在0.47.x上也失败了。似乎更多的是numpy问题。解决安装numpy的一种方法>=1.20.0和numba v>;0.52.

有关此问题的更多信息: https://github.com/numba/numba/issues/6041

附言:不确定你是否仍然有这个错误,只是想更新,面临着类似的问题

相关问题 更多 >