Source code for pytorchrl.scheme.base.worker

import os
import ray
import torch
import logging
from ray._private.services import get_node_ip_address
from .utils import find_free_port

logger = logging.getLogger(__name__)

default_remote_config = {
    "num_cpus": 1,
    "num_gpus": 0.2,
    "memory": 5 * 1024 ** 3,
    "object_store_memory": 2 * 1024 ** 3
}


[docs]class Worker: """ Class containing common worker functionality. Parameters ---------- index_worker : int Worker index. Attributes ---------- index_worker : int Index assigned to this worker. actor : nn.Module An actor class instance. """ def __init__(self, index_worker): self.index_worker = index_worker self.actor = None # Initialize in inherited class self.is_remote = False
[docs] @classmethod def as_remote(cls, num_cpus=None, num_gpus=None, memory=None, object_store_memory=None, resources=None): """ Creates a Worker instance as a remote ray actor. Parameters ---------- num_cpus : int The quantity of CPU cores to reserve for this Worker class. num_gpus : float The quantity of GPUs to reserve for this Worker class. memory : int The heap memory quota for this actor (in bytes). object_store_memory : int The object store memory quota for this actor (in bytes). resources: Dict[str, float] The default resources required by the actor creation task. Returns ------- W : Worker A ray remote actor Worker class. """ w = ray.remote( num_cpus=num_cpus, num_gpus=num_gpus, memory=memory, object_store_memory=object_store_memory, resources=resources)(cls) w.is_remote = True return w
[docs] def print_worker_info(self): """Print information about this worker, including index and resources assigned""" s = "Created {} with worker_index {}".format( str(type(self).__name__), self.index_worker) if self.index_worker != 0: s += ", in machine {} using gpus {}".format( get_node_ip_address(), ray.get_gpu_ids()) logger.warning(s)
[docs] def get_weights(self): """Returns current actor.state_dict() weights""" return {k: v.cpu() for k, v in self.actor.state_dict().items()}
[docs] def terminate_worker(self): """Terminate this ray actor""" ray.actor.exit_actor()
[docs] def get_node_ip(self): """Returns the IP address of the current node.""" return get_node_ip_address()
[docs] def find_free_port(self): """Returns a free port on the current node.""" return find_free_port()
[docs] def setup_torch_data_parallel(self, url, rank, world_size, backend): """ Join a torch process group for distributed SGD. Parameters ---------- url : URL specifying how to initialize the process group. rank : Rank of the current process. world_size : int Number of processes participating in the job. backend : str The pytorch distributed backend to use. valid values include mpi, gloo, and nccl. """ torch.distributed.init_process_group( backend=backend, init_method=url, rank=rank, world_size=world_size)
[docs] @staticmethod def get_host(): """Return node name where this Worker is being executed.""" return os.uname()[1]