add subgoal generation pipeline
This commit is contained in:
parent
1604589c5f
commit
08fb70d368
6 changed files with 2304 additions and 0 deletions
200
Preprocess_AtariHEAD.ipynb
Normal file
200
Preprocess_AtariHEAD.ipynb
Normal file
|
@ -0,0 +1,200 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d094257c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import numpy as np \n",
|
||||
"import cv2 \n",
|
||||
"import pandas as pd\n",
|
||||
"import matplotlib.pyplot as plt \n",
|
||||
"import subprocess\n",
|
||||
"import warnings\n",
|
||||
"\n",
|
||||
"import dataset_utils as utils"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "dbabd791",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"\n",
|
||||
"### Read all Montezuma's Revenge trials"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9e42d862",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"DATA_PATH = 'montezuma_revenge/'\n",
|
||||
"\n",
|
||||
"df = pd.read_csv(os.path.join(DATA_PATH, 'meta_data.csv'))\n",
|
||||
"df = df.loc[df.GameName.str.contains('montezuma_revenge')]\n",
|
||||
"df.sort_values(by=['trial_id'], inplace=True, ascending=True)\n",
|
||||
"df.reset_index(drop=True, inplace=True)\n",
|
||||
"df.head()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e42160ea",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Get folder names of each trial"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a9175b8f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"file_lst = [os.path.join(root, name) for root, dirs, files in os.walk(DATA_PATH) for name in files if 'tar.bz2' in name]\n",
|
||||
"\n",
|
||||
"folder_lst = [f.split('.')[0].split('/')[-1] for f in file_lst]\n",
|
||||
"folder_lst.sort(key=lambda x: int(str(x).split('_')[0]), reverse=False)\n",
|
||||
"\n",
|
||||
"df['trial_folder'] = folder_lst\n",
|
||||
"\n",
|
||||
"df.to_pickle(os.path.join(DATA_PATH, \"all_trials_summary.pkl\")) \n",
|
||||
"df.head()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "28f38343",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Unpack all folders"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c83443d7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for i in df.trial_id:\n",
|
||||
" file = [f for f in file_lst if str(i) + '_' in f]\n",
|
||||
" print(i, *file)\n",
|
||||
" cmd = f'tar -jxf {file[0]} --directory {DATA_PATH}'\n",
|
||||
" subprocess.call(cmd, shell=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b65c1efb",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Genarate Dataframe with all Trials"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "24ec2a88",
|
||||
"metadata": {
|
||||
"scrolled": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%%time\n",
|
||||
"def write_unique_id(frame_id, episode_id, trial_id):\n",
|
||||
" if not pd.isna(episode_id) and not pd.isna(frame_id):\n",
|
||||
" unique_id = str(trial_id)+ '_' + '_'.join(frame_id.split('_')[:2]) + '_E{:01d}'.format(int(episode_id))\n",
|
||||
" elif not pd.isna(frame_id):\n",
|
||||
" unique_id = str(trial_id)+ '_' + '_'.join(frame_id.split('_')[:2]) + '_E0'\n",
|
||||
" else: \n",
|
||||
" unique_id = None\n",
|
||||
" return unique_id\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"path = os.path.join(DATA_PATH, df.iloc[0].trial_folder)\n",
|
||||
"print(path) \n",
|
||||
" \n",
|
||||
"# Read Annotations\n",
|
||||
"trial_df = utils.txt_to_dataframe(path + '.txt') \n",
|
||||
"\n",
|
||||
"# Write unique ID\n",
|
||||
"trial_df['ID'] = trial_df.apply(lambda x: write_unique_id(x['frame_id'], x['episode_id'], df.iloc[0].trial_id), axis=1)\n",
|
||||
"\n",
|
||||
"# Write image paths\n",
|
||||
"trial_df['img_path'] = trial_df.apply(lambda x: os.path.join(path, str(x['frame_id']) + '.png'), axis=1)\n",
|
||||
"\n",
|
||||
"# Reorder columns\n",
|
||||
"cols = ['ID'] + [c for c in trial_df.columns.tolist() if not c=='ID'] \n",
|
||||
"trial_df = trial_df[cols]\n",
|
||||
"\n",
|
||||
"# Cut frames without annotations\n",
|
||||
"trial_df = trial_df[trial_df.ID.notnull()] \n",
|
||||
"\n",
|
||||
"print(f'Episodes: {trial_df.ID.unique()}\\n')\n",
|
||||
"\n",
|
||||
"full_df = trial_df.copy()\n",
|
||||
"\n",
|
||||
"for idx in df.index[1:]:\n",
|
||||
" row = df.iloc[idx]\n",
|
||||
" if row.GameName == 'montezuma_revenge':\n",
|
||||
" path = os.path.join(DATA_PATH, row.trial_folder)\n",
|
||||
" elif row.GameName == 'montezuma_revenge_highscore':\n",
|
||||
" path = os.path.join(DATA_PATH, 'highscore', row.trial_folder)\n",
|
||||
" else: \n",
|
||||
" path = ''\n",
|
||||
" warnings.warn(f\"GameName of row {idx} not recognised! Returning empty path.\")\n",
|
||||
" print(f'Reading {path}')\n",
|
||||
" \n",
|
||||
" # Read Annotations\n",
|
||||
" trial_df = utils.txt_to_dataframe(path + '.txt') \n",
|
||||
" \n",
|
||||
" # Write unique ID \n",
|
||||
" trial_df['ID'] = trial_df.apply(lambda x: write_unique_id(x['frame_id'], x['episode_id'], row.trial_id), axis=1)\n",
|
||||
" \n",
|
||||
" # Write image paths\n",
|
||||
" trial_df['img_path'] = trial_df.apply(lambda x: os.path.join(path, str(x['frame_id']) + '.png'), axis=1)\n",
|
||||
"\n",
|
||||
" # Cut frames without annotations\n",
|
||||
" trial_df = trial_df[trial_df.ID.notnull()] \n",
|
||||
"\n",
|
||||
" print(f'Episodes: {trial_df.ID.unique()}\\n')\n",
|
||||
" full_df = pd.concat([full_df, trial_df], join='inner', ignore_index=True)\n",
|
||||
"\n",
|
||||
"outpath = os.path.join(DATA_PATH, \"all_trials.pkl\")\n",
|
||||
"print(f'Saving dataframe to {outpath}\\n')\n",
|
||||
"\n",
|
||||
"full_df.to_pickle(outpath)\n",
|
||||
"full_df.head()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
262
RAMStateLabeling.ipynb
Normal file
262
RAMStateLabeling.ipynb
Normal file
|
@ -0,0 +1,262 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e5e28b09",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Get RAM state of Montezuma's Revenge"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9660bf17",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import random\n",
|
||||
"import cv2\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd \n",
|
||||
"import matplotlib.pyplot as plt \n",
|
||||
"import gym\n",
|
||||
"\n",
|
||||
"from atariari.benchmark.wrapper import AtariARIWrapper\n",
|
||||
"from utils import visualize_sample\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"DATA_PATH = 'montezuma_revenge'\n",
|
||||
"\n",
|
||||
"df = pd.read_pickle(os.path.join(DATA_PATH, \"all_trials.pkl\"))\n",
|
||||
"df.head()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4417d418",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Use AtariARI Wrapper to extract RAM state \n",
|
||||
"```\n",
|
||||
"labels: {'room_number': 15,\n",
|
||||
"'player_x': 46,\n",
|
||||
"'player_y': 235,\n",
|
||||
"'player_direction': 76,\n",
|
||||
"'enemy_skull_x': 58,\n",
|
||||
"'enemy_skull_y': 240,\n",
|
||||
"'key_monster_x': 132,\n",
|
||||
"'key_monster_y': 254,\n",
|
||||
"'level': 0,\n",
|
||||
"'num_lives': 1,\n",
|
||||
"'items_in_inventory_count': 0,\n",
|
||||
"'room_state': 10,\n",
|
||||
"'score_0': 1,\n",
|
||||
"'score_1': 8,\n",
|
||||
"'score_2': 0}```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "19450d3e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env = AtariARIWrapper(gym.make('MontezumaRevenge-v4', \n",
|
||||
" frameskip=1, \n",
|
||||
" render_mode='rgb_array', \n",
|
||||
" repeat_action_probability=0.0))\n",
|
||||
"\n",
|
||||
"#env.unwrapped.ale.getRAM()\n",
|
||||
"obs = env.reset(seed=42)\n",
|
||||
"obs, reward, done, info = env.step(1)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "cc13394b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Visualize AtariHEAD data and RAM state labels \n",
|
||||
"> offset of player and skull locations was discovered manually "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3882fa3f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from IPython import display\n",
|
||||
"obs = env.reset()\n",
|
||||
"\n",
|
||||
"screen = plt.imshow(env.render(mode='rgb_array'), aspect='auto')\n",
|
||||
"plt.axis('off')\n",
|
||||
"\n",
|
||||
"all_images = []\n",
|
||||
"agent_locations = []\n",
|
||||
"skull_locations = []\n",
|
||||
"room_ids = []\n",
|
||||
"\n",
|
||||
"for i, action in enumerate(df.loc[df.ID == '285_RZ_5619207_E00'].action.values): \n",
|
||||
"\n",
|
||||
" n_state, reward, done, info = env.step(action)\n",
|
||||
" img = info['rgb']\n",
|
||||
" room_ids.append(info['labels']['room_number'])\n",
|
||||
" \n",
|
||||
" # agent \n",
|
||||
" mean_x, mean_y = info['labels']['player_x'], 320 - info['labels']['player_y']\n",
|
||||
" agent_locations.append([mean_x, mean_y])\n",
|
||||
" \n",
|
||||
" x1, x2, y1, y2 = mean_x - 5 , mean_x + 10, mean_y - 15, mean_y + 10\n",
|
||||
" img = cv2.rectangle(img, (x1, y1), (x2, y2), (0,255,0), 2)\n",
|
||||
" \n",
|
||||
" # skull\n",
|
||||
" mean_x, mean_y = info['labels']['enemy_skull_x'] + 35, info['labels']['enemy_skull_y'] - 65\n",
|
||||
" skull_locations.append([mean_x, mean_y])\n",
|
||||
" x1, x2, y1, y2 = mean_x - 5, mean_x + 5, mean_y - 10, mean_y + 5\n",
|
||||
" img = cv2.rectangle(img, (x1, y1), (x2, y2), (255,0,0), 2)\n",
|
||||
" \n",
|
||||
" img = cv2.putText(img=img, text='Room ID: ' + str(info['labels']['room_number']) + ' index: ' + str(i), org=(5, 205), fontFace=cv2.FONT_HERSHEY_SIMPLEX, \n",
|
||||
" fontScale=0.3, color=(255, 255, 255),thickness=1)\n",
|
||||
" \n",
|
||||
" screen.set_data(img) # just update the data\n",
|
||||
" display.display(plt.gcf())\n",
|
||||
" display.clear_output(wait=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1031fee6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Number of actions with correct labeling for environment with random seed = 42 \n",
|
||||
"> Discovered manually through above visualization \n",
|
||||
"[-1: all actions valid, 0: no actions valid]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "868493b1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"test = {'284_RZ_5540489_E00': 11900, '285_RZ_5619207_E00': 2940, '285_RZ_5619207_E01': -1,\n",
|
||||
" '287_RZ_7172481_E00': 0, '291_RZ_7364933_E00': 12000, '324_RZ_452975_E00':0,\n",
|
||||
" '333_RZ_900705_E00': 3000, '340_RZ_1323550_E00': 5950, '359_RZ_1993616_E00': 9000,\n",
|
||||
" '365_RZ_2079996_E00': 9000, '371_RZ_2173469_E00': -1, '385_RZ_2344725_E00': 3500,\n",
|
||||
" '398_RZ_2530473_E00': 1200, '402_RZ_2603283_E00': -1, '416_RZ_2788252_E00': -1,\n",
|
||||
" '429_RZ_2945490_E00': 4500, '436_RZ_3131841_E00': 10500, '459_RZ_3291266_E00': 5400,\n",
|
||||
" '469_RZ_3390904_E00': 14500, '480_RZ_3470098_E00': 8000, '493_RZ_3557734_E00': 10500, \n",
|
||||
" '523_RZ_4091327_E00': 0, '536_RZ_4420664_E00': 0, '548_RZ_4509746_E00': 0,\n",
|
||||
" '561_RZ_4598680_E00': 0, '573_RZ_4680777_E00': 0, '584_RZ_4772014_E00': 0,\n",
|
||||
" '588_RZ_5032278_E00': 0}\n",
|
||||
"\n",
|
||||
"num_frames = 0 \n",
|
||||
"num_labeled_frames = 0\n",
|
||||
"counter = 0 \n",
|
||||
"for episode in test.keys():\n",
|
||||
" counter += 1\n",
|
||||
" num_frames += len(df.loc[df.ID == episode])\n",
|
||||
" num_samples = test.get(episode)\n",
|
||||
" if num_samples == -1:\n",
|
||||
" num_labeled_frames += len(df.loc[df.ID == episode])\n",
|
||||
" else: \n",
|
||||
" num_labeled_frames += num_samples\n",
|
||||
" \n",
|
||||
"print(f'Overall percantage {num_labeled_frames / num_frames:%} for 21 episodes')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7f3c0026",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Label Atari-HEAD data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "fe0dcd7a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%%time \n",
|
||||
"df['level'] = None \n",
|
||||
"df['room_id'] = None\n",
|
||||
"df['player_location'] = None\n",
|
||||
"df['skull_location'] = None \n",
|
||||
"\n",
|
||||
"for episode in df.ID.unique():\n",
|
||||
" \n",
|
||||
" obs = env.reset()\n",
|
||||
" room_ids = []\n",
|
||||
" agent_locations = []\n",
|
||||
" skull_locations = []\n",
|
||||
" level = []\n",
|
||||
" \n",
|
||||
" num_valid_actions = test.get(episode)\n",
|
||||
"\n",
|
||||
" for action in df.loc[df.ID == episode].action.values[:num_valid_actions]: \n",
|
||||
" \n",
|
||||
" n_state, reward, done, info = env.step(action)\n",
|
||||
" room_ids.append(info['labels']['room_number'])\n",
|
||||
" level.append(info['labels']['level'])\n",
|
||||
" \n",
|
||||
" # agent \n",
|
||||
" mean_x, mean_y = info['labels']['player_x'], 320 - info['labels']['player_y']\n",
|
||||
" agent_locations.append([mean_x, mean_y])\n",
|
||||
" \n",
|
||||
" # skull\n",
|
||||
" mean_x, mean_y = info['labels']['enemy_skull_x'] + 35, info['labels']['enemy_skull_y'] - 65\n",
|
||||
" skull_locations.append([mean_x, mean_y])\n",
|
||||
" \n",
|
||||
" index = df.loc[df.ID == episode].index[:num_valid_actions]\n",
|
||||
" df.loc[index, 'level'] = level\n",
|
||||
" df.loc[index, 'room_id'] = room_ids\n",
|
||||
" df.loc[index, 'player_location'] = agent_locations\n",
|
||||
" df.loc[index, 'skull_location'] = skull_locations\n",
|
||||
"\n",
|
||||
"print(f'Percentage of labeled data {len(df[df.room_id.notnull()]) / len(df):%}')\n",
|
||||
"print()\n",
|
||||
"df.to_pickle(os.path.join(DATA_PATH, \"all_trials_labeled.pkl\")) \n",
|
||||
"df.head()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "96ff9136",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "msc_env",
|
||||
"language": "python",
|
||||
"name": "msc_env"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
40
README.md
40
README.md
|
@ -1,2 +1,42 @@
|
|||
# Int-HRL
|
||||
|
||||
This is the official repository for [Int-HRL: Towards Intention-based Hierarchical Reinforcement Learning](https://perceptualui.org/publications/penzkofer23_ala/)<br>
|
||||
|
||||
Int-HRL uses eye gaze from human demonstration data on the Atari game Montezuma's Revenge to extract human player's intentions and converts them to sub-goals for Hierarchical Reinforcement Learning (HRL). For further details take a look at the corresponding paper.
|
||||
|
||||
## Dataset
|
||||
Atari-HEAD: Atari Human Eye-Tracking and Demonstration Dataset available at [https://zenodo.org/record/3451402#.Y5chr-zMK3J](https://zenodo.org/record/3451402#.Y5chr-zMK3J) <br>
|
||||

|
||||
|
||||
To pre-process the Atari-HEAD data run [Preprocess_AtariHEAD.ipynb](Preprocess_AtariHEAD.ipynb), yielding the `all_trials.pkl` file needed for the following steps.
|
||||
|
||||
## Sub-goal Extraction Pipeline
|
||||
|
||||
1. [RAM State Labeling](RAMStateLabeling.ipynb): annotate Atari-HEAD data with room id and level information, as well as agent and skull location
|
||||
2. [Subgoals From Gaze](SubgoalsFromGaze.ipynb): run sub-goal proposal extraction by generating saliency maps
|
||||
3. [Alignment with Trajectory](TrajectoryMatching.ipynb): run expert trajectory to get order of subgoals
|
||||
|
||||
## Intention-based Hierarchical RL Agent
|
||||
under construction
|
||||
|
||||
## Citation
|
||||
Please consider citing these paper if you use Int-HRL or parts of this repository in your research:
|
||||
```
|
||||
@article{penzkofer24_ncaa,
|
||||
author = {Penzkofer, Anna and Schaefer, Simon and Strohm, Florian and Bâce, Mihai and Leutenegger, Stefan and Bulling, Andreas},
|
||||
title = {Int-HRL: Towards Intention-based Hierarchical Reinforcement Learning},
|
||||
journal = {Neural Computing and Applications (NCAA)},
|
||||
year = {2024},
|
||||
pages = {1--7},
|
||||
doi = {10.1007/s00521-024-10596-2},
|
||||
volume = {36}
|
||||
}
|
||||
@inproceedings{penzkofer23_ala,
|
||||
author = {Penzkofer, Anna and Schaefer, Simon and Strohm, Florian and Bâce, Mihai and Leutenegger, Stefan and Bulling, Andreas},
|
||||
title = {Int-HRL: Towards Intention-based Hierarchical Reinforcement Learning},
|
||||
booktitle = {Proc. Adaptive and Learning Agents Workshop (ALA)},
|
||||
year = {2023},
|
||||
doi = {10.48550/arXiv.2306.11483},
|
||||
pages = {1--7}
|
||||
}
|
||||
```
|
||||
|
|
262
SubgoalsFromGaze.ipynb
Normal file
262
SubgoalsFromGaze.ipynb
Normal file
|
@ -0,0 +1,262 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b694995f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Extract Subgoals from Gaze\n",
|
||||
"\n",
|
||||
"### For each episode\n",
|
||||
"- Generate saliency map of first room and threshold\n",
|
||||
"- Draw bounding box around each _salient_ pixel \n",
|
||||
"- Perform Non-Maximum Supression (NMS) on generated bounding boxes with iou threshold \n",
|
||||
"- Merge resulting boxes if there are still overlaps \n",
|
||||
"\n",
|
||||
"### For room 1\n",
|
||||
"- Perform NMS on the filtered and merged bounding box proposals from all episodes\n",
|
||||
"- Merge again "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d1245d4c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import random\n",
|
||||
"\n",
|
||||
"import cv2\n",
|
||||
"import torch\n",
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd \n",
|
||||
"import matplotlib.pyplot as plt \n",
|
||||
"from scipy.ndimage import gaussian_filter\n",
|
||||
"from torchvision.ops import masks_to_boxes\n",
|
||||
"\n",
|
||||
"from dataset_utils import visualize_sample, apply_nms, merge_boxes, get_subgoal_proposals, SIGMA\n",
|
||||
"\n",
|
||||
"DATA_PATH = 'montezuma_revenge'\n",
|
||||
"\n",
|
||||
"df = pd.read_pickle(os.path.join(DATA_PATH, \"all_trials_labeled.pkl\"))\n",
|
||||
"\n",
|
||||
"init_screen = cv2.imread(df.iloc[0].img_path)\n",
|
||||
"init_screen = cv2.cvtColor(init_screen, cv2.COLOR_BGR2RGB)\n",
|
||||
"\n",
|
||||
"df.head()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "27bf55d2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Generate example saliency map "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8b13e060",
|
||||
"metadata": {
|
||||
"scrolled": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Get gaze from one run \n",
|
||||
"episode = '284_RZ_5540489_E00'\n",
|
||||
"gaze = df.loc[df.ID == episode].loc[df.room_id == 1].loc[df.level==0].gaze_positions\n",
|
||||
"\n",
|
||||
"flat_list = []\n",
|
||||
"for gaze_points in gaze:\n",
|
||||
" if gaze_points is not None: \n",
|
||||
" for item in gaze_points:\n",
|
||||
" flat_list.append(item)\n",
|
||||
"\n",
|
||||
"saliency_map = np.zeros(init_screen.shape[:2])\n",
|
||||
"threshold = 0.35\n",
|
||||
"\n",
|
||||
"# Add gaze coordinates to saliency map\n",
|
||||
"for cords in flat_list:\n",
|
||||
" try: \n",
|
||||
" saliency_map[int(cords[1])][int(cords[0])] += 1\n",
|
||||
" except:\n",
|
||||
" # Not all gaze points are on image \n",
|
||||
" continue\n",
|
||||
"\n",
|
||||
"# Construct fixation map \n",
|
||||
"fix_map = saliency_map >= 1.0 \n",
|
||||
"\n",
|
||||
"# Construct empirical saliency map\n",
|
||||
"saliency_map = gaussian_filter(saliency_map, sigma=SIGMA, mode='nearest')\n",
|
||||
"\n",
|
||||
"# Normalize saliency map into range [0, 1]\n",
|
||||
"if not saliency_map.max() == 0:\n",
|
||||
" saliency_map /= saliency_map.max()\n",
|
||||
"\n",
|
||||
"gray = cv2.cvtColor(init_screen, cv2.COLOR_BGR2GRAY) \n",
|
||||
"fov_image = np.multiply(saliency_map, gray) # element-wise product\n",
|
||||
"\n",
|
||||
"visualize_sample(init_screen, saliency_map)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "334fc3b8",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Find good saliency threshold "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8043f4bf",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"plt.hist(saliency_map.flatten())\n",
|
||||
"plt.show()\n",
|
||||
"\n",
|
||||
"mask = saliency_map > 0.4\n",
|
||||
"masked_saliency = saliency_map.copy()\n",
|
||||
"masked_saliency[~mask] = 0 \n",
|
||||
"print('Threshold: 0.4')\n",
|
||||
"visualize_sample(init_screen, masked_saliency)\n",
|
||||
"\n",
|
||||
"mask = saliency_map > 0.35\n",
|
||||
"masked_saliency = saliency_map.copy()\n",
|
||||
"masked_saliency[~mask] = 0 \n",
|
||||
"print('Threshold: 0.35')\n",
|
||||
"visualize_sample(init_screen, masked_saliency)\n",
|
||||
"\n",
|
||||
"mask = saliency_map > 0.2\n",
|
||||
"masked_saliency = saliency_map.copy()\n",
|
||||
"masked_saliency[~mask] = 0 \n",
|
||||
"print('Threshold: 0.2')\n",
|
||||
"visualize_sample(init_screen, masked_saliency)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9f4a356f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Extract subgoals from all saliency maps"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6aad7e32",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%%time \n",
|
||||
"\n",
|
||||
"SALIENCY_THRESH = 0.3\n",
|
||||
"VISUALIZE = False \n",
|
||||
"ROOM = 1\n",
|
||||
"\n",
|
||||
"init_screen = cv2.imread(df.loc[df.room_id == ROOM].iloc[10].img_path)\n",
|
||||
"init_screen = cv2.cvtColor(init_screen, cv2.COLOR_BGR2RGB)\n",
|
||||
"\n",
|
||||
"subgoal_proposals = get_subgoal_proposals(df, threshold=SALIENCY_THRESH, visualize=VISUALIZE, room=ROOM)\n",
|
||||
"\n",
|
||||
"proposal_df = pd.DataFrame.from_dict(subgoal_proposals, orient='index', columns=['bboxes', 'merged_bboxes'])\n",
|
||||
"proposal_df.to_pickle(os.path.join(DATA_PATH, f\"room1_subgoal_proposals{int(SALIENCY_THRESH * 100)}.pkl\")) \n",
|
||||
"\n",
|
||||
"proposal_df.head()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "cbe23a6a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"img = init_screen.copy()\n",
|
||||
"for proposals in proposal_df.merged_bboxes:\n",
|
||||
" \n",
|
||||
" for box in proposals:\n",
|
||||
" img = cv2.rectangle(img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (255,0,0), 1)\n",
|
||||
"\n",
|
||||
"fig = plt.figure(figsize=(8,8))\n",
|
||||
"plt.imshow(img)\n",
|
||||
"plt.axis('off')\n",
|
||||
"plt.savefig(f'visualizations/room{ROOM}_all_subgoal_proposals{int(SALIENCY_THRESH * 100)}.png', bbox_inches='tight')\n",
|
||||
"plt.show()\n",
|
||||
"\n",
|
||||
"all_proposals = np.concatenate(proposal_df.merged_bboxes.to_numpy())\n",
|
||||
"print(all_proposals.shape)\n",
|
||||
"\n",
|
||||
"# Non-max suppression \n",
|
||||
"keep = apply_nms(all_proposals, thresh_iou=0.01)\n",
|
||||
"\n",
|
||||
"print('Bounding boxes after non-maximum suppression')\n",
|
||||
"img = init_screen.copy()\n",
|
||||
"for box in keep:\n",
|
||||
" img = cv2.rectangle(img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (255,0,0), 1)\n",
|
||||
"\n",
|
||||
"fig = plt.figure(figsize=(8,8))\n",
|
||||
"plt.imshow(img)\n",
|
||||
"plt.axis('off')\n",
|
||||
"plt.show()\n",
|
||||
"\n",
|
||||
"merged = merge_boxes(keep)\n",
|
||||
"print('Bounding boxes after merging')\n",
|
||||
"img = init_screen.copy()\n",
|
||||
"for box in merged:\n",
|
||||
" img = cv2.rectangle(img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (255,0,0), 1)\n",
|
||||
"\n",
|
||||
"fig = plt.figure(figsize=(8,8))\n",
|
||||
"plt.imshow(img)\n",
|
||||
"plt.axis('off')\n",
|
||||
"plt.savefig(f'visualizations/room{ROOM}_final_subgoals{int(SALIENCY_THRESH * 100)}.png', bbox_inches='tight')\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "fdf9db11",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"np.savetxt('subgoals.txt', merged, fmt='%i', delimiter=',')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ad587996",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "msc_env",
|
||||
"language": "python",
|
||||
"name": "msc_env"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
1108
TrajectoryMatching.ipynb
Normal file
1108
TrajectoryMatching.ipynb
Normal file
File diff suppressed because one or more lines are too long
432
dataset_utils.py
Normal file
432
dataset_utils.py
Normal file
|
@ -0,0 +1,432 @@
|
|||
import random
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
from scipy.ndimage import gaussian_filter
|
||||
from tqdm import tqdm
|
||||
from scipy import interpolate
|
||||
from sklearn.preprocessing import normalize
|
||||
|
||||
|
||||
# Atari-HEAD constants
|
||||
SIGMA = (210 / 44.6, 160 / 28.5)
|
||||
SUBJECT_TO_SCREEN = 787
|
||||
|
||||
SCREEN_WIDTH_MM = 646
|
||||
SCREEN_HEIGHT_MM = 400
|
||||
|
||||
SCREEN_WIDTH_PX = 1280
|
||||
SCREEN_HEIGHT_PX = 840
|
||||
|
||||
|
||||
TYPES = {'frame_id': str, 'episode_id': int, 'score': int, 'duration(ms)': int,
|
||||
'unclipped_reward': int, 'action': int, 'gaze_positions': list}
|
||||
|
||||
ALE_ENUMS = {0: 'PLAYER_A_NOOP', 1: 'PLAYER_A_FIRE', 2: 'PLAYER_A_UP', 3: 'PLAYER_A_RIGHT', 4: 'PLAYER_A_LEFT', 5: 'PLAYER_A_DOWN',
|
||||
6: 'PLAYER_A_UPRIGHT', 7: 'PLAYER_A_UPLEFT', 8: 'PLAYER_A_DOWNRIGHT', 9: 'PLAYER_A_DOWNLEFT',
|
||||
10: 'PLAYER_A_UPFIRE', 11: 'PLAYER_A_RIGHTFIRE', 12: 'PLAYER_A_LEFTFIRE', 13: 'PLAYER_A_DOWNFIRE',
|
||||
14: 'PLAYER_A_UPRIGHTFIRE', 15: 'PLAYER_A_UPLEFTFIRE', 16: 'PLAYER_A_DOWNRIGHTFIRE', 17: 'PLAYER_A_DOWNLEFTFIRE'}
|
||||
|
||||
|
||||
def txt_to_dataframe(path: str) -> pd.DataFrame:
|
||||
"""Read txt file with annotations for trial line by line and add to new dataframe.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : str
|
||||
The path to the trial's txt file e.g. 291_RZ_7364933_May-08-20-23-25.txt
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame
|
||||
Dataframe with one frame per row and columns TYPES if available.
|
||||
"""
|
||||
|
||||
file = open(path, 'r')
|
||||
Lines = file.readlines()
|
||||
|
||||
columns = Lines[0].strip().split(',')
|
||||
|
||||
trial_df = pd.DataFrame(columns=columns)
|
||||
|
||||
for line in Lines[1:]:
|
||||
raw_vals = line.strip().split(',')
|
||||
vals = dict()
|
||||
for i, c in enumerate(columns):
|
||||
if not c == 'gaze_positions':
|
||||
try:
|
||||
vals[c] = [TYPES.get(c)((raw_vals[i]))]
|
||||
except:
|
||||
vals[c] = None
|
||||
#print('WARNING', c, raw_vals[i])
|
||||
else:
|
||||
# gaze_positions: x0,y0,x1,y1,...,xn,yn. Gaze positions for the current frame.
|
||||
# Could be null if no gaze. (0,0) is the top-left corner. x: horizontal axis. y: vertical.
|
||||
try:
|
||||
gaze_positions = np.array([float(v) for v in raw_vals[i:]]).reshape(-1, 2)
|
||||
except Exception as e:
|
||||
gaze_positions = None
|
||||
#print(f'WARNING: no gaze data available for frame_id: {vals["frame_id"]} because {e}', raw_vals[i])
|
||||
new_df = pd.DataFrame(vals)
|
||||
new_df['gaze_positions'] = [gaze_positions]
|
||||
|
||||
trial_df = pd.concat([trial_df, new_df], ignore_index=True)
|
||||
|
||||
return trial_df
|
||||
|
||||
|
||||
def get_subgoal_proposals(df, threshold=0.35, visualize=False, room=1) -> dict():
|
||||
|
||||
# Get init screen for visualizations
|
||||
init_screen = cv2.imread(df.iloc[0].img_path)
|
||||
init_screen = cv2.cvtColor(init_screen, cv2.COLOR_BGR2RGB)
|
||||
|
||||
subgoal_proposals = {}
|
||||
|
||||
for episode in df.ID.unique():
|
||||
|
||||
gaze = df.loc[df.ID == episode].loc[df.room_id == room].loc[df.level==0].gaze_positions
|
||||
|
||||
if gaze is None:
|
||||
continue
|
||||
|
||||
# Generate saliency map
|
||||
saliency_map = np.zeros(init_screen.shape[:2])
|
||||
for gaze_points in gaze:
|
||||
if gaze_points is not None:
|
||||
for item in gaze_points:
|
||||
try:
|
||||
saliency_map[int(item[1])][int(item[0])] += 1
|
||||
except:
|
||||
# Not all gaze points are on image
|
||||
continue
|
||||
|
||||
# Construct fixation map
|
||||
fix_map = saliency_map >= 1.0
|
||||
|
||||
# Construct empirical saliency map
|
||||
saliency_map = gaussian_filter(saliency_map, sigma=SIGMA, mode='nearest')
|
||||
|
||||
# Normalize saliency map into range [0, 1]
|
||||
if not saliency_map.max() == 0:
|
||||
saliency_map /= saliency_map.max()
|
||||
|
||||
proposals_y, proposals_x = np.where(saliency_map > threshold)
|
||||
|
||||
bboxes = []
|
||||
scores = []
|
||||
for x, y in zip(proposals_x, proposals_y):
|
||||
# draw bounding box around saliency map peak in panama joe size
|
||||
box = [x - 5, y - 10, x + 5, y + 10]
|
||||
bboxes.append(box)
|
||||
scores.append(saliency_map[y][x])
|
||||
|
||||
if len(bboxes) == 0:
|
||||
continue
|
||||
|
||||
# Non-max suppression
|
||||
keep = apply_nms(np.array(bboxes), np.array(scores), thresh_iou=0.1)
|
||||
|
||||
# Merge boxes with any iou > 0
|
||||
# Note: run might generate new ious > 0
|
||||
merged = merge_boxes(keep)
|
||||
|
||||
subgoal_proposals[episode] = [keep, merged]
|
||||
|
||||
if visualize:
|
||||
print('Episode: ', episode)
|
||||
mask = saliency_map > threshold
|
||||
masked_saliency = saliency_map.copy()
|
||||
masked_saliency[~mask] = 0
|
||||
|
||||
img = masked_saliency.copy()
|
||||
for box in random.choices(bboxes, k=25):
|
||||
img = cv2.rectangle(img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (1,0,0), 1)
|
||||
|
||||
print('Number of bounding box proposals: ', len(bboxes))
|
||||
fig = plt.figure(figsize=(8,8))
|
||||
plt.imshow(init_screen)
|
||||
plt.imshow(img, cmap='jet', alpha=0.5)
|
||||
plt.axis('off')
|
||||
plt.show()
|
||||
|
||||
print('Bounding boxes after non-maximum suppression')
|
||||
img = init_screen.copy()
|
||||
for box in keep:
|
||||
img = cv2.rectangle(img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (255,0,0), 1)
|
||||
|
||||
fig = plt.figure(figsize=(8,8))
|
||||
plt.imshow(img)
|
||||
plt.axis('off')
|
||||
plt.show()
|
||||
|
||||
print('Bounding boxes after merging')
|
||||
img = init_screen.copy()
|
||||
for box in keep:
|
||||
img = cv2.rectangle(img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (255,0,0), 1)
|
||||
|
||||
fig = plt.figure(figsize=(8,8))
|
||||
plt.imshow(img)
|
||||
plt.axis('off')
|
||||
plt.show()
|
||||
|
||||
return subgoal_proposals
|
||||
|
||||
|
||||
def visualize_sample(image, target):
|
||||
|
||||
fig = plt.figure(figsize=(12,6))
|
||||
|
||||
ax1 = fig.add_subplot(131)
|
||||
ax2 = fig.add_subplot(132)
|
||||
ax3 = fig.add_subplot(133)
|
||||
|
||||
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
fov_image = np.multiply(target, gray) # element-wise product
|
||||
|
||||
ax1.imshow(image)
|
||||
ax1.set_title('Input image')
|
||||
ax1.axis('off')
|
||||
|
||||
ax2.imshow(target, cmap='jet')
|
||||
ax2.set_title('Saliency map')
|
||||
ax2.axis('off')
|
||||
|
||||
ax3.imshow(fov_image, cmap='gray')
|
||||
ax3.set_title('Foveated image')
|
||||
ax3.axis('off')
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def saliency_map_to_image(saliency_map):
|
||||
minimum_value = saliency_map.min()
|
||||
if minimum_value < 0:
|
||||
saliency_map = saliency_map - minimum_value
|
||||
|
||||
saliency_map = saliency_map * 255 / saliency_map.max()
|
||||
|
||||
image_data = np.round(saliency_map).astype(np.uint8)
|
||||
|
||||
return image_data
|
||||
|
||||
|
||||
def apply_nms(boxes: np.ndarray, scores: np.ndarray = None, thresh_iou: float = 0.2) -> np.ndarray:
|
||||
"""
|
||||
adapted from https://learnopencv.com/non-maximum-suppression-theory-and-implementation-in-pytorch/
|
||||
Apply non-maximum suppression to avoid detecting too many
|
||||
overlapping bounding boxes based on iou threshold.
|
||||
"""
|
||||
|
||||
x1 = boxes[:, 0] # x coordinate of the top-left corner
|
||||
y1 = boxes[:, 1] # y coordinate of the top-left corner
|
||||
x2 = boxes[:, 2] # x coordinate of the bottom-right corner
|
||||
y2 = boxes[:, 3] # y coordinate of the bottom-right corner
|
||||
|
||||
# calculate area of every block in boxes
|
||||
areas = (x2 - x1) * (y2 - y1)
|
||||
|
||||
if scores is not None:
|
||||
# sort the prediction boxes according to their confidence scores
|
||||
order = scores.argsort()
|
||||
else:
|
||||
order = y2.argsort()
|
||||
|
||||
# initialise an empty list for filtered prediction boxes
|
||||
keep = []
|
||||
|
||||
while len(order) > 0:
|
||||
|
||||
# extract the index of the prediction with highest score and add to keep list
|
||||
idx = order[-1]
|
||||
keep.append(boxes[idx])
|
||||
order = order[:-1]
|
||||
|
||||
# sanity check
|
||||
if len(order) == 0:
|
||||
break
|
||||
|
||||
# select coordinates of boxes according to the indices in order
|
||||
xx1 = np.take(x1, indices=order, axis=0)
|
||||
xx2 = np.take(x2, indices=order, axis=0)
|
||||
yy1 = np.take(y1, indices=order, axis=0)
|
||||
yy2 = np.take(y2, indices=order, axis=0)
|
||||
|
||||
# find the coordinates of the intersection boxes
|
||||
xx1 = np.maximum(xx1, x1[idx])
|
||||
yy1 = np.maximum(yy1, y1[idx])
|
||||
xx2 = np.minimum(xx2, x2[idx])
|
||||
yy2 = np.minimum(yy2, y2[idx])
|
||||
|
||||
# find out the width and the height of the intersection box
|
||||
w = np.maximum(0, xx2 - xx1)
|
||||
h = np.maximum(0, yy2 - yy1)
|
||||
|
||||
# find the intersection area
|
||||
inter = w*h
|
||||
|
||||
# find the areas of boxes according to indices in order
|
||||
rem_areas = np.take(areas, indices=order, axis=0)
|
||||
|
||||
# find the union of every box with currently selected box
|
||||
union = (rem_areas - inter) + areas[idx]
|
||||
|
||||
# find the IoU of every box with currently selected box
|
||||
IoU = inter / union
|
||||
|
||||
# keep the boxes with IoU less than thresh_iou
|
||||
mask = IoU < thresh_iou
|
||||
order = order[mask]
|
||||
|
||||
return np.array(keep)
|
||||
|
||||
|
||||
def merge_boxes(boxes: np.ndarray) -> np.ndarray:
|
||||
x1 = boxes[:, 0] # x coordinate of the top-left corner
|
||||
y1 = boxes[:, 1] # y coordinate of the top-left corner
|
||||
x2 = boxes[:, 2] # x coordinate of the bottom-right corner
|
||||
y2 = boxes[:, 3] # y coordinate of the bottom-right corner
|
||||
|
||||
# calculate area of every block in boxes
|
||||
areas = (x2 - x1) * (y2 - y1)
|
||||
|
||||
merged = []
|
||||
indices = np.arange(len(boxes))
|
||||
|
||||
while len(indices) > 0:
|
||||
idx = indices[0]
|
||||
# find the coordinates of the intersection boxes
|
||||
xx1 = np.maximum(x1, x1[idx])
|
||||
yy1 = np.maximum(y1, y1[idx])
|
||||
xx2 = np.minimum(x2, x2[idx])
|
||||
yy2 = np.minimum(y2, y2[idx])
|
||||
|
||||
# find out the width and the height of the intersection box
|
||||
w = np.maximum(0, xx2 - xx1)
|
||||
h = np.maximum(0, yy2 - yy1)
|
||||
|
||||
# find the intersection over union of every box with currently selected box
|
||||
inter = w * h
|
||||
union = (areas - inter) + areas[idx]
|
||||
iou = inter / union
|
||||
|
||||
merge_idx = np.where(iou > 0.0)[0]
|
||||
|
||||
# box surrounding all selected boxes --> [min(x1), min(y1)] x [max(x2), max(y2)]
|
||||
big_box = [boxes[merge_idx, 0].min(), boxes[merge_idx, 1].min(),
|
||||
boxes[merge_idx, 2].max(), boxes[merge_idx, 3].max()]
|
||||
|
||||
merged.append(big_box)
|
||||
delete_idx = [np.where(indices == i)[0] for i in merge_idx if len(np.where(indices == i)[0]) > 0]
|
||||
indices = np.delete(indices, delete_idx)
|
||||
|
||||
return np.array(merged)
|
||||
|
||||
|
||||
def pixel_to_3D(gaze_positions):
|
||||
if gaze_positions.shape[0] != 2:
|
||||
gaze_positions = np.moveaxis(gaze_positions, 0, 1)
|
||||
|
||||
x, y = gaze_positions
|
||||
|
||||
x *= SCREEN_WIDTH_MM / SCREEN_WIDTH_PX
|
||||
y *= SCREEN_HEIGHT_MM / SCREEN_HEIGHT_PX
|
||||
|
||||
gaze_positions_3D = np.array([x, y, [SUBJECT_TO_SCREEN] * len(x)])
|
||||
|
||||
return np.moveaxis(gaze_positions_3D, 0, 1)
|
||||
|
||||
def get_velocity_vectorized(gaze: np.ndarray, ratio: float):
|
||||
|
||||
# FIXED: https://stackoverflow.com/questions/52457989/pandas-df-apply-unexpectedly-changes-dataframe-inplace
|
||||
gaze = gaze.copy()
|
||||
|
||||
# pixel coordinates to 3D world coordinates
|
||||
gaze_3D = pixel_to_3D(gaze)
|
||||
|
||||
# vectorize gaze[i], gaze[i+1] by shifting vector by 1
|
||||
u, v = gaze_3D[:-1], gaze_3D[1:]
|
||||
assert len(u) == len(v)
|
||||
|
||||
"""
|
||||
# normalize
|
||||
try:
|
||||
u, v = normalize(u), normalize(v)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
"""
|
||||
# Normalize each vector u and v --> ||u[i]|| = ||v[i]|| = 1
|
||||
norm_mat_u = np.stack([np.linalg.norm(u, axis=1), np.linalg.norm(u, axis=1), np.linalg.norm(u, axis=1)], axis=1)
|
||||
norm_mat_v = np.stack([np.linalg.norm(v, axis=1), np.linalg.norm(v, axis=1), np.linalg.norm(v, axis=1)], axis=1)
|
||||
|
||||
u /= norm_mat_u
|
||||
v /= norm_mat_v
|
||||
|
||||
u_minus_v = np.linalg.norm(u - v, axis=1) # || u - v ||
|
||||
u_plus_v = np.linalg.norm(u + v, axis=1) # || u + v ||
|
||||
|
||||
# angular displacement
|
||||
theta = 2 * np.arctan2(u_minus_v, u_plus_v) * 5.73 # converts the unit from radians to degrees
|
||||
|
||||
# velocity with average fps
|
||||
velocity = (theta / ratio) * 10000 # converts the unit from microsecond to degrees per second
|
||||
|
||||
return theta, velocity
|
||||
|
||||
def interpolate_outliers(gaze, ratio, threshold=800, visualize=False):
|
||||
|
||||
idx = np.where(gaze > threshold)[0][0]
|
||||
|
||||
x = list(np.arange(len(gaze)))
|
||||
x.pop(idx)
|
||||
|
||||
gaze = list(gaze)
|
||||
gaze.pop(idx)
|
||||
|
||||
# outliers on border can't be interpolated -> remove entirely
|
||||
if idx == 0 or idx == len(gaze):
|
||||
return gaze
|
||||
|
||||
else:
|
||||
f = interpolate.interp1d(x, gaze)
|
||||
|
||||
if visualize:
|
||||
xnew = np.arange(0, ratio * (len(gaze) - 1) + 0.1, 0.1)
|
||||
ynew = f(xnew)
|
||||
|
||||
plt.plot(x, gaze, 'o', xnew, ynew, '-', idx, f(idx), '*')
|
||||
plt.show()
|
||||
|
||||
return np.array(gaze[:idx] + [float(f(idx))] + gaze[idx:])
|
||||
|
||||
def get_angle(center, point):
|
||||
pf = [center[0], center[1], SUBJECT_TO_SCREEN]
|
||||
cf = [point[0], point[1], SUBJECT_TO_SCREEN]
|
||||
|
||||
v = np.dot(pf, cf) / np.dot(np.linalg.norm(pf), np.linalg.norm(cf))
|
||||
angle = np.arccos(np.clip(v, a_min=-1, a_max=1))
|
||||
|
||||
return angle * 5.73 * 1000
|
||||
|
||||
def get_idt_dispersion(cfg):
|
||||
# Get dispersion of current fixation group to determine smooth pursuits
|
||||
# see https://github.com/M3stark/Eye_tracking_proj/blob/main/ivdt.py
|
||||
max_x, min_x = max(cfg[:, 0]), min(cfg[:, 0])
|
||||
max_y, min_y = max(cfg[:, 1]), min(cfg[:, 1])
|
||||
|
||||
return (max_x - min_x) + (max_y - min_y)
|
||||
|
||||
|
||||
def agent_in_subgoal(subgoals, agent_x, agent_y):
|
||||
|
||||
test_min_x = subgoals[:, 0] < agent_x
|
||||
test_max_x = subgoals[:, 2] > agent_x
|
||||
|
||||
test_min_y = subgoals[:, 1] < agent_y
|
||||
test_max_y = subgoals[:, 3] > agent_y
|
||||
|
||||
return np.any(test_min_x & test_max_x & test_min_y & test_max_y), np.where(test_min_x & test_max_x & test_min_y & test_max_y)[0]
|
||||
|
||||
|
Loading…
Add table
Reference in a new issue