본문 바로가기

딥러닝/강화학습(RL)

[Analyse RLLib] 3. Train Model with Ray Trainer

Train Model with Ray Trainer

DQN with simple environment

주어진 환경에 대해서 빠르게 강화학습 코드를 돌리는 방법은 다음과 같다.

  1. First Register Environment
  2. Train the model with tune.run

필요한 라이브러리를 import합니다.

import ray 
from ray.tune.registry import register_env 
from ray.rllib.agents import ppo 
from ray import tune 
from my_env import MyEnv

gym environment를 불러오고, ray에 등록해줘야 합니다. distributed system을 위해서 반드시 등록을 해줘야 합니다. 또한 Model도 등록을 해줘야 합니다. 물론 CartPole-v0또는 DQN처럼 많이 사용되는 환경과 모델은 이미 등록이 되어 있습니다.

Here is a simple way to train already registered model.

# Init ray and register custom env
ray.init()
register_env("my_env", lambda config : MyEnv(config))


# Run tune with DQN
tune.run("DQN", 
        stop = {"training_iteration": 100},
        config= {'env':'my_env',
                "timesteps_per_iteration": 1000,
                "buffer_size": 10000,    
                "train_batch_size": 64,
                'lr' : 1e-4
                }
        )

Tainer

위에서 나온 DQN은 Trainer입니다.

environment로부터 observation을 받아서, 처리하고 action을 compute해주며, 훈련을 진행합니다.

총 4가지 주요 함수가 있는데 마지막 compute_action()만 살펴보도록 하겠습니다.

  1. train()
  2. save()
  3. restore()
  4. compute_action()

Compute Action

trainer의 compute_action을 사용하면, Observation에 대해서 어떠한 action을 취해야 하는지 결정할 수 있습니다.

import ray 
from my_env import MyEnv
from ray.rllib.agents import ppo
from ray.tune.registry import register_env

# === init ray ===
ray.init()

# === register env ===
register_env("my_env", lambda config : MyEnv(config))

# === agent and env ===
agent = ppo.PPOTrainer(env="my_env")
env = MyEnv()

# === initial values === 
obs = env.reset()
done = False 
episode_reward = 0

# === choose actions with agent.compute_action(obs) === 
while not done:
    action = agent.compute_action(obs)
    obs, reward, done, info = env.step(action)
    episode_reward += reward 
    print(episode_reward)

inside of compute_action()

Trainer의 compute_action(...) 함수는 observation을 처리하고 내부의 policy를 이용해서 action을 계산하는 역할을 합니다.

class Trainer(Trainable):
    @publicAPI
    def compute_action(self, 
                        observation, 
                        state=None, 
                        prev_action=None, 
                        prev_reward=None, 
                        info=None, 
                        policy_id=DEFAULT_POLICY_ID, 
                        full_fetch=False):

        if state is None:
            state = [] 
        # === observation에 대해서 전처리를 진행한다 ===
        preprocessed = self.workers.local_worker().preprocessors[policy_id].transform(observation)
        filtered_obs = self.workers.local_worker().filters[policy_id](preprocessed, update=False)

        # === Trainer 안에 있는 policy를 이용해서 action을 계산한다.  ===
        if state:
            return self.get_policy(policy_id).compute_single_action(...)

        res = self.get_policy(policy_id).compute_single_action(...)
        if full_fetch:
            return res
        else:
            return res[0]

inside policy, MODEL!!(Neural Network)

policy안에는 neural network가 있어서, 모델을 통해서 다음 행동을 결정할 수도 있습니다. 만일 Rule Base로 행동을 결정하고 싶다면 이는 compute_action에서 이루어지고 model과는 별개로 진행됩니다.

import ray 
import numpy as np
from ray.rllib.agents.ppo import PPOTrainer 

## === setup ===
ray.init()
trainer = PPOTrainer(env="CartPole-v0")
policy = trainer.get_policy()

# === feed a minibatch of size 1 ===
logits, _ = policy.model.from_batch({"obs":np.array([[0.1, 0.2, 0.3, 0.4]])})
dist = policy.dist_class(logits, policy.model)

print(type(policy.model))  # <class 'ray.rllib.models.tf.fcnet.FullyConnectedNetwork'>
print(dist.sample())       # Tensor("Squeeze:0", shape=(1,), dtype=int64)
print(dist.logp)           # bound method Categorical.logp 
print(policy.model.value_function())     # Tensor("Reshape:0", shape=(1,), dtype=float32)
print(policy.model.base_model.summary()) # Model: "functional_1"

Getting Q-values from a DQN

모델을 좀더 세부적으로 파고든다면, input batch를 넣은 상태에서 값들을 추출해낼 수 있습니다.

import ray 
from ray.rllib.agents.dqn import DQNTrainer 
import numpy as np

# === setup === 
ray.init()
trainer = DQNTrainer(env="CartPole-v0")
model = trainer.get_policy().model 

# === feed to model ===
model_out = model.from_batch({"obs":np.array([[0.1, 0.2, 0.3, 0.4]])})
model_out_dist = model.get_q_value_distributions(model_out)

# === print values ===
print(model_out) # Tuple
print(model_out_dist)  # List
print(model.get_state_value(model_out)) # Tensor 


"""
(<tf.Tensor 'functional_1/fc_out/Tanh:0' shape=(1, 256) dtype=float32>, [])

[<tf.Tensor 'BiasAdd_1:0' shape=(1, 2) dtype=float32>, 
<tf.Tensor 'default_policy/ExpandDims_10:0' shape=(1, 2, 1) dtype=float32>, 
<tf.Tensor 'default_policy/ExpandDims_1_1:0' shape=(1, 2, 1) dtype=float32>]

Tensor("BiasAdd_3:0", shape=(1, 1), dtype=float32)
"""

References

[1] https://docs.ray.io/en/master/rllib.html