From e6f9ace2e3fa21892c6e76aac2f988e6c998fde9 Mon Sep 17 00:00:00 2001 From: penzkofer Date: Wed, 12 Mar 2025 18:46:09 +0100 Subject: [PATCH] add Int-HRL agent scripts --- README.md | 10 +- agent/atari_env.py | 229 +++++++++++++++++ agent/ddqn_agent.py | 308 +++++++++++++++++++++++ agent/metacontroller.py | 158 ++++++++++++ agent/replay_buffer.py | 408 +++++++++++++++++++++++++++++++ agent/run_experiment.py | 343 ++++++++++++++++++++++++++ agent/single_agent_experiment.py | 348 ++++++++++++++++++++++++++ 7 files changed, 1803 insertions(+), 1 deletion(-) create mode 100644 agent/atari_env.py create mode 100644 agent/ddqn_agent.py create mode 100644 agent/metacontroller.py create mode 100644 agent/replay_buffer.py create mode 100644 agent/run_experiment.py create mode 100644 agent/single_agent_experiment.py diff --git a/README.md b/README.md index 4d509fc..84bf95c 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Int-HRL -This is the official repository for [Int-HRL: Towards Intention-based Hierarchical Reinforcement Learning](https://perceptualui.org/publications/penzkofer23_ala/)
+This is the official repository for [Int-HRL: Towards Intention-based Hierarchical Reinforcement Learning](https://collaborative-ai.org/publications/penzkofer24_ncaa/)
Int-HRL uses eye gaze from human demonstration data on the Atari game Montezuma's Revenge to extract human player's intentions and converts them to sub-goals for Hierarchical Reinforcement Learning (HRL). For further details take a look at the corresponding paper. @@ -17,8 +17,16 @@ To pre-process the Atari-HEAD data run [Preprocess_AtariHEAD.ipynb](Preprocess_A 3. [Alignment with Trajectory](TrajectoryMatching.ipynb): run expert trajectory to get order of subgoals ## Intention-based Hierarchical RL Agent +The Int-HRL agent is based on the hierarchically guided Imitation Learning method (hg-DAgger/Q), where we adapted code from [https://github.com/hoangminhle/hierarchical_IL_RL](https://github.com/hoangminhle/hierarchical_IL_RL)
+ +Due to the novel sub-goal extraction pipeline, our agent does not require experts during training and is more than three times more sample efficient compared to hg-DAgger/Q.
+ +To run the full agent with 12 separate low-level agents for sub-goal execution, run [agent/run_experiment.py](agent/run_experiment.py), for single agents (one low-level agent for all sub-goals) run [agent/single_agent_experiment.py](agent/single_agent_experiment.py). + +## Extension to Venture and Hero under construction + ## Citation Please consider citing these paper if you use Int-HRL or parts of this repository in your research: ``` diff --git a/agent/atari_env.py b/agent/atari_env.py new file mode 100644 index 0000000..0bf66eb --- /dev/null +++ b/agent/atari_env.py @@ -0,0 +1,229 @@ +import os +import cv2 +import gym +import numpy as np +import torchvision.transforms as T + +from atariari.benchmark.wrapper import AtariARIWrapper + +SUBGOAL_ORDER = [8, 6, 1, 0, 2, 7, 2, 0, 1, 6, 8, 9] + +class ALEEnvironment(): + + def __init__(self, args, device='cpu'): + if 'cuda' in device: + os.environ['CUDA_VISIBLE_DEVICES']= device.split(':')[-1] + + self.ale = AtariARIWrapper(gym.make('MontezumaRevenge-v4', + frameskip=args.frame_skip, + render_mode='rgb_array', # return the rgb key in step metadata with the current environment RGB frame. + repeat_action_probability=0.0)) + + self.histLen = args.sequence_length + print(self.histLen) + + self.screen_width = args.screen_width + self.screen_height = args.screen_height + + self.actions = np.arange(self.ale.action_space.n) + + print('Action space: ', self.actions) + + self.mode = "train" + self.life_lost = False + + # Set environment with random seed! + self.ale.reset(seed=args.random_seed) + + # Get init screen + self.init_screen = self.getScreen() + print(f'Size of screen is {self.init_screen.shape}') + + # Perform NOOP action to init self.done and self.info + _, reward, self.done, self.info = self.ale.step(0) + + # Use subgoals from gaze analysis + self.goalSet = np.loadtxt('subgoals.txt', dtype=int, delimiter=',') + self.goalCenterLoc = [] + + for goal in self.goalSet: + goalCenter = [float(goal[0]+goal[2])/2, float(goal[1]+goal[3])/2] + self.goalCenterLoc.append(goalCenter) + + self.agentOriginLoc = [42, 33] + self.agentLastX = 42 + self.agentLastY = 33 + self.reachedGoal = [0, 0, 0,0] + self.histState = self.initializeHistState() + print('Histogram of states: ', self.histState.shape) + + self.to_tensor = T.ToTensor() + + + def initializeHistState(self): + if self.histLen >= 2: + histState = np.concatenate((self.getState(), self.getState()), axis = 2) + for _ in range(self.histLen - 2): + histState = np.concatenate((histState, self.getState()), axis = 2) + else: + histState = self.getState() + + return histState + + def get_input_shape(self): + return (self.histLen, self.screen_width, self.screen_height) + + def numActions(self): + return len(self.actions) + + def resetGoalReach(self): + self.reachedGoal = [0, 0, 0,0] + + def act(self, action): + 'Perform action and handle results' + lives = self.info['lives'] + + _, reward, self.done, self.info = self.ale.step(self.actions[action]) + + # agent location from AtariARI Wrapper + self.agentLastX, self.agentLastY = self.info['labels']['player_x'], 320 - self.info['labels']['player_y'] + + self.life_lost = (not lives == self.info['lives']) + currState = self.getState() + self.histState = np.concatenate((self.histState[:, :, 1:], currState), axis = 2) + return reward + + def restart(self): + 'Restart environment: set goals and agent location' + self.ale.reset() + self.life_lost = False + self.reachedGoal = [0, 0, 0, 0] + for i in range(19): + self.act(0) #wait for initialization + self.histState = self.initializeHistState() + self.agentLastX = self.agentOriginLoc[0] + self.agentLastY = self.agentOriginLoc[1] + + def beginNextLife(self): + 'Begin next life without restarting environment: set goals and agent location' + self.life_lost = False + self.reachedGoal = [0, 0, 0,0] + for i in range(19): + self.act(0) #wait for initialization + self.histState = self.initializeHistState() + self.agentLastX = self.agentOriginLoc[0] + self.agentLastY = self.agentOriginLoc[1] + + def detect_agent_with_tracker(self): + img = self.getScreenOrig() + preds = self.panamajoe_detector.predict(img)[0] + if len(preds['boxes']) == 1: + box = preds['boxes'][0].detach().cpu().numpy() + x, y, width, height = box[0], box[1], box[2]-box[0], box[3]-box[1] + self.agentLastX, self.agentLastY = x + width / 2, y + height / 2 + return (self.agentLastX, self.agentLastY) + + def goal_reached(self, goal): + if self.panamajoe_detector: + [[x1, y1, x2, y2]] = self.goalSet[goal] + x_mean, y_mean = self.detect_agent_with_tracker() + return x_mean > x1 and x_mean < x2 and y_mean > y1 and y_mean < y2 + else: + return self.goalReached(goal) + + def get_goal_direction(self, goal_idx): + if ((goal_idx-1) == -1): + lastGoalCenter = self.agentOriginLoc + else: + lastGoalCenter = self.goalCenterLoc[SUBGOAL_ORDER[goal_idx-1]] + + goal_direction = np.array(self.goalSet[SUBGOAL_ORDER[goal_idx]]).mean(axis=0) - lastGoalCenter + norm = np.linalg.norm(goal_direction) + + if norm == 0: + return goal_direction + + return goal_direction / norm + + def distanceReward(self, lastGoal, goal): + 'Calculate distance between agent and next sub-goal' + if lastGoal is None: + lastGoalCenter = self.agentOriginLoc + else: + lastGoalCenter = self.goalCenterLoc[lastGoal] + goalCenter = self.goalCenterLoc[goal] + agentX, agentY = self.agentLastX, self.agentLastY + dis = np.sqrt((goalCenter[0] - agentX)*(goalCenter[0] - agentX) + (goalCenter[1]-agentY)*(goalCenter[1]-agentY)) + disLast = np.sqrt((lastGoalCenter[0] - agentX)*(lastGoalCenter[0] - agentX) + (lastGoalCenter[1]-agentY)*(lastGoalCenter[1]-agentY)) + disGoals = np.sqrt((goalCenter[0]-lastGoalCenter[0])*(goalCenter[0]-lastGoalCenter[0]) + (goalCenter[1]-lastGoalCenter[1])*(goalCenter[1]-lastGoalCenter[1])) + return 0.001 * (disLast - dis) / disGoals + + def getScreen(self): + 'Get processed screen: grayscale and resized' + screen = self.ale.render(mode='rgb_array') + screen = cv2.cvtColor(screen, cv2.COLOR_RGB2GRAY) + resized = cv2.resize(screen, (self.screen_width, self.screen_height), interpolation=cv2.INTER_AREA) + return resized + + def getScreenOrig(self): + 'Get original RGB screen' + return self.ale.render(mode='rgb_array') + + def getScreenRGB(self): + 'Get RGB screen for finding agent location' + screen = self.ale.render(mode='rgb_array') + resized = cv2.resize(screen, (self.screen_width, self.screen_height), interpolation=cv2.INTER_AREA) + return resized + + def getState(self): + 'Get current state, i.e. current screen. Process screen and add color channel for input of network' + screen = self.ale.render(mode='rgb_array') + screen = cv2.cvtColor(screen, cv2.COLOR_RGB2GRAY) + resized = cv2.resize(screen, (self.screen_width, self.screen_height), interpolation=cv2.INTER_AREA) + return np.expand_dims(resized, axis=-1) + + def getStackedState(self): + return self.to_tensor(self.histState) + + def isTerminal(self): + if self.mode == 'train': + return self.done or self.life_lost + return self.done + + def isGameOver(self): + return self.done + + def isLifeLost(self): + return self.life_lost + + def reset(self): + self.ale.reset() + self.life_lost = False + + def goalReached(self, goal): + goalPosition = self.goalSet[goal] + goalScreen = self.init_screen + stateScreen = self.getState() + count = 0 + for y in range (goalPosition[0][0], goalPosition[1][0]): + for x in range (goalPosition[0][1], goalPosition[1][1]): + + if goalScreen[x][y] != stateScreen[x][y]: + count = count + 1 + + # 30 is total number of pixels of agent + if float(count) / 30 > 0.3: + self.reachedGoal[goal] = 1 + return True + + return False + + def trueGoalReached(self, goal): + 'With the AtariARI Wrapper enabled agent locations are updated every step and yield true location' + goalPosition = self.goalSet[goal] + return self.agentLastX > goalPosition[0] and self.agentLastX < goalPosition[2] and self.agentLastY > goalPosition[1] and self.agentLastY < goalPosition[3] + + def goalNotReachedBefore(self, goal): + if (self.reachedGoal[goal] == 1): + return False + return True \ No newline at end of file diff --git a/agent/ddqn_agent.py b/agent/ddqn_agent.py new file mode 100644 index 0000000..684eb20 --- /dev/null +++ b/agent/ddqn_agent.py @@ -0,0 +1,308 @@ +from collections import namedtuple +import random + +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +from torchinfo import summary +import numpy as np +from tqdm import tqdm + +from replay_buffer import PrioritizedReplayBuffer, LinearSchedule, TanHSchedule + +prioritized_replay_alpha = 0.6 +max_timesteps=1000000 +prioritized_replay_beta0=0.4 +prioritized_replay_eps=1e-6 +prioritized_replay_beta_iters = max_timesteps*0.5 +beta_schedule = LinearSchedule(prioritized_replay_beta_iters, + initial_p=prioritized_replay_beta0, + final_p=1.0) + +Experience = namedtuple('Experience', field_names=['state', 'action', 'reward', 'next_state', 'done']) + +# Deep Q Network +class DQN(nn.Module): + def __init__(self, device, input_shape, n_actions, hidden_nodes=512, goal_feature=False): + super(DQN, self).__init__() + self.device = device + + self.conv = nn.Sequential( + nn.Conv2d(input_shape[0], 16, kernel_size=5, stride=2), + nn.BatchNorm2d(16), + nn.ReLU(), + nn.Conv2d(16, 32, kernel_size=5, stride=2), + nn.BatchNorm2d(32), + nn.ReLU(), + nn.Conv2d(32, 32, kernel_size=5, stride=2), + nn.BatchNorm2d(32), + nn.ReLU() + ) + + conv_out_size = self._get_conv_out(input_shape) + if goal_feature: + conv_out_size += 4 + + # Standard DQN + self.fc = nn.Sequential( + nn.Linear(conv_out_size, hidden_nodes), + nn.LeakyReLU(), + nn.Linear(hidden_nodes, n_actions), + # nn.Softmax(dim=1) # necessary for Boltzman Sampling + ) + + def _get_conv_out(self, shape): + o = self.conv(torch.zeros(1, *shape)) + return int(np.prod(o.size())) + + def forward(self, x, goal=None): + x = x.to(self.device) + conv_out = self.conv(x).view(x.size()[0], -1) + if not goal is None: + conv_out = torch.cat((conv_out, goal.squeeze(1)), dim=1) + return self.fc(conv_out) + + +# Dueling Deep Q Network +class DuelingDQN(nn.Module): + def __init__(self, device, input_shape, n_actions, hidden_nodes=512, goal_feature=False): + super(DuelingDQN, self).__init__() + self.device = device + + self.conv = nn.Sequential( + nn.Conv2d(input_shape[0], 16, kernel_size=5, stride=2), + nn.BatchNorm2d(16), + nn.ReLU(), + nn.Conv2d(16, 32, kernel_size=5, stride=2), + nn.BatchNorm2d(32), + nn.ReLU(), + nn.Conv2d(32, 32, kernel_size=5, stride=2), + nn.BatchNorm2d(32), + nn.ReLU() + ) + + conv_out_size = self._get_conv_out(input_shape) + if goal_feature: + conv_out_size += 4 + + # Dueling DQN --> two heads + # advantage of actions A(s,a) + self.fc_adv = nn.Sequential( + nn.Linear(conv_out_size, hidden_nodes), + nn.LeakyReLU(), + nn.Linear(hidden_nodes, n_actions) + ) + # value of states V(s) + self.fc_val = nn.Sequential( + nn.Linear(conv_out_size, hidden_nodes), + nn.LeakyReLU(), + nn.Linear(hidden_nodes, 1) + ) + + + def _get_conv_out(self, shape): + o = self.conv(torch.zeros(1, *shape)) + return int(np.prod(o.size())) + + + def forward(self, x, goal=None): + x = x.to(self.device) + fx = x.float() / 256 + conv_out = self.conv(fx).view(fx.size()[0], -1) + if not goal is None: + conv_out = torch.cat((conv_out, goal.squeeze(1)), dim=1) + val = self.fc_val(conv_out) + adv = self.fc_adv(conv_out) + # Q(s,a) = V(s) + A(s,a) - 1/N sum(A(s,k)) --> expected value of advantage = 0 + return val + (adv - adv.mean(dim=1, keepdim=True)) + + +class MaskedHuberLoss(nn.Module): + def __init__(self, delta=1.0): + super(MaskedHuberLoss, self).__init__() + self.criterion = nn.HuberLoss(reduction='none', delta=delta) + + def forward(self, inputs, targets, mask): + loss = self.criterion(inputs, targets) + loss *= mask + return loss.sum(dim=-1) + + +# Agent that performs, remembers and learns actions +class Agent(): + def __init__(self, device, goal, input_shape, n_actions, random_play_steps, args, dueling=True, goal_feature=False): + self.device = device + self.n_actions = n_actions + self.goal_feature = goal_feature + + if dueling: + self.policy_net = DuelingDQN(device, input_shape, n_actions, args.hidden_nodes, goal_feature).to(self.device) + self.target_net = DuelingDQN(device, input_shape, n_actions, args.hidden_nodes, goal_feature).to(self.device) + else: + self.policy_net = DQN(device, input_shape, n_actions, args.hidden_nodes, goal_feature).to(self.device) + self.target_net = DQN(device, input_shape, n_actions, args.hidden_nodes, goal_feature).to(self.device) + #self.optimizer = optim.RMSprop(self.policy_net.parameters(), lr=args.learning_rate) + self.optimizer = optim.Adam(self.policy_net.parameters(), lr=args.learning_rate) + self.memory = PrioritizedReplayBuffer(args.experience_buffer, alpha=0.6) + self.exploration = LinearSchedule(schedule_timesteps=args.epsilon_steps, initial_p=args.epsilon_start, final_p=args.epsilon_end) + #self.exploration = TanHSchedule(schedule_timesteps=(args.epsilon_steps//10), initial_p=args.epsilon_start, final_p=args.epsilon_end) + self.loss_fn = MaskedHuberLoss(delta=1.0) + self.steps_done = -1 + self.batch_size = args.batch_size + self.gamma = args.gamma + self.goal = goal + self.enable_double_dqn = True + self.hard_update = 1000 + self.training_done = False + self.verbose = args.verbose + self.random_play_steps = random_play_steps + self.args = args + + def remember(self, state, action, reward, next_state, done, goal): + self.memory.add(state, action, reward, next_state, done, goal) + + def load_model(self, path): + #summary(self.policy_net, (4, 84, 84)) + self.policy_net.load_state_dict(torch.load(path)) + self.policy_net.eval() + self.random_play_steps = 100 + self.training_done = True + self.exploration = LinearSchedule(schedule_timesteps=1000, initial_p=0.1, final_p=0.005) + print('Loaded DDQN policy network weights from ', path) + + def get_epsilon(self): + return self.exploration.value(self.steps_done - self.random_play_steps) + + def reset_epsilon_schedule(self): + if isinstance(self.exploration, TanHSchedule): + self.exploration.restart() + else: + print('Exploration schedule is linear and can not be reset!') + + def select_action(self, state, goal=None): + #if not self.training_done: + self.steps_done += 1 + # Select an action according to an epsilon greedy approach + if self.steps_done < self.random_play_steps or random.random() < self.get_epsilon(): + return torch.tensor([[random.randrange(0, self.n_actions)]], device=self.device, dtype=torch.long) + else: + with torch.no_grad(): + return self.policy_net(state, goal).max(1)[1].view(1, 1) + + # Select an action via Boltzmann Sampling + if self.steps_done > self.random_play_steps: + return self.sample(state, temperature=self.get_epsilon()) + else: + return torch.tensor([[random.randrange(0, self.n_actions)]], device=self.device, dtype=torch.long) + + + """ + else: + with torch.no_grad(): + return self.policy_net(state).max(1)[1].view(1, 1) + """ + + + def sample(self, state, temperature=0.1, goal=None): + 'Predict action probability vector and use Boltzmann Sampling to get new action' + with torch.no_grad(): + prob_vec = self.policy_net(state, goal) + + sample_prob = torch.exp(prob_vec / temperature) + sample_prob /= sample_prob.sum() + + try: + sample_idx = sample_prob.multinomial(num_samples=1) # random choice with probabilities + except Exception as e: + # print(temperature, prob_vec, sample_prob) + # RuntimeError: probability tensor contains either `inf`, `nan` or element < 0 + import logging + logging.basicConfig(filename=f'runs/{self.args.model_name}sample_prob.log', filemode='w', format='%(name)s - %(levelname)s - %(message)s') + logging.warning(f'[Agent {self.goal} - step {self.steps_done}] {e}: {sample_prob}') + sample_idx = torch.tensor([[random.randrange(0, self.n_actions)]], device=self.device, dtype=torch.long) + + return sample_idx.view(1,1) + + def finish_training(self, model_name): + torch.save(self.policy_net.state_dict(), f"runs/run{model_name}-agent{self.goal}.pt") + self.training_done = True + # Free memory by deleting target net (no longer needed) + del self.target_net + + def optimize_model(self): + + if len(self.memory) < self.batch_size or self.steps_done < self.random_play_steps or self.training_done: + return + + # Sample from prioritized replay memory + sample = self.memory.sample(self.batch_size, beta=beta_schedule.value(self.steps_done)) + states, actions, rewards, next_states, dones, goals, importance_v, idx_vector = sample + + states_v = states.to(self.device) + next_states_v = next_states.to(self.device) + actions_v = actions.to(self.device) + rewards_v = rewards.to(self.device) + done_mask = dones.to(self.device) + goals_v = goals.to(self.device) + if not self.goal_feature: + goals_v = None + + # predicted Q(s, a) + q_values = self.policy_net(states_v, goals_v) + state_action_values = q_values.gather(1, actions_v.squeeze(-1)) + + if self.enable_double_dqn: + # calculate max_a Q'(s_t+1, argmax_a Q(s_t+1, a)) + next_state_actions = self.policy_net(next_states_v, goals_v).max(1)[1] # [1] = argmax + next_state_values = self.target_net(next_states_v, goals_v).gather(1, next_state_actions.unsqueeze(-1)).squeeze(-1) + else: + # calculate max_a Q'(s_t+1, a) + next_state_values = self.target_net(next_states_v, goals_v).max(1)[0] # [0] = max + + # Set discounted reward to zero for all states that were terminal + next_state_values[done_mask] = 0.0 + + # Compute r_t + gamma * max_a Q'(s_t+1, a) + expected_state_action_values = next_state_values.detach() * self.gamma + rewards_v + + targets = np.zeros((self.batch_size, self.n_actions)) + dummy_targets = np.zeros((self.batch_size,)) + masks = np.zeros((self.batch_size, self.n_actions)) + + for idx, (target, mask, reward, action) in enumerate(zip(targets, masks, expected_state_action_values, actions_v)): + target[action] = reward # update action with estimated accumulated reward + dummy_targets[idx] = reward + mask[action] = 1. # enable loss for this specific action + + # Calculate TD [Temporal Difference] error + action_batch = actions_v.squeeze().cpu().detach().numpy() + td_errors = targets[range(self.batch_size), action_batch] - state_action_values.squeeze().cpu().detach().numpy() + + # Compute masked loss between predicted state action values and targets + targets = torch.tensor(targets, dtype=torch.float32).to(self.device) + masks = torch.tensor(masks, dtype=torch.float32).to(self.device) + loss = self.loss_fn(q_values, targets , masks) + + # Update models and memory only when training is not done + if not self.training_done: + # Update priorities in replay buffer + new_priorities = np.abs(td_errors) + prioritized_replay_eps + self.memory.update_priorities(idx_vector, new_priorities) + + self.optimizer.zero_grad() + loss.mean().backward() # https://discuss.pytorch.org/t/loss-backward-raises-error-grad-can-be-implicitly-created-only-for-scalar-outputs/12152/2 + + for param in self.policy_net.parameters(): + param.grad.data.clamp_(-1, 1) + + self.optimizer.step() + + # Update the target network + if self.steps_done % self.hard_update == 0: + if self.verbose >= 1: + print('\nUpdating target network...\n') + self.target_net.load_state_dict(self.policy_net.state_dict()) + + + return loss diff --git a/agent/metacontroller.py b/agent/metacontroller.py new file mode 100644 index 0000000..4229b8c --- /dev/null +++ b/agent/metacontroller.py @@ -0,0 +1,158 @@ +from collections import deque + +import torch +import torch.nn as nn +import torch.optim as optim +from torchmetrics import Accuracy +import numpy as np + + +BATCH_SIZE = 32 +TRAIN_HIST_SIZE = 10000 +P_DROPOUT = 0.5 + +class MetaNN(nn.Module): + + def __init__(self, device, input_shape=(4,84,84), n_goals=4, hidden_nodes=512): + super(MetaNN, self).__init__() + + self.device = device + + ### Setup model architecture ### + self.conv = nn.Sequential( + nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4), + nn.ReLU(), + nn.Dropout(p=P_DROPOUT), + + nn.Conv2d(32, 64, kernel_size=4, stride=2), + nn.ReLU(), + nn.Dropout(p=P_DROPOUT), + + nn.Conv2d(64, 64, kernel_size=3, stride=1), + nn.ReLU(), + nn.Dropout(p=P_DROPOUT), + ) + + conv_out_size = self._get_conv_out(input_shape) + self.fc = nn.Sequential( + nn.Linear(conv_out_size, hidden_nodes), #TODO: copy initialization from meta_net_il.py + nn.ReLU(), + nn.Dropout(p=P_DROPOUT), + nn.Linear(hidden_nodes, n_goals), + nn.Softmax(dim=1) + ) + + def _get_conv_out(self, shape): + o = self.conv(torch.zeros(1, *shape)) + return int(np.prod(o.size())) + + def forward(self, x): + x = x.to(self.device) + conv_out = self.conv(x).view(x.size()[0], -1) + return self.fc(conv_out) + + + +class MetaController(): + def __init__(self, device, args, input_shape=(4, 84, 84), n_goals=4, hidden_nodes=512) -> None: + self.device = device + self.model = MetaNN(device, input_shape, n_goals, hidden_nodes).to(self.device) + print('Saving init state of MetaNN') + torch.save(self.model.state_dict(), f"runs/meta_init-{args.model_name}.pt") + + self.optimizer = optim.RMSprop(self.model.parameters(), lr=0.00025, alpha=0.95, eps=1e-08, weight_decay=0.0) + self.loss_fn = nn.MSELoss() + self.accuracy_fn = Accuracy(num_classes=n_goals, average='macro').to(device) + + self.replay_hist = deque([None], maxlen=TRAIN_HIST_SIZE) + + self.ind = 0 + self.count = 0 + self.meta_steps = 0 + + self.input_shape = input_shape + self.n_goals = n_goals + self.verbose = args.verbose + self.name = args.model_name + + def reset(self) -> None: + 'Load initial state dictionary from file and reset optimizer' + self.model.load_state_dict(torch.load(f"runs/meta_init-{self.name}.pt")) + self.optimizer = optim.RMSprop(self.model.parameters(), lr=0.00025, alpha=0.95, eps=1e-08, weight_decay=0.0) + + def check_training_clock(self) -> bool: + 'Only train every meta controller steps, i.e. new samples in replay buffer.' + if BATCH_SIZE: + return (self.meta_steps % BATCH_SIZE == 0) + else: + return (self.meta_steps % 20 == 0) + + def collect(self, processed, expert_a) -> None: + 'Collect sample consisting of state (4, 84, 84) and one-hot vector of goal' + if processed is not None: + self.replay_hist.appendleft(tuple([processed, expert_a])) + self.meta_steps += 1 + + def train(self): + # if not reached TRAIN_HIST_SIZE yet, then get the number of samples + num_samples = min(self.meta_steps, TRAIN_HIST_SIZE) + + inputs = torch.stack([self.replay_hist[i][0] for i in range(num_samples)], 0).to(self.device) + labels = torch.stack([torch.tensor(self.replay_hist[i][1], dtype=torch.float32) for i in range(num_samples)], 0).to(self.device) + + if self.verbose >= 2.0: + print(f'\nMetacontroller collected {self.meta_steps} samples') + if len(labels) == TRAIN_HIST_SIZE: + print(f'Reached TRAIN_HIST_SIZE = {TRAIN_HIST_SIZE}') + print('Dataset Distribution:') + for goal in labels.unique(): + goal = goal.cpu().detach().numpy() + print(f'\nNumber of samples for goal {goal}: {sum(labels).squeeze()[goal]}' ) + print(f'--> {sum(labels).squeeze()[goal] / len(labels):.2%}') + print() + + if BATCH_SIZE and num_samples >= BATCH_SIZE: + # train one epoch --> noisy convergence more likely to find broader minimum + accumulated_loss = [] + accumulated_acc = [] + + for index in range(0, len(labels) // BATCH_SIZE): + b_inputs, b_labels = inputs[index * BATCH_SIZE: (index + 1) * BATCH_SIZE], labels[index * BATCH_SIZE: (index + 1) * BATCH_SIZE] + + # zero the parameter gradients + self.optimizer.zero_grad() + + outputs = self.model(b_inputs) + loss = self.loss_fn(outputs, b_labels.squeeze(1)) + + loss.backward() + self.optimizer.step() + + accumulated_loss.append(loss) + accumulated_acc.append(self.accuracy_fn(outputs, b_labels.squeeze(1).type(torch.uint8))) + loss = torch.stack(accumulated_loss).mean() + accuracy = torch.stack(accumulated_acc).mean() + else: + # run once over all samples --> smooth convergence to a deep local minimum + self.optimizer.zero_grad() + outputs = self.model(inputs) + loss = self.loss_fn(outputs, labels.squeeze(1)) + accuracy = self.accuracy_fn(outputs, labels.squeeze(1).type(torch.uint8)) + loss.backward() + self.optimizer.step() + + self.count = 0 # reset the count clock + return loss, accuracy + + def predict(self, state, batch_size=1) -> np.ndarray: + 'Predict probability distribution of goals with metacontroller model and return as ndarray for sampling' + return self.model.forward(state.unsqueeze(0)).squeeze(0).detach().cpu().numpy() + + def sample(self, prob_vec, temperature=0.1) -> int: + prob_pred = np.log(prob_vec) / temperature + dist = np.exp(prob_pred)/np.sum(np.exp(prob_pred)) + choices = range(len(prob_pred)) + return np.random.choice(choices, p=dist) + + + diff --git a/agent/replay_buffer.py b/agent/replay_buffer.py new file mode 100644 index 0000000..6112fea --- /dev/null +++ b/agent/replay_buffer.py @@ -0,0 +1,408 @@ +# This script is taken from openAI's baselines implementation and adapted for hInt-RL +# =================================================================================================================== +import random +import operator +from collections import namedtuple + +import torch +import numpy as np + +Experience = namedtuple('Experience', field_names=['state', 'action', 'reward', 'next_state', 'done', 'goal']) + +class LinearSchedule(object): + def __init__(self, schedule_timesteps, final_p, initial_p=1.0): + """Linear interpolation between initial_p and final_p over + schedule_timesteps. After this many timesteps pass final_p is + returned. + + Parameters + ---------- + schedule_timesteps: int + Number of timesteps for which to linearly anneal initial_p + to final_p + initial_p: float + initial output value + final_p: float + final output value + """ + self.schedule_timesteps = schedule_timesteps + self.final_p = final_p + self.initial_p = initial_p + + def value(self, t): + """See Schedule.value""" + fraction = min(float(t) / self.schedule_timesteps, 1.0) + return self.initial_p + fraction * (self.final_p - self.initial_p) + + +class TanHSchedule(object): + def __init__(self, schedule_timesteps, final_p, initial_p=1.0): + """This is a tanh annealing schedule with restarts, i.e. hyperbolic tangent decay. + + Parameters + ---------- + schedule_timesteps: int + Number of overall timesteps for which to anneal initial_p to final_p + initial_p: float + max value + final_p: float + min value + """ + self.schedule_timesteps = schedule_timesteps + self.final_p = final_p + self.initial_p = initial_p + self.tt = 0 + + def value(self, t): + if t > 0: + # otherwise random play steps are active + self.tt += 1. + tanh = np.tanh(self.tt/(self.schedule_timesteps) - 4.) + return (self.initial_p + self.final_p) / 2. - (self.initial_p - self.final_p) / 2. * tanh + + def restart(self): + print('Resetting tanh cycle...') + self.tt = 0 + + +class SegmentTree(object): + def __init__(self, capacity, operation, neutral_element): + """Build a Segment Tree data structure. + + https://en.wikipedia.org/wiki/Segment_tree + + Can be used as regular array, but with two + important differences: + + a) setting item's value is slightly slower. + It is O(lg capacity) instead of O(1). + b) user has access to an efficient `reduce` + operation which reduces `operation` over + a contiguous subsequence of items in the + array. + + Paramters + --------- + capacity: int + Total size of the array - must be a power of two. + operation: lambda obj, obj -> obj + and operation for combining elements (eg. sum, max) + must for a mathematical group together with the set of + possible values for array elements. + neutral_element: obj + neutral element for the operation above. eg. float('-inf') + for max and 0 for sum. + """ + assert capacity > 0 and capacity & (capacity - 1) == 0, "capacity must be positive and a power of 2." + self._capacity = capacity + self._value = [neutral_element for _ in range(2 * capacity)] + self._operation = operation + + def _reduce_helper(self, start, end, node, node_start, node_end): + if start == node_start and end == node_end: + return self._value[node] + mid = (node_start + node_end) // 2 + if end <= mid: + return self._reduce_helper(start, end, 2 * node, node_start, mid) + else: + if mid + 1 <= start: + return self._reduce_helper(start, end, 2 * node + 1, mid + 1, node_end) + else: + return self._operation( + self._reduce_helper(start, mid, 2 * node, node_start, mid), + self._reduce_helper(mid + 1, end, 2 * node + 1, mid + 1, node_end) + ) + + def reduce(self, start=0, end=None): + """Returns result of applying `self.operation` + to a contiguous subsequence of the array. + + self.operation(arr[start], operation(arr[start+1], operation(... arr[end]))) + + Parameters + ---------- + start: int + beginning of the subsequence + end: int + end of the subsequences + + Returns + ------- + reduced: obj + result of reducing self.operation over the specified range of array elements. + """ + if end is None: + end = self._capacity + if end < 0: + end += self._capacity + end -= 1 + return self._reduce_helper(start, end, 1, 0, self._capacity - 1) + + def __setitem__(self, idx, val): + # index of the leaf + idx += self._capacity + self._value[idx] = val + idx //= 2 + while idx >= 1: + self._value[idx] = self._operation( + self._value[2 * idx], + self._value[2 * idx + 1] + ) + idx //= 2 + + def __getitem__(self, idx): + assert 0 <= idx < self._capacity + return self._value[self._capacity + idx] + + +class SumSegmentTree(SegmentTree): + def __init__(self, capacity): + super(SumSegmentTree, self).__init__( + capacity=capacity, + operation=operator.add, + neutral_element=0.0 + ) + + def sum(self, start=0, end=None): + """Returns arr[start] + ... + arr[end]""" + return super(SumSegmentTree, self).reduce(start, end) + + def find_prefixsum_idx(self, prefixsum): + """Find the highest index `i` in the array such that + sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum + + if array values are probabilities, this function + allows to sample indexes according to the discrete + probability efficiently. + + Parameters + ---------- + perfixsum: float + upperbound on the sum of array prefix + + Returns + ------- + idx: int + highest index satisfying the prefixsum constraint + """ + assert 0 <= prefixsum <= self.sum() + 1e-5 + idx = 1 + while idx < self._capacity: # while non-leaf + if self._value[2 * idx] > prefixsum: + idx = 2 * idx + else: + prefixsum -= self._value[2 * idx] + idx = 2 * idx + 1 + return idx - self._capacity + + +class MinSegmentTree(SegmentTree): + def __init__(self, capacity): + super(MinSegmentTree, self).__init__( + capacity=capacity, + operation=min, + neutral_element=float('inf') + ) + + def min(self, start=0, end=None): + """Returns min(arr[start], ..., arr[end])""" + + return super(MinSegmentTree, self).reduce(start, end) + + +class ReplayBuffer(object): + def __init__(self, size): + """Create Replay buffer. + + Parameters + ---------- + size: int + Max number of transitions to store in the buffer. When the buffer + overflows the old memories are dropped. + """ + self._storage = [] + self._maxsize = size + self._next_idx = 0 + + def __len__(self): + return len(self._storage) + + def add(self, obs_t, action, reward, obs_tp1, done, goal): + data = (obs_t, action, reward, obs_tp1, done, goal) + + if self._next_idx >= len(self._storage): + self._storage.append(data) + else: + self._storage[self._next_idx] = data + self._next_idx = (self._next_idx + 1) % self._maxsize + + def _encode_sample_torch(self, idxes): + sample = [self._storage[i] for i in idxes] + batch = Experience(*zip(*sample)) + return torch.stack(batch.state), torch.stack(batch.action), torch.stack(batch.reward).squeeze(), torch.stack(batch.next_state), \ + torch.stack([torch.tensor(batch.done)]).squeeze(), torch.stack(batch.goal) + + """ + # old implementation from + def _encode_sample(self, idxes): + obses_t, actions, rewards, obses_tp1, dones, goals = [], [], [], [], [], [] + for i in idxes: + data = self._storage[i] + obs_t, action, reward, obs_tp1, done, goal = data + obses_t.append(np.array(obs_t, copy=False)) + actions.append(np.array(action, copy=False)) + rewards.append(reward) + obses_tp1.append(np.array(obs_tp1, copy=False)) + dones.append(done) + goals.append(np.array(action, copy=False)) + return np.array(obses_t), np.array(actions), np.array(rewards, dtype=np.float32), np.array(obses_tp1), np.array(dones, dtype=np.uint8), np.array(goals, dtype=np.uint8) + """ + + def sample(self, batch_size): + """Sample a batch of experiences. + + Parameters + ---------- + batch_size: int + How many transitions to sample. + + Returns + ------- + obs_batch: np.array + batch of observations + act_batch: np.array + batch of actions executed given obs_batch + rew_batch: np.array + rewards received as results of executing act_batch + next_obs_batch: np.array + next set of observations seen after executing act_batch + done_mask: np.array + done_mask[i] = 1 if executing act_batch[i] resulted in + the end of an episode and 0 otherwise. + """ + idxes = [random.randint(0, len(self._storage) - 1) for _ in range(batch_size)] + return self._encode_sample_torch(idxes) + + +class PrioritizedReplayBuffer(ReplayBuffer): + def __init__(self, size, alpha): + """Create Prioritized Replay buffer. + + Parameters + ---------- + size: int + Max number of transitions to store in the buffer. When the buffer + overflows the old memories are dropped. + alpha: float + how much prioritization is used + (0 - no prioritization, 1 - full prioritization) + + See Also + -------- + ReplayBuffer.__init__ + """ + super(PrioritizedReplayBuffer, self).__init__(size) + assert alpha > 0 + self._alpha = alpha + + it_capacity = 1 + while it_capacity < size: + it_capacity *= 2 + + self._it_sum = SumSegmentTree(it_capacity) + self._it_min = MinSegmentTree(it_capacity) + self._max_priority = 1.0 + + def add(self, *args, **kwargs): + """See ReplayBuffer.store_effect""" + idx = self._next_idx + super(PrioritizedReplayBuffer,self).add(*args, **kwargs) + self._it_sum[idx] = self._max_priority ** self._alpha + self._it_min[idx] = self._max_priority ** self._alpha + + def _sample_proportional(self, batch_size): + res = [] + for _ in range(batch_size): + # TODO(szymon): should we ensure no repeats? + mass = random.random() * self._it_sum.sum(0, len(self._storage) - 1) + idx = self._it_sum.find_prefixsum_idx(mass) + res.append(idx) + return res + + def sample(self, batch_size, beta): + """Sample a batch of experiences. + + compared to ReplayBuffer.sample + it also returns importance weights and idxes + of sampled experiences. + + + Parameters + ---------- + batch_size: int + How many transitions to sample. + beta: float + To what degree to use importance weights + (0 - no corrections, 1 - full correction) + + Returns + ------- + obs_batch: np.array + batch of observations + act_batch: np.array + batch of actions executed given obs_batch + rew_batch: np.array + rewards received as results of executing act_batch + next_obs_batch: np.array + next set of observations seen after executing act_batch + done_mask: np.array + done_mask[i] = 1 if executing act_batch[i] resulted in + the end of an episode and 0 otherwise. + weights: np.array + Array of shape (batch_size,) and dtype np.float32 + denoting importance weight of each sampled transition + idxes: np.array + Array of shape (batch_size,) and dtype np.int32 + idexes in buffer of sampled experiences + """ + assert beta > 0 + + idxes = self._sample_proportional(batch_size) + + weights = [] + p_min = self._it_min.min() / self._it_sum.sum() + max_weight = (p_min * len(self._storage)) ** (-beta) + + for idx in idxes: + p_sample = self._it_sum[idx] / self._it_sum.sum() + weight = (p_sample * len(self._storage)) ** (-beta) + weights.append(weight / max_weight) + weights = np.array(weights) + encoded_sample = self._encode_sample_torch(idxes) + return tuple(list(encoded_sample) + [weights, idxes]) + + def update_priorities(self, idxes, priorities): + """Update priorities of sampled transitions. + + sets priority of transition at index idxes[i] in buffer + to priorities[i]. + + Parameters + ---------- + idxes: [int] + List of idxes of sampled transitions + priorities: [float] + List of updated priorities corresponding to + transitions at the sampled idxes denoted by + variable `idxes`. + """ + assert len(idxes) == len(priorities) + for idx, priority in zip(idxes, priorities): + #print priority + #time.sleep(0.5) + assert priority > 0 + assert 0 <= idx < len(self._storage) + self._it_sum[idx] = priority ** self._alpha + self._it_min[idx] = priority ** self._alpha + + self._max_priority = max(self._max_priority, priority) diff --git a/agent/run_experiment.py b/agent/run_experiment.py new file mode 100644 index 0000000..4a85024 --- /dev/null +++ b/agent/run_experiment.py @@ -0,0 +1,343 @@ +import os +import time + +import cv2 +import torch +import numpy as np +import imageio +import wandb +from datetime import datetime + +from ddqn_agent import Agent +from atari_env import ALEEnvironment +from args import HIntArgumentParser +from metacontroller import MetaController + + +GPU_DEVICE = 7 +VERBOSE = 2 +DEBUG = False +TEST = False + +STOP_TRAINING_THRESHOLD = 0.90 +EPISODES = 1000000 +MAX_EPISODE_STEPS = 1000 +RANDOM_PLAY_STEPS = 20000 + + +# Use subgoals from gaze analysis +GOAL_MODE = 'full_sequence' # 'full_sequence' 'unique' +ALL_SUBGOALS = np.loadtxt('subgoals.txt', dtype=int, delimiter=',') +SUBGOAL_ORDER = [8, 6, 1, 0, 2, 7, 2, 0, 1, 6, 8, 9] +GOALS = len(SUBGOAL_ORDER) if GOAL_MODE == 'full_sequence' else len(np.unique(SUBGOAL_ORDER)) +SUBGOALs = SUBGOAL_ORDER if GOAL_MODE == 'full_sequence' else np.array(SUBGOAL_ORDER)[np.sort(np.unique(SUBGOAL_ORDER, return_index=True)[1])] # unsorted unique -> keep subgoal appearance order +TRAINED_GOALS = [False] * GOALS + +def main(): + + # init random seed + RANDOM_SEED = np.random.randint(100) + print(f'Setting random seed to {RANDOM_SEED}') + os.environ['PYTHONHASHSEED']=str(RANDOM_SEED) + np.random.seed(RANDOM_SEED) + torch.manual_seed(RANDOM_SEED) + + x = datetime.now() + + TIME_STAMP = x.strftime("%d%b%y-%H%M%S") + MODEL_NAME = 'hInt-RL-full_' + str(RANDOM_SEED) + '_' + TIME_STAMP + + os.environ['CUDA_VISIBLE_DEVICES']=str(GPU_DEVICE) + os.environ['CUDA_LAUNCH_BLOCKING'] = str(1.0) + + actionMap = [0, 1, 2, 3, 4, 5, 11, 12] + actionExplain = ['no action', 'jump', 'up', 'right', 'left', 'down', 'jump right', 'jump left'] + actionVectors = [[0,0], [0.0,-1.0], [0.0,-1.0], [1.0,0.0], [-1.0,0.0], [0.0,1.0], [0.7071067811865475,-0.7071067811865475], [-0.7071067811865475,-0.7071067811865475]] + + inv_label_mapping = {} + for i, l in enumerate(SUBGOALs): # np.unique without sorting the values + inv_label_mapping[l] = i + print(inv_label_mapping) + + goalExplain = {8: 'start', 6: 'rope', 1: 'lower right ladder', 0: 'danger zone', 2: 'lower left ladder', 7: 'key', 9:'left door'} + + subgoal_success_tracker = [[0] for i in range(GOALS)] + subgoal_success_steps = [[] for i in range(GOALS)] + subgoal_trailing_performance = [0.0 for _ in range(GOALS)] + + parser = HIntArgumentParser() + parser.set_common_args() + + args = parser.parse_args(['--model_name', MODEL_NAME, '--verbose', str(VERBOSE), '--random_seed', str(RANDOM_SEED), + '--num_goals', str(GOALS), '--goal_mode', GOAL_MODE]) + + if not DEBUG: + wandb.init(project="hInt-RL", config=vars(args), name=MODEL_NAME) + + print(*[f'\n[PARAM] {k}: {v}' for (k,v) in vars(args).items()]) + print() + + # Setup environment + env = ALEEnvironment(args, device=f"cuda:{GPU_DEVICE}") + input_shape = env.getStackedState().shape + print(f'Input shape {input_shape}\n') + + device = torch.device(f"cuda:{GPU_DEVICE}" if torch.cuda.is_available() else "cpu") + + # Setup metacontroller + metacontroller = MetaController(device=device, args=args, input_shape=input_shape, n_goals=GOALS, hidden_nodes=args.hidden_nodes) + + # Setup agents and their DQN models + agent_list = [] + + for goal in range(sum(TRAINED_GOALS), GOALS): + print(f'Setting up agent for goal {goalExplain.get(SUBGOALs[goal])}') + + agent = Agent(device=device, goal=goal, input_shape=input_shape, n_actions=len(actionMap), random_play_steps=20000, args=args) + agent_list.append(agent) + + + total_rewards = [] + wrong_meta_pred = 0 + total_steps = 0 + create_video = False + + for episode in range(EPISODES): + if args.verbose >= 1: + print(f'\n[Episode {episode}] Starting new episode...') + + # Initialize environment and get state + env.restart() + state = env.getStackedState() + + episode_steps = 0 + subgoal_agent_loss = [0.0 for _ in range(GOALS)] + meta_labels = [] + wrong_goal = False + goal_idx = 0 + true_goal = goal_idx if GOAL_MODE == 'full_sequence' else inv_label_mapping.get(SUBGOAL_ORDER[goal_idx]) + + img_array = [] + if episode % 10000 == 0: + create_video = True + + expert_goal = np.zeros((1, GOALS)) + expert_goal[0, true_goal] = 1.0 + meta_labels.append((state, expert_goal)) # append, but do not collect yet + + if TRAINED_GOALS[true_goal]: + goal = true_goal + # Collect sample for metacontroller because goal will be reached + metacontroller.collect(state, expert_goal) + else: + # Metacontroller predict goal + goal = metacontroller.sample(metacontroller.predict(state)) + + if goal!= true_goal: + wrong_goal = True + if args.verbose >= 1: + print(f"[Episode {episode}] Metacontroller predicted {goal} instead of {true_goal} as goal \U0001F47E ") + wrong_meta_pred += 1 + if wrong_meta_pred % 100 == 0: + print(f'[Episode {episode}] 100 wrong meta choices. Resetting metacontroller... ') + metacontroller.reset() + else: + if args.verbose >= 1: + print(f'[Episode {episode}] Metacontroller predicted goal {goal} as new goal...') + + all_step_times = [] + step_time = time.time() + + while not env.isTerminal() and episode_steps < MAX_EPISODE_STEPS and not wrong_goal: + + # Unroll episode until agent reaches goal + while not env.trueGoalReached(SUBGOAL_ORDER[goal_idx]) and not env.isTerminal() and episode_steps < MAX_EPISODE_STEPS and not wrong_goal: + + goal_position = torch.tensor(np.array([ALL_SUBGOALS[SUBGOAL_ORDER[goal_idx]]]), device=device) + + if create_video: + img = env.getScreenOrig() + + cv2.putText(img=img, text='goal : ' + goalExplain.get(SUBGOAL_ORDER[goal_idx]), org=(5, 205), fontFace=cv2.FONT_HERSHEY_SIMPLEX, + fontScale=0.3, color=(255, 255, 255),thickness=1) + + img_array.append(img) + + action = agent_list[goal].select_action(state.unsqueeze(0)) + + external_reward = env.act(actionMap[action]) + if (external_reward != 0): + external_reward = 1.0 + + all_step_times.append(time.time() - step_time) + step_time = time.time() + + episode_steps += 1 + + # Calculate intrinsic reward [optionally add distance reward] + reward = 0.0 + if env.trueGoalReached(SUBGOAL_ORDER[goal_idx]): + reward += 1.0 + if env.isTerminal() or episode_steps==MAX_EPISODE_STEPS: + reward -= 1.0 + + """ + # Simple direction reward + goal_vec = env.get_goal_direction(goal) + action_vec = actionVectors[action] + direction_reward = np.dot(action_vec, goal_vec) / 100 + + reward += direction_reward + """ + + """ + # Distance Reward + # Query agent location every 20 steps (otherwise training is too slow) + if episode_steps % 20 == 0: + env.detect_agent() + + reward += env.distanceReward(lastGoal=(true_goal-1), goal=goal) + """ + + total_rewards.append(reward) + reward = torch.tensor([reward], device=device) + + next_state = env.getStackedState() + + # Store transition and update network parameters + agent_list[goal].remember(state, action, reward, next_state, env.isTerminal(), goal_position) + + # Move to the next state + state = next_state + + # Optimize the policy network + agent_loss = agent_list[goal].optimize_model() + if agent_loss is not None: + subgoal_agent_loss[goal] += agent_loss + + # Update goal + if episode_steps >= MAX_EPISODE_STEPS: + subgoal_success_tracker[goal].append(0) + if args.verbose >= 1: + print(f'[Episode {episode}] Reached maximum epsiode steps: {episode_steps}. Terminate episode.') + break + + elif env.trueGoalReached(SUBGOAL_ORDER[goal_idx]): + if args.verbose >= 1: + print(f'[Episode {episode}] Goal reached! \U0001F389 after step #{episode_steps}. ') + + subgoal_success_tracker[goal].append(1) + subgoal_success_steps[goal].append(episode_steps) + + # Predict new goal and continue if its true goal + goal_idx += 1 + + if goal_idx == len(SUBGOAL_ORDER): + # Finished all options --> start new episode + wrong_goal = True + print(f"[Episode {episode}] Reached all goals! \U0001F38A \U0001F38A \U0001F38A") + break + + true_goal = goal_idx if GOAL_MODE == 'full_sequence' else inv_label_mapping.get(SUBGOAL_ORDER[goal_idx]) + expert_goal = np.zeros((1, GOALS)) + expert_goal[0, true_goal] = 1.0 + meta_labels.append((state, expert_goal)) # append, but do not collect yet + + # Metacontroller predict goal + goal = metacontroller.sample(metacontroller.predict(state)) + + if goal!= true_goal: + wrong_goal = True + if args.verbose >= 1: + print(f"[Episode {episode}] Metacontroller predicted {goal} instead of {true_goal} as goal \U0001F47E ") + wrong_meta_pred += 1 + if wrong_meta_pred % 100 == 0: + if args.verbose >= 1: + print(f'[Episode {episode}] 100 wrong meta choices. Resetting metacontroller... ') + metacontroller.reset() + break + else: + # Continue with new goal + if args.verbose >= 1: + print(f'[Episode {episode}] Metacontroller predicted goal {goal} as new goal...') + + else: + subgoal_success_tracker[goal].append(0) + if args.verbose >= 1: + print(f'[Episode {episode}] Agent killed after {episode_steps} steps \U0001F47E Terminate episode.') + break + + if create_video and img_array: + out_path = f'runs/video_run{TIME_STAMP}-E{episode}.gif' + imageio.mimsave(out_path, img_array) + print('Saved gif to ', out_path) + create_video = False + + """END OF EPISODE""" + # At end of episode: aggregate data for metacontroller + item = meta_labels[-1] + metacontroller.collect(item[0], item[1]) + + # Log agent losses if agent was running + total_steps += episode_steps + avg_step_time = sum(all_step_times) / len(all_step_times) if episode_steps > 0 else 0 + #print(f'[Episode {episode}] episode time: {sum(all_step_times) / 60:.2f}min average step time: {avg_step_time*100:.1f}ms') + + if not DEBUG: + wandb.log({'mean_reward': np.array(total_rewards[-episode_steps:]).mean(), 'avg_step_time': avg_step_time*100, 'episode': episode}) + + for g in range(GOALS): + if len(subgoal_success_tracker[g]) < episode: + # goal not evaluated yet because other goals not reached + subgoal_success_tracker[g].append(0) + # agent_loss = subgoal_agent_loss[g] / episode_steps if episode_steps > 0 else 0 + # log_steps = episode_steps if (subgoal_success_tracker[g] and subgoal_success_tracker[g][-1] == 1) else -episode_steps + if not DEBUG: + wandb.log({ f'sub-goal {g}': subgoal_success_tracker[g][-1], 'episode': episode}) #f'agent{g}_loss': agent_loss, f'agent{g}_steps': log_steps, + + + # Train metacontroller + if metacontroller.check_training_clock(): + if args.verbose >= 2: + print('###################################') + print('### Training Metacontroller ###') + print('###################################') + + meta_loss, meta_acc = metacontroller.train() + if not DEBUG: + wandb.log({'meta_loss': meta_loss, 'meta_acc': meta_acc, 'episode': episode}) + + if args.verbose >= 2: + print(f'Metacontroller loss: {meta_loss:.2f}\n') + print('###################################') + + if len(subgoal_success_tracker[goal]) > 100: + subgoal_trailing_performance[goal] = sum([v for v in subgoal_success_tracker[goal][-100:] if not v is None])/ 100.0 + + if args.verbose >= 0: + print(f'[Episode {episode}] Subgoal trailing performance for goal {goal} is {subgoal_trailing_performance[goal]:.2f}') + print(f'[Episode {episode}] Subgoal agent {goal} steps done {agent_list[goal].steps_done}') + print(f'[Episode {episode}] Subgoal agent {goal} epsilon value is {agent_list[goal].get_epsilon():.2f}') + + if subgoal_trailing_performance[goal] > STOP_TRAINING_THRESHOLD: + if args.verbose >= 1: + print(f'[Episode {episode}] Training for goal #{goal} completed...') + if not agent_list[goal].training_done: + agent_list[goal].finish_training(TIME_STAMP) + TRAINED_GOALS[goal] = True + + if goal_idx == len(SUBGOAL_ORDER): + # Last goal reached trailing performance --> stop training + print(f"\nTraining completed \U0001F38A \U0001F38A \U0001F38A \n") + torch.save(metacontroller.model.state_dict(), f"runs/run{TIME_STAMP}-metacontroller.pt") + break + + # After training is completed + for goal in range(GOALS): + if subgoal_success_steps[goal]: + print(f'\nAgent{goal} performance: {np.array(subgoal_success_tracker[goal]).mean():.2%}') + print(f'Reached goal {goal} {sum(subgoal_success_tracker[goal])}x with {np.array(subgoal_success_steps[goal]).mean():.0f} average steps') + + +if __name__ == "__main__": + main() diff --git a/agent/single_agent_experiment.py b/agent/single_agent_experiment.py new file mode 100644 index 0000000..c3e5661 --- /dev/null +++ b/agent/single_agent_experiment.py @@ -0,0 +1,348 @@ +import os +import time + +import cv2 +import torch +import numpy as np +import imageio +import wandb +from datetime import datetime + +from ddqn_agent import Agent +from atari_env import ALEEnvironment +from args import HIntArgumentParser +from metacontroller import MetaController + + +GPU_DEVICE = 4 +VERBOSE = 2 +DEBUG = False +TEST = False + +STOP_TRAINING_THRESHOLD = 0.90 +EPISODES = 1000000 +MAX_EPISODE_STEPS = 1000 +RANDOM_PLAY_STEPS = 20000 + +# Use subgoals from gaze analysis +GOAL_FEATURE = True +GOAL_MODE = 'single_agent' +ALL_SUBGOALS = np.loadtxt('subgoals.txt', dtype=int, delimiter=',') +SUBGOAL_ORDER = [8, 6, 1, 0, 2, 7, 2, 0, 1, 6, 8, 9] +GOALS = len(SUBGOAL_ORDER) #if GOAL_MODE == 'full_sequence' else len(np.unique(SUBGOAL_ORDER)) +SUBGOALs = SUBGOAL_ORDER #if GOAL_MODE == 'full_sequence' else np.array(SUBGOAL_ORDER)[np.sort(np.unique(SUBGOAL_ORDER, return_index=True)[1])] # unsorted unique -> keep subgoal appearance order +TRAINED_GOALS = [False] * GOALS + +def main(): + + # init random seed + RANDOM_SEED = 9 #np.random.randint(100) + print(f'Setting random seed to {RANDOM_SEED}') + os.environ['PYTHONHASHSEED']=str(RANDOM_SEED) + np.random.seed(RANDOM_SEED) + torch.manual_seed(RANDOM_SEED) + + x = datetime.now() + + TIME_STAMP = x.strftime("%d%b%y-%H%M%S") + MODEL_NAME = 'hInt-RL-single-dist_' + str(RANDOM_SEED) + '_' + TIME_STAMP + + os.environ['CUDA_VISIBLE_DEVICES']=str(GPU_DEVICE) + os.environ['CUDA_LAUNCH_BLOCKING'] = str(1.0) + + actionMap = [0, 1, 2, 3, 4, 5, 11, 12] + actionExplain = ['no action', 'jump', 'up', 'right', 'left', 'down', 'jump right', 'jump left'] + actionVectors = [[0,0], [0.0,-1.0], [0.0,-1.0], [1.0,0.0], [-1.0,0.0], [0.0,1.0], [0.7071067811865475,-0.7071067811865475], [-0.7071067811865475,-0.7071067811865475]] + goalExplain = {8: 'start', 6: 'rope', 1: 'lower right ladder', 0: 'danger zone', 2: 'lower left ladder', 7: 'key', 9:'left door'} + + inv_label_mapping = {} + for i, l in enumerate(SUBGOALs): # np.unique without sorting the values + inv_label_mapping[l] = i + print(inv_label_mapping) + + subgoal_success_tracker = [[0] for i in range(GOALS)] + subgoal_success_steps = [[] for i in range(GOALS)] + subgoal_trailing_performance = [0.0 for _ in range(GOALS)] + + parser = HIntArgumentParser() + parser.set_common_args() + + args = parser.parse_args(['--model_name', MODEL_NAME, '--verbose', str(VERBOSE), '--random_seed', str(RANDOM_SEED), + '--num_goals', str(GOALS), '--goal_mode', GOAL_MODE, + '--epsilon_start', str(1.0), '--epsilon_end', str(0.1), '--epsilon_steps', str(2000000)]) + + if not DEBUG: + wandb.init(project="hInt-RL", config=vars(args), name=MODEL_NAME) + + print(*[f'\n[PARAM] {k}: {v}' for (k,v) in vars(args).items()]) + print() + + env = ALEEnvironment(args, device=f"cuda:{GPU_DEVICE}") + + input_shape = env.getStackedState().shape + #input_shape = np.array([input_shape[2], input_shape[1], input_shape[0]]) + print(f'Input shape {input_shape}\n') + + device = torch.device(f"cuda:{GPU_DEVICE}" if torch.cuda.is_available() else "cpu") + + # Setup metacontroller + metacontroller = MetaController(device=device, args=args, input_shape=input_shape, n_goals=GOALS, hidden_nodes=args.hidden_nodes) + + # Setup single agent + agent = Agent(device=device, goal=None, input_shape=input_shape, n_actions=len(actionMap), random_play_steps=RANDOM_PLAY_STEPS, args=args, goal_feature=GOAL_FEATURE) + + total_rewards = [] + wrong_meta_pred = 0 + total_steps = 0 + create_video = False + + for episode in range(EPISODES): + if args.verbose >= 1: + print(f'\n[Episode {episode}] Starting new episode...') + + # Initialize environment and get state + env.restart() + state = env.getStackedState() + + episode_steps = 0 + subgoal_agent_loss = 0 + meta_labels = [] + wrong_goal = False + goal_idx = 0 + true_goal = goal_idx # if GOAL_MODE == 'full_sequence' else inv_label_mapping.get(SUBGOAL_ORDER[goal_idx]) + + img_array = [] + if episode % 10000 == 0: + create_video = True + + expert_goal = np.zeros((1, GOALS)) + expert_goal[0, true_goal] = 1.0 + meta_labels.append((state, expert_goal)) # append, but do not collect yet + + if TRAINED_GOALS[true_goal]: + goal = true_goal + # Collect sample for metacontroller because goal will be reached + metacontroller.collect(state, expert_goal) + else: + # Metacontroller predict goal + goal = metacontroller.sample(metacontroller.predict(state)) + + if goal!= true_goal: + wrong_goal = True + if args.verbose >= 1: + print(f"[Episode {episode}] Metacontroller predicted {goal} instead of {true_goal} as goal \U0001F47E ") + wrong_meta_pred += 1 + if wrong_meta_pred % 100 == 0: + print(f'[Episode {episode}] 100 wrong meta choices. Resetting metacontroller... ') + metacontroller.reset() + + else: + if args.verbose >= 1: + print(f'[Episode {episode}] Metacontroller predicted goal {goal} as new goal...') + + all_step_times = [] + step_time = time.time() + + while not env.isTerminal() and episode_steps < MAX_EPISODE_STEPS and not wrong_goal: + + # Unroll episode until agent reaches goal + while not env.trueGoalReached(SUBGOAL_ORDER[goal_idx]) and not env.isTerminal() and episode_steps < MAX_EPISODE_STEPS and not wrong_goal: + + goal_position = torch.tensor(np.array([ALL_SUBGOALS[SUBGOAL_ORDER[goal_idx]]]), device=device) + + if create_video: + img = env.getScreenOrig() + + + cv2.putText(img=img, text='goal: ' + goalExplain.get(SUBGOAL_ORDER[goal_idx]), org=(5, 205), fontFace=cv2.FONT_HERSHEY_SIMPLEX, + fontScale=0.3, color=(255, 255, 255),thickness=1) + + img_array.append(img) + + if GOAL_FEATURE: + action = agent.select_action(state.unsqueeze(0), goal_position) + else: + action = agent.select_action(state.unsqueeze(0)) + + external_reward = env.act(actionMap[action]) + if (external_reward != 0): + external_reward = 1.0 + + all_step_times.append(time.time() - step_time) + step_time = time.time() + episode_steps += 1 + + # Calculate intrinsic reward [optionally add distance reward] + reward = 0.0 # - 1.0 / 1000 # small negative reward for each step + if env.trueGoalReached(SUBGOAL_ORDER[goal_idx]): + reward += 1.0 + if env.isTerminal() or episode_steps==MAX_EPISODE_STEPS: + reward -= 1.0 + + """ + # Simple direction reward + goal_vec = env.get_goal_direction(SUBGOAL_ORDER[goal_idx]) + action_vec = actionVectors[action] + direction_reward = 0.001 * np.dot(action_vec, goal_vec) + reward += direction_reward + """ + + # Distance Reward + last_goal = SUBGOAL_ORDER[goal_idx-1] if goal_idx != 0 else None + reward += env.distanceReward(lastGoal=last_goal, goal=SUBGOAL_ORDER[goal_idx]) + + + total_rewards.append(reward) + reward = torch.tensor([reward], device=device) + + next_state = env.getStackedState() + + # Store transition and update network parameters + agent.remember(state, action, reward, next_state, env.isTerminal(), goal_position) + + # Move to the next state + state = next_state + + # Optimize the policy network + agent_loss = agent.optimize_model() + if agent_loss is not None: + subgoal_agent_loss += agent_loss + + # Update goal + if episode_steps >= MAX_EPISODE_STEPS: + subgoal_success_tracker[goal_idx].append(0) + if args.verbose >= 1: + print(f'[Episode {episode}] Reached maximum epsiode steps: {episode_steps}. Terminate episode.') + break + + elif env.trueGoalReached(SUBGOAL_ORDER[goal_idx]): + if args.verbose >= 1: + print(f'[Episode {episode}] Goal reached! \U0001F389 after step #{episode_steps}. ') + + subgoal_success_tracker[goal_idx].append(1) + subgoal_success_steps[goal_idx].append(episode_steps) + + # Predict new goal and continue if its true goal + goal_idx += 1 + + """ + if subgoal_trailing_performance[goal_idx] == 0: + # goal has not been seen before --> reset TanH schedule to enforce new exploration + print(episode, episode_steps, goal, 'resetting TanH schedule...') + agent.reset_epsilon_schedule() + """ + + if goal_idx == len(SUBGOAL_ORDER): + # Finished all options --> start new episode + wrong_goal = True + print(f"[Episode {episode}] Reached all goals! \U0001F38A \U0001F38A \U0001F38A") + break + + true_goal = goal_idx # if GOAL_MODE == 'full_sequence' else inv_label_mapping.get(SUBGOAL_ORDER[goal_idx]) + expert_goal = np.zeros((1, GOALS)) + expert_goal[0, true_goal] = 1.0 + meta_labels.append((state, expert_goal)) # append, but do not collect yet + + # Metacontroller predict goal + goal = metacontroller.sample(metacontroller.predict(state)) + + if goal!= true_goal: + wrong_goal = True + if args.verbose >= 1: + print(f"[Episode {episode}] Metacontroller predicted {goal} instead of {true_goal} as goal \U0001F47E ") + wrong_meta_pred += 1 + if wrong_meta_pred % 100 == 0: + if args.verbose >= 1: + print(f'[Episode {episode}] 100 wrong meta choices. Resetting metacontroller... ') + metacontroller.reset() + break + else: + # Continue with new goal + if args.verbose >= 1: + print(f'[Episode {episode}] Metacontroller predicted goal {goal} as new goal...') + + else: + subgoal_success_tracker[goal_idx].append(0) + if args.verbose >= 1: + print(f'[Episode {episode}] Agent killed after {episode_steps} steps \U0001F47E Terminate episode.') + break + + if create_video and img_array: + out_path = f'runs/video_run{TIME_STAMP}-E{episode}.gif' + imageio.mimsave(out_path, img_array) + print('Saved gif to ', out_path) + create_video = False + + """END OF EPISODE""" + # At end of episode: aggregate data for metacontroller + item = meta_labels[-1] + metacontroller.collect(item[0], item[1]) + + total_steps += episode_steps + + # if agent was running + if all_step_times: + # Print average step time + avg_step_time = sum(all_step_times) / len(all_step_times) + + # Log agent losses + agent_loss = subgoal_agent_loss / episode_steps + + else: + avg_step_time = 0 + agent_loss = 0 + + + if not DEBUG: + wandb.log({'mean_reward': np.array(total_rewards[-episode_steps:]).mean(), 'avg_step_time': avg_step_time*100, 'agent_loss': agent_loss, 'epsilon': agent.get_epsilon(), 'episode': episode}) + + for g in range(GOALS): + if len(subgoal_success_tracker[g]) < episode: + # goal not evaluated yet because other goals not reached + subgoal_success_tracker[g].append(None) + if not DEBUG: + wandb.log({ f'sub-goal {g}': subgoal_success_tracker[g][-1], 'episode': episode}) + + # Train metacontroller + if metacontroller.check_training_clock(): + if args.verbose >= 2: + print('###################################') + print('### Training Metacontroller ###') + print('###################################') + + meta_loss, meta_acc = metacontroller.train() + if not DEBUG: + wandb.log({'meta_loss': meta_loss, 'meta_acc': meta_acc, 'episode': episode}) + + if args.verbose >= 2: + print(f'Metacontroller loss: {meta_loss:.2f}\n') + print('###################################') + + if len(subgoal_success_tracker[goal_idx]) > 100: + subgoal_trailing_performance[goal] = sum([v for v in subgoal_success_tracker[goal][-100:] if not v is None])/ 100.0 + + if args.verbose >= 0: + print(f'[Episode {episode}] Subgoal trailing performance for goal {goal_idx} is {subgoal_trailing_performance[goal_idx]:.2f}') + print(f'[Episode {episode}] Agent {goal_idx} steps done {agent.steps_done}') + print(f'[Episode {episode}] Agent epsilon value is {agent.get_epsilon():.2f}') + + if subgoal_trailing_performance[goal] > STOP_TRAINING_THRESHOLD: + if args.verbose >= 1: + print(f'[Episode {episode}] Training for goal #{goal} completed...') + + if true_goal == GOALS: + # Last goal reached trailing performance --> stop training + print(f"\nTraining completed \U0001F38A \U0001F38A \U0001F38A \n") + agent.finish_training(TIME_STAMP) + torch.save(metacontroller.model.state_dict(), f"run{TIME_STAMP}-metacontroller.pt") + break + + # After training is completed + for goal in range(GOALS): + if subgoal_success_steps[goal]: + print(f'\nAgent{goal} performance: {np.array(subgoal_success_tracker[goal]).mean():.2%}') + print(f'Reached goal {goal} {sum(subgoal_success_tracker[goal])}x with {np.array(subgoal_success_steps[goal]).mean():.0f} average steps') + + +if __name__ == "__main__": + main()