Numpy ndarray子类 - 在__array_finalize__中强制重塑
我遇到了一些问题:
我想写一个ndarray的子类,并且希望这个子类的任何新实例都能强制设置形状为(-1,3),无论是通过显式构造函数、视图转换还是模板生成。
我尝试了很多方法,但似乎都没有成功。 我觉得我可能还没有完全理解这个过程。任何帮助都非常感谢!
import numpy as np
class test(np.ndarray):
def __new__(cls, *args, **kwargs):
return np.ndarray.__new__(cls, *args, **kwargs)
def __array_finalize__(self, obj):
# self.resize(-1,3)
# self.reshape(-1,3)
# self=self.reshape(-1,3)
np.reshape(self,(-1,3))
a=np.array([1,2,3])
b=a.view(test)
c=test(a)
d=a.reshape(-1,3)
print '+++++++'
print a.shape,a
print '+++++++'
print b.shape,b
print '+++++++'
print c.shape,c
print '+++++++'
print d.shape,d
为了更清楚我想做的事情:
我有一些向量场,想把它们通用地当作3D来处理,所以我需要形状为(:,3)和(-1,3)的调整。我希望找到一个纯面向对象的解决方案,来实现一些额外的方法,以补充NumPy自带的功能。
例如,我开始纯粹使用ndarrays来写一些代码,但如果我能直接写
normalizedVector = ndarray.view(my3DVectorClass).normalize()
而不是
normalizedVector = ndarray / ( sum(ndarray**2, axis=1)**0.5 )
那样的话,代码会更易读。
我在第二种方法上遇到的问题:
- 我希望不必担心我请求的是形状为(3,)还是(:,3)的数组的标准化版本。
- 我希望能够在类的方法实现中使用纯线性代数的术语,而不必在方法定义中处理索引和错误/维度检查。
我想你可以说只用我的my3DVectorClass的实例来工作,但在使用SciPy的各种工具时,我就得进行反向视图转换,因为如果我没记错的话,它们期望的是ndarray,这样会让代码的某些部分变得有些臃肿。
如果我有哪里理解错了,非常感谢你的建议。我在面向对象编程和SciPy/NumPy的学习上还在摸索阶段。
非常感谢!
Markus
2 个回答
reshape
会尝试用新的形状来创建数据的视图,如果不能做到这一点,它就会用新的形状创建数据的副本。但原始的数据对象不会改变。如果你想直接修改形状,可以这样做:
self.shape = (-1, 3)
举个例子:
>>> a = np.arange(9)
>>> a
array([0, 1, 2, 3, 4, 5, 6, 7, 8])
>>> np.reshape(a, (-1, 3)) # creates a view with the new shape
array([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
>>> a # but the original object is unchanged
array([0, 1, 2, 3, 4, 5, 6, 7, 8])
>>> a.shape = (-1, 3) # this modifies the original object
>>> a
array([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
不过你要小心,因为如果不能在不复制的情况下改变形状,它会抛出一个 AttributeError
错误:
>>> a = np.arange(36).reshape(6, 6).T
>>> b = np.reshape(a, (-1, 3)) # creates a copy of the data in a
>>> a.shape = (-1, 3) # tries to reshape in-place, and fails
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
AttributeError: incompatible shape for a non-contiguous array
你可以看看matrix
类是怎么实现的。它用了一些类似的技巧来保持ndims=2
。
不过,我和很多人认为这种技巧带来的麻烦远远超过它的好处。matrix
类过去引发了很多问题,因为它只部分地像一个普通的ndarray
。建议你考虑写一些函数来解决问题。你上面给出的代码示例,如果改成这样会更容易理解:normalizedVector = normalize(ndarray)
。创建更多的子类并不总是最好的设计,即使是在使用面向对象的风格时。