딥러닝/강화학습(RL)
[Analyse RLLib] 4. RLlib CallBacks
Rudi
2021. 2. 26. 21:52
-
RLlib Callbacks
Trainer에서 Batch에 대해서 훈련을 진행하면서, 진행되는 전처리, 후처리를 모두 처리하는 부분입니다.
Functions
아래의 CallBack Method
들은 모두 상황에 따라서 필요한 경우가 다르겠지만, 훈련 중간에 각 부분에서 어떤 처리를 하고 싶다면 반드시 구현해야 하는 부분입니다.
예를 들어서, MultiAgent
환경에서 각각의 Agent
의 Observation
을 합치고 싶다면, on_postprocess_trajectory()
부분을 구현해서 Batch
의 형태를 바꿔주면 됩니다.
CallBack Class Functions | Description |
---|---|
on_spisode_start() | rollout worker에 대해서 episode를 시작하기 전에 불리는 함수입니다. |
on_episode_step() | episode의 매 step마다 불리는 함수입니다. |
on_episode_end() | episode가 끝날 때 불리는 함수입니다. |
on_postprocess_trajectory() | policy 에서 policy'postprocess_fn 이 불리고 호출되는 함수로, batch 를 처리하는 부분입니다. 예를 들어서 MultiAgent 에서 다른 Agent 의 Observation을 처리하는 부분을 추가할 수 있습니다. |
on_sample_end() | RolloutWorer.sample() 이 끝나고 호출되는 함수입니다. |
on_learned_on_batch() | Policy.learn_on_batch() 의 첫 부분에 호출되는 함수입니다. |
on_train_result() | Trainer.train() 으로 학습을 완료하고 호출되는 함수입니다. |
How to write a CallBack Class
CartPole-v0
환경에서 CallBack
을 적용해보겠습니다. 아래의 기능들은 밑에 Class MyCallbacks
에 구현이 되어 있습니다.
- 에피소드가 시작되면
Pole
의 각도를 저장할 리스트를 생성합니다. - 에피소드의 매 Timestep 마다
Pole
의 각도를 저장합니다. - 에피소트다 끝나면
Pole
의 각도를 Timestep에 대해서 평균내고 저장합니다. Rollout Worker
에서 샘플링을 진행하면, 샘플의 사이즈를 출력합니다.- 훈련이 종료되면
Callback
이 완료되었음을 저장합니다. - mini batch에 대해서 훈련을 시작하기 전에, action을 평균냅니다.
- Policy를 Mapping 후 Batch를 처리할 때, 처리된 개수를 저장합니다.
class MyCallbacks(DefaultCallbacks):
def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv,
policies: Dict[str, Policy],
episode: MultiAgentEpisode, env_index: int, **kwargs):
print("episode {} (env-idx={}) started.".format(
episode.episode_id, env_index))
episode.user_data["pole_angles"] = []
episode.hist_data["pole_angles"] = []
def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv,
episode: MultiAgentEpisode, env_index: int, **kwargs):
pole_angle = abs(episode.last_observation_for()[2])
raw_angle = abs(episode.last_raw_obs_for()[2])
assert pole_angle == raw_angle
episode.user_data["pole_angles"].append(pole_angle)
def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv,
policies: Dict[str, Policy], episode: MultiAgentEpisode,
env_index: int, **kwargs):
pole_angle = np.mean(episode.user_data["pole_angles"])
print("episode {} (env-idx={}) ended with length {} and pole "
"angles {}".format(episode.episode_id, env_index, episode.length,
pole_angle))
episode.custom_metrics["pole_angle"] = pole_angle
episode.hist_data["pole_angles"] = episode.user_data["pole_angles"]
def on_sample_end(self, *, worker: RolloutWorker, samples: SampleBatch,
**kwargs):
print("returned sample batch of size {}".format(samples.count))
def on_learn_on_batch(self, *, policy: Policy, train_batch: SampleBatch,
result: dict = {}, **kwargs) -> None:
result["sum_actions_in_train_batch"] = np.sum(train_batch["actions"])
print("policy.learn_on_batch() result: {} -> sum actions: {}".format(
policy, result["sum_actions_in_train_batch"]))
def on_postprocess_trajectory(
self, *, worker: RolloutWorker, episode: MultiAgentEpisode,
agent_id: str, policy_id: str, policies: Dict[str, Policy],
postprocessed_batch: SampleBatch,
original_batches: Dict[str, SampleBatch], **kwargs):
print("postprocessed {} steps".format(postprocessed_batch.count))
if "num_batches" not in episode.custom_metrics:
episode.custom_metrics["num_batches"] = 0
episode.custom_metrics["num_batches"] += 1
def on_train_result(self, *, trainer, result: dict, **kwargs):
print("trainer.train() result: {} -> {} episodes".format(
trainer, result["episodes_this_iter"]))
# you can mutate the result dict to add new fields to return
result["callback_ok"] = True
How to use callbacks
CallBack Class
를 정의하고 난 후에는, tune
을 하면서 config
의 callbacks
에 넣어주시면 됩니다.
ray.init()
trials = tune.run(
"PG",
stop={"training_iteration": args.stop_iters},
config={
"env": "CartPole-v0",
"num_envs_per_worker": 2,
"callbacks": MyCallbacks, # <-----------------!!!
}
)
Note : Result에 대해서 에러가 있을 수 있습니다 :(