Model-Based Actor

class pytorchrl.agent.actors.model_based_planner_actor.ModelBasedPlannerActor(device, horizon, n_planner, input_space, action_space, algorithm_name, checkpoint=None, world_model_class=None, world_model_kwargs={})[source]

Bases: pytorchrl.agent.actors.base.Actor

Actor Planner class for MB agents.

actor_initial_states(obs)[source]

Returns all actor inputs required to predict initial action.

Parameters

obs (torch.tensor) – Initial environment observation.

Returns

  • obs (torch.tensor) – Initial environment observation.

  • rhs (dict) – Initial recurrent hidden state (will contain zeroes).

  • done (torch.tensor) – Initial done tensor, indicating the environment is not done.

classmethod create_factory(input_space, action_space, algorithm_name, horizon, n_planner, restart_model=None, world_model_class=None, world_model_kwargs={})[source]

Returns a function that creates actor critic instances.

Parameters
  • horizon (int) – The horizon of online planning.

  • n_planner (int) – Number of parallel planned trajectories.

  • input_space (gym.Space) – Environment observation space.

  • action_space (gym.Space) – Environment action space.

  • algorithm_name (str) – Name of the RL algorithm used for learning.

  • restart_model (str) – Path to a previously trained Actor checkpoint to be loaded.

  • world_model_class (class) – PyTorch nn.Module to approximate world dynamics.

  • world_model_kwargs – Keyword arguments for the world model class.

Returns

create_actor_instance – creates a new OffPolicyActor class instance.

Return type

func

create_world_dynamics_model(world_model_class, world_model_kwargs)[source]

Create a world model instance and define it as class attribute under the name world_wodel.

Parameters
  • world_model_class (class) – WorldModel class

  • world_model_kwargs (dict) – WorldModel class arguments

get_prediction(obs, act, rhs, done, deterministic=False)[source]

Predict and return next action, along with other information.

Parameters
  • obs (torch.tensor) – Current environment observation.

  • act (torch.tensor) – Action to take given obs.

  • rhs (dict) – Current recurrent hidden states.

  • done (torch.tensor) – Current done tensor, indicating if episode has finished.

  • deterministic (bool) – Whether to randomly sample action from predicted distribution or take the mode.

Returns

  • next_states (torch.Tensor) – Next states.

  • rewards (torch.Tensor) – Reward prediction.

property is_recurrent

Returns True if the actor network are recurrent.

property recurrent_hidden_state_size

Size of policy recurrent hidden state

training: bool