使用 numba 类方法时出现 NotImplementedError(dtype)

1 投票
1 回答
2913 浏览
提问于 2025-04-20 05:16

我正在使用numpy 1.8.x和numba。我有一个叫做train的函数,它的定义如下:

@autojit 
def train_function( X, y, H):

这个函数返回一个三维的numpy数组。

然后我有一个类,它会调用这个函数,像这样:

class GentleBoostC(object):
# different methods including init
# and now the train function
def train(self, X, y, H):
     self.g_per_round = train_function(X,y,H)

我接着实例化这个类,并用它来训练一个对象。

# initiate the variables X_train, y_train and boosting_rounds
gentlebooster = gbc.GentleBoostC() # gbc has already been imported
gentlebooster.train(X_train,y_train,boosting_rounds)

但是我遇到了这个错误:

    gentlebooster.train(X_train,y_train,boosting_rounds)
  File "C:\Users\app\Documents\Python Scripts\gentleboost_c_class_jit_v7_nolimit.py", line 299, in train
    self.g_per_round = train_function(self,X, y, H)  
  File "C:\Anaconda\lib\site-packages\numba\dispatcher.py", line 152, in typeof_pyval
    dtype = numpy_support.from_dtype(val.dtype)
  File "C:\Anaconda\lib\site-packages\numba\numpy_support.py", line 61, in from_dtype
    raise NotImplementedError(dtype)
NotImplementedError: object

这里发生了什么问题呢?

编辑

查看文档后,它说:

异常 NotImplementedError

这个异常是从RuntimeError派生出来的。在用户定义的基类中,抽象方法应该在需要派生类重写这个方法时抛出这个异常。

这对我的情况来说是什么意思呢?

编辑

关于我如何调用train函数的更多细节:

#img_hogs and sample_labels have already been populated above, both are numpy arrays
X_train = np.array(img_hogs)
y_train = np.array(sample_labels)
boosting_rounds = 7

gentlebooster = gbc.GentleBoostC()
gentlebooster.train(X_train,y_train,boosting_rounds)

1 个回答

1

我的数组 X_train 是一个包含对象的 numpy 数组,而 numba 不支持这种类型

@Korem 说得对! 我实际上是这样从文件中加载 img_hogs 变量的:

img_hogs = np.array(pickle.load(file("C:\\PATH_TO_FILE")), dtype=object)

我一直没有注意到这一点。 当我最终去掉 dtype=object 这一部分时,它就正常工作了!

撰写回答