Int-HRL/agent/atari_env.py

229 lines
No EOL
8.5 KiB
Python

import os
import cv2
import gym
import numpy as np
import torchvision.transforms as T
from atariari.benchmark.wrapper import AtariARIWrapper
SUBGOAL_ORDER = [8, 6, 1, 0, 2, 7, 2, 0, 1, 6, 8, 9]
class ALEEnvironment():
def __init__(self, args, device='cpu'):
if 'cuda' in device:
os.environ['CUDA_VISIBLE_DEVICES']= device.split(':')[-1]
self.ale = AtariARIWrapper(gym.make('MontezumaRevenge-v4',
frameskip=args.frame_skip,
render_mode='rgb_array', # return the rgb key in step metadata with the current environment RGB frame.
repeat_action_probability=0.0))
self.histLen = args.sequence_length
print(self.histLen)
self.screen_width = args.screen_width
self.screen_height = args.screen_height
self.actions = np.arange(self.ale.action_space.n)
print('Action space: ', self.actions)
self.mode = "train"
self.life_lost = False
# Set environment with random seed!
self.ale.reset(seed=args.random_seed)
# Get init screen
self.init_screen = self.getScreen()
print(f'Size of screen is {self.init_screen.shape}')
# Perform NOOP action to init self.done and self.info
_, reward, self.done, self.info = self.ale.step(0)
# Use subgoals from gaze analysis
self.goalSet = np.loadtxt('subgoals.txt', dtype=int, delimiter=',')
self.goalCenterLoc = []
for goal in self.goalSet:
goalCenter = [float(goal[0]+goal[2])/2, float(goal[1]+goal[3])/2]
self.goalCenterLoc.append(goalCenter)
self.agentOriginLoc = [42, 33]
self.agentLastX = 42
self.agentLastY = 33
self.reachedGoal = [0, 0, 0,0]
self.histState = self.initializeHistState()
print('Histogram of states: ', self.histState.shape)
self.to_tensor = T.ToTensor()
def initializeHistState(self):
if self.histLen >= 2:
histState = np.concatenate((self.getState(), self.getState()), axis = 2)
for _ in range(self.histLen - 2):
histState = np.concatenate((histState, self.getState()), axis = 2)
else:
histState = self.getState()
return histState
def get_input_shape(self):
return (self.histLen, self.screen_width, self.screen_height)
def numActions(self):
return len(self.actions)
def resetGoalReach(self):
self.reachedGoal = [0, 0, 0,0]
def act(self, action):
'Perform action and handle results'
lives = self.info['lives']
_, reward, self.done, self.info = self.ale.step(self.actions[action])
# agent location from AtariARI Wrapper
self.agentLastX, self.agentLastY = self.info['labels']['player_x'], 320 - self.info['labels']['player_y']
self.life_lost = (not lives == self.info['lives'])
currState = self.getState()
self.histState = np.concatenate((self.histState[:, :, 1:], currState), axis = 2)
return reward
def restart(self):
'Restart environment: set goals and agent location'
self.ale.reset()
self.life_lost = False
self.reachedGoal = [0, 0, 0, 0]
for i in range(19):
self.act(0) #wait for initialization
self.histState = self.initializeHistState()
self.agentLastX = self.agentOriginLoc[0]
self.agentLastY = self.agentOriginLoc[1]
def beginNextLife(self):
'Begin next life without restarting environment: set goals and agent location'
self.life_lost = False
self.reachedGoal = [0, 0, 0,0]
for i in range(19):
self.act(0) #wait for initialization
self.histState = self.initializeHistState()
self.agentLastX = self.agentOriginLoc[0]
self.agentLastY = self.agentOriginLoc[1]
def detect_agent_with_tracker(self):
img = self.getScreenOrig()
preds = self.panamajoe_detector.predict(img)[0]
if len(preds['boxes']) == 1:
box = preds['boxes'][0].detach().cpu().numpy()
x, y, width, height = box[0], box[1], box[2]-box[0], box[3]-box[1]
self.agentLastX, self.agentLastY = x + width / 2, y + height / 2
return (self.agentLastX, self.agentLastY)
def goal_reached(self, goal):
if self.panamajoe_detector:
[[x1, y1, x2, y2]] = self.goalSet[goal]
x_mean, y_mean = self.detect_agent_with_tracker()
return x_mean > x1 and x_mean < x2 and y_mean > y1 and y_mean < y2
else:
return self.goalReached(goal)
def get_goal_direction(self, goal_idx):
if ((goal_idx-1) == -1):
lastGoalCenter = self.agentOriginLoc
else:
lastGoalCenter = self.goalCenterLoc[SUBGOAL_ORDER[goal_idx-1]]
goal_direction = np.array(self.goalSet[SUBGOAL_ORDER[goal_idx]]).mean(axis=0) - lastGoalCenter
norm = np.linalg.norm(goal_direction)
if norm == 0:
return goal_direction
return goal_direction / norm
def distanceReward(self, lastGoal, goal):
'Calculate distance between agent and next sub-goal'
if lastGoal is None:
lastGoalCenter = self.agentOriginLoc
else:
lastGoalCenter = self.goalCenterLoc[lastGoal]
goalCenter = self.goalCenterLoc[goal]
agentX, agentY = self.agentLastX, self.agentLastY
dis = np.sqrt((goalCenter[0] - agentX)*(goalCenter[0] - agentX) + (goalCenter[1]-agentY)*(goalCenter[1]-agentY))
disLast = np.sqrt((lastGoalCenter[0] - agentX)*(lastGoalCenter[0] - agentX) + (lastGoalCenter[1]-agentY)*(lastGoalCenter[1]-agentY))
disGoals = np.sqrt((goalCenter[0]-lastGoalCenter[0])*(goalCenter[0]-lastGoalCenter[0]) + (goalCenter[1]-lastGoalCenter[1])*(goalCenter[1]-lastGoalCenter[1]))
return 0.001 * (disLast - dis) / disGoals
def getScreen(self):
'Get processed screen: grayscale and resized'
screen = self.ale.render(mode='rgb_array')
screen = cv2.cvtColor(screen, cv2.COLOR_RGB2GRAY)
resized = cv2.resize(screen, (self.screen_width, self.screen_height), interpolation=cv2.INTER_AREA)
return resized
def getScreenOrig(self):
'Get original RGB screen'
return self.ale.render(mode='rgb_array')
def getScreenRGB(self):
'Get RGB screen for finding agent location'
screen = self.ale.render(mode='rgb_array')
resized = cv2.resize(screen, (self.screen_width, self.screen_height), interpolation=cv2.INTER_AREA)
return resized
def getState(self):
'Get current state, i.e. current screen. Process screen and add color channel for input of network'
screen = self.ale.render(mode='rgb_array')
screen = cv2.cvtColor(screen, cv2.COLOR_RGB2GRAY)
resized = cv2.resize(screen, (self.screen_width, self.screen_height), interpolation=cv2.INTER_AREA)
return np.expand_dims(resized, axis=-1)
def getStackedState(self):
return self.to_tensor(self.histState)
def isTerminal(self):
if self.mode == 'train':
return self.done or self.life_lost
return self.done
def isGameOver(self):
return self.done
def isLifeLost(self):
return self.life_lost
def reset(self):
self.ale.reset()
self.life_lost = False
def goalReached(self, goal):
goalPosition = self.goalSet[goal]
goalScreen = self.init_screen
stateScreen = self.getState()
count = 0
for y in range (goalPosition[0][0], goalPosition[1][0]):
for x in range (goalPosition[0][1], goalPosition[1][1]):
if goalScreen[x][y] != stateScreen[x][y]:
count = count + 1
# 30 is total number of pixels of agent
if float(count) / 30 > 0.3:
self.reachedGoal[goal] = 1
return True
return False
def trueGoalReached(self, goal):
'With the AtariARI Wrapper enabled agent locations are updated every step and yield true location'
goalPosition = self.goalSet[goal]
return self.agentLastX > goalPosition[0] and self.agentLastX < goalPosition[2] and self.agentLastY > goalPosition[1] and self.agentLastY < goalPosition[3]
def goalNotReachedBefore(self, goal):
if (self.reachedGoal[goal] == 1):
return False
return True