RLlib Basic Training
Trainer Class
는 다음과 같은 기능이 있습니다.
- Policy Optimizer를 가지고 있고, 외부 환경과 상호작용을 책임집니다.
- Policy에 대해서 훈련, Checkpoint, 모델 파라미터 복구, 다음 action을 계산해줍니다. (train, save, restore, compute)
- multi-agent환경에서는 여러 개의 Policy를 한번에 Optimize 해줍니다.
trainer
는 환경(Env)와 상호작용하면서 처리하는 모든 것들을 해줍니다.
Ray 와 RLlib 설치방법
pip install 'ray[rllib]'
간단한 훈련(Cmd 창)
pip로 rllib을 설치하셨다면, cmd 창에서 훈련 코드를 바로 돌릴 수 있습니다.
- Algorithm : DQN(Deep Q-Network)
- Environment : CartPole-v0 (Gym의 내장 환경)
rllib train --run DQN --env CartPole-v0
result |
---|
![]() |
http://127.0.0.1:8265 포트에서 확인이 가능하지만, windows는 지원하지 않는다. |
결과 파일 구조
Command line에서 실행한 RLlib의 결과는 ray_results/
에 저장됩니다.
- params.json : 모델, 환경에 대한 정보를 담고 있다.
- params.pkl : 모델 파라미터
- resuls.json :
- progress : 결과값을 CSV로 저장합니다.
- events.out.tfevents. : Plot 가능한 값에 대해서 tensorboard로 확인 가능하다.
tensorboard --logdir .
테스트 해보기(Cmd 창)
train
대신에 rollout
을 사용하면 환경에 대해서 simulation을 해줍니다. 물론 알고리즘에 대해서 학습된 모델이 필요하기 때문에 checkpoint 정보를 줘야 합니다.
rllib rollout \
~/ray_results/default/DQN_CartPole-v0_0upjmdgr0/checkpoint_1/checkpoint-1 \
--run DQN --env CartPole-v0 --steps 10000
Basic Python API(.py File)
Trainer
를 직접 생성하는 방법으로 환경에 대해서 모델을 학습시킬 수 있습니다. Trainer안에는 환경과 상호작용으로 배우는 모든 정보가 담겨져 있으므로 high level api
로 생각할 수도 있을 것 같습니다.
import ray
import ray.rllib.agents.ppo as ppo
from ray.tune.logger import pretty_print
# === initialize ray ===
ray.init()
config = ppo.DEFAULT_CONFIG.copy()
config["num_gpus"] = 0
config["num_workers"] = 1
config["eager"] = False
# === make trainer(ppo algorithm) ==
# ppo trainer을 생성하면서, 환경에 대한 정보, configuration을 같이 준다.
trainer = ppo.PPOTrainer(config=config, env="CartPole-v0")
# === Train ===
for i in range(1000):
result = trainer.train()
print(pretty_print(result))
if i % 100 == 0:
checkpoint = trainer.save()
print("checkpoint saved at", checkpoint)
# === import from the checkpoint ===
trainer.import_model("my_weights.h5")
Tune API
Trainer Class
를 직접 호출해서 학습을 진행하는데, 만일 하이퍼파리미터를 설정해야 한다면 ray.tune
을 이용해서 파라미터를 설정하고 Train을 하는데 필요한 정보들을 추가적으로 설정할 수 있습니다.
import ray
from ray import tune
ray.init()
tune.run(
"PPO",
stop={"episode_reward_mean": 200},
config={
"env": "CartPole-v0",
"num_gpus": 0,
"num_workers": 1,
"lr": tune.grid_search([0.01, 0.001, 0.0001]),
"eager": False,
},
)
References
'딥러닝 > 강화학습(RL)' 카테고리의 다른 글
[Analyse RLLib] 4. RLlib CallBacks (0) | 2021.02.26 |
---|---|
[Analyse RLLib] 3. Train Model with Ray Trainer (0) | 2021.02.26 |
[Analyse RLLib] 1. Ray와 RLlib의 전체적인 구조 (0) | 2021.02.26 |
MARL-DRONE (0) | 2020.12.31 |
MARL - MADDPG 이해하기 (0) | 2020.12.20 |