如何封装numpy数组类型?
我想创建一个类,继承自numpy数组的基本类型,
class LemmaMatrix(numpy.ndarray):
@classmethod
def init_from_corpus(cls, ...): cls(numpy.empty(...))
但显然,它不支持多维数组类型。有没有什么办法可以解决这个问题?提前谢谢!
ndarray(empty([3, 3]))
TypeError: only length-1 arrays can be converted to Python scalars
1 个回答
5
import numpy as np
class LemmaMatrix(np.ndarray):
def __new__(subtype,data,dtype=None):
subarr=np.empty(data,dtype=dtype)
return subarr
lm=LemmaMatrix([3,3])
print(lm)
# [[ 3.15913337e-260 4.94951870e+173 4.88364603e-309]
# [ 1.63321355e-301 4.80218258e-309 2.05227026e-287]
# [ 2.10277051e-309 2.07088188e+289 7.29366696e-304]]
你可能还想看看这个指南,里面有关于如何创建ndarray
子类的更多信息。