numpy ndarray的子类无法按预期工作

6 投票
2 回答
866 浏览
提问于 2025-04-18 06:58

大家好。

我发现当我对ndarray进行子类化时,有一些奇怪的行为。

import numpy as np

class fooarray(np.ndarray):
    def __new__(cls, input_array, *args, **kwargs):
        obj = np.asarray(input_array).view(cls)
        return obj

    def __init__(self, *args, **kwargs):
        return

    def __array_finalize__(self, obj):
        return

a=fooarray(np.random.randn(3,5))
b=np.random.randn(3,5)

a_sum=np.sum(a,axis=0,keepdims=True)
b_sum=np.sum(b,axis=0, keepdims=True)

print a_sum.ndim #1
print b_sum.ndim #2

正如你们所看到的,keepdims这个参数在我的子类fooarray中不起作用。它丢失了一个维度。我该如何避免这个问题呢?或者更一般来说,我该如何正确地对子类numpy的ndarray进行操作呢?

2 个回答

2

为了更详细地解释一下@mskimm的评论,如果你查看numpy的源代码中的相关部分,core/fromnumeric.py,就能明白为什么a.sum(..., keepdims=True)可以正常工作,而np.sum(a, ..., keepdims=True)却不行:

def sum(a, axis=None, dtype=None, out=None, keepdims=False):
    ...
    if isinstance(a, _gentype):
        res = _sum_(a)
        if out is not None:
            out[...] = res
            return out
        return res
    elif type(a) is not mu.ndarray:
        try:
            sum = a.sum
        except AttributeError:
            return _methods._sum(a, axis=axis, dtype=dtype,
                                out=out, keepdims=keepdims)
        # NOTE: Dropping the keepdims parameters here...
        return sum(axis=axis, dtype=dtype, out=out)
    else:
        return _methods._sum(a, axis=axis, dtype=dtype,
                            out=out, keepdims=keepdims)
    ...

因为你创建了一个np.ndarray的子类,所以type(a)fooarray,而不是mu.ndarray,所以你最终会到达这一行:

# NOTE: Dropping the keepdims parameters here...
return sum(axis=axis, dtype=dtype, out=out)

keepdims这个参数是ndarrays比较新的一个功能,但目前对于某些其他类似数组的类,比如np.matrixnp.ma.masked_array,这个功能还没有实现,尽管它们也有.sum()方法。因此,对于非ndarray的情况,这个参数就会被忽略。

5

np.sum 可以接受多种类型的输入,不仅仅是 ndarrays(多维数组),还可以是列表、生成器和 np.matrix 等。例如,keepdims 这个参数对于列表或生成器来说显然没有意义。对于 np.matrix 实例来说也不合适,因为 np.matrix 总是有两个维度。如果你查看 np.matrix.sum 的调用方式,你会发现它的 sum 方法没有 keepdims 这个参数:

Definition: np.matrix.sum(self, axis=None, dtype=None, out=None)

所以某些 ndarray 的子类可能有 sum 方法,但这些方法没有 keepdims 参数。这违反了里斯科夫替换原则,也是你遇到问题的根源。

现在如果你查看np.sum 的源代码,你会发现它是一个委托函数,试图根据第一个参数的类型来决定该做什么。

如果第一个参数的类型不是 ndarray,它就会忽略 keepdims 参数。这样做是因为将 keepdims 参数传给 np.matrix.sum 会引发异常。

因此,由于 np.sum 尝试以最通用的方式进行委托,不对 ndarray 的子类可能接受的参数做任何假设,当传入一个 fooarray 时,它就会忽略 keepdims 参数。

解决方法是不要使用 np.sum,而是直接调用 a.sum。这样做更直接,因为 np.sum 只是一个委托函数。

import numpy as np


class fooarray(np.ndarray):
    def __new__(cls, input_array, *args, **kwargs):
        obj = np.asarray(input_array, *args, **kwargs).view(cls)
        return obj

a = fooarray(np.random.randn(3, 5))
b = np.random.randn(3, 5)

a_sum = a.sum(axis=0, keepdims=True)
b_sum = np.sum(b, axis=0, keepdims=True)

print(a_sum.ndim)  # 2
print(b_sum.ndim)  # 2

撰写回答