import os import cv2 import gym import numpy as np import torchvision.transforms as T from atariari.benchmark.wrapper import AtariARIWrapper SUBGOAL_ORDER = [8, 6, 1, 0, 2, 7, 2, 0, 1, 6, 8, 9] class ALEEnvironment(): def __init__(self, args, device='cpu'): if 'cuda' in device: os.environ['CUDA_VISIBLE_DEVICES']= device.split(':')[-1] self.ale = AtariARIWrapper(gym.make('MontezumaRevenge-v4', frameskip=args.frame_skip, render_mode='rgb_array', # return the rgb key in step metadata with the current environment RGB frame. repeat_action_probability=0.0)) self.histLen = args.sequence_length print(self.histLen) self.screen_width = args.screen_width self.screen_height = args.screen_height self.actions = np.arange(self.ale.action_space.n) print('Action space: ', self.actions) self.mode = "train" self.life_lost = False # Set environment with random seed! self.ale.reset(seed=args.random_seed) # Get init screen self.init_screen = self.getScreen() print(f'Size of screen is {self.init_screen.shape}') # Perform NOOP action to init self.done and self.info _, reward, self.done, self.info = self.ale.step(0) # Use subgoals from gaze analysis self.goalSet = np.loadtxt('subgoals.txt', dtype=int, delimiter=',') self.goalCenterLoc = [] for goal in self.goalSet: goalCenter = [float(goal[0]+goal[2])/2, float(goal[1]+goal[3])/2] self.goalCenterLoc.append(goalCenter) self.agentOriginLoc = [42, 33] self.agentLastX = 42 self.agentLastY = 33 self.reachedGoal = [0, 0, 0,0] self.histState = self.initializeHistState() print('Histogram of states: ', self.histState.shape) self.to_tensor = T.ToTensor() def initializeHistState(self): if self.histLen >= 2: histState = np.concatenate((self.getState(), self.getState()), axis = 2) for _ in range(self.histLen - 2): histState = np.concatenate((histState, self.getState()), axis = 2) else: histState = self.getState() return histState def get_input_shape(self): return (self.histLen, self.screen_width, self.screen_height) def numActions(self): return len(self.actions) def resetGoalReach(self): self.reachedGoal = [0, 0, 0,0] def act(self, action): 'Perform action and handle results' lives = self.info['lives'] _, reward, self.done, self.info = self.ale.step(self.actions[action]) # agent location from AtariARI Wrapper self.agentLastX, self.agentLastY = self.info['labels']['player_x'], 320 - self.info['labels']['player_y'] self.life_lost = (not lives == self.info['lives']) currState = self.getState() self.histState = np.concatenate((self.histState[:, :, 1:], currState), axis = 2) return reward def restart(self): 'Restart environment: set goals and agent location' self.ale.reset() self.life_lost = False self.reachedGoal = [0, 0, 0, 0] for i in range(19): self.act(0) #wait for initialization self.histState = self.initializeHistState() self.agentLastX = self.agentOriginLoc[0] self.agentLastY = self.agentOriginLoc[1] def beginNextLife(self): 'Begin next life without restarting environment: set goals and agent location' self.life_lost = False self.reachedGoal = [0, 0, 0,0] for i in range(19): self.act(0) #wait for initialization self.histState = self.initializeHistState() self.agentLastX = self.agentOriginLoc[0] self.agentLastY = self.agentOriginLoc[1] def detect_agent_with_tracker(self): img = self.getScreenOrig() preds = self.panamajoe_detector.predict(img)[0] if len(preds['boxes']) == 1: box = preds['boxes'][0].detach().cpu().numpy() x, y, width, height = box[0], box[1], box[2]-box[0], box[3]-box[1] self.agentLastX, self.agentLastY = x + width / 2, y + height / 2 return (self.agentLastX, self.agentLastY) def goal_reached(self, goal): if self.panamajoe_detector: [[x1, y1, x2, y2]] = self.goalSet[goal] x_mean, y_mean = self.detect_agent_with_tracker() return x_mean > x1 and x_mean < x2 and y_mean > y1 and y_mean < y2 else: return self.goalReached(goal) def get_goal_direction(self, goal_idx): if ((goal_idx-1) == -1): lastGoalCenter = self.agentOriginLoc else: lastGoalCenter = self.goalCenterLoc[SUBGOAL_ORDER[goal_idx-1]] goal_direction = np.array(self.goalSet[SUBGOAL_ORDER[goal_idx]]).mean(axis=0) - lastGoalCenter norm = np.linalg.norm(goal_direction) if norm == 0: return goal_direction return goal_direction / norm def distanceReward(self, lastGoal, goal): 'Calculate distance between agent and next sub-goal' if lastGoal is None: lastGoalCenter = self.agentOriginLoc else: lastGoalCenter = self.goalCenterLoc[lastGoal] goalCenter = self.goalCenterLoc[goal] agentX, agentY = self.agentLastX, self.agentLastY dis = np.sqrt((goalCenter[0] - agentX)*(goalCenter[0] - agentX) + (goalCenter[1]-agentY)*(goalCenter[1]-agentY)) disLast = np.sqrt((lastGoalCenter[0] - agentX)*(lastGoalCenter[0] - agentX) + (lastGoalCenter[1]-agentY)*(lastGoalCenter[1]-agentY)) disGoals = np.sqrt((goalCenter[0]-lastGoalCenter[0])*(goalCenter[0]-lastGoalCenter[0]) + (goalCenter[1]-lastGoalCenter[1])*(goalCenter[1]-lastGoalCenter[1])) return 0.001 * (disLast - dis) / disGoals def getScreen(self): 'Get processed screen: grayscale and resized' screen = self.ale.render(mode='rgb_array') screen = cv2.cvtColor(screen, cv2.COLOR_RGB2GRAY) resized = cv2.resize(screen, (self.screen_width, self.screen_height), interpolation=cv2.INTER_AREA) return resized def getScreenOrig(self): 'Get original RGB screen' return self.ale.render(mode='rgb_array') def getScreenRGB(self): 'Get RGB screen for finding agent location' screen = self.ale.render(mode='rgb_array') resized = cv2.resize(screen, (self.screen_width, self.screen_height), interpolation=cv2.INTER_AREA) return resized def getState(self): 'Get current state, i.e. current screen. Process screen and add color channel for input of network' screen = self.ale.render(mode='rgb_array') screen = cv2.cvtColor(screen, cv2.COLOR_RGB2GRAY) resized = cv2.resize(screen, (self.screen_width, self.screen_height), interpolation=cv2.INTER_AREA) return np.expand_dims(resized, axis=-1) def getStackedState(self): return self.to_tensor(self.histState) def isTerminal(self): if self.mode == 'train': return self.done or self.life_lost return self.done def isGameOver(self): return self.done def isLifeLost(self): return self.life_lost def reset(self): self.ale.reset() self.life_lost = False def goalReached(self, goal): goalPosition = self.goalSet[goal] goalScreen = self.init_screen stateScreen = self.getState() count = 0 for y in range (goalPosition[0][0], goalPosition[1][0]): for x in range (goalPosition[0][1], goalPosition[1][1]): if goalScreen[x][y] != stateScreen[x][y]: count = count + 1 # 30 is total number of pixels of agent if float(count) / 30 > 0.3: self.reachedGoal[goal] = 1 return True return False def trueGoalReached(self, goal): 'With the AtariARI Wrapper enabled agent locations are updated every step and yield true location' goalPosition = self.goalSet[goal] return self.agentLastX > goalPosition[0] and self.agentLastX < goalPosition[2] and self.agentLastY > goalPosition[1] and self.agentLastY < goalPosition[3] def goalNotReachedBefore(self, goal): if (self.reachedGoal[goal] == 1): return False return True