VSA4VQA/utils.py

298 lines
10 KiB
Python
Raw Normal View History

2024-04-29 17:18:10 +02:00
import json
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from pattern.text.en import singularize
import nengo.spa as spa
import scipy.integrate as integrate
RGB_COLORS = []
hex_colors = ['#8a3ffc', '#ff7eb6', '#6fdc8c', '#d2a106', '#ba4e00', '#33b1ff', '#570408',
'#fa4d56', '#4589ff', '#08bdba', '#d4bbff', '#007d79', '#d12771', '#bae6ff']
for h in hex_colors:
RGB_COLORS.append(matplotlib.colors.to_rgb(h))
for i, (name, h) in enumerate(matplotlib.colors.cnames.items()):
if i > 10:
RGB_COLORS.append(matplotlib.colors.to_rgb(h))
f = open('gqa_all_relations_map.json')
RELATION_DICT = json.load(f)
f.close()
f = open('gqa_all_vocab_classes.json')
CLASS_DICT = json.load(f)
f.close()
f = open('gqa_all_attributes.json')
ATTRIBUTE_DICT = json.load(f)
f.close()
def bbox_to_mask(x, y, w, h, img_size=(500, 500), name=None, visualize=False):
img = np.zeros(img_size)
mask_w = np.ones(w)
for j in range(y, y+h):
img[j][x:x+w] = mask_w
if visualize:
fig = plt.figure(figsize=(img_size[0] // 80, img_size[1] // 80))
plt.imshow(img, cmap='gray')
if name:
plt.title(name)
plt.axis('off')
plt.show()
return img
def make_good_unitary(D, eps=1e-3, rng=np.random):
"""from https://github.com/ctn-waterloo/cogsci2019-ssp/tree/master"""
a = rng.rand((D - 1) // 2)
sign = rng.choice((-1, +1), len(a))
phi = sign * np.pi * (eps + a * (1 - 2 * eps))
assert np.all(np.abs(phi) >= np.pi * eps)
assert np.all(np.abs(phi) <= np.pi * (1 - eps))
fv = np.zeros(D, dtype='complex64')
fv[0] = 1
fv[1:(D + 1) // 2] = np.cos(phi) + 1j * np.sin(phi)
fv[-1:D // 2:-1] = np.conj(fv[1:(D + 1) // 2])
if D % 2 == 0:
fv[D // 2] = 1
assert np.allclose(np.abs(fv), 1)
v = np.fft.ifft(fv)
# assert np.allclose(v.imag, 0, atol=1e-5)
v = v.real
assert np.allclose(np.fft.fft(v), fv)
assert np.allclose(np.linalg.norm(v), 1)
return spa.SemanticPointer(v)
def get_heatmap_vectors(xs, ys, x_axis_sp, y_axis_sp):
"""from https://github.com/ctn-waterloo/cogsci2019-ssp/tree/master:
Precompute spatial semantic pointers for every location in the linspace
Used to quickly compute heat maps by a simple vectorized dot product (matrix multiplication)
"""
if x_axis_sp.__class__.__name__ == 'SemanticPointer':
dim = len(x_axis_sp.v)
else:
dim = len(x_axis_sp)
x_axis_sp = spa.SemanticPointer(data=x_axis_sp)
y_axis_sp = spa.SemanticPointer(data=y_axis_sp)
vectors = np.zeros((len(xs), len(ys), dim))
for i, x in enumerate(xs):
for j, y in enumerate(ys):
p = encode_point(
x=x, y=y, x_axis=x_axis_sp, y_axis=y_axis_sp,
)
vectors[i, j, :] = p.v
return vectors
def power(s, e):
"""from https://github.com/ctn-waterloo/cogsci2019-ssp/tree/master"""
x = np.fft.ifft(np.fft.fft(s.v) ** e).real
return spa.SemanticPointer(data=x)
def encode_point(x, y, x_axis, y_axis):
"""from https://github.com/ctn-waterloo/cogsci2019-ssp/tree/master"""
return power(x_axis, x) * power(y_axis, y)
def encode_region(x, y, x_axis, y_axis):
"""from https://github.com/ctn-waterloo/cogsci2019-ssp/tree/master"""
print(integrate.quad(power(x_axis, x) * power(y_axis, y), x, x+28))
return integrate.quad(power(x_axis, x) * power(y_axis, y), x, x+28)
def plot_heatmap(img, img_area, encoded_pos, xs, ys, heatmap_vectors, name='', vmin=-1, vmax=1, invert=False):
"""from https://github.com/ctn-waterloo/cogsci2019-ssp/tree/master"""
assert encoded_pos.__class__.__name__ == 'SemanticPointer'
# sp has shape (dim,) and heatmap_vectors have shape (xs, ys, dim) so the result will be (xs, ys)
vec_sim = np.tensordot(encoded_pos.v, heatmap_vectors, axes=([0], [2]))
num_plots = 3 if img_area is not None else 2
fig, axs = plt.subplots(1, num_plots, figsize=(4 * num_plots + 3, 3))
fig.suptitle(name)
axs[0].imshow(img)
axs[0].axis('off')
if img_area is not None:
axs[1].imshow(img_area, cmap='gray')
axs[1].set_xticks(np.arange(0, len(xs), 20), np.arange(0, img.shape[1], img.shape[1] / len(xs)).astype(int)[::20])
axs[1].set_yticks(np.arange(0, len(ys), 10), np.arange(0, img.shape[0], img.shape[0] / len(ys)).astype(int)[::10])
axs[1].axis('off')
im = axs[2].imshow(np.transpose(vec_sim), origin='upper', interpolation='none', extent=(xs[-1], xs[0], ys[-1], ys[0]), vmin=vmin, vmax=vmax, cmap='plasma')
axs[2].axis('off')
else:
im = axs[1].imshow(np.transpose(vec_sim), origin='upper', interpolation='none', extent=(xs[-1], xs[0], ys[-1], ys[0]), vmin=vmin, vmax=vmax, cmap='plasma')
axs[1].axis('off')
fig.colorbar(im, ax=axs.ravel().tolist())
plt.show()
def generate_region_vector(desired, xs, ys, x_axis_sp, y_axis_sp):
"""from https://github.com/ctn-waterloo/cogsci2019-ssp/tree/master"""
vector = np.zeros_like((x_axis_sp.v))
for i, x in enumerate(xs):
for j, y in enumerate(ys):
if desired[j, i] == 1:
vector += encode_point(x, y, x_axis_sp, y_axis_sp).v
sp = spa.SemanticPointer(data=vector)
sp.normalize()
return sp
def bb_intersection_over_union(boxA, boxB):
"""from https://pyimagesearch.com/2016/11/07/intersection-over-union-iou-for-object-detection/"""
# determine the (x, y)-coordinates of the intersection rectangle
xA = max(boxA[0], boxB[0])
yA = max(boxA[1], boxB[1])
xB = min(boxA[2], boxB[2])
yB = min(boxA[3], boxB[3])
# compute the area of intersection rectangle
interArea = abs(max((xB - xA, 0)) * max((yB - yA), 0))
if interArea == 0:
return 0
# compute the area of both the prediction and ground-truth
# rectangles
boxAArea = abs((boxA[2] - boxA[0]) * (boxA[3] - boxA[1]))
boxBArea = abs((boxB[2] - boxB[0]) * (boxB[3] - boxB[1]))
# compute the intersection over union by taking the intersection
# area and dividing it by the sum of prediction + ground-truth
# areas - the interesection area
iou = interArea / float(boxAArea + boxBArea - interArea)
# return the intersection over union value
return iou
def encode_point_multidim(values, axes):
""" power(x_axis, x) * power(y_axis, y) for variable dimensions """
assert len(values) == len(axes), f'number of values {len(values)} does not match number of axes {len(axes)}'
res = 1
for v, a in zip(values, axes):
res *= power(a, v)
return res
def get_heatmap_vectors_multidim(xs, ys, ws, hs, x_axis, y_axis, w_axis, h_axis):
""" adaptation of get_heatmap_vectors for 4 dimensions """
assert x_axis.__class__.__name__ == 'SemanticPointer', f'Axes need to be of type SemanticPointer but are {x_axis.__class__.__name__}'
dim = len(x_axis.v)
vectors = np.zeros((len(xs), len(ys), len(ws), len(hs), dim))
for i, x in enumerate(xs):
for j, y in enumerate(ys):
for n, w in enumerate(ws):
for k, h in enumerate(hs):
p = encode_point_multidim(values=[x, y, w, h], axes=[x_axis, y_axis, w_axis, h_axis])
vectors[i, j, n, k, :] = p.v
return vectors
def ssp_to_loc_multidim(sp, heatmap_vectors, linspace):
""" adaptation of loc_match from https://github.com/ctn-waterloo/cogsci2019-ssp/tree/master
Convert an SSP to the approximate 4-dim location that it represents.
Uses the heatmap vectors as a lookup table
"""
xs, ys, ws, hs = linspace
assert sp.__class__.__name__ == 'SemanticPointer', \
f'Queried object needs to be of type SemanticPointer but is {sp.__class__.__name__}'
# axes: a list of axes to be summed over, first sequence applying to first tensor, second to second tensor
vs = np.tensordot(sp.v, heatmap_vectors, axes=([0], [4]))
res = np.unravel_index(np.argmax(vs, axis=None), vs.shape)
x = xs[res[0]]
y = ys[res[1]]
w = ws[res[2]]
h = hs[res[3]]
return np.array([x, y, w, h])
def encode_image_ssp(img, sg_data, axes, new_size, dim, visualize=True):
"""encode all objects in an image to an SSP memory"""
img_size = img.shape[:2]
if img_size[1] / 2 < img_size[0]:
scale = img_size[0] / new_size[0]
else:
scale = img_size[1] / new_size[1]
# scale width and height to fixed size of 10
w_scale = img_size[1] / 10
h_scale = img_size[0] / 10
if visualize:
print(f'Original image {img_size[1]}x{img_size[0]} --> {np.array(img_size) / scale}')
fig, ax = plt.subplots(1,1)
ax.imshow(img, interpolation='none', origin='upper', extent=[0, img_size[1] / scale, img_size[0] / scale, 0])
plt.axis('off')
encoded_items = {}
encoded_ssps = {}
memory = spa.SemanticPointer(data=np.zeros(dim))
name_lst = []
for i, obj in enumerate(sg_data.items()):
id_num, obj_dict = obj
name = obj_dict.get('name')
name = singularize(name)
name_lst.append(name)
name += '_' + str(name_lst.count(name))
# extract ground truth data and scale to fit to SSPs
x, y, width, height = obj_dict.get('x'), obj_dict.get('y'), obj_dict.get('w'), obj_dict.get('h')
x, y, width, height = x / scale, y / scale, width / w_scale, height / h_scale
width = width if width >= 1 else 1
height = height if height >= 1 else 1
# Round values to next int (otherwise decoding gets buggy)
item = np.round([x, y, width, height], decimals=0).astype(int)
encoded_items[name] = item
#print(name, item)
pos = encode_point_multidim(list(item), axes)
ssp = spa.SemanticPointer(dim)
encoded_ssps[name] = ssp
memory += ssp * pos
if visualize:
x, y, width, height = item
width, height = (width * w_scale) / scale, (height * h_scale) / scale
rect = patches.Rectangle((x, y),
width, height,
linewidth = 2,
label = name,
edgecolor = 'c',
facecolor = 'none')
ax.add_patch(rect)
if visualize:
plt.show()
return encoded_items, encoded_ssps, memory