On-Policy Actor

class pytorchrl.agent.actors.on_policy_actor.OnPolicyActor(device, input_space, action_space, algorithm_name, checkpoint=None, recurrent_net=None, recurrent_net_kwargs={}, feature_extractor_network=None, feature_extractor_kwargs={}, shared_policy_value_network=True)[source]

Bases: pytorchrl.agent.actors.base.Actor

Actor critic class for On-Policy algorithms.

It contains a policy network to predict next actions and a critic value network to predict the value score of a given obs.

Parameters
  • device (torch.device) – CPU or specific GPU where class computations will take place.

  • input_space (gym.Space) – Environment observation space.

  • action_space (gym.Space) – Environment action space.

  • algorithm_name (str) – Name of the RL algorithm used for learning.

  • checkpoint (str) – Path to a previously trained Actor checkpoint to be loaded.

  • recurrent_net (bool) – Whether to use a RNNs on top of the feature extractors.

  • recurrent_net_kwargs – Keyword arguments for the memory network.

  • feature_extractor_network (nn.Module) – PyTorch nn.Module used as the features extraction block in all networks.

  • feature_extractor_kwargs (dict) – Keyword arguments for the feature extractor network.

  • shared_policy_value_network (bool) – Whether or not to share weights between policy and value networks.

actor_initial_states(obs)[source]

Returns all actor inputs required to predict initial action.

Parameters

obs (torch.tensor) – Initial environment observation.

Returns

  • obs (torch.tensor) – Initial environment observation.

  • rhs (dict) – Initial recurrent hidden states.

  • done (torch.tensor) – Initial done tensor, indicating the environment is not done.

create_critic(name)[source]

Create a critic value network and define it as class attribute under the name name. This actor defines defines value networks as:

value = obs_feature_extractor + memory_net + v_prediction_layer

and defines shared policy-value network as:

action_distribution

value = obs_feature_extractor + memory_net +

v_prediction_layer

Parameters

name (str) – Critic network name.

classmethod create_factory(input_space, action_space, algorithm_name, restart_model=None, recurrent_net=None, recurrent_net_kwargs={}, feature_extractor_kwargs={}, feature_extractor_network=None, shared_policy_value_network=True)[source]

Returns a function that creates actor critic instances.

Parameters
  • input_space (gym.Space) – Environment observation space.

  • action_space (gym.Space) – Environment action space.

  • algorithm_name (str) – Name of the RL algorithm_name used for learning.

  • restart_model (str) – Path to a previously trained Actor checkpoint to be loaded.

  • feature_extractor_network (nn.Module) – PyTorch nn.Module used as the features extraction block in all networks.

  • feature_extractor_kwargs (dict) – Keyword arguments for the feature extractor network.

  • recurrent_net (nn.Module) – PyTorch nn.Module to use after the feature extractors.

  • recurrent_net_kwargs – Keyword arguments for the memory network.

  • shared_policy_value_network (bool) – Whether or not to share weights between policy and value networks.

Returns

create_actor_critic_instance – creates a new OnPolicyActor class instance.

Return type

func

create_policy(name)[source]

Create a policy network and define it as class attribute under the name name. This actor defines policy network as:

policy = obs_feature_extractor + memory_net + action_distribution

Parameters

name (str) – Policy network name.

evaluate_actions(obs, rhs, done, action)[source]

Evaluate log likelihood of action given obs and the current policy network. Returns also entropy distribution.

Parameters
  • obs (torch.tensor) – Environment observation.

  • rhs (dict) – Recurrent hidden states.

  • done (torch.tensor) – Done tensor, indicating if episode has finished.

  • action (torch.tensor) – Evaluated action.

Returns

  • logp_action (torch.tensor) – Log probability of action according to the action distribution predicted with current version of the policy_net.

  • entropy_dist (torch.tensor) – Entropy of the action distribution predicted with current version of the policy_net.

  • dist (torch.Distribution) – Predicted probability distribution over next action.

get_action(obs, rhs, done, deterministic=False)[source]

Predict and return next action, along with other information.

Parameters
  • obs (torch.tensor) – Current environment observation.

  • rhs (dict) – Current recurrent hidden states.

  • done (torch.tensor) – Current done tensor, indicating if episode has finished.

  • deterministic (bool) – Whether to randomly sample action from predicted distribution or take the mode.

Returns

  • action (torch.tensor) – Next action sampled.

  • clipped_action (torch.tensor) – Next action sampled, but clipped to be within the env action space.

  • logp_action (torch.tensor) – Log probability of action within the predicted action distribution.

  • rhs (dict) – Updated recurrent hidden states.

  • entropy_dist (torch.tensor) – Entropy of the predicted action distribution.

  • dist (torch.Distribution) – Predicted probability distribution over next action.

get_value(obs, rhs, done)[source]

Return all value scores of given observation.

Parameters
  • obs (torch.tensor) – Environment observation.

  • rhs (dict) – Recurrent hidden states.

  • done (torch.tensor) – Done tensor, indicating if episode has finished.

Returns

output – Dict containing value prediction from each critic under keys “value_net1”, “value_net2”, etc as well as the recurrent hidden states under the key “rhs”.

Return type

dict

get_value_specific_net(obs, rhs, done, value_net_name)[source]

Return value score for a single value network.

Parameters
  • obs (torch.tensor) – Environment observation.

  • rhs (dict) – Recurrent hidden states.

  • done (torch.tensor) – Done tensor, indicating if episode has finished.

Returns

  • value (torch.tensor) – Predicted value score.

  • rhs (dict) – Updated recurrent hidden states.

property is_recurrent

Returns True if the actor network are recurrent.

property recurrent_hidden_state_size

Size of policy recurrent hidden state

training: bool