Source code for pytorchrl.agent.algorithms.policy_loss_addons.base
from abc import ABC, abstractmethod
import torch
[docs]class PolicyLossAddOn(ABC):
"""Base class for all add ons to the policy loss."""
[docs] @abstractmethod
def setup(self, actor, device):
"""Initializes the class."""
raise NotImplementedError
[docs] @abstractmethod
def compute_loss_term(self, batch, dist_entropy=None):
"""Calculates addon loss term."""
raise NotImplementedError