Source code for pytorchrl.agent.actors.memory_networks

from pytorchrl.agent.actors.memory_networks.gru_net import GruNet
from pytorchrl.agent.actors.memory_networks.lstm_net import LstmNet
from pytorchrl.agent.actors.memory_networks.lstm_encoder_decoder_net import LSTMEncoderDecoder


[docs]def get_memory_network(name): """Returns model class from name.""" if name is None: return None elif name == "GRU": return GruNet elif name == "LSTM": return LstmNet elif name == "LSTMEncoderDecoder": return LSTMEncoderDecoder else: raise ValueError("Specified memory net model not found!")