Int-HRL/SubgoalsFromGaze.ipynb

7.3 KiB

Extract Subgoals from Gaze

For each episode

  • Generate saliency map of first room and threshold
  • Draw bounding box around each salient pixel
  • Perform Non-Maximum Supression (NMS) on generated bounding boxes with iou threshold
  • Merge resulting boxes if there are still overlaps

For room 1

  • Perform NMS on the filtered and merged bounding box proposals from all episodes
  • Merge again
In [ ]:
import os
import random

import cv2
import torch
import numpy as np
import pandas as pd 
import matplotlib.pyplot as plt 
from scipy.ndimage import gaussian_filter
from torchvision.ops import masks_to_boxes

from dataset_utils import visualize_sample, apply_nms, merge_boxes, get_subgoal_proposals, SIGMA

DATA_PATH = 'montezuma_revenge'

df = pd.read_pickle(os.path.join(DATA_PATH, "all_trials_labeled.pkl"))

init_screen = cv2.imread(df.iloc[0].img_path)
init_screen = cv2.cvtColor(init_screen, cv2.COLOR_BGR2RGB)

df.head()

Generate example saliency map

In [ ]:
# Get gaze from one run 
episode = '284_RZ_5540489_E00'
gaze = df.loc[df.ID == episode].loc[df.room_id == 1].loc[df.level==0].gaze_positions

flat_list = []
for gaze_points in gaze:
    if gaze_points is not None: 
        for item in gaze_points:
            flat_list.append(item)

saliency_map = np.zeros(init_screen.shape[:2])
threshold = 0.35

# Add gaze coordinates to saliency map
for cords in flat_list:
    try: 
        saliency_map[int(cords[1])][int(cords[0])] += 1
    except:
        # Not all gaze points are on image 
        continue

# Construct fixation map 
fix_map = saliency_map >= 1.0 

# Construct empirical saliency map
saliency_map = gaussian_filter(saliency_map, sigma=SIGMA, mode='nearest')

# Normalize saliency map into range [0, 1]
if not saliency_map.max() == 0:
    saliency_map /= saliency_map.max()

gray = cv2.cvtColor(init_screen, cv2.COLOR_BGR2GRAY)    
fov_image = np.multiply(saliency_map, gray)  # element-wise product

visualize_sample(init_screen, saliency_map)

Find good saliency threshold

In [ ]:
plt.hist(saliency_map.flatten())
plt.show()

mask = saliency_map > 0.4
masked_saliency = saliency_map.copy()
masked_saliency[~mask] = 0  
print('Threshold: 0.4')
visualize_sample(init_screen, masked_saliency)

mask = saliency_map > 0.35
masked_saliency = saliency_map.copy()
masked_saliency[~mask] = 0  
print('Threshold: 0.35')
visualize_sample(init_screen, masked_saliency)

mask = saliency_map > 0.2
masked_saliency = saliency_map.copy()
masked_saliency[~mask] = 0  
print('Threshold: 0.2')
visualize_sample(init_screen, masked_saliency)

Extract subgoals from all saliency maps

In [ ]:
%%time 

SALIENCY_THRESH = 0.3
VISUALIZE = False 
ROOM = 1

init_screen = cv2.imread(df.loc[df.room_id == ROOM].iloc[10].img_path)
init_screen = cv2.cvtColor(init_screen, cv2.COLOR_BGR2RGB)

subgoal_proposals = get_subgoal_proposals(df, threshold=SALIENCY_THRESH, visualize=VISUALIZE, room=ROOM)

proposal_df = pd.DataFrame.from_dict(subgoal_proposals, orient='index', columns=['bboxes', 'merged_bboxes'])
proposal_df.to_pickle(os.path.join(DATA_PATH, f"room1_subgoal_proposals{int(SALIENCY_THRESH * 100)}.pkl"))  

proposal_df.head()
In [ ]:
img = init_screen.copy()
for proposals in proposal_df.merged_bboxes:
    
    for box in proposals:
        img = cv2.rectangle(img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (255,0,0), 1)

fig = plt.figure(figsize=(8,8))
plt.imshow(img)
plt.axis('off')
plt.savefig(f'visualizations/room{ROOM}_all_subgoal_proposals{int(SALIENCY_THRESH * 100)}.png', bbox_inches='tight')
plt.show()

all_proposals = np.concatenate(proposal_df.merged_bboxes.to_numpy())
print(all_proposals.shape)

# Non-max suppression 
keep = apply_nms(all_proposals, thresh_iou=0.01)

print('Bounding boxes after non-maximum suppression')
img = init_screen.copy()
for box in keep:
    img = cv2.rectangle(img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (255,0,0), 1)

fig = plt.figure(figsize=(8,8))
plt.imshow(img)
plt.axis('off')
plt.show()

merged = merge_boxes(keep)
print('Bounding boxes after merging')
img = init_screen.copy()
for box in merged:
    img = cv2.rectangle(img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (255,0,0), 1)

fig = plt.figure(figsize=(8,8))
plt.imshow(img)
plt.axis('off')
plt.savefig(f'visualizations/room{ROOM}_final_subgoals{int(SALIENCY_THRESH * 100)}.png', bbox_inches='tight')
plt.show()
In [ ]:
np.savetxt('subgoals.txt', merged, fmt='%i', delimiter=',')
In [ ]: