Source code for pytorchrl.scheme.gradients.g_worker

import os
import ray
import time
import torch
import queue
import threading
from shutil import copy2
from copy import deepcopy

import pytorchrl as prl
from pytorchrl.scheme.base.worker import Worker as W
from pytorchrl.scheme.utils import ray_get_and_free, broadcast_message, pack, unpack

# Puts a limit to the allowed policy lag
max_queue_size = 10


[docs]class GWorker(W): """ Worker class handling gradient computation. This class wraps an actor instance, a storage class instance and a worker set of remote data collection workers. It receives data from the collection workers and computes gradients following a logic defined in function self.step(), which will be called from the Learner class. Parameters ---------- index_worker : int Worker index. col_workers_factory : func A function that creates a set of data collection workers. col_communication : str Communication coordination pattern for data collection. compress_grads_to_send : bool Whether or not to compress gradients before sending then to the update worker. col_execution : str Execution patterns for data collection. col_fraction_workers : float Minimum fraction of samples required to stop if collection is synchronously coordinated and most workers have finished their collection task. device : str "cpu" or specific GPU "cuda:number`" to use for computation. initial_weights : ray object ID Initial model weights. Attributes ---------- index_worker : int Index assigned to this worker. iter : int Number of times gradients have been computed and sent. col_communication : str Communication coordination pattern for data collection. compress_grads_to_send : bool Whether or not to compress gradients before sending then to the update worker. col_workers : CWorkerSet A CWorkerSet class instance. local_worker : CWorker col_workers local worker. remote_workers : List of CWorker's col_workers remote data collection workers. actor : Actor An actor class instance. algo : Algo An algorithm class instance. storage : Storage A Storage class instance. inqueue : queue.Queue Input queue where incoming collected samples are placed. collector : CollectorThread Class handling data collection via col_workers and placing incoming rollouts into the input queue `inqueue`. """ def __init__(self, index_worker, col_workers_factory, col_communication=prl.SYNC, compress_grads_to_send=False, col_execution=prl.CENTRAL, col_fraction_workers=1.0, initial_weights=None, device=None): self.index_worker = index_worker super(GWorker, self).__init__(index_worker) # Define counters and other attributes self.iter = 0 self.col_communication = col_communication self.compress_grads_to_send = compress_grads_to_send self.processing_time_start = None # Computation device dev = device or "cuda" if torch.cuda.is_available() else "cpu" # Create CWorkerSet instance self.col_workers = col_workers_factory(dev, initial_weights, index_worker) self.local_worker = self.col_workers.local_worker() self.remote_workers = self.col_workers.remote_workers() self.num_remote_workers = len(self.remote_workers) self.num_collection_workers = 1 if self.num_remote_workers == 0 else self.num_remote_workers # Get Actor Critic instance self.actor = self.local_worker.actor # Get Algorithm instance self.algo = self.local_worker.algo # Get storage instance. if col_communication == prl.ASYNC and self.local_worker.envs_train is not None: # If async collection, c_worker and g_worker need different storage instances # To avoid overwriting data. Define c_worker one with the minimum size required, # to save memory size1 = self.local_worker.storage.max_size size2 = self.local_worker.algo.update_every size2 = size2 * 2 if size2 is not None else float("Inf") new_size = min(size1, size2) if self.local_worker.envs_train is not None: # Create a storage class deepcopy, but be careful with envs attribute, # can not be deep-copied envs = getattr(self.local_worker.storage, "envs") self.local_worker.storage.envs = None self.storage = deepcopy(self.local_worker.storage) self.local_worker.storage.envs = envs self.local_worker.update_storage_parameter("max_size", new_size) else: self.storage = self.local_worker.storage for e in self.remote_workers: e.update_storage_parameter.remote("max_size", new_size) else: self.storage = self.local_worker.storage if len(self.remote_workers) > 0 or (len(self.remote_workers) == 0 and self.local_worker.envs_train is not None): # Create CollectorThread self.collector = CollectorThread( index_worker=index_worker, local_worker=self.local_worker, remote_workers=self.remote_workers, col_communication=col_communication, col_fraction_workers=col_fraction_workers, col_execution=col_execution, broadcast_interval=1) # Print worker information self.print_worker_info() @property def actor_version(self): """Number of times Actor has been updated.""" return self.local_worker.actor_version
[docs] def step(self, distribute_gradients=False): """ Pulls data from `self.collector.queue`, then perform a gradient computation step. Parameters ---------- distribute_gradients : bool If True, gradients will be directly shared across remote workers and optimization steps will executed in a decentralised way. Returns ------- grads : list of tensors List of actor gradients. info : dict Summary dict of relevant gradient operation information. """ self.get_data() grads, info = self.get_grads(distribute_gradients) if distribute_gradients: self.apply_gradients() # Encode data if self.compress_data_to_send is True data_to_send = pack((grads, info)) if self.compress_grads_to_send else (grads, info) return data_to_send
[docs] def get_data(self): """ Pulls data from `self.collector.queue` and prepares batches to compute gradients. """ try: # Try pulling next batch self.next_batch = self.batches.__next__() # Only use episode information once self.info.pop(prl.EPISODES, None) # Only record collected samples once self.info[prl.NUMSAMPLES] = 0 # except (StopIteration, AttributeError): except Exception: # Check if more batches with the same data are required can_reuse_data = hasattr(self.algo, "reuse_data") if not can_reuse_data or not self.algo.reuse_data: # Get new data if self.col_communication == prl.SYNC: self.collector.step() data, self.info = self.collector.queue.get() # Calculate and log data collection-to-processing time ratio if not self.processing_time_start: self.processing_time_start = time.time() else: self.info[prl.TIME][prl.PROCESSING] = time.time() - self.processing_time_start self.info[prl.TIME][prl.CPRATIO] = ( self.info[prl.TIME][prl.COLLECTION] / self.num_collection_workers) / self.info[prl.TIME][ prl.PROCESSING] self.processing_time_start = time.time() # Proprocess new data self.storage.insert_data_slice(data) self.storage.before_gradients() else: # Only record collected samples once self.info[prl.NUMSAMPLES] = 0 # Genarate batches self.batches = self.storage.generate_batches( self.algo.num_mini_batch, self.algo.mini_batch_size, self.algo.num_epochs) self.next_batch = self.batches.__next__()
[docs] def get_grads(self, distribute_gradients=False): """ Perform a gradient computation step. Parameters ---------- distribute_gradients : bool If True, gradients will be directly shared across remote workers and optimization steps will executed in a decentralised way. Returns ------- grads : list of tensors List of actor gradients. info : dict Summary dict of relevant gradient operation information. """ # Get gradients and algorithm-related information t = time.time() grads, algo_info = self.compute_gradients(self.next_batch, distribute_gradients) compute_time = time.time() - t # Add extra information to info dict self.info[prl.ALGORITHM] = algo_info self.info[prl.VERSION][prl.GRADIENT] = self.local_worker.actor_version self.info[prl.TIME][prl.GRADIENT] = compute_time # Run after gradients data process (if any) info = self.storage.after_gradients(self.next_batch, self.info) # Update iteration counter self.iter += 1 return grads, info
[docs] def compute_gradients(self, batch, distribute_gradients): """ Calculate actor gradients and update networks. Parameters ---------- batch : dict data batch containing all required tensors to compute algo loss. distribute_gradients : bool If True, gradients will be directly shared across remote workers and optimization steps will executed in a decentralised way. Returns ------- grads : list of tensors List of actor gradients. info : dict Summary dict with relevant gradient-related information. """ grads, info = self.algo.compute_gradients(batch, grads_to_cpu=not distribute_gradients) if distribute_gradients: if torch.cuda.is_available(): for g in grads: torch.distributed.all_reduce(g, op=torch.distributed.ReduceOp.SUM) else: torch.distributed.all_reduce_coalesced(grads, op=torch.distributed.ReduceOp.SUM) for p in self.actor.parameters(): if p.grad is not None: p.grad /= self.num_remote_workers grads = None return grads, info
[docs] def apply_gradients(self, gradients=None): """Update Actor Critic model""" self.local_worker.actor_version += 1 self.algo.apply_gradients(gradients) if self.col_communication == prl.SYNC and len(self.remote_workers) > 0: self.collector.broadcast_new_weights()
[docs] def set_weights(self, actor_weights): """ Update the worker actor version with provided weights. weights : dict of tensors Dict containing actor weights to be set. """ self.local_worker.actor_version = actor_weights[prl.VERSION] self.local_worker.algo.set_weights(actor_weights[prl.WEIGHTS])
[docs] def update_algorithm_parameter(self, parameter_name, new_parameter_value): """ If `parameter_name` is an attribute of Worker.algo, change its value to `new_parameter_value value`. Parameters ---------- parameter_name : str Algorithm attribute name """ self.local_worker.update_algorithm_parameter(parameter_name, new_parameter_value) for e in self.remote_workers: e.update_algorithm_parameter.remote(parameter_name, new_parameter_value) self.algo.update_algorithm_parameter(parameter_name, new_parameter_value) for e in self.col_workers.remote_workers(): e.update_algorithm_parameter.remote(parameter_name, new_parameter_value)
[docs] def save_model(self, fname): """ Save current version of actor as a torch loadable checkpoint. Parameters ---------- fname : str Filename given to the checkpoint. Returns ------- save_name : str Path to saved file. """ torch.save(self.local_worker.actor.state_dict(), fname + ".tmp") os.replace(fname + '.tmp', fname) save_name = fname + ".{}".format(self.local_worker.actor_version) copy2(fname, save_name) return save_name
[docs] def stop(self): """Stop collecting data.""" if hasattr(self, "collector"): self.collector.stopped = True self.local_worker.stop() for e in self.remote_workers: e.stop.remote() e.terminate_worker.remote()
[docs]class CollectorThread(threading.Thread): """ This class receives data samples from the data collection workers and queues them into the data input_queue. Parameters ---------- index_worker : int Index assigned to this worker. input_queue : queue.Queue Queue to store the data dicts received from data collection workers. local_worker : Worker Local worker that acts as a parameter server. remote_workers : list of Workers Set of workers collecting and sending rollouts. col_fraction_workers : float Minimum fraction of samples required to stop if collection is synchronously coordinated and most workers have finished their collection task. col_communication : str Communication coordination pattern for data collection. col_execution : str Execution patterns for data collection. broadcast_interval : int After how many central updates, model weights should be broadcasted to remote collection workers. Attributes ---------- stopped : bool Whether or not the thread in running. queue : queue.Queue Queue to store the data dicts received from data collection workers. index_worker : int Index assigned to this worker. local_worker : CWorker col_workers local worker. remote_workers : List of CWorker's col_workers remote data collection workers. num_workers : int Number of collection remote workers. broadcast_interval : int After how many collection step model weights should be broadcasted to remote collection workers. num_sent_since_broadcast : int Number of data dicts received since last model weights were broadcasted. """ def __init__(self, index_worker, local_worker, remote_workers, col_fraction_workers=1.0, col_communication=prl.SYNC, col_execution=prl.CENTRAL, broadcast_interval=1): self.index_worker = index_worker threading.Thread.__init__(self) self.stopped = False self.queue = queue.SimpleQueue() self.col_execution = col_execution self.col_communication = col_communication self.broadcast_interval = broadcast_interval self.fraction_workers = col_fraction_workers self.local_worker = local_worker self.remote_workers = remote_workers self.num_remote_workers = len(self.remote_workers) self.num_collection_workers = 1 if self.num_remote_workers == 0 else self.num_remote_workers # Counters self.num_sent_since_broadcast = 0 if col_execution == prl.CENTRAL and col_communication == prl.SYNC: pass elif col_execution == prl.CENTRAL and col_communication == prl.ASYNC: if self.local_worker.envs_train: self.start() # Start CollectorThread elif col_execution == prl.PARALLEL and col_communication == prl.SYNC: pass elif col_execution == prl.PARALLEL and col_communication == prl.ASYNC: self.pending_tasks = {} self.broadcast_new_weights() for w in self.remote_workers: future = w.collect_data.remote() self.pending_tasks[future] = w self.start()
[docs] def run(self): while not self.stopped: # First, collect data self.step() # Then, update counter and broadcast weights to worker if necessary self.num_sent_since_broadcast += 1 if self.should_broadcast(): self.broadcast_new_weights()
[docs] def step(self): """ Collects data from remote workers and puts it in the GWorker queue. """ if self.col_execution == prl.CENTRAL and self.col_communication == prl.SYNC: rollouts = self.local_worker.collect_data(listen_to=["sync"], data_to_cpu=False) rollouts = unpack(rollouts) if type(rollouts) == str else rollouts self.queue.put(rollouts) while self.queue.qsize() >= max_queue_size: time.sleep(0.5) elif self.col_execution == prl.CENTRAL and self.col_communication == prl.ASYNC: rollouts = self.local_worker.collect_data(data_to_cpu=False) rollouts = unpack(rollouts) if type(rollouts) == str else rollouts self.queue.put(rollouts) while self.queue.qsize() >= max_queue_size: time.sleep(0.5) elif self.col_execution == prl.PARALLEL and self.col_communication == prl.SYNC: # Start data collection in all workers worker_key = "worker_{}".format(self.index_worker) broadcast_message(worker_key, b"start-continue") pending_samples = [e.collect_data.remote( listen_to=["sync", worker_key]) for e in self.remote_workers] if self.fraction_workers < 1.0: # Keep checking how many workers have finished until # percent% are ready samples_ready = [] while len(samples_ready) < (self.num_remote_workers * self.fraction_workers): samples_ready, samples_not_ready = ray.wait( pending_samples, num_returns=len(pending_samples), timeout=0.001) # Send stop message to the workers broadcast_message(worker_key, b"stop") # Compute model updates for r in pending_samples: rollouts = ray_get_and_free(r) rollouts = unpack(rollouts) if type(rollouts) == str else rollouts self.queue.put(rollouts) while self.queue.qsize() >= max_queue_size: time.sleep(0.5) elif self.col_execution == prl.PARALLEL and self.col_communication == prl.ASYNC: # Wait for first worker to finish assert len(list(self.pending_tasks.keys())) == len(self.remote_workers) wait_results = ray.wait(list(self.pending_tasks.keys())) future = wait_results[0][0] w = self.pending_tasks.pop(future) # Retrieve rollouts and add them to queue rollouts = ray_get_and_free(future) rollouts = unpack(rollouts) if type(rollouts) == str else rollouts self.queue.put(rollouts) while self.queue.qsize() >= max_queue_size: time.sleep(0.5) # Schedule a new collection task future = w.collect_data.remote() self.pending_tasks[future] = w
[docs] def should_broadcast(self): """Returns whether broadcast() should be called to update weights.""" return self.num_sent_since_broadcast >= self.broadcast_interval
[docs] def broadcast_new_weights(self): """Broadcast a new set of weights from the local worker.""" if self.num_remote_workers > 0: latest_weights = ray.put({ prl.VERSION: self.local_worker.actor_version, prl.WEIGHTS: self.local_worker.get_weights()}) for e in self.remote_workers: e.set_weights.remote(latest_weights) self.num_sent_since_broadcast = 0