使用Numb提取numpy数组中的特定行

2021-06-13 14:19:49 发布

您现在位置:Python中文网/ 问答频道 /正文

我有以下数组:

import numpy as np
from numba import njit


test_array = np.random.rand(4, 10)

我创建了一个“jitted”函数,用于分割数组,然后执行一些操作:

@njit(fastmath = True)
def test_function(array):

   test_array_sliced = test_array[[0,1,3]]

   return test_array_sliced

但是,Numba抛出以下错误:

In definition 11:
    TypeError: unsupported array index type list(int64) in [list(int64)]
    raised from /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/numba/typing/arraydecl.py:71
This error is usually caused by passing an argument of a type that is unsupported by the named function.

变通办法

我尝试使用np.delete删除不需要的行,但由于必须指定axisNumba,因此会引发以下错误:

@njit(fastmath = True)
def test_function(array):

   test_array_sliced = np.delete(test_array, obj = 2, axis = 0)

   return test_array_sliced

In definition 1:
    TypeError: np_delete() got an unexpected keyword argument 'axis'
    raised from /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/numba/typing/templates.py:475
This error is usually caused by passing an argument of a type that is unsupported by the named function.

你知道如何在Numba下提取特定的行吗?你知道吗

2条回答
网友
1楼 ·

Numba不支持numpy fancy索引。我不能100%确定您的实际用例是什么样的,但一个简单的方法是:

import numpy as np
import numba as nb

@nb.njit
def test_func(x):
    idx = (0, 1, 3)
    res = np.empty((len(idx), x.shape[1]), dtype=x.dtype)
    for i, ix in enumerate(idx):
        res[i] = x[ix]

    return res

test_array = np.random.rand(4, 10)
print(test_array)
print()
print(test_func(test_array))

编辑:@kwinkunks是正确的,我最初的回答是错误的,不支持花哨的索引。这是在一系列有限的情况下,包括这一个。你知道吗

网友
2楼 ·

我认为如果使用数组而不是列表进行索引,它将起作用(似乎建议使用soin the docs):

test_array_sliced = array[np.array([0,1,3])]

(我将要切片的数组更改为array,这是传递给函数的内容。也许是故意的,但要小心全局的!)你知道吗

相关问题