Source code for pytorchrl.agent.actors.utils

import torch
import torch.nn as nn


[docs]class Scale(nn.Module): """ Maps inputs from [space.low, space.high] range to [-1, 1] range. Parameters ---------- space : gym.Space Space to map from. Attributes ---------- low : torch.tensor Lower bound for unscaled Space. high : torch.tensor Upper bound for unscaled Space. """ def __init__(self, space): super(Scale, self).__init__() self.register_buffer("low", torch.from_numpy(space.low)) self.register_buffer("high", torch.from_numpy(space.high))
[docs] def forward(self, x): """ Maps x from [space.low, space.high] to [-1, 1]. Parameters ---------- x : torch.tensor Input to be scaled """ return 2.0 * ((x - self.low) / (self.high - self.low)) - 1.0
[docs]class Unscale(nn.Module): """ Maps inputs from [-1, 1] range to [space.low, space.high] range. Parameters ---------- space : gym.Space Space to map from. Attributes ---------- low : torch.tensor Lower bound for unscaled Space. high : torch.tensor Upper bound for unscaled Space. """ def __init__(self, space): super(Unscale, self).__init__() self.register_buffer("low", torch.from_numpy(space.low)) self.register_buffer("high", torch.from_numpy(space.high))
[docs] def forward(self, x): """ Maps x from [-1, 1] to [space.low, space.high]. Parameters ---------- x : torch.tensor Input to be unscaled """ return self.low + (0.5 * (x + 1.0) * (self.high - self.low))
[docs]def init(module, weight_init, bias_init, gain=1): """ Parameters ---------- module : nn.Module nn.Module to initialize. weight_init : func Function to initialize module weights. bias_init : func Function to initialize module biases. Returns ------- module : nn.Module Initialized module """ weight_init(module.weight.data, gain=gain) weight_init(module.weight.data) bias_init(module.bias.data) return module
[docs]def partially_load_checkpoint(module, submodule_name, checkpoint, map_location): """Load `submodule_name` to `module` from checkpoint.""" current_state = module.state_dict() checkpoint_state = torch.load(checkpoint, map_location=map_location) for name, param in checkpoint_state.items(): if name.startswith(submodule_name): current_state[name].copy_(param)