计算两组向量之间距离矩阵的简单函数,如
import numpy as np
from numba import jit
@jit(signature_or_function='(f8[:,:], f8[:,:])', nopython=True, cache=True, locals={'d': numba.float64[:]})
def foo(prototypes, features):
protonum = prototypes.shape[0]
featurenum = features.shape[0]
dismatrix = np.zeros(shape=(featurenum, protonum), dtype=np.double)
for i in range(featurenum):
feature = features[i,:]
tmp = (prototypes - feature)**2
d = np.sqrt(tmp.sum(axis=1)) # (nproto, 1)
dismatrix[i,:] = d
idx = d.argmin()
return dismatrix
输入参数“prototypes”是带有shape(protonum,dim)的ndarray,“features”是shape(featurenum,dim)的ndarray。如果“nopython”设置为False,则此函数运行良好,但如果“nopython”设置为True,则会引发错误。错误消息是
^{pr2}$似乎numba错误地将“d”的类型推断为“float64”,而不是float64数组,即使由“locals={d”显式指定:浮点数64[:]}'. 有没有我做错的地方?在
Numba版本=0.23.1,Python 3.5。在
目前没有回答
相关问题 更多 >
编程相关推荐