Source code for pytorchrl.agent.actors.world_models.utils
import torch
[docs]class StandardScaler(object):
def __init__(self, device):
self.input_mu = torch.zeros(1).to(device)
self.input_std = torch.ones(1).to(device)
self.target_mu = torch.zeros(1).to(device)
self.target_std = torch.ones(1).to(device)
self.device = device
[docs] def fit(self, inputs, targets):
"""
Runs two ops, one for assigning the mean of the data to the internal mean, and
another for assigning the standard deviation of the data to the internal standard deviation.
This function must be called within a 'with <session>.as_default()' block.
Parameters
----------
inputs : torch.Tensor
A torch Tensor containing the input
targets : torch.Tensor
A torch Tensor containing the input
"""
self.input_mu = torch.mean(inputs, dim=0, keepdims=True).to(self.device)
self.input_std = torch.std(inputs, dim=0, keepdims=True).to(self.device)
self.input_std[self.input_std < 1e-8] = 1.0
self.target_mu = torch.mean(targets, dim=0, keepdims=True).to(self.device)
self.target_std = torch.std(targets, dim=0, keepdims=True).to(self.device)
self.target_std[self.target_std < 1e-8] = 1.0
[docs] def transform(self, inputs, targets=None):
"""
Transforms the input matrix data using the parameters of this scaler.
Parameters
----------
inputs : torch.Tensor
A torch Tensor containing the points to be transformed.
targets : torch.Tensor
A torch Tensor containing the points to be transformed.
Returns
-------
norm_inputs : torch.Tensor
Normalized inputs
norm_targets : torch.Tensor
Normalized targets
"""
norm_inputs = (inputs - self.input_mu) / self.input_std
norm_targets = None
if targets is not None:
norm_targets = (targets - self.target_mu) / self.target_std
return norm_inputs, norm_targets
[docs] def inverse_transform(self, targets):
"""
Undoes the transformation performed by this scaler.
Parameters
----------
targets : torch.Tensor
A torch Tensor containing the points to be transformed.
Returns
-------
output : torch.Tensor
The transformed dataset.
"""
output = self.target_std * targets + self.target_mu
return output