VSA4VQA/VSA4VQA_examples.ipynb

518 lines
1.3 MiB
Text
Raw Normal View History

2024-04-29 17:18:10 +02:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "f7f10b6e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px']\n"
]
}
],
"source": [
"import os \n",
"os.environ[\"OPENBLAS_NUM_THREADS\"] = \"10\"\n",
"\n",
"import json \n",
"import cv2\n",
"import re\n",
"import random\n",
"import time\n",
"from tqdm.notebook import tqdm\n",
"import pandas as pd \n",
"import numpy as np \n",
"import matplotlib\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.patches as patches\n",
"from collections import OrderedDict\n",
"from pattern.text.en import singularize, pluralize\n",
"\n",
"import torch\n",
"import clip\n",
"from PIL import Image\n",
"\n",
"import nengo.spa as spa\n",
"from utils import *\n",
"from dataset import GQADataset\n",
"\n",
"RANDOM_SEED = 17\n",
"\n",
"np.random.seed(RANDOM_SEED)\n",
"torch.manual_seed(RANDOM_SEED)\n",
"random.seed(RANDOM_SEED)\n",
"\n",
"DATA_PATH = '/scratch/penzkofer/GQA'\n",
" \n",
"print(clip.available_models())\n",
"\n",
"CUDA_DEVICE = 7\n",
"torch.cuda.set_device(CUDA_DEVICE)\n",
"\n",
"device = torch.device(\"cuda:\" + str(CUDA_DEVICE))\n",
"clip_model, preprocess = clip.load(\"ViT-B/32\", device=device)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "6c6c9d4c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Size of vector space: 10000x100x512\n",
"x-axis resolution = 100, y-axis resolution = 100\n",
"width resolution = 10, height resolution = 10\n",
"(100, 100, 10, 10, 512)\n",
"CPU times: user 7min 11s, sys: 2.2 s, total: 7min 13s\n",
"Wall time: 7min 13s\n"
]
}
],
"source": [
"%%time\n",
"\n",
"# define SSP vector space \n",
"res = 100\n",
"dim = 1024\n",
"new_size = (25, 25) # size should be smaller than resolution\n",
"\n",
"xs = np.linspace(0, new_size[1], res)\n",
"ys = np.linspace(0, new_size[0], res)\n",
"ws = np.linspace(1, 10, 10)\n",
"hs = np.linspace(1, 10, 10)\n",
"\n",
"rng = np.random.RandomState(seed=RANDOM_SEED)\n",
"x_axis = make_good_unitary(dim, rng=rng)\n",
"y_axis = make_good_unitary(dim, rng=rng)\n",
"w_axis = make_good_unitary(dim, rng=rng)\n",
"h_axis = make_good_unitary(dim, rng=rng)\n",
"\n",
"print(f'Size of vector space: {res**2}x{10**2}x{dim}')\n",
"print(f'x-axis resolution = {len(xs)}, y-axis resolution = {len(ys)}')\n",
"print(f'width resolution = {len(ws)}, height resolution = {len(hs)}')\n",
"\n",
"# precompute all vectors -- this will take some time \n",
"VECTORS = get_heatmap_vectors_multidim(xs, ys, ws, hs, x_axis, y_axis, w_axis, h_axis)\n",
"print(VECTORS.shape)"
]
},
{
"cell_type": "markdown",
"id": "6d6ef10a",
"metadata": {},
"source": [
"### Image to SSP memory "
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "dde1bbd1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"943000\n",
"0 02930152 Is the sky dark? 2354786\n",
"1 07333408 What is on the white wall? 2375429\n",
"2 07333405 Is that pipe red? 2375429\n",
"3 15736264 Is the tall clock small or large? 2368326\n",
"4 111007521 Who is wearing a shirt? 2331819\n"
]
}
],
"source": [
"TEST = False\n",
"\n",
"# load questions, programs and scenegraphs\n",
"if TEST: \n",
" questions_path = 'val_balanced_questions.json'\n",
" programs_path = 'programs/trainval_balanced_programs.json'\n",
" scene_path = 'val_sceneGraphs.json'\n",
"else: \n",
" questions_path = 'train_balanced_questions.json'\n",
" programs_path = 'programs/trainval_balanced_programs.json'\n",
" scene_path = 'train_sceneGraphs.json'\n",
"\n",
"with open(os.path.join(DATA_PATH, questions_path), 'r') as f:\n",
" questions = json.load(f)\n",
"\n",
"with open(os.path.join(DATA_PATH, programs_path), 'r') as f:\n",
" programs = json.load(f)\n",
"\n",
"with open(os.path.join(DATA_PATH, scene_path), 'r') as f:\n",
" scenegraphs = json.load(f)\n",
"\n",
"\n",
"columns = ['semantic', 'entailed', 'equivalent', 'question', 'imageId', 'isBalanced', 'groups', \n",
" 'answer', 'semanticStr', 'annotations', 'types', 'fullAnswer']\n",
"\n",
"questions = pd.DataFrame.from_dict(questions, orient='index', columns=columns)\n",
"questions = questions.reset_index()\n",
"questions = questions.rename(columns={\"index\": \"questionID\"}, errors=\"raise\")\n",
" \n",
"columns = ['imageID', 'question', 'program', 'questionID', 'answer']\n",
"programs = pd.DataFrame(programs, columns=columns)\n",
"\n",
"DATA = GQADataset(questions, programs, scenegraphs, vectors=VECTORS, path=os.path.join(DATA_PATH, 'images/images/'),\n",
" axes=[x_axis, y_axis, w_axis, h_axis], linspace=[xs, ys, ws, hs], visualize=True)\n",
"\n",
"print(f'Length of {\"test\" if TEST else \"train\"} data set: ', len(DATA))\n",
"\n",
"for i in range(5):\n",
" print(i, questions.iloc[i].questionID, questions.iloc[i].question, questions.iloc[i].imageId)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "62842206",
"metadata": {},
"outputs": [],
"source": [
"from run_programs import *\n",
"\n",
"def run_program(data, img, info, counter, memory, dim, verbose=0):\n",
" \"\"\" run program for question on given image: \n",
" for each step in program select appropriate function\n",
" \"\"\"\n",
" scale, w_scale, h_scale = info['scales']\n",
" img_size = img.shape[:2]\n",
"\n",
" results = []\n",
" last_step = False\n",
" last_func = None\n",
" \n",
" for i, step in enumerate(info['program']): \n",
" \n",
" if i+1 == len(info['program']): \n",
" last_step = True\n",
" \n",
" _, func = step.split('=')\n",
" attr = func.split('(')[-1].split(')')[0].split(',')\n",
" \n",
" if verbose > 0: \n",
" print(f'{i+1}. step: \\t {func}')\n",
" \n",
" if 'select' in func: \n",
" obj = attr[0].strip()\n",
" res = select_func(data, img, obj, info, counter, memory, visualize=VISUALIZE)\n",
" \n",
" results.append(res)\n",
" \n",
" if res is None:\n",
" if verbose > 1: \n",
" print(f'Could not find {obj}')\n",
" \n",
" \n",
" elif 'relate' in func: \n",
" \n",
" found_rel_obj = relate_func(data, func, attr, results, info, dim, memory, visualize=VISUALIZE)\n",
"\n",
" if found_rel_obj is not None:\n",
" assert found_rel_obj in info['encoded_ssps'], f'Result of {func}: {found_rel_obj} is not encoded'\n",
" \n",
" selected_ssp = info['encoded_ssps'][found_rel_obj] \n",
" _, selected_pos = select_ssp(found_rel_obj, memory, info['encoded_ssps'], data.ssp_vectors, data.linspace)\n",
" results.append([found_rel_obj, selected_ssp, selected_pos])\n",
"\n",
" if last_step:\n",
" return 'yes' \n",
" else:\n",
" results.append(None)\n",
" \n",
" if last_step:\n",
" return 'no' \n",
" \n",
" elif 'filter' in func: \n",
" \n",
" last_filter = filter_func(func, img, attr, img_size, results, info, visualize=VISUALIZE)\n",
"\n",
" if last_filter: \n",
" results.append(results[-1])\n",
" else:\n",
" if results[-1] is None: \n",
" results.append(None)\n",
" elif results[-1][0].split(\"_\")[0] + \"_\" + str(counter+1) in info['encoded_ssps'].keys():\n",
" counter += 1\n",
" return None\n",
" else: \n",
" last_filter = False\n",
" results.append(results[-1])\n",
" \n",
" \n",
" elif 'verify' in func: \n",
" pred = verify_func(data, img, attr, results, info, dim, memory, verbose=verbose, visualize=VISUALIZE)\n",
" \n",
" if 'verify_relation_name' in func or 'verify_relation_inv_name' in func: \n",
" results.append(results[-1] if pred else None)\n",
" else:\n",
" results.append(pred)\n",
"\n",
" if last_step:\n",
" return 'yes' if pred else 'no'\n",
" \n",
" \n",
" elif 'query' in func: \n",
"\n",
" return query_func(func, img, attr, results, info, img_size, dim, verbose=verbose, visualize=VISUALIZE)\n",
" \n",
" elif 'exist' in func: \n",
"\n",
" num = int(re.findall(r'\\d+', attr[0])[0])\n",
"\n",
" if last_step:\n",
" return 'yes' if results[num] is not None else 'no'\n",
" else:\n",
" if results[num] is not None and 'filter' not in last_func: \n",
" results.append(True)\n",
" elif results[num] is not None and last_filter:\n",
" results.append(True)\n",
" else:\n",
" results.append(False)\n",
"\n",
" elif 'or(' in func: \n",
" attr1 = int(re.findall(r'\\d+', attr[0])[0])\n",
" attr2 = int(re.findall(r'\\d+', attr[1])[0])\n",
" return 'yes' if results[attr1] or results[attr2] else 'no'\n",
"\n",
" elif 'and(' in func: \n",
" attr1 = int(re.findall(r'\\d+', attr[0])[0])\n",
" attr2 = int(re.findall(r'\\d+', attr[1])[0])\n",
" return 'yes' if results[attr1] and results[attr2] else 'no'\n",
" \n",
" elif 'different' in func: \n",
" \n",
" if len(attr) == 1:\n",
" print('[WARNING]'+f'{func} cannot be computed')\n",
" return None \n",
" \n",
" else: \n",
" pred_attr1 = query_func(f'query_{attr[2].strip()}', img, [attr[0], attr[2]], results, info, img_size, dim)\n",
" pred_attr2 = query_func(f'query_{attr[2].strip()}', img, [attr[1], attr[2]], results, info, img_size, dim)\n",
"\n",
" if pred_attr1 != pred_attr2:\n",
" return 'yes'\n",
" else:\n",
" return 'no'\n",
" \n",
" elif 'same' in func: \n",
" \n",
" if len(attr) == 1:\n",
" print('[WARNING]'+f'{func} cannot be computed')\n",
" return None \n",
" \n",
" pred_attr1 = query_func(f'query_{attr[2].strip()}', img, [attr[0], attr[2]], results, info, img_size, dim)\n",
" pred_attr2 = query_func(f'query_{attr[2].strip()}', img, [attr[1], attr[2]], results, info, img_size, dim)\n",
" \n",
" if pred_attr1 == pred_attr2:\n",
" return 'yes'\n",
" else:\n",
" return 'no'\n",
" \n",
" elif 'choose' in func: \n",
" return choose_func(data, img, func, attr, img_size, results, info, dim, memory, visualize=VISUALIZE)\n",
" \n",
" else: \n",
" print('[WARNING]'+f'{func} not implemented')\n",
" return -1\n",
" \n",
" last_func = func\n",
" \n",
" "
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "40084203",
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"425457\n",
"Question #16402843: \n",
"Are there any knives or glasses that are black?\n",
"[yes] Yes, the knife is black.\n",
"\n",
"Program:\n",
"0. select(knife)\n",
"1. filter([0], black)\n",
"2. exist([1])\n",
"3. select(glass)\n",
"4. filter([3], black)\n",
"5. exist([4])\n",
"6. or([2],[5])\n",
"\n",
"Average mean-squared error of 2D locations: 41.5312\n",
"Average IoU of 4D bounding boxes: 0.67\n",
"Correct items: 12 / 17\n",
"\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAgMAAAE8CAYAAABdH7KyAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy81sbWrAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9ecxsWZbdh/32PsO9NyK+8U35cqrMrKquSSSrm91ks0ULHCSRNAlbgG2QNEzDMGQLIGDKJgnYMkzLkgyBAAFZ+sOwLVuySQuCLEKkRdIymzZFstktstVV7OrqmofMrBxfvumbIuIOZ9j+48TLrrYFuBtoQn/0O0Ahs15+L76IuPees/baa60tZmY8X8/X8/V8PV/P1/P1m3bpf9lv4Pl6vp6v5+v5er6er/9y13Mw8Hw9X8/X8/V8PV+/yddzMPB8PV/P1/P1fD1fv8nXczDwfD1fz9fz9Xw9X7/J13Mw8Hw9X8/X8/V8PV+/yddzMPB8PV/P1/P1fD1fv8nXczDwfD1fz9fz9Xw9X7/J13Mw8Hw9X8/X8/V8PV+/yZf/tf6gyD/Ot/F8/WZdTj/gt/zW306ZMlodpTgSgBeQQtBCrZEIbDaRqplSZoLrSdWxz5VVKDx6dMHQnRKi0MVMLZ6yOEQLOStZJ0w6njza4kQgVc5vHXF0IuzGyPFRj0rH2z/4Hp947WWOj29TRYlmEBesVFgiAN1RYNzPBFtRrFLshnme2e1G4srY77cM/TnzlKHsYYjEYYVMMN1MTGXm7HxFWXZIOGWz2fD44iHKEV2sWDG8ZGpSrOxJeSH1z3C74mMg5oJZAS045zBzvP/eQ262lX444tOf+Qz3X7xHrJ6vffWXKSVjmrFOuHX7Nl23YtWvsbolpR4lc7RyZITtNOPI+DlxozvK7NlePmaeE9INZHNo3TDdXBCOhZPNCYMEVv2asSSmtFDnAihJjFqEy6tHDMMarQlb9sR+YK4Tjo4YMv1wRq2VikEwpnRNtQGVmbpkojmset778JKCIw6e4+MVtkRChRIqISpnfUfwjn6zIm0ntmXm1tkROWfi5px3HjzgB+99n9dfeok37r/Kd7/7bbbbbftsRMKQOT2/xTwlegeLeTanp1RLOMvM2wnxju28Q2rFaw9AKhNzqQiBVezYrDtEF0oxrCr4wG66oescnVO0GpkZVUUlICLkZYeZoxsG8HtC3/HkYeL9d68JsWPc7Ynek60QVx6vglMjuohkQUWIvTKljIYNa6n4tWPOe3zxiAg1VD68uMB0IBTP+TlQF/LUYzUSYmVYKeO8ZzPcRkSY5oy4DixAgcePH/LCvXssLhBRfKrsdjvmWtDYgTgUh7lEDJ6836MnK+btRM4LvRg+erIt1Aq1CF485IL3noyxVCP0Akth3BaIgoSRPFWEiPkJEeH9D3ckMl2v5BmcLrguYzZgc0G9EDtDiTiMbJByJRVlWBc2R4Gc271aSkLVY1Xo8BzFNR9dJlI30w9C5zzzsmO5UQiC04JYAR9AHAWjJmO6KayLJ9WEj4FcDOccotauVwDBMy+Gj0opBSuOGB1ugO00YtvMnGCsgiWjD4rVBRcU3w/4WgBBMEqt9EcDuSxoEbw5iAuCR5zHzPjuNx7+/92Lf81goK3CMDxGxBCFnCtmgh7+XcSoVbAqQPuZZ/mGXecxjJwKtYKqkHPFqTKsA9N+oWQQNUBAoBooQugcTsGsgjVUYhilGN4HVAwzww5/LkAplVLq4bWkUSDPAI2AVSjZQAwVaX9PBHWKanv/7RcJcng/IoBZewtmmEE1A4Oc6q+8+A8BJ5Efeq32xpFnb8asvejhn2b2Meiq1Q5/7fBd/hehMWm3w//3f3r2G9ufCyJ2uA7PftA+/rstgFKQH/rPP5xJKc9eTX7ldWuph9f/1b/Y/n/+T3vfz15PDj9hQCkvAA4EfHFUqSQpmKt49dRiFDPcMLAsI0PfoV4wNXyM5FwZcyI7BfU4t6YCFSMVR8mlvd8izDnhO4f+0IUpGPOciLNSEC63mSEW7ty7xZyNcQbnPYNPLKmSc6Uj4mMg5YnQK5oryIr1MPDk4gPKOJNywPtjkgGhUt2K6CKalDllRlH8akPRQGFh8JXdzZagPSrGsizEMJCKUTAsQNKIRzErhAhKYiEjeJw1IJCKUcxhVMyMruswM370d/12cMY3vvY1UipM25E0LNwabqNZyVnYDB1pSexvrnFuTQgdGgIaPecWySvPyfGa73/r+zAtVFdQVfo+8vTikmlbGMIK7JrN0RGh7+h6T84VnUdKDMQY2T695NbZEZysmJaEd2AOplrI+x1934MVJAmheILrqVRmy4gXxDvoCtPuBp87BhcoUelCj4lSa2bJhZPzE1546QX6EHnv0YfsrnecHJ+zPjrCdz3vvv0ujz+45Auf+AxvfPIL3L19m/14xTd++R2KTIz7HarKVa6shx7vI7kK41QZTm4TvUN3DikJDZ4MaO6QkilLIi0TF7sd3SY2oCaVoR/wEtEi7Rq5ilpHmo0YHZUFNNJ1HtGFED2dOObdFT4Elmpo9KRSiSHAknBdR3QOAVQVp7AaOnBgZLIYQSGEwFwKDiXPGYdjwagCaQFfFa/ge6iSmaaKrwNPrvc4OtQ8Jc04P2FiDJs11+OCdCNLKhx3AyfnK6ZsLBXMIqqevQWcKhJhzoaJosGjFFKdqQgiDuehpIVqiZQFnGII8yxE7xmOO+Z5JC8JcYJSqCkz5YxKwQmUlHE+gCiGpyKoq1ACuni89kRf2NeJjAfNFKvkWihWUAQ7nC2iRi0T81KxA0hwEkgpUYuSilBrYt05nA9UU0oxRAWtgpZMEsOckHMmhK6BUeepeaE6pRj4UFBTqjmkUyw4nl5cg814OlIt5FxQaGDCOVQVqpFSpY8B5xUtBS2CVsWJ4ZyRqyKqUAWVX1sD4NcFBobhEX/8j/8zLGpknrCMnq/8wkd0wzHr2wOVG9793kTNgdXQs98W7r/Y8fnf8jIfvP8h3/vGR6yPT5EoXD285PKp44//8z/FzfXX+ev/t0uObgXC0YRujymMvPSp1/nEJ9asj0aW3TULK4yJtVe+++0H3H3tJe69uGZ5OjXU6nogM883fOUr38P8Bud7Viu41Z0w5Qu6lefm0vHdbz5BcQwrx2o1cPuVI87uCX0fmKdKmhJiyq2jE4oUKoqWTOeFR9sbHn20ZdolJiustOPt716wWd8h5S1FK2YBM8O5RB89eIcYGBVx4J3QOaE6w5ljXyrRRUgFpDKZcf10hxuF2rl2mCk4J9RaERG89yjCqmsbgQg42o3pnLBad3SrQNruqTXj4+pwUy1QgBLJS2KqmdB5nEJQYUoLnR9QHM535HLTbpYgKMKTRxPL0lB3TjTUFoSsDavVmnEIUhuiE3f4mcPKOH7w9lcp5UXMwEptlWvw1AIlJYIGQuwoGARHkcpSE6UueO+RUukRgilQCSEiIoTQYIe49lBbcagLpHnEO6WPvh2ch2tKcawCXF/MuF6498I54z7TYViaMe1QV3FaMArVIFvbUFMqKDcUjgg2oNXonDDOC96AYu3QNEVNqF7RxVCDPC2o8yxlT84dYgPmEzkXgssscyV2ntkgWaG3QBWDnFBpQExyBtcxLYUlZUKM2FjJ1XjzzTexNxM/+Pab3Do7596tOzx8/BFvvPEaBUO8EGIk6Dm1Vj56/BEvnZ9BEVQ84gr7aeLhRx/yI69/mmTCds7cOt8Q8w3Lcs1I4OzkFB8GDMc8Lrz34QNqyszzjPeRUmbGMXFyumKzPsKFjqgFp+1zuV4pOZAnI6VMDIoXcK6nFqXiGGJHtQWhknYjUSPReVwUto8n3IlwvFmxTDOFyvsPbjARjs5u8dKd1xmHiZurK6b9SMmJ1dGKXGa+8d43ubW6x4eP9nR95Q/+kd/PP/iHP8ODR9egHUd9h5XKfnuDOsGZ4/LRFd4VNoOwiGDVWPJCkXagiBNcN4ArpCmTQmJKExWlzoZEBw6KFmpaEB3IZcL5itOOkhV1Qp4dzieOjo95/8ED1psTlpIOoFoJvkNMsWKYGMF7FEhpxmtFA5R
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"1. step: \t select(knife)\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAgMAAAE8CAYAAABdH7KyAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy81sbWrAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9y69lSZbmh/3WMrO99znn3uvP8IjIjHxnVlZVq6qr2U12UWKLBCRQAAcCJEITcaSBJoIGGmgoQCMJ0Ez/gAYCIUDQVBAnFIEWiS5B7CpWdb0rszIiM97hHu5+X+ecvbeZraXB2h41rUECHJQb4PnwiHvvuefYNvvWt77vW+Luztv1dr1db9fb9Xa9Xf9gl/63/QLerrfr7Xq73q636+36b3e9BQNv19v1dr1db9fb9Q98vQUDb9fb9Xa9XW/X2/UPfL0FA2/X2/V2vV1v19v1D3y9BQNv19v1dr1db9fb9Q98vQUDb9fb9Xa9XW/X2/UPfL0FA2/X2/V2vV1v19v1D3y9BQNv19v1dr1db9fb9Q985b/vv/jP/8W73C/OYV/49jPhvXef8ad//oIn7z3k2z9s/O2ffMYXvxr49neu+JM/+gykUHYr8B4ynFG75ye/9TsM4wu++OSa+7Uz3wjXz2f+o//4H3N3+xX/8v/5nItvGa1l6MZYEn7qyPiQH/2j9/jgu0arZ0SUbieub1ZOZ+eDD56xywM5N9SFZYabZeZXn33KctPwtZAeZC53grvQXTguM19/PsM5c7gaSVNhOXaUxLP3rrh8pEhZ0ayog7hS0sSQBiQrqTjqnbo2Zruh1sqrrzpffHrifO6QjFyUJEpSoQN0YcyZNAmGkbxQGfC1I15JyZCSOZ0r89koDLRaIVXGccBNcIQsiVYXhM7VxY6Lg2JVsJ4pKZFKR7STckEkIVaxBuiADpVpB2Yj8yJ0W+jSGCShZcQEfG64CFoKu5SZa0NVSLkhSfniyzvcBEkOLiSMlDLLUvEOQ8qoKktfERXIhb5W3BpguCtusM6N/bTj4nHG5o5YopvSkNiZ0sniYJkxKbt9xiS+j+rI2pXqhtrC3d3MNFwwDE5OTm8Z7+A4zYyuxjzD6X5BDKYh8/DRhOaE6Mg0TvQVrq9fUSZ49u57lPFAAjQ1kIbXhFgiT4luHVoGV7qfWeuZ0+mM5M5az+S0p3fF6hkZlLzboU1Z71fmunBxNSI+Q7rgcLjk5u41ZhNjAcxJNLwJ3mdqW6ijAAIipFIYuoF3XDopKZD47NMX3N0b43TBT376U97/1rsMnvmLP/1zem+4NHyEx0+fMo579tMB7EhtE0rjYhfv/XFZSN5Ia+VOTvQlcX/9NevaYNjRPSF2YLm/plwJDy4esJPCbtxz7pW5rvhqOEIFenNu716xm3aIrdAWyjCw+kpiYBycYbzCzHBxLHWWdsIZEF+hdrJn5tl48eqIa2Z3KPH9aiYBPjj7IXMxDoxjIaXEfF5gEPbTAKlwavCrzz/G+8oPv/0dBtnz8uXX3N19zfk807uQClxc7dA0oGYgStkdSENCvNHnBTOY+0JvlSQFlUTrK0tvOIkpj1zsR1Lu8TuZxM9fT+QMY1aSQ2dBRFApiAi9nnEXhmmH5DN5Gnj9deOzT25JeWA+nSkp0THKLpFVSApDKkgXFGGYlKV1JB84iJMPytrPpJ4QBCvGl9fXuOwolnn0yMEqbZ7AC2Vwpp0wr2cudk8BWNYOOoJn6PDq1dc8e/aMqoUBIVXjeDyyuqHDCJIQFFKjJKUvC3IxshwXrK+MAqkozRtm4CYkEtKNlBIVpwukDDZ31sVgMFwWrAKikCrnufPy+ownI6ngzUi5IlnwlsGcPEBOSiJRMnSHZTVqF7Q4hwshJYnX4RaXnSuZxCHtOM5w9JW8i/vI2spyNAyJ85qOpIRLxgDrznLfGWvCrJFypjuoarxsdYbiCJm1QioSe6QrpSTSDo7LjN83luqcTfAKYxbwimYhTzuyG7LdzWbGeDHRrSIdsicoNc5+jSv+53/5/NcHBqoI7z3d8/TBA4bdTC8zw77z8qtbzIxPvmyMD1d+8fGZzoFcGjrsWedbMsLv//u/zTJ/zd/8zXNub4TdNPLy5ZFvf/8Z73xrxx/+wUs0J3x1sitpcG7P8OjRQ37vn3+bp49njjdHtGZUG6fVOB4LP/reD3kwnZhr5bh2uje8d158/pL7lx3XgWlf2GmimyMiiMH8dabfF6ad4lTquvLw6RXvfGvP/tKBlT4LtmRymkhDYRoSRRuVE2ttrOfO/XFlWRe++OSO668rKiOSB+hKM0HHjPkKCZIkam9YU4ahYGsnpRXPnZQTa20k7+ymxFoX5nklM9ANmoEIIE71BkkpQ6FJYt42sqROp1M8Yc0YiyPSGUelykrSEWPkfp4RXdA80OYKPsJQSCIkgVUMKPTFWdSR3EkpDp/WO7423KGkAqKYGdU6pgIKszekOwIkTbS6ouJoUtyFZtARnIRLIvfEKkaVAFFJEtYJQDIUaltJ44jG3YtIptbOTIWUUAqqYEB3wTuY9biozenuJFWUeNgdp/XOslSyOL3F/x+HwsWjHQ7MK7gquwx0p5qRfCCnHe4dyYaq4z4wDjuYX3G/3GHdUN1jkvDUcRlJuZB62l6zINMeywWqMSbnfDyiFJJCrSsljzRLdBzPULWQUMDI2VEalQYkkifcNd5TTziGuzMMA+7OP/n9f4ok+Ms///P4+cczdbfyZPeU1BOtCRe7gXVZOd8f0XSglBHNBR0yj32k7RMPrg58+DcfwrJgycgijGPh9esb5vvOLu+AWy4uLynTSBoTrRm6zrSSKTlxur7l0cMDfjmx1EoCTJ2zNdp8ZhgGxA1tMFoh6UinU+noIKgollZqvWdvB3a5QErkPBAxqgZJefj0MY8fP2KpCy9evYQOl5cPeJAL9/cnPv3Vx6wPG+9868Buv+Ow/01KET775CtefP2Kaid66yxALoWpjKSkLKuTpiv2pTCsJ/p6RpNgIqw2IM1otdLXldtXZ8qUySWDwFhGshSSOXTHU0cZqCuUknCpoIVhSIhWypAZJbGe7sg5s7qjQ6Z1o+QMtZGGgSElhLhskgj73YjUFafRgKJOkcLSOypCW6LgqeKYOG2FbEpJkAYwKsvsJNvx8uZEYkQ8YW1Gs+PAdNhzd16Q4cTajKtxx4PHe+bmrCbAAJKY6SQVKMLSHVdFNSPeqVZxBJG4JHtbcO+sXUAVc8GaUobEmJVlPeEWxYiY0ZaVujZUDHPHzQMEaAIEV6KIq5lkQxRJvrJaY3XFxRGMbrLx48KbLF5Rx61SW6dbxrwjLlhzeofW4xycBiHnjLti292ipmjvdByS0KxR8kDvRk4J7w1LggEpd9Tj62VQvCSub+4xm8kMVHNa63GOpijsVATMac0YSyblOH+1C+6KipPU6a6Aggkib2DDrwkMXF085J1Hjf3YWFPi9vYVjx4m/s3ffsHrV1dcPH2M+R039wsMlTKNnO6Mb3174qe//T6ffvQFH/7NCw5XD5km5/r5LYmR3/8XP+TDn/8Vn3zc2F1lfDQ4ZXoTfvpb3+WDDyZ2uxPXL09UGVFdkdZ5/tU9737vPfRwzevbGlWETiRfeXn9NV++vIVyYBgG9rvMPhW6nEASXz1fefVypuREzoXLqwNPvj1x+TAhoqxzw6qQJXN5MZFLogtgFcep1Xnx8sTd9cJcG0kS6zkxlD0i0FhxwAyWZWYogpIxATTRusPaGZJAMjQl5u6kskOb4xi7w4FaT/RTxUuimaMKSeMyI0EXWHontWAfEENR1tYDiUpi3Bfa8Yw6qAbylpS2m1OYGJm90d2RJpQkuHSGFOBAU6b1c/wyKhSBwzhQqyHm9BoPCkUQlXhtZmgK0GWtxWt2iOTreNjdnC5C7Y53o1vHc8a6Q2sULWjJdHHISqOzmmFWyZpI5uxc4/sK5FxICVKO1yBYFNIkvDmtdhQh5xQPOYJ7oiBQO+upc3hUOFxcUFdDXSm9ImSklHh/PQ5L8w7u9O7QzqhNaBtIrmQ1Wm3QHXVHJTG6kk0gKaoggC0VTUq1md4
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2. step: \t filter([0], black)\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAgMAAAE8CAYAAABdH7KyAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy81sbWrAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9y69lSZbmh/3WMrO99znn3uvP8IjIjHxnVlZVq6qr2U12UWKLBCRQAAcCJEITcaSBJoIGGmgoQCMJ0Ez/gAYCIUDQVBAnFIEWiS5B7CpWdb0rszIiM97hHu5+X+ecvbeZraXB2h41rUECHJQb4PnwiHvvuefYNvvWt77vW+Luztv1dr1db9fb9Xa9Xf9gl/63/QLerrfr7Xq73q636+36b3e9BQNv19v1dr1db9fb9Q98vQUDb9fb9Xa9XW/X2/UPfL0FA2/X2/V2vV1v19v1D3y9BQNv19v1dr1db9fb9Q98vQUDb9fb9Xa9XW/X2/UPfL0FA2/X2/V2vV1v19v1D3y9BQNv19v1dr1db9fb9Q985b/vv/jP/8W73C/OYV/49jPhvXef8ad//oIn7z3k2z9s/O2ffMYXvxr49neu+JM/+gykUHYr8B4ynFG75ye/9TsM4wu++OSa+7Uz3wjXz2f+o//4H3N3+xX/8v/5nItvGa1l6MZYEn7qyPiQH/2j9/jgu0arZ0SUbieub1ZOZ+eDD56xywM5N9SFZYabZeZXn33KctPwtZAeZC53grvQXTguM19/PsM5c7gaSVNhOXaUxLP3rrh8pEhZ0ayog7hS0sSQBiQrqTjqnbo2Zruh1sqrrzpffHrifO6QjFyUJEpSoQN0YcyZNAmGkbxQGfC1I15JyZCSOZ0r89koDLRaIVXGccBNcIQsiVYXhM7VxY6Lg2JVsJ4pKZFKR7STckEkIVaxBuiADpVpB2Yj8yJ0W+jSGCShZcQEfG64CFoKu5SZa0NVSLkhSfniyzvcBEkOLiSMlDLLUvEOQ8qoKktfERXIhb5W3BpguCtusM6N/bTj4nHG5o5YopvSkNiZ0sniYJkxKbt9xiS+j+rI2pXqhtrC3d3MNFwwDE5OTm8Z7+A4zYyuxjzD6X5BDKYh8/DRhOaE6Mg0TvQVrq9fUSZ49u57lPFAAjQ1kIbXhFgiT4luHVoGV7qfWeuZ0+mM5M5az+S0p3fF6hkZlLzboU1Z71fmunBxNSI+Q7rgcLjk5u41ZhNjAcxJNLwJ3mdqW6ijAAIipFIYuoF3XDopKZD47NMX3N0b43TBT376U97/1rsMnvmLP/1zem+4NHyEx0+fMo579tMB7EhtE0rjYhfv/XFZSN5Ia+VOTvQlcX/9NevaYNjRPSF2YLm/plwJDy4esJPCbtxz7pW5rvhqOEIFenNu716xm3aIrdAWyjCw+kpiYBycYbzCzHBxLHWWdsIZEF+hdrJn5tl48eqIa2Z3KPH9aiYBPjj7IXMxDoxjIaXEfF5gEPbTAKlwavCrzz/G+8oPv/0dBtnz8uXX3N19zfk807uQClxc7dA0oGYgStkdSENCvNHnBTOY+0JvlSQFlUTrK0tvOIkpj1zsR1Lu8TuZxM9fT+QMY1aSQ2dBRFApiAi9nnEXhmmH5DN5Gnj9deOzT25JeWA+nSkp0THKLpFVSApDKkgXFGGYlKV1JB84iJMPytrPpJ4QBCvGl9fXuOwolnn0yMEqbZ7AC2Vwpp0wr2cudk8BWNYOOoJn6PDq1dc8e/aMqoUBIVXjeDyyuqHDCJIQFFKjJKUvC3IxshwXrK+MAqkozRtm4CYkEtKNlBIVpwukDDZ31sVgMFwWrAKikCrnufPy+ownI6ngzUi5IlnwlsGcPEBOSiJRMnSHZTVqF7Q4hwshJYnX4RaXnSuZxCHtOM5w9JW8i/vI2spyNAyJ85qOpIRLxgDrznLfGWvCrJFypjuoarxsdYbiCJm1QioSe6QrpSTSDo7LjN83luqcTfAKYxbwimYhTzuyG7LdzWbGeDHRrSIdsicoNc5+jSv+53/5/NcHBqoI7z3d8/TBA4bdTC8zw77z8qtbzIxPvmyMD1d+8fGZzoFcGjrsWedbMsLv//u/zTJ/zd/8zXNub4TdNPLy5ZFvf/8Z73xrxx/+wUs0J3x1sitpcG7P8OjRQ37vn3+bp49njjdHtGZUG6fVOB4LP/reD3kwnZhr5bh2uje8d158/pL7lx3XgWlf2GmimyMiiMH8dabfF6ad4lTquvLw6RXvfGvP/tKBlT4LtmRymkhDYRoSRRuVE2ttrOfO/XFlWRe++OSO668rKiOSB+hKM0HHjPkKCZIkam9YU4ahYGsnpRXPnZQTa20k7+ymxFoX5nklM9ANmoEIIE71BkkpQ6FJYt42sqROp1M8Yc0YiyPSGUelykrSEWPkfp4RXdA80OYKPsJQSCIkgVUMKPTFWdSR3EkpDp/WO7423KGkAqKYGdU6pgIKszekOwIkTbS6ouJoUtyFZtARnIRLIvfEKkaVAFFJEtYJQDIUaltJ44jG3YtIptbOTIWUUAqqYEB3wTuY9biozenuJFWUeNgdp/XOslSyOL3F/x+HwsWjHQ7MK7gquwx0p5qRfCCnHe4dyYaq4z4wDjuYX3G/3GHdUN1jkvDUcRlJuZB62l6zINMeywWqMSbnfDyiFJJCrSsljzRLdBzPULWQUMDI2VEalQYkkifcNd5TTziGuzMMA+7OP/n9f4ok+Ms///P4+cczdbfyZPeU1BOtCRe7gXVZOd8f0XSglBHNBR0yj32k7RMPrg58+DcfwrJgycgijGPh9esb5vvOLu+AWy4uLynTSBoTrRm6zrSSKTlxur7l0cMDfjmx1EoCTJ2zNdp8ZhgGxA1tMFoh6UinU+noIKgollZqvWdvB3a5QErkPBAxqgZJefj0MY8fP2KpCy9evYQOl5cPeJAL9/cnPv3Vx6wPG+9868Buv+Ow/01KET775CtefP2Kaid66yxALoWpjKSkLKuTpiv2pTCsJ/p6RpNgIqw2IM1otdLXldtXZ8qUySWDwFhGshSSOXTHU0cZqCuUknCpoIVhSIhWypAZJbGe7sg5s7qjQ6Z1o+QMtZGGgSElhLhskgj73YjUFafRgKJOkcLSOypCW6LgqeKYOG2FbEpJkAYwKsvsJNvx8uZEYkQ8YW1Gs+PAdNhzd16Q4cTajKtxx4PHe+bmrCbAAJKY6SQVKMLSHVdFNSPeqVZxBJG4JHtbcO+sXUAVc8GaUobEmJVlPeEWxYiY0ZaVujZUDHPHzQMEaAIEV6KIq5lkQxRJvrJaY3XFxRGMbrLx48KbLF5Rx61SW6dbxrwjLlhzeofW4xycBiHnjLti292ipmjvdByS0KxR8kDvRk4J7w1LggEpd9Tj62VQvCSub+4xm8kMVHNa63GOpijsVATMac0YSyblOH+1C+6KipPU6a6Aggkib2DDrwkMXF085J1Hjf3YWFPi9vYVjx4m/s3ffsHrV1dcPH2M+R039wsMlTKNnO6Mb3174qe//T6ffvQFH/7NCw5XD5km5/r5LYmR3/8XP+TDn/8Vn3zc2F1lfDQ4ZXoTfvpb3+WDDyZ2uxPXL09UGVFdkdZ5/tU9737vPfRwzevbGlWETiRfeXn9NV++vIVyYBgG9rvMPhW6nEASXz1fefVypuREzoXLqwNPvj1x+TAhoqxzw6qQJXN5MZFLogtgFcep1Xnx8sTd9cJcG0kS6zkxlD0i0FhxwAyWZWYogpIxATTRusPaGZJAMjQl5u6kskOb4xi7w4FaT/RTxUuimaMKSeMyI0EXWHontWAfEENR1tYDiUpi3Bfa8Yw6qAbylpS2m1OYGJm90d2RJpQkuHSGFOBAU6b1c/wyKhSBwzhQqyHm9BoPCkUQlXhtZmgK0GWtxWt2iOTreNjdnC5C7Y53o1vHc8a6Q2sULWjJdHHISqOzmmFWyZpI5uxc4/sK5FxICVKO1yBYFNIkvDmtdhQh5xQPOYJ7oiBQO+upc3hUOFxcUFdDXSm9ImSklHh/PQ5L8w7u9O7QzqhNaBtIrmQ1Wm3QHXVHJTG6kk0gKaoggC0VTUq1md4
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"CLIP\n",
"[0.5888671875, 0.4111328125]\n",
"The knife is black\n",
"3. step: \t exist([1])\n",
"4. step: \t select(glass)\n",
"Could not find glass\n",
"5. step: \t filter([3], black)\n",
"6. step: \t exist([4])\n",
"7. step: \t or([2],[5])\n",
"\n",
"Question: Are there any knives or glasses that are black?\n",
"Answer: yes\n",
"Answer \"yes\" is correct\n",
"\n",
"🎊 🎊 🎊 \n",
"\n",
"Took 1.90 seconds\n"
]
}
],
"source": [
"IDX = np.random.randint(len(DATA))\n",
"VISUALIZE = True\n",
"\n",
"start = time.time()\n",
"\n",
"print(IDX)\n",
"DATA.print_item(IDX)\n",
"DATA.set_visualize(False)\n",
"DATA.set_verbose(1)\n",
"img, info, memory = DATA.encode_item(IDX, dim=dim)\n",
"avg_mse, avg_iou, correct_items = DATA.decode_item(img, info, memory)\n",
"print()\n",
"\n",
"img_size = np.array(img.shape[:2])\n",
"fig, ax = plt.subplots(1,1)\n",
"\n",
"ax.imshow(img, interpolation='none', origin='upper', vmin=0, vmax=1,\n",
" extent=[0, img_size[1] / info['scales'][0], img_size[0] / info['scales'][0], 0], cmap='gray')\n",
"\n",
"for i, obj in enumerate(info['encoded_items'].items()):\n",
" name, [x, y, width, height] = obj\n",
" name = name.split('_')[0]\n",
" \n",
" width, height = (width * info['scales'][1]) / info['scales'][0], (height * info['scales'][2]) / info['scales'][0] \n",
" rect = patches.Rectangle((x, y),\n",
" width, height,\n",
" linewidth = 2,\n",
" label = name,\n",
" edgecolor = RGB_COLORS[i],\n",
" facecolor = 'none')\n",
" ax.add_patch(rect)\n",
" plt.text(x+0.2, y+1, name, fontsize='medium', color=RGB_COLORS[i])\n",
"\n",
"plt.axis('off')\n",
"plt.show()\n",
"\n",
"answer = run_program(DATA, img, info, counter=1, memory=memory, dim=dim, verbose=2)\n",
"\n",
"time_in_sec = time.time() - start\n",
"correct = answer == info[\"answer\"]\n",
"\n",
"print()\n",
"print(f'Question: {info[\"question\"]}')\n",
"print(f'Answer: {answer}')\n",
"print(f'Answer \"{answer}\" is {\"correct\" if correct else \"incorrect\"}\\n')\n",
"if correct:\n",
" print('\\U0001F38A \\U0001F38A \\U0001F38A \\n')\n",
" \n",
"print(f'Took {time_in_sec:.2f} seconds')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "187f1a83",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "ssp_env",
"language": "python",
"name": "ssp_env"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
}
},
"nbformat": 4,
"nbformat_minor": 5
}