import itertools
import numpy as np
from copy import deepcopy
from scipy.optimize import minimize
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions.kl import kl_divergence
import pytorchrl as prl
from pytorchrl.agent.algorithms.base import Algorithm
from pytorchrl.agent.algorithms.utils import get_gradients, set_gradients, gaussian_kl
from pytorchrl.agent.algorithms.policy_loss_addons import PolicyLossAddOn
[docs]class MPO(Algorithm):
"""
Maximum a Posteriori Policy Optimization algorithm class.
Algorithm class to execute MPO, from A Abdolmaleki et al.
(https://arxiv.org/abs/1806.06920). Algorithms are modules generally
required by multiple workers, so MPO.algo_factory(...) returns a function
that can be passed on to workers to instantiate their own MPO module.
This code has been adapted from https://github.com/daisatojp/mpo.
Parameters
----------
device : torch.device
CPU or specific GPU where class computations will take place.
envs : VecEnv
Vector of environments instance.
actor : Actor
Actor_critic class instance.
lr_pi : float
Policy optimizer learning rate.
lr_q : float
Q-nets optimizer learning rate.
gamma : float
Discount factor parameter.
polyak : float
SAC polyak averaging parameter.
num_updates : int
Num consecutive actor updates before data collection continues.
update_every : int
Regularity of actor updates in number environment steps.
start_steps : int
Num of initial random environment steps before learning starts.
mini_batch_size : int
Size of actor update batches.
target_update_interval : float
regularity of target nets updates with respect to actor Adam updates.
num_test_episodes : int
Number of episodes to complete in each test phase.
test_every : int
Regularity of test evaluations in actor updates.
dual_constraint : float
Hard constraint of the dual formulation in the E-step corresponding to [2] p.4 ε.
kl_mean_constraint : float
Hard constraint of the mean in the M-step corresponding to [2] p.6 ε_μ for continuous action space.
kl_var_constraint : float
Hard constraint of the covariance in the M-step corresponding to [2] p.6 ε_Σ for continuous action space.
kl_constraint : float
Hard constraint in the M-step corresponding to [2] p.6 ε_π for discrete action space.
alpha_scale: float
Scaling factor of the lagrangian multiplier in the M-step for dicrete action spaces.
alpha_max : float
Higher bound used for clipping the lagrangian lagrangian in discrete action spaces.
alpha_mean_scale : float
Mean scaling factor of the lagrangian multiplier in the M-step for continuous action spaces.
alpha_var_scale : float
Varience scaling factor of the lagrangian lagrangian in the M-step for continuous action spaces.
alpha_mean_max : float
Higher bound used for clipping the lagrangian lagrangian in continuous action spaces.
alpha_var_max : float
Higher bound used for clipping the lagrangian lagrangian in continuous action spaces.
mstep_iterations : int
The number of iterations of the M-step
sample_action_num : int
For continuous action spaces, number of samples used to compute expected Q scores.
max_grad_norm : float
Gradient clipping parameter.
policy_loss_addons : list
List of PolicyLossAddOn components adding loss terms to the algorithm policy loss.
Examples
--------
>>> create_algo = MPO.create_factory(
lr_q=1e-4, lr_pi=1e-4, lr_alpha=1e-4, gamma=0.99, polyak=0.995,
num_updates=50, update_every=50, test_every=5000, start_steps=20000,
mini_batch_size=64, alpha=1.0, num_test_episodes=0, target_update_interval=1)
"""
def __init__(self,
device,
envs,
actor,
lr_q=1e-4,
lr_pi=1e-4,
gamma=0.99,
polyak=1.0,
num_updates=1,
update_every=50,
test_every=1000,
start_steps=20000,
mini_batch_size=64,
num_test_episodes=5,
target_update_interval=1,
dual_constraint=0.1,
kl_mean_constraint=0.01,
kl_var_constraint=0.0001,
kl_constraint=0.01,
alpha_scale=10.0,
alpha_mean_scale=1.0,
alpha_var_scale=100.0,
alpha_mean_max=0.1,
alpha_var_max=10.0,
alpha_max=1.0,
mstep_iterations=5,
sample_action_num=64,
max_grad_norm=0.1,
policy_loss_addons=[]):
# ---- General algo attributes ----------------------------------------
# Discount factor
self._gamma = gamma
# Number of steps collected with initial random policy
self._start_steps = int(start_steps)
# Times data in the buffer is re-used before data collection proceeds
self._num_epochs = int(1) # Default to 1 for off-policy algorithms
# Number of data samples collected between network update stages
self._update_every = int(update_every)
# Number mini batches per epoch
self._num_mini_batch = int(num_updates)
# Size of update mini batches
self._mini_batch_size = int(mini_batch_size)
# Number of network updates between test evaluations
self._test_every = int(test_every)
# Number of episodes to complete when testing
self._num_test_episodes = int(num_test_episodes)
# ---- MPO-specific attributes ----------------------------------------
self.iter = 0
self.envs = envs
self.actor = actor
self.device = device
self.polyak = polyak
self.dual_constraint = dual_constraint
self.max_grad_norm = max_grad_norm
self.mstep_iterations = mstep_iterations
self.sample_action_num = sample_action_num
self.target_update_interval = target_update_interval
# For continuous action space
self.kl_mean_constraint = kl_mean_constraint
self.kl_var_constraint = kl_var_constraint
self.alpha_mean_scale = alpha_mean_scale
self.alpha_var_scale = alpha_var_scale
self.alpha_mean_max = alpha_mean_max
self.alpha_var_max = alpha_var_max
# For discrete action space
self.alpha_max = alpha_max
self.kl_constraint = kl_constraint
self.alpha_scale = alpha_scale
# Initialize Lagrange Multiplier
self.eta = np.random.rand()
self.alpha_mu = 0.0 # lagrangian multiplier for continuous action space in the M-step
self.alpha_sigma = 0.0 # lagrangian multiplier for continuous action space in the M-step
self.alpha = 0.0 # lagrangian multiplier for discrete action space in the M-step
assert hasattr(self.actor, "q1"), "MPO requires q critic (num_critics=1)"
# Create target networks
self.actor_targ = deepcopy(actor)
# Freeze target networks with respect to optimizers
for p in self.actor_targ.parameters():
p.requires_grad = False
# List of parameters for both Q-networks
q_params = self.actor.q1.parameters()
# List of parameters for policy network
p_params = itertools.chain(self.actor.policy_net.parameters())
# ----- Policy Loss Addons --------------------------------------------
# Sanity check, policy_loss_addons is a PolicyLossAddOn instance
# or a list of PolicyLossAddOn instances
assert isinstance(policy_loss_addons, (PolicyLossAddOn, list)),\
"MPO policy_loss_addons parameter should be a PolicyLossAddOn instance " \
"or a list of PolicyLossAddOn instances"
if isinstance(policy_loss_addons, list):
for addon in policy_loss_addons:
assert isinstance(addon, PolicyLossAddOn), \
"MPO policy_loss_addons parameter should be a PolicyLossAddOn " \
"instance or a list of PolicyLossAddOn instances"
else:
policy_loss_addons = [policy_loss_addons]
self.policy_loss_addons = policy_loss_addons
for addon in self.policy_loss_addons:
addon.setup(self.actor, self.device)
# ----- Optimizers ----------------------------------------------------
self.norm_loss_q = nn.SmoothL1Loss()
self.pi_optimizer = optim.Adam(p_params, lr=lr_pi)
self.q_optimizer = optim.Adam(q_params, lr=lr_q)
[docs] @classmethod
def create_factory(cls,
lr_q=1e-4,
lr_pi=1e-4,
gamma=0.99,
polyak=0.995,
num_updates=50,
test_every=5000,
update_every=50,
start_steps=1000,
mini_batch_size=64,
num_test_episodes=5,
target_update_interval=1.0,
dual_constraint=0.1,
kl_mean_constraint=0.01,
kl_var_constraint=0.0001,
kl_constraint=0.01,
alpha_scale=10.0,
alpha_mean_scale=1.0,
alpha_var_scale=100.0,
alpha_mean_max=0.1,
alpha_var_max=10.0,
alpha_max=1.0,
mstep_iterations=5,
sample_action_num=64,
max_grad_norm=0.1,
policy_loss_addons=[]):
"""
Returns a function to create new MPO instances.
Parameters
----------
lr_pi : float
Policy optimizer learning rate.
lr_q : float
Q-nets optimizer learning rate.
gamma : float
Discount factor parameter.
polyak : float
SAC polyak averaging parameter.
num_updates : int
Num consecutive actor updates before data collection continues.
update_every : int
Regularity of actor updates in number environment steps.
start_steps : int
Num of initial random environment steps before learning starts.
mini_batch_size : int
Size of actor update batches.
target_update_interval : float
regularity of target nets updates with respect to actor Adam updates.
num_test_episodes : int
Number of episodes to complete in each test phase.
test_every : int
Regularity of test evaluations in actor updates.
dual_constraint : float
Hard constraint of the dual formulation in the E-step corresponding to [2] p.4 ε.
kl_mean_constraint : float
Hard constraint of the mean in the M-step corresponding to [2] p.6 ε_μ for continuous action space.
kl_var_constraint : float
Hard constraint of the covariance in the M-step corresponding to [2] p.6 ε_Σ for continuous action space.
kl_constraint : float
Hard constraint in the M-step corresponding to [2] p.6 ε_π for discrete action space.
alpha_scale: float
Scaling factor of the lagrangian multiplier in the M-step for dicrete action spaces.
alpha_max : float
Higher bound used for clipping the lagrangian lagrangian in discrete action spaces.
alpha_mean_scale : float
Mean scaling factor of the lagrangian multiplier in the M-step for continuous action spaces.
alpha_var_scale : float
Varience scaling factor of the lagrangian lagrangian in the M-step for continuous action spaces.
alpha_mean_max : float
Higher bound used for clipping the lagrangian lagrangian in continuous action spaces.
alpha_var_max : float
Higher bound used for clipping the lagrangian lagrangian in continuous action spaces.
mstep_iterations : int
The number of iterations of the M-step
sample_action_num : int
For continuous action spaces, number of samples used to compute expected Q scores.
max_grad_norm : float
Gradient clipping parameter.
policy_loss_addons : list
List of PolicyLossAddOn components adding loss terms to the algorithm policy loss.
Returns
-------
create_algo_instance : func
Function that creates a new MPO class instance.
algo_name : str
Name of the algorithm.
"""
def create_algo_instance(device, actor, envs):
return cls(lr_q=lr_q,
lr_pi=lr_pi,
envs=envs,
actor=actor,
gamma=gamma,
device=device,
polyak=polyak,
test_every=test_every,
start_steps=start_steps,
num_updates=num_updates,
update_every=update_every,
mini_batch_size=mini_batch_size,
num_test_episodes=num_test_episodes,
target_update_interval=target_update_interval,
dual_constraint=dual_constraint,
kl_mean_constraint=kl_mean_constraint,
kl_var_constraint=kl_var_constraint,
kl_constraint=kl_constraint,
alpha_scale=alpha_scale,
alpha_mean_scale=alpha_mean_scale,
alpha_var_scale=alpha_var_scale,
alpha_mean_max=alpha_mean_max,
alpha_var_max=alpha_var_max,
policy_loss_addons=policy_loss_addons)
return create_algo_instance, prl.MPO
@property
def gamma(self):
"""Returns discount factor gamma."""
return self._gamma
@property
def start_steps(self):
"""Returns the number of steps to collect with initial random policy."""
return self._start_steps
@property
def num_epochs(self):
"""
Returns the number of times the whole buffer is re-used before data
collection proceeds.
"""
return self._num_epochs
@property
def update_every(self):
"""
Returns the number of data samples collected between
network update stages.
"""
return self._update_every
@property
def num_mini_batch(self):
"""
Returns the number of times the whole buffer is re-used before data
collection proceeds.
"""
return self._num_mini_batch
@property
def mini_batch_size(self):
"""
Returns the number of mini batches per epoch.
"""
return self._mini_batch_size
@property
def test_every(self):
"""Number of network updates between test evaluations."""
return self._test_every
@property
def num_test_episodes(self):
"""
Returns the number of episodes to complete when testing.
"""
return self._num_test_episodes
@property
def discrete_version(self):
"""Returns True if action_space is discrete."""
return self.actor.action_space.__class__.__name__ == "Discrete"
[docs] def acting_step(self, obs, rhs, done, deterministic=False):
"""
MPO acting function.
Parameters
----------
obs : torch.tensor
Current world observation
rhs : torch.tensor
RNN recurrent hidden state (if policy is not a RNN, rhs will contain zeroes).
done : torch.tensor
1.0 if current obs is the last one in the episode, else 0.0.
deterministic : bool
Whether to randomly sample action from predicted distribution or taking the mode.
Returns
-------
action : torch.tensor
Predicted next action.
clipped_action : torch.tensor
Predicted next action (clipped to be within action space).
rhs : torch.tensor
Policy recurrent hidden state (if policy is not a RNN, rhs will contain zeroes).
other : dict
Additional MPO predictions, which are not used in other algorithms.
"""
with torch.no_grad():
(action, clipped_action, logp_action, rhs,
entropy_dist, dist) = self.actor.get_action(
obs, rhs, done, deterministic=deterministic)
return action, clipped_action, rhs, {}
[docs] def compute_loss_q(self, batch, n_step=1, per_weights=1):
"""
Calculate MPO Q-nets loss
Parameters
----------
batch : dict
Data batch dict containing all required tensors to compute MPO losses.
n_step : int or float
Number of future steps used to computed the truncated n-step return value.
per_weights :
Prioritized Experience Replay (PER) important sampling weights or 1.0.
Returns
-------
loss_q : torch.tensor
Q-net loss.
errors : torch.tensor
TD errors.
"""
# [2] 3 Policy Evaluation (Step 1)
o, rhs, d = batch[prl.OBS], batch[prl.RHS], batch[prl.DONE]
a, r = batch[prl.ACT], batch[prl.REW]
o2, rhs2, d2 = batch[prl.OBS2], batch[prl.RHS2], batch[prl.DONE2]
bs, ds = o.shape[0], o.shape[-1]
if self.discrete_version:
# Q-values for all actions
q1 = self.actor.get_q_scores(o, rhs, d).get("q1")
q1 = q1.gather(1, a.long())
# Bellman backup for Q functions
with torch.no_grad():
# Target actions come from *current* policy
a2, _, _, _, _, dist = self.actor_targ.get_action(o2, rhs2, d2)
N = dist.probs.shape[-1] # num actions
actions = torch.arange(N)[..., None].expand(-1, bs).to(self.device) # (da, bs)
p_a2 = dist.expand((N, bs)).log_prob(actions).exp().transpose(0, 1) # (bs, da)
# Target Q-values
q1_pi_targ = self.actor_targ.get_q_scores(o2, rhs2, d2).get("q1")
q_pi_targ = (p_a2 * q1_pi_targ).sum(dim=1, keepdim=True)
assert r.shape == q_pi_targ.shape
backup = r + (self.gamma ** n_step) * (1 - d2) * q_pi_targ
else:
N = self.sample_action_num
da = a.shape[-1] # num action dimensions
# Q-values for all actions
q1 = self.actor.get_q_scores(o, rhs, d, a).get("q1")
# Bellman backup for Q functions
with torch.no_grad():
# Target actions come from *current* policy
a2, _, logp_a2, _, _, dist = self.actor_targ.get_action(o2, rhs2, d2)
sampled_actions = dist.sample((N,)).transpose(0, 1) # (bs, N, da)
expanded_obs2 = o2[:, None, :].expand(-1, N, -1) # (bs, N, ds)
expanded_d2 = d2[:, None, :].expand(-1, N, -1) # (bs, N, 1)
expanded_rhs2 = {k: v[:, None, :].expand(-1, N, -1) for k, v in rhs2.items()}
expanded_reshaped_rhs2 = {k: v.reshape(-1, v.shape[-1]) for k, v in expanded_rhs2.items()}
next_q1 = self.actor_targ.get_q_scores(
expanded_obs2.reshape(-1, ds), # (N * bs, ds)
expanded_reshaped_rhs2, # get expanded rhs
expanded_d2.reshape(-1, 1), # (N * bs, ds)
sampled_actions.reshape(-1, da), # (N * bs, da)
).get("q1") # (N * bs, 1)
expected_next_q1 = next_q1.reshape(bs, N).mean(dim=1, keepdim=True) # (B,)
q_pi_targ = expected_next_q1
backup = r + (self.gamma ** n_step) * (1 - d2) * q_pi_targ
# MSE loss against Bellman backup
loss_q = 0.5 * (((q1 - backup) ** 2) * per_weights).mean()
# loss_q = self.norm_loss_q(backup, q1)
errors = ((q1 - backup).abs()).detach().cpu()
return loss_q, errors
[docs] def compute_loss_pi(self, batch, per_weights=1):
"""
Calculate MPO policy loss.
Parameters
----------
batch : dict
Data batch dict containing all required tensors to compute MPO losses.
per_weights :
Prioritized Experience Replay (PER) important sampling weights or 1.0.
Returns
-------
loss_policy : torch.tensor
MPO policy loss.
"""
o, rhs, d, a = batch[prl.OBS], batch[prl.RHS], batch[prl.DONE], batch[prl.ACT]
bs, ds = o.shape[0], o.shape[-1]
# E-Step of Policy Improvement
# [2] 4.1 Finding action weights (Step 2)
_, _, _, _, _, dist_targ = self.actor_targ.get_action(o, rhs, d)
if self.discrete_version:
N = dist_targ.probs.shape[-1] # num possible actions actions, env.action_space.n
# for each state in the batch, any possible action
actions = torch.arange(N)[..., None].expand(N, bs).to(self.device) # (N, bs)
dist_targ_probs = dist_targ.expand((N, bs)).log_prob(actions).exp() # (N, bs)
target_q1 = self.actor_targ.get_q_scores(o, rhs, d).get("q1") # (bs, N)
target_q1 = target_q1.transpose(1, 0) # (N, bs)
b_prob_np = dist_targ_probs.cpu().transpose(0, 1).numpy() # (bs, N)
target_q1_np = target_q1.cpu().transpose(0, 1).numpy() # (bs, N)
else:
N = self.sample_action_num
da = a.shape[-1] # num action dimensions
sampled_actions = dist_targ.sample((N,)) # (N, bs, da)
expanded_obs = o[None, ...].expand(N, -1, -1) # (N, bs, ds)
expanded_d = d[None, ...].expand(N, -1, -1) # (N, bs, 1)
expanded_rhs = {k: v[None, ...].expand(N, -1, -1) for k, v in rhs.items()}
expanded_reshaped_rhs = {k: v.reshape(-1, v.shape[-1]) for k, v in expanded_rhs.items()}
target_q1 = self.actor_targ.get_q_scores(
expanded_obs.reshape(-1, ds), # (N * bs, ds)
expanded_reshaped_rhs, # get expanded rhs
expanded_d.reshape(-1, 1), # (N * bs, ds)
sampled_actions.reshape(-1, da), # (N * bs, da)
).get("q1")
target_q1 = target_q1.reshape(N, bs) # (N, bs)
target_q1_np = target_q1.cpu().transpose(0, 1).numpy() # (bs, N)
# https://arxiv.org/pdf/1812.02256.pdf
# [2] 4.1 Finding action weights (Step 2)
# Using an exponential transformation of the Q-values
if self.discrete_version:
def dual(eta):
"""
dual function of the non-parametric variational
g(η) = η*ε + η*mean(log(sum(π(a|s)*exp(Q(s, a)/η))))
We have to multiply π by exp because this is expectation.
This equation is correspond to last equation of the [2] p.15
For numerical stabilization, this can be modified to
Qj = max(Q(s, a), along=a)
g(η) = η*ε + mean(Qj, along=j) + η*mean(log(sum(π(a|s)*(exp(Q(s, a)-Qj)/η))))
"""
max_q = np.max(target_q1_np, 1)
return eta * self.dual_constraint + np.mean(max_q) + eta * np.mean(np.log(np.sum(
b_prob_np * np.exp((target_q1_np - max_q[:, None]) / eta), axis=1)))
else: # discrete action space
def dual(eta):
"""
dual function of the non-parametric variational
Q = target_q_np (K, N)
g(η) = η*ε + η*mean(log(mean(exp(Q(s, a)/η), along=a)), along=s)
For numerical stabilization, this can be modified to
Qj = max(Q(s, a), along=a)
g(η) = η*ε + mean(Qj, along=j) + η*mean(log(mean(exp((Q(s, a)-Qj)/η), along=a)), along=s)
"""
max_q = np.max(target_q1_np, 1)
return eta * self.dual_constraint + np.mean(max_q) + eta * np.mean(np.log(
np.mean(np.exp((target_q1_np - max_q[:, None]) / eta), axis=1)))
bounds = [(1e-6, None)]
res = minimize(dual, np.array([self.eta]), method='SLSQP', bounds=bounds)
self.eta = res.x[0]
qij = torch.softmax(target_q1 / self.eta, dim=0) # (N, bs)
# M-Step of Policy Improvement
# [2] 4.2 Fitting an improved policy (Step 3)
for _ in range(self.mstep_iterations):
if self.discrete_version:
_, _, _, _, _, dist = self.actor.get_action(o, rhs, d)
loss_pi = torch.mean(qij * dist.expand((N, bs)).log_prob(actions))
kl = kl_divergence(dist, dist_targ).mean()
# Update lagrange multipliers by gradient descent
# this equation is derived from last eq of [2] p.5,
# just differentiate with respect to α
# and update α so that the equation is to be minimized.
self.alpha -= self.alpha_scale * (self.kl_constraint - kl).detach().item()
self.alpha = np.clip(self.alpha, 0.0, self.alpha_max)
# last eq of [2] p.5
loss_policy = -(loss_pi + self.alpha * (self.kl_constraint - kl))
else:
_, _, _, _, _, dist = self.actor.get_action(o, rhs, d)
loss_pi = torch.mean(
qij * (
dist_targ.expand((N, bs, da)).log_prob(sampled_actions).sum(-1) # (N, K)
+ dist.expand((N, bs, da)).log_prob(sampled_actions).sum(-1) # (N, K)
)
)
# Define diag covariance matrices
cov1 = torch.eye(dist.variance.shape[-1]).to(self.device) * dist.variance.unsqueeze(
2).expand(*dist.variance.size(), dist.variance.size(1))
cov2 = torch.eye(dist_targ.variance.shape[-1]).to(self.device) * dist_targ.variance.unsqueeze(
2).expand(*dist_targ.variance.size(), dist_targ.variance.size(1))
kl_mu, kl_sigma = gaussian_kl(dist_targ.mean, dist.mean, cov1, cov2)
if np.isnan(kl_mu.item()): # This should not happen
raise RuntimeError('kl_mu is nan')
if np.isnan(kl_sigma.item()): # This should not happen
raise RuntimeError('kl_sigma is nan')
# Update lagrange multipliers by gradient descent
# this equation is derived from last eq of [2] p.5, just differentiate with
# respect to α and update α so that the equation is to be minimized.
self.alpha_mu -= self.alpha_mean_scale * (self.kl_mean_constraint - kl_mu).detach().item()
self.alpha_sigma -= self.alpha_var_scale * (self.kl_var_constraint - kl_mu).detach().item()
self.alpha_mu = np.clip(self.alpha_mu, 0.0, self.alpha_mean_max)
self.alpha_sigma = np.clip(self.alpha_sigma, 0.0, self.alpha_var_max)
# last eq of [2] p.5
loss_policy = -(loss_pi + self.alpha_mu * (self.kl_mean_constraint - kl_mu) +
self.alpha_sigma * (self.kl_var_constraint - kl_sigma))
addons_info = {}
for addon in self.policy_loss_addons:
addon_loss, addons_info = addon.compute_loss_term(data, dist, addons_info)
loss_policy += addon_loss
return loss_policy, addons_info
[docs] def compute_gradients(self, batch, grads_to_cpu=True):
"""
Compute loss and compute gradients but don't do optimization step,
return gradients instead.
Parameters
----------
batch : dict
data batch containing all required tensors to compute MPO losses.
grads_to_cpu : bool
If gradient tensor will be sent to another node, need to be in CPU.
Returns
-------
grads : list of tensors
List of actor gradients.
info : dict
Dict containing current MPO iteration information.
"""
# Recurrent burn-in
if self.actor.is_recurrent:
batch = self.actor.burn_in_recurrent_states(batch)
# N-step returns
n_step = batch["n_step"] if "n_step" in batch else 1.0
# PER
per_weights = batch["per_weights"] if "per_weights" in batch else 1.0
# Run one gradient descent step for Q1 and Q2
loss_q, errors = self.compute_loss_q(batch, n_step, per_weights)
self.q_optimizer.zero_grad()
loss_q.backward(retain_graph=True)
nn.utils.clip_grad_norm_(self.actor.q1.parameters(), self.max_grad_norm)
q_grads = get_gradients(self.actor.q1, self.actor.q2, grads_to_cpu=grads_to_cpu)
# Freeze Q-networks so you don't waste computational effort
# computing gradients for them during the policy learning step.
for p in self.actor.q1.parameters():
p.requires_grad = False
# Run one gradient descent step for pi.
loss_pi, addons_info = self.compute_loss_pi(batch, per_weights)
self.pi_optimizer.zero_grad()
loss_pi.backward()
nn.utils.clip_grad_norm_(self.actor.policy_net.parameters(), self.max_grad_norm)
pi_grads = get_gradients(self.actor.policy_net, grads_to_cpu=grads_to_cpu)
for p in self.actor.q1.parameters():
p.requires_grad = True
info = {
"loss_q": loss_q.detach().item(),
"loss_pi": loss_pi.detach().item(),
}
if "per_weights" in batch:
info.update({"errors": errors})
grads = {"q_grads": q_grads, "pi_grads": pi_grads}
info.update(addons_info)
return grads, info
[docs] def update_target_networks(self):
"""Update actor critic target networks with polyak averaging."""
if self.iter % self.target_update_interval == 0:
with torch.no_grad():
for p, p_targ in zip(self.actor.parameters(), self.actor_targ.parameters()):
p_targ.data.mul_(self.polyak)
p_targ.data.add_((1 - self.polyak) * p.data)
[docs] def apply_gradients(self, gradients=None):
"""
Take an optimization step, previously setting new gradients if provided.
Update also target networks.
Parameters
----------
gradients : list of tensors
List of actor gradients.
"""
if gradients:
set_gradients(
self.actor.policy_net,
gradients=gradients["pi_grads"], device=self.device)
set_gradients(
self.actor.q1, self.actor.q2,
gradients=gradients["q_grads"], device=self.device)
self.q_optimizer.step()
self.pi_optimizer.step()
# Update target networks by polyak averaging.
self.iter += 1
self.update_target_networks()
[docs] def set_weights(self, actor_weights):
"""
Update actor with the given weights. Update also target networks.
Parameters
----------
actor_weights : dict of tensors
Dict containing actor weights to be set.
"""
self.actor.load_state_dict(actor_weights)
# Update target networks by polyak averaging.
self.iter += 1
self.update_target_networks()
[docs] def update_algorithm_parameter(self, parameter_name, new_parameter_value):
"""
If `parameter_name` is an attribute of the algorithm, change its value
to `new_parameter_value value`.
Parameters
----------
parameter_name : str
Worker.algo attribute name
new_parameter_value : int or float
New value for `parameter_name`.
"""
if hasattr(self, parameter_name):
setattr(self, parameter_name, new_parameter_value)
if parameter_name == "lr":
for param_group in self.pi_optimizer.param_groups:
param_group['lr'] = new_parameter_value
for param_group in self.q_optimizer.param_groups:
param_group['lr'] = new_parameter_value