Python类型检查numpy数组包括它们的数据类型

2022-05-21 07:40:26 发布

您现在位置:Python中文网/ 问答频道 /正文

我可以使用以下方法验证我的功能是否接收到正确类型的输入:

def foo(x: np.ndarray, y: float):
    return x * y

确保如果我尝试将此函数与不是np.ndarrayx一起使用,即使在运行代码之前,我也会收到一个错误

我不知道的是如何验证数组类型。例如:

 def return_valid_points_only(points: np.ndarray, valid: np.ndarray):
    assert points.shape == valid.shape
    return points[valid]

我想检查一下valid不仅是np.ndarray而且是valid.dtype == bool

对于这个例子,如果valid将提供0和1来表示有效性,那么程序不会失败,我将得到可怕的结果

谢谢


Tags: 方法函数代码功能类型returnfoodef错误npfloatpointsndarrayshapevalid
1条回答
网友
1楼 ·

Python是关于请求原谅,而不是许可的。这意味着,即使在您的第一个定义中,def foo(x: np.ndarray, y: float):也确实依赖于用户来遵守提示,除非您使用类似于mypy的东西

这里有几种方法,通常是串联的。一种是以与传入的输入一起工作的方式编写函数,这可能意味着失败或强制无效输入。另一种方法是仔细记录代码,这样用户就可以做出明智的决定。第二种方法特别重要,但我将重点介绍第一种方法

Numpy为您执行大部分检查。例如,与其期望一个数组,不如强制一个数组:

x = np.asanyarray(x)

^{}通常是array(a, dtype, copy=False, order=order, subok=True)的别名。您可以为y执行类似的操作:

y = np.asanyarray(y).item()

这将允许任何类似的数组,只要它有一个元素,无论是否为标量。另一种方法是尊重numpy一起广播数组的能力,因此如果用户将y作为x.shape[-1]元素的列表传入

对于第二个函数,您有两个选项。一个选择是允许一个奇特的索引。因此,如果用户传入一个索引列表和一个布尔掩码,您可以同时使用这两种方法。另一方面,如果坚持使用布尔掩码,则可以检查或强制数据类型

如果选中,请记住,如果数组大小不匹配,numpy索引操作将为您引发错误。您只需检查类型本身:

points = np.asanyarray(points)
valid = np.asanyarray(valid)
if valid.dtype != bool:
    raise ValueError('valid argument must be a boolean mask')

如果选择强制,则允许用户使用0和1,但不会不必要地复制有效输入:

valid = np.asanyarray(valid, bool)