VSA4VQA/generate_query_masks.py

301 lines
11 KiB
Python
Raw Normal View History

2024-04-29 17:18:10 +02:00
import os
import time
import json
import queue
from multiprocessing import Process, Queue
from multiprocessing.pool import Pool
from tqdm import tqdm
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
DATA_PATH = 'GQA/'
REL_PATH = 'full_relations_df.pkl'
IMG_SIZE = (500, 500)
NUM_PROCESSES = 20
NUM_SAMPLES = 100
def bbox_to_mask(x, y, w, h, img_size=IMG_SIZE, name=None, visualize=False):
img = np.zeros(img_size)
mask_w = np.ones(np.clip(w, 0, img_size[1]-x))
for j in range(y, np.clip(y+h, 0, img_size[0])):
img[j][x:x+w] = mask_w
if visualize:
fig = plt.figure(figsize=(img_size[0] // 80, img_size[1] // 80))
plt.imshow(img, cmap='gray')
if name:
plt.title(name)
plt.axis('off')
plt.show()
return img
def get_all_relations_df(data):
print(f'Length of scenegraph data set: {len(data)}')
start = time.time()
df = pd.DataFrame(columns=['image_id', 'relation', 'from', 'to', 'obj_loc', 'obj_w', 'obj_h', 'obj_center',
'rel_obj_loc', 'rel_obj_w', 'rel_obj_h'])
for img_id in data.keys():
all_objects = data.get(str(img_id)).get('objects').items()
# get all object names
all_objects_dict = {id_num: (obj_dict.get('name'), obj_dict.get('x'), obj_dict.get('y'), obj_dict.get('w'), obj_dict.get('h'))
for (id_num, obj_dict) in all_objects}
# get all relations
for obj in all_objects:
id_num, obj_dict = obj
name = obj_dict.get('name')
x, y, width, height = obj_dict.get('x'), obj_dict.get('y'), obj_dict.get('w'), obj_dict.get('h')
center = [x + width / 2, y + height / 2]
for relation in obj_dict.get('relations'):
rel = relation.get('name')
rel_obj, rel_x, rel_y, rel_w, rel_h = all_objects_dict.get(relation.get('object'))
temp = pd.DataFrame.from_dict([{'image_id': img_id, 'relation': rel, 'from': name, 'to': rel_obj,
'obj_loc': [x, y], 'obj_w': width, 'obj_h': height, 'center': center,
'rel_obj_loc': [rel_x, rel_y], 'rel_obj_w': rel_w, 'rel_obj_h': rel_h}])
df = pd.concat([df, temp], ignore_index=True)
#print(f'{df.iloc[-1]["from"]} {df.iloc[-1].relation} {df.iloc[-1].to}')
out_path = 'all_relations.pkl'
df.to_pickle(out_path)
print(f'Saved df to {out_path}')
end = time.time()
elapsed = end - start
print(f'Took {int(elapsed // 60)}:{int(elapsed % 60)} min:s for all {len(df)} relations --> {elapsed / len(df):.2f}s / relation')
def generate_query_mask(df, rel, i, img_center=np.array([250, 250]), uni_size=np.array([50, 50])):
# uni_obj only needed for visualization in the end
uni_obj = bbox_to_mask(img_center[0] - (uni_size[0] // 2), img_center[1] - (uni_size[1] // 2),
50, 50, img_size=(500, 500))
temp_df = df.loc[df.relation == rel]
print(f'[{i}] Number of "{rel}" samples: {len(temp_df)}')
query_mask = np.zeros((500, 500), dtype=np.uint8)
counter = 0
num_discard = 0
for idx in range(len(temp_df)):
if counter >= NUM_SAMPLES:
print(f'[{i}] Reached {counter} samples for relation "{rel}":')
break
img_id = temp_df.iloc[idx].image_id
img_size = (data.get(img_id)['height'], data.get(img_id)['width'])
# get relative object info and generate binary mask
obj_loc = temp_df.iloc[idx].rel_obj_loc
width = temp_df.iloc[idx].rel_obj_w
height = temp_df.iloc[idx].rel_obj_h
# get mask info and generate binary mask
mask_loc = temp_df.iloc[idx].obj_loc
mask_w = temp_df.iloc[idx].obj_w
mask_h = temp_df.iloc[idx].obj_h
if obj_loc[0] > img_size[1] or obj_loc[1] > img_size[0] or mask_loc[0] > img_size[1] or mask_loc[1] > img_size[0]:
#print('error in bounding box -- discard sample')
continue
obj = bbox_to_mask(obj_loc[0], obj_loc[1], width, height, img_size=img_size)
mask = bbox_to_mask(mask_loc[0], mask_loc[1], mask_w, mask_h, img_size=img_size)
img = obj*2 + mask
img_transformed = np.zeros((1000, 1000), dtype=np.uint8)
# scale image first
scale_x, scale_y = uni_size[0] / width, uni_size[1] / height
scale_mat = np.array([[scale_y, 0, 0], [0, scale_x, 0], [0, 0, 1]])
if scale_x > 5 or scale_y > 5:
num_discard += 1
#print(f'Scale is too high! x: {scale_x}, y: {scale_y} -- discard sample')
continue
for i, row in enumerate(img):
for j, col in enumerate(row):
pixel_data = img[i, j]
input_coords = np.array([i, j, 1])
i_out, j_out, _ = scale_mat @ input_coords
if i_out > 0 and i_out < 1000 and j_out > 0 and j_out < 1000 and pixel_data > 0:
# new indices must be within new image -- discard others
img_transformed[int(i_out), int(j_out)] = pixel_data
if not len(np.where(img_transformed >= 2)[0]) > 0:
# no data in transformed image -- discard sample
continue
# find new (x, y) location of object
new_loc = sorted([[y, x] for (y, x) in zip(*np.where(img_transformed >= 2))])[0]
new_center = [new_loc[0] + uni_size[0] // 2, new_loc[1] + uni_size[1] // 2]
# move object to center
move_x, move_y = img_center - new_center
move_mat = np.array([[1, 0, move_x], [0, 1, move_y], [0, 0, 1]])
img_moved = np.zeros((500, 500), dtype=np.uint8)
for i, row in enumerate(img_transformed):
for j, col in enumerate(row):
pixel_data = img_transformed[i, j]
input_coords = np.array([i, j, 1])
i_out, j_out, _ = move_mat @ input_coords
if i_out > 0 and i_out < 500 and j_out > 0 and j_out < 500 and pixel_data > 0:
# new indices must be within new image -- discard others
img_moved[int(i_out), int(j_out)] = pixel_data
# extract relative object mask and add to query mask
mask_transformed = np.where(img_moved==1, img_moved, 0) + np.where(img_moved==3, img_moved, 0)
query_mask += mask_transformed
counter += 1
if counter > 0:
query_mask = query_mask / counter
rel_name = '_'.join(rel.split(' '))
np.save(f'relations/{rel_name}.npy', query_mask)
print(f'[{i}] Saved query mask to: relations/{rel_name}.npy')
if num_discard > 0:
print(f'[{i}] Discarded {num_discard} samples, because scaling was too high.')
plt.figure(figsize=(3,3))
plt.imshow(uni_obj*0.1+ query_mask, cmap='gray')
plt.title(rel)
plt.axis('off')
plt.savefig(f'relations/{rel_name}.png', bbox_inches='tight', dpi=300)
plt.clf()
else:
print(f'[{i}] Could not generate query mask for "{rel}"')
def run_process(tasks, df):
while True:
try:
'''
try to get task from the queue. get_nowait() function will
raise queue.Empty exception if the queue is empty.
queue(False) function would do the same task also.
'''
task = tasks.get_nowait()
i = list(df.relation.unique()).index(task)
except queue.Empty:
break
else:
''' no exception has been raised '''
print(f'[{i}] Starting relation #{i}: {task}')
print()
generate_query_mask(df, task, i)
time.sleep(.5)
return True
# task executed in a worker process
def get_relations_task(img_id):
width, height = data.get(str(img_id))['width'], data.get(str(img_id))['height']
all_objects = data.get(str(img_id)).get('objects').items()
# get all object names
all_objects_dict = {id_num: (obj_dict.get('name'), obj_dict.get('x'), obj_dict.get('y'), obj_dict.get('w'), obj_dict.get('h'))
for (id_num, obj_dict) in all_objects}
all_relations = []
# get all relations
for obj in all_objects:
id_num, obj_dict = obj
name = obj_dict.get('name')
x, y, obj_w, obj_h = obj_dict.get('x'), obj_dict.get('y'), obj_dict.get('w'), obj_dict.get('h')
center = [x + width / 2, y + height / 2]
for relation in obj_dict.get('relations'):
rel = relation.get('name')
rel_obj, rel_x, rel_y, rel_w, rel_h = all_objects_dict.get(relation.get('object'))
all_relations.append({'image_id': img_id, 'width': width, 'height': height, 'relation': rel,
'from': name, 'to': rel_obj, 'obj_loc': [x, y], 'obj_w': obj_w, 'obj_h': obj_h,
'obj_center': center,'rel_obj_loc': [rel_x, rel_y], 'rel_obj_w': rel_w, 'rel_obj_h': rel_h})
return all_relations
if __name__ == '__main__':
path = os.path.join(DATA_PATH, 'train_sceneGraphs.json')
assert os.path.exists(path), f'{path} does not exist!'
with open(os.path.join(DATA_PATH, 'train_sceneGraphs.json'), 'r') as f:
data = json.load(f)
print(f'Length of scenegraph data set: {len(data)}')
if not os.path.exists(REL_PATH):
print('Generating dataframe of all relations...')
# generate list of relations pkl -- use multiprocessing!
# create and configure the process pool
with Pool(processes=NUM_PROCESSES) as pool:
df = pd.DataFrame(columns=['image_id', 'width', 'height', 'relation', 'from', 'to', 'obj_loc', 'obj_w',
'obj_h', 'obj_center', 'rel_obj_loc', 'rel_obj_w', 'rel_obj_h'])
# execute tasks in order
for i, result in enumerate(tqdm(pool.map(get_relations_task, list(data.keys()), chunksize=100))):
temp = pd.DataFrame.from_dict(result)
df = pd.concat([df, temp], ignore_index=True)
if i % 10000 == 0:
df.to_pickle('temp_' + REL_PATH)
print(f'Saved df to {"temp_" + REL_PATH}')
df.to_pickle(REL_PATH)
print(f'Saved df to {REL_PATH}')
else:
df = pd.read_pickle(REL_PATH)
print(f'Number of relations: {len(df.relation.unique())}')
print(df.relation.unique())
# generate query mask for each relation
#for i, rel in enumerate(df.relation.unique()):
# generate_query_mask(df, rel, i)
print('Generating a query mask for each relation...')
# generate query mask for each relation -- use multiprocessing
tasks = Queue()
procs = []
# only use relations with at least 1000 samples
rel_lst = df.relation.value_counts()[df.relation.value_counts() > 1000].index.to_list()
for rel in rel_lst:
tasks.put(rel)
# creating processes -- run only NUM_PROCESSES processes at the same time
for _ in range(NUM_PROCESSES):
p = Process(target=run_process, args=(tasks, df,))
procs.append(p)
p.start()
# completing all processes
for p in procs:
p.join()