287 lines
10 KiB
Python
287 lines
10 KiB
Python
|
import os
|
||
|
import cv2
|
||
|
import torch
|
||
|
from torch.utils.data import Dataset
|
||
|
import numpy as np
|
||
|
import matplotlib.pyplot as plt
|
||
|
import matplotlib.colors as mcolors
|
||
|
import matplotlib.patches as patches
|
||
|
|
||
|
import nengo.spa as spa
|
||
|
from utils import encode_point_multidim, ssp_to_loc_multidim, bb_intersection_over_union
|
||
|
|
||
|
RGB_COLORS = []
|
||
|
for name, hex in mcolors.cnames.items():
|
||
|
RGB_COLORS.append(mcolors.to_rgb(hex))
|
||
|
|
||
|
class MNISTQueryDataset(Dataset):
|
||
|
"""MNIST spatial query dataset."""
|
||
|
|
||
|
def __init__(self, mnist_data, num_imgs, img_size=120, visualize=False, transform=None, seed=42):
|
||
|
|
||
|
# Set random seed for location and mnist image selection
|
||
|
np.random.seed(seed)
|
||
|
torch.manual_seed(seed)
|
||
|
|
||
|
self.mnist_data = mnist_data
|
||
|
self.mnist_size = mnist_data[0][0].squeeze().numpy().shape[0]
|
||
|
# Shuffle MNIST data set according to random seed
|
||
|
self.mnist_indices = torch.randperm(len(self.mnist_data))
|
||
|
|
||
|
self.num_imgs = num_imgs
|
||
|
self.img_size = img_size
|
||
|
self.border = self.img_size - self.mnist_size
|
||
|
self.visualize = visualize
|
||
|
|
||
|
self.transform = transform
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.mnist_data) // self.num_imgs
|
||
|
|
||
|
def __getitem__(self, idx):
|
||
|
|
||
|
current_indices = self.mnist_indices[idx: idx + self.num_imgs]
|
||
|
|
||
|
image = np.zeros((self.img_size, self.img_size))
|
||
|
mask = np.ones((self.border, self.border))
|
||
|
labels = []
|
||
|
mnist_imgs = []
|
||
|
|
||
|
for i in current_indices:
|
||
|
|
||
|
mnist, label = self.mnist_data[i]
|
||
|
mnist_imgs.append(mnist.squeeze().numpy())
|
||
|
|
||
|
# find available space
|
||
|
indices = np.where(mask == 1)[:2]
|
||
|
coords = np.transpose(indices)
|
||
|
|
||
|
# pick random pixel as x0, y0
|
||
|
idx = np.random.randint(len(indices[0]))
|
||
|
y_pos, x_pos = coords[idx]
|
||
|
|
||
|
# add mnist to image
|
||
|
image[y_pos: y_pos+self.mnist_size, x_pos: x_pos+self.mnist_size] = mnist.squeeze().numpy()
|
||
|
|
||
|
# position label = center of mnist image
|
||
|
labels.append(dict({label: (x_pos + self.mnist_size // 2, y_pos + self.mnist_size // 2)}))
|
||
|
|
||
|
# remove available space
|
||
|
for x in np.arange(max(0, x_pos-self.mnist_size), min(x_pos+self.mnist_size+1, self.border)):
|
||
|
for y in np.arange(max(0, y_pos-self.mnist_size), min(y_pos+self.mnist_size+1, self.border)):
|
||
|
mask[y, x] = 0
|
||
|
|
||
|
# visualize image state and current mask
|
||
|
if self.visualize:
|
||
|
f, (ax1, ax2) = plt.subplots(1, 2, sharey=False)
|
||
|
ax1.imshow(image, cmap='gray')
|
||
|
ax2.imshow(mask, cmap='gray')
|
||
|
plt.show()
|
||
|
|
||
|
sample = {'image': image, 'labels': labels, 'mnist_images': mnist_imgs}
|
||
|
|
||
|
#if self.transform:
|
||
|
# sample = self.transform(sample)
|
||
|
|
||
|
return sample
|
||
|
|
||
|
|
||
|
class GQADataset():
|
||
|
|
||
|
def __init__(self, questions, programs, scenegraphs, vectors, axes, linspace,
|
||
|
path='GQA/images/images/', seed=17, verbose=0, visualize=False):
|
||
|
np.random.seed(seed)
|
||
|
|
||
|
self.questions = questions
|
||
|
self.programs = programs
|
||
|
self.scenegraphs = scenegraphs
|
||
|
self.vis = visualize
|
||
|
self.verbose = verbose
|
||
|
self.seed = seed
|
||
|
self.path = path
|
||
|
|
||
|
# vector space
|
||
|
self.ssp_vectors = vectors
|
||
|
self.ssp_axes = axes #[x_axis, y_axis, w_axis, h_axis]
|
||
|
self.linspace = linspace # [xs, ys, ws, hs]
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.questions)
|
||
|
|
||
|
def __get_item__(self, idx):
|
||
|
|
||
|
q_id = self.questions.iloc[idx].questionID
|
||
|
temp_df = self.questions.loc[self.questions.questionID == q_id]
|
||
|
|
||
|
# get image
|
||
|
img_id = temp_df.imageId.values[0]
|
||
|
img_path = os.path.join(self.path, f'{img_id}.jpg')
|
||
|
assert os.path.exists(img_path), f'Image path {img_path} does not exist!'
|
||
|
img = cv2.imread(img_path)
|
||
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||
|
|
||
|
# get question and answer
|
||
|
question = temp_df.question.values[0]
|
||
|
answer = temp_df.answer.values[0]
|
||
|
full_answer = temp_df.fullAnswer.values[0]
|
||
|
|
||
|
# get program
|
||
|
idx = self.programs.loc[self.programs.questionID == q_id].index[0]
|
||
|
program = self.programs.iloc[idx].program
|
||
|
|
||
|
info = {'q_id': q_id, 'img_id': img_id, 'question': question, 'answer': answer,
|
||
|
'full_answer': full_answer, 'program': program}
|
||
|
|
||
|
return img, info
|
||
|
|
||
|
def encode_item(self, idx, new_size=(25, 25), dim=1024):
|
||
|
""" Encode all objects in image into SSP memory = vector space.
|
||
|
ensure x- and y-axis of SSP memory have same resolution
|
||
|
and fixed width & height axes (10,10), no zero values for width & height
|
||
|
and int values instead of decimals, otherwise decoding accuracy degrades
|
||
|
"""
|
||
|
img, info = self.__get_item__(idx)
|
||
|
sg_data = self.scenegraphs.get(str(info['img_id'])).get('objects')
|
||
|
|
||
|
img_size = img.shape[:2]
|
||
|
|
||
|
# find orientation and select scale to fit into quadratic vector space
|
||
|
if img_size[1] / 2 < img_size[0]:
|
||
|
scale = img_size[0] / new_size[0]
|
||
|
else:
|
||
|
scale = img_size[1] / new_size[1]
|
||
|
|
||
|
# scale width and height to fixed size of 10
|
||
|
w_scale, h_scale = img_size[1] / 10, img_size[0] / 10
|
||
|
|
||
|
encoded_items = {}
|
||
|
encoded_ssps = {}
|
||
|
|
||
|
rng = np.random.RandomState(seed=self.seed)
|
||
|
memory = spa.SemanticPointer(data=np.zeros(dim), rng=rng)
|
||
|
name_lst = []
|
||
|
|
||
|
if self.vis:
|
||
|
print(f'Original image {img_size[0]}x{img_size[1]} --> {int(img_size[0] / scale)}x{int(img_size[1] / scale)}')
|
||
|
fig, ax = plt.subplots(1,1)
|
||
|
ax.imshow(img, interpolation='none', origin='upper', extent=[0, img_size[1] / scale, img_size[0] / scale, 0])
|
||
|
plt.axis('off')
|
||
|
|
||
|
for i, obj in enumerate(sg_data.items()):
|
||
|
id_num, obj_dict = obj
|
||
|
name = obj_dict.get('name')
|
||
|
#name = singularize(name)
|
||
|
name_lst.append(name)
|
||
|
name += '_' + str(name_lst.count(name))
|
||
|
|
||
|
# extract ground truth data and scale to fit to SSPs
|
||
|
x, y, width, height = obj_dict.get('x'), obj_dict.get('y'), obj_dict.get('w'), obj_dict.get('h')
|
||
|
x, y, width, height = x / scale, y / scale, width / w_scale, height / h_scale
|
||
|
|
||
|
width = width if width >= 1 else 1
|
||
|
height = height if height >= 1 else 1
|
||
|
|
||
|
# Round values to next int (otherwise decoding gets buggy)
|
||
|
item = np.round([x, y, width, height], decimals=0).astype(int)
|
||
|
encoded_items[name] = item
|
||
|
|
||
|
pos = encode_point_multidim(list(item), self.ssp_axes)
|
||
|
ssp = spa.SemanticPointer(dim)
|
||
|
encoded_ssps[name] = ssp
|
||
|
|
||
|
memory += ssp * pos
|
||
|
|
||
|
if self.vis:
|
||
|
x, y, width, height = item
|
||
|
width, height = (width * w_scale) / scale, (height * h_scale) / scale
|
||
|
rect = patches.Rectangle((x, y),
|
||
|
width, height,
|
||
|
linewidth = 2,
|
||
|
label = name,
|
||
|
edgecolor = RGB_COLORS[i],
|
||
|
facecolor = 'none')
|
||
|
ax.add_patch(rect)
|
||
|
|
||
|
if self.vis:
|
||
|
plt.show()
|
||
|
|
||
|
info['encoded_items'] = encoded_items
|
||
|
info['encoded_ssps'] = encoded_ssps
|
||
|
info['scales'] = [scale, w_scale, h_scale]
|
||
|
|
||
|
return img, info, memory
|
||
|
|
||
|
def decode_item(self, img, info, memory):
|
||
|
|
||
|
img_size = img.shape[:2]
|
||
|
scale, w_scale, h_scale = info['scales']
|
||
|
|
||
|
if self.vis:
|
||
|
fig, ax = plt.subplots(1,1)
|
||
|
ax.imshow(img, interpolation='none', origin='upper', extent=[0, img_size[1] / scale, img_size[0] / scale, 0])
|
||
|
plt.axis('off')
|
||
|
|
||
|
errors = []
|
||
|
iou_lst = []
|
||
|
iou_binary_lst = []
|
||
|
|
||
|
for i, (name, data) in enumerate(info['encoded_items'].items()):
|
||
|
ssp_item = info['encoded_ssps'][name]
|
||
|
|
||
|
item_decoded = memory *~ ssp_item
|
||
|
clean_loc = ssp_to_loc_multidim(item_decoded, self.ssp_vectors, self.linspace)
|
||
|
x, y, width, height = clean_loc
|
||
|
|
||
|
mse = np.square(np.subtract(data[:2], clean_loc[:2])).mean()
|
||
|
errors.append(mse)
|
||
|
|
||
|
width, height = width * w_scale / scale, height * h_scale / scale
|
||
|
bb_gt = np.array([data[0], data[1], data[0]+(data[2] * w_scale / scale), data[1]+(data[3] * h_scale / scale)])
|
||
|
|
||
|
iou = bb_intersection_over_union(bb_gt, [x, y, x+width, y+height])
|
||
|
iou_lst.append(iou)
|
||
|
|
||
|
if iou > 0.5:
|
||
|
iou_binary_lst.append(1)
|
||
|
else:
|
||
|
iou_binary_lst.append(0)
|
||
|
|
||
|
if self.vis:
|
||
|
rect = patches.Rectangle((x, y),
|
||
|
width, height,
|
||
|
linewidth = 2,
|
||
|
label = name,
|
||
|
edgecolor = RGB_COLORS[i],
|
||
|
facecolor = 'none')
|
||
|
ax.add_patch(rect)
|
||
|
|
||
|
if self.vis:
|
||
|
plt.legend(loc='upper left', bbox_to_anchor=(1., 1.02))
|
||
|
plt.show()
|
||
|
|
||
|
avg_mse = np.mean(errors)
|
||
|
avg_iou = np.mean(iou_lst)
|
||
|
|
||
|
if self.verbose > 0:
|
||
|
print(f'Average mean-squared error of 2D locations: {avg_mse:.4f}')
|
||
|
print(f'Average IoU of 4D bounding boxes: {avg_iou:.2f}')
|
||
|
print(f'Correct items: {np.sum(iou_binary_lst)} / {len(info["encoded_items"])}')
|
||
|
|
||
|
return avg_mse, avg_iou, np.sum(iou_binary_lst)
|
||
|
|
||
|
def print_item(self, idx):
|
||
|
_, info = self.__get_item__(idx)
|
||
|
|
||
|
print(f"Question #{info['q_id']}: \n{info['question']}")
|
||
|
print(f"[{info['answer']}] {info['full_answer']}\n")
|
||
|
print('Program:')
|
||
|
for i, step in enumerate(info['program']):
|
||
|
num, func = step.split('=')
|
||
|
print(f'{i}. {func}')
|
||
|
print()
|
||
|
|
||
|
def set_visualize(self, visualize):
|
||
|
self.vis = visualize
|
||
|
|
||
|
def set_verbose(self, verbose):
|
||
|
self.verbose = verbose
|