pytorchrl.scheme.updates package

Submodules

pytorchrl.scheme.updates.u_worker module

class pytorchrl.scheme.updates.u_worker.UWorker(grad_workers_factory, index_worker=0, col_fraction_workers=1.0, grad_execution='Central', grad_communication='synchronous', decentralized_update_execution=False, local_device=None)[source]

Bases: pytorchrl.scheme.base.worker.Worker

Update worker. Handles actor updates.

This worker receives gradients from gradient workers and then handle actor model updates. Updated weights are broadcasted back to gradient workers if required by the training scheme.

Parameters
  • grad_workers_factory (func) – A function that creates a set of gradientc omputation workers.

  • index_worker (int) – Worker index.

  • col_fraction_workers (float) – Minimum fraction of samples required to stop if collection is synchronously coordinated and most workers have finished their collection task.

  • grad_execution (str) – Execution patterns for gradients computation.

  • grad_communication (str) – Communication coordination pattern for gradient computation workers.

  • decentralized_update_execution (bool) – Whether the gradients are applied in the update workers (central update) or broadcasted to all gradient workers for a decentralized update.

  • local_device (str) – “cpu” or specific GPU “cuda:number`” to use for computation.

grad_execution

Execution patterns for gradients computation.

Type

str

grad_communication

Communication coordination pattern for gradient computation workers.

Type

str

grad_workers

A GWorkerSet class instance.

Type

GWorkerSet

local_worker

grad_workers local worker.

Type

GWorker

remote_workers

grad_workers remote data collection workers.

Type

List

num_workers

Number of gradient remote workers.

Type

int

updater

Class handling updates, calling grad_workers to get gradients, performing update steps and placing update information into the output queue outqueue.

Type

UpdaterThread

property actor_version

Number of times Actor has been updated.

save_model(fname)[source]

Save current version of actor as a torch loadable checkpoint.

Parameters

fname (str) – Filename given to the checkpoint.

Returns

save_name – Path to saved file.

Return type

str

step()[source]

Pulls information from update operations from self.updater.outqueue.

stop()[source]

Stop remote workers

update_algorithm_parameter(parameter_name, new_parameter_value)[source]

If parameter_name is an attribute of Worker.algo, change its value to new_parameter_value value.

Parameters

parameter_name (str) – Algorithm attribute name

class pytorchrl.scheme.updates.u_worker.UpdaterThread(local_worker, remote_workers, decentralized_update_execution, col_fraction_workers=1.0, grad_execution='Central', grad_communication='synchronous')[source]

Bases: threading.Thread

This class receives data from the workers and continuously updates central actor.

Parameters
  • local_worker (GWorker) – Local GWorker that acts as a parameter server.

  • remote_workers (List) – grad_workers remote data collection workers.

  • decentralized_update_execution (bool) – Whether decentralized execution pattern for update steps is enabled or not.

  • col_fraction_workers (float) – Minimum fraction of samples required to stop if collection is synchronously coordinated and most workers have finished their collection task.

  • grad_execution (str) – Execution patterns for gradients computation.

  • grad_communication (str) – Communication coordination pattern for gradient computation workers.

stopped

Whether or not the thread in running.

Type

bool

outqueue

Queue to store the info dicts resulting from the model update operation.

Type

queue.Queue

local_worker

Local worker that acts as a parameter server.

Type

Worker

remote_workers

Minimum fraction of samples required to stop if collection is synchronously coordinated and most workers have finished their collection task.

Type

List

num_workers

Number of gradient remote workers.

Type

int

decentralized_update_execution

Whether decentralized execution pattern for update steps is enabled or not.

Type

bool

col_fraction_workers

Minimum fraction of samples required to stop if collection is synchronously coordinated and most workers have finished their collection task.

Type

float

grad_execution

Execution patterns for gradients computation.

Type

str

grad_communication

Communication coordination pattern for gradient computation workers.

Type

str

run()[source]

Method representing the thread’s activity.

You may override this method in a subclass. The standard run() method invokes the callable object passed to the object’s constructor as the target argument, if any, with sequential and keyword arguments taken from the args and kwargs arguments, respectively.

step()[source]

Takes a logical optimization step and places output information in the output queue.

sync_weights()[source]

Synchronize gradient worker models with updater worker model

Module contents