pytorchrl package
Subpackages
- pytorchrl.agent package
- pytorchrl.envs package
- Subpackages
- Submodules
- pytorchrl.envs.common module
- Module contents
- pytorchrl.scheme package
Submodules
pytorchrl.learner module
- class pytorchrl.learner.Learner(scheme, target_steps, log_dir=None)[source]
Bases:
objectTask learner class. Class to manage the training process. It pushes forward the training process by calling the update workers and tracks progress.
- Parameters
scheme (Scheme) – Training scheme class instance, handling coordination of workers.
target_steps (int) – Number of environment steps to reach to complete training.
log_dir (str) – Target directory for model checkpoints and logs.
- done()[source]
Return True if training has finished (target_steps reached).
- Returns
flag – True if training has reached the target number of steps.
- Return type
bool
- get_metrics(add_algo_metrics=True, add_episodes_metrics=False, add_scheme_metrics=False, add_time_metrics=False)[source]
Returns current value of tracked metrics.
- print_info(add_algo_info=True, add_episodes_info=True, add_scheme_info=False, add_time_info=False)[source]
Print relevant information about the training process
- save_model(fname='model.state_dict')[source]
Save currently learned actor_critic version.
- Returns
save_name – Path to saved file.
- Return type
str
- update_algorithm_parameter(parameter_name, new_parameter_value)[source]
If parameter_name is an attribute of the algorithm used for training, change its value to new_parameter_value value.
- Parameters
parameter_name (str) – Worker.algo attribute name
new_parameter_value (int or float) – New value for parameter_name.
pytorchrl.utils module
- class pytorchrl.utils.LoadFromFile(option_strings, dest, nargs=None, const=None, default=None, type=None, choices=None, required=False, help=None, metavar=None)[source]
Bases:
argparse.Action
- class pytorchrl.utils.RunningMeanStd(epsilon=0.0001, shape=(), device=device(type='cpu'))[source]
Bases:
objectClass to keep track on the running mean and variance of tensors batches.
- pytorchrl.utils.cleanup_log_dir(log_dir)[source]
Create log directory and remove old files.
- Parameters
log_dir (str) – Path to log directory.
- pytorchrl.utils.clip_grad_norm_(parameters, norm_type: float = 2.0)[source]
This is the official clip_grad_norm implemented in pytorch but the max_norm part has been removed. https://github.com/pytorch/pytorch/blob/52f2db752d2b29267da356a06ca91e10cd732dbc/torch/nn/utils/clip_grad.py#L9
Module contents
- class pytorchrl.DataTransition(Observation, RecurrentHiddenStates, Done, Action, Reward, NextObservation, NextRecurrentHiddenStates, NextDone, EnvironmentInformation)
Bases:
tuple- property Action
Alias for field number 3
- property Done
Alias for field number 2
- property EnvironmentInformation
Alias for field number 8
- property NextDone
Alias for field number 7
- property NextObservation
Alias for field number 5
- property NextRecurrentHiddenStates
Alias for field number 6
- property Observation
Alias for field number 0
- property RecurrentHiddenStates
Alias for field number 1
- property Reward
Alias for field number 4