numpy子类的数组

4 投票
2 回答
1326 浏览
提问于 2025-04-18 14:22

我遇到了一个问题。问题是:我想创建一个numpy数组的子类,然后用这种类型的对象来创建一个数组。当我引用这个数组中的某个元素时,我希望它仍然是这个子类的实例。但实际上,它却变成了numpy数组的实例。

这里有一个测试代码,但它失败了:

import numpy as np


class ImageWrapper(np.ndarray):

    def __new__(cls, image_data):
        assert image_data.ndim in (2, 3)
        return image_data.view(cls)

    @property
    def n_colours(self): 
        return 1 if self.ndim==2 else self.shape[2]


n_frames = 10
frames = [ImageWrapper(np.random.randint(255, size = (20, 15, 3)).astype('uint8')) for _ in xrange(n_frames)]
video = np.array(frames)
assert video[0].n_colours == 3

结果是:AttributeError: 'numpy.ndarray'对象没有'n_colours'这个属性。

我该怎么才能让它正常工作呢?

我已经尝试过的办法:

  • 在构建视频时设置subok=True - 这个方法只在从单个子类对象构建数组时有效,而不能用于列表。
  • 设置dtype=object或dtype=ImageWrapper也不行。

我知道我可以把视频做成一个列表,但出于其他原因,我更希望它保持为numpy数组。

2 个回答

2

numpy.array 这个函数不够高级,无法处理这种情况。subok=True 是告诉这个函数可以接受子类,但你传给它的不是 ndarray 的子类,而是一个列表(这个列表里面装的是 ndarray 子类的实例)。如果你想得到你期待的结果,可以这样做:

import numpy as np


class ImageWrapper(np.ndarray):

    def __new__(cls, image_data):
        assert 2 <= image_data.ndim <= 4
        return image_data.view(cls)

    @property
    def n_colours(self): 
        return 1 if self.ndim==2 else self.shape[-1]


n_frames = 10
frame_shape = (20, 15, 3)
video = ImageWrapper(np.empty((n_frames,) + frame_shape, dtype='uint8'))
for i in xrange(n_frames):
    video[i] = np.random.randint(255, size=(20, 15, 3))
assert video[0].n_colours == 3

注意,我需要更新 ImageWrapper,以便它可以接受四维数组作为输入。

3

无论你想要实现什么,可能都有比直接继承ndarray更好的方法。不过,如果你真的想这么做,你可以把你的数组设置为object类型,但在创建的时候要小心。下面这个方法是可以的:

>>> video = np.empty((len(frames),), dtype=object)
>>> video[:] = frames
>>> video[0].n_colours
3

但是这个方法就不行:

>>> video = np.array(frames, dtype=object)
>>> video[0].n_colours
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
AttributeError: 'numpy.ndarray' object has no attribute 'n_colours'

撰写回答