煤油backend.repeat_元素不工作?

2024-04-20 11:39:53 发布

您现在位置:Python中文网/ 问答频道 /正文

我试着在这里创建一个简单的矩阵,在我的批处理中重复每个样本。在

矩阵如下:

balanceMatrix = np.array([[[5,10,10],[1,1,1],[1,1,1]]])
print(balanceMatrix.shape)

balanceMatrix = K.constant(balanceMatrix)
print(K.shape(balanceMatrix).eval())

到目前为止,很好,我得到了预期的矩阵形状(1,3,3)。 现在我想对每批样品重复一次(比如60000个样品)。从kerasdocumentation开始,我要做的就是:

^{pr2}$

但这引发了以下错误,我无法简单理解:

IndexError                                Traceback (most recent call last)
<ipython-input-28-4356baf13de8> in <module>()
     20 balanceMatrix = K.constant(balanceMatrix)
     21 print(K.shape(balanceMatrix).eval())
---> 22 balanceMatrix = K.repeat_elements(balanceMatrix, 60000,axis=0)
     23 print(K.shape(balanceMatrix).eval())
     24 

c:\users\ut65\appdata\local\programs\python\python35\lib\site-packages\keras\backend\theano_backend.py in repeat_elements(x, rep, axis)
    743     if hasattr(x, '_keras_shape'):
    744         y._keras_shape = list(x._keras_shape)
--> 745         repeat_dim = x._keras_shape[axis]
    746         if repeat_dim is not None:
    747                 y._keras_shape[axis] = repeat_dim * rep

IndexError: tuple index out of range

怎么回事?? 我知道,我可以先用np.repeat(balanceMatrix,60000,axis=0)创建keras张量,然后创建keras张量,但是keras选项不也可以吗?在


Tags: inevalnp样品矩阵elementskerasrepeat