229 lines
8.5 KiB
Python
229 lines
8.5 KiB
Python
|
import os
|
||
|
import cv2
|
||
|
import gym
|
||
|
import numpy as np
|
||
|
import torchvision.transforms as T
|
||
|
|
||
|
from atariari.benchmark.wrapper import AtariARIWrapper
|
||
|
|
||
|
SUBGOAL_ORDER = [8, 6, 1, 0, 2, 7, 2, 0, 1, 6, 8, 9]
|
||
|
|
||
|
class ALEEnvironment():
|
||
|
|
||
|
def __init__(self, args, device='cpu'):
|
||
|
if 'cuda' in device:
|
||
|
os.environ['CUDA_VISIBLE_DEVICES']= device.split(':')[-1]
|
||
|
|
||
|
self.ale = AtariARIWrapper(gym.make('MontezumaRevenge-v4',
|
||
|
frameskip=args.frame_skip,
|
||
|
render_mode='rgb_array', # return the rgb key in step metadata with the current environment RGB frame.
|
||
|
repeat_action_probability=0.0))
|
||
|
|
||
|
self.histLen = args.sequence_length
|
||
|
print(self.histLen)
|
||
|
|
||
|
self.screen_width = args.screen_width
|
||
|
self.screen_height = args.screen_height
|
||
|
|
||
|
self.actions = np.arange(self.ale.action_space.n)
|
||
|
|
||
|
print('Action space: ', self.actions)
|
||
|
|
||
|
self.mode = "train"
|
||
|
self.life_lost = False
|
||
|
|
||
|
# Set environment with random seed!
|
||
|
self.ale.reset(seed=args.random_seed)
|
||
|
|
||
|
# Get init screen
|
||
|
self.init_screen = self.getScreen()
|
||
|
print(f'Size of screen is {self.init_screen.shape}')
|
||
|
|
||
|
# Perform NOOP action to init self.done and self.info
|
||
|
_, reward, self.done, self.info = self.ale.step(0)
|
||
|
|
||
|
# Use subgoals from gaze analysis
|
||
|
self.goalSet = np.loadtxt('subgoals.txt', dtype=int, delimiter=',')
|
||
|
self.goalCenterLoc = []
|
||
|
|
||
|
for goal in self.goalSet:
|
||
|
goalCenter = [float(goal[0]+goal[2])/2, float(goal[1]+goal[3])/2]
|
||
|
self.goalCenterLoc.append(goalCenter)
|
||
|
|
||
|
self.agentOriginLoc = [42, 33]
|
||
|
self.agentLastX = 42
|
||
|
self.agentLastY = 33
|
||
|
self.reachedGoal = [0, 0, 0,0]
|
||
|
self.histState = self.initializeHistState()
|
||
|
print('Histogram of states: ', self.histState.shape)
|
||
|
|
||
|
self.to_tensor = T.ToTensor()
|
||
|
|
||
|
|
||
|
def initializeHistState(self):
|
||
|
if self.histLen >= 2:
|
||
|
histState = np.concatenate((self.getState(), self.getState()), axis = 2)
|
||
|
for _ in range(self.histLen - 2):
|
||
|
histState = np.concatenate((histState, self.getState()), axis = 2)
|
||
|
else:
|
||
|
histState = self.getState()
|
||
|
|
||
|
return histState
|
||
|
|
||
|
def get_input_shape(self):
|
||
|
return (self.histLen, self.screen_width, self.screen_height)
|
||
|
|
||
|
def numActions(self):
|
||
|
return len(self.actions)
|
||
|
|
||
|
def resetGoalReach(self):
|
||
|
self.reachedGoal = [0, 0, 0,0]
|
||
|
|
||
|
def act(self, action):
|
||
|
'Perform action and handle results'
|
||
|
lives = self.info['lives']
|
||
|
|
||
|
_, reward, self.done, self.info = self.ale.step(self.actions[action])
|
||
|
|
||
|
# agent location from AtariARI Wrapper
|
||
|
self.agentLastX, self.agentLastY = self.info['labels']['player_x'], 320 - self.info['labels']['player_y']
|
||
|
|
||
|
self.life_lost = (not lives == self.info['lives'])
|
||
|
currState = self.getState()
|
||
|
self.histState = np.concatenate((self.histState[:, :, 1:], currState), axis = 2)
|
||
|
return reward
|
||
|
|
||
|
def restart(self):
|
||
|
'Restart environment: set goals and agent location'
|
||
|
self.ale.reset()
|
||
|
self.life_lost = False
|
||
|
self.reachedGoal = [0, 0, 0, 0]
|
||
|
for i in range(19):
|
||
|
self.act(0) #wait for initialization
|
||
|
self.histState = self.initializeHistState()
|
||
|
self.agentLastX = self.agentOriginLoc[0]
|
||
|
self.agentLastY = self.agentOriginLoc[1]
|
||
|
|
||
|
def beginNextLife(self):
|
||
|
'Begin next life without restarting environment: set goals and agent location'
|
||
|
self.life_lost = False
|
||
|
self.reachedGoal = [0, 0, 0,0]
|
||
|
for i in range(19):
|
||
|
self.act(0) #wait for initialization
|
||
|
self.histState = self.initializeHistState()
|
||
|
self.agentLastX = self.agentOriginLoc[0]
|
||
|
self.agentLastY = self.agentOriginLoc[1]
|
||
|
|
||
|
def detect_agent_with_tracker(self):
|
||
|
img = self.getScreenOrig()
|
||
|
preds = self.panamajoe_detector.predict(img)[0]
|
||
|
if len(preds['boxes']) == 1:
|
||
|
box = preds['boxes'][0].detach().cpu().numpy()
|
||
|
x, y, width, height = box[0], box[1], box[2]-box[0], box[3]-box[1]
|
||
|
self.agentLastX, self.agentLastY = x + width / 2, y + height / 2
|
||
|
return (self.agentLastX, self.agentLastY)
|
||
|
|
||
|
def goal_reached(self, goal):
|
||
|
if self.panamajoe_detector:
|
||
|
[[x1, y1, x2, y2]] = self.goalSet[goal]
|
||
|
x_mean, y_mean = self.detect_agent_with_tracker()
|
||
|
return x_mean > x1 and x_mean < x2 and y_mean > y1 and y_mean < y2
|
||
|
else:
|
||
|
return self.goalReached(goal)
|
||
|
|
||
|
def get_goal_direction(self, goal_idx):
|
||
|
if ((goal_idx-1) == -1):
|
||
|
lastGoalCenter = self.agentOriginLoc
|
||
|
else:
|
||
|
lastGoalCenter = self.goalCenterLoc[SUBGOAL_ORDER[goal_idx-1]]
|
||
|
|
||
|
goal_direction = np.array(self.goalSet[SUBGOAL_ORDER[goal_idx]]).mean(axis=0) - lastGoalCenter
|
||
|
norm = np.linalg.norm(goal_direction)
|
||
|
|
||
|
if norm == 0:
|
||
|
return goal_direction
|
||
|
|
||
|
return goal_direction / norm
|
||
|
|
||
|
def distanceReward(self, lastGoal, goal):
|
||
|
'Calculate distance between agent and next sub-goal'
|
||
|
if lastGoal is None:
|
||
|
lastGoalCenter = self.agentOriginLoc
|
||
|
else:
|
||
|
lastGoalCenter = self.goalCenterLoc[lastGoal]
|
||
|
goalCenter = self.goalCenterLoc[goal]
|
||
|
agentX, agentY = self.agentLastX, self.agentLastY
|
||
|
dis = np.sqrt((goalCenter[0] - agentX)*(goalCenter[0] - agentX) + (goalCenter[1]-agentY)*(goalCenter[1]-agentY))
|
||
|
disLast = np.sqrt((lastGoalCenter[0] - agentX)*(lastGoalCenter[0] - agentX) + (lastGoalCenter[1]-agentY)*(lastGoalCenter[1]-agentY))
|
||
|
disGoals = np.sqrt((goalCenter[0]-lastGoalCenter[0])*(goalCenter[0]-lastGoalCenter[0]) + (goalCenter[1]-lastGoalCenter[1])*(goalCenter[1]-lastGoalCenter[1]))
|
||
|
return 0.001 * (disLast - dis) / disGoals
|
||
|
|
||
|
def getScreen(self):
|
||
|
'Get processed screen: grayscale and resized'
|
||
|
screen = self.ale.render(mode='rgb_array')
|
||
|
screen = cv2.cvtColor(screen, cv2.COLOR_RGB2GRAY)
|
||
|
resized = cv2.resize(screen, (self.screen_width, self.screen_height), interpolation=cv2.INTER_AREA)
|
||
|
return resized
|
||
|
|
||
|
def getScreenOrig(self):
|
||
|
'Get original RGB screen'
|
||
|
return self.ale.render(mode='rgb_array')
|
||
|
|
||
|
def getScreenRGB(self):
|
||
|
'Get RGB screen for finding agent location'
|
||
|
screen = self.ale.render(mode='rgb_array')
|
||
|
resized = cv2.resize(screen, (self.screen_width, self.screen_height), interpolation=cv2.INTER_AREA)
|
||
|
return resized
|
||
|
|
||
|
def getState(self):
|
||
|
'Get current state, i.e. current screen. Process screen and add color channel for input of network'
|
||
|
screen = self.ale.render(mode='rgb_array')
|
||
|
screen = cv2.cvtColor(screen, cv2.COLOR_RGB2GRAY)
|
||
|
resized = cv2.resize(screen, (self.screen_width, self.screen_height), interpolation=cv2.INTER_AREA)
|
||
|
return np.expand_dims(resized, axis=-1)
|
||
|
|
||
|
def getStackedState(self):
|
||
|
return self.to_tensor(self.histState)
|
||
|
|
||
|
def isTerminal(self):
|
||
|
if self.mode == 'train':
|
||
|
return self.done or self.life_lost
|
||
|
return self.done
|
||
|
|
||
|
def isGameOver(self):
|
||
|
return self.done
|
||
|
|
||
|
def isLifeLost(self):
|
||
|
return self.life_lost
|
||
|
|
||
|
def reset(self):
|
||
|
self.ale.reset()
|
||
|
self.life_lost = False
|
||
|
|
||
|
def goalReached(self, goal):
|
||
|
goalPosition = self.goalSet[goal]
|
||
|
goalScreen = self.init_screen
|
||
|
stateScreen = self.getState()
|
||
|
count = 0
|
||
|
for y in range (goalPosition[0][0], goalPosition[1][0]):
|
||
|
for x in range (goalPosition[0][1], goalPosition[1][1]):
|
||
|
|
||
|
if goalScreen[x][y] != stateScreen[x][y]:
|
||
|
count = count + 1
|
||
|
|
||
|
# 30 is total number of pixels of agent
|
||
|
if float(count) / 30 > 0.3:
|
||
|
self.reachedGoal[goal] = 1
|
||
|
return True
|
||
|
|
||
|
return False
|
||
|
|
||
|
def trueGoalReached(self, goal):
|
||
|
'With the AtariARI Wrapper enabled agent locations are updated every step and yield true location'
|
||
|
goalPosition = self.goalSet[goal]
|
||
|
return self.agentLastX > goalPosition[0] and self.agentLastX < goalPosition[2] and self.agentLastY > goalPosition[1] and self.agentLastY < goalPosition[3]
|
||
|
|
||
|
def goalNotReachedBefore(self, goal):
|
||
|
if (self.reachedGoal[goal] == 1):
|
||
|
return False
|
||
|
return True
|