Source code for pytorchrl.agent.actors.distributions.categorical

import numpy as np
import torch
import torch.nn as nn
from pytorchrl.agent.actors.utils import init


[docs]class Categorical(nn.Module): """ Categorical probability distribution. Parameters ---------- num_inputs : int Size of input feature maps. num_outputs : int Number of options in output space. """ def __init__(self, num_inputs, num_outputs): super(Categorical, self).__init__() init_ = lambda m: init( m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), gain=np.sqrt(0.01)) self.linear = init_(nn.Linear(num_inputs, num_outputs))
[docs] def forward(self, x, deterministic=False): """ Predict distribution parameters from x (obs features) and return predictions (sampled and clipped), sampled log probability and distribution entropy. Parameters ---------- x : torch.tensor Feature maps extracted from environment observations. deterministic : bool Whether to randomly sample from predicted distribution or take the mode. Returns ------- pred: torch.tensor Predicted value. clipped_pred: torch.tensor Predicted value (clipped to be within [-1, 1] range). logp : torch.tensor Log probability of `pred` according to the predicted distribution. entropy_dist : torch.tensor Entropy of the predicted distribution. dist : torch.Distribution Action probability distribution. """ # Predict distribution parameters x = self.linear(x) # Create distribution and sample dist = torch.distributions.Categorical(logits=x) if deterministic: pred = clipped_pred = dist.probs.argmax(dim=-1, keepdim=True) else: pred = clipped_pred = dist.sample().unsqueeze(-1) # Action log probability # logp = dist.log_prob(pred.squeeze( -1)).unsqueeze(-1) logp = dist.log_prob(pred.squeeze(-1)).view(pred.size(0), -1).sum(-1).unsqueeze(-1) # Distribution entropy entropy_dist = dist.entropy().mean() return pred, clipped_pred, logp, entropy_dist, dist
[docs] def evaluate_pred(self, x, pred): """ Return log prob of `pred` under the distribution generated from x (obs features). Also return entropy of the generated distribution. Parameters ---------- x : torch.tensor obs feature map obtained from a policy_net. pred : torch.tensor Prediction to evaluate. Returns ------- logp : torch.tensor Log probability of `pred` according to the predicted distribution. entropy_dist : torch.tensor Entropy of the predicted distribution. dist : torch.Distribution Action probability distribution. """ # Predict distribution parameters x = self.linear(x) # Create distribution dist = torch.distributions.Categorical(logits=x) # Evaluate log prob of under dist logp = dist.log_prob(pred.squeeze(-1)).unsqueeze(-1).sum(-1, keepdim=True) # Distribution entropy entropy_dist = dist.entropy().mean() return logp, entropy_dist, dist