Model-Based Actor
- class pytorchrl.agent.actors.model_based_planner_actor.ModelBasedPlannerActor(device, horizon, n_planner, input_space, action_space, algorithm_name, checkpoint=None, world_model_class=None, world_model_kwargs={})[source]
Bases:
pytorchrl.agent.actors.base.ActorActor Planner class for MB agents.
- actor_initial_states(obs)[source]
Returns all actor inputs required to predict initial action.
- Parameters
obs (torch.tensor) – Initial environment observation.
- Returns
obs (torch.tensor) – Initial environment observation.
rhs (dict) – Initial recurrent hidden state (will contain zeroes).
done (torch.tensor) – Initial done tensor, indicating the environment is not done.
- classmethod create_factory(input_space, action_space, algorithm_name, horizon, n_planner, restart_model=None, world_model_class=None, world_model_kwargs={})[source]
Returns a function that creates actor critic instances.
- Parameters
horizon (int) – The horizon of online planning.
n_planner (int) – Number of parallel planned trajectories.
input_space (gym.Space) – Environment observation space.
action_space (gym.Space) – Environment action space.
algorithm_name (str) – Name of the RL algorithm used for learning.
restart_model (str) – Path to a previously trained Actor checkpoint to be loaded.
world_model_class (class) – PyTorch nn.Module to approximate world dynamics.
world_model_kwargs – Keyword arguments for the world model class.
- Returns
create_actor_instance – creates a new OffPolicyActor class instance.
- Return type
func
- create_world_dynamics_model(world_model_class, world_model_kwargs)[source]
Create a world model instance and define it as class attribute under the name world_wodel.
- Parameters
world_model_class (class) – WorldModel class
world_model_kwargs (dict) – WorldModel class arguments
- get_prediction(obs, act, rhs, done, deterministic=False)[source]
Predict and return next action, along with other information.
- Parameters
obs (torch.tensor) – Current environment observation.
act (torch.tensor) – Action to take given obs.
rhs (dict) – Current recurrent hidden states.
done (torch.tensor) – Current done tensor, indicating if episode has finished.
deterministic (bool) – Whether to randomly sample action from predicted distribution or take the mode.
- Returns
next_states (torch.Tensor) – Next states.
rewards (torch.Tensor) – Reward prediction.
- property is_recurrent
Returns True if the actor network are recurrent.
Size of policy recurrent hidden state
- training: bool