Source code for pytorchrl.scheme.scheme

import logging
logger = logging.getLogger(__name__)

import pytorchrl as prl
from pytorchrl.scheme.collection.c_worker_set import CWorkerSet
from pytorchrl.scheme.gradients.g_worker_set import GWorkerSet
from pytorchrl.scheme.updates.u_worker import UWorker


[docs]class Scheme: """ Class to define training schemes and handle creation and operation of its workers. Parameters ---------- algo_factory : func A function that creates an algorithm class. actor_factory : func A function that creates a policy. storage_factory : func A function that create a rollouts storage. train_envs_factory : func A function to create train environments. test_envs_factory : func A function to create test environments. num_col_workers : int Number of data collection workers per gradient worker. col_workers_communication : str Communication coordination pattern for data collection. col_workers_resources : dict Ray resource specs for collection remote workers. col_preemption_thresholds : dict specs about minimum fraction_samples [0 - 1.0] and minimum fraction_workers [0 - 1.0] required in synchronous data collection. num_grad_workers : int Number of gradient workers. grad_workers_communication : str Communication coordination pattern for gradient computation workers. grad_workers_resources : dict Ray resource specs for gradient remote workers. local_device : str "cpu" or specific GPU "cuda:`number`" to use for computation. decentralized_update_execution : bool Whether the gradients are applied in the update workers (central update) or broadcasted to all gradient workers for a decentralized update. """ def __init__(self, # core algo_factory, actor_factory, storage_factory, train_envs_factory, test_envs_factory=lambda v, x, y, z: None, # collection num_col_workers=1, col_compress_data=False, col_workers_communication=prl.SYNC, col_workers_resources={"num_cpus": 1, "num_gpus": 0.5}, col_preemption_thresholds={"fraction_samples": 1.0, "fraction_workers": 1.0}, # gradients num_grad_workers=1, grad_compress_data=False, grad_workers_communication=prl.SYNC, grad_workers_resources={"num_cpus": 1, "num_gpus": 0.5}, # update local_device=None, decentralized_update_execution=False, # OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 main.py ): assert col_workers_communication in (prl.SYNC, prl.ASYNC),\ "col_workers_communication can only be `prl.SYNC` or `prl.ASYNC`" assert grad_workers_communication in (prl.SYNC, prl.ASYNC),\ "grad_workers_communication can only be `prl.SYNC` or `prl.ASYNC`" col_execution = prl.PARALLEL if num_col_workers > 1 else prl.CENTRAL grad_execution = prl.PARALLEL if num_grad_workers > 1 else prl.CENTRAL col_workers_factory = CWorkerSet.create_factory( # core modules algo_factory=algo_factory, actor_factory=actor_factory, storage_factory=storage_factory, test_envs_factory=test_envs_factory, train_envs_factory=train_envs_factory, # col specs num_workers=num_col_workers - 1 if num_col_workers == 1 else num_col_workers, col_worker_resources=col_workers_resources, col_fraction_samples=col_preemption_thresholds.get("fraction_samples"), compress_data_to_send=col_compress_data, # grad specs total_parent_workers=num_grad_workers - 1 if num_grad_workers == 1 else num_grad_workers, ) grad_workers_factory = GWorkerSet.create_factory( # col specs col_execution=col_execution, col_communication=col_workers_communication, col_workers_factory=col_workers_factory, col_fraction_workers=col_preemption_thresholds.get("fraction_workers"), # grad_specs num_workers=num_grad_workers - 1 if num_grad_workers == 1 else num_grad_workers, grad_worker_resources=grad_workers_resources, compress_grads_to_send=grad_compress_data, ) self._update_worker = UWorker( # col specs col_fraction_workers=col_preemption_thresholds.get("fraction_workers"), # grad specs grad_execution=grad_execution, grad_communication=grad_workers_communication, grad_workers_factory=grad_workers_factory, # update specs local_device=local_device, decentralized_update_execution=decentralized_update_execution, ) logger.warning("Created training scheme.")
[docs] def update_worker(self): """Returns local worker""" return self._update_worker
[docs] def get_agent_components(self): """Returns class names for each agent component.""" return { "Actor": self._update_worker.local_worker.local_worker.actor.__class__.__name__, "Algorithm": self._update_worker.local_worker.local_worker.algo.__class__.__name__, "Storage": self._update_worker.local_worker.local_worker.storage.__class__.__name__, }