将类方法作为一个参数callb

2024-05-27 12:23:26 发布

您现在位置:Python中文网/ 问答频道 /正文

我正在和tune一起训练一些强化学习模型。我需要使用一个custom callback,它必须是一个参数函数。 我想从一个自定义类中得到一个方法作为回调,但是我真的不知道如何在不使用self的情况下传递info参数

这是我现在的课:

import json
from collections import Counter

import termcolor
from ray import tune


class Evaluation:



    def configure_callbacks(self,config):

        config["callbacks"]["on_episode_step"] = self.on_episode_step

        return config

    def log(self,msg, color="white"):
        termcolor.cprint(msg,color)

    @tune.function
    def on_episode_step(self,info):

         self.log(info)



我得到以下错误

ray.exceptions.RayTaskError: /Users/giulia/anaconda3/envs/dmas/bin/python /Users/giulia/Desktop/mas_traffic/FlowMas/simulation.py (pid=25935, host=Giulias-MacBook-Pro.local)
  File "/Users/giulia/anaconda3/envs/dmas/lib/python3.6/site-packages/ray-0.7.4-py3.6-macosx-10.7-x86_64.egg/ray/local_mode_manager.py", line 55, in execute
    results = function(*copy.deepcopy(args))
  File "/Users/giulia/anaconda3/envs/dmas/lib/python3.6/site-packages/ray-0.7.4-py3.6-macosx-10.7-x86_64.egg/ray/rllib/agents/trainer.py", line 415, in train
    raise e
  File "/Users/giulia/anaconda3/envs/dmas/lib/python3.6/site-packages/ray-0.7.4-py3.6-macosx-10.7-x86_64.egg/ray/rllib/agents/trainer.py", line 401, in train
    result = Trainable.train(self)
  File "/Users/giulia/anaconda3/envs/dmas/lib/python3.6/site-packages/ray-0.7.4-py3.6-macosx-10.7-x86_64.egg/ray/tune/trainable.py", line 171, in train
    result = self._train()
  File "/Users/giulia/anaconda3/envs/dmas/lib/python3.6/site-packages/ray-0.7.4-py3.6-macosx-10.7-x86_64.egg/ray/rllib/agents/trainer_template.py", line 129, in _train
    fetches = self.optimizer.step()
  File "/Users/giulia/anaconda3/envs/dmas/lib/python3.6/site-packages/ray-0.7.4-py3.6-macosx-10.7-x86_64.egg/ray/rllib/optimizers/sync_batch_replay_optimizer.py", line 66, in step
    batches = [self.workers.local_worker().sample()]
  File "/Users/giulia/anaconda3/envs/dmas/lib/python3.6/site-packages/ray-0.7.4-py3.6-macosx-10.7-x86_64.egg/ray/rllib/evaluation/rollout_worker.py", line 472, in sample
    batches = [self.input_reader.next()]
  File "/Users/giulia/anaconda3/envs/dmas/lib/python3.6/site-packages/ray-0.7.4-py3.6-macosx-10.7-x86_64.egg/ray/rllib/evaluation/sampler.py", line 56, in next
    batches = [self.get_data()]
  File "/Users/giulia/anaconda3/envs/dmas/lib/python3.6/site-packages/ray-0.7.4-py3.6-macosx-10.7-x86_64.egg/ray/rllib/evaluation/sampler.py", line 99, in get_data
    item = next(self.rollout_provider)
  File "/Users/giulia/anaconda3/envs/dmas/lib/python3.6/site-packages/ray-0.7.4-py3.6-macosx-10.7-x86_64.egg/ray/rllib/evaluation/sampler.py", line 319, in _env_runner
    soft_horizon, no_done_at_end)
  File "/Users/giulia/anaconda3/envs/dmas/lib/python3.6/site-packages/ray-0.7.4-py3.6-macosx-10.7-x86_64.egg/ray/rllib/evaluation/sampler.py", line 364, in _process_observations
    episode = active_episodes[env_id]
  File "/Users/giulia/anaconda3/envs/dmas/lib/python3.6/site-packages/ray-0.7.4-py3.6-macosx-10.7-x86_64.egg/ray/rllib/evaluation/sampler.py", line 294, in new_episode
    "episode": episode,
  File "/Users/giulia/anaconda3/envs/dmas/lib/python3.6/site-packages/ray-0.7.4-py3.6-macosx-10.7-x86_64.egg/ray/tune/sample.py", line 45, in __call__
    return self.func(*args, **kwargs)
TypeError: on_episode_step() missing 1 required positional argument: 'info'

这很明显,因为我要传递两个参数

我真的很想保持评估过程的进度,让类的属性不时更新

有什么解决办法吗


Tags: pyselfegglibpackagessiteusersx86