Source code for pytorchrl.scheme.utils

import ray
import time
import base64
import lz4.frame
from ray import cloudpickle as pickle

FREE_DELAY_S = 10.0
MAX_FREE_QUEUE_SIZE = 100
_last_free_time = 0.0
_to_free = []


[docs]def ray_get_and_free(object_ids): """ Call ray.get and then queue the object ids for deletion. This function should be used whenever possible in RLlib, to optimize memory usage. The only exception is when an object_id is shared among multiple readers. Adapted from https://github.com/ray-project/ray/blob/master/rllib/utils/memory.py Parameters ---------- object_ids : ObjectID|List[ObjectID] Object ids to fetch and free. Returns ------- result : python objects The result of ray.get(object_ids). """ global _last_free_time global _to_free result = ray.get(object_ids) if type(object_ids) is not list: object_ids = [object_ids] _to_free.extend(object_ids) # batch calls to free to reduce overheads now = time.time() if (len(_to_free) > MAX_FREE_QUEUE_SIZE or now - _last_free_time > FREE_DELAY_S): ray.internal.free(_to_free) _to_free = [] _last_free_time = now return result
[docs]def broadcast_message(key, message): ray.worker.global_worker.redis_client.set(key, message)
[docs]def check_message(key): return ray.worker.global_worker.redis_client.get(key)
[docs]def average_gradients(grads_list): """ Averages gradients coming from distributed workers. Parameters ---------- grads_list : list of lists of tensors List of actor gradients from different workers. Returns ------- avg_grads : list of tensors Averaged actor gradients. """ avg_grads = [ sum(d[grad] for d in grads_list) / len(grads_list) if grads_list[0][grad] is not None else 0.0 for grad in range(len(grads_list[0]))] return avg_grads
[docs]def pack(data): """ from https://github.com/ray-project/ray/blob/master/rllib/utils/compression.py """ data = pickle.dumps(data) data = lz4.frame.compress(data) # TODO(ekl) we shouldn't need to base64 encode this data, but this # seems to not survive a transfer through the object store if we don't. data = base64.b64encode(data).decode("ascii") return data
[docs]def unpack(data): """ from https://github.com/ray-project/ray/blob/master/rllib/utils/compression.py """ data = base64.b64decode(data) data = lz4.frame.decompress(data) data = pickle.loads(data) return data