neuro-symbolic-visual-dialog/clevr_utils.py

225 lines
6.9 KiB
Python
Raw Permalink Normal View History

2022-08-10 16:49:55 +02:00
"""Utilities for CLEVR-Dialog dataset generation.
Author: Satwik Kottur
"""
import copy
def pretty_print_templates(templates, verbosity=1):
"""Pretty prints templates.
Args:
templates: Templates to print
verbosity: 1 to print name and type of the templates
"""
# Verbosity 1: Name and type.
print('-'*70)
for ii in templates:
print('[Name: %s] [Type: %s]' % (ii['name'], ii['type']))
print('-'*70)
print('Total of %s templates..' % len(templates))
print('-'*70)
def pretty_print_scene_objects(scene):
"""Pretty prints scene objects.
Args:
scene: Scene graph containing list of objects
"""
for index, ii in enumerate(scene['objects']):
print_args = (index, ii['shape'], ii['color'],
ii['size'], ii['material'])
print('\t%d : %s-%s-%s-%s' % print_args)
def pretty_print_dialogs(dialogs):
"""Pretty prints generated dialogs.
Args:
dialogs: Generated dialogs to print
"""
for scene_id, dialog_datum in enumerate(dialogs):
for dialog in dialog_datum['dialogs']:
print(dialog['caption'])
for round_id, ii in enumerate(dialog['dialog']):
coref_id = dialog['graph']['history'][round_id+1]['dependence']
in_tuple = (round_id, ii['question'], str(ii['answer']),
ii['template'], str(coref_id))
print('\t[Q-%d: %s] [A: %s] [%s] [%s]' % in_tuple)
def merge_update_scene_graph(orig_graph, graph_item):
"""Merges two scene graphs into one.
Args:
orig_graph: Original scene graph
graph_item: New graph item to add to the scene graph
Returns:
graph: Deep copy of the original scene graph after merging
"""
graph = copy.deepcopy(orig_graph)
# Local alias.
objects = graph['objects']
# If not mergeable, return the same scene graph.
if not graph_item['mergeable']:
return graph
# 1. Go through each new object
# 2. Find its batch in objects
# a. If found, assert for a clash of attributes, update
# b. If novel, just add the object as is
for new_obj in graph_item['objects']:
match_found = False
obj = objects.get(new_obj['id'], None)
if obj:
# Assert for existing entries.
for attr in new_obj:
try:
assert new_obj[attr] == obj.get(attr, new_obj[attr]),\
'Some of the attributes do not match!'
except:
pdb.set_trace()
# Add additional keys.
objects[new_obj['id']].update(new_obj)
else:
# Add the new object.
objects[new_obj['id']] = new_obj
# if a relation, update it
if 'relation' in graph_item:
rel = graph_item['relation']
# update it with object 2 id
id1 = graph_item['objects'][0]['id']
id2 = graph_item['objects'][1]['id']
rel_objs = graph['relationships'][rel][id1]
rel_objs.append(id2)
graph['relationships'][rel][id1] = rel_objs
# update objects in graph
graph['objects'] = objects
return graph
def add_object_ids(scenes):
"""Adds object ids field for input scenes.
Args:
scenes: List of CLEVR scene graphs
Returns:
scenes: Adds object_id field for the objects in the scene graph inplace
"""
for scene_id, scene in enumerate(scenes['scenes']):
for obj_id, _ in enumerate(scene['objects']):
scenes['scenes'][scene_id]['objects'][obj_id]['id'] = obj_id
return scenes
def clean_object_attributes(scenes):
"""Cleans attributes for objects, keeping only attributes and id.
Args:
scenes: Scene graph to clean
Returns:
scenes: Cleaned up scene graphs inplace
"""
keys = ['shape', 'size', 'material', 'color', 'id']
for scene_id, scene in enumerate(scenes['scenes']):
for obj_id, obj in enumerate(scene['objects']):
new_obj = {key: obj[key] for key in keys}
scenes['scenes'][scene_id]['objects'][obj_id] = new_obj
return scenes
def pretty_print_corefs(dialog, coref_groups):
"""Prints coreferences for a dialog, higlighting different groups in colors.
Args:
dialog: Generated dialogs to print
coref_groups: Coreference groups for dialogs
"""
colorama.init()
# Mapping of group_id -> color_ids for (foreground, background)
color_map = {}
groups = coref_groups.get(0, [])
colored, color_map = pretty_print_coref_sentence(dialog['caption'], groups,
color_map)
print('\n\nC: %s' % colored)
for round_id, round_datum in enumerate(dialog['dialog']):
question = round_datum['question']
groups = coref_groups.get(round_id + 1, [])
colored, color_map = pretty_print_coref_sentence(question, groups,
color_map)
print('%d: %s' % (round_id, colored))
def pretty_print_coref_sentence(sentence, groups, color_map):
"""Prints a sentence containing difference coreference groups.
Args:
sentence: Text sentence
groups: List of coreference groups with spans
color_map: List of groups and associated color maps
Returns:
sentence: Text sentence with colors inserted
color_map: Updated, if new groups in the current sentence
"""
fore_colors = ['RED', 'GREEN', 'YELLOW', 'BLUE', 'MAGENTA']
back_colors = ['BLACK', 'YELLOW', 'CYAN']
insertions = []
for group in groups:
group_id = group['group_id']
if group_id in color_map:
forecolor_id, backcolor_id = color_map[group_id]
else:
num_groups = len(color_map)
forecolor_id = num_groups % len(fore_colors)
backcolor_id = num_groups // len(fore_colors)
color_map[group_id] = (forecolor_id, backcolor_id)
forecolor = fore_colors[forecolor_id]
backcolor = back_colors[backcolor_id]
insertions.append(
(group['span'][0], getattr(colorama.Fore, forecolor)))
insertions.append(
(group['span'][0], getattr(colorama.Back, backcolor)))
insertions.append((group['span'][1],
getattr(colorama.Style, 'RESET_ALL')))
# Perform insertions.
sentence = insert_into_sentence(sentence, insertions)
return sentence, color_map
def insert_into_sentence(sentence, insertions):
"""Sorts and performs insertions from right.
Args:
sentence: Sentence to perform insertions into
insertions: List of insertions, format: (position, text_insert)
Returns:
sentence: Inplace inserted sentence
"""
insertions = sorted(insertions, key=lambda x: x[0], reverse=True)
for position, text in insertions:
sentence = sentence[:position] + text + sentence[position:]
return sentence