import gym
import numpy as np
import torch
import torch.nn as nn
from collections import OrderedDict
import pytorchrl as prl
from pytorchrl.agent.actors.base import Actor
from pytorchrl.agent.actors.distributions import get_dist
from pytorchrl.agent.actors.utils import Scale, Unscale, init
from pytorchrl.agent.actors.feature_extractors import default_feature_extractor
[docs]class OnPolicyActor(Actor):
"""
Actor critic class for On-Policy algorithms.
It contains a policy network to predict next actions and a critic
value network to predict the value score of a given obs.
Parameters
----------
device: torch.device
CPU or specific GPU where class computations will take place.
input_space : gym.Space
Environment observation space.
action_space : gym.Space
Environment action space.
algorithm_name : str
Name of the RL algorithm used for learning.
checkpoint : str
Path to a previously trained Actor checkpoint to be loaded.
recurrent_net : bool
Whether to use a RNNs on top of the feature extractors.
recurrent_net_kwargs:
Keyword arguments for the memory network.
feature_extractor_network : nn.Module
PyTorch nn.Module used as the features extraction block in all networks.
feature_extractor_kwargs : dict
Keyword arguments for the feature extractor network.
shared_policy_value_network : bool
Whether or not to share weights between policy and value networks.
"""
def __init__(self,
device,
input_space,
action_space,
algorithm_name,
checkpoint=None,
recurrent_net=None,
recurrent_net_kwargs={},
feature_extractor_network=None,
feature_extractor_kwargs={},
shared_policy_value_network=True):
super(OnPolicyActor, self).__init__(
device=device,
checkpoint=checkpoint,
input_space=input_space,
action_space=action_space)
self.recurrent_net = recurrent_net
self.recurrent_net_kwargs = recurrent_net_kwargs
self.feature_extractor_network = feature_extractor_network
self.shared_policy_value_network = shared_policy_value_network
self.feature_extractor_kwargs = feature_extractor_kwargs
if algorithm_name in (prl.A2C, prl.PPO):
self.num_critics_ext = 1
self.num_critics_int = 0
elif algorithm_name in (prl.RND_PPO):
self.num_critics_ext = 1
self.num_critics_int = 1
# ----- Policy Network ----------------------------------------------------
self.create_policy("policy_net")
# ----- Value Networks ----------------------------------------------------
for i in range(self.num_critics_ext):
self.create_critic("value_net{}".format(i + 1))
for i in range(self.num_critics_int):
self.create_critic("ivalue_net{}".format(i + 1))
[docs] @classmethod
def create_factory(
cls,
input_space,
action_space,
algorithm_name,
restart_model=None,
recurrent_net=None,
recurrent_net_kwargs={},
feature_extractor_kwargs={},
feature_extractor_network=None,
shared_policy_value_network=True):
"""
Returns a function that creates actor critic instances.
Parameters
----------
input_space : gym.Space
Environment observation space.
action_space : gym.Space
Environment action space.
algorithm_name : str
Name of the RL algorithm_name used for learning.
restart_model : str
Path to a previously trained Actor checkpoint to be loaded.
feature_extractor_network : nn.Module
PyTorch nn.Module used as the features extraction block in all networks.
feature_extractor_kwargs : dict
Keyword arguments for the feature extractor network.
recurrent_net : nn.Module
PyTorch nn.Module to use after the feature extractors.
recurrent_net_kwargs:
Keyword arguments for the memory network.
shared_policy_value_network : bool
Whether or not to share weights between policy and value networks.
Returns
-------
create_actor_critic_instance : func
creates a new OnPolicyActor class instance.
"""
def create_actor_critic_instance(device):
"""Create and return an actor critic instance."""
policy = cls(device=device,
input_space=input_space,
action_space=action_space,
algorithm_name=algorithm_name,
recurrent_net=recurrent_net,
checkpoint=restart_model,
recurrent_net_kwargs=recurrent_net_kwargs,
feature_extractor_kwargs=feature_extractor_kwargs,
feature_extractor_network=feature_extractor_network,
shared_policy_value_network=shared_policy_value_network)
policy.to(device)
try:
policy.try_load_from_checkpoint()
except RuntimeError:
pass
return policy
return create_actor_critic_instance
@property
def is_recurrent(self):
"""Returns True if the actor network are recurrent."""
return self.recurrent_net
@property
def recurrent_hidden_state_size(self):
"""Size of policy recurrent hidden state"""
return self.recurrent_size
[docs] def actor_initial_states(self, obs):
"""
Returns all actor inputs required to predict initial action.
Parameters
----------
obs : torch.tensor
Initial environment observation.
Returns
-------
obs : torch.tensor
Initial environment observation.
rhs : dict
Initial recurrent hidden states.
done : torch.tensor
Initial done tensor, indicating the environment is not done.
"""
if isinstance(obs, dict):
num_proc = list(obs.values())[0].size(0)
dev = list(obs.values())[0].device
else:
num_proc = obs.size(0)
dev = obs.device
done = torch.zeros(num_proc, 1).to(dev)
try:
rhs_policy = self.policy_net.memory_net.get_initial_recurrent_state(num_proc).to(dev)
except Exception:
rhs_policy = torch.zeros(num_proc, self.recurrent_hidden_state_size).to(dev)
rhs = {"policy": rhs_policy}
rhs.update({"value_net{}".format(i + 1): rhs_policy.clone() for i in range(self.num_critics_ext)})
rhs.update({"ivalue_net{}".format(i + 1): rhs_policy.clone() for i in range(self.num_critics_int)})
return obs, rhs, done
[docs] def get_action(self, obs, rhs, done, deterministic=False):
"""
Predict and return next action, along with other information.
Parameters
----------
obs : torch.tensor
Current environment observation.
rhs : dict
Current recurrent hidden states.
done : torch.tensor
Current done tensor, indicating if episode has finished.
deterministic : bool
Whether to randomly sample action from predicted distribution or take the mode.
Returns
-------
action : torch.tensor
Next action sampled.
clipped_action : torch.tensor
Next action sampled, but clipped to be within the env action space.
logp_action : torch.tensor
Log probability of `action` within the predicted action distribution.
rhs : dict
Updated recurrent hidden states.
entropy_dist : torch.tensor
Entropy of the predicted action distribution.
dist : torch.Distribution
Predicted probability distribution over next action.
"""
features = self.policy_net.feature_extractor(obs)
if self.recurrent_net:
features, rhs["policy"] = self.policy_net.memory_net(
features, rhs["policy"], done)
(action, clipped_action, logp_action, entropy_dist, dist) = self.policy_net.dist(
features, deterministic=deterministic)
self.last_action_features = features
if self.unscale:
action = self.unscale(action)
clipped_action = self.unscale(clipped_action)
return action, clipped_action, logp_action, rhs, entropy_dist, dist
[docs] def evaluate_actions(self, obs, rhs, done, action):
"""
Evaluate log likelihood of action given obs and the current
policy network. Returns also entropy distribution.
Parameters
----------
obs : torch.tensor
Environment observation.
rhs : dict
Recurrent hidden states.
done : torch.tensor
Done tensor, indicating if episode has finished.
action : torch.tensor
Evaluated action.
Returns
-------
logp_action : torch.tensor
Log probability of `action` according to the action distribution
predicted with current version of the policy_net.
entropy_dist : torch.tensor
Entropy of the action distribution predicted with current version
of the policy_net.
dist : torch.Distribution
Predicted probability distribution over next action.
"""
if self.scale:
action = self.scale(action)
features = self.policy_net.feature_extractor(obs)
if self.recurrent_net:
features, rhs["policy"] = self.policy_net.memory_net(features, rhs["policy"], done)
logp_action, entropy_dist, dist = self.policy_net.dist.evaluate_pred(features, action)
self.last_action_features = features
return logp_action, entropy_dist, dist
[docs] def get_value_specific_net(self, obs, rhs, done, value_net_name):
"""
Return value score for a single value network.
Parameters
----------
obs : torch.tensor
Environment observation.
rhs : dict
Recurrent hidden states.
done : torch.tensor
Done tensor, indicating if episode has finished.
Returns
-------
value : torch.tensor
Predicted value score.
rhs : dict
Updated recurrent hidden states.
"""
value_net = getattr(self, value_net_name)
if self.shared_policy_value_network:
if self.last_action_features is None or self.last_action_features.shape[0] != done.shape[0]:
_, _, _, _, _, _ = self.get_action(obs, rhs["policy"], done)
value = value_net.predictor(self.last_action_features)
else:
features = value_net.feature_extractor(obs)
if self.recurrent_net:
features, rhs[value_net_name] = value_net.memory_net(
features, rhs[value_net_name], done)
value = value_net.predictor(features)
return value, rhs
[docs] def get_value(self, obs, rhs, done):
"""
Return all value scores of given observation.
Parameters
----------
obs : torch.tensor
Environment observation.
rhs : dict
Recurrent hidden states.
done : torch.tensor
Done tensor, indicating if episode has finished.
Returns
-------
output : dict
Dict containing value prediction from each critic under keys "value_net1", "value_net2", etc
as well as the recurrent hidden states under the key "rhs".
"""
outputs = {}
for i in range(self.num_critics_ext):
value_net_name = "value_net{}".format(i + 1)
value, rhs = self.get_value_specific_net(obs, rhs, done, value_net_name)
outputs[value_net_name] = value
for i in range(self.num_critics_int):
value_net_name = "ivalue_net{}".format(i + 1)
value, rhs = self.get_value_specific_net(obs, rhs, done, value_net_name)
outputs[value_net_name] = value
outputs["rhs"] = rhs
return outputs
[docs] def create_critic(self, name):
"""
Create a critic value network and define it as class attribute under the name `name`.
This actor defines defines value networks as:
value = obs_feature_extractor + memory_net + v_prediction_layer
and defines shared policy-value network as:
action_distribution
value = obs_feature_extractor + memory_net +
v_prediction_layer
Parameters
----------
name : str
Critic network name.
"""
# If feature_extractor_network not defined, take default one based on input_space
feature_extractor = self.feature_extractor_network or default_feature_extractor(self.input_space)
if self.shared_policy_value_network:
value_feature_extractor = nn.Identity()
value_memory_net = nn.Identity()
self.last_action_features = None
else:
# ---- 1. Define obs feature extractor ----------------------------
value_feature_extractor = feature_extractor(
self.input_space, **self.feature_extractor_kwargs)
# ---- 2. Define memory network ----------------------------------
if isinstance(self.input_space, gym.spaces.Dict):
dummy_obs = {k: torch.zeros(1, *self.input_space[k].shape) for k in self.input_space}
features = value_feature_extractor(dummy_obs)
else:
dummy_obs = torch.zeros(1, *self.input_space.shape)
features = value_feature_extractor(dummy_obs)
if isinstance(value_feature_extractor, nn.Identity):
feature_size = dummy_obs
else:
feature_size = features.shape[-1]
if self.recurrent_net:
value_memory_net = self.recurrent_net(feature_size, **self.recurrent_net_kwargs)
else:
value_memory_net = nn.Identity()
# ---- 3. Define value predictor --------------------------------------
init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), gain=np.sqrt(0.01))
value_predictor = init_(nn.Linear(self.recurrent_size, 1))
# ---- 4. Concatenate all value net modules ---------------------------
v_net = nn.Sequential(OrderedDict([
('feature_extractor', value_feature_extractor),
('memory_net', value_memory_net),
("predictor", value_predictor),
]))
setattr(self, name, v_net)
[docs] def create_policy(self, name):
"""
Create a policy network and define it as class attribute under the name `name`.
This actor defines policy network as:
policy = obs_feature_extractor + memory_net + action_distribution
Parameters
----------
name : str
Policy network name.
"""
# If feature_extractor_network not defined, take default one based on input_space
feature_extractor = self.feature_extractor_network or default_feature_extractor(self.input_space)
# ---- 1. Define obs feature extractor --------------------------------
policy_feature_extractor = feature_extractor(
self.input_space, **self.feature_extractor_kwargs)
# ---- 2. Define memory network --------------------------------------
if isinstance(self.input_space, gym.spaces.Dict):
dummy_obs = {k: torch.zeros(1, *self.input_space[k].shape) for k in self.input_space}
features = policy_feature_extractor(dummy_obs)
else:
dummy_obs = torch.zeros(1, *self.input_space.shape)
features = policy_feature_extractor(dummy_obs)
if isinstance(policy_feature_extractor, nn.Identity):
feature_size = dummy_obs
else:
feature_size = features.shape[-1]
if self.recurrent_net:
policy_memory_net = self.recurrent_net(feature_size, **self.recurrent_net_kwargs)
self.recurrent_size = policy_memory_net.recurrent_hidden_state_size
else:
policy_memory_net = nn.Identity()
self.recurrent_size = feature_size
# ---- 3. Define action distribution ----------------------------------
if isinstance(self.action_space, gym.spaces.Discrete):
dist = get_dist("Categorical")(self.recurrent_size, self.action_space.n)
self.scale = None
self.unscale = None
elif isinstance(self.action_space, gym.spaces.Box): # Continuous action space
dist = get_dist("Gaussian")(self.recurrent_size, self.action_space.shape[0])
self.scale = Scale(self.action_space)
self.unscale = Unscale(self.action_space)
elif isinstance(self.action_space, gym.spaces.Dict):
raise NotImplementedError
else:
raise ValueError("Unrecognized action space")
# ---- 4. Concatenate all policy modules ------------------------------
policy_net = nn.Sequential(OrderedDict([
('feature_extractor', policy_feature_extractor),
('memory_net', policy_memory_net),
('dist', dist),
]))
setattr(self, name, policy_net)