World Models
WorldModel
- class pytorchrl.agent.actors.world_models.world_model.WorldModel(device, input_space, action_space, standard_scaler, hidden_size=64, reward_function=None)[source]
Bases:
torch.nn.modules.module.ModuleModel-Based Actor class for Model-Based algorithms.
It contains the dynamics network to predict the next state (and reward if selected).
- Parameters
input_space (gym.Space) – Environment observation space.
action_space (gym.Space) – Environment action space.
hidden_size (int) – Hidden size number.
standard_scaler (StandardScaler) – StandardScaler class instance.
reward_function (func) – Reward function to be learned.
- create_dynamics()[source]
Create a dynamics model and define it as class attribute under the name name.
- Parameters
name (str) – dynamics model name.
- predict_given_reward(states: torch.Tensor, actions: torch.Tensor) Tuple[torch.Tensor, torch.Tensor][source]
Does the next state prediction and calculates the reward given a reward function.
- Parameters
states (torch.Tensor) – Current state s
actions (torch.Tensor) – Action taken in state s
- Returns
next_states (torch.Tensor) – Next states.
rewards (torch.Tensor) – Calculated reward.
- predict_learned_reward(states: torch.Tensor, actions: torch.Tensor) Tuple[torch.Tensor, torch.Tensor][source]
Does the next state prediction and reward prediction with a learn reward function.
- Parameters
states (torch.Tensor) – Current state s
actions (torch.Tensor) – Action taken in state s
- Returns
next_states (torch.Tensor) – Next states.
rewards (torch.Tensor) – Reward prediction.
- reinitialize_dynamics_model()[source]
Re-initializes the dynamics model, can be done before each new Model learning run. Might help in some environments to overcome over-fitting of the model!
- training: bool