import os import sys os.environ["OPENBLAS_NUM_THREADS"] = "10" import json import cv2 import re import random import time import torch import clip import logging from datetime import datetime from PIL import Image from tqdm import tqdm import pandas as pd import numpy as np import nengo.spa as spa import matplotlib import matplotlib.pyplot as plt import matplotlib.patches as patches import matplotlib.colors as mcolors from collections import OrderedDict from pattern.text.en import singularize, pluralize from dataset import GQADataset from utils import * DATA_PATH = '/scratch/penzkofer/GQA' RGB_COLORS = [] for name, hex in mcolors.cnames.items(): RGB_COLORS.append(mcolors.to_rgb(hex)) CUDA_DEVICE = 7 torch.cuda.set_device(CUDA_DEVICE) device = torch.device("cuda:" + str(CUDA_DEVICE)) clip_model, preprocess = clip.load("ViT-B/32", device=device) with open('gqa_all_relations_map.json') as f: RELATION_DICT = json.load(f) with open('gqa_all_vocab_classes.json') as f: CLASS_DICT = json.load(f) with open('gqa_all_attributes.json') as f: ATTRIBUTE_DICT = json.load(f) SYNONYMS = {'he': ['man', 'boy'], 'she': ['woman', 'girl']} ANSWER_MAP = {'to the right of': 'right', 'to the left of': 'left'} VISUALIZE = False def plot_heatmap_multidim(sp, xs, ys, heatmap_vectors, name='', vmin=-1, vmax=1, cmap='plasma', invert=False): """adapted from https://github.com/ctn-waterloo/cogsci2019-ssp/tree/master""" assert sp.__class__.__name__ == 'SemanticPointer', \ f'Queried object needs to be of type SemanticPointer but is {sp.__class__.__name__}' # axes: a list of axes to be summed over, first sequence applying to first tensor, second to second tensor vs = np.tensordot(sp.v, heatmap_vectors, axes=([0], [4])) res = np.unravel_index(np.argmax(vs, axis=None), vs.shape) plt.imshow(np.transpose(vs[:, :, res[2], res[3]]), origin='upper', interpolation='none', extent=(xs[-1], xs[0], ys[-1], ys[0]), vmin=vmin, vmax=vmax, cmap=cmap) plt.colorbar() plt.axis('off') plt.title(name) plt.show() def select_ssp(name, memory, encoded_ssps, vectors, linspace): """decode location of object with name from SSP memory""" ssp_item = encoded_ssps[name] item_decoded = memory *~ ssp_item clean_loc = ssp_to_loc_multidim(item_decoded, vectors, linspace) return item_decoded, clean_loc def clip_query(bbox, img, obj_name, clip_tokens, visualize=False): """Implements CLIP queries for different attributes""" x, y, w, h = bbox obj_center = (x + w / 2, y + h / 2) masked_img = img.copy() masked_img = cv2.ellipse(masked_img, (int(obj_center[0]), int(obj_center[1])), (int(w*0.7), int(h*0.7)), 0, 0, 360, (255, 0, 0), 4) if visualize: plt.imshow(masked_img) plt.axis('off') plt.show() masked_img = Image.fromarray(np.uint8(masked_img)) tokens = clip.tokenize(clip_tokens) with torch.no_grad(): text_features=clip_model.encode_text(tokens.to(device)) image_features = clip_model.encode_image(preprocess(masked_img).unsqueeze(0).to(device)) image_features /= image_features.norm(dim=-1, keepdim=True) text_features /= text_features.norm(dim=-1, keepdim=True) similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1) indices = torch.max(similarity, 1)[1] similarity = similarity.squeeze() scores = [s.item() for s in similarity] pred = clip_tokens[indices.squeeze().item()] if visualize: print('CLIP') print(scores) print(pred) return indices.squeeze().item() def clip_query_scene(img, clip_tokens, verbose=0): """Implements CLIP queries for entire scene, i.e. no bounding box selection""" img = Image.fromarray(np.uint8(img)) tokens = clip.tokenize(clip_tokens) with torch.no_grad(): text_features=clip_model.encode_text(tokens.to(device)) image_features = clip_model.encode_image(preprocess(img).unsqueeze(0).to(device)) image_features /= image_features.norm(dim=-1, keepdim=True) text_features /= text_features.norm(dim=-1, keepdim=True) similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1) indices = torch.max(similarity, 1)[1] similarity = similarity.squeeze() scores = [s.item() for s in similarity] pred = clip_tokens[indices.squeeze().item()] if verbose > 0: print('CLIP') print(scores) print(pred) return indices.squeeze().item() def clip_choose(bbox1, bbox2, img, attribute, visualize=False): """Run attribute vs. not attribute check for both subjects -- select clip prediction with higher confidence""" x1, y1, w1, h1 = bbox1 obj_center = (x1 + w1 / 2, y1 + h1 / 2) masked_img1 = img.copy() masked_img1 = cv2.ellipse(masked_img1, (int(obj_center[0]), int(obj_center[1])), (int(w1*0.7), int(h1*0.7)), 0, 0, 360, (255, 0, 0), 4) x2, y2, w2, h2 = bbox2 obj_center = (x2 + w2 / 2, y2 + h2 / 2) masked_img2 = img.copy() masked_img2 = cv2.ellipse(masked_img2, (int(obj_center[0]), int(obj_center[1])), (int(w2*0.7), int(h2*0.7)), 0, 0, 360, (255, 0, 0), 4) if visualize: plt.imshow(masked_img1) plt.axis('off') plt.show() plt.imshow(masked_img2) plt.axis('off') plt.show() masked_img1 = Image.fromarray(np.uint8(masked_img1)) masked_img2 = Image.fromarray(np.uint8(masked_img2)) tokens = clip.tokenize([attribute, f'not {attribute}']) with torch.no_grad(): text_features=clip_model.encode_text(tokens.to(device)) image_features = clip_model.encode_image(preprocess(masked_img1).unsqueeze(0).to(device)) image_features /= image_features.norm(dim=-1, keepdim=True) text_features /= text_features.norm(dim=-1, keepdim=True) similarity1 = (100.0 * image_features @ text_features.T).softmax(dim=-1) similarity1 = similarity1.squeeze()[0] with torch.no_grad(): text_features=clip_model.encode_text(tokens.to(device)) image_features = clip_model.encode_image(preprocess(masked_img2).unsqueeze(0).to(device)) image_features /= image_features.norm(dim=-1, keepdim=True) text_features /= text_features.norm(dim=-1, keepdim=True) similarity2 = (100.0 * image_features @ text_features.T).softmax(dim=-1) similarity2 = similarity2.squeeze()[0] if visualize: logging.info(f'CLIP prediction: {[similarity1, similarity2]}') return 0 if similarity1 > similarity2 else 1 def get_rel_path(rel, verbose=0): """get correct relation path, map rel to one of the 37 existing query masks""" rel = RELATION_DICT.get(rel.strip()) # get synonym of relation if no mask exists rel = '_'.join(rel.split(' ')) if ' ' in rel else rel path = 'relations/' + rel + '.npy' if verbose > 0: logging.info('Loading ', path) return path def use_query_mask(obj_pos, info, rel, linspace, axes, dim, memory, verbose=0, visualize=False): """implements query mask usage: load spatial query mask for relation rel, scale query mask to object, encode mask to SSP region and move to object position in SSP memory, extract object proposals in region""" xs, ys, ws, hs = linspace x_axis, y_axis, w_axis, h_axis = axes x, y, width, height = obj_pos # 50 pixels was object size in query mask generation -- use pixel scale values for height and width iso_scale = np.mean([(width*info['scales'][1]) / 50, (height*info['scales'][2]) / 50]) mask = np.load(get_rel_path(rel, verbose)) mask = cv2.resize(mask, (100, 100), interpolation = cv2.INTER_AREA) # crop new query mask according to scale new_area = int(100 / iso_scale) if verbose > 0: print(iso_scale, new_area) resized = mask[max(0, 50-new_area): min(50+new_area, 100), max(0, 50-new_area): min(50+new_area, 100)] resized = cv2.resize(resized, (100,100), interpolation = cv2.INTER_AREA) if visualize: fig, axs = plt.subplots(1,2, sharey=True, layout="constrained", figsize=(6, 3)) fig.set_tight_layout(True) fig.subplots_adjust(top=1.05) fig.suptitle('Relation: '+ rel) plt.setp(plt.gcf().get_axes(), xticks=[], yticks=[]) axs[0].imshow(mask, cmap='gray') axs[0].title.set_text('original') axs[1].imshow(resized, cmap='gray') axs[1].title.set_text('resized') plt.show() # encode mask to SSP region counter = 0 vector = spa.SemanticPointer(data=np.zeros(dim)) for (i, j) in zip(*np.where(resized > 0.05)): x, y = xs[i], ys[j] vector += encode_point_multidim([y, x, 1, 1], axes=axes) counter += 1 vector.normalize() if verbose > 0: logging.info(f'Resized mask encoded {counter} points') if visualize: plot_heatmap_multidim(vector, xs, ys, VECTORS, vmin=-0.2, vmax=0.2, name=f'Encoded Region') # get object info and move query mask to position x, y, width, height = obj_pos obj_center = (x + width / 2, y + height / 2) img_center = np.array([xs[50], ys[50]]) shift = -img_center + (obj_center[1], obj_center[0]) encoded_shift = encode_point(shift[1], shift[0], x_axis=x_axis, y_axis=y_axis) shifted_pos = vector.convolve(encoded_shift) shifted_pos.normalize() if visualize: plot_heatmap_multidim(shifted_pos, xs, ys, VECTORS, vmin=-0.1, vmax=0.1, name=f'Query Region') # query region and compare output to vocab = all saved SSPs vocab_vectors = np.zeros((len(info['encoded_ssps']), dim)) color_lst = [] for i, (name, ssp) in enumerate(info['encoded_ssps'].items()): vocab_vectors[i, :] = ssp.v color_lst.append(RGB_COLORS[i]) similarity = np.tensordot((memory * ~shifted_pos).v, vocab_vectors, axes=([0], [1])) d = OrderedDict(zip(list(info['encoded_ssps'].keys()), similarity)) res = list(OrderedDict(sorted(d.items(), key=lambda x: x[1], reverse=True))) proposals = np.array(res)[np.array(np.array(sorted(similarity, reverse=True)) > 0).astype(bool)] if verbose > 0: print('Proposals: ', *proposals) if visualize: fig = plt.Figure() plt.bar(np.arange(len(info['encoded_ssps'])), similarity, color=color_lst, label=info['encoded_items'].keys()) plt.title('Similarity') plt.legend(loc='upper left', bbox_to_anchor=(1.0, 1.05)) plt.show() return proposals def select_func(data, img, obj, info, counter, memory, visualize=False): """implements select function: probe SSP memory with given object name""" result = None if obj + '_' + str(counter) in info['encoded_ssps'].keys(): obj += '_' + str(counter) obj_ssp, obj_pos = select_ssp(obj, memory, info['encoded_ssps'], data.ssp_vectors, data.linspace) result = [obj, obj_ssp, obj_pos] else: # test synonyms and plural/singular test = [singularize(obj), pluralize(obj)] if SYNONYMS.get(obj): test += SYNONYMS.get(obj) if CLASS_DICT.get(obj): test += CLASS_DICT.get(obj) for obj in test: obj = str(obj) + '_' + str(counter) if obj in info['encoded_ssps'].keys(): obj_ssp, obj_pos = select_ssp(obj, memory, info['encoded_ssps'], data.ssp_vectors, data.linspace) result = [obj, obj_ssp, obj_pos] if result is not None and visualize: fig, ax = plt.subplots(1,1) plt.axis('off') ax.imshow(img) rect = patches.Rectangle((obj_pos[0] * info['scales'][0], obj_pos[1] * info['scales'][0]), obj_pos[2] * info['scales'][1], obj_pos[3] * info['scales'][2], linewidth = 2, label = name, edgecolor = 'c', facecolor = 'none') ax.add_patch(rect) plt.show() return result def verify_func(data, img, attr, results, info, dim, memory, verbose=0, visualize=False): """implements all verify functions dependent on length of attributes given""" if len(attr) == 2: # verify_color, verify_shape, verify_scene num = int(re.findall(r'\d+', attr[0])[0]) if results[num] is not None: name, obj_ssp, obj_pos = results[num] x, y = obj_pos[0] * info['scales'][0], obj_pos[1] * info['scales'][0] w, h = obj_pos[2] * info['scales'][1], obj_pos[3] * info['scales'][2] clip_tokens = [f'The {name.split("_")[0]} is {attr[1].strip()}', f'The {name.split("_")[0]} is not {attr[1].strip()}'] pred = clip_query([x, y, w, h], img, name.split('_')[0], clip_tokens, visualize=visualize) return True if pred == 0 else False else: return False elif len(attr) == 3: # verify_rel, verify_rel_inv obj, rel, rel_obj = attr num = int(re.findall(r'\d+', obj)[0]) proposals = [] if results[num] is not None: name, obj_ssp, obj_pos = results[num] proposals = use_query_mask(obj_pos, info, rel, data.linspace, data.ssp_axes, dim, memory, visualize=visualize) proposals = [str(p).split('_')[0] for p in proposals] return True if rel_obj.strip() in proposals else False else: return False # verify_f elif len(attr) == 1: clip_tokens = [f'The image is {attr[0].strip()}', f'The image is not {attr[0].strip()}', f'The image is a {attr[0].strip()}', f'The image is not a{attr[0].strip()}'] pred = clip_query_scene(img, clip_tokens, verbose=0) return True if pred == 0 or pred == 2 else False else: logging.warning('verify_func not implemented') return -1 def query_func(func, img, attr, results, info, img_size, dim, verbose=0, visualize=False): """implements all query functions""" if 'query_f(' in func: attr_type = attr[0].strip() assert attr_type in CLASS_DICT, f'{attr_type} not found in class dictionary' attributes = CLASS_DICT.get(attr_type) clip_tokens = [f'This is a {a} {attr_type}' for a in attributes] pred = clip_query_scene(img, clip_tokens, verbose=0) return attributes[pred] num = int(re.findall(r'\d+', attr[0])[0]) if results[num] is not None: name, obj_ssp, obj_pos = results[num] x, y = obj_pos[0] * info['scales'][0], obj_pos[1] * info['scales'][0] w, h = obj_pos[2] * info['scales'][1], obj_pos[3] * info['scales'][2] if len(attr) == 1: if 'query_n' in func: # query name return name.split('_')[0] if 'query_h' in func: # query horizontal position --> x-value if (x + w / 2) >= (img_size[1] / 2): return 'right' else: return 'left' if 'query_v' in func: # query vertical position --> y-value if (y + h / 2) >= (img_size[0] / 2): return 'bottom' else: return 'top' elif len(attr) == 2: attr_type = attr[1].strip() assert attr_type in ATTRIBUTE_DICT, f'{attr_type} not found in attribute dictionary' attributes = ATTRIBUTE_DICT.get(attr_type) clip_tokens = [f'The {attr_type} of {name.split("_")[0]} is {a}' for a in attributes] pred = clip_query([x, y, w, h], img, name.split('_')[0], clip_tokens, visualize=visualize) return attributes[pred] else: return None logging.warning('query not implemented', func, attr) return -1 def relate_func(data, func, attr, results, info, dim, memory, visualize=False): """implements all relationship functions""" if 'relate_inv_name' in func or 'relate_name' in func: obj, rel, rel_obj = attr num = int(re.findall(r'\d+', attr[0])[0]) if results[num] is not None: name, obj_ssp, obj_pos = results[num] proposals = use_query_mask(obj_pos, info, rel, data.linspace, data.ssp_axes, dim, memory, verbose=0, visualize=visualize) selected_obj = proposals[0] # use rel_obj to filter proposals rel_obj = rel_obj.strip() if rel_obj == selected_obj.split('_')[0]: if visualize: logging.info('Found perfect match\n') return selected_obj elif rel_obj in [str(p).split('_')[0] for p in proposals]: idx = [str(p).split('_')[0] for p in proposals].index(rel_obj) return proposals[idx] elif rel_obj in CLASS_DICT.keys(): class_lst = CLASS_DICT.get(rel_obj) for p in [str(p).split('_')[0] for p in proposals]: if p in class_lst or singularize(p) in class_lst: if visualize: logging.info(f'Found better proposal for {rel_obj}: {p}\n') idx = [str(p).split('_')[0] for p in proposals].index(p) return proposals[idx] else: if visualize: logging.info(f'Did not find {rel_obj} in proposals\n') return None elif 'relate_inv' in func or 'relate(' in func: obj, rel = attr num = int(re.findall(r'\d+', attr[0])[0]) if results[num] is not None: name, obj_ssp, obj_pos = results[num] proposals = use_query_mask(obj_pos, info, rel, data.linspace, data.ssp_axes, dim, memory, visualize=visualize) return proposals[0] else: logging.warning(f'{func} not implemented') return -1 def filter_func(func, img, attr, img_size, results, info, visualize=False): """implements all filter functions""" obj, filter_attr = attr num = int(re.findall(r'\d+', attr[0])[0]) if results[num] is not None: name, obj_ssp, obj_pos = results[num] x, y = obj_pos[0] * info['scales'][0], obj_pos[1] * info['scales'][0] w, h = obj_pos[2] * info['scales'][1], obj_pos[3] * info['scales'][2] # query height --> y-value if 'bottom' in filter_attr or 'top' in filter_attr: if (y + h / 2) >= (img_size[0] / 2): pred_attr = 'bottom' else: pred_attr = 'top' return pred_attr == filter_attr.strip() # query side --> x-value if 'right' in filter_attr or 'left' in filter_attr: if (x + w / 2) >= (img_size[1] / 2): pred_attr = 'right' else: pred_attr = 'left' return pred_attr == filter_attr.strip() # filter by attribute: color, shape, activity, material else: clip_tokens = [f'The {name.split("_")[0]} is {filter_attr}', f'The {name.split("_")[0]} is not {filter_attr}'] pred = clip_query([x, y, w, h], img, name.split('_')[0], clip_tokens, visualize=visualize) if 'not' in func: return True if pred == 1 else False else: return True if pred == 0 else False else: if visualize: logging.info('No object was found in last step -- nothing to filter') return None return -1 def choose_func(data, img, func, attr, img_size, results, info, dim, memory, verbose=0, visualize=False): """implements all choose functions""" if 'choose_f' in func: pred = clip_query_scene(img, attr, verbose=0) return attr[pred] num = int(re.findall(r'\d+', attr[0])[0]) if results[num] is not None: name, obj_ssp, obj_pos = results[num] if 'choose_h' in func or 'choose_v' in func: x, y = obj_pos[0] * info['scales'][0], obj_pos[1] * info['scales'][0] w, h = obj_pos[2] * info['scales'][1], obj_pos[3] * info['scales'][2] # choose side --> x-value if 'choose_h' in func: if (x + w / 2) >= (img_size[1] / 2): return 'right' else: return 'left' # choose vertical alignment --> y-value if 'choose_v' in func: if (y + h / 2) >= (img_size[0] / 2): return 'bottom' else: return 'top' elif 'choose_n' in func: obj, name1, name2 = attr if name == name1: return name1 elif name == name2: return name2 else: return None elif 'choose_attr' in func: obj, attr_type, attr1, attr2 = attr x, y = obj_pos[0] * info['scales'][0], obj_pos[1] * info['scales'][0] w, h = obj_pos[2] * info['scales'][1], obj_pos[3] * info['scales'][2] clip_tokens = [f'The {attr_type} of {name.split("_")[0]} is {attr1}', f'The {attr_type} of {name.split("_")[0]} is {attr2}'] pred = clip_query([x, y, w, h], img, name.split('_')[0], clip_tokens, visualize=visualize) attr_lst = [attr1.strip(), attr2.strip()] if visualize: logging.info(f'Choose attribute: {attr_type} of {name} --> clip prediction: {attr_lst[pred]}') return attr_lst[pred] elif 'choose_rel_inv' in func: obj, rel_obj, attr1, attr2 = attr proposals1 = use_query_mask(obj_pos, info, attr1, data.linspace, data.ssp_axes, dim, memory, 0, visualize) proposals1 = [str(p).split('_')[0] for p in proposals1] proposals2 = use_query_mask(obj_pos, info, attr2, data.linspace, data.ssp_axes, dim, memory, 0, visualize) proposals2 = [str(p).split('_')[0] for p in proposals2] if rel_obj.strip() in proposals1: return ANSWER_MAP.get(attr1.strip()) if attr1.strip() in ANSWER_MAP else attr1.strip() elif rel_obj.strip() in proposals2: return ANSWER_MAP.get(attr2.strip()) if attr2.strip() in ANSWER_MAP else attr2.strip() else: return None elif 'choose_subj' in func: subj1, subj2, attribute = attr num1 = int(re.findall(r'\d+', subj1)[0]) num2 = int(re.findall(r'\d+', subj2)[0]) if results[num1] is not None and results[num2] is not None: name1, obj_ssp1, obj_pos1 = results[num1] name2, obj_ssp2, obj_pos2 = results[num2] x1, y1 = obj_pos1[0] * info['scales'][0], obj_pos1[1] * info['scales'][0] w1, h1 = obj_pos1[2] * info['scales'][1], obj_pos1[3] * info['scales'][2] x2, y2 = obj_pos2[0] * info['scales'][0], obj_pos2[1] * info['scales'][0] w2, h2 = obj_pos2[2] * info['scales'][1], obj_pos2[3] * info['scales'][2] pred = clip_choose([x1, y1, w1, h1], [x2, y2, w2, h2], img, attribute, visualize=visualize) return name1.split('_')[0] if pred == 0 else name2.split('_')[0] elif 'choose(' in func: obj, attr1, attr2 = attr x, y = obj_pos[0] * info['scales'][0], obj_pos[1] * info['scales'][0] w, h = obj_pos[2] * info['scales'][1], obj_pos[3] * info['scales'][2] clip_tokens = [f'The {name.split("_")[0]} is {attr1}', f'The {name.split("_")[0]} is {attr2}'] pred = clip_query([x, y, w, h], img, name.split('_')[0], clip_tokens, visualize=visualize) attr_lst = [attr1.strip(), attr2.strip()] if visualize: logging.info(f'Choose {attr1} or {attr2} for {name} --> clip prediction: {attr_lst[pred]}') return attr_lst[pred] else: logging.warning(func, 'not implemented yet') return -1 else: if visualize: logging.info('No object was found in last step -- nothing to choose') return None return -1 def run_program(data, img, info, counter, memory, dim, verbose=0): """ run program for question on given image: for each step in program select appropriate function """ scale, w_scale, h_scale = info['scales'] img_size = img.shape[:2] results = [] last_step = False last_func = None for i, step in enumerate(info['program']): if i+1 == len(info['program']): last_step = True _, func = step.split('=') attr = func.split('(')[-1].split(')')[0].split(',') if verbose > 0: logging.info(f'{i+1}. step: \t {func}') if 'select' in func: obj = attr[0].strip() res = select_func(data, img, obj, info, counter, memory, visualize=VISUALIZE) results.append(res) if res is None: if verbose > 1: logging.info(f'Could not find {obj}') elif 'relate' in func: found_rel_obj = relate_func(data, func, attr, results, info, dim, memory, visualize=VISUALIZE) if found_rel_obj is not None: assert found_rel_obj in info['encoded_ssps'], f'Result of {func}: {found_rel_obj} is not encoded' selected_ssp = info['encoded_ssps'][found_rel_obj] _, selected_pos = select_ssp(found_rel_obj, memory, info['encoded_ssps'], data.ssp_vectors, data.linspace) results.append([found_rel_obj, selected_ssp, selected_pos]) if last_step: return 'yes' else: results.append(None) if last_step: return 'no' elif 'filter' in func: last_filter = filter_func(func, img, attr, img_size, results, info, visualize=VISUALIZE) if last_filter: results.append(results[-1]) else: if results[-1] is None: results.append(None) elif results[-1][0].split("_")[0] + "_" + str(counter+1) in info['encoded_ssps'].keys(): counter += 1 return None else: last_filter = False results.append(results[-1]) elif 'verify' in func: pred = verify_func(data, img, attr, results, info, dim, memory, verbose=verbose, visualize=VISUALIZE) if 'verify_relation_name' in func or 'verify_relation_inv_name' in func: results.append(results[-1] if pred else None) else: results.append(pred) if last_step: return 'yes' if pred else 'no' elif 'query' in func: return query_func(func, img, attr, results, info, img_size, dim, verbose=verbose, visualize=VISUALIZE) elif 'exist' in func: num = int(re.findall(r'\d+', attr[0])[0]) if last_step: return 'yes' if results[num] is not None else 'no' else: if results[num] is not None and 'filter' not in last_func: results.append(True) elif results[num] is not None and last_filter: results.append(True) else: results.append(False) elif 'or(' in func: attr1 = int(re.findall(r'\d+', attr[0])[0]) attr2 = int(re.findall(r'\d+', attr[1])[0]) return 'yes' if results[attr1] or results[attr2] else 'no' elif 'and(' in func: attr1 = int(re.findall(r'\d+', attr[0])[0]) attr2 = int(re.findall(r'\d+', attr[1])[0]) return 'yes' if results[attr1] and results[attr2] else 'no' elif 'different' in func: if len(attr) == 1: logging.warning(f'{func} cannot be computed') return None else: pred_attr1 = query_func(f'query_{attr[2].strip()}', img, [attr[0], attr[2]], results, info, img_size, dim) pred_attr2 = query_func(f'query_{attr[2].strip()}', img, [attr[1], attr[2]], results, info, img_size, dim) if pred_attr1 != pred_attr2: return 'yes' else: return 'no' elif 'same' in func: if len(attr) == 1: logging.warning(f'{func} cannot be computed') return None pred_attr1 = query_func(f'query_{attr[2].strip()}', img, [attr[0], attr[2]], results, info, img_size, dim) pred_attr2 = query_func(f'query_{attr[2].strip()}', img, [attr[1], attr[2]], results, info, img_size, dim) if pred_attr1 == pred_attr2: return 'yes' else: return 'no' elif 'choose' in func: return choose_func(data, img, func, attr, img_size, results, info, dim, memory, visualize=VISUALIZE) else: logging.warning(f'{func} not implemented') return -1 last_func = func if __name__ == "__main__": TEST = True DIM = 2048 RANDOM_SEED = 17 np.random.seed(RANDOM_SEED) torch.manual_seed(RANDOM_SEED) random.seed(RANDOM_SEED) x = datetime.now() TIME_STAMP = x.strftime("%d%b%y-%H%M") log_file = f"logs/run{TIME_STAMP}-{'VAL' if TEST else 'TRAIN'}{RANDOM_SEED}.log" log_file = f"logs/run{TIME_STAMP}-DIM{DIM}-{RANDOM_SEED}.log" logging.basicConfig(level=logging.INFO, filename=log_file, filemode="w", format="%(asctime)s %(levelname)s %(message)s") print('Logging to ', log_file) DATA_PATH = '/scratch/penzkofer/GQA' CUDA_DEVICE = 7 torch.cuda.set_device(CUDA_DEVICE) device = torch.device("cuda:" + str(CUDA_DEVICE)) clip_model, preprocess = clip.load("ViT-B/32", device=device) with open('gqa_all_relations_map.json') as f: RELATION_DICT = json.load(f) with open('gqa_vocab_classes.json') as f: CLASS_DICT = json.load(f) with open('gqa_all_attributes.json') as f: ATTRIBUTE_DICT = json.load(f) start = time.time() res = 100 dim = DIM new_size = (25, 25) # size should be smaller than resolution! xs = np.linspace(0, new_size[1], res) ys = np.linspace(0, new_size[0], res) ws = np.linspace(1, 10, 10) hs = np.linspace(1, 10, 10) rng = np.random.RandomState(seed=RANDOM_SEED) x_axis = make_good_unitary(dim, rng=rng) y_axis = make_good_unitary(dim, rng=rng) w_axis = make_good_unitary(dim, rng=rng) h_axis = make_good_unitary(dim, rng=rng) logging.info(f'Size of vector space: {res**2}x{10**2}x{dim}') logging.info(f'x-axis resolution = {len(xs)}, y-axis resolution = {len(ys)}') logging.info(f'width resolution = {len(ws)}, height resolution = {len(hs)}') # precompute the vectors VECTORS = get_heatmap_vectors_multidim(xs, ys, ws, hs, x_axis, y_axis, w_axis, h_axis) logging.info(VECTORS.shape) logging.info(f'Took {time.time() - start}seconds to load vectors.\n') # load questions, programs and scenegraphs if TEST: questions_path = 'val_balanced_questions.json' programs_path = 'programs/trainval_balanced_programs.json' scene_path = 'val_sceneGraphs.json' else: questions_path = 'train_balanced_questions.json' programs_path = 'programs/trainval_balanced_programs.json' scene_path = 'train_sceneGraphs.json' with open(os.path.join(DATA_PATH, questions_path), 'r') as f: questions = json.load(f) with open(os.path.join(DATA_PATH, programs_path), 'r') as f: programs = json.load(f) with open(os.path.join(DATA_PATH, scene_path), 'r') as f: scenegraphs = json.load(f) columns = ['semantic', 'entailed', 'equivalent', 'question', 'imageId', 'isBalanced', 'groups', 'answer', 'semanticStr', 'annotations', 'types', 'fullAnswer'] questions = pd.DataFrame.from_dict(questions, orient='index', columns=columns) questions = questions.reset_index() questions = questions.rename(columns={"index": "questionID"}, errors="raise") columns = ['imageID', 'question', 'program', 'questionID', 'answer'] programs = pd.DataFrame(programs, columns=columns) DATA = GQADataset(questions, programs, scenegraphs, vectors=VECTORS, axes=[x_axis, y_axis, w_axis, h_axis], linspace=[xs, ys, ws, hs]) logging.info(f'Length of data set: {len(DATA)}') VISUALIZE = False DATA.set_visualize(VISUALIZE) DATA.set_verbose(0) results_lst = [] num_correct = 0 pbar = tqdm(range(len(DATA)), ncols=115) for i, IDX in enumerate(pbar): start = time.time() img, info, memory = DATA.encode_item(IDX, dim=dim) avg_mse, avg_iou, correct_items = DATA.decode_item(img, info, memory) try: answer = run_program(DATA, img, info, counter=1, memory=memory, dim=dim, verbose=1) except Exception as e: answer = None logging.error(e) if answer == -1: logging.warning(f'[{IDX}] not fully implemented!') time_in_sec = time.time() - start correct = answer == info["answer"] num_correct += int(correct) results = {'q_id':info['q_id'], 'question':info['question'],'program':info['program'], 'image':info['img_id'], 'true_answer': info['answer'], 'pred_answer': answer, 'correct': correct, 'time': time_in_sec, 'enc_avg_mse': avg_mse, 'enc_avg_iou': avg_iou, 'enc_correct_items': correct_items, 'enc_items': len(info["encoded_items"]), 'q_idx': IDX} results_lst.append(results) logging.info(f'[{IDX+1}] {num_correct / (i+1):.2%}') pbar.set_postfix({'correct': f'{num_correct / (i+1):.2%}', 'q_idx': str(IDX+1)}) results_df = pd.DataFrame(results_lst) logging.info(f'Acurracy: {results_df.correct.sum() / len(results_df):.2%}') out_path = os.path.join(DATA_PATH, f'results-DIM{DIM}-{"VAL" if TEST else "TRAIN"}{RANDOM_SEED}.pkl') results_df.to_pickle(out_path) logging.info(f'Saved results to {out_path}.')