如何继承numpy.ndarray的子类

6 投票
1 回答
2256 浏览
提问于 2025-04-17 01:45

我在尝试创建自己子类的numpy.ndarray时遇到了困难。我不太明白问题出在哪里,希望有人能解释一下下面这些情况出错的原因,以及我该如何实现我想要的功能。

我想要实现的目标:

我有一个numpy.ndarray的子类,行为符合我的预期(下面代码中的类A)。我想在这个类A的基础上再创建一个子类B(下面代码中的类B),让B包含额外的信息(名字)和方法(装饰过的.simple_data方法)。

情况1:

import numpy as np

class A(np.ndarray):

    def __new__(cls,data):
        obj = np.asarray(data).view(cls)
        return obj

    def __array_finalize(self,obj):
        if obj is None: return

class B(A):

    def __init__(self,data,name):
        super(B,self).__init__(data)
        self.name = name

    @property
    def simple_data(self):
        return [data[0,:],data[:,0]]

if __name__ == '__main__':
    data = np.arange(20).reshape((4,5))
    b = B(data,'B')
    print type(b)
    print b.simple_data

运行这段代码后得到的输出是:

Traceback (most recent call last):
  File "ndsubclass.py", line 24, in <module>
    b = B(data,'B')
TypeError: __new__() takes exactly 2 arguments (3 given)

我猜这和B的构造函数中的'name'变量有关,因为A是numpy.array的子类,所以A的new方法在B的init方法之前被调用。因此,我认为B也需要一个new方法来正确处理这个额外的参数。

我猜想可以这样做:

def __new__(cls,data,name):
    obj = A(data)
    obj.name = name
    return obj

应该可以解决问题,但我该如何改变obj的类呢?

情况2:

import numpy as np

class A(np.ndarray):

    def __new__(cls,data):
        obj = np.asarray(data).view(cls)
        return obj

    def __array_finalize__(self,obj):
        if obj is None: return

class B(A):

    def __new__(cls,data):
        obj = A(data)
        obj.view(cls)
        return obj

    def __array_finalize__(self,obj):
        if obj is None: return

    @property
    def simple_data(self):
        return [self[0,:],self[:,0]]

if __name__ == '__main__':
    data = np.arange(20).reshape((4,5))
    b = B(data)
    print type(b)
    print b.simple_data()

运行后输出是:

<class '__main__.A'>
Traceback (most recent call last):
  File "ndsubclass.py", line 30, in <module>
    print b.simple_data()
AttributeError: 'A' object has no attribute 'simple_data'

这让我感到惊讶,因为我原本期待的是:

<class '__main__.B'>
[array([0, 1, 2, 3, 4]), array([ 0,  5, 10, 15])]

我猜B.new()中的view()调用没有正确设置obj的类。这是为什么呢?

我对发生的事情感到困惑,如果有人能解释一下,我将非常感激。

1 个回答

4

对于情况 1,最简单的方法是:

class B(A):
    def __new__(cls,data,name):
        obj = A.__new__(cls, data)
        obj.name = name
        return obj

__new__其实是一个静态方法,它的第一个参数是类,而不是类方法,所以你可以直接用你想要创建实例的类来调用它。

对于情况 2view并不会直接在原地修改,你需要把结果赋值给某个东西,最简单的方法是:

class B(A):
    def __new__(cls,data):
        obj = A(data)
        return obj.view(cls)

另外,你在中定义的__array_finalize__是一样的(可能只是个打字错误)——你其实不需要这样做。

撰写回答