numba guvectorize 中的列表索引错误

1 投票
1 回答
961 浏览
提问于 2025-04-18 10:13

我刚接触numba / numbapro。最近我在尝试运行一个例子,关于使用guvectorize的通用Ufuncs:

(这是例子的链接): http://docs.continuum.io/numbapro/quickstart.html#numbapro-guvectorize

import numbapro as numbapro

@numbapro.guvectorize(['void(int32[:], int32[:])'], '(n)->()')
def sum_row(inp, out):
    """
    Sum every row

    function type: two arrays
                   (note: scalar is represented as an array of length 1)
    signature: n elements to scalar
    """
    tmp = 0.
    for i in range(inp.shape[0]):
        tmp += inp[i]
    out[0] = tmp

我遇到了这个错误:

IndexError                                Traceback (most recent call last)
<ipython-input-98-79514a184595> in <module>()
----> 1 @numbapro.guvectorize(['void(int32[:], int32[:])'], '(n)->()')
      2 def sum_row(inp, out):
      3     """
      4     Sum every row
      5 

/users/adelacalle/anaconda_linux/lib/python2.7/site-packages/numba/npyufunc/decorators.pyc in wrap(func)
    117         for fty in ftylist:
    118             guvec.add(fty)
--> 119         return guvec.build_ufunc()
    120 
    121     return wrap

/users/adelacalle/anaconda_linux/lib/python2.7/site-packages/numba/npyufunc/ufuncbuilder.pyc in build_ufunc(self)
    149 
    150         for sig, cres in self.nb_func.overloads.items():
--> 151             dtypenums, ptr = self.build(cres)
    152             dtypelist.append(dtypenums)
    153             ptrlist.append(utils.longint(ptr))

/users/adelacalle/anaconda_linux/lib/python2.7/site-packages/numba/npyufunc/ufuncbuilder.pyc in build(self, cres)
    167         signature = cres.signature
    168         wrapper = build_gufunc_wrapper(ctx, cres.llvm_func, signature,
--> 169                                        self.sin, self.sout)
    170         ctx.engine.add_module(wrapper.module)
    171         ptr = ctx.engine.get_pointer_to_function(wrapper)

/users/adelacalle/anaconda_linux/lib/python2.7/site-packages/numba/npyufunc/wrappers.pyc in build_gufunc_wrapper(context, func, signature, sin, sout)
    143     for i, (typ, sym) in enumerate(zip(signature.args, sin + sout)):
    144         ary = GUArrayArg(context, builder, arg_args, arg_dims, arg_steps, i,
--> 145                          step_offset, typ, sym, sym_dim)
    146         step_offset += ary.ndim
    147         arrays.append(ary)

/users/adelacalle/anaconda_linux/lib/python2.7/site-packages/numba/npyufunc/wrappers.pyc in __init__(self, context, builder, args, dims, steps, i, step_offset, typ, syms, sym_dim)
    207         self.array = arycls(context, builder)
    208         self.array.data = builder.bitcast(self.data, self.array.data.type)
--> 209         self.array.shape = cgutils.pack_array(builder, self.shape)
    210         self.array.strides = cgutils.pack_array(builder, self.strides)
    211         self.array_value = self.array._getpointer()

/users/adelacalle/anaconda_linux/lib/python2.7/site-packages/numba/cgutils.pyc in pack_array(builder, values)
    257 def pack_array(builder, values):
    258     n = len(values)
--> 259     ty = values[0].type
    260     ary = Constant.undef(Type.array(ty, n))
    261     for i, v in enumerate(values):

IndexError: list index out of range

我没有找到比这个链接更多的文档。我是不是做错了什么?我发现当签名中有空括号时会出现这个问题。 我在一台Linux机器上运行,numbapro的版本是0.14.1。

提前谢谢你,

亚历克斯

1 个回答

2

最后我自己得出了一个结论。你可以分配一个实际包含n个元素的数组,所以你可以把函数的签名写成(n)->(n),而不是(n)->()。文档里可能有错误,或者是过时了。

不过,这样做效率不高,因为你需要分配整个数组,这样会浪费内存(虽然这样是可行的!)。更好的方法是使用guvectorize来对数组中的元素进行求和,这样做既整洁又高效,所以我使用了@jit。

撰写回答