numpy ndarray的子类无法按预期工作
大家好。
我发现当我对ndarray进行子类化时,有一些奇怪的行为。
import numpy as np
class fooarray(np.ndarray):
def __new__(cls, input_array, *args, **kwargs):
obj = np.asarray(input_array).view(cls)
return obj
def __init__(self, *args, **kwargs):
return
def __array_finalize__(self, obj):
return
a=fooarray(np.random.randn(3,5))
b=np.random.randn(3,5)
a_sum=np.sum(a,axis=0,keepdims=True)
b_sum=np.sum(b,axis=0, keepdims=True)
print a_sum.ndim #1
print b_sum.ndim #2
正如你们所看到的,keepdims
这个参数在我的子类fooarray
中不起作用。它丢失了一个维度。我该如何避免这个问题呢?或者更一般来说,我该如何正确地对子类numpy的ndarray进行操作呢?
2 个回答
为了更详细地解释一下@mskimm的评论,如果你查看numpy的源代码中的相关部分,core/fromnumeric.py
,就能明白为什么a.sum(..., keepdims=True)
可以正常工作,而np.sum(a, ..., keepdims=True)
却不行:
def sum(a, axis=None, dtype=None, out=None, keepdims=False):
...
if isinstance(a, _gentype):
res = _sum_(a)
if out is not None:
out[...] = res
return out
return res
elif type(a) is not mu.ndarray:
try:
sum = a.sum
except AttributeError:
return _methods._sum(a, axis=axis, dtype=dtype,
out=out, keepdims=keepdims)
# NOTE: Dropping the keepdims parameters here...
return sum(axis=axis, dtype=dtype, out=out)
else:
return _methods._sum(a, axis=axis, dtype=dtype,
out=out, keepdims=keepdims)
...
因为你创建了一个np.ndarray
的子类,所以type(a)
是fooarray
,而不是mu.ndarray
,所以你最终会到达这一行:
# NOTE: Dropping the keepdims parameters here...
return sum(axis=axis, dtype=dtype, out=out)
keepdims
这个参数是ndarrays
比较新的一个功能,但目前对于某些其他类似数组的类,比如np.matrix
或np.ma.masked_array
,这个功能还没有实现,尽管它们也有.sum()
方法。因此,对于非ndarray
的情况,这个参数就会被忽略。
np.sum
可以接受多种类型的输入,不仅仅是 ndarrays(多维数组),还可以是列表、生成器和 np.matrix
等。例如,keepdims
这个参数对于列表或生成器来说显然没有意义。对于 np.matrix
实例来说也不合适,因为 np.matrix
总是有两个维度。如果你查看 np.matrix.sum
的调用方式,你会发现它的 sum
方法没有 keepdims
这个参数:
Definition: np.matrix.sum(self, axis=None, dtype=None, out=None)
所以某些 ndarray
的子类可能有 sum
方法,但这些方法没有 keepdims
参数。这违反了里斯科夫替换原则,也是你遇到问题的根源。
现在如果你查看np.sum 的源代码,你会发现它是一个委托函数,试图根据第一个参数的类型来决定该做什么。
如果第一个参数的类型不是 ndarray
,它就会忽略 keepdims
参数。这样做是因为将 keepdims
参数传给 np.matrix.sum
会引发异常。
因此,由于 np.sum
尝试以最通用的方式进行委托,不对 ndarray
的子类可能接受的参数做任何假设,当传入一个 fooarray
时,它就会忽略 keepdims
参数。
解决方法是不要使用 np.sum
,而是直接调用 a.sum
。这样做更直接,因为 np.sum
只是一个委托函数。
import numpy as np
class fooarray(np.ndarray):
def __new__(cls, input_array, *args, **kwargs):
obj = np.asarray(input_array, *args, **kwargs).view(cls)
return obj
a = fooarray(np.random.randn(3, 5))
b = np.random.randn(3, 5)
a_sum = a.sum(axis=0, keepdims=True)
b_sum = np.sum(b, axis=0, keepdims=True)
print(a_sum.ndim) # 2
print(b_sum.ndim) # 2