본문 바로가기

딥러닝/강화학습(RL)

[Analyse RLLib] 2. RLlib 기본 훈련 코드 돌리기

RLlib Basic Training

Trainer Class 는 다음과 같은 기능이 있습니다.

  1. Policy Optimizer를 가지고 있고, 외부 환경과 상호작용을 책임집니다.
  2. Policy에 대해서 훈련, Checkpoint, 모델 파라미터 복구, 다음 action을 계산해줍니다. (train, save, restore, compute)
  3. multi-agent환경에서는 여러 개의 Policy를 한번에 Optimize 해줍니다.

trainer는 환경(Env)와 상호작용하면서 처리하는 모든 것들을 해줍니다.

 

 

Ray 와 RLlib 설치방법

pip install 'ray[rllib]'

간단한 훈련(Cmd 창)

pip로 rllib을 설치하셨다면, cmd 창에서 훈련 코드를 바로 돌릴 수 있습니다.

  • Algorithm : DQN(Deep Q-Network)
  • Environment : CartPole-v0 (Gym의 내장 환경)
rllib train --run DQN --env CartPole-v0
result
http://127.0.0.1:8265 포트에서 확인이 가능하지만, windows는 지원하지 않는다.

결과 파일 구조

Command line에서 실행한 RLlib의 결과는 ray_results/에 저장됩니다.

  • params.json : 모델, 환경에 대한 정보를 담고 있다.
  • params.pkl : 모델 파라미터
  • resuls.json :
  • progress : 결과값을 CSV로 저장합니다.
  • events.out.tfevents. : Plot 가능한 값에 대해서 tensorboard로 확인 가능하다. tensorboard --logdir .

테스트 해보기(Cmd 창)

train 대신에 rollout을 사용하면 환경에 대해서 simulation을 해줍니다. 물론 알고리즘에 대해서 학습된 모델이 필요하기 때문에 checkpoint 정보를 줘야 합니다.

rllib rollout \
    ~/ray_results/default/DQN_CartPole-v0_0upjmdgr0/checkpoint_1/checkpoint-1 \
    --run DQN --env CartPole-v0 --steps 10000

Basic Python API(.py File)

Trainer를 직접 생성하는 방법으로 환경에 대해서 모델을 학습시킬 수 있습니다. Trainer안에는 환경과 상호작용으로 배우는 모든 정보가 담겨져 있으므로 high level api로 생각할 수도 있을 것 같습니다.

import ray
import ray.rllib.agents.ppo as ppo
from ray.tune.logger import pretty_print

# === initialize ray === 
ray.init()
config = ppo.DEFAULT_CONFIG.copy()
config["num_gpus"] = 0
config["num_workers"] = 1
config["eager"] = False

# === make trainer(ppo algorithm) ==
# ppo trainer을 생성하면서, 환경에 대한 정보, configuration을 같이 준다. 
trainer = ppo.PPOTrainer(config=config, env="CartPole-v0") 

# === Train === 
for i in range(1000):
   result = trainer.train()
   print(pretty_print(result))
   if i % 100 == 0:
       checkpoint = trainer.save()
       print("checkpoint saved at", checkpoint)

# === import from the checkpoint === 
trainer.import_model("my_weights.h5")

Tune API

Trainer Class를 직접 호출해서 학습을 진행하는데, 만일 하이퍼파리미터를 설정해야 한다면 ray.tune을 이용해서 파라미터를 설정하고 Train을 하는데 필요한 정보들을 추가적으로 설정할 수 있습니다.

import ray
from ray import tune

ray.init()
tune.run(
    "PPO",
    stop={"episode_reward_mean": 200},
    config={
        "env": "CartPole-v0",
        "num_gpus": 0,
        "num_workers": 1,
        "lr": tune.grid_search([0.01, 0.001, 0.0001]),
        "eager": False,
    },
)

References

[1] https://docs.ray.io/en/master/rllib-training.html