make code public

This commit is contained in:
Adnen Abdessaied 2022-08-10 16:49:55 +02:00
commit 9d8b93db26
26 changed files with 11937 additions and 0 deletions

148
README.md Normal file
View File

@ -0,0 +1,148 @@
# NSVD
This repository contains the official code of the paper:
## Neuro-Symbolic Visual Dialog [[PDF](TODO)]
[Adnen Abdessaied](https://adnenabdessaied.de), [Mihai Bace](https://perceptualui.org/people/bace/), [Andreas Bulling](https://perceptualui.org/people/bulling/)
**Oral Presentaion / Poster**
International Conferenc on Computational Linguistics (COLING), 2022 / Gyeongju, Republic of Korea.
If you find our code useful or use it in your own projects, please cite our paper:
``TODO``
# Abstract
We propose Neuro-Symbolic Visual Dialog (NSVD) —the first method to combine deep learning and symbolic program execution for multi-round visually-grounded reasoning. NSVD significantly outperforms existing purely-connectionist methods on two key challenges inherent to visual dialog: long-distance co-reference resolution as well as vanishing question-answering performance. We demonstrate the latter by proposing a more realistic and stricter evaluation scheme in which we use predicted answers for the full dialog history when calculating accuracy. We describe two variants of our model and show that using this new scheme, our best model achieves an accuracy of 99.72% on CLEVR-Dialog —a relative improvement of more than 10% over the state
of the art —while only requiring a fraction of training data. Moreover, we demonstrate that our neuro-symbolic models have a higher mean first failure round, are more robust against incomplete dialog histories, and generalise better not only to dialogs that are up to three times longer than those seen during training but also to unseen question types and scenes.
# Method
<figure>
<p align="center"><img src="misc/method_overview.png" alt="missing"/></
<figcaption>Overview of our method NSVD.</figcaption>
</figure>
<figure>
<p align="center"><img src="misc/method_smaller.png" alt="missing"/></
<figcaption>Overview of concat and stack encoders.</figcaption>
</figure>
# Requirements
- PyTorch 1.3.1
- Python 3.6
- Ubuntu 18.04
# Raw Data
## Scene Data
We used CLEVR and Minecraft images in this project. The raw images have a large footprint and we won't upload them. However, we provide their json file as well as their derendedred versions. They can be found in :
- ``data/scenes/raw``
- ``data/scenes/derendered``
## Dialog Data
The dialog data we used can be found in ``data/dialogs``.
You can also create your own data using the ``generate_dataset.py`` script.
# Preprocessing
## Scenes
The derendered scenes do not need any further preprocessing and can be diretly used with our neuro-symbolic executor.
## Dialogs
To preprocess the dialogs, follow these steps:
- ``cd preprocess_dialogs``
For the stack encoder, execute
- ``python preprocess.py --input_dialogs_json <path_to_raw_dialog_file> --input_vocab_json '' --output_vocab_json <path_where_to_save_the_vocab> --output_h5_file <path_of_the_output_file> --split <train/val/test> --mode stack``
For the concat encoder, execute
- ``python preprocess.py --input_dialogs_json <path_to_raw_dialog_file> --input_vocab_json '' --output_vocab_json <path_where_to_save_the_vocab> --output_h5_file <path_of_the_output_file> --split <train/val/test> --mode concat``
# Training
First, change directory
- ``cd ../prog_generator``
## Caption Program Parser
To train the caption parser, execute
- ``python train_caption_parser.py --mode train --run_dir <experiment_dir> --res_path <path_to_store_results> --dataPathTr <path_to_preprocessed_training_data> --dataPathVal <path_to_preprocessed_val_data> --dataPathTest <path_to_preprocessed_test_data> --vocab_path <path_where_to_save_the_vocab>``
## Question Program Parser
To train the question program parser with the stack encoder, execute
- ``python train_question_parser.py --mode train --run_dir <experiment_dir> --text_log_dir <log_dir_path> --dataPathTr <path_to_preprocessed_training_data> --dataPathVal <path_to_preprocessed_val_data> --dataPathTest <path_to_preprocessed_test_data> --scenePath <path_to_derendered_scenes> --vocab_path <path_where_to_save_the_vocab> --encoder_type 2``
To train the question program parser with the concat encoder, execute
- ``python train_question_parser.py --mode train --run_dir <experiment_dir> --text_log_dir <log_dir_path> --dataPathTr <path_to_preprocessed_training_data> --dataPathVal <path_to_preprocessed_val_data> --dataPathTest <path_to_preprocessed_test_data> --scenePath <path_to_derendered_scenes> --vocab_path <path_where_to_save_the_vocab> --encoder_type 1``
## Baselines
- [MAC-XXX](https://github.com/ahmedshah1494/clevr-dialog-mac-net/tree/dialog-macnet)
- [HCN](https://github.com/jojonki/Hybrid-Code-Networks)
# Evaluation
To evaluate using the *Hist+GT* scheme, execute
- ``python train_question_parser.py --mode test_with_gt --run_dir <experiment_dir> --text_log_dir <log_dir_path> --dataPathTr <path_to_preprocessed_training_data> --dataPathVal <path_to_preprocessed_val_data> --dataPathTest <path_to_preprocessed_test_data> --scenePath <path_to_derendered_scenes> --vocab_path <path_where_to_save_the_vocab> --encoder_type <1/2> --questionNetPath <path_to_pretrained_question_parser> --captionNetPath <path_to_pretrained_caption_parser> --dialogLen <total_number_of_dialog_rounds> --last_n_rounds <number_of_last_rounds_to_considered_in_history>``
To evaluate using the *Hist+Pred* scheme, execute
- ``python train_question_parser.py --mode test_with_pred --run_dir <experiment_dir> --text_log_dir <log_dir_path> --dataPathTr <path_to_preprocessed_training_data> --dataPathVal <path_to_preprocessed_val_data> --dataPathTest <path_to_preprocessed_test_data> --scenePath <path_to_derendered_scenes> --vocab_path <path_where_to_save_the_vocab> --encoder_type <1/2> --questionNetPath <path_to_pretrained_question_parser> --captionNetPath <path_to_pretrained_caption_parser> --dialogLen <total_number_of_dialog_rounds> --last_n_rounds <number_of_last_rounds_to_considered_in_history>``
# Results
We achieve new state-of-the-art performance on clevr-dialog.
## Hist+GT
| <center>Model</center> | <center>Accurcy</center> | <center>NFFR</center> |
| :---: | :---: | :---: |
| MAC-CQ | 97.34 | 0.92 |
| + CAA | 97.87 | 0.94 |
| + MTM | 97.58 | 0.92 |
| HCN | 75.88 | 0.34 |
| **NSVD-concat (Ours)** | 99.59 | 0.98 |
| **NSVD-stack (Ours)** | **99.72** | **0.99** |
## Hist+Pred
| <center>Model</center> | <center>Accurcy</center> | <center>NFFR</center> |
| :---: | :---: | :---: |
| MAC-CQ | 41.10 | 0.15 |
| + CAA | 89.39 | 0.75 |
| + MTM | 70.39 | 0.46 |
| HCN | 74.42 | 0.32 |
| **NSVD-concat (Ours)** | 99.59 | 0.98 |
| **NSVD-stack (Ours)** | **99.72** | **0.99** |
We refer to our paper for the results of the other experiments.
# Acknowledgements
We thank [Ahmed Shah](https://www.linkedin.com/in/mahmedshah/) for his MAC-XXX implemetation,[Junki Ohmura](https://www.linkedin.com/in/junki/) for his HCN implemantation, [Jiayuan Mao](https://jiayuanm.com/) for providing us with the minecraft images, and finally [Satwik Kottur](https://satwikkottur.github.io/) for his clevr-dialog [codebase](https://github.com/satwikkottur/clevr-dialog).
# Contributors
- [Adnen Abdessaied](https://adnenabdessaied.de)
For any questions or enquiries, don't not hesitate to contact the above contributor.

224
clevr_utils.py Normal file
View File

@ -0,0 +1,224 @@
"""Utilities for CLEVR-Dialog dataset generation.
Author: Satwik Kottur
"""
import copy
def pretty_print_templates(templates, verbosity=1):
"""Pretty prints templates.
Args:
templates: Templates to print
verbosity: 1 to print name and type of the templates
"""
# Verbosity 1: Name and type.
print('-'*70)
for ii in templates:
print('[Name: %s] [Type: %s]' % (ii['name'], ii['type']))
print('-'*70)
print('Total of %s templates..' % len(templates))
print('-'*70)
def pretty_print_scene_objects(scene):
"""Pretty prints scene objects.
Args:
scene: Scene graph containing list of objects
"""
for index, ii in enumerate(scene['objects']):
print_args = (index, ii['shape'], ii['color'],
ii['size'], ii['material'])
print('\t%d : %s-%s-%s-%s' % print_args)
def pretty_print_dialogs(dialogs):
"""Pretty prints generated dialogs.
Args:
dialogs: Generated dialogs to print
"""
for scene_id, dialog_datum in enumerate(dialogs):
for dialog in dialog_datum['dialogs']:
print(dialog['caption'])
for round_id, ii in enumerate(dialog['dialog']):
coref_id = dialog['graph']['history'][round_id+1]['dependence']
in_tuple = (round_id, ii['question'], str(ii['answer']),
ii['template'], str(coref_id))
print('\t[Q-%d: %s] [A: %s] [%s] [%s]' % in_tuple)
def merge_update_scene_graph(orig_graph, graph_item):
"""Merges two scene graphs into one.
Args:
orig_graph: Original scene graph
graph_item: New graph item to add to the scene graph
Returns:
graph: Deep copy of the original scene graph after merging
"""
graph = copy.deepcopy(orig_graph)
# Local alias.
objects = graph['objects']
# If not mergeable, return the same scene graph.
if not graph_item['mergeable']:
return graph
# 1. Go through each new object
# 2. Find its batch in objects
# a. If found, assert for a clash of attributes, update
# b. If novel, just add the object as is
for new_obj in graph_item['objects']:
match_found = False
obj = objects.get(new_obj['id'], None)
if obj:
# Assert for existing entries.
for attr in new_obj:
try:
assert new_obj[attr] == obj.get(attr, new_obj[attr]),\
'Some of the attributes do not match!'
except:
pdb.set_trace()
# Add additional keys.
objects[new_obj['id']].update(new_obj)
else:
# Add the new object.
objects[new_obj['id']] = new_obj
# if a relation, update it
if 'relation' in graph_item:
rel = graph_item['relation']
# update it with object 2 id
id1 = graph_item['objects'][0]['id']
id2 = graph_item['objects'][1]['id']
rel_objs = graph['relationships'][rel][id1]
rel_objs.append(id2)
graph['relationships'][rel][id1] = rel_objs
# update objects in graph
graph['objects'] = objects
return graph
def add_object_ids(scenes):
"""Adds object ids field for input scenes.
Args:
scenes: List of CLEVR scene graphs
Returns:
scenes: Adds object_id field for the objects in the scene graph inplace
"""
for scene_id, scene in enumerate(scenes['scenes']):
for obj_id, _ in enumerate(scene['objects']):
scenes['scenes'][scene_id]['objects'][obj_id]['id'] = obj_id
return scenes
def clean_object_attributes(scenes):
"""Cleans attributes for objects, keeping only attributes and id.
Args:
scenes: Scene graph to clean
Returns:
scenes: Cleaned up scene graphs inplace
"""
keys = ['shape', 'size', 'material', 'color', 'id']
for scene_id, scene in enumerate(scenes['scenes']):
for obj_id, obj in enumerate(scene['objects']):
new_obj = {key: obj[key] for key in keys}
scenes['scenes'][scene_id]['objects'][obj_id] = new_obj
return scenes
def pretty_print_corefs(dialog, coref_groups):
"""Prints coreferences for a dialog, higlighting different groups in colors.
Args:
dialog: Generated dialogs to print
coref_groups: Coreference groups for dialogs
"""
colorama.init()
# Mapping of group_id -> color_ids for (foreground, background)
color_map = {}
groups = coref_groups.get(0, [])
colored, color_map = pretty_print_coref_sentence(dialog['caption'], groups,
color_map)
print('\n\nC: %s' % colored)
for round_id, round_datum in enumerate(dialog['dialog']):
question = round_datum['question']
groups = coref_groups.get(round_id + 1, [])
colored, color_map = pretty_print_coref_sentence(question, groups,
color_map)
print('%d: %s' % (round_id, colored))
def pretty_print_coref_sentence(sentence, groups, color_map):
"""Prints a sentence containing difference coreference groups.
Args:
sentence: Text sentence
groups: List of coreference groups with spans
color_map: List of groups and associated color maps
Returns:
sentence: Text sentence with colors inserted
color_map: Updated, if new groups in the current sentence
"""
fore_colors = ['RED', 'GREEN', 'YELLOW', 'BLUE', 'MAGENTA']
back_colors = ['BLACK', 'YELLOW', 'CYAN']
insertions = []
for group in groups:
group_id = group['group_id']
if group_id in color_map:
forecolor_id, backcolor_id = color_map[group_id]
else:
num_groups = len(color_map)
forecolor_id = num_groups % len(fore_colors)
backcolor_id = num_groups // len(fore_colors)
color_map[group_id] = (forecolor_id, backcolor_id)
forecolor = fore_colors[forecolor_id]
backcolor = back_colors[backcolor_id]
insertions.append(
(group['span'][0], getattr(colorama.Fore, forecolor)))
insertions.append(
(group['span'][0], getattr(colorama.Back, backcolor)))
insertions.append((group['span'][1],
getattr(colorama.Style, 'RESET_ALL')))
# Perform insertions.
sentence = insert_into_sentence(sentence, insertions)
return sentence, color_map
def insert_into_sentence(sentence, insertions):
"""Sorts and performs insertions from right.
Args:
sentence: Sentence to perform insertions into
insertions: List of insertions, format: (position, text_insert)
Returns:
sentence: Inplace inserted sentence
"""
insertions = sorted(insertions, key=lambda x: x[0], reverse=True)
for position, text in insertions:
sentence = sentence[:position] + text + sentence[position:]
return sentence

1049
constraints.py Normal file

File diff suppressed because it is too large Load Diff

1055
constraints_minecraft.py Normal file

File diff suppressed because it is too large Load Diff

1055
constraints_splitA.py Normal file

File diff suppressed because it is too large Load Diff

1055
constraints_splitB.py Normal file

File diff suppressed because it is too large Load Diff

0
executor/__init__.py Normal file
View File

47
executor/clevr_statics.py Normal file
View File

@ -0,0 +1,47 @@
"""
author: Adnen Abdessaied
maintainer: "Adnen Abdessaied"
website: adnenabdessaied.de
version: 1.0.1
"""
COLORS = ["blue", "brown", "cyan", "gray", "green", "purple", "red", "yellow"]
MATERIALS = ["rubber", "metal"]
SHAPES = ["cube", "cylinder", "sphere"]
SIZES = ["large", "small"]
ATTRIBUTES_ALL = COLORS + MATERIALS + SHAPES + SIZES
ANSWER_CANDIDATES = {
# Count questions
"count-all": ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10"],
"count-other": ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"],
"count-all-group": ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10"],
"count-attribute": ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10"],
"count-attribure-group": ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10"],
"count-obj-rel-imm": ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"],
"count-obj-rel-imm2": ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"],
"count-obj-rel-early": ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"],
"count-obj-exclude-imm": ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"],
"count-obj-exclude-early": ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"],
# Existence questions
"exist-other": ["yes", "no"],
"exist-attribute": ["yes", "no"],
"exist-attribute-group": ["yes", "no"],
"exist-obj-rel-imm": ["yes", "no"],
"exist-obj-rel-imm2": ["yes", "no"],
"exist-obj-rel-early": ["yes", "no"],
"exist-obj-exclude-imm": ["yes", "no"],
"exist-obj-exclude-early": ["yes", "no"],
# Seek questions
"seek-attr-imm": ATTRIBUTES_ALL,
"seek-attr-imm2": ATTRIBUTES_ALL,
"seek-attr-early": ATTRIBUTES_ALL,
"seek-attr-sim-early": ATTRIBUTES_ALL,
"seek-attr-rel-imm": ATTRIBUTES_ALL,
"seek-attr-rel-early": ATTRIBUTES_ALL,
}

View File

@ -0,0 +1,44 @@
"""
author: Adnen Abdessaied
maintainer: "Adnen Abdessaied"
website: adnenabdessaied.de
version: 1.0.1
"""
CLASSES = ["pig", "cow", "sheep", "chicken", "wolf", "horse", "villager", "treeA", "treeB", "armorstand", "boat", "minecart"]
DIRECTIONS = ["facing_forward", "facing_backward", "facing_right", "facing_left"]
NATURES = ["animal", "human", "plant", "inanimated_object"]
ATTRIBUTES_ALL = CLASSES + DIRECTIONS + NATURES
ANSWER_CANDIDATES = {
# Count questions
"count-all": ["0", "1", "2", "3", "4", "5", "6"],
"count-other": ["0", "1", "2", "3", "4", "5", "6"],
"count-all-group": ["0", "1", "2", "3", "4", "5", "6"],
"count-attribute": ["0", "1", "2", "3", "4", "5", "6"],
"count-attribure-group": ["0", "1", "2", "3", "4", "5", "6"],
"count-obj-rel-imm": ["0", "1", "2", "3", "4", "5", "6"],
"count-obj-rel-imm2": ["0", "1", "2", "3", "4", "5", "6"],
"count-obj-rel-early": ["0", "1", "2", "3", "4", "5", "6"],
"count-obj-exclude-imm": ["0", "1", "2", "3", "4", "5", "6"],
"count-obj-exclude-early": ["0", "1", "2", "3", "4", "5", "6"],
# Existence questions
"exist-other": ["yes", "no"],
"exist-attribute": ["yes", "no"],
"exist-attribute-group": ["yes", "no"],
"exist-obj-rel-imm": ["yes", "no"],
"exist-obj-rel-imm2": ["yes", "no"],
"exist-obj-rel-early": ["yes", "no"],
"exist-obj-exclude-imm": ["yes", "no"],
"exist-obj-exclude-early": ["yes", "no"],
# Seek questions
"seek-attr-imm": ATTRIBUTES_ALL,
"seek-attr-imm2": ATTRIBUTES_ALL,
"seek-attr-early": ATTRIBUTES_ALL,
"seek-attr-sim-early": ATTRIBUTES_ALL,
"seek-attr-rel-imm": ATTRIBUTES_ALL,
"seek-attr-rel-early": ATTRIBUTES_ALL,
}

File diff suppressed because it is too large Load Diff

952
generate_dataset.py Normal file
View File

@ -0,0 +1,952 @@
r"""Generates CLEVR-Dialog dataset.
Needs access to the following files:
synonyms: Contains several synonyms for each word in the question/caption.
caption templates: List of caption templates.
question templates: List of question templates.
metainfo: Meta-information related to attributes and values of CLEVR objects.
Usage:
python -u generate_dataset.py \
--scene_path="data/scenes/CLEVR_train_scenes.json" \
--num_beams=100 \
--num_workers=12 \
--save_path="data/clevr_train_raw.json"
Author: Satwik Kottur
"""
import copy
import collections
import json
import multiprocessing
import os
import random
import re
import time
from absl import flags
from absl import app
import numpy as np
from tqdm import tqdm as progressbar
import clevr_utils as utils
import global_vars as gvars
# import constraints_splitB as constraints
import constraints
FLAGS = flags.FLAGS
flags.DEFINE_string('synonym_path', '/projects/abdessaied/clevr-dialog/templates/synonyms.json',
'Path to synonyms file')
flags.DEFINE_string('metainfo_path', '/projects/abdessaied/clevr-dialog/templates/metainfo.json',
'Path to meta information file')
flags.DEFINE_string('caption_template_root', '/projects/abdessaied/clevr-dialog/templates/captions/',
'Root to folder with caption templates')
flags.DEFINE_string('question_template_root', '/projects/abdessaied/clevr-dialog/templates/questions/',
'Root to folder with question templates')
flags.DEFINE_string('scene_path',
# '/projects/abdessaied/clevr-dialog/output/result_clevr_oroginal_test.json',
'/projects/abdessaied/clevr-dataset-gen/output_finetune_20_objs_with_masks_many_attr/CLEVR_scenes.json',
'Path to CLEVR scene path json file')
flags.DEFINE_string('scene_id_file', '',
'Path to specific CLEVR scene ids to generate dialogs')
flags.DEFINE_string('save_path', '/projects/abdessaied/clevr-dialog/output/raw_data_modified/dialogs_finetune_20_objects_10_rounds.json',
'Path to save the dataset json')
flags.DEFINE_integer('num_beams', 100, 'Number of beams in dialog search')
flags.DEFINE_integer('num_workers', 64, 'Number of workers to use in search')
flags.DEFINE_integer('captions_per_image', 5, 'Number of captions per image')
flags.DEFINE_integer('num_images', -1,
'Number of images to generate dialogs. -1 for all.')
flags.DEFINE_integer('num_rounds', 10, 'Number of rounds in each dialog')
# Number of beams and distribution of question types.
# Start cutting down beams after 5th round.
# Heuristics (for round 4):
# A. count <= 2 1 <= seek <= 3 exist <= 2
# B. count + exist <= 3
# C. Independent questions <= 1
# Heuristics (for round 5):
# A. count <= 2 2 <= seek <= 4 exist <= 2
# B. count + exist <= 3
# C. Independent questions <= 1
ranges = {3: {'indep': [0, 1], 'seek': [1, 4], 'exist': [0, 1],
'count': [0, 1], 'exist+count': [0, 2]},
4: {'indep': [0, 1], 'seek': [2, 4], 'exist': [0, 1],
'count': [0, 1], 'exist+count': [0, 2]},
5: {'indep': [0, 1], 'seek': [2, 5], 'exist': [0, 2],
'count': [0, 2], 'exist+count': [0, 3]},
6: {'indep': [0, 1], 'seek': [2, 5], 'exist': [0, 2],
'count': [0, 2], 'exist+count': [0, 3]},
7: {'indep': [0, 2], 'seek': [3, 5], 'exist': [0, 2],
'count': [0, 2], 'exist+count': [0, 3]},
8: {'indep': [0, 2], 'seek': [3, 6], 'exist': [0, 3],
'count': [0, 3], 'exist+count': [0, 3]},
9: {'indep': [0, 2], 'seek': [3, 6], 'exist': [0, 3],
'count': [0, 3], 'exist+count': [0, 4]}}
def mapping(tag):
"""Maps tag to attribute.
Args:
tag: An input tag
Returns:
tag_label: Label for the input tag
"""
return gvars.METAINFO['tag_map'][tag.replace('1', '')]
def inv_mapping(attribute, arg_id=0):
"""Inverse maps attribute to tag.
Args:
attribute: Name of the attribute
arg_id: Argument id to use. Append 1 if arg_id is 1, else nothing
Returns:
base_tag: The string for the tag
"""
base_tag = gvars.METAINFO['tag_inv_map'][attribute]
if arg_id > 0:
base_tag = base_tag[:-1] + str(arg_id) + base_tag[-1]
return base_tag
def get_tag_group(tag):
"""Gets the group id from tag string.
For example, tag string of <S> is 0, <S1> is 1.
Assumes single digit group id.
Args:
tag: Tag string
Returns:
group_id: Return extracted group id
"""
group_id = 0 if len(tag) <= 3 else int(tag[-2])
return group_id
def replace_attribute(text, tag, obj_group, eliminate=False):
"""Replaces the attribute tags in text using available object properties.
NOTE: If shape is to be replaced, we use 'thing' in its place.
Args:
text: The text template to perform replacement
tag: The tags to replace in the text
obj_group: Available object properties to replace with
eliminate: Eliminate the remaining attribute tags
Returns:
replaced_text: The replaced text
"""
group = get_tag_group(tag)
if mapping(tag) == 'relation':
# Actual relation tag, else position tag.
if tag == '<R>':
relation_list = gvars.METAINFO['relation_phrases'][obj_group['relation']]
relation_cand = random.choice(relation_list)
else:
relation_cand = obj_group['relation']
return text.replace(tag, relation_cand)
if mapping(tag) == 'shape':
if eliminate:
replacer = 'thing'
else:
replacer = str(obj_group['objects'][group][mapping(tag)])
# Plural forms for groups.
if obj_group.get('count', 1) > 1 or obj_group.get('use_plural', False):
replacer += 's'
elif mapping(tag) == 'count':
if eliminate:
replacer = ''
else:
replacer = str(obj_group['count'])
else:
if eliminate:
replacer = ''
else:
replacer = str(obj_group['objects'][group][mapping(tag)])
return text.replace(tag, replacer)
def realize_text_and_extract_scene(scene, template, filter_objs):
"""Samples attributes for template using filtered objects.
In addition, creates scene graph for the new information added.
Args:
scene: Current scene graph
template: Text template to use to generate questions
filter_objs: Set of objects satisfying constraints of current template
Returns:
sample: Contains the text realization and scene graph
"""
def default_list(): return collections.defaultdict(list)
graph = {'relationships': collections.defaultdict(default_list),
'counts': {}, 'exists': {}, 'history': [], 'objects': {}}
# number of inputs
n_inputs = template.get('inputs', 1)
# sample a text template
text_sample = random.choice(template['text'])
text_sample_index = template['text'].index(text_sample)
# extract attribute tags and get them into groups
tags = re.findall('(<[\d\w]*>)', text_sample)
tag_groups = collections.defaultdict(list)
for tag in tags:
group_id = get_tag_group(tag)
tag_groups[group_id].append(tag)
# sample a random element from filtered
arg_sample = random.choice(filter_objs)
# scene information obtained from the current round
graph_item = arg_sample['graph']
# remove tags from text not allowed by filter_objs
for arg_ind in range(n_inputs):
obj_sample = arg_sample['objects'][arg_ind]
avail_attrs = obj_sample['optional'] + obj_sample['required']
for ii in tag_groups[arg_ind][::-1]:
if mapping(ii) not in avail_attrs:
tag_groups[arg_ind].remove(ii)
text_sample = replace_attribute(
text_sample, ii, arg_sample, True)
# assert that all required attributes are present as tags
for attribute in obj_sample['required']:
required_tag = inv_mapping(attribute, arg_ind)
if required_tag not in tag_groups[arg_ind]:
print("required_tag: {}".format(required_tag))
print("template: {}".format(template))
assert required_tag in tag_groups[arg_ind], \
'A required attribute is missing in template!'
# start compiling tags to keep
tags_to_keep = [inv_mapping(ii, arg_ind)
for ii in obj_sample['required']]
# filter out those not present in text template
optional_tags = [inv_mapping(ii, arg_ind)
for ii in obj_sample['optional']]
optional_tags = [
ii for ii in optional_tags if ii in tag_groups[arg_ind]]
# if tags_to_keep is empty, sample from optional with 1:70 2:25 3:5
if len(optional_tags) > 0:
if len(tags_to_keep) > 0:
n_tags_sample = [0, 1, 2]
else:
n_tags_sample = [1, 2, 3]
n_sample = np.random.choice(n_tags_sample, 1,
p=gvars.METAINFO['probabilities'],
replace=False)
# lower cap at the length of optional
n_sample = min(n_sample[0], len(optional_tags))
if n_sample > 0:
tags_to_keep += random.sample(optional_tags, n_sample)
# now create a dictionary of placeholders with actual attribute values
for tag in tag_groups[arg_ind]:
remove = tag not in tags_to_keep
text_sample = replace_attribute(
text_sample, tag, arg_sample, remove)
# remove attributes from objects not included in tags_to_keep
if 'objects' in graph_item:
for ii in gvars.METAINFO['attributes']:
if inv_mapping(ii, arg_ind) not in tags_to_keep:
if ii in graph_item['objects'][arg_ind]:
del graph_item['objects'][arg_ind][ii]
# record the caption info
# Record info and merge scene graphs.
args = []
# if "unique-obj" == template['label']:
# print('yey')
for obj in arg_sample['objects']:
if obj is None:
continue
else:
for k in obj['required']:
arg = obj.get(k, None)
if arg is not None:
if arg not in args: # and type(arg) == str:
args.append(arg)
else:
arg = arg_sample.get(k, None)
if arg is not None and arg not in args and type(arg) == str:
args.append(arg)
arg = obj.get('attribute', None)
if arg is not None and arg not in args:
args.append(arg)
if template['label'] == 'obj-relation':
args.append(arg_sample['relation'])
if template['label'] == "count-att-no":
template['label'] = "count-att"
graph_item['round'] = 0
sample = {}
sample['template_info'] = [copy.deepcopy(template)]
sample['args'] = args
del sample['template_info'][-1]['text']
sample['template_info'][-1]['index'] = text_sample_index
sample['caption'] = text_sample
sample['template'] = template['label']
sample['dialog'] = []
# append history, update scene graph, and save the new scene graph
graph['history'].append(graph_item)
sample['graph'] = utils.merge_update_scene_graph(graph, graph_item)
return sample
def realize_question(dialog, template, filter_objs):
"""Samples attributes for template using filtered objects.
In addition, creates scene graph for the new information added.
Args:
scene: Current scene graph
template: Text template to use to generate questions
filter_objs: Set of objects satisfying constraints of current template
Returns:
sample: Contains the text realization and scene graph
"""
# Number of inputs.
n_inputs = template.get('inputs', 0)
# Sample a text template.
text_sample = random.choice(template['text'])
text_sample_index = template['text'].index(text_sample)
# Extract attribute tags and get them into groups.
tags = re.findall('(<[\d\w]*>)', text_sample)
tag_groups = collections.defaultdict(list)
for tag in tags:
group_id = get_tag_group(tag)
tag_groups[group_id].append(tag)
# Sample a random element from filtered.
arg_sample = random.choice(filter_objs)
# Remove tags from text not allowed by filter_objs.
for arg_ind in range(n_inputs):
obj_sample = arg_sample['objects'][arg_ind]
avail_attrs = obj_sample['optional'] + obj_sample['required']
for ii in tag_groups[arg_ind][::-1]:
if mapping(ii) not in avail_attrs:
tag_groups[arg_ind].remove(ii)
text_sample = replace_attribute(
text_sample, ii, arg_sample, True)
# Assert that all required attributes are present as tags.
for attribute in obj_sample['required']:
required_tag = inv_mapping(attribute, arg_ind)
# Make an exception for <R> and <P>
if required_tag == '<R>' and '<P>' in tag_groups[arg_ind]:
continue
assert required_tag in tag_groups[arg_ind], \
'A required attribute is missing in template!'
# Start compiling tags to keep.
tags_to_keep = [inv_mapping(ii, arg_ind)
for ii in obj_sample['required']]
# Filter out those not present in text template.
optional_tags = [inv_mapping(ii, arg_ind)
for ii in obj_sample['optional']]
optional_tags = [
ii for ii in optional_tags if ii in tag_groups[arg_ind]]
# If tags_to_keep is empty, sample from optional with (1:70, 2:25, 3:5).
if len(optional_tags) > 0:
if len(tags_to_keep) > 0:
n_tags_sample = [0, 1, 2]
else:
n_tags_sample = [1, 2, 3]
n_sample = np.random.choice(n_tags_sample, 1,
p=gvars.METAINFO['probabilities'],
replace=False)
# Lower cap at the length of optional.
n_sample = min(n_sample[0], len(optional_tags))
if n_sample > 0:
tags_to_keep += random.sample(optional_tags, n_sample)
# Now create a dictionary of placeholders with actual attribute values.
for tag in tag_groups[arg_ind]:
remove = tag not in tags_to_keep
text_sample = replace_attribute(
text_sample, tag, arg_sample, remove)
# Record info and merge scene graphs.
args = []
# if template['label'] == 'seek-attr-early':
# print('yey')
for obj in arg_sample['objects']:
if obj is None:
continue
else:
for k in obj['required']:
arg = obj.get(k, None)
if arg is not None:
if arg not in args:
args.append(arg)
else:
arg = arg_sample.get(k, None)
if arg is not None:
args.append(arg)
arg = obj.get('attribute', None)
if arg is not None and arg not in args:
args.append(arg)
# req_att_keys = [k for obj in arg_sample['objects'] for k in obj['required'] if obj is not None]
dialog_datum = {'question': text_sample, 'answer': arg_sample['answer'],
'template': template['label'], 'args': args}
dialog['template_info'].append(template.copy())
del dialog['template_info'][-1]['text']
dialog['template_info'][-1]['index'] = text_sample_index
if 'unique' in template['label']:
print('voila')
dialog['dialog'].append(dialog_datum)
graph_item = arg_sample['graph']
# If mergeable, add it to the objects list.
dialog['graph'] = utils.merge_update_scene_graph(
dialog['graph'], graph_item)
# If there are volatile objects in the graph item, remove them.
for obj in graph_item['objects'][::-1]:
if obj.get('volatile', False):
graph_item['objects'].remove(obj)
dialog['graph']['history'].append(graph_item)
return dialog
def clean_text_subroutine(text, thing, suffix):
"""Cleans the text and substitutes thing with object (subroutine).
Args:
text: Text string to be cleaned
thing: Whether to use 'thing' or 'object'
suffix: Either '?' (question) or '.' (caption)
Returns:
clean_text: Text string after cleaning procedure
"""
# Synonyms + skipping optional part of the sentence
clean_text = skip_and_replace_phrases(text)
# Remove full stop, empty spaces, capitalize the start letter.
clean_text = re.sub(' +', ' ', clean_text.replace(suffix, '').strip(' '))
# First replace 'a thing' -> 'an object'.
# Then perform remaining actions.
if thing == 'object':
clean_text = clean_text.replace('a thing', 'an object')
clean_text = clean_text.replace('thing', thing)
clean_text = clean_text[0].upper() + clean_text[1:] + suffix
return clean_text
def clean_dialog_text(dialogs):
"""Cleans the dialog texts.
Args:
dialogs: Generated dialogs to perform text cleaning
Returns:
dialogs: Return the dialogs after cleaning the text inplace
"""
# Replace thing with object throughout with probability 0.5.
thing = 'thing' if random.random() > 0.5 else 'object'
for index, dialog_datum in enumerate(dialogs):
# Clean the caption.
text = dialog_datum['caption']
dialogs[index]['caption'] = clean_text_subroutine(text, thing, '.')
for r_id, dialog in enumerate(dialog_datum['dialog']):
# Clean the question.
text = dialog['question']
text = clean_text_subroutine(text, thing, '?')
dialogs[index]['dialog'][r_id]['question'] = text
return dialogs
def skip_and_replace_phrases(text):
"""Substitutes synonyms and skips optional parts stochastically.
Args:
text: Text string
Returns:
text: Text string with synonyms replaced and optional parts skipped
"""
# For each text in [], replace it with '' with probability 0.5.
matches = re.findall('(\[[ \w]*\])', text)
for match in matches:
if random.uniform(0, 1) > 0.5:
text = text.replace(match, '')
else:
text = text.replace(match, match[1:-1])
# Remove empty spaces, if any.
text = re.sub(' +', ' ', text)
# Search for synonyms, replace at uniformly random.
text = text.lower()
for key, values in gvars.METAINFO['synonym_keys']:
if key in text:
text = text.replace(key, random.choice(values))
return text
def generate_captions(scenes, templates):
"""Wrapper generates captions.
Args:
scenes: List of scene graphs for which to generate captions
templates: List of available caption templates
Returns:
generated_content: Captions generated for the input scenes
"""
template_dictionary = {ii['label']: ii for ii in templates}
generated_content = []
for scene in scenes['scenes'][0:FLAGS.num_images]:
content = {}
# Copy over image_index, split, image_filename from scene.
for key in ['image_index', 'split', 'image_filename']:
content[key] = scene[key]
content['dialogs'] = []
# Filter objects based on constraints.
filter_objs = constraints.caption(scene, templates)
for filter_obj in filter_objs:
# Realize the text, and return the partial scene knowledge (q).
template = template_dictionary[filter_obj[0]['graph']['template']]
sample = realize_text_and_extract_scene(
scene, template, filter_obj)
# Add it to the list of dialogs.
content['dialogs'].append(sample)
generated_content.append(content)
return generated_content
def generate_questions(scenes, dialogs, templates, params):
"""Wrapper generates questions.
Args:
scenes: List of scene graphs to generate questions
dialogs: Contains already generated captions for scenes graphs
templates: List of available question templates
params: Beam search parameters for question generation
Returns:
new_dialogs: Generated raw dialogs with captions and questions
"""
new_dialogs = []
for scene_id, dialog_datum in enumerate(dialogs):
image_dialogs = copy.deepcopy(dialog_datum)
image_dialogs['dialogs'] = []
for dialog in dialog_datum['dialogs']:
# Pick a template at random.
flag = False
iter_count = 0
while not flag:
# Pick a template at random.
template = random.choice(templates)
# Filter objects based on constraints.
filter_objs = constraints.question(scenes['scenes'][scene_id],
dialog, template)
flag = len(filter_objs) != 0
# Extreme case -- exit
iter_count += 1
if iter_count > 10:
break
# Realize q question.
if flag:
deep_copy = copy.deepcopy(dialog)
gen_dialog = realize_question(deep_copy, template, filter_objs)
image_dialogs['dialogs'].append(copy.deepcopy(gen_dialog))
new_dialogs.append(image_dialogs)
return new_dialogs
def worker(scenes, cap_templates, ques_templates, worker_id, out_q):
"""Worker method generates dialogs (caption + questions) for pool of scenes.
Args:
scenes: List of CLEVR scenes to generate dialogs
cap_templates: Templates for caption generation
ques_templates: Templates for question generation
worker_id: Id for the current worker
out_q: Output queue to save generated dialogs from different sources
Returns:
Adds dialogs against the worker id in the output queue.
"""
dialogs = []
for index, scene in enumerate(scenes):
cur_time = time.strftime('%a-%d%b%y-%X', time.gmtime())
print('Generating [ %s ] [ Worker: %d, Progress: %d/%d Scene: %d ]' %
(cur_time, worker_id, index, len(scenes), scene['image_index']))
try:
gen_dialog = generate_dialog_bfs(
scene, cap_templates, ques_templates)
dialogs.append(json.loads(json.dumps(gen_dialog)))
except:
print('NOTE: Missing data for %d' % scene['image_index'])
out_q.put({worker_id: dialogs})
def generate_dialog_bfs(scene, cap_templates, ques_templates):
"""Perform approximate breadth-first-search (BFS) to generate dialogs.
Args:
scene: Scene graph for the CLEVR image
cap_templates: List of caption templates
ques_templates: List of question templates
Returns:
bundle: List of dialogs generated for the input scene graph
"""
bundle = {}
# Generate captions for the scene.
# Copy over image_index, split, image_filename from scene.
for key in ['image_index', 'split', 'image_filename']:
bundle[key] = scene[key]
template_dictionary = {ii['label']: ii for ii in cap_templates}
content = {}
# Filter objects based on constraints on captions.
filter_objs = constraints.caption(scene, cap_templates)
for filter_obj in filter_objs:
# Realize the text, and return the partial scene knowledge (q).
template = template_dictionary[filter_obj[0]['graph']['template']]
sample = realize_text_and_extract_scene(scene, template, filter_obj)
# Add it to the list of dialogs.
content[template['label']] = [sample]
# Now generate questions.
# Group templates, exist/count of similar type together.
ques_groups = collections.defaultdict(list)
labels = [ii['label'] for ii in ques_templates]
# print('\n'.join(labels))
for index, ii in enumerate(ques_templates):
if 'exist' in ii['label'] or 'count' in ii['label']:
ques_groups[labels[index][4:]].append(ii)
else:
ques_groups[labels[index]].append(ii)
for round_id in range(FLAGS.num_rounds):
new_content = {}
# For each group.
for cap_label, cap_dialogs in content.items():
cur_pool = []
for dialog_datum in cap_dialogs:
for _, group in ques_groups.items():
template = random.choice(group)
# Make a copy.
datum_copy = copy.deepcopy(dialog_datum)
# Filter objects based on constraints.
filter_objs = constraints.question(
scene, datum_copy, template)
if len(filter_objs) == 0:
continue
# Realize q question.
gen_dialog = realize_question(
datum_copy, template, filter_objs)
cur_pool.append(gen_dialog)
if round_id in ranges:
for d_id, dialog in enumerate(cur_pool):
n_types = {'indep': 0, 'seek': 0, 'exist': 0, 'count': 0}
keep_dialog = True
labels = [ii['label']
for ii in dialog['template_info'][1:]]
for label in labels:
if label in gvars.METAINFO['independent_questions']:
n_types['indep'] += 1
label_type = label.split('-')[0]
n_types[label_type] += 1
# Heuristic A, C
for q_type, count in n_types.items():
limit = ranges[round_id][q_type]
if limit[0] > count or count > limit[1]:
keep_dialog = False
break
# Heuristic B
limit = ranges[round_id]['exist+count']
if n_types['count'] + n_types['exist'] > limit[1]:
keep_dialog = False
if not keep_dialog:
cur_pool[d_id] = None
cur_pool = [ii for ii in cur_pool if ii is not None]
# Keep limited number of beams (for speed).
if len(cur_pool) > FLAGS.num_beams:
cur_pool = sample_beams(cur_pool)[:FLAGS.num_beams]
new_content[cap_label] = cur_pool
content = copy.deepcopy(new_content)
# Get dialogs with sim, imm2, early questions.
for cap_label, cap_dialogs in content.items():
# Sample beams.
content[cap_label] = sample_beams(cap_dialogs)
# Remove keys that are empty.
empty_keys = [key for key, val in content.items() if len(val) == 0]
for key in empty_keys:
del content[key]
# For each caption, sample one.
sampled_dialogs = []
for cap_label, cap_dialogs in content.items():
if len(cap_dialogs) > 0:
sampled_dialogs.append(cap_dialogs.pop())
# Get 5 per image, compensate by taking from other entries.
content_keys = [ii for ii in content.keys()]
while len(sampled_dialogs) < 5:
random_label = random.choice(content_keys)
sampled_dialogs.append(cap_dialogs.pop())
# Finally, make the dialog text readable.
sampled_dialogs = clean_dialog_text(sampled_dialogs)
# Generate the coreference chain.
for dialog_id, dialog in enumerate(sampled_dialogs):
sampled_dialogs[dialog_id] = identify_coref_chains(dialog)
bundle['dialogs'] = sampled_dialogs
return bundle
def sample_beams(dialogs):
"""Samples beams based on the number of constraints satisfied.
Args:
dialogs: Generated dialogs to sample beams
Returns:
sampled_dialogs: List of sampled dialogs based on the constraints
"""
num_constraints = []
for d_id, dialog in enumerate(dialogs):
satisfied = 0
labels = [ii['label'] for ii in dialog['template_info'][1:]]
# Have a imm2 for sure
satisfied += np.sum(['imm2' in ii for ii in labels])
# Have a imm2 for sure
satisfied += np.sum(['sim' in ii for ii in labels])
# Have 'early'
satisfied += min(4, np.sum(['early' in ii for ii in labels]))
# Add it with the number of constraints it satisfies.
num_constraints.append((satisfied, d_id))
# Then order.
def sort_key(x): return (x[0], random.random())
ids = sorted(num_constraints, key=sort_key, reverse=True)
sampled_dialogs = [dialogs[ii[1]] for ii in ids]
return sampled_dialogs
def identify_coref_chains(dialog):
"""Identifies the coreference chains in generated dialog.
Args:
dialog: Generated dialogs for which coreference chains to be identified
Returns:
dialog: A copy of dialog, with coreference chains annotated
"""
for r_id, datum in enumerate(dialog['dialog']):
label = datum['template']
if label in gvars.METAINFO['independent_questions']:
dialog['graph']['history'][r_id + 1]['dependence'] = None
continue
if (label == 'exist-attribute-group' or label == 'count-attribute-group' or
label == 'count-all-group'):
dialog['graph']['history'][r_id + 1]['dependence'] = r_id - 1
continue
if 'imm' in label:
dialog['graph']['history'][r_id + 1]['dependence'] = r_id - 1
continue
if 'early' in label:
# Go over previous history.
cur_history = dialog['graph']['history'][r_id + 1]
assert 'focus_id' in cur_history and 'focus_desc' in cur_history,\
'More focus objects than one, no focus objects!'
focus_id = cur_history['focus_id']
for attr in gvars.METAINFO['attributes']:
if attr in cur_history['focus_desc']:
break
history = dialog['graph']['history'][:r_id + 1]
for hist_id, hist_datum in enumerate(history):
for obj in hist_datum['objects']:
if obj['id'] == focus_id and attr in obj:
dialog['graph']['history'][r_id +
1]['dependence'] = hist_id - 1
break
return dialog
def main(unused_argv):
"""Main method generates the CLEVR-Dialog dataset.
"""
# Read the scene file.
with open(FLAGS.scene_path, 'r') as file_id:
scenes = json.load(file_id)
# Read the synonyms file.
with open(FLAGS.synonym_path, 'r') as file_id:
synonyms = json.load(file_id)
def sorter(x): return len(x[0].split(' '))
# Read the metainformation file.
with open(FLAGS.metainfo_path, 'r') as file_id:
gvars.METAINFO = json.load(file_id)
tag_inv_map = {attr: tag for tag, attr in gvars.METAINFO['tag_map'].items()
if tag != '<P>'}
gvars.METAINFO['tag_inv_map'] = tag_inv_map
gvars.METAINFO['synonym_keys'] = sorted(synonyms.items(),
key=sorter, reverse=True)
# Add ids to objects.
scenes = utils.add_object_ids(scenes)
scenes = utils.clean_object_attributes(scenes)
# Read the caption templates.
template_paths = os.listdir(FLAGS.caption_template_root)
cap_templates = []
for ii in template_paths:
with open(os.path.join(FLAGS.caption_template_root, ii), 'r') as file_id:
cur_templates = json.load(file_id)
cap_templates.extend(cur_templates)
# utils.pretty_print_templates(cap_templates, 1)
# Read the question templates.
template_paths = os.listdir(FLAGS.question_template_root)
ques_templates = []
for ii in template_paths:
with open(os.path.join(FLAGS.question_template_root, ii), 'r') as file_id:
cur_templates = json.load(file_id)
ques_templates.extend(cur_templates)
# utils.pretty_print_templates(ques_templates, 1)
# 1. Check if there a scene_id_file specified.
# 2. Check if num_images is -1
if FLAGS.scene_id_file != '':
with open(FLAGS.scene_id_file, 'r') as file_id:
missing_ids = [int(ii.strip('\n')) for ii in file_id.readlines()]
print('Dialogs missing for scenes: %d' % len(missing_ids))
# Create a image_index -> scenes list index dictionary
image_list_id_dict = {ii['image_index']: index
for index, ii in enumerate(scenes['scenes'])}
scenes_subset = [scenes['scenes'][image_list_id_dict[scene_id]]
for scene_id in missing_ids]
elif FLAGS.num_images == -1:
scenes_subset = scenes['scenes']
else:
scenes_subset = scenes['scenes'][0: FLAGS.num_images]
# BFS for each scene.
if FLAGS.num_workers == 1:
# Single thread version.
dialogs = []
for index, scene in enumerate(scenes_subset):
cur_time = time.strftime('%a-%d%b%y-%X', time.gmtime())
print('Generating [ %s ] [ Worker: %d, Progress: %d/%d Scene: %d ]' %
(cur_time, 0, index, len(scenes_subset), scene['image_index']))
gen_dialog = generate_dialog_bfs(
scene, cap_templates, ques_templates)
dialogs.append(gen_dialog)
else:
# Multithread version.
output_q = multiprocessing.Queue()
jobs = []
for worker_id in range(FLAGS.num_workers):
allotment = scenes_subset[worker_id::FLAGS.num_workers]
inputs = (allotment, cap_templates, ques_templates)
inputs += (worker_id, output_q)
process = multiprocessing.Process(target=worker, args=inputs)
jobs.append(process)