import torch
import torch.nn as nn
[docs]class EnsembleFC(nn.Module):
def __init__(self, in_features: int, out_features: int, ensemble_size: int, weight_decay: float = 0., bias: bool = True) -> None:
super(EnsembleFC, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.ensemble_size = ensemble_size
self.weight = nn.Parameter(torch.Tensor(ensemble_size, in_features, out_features))
torch.nn.init.xavier_uniform_(self.weight)
#torch.nn.init.trunc_normal_(self.weight)
self.weight_decay = weight_decay
if bias:
self.bias = nn.Parameter(torch.zeros(ensemble_size, out_features))
else:
self.register_parameter('bias', None)
self.reset_parameters()
[docs] def reset_parameters(self) -> None:
pass
[docs] def forward(self, input: torch.Tensor) -> torch.Tensor:
w_times_x = torch.bmm(input, self.weight)
return torch.add(w_times_x, self.bias[:, None, :]) # w times x + b