ValueError:“Concatenate”层需要具有匹配形状的输入

2024-04-25 06:14:49 发布

您现在位置:Python中文网/ 问答频道 /正文

    from fr_utils import *
    from inception_blocks_v2 import *

    def triplet_loss(y_true, y_pred, alpha=0.3):
        """
        Implementation of the triplet loss as defined by formula (3)

        Arguments:
        y_pred -- python list containing three objects:
                anchor -- the encodings for the anchor images, of shape (None, 128)
                positive -- the encodings for the positive images, of shape (None, 128)
                negative -- the encodings for the negative images, of shape (None, 128)

        Returns:
        loss -- real number, value of the loss
        """

        anchor, positive, negative = y_pred[0], y_pred[1], y_pred[2]

        # Step 1: Compute the (encoding) distance between the anchor and the positive, you will need to sum over axis=-1
        pos_dist = tf.reduce_sum(tf.square(tf.subtract(anchor, positive)), axis=-1)
        # Step 2: Compute the (encoding) distance between the anchor and the negative, you will need to sum over axis=-1
        neg_dist = tf.reduce_sum(tf.square(tf.subtract(anchor, negative)), axis=-1)
        # Step 3: subtract the two previous distances and add alpha.
        basic_loss = tf.add(tf.subtract(pos_dist, neg_dist), alpha)
        # Step 4: Take the maximum of basic_loss and 0.0. Sum over the training examples.
        loss = tf.reduce_sum(tf.maximum(basic_loss, 0.0))

        return loss

    def main():
        FRmodel = faceRecoModel(input_shape=(3, 96, 96))
        FRmodel.compile(optimizer='adam', loss=triplet_loss, metrics=['accuracy'])
        FRmodel.save('face-rec_Google.h5')
        print_summary(model)

    main()

此代码中显示的错误如下

已获取输入形状:%s%(input\u shape))

值错误:Concatenate层需要具有匹配形状的输入,但concat轴除外。获取输入形状:[(None,128,12,192),(None,32,12,192),(None,32,12,102),(None,64,12,192)]

我试着在网上查找错误,但没有找到解决办法


Tags: andofthenonedisttfstepsum
1条回答
网友
1楼 · 发布于 2024-04-25 06:14:49

您只需要更改图像的表示。我想,你们的图像是用三维数组表示的,颜色通道在最后一个维度。此代码将把颜色通道移到第一维[channels][rows][cols]

from keras import backend as K
K.set_image_data_format('channels_first')

相关问题 更多 >