Off-Policy Actor

class pytorchrl.agent.actors.off_policy_actor.OffPolicyActor(device, input_space, action_space, algorithm_name, noise=None, checkpoint=None, sequence_overlap=0.5, recurrent_net=None, recurrent_net_kwargs={}, obs_feature_extractor=None, obs_feature_extractor_kwargs={}, act_feature_extractor=None, act_feature_extractor_kwargs={}, common_feature_extractor=None, common_feature_extractor_kwargs={}, num_critics=2)[source]

Bases: pytorchrl.agent.actors.base.Actor

Actor critic class for Off-Policy algorithms.

It contains a policy network (actor) to predict next actions and one or two Q networks.

Parameters
  • device (torch.device) – CPU or specific GPU where class computations will take place.

  • input_space (gym.Space) – Environment observation space.

  • action_space (gym.Space) – Environment action space.

  • algorithm_name (str) – Name of the RL algorithm used for learning.

  • checkpoint (str) – Path to a previously trained Actor checkpoint to be loaded.

  • noise (str) – Type of exploration noise that will be added to the deterministic actions.

  • obs_feature_extractor (nn.Module) – PyTorch nn.Module to extract features from observation in all networks.

  • obs_feature_extractor_kwargs (dict) – Keyword arguments for the obs extractor network.

  • act_feature_extractor (nn.Module) – PyTorch nn.Module to extract features from actions in all networks.

  • act_feature_extractor_kwargs (dict) – Keyword arguments for the act extractor network.

  • common_feature_extractor (nn.Module) – PyTorch nn.Module to extract joint features from the concatenation of action and observation features.

  • common_feature_extractor_kwargs (dict) – Keyword arguments for the common extractor network.

  • recurrent_net (bool) – Whether to use a RNNs as feature extractors.

  • sequence_overlap (float) – From 0.0 to 1.0, how much consecutive rollout sequences will overlap.

  • recurrent_net_kwargs (dict) – Keyword arguments for the memory network.

  • num_critics (int) – Number of Q networks to be instantiated.

Examples

actor_initial_states(obs)[source]

Returns all actor inputs required to predict initial action.

Parameters

obs (torch.tensor) – Initial environment observation.

Returns

  • obs (torch.tensor) – Initial environment observation.

  • rhs (dict) – Initial recurrent hidden state (will contain zeroes).

  • done (torch.tensor) – Initial done tensor, indicating the environment is not done.

burn_in_recurrent_states(data_batch)[source]

Applies a recurrent burn-in phase to data_batch as described in (https://openreview.net/pdf?id=r1lyTjAqYX). Initial B steps are used to compute on-policy recurrent hidden states. data_batch is then updated, discarding B first steps in all tensors.

Parameters

data_batch (dict) – data batch containing all required tensors to compute Algorithm loss.

Returns

data_batch – Updated data batch after burn-in phase.

Return type

dict

create_critic(name)[source]

Create a critic q network and define it as class attribute under the name name. This actor defines defines q networks as:

obs_feature_extractor

q = + common_feature_extractor + memory_net + q_prediction_layer

act_feature_extractor

Parameters

name (str) – Critic network name.

classmethod create_factory(input_space, action_space, algorithm_name, noise=None, restart_model=None, sequence_overlap=0.5, recurrent_net_kwargs={}, recurrent_net=None, obs_feature_extractor=None, obs_feature_extractor_kwargs={}, act_feature_extractor=None, act_feature_extractor_kwargs={}, common_feature_extractor=<class 'pytorchrl.agent.actors.feature_extractors.mlp.MLP'>, common_feature_extractor_kwargs={}, num_critics=2)[source]

Returns a function that creates actor critic instances.

Parameters
  • input_space (gym.Space) – Environment observation space.

  • action_space (gym.Space) – Environment action space.

  • algorithm_name (str) – Name of the RL algorithm used for learning.

  • noise (str) – Type of exploration noise that will be added to the deterministic actions.

  • obs_feature_extractor (nn.Module) – PyTorch nn.Module to extract features from observation in all networks.

  • obs_feature_extractor_kwargs (dict) – Keyword arguments for the obs extractor network.

  • act_feature_extractor (nn.Module) – PyTorch nn.Module to extract features from actions in all networks.

  • act_feature_extractor_kwargs (dict) – Keyword arguments for the act extractor network.

  • common_feature_extractor (nn.Module) – PyTorch nn.Module to extract joint features from the concatenation of action and observation features.

  • common_feature_extractor_kwargs (dict) – Keyword arguments for the common extractor network.

  • recurrent_net (bool) – Whether to use a RNNs as feature extractors.

  • sequence_overlap (float) – From 0.0 to 1.0, how much consecutive rollout sequences will overlap.

  • recurrent_net_kwargs (dict) – Keyword arguments for the memory network.

  • num_critics (int) – Number of Q networks to be instantiated.

  • restart_model (str) – Path to a previously trained Actor checkpoint to be loaded.

Returns

create_actor_instance – creates a new OffPolicyActor class instance.

Return type

func

create_policy(name)[source]

Create a policy network and define it as class attribute under the name name. This actor defines policy network as:

policy = obs_feature_extractor + common_feature_extractor + memory_net + action distribution

Parameters

name (str) – Policy network name.

evaluate_actions(obs, rhs, done, action)[source]

Evaluate log likelihood of action given obs and the current policy network. Returns also entropy distribution.

Parameters
  • obs (torch.tensor) – Environment observation.

  • rhs (dict) – Recurrent hidden states.

  • done (torch.tensor) – Done tensor, indicating if episode has finished.

  • action (torch.tensor) – Evaluated action.

Returns

  • logp_action (torch.tensor) – Log probability of action according to the action distribution predicted with current version of the policy_net.

  • entropy_dist (torch.tensor) – Entropy of the action distribution predicted with current version of the policy_net.

  • dist (torch.Distribution) – Predicted probability distribution over next action.

get_action(obs, rhs, done, deterministic=False)[source]

Predict and return next action, along with other information.

Parameters
  • obs (torch.tensor) – Current environment observation.

  • rhs (dict) – Current recurrent hidden states.

  • done (torch.tensor) – Current done tensor, indicating if episode has finished.

  • deterministic (bool) – Whether to randomly sample action from predicted distribution or take the mode.

Returns

  • action (torch.tensor) – Next action sampled.

  • clipped_action (torch.tensor) – Next action sampled, but clipped to be within the env action space.

  • logp_action (torch.tensor) – Log probability of action within the predicted action distribution.

  • rhs (dict) – Updated recurrent hidden states.

  • entropy_dist (torch.tensor) – Entropy of the predicted action distribution.

  • dist (torch.Distribution) – Predicted probability distribution over next action.

get_q_scores(obs, rhs, done, actions=None)[source]

Return Q scores of the given observations and actions.

Parameters
  • obs (torch.tensor) – Environment observation.

  • rhs (dict) – Current recurrent hidden states.

  • done (torch.tensor) – Current done tensor, indicating if episode has finished.

  • actions (torch.tensor) – Evaluated actions.

Returns

output – Dict containing value prediction from each critic under keys “q1”, “q2”, etc as well as the recurrent hidden states under the key “rhs”.

Return type

dict

property is_recurrent

Returns True if the actor network are recurrent.

property recurrent_hidden_state_size

Size of policy recurrent hidden state

training: bool