Source code for pytorchrl.agent.algorithms.off_policy.sac

import itertools
import numpy as np
from copy import deepcopy

import torch
import torch.nn as nn
import torch.optim as optim

import pytorchrl as prl
from pytorchrl.agent.algorithms.base import Algorithm
from pytorchrl.agent.algorithms.utils import get_gradients, set_gradients
from pytorchrl.agent.algorithms.policy_loss_addons import PolicyLossAddOn


[docs]class SAC(Algorithm): """ Soft Actor Critic algorithm class. Algorithm class to execute SAC, from Haarnoja et al. (https://arxiv.org/abs/1812.05905). Algorithms are modules generally required by multiple workers, so SAC.algo_factory(...) returns a function that can be passed on to workers to instantiate their own SAC module. Parameters ---------- device : torch.device CPU or specific GPU where class computations will take place. envs : VecEnv Vector of environments instance. actor : Actor Actor_critic class instance. lr_pi : float Policy optimizer learning rate. lr_q : float Q-nets optimizer learning rate. lr_alpha : float Alpha optimizer learning rate. gamma : float Discount factor parameter. initial_alpha : float Initial entropy coefficient value (temperature). polyak : float SAC polyak averaging parameter. num_updates : int Num consecutive actor updates before data collection continues. update_every : int Regularity of actor updates in number environment steps. start_steps : int Num of initial random environment steps before learning starts. mini_batch_size : int Size of actor update batches. target_update_interval : float regularity of target nets updates with respect to actor Adam updates. num_test_episodes : int Number of episodes to complete in each test phase. test_every : int Regularity of test evaluations in actor updates. max_grad_norm : float Gradient clipping parameter. policy_loss_addons : list List of PolicyLossAddOn components adding loss terms to the algorithm policy loss. Examples -------- >>> create_algo = SAC.create_factory( lr_q=1e-4, lr_pi=1e-4, lr_alpha=1e-4, gamma=0.99, polyak=0.995, num_updates=50, update_every=50, test_every=5000, start_steps=20000, mini_batch_size=64, alpha=1.0, num_test_episodes=0, target_update_interval=1) """ def __init__(self, device, envs, actor, lr_q=1e-4, lr_pi=1e-4, lr_alpha=1e-4, gamma=0.99, polyak=0.995, num_updates=1, update_every=50, test_every=1000, max_grad_norm=0.5, initial_alpha=1.0, start_steps=20000, mini_batch_size=64, num_test_episodes=5, target_update_interval=1, policy_loss_addons=[]): # ---- General algo attributes ---------------------------------------- # Discount factor self._gamma = gamma # Number of steps collected with initial random policy self._start_steps = int(start_steps) # Times data in the buffer is re-used before data collection proceeds self._num_epochs = int(1) # Default to 1 for off-policy algorithms # Number of data samples collected between network update stages self._update_every = int(update_every) # Number mini batches per epoch self._num_mini_batch = int(num_updates) # Size of update mini batches self._mini_batch_size = int(mini_batch_size) # Number of network updates between test evaluations self._test_every = int(test_every) # Number of episodes to complete when testing self._num_test_episodes = int(num_test_episodes) # ---- SAC-specific attributes ---------------------------------------- self.iter = 0 self.envs = envs self.actor = actor self.polyak = polyak self.device = device self.max_grad_norm = max_grad_norm self.target_update_interval = target_update_interval assert hasattr(self.actor, "q1"), "SAC requires double q critic (num_critics=2)" assert hasattr(self.actor, "q2"), "SAC requires double q critic (num_critics=2)" self.log_alpha = torch.tensor( data=[np.log(initial_alpha)], dtype=torch.float32, requires_grad=True, device=device) self.alpha = self.log_alpha.detach().exp() # Compute target entropy target_entropy = self.calculate_target_entropy() self.target_entropy = torch.tensor( data=target_entropy, dtype=torch.float32, requires_grad=False, device=device) # Create target networks self.actor_targ = deepcopy(actor) # Freeze target networks with respect to optimizers for p in self.actor_targ.parameters(): p.requires_grad = False # List of parameters for both Q-networks q_params = itertools.chain(self.actor.q1.parameters(), self.actor.q2.parameters()) # List of parameters for policy network p_params = itertools.chain(self.actor.policy_net.parameters()) # ----- Policy Loss Addons -------------------------------------------- # Sanity check, policy_loss_addons is a PolicyLossAddOn instance # or a list of PolicyLossAddOn instances assert isinstance(policy_loss_addons, (PolicyLossAddOn, list)),\ "SAC policy_loss_addons parameter should be a PolicyLossAddOn instance " \ "or a list of PolicyLossAddOn instances" if isinstance(policy_loss_addons, list): for addon in policy_loss_addons: assert isinstance(addon, PolicyLossAddOn), \ "SAC policy_loss_addons parameter should be a PolicyLossAddOn " \ "instance or a list of PolicyLossAddOn instances" else: policy_loss_addons = [policy_loss_addons] self.policy_loss_addons = policy_loss_addons for addon in self.policy_loss_addons: addon.setup(self.actor, self.device) # ----- Optimizers ---------------------------------------------------- self.pi_optimizer = optim.Adam(p_params, lr=lr_pi) self.q_optimizer = optim.Adam(q_params, lr=lr_q) self.alpha_optimizer = optim.Adam([self.log_alpha], lr=lr_alpha)
[docs] @classmethod def create_factory(cls, lr_q=1e-4, lr_pi=1e-4, lr_alpha=1e-4, gamma=0.99, polyak=0.995, num_updates=50, test_every=5000, update_every=50, start_steps=1000, max_grad_norm=0.5, initial_alpha=1.0, mini_batch_size=64, num_test_episodes=5, target_update_interval=1.0, policy_loss_addons=[]): """ Returns a function to create new SAC instances. Parameters ---------- lr_pi : float Policy optimizer learning rate. lr_q : float Q-nets optimizer learning rate. lr_alpha : float Alpha optimizer learning rate. gamma : float Discount factor parameter. initial_alpha : float Initial entropy coefficient value. polyak : float SAC polyak averaging parameter. num_updates : int Num consecutive actor updates before data collection continues. update_every : int Regularity of actor updates in number environment steps. start_steps : int Num of initial random environment steps before learning starts. mini_batch_size : int Size of actor update batches. target_update_interval : float regularity of target nets updates with respect to actor Adam updates. num_test_episodes : int Number of episodes to complete in each test phase. test_every : int Regularity of test evaluations in actor updates. max_grad_norm : float Gradient clipping parameter. policy_loss_addons : list List of PolicyLossAddOn components adding loss terms to the algorithm policy loss. Returns ------- create_algo_instance : func Function that creates a new SAC class instance. algo_name : str Name of the algorithm. """ def create_algo_instance(device, actor, envs): return cls(lr_q=lr_q, lr_pi=lr_pi, lr_alpha=lr_alpha, envs=envs, actor=actor, gamma=gamma, device=device, polyak=polyak, test_every=test_every, start_steps=start_steps, num_updates=num_updates, update_every=update_every, initial_alpha=initial_alpha, max_grad_norm=max_grad_norm, mini_batch_size=mini_batch_size, num_test_episodes=num_test_episodes, target_update_interval=target_update_interval, policy_loss_addons=policy_loss_addons) return create_algo_instance, prl.SAC
@property def gamma(self): """Returns discount factor gamma.""" return self._gamma @property def start_steps(self): """Returns the number of steps to collect with initial random policy.""" return self._start_steps @property def num_epochs(self): """ Returns the number of times the whole buffer is re-used before data collection proceeds. """ return self._num_epochs @property def update_every(self): """ Returns the number of data samples collected between network update stages. """ return self._update_every @property def num_mini_batch(self): """ Returns the number of times the whole buffer is re-used before data collection proceeds. """ return self._num_mini_batch @property def mini_batch_size(self): """ Returns the number of mini batches per epoch. """ return self._mini_batch_size @property def test_every(self): """Number of network updates between test evaluations.""" return self._test_every @property def num_test_episodes(self): """ Returns the number of episodes to complete when testing. """ return self._num_test_episodes @property def discrete_version(self): """Returns True if action_space is discrete.""" return self.actor.action_space.__class__.__name__ == "Discrete"
[docs] def acting_step(self, obs, rhs, done, deterministic=False): """ SAC acting function. Parameters ---------- obs : torch.tensor Current world observation rhs : torch.tensor RNN recurrent hidden state (if policy is not a RNN, rhs will contain zeroes). done : torch.tensor 1.0 if current obs is the last one in the episode, else 0.0. deterministic : bool Whether to randomly sample action from predicted distribution or taking the mode. Returns ------- action : torch.tensor Predicted next action. clipped_action : torch.tensor Predicted next action (clipped to be within action space). rhs : torch.tensor Policy recurrent hidden state (if policy is not a RNN, rhs will contain zeroes). other : dict Additional SAC predictions, which are not used in other algorithms. """ with torch.no_grad(): (action, clipped_action, logp_action, rhs, entropy_dist, dist) = self.actor.get_action( obs, rhs, done, deterministic=deterministic) return action, clipped_action, rhs, {}
[docs] def compute_loss_q(self, batch, n_step=1, per_weights=1): """ Calculate SAC Q-nets loss Parameters ---------- batch : dict Data batch dict containing all required tensors to compute SAC losses. n_step : int or float Number of future steps used to computed the truncated n-step return value. per_weights : Prioritized Experience Replay (PER) important sampling weights or 1.0. Returns ------- loss_q1 : torch.tensor Q1-net loss. loss_q2 : torch.tensor Q2-net loss. loss_q : torch.tensor Weighted average of loss_q1 and loss_q2. errors : torch.tensor TD errors. """ o, rhs, d = batch[prl.OBS], batch[prl.RHS], batch[prl.DONE] a, r = batch[prl.ACT], batch[prl.REW] o2, rhs2, d2 = batch[prl.OBS2], batch[prl.RHS2], batch[prl.DONE2] if self.discrete_version: # Q-values for all actions q_scores = self.actor.get_q_scores(o, rhs, d) q1 = q_scores.get("q1").gather(1, a.long()) q2 = q_scores.get("q2").gather(1, a.long()) # Bellman backup for Q functions with torch.no_grad(): # Target actions come from *current* policy a2, _, _, _, _, dist = self.actor.get_action(o2, rhs2, d2) bs, n = o.shape[0], dist.probs.shape[-1] actions = torch.arange(n)[..., None].expand(-1, bs).to(self.device) p_a2 = dist.expand((n, bs)).log_prob(actions).exp().transpose(0, 1) # p_a2 = dist.probs z = (p_a2 == 0.0).float() * 1e-8 logp_a2 = torch.log(p_a2 + z) # Target Q-values q_scores_targ = self.actor_targ.get_q_scores(o2, rhs2, d2) q1_targ = q_scores_targ.get("q1") q2_targ = q_scores_targ.get("q2") q_targ = (p_a2 * (torch.min(q1_targ, q2_targ) - self.alpha * logp_a2)).sum(dim=1, keepdim=True) assert r.shape == q_targ.shape backup = r + (self.gamma ** n_step) * (1 - d2) * q_targ else: # Q-values for all actions q_scores = self.actor.get_q_scores(o, rhs, d, a) q1 = q_scores.get("q1") q2 = q_scores.get("q2") # Bellman backup for Q functions with torch.no_grad(): # Target actions come from *current* policy a2, _, logp_a2, _, _, dist = self.actor.get_action(o2, rhs2, d2) # Target Q-values q_scores_targ = self.actor_targ.get_q_scores(o2, rhs2, d2, a2) q1_targ = q_scores_targ.get("q1") q2_targ = q_scores_targ.get("q2") q_pi_targ = torch.min(q1_targ, q2_targ) backup = r + (self.gamma ** n_step) * (1 - d2) * (q_pi_targ - self.alpha * logp_a2) # MSE loss against Bellman backup loss_q1 = (((q1 - backup) ** 2) * per_weights).mean() loss_q2 = (((q2 - backup) ** 2) * per_weights).mean() loss_q = 0.5 * loss_q1 + 0.5 * loss_q2 # errors = (torch.min(q1, q2) - backup).abs().detach().cpu() # errors = torch.max((q1 - backup).abs(), (q2 - backup).abs()).detach().cpu() errors = (0.5 * (q1 - backup).abs() + 0.5 * (q2 - backup).abs()).detach().cpu() return loss_q1, loss_q2, loss_q, errors
[docs] def compute_loss_pi(self, batch, per_weights=1): """ Calculate SAC policy loss. Parameters ---------- batch : dict Data batch dict containing all required tensors to compute SAC losses. per_weights : Prioritized Experience Replay (PER) important sampling weights or 1.0. Returns ------- loss_pi : torch.tensor SAC policy loss. logp_pi : torch.tensor Log probability of predicted next action. """ o, rhs, d = batch[prl.OBS], batch[prl.RHS], batch[prl.DONE] if self.discrete_version: pi, _, _, _, _, dist = self.actor.get_action(o, rhs, d) # Get action log probs bs, n = o.shape[0], dist.probs.shape[-1] actions = torch.arange(n)[..., None].expand(-1, bs).to(self.device) p_pi = dist.expand((n, bs)).log_prob(actions).exp().transpose(0, 1) # p_pi = dist.probs z = (p_pi == 0.0).float() * 1e-8 logp_pi = torch.log(p_pi + z) logp_pi = torch.sum(p_pi * logp_pi, dim=1, keepdim=True) q_scores = self.actor.get_q_scores(o, rhs, d) q1 = q_scores.get("q1") q2 = q_scores.get("q2") q_pi = torch.sum(torch.min(q1, q2) * p_pi, dim=1, keepdim=True) else: pi, _, logp_pi, _, _, dist = self.actor.get_action(o, rhs, d) q_scores = self.actor.get_q_scores(o, rhs, d, pi) q1 = q_scores.get("q1") q2 = q_scores.get("q2") q_pi = torch.min(q1, q2) loss_pi = ((self.alpha * logp_pi - q_pi) * per_weights).mean() # Extend policy loss with addons addons_info = {} for addon in self.policy_loss_addons: addon_loss, addons_info = addon.compute_loss_term(data, dist, addons_info) loss_pi += addon_loss return loss_pi, logp_pi, addons_info
[docs] def compute_loss_alpha(self, log_probs, per_weights=1): """ Calculate SAC entropy loss. Parameters ---------- log_probs : torch.tensor Log probability of predicted next action. per_weights : Prioritized Experience Replay (PER) important sampling weights or 1.0. Returns ------- alpha_loss : torch.tensor SAC entropy loss. """ alpha_loss = - ((self.log_alpha * (log_probs + self.target_entropy).detach()) * per_weights).mean() return alpha_loss
[docs] def calculate_target_entropy(self): """Calculate SAC target entropy""" if self.discrete_version: target = - np.log(1.0 / self.actor.action_space.n) * 0.98 else: target_old = - self.actor.action_space.shape[0] target = - np.prod(self.actor.action_space.shape) assert target_old == target return target
[docs] def compute_gradients(self, batch, grads_to_cpu=True): """ Compute loss and compute gradients but don't do optimization step, return gradients instead. Parameters ---------- batch : dict data batch containing all required tensors to compute SAC losses. grads_to_cpu : bool If gradient tensor will be sent to another node, need to be in CPU. Returns ------- grads : list of tensors List of actor gradients. info : dict Dict containing current SAC iteration information. """ # Recurrent burn-in if self.actor.is_recurrent: batch = self.actor.burn_in_recurrent_states(batch) # N-step returns n_step = batch["n_step"] if "n_step" in batch else 1.0 # PER per_weights = batch["per_weights"] if "per_weights" in batch else 1.0 # Run one gradient descent step for Q1 and Q2 loss_q1, loss_q2, loss_q, errors = self.compute_loss_q(batch, n_step, per_weights) self.q_optimizer.zero_grad() loss_q.backward(retain_graph=True) nn.utils.clip_grad_norm_(itertools.chain( self.actor.q1.parameters(), self.actor.q2.parameters()), self.max_grad_norm) q_grads = get_gradients(self.actor.q1, self.actor.q2, grads_to_cpu=grads_to_cpu) # Freeze Q-networks so you don't waste computational effort # computing gradients for them during the policy learning step. for p in itertools.chain(self.actor.q1.parameters(), self.actor.q2.parameters()): p.requires_grad = False # Run one gradient descent step for pi. loss_pi, logp_pi, addons_info = self.compute_loss_pi(batch, per_weights) self.pi_optimizer.zero_grad() loss_pi.backward() nn.utils.clip_grad_norm_(self.actor.policy_net.parameters(), self.max_grad_norm) pi_grads = get_gradients(self.actor.policy_net, grads_to_cpu=grads_to_cpu) for p in itertools.chain(self.actor.q1.parameters(), self.actor.q2.parameters()): p.requires_grad = True # Run one gradient descent step for alpha. self.alpha_optimizer.zero_grad() loss_alpha = self.compute_loss_alpha(logp_pi, per_weights) loss_alpha.backward() info = { "loss_q1": loss_q1.detach().item(), "loss_q2": loss_q2.detach().item(), "loss_pi": loss_pi.detach().item(), "loss_alpha": loss_alpha.detach().item(), "alpha": self.alpha.detach().item(), } if "per_weights" in batch: info.update({"errors": errors}) grads = {"q_grads": q_grads, "pi_grads": pi_grads} info.update(addons_info) return grads, info
[docs] def update_target_networks(self): """Update actor critic target networks with polyak averaging""" if self.iter % self.target_update_interval == 0: with torch.no_grad(): for p, p_targ in zip(self.actor.parameters(), self.actor_targ.parameters()): p_targ.data.mul_(self.polyak) p_targ.data.add_((1 - self.polyak) * p.data)
[docs] def apply_gradients(self, gradients=None): """ Take an optimization step, previously setting new gradients if provided. Update also target networks. Parameters ---------- gradients : list of tensors List of actor gradients. """ if gradients is not None: set_gradients( self.actor.policy_net, gradients=gradients["pi_grads"], device=self.device) set_gradients( self.actor.q1, self.actor.q2, gradients=gradients["q_grads"], device=self.device) self.q_optimizer.step() self.pi_optimizer.step() self.alpha_optimizer.step() self.alpha = self.log_alpha.detach().exp() # Update target networks by polyak averaging. self.iter += 1 self.update_target_networks()
[docs] def set_weights(self, actor_weights): """ Update actor with the given weights. Update also target networks. Parameters ---------- actor_weights : dict of tensors Dict containing actor weights to be set. """ self.actor.load_state_dict(actor_weights) self.alpha_optimizer.step() self.alpha = self.log_alpha.detach().exp() # Update target networks by polyak averaging. self.iter += 1 self.update_target_networks()
[docs] def update_algorithm_parameter(self, parameter_name, new_parameter_value): """ If `parameter_name` is an attribute of the algorithm, change its value to `new_parameter_value value`. Parameters ---------- parameter_name : str Worker.algo attribute name new_parameter_value : int or float New value for `parameter_name`. """ if hasattr(self, parameter_name): setattr(self, parameter_name, new_parameter_value) if parameter_name == "lr": for param_group in self.pi_optimizer.param_groups: param_group['lr'] = new_parameter_value for param_group in self.q_optimizer.param_groups: param_group['lr'] = new_parameter_value for param_group in self.alpha_optimizer.param_groups: param_group['lr'] = new_parameter_value