263 lines
7.3 KiB
Text
263 lines
7.3 KiB
Text
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "b694995f",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Extract Subgoals from Gaze\n",
|
||
|
"\n",
|
||
|
"### For each episode\n",
|
||
|
"- Generate saliency map of first room and threshold\n",
|
||
|
"- Draw bounding box around each _salient_ pixel \n",
|
||
|
"- Perform Non-Maximum Supression (NMS) on generated bounding boxes with iou threshold \n",
|
||
|
"- Merge resulting boxes if there are still overlaps \n",
|
||
|
"\n",
|
||
|
"### For room 1\n",
|
||
|
"- Perform NMS on the filtered and merged bounding box proposals from all episodes\n",
|
||
|
"- Merge again "
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "d1245d4c",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import os\n",
|
||
|
"import random\n",
|
||
|
"\n",
|
||
|
"import cv2\n",
|
||
|
"import torch\n",
|
||
|
"import numpy as np\n",
|
||
|
"import pandas as pd \n",
|
||
|
"import matplotlib.pyplot as plt \n",
|
||
|
"from scipy.ndimage import gaussian_filter\n",
|
||
|
"from torchvision.ops import masks_to_boxes\n",
|
||
|
"\n",
|
||
|
"from dataset_utils import visualize_sample, apply_nms, merge_boxes, get_subgoal_proposals, SIGMA\n",
|
||
|
"\n",
|
||
|
"DATA_PATH = 'montezuma_revenge'\n",
|
||
|
"\n",
|
||
|
"df = pd.read_pickle(os.path.join(DATA_PATH, \"all_trials_labeled.pkl\"))\n",
|
||
|
"\n",
|
||
|
"init_screen = cv2.imread(df.iloc[0].img_path)\n",
|
||
|
"init_screen = cv2.cvtColor(init_screen, cv2.COLOR_BGR2RGB)\n",
|
||
|
"\n",
|
||
|
"df.head()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "27bf55d2",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### Generate example saliency map "
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "8b13e060",
|
||
|
"metadata": {
|
||
|
"scrolled": false
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Get gaze from one run \n",
|
||
|
"episode = '284_RZ_5540489_E00'\n",
|
||
|
"gaze = df.loc[df.ID == episode].loc[df.room_id == 1].loc[df.level==0].gaze_positions\n",
|
||
|
"\n",
|
||
|
"flat_list = []\n",
|
||
|
"for gaze_points in gaze:\n",
|
||
|
" if gaze_points is not None: \n",
|
||
|
" for item in gaze_points:\n",
|
||
|
" flat_list.append(item)\n",
|
||
|
"\n",
|
||
|
"saliency_map = np.zeros(init_screen.shape[:2])\n",
|
||
|
"threshold = 0.35\n",
|
||
|
"\n",
|
||
|
"# Add gaze coordinates to saliency map\n",
|
||
|
"for cords in flat_list:\n",
|
||
|
" try: \n",
|
||
|
" saliency_map[int(cords[1])][int(cords[0])] += 1\n",
|
||
|
" except:\n",
|
||
|
" # Not all gaze points are on image \n",
|
||
|
" continue\n",
|
||
|
"\n",
|
||
|
"# Construct fixation map \n",
|
||
|
"fix_map = saliency_map >= 1.0 \n",
|
||
|
"\n",
|
||
|
"# Construct empirical saliency map\n",
|
||
|
"saliency_map = gaussian_filter(saliency_map, sigma=SIGMA, mode='nearest')\n",
|
||
|
"\n",
|
||
|
"# Normalize saliency map into range [0, 1]\n",
|
||
|
"if not saliency_map.max() == 0:\n",
|
||
|
" saliency_map /= saliency_map.max()\n",
|
||
|
"\n",
|
||
|
"gray = cv2.cvtColor(init_screen, cv2.COLOR_BGR2GRAY) \n",
|
||
|
"fov_image = np.multiply(saliency_map, gray) # element-wise product\n",
|
||
|
"\n",
|
||
|
"visualize_sample(init_screen, saliency_map)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "334fc3b8",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"### Find good saliency threshold "
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "8043f4bf",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"plt.hist(saliency_map.flatten())\n",
|
||
|
"plt.show()\n",
|
||
|
"\n",
|
||
|
"mask = saliency_map > 0.4\n",
|
||
|
"masked_saliency = saliency_map.copy()\n",
|
||
|
"masked_saliency[~mask] = 0 \n",
|
||
|
"print('Threshold: 0.4')\n",
|
||
|
"visualize_sample(init_screen, masked_saliency)\n",
|
||
|
"\n",
|
||
|
"mask = saliency_map > 0.35\n",
|
||
|
"masked_saliency = saliency_map.copy()\n",
|
||
|
"masked_saliency[~mask] = 0 \n",
|
||
|
"print('Threshold: 0.35')\n",
|
||
|
"visualize_sample(init_screen, masked_saliency)\n",
|
||
|
"\n",
|
||
|
"mask = saliency_map > 0.2\n",
|
||
|
"masked_saliency = saliency_map.copy()\n",
|
||
|
"masked_saliency[~mask] = 0 \n",
|
||
|
"print('Threshold: 0.2')\n",
|
||
|
"visualize_sample(init_screen, masked_saliency)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"id": "9f4a356f",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Extract subgoals from all saliency maps"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "6aad7e32",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"%%time \n",
|
||
|
"\n",
|
||
|
"SALIENCY_THRESH = 0.3\n",
|
||
|
"VISUALIZE = False \n",
|
||
|
"ROOM = 1\n",
|
||
|
"\n",
|
||
|
"init_screen = cv2.imread(df.loc[df.room_id == ROOM].iloc[10].img_path)\n",
|
||
|
"init_screen = cv2.cvtColor(init_screen, cv2.COLOR_BGR2RGB)\n",
|
||
|
"\n",
|
||
|
"subgoal_proposals = get_subgoal_proposals(df, threshold=SALIENCY_THRESH, visualize=VISUALIZE, room=ROOM)\n",
|
||
|
"\n",
|
||
|
"proposal_df = pd.DataFrame.from_dict(subgoal_proposals, orient='index', columns=['bboxes', 'merged_bboxes'])\n",
|
||
|
"proposal_df.to_pickle(os.path.join(DATA_PATH, f\"room1_subgoal_proposals{int(SALIENCY_THRESH * 100)}.pkl\")) \n",
|
||
|
"\n",
|
||
|
"proposal_df.head()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "cbe23a6a",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"img = init_screen.copy()\n",
|
||
|
"for proposals in proposal_df.merged_bboxes:\n",
|
||
|
" \n",
|
||
|
" for box in proposals:\n",
|
||
|
" img = cv2.rectangle(img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (255,0,0), 1)\n",
|
||
|
"\n",
|
||
|
"fig = plt.figure(figsize=(8,8))\n",
|
||
|
"plt.imshow(img)\n",
|
||
|
"plt.axis('off')\n",
|
||
|
"plt.savefig(f'visualizations/room{ROOM}_all_subgoal_proposals{int(SALIENCY_THRESH * 100)}.png', bbox_inches='tight')\n",
|
||
|
"plt.show()\n",
|
||
|
"\n",
|
||
|
"all_proposals = np.concatenate(proposal_df.merged_bboxes.to_numpy())\n",
|
||
|
"print(all_proposals.shape)\n",
|
||
|
"\n",
|
||
|
"# Non-max suppression \n",
|
||
|
"keep = apply_nms(all_proposals, thresh_iou=0.01)\n",
|
||
|
"\n",
|
||
|
"print('Bounding boxes after non-maximum suppression')\n",
|
||
|
"img = init_screen.copy()\n",
|
||
|
"for box in keep:\n",
|
||
|
" img = cv2.rectangle(img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (255,0,0), 1)\n",
|
||
|
"\n",
|
||
|
"fig = plt.figure(figsize=(8,8))\n",
|
||
|
"plt.imshow(img)\n",
|
||
|
"plt.axis('off')\n",
|
||
|
"plt.show()\n",
|
||
|
"\n",
|
||
|
"merged = merge_boxes(keep)\n",
|
||
|
"print('Bounding boxes after merging')\n",
|
||
|
"img = init_screen.copy()\n",
|
||
|
"for box in merged:\n",
|
||
|
" img = cv2.rectangle(img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (255,0,0), 1)\n",
|
||
|
"\n",
|
||
|
"fig = plt.figure(figsize=(8,8))\n",
|
||
|
"plt.imshow(img)\n",
|
||
|
"plt.axis('off')\n",
|
||
|
"plt.savefig(f'visualizations/room{ROOM}_final_subgoals{int(SALIENCY_THRESH * 100)}.png', bbox_inches='tight')\n",
|
||
|
"plt.show()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "fdf9db11",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"np.savetxt('subgoals.txt', merged, fmt='%i', delimiter=',')"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "ad587996",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": []
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "msc_env",
|
||
|
"language": "python",
|
||
|
"name": "msc_env"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.9.12"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 5
|
||
|
}
|