TFagents Replay缓冲区将轨迹添加到批处理形状不匹配

2024-04-18 11:47:49 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在发布另一个用户发布的问题,然后被删除。我有同样的问题,我找到了答案。原问题:

根据本教程,我目前正在尝试实现一个分类DQN:https://www.tensorflow.org/agents/tutorials/9_c51_tutorial

下面的部分让我有点头疼:

random_policy = random_tf_policy.RandomTFPolicy(env.time_step_spec(),
                                                env.action_spec())

replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
data_spec=agent.collect_data_spec,
batch_size=1,
max_length=replay_buffer_capacity) # this is 100

# ...

def collect_step(environment, policy):
  time_step = environment.current_time_step()
  action_step = policy.action(time_step)
  next_time_step = environment.step(action_step.action)
  traj = trajectory.from_transition(time_step, action_step, next_time_step)
  print(traj)

  # Add trajectory to the replay buffer
  replay_buffer.add_batch(traj)

for _ in range(initial_collect_steps):
  collect_step(env, random_policy)

对于上下文:agent.collect_data_spec具有以下形状:

Trajectory(step_type=TensorSpec(shape=(), dtype=tf.int32, name='step_type'), observation=BoundedTensorSpec(shape=(4, 84, 84), dtype=tf.float32, name='screen', minimum=array(0., dtype=float32), maximum=array(1., dtype=float32)), action=BoundedTensorSpec(shape=(), dtype=tf.int32, name='play', minimum=array(0), maximum=array(6)), policy_info=(), next_step_type=TensorSpec(shape=(), dtype=tf.int32, name='step_type'), reward=TensorSpec(shape=(), dtype=tf.float32, name='reward'), discount=BoundedTensorSpec(shape=(), dtype=tf.float32, name='discount', minimum=array(0., dtype=float32), maximum=array(1., dtype=float32)))

下面是示例traj的外观:

Trajectory(step_type=<tf.Tensor: shape=(), dtype=int32, numpy=0>, observation=<tf.Tensor: shape=(4, 84, 84), dtype=float32, numpy=array([tensor contents omitted], dtype=float32)>, action=<tf.Tensor: shape=(), dtype=int32, numpy=1>, policy_info=(), next_step_type=<tf.Tensor: shape=(), dtype=int32, numpy=1>, reward=<tf.Tensor: shape=(), dtype=float32, numpy=0.0>, discount=<tf.Tensor: shape=(), dtype=float32, numpy=1.0>)

所以,一切都应该检查出来,对吗?环境输出一个形状为[4,84,84]的张量,与重播缓冲区所期望的相同。但我得到了以下错误:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Must have updates.shape = indices.shape + params.shape[1:] or updates.shape = [], got updates.shape [4,84,84], indices.shape [1], params.shape [100,4,84,84] [Op:ResourceScatterUpdate]

这表明它实际上期望一个[1, 4, 84, 84]形状的张量。但问题是,如果我的环境输出了该形状的张量,那么我会收到另一条错误消息,告诉我输出的形状与规格形状不匹配(duh)。然后,如果我将spec形状调整为[1, 4, 84, 84],那么重播缓冲区突然需要一个[1, 1, 4, 84, 84]的形状,依此类推

最后,为了完成,这里您分别有我的环境的time_step_specaction_spec

TimeStep(step_type=TensorSpec(shape=(), dtype=tf.int32, name='step_type'), reward=TensorSpec(shape=(), dtype=tf.float32, name='reward'), discount=BoundedTensorSpec(shape=(), dtype=tf.float32, name='discount', minimum=array(0., dtype=float32), maximum=array(1., dtype=float32)), observation=BoundedTensorSpec(shape=(4, 84, 84), dtype=tf.float32, name='screen', minimum=array(0., dtype=float32), maximum=array(1., dtype=float32)))
---
BoundedTensorSpec(shape=(), dtype=tf.int32, name='play', minimum=array(0), maximum=array(6))

今天我已经试了差不多一半的时间,试着让张量合适,但是你不能改变它,因为它是一个属性,所以在最后的努力中,我希望可能有一些陌生人能告诉我这里到底发生了什么

提前谢谢你


Tags: nametimetftypesteppolicyactionarray
1条回答
网友
1楼 · 发布于 2024-04-18 11:47:49

似乎在collect_step函数中,traj是单个轨迹,而不是批处理。因此,您需要将维度展开为一个批,然后使用它。请注意,您不能只执行tf.expand_dims(traj, 0)。有一个帮助器函数用于为嵌套结构执行此操作

def collect_step(environment, policy):
    time_step = environment.current_time_step()
    action_step = policy.action(time_step)
    next_time_step = environment.step(action_step.action)
    traj = trajectory.from_transition(time_step, action_step, next_time_step)
    batch = tf.nest.map_structure(lambda t: tf.expand_dims(t, 0), traj)
    # Add trajectory to the replay buffer
    replay_buffer.add_batch(batch)

相关问题 更多 >