Train Model with Ray Trainer
DQN with simple environment
주어진 환경에 대해서 빠르게 강화학습 코드를 돌리는 방법은 다음과 같다.
- First Register Environment
- Train the model with
tune.run
필요한 라이브러리를 import
합니다.
import ray
from ray.tune.registry import register_env
from ray.rllib.agents import ppo
from ray import tune
from my_env import MyEnv
gym environment
를 불러오고, ray
에 등록해줘야 합니다. distributed system
을 위해서 반드시 등록을 해줘야 합니다. 또한 Model도 등록을 해줘야 합니다. 물론 CartPole-v0
또는 DQN
처럼 많이 사용되는 환경과 모델은 이미 등록이 되어 있습니다.
Here is a simple way to train already registered model.
# Init ray and register custom env
ray.init()
register_env("my_env", lambda config : MyEnv(config))
# Run tune with DQN
tune.run("DQN",
stop = {"training_iteration": 100},
config= {'env':'my_env',
"timesteps_per_iteration": 1000,
"buffer_size": 10000,
"train_batch_size": 64,
'lr' : 1e-4
}
)
Tainer
위에서 나온 DQN
은 Trainer입니다.
environment로부터 observation을 받아서, 처리하고 action을 compute해주며, 훈련을 진행합니다.
총 4가지 주요 함수가 있는데 마지막 compute_action()
만 살펴보도록 하겠습니다.
- train()
- save()
- restore()
- compute_action()
Compute Action
trainer의 compute_action을 사용하면, Observation에 대해서 어떠한 action을 취해야 하는지 결정할 수 있습니다.
import ray
from my_env import MyEnv
from ray.rllib.agents import ppo
from ray.tune.registry import register_env
# === init ray ===
ray.init()
# === register env ===
register_env("my_env", lambda config : MyEnv(config))
# === agent and env ===
agent = ppo.PPOTrainer(env="my_env")
env = MyEnv()
# === initial values ===
obs = env.reset()
done = False
episode_reward = 0
# === choose actions with agent.compute_action(obs) ===
while not done:
action = agent.compute_action(obs)
obs, reward, done, info = env.step(action)
episode_reward += reward
print(episode_reward)
inside of compute_action()
Trainer의 compute_action(...)
함수는 observation
을 처리하고 내부의 policy
를 이용해서 action
을 계산하는 역할을 합니다.
class Trainer(Trainable):
@publicAPI
def compute_action(self,
observation,
state=None,
prev_action=None,
prev_reward=None,
info=None,
policy_id=DEFAULT_POLICY_ID,
full_fetch=False):
if state is None:
state = []
# === observation에 대해서 전처리를 진행한다 ===
preprocessed = self.workers.local_worker().preprocessors[policy_id].transform(observation)
filtered_obs = self.workers.local_worker().filters[policy_id](preprocessed, update=False)
# === Trainer 안에 있는 policy를 이용해서 action을 계산한다. ===
if state:
return self.get_policy(policy_id).compute_single_action(...)
res = self.get_policy(policy_id).compute_single_action(...)
if full_fetch:
return res
else:
return res[0]
inside policy, MODEL!!(Neural Network)
policy
안에는 neural network
가 있어서, 모델을 통해서 다음 행동을 결정할 수도 있습니다. 만일 Rule Base
로 행동을 결정하고 싶다면 이는 compute_action
에서 이루어지고 model
과는 별개로 진행됩니다.
import ray
import numpy as np
from ray.rllib.agents.ppo import PPOTrainer
## === setup ===
ray.init()
trainer = PPOTrainer(env="CartPole-v0")
policy = trainer.get_policy()
# === feed a minibatch of size 1 ===
logits, _ = policy.model.from_batch({"obs":np.array([[0.1, 0.2, 0.3, 0.4]])})
dist = policy.dist_class(logits, policy.model)
print(type(policy.model)) # <class 'ray.rllib.models.tf.fcnet.FullyConnectedNetwork'>
print(dist.sample()) # Tensor("Squeeze:0", shape=(1,), dtype=int64)
print(dist.logp) # bound method Categorical.logp
print(policy.model.value_function()) # Tensor("Reshape:0", shape=(1,), dtype=float32)
print(policy.model.base_model.summary()) # Model: "functional_1"
Getting Q-values from a DQN
모델을 좀더 세부적으로 파고든다면, input batch
를 넣은 상태에서 값들을 추출해낼 수 있습니다.
import ray
from ray.rllib.agents.dqn import DQNTrainer
import numpy as np
# === setup ===
ray.init()
trainer = DQNTrainer(env="CartPole-v0")
model = trainer.get_policy().model
# === feed to model ===
model_out = model.from_batch({"obs":np.array([[0.1, 0.2, 0.3, 0.4]])})
model_out_dist = model.get_q_value_distributions(model_out)
# === print values ===
print(model_out) # Tuple
print(model_out_dist) # List
print(model.get_state_value(model_out)) # Tensor
"""
(<tf.Tensor 'functional_1/fc_out/Tanh:0' shape=(1, 256) dtype=float32>, [])
[<tf.Tensor 'BiasAdd_1:0' shape=(1, 2) dtype=float32>,
<tf.Tensor 'default_policy/ExpandDims_10:0' shape=(1, 2, 1) dtype=float32>,
<tf.Tensor 'default_policy/ExpandDims_1_1:0' shape=(1, 2, 1) dtype=float32>]
Tensor("BiasAdd_3:0", shape=(1, 1), dtype=float32)
"""
References
'딥러닝 > 강화학습(RL)' 카테고리의 다른 글
[Analyse RLLib] 5. Explore and Curriculum Learning (0) | 2021.02.26 |
---|---|
[Analyse RLLib] 4. RLlib CallBacks (0) | 2021.02.26 |
[Analyse RLLib] 2. RLlib 기본 훈련 코드 돌리기 (0) | 2021.02.26 |
[Analyse RLLib] 1. Ray와 RLlib의 전체적인 구조 (0) | 2021.02.26 |
MARL-DRONE (0) | 2020.12.31 |