add Int-HRL agent scripts

This commit is contained in:
Anna Penzkofer 2025-03-12 18:46:09 +01:00
parent d410acde26
commit e6f9ace2e3
7 changed files with 1803 additions and 1 deletions

View file

@ -1,6 +1,6 @@
# Int-HRL # Int-HRL
This is the official repository for [Int-HRL: Towards Intention-based Hierarchical Reinforcement Learning](https://perceptualui.org/publications/penzkofer23_ala/)<br> This is the official repository for [Int-HRL: Towards Intention-based Hierarchical Reinforcement Learning](https://collaborative-ai.org/publications/penzkofer24_ncaa/)<br>
Int-HRL uses eye gaze from human demonstration data on the Atari game Montezuma's Revenge to extract human player's intentions and converts them to sub-goals for Hierarchical Reinforcement Learning (HRL). For further details take a look at the corresponding paper. Int-HRL uses eye gaze from human demonstration data on the Atari game Montezuma's Revenge to extract human player's intentions and converts them to sub-goals for Hierarchical Reinforcement Learning (HRL). For further details take a look at the corresponding paper.
@ -17,8 +17,16 @@ To pre-process the Atari-HEAD data run [Preprocess_AtariHEAD.ipynb](Preprocess_A
3. [Alignment with Trajectory](TrajectoryMatching.ipynb): run expert trajectory to get order of subgoals 3. [Alignment with Trajectory](TrajectoryMatching.ipynb): run expert trajectory to get order of subgoals
## Intention-based Hierarchical RL Agent ## Intention-based Hierarchical RL Agent
The Int-HRL agent is based on the hierarchically guided Imitation Learning method (hg-DAgger/Q), where we adapted code from [https://github.com/hoangminhle/hierarchical_IL_RL](https://github.com/hoangminhle/hierarchical_IL_RL) <br>
Due to the novel sub-goal extraction pipeline, our agent does not require experts during training and is more than three times more sample efficient compared to hg-DAgger/Q. <br>
To run the full agent with 12 separate low-level agents for sub-goal execution, run [agent/run_experiment.py](agent/run_experiment.py), for single agents (one low-level agent for all sub-goals) run [agent/single_agent_experiment.py](agent/single_agent_experiment.py).
## Extension to Venture and Hero
under construction under construction
## Citation ## Citation
Please consider citing these paper if you use Int-HRL or parts of this repository in your research: Please consider citing these paper if you use Int-HRL or parts of this repository in your research:
``` ```

229
agent/atari_env.py Normal file
View file

@ -0,0 +1,229 @@
import os
import cv2
import gym
import numpy as np
import torchvision.transforms as T
from atariari.benchmark.wrapper import AtariARIWrapper
SUBGOAL_ORDER = [8, 6, 1, 0, 2, 7, 2, 0, 1, 6, 8, 9]
class ALEEnvironment():
def __init__(self, args, device='cpu'):
if 'cuda' in device:
os.environ['CUDA_VISIBLE_DEVICES']= device.split(':')[-1]
self.ale = AtariARIWrapper(gym.make('MontezumaRevenge-v4',
frameskip=args.frame_skip,
render_mode='rgb_array', # return the rgb key in step metadata with the current environment RGB frame.
repeat_action_probability=0.0))
self.histLen = args.sequence_length
print(self.histLen)
self.screen_width = args.screen_width
self.screen_height = args.screen_height
self.actions = np.arange(self.ale.action_space.n)
print('Action space: ', self.actions)
self.mode = "train"
self.life_lost = False
# Set environment with random seed!
self.ale.reset(seed=args.random_seed)
# Get init screen
self.init_screen = self.getScreen()
print(f'Size of screen is {self.init_screen.shape}')
# Perform NOOP action to init self.done and self.info
_, reward, self.done, self.info = self.ale.step(0)
# Use subgoals from gaze analysis
self.goalSet = np.loadtxt('subgoals.txt', dtype=int, delimiter=',')
self.goalCenterLoc = []
for goal in self.goalSet:
goalCenter = [float(goal[0]+goal[2])/2, float(goal[1]+goal[3])/2]
self.goalCenterLoc.append(goalCenter)
self.agentOriginLoc = [42, 33]
self.agentLastX = 42
self.agentLastY = 33
self.reachedGoal = [0, 0, 0,0]
self.histState = self.initializeHistState()
print('Histogram of states: ', self.histState.shape)
self.to_tensor = T.ToTensor()
def initializeHistState(self):
if self.histLen >= 2:
histState = np.concatenate((self.getState(), self.getState()), axis = 2)
for _ in range(self.histLen - 2):
histState = np.concatenate((histState, self.getState()), axis = 2)
else:
histState = self.getState()
return histState
def get_input_shape(self):
return (self.histLen, self.screen_width, self.screen_height)
def numActions(self):
return len(self.actions)
def resetGoalReach(self):
self.reachedGoal = [0, 0, 0,0]
def act(self, action):
'Perform action and handle results'
lives = self.info['lives']
_, reward, self.done, self.info = self.ale.step(self.actions[action])
# agent location from AtariARI Wrapper
self.agentLastX, self.agentLastY = self.info['labels']['player_x'], 320 - self.info['labels']['player_y']
self.life_lost = (not lives == self.info['lives'])
currState = self.getState()
self.histState = np.concatenate((self.histState[:, :, 1:], currState), axis = 2)
return reward
def restart(self):
'Restart environment: set goals and agent location'
self.ale.reset()
self.life_lost = False
self.reachedGoal = [0, 0, 0, 0]
for i in range(19):
self.act(0) #wait for initialization
self.histState = self.initializeHistState()
self.agentLastX = self.agentOriginLoc[0]
self.agentLastY = self.agentOriginLoc[1]
def beginNextLife(self):
'Begin next life without restarting environment: set goals and agent location'
self.life_lost = False
self.reachedGoal = [0, 0, 0,0]
for i in range(19):
self.act(0) #wait for initialization
self.histState = self.initializeHistState()
self.agentLastX = self.agentOriginLoc[0]
self.agentLastY = self.agentOriginLoc[1]
def detect_agent_with_tracker(self):
img = self.getScreenOrig()
preds = self.panamajoe_detector.predict(img)[0]
if len(preds['boxes']) == 1:
box = preds['boxes'][0].detach().cpu().numpy()
x, y, width, height = box[0], box[1], box[2]-box[0], box[3]-box[1]
self.agentLastX, self.agentLastY = x + width / 2, y + height / 2
return (self.agentLastX, self.agentLastY)
def goal_reached(self, goal):
if self.panamajoe_detector:
[[x1, y1, x2, y2]] = self.goalSet[goal]
x_mean, y_mean = self.detect_agent_with_tracker()
return x_mean > x1 and x_mean < x2 and y_mean > y1 and y_mean < y2
else:
return self.goalReached(goal)
def get_goal_direction(self, goal_idx):
if ((goal_idx-1) == -1):
lastGoalCenter = self.agentOriginLoc
else:
lastGoalCenter = self.goalCenterLoc[SUBGOAL_ORDER[goal_idx-1]]
goal_direction = np.array(self.goalSet[SUBGOAL_ORDER[goal_idx]]).mean(axis=0) - lastGoalCenter
norm = np.linalg.norm(goal_direction)
if norm == 0:
return goal_direction
return goal_direction / norm
def distanceReward(self, lastGoal, goal):
'Calculate distance between agent and next sub-goal'
if lastGoal is None:
lastGoalCenter = self.agentOriginLoc
else:
lastGoalCenter = self.goalCenterLoc[lastGoal]
goalCenter = self.goalCenterLoc[goal]
agentX, agentY = self.agentLastX, self.agentLastY
dis = np.sqrt((goalCenter[0] - agentX)*(goalCenter[0] - agentX) + (goalCenter[1]-agentY)*(goalCenter[1]-agentY))
disLast = np.sqrt((lastGoalCenter[0] - agentX)*(lastGoalCenter[0] - agentX) + (lastGoalCenter[1]-agentY)*(lastGoalCenter[1]-agentY))
disGoals = np.sqrt((goalCenter[0]-lastGoalCenter[0])*(goalCenter[0]-lastGoalCenter[0]) + (goalCenter[1]-lastGoalCenter[1])*(goalCenter[1]-lastGoalCenter[1]))
return 0.001 * (disLast - dis) / disGoals
def getScreen(self):
'Get processed screen: grayscale and resized'
screen = self.ale.render(mode='rgb_array')
screen = cv2.cvtColor(screen, cv2.COLOR_RGB2GRAY)
resized = cv2.resize(screen, (self.screen_width, self.screen_height), interpolation=cv2.INTER_AREA)
return resized
def getScreenOrig(self):
'Get original RGB screen'
return self.ale.render(mode='rgb_array')
def getScreenRGB(self):
'Get RGB screen for finding agent location'
screen = self.ale.render(mode='rgb_array')
resized = cv2.resize(screen, (self.screen_width, self.screen_height), interpolation=cv2.INTER_AREA)
return resized
def getState(self):
'Get current state, i.e. current screen. Process screen and add color channel for input of network'
screen = self.ale.render(mode='rgb_array')
screen = cv2.cvtColor(screen, cv2.COLOR_RGB2GRAY)
resized = cv2.resize(screen, (self.screen_width, self.screen_height), interpolation=cv2.INTER_AREA)
return np.expand_dims(resized, axis=-1)
def getStackedState(self):
return self.to_tensor(self.histState)
def isTerminal(self):
if self.mode == 'train':
return self.done or self.life_lost
return self.done
def isGameOver(self):
return self.done
def isLifeLost(self):
return self.life_lost
def reset(self):
self.ale.reset()
self.life_lost = False
def goalReached(self, goal):
goalPosition = self.goalSet[goal]
goalScreen = self.init_screen
stateScreen = self.getState()
count = 0
for y in range (goalPosition[0][0], goalPosition[1][0]):
for x in range (goalPosition[0][1], goalPosition[1][1]):
if goalScreen[x][y] != stateScreen[x][y]:
count = count + 1
# 30 is total number of pixels of agent
if float(count) / 30 > 0.3:
self.reachedGoal[goal] = 1
return True
return False
def trueGoalReached(self, goal):
'With the AtariARI Wrapper enabled agent locations are updated every step and yield true location'
goalPosition = self.goalSet[goal]
return self.agentLastX > goalPosition[0] and self.agentLastX < goalPosition[2] and self.agentLastY > goalPosition[1] and self.agentLastY < goalPosition[3]
def goalNotReachedBefore(self, goal):
if (self.reachedGoal[goal] == 1):
return False
return True

308
agent/ddqn_agent.py Normal file
View file

@ -0,0 +1,308 @@
from collections import namedtuple
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchinfo import summary
import numpy as np
from tqdm import tqdm
from replay_buffer import PrioritizedReplayBuffer, LinearSchedule, TanHSchedule
prioritized_replay_alpha = 0.6
max_timesteps=1000000
prioritized_replay_beta0=0.4
prioritized_replay_eps=1e-6
prioritized_replay_beta_iters = max_timesteps*0.5
beta_schedule = LinearSchedule(prioritized_replay_beta_iters,
initial_p=prioritized_replay_beta0,
final_p=1.0)
Experience = namedtuple('Experience', field_names=['state', 'action', 'reward', 'next_state', 'done'])
# Deep Q Network
class DQN(nn.Module):
def __init__(self, device, input_shape, n_actions, hidden_nodes=512, goal_feature=False):
super(DQN, self).__init__()
self.device = device
self.conv = nn.Sequential(
nn.Conv2d(input_shape[0], 16, kernel_size=5, stride=2),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.Conv2d(16, 32, kernel_size=5, stride=2),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32, 32, kernel_size=5, stride=2),
nn.BatchNorm2d(32),
nn.ReLU()
)
conv_out_size = self._get_conv_out(input_shape)
if goal_feature:
conv_out_size += 4
# Standard DQN
self.fc = nn.Sequential(
nn.Linear(conv_out_size, hidden_nodes),
nn.LeakyReLU(),
nn.Linear(hidden_nodes, n_actions),
# nn.Softmax(dim=1) # necessary for Boltzman Sampling
)
def _get_conv_out(self, shape):
o = self.conv(torch.zeros(1, *shape))
return int(np.prod(o.size()))
def forward(self, x, goal=None):
x = x.to(self.device)
conv_out = self.conv(x).view(x.size()[0], -1)
if not goal is None:
conv_out = torch.cat((conv_out, goal.squeeze(1)), dim=1)
return self.fc(conv_out)
# Dueling Deep Q Network
class DuelingDQN(nn.Module):
def __init__(self, device, input_shape, n_actions, hidden_nodes=512, goal_feature=False):
super(DuelingDQN, self).__init__()
self.device = device
self.conv = nn.Sequential(
nn.Conv2d(input_shape[0], 16, kernel_size=5, stride=2),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.Conv2d(16, 32, kernel_size=5, stride=2),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32, 32, kernel_size=5, stride=2),
nn.BatchNorm2d(32),
nn.ReLU()
)
conv_out_size = self._get_conv_out(input_shape)
if goal_feature:
conv_out_size += 4
# Dueling DQN --> two heads
# advantage of actions A(s,a)
self.fc_adv = nn.Sequential(
nn.Linear(conv_out_size, hidden_nodes),
nn.LeakyReLU(),
nn.Linear(hidden_nodes, n_actions)
)
# value of states V(s)
self.fc_val = nn.Sequential(
nn.Linear(conv_out_size, hidden_nodes),
nn.LeakyReLU(),
nn.Linear(hidden_nodes, 1)
)
def _get_conv_out(self, shape):
o = self.conv(torch.zeros(1, *shape))
return int(np.prod(o.size()))
def forward(self, x, goal=None):
x = x.to(self.device)
fx = x.float() / 256
conv_out = self.conv(fx).view(fx.size()[0], -1)
if not goal is None:
conv_out = torch.cat((conv_out, goal.squeeze(1)), dim=1)
val = self.fc_val(conv_out)
adv = self.fc_adv(conv_out)
# Q(s,a) = V(s) + A(s,a) - 1/N sum(A(s,k)) --> expected value of advantage = 0
return val + (adv - adv.mean(dim=1, keepdim=True))
class MaskedHuberLoss(nn.Module):
def __init__(self, delta=1.0):
super(MaskedHuberLoss, self).__init__()
self.criterion = nn.HuberLoss(reduction='none', delta=delta)
def forward(self, inputs, targets, mask):
loss = self.criterion(inputs, targets)
loss *= mask
return loss.sum(dim=-1)
# Agent that performs, remembers and learns actions
class Agent():
def __init__(self, device, goal, input_shape, n_actions, random_play_steps, args, dueling=True, goal_feature=False):
self.device = device
self.n_actions = n_actions
self.goal_feature = goal_feature
if dueling:
self.policy_net = DuelingDQN(device, input_shape, n_actions, args.hidden_nodes, goal_feature).to(self.device)
self.target_net = DuelingDQN(device, input_shape, n_actions, args.hidden_nodes, goal_feature).to(self.device)
else:
self.policy_net = DQN(device, input_shape, n_actions, args.hidden_nodes, goal_feature).to(self.device)
self.target_net = DQN(device, input_shape, n_actions, args.hidden_nodes, goal_feature).to(self.device)
#self.optimizer = optim.RMSprop(self.policy_net.parameters(), lr=args.learning_rate)
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=args.learning_rate)
self.memory = PrioritizedReplayBuffer(args.experience_buffer, alpha=0.6)
self.exploration = LinearSchedule(schedule_timesteps=args.epsilon_steps, initial_p=args.epsilon_start, final_p=args.epsilon_end)
#self.exploration = TanHSchedule(schedule_timesteps=(args.epsilon_steps//10), initial_p=args.epsilon_start, final_p=args.epsilon_end)
self.loss_fn = MaskedHuberLoss(delta=1.0)
self.steps_done = -1
self.batch_size = args.batch_size
self.gamma = args.gamma
self.goal = goal
self.enable_double_dqn = True
self.hard_update = 1000
self.training_done = False
self.verbose = args.verbose
self.random_play_steps = random_play_steps
self.args = args
def remember(self, state, action, reward, next_state, done, goal):
self.memory.add(state, action, reward, next_state, done, goal)
def load_model(self, path):
#summary(self.policy_net, (4, 84, 84))
self.policy_net.load_state_dict(torch.load(path))
self.policy_net.eval()
self.random_play_steps = 100
self.training_done = True
self.exploration = LinearSchedule(schedule_timesteps=1000, initial_p=0.1, final_p=0.005)
print('Loaded DDQN policy network weights from ', path)
def get_epsilon(self):
return self.exploration.value(self.steps_done - self.random_play_steps)
def reset_epsilon_schedule(self):
if isinstance(self.exploration, TanHSchedule):
self.exploration.restart()
else:
print('Exploration schedule is linear and can not be reset!')
def select_action(self, state, goal=None):
#if not self.training_done:
self.steps_done += 1
# Select an action according to an epsilon greedy approach
if self.steps_done < self.random_play_steps or random.random() < self.get_epsilon():
return torch.tensor([[random.randrange(0, self.n_actions)]], device=self.device, dtype=torch.long)
else:
with torch.no_grad():
return self.policy_net(state, goal).max(1)[1].view(1, 1)
# Select an action via Boltzmann Sampling
if self.steps_done > self.random_play_steps:
return self.sample(state, temperature=self.get_epsilon())
else:
return torch.tensor([[random.randrange(0, self.n_actions)]], device=self.device, dtype=torch.long)
"""
else:
with torch.no_grad():
return self.policy_net(state).max(1)[1].view(1, 1)
"""
def sample(self, state, temperature=0.1, goal=None):
'Predict action probability vector and use Boltzmann Sampling to get new action'
with torch.no_grad():
prob_vec = self.policy_net(state, goal)
sample_prob = torch.exp(prob_vec / temperature)
sample_prob /= sample_prob.sum()
try:
sample_idx = sample_prob.multinomial(num_samples=1) # random choice with probabilities
except Exception as e:
# print(temperature, prob_vec, sample_prob)
# RuntimeError: probability tensor contains either `inf`, `nan` or element < 0
import logging
logging.basicConfig(filename=f'runs/{self.args.model_name}sample_prob.log', filemode='w', format='%(name)s - %(levelname)s - %(message)s')
logging.warning(f'[Agent {self.goal} - step {self.steps_done}] {e}: {sample_prob}')
sample_idx = torch.tensor([[random.randrange(0, self.n_actions)]], device=self.device, dtype=torch.long)
return sample_idx.view(1,1)
def finish_training(self, model_name):
torch.save(self.policy_net.state_dict(), f"runs/run{model_name}-agent{self.goal}.pt")
self.training_done = True
# Free memory by deleting target net (no longer needed)
del self.target_net
def optimize_model(self):
if len(self.memory) < self.batch_size or self.steps_done < self.random_play_steps or self.training_done:
return
# Sample from prioritized replay memory
sample = self.memory.sample(self.batch_size, beta=beta_schedule.value(self.steps_done))
states, actions, rewards, next_states, dones, goals, importance_v, idx_vector = sample
states_v = states.to(self.device)
next_states_v = next_states.to(self.device)
actions_v = actions.to(self.device)
rewards_v = rewards.to(self.device)
done_mask = dones.to(self.device)
goals_v = goals.to(self.device)
if not self.goal_feature:
goals_v = None
# predicted Q(s, a)
q_values = self.policy_net(states_v, goals_v)
state_action_values = q_values.gather(1, actions_v.squeeze(-1))
if self.enable_double_dqn:
# calculate max_a Q'(s_t+1, argmax_a Q(s_t+1, a))
next_state_actions = self.policy_net(next_states_v, goals_v).max(1)[1] # [1] = argmax
next_state_values = self.target_net(next_states_v, goals_v).gather(1, next_state_actions.unsqueeze(-1)).squeeze(-1)
else:
# calculate max_a Q'(s_t+1, a)
next_state_values = self.target_net(next_states_v, goals_v).max(1)[0] # [0] = max
# Set discounted reward to zero for all states that were terminal
next_state_values[done_mask] = 0.0
# Compute r_t + gamma * max_a Q'(s_t+1, a)
expected_state_action_values = next_state_values.detach() * self.gamma + rewards_v
targets = np.zeros((self.batch_size, self.n_actions))
dummy_targets = np.zeros((self.batch_size,))
masks = np.zeros((self.batch_size, self.n_actions))
for idx, (target, mask, reward, action) in enumerate(zip(targets, masks, expected_state_action_values, actions_v)):
target[action] = reward # update action with estimated accumulated reward
dummy_targets[idx] = reward
mask[action] = 1. # enable loss for this specific action
# Calculate TD [Temporal Difference] error
action_batch = actions_v.squeeze().cpu().detach().numpy()
td_errors = targets[range(self.batch_size), action_batch] - state_action_values.squeeze().cpu().detach().numpy()
# Compute masked loss between predicted state action values and targets
targets = torch.tensor(targets, dtype=torch.float32).to(self.device)
masks = torch.tensor(masks, dtype=torch.float32).to(self.device)
loss = self.loss_fn(q_values, targets , masks)
# Update models and memory only when training is not done
if not self.training_done:
# Update priorities in replay buffer
new_priorities = np.abs(td_errors) + prioritized_replay_eps
self.memory.update_priorities(idx_vector, new_priorities)
self.optimizer.zero_grad()
loss.mean().backward() # https://discuss.pytorch.org/t/loss-backward-raises-error-grad-can-be-implicitly-created-only-for-scalar-outputs/12152/2
for param in self.policy_net.parameters():
param.grad.data.clamp_(-1, 1)
self.optimizer.step()
# Update the target network
if self.steps_done % self.hard_update == 0:
if self.verbose >= 1:
print('\nUpdating target network...\n')
self.target_net.load_state_dict(self.policy_net.state_dict())
return loss

158
agent/metacontroller.py Normal file
View file

@ -0,0 +1,158 @@
from collections import deque
import torch
import torch.nn as nn
import torch.optim as optim
from torchmetrics import Accuracy
import numpy as np
BATCH_SIZE = 32
TRAIN_HIST_SIZE = 10000
P_DROPOUT = 0.5
class MetaNN(nn.Module):
def __init__(self, device, input_shape=(4,84,84), n_goals=4, hidden_nodes=512):
super(MetaNN, self).__init__()
self.device = device
### Setup model architecture ###
self.conv = nn.Sequential(
nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
nn.ReLU(),
nn.Dropout(p=P_DROPOUT),
nn.Conv2d(32, 64, kernel_size=4, stride=2),
nn.ReLU(),
nn.Dropout(p=P_DROPOUT),
nn.Conv2d(64, 64, kernel_size=3, stride=1),
nn.ReLU(),
nn.Dropout(p=P_DROPOUT),
)
conv_out_size = self._get_conv_out(input_shape)
self.fc = nn.Sequential(
nn.Linear(conv_out_size, hidden_nodes), #TODO: copy initialization from meta_net_il.py
nn.ReLU(),
nn.Dropout(p=P_DROPOUT),
nn.Linear(hidden_nodes, n_goals),
nn.Softmax(dim=1)
)
def _get_conv_out(self, shape):
o = self.conv(torch.zeros(1, *shape))
return int(np.prod(o.size()))
def forward(self, x):
x = x.to(self.device)
conv_out = self.conv(x).view(x.size()[0], -1)
return self.fc(conv_out)
class MetaController():
def __init__(self, device, args, input_shape=(4, 84, 84), n_goals=4, hidden_nodes=512) -> None:
self.device = device
self.model = MetaNN(device, input_shape, n_goals, hidden_nodes).to(self.device)
print('Saving init state of MetaNN')
torch.save(self.model.state_dict(), f"runs/meta_init-{args.model_name}.pt")
self.optimizer = optim.RMSprop(self.model.parameters(), lr=0.00025, alpha=0.95, eps=1e-08, weight_decay=0.0)
self.loss_fn = nn.MSELoss()
self.accuracy_fn = Accuracy(num_classes=n_goals, average='macro').to(device)
self.replay_hist = deque([None], maxlen=TRAIN_HIST_SIZE)
self.ind = 0
self.count = 0
self.meta_steps = 0
self.input_shape = input_shape
self.n_goals = n_goals
self.verbose = args.verbose
self.name = args.model_name
def reset(self) -> None:
'Load initial state dictionary from file and reset optimizer'
self.model.load_state_dict(torch.load(f"runs/meta_init-{self.name}.pt"))
self.optimizer = optim.RMSprop(self.model.parameters(), lr=0.00025, alpha=0.95, eps=1e-08, weight_decay=0.0)
def check_training_clock(self) -> bool:
'Only train every <BATCH_SIZE> meta controller steps, i.e. <BATCH_SIZE> new samples in replay buffer.'
if BATCH_SIZE:
return (self.meta_steps % BATCH_SIZE == 0)
else:
return (self.meta_steps % 20 == 0)
def collect(self, processed, expert_a) -> None:
'Collect sample consisting of state (4, 84, 84) and one-hot vector of goal'
if processed is not None:
self.replay_hist.appendleft(tuple([processed, expert_a]))
self.meta_steps += 1
def train(self):
# if not reached TRAIN_HIST_SIZE yet, then get the number of samples
num_samples = min(self.meta_steps, TRAIN_HIST_SIZE)
inputs = torch.stack([self.replay_hist[i][0] for i in range(num_samples)], 0).to(self.device)
labels = torch.stack([torch.tensor(self.replay_hist[i][1], dtype=torch.float32) for i in range(num_samples)], 0).to(self.device)
if self.verbose >= 2.0:
print(f'\nMetacontroller collected {self.meta_steps} samples')
if len(labels) == TRAIN_HIST_SIZE:
print(f'Reached TRAIN_HIST_SIZE = {TRAIN_HIST_SIZE}')
print('Dataset Distribution:')
for goal in labels.unique():
goal = goal.cpu().detach().numpy()
print(f'\nNumber of samples for goal {goal}: {sum(labels).squeeze()[goal]}' )
print(f'--> {sum(labels).squeeze()[goal] / len(labels):.2%}')
print()
if BATCH_SIZE and num_samples >= BATCH_SIZE:
# train one epoch --> noisy convergence more likely to find broader minimum
accumulated_loss = []
accumulated_acc = []
for index in range(0, len(labels) // BATCH_SIZE):
b_inputs, b_labels = inputs[index * BATCH_SIZE: (index + 1) * BATCH_SIZE], labels[index * BATCH_SIZE: (index + 1) * BATCH_SIZE]
# zero the parameter gradients
self.optimizer.zero_grad()
outputs = self.model(b_inputs)
loss = self.loss_fn(outputs, b_labels.squeeze(1))
loss.backward()
self.optimizer.step()
accumulated_loss.append(loss)
accumulated_acc.append(self.accuracy_fn(outputs, b_labels.squeeze(1).type(torch.uint8)))
loss = torch.stack(accumulated_loss).mean()
accuracy = torch.stack(accumulated_acc).mean()
else:
# run once over all samples --> smooth convergence to a deep local minimum
self.optimizer.zero_grad()
outputs = self.model(inputs)
loss = self.loss_fn(outputs, labels.squeeze(1))
accuracy = self.accuracy_fn(outputs, labels.squeeze(1).type(torch.uint8))
loss.backward()
self.optimizer.step()
self.count = 0 # reset the count clock
return loss, accuracy
def predict(self, state, batch_size=1) -> np.ndarray:
'Predict probability distribution of goals with metacontroller model and return as ndarray for sampling'
return self.model.forward(state.unsqueeze(0)).squeeze(0).detach().cpu().numpy()
def sample(self, prob_vec, temperature=0.1) -> int:
prob_pred = np.log(prob_vec) / temperature
dist = np.exp(prob_pred)/np.sum(np.exp(prob_pred))
choices = range(len(prob_pred))
return np.random.choice(choices, p=dist)

408
agent/replay_buffer.py Normal file
View file

@ -0,0 +1,408 @@
# This script is taken from openAI's baselines implementation and adapted for hInt-RL
# ===================================================================================================================
import random
import operator
from collections import namedtuple
import torch
import numpy as np
Experience = namedtuple('Experience', field_names=['state', 'action', 'reward', 'next_state', 'done', 'goal'])
class LinearSchedule(object):
def __init__(self, schedule_timesteps, final_p, initial_p=1.0):
"""Linear interpolation between initial_p and final_p over
schedule_timesteps. After this many timesteps pass final_p is
returned.
Parameters
----------
schedule_timesteps: int
Number of timesteps for which to linearly anneal initial_p
to final_p
initial_p: float
initial output value
final_p: float
final output value
"""
self.schedule_timesteps = schedule_timesteps
self.final_p = final_p
self.initial_p = initial_p
def value(self, t):
"""See Schedule.value"""
fraction = min(float(t) / self.schedule_timesteps, 1.0)
return self.initial_p + fraction * (self.final_p - self.initial_p)
class TanHSchedule(object):
def __init__(self, schedule_timesteps, final_p, initial_p=1.0):
"""This is a tanh annealing schedule with restarts, i.e. hyperbolic tangent decay.
Parameters
----------
schedule_timesteps: int
Number of overall timesteps for which to anneal initial_p to final_p
initial_p: float
max value
final_p: float
min value
"""
self.schedule_timesteps = schedule_timesteps
self.final_p = final_p
self.initial_p = initial_p
self.tt = 0
def value(self, t):
if t > 0:
# otherwise random play steps are active
self.tt += 1.
tanh = np.tanh(self.tt/(self.schedule_timesteps) - 4.)
return (self.initial_p + self.final_p) / 2. - (self.initial_p - self.final_p) / 2. * tanh
def restart(self):
print('Resetting tanh cycle...')
self.tt = 0
class SegmentTree(object):
def __init__(self, capacity, operation, neutral_element):
"""Build a Segment Tree data structure.
https://en.wikipedia.org/wiki/Segment_tree
Can be used as regular array, but with two
important differences:
a) setting item's value is slightly slower.
It is O(lg capacity) instead of O(1).
b) user has access to an efficient `reduce`
operation which reduces `operation` over
a contiguous subsequence of items in the
array.
Paramters
---------
capacity: int
Total size of the array - must be a power of two.
operation: lambda obj, obj -> obj
and operation for combining elements (eg. sum, max)
must for a mathematical group together with the set of
possible values for array elements.
neutral_element: obj
neutral element for the operation above. eg. float('-inf')
for max and 0 for sum.
"""
assert capacity > 0 and capacity & (capacity - 1) == 0, "capacity must be positive and a power of 2."
self._capacity = capacity
self._value = [neutral_element for _ in range(2 * capacity)]
self._operation = operation
def _reduce_helper(self, start, end, node, node_start, node_end):
if start == node_start and end == node_end:
return self._value[node]
mid = (node_start + node_end) // 2
if end <= mid:
return self._reduce_helper(start, end, 2 * node, node_start, mid)
else:
if mid + 1 <= start:
return self._reduce_helper(start, end, 2 * node + 1, mid + 1, node_end)
else:
return self._operation(
self._reduce_helper(start, mid, 2 * node, node_start, mid),
self._reduce_helper(mid + 1, end, 2 * node + 1, mid + 1, node_end)
)
def reduce(self, start=0, end=None):
"""Returns result of applying `self.operation`
to a contiguous subsequence of the array.
self.operation(arr[start], operation(arr[start+1], operation(... arr[end])))
Parameters
----------
start: int
beginning of the subsequence
end: int
end of the subsequences
Returns
-------
reduced: obj
result of reducing self.operation over the specified range of array elements.
"""
if end is None:
end = self._capacity
if end < 0:
end += self._capacity
end -= 1
return self._reduce_helper(start, end, 1, 0, self._capacity - 1)
def __setitem__(self, idx, val):
# index of the leaf
idx += self._capacity
self._value[idx] = val
idx //= 2
while idx >= 1:
self._value[idx] = self._operation(
self._value[2 * idx],
self._value[2 * idx + 1]
)
idx //= 2
def __getitem__(self, idx):
assert 0 <= idx < self._capacity
return self._value[self._capacity + idx]
class SumSegmentTree(SegmentTree):
def __init__(self, capacity):
super(SumSegmentTree, self).__init__(
capacity=capacity,
operation=operator.add,
neutral_element=0.0
)
def sum(self, start=0, end=None):
"""Returns arr[start] + ... + arr[end]"""
return super(SumSegmentTree, self).reduce(start, end)
def find_prefixsum_idx(self, prefixsum):
"""Find the highest index `i` in the array such that
sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum
if array values are probabilities, this function
allows to sample indexes according to the discrete
probability efficiently.
Parameters
----------
perfixsum: float
upperbound on the sum of array prefix
Returns
-------
idx: int
highest index satisfying the prefixsum constraint
"""
assert 0 <= prefixsum <= self.sum() + 1e-5
idx = 1
while idx < self._capacity: # while non-leaf
if self._value[2 * idx] > prefixsum:
idx = 2 * idx
else:
prefixsum -= self._value[2 * idx]
idx = 2 * idx + 1
return idx - self._capacity
class MinSegmentTree(SegmentTree):
def __init__(self, capacity):
super(MinSegmentTree, self).__init__(
capacity=capacity,
operation=min,
neutral_element=float('inf')
)
def min(self, start=0, end=None):
"""Returns min(arr[start], ..., arr[end])"""
return super(MinSegmentTree, self).reduce(start, end)
class ReplayBuffer(object):
def __init__(self, size):
"""Create Replay buffer.
Parameters
----------
size: int
Max number of transitions to store in the buffer. When the buffer
overflows the old memories are dropped.
"""
self._storage = []
self._maxsize = size
self._next_idx = 0
def __len__(self):
return len(self._storage)
def add(self, obs_t, action, reward, obs_tp1, done, goal):
data = (obs_t, action, reward, obs_tp1, done, goal)
if self._next_idx >= len(self._storage):
self._storage.append(data)
else:
self._storage[self._next_idx] = data
self._next_idx = (self._next_idx + 1) % self._maxsize
def _encode_sample_torch(self, idxes):
sample = [self._storage[i] for i in idxes]
batch = Experience(*zip(*sample))
return torch.stack(batch.state), torch.stack(batch.action), torch.stack(batch.reward).squeeze(), torch.stack(batch.next_state), \
torch.stack([torch.tensor(batch.done)]).squeeze(), torch.stack(batch.goal)
"""
# old implementation from
def _encode_sample(self, idxes):
obses_t, actions, rewards, obses_tp1, dones, goals = [], [], [], [], [], []
for i in idxes:
data = self._storage[i]
obs_t, action, reward, obs_tp1, done, goal = data
obses_t.append(np.array(obs_t, copy=False))
actions.append(np.array(action, copy=False))
rewards.append(reward)
obses_tp1.append(np.array(obs_tp1, copy=False))
dones.append(done)
goals.append(np.array(action, copy=False))
return np.array(obses_t), np.array(actions), np.array(rewards, dtype=np.float32), np.array(obses_tp1), np.array(dones, dtype=np.uint8), np.array(goals, dtype=np.uint8)
"""
def sample(self, batch_size):
"""Sample a batch of experiences.
Parameters
----------
batch_size: int
How many transitions to sample.
Returns
-------
obs_batch: np.array
batch of observations
act_batch: np.array
batch of actions executed given obs_batch
rew_batch: np.array
rewards received as results of executing act_batch
next_obs_batch: np.array
next set of observations seen after executing act_batch
done_mask: np.array
done_mask[i] = 1 if executing act_batch[i] resulted in
the end of an episode and 0 otherwise.
"""
idxes = [random.randint(0, len(self._storage) - 1) for _ in range(batch_size)]
return self._encode_sample_torch(idxes)
class PrioritizedReplayBuffer(ReplayBuffer):
def __init__(self, size, alpha):
"""Create Prioritized Replay buffer.
Parameters
----------
size: int
Max number of transitions to store in the buffer. When the buffer
overflows the old memories are dropped.
alpha: float
how much prioritization is used
(0 - no prioritization, 1 - full prioritization)
See Also
--------
ReplayBuffer.__init__
"""
super(PrioritizedReplayBuffer, self).__init__(size)
assert alpha > 0
self._alpha = alpha
it_capacity = 1
while it_capacity < size:
it_capacity *= 2
self._it_sum = SumSegmentTree(it_capacity)
self._it_min = MinSegmentTree(it_capacity)
self._max_priority = 1.0
def add(self, *args, **kwargs):
"""See ReplayBuffer.store_effect"""
idx = self._next_idx
super(PrioritizedReplayBuffer,self).add(*args, **kwargs)
self._it_sum[idx] = self._max_priority ** self._alpha
self._it_min[idx] = self._max_priority ** self._alpha
def _sample_proportional(self, batch_size):
res = []
for _ in range(batch_size):
# TODO(szymon): should we ensure no repeats?
mass = random.random() * self._it_sum.sum(0, len(self._storage) - 1)
idx = self._it_sum.find_prefixsum_idx(mass)
res.append(idx)
return res
def sample(self, batch_size, beta):
"""Sample a batch of experiences.
compared to ReplayBuffer.sample
it also returns importance weights and idxes
of sampled experiences.
Parameters
----------
batch_size: int
How many transitions to sample.
beta: float
To what degree to use importance weights
(0 - no corrections, 1 - full correction)
Returns
-------
obs_batch: np.array
batch of observations
act_batch: np.array
batch of actions executed given obs_batch
rew_batch: np.array
rewards received as results of executing act_batch
next_obs_batch: np.array
next set of observations seen after executing act_batch
done_mask: np.array
done_mask[i] = 1 if executing act_batch[i] resulted in
the end of an episode and 0 otherwise.
weights: np.array
Array of shape (batch_size,) and dtype np.float32
denoting importance weight of each sampled transition
idxes: np.array
Array of shape (batch_size,) and dtype np.int32
idexes in buffer of sampled experiences
"""
assert beta > 0
idxes = self._sample_proportional(batch_size)
weights = []
p_min = self._it_min.min() / self._it_sum.sum()
max_weight = (p_min * len(self._storage)) ** (-beta)
for idx in idxes:
p_sample = self._it_sum[idx] / self._it_sum.sum()
weight = (p_sample * len(self._storage)) ** (-beta)
weights.append(weight / max_weight)
weights = np.array(weights)
encoded_sample = self._encode_sample_torch(idxes)
return tuple(list(encoded_sample) + [weights, idxes])
def update_priorities(self, idxes, priorities):
"""Update priorities of sampled transitions.
sets priority of transition at index idxes[i] in buffer
to priorities[i].
Parameters
----------
idxes: [int]
List of idxes of sampled transitions
priorities: [float]
List of updated priorities corresponding to
transitions at the sampled idxes denoted by
variable `idxes`.
"""
assert len(idxes) == len(priorities)
for idx, priority in zip(idxes, priorities):
#print priority
#time.sleep(0.5)
assert priority > 0
assert 0 <= idx < len(self._storage)
self._it_sum[idx] = priority ** self._alpha
self._it_min[idx] = priority ** self._alpha
self._max_priority = max(self._max_priority, priority)

343
agent/run_experiment.py Normal file
View file

@ -0,0 +1,343 @@
import os
import time
import cv2
import torch
import numpy as np
import imageio
import wandb
from datetime import datetime
from ddqn_agent import Agent
from atari_env import ALEEnvironment
from args import HIntArgumentParser
from metacontroller import MetaController
GPU_DEVICE = 7
VERBOSE = 2
DEBUG = False
TEST = False
STOP_TRAINING_THRESHOLD = 0.90
EPISODES = 1000000
MAX_EPISODE_STEPS = 1000
RANDOM_PLAY_STEPS = 20000
# Use subgoals from gaze analysis
GOAL_MODE = 'full_sequence' # 'full_sequence' 'unique'
ALL_SUBGOALS = np.loadtxt('subgoals.txt', dtype=int, delimiter=',')
SUBGOAL_ORDER = [8, 6, 1, 0, 2, 7, 2, 0, 1, 6, 8, 9]
GOALS = len(SUBGOAL_ORDER) if GOAL_MODE == 'full_sequence' else len(np.unique(SUBGOAL_ORDER))
SUBGOALs = SUBGOAL_ORDER if GOAL_MODE == 'full_sequence' else np.array(SUBGOAL_ORDER)[np.sort(np.unique(SUBGOAL_ORDER, return_index=True)[1])] # unsorted unique -> keep subgoal appearance order
TRAINED_GOALS = [False] * GOALS
def main():
# init random seed
RANDOM_SEED = np.random.randint(100)
print(f'Setting random seed to {RANDOM_SEED}')
os.environ['PYTHONHASHSEED']=str(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
x = datetime.now()
TIME_STAMP = x.strftime("%d%b%y-%H%M%S")
MODEL_NAME = 'hInt-RL-full_' + str(RANDOM_SEED) + '_' + TIME_STAMP
os.environ['CUDA_VISIBLE_DEVICES']=str(GPU_DEVICE)
os.environ['CUDA_LAUNCH_BLOCKING'] = str(1.0)
actionMap = [0, 1, 2, 3, 4, 5, 11, 12]
actionExplain = ['no action', 'jump', 'up', 'right', 'left', 'down', 'jump right', 'jump left']
actionVectors = [[0,0], [0.0,-1.0], [0.0,-1.0], [1.0,0.0], [-1.0,0.0], [0.0,1.0], [0.7071067811865475,-0.7071067811865475], [-0.7071067811865475,-0.7071067811865475]]
inv_label_mapping = {}
for i, l in enumerate(SUBGOALs): # np.unique without sorting the values
inv_label_mapping[l] = i
print(inv_label_mapping)
goalExplain = {8: 'start', 6: 'rope', 1: 'lower right ladder', 0: 'danger zone', 2: 'lower left ladder', 7: 'key', 9:'left door'}
subgoal_success_tracker = [[0] for i in range(GOALS)]
subgoal_success_steps = [[] for i in range(GOALS)]
subgoal_trailing_performance = [0.0 for _ in range(GOALS)]
parser = HIntArgumentParser()
parser.set_common_args()
args = parser.parse_args(['--model_name', MODEL_NAME, '--verbose', str(VERBOSE), '--random_seed', str(RANDOM_SEED),
'--num_goals', str(GOALS), '--goal_mode', GOAL_MODE])
if not DEBUG:
wandb.init(project="hInt-RL", config=vars(args), name=MODEL_NAME)
print(*[f'\n[PARAM] {k}: {v}' for (k,v) in vars(args).items()])
print()
# Setup environment
env = ALEEnvironment(args, device=f"cuda:{GPU_DEVICE}")
input_shape = env.getStackedState().shape
print(f'Input shape {input_shape}\n')
device = torch.device(f"cuda:{GPU_DEVICE}" if torch.cuda.is_available() else "cpu")
# Setup metacontroller
metacontroller = MetaController(device=device, args=args, input_shape=input_shape, n_goals=GOALS, hidden_nodes=args.hidden_nodes)
# Setup agents and their DQN models
agent_list = []
for goal in range(sum(TRAINED_GOALS), GOALS):
print(f'Setting up agent for goal {goalExplain.get(SUBGOALs[goal])}')
agent = Agent(device=device, goal=goal, input_shape=input_shape, n_actions=len(actionMap), random_play_steps=20000, args=args)
agent_list.append(agent)
total_rewards = []
wrong_meta_pred = 0
total_steps = 0
create_video = False
for episode in range(EPISODES):
if args.verbose >= 1:
print(f'\n[Episode {episode}] Starting new episode...')
# Initialize environment and get state
env.restart()
state = env.getStackedState()
episode_steps = 0
subgoal_agent_loss = [0.0 for _ in range(GOALS)]
meta_labels = []
wrong_goal = False
goal_idx = 0
true_goal = goal_idx if GOAL_MODE == 'full_sequence' else inv_label_mapping.get(SUBGOAL_ORDER[goal_idx])
img_array = []
if episode % 10000 == 0:
create_video = True
expert_goal = np.zeros((1, GOALS))
expert_goal[0, true_goal] = 1.0
meta_labels.append((state, expert_goal)) # append, but do not collect yet
if TRAINED_GOALS[true_goal]:
goal = true_goal
# Collect sample for metacontroller because goal will be reached
metacontroller.collect(state, expert_goal)
else:
# Metacontroller predict goal
goal = metacontroller.sample(metacontroller.predict(state))
if goal!= true_goal:
wrong_goal = True
if args.verbose >= 1:
print(f"[Episode {episode}] Metacontroller predicted {goal} instead of {true_goal} as goal \U0001F47E ")
wrong_meta_pred += 1
if wrong_meta_pred % 100 == 0:
print(f'[Episode {episode}] 100 wrong meta choices. Resetting metacontroller... ')
metacontroller.reset()
else:
if args.verbose >= 1:
print(f'[Episode {episode}] Metacontroller predicted goal {goal} as new goal...')
all_step_times = []
step_time = time.time()
while not env.isTerminal() and episode_steps < MAX_EPISODE_STEPS and not wrong_goal:
# Unroll episode until agent reaches goal
while not env.trueGoalReached(SUBGOAL_ORDER[goal_idx]) and not env.isTerminal() and episode_steps < MAX_EPISODE_STEPS and not wrong_goal:
goal_position = torch.tensor(np.array([ALL_SUBGOALS[SUBGOAL_ORDER[goal_idx]]]), device=device)
if create_video:
img = env.getScreenOrig()
cv2.putText(img=img, text='goal : ' + goalExplain.get(SUBGOAL_ORDER[goal_idx]), org=(5, 205), fontFace=cv2.FONT_HERSHEY_SIMPLEX,
fontScale=0.3, color=(255, 255, 255),thickness=1)
img_array.append(img)
action = agent_list[goal].select_action(state.unsqueeze(0))
external_reward = env.act(actionMap[action])
if (external_reward != 0):
external_reward = 1.0
all_step_times.append(time.time() - step_time)
step_time = time.time()
episode_steps += 1
# Calculate intrinsic reward [optionally add distance reward]
reward = 0.0
if env.trueGoalReached(SUBGOAL_ORDER[goal_idx]):
reward += 1.0
if env.isTerminal() or episode_steps==MAX_EPISODE_STEPS:
reward -= 1.0
"""
# Simple direction reward
goal_vec = env.get_goal_direction(goal)
action_vec = actionVectors[action]
direction_reward = np.dot(action_vec, goal_vec) / 100
reward += direction_reward
"""
"""
# Distance Reward
# Query agent location every 20 steps (otherwise training is too slow)
if episode_steps % 20 == 0:
env.detect_agent()
reward += env.distanceReward(lastGoal=(true_goal-1), goal=goal)
"""
total_rewards.append(reward)
reward = torch.tensor([reward], device=device)
next_state = env.getStackedState()
# Store transition and update network parameters
agent_list[goal].remember(state, action, reward, next_state, env.isTerminal(), goal_position)
# Move to the next state
state = next_state
# Optimize the policy network
agent_loss = agent_list[goal].optimize_model()
if agent_loss is not None:
subgoal_agent_loss[goal] += agent_loss
# Update goal
if episode_steps >= MAX_EPISODE_STEPS:
subgoal_success_tracker[goal].append(0)
if args.verbose >= 1:
print(f'[Episode {episode}] Reached maximum epsiode steps: {episode_steps}. Terminate episode.')
break
elif env.trueGoalReached(SUBGOAL_ORDER[goal_idx]):
if args.verbose >= 1:
print(f'[Episode {episode}] Goal reached! \U0001F389 after step #{episode_steps}. ')
subgoal_success_tracker[goal].append(1)
subgoal_success_steps[goal].append(episode_steps)
# Predict new goal and continue if its true goal
goal_idx += 1
if goal_idx == len(SUBGOAL_ORDER):
# Finished all options --> start new episode
wrong_goal = True
print(f"[Episode {episode}] Reached all goals! \U0001F38A \U0001F38A \U0001F38A")
break
true_goal = goal_idx if GOAL_MODE == 'full_sequence' else inv_label_mapping.get(SUBGOAL_ORDER[goal_idx])
expert_goal = np.zeros((1, GOALS))
expert_goal[0, true_goal] = 1.0
meta_labels.append((state, expert_goal)) # append, but do not collect yet
# Metacontroller predict goal
goal = metacontroller.sample(metacontroller.predict(state))
if goal!= true_goal:
wrong_goal = True
if args.verbose >= 1:
print(f"[Episode {episode}] Metacontroller predicted {goal} instead of {true_goal} as goal \U0001F47E ")
wrong_meta_pred += 1
if wrong_meta_pred % 100 == 0:
if args.verbose >= 1:
print(f'[Episode {episode}] 100 wrong meta choices. Resetting metacontroller... ')
metacontroller.reset()
break
else:
# Continue with new goal
if args.verbose >= 1:
print(f'[Episode {episode}] Metacontroller predicted goal {goal} as new goal...')
else:
subgoal_success_tracker[goal].append(0)
if args.verbose >= 1:
print(f'[Episode {episode}] Agent killed after {episode_steps} steps \U0001F47E Terminate episode.')
break
if create_video and img_array:
out_path = f'runs/video_run{TIME_STAMP}-E{episode}.gif'
imageio.mimsave(out_path, img_array)
print('Saved gif to ', out_path)
create_video = False
"""END OF EPISODE"""
# At end of episode: aggregate data for metacontroller
item = meta_labels[-1]
metacontroller.collect(item[0], item[1])
# Log agent losses if agent was running
total_steps += episode_steps
avg_step_time = sum(all_step_times) / len(all_step_times) if episode_steps > 0 else 0
#print(f'[Episode {episode}] episode time: {sum(all_step_times) / 60:.2f}min average step time: {avg_step_time*100:.1f}ms')
if not DEBUG:
wandb.log({'mean_reward': np.array(total_rewards[-episode_steps:]).mean(), 'avg_step_time': avg_step_time*100, 'episode': episode})
for g in range(GOALS):
if len(subgoal_success_tracker[g]) < episode:
# goal not evaluated yet because other goals not reached
subgoal_success_tracker[g].append(0)
# agent_loss = subgoal_agent_loss[g] / episode_steps if episode_steps > 0 else 0
# log_steps = episode_steps if (subgoal_success_tracker[g] and subgoal_success_tracker[g][-1] == 1) else -episode_steps
if not DEBUG:
wandb.log({ f'sub-goal {g}': subgoal_success_tracker[g][-1], 'episode': episode}) #f'agent{g}_loss': agent_loss, f'agent{g}_steps': log_steps,
# Train metacontroller
if metacontroller.check_training_clock():
if args.verbose >= 2:
print('###################################')
print('### Training Metacontroller ###')
print('###################################')
meta_loss, meta_acc = metacontroller.train()
if not DEBUG:
wandb.log({'meta_loss': meta_loss, 'meta_acc': meta_acc, 'episode': episode})
if args.verbose >= 2:
print(f'Metacontroller loss: {meta_loss:.2f}\n')
print('###################################')
if len(subgoal_success_tracker[goal]) > 100:
subgoal_trailing_performance[goal] = sum([v for v in subgoal_success_tracker[goal][-100:] if not v is None])/ 100.0
if args.verbose >= 0:
print(f'[Episode {episode}] Subgoal trailing performance for goal {goal} is {subgoal_trailing_performance[goal]:.2f}')
print(f'[Episode {episode}] Subgoal agent {goal} steps done {agent_list[goal].steps_done}')
print(f'[Episode {episode}] Subgoal agent {goal} epsilon value is {agent_list[goal].get_epsilon():.2f}')
if subgoal_trailing_performance[goal] > STOP_TRAINING_THRESHOLD:
if args.verbose >= 1:
print(f'[Episode {episode}] Training for goal #{goal} completed...')
if not agent_list[goal].training_done:
agent_list[goal].finish_training(TIME_STAMP)
TRAINED_GOALS[goal] = True
if goal_idx == len(SUBGOAL_ORDER):
# Last goal reached trailing performance --> stop training
print(f"\nTraining completed \U0001F38A \U0001F38A \U0001F38A \n")
torch.save(metacontroller.model.state_dict(), f"runs/run{TIME_STAMP}-metacontroller.pt")
break
# After training is completed
for goal in range(GOALS):
if subgoal_success_steps[goal]:
print(f'\nAgent{goal} performance: {np.array(subgoal_success_tracker[goal]).mean():.2%}')
print(f'Reached goal {goal} {sum(subgoal_success_tracker[goal])}x with {np.array(subgoal_success_steps[goal]).mean():.0f} average steps')
if __name__ == "__main__":
main()

View file

@ -0,0 +1,348 @@
import os
import time
import cv2
import torch
import numpy as np
import imageio
import wandb
from datetime import datetime
from ddqn_agent import Agent
from atari_env import ALEEnvironment
from args import HIntArgumentParser
from metacontroller import MetaController
GPU_DEVICE = 4
VERBOSE = 2
DEBUG = False
TEST = False
STOP_TRAINING_THRESHOLD = 0.90
EPISODES = 1000000
MAX_EPISODE_STEPS = 1000
RANDOM_PLAY_STEPS = 20000
# Use subgoals from gaze analysis
GOAL_FEATURE = True
GOAL_MODE = 'single_agent'
ALL_SUBGOALS = np.loadtxt('subgoals.txt', dtype=int, delimiter=',')
SUBGOAL_ORDER = [8, 6, 1, 0, 2, 7, 2, 0, 1, 6, 8, 9]
GOALS = len(SUBGOAL_ORDER) #if GOAL_MODE == 'full_sequence' else len(np.unique(SUBGOAL_ORDER))
SUBGOALs = SUBGOAL_ORDER #if GOAL_MODE == 'full_sequence' else np.array(SUBGOAL_ORDER)[np.sort(np.unique(SUBGOAL_ORDER, return_index=True)[1])] # unsorted unique -> keep subgoal appearance order
TRAINED_GOALS = [False] * GOALS
def main():
# init random seed
RANDOM_SEED = 9 #np.random.randint(100)
print(f'Setting random seed to {RANDOM_SEED}')
os.environ['PYTHONHASHSEED']=str(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
x = datetime.now()
TIME_STAMP = x.strftime("%d%b%y-%H%M%S")
MODEL_NAME = 'hInt-RL-single-dist_' + str(RANDOM_SEED) + '_' + TIME_STAMP
os.environ['CUDA_VISIBLE_DEVICES']=str(GPU_DEVICE)
os.environ['CUDA_LAUNCH_BLOCKING'] = str(1.0)
actionMap = [0, 1, 2, 3, 4, 5, 11, 12]
actionExplain = ['no action', 'jump', 'up', 'right', 'left', 'down', 'jump right', 'jump left']
actionVectors = [[0,0], [0.0,-1.0], [0.0,-1.0], [1.0,0.0], [-1.0,0.0], [0.0,1.0], [0.7071067811865475,-0.7071067811865475], [-0.7071067811865475,-0.7071067811865475]]
goalExplain = {8: 'start', 6: 'rope', 1: 'lower right ladder', 0: 'danger zone', 2: 'lower left ladder', 7: 'key', 9:'left door'}
inv_label_mapping = {}
for i, l in enumerate(SUBGOALs): # np.unique without sorting the values
inv_label_mapping[l] = i
print(inv_label_mapping)
subgoal_success_tracker = [[0] for i in range(GOALS)]
subgoal_success_steps = [[] for i in range(GOALS)]
subgoal_trailing_performance = [0.0 for _ in range(GOALS)]
parser = HIntArgumentParser()
parser.set_common_args()
args = parser.parse_args(['--model_name', MODEL_NAME, '--verbose', str(VERBOSE), '--random_seed', str(RANDOM_SEED),
'--num_goals', str(GOALS), '--goal_mode', GOAL_MODE,
'--epsilon_start', str(1.0), '--epsilon_end', str(0.1), '--epsilon_steps', str(2000000)])
if not DEBUG:
wandb.init(project="hInt-RL", config=vars(args), name=MODEL_NAME)
print(*[f'\n[PARAM] {k}: {v}' for (k,v) in vars(args).items()])
print()
env = ALEEnvironment(args, device=f"cuda:{GPU_DEVICE}")
input_shape = env.getStackedState().shape
#input_shape = np.array([input_shape[2], input_shape[1], input_shape[0]])
print(f'Input shape {input_shape}\n')
device = torch.device(f"cuda:{GPU_DEVICE}" if torch.cuda.is_available() else "cpu")
# Setup metacontroller
metacontroller = MetaController(device=device, args=args, input_shape=input_shape, n_goals=GOALS, hidden_nodes=args.hidden_nodes)
# Setup single agent
agent = Agent(device=device, goal=None, input_shape=input_shape, n_actions=len(actionMap), random_play_steps=RANDOM_PLAY_STEPS, args=args, goal_feature=GOAL_FEATURE)
total_rewards = []
wrong_meta_pred = 0
total_steps = 0
create_video = False
for episode in range(EPISODES):
if args.verbose >= 1:
print(f'\n[Episode {episode}] Starting new episode...')
# Initialize environment and get state
env.restart()
state = env.getStackedState()
episode_steps = 0
subgoal_agent_loss = 0
meta_labels = []
wrong_goal = False
goal_idx = 0
true_goal = goal_idx # if GOAL_MODE == 'full_sequence' else inv_label_mapping.get(SUBGOAL_ORDER[goal_idx])
img_array = []
if episode % 10000 == 0:
create_video = True
expert_goal = np.zeros((1, GOALS))
expert_goal[0, true_goal] = 1.0
meta_labels.append((state, expert_goal)) # append, but do not collect yet
if TRAINED_GOALS[true_goal]:
goal = true_goal
# Collect sample for metacontroller because goal will be reached
metacontroller.collect(state, expert_goal)
else:
# Metacontroller predict goal
goal = metacontroller.sample(metacontroller.predict(state))
if goal!= true_goal:
wrong_goal = True
if args.verbose >= 1:
print(f"[Episode {episode}] Metacontroller predicted {goal} instead of {true_goal} as goal \U0001F47E ")
wrong_meta_pred += 1
if wrong_meta_pred % 100 == 0:
print(f'[Episode {episode}] 100 wrong meta choices. Resetting metacontroller... ')
metacontroller.reset()
else:
if args.verbose >= 1:
print(f'[Episode {episode}] Metacontroller predicted goal {goal} as new goal...')
all_step_times = []
step_time = time.time()
while not env.isTerminal() and episode_steps < MAX_EPISODE_STEPS and not wrong_goal:
# Unroll episode until agent reaches goal
while not env.trueGoalReached(SUBGOAL_ORDER[goal_idx]) and not env.isTerminal() and episode_steps < MAX_EPISODE_STEPS and not wrong_goal:
goal_position = torch.tensor(np.array([ALL_SUBGOALS[SUBGOAL_ORDER[goal_idx]]]), device=device)
if create_video:
img = env.getScreenOrig()
cv2.putText(img=img, text='goal: ' + goalExplain.get(SUBGOAL_ORDER[goal_idx]), org=(5, 205), fontFace=cv2.FONT_HERSHEY_SIMPLEX,
fontScale=0.3, color=(255, 255, 255),thickness=1)
img_array.append(img)
if GOAL_FEATURE:
action = agent.select_action(state.unsqueeze(0), goal_position)
else:
action = agent.select_action(state.unsqueeze(0))
external_reward = env.act(actionMap[action])
if (external_reward != 0):
external_reward = 1.0
all_step_times.append(time.time() - step_time)
step_time = time.time()
episode_steps += 1
# Calculate intrinsic reward [optionally add distance reward]
reward = 0.0 # - 1.0 / 1000 # small negative reward for each step
if env.trueGoalReached(SUBGOAL_ORDER[goal_idx]):
reward += 1.0
if env.isTerminal() or episode_steps==MAX_EPISODE_STEPS:
reward -= 1.0
"""
# Simple direction reward
goal_vec = env.get_goal_direction(SUBGOAL_ORDER[goal_idx])
action_vec = actionVectors[action]
direction_reward = 0.001 * np.dot(action_vec, goal_vec)
reward += direction_reward
"""
# Distance Reward
last_goal = SUBGOAL_ORDER[goal_idx-1] if goal_idx != 0 else None
reward += env.distanceReward(lastGoal=last_goal, goal=SUBGOAL_ORDER[goal_idx])
total_rewards.append(reward)
reward = torch.tensor([reward], device=device)
next_state = env.getStackedState()
# Store transition and update network parameters
agent.remember(state, action, reward, next_state, env.isTerminal(), goal_position)
# Move to the next state
state = next_state
# Optimize the policy network
agent_loss = agent.optimize_model()
if agent_loss is not None:
subgoal_agent_loss += agent_loss
# Update goal
if episode_steps >= MAX_EPISODE_STEPS:
subgoal_success_tracker[goal_idx].append(0)
if args.verbose >= 1:
print(f'[Episode {episode}] Reached maximum epsiode steps: {episode_steps}. Terminate episode.')
break
elif env.trueGoalReached(SUBGOAL_ORDER[goal_idx]):
if args.verbose >= 1:
print(f'[Episode {episode}] Goal reached! \U0001F389 after step #{episode_steps}. ')
subgoal_success_tracker[goal_idx].append(1)
subgoal_success_steps[goal_idx].append(episode_steps)
# Predict new goal and continue if its true goal
goal_idx += 1
"""
if subgoal_trailing_performance[goal_idx] == 0:
# goal has not been seen before --> reset TanH schedule to enforce new exploration
print(episode, episode_steps, goal, 'resetting TanH schedule...')
agent.reset_epsilon_schedule()
"""
if goal_idx == len(SUBGOAL_ORDER):
# Finished all options --> start new episode
wrong_goal = True
print(f"[Episode {episode}] Reached all goals! \U0001F38A \U0001F38A \U0001F38A")
break
true_goal = goal_idx # if GOAL_MODE == 'full_sequence' else inv_label_mapping.get(SUBGOAL_ORDER[goal_idx])
expert_goal = np.zeros((1, GOALS))
expert_goal[0, true_goal] = 1.0
meta_labels.append((state, expert_goal)) # append, but do not collect yet
# Metacontroller predict goal
goal = metacontroller.sample(metacontroller.predict(state))
if goal!= true_goal:
wrong_goal = True
if args.verbose >= 1:
print(f"[Episode {episode}] Metacontroller predicted {goal} instead of {true_goal} as goal \U0001F47E ")
wrong_meta_pred += 1
if wrong_meta_pred % 100 == 0:
if args.verbose >= 1:
print(f'[Episode {episode}] 100 wrong meta choices. Resetting metacontroller... ')
metacontroller.reset()
break
else:
# Continue with new goal
if args.verbose >= 1:
print(f'[Episode {episode}] Metacontroller predicted goal {goal} as new goal...')
else:
subgoal_success_tracker[goal_idx].append(0)
if args.verbose >= 1:
print(f'[Episode {episode}] Agent killed after {episode_steps} steps \U0001F47E Terminate episode.')
break
if create_video and img_array:
out_path = f'runs/video_run{TIME_STAMP}-E{episode}.gif'
imageio.mimsave(out_path, img_array)
print('Saved gif to ', out_path)
create_video = False
"""END OF EPISODE"""
# At end of episode: aggregate data for metacontroller
item = meta_labels[-1]
metacontroller.collect(item[0], item[1])
total_steps += episode_steps
# if agent was running
if all_step_times:
# Print average step time
avg_step_time = sum(all_step_times) / len(all_step_times)
# Log agent losses
agent_loss = subgoal_agent_loss / episode_steps
else:
avg_step_time = 0
agent_loss = 0
if not DEBUG:
wandb.log({'mean_reward': np.array(total_rewards[-episode_steps:]).mean(), 'avg_step_time': avg_step_time*100, 'agent_loss': agent_loss, 'epsilon': agent.get_epsilon(), 'episode': episode})
for g in range(GOALS):
if len(subgoal_success_tracker[g]) < episode:
# goal not evaluated yet because other goals not reached
subgoal_success_tracker[g].append(None)
if not DEBUG:
wandb.log({ f'sub-goal {g}': subgoal_success_tracker[g][-1], 'episode': episode})
# Train metacontroller
if metacontroller.check_training_clock():
if args.verbose >= 2:
print('###################################')
print('### Training Metacontroller ###')
print('###################################')
meta_loss, meta_acc = metacontroller.train()
if not DEBUG:
wandb.log({'meta_loss': meta_loss, 'meta_acc': meta_acc, 'episode': episode})
if args.verbose >= 2:
print(f'Metacontroller loss: {meta_loss:.2f}\n')
print('###################################')
if len(subgoal_success_tracker[goal_idx]) > 100:
subgoal_trailing_performance[goal] = sum([v for v in subgoal_success_tracker[goal][-100:] if not v is None])/ 100.0
if args.verbose >= 0:
print(f'[Episode {episode}] Subgoal trailing performance for goal {goal_idx} is {subgoal_trailing_performance[goal_idx]:.2f}')
print(f'[Episode {episode}] Agent {goal_idx} steps done {agent.steps_done}')
print(f'[Episode {episode}] Agent epsilon value is {agent.get_epsilon():.2f}')
if subgoal_trailing_performance[goal] > STOP_TRAINING_THRESHOLD:
if args.verbose >= 1:
print(f'[Episode {episode}] Training for goal #{goal} completed...')
if true_goal == GOALS:
# Last goal reached trailing performance --> stop training
print(f"\nTraining completed \U0001F38A \U0001F38A \U0001F38A \n")
agent.finish_training(TIME_STAMP)
torch.save(metacontroller.model.state_dict(), f"run{TIME_STAMP}-metacontroller.pt")
break
# After training is completed
for goal in range(GOALS):
if subgoal_success_steps[goal]:
print(f'\nAgent{goal} performance: {np.array(subgoal_success_tracker[goal]).mean():.2%}')
print(f'Reached goal {goal} {sum(subgoal_success_tracker[goal])}x with {np.array(subgoal_success_steps[goal]).mean():.0f} average steps')
if __name__ == "__main__":
main()