将数组参数传递给numba函数

1 投票
1 回答
37 浏览
提问于 2025-04-13 01:54

我有一段代码,它把一个数组作为参数传给一个numba函数:

import numpy as np
from numba import njit, float64

A = [( 0.0182286178413157, -1.2904019395416308),
 ( 0.5228683581098151,  0.2323207738837293),
 (-0.6056770113345468,  1.5990251249135883),
 (-0.7557841434090988,  1.4641641762952791),
 ( 0.9882455737412416, -1.1838797980930709),
 (-1.2168205368640061,  1.5178083863904257),
 (-0.5566781056044838,  0.2160324328998916),
 ( 0.0671405605855369, -0.4246242749812621),
 ( 0.4806167193998933,  1.0521631181457611),
 ( 0.0563547059786364, -0.8223422191733811)]

A = np.array(A)

@njit(float64(float64[:]))
def distance(a):           
    return a[0]**2 + a[1]**2 + 2*a[0]*a[1]

distance(A)

但是我无法找到让这段代码正常运行的签名字符串(目前只对单个数值参数有效)。我总是遇到这个错误:

TypeError: No matching definition for argument type(s) array(float64, 2d, C)

1 个回答

0

你应该这样重写这个函数:

@njit
def distance(A):
    return A[:, 0]**2 + A[:, 1]**2 + 2*A[:, 0]*A[:, 1]

撰写回答