将数组参数传递给numba函数
我有一段代码,它把一个数组作为参数传给一个numba函数:
import numpy as np
from numba import njit, float64
A = [( 0.0182286178413157, -1.2904019395416308),
( 0.5228683581098151, 0.2323207738837293),
(-0.6056770113345468, 1.5990251249135883),
(-0.7557841434090988, 1.4641641762952791),
( 0.9882455737412416, -1.1838797980930709),
(-1.2168205368640061, 1.5178083863904257),
(-0.5566781056044838, 0.2160324328998916),
( 0.0671405605855369, -0.4246242749812621),
( 0.4806167193998933, 1.0521631181457611),
( 0.0563547059786364, -0.8223422191733811)]
A = np.array(A)
@njit(float64(float64[:]))
def distance(a):
return a[0]**2 + a[1]**2 + 2*a[0]*a[1]
distance(A)
但是我无法找到让这段代码正常运行的签名字符串(目前只对单个数值参数有效)。我总是遇到这个错误:
TypeError: No matching definition for argument type(s) array(float64, 2d, C)
1 个回答
0
你应该这样重写这个函数:
@njit
def distance(A):
return A[:, 0]**2 + A[:, 1]**2 + 2*A[:, 0]*A[:, 1]