import os
import yaml
import torch
import shutil
import argparse
from torch._six import inf
[docs]def cleanup_log_dir(log_dir):
"""
Create log directory and remove old files.
Parameters
----------
log_dir : str
Path to log directory.
"""
try:
shutil.rmtree(os.path.join(log_dir))
except Exception:
print("Unable to cleanup log_dir...")
os.makedirs(log_dir, exist_ok=True)
[docs]class LoadFromFile(argparse.Action):
# parser.add_argument('--file', type=open, action=LoadFromFile)
def __call__(self, parser, namespace, values, option_string=None):
if values.name.endswith('yaml') or values.name.endswith('yml'):
with values as f:
config = yaml.load(f, Loader=yaml.FullLoader)
for key in config.keys():
if key not in namespace:
raise ValueError(f'Unknown argument in config file: {key}')
namespace.__dict__.update(config)
else:
raise ValueError('Configuration file must end with yaml or yml')
[docs]def save_argparse(args, filename, exclude=None):
if filename.endswith('yaml') or filename.endswith('yml'):
if isinstance(exclude, str):
exclude = [exclude, ]
args = args.__dict__.copy()
for exl in exclude:
del args[exl]
yaml.dump(args, open(filename, 'w'))
else:
raise ValueError("Configuration file should end with yaml or yml")
[docs]class RunningMeanStd:
"""Class to keep track on the running mean and variance of tensors batches."""
def __init__(self, epsilon=1e-4, shape=(), device=torch.device("cpu")):
self.mean = torch.zeros(shape, dtype=torch.float64).to(device)
self.var = torch.ones(shape, dtype=torch.float64).to(device)
self.count = epsilon
[docs] def update(self, x):
batch_mean = torch.mean(x, dim=0)
batch_var = torch.var(x, dim=0)
batch_count = x.shape[0]
self.update_from_moments(batch_mean, batch_var, batch_count)
[docs] def update_from_moments(self, batch_mean, batch_var, batch_count):
delta = batch_mean - self.mean
tot_count = self.count + batch_count
new_mean = self.mean + delta * batch_count / tot_count
m_a = self.var * self.count
m_b = batch_var * batch_count
M2 = m_a + m_b + torch.square(delta) * self.count * batch_count / tot_count
new_var = M2 / tot_count
new_count = tot_count
self.mean, self.var, self.count = new_mean, new_var, new_count
[docs]def clip_grad_norm_(parameters, norm_type: float = 2.0):
"""
This is the official clip_grad_norm implemented in pytorch but the max_norm part has been removed.
https://github.com/pytorch/pytorch/blob/52f2db752d2b29267da356a06ca91e10cd732dbc/torch/nn/utils/clip_grad.py#L9
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = [p for p in parameters if p.grad is not None]
norm_type = float(norm_type)
if len(parameters) == 0:
return torch.tensor(0.)
device = parameters[0].grad.device
if norm_type == inf:
total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
else:
total_norm = torch.norm(
torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
return total_norm