获取CartPole环境的帧作为观察值

1 投票
2 回答
65 浏览
提问于 2025-04-14 15:38

在Python中,我正在使用stablebaselines3gymnasium来实现一个自定义的DQN(深度Q网络)。我用Atari游戏测试了这个智能体,它运行得很好,现在我还想在像CartPole这样的环境中测试它。

问题是,这种环境返回的不是图像帧作为观察结果,而是一个简单的向量。

所以我需要找到一种方法,让CartPole返回图像帧作为观察结果,并且应用我在Atari游戏中使用的相同预处理方法(比如将4帧游戏画面叠加在一起)。

我在网上查了很多资料,经过几次尝试,我写出了这段代码,但遇到了一些问题。

这是我的代码:

from stable_baselines3.common.env_util import make_atari_env, make_vec_env
from stable_baselines3.common.vec_env import VecFrameStack, VecTransposeImage
import gymnasium as gym
from gymnasium import spaces
from gymnasium.envs.classic_control import CartPoleEnv
import numpy as np
import cv2


class CartPoleImageWrapper(gym.Wrapper):
    metadata = {'render.modes': ['rgb_array']}

    def __init__(self, env):
        super(CartPoleImageWrapper, self).__init__(env)
        self.observation_space = spaces.Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)

    def _get_image_observation(self):
        # Render the CartPole environment
        cartpole_image = self.render()

        # Resize the image to 84x84 pixels
        resized_image = cv2.resize(cartpole_image, (84, 84))
        # make it grayscale
        resized_image = cv2.cvtColor(resized_image, cv2.COLOR_RGB2GRAY)
        resized_image = np.expand_dims(resized_image, axis=-1)
        return resized_image

    def reset(self):
        self.env.reset()
        return self._get_image_observation()

    def step(self, action):
        observation, reward, terminated, info = self.env.step(action)
        return self._get_image_observation(), reward, terminated, info


env = CartPoleImageWrapper(CartPoleEnv(render_mode='rgb_array'))
vec_env = make_vec_env(lambda: env, n_envs=1)
vec_env = VecTransposeImage(vec_env)
vec_env = VecFrameStack(vec_env, n_stack=4)
obs = vec_env.reset()
print(f"Observation space: {obs.shape}")
#exit()
    
vec_env.close()

当我调用env.reset()时,出现了这个错误:

Traceback (most recent call last):
    File "/data/g.carfi/rl/tmp.py", line 41, in <module>
        obs = vec_env.reset()
    File "/data/g/.virtualenvs/rl_new/lib/python3.8/site-packages/stable_baselines3/common/vec_env/vec_frame_stack.py", line 41, in reset
        observation = self.venv.reset()
    File "/data/g/.virtualenvs/rl_new/lib/python3.8/site-packages/stable_baselines3/common/vec_env/vec_transpose.py", line 113, in reset
        observations = self.venv.reset()
    File "/data/g/.virtualenvs/rl_new/lib/python3.8/site-packages/stable_baselines3/common/vec_env/dummy_vec_env.py", line 77, in reset
        obs, self.reset_infos[env_idx] = self.envs[env_idx].reset(seed=self._seeds[env_idx], **maybe_options)
    File "/data/g/.virtualenvs/rl_new/lib/python3.8/site-packages/stable_baselines3/common/monitor.py", line 83, in reset
        return self.env.reset(**kwargs)
    TypeError: reset() got an unexpected keyword argument 'seed'

我该如何解决这个问题呢?

2 个回答

0

你遇到的问题是因为CartPoleEnv类的reset()方法不接受seed这个参数,但看起来VecEnv内部却传递了这个参数。

要解决这个问题,你可以在你的CartPoleImageWrapper类中修改reset()方法,以处理这个不一致的情况。你可以在调用被包装环境的reset()方法时,简单地忽略seed参数。下面是你可以这样做的示例:

class CartPoleImageWrapper(gym.Wrapper):
    metadata = {'render.modes': ['rgb_array']}

def __init__(self, env):
    super(CartPoleImageWrapper, self).__init__(env)
    self.observation_space = spaces.Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)

def _get_image_observation(self):
    # Render the CartPole environment
    cartpole_image = self.render()

    # Resize the image to 84x84 pixels
    resized_image = cv2.resize(cartpole_image, (84, 84))
    # make it grayscale
    resized_image = cv2.cvtColor(resized_image, cv2.COLOR_RGB2GRAY)
    resized_image = np.expand_dims(resized_image, axis=-1)
    return resized_image

def reset(self, **kwargs):
    self.env.reset(**kwargs)  # Ignore the 'seed' argument
    return self._get_image_observation()

def step(self, action):
    observation, reward, terminated, info = self.env.step(action)
    return self._get_image_observation(), reward, terminated, info

通过这个修改,你应该能够在使用CartPoleImageWrapper和VecFrameStack时,不再遇到与意外的seed参数相关的TypeError错误。

0

这个错误是因为CartPoleEnv类的reset()方法不接受seed这个参数,但VecEnv却传递了这个参数。要解决这个问题,你需要修改你的CartPoleImageWrapper,让它正确处理这个情况。

import cv2
import gym
import numpy as np
from gym import spaces
from gym.envs.classic_control import CartPoleEnv
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecFrameStack, VecTransposeImage


class CartPoleImageWrapper(gym.Wrapper):
    metadata = {'render.modes': ['rgb_array']}

    def __init__(self, env):
        super(CartPoleImageWrapper, self).__init__(env)
        self.observation_space = spaces.Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)

    def _get_image_observation(self):
        # Render the CartPole environment
        cartpole_image = self.env.render(mode='rgb_array')

        # Resize the image to 84x84 pixels
        resized_image = cv2.resize(cartpole_image, (84, 84))
        # make it grayscale
        resized_image = cv2.cvtColor(resized_image, cv2.COLOR_RGB2GRAY)
        resized_image = np.expand_dims(resized_image, axis=-1)
        return resized_image

    def reset(self, **kwargs):
        self.env.reset(**kwargs)
        return self._get_image_observation()

    def step(self, action):
        observation, reward, terminated, info = self.env.step(action)
        return self._get_image_observation(), reward, terminated, info


env = CartPoleImageWrapper(CartPoleEnv())
vec_env = make_vec_env(lambda: env, n_envs=1)
vec_env = VecTransposeImage(vec_env)
vec_env = VecFrameStack(vec_env, n_stack=4)
obs = vec_env.reset()
print(f"Observation space: {obs.shape}")

vec_env.close()

这是我做的修改:

  1. 把_get_image_observation()方法改成使用render()方法,并设置模式为'rgb_array'。

  2. 把reset()方法改成可以接受**kwargs,并把这些参数传递给内部环境的reset()方法(也就是CartPoleEnv)。

  3. 去掉了不必要的gym导入。

做了这些修改后,你的容器应该能和向量化环境正常工作了。

撰写回答