Source code for pytorchrl.agent.env.env_wrappers

import gym
from gym.spaces.box import Box


[docs]class TransposeImagesIfRequired(gym.ObservationWrapper): """ When environment observations are images, this wrapper transposes the axis. It is useful when the images have shape (W,H,C), as they can be transposed "on the fly" to (C,W,H) for PyTorch convolutions to be applied. Parameters ---------- env : gym.Env Original Gym environment, previous to applying the wrapper. op : list New axis ordering. """ def __init__(self, env=None, op=[2, 0, 1]): """Transpose observation space for images""" super(TransposeImagesIfRequired, self).__init__(env) self.op = op if isinstance(self.observation_space, gym.spaces.Box) and \ len(self.observation_space.shape) == 3: obs_shape = self.observation_space.shape self.observation_space = Box( self.observation_space.low[0, 0, 0], self.observation_space.high[0, 0, 0], [obs_shape[self.op[0]], obs_shape[self.op[1]], obs_shape[self.op[2]]], dtype=self.observation_space.dtype) elif isinstance(self.observation_space, gym.spaces.Dict): for k in self.observation_space.spaces: if isinstance(self.observation_space[k], gym.spaces.Box) and \ len(self.observation_space[k].shape) == 3: obs_shape = self.observation_space[k].shape self.observation_space[k] = Box( self.observation_space[k].low[0, 0, 0], self.observation_space[k].high[0, 0, 0], [obs_shape[self.op[0]], obs_shape[self.op[1]], obs_shape[self.op[2]]], dtype=self.observation_space.dtype)
[docs] def observation(self, ob): """Transpose observation""" if isinstance(ob, dict): for k in ob: if len(ob[k].shape) == 3: ob[k] = ob[k].transpose(self.op[0], self.op[1], self.op[2]) else: if len(ob.shape) == 3: ob = ob.transpose(self.op[0], self.op[1], self.op[2]) return ob