kerasrl的处理器类改变了形状

2024-04-27 10:32:24 发布

您现在位置:Python中文网/ 问答频道 /正文

好吧,我试图给keras-rl的模型提供一个包含10个整数的列表作为输入,但是,当我使用OpenAI-Gym的新环境时,我需要根据需要设置处理器类。我的处理器类如下所示:

class RecoProcessor(Processor):
    def process_observation(self, observation):
        print("Observation:")
#         print(observation.shape)
        look_back = 10
        if observation is None:
            X=np.zeros(10)
        else:
            X=np.array(observation, dtype='float32')
#         X.append(np.zeros{look_back - len(X)})
        print(X.shape)
        return X

    def process_state_batch(self, batch):
        print("Batch:")
        print(batch.shape)
        return batch

    def process_reward(self, reward):
        return reward

    def process_demo_data(self, demo_data):
        for step in demo_data:
            step[0] = self.process_observation(step[0])
            step[2] = self.process_reward(step[2])
        return demo_data

我的经纪人和模特是这样的:

^{pr2}$

但是当y尝试执行这个时,我可以看到输出是这样的:

Training for 50000 steps ...
CCCCCCCCCCCC
(10,)
Interval 1 (0 steps performed)
AAAAAAAAAAAAAAA
(1, 1, 10)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-29-4d8fdf0e849e> in <module>
     32 dqn.compile(Adam(lr), metrics=['mae'])
     33 
---> 34 train = dqn.fit(env, nb_steps=50000, visualize=False, verbose=1, nb_max_episode_steps = None)
     35 np.savetxt(fichero_train_history, 
     36            np.array(train.history["episode_reward"]), delimiter=",")

c:\users\eloy.anguiano\src\keras-rl\rl\core.py in fit(self, env, nb_steps, action_repetition, callbacks, verbose, visualize, nb_max_start_steps, start_step_policy, log_interval, nb_max_episode_steps)
    167                 # This is were all of the work happens. We first perceive and compute the action
    168                 # (forward step) and then use the reward to improve (backward step).
--> 169                 action = self.forward(observation)
    170                 if self.processor is not None:
    171                     action = self.processor.process_action(action)

c:\users\eloy.anguiano\src\keras-rl\rl\agents\dqn.py in forward(self, observation)
     87         # Select an action.
     88         state = self.memory.get_recent_state(observation)
---> 89         q_values = self.compute_q_values(state)
     90         if self.training:
     91             action = self.policy.select_action(q_values=q_values)

c:\users\eloy.anguiano\src\keras-rl\rl\agents\dqn.py in compute_q_values(self, state)
     67 
     68     def compute_q_values(self, state):
---> 69         q_values = self.compute_batch_q_values([state]).flatten()
     70         assert q_values.shape == (self.nb_actions,)
     71         return q_values

c:\users\eloy.anguiano\src\keras-rl\rl\agents\dqn.py in compute_batch_q_values(self, state_batch)
     62     def compute_batch_q_values(self, state_batch):
     63         batch = self.process_state_batch(state_batch)
---> 64         q_values = self.model.predict_on_batch(batch)
     65         assert q_values.shape == (len(state_batch), self.nb_actions)
     66         return q_values

~\AppData\Local\Continuum\anaconda3\lib\site-packages\keras-2.2.4-py3.7.egg\keras\engine\training.py in predict_on_batch(self, x)
   1266             Numpy array(s) of predictions.
   1267         
-> 1268         x, _, _ = self._standardize_user_data(x)
   1269         if self._uses_dynamic_learning_phase():
   1270             ins = x + [0.]

~\AppData\Local\Continuum\anaconda3\lib\site-packages\keras-2.2.4-py3.7.egg\keras\engine\training.py in _standardize_user_data(self, x, y, sample_weight, class_weight, check_array_lengths, batch_size)
    749             feed_input_shapes,
    750             check_batch_axis=False,  # Don't enforce the batch size.
--> 751             exception_prefix='input')
    752 
    753         if y is not None:

~\AppData\Local\Continuum\anaconda3\lib\site-packages\keras-2.2.4-py3.7.egg\keras\engine\training_utils.py in standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix)
    126                         ': expected ' + names[i] + ' to have ' +
    127                         str(len(shape)) + ' dimensions, but got array '
--> 128                         'with shape ' + str(data_shape))
    129                 if not check_batch_axis:
    130                     data_shape = data_shape[1:]

ValueError: Error when checking input: expected embedding_12_input to have 2 dimensions, but got array with shape (1, 1, 10)

如你所见,我得到的是批处理的形状,我不知道如何解决它。如果您想做一些试验,我使用的环境是RecoGym(版本1)。在


Tags: inselfdatastepbatchactionstepsprocess