子类化 numpy ndarray 问题

12 投票
3 回答
5727 浏览
提问于 2025-04-16 12:42

我想要创建一个新的类,继承自numpy的ndarray(这是一个用来处理数组的工具)。不过,我不能改变这个数组。为什么用self = ...的方式不能改变这个数组呢?谢谢。

import numpy as np

class Data(np.ndarray):

    def __new__(cls, inputarr):
        obj = np.asarray(inputarr).view(cls)
        return obj

    def remove_some(self, t):
        test_cols, test_vals = zip(*t)
        test_cols = self[list(test_cols)]
        test_vals = np.array(test_vals, test_cols.dtype)

        self = self[test_cols != test_vals] # Is this part correct?

        print len(self) # correct result

z = np.array([(1,2,3), (4,5,6), (7,8,9)],
    dtype=[('a', int), ('b', int), ('c', int)])
d = Data(z)
d.remove_some([('a',4)])

print len(d)  # output the same size as original. Why?

3 个回答

3

我也试过这样做,但其实要对ndarray进行子类化真的很复杂。

如果你只是想增加一些功能,我建议你可以创建一个类,把数组作为一个属性来存储。

class Data(object):

    def __init__(self, array):
        self.array = array

    def remove_some(self, t):
        //operate on self.array
        pass

d = Data(z)
print(d.array)
6

你没有得到预期结果的原因是因为你在方法 remove_some 中重新赋值了 self。这实际上是在创建一个新的局部变量 self。如果你的数组形状不变,你可以简单地用 self[:] = ... 来保持对 self 的引用,这样一切都会正常,但你现在是想改变 self 的形状。这就意味着我们需要重新分配一些新的内存,并且在提到 self 时需要改变指向的地方。

我不知道怎么做。我原以为可以通过 __array_finalize____array____array_wrap__ 来实现。但我尝试的所有方法都没有成功。

现在,有另一种方法可以做到这一点,而不需要继承 ndarray。你可以创建一个新类,里面有一个属性是 ndarray,然后重写所有常用的 __add____mul__ 等等。大概是这样的:

Class Data(object):
    def __init__(self, inarr):
        self._array = np.array(inarr)
    def remove_some(x):
        self._array = self._array[x]
    def __add__(self, other):
        return np.add(self._array, other)

你明白我的意思了。重写所有操作符确实很麻烦,但从长远来看,我觉得这样更灵活。

你需要仔细阅读 这篇文档,才能正确地做到这一点。有一些方法,比如 __array_finalize__,需要在合适的时机被调用来进行“清理”。

5

也许可以把这个做成一个函数,而不是一个方法:

import numpy as np

def remove_row(arr,col,val):
    return arr[arr[col]!=val]

z = np.array([(1,2,3), (4,5,6), (7,8,9)],
    dtype=[('a', int), ('b', int), ('c', int)])

z=remove_row(z,'a',4)
print(repr(z))

# array([(1, 2, 3), (7, 8, 9)], 
#       dtype=[('a', '<i4'), ('b', '<i4'), ('c', '<i4')])

或者,如果你想把它作为一个方法,

import numpy as np

class Data(np.ndarray):

    def __new__(cls, inputarr):
        obj = np.asarray(inputarr).view(cls)
        return obj

    def remove_some(self, col, val):
        return self[self[col] != val]

z = np.array([(1,2,3), (4,5,6), (7,8,9)],
    dtype=[('a', int), ('b', int), ('c', int)])
d = Data(z)
d = d.remove_some('a', 4)
print(d)

这里的关键区别是,remove_some 并不试图去改变 self,它只是返回一个新的 Data 实例。

撰写回答