Python传递列表或Numpy数组
我在想,有没有人能提供一个优雅的解决方案,让我们可以把一个Python列表、一个形状为(n,)的numpy向量,或者一个形状为(n,1)的numpy向量传递给一个函数。我的想法是,让这个函数能够通用,能够接受这三种类型,而不增加复杂性。
初步想法:
1) Use a type checking decorator function and cast to a standard representation.
2) Add type checking logic inline (significantly less ideal than #1).
3) ?
我一般不使用Python内置的数组类型,但我怀疑这个问题的解决方案也应该能支持这些类型。
2 个回答
3
我觉得最简单的方法就是在你的函数开头使用 numpy.atleast_2d
。这样,你的三种情况都会被转换成 x.shape == (n, 1)
的形式,这样可以让你的函数更简单。
举个例子,
def sum(x):
x = np.atleast_2d(x)
return np.dot(x, np.ones((x.shape[0], 1)))
atleast_2d
会返回一个数组的视图,所以如果你传入的已经是一个 ndarray
,那就不会增加太多的负担。不过,如果你打算修改 x
,想要做一个副本的话,可以用 x = np.atleast_2d(np.array(x))
这样来实现。
2
你可以把这三种类型转换成一种“标准”的类型,也就是一个一维数组,方法如下:
arr = np.asarray(arr).ravel()
在这里加一个装饰器:
import numpy as np
import functools
def takes_1dim_array(func):
@functools.wraps(func)
def f(arr, *a, **kw):
arr = np.asarray(arr).ravel()
return func(arr, *a, **kw)
return f
然后:
@takes_1dim_arr
def func(arr):
print arr.shape