Source code for pytorchrl.scheme.base.worker_set

import ray
import pytorchrl as prl
from pytorchrl.scheme.base.worker import default_remote_config


[docs]class WorkerSet: """ Class to better handle the operations of ensembles of Workers. Contains common functionality across all worker sets. Parameters ---------- worker : func A function that creates a worker class. worker_params : dict Worker class kwargs. worker_remote_config : dict Ray resource specs for the remote workers. num_workers : int Num workers replicas in the worker_set. add_local_worker : bool Whether or not to include have a non-remote worker in the worker set. Attributes ---------- worker_class : python class Worker class to be instantiated to create Ray remote actors. remote_config : dict Ray resource specs for the remote workers. worker_params : dict Keyword arguments of the worker_class. num_workers : int Number of remote workers in the worker set. """ def __init__(self, worker, worker_params, index_parent_worker, worker_remote_config=default_remote_config, num_workers=1, local_device=None, initial_weights=None, add_local_worker=True, total_parent_workers=None): self.worker_class = worker self.num_workers = num_workers self.worker_params = worker_params self.remote_config = worker_remote_config if add_local_worker: local_params = worker_params.copy() local_params.update( {"device": local_device, "initial_weights": initial_weights}) # If multiple grad workers the collection workers of grad worker with index 0 should not collect if worker.__name__ == "CWorker" and total_parent_workers > 0 and index_parent_worker == 0: self.num_workers = 0 _ = local_params.pop("test_envs_factory") _ = local_params.pop("train_envs_factory") # If multiple col workers, local collection workers don't need to collect elif worker.__name__ == "CWorker" and num_workers > 0: _ = local_params.pop("test_envs_factory") _ = local_params.pop("train_envs_factory") self._local_worker = self._make_worker( self.worker_class, index_worker=0, worker_params=local_params) else: self._local_worker = None self._remote_workers = [] if self.num_workers > 0: self.add_workers(self.num_workers) @staticmethod def _make_worker(cls, index_worker, worker_params): """ Create a single worker. Parameters ---------- index_worker : int Index assigned to remote worker. worker_params : dict Keyword parameters of the worker_class. Returns ------- w : python class An instance of worker class cls """ w = cls(index_worker=index_worker, **worker_params) return w
[docs] def add_workers(self, num_workers): """ Create and add a number of remote workers to this worker set. Parameters ---------- num_workers : int Number of remote workers to create. """ self.worker_params.update({"initial_weights": ray.put( {prl.VERSION: 0, prl.WEIGHTS: self._local_worker.get_weights()})}) cls = self.worker_class.as_remote(**self.remote_config).remote self._remote_workers.extend([ self._make_worker(cls, index_worker=i + 1, worker_params=self.worker_params) for i in range(num_workers)])
[docs] def local_worker(self): """Return local worker""" return self._local_worker
[docs] def remote_workers(self): """Returns list of remote workers""" return self._remote_workers
[docs] def stop(self): """Stop all remote workers""" for w in self.remote_workers(): w.__ray_terminate__.remote()