{ "cells": [ { "cell_type": "markdown", "id": "909118c1", "metadata": {}, "source": [ "# Label Subgoals\n", "- Label each frame with current subgoal \n", "- Extract order of subgoals visited \n", "- Only include visited subgoals as true subgoals \n", "\n", "Result: `subgoal_order = [8, 6, 1, 0, 2, 7, 9]`\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "b382d001", "metadata": {}, "outputs": [], "source": [ "import os\n", "import random\n", "\n", "import cv2\n", "import numpy as np\n", "import pandas as pd \n", "import matplotlib.pyplot as plt \n", "\n", "\n", "SIGMA = (210 / 44.6, 160 / 28.5)\n", "\n", "GPU_DEVICE = 3\n", "\n", "os.environ['CUDA_VISIBLE_DEVICES']=str(GPU_DEVICE)\n", "\n", "DATA_PATH = '/datasets/public/anna/montezuma_revenge'" ] }, { "cell_type": "markdown", "id": "f7c3ec86", "metadata": {}, "source": [ "### Load labeled data and extracted subgoals " ] }, { "cell_type": "code", "execution_count": 2, "id": "21935cab", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 11.2 s, sys: 3.94 s, total: 15.2 s\n", "Wall time: 15.2 s\n" ] }, { "data": { "text/html": [ "
\n", " | ID | \n", "frame_id | \n", "episode_id | \n", "score | \n", "duration(ms) | \n", "unclipped_reward | \n", "action | \n", "gaze_positions | \n", "img_path | \n", "level | \n", "... | \n", "player_location | \n", "skull_location | \n", "num_gaze_positions | \n", "gaze_duration_ratio | \n", "angular_gaze_displacement | \n", "gaze_velocity | \n", "max_gaze_velocity | \n", "avg_gaze_velocity | \n", "time_stamps | \n", "current_subgoal | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "284_RZ_5540489_E00 | \n", "RZ_5540489_1 | \n", "0 | \n", "0 | \n", "2817 | \n", "0 | \n", "0 | \n", "[[80.4, 103.5], [80.34, 103.4], [80.34, 103.3]... | \n", "/datasets/public/anna/montezuma_revenge/284_RZ... | \n", "0 | \n", "... | \n", "[77, 85] | \n", "[92, 175] | \n", "2614.0 | \n", "0.927938 | \n", "[8.986148924666498, 0.00034557292731263054, 0.... | \n", "[96840.02112006705, 3.724096925170927, 2.97928... | \n", "266.07 | \n", "95.051218 | \n", "[0.0, 0.9279375221867234, 1.8558750443734469, ... | \n", "8 | \n", "
1 | \n", "284_RZ_5540489_E00 | \n", "RZ_5540489_2 | \n", "0 | \n", "0 | \n", "50 | \n", "0 | \n", "0 | \n", "[[113.66, 98.28], [113.65, 98.42], [113.65, 98... | \n", "/datasets/public/anna/montezuma_revenge/284_RZ... | \n", "0 | \n", "... | \n", "[77, 85] | \n", "[92, 175] | \n", "50.0 | \n", "1.000000 | \n", "[8.986165402283532, 0.00044873289690316374, 0.... | \n", "[89861.65402283533, 4.487328969031638, 0.0, 3.... | \n", "114.40 | \n", "104.255400 | \n", "[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ... | \n", "8 | \n", "
2 | \n", "284_RZ_5540489_E00 | \n", "RZ_5540489_3 | \n", "0 | \n", "0 | \n", "51 | \n", "0 | \n", "0 | \n", "[[87.26, 100.72], [86.4, 100.8], [85.61, 100.8... | \n", "/datasets/public/anna/montezuma_revenge/284_RZ... | \n", "0 | \n", "... | \n", "[77, 85] | \n", "[92, 175] | \n", "51.0 | \n", "1.000000 | \n", "[8.986150521057317, 0.0029062899985003087, 0.0... | \n", "[89861.50521057317, 29.06289998500309, 21.2663... | \n", "102.50 | \n", "90.330686 | \n", "[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ... | \n", "8 | \n", "
3 | \n", "284_RZ_5540489_E00 | \n", "RZ_5540489_4 | \n", "0 | \n", "0 | \n", "51 | \n", "0 | \n", "0 | \n", "[[78.41, 101.5], [78.41, 101.5], [78.41, 101.6... | \n", "/datasets/public/anna/montezuma_revenge/284_RZ... | \n", "0 | \n", "... | \n", "[77, 85] | \n", "[92, 175] | \n", "51.0 | \n", "1.000000 | \n", "[8.986146988240504, 0.00041474113941382504, 0.... | \n", "[89861.46988240504, 4.14741139413825, 7.982843... | \n", "102.88 | \n", "90.392941 | \n", "[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ... | \n", "8 | \n", "
4 | \n", "284_RZ_5540489_E00 | \n", "RZ_5540489_5 | \n", "0 | \n", "0 | \n", "55 | \n", "0 | \n", "0 | \n", "[[78.41, 102.85], [78.42, 102.95], [78.41, 102... | \n", "/datasets/public/anna/montezuma_revenge/284_RZ... | \n", "0 | \n", "... | \n", "[77, 85] | \n", "[91, 175] | \n", "55.0 | \n", "1.000000 | \n", "[8.986147775662209, 7.822521255579121e-05, 0.0... | \n", "[89861.47775662209, 0.7822521255579121, 2.8601... | \n", "103.92 | \n", "90.895364 | \n", "[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ... | \n", "8 | \n", "
5 rows × 21 columns
\n", "\n", " | ID | \n", "frame_id | \n", "episode_id | \n", "score | \n", "duration(ms) | \n", "unclipped_reward | \n", "action | \n", "gaze_positions | \n", "img_path | \n", "level | \n", "... | \n", "player_location | \n", "skull_location | \n", "num_gaze_positions | \n", "gaze_duration_ratio | \n", "angular_gaze_displacement | \n", "gaze_velocity | \n", "max_gaze_velocity | \n", "avg_gaze_velocity | \n", "time_stamps | \n", "current_subgoal | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "284_RZ_5540489_E00 | \n", "RZ_5540489_1 | \n", "0 | \n", "0 | \n", "2817 | \n", "0 | \n", "0 | \n", "[[80.4, 103.5], [80.34, 103.4], [80.34, 103.3]... | \n", "/datasets/public/anna/montezuma_revenge/284_RZ... | \n", "0 | \n", "... | \n", "[77, 85] | \n", "[92, 175] | \n", "2614.0 | \n", "0.927938 | \n", "[8.986148924666498, 0.00034557292731263054, 0.... | \n", "[96840.02112006705, 3.724096925170927, 2.97928... | \n", "266.07 | \n", "95.051218 | \n", "[0.0, 0.9279375221867234, 1.8558750443734469, ... | \n", "8 | \n", "
1 | \n", "284_RZ_5540489_E00 | \n", "RZ_5540489_2 | \n", "0 | \n", "0 | \n", "50 | \n", "0 | \n", "0 | \n", "[[113.66, 98.28], [113.65, 98.42], [113.65, 98... | \n", "/datasets/public/anna/montezuma_revenge/284_RZ... | \n", "0 | \n", "... | \n", "[77, 85] | \n", "[92, 175] | \n", "50.0 | \n", "1.000000 | \n", "[8.986165402283532, 0.00044873289690316374, 0.... | \n", "[89861.65402283533, 4.487328969031638, 0.0, 3.... | \n", "114.40 | \n", "104.255400 | \n", "[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ... | \n", "8 | \n", "
2 | \n", "284_RZ_5540489_E00 | \n", "RZ_5540489_3 | \n", "0 | \n", "0 | \n", "51 | \n", "0 | \n", "0 | \n", "[[87.26, 100.72], [86.4, 100.8], [85.61, 100.8... | \n", "/datasets/public/anna/montezuma_revenge/284_RZ... | \n", "0 | \n", "... | \n", "[77, 85] | \n", "[92, 175] | \n", "51.0 | \n", "1.000000 | \n", "[8.986150521057317, 0.0029062899985003087, 0.0... | \n", "[89861.50521057317, 29.06289998500309, 21.2663... | \n", "102.50 | \n", "90.330686 | \n", "[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ... | \n", "8 | \n", "
3 | \n", "284_RZ_5540489_E00 | \n", "RZ_5540489_4 | \n", "0 | \n", "0 | \n", "51 | \n", "0 | \n", "0 | \n", "[[78.41, 101.5], [78.41, 101.5], [78.41, 101.6... | \n", "/datasets/public/anna/montezuma_revenge/284_RZ... | \n", "0 | \n", "... | \n", "[77, 85] | \n", "[92, 175] | \n", "51.0 | \n", "1.000000 | \n", "[8.986146988240504, 0.00041474113941382504, 0.... | \n", "[89861.46988240504, 4.14741139413825, 7.982843... | \n", "102.88 | \n", "90.392941 | \n", "[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ... | \n", "8 | \n", "
4 | \n", "284_RZ_5540489_E00 | \n", "RZ_5540489_5 | \n", "0 | \n", "0 | \n", "55 | \n", "0 | \n", "0 | \n", "[[78.41, 102.85], [78.42, 102.95], [78.41, 102... | \n", "/datasets/public/anna/montezuma_revenge/284_RZ... | \n", "0 | \n", "... | \n", "[77, 85] | \n", "[91, 175] | \n", "55.0 | \n", "1.000000 | \n", "[8.986147775662209, 7.822521255579121e-05, 0.0... | \n", "[89861.47775662209, 0.7822521255579121, 2.8601... | \n", "103.92 | \n", "90.895364 | \n", "[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ... | \n", "8 | \n", "
5 rows × 21 columns
\n", "