Source code for pytorchrl.learner

import os
import time
import logging
from functools import partial
from collections import defaultdict, deque
import pytorchrl as prl

logger = logging.getLogger(__name__)


[docs]class Learner: """ Task learner class. Class to manage the training process. It pushes forward the training process by calling the update workers and tracks progress. Parameters ---------- scheme : Scheme Training scheme class instance, handling coordination of workers. target_steps : int Number of environment steps to reach to complete training. log_dir : str Target directory for model checkpoints and logs. """ def __init__(self, scheme, target_steps, log_dir=None): # Input attributes self.log_dir = log_dir self.target_steps = target_steps self.update_worker = scheme.update_worker() # Counters and metrics self.num_samples_collected = 0 self.metrics = {k: defaultdict(partial(deque, maxlen=1)) for k in prl.INFO_KEYS} # Record starting time self.start = time.time() logger.warning("Created Learner.")
[docs] def step(self): """Takes a logical synchronous optimization step.""" # Update step info = self.update_worker.step() actor_version = info[prl.VERSION] # grad_update_lag info[prl.SCHEME] = {} info[prl.SCHEME][prl.FPS] = int(self.num_samples_collected / (time.time() - self.start)) info[prl.SCHEME][prl.PL] = actor_version[prl.GRADIENT] - actor_version[prl.COLLECTION] info[prl.SCHEME][prl.GA] = actor_version[prl.UPDATE] - actor_version[prl.GRADIENT] # Update counters self.num_samples_collected += info.pop(prl.NUMSAMPLES) # Update and log metrics for k in prl.INFO_KEYS: if k in info and isinstance(info[k], dict): for x, y in info[k].items(): if isinstance(y, (float, int)): self.metrics[k][x].append(y)
[docs] def done(self): """ Return True if training has finished (target_steps reached). Returns ------- flag : bool True if training has reached the target number of steps. """ flag = self.num_samples_collected >= self.target_steps if flag: self.update_worker.stop() print("\nTraining finished!") time.sleep(1) return flag
[docs] def get_metrics(self, add_algo_metrics=True, add_episodes_metrics=False, add_scheme_metrics=False, add_time_metrics=False): """Returns current value of tracked metrics.""" m = {} def include_info(key): for k, v in self.metrics[key].items(): m[os.path.join(key, k)] = sum(v) / len(v) if add_algo_metrics: include_info(prl.ALGORITHM) if add_episodes_metrics: include_info(prl.EPISODES) if add_scheme_metrics: include_info(prl.SCHEME) if add_time_metrics: include_info(prl.TIME) return m
[docs] def print_info(self, add_algo_info=True, add_episodes_info=True, add_scheme_info=False, add_time_info=False): """Print relevant information about the training process""" def write_info(msg, key): msg += "\n {}: ".format(key) for k, v in self.metrics[key].items(): msg += "{} {:.4f}, ".format(k, sum(v) / len(v)) msg = msg[:-2] return msg s = "Update {}".format(self.update_worker.actor_version) s += ", num samples collected {}, FPS {}".format(self.num_samples_collected, int(self.num_samples_collected / (time.time() - self.start))) if add_algo_info: s = write_info(s, prl.ALGORITHM) if add_episodes_info: s = write_info(s, prl.EPISODES) if add_scheme_info: s = write_info(s, prl.SCHEME) if add_time_info: s = write_info(s, prl.TIME) print(s, flush=True)
[docs] def update_algorithm_parameter(self, parameter_name, new_parameter_value): """ If `parameter_name` is an attribute of the algorithm used for training, change its value to `new_parameter_value value`. Parameters ---------- parameter_name : str Worker.algo attribute name new_parameter_value : int or float New value for `parameter_name`. """ self.update_worker.update_algorithm_parameter(parameter_name, new_parameter_value)
[docs] def save_model(self, fname="model.state_dict"): """ Save currently learned actor_critic version. Returns ------- save_name : str Path to saved file. """ fname = os.path.join(self.log_dir, fname) save_name = self.update_worker.save_model(fname) return save_name
[docs] def stop(self): """Stop all threads.""" self.update_worker.stop()