300 lines
11 KiB
Python
300 lines
11 KiB
Python
import os
|
|
import time
|
|
import json
|
|
import queue
|
|
from multiprocessing import Process, Queue
|
|
from multiprocessing.pool import Pool
|
|
from tqdm import tqdm
|
|
|
|
import pandas as pd
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
|
|
DATA_PATH = 'GQA/'
|
|
REL_PATH = 'full_relations_df.pkl'
|
|
IMG_SIZE = (500, 500)
|
|
NUM_PROCESSES = 20
|
|
NUM_SAMPLES = 100
|
|
|
|
|
|
def bbox_to_mask(x, y, w, h, img_size=IMG_SIZE, name=None, visualize=False):
|
|
img = np.zeros(img_size)
|
|
mask_w = np.ones(np.clip(w, 0, img_size[1]-x))
|
|
|
|
for j in range(y, np.clip(y+h, 0, img_size[0])):
|
|
img[j][x:x+w] = mask_w
|
|
|
|
if visualize:
|
|
fig = plt.figure(figsize=(img_size[0] // 80, img_size[1] // 80))
|
|
plt.imshow(img, cmap='gray')
|
|
if name:
|
|
plt.title(name)
|
|
plt.axis('off')
|
|
plt.show()
|
|
|
|
return img
|
|
|
|
def get_all_relations_df(data):
|
|
|
|
print(f'Length of scenegraph data set: {len(data)}')
|
|
start = time.time()
|
|
|
|
df = pd.DataFrame(columns=['image_id', 'relation', 'from', 'to', 'obj_loc', 'obj_w', 'obj_h', 'obj_center',
|
|
'rel_obj_loc', 'rel_obj_w', 'rel_obj_h'])
|
|
|
|
for img_id in data.keys():
|
|
all_objects = data.get(str(img_id)).get('objects').items()
|
|
|
|
# get all object names
|
|
all_objects_dict = {id_num: (obj_dict.get('name'), obj_dict.get('x'), obj_dict.get('y'), obj_dict.get('w'), obj_dict.get('h'))
|
|
for (id_num, obj_dict) in all_objects}
|
|
|
|
# get all relations
|
|
for obj in all_objects:
|
|
id_num, obj_dict = obj
|
|
name = obj_dict.get('name')
|
|
x, y, width, height = obj_dict.get('x'), obj_dict.get('y'), obj_dict.get('w'), obj_dict.get('h')
|
|
center = [x + width / 2, y + height / 2]
|
|
|
|
for relation in obj_dict.get('relations'):
|
|
rel = relation.get('name')
|
|
rel_obj, rel_x, rel_y, rel_w, rel_h = all_objects_dict.get(relation.get('object'))
|
|
|
|
|
|
temp = pd.DataFrame.from_dict([{'image_id': img_id, 'relation': rel, 'from': name, 'to': rel_obj,
|
|
'obj_loc': [x, y], 'obj_w': width, 'obj_h': height, 'center': center,
|
|
'rel_obj_loc': [rel_x, rel_y], 'rel_obj_w': rel_w, 'rel_obj_h': rel_h}])
|
|
|
|
df = pd.concat([df, temp], ignore_index=True)
|
|
#print(f'{df.iloc[-1]["from"]} {df.iloc[-1].relation} {df.iloc[-1].to}')
|
|
|
|
out_path = 'all_relations.pkl'
|
|
df.to_pickle(out_path)
|
|
print(f'Saved df to {out_path}')
|
|
|
|
end = time.time()
|
|
elapsed = end - start
|
|
print(f'Took {int(elapsed // 60)}:{int(elapsed % 60)} min:s for all {len(df)} relations --> {elapsed / len(df):.2f}s / relation')
|
|
|
|
|
|
def generate_query_mask(df, rel, i, img_center=np.array([250, 250]), uni_size=np.array([50, 50])):
|
|
|
|
# uni_obj only needed for visualization in the end
|
|
uni_obj = bbox_to_mask(img_center[0] - (uni_size[0] // 2), img_center[1] - (uni_size[1] // 2),
|
|
50, 50, img_size=(500, 500))
|
|
|
|
temp_df = df.loc[df.relation == rel]
|
|
print(f'[{i}] Number of "{rel}" samples: {len(temp_df)}')
|
|
|
|
query_mask = np.zeros((500, 500), dtype=np.uint8)
|
|
counter = 0
|
|
num_discard = 0
|
|
|
|
for idx in range(len(temp_df)):
|
|
if counter >= NUM_SAMPLES:
|
|
print(f'[{i}] Reached {counter} samples for relation "{rel}":')
|
|
break
|
|
|
|
img_id = temp_df.iloc[idx].image_id
|
|
img_size = (data.get(img_id)['height'], data.get(img_id)['width'])
|
|
|
|
# get relative object info and generate binary mask
|
|
obj_loc = temp_df.iloc[idx].rel_obj_loc
|
|
width = temp_df.iloc[idx].rel_obj_w
|
|
height = temp_df.iloc[idx].rel_obj_h
|
|
|
|
# get mask info and generate binary mask
|
|
mask_loc = temp_df.iloc[idx].obj_loc
|
|
mask_w = temp_df.iloc[idx].obj_w
|
|
mask_h = temp_df.iloc[idx].obj_h
|
|
|
|
if obj_loc[0] > img_size[1] or obj_loc[1] > img_size[0] or mask_loc[0] > img_size[1] or mask_loc[1] > img_size[0]:
|
|
#print('error in bounding box -- discard sample')
|
|
continue
|
|
|
|
obj = bbox_to_mask(obj_loc[0], obj_loc[1], width, height, img_size=img_size)
|
|
mask = bbox_to_mask(mask_loc[0], mask_loc[1], mask_w, mask_h, img_size=img_size)
|
|
|
|
img = obj*2 + mask
|
|
img_transformed = np.zeros((1000, 1000), dtype=np.uint8)
|
|
|
|
# scale image first
|
|
scale_x, scale_y = uni_size[0] / width, uni_size[1] / height
|
|
scale_mat = np.array([[scale_y, 0, 0], [0, scale_x, 0], [0, 0, 1]])
|
|
|
|
if scale_x > 5 or scale_y > 5:
|
|
num_discard += 1
|
|
#print(f'Scale is too high! x: {scale_x}, y: {scale_y} -- discard sample')
|
|
continue
|
|
|
|
for i, row in enumerate(img):
|
|
for j, col in enumerate(row):
|
|
pixel_data = img[i, j]
|
|
input_coords = np.array([i, j, 1])
|
|
i_out, j_out, _ = scale_mat @ input_coords
|
|
|
|
if i_out > 0 and i_out < 1000 and j_out > 0 and j_out < 1000 and pixel_data > 0:
|
|
# new indices must be within new image -- discard others
|
|
img_transformed[int(i_out), int(j_out)] = pixel_data
|
|
|
|
if not len(np.where(img_transformed >= 2)[0]) > 0:
|
|
# no data in transformed image -- discard sample
|
|
continue
|
|
|
|
# find new (x, y) location of object
|
|
new_loc = sorted([[y, x] for (y, x) in zip(*np.where(img_transformed >= 2))])[0]
|
|
new_center = [new_loc[0] + uni_size[0] // 2, new_loc[1] + uni_size[1] // 2]
|
|
|
|
# move object to center
|
|
move_x, move_y = img_center - new_center
|
|
move_mat = np.array([[1, 0, move_x], [0, 1, move_y], [0, 0, 1]])
|
|
|
|
img_moved = np.zeros((500, 500), dtype=np.uint8)
|
|
for i, row in enumerate(img_transformed):
|
|
for j, col in enumerate(row):
|
|
pixel_data = img_transformed[i, j]
|
|
input_coords = np.array([i, j, 1])
|
|
i_out, j_out, _ = move_mat @ input_coords
|
|
|
|
if i_out > 0 and i_out < 500 and j_out > 0 and j_out < 500 and pixel_data > 0:
|
|
# new indices must be within new image -- discard others
|
|
img_moved[int(i_out), int(j_out)] = pixel_data
|
|
|
|
# extract relative object mask and add to query mask
|
|
mask_transformed = np.where(img_moved==1, img_moved, 0) + np.where(img_moved==3, img_moved, 0)
|
|
query_mask += mask_transformed
|
|
counter += 1
|
|
|
|
if counter > 0:
|
|
query_mask = query_mask / counter
|
|
rel_name = '_'.join(rel.split(' '))
|
|
np.save(f'relations/{rel_name}.npy', query_mask)
|
|
print(f'[{i}] Saved query mask to: relations/{rel_name}.npy')
|
|
|
|
if num_discard > 0:
|
|
print(f'[{i}] Discarded {num_discard} samples, because scaling was too high.')
|
|
|
|
plt.figure(figsize=(3,3))
|
|
plt.imshow(uni_obj*0.1+ query_mask, cmap='gray')
|
|
plt.title(rel)
|
|
plt.axis('off')
|
|
plt.savefig(f'relations/{rel_name}.png', bbox_inches='tight', dpi=300)
|
|
plt.clf()
|
|
|
|
else:
|
|
print(f'[{i}] Could not generate query mask for "{rel}"')
|
|
|
|
def run_process(tasks, df):
|
|
while True:
|
|
try:
|
|
'''
|
|
try to get task from the queue. get_nowait() function will
|
|
raise queue.Empty exception if the queue is empty.
|
|
queue(False) function would do the same task also.
|
|
'''
|
|
task = tasks.get_nowait()
|
|
i = list(df.relation.unique()).index(task)
|
|
except queue.Empty:
|
|
break
|
|
else:
|
|
''' no exception has been raised '''
|
|
print(f'[{i}] Starting relation #{i}: {task}')
|
|
print()
|
|
generate_query_mask(df, task, i)
|
|
time.sleep(.5)
|
|
|
|
return True
|
|
|
|
|
|
# task executed in a worker process
|
|
def get_relations_task(img_id):
|
|
width, height = data.get(str(img_id))['width'], data.get(str(img_id))['height']
|
|
all_objects = data.get(str(img_id)).get('objects').items()
|
|
|
|
# get all object names
|
|
all_objects_dict = {id_num: (obj_dict.get('name'), obj_dict.get('x'), obj_dict.get('y'), obj_dict.get('w'), obj_dict.get('h'))
|
|
for (id_num, obj_dict) in all_objects}
|
|
|
|
all_relations = []
|
|
|
|
# get all relations
|
|
for obj in all_objects:
|
|
id_num, obj_dict = obj
|
|
name = obj_dict.get('name')
|
|
x, y, obj_w, obj_h = obj_dict.get('x'), obj_dict.get('y'), obj_dict.get('w'), obj_dict.get('h')
|
|
center = [x + width / 2, y + height / 2]
|
|
|
|
for relation in obj_dict.get('relations'):
|
|
rel = relation.get('name')
|
|
rel_obj, rel_x, rel_y, rel_w, rel_h = all_objects_dict.get(relation.get('object'))
|
|
|
|
all_relations.append({'image_id': img_id, 'width': width, 'height': height, 'relation': rel,
|
|
'from': name, 'to': rel_obj, 'obj_loc': [x, y], 'obj_w': obj_w, 'obj_h': obj_h,
|
|
'obj_center': center,'rel_obj_loc': [rel_x, rel_y], 'rel_obj_w': rel_w, 'rel_obj_h': rel_h})
|
|
|
|
return all_relations
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
|
path = os.path.join(DATA_PATH, 'train_sceneGraphs.json')
|
|
assert os.path.exists(path), f'{path} does not exist!'
|
|
|
|
with open(os.path.join(DATA_PATH, 'train_sceneGraphs.json'), 'r') as f:
|
|
data = json.load(f)
|
|
|
|
print(f'Length of scenegraph data set: {len(data)}')
|
|
|
|
if not os.path.exists(REL_PATH):
|
|
print('Generating dataframe of all relations...')
|
|
# generate list of relations pkl -- use multiprocessing!
|
|
# create and configure the process pool
|
|
with Pool(processes=NUM_PROCESSES) as pool:
|
|
|
|
df = pd.DataFrame(columns=['image_id', 'width', 'height', 'relation', 'from', 'to', 'obj_loc', 'obj_w',
|
|
'obj_h', 'obj_center', 'rel_obj_loc', 'rel_obj_w', 'rel_obj_h'])
|
|
|
|
# execute tasks in order
|
|
for i, result in enumerate(tqdm(pool.map(get_relations_task, list(data.keys()), chunksize=100))):
|
|
temp = pd.DataFrame.from_dict(result)
|
|
df = pd.concat([df, temp], ignore_index=True)
|
|
if i % 10000 == 0:
|
|
df.to_pickle('temp_' + REL_PATH)
|
|
print(f'Saved df to {"temp_" + REL_PATH}')
|
|
|
|
df.to_pickle(REL_PATH)
|
|
print(f'Saved df to {REL_PATH}')
|
|
else:
|
|
df = pd.read_pickle(REL_PATH)
|
|
|
|
print(f'Number of relations: {len(df.relation.unique())}')
|
|
print(df.relation.unique())
|
|
|
|
# generate query mask for each relation
|
|
#for i, rel in enumerate(df.relation.unique()):
|
|
# generate_query_mask(df, rel, i)
|
|
|
|
print('Generating a query mask for each relation...')
|
|
# generate query mask for each relation -- use multiprocessing
|
|
tasks = Queue()
|
|
procs = []
|
|
|
|
# only use relations with at least 1000 samples
|
|
rel_lst = df.relation.value_counts()[df.relation.value_counts() > 1000].index.to_list()
|
|
|
|
for rel in rel_lst:
|
|
tasks.put(rel)
|
|
|
|
# creating processes -- run only NUM_PROCESSES processes at the same time
|
|
for _ in range(NUM_PROCESSES):
|
|
p = Process(target=run_process, args=(tasks, df,))
|
|
procs.append(p)
|
|
p.start()
|
|
|
|
# completing all processes
|
|
for p in procs:
|
|
p.join()
|
|
|
|
|