import gym
import numpy as np
import torch
import torch.nn as nn
from pytorchrl.agent.actors.utils import init
from pytorchrl.agent.actors.feature_extractors.utils import get_gain
[docs]class CNN(nn.Module):
"""
Convolutional Neural Network.
Parameters
----------
input_space : gym.Space
Environment observation space.
rgb_norm : bool
Whether or not to divide input by 255.
activation : func
Non-linear activation function.
final_activation : bool
Whether or not to apply activation function after last layer.
layer_norm: bool
Use layer normalization.
dropout: float
Dropout probability.
strides : list
Convolutional layers strides.
filters : list
Convolutional layers number of filters.
kernel_sizes : list
Convolutional layers kernel sizes.
output_sizes : list
output hidden layers sizes.
"""
def __init__(self,
input_space,
rgb_norm=True,
activation=nn.ReLU,
final_activation=True,
dropout=0.0,
layer_norm=False,
strides=[4, 2, 1],
filters=[32, 64, 64],
kernel_sizes=[8, 4, 3],
output_sizes=[256, 448]):
super(CNN, self).__init__()
self.rgb_norm = rgb_norm
assert len(filters) == len(strides) and len(strides) == len(kernel_sizes)
if isinstance(input_space, gym.Space):
input_shape = input_space.shape
else:
input_shape = input_space
if len(input_shape) != 3:
raise ValueError("Trying to extract features with a CNN for an obs space with len(shape) != 3")
# Define CNN feature extractor
layers = []
filters = [input_shape[0]] + filters
for j in range(len(filters) - 1):
layers += [nn.Conv2d(
filters[j], filters[j + 1], stride=strides[j],
kernel_size=kernel_sizes[j]), activation()]
self.feature_extractor = nn.Sequential(*layers)
# Define final MLP layers
feature_size = int(np.prod(self.feature_extractor(torch.randn(1, *input_shape)).shape))
layers = []
sizes = [feature_size] + output_sizes
for j in range(len(sizes) - 1):
layers += [nn.Linear(sizes[j], sizes[j + 1])]
if dropout > 0.0 and j < len(sizes) - 2:
layers += [nn.Dropout(dropout)]
if layer_norm and j < len(sizes) - 2:
layers += [nn.LayerNorm(sizes[j + 1])]
if j < len(sizes) - 2 or final_activation:
layers += [activation()]
self.head = nn.Sequential(*layers)
for layer in self.feature_extractor.modules():
if isinstance(layer, nn.Conv2d):
nn.init.orthogonal_(layer.weight, gain=get_gain(activation))
layer.bias.data.zero_()
elif isinstance(layer, nn.Linear):
nn.init.orthogonal_(layer.weight, gain=get_gain(activation))
layer.bias.data.zero_()
for layer in self.head.modules():
if isinstance(layer, nn.Conv2d):
nn.init.orthogonal_(layer.weight, gain=get_gain(activation))
layer.bias.data.zero_()
elif isinstance(layer, nn.Linear):
nn.init.orthogonal_(layer.weight, gain=get_gain(activation))
layer.bias.data.zero_()
self.train()
[docs] def forward(self, inputs):
"""
Forward pass Neural Network
Parameters
----------
inputs : torch.tensor
Input data.
Returns
-------
out : torch.tensor
Output feature map.
"""
if self.rgb_norm:
inputs = inputs / 255.0
out = self.feature_extractor(inputs)
out = out.contiguous()
out = out.view(inputs.size(0), -1)
out = self.head(out)
return out