initial commit
This commit is contained in:
parent
449dff858d
commit
8e0fd07853
10 changed files with 4550 additions and 1 deletions
65
README.md
65
README.md
|
@ -1,3 +1,66 @@
|
|||
# VSA4VQA
|
||||
|
||||
Official code for "VSA4VQA: Scaling a Vector Symbolic Architecture to Visual Question Answering on Natural Images" published at CogSci'24
|
||||
Official code for [VSA4VQA: Scaling a Vector Symbolic Architecture to Visual Question Answering on Natural Images](https://perceptualui.org/publications/penzkofer24_cogsci/) published at CogSci'24
|
||||
|
||||
## Installation
|
||||
```shell
|
||||
# create environment
|
||||
conda create -n ssp_env python=3.9 pip
|
||||
conda activate ssp_env
|
||||
conda install pytorch torchvision pytorch-cuda=11.8 -c pytorch -c nvidia -y
|
||||
|
||||
sudo apt install libmysqlclient-dev
|
||||
|
||||
# install requirements
|
||||
pip install -r requirements.txt
|
||||
|
||||
# install CLIP
|
||||
pip install git+https://github.com/openai/CLIP.git
|
||||
|
||||
# setup jupyter notebook kernel
|
||||
python -m ipykernel install --user --name=ssp_env
|
||||
|
||||
```
|
||||
|
||||
## Get GQA Programs
|
||||
using code by [https://github.com/wenhuchen/Meta-Module-Network](https://github.com/wenhuchen/Meta-Module-Network)<br>
|
||||
- Download github repo MMN
|
||||
- Add `gqa-questions` folder with GQA json files
|
||||
- Run Preprocessing
|
||||
`python preprocess.py create_balanced_programs`
|
||||
- Save generated programs to data folder:
|
||||
```
|
||||
testdev_balanced_inputs.json
|
||||
trainval_balanced_inputs.json
|
||||
testdev_balanced_programs.json
|
||||
trainval_balanced_programs.json
|
||||
```
|
||||
|
||||
> GQA dictionaries: `gqa_all_attributes.json` and `gqa_all_vocab_classes` are also adapted from [https://github.com/wenhuchen/Meta-Module-Network](https://github.com/wenhuchen/Meta-Module-Network)
|
||||
|
||||
## Generate Query Masks
|
||||
- generates full_relations_df.pkl if not already present
|
||||
- generates query masks for all relations with more than 1000 samples
|
||||
```shell
|
||||
python generate_query_masks.py
|
||||
```
|
||||
|
||||
## Pipeline
|
||||
Execute Pipeline for all samples in GQA: train_balanced (with `TEST=False`) or validation_balanced (with `TEST=True`)
|
||||
```shell
|
||||
python run_programs.py
|
||||
```
|
||||
For visualizing samples see [code/GQA_PIPELINE.ipynb](code/GQA_PIPELINE.ipynb) <br>
|
||||
For generating figures see [code/GQA_EVAL.ipynb](code/GQA_EVAL.ipynb) <br>
|
||||
|
||||
## Citation
|
||||
Please consider citing this paper if you use VSA4VQA or parts of this publication in your research:
|
||||
```
|
||||
@inproceedings{penzkofer24_cogsci,
|
||||
author = {Penzkofer, Anna and Shi, Lei and Bulling, Andreas},
|
||||
title = {VSA4VQA: Scaling A Vector Symbolic Architecture To Visual Question Answering on Natural Images},
|
||||
booktitle = {Proc. 46th Annual Meeting of the Cognitive Science Society (CogSci)},
|
||||
year = {2024},
|
||||
pages = {}
|
||||
}
|
||||
```
|
517
VSA4VQA_examples.ipynb
Normal file
517
VSA4VQA_examples.ipynb
Normal file
File diff suppressed because one or more lines are too long
287
dataset.py
Normal file
287
dataset.py
Normal file
|
@ -0,0 +1,287 @@
|
|||
import os
|
||||
import cv2
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.colors as mcolors
|
||||
import matplotlib.patches as patches
|
||||
|
||||
import nengo.spa as spa
|
||||
from utils import encode_point_multidim, ssp_to_loc_multidim, bb_intersection_over_union
|
||||
|
||||
RGB_COLORS = []
|
||||
for name, hex in mcolors.cnames.items():
|
||||
RGB_COLORS.append(mcolors.to_rgb(hex))
|
||||
|
||||
class MNISTQueryDataset(Dataset):
|
||||
"""MNIST spatial query dataset."""
|
||||
|
||||
def __init__(self, mnist_data, num_imgs, img_size=120, visualize=False, transform=None, seed=42):
|
||||
|
||||
# Set random seed for location and mnist image selection
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
self.mnist_data = mnist_data
|
||||
self.mnist_size = mnist_data[0][0].squeeze().numpy().shape[0]
|
||||
# Shuffle MNIST data set according to random seed
|
||||
self.mnist_indices = torch.randperm(len(self.mnist_data))
|
||||
|
||||
self.num_imgs = num_imgs
|
||||
self.img_size = img_size
|
||||
self.border = self.img_size - self.mnist_size
|
||||
self.visualize = visualize
|
||||
|
||||
self.transform = transform
|
||||
|
||||
def __len__(self):
|
||||
return len(self.mnist_data) // self.num_imgs
|
||||
|
||||
def __getitem__(self, idx):
|
||||
|
||||
current_indices = self.mnist_indices[idx: idx + self.num_imgs]
|
||||
|
||||
image = np.zeros((self.img_size, self.img_size))
|
||||
mask = np.ones((self.border, self.border))
|
||||
labels = []
|
||||
mnist_imgs = []
|
||||
|
||||
for i in current_indices:
|
||||
|
||||
mnist, label = self.mnist_data[i]
|
||||
mnist_imgs.append(mnist.squeeze().numpy())
|
||||
|
||||
# find available space
|
||||
indices = np.where(mask == 1)[:2]
|
||||
coords = np.transpose(indices)
|
||||
|
||||
# pick random pixel as x0, y0
|
||||
idx = np.random.randint(len(indices[0]))
|
||||
y_pos, x_pos = coords[idx]
|
||||
|
||||
# add mnist to image
|
||||
image[y_pos: y_pos+self.mnist_size, x_pos: x_pos+self.mnist_size] = mnist.squeeze().numpy()
|
||||
|
||||
# position label = center of mnist image
|
||||
labels.append(dict({label: (x_pos + self.mnist_size // 2, y_pos + self.mnist_size // 2)}))
|
||||
|
||||
# remove available space
|
||||
for x in np.arange(max(0, x_pos-self.mnist_size), min(x_pos+self.mnist_size+1, self.border)):
|
||||
for y in np.arange(max(0, y_pos-self.mnist_size), min(y_pos+self.mnist_size+1, self.border)):
|
||||
mask[y, x] = 0
|
||||
|
||||
# visualize image state and current mask
|
||||
if self.visualize:
|
||||
f, (ax1, ax2) = plt.subplots(1, 2, sharey=False)
|
||||
ax1.imshow(image, cmap='gray')
|
||||
ax2.imshow(mask, cmap='gray')
|
||||
plt.show()
|
||||
|
||||
sample = {'image': image, 'labels': labels, 'mnist_images': mnist_imgs}
|
||||
|
||||
#if self.transform:
|
||||
# sample = self.transform(sample)
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
class GQADataset():
|
||||
|
||||
def __init__(self, questions, programs, scenegraphs, vectors, axes, linspace,
|
||||
path='GQA/images/images/', seed=17, verbose=0, visualize=False):
|
||||
np.random.seed(seed)
|
||||
|
||||
self.questions = questions
|
||||
self.programs = programs
|
||||
self.scenegraphs = scenegraphs
|
||||
self.vis = visualize
|
||||
self.verbose = verbose
|
||||
self.seed = seed
|
||||
self.path = path
|
||||
|
||||
# vector space
|
||||
self.ssp_vectors = vectors
|
||||
self.ssp_axes = axes #[x_axis, y_axis, w_axis, h_axis]
|
||||
self.linspace = linspace # [xs, ys, ws, hs]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.questions)
|
||||
|
||||
def __get_item__(self, idx):
|
||||
|
||||
q_id = self.questions.iloc[idx].questionID
|
||||
temp_df = self.questions.loc[self.questions.questionID == q_id]
|
||||
|
||||
# get image
|
||||
img_id = temp_df.imageId.values[0]
|
||||
img_path = os.path.join(self.path, f'{img_id}.jpg')
|
||||
assert os.path.exists(img_path), f'Image path {img_path} does not exist!'
|
||||
img = cv2.imread(img_path)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# get question and answer
|
||||
question = temp_df.question.values[0]
|
||||
answer = temp_df.answer.values[0]
|
||||
full_answer = temp_df.fullAnswer.values[0]
|
||||
|
||||
# get program
|
||||
idx = self.programs.loc[self.programs.questionID == q_id].index[0]
|
||||
program = self.programs.iloc[idx].program
|
||||
|
||||
info = {'q_id': q_id, 'img_id': img_id, 'question': question, 'answer': answer,
|
||||
'full_answer': full_answer, 'program': program}
|
||||
|
||||
return img, info
|
||||
|
||||
def encode_item(self, idx, new_size=(25, 25), dim=1024):
|
||||
""" Encode all objects in image into SSP memory = vector space.
|
||||
ensure x- and y-axis of SSP memory have same resolution
|
||||
and fixed width & height axes (10,10), no zero values for width & height
|
||||
and int values instead of decimals, otherwise decoding accuracy degrades
|
||||
"""
|
||||
img, info = self.__get_item__(idx)
|
||||
sg_data = self.scenegraphs.get(str(info['img_id'])).get('objects')
|
||||
|
||||
img_size = img.shape[:2]
|
||||
|
||||
# find orientation and select scale to fit into quadratic vector space
|
||||
if img_size[1] / 2 < img_size[0]:
|
||||
scale = img_size[0] / new_size[0]
|
||||
else:
|
||||
scale = img_size[1] / new_size[1]
|
||||
|
||||
# scale width and height to fixed size of 10
|
||||
w_scale, h_scale = img_size[1] / 10, img_size[0] / 10
|
||||
|
||||
encoded_items = {}
|
||||
encoded_ssps = {}
|
||||
|
||||
rng = np.random.RandomState(seed=self.seed)
|
||||
memory = spa.SemanticPointer(data=np.zeros(dim), rng=rng)
|
||||
name_lst = []
|
||||
|
||||
if self.vis:
|
||||
print(f'Original image {img_size[0]}x{img_size[1]} --> {int(img_size[0] / scale)}x{int(img_size[1] / scale)}')
|
||||
fig, ax = plt.subplots(1,1)
|
||||
ax.imshow(img, interpolation='none', origin='upper', extent=[0, img_size[1] / scale, img_size[0] / scale, 0])
|
||||
plt.axis('off')
|
||||
|
||||
for i, obj in enumerate(sg_data.items()):
|
||||
id_num, obj_dict = obj
|
||||
name = obj_dict.get('name')
|
||||
#name = singularize(name)
|
||||
name_lst.append(name)
|
||||
name += '_' + str(name_lst.count(name))
|
||||
|
||||
# extract ground truth data and scale to fit to SSPs
|
||||
x, y, width, height = obj_dict.get('x'), obj_dict.get('y'), obj_dict.get('w'), obj_dict.get('h')
|
||||
x, y, width, height = x / scale, y / scale, width / w_scale, height / h_scale
|
||||
|
||||
width = width if width >= 1 else 1
|
||||
height = height if height >= 1 else 1
|
||||
|
||||
# Round values to next int (otherwise decoding gets buggy)
|
||||
item = np.round([x, y, width, height], decimals=0).astype(int)
|
||||
encoded_items[name] = item
|
||||
|
||||
pos = encode_point_multidim(list(item), self.ssp_axes)
|
||||
ssp = spa.SemanticPointer(dim)
|
||||
encoded_ssps[name] = ssp
|
||||
|
||||
memory += ssp * pos
|
||||
|
||||
if self.vis:
|
||||
x, y, width, height = item
|
||||
width, height = (width * w_scale) / scale, (height * h_scale) / scale
|
||||
rect = patches.Rectangle((x, y),
|
||||
width, height,
|
||||
linewidth = 2,
|
||||
label = name,
|
||||
edgecolor = RGB_COLORS[i],
|
||||
facecolor = 'none')
|
||||
ax.add_patch(rect)
|
||||
|
||||
if self.vis:
|
||||
plt.show()
|
||||
|
||||
info['encoded_items'] = encoded_items
|
||||
info['encoded_ssps'] = encoded_ssps
|
||||
info['scales'] = [scale, w_scale, h_scale]
|
||||
|
||||
return img, info, memory
|
||||
|
||||
def decode_item(self, img, info, memory):
|
||||
|
||||
img_size = img.shape[:2]
|
||||
scale, w_scale, h_scale = info['scales']
|
||||
|
||||
if self.vis:
|
||||
fig, ax = plt.subplots(1,1)
|
||||
ax.imshow(img, interpolation='none', origin='upper', extent=[0, img_size[1] / scale, img_size[0] / scale, 0])
|
||||
plt.axis('off')
|
||||
|
||||
errors = []
|
||||
iou_lst = []
|
||||
iou_binary_lst = []
|
||||
|
||||
for i, (name, data) in enumerate(info['encoded_items'].items()):
|
||||
ssp_item = info['encoded_ssps'][name]
|
||||
|
||||
item_decoded = memory *~ ssp_item
|
||||
clean_loc = ssp_to_loc_multidim(item_decoded, self.ssp_vectors, self.linspace)
|
||||
x, y, width, height = clean_loc
|
||||
|
||||
mse = np.square(np.subtract(data[:2], clean_loc[:2])).mean()
|
||||
errors.append(mse)
|
||||
|
||||
width, height = width * w_scale / scale, height * h_scale / scale
|
||||
bb_gt = np.array([data[0], data[1], data[0]+(data[2] * w_scale / scale), data[1]+(data[3] * h_scale / scale)])
|
||||
|
||||
iou = bb_intersection_over_union(bb_gt, [x, y, x+width, y+height])
|
||||
iou_lst.append(iou)
|
||||
|
||||
if iou > 0.5:
|
||||
iou_binary_lst.append(1)
|
||||
else:
|
||||
iou_binary_lst.append(0)
|
||||
|
||||
if self.vis:
|
||||
rect = patches.Rectangle((x, y),
|
||||
width, height,
|
||||
linewidth = 2,
|
||||
label = name,
|
||||
edgecolor = RGB_COLORS[i],
|
||||
facecolor = 'none')
|
||||
ax.add_patch(rect)
|
||||
|
||||
if self.vis:
|
||||
plt.legend(loc='upper left', bbox_to_anchor=(1., 1.02))
|
||||
plt.show()
|
||||
|
||||
avg_mse = np.mean(errors)
|
||||
avg_iou = np.mean(iou_lst)
|
||||
|
||||
if self.verbose > 0:
|
||||
print(f'Average mean-squared error of 2D locations: {avg_mse:.4f}')
|
||||
print(f'Average IoU of 4D bounding boxes: {avg_iou:.2f}')
|
||||
print(f'Correct items: {np.sum(iou_binary_lst)} / {len(info["encoded_items"])}')
|
||||
|
||||
return avg_mse, avg_iou, np.sum(iou_binary_lst)
|
||||
|
||||
def print_item(self, idx):
|
||||
_, info = self.__get_item__(idx)
|
||||
|
||||
print(f"Question #{info['q_id']}: \n{info['question']}")
|
||||
print(f"[{info['answer']}] {info['full_answer']}\n")
|
||||
print('Program:')
|
||||
for i, step in enumerate(info['program']):
|
||||
num, func = step.split('=')
|
||||
print(f'{i}. {func}')
|
||||
print()
|
||||
|
||||
def set_visualize(self, visualize):
|
||||
self.vis = visualize
|
||||
|
||||
def set_verbose(self, verbose):
|
||||
self.verbose = verbose
|
300
generate_query_masks.py
Normal file
300
generate_query_masks.py
Normal file
|
@ -0,0 +1,300 @@
|
|||
import os
|
||||
import time
|
||||
import json
|
||||
import queue
|
||||
from multiprocessing import Process, Queue
|
||||
from multiprocessing.pool import Pool
|
||||
from tqdm import tqdm
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
DATA_PATH = 'GQA/'
|
||||
REL_PATH = 'full_relations_df.pkl'
|
||||
IMG_SIZE = (500, 500)
|
||||
NUM_PROCESSES = 20
|
||||
NUM_SAMPLES = 100
|
||||
|
||||
|
||||
def bbox_to_mask(x, y, w, h, img_size=IMG_SIZE, name=None, visualize=False):
|
||||
img = np.zeros(img_size)
|
||||
mask_w = np.ones(np.clip(w, 0, img_size[1]-x))
|
||||
|
||||
for j in range(y, np.clip(y+h, 0, img_size[0])):
|
||||
img[j][x:x+w] = mask_w
|
||||
|
||||
if visualize:
|
||||
fig = plt.figure(figsize=(img_size[0] // 80, img_size[1] // 80))
|
||||
plt.imshow(img, cmap='gray')
|
||||
if name:
|
||||
plt.title(name)
|
||||
plt.axis('off')
|
||||
plt.show()
|
||||
|
||||
return img
|
||||
|
||||
def get_all_relations_df(data):
|
||||
|
||||
print(f'Length of scenegraph data set: {len(data)}')
|
||||
start = time.time()
|
||||
|
||||
df = pd.DataFrame(columns=['image_id', 'relation', 'from', 'to', 'obj_loc', 'obj_w', 'obj_h', 'obj_center',
|
||||
'rel_obj_loc', 'rel_obj_w', 'rel_obj_h'])
|
||||
|
||||
for img_id in data.keys():
|
||||
all_objects = data.get(str(img_id)).get('objects').items()
|
||||
|
||||
# get all object names
|
||||
all_objects_dict = {id_num: (obj_dict.get('name'), obj_dict.get('x'), obj_dict.get('y'), obj_dict.get('w'), obj_dict.get('h'))
|
||||
for (id_num, obj_dict) in all_objects}
|
||||
|
||||
# get all relations
|
||||
for obj in all_objects:
|
||||
id_num, obj_dict = obj
|
||||
name = obj_dict.get('name')
|
||||
x, y, width, height = obj_dict.get('x'), obj_dict.get('y'), obj_dict.get('w'), obj_dict.get('h')
|
||||
center = [x + width / 2, y + height / 2]
|
||||
|
||||
for relation in obj_dict.get('relations'):
|
||||
rel = relation.get('name')
|
||||
rel_obj, rel_x, rel_y, rel_w, rel_h = all_objects_dict.get(relation.get('object'))
|
||||
|
||||
|
||||
temp = pd.DataFrame.from_dict([{'image_id': img_id, 'relation': rel, 'from': name, 'to': rel_obj,
|
||||
'obj_loc': [x, y], 'obj_w': width, 'obj_h': height, 'center': center,
|
||||
'rel_obj_loc': [rel_x, rel_y], 'rel_obj_w': rel_w, 'rel_obj_h': rel_h}])
|
||||
|
||||
df = pd.concat([df, temp], ignore_index=True)
|
||||
#print(f'{df.iloc[-1]["from"]} {df.iloc[-1].relation} {df.iloc[-1].to}')
|
||||
|
||||
out_path = 'all_relations.pkl'
|
||||
df.to_pickle(out_path)
|
||||
print(f'Saved df to {out_path}')
|
||||
|
||||
end = time.time()
|
||||
elapsed = end - start
|
||||
print(f'Took {int(elapsed // 60)}:{int(elapsed % 60)} min:s for all {len(df)} relations --> {elapsed / len(df):.2f}s / relation')
|
||||
|
||||
|
||||
def generate_query_mask(df, rel, i, img_center=np.array([250, 250]), uni_size=np.array([50, 50])):
|
||||
|
||||
# uni_obj only needed for visualization in the end
|
||||
uni_obj = bbox_to_mask(img_center[0] - (uni_size[0] // 2), img_center[1] - (uni_size[1] // 2),
|
||||
50, 50, img_size=(500, 500))
|
||||
|
||||
temp_df = df.loc[df.relation == rel]
|
||||
print(f'[{i}] Number of "{rel}" samples: {len(temp_df)}')
|
||||
|
||||
query_mask = np.zeros((500, 500), dtype=np.uint8)
|
||||
counter = 0
|
||||
num_discard = 0
|
||||
|
||||
for idx in range(len(temp_df)):
|
||||
if counter >= NUM_SAMPLES:
|
||||
print(f'[{i}] Reached {counter} samples for relation "{rel}":')
|
||||
break
|
||||
|
||||
img_id = temp_df.iloc[idx].image_id
|
||||
img_size = (data.get(img_id)['height'], data.get(img_id)['width'])
|
||||
|
||||
# get relative object info and generate binary mask
|
||||
obj_loc = temp_df.iloc[idx].rel_obj_loc
|
||||
width = temp_df.iloc[idx].rel_obj_w
|
||||
height = temp_df.iloc[idx].rel_obj_h
|
||||
|
||||
# get mask info and generate binary mask
|
||||
mask_loc = temp_df.iloc[idx].obj_loc
|
||||
mask_w = temp_df.iloc[idx].obj_w
|
||||
mask_h = temp_df.iloc[idx].obj_h
|
||||
|
||||
if obj_loc[0] > img_size[1] or obj_loc[1] > img_size[0] or mask_loc[0] > img_size[1] or mask_loc[1] > img_size[0]:
|
||||
#print('error in bounding box -- discard sample')
|
||||
continue
|
||||
|
||||
obj = bbox_to_mask(obj_loc[0], obj_loc[1], width, height, img_size=img_size)
|
||||
mask = bbox_to_mask(mask_loc[0], mask_loc[1], mask_w, mask_h, img_size=img_size)
|
||||
|
||||
img = obj*2 + mask
|
||||
img_transformed = np.zeros((1000, 1000), dtype=np.uint8)
|
||||
|
||||
# scale image first
|
||||
scale_x, scale_y = uni_size[0] / width, uni_size[1] / height
|
||||
scale_mat = np.array([[scale_y, 0, 0], [0, scale_x, 0], [0, 0, 1]])
|
||||
|
||||
if scale_x > 5 or scale_y > 5:
|
||||
num_discard += 1
|
||||
#print(f'Scale is too high! x: {scale_x}, y: {scale_y} -- discard sample')
|
||||
continue
|
||||
|
||||
for i, row in enumerate(img):
|
||||
for j, col in enumerate(row):
|
||||
pixel_data = img[i, j]
|
||||
input_coords = np.array([i, j, 1])
|
||||
i_out, j_out, _ = scale_mat @ input_coords
|
||||
|
||||
if i_out > 0 and i_out < 1000 and j_out > 0 and j_out < 1000 and pixel_data > 0:
|
||||
# new indices must be within new image -- discard others
|
||||
img_transformed[int(i_out), int(j_out)] = pixel_data
|
||||
|
||||
if not len(np.where(img_transformed >= 2)[0]) > 0:
|
||||
# no data in transformed image -- discard sample
|
||||
continue
|
||||
|
||||
# find new (x, y) location of object
|
||||
new_loc = sorted([[y, x] for (y, x) in zip(*np.where(img_transformed >= 2))])[0]
|
||||
new_center = [new_loc[0] + uni_size[0] // 2, new_loc[1] + uni_size[1] // 2]
|
||||
|
||||
# move object to center
|
||||
move_x, move_y = img_center - new_center
|
||||
move_mat = np.array([[1, 0, move_x], [0, 1, move_y], [0, 0, 1]])
|
||||
|
||||
img_moved = np.zeros((500, 500), dtype=np.uint8)
|
||||
for i, row in enumerate(img_transformed):
|
||||
for j, col in enumerate(row):
|
||||
pixel_data = img_transformed[i, j]
|
||||
input_coords = np.array([i, j, 1])
|
||||
i_out, j_out, _ = move_mat @ input_coords
|
||||
|
||||
if i_out > 0 and i_out < 500 and j_out > 0 and j_out < 500 and pixel_data > 0:
|
||||
# new indices must be within new image -- discard others
|
||||
img_moved[int(i_out), int(j_out)] = pixel_data
|
||||
|
||||
# extract relative object mask and add to query mask
|
||||
mask_transformed = np.where(img_moved==1, img_moved, 0) + np.where(img_moved==3, img_moved, 0)
|
||||
query_mask += mask_transformed
|
||||
counter += 1
|
||||
|
||||
if counter > 0:
|
||||
query_mask = query_mask / counter
|
||||
rel_name = '_'.join(rel.split(' '))
|
||||
np.save(f'relations/{rel_name}.npy', query_mask)
|
||||
print(f'[{i}] Saved query mask to: relations/{rel_name}.npy')
|
||||
|
||||
if num_discard > 0:
|
||||
print(f'[{i}] Discarded {num_discard} samples, because scaling was too high.')
|
||||
|
||||
plt.figure(figsize=(3,3))
|
||||
plt.imshow(uni_obj*0.1+ query_mask, cmap='gray')
|
||||
plt.title(rel)
|
||||
plt.axis('off')
|
||||
plt.savefig(f'relations/{rel_name}.png', bbox_inches='tight', dpi=300)
|
||||
plt.clf()
|
||||
|
||||
else:
|
||||
print(f'[{i}] Could not generate query mask for "{rel}"')
|
||||
|
||||
def run_process(tasks, df):
|
||||
while True:
|
||||
try:
|
||||
'''
|
||||
try to get task from the queue. get_nowait() function will
|
||||
raise queue.Empty exception if the queue is empty.
|
||||
queue(False) function would do the same task also.
|
||||
'''
|
||||
task = tasks.get_nowait()
|
||||
i = list(df.relation.unique()).index(task)
|
||||
except queue.Empty:
|
||||
break
|
||||
else:
|
||||
''' no exception has been raised '''
|
||||
print(f'[{i}] Starting relation #{i}: {task}')
|
||||
print()
|
||||
generate_query_mask(df, task, i)
|
||||
time.sleep(.5)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# task executed in a worker process
|
||||
def get_relations_task(img_id):
|
||||
width, height = data.get(str(img_id))['width'], data.get(str(img_id))['height']
|
||||
all_objects = data.get(str(img_id)).get('objects').items()
|
||||
|
||||
# get all object names
|
||||
all_objects_dict = {id_num: (obj_dict.get('name'), obj_dict.get('x'), obj_dict.get('y'), obj_dict.get('w'), obj_dict.get('h'))
|
||||
for (id_num, obj_dict) in all_objects}
|
||||
|
||||
all_relations = []
|
||||
|
||||
# get all relations
|
||||
for obj in all_objects:
|
||||
id_num, obj_dict = obj
|
||||
name = obj_dict.get('name')
|
||||
x, y, obj_w, obj_h = obj_dict.get('x'), obj_dict.get('y'), obj_dict.get('w'), obj_dict.get('h')
|
||||
center = [x + width / 2, y + height / 2]
|
||||
|
||||
for relation in obj_dict.get('relations'):
|
||||
rel = relation.get('name')
|
||||
rel_obj, rel_x, rel_y, rel_w, rel_h = all_objects_dict.get(relation.get('object'))
|
||||
|
||||
all_relations.append({'image_id': img_id, 'width': width, 'height': height, 'relation': rel,
|
||||
'from': name, 'to': rel_obj, 'obj_loc': [x, y], 'obj_w': obj_w, 'obj_h': obj_h,
|
||||
'obj_center': center,'rel_obj_loc': [rel_x, rel_y], 'rel_obj_w': rel_w, 'rel_obj_h': rel_h})
|
||||
|
||||
return all_relations
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
|
||||
path = os.path.join(DATA_PATH, 'train_sceneGraphs.json')
|
||||
assert os.path.exists(path), f'{path} does not exist!'
|
||||
|
||||
with open(os.path.join(DATA_PATH, 'train_sceneGraphs.json'), 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
print(f'Length of scenegraph data set: {len(data)}')
|
||||
|
||||
if not os.path.exists(REL_PATH):
|
||||
print('Generating dataframe of all relations...')
|
||||
# generate list of relations pkl -- use multiprocessing!
|
||||
# create and configure the process pool
|
||||
with Pool(processes=NUM_PROCESSES) as pool:
|
||||
|
||||
df = pd.DataFrame(columns=['image_id', 'width', 'height', 'relation', 'from', 'to', 'obj_loc', 'obj_w',
|
||||
'obj_h', 'obj_center', 'rel_obj_loc', 'rel_obj_w', 'rel_obj_h'])
|
||||
|
||||
# execute tasks in order
|
||||
for i, result in enumerate(tqdm(pool.map(get_relations_task, list(data.keys()), chunksize=100))):
|
||||
temp = pd.DataFrame.from_dict(result)
|
||||
df = pd.concat([df, temp], ignore_index=True)
|
||||
if i % 10000 == 0:
|
||||
df.to_pickle('temp_' + REL_PATH)
|
||||
print(f'Saved df to {"temp_" + REL_PATH}')
|
||||
|
||||
df.to_pickle(REL_PATH)
|
||||
print(f'Saved df to {REL_PATH}')
|
||||
else:
|
||||
df = pd.read_pickle(REL_PATH)
|
||||
|
||||
print(f'Number of relations: {len(df.relation.unique())}')
|
||||
print(df.relation.unique())
|
||||
|
||||
# generate query mask for each relation
|
||||
#for i, rel in enumerate(df.relation.unique()):
|
||||
# generate_query_mask(df, rel, i)
|
||||
|
||||
print('Generating a query mask for each relation...')
|
||||
# generate query mask for each relation -- use multiprocessing
|
||||
tasks = Queue()
|
||||
procs = []
|
||||
|
||||
# only use relations with at least 1000 samples
|
||||
rel_lst = df.relation.value_counts()[df.relation.value_counts() > 1000].index.to_list()
|
||||
|
||||
for rel in rel_lst:
|
||||
tasks.put(rel)
|
||||
|
||||
# creating processes -- run only NUM_PROCESSES processes at the same time
|
||||
for _ in range(NUM_PROCESSES):
|
||||
p = Process(target=run_process, args=(tasks, df,))
|
||||
procs.append(p)
|
||||
p.start()
|
||||
|
||||
# completing all processes
|
||||
for p in procs:
|
||||
p.join()
|
||||
|
||||
|
222
gqa_all_attributes.json
Normal file
222
gqa_all_attributes.json
Normal file
|
@ -0,0 +1,222 @@
|
|||
{
|
||||
"color": [
|
||||
"beige",
|
||||
"black",
|
||||
"blond",
|
||||
"blue",
|
||||
"brown",
|
||||
"brunette",
|
||||
"cream colored",
|
||||
"dark",
|
||||
"dark blue",
|
||||
"dark brown",
|
||||
"gold",
|
||||
"gray",
|
||||
"green",
|
||||
"khaki",
|
||||
"light blue",
|
||||
"light brown",
|
||||
"maroon",
|
||||
"orange",
|
||||
"pink",
|
||||
"purple",
|
||||
"red",
|
||||
"silver",
|
||||
"tan",
|
||||
"teal",
|
||||
"white",
|
||||
"yellow"
|
||||
],
|
||||
"pose": [
|
||||
"bending",
|
||||
"brushing tooth",
|
||||
"crouching",
|
||||
"jumping",
|
||||
"lying",
|
||||
"making a face",
|
||||
"pointing",
|
||||
"running",
|
||||
"shaking hand",
|
||||
"sitting",
|
||||
"standing",
|
||||
"taking a photo",
|
||||
"taking a picture",
|
||||
"taking picture",
|
||||
"walking"
|
||||
],
|
||||
"material": [
|
||||
"brick",
|
||||
"concrete",
|
||||
"glass",
|
||||
"leather",
|
||||
"metal",
|
||||
"plastic",
|
||||
"porcelain",
|
||||
"wood"
|
||||
],
|
||||
"activity": [
|
||||
"brushing tooth",
|
||||
"cooking",
|
||||
"drinking",
|
||||
"driving",
|
||||
"eating",
|
||||
"looking down",
|
||||
"looking up",
|
||||
"playing",
|
||||
"posing",
|
||||
"reading",
|
||||
"resting",
|
||||
"sleeping",
|
||||
"staring",
|
||||
"talking",
|
||||
"waiting"
|
||||
],
|
||||
"weather": [
|
||||
"clear",
|
||||
"cloudless",
|
||||
"cloudy",
|
||||
"foggy",
|
||||
"overcast",
|
||||
"partly cloudy",
|
||||
"rainy",
|
||||
"stormy",
|
||||
"sunny"
|
||||
],
|
||||
"size": [
|
||||
"giant",
|
||||
"huge",
|
||||
"large",
|
||||
"little",
|
||||
"small",
|
||||
"tiny"
|
||||
],
|
||||
"fatness": [
|
||||
"fat",
|
||||
"skinny",
|
||||
"thin"
|
||||
],
|
||||
"gender": [
|
||||
"female",
|
||||
"male"
|
||||
],
|
||||
"height": [
|
||||
"short",
|
||||
"tall"
|
||||
],
|
||||
"state": [
|
||||
"calm",
|
||||
"choppy",
|
||||
"rough",
|
||||
"smooth",
|
||||
"still",
|
||||
"wavy"
|
||||
],
|
||||
"hposition": [
|
||||
"left",
|
||||
"right"
|
||||
],
|
||||
"length": [
|
||||
"long",
|
||||
"short"
|
||||
],
|
||||
"shape": [
|
||||
"octagonal",
|
||||
"rectangular",
|
||||
"round",
|
||||
"square",
|
||||
"triangular"
|
||||
],
|
||||
"pattern": [
|
||||
"checkered",
|
||||
"dotted",
|
||||
"striped"
|
||||
],
|
||||
"thickness": [
|
||||
"thick",
|
||||
"thin"
|
||||
],
|
||||
"age": [
|
||||
"little",
|
||||
"old",
|
||||
"young"
|
||||
],
|
||||
"tone": [
|
||||
"light",
|
||||
"dark"
|
||||
],
|
||||
"room": [
|
||||
"attic",
|
||||
"bathroom",
|
||||
"bedroom",
|
||||
"dining room",
|
||||
"kitchen",
|
||||
"living room",
|
||||
"office"
|
||||
],
|
||||
"width": [
|
||||
"narrow",
|
||||
"wide"
|
||||
],
|
||||
"depth": [
|
||||
"deep",
|
||||
"shallow"
|
||||
],
|
||||
"cleanliness": [
|
||||
"clean",
|
||||
"dirty",
|
||||
"stained",
|
||||
"tinted"
|
||||
],
|
||||
"hardness": [
|
||||
"hard",
|
||||
"soft"
|
||||
],
|
||||
"race": [
|
||||
"asian",
|
||||
"caucasian"
|
||||
],
|
||||
"company": [
|
||||
"adida",
|
||||
"nike"
|
||||
],
|
||||
"sportActivity": [
|
||||
"performing trick",
|
||||
"riding",
|
||||
"skateboarding",
|
||||
"skating",
|
||||
"skiing",
|
||||
"snowboarding",
|
||||
"surfing",
|
||||
"swimming"
|
||||
],
|
||||
"sportactivity": [
|
||||
"performing trick",
|
||||
"riding",
|
||||
"skateboarding",
|
||||
"skating",
|
||||
"skiing",
|
||||
"snowboarding",
|
||||
"surfing",
|
||||
"swimming"
|
||||
],
|
||||
"weight": [
|
||||
"heavy",
|
||||
"light"
|
||||
],
|
||||
"texture": [
|
||||
"coarse",
|
||||
"fine"
|
||||
],
|
||||
"flavor": [
|
||||
"chocolate",
|
||||
"strawberry",
|
||||
"vanilla"
|
||||
],
|
||||
"realism": [
|
||||
"fake",
|
||||
"real"
|
||||
],
|
||||
"face expression": [
|
||||
"making a face"
|
||||
]
|
||||
}
|
314
gqa_all_relations_map.json
Normal file
314
gqa_all_relations_map.json
Normal file
|
@ -0,0 +1,314 @@
|
|||
{
|
||||
"to the left of": "to the left of",
|
||||
"to the right of": "to the right of",
|
||||
"on": "on",
|
||||
"wearing": "wearing",
|
||||
"of": "of",
|
||||
"near": "near",
|
||||
"in": "in",
|
||||
"behind": "behind",
|
||||
"in front of": "in front of",
|
||||
"holding": "holding",
|
||||
"on top of": "on top of",
|
||||
"next to": "next to",
|
||||
"above": "above",
|
||||
"with": "with",
|
||||
"below": "below",
|
||||
"by": "by",
|
||||
"sitting on": "sitting on",
|
||||
"under": "under",
|
||||
"on the side of": "on the side of",
|
||||
"beside": "beside",
|
||||
"standing on": "standing on",
|
||||
"inside": "inside",
|
||||
"carrying": "carrying",
|
||||
"at": "at",
|
||||
"walking on": "walking on",
|
||||
"riding": "riding",
|
||||
"standing in": "standing in",
|
||||
"covered by": "covered by",
|
||||
"around": "around",
|
||||
"lying on": "lying on",
|
||||
"hanging on": "hanging on",
|
||||
"eating": "eating",
|
||||
"watching": "watching",
|
||||
"looking at": "looking at",
|
||||
"covering": "covering",
|
||||
"sitting in": "sitting in",
|
||||
"on the front of": "on the front of",
|
||||
|
||||
|
||||
"hanging from": "hanging on",
|
||||
"parked on": "on",
|
||||
"riding on": "on",
|
||||
"using": "holding",
|
||||
"covered in": "covered by",
|
||||
"flying in": "sitting in",
|
||||
"sitting at": "sitting in",
|
||||
"playing with": "holding",
|
||||
"full of": "carrying",
|
||||
"filled with": "carrying",
|
||||
"walking in": "walking on",
|
||||
"crossing": "walking on",
|
||||
"on the back of": "behind",
|
||||
"surrounded by": "inside",
|
||||
"swinging": "sitting in",
|
||||
"standing next to": "next to",
|
||||
"reflected in": "near",
|
||||
"covered with": "covered by",
|
||||
"touching": "holding",
|
||||
"flying": "near",
|
||||
"pulling": "holding",
|
||||
"pulled by": "next to",
|
||||
"contain": "carrying",
|
||||
"hitting": "holding",
|
||||
"leaning on": "next to",
|
||||
"lying in": "lying on",
|
||||
"standing by": "next to",
|
||||
"driving on": "walking on",
|
||||
"throwing": "near",
|
||||
"sitting on top of": "on top of",
|
||||
"surrounding": "around",
|
||||
"underneath": "below",
|
||||
"walking down": "walking on",
|
||||
"parked in": "standing in",
|
||||
"growing in": "sitting in",
|
||||
"standing near": "near",
|
||||
"growing on": "sitting on",
|
||||
"standing behind": "behind",
|
||||
"playing": "holding",
|
||||
"printed on": "on",
|
||||
"mounted on": "on",
|
||||
"beneath": "below",
|
||||
"attached to": "next to",
|
||||
"talking on": "on",
|
||||
"facing": "looking at",
|
||||
"leaning against": "next to",
|
||||
"cutting": "holding",
|
||||
"driving": "sitting in",
|
||||
"worn on": "on",
|
||||
"resting on": "on",
|
||||
"floating in": "in",
|
||||
"lying on top of": "on top of",
|
||||
"catching": "holding",
|
||||
"grazing on": "standing on",
|
||||
"on the bottom of": "below",
|
||||
"drinking": "holding",
|
||||
"standing in front of": "in front of",
|
||||
"topped with": "on top of",
|
||||
"playing in": "inside",
|
||||
"walking with": "with",
|
||||
"swimming in": "inside",
|
||||
"driving down": "walking on",
|
||||
"hanging over": "above",
|
||||
"pushed by": "next to",
|
||||
"pushing": "holding",
|
||||
"playing on": "on",
|
||||
"sitting next to": "next to",
|
||||
"close to": "near",
|
||||
"feeding": "holding",
|
||||
"waiting for": "near",
|
||||
"between": "next to",
|
||||
"running on": "walking on",
|
||||
"tied to": "next to",
|
||||
"on the edge of": "on top of",
|
||||
"talking to": "next to",
|
||||
"holding onto": "holding",
|
||||
"eating from": "holding",
|
||||
"perched on": "on",
|
||||
"reading": "holding",
|
||||
"parked by": "next to",
|
||||
"painted on": "on",
|
||||
"reaching for": "holding",
|
||||
"sleeping on": "lying on",
|
||||
"connected to": "next to",
|
||||
"grazing in": "in",
|
||||
"hanging above": "above",
|
||||
"floating on": "on",
|
||||
"wrapped around": "around",
|
||||
"stacked on": "on",
|
||||
"skiing on": "walking on",
|
||||
"parked at": "next to",
|
||||
"standing at": "next to",
|
||||
"hanging in": "hanging on",
|
||||
"parked near": "near",
|
||||
"walking across": "walking on",
|
||||
"plugged into": "next to",
|
||||
"standing beside": "beside",
|
||||
"parked next to": "next to",
|
||||
"working on": "on",
|
||||
"stuck on": "on",
|
||||
"stuck in": "in",
|
||||
"drinking from": "holding",
|
||||
"seen through": "in front of",
|
||||
"kicking": "near",
|
||||
"sitting by": "by",
|
||||
"sitting in front of": "in front of",
|
||||
"looking out": "behind",
|
||||
"petting": "holding",
|
||||
"parked in front of": "in front of",
|
||||
"wrapped in": "covered by",
|
||||
"flying over": "above",
|
||||
"selling": "holding",
|
||||
"lying inside": "lying on",
|
||||
"coming from": "near",
|
||||
"parked along": "standing on",
|
||||
"serving": "holding",
|
||||
"sitting inside": "inside",
|
||||
"sitting with": "with",
|
||||
"walking by": "by",
|
||||
"standing under": "below",
|
||||
"making": "holding",
|
||||
"walking through": "walking on",
|
||||
"standing on top of": "on top of",
|
||||
"hung on": "below",
|
||||
"walking along": "by",
|
||||
"walking near": "near",
|
||||
"going down": "walking on",
|
||||
"flying through": "near",
|
||||
"running in": "walking on",
|
||||
"leaving": "near",
|
||||
"mounted to": "on top of",
|
||||
"sitting behind": "behind",
|
||||
"on the other side of": "on the side of",
|
||||
"licking": "holding",
|
||||
"riding in": "riding",
|
||||
"followed by": "by",
|
||||
"following": "by",
|
||||
"sniffing": "looking at",
|
||||
"biting": "with",
|
||||
"parked alongside": "by",
|
||||
"flying above": "above",
|
||||
"chasing": "near",
|
||||
"leading": "near",
|
||||
"boarding": "near",
|
||||
"hanging off": "below",
|
||||
"walking behind": "behind",
|
||||
"parked behind": "behind",
|
||||
"sitting near": "near",
|
||||
"helping": "holding",
|
||||
"parked beside": "beside",
|
||||
"growing near": "near",
|
||||
"sitting under": "below",
|
||||
"coming out of": "in front of",
|
||||
"sitting beside": "beside",
|
||||
"hanging out of": "hanging on",
|
||||
"served on": "on",
|
||||
"staring at": "looking at",
|
||||
"walking toward": "near",
|
||||
"hugging": "carrying",
|
||||
"skiing in": "in",
|
||||
"entering": "in front of",
|
||||
"looking in": "looking at",
|
||||
"draped over": "covering",
|
||||
"walking next to": "next to",
|
||||
"tied around": "covering",
|
||||
"growing behind": "behind",
|
||||
"exiting": "in front of",
|
||||
"balancing on": "on",
|
||||
"drawn on": "on",
|
||||
"jumping over": "above",
|
||||
"looking down at": "below",
|
||||
"looking into": "looking at",
|
||||
"reflecting in": "in front of",
|
||||
"posing with": "with",
|
||||
"eating at": "at",
|
||||
"sewn on": "on",
|
||||
"walking up": "walking on",
|
||||
"leaning over": "on the side of",
|
||||
"about to hit": "holding",
|
||||
"reflected on": "in front of",
|
||||
"approaching": "near",
|
||||
"getting on": "on",
|
||||
"observing": "watching",
|
||||
"growing next to": "next to",
|
||||
"traveling on": "on",
|
||||
"walking towards": "near",
|
||||
"growing by": "by",
|
||||
"displayed on": "on",
|
||||
"wading in": "standing in",
|
||||
"growing along": "beside",
|
||||
"mixed with": "covered by",
|
||||
"grabbing": "holding",
|
||||
"jumping on": "walking on",
|
||||
"scattered on": "on",
|
||||
"opening": "holding",
|
||||
"climbing": "walking on",
|
||||
"pointing at": "at",
|
||||
"preparing": "holding",
|
||||
"coming down": "above",
|
||||
"decorated by": "by",
|
||||
"decorating": "on",
|
||||
"taller than": "than",
|
||||
"going into": "standing in",
|
||||
"growing from": "on",
|
||||
"tossing": "holding",
|
||||
"eating in": "in",
|
||||
"sleeping in": "inside",
|
||||
"herding": "near",
|
||||
"chewing": "eating",
|
||||
"washing": "holding",
|
||||
"looking through": "looking at",
|
||||
"picking up": "holding",
|
||||
"trying to catch": "holding",
|
||||
"working in": "in",
|
||||
"slicing": "holding",
|
||||
"skiing down": "walking on",
|
||||
"looking over": "looking at",
|
||||
"standing against": "next to",
|
||||
"typing on": "on",
|
||||
"piled on": "on",
|
||||
"lying next to": "next to",
|
||||
"tying": "standing on",
|
||||
"smiling at": "looking at",
|
||||
"smoking": "holding",
|
||||
"cleaning": "carrying",
|
||||
"shining through": "behind",
|
||||
"guiding": "near",
|
||||
"walking to": "near",
|
||||
"chained to": "next to",
|
||||
"dragging": "carrying",
|
||||
"cooking": "holding",
|
||||
"going through": "holding",
|
||||
"enclosing": "covering",
|
||||
"smelling": "eating",
|
||||
"adjusting": "holding",
|
||||
"photographing": "looking at",
|
||||
"skating on": "walking on",
|
||||
"running through": "walking on",
|
||||
"decorated with": "with",
|
||||
"kissing": "next to",
|
||||
"falling off": "below",
|
||||
"walking into": "in front of",
|
||||
"blowing out": "eating",
|
||||
"walking past": "behind",
|
||||
"towing": "near",
|
||||
"worn around": "covering",
|
||||
"jumping off": "on top of",
|
||||
"sprinkled on": "on top of",
|
||||
"moving": "carrying",
|
||||
"running across": "walking on",
|
||||
"hidden by": "behind",
|
||||
"traveling down": "walking on",
|
||||
"looking toward": "looking at",
|
||||
"splashing": "near",
|
||||
"hang from": "below",
|
||||
"kept in": "inside",
|
||||
"sitting around": "sitting on",
|
||||
"displayed in": "inside",
|
||||
"cooked in": "inside",
|
||||
"sitting atop": "sitting on",
|
||||
"brushing": "holding",
|
||||
"in between": "next to",
|
||||
"buying": "holding",
|
||||
"standing around": "next to",
|
||||
"larger than": "than",
|
||||
"smaller than": "than",
|
||||
"pouring": "holding",
|
||||
"playing at": "at",
|
||||
"longer than": "than",
|
||||
"higher than": "than",
|
||||
"jumping in": "in",
|
||||
"shorter than": "than",
|
||||
"bigger than": "than"
|
||||
}
|
1541
gqa_all_vocab_classes.json
Normal file
1541
gqa_all_vocab_classes.json
Normal file
File diff suppressed because it is too large
Load diff
39
requirements.txt
Normal file
39
requirements.txt
Normal file
|
@ -0,0 +1,39 @@
|
|||
# Requirments for SSP VQA project
|
||||
|
||||
# Standard Libraries
|
||||
opencv-python
|
||||
ipykernel
|
||||
ipywidgets
|
||||
numpy
|
||||
pandas
|
||||
tqdm
|
||||
matplotlib
|
||||
imageio
|
||||
moviepy
|
||||
scikit-learn
|
||||
wandb
|
||||
torchinfo
|
||||
torchmetrics
|
||||
|
||||
# Nengo Libraries
|
||||
nengo
|
||||
nbconvert>=7
|
||||
mistune>=2
|
||||
nengo-spa
|
||||
|
||||
# CLIP requirements
|
||||
ftfy
|
||||
regex
|
||||
tqdm
|
||||
|
||||
# DFOL-VQA Libraries
|
||||
h5py
|
||||
pyyaml
|
||||
mysqlclient
|
||||
pattern
|
||||
|
||||
# NLP Libraries
|
||||
stanza
|
||||
sexpdata
|
||||
nltk
|
||||
svgling
|
968
run_programs.py
Normal file
968
run_programs.py
Normal file
|
@ -0,0 +1,968 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
os.environ["OPENBLAS_NUM_THREADS"] = "10"
|
||||
|
||||
import json
|
||||
import cv2
|
||||
import re
|
||||
import random
|
||||
import time
|
||||
import torch
|
||||
import clip
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import nengo.spa as spa
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as patches
|
||||
import matplotlib.colors as mcolors
|
||||
from collections import OrderedDict
|
||||
from pattern.text.en import singularize, pluralize
|
||||
|
||||
from dataset import GQADataset
|
||||
from utils import *
|
||||
|
||||
DATA_PATH = '/scratch/penzkofer/GQA'
|
||||
RGB_COLORS = []
|
||||
for name, hex in mcolors.cnames.items():
|
||||
RGB_COLORS.append(mcolors.to_rgb(hex))
|
||||
|
||||
CUDA_DEVICE = 7
|
||||
torch.cuda.set_device(CUDA_DEVICE)
|
||||
|
||||
device = torch.device("cuda:" + str(CUDA_DEVICE))
|
||||
clip_model, preprocess = clip.load("ViT-B/32", device=device)
|
||||
|
||||
with open('gqa_all_relations_map.json') as f:
|
||||
RELATION_DICT = json.load(f)
|
||||
|
||||
with open('gqa_all_vocab_classes.json') as f:
|
||||
CLASS_DICT = json.load(f)
|
||||
|
||||
with open('gqa_all_attributes.json') as f:
|
||||
ATTRIBUTE_DICT = json.load(f)
|
||||
|
||||
SYNONYMS = {'he': ['man', 'boy'], 'she': ['woman', 'girl']}
|
||||
ANSWER_MAP = {'to the right of': 'right', 'to the left of': 'left'}
|
||||
VISUALIZE = False
|
||||
|
||||
|
||||
def plot_heatmap_multidim(sp, xs, ys, heatmap_vectors, name='', vmin=-1, vmax=1, cmap='plasma', invert=False):
|
||||
"""adapted from https://github.com/ctn-waterloo/cogsci2019-ssp/tree/master"""
|
||||
|
||||
assert sp.__class__.__name__ == 'SemanticPointer', \
|
||||
f'Queried object needs to be of type SemanticPointer but is {sp.__class__.__name__}'
|
||||
|
||||
# axes: a list of axes to be summed over, first sequence applying to first tensor, second to second tensor
|
||||
vs = np.tensordot(sp.v, heatmap_vectors, axes=([0], [4]))
|
||||
res = np.unravel_index(np.argmax(vs, axis=None), vs.shape)
|
||||
|
||||
plt.imshow(np.transpose(vs[:, :, res[2], res[3]]), origin='upper', interpolation='none', extent=(xs[-1], xs[0], ys[-1], ys[0]), vmin=vmin, vmax=vmax, cmap=cmap)
|
||||
plt.colorbar()
|
||||
plt.axis('off')
|
||||
plt.title(name)
|
||||
plt.show()
|
||||
|
||||
def select_ssp(name, memory, encoded_ssps, vectors, linspace):
|
||||
"""decode location of object with name from SSP memory"""
|
||||
ssp_item = encoded_ssps[name]
|
||||
item_decoded = memory *~ ssp_item
|
||||
clean_loc = ssp_to_loc_multidim(item_decoded, vectors, linspace)
|
||||
return item_decoded, clean_loc
|
||||
|
||||
|
||||
def clip_query(bbox, img, obj_name, clip_tokens, visualize=False):
|
||||
"""Implements CLIP queries for different attributes"""
|
||||
x, y, w, h = bbox
|
||||
obj_center = (x + w / 2, y + h / 2)
|
||||
masked_img = img.copy()
|
||||
masked_img = cv2.ellipse(masked_img, (int(obj_center[0]), int(obj_center[1])), (int(w*0.7), int(h*0.7)),
|
||||
0, 0, 360, (255, 0, 0), 4)
|
||||
|
||||
if visualize:
|
||||
plt.imshow(masked_img)
|
||||
plt.axis('off')
|
||||
plt.show()
|
||||
|
||||
masked_img = Image.fromarray(np.uint8(masked_img))
|
||||
|
||||
tokens = clip.tokenize(clip_tokens)
|
||||
|
||||
with torch.no_grad():
|
||||
|
||||
text_features=clip_model.encode_text(tokens.to(device))
|
||||
image_features = clip_model.encode_image(preprocess(masked_img).unsqueeze(0).to(device))
|
||||
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
|
||||
indices = torch.max(similarity, 1)[1]
|
||||
|
||||
similarity = similarity.squeeze()
|
||||
scores = [s.item() for s in similarity]
|
||||
pred = clip_tokens[indices.squeeze().item()]
|
||||
|
||||
if visualize:
|
||||
print('CLIP')
|
||||
print(scores)
|
||||
print(pred)
|
||||
return indices.squeeze().item()
|
||||
|
||||
def clip_query_scene(img, clip_tokens, verbose=0):
|
||||
"""Implements CLIP queries for entire scene, i.e. no bounding box selection"""
|
||||
img = Image.fromarray(np.uint8(img))
|
||||
tokens = clip.tokenize(clip_tokens)
|
||||
|
||||
with torch.no_grad():
|
||||
|
||||
text_features=clip_model.encode_text(tokens.to(device))
|
||||
image_features = clip_model.encode_image(preprocess(img).unsqueeze(0).to(device))
|
||||
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
|
||||
indices = torch.max(similarity, 1)[1]
|
||||
|
||||
similarity = similarity.squeeze()
|
||||
scores = [s.item() for s in similarity]
|
||||
pred = clip_tokens[indices.squeeze().item()]
|
||||
|
||||
if verbose > 0:
|
||||
print('CLIP')
|
||||
print(scores)
|
||||
print(pred)
|
||||
|
||||
return indices.squeeze().item()
|
||||
|
||||
|
||||
def clip_choose(bbox1, bbox2, img, attribute, visualize=False):
|
||||
"""Run attribute vs. not attribute check for both subjects
|
||||
-- select clip prediction with higher confidence"""
|
||||
|
||||
x1, y1, w1, h1 = bbox1
|
||||
obj_center = (x1 + w1 / 2, y1 + h1 / 2)
|
||||
masked_img1 = img.copy()
|
||||
masked_img1 = cv2.ellipse(masked_img1, (int(obj_center[0]), int(obj_center[1])), (int(w1*0.7), int(h1*0.7)),
|
||||
0, 0, 360, (255, 0, 0), 4)
|
||||
|
||||
x2, y2, w2, h2 = bbox2
|
||||
obj_center = (x2 + w2 / 2, y2 + h2 / 2)
|
||||
masked_img2 = img.copy()
|
||||
masked_img2 = cv2.ellipse(masked_img2, (int(obj_center[0]), int(obj_center[1])), (int(w2*0.7), int(h2*0.7)),
|
||||
0, 0, 360, (255, 0, 0), 4)
|
||||
|
||||
if visualize:
|
||||
plt.imshow(masked_img1)
|
||||
plt.axis('off')
|
||||
plt.show()
|
||||
plt.imshow(masked_img2)
|
||||
plt.axis('off')
|
||||
plt.show()
|
||||
|
||||
masked_img1 = Image.fromarray(np.uint8(masked_img1))
|
||||
masked_img2 = Image.fromarray(np.uint8(masked_img2))
|
||||
|
||||
tokens = clip.tokenize([attribute, f'not {attribute}'])
|
||||
|
||||
with torch.no_grad():
|
||||
|
||||
text_features=clip_model.encode_text(tokens.to(device))
|
||||
image_features = clip_model.encode_image(preprocess(masked_img1).unsqueeze(0).to(device))
|
||||
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||
similarity1 = (100.0 * image_features @ text_features.T).softmax(dim=-1)
|
||||
similarity1 = similarity1.squeeze()[0]
|
||||
|
||||
with torch.no_grad():
|
||||
|
||||
text_features=clip_model.encode_text(tokens.to(device))
|
||||
image_features = clip_model.encode_image(preprocess(masked_img2).unsqueeze(0).to(device))
|
||||
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||
similarity2 = (100.0 * image_features @ text_features.T).softmax(dim=-1)
|
||||
similarity2 = similarity2.squeeze()[0]
|
||||
|
||||
if visualize:
|
||||
logging.info(f'CLIP prediction: {[similarity1, similarity2]}')
|
||||
|
||||
return 0 if similarity1 > similarity2 else 1
|
||||
|
||||
|
||||
def get_rel_path(rel, verbose=0):
|
||||
"""get correct relation path, map rel to one of the 37 existing query masks"""
|
||||
|
||||
rel = RELATION_DICT.get(rel.strip()) # get synonym of relation if no mask exists
|
||||
rel = '_'.join(rel.split(' ')) if ' ' in rel else rel
|
||||
path = 'relations/' + rel + '.npy'
|
||||
if verbose > 0:
|
||||
logging.info('Loading ', path)
|
||||
return path
|
||||
|
||||
def use_query_mask(obj_pos, info, rel, linspace, axes, dim, memory, verbose=0, visualize=False):
|
||||
"""implements query mask usage: load spatial query mask for relation rel,
|
||||
scale query mask to object, encode mask to SSP region and
|
||||
move to object position in SSP memory, extract object proposals in region"""
|
||||
|
||||
xs, ys, ws, hs = linspace
|
||||
x_axis, y_axis, w_axis, h_axis = axes
|
||||
|
||||
x, y, width, height = obj_pos
|
||||
# 50 pixels was object size in query mask generation -- use pixel scale values for height and width
|
||||
iso_scale = np.mean([(width*info['scales'][1]) / 50, (height*info['scales'][2]) / 50])
|
||||
|
||||
mask = np.load(get_rel_path(rel, verbose))
|
||||
mask = cv2.resize(mask, (100, 100), interpolation = cv2.INTER_AREA)
|
||||
|
||||
# crop new query mask according to scale
|
||||
new_area = int(100 / iso_scale)
|
||||
if verbose > 0:
|
||||
print(iso_scale, new_area)
|
||||
|
||||
resized = mask[max(0, 50-new_area): min(50+new_area, 100), max(0, 50-new_area): min(50+new_area, 100)]
|
||||
resized = cv2.resize(resized, (100,100), interpolation = cv2.INTER_AREA)
|
||||
|
||||
if visualize:
|
||||
fig, axs = plt.subplots(1,2, sharey=True, layout="constrained", figsize=(6, 3))
|
||||
fig.set_tight_layout(True)
|
||||
fig.subplots_adjust(top=1.05)
|
||||
fig.suptitle('Relation: '+ rel)
|
||||
plt.setp(plt.gcf().get_axes(), xticks=[], yticks=[])
|
||||
|
||||
axs[0].imshow(mask, cmap='gray')
|
||||
axs[0].title.set_text('original')
|
||||
axs[1].imshow(resized, cmap='gray')
|
||||
axs[1].title.set_text('resized')
|
||||
plt.show()
|
||||
|
||||
# encode mask to SSP region
|
||||
counter = 0
|
||||
vector = spa.SemanticPointer(data=np.zeros(dim))
|
||||
for (i, j) in zip(*np.where(resized > 0.05)):
|
||||
x, y = xs[i], ys[j]
|
||||
vector += encode_point_multidim([y, x, 1, 1], axes=axes)
|
||||
counter += 1
|
||||
vector.normalize()
|
||||
if verbose > 0:
|
||||
logging.info(f'Resized mask encoded {counter} points')
|
||||
|
||||
if visualize:
|
||||
plot_heatmap_multidim(vector, xs, ys, VECTORS, vmin=-0.2, vmax=0.2, name=f'Encoded Region')
|
||||
|
||||
# get object info and move query mask to position
|
||||
x, y, width, height = obj_pos
|
||||
obj_center = (x + width / 2, y + height / 2)
|
||||
img_center = np.array([xs[50], ys[50]])
|
||||
|
||||
shift = -img_center + (obj_center[1], obj_center[0])
|
||||
encoded_shift = encode_point(shift[1], shift[0], x_axis=x_axis, y_axis=y_axis)
|
||||
|
||||
shifted_pos = vector.convolve(encoded_shift)
|
||||
shifted_pos.normalize()
|
||||
|
||||
if visualize:
|
||||
plot_heatmap_multidim(shifted_pos, xs, ys, VECTORS, vmin=-0.1, vmax=0.1, name=f'Query Region')
|
||||
|
||||
# query region and compare output to vocab = all saved SSPs
|
||||
vocab_vectors = np.zeros((len(info['encoded_ssps']), dim))
|
||||
|
||||
color_lst = []
|
||||
for i, (name, ssp) in enumerate(info['encoded_ssps'].items()):
|
||||
vocab_vectors[i, :] = ssp.v
|
||||
color_lst.append(RGB_COLORS[i])
|
||||
|
||||
similarity = np.tensordot((memory * ~shifted_pos).v, vocab_vectors, axes=([0], [1]))
|
||||
|
||||
d = OrderedDict(zip(list(info['encoded_ssps'].keys()), similarity))
|
||||
res = list(OrderedDict(sorted(d.items(), key=lambda x: x[1], reverse=True)))
|
||||
proposals = np.array(res)[np.array(np.array(sorted(similarity, reverse=True)) > 0).astype(bool)]
|
||||
|
||||
if verbose > 0:
|
||||
print('Proposals: ', *proposals)
|
||||
|
||||
if visualize:
|
||||
fig = plt.Figure()
|
||||
plt.bar(np.arange(len(info['encoded_ssps'])), similarity, color=color_lst, label=info['encoded_items'].keys())
|
||||
plt.title('Similarity')
|
||||
plt.legend(loc='upper left', bbox_to_anchor=(1.0, 1.05))
|
||||
plt.show()
|
||||
|
||||
return proposals
|
||||
|
||||
def select_func(data, img, obj, info, counter, memory, visualize=False):
|
||||
"""implements select function: probe SSP memory with given object name"""
|
||||
|
||||
result = None
|
||||
|
||||
if obj + '_' + str(counter) in info['encoded_ssps'].keys():
|
||||
obj += '_' + str(counter)
|
||||
obj_ssp, obj_pos = select_ssp(obj, memory, info['encoded_ssps'], data.ssp_vectors, data.linspace)
|
||||
result = [obj, obj_ssp, obj_pos]
|
||||
|
||||
else:
|
||||
# test synonyms and plural/singular
|
||||
test = [singularize(obj), pluralize(obj)]
|
||||
|
||||
if SYNONYMS.get(obj):
|
||||
test += SYNONYMS.get(obj)
|
||||
|
||||
if CLASS_DICT.get(obj):
|
||||
test += CLASS_DICT.get(obj)
|
||||
|
||||
for obj in test:
|
||||
obj = str(obj) + '_' + str(counter)
|
||||
if obj in info['encoded_ssps'].keys():
|
||||
obj_ssp, obj_pos = select_ssp(obj, memory, info['encoded_ssps'], data.ssp_vectors, data.linspace)
|
||||
result = [obj, obj_ssp, obj_pos]
|
||||
|
||||
if result is not None and visualize:
|
||||
fig, ax = plt.subplots(1,1)
|
||||
plt.axis('off')
|
||||
ax.imshow(img)
|
||||
rect = patches.Rectangle((obj_pos[0] * info['scales'][0], obj_pos[1] * info['scales'][0]),
|
||||
obj_pos[2] * info['scales'][1], obj_pos[3] * info['scales'][2],
|
||||
linewidth = 2,
|
||||
label = name,
|
||||
edgecolor = 'c',
|
||||
facecolor = 'none')
|
||||
ax.add_patch(rect)
|
||||
plt.show()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def verify_func(data, img, attr, results, info, dim, memory, verbose=0, visualize=False):
|
||||
"""implements all verify functions dependent on length of attributes given"""
|
||||
|
||||
if len(attr) == 2:
|
||||
# verify_color, verify_shape, verify_scene
|
||||
num = int(re.findall(r'\d+', attr[0])[0])
|
||||
|
||||
if results[num] is not None:
|
||||
name, obj_ssp, obj_pos = results[num]
|
||||
x, y = obj_pos[0] * info['scales'][0], obj_pos[1] * info['scales'][0]
|
||||
w, h = obj_pos[2] * info['scales'][1], obj_pos[3] * info['scales'][2]
|
||||
|
||||
clip_tokens = [f'The {name.split("_")[0]} is {attr[1].strip()}',
|
||||
f'The {name.split("_")[0]} is not {attr[1].strip()}']
|
||||
|
||||
pred = clip_query([x, y, w, h], img, name.split('_')[0], clip_tokens, visualize=visualize)
|
||||
|
||||
return True if pred == 0 else False
|
||||
|
||||
else:
|
||||
return False
|
||||
|
||||
elif len(attr) == 3:
|
||||
# verify_rel, verify_rel_inv
|
||||
obj, rel, rel_obj = attr
|
||||
num = int(re.findall(r'\d+', obj)[0])
|
||||
|
||||
proposals = []
|
||||
|
||||
if results[num] is not None:
|
||||
name, obj_ssp, obj_pos = results[num]
|
||||
|
||||
proposals = use_query_mask(obj_pos, info, rel, data.linspace, data.ssp_axes, dim, memory, visualize=visualize)
|
||||
proposals = [str(p).split('_')[0] for p in proposals]
|
||||
|
||||
return True if rel_obj.strip() in proposals else False
|
||||
else:
|
||||
return False
|
||||
|
||||
# verify_f
|
||||
elif len(attr) == 1:
|
||||
|
||||
clip_tokens = [f'The image is {attr[0].strip()}',
|
||||
f'The image is not {attr[0].strip()}',
|
||||
f'The image is a {attr[0].strip()}',
|
||||
f'The image is not a{attr[0].strip()}']
|
||||
|
||||
pred = clip_query_scene(img, clip_tokens, verbose=0)
|
||||
|
||||
return True if pred == 0 or pred == 2 else False
|
||||
|
||||
else:
|
||||
logging.warning('verify_func not implemented')
|
||||
|
||||
return -1
|
||||
|
||||
|
||||
def query_func(func, img, attr, results, info, img_size, dim, verbose=0, visualize=False):
|
||||
"""implements all query functions"""
|
||||
|
||||
if 'query_f(' in func:
|
||||
attr_type = attr[0].strip()
|
||||
|
||||
assert attr_type in CLASS_DICT, f'{attr_type} not found in class dictionary'
|
||||
|
||||
attributes = CLASS_DICT.get(attr_type)
|
||||
clip_tokens = [f'This is a {a} {attr_type}' for a in attributes]
|
||||
pred = clip_query_scene(img, clip_tokens, verbose=0)
|
||||
|
||||
return attributes[pred]
|
||||
|
||||
num = int(re.findall(r'\d+', attr[0])[0])
|
||||
|
||||
if results[num] is not None:
|
||||
name, obj_ssp, obj_pos = results[num]
|
||||
x, y = obj_pos[0] * info['scales'][0], obj_pos[1] * info['scales'][0]
|
||||
w, h = obj_pos[2] * info['scales'][1], obj_pos[3] * info['scales'][2]
|
||||
|
||||
if len(attr) == 1:
|
||||
|
||||
if 'query_n' in func:
|
||||
# query name
|
||||
return name.split('_')[0]
|
||||
|
||||
if 'query_h' in func:
|
||||
# query horizontal position --> x-value
|
||||
if (x + w / 2) >= (img_size[1] / 2):
|
||||
return 'right'
|
||||
else:
|
||||
return 'left'
|
||||
|
||||
if 'query_v' in func:
|
||||
# query vertical position --> y-value
|
||||
if (y + h / 2) >= (img_size[0] / 2):
|
||||
return 'bottom'
|
||||
else:
|
||||
return 'top'
|
||||
|
||||
elif len(attr) == 2:
|
||||
|
||||
attr_type = attr[1].strip()
|
||||
|
||||
assert attr_type in ATTRIBUTE_DICT, f'{attr_type} not found in attribute dictionary'
|
||||
|
||||
attributes = ATTRIBUTE_DICT.get(attr_type)
|
||||
clip_tokens = [f'The {attr_type} of {name.split("_")[0]} is {a}' for a in attributes]
|
||||
pred = clip_query([x, y, w, h], img, name.split('_')[0], clip_tokens, visualize=visualize)
|
||||
|
||||
return attributes[pred]
|
||||
|
||||
else:
|
||||
return None
|
||||
|
||||
logging.warning('query not implemented', func, attr)
|
||||
return -1
|
||||
|
||||
|
||||
def relate_func(data, func, attr, results, info, dim, memory, visualize=False):
|
||||
"""implements all relationship functions"""
|
||||
|
||||
if 'relate_inv_name' in func or 'relate_name' in func:
|
||||
obj, rel, rel_obj = attr
|
||||
num = int(re.findall(r'\d+', attr[0])[0])
|
||||
|
||||
if results[num] is not None:
|
||||
name, obj_ssp, obj_pos = results[num]
|
||||
|
||||
proposals = use_query_mask(obj_pos, info, rel, data.linspace, data.ssp_axes, dim,
|
||||
memory, verbose=0, visualize=visualize)
|
||||
|
||||
selected_obj = proposals[0]
|
||||
|
||||
# use rel_obj to filter proposals
|
||||
rel_obj = rel_obj.strip()
|
||||
if rel_obj == selected_obj.split('_')[0]:
|
||||
if visualize:
|
||||
logging.info('Found perfect match\n')
|
||||
return selected_obj
|
||||
|
||||
elif rel_obj in [str(p).split('_')[0] for p in proposals]:
|
||||
idx = [str(p).split('_')[0] for p in proposals].index(rel_obj)
|
||||
return proposals[idx]
|
||||
|
||||
elif rel_obj in CLASS_DICT.keys():
|
||||
class_lst = CLASS_DICT.get(rel_obj)
|
||||
|
||||
for p in [str(p).split('_')[0] for p in proposals]:
|
||||
if p in class_lst or singularize(p) in class_lst:
|
||||
if visualize:
|
||||
logging.info(f'Found better proposal for {rel_obj}: {p}\n')
|
||||
idx = [str(p).split('_')[0] for p in proposals].index(p)
|
||||
return proposals[idx]
|
||||
else:
|
||||
if visualize:
|
||||
logging.info(f'Did not find {rel_obj} in proposals\n')
|
||||
return None
|
||||
|
||||
elif 'relate_inv' in func or 'relate(' in func:
|
||||
obj, rel = attr
|
||||
num = int(re.findall(r'\d+', attr[0])[0])
|
||||
|
||||
if results[num] is not None:
|
||||
name, obj_ssp, obj_pos = results[num]
|
||||
|
||||
proposals = use_query_mask(obj_pos, info, rel, data.linspace, data.ssp_axes, dim,
|
||||
memory, visualize=visualize)
|
||||
|
||||
return proposals[0]
|
||||
|
||||
else:
|
||||
logging.warning(f'{func} not implemented')
|
||||
return -1
|
||||
|
||||
|
||||
def filter_func(func, img, attr, img_size, results, info, visualize=False):
|
||||
"""implements all filter functions"""
|
||||
|
||||
obj, filter_attr = attr
|
||||
num = int(re.findall(r'\d+', attr[0])[0])
|
||||
|
||||
if results[num] is not None:
|
||||
name, obj_ssp, obj_pos = results[num]
|
||||
x, y = obj_pos[0] * info['scales'][0], obj_pos[1] * info['scales'][0]
|
||||
w, h = obj_pos[2] * info['scales'][1], obj_pos[3] * info['scales'][2]
|
||||
|
||||
# query height --> y-value
|
||||
if 'bottom' in filter_attr or 'top' in filter_attr:
|
||||
if (y + h / 2) >= (img_size[0] / 2):
|
||||
pred_attr = 'bottom'
|
||||
else:
|
||||
pred_attr = 'top'
|
||||
|
||||
return pred_attr == filter_attr.strip()
|
||||
|
||||
# query side --> x-value
|
||||
if 'right' in filter_attr or 'left' in filter_attr:
|
||||
if (x + w / 2) >= (img_size[1] / 2):
|
||||
pred_attr = 'right'
|
||||
else:
|
||||
pred_attr = 'left'
|
||||
|
||||
return pred_attr == filter_attr.strip()
|
||||
|
||||
# filter by attribute: color, shape, activity, material
|
||||
else:
|
||||
|
||||
clip_tokens = [f'The {name.split("_")[0]} is {filter_attr}',
|
||||
f'The {name.split("_")[0]} is not {filter_attr}']
|
||||
|
||||
pred = clip_query([x, y, w, h], img, name.split('_')[0], clip_tokens, visualize=visualize)
|
||||
|
||||
if 'not' in func:
|
||||
return True if pred == 1 else False
|
||||
else:
|
||||
return True if pred == 0 else False
|
||||
|
||||
else:
|
||||
if visualize:
|
||||
logging.info('No object was found in last step -- nothing to filter')
|
||||
return None
|
||||
|
||||
return -1
|
||||
|
||||
def choose_func(data, img, func, attr, img_size, results, info, dim, memory, verbose=0, visualize=False):
|
||||
"""implements all choose functions"""
|
||||
|
||||
if 'choose_f' in func:
|
||||
pred = clip_query_scene(img, attr, verbose=0)
|
||||
return attr[pred]
|
||||
|
||||
num = int(re.findall(r'\d+', attr[0])[0])
|
||||
|
||||
if results[num] is not None:
|
||||
name, obj_ssp, obj_pos = results[num]
|
||||
|
||||
if 'choose_h' in func or 'choose_v' in func:
|
||||
x, y = obj_pos[0] * info['scales'][0], obj_pos[1] * info['scales'][0]
|
||||
w, h = obj_pos[2] * info['scales'][1], obj_pos[3] * info['scales'][2]
|
||||
|
||||
# choose side --> x-value
|
||||
if 'choose_h' in func:
|
||||
if (x + w / 2) >= (img_size[1] / 2):
|
||||
return 'right'
|
||||
else:
|
||||
return 'left'
|
||||
|
||||
# choose vertical alignment --> y-value
|
||||
if 'choose_v' in func:
|
||||
if (y + h / 2) >= (img_size[0] / 2):
|
||||
return 'bottom'
|
||||
else:
|
||||
return 'top'
|
||||
|
||||
elif 'choose_n' in func:
|
||||
obj, name1, name2 = attr
|
||||
|
||||
if name == name1:
|
||||
return name1
|
||||
elif name == name2:
|
||||
return name2
|
||||
else:
|
||||
return None
|
||||
|
||||
elif 'choose_attr' in func:
|
||||
obj, attr_type, attr1, attr2 = attr
|
||||
x, y = obj_pos[0] * info['scales'][0], obj_pos[1] * info['scales'][0]
|
||||
w, h = obj_pos[2] * info['scales'][1], obj_pos[3] * info['scales'][2]
|
||||
|
||||
clip_tokens = [f'The {attr_type} of {name.split("_")[0]} is {attr1}',
|
||||
f'The {attr_type} of {name.split("_")[0]} is {attr2}']
|
||||
|
||||
pred = clip_query([x, y, w, h], img, name.split('_')[0], clip_tokens, visualize=visualize)
|
||||
|
||||
attr_lst = [attr1.strip(), attr2.strip()]
|
||||
|
||||
if visualize:
|
||||
logging.info(f'Choose attribute: {attr_type} of {name} --> clip prediction: {attr_lst[pred]}')
|
||||
|
||||
return attr_lst[pred]
|
||||
|
||||
elif 'choose_rel_inv' in func:
|
||||
obj, rel_obj, attr1, attr2 = attr
|
||||
|
||||
proposals1 = use_query_mask(obj_pos, info, attr1, data.linspace, data.ssp_axes, dim,
|
||||
memory, 0, visualize)
|
||||
proposals1 = [str(p).split('_')[0] for p in proposals1]
|
||||
|
||||
proposals2 = use_query_mask(obj_pos, info, attr2, data.linspace, data.ssp_axes, dim,
|
||||
memory, 0, visualize)
|
||||
proposals2 = [str(p).split('_')[0] for p in proposals2]
|
||||
|
||||
if rel_obj.strip() in proposals1:
|
||||
return ANSWER_MAP.get(attr1.strip()) if attr1.strip() in ANSWER_MAP else attr1.strip()
|
||||
elif rel_obj.strip() in proposals2:
|
||||
return ANSWER_MAP.get(attr2.strip()) if attr2.strip() in ANSWER_MAP else attr2.strip()
|
||||
else:
|
||||
return None
|
||||
|
||||
elif 'choose_subj' in func:
|
||||
subj1, subj2, attribute = attr
|
||||
num1 = int(re.findall(r'\d+', subj1)[0])
|
||||
num2 = int(re.findall(r'\d+', subj2)[0])
|
||||
|
||||
if results[num1] is not None and results[num2] is not None:
|
||||
name1, obj_ssp1, obj_pos1 = results[num1]
|
||||
name2, obj_ssp2, obj_pos2 = results[num2]
|
||||
|
||||
x1, y1 = obj_pos1[0] * info['scales'][0], obj_pos1[1] * info['scales'][0]
|
||||
w1, h1 = obj_pos1[2] * info['scales'][1], obj_pos1[3] * info['scales'][2]
|
||||
x2, y2 = obj_pos2[0] * info['scales'][0], obj_pos2[1] * info['scales'][0]
|
||||
w2, h2 = obj_pos2[2] * info['scales'][1], obj_pos2[3] * info['scales'][2]
|
||||
|
||||
pred = clip_choose([x1, y1, w1, h1], [x2, y2, w2, h2], img, attribute, visualize=visualize)
|
||||
|
||||
return name1.split('_')[0] if pred == 0 else name2.split('_')[0]
|
||||
|
||||
elif 'choose(' in func:
|
||||
obj, attr1, attr2 = attr
|
||||
x, y = obj_pos[0] * info['scales'][0], obj_pos[1] * info['scales'][0]
|
||||
w, h = obj_pos[2] * info['scales'][1], obj_pos[3] * info['scales'][2]
|
||||
|
||||
clip_tokens = [f'The {name.split("_")[0]} is {attr1}',
|
||||
f'The {name.split("_")[0]} is {attr2}']
|
||||
|
||||
pred = clip_query([x, y, w, h], img, name.split('_')[0], clip_tokens, visualize=visualize)
|
||||
|
||||
attr_lst = [attr1.strip(), attr2.strip()]
|
||||
|
||||
if visualize:
|
||||
logging.info(f'Choose {attr1} or {attr2} for {name} --> clip prediction: {attr_lst[pred]}')
|
||||
|
||||
return attr_lst[pred]
|
||||
|
||||
else:
|
||||
logging.warning(func, 'not implemented yet')
|
||||
return -1
|
||||
|
||||
else:
|
||||
if visualize:
|
||||
logging.info('No object was found in last step -- nothing to choose')
|
||||
return None
|
||||
|
||||
return -1
|
||||
|
||||
|
||||
def run_program(data, img, info, counter, memory, dim, verbose=0):
|
||||
""" run program for question on given image:
|
||||
for each step in program select appropriate function
|
||||
"""
|
||||
scale, w_scale, h_scale = info['scales']
|
||||
img_size = img.shape[:2]
|
||||
|
||||
results = []
|
||||
last_step = False
|
||||
last_func = None
|
||||
|
||||
for i, step in enumerate(info['program']):
|
||||
|
||||
if i+1 == len(info['program']):
|
||||
last_step = True
|
||||
|
||||
_, func = step.split('=')
|
||||
attr = func.split('(')[-1].split(')')[0].split(',')
|
||||
|
||||
if verbose > 0:
|
||||
logging.info(f'{i+1}. step: \t {func}')
|
||||
|
||||
if 'select' in func:
|
||||
obj = attr[0].strip()
|
||||
res = select_func(data, img, obj, info, counter, memory, visualize=VISUALIZE)
|
||||
|
||||
results.append(res)
|
||||
|
||||
if res is None:
|
||||
if verbose > 1:
|
||||
logging.info(f'Could not find {obj}')
|
||||
|
||||
|
||||
elif 'relate' in func:
|
||||
|
||||
found_rel_obj = relate_func(data, func, attr, results, info, dim, memory, visualize=VISUALIZE)
|
||||
|
||||
if found_rel_obj is not None:
|
||||
assert found_rel_obj in info['encoded_ssps'], f'Result of {func}: {found_rel_obj} is not encoded'
|
||||
|
||||
selected_ssp = info['encoded_ssps'][found_rel_obj]
|
||||
_, selected_pos = select_ssp(found_rel_obj, memory, info['encoded_ssps'], data.ssp_vectors, data.linspace)
|
||||
results.append([found_rel_obj, selected_ssp, selected_pos])
|
||||
|
||||
if last_step:
|
||||
return 'yes'
|
||||
else:
|
||||
results.append(None)
|
||||
|
||||
if last_step:
|
||||
return 'no'
|
||||
|
||||
elif 'filter' in func:
|
||||
|
||||
last_filter = filter_func(func, img, attr, img_size, results, info, visualize=VISUALIZE)
|
||||
|
||||
if last_filter:
|
||||
results.append(results[-1])
|
||||
else:
|
||||
if results[-1] is None:
|
||||
results.append(None)
|
||||
elif results[-1][0].split("_")[0] + "_" + str(counter+1) in info['encoded_ssps'].keys():
|
||||
counter += 1
|
||||
return None
|
||||
else:
|
||||
last_filter = False
|
||||
results.append(results[-1])
|
||||
|
||||
elif 'verify' in func:
|
||||
pred = verify_func(data, img, attr, results, info, dim, memory, verbose=verbose, visualize=VISUALIZE)
|
||||
|
||||
if 'verify_relation_name' in func or 'verify_relation_inv_name' in func:
|
||||
results.append(results[-1] if pred else None)
|
||||
else:
|
||||
results.append(pred)
|
||||
|
||||
if last_step:
|
||||
return 'yes' if pred else 'no'
|
||||
|
||||
|
||||
elif 'query' in func:
|
||||
|
||||
return query_func(func, img, attr, results, info, img_size, dim, verbose=verbose, visualize=VISUALIZE)
|
||||
|
||||
elif 'exist' in func:
|
||||
|
||||
num = int(re.findall(r'\d+', attr[0])[0])
|
||||
|
||||
if last_step:
|
||||
return 'yes' if results[num] is not None else 'no'
|
||||
else:
|
||||
if results[num] is not None and 'filter' not in last_func:
|
||||
results.append(True)
|
||||
elif results[num] is not None and last_filter:
|
||||
results.append(True)
|
||||
else:
|
||||
results.append(False)
|
||||
|
||||
elif 'or(' in func:
|
||||
attr1 = int(re.findall(r'\d+', attr[0])[0])
|
||||
attr2 = int(re.findall(r'\d+', attr[1])[0])
|
||||
return 'yes' if results[attr1] or results[attr2] else 'no'
|
||||
|
||||
elif 'and(' in func:
|
||||
attr1 = int(re.findall(r'\d+', attr[0])[0])
|
||||
attr2 = int(re.findall(r'\d+', attr[1])[0])
|
||||
return 'yes' if results[attr1] and results[attr2] else 'no'
|
||||
|
||||
elif 'different' in func:
|
||||
|
||||
if len(attr) == 1:
|
||||
logging.warning(f'{func} cannot be computed')
|
||||
return None
|
||||
|
||||
else:
|
||||
pred_attr1 = query_func(f'query_{attr[2].strip()}', img, [attr[0], attr[2]], results, info, img_size, dim)
|
||||
pred_attr2 = query_func(f'query_{attr[2].strip()}', img, [attr[1], attr[2]], results, info, img_size, dim)
|
||||
|
||||
if pred_attr1 != pred_attr2:
|
||||
return 'yes'
|
||||
else:
|
||||
return 'no'
|
||||
|
||||
elif 'same' in func:
|
||||
|
||||
if len(attr) == 1:
|
||||
logging.warning(f'{func} cannot be computed')
|
||||
return None
|
||||
|
||||
pred_attr1 = query_func(f'query_{attr[2].strip()}', img, [attr[0], attr[2]], results, info, img_size, dim)
|
||||
pred_attr2 = query_func(f'query_{attr[2].strip()}', img, [attr[1], attr[2]], results, info, img_size, dim)
|
||||
|
||||
if pred_attr1 == pred_attr2:
|
||||
return 'yes'
|
||||
else:
|
||||
return 'no'
|
||||
|
||||
elif 'choose' in func:
|
||||
return choose_func(data, img, func, attr, img_size, results, info, dim, memory, visualize=VISUALIZE)
|
||||
|
||||
else:
|
||||
logging.warning(f'{func} not implemented')
|
||||
return -1
|
||||
|
||||
last_func = func
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
TEST = True
|
||||
DIM = 2048
|
||||
RANDOM_SEED = 17
|
||||
np.random.seed(RANDOM_SEED)
|
||||
torch.manual_seed(RANDOM_SEED)
|
||||
random.seed(RANDOM_SEED)
|
||||
|
||||
x = datetime.now()
|
||||
TIME_STAMP = x.strftime("%d%b%y-%H%M")
|
||||
log_file = f"logs/run{TIME_STAMP}-{'VAL' if TEST else 'TRAIN'}{RANDOM_SEED}.log"
|
||||
log_file = f"logs/run{TIME_STAMP}-DIM{DIM}-{RANDOM_SEED}.log"
|
||||
|
||||
logging.basicConfig(level=logging.INFO, filename=log_file,
|
||||
filemode="w", format="%(asctime)s %(levelname)s %(message)s")
|
||||
|
||||
print('Logging to ', log_file)
|
||||
DATA_PATH = '/scratch/penzkofer/GQA'
|
||||
CUDA_DEVICE = 7
|
||||
torch.cuda.set_device(CUDA_DEVICE)
|
||||
|
||||
device = torch.device("cuda:" + str(CUDA_DEVICE))
|
||||
clip_model, preprocess = clip.load("ViT-B/32", device=device)
|
||||
|
||||
with open('gqa_all_relations_map.json') as f:
|
||||
RELATION_DICT = json.load(f)
|
||||
|
||||
with open('gqa_vocab_classes.json') as f:
|
||||
CLASS_DICT = json.load(f)
|
||||
|
||||
with open('gqa_all_attributes.json') as f:
|
||||
ATTRIBUTE_DICT = json.load(f)
|
||||
|
||||
start = time.time()
|
||||
res = 100
|
||||
dim = DIM
|
||||
new_size = (25, 25) # size should be smaller than resolution!
|
||||
|
||||
xs = np.linspace(0, new_size[1], res)
|
||||
ys = np.linspace(0, new_size[0], res)
|
||||
ws = np.linspace(1, 10, 10)
|
||||
hs = np.linspace(1, 10, 10)
|
||||
|
||||
rng = np.random.RandomState(seed=RANDOM_SEED)
|
||||
x_axis = make_good_unitary(dim, rng=rng)
|
||||
y_axis = make_good_unitary(dim, rng=rng)
|
||||
w_axis = make_good_unitary(dim, rng=rng)
|
||||
h_axis = make_good_unitary(dim, rng=rng)
|
||||
|
||||
logging.info(f'Size of vector space: {res**2}x{10**2}x{dim}')
|
||||
logging.info(f'x-axis resolution = {len(xs)}, y-axis resolution = {len(ys)}')
|
||||
logging.info(f'width resolution = {len(ws)}, height resolution = {len(hs)}')
|
||||
|
||||
# precompute the vectors
|
||||
VECTORS = get_heatmap_vectors_multidim(xs, ys, ws, hs, x_axis, y_axis, w_axis, h_axis)
|
||||
logging.info(VECTORS.shape)
|
||||
logging.info(f'Took {time.time() - start}seconds to load vectors.\n')
|
||||
|
||||
# load questions, programs and scenegraphs
|
||||
if TEST:
|
||||
questions_path = 'val_balanced_questions.json'
|
||||
programs_path = 'programs/trainval_balanced_programs.json'
|
||||
scene_path = 'val_sceneGraphs.json'
|
||||
else:
|
||||
questions_path = 'train_balanced_questions.json'
|
||||
programs_path = 'programs/trainval_balanced_programs.json'
|
||||
scene_path = 'train_sceneGraphs.json'
|
||||
|
||||
with open(os.path.join(DATA_PATH, questions_path), 'r') as f:
|
||||
questions = json.load(f)
|
||||
|
||||
with open(os.path.join(DATA_PATH, programs_path), 'r') as f:
|
||||
programs = json.load(f)
|
||||
|
||||
with open(os.path.join(DATA_PATH, scene_path), 'r') as f:
|
||||
scenegraphs = json.load(f)
|
||||
|
||||
columns = ['semantic', 'entailed', 'equivalent', 'question', 'imageId', 'isBalanced', 'groups',
|
||||
'answer', 'semanticStr', 'annotations', 'types', 'fullAnswer']
|
||||
|
||||
questions = pd.DataFrame.from_dict(questions, orient='index', columns=columns)
|
||||
questions = questions.reset_index()
|
||||
questions = questions.rename(columns={"index": "questionID"}, errors="raise")
|
||||
|
||||
columns = ['imageID', 'question', 'program', 'questionID', 'answer']
|
||||
programs = pd.DataFrame(programs, columns=columns)
|
||||
|
||||
DATA = GQADataset(questions, programs, scenegraphs, vectors=VECTORS,
|
||||
axes=[x_axis, y_axis, w_axis, h_axis], linspace=[xs, ys, ws, hs])
|
||||
|
||||
logging.info(f'Length of data set: {len(DATA)}')
|
||||
|
||||
VISUALIZE = False
|
||||
DATA.set_visualize(VISUALIZE)
|
||||
DATA.set_verbose(0)
|
||||
|
||||
results_lst = []
|
||||
num_correct = 0
|
||||
|
||||
pbar = tqdm(range(len(DATA)), ncols=115)
|
||||
|
||||
for i, IDX in enumerate(pbar):
|
||||
start = time.time()
|
||||
img, info, memory = DATA.encode_item(IDX, dim=dim)
|
||||
avg_mse, avg_iou, correct_items = DATA.decode_item(img, info, memory)
|
||||
|
||||
try:
|
||||
answer = run_program(DATA, img, info, counter=1, memory=memory, dim=dim, verbose=1)
|
||||
except Exception as e:
|
||||
answer = None
|
||||
logging.error(e)
|
||||
|
||||
if answer == -1:
|
||||
logging.warning(f'[{IDX}] not fully implemented!')
|
||||
|
||||
time_in_sec = time.time() - start
|
||||
correct = answer == info["answer"]
|
||||
num_correct += int(correct)
|
||||
|
||||
results = {'q_id':info['q_id'], 'question':info['question'],'program':info['program'], 'image':info['img_id'],
|
||||
'true_answer': info['answer'], 'pred_answer': answer, 'correct': correct, 'time': time_in_sec,
|
||||
'enc_avg_mse': avg_mse, 'enc_avg_iou': avg_iou, 'enc_correct_items': correct_items,
|
||||
'enc_items': len(info["encoded_items"]), 'q_idx': IDX}
|
||||
|
||||
results_lst.append(results)
|
||||
logging.info(f'[{IDX+1}] {num_correct / (i+1):.2%}')
|
||||
pbar.set_postfix({'correct': f'{num_correct / (i+1):.2%}', 'q_idx': str(IDX+1)})
|
||||
|
||||
results_df = pd.DataFrame(results_lst)
|
||||
logging.info(f'Acurracy: {results_df.correct.sum() / len(results_df):.2%}')
|
||||
|
||||
out_path = os.path.join(DATA_PATH, f'results-DIM{DIM}-{"VAL" if TEST else "TRAIN"}{RANDOM_SEED}.pkl')
|
||||
results_df.to_pickle(out_path)
|
||||
logging.info(f'Saved results to {out_path}.')
|
||||
|
298
utils.py
Normal file
298
utils.py
Normal file
|
@ -0,0 +1,298 @@
|
|||
import json
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as patches
|
||||
from pattern.text.en import singularize
|
||||
import nengo.spa as spa
|
||||
import scipy.integrate as integrate
|
||||
|
||||
RGB_COLORS = []
|
||||
hex_colors = ['#8a3ffc', '#ff7eb6', '#6fdc8c', '#d2a106', '#ba4e00', '#33b1ff', '#570408',
|
||||
'#fa4d56', '#4589ff', '#08bdba', '#d4bbff', '#007d79', '#d12771', '#bae6ff']
|
||||
|
||||
for h in hex_colors:
|
||||
RGB_COLORS.append(matplotlib.colors.to_rgb(h))
|
||||
|
||||
for i, (name, h) in enumerate(matplotlib.colors.cnames.items()):
|
||||
if i > 10:
|
||||
RGB_COLORS.append(matplotlib.colors.to_rgb(h))
|
||||
|
||||
|
||||
f = open('gqa_all_relations_map.json')
|
||||
RELATION_DICT = json.load(f)
|
||||
f.close()
|
||||
|
||||
f = open('gqa_all_vocab_classes.json')
|
||||
CLASS_DICT = json.load(f)
|
||||
f.close()
|
||||
|
||||
f = open('gqa_all_attributes.json')
|
||||
ATTRIBUTE_DICT = json.load(f)
|
||||
f.close()
|
||||
|
||||
|
||||
def bbox_to_mask(x, y, w, h, img_size=(500, 500), name=None, visualize=False):
|
||||
img = np.zeros(img_size)
|
||||
mask_w = np.ones(w)
|
||||
for j in range(y, y+h):
|
||||
img[j][x:x+w] = mask_w
|
||||
|
||||
if visualize:
|
||||
fig = plt.figure(figsize=(img_size[0] // 80, img_size[1] // 80))
|
||||
plt.imshow(img, cmap='gray')
|
||||
if name:
|
||||
plt.title(name)
|
||||
plt.axis('off')
|
||||
plt.show()
|
||||
|
||||
return img
|
||||
|
||||
def make_good_unitary(D, eps=1e-3, rng=np.random):
|
||||
"""from https://github.com/ctn-waterloo/cogsci2019-ssp/tree/master"""
|
||||
a = rng.rand((D - 1) // 2)
|
||||
sign = rng.choice((-1, +1), len(a))
|
||||
phi = sign * np.pi * (eps + a * (1 - 2 * eps))
|
||||
assert np.all(np.abs(phi) >= np.pi * eps)
|
||||
assert np.all(np.abs(phi) <= np.pi * (1 - eps))
|
||||
|
||||
fv = np.zeros(D, dtype='complex64')
|
||||
fv[0] = 1
|
||||
fv[1:(D + 1) // 2] = np.cos(phi) + 1j * np.sin(phi)
|
||||
fv[-1:D // 2:-1] = np.conj(fv[1:(D + 1) // 2])
|
||||
if D % 2 == 0:
|
||||
fv[D // 2] = 1
|
||||
|
||||
assert np.allclose(np.abs(fv), 1)
|
||||
v = np.fft.ifft(fv)
|
||||
# assert np.allclose(v.imag, 0, atol=1e-5)
|
||||
v = v.real
|
||||
assert np.allclose(np.fft.fft(v), fv)
|
||||
assert np.allclose(np.linalg.norm(v), 1)
|
||||
return spa.SemanticPointer(v)
|
||||
|
||||
def get_heatmap_vectors(xs, ys, x_axis_sp, y_axis_sp):
|
||||
"""from https://github.com/ctn-waterloo/cogsci2019-ssp/tree/master:
|
||||
Precompute spatial semantic pointers for every location in the linspace
|
||||
Used to quickly compute heat maps by a simple vectorized dot product (matrix multiplication)
|
||||
"""
|
||||
if x_axis_sp.__class__.__name__ == 'SemanticPointer':
|
||||
dim = len(x_axis_sp.v)
|
||||
else:
|
||||
dim = len(x_axis_sp)
|
||||
x_axis_sp = spa.SemanticPointer(data=x_axis_sp)
|
||||
y_axis_sp = spa.SemanticPointer(data=y_axis_sp)
|
||||
|
||||
vectors = np.zeros((len(xs), len(ys), dim))
|
||||
|
||||
for i, x in enumerate(xs):
|
||||
for j, y in enumerate(ys):
|
||||
p = encode_point(
|
||||
x=x, y=y, x_axis=x_axis_sp, y_axis=y_axis_sp,
|
||||
)
|
||||
vectors[i, j, :] = p.v
|
||||
|
||||
return vectors
|
||||
|
||||
def power(s, e):
|
||||
"""from https://github.com/ctn-waterloo/cogsci2019-ssp/tree/master"""
|
||||
x = np.fft.ifft(np.fft.fft(s.v) ** e).real
|
||||
return spa.SemanticPointer(data=x)
|
||||
|
||||
def encode_point(x, y, x_axis, y_axis):
|
||||
"""from https://github.com/ctn-waterloo/cogsci2019-ssp/tree/master"""
|
||||
return power(x_axis, x) * power(y_axis, y)
|
||||
|
||||
def encode_region(x, y, x_axis, y_axis):
|
||||
"""from https://github.com/ctn-waterloo/cogsci2019-ssp/tree/master"""
|
||||
print(integrate.quad(power(x_axis, x) * power(y_axis, y), x, x+28))
|
||||
return integrate.quad(power(x_axis, x) * power(y_axis, y), x, x+28)
|
||||
|
||||
|
||||
def plot_heatmap(img, img_area, encoded_pos, xs, ys, heatmap_vectors, name='', vmin=-1, vmax=1, invert=False):
|
||||
"""from https://github.com/ctn-waterloo/cogsci2019-ssp/tree/master"""
|
||||
assert encoded_pos.__class__.__name__ == 'SemanticPointer'
|
||||
|
||||
# sp has shape (dim,) and heatmap_vectors have shape (xs, ys, dim) so the result will be (xs, ys)
|
||||
vec_sim = np.tensordot(encoded_pos.v, heatmap_vectors, axes=([0], [2]))
|
||||
|
||||
num_plots = 3 if img_area is not None else 2
|
||||
fig, axs = plt.subplots(1, num_plots, figsize=(4 * num_plots + 3, 3))
|
||||
fig.suptitle(name)
|
||||
|
||||
axs[0].imshow(img)
|
||||
axs[0].axis('off')
|
||||
|
||||
if img_area is not None:
|
||||
axs[1].imshow(img_area, cmap='gray')
|
||||
axs[1].set_xticks(np.arange(0, len(xs), 20), np.arange(0, img.shape[1], img.shape[1] / len(xs)).astype(int)[::20])
|
||||
axs[1].set_yticks(np.arange(0, len(ys), 10), np.arange(0, img.shape[0], img.shape[0] / len(ys)).astype(int)[::10])
|
||||
axs[1].axis('off')
|
||||
|
||||
im = axs[2].imshow(np.transpose(vec_sim), origin='upper', interpolation='none', extent=(xs[-1], xs[0], ys[-1], ys[0]), vmin=vmin, vmax=vmax, cmap='plasma')
|
||||
axs[2].axis('off')
|
||||
|
||||
else:
|
||||
im = axs[1].imshow(np.transpose(vec_sim), origin='upper', interpolation='none', extent=(xs[-1], xs[0], ys[-1], ys[0]), vmin=vmin, vmax=vmax, cmap='plasma')
|
||||
axs[1].axis('off')
|
||||
|
||||
fig.colorbar(im, ax=axs.ravel().tolist())
|
||||
plt.show()
|
||||
|
||||
|
||||
def generate_region_vector(desired, xs, ys, x_axis_sp, y_axis_sp):
|
||||
"""from https://github.com/ctn-waterloo/cogsci2019-ssp/tree/master"""
|
||||
vector = np.zeros_like((x_axis_sp.v))
|
||||
for i, x in enumerate(xs):
|
||||
for j, y in enumerate(ys):
|
||||
if desired[j, i] == 1:
|
||||
vector += encode_point(x, y, x_axis_sp, y_axis_sp).v
|
||||
|
||||
sp = spa.SemanticPointer(data=vector)
|
||||
sp.normalize()
|
||||
return sp
|
||||
|
||||
|
||||
def bb_intersection_over_union(boxA, boxB):
|
||||
"""from https://pyimagesearch.com/2016/11/07/intersection-over-union-iou-for-object-detection/"""
|
||||
# determine the (x, y)-coordinates of the intersection rectangle
|
||||
xA = max(boxA[0], boxB[0])
|
||||
yA = max(boxA[1], boxB[1])
|
||||
xB = min(boxA[2], boxB[2])
|
||||
yB = min(boxA[3], boxB[3])
|
||||
|
||||
# compute the area of intersection rectangle
|
||||
interArea = abs(max((xB - xA, 0)) * max((yB - yA), 0))
|
||||
if interArea == 0:
|
||||
return 0
|
||||
# compute the area of both the prediction and ground-truth
|
||||
# rectangles
|
||||
boxAArea = abs((boxA[2] - boxA[0]) * (boxA[3] - boxA[1]))
|
||||
boxBArea = abs((boxB[2] - boxB[0]) * (boxB[3] - boxB[1]))
|
||||
|
||||
# compute the intersection over union by taking the intersection
|
||||
# area and dividing it by the sum of prediction + ground-truth
|
||||
# areas - the interesection area
|
||||
iou = interArea / float(boxAArea + boxBArea - interArea)
|
||||
|
||||
# return the intersection over union value
|
||||
return iou
|
||||
|
||||
|
||||
def encode_point_multidim(values, axes):
|
||||
""" power(x_axis, x) * power(y_axis, y) for variable dimensions """
|
||||
assert len(values) == len(axes), f'number of values {len(values)} does not match number of axes {len(axes)}'
|
||||
res = 1
|
||||
for v, a in zip(values, axes):
|
||||
res *= power(a, v)
|
||||
return res
|
||||
|
||||
def get_heatmap_vectors_multidim(xs, ys, ws, hs, x_axis, y_axis, w_axis, h_axis):
|
||||
""" adaptation of get_heatmap_vectors for 4 dimensions """
|
||||
assert x_axis.__class__.__name__ == 'SemanticPointer', f'Axes need to be of type SemanticPointer but are {x_axis.__class__.__name__}'
|
||||
|
||||
dim = len(x_axis.v)
|
||||
vectors = np.zeros((len(xs), len(ys), len(ws), len(hs), dim))
|
||||
|
||||
for i, x in enumerate(xs):
|
||||
for j, y in enumerate(ys):
|
||||
for n, w in enumerate(ws):
|
||||
for k, h in enumerate(hs):
|
||||
p = encode_point_multidim(values=[x, y, w, h], axes=[x_axis, y_axis, w_axis, h_axis])
|
||||
vectors[i, j, n, k, :] = p.v
|
||||
|
||||
return vectors
|
||||
|
||||
|
||||
def ssp_to_loc_multidim(sp, heatmap_vectors, linspace):
|
||||
""" adaptation of loc_match from https://github.com/ctn-waterloo/cogsci2019-ssp/tree/master
|
||||
Convert an SSP to the approximate 4-dim location that it represents.
|
||||
Uses the heatmap vectors as a lookup table
|
||||
"""
|
||||
xs, ys, ws, hs = linspace
|
||||
|
||||
assert sp.__class__.__name__ == 'SemanticPointer', \
|
||||
f'Queried object needs to be of type SemanticPointer but is {sp.__class__.__name__}'
|
||||
|
||||
# axes: a list of axes to be summed over, first sequence applying to first tensor, second to second tensor
|
||||
vs = np.tensordot(sp.v, heatmap_vectors, axes=([0], [4]))
|
||||
|
||||
res = np.unravel_index(np.argmax(vs, axis=None), vs.shape)
|
||||
|
||||
x = xs[res[0]]
|
||||
y = ys[res[1]]
|
||||
w = ws[res[2]]
|
||||
h = hs[res[3]]
|
||||
|
||||
|
||||
return np.array([x, y, w, h])
|
||||
|
||||
|
||||
def encode_image_ssp(img, sg_data, axes, new_size, dim, visualize=True):
|
||||
"""encode all objects in an image to an SSP memory"""
|
||||
|
||||
img_size = img.shape[:2]
|
||||
|
||||
if img_size[1] / 2 < img_size[0]:
|
||||
scale = img_size[0] / new_size[0]
|
||||
else:
|
||||
scale = img_size[1] / new_size[1]
|
||||
|
||||
# scale width and height to fixed size of 10
|
||||
w_scale = img_size[1] / 10
|
||||
h_scale = img_size[0] / 10
|
||||
|
||||
|
||||
if visualize:
|
||||
print(f'Original image {img_size[1]}x{img_size[0]} --> {np.array(img_size) / scale}')
|
||||
fig, ax = plt.subplots(1,1)
|
||||
ax.imshow(img, interpolation='none', origin='upper', extent=[0, img_size[1] / scale, img_size[0] / scale, 0])
|
||||
plt.axis('off')
|
||||
|
||||
|
||||
encoded_items = {}
|
||||
encoded_ssps = {}
|
||||
|
||||
memory = spa.SemanticPointer(data=np.zeros(dim))
|
||||
name_lst = []
|
||||
|
||||
for i, obj in enumerate(sg_data.items()):
|
||||
id_num, obj_dict = obj
|
||||
name = obj_dict.get('name')
|
||||
name = singularize(name)
|
||||
name_lst.append(name)
|
||||
name += '_' + str(name_lst.count(name))
|
||||
|
||||
# extract ground truth data and scale to fit to SSPs
|
||||
x, y, width, height = obj_dict.get('x'), obj_dict.get('y'), obj_dict.get('w'), obj_dict.get('h')
|
||||
x, y, width, height = x / scale, y / scale, width / w_scale, height / h_scale
|
||||
|
||||
width = width if width >= 1 else 1
|
||||
height = height if height >= 1 else 1
|
||||
|
||||
# Round values to next int (otherwise decoding gets buggy)
|
||||
item = np.round([x, y, width, height], decimals=0).astype(int)
|
||||
encoded_items[name] = item
|
||||
#print(name, item)
|
||||
|
||||
pos = encode_point_multidim(list(item), axes)
|
||||
ssp = spa.SemanticPointer(dim)
|
||||
encoded_ssps[name] = ssp
|
||||
|
||||
memory += ssp * pos
|
||||
|
||||
if visualize:
|
||||
x, y, width, height = item
|
||||
width, height = (width * w_scale) / scale, (height * h_scale) / scale
|
||||
rect = patches.Rectangle((x, y),
|
||||
width, height,
|
||||
linewidth = 2,
|
||||
label = name,
|
||||
edgecolor = 'c',
|
||||
facecolor = 'none')
|
||||
ax.add_patch(rect)
|
||||
|
||||
if visualize:
|
||||
plt.show()
|
||||
|
||||
return encoded_items, encoded_ssps, memory
|
Loading…
Reference in a new issue