Source code for pytorchrl.agent.algorithms.on_policy.a2c

import torch
import itertools
import torch.nn as nn
import torch.optim as optim

import pytorchrl as prl
from pytorchrl.agent.algorithms.base import Algorithm
from pytorchrl.agent.algorithms.policy_loss_addons import PolicyLossAddOn
from pytorchrl.agent.algorithms.utils import get_gradients, set_gradients


[docs]class A2C(Algorithm): """ Algorithm class to execute A2C, from Mnih et al. 2016 (https://arxiv.org/pdf/1602.01783.pdf). Parameters ---------- device : torch.device CPU or specific GPU where class computations will take place. envs : VecEnv Vector of environments instance. actor : Actor Actor_critic class instance. lr_v : float Value network learning rate. lr_pi : float Policy network learning rate. gamma : float Discount factor parameter. num_test_episodes : int Number of episodes to complete in each test phase. max_grad_norm : float Gradient clipping parameter. test_every : int Regularity of test evaluations in actor updates. num_test_episodes : int Number of episodes to complete in each test phase. policy_loss_addons : list List of PolicyLossAddOn components adding loss terms to the algorithm policy loss. """ def __init__(self, device, envs, actor, lr_v=1e-4, lr_pi=1e-4, gamma=0.99, test_every=5000, max_grad_norm=0.5, num_test_episodes=5, policy_loss_addons=[]): # ---- General algo attributes ---------------------------------------- # Discount factor self._gamma = gamma # Number of steps collected with initial random policy self._start_steps = int(0) # Default to 0 for On-policy algos # Times data in the buffer is re-used before data collection proceeds self._num_epochs = int(1) # Number of data samples collected between network update stages self._update_every = None # Depends on storage capacity # Number mini batches per epoch self._num_mini_batch = int(1) # Size of update mini batches self._mini_batch_size = None # Depends on storage capacity # Number of network updates between test evaluations self._test_every = int(test_every) # Number of episodes to complete when testing self._num_test_episodes = int(num_test_episodes) # ---- A2C-specific attributes ---------------------------------------- self.iter = 0 self.envs = envs self.actor = actor self.device = device self.max_grad_norm = max_grad_norm assert hasattr(self.actor, "value_net1"), "A2C requires value critic (num_critics=1)" # ----- Policy Loss Addons -------------------------------------------- # Sanity check, policy_loss_addons is a PolicyLossAddOn instance # or a list of PolicyLossAddOn instances assert isinstance(policy_loss_addons, (PolicyLossAddOn, list)),\ "A2C policy_loss_addons parameter should be a PolicyLossAddOn instance " \ "or a list of PolicyLossAddOn instances" if isinstance(policy_loss_addons, list): for addon in policy_loss_addons: assert isinstance(addon, PolicyLossAddOn), \ "A2C policy_loss_addons parameter should be a PolicyLossAddOn" \ " instance or a list of PolicyLossAddOn instances" else: policy_loss_addons = [policy_loss_addons] self.policy_loss_addons = policy_loss_addons for addon in self.policy_loss_addons: addon.setup(self.actor, self.device) # ----- Optimizer ----------------------------------------------------- self.pi_optimizer = optim.Adam(self.actor.policy_net.parameters(), lr=lr_pi) self.v_optimizer = optim.Adam(self.actor.value_net1.parameters(), lr=lr_v)
[docs] @classmethod def create_factory(cls, lr_v=1e-4, lr_pi=1e-4, gamma=0.99, test_every=5000, max_grad_norm=0.5, num_test_episodes=5, policy_loss_addons=[]): """ Returns a function to create new A2C instances. Parameters ---------- lr_v : float Value network learning rate. lr_pi : float Policy network learning rate. gamma : float Discount factor parameter. num_test_episodes : int Number of episodes to complete in each test phase. max_grad_norm : float Gradient clipping parameter. test_every : int Regularity of test evaluations in actor updates. num_test_episodes : int Number of episodes to complete in each test phase. policy_loss_addons : list List of PolicyLossAddOn components adding loss terms to the algorithm policy loss. Returns ------- create_algo_instance : func Function that creates a new A2C class instance. algo_name : str Name of the algorithm. """ def create_algo_instance(device, actor, envs): return cls(lr_pi=lr_pi, lr_v=lr_v, envs=envs, actor=actor, gamma=gamma, device=device, test_every=test_every, max_grad_norm=max_grad_norm, num_test_episodes=num_test_episodes, policy_loss_addons=policy_loss_addons) return create_algo_instance, prl.A2C
@property def gamma(self): """Returns discount factor gamma.""" return self._gamma @property def start_steps(self): """Returns the number of steps to collect with initial random policy.""" return self._start_steps @property def num_epochs(self): """ Returns the number of times the whole buffer is re-used before data collection proceeds. """ return self._num_epochs @property def update_every(self): """ Returns the number of data samples collected between network update stages. """ return self._update_every @property def num_mini_batch(self): """ Returns the number of times the whole buffer is re-used before data collection proceeds. """ return self._num_mini_batch @property def mini_batch_size(self): """ Returns the number of mini batches per epoch. """ return self._mini_batch_size @property def test_every(self): """Number of network updates between test evaluations.""" return self._test_every @property def num_test_episodes(self): """ Returns the number of episodes to complete when testing. """ return self._num_test_episodes
[docs] def acting_step(self, obs, rhs, done, deterministic=False): """ A2C acting function. Parameters ---------- obs: torch.tensor Current world observation rhs: torch.tensor RNN recurrent hidden state (if policy is not a RNN, rhs will contain zeroes). done: torch.tensor 1.0 if current obs is the last one in the episode, else 0.0. deterministic: bool Whether to randomly sample action from predicted distribution or take the mode. Returns ------- action: torch.tensor Predicted next action. clipped_action: torch.tensor Predicted next action (clipped to be within action space). rhs: torch.tensor Policy recurrent hidden state (if policy is not a RNN, rhs will contain zeroes). other: dict Additional A2C predictions, value score and action log probability. """ with torch.no_grad(): (action, clipped_action, logp_action, rhs, entropy_dist, dist) = self.actor.get_action( obs, rhs, done, deterministic) value_dict = self.actor.get_value(obs, rhs, done) value = value_dict.get("value_net1") rhs = value_dict.get("rhs") other = {prl.VAL: value, prl.LOGP: logp_action} return action, clipped_action, rhs, other
[docs] def compute_loss(self, data): """ Calculate A2C loss Parameters ---------- data: dict Data batch dict containing all required tensors to compute A2C loss. Returns ------- loss : torch.tensor A2C loss. """ o, rhs, a, old_v = data[prl.OBS], data[prl.RHS], data[prl.ACT], data[prl.VAL] r, d, old_logp, adv = data[prl.RET], data[prl.DONE], data[prl.LOGP], data[prl.ADV] # Policy loss logp, dist_entropy, dist = self.actor.evaluate_actions(o, rhs, d, a) pi_loss = - (logp * adv).mean() # Extend policy loss with addons addons_info = {} for addon in self.policy_loss_addons: addon_loss, addons_info = addon.compute_loss_term(data, dist, addons_info) pi_loss += addon_loss # Value loss new_v = self.actor.get_value(o, rhs, d).get("value_net1") value_loss = (r - new_v).pow(2).mean() return pi_loss, value_loss, addons_info
[docs] def compute_gradients(self, batch, grads_to_cpu=True): """ Compute loss and compute gradients but don't do optimization step, return gradients instead. Parameters ---------- data: dict data batch containing all required tensors to compute A2C loss. grads_to_cpu: bool If gradient tensor will be sent to another node, need to be in CPU. Returns ------- grads: list of tensors List of actor gradients. info: dict Dict containing current A2C iteration information. """ # Compute A2C losses action_loss, value_loss, addons_info = self.compute_loss(batch) # Compute policy gradients self.pi_optimizer.zero_grad() action_loss.backward(retain_graph=True) for p in self.actor.policy_net.parameters(): p.requires_grad = False # Compute value gradients self.v_optimizer.zero_grad() value_loss.backward() for p in self.actor.policy_net.parameters(): p.requires_grad = True # Clip gradients to max value nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm) pi_grads = get_gradients(self.actor.policy_net, grads_to_cpu=grads_to_cpu) v_grads = get_gradients(self.actor.value_net1, grads_to_cpu=grads_to_cpu) grads = {"pi_grads": pi_grads, "v_grads": v_grads} info = { "value_loss": value_loss.item(), "action_loss": action_loss.item(), } info.update(addons_info) return grads, info
[docs] def apply_gradients(self, gradients=None): """ Take an optimization step, previously setting new gradients if provided. Parameters ---------- gradients: list of tensors List of actor gradients. """ if gradients is not None: set_gradients( self.actor.policy_net, gradients=gradients["pi_grads"], device=self.device) set_gradients( self.actor.value_net1, gradients=gradients["v_grads"], device=self.device) self.pi_optimizer.step() self.v_optimizer.step() self.iter += 1
[docs] def set_weights(self, actor_weights): """ Update actor with the given weights. Parameters ---------- actor_weights: dict of tensors Dict containing actor weights to be set. """ self.actor.load_state_dict(actor_weights) self.iter += 1
[docs] def update_algorithm_parameter(self, parameter_name, new_parameter_value): """ If `parameter_name` is an attribute of the algorithm, change its value to `new_parameter_value value`. Parameters ---------- parameter_name : str Worker.algo attribute name new_parameter_value : int or float New value for `parameter_name`. """ if hasattr(self, parameter_name): setattr(self, parameter_name, new_parameter_value) if parameter_name == "lr_v": for param_group in self.v_optimizer.param_groups: param_group['lr'] = new_parameter_value elif parameter_name == "lr_pi": for param_group in self.pi_optimizer.param_groups: param_group['lr'] = new_parameter_value