Source code for pytorchrl.agent.actors.base

import torch
import torch.nn as nn
from abc import ABC, abstractmethod
from pytorchrl.agent.actors.utils import partially_load_checkpoint


[docs]class Actor(nn.Module, ABC): """ Base class for all Actors. Parameters ---------- device: torch.device CPU or specific GPU where class computations will take place. input_space : gym.Space Environment observation space. action_space : gym.Space Environment action space. checkpoint : str Path to a previously trained Actor checkpoint to be loaded. """ def __init__(self, device, input_space, action_space, checkpoint=None, *args): super(Actor, self).__init__() self.device = device self.checkpoint = checkpoint self.input_space = input_space self.action_space = action_space
[docs] @classmethod @abstractmethod def create_factory( cls, device, input_space, action_space, checkpoint=None, *args): """ Returns a function that creates actor critic instances. Parameters ---------- device: torch.device CPU or specific GPU where class computations will take place. input_space : gym.Space Environment observation space. action_space : gym.Space Environment action space. checkpoint : str Path to a previously trained Actor checkpoint to be loaded. Returns ------- create_actor_instance : func creates a new Actor class instance. """ raise NotImplementedError
@property @abstractmethod def is_recurrent(self, *args): """Returns True if the actor network are recurrent.""" raise NotImplementedError @property @abstractmethod def recurrent_hidden_state_size(self): """Size of policy recurrent hidden state""" raise NotImplementedError
[docs] @abstractmethod def actor_initial_states(self, obs, *args): """ Returns all policy inputs to predict the environment initial action. Parameters ---------- obs : torch.tensor Initial environment observation. Returns ------- obs : torch.tensor Initial environment observation. rhs : torch.tensor Initial recurrent hidden state. done : torch.tensor Initial done tensor, indicating the environment is not done. """ raise NotImplementedError
[docs] def try_load_from_checkpoint(self): """Load weights from previously saved checkpoint.""" try: if isinstance(self.checkpoint, str): print("Loading all model weight from {}".format(self.checkpoint)) self.load_state_dict(torch.load(self.checkpoint, map_location=self.device)) elif isinstance(self.checkpoint, dict): for submodule, checkpoint in self.checkpoint.items(): print("Loading {} model weight from {}".format(submodule, self.checkpoint[submodule])) partially_load_checkpoint(self, submodule, checkpoint, map_location=self.device) else: print("Training model from scratch") except Exception: print("Error when trying to load checkpoint!")