账号np.diff有漏洞吗?

2024-05-13 04:17:47 发布

您现在位置:Python中文网/ 问答频道 /正文

我遇到了一个问题np.差异不适用于矩阵的切片。这是虫子还是我做错了什么?你知道吗

import numpy as np
from numba import njit
v = np.ones((2,2))
np.diff(v[:,0])
array([0.])
@njit
def numbadiff(x):
    return np.diff(x)

numbadiff(v[:,0])

最后一个调用返回一个错误,但我不知道为什么。你知道吗


Tags: fromimportnumpyasnponesdiff切片
2条回答

问题是Numba中的np.diff进行了内部整形,这只支持contiguous arrays。您正在制作的片v[:, 0]不是连续的,因此出现了错误。可以使用^{}使其工作,如果给定数组尚未连续,则返回该数组的连续副本:

numbadiff(np.ascontiguousarray(v[:, 0]))

注意,您也可以避免np.diff,并将numbadiff重新定义为:

@njit
def numbadiff(x):
    return x[:-1] - x[1:]

当你遇到错误时,礼貌的做法是表现出错误。有时回溯的完全错误是合适的。对于numba来说,这可能太多了,但是您应该尝试发布一个摘要。它使我们更容易,特别是当我们无法运行您的代码并亲自看到错误时。你甚至可以学到一些东西。你知道吗

我运行了你的例子,得到(部分):

In [428]: numbadiff(np.ones((2,2))[:,0])                                        
                                     -
TypingError    
...
TypeError: reshape() supports contiguous array only
...
    def diff_impl(a, n=1):
        <source elided>
        # To make things easier, normalize input and output into 2d arrays
        a2 = a.reshape((-1, size))
...
TypeError: reshape() supports contiguous array only
....
This is not usually a problem with Numba itself but instead often caused by
the use of unsupported features or an issue in resolving types.

这支持@jdehesa提供的诊断和修复。这不是numba中的bug,而是您的输入有问题。你知道吗

使用numba的一个缺点是错误更难理解。另一个明显的问题是,它对于诸如这个数组视图之类的输入没有那么灵活。如果你真的想要速度上的优势,你需要愿意自己去挖掘错误信息。你知道吗

相关问题 更多 >