import torch
import numpy as np
from collections import deque
import pytorchrl as prl
from pytorchrl.agent.storages.off_policy.replay_buffer import ReplayBuffer as S
[docs]def dim0_reshape(tensor, size):
"""
Reshapes tensor so indices are defined like this:
00, 01, 02, 03, 04, 05, 06, 07, 08, 09, size + 1, ..., self.max_size
10, 11, 12, 13, 14, 15, 16, 17, 18, 19, size + 1, ..., self.max_size
20, 21, 22, 23, 24, 25, 26, 27, 28, 29, size + 1, ..., self.max_size
"""
return np.moveaxis(tensor, [0, 1], [1, 0])[:, 0: size].reshape(-1, *tensor.shape[2:])
[docs]class NStepReplayBuffer(S):
"""
Storage class for Off-Policy with multi step learning (https://arxiv.org/abs/1710.02298).
Parameters
----------
size : int
Storage capacity along time axis.
device : torch.device
CPU or specific GPU where data tensors will be placed and class
computations will take place. Should be the same device where the
actor model is located.
envs : VecEnv
Vector of environments instance.
actor : Actor
Actor class instance.
algorithm : Algorithm
Algorithm class instance.
n_step : int or float
Number of future steps used to computed the truncated n-step return value.
"""
# Data fields to store in buffer and contained in the generated batches
storage_tensors = prl.OffPolicyDataKeys
def __init__(self, size, device, actor, algorithm, envs, n_step=1):
algorithm._update_every += n_step - 1
super(NStepReplayBuffer, self).__init__(
size=size, device=device, actor=actor, algorithm=algorithm, envs=envs)
self.n_step = n_step
self.gamma = algorithm.gamma
self.n_step_buffer = {k: deque(maxlen=n_step) for k in self.storage_tensors}
[docs] @classmethod
def create_factory(cls, size, n_step=1):
"""
Returns a function that creates NStepReplayBuffer instances.
Parameters
----------
size : int
Storage capacity along time axis.
n_step : int or float
Number of future steps used to computed the truncated n-step return value.
Returns
-------
create_buffer_instance : func
creates a new NStepReplayBuffer class instance.
"""
def create_buffer(device, actor, algorithm, envs):
"""Create and return a NStepReplayBuffer instance."""
return cls(size, device, actor, algorithm, envs, n_step)
return create_buffer
[docs] def insert_transition(self, sample):
"""
Store new transition sample.
Parameters
----------
sample : dict
Data sample (containing all tensors of an environment transition)
"""
# Data tensors lazy initialization
if self.size == 0 and prl.OBS not in self.data.keys():
self.init_tensors(sample)
# If using memory, save fixed length consecutive overlapping sequences
if self.recurrent_actor and self.step % self.sequence_length == 0 and self.step != 0:
next_seq_overlap = self.get_data_slice(self.step - self.overlap_length, self.step)
self.insert_data_slice(next_seq_overlap)
# Add obs, rhs, done, act and rew to n_step buffer
self.n_step_buffer[prl.OBS].append(sample[prl.OBS])
self.n_step_buffer[prl.REW].append(sample[prl.REW])
self.n_step_buffer[prl.ACT].append(sample[prl.ACT])
self.n_step_buffer[prl.RHS].append(sample[prl.RHS])
self.n_step_buffer[prl.DONE].append(sample[prl.DONE])
if len(self.n_step_buffer[prl.OBS]) == self.n_step:
# Add obs2, rhs2 and done2 directly
for k in (prl.OBS2, prl.RHS2, prl.DONE2):
if not self.recurrent_actor and k == prl.RHS2:
continue
if isinstance(sample[k], dict):
for x, y in sample[k].items():
self.data[k][x][self.step] = y.cpu()
else:
self.data[k][self.step] = sample[k].cpu()
# Compute done and rew
(self.data[prl.REW][self.step],
self.data[prl.DONE][self.step]) = self._nstep_return()
# Get obs, rhs and act from step buffer
for k in (prl.OBS, prl.RHS, prl.ACT):
if not self.recurrent_actor and k == prl.RHS:
continue
tensor = self.n_step_buffer[k].popleft()
if isinstance(tensor, dict):
for x, y in tensor.items():
self.data[k][x][self.step] = y.cpu()
else:
self.data[k][self.step] = tensor.cpu()
# Update
self.step = (self.step + 1) % self.max_size
self.size = min(self.size + 1, self.max_size)
def _nstep_return(self):
"""
Computes truncated n-step returns.
Returns
-------
ret : numpy.ndarray
Next sample returns, to store in buffer.
done : numpy.ndarray
Next sample dones, to store in buffer.
"""
ret = self.n_step_buffer[prl.REW][self.n_step - 1].clone()
done = self.n_step_buffer[prl.DONE][self.n_step - 1].clone()
for i in reversed(range(self.n_step - 1)):
ret = ret * self.gamma * (1 - self.n_step_buffer[prl.DONE][i + 1])\
+ self.n_step_buffer[prl.REW][i]
done = done + self.n_step_buffer[prl.DONE][i]
self.n_step_buffer[prl.REW].popleft()
self.n_step_buffer[prl.DONE].popleft()
return ret.cpu(), done.cpu()
[docs] def generate_batches(self, num_mini_batch, mini_batch_size, num_epochs=1):
"""
Returns a batch iterator to update actor.
Parameters
----------
num_mini_batch : int
Number mini batches per epoch.
mini_batch_size : int
Number of samples contained in each mini batch.
num_epochs : int
Number of epochs.
Yields
------
batch : dict
Generated data batches.
"""
num_proc = self.data[prl.DONE].shape[1]
for _ in range(num_mini_batch):
if self.recurrent_actor: # Batches to feed a recurrent actor
sequences_x_batch = mini_batch_size // self.sequence_length + 1
assert self.size % self.sequence_length == 0, \
"Buffer does not contain an integer number of complete rollout sequences"
# Define batch structure
batch = {k: [] if not isinstance(self.data[k], dict) else
{x: [] for x in self.data[k]} for k in self.data.keys()}
# Randomly select sequences
seq_idxs = np.random.randint(0, num_proc * int(
self.size / self.sequence_length), size=sequences_x_batch)
# Get data indexes
idxs = []
for idx in seq_idxs:
idxs += range(idx * self.sequence_length, (idx + 1) * self.sequence_length)
# Fill up batch with data
for k, v in self.data.items():
# Only first recurrent state in each sequence needed
positions = seq_idxs * self.sequence_length if k in (prl.RHS, prl.RHS2) else idxs
if isinstance(v, dict):
for x, y in v.items():
t = dim0_reshape(y, self.size)[positions]
batch[k][x] = torch.as_tensor(t, dtype=torch.float32).to(self.device)
else:
t = dim0_reshape(v, self.size)[positions]
batch[k] = torch.as_tensor(t, dtype=torch.float32).to(self.device)
batch.update({"n_step": self.n_step})
yield batch
else:
batch = {k: {} for k in self.data.keys()}
samples = np.random.randint(0, num_proc * self.size, size=mini_batch_size)
for k, v in self.data.items():
if k in (prl.RHS, prl.RHS2):
size, idxs = 1, np.array([0])
else:
size, idxs = self.size, samples
if isinstance(v, dict):
for x, y in v.items():
batch[k][x] = torch.as_tensor(y[0:size].reshape(
-1, *y.shape[2:])[idxs], dtype=torch.float32).to(self.device)
else:
batch[k] = torch.as_tensor(v[0:size].reshape(
-1, *v.shape[2:])[idxs], dtype=torch.float32).to(self.device)
batch.update({"n_step": self.n_step})
yield batch
[docs] def update_storage_parameter(self, parameter_name, new_parameter_value):
"""
If `parameter_name` is an attribute of the algorithm, change its value
to `new_parameter_value value`.
Parameters
----------
parameter_name : str
Attribute name
new_parameter_value : int or float
New value for `parameter_name`.
"""
if hasattr(self, parameter_name):
if parameter_name == "max_size" and self.recurrent_actor:
new_parameter_value = (new_parameter_value // self.sequence_length) * self.sequence_length
new_parameter_value *= 2
new_parameter_value += self.n_step
setattr(self, parameter_name, new_parameter_value)