968 lines
35 KiB
Python
968 lines
35 KiB
Python
import os
|
|
import sys
|
|
|
|
os.environ["OPENBLAS_NUM_THREADS"] = "10"
|
|
|
|
import json
|
|
import cv2
|
|
import re
|
|
import random
|
|
import time
|
|
import torch
|
|
import clip
|
|
import logging
|
|
from datetime import datetime
|
|
from PIL import Image
|
|
from tqdm import tqdm
|
|
import pandas as pd
|
|
import numpy as np
|
|
import nengo.spa as spa
|
|
import matplotlib
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.patches as patches
|
|
import matplotlib.colors as mcolors
|
|
from collections import OrderedDict
|
|
from pattern.text.en import singularize, pluralize
|
|
|
|
from dataset import GQADataset
|
|
from utils import *
|
|
|
|
DATA_PATH = '/scratch/penzkofer/GQA'
|
|
RGB_COLORS = []
|
|
for name, hex in mcolors.cnames.items():
|
|
RGB_COLORS.append(mcolors.to_rgb(hex))
|
|
|
|
CUDA_DEVICE = 7
|
|
torch.cuda.set_device(CUDA_DEVICE)
|
|
|
|
device = torch.device("cuda:" + str(CUDA_DEVICE))
|
|
clip_model, preprocess = clip.load("ViT-B/32", device=device)
|
|
|
|
with open('gqa_all_relations_map.json') as f:
|
|
RELATION_DICT = json.load(f)
|
|
|
|
with open('gqa_all_vocab_classes.json') as f:
|
|
CLASS_DICT = json.load(f)
|
|
|
|
with open('gqa_all_attributes.json') as f:
|
|
ATTRIBUTE_DICT = json.load(f)
|
|
|
|
SYNONYMS = {'he': ['man', 'boy'], 'she': ['woman', 'girl']}
|
|
ANSWER_MAP = {'to the right of': 'right', 'to the left of': 'left'}
|
|
VISUALIZE = False
|
|
|
|
|
|
def plot_heatmap_multidim(sp, xs, ys, heatmap_vectors, name='', vmin=-1, vmax=1, cmap='plasma', invert=False):
|
|
"""adapted from https://github.com/ctn-waterloo/cogsci2019-ssp/tree/master"""
|
|
|
|
assert sp.__class__.__name__ == 'SemanticPointer', \
|
|
f'Queried object needs to be of type SemanticPointer but is {sp.__class__.__name__}'
|
|
|
|
# axes: a list of axes to be summed over, first sequence applying to first tensor, second to second tensor
|
|
vs = np.tensordot(sp.v, heatmap_vectors, axes=([0], [4]))
|
|
res = np.unravel_index(np.argmax(vs, axis=None), vs.shape)
|
|
|
|
plt.imshow(np.transpose(vs[:, :, res[2], res[3]]), origin='upper', interpolation='none', extent=(xs[-1], xs[0], ys[-1], ys[0]), vmin=vmin, vmax=vmax, cmap=cmap)
|
|
plt.colorbar()
|
|
plt.axis('off')
|
|
plt.title(name)
|
|
plt.show()
|
|
|
|
def select_ssp(name, memory, encoded_ssps, vectors, linspace):
|
|
"""decode location of object with name from SSP memory"""
|
|
ssp_item = encoded_ssps[name]
|
|
item_decoded = memory *~ ssp_item
|
|
clean_loc = ssp_to_loc_multidim(item_decoded, vectors, linspace)
|
|
return item_decoded, clean_loc
|
|
|
|
|
|
def clip_query(bbox, img, obj_name, clip_tokens, visualize=False):
|
|
"""Implements CLIP queries for different attributes"""
|
|
x, y, w, h = bbox
|
|
obj_center = (x + w / 2, y + h / 2)
|
|
masked_img = img.copy()
|
|
masked_img = cv2.ellipse(masked_img, (int(obj_center[0]), int(obj_center[1])), (int(w*0.7), int(h*0.7)),
|
|
0, 0, 360, (255, 0, 0), 4)
|
|
|
|
if visualize:
|
|
plt.imshow(masked_img)
|
|
plt.axis('off')
|
|
plt.show()
|
|
|
|
masked_img = Image.fromarray(np.uint8(masked_img))
|
|
|
|
tokens = clip.tokenize(clip_tokens)
|
|
|
|
with torch.no_grad():
|
|
|
|
text_features=clip_model.encode_text(tokens.to(device))
|
|
image_features = clip_model.encode_image(preprocess(masked_img).unsqueeze(0).to(device))
|
|
|
|
image_features /= image_features.norm(dim=-1, keepdim=True)
|
|
text_features /= text_features.norm(dim=-1, keepdim=True)
|
|
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
|
|
indices = torch.max(similarity, 1)[1]
|
|
|
|
similarity = similarity.squeeze()
|
|
scores = [s.item() for s in similarity]
|
|
pred = clip_tokens[indices.squeeze().item()]
|
|
|
|
if visualize:
|
|
print('CLIP')
|
|
print(scores)
|
|
print(pred)
|
|
return indices.squeeze().item()
|
|
|
|
def clip_query_scene(img, clip_tokens, verbose=0):
|
|
"""Implements CLIP queries for entire scene, i.e. no bounding box selection"""
|
|
img = Image.fromarray(np.uint8(img))
|
|
tokens = clip.tokenize(clip_tokens)
|
|
|
|
with torch.no_grad():
|
|
|
|
text_features=clip_model.encode_text(tokens.to(device))
|
|
image_features = clip_model.encode_image(preprocess(img).unsqueeze(0).to(device))
|
|
|
|
image_features /= image_features.norm(dim=-1, keepdim=True)
|
|
text_features /= text_features.norm(dim=-1, keepdim=True)
|
|
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
|
|
indices = torch.max(similarity, 1)[1]
|
|
|
|
similarity = similarity.squeeze()
|
|
scores = [s.item() for s in similarity]
|
|
pred = clip_tokens[indices.squeeze().item()]
|
|
|
|
if verbose > 0:
|
|
print('CLIP')
|
|
print(scores)
|
|
print(pred)
|
|
|
|
return indices.squeeze().item()
|
|
|
|
|
|
def clip_choose(bbox1, bbox2, img, attribute, visualize=False):
|
|
"""Run attribute vs. not attribute check for both subjects
|
|
-- select clip prediction with higher confidence"""
|
|
|
|
x1, y1, w1, h1 = bbox1
|
|
obj_center = (x1 + w1 / 2, y1 + h1 / 2)
|
|
masked_img1 = img.copy()
|
|
masked_img1 = cv2.ellipse(masked_img1, (int(obj_center[0]), int(obj_center[1])), (int(w1*0.7), int(h1*0.7)),
|
|
0, 0, 360, (255, 0, 0), 4)
|
|
|
|
x2, y2, w2, h2 = bbox2
|
|
obj_center = (x2 + w2 / 2, y2 + h2 / 2)
|
|
masked_img2 = img.copy()
|
|
masked_img2 = cv2.ellipse(masked_img2, (int(obj_center[0]), int(obj_center[1])), (int(w2*0.7), int(h2*0.7)),
|
|
0, 0, 360, (255, 0, 0), 4)
|
|
|
|
if visualize:
|
|
plt.imshow(masked_img1)
|
|
plt.axis('off')
|
|
plt.show()
|
|
plt.imshow(masked_img2)
|
|
plt.axis('off')
|
|
plt.show()
|
|
|
|
masked_img1 = Image.fromarray(np.uint8(masked_img1))
|
|
masked_img2 = Image.fromarray(np.uint8(masked_img2))
|
|
|
|
tokens = clip.tokenize([attribute, f'not {attribute}'])
|
|
|
|
with torch.no_grad():
|
|
|
|
text_features=clip_model.encode_text(tokens.to(device))
|
|
image_features = clip_model.encode_image(preprocess(masked_img1).unsqueeze(0).to(device))
|
|
|
|
image_features /= image_features.norm(dim=-1, keepdim=True)
|
|
text_features /= text_features.norm(dim=-1, keepdim=True)
|
|
similarity1 = (100.0 * image_features @ text_features.T).softmax(dim=-1)
|
|
similarity1 = similarity1.squeeze()[0]
|
|
|
|
with torch.no_grad():
|
|
|
|
text_features=clip_model.encode_text(tokens.to(device))
|
|
image_features = clip_model.encode_image(preprocess(masked_img2).unsqueeze(0).to(device))
|
|
|
|
image_features /= image_features.norm(dim=-1, keepdim=True)
|
|
text_features /= text_features.norm(dim=-1, keepdim=True)
|
|
similarity2 = (100.0 * image_features @ text_features.T).softmax(dim=-1)
|
|
similarity2 = similarity2.squeeze()[0]
|
|
|
|
if visualize:
|
|
logging.info(f'CLIP prediction: {[similarity1, similarity2]}')
|
|
|
|
return 0 if similarity1 > similarity2 else 1
|
|
|
|
|
|
def get_rel_path(rel, verbose=0):
|
|
"""get correct relation path, map rel to one of the 37 existing query masks"""
|
|
|
|
rel = RELATION_DICT.get(rel.strip()) # get synonym of relation if no mask exists
|
|
rel = '_'.join(rel.split(' ')) if ' ' in rel else rel
|
|
path = 'relations/' + rel + '.npy'
|
|
if verbose > 0:
|
|
logging.info('Loading ', path)
|
|
return path
|
|
|
|
def use_query_mask(obj_pos, info, rel, linspace, axes, dim, memory, verbose=0, visualize=False):
|
|
"""implements query mask usage: load spatial query mask for relation rel,
|
|
scale query mask to object, encode mask to SSP region and
|
|
move to object position in SSP memory, extract object proposals in region"""
|
|
|
|
xs, ys, ws, hs = linspace
|
|
x_axis, y_axis, w_axis, h_axis = axes
|
|
|
|
x, y, width, height = obj_pos
|
|
# 50 pixels was object size in query mask generation -- use pixel scale values for height and width
|
|
iso_scale = np.mean([(width*info['scales'][1]) / 50, (height*info['scales'][2]) / 50])
|
|
|
|
mask = np.load(get_rel_path(rel, verbose))
|
|
mask = cv2.resize(mask, (100, 100), interpolation = cv2.INTER_AREA)
|
|
|
|
# crop new query mask according to scale
|
|
new_area = int(100 / iso_scale)
|
|
if verbose > 0:
|
|
print(iso_scale, new_area)
|
|
|
|
resized = mask[max(0, 50-new_area): min(50+new_area, 100), max(0, 50-new_area): min(50+new_area, 100)]
|
|
resized = cv2.resize(resized, (100,100), interpolation = cv2.INTER_AREA)
|
|
|
|
if visualize:
|
|
fig, axs = plt.subplots(1,2, sharey=True, layout="constrained", figsize=(6, 3))
|
|
fig.set_tight_layout(True)
|
|
fig.subplots_adjust(top=1.05)
|
|
fig.suptitle('Relation: '+ rel)
|
|
plt.setp(plt.gcf().get_axes(), xticks=[], yticks=[])
|
|
|
|
axs[0].imshow(mask, cmap='gray')
|
|
axs[0].title.set_text('original')
|
|
axs[1].imshow(resized, cmap='gray')
|
|
axs[1].title.set_text('resized')
|
|
plt.show()
|
|
|
|
# encode mask to SSP region
|
|
counter = 0
|
|
vector = spa.SemanticPointer(data=np.zeros(dim))
|
|
for (i, j) in zip(*np.where(resized > 0.05)):
|
|
x, y = xs[i], ys[j]
|
|
vector += encode_point_multidim([y, x, 1, 1], axes=axes)
|
|
counter += 1
|
|
vector.normalize()
|
|
if verbose > 0:
|
|
logging.info(f'Resized mask encoded {counter} points')
|
|
|
|
if visualize:
|
|
plot_heatmap_multidim(vector, xs, ys, VECTORS, vmin=-0.2, vmax=0.2, name=f'Encoded Region')
|
|
|
|
# get object info and move query mask to position
|
|
x, y, width, height = obj_pos
|
|
obj_center = (x + width / 2, y + height / 2)
|
|
img_center = np.array([xs[50], ys[50]])
|
|
|
|
shift = -img_center + (obj_center[1], obj_center[0])
|
|
encoded_shift = encode_point(shift[1], shift[0], x_axis=x_axis, y_axis=y_axis)
|
|
|
|
shifted_pos = vector.convolve(encoded_shift)
|
|
shifted_pos.normalize()
|
|
|
|
if visualize:
|
|
plot_heatmap_multidim(shifted_pos, xs, ys, VECTORS, vmin=-0.1, vmax=0.1, name=f'Query Region')
|
|
|
|
# query region and compare output to vocab = all saved SSPs
|
|
vocab_vectors = np.zeros((len(info['encoded_ssps']), dim))
|
|
|
|
color_lst = []
|
|
for i, (name, ssp) in enumerate(info['encoded_ssps'].items()):
|
|
vocab_vectors[i, :] = ssp.v
|
|
color_lst.append(RGB_COLORS[i])
|
|
|
|
similarity = np.tensordot((memory * ~shifted_pos).v, vocab_vectors, axes=([0], [1]))
|
|
|
|
d = OrderedDict(zip(list(info['encoded_ssps'].keys()), similarity))
|
|
res = list(OrderedDict(sorted(d.items(), key=lambda x: x[1], reverse=True)))
|
|
proposals = np.array(res)[np.array(np.array(sorted(similarity, reverse=True)) > 0).astype(bool)]
|
|
|
|
if verbose > 0:
|
|
print('Proposals: ', *proposals)
|
|
|
|
if visualize:
|
|
fig = plt.Figure()
|
|
plt.bar(np.arange(len(info['encoded_ssps'])), similarity, color=color_lst, label=info['encoded_items'].keys())
|
|
plt.title('Similarity')
|
|
plt.legend(loc='upper left', bbox_to_anchor=(1.0, 1.05))
|
|
plt.show()
|
|
|
|
return proposals
|
|
|
|
def select_func(data, img, obj, info, counter, memory, visualize=False):
|
|
"""implements select function: probe SSP memory with given object name"""
|
|
|
|
result = None
|
|
|
|
if obj + '_' + str(counter) in info['encoded_ssps'].keys():
|
|
obj += '_' + str(counter)
|
|
obj_ssp, obj_pos = select_ssp(obj, memory, info['encoded_ssps'], data.ssp_vectors, data.linspace)
|
|
result = [obj, obj_ssp, obj_pos]
|
|
|
|
else:
|
|
# test synonyms and plural/singular
|
|
test = [singularize(obj), pluralize(obj)]
|
|
|
|
if SYNONYMS.get(obj):
|
|
test += SYNONYMS.get(obj)
|
|
|
|
if CLASS_DICT.get(obj):
|
|
test += CLASS_DICT.get(obj)
|
|
|
|
for obj in test:
|
|
obj = str(obj) + '_' + str(counter)
|
|
if obj in info['encoded_ssps'].keys():
|
|
obj_ssp, obj_pos = select_ssp(obj, memory, info['encoded_ssps'], data.ssp_vectors, data.linspace)
|
|
result = [obj, obj_ssp, obj_pos]
|
|
|
|
if result is not None and visualize:
|
|
fig, ax = plt.subplots(1,1)
|
|
plt.axis('off')
|
|
ax.imshow(img)
|
|
rect = patches.Rectangle((obj_pos[0] * info['scales'][0], obj_pos[1] * info['scales'][0]),
|
|
obj_pos[2] * info['scales'][1], obj_pos[3] * info['scales'][2],
|
|
linewidth = 2,
|
|
label = name,
|
|
edgecolor = 'c',
|
|
facecolor = 'none')
|
|
ax.add_patch(rect)
|
|
plt.show()
|
|
|
|
return result
|
|
|
|
|
|
def verify_func(data, img, attr, results, info, dim, memory, verbose=0, visualize=False):
|
|
"""implements all verify functions dependent on length of attributes given"""
|
|
|
|
if len(attr) == 2:
|
|
# verify_color, verify_shape, verify_scene
|
|
num = int(re.findall(r'\d+', attr[0])[0])
|
|
|
|
if results[num] is not None:
|
|
name, obj_ssp, obj_pos = results[num]
|
|
x, y = obj_pos[0] * info['scales'][0], obj_pos[1] * info['scales'][0]
|
|
w, h = obj_pos[2] * info['scales'][1], obj_pos[3] * info['scales'][2]
|
|
|
|
clip_tokens = [f'The {name.split("_")[0]} is {attr[1].strip()}',
|
|
f'The {name.split("_")[0]} is not {attr[1].strip()}']
|
|
|
|
pred = clip_query([x, y, w, h], img, name.split('_')[0], clip_tokens, visualize=visualize)
|
|
|
|
return True if pred == 0 else False
|
|
|
|
else:
|
|
return False
|
|
|
|
elif len(attr) == 3:
|
|
# verify_rel, verify_rel_inv
|
|
obj, rel, rel_obj = attr
|
|
num = int(re.findall(r'\d+', obj)[0])
|
|
|
|
proposals = []
|
|
|
|
if results[num] is not None:
|
|
name, obj_ssp, obj_pos = results[num]
|
|
|
|
proposals = use_query_mask(obj_pos, info, rel, data.linspace, data.ssp_axes, dim, memory, visualize=visualize)
|
|
proposals = [str(p).split('_')[0] for p in proposals]
|
|
|
|
return True if rel_obj.strip() in proposals else False
|
|
else:
|
|
return False
|
|
|
|
# verify_f
|
|
elif len(attr) == 1:
|
|
|
|
clip_tokens = [f'The image is {attr[0].strip()}',
|
|
f'The image is not {attr[0].strip()}',
|
|
f'The image is a {attr[0].strip()}',
|
|
f'The image is not a{attr[0].strip()}']
|
|
|
|
pred = clip_query_scene(img, clip_tokens, verbose=0)
|
|
|
|
return True if pred == 0 or pred == 2 else False
|
|
|
|
else:
|
|
logging.warning('verify_func not implemented')
|
|
|
|
return -1
|
|
|
|
|
|
def query_func(func, img, attr, results, info, img_size, dim, verbose=0, visualize=False):
|
|
"""implements all query functions"""
|
|
|
|
if 'query_f(' in func:
|
|
attr_type = attr[0].strip()
|
|
|
|
assert attr_type in CLASS_DICT, f'{attr_type} not found in class dictionary'
|
|
|
|
attributes = CLASS_DICT.get(attr_type)
|
|
clip_tokens = [f'This is a {a} {attr_type}' for a in attributes]
|
|
pred = clip_query_scene(img, clip_tokens, verbose=0)
|
|
|
|
return attributes[pred]
|
|
|
|
num = int(re.findall(r'\d+', attr[0])[0])
|
|
|
|
if results[num] is not None:
|
|
name, obj_ssp, obj_pos = results[num]
|
|
x, y = obj_pos[0] * info['scales'][0], obj_pos[1] * info['scales'][0]
|
|
w, h = obj_pos[2] * info['scales'][1], obj_pos[3] * info['scales'][2]
|
|
|
|
if len(attr) == 1:
|
|
|
|
if 'query_n' in func:
|
|
# query name
|
|
return name.split('_')[0]
|
|
|
|
if 'query_h' in func:
|
|
# query horizontal position --> x-value
|
|
if (x + w / 2) >= (img_size[1] / 2):
|
|
return 'right'
|
|
else:
|
|
return 'left'
|
|
|
|
if 'query_v' in func:
|
|
# query vertical position --> y-value
|
|
if (y + h / 2) >= (img_size[0] / 2):
|
|
return 'bottom'
|
|
else:
|
|
return 'top'
|
|
|
|
elif len(attr) == 2:
|
|
|
|
attr_type = attr[1].strip()
|
|
|
|
assert attr_type in ATTRIBUTE_DICT, f'{attr_type} not found in attribute dictionary'
|
|
|
|
attributes = ATTRIBUTE_DICT.get(attr_type)
|
|
clip_tokens = [f'The {attr_type} of {name.split("_")[0]} is {a}' for a in attributes]
|
|
pred = clip_query([x, y, w, h], img, name.split('_')[0], clip_tokens, visualize=visualize)
|
|
|
|
return attributes[pred]
|
|
|
|
else:
|
|
return None
|
|
|
|
logging.warning('query not implemented', func, attr)
|
|
return -1
|
|
|
|
|
|
def relate_func(data, func, attr, results, info, dim, memory, visualize=False):
|
|
"""implements all relationship functions"""
|
|
|
|
if 'relate_inv_name' in func or 'relate_name' in func:
|
|
obj, rel, rel_obj = attr
|
|
num = int(re.findall(r'\d+', attr[0])[0])
|
|
|
|
if results[num] is not None:
|
|
name, obj_ssp, obj_pos = results[num]
|
|
|
|
proposals = use_query_mask(obj_pos, info, rel, data.linspace, data.ssp_axes, dim,
|
|
memory, verbose=0, visualize=visualize)
|
|
|
|
selected_obj = proposals[0]
|
|
|
|
# use rel_obj to filter proposals
|
|
rel_obj = rel_obj.strip()
|
|
if rel_obj == selected_obj.split('_')[0]:
|
|
if visualize:
|
|
logging.info('Found perfect match\n')
|
|
return selected_obj
|
|
|
|
elif rel_obj in [str(p).split('_')[0] for p in proposals]:
|
|
idx = [str(p).split('_')[0] for p in proposals].index(rel_obj)
|
|
return proposals[idx]
|
|
|
|
elif rel_obj in CLASS_DICT.keys():
|
|
class_lst = CLASS_DICT.get(rel_obj)
|
|
|
|
for p in [str(p).split('_')[0] for p in proposals]:
|
|
if p in class_lst or singularize(p) in class_lst:
|
|
if visualize:
|
|
logging.info(f'Found better proposal for {rel_obj}: {p}\n')
|
|
idx = [str(p).split('_')[0] for p in proposals].index(p)
|
|
return proposals[idx]
|
|
else:
|
|
if visualize:
|
|
logging.info(f'Did not find {rel_obj} in proposals\n')
|
|
return None
|
|
|
|
elif 'relate_inv' in func or 'relate(' in func:
|
|
obj, rel = attr
|
|
num = int(re.findall(r'\d+', attr[0])[0])
|
|
|
|
if results[num] is not None:
|
|
name, obj_ssp, obj_pos = results[num]
|
|
|
|
proposals = use_query_mask(obj_pos, info, rel, data.linspace, data.ssp_axes, dim,
|
|
memory, visualize=visualize)
|
|
|
|
return proposals[0]
|
|
|
|
else:
|
|
logging.warning(f'{func} not implemented')
|
|
return -1
|
|
|
|
|
|
def filter_func(func, img, attr, img_size, results, info, visualize=False):
|
|
"""implements all filter functions"""
|
|
|
|
obj, filter_attr = attr
|
|
num = int(re.findall(r'\d+', attr[0])[0])
|
|
|
|
if results[num] is not None:
|
|
name, obj_ssp, obj_pos = results[num]
|
|
x, y = obj_pos[0] * info['scales'][0], obj_pos[1] * info['scales'][0]
|
|
w, h = obj_pos[2] * info['scales'][1], obj_pos[3] * info['scales'][2]
|
|
|
|
# query height --> y-value
|
|
if 'bottom' in filter_attr or 'top' in filter_attr:
|
|
if (y + h / 2) >= (img_size[0] / 2):
|
|
pred_attr = 'bottom'
|
|
else:
|
|
pred_attr = 'top'
|
|
|
|
return pred_attr == filter_attr.strip()
|
|
|
|
# query side --> x-value
|
|
if 'right' in filter_attr or 'left' in filter_attr:
|
|
if (x + w / 2) >= (img_size[1] / 2):
|
|
pred_attr = 'right'
|
|
else:
|
|
pred_attr = 'left'
|
|
|
|
return pred_attr == filter_attr.strip()
|
|
|
|
# filter by attribute: color, shape, activity, material
|
|
else:
|
|
|
|
clip_tokens = [f'The {name.split("_")[0]} is {filter_attr}',
|
|
f'The {name.split("_")[0]} is not {filter_attr}']
|
|
|
|
pred = clip_query([x, y, w, h], img, name.split('_')[0], clip_tokens, visualize=visualize)
|
|
|
|
if 'not' in func:
|
|
return True if pred == 1 else False
|
|
else:
|
|
return True if pred == 0 else False
|
|
|
|
else:
|
|
if visualize:
|
|
logging.info('No object was found in last step -- nothing to filter')
|
|
return None
|
|
|
|
return -1
|
|
|
|
def choose_func(data, img, func, attr, img_size, results, info, dim, memory, verbose=0, visualize=False):
|
|
"""implements all choose functions"""
|
|
|
|
if 'choose_f' in func:
|
|
pred = clip_query_scene(img, attr, verbose=0)
|
|
return attr[pred]
|
|
|
|
num = int(re.findall(r'\d+', attr[0])[0])
|
|
|
|
if results[num] is not None:
|
|
name, obj_ssp, obj_pos = results[num]
|
|
|
|
if 'choose_h' in func or 'choose_v' in func:
|
|
x, y = obj_pos[0] * info['scales'][0], obj_pos[1] * info['scales'][0]
|
|
w, h = obj_pos[2] * info['scales'][1], obj_pos[3] * info['scales'][2]
|
|
|
|
# choose side --> x-value
|
|
if 'choose_h' in func:
|
|
if (x + w / 2) >= (img_size[1] / 2):
|
|
return 'right'
|
|
else:
|
|
return 'left'
|
|
|
|
# choose vertical alignment --> y-value
|
|
if 'choose_v' in func:
|
|
if (y + h / 2) >= (img_size[0] / 2):
|
|
return 'bottom'
|
|
else:
|
|
return 'top'
|
|
|
|
elif 'choose_n' in func:
|
|
obj, name1, name2 = attr
|
|
|
|
if name == name1:
|
|
return name1
|
|
elif name == name2:
|
|
return name2
|
|
else:
|
|
return None
|
|
|
|
elif 'choose_attr' in func:
|
|
obj, attr_type, attr1, attr2 = attr
|
|
x, y = obj_pos[0] * info['scales'][0], obj_pos[1] * info['scales'][0]
|
|
w, h = obj_pos[2] * info['scales'][1], obj_pos[3] * info['scales'][2]
|
|
|
|
clip_tokens = [f'The {attr_type} of {name.split("_")[0]} is {attr1}',
|
|
f'The {attr_type} of {name.split("_")[0]} is {attr2}']
|
|
|
|
pred = clip_query([x, y, w, h], img, name.split('_')[0], clip_tokens, visualize=visualize)
|
|
|
|
attr_lst = [attr1.strip(), attr2.strip()]
|
|
|
|
if visualize:
|
|
logging.info(f'Choose attribute: {attr_type} of {name} --> clip prediction: {attr_lst[pred]}')
|
|
|
|
return attr_lst[pred]
|
|
|
|
elif 'choose_rel_inv' in func:
|
|
obj, rel_obj, attr1, attr2 = attr
|
|
|
|
proposals1 = use_query_mask(obj_pos, info, attr1, data.linspace, data.ssp_axes, dim,
|
|
memory, 0, visualize)
|
|
proposals1 = [str(p).split('_')[0] for p in proposals1]
|
|
|
|
proposals2 = use_query_mask(obj_pos, info, attr2, data.linspace, data.ssp_axes, dim,
|
|
memory, 0, visualize)
|
|
proposals2 = [str(p).split('_')[0] for p in proposals2]
|
|
|
|
if rel_obj.strip() in proposals1:
|
|
return ANSWER_MAP.get(attr1.strip()) if attr1.strip() in ANSWER_MAP else attr1.strip()
|
|
elif rel_obj.strip() in proposals2:
|
|
return ANSWER_MAP.get(attr2.strip()) if attr2.strip() in ANSWER_MAP else attr2.strip()
|
|
else:
|
|
return None
|
|
|
|
elif 'choose_subj' in func:
|
|
subj1, subj2, attribute = attr
|
|
num1 = int(re.findall(r'\d+', subj1)[0])
|
|
num2 = int(re.findall(r'\d+', subj2)[0])
|
|
|
|
if results[num1] is not None and results[num2] is not None:
|
|
name1, obj_ssp1, obj_pos1 = results[num1]
|
|
name2, obj_ssp2, obj_pos2 = results[num2]
|
|
|
|
x1, y1 = obj_pos1[0] * info['scales'][0], obj_pos1[1] * info['scales'][0]
|
|
w1, h1 = obj_pos1[2] * info['scales'][1], obj_pos1[3] * info['scales'][2]
|
|
x2, y2 = obj_pos2[0] * info['scales'][0], obj_pos2[1] * info['scales'][0]
|
|
w2, h2 = obj_pos2[2] * info['scales'][1], obj_pos2[3] * info['scales'][2]
|
|
|
|
pred = clip_choose([x1, y1, w1, h1], [x2, y2, w2, h2], img, attribute, visualize=visualize)
|
|
|
|
return name1.split('_')[0] if pred == 0 else name2.split('_')[0]
|
|
|
|
elif 'choose(' in func:
|
|
obj, attr1, attr2 = attr
|
|
x, y = obj_pos[0] * info['scales'][0], obj_pos[1] * info['scales'][0]
|
|
w, h = obj_pos[2] * info['scales'][1], obj_pos[3] * info['scales'][2]
|
|
|
|
clip_tokens = [f'The {name.split("_")[0]} is {attr1}',
|
|
f'The {name.split("_")[0]} is {attr2}']
|
|
|
|
pred = clip_query([x, y, w, h], img, name.split('_')[0], clip_tokens, visualize=visualize)
|
|
|
|
attr_lst = [attr1.strip(), attr2.strip()]
|
|
|
|
if visualize:
|
|
logging.info(f'Choose {attr1} or {attr2} for {name} --> clip prediction: {attr_lst[pred]}')
|
|
|
|
return attr_lst[pred]
|
|
|
|
else:
|
|
logging.warning(func, 'not implemented yet')
|
|
return -1
|
|
|
|
else:
|
|
if visualize:
|
|
logging.info('No object was found in last step -- nothing to choose')
|
|
return None
|
|
|
|
return -1
|
|
|
|
|
|
def run_program(data, img, info, counter, memory, dim, verbose=0):
|
|
""" run program for question on given image:
|
|
for each step in program select appropriate function
|
|
"""
|
|
scale, w_scale, h_scale = info['scales']
|
|
img_size = img.shape[:2]
|
|
|
|
results = []
|
|
last_step = False
|
|
last_func = None
|
|
|
|
for i, step in enumerate(info['program']):
|
|
|
|
if i+1 == len(info['program']):
|
|
last_step = True
|
|
|
|
_, func = step.split('=')
|
|
attr = func.split('(')[-1].split(')')[0].split(',')
|
|
|
|
if verbose > 0:
|
|
logging.info(f'{i+1}. step: \t {func}')
|
|
|
|
if 'select' in func:
|
|
obj = attr[0].strip()
|
|
res = select_func(data, img, obj, info, counter, memory, visualize=VISUALIZE)
|
|
|
|
results.append(res)
|
|
|
|
if res is None:
|
|
if verbose > 1:
|
|
logging.info(f'Could not find {obj}')
|
|
|
|
|
|
elif 'relate' in func:
|
|
|
|
found_rel_obj = relate_func(data, func, attr, results, info, dim, memory, visualize=VISUALIZE)
|
|
|
|
if found_rel_obj is not None:
|
|
assert found_rel_obj in info['encoded_ssps'], f'Result of {func}: {found_rel_obj} is not encoded'
|
|
|
|
selected_ssp = info['encoded_ssps'][found_rel_obj]
|
|
_, selected_pos = select_ssp(found_rel_obj, memory, info['encoded_ssps'], data.ssp_vectors, data.linspace)
|
|
results.append([found_rel_obj, selected_ssp, selected_pos])
|
|
|
|
if last_step:
|
|
return 'yes'
|
|
else:
|
|
results.append(None)
|
|
|
|
if last_step:
|
|
return 'no'
|
|
|
|
elif 'filter' in func:
|
|
|
|
last_filter = filter_func(func, img, attr, img_size, results, info, visualize=VISUALIZE)
|
|
|
|
if last_filter:
|
|
results.append(results[-1])
|
|
else:
|
|
if results[-1] is None:
|
|
results.append(None)
|
|
elif results[-1][0].split("_")[0] + "_" + str(counter+1) in info['encoded_ssps'].keys():
|
|
counter += 1
|
|
return None
|
|
else:
|
|
last_filter = False
|
|
results.append(results[-1])
|
|
|
|
elif 'verify' in func:
|
|
pred = verify_func(data, img, attr, results, info, dim, memory, verbose=verbose, visualize=VISUALIZE)
|
|
|
|
if 'verify_relation_name' in func or 'verify_relation_inv_name' in func:
|
|
results.append(results[-1] if pred else None)
|
|
else:
|
|
results.append(pred)
|
|
|
|
if last_step:
|
|
return 'yes' if pred else 'no'
|
|
|
|
|
|
elif 'query' in func:
|
|
|
|
return query_func(func, img, attr, results, info, img_size, dim, verbose=verbose, visualize=VISUALIZE)
|
|
|
|
elif 'exist' in func:
|
|
|
|
num = int(re.findall(r'\d+', attr[0])[0])
|
|
|
|
if last_step:
|
|
return 'yes' if results[num] is not None else 'no'
|
|
else:
|
|
if results[num] is not None and 'filter' not in last_func:
|
|
results.append(True)
|
|
elif results[num] is not None and last_filter:
|
|
results.append(True)
|
|
else:
|
|
results.append(False)
|
|
|
|
elif 'or(' in func:
|
|
attr1 = int(re.findall(r'\d+', attr[0])[0])
|
|
attr2 = int(re.findall(r'\d+', attr[1])[0])
|
|
return 'yes' if results[attr1] or results[attr2] else 'no'
|
|
|
|
elif 'and(' in func:
|
|
attr1 = int(re.findall(r'\d+', attr[0])[0])
|
|
attr2 = int(re.findall(r'\d+', attr[1])[0])
|
|
return 'yes' if results[attr1] and results[attr2] else 'no'
|
|
|
|
elif 'different' in func:
|
|
|
|
if len(attr) == 1:
|
|
logging.warning(f'{func} cannot be computed')
|
|
return None
|
|
|
|
else:
|
|
pred_attr1 = query_func(f'query_{attr[2].strip()}', img, [attr[0], attr[2]], results, info, img_size, dim)
|
|
pred_attr2 = query_func(f'query_{attr[2].strip()}', img, [attr[1], attr[2]], results, info, img_size, dim)
|
|
|
|
if pred_attr1 != pred_attr2:
|
|
return 'yes'
|
|
else:
|
|
return 'no'
|
|
|
|
elif 'same' in func:
|
|
|
|
if len(attr) == 1:
|
|
logging.warning(f'{func} cannot be computed')
|
|
return None
|
|
|
|
pred_attr1 = query_func(f'query_{attr[2].strip()}', img, [attr[0], attr[2]], results, info, img_size, dim)
|
|
pred_attr2 = query_func(f'query_{attr[2].strip()}', img, [attr[1], attr[2]], results, info, img_size, dim)
|
|
|
|
if pred_attr1 == pred_attr2:
|
|
return 'yes'
|
|
else:
|
|
return 'no'
|
|
|
|
elif 'choose' in func:
|
|
return choose_func(data, img, func, attr, img_size, results, info, dim, memory, visualize=VISUALIZE)
|
|
|
|
else:
|
|
logging.warning(f'{func} not implemented')
|
|
return -1
|
|
|
|
last_func = func
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
TEST = True
|
|
DIM = 2048
|
|
RANDOM_SEED = 17
|
|
np.random.seed(RANDOM_SEED)
|
|
torch.manual_seed(RANDOM_SEED)
|
|
random.seed(RANDOM_SEED)
|
|
|
|
x = datetime.now()
|
|
TIME_STAMP = x.strftime("%d%b%y-%H%M")
|
|
log_file = f"logs/run{TIME_STAMP}-{'VAL' if TEST else 'TRAIN'}{RANDOM_SEED}.log"
|
|
log_file = f"logs/run{TIME_STAMP}-DIM{DIM}-{RANDOM_SEED}.log"
|
|
|
|
logging.basicConfig(level=logging.INFO, filename=log_file,
|
|
filemode="w", format="%(asctime)s %(levelname)s %(message)s")
|
|
|
|
print('Logging to ', log_file)
|
|
DATA_PATH = '/scratch/penzkofer/GQA'
|
|
CUDA_DEVICE = 7
|
|
torch.cuda.set_device(CUDA_DEVICE)
|
|
|
|
device = torch.device("cuda:" + str(CUDA_DEVICE))
|
|
clip_model, preprocess = clip.load("ViT-B/32", device=device)
|
|
|
|
with open('gqa_all_relations_map.json') as f:
|
|
RELATION_DICT = json.load(f)
|
|
|
|
with open('gqa_vocab_classes.json') as f:
|
|
CLASS_DICT = json.load(f)
|
|
|
|
with open('gqa_all_attributes.json') as f:
|
|
ATTRIBUTE_DICT = json.load(f)
|
|
|
|
start = time.time()
|
|
res = 100
|
|
dim = DIM
|
|
new_size = (25, 25) # size should be smaller than resolution!
|
|
|
|
xs = np.linspace(0, new_size[1], res)
|
|
ys = np.linspace(0, new_size[0], res)
|
|
ws = np.linspace(1, 10, 10)
|
|
hs = np.linspace(1, 10, 10)
|
|
|
|
rng = np.random.RandomState(seed=RANDOM_SEED)
|
|
x_axis = make_good_unitary(dim, rng=rng)
|
|
y_axis = make_good_unitary(dim, rng=rng)
|
|
w_axis = make_good_unitary(dim, rng=rng)
|
|
h_axis = make_good_unitary(dim, rng=rng)
|
|
|
|
logging.info(f'Size of vector space: {res**2}x{10**2}x{dim}')
|
|
logging.info(f'x-axis resolution = {len(xs)}, y-axis resolution = {len(ys)}')
|
|
logging.info(f'width resolution = {len(ws)}, height resolution = {len(hs)}')
|
|
|
|
# precompute the vectors
|
|
VECTORS = get_heatmap_vectors_multidim(xs, ys, ws, hs, x_axis, y_axis, w_axis, h_axis)
|
|
logging.info(VECTORS.shape)
|
|
logging.info(f'Took {time.time() - start}seconds to load vectors.\n')
|
|
|
|
# load questions, programs and scenegraphs
|
|
if TEST:
|
|
questions_path = 'val_balanced_questions.json'
|
|
programs_path = 'programs/trainval_balanced_programs.json'
|
|
scene_path = 'val_sceneGraphs.json'
|
|
else:
|
|
questions_path = 'train_balanced_questions.json'
|
|
programs_path = 'programs/trainval_balanced_programs.json'
|
|
scene_path = 'train_sceneGraphs.json'
|
|
|
|
with open(os.path.join(DATA_PATH, questions_path), 'r') as f:
|
|
questions = json.load(f)
|
|
|
|
with open(os.path.join(DATA_PATH, programs_path), 'r') as f:
|
|
programs = json.load(f)
|
|
|
|
with open(os.path.join(DATA_PATH, scene_path), 'r') as f:
|
|
scenegraphs = json.load(f)
|
|
|
|
columns = ['semantic', 'entailed', 'equivalent', 'question', 'imageId', 'isBalanced', 'groups',
|
|
'answer', 'semanticStr', 'annotations', 'types', 'fullAnswer']
|
|
|
|
questions = pd.DataFrame.from_dict(questions, orient='index', columns=columns)
|
|
questions = questions.reset_index()
|
|
questions = questions.rename(columns={"index": "questionID"}, errors="raise")
|
|
|
|
columns = ['imageID', 'question', 'program', 'questionID', 'answer']
|
|
programs = pd.DataFrame(programs, columns=columns)
|
|
|
|
DATA = GQADataset(questions, programs, scenegraphs, vectors=VECTORS,
|
|
axes=[x_axis, y_axis, w_axis, h_axis], linspace=[xs, ys, ws, hs])
|
|
|
|
logging.info(f'Length of data set: {len(DATA)}')
|
|
|
|
VISUALIZE = False
|
|
DATA.set_visualize(VISUALIZE)
|
|
DATA.set_verbose(0)
|
|
|
|
results_lst = []
|
|
num_correct = 0
|
|
|
|
pbar = tqdm(range(len(DATA)), ncols=115)
|
|
|
|
for i, IDX in enumerate(pbar):
|
|
start = time.time()
|
|
img, info, memory = DATA.encode_item(IDX, dim=dim)
|
|
avg_mse, avg_iou, correct_items = DATA.decode_item(img, info, memory)
|
|
|
|
try:
|
|
answer = run_program(DATA, img, info, counter=1, memory=memory, dim=dim, verbose=1)
|
|
except Exception as e:
|
|
answer = None
|
|
logging.error(e)
|
|
|
|
if answer == -1:
|
|
logging.warning(f'[{IDX}] not fully implemented!')
|
|
|
|
time_in_sec = time.time() - start
|
|
correct = answer == info["answer"]
|
|
num_correct += int(correct)
|
|
|
|
results = {'q_id':info['q_id'], 'question':info['question'],'program':info['program'], 'image':info['img_id'],
|
|
'true_answer': info['answer'], 'pred_answer': answer, 'correct': correct, 'time': time_in_sec,
|
|
'enc_avg_mse': avg_mse, 'enc_avg_iou': avg_iou, 'enc_correct_items': correct_items,
|
|
'enc_items': len(info["encoded_items"]), 'q_idx': IDX}
|
|
|
|
results_lst.append(results)
|
|
logging.info(f'[{IDX+1}] {num_correct / (i+1):.2%}')
|
|
pbar.set_postfix({'correct': f'{num_correct / (i+1):.2%}', 'q_idx': str(IDX+1)})
|
|
|
|
results_df = pd.DataFrame(results_lst)
|
|
logging.info(f'Acurracy: {results_df.correct.sum() / len(results_df):.2%}')
|
|
|
|
out_path = os.path.join(DATA_PATH, f'results-DIM{DIM}-{"VAL" if TEST else "TRAIN"}{RANDOM_SEED}.pkl')
|
|
results_df.to_pickle(out_path)
|
|
logging.info(f'Saved results to {out_path}.')
|
|
|