pytorchrl.scheme.updates package
Submodules
pytorchrl.scheme.updates.u_worker module
- class pytorchrl.scheme.updates.u_worker.UWorker(grad_workers_factory, index_worker=0, col_fraction_workers=1.0, grad_execution='Central', grad_communication='synchronous', decentralized_update_execution=False, local_device=None)[source]
Bases:
pytorchrl.scheme.base.worker.WorkerUpdate worker. Handles actor updates.
This worker receives gradients from gradient workers and then handle actor model updates. Updated weights are broadcasted back to gradient workers if required by the training scheme.
- Parameters
grad_workers_factory (func) – A function that creates a set of gradientc omputation workers.
index_worker (int) – Worker index.
col_fraction_workers (float) – Minimum fraction of samples required to stop if collection is synchronously coordinated and most workers have finished their collection task.
grad_execution (str) – Execution patterns for gradients computation.
grad_communication (str) – Communication coordination pattern for gradient computation workers.
decentralized_update_execution (bool) – Whether the gradients are applied in the update workers (central update) or broadcasted to all gradient workers for a decentralized update.
local_device (str) – “cpu” or specific GPU “cuda:number`” to use for computation.
- grad_execution
Execution patterns for gradients computation.
- Type
str
- grad_communication
Communication coordination pattern for gradient computation workers.
- Type
str
- grad_workers
A GWorkerSet class instance.
- Type
- remote_workers
grad_workers remote data collection workers.
- Type
List
- num_workers
Number of gradient remote workers.
- Type
int
- updater
Class handling updates, calling grad_workers to get gradients, performing update steps and placing update information into the output queue outqueue.
- Type
- property actor_version
Number of times Actor has been updated.
- class pytorchrl.scheme.updates.u_worker.UpdaterThread(local_worker, remote_workers, decentralized_update_execution, col_fraction_workers=1.0, grad_execution='Central', grad_communication='synchronous')[source]
Bases:
threading.ThreadThis class receives data from the workers and continuously updates central actor.
- Parameters
local_worker (GWorker) – Local GWorker that acts as a parameter server.
remote_workers (List) – grad_workers remote data collection workers.
decentralized_update_execution (bool) – Whether decentralized execution pattern for update steps is enabled or not.
col_fraction_workers (float) – Minimum fraction of samples required to stop if collection is synchronously coordinated and most workers have finished their collection task.
grad_execution (str) – Execution patterns for gradients computation.
grad_communication (str) – Communication coordination pattern for gradient computation workers.
- stopped
Whether or not the thread in running.
- Type
bool
- outqueue
Queue to store the info dicts resulting from the model update operation.
- Type
queue.Queue
- remote_workers
Minimum fraction of samples required to stop if collection is synchronously coordinated and most workers have finished their collection task.
- Type
List
- num_workers
Number of gradient remote workers.
- Type
int
- decentralized_update_execution
Whether decentralized execution pattern for update steps is enabled or not.
- Type
bool
- col_fraction_workers
Minimum fraction of samples required to stop if collection is synchronously coordinated and most workers have finished their collection task.
- Type
float
- grad_execution
Execution patterns for gradients computation.
- Type
str
- grad_communication
Communication coordination pattern for gradient computation workers.
- Type
str
- run()[source]
Method representing the thread’s activity.
You may override this method in a subclass. The standard run() method invokes the callable object passed to the object’s constructor as the target argument, if any, with sequential and keyword arguments taken from the args and kwargs arguments, respectively.