import numpy as np
import gym
from gym import spaces
from pytorchrl.envs.obstacle_tower.utils import (
box_is_placed, box_location, place_location, reduced_action_lookup_6,
reduced_action_lookup_7, reduced_action_lookup_8)
[docs]class BasicObstacleEnv(gym.Wrapper):
def __init__(self, env, min_floor, max_floor, seed_list=[]):
gym.Wrapper.__init__(self, env)
self.reached_floor = 0
self._seed_list = seed_list
self._min_floor = min_floor
self._max_floor = max_floor
self.count = 0
self.last_time = 3000
self.force_seed = None
self.start_floor = 0
[docs] def step(self, action):
obs, reward, done, info = self.env.step(action)
if info['current_floor'] > self.reached_floor:
self.reached_floor = info['current_floor']
if info['current_floor'] > self._max_floor:
done = True
info['seed'] = self.seed
info['start'] = float(self.start_floor)
info['floor'] = float(self.reached_floor)
num_keys = info["total_keys"]
self.picked_key = num_keys > self._previous_keys
self._previous_keys = num_keys
self.last_time = info['time_remaining']
return obs, reward, done, info
[docs] def reset(self, **kwargs):
self._previous_keys = 0
self.puzzle_solved = False
if len(self._seed_list) > 0:
self.seed = np.random.choice(self._seed_list)
else:
self.seed = np.random.randint(0, 100)
self.env.unwrapped.seed(self.seed)
self.env.unwrapped.floor(self._min_floor)
self.reached_floor = 0
config = {"total-floors": self._max_floor + 1}
self.count += 1
return self.env.reset(config=config, **kwargs)
[docs]class BasicObstacleEnvTest(gym.Wrapper):
def __init__(self, env, min_floor, max_floor, seed_list=[1001, 1002, 1003, 1004, 1005]):
gym.Wrapper.__init__(self, env)
self.reached_floor = None
self._min_floor = min_floor
self._max_floor = max_floor
self.last_time = 3000
self.seed_index = -1
self.seed_list = seed_list
self.reached_floors = []
[docs] def step(self, action):
obs, reward, done, info = self.env.step(action)
if info['current_floor'] > self.reached_floor:
self.reached_floor = info['current_floor']
info['seed'] = self.seed
info['start'] = float(self.start_floor)
info['floor'] = float(self.reached_floor)
num_keys = info["total_keys"]
self.picked_key = num_keys > self._previous_keys
self._previous_keys = num_keys
self.last_time = info['time_remaining']
return obs, reward, done, info
[docs] def reset(self, **kwargs):
if self.reached_floor is not None:
self.reached_floors.append(self.reached_floor)
print("Seed {}, reahed foor {}".format(self.seed, self.reached_floor))
if self.seed_index == len(self.seed_list) - 1:
print("Average of all seeds {}".format(np.mean(self.reached_floors)))
self.reached_floors = []
self._previous_keys = 0
self.seed_index = (self.seed_index + 1) % len(self.seed_list)
self.seed = self.seed_list[self.seed_index]
self.env.unwrapped.seed(self.seed)
self.start_floor = self._min_floor
self.env.unwrapped.floor(self.start_floor)
self.reached_floor = 0
config = {"total-floors": self._max_floor + 2}
return self.env.reset(config=config, **kwargs)
[docs]class RewardShapeObstacleEnv(gym.Wrapper):
def __init__(self, env, killed_reward=2):
gym.Wrapper.__init__(self, env)
self._killed_reward = killed_reward
self.time_remaining = None
[docs] def step(self, action):
obs, reward, done, info = self.env.step(action)
# Picked key
if self.env.picked_key:
reward += 1.01
found_box, _ = box_location(obs)
if found_box:
reward += 0.002
found_place, _ = place_location(obs)
if found_place:
reward += 0.001
# Account for solving puzzle
if box_is_placed(obs) and (not self.puzzle_solved) and (reward > 0.08):
self.puzzle_solved = True
reward += 1.5
if self.time_remaining == None:
self.time_remaining = info['time_remaining']
if info['time_remaining'] > self.time_remaining:
reward += 0.002
return obs, reward, done, info
[docs] def reset(self, **kwargs):
self.puzzle_solved = False
self.time_remaining = None
return self.env.reset(**kwargs)
[docs]class ReducedActionEnv(gym.Wrapper):
def __init__(self, env, num_actions=8):
if num_actions == 6:
_action_lookup = reduced_action_lookup_6
elif num_actions == 7:
_action_lookup = reduced_action_lookup_7
elif num_actions == 8:
_action_lookup = reduced_action_lookup_8
else:
ValueError("No lookup table for num reduced actions {}".format(
num_actions))
env.unwrapped._flattener.action_lookup = _action_lookup
num_actions = len(env.unwrapped._flattener.action_lookup)
env.unwrapped._flattener.action_space = spaces.Discrete(num_actions)
env.unwrapped._action_space = env.unwrapped._flattener.action_space
gym.Wrapper.__init__(self, env)