딥러닝/강화학습(RL)

[Analyse RLLib] 4. RLlib CallBacks

Rudi 2021. 2. 26. 21:52

아래의 CallBack Method 들은 모두 상황에 따라서 필요한 경우가 다르겠지만, 훈련 중간에 각 부분에서 어떤 처리를 하고 싶다면 반드시 구현해야 하는 부분입니다.

예를 들어서, MultiAgent환경에서 각각의 AgentObservation을 합치고 싶다면, on_postprocess_trajectory()부분을 구현해서 Batch의 형태를 바꿔주면 됩니다.

CallBack Class Functions Description
on_spisode_start() rollout worker에 대해서 episode를 시작하기 전에 불리는 함수입니다.
on_episode_step() episode의 매 step마다 불리는 함수입니다.
on_episode_end() episode가 끝날 때 불리는 함수입니다.
on_postprocess_trajectory() policy에서 policy'postprocess_fn이 불리고 호출되는 함수로, batch를 처리하는 부분입니다. 예를 들어서 MultiAgent에서 다른 Agent의 Observation을 처리하는 부분을 추가할 수 있습니다.
on_sample_end() RolloutWorer.sample()이 끝나고 호출되는 함수입니다.
on_learned_on_batch() Policy.learn_on_batch()의 첫 부분에 호출되는 함수입니다.
on_train_result() Trainer.train()으로 학습을 완료하고 호출되는 함수입니다.

How to write a CallBack Class

CartPole-v0 환경에서 CallBack을 적용해보겠습니다. 아래의 기능들은 밑에 Class MyCallbacks에 구현이 되어 있습니다.

  1. 에피소드가 시작되면 Pole의 각도를 저장할 리스트를 생성합니다.
  2. 에피소드의 매 Timestep 마다 Pole의 각도를 저장합니다.
  3. 에피소트다 끝나면 Pole의 각도를 Timestep에 대해서 평균내고 저장합니다.
  4. Rollout Worker에서 샘플링을 진행하면, 샘플의 사이즈를 출력합니다.
  5. 훈련이 종료되면 Callback이 완료되었음을 저장합니다.
  6. mini batch에 대해서 훈련을 시작하기 전에, action을 평균냅니다.
  7. Policy를 Mapping 후 Batch를 처리할 때, 처리된 개수를 저장합니다.
class MyCallbacks(DefaultCallbacks):
    def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv,
                         policies: Dict[str, Policy],
                         episode: MultiAgentEpisode, env_index: int, **kwargs):
        print("episode {} (env-idx={}) started.".format(
            episode.episode_id, env_index))

        episode.user_data["pole_angles"] = []
        episode.hist_data["pole_angles"] = []

    def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv,
                        episode: MultiAgentEpisode, env_index: int, **kwargs):
        pole_angle = abs(episode.last_observation_for()[2])
        raw_angle = abs(episode.last_raw_obs_for()[2])
        assert pole_angle == raw_angle
        episode.user_data["pole_angles"].append(pole_angle)

    def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv,
                       policies: Dict[str, Policy], episode: MultiAgentEpisode,
                       env_index: int, **kwargs):
        pole_angle = np.mean(episode.user_data["pole_angles"])
        print("episode {} (env-idx={}) ended with length {} and pole "
              "angles {}".format(episode.episode_id, env_index, episode.length,
                                 pole_angle))
        episode.custom_metrics["pole_angle"] = pole_angle
        episode.hist_data["pole_angles"] = episode.user_data["pole_angles"]

    def on_sample_end(self, *, worker: RolloutWorker, samples: SampleBatch,
                      **kwargs):
        print("returned sample batch of size {}".format(samples.count))

    def on_learn_on_batch(self, *, policy: Policy, train_batch: SampleBatch,
                          result: dict = {}, **kwargs) -> None:
        result["sum_actions_in_train_batch"] = np.sum(train_batch["actions"])
        print("policy.learn_on_batch() result: {} -> sum actions: {}".format(
            policy, result["sum_actions_in_train_batch"]))

    def on_postprocess_trajectory(
            self, *, worker: RolloutWorker, episode: MultiAgentEpisode,
            agent_id: str, policy_id: str, policies: Dict[str, Policy],
            postprocessed_batch: SampleBatch,
            original_batches: Dict[str, SampleBatch], **kwargs):
        print("postprocessed {} steps".format(postprocessed_batch.count))
        if "num_batches" not in episode.custom_metrics:
            episode.custom_metrics["num_batches"] = 0
        episode.custom_metrics["num_batches"] += 1

    def on_train_result(self, *, trainer, result: dict, **kwargs):
        print("trainer.train() result: {} -> {} episodes".format(
            trainer, result["episodes_this_iter"]))
        # you can mutate the result dict to add new fields to return
        result["callback_ok"] = True

How to use callbacks

CallBack Class를 정의하고 난 후에는, tune을 하면서 configcallbacks에 넣어주시면 됩니다.

ray.init()
trials = tune.run(
    "PG",
    stop={"training_iteration": args.stop_iters},
    config={
        "env": "CartPole-v0",
        "num_envs_per_worker": 2,
        "callbacks": MyCallbacks, # <-----------------!!!
    }
)

Note : Result에 대해서 에러가 있을 수 있습니다 :(