Source code for pytorchrl.agent.algorithms.policy_loss_addons.base

from abc import ABC, abstractmethod
import torch


[docs]class PolicyLossAddOn(ABC): """Base class for all add ons to the policy loss."""
[docs] @abstractmethod def setup(self, actor, device): """Initializes the class.""" raise NotImplementedError
[docs] @abstractmethod def compute_loss_term(self, batch, dist_entropy=None): """Calculates addon loss term.""" raise NotImplementedError