{ "cells": [ { "cell_type": "markdown", "id": "b694995f", "metadata": {}, "source": [ "# Extract Subgoals from Gaze\n", "\n", "### For each episode\n", "- Generate saliency map of first room and threshold\n", "- Draw bounding box around each _salient_ pixel \n", "- Perform Non-Maximum Supression (NMS) on generated bounding boxes with iou threshold \n", "- Merge resulting boxes if there are still overlaps \n", "\n", "### For room 1\n", "- Perform NMS on the filtered and merged bounding box proposals from all episodes\n", "- Merge again " ] }, { "cell_type": "code", "execution_count": null, "id": "d1245d4c", "metadata": {}, "outputs": [], "source": [ "import os\n", "import random\n", "\n", "import cv2\n", "import torch\n", "import numpy as np\n", "import pandas as pd \n", "import matplotlib.pyplot as plt \n", "from scipy.ndimage import gaussian_filter\n", "from torchvision.ops import masks_to_boxes\n", "\n", "from dataset_utils import visualize_sample, apply_nms, merge_boxes, get_subgoal_proposals, SIGMA\n", "\n", "DATA_PATH = 'montezuma_revenge'\n", "\n", "df = pd.read_pickle(os.path.join(DATA_PATH, \"all_trials_labeled.pkl\"))\n", "\n", "init_screen = cv2.imread(df.iloc[0].img_path)\n", "init_screen = cv2.cvtColor(init_screen, cv2.COLOR_BGR2RGB)\n", "\n", "df.head()" ] }, { "cell_type": "markdown", "id": "27bf55d2", "metadata": {}, "source": [ "### Generate example saliency map " ] }, { "cell_type": "code", "execution_count": null, "id": "8b13e060", "metadata": { "scrolled": false }, "outputs": [], "source": [ "# Get gaze from one run \n", "episode = '284_RZ_5540489_E00'\n", "gaze = df.loc[df.ID == episode].loc[df.room_id == 1].loc[df.level==0].gaze_positions\n", "\n", "flat_list = []\n", "for gaze_points in gaze:\n", " if gaze_points is not None: \n", " for item in gaze_points:\n", " flat_list.append(item)\n", "\n", "saliency_map = np.zeros(init_screen.shape[:2])\n", "threshold = 0.35\n", "\n", "# Add gaze coordinates to saliency map\n", "for cords in flat_list:\n", " try: \n", " saliency_map[int(cords[1])][int(cords[0])] += 1\n", " except:\n", " # Not all gaze points are on image \n", " continue\n", "\n", "# Construct fixation map \n", "fix_map = saliency_map >= 1.0 \n", "\n", "# Construct empirical saliency map\n", "saliency_map = gaussian_filter(saliency_map, sigma=SIGMA, mode='nearest')\n", "\n", "# Normalize saliency map into range [0, 1]\n", "if not saliency_map.max() == 0:\n", " saliency_map /= saliency_map.max()\n", "\n", "gray = cv2.cvtColor(init_screen, cv2.COLOR_BGR2GRAY) \n", "fov_image = np.multiply(saliency_map, gray) # element-wise product\n", "\n", "visualize_sample(init_screen, saliency_map)" ] }, { "cell_type": "markdown", "id": "334fc3b8", "metadata": {}, "source": [ "### Find good saliency threshold " ] }, { "cell_type": "code", "execution_count": null, "id": "8043f4bf", "metadata": {}, "outputs": [], "source": [ "plt.hist(saliency_map.flatten())\n", "plt.show()\n", "\n", "mask = saliency_map > 0.4\n", "masked_saliency = saliency_map.copy()\n", "masked_saliency[~mask] = 0 \n", "print('Threshold: 0.4')\n", "visualize_sample(init_screen, masked_saliency)\n", "\n", "mask = saliency_map > 0.35\n", "masked_saliency = saliency_map.copy()\n", "masked_saliency[~mask] = 0 \n", "print('Threshold: 0.35')\n", "visualize_sample(init_screen, masked_saliency)\n", "\n", "mask = saliency_map > 0.2\n", "masked_saliency = saliency_map.copy()\n", "masked_saliency[~mask] = 0 \n", "print('Threshold: 0.2')\n", "visualize_sample(init_screen, masked_saliency)" ] }, { "cell_type": "markdown", "id": "9f4a356f", "metadata": {}, "source": [ "# Extract subgoals from all saliency maps" ] }, { "cell_type": "code", "execution_count": null, "id": "6aad7e32", "metadata": {}, "outputs": [], "source": [ "%%time \n", "\n", "SALIENCY_THRESH = 0.3\n", "VISUALIZE = False \n", "ROOM = 1\n", "\n", "init_screen = cv2.imread(df.loc[df.room_id == ROOM].iloc[10].img_path)\n", "init_screen = cv2.cvtColor(init_screen, cv2.COLOR_BGR2RGB)\n", "\n", "subgoal_proposals = get_subgoal_proposals(df, threshold=SALIENCY_THRESH, visualize=VISUALIZE, room=ROOM)\n", "\n", "proposal_df = pd.DataFrame.from_dict(subgoal_proposals, orient='index', columns=['bboxes', 'merged_bboxes'])\n", "proposal_df.to_pickle(os.path.join(DATA_PATH, f\"room1_subgoal_proposals{int(SALIENCY_THRESH * 100)}.pkl\")) \n", "\n", "proposal_df.head()" ] }, { "cell_type": "code", "execution_count": null, "id": "cbe23a6a", "metadata": {}, "outputs": [], "source": [ "img = init_screen.copy()\n", "for proposals in proposal_df.merged_bboxes:\n", " \n", " for box in proposals:\n", " img = cv2.rectangle(img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (255,0,0), 1)\n", "\n", "fig = plt.figure(figsize=(8,8))\n", "plt.imshow(img)\n", "plt.axis('off')\n", "plt.savefig(f'visualizations/room{ROOM}_all_subgoal_proposals{int(SALIENCY_THRESH * 100)}.png', bbox_inches='tight')\n", "plt.show()\n", "\n", "all_proposals = np.concatenate(proposal_df.merged_bboxes.to_numpy())\n", "print(all_proposals.shape)\n", "\n", "# Non-max suppression \n", "keep = apply_nms(all_proposals, thresh_iou=0.01)\n", "\n", "print('Bounding boxes after non-maximum suppression')\n", "img = init_screen.copy()\n", "for box in keep:\n", " img = cv2.rectangle(img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (255,0,0), 1)\n", "\n", "fig = plt.figure(figsize=(8,8))\n", "plt.imshow(img)\n", "plt.axis('off')\n", "plt.show()\n", "\n", "merged = merge_boxes(keep)\n", "print('Bounding boxes after merging')\n", "img = init_screen.copy()\n", "for box in merged:\n", " img = cv2.rectangle(img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (255,0,0), 1)\n", "\n", "fig = plt.figure(figsize=(8,8))\n", "plt.imshow(img)\n", "plt.axis('off')\n", "plt.savefig(f'visualizations/room{ROOM}_final_subgoals{int(SALIENCY_THRESH * 100)}.png', bbox_inches='tight')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "fdf9db11", "metadata": {}, "outputs": [], "source": [ "np.savetxt('subgoals.txt', merged, fmt='%i', delimiter=',')" ] }, { "cell_type": "code", "execution_count": null, "id": "ad587996", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "msc_env", "language": "python", "name": "msc_env" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.12" } }, "nbformat": 4, "nbformat_minor": 5 }