VSA4VQA/dataset.py

287 lines
10 KiB
Python
Raw Normal View History

2024-04-29 17:18:10 +02:00
import os
import cv2
import torch
from torch.utils.data import Dataset
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.patches as patches
import nengo.spa as spa
from utils import encode_point_multidim, ssp_to_loc_multidim, bb_intersection_over_union
RGB_COLORS = []
for name, hex in mcolors.cnames.items():
RGB_COLORS.append(mcolors.to_rgb(hex))
class MNISTQueryDataset(Dataset):
"""MNIST spatial query dataset."""
def __init__(self, mnist_data, num_imgs, img_size=120, visualize=False, transform=None, seed=42):
# Set random seed for location and mnist image selection
np.random.seed(seed)
torch.manual_seed(seed)
self.mnist_data = mnist_data
self.mnist_size = mnist_data[0][0].squeeze().numpy().shape[0]
# Shuffle MNIST data set according to random seed
self.mnist_indices = torch.randperm(len(self.mnist_data))
self.num_imgs = num_imgs
self.img_size = img_size
self.border = self.img_size - self.mnist_size
self.visualize = visualize
self.transform = transform
def __len__(self):
return len(self.mnist_data) // self.num_imgs
def __getitem__(self, idx):
current_indices = self.mnist_indices[idx: idx + self.num_imgs]
image = np.zeros((self.img_size, self.img_size))
mask = np.ones((self.border, self.border))
labels = []
mnist_imgs = []
for i in current_indices:
mnist, label = self.mnist_data[i]
mnist_imgs.append(mnist.squeeze().numpy())
# find available space
indices = np.where(mask == 1)[:2]
coords = np.transpose(indices)
# pick random pixel as x0, y0
idx = np.random.randint(len(indices[0]))
y_pos, x_pos = coords[idx]
# add mnist to image
image[y_pos: y_pos+self.mnist_size, x_pos: x_pos+self.mnist_size] = mnist.squeeze().numpy()
# position label = center of mnist image
labels.append(dict({label: (x_pos + self.mnist_size // 2, y_pos + self.mnist_size // 2)}))
# remove available space
for x in np.arange(max(0, x_pos-self.mnist_size), min(x_pos+self.mnist_size+1, self.border)):
for y in np.arange(max(0, y_pos-self.mnist_size), min(y_pos+self.mnist_size+1, self.border)):
mask[y, x] = 0
# visualize image state and current mask
if self.visualize:
f, (ax1, ax2) = plt.subplots(1, 2, sharey=False)
ax1.imshow(image, cmap='gray')
ax2.imshow(mask, cmap='gray')
plt.show()
sample = {'image': image, 'labels': labels, 'mnist_images': mnist_imgs}
#if self.transform:
# sample = self.transform(sample)
return sample
class GQADataset():
def __init__(self, questions, programs, scenegraphs, vectors, axes, linspace,
path='GQA/images/images/', seed=17, verbose=0, visualize=False):
np.random.seed(seed)
self.questions = questions
self.programs = programs
self.scenegraphs = scenegraphs
self.vis = visualize
self.verbose = verbose
self.seed = seed
self.path = path
# vector space
self.ssp_vectors = vectors
self.ssp_axes = axes #[x_axis, y_axis, w_axis, h_axis]
self.linspace = linspace # [xs, ys, ws, hs]
def __len__(self):
return len(self.questions)
def __get_item__(self, idx):
q_id = self.questions.iloc[idx].questionID
temp_df = self.questions.loc[self.questions.questionID == q_id]
# get image
img_id = temp_df.imageId.values[0]
img_path = os.path.join(self.path, f'{img_id}.jpg')
assert os.path.exists(img_path), f'Image path {img_path} does not exist!'
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# get question and answer
question = temp_df.question.values[0]
answer = temp_df.answer.values[0]
full_answer = temp_df.fullAnswer.values[0]
# get program
idx = self.programs.loc[self.programs.questionID == q_id].index[0]
program = self.programs.iloc[idx].program
info = {'q_id': q_id, 'img_id': img_id, 'question': question, 'answer': answer,
'full_answer': full_answer, 'program': program}
return img, info
def encode_item(self, idx, new_size=(25, 25), dim=1024):
""" Encode all objects in image into SSP memory = vector space.
ensure x- and y-axis of SSP memory have same resolution
and fixed width & height axes (10,10), no zero values for width & height
and int values instead of decimals, otherwise decoding accuracy degrades
"""
img, info = self.__get_item__(idx)
sg_data = self.scenegraphs.get(str(info['img_id'])).get('objects')
img_size = img.shape[:2]
# find orientation and select scale to fit into quadratic vector space
if img_size[1] / 2 < img_size[0]:
scale = img_size[0] / new_size[0]
else:
scale = img_size[1] / new_size[1]
# scale width and height to fixed size of 10
w_scale, h_scale = img_size[1] / 10, img_size[0] / 10
encoded_items = {}
encoded_ssps = {}
rng = np.random.RandomState(seed=self.seed)
memory = spa.SemanticPointer(data=np.zeros(dim), rng=rng)
name_lst = []
if self.vis:
print(f'Original image {img_size[0]}x{img_size[1]} --> {int(img_size[0] / scale)}x{int(img_size[1] / scale)}')
fig, ax = plt.subplots(1,1)
ax.imshow(img, interpolation='none', origin='upper', extent=[0, img_size[1] / scale, img_size[0] / scale, 0])
plt.axis('off')
for i, obj in enumerate(sg_data.items()):
id_num, obj_dict = obj
name = obj_dict.get('name')
#name = singularize(name)
name_lst.append(name)
name += '_' + str(name_lst.count(name))
# extract ground truth data and scale to fit to SSPs
x, y, width, height = obj_dict.get('x'), obj_dict.get('y'), obj_dict.get('w'), obj_dict.get('h')
x, y, width, height = x / scale, y / scale, width / w_scale, height / h_scale
width = width if width >= 1 else 1
height = height if height >= 1 else 1
# Round values to next int (otherwise decoding gets buggy)
item = np.round([x, y, width, height], decimals=0).astype(int)
encoded_items[name] = item
pos = encode_point_multidim(list(item), self.ssp_axes)
ssp = spa.SemanticPointer(dim)
encoded_ssps[name] = ssp
memory += ssp * pos
if self.vis:
x, y, width, height = item
width, height = (width * w_scale) / scale, (height * h_scale) / scale
rect = patches.Rectangle((x, y),
width, height,
linewidth = 2,
label = name,
edgecolor = RGB_COLORS[i],
facecolor = 'none')
ax.add_patch(rect)
if self.vis:
plt.show()
info['encoded_items'] = encoded_items
info['encoded_ssps'] = encoded_ssps
info['scales'] = [scale, w_scale, h_scale]
return img, info, memory
def decode_item(self, img, info, memory):
img_size = img.shape[:2]
scale, w_scale, h_scale = info['scales']
if self.vis:
fig, ax = plt.subplots(1,1)
ax.imshow(img, interpolation='none', origin='upper', extent=[0, img_size[1] / scale, img_size[0] / scale, 0])
plt.axis('off')
errors = []
iou_lst = []
iou_binary_lst = []
for i, (name, data) in enumerate(info['encoded_items'].items()):
ssp_item = info['encoded_ssps'][name]
item_decoded = memory *~ ssp_item
clean_loc = ssp_to_loc_multidim(item_decoded, self.ssp_vectors, self.linspace)
x, y, width, height = clean_loc
mse = np.square(np.subtract(data[:2], clean_loc[:2])).mean()
errors.append(mse)
width, height = width * w_scale / scale, height * h_scale / scale
bb_gt = np.array([data[0], data[1], data[0]+(data[2] * w_scale / scale), data[1]+(data[3] * h_scale / scale)])
iou = bb_intersection_over_union(bb_gt, [x, y, x+width, y+height])
iou_lst.append(iou)
if iou > 0.5:
iou_binary_lst.append(1)
else:
iou_binary_lst.append(0)
if self.vis:
rect = patches.Rectangle((x, y),
width, height,
linewidth = 2,
label = name,
edgecolor = RGB_COLORS[i],
facecolor = 'none')
ax.add_patch(rect)
if self.vis:
plt.legend(loc='upper left', bbox_to_anchor=(1., 1.02))
plt.show()
avg_mse = np.mean(errors)
avg_iou = np.mean(iou_lst)
if self.verbose > 0:
print(f'Average mean-squared error of 2D locations: {avg_mse:.4f}')
print(f'Average IoU of 4D bounding boxes: {avg_iou:.2f}')
print(f'Correct items: {np.sum(iou_binary_lst)} / {len(info["encoded_items"])}')
return avg_mse, avg_iou, np.sum(iou_binary_lst)
def print_item(self, idx):
_, info = self.__get_item__(idx)
print(f"Question #{info['q_id']}: \n{info['question']}")
print(f"[{info['answer']}] {info['full_answer']}\n")
print('Program:')
for i, step in enumerate(info['program']):
num, func = step.split('=')
print(f'{i}. {func}')
print()
def set_visualize(self, visualize):
self.vis = visualize
def set_verbose(self, verbose):
self.verbose = verbose