Source code for pytorchrl.agent.actors.feature_extractors

import gym
import torch.nn as nn
from pytorchrl.agent.actors.feature_extractors.mlp import MLP
from pytorchrl.agent.actors.feature_extractors.cnn import CNN
from pytorchrl.agent.actors.feature_extractors.dictnet import DictNet
from pytorchrl.agent.actors.feature_extractors.fixup_cnn import FixupCNN
from pytorchrl.agent.actors.feature_extractors.embedding import Embedding


[docs]def get_feature_extractor(name): """Returns model class from name.""" if name is None: return None elif name == "MLP": return MLP elif name == "CNN": return CNN elif name == "Fixup": return FixupCNN elif name == "DictNet": return DictNet elif name == "Embedding": return Embedding else: raise ValueError("Specified feature extractor model not found!")
[docs]def default_feature_extractor(input_space): """ Returns the default net to use as a feature extractor given input_space. Parameters ---------- input_space : gym.Space Environment observation space. """ if isinstance(input_space, gym.spaces.Dict): net = DictNet elif len(input_space.shape) <= 2: net = nn.Identity elif len(input_space.shape) == 3: net = CNN else: raise NotImplementedError return net