获取CartPole环境的帧作为观察值
在Python中,我正在使用stablebaselines3
和gymnasium
来实现一个自定义的DQN(深度Q网络)。我用Atari游戏测试了这个智能体,它运行得很好,现在我还想在像CartPole
这样的环境中测试它。
问题是,这种环境返回的不是图像帧作为观察结果,而是一个简单的向量。
所以我需要找到一种方法,让CartPole
返回图像帧作为观察结果,并且应用我在Atari游戏中使用的相同预处理方法(比如将4帧游戏画面叠加在一起)。
我在网上查了很多资料,经过几次尝试,我写出了这段代码,但遇到了一些问题。
这是我的代码:
from stable_baselines3.common.env_util import make_atari_env, make_vec_env
from stable_baselines3.common.vec_env import VecFrameStack, VecTransposeImage
import gymnasium as gym
from gymnasium import spaces
from gymnasium.envs.classic_control import CartPoleEnv
import numpy as np
import cv2
class CartPoleImageWrapper(gym.Wrapper):
metadata = {'render.modes': ['rgb_array']}
def __init__(self, env):
super(CartPoleImageWrapper, self).__init__(env)
self.observation_space = spaces.Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)
def _get_image_observation(self):
# Render the CartPole environment
cartpole_image = self.render()
# Resize the image to 84x84 pixels
resized_image = cv2.resize(cartpole_image, (84, 84))
# make it grayscale
resized_image = cv2.cvtColor(resized_image, cv2.COLOR_RGB2GRAY)
resized_image = np.expand_dims(resized_image, axis=-1)
return resized_image
def reset(self):
self.env.reset()
return self._get_image_observation()
def step(self, action):
observation, reward, terminated, info = self.env.step(action)
return self._get_image_observation(), reward, terminated, info
env = CartPoleImageWrapper(CartPoleEnv(render_mode='rgb_array'))
vec_env = make_vec_env(lambda: env, n_envs=1)
vec_env = VecTransposeImage(vec_env)
vec_env = VecFrameStack(vec_env, n_stack=4)
obs = vec_env.reset()
print(f"Observation space: {obs.shape}")
#exit()
vec_env.close()
当我调用env.reset()
时,出现了这个错误:
Traceback (most recent call last):
File "/data/g.carfi/rl/tmp.py", line 41, in <module>
obs = vec_env.reset()
File "/data/g/.virtualenvs/rl_new/lib/python3.8/site-packages/stable_baselines3/common/vec_env/vec_frame_stack.py", line 41, in reset
observation = self.venv.reset()
File "/data/g/.virtualenvs/rl_new/lib/python3.8/site-packages/stable_baselines3/common/vec_env/vec_transpose.py", line 113, in reset
observations = self.venv.reset()
File "/data/g/.virtualenvs/rl_new/lib/python3.8/site-packages/stable_baselines3/common/vec_env/dummy_vec_env.py", line 77, in reset
obs, self.reset_infos[env_idx] = self.envs[env_idx].reset(seed=self._seeds[env_idx], **maybe_options)
File "/data/g/.virtualenvs/rl_new/lib/python3.8/site-packages/stable_baselines3/common/monitor.py", line 83, in reset
return self.env.reset(**kwargs)
TypeError: reset() got an unexpected keyword argument 'seed'
我该如何解决这个问题呢?
2 个回答
你遇到的问题是因为CartPoleEnv类的reset()方法不接受seed这个参数,但看起来VecEnv内部却传递了这个参数。
要解决这个问题,你可以在你的CartPoleImageWrapper类中修改reset()方法,以处理这个不一致的情况。你可以在调用被包装环境的reset()方法时,简单地忽略seed参数。下面是你可以这样做的示例:
class CartPoleImageWrapper(gym.Wrapper):
metadata = {'render.modes': ['rgb_array']}
def __init__(self, env):
super(CartPoleImageWrapper, self).__init__(env)
self.observation_space = spaces.Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)
def _get_image_observation(self):
# Render the CartPole environment
cartpole_image = self.render()
# Resize the image to 84x84 pixels
resized_image = cv2.resize(cartpole_image, (84, 84))
# make it grayscale
resized_image = cv2.cvtColor(resized_image, cv2.COLOR_RGB2GRAY)
resized_image = np.expand_dims(resized_image, axis=-1)
return resized_image
def reset(self, **kwargs):
self.env.reset(**kwargs) # Ignore the 'seed' argument
return self._get_image_observation()
def step(self, action):
observation, reward, terminated, info = self.env.step(action)
return self._get_image_observation(), reward, terminated, info
通过这个修改,你应该能够在使用CartPoleImageWrapper和VecFrameStack时,不再遇到与意外的seed参数相关的TypeError错误。
这个错误是因为CartPoleEnv类的reset()方法不接受seed这个参数,但VecEnv却传递了这个参数。要解决这个问题,你需要修改你的CartPoleImageWrapper,让它正确处理这个情况。
import cv2
import gym
import numpy as np
from gym import spaces
from gym.envs.classic_control import CartPoleEnv
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecFrameStack, VecTransposeImage
class CartPoleImageWrapper(gym.Wrapper):
metadata = {'render.modes': ['rgb_array']}
def __init__(self, env):
super(CartPoleImageWrapper, self).__init__(env)
self.observation_space = spaces.Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)
def _get_image_observation(self):
# Render the CartPole environment
cartpole_image = self.env.render(mode='rgb_array')
# Resize the image to 84x84 pixels
resized_image = cv2.resize(cartpole_image, (84, 84))
# make it grayscale
resized_image = cv2.cvtColor(resized_image, cv2.COLOR_RGB2GRAY)
resized_image = np.expand_dims(resized_image, axis=-1)
return resized_image
def reset(self, **kwargs):
self.env.reset(**kwargs)
return self._get_image_observation()
def step(self, action):
observation, reward, terminated, info = self.env.step(action)
return self._get_image_observation(), reward, terminated, info
env = CartPoleImageWrapper(CartPoleEnv())
vec_env = make_vec_env(lambda: env, n_envs=1)
vec_env = VecTransposeImage(vec_env)
vec_env = VecFrameStack(vec_env, n_stack=4)
obs = vec_env.reset()
print(f"Observation space: {obs.shape}")
vec_env.close()
这是我做的修改:
把_get_image_observation()方法改成使用render()方法,并设置模式为'rgb_array'。
把reset()方法改成可以接受**kwargs,并把这些参数传递给内部环境的reset()方法(也就是CartPoleEnv)。
去掉了不必要的gym导入。
做了这些修改后,你的容器应该能和向量化环境正常工作了。