VSA4VQA/run_programs.py
2024-04-29 17:18:10 +02:00

968 lines
35 KiB
Python

import os
import sys
os.environ["OPENBLAS_NUM_THREADS"] = "10"
import json
import cv2
import re
import random
import time
import torch
import clip
import logging
from datetime import datetime
from PIL import Image
from tqdm import tqdm
import pandas as pd
import numpy as np
import nengo.spa as spa
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.colors as mcolors
from collections import OrderedDict
from pattern.text.en import singularize, pluralize
from dataset import GQADataset
from utils import *
DATA_PATH = '/scratch/penzkofer/GQA'
RGB_COLORS = []
for name, hex in mcolors.cnames.items():
RGB_COLORS.append(mcolors.to_rgb(hex))
CUDA_DEVICE = 7
torch.cuda.set_device(CUDA_DEVICE)
device = torch.device("cuda:" + str(CUDA_DEVICE))
clip_model, preprocess = clip.load("ViT-B/32", device=device)
with open('gqa_all_relations_map.json') as f:
RELATION_DICT = json.load(f)
with open('gqa_all_vocab_classes.json') as f:
CLASS_DICT = json.load(f)
with open('gqa_all_attributes.json') as f:
ATTRIBUTE_DICT = json.load(f)
SYNONYMS = {'he': ['man', 'boy'], 'she': ['woman', 'girl']}
ANSWER_MAP = {'to the right of': 'right', 'to the left of': 'left'}
VISUALIZE = False
def plot_heatmap_multidim(sp, xs, ys, heatmap_vectors, name='', vmin=-1, vmax=1, cmap='plasma', invert=False):
"""adapted from https://github.com/ctn-waterloo/cogsci2019-ssp/tree/master"""
assert sp.__class__.__name__ == 'SemanticPointer', \
f'Queried object needs to be of type SemanticPointer but is {sp.__class__.__name__}'
# axes: a list of axes to be summed over, first sequence applying to first tensor, second to second tensor
vs = np.tensordot(sp.v, heatmap_vectors, axes=([0], [4]))
res = np.unravel_index(np.argmax(vs, axis=None), vs.shape)
plt.imshow(np.transpose(vs[:, :, res[2], res[3]]), origin='upper', interpolation='none', extent=(xs[-1], xs[0], ys[-1], ys[0]), vmin=vmin, vmax=vmax, cmap=cmap)
plt.colorbar()
plt.axis('off')
plt.title(name)
plt.show()
def select_ssp(name, memory, encoded_ssps, vectors, linspace):
"""decode location of object with name from SSP memory"""
ssp_item = encoded_ssps[name]
item_decoded = memory *~ ssp_item
clean_loc = ssp_to_loc_multidim(item_decoded, vectors, linspace)
return item_decoded, clean_loc
def clip_query(bbox, img, obj_name, clip_tokens, visualize=False):
"""Implements CLIP queries for different attributes"""
x, y, w, h = bbox
obj_center = (x + w / 2, y + h / 2)
masked_img = img.copy()
masked_img = cv2.ellipse(masked_img, (int(obj_center[0]), int(obj_center[1])), (int(w*0.7), int(h*0.7)),
0, 0, 360, (255, 0, 0), 4)
if visualize:
plt.imshow(masked_img)
plt.axis('off')
plt.show()
masked_img = Image.fromarray(np.uint8(masked_img))
tokens = clip.tokenize(clip_tokens)
with torch.no_grad():
text_features=clip_model.encode_text(tokens.to(device))
image_features = clip_model.encode_image(preprocess(masked_img).unsqueeze(0).to(device))
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
indices = torch.max(similarity, 1)[1]
similarity = similarity.squeeze()
scores = [s.item() for s in similarity]
pred = clip_tokens[indices.squeeze().item()]
if visualize:
print('CLIP')
print(scores)
print(pred)
return indices.squeeze().item()
def clip_query_scene(img, clip_tokens, verbose=0):
"""Implements CLIP queries for entire scene, i.e. no bounding box selection"""
img = Image.fromarray(np.uint8(img))
tokens = clip.tokenize(clip_tokens)
with torch.no_grad():
text_features=clip_model.encode_text(tokens.to(device))
image_features = clip_model.encode_image(preprocess(img).unsqueeze(0).to(device))
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
indices = torch.max(similarity, 1)[1]
similarity = similarity.squeeze()
scores = [s.item() for s in similarity]
pred = clip_tokens[indices.squeeze().item()]
if verbose > 0:
print('CLIP')
print(scores)
print(pred)
return indices.squeeze().item()
def clip_choose(bbox1, bbox2, img, attribute, visualize=False):
"""Run attribute vs. not attribute check for both subjects
-- select clip prediction with higher confidence"""
x1, y1, w1, h1 = bbox1
obj_center = (x1 + w1 / 2, y1 + h1 / 2)
masked_img1 = img.copy()
masked_img1 = cv2.ellipse(masked_img1, (int(obj_center[0]), int(obj_center[1])), (int(w1*0.7), int(h1*0.7)),
0, 0, 360, (255, 0, 0), 4)
x2, y2, w2, h2 = bbox2
obj_center = (x2 + w2 / 2, y2 + h2 / 2)
masked_img2 = img.copy()
masked_img2 = cv2.ellipse(masked_img2, (int(obj_center[0]), int(obj_center[1])), (int(w2*0.7), int(h2*0.7)),
0, 0, 360, (255, 0, 0), 4)
if visualize:
plt.imshow(masked_img1)
plt.axis('off')
plt.show()
plt.imshow(masked_img2)
plt.axis('off')
plt.show()
masked_img1 = Image.fromarray(np.uint8(masked_img1))
masked_img2 = Image.fromarray(np.uint8(masked_img2))
tokens = clip.tokenize([attribute, f'not {attribute}'])
with torch.no_grad():
text_features=clip_model.encode_text(tokens.to(device))
image_features = clip_model.encode_image(preprocess(masked_img1).unsqueeze(0).to(device))
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity1 = (100.0 * image_features @ text_features.T).softmax(dim=-1)
similarity1 = similarity1.squeeze()[0]
with torch.no_grad():
text_features=clip_model.encode_text(tokens.to(device))
image_features = clip_model.encode_image(preprocess(masked_img2).unsqueeze(0).to(device))
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity2 = (100.0 * image_features @ text_features.T).softmax(dim=-1)
similarity2 = similarity2.squeeze()[0]
if visualize:
logging.info(f'CLIP prediction: {[similarity1, similarity2]}')
return 0 if similarity1 > similarity2 else 1
def get_rel_path(rel, verbose=0):
"""get correct relation path, map rel to one of the 37 existing query masks"""
rel = RELATION_DICT.get(rel.strip()) # get synonym of relation if no mask exists
rel = '_'.join(rel.split(' ')) if ' ' in rel else rel
path = 'relations/' + rel + '.npy'
if verbose > 0:
logging.info('Loading ', path)
return path
def use_query_mask(obj_pos, info, rel, linspace, axes, dim, memory, verbose=0, visualize=False):
"""implements query mask usage: load spatial query mask for relation rel,
scale query mask to object, encode mask to SSP region and
move to object position in SSP memory, extract object proposals in region"""
xs, ys, ws, hs = linspace
x_axis, y_axis, w_axis, h_axis = axes
x, y, width, height = obj_pos
# 50 pixels was object size in query mask generation -- use pixel scale values for height and width
iso_scale = np.mean([(width*info['scales'][1]) / 50, (height*info['scales'][2]) / 50])
mask = np.load(get_rel_path(rel, verbose))
mask = cv2.resize(mask, (100, 100), interpolation = cv2.INTER_AREA)
# crop new query mask according to scale
new_area = int(100 / iso_scale)
if verbose > 0:
print(iso_scale, new_area)
resized = mask[max(0, 50-new_area): min(50+new_area, 100), max(0, 50-new_area): min(50+new_area, 100)]
resized = cv2.resize(resized, (100,100), interpolation = cv2.INTER_AREA)
if visualize:
fig, axs = plt.subplots(1,2, sharey=True, layout="constrained", figsize=(6, 3))
fig.set_tight_layout(True)
fig.subplots_adjust(top=1.05)
fig.suptitle('Relation: '+ rel)
plt.setp(plt.gcf().get_axes(), xticks=[], yticks=[])
axs[0].imshow(mask, cmap='gray')
axs[0].title.set_text('original')
axs[1].imshow(resized, cmap='gray')
axs[1].title.set_text('resized')
plt.show()
# encode mask to SSP region
counter = 0
vector = spa.SemanticPointer(data=np.zeros(dim))
for (i, j) in zip(*np.where(resized > 0.05)):
x, y = xs[i], ys[j]
vector += encode_point_multidim([y, x, 1, 1], axes=axes)
counter += 1
vector.normalize()
if verbose > 0:
logging.info(f'Resized mask encoded {counter} points')
if visualize:
plot_heatmap_multidim(vector, xs, ys, VECTORS, vmin=-0.2, vmax=0.2, name=f'Encoded Region')
# get object info and move query mask to position
x, y, width, height = obj_pos
obj_center = (x + width / 2, y + height / 2)
img_center = np.array([xs[50], ys[50]])
shift = -img_center + (obj_center[1], obj_center[0])
encoded_shift = encode_point(shift[1], shift[0], x_axis=x_axis, y_axis=y_axis)
shifted_pos = vector.convolve(encoded_shift)
shifted_pos.normalize()
if visualize:
plot_heatmap_multidim(shifted_pos, xs, ys, VECTORS, vmin=-0.1, vmax=0.1, name=f'Query Region')
# query region and compare output to vocab = all saved SSPs
vocab_vectors = np.zeros((len(info['encoded_ssps']), dim))
color_lst = []
for i, (name, ssp) in enumerate(info['encoded_ssps'].items()):
vocab_vectors[i, :] = ssp.v
color_lst.append(RGB_COLORS[i])
similarity = np.tensordot((memory * ~shifted_pos).v, vocab_vectors, axes=([0], [1]))
d = OrderedDict(zip(list(info['encoded_ssps'].keys()), similarity))
res = list(OrderedDict(sorted(d.items(), key=lambda x: x[1], reverse=True)))
proposals = np.array(res)[np.array(np.array(sorted(similarity, reverse=True)) > 0).astype(bool)]
if verbose > 0:
print('Proposals: ', *proposals)
if visualize:
fig = plt.Figure()
plt.bar(np.arange(len(info['encoded_ssps'])), similarity, color=color_lst, label=info['encoded_items'].keys())
plt.title('Similarity')
plt.legend(loc='upper left', bbox_to_anchor=(1.0, 1.05))
plt.show()
return proposals
def select_func(data, img, obj, info, counter, memory, visualize=False):
"""implements select function: probe SSP memory with given object name"""
result = None
if obj + '_' + str(counter) in info['encoded_ssps'].keys():
obj += '_' + str(counter)
obj_ssp, obj_pos = select_ssp(obj, memory, info['encoded_ssps'], data.ssp_vectors, data.linspace)
result = [obj, obj_ssp, obj_pos]
else:
# test synonyms and plural/singular
test = [singularize(obj), pluralize(obj)]
if SYNONYMS.get(obj):
test += SYNONYMS.get(obj)
if CLASS_DICT.get(obj):
test += CLASS_DICT.get(obj)
for obj in test:
obj = str(obj) + '_' + str(counter)
if obj in info['encoded_ssps'].keys():
obj_ssp, obj_pos = select_ssp(obj, memory, info['encoded_ssps'], data.ssp_vectors, data.linspace)
result = [obj, obj_ssp, obj_pos]
if result is not None and visualize:
fig, ax = plt.subplots(1,1)
plt.axis('off')
ax.imshow(img)
rect = patches.Rectangle((obj_pos[0] * info['scales'][0], obj_pos[1] * info['scales'][0]),
obj_pos[2] * info['scales'][1], obj_pos[3] * info['scales'][2],
linewidth = 2,
label = name,
edgecolor = 'c',
facecolor = 'none')
ax.add_patch(rect)
plt.show()
return result
def verify_func(data, img, attr, results, info, dim, memory, verbose=0, visualize=False):
"""implements all verify functions dependent on length of attributes given"""
if len(attr) == 2:
# verify_color, verify_shape, verify_scene
num = int(re.findall(r'\d+', attr[0])[0])
if results[num] is not None:
name, obj_ssp, obj_pos = results[num]
x, y = obj_pos[0] * info['scales'][0], obj_pos[1] * info['scales'][0]
w, h = obj_pos[2] * info['scales'][1], obj_pos[3] * info['scales'][2]
clip_tokens = [f'The {name.split("_")[0]} is {attr[1].strip()}',
f'The {name.split("_")[0]} is not {attr[1].strip()}']
pred = clip_query([x, y, w, h], img, name.split('_')[0], clip_tokens, visualize=visualize)
return True if pred == 0 else False
else:
return False
elif len(attr) == 3:
# verify_rel, verify_rel_inv
obj, rel, rel_obj = attr
num = int(re.findall(r'\d+', obj)[0])
proposals = []
if results[num] is not None:
name, obj_ssp, obj_pos = results[num]
proposals = use_query_mask(obj_pos, info, rel, data.linspace, data.ssp_axes, dim, memory, visualize=visualize)
proposals = [str(p).split('_')[0] for p in proposals]
return True if rel_obj.strip() in proposals else False
else:
return False
# verify_f
elif len(attr) == 1:
clip_tokens = [f'The image is {attr[0].strip()}',
f'The image is not {attr[0].strip()}',
f'The image is a {attr[0].strip()}',
f'The image is not a{attr[0].strip()}']
pred = clip_query_scene(img, clip_tokens, verbose=0)
return True if pred == 0 or pred == 2 else False
else:
logging.warning('verify_func not implemented')
return -1
def query_func(func, img, attr, results, info, img_size, dim, verbose=0, visualize=False):
"""implements all query functions"""
if 'query_f(' in func:
attr_type = attr[0].strip()
assert attr_type in CLASS_DICT, f'{attr_type} not found in class dictionary'
attributes = CLASS_DICT.get(attr_type)
clip_tokens = [f'This is a {a} {attr_type}' for a in attributes]
pred = clip_query_scene(img, clip_tokens, verbose=0)
return attributes[pred]
num = int(re.findall(r'\d+', attr[0])[0])
if results[num] is not None:
name, obj_ssp, obj_pos = results[num]
x, y = obj_pos[0] * info['scales'][0], obj_pos[1] * info['scales'][0]
w, h = obj_pos[2] * info['scales'][1], obj_pos[3] * info['scales'][2]
if len(attr) == 1:
if 'query_n' in func:
# query name
return name.split('_')[0]
if 'query_h' in func:
# query horizontal position --> x-value
if (x + w / 2) >= (img_size[1] / 2):
return 'right'
else:
return 'left'
if 'query_v' in func:
# query vertical position --> y-value
if (y + h / 2) >= (img_size[0] / 2):
return 'bottom'
else:
return 'top'
elif len(attr) == 2:
attr_type = attr[1].strip()
assert attr_type in ATTRIBUTE_DICT, f'{attr_type} not found in attribute dictionary'
attributes = ATTRIBUTE_DICT.get(attr_type)
clip_tokens = [f'The {attr_type} of {name.split("_")[0]} is {a}' for a in attributes]
pred = clip_query([x, y, w, h], img, name.split('_')[0], clip_tokens, visualize=visualize)
return attributes[pred]
else:
return None
logging.warning('query not implemented', func, attr)
return -1
def relate_func(data, func, attr, results, info, dim, memory, visualize=False):
"""implements all relationship functions"""
if 'relate_inv_name' in func or 'relate_name' in func:
obj, rel, rel_obj = attr
num = int(re.findall(r'\d+', attr[0])[0])
if results[num] is not None:
name, obj_ssp, obj_pos = results[num]
proposals = use_query_mask(obj_pos, info, rel, data.linspace, data.ssp_axes, dim,
memory, verbose=0, visualize=visualize)
selected_obj = proposals[0]
# use rel_obj to filter proposals
rel_obj = rel_obj.strip()
if rel_obj == selected_obj.split('_')[0]:
if visualize:
logging.info('Found perfect match\n')
return selected_obj
elif rel_obj in [str(p).split('_')[0] for p in proposals]:
idx = [str(p).split('_')[0] for p in proposals].index(rel_obj)
return proposals[idx]
elif rel_obj in CLASS_DICT.keys():
class_lst = CLASS_DICT.get(rel_obj)
for p in [str(p).split('_')[0] for p in proposals]:
if p in class_lst or singularize(p) in class_lst:
if visualize:
logging.info(f'Found better proposal for {rel_obj}: {p}\n')
idx = [str(p).split('_')[0] for p in proposals].index(p)
return proposals[idx]
else:
if visualize:
logging.info(f'Did not find {rel_obj} in proposals\n')
return None
elif 'relate_inv' in func or 'relate(' in func:
obj, rel = attr
num = int(re.findall(r'\d+', attr[0])[0])
if results[num] is not None:
name, obj_ssp, obj_pos = results[num]
proposals = use_query_mask(obj_pos, info, rel, data.linspace, data.ssp_axes, dim,
memory, visualize=visualize)
return proposals[0]
else:
logging.warning(f'{func} not implemented')
return -1
def filter_func(func, img, attr, img_size, results, info, visualize=False):
"""implements all filter functions"""
obj, filter_attr = attr
num = int(re.findall(r'\d+', attr[0])[0])
if results[num] is not None:
name, obj_ssp, obj_pos = results[num]
x, y = obj_pos[0] * info['scales'][0], obj_pos[1] * info['scales'][0]
w, h = obj_pos[2] * info['scales'][1], obj_pos[3] * info['scales'][2]
# query height --> y-value
if 'bottom' in filter_attr or 'top' in filter_attr:
if (y + h / 2) >= (img_size[0] / 2):
pred_attr = 'bottom'
else:
pred_attr = 'top'
return pred_attr == filter_attr.strip()
# query side --> x-value
if 'right' in filter_attr or 'left' in filter_attr:
if (x + w / 2) >= (img_size[1] / 2):
pred_attr = 'right'
else:
pred_attr = 'left'
return pred_attr == filter_attr.strip()
# filter by attribute: color, shape, activity, material
else:
clip_tokens = [f'The {name.split("_")[0]} is {filter_attr}',
f'The {name.split("_")[0]} is not {filter_attr}']
pred = clip_query([x, y, w, h], img, name.split('_')[0], clip_tokens, visualize=visualize)
if 'not' in func:
return True if pred == 1 else False
else:
return True if pred == 0 else False
else:
if visualize:
logging.info('No object was found in last step -- nothing to filter')
return None
return -1
def choose_func(data, img, func, attr, img_size, results, info, dim, memory, verbose=0, visualize=False):
"""implements all choose functions"""
if 'choose_f' in func:
pred = clip_query_scene(img, attr, verbose=0)
return attr[pred]
num = int(re.findall(r'\d+', attr[0])[0])
if results[num] is not None:
name, obj_ssp, obj_pos = results[num]
if 'choose_h' in func or 'choose_v' in func:
x, y = obj_pos[0] * info['scales'][0], obj_pos[1] * info['scales'][0]
w, h = obj_pos[2] * info['scales'][1], obj_pos[3] * info['scales'][2]
# choose side --> x-value
if 'choose_h' in func:
if (x + w / 2) >= (img_size[1] / 2):
return 'right'
else:
return 'left'
# choose vertical alignment --> y-value
if 'choose_v' in func:
if (y + h / 2) >= (img_size[0] / 2):
return 'bottom'
else:
return 'top'
elif 'choose_n' in func:
obj, name1, name2 = attr
if name == name1:
return name1
elif name == name2:
return name2
else:
return None
elif 'choose_attr' in func:
obj, attr_type, attr1, attr2 = attr
x, y = obj_pos[0] * info['scales'][0], obj_pos[1] * info['scales'][0]
w, h = obj_pos[2] * info['scales'][1], obj_pos[3] * info['scales'][2]
clip_tokens = [f'The {attr_type} of {name.split("_")[0]} is {attr1}',
f'The {attr_type} of {name.split("_")[0]} is {attr2}']
pred = clip_query([x, y, w, h], img, name.split('_')[0], clip_tokens, visualize=visualize)
attr_lst = [attr1.strip(), attr2.strip()]
if visualize:
logging.info(f'Choose attribute: {attr_type} of {name} --> clip prediction: {attr_lst[pred]}')
return attr_lst[pred]
elif 'choose_rel_inv' in func:
obj, rel_obj, attr1, attr2 = attr
proposals1 = use_query_mask(obj_pos, info, attr1, data.linspace, data.ssp_axes, dim,
memory, 0, visualize)
proposals1 = [str(p).split('_')[0] for p in proposals1]
proposals2 = use_query_mask(obj_pos, info, attr2, data.linspace, data.ssp_axes, dim,
memory, 0, visualize)
proposals2 = [str(p).split('_')[0] for p in proposals2]
if rel_obj.strip() in proposals1:
return ANSWER_MAP.get(attr1.strip()) if attr1.strip() in ANSWER_MAP else attr1.strip()
elif rel_obj.strip() in proposals2:
return ANSWER_MAP.get(attr2.strip()) if attr2.strip() in ANSWER_MAP else attr2.strip()
else:
return None
elif 'choose_subj' in func:
subj1, subj2, attribute = attr
num1 = int(re.findall(r'\d+', subj1)[0])
num2 = int(re.findall(r'\d+', subj2)[0])
if results[num1] is not None and results[num2] is not None:
name1, obj_ssp1, obj_pos1 = results[num1]
name2, obj_ssp2, obj_pos2 = results[num2]
x1, y1 = obj_pos1[0] * info['scales'][0], obj_pos1[1] * info['scales'][0]
w1, h1 = obj_pos1[2] * info['scales'][1], obj_pos1[3] * info['scales'][2]
x2, y2 = obj_pos2[0] * info['scales'][0], obj_pos2[1] * info['scales'][0]
w2, h2 = obj_pos2[2] * info['scales'][1], obj_pos2[3] * info['scales'][2]
pred = clip_choose([x1, y1, w1, h1], [x2, y2, w2, h2], img, attribute, visualize=visualize)
return name1.split('_')[0] if pred == 0 else name2.split('_')[0]
elif 'choose(' in func:
obj, attr1, attr2 = attr
x, y = obj_pos[0] * info['scales'][0], obj_pos[1] * info['scales'][0]
w, h = obj_pos[2] * info['scales'][1], obj_pos[3] * info['scales'][2]
clip_tokens = [f'The {name.split("_")[0]} is {attr1}',
f'The {name.split("_")[0]} is {attr2}']
pred = clip_query([x, y, w, h], img, name.split('_')[0], clip_tokens, visualize=visualize)
attr_lst = [attr1.strip(), attr2.strip()]
if visualize:
logging.info(f'Choose {attr1} or {attr2} for {name} --> clip prediction: {attr_lst[pred]}')
return attr_lst[pred]
else:
logging.warning(func, 'not implemented yet')
return -1
else:
if visualize:
logging.info('No object was found in last step -- nothing to choose')
return None
return -1
def run_program(data, img, info, counter, memory, dim, verbose=0):
""" run program for question on given image:
for each step in program select appropriate function
"""
scale, w_scale, h_scale = info['scales']
img_size = img.shape[:2]
results = []
last_step = False
last_func = None
for i, step in enumerate(info['program']):
if i+1 == len(info['program']):
last_step = True
_, func = step.split('=')
attr = func.split('(')[-1].split(')')[0].split(',')
if verbose > 0:
logging.info(f'{i+1}. step: \t {func}')
if 'select' in func:
obj = attr[0].strip()
res = select_func(data, img, obj, info, counter, memory, visualize=VISUALIZE)
results.append(res)
if res is None:
if verbose > 1:
logging.info(f'Could not find {obj}')
elif 'relate' in func:
found_rel_obj = relate_func(data, func, attr, results, info, dim, memory, visualize=VISUALIZE)
if found_rel_obj is not None:
assert found_rel_obj in info['encoded_ssps'], f'Result of {func}: {found_rel_obj} is not encoded'
selected_ssp = info['encoded_ssps'][found_rel_obj]
_, selected_pos = select_ssp(found_rel_obj, memory, info['encoded_ssps'], data.ssp_vectors, data.linspace)
results.append([found_rel_obj, selected_ssp, selected_pos])
if last_step:
return 'yes'
else:
results.append(None)
if last_step:
return 'no'
elif 'filter' in func:
last_filter = filter_func(func, img, attr, img_size, results, info, visualize=VISUALIZE)
if last_filter:
results.append(results[-1])
else:
if results[-1] is None:
results.append(None)
elif results[-1][0].split("_")[0] + "_" + str(counter+1) in info['encoded_ssps'].keys():
counter += 1
return None
else:
last_filter = False
results.append(results[-1])
elif 'verify' in func:
pred = verify_func(data, img, attr, results, info, dim, memory, verbose=verbose, visualize=VISUALIZE)
if 'verify_relation_name' in func or 'verify_relation_inv_name' in func:
results.append(results[-1] if pred else None)
else:
results.append(pred)
if last_step:
return 'yes' if pred else 'no'
elif 'query' in func:
return query_func(func, img, attr, results, info, img_size, dim, verbose=verbose, visualize=VISUALIZE)
elif 'exist' in func:
num = int(re.findall(r'\d+', attr[0])[0])
if last_step:
return 'yes' if results[num] is not None else 'no'
else:
if results[num] is not None and 'filter' not in last_func:
results.append(True)
elif results[num] is not None and last_filter:
results.append(True)
else:
results.append(False)
elif 'or(' in func:
attr1 = int(re.findall(r'\d+', attr[0])[0])
attr2 = int(re.findall(r'\d+', attr[1])[0])
return 'yes' if results[attr1] or results[attr2] else 'no'
elif 'and(' in func:
attr1 = int(re.findall(r'\d+', attr[0])[0])
attr2 = int(re.findall(r'\d+', attr[1])[0])
return 'yes' if results[attr1] and results[attr2] else 'no'
elif 'different' in func:
if len(attr) == 1:
logging.warning(f'{func} cannot be computed')
return None
else:
pred_attr1 = query_func(f'query_{attr[2].strip()}', img, [attr[0], attr[2]], results, info, img_size, dim)
pred_attr2 = query_func(f'query_{attr[2].strip()}', img, [attr[1], attr[2]], results, info, img_size, dim)
if pred_attr1 != pred_attr2:
return 'yes'
else:
return 'no'
elif 'same' in func:
if len(attr) == 1:
logging.warning(f'{func} cannot be computed')
return None
pred_attr1 = query_func(f'query_{attr[2].strip()}', img, [attr[0], attr[2]], results, info, img_size, dim)
pred_attr2 = query_func(f'query_{attr[2].strip()}', img, [attr[1], attr[2]], results, info, img_size, dim)
if pred_attr1 == pred_attr2:
return 'yes'
else:
return 'no'
elif 'choose' in func:
return choose_func(data, img, func, attr, img_size, results, info, dim, memory, visualize=VISUALIZE)
else:
logging.warning(f'{func} not implemented')
return -1
last_func = func
if __name__ == "__main__":
TEST = True
DIM = 2048
RANDOM_SEED = 17
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
random.seed(RANDOM_SEED)
x = datetime.now()
TIME_STAMP = x.strftime("%d%b%y-%H%M")
log_file = f"logs/run{TIME_STAMP}-{'VAL' if TEST else 'TRAIN'}{RANDOM_SEED}.log"
log_file = f"logs/run{TIME_STAMP}-DIM{DIM}-{RANDOM_SEED}.log"
logging.basicConfig(level=logging.INFO, filename=log_file,
filemode="w", format="%(asctime)s %(levelname)s %(message)s")
print('Logging to ', log_file)
DATA_PATH = '/scratch/penzkofer/GQA'
CUDA_DEVICE = 7
torch.cuda.set_device(CUDA_DEVICE)
device = torch.device("cuda:" + str(CUDA_DEVICE))
clip_model, preprocess = clip.load("ViT-B/32", device=device)
with open('gqa_all_relations_map.json') as f:
RELATION_DICT = json.load(f)
with open('gqa_vocab_classes.json') as f:
CLASS_DICT = json.load(f)
with open('gqa_all_attributes.json') as f:
ATTRIBUTE_DICT = json.load(f)
start = time.time()
res = 100
dim = DIM
new_size = (25, 25) # size should be smaller than resolution!
xs = np.linspace(0, new_size[1], res)
ys = np.linspace(0, new_size[0], res)
ws = np.linspace(1, 10, 10)
hs = np.linspace(1, 10, 10)
rng = np.random.RandomState(seed=RANDOM_SEED)
x_axis = make_good_unitary(dim, rng=rng)
y_axis = make_good_unitary(dim, rng=rng)
w_axis = make_good_unitary(dim, rng=rng)
h_axis = make_good_unitary(dim, rng=rng)
logging.info(f'Size of vector space: {res**2}x{10**2}x{dim}')
logging.info(f'x-axis resolution = {len(xs)}, y-axis resolution = {len(ys)}')
logging.info(f'width resolution = {len(ws)}, height resolution = {len(hs)}')
# precompute the vectors
VECTORS = get_heatmap_vectors_multidim(xs, ys, ws, hs, x_axis, y_axis, w_axis, h_axis)
logging.info(VECTORS.shape)
logging.info(f'Took {time.time() - start}seconds to load vectors.\n')
# load questions, programs and scenegraphs
if TEST:
questions_path = 'val_balanced_questions.json'
programs_path = 'programs/trainval_balanced_programs.json'
scene_path = 'val_sceneGraphs.json'
else:
questions_path = 'train_balanced_questions.json'
programs_path = 'programs/trainval_balanced_programs.json'
scene_path = 'train_sceneGraphs.json'
with open(os.path.join(DATA_PATH, questions_path), 'r') as f:
questions = json.load(f)
with open(os.path.join(DATA_PATH, programs_path), 'r') as f:
programs = json.load(f)
with open(os.path.join(DATA_PATH, scene_path), 'r') as f:
scenegraphs = json.load(f)
columns = ['semantic', 'entailed', 'equivalent', 'question', 'imageId', 'isBalanced', 'groups',
'answer', 'semanticStr', 'annotations', 'types', 'fullAnswer']
questions = pd.DataFrame.from_dict(questions, orient='index', columns=columns)
questions = questions.reset_index()
questions = questions.rename(columns={"index": "questionID"}, errors="raise")
columns = ['imageID', 'question', 'program', 'questionID', 'answer']
programs = pd.DataFrame(programs, columns=columns)
DATA = GQADataset(questions, programs, scenegraphs, vectors=VECTORS,
axes=[x_axis, y_axis, w_axis, h_axis], linspace=[xs, ys, ws, hs])
logging.info(f'Length of data set: {len(DATA)}')
VISUALIZE = False
DATA.set_visualize(VISUALIZE)
DATA.set_verbose(0)
results_lst = []
num_correct = 0
pbar = tqdm(range(len(DATA)), ncols=115)
for i, IDX in enumerate(pbar):
start = time.time()
img, info, memory = DATA.encode_item(IDX, dim=dim)
avg_mse, avg_iou, correct_items = DATA.decode_item(img, info, memory)
try:
answer = run_program(DATA, img, info, counter=1, memory=memory, dim=dim, verbose=1)
except Exception as e:
answer = None
logging.error(e)
if answer == -1:
logging.warning(f'[{IDX}] not fully implemented!')
time_in_sec = time.time() - start
correct = answer == info["answer"]
num_correct += int(correct)
results = {'q_id':info['q_id'], 'question':info['question'],'program':info['program'], 'image':info['img_id'],
'true_answer': info['answer'], 'pred_answer': answer, 'correct': correct, 'time': time_in_sec,
'enc_avg_mse': avg_mse, 'enc_avg_iou': avg_iou, 'enc_correct_items': correct_items,
'enc_items': len(info["encoded_items"]), 'q_idx': IDX}
results_lst.append(results)
logging.info(f'[{IDX+1}] {num_correct / (i+1):.2%}')
pbar.set_postfix({'correct': f'{num_correct / (i+1):.2%}', 'q_idx': str(IDX+1)})
results_df = pd.DataFrame(results_lst)
logging.info(f'Acurracy: {results_df.correct.sum() / len(results_df):.2%}')
out_path = os.path.join(DATA_PATH, f'results-DIM{DIM}-{"VAL" if TEST else "TRAIN"}{RANDOM_SEED}.pkl')
results_df.to_pickle(out_path)
logging.info(f'Saved results to {out_path}.')