VSA4VQA/utils.py
2024-04-29 17:18:10 +02:00

298 lines
No EOL
10 KiB
Python

import json
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from pattern.text.en import singularize
import nengo.spa as spa
import scipy.integrate as integrate
RGB_COLORS = []
hex_colors = ['#8a3ffc', '#ff7eb6', '#6fdc8c', '#d2a106', '#ba4e00', '#33b1ff', '#570408',
'#fa4d56', '#4589ff', '#08bdba', '#d4bbff', '#007d79', '#d12771', '#bae6ff']
for h in hex_colors:
RGB_COLORS.append(matplotlib.colors.to_rgb(h))
for i, (name, h) in enumerate(matplotlib.colors.cnames.items()):
if i > 10:
RGB_COLORS.append(matplotlib.colors.to_rgb(h))
f = open('gqa_all_relations_map.json')
RELATION_DICT = json.load(f)
f.close()
f = open('gqa_all_vocab_classes.json')
CLASS_DICT = json.load(f)
f.close()
f = open('gqa_all_attributes.json')
ATTRIBUTE_DICT = json.load(f)
f.close()
def bbox_to_mask(x, y, w, h, img_size=(500, 500), name=None, visualize=False):
img = np.zeros(img_size)
mask_w = np.ones(w)
for j in range(y, y+h):
img[j][x:x+w] = mask_w
if visualize:
fig = plt.figure(figsize=(img_size[0] // 80, img_size[1] // 80))
plt.imshow(img, cmap='gray')
if name:
plt.title(name)
plt.axis('off')
plt.show()
return img
def make_good_unitary(D, eps=1e-3, rng=np.random):
"""from https://github.com/ctn-waterloo/cogsci2019-ssp/tree/master"""
a = rng.rand((D - 1) // 2)
sign = rng.choice((-1, +1), len(a))
phi = sign * np.pi * (eps + a * (1 - 2 * eps))
assert np.all(np.abs(phi) >= np.pi * eps)
assert np.all(np.abs(phi) <= np.pi * (1 - eps))
fv = np.zeros(D, dtype='complex64')
fv[0] = 1
fv[1:(D + 1) // 2] = np.cos(phi) + 1j * np.sin(phi)
fv[-1:D // 2:-1] = np.conj(fv[1:(D + 1) // 2])
if D % 2 == 0:
fv[D // 2] = 1
assert np.allclose(np.abs(fv), 1)
v = np.fft.ifft(fv)
# assert np.allclose(v.imag, 0, atol=1e-5)
v = v.real
assert np.allclose(np.fft.fft(v), fv)
assert np.allclose(np.linalg.norm(v), 1)
return spa.SemanticPointer(v)
def get_heatmap_vectors(xs, ys, x_axis_sp, y_axis_sp):
"""from https://github.com/ctn-waterloo/cogsci2019-ssp/tree/master:
Precompute spatial semantic pointers for every location in the linspace
Used to quickly compute heat maps by a simple vectorized dot product (matrix multiplication)
"""
if x_axis_sp.__class__.__name__ == 'SemanticPointer':
dim = len(x_axis_sp.v)
else:
dim = len(x_axis_sp)
x_axis_sp = spa.SemanticPointer(data=x_axis_sp)
y_axis_sp = spa.SemanticPointer(data=y_axis_sp)
vectors = np.zeros((len(xs), len(ys), dim))
for i, x in enumerate(xs):
for j, y in enumerate(ys):
p = encode_point(
x=x, y=y, x_axis=x_axis_sp, y_axis=y_axis_sp,
)
vectors[i, j, :] = p.v
return vectors
def power(s, e):
"""from https://github.com/ctn-waterloo/cogsci2019-ssp/tree/master"""
x = np.fft.ifft(np.fft.fft(s.v) ** e).real
return spa.SemanticPointer(data=x)
def encode_point(x, y, x_axis, y_axis):
"""from https://github.com/ctn-waterloo/cogsci2019-ssp/tree/master"""
return power(x_axis, x) * power(y_axis, y)
def encode_region(x, y, x_axis, y_axis):
"""from https://github.com/ctn-waterloo/cogsci2019-ssp/tree/master"""
print(integrate.quad(power(x_axis, x) * power(y_axis, y), x, x+28))
return integrate.quad(power(x_axis, x) * power(y_axis, y), x, x+28)
def plot_heatmap(img, img_area, encoded_pos, xs, ys, heatmap_vectors, name='', vmin=-1, vmax=1, invert=False):
"""from https://github.com/ctn-waterloo/cogsci2019-ssp/tree/master"""
assert encoded_pos.__class__.__name__ == 'SemanticPointer'
# sp has shape (dim,) and heatmap_vectors have shape (xs, ys, dim) so the result will be (xs, ys)
vec_sim = np.tensordot(encoded_pos.v, heatmap_vectors, axes=([0], [2]))
num_plots = 3 if img_area is not None else 2
fig, axs = plt.subplots(1, num_plots, figsize=(4 * num_plots + 3, 3))
fig.suptitle(name)
axs[0].imshow(img)
axs[0].axis('off')
if img_area is not None:
axs[1].imshow(img_area, cmap='gray')
axs[1].set_xticks(np.arange(0, len(xs), 20), np.arange(0, img.shape[1], img.shape[1] / len(xs)).astype(int)[::20])
axs[1].set_yticks(np.arange(0, len(ys), 10), np.arange(0, img.shape[0], img.shape[0] / len(ys)).astype(int)[::10])
axs[1].axis('off')
im = axs[2].imshow(np.transpose(vec_sim), origin='upper', interpolation='none', extent=(xs[-1], xs[0], ys[-1], ys[0]), vmin=vmin, vmax=vmax, cmap='plasma')
axs[2].axis('off')
else:
im = axs[1].imshow(np.transpose(vec_sim), origin='upper', interpolation='none', extent=(xs[-1], xs[0], ys[-1], ys[0]), vmin=vmin, vmax=vmax, cmap='plasma')
axs[1].axis('off')
fig.colorbar(im, ax=axs.ravel().tolist())
plt.show()
def generate_region_vector(desired, xs, ys, x_axis_sp, y_axis_sp):
"""from https://github.com/ctn-waterloo/cogsci2019-ssp/tree/master"""
vector = np.zeros_like((x_axis_sp.v))
for i, x in enumerate(xs):
for j, y in enumerate(ys):
if desired[j, i] == 1:
vector += encode_point(x, y, x_axis_sp, y_axis_sp).v
sp = spa.SemanticPointer(data=vector)
sp.normalize()
return sp
def bb_intersection_over_union(boxA, boxB):
"""from https://pyimagesearch.com/2016/11/07/intersection-over-union-iou-for-object-detection/"""
# determine the (x, y)-coordinates of the intersection rectangle
xA = max(boxA[0], boxB[0])
yA = max(boxA[1], boxB[1])
xB = min(boxA[2], boxB[2])
yB = min(boxA[3], boxB[3])
# compute the area of intersection rectangle
interArea = abs(max((xB - xA, 0)) * max((yB - yA), 0))
if interArea == 0:
return 0
# compute the area of both the prediction and ground-truth
# rectangles
boxAArea = abs((boxA[2] - boxA[0]) * (boxA[3] - boxA[1]))
boxBArea = abs((boxB[2] - boxB[0]) * (boxB[3] - boxB[1]))
# compute the intersection over union by taking the intersection
# area and dividing it by the sum of prediction + ground-truth
# areas - the interesection area
iou = interArea / float(boxAArea + boxBArea - interArea)
# return the intersection over union value
return iou
def encode_point_multidim(values, axes):
""" power(x_axis, x) * power(y_axis, y) for variable dimensions """
assert len(values) == len(axes), f'number of values {len(values)} does not match number of axes {len(axes)}'
res = 1
for v, a in zip(values, axes):
res *= power(a, v)
return res
def get_heatmap_vectors_multidim(xs, ys, ws, hs, x_axis, y_axis, w_axis, h_axis):
""" adaptation of get_heatmap_vectors for 4 dimensions """
assert x_axis.__class__.__name__ == 'SemanticPointer', f'Axes need to be of type SemanticPointer but are {x_axis.__class__.__name__}'
dim = len(x_axis.v)
vectors = np.zeros((len(xs), len(ys), len(ws), len(hs), dim))
for i, x in enumerate(xs):
for j, y in enumerate(ys):
for n, w in enumerate(ws):
for k, h in enumerate(hs):
p = encode_point_multidim(values=[x, y, w, h], axes=[x_axis, y_axis, w_axis, h_axis])
vectors[i, j, n, k, :] = p.v
return vectors
def ssp_to_loc_multidim(sp, heatmap_vectors, linspace):
""" adaptation of loc_match from https://github.com/ctn-waterloo/cogsci2019-ssp/tree/master
Convert an SSP to the approximate 4-dim location that it represents.
Uses the heatmap vectors as a lookup table
"""
xs, ys, ws, hs = linspace
assert sp.__class__.__name__ == 'SemanticPointer', \
f'Queried object needs to be of type SemanticPointer but is {sp.__class__.__name__}'
# axes: a list of axes to be summed over, first sequence applying to first tensor, second to second tensor
vs = np.tensordot(sp.v, heatmap_vectors, axes=([0], [4]))
res = np.unravel_index(np.argmax(vs, axis=None), vs.shape)
x = xs[res[0]]
y = ys[res[1]]
w = ws[res[2]]
h = hs[res[3]]
return np.array([x, y, w, h])
def encode_image_ssp(img, sg_data, axes, new_size, dim, visualize=True):
"""encode all objects in an image to an SSP memory"""
img_size = img.shape[:2]
if img_size[1] / 2 < img_size[0]:
scale = img_size[0] / new_size[0]
else:
scale = img_size[1] / new_size[1]
# scale width and height to fixed size of 10
w_scale = img_size[1] / 10
h_scale = img_size[0] / 10
if visualize:
print(f'Original image {img_size[1]}x{img_size[0]} --> {np.array(img_size) / scale}')
fig, ax = plt.subplots(1,1)
ax.imshow(img, interpolation='none', origin='upper', extent=[0, img_size[1] / scale, img_size[0] / scale, 0])
plt.axis('off')
encoded_items = {}
encoded_ssps = {}
memory = spa.SemanticPointer(data=np.zeros(dim))
name_lst = []
for i, obj in enumerate(sg_data.items()):
id_num, obj_dict = obj
name = obj_dict.get('name')
name = singularize(name)
name_lst.append(name)
name += '_' + str(name_lst.count(name))
# extract ground truth data and scale to fit to SSPs
x, y, width, height = obj_dict.get('x'), obj_dict.get('y'), obj_dict.get('w'), obj_dict.get('h')
x, y, width, height = x / scale, y / scale, width / w_scale, height / h_scale
width = width if width >= 1 else 1
height = height if height >= 1 else 1
# Round values to next int (otherwise decoding gets buggy)
item = np.round([x, y, width, height], decimals=0).astype(int)
encoded_items[name] = item
#print(name, item)
pos = encode_point_multidim(list(item), axes)
ssp = spa.SemanticPointer(dim)
encoded_ssps[name] = ssp
memory += ssp * pos
if visualize:
x, y, width, height = item
width, height = (width * w_scale) / scale, (height * h_scale) / scale
rect = patches.Rectangle((x, y),
width, height,
linewidth = 2,
label = name,
edgecolor = 'c',
facecolor = 'none')
ax.add_patch(rect)
if visualize:
plt.show()
return encoded_items, encoded_ssps, memory