1109 lines
129 KiB
Text
1109 lines
129 KiB
Text
|
{
|
|||
|
"cells": [
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "909118c1",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"# Label Subgoals\n",
|
|||
|
"- Label each frame with current subgoal \n",
|
|||
|
"- Extract order of subgoals visited \n",
|
|||
|
"- Only include visited subgoals as true subgoals \n",
|
|||
|
"\n",
|
|||
|
"Result: `subgoal_order = [8, 6, 1, 0, 2, 7, 9]`\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 1,
|
|||
|
"id": "b382d001",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"import os\n",
|
|||
|
"import random\n",
|
|||
|
"\n",
|
|||
|
"import cv2\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"import pandas as pd \n",
|
|||
|
"import matplotlib.pyplot as plt \n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"SIGMA = (210 / 44.6, 160 / 28.5)\n",
|
|||
|
"\n",
|
|||
|
"GPU_DEVICE = 3\n",
|
|||
|
"\n",
|
|||
|
"os.environ['CUDA_VISIBLE_DEVICES']=str(GPU_DEVICE)\n",
|
|||
|
"\n",
|
|||
|
"DATA_PATH = '/datasets/public/anna/montezuma_revenge'"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "f7c3ec86",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### Load labeled data and extracted subgoals "
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 2,
|
|||
|
"id": "21935cab",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"CPU times: user 11.2 s, sys: 3.94 s, total: 15.2 s\n",
|
|||
|
"Wall time: 15.2 s\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>ID</th>\n",
|
|||
|
" <th>frame_id</th>\n",
|
|||
|
" <th>episode_id</th>\n",
|
|||
|
" <th>score</th>\n",
|
|||
|
" <th>duration(ms)</th>\n",
|
|||
|
" <th>unclipped_reward</th>\n",
|
|||
|
" <th>action</th>\n",
|
|||
|
" <th>gaze_positions</th>\n",
|
|||
|
" <th>img_path</th>\n",
|
|||
|
" <th>level</th>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <th>player_location</th>\n",
|
|||
|
" <th>skull_location</th>\n",
|
|||
|
" <th>num_gaze_positions</th>\n",
|
|||
|
" <th>gaze_duration_ratio</th>\n",
|
|||
|
" <th>angular_gaze_displacement</th>\n",
|
|||
|
" <th>gaze_velocity</th>\n",
|
|||
|
" <th>max_gaze_velocity</th>\n",
|
|||
|
" <th>avg_gaze_velocity</th>\n",
|
|||
|
" <th>time_stamps</th>\n",
|
|||
|
" <th>current_subgoal</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>0</th>\n",
|
|||
|
" <td>284_RZ_5540489_E00</td>\n",
|
|||
|
" <td>RZ_5540489_1</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>2817</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>[[80.4, 103.5], [80.34, 103.4], [80.34, 103.3]...</td>\n",
|
|||
|
" <td>/datasets/public/anna/montezuma_revenge/284_RZ...</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>[77, 85]</td>\n",
|
|||
|
" <td>[92, 175]</td>\n",
|
|||
|
" <td>2614.0</td>\n",
|
|||
|
" <td>0.927938</td>\n",
|
|||
|
" <td>[8.986148924666498, 0.00034557292731263054, 0....</td>\n",
|
|||
|
" <td>[96840.02112006705, 3.724096925170927, 2.97928...</td>\n",
|
|||
|
" <td>266.07</td>\n",
|
|||
|
" <td>95.051218</td>\n",
|
|||
|
" <td>[0.0, 0.9279375221867234, 1.8558750443734469, ...</td>\n",
|
|||
|
" <td>8</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1</th>\n",
|
|||
|
" <td>284_RZ_5540489_E00</td>\n",
|
|||
|
" <td>RZ_5540489_2</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>50</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>[[113.66, 98.28], [113.65, 98.42], [113.65, 98...</td>\n",
|
|||
|
" <td>/datasets/public/anna/montezuma_revenge/284_RZ...</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>[77, 85]</td>\n",
|
|||
|
" <td>[92, 175]</td>\n",
|
|||
|
" <td>50.0</td>\n",
|
|||
|
" <td>1.000000</td>\n",
|
|||
|
" <td>[8.986165402283532, 0.00044873289690316374, 0....</td>\n",
|
|||
|
" <td>[89861.65402283533, 4.487328969031638, 0.0, 3....</td>\n",
|
|||
|
" <td>114.40</td>\n",
|
|||
|
" <td>104.255400</td>\n",
|
|||
|
" <td>[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ...</td>\n",
|
|||
|
" <td>8</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>2</th>\n",
|
|||
|
" <td>284_RZ_5540489_E00</td>\n",
|
|||
|
" <td>RZ_5540489_3</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>51</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>[[87.26, 100.72], [86.4, 100.8], [85.61, 100.8...</td>\n",
|
|||
|
" <td>/datasets/public/anna/montezuma_revenge/284_RZ...</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>[77, 85]</td>\n",
|
|||
|
" <td>[92, 175]</td>\n",
|
|||
|
" <td>51.0</td>\n",
|
|||
|
" <td>1.000000</td>\n",
|
|||
|
" <td>[8.986150521057317, 0.0029062899985003087, 0.0...</td>\n",
|
|||
|
" <td>[89861.50521057317, 29.06289998500309, 21.2663...</td>\n",
|
|||
|
" <td>102.50</td>\n",
|
|||
|
" <td>90.330686</td>\n",
|
|||
|
" <td>[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ...</td>\n",
|
|||
|
" <td>8</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>3</th>\n",
|
|||
|
" <td>284_RZ_5540489_E00</td>\n",
|
|||
|
" <td>RZ_5540489_4</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>51</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>[[78.41, 101.5], [78.41, 101.5], [78.41, 101.6...</td>\n",
|
|||
|
" <td>/datasets/public/anna/montezuma_revenge/284_RZ...</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>[77, 85]</td>\n",
|
|||
|
" <td>[92, 175]</td>\n",
|
|||
|
" <td>51.0</td>\n",
|
|||
|
" <td>1.000000</td>\n",
|
|||
|
" <td>[8.986146988240504, 0.00041474113941382504, 0....</td>\n",
|
|||
|
" <td>[89861.46988240504, 4.14741139413825, 7.982843...</td>\n",
|
|||
|
" <td>102.88</td>\n",
|
|||
|
" <td>90.392941</td>\n",
|
|||
|
" <td>[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ...</td>\n",
|
|||
|
" <td>8</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>4</th>\n",
|
|||
|
" <td>284_RZ_5540489_E00</td>\n",
|
|||
|
" <td>RZ_5540489_5</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>55</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>[[78.41, 102.85], [78.42, 102.95], [78.41, 102...</td>\n",
|
|||
|
" <td>/datasets/public/anna/montezuma_revenge/284_RZ...</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>[77, 85]</td>\n",
|
|||
|
" <td>[91, 175]</td>\n",
|
|||
|
" <td>55.0</td>\n",
|
|||
|
" <td>1.000000</td>\n",
|
|||
|
" <td>[8.986147775662209, 7.822521255579121e-05, 0.0...</td>\n",
|
|||
|
" <td>[89861.47775662209, 0.7822521255579121, 2.8601...</td>\n",
|
|||
|
" <td>103.92</td>\n",
|
|||
|
" <td>90.895364</td>\n",
|
|||
|
" <td>[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ...</td>\n",
|
|||
|
" <td>8</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>5 rows × 21 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" ID frame_id episode_id score duration(ms) \\\n",
|
|||
|
"0 284_RZ_5540489_E00 RZ_5540489_1 0 0 2817 \n",
|
|||
|
"1 284_RZ_5540489_E00 RZ_5540489_2 0 0 50 \n",
|
|||
|
"2 284_RZ_5540489_E00 RZ_5540489_3 0 0 51 \n",
|
|||
|
"3 284_RZ_5540489_E00 RZ_5540489_4 0 0 51 \n",
|
|||
|
"4 284_RZ_5540489_E00 RZ_5540489_5 0 0 55 \n",
|
|||
|
"\n",
|
|||
|
" unclipped_reward action gaze_positions \\\n",
|
|||
|
"0 0 0 [[80.4, 103.5], [80.34, 103.4], [80.34, 103.3]... \n",
|
|||
|
"1 0 0 [[113.66, 98.28], [113.65, 98.42], [113.65, 98... \n",
|
|||
|
"2 0 0 [[87.26, 100.72], [86.4, 100.8], [85.61, 100.8... \n",
|
|||
|
"3 0 0 [[78.41, 101.5], [78.41, 101.5], [78.41, 101.6... \n",
|
|||
|
"4 0 0 [[78.41, 102.85], [78.42, 102.95], [78.41, 102... \n",
|
|||
|
"\n",
|
|||
|
" img_path level ... \\\n",
|
|||
|
"0 /datasets/public/anna/montezuma_revenge/284_RZ... 0 ... \n",
|
|||
|
"1 /datasets/public/anna/montezuma_revenge/284_RZ... 0 ... \n",
|
|||
|
"2 /datasets/public/anna/montezuma_revenge/284_RZ... 0 ... \n",
|
|||
|
"3 /datasets/public/anna/montezuma_revenge/284_RZ... 0 ... \n",
|
|||
|
"4 /datasets/public/anna/montezuma_revenge/284_RZ... 0 ... \n",
|
|||
|
"\n",
|
|||
|
" player_location skull_location num_gaze_positions gaze_duration_ratio \\\n",
|
|||
|
"0 [77, 85] [92, 175] 2614.0 0.927938 \n",
|
|||
|
"1 [77, 85] [92, 175] 50.0 1.000000 \n",
|
|||
|
"2 [77, 85] [92, 175] 51.0 1.000000 \n",
|
|||
|
"3 [77, 85] [92, 175] 51.0 1.000000 \n",
|
|||
|
"4 [77, 85] [91, 175] 55.0 1.000000 \n",
|
|||
|
"\n",
|
|||
|
" angular_gaze_displacement \\\n",
|
|||
|
"0 [8.986148924666498, 0.00034557292731263054, 0.... \n",
|
|||
|
"1 [8.986165402283532, 0.00044873289690316374, 0.... \n",
|
|||
|
"2 [8.986150521057317, 0.0029062899985003087, 0.0... \n",
|
|||
|
"3 [8.986146988240504, 0.00041474113941382504, 0.... \n",
|
|||
|
"4 [8.986147775662209, 7.822521255579121e-05, 0.0... \n",
|
|||
|
"\n",
|
|||
|
" gaze_velocity max_gaze_velocity \\\n",
|
|||
|
"0 [96840.02112006705, 3.724096925170927, 2.97928... 266.07 \n",
|
|||
|
"1 [89861.65402283533, 4.487328969031638, 0.0, 3.... 114.40 \n",
|
|||
|
"2 [89861.50521057317, 29.06289998500309, 21.2663... 102.50 \n",
|
|||
|
"3 [89861.46988240504, 4.14741139413825, 7.982843... 102.88 \n",
|
|||
|
"4 [89861.47775662209, 0.7822521255579121, 2.8601... 103.92 \n",
|
|||
|
"\n",
|
|||
|
" avg_gaze_velocity time_stamps \\\n",
|
|||
|
"0 95.051218 [0.0, 0.9279375221867234, 1.8558750443734469, ... \n",
|
|||
|
"1 104.255400 [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ... \n",
|
|||
|
"2 90.330686 [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ... \n",
|
|||
|
"3 90.392941 [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ... \n",
|
|||
|
"4 90.895364 [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ... \n",
|
|||
|
"\n",
|
|||
|
" current_subgoal \n",
|
|||
|
"0 8 \n",
|
|||
|
"1 8 \n",
|
|||
|
"2 8 \n",
|
|||
|
"3 8 \n",
|
|||
|
"4 8 \n",
|
|||
|
"\n",
|
|||
|
"[5 rows x 21 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 2,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"%%time \n",
|
|||
|
"df = pd.read_pickle(os.path.join(DATA_PATH, \"all_trials_labeled.pkl\"))\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"init_screen = cv2.imread(df.iloc[0].img_path)\n",
|
|||
|
"init_screen = cv2.cvtColor(init_screen, cv2.COLOR_BGR2RGB)\n",
|
|||
|
"SUBGOALS = np.loadtxt('subgoals.txt', dtype=int, delimiter=',')\n",
|
|||
|
"\n",
|
|||
|
"dim = (init_screen.shape[1] * 2, init_screen.shape[0] * 2)\n",
|
|||
|
"img = cv2.resize(init_screen.copy(), dim, interpolation = cv2.INTER_AREA)\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"df.head()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "284918a0",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## Visualize extracted order"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 3,
|
|||
|
"id": "e42effb0",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAVkAAAHBCAYAAADDx8j1AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAAA8aElEQVR4nO3dd5wTZeIG8GcmyWY321h2Yem9Kh2xgFIURU4FOc4TsIHSbOhZzp8VldM79TxPz4JYsAOHyokezUKRE6T3Ih22sOwC23fTZn5/zPZkJtlNJslkn+/nsx+SeTNv3gzJk8k777wjyLIMIiLShxjuBhARRTOGLBGRjhiyREQ6YsgSEemIIUtEpCOGLBGRjsxahYIgcHwXEZEfZFkWvC3nniwRkY4YskREOmLIEhHpiCFLRKQjhiwRkY4YskREOmLIEhHpiCFLRKQjhiwRkY4YskREOmLIEhHpSHPuAiLyz913340ePXp4LN+wYQMWLlwYhhZRpGDIEgXBlVdeiREjRngslySJIdvIMWSJAjBlyhQ0b94cHTp0AAAsW7YMu3btwhVXXIEhQ4agb9++eOyxx3D06FEsXrw4vI2l8JBlWfUPgMw//vFP/W/NmjVyXl5e1d+kSZNkAPJjjz1Wa/nChQvD3lb+6funlqM88EVEpCOGLBGRjhiyREQ6YsgSEemIIUtEpCOGLBGRjhiyREQ6YsgSEemIIUsURHFxcUhKSoLVag13UyhC8LRaoiB66aWX8NJLL4W7GRRBuCdLFAC32w2n0wlJkryWS5IEp9MJt9sd4pZRpBAq5ijwXigI6oVEVGXx4sVeZ+H69NNP8ac//SkMLaJQk2VZ8Lac3QVEQVBjUiWiWrgnS0QUBGp7suyTJSLSEbsLGrG3nrsCl/VP91i+5PtjmPOvrT7XF0UBW/4z3mvZlbcuRX6hw2cdE2/ogken9fNYnpFTgjHTlvtcHwDWLhiLxHiLx/I/vfAL1v6a5XP9Dm0S8fXbozyWywAGjvnSrzYQqWHINmImUYDZ7PljxiR4/dXjlbf160MQVNog+t8Gk8l7Hf6+DEHw/jok9rFSELC7gIhIRwxZIiIdsbugEbOvKkPZziKP5Y5j5f5VIMso+9BzfQBAuX8/td37HF7rsJeU+tcGAOWfl8BisXssl066/Fpfzpe8toGdBRQMDNlGTC6TIRd7RonsZ0BChtf1K8v8qsLpvQ651P+Ik4slyBYvdfiXsZAl721gnywFA8fJNiJz5wxFSnL1xCVt4m2IN3l+zxY4HThdVr03ezyzCI+9tBFNm1jxzvNDq5YLAtAtMcnrcx0uLoJbqn77PPv6Zuw/ko8HJvfG4AEtqpanxMSgeWysx/oOScKx4uKq+y63hFv+9CMAYNEbV9d6bJfERK8H6zLLSlHsrE7a/3x/DAu+PYzLL2qB+2/vXbU8xiSiY3yCx/oygN8KC2stu/2Rn2B38BRZ8sQzvghdOiSjeWqcz8elIA4pqH5cZX5ZTCJ6dknx67m6NW9S674tTnmrtUqP96uOWJjQE9WPc7qq5wbo0bkJBD+GDrRFYq37v2w7DQBISojx+3X0bF77cSKPYlA9MWQbEcfKMpTH1//HieOsslcrl8so/8b/vtKapLNKSLq22FGeX/86XDUmYLEvLQXg/xCvqjr2OQEA7gxXg18H/OyCIKrEkG1E3BkuSLH1TwmpUPl5LLtkSCcaljKV/bxSngTJUv863DVC1n3c5deerEcb8pU65JKGvw4eDaP6Yp9sI3J1ixawmkz1Xq/Q6cS6M2cQazJhZIsWvlfw4pfcXJxzODCgaVO0ivPdZVGXDOC/mZkAgOtbt25QGw4XFeFAYSFax8Whf9OmDapjeVYW3DwgRl6o9ckyZImIgoATxBARhQFDlohIRwxZIiIdMWSJiHTEkCUi0hFDlohIRwxZIiIdMWSJiHTEkCUi0hFDlohIRwxZIiIdMWSJiHTEkCUi0hFDlohIRwxZIiIdMWSJiHTEkCUi0hFDlohIRwxZIiIdMWSJiHTEkCUi0hFDlohIRwxZIiIdMWSJiHTEkCUi0hFDlohIRwxZIiIdMWSJiHTEkCUi0hFDlohIRwxZIiIdMWSJiHTEkCUi0hFDlohIRwxZIiIdMWSJiHTEkCUi0hFDlohIRwxZIiIdMWSJiHTEkCUi0hFDlohIRwxZIiIdMWSJiHTEkCUi0hFDlohIRwxZIiIdMWSJiHTEkCUi0hFDlohIRwxZIiIdMWSJiHTEkCUi0hFDlohIR2atwr4P3hqqdhARRSXNkE3t0y1U7SAiikrsLiAi0hFDlohIRwxZIiIdMWSJiHSkeeArc+2WULWDiMjYpnhfLMiyrLqOIAjqhUREVEWWZcHbcnYXEBHpiCFLRKQjhiwRkY4YskREOmLIEhHpiCFLRKQjzXGyRmCNMcEaY1Itl2QZxSXOELaIiOojId4CUfA6+gkAYHe4YXe4Q9ii4DJ8yM6a0hcPT+uvWp6VU4IB1y0MYYuIqD7WfzkeLZrFq5a/PHcr/vH+jtA1KMgMH7KSJMPplFTLXW71MiIKP5dL+zMsScY+J4pnfBERBQHP+CIiCgPDdxc8OmMAHpraT7U8+0wp+2SJItiO5ROQnmZTLX/l3W3skw03QePIJBFFvmj+DBu+T1YQAFHU/g9yuyP+ZRA1WiaT9udXkmRoxFTEUOuTNXzIEhFFArWQNXx3wawpfXHfHX1Uy0/nlmLoTV+FsEVEVB/rvxyP5hp9sm/M34k3P94VwhYFl+FD1hpjQlJCjGo5z/YiimwJ8RbNz3CMxdiDoAzfXZAYb0Gixn+Q2y0jJ680hC0iovpIT7Np9ssWFjsMsbPEPlkiIh1FbZ/s7eN74LZx3VXLc8+VY9KslSFsERHVx8J/jUJqSqxq+cdfHcBnSw6GsEXBZfiQTU+zoXePNNXyrJySELaGiOqrR5cUzQlimqfGhbA1wWf47oKObZPQsW2Sarnd7sb/tmaHsEVEVB9DLmqpOV3p0ZMFOJ5RFMIWNQz7ZImIdBS1fbLXDmuHa4a2Uy0vKHTgudc3hbBFRFQfzz54MZIS1UcIrVh7EqvWnQxhi4LL8CHbu0caJo1VP/CVlVPCkCWKYDeO6qTZJ5uRXcyQDaf1m7Pg1piY2wjj64gas7c/3Y0Em0W1/Jetp0PYmuBjnywRURBEbZ9snx6pmkO4SsudWLLiaAhbRET18ftrOyMuVj2Kdu3Pw+6DZ0PYouAyfMiOGtbe54UUGbJEkeuZBwb5vJAiQzaMjpwowPc/q3eKn8svD2FriKi+1v2ahZRkq2r50ZOFIWxN8LFPlogoCKK2TzY9zYb0ZupzUTqdbuw/fD6ELSKi+rigawrMZvUzvnJyS5CTVxbCFgWX4UP29vE9fPbJ8kKKRJHrizdG+eyT5YUUw6iwyIFT2ernNZ8x8DcgUWOQnVMKp0t9rHthsSOErQk+9skSEQVB1PbJWmNMmjP4SLLMs76IIlhCvAWixiXB7Q437A53CFsUXIYP2VlT+rJPlsjA1n85nn2ykUySZDid6v05Lo15DYgo/Fwu7c+wJBm715J9skREQaDWJ2vsa+0SEUU4w3cXPDpjAB6a2k+1PPtMKftkiSLYjuUTkJ6mfkLRK+9uY59suAkaRyaJKPJF82fY8H2yggCIovZ/kNsd8S+DqNEymbQ/v5IkQyOmIgYvpEhEpKOoPRlh1pS+uO+OPqrlp3NLMfSmr0LYIiKqj/VfjkdzjT7ZN+bvxJsf7wphi4LL8CFrjTEhKUH9Spc824sosiXEWzQ/wzEWYw+CMnx3QWK8BYkJMcCECcAVQ4F776lV7nbLyMkrBWw2YOdOoGvXMLWUAmE2CVj24XVey8bNXIGSMleIW0TBkp5m0+yXLSx2GGJnKWq7C4runImie+8FkpOB+Hhgjh2YPr36AenpwIFtgCgCnToBe/cCF14YvgZTg7VQmTfY14FPimw5eaXhboKujB2yDz+s/LVsWbVo0HUD8OJnYwEAGTGpmNLzT0D37tXrdOsW6lYSkYaF/xqF1JRY1fKPvzqAz5YcDGGLgsvYIbtkCdCrFzB5ctW
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 576x576 with 1 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {
|
|||
|
"needs_background": "light"
|
|||
|
},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"subgoal_order = np.array([8, 6, 1, 0, 2, 7, 2, 0, 1, 6, 8, 9])\n",
|
|||
|
"\n",
|
|||
|
"for i, box in enumerate(SUBGOALS):\n",
|
|||
|
" \n",
|
|||
|
" #img = cv2.putText(img=img, text=str(i), org=(int(box[0]), int(box[1])), fontFace=cv2.FONT_HERSHEY_SIMPLEX, \n",
|
|||
|
" #fontScale=0.3, color=(255, 255, 255),thickness=1)\n",
|
|||
|
" \n",
|
|||
|
" if i in subgoal_order:\n",
|
|||
|
" box = box.copy() * 2\n",
|
|||
|
" img = cv2.rectangle(img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (255,0,0), 2)\n",
|
|||
|
" idx = np.where(subgoal_order == i)[0]\n",
|
|||
|
" img = cv2.putText(img=img, text='/'.join([f'{n}.' for n in idx]), org=(int(box[0]), int(box[1]) - 5), fontFace=cv2.FONT_HERSHEY_SIMPLEX, \n",
|
|||
|
" fontScale=0.4, color=(0, 255, 255), thickness=2)\n",
|
|||
|
"\n",
|
|||
|
" # text='[' + ','.join([str(n) for n in idx]) + ']'\n",
|
|||
|
" \n",
|
|||
|
"fig = plt.figure(figsize=(8,8))\n",
|
|||
|
"plt.imshow(img)\n",
|
|||
|
"plt.axis('off')\n",
|
|||
|
"plt.savefig('visualizations/subgoals_labeled.png', bbox_inches='tight')\n",
|
|||
|
"plt.show()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 4,
|
|||
|
"id": "5af4f01c",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"def agent_in_subgoal(subgoals, agent_x, agent_y): \n",
|
|||
|
"\n",
|
|||
|
" test_min_x = subgoals[:, 0] < agent_x \n",
|
|||
|
" test_max_x = subgoals[:, 2] > agent_x\n",
|
|||
|
"\n",
|
|||
|
" test_min_y = subgoals[:, 1] < agent_y\n",
|
|||
|
" test_max_y = subgoals[:, 3] > agent_y\n",
|
|||
|
"\n",
|
|||
|
" return np.any(test_min_x & test_max_x & test_min_y & test_max_y), np.where(test_min_x & test_max_x & test_min_y & test_max_y)[0]"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 28,
|
|||
|
"id": "7eb6ceb1",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Percentage on subgoals 56.721905%\n",
|
|||
|
"Percentage not on screen 1.566032%\n",
|
|||
|
"Percentage on skull 8.218883%\n",
|
|||
|
"Percentage on agent 32.611321%\n",
|
|||
|
"Percentage on agent vicinity 68.234675%\n",
|
|||
|
"CPU times: user 1min 7s, sys: 1.19 s, total: 1min 8s\n",
|
|||
|
"Wall time: 1min 11s\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"%%time\n",
|
|||
|
"screen = [0,0, 160, 210]\n",
|
|||
|
"\n",
|
|||
|
"counter = 0\n",
|
|||
|
"\n",
|
|||
|
"subgoal_count = 0 \n",
|
|||
|
"out_screen = 0 \n",
|
|||
|
"\n",
|
|||
|
"areas_count = {}\n",
|
|||
|
"for k in range(3):\n",
|
|||
|
" areas_count[k] = 0 \n",
|
|||
|
" \n",
|
|||
|
"\n",
|
|||
|
"for episode in df.ID.unique():\n",
|
|||
|
" \n",
|
|||
|
" valid_actions_idx = df.loc[df.ID == episode].loc[df.room_id == 1].loc[df.level==0].index\n",
|
|||
|
" \n",
|
|||
|
" if len(valid_actions_idx) > 0: \n",
|
|||
|
" \n",
|
|||
|
" temp_df = df.iloc[valid_actions_idx]\n",
|
|||
|
" \n",
|
|||
|
" for i, frame in temp_df.iterrows():\n",
|
|||
|
" if not frame.gaze_positions is None: \n",
|
|||
|
" \n",
|
|||
|
" mean_x, mean_y = frame.skull_location\n",
|
|||
|
" \n",
|
|||
|
" # x1, x2, y1, y2 \n",
|
|||
|
" skull_area = [mean_x - 5, mean_y - 15, mean_x + 10, mean_y + 10]\n",
|
|||
|
" \n",
|
|||
|
" mean_x, mean_y = frame.player_location\n",
|
|||
|
" agent_area = [mean_x - 5, mean_y - 15, mean_x + 10, mean_y + 10]\n",
|
|||
|
" agent_vicinity = [mean_x - 10, mean_y - 30, mean_x + 20, mean_y + 20]\n",
|
|||
|
" \n",
|
|||
|
" for gaze_points in frame.gaze_positions:\n",
|
|||
|
" counter += 1\n",
|
|||
|
" \n",
|
|||
|
" if agent_in_subgoal(SUBGOALS, gaze_points[0], gaze_points[1])[0]: \n",
|
|||
|
" subgoal_count += 1\n",
|
|||
|
" \n",
|
|||
|
" if not agent_in_subgoal(np.expand_dims(screen, axis=0), gaze_points[0], gaze_points[1])[0]:\n",
|
|||
|
" out_screen += 1\n",
|
|||
|
" \n",
|
|||
|
" check, idx = agent_in_subgoal(np.stack([skull_area, agent_area, agent_vicinity]), \n",
|
|||
|
" gaze_points[0], gaze_points[1])\n",
|
|||
|
" if check: \n",
|
|||
|
" for k in idx: \n",
|
|||
|
" areas_count[k] += 1\n",
|
|||
|
" \n",
|
|||
|
"print(f'Percentage on subgoals {subgoal_count / counter:%}')\n",
|
|||
|
"print(f'Percentage not on screen {out_screen / counter:%}')\n",
|
|||
|
"print(f'Percentage on skull {areas_count[0] / counter:%}')\n",
|
|||
|
"print(f'Percentage on agent {areas_count[1] / counter:%}')\n",
|
|||
|
"print(f'Percentage on agent vicinity {areas_count[2] / counter:%}')"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 31,
|
|||
|
"id": "a7d3b2cb",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"(10829, 691493)"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 31,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"out_screen, counter"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 27,
|
|||
|
"id": "a845983b",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAVkAAAHBCAYAAADDx8j1AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAAA+zElEQVR4nO3deXRcV50v+u8+Q01SabYka/A8xHYGJ3HmhDhAAiGB0M0NIcB70ARYlwsXuu/rvrzu9Op+d/Xwx+uB1d00rzvAJYSGEJyGQAghA85gxwk2seNYHmTLVmzJmqy5VNOpc85+f5Qsx7Z2SS7VUdWRvp+1vFaic+rUr7aqvjq1zz57CykliIjIG1qxCyAiWsgYskREHmLIEhF5iCFLROQhhiwRkYcYskREHjJybhWC47uIiGZDSjHdj3kmS0TkIYYsEZGHcncXvIsDYAyA5V0tRES+YgKoAqDn2GfWITsG4C8A7J9bTUREC8ZlAP4aQEOOfWYdshkAbwPYOceiiIgWihRm/nbPPlkiIg8xZImIPMSQJSLyEEOWiMhDDFkiIg/NenQBEakJMe0dleDKI8SQJZojTdPw0Y9+FNdcc81F23bu3Ilf//rXRaiKSgVDlmiOdF3H1q1b8clPfvK8M1chBBzHwXPPPccz2kWMIUuUJ9M0cc8992DDhg244oor4Loutm/fjt/97ne47bbbcPPNN+P666/H17/+dbz11lt4/vnn4bpuscumecaQJcqTaZq47777cO+99wLI9r++9NJL+Pd//3fouo6bb74ZW7ZswZYtW/CDH/wAL774IkN2EWLIEs3R2Yte03UJqC6I0eLBIVxERB5iyBIReYghS0TkIYYsEZGHGLJERB5iyBIReYghS0TkIYYsUQGtXLkSN998M1pbW4tdCpUI3oxAVCBCCHziE5/ARz/6UUQikWKXQyWCIUuUJ9d10dnZif3796OlpQW1tbUoLy9HeXn51D4DAwPo6elBV1cXJ4lZpNhdQJSndDqNb37zm/jc5z6HHTt2TLvPs88+i89+9rP43ve+B8dx5rlCKgU8kyXKk5QSw8PDGB8fR2dnJ44ePXrRPp2dnTh9+jTPYhcxkfOXL8TUxn4A9wOY/u810eJWX1+PaDR60c9HR0cxNDRUhIpoPmwB8FMArQAg5bSzAfFMdhGrjAYQCV/8FphIZBCbyMz4eMPQUFMZhK6f/95yXYmRsTSszMzT+pVFDFSUBy76eSrtYHQ8jZlOADUBVFcGEQjo52+QwFjMQiJlz1hDKKijqiJw0YxZGdvFyFgajjPzWejAwAAGBgZm3I8WH4bsIqXrAp/8yFpsvaHpom2/3H4S//HUUcwULS2NZfjjz1+FJTXh834ei2fwje/tx8GjIzPW8f6bW/DgR9biwlOAN9vO4J+/fwCpdO5+zPKyAP77Z67AhtXV5/3ccSW++5PD+M2u0zPWcOVltfjK/3E5ghcE9ameCfz9d95C/2ByxmMQqTBkFykBYGl9BBvWVF+0bff+gewOM6RsKKhj9fJKNDeUnffzkbE0ysLmrOqorQ5hw+qqi84i+wcT0LSZ52I1dIFlTdGLXodtu6iqCM6qhmiZiXWrqhAJnf9x0HUB0+S1YZobvoOIiDzEM9nFLCMh0xefrspZ9EECAFwA1jTHsGR222w4k4+/4KRV2pjxTHrKNK9DOhK4lNeRlpBimtfBQQE0RwzZxUoCdlsG6czF/Y3OkcyswkWOu7BeSSEdPf8LkZWy4A7Nbkyoc8KG9VwSF6asfTI9q5CUlkRmdxrprvNfh+O6cLpnWUO/A+vFJDTj/I+DNZyGTDBlaW4YsouIaWrQJvs+TU1AjEm4py+++q5NZPtb3cl8cV2JjO1CE4Bpnrs4FBAaZL8Dd/z8Y7iWA9PVzruQlLFduK6EYQjo2rlQ1pOA0+3gwqWwxIhE0NThTP5cSgkr40IIwDS0qT7coKEDQy5c94IapISRxnk1OI4L25HQNQHDOFeDaQu4PQ7cCwYoIOYiqL37dWRr4JBXuhQcJ7tIhEM6PnXfOqxfWQkge5/9hopKLA2HL9r3VDyOY7EYzp7O7j88hB8/cxyrl1XgU/etRTiYDZ2oaeLKqmqE9PPTKSNdHBgZxYiVBgDYjsS2Xx3HvkOD+OidK3HzNQ1T+y4vL8ea8ovHlw6mU2gbHYMjs/0OXb1xfP+n7QgGdHz2Y+uxpCYEAAhoOq6oqkJV4PxhYBISR8bH0ZNITP3sxddO47kdXbh1SyM+/L4VOHtdbUkohE2VVdAvSPoJO4O3R0aRdLIBPj6RwWM/a8fJ0xM5WpoWE46TpSmmoWHLpiW46eqGGfddiUqsROW5H0hg27MnsKQ6hPfd2IzySO6RAzqA61A/9f+ZjIude3rxlhDYsLoad90y8wxVDShHA87NAXDw2DCeeKYDkZCBW69pxPLmi4P5QpejFpejdur/3+mO4bkdwLKlUdx1S8vUWb1KJQzchnN/hAZHUvjFi+/gJBiyNHsM2UVCZoDMWxas0dQlP9ZuzwCuhDvkwno1BSt4affgZxwXbr8DSAmnPQMrfOk1ZPosICUhXYnMb9Owqmc3ROzdnBPZM1K324b1UmrGkL2QFU9Bjs32ih5RFkN2sXAk7JMZZOLWJT/U7rMBF3AnXGSOZJAxLq1TMuO6cEddQAJOr42MmUcNYxlIO9s3mzmWQabs0o/hDGT/ODhDDjIHrUsO2Uw6A5lkhyxdGobsIpFyXTzV1YU3Bgcv+bFd8ThsKXEyHse/HT0KU7u04dWuzPaPSgAv9fej6139pLM1bFmIZTJIOQ4eO3ECUfPSz2QPjo0BAN4aGcE/t7dfdJfZTFKOg54k7/6iS8MLX0REeZrNhS/e8UVE5CGGLBGRhxiyREQeYsgSEXmIIUtE5CGGLBGRhxiyREQeYsgSEXmIIUtE5CGGLBGRhxiyREQeYsgSEXmIIUtE5CGGLBGRhxiyREQeYsgSEXmIIUtE5CGGLBGRhxiyREQeYsgSEXmIIUtE5CGGLBGRhxiyREQeYsgSEXmIIUtE5CGGLBGRhxiyREQeYsgSEXmIIUtE5CGGLBGRhxiyREQeYsgSEXmIIUtE5CGGLBGRhxiyREQeYsgSEXmIIUtE5CGGLBGRhxiyREQeYsgSEXmIIUtE5CGGLBGRhxiyREQeYsgSEXmIIUtE5CGGLBGRhxiyREQeYsgSEXmIIUtE5CGGLBGRhxiyREQeYsgSEXmIIUtE5CHjUnYWmoAQwqtaiIh8RUgArptzn1mHrBEKYtkd12F9Y91c6yIiWhBWDI7A2L4biCeV+8w6ZDXTQO2V69C0fkUhaiMi8r26ztPQd+7LGbLskyUi8hBDlojIQwxZIiIPMWSJiDw06wtfjpXBwO42nOo87WU9RES+UTk8Bidl5dxn9iGbtnD6pd3omHNZREQLQzUAe4Z9Lqm7QOZfCxHRgjObTGSfLBGRhxiyREQeYsgSEXmIIUtE5CGGLBGRhy5pqsNS1Lq0HC1Ly5XbR8fTOHpiFI7LsRFEpUbXBdatrEJVRVC5T1fvBLp7J+axqsLyfch+5M6V+NKnr1Bu37GnB3/yt69hIp6Zx6qIaDbKwib+8KHNuOXapcp9vvnY2/i3/2ibx6oKy/chOz5h4XSf+q/c0EgKLs9iiUqSKyWGhlPKz7AEEJvw9wmSkDJHAAkxtbEfwP0Adnhf0yWJlpmIlgeU21NpByNjKeR6mURUHEIANZUhBIO6cp/xCatkv4luAfBTAK0AIOW0y8b4/kzWNHWUhXO/DAHerUZUioQQCAZzf4ZT6ZluXC1tvg/Z3/vAKvzBxzdCtfLY63v78JffeAPxhL9/UUQLUVnYwB9/8WrcsLlx2u1SSnz3iUP43rbD81xZ4fg+ZA1dQzCgHolmGhqEMoKJqKgEYBo6AqrPsMyOQPAz34fsz547jtfe7FVuj8UtJFI8iyUqRYmEjX/49l6Ul6mvqwwMJuaxosLzfchOJDKQZ9S/BCvjIOfFPSIqGldKjIylc3bnxZOledFrtnwfsvfduQoPfHitcvvetgH8/SP7kEjybJao1JSFTfzRQ5uxedM
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 576x576 with 1 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {
|
|||
|
"needs_background": "light"
|
|||
|
},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"img = init_screen.copy()\n",
|
|||
|
"img = cv2.rectangle(img, (int(agent_area[0]), int(agent_area[1])), (int(agent_area[2]), int(agent_area[3])), (255,0,0), 2)\n",
|
|||
|
"img = cv2.rectangle(img, (0,0), (160, 210), (255,0,0), 2)\n",
|
|||
|
" \n",
|
|||
|
" \n",
|
|||
|
"fig = plt.figure(figsize=(8,8))\n",
|
|||
|
"plt.imshow(img)\n",
|
|||
|
"plt.axis('off')\n",
|
|||
|
"\n",
|
|||
|
"plt.show()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "8cec6e9c",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### Run Simulation"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 9,
|
|||
|
"id": "7e7cb363",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"A.L.E: Arcade Learning Environment (version 0.7.5+db37282)\n",
|
|||
|
"[Powered by Stella]\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import gym\n",
|
|||
|
"from atariari.benchmark.wrapper import AtariARIWrapper\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"env = AtariARIWrapper(gym.make('MontezumaRevenge-v4', \n",
|
|||
|
" frameskip=1, \n",
|
|||
|
" render_mode='rgb_array', \n",
|
|||
|
" repeat_action_probability=0.0))\n",
|
|||
|
"\n",
|
|||
|
"#env.unwrapped.ale.getRAM()\n",
|
|||
|
"obs = env.reset(seed=42)\n",
|
|||
|
"obs, reward, done, info = env.step(1)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 11,
|
|||
|
"id": "236a44a5",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAV0AAADnCAYAAAC9roUQAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAABI6UlEQVR4nO29d5gc1ZX//b1V1bl7enLQSDNKKEsIAUoEgzDBxjbJy4vhxaxhH3bNyxps7N0fBmwLWNngtMbgXfPDsDisWa8BrxdnkBEIJCQQKKCE0kiTU09P5+6quu8fNUHDzL3VM93qCTqf5+Hh0dzqU7dO1T1176lzz2GccxAEQRCFQRnvDhAEQZxOkNElCIIoIGR0CYIgCggZXYIgiAJCRpcgCKKAaLJGxhiFNhAEQYwSzjkTtdFMlyAIooCQ0SUIgiggZHQJgiAKCBldgiCIAkJGlyAIooBIoxcIYrJxwQUX4OGHH4bf7wcAfPDBB/jyl7+MxsZG6e8YY7jrrrtw8803D2v705/+hPXr1yOVSkllBINBfPOb38SqVauG/J1zjieeeALPPPPMKK+GmJJwzoX/AeD0H/03mf677LLL+JEjR3hnZyfv7OzkmzZt4jNnzhQer2kav/fee/nmzZv5wYMHeUdHB//ud7/LFyxYwP/5n/+Zt7W18SNHjvA33niDf+973+OBQGCYjMrKSv7kk0/yLVu28IaGBt7Y2MjvuOMOvnjxYv6Tn/yEd3R08H379vHNmzfzO+64g/eFYtJ/U/g/mV2lmS4xJbjkkktw9913o7y8HD6fL+vfMcZQU1ODBQsWAAA45+jq6sL+/fuxevVqcM5RVFSEoqIiNDc3Q1XVYTI0TUN9fT3OOOMMAEA8HkdTUxMOHDiAnp4eMMZQUVGBiooKVFZW5ueCiUkL+XSJKUFZWRlWrFiBcDiMW265Bffddx9CodB4d4sghkEzXWJSc+mll+Izn/kM6urqoGkaurq68MYbb6CnpwfpdHq8u0cQwyCjS0xqZs2ahSuvvHJg2b9kyRJ85zvfQTAYRFFREVwuF77xjW/g2LFjeOqpp2w/qBHEqYaMLjEp+ehHP4pLLrkEixcvBmOD29xra2tx3XXXDfzb4/HgE5/4BBoaGvD8889nZXQvvvhiBAIBLFy4EIoyeg+cw+HAjTfeiLVr12LlypWj/j0xtWGycj2U8IaYqKxduxZr1qwZ9ve5c+fiU5/6FFwuFwCgvb0dzz//PBoaGvDb3/4W7e3tQ45XFAWXXnoplixZgnXr1g0L9wKAvXv34ne/+x0OHTqEl156aVjoWCAQwCc/+UnMnj0bV111FWbNmjWknXOOTZs2YcuWLdi+fTtee+21XC+fmODIEt6Q0SWmFPPmzcO1114Lt9sNAGhra8Nzzz2X1Ue1yy67DGvXrh329z179uA3v/kNdF2X/t7j8eDTn/405syZM+TvnHO88sor2Lx58yiuhJjMkNElhCydX4p5s4qH/f1ESxTv7O6AYcofAadDwcozK1FV7h3yd8453v8ghANHemz7UFzkxOqzquDzOIb8PZMxsG1XO1o7ErYy6qb5sWJJBVRl6LMejqSxZUcrYgm5wQRy1wVB9CMzuuTTPY1hDLjs/Bm45br5w9r+sOk4du7rgpE2pDI8bg03Xz0Pa1ZUD/k75xyPPbs7K6NbU+HFXX+7DLVVQ+Nre6Np3Pvtt7IyussWlOG+O1bA5RwaR3vwaA/2HwnZGt186IIgsoGM7mmOGTJgHM0M+zvvMKy9NbYCALNtuAwOgIfNrPrA04DZqMOID5VhJnTweJYyohxmgw5DG3q82WwAwy9vRHLWBUFkARnd0whFYZhZG0CwyGn9G0B5wonMO8PjWYu6VCxfUIa0aRmxeFzHkRO9UBSGOXVFcLmsGaVf0+BtZsjEhsrgHKg23ThrcfnA37pCSZxojqIo4MTM6QEofa6Aep8Pyn4DGddQGaauY3ZxALHFfTNMDjS2RtHRnURNhRfVlYMujTqPF/q7abAPRRs44iYW1ZWgosoDANB1E0dPRBBP6jnrIqNn90IgiJMhn+5phNet4eHbz8W6FdNG/dtdh7vx5ce3wufW8L0vrMGsmsCoZbzw2lFsePZdXLi8Bg/ffi68rtG98w2T49v/uRPPvXwYn79mEW6/aiGEjjMB3ZEUvvL4Vuxr6MlZF+0he7cHcXpCPl0CAGBmOFpeDuPgu6O/7SciEegxA+kk0PDrLmS8ozc47a0RcBOIHkvh0LMdcI+Qx0CGyTl6jljn7doRwwed7Ta/GE5vJoNEayYvuiCIsUAzXYIgiDxDhSkJgiAmCGR0CYIgCggZXYIgiAJCRpcgCKKAkNElCIIoIGR0CYIgCggZXYIgiAJCRpcgCKKAkNElCIIoIGR0CYIgCggZXYIgiAJCRpcgCKKAkNElCIIoIGR0CYIgCggZXYIgiAJCRpcgCKKAkNElCIIoIGR0CYIgCggZXYIgiAJCRpcgCKKAkNElCIIoIGR0CYIgCggZXYIgiAJCRpcgCKKAkNElCIIoIGR0CYIgCogma3SVBgvVD4IgiNMCqdE97zv3FKofBEEQpwXkXiAIgigg0pmukckUqh8EQRCnBYxzLm5kTNxIEARBjAjnnInayL1AEARRQMjoEgRBFBCpT/fvvnBC2GaaOt589X7s3/OLvHeKIAhiqiI1uowJ3RIQt+SXFau+hLNW3iVsb2vehlf+cAcS8Y4C9Wg4jGlYe9GDWLDkJuExxw7/EZv+8kXomfiI7R5vJS75+I9QVXOuQALHjre+j3e3/SAPPSamErmOEYfDh49c9n3Uz75cKGP/7p/jzU1fA+dGzv0dK1NljEg/pN1+d5Ow0TQyeKMAM13GVDBFFR/ATZimfkr7kA12/eTcBLfpp6JoABN7fLhpjOtDT0xM8jFGmKKBTYJnb7KMEdmHNOlMdyLAuQFu5K5ETWG49+OL8Mkza22P3XqkE/e9uAuRZPbGPB/9nAgvD2LykY9nj5s6JkOo0lQYI1P+Q5rHqWLDtcvw8j0XY19LLy7//qv47XtN8Lm0Yf9tOdyJT/3wNfzm3Ub89z+cjx/ffC4qAq7xvgSCIKYQE36mu3TF7Viy/O8g8iJ3tL2LzRvvRTLRNWI7A+B3aSj3u/Cly+bjHz4yF/+5rQHrvrNx2LFr5pThmc+tQsDtQNDrQHcsBUXi1x44B1Nxztp/wtz51wqPaWz4K7Zs+jp0PTFiu9tThgvWfQvlVcsFEjh273gSe957yrY/xOlFrmNEc3ix5sL1mF5/kfAch/Y/j7e3fHtcl+5TZYxMeJ+uy10Ct7tE2K7rScRjbcKHwetU8a3rzsSF8yrx+MaDeGVfG64+azouWVg17Nh3GkL4+dZjWDItiH/+2EIc6Yjii//1Ltp6k7b99HjL4XQWCdszmRjisXZAsIhjTIXXVwVNcwtlJJPdSCV7bPtCnF7kOkYABq+vEg6HTygjnepFItGZY09zYzKNkUnt000lQ0glQ7kL4kBnNI1jnTGoCsOscv+wQ/a19OJEVwwVfhdMyctoJBLxTiTiY38oOTcQizaP+ffE6UvuY4QjHmvLW39OFVNljEx5ny5BEMREYsLPdOfOvwZz5l8tbA91HcR7bz+OdCpcuE59CMYULFz6WcyYebHwmLaWt7F7x5MwjNSI7S5XEGeeeydKSucJZRza/wIOH/yfnPtLTC1yHSOq6sayFbejsuZsoYwTxzZi3+6fgXMz1+6OmakyRia80e3q3AtT4rxPJrpg6PY+11MJ5xztrTuQTHYLj4lFWqThLrqeRGPDJnS27xYeE+rcn1M/ialJrmPENDNobtqCcPiY8JjenmOQff8pBFNljEx4oxvqOoBQ14Hx7oYNHJ3tu9DZvmvMEgwjheYTm/PYJ+J0IdcxwrmBtubteezRqWGqjJEJb3TzhaIwXLGkGnMq/IildHznT/uGHePUVNyx7gxML/bCrUl2+BAEQYyRCW90q2tXoXraSmF7LNKMo4f/IMxpcDKGyZExTDAADnX4N0SFAbrBoZujXUYxzKi/CGWVS4RHhEOH0XDkLzDNkRPDaw4vZs3
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 432x288 with 1 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {
|
|||
|
"needs_background": "light"
|
|||
|
},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from IPython import display\n",
|
|||
|
"obs = env.reset()\n",
|
|||
|
"\n",
|
|||
|
"#fig = plt.figure(figsize=(10, 8))\n",
|
|||
|
"screen = plt.imshow(env.render(mode='rgb_array'), aspect='auto')\n",
|
|||
|
"plt.axis('off')\n",
|
|||
|
"\n",
|
|||
|
"all_images = []\n",
|
|||
|
"agent_locations = []\n",
|
|||
|
"skull_locations = []\n",
|
|||
|
"room_ids = []\n",
|
|||
|
"\n",
|
|||
|
"subgoal_anno = []\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"for i, action in enumerate(df.loc[df.ID == '285_RZ_5619207_E00'].action.values[:800]): \n",
|
|||
|
" #print(action)\n",
|
|||
|
" n_state, reward, done, info = env.step(action)\n",
|
|||
|
" img = info['rgb']\n",
|
|||
|
" room_ids.append(info['labels']['room_number'])\n",
|
|||
|
" \n",
|
|||
|
" # agent \n",
|
|||
|
" mean_x, mean_y = info['labels']['player_x'], 320 - info['labels']['player_y']\n",
|
|||
|
" agent_locations.append([mean_x, mean_y])\n",
|
|||
|
" check, goal_idx = agent_in_subgoal(SUBGOALS, mean_x, mean_y)\n",
|
|||
|
" if check: \n",
|
|||
|
" \n",
|
|||
|
" box = SUBGOALS[goal_idx][0]\n",
|
|||
|
" \n",
|
|||
|
" subgoal_anno += [goal_idx[0]] * (i - len(subgoal_anno))\n",
|
|||
|
" \n",
|
|||
|
" img = cv2.rectangle(img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0,0,255), 1)\n",
|
|||
|
" \n",
|
|||
|
" x1, x2, y1, y2 = mean_x - 5 , mean_x + 10, mean_y - 15, mean_y + 10\n",
|
|||
|
" img = cv2.rectangle(img, (x1, y1), (x2, y2), (0,255,0), 2)\n",
|
|||
|
" \n",
|
|||
|
" # skull\n",
|
|||
|
" mean_x, mean_y = info['labels']['enemy_skull_x'] + 35, info['labels']['enemy_skull_y'] - 65\n",
|
|||
|
" skull_locations.append([mean_x, mean_y])\n",
|
|||
|
" x1, x2, y1, y2 = mean_x - 5, mean_x + 5, mean_y - 10, mean_y + 5\n",
|
|||
|
" img = cv2.rectangle(img, (x1, y1), (x2, y2), (255,0,0), 2)\n",
|
|||
|
" \n",
|
|||
|
" img = cv2.putText(img=img, text='Room ID: ' + str(info['labels']['room_number']) + ' index: ' + str(i), org=(5, 205), fontFace=cv2.FONT_HERSHEY_SIMPLEX, \n",
|
|||
|
" fontScale=0.3, color=(255, 255, 255),thickness=1)\n",
|
|||
|
" \n",
|
|||
|
" screen.set_data(img) # just update the data\n",
|
|||
|
" display.display(plt.gcf())\n",
|
|||
|
" display.clear_output(wait=True)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 14,
|
|||
|
"id": "257d71ee",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"CPU times: user 16.5 s, sys: 182 ms, total: 16.7 s\n",
|
|||
|
"Wall time: 16.7 s\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>ID</th>\n",
|
|||
|
" <th>frame_id</th>\n",
|
|||
|
" <th>episode_id</th>\n",
|
|||
|
" <th>score</th>\n",
|
|||
|
" <th>duration(ms)</th>\n",
|
|||
|
" <th>unclipped_reward</th>\n",
|
|||
|
" <th>action</th>\n",
|
|||
|
" <th>gaze_positions</th>\n",
|
|||
|
" <th>img_path</th>\n",
|
|||
|
" <th>level</th>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <th>player_location</th>\n",
|
|||
|
" <th>skull_location</th>\n",
|
|||
|
" <th>num_gaze_positions</th>\n",
|
|||
|
" <th>gaze_duration_ratio</th>\n",
|
|||
|
" <th>angular_gaze_displacement</th>\n",
|
|||
|
" <th>gaze_velocity</th>\n",
|
|||
|
" <th>max_gaze_velocity</th>\n",
|
|||
|
" <th>avg_gaze_velocity</th>\n",
|
|||
|
" <th>time_stamps</th>\n",
|
|||
|
" <th>current_subgoal</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>0</th>\n",
|
|||
|
" <td>284_RZ_5540489_E00</td>\n",
|
|||
|
" <td>RZ_5540489_1</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>2817</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>[[80.4, 103.5], [80.34, 103.4], [80.34, 103.3]...</td>\n",
|
|||
|
" <td>/datasets/public/anna/montezuma_revenge/284_RZ...</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>[77, 85]</td>\n",
|
|||
|
" <td>[92, 175]</td>\n",
|
|||
|
" <td>2614.0</td>\n",
|
|||
|
" <td>0.927938</td>\n",
|
|||
|
" <td>[8.986148924666498, 0.00034557292731263054, 0....</td>\n",
|
|||
|
" <td>[96840.02112006705, 3.724096925170927, 2.97928...</td>\n",
|
|||
|
" <td>266.07</td>\n",
|
|||
|
" <td>95.051218</td>\n",
|
|||
|
" <td>[0.0, 0.9279375221867234, 1.8558750443734469, ...</td>\n",
|
|||
|
" <td>8</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1</th>\n",
|
|||
|
" <td>284_RZ_5540489_E00</td>\n",
|
|||
|
" <td>RZ_5540489_2</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>50</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>[[113.66, 98.28], [113.65, 98.42], [113.65, 98...</td>\n",
|
|||
|
" <td>/datasets/public/anna/montezuma_revenge/284_RZ...</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>[77, 85]</td>\n",
|
|||
|
" <td>[92, 175]</td>\n",
|
|||
|
" <td>50.0</td>\n",
|
|||
|
" <td>1.000000</td>\n",
|
|||
|
" <td>[8.986165402283532, 0.00044873289690316374, 0....</td>\n",
|
|||
|
" <td>[89861.65402283533, 4.487328969031638, 0.0, 3....</td>\n",
|
|||
|
" <td>114.40</td>\n",
|
|||
|
" <td>104.255400</td>\n",
|
|||
|
" <td>[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ...</td>\n",
|
|||
|
" <td>8</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>2</th>\n",
|
|||
|
" <td>284_RZ_5540489_E00</td>\n",
|
|||
|
" <td>RZ_5540489_3</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>51</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>[[87.26, 100.72], [86.4, 100.8], [85.61, 100.8...</td>\n",
|
|||
|
" <td>/datasets/public/anna/montezuma_revenge/284_RZ...</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>[77, 85]</td>\n",
|
|||
|
" <td>[92, 175]</td>\n",
|
|||
|
" <td>51.0</td>\n",
|
|||
|
" <td>1.000000</td>\n",
|
|||
|
" <td>[8.986150521057317, 0.0029062899985003087, 0.0...</td>\n",
|
|||
|
" <td>[89861.50521057317, 29.06289998500309, 21.2663...</td>\n",
|
|||
|
" <td>102.50</td>\n",
|
|||
|
" <td>90.330686</td>\n",
|
|||
|
" <td>[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ...</td>\n",
|
|||
|
" <td>8</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>3</th>\n",
|
|||
|
" <td>284_RZ_5540489_E00</td>\n",
|
|||
|
" <td>RZ_5540489_4</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>51</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>[[78.41, 101.5], [78.41, 101.5], [78.41, 101.6...</td>\n",
|
|||
|
" <td>/datasets/public/anna/montezuma_revenge/284_RZ...</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>[77, 85]</td>\n",
|
|||
|
" <td>[92, 175]</td>\n",
|
|||
|
" <td>51.0</td>\n",
|
|||
|
" <td>1.000000</td>\n",
|
|||
|
" <td>[8.986146988240504, 0.00041474113941382504, 0....</td>\n",
|
|||
|
" <td>[89861.46988240504, 4.14741139413825, 7.982843...</td>\n",
|
|||
|
" <td>102.88</td>\n",
|
|||
|
" <td>90.392941</td>\n",
|
|||
|
" <td>[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ...</td>\n",
|
|||
|
" <td>8</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>4</th>\n",
|
|||
|
" <td>284_RZ_5540489_E00</td>\n",
|
|||
|
" <td>RZ_5540489_5</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>55</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>[[78.41, 102.85], [78.42, 102.95], [78.41, 102...</td>\n",
|
|||
|
" <td>/datasets/public/anna/montezuma_revenge/284_RZ...</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>[77, 85]</td>\n",
|
|||
|
" <td>[91, 175]</td>\n",
|
|||
|
" <td>55.0</td>\n",
|
|||
|
" <td>1.000000</td>\n",
|
|||
|
" <td>[8.986147775662209, 7.822521255579121e-05, 0.0...</td>\n",
|
|||
|
" <td>[89861.47775662209, 0.7822521255579121, 2.8601...</td>\n",
|
|||
|
" <td>103.92</td>\n",
|
|||
|
" <td>90.895364</td>\n",
|
|||
|
" <td>[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ...</td>\n",
|
|||
|
" <td>8</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>5 rows × 21 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" ID frame_id episode_id score duration(ms) \\\n",
|
|||
|
"0 284_RZ_5540489_E00 RZ_5540489_1 0 0 2817 \n",
|
|||
|
"1 284_RZ_5540489_E00 RZ_5540489_2 0 0 50 \n",
|
|||
|
"2 284_RZ_5540489_E00 RZ_5540489_3 0 0 51 \n",
|
|||
|
"3 284_RZ_5540489_E00 RZ_5540489_4 0 0 51 \n",
|
|||
|
"4 284_RZ_5540489_E00 RZ_5540489_5 0 0 55 \n",
|
|||
|
"\n",
|
|||
|
" unclipped_reward action gaze_positions \\\n",
|
|||
|
"0 0 0 [[80.4, 103.5], [80.34, 103.4], [80.34, 103.3]... \n",
|
|||
|
"1 0 0 [[113.66, 98.28], [113.65, 98.42], [113.65, 98... \n",
|
|||
|
"2 0 0 [[87.26, 100.72], [86.4, 100.8], [85.61, 100.8... \n",
|
|||
|
"3 0 0 [[78.41, 101.5], [78.41, 101.5], [78.41, 101.6... \n",
|
|||
|
"4 0 0 [[78.41, 102.85], [78.42, 102.95], [78.41, 102... \n",
|
|||
|
"\n",
|
|||
|
" img_path level ... \\\n",
|
|||
|
"0 /datasets/public/anna/montezuma_revenge/284_RZ... 0 ... \n",
|
|||
|
"1 /datasets/public/anna/montezuma_revenge/284_RZ... 0 ... \n",
|
|||
|
"2 /datasets/public/anna/montezuma_revenge/284_RZ... 0 ... \n",
|
|||
|
"3 /datasets/public/anna/montezuma_revenge/284_RZ... 0 ... \n",
|
|||
|
"4 /datasets/public/anna/montezuma_revenge/284_RZ... 0 ... \n",
|
|||
|
"\n",
|
|||
|
" player_location skull_location num_gaze_positions gaze_duration_ratio \\\n",
|
|||
|
"0 [77, 85] [92, 175] 2614.0 0.927938 \n",
|
|||
|
"1 [77, 85] [92, 175] 50.0 1.000000 \n",
|
|||
|
"2 [77, 85] [92, 175] 51.0 1.000000 \n",
|
|||
|
"3 [77, 85] [92, 175] 51.0 1.000000 \n",
|
|||
|
"4 [77, 85] [91, 175] 55.0 1.000000 \n",
|
|||
|
"\n",
|
|||
|
" angular_gaze_displacement \\\n",
|
|||
|
"0 [8.986148924666498, 0.00034557292731263054, 0.... \n",
|
|||
|
"1 [8.986165402283532, 0.00044873289690316374, 0.... \n",
|
|||
|
"2 [8.986150521057317, 0.0029062899985003087, 0.0... \n",
|
|||
|
"3 [8.986146988240504, 0.00041474113941382504, 0.... \n",
|
|||
|
"4 [8.986147775662209, 7.822521255579121e-05, 0.0... \n",
|
|||
|
"\n",
|
|||
|
" gaze_velocity max_gaze_velocity \\\n",
|
|||
|
"0 [96840.02112006705, 3.724096925170927, 2.97928... 266.07 \n",
|
|||
|
"1 [89861.65402283533, 4.487328969031638, 0.0, 3.... 114.40 \n",
|
|||
|
"2 [89861.50521057317, 29.06289998500309, 21.2663... 102.50 \n",
|
|||
|
"3 [89861.46988240504, 4.14741139413825, 7.982843... 102.88 \n",
|
|||
|
"4 [89861.47775662209, 0.7822521255579121, 2.8601... 103.92 \n",
|
|||
|
"\n",
|
|||
|
" avg_gaze_velocity time_stamps \\\n",
|
|||
|
"0 95.051218 [0.0, 0.9279375221867234, 1.8558750443734469, ... \n",
|
|||
|
"1 104.255400 [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ... \n",
|
|||
|
"2 90.330686 [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ... \n",
|
|||
|
"3 90.392941 [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ... \n",
|
|||
|
"4 90.895364 [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ... \n",
|
|||
|
"\n",
|
|||
|
" current_subgoal \n",
|
|||
|
"0 8 \n",
|
|||
|
"1 8 \n",
|
|||
|
"2 8 \n",
|
|||
|
"3 8 \n",
|
|||
|
"4 8 \n",
|
|||
|
"\n",
|
|||
|
"[5 rows x 21 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 14,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"%%time \n",
|
|||
|
"df['current_subgoal'] = None \n",
|
|||
|
"\n",
|
|||
|
"for episode in df.ID.unique():\n",
|
|||
|
" \n",
|
|||
|
" obs = env.reset()\n",
|
|||
|
" subgoal_anno = []\n",
|
|||
|
" \n",
|
|||
|
" valid_actions_idx = df.loc[df.ID == episode].loc[df.room_id == 1].loc[df.level==0].index\n",
|
|||
|
" \n",
|
|||
|
" if len(valid_actions_idx) > 0: \n",
|
|||
|
"\n",
|
|||
|
" for i, action in enumerate(df.iloc[valid_actions_idx].action.values): \n",
|
|||
|
"\n",
|
|||
|
" n_state, reward, done, info = env.step(action)\n",
|
|||
|
"\n",
|
|||
|
" # agent \n",
|
|||
|
" mean_x, mean_y = info['labels']['player_x'], 320 - info['labels']['player_y']\n",
|
|||
|
"\n",
|
|||
|
" check, goal_idx = agent_in_subgoal(SUBGOALS, mean_x, mean_y)\n",
|
|||
|
" if check: \n",
|
|||
|
" box = SUBGOALS[goal_idx][0]\n",
|
|||
|
" subgoal_anno += [goal_idx[0]] * (i - len(subgoal_anno))\n",
|
|||
|
"\n",
|
|||
|
" subgoal_anno += [None] * (i - len(subgoal_anno) + 1)\n",
|
|||
|
"\n",
|
|||
|
" assert len(valid_actions_idx) == len(subgoal_anno), f'{episode, i} Number of actions: {len(valid_actions_idx)} does not match length of subgoal annotation: { len(subgoal_anno)}'\n",
|
|||
|
"\n",
|
|||
|
" df.loc[valid_actions_idx, 'current_subgoal'] = subgoal_anno\n",
|
|||
|
"\n",
|
|||
|
"df.head()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 15,
|
|||
|
"id": "0374a999",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"(37879, 2)\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAM4AAAD8CAYAAAA/rZtiAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAABIVklEQVR4nO29f4xk2XXf97lbVd3V1dXdtV293bPzg5zZH1yKIq2VuKCIyFHsMHYkIgmtIGDIPyRKIbwSQCI2oCAhJSNRbBiQHVFCggSEKYgQFcikFdCyCIGJRTMObAGmrKW8En9ptbPcWc/sbPds92zNdk11dXfV3Pxx36l33qn7qqp/TVcP3xd4eK9evR/3vXe/9/y4557rvPcUKFDgYHjotAtQoMBZREGcAgUOgYI4BQocAgVxChQ4BAriFChwCBTEKVDgEDgx4jjnfsw594Jz7qpz7hMndZ8CBU4D7iT6cZxzJeAvgL8G3AD+GPiw9/7bx36zAgVOASclcd4DXPXef9d7vwd8AfjACd2rQIH7jvIJXfcCcF39vgH8cN7BztU8NE6oKAUKHBavbXrvH4n9c1LEGQvn3LPAs+HXUrpZoMDU4H9+Je+fk1LVXgUuqd8Xk30DeO8/471/xnv/DNROqBgFCpwMToo4fww86Zy74pybAT4EfOmE7lWgwH3Hiahq3vuec+7jwD8HSsBnvfffOol7FShwGjgxG8d7/2Xgyyd1/QIFThNF5ECBAodAQZwCBQ6BgjgFChwCBXEKFDgECuIUKHAIFMQpUOAQKIhToMAhUBCnQIFDoCBOgQKHQEGcAgUOgYI4BQocAgVxChQ4BAriFChwCJzaCNACJ4XKAY/fP5FSPOgoiPPAQBNm0s/aM+cVJJoUBXHOPGKEqZjfFr2cfRUK8kyGgjhnFpYwlZztGHrq/57ZL+cWBBqFQzsHnHOXnHP/0jn3befct5xzfyvZ/0vOuVedc88ny/uPr7gFAjQxhChlYE4tC8AisGyWxeQ/fay+hpVaBWI4isTpAT/vvf8T59wC8HXn3FeS/37Ne/8rRy9egWFYyWKJUzFrORbCJ9tX6x217pGVQIXqNgqHJo73/jXgtWR72zn3HUIiwgInhjzSzJllkWFJAlnC6KWSrC0K8uThWGwc59xl4AeBPwJ+BPi4c+6ngOcIUumN47hPAYiTZoGQm07UM1nPMUwcIcubybIDbI+4X8yRUODIHaDOuTrwReBve+/fBD4NPA48TZBIn8o571nn3HPOueegc9RifA/A2hyaNGLLrBGE/kWgCeUa1CvQcGGpVpJj5bi15DxNtLmcexXQOJLEcc5VCKT5be/9PwXw3m+o/38d+P3Yud77zwCfCcedL6a+HomYiiY2TY2sE6AJVUIq7jphW75yN1laQKvGMEn21VpXjULqWByaOM45B/wG8B3v/a+q/Y8m9g/ATwDfPFoRC2QRs2tEYiSkOQesEIhTJ0ucNuGYKrDuCGQTh4E4CyrJb02YwtbROIrE+RHgJ4FvOOeeT/b9AvBh59zTgAeuAT97hHsUGCDWV6MJtBx2r5hFSNIjlTaaSC1R33aS6+yQSrQyBVniOIpX7Q8BF/mryN55rIjZGzF1rRIIUieoabJoidNO1j1SadQGetoDJ4t42UT6FNAoIgfONEyEgEiXstrWNk4vsn/QdaO9bwXGoRhWMNUYFzpzANjImp7dv08hWSZH0cTcVxzVrWs/l9gfSYXvmqWtTuupffqYnlxHk2Y/Z7uAoCDOfUEeYbQONck1dBCmnLNP6Afbh3YlGP+ihglZ9HZLLUIstkk7RqVPzXrVCmgUxDlRTDpG5qCfQfpZdNjMm0Az9ZoJUSyJ2oRjNpNt3iSQRa4jhLGSppA8GgVxTgwx+ySPSLGWPa+i6s5JiTvbZuAN6y7COtk+G4HsE3WNDrAF3CaVNnJNuX4hdWIoiHPsyBtYNipc346BsSMzLUQqSIBmhSA5kmv0FmGzkkqfsjplYNNIrNptAvEkbk06QfNsngJQEOeYMW5w2SjyiCSx7q8YtNSx0kEkxlzon+nZkBpR60RSvWn2aVWtkDZ5KIhzIoiNlYmN0BRY0owikB03o68hUmib7LCC2DHaGaDtGytxCmkTQ0GcY8OoAWZ5JBLMkVZWS4pRkkhXcDlGpIqVcpCVJB2yRNlR/+t7F4ihIM6xIM8RoAMyy4TQGKuu7avf4xwCMckj52ji6Ptb2FGgsbUuV4EYCuIcO6wto0kTU5+0tIlhlJ2hu/+tqieOAw3TYZohjN1fYBQK4hwbYtJGFiFNjDhyrkQnQ7Yil80+u62hiQPx4dD6OEuYUdcuoFEQ58iomG0tbWwkc8zOsZVX/rOqGOZ4izxijUr3VBDmsCiIcyKwRNGLzTyjzzlI5Z3UVTzuuIIwh0FBnPuGcs72YXDY/pWCJMeFgjj3DTZAUxvz4xwD1ojXKMhwGjgycZxz1wg9bn2g571/xjm3DPwT4DJh+PQHv/dSROkeeK26Qdbg125gfU7McLe/C9KcFo5rINtf9d4/7b1/Jvn9CeCr3vsnga8mv7+HYCu3JYXuuY/9tgQqopWnDSc1AvQDwOeS7c8Bf+OE7jNlsB2INn7MDgOQ8BhLon1zLUHRzzItOA4bxwN/4JzzwD9K8qWtqRRR64TMdw8odDol2+eyY46TY2NhMDZFU16PfoFpwHEQ5y977191zq0CX3HO/bn+03vvE1Jl4Jx7Fng2/Fo6hmJMC6w0sAnN8/pTRoXBjLp+gdPAkYnjvX81Wd9yzv0u8B5gQxITOuceBW5FznsAM3lqqROTEPp/zHGx8JciFGZacSQbxzk3n0zxgXNuHvjrhMydXwI+khz2EeD3jnKfswdb+W0of94SU9cK0kwjjipx1oDfDdlwKQP/2Hv//zjn/hj4HefcR4FXgA8e8T5TDrFdIBsqo1PJxmLOUOfY/wq38zTjSMTx3n8X+IHI/i3gfUe59tmGjTMblQ3TkqIgzFlAETlwYhgnZeyxGgVhph0FcU4UB3UhF4Q5KyiIMxUoCHPWUBDn1FCQ5SyjIM59R0GYBwEFce4bCsI8SCiIc+IoCPMgopgf50RRkOZBRUGcE0NBmgcZhap2bCiI8r2EQuIUKHAIFMQpUOAQKIhToMAhUBCnQIFDoCBOgQKHwJQQx3H0qcwLFLh/mBJ3tEvWdvawAgWmE4cmjnPuKUK2TsFjwP8INIC/Cbye7P8F7/2Xx1yNdJ4YyKaLLQhUYPpwaOJ4718AngZwzpWAV4HfBX4G+DXv/a9MfrUS6dwxQpp9potAeRPeFvhexHGpau8DXvLev5Ik7jggSsAC2awwZfN71FR/x41J7a37WaZxyJt9rcBJ4LicAx8CPq9+f9w592fOuc865x4ef3qFkDBnGVhkePYyO9fMccPOYXMWMK7MZ/GZzg6OTBzn3AzwXwD/V7Lr08DjBDXuNeBTOec965x7zjn3HNwlTGxwgSyBNInsbGfHgeOoVPe7UuoylydYYucVOCqOQ1X7ceBPvPcbALIGcM79OvD7sZOymTyf8fAEcDtZ5ghJyd8knR9T52E+qup2FiuQLnNeAzIqi6jsnxab8WzjOIjzYZSaJqlvk58/QcjsORo1YNnBehN6y4SMuRuEj/ymOvAo5BlFlklfw2kkPh9FmDwCaXLYad7h/tuMDx6ORJwk7e1fA35W7f6HzrmnCbMYXDP/xVEHnkmOvuFgc41hVUMQI89hELt23rWkguW16CdRCccRJqa62vJZLyVkyVRIn8PiqJk87wJNs+8nD3yhBeC9hB6gOnCVIH2A4aTko2YoG4dJyBJLim4rWKwVP07kkcbOWm0JpNEj7RvTZbVlL6TPYTAVkQOu3sc/A1RJv2sP2GySnXDJZvA/rLQZRRYbvRB7RTEV6KiVL69MowiT5zTRMx/octoyF9LnsJgK4syVO9x75jbd3jK0yS7dZbLZ/IVE9oMfFHlksbNDa4lnVZ/jIM84wuhtIUuZYBiWGSYPDM+vI+rtKEkp+wvpMwmmgjhVdllbusZ3nngYNl2Yw20F2ATWa6SVQxYRScfVyueRKA/jyMOYco2
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 432x288 with 1 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {
|
|||
|
"needs_background": "light"
|
|||
|
},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from scipy.ndimage import gaussian_filter\n",
|
|||
|
"SIGMA = (210 / 44.6, 160 / 28.5)\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"saliency_map = np.zeros(init_screen.shape[:2])\n",
|
|||
|
"\n",
|
|||
|
"episode = '284_RZ_5540489_E00'\n",
|
|||
|
"gaze = df.loc[df.ID == episode].loc[df.room_id == 1].loc[df.level==0].gaze_positions\n",
|
|||
|
"\n",
|
|||
|
"flat_list = []\n",
|
|||
|
"for gaze_points in gaze:\n",
|
|||
|
" if gaze_points is not None: \n",
|
|||
|
" for item in gaze_points:\n",
|
|||
|
" flat_list.append(item)\n",
|
|||
|
"\n",
|
|||
|
"print(np.array(flat_list).shape)\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"# Add gaze coordinates to saliency map\n",
|
|||
|
"for cords in flat_list:\n",
|
|||
|
" try: \n",
|
|||
|
" saliency_map[int(cords[1])][int(cords[0])] += 1\n",
|
|||
|
" except:\n",
|
|||
|
" # Not all gaze points are on image \n",
|
|||
|
" continue\n",
|
|||
|
" \n",
|
|||
|
"# Construct empirical saliency map\n",
|
|||
|
"saliency_map = gaussian_filter(saliency_map, sigma=SIGMA, mode='nearest')\n",
|
|||
|
"\n",
|
|||
|
"plt.imshow(saliency_map, cmap='jet')\n",
|
|||
|
"plt.show()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 16,
|
|||
|
"id": "c6295363",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Saved dataframe to /datasets/public/anna/montezuma_revenge/all_trials_labeled.pkl\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"path = os.path.join(DATA_PATH, \"all_trials_labeled.pkl\")\n",
|
|||
|
"df.to_pickle(path) \n",
|
|||
|
"print(f'Saved dataframe to {path}')"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 17,
|
|||
|
"id": "11aebc9c",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"284_RZ_5540489_E00 [8 6 1 0 2 9 None]\n",
|
|||
|
"285_RZ_5619207_E00 [8 6 1 0 2 7 9 None]\n",
|
|||
|
"285_RZ_5619207_E01 [8 6 1 0 2 9 None]\n",
|
|||
|
"291_RZ_7364933_E00 [8 6 1 0 2 7 9 None]\n",
|
|||
|
"333_RZ_900705_E00 [8 6 1 0 2 9 None]\n",
|
|||
|
"340_RZ_1323550_E00 [8 6 1 0 2 7 9 None]\n",
|
|||
|
"359_RZ_1993616_E00 [8 6 1 0 2 7 9 None]\n",
|
|||
|
"365_RZ_2079996_E00 [8 6 1 0 2 7 9 None]\n",
|
|||
|
"371_RZ_2173469_E00 [8 6 1 0 2 7 9 None]\n",
|
|||
|
"385_RZ_2344725_E00 [8 6 1 0 2 7 9 None]\n",
|
|||
|
"398_RZ_2530473_E00 [8 6 1 0 2 9 None]\n",
|
|||
|
"402_RZ_2603283_E00 [8 6 1 0 2 9 None]\n",
|
|||
|
"416_RZ_2788252_E00 [8 6 1 0 2 7 9 None]\n",
|
|||
|
"429_RZ_2945490_E00 [8 6 1 0 2 7 9 None]\n",
|
|||
|
"436_RZ_3131841_E00 [8 6 1 0 2 7 9 None]\n",
|
|||
|
"459_RZ_3291266_E00 [8 6 1 0 2 7 9 None]\n",
|
|||
|
"469_RZ_3390904_E00 [8 6 1 0 2 7 9 None]\n",
|
|||
|
"480_RZ_3470098_E00 [8 6 1 0 2 7 9 None]\n",
|
|||
|
"493_RZ_3557734_E00 [8 6 1 0 2 7 9 None]\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"all_orders = []\n",
|
|||
|
"for episode in df.ID.unique(): \n",
|
|||
|
" subgoal_order = df.loc[df.ID == episode].loc[df.room_id == 1].loc[df.level==0].current_subgoal.unique()\n",
|
|||
|
" if len(subgoal_order) > 0: \n",
|
|||
|
" all_orders.append(subgoal_order)\n",
|
|||
|
" print(episode, subgoal_order)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"id": "34c2efb3",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## Get majority vote of subgoal order"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 18,
|
|||
|
"id": "6c8f80d7",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"[8, 6, 1, 0, 2, 7, 9, None]"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 18,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from collections import Counter\n",
|
|||
|
"\n",
|
|||
|
"max_len = max([len(order) for order in all_orders])\n",
|
|||
|
"majority_order = []\n",
|
|||
|
"\n",
|
|||
|
"for i in range(max_len):\n",
|
|||
|
" votes = []\n",
|
|||
|
" for order in all_orders:\n",
|
|||
|
" \n",
|
|||
|
" if i < len(order):\n",
|
|||
|
" votes.append(order[i])\n",
|
|||
|
" \n",
|
|||
|
" vote_count = Counter(votes)\n",
|
|||
|
" majority_order.append(vote_count.most_common(1)[0][0])\n",
|
|||
|
"\n",
|
|||
|
"majority_order"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 19,
|
|||
|
"id": "5649bb15",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"subgoal_order = majority_order"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"id": "21af5234",
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": []
|
|||
|
}
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"kernelspec": {
|
|||
|
"display_name": "msc_env",
|
|||
|
"language": "python",
|
|||
|
"name": "msc_env"
|
|||
|
},
|
|||
|
"language_info": {
|
|||
|
"codemirror_mode": {
|
|||
|
"name": "ipython",
|
|||
|
"version": 3
|
|||
|
},
|
|||
|
"file_extension": ".py",
|
|||
|
"mimetype": "text/x-python",
|
|||
|
"name": "python",
|
|||
|
"nbconvert_exporter": "python",
|
|||
|
"pygments_lexer": "ipython3",
|
|||
|
"version": "3.9.12"
|
|||
|
}
|
|||
|
},
|
|||
|
"nbformat": 4,
|
|||
|
"nbformat_minor": 5
|
|||
|
}
|