Source code for pytorchrl.agent.storages.off_policy.replay_buffer

import numpy as np
import torch

import pytorchrl as prl
from pytorchrl.agent.storages.base import Storage as S


[docs]def dim0_reshape(tensor, size): """ Reshapes tensor so indices are defined like this: 00, 01, 02, 03, 04, 05, 06, 07, 08, 09, size + 1, ..., self.max_size 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, size + 1, ..., self.max_size 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, size + 1, ..., self.max_size """ return np.moveaxis(tensor, [0, 1], [1, 0])[:, 0: size].reshape(-1, *tensor.shape[2:])
[docs]class ReplayBuffer(S): """ Storage class for Off-Policy algorithms. Parameters ---------- size : int Storage capacity along time axis. device : torch.device CPU or specific GPU where data tensors will be placed and class computations will take place. Should be the same device where the actor model is located. actor : Actor Actor class instance. algorithm : Algorithm Algorithm class instance envs : VecEnv Vector of environments instance. """ # Data fields to store in buffer and contained in the generated batches storage_tensors = prl.OffPolicyDataKeys def __init__(self, size, device, actor, algorithm, envs): self.actor = actor self.device = device self.algo = algorithm self.recurrent_actor = actor.is_recurrent self.max_size, self.size, self.step = size, 0, 0 self.data = {} # lazy init if self.recurrent_actor: self.sequence_length = algorithm.update_every self.overlap_length = int(actor.sequence_overlap * self.sequence_length) self.non_overlap_length = self.sequence_length - self.overlap_length self.max_size = (self.max_size // self.sequence_length) * self.sequence_length self.max_size *= 2 # to account for consecutive overlapping of sequences self.staged_overlap = None self.reset()
[docs] @classmethod def create_factory(cls, size): """ Returns a function that creates ReplayBuffer instances. Parameters ---------- size : int Storage capacity along time axis. Returns ------- create_buffer_instance : func creates a new ReplayBuffer class instance. """ def create_buffer(device, actor, algorithm, envs): """Create and return a ReplayBuffer instance.""" return cls(size, device, actor, algorithm, envs) return create_buffer
[docs] def init_tensors(self, sample): """ Lazy initialization of data tensors from a sample. Parameters ---------- sample : dict Data sample (containing all tensors of an environment transition) """ for k, v in sample.items(): if k not in self.storage_tensors or v is None: continue if not self.recurrent_actor and k in (prl.RHS, prl.RHS2): size = 1 else: size = self.max_size if isinstance(v, dict): self.data[k] = {} for x, y in sample[k].items(): self.data[k][x] = np.zeros((size, *y.shape), dtype=np.float32) else: self.data[k] = np.zeros((size, *v.shape), dtype=np.float32)
[docs] def get_data_slice(self, start_pos, end_pos): """ Makes a copy of all tensors in the bufer between steps `start_pos` and `end_pos`. Parameters ---------- start_pos : int initial slice position. end_pos : int final slice position. Returns ------- data : dict data slice copied from the buffer. """ copied_data = {k: None for k in self.storage_tensors} for k, v in self.data.items(): if v is None: continue if isinstance(self.data[k], dict): copied_data[k] = {x: None for x in self.data[k]} for x, y in v.items(): copied_data[k][x] = np.copy(y[start_pos:end_pos]) else: copied_data[k] = np.copy(v[start_pos:end_pos]) return copied_data
[docs] def get_all_buffer_data(self, data_to_cpu=False): """ Return all currently stored data. If data_to_cpu, no need to do anything since data tensors are already in cpu memory. Parameters ---------- data_to_cpu : bool Whether or not to move data tensors to cpu memory. Returns ------- data : dict data currently stored in the buffer. """ # Define data structure data = {k: None if not isinstance(self.data[k], dict) else {x: None for x in self.data[k]} for k in self.data} # If recurrent, get only whole sequences if self.recurrent_actor: idx = int((self.step // self.sequence_length) * self.sequence_length) else: idx = self.step # Fill up data for k, v in self.data.items(): if v is None: continue if isinstance(self.data[k], dict): for x, y in self.data[k].items(): data[k][x] = y[:idx] else: data[k] = v[:idx] # If self.actor uses RNNs, save ending of last sequence if self.recurrent_actor and data_to_cpu: self.staged_overlap = self.get_data_slice( self.step - self.overlap_length, self.step) return data
[docs] def reset(self): """ Set class size and step to zero. If self.actor uses RNNs, add overlap slice of last sequence before reset at the beginning of the storage. """ self.size -= self.step self.step = 0 if self.recurrent_actor and self.staged_overlap: self.insert_data_slice(self.staged_overlap)
[docs] def insert_data_slice(self, new_data): """ Appends new_data to currently stored data. Parameters ---------- new_data : dict Dictionary of env transition samples to be added to self.data. """ lengths = [] for k, v in new_data.items(): if v is None: continue if not self.recurrent_actor and k in (prl.RHS, prl.RHS2): continue if isinstance(new_data[k], dict): if self.data[k] is None: self.data[k] = {i: None for i in new_data[k].keys()} for x, y in new_data[k].items(): length = self.insert_single_tensor_slice(self.data[k], x, y) lengths.append(length) else: length = self.insert_single_tensor_slice(self.data, k, v) lengths.append(length) assert len(set(lengths)) == 1 self.step = (self.step + length) % self.max_size self.size = min(self.size + length, self.max_size)
[docs] def insert_single_tensor_slice(self, tensor_storage, tensor_key, tensor_values): """ Appends tensor_value to buffer dict using tensor_key as key. Parameters ---------- tensor_storage : tensor_key : str key to use to store the tensor. tensor_values : np.ndarray tensor values. Returns ------- l : int length (time axe) of the tensor added to the buffer. """ l = tensor_values.shape[0] if tensor_storage[tensor_key] is None: # If not defined, initialize tensor tensor_storage[tensor_key] = np.zeros((self.max_size, *tensor_values.shape[1:]), dtype=np.float32) if self.step + l <= self.max_size: # If enough space, add tensor at the end tensor_storage[tensor_key][self.step:self.step + l] = tensor_values else: # Circular buffer tensor_storage[tensor_key][ self.step:self.max_size] = tensor_values[0:self.max_size - self.step] tensor_storage[tensor_key][0:l - self.max_size + self.step] = tensor_values[self.max_size - self.step:] return l
[docs] def insert_transition(self, sample): """ Store new transition sample. Parameters ---------- sample : dict Data sample (containing all tensors of an environment transition) """ # Data tensors lazy initialization if self.size == 0 and prl.OBS not in self.data.keys(): self.init_tensors(sample) # If using memory, save fixed length consecutive overlapping sequences if self.recurrent_actor and self.step % self.sequence_length == 0 and self.step != 0 and self.overlap_length > 0: next_seq_overlap = self.get_data_slice(self.step - self.overlap_length, self.step) self.insert_data_slice(next_seq_overlap) # Insert for k, v in sample.items(): if k not in self.storage_tensors: continue if not self.recurrent_actor and k in (prl.RHS, prl.RHS2): continue if isinstance(sample[k], dict): for x, y in sample[k].items(): self.data[k][x][self.step] = y.cpu() else: self.data[k][self.step] = v.cpu() # Update self.step = (self.step + 1) % self.max_size self.size = min(self.size + 1, self.max_size)
[docs] def before_gradients(self): """ Steps required before updating actor policy model. """ pass
[docs] def after_gradients(self, batch, info): """ Steps required after updating actor policy model Parameters ---------- batch : dict Data batch used to compute the gradients. info : dict Additional relevant info from gradient computation. Returns ------- info : dict info dict updated with relevant info from Storage. """ return info
[docs] def generate_batches(self, num_mini_batch, mini_batch_size, num_epochs=1): """ Returns a batch iterator to update actor. Parameters ---------- num_mini_batch : int Number mini batches per epoch. mini_batch_size : int Number of samples contained in each mini batch. num_epochs : int Number of epochs. Yields ------ batch : dict Generated data batches. """ num_proc = self.data[prl.DONE].shape[1] for _ in range(num_mini_batch): if self.recurrent_actor: # Batches to feed a recurrent actor sequences_x_batch = mini_batch_size // self.sequence_length + 1 assert self.size % self.sequence_length == 0, \ "Buffer does not contain an integer number of complete rollout sequences" # Define batch structure batch = {k: [] if not isinstance(self.data[k], dict) else {x: [] for x in self.data[k]} for k in self.data.keys()} # Randomly select sequences seq_idxs = np.random.randint(0, num_proc * int( self.size / self.sequence_length), size=sequences_x_batch) # Get data indexes idxs = [] for idx in seq_idxs: idxs += range(idx * self.sequence_length, (idx + 1) * self.sequence_length) # Fill up batch with sequences data for k, v in self.data.items(): # Only first recurrent state in each sequence needed positions = seq_idxs * self.sequence_length if k in (prl.RHS, prl.RHS2) else idxs if isinstance(v, dict): for x, y in v.items(): t = dim0_reshape(y, self.size)[positions] batch[k][x] = torch.as_tensor(t, dtype=torch.float32).to(self.device) else: t = dim0_reshape(v, self.size)[positions] batch[k] = torch.as_tensor(t, dtype=torch.float32).to(self.device) yield batch else: batch = {k: {} for k in self.data.keys()} samples = np.random.randint(0, num_proc * self.size, size=mini_batch_size) for k, v in self.data.items(): if k in (prl.RHS, prl.RHS2): size, idxs = 1, np.array([0]) else: size, idxs = self.size, samples if isinstance(v, dict): for x, y in v.items(): batch[k][x] = torch.as_tensor(y[0:size].reshape( -1, *y.shape[2:])[idxs], dtype=torch.float32).to(self.device) else: batch[k] = torch.as_tensor(v[0:size].reshape( -1, *v.shape[2:])[idxs], dtype=torch.float32).to(self.device) yield batch
[docs] def update_storage_parameter(self, parameter_name, new_parameter_value): """ If `parameter_name` is an attribute of the algorithm, change its value to `new_parameter_value value`. Parameters ---------- parameter_name : str Attribute name new_parameter_value : int or float New value for `parameter_name`. """ if hasattr(self, parameter_name): if parameter_name == "max_size" and self.recurrent_actor: new_parameter_value = (new_parameter_value // self.sequence_length) * self.sequence_length new_parameter_value *= 2 setattr(self, parameter_name, new_parameter_value)