298 lines
No EOL
10 KiB
Python
298 lines
No EOL
10 KiB
Python
import json
|
|
import numpy as np
|
|
import matplotlib
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.patches as patches
|
|
from pattern.text.en import singularize
|
|
import nengo.spa as spa
|
|
import scipy.integrate as integrate
|
|
|
|
RGB_COLORS = []
|
|
hex_colors = ['#8a3ffc', '#ff7eb6', '#6fdc8c', '#d2a106', '#ba4e00', '#33b1ff', '#570408',
|
|
'#fa4d56', '#4589ff', '#08bdba', '#d4bbff', '#007d79', '#d12771', '#bae6ff']
|
|
|
|
for h in hex_colors:
|
|
RGB_COLORS.append(matplotlib.colors.to_rgb(h))
|
|
|
|
for i, (name, h) in enumerate(matplotlib.colors.cnames.items()):
|
|
if i > 10:
|
|
RGB_COLORS.append(matplotlib.colors.to_rgb(h))
|
|
|
|
|
|
f = open('gqa_all_relations_map.json')
|
|
RELATION_DICT = json.load(f)
|
|
f.close()
|
|
|
|
f = open('gqa_all_vocab_classes.json')
|
|
CLASS_DICT = json.load(f)
|
|
f.close()
|
|
|
|
f = open('gqa_all_attributes.json')
|
|
ATTRIBUTE_DICT = json.load(f)
|
|
f.close()
|
|
|
|
|
|
def bbox_to_mask(x, y, w, h, img_size=(500, 500), name=None, visualize=False):
|
|
img = np.zeros(img_size)
|
|
mask_w = np.ones(w)
|
|
for j in range(y, y+h):
|
|
img[j][x:x+w] = mask_w
|
|
|
|
if visualize:
|
|
fig = plt.figure(figsize=(img_size[0] // 80, img_size[1] // 80))
|
|
plt.imshow(img, cmap='gray')
|
|
if name:
|
|
plt.title(name)
|
|
plt.axis('off')
|
|
plt.show()
|
|
|
|
return img
|
|
|
|
def make_good_unitary(D, eps=1e-3, rng=np.random):
|
|
"""from https://github.com/ctn-waterloo/cogsci2019-ssp/tree/master"""
|
|
a = rng.rand((D - 1) // 2)
|
|
sign = rng.choice((-1, +1), len(a))
|
|
phi = sign * np.pi * (eps + a * (1 - 2 * eps))
|
|
assert np.all(np.abs(phi) >= np.pi * eps)
|
|
assert np.all(np.abs(phi) <= np.pi * (1 - eps))
|
|
|
|
fv = np.zeros(D, dtype='complex64')
|
|
fv[0] = 1
|
|
fv[1:(D + 1) // 2] = np.cos(phi) + 1j * np.sin(phi)
|
|
fv[-1:D // 2:-1] = np.conj(fv[1:(D + 1) // 2])
|
|
if D % 2 == 0:
|
|
fv[D // 2] = 1
|
|
|
|
assert np.allclose(np.abs(fv), 1)
|
|
v = np.fft.ifft(fv)
|
|
# assert np.allclose(v.imag, 0, atol=1e-5)
|
|
v = v.real
|
|
assert np.allclose(np.fft.fft(v), fv)
|
|
assert np.allclose(np.linalg.norm(v), 1)
|
|
return spa.SemanticPointer(v)
|
|
|
|
def get_heatmap_vectors(xs, ys, x_axis_sp, y_axis_sp):
|
|
"""from https://github.com/ctn-waterloo/cogsci2019-ssp/tree/master:
|
|
Precompute spatial semantic pointers for every location in the linspace
|
|
Used to quickly compute heat maps by a simple vectorized dot product (matrix multiplication)
|
|
"""
|
|
if x_axis_sp.__class__.__name__ == 'SemanticPointer':
|
|
dim = len(x_axis_sp.v)
|
|
else:
|
|
dim = len(x_axis_sp)
|
|
x_axis_sp = spa.SemanticPointer(data=x_axis_sp)
|
|
y_axis_sp = spa.SemanticPointer(data=y_axis_sp)
|
|
|
|
vectors = np.zeros((len(xs), len(ys), dim))
|
|
|
|
for i, x in enumerate(xs):
|
|
for j, y in enumerate(ys):
|
|
p = encode_point(
|
|
x=x, y=y, x_axis=x_axis_sp, y_axis=y_axis_sp,
|
|
)
|
|
vectors[i, j, :] = p.v
|
|
|
|
return vectors
|
|
|
|
def power(s, e):
|
|
"""from https://github.com/ctn-waterloo/cogsci2019-ssp/tree/master"""
|
|
x = np.fft.ifft(np.fft.fft(s.v) ** e).real
|
|
return spa.SemanticPointer(data=x)
|
|
|
|
def encode_point(x, y, x_axis, y_axis):
|
|
"""from https://github.com/ctn-waterloo/cogsci2019-ssp/tree/master"""
|
|
return power(x_axis, x) * power(y_axis, y)
|
|
|
|
def encode_region(x, y, x_axis, y_axis):
|
|
"""from https://github.com/ctn-waterloo/cogsci2019-ssp/tree/master"""
|
|
print(integrate.quad(power(x_axis, x) * power(y_axis, y), x, x+28))
|
|
return integrate.quad(power(x_axis, x) * power(y_axis, y), x, x+28)
|
|
|
|
|
|
def plot_heatmap(img, img_area, encoded_pos, xs, ys, heatmap_vectors, name='', vmin=-1, vmax=1, invert=False):
|
|
"""from https://github.com/ctn-waterloo/cogsci2019-ssp/tree/master"""
|
|
assert encoded_pos.__class__.__name__ == 'SemanticPointer'
|
|
|
|
# sp has shape (dim,) and heatmap_vectors have shape (xs, ys, dim) so the result will be (xs, ys)
|
|
vec_sim = np.tensordot(encoded_pos.v, heatmap_vectors, axes=([0], [2]))
|
|
|
|
num_plots = 3 if img_area is not None else 2
|
|
fig, axs = plt.subplots(1, num_plots, figsize=(4 * num_plots + 3, 3))
|
|
fig.suptitle(name)
|
|
|
|
axs[0].imshow(img)
|
|
axs[0].axis('off')
|
|
|
|
if img_area is not None:
|
|
axs[1].imshow(img_area, cmap='gray')
|
|
axs[1].set_xticks(np.arange(0, len(xs), 20), np.arange(0, img.shape[1], img.shape[1] / len(xs)).astype(int)[::20])
|
|
axs[1].set_yticks(np.arange(0, len(ys), 10), np.arange(0, img.shape[0], img.shape[0] / len(ys)).astype(int)[::10])
|
|
axs[1].axis('off')
|
|
|
|
im = axs[2].imshow(np.transpose(vec_sim), origin='upper', interpolation='none', extent=(xs[-1], xs[0], ys[-1], ys[0]), vmin=vmin, vmax=vmax, cmap='plasma')
|
|
axs[2].axis('off')
|
|
|
|
else:
|
|
im = axs[1].imshow(np.transpose(vec_sim), origin='upper', interpolation='none', extent=(xs[-1], xs[0], ys[-1], ys[0]), vmin=vmin, vmax=vmax, cmap='plasma')
|
|
axs[1].axis('off')
|
|
|
|
fig.colorbar(im, ax=axs.ravel().tolist())
|
|
plt.show()
|
|
|
|
|
|
def generate_region_vector(desired, xs, ys, x_axis_sp, y_axis_sp):
|
|
"""from https://github.com/ctn-waterloo/cogsci2019-ssp/tree/master"""
|
|
vector = np.zeros_like((x_axis_sp.v))
|
|
for i, x in enumerate(xs):
|
|
for j, y in enumerate(ys):
|
|
if desired[j, i] == 1:
|
|
vector += encode_point(x, y, x_axis_sp, y_axis_sp).v
|
|
|
|
sp = spa.SemanticPointer(data=vector)
|
|
sp.normalize()
|
|
return sp
|
|
|
|
|
|
def bb_intersection_over_union(boxA, boxB):
|
|
"""from https://pyimagesearch.com/2016/11/07/intersection-over-union-iou-for-object-detection/"""
|
|
# determine the (x, y)-coordinates of the intersection rectangle
|
|
xA = max(boxA[0], boxB[0])
|
|
yA = max(boxA[1], boxB[1])
|
|
xB = min(boxA[2], boxB[2])
|
|
yB = min(boxA[3], boxB[3])
|
|
|
|
# compute the area of intersection rectangle
|
|
interArea = abs(max((xB - xA, 0)) * max((yB - yA), 0))
|
|
if interArea == 0:
|
|
return 0
|
|
# compute the area of both the prediction and ground-truth
|
|
# rectangles
|
|
boxAArea = abs((boxA[2] - boxA[0]) * (boxA[3] - boxA[1]))
|
|
boxBArea = abs((boxB[2] - boxB[0]) * (boxB[3] - boxB[1]))
|
|
|
|
# compute the intersection over union by taking the intersection
|
|
# area and dividing it by the sum of prediction + ground-truth
|
|
# areas - the interesection area
|
|
iou = interArea / float(boxAArea + boxBArea - interArea)
|
|
|
|
# return the intersection over union value
|
|
return iou
|
|
|
|
|
|
def encode_point_multidim(values, axes):
|
|
""" power(x_axis, x) * power(y_axis, y) for variable dimensions """
|
|
assert len(values) == len(axes), f'number of values {len(values)} does not match number of axes {len(axes)}'
|
|
res = 1
|
|
for v, a in zip(values, axes):
|
|
res *= power(a, v)
|
|
return res
|
|
|
|
def get_heatmap_vectors_multidim(xs, ys, ws, hs, x_axis, y_axis, w_axis, h_axis):
|
|
""" adaptation of get_heatmap_vectors for 4 dimensions """
|
|
assert x_axis.__class__.__name__ == 'SemanticPointer', f'Axes need to be of type SemanticPointer but are {x_axis.__class__.__name__}'
|
|
|
|
dim = len(x_axis.v)
|
|
vectors = np.zeros((len(xs), len(ys), len(ws), len(hs), dim))
|
|
|
|
for i, x in enumerate(xs):
|
|
for j, y in enumerate(ys):
|
|
for n, w in enumerate(ws):
|
|
for k, h in enumerate(hs):
|
|
p = encode_point_multidim(values=[x, y, w, h], axes=[x_axis, y_axis, w_axis, h_axis])
|
|
vectors[i, j, n, k, :] = p.v
|
|
|
|
return vectors
|
|
|
|
|
|
def ssp_to_loc_multidim(sp, heatmap_vectors, linspace):
|
|
""" adaptation of loc_match from https://github.com/ctn-waterloo/cogsci2019-ssp/tree/master
|
|
Convert an SSP to the approximate 4-dim location that it represents.
|
|
Uses the heatmap vectors as a lookup table
|
|
"""
|
|
xs, ys, ws, hs = linspace
|
|
|
|
assert sp.__class__.__name__ == 'SemanticPointer', \
|
|
f'Queried object needs to be of type SemanticPointer but is {sp.__class__.__name__}'
|
|
|
|
# axes: a list of axes to be summed over, first sequence applying to first tensor, second to second tensor
|
|
vs = np.tensordot(sp.v, heatmap_vectors, axes=([0], [4]))
|
|
|
|
res = np.unravel_index(np.argmax(vs, axis=None), vs.shape)
|
|
|
|
x = xs[res[0]]
|
|
y = ys[res[1]]
|
|
w = ws[res[2]]
|
|
h = hs[res[3]]
|
|
|
|
|
|
return np.array([x, y, w, h])
|
|
|
|
|
|
def encode_image_ssp(img, sg_data, axes, new_size, dim, visualize=True):
|
|
"""encode all objects in an image to an SSP memory"""
|
|
|
|
img_size = img.shape[:2]
|
|
|
|
if img_size[1] / 2 < img_size[0]:
|
|
scale = img_size[0] / new_size[0]
|
|
else:
|
|
scale = img_size[1] / new_size[1]
|
|
|
|
# scale width and height to fixed size of 10
|
|
w_scale = img_size[1] / 10
|
|
h_scale = img_size[0] / 10
|
|
|
|
|
|
if visualize:
|
|
print(f'Original image {img_size[1]}x{img_size[0]} --> {np.array(img_size) / scale}')
|
|
fig, ax = plt.subplots(1,1)
|
|
ax.imshow(img, interpolation='none', origin='upper', extent=[0, img_size[1] / scale, img_size[0] / scale, 0])
|
|
plt.axis('off')
|
|
|
|
|
|
encoded_items = {}
|
|
encoded_ssps = {}
|
|
|
|
memory = spa.SemanticPointer(data=np.zeros(dim))
|
|
name_lst = []
|
|
|
|
for i, obj in enumerate(sg_data.items()):
|
|
id_num, obj_dict = obj
|
|
name = obj_dict.get('name')
|
|
name = singularize(name)
|
|
name_lst.append(name)
|
|
name += '_' + str(name_lst.count(name))
|
|
|
|
# extract ground truth data and scale to fit to SSPs
|
|
x, y, width, height = obj_dict.get('x'), obj_dict.get('y'), obj_dict.get('w'), obj_dict.get('h')
|
|
x, y, width, height = x / scale, y / scale, width / w_scale, height / h_scale
|
|
|
|
width = width if width >= 1 else 1
|
|
height = height if height >= 1 else 1
|
|
|
|
# Round values to next int (otherwise decoding gets buggy)
|
|
item = np.round([x, y, width, height], decimals=0).astype(int)
|
|
encoded_items[name] = item
|
|
#print(name, item)
|
|
|
|
pos = encode_point_multidim(list(item), axes)
|
|
ssp = spa.SemanticPointer(dim)
|
|
encoded_ssps[name] = ssp
|
|
|
|
memory += ssp * pos
|
|
|
|
if visualize:
|
|
x, y, width, height = item
|
|
width, height = (width * w_scale) / scale, (height * h_scale) / scale
|
|
rect = patches.Rectangle((x, y),
|
|
width, height,
|
|
linewidth = 2,
|
|
label = name,
|
|
edgecolor = 'c',
|
|
facecolor = 'none')
|
|
ax.add_patch(rect)
|
|
|
|
if visualize:
|
|
plt.show()
|
|
|
|
return encoded_items, encoded_ssps, memory |