推荐的使用Gymnasium与神经网络的方法以避免model.fit和model.predict的开销
我正在尝试使用Sutton和Barto第10章(第二版)中的“情节半梯度Sarsa”来解决CartPole问题,并且我使用Keras来进行函数逼近。不过,按照算法的要求实现代码让我不得不在fit和predict时使用批量大小为1,这导致代码运行得非常慢。一个替代方案是先运行代码从Gymnasium收集数据,然后再用这些数据离线训练神经网络。这样做可以吗?如果我理解没错的话,这样还是在策略上?或者有没有其他标准的方法可以在不影响性能的情况下将神经网络与Gymnasium结合使用?
这是我当前尝试的概述 -
import gymnasium as gym
from numpy.random import choice as random_choice
from numpy import array, argmax
我把算法写成了以下的Python代码:
env = gym.make('CartPole-v1')
for ep_idx in range(num_episodes):
terminated = False
state, _ = env.reset()
action = env.action_space.sample()
while not terminated:
action_ = policy.take_action(state, qvalue, ep_idx)
state_, reward, terminated, _, _ = env.step(action_)
if terminated:
qvalue.update(state, action, reward, None, None)
else:
qvalue.update(state, action, reward, state_, action_)
state, action = state_, action_
在函数逼近方面,我决定使用Keras。这是在qvalue.update
中实现的,如下所示:
class QValueFunction:
def __init__(self, discount, learning_rate, num_actions, *state_vector_dim):
# not shown here for brevity
def __call__(self, state, action=None):
# not shown here for brevity
def update(self, s, a, r, s_, a_):
model = self._model # instance of keras.models.Model
gamma = self._discount # float
update_targets = self._update_targets # a pre-allocated numpy array
q = self
update_targets[:] = q(s, None)
self._s[:] = s
s = self._s
if s_ is None and a_ is None:
update_targets[0, a] = r
else:
update_targets[0, a] = r + gamma * q(s_, a_)
model.fit(s, update_targets, batch_size=1, verbose=0)
而policy
是EpsilonGreedyPolicy
的一个实例:
class EpsilonGreedyPolicy:
def __init__(self, epsilon):
self.eps = epsilon
def take_action(self, state, qvalue, ep=None):
num_actions = qvalue.num_actions
if callable(self.eps): eps = self.eps(ep+1)
else: eps = self.eps
if rand() < eps:
return random_choice(num_actions)
else:
qvalues = qvalue(state)
return argmax(qvalues)
以上代码在我的笔记本电脑上(仅使用CPU)大约每10秒运行1个情节。为了检查代码实际能多快运行,我尝试使用随机策略(eps=1)生成1000个情节的数据,生成了20000多个元组(s, a, r, s_, a_)
。这大约只需要10秒。接下来,我用这些数据单独训练神经网络,这样做的结果是通过一次性将所有数据传递给Keras的model.predict
和model.fit
,每10000个数据点大约只需1秒。总的来说,按照算法要求使用批量大小为1的model.fit
和model.predict
运行代码需要10000秒,而按照(i)先生成数据(ii)再训练神经网络的方式只需要10到100秒。
有没有推荐的方法可以在使用Gymnasium时避免这么大的开销呢?
0 个回答
暂无回答