Source code for pytorchrl.agent.actors.feature_extractors.ensemble_layer

import torch
import torch.nn as nn


[docs]class EnsembleFC(nn.Module): def __init__(self, in_features: int, out_features: int, ensemble_size: int, weight_decay: float = 0., bias: bool = True) -> None: super(EnsembleFC, self).__init__() self.in_features = in_features self.out_features = out_features self.ensemble_size = ensemble_size self.weight = nn.Parameter(torch.Tensor(ensemble_size, in_features, out_features)) torch.nn.init.xavier_uniform_(self.weight) #torch.nn.init.trunc_normal_(self.weight) self.weight_decay = weight_decay if bias: self.bias = nn.Parameter(torch.zeros(ensemble_size, out_features)) else: self.register_parameter('bias', None) self.reset_parameters()
[docs] def reset_parameters(self) -> None: pass
[docs] def forward(self, input: torch.Tensor) -> torch.Tensor: w_times_x = torch.bmm(input, self.weight) return torch.add(w_times_x, self.bias[:, None, :]) # w times x + b
[docs] def extra_repr(self) -> str: return 'in_features={}, out_features={}, bias={}'.format( self.in_features, self.out_features, self.bias is not None )