import gym
import torch
import torch.nn as nn
import numpy as np
from collections import OrderedDict
import pytorchrl as prl
from pytorchrl.agent.actors.base import Actor
from pytorchrl.agent.actors.distributions import get_dist
from pytorchrl.agent.actors.utils import Scale, Unscale, init
from pytorchrl.agent.actors.feature_extractors import MLP, default_feature_extractor
[docs]class OffPolicyActor(Actor):
"""
Actor critic class for Off-Policy algorithms.
It contains a policy network (actor) to predict next actions and one or two
Q networks.
Parameters
----------
device: torch.device
CPU or specific GPU where class computations will take place.
input_space : gym.Space
Environment observation space.
action_space : gym.Space
Environment action space.
algorithm_name : str
Name of the RL algorithm used for learning.
checkpoint : str
Path to a previously trained Actor checkpoint to be loaded.
noise : str
Type of exploration noise that will be added to the deterministic actions.
obs_feature_extractor : nn.Module
PyTorch nn.Module to extract features from observation in all networks.
obs_feature_extractor_kwargs : dict
Keyword arguments for the obs extractor network.
act_feature_extractor : nn.Module
PyTorch nn.Module to extract features from actions in all networks.
act_feature_extractor_kwargs : dict
Keyword arguments for the act extractor network.
common_feature_extractor : nn.Module
PyTorch nn.Module to extract joint features from the concatenation of
action and observation features.
common_feature_extractor_kwargs : dict
Keyword arguments for the common extractor network.
recurrent_net : bool
Whether to use a RNNs as feature extractors.
sequence_overlap : float
From 0.0 to 1.0, how much consecutive rollout sequences will overlap.
recurrent_net_kwargs : dict
Keyword arguments for the memory network.
num_critics : int
Number of Q networks to be instantiated.
Examples
--------
"""
def __init__(self,
device,
input_space,
action_space,
algorithm_name,
noise=None,
checkpoint=None,
sequence_overlap=0.5,
recurrent_net=None,
recurrent_net_kwargs={},
obs_feature_extractor=None,
obs_feature_extractor_kwargs={},
act_feature_extractor=None,
act_feature_extractor_kwargs={},
common_feature_extractor=None,
common_feature_extractor_kwargs={},
num_critics=2):
super(OffPolicyActor, self).__init__(
device=device,
checkpoint=checkpoint,
input_space=input_space,
action_space=action_space)
self.noise = noise
self.algorithm_name = algorithm_name
self.input_space = input_space
self.action_space = action_space
self.act_feature_extractor = act_feature_extractor
self.act_feature_extractor_kwargs = act_feature_extractor_kwargs
self.obs_feature_extractor = obs_feature_extractor
self.obs_feature_extractor_kwargs = obs_feature_extractor_kwargs
self.common_feature_extractor = common_feature_extractor
self.common_feature_extractor_kwargs = common_feature_extractor_kwargs
self.recurrent_net = recurrent_net
self.recurrent_net_kwargs = recurrent_net_kwargs
self.sequence_overlap = np.clip(sequence_overlap, 0.0, 1.0)
self.num_critics = num_critics
self.deterministic = algorithm_name in [prl.DDPG, prl.TD3]
# ----- Policy Network ----------------------------------------------------
self.create_policy("policy_net")
# ----- Q Networks ----------------------------------------------------
for i in range(num_critics):
self.create_critic("q{}".format(i + 1))
[docs] @classmethod
def create_factory(
cls,
input_space,
action_space,
algorithm_name,
noise=None,
restart_model=None,
sequence_overlap=0.5,
recurrent_net_kwargs={},
recurrent_net=None,
obs_feature_extractor=None,
obs_feature_extractor_kwargs={},
act_feature_extractor=None,
act_feature_extractor_kwargs={},
common_feature_extractor=MLP,
common_feature_extractor_kwargs={},
num_critics=2
):
"""
Returns a function that creates actor critic instances.
Parameters
----------
input_space : gym.Space
Environment observation space.
action_space : gym.Space
Environment action space.
algorithm_name : str
Name of the RL algorithm used for learning.
noise : str
Type of exploration noise that will be added to the deterministic actions.
obs_feature_extractor : nn.Module
PyTorch nn.Module to extract features from observation in all networks.
obs_feature_extractor_kwargs : dict
Keyword arguments for the obs extractor network.
act_feature_extractor : nn.Module
PyTorch nn.Module to extract features from actions in all networks.
act_feature_extractor_kwargs : dict
Keyword arguments for the act extractor network.
common_feature_extractor : nn.Module
PyTorch nn.Module to extract joint features from the concatenation of
action and observation features.
common_feature_extractor_kwargs : dict
Keyword arguments for the common extractor network.
recurrent_net : bool
Whether to use a RNNs as feature extractors.
sequence_overlap : float
From 0.0 to 1.0, how much consecutive rollout sequences will overlap.
recurrent_net_kwargs : dict
Keyword arguments for the memory network.
num_critics : int
Number of Q networks to be instantiated.
restart_model : str
Path to a previously trained Actor checkpoint to be loaded.
Returns
-------
create_actor_instance : func
creates a new OffPolicyActor class instance.
"""
def create_actor_instance(device):
"""Create and return an actor critic instance."""
policy = cls(noise=noise,
device=device,
input_space=input_space,
action_space=action_space,
algorithm_name=algorithm_name,
checkpoint=restart_model,
sequence_overlap=sequence_overlap,
recurrent_net_kwargs=recurrent_net_kwargs,
recurrent_net=recurrent_net,
obs_feature_extractor=obs_feature_extractor,
obs_feature_extractor_kwargs=obs_feature_extractor_kwargs,
act_feature_extractor=act_feature_extractor,
act_feature_extractor_kwargs=act_feature_extractor_kwargs,
common_feature_extractor=common_feature_extractor,
common_feature_extractor_kwargs=common_feature_extractor_kwargs,
num_critics=num_critics)
policy.to(device)
try:
policy.try_load_from_checkpoint()
except RuntimeError:
pass
return policy
return create_actor_instance
@property
def is_recurrent(self):
"""Returns True if the actor network are recurrent."""
return self.recurrent_net
@property
def recurrent_hidden_state_size(self):
"""Size of policy recurrent hidden state"""
return self.recurrent_size
[docs] def actor_initial_states(self, obs):
"""
Returns all actor inputs required to predict initial action.
Parameters
----------
obs : torch.tensor
Initial environment observation.
Returns
-------
obs : torch.tensor
Initial environment observation.
rhs : dict
Initial recurrent hidden state (will contain zeroes).
done : torch.tensor
Initial done tensor, indicating the environment is not done.
"""
if isinstance(obs, dict):
num_proc = list(obs.values())[0].size(0)
dev = list(obs.values())[0].device
else:
num_proc = obs.size(0)
dev = obs.device
done = torch.zeros(num_proc, 1).to(dev)
try:
rhs_act = self.policy_net.memory_net.get_initial_recurrent_state(num_proc).to(dev)
except Exception:
rhs_act = torch.zeros(num_proc, self.recurrent_hidden_state_size).to(dev)
rhs = {"rhs_act": rhs_act}
rhs.update({"rhs_q{}".format(i + 1): rhs_act.clone() for i in range(self.num_critics)})
return obs, rhs, done
[docs] def burn_in_recurrent_states(self, data_batch):
"""
Applies a recurrent burn-in phase to data_batch as described in
(https://openreview.net/pdf?id=r1lyTjAqYX). Initial B steps are used
to compute on-policy recurrent hidden states. data_batch is then
updated, discarding B first steps in all tensors.
Parameters
----------
data_batch : dict
data batch containing all required tensors to compute Algorithm loss.
Returns
-------
data_batch : dict
Updated data batch after burn-in phase.
"""
# (T, N, -1) tensors that have been flatten to (T * N, -1)
N = data_batch[prl.RHS]["rhs_act"].shape[0] # number of sequences
T = int(data_batch[prl.DONE].shape[0] / N) # sequence lengths
B = int(self.sequence_overlap * T) # sequence burn-in length
if B == 0:
return data_batch
# Split tensors into burn-in and no-burn-in
chunk_sizes = [B, T - B] * N
burn_in_data = {k: {} for k in data_batch}
non_burn_in_data = {k: {} for k in data_batch}
for k, v in data_batch.items():
if k in (prl.RHS, prl.RHS2):
burn_in_data[k] = v
continue
if not isinstance(v, (torch.Tensor, dict)):
non_burn_in_data[k] = v
continue
if isinstance(v, dict):
for x, y in v.items():
if not isinstance(y, torch.Tensor):
non_burn_in_data[k][x] = v
continue
sequence_slices = torch.split(y, chunk_sizes)
burn_in_data[k][x] = torch.cat(sequence_slices[0::2])
non_burn_in_data[k][x] = torch.cat(sequence_slices[1::2])
else:
sequence_slices = torch.split(v, chunk_sizes)
burn_in_data[k] = torch.cat(sequence_slices[0::2])
non_burn_in_data[k] = torch.cat(sequence_slices[1::2])
# Do burn-in
with torch.no_grad():
act, _, _, rhs, _, _ = self.get_action(
burn_in_data[prl.OBS], burn_in_data[prl.RHS], burn_in_data[prl.DONE])
act2, _, _, rhs2, _, _ = self.get_action(
burn_in_data[prl.OBS2], burn_in_data[prl.RHS2], burn_in_data[prl.DONE2])
rhs = self.get_q_scores(
burn_in_data[prl.OBS], rhs, burn_in_data[prl.DONE], act).get("rhs")
rhs2 = self.get_q_scores(
burn_in_data[prl.OBS2], rhs2, burn_in_data[prl.DONE2], act2).get("rhs")
for k in rhs:
rhs[k] = rhs[k].detach()
for k in rhs2:
rhs2[k] = rhs2[k].detach()
non_burn_in_data[prl.RHS] = rhs
non_burn_in_data[prl.RHS2] = rhs2
return non_burn_in_data
[docs] def get_action(self, obs, rhs, done, deterministic=False):
"""
Predict and return next action, along with other information.
Parameters
----------
obs : torch.tensor
Current environment observation.
rhs : dict
Current recurrent hidden states.
done : torch.tensor
Current done tensor, indicating if episode has finished.
deterministic : bool
Whether to randomly sample action from predicted distribution or take the mode.
Returns
-------
action : torch.tensor
Next action sampled.
clipped_action : torch.tensor
Next action sampled, but clipped to be within the env action space.
logp_action : torch.tensor
Log probability of `action` within the predicted action distribution.
rhs : dict
Updated recurrent hidden states.
entropy_dist : torch.tensor
Entropy of the predicted action distribution.
dist : torch.Distribution
Predicted probability distribution over next action.
"""
x = self.policy_net.common_feature_extractor(self.policy_net.obs_feature_extractor(obs))
if self.recurrent_net:
x, rhs["rhs_act"] = self.policy_net.memory_net(x, rhs["rhs_act"], done)
(action, clipped_action, logp_action, entropy_dist, dist) = self.policy_net.dist(
x, deterministic=deterministic)
if self.unscale:
action = self.unscale(action)
clipped_action = self.unscale(clipped_action)
return action, clipped_action, logp_action, rhs, entropy_dist, dist
[docs] def evaluate_actions(self, obs, rhs, done, action):
"""
Evaluate log likelihood of action given obs and the current
policy network. Returns also entropy distribution.
Parameters
----------
obs : torch.tensor
Environment observation.
rhs : dict
Recurrent hidden states.
done : torch.tensor
Done tensor, indicating if episode has finished.
action : torch.tensor
Evaluated action.
Returns
-------
logp_action : torch.tensor
Log probability of `action` according to the action distribution
predicted with current version of the policy_net.
entropy_dist : torch.tensor
Entropy of the action distribution predicted with current version
of the policy_net.
dist : torch.Distribution
Predicted probability distribution over next action.
"""
if self.scale:
action = self.scale(action)
features = self.policy_net.common_feature_extractor(self.policy_net.obs_feature_extractor(obs))
if self.recurrent_net:
features, rhs["rhs_act"] = self.policy_net.memory_net(features, rhs["rhs_act"], done)
logp_action, entropy_dist, dist = self.policy_net.dist.evaluate_pred(features, action)
return logp_action, entropy_dist, dist
[docs] def get_q_scores(self, obs, rhs, done, actions=None):
"""
Return Q scores of the given observations and actions.
Parameters
----------
obs : torch.tensor
Environment observation.
rhs : dict
Current recurrent hidden states.
done : torch.tensor
Current done tensor, indicating if episode has finished.
actions : torch.tensor
Evaluated actions.
Returns
-------
output : dict
Dict containing value prediction from each critic under keys "q1", "q2", etc
as well as the recurrent hidden states under the key "rhs".
"""
outputs = {}
for i in range(self.num_critics):
q = getattr(self, "q{}".format(i + 1))
features = q.obs_feature_extractor(obs)
if actions is not None:
act_features = q.act_feature_extractor(actions)
features = torch.cat([features, act_features], -1)
features = q.common_feature_extractor(features)
if self.recurrent_net:
features, rhs["rhs_q{}".format(1 + 1)] = q.memory_net(
features, rhs["rhs_q{}".format(i + 1)], done)
q_scores = q.predictor(features)
outputs["q{}".format(i + 1)] = q_scores
outputs["rhs"] = rhs
return outputs
[docs] def create_critic(self, name):
"""
Create a critic q network and define it as class attribute under the name `name`.
This actor defines defines q networks as:
obs_feature_extractor
q = + common_feature_extractor + memory_net + q_prediction_layer
act_feature_extractor
Parameters
----------
name : str
Critic network name.
"""
if self.obs_feature_extractor:
obs_feature_extractor = self.obs_feature_extractor
else:
obs_feature_extractor = default_feature_extractor(self.input_space)
# ---- 1. Define action feature extractor -----------------------------
act_extractor = self.act_feature_extractor or nn.Identity
q_act_feature_extractor = act_extractor(
self.action_space, **self.act_feature_extractor_kwargs)
# ---- 2. Define obs feature extractor -----------------------------
obs_extractor = obs_feature_extractor or nn.Identity
q_obs_feature_extractor = obs_extractor(
self.input_space, **self.obs_feature_extractor_kwargs)
obs_feature_size = q_obs_feature_extractor(
torch.zeros(1, *self.input_space.shape)).shape[-1]
# ---- 3. Define shared feature extractor -----------------------------
if isinstance(self.action_space, gym.spaces.Discrete):
act_feature_size = 0
q_outputs = self.action_space.n
elif isinstance(self.action_space, gym.spaces.Box):
act_feature_size = q_act_feature_extractor(
torch.zeros(1, *self.action_space.shape)).shape[-1] if self.act_feature_extractor \
else self.action_space.shape[-1]
q_outputs = 1
else:
raise NotImplementedError
feature_size = obs_feature_size + act_feature_size
if self.common_feature_extractor:
q_common_feature_extractor = self.common_feature_extractor(feature_size, **self.common_feature_extractor_kwargs)
feature_size = q_common_feature_extractor(torch.zeros(1, feature_size)).shape[-1]
else:
q_common_feature_extractor = nn.Identity()
# ---- 4. Define memory network ---------------------------------------
q_memory_net = self.recurrent_net(feature_size, **self.recurrent_net_kwargs) if\
self.recurrent_net else nn.Identity()
feature_size = q_memory_net.recurrent_hidden_state_size if self.recurrent_net\
else feature_size
# ---- 5. Define prediction layer -------------------------------------
init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0))
q_predictor = init_(nn.Linear(feature_size, q_outputs))
# ---- 6. Concatenate all q1 net modules ------------------------------
q_net = nn.Sequential(OrderedDict([
('obs_feature_extractor', q_obs_feature_extractor),
('act_feature_extractor', q_act_feature_extractor),
('common_feature_extractor', q_common_feature_extractor),
('memory_net', q_memory_net),
("predictor", q_predictor),
]))
setattr(self, name, q_net)
[docs] def create_policy(self, name):
"""
Create a policy network and define it as class attribute under the name `name`.
This actor defines policy network as:
policy = obs_feature_extractor + common_feature_extractor + memory_net + action distribution
Parameters
----------
name : str
Policy network name.
"""
# ---- 1. Define Obs feature extractor --------------------------------
if self.obs_feature_extractor:
obs_feature_extractor = self.obs_feature_extractor
else:
obs_feature_extractor = default_feature_extractor(self.input_space)
obs_extractor = obs_feature_extractor or nn.Identity
policy_obs_feature_extractor = obs_extractor(
self.input_space, **self.obs_feature_extractor_kwargs)
# ---- 2. Define Common feature extractor -----------------------------
feature_size = policy_obs_feature_extractor(torch.zeros(1, *self.input_space.shape)).shape[-1]
if self.common_feature_extractor:
policy_common_feature_extractor = self.common_feature_extractor(
feature_size, **self.common_feature_extractor_kwargs)
feature_size = policy_common_feature_extractor(torch.zeros(1, feature_size)).shape[-1]
else:
policy_common_feature_extractor = nn.Identity()
# ---- 3. Define memory network --------------------------------------
if self.recurrent_net:
policy_memory_net = self.recurrent_net(feature_size, **self.recurrent_net_kwargs)
self.recurrent_size = policy_memory_net.recurrent_hidden_state_size
else:
policy_memory_net = nn.Identity()
self.recurrent_size = feature_size
# ---- 4. Define action distribution ----------------------------------
if isinstance(self.action_space, gym.spaces.Discrete):
dist = get_dist("Categorical")(self.recurrent_size, self.action_space.n)
self.scale = None
self.unscale = None
elif isinstance(self.action_space, gym.spaces.Box) and not self.deterministic:
if self.algorithm_name in [prl.SAC]:
dist = get_dist("SquashedGaussian")(self.recurrent_size, self.action_space.shape[0])
else:
dist = get_dist("Gaussian")(self.recurrent_size, self.action_space.shape[0])
self.scale = Scale(self.action_space)
self.unscale = Unscale(self.action_space)
elif isinstance(self.action_space, gym.spaces.Box) and self.deterministic:
dist = get_dist("Deterministic")(
self.recurrent_size, self.action_space.shape[0], noise=self.noise)
self.scale = Scale(self.action_space)
self.unscale = Unscale(self.action_space)
else:
raise NotImplementedError
# ---- 5. Concatenate all policy modules ------------------------------
policy_net = nn.Sequential(OrderedDict([
('obs_feature_extractor', policy_obs_feature_extractor),
('common_feature_extractor', policy_common_feature_extractor),
('memory_net', policy_memory_net), ('dist', dist),
]))
setattr(self, name, policy_net)