World Models

WorldModel

class pytorchrl.agent.actors.world_models.world_model.WorldModel(device, input_space, action_space, standard_scaler, hidden_size=64, reward_function=None)[source]

Bases: torch.nn.modules.module.Module

Model-Based Actor class for Model-Based algorithms.

It contains the dynamics network to predict the next state (and reward if selected).

Parameters
  • input_space (gym.Space) – Environment observation space.

  • action_space (gym.Space) – Environment action space.

  • hidden_size (int) – Hidden size number.

  • standard_scaler (StandardScaler) – StandardScaler class instance.

  • reward_function (func) – Reward function to be learned.

static check_dynamics_weights(parameter1, parameter2)[source]
create_dynamics()[source]

Create a dynamics model and define it as class attribute under the name name.

Parameters

name (str) – dynamics model name.

predict_given_reward(states: torch.Tensor, actions: torch.Tensor) Tuple[torch.Tensor, torch.Tensor][source]

Does the next state prediction and calculates the reward given a reward function.

Parameters
  • states (torch.Tensor) – Current state s

  • actions (torch.Tensor) – Action taken in state s

Returns

  • next_states (torch.Tensor) – Next states.

  • rewards (torch.Tensor) – Calculated reward.

predict_learned_reward(states: torch.Tensor, actions: torch.Tensor) Tuple[torch.Tensor, torch.Tensor][source]

Does the next state prediction and reward prediction with a learn reward function.

Parameters
  • states (torch.Tensor) – Current state s

  • actions (torch.Tensor) – Action taken in state s

Returns

  • next_states (torch.Tensor) – Next states.

  • rewards (torch.Tensor) – Reward prediction.

reinitialize_dynamics_model()[source]

Re-initializes the dynamics model, can be done before each new Model learning run. Might help in some environments to overcome over-fitting of the model!

training: bool