Scheme
- class pytorchrl.scheme.scheme.Scheme(algo_factory, actor_factory, storage_factory, train_envs_factory, test_envs_factory=<function Scheme.<lambda>>, num_col_workers=1, col_compress_data=False, col_workers_communication='synchronous', col_workers_resources={'num_cpus': 1, 'num_gpus': 0.5}, col_preemption_thresholds={'fraction_samples': 1.0, 'fraction_workers': 1.0}, num_grad_workers=1, grad_compress_data=False, grad_workers_communication='synchronous', grad_workers_resources={'num_cpus': 1, 'num_gpus': 0.5}, local_device=None, decentralized_update_execution=False)[source]
Bases:
object
Class to define training schemes and handle creation and operation of its workers.
- Parameters
algo_factory (func) – A function that creates an algorithm class.
actor_factory (func) – A function that creates a policy.
storage_factory (func) – A function that create a rollouts storage.
train_envs_factory (func) – A function to create train environments.
test_envs_factory (func) – A function to create test environments.
num_col_workers (int) – Number of data collection workers per gradient worker.
col_workers_communication (str) – Communication coordination pattern for data collection.
col_workers_resources (dict) – Ray resource specs for collection remote workers.
col_preemption_thresholds (dict) – specs about minimum fraction_samples [0 - 1.0] and minimum fraction_workers [0 - 1.0] required in synchronous data collection.
num_grad_workers (int) – Number of gradient workers.
grad_workers_communication (str) – Communication coordination pattern for gradient computation workers.
grad_workers_resources (dict) – Ray resource specs for gradient remote workers.
local_device (str) – “cpu” or specific GPU “cuda:number” to use for computation.
decentralized_update_execution (bool) – Whether the gradients are applied in the update workers (central update) or broadcasted to all gradient workers for a decentralized update.