Scheme

class pytorchrl.scheme.scheme.Scheme(algo_factory, actor_factory, storage_factory, train_envs_factory, test_envs_factory=<function Scheme.<lambda>>, num_col_workers=1, col_compress_data=False, col_workers_communication='synchronous', col_workers_resources={'num_cpus': 1, 'num_gpus': 0.5}, col_preemption_thresholds={'fraction_samples': 1.0, 'fraction_workers': 1.0}, num_grad_workers=1, grad_compress_data=False, grad_workers_communication='synchronous', grad_workers_resources={'num_cpus': 1, 'num_gpus': 0.5}, local_device=None, decentralized_update_execution=False)[source]

Bases: object

Class to define training schemes and handle creation and operation of its workers.

Parameters
  • algo_factory (func) – A function that creates an algorithm class.

  • actor_factory (func) – A function that creates a policy.

  • storage_factory (func) – A function that create a rollouts storage.

  • train_envs_factory (func) – A function to create train environments.

  • test_envs_factory (func) – A function to create test environments.

  • num_col_workers (int) – Number of data collection workers per gradient worker.

  • col_workers_communication (str) – Communication coordination pattern for data collection.

  • col_workers_resources (dict) – Ray resource specs for collection remote workers.

  • col_preemption_thresholds (dict) – specs about minimum fraction_samples [0 - 1.0] and minimum fraction_workers [0 - 1.0] required in synchronous data collection.

  • num_grad_workers (int) – Number of gradient workers.

  • grad_workers_communication (str) – Communication coordination pattern for gradient computation workers.

  • grad_workers_resources (dict) – Ray resource specs for gradient remote workers.

  • local_device (str) – “cpu” or specific GPU “cuda:number” to use for computation.

  • decentralized_update_execution (bool) – Whether the gradients are applied in the update workers (central update) or broadcasted to all gradient workers for a decentralized update.

get_agent_components()[source]

Returns class names for each agent component.

update_worker()[source]

Returns local worker