pytorchrl package

Subpackages

Submodules

pytorchrl.learner module

class pytorchrl.learner.Learner(scheme, target_steps, log_dir=None)[source]

Bases: object

Task learner class. Class to manage the training process. It pushes forward the training process by calling the update workers and tracks progress.

Parameters
  • scheme (Scheme) – Training scheme class instance, handling coordination of workers.

  • target_steps (int) – Number of environment steps to reach to complete training.

  • log_dir (str) – Target directory for model checkpoints and logs.

done()[source]

Return True if training has finished (target_steps reached).

Returns

flag – True if training has reached the target number of steps.

Return type

bool

get_metrics(add_algo_metrics=True, add_episodes_metrics=False, add_scheme_metrics=False, add_time_metrics=False)[source]

Returns current value of tracked metrics.

print_info(add_algo_info=True, add_episodes_info=True, add_scheme_info=False, add_time_info=False)[source]

Print relevant information about the training process

save_model(fname='model.state_dict')[source]

Save currently learned actor_critic version.

Returns

save_name – Path to saved file.

Return type

str

step()[source]

Takes a logical synchronous optimization step.

stop()[source]

Stop all threads.

update_algorithm_parameter(parameter_name, new_parameter_value)[source]

If parameter_name is an attribute of the algorithm used for training, change its value to new_parameter_value value.

Parameters
  • parameter_name (str) – Worker.algo attribute name

  • new_parameter_value (int or float) – New value for parameter_name.

pytorchrl.utils module

class pytorchrl.utils.LoadFromFile(option_strings, dest, nargs=None, const=None, default=None, type=None, choices=None, required=False, help=None, metavar=None)[source]

Bases: argparse.Action

class pytorchrl.utils.RunningMeanStd(epsilon=0.0001, shape=(), device=device(type='cpu'))[source]

Bases: object

Class to keep track on the running mean and variance of tensors batches.

update(x)[source]
update_from_moments(batch_mean, batch_var, batch_count)[source]
pytorchrl.utils.cleanup_log_dir(log_dir)[source]

Create log directory and remove old files.

Parameters

log_dir (str) – Path to log directory.

pytorchrl.utils.clip_grad_norm_(parameters, norm_type: float = 2.0)[source]

This is the official clip_grad_norm implemented in pytorch but the max_norm part has been removed. https://github.com/pytorch/pytorch/blob/52f2db752d2b29267da356a06ca91e10cd732dbc/torch/nn/utils/clip_grad.py#L9

pytorchrl.utils.save_argparse(args, filename, exclude=None)[source]

Module contents

class pytorchrl.DataTransition(Observation, RecurrentHiddenStates, Done, Action, Reward, NextObservation, NextRecurrentHiddenStates, NextDone, EnvironmentInformation)

Bases: tuple

property Action

Alias for field number 3

property Done

Alias for field number 2

property EnvironmentInformation

Alias for field number 8

property NextDone

Alias for field number 7

property NextObservation

Alias for field number 5

property NextRecurrentHiddenStates

Alias for field number 6

property Observation

Alias for field number 0

property RecurrentHiddenStates

Alias for field number 1

property Reward

Alias for field number 4