Numpy ndarray子类 - 在__array_finalize__中强制重塑

2 投票
2 回答
953 浏览
提问于 2025-04-17 16:24

我遇到了一些问题:

我想写一个ndarray的子类,并且希望这个子类的任何新实例都能强制设置形状为(-1,3),无论是通过显式构造函数、视图转换还是模板生成。

我尝试了很多方法,但似乎都没有成功。 我觉得我可能还没有完全理解这个过程。任何帮助都非常感谢!

import numpy as np

class test(np.ndarray):
def __new__(cls, *args, **kwargs):
    return np.ndarray.__new__(cls, *args, **kwargs)

def __array_finalize__(self, obj):

#        self.resize(-1,3)
#        self.reshape(-1,3)
#        self=self.reshape(-1,3)
        np.reshape(self,(-1,3))

a=np.array([1,2,3])
b=a.view(test)
c=test(a)
d=a.reshape(-1,3)
print '+++++++'
print a.shape,a
print '+++++++'
print b.shape,b
print '+++++++'
print c.shape,c
print '+++++++'
print d.shape,d

为了更清楚我想做的事情:

我有一些向量场,想把它们通用地当作3D来处理,所以我需要形状为(:,3)和(-1,3)的调整。我希望找到一个纯面向对象的解决方案,来实现一些额外的方法,以补充NumPy自带的功能。

例如,我开始纯粹使用ndarrays来写一些代码,但如果我能直接写

normalizedVector = ndarray.view(my3DVectorClass).normalize()

而不是

normalizedVector = ndarray / ( sum(ndarray**2, axis=1)**0.5 )

那样的话,代码会更易读。

我在第二种方法上遇到的问题:

  • 我希望不必担心我请求的是形状为(3,)还是(:,3)的数组的标准化版本。
  • 我希望能够在类的方法实现中使用纯线性代数的术语,而不必在方法定义中处理索引和错误/维度检查。

我想你可以说只用我的my3DVectorClass的实例来工作,但在使用SciPy的各种工具时,我就得进行反向视图转换,因为如果我没记错的话,它们期望的是ndarray,这样会让代码的某些部分变得有些臃肿。

如果我有哪里理解错了,非常感谢你的建议。我在面向对象编程和SciPy/NumPy的学习上还在摸索阶段。

非常感谢!

Markus

2 个回答

0

reshape 会尝试用新的形状来创建数据的视图,如果不能做到这一点,它就会用新的形状创建数据的副本。但原始的数据对象不会改变。如果你想直接修改形状,可以这样做:

self.shape = (-1, 3)

举个例子:

>>> a = np.arange(9)
>>> a
array([0, 1, 2, 3, 4, 5, 6, 7, 8])
>>> np.reshape(a, (-1, 3)) # creates a view with the new shape
array([[0, 1, 2],
       [3, 4, 5],
       [6, 7, 8]])
>>> a # but the original object is unchanged
array([0, 1, 2, 3, 4, 5, 6, 7, 8])
>>> a.shape = (-1, 3) # this modifies the original object
>>> a
array([[0, 1, 2],
       [3, 4, 5],
       [6, 7, 8]])

不过你要小心,因为如果不能在不复制的情况下改变形状,它会抛出一个 AttributeError 错误:

>>> a = np.arange(36).reshape(6, 6).T
>>> b = np.reshape(a, (-1, 3)) # creates a copy of the data in a
>>> a.shape = (-1, 3) # tries to reshape in-place, and fails
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
AttributeError: incompatible shape for a non-contiguous array
7

你可以看看matrix是怎么实现的。它用了一些类似的技巧来保持ndims=2

不过,我和很多人认为这种技巧带来的麻烦远远超过它的好处。matrix类过去引发了很多问题,因为它只部分地像一个普通的ndarray。建议你考虑写一些函数来解决问题。你上面给出的代码示例,如果改成这样会更容易理解:normalizedVector = normalize(ndarray)。创建更多的子类并不总是最好的设计,即使是在使用面向对象的风格时。

撰写回答