如何在Keras中通过多输入批量训练

2024-04-27 04:28:44 发布

您现在位置:Python中文网/ 问答频道 /正文

 ValueError: could not broadcast input array from shape (60,60,2) into shape (1)

我试过在我的代码中用某种方式修改,但仍然有相同的错误。你知道吗

  1. 你知道吗状态.append(np.数组(s) )#标记1 目标\u f_列表.append(np.数组(目标#f))#标记2
  2. 你知道吗self.model.train\u on\u批量([状态],[目标列表])#标记3
  3. 你知道吗self.model.train\u on\u批量(np.数组(州),np.数组(目标#列表)#标记3

这是我的网络Keras:

    input_1 = Input(shape=(60, 60, 2))
    input_2 = Input(shape=(self.action_size, self.action_size))
    x1 = Conv2D(32, (4, 4), strides=(2, 2), padding='Same', activation=LeakyReLU(alpha=self.Beta))(input_1)
    x1 = Conv2D(64, (2, 2), strides=(2, 2), padding='Same', activation=LeakyReLU(alpha=self.Beta))(x1)
    x1 = Conv2D(128, (2, 2), strides=(1, 1), padding='Same', activation=LeakyReLU(alpha=self.Beta))(x1)
    x1 = Flatten()(x1)
    x1 = Dense(128, activation=LeakyReLU(alpha=self.Beta))(x1)
    x1_value = Dense(64, activation=LeakyReLU(alpha=self.Beta))(x1)
    value = Dense(1, activation=LeakyReLU(alpha=self.Beta))(x1_value)
    x1_advantage = Dense(64, activation=LeakyReLU(alpha=self.Beta))(x1)
    advantage = Dense(self.action_size, activation=LeakyReLU(alpha=self.Beta))(x1_advantage)

    A = Dot(axes=1)([input_2, advantage])
    A_subtract = Subtract()([advantage, A])

    Q_value = Add()([value, A_subtract])

    model = Model(inputs=[input_1, input_2], outputs=[Q_value])
    model.compile(optimizer=Adam(lr=self.epsilon_r), loss='mse')

我的职责是训练:

    state = []
    target_f_list = []
    for s, a, r, next_s, done in minibatch:
        if not done:

            ... do calculate target_f ...

            state.append(s)                   # mark 1
            target_f_list.append(target_f)    # mark 2

            # this is fit function i use before and it's worked fine. But i want to train all minibatch add the same time.
            # self.model.fit(s, target_f, epochs=1, verbose=0, batch_size=self.minibatch_size)

    # This is my code has error
    self.model.train_on_batch(state,target_f_list)  # mark 3

谢谢你阅读我的问题。你知道吗


Tags: selfalphatargetinputsizemodelvalueactivation