使用“展平”或“重塑”在keras中获得未知输入形状的1D输出

2024-04-25 09:52:54 发布

您现在位置:Python中文网/ 问答频道 /正文

我想在模型末尾使用keras层Flatten()或{}来输出一个1D向量,比如[0,0,1,0,0, ... ,0,0,1,0]。在

不幸的是,有一个问题,因为我的未知输入形状是:
input_shape=(4, None, 1)))。在

因此,通常输入形状介于[batch_size, 4, 64, 1][batch_size, 4, 256, 1]之间,输出应该是批处理大小x未知维度(对于上面的第一个示例:[batch_size, 64]和第二个示例[batch_size, 256])。在

我的模型看起来像:

model = Sequential()
model.add(Convolution2D(32, (4, 32), padding='same', input_shape=(4, None, 1)))
model.add(BatchNormalization())
model.add(LeakyReLU())
model.add(Convolution2D(1, (1, 2), strides=(4, 1), padding='same'))
model.add(Activation('sigmoid'))
# model.add(Reshape((-1,))) produces the error
# int() argument must be a string, a bytes-like object or a number, not 'NoneType' 
model.compile(loss='binary_crossentropy', optimizer='adadelta')

所以我当前的输出形状是[batchsize,1,未知维度,1]。 这不允许我使用类权重,例如"ValueError: class_weight not supported for 3+ dimensional targets."。在

当我使用灵活的输入形状时,是否可以使用Flatten()或{}来压缩我在keras(2.0.4和tensorflow后端)中的三维输出?在

非常感谢!在


Tags: 模型noneadd示例inputsizemodelbatch
1条回答
网友
1楼 · 发布于 2024-04-25 09:52:54

您可以尝试K.batch_flatten()包装在Lambda层中。 K.batch_flatten()的输出形状在运行时动态确定。在

model.add(Lambda(lambda x: K.batch_flatten(x)))
model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_5 (Conv2D)            (None, 4, None, 32)       4128      
_________________________________________________________________
batch_normalization_3 (Batch (None, 4, None, 32)       128       
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU)    (None, 4, None, 32)       0         
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 1, None, 1)        65        
_________________________________________________________________
activation_3 (Activation)    (None, 1, None, 1)        0         
_________________________________________________________________
lambda_5 (Lambda)            (None, None)              0         
=================================================================
Total params: 4,321
Trainable params: 4,257
Non-trainable params: 64
_________________________________________________________________


X = np.random.rand(32, 4, 256, 1)
print(model.predict(X).shape)
(32, 256)

X = np.random.rand(32, 4, 64, 1)
print(model.predict(X).shape)
(32, 64)

相关问题 更多 >

    热门问题