Source code for pytorchrl.agent.algorithms.utils

import itertools
import torch


[docs]def get_gradients(*nets, grads_to_cpu=False): """Gets gradients for all parameters in nets.""" params = itertools.chain(*[net.parameters() for net in nets]) grads = [] for p in params: if grads_to_cpu: if p.grad is not None: grads.append(p.grad.data.cpu().numpy()) else: grads.append(None) else: if p.grad is not None: grads.append(p.grad) return grads
[docs]def set_gradients(*nets, gradients, device): """Sets gradients as the gradient vaues for all parameters in nets.""" params = itertools.chain(*[net.parameters() for net in nets]) for g, p in zip(gradients, params): if g is not None: p.grad = torch.from_numpy(g).to(device)
###### KL DIVERGENCE ###################################################################################################
[docs]def bt(m): return m.transpose(dim0=-2, dim1=-1)
[docs]def btr(m): return m.diagonal(dim1=-2, dim2=-1).sum(-1)
[docs]def gaussian_kl(mu1, mu2, cov1, cov2): """ Decoupled KL between two multivariate gaussian distribution C_μ = KL(f(x|μi,Σi)||f(x|μ,Σi)) C_Σ = KL(f(x|μi,Σi)||f(x|μi,Σ)) Adapted from https://github.com/daisatojp/mpo/blob/master/mpo/mpo.py Parameters ---------- mu1: torch.tensor Mean distribution 1 - (B, n). mu2: torch.tensor Mean distribution 2 - (B, n). cov1: torch.tensor Covariance matrix distribution 1 - (B, n, n). cov2: Covariance matrix distribution 2 - (B, n, n) Returns ------- kl_mu: scalar Mean term of the KL. kl_sigma: scalar Covariance term of the KL. ref : https://stanford.edu/~jduchi/projects/general_notes.pdf page.13 """ n = cov2.size(-1) mu1 = mu1.unsqueeze(-1) # (B, n, 1) mu2 = mu2.unsqueeze(-1) # (B, n, 1) sigma1 = cov1 @ bt(cov1) # (B, n, n) sigma2 = cov2 @ bt(cov2) # (B, n, n) sigma1_det = sigma1.det() # (B,) sigma2_det = sigma2.det() # (B,) sigma1_inv = sigma1.inverse() # (B, n, n) sigma2_inv = sigma2.inverse() # (B, n, n) # determinant can be minus due to numerical calculation error # https://github.com/daisatojp/mpo/issues/11 sigma1_det = torch.clamp_min(sigma1_det, 1e-6) sigma2_det = torch.clamp_min(sigma2_det, 1e-6) inner_mu = ((mu2 - mu1).transpose(-2, -1) @ sigma1_inv @ (mu2 - mu1)).squeeze() # (B,) inner_sigma = torch.log(sigma1_det / sigma2_det) - n + btr(sigma2_inv @ sigma1_inv) # (B,) kl_mu = 0.5 * torch.mean(inner_mu) kl_sigma = 0.5 * torch.mean(inner_sigma) return kl_mu, kl_sigma