import os
import time

import cv2
import torch
import numpy as np 
import imageio
import wandb
from datetime import datetime

from ddqn_agent import Agent
from atari_env import ALEEnvironment
from args import HIntArgumentParser
from metacontroller import MetaController


GPU_DEVICE = 4
VERBOSE = 2
DEBUG = False
TEST = False

STOP_TRAINING_THRESHOLD = 0.90
EPISODES = 1000000
MAX_EPISODE_STEPS = 1000
RANDOM_PLAY_STEPS = 20000

# Use subgoals from gaze analysis 
GOAL_FEATURE = True
GOAL_MODE = 'single_agent' 
ALL_SUBGOALS = np.loadtxt('subgoals.txt', dtype=int, delimiter=',')
SUBGOAL_ORDER = [8, 6, 1, 0, 2, 7, 2, 0, 1, 6, 8, 9] 
GOALS = len(SUBGOAL_ORDER) #if GOAL_MODE == 'full_sequence' else len(np.unique(SUBGOAL_ORDER))
SUBGOALs = SUBGOAL_ORDER #if GOAL_MODE == 'full_sequence' else np.array(SUBGOAL_ORDER)[np.sort(np.unique(SUBGOAL_ORDER, return_index=True)[1])] # unsorted unique -> keep subgoal appearance order 
TRAINED_GOALS = [False] * GOALS

def main():

    # init random seed 
    RANDOM_SEED = 9 #np.random.randint(100)
    print(f'Setting random seed to {RANDOM_SEED}')
    os.environ['PYTHONHASHSEED']=str(RANDOM_SEED)
    np.random.seed(RANDOM_SEED)
    torch.manual_seed(RANDOM_SEED)

    x = datetime.now()

    TIME_STAMP = x.strftime("%d%b%y-%H%M%S")
    MODEL_NAME = 'hInt-RL-single-dist_' + str(RANDOM_SEED) + '_' + TIME_STAMP

    os.environ['CUDA_VISIBLE_DEVICES']=str(GPU_DEVICE)
    os.environ['CUDA_LAUNCH_BLOCKING'] = str(1.0)

    actionMap = [0, 1, 2, 3, 4, 5, 11, 12]
    actionExplain = ['no action', 'jump', 'up', 'right', 'left', 'down', 'jump right', 'jump left']
    actionVectors = [[0,0], [0.0,-1.0], [0.0,-1.0], [1.0,0.0], [-1.0,0.0], [0.0,1.0], [0.7071067811865475,-0.7071067811865475], [-0.7071067811865475,-0.7071067811865475]]
    goalExplain = {8: 'start', 6: 'rope', 1: 'lower right ladder', 0: 'danger zone', 2: 'lower left ladder',  7: 'key', 9:'left door'}

    inv_label_mapping = {}    
    for i, l in enumerate(SUBGOALs):  # np.unique without sorting the values 
        inv_label_mapping[l] = i
    print(inv_label_mapping)

    subgoal_success_tracker = [[0] for i in range(GOALS)]
    subgoal_success_steps = [[] for i in range(GOALS)]
    subgoal_trailing_performance = [0.0 for _ in range(GOALS)]
    
    parser = HIntArgumentParser()
    parser.set_common_args()

    args = parser.parse_args(['--model_name', MODEL_NAME, '--verbose', str(VERBOSE), '--random_seed', str(RANDOM_SEED), 
                              '--num_goals', str(GOALS), '--goal_mode', GOAL_MODE, 
                              '--epsilon_start', str(1.0), '--epsilon_end', str(0.1), '--epsilon_steps', str(2000000)])
    
    if not DEBUG:
        wandb.init(project="hInt-RL", config=vars(args), name=MODEL_NAME)

    print(*[f'\n[PARAM] {k}: {v}' for (k,v) in vars(args).items()])
    print()

    env = ALEEnvironment(args, device=f"cuda:{GPU_DEVICE}")

    input_shape = env.getStackedState().shape
    #input_shape = np.array([input_shape[2], input_shape[1], input_shape[0]])
    print(f'Input shape {input_shape}\n')
   
    device = torch.device(f"cuda:{GPU_DEVICE}" if torch.cuda.is_available() else "cpu")

    # Setup metacontroller
    metacontroller = MetaController(device=device, args=args, input_shape=input_shape, n_goals=GOALS, hidden_nodes=args.hidden_nodes)
    
    # Setup single agent 
    agent = Agent(device=device, goal=None, input_shape=input_shape, n_actions=len(actionMap), random_play_steps=RANDOM_PLAY_STEPS, args=args, goal_feature=GOAL_FEATURE)

    total_rewards = []
    wrong_meta_pred = 0 
    total_steps = 0 
    create_video = False

    for episode in range(EPISODES):
        if args.verbose >= 1: 
            print(f'\n[Episode {episode}] Starting new episode...')

        # Initialize environment and get state
        env.restart()
        state = env.getStackedState()

        episode_steps = 0
        subgoal_agent_loss = 0
        meta_labels = []
        wrong_goal = False
        goal_idx = 0
        true_goal = goal_idx # if GOAL_MODE == 'full_sequence' else inv_label_mapping.get(SUBGOAL_ORDER[goal_idx])

        img_array = []
        if episode % 10000 == 0:
            create_video = True

        expert_goal = np.zeros((1, GOALS))
        expert_goal[0, true_goal] = 1.0
        meta_labels.append((state, expert_goal)) # append, but do not collect yet 

        if TRAINED_GOALS[true_goal]:
            goal = true_goal
            # Collect sample for metacontroller because goal will be reached
            metacontroller.collect(state, expert_goal) 
        else:
            # Metacontroller predict goal 
            goal = metacontroller.sample(metacontroller.predict(state))

        if goal!= true_goal:
            wrong_goal = True
            if args.verbose >= 1: 
                print(f"[Episode {episode}] Metacontroller predicted {goal} instead of {true_goal} as goal \U0001F47E ")
            wrong_meta_pred += 1
            if wrong_meta_pred % 100 == 0:
                print(f'[Episode {episode}] 100 wrong meta choices. Resetting metacontroller... ')
                metacontroller.reset() 

        else:
            if args.verbose >= 1: 
                print(f'[Episode {episode}] Metacontroller predicted goal {goal} as new goal...')

        all_step_times = []    
        step_time = time.time()

        while not env.isTerminal() and episode_steps < MAX_EPISODE_STEPS and not wrong_goal:

            # Unroll episode until agent reaches goal
            while not env.trueGoalReached(SUBGOAL_ORDER[goal_idx]) and not env.isTerminal() and episode_steps < MAX_EPISODE_STEPS and not wrong_goal: 
                
                goal_position = torch.tensor(np.array([ALL_SUBGOALS[SUBGOAL_ORDER[goal_idx]]]), device=device)

                if create_video:
                    img = env.getScreenOrig()
                    
                    
                    cv2.putText(img=img, text='goal: ' + goalExplain.get(SUBGOAL_ORDER[goal_idx]), org=(5, 205), fontFace=cv2.FONT_HERSHEY_SIMPLEX, 
                                fontScale=0.3, color=(255, 255, 255),thickness=1)

                    img_array.append(img)

                if GOAL_FEATURE:
                    action = agent.select_action(state.unsqueeze(0), goal_position) 
                else:
                    action = agent.select_action(state.unsqueeze(0)) 
                              
                external_reward = env.act(actionMap[action])
                if (external_reward != 0):
                    external_reward = 1.0

                all_step_times.append(time.time() - step_time)
                step_time = time.time()
                episode_steps += 1
                
                # Calculate intrinsic reward [optionally add distance reward]
                reward = 0.0 # - 1.0 / 1000       # small negative reward for each step 
                if env.trueGoalReached(SUBGOAL_ORDER[goal_idx]):
                    reward += 1.0
                if env.isTerminal() or episode_steps==MAX_EPISODE_STEPS:
                    reward -= 1.0

                """
                # Simple direction reward
                goal_vec = env.get_goal_direction(SUBGOAL_ORDER[goal_idx])
                action_vec = actionVectors[action]
                direction_reward = 0.001 * np.dot(action_vec, goal_vec)
                reward += direction_reward
                """

                # Distance Reward
                last_goal = SUBGOAL_ORDER[goal_idx-1] if goal_idx != 0 else None
                reward += env.distanceReward(lastGoal=last_goal, goal=SUBGOAL_ORDER[goal_idx])
                

                total_rewards.append(reward)
                reward = torch.tensor([reward], device=device)
                
                next_state = env.getStackedState()
                
                # Store transition and update network parameters
                agent.remember(state, action, reward, next_state, env.isTerminal(), goal_position)

                # Move to the next state
                state = next_state

                # Optimize the policy network
                agent_loss = agent.optimize_model()
                if agent_loss is not None:
                    subgoal_agent_loss += agent_loss

            # Update goal
            if episode_steps >= MAX_EPISODE_STEPS:
                subgoal_success_tracker[goal_idx].append(0)
                if args.verbose >= 1: 
                    print(f'[Episode {episode}] Reached maximum epsiode steps: {episode_steps}. Terminate episode.')
                break

            elif env.trueGoalReached(SUBGOAL_ORDER[goal_idx]):
                if args.verbose >= 1: 
                    print(f'[Episode {episode}] Goal reached! \U0001F389 after step #{episode_steps}. ')

                subgoal_success_tracker[goal_idx].append(1)
                subgoal_success_steps[goal_idx].append(episode_steps)
                
                # Predict new goal and continue if its true goal 
                goal_idx += 1

                """
                if subgoal_trailing_performance[goal_idx] == 0: 
                    # goal has not been seen before --> reset TanH schedule to enforce new exploration
                    print(episode, episode_steps, goal, 'resetting TanH schedule...')
                    agent.reset_epsilon_schedule()
                """

                if goal_idx == len(SUBGOAL_ORDER):
                    # Finished all options --> start new episode
                    wrong_goal = True
                    print(f"[Episode {episode}] Reached all goals! \U0001F38A \U0001F38A \U0001F38A")
                    break 

                true_goal = goal_idx # if GOAL_MODE == 'full_sequence' else inv_label_mapping.get(SUBGOAL_ORDER[goal_idx])
                expert_goal = np.zeros((1, GOALS))
                expert_goal[0, true_goal] = 1.0
                meta_labels.append((state, expert_goal)) # append, but do not collect yet 

                # Metacontroller predict goal 
                goal = metacontroller.sample(metacontroller.predict(state))

                if goal!= true_goal:
                    wrong_goal = True
                    if args.verbose >= 1: 
                        print(f"[Episode {episode}] Metacontroller predicted {goal} instead of {true_goal} as goal \U0001F47E ")
                    wrong_meta_pred += 1
                    if wrong_meta_pred % 100 == 0:
                        if args.verbose >= 1: 
                            print(f'[Episode {episode}] 100 wrong meta choices. Resetting metacontroller... ')
                        metacontroller.reset() 
                    break
                else:
                    # Continue with new goal
                    if args.verbose >= 1: 
                        print(f'[Episode {episode}] Metacontroller predicted goal {goal} as new goal...')

            else:
                subgoal_success_tracker[goal_idx].append(0)
                if args.verbose >= 1: 
                    print(f'[Episode {episode}] Agent killed after {episode_steps} steps \U0001F47E Terminate episode.')
                break

        if create_video and img_array:
            out_path = f'runs/video_run{TIME_STAMP}-E{episode}.gif'
            imageio.mimsave(out_path, img_array)
            print('Saved gif to ', out_path)
            create_video = False

        """END OF EPISODE"""
        # At end of episode: aggregate data for metacontroller 
        item = meta_labels[-1]
        metacontroller.collect(item[0], item[1]) 

        total_steps += episode_steps

        # if agent was running 
        if all_step_times:
            # Print average step time
            avg_step_time = sum(all_step_times) / len(all_step_times)
            
            # Log agent losses 
            agent_loss = subgoal_agent_loss / episode_steps

        else: 
            avg_step_time = 0 
            agent_loss = 0 

        
        if not DEBUG: 
            wandb.log({'mean_reward': np.array(total_rewards[-episode_steps:]).mean(), 'avg_step_time': avg_step_time*100, 'agent_loss': agent_loss, 'epsilon': agent.get_epsilon(), 'episode': episode})
            
        for g in range(GOALS):
            if len(subgoal_success_tracker[g]) < episode: 
                    # goal not evaluated yet because other goals not reached
                    subgoal_success_tracker[g].append(None)
            if not DEBUG:
                wandb.log({ f'sub-goal {g}': subgoal_success_tracker[g][-1], 'episode': episode})

        # Train metacontroller
        if metacontroller.check_training_clock():
            if args.verbose >= 2: 
                print('###################################')
                print('###   Training Metacontroller   ###')
                print('###################################')

            meta_loss, meta_acc = metacontroller.train()
            if not DEBUG:
                wandb.log({'meta_loss': meta_loss, 'meta_acc': meta_acc, 'episode': episode})

            if args.verbose >= 2: 
                print(f'Metacontroller loss: {meta_loss:.2f}\n')
                print('###################################')
        
        if len(subgoal_success_tracker[goal_idx]) > 100:
            subgoal_trailing_performance[goal] = sum([v for v in subgoal_success_tracker[goal][-100:] if not v is None])/ 100.0
            
            if args.verbose >= 0: 
                print(f'[Episode {episode}] Subgoal trailing performance for goal {goal_idx} is {subgoal_trailing_performance[goal_idx]:.2f}')
                print(f'[Episode {episode}] Agent {goal_idx} steps done {agent.steps_done}')
                print(f'[Episode {episode}] Agent epsilon value is {agent.get_epsilon():.2f}')
            
            if subgoal_trailing_performance[goal] > STOP_TRAINING_THRESHOLD:
                if args.verbose >= 1: 
                    print(f'[Episode {episode}] Training for goal #{goal} completed...')

                if true_goal == GOALS:
                    # Last goal reached trailing performance --> stop training
                    print(f"\nTraining completed \U0001F38A \U0001F38A \U0001F38A \n")
                    agent.finish_training(TIME_STAMP)
                    torch.save(metacontroller.model.state_dict(), f"run{TIME_STAMP}-metacontroller.pt")
                    break
    
    # After training is completed 
    for goal in range(GOALS):
        if subgoal_success_steps[goal]:
            print(f'\nAgent{goal} performance: {np.array(subgoal_success_tracker[goal]).mean():.2%}')
            print(f'Reached goal {goal} {sum(subgoal_success_tracker[goal])}x with {np.array(subgoal_success_steps[goal]).mean():.0f} average steps')


if __name__ == "__main__":
    main()