import torch
import numpy as np
import pytorchrl as prl
from pytorchrl.agent.storages.off_policy.nstep_buffer import NStepReplayBuffer as B
[docs]def dim0_reshape(tensor, size):
"""
Reshapes tensor so indices are defined like this:
00, 01, 02, 03, 04, 05, 06, 07, 08, 09, size + 1, ..., self.max_size
10, 11, 12, 13, 14, 15, 16, 17, 18, 19, size + 1, ..., self.max_size
20, 21, 22, 23, 24, 25, 26, 27, 28, 29, size + 1, ..., self.max_size
"""
return np.moveaxis(tensor, [0, 1], [1, 0])[:, 0: size].reshape(-1, *tensor.shape[2:])
[docs]class PERBuffer(B):
"""
Storage class for Off-Policy algorithms using PER (https://arxiv.org/abs/1707.01495).
This component extends NStepReplayBuffer, enabling to combine PER with
n step learning. However, default n_step value is 1, which is equivalent
to not using n_step learning at all.
Parameters
----------
size : int
Storage capacity along time axis.
device : torch.device
CPU or specific GPU where data tensors will be placed and class
computations will take place. Should be the same device where the
actor model is located.
envs : VecEnv
Vector of environments instance.
actor : Actor
Actor class instance.
algorithm : Algorithm
Algorithm class instance.
n_step : int or float
Number of future steps used to computed the truncated n-step return value.
epsilon : float
PER epsilon parameter.
alpha : float
PER alpha parameter.
beta : float
PER beta parameter.
default_error : int or float
Default TD error value to use for newly added data samples.
"""
# Data fields to store in buffer and contained in the generated batches
storage_tensors = prl.OffPolicyDataKeys
def __init__(self, size, device, actor, algorithm, envs, n_step=1, epsilon=0.0, alpha=0.0, beta=1.0, default_error=1000000):
super(PERBuffer, self).__init__(
size=size, device=device, actor=actor,
algorithm=algorithm, envs=envs, n_step=n_step)
self.beta = beta
self.alpha = alpha
self.epsilon = epsilon
self.data["priority"] = None
self.error = default_error
[docs] @classmethod
def create_factory(cls, size, n_step=1, epsilon=0.0, alpha=0.0, beta=1.0, default_error=1000000):
"""
Returns a function that creates PERBuffer instances.
Parameters
----------
size : int
Storage capacity along time axis.
n_step : int or float
Number of future steps used to computed the truncated n-step return value.
epsilon : float
PER epsilon parameter.
alpha : float
PER alpha parameter.
beta : float
PER beta parameter.
default_error : int or float
Default TD error value to use for newly added data samples.
Returns
-------
create_buffer_instance : func
creates a new PERBuffer class instance.
"""
def create_buffer(device, actor, algorithm, envs):
"""Create and return a PERBuffer instance."""
return cls(size, device, actor, algorithm, envs, n_step, epsilon, alpha, beta, default_error)
return create_buffer
[docs] def get_priority(self, error):
"""Takes in the error of one or more examples and returns the proportional priority"""
return np.power(error + self.epsilon, self.alpha)
[docs] def get_sequence_priority(self, sequence_data, eta=0.9):
""" Get priority score for a given data sequence. """
term1 = eta * np.max(sequence_data, axis=0)
term2 = (1 - eta) * np.mean(sequence_data, axis=0)
priority = term1 + term2
return priority
[docs] def insert_transition(self, sample):
"""
Store new transition sample.
Parameters
----------
sample : dict
Data sample (containing all tensors of an environment transition)
"""
# Data tensors lazy initialization
if self.size == 0 and prl.OBS not in self.data.keys():
self.init_tensors(sample)
# if recurrent, save fixed-length consecutive overlapping sequences
if self.recurrent_actor and self.step % self.sequence_length == 0 and self.step != 0:
next_seq_overlap = self.get_data_slice(self.step - self.overlap_length, self.step)
self.insert_data_slice(next_seq_overlap)
# Add obs, rhs, done, act and rew to n_step buffer
self.n_step_buffer[prl.OBS].append(sample[prl.OBS])
self.n_step_buffer[prl.REW].append(sample[prl.REW])
self.n_step_buffer[prl.ACT].append(sample[prl.ACT])
self.n_step_buffer[prl.RHS].append(sample[prl.RHS])
self.n_step_buffer[prl.DONE].append(sample[prl.DONE])
if len(self.n_step_buffer[prl.OBS]) == self.n_step:
# Add obs2, rhs2 and done2 directly
for k in (prl.OBS2, prl.RHS2, prl.DONE2):
if not self.recurrent_actor and k == prl.RHS2:
continue
if isinstance(sample[k], dict):
for x, y in sample[k].items():
self.data[k][x][self.step] = y.cpu()
else:
self.data[k][self.step] = sample[k].cpu()
# Compute done and rew
(self.data[prl.REW][self.step],
self.data[prl.DONE][self.step]) = self._nstep_return()
# Get obs, rhs and act from step buffer
for k in (prl.OBS, prl.RHS, prl.ACT):
if not self.recurrent_actor and k == prl.RHS:
continue
tensor = self.n_step_buffer[k].popleft()
if isinstance(tensor, dict):
for x, y in tensor.items():
self.data[k][x][self.step] = y.cpu()
else:
self.data[k][self.step] = tensor.cpu()
# Update
self.step = (self.step + 1) % self.max_size
self.size = min(self.size + 1, self.max_size)
[docs] def before_gradients(self):
"""
Steps required before updating actor policy model.
"""
num_proc = self.data[prl.DONE].shape[1]
if self.data["priority"] is None: # Lazy initialization of priority
if self.recurrent_actor:
default_priority = self.get_sequence_priority(self.error * np.ones((self.sequence_length)))
else:
default_priority = self.error
self.data["priority"] = default_priority * np.ones((self.max_size, num_proc, 1))
[docs] def after_gradients(self, batch, info):
"""
Steps required after updating actor policy model
Parameters
----------
batch : dict
Data batch used to compute the gradients.
info : dict
Additional relevant info from gradient computation.
Returns
-------
info : dict
info dict updated with relevant info from Storage.
"""
if "per_weights" in batch.keys() and isinstance(batch["per_weights"], torch.Tensor):
assert "errors" in info[prl.ALGORITHM].keys(), "TD errors missing!"
if self.recurrent_actor:
endpos = int(self.size + self.sequence_length)
# Get data indices and td errors
idxs = np.array(batch.pop("idxs")).reshape(-1, self.sequence_length)
errors = info[prl.ALGORITHM]["errors"].reshape(-1, self.non_overlap_length)
# Since sequences overlap, update both current sequence and
# start of the next overlapping sequence
idxs = np.concatenate([idxs[:, -self.non_overlap_length:], idxs[:, -self.overlap_length:] + self.overlap_length], axis=1)
errors = torch.cat([errors, errors[:, -self.overlap_length:]], dim=1)
for i, e in zip(idxs, errors): # each sequence in the batch
# Assign priorities to both end of current sequence
# and start of next sequence
dim0_reshape(self.data["priority"], endpos)[i] = \
self.get_priority(e.unsqueeze(1))
# Update current sequence average priority
sequence = dim0_reshape(self.data["priority"], endpos)[
i - self.overlap_length]
pri = self.get_sequence_priority(sequence)
dim0_reshape(self.data["priority"], endpos)[
i - self.overlap_length] = pri * np.ones(sequence.shape)
# Update next sequence average if some overlap
if self.overlap_length > 0:
sequence = dim0_reshape(
self.data["priority"], endpos)[i + self.non_overlap_length]
pri = self.get_sequence_priority(sequence)
dim0_reshape(self.data["priority"], endpos)[
i + self.non_overlap_length] = pri * np.ones(sequence.shape)
else:
self.data["priority"][0:self.size].reshape(-1, *self.data["priority"].shape[2:])[
batch.pop("idxs")] = self.get_priority(info[prl.ALGORITHM]["errors"])
return info
[docs] def generate_batches(self, num_mini_batch, mini_batch_size, num_epochs=1):
"""
Returns a batch iterator to update actor.
Parameters
----------
num_mini_batch : int
Number mini batches per epoch.
mini_batch_size : int
Number of samples contained in each mini batch.
num_epochs : int
Number of epochs.
Yields
------
batch : dict
Generated data batches.
"""
num_proc = self.data[prl.DONE].shape[1]
for _ in range(num_mini_batch):
if self.recurrent_actor: # Batches to feed a recurrent actor
sequences_x_batch = mini_batch_size // self.sequence_length + 1
sequences_x_proc = int(self.size / self.sequence_length)
assert self.size % self.sequence_length == 0, \
"Buffer does not contain an integer number of complete rollout sequences"
# Define batch structure
batch = {k: {} if not isinstance(self.data[k], dict) else
{x: {} for x in self.data[k]} for k in self.data.keys()}
# Select sequences
if self.alpha == 0.0:
seq_idxs = np.random.randint(0, num_proc * sequences_x_proc, size=sequences_x_batch)
per_weigths = 1.0
else:
priors = dim0_reshape(self.data["priority"], self.size)
probs = priors / priors.sum()
per_weigths = np.power(num_proc * self.size * probs, - self.beta)
per_weigths = per_weigths / per_weigths.max()
# Trick to allow updating priorities of next overlapping
# sequences after gradient computation. Insert some "0.0"
# values in per_weights for end-of-row + 1 sequences.
per_weigths = np.split(per_weigths, num_proc)
per_weigths = [np.concatenate([chunk, np.zeros((
self.sequence_length, 1))]) for chunk in per_weigths]
per_weigths = np.concatenate(per_weigths)
# Trick to allow updating priorities of next overlapping
# sequences after gradient computation. Insert some "0.0"
# values in probs for end-of-row + 1 sequences.
probs = probs[self.sequence_length
- 1::self.sequence_length] * self.sequence_length
probs = np.split(probs, num_proc)
probs = [np.concatenate([chunk, np.zeros(
(1, 1))]) for chunk in probs]
ext_probs = np.concatenate(probs).squeeze(1)
ext_probs = ext_probs / ext_probs.sum() # sanity check
seq_idxs = np.random.choice(range(len(ext_probs)), size=sequences_x_batch, p=ext_probs)
# Get data indexes
idxs = []
for idx in seq_idxs:
idxs += range(idx * self.sequence_length, (idx + 1) * self.sequence_length)
if not isinstance(per_weigths, float):
per_weigths = torch.as_tensor(per_weigths[idxs], dtype=torch.float32).to(self.device)
# Fill up batch with data
for k, v in self.data.items():
# Only first recurrent state in each sequence needed
positions = seq_idxs * self.sequence_length if k in (prl.RHS, prl.RHS2) else idxs
if isinstance(v, dict):
for x, y in v.items():
t = dim0_reshape(y, self.size + self.sequence_length)[positions]
batch[k][x] = torch.as_tensor(t, dtype=torch.float32).to(self.device)
else:
t = dim0_reshape(v, self.size + self.sequence_length)[positions]
batch[k] = torch.as_tensor(t, dtype=torch.float32).to(self.device)
batch.update({"per_weights": per_weigths, "n_step": self.n_step, "idxs": idxs})
yield batch
else:
batch = {k: {} for k in self.data.keys()}
if self.alpha == 0.0:
samples = np.random.randint(0, num_proc * self.size, size=mini_batch_size)
per_weigths = 1.0
else:
priors = self.data["priority"][0:self.size].reshape(-1)
probs = priors / priors.sum()
samples = np.random.choice(range(num_proc * self.size), size=mini_batch_size, p=probs)
per_weigths = np.power(num_proc * self.size * probs, -self.beta)
per_weigths = torch.as_tensor(per_weigths / per_weigths.max(), dtype=torch.float32).to(self.device)
per_weigths = per_weigths.view(-1, 1)[samples]
for k, v in self.data.items():
if k in (prl.RHS, prl.RHS2):
size, idxs = 1, np.array([0])
else:
size, idxs = self.size, samples
if isinstance(v, dict):
for x, y in v.items():
batch[k][x] = torch.as_tensor(y[0:self.size].reshape(
-1, *y.shape[2:])[idxs], dtype=torch.float32).to(self.device)
else:
batch[k] = torch.as_tensor(v[0:self.size].reshape(
-1, *v.shape[2:])[idxs], dtype=torch.float32).to(self.device)
batch.update({"per_weights": per_weigths, "n_step": self.n_step, "idxs": idxs})
yield batch