Int-HRL/SubgoalsFromGaze.ipynb

263 lines
7.3 KiB
Text
Raw Normal View History

2025-03-12 18:20:56 +01:00
{
"cells": [
{
"cell_type": "markdown",
"id": "b694995f",
"metadata": {},
"source": [
"# Extract Subgoals from Gaze\n",
"\n",
"### For each episode\n",
"- Generate saliency map of first room and threshold\n",
"- Draw bounding box around each _salient_ pixel \n",
"- Perform Non-Maximum Supression (NMS) on generated bounding boxes with iou threshold \n",
"- Merge resulting boxes if there are still overlaps \n",
"\n",
"### For room 1\n",
"- Perform NMS on the filtered and merged bounding box proposals from all episodes\n",
"- Merge again "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d1245d4c",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import random\n",
"\n",
"import cv2\n",
"import torch\n",
"import numpy as np\n",
"import pandas as pd \n",
"import matplotlib.pyplot as plt \n",
"from scipy.ndimage import gaussian_filter\n",
"from torchvision.ops import masks_to_boxes\n",
"\n",
"from dataset_utils import visualize_sample, apply_nms, merge_boxes, get_subgoal_proposals, SIGMA\n",
"\n",
"DATA_PATH = 'montezuma_revenge'\n",
"\n",
"df = pd.read_pickle(os.path.join(DATA_PATH, \"all_trials_labeled.pkl\"))\n",
"\n",
"init_screen = cv2.imread(df.iloc[0].img_path)\n",
"init_screen = cv2.cvtColor(init_screen, cv2.COLOR_BGR2RGB)\n",
"\n",
"df.head()"
]
},
{
"cell_type": "markdown",
"id": "27bf55d2",
"metadata": {},
"source": [
"### Generate example saliency map "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8b13e060",
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"# Get gaze from one run \n",
"episode = '284_RZ_5540489_E00'\n",
"gaze = df.loc[df.ID == episode].loc[df.room_id == 1].loc[df.level==0].gaze_positions\n",
"\n",
"flat_list = []\n",
"for gaze_points in gaze:\n",
" if gaze_points is not None: \n",
" for item in gaze_points:\n",
" flat_list.append(item)\n",
"\n",
"saliency_map = np.zeros(init_screen.shape[:2])\n",
"threshold = 0.35\n",
"\n",
"# Add gaze coordinates to saliency map\n",
"for cords in flat_list:\n",
" try: \n",
" saliency_map[int(cords[1])][int(cords[0])] += 1\n",
" except:\n",
" # Not all gaze points are on image \n",
" continue\n",
"\n",
"# Construct fixation map \n",
"fix_map = saliency_map >= 1.0 \n",
"\n",
"# Construct empirical saliency map\n",
"saliency_map = gaussian_filter(saliency_map, sigma=SIGMA, mode='nearest')\n",
"\n",
"# Normalize saliency map into range [0, 1]\n",
"if not saliency_map.max() == 0:\n",
" saliency_map /= saliency_map.max()\n",
"\n",
"gray = cv2.cvtColor(init_screen, cv2.COLOR_BGR2GRAY) \n",
"fov_image = np.multiply(saliency_map, gray) # element-wise product\n",
"\n",
"visualize_sample(init_screen, saliency_map)"
]
},
{
"cell_type": "markdown",
"id": "334fc3b8",
"metadata": {},
"source": [
"### Find good saliency threshold "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8043f4bf",
"metadata": {},
"outputs": [],
"source": [
"plt.hist(saliency_map.flatten())\n",
"plt.show()\n",
"\n",
"mask = saliency_map > 0.4\n",
"masked_saliency = saliency_map.copy()\n",
"masked_saliency[~mask] = 0 \n",
"print('Threshold: 0.4')\n",
"visualize_sample(init_screen, masked_saliency)\n",
"\n",
"mask = saliency_map > 0.35\n",
"masked_saliency = saliency_map.copy()\n",
"masked_saliency[~mask] = 0 \n",
"print('Threshold: 0.35')\n",
"visualize_sample(init_screen, masked_saliency)\n",
"\n",
"mask = saliency_map > 0.2\n",
"masked_saliency = saliency_map.copy()\n",
"masked_saliency[~mask] = 0 \n",
"print('Threshold: 0.2')\n",
"visualize_sample(init_screen, masked_saliency)"
]
},
{
"cell_type": "markdown",
"id": "9f4a356f",
"metadata": {},
"source": [
"# Extract subgoals from all saliency maps"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6aad7e32",
"metadata": {},
"outputs": [],
"source": [
"%%time \n",
"\n",
"SALIENCY_THRESH = 0.3\n",
"VISUALIZE = False \n",
"ROOM = 1\n",
"\n",
"init_screen = cv2.imread(df.loc[df.room_id == ROOM].iloc[10].img_path)\n",
"init_screen = cv2.cvtColor(init_screen, cv2.COLOR_BGR2RGB)\n",
"\n",
"subgoal_proposals = get_subgoal_proposals(df, threshold=SALIENCY_THRESH, visualize=VISUALIZE, room=ROOM)\n",
"\n",
"proposal_df = pd.DataFrame.from_dict(subgoal_proposals, orient='index', columns=['bboxes', 'merged_bboxes'])\n",
"proposal_df.to_pickle(os.path.join(DATA_PATH, f\"room1_subgoal_proposals{int(SALIENCY_THRESH * 100)}.pkl\")) \n",
"\n",
"proposal_df.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cbe23a6a",
"metadata": {},
"outputs": [],
"source": [
"img = init_screen.copy()\n",
"for proposals in proposal_df.merged_bboxes:\n",
" \n",
" for box in proposals:\n",
" img = cv2.rectangle(img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (255,0,0), 1)\n",
"\n",
"fig = plt.figure(figsize=(8,8))\n",
"plt.imshow(img)\n",
"plt.axis('off')\n",
"plt.savefig(f'visualizations/room{ROOM}_all_subgoal_proposals{int(SALIENCY_THRESH * 100)}.png', bbox_inches='tight')\n",
"plt.show()\n",
"\n",
"all_proposals = np.concatenate(proposal_df.merged_bboxes.to_numpy())\n",
"print(all_proposals.shape)\n",
"\n",
"# Non-max suppression \n",
"keep = apply_nms(all_proposals, thresh_iou=0.01)\n",
"\n",
"print('Bounding boxes after non-maximum suppression')\n",
"img = init_screen.copy()\n",
"for box in keep:\n",
" img = cv2.rectangle(img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (255,0,0), 1)\n",
"\n",
"fig = plt.figure(figsize=(8,8))\n",
"plt.imshow(img)\n",
"plt.axis('off')\n",
"plt.show()\n",
"\n",
"merged = merge_boxes(keep)\n",
"print('Bounding boxes after merging')\n",
"img = init_screen.copy()\n",
"for box in merged:\n",
" img = cv2.rectangle(img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (255,0,0), 1)\n",
"\n",
"fig = plt.figure(figsize=(8,8))\n",
"plt.imshow(img)\n",
"plt.axis('off')\n",
"plt.savefig(f'visualizations/room{ROOM}_final_subgoals{int(SALIENCY_THRESH * 100)}.png', bbox_inches='tight')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fdf9db11",
"metadata": {},
"outputs": [],
"source": [
"np.savetxt('subgoals.txt', merged, fmt='%i', delimiter=',')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ad587996",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "msc_env",
"language": "python",
"name": "msc_env"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}