Source code for pytorchrl.agent.actors.model_based_planner_actor

import gym
import torch

from pytorchrl.agent.actors.base import Actor
from pytorchrl.agent.actors.world_models import WorldModel
from pytorchrl.agent.actors.world_models.utils import StandardScaler


[docs]class ModelBasedPlannerActor(Actor): """ Actor Planner class for MB agents. """ def __init__(self, device, horizon, n_planner, input_space, action_space, algorithm_name, checkpoint=None, world_model_class=None, world_model_kwargs={}): super(ModelBasedPlannerActor, self).__init__( device=device, checkpoint=checkpoint, input_space=input_space, action_space=action_space) if type(action_space) == gym.spaces.discrete.Discrete: self.action_dims = action_space.n self.action_type = "discrete" self.action_low = None self.action_high = None elif type(action_space) == gym.spaces.box.Box: self.action_dims = action_space.shape[0] self.action_type = "continuous" self.action_low = action_space.low self.action_high = action_space.high else: raise ValueError("Unknown action space") self.horizon = horizon self.n_planner = n_planner self.algorithm_name = algorithm_name # ----- World Model --------------------------------------------------- self.create_world_dynamics_model(world_model_class, world_model_kwargs)
[docs] @classmethod def create_factory( cls, input_space, action_space, algorithm_name, horizon, n_planner, restart_model=None, world_model_class=None, world_model_kwargs={} ): """ Returns a function that creates actor critic instances. Parameters ---------- horizon : int The horizon of online planning. n_planner : int Number of parallel planned trajectories. input_space : gym.Space Environment observation space. action_space : gym.Space Environment action space. algorithm_name : str Name of the RL algorithm used for learning. restart_model : str Path to a previously trained Actor checkpoint to be loaded. world_model_class : class PyTorch nn.Module to approximate world dynamics. world_model_kwargs Keyword arguments for the world model class. Returns ------- create_actor_instance : func creates a new OffPolicyActor class instance. """ def create_actor_instance(device): """Create and return an actor critic instance.""" policy = cls(device=device, horizon=horizon, n_planner=n_planner, input_space=input_space, action_space=action_space, algorithm_name=algorithm_name, checkpoint=restart_model, world_model_class=world_model_class, world_model_kwargs=world_model_kwargs) policy.to(device) try: policy.try_load_from_checkpoint() except RuntimeError: pass return policy return create_actor_instance
@property def is_recurrent(self): """Returns True if the actor network are recurrent.""" return False @property def recurrent_hidden_state_size(self): """Size of policy recurrent hidden state""" return 1
[docs] def actor_initial_states(self, obs): """ Returns all actor inputs required to predict initial action. Parameters ---------- obs : torch.tensor Initial environment observation. Returns ------- obs : torch.tensor Initial environment observation. rhs : dict Initial recurrent hidden state (will contain zeroes). done : torch.tensor Initial done tensor, indicating the environment is not done. """ if isinstance(obs, dict): num_proc = list(obs.values())[0].size(0) dev = list(obs.values())[0].device else: num_proc = obs.size(0) dev = obs.device done = torch.zeros(num_proc, 1).to(dev) try: rhs = self.policy_net.memory_net.get_initial_recurrent_state(num_proc).to(dev) except Exception: rhs = torch.zeros(num_proc, self.recurrent_hidden_state_size).to(dev) rhs = {"world_model_rhs": rhs} return obs, rhs, done
[docs] def get_prediction(self, obs, act, rhs, done, deterministic=False): """ Predict and return next action, along with other information. Parameters ---------- obs : torch.tensor Current environment observation. act : torch.tensor Action to take given obs. rhs : dict Current recurrent hidden states. done : torch.tensor Current done tensor, indicating if episode has finished. deterministic : bool Whether to randomly sample action from predicted distribution or take the mode. Returns ------- next_states : torch.Tensor Next states. rewards : torch.Tensor Reward prediction. """ next_states, rewards = self.dynamics_model.predict(obs, act) return next_states, rewards
[docs] def create_world_dynamics_model(self, world_model_class, world_model_kwargs): """ Create a world model instance and define it as class attribute under the name `world_wodel`. Parameters ---------- world_model_class : class WorldModel class world_model_kwargs : dict WorldModel class arguments """ if world_model_class is None: world_model_class = WorldModel data_scaler = StandardScaler(self.device) dynamics_model = world_model_class( self.device, self.input_space, self.action_space, data_scaler, **world_model_kwargs).to(self.device) setattr(self, "dynamics_model", dynamics_model)