initial commit

This commit is contained in:
Anna Penzkofer 2024-04-29 17:18:10 +02:00
parent 449dff858d
commit 8e0fd07853
10 changed files with 4550 additions and 1 deletions

View file

@ -1,3 +1,66 @@
# VSA4VQA # VSA4VQA
Official code for "VSA4VQA: Scaling a Vector Symbolic Architecture to Visual Question Answering on Natural Images" published at CogSci'24 Official code for [VSA4VQA: Scaling a Vector Symbolic Architecture to Visual Question Answering on Natural Images](https://perceptualui.org/publications/penzkofer24_cogsci/) published at CogSci'24
## Installation
```shell
# create environment
conda create -n ssp_env python=3.9 pip
conda activate ssp_env
conda install pytorch torchvision pytorch-cuda=11.8 -c pytorch -c nvidia -y
sudo apt install libmysqlclient-dev
# install requirements
pip install -r requirements.txt
# install CLIP
pip install git+https://github.com/openai/CLIP.git
# setup jupyter notebook kernel
python -m ipykernel install --user --name=ssp_env
```
## Get GQA Programs
using code by [https://github.com/wenhuchen/Meta-Module-Network](https://github.com/wenhuchen/Meta-Module-Network)<br>
- Download github repo MMN
- Add `gqa-questions` folder with GQA json files
- Run Preprocessing
`python preprocess.py create_balanced_programs`
- Save generated programs to data folder:
```
testdev_balanced_inputs.json
trainval_balanced_inputs.json
testdev_balanced_programs.json
trainval_balanced_programs.json
```
> GQA dictionaries: `gqa_all_attributes.json` and `gqa_all_vocab_classes` are also adapted from [https://github.com/wenhuchen/Meta-Module-Network](https://github.com/wenhuchen/Meta-Module-Network)
## Generate Query Masks
- generates full_relations_df.pkl if not already present
- generates query masks for all relations with more than 1000 samples
```shell
python generate_query_masks.py
```
## Pipeline
Execute Pipeline for all samples in GQA: train_balanced (with `TEST=False`) or validation_balanced (with `TEST=True`)
```shell
python run_programs.py
```
For visualizing samples see [code/GQA_PIPELINE.ipynb](code/GQA_PIPELINE.ipynb) <br>
For generating figures see [code/GQA_EVAL.ipynb](code/GQA_EVAL.ipynb) <br>
## Citation
Please consider citing this paper if you use VSA4VQA or parts of this publication in your research:
```
@inproceedings{penzkofer24_cogsci,
author = {Penzkofer, Anna and Shi, Lei and Bulling, Andreas},
title = {VSA4VQA: Scaling A Vector Symbolic Architecture To Visual Question Answering on Natural Images},
booktitle = {Proc. 46th Annual Meeting of the Cognitive Science Society (CogSci)},
year = {2024},
pages = {}
}
```

517
VSA4VQA_examples.ipynb Normal file

File diff suppressed because one or more lines are too long

287
dataset.py Normal file
View file

@ -0,0 +1,287 @@
import os
import cv2
import torch
from torch.utils.data import Dataset
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.patches as patches
import nengo.spa as spa
from utils import encode_point_multidim, ssp_to_loc_multidim, bb_intersection_over_union
RGB_COLORS = []
for name, hex in mcolors.cnames.items():
RGB_COLORS.append(mcolors.to_rgb(hex))
class MNISTQueryDataset(Dataset):
"""MNIST spatial query dataset."""
def __init__(self, mnist_data, num_imgs, img_size=120, visualize=False, transform=None, seed=42):
# Set random seed for location and mnist image selection
np.random.seed(seed)
torch.manual_seed(seed)
self.mnist_data = mnist_data
self.mnist_size = mnist_data[0][0].squeeze().numpy().shape[0]
# Shuffle MNIST data set according to random seed
self.mnist_indices = torch.randperm(len(self.mnist_data))
self.num_imgs = num_imgs
self.img_size = img_size
self.border = self.img_size - self.mnist_size
self.visualize = visualize
self.transform = transform
def __len__(self):
return len(self.mnist_data) // self.num_imgs
def __getitem__(self, idx):
current_indices = self.mnist_indices[idx: idx + self.num_imgs]
image = np.zeros((self.img_size, self.img_size))
mask = np.ones((self.border, self.border))
labels = []
mnist_imgs = []
for i in current_indices:
mnist, label = self.mnist_data[i]
mnist_imgs.append(mnist.squeeze().numpy())
# find available space
indices = np.where(mask == 1)[:2]
coords = np.transpose(indices)
# pick random pixel as x0, y0
idx = np.random.randint(len(indices[0]))
y_pos, x_pos = coords[idx]
# add mnist to image
image[y_pos: y_pos+self.mnist_size, x_pos: x_pos+self.mnist_size] = mnist.squeeze().numpy()
# position label = center of mnist image
labels.append(dict({label: (x_pos + self.mnist_size // 2, y_pos + self.mnist_size // 2)}))
# remove available space
for x in np.arange(max(0, x_pos-self.mnist_size), min(x_pos+self.mnist_size+1, self.border)):
for y in np.arange(max(0, y_pos-self.mnist_size), min(y_pos+self.mnist_size+1, self.border)):
mask[y, x] = 0
# visualize image state and current mask
if self.visualize:
f, (ax1, ax2) = plt.subplots(1, 2, sharey=False)
ax1.imshow(image, cmap='gray')
ax2.imshow(mask, cmap='gray')
plt.show()
sample = {'image': image, 'labels': labels, 'mnist_images': mnist_imgs}
#if self.transform:
# sample = self.transform(sample)
return sample
class GQADataset():
def __init__(self, questions, programs, scenegraphs, vectors, axes, linspace,
path='GQA/images/images/', seed=17, verbose=0, visualize=False):
np.random.seed(seed)
self.questions = questions
self.programs = programs
self.scenegraphs = scenegraphs
self.vis = visualize
self.verbose = verbose
self.seed = seed
self.path = path
# vector space
self.ssp_vectors = vectors
self.ssp_axes = axes #[x_axis, y_axis, w_axis, h_axis]
self.linspace = linspace # [xs, ys, ws, hs]
def __len__(self):
return len(self.questions)
def __get_item__(self, idx):
q_id = self.questions.iloc[idx].questionID
temp_df = self.questions.loc[self.questions.questionID == q_id]
# get image
img_id = temp_df.imageId.values[0]
img_path = os.path.join(self.path, f'{img_id}.jpg')
assert os.path.exists(img_path), f'Image path {img_path} does not exist!'
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# get question and answer
question = temp_df.question.values[0]
answer = temp_df.answer.values[0]
full_answer = temp_df.fullAnswer.values[0]
# get program
idx = self.programs.loc[self.programs.questionID == q_id].index[0]
program = self.programs.iloc[idx].program
info = {'q_id': q_id, 'img_id': img_id, 'question': question, 'answer': answer,
'full_answer': full_answer, 'program': program}
return img, info
def encode_item(self, idx, new_size=(25, 25), dim=1024):
""" Encode all objects in image into SSP memory = vector space.
ensure x- and y-axis of SSP memory have same resolution
and fixed width & height axes (10,10), no zero values for width & height
and int values instead of decimals, otherwise decoding accuracy degrades
"""
img, info = self.__get_item__(idx)
sg_data = self.scenegraphs.get(str(info['img_id'])).get('objects')
img_size = img.shape[:2]
# find orientation and select scale to fit into quadratic vector space
if img_size[1] / 2 < img_size[0]:
scale = img_size[0] / new_size[0]
else:
scale = img_size[1] / new_size[1]
# scale width and height to fixed size of 10
w_scale, h_scale = img_size[1] / 10, img_size[0] / 10
encoded_items = {}
encoded_ssps = {}
rng = np.random.RandomState(seed=self.seed)
memory = spa.SemanticPointer(data=np.zeros(dim), rng=rng)
name_lst = []
if self.vis:
print(f'Original image {img_size[0]}x{img_size[1]} --> {int(img_size[0] / scale)}x{int(img_size[1] / scale)}')
fig, ax = plt.subplots(1,1)
ax.imshow(img, interpolation='none', origin='upper', extent=[0, img_size[1] / scale, img_size[0] / scale, 0])
plt.axis('off')
for i, obj in enumerate(sg_data.items()):
id_num, obj_dict = obj
name = obj_dict.get('name')
#name = singularize(name)
name_lst.append(name)
name += '_' + str(name_lst.count(name))
# extract ground truth data and scale to fit to SSPs
x, y, width, height = obj_dict.get('x'), obj_dict.get('y'), obj_dict.get('w'), obj_dict.get('h')
x, y, width, height = x / scale, y / scale, width / w_scale, height / h_scale
width = width if width >= 1 else 1
height = height if height >= 1 else 1
# Round values to next int (otherwise decoding gets buggy)
item = np.round([x, y, width, height], decimals=0).astype(int)
encoded_items[name] = item
pos = encode_point_multidim(list(item), self.ssp_axes)
ssp = spa.SemanticPointer(dim)
encoded_ssps[name] = ssp
memory += ssp * pos
if self.vis:
x, y, width, height = item
width, height = (width * w_scale) / scale, (height * h_scale) / scale
rect = patches.Rectangle((x, y),
width, height,
linewidth = 2,
label = name,
edgecolor = RGB_COLORS[i],
facecolor = 'none')
ax.add_patch(rect)
if self.vis:
plt.show()
info['encoded_items'] = encoded_items
info['encoded_ssps'] = encoded_ssps
info['scales'] = [scale, w_scale, h_scale]
return img, info, memory
def decode_item(self, img, info, memory):
img_size = img.shape[:2]
scale, w_scale, h_scale = info['scales']
if self.vis:
fig, ax = plt.subplots(1,1)
ax.imshow(img, interpolation='none', origin='upper', extent=[0, img_size[1] / scale, img_size[0] / scale, 0])
plt.axis('off')
errors = []
iou_lst = []
iou_binary_lst = []
for i, (name, data) in enumerate(info['encoded_items'].items()):
ssp_item = info['encoded_ssps'][name]
item_decoded = memory *~ ssp_item
clean_loc = ssp_to_loc_multidim(item_decoded, self.ssp_vectors, self.linspace)
x, y, width, height = clean_loc
mse = np.square(np.subtract(data[:2], clean_loc[:2])).mean()
errors.append(mse)
width, height = width * w_scale / scale, height * h_scale / scale
bb_gt = np.array([data[0], data[1], data[0]+(data[2] * w_scale / scale), data[1]+(data[3] * h_scale / scale)])
iou = bb_intersection_over_union(bb_gt, [x, y, x+width, y+height])
iou_lst.append(iou)
if iou > 0.5:
iou_binary_lst.append(1)
else:
iou_binary_lst.append(0)
if self.vis:
rect = patches.Rectangle((x, y),
width, height,
linewidth = 2,
label = name,
edgecolor = RGB_COLORS[i],
facecolor = 'none')
ax.add_patch(rect)
if self.vis:
plt.legend(loc='upper left', bbox_to_anchor=(1., 1.02))
plt.show()
avg_mse = np.mean(errors)
avg_iou = np.mean(iou_lst)
if self.verbose > 0:
print(f'Average mean-squared error of 2D locations: {avg_mse:.4f}')
print(f'Average IoU of 4D bounding boxes: {avg_iou:.2f}')
print(f'Correct items: {np.sum(iou_binary_lst)} / {len(info["encoded_items"])}')
return avg_mse, avg_iou, np.sum(iou_binary_lst)
def print_item(self, idx):
_, info = self.__get_item__(idx)
print(f"Question #{info['q_id']}: \n{info['question']}")
print(f"[{info['answer']}] {info['full_answer']}\n")
print('Program:')
for i, step in enumerate(info['program']):
num, func = step.split('=')
print(f'{i}. {func}')
print()
def set_visualize(self, visualize):
self.vis = visualize
def set_verbose(self, verbose):
self.verbose = verbose

300
generate_query_masks.py Normal file
View file

@ -0,0 +1,300 @@
import os
import time
import json
import queue
from multiprocessing import Process, Queue
from multiprocessing.pool import Pool
from tqdm import tqdm
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
DATA_PATH = 'GQA/'
REL_PATH = 'full_relations_df.pkl'
IMG_SIZE = (500, 500)
NUM_PROCESSES = 20
NUM_SAMPLES = 100
def bbox_to_mask(x, y, w, h, img_size=IMG_SIZE, name=None, visualize=False):
img = np.zeros(img_size)
mask_w = np.ones(np.clip(w, 0, img_size[1]-x))
for j in range(y, np.clip(y+h, 0, img_size[0])):
img[j][x:x+w] = mask_w
if visualize:
fig = plt.figure(figsize=(img_size[0] // 80, img_size[1] // 80))
plt.imshow(img, cmap='gray')
if name:
plt.title(name)
plt.axis('off')
plt.show()
return img
def get_all_relations_df(data):
print(f'Length of scenegraph data set: {len(data)}')
start = time.time()
df = pd.DataFrame(columns=['image_id', 'relation', 'from', 'to', 'obj_loc', 'obj_w', 'obj_h', 'obj_center',
'rel_obj_loc', 'rel_obj_w', 'rel_obj_h'])
for img_id in data.keys():
all_objects = data.get(str(img_id)).get('objects').items()
# get all object names
all_objects_dict = {id_num: (obj_dict.get('name'), obj_dict.get('x'), obj_dict.get('y'), obj_dict.get('w'), obj_dict.get('h'))
for (id_num, obj_dict) in all_objects}
# get all relations
for obj in all_objects:
id_num, obj_dict = obj
name = obj_dict.get('name')
x, y, width, height = obj_dict.get('x'), obj_dict.get('y'), obj_dict.get('w'), obj_dict.get('h')
center = [x + width / 2, y + height / 2]
for relation in obj_dict.get('relations'):
rel = relation.get('name')
rel_obj, rel_x, rel_y, rel_w, rel_h = all_objects_dict.get(relation.get('object'))
temp = pd.DataFrame.from_dict([{'image_id': img_id, 'relation': rel, 'from': name, 'to': rel_obj,
'obj_loc': [x, y], 'obj_w': width, 'obj_h': height, 'center': center,
'rel_obj_loc': [rel_x, rel_y], 'rel_obj_w': rel_w, 'rel_obj_h': rel_h}])
df = pd.concat([df, temp], ignore_index=True)
#print(f'{df.iloc[-1]["from"]} {df.iloc[-1].relation} {df.iloc[-1].to}')
out_path = 'all_relations.pkl'
df.to_pickle(out_path)
print(f'Saved df to {out_path}')
end = time.time()
elapsed = end - start
print(f'Took {int(elapsed // 60)}:{int(elapsed % 60)} min:s for all {len(df)} relations --> {elapsed / len(df):.2f}s / relation')
def generate_query_mask(df, rel, i, img_center=np.array([250, 250]), uni_size=np.array([50, 50])):
# uni_obj only needed for visualization in the end
uni_obj = bbox_to_mask(img_center[0] - (uni_size[0] // 2), img_center[1] - (uni_size[1] // 2),
50, 50, img_size=(500, 500))
temp_df = df.loc[df.relation == rel]
print(f'[{i}] Number of "{rel}" samples: {len(temp_df)}')
query_mask = np.zeros((500, 500), dtype=np.uint8)
counter = 0
num_discard = 0
for idx in range(len(temp_df)):
if counter >= NUM_SAMPLES:
print(f'[{i}] Reached {counter} samples for relation "{rel}":')
break
img_id = temp_df.iloc[idx].image_id
img_size = (data.get(img_id)['height'], data.get(img_id)['width'])
# get relative object info and generate binary mask
obj_loc = temp_df.iloc[idx].rel_obj_loc
width = temp_df.iloc[idx].rel_obj_w
height = temp_df.iloc[idx].rel_obj_h
# get mask info and generate binary mask
mask_loc = temp_df.iloc[idx].obj_loc
mask_w = temp_df.iloc[idx].obj_w
mask_h = temp_df.iloc[idx].obj_h
if obj_loc[0] > img_size[1] or obj_loc[1] > img_size[0] or mask_loc[0] > img_size[1] or mask_loc[1] > img_size[0]:
#print('error in bounding box -- discard sample')
continue
obj = bbox_to_mask(obj_loc[0], obj_loc[1], width, height, img_size=img_size)
mask = bbox_to_mask(mask_loc[0], mask_loc[1], mask_w, mask_h, img_size=img_size)
img = obj*2 + mask
img_transformed = np.zeros((1000, 1000), dtype=np.uint8)
# scale image first
scale_x, scale_y = uni_size[0] / width, uni_size[1] / height
scale_mat = np.array([[scale_y, 0, 0], [0, scale_x, 0], [0, 0, 1]])
if scale_x > 5 or scale_y > 5:
num_discard += 1
#print(f'Scale is too high! x: {scale_x}, y: {scale_y} -- discard sample')
continue
for i, row in enumerate(img):
for j, col in enumerate(row):
pixel_data = img[i, j]
input_coords = np.array([i, j, 1])
i_out, j_out, _ = scale_mat @ input_coords
if i_out > 0 and i_out < 1000 and j_out > 0 and j_out < 1000 and pixel_data > 0:
# new indices must be within new image -- discard others
img_transformed[int(i_out), int(j_out)] = pixel_data
if not len(np.where(img_transformed >= 2)[0]) > 0:
# no data in transformed image -- discard sample
continue
# find new (x, y) location of object
new_loc = sorted([[y, x] for (y, x) in zip(*np.where(img_transformed >= 2))])[0]
new_center = [new_loc[0] + uni_size[0] // 2, new_loc[1] + uni_size[1] // 2]
# move object to center
move_x, move_y = img_center - new_center
move_mat = np.array([[1, 0, move_x], [0, 1, move_y], [0, 0, 1]])
img_moved = np.zeros((500, 500), dtype=np.uint8)
for i, row in enumerate(img_transformed):
for j, col in enumerate(row):
pixel_data = img_transformed[i, j]
input_coords = np.array([i, j, 1])
i_out, j_out, _ = move_mat @ input_coords
if i_out > 0 and i_out < 500 and j_out > 0 and j_out < 500 and pixel_data > 0:
# new indices must be within new image -- discard others
img_moved[int(i_out), int(j_out)] = pixel_data
# extract relative object mask and add to query mask
mask_transformed = np.where(img_moved==1, img_moved, 0) + np.where(img_moved==3, img_moved, 0)
query_mask += mask_transformed
counter += 1
if counter > 0:
query_mask = query_mask / counter
rel_name = '_'.join(rel.split(' '))
np.save(f'relations/{rel_name}.npy', query_mask)
print(f'[{i}] Saved query mask to: relations/{rel_name}.npy')
if num_discard > 0:
print(f'[{i}] Discarded {num_discard} samples, because scaling was too high.')
plt.figure(figsize=(3,3))
plt.imshow(uni_obj*0.1+ query_mask, cmap='gray')
plt.title(rel)
plt.axis('off')
plt.savefig(f'relations/{rel_name}.png', bbox_inches='tight', dpi=300)
plt.clf()
else:
print(f'[{i}] Could not generate query mask for "{rel}"')
def run_process(tasks, df):
while True:
try:
'''
try to get task from the queue. get_nowait() function will
raise queue.Empty exception if the queue is empty.
queue(False) function would do the same task also.
'''
task = tasks.get_nowait()
i = list(df.relation.unique()).index(task)
except queue.Empty:
break
else:
''' no exception has been raised '''
print(f'[{i}] Starting relation #{i}: {task}')
print()
generate_query_mask(df, task, i)
time.sleep(.5)
return True
# task executed in a worker process
def get_relations_task(img_id):
width, height = data.get(str(img_id))['width'], data.get(str(img_id))['height']
all_objects = data.get(str(img_id)).get('objects').items()
# get all object names
all_objects_dict = {id_num: (obj_dict.get('name'), obj_dict.get('x'), obj_dict.get('y'), obj_dict.get('w'), obj_dict.get('h'))
for (id_num, obj_dict) in all_objects}
all_relations = []
# get all relations
for obj in all_objects:
id_num, obj_dict = obj
name = obj_dict.get('name')
x, y, obj_w, obj_h = obj_dict.get('x'), obj_dict.get('y'), obj_dict.get('w'), obj_dict.get('h')
center = [x + width / 2, y + height / 2]
for relation in obj_dict.get('relations'):
rel = relation.get('name')
rel_obj, rel_x, rel_y, rel_w, rel_h = all_objects_dict.get(relation.get('object'))
all_relations.append({'image_id': img_id, 'width': width, 'height': height, 'relation': rel,
'from': name, 'to': rel_obj, 'obj_loc': [x, y], 'obj_w': obj_w, 'obj_h': obj_h,
'obj_center': center,'rel_obj_loc': [rel_x, rel_y], 'rel_obj_w': rel_w, 'rel_obj_h': rel_h})
return all_relations
if __name__ == '__main__':
path = os.path.join(DATA_PATH, 'train_sceneGraphs.json')
assert os.path.exists(path), f'{path} does not exist!'
with open(os.path.join(DATA_PATH, 'train_sceneGraphs.json'), 'r') as f:
data = json.load(f)
print(f'Length of scenegraph data set: {len(data)}')
if not os.path.exists(REL_PATH):
print('Generating dataframe of all relations...')
# generate list of relations pkl -- use multiprocessing!
# create and configure the process pool
with Pool(processes=NUM_PROCESSES) as pool:
df = pd.DataFrame(columns=['image_id', 'width', 'height', 'relation', 'from', 'to', 'obj_loc', 'obj_w',
'obj_h', 'obj_center', 'rel_obj_loc', 'rel_obj_w', 'rel_obj_h'])
# execute tasks in order
for i, result in enumerate(tqdm(pool.map(get_relations_task, list(data.keys()), chunksize=100))):
temp = pd.DataFrame.from_dict(result)
df = pd.concat([df, temp], ignore_index=True)
if i % 10000 == 0:
df.to_pickle('temp_' + REL_PATH)
print(f'Saved df to {"temp_" + REL_PATH}')
df.to_pickle(REL_PATH)
print(f'Saved df to {REL_PATH}')
else:
df = pd.read_pickle(REL_PATH)
print(f'Number of relations: {len(df.relation.unique())}')
print(df.relation.unique())
# generate query mask for each relation
#for i, rel in enumerate(df.relation.unique()):
# generate_query_mask(df, rel, i)
print('Generating a query mask for each relation...')
# generate query mask for each relation -- use multiprocessing
tasks = Queue()
procs = []
# only use relations with at least 1000 samples
rel_lst = df.relation.value_counts()[df.relation.value_counts() > 1000].index.to_list()
for rel in rel_lst:
tasks.put(rel)
# creating processes -- run only NUM_PROCESSES processes at the same time
for _ in range(NUM_PROCESSES):
p = Process(target=run_process, args=(tasks, df,))
procs.append(p)
p.start()
# completing all processes
for p in procs:
p.join()

222
gqa_all_attributes.json Normal file
View file

@ -0,0 +1,222 @@
{
"color": [
"beige",
"black",
"blond",
"blue",
"brown",
"brunette",
"cream colored",
"dark",
"dark blue",
"dark brown",
"gold",
"gray",
"green",
"khaki",
"light blue",
"light brown",
"maroon",
"orange",
"pink",
"purple",
"red",
"silver",
"tan",
"teal",
"white",
"yellow"
],
"pose": [
"bending",
"brushing tooth",
"crouching",
"jumping",
"lying",
"making a face",
"pointing",
"running",
"shaking hand",
"sitting",
"standing",
"taking a photo",
"taking a picture",
"taking picture",
"walking"
],
"material": [
"brick",
"concrete",
"glass",
"leather",
"metal",
"plastic",
"porcelain",
"wood"
],
"activity": [
"brushing tooth",
"cooking",
"drinking",
"driving",
"eating",
"looking down",
"looking up",
"playing",
"posing",
"reading",
"resting",
"sleeping",
"staring",
"talking",
"waiting"
],
"weather": [
"clear",
"cloudless",
"cloudy",
"foggy",
"overcast",
"partly cloudy",
"rainy",
"stormy",
"sunny"
],
"size": [
"giant",
"huge",
"large",
"little",
"small",
"tiny"
],
"fatness": [
"fat",
"skinny",
"thin"
],
"gender": [
"female",
"male"
],
"height": [
"short",
"tall"
],
"state": [
"calm",
"choppy",
"rough",
"smooth",
"still",
"wavy"
],
"hposition": [
"left",
"right"
],
"length": [
"long",
"short"
],
"shape": [
"octagonal",
"rectangular",
"round",
"square",
"triangular"
],
"pattern": [
"checkered",
"dotted",
"striped"
],
"thickness": [
"thick",
"thin"
],
"age": [
"little",
"old",
"young"
],
"tone": [
"light",
"dark"
],
"room": [
"attic",
"bathroom",
"bedroom",
"dining room",
"kitchen",
"living room",
"office"
],
"width": [
"narrow",
"wide"
],
"depth": [
"deep",
"shallow"
],
"cleanliness": [
"clean",
"dirty",
"stained",
"tinted"
],
"hardness": [
"hard",
"soft"
],
"race": [
"asian",
"caucasian"
],
"company": [
"adida",
"nike"
],
"sportActivity": [
"performing trick",
"riding",
"skateboarding",
"skating",
"skiing",
"snowboarding",
"surfing",
"swimming"
],
"sportactivity": [
"performing trick",
"riding",
"skateboarding",
"skating",
"skiing",
"snowboarding",
"surfing",
"swimming"
],
"weight": [
"heavy",
"light"
],
"texture": [
"coarse",
"fine"
],
"flavor": [
"chocolate",
"strawberry",
"vanilla"
],
"realism": [
"fake",
"real"
],
"face expression": [
"making a face"
]
}

314
gqa_all_relations_map.json Normal file
View file

@ -0,0 +1,314 @@
{
"to the left of": "to the left of",
"to the right of": "to the right of",
"on": "on",
"wearing": "wearing",
"of": "of",
"near": "near",
"in": "in",
"behind": "behind",
"in front of": "in front of",
"holding": "holding",
"on top of": "on top of",
"next to": "next to",
"above": "above",
"with": "with",
"below": "below",
"by": "by",
"sitting on": "sitting on",
"under": "under",
"on the side of": "on the side of",
"beside": "beside",
"standing on": "standing on",
"inside": "inside",
"carrying": "carrying",
"at": "at",
"walking on": "walking on",
"riding": "riding",
"standing in": "standing in",
"covered by": "covered by",
"around": "around",
"lying on": "lying on",
"hanging on": "hanging on",
"eating": "eating",
"watching": "watching",
"looking at": "looking at",
"covering": "covering",
"sitting in": "sitting in",
"on the front of": "on the front of",
"hanging from": "hanging on",
"parked on": "on",
"riding on": "on",
"using": "holding",
"covered in": "covered by",
"flying in": "sitting in",
"sitting at": "sitting in",
"playing with": "holding",
"full of": "carrying",
"filled with": "carrying",
"walking in": "walking on",
"crossing": "walking on",
"on the back of": "behind",
"surrounded by": "inside",
"swinging": "sitting in",
"standing next to": "next to",
"reflected in": "near",
"covered with": "covered by",
"touching": "holding",
"flying": "near",
"pulling": "holding",
"pulled by": "next to",
"contain": "carrying",
"hitting": "holding",
"leaning on": "next to",
"lying in": "lying on",
"standing by": "next to",
"driving on": "walking on",
"throwing": "near",
"sitting on top of": "on top of",
"surrounding": "around",
"underneath": "below",
"walking down": "walking on",
"parked in": "standing in",
"growing in": "sitting in",
"standing near": "near",
"growing on": "sitting on",
"standing behind": "behind",
"playing": "holding",
"printed on": "on",
"mounted on": "on",
"beneath": "below",
"attached to": "next to",
"talking on": "on",
"facing": "looking at",
"leaning against": "next to",
"cutting": "holding",
"driving": "sitting in",
"worn on": "on",
"resting on": "on",
"floating in": "in",
"lying on top of": "on top of",
"catching": "holding",
"grazing on": "standing on",
"on the bottom of": "below",
"drinking": "holding",
"standing in front of": "in front of",
"topped with": "on top of",
"playing in": "inside",
"walking with": "with",
"swimming in": "inside",
"driving down": "walking on",
"hanging over": "above",
"pushed by": "next to",
"pushing": "holding",
"playing on": "on",
"sitting next to": "next to",
"close to": "near",
"feeding": "holding",
"waiting for": "near",
"between": "next to",
"running on": "walking on",
"tied to": "next to",
"on the edge of": "on top of",
"talking to": "next to",
"holding onto": "holding",
"eating from": "holding",
"perched on": "on",
"reading": "holding",
"parked by": "next to",
"painted on": "on",
"reaching for": "holding",
"sleeping on": "lying on",
"connected to": "next to",
"grazing in": "in",
"hanging above": "above",
"floating on": "on",
"wrapped around": "around",
"stacked on": "on",
"skiing on": "walking on",
"parked at": "next to",
"standing at": "next to",
"hanging in": "hanging on",
"parked near": "near",
"walking across": "walking on",
"plugged into": "next to",
"standing beside": "beside",
"parked next to": "next to",
"working on": "on",
"stuck on": "on",
"stuck in": "in",
"drinking from": "holding",
"seen through": "in front of",
"kicking": "near",
"sitting by": "by",
"sitting in front of": "in front of",
"looking out": "behind",
"petting": "holding",
"parked in front of": "in front of",
"wrapped in": "covered by",
"flying over": "above",
"selling": "holding",
"lying inside": "lying on",
"coming from": "near",
"parked along": "standing on",
"serving": "holding",
"sitting inside": "inside",
"sitting with": "with",
"walking by": "by",
"standing under": "below",
"making": "holding",
"walking through": "walking on",
"standing on top of": "on top of",
"hung on": "below",
"walking along": "by",
"walking near": "near",
"going down": "walking on",
"flying through": "near",
"running in": "walking on",
"leaving": "near",
"mounted to": "on top of",
"sitting behind": "behind",
"on the other side of": "on the side of",
"licking": "holding",
"riding in": "riding",
"followed by": "by",
"following": "by",
"sniffing": "looking at",
"biting": "with",
"parked alongside": "by",
"flying above": "above",
"chasing": "near",
"leading": "near",
"boarding": "near",
"hanging off": "below",
"walking behind": "behind",
"parked behind": "behind",
"sitting near": "near",
"helping": "holding",
"parked beside": "beside",
"growing near": "near",
"sitting under": "below",
"coming out of": "in front of",
"sitting beside": "beside",
"hanging out of": "hanging on",
"served on": "on",
"staring at": "looking at",
"walking toward": "near",
"hugging": "carrying",
"skiing in": "in",
"entering": "in front of",
"looking in": "looking at",
"draped over": "covering",
"walking next to": "next to",
"tied around": "covering",
"growing behind": "behind",
"exiting": "in front of",
"balancing on": "on",
"drawn on": "on",
"jumping over": "above",
"looking down at": "below",
"looking into": "looking at",
"reflecting in": "in front of",
"posing with": "with",
"eating at": "at",
"sewn on": "on",
"walking up": "walking on",
"leaning over": "on the side of",
"about to hit": "holding",
"reflected on": "in front of",
"approaching": "near",
"getting on": "on",
"observing": "watching",
"growing next to": "next to",
"traveling on": "on",
"walking towards": "near",
"growing by": "by",
"displayed on": "on",
"wading in": "standing in",
"growing along": "beside",
"mixed with": "covered by",
"grabbing": "holding",
"jumping on": "walking on",
"scattered on": "on",
"opening": "holding",
"climbing": "walking on",
"pointing at": "at",
"preparing": "holding",
"coming down": "above",
"decorated by": "by",
"decorating": "on",
"taller than": "than",
"going into": "standing in",
"growing from": "on",
"tossing": "holding",
"eating in": "in",
"sleeping in": "inside",
"herding": "near",
"chewing": "eating",
"washing": "holding",
"looking through": "looking at",
"picking up": "holding",
"trying to catch": "holding",
"working in": "in",
"slicing": "holding",
"skiing down": "walking on",
"looking over": "looking at",
"standing against": "next to",
"typing on": "on",
"piled on": "on",
"lying next to": "next to",
"tying": "standing on",
"smiling at": "looking at",
"smoking": "holding",
"cleaning": "carrying",
"shining through": "behind",
"guiding": "near",
"walking to": "near",
"chained to": "next to",
"dragging": "carrying",
"cooking": "holding",
"going through": "holding",
"enclosing": "covering",
"smelling": "eating",
"adjusting": "holding",
"photographing": "looking at",
"skating on": "walking on",
"running through": "walking on",
"decorated with": "with",
"kissing": "next to",
"falling off": "below",
"walking into": "in front of",
"blowing out": "eating",
"walking past": "behind",
"towing": "near",
"worn around": "covering",
"jumping off": "on top of",
"sprinkled on": "on top of",
"moving": "carrying",
"running across": "walking on",
"hidden by": "behind",
"traveling down": "walking on",
"looking toward": "looking at",
"splashing": "near",
"hang from": "below",
"kept in": "inside",
"sitting around": "sitting on",
"displayed in": "inside",
"cooked in": "inside",
"sitting atop": "sitting on",
"brushing": "holding",
"in between": "next to",
"buying": "holding",
"standing around": "next to",
"larger than": "than",
"smaller than": "than",
"pouring": "holding",
"playing at": "at",
"longer than": "than",
"higher than": "than",
"jumping in": "in",
"shorter than": "than",
"bigger than": "than"
}

1541
gqa_all_vocab_classes.json Normal file

File diff suppressed because it is too large Load diff

39
requirements.txt Normal file
View file

@ -0,0 +1,39 @@
# Requirments for SSP VQA project
# Standard Libraries
opencv-python
ipykernel
ipywidgets
numpy
pandas
tqdm
matplotlib
imageio
moviepy
scikit-learn
wandb
torchinfo
torchmetrics
# Nengo Libraries
nengo
nbconvert>=7
mistune>=2
nengo-spa
# CLIP requirements
ftfy
regex
tqdm
# DFOL-VQA Libraries
h5py
pyyaml
mysqlclient
pattern
# NLP Libraries
stanza
sexpdata
nltk
svgling

968
run_programs.py Normal file
View file

@ -0,0 +1,968 @@
import os
import sys
os.environ["OPENBLAS_NUM_THREADS"] = "10"
import json
import cv2
import re
import random
import time
import torch
import clip
import logging
from datetime import datetime
from PIL import Image
from tqdm import tqdm
import pandas as pd
import numpy as np
import nengo.spa as spa
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.colors as mcolors
from collections import OrderedDict
from pattern.text.en import singularize, pluralize
from dataset import GQADataset
from utils import *
DATA_PATH = '/scratch/penzkofer/GQA'
RGB_COLORS = []
for name, hex in mcolors.cnames.items():
RGB_COLORS.append(mcolors.to_rgb(hex))
CUDA_DEVICE = 7
torch.cuda.set_device(CUDA_DEVICE)
device = torch.device("cuda:" + str(CUDA_DEVICE))
clip_model, preprocess = clip.load("ViT-B/32", device=device)
with open('gqa_all_relations_map.json') as f:
RELATION_DICT = json.load(f)
with open('gqa_all_vocab_classes.json') as f:
CLASS_DICT = json.load(f)
with open('gqa_all_attributes.json') as f:
ATTRIBUTE_DICT = json.load(f)
SYNONYMS = {'he': ['man', 'boy'], 'she': ['woman', 'girl']}
ANSWER_MAP = {'to the right of': 'right', 'to the left of': 'left'}
VISUALIZE = False
def plot_heatmap_multidim(sp, xs, ys, heatmap_vectors, name='', vmin=-1, vmax=1, cmap='plasma', invert=False):
"""adapted from https://github.com/ctn-waterloo/cogsci2019-ssp/tree/master"""
assert sp.__class__.__name__ == 'SemanticPointer', \
f'Queried object needs to be of type SemanticPointer but is {sp.__class__.__name__}'
# axes: a list of axes to be summed over, first sequence applying to first tensor, second to second tensor
vs = np.tensordot(sp.v, heatmap_vectors, axes=([0], [4]))
res = np.unravel_index(np.argmax(vs, axis=None), vs.shape)
plt.imshow(np.transpose(vs[:, :, res[2], res[3]]), origin='upper', interpolation='none', extent=(xs[-1], xs[0], ys[-1], ys[0]), vmin=vmin, vmax=vmax, cmap=cmap)
plt.colorbar()
plt.axis('off')
plt.title(name)
plt.show()
def select_ssp(name, memory, encoded_ssps, vectors, linspace):
"""decode location of object with name from SSP memory"""
ssp_item = encoded_ssps[name]
item_decoded = memory *~ ssp_item
clean_loc = ssp_to_loc_multidim(item_decoded, vectors, linspace)
return item_decoded, clean_loc
def clip_query(bbox, img, obj_name, clip_tokens, visualize=False):
"""Implements CLIP queries for different attributes"""
x, y, w, h = bbox
obj_center = (x + w / 2, y + h / 2)
masked_img = img.copy()
masked_img = cv2.ellipse(masked_img, (int(obj_center[0]), int(obj_center[1])), (int(w*0.7), int(h*0.7)),
0, 0, 360, (255, 0, 0), 4)
if visualize:
plt.imshow(masked_img)
plt.axis('off')
plt.show()
masked_img = Image.fromarray(np.uint8(masked_img))
tokens = clip.tokenize(clip_tokens)
with torch.no_grad():
text_features=clip_model.encode_text(tokens.to(device))
image_features = clip_model.encode_image(preprocess(masked_img).unsqueeze(0).to(device))
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
indices = torch.max(similarity, 1)[1]
similarity = similarity.squeeze()
scores = [s.item() for s in similarity]
pred = clip_tokens[indices.squeeze().item()]
if visualize:
print('CLIP')
print(scores)
print(pred)
return indices.squeeze().item()
def clip_query_scene(img, clip_tokens, verbose=0):
"""Implements CLIP queries for entire scene, i.e. no bounding box selection"""
img = Image.fromarray(np.uint8(img))
tokens = clip.tokenize(clip_tokens)
with torch.no_grad():
text_features=clip_model.encode_text(tokens.to(device))
image_features = clip_model.encode_image(preprocess(img).unsqueeze(0).to(device))
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
indices = torch.max(similarity, 1)[1]
similarity = similarity.squeeze()
scores = [s.item() for s in similarity]
pred = clip_tokens[indices.squeeze().item()]
if verbose > 0:
print('CLIP')
print(scores)
print(pred)
return indices.squeeze().item()
def clip_choose(bbox1, bbox2, img, attribute, visualize=False):
"""Run attribute vs. not attribute check for both subjects
-- select clip prediction with higher confidence"""
x1, y1, w1, h1 = bbox1
obj_center = (x1 + w1 / 2, y1 + h1 / 2)
masked_img1 = img.copy()
masked_img1 = cv2.ellipse(masked_img1, (int(obj_center[0]), int(obj_center[1])), (int(w1*0.7), int(h1*0.7)),
0, 0, 360, (255, 0, 0), 4)
x2, y2, w2, h2 = bbox2
obj_center = (x2 + w2 / 2, y2 + h2 / 2)
masked_img2 = img.copy()
masked_img2 = cv2.ellipse(masked_img2, (int(obj_center[0]), int(obj_center[1])), (int(w2*0.7), int(h2*0.7)),
0, 0, 360, (255, 0, 0), 4)
if visualize:
plt.imshow(masked_img1)
plt.axis('off')
plt.show()
plt.imshow(masked_img2)
plt.axis('off')
plt.show()
masked_img1 = Image.fromarray(np.uint8(masked_img1))
masked_img2 = Image.fromarray(np.uint8(masked_img2))
tokens = clip.tokenize([attribute, f'not {attribute}'])
with torch.no_grad():
text_features=clip_model.encode_text(tokens.to(device))
image_features = clip_model.encode_image(preprocess(masked_img1).unsqueeze(0).to(device))
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity1 = (100.0 * image_features @ text_features.T).softmax(dim=-1)
similarity1 = similarity1.squeeze()[0]
with torch.no_grad():
text_features=clip_model.encode_text(tokens.to(device))
image_features = clip_model.encode_image(preprocess(masked_img2).unsqueeze(0).to(device))
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity2 = (100.0 * image_features @ text_features.T).softmax(dim=-1)
similarity2 = similarity2.squeeze()[0]
if visualize:
logging.info(f'CLIP prediction: {[similarity1, similarity2]}')
return 0 if similarity1 > similarity2 else 1
def get_rel_path(rel, verbose=0):
"""get correct relation path, map rel to one of the 37 existing query masks"""
rel = RELATION_DICT.get(rel.strip()) # get synonym of relation if no mask exists
rel = '_'.join(rel.split(' ')) if ' ' in rel else rel
path = 'relations/' + rel + '.npy'
if verbose > 0:
logging.info('Loading ', path)
return path
def use_query_mask(obj_pos, info, rel, linspace, axes, dim, memory, verbose=0, visualize=False):
"""implements query mask usage: load spatial query mask for relation rel,
scale query mask to object, encode mask to SSP region and
move to object position in SSP memory, extract object proposals in region"""
xs, ys, ws, hs = linspace
x_axis, y_axis, w_axis, h_axis = axes
x, y, width, height = obj_pos
# 50 pixels was object size in query mask generation -- use pixel scale values for height and width
iso_scale = np.mean([(width*info['scales'][1]) / 50, (height*info['scales'][2]) / 50])
mask = np.load(get_rel_path(rel, verbose))
mask = cv2.resize(mask, (100, 100), interpolation = cv2.INTER_AREA)
# crop new query mask according to scale
new_area = int(100 / iso_scale)
if verbose > 0:
print(iso_scale, new_area)
resized = mask[max(0, 50-new_area): min(50+new_area, 100), max(0, 50-new_area): min(50+new_area, 100)]
resized = cv2.resize(resized, (100,100), interpolation = cv2.INTER_AREA)
if visualize:
fig, axs = plt.subplots(1,2, sharey=True, layout="constrained", figsize=(6, 3))
fig.set_tight_layout(True)
fig.subplots_adjust(top=1.05)
fig.suptitle('Relation: '+ rel)
plt.setp(plt.gcf().get_axes(), xticks=[], yticks=[])
axs[0].imshow(mask, cmap='gray')
axs[0].title.set_text('original')
axs[1].imshow(resized, cmap='gray')
axs[1].title.set_text('resized')
plt.show()
# encode mask to SSP region
counter = 0
vector = spa.SemanticPointer(data=np.zeros(dim))
for (i, j) in zip(*np.where(resized > 0.05)):
x, y = xs[i], ys[j]
vector += encode_point_multidim([y, x, 1, 1], axes=axes)
counter += 1
vector.normalize()
if verbose > 0:
logging.info(f'Resized mask encoded {counter} points')
if visualize:
plot_heatmap_multidim(vector, xs, ys, VECTORS, vmin=-0.2, vmax=0.2, name=f'Encoded Region')
# get object info and move query mask to position
x, y, width, height = obj_pos
obj_center = (x + width / 2, y + height / 2)
img_center = np.array([xs[50], ys[50]])
shift = -img_center + (obj_center[1], obj_center[0])
encoded_shift = encode_point(shift[1], shift[0], x_axis=x_axis, y_axis=y_axis)
shifted_pos = vector.convolve(encoded_shift)
shifted_pos.normalize()
if visualize:
plot_heatmap_multidim(shifted_pos, xs, ys, VECTORS, vmin=-0.1, vmax=0.1, name=f'Query Region')
# query region and compare output to vocab = all saved SSPs
vocab_vectors = np.zeros((len(info['encoded_ssps']), dim))
color_lst = []
for i, (name, ssp) in enumerate(info['encoded_ssps'].items()):
vocab_vectors[i, :] = ssp.v
color_lst.append(RGB_COLORS[i])
similarity = np.tensordot((memory * ~shifted_pos).v, vocab_vectors, axes=([0], [1]))
d = OrderedDict(zip(list(info['encoded_ssps'].keys()), similarity))
res = list(OrderedDict(sorted(d.items(), key=lambda x: x[1], reverse=True)))
proposals = np.array(res)[np.array(np.array(sorted(similarity, reverse=True)) > 0).astype(bool)]
if verbose > 0:
print('Proposals: ', *proposals)
if visualize:
fig = plt.Figure()
plt.bar(np.arange(len(info['encoded_ssps'])), similarity, color=color_lst, label=info['encoded_items'].keys())
plt.title('Similarity')
plt.legend(loc='upper left', bbox_to_anchor=(1.0, 1.05))
plt.show()
return proposals
def select_func(data, img, obj, info, counter, memory, visualize=False):
"""implements select function: probe SSP memory with given object name"""
result = None
if obj + '_' + str(counter) in info['encoded_ssps'].keys():
obj += '_' + str(counter)
obj_ssp, obj_pos = select_ssp(obj, memory, info['encoded_ssps'], data.ssp_vectors, data.linspace)
result = [obj, obj_ssp, obj_pos]
else:
# test synonyms and plural/singular
test = [singularize(obj), pluralize(obj)]
if SYNONYMS.get(obj):
test += SYNONYMS.get(obj)
if CLASS_DICT.get(obj):
test += CLASS_DICT.get(obj)
for obj in test:
obj = str(obj) + '_' + str(counter)
if obj in info['encoded_ssps'].keys():
obj_ssp, obj_pos = select_ssp(obj, memory, info['encoded_ssps'], data.ssp_vectors, data.linspace)
result = [obj, obj_ssp, obj_pos]
if result is not None and visualize:
fig, ax = plt.subplots(1,1)
plt.axis('off')
ax.imshow(img)
rect = patches.Rectangle((obj_pos[0] * info['scales'][0], obj_pos[1] * info['scales'][0]),
obj_pos[2] * info['scales'][1], obj_pos[3] * info['scales'][2],
linewidth = 2,
label = name,
edgecolor = 'c',
facecolor = 'none')
ax.add_patch(rect)
plt.show()
return result
def verify_func(data, img, attr, results, info, dim, memory, verbose=0, visualize=False):
"""implements all verify functions dependent on length of attributes given"""
if len(attr) == 2:
# verify_color, verify_shape, verify_scene
num = int(re.findall(r'\d+', attr[0])[0])
if results[num] is not None:
name, obj_ssp, obj_pos = results[num]
x, y = obj_pos[0] * info['scales'][0], obj_pos[1] * info['scales'][0]
w, h = obj_pos[2] * info['scales'][1], obj_pos[3] * info['scales'][2]
clip_tokens = [f'The {name.split("_")[0]} is {attr[1].strip()}',
f'The {name.split("_")[0]} is not {attr[1].strip()}']
pred = clip_query([x, y, w, h], img, name.split('_')[0], clip_tokens, visualize=visualize)
return True if pred == 0 else False
else:
return False
elif len(attr) == 3:
# verify_rel, verify_rel_inv
obj, rel, rel_obj = attr
num = int(re.findall(r'\d+', obj)[0])
proposals = []
if results[num] is not None:
name, obj_ssp, obj_pos = results[num]
proposals = use_query_mask(obj_pos, info, rel, data.linspace, data.ssp_axes, dim, memory, visualize=visualize)
proposals = [str(p).split('_')[0] for p in proposals]
return True if rel_obj.strip() in proposals else False
else:
return False
# verify_f
elif len(attr) == 1:
clip_tokens = [f'The image is {attr[0].strip()}',
f'The image is not {attr[0].strip()}',
f'The image is a {attr[0].strip()}',
f'The image is not a{attr[0].strip()}']
pred = clip_query_scene(img, clip_tokens, verbose=0)
return True if pred == 0 or pred == 2 else False
else:
logging.warning('verify_func not implemented')
return -1
def query_func(func, img, attr, results, info, img_size, dim, verbose=0, visualize=False):
"""implements all query functions"""
if 'query_f(' in func:
attr_type = attr[0].strip()
assert attr_type in CLASS_DICT, f'{attr_type} not found in class dictionary'
attributes = CLASS_DICT.get(attr_type)
clip_tokens = [f'This is a {a} {attr_type}' for a in attributes]
pred = clip_query_scene(img, clip_tokens, verbose=0)
return attributes[pred]
num = int(re.findall(r'\d+', attr[0])[0])
if results[num] is not None:
name, obj_ssp, obj_pos = results[num]
x, y = obj_pos[0] * info['scales'][0], obj_pos[1] * info['scales'][0]
w, h = obj_pos[2] * info['scales'][1], obj_pos[3] * info['scales'][2]
if len(attr) == 1:
if 'query_n' in func:
# query name
return name.split('_')[0]
if 'query_h' in func:
# query horizontal position --> x-value
if (x + w / 2) >= (img_size[1] / 2):
return 'right'
else:
return 'left'
if 'query_v' in func:
# query vertical position --> y-value
if (y + h / 2) >= (img_size[0] / 2):
return 'bottom'
else:
return 'top'
elif len(attr) == 2:
attr_type = attr[1].strip()
assert attr_type in ATTRIBUTE_DICT, f'{attr_type} not found in attribute dictionary'
attributes = ATTRIBUTE_DICT.get(attr_type)
clip_tokens = [f'The {attr_type} of {name.split("_")[0]} is {a}' for a in attributes]
pred = clip_query([x, y, w, h], img, name.split('_')[0], clip_tokens, visualize=visualize)
return attributes[pred]
else:
return None
logging.warning('query not implemented', func, attr)
return -1
def relate_func(data, func, attr, results, info, dim, memory, visualize=False):
"""implements all relationship functions"""
if 'relate_inv_name' in func or 'relate_name' in func:
obj, rel, rel_obj = attr
num = int(re.findall(r'\d+', attr[0])[0])
if results[num] is not None:
name, obj_ssp, obj_pos = results[num]
proposals = use_query_mask(obj_pos, info, rel, data.linspace, data.ssp_axes, dim,
memory, verbose=0, visualize=visualize)
selected_obj = proposals[0]
# use rel_obj to filter proposals
rel_obj = rel_obj.strip()
if rel_obj == selected_obj.split('_')[0]:
if visualize:
logging.info('Found perfect match\n')
return selected_obj
elif rel_obj in [str(p).split('_')[0] for p in proposals]:
idx = [str(p).split('_')[0] for p in proposals].index(rel_obj)
return proposals[idx]
elif rel_obj in CLASS_DICT.keys():
class_lst = CLASS_DICT.get(rel_obj)
for p in [str(p).split('_')[0] for p in proposals]:
if p in class_lst or singularize(p) in class_lst:
if visualize:
logging.info(f'Found better proposal for {rel_obj}: {p}\n')
idx = [str(p).split('_')[0] for p in proposals].index(p)
return proposals[idx]
else:
if visualize:
logging.info(f'Did not find {rel_obj} in proposals\n')
return None
elif 'relate_inv' in func or 'relate(' in func:
obj, rel = attr
num = int(re.findall(r'\d+', attr[0])[0])
if results[num] is not None:
name, obj_ssp, obj_pos = results[num]
proposals = use_query_mask(obj_pos, info, rel, data.linspace, data.ssp_axes, dim,
memory, visualize=visualize)
return proposals[0]
else:
logging.warning(f'{func} not implemented')
return -1
def filter_func(func, img, attr, img_size, results, info, visualize=False):
"""implements all filter functions"""
obj, filter_attr = attr
num = int(re.findall(r'\d+', attr[0])[0])
if results[num] is not None:
name, obj_ssp, obj_pos = results[num]
x, y = obj_pos[0] * info['scales'][0], obj_pos[1] * info['scales'][0]
w, h = obj_pos[2] * info['scales'][1], obj_pos[3] * info['scales'][2]
# query height --> y-value
if 'bottom' in filter_attr or 'top' in filter_attr:
if (y + h / 2) >= (img_size[0] / 2):
pred_attr = 'bottom'
else:
pred_attr = 'top'
return pred_attr == filter_attr.strip()
# query side --> x-value
if 'right' in filter_attr or 'left' in filter_attr:
if (x + w / 2) >= (img_size[1] / 2):
pred_attr = 'right'
else:
pred_attr = 'left'
return pred_attr == filter_attr.strip()
# filter by attribute: color, shape, activity, material
else:
clip_tokens = [f'The {name.split("_")[0]} is {filter_attr}',
f'The {name.split("_")[0]} is not {filter_attr}']
pred = clip_query([x, y, w, h], img, name.split('_')[0], clip_tokens, visualize=visualize)
if 'not' in func:
return True if pred == 1 else False
else:
return True if pred == 0 else False
else:
if visualize:
logging.info('No object was found in last step -- nothing to filter')
return None
return -1
def choose_func(data, img, func, attr, img_size, results, info, dim, memory, verbose=0, visualize=False):
"""implements all choose functions"""
if 'choose_f' in func:
pred = clip_query_scene(img, attr, verbose=0)
return attr[pred]
num = int(re.findall(r'\d+', attr[0])[0])
if results[num] is not None:
name, obj_ssp, obj_pos = results[num]
if 'choose_h' in func or 'choose_v' in func:
x, y = obj_pos[0] * info['scales'][0], obj_pos[1] * info['scales'][0]
w, h = obj_pos[2] * info['scales'][1], obj_pos[3] * info['scales'][2]
# choose side --> x-value
if 'choose_h' in func:
if (x + w / 2) >= (img_size[1] / 2):
return 'right'
else:
return 'left'
# choose vertical alignment --> y-value
if 'choose_v' in func:
if (y + h / 2) >= (img_size[0] / 2):
return 'bottom'
else:
return 'top'
elif 'choose_n' in func:
obj, name1, name2 = attr
if name == name1:
return name1
elif name == name2:
return name2
else:
return None
elif 'choose_attr' in func:
obj, attr_type, attr1, attr2 = attr
x, y = obj_pos[0] * info['scales'][0], obj_pos[1] * info['scales'][0]
w, h = obj_pos[2] * info['scales'][1], obj_pos[3] * info['scales'][2]
clip_tokens = [f'The {attr_type} of {name.split("_")[0]} is {attr1}',
f'The {attr_type} of {name.split("_")[0]} is {attr2}']
pred = clip_query([x, y, w, h], img, name.split('_')[0], clip_tokens, visualize=visualize)
attr_lst = [attr1.strip(), attr2.strip()]
if visualize:
logging.info(f'Choose attribute: {attr_type} of {name} --> clip prediction: {attr_lst[pred]}')
return attr_lst[pred]
elif 'choose_rel_inv' in func:
obj, rel_obj, attr1, attr2 = attr
proposals1 = use_query_mask(obj_pos, info, attr1, data.linspace, data.ssp_axes, dim,
memory, 0, visualize)
proposals1 = [str(p).split('_')[0] for p in proposals1]
proposals2 = use_query_mask(obj_pos, info, attr2, data.linspace, data.ssp_axes, dim,
memory, 0, visualize)
proposals2 = [str(p).split('_')[0] for p in proposals2]
if rel_obj.strip() in proposals1:
return ANSWER_MAP.get(attr1.strip()) if attr1.strip() in ANSWER_MAP else attr1.strip()
elif rel_obj.strip() in proposals2:
return ANSWER_MAP.get(attr2.strip()) if attr2.strip() in ANSWER_MAP else attr2.strip()
else:
return None
elif 'choose_subj' in func:
subj1, subj2, attribute = attr
num1 = int(re.findall(r'\d+', subj1)[0])
num2 = int(re.findall(r'\d+', subj2)[0])
if results[num1] is not None and results[num2] is not None:
name1, obj_ssp1, obj_pos1 = results[num1]
name2, obj_ssp2, obj_pos2 = results[num2]
x1, y1 = obj_pos1[0] * info['scales'][0], obj_pos1[1] * info['scales'][0]
w1, h1 = obj_pos1[2] * info['scales'][1], obj_pos1[3] * info['scales'][2]
x2, y2 = obj_pos2[0] * info['scales'][0], obj_pos2[1] * info['scales'][0]
w2, h2 = obj_pos2[2] * info['scales'][1], obj_pos2[3] * info['scales'][2]
pred = clip_choose([x1, y1, w1, h1], [x2, y2, w2, h2], img, attribute, visualize=visualize)
return name1.split('_')[0] if pred == 0 else name2.split('_')[0]
elif 'choose(' in func:
obj, attr1, attr2 = attr
x, y = obj_pos[0] * info['scales'][0], obj_pos[1] * info['scales'][0]
w, h = obj_pos[2] * info['scales'][1], obj_pos[3] * info['scales'][2]
clip_tokens = [f'The {name.split("_")[0]} is {attr1}',
f'The {name.split("_")[0]} is {attr2}']
pred = clip_query([x, y, w, h], img, name.split('_')[0], clip_tokens, visualize=visualize)
attr_lst = [attr1.strip(), attr2.strip()]
if visualize:
logging.info(f'Choose {attr1} or {attr2} for {name} --> clip prediction: {attr_lst[pred]}')
return attr_lst[pred]
else:
logging.warning(func, 'not implemented yet')
return -1
else:
if visualize:
logging.info('No object was found in last step -- nothing to choose')
return None
return -1
def run_program(data, img, info, counter, memory, dim, verbose=0):
""" run program for question on given image:
for each step in program select appropriate function
"""
scale, w_scale, h_scale = info['scales']
img_size = img.shape[:2]
results = []
last_step = False
last_func = None
for i, step in enumerate(info['program']):
if i+1 == len(info['program']):
last_step = True
_, func = step.split('=')
attr = func.split('(')[-1].split(')')[0].split(',')
if verbose > 0:
logging.info(f'{i+1}. step: \t {func}')
if 'select' in func:
obj = attr[0].strip()
res = select_func(data, img, obj, info, counter, memory, visualize=VISUALIZE)
results.append(res)
if res is None:
if verbose > 1:
logging.info(f'Could not find {obj}')
elif 'relate' in func:
found_rel_obj = relate_func(data, func, attr, results, info, dim, memory, visualize=VISUALIZE)
if found_rel_obj is not None:
assert found_rel_obj in info['encoded_ssps'], f'Result of {func}: {found_rel_obj} is not encoded'
selected_ssp = info['encoded_ssps'][found_rel_obj]
_, selected_pos = select_ssp(found_rel_obj, memory, info['encoded_ssps'], data.ssp_vectors, data.linspace)
results.append([found_rel_obj, selected_ssp, selected_pos])
if last_step:
return 'yes'
else:
results.append(None)
if last_step:
return 'no'
elif 'filter' in func:
last_filter = filter_func(func, img, attr, img_size, results, info, visualize=VISUALIZE)
if last_filter:
results.append(results[-1])
else:
if results[-1] is None:
results.append(None)
elif results[-1][0].split("_")[0] + "_" + str(counter+1) in info['encoded_ssps'].keys():
counter += 1
return None
else:
last_filter = False
results.append(results[-1])
elif 'verify' in func:
pred = verify_func(data, img, attr, results, info, dim, memory, verbose=verbose, visualize=VISUALIZE)
if 'verify_relation_name' in func or 'verify_relation_inv_name' in func:
results.append(results[-1] if pred else None)
else:
results.append(pred)
if last_step:
return 'yes' if pred else 'no'
elif 'query' in func:
return query_func(func, img, attr, results, info, img_size, dim, verbose=verbose, visualize=VISUALIZE)
elif 'exist' in func:
num = int(re.findall(r'\d+', attr[0])[0])
if last_step:
return 'yes' if results[num] is not None else 'no'
else:
if results[num] is not None and 'filter' not in last_func:
results.append(True)
elif results[num] is not None and last_filter:
results.append(True)
else:
results.append(False)
elif 'or(' in func:
attr1 = int(re.findall(r'\d+', attr[0])[0])
attr2 = int(re.findall(r'\d+', attr[1])[0])
return 'yes' if results[attr1] or results[attr2] else 'no'
elif 'and(' in func:
attr1 = int(re.findall(r'\d+', attr[0])[0])
attr2 = int(re.findall(r'\d+', attr[1])[0])
return 'yes' if results[attr1] and results[attr2] else 'no'
elif 'different' in func:
if len(attr) == 1:
logging.warning(f'{func} cannot be computed')
return None
else:
pred_attr1 = query_func(f'query_{attr[2].strip()}', img, [attr[0], attr[2]], results, info, img_size, dim)
pred_attr2 = query_func(f'query_{attr[2].strip()}', img, [attr[1], attr[2]], results, info, img_size, dim)
if pred_attr1 != pred_attr2:
return 'yes'
else:
return 'no'
elif 'same' in func:
if len(attr) == 1:
logging.warning(f'{func} cannot be computed')
return None
pred_attr1 = query_func(f'query_{attr[2].strip()}', img, [attr[0], attr[2]], results, info, img_size, dim)
pred_attr2 = query_func(f'query_{attr[2].strip()}', img, [attr[1], attr[2]], results, info, img_size, dim)
if pred_attr1 == pred_attr2:
return 'yes'
else:
return 'no'
elif 'choose' in func:
return choose_func(data, img, func, attr, img_size, results, info, dim, memory, visualize=VISUALIZE)
else:
logging.warning(f'{func} not implemented')
return -1
last_func = func
if __name__ == "__main__":
TEST = True
DIM = 2048
RANDOM_SEED = 17
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
random.seed(RANDOM_SEED)
x = datetime.now()
TIME_STAMP = x.strftime("%d%b%y-%H%M")
log_file = f"logs/run{TIME_STAMP}-{'VAL' if TEST else 'TRAIN'}{RANDOM_SEED}.log"
log_file = f"logs/run{TIME_STAMP}-DIM{DIM}-{RANDOM_SEED}.log"
logging.basicConfig(level=logging.INFO, filename=log_file,
filemode="w", format="%(asctime)s %(levelname)s %(message)s")
print('Logging to ', log_file)
DATA_PATH = '/scratch/penzkofer/GQA'
CUDA_DEVICE = 7
torch.cuda.set_device(CUDA_DEVICE)
device = torch.device("cuda:" + str(CUDA_DEVICE))
clip_model, preprocess = clip.load("ViT-B/32", device=device)
with open('gqa_all_relations_map.json') as f:
RELATION_DICT = json.load(f)
with open('gqa_vocab_classes.json') as f:
CLASS_DICT = json.load(f)
with open('gqa_all_attributes.json') as f:
ATTRIBUTE_DICT = json.load(f)
start = time.time()
res = 100
dim = DIM
new_size = (25, 25) # size should be smaller than resolution!
xs = np.linspace(0, new_size[1], res)
ys = np.linspace(0, new_size[0], res)
ws = np.linspace(1, 10, 10)
hs = np.linspace(1, 10, 10)
rng = np.random.RandomState(seed=RANDOM_SEED)
x_axis = make_good_unitary(dim, rng=rng)
y_axis = make_good_unitary(dim, rng=rng)
w_axis = make_good_unitary(dim, rng=rng)
h_axis = make_good_unitary(dim, rng=rng)
logging.info(f'Size of vector space: {res**2}x{10**2}x{dim}')
logging.info(f'x-axis resolution = {len(xs)}, y-axis resolution = {len(ys)}')
logging.info(f'width resolution = {len(ws)}, height resolution = {len(hs)}')
# precompute the vectors
VECTORS = get_heatmap_vectors_multidim(xs, ys, ws, hs, x_axis, y_axis, w_axis, h_axis)
logging.info(VECTORS.shape)
logging.info(f'Took {time.time() - start}seconds to load vectors.\n')
# load questions, programs and scenegraphs
if TEST:
questions_path = 'val_balanced_questions.json'
programs_path = 'programs/trainval_balanced_programs.json'
scene_path = 'val_sceneGraphs.json'
else:
questions_path = 'train_balanced_questions.json'
programs_path = 'programs/trainval_balanced_programs.json'
scene_path = 'train_sceneGraphs.json'
with open(os.path.join(DATA_PATH, questions_path), 'r') as f:
questions = json.load(f)
with open(os.path.join(DATA_PATH, programs_path), 'r') as f:
programs = json.load(f)
with open(os.path.join(DATA_PATH, scene_path), 'r') as f:
scenegraphs = json.load(f)
columns = ['semantic', 'entailed', 'equivalent', 'question', 'imageId', 'isBalanced', 'groups',
'answer', 'semanticStr', 'annotations', 'types', 'fullAnswer']
questions = pd.DataFrame.from_dict(questions, orient='index', columns=columns)
questions = questions.reset_index()
questions = questions.rename(columns={"index": "questionID"}, errors="raise")
columns = ['imageID', 'question', 'program', 'questionID', 'answer']
programs = pd.DataFrame(programs, columns=columns)
DATA = GQADataset(questions, programs, scenegraphs, vectors=VECTORS,
axes=[x_axis, y_axis, w_axis, h_axis], linspace=[xs, ys, ws, hs])
logging.info(f'Length of data set: {len(DATA)}')
VISUALIZE = False
DATA.set_visualize(VISUALIZE)
DATA.set_verbose(0)
results_lst = []
num_correct = 0
pbar = tqdm(range(len(DATA)), ncols=115)
for i, IDX in enumerate(pbar):
start = time.time()
img, info, memory = DATA.encode_item(IDX, dim=dim)
avg_mse, avg_iou, correct_items = DATA.decode_item(img, info, memory)
try:
answer = run_program(DATA, img, info, counter=1, memory=memory, dim=dim, verbose=1)
except Exception as e:
answer = None
logging.error(e)
if answer == -1:
logging.warning(f'[{IDX}] not fully implemented!')
time_in_sec = time.time() - start
correct = answer == info["answer"]
num_correct += int(correct)
results = {'q_id':info['q_id'], 'question':info['question'],'program':info['program'], 'image':info['img_id'],
'true_answer': info['answer'], 'pred_answer': answer, 'correct': correct, 'time': time_in_sec,
'enc_avg_mse': avg_mse, 'enc_avg_iou': avg_iou, 'enc_correct_items': correct_items,
'enc_items': len(info["encoded_items"]), 'q_idx': IDX}
results_lst.append(results)
logging.info(f'[{IDX+1}] {num_correct / (i+1):.2%}')
pbar.set_postfix({'correct': f'{num_correct / (i+1):.2%}', 'q_idx': str(IDX+1)})
results_df = pd.DataFrame(results_lst)
logging.info(f'Acurracy: {results_df.correct.sum() / len(results_df):.2%}')
out_path = os.path.join(DATA_PATH, f'results-DIM{DIM}-{"VAL" if TEST else "TRAIN"}{RANDOM_SEED}.pkl')
results_df.to_pickle(out_path)
logging.info(f'Saved results to {out_path}.')

298
utils.py Normal file
View file

@ -0,0 +1,298 @@
import json
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from pattern.text.en import singularize
import nengo.spa as spa
import scipy.integrate as integrate
RGB_COLORS = []
hex_colors = ['#8a3ffc', '#ff7eb6', '#6fdc8c', '#d2a106', '#ba4e00', '#33b1ff', '#570408',
'#fa4d56', '#4589ff', '#08bdba', '#d4bbff', '#007d79', '#d12771', '#bae6ff']
for h in hex_colors:
RGB_COLORS.append(matplotlib.colors.to_rgb(h))
for i, (name, h) in enumerate(matplotlib.colors.cnames.items()):
if i > 10:
RGB_COLORS.append(matplotlib.colors.to_rgb(h))
f = open('gqa_all_relations_map.json')
RELATION_DICT = json.load(f)
f.close()
f = open('gqa_all_vocab_classes.json')
CLASS_DICT = json.load(f)
f.close()
f = open('gqa_all_attributes.json')
ATTRIBUTE_DICT = json.load(f)
f.close()
def bbox_to_mask(x, y, w, h, img_size=(500, 500), name=None, visualize=False):
img = np.zeros(img_size)
mask_w = np.ones(w)
for j in range(y, y+h):
img[j][x:x+w] = mask_w
if visualize:
fig = plt.figure(figsize=(img_size[0] // 80, img_size[1] // 80))
plt.imshow(img, cmap='gray')
if name:
plt.title(name)
plt.axis('off')
plt.show()
return img
def make_good_unitary(D, eps=1e-3, rng=np.random):
"""from https://github.com/ctn-waterloo/cogsci2019-ssp/tree/master"""
a = rng.rand((D - 1) // 2)
sign = rng.choice((-1, +1), len(a))
phi = sign * np.pi * (eps + a * (1 - 2 * eps))
assert np.all(np.abs(phi) >= np.pi * eps)
assert np.all(np.abs(phi) <= np.pi * (1 - eps))
fv = np.zeros(D, dtype='complex64')
fv[0] = 1
fv[1:(D + 1) // 2] = np.cos(phi) + 1j * np.sin(phi)
fv[-1:D // 2:-1] = np.conj(fv[1:(D + 1) // 2])
if D % 2 == 0:
fv[D // 2] = 1
assert np.allclose(np.abs(fv), 1)
v = np.fft.ifft(fv)
# assert np.allclose(v.imag, 0, atol=1e-5)
v = v.real
assert np.allclose(np.fft.fft(v), fv)
assert np.allclose(np.linalg.norm(v), 1)
return spa.SemanticPointer(v)
def get_heatmap_vectors(xs, ys, x_axis_sp, y_axis_sp):
"""from https://github.com/ctn-waterloo/cogsci2019-ssp/tree/master:
Precompute spatial semantic pointers for every location in the linspace
Used to quickly compute heat maps by a simple vectorized dot product (matrix multiplication)
"""
if x_axis_sp.__class__.__name__ == 'SemanticPointer':
dim = len(x_axis_sp.v)
else:
dim = len(x_axis_sp)
x_axis_sp = spa.SemanticPointer(data=x_axis_sp)
y_axis_sp = spa.SemanticPointer(data=y_axis_sp)
vectors = np.zeros((len(xs), len(ys), dim))
for i, x in enumerate(xs):
for j, y in enumerate(ys):
p = encode_point(
x=x, y=y, x_axis=x_axis_sp, y_axis=y_axis_sp,
)
vectors[i, j, :] = p.v
return vectors
def power(s, e):
"""from https://github.com/ctn-waterloo/cogsci2019-ssp/tree/master"""
x = np.fft.ifft(np.fft.fft(s.v) ** e).real
return spa.SemanticPointer(data=x)
def encode_point(x, y, x_axis, y_axis):
"""from https://github.com/ctn-waterloo/cogsci2019-ssp/tree/master"""
return power(x_axis, x) * power(y_axis, y)
def encode_region(x, y, x_axis, y_axis):
"""from https://github.com/ctn-waterloo/cogsci2019-ssp/tree/master"""
print(integrate.quad(power(x_axis, x) * power(y_axis, y), x, x+28))
return integrate.quad(power(x_axis, x) * power(y_axis, y), x, x+28)
def plot_heatmap(img, img_area, encoded_pos, xs, ys, heatmap_vectors, name='', vmin=-1, vmax=1, invert=False):
"""from https://github.com/ctn-waterloo/cogsci2019-ssp/tree/master"""
assert encoded_pos.__class__.__name__ == 'SemanticPointer'
# sp has shape (dim,) and heatmap_vectors have shape (xs, ys, dim) so the result will be (xs, ys)
vec_sim = np.tensordot(encoded_pos.v, heatmap_vectors, axes=([0], [2]))
num_plots = 3 if img_area is not None else 2
fig, axs = plt.subplots(1, num_plots, figsize=(4 * num_plots + 3, 3))
fig.suptitle(name)
axs[0].imshow(img)
axs[0].axis('off')
if img_area is not None:
axs[1].imshow(img_area, cmap='gray')
axs[1].set_xticks(np.arange(0, len(xs), 20), np.arange(0, img.shape[1], img.shape[1] / len(xs)).astype(int)[::20])
axs[1].set_yticks(np.arange(0, len(ys), 10), np.arange(0, img.shape[0], img.shape[0] / len(ys)).astype(int)[::10])
axs[1].axis('off')
im = axs[2].imshow(np.transpose(vec_sim), origin='upper', interpolation='none', extent=(xs[-1], xs[0], ys[-1], ys[0]), vmin=vmin, vmax=vmax, cmap='plasma')
axs[2].axis('off')
else:
im = axs[1].imshow(np.transpose(vec_sim), origin='upper', interpolation='none', extent=(xs[-1], xs[0], ys[-1], ys[0]), vmin=vmin, vmax=vmax, cmap='plasma')
axs[1].axis('off')
fig.colorbar(im, ax=axs.ravel().tolist())
plt.show()
def generate_region_vector(desired, xs, ys, x_axis_sp, y_axis_sp):
"""from https://github.com/ctn-waterloo/cogsci2019-ssp/tree/master"""
vector = np.zeros_like((x_axis_sp.v))
for i, x in enumerate(xs):
for j, y in enumerate(ys):
if desired[j, i] == 1:
vector += encode_point(x, y, x_axis_sp, y_axis_sp).v
sp = spa.SemanticPointer(data=vector)
sp.normalize()
return sp
def bb_intersection_over_union(boxA, boxB):
"""from https://pyimagesearch.com/2016/11/07/intersection-over-union-iou-for-object-detection/"""
# determine the (x, y)-coordinates of the intersection rectangle
xA = max(boxA[0], boxB[0])
yA = max(boxA[1], boxB[1])
xB = min(boxA[2], boxB[2])
yB = min(boxA[3], boxB[3])
# compute the area of intersection rectangle
interArea = abs(max((xB - xA, 0)) * max((yB - yA), 0))
if interArea == 0:
return 0
# compute the area of both the prediction and ground-truth
# rectangles
boxAArea = abs((boxA[2] - boxA[0]) * (boxA[3] - boxA[1]))
boxBArea = abs((boxB[2] - boxB[0]) * (boxB[3] - boxB[1]))
# compute the intersection over union by taking the intersection
# area and dividing it by the sum of prediction + ground-truth
# areas - the interesection area
iou = interArea / float(boxAArea + boxBArea - interArea)
# return the intersection over union value
return iou
def encode_point_multidim(values, axes):
""" power(x_axis, x) * power(y_axis, y) for variable dimensions """
assert len(values) == len(axes), f'number of values {len(values)} does not match number of axes {len(axes)}'
res = 1
for v, a in zip(values, axes):
res *= power(a, v)
return res
def get_heatmap_vectors_multidim(xs, ys, ws, hs, x_axis, y_axis, w_axis, h_axis):
""" adaptation of get_heatmap_vectors for 4 dimensions """
assert x_axis.__class__.__name__ == 'SemanticPointer', f'Axes need to be of type SemanticPointer but are {x_axis.__class__.__name__}'
dim = len(x_axis.v)
vectors = np.zeros((len(xs), len(ys), len(ws), len(hs), dim))
for i, x in enumerate(xs):
for j, y in enumerate(ys):
for n, w in enumerate(ws):
for k, h in enumerate(hs):
p = encode_point_multidim(values=[x, y, w, h], axes=[x_axis, y_axis, w_axis, h_axis])
vectors[i, j, n, k, :] = p.v
return vectors
def ssp_to_loc_multidim(sp, heatmap_vectors, linspace):
""" adaptation of loc_match from https://github.com/ctn-waterloo/cogsci2019-ssp/tree/master
Convert an SSP to the approximate 4-dim location that it represents.
Uses the heatmap vectors as a lookup table
"""
xs, ys, ws, hs = linspace
assert sp.__class__.__name__ == 'SemanticPointer', \
f'Queried object needs to be of type SemanticPointer but is {sp.__class__.__name__}'
# axes: a list of axes to be summed over, first sequence applying to first tensor, second to second tensor
vs = np.tensordot(sp.v, heatmap_vectors, axes=([0], [4]))
res = np.unravel_index(np.argmax(vs, axis=None), vs.shape)
x = xs[res[0]]
y = ys[res[1]]
w = ws[res[2]]
h = hs[res[3]]
return np.array([x, y, w, h])
def encode_image_ssp(img, sg_data, axes, new_size, dim, visualize=True):
"""encode all objects in an image to an SSP memory"""
img_size = img.shape[:2]
if img_size[1] / 2 < img_size[0]:
scale = img_size[0] / new_size[0]
else:
scale = img_size[1] / new_size[1]
# scale width and height to fixed size of 10
w_scale = img_size[1] / 10
h_scale = img_size[0] / 10
if visualize:
print(f'Original image {img_size[1]}x{img_size[0]} --> {np.array(img_size) / scale}')
fig, ax = plt.subplots(1,1)
ax.imshow(img, interpolation='none', origin='upper', extent=[0, img_size[1] / scale, img_size[0] / scale, 0])
plt.axis('off')
encoded_items = {}
encoded_ssps = {}
memory = spa.SemanticPointer(data=np.zeros(dim))
name_lst = []
for i, obj in enumerate(sg_data.items()):
id_num, obj_dict = obj
name = obj_dict.get('name')
name = singularize(name)
name_lst.append(name)
name += '_' + str(name_lst.count(name))
# extract ground truth data and scale to fit to SSPs
x, y, width, height = obj_dict.get('x'), obj_dict.get('y'), obj_dict.get('w'), obj_dict.get('h')
x, y, width, height = x / scale, y / scale, width / w_scale, height / h_scale
width = width if width >= 1 else 1
height = height if height >= 1 else 1
# Round values to next int (otherwise decoding gets buggy)
item = np.round([x, y, width, height], decimals=0).astype(int)
encoded_items[name] = item
#print(name, item)
pos = encode_point_multidim(list(item), axes)
ssp = spa.SemanticPointer(dim)
encoded_ssps[name] = ssp
memory += ssp * pos
if visualize:
x, y, width, height = item
width, height = (width * w_scale) / scale, (height * h_scale) / scale
rect = patches.Rectangle((x, y),
width, height,
linewidth = 2,
label = name,
edgecolor = 'c',
facecolor = 'none')
ax.add_patch(rect)
if visualize:
plt.show()
return encoded_items, encoded_ssps, memory