为什么麻木错误地推断类型?

2024-06-16 09:45:59 发布

您现在位置:Python中文网/ 问答频道 /正文

计算两组向量之间距离矩阵的简单函数,如

import numpy as np
from numba import jit

@jit(signature_or_function='(f8[:,:], f8[:,:])',  nopython=True, cache=True, locals={'d': numba.float64[:]})
def foo(prototypes, features):
    protonum   = prototypes.shape[0]
    featurenum = features.shape[0]
    dismatrix = np.zeros(shape=(featurenum, protonum), dtype=np.double)
    for i in range(featurenum):
        feature = features[i,:]
        tmp = (prototypes - feature)**2
        d = np.sqrt(tmp.sum(axis=1))   # (nproto, 1)
        dismatrix[i,:] = d
        idx = d.argmin()
    return dismatrix

输入参数“prototypes”是带有shape(protonum,dim)的ndarray,“features”是shape(featurenum,dim)的ndarray。如果“nopython”设置为False,则此函数运行良好,但如果“nopython”设置为True,则会引发错误。错误消息是

^{pr2}$

似乎numba错误地将“d”的类型推断为“float64”,而不是float64数组,即使由“locals={d”显式指定:浮点数64[:]}'. 有没有我做错的地方?在

Numba版本=0.23.1,Python 3.5。在


Tags: 函数importtrue错误npjitfeaturesshape