import os import cv2 import torch from torch.utils.data import Dataset import numpy as np import matplotlib.pyplot as plt import matplotlib.colors as mcolors import matplotlib.patches as patches import nengo.spa as spa from utils import encode_point_multidim, ssp_to_loc_multidim, bb_intersection_over_union RGB_COLORS = [] for name, hex in mcolors.cnames.items(): RGB_COLORS.append(mcolors.to_rgb(hex)) class MNISTQueryDataset(Dataset): """MNIST spatial query dataset.""" def __init__(self, mnist_data, num_imgs, img_size=120, visualize=False, transform=None, seed=42): # Set random seed for location and mnist image selection np.random.seed(seed) torch.manual_seed(seed) self.mnist_data = mnist_data self.mnist_size = mnist_data[0][0].squeeze().numpy().shape[0] # Shuffle MNIST data set according to random seed self.mnist_indices = torch.randperm(len(self.mnist_data)) self.num_imgs = num_imgs self.img_size = img_size self.border = self.img_size - self.mnist_size self.visualize = visualize self.transform = transform def __len__(self): return len(self.mnist_data) // self.num_imgs def __getitem__(self, idx): current_indices = self.mnist_indices[idx: idx + self.num_imgs] image = np.zeros((self.img_size, self.img_size)) mask = np.ones((self.border, self.border)) labels = [] mnist_imgs = [] for i in current_indices: mnist, label = self.mnist_data[i] mnist_imgs.append(mnist.squeeze().numpy()) # find available space indices = np.where(mask == 1)[:2] coords = np.transpose(indices) # pick random pixel as x0, y0 idx = np.random.randint(len(indices[0])) y_pos, x_pos = coords[idx] # add mnist to image image[y_pos: y_pos+self.mnist_size, x_pos: x_pos+self.mnist_size] = mnist.squeeze().numpy() # position label = center of mnist image labels.append(dict({label: (x_pos + self.mnist_size // 2, y_pos + self.mnist_size // 2)})) # remove available space for x in np.arange(max(0, x_pos-self.mnist_size), min(x_pos+self.mnist_size+1, self.border)): for y in np.arange(max(0, y_pos-self.mnist_size), min(y_pos+self.mnist_size+1, self.border)): mask[y, x] = 0 # visualize image state and current mask if self.visualize: f, (ax1, ax2) = plt.subplots(1, 2, sharey=False) ax1.imshow(image, cmap='gray') ax2.imshow(mask, cmap='gray') plt.show() sample = {'image': image, 'labels': labels, 'mnist_images': mnist_imgs} #if self.transform: # sample = self.transform(sample) return sample class GQADataset(): def __init__(self, questions, programs, scenegraphs, vectors, axes, linspace, path='GQA/images/images/', seed=17, verbose=0, visualize=False): np.random.seed(seed) self.questions = questions self.programs = programs self.scenegraphs = scenegraphs self.vis = visualize self.verbose = verbose self.seed = seed self.path = path # vector space self.ssp_vectors = vectors self.ssp_axes = axes #[x_axis, y_axis, w_axis, h_axis] self.linspace = linspace # [xs, ys, ws, hs] def __len__(self): return len(self.questions) def __get_item__(self, idx): q_id = self.questions.iloc[idx].questionID temp_df = self.questions.loc[self.questions.questionID == q_id] # get image img_id = temp_df.imageId.values[0] img_path = os.path.join(self.path, f'{img_id}.jpg') assert os.path.exists(img_path), f'Image path {img_path} does not exist!' img = cv2.imread(img_path) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # get question and answer question = temp_df.question.values[0] answer = temp_df.answer.values[0] full_answer = temp_df.fullAnswer.values[0] # get program idx = self.programs.loc[self.programs.questionID == q_id].index[0] program = self.programs.iloc[idx].program info = {'q_id': q_id, 'img_id': img_id, 'question': question, 'answer': answer, 'full_answer': full_answer, 'program': program} return img, info def encode_item(self, idx, new_size=(25, 25), dim=1024): """ Encode all objects in image into SSP memory = vector space. ensure x- and y-axis of SSP memory have same resolution and fixed width & height axes (10,10), no zero values for width & height and int values instead of decimals, otherwise decoding accuracy degrades """ img, info = self.__get_item__(idx) sg_data = self.scenegraphs.get(str(info['img_id'])).get('objects') img_size = img.shape[:2] # find orientation and select scale to fit into quadratic vector space if img_size[1] / 2 < img_size[0]: scale = img_size[0] / new_size[0] else: scale = img_size[1] / new_size[1] # scale width and height to fixed size of 10 w_scale, h_scale = img_size[1] / 10, img_size[0] / 10 encoded_items = {} encoded_ssps = {} rng = np.random.RandomState(seed=self.seed) memory = spa.SemanticPointer(data=np.zeros(dim), rng=rng) name_lst = [] if self.vis: print(f'Original image {img_size[0]}x{img_size[1]} --> {int(img_size[0] / scale)}x{int(img_size[1] / scale)}') fig, ax = plt.subplots(1,1) ax.imshow(img, interpolation='none', origin='upper', extent=[0, img_size[1] / scale, img_size[0] / scale, 0]) plt.axis('off') for i, obj in enumerate(sg_data.items()): id_num, obj_dict = obj name = obj_dict.get('name') #name = singularize(name) name_lst.append(name) name += '_' + str(name_lst.count(name)) # extract ground truth data and scale to fit to SSPs x, y, width, height = obj_dict.get('x'), obj_dict.get('y'), obj_dict.get('w'), obj_dict.get('h') x, y, width, height = x / scale, y / scale, width / w_scale, height / h_scale width = width if width >= 1 else 1 height = height if height >= 1 else 1 # Round values to next int (otherwise decoding gets buggy) item = np.round([x, y, width, height], decimals=0).astype(int) encoded_items[name] = item pos = encode_point_multidim(list(item), self.ssp_axes) ssp = spa.SemanticPointer(dim) encoded_ssps[name] = ssp memory += ssp * pos if self.vis: x, y, width, height = item width, height = (width * w_scale) / scale, (height * h_scale) / scale rect = patches.Rectangle((x, y), width, height, linewidth = 2, label = name, edgecolor = RGB_COLORS[i], facecolor = 'none') ax.add_patch(rect) if self.vis: plt.show() info['encoded_items'] = encoded_items info['encoded_ssps'] = encoded_ssps info['scales'] = [scale, w_scale, h_scale] return img, info, memory def decode_item(self, img, info, memory): img_size = img.shape[:2] scale, w_scale, h_scale = info['scales'] if self.vis: fig, ax = plt.subplots(1,1) ax.imshow(img, interpolation='none', origin='upper', extent=[0, img_size[1] / scale, img_size[0] / scale, 0]) plt.axis('off') errors = [] iou_lst = [] iou_binary_lst = [] for i, (name, data) in enumerate(info['encoded_items'].items()): ssp_item = info['encoded_ssps'][name] item_decoded = memory *~ ssp_item clean_loc = ssp_to_loc_multidim(item_decoded, self.ssp_vectors, self.linspace) x, y, width, height = clean_loc mse = np.square(np.subtract(data[:2], clean_loc[:2])).mean() errors.append(mse) width, height = width * w_scale / scale, height * h_scale / scale bb_gt = np.array([data[0], data[1], data[0]+(data[2] * w_scale / scale), data[1]+(data[3] * h_scale / scale)]) iou = bb_intersection_over_union(bb_gt, [x, y, x+width, y+height]) iou_lst.append(iou) if iou > 0.5: iou_binary_lst.append(1) else: iou_binary_lst.append(0) if self.vis: rect = patches.Rectangle((x, y), width, height, linewidth = 2, label = name, edgecolor = RGB_COLORS[i], facecolor = 'none') ax.add_patch(rect) if self.vis: plt.legend(loc='upper left', bbox_to_anchor=(1., 1.02)) plt.show() avg_mse = np.mean(errors) avg_iou = np.mean(iou_lst) if self.verbose > 0: print(f'Average mean-squared error of 2D locations: {avg_mse:.4f}') print(f'Average IoU of 4D bounding boxes: {avg_iou:.2f}') print(f'Correct items: {np.sum(iou_binary_lst)} / {len(info["encoded_items"])}') return avg_mse, avg_iou, np.sum(iou_binary_lst) def print_item(self, idx): _, info = self.__get_item__(idx) print(f"Question #{info['q_id']}: \n{info['question']}") print(f"[{info['answer']}] {info['full_answer']}\n") print('Program:') for i, step in enumerate(info['program']): num, func = step.split('=') print(f'{i}. {func}') print() def set_visualize(self, visualize): self.vis = visualize def set_verbose(self, verbose): self.verbose = verbose