7.3 KiB
7.3 KiB
Extract Subgoals from Gaze¶
For each episode¶
- Generate saliency map of first room and threshold
- Draw bounding box around each salient pixel
- Perform Non-Maximum Supression (NMS) on generated bounding boxes with iou threshold
- Merge resulting boxes if there are still overlaps
For room 1¶
- Perform NMS on the filtered and merged bounding box proposals from all episodes
- Merge again
In [ ]:
import os import random import cv2 import torch import numpy as np import pandas as pd import matplotlib.pyplot as plt from scipy.ndimage import gaussian_filter from torchvision.ops import masks_to_boxes from dataset_utils import visualize_sample, apply_nms, merge_boxes, get_subgoal_proposals, SIGMA DATA_PATH = 'montezuma_revenge' df = pd.read_pickle(os.path.join(DATA_PATH, "all_trials_labeled.pkl")) init_screen = cv2.imread(df.iloc[0].img_path) init_screen = cv2.cvtColor(init_screen, cv2.COLOR_BGR2RGB) df.head()
Generate example saliency map¶
In [ ]:
# Get gaze from one run episode = '284_RZ_5540489_E00' gaze = df.loc[df.ID == episode].loc[df.room_id == 1].loc[df.level==0].gaze_positions flat_list = [] for gaze_points in gaze: if gaze_points is not None: for item in gaze_points: flat_list.append(item) saliency_map = np.zeros(init_screen.shape[:2]) threshold = 0.35 # Add gaze coordinates to saliency map for cords in flat_list: try: saliency_map[int(cords[1])][int(cords[0])] += 1 except: # Not all gaze points are on image continue # Construct fixation map fix_map = saliency_map >= 1.0 # Construct empirical saliency map saliency_map = gaussian_filter(saliency_map, sigma=SIGMA, mode='nearest') # Normalize saliency map into range [0, 1] if not saliency_map.max() == 0: saliency_map /= saliency_map.max() gray = cv2.cvtColor(init_screen, cv2.COLOR_BGR2GRAY) fov_image = np.multiply(saliency_map, gray) # element-wise product visualize_sample(init_screen, saliency_map)
Find good saliency threshold¶
In [ ]:
plt.hist(saliency_map.flatten()) plt.show() mask = saliency_map > 0.4 masked_saliency = saliency_map.copy() masked_saliency[~mask] = 0 print('Threshold: 0.4') visualize_sample(init_screen, masked_saliency) mask = saliency_map > 0.35 masked_saliency = saliency_map.copy() masked_saliency[~mask] = 0 print('Threshold: 0.35') visualize_sample(init_screen, masked_saliency) mask = saliency_map > 0.2 masked_saliency = saliency_map.copy() masked_saliency[~mask] = 0 print('Threshold: 0.2') visualize_sample(init_screen, masked_saliency)
Extract subgoals from all saliency maps¶
In [ ]:
%%time SALIENCY_THRESH = 0.3 VISUALIZE = False ROOM = 1 init_screen = cv2.imread(df.loc[df.room_id == ROOM].iloc[10].img_path) init_screen = cv2.cvtColor(init_screen, cv2.COLOR_BGR2RGB) subgoal_proposals = get_subgoal_proposals(df, threshold=SALIENCY_THRESH, visualize=VISUALIZE, room=ROOM) proposal_df = pd.DataFrame.from_dict(subgoal_proposals, orient='index', columns=['bboxes', 'merged_bboxes']) proposal_df.to_pickle(os.path.join(DATA_PATH, f"room1_subgoal_proposals{int(SALIENCY_THRESH * 100)}.pkl")) proposal_df.head()
In [ ]:
img = init_screen.copy() for proposals in proposal_df.merged_bboxes: for box in proposals: img = cv2.rectangle(img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (255,0,0), 1) fig = plt.figure(figsize=(8,8)) plt.imshow(img) plt.axis('off') plt.savefig(f'visualizations/room{ROOM}_all_subgoal_proposals{int(SALIENCY_THRESH * 100)}.png', bbox_inches='tight') plt.show() all_proposals = np.concatenate(proposal_df.merged_bboxes.to_numpy()) print(all_proposals.shape) # Non-max suppression keep = apply_nms(all_proposals, thresh_iou=0.01) print('Bounding boxes after non-maximum suppression') img = init_screen.copy() for box in keep: img = cv2.rectangle(img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (255,0,0), 1) fig = plt.figure(figsize=(8,8)) plt.imshow(img) plt.axis('off') plt.show() merged = merge_boxes(keep) print('Bounding boxes after merging') img = init_screen.copy() for box in merged: img = cv2.rectangle(img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (255,0,0), 1) fig = plt.figure(figsize=(8,8)) plt.imshow(img) plt.axis('off') plt.savefig(f'visualizations/room{ROOM}_final_subgoals{int(SALIENCY_THRESH * 100)}.png', bbox_inches='tight') plt.show()
In [ ]:
np.savetxt('subgoals.txt', merged, fmt='%i', delimiter=',')
In [ ]: