Source code for pytorchrl.agent.storages.on_policy.vtrace_buffer

import torch
import pytorchrl as prl
from pytorchrl.agent.storages.on_policy.vanilla_on_policy_buffer import VanillaOnPolicyBuffer as B


[docs]class VTraceBuffer(B): """ Storage class for On-Policy algorithms with off-policy correction method V-trace (https://arxiv.org/abs/1506.02438). Parameters ---------- size : int Storage capacity along time axis. device: torch.device CPU or specific GPU where data tensors will be placed and class computations will take place. Should be the same device where the actor model is located. envs : VecEnv Vector of environments instance. actor : Actor Actor class instance. algorithm : Algorithm Algorithm class instance. """ # Data fields to store in buffer and contained in generated batches storage_tensors = prl.OnPolicyDataKeys def __init__(self, size, device, actor, algorithm, envs): super(VTraceBuffer, self).__init__( size=size, envs=envs, device=device, actor=actor, algorithm=algorithm)
[docs] def before_gradients(self): """ Before updating actor policy model, compute returns and advantages. """ last_tensors = {} step = self.step if self.step != 0 else -1 for k in (prl.OBS, prl.RHS, prl.DONE): if isinstance(self.data[k], dict): last_tensors[k] = {x: self.data[k][x][step] for x in self.data[k]} else: last_tensors[k] = self.data[k][step - 1] with torch.no_grad(): _ = self.actor.get_action(last_tensors[prl.OBS], last_tensors[prl.RHS], last_tensors[prl.DONE]) value_dict = self.actor.get_value(last_tensors[prl.OBS], last_tensors[prl.RHS], last_tensors[prl.DONE]) next_value = value_dict.get("value_net1") next_rhs = value_dict.get("rhs") self.data[prl.RET][step].copy_(next_value) self.data[prl.VAL][step].copy_(next_value) if isinstance(next_rhs, dict): for x in self.data[prl.RHS]: self.data[prl.RHS][x][step].copy_(next_rhs[x]) else: self.data[prl.RHS][step] = next_rhs self.compute_returns() self.compute_vtrace()
[docs] @torch.no_grad() def get_updated_action_log_probs(self): """ Computes new log probabilities of actions stored in `storage` according to current `actor` version. It also uses the current `actor` version to update the value predictions. """ l, num_proc = self.data[prl.DONE].shape[0:2] l = self.step if self.step != 0 else self.max_size # Create batches without shuffling data batches = self.generate_batches( self.algo.num_mini_batch, self.algo.mini_batch_size, num_epochs=1, shuffle=False) # Obtain new value and log probability predictions new_val = [] new_logp = [] for batch in batches: obs, rhs, act, done = batch[prl.OBS], batch[prl.RHS], batch[prl.ACT], batch[prl.DONE] (logp, _, _) = self.actor.evaluate_actions(obs, rhs, done, act) val = self.actor.get_value(obs, rhs, done).get("value_net1") new_val.append(val) new_logp.append(logp) # Concatenate results if self.actor.is_recurrent: new_val = [p.view(l, num_proc // self.algo.num_mini_batch, -1) for p in new_val] self.data[prl.VAL][:-1] = torch.cat(new_val, dim=1) new_logp = [p.view(l, num_proc // self.algo.num_mini_batch, -1) for p in new_logp] new_logp = torch.cat(new_logp, dim=1) else: self.data[prl.VAL][:-1] = torch.cat(new_val, dim=0).view(l, num_proc, 1) new_logp = torch.cat(new_logp, dim=0).view(l, num_proc, 1) return new_logp
[docs] @torch.no_grad() def compute_vtrace(self, clip_rho_thres=1.0, clip_c_thres=1.0): """ Computes V-trace target values and advantage predictions and stores them, along with the updated action log probabilities, in `storage`. Parameters ---------- clip_rho_thres : float V-trace rho threshold parameter. clip_c_thres : float V-trace c threshold parameter. """ l = self.step if self.step != 0 else self.max_size new_action_log_probs = self.get_updated_action_log_probs() log_rhos = (new_action_log_probs - self.data[prl.LOGP][:l]) clipped_rhos = torch.clamp(torch.exp(log_rhos), max=clip_rho_thres) clipped_cs = torch.clamp(torch.exp(log_rhos), max=clip_c_thres) deltas = clipped_rhos * ( self.data[prl.RET][:-1] + self.algo.gamma * self.data[prl.VAL][1:] - self.data[prl.VAL][:-1]) acc = torch.zeros_like(self.data[prl.VAL][-1]) result = [] for i in reversed(range(l)): acc = deltas[i] + self.algo.gamma * clipped_cs[i] * acc * ( 1 - self.data[prl.DONE][i + 1]) result.append(acc) result.reverse() result.append(torch.zeros_like(self.data[prl.VAL][-1])) vs_minus_v_xs = torch.stack(result) vs = torch.add(vs_minus_v_xs, self.data[prl.VAL]) adv = clipped_rhos * (self.data[prl.RET][:-1] + self.algo.gamma * vs[1:] - self.data[prl.VAL][:-1]) self.data[prl.RET] = vs self.data[prl.LOGP] = new_action_log_probs self.data[prl.ADV] = (adv - adv.mean()) / (adv.std() + 1e-8)