Off-Policy Actor
- class pytorchrl.agent.actors.off_policy_actor.OffPolicyActor(device, input_space, action_space, algorithm_name, noise=None, checkpoint=None, sequence_overlap=0.5, recurrent_net=None, recurrent_net_kwargs={}, obs_feature_extractor=None, obs_feature_extractor_kwargs={}, act_feature_extractor=None, act_feature_extractor_kwargs={}, common_feature_extractor=None, common_feature_extractor_kwargs={}, num_critics=2)[source]
Bases:
pytorchrl.agent.actors.base.ActorActor critic class for Off-Policy algorithms.
It contains a policy network (actor) to predict next actions and one or two Q networks.
- Parameters
device (torch.device) – CPU or specific GPU where class computations will take place.
input_space (gym.Space) – Environment observation space.
action_space (gym.Space) – Environment action space.
algorithm_name (str) – Name of the RL algorithm used for learning.
checkpoint (str) – Path to a previously trained Actor checkpoint to be loaded.
noise (str) – Type of exploration noise that will be added to the deterministic actions.
obs_feature_extractor (nn.Module) – PyTorch nn.Module to extract features from observation in all networks.
obs_feature_extractor_kwargs (dict) – Keyword arguments for the obs extractor network.
act_feature_extractor (nn.Module) – PyTorch nn.Module to extract features from actions in all networks.
act_feature_extractor_kwargs (dict) – Keyword arguments for the act extractor network.
common_feature_extractor (nn.Module) – PyTorch nn.Module to extract joint features from the concatenation of action and observation features.
common_feature_extractor_kwargs (dict) – Keyword arguments for the common extractor network.
recurrent_net (bool) – Whether to use a RNNs as feature extractors.
sequence_overlap (float) – From 0.0 to 1.0, how much consecutive rollout sequences will overlap.
recurrent_net_kwargs (dict) – Keyword arguments for the memory network.
num_critics (int) – Number of Q networks to be instantiated.
Examples
- actor_initial_states(obs)[source]
Returns all actor inputs required to predict initial action.
- Parameters
obs (torch.tensor) – Initial environment observation.
- Returns
obs (torch.tensor) – Initial environment observation.
rhs (dict) – Initial recurrent hidden state (will contain zeroes).
done (torch.tensor) – Initial done tensor, indicating the environment is not done.
- burn_in_recurrent_states(data_batch)[source]
Applies a recurrent burn-in phase to data_batch as described in (https://openreview.net/pdf?id=r1lyTjAqYX). Initial B steps are used to compute on-policy recurrent hidden states. data_batch is then updated, discarding B first steps in all tensors.
- Parameters
data_batch (dict) – data batch containing all required tensors to compute Algorithm loss.
- Returns
data_batch – Updated data batch after burn-in phase.
- Return type
dict
- create_critic(name)[source]
Create a critic q network and define it as class attribute under the name name. This actor defines defines q networks as:
obs_feature_extractor
- q = + common_feature_extractor + memory_net + q_prediction_layer
act_feature_extractor
- Parameters
name (str) – Critic network name.
- classmethod create_factory(input_space, action_space, algorithm_name, noise=None, restart_model=None, sequence_overlap=0.5, recurrent_net_kwargs={}, recurrent_net=None, obs_feature_extractor=None, obs_feature_extractor_kwargs={}, act_feature_extractor=None, act_feature_extractor_kwargs={}, common_feature_extractor=<class 'pytorchrl.agent.actors.feature_extractors.mlp.MLP'>, common_feature_extractor_kwargs={}, num_critics=2)[source]
Returns a function that creates actor critic instances.
- Parameters
input_space (gym.Space) – Environment observation space.
action_space (gym.Space) – Environment action space.
algorithm_name (str) – Name of the RL algorithm used for learning.
noise (str) – Type of exploration noise that will be added to the deterministic actions.
obs_feature_extractor (nn.Module) – PyTorch nn.Module to extract features from observation in all networks.
obs_feature_extractor_kwargs (dict) – Keyword arguments for the obs extractor network.
act_feature_extractor (nn.Module) – PyTorch nn.Module to extract features from actions in all networks.
act_feature_extractor_kwargs (dict) – Keyword arguments for the act extractor network.
common_feature_extractor (nn.Module) – PyTorch nn.Module to extract joint features from the concatenation of action and observation features.
common_feature_extractor_kwargs (dict) – Keyword arguments for the common extractor network.
recurrent_net (bool) – Whether to use a RNNs as feature extractors.
sequence_overlap (float) – From 0.0 to 1.0, how much consecutive rollout sequences will overlap.
recurrent_net_kwargs (dict) – Keyword arguments for the memory network.
num_critics (int) – Number of Q networks to be instantiated.
restart_model (str) – Path to a previously trained Actor checkpoint to be loaded.
- Returns
create_actor_instance – creates a new OffPolicyActor class instance.
- Return type
func
- create_policy(name)[source]
Create a policy network and define it as class attribute under the name name. This actor defines policy network as:
policy = obs_feature_extractor + common_feature_extractor + memory_net + action distribution
- Parameters
name (str) – Policy network name.
- evaluate_actions(obs, rhs, done, action)[source]
Evaluate log likelihood of action given obs and the current policy network. Returns also entropy distribution.
- Parameters
obs (torch.tensor) – Environment observation.
rhs (dict) – Recurrent hidden states.
done (torch.tensor) – Done tensor, indicating if episode has finished.
action (torch.tensor) – Evaluated action.
- Returns
logp_action (torch.tensor) – Log probability of action according to the action distribution predicted with current version of the policy_net.
entropy_dist (torch.tensor) – Entropy of the action distribution predicted with current version of the policy_net.
dist (torch.Distribution) – Predicted probability distribution over next action.
- get_action(obs, rhs, done, deterministic=False)[source]
Predict and return next action, along with other information.
- Parameters
obs (torch.tensor) – Current environment observation.
rhs (dict) – Current recurrent hidden states.
done (torch.tensor) – Current done tensor, indicating if episode has finished.
deterministic (bool) – Whether to randomly sample action from predicted distribution or take the mode.
- Returns
action (torch.tensor) – Next action sampled.
clipped_action (torch.tensor) – Next action sampled, but clipped to be within the env action space.
logp_action (torch.tensor) – Log probability of action within the predicted action distribution.
rhs (dict) – Updated recurrent hidden states.
entropy_dist (torch.tensor) – Entropy of the predicted action distribution.
dist (torch.Distribution) – Predicted probability distribution over next action.
- get_q_scores(obs, rhs, done, actions=None)[source]
Return Q scores of the given observations and actions.
- Parameters
obs (torch.tensor) – Environment observation.
rhs (dict) – Current recurrent hidden states.
done (torch.tensor) – Current done tensor, indicating if episode has finished.
actions (torch.tensor) – Evaluated actions.
- Returns
output – Dict containing value prediction from each critic under keys “q1”, “q2”, etc as well as the recurrent hidden states under the key “rhs”.
- Return type
dict
- property is_recurrent
Returns True if the actor network are recurrent.
Size of policy recurrent hidden state
- training: bool