make code public
This commit is contained in:
commit
9d8b93db26
26 changed files with 11937 additions and 0 deletions
148
README.md
Normal file
148
README.md
Normal file
|
@ -0,0 +1,148 @@
|
||||||
|
|
||||||
|
# NSVD
|
||||||
|
|
||||||
|
This repository contains the official code of the paper:
|
||||||
|
|
||||||
|
## Neuro-Symbolic Visual Dialog [[PDF](TODO)]
|
||||||
|
|
||||||
|
[Adnen Abdessaied](https://adnenabdessaied.de), [Mihai Bace](https://perceptualui.org/people/bace/), [Andreas Bulling](https://perceptualui.org/people/bulling/)
|
||||||
|
**Oral Presentaion / Poster**
|
||||||
|
International Conferenc on Computational Linguistics (COLING), 2022 / Gyeongju, Republic of Korea.
|
||||||
|
|
||||||
|
If you find our code useful or use it in your own projects, please cite our paper:
|
||||||
|
|
||||||
|
``TODO``
|
||||||
|
|
||||||
|
# Abstract
|
||||||
|
|
||||||
|
We propose Neuro-Symbolic Visual Dialog (NSVD) —the first method to combine deep learning and symbolic program execution for multi-round visually-grounded reasoning. NSVD significantly outperforms existing purely-connectionist methods on two key challenges inherent to visual dialog: long-distance co-reference resolution as well as vanishing question-answering performance. We demonstrate the latter by proposing a more realistic and stricter evaluation scheme in which we use predicted answers for the full dialog history when calculating accuracy. We describe two variants of our model and show that using this new scheme, our best model achieves an accuracy of 99.72% on CLEVR-Dialog —a relative improvement of more than 10% over the state
|
||||||
|
of the art —while only requiring a fraction of training data. Moreover, we demonstrate that our neuro-symbolic models have a higher mean first failure round, are more robust against incomplete dialog histories, and generalise better not only to dialogs that are up to three times longer than those seen during training but also to unseen question types and scenes.
|
||||||
|
|
||||||
|
# Method
|
||||||
|
|
||||||
|
<figure>
|
||||||
|
<p align="center"><img src="misc/method_overview.png" alt="missing"/></
|
||||||
|
<figcaption>Overview of our method NSVD.</figcaption>
|
||||||
|
</figure>
|
||||||
|
|
||||||
|
<figure>
|
||||||
|
<p align="center"><img src="misc/method_smaller.png" alt="missing"/></
|
||||||
|
<figcaption>Overview of concat and stack encoders.</figcaption>
|
||||||
|
</figure>
|
||||||
|
|
||||||
|
# Requirements
|
||||||
|
|
||||||
|
- PyTorch 1.3.1
|
||||||
|
- Python 3.6
|
||||||
|
- Ubuntu 18.04
|
||||||
|
|
||||||
|
# Raw Data
|
||||||
|
|
||||||
|
## Scene Data
|
||||||
|
|
||||||
|
We used CLEVR and Minecraft images in this project. The raw images have a large footprint and we won't upload them. However, we provide their json file as well as their derendedred versions. They can be found in :
|
||||||
|
|
||||||
|
- ``data/scenes/raw``
|
||||||
|
- ``data/scenes/derendered``
|
||||||
|
|
||||||
|
## Dialog Data
|
||||||
|
|
||||||
|
The dialog data we used can be found in ``data/dialogs``.
|
||||||
|
You can also create your own data using the ``generate_dataset.py`` script.
|
||||||
|
|
||||||
|
# Preprocessing
|
||||||
|
|
||||||
|
## Scenes
|
||||||
|
|
||||||
|
The derendered scenes do not need any further preprocessing and can be diretly used with our neuro-symbolic executor.
|
||||||
|
|
||||||
|
## Dialogs
|
||||||
|
|
||||||
|
To preprocess the dialogs, follow these steps:
|
||||||
|
|
||||||
|
- ``cd preprocess_dialogs``
|
||||||
|
|
||||||
|
For the stack encoder, execute
|
||||||
|
|
||||||
|
- ``python preprocess.py --input_dialogs_json <path_to_raw_dialog_file> --input_vocab_json '' --output_vocab_json <path_where_to_save_the_vocab> --output_h5_file <path_of_the_output_file> --split <train/val/test> --mode stack``
|
||||||
|
|
||||||
|
For the concat encoder, execute
|
||||||
|
|
||||||
|
- ``python preprocess.py --input_dialogs_json <path_to_raw_dialog_file> --input_vocab_json '' --output_vocab_json <path_where_to_save_the_vocab> --output_h5_file <path_of_the_output_file> --split <train/val/test> --mode concat``
|
||||||
|
|
||||||
|
# Training
|
||||||
|
|
||||||
|
First, change directory
|
||||||
|
|
||||||
|
- ``cd ../prog_generator``
|
||||||
|
|
||||||
|
## Caption Program Parser
|
||||||
|
|
||||||
|
To train the caption parser, execute
|
||||||
|
|
||||||
|
- ``python train_caption_parser.py --mode train --run_dir <experiment_dir> --res_path <path_to_store_results> --dataPathTr <path_to_preprocessed_training_data> --dataPathVal <path_to_preprocessed_val_data> --dataPathTest <path_to_preprocessed_test_data> --vocab_path <path_where_to_save_the_vocab>``
|
||||||
|
|
||||||
|
## Question Program Parser
|
||||||
|
|
||||||
|
To train the question program parser with the stack encoder, execute
|
||||||
|
|
||||||
|
- ``python train_question_parser.py --mode train --run_dir <experiment_dir> --text_log_dir <log_dir_path> --dataPathTr <path_to_preprocessed_training_data> --dataPathVal <path_to_preprocessed_val_data> --dataPathTest <path_to_preprocessed_test_data> --scenePath <path_to_derendered_scenes> --vocab_path <path_where_to_save_the_vocab> --encoder_type 2``
|
||||||
|
|
||||||
|
To train the question program parser with the concat encoder, execute
|
||||||
|
|
||||||
|
- ``python train_question_parser.py --mode train --run_dir <experiment_dir> --text_log_dir <log_dir_path> --dataPathTr <path_to_preprocessed_training_data> --dataPathVal <path_to_preprocessed_val_data> --dataPathTest <path_to_preprocessed_test_data> --scenePath <path_to_derendered_scenes> --vocab_path <path_where_to_save_the_vocab> --encoder_type 1``
|
||||||
|
|
||||||
|
## Baselines
|
||||||
|
|
||||||
|
- [MAC-XXX](https://github.com/ahmedshah1494/clevr-dialog-mac-net/tree/dialog-macnet)
|
||||||
|
|
||||||
|
- [HCN](https://github.com/jojonki/Hybrid-Code-Networks)
|
||||||
|
|
||||||
|
# Evaluation
|
||||||
|
|
||||||
|
To evaluate using the *Hist+GT* scheme, execute
|
||||||
|
|
||||||
|
- ``python train_question_parser.py --mode test_with_gt --run_dir <experiment_dir> --text_log_dir <log_dir_path> --dataPathTr <path_to_preprocessed_training_data> --dataPathVal <path_to_preprocessed_val_data> --dataPathTest <path_to_preprocessed_test_data> --scenePath <path_to_derendered_scenes> --vocab_path <path_where_to_save_the_vocab> --encoder_type <1/2> --questionNetPath <path_to_pretrained_question_parser> --captionNetPath <path_to_pretrained_caption_parser> --dialogLen <total_number_of_dialog_rounds> --last_n_rounds <number_of_last_rounds_to_considered_in_history>``
|
||||||
|
|
||||||
|
To evaluate using the *Hist+Pred* scheme, execute
|
||||||
|
|
||||||
|
- ``python train_question_parser.py --mode test_with_pred --run_dir <experiment_dir> --text_log_dir <log_dir_path> --dataPathTr <path_to_preprocessed_training_data> --dataPathVal <path_to_preprocessed_val_data> --dataPathTest <path_to_preprocessed_test_data> --scenePath <path_to_derendered_scenes> --vocab_path <path_where_to_save_the_vocab> --encoder_type <1/2> --questionNetPath <path_to_pretrained_question_parser> --captionNetPath <path_to_pretrained_caption_parser> --dialogLen <total_number_of_dialog_rounds> --last_n_rounds <number_of_last_rounds_to_considered_in_history>``
|
||||||
|
|
||||||
|
# Results
|
||||||
|
|
||||||
|
We achieve new state-of-the-art performance on clevr-dialog.
|
||||||
|
|
||||||
|
## Hist+GT
|
||||||
|
|
||||||
|
| <center>Model</center> | <center>Accurcy</center> | <center>NFFR</center> |
|
||||||
|
| :---: | :---: | :---: |
|
||||||
|
| MAC-CQ | 97.34 | 0.92 |
|
||||||
|
| + CAA | 97.87 | 0.94 |
|
||||||
|
| + MTM | 97.58 | 0.92 |
|
||||||
|
| HCN | 75.88 | 0.34 |
|
||||||
|
| **NSVD-concat (Ours)** | 99.59 | 0.98 |
|
||||||
|
| **NSVD-stack (Ours)** | **99.72** | **0.99** |
|
||||||
|
|
||||||
|
## Hist+Pred
|
||||||
|
|
||||||
|
| <center>Model</center> | <center>Accurcy</center> | <center>NFFR</center> |
|
||||||
|
| :---: | :---: | :---: |
|
||||||
|
| MAC-CQ | 41.10 | 0.15 |
|
||||||
|
| + CAA | 89.39 | 0.75 |
|
||||||
|
| + MTM | 70.39 | 0.46 |
|
||||||
|
| HCN | 74.42 | 0.32 |
|
||||||
|
| **NSVD-concat (Ours)** | 99.59 | 0.98 |
|
||||||
|
| **NSVD-stack (Ours)** | **99.72** | **0.99** |
|
||||||
|
|
||||||
|
We refer to our paper for the results of the other experiments.
|
||||||
|
|
||||||
|
# Acknowledgements
|
||||||
|
|
||||||
|
We thank [Ahmed Shah](https://www.linkedin.com/in/mahmedshah/) for his MAC-XXX implemetation,[Junki Ohmura](https://www.linkedin.com/in/junki/) for his HCN implemantation, [Jiayuan Mao](https://jiayuanm.com/) for providing us with the minecraft images, and finally [Satwik Kottur](https://satwikkottur.github.io/) for his clevr-dialog [codebase](https://github.com/satwikkottur/clevr-dialog).
|
||||||
|
|
||||||
|
# Contributors
|
||||||
|
|
||||||
|
- [Adnen Abdessaied](https://adnenabdessaied.de)
|
||||||
|
|
||||||
|
For any questions or enquiries, don't not hesitate to contact the above contributor.
|
||||||
|
|
224
clevr_utils.py
Normal file
224
clevr_utils.py
Normal file
|
@ -0,0 +1,224 @@
|
||||||
|
"""Utilities for CLEVR-Dialog dataset generation.
|
||||||
|
|
||||||
|
Author: Satwik Kottur
|
||||||
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
|
|
||||||
|
|
||||||
|
def pretty_print_templates(templates, verbosity=1):
|
||||||
|
"""Pretty prints templates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
templates: Templates to print
|
||||||
|
verbosity: 1 to print name and type of the templates
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Verbosity 1: Name and type.
|
||||||
|
print('-'*70)
|
||||||
|
for ii in templates:
|
||||||
|
print('[Name: %s] [Type: %s]' % (ii['name'], ii['type']))
|
||||||
|
print('-'*70)
|
||||||
|
print('Total of %s templates..' % len(templates))
|
||||||
|
print('-'*70)
|
||||||
|
|
||||||
|
|
||||||
|
def pretty_print_scene_objects(scene):
|
||||||
|
"""Pretty prints scene objects.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scene: Scene graph containing list of objects
|
||||||
|
"""
|
||||||
|
|
||||||
|
for index, ii in enumerate(scene['objects']):
|
||||||
|
print_args = (index, ii['shape'], ii['color'],
|
||||||
|
ii['size'], ii['material'])
|
||||||
|
print('\t%d : %s-%s-%s-%s' % print_args)
|
||||||
|
|
||||||
|
|
||||||
|
def pretty_print_dialogs(dialogs):
|
||||||
|
"""Pretty prints generated dialogs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dialogs: Generated dialogs to print
|
||||||
|
"""
|
||||||
|
|
||||||
|
for scene_id, dialog_datum in enumerate(dialogs):
|
||||||
|
for dialog in dialog_datum['dialogs']:
|
||||||
|
print(dialog['caption'])
|
||||||
|
for round_id, ii in enumerate(dialog['dialog']):
|
||||||
|
coref_id = dialog['graph']['history'][round_id+1]['dependence']
|
||||||
|
in_tuple = (round_id, ii['question'], str(ii['answer']),
|
||||||
|
ii['template'], str(coref_id))
|
||||||
|
print('\t[Q-%d: %s] [A: %s] [%s] [%s]' % in_tuple)
|
||||||
|
|
||||||
|
|
||||||
|
def merge_update_scene_graph(orig_graph, graph_item):
|
||||||
|
"""Merges two scene graphs into one.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
orig_graph: Original scene graph
|
||||||
|
graph_item: New graph item to add to the scene graph
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
graph: Deep copy of the original scene graph after merging
|
||||||
|
"""
|
||||||
|
|
||||||
|
graph = copy.deepcopy(orig_graph)
|
||||||
|
# Local alias.
|
||||||
|
objects = graph['objects']
|
||||||
|
|
||||||
|
# If not mergeable, return the same scene graph.
|
||||||
|
if not graph_item['mergeable']:
|
||||||
|
return graph
|
||||||
|
|
||||||
|
# 1. Go through each new object
|
||||||
|
# 2. Find its batch in objects
|
||||||
|
# a. If found, assert for a clash of attributes, update
|
||||||
|
# b. If novel, just add the object as is
|
||||||
|
for new_obj in graph_item['objects']:
|
||||||
|
match_found = False
|
||||||
|
obj = objects.get(new_obj['id'], None)
|
||||||
|
|
||||||
|
if obj:
|
||||||
|
# Assert for existing entries.
|
||||||
|
for attr in new_obj:
|
||||||
|
try:
|
||||||
|
assert new_obj[attr] == obj.get(attr, new_obj[attr]),\
|
||||||
|
'Some of the attributes do not match!'
|
||||||
|
except:
|
||||||
|
pdb.set_trace()
|
||||||
|
|
||||||
|
# Add additional keys.
|
||||||
|
objects[new_obj['id']].update(new_obj)
|
||||||
|
else:
|
||||||
|
# Add the new object.
|
||||||
|
objects[new_obj['id']] = new_obj
|
||||||
|
|
||||||
|
# if a relation, update it
|
||||||
|
if 'relation' in graph_item:
|
||||||
|
rel = graph_item['relation']
|
||||||
|
# update it with object 2 id
|
||||||
|
id1 = graph_item['objects'][0]['id']
|
||||||
|
id2 = graph_item['objects'][1]['id']
|
||||||
|
rel_objs = graph['relationships'][rel][id1]
|
||||||
|
rel_objs.append(id2)
|
||||||
|
graph['relationships'][rel][id1] = rel_objs
|
||||||
|
|
||||||
|
# update objects in graph
|
||||||
|
graph['objects'] = objects
|
||||||
|
return graph
|
||||||
|
|
||||||
|
|
||||||
|
def add_object_ids(scenes):
|
||||||
|
"""Adds object ids field for input scenes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scenes: List of CLEVR scene graphs
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
scenes: Adds object_id field for the objects in the scene graph inplace
|
||||||
|
"""
|
||||||
|
|
||||||
|
for scene_id, scene in enumerate(scenes['scenes']):
|
||||||
|
for obj_id, _ in enumerate(scene['objects']):
|
||||||
|
scenes['scenes'][scene_id]['objects'][obj_id]['id'] = obj_id
|
||||||
|
return scenes
|
||||||
|
|
||||||
|
|
||||||
|
def clean_object_attributes(scenes):
|
||||||
|
"""Cleans attributes for objects, keeping only attributes and id.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scenes: Scene graph to clean
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
scenes: Cleaned up scene graphs inplace
|
||||||
|
"""
|
||||||
|
|
||||||
|
keys = ['shape', 'size', 'material', 'color', 'id']
|
||||||
|
for scene_id, scene in enumerate(scenes['scenes']):
|
||||||
|
for obj_id, obj in enumerate(scene['objects']):
|
||||||
|
new_obj = {key: obj[key] for key in keys}
|
||||||
|
scenes['scenes'][scene_id]['objects'][obj_id] = new_obj
|
||||||
|
return scenes
|
||||||
|
|
||||||
|
|
||||||
|
def pretty_print_corefs(dialog, coref_groups):
|
||||||
|
"""Prints coreferences for a dialog, higlighting different groups in colors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dialog: Generated dialogs to print
|
||||||
|
coref_groups: Coreference groups for dialogs
|
||||||
|
"""
|
||||||
|
|
||||||
|
colorama.init()
|
||||||
|
# Mapping of group_id -> color_ids for (foreground, background)
|
||||||
|
color_map = {}
|
||||||
|
groups = coref_groups.get(0, [])
|
||||||
|
colored, color_map = pretty_print_coref_sentence(dialog['caption'], groups,
|
||||||
|
color_map)
|
||||||
|
print('\n\nC: %s' % colored)
|
||||||
|
for round_id, round_datum in enumerate(dialog['dialog']):
|
||||||
|
question = round_datum['question']
|
||||||
|
groups = coref_groups.get(round_id + 1, [])
|
||||||
|
colored, color_map = pretty_print_coref_sentence(question, groups,
|
||||||
|
color_map)
|
||||||
|
print('%d: %s' % (round_id, colored))
|
||||||
|
|
||||||
|
|
||||||
|
def pretty_print_coref_sentence(sentence, groups, color_map):
|
||||||
|
"""Prints a sentence containing difference coreference groups.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sentence: Text sentence
|
||||||
|
groups: List of coreference groups with spans
|
||||||
|
color_map: List of groups and associated color maps
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
sentence: Text sentence with colors inserted
|
||||||
|
color_map: Updated, if new groups in the current sentence
|
||||||
|
"""
|
||||||
|
|
||||||
|
fore_colors = ['RED', 'GREEN', 'YELLOW', 'BLUE', 'MAGENTA']
|
||||||
|
back_colors = ['BLACK', 'YELLOW', 'CYAN']
|
||||||
|
insertions = []
|
||||||
|
for group in groups:
|
||||||
|
group_id = group['group_id']
|
||||||
|
if group_id in color_map:
|
||||||
|
forecolor_id, backcolor_id = color_map[group_id]
|
||||||
|
else:
|
||||||
|
num_groups = len(color_map)
|
||||||
|
forecolor_id = num_groups % len(fore_colors)
|
||||||
|
backcolor_id = num_groups // len(fore_colors)
|
||||||
|
color_map[group_id] = (forecolor_id, backcolor_id)
|
||||||
|
|
||||||
|
forecolor = fore_colors[forecolor_id]
|
||||||
|
backcolor = back_colors[backcolor_id]
|
||||||
|
insertions.append(
|
||||||
|
(group['span'][0], getattr(colorama.Fore, forecolor)))
|
||||||
|
insertions.append(
|
||||||
|
(group['span'][0], getattr(colorama.Back, backcolor)))
|
||||||
|
insertions.append((group['span'][1],
|
||||||
|
getattr(colorama.Style, 'RESET_ALL')))
|
||||||
|
|
||||||
|
# Perform insertions.
|
||||||
|
sentence = insert_into_sentence(sentence, insertions)
|
||||||
|
return sentence, color_map
|
||||||
|
|
||||||
|
|
||||||
|
def insert_into_sentence(sentence, insertions):
|
||||||
|
"""Sorts and performs insertions from right.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sentence: Sentence to perform insertions into
|
||||||
|
insertions: List of insertions, format: (position, text_insert)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
sentence: Inplace inserted sentence
|
||||||
|
"""
|
||||||
|
|
||||||
|
insertions = sorted(insertions, key=lambda x: x[0], reverse=True)
|
||||||
|
for position, text in insertions:
|
||||||
|
sentence = sentence[:position] + text + sentence[position:]
|
||||||
|
return sentence
|
1049
constraints.py
Normal file
1049
constraints.py
Normal file
File diff suppressed because it is too large
Load diff
1055
constraints_minecraft.py
Normal file
1055
constraints_minecraft.py
Normal file
File diff suppressed because it is too large
Load diff
1055
constraints_splitA.py
Normal file
1055
constraints_splitA.py
Normal file
File diff suppressed because it is too large
Load diff
1055
constraints_splitB.py
Normal file
1055
constraints_splitB.py
Normal file
File diff suppressed because it is too large
Load diff
0
executor/__init__.py
Normal file
0
executor/__init__.py
Normal file
47
executor/clevr_statics.py
Normal file
47
executor/clevr_statics.py
Normal file
|
@ -0,0 +1,47 @@
|
||||||
|
"""
|
||||||
|
author: Adnen Abdessaied
|
||||||
|
maintainer: "Adnen Abdessaied"
|
||||||
|
website: adnenabdessaied.de
|
||||||
|
version: 1.0.1
|
||||||
|
"""
|
||||||
|
|
||||||
|
COLORS = ["blue", "brown", "cyan", "gray", "green", "purple", "red", "yellow"]
|
||||||
|
MATERIALS = ["rubber", "metal"]
|
||||||
|
SHAPES = ["cube", "cylinder", "sphere"]
|
||||||
|
SIZES = ["large", "small"]
|
||||||
|
|
||||||
|
ATTRIBUTES_ALL = COLORS + MATERIALS + SHAPES + SIZES
|
||||||
|
|
||||||
|
ANSWER_CANDIDATES = {
|
||||||
|
# Count questions
|
||||||
|
"count-all": ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10"],
|
||||||
|
"count-other": ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"],
|
||||||
|
"count-all-group": ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10"],
|
||||||
|
"count-attribute": ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10"],
|
||||||
|
"count-attribure-group": ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10"],
|
||||||
|
"count-obj-rel-imm": ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"],
|
||||||
|
"count-obj-rel-imm2": ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"],
|
||||||
|
"count-obj-rel-early": ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"],
|
||||||
|
"count-obj-exclude-imm": ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"],
|
||||||
|
"count-obj-exclude-early": ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"],
|
||||||
|
|
||||||
|
# Existence questions
|
||||||
|
"exist-other": ["yes", "no"],
|
||||||
|
"exist-attribute": ["yes", "no"],
|
||||||
|
"exist-attribute-group": ["yes", "no"],
|
||||||
|
"exist-obj-rel-imm": ["yes", "no"],
|
||||||
|
"exist-obj-rel-imm2": ["yes", "no"],
|
||||||
|
"exist-obj-rel-early": ["yes", "no"],
|
||||||
|
"exist-obj-exclude-imm": ["yes", "no"],
|
||||||
|
"exist-obj-exclude-early": ["yes", "no"],
|
||||||
|
|
||||||
|
# Seek questions
|
||||||
|
"seek-attr-imm": ATTRIBUTES_ALL,
|
||||||
|
"seek-attr-imm2": ATTRIBUTES_ALL,
|
||||||
|
"seek-attr-early": ATTRIBUTES_ALL,
|
||||||
|
"seek-attr-sim-early": ATTRIBUTES_ALL,
|
||||||
|
"seek-attr-rel-imm": ATTRIBUTES_ALL,
|
||||||
|
"seek-attr-rel-early": ATTRIBUTES_ALL,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
44
executor/minecraft_statics.py
Normal file
44
executor/minecraft_statics.py
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
"""
|
||||||
|
author: Adnen Abdessaied
|
||||||
|
maintainer: "Adnen Abdessaied"
|
||||||
|
website: adnenabdessaied.de
|
||||||
|
version: 1.0.1
|
||||||
|
"""
|
||||||
|
|
||||||
|
CLASSES = ["pig", "cow", "sheep", "chicken", "wolf", "horse", "villager", "treeA", "treeB", "armorstand", "boat", "minecart"]
|
||||||
|
DIRECTIONS = ["facing_forward", "facing_backward", "facing_right", "facing_left"]
|
||||||
|
NATURES = ["animal", "human", "plant", "inanimated_object"]
|
||||||
|
|
||||||
|
ATTRIBUTES_ALL = CLASSES + DIRECTIONS + NATURES
|
||||||
|
|
||||||
|
ANSWER_CANDIDATES = {
|
||||||
|
# Count questions
|
||||||
|
"count-all": ["0", "1", "2", "3", "4", "5", "6"],
|
||||||
|
"count-other": ["0", "1", "2", "3", "4", "5", "6"],
|
||||||
|
"count-all-group": ["0", "1", "2", "3", "4", "5", "6"],
|
||||||
|
"count-attribute": ["0", "1", "2", "3", "4", "5", "6"],
|
||||||
|
"count-attribure-group": ["0", "1", "2", "3", "4", "5", "6"],
|
||||||
|
"count-obj-rel-imm": ["0", "1", "2", "3", "4", "5", "6"],
|
||||||
|
"count-obj-rel-imm2": ["0", "1", "2", "3", "4", "5", "6"],
|
||||||
|
"count-obj-rel-early": ["0", "1", "2", "3", "4", "5", "6"],
|
||||||
|
"count-obj-exclude-imm": ["0", "1", "2", "3", "4", "5", "6"],
|
||||||
|
"count-obj-exclude-early": ["0", "1", "2", "3", "4", "5", "6"],
|
||||||
|
|
||||||
|
# Existence questions
|
||||||
|
"exist-other": ["yes", "no"],
|
||||||
|
"exist-attribute": ["yes", "no"],
|
||||||
|
"exist-attribute-group": ["yes", "no"],
|
||||||
|
"exist-obj-rel-imm": ["yes", "no"],
|
||||||
|
"exist-obj-rel-imm2": ["yes", "no"],
|
||||||
|
"exist-obj-rel-early": ["yes", "no"],
|
||||||
|
"exist-obj-exclude-imm": ["yes", "no"],
|
||||||
|
"exist-obj-exclude-early": ["yes", "no"],
|
||||||
|
|
||||||
|
# Seek questions
|
||||||
|
"seek-attr-imm": ATTRIBUTES_ALL,
|
||||||
|
"seek-attr-imm2": ATTRIBUTES_ALL,
|
||||||
|
"seek-attr-early": ATTRIBUTES_ALL,
|
||||||
|
"seek-attr-sim-early": ATTRIBUTES_ALL,
|
||||||
|
"seek-attr-rel-imm": ATTRIBUTES_ALL,
|
||||||
|
"seek-attr-rel-early": ATTRIBUTES_ALL,
|
||||||
|
}
|
1678
executor/symbolic_executor.py
Normal file
1678
executor/symbolic_executor.py
Normal file
File diff suppressed because it is too large
Load diff
952
generate_dataset.py
Normal file
952
generate_dataset.py
Normal file
|
@ -0,0 +1,952 @@
|
||||||
|
r"""Generates CLEVR-Dialog dataset.
|
||||||
|
|
||||||
|
Needs access to the following files:
|
||||||
|
synonyms: Contains several synonyms for each word in the question/caption.
|
||||||
|
caption templates: List of caption templates.
|
||||||
|
question templates: List of question templates.
|
||||||
|
metainfo: Meta-information related to attributes and values of CLEVR objects.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python -u generate_dataset.py \
|
||||||
|
--scene_path="data/scenes/CLEVR_train_scenes.json" \
|
||||||
|
--num_beams=100 \
|
||||||
|
--num_workers=12 \
|
||||||
|
--save_path="data/clevr_train_raw.json"
|
||||||
|
|
||||||
|
Author: Satwik Kottur
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import collections
|
||||||
|
import json
|
||||||
|
import multiprocessing
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
from absl import flags
|
||||||
|
from absl import app
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm as progressbar
|
||||||
|
|
||||||
|
import clevr_utils as utils
|
||||||
|
import global_vars as gvars
|
||||||
|
# import constraints_splitB as constraints
|
||||||
|
import constraints
|
||||||
|
|
||||||
|
FLAGS = flags.FLAGS
|
||||||
|
flags.DEFINE_string('synonym_path', '/projects/abdessaied/clevr-dialog/templates/synonyms.json',
|
||||||
|
'Path to synonyms file')
|
||||||
|
flags.DEFINE_string('metainfo_path', '/projects/abdessaied/clevr-dialog/templates/metainfo.json',
|
||||||
|
'Path to meta information file')
|
||||||
|
flags.DEFINE_string('caption_template_root', '/projects/abdessaied/clevr-dialog/templates/captions/',
|
||||||
|
'Root to folder with caption templates')
|
||||||
|
flags.DEFINE_string('question_template_root', '/projects/abdessaied/clevr-dialog/templates/questions/',
|
||||||
|
'Root to folder with question templates')
|
||||||
|
flags.DEFINE_string('scene_path',
|
||||||
|
# '/projects/abdessaied/clevr-dialog/output/result_clevr_oroginal_test.json',
|
||||||
|
'/projects/abdessaied/clevr-dataset-gen/output_finetune_20_objs_with_masks_many_attr/CLEVR_scenes.json',
|
||||||
|
'Path to CLEVR scene path json file')
|
||||||
|
flags.DEFINE_string('scene_id_file', '',
|
||||||
|
'Path to specific CLEVR scene ids to generate dialogs')
|
||||||
|
flags.DEFINE_string('save_path', '/projects/abdessaied/clevr-dialog/output/raw_data_modified/dialogs_finetune_20_objects_10_rounds.json',
|
||||||
|
'Path to save the dataset json')
|
||||||
|
flags.DEFINE_integer('num_beams', 100, 'Number of beams in dialog search')
|
||||||
|
flags.DEFINE_integer('num_workers', 64, 'Number of workers to use in search')
|
||||||
|
flags.DEFINE_integer('captions_per_image', 5, 'Number of captions per image')
|
||||||
|
flags.DEFINE_integer('num_images', -1,
|
||||||
|
'Number of images to generate dialogs. -1 for all.')
|
||||||
|
flags.DEFINE_integer('num_rounds', 10, 'Number of rounds in each dialog')
|
||||||
|
|
||||||
|
|
||||||
|
# Number of beams and distribution of question types.
|
||||||
|
# Start cutting down beams after 5th round.
|
||||||
|
# Heuristics (for round 4):
|
||||||
|
# A. count <= 2 1 <= seek <= 3 exist <= 2
|
||||||
|
# B. count + exist <= 3
|
||||||
|
# C. Independent questions <= 1
|
||||||
|
# Heuristics (for round 5):
|
||||||
|
# A. count <= 2 2 <= seek <= 4 exist <= 2
|
||||||
|
# B. count + exist <= 3
|
||||||
|
# C. Independent questions <= 1
|
||||||
|
ranges = {3: {'indep': [0, 1], 'seek': [1, 4], 'exist': [0, 1],
|
||||||
|
'count': [0, 1], 'exist+count': [0, 2]},
|
||||||
|
4: {'indep': [0, 1], 'seek': [2, 4], 'exist': [0, 1],
|
||||||
|
'count': [0, 1], 'exist+count': [0, 2]},
|
||||||
|
5: {'indep': [0, 1], 'seek': [2, 5], 'exist': [0, 2],
|
||||||
|
'count': [0, 2], 'exist+count': [0, 3]},
|
||||||
|
6: {'indep': [0, 1], 'seek': [2, 5], 'exist': [0, 2],
|
||||||
|
'count': [0, 2], 'exist+count': [0, 3]},
|
||||||
|
7: {'indep': [0, 2], 'seek': [3, 5], 'exist': [0, 2],
|
||||||
|
'count': [0, 2], 'exist+count': [0, 3]},
|
||||||
|
8: {'indep': [0, 2], 'seek': [3, 6], 'exist': [0, 3],
|
||||||
|
'count': [0, 3], 'exist+count': [0, 3]},
|
||||||
|
9: {'indep': [0, 2], 'seek': [3, 6], 'exist': [0, 3],
|
||||||
|
'count': [0, 3], 'exist+count': [0, 4]}}
|
||||||
|
|
||||||
|
|
||||||
|
def mapping(tag):
|
||||||
|
"""Maps tag to attribute.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tag: An input tag
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tag_label: Label for the input tag
|
||||||
|
"""
|
||||||
|
|
||||||
|
return gvars.METAINFO['tag_map'][tag.replace('1', '')]
|
||||||
|
|
||||||
|
|
||||||
|
def inv_mapping(attribute, arg_id=0):
|
||||||
|
"""Inverse maps attribute to tag.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attribute: Name of the attribute
|
||||||
|
arg_id: Argument id to use. Append 1 if arg_id is 1, else nothing
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
base_tag: The string for the tag
|
||||||
|
"""
|
||||||
|
|
||||||
|
base_tag = gvars.METAINFO['tag_inv_map'][attribute]
|
||||||
|
if arg_id > 0:
|
||||||
|
base_tag = base_tag[:-1] + str(arg_id) + base_tag[-1]
|
||||||
|
|
||||||
|
return base_tag
|
||||||
|
|
||||||
|
|
||||||
|
def get_tag_group(tag):
|
||||||
|
"""Gets the group id from tag string.
|
||||||
|
|
||||||
|
For example, tag string of <S> is 0, <S1> is 1.
|
||||||
|
Assumes single digit group id.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tag: Tag string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
group_id: Return extracted group id
|
||||||
|
"""
|
||||||
|
|
||||||
|
group_id = 0 if len(tag) <= 3 else int(tag[-2])
|
||||||
|
return group_id
|
||||||
|
|
||||||
|
|
||||||
|
def replace_attribute(text, tag, obj_group, eliminate=False):
|
||||||
|
"""Replaces the attribute tags in text using available object properties.
|
||||||
|
|
||||||
|
NOTE: If shape is to be replaced, we use 'thing' in its place.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text template to perform replacement
|
||||||
|
tag: The tags to replace in the text
|
||||||
|
obj_group: Available object properties to replace with
|
||||||
|
eliminate: Eliminate the remaining attribute tags
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
replaced_text: The replaced text
|
||||||
|
"""
|
||||||
|
|
||||||
|
group = get_tag_group(tag)
|
||||||
|
if mapping(tag) == 'relation':
|
||||||
|
# Actual relation tag, else position tag.
|
||||||
|
if tag == '<R>':
|
||||||
|
relation_list = gvars.METAINFO['relation_phrases'][obj_group['relation']]
|
||||||
|
relation_cand = random.choice(relation_list)
|
||||||
|
else:
|
||||||
|
relation_cand = obj_group['relation']
|
||||||
|
|
||||||
|
return text.replace(tag, relation_cand)
|
||||||
|
|
||||||
|
if mapping(tag) == 'shape':
|
||||||
|
if eliminate:
|
||||||
|
replacer = 'thing'
|
||||||
|
else:
|
||||||
|
replacer = str(obj_group['objects'][group][mapping(tag)])
|
||||||
|
|
||||||
|
# Plural forms for groups.
|
||||||
|
if obj_group.get('count', 1) > 1 or obj_group.get('use_plural', False):
|
||||||
|
replacer += 's'
|
||||||
|
elif mapping(tag) == 'count':
|
||||||
|
if eliminate:
|
||||||
|
replacer = ''
|
||||||
|
else:
|
||||||
|
replacer = str(obj_group['count'])
|
||||||
|
else:
|
||||||
|
if eliminate:
|
||||||
|
replacer = ''
|
||||||
|
else:
|
||||||
|
replacer = str(obj_group['objects'][group][mapping(tag)])
|
||||||
|
return text.replace(tag, replacer)
|
||||||
|
|
||||||
|
|
||||||
|
def realize_text_and_extract_scene(scene, template, filter_objs):
|
||||||
|
"""Samples attributes for template using filtered objects.
|
||||||
|
|
||||||
|
In addition, creates scene graph for the new information added.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scene: Current scene graph
|
||||||
|
template: Text template to use to generate questions
|
||||||
|
filter_objs: Set of objects satisfying constraints of current template
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
sample: Contains the text realization and scene graph
|
||||||
|
"""
|
||||||
|
|
||||||
|
def default_list(): return collections.defaultdict(list)
|
||||||
|
graph = {'relationships': collections.defaultdict(default_list),
|
||||||
|
'counts': {}, 'exists': {}, 'history': [], 'objects': {}}
|
||||||
|
|
||||||
|
# number of inputs
|
||||||
|
n_inputs = template.get('inputs', 1)
|
||||||
|
# sample a text template
|
||||||
|
text_sample = random.choice(template['text'])
|
||||||
|
text_sample_index = template['text'].index(text_sample)
|
||||||
|
|
||||||
|
# extract attribute tags and get them into groups
|
||||||
|
tags = re.findall('(<[\d\w]*>)', text_sample)
|
||||||
|
|
||||||
|
tag_groups = collections.defaultdict(list)
|
||||||
|
for tag in tags:
|
||||||
|
group_id = get_tag_group(tag)
|
||||||
|
tag_groups[group_id].append(tag)
|
||||||
|
|
||||||
|
# sample a random element from filtered
|
||||||
|
arg_sample = random.choice(filter_objs)
|
||||||
|
# scene information obtained from the current round
|
||||||
|
graph_item = arg_sample['graph']
|
||||||
|
|
||||||
|
# remove tags from text not allowed by filter_objs
|
||||||
|
for arg_ind in range(n_inputs):
|
||||||
|
obj_sample = arg_sample['objects'][arg_ind]
|
||||||
|
avail_attrs = obj_sample['optional'] + obj_sample['required']
|
||||||
|
|
||||||
|
for ii in tag_groups[arg_ind][::-1]:
|
||||||
|
if mapping(ii) not in avail_attrs:
|
||||||
|
tag_groups[arg_ind].remove(ii)
|
||||||
|
text_sample = replace_attribute(
|
||||||
|
text_sample, ii, arg_sample, True)
|
||||||
|
|
||||||
|
# assert that all required attributes are present as tags
|
||||||
|
for attribute in obj_sample['required']:
|
||||||
|
required_tag = inv_mapping(attribute, arg_ind)
|
||||||
|
if required_tag not in tag_groups[arg_ind]:
|
||||||
|
print("required_tag: {}".format(required_tag))
|
||||||
|
print("template: {}".format(template))
|
||||||
|
assert required_tag in tag_groups[arg_ind], \
|
||||||
|
'A required attribute is missing in template!'
|
||||||
|
|
||||||
|
# start compiling tags to keep
|
||||||
|
tags_to_keep = [inv_mapping(ii, arg_ind)
|
||||||
|
for ii in obj_sample['required']]
|
||||||
|
|
||||||
|
# filter out those not present in text template
|
||||||
|
optional_tags = [inv_mapping(ii, arg_ind)
|
||||||
|
for ii in obj_sample['optional']]
|
||||||
|
optional_tags = [
|
||||||
|
ii for ii in optional_tags if ii in tag_groups[arg_ind]]
|
||||||
|
|
||||||
|
# if tags_to_keep is empty, sample from optional with 1:70 2:25 3:5
|
||||||
|
if len(optional_tags) > 0:
|
||||||
|
if len(tags_to_keep) > 0:
|
||||||
|
n_tags_sample = [0, 1, 2]
|
||||||
|
else:
|
||||||
|
n_tags_sample = [1, 2, 3]
|
||||||
|
n_sample = np.random.choice(n_tags_sample, 1,
|
||||||
|
p=gvars.METAINFO['probabilities'],
|
||||||
|
replace=False)
|
||||||
|
# lower cap at the length of optional
|
||||||
|
n_sample = min(n_sample[0], len(optional_tags))
|
||||||
|
if n_sample > 0:
|
||||||
|
tags_to_keep += random.sample(optional_tags, n_sample)
|
||||||
|
|
||||||
|
# now create a dictionary of placeholders with actual attribute values
|
||||||
|
for tag in tag_groups[arg_ind]:
|
||||||
|
remove = tag not in tags_to_keep
|
||||||
|
text_sample = replace_attribute(
|
||||||
|
text_sample, tag, arg_sample, remove)
|
||||||
|
|
||||||
|
# remove attributes from objects not included in tags_to_keep
|
||||||
|
if 'objects' in graph_item:
|
||||||
|
for ii in gvars.METAINFO['attributes']:
|
||||||
|
if inv_mapping(ii, arg_ind) not in tags_to_keep:
|
||||||
|
if ii in graph_item['objects'][arg_ind]:
|
||||||
|
del graph_item['objects'][arg_ind][ii]
|
||||||
|
|
||||||
|
# record the caption info
|
||||||
|
# Record info and merge scene graphs.
|
||||||
|
args = []
|
||||||
|
# if "unique-obj" == template['label']:
|
||||||
|
# print('yey')
|
||||||
|
for obj in arg_sample['objects']:
|
||||||
|
if obj is None:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
for k in obj['required']:
|
||||||
|
arg = obj.get(k, None)
|
||||||
|
if arg is not None:
|
||||||
|
if arg not in args: # and type(arg) == str:
|
||||||
|
args.append(arg)
|
||||||
|
else:
|
||||||
|
arg = arg_sample.get(k, None)
|
||||||
|
if arg is not None and arg not in args and type(arg) == str:
|
||||||
|
args.append(arg)
|
||||||
|
arg = obj.get('attribute', None)
|
||||||
|
if arg is not None and arg not in args:
|
||||||
|
args.append(arg)
|
||||||
|
if template['label'] == 'obj-relation':
|
||||||
|
args.append(arg_sample['relation'])
|
||||||
|
|
||||||
|
if template['label'] == "count-att-no":
|
||||||
|
template['label'] = "count-att"
|
||||||
|
|
||||||
|
graph_item['round'] = 0
|
||||||
|
sample = {}
|
||||||
|
sample['template_info'] = [copy.deepcopy(template)]
|
||||||
|
sample['args'] = args
|
||||||
|
del sample['template_info'][-1]['text']
|
||||||
|
sample['template_info'][-1]['index'] = text_sample_index
|
||||||
|
sample['caption'] = text_sample
|
||||||
|
sample['template'] = template['label']
|
||||||
|
|
||||||
|
sample['dialog'] = []
|
||||||
|
|
||||||
|
# append history, update scene graph, and save the new scene graph
|
||||||
|
graph['history'].append(graph_item)
|
||||||
|
sample['graph'] = utils.merge_update_scene_graph(graph, graph_item)
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
def realize_question(dialog, template, filter_objs):
|
||||||
|
"""Samples attributes for template using filtered objects.
|
||||||
|
|
||||||
|
In addition, creates scene graph for the new information added.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scene: Current scene graph
|
||||||
|
template: Text template to use to generate questions
|
||||||
|
filter_objs: Set of objects satisfying constraints of current template
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
sample: Contains the text realization and scene graph
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Number of inputs.
|
||||||
|
n_inputs = template.get('inputs', 0)
|
||||||
|
# Sample a text template.
|
||||||
|
text_sample = random.choice(template['text'])
|
||||||
|
text_sample_index = template['text'].index(text_sample)
|
||||||
|
|
||||||
|
# Extract attribute tags and get them into groups.
|
||||||
|
tags = re.findall('(<[\d\w]*>)', text_sample)
|
||||||
|
tag_groups = collections.defaultdict(list)
|
||||||
|
for tag in tags:
|
||||||
|
group_id = get_tag_group(tag)
|
||||||
|
tag_groups[group_id].append(tag)
|
||||||
|
|
||||||
|
# Sample a random element from filtered.
|
||||||
|
arg_sample = random.choice(filter_objs)
|
||||||
|
|
||||||
|
# Remove tags from text not allowed by filter_objs.
|
||||||
|
for arg_ind in range(n_inputs):
|
||||||
|
obj_sample = arg_sample['objects'][arg_ind]
|
||||||
|
avail_attrs = obj_sample['optional'] + obj_sample['required']
|
||||||
|
|
||||||
|
for ii in tag_groups[arg_ind][::-1]:
|
||||||
|
if mapping(ii) not in avail_attrs:
|
||||||
|
tag_groups[arg_ind].remove(ii)
|
||||||
|
text_sample = replace_attribute(
|
||||||
|
text_sample, ii, arg_sample, True)
|
||||||
|
|
||||||
|
# Assert that all required attributes are present as tags.
|
||||||
|
for attribute in obj_sample['required']:
|
||||||
|
required_tag = inv_mapping(attribute, arg_ind)
|
||||||
|
# Make an exception for <R> and <P>
|
||||||
|
if required_tag == '<R>' and '<P>' in tag_groups[arg_ind]:
|
||||||
|
continue
|
||||||
|
assert required_tag in tag_groups[arg_ind], \
|
||||||
|
'A required attribute is missing in template!'
|
||||||
|
|
||||||
|
# Start compiling tags to keep.
|
||||||
|
tags_to_keep = [inv_mapping(ii, arg_ind)
|
||||||
|
for ii in obj_sample['required']]
|
||||||
|
# Filter out those not present in text template.
|
||||||
|
optional_tags = [inv_mapping(ii, arg_ind)
|
||||||
|
for ii in obj_sample['optional']]
|
||||||
|
optional_tags = [
|
||||||
|
ii for ii in optional_tags if ii in tag_groups[arg_ind]]
|
||||||
|
|
||||||
|
# If tags_to_keep is empty, sample from optional with (1:70, 2:25, 3:5).
|
||||||
|
if len(optional_tags) > 0:
|
||||||
|
if len(tags_to_keep) > 0:
|
||||||
|
n_tags_sample = [0, 1, 2]
|
||||||
|
else:
|
||||||
|
n_tags_sample = [1, 2, 3]
|
||||||
|
n_sample = np.random.choice(n_tags_sample, 1,
|
||||||
|
p=gvars.METAINFO['probabilities'],
|
||||||
|
replace=False)
|
||||||
|
# Lower cap at the length of optional.
|
||||||
|
n_sample = min(n_sample[0], len(optional_tags))
|
||||||
|
if n_sample > 0:
|
||||||
|
tags_to_keep += random.sample(optional_tags, n_sample)
|
||||||
|
|
||||||
|
# Now create a dictionary of placeholders with actual attribute values.
|
||||||
|
for tag in tag_groups[arg_ind]:
|
||||||
|
remove = tag not in tags_to_keep
|
||||||
|
text_sample = replace_attribute(
|
||||||
|
text_sample, tag, arg_sample, remove)
|
||||||
|
|
||||||
|
# Record info and merge scene graphs.
|
||||||
|
args = []
|
||||||
|
# if template['label'] == 'seek-attr-early':
|
||||||
|
# print('yey')
|
||||||
|
for obj in arg_sample['objects']:
|
||||||
|
if obj is None:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
for k in obj['required']:
|
||||||
|
arg = obj.get(k, None)
|
||||||
|
if arg is not None:
|
||||||
|
if arg not in args:
|
||||||
|
args.append(arg)
|
||||||
|
else:
|
||||||
|
arg = arg_sample.get(k, None)
|
||||||
|
if arg is not None:
|
||||||
|
args.append(arg)
|
||||||
|
arg = obj.get('attribute', None)
|
||||||
|
if arg is not None and arg not in args:
|
||||||
|
args.append(arg)
|
||||||
|
|
||||||
|
# req_att_keys = [k for obj in arg_sample['objects'] for k in obj['required'] if obj is not None]
|
||||||
|
dialog_datum = {'question': text_sample, 'answer': arg_sample['answer'],
|
||||||
|
'template': template['label'], 'args': args}
|
||||||
|
dialog['template_info'].append(template.copy())
|
||||||
|
del dialog['template_info'][-1]['text']
|
||||||
|
dialog['template_info'][-1]['index'] = text_sample_index
|
||||||
|
if 'unique' in template['label']:
|
||||||
|
print('voila')
|
||||||
|
dialog['dialog'].append(dialog_datum)
|
||||||
|
graph_item = arg_sample['graph']
|
||||||
|
|
||||||
|
# If mergeable, add it to the objects list.
|
||||||
|
dialog['graph'] = utils.merge_update_scene_graph(
|
||||||
|
dialog['graph'], graph_item)
|
||||||
|
|
||||||
|
# If there are volatile objects in the graph item, remove them.
|
||||||
|
for obj in graph_item['objects'][::-1]:
|
||||||
|
if obj.get('volatile', False):
|
||||||
|
graph_item['objects'].remove(obj)
|
||||||
|
dialog['graph']['history'].append(graph_item)
|
||||||
|
return dialog
|
||||||
|
|
||||||
|
|
||||||
|
def clean_text_subroutine(text, thing, suffix):
|
||||||
|
"""Cleans the text and substitutes thing with object (subroutine).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text string to be cleaned
|
||||||
|
thing: Whether to use 'thing' or 'object'
|
||||||
|
suffix: Either '?' (question) or '.' (caption)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
clean_text: Text string after cleaning procedure
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Synonyms + skipping optional part of the sentence
|
||||||
|
clean_text = skip_and_replace_phrases(text)
|
||||||
|
|
||||||
|
# Remove full stop, empty spaces, capitalize the start letter.
|
||||||
|
clean_text = re.sub(' +', ' ', clean_text.replace(suffix, '').strip(' '))
|
||||||
|
# First replace 'a thing' -> 'an object'.
|
||||||
|
# Then perform remaining actions.
|
||||||
|
if thing == 'object':
|
||||||
|
clean_text = clean_text.replace('a thing', 'an object')
|
||||||
|
clean_text = clean_text.replace('thing', thing)
|
||||||
|
clean_text = clean_text[0].upper() + clean_text[1:] + suffix
|
||||||
|
return clean_text
|
||||||
|
|
||||||
|
|
||||||
|
def clean_dialog_text(dialogs):
|
||||||
|
"""Cleans the dialog texts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dialogs: Generated dialogs to perform text cleaning
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dialogs: Return the dialogs after cleaning the text inplace
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Replace thing with object throughout with probability 0.5.
|
||||||
|
thing = 'thing' if random.random() > 0.5 else 'object'
|
||||||
|
for index, dialog_datum in enumerate(dialogs):
|
||||||
|
# Clean the caption.
|
||||||
|
text = dialog_datum['caption']
|
||||||
|
dialogs[index]['caption'] = clean_text_subroutine(text, thing, '.')
|
||||||
|
|
||||||
|
for r_id, dialog in enumerate(dialog_datum['dialog']):
|
||||||
|
# Clean the question.
|
||||||
|
text = dialog['question']
|
||||||
|
text = clean_text_subroutine(text, thing, '?')
|
||||||
|
dialogs[index]['dialog'][r_id]['question'] = text
|
||||||
|
return dialogs
|
||||||
|
|
||||||
|
|
||||||
|
def skip_and_replace_phrases(text):
|
||||||
|
"""Substitutes synonyms and skips optional parts stochastically.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
text: Text string with synonyms replaced and optional parts skipped
|
||||||
|
"""
|
||||||
|
|
||||||
|
# For each text in [], replace it with '' with probability 0.5.
|
||||||
|
matches = re.findall('(\[[ \w]*\])', text)
|
||||||
|
for match in matches:
|
||||||
|
if random.uniform(0, 1) > 0.5:
|
||||||
|
text = text.replace(match, '')
|
||||||
|
else:
|
||||||
|
text = text.replace(match, match[1:-1])
|
||||||
|
|
||||||
|
# Remove empty spaces, if any.
|
||||||
|
text = re.sub(' +', ' ', text)
|
||||||
|
# Search for synonyms, replace at uniformly random.
|
||||||
|
text = text.lower()
|
||||||
|
for key, values in gvars.METAINFO['synonym_keys']:
|
||||||
|
if key in text:
|
||||||
|
text = text.replace(key, random.choice(values))
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def generate_captions(scenes, templates):
|
||||||
|
"""Wrapper generates captions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scenes: List of scene graphs for which to generate captions
|
||||||
|
templates: List of available caption templates
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
generated_content: Captions generated for the input scenes
|
||||||
|
"""
|
||||||
|
|
||||||
|
template_dictionary = {ii['label']: ii for ii in templates}
|
||||||
|
generated_content = []
|
||||||
|
for scene in scenes['scenes'][0:FLAGS.num_images]:
|
||||||
|
content = {}
|
||||||
|
# Copy over image_index, split, image_filename from scene.
|
||||||
|
for key in ['image_index', 'split', 'image_filename']:
|
||||||
|
content[key] = scene[key]
|
||||||
|
|
||||||
|
content['dialogs'] = []
|
||||||
|
# Filter objects based on constraints.
|
||||||
|
filter_objs = constraints.caption(scene, templates)
|
||||||
|
for filter_obj in filter_objs:
|
||||||
|
# Realize the text, and return the partial scene knowledge (q).
|
||||||
|
template = template_dictionary[filter_obj[0]['graph']['template']]
|
||||||
|
sample = realize_text_and_extract_scene(
|
||||||
|
scene, template, filter_obj)
|
||||||
|
# Add it to the list of dialogs.
|
||||||
|
content['dialogs'].append(sample)
|
||||||
|
generated_content.append(content)
|
||||||
|
return generated_content
|
||||||
|
|
||||||
|
|
||||||
|
def generate_questions(scenes, dialogs, templates, params):
|
||||||
|
"""Wrapper generates questions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scenes: List of scene graphs to generate questions
|
||||||
|
dialogs: Contains already generated captions for scenes graphs
|
||||||
|
templates: List of available question templates
|
||||||
|
params: Beam search parameters for question generation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
new_dialogs: Generated raw dialogs with captions and questions
|
||||||
|
"""
|
||||||
|
|
||||||
|
new_dialogs = []
|
||||||
|
for scene_id, dialog_datum in enumerate(dialogs):
|
||||||
|
image_dialogs = copy.deepcopy(dialog_datum)
|
||||||
|
image_dialogs['dialogs'] = []
|
||||||
|
|
||||||
|
for dialog in dialog_datum['dialogs']:
|
||||||
|
# Pick a template at random.
|
||||||
|
flag = False
|
||||||
|
iter_count = 0
|
||||||
|
while not flag:
|
||||||
|
# Pick a template at random.
|
||||||
|
template = random.choice(templates)
|
||||||
|
|
||||||
|
# Filter objects based on constraints.
|
||||||
|
filter_objs = constraints.question(scenes['scenes'][scene_id],
|
||||||
|
dialog, template)
|
||||||
|
flag = len(filter_objs) != 0
|
||||||
|
|
||||||
|
# Extreme case -- exit
|
||||||
|
iter_count += 1
|
||||||
|
if iter_count > 10:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Realize q question.
|
||||||
|
if flag:
|
||||||
|
deep_copy = copy.deepcopy(dialog)
|
||||||
|
gen_dialog = realize_question(deep_copy, template, filter_objs)
|
||||||
|
image_dialogs['dialogs'].append(copy.deepcopy(gen_dialog))
|
||||||
|
new_dialogs.append(image_dialogs)
|
||||||
|
|
||||||
|
return new_dialogs
|
||||||
|
|
||||||
|
|
||||||
|
def worker(scenes, cap_templates, ques_templates, worker_id, out_q):
|
||||||
|
"""Worker method generates dialogs (caption + questions) for pool of scenes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scenes: List of CLEVR scenes to generate dialogs
|
||||||
|
cap_templates: Templates for caption generation
|
||||||
|
ques_templates: Templates for question generation
|
||||||
|
worker_id: Id for the current worker
|
||||||
|
out_q: Output queue to save generated dialogs from different sources
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Adds dialogs against the worker id in the output queue.
|
||||||
|
"""
|
||||||
|
|
||||||
|
dialogs = []
|
||||||
|
for index, scene in enumerate(scenes):
|
||||||
|
cur_time = time.strftime('%a-%d%b%y-%X', time.gmtime())
|
||||||
|
print('Generating [ %s ] [ Worker: %d, Progress: %d/%d Scene: %d ]' %
|
||||||
|
(cur_time, worker_id, index, len(scenes), scene['image_index']))
|
||||||
|
try:
|
||||||
|
gen_dialog = generate_dialog_bfs(
|
||||||
|
scene, cap_templates, ques_templates)
|
||||||
|
dialogs.append(json.loads(json.dumps(gen_dialog)))
|
||||||
|
except:
|
||||||
|
print('NOTE: Missing data for %d' % scene['image_index'])
|
||||||
|
out_q.put({worker_id: dialogs})
|
||||||
|
|
||||||
|
|
||||||
|
def generate_dialog_bfs(scene, cap_templates, ques_templates):
|
||||||
|
"""Perform approximate breadth-first-search (BFS) to generate dialogs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scene: Scene graph for the CLEVR image
|
||||||
|
cap_templates: List of caption templates
|
||||||
|
ques_templates: List of question templates
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bundle: List of dialogs generated for the input scene graph
|
||||||
|
"""
|
||||||
|
|
||||||
|
bundle = {}
|
||||||
|
# Generate captions for the scene.
|
||||||
|
# Copy over image_index, split, image_filename from scene.
|
||||||
|
for key in ['image_index', 'split', 'image_filename']:
|
||||||
|
bundle[key] = scene[key]
|
||||||
|
|
||||||
|
template_dictionary = {ii['label']: ii for ii in cap_templates}
|
||||||
|
content = {}
|
||||||
|
|
||||||
|
# Filter objects based on constraints on captions.
|
||||||
|
filter_objs = constraints.caption(scene, cap_templates)
|
||||||
|
|
||||||
|
for filter_obj in filter_objs:
|
||||||
|
# Realize the text, and return the partial scene knowledge (q).
|
||||||
|
template = template_dictionary[filter_obj[0]['graph']['template']]
|
||||||
|
sample = realize_text_and_extract_scene(scene, template, filter_obj)
|
||||||
|
# Add it to the list of dialogs.
|
||||||
|
content[template['label']] = [sample]
|
||||||
|
|
||||||
|
# Now generate questions.
|
||||||
|
# Group templates, exist/count of similar type together.
|
||||||
|
ques_groups = collections.defaultdict(list)
|
||||||
|
|
||||||
|
labels = [ii['label'] for ii in ques_templates]
|
||||||
|
# print('\n'.join(labels))
|
||||||
|
for index, ii in enumerate(ques_templates):
|
||||||
|
if 'exist' in ii['label'] or 'count' in ii['label']:
|
||||||
|
ques_groups[labels[index][4:]].append(ii)
|
||||||
|
else:
|
||||||
|
ques_groups[labels[index]].append(ii)
|
||||||
|
|
||||||
|
for round_id in range(FLAGS.num_rounds):
|
||||||
|
new_content = {}
|
||||||
|
|
||||||
|
# For each group.
|
||||||
|
for cap_label, cap_dialogs in content.items():
|
||||||
|
cur_pool = []
|
||||||
|
for dialog_datum in cap_dialogs:
|
||||||
|
for _, group in ques_groups.items():
|
||||||
|
template = random.choice(group)
|
||||||
|
|
||||||
|
# Make a copy.
|
||||||
|
datum_copy = copy.deepcopy(dialog_datum)
|
||||||
|
|
||||||
|
# Filter objects based on constraints.
|
||||||
|
filter_objs = constraints.question(
|
||||||
|
scene, datum_copy, template)
|
||||||
|
|
||||||
|
if len(filter_objs) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Realize q question.
|
||||||
|
gen_dialog = realize_question(
|
||||||
|
datum_copy, template, filter_objs)
|
||||||
|
cur_pool.append(gen_dialog)
|
||||||
|
|
||||||
|
if round_id in ranges:
|
||||||
|
for d_id, dialog in enumerate(cur_pool):
|
||||||
|
n_types = {'indep': 0, 'seek': 0, 'exist': 0, 'count': 0}
|
||||||
|
keep_dialog = True
|
||||||
|
|
||||||
|
labels = [ii['label']
|
||||||
|
for ii in dialog['template_info'][1:]]
|
||||||
|
for label in labels:
|
||||||
|
if label in gvars.METAINFO['independent_questions']:
|
||||||
|
n_types['indep'] += 1
|
||||||
|
|
||||||
|
label_type = label.split('-')[0]
|
||||||
|
n_types[label_type] += 1
|
||||||
|
|
||||||
|
# Heuristic A, C
|
||||||
|
for q_type, count in n_types.items():
|
||||||
|
limit = ranges[round_id][q_type]
|
||||||
|
if limit[0] > count or count > limit[1]:
|
||||||
|
keep_dialog = False
|
||||||
|
break
|
||||||
|
|
||||||
|
# Heuristic B
|
||||||
|
limit = ranges[round_id]['exist+count']
|
||||||
|
if n_types['count'] + n_types['exist'] > limit[1]:
|
||||||
|
keep_dialog = False
|
||||||
|
if not keep_dialog:
|
||||||
|
cur_pool[d_id] = None
|
||||||
|
cur_pool = [ii for ii in cur_pool if ii is not None]
|
||||||
|
|
||||||
|
# Keep limited number of beams (for speed).
|
||||||
|
if len(cur_pool) > FLAGS.num_beams:
|
||||||
|
cur_pool = sample_beams(cur_pool)[:FLAGS.num_beams]
|
||||||
|
new_content[cap_label] = cur_pool
|
||||||
|
content = copy.deepcopy(new_content)
|
||||||
|
|
||||||
|
# Get dialogs with sim, imm2, early questions.
|
||||||
|
for cap_label, cap_dialogs in content.items():
|
||||||
|
# Sample beams.
|
||||||
|
content[cap_label] = sample_beams(cap_dialogs)
|
||||||
|
|
||||||
|
# Remove keys that are empty.
|
||||||
|
empty_keys = [key for key, val in content.items() if len(val) == 0]
|
||||||
|
for key in empty_keys:
|
||||||
|
del content[key]
|
||||||
|
|
||||||
|
# For each caption, sample one.
|
||||||
|
sampled_dialogs = []
|
||||||
|
for cap_label, cap_dialogs in content.items():
|
||||||
|
if len(cap_dialogs) > 0:
|
||||||
|
sampled_dialogs.append(cap_dialogs.pop())
|
||||||
|
|
||||||
|
# Get 5 per image, compensate by taking from other entries.
|
||||||
|
content_keys = [ii for ii in content.keys()]
|
||||||
|
while len(sampled_dialogs) < 5:
|
||||||
|
random_label = random.choice(content_keys)
|
||||||
|
sampled_dialogs.append(cap_dialogs.pop())
|
||||||
|
|
||||||
|
# Finally, make the dialog text readable.
|
||||||
|
sampled_dialogs = clean_dialog_text(sampled_dialogs)
|
||||||
|
|
||||||
|
# Generate the coreference chain.
|
||||||
|
for dialog_id, dialog in enumerate(sampled_dialogs):
|
||||||
|
sampled_dialogs[dialog_id] = identify_coref_chains(dialog)
|
||||||
|
bundle['dialogs'] = sampled_dialogs
|
||||||
|
return bundle
|
||||||
|
|
||||||
|
|
||||||
|
def sample_beams(dialogs):
|
||||||
|
"""Samples beams based on the number of constraints satisfied.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dialogs: Generated dialogs to sample beams
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
sampled_dialogs: List of sampled dialogs based on the constraints
|
||||||
|
"""
|
||||||
|
|
||||||
|
num_constraints = []
|
||||||
|
for d_id, dialog in enumerate(dialogs):
|
||||||
|
satisfied = 0
|
||||||
|
labels = [ii['label'] for ii in dialog['template_info'][1:]]
|
||||||
|
|
||||||
|
# Have a imm2 for sure
|
||||||
|
satisfied += np.sum(['imm2' in ii for ii in labels])
|
||||||
|
# Have a imm2 for sure
|
||||||
|
satisfied += np.sum(['sim' in ii for ii in labels])
|
||||||
|
# Have 'early'
|
||||||
|
satisfied += min(4, np.sum(['early' in ii for ii in labels]))
|
||||||
|
|
||||||
|
# Add it with the number of constraints it satisfies.
|
||||||
|
num_constraints.append((satisfied, d_id))
|
||||||
|
|
||||||
|
# Then order.
|
||||||
|
def sort_key(x): return (x[0], random.random())
|
||||||
|
ids = sorted(num_constraints, key=sort_key, reverse=True)
|
||||||
|
sampled_dialogs = [dialogs[ii[1]] for ii in ids]
|
||||||
|
return sampled_dialogs
|
||||||
|
|
||||||
|
|
||||||
|
def identify_coref_chains(dialog):
|
||||||
|
"""Identifies the coreference chains in generated dialog.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dialog: Generated dialogs for which coreference chains to be identified
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dialog: A copy of dialog, with coreference chains annotated
|
||||||
|
"""
|
||||||
|
|
||||||
|
for r_id, datum in enumerate(dialog['dialog']):
|
||||||
|
label = datum['template']
|
||||||
|
if label in gvars.METAINFO['independent_questions']:
|
||||||
|
dialog['graph']['history'][r_id + 1]['dependence'] = None
|
||||||
|
continue
|
||||||
|
|
||||||
|
if (label == 'exist-attribute-group' or label == 'count-attribute-group' or
|
||||||
|
label == 'count-all-group'):
|
||||||
|
dialog['graph']['history'][r_id + 1]['dependence'] = r_id - 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
if 'imm' in label:
|
||||||
|
dialog['graph']['history'][r_id + 1]['dependence'] = r_id - 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
if 'early' in label:
|
||||||
|
# Go over previous history.
|
||||||
|
cur_history = dialog['graph']['history'][r_id + 1]
|
||||||
|
assert 'focus_id' in cur_history and 'focus_desc' in cur_history,\
|
||||||
|
'More focus objects than one, no focus objects!'
|
||||||
|
focus_id = cur_history['focus_id']
|
||||||
|
for attr in gvars.METAINFO['attributes']:
|
||||||
|
if attr in cur_history['focus_desc']:
|
||||||
|
break
|
||||||
|
|
||||||
|
history = dialog['graph']['history'][:r_id + 1]
|
||||||
|
for hist_id, hist_datum in enumerate(history):
|
||||||
|
for obj in hist_datum['objects']:
|
||||||
|
if obj['id'] == focus_id and attr in obj:
|
||||||
|
dialog['graph']['history'][r_id +
|
||||||
|
1]['dependence'] = hist_id - 1
|
||||||
|
break
|
||||||
|
return dialog
|
||||||
|
|
||||||
|
|
||||||
|
def main(unused_argv):
|
||||||
|
"""Main method generates the CLEVR-Dialog dataset.
|
||||||
|
"""
|
||||||
|
# Read the scene file.
|
||||||
|
with open(FLAGS.scene_path, 'r') as file_id:
|
||||||
|
scenes = json.load(file_id)
|
||||||
|
|
||||||
|
# Read the synonyms file.
|
||||||
|
with open(FLAGS.synonym_path, 'r') as file_id:
|
||||||
|
synonyms = json.load(file_id)
|
||||||
|
|
||||||
|
def sorter(x): return len(x[0].split(' '))
|
||||||
|
|
||||||
|
# Read the metainformation file.
|
||||||
|
with open(FLAGS.metainfo_path, 'r') as file_id:
|
||||||
|
gvars.METAINFO = json.load(file_id)
|
||||||
|
tag_inv_map = {attr: tag for tag, attr in gvars.METAINFO['tag_map'].items()
|
||||||
|
if tag != '<P>'}
|
||||||
|
gvars.METAINFO['tag_inv_map'] = tag_inv_map
|
||||||
|
gvars.METAINFO['synonym_keys'] = sorted(synonyms.items(),
|
||||||
|
key=sorter, reverse=True)
|
||||||
|
|
||||||
|
# Add ids to objects.
|
||||||
|
scenes = utils.add_object_ids(scenes)
|
||||||
|
scenes = utils.clean_object_attributes(scenes)
|
||||||
|
|
||||||
|
# Read the caption templates.
|
||||||
|
template_paths = os.listdir(FLAGS.caption_template_root)
|
||||||
|
cap_templates = []
|
||||||
|
for ii in template_paths:
|
||||||
|
with open(os.path.join(FLAGS.caption_template_root, ii), 'r') as file_id:
|
||||||
|
cur_templates = json.load(file_id)
|
||||||
|
cap_templates.extend(cur_templates)
|
||||||
|
# utils.pretty_print_templates(cap_templates, 1)
|
||||||
|
|
||||||
|
# Read the question templates.
|
||||||
|
template_paths = os.listdir(FLAGS.question_template_root)
|
||||||
|
ques_templates = []
|
||||||
|
for ii in template_paths:
|
||||||
|
with open(os.path.join(FLAGS.question_template_root, ii), 'r') as file_id:
|
||||||
|
cur_templates = json.load(file_id)
|
||||||
|
ques_templates.extend(cur_templates)
|
||||||
|
# utils.pretty_print_templates(ques_templates, 1)
|
||||||
|
|
||||||
|
# 1. Check if there a scene_id_file specified.
|
||||||
|
# 2. Check if num_images is -1
|
||||||
|
if FLAGS.scene_id_file != '':
|
||||||
|
with open(FLAGS.scene_id_file, 'r') as file_id:
|
||||||
|
missing_ids = [int(ii.strip('\n')) for ii in file_id.readlines()]
|
||||||
|
print('Dialogs missing for scenes: %d' % len(missing_ids))
|
||||||
|
|
||||||
|
# Create a image_index -> scenes list index dictionary
|
||||||
|
image_list_id_dict = {ii['image_index']: index
|
||||||
|
for index, ii in enumerate(scenes['scenes'])}
|
||||||
|
scenes_subset = [scenes['scenes'][image_list_id_dict[scene_id]]
|
||||||
|
for scene_id in missing_ids]
|
||||||
|
|
||||||
|
elif FLAGS.num_images == -1:
|
||||||
|
scenes_subset = scenes['scenes']
|
||||||
|
|
||||||
|
else:
|
||||||
|
scenes_subset = scenes['scenes'][0: FLAGS.num_images]
|
||||||
|
|
||||||
|
# BFS for each scene.
|
||||||
|
if FLAGS.num_workers == 1:
|
||||||
|
# Single thread version.
|
||||||
|
dialogs = []
|
||||||
|
for index, scene in enumerate(scenes_subset):
|
||||||
|
cur_time = time.strftime('%a-%d%b%y-%X', time.gmtime())
|
||||||
|
print('Generating [ %s ] [ Worker: %d, Progress: %d/%d Scene: %d ]' %
|
||||||
|
(cur_time, 0, index, len(scenes_subset), scene['image_index']))
|
||||||
|
gen_dialog = generate_dialog_bfs(
|
||||||
|
scene, cap_templates, ques_templates)
|
||||||
|
dialogs.append(gen_dialog)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Multithread version.
|
||||||
|
output_q = multiprocessing.Queue()
|
||||||
|
jobs = []
|
||||||
|
for worker_id in range(FLAGS.num_workers):
|
||||||
|
allotment = scenes_subset[worker_id::FLAGS.num_workers]
|
||||||
|
inputs = (allotment, cap_templates, ques_templates)
|
||||||
|
inputs += (worker_id, output_q)
|
||||||
|
|
||||||
|
process = multiprocessing.Process(target=worker, args=inputs)
|
||||||
|
jobs.append(process)
|
||||||
|
process.start()
|
||||||
|
|
||||||
|
# Wait for all the jobs to finish and collect the output.
|
||||||
|
final_results = {}
|
||||||
|
for _ in jobs:
|
||||||
|
final_results.update(output_q.get())
|
||||||
|
for job in jobs:
|
||||||
|
job.join()
|
||||||
|
|
||||||
|
# Flatten and sort.
|
||||||
|
final_results = [jj for _, ii in final_results.items() for jj in ii]
|
||||||
|
dialogs = sorted(final_results, key=lambda x: x['image_index'])
|
||||||
|
# utils.pretty_print_dialogs(dialogs)
|
||||||
|
|
||||||
|
# Save the dialogs.
|
||||||
|
print('Saving dialog at: %s' % FLAGS.save_path)
|
||||||
|
with open(FLAGS.save_path, 'w') as file_id:
|
||||||
|
json.dump(dialogs, file_id)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
gvars.initialize()
|
||||||
|
app.run(main)
|
1069
generate_dataset_minecraft.py
Normal file
1069
generate_dataset_minecraft.py
Normal file
File diff suppressed because it is too large
Load diff
10
global_vars.py
Normal file
10
global_vars.py
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
"""Global variables (avoid as much as possible).
|
||||||
|
Author: Satwik Kottur
|
||||||
|
"""
|
||||||
|
|
||||||
|
def initialize():
|
||||||
|
"""Sets up global variables.
|
||||||
|
"""
|
||||||
|
|
||||||
|
global METAINFO
|
||||||
|
METAINFO = {}
|
224
minecraft_utils.py
Normal file
224
minecraft_utils.py
Normal file
|
@ -0,0 +1,224 @@
|
||||||
|
"""Utilities for CLEVR-Dialog dataset generation.
|
||||||
|
|
||||||
|
Author: Satwik Kottur
|
||||||
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
|
|
||||||
|
|
||||||
|
def pretty_print_templates(templates, verbosity=1):
|
||||||
|
"""Pretty prints templates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
templates: Templates to print
|
||||||
|
verbosity: 1 to print name and type of the templates
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Verbosity 1: Name and type.
|
||||||
|
print('-'*70)
|
||||||
|
for ii in templates:
|
||||||
|
print('[Name: %s] [Type: %s]' % (ii['name'], ii['type']))
|
||||||
|
print('-'*70)
|
||||||
|
print('Total of %s templates..' % len(templates))
|
||||||
|
print('-'*70)
|
||||||
|
|
||||||
|
|
||||||
|
def pretty_print_scene_objects(scene):
|
||||||
|
"""Pretty prints scene objects.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scene: Scene graph containing list of objects
|
||||||
|
"""
|
||||||
|
|
||||||
|
for index, ii in enumerate(scene['objects']):
|
||||||
|
print_args = (index, ii['shape'], ii['color'],
|
||||||
|
ii['size'], ii['material'])
|
||||||
|
print('\t%d : %s-%s-%s-%s' % print_args)
|
||||||
|
|
||||||
|
|
||||||
|
def pretty_print_dialogs(dialogs):
|
||||||
|
"""Pretty prints generated dialogs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dialogs: Generated dialogs to print
|
||||||
|
"""
|
||||||
|
|
||||||
|
for scene_id, dialog_datum in enumerate(dialogs):
|
||||||
|
for dialog in dialog_datum['dialogs']:
|
||||||
|
print(dialog['caption'])
|
||||||
|
for round_id, ii in enumerate(dialog['dialog']):
|
||||||
|
coref_id = dialog['graph']['history'][round_id+1]['dependence']
|
||||||
|
in_tuple = (round_id, ii['question'], str(ii['answer']),
|
||||||
|
ii['template'], str(coref_id))
|
||||||
|
print('\t[Q-%d: %s] [A: %s] [%s] [%s]' % in_tuple)
|
||||||
|
|
||||||
|
|
||||||
|
def merge_update_scene_graph(orig_graph, graph_item):
|
||||||
|
"""Merges two scene graphs into one.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
orig_graph: Original scene graph
|
||||||
|
graph_item: New graph item to add to the scene graph
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
graph: Deep copy of the original scene graph after merging
|
||||||
|
"""
|
||||||
|
|
||||||
|
graph = copy.deepcopy(orig_graph)
|
||||||
|
# Local alias.
|
||||||
|
objects = graph['objects']
|
||||||
|
|
||||||
|
# If not mergeable, return the same scene graph.
|
||||||
|
if not graph_item['mergeable']:
|
||||||
|
return graph
|
||||||
|
|
||||||
|
# 1. Go through each new object
|
||||||
|
# 2. Find its batch in objects
|
||||||
|
# a. If found, assert for a clash of attributes, update
|
||||||
|
# b. If novel, just add the object as is
|
||||||
|
for new_obj in graph_item['objects']:
|
||||||
|
match_found = False
|
||||||
|
obj = objects.get(new_obj['id'], None)
|
||||||
|
|
||||||
|
if obj:
|
||||||
|
# Assert for existing entries.
|
||||||
|
for attr in new_obj:
|
||||||
|
try:
|
||||||
|
assert new_obj[attr] == obj.get(attr, new_obj[attr]),\
|
||||||
|
'Some of the attributes do not match!'
|
||||||
|
except:
|
||||||
|
pdb.set_trace()
|
||||||
|
|
||||||
|
# Add additional keys.
|
||||||
|
objects[new_obj['id']].update(new_obj)
|
||||||
|
else:
|
||||||
|
# Add the new object.
|
||||||
|
objects[new_obj['id']] = new_obj
|
||||||
|
|
||||||
|
# if a relation, update it
|
||||||
|
if 'relation' in graph_item:
|
||||||
|
rel = graph_item['relation']
|
||||||
|
# update it with object 2 id
|
||||||
|
id1 = graph_item['objects'][0]['id']
|
||||||
|
id2 = graph_item['objects'][1]['id']
|
||||||
|
rel_objs = graph['relationships'][rel][id1]
|
||||||
|
rel_objs.append(id2)
|
||||||
|
graph['relationships'][rel][id1] = rel_objs
|
||||||
|
|
||||||
|
# update objects in graph
|
||||||
|
graph['objects'] = objects
|
||||||
|
return graph
|
||||||
|
|
||||||
|
|
||||||
|
def add_object_ids(scenes):
|
||||||
|
"""Adds object ids field for input scenes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scenes: List of CLEVR scene graphs
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
scenes: Adds object_id field for the objects in the scene graph inplace
|
||||||
|
"""
|
||||||
|
|
||||||
|
for scene_id, scene in enumerate(scenes['scenes']):
|
||||||
|
for obj_id, _ in enumerate(scene['objects']):
|
||||||
|
scenes['scenes'][scene_id]['objects'][obj_id]['id'] = obj_id
|
||||||
|
return scenes
|
||||||
|
|
||||||
|
|
||||||
|
def clean_object_attributes(scenes):
|
||||||
|
"""Cleans attributes for objects, keeping only attributes and id.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scenes: Scene graph to clean
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
scenes: Cleaned up scene graphs inplace
|
||||||
|
"""
|
||||||
|
|
||||||
|
keys = ['class', 'direction', 'nature', 'id']
|
||||||
|
for scene_id, scene in enumerate(scenes['scenes']):
|
||||||
|
for obj_id, obj in enumerate(scene['objects']):
|
||||||
|
new_obj = {key: obj[key] for key in keys}
|
||||||
|
scenes['scenes'][scene_id]['objects'][obj_id] = new_obj
|
||||||
|
return scenes
|
||||||
|
|
||||||
|
|
||||||
|
def pretty_print_corefs(dialog, coref_groups):
|
||||||
|
"""Prints coreferences for a dialog, higlighting different groups in colors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dialog: Generated dialogs to print
|
||||||
|
coref_groups: Coreference groups for dialogs
|
||||||
|
"""
|
||||||
|
|
||||||
|
colorama.init()
|
||||||
|
# Mapping of group_id -> color_ids for (foreground, background)
|
||||||
|
color_map = {}
|
||||||
|
groups = coref_groups.get(0, [])
|
||||||
|
colored, color_map = pretty_print_coref_sentence(dialog['caption'], groups,
|
||||||
|
color_map)
|
||||||
|
print('\n\nC: %s' % colored)
|
||||||
|
for round_id, round_datum in enumerate(dialog['dialog']):
|
||||||
|
question = round_datum['question']
|
||||||
|
groups = coref_groups.get(round_id + 1, [])
|
||||||
|
colored, color_map = pretty_print_coref_sentence(question, groups,
|
||||||
|
color_map)
|
||||||
|
print('%d: %s' % (round_id, colored))
|
||||||
|
|
||||||
|
|
||||||
|
def pretty_print_coref_sentence(sentence, groups, color_map):
|
||||||
|
"""Prints a sentence containing difference coreference groups.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sentence: Text sentence
|
||||||
|
groups: List of coreference groups with spans
|
||||||
|
color_map: List of groups and associated color maps
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
sentence: Text sentence with colors inserted
|
||||||
|
color_map: Updated, if new groups in the current sentence
|
||||||
|
"""
|
||||||
|
|
||||||
|
fore_colors = ['RED', 'GREEN', 'YELLOW', 'BLUE', 'MAGENTA']
|
||||||
|
back_colors = ['BLACK', 'YELLOW', 'CYAN']
|
||||||
|
insertions = []
|
||||||
|
for group in groups:
|
||||||
|
group_id = group['group_id']
|
||||||
|
if group_id in color_map:
|
||||||
|
forecolor_id, backcolor_id = color_map[group_id]
|
||||||
|
else:
|
||||||
|
num_groups = len(color_map)
|
||||||
|
forecolor_id = num_groups % len(fore_colors)
|
||||||
|
backcolor_id = num_groups // len(fore_colors)
|
||||||
|
color_map[group_id] = (forecolor_id, backcolor_id)
|
||||||
|
|
||||||
|
forecolor = fore_colors[forecolor_id]
|
||||||
|
backcolor = back_colors[backcolor_id]
|
||||||
|
insertions.append(
|
||||||
|
(group['span'][0], getattr(colorama.Fore, forecolor)))
|
||||||
|
insertions.append(
|
||||||
|
(group['span'][0], getattr(colorama.Back, backcolor)))
|
||||||
|
insertions.append((group['span'][1],
|
||||||
|
getattr(colorama.Style, 'RESET_ALL')))
|
||||||
|
|
||||||
|
# Perform insertions.
|
||||||
|
sentence = insert_into_sentence(sentence, insertions)
|
||||||
|
return sentence, color_map
|
||||||
|
|
||||||
|
|
||||||
|
def insert_into_sentence(sentence, insertions):
|
||||||
|
"""Sorts and performs insertions from right.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sentence: Sentence to perform insertions into
|
||||||
|
insertions: List of insertions, format: (position, text_insert)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
sentence: Inplace inserted sentence
|
||||||
|
"""
|
||||||
|
|
||||||
|
insertions = sorted(insertions, key=lambda x: x[0], reverse=True)
|
||||||
|
for position, text in insertions:
|
||||||
|
sentence = sentence[:position] + text + sentence[position:]
|
||||||
|
return sentence
|
BIN
misc/method_overview.png
Normal file
BIN
misc/method_overview.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 100 KiB |
BIN
misc/method_smaller.png
Normal file
BIN
misc/method_smaller.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 66 KiB |
735
preprocess_dialogs/preprocess.py
Normal file
735
preprocess_dialogs/preprocess.py
Normal file
|
@ -0,0 +1,735 @@
|
||||||
|
"""
|
||||||
|
author: Adnen Abdessaied
|
||||||
|
maintainer: "Adnen Abdessaied"
|
||||||
|
website: adnenabdessaied.de
|
||||||
|
version: 1.0.1
|
||||||
|
"""
|
||||||
|
|
||||||
|
# This script preprocesses clevr-dialog questions
|
||||||
|
|
||||||
|
from copy import deepcopy
|
||||||
|
from tqdm import tqdm
|
||||||
|
import numpy as np
|
||||||
|
import h5py
|
||||||
|
import json
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--input_dialogs_json',
|
||||||
|
help='The path of the raw dialog json file.',
|
||||||
|
required=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# '/projects/abdessaied/ns-vqa/output/clevr_vocab.json')
|
||||||
|
parser.add_argument(
|
||||||
|
'--input_vocab_json',
|
||||||
|
help='The path of the generated vocab.',
|
||||||
|
required=True
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--output_vocab_json',
|
||||||
|
help='The path to save the generated vocab.',
|
||||||
|
required=True
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--output_h5_file',
|
||||||
|
help='The path of the output h5 file.',
|
||||||
|
required=True
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--mode',
|
||||||
|
help='The preprocessing strategy.',
|
||||||
|
choices=['stack', 'concat'],
|
||||||
|
required=True
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--split',
|
||||||
|
help='The split type of the data.',
|
||||||
|
choices=['train', 'val', 'test'],
|
||||||
|
required=True
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--percentage',
|
||||||
|
default=1.0,
|
||||||
|
type=int,
|
||||||
|
help='The percentage of data to use in training.'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--num_rounds',
|
||||||
|
type=int,
|
||||||
|
default=10,
|
||||||
|
help='The total number of rounds in one dialog.'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--val_size',
|
||||||
|
type=int,
|
||||||
|
help='The size of the validation set.',
|
||||||
|
required=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
SPECIAL_TOKENS = {
|
||||||
|
'<NULL>': 0,
|
||||||
|
'<START>': 1,
|
||||||
|
'<END>': 2,
|
||||||
|
'<UNK>': 3,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize(s, delim=' ',
|
||||||
|
add_start_token=True, add_end_token=True,
|
||||||
|
punct_to_keep=None, punct_to_remove=None):
|
||||||
|
"""
|
||||||
|
Tokenize a sequence, converting a string s into a list of (string) tokens by
|
||||||
|
splitting on the specified delimiter. Optionally keep or remove certain
|
||||||
|
punctuation marks and add start and end tokens.
|
||||||
|
"""
|
||||||
|
if punct_to_keep is not None:
|
||||||
|
for p in punct_to_keep:
|
||||||
|
s = s.replace(p, '%s%s' % (delim, p))
|
||||||
|
|
||||||
|
if punct_to_remove is not None:
|
||||||
|
for p in punct_to_remove:
|
||||||
|
s = s.replace(p, '')
|
||||||
|
|
||||||
|
tokens = s.split(delim)
|
||||||
|
if add_start_token:
|
||||||
|
tokens.insert(0, '<START>')
|
||||||
|
if add_end_token:
|
||||||
|
tokens.append('<END>')
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
|
def build_vocab(sequences, min_token_count=1, delim=' ',
|
||||||
|
punct_to_keep=None, punct_to_remove=None):
|
||||||
|
token_to_count = {}
|
||||||
|
tokenize_kwargs = {
|
||||||
|
'delim': delim,
|
||||||
|
'punct_to_keep': punct_to_keep,
|
||||||
|
'punct_to_remove': punct_to_remove,
|
||||||
|
}
|
||||||
|
for seq in sequences:
|
||||||
|
seq_tokens = tokenize(seq, **tokenize_kwargs,
|
||||||
|
add_start_token=False, add_end_token=False)
|
||||||
|
for token in seq_tokens:
|
||||||
|
if token not in token_to_count:
|
||||||
|
token_to_count[token] = 0
|
||||||
|
token_to_count[token] += 1
|
||||||
|
|
||||||
|
token_to_idx = {}
|
||||||
|
for token, idx in SPECIAL_TOKENS.items():
|
||||||
|
token_to_idx[token] = idx
|
||||||
|
for token, count in sorted(token_to_count.items()):
|
||||||
|
if count >= min_token_count:
|
||||||
|
token_to_idx[token] = len(token_to_idx)
|
||||||
|
|
||||||
|
return token_to_idx
|
||||||
|
|
||||||
|
|
||||||
|
def encode(seq_tokens, token_to_idx, allow_unk=False):
|
||||||
|
seq_idx = []
|
||||||
|
for token in seq_tokens:
|
||||||
|
if token not in token_to_idx:
|
||||||
|
if allow_unk:
|
||||||
|
token = '<UNK>'
|
||||||
|
else:
|
||||||
|
raise KeyError('Token "%s" not in vocab' % token)
|
||||||
|
seq_idx.append(token_to_idx[token])
|
||||||
|
return seq_idx
|
||||||
|
|
||||||
|
|
||||||
|
def decode(seq_idx, idx_to_token, delim=None, stop_at_end=True):
|
||||||
|
tokens = []
|
||||||
|
for idx in seq_idx:
|
||||||
|
tokens.append(idx_to_token[idx])
|
||||||
|
if stop_at_end and tokens[-1] == '<END>':
|
||||||
|
break
|
||||||
|
if delim is None:
|
||||||
|
return tokens
|
||||||
|
else:
|
||||||
|
return delim.join(tokens)
|
||||||
|
|
||||||
|
|
||||||
|
def concat(allDialogs, vocab, percentage, split="train", num_rounds=10):
|
||||||
|
pbar = tqdm(allDialogs)
|
||||||
|
pbar.set_description("[INFO] Encoding data ...")
|
||||||
|
|
||||||
|
captions = []
|
||||||
|
captionProgs = []
|
||||||
|
captionImgIdx = []
|
||||||
|
|
||||||
|
questions = []
|
||||||
|
questionProgs = []
|
||||||
|
questionImgIdx = []
|
||||||
|
questionRounds = []
|
||||||
|
|
||||||
|
histories = []
|
||||||
|
historiesProg = []
|
||||||
|
|
||||||
|
answers = []
|
||||||
|
maxQ = vocab["maxQ"]
|
||||||
|
# maxC = vocab["maxC"]
|
||||||
|
maxP = vocab["maxP"]
|
||||||
|
maxH = maxQ + (num_rounds-1)*(maxQ - 1)
|
||||||
|
maxHistProg = num_rounds * maxP
|
||||||
|
|
||||||
|
questionBins = {}
|
||||||
|
captionBins = {}
|
||||||
|
# k=0
|
||||||
|
for imgDialogs in pbar:
|
||||||
|
# k+= 1
|
||||||
|
# if k>2:
|
||||||
|
# break
|
||||||
|
for dialog in imgDialogs["dialogs"]:
|
||||||
|
if split == "train":
|
||||||
|
if dialog["template"] not in captionBins:
|
||||||
|
captionBins[dialog["template"]] = {
|
||||||
|
"captions": [],
|
||||||
|
"captionProgs": []
|
||||||
|
}
|
||||||
|
|
||||||
|
caption = tokenize(dialog["caption"], punct_to_keep=[
|
||||||
|
';', ','], punct_to_remove=['?', '.'])
|
||||||
|
|
||||||
|
# if len(caption) < maxQ:
|
||||||
|
while len(caption) < maxQ:
|
||||||
|
caption.append(vocab["text_token_to_idx"]["<NULL>"])
|
||||||
|
caption = encode(
|
||||||
|
caption, vocab["text_token_to_idx"], allow_unk=True)
|
||||||
|
history = caption[:-1] # removes <END> token
|
||||||
|
|
||||||
|
captions.append(caption)
|
||||||
|
|
||||||
|
progC = [dialog["template"]] + \
|
||||||
|
list(map(lambda a: "_".join(a.split(" ")), dialog["args"]))
|
||||||
|
progC = " ".join(progC)
|
||||||
|
progC = tokenize(progC)
|
||||||
|
progC = encode(progC, vocab["prog_token_to_idx"], allow_unk=True)
|
||||||
|
while len(progC) < maxP:
|
||||||
|
progC.append(vocab["prog_token_to_idx"]["<NULL>"])
|
||||||
|
|
||||||
|
captionProgs.append(progC)
|
||||||
|
imgIdx = imgDialogs["image_index"]
|
||||||
|
captionImgIdx.append(imgIdx)
|
||||||
|
|
||||||
|
if split == "train":
|
||||||
|
captionBins[dialog["template"]]["captions"].append(caption)
|
||||||
|
captionBins[dialog["template"]]["captionProgs"].append(progC)
|
||||||
|
while len(history) < maxQ - 1:
|
||||||
|
history.append(vocab["text_token_to_idx"]["<NULL>"])
|
||||||
|
|
||||||
|
histoyProg = progC
|
||||||
|
# qRounds = []
|
||||||
|
for i, _round in enumerate(dialog["dialog"]):
|
||||||
|
question = tokenize(_round["question"], punct_to_keep=[
|
||||||
|
';', ','], punct_to_remove=['?', '.'])
|
||||||
|
question = encode(
|
||||||
|
question, vocab["text_token_to_idx"], allow_unk=True)
|
||||||
|
questionH = question[1:-1] # Delete <END> token
|
||||||
|
|
||||||
|
# if len(question) < maxQ:
|
||||||
|
# if len(question) < maxQ:
|
||||||
|
# print("q < {}".format(maxQ))
|
||||||
|
# else:
|
||||||
|
# print("q >= {}".format(maxQ))
|
||||||
|
|
||||||
|
while len(question) < maxQ:
|
||||||
|
question.append(vocab["text_token_to_idx"]["<NULL>"])
|
||||||
|
# else:
|
||||||
|
# question = question[:maxQ]
|
||||||
|
|
||||||
|
prog = [_round["template"]] + \
|
||||||
|
list(map(lambda a: "_".join(a.split(" ")), _round["args"]))
|
||||||
|
prog = " ".join(prog)
|
||||||
|
prog = tokenize(prog, punct_to_keep=[
|
||||||
|
';', ','], punct_to_remove=['?', '.'])
|
||||||
|
prog = encode(prog, vocab["prog_token_to_idx"], allow_unk=True)
|
||||||
|
|
||||||
|
while len(prog) < maxP:
|
||||||
|
prog.append(vocab["prog_token_to_idx"]["<NULL>"])
|
||||||
|
|
||||||
|
answer = tokenize("_".join(str(_round["answer"]).split(" ")), punct_to_keep=[
|
||||||
|
';', ','], punct_to_remove=['?', '.'])
|
||||||
|
answer = encode(
|
||||||
|
answer, vocab["text_token_to_idx"], allow_unk=True)
|
||||||
|
assert len(answer) == 3 # answer = <START> ans <END>
|
||||||
|
answer = answer[1]
|
||||||
|
historyPadded = deepcopy(history)
|
||||||
|
|
||||||
|
while len(historyPadded) < maxH - 1:
|
||||||
|
historyPadded.append(vocab["text_token_to_idx"]["<NULL>"])
|
||||||
|
|
||||||
|
historyProgPadded = deepcopy(histoyProg)
|
||||||
|
while len(historyProgPadded) < maxHistProg:
|
||||||
|
historyProgPadded.append(
|
||||||
|
vocab["prog_token_to_idx"]["<NULL>"])
|
||||||
|
|
||||||
|
if split == "train":
|
||||||
|
questionTypeIdx = _round["template"]
|
||||||
|
if questionTypeIdx not in questionBins:
|
||||||
|
questionBins[questionTypeIdx] = {
|
||||||
|
"questions": [],
|
||||||
|
"questionProgs": [],
|
||||||
|
"questionImgIdx": [],
|
||||||
|
"questionRounds": [],
|
||||||
|
|
||||||
|
"histories": [],
|
||||||
|
"historiesProg": [],
|
||||||
|
"answers": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
questionBins[questionTypeIdx]["questions"].append(question)
|
||||||
|
questionBins[questionTypeIdx]["questionProgs"].append(prog)
|
||||||
|
questionBins[questionTypeIdx]["questionImgIdx"].append(
|
||||||
|
imgIdx)
|
||||||
|
questionBins[questionTypeIdx]["questionRounds"].append(i+1)
|
||||||
|
|
||||||
|
questionBins[questionTypeIdx]["histories"].append(
|
||||||
|
historyPadded)
|
||||||
|
questionBins[questionTypeIdx]["historiesProg"].append(
|
||||||
|
historyProgPadded)
|
||||||
|
questionBins[questionTypeIdx]["answers"].append(answer)
|
||||||
|
else:
|
||||||
|
questions.append(question)
|
||||||
|
questionProgs.append(prog)
|
||||||
|
histories.append(historyPadded)
|
||||||
|
historiesProg.append(historyProgPadded)
|
||||||
|
answers.append(answer)
|
||||||
|
questionImgIdx.append(imgIdx)
|
||||||
|
questionRounds.append(i+1)
|
||||||
|
|
||||||
|
while len(questionH) < maxQ-2:
|
||||||
|
questionH.append(vocab["text_token_to_idx"]["<NULL>"])
|
||||||
|
qaPair = questionH + [answer]
|
||||||
|
history.extend(qaPair)
|
||||||
|
histoyProg.extend(prog)
|
||||||
|
|
||||||
|
if split == "train":
|
||||||
|
captions = []
|
||||||
|
captionProgs = []
|
||||||
|
|
||||||
|
questions = []
|
||||||
|
questionProgs = []
|
||||||
|
questionImgIdx = []
|
||||||
|
questionRounds = []
|
||||||
|
|
||||||
|
histories = []
|
||||||
|
historiesProg = []
|
||||||
|
answers = []
|
||||||
|
|
||||||
|
for ctype in captionBins:
|
||||||
|
numTrSamples = int(percentage * len(captionBins[ctype]["captions"]))
|
||||||
|
|
||||||
|
captions.extend(captionBins[ctype]["captions"][:numTrSamples])
|
||||||
|
captionProgs.extend(
|
||||||
|
captionBins[ctype]["captionProgs"][:numTrSamples])
|
||||||
|
|
||||||
|
for qtype in questionBins:
|
||||||
|
numTrSamples = int(percentage *
|
||||||
|
len(questionBins[qtype]["questions"]))
|
||||||
|
|
||||||
|
questions.extend(questionBins[qtype]["questions"][:numTrSamples])
|
||||||
|
questionProgs.extend(
|
||||||
|
questionBins[qtype]["questionProgs"][:numTrSamples])
|
||||||
|
questionImgIdx.extend(
|
||||||
|
questionBins[qtype]["questionImgIdx"][:numTrSamples])
|
||||||
|
questionRounds.extend(
|
||||||
|
questionBins[qtype]["questionRounds"][:numTrSamples])
|
||||||
|
|
||||||
|
histories.extend(questionBins[qtype]["histories"][:numTrSamples])
|
||||||
|
historiesProg.extend(
|
||||||
|
questionBins[qtype]["historiesProg"][:numTrSamples])
|
||||||
|
|
||||||
|
answers.extend(questionBins[qtype]["answers"][:numTrSamples])
|
||||||
|
|
||||||
|
result = {
|
||||||
|
split: {
|
||||||
|
"captions": captions,
|
||||||
|
"captionProgs": captionProgs,
|
||||||
|
# "captionImgIdx": captionImgIdx,
|
||||||
|
|
||||||
|
"questions": questions,
|
||||||
|
"questionProgs": questionProgs,
|
||||||
|
"questionImgIdx": questionImgIdx,
|
||||||
|
"questionRounds": questionRounds,
|
||||||
|
|
||||||
|
"histories": histories,
|
||||||
|
"historiesProg": historiesProg,
|
||||||
|
"answers": answers,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def stack(allDialogs, vocab, percentage, split="train", num_rounds=10):
|
||||||
|
pbar = tqdm(allDialogs)
|
||||||
|
pbar.set_description("[INFO] Encoding data ...")
|
||||||
|
|
||||||
|
captions = []
|
||||||
|
captionProgs = []
|
||||||
|
captionImgIdx = []
|
||||||
|
|
||||||
|
questions = []
|
||||||
|
questionProgs = []
|
||||||
|
questionImgIdx = []
|
||||||
|
questionRounds = []
|
||||||
|
|
||||||
|
histories = []
|
||||||
|
historiesProg = []
|
||||||
|
|
||||||
|
answers = []
|
||||||
|
|
||||||
|
maxQ = vocab["maxQ"]
|
||||||
|
# maxC = vocab["maxC"]
|
||||||
|
maxP = vocab["maxP"]
|
||||||
|
maxHistProg = num_rounds * maxP
|
||||||
|
questionBins = {}
|
||||||
|
captionBins = {}
|
||||||
|
|
||||||
|
for imgDialogs in pbar:
|
||||||
|
for dialog in imgDialogs["dialogs"]:
|
||||||
|
if split == "train":
|
||||||
|
if dialog["template"] not in captionBins:
|
||||||
|
captionBins[dialog["template"]] = {
|
||||||
|
"captions": [],
|
||||||
|
"captionProgs": []
|
||||||
|
}
|
||||||
|
|
||||||
|
caption = tokenize(dialog["caption"], punct_to_keep=[
|
||||||
|
';', ','], punct_to_remove=['?', '.'])
|
||||||
|
caption = encode(
|
||||||
|
caption, vocab["text_token_to_idx"], allow_unk=True)
|
||||||
|
while len(caption) < maxQ:
|
||||||
|
caption.append(vocab["text_token_to_idx"]["<NULL>"])
|
||||||
|
captions.append(caption)
|
||||||
|
|
||||||
|
progC = [dialog["template"]] + \
|
||||||
|
list(map(lambda a: "_".join(a.split(" ")), dialog["args"]))
|
||||||
|
progC = " ".join(progC)
|
||||||
|
progC = tokenize(progC)
|
||||||
|
progC = encode(progC, vocab["prog_token_to_idx"], allow_unk=True)
|
||||||
|
while len(progC) < maxP:
|
||||||
|
progC.append(vocab["prog_token_to_idx"]["<NULL>"])
|
||||||
|
|
||||||
|
captionProgs.append(progC)
|
||||||
|
imgIdx = imgDialogs["image_index"]
|
||||||
|
captionImgIdx.append(imgIdx)
|
||||||
|
|
||||||
|
if split == "train":
|
||||||
|
captionBins[dialog["template"]]["captions"].append(caption)
|
||||||
|
captionBins[dialog["template"]]["captionProgs"].append(progC)
|
||||||
|
|
||||||
|
while len(caption) < maxQ + 1:
|
||||||
|
caption.append(vocab["text_token_to_idx"]["<NULL>"])
|
||||||
|
|
||||||
|
history = np.zeros((num_rounds, maxQ + 1))
|
||||||
|
history[0, :] = caption
|
||||||
|
histoyProg = progC
|
||||||
|
# qRounds = []
|
||||||
|
for i, _round in enumerate(dialog["dialog"]):
|
||||||
|
question = tokenize(_round["question"], punct_to_keep=[
|
||||||
|
';', ','], punct_to_remove=['?', '.'])
|
||||||
|
question = encode(
|
||||||
|
question, vocab["text_token_to_idx"], allow_unk=True)
|
||||||
|
questionH = question[0:-1] # Delete <END> token
|
||||||
|
|
||||||
|
if len(question) < maxQ:
|
||||||
|
while len(question) < maxQ:
|
||||||
|
question.append(vocab["text_token_to_idx"]["<NULL>"])
|
||||||
|
else:
|
||||||
|
question = question[:maxQ]
|
||||||
|
|
||||||
|
prog = [_round["template"]] + \
|
||||||
|
list(map(lambda a: "_".join(a.split(" ")), _round["args"]))
|
||||||
|
prog = " ".join(prog)
|
||||||
|
prog = tokenize(prog, punct_to_keep=[
|
||||||
|
';', ','], punct_to_remove=['?', '.'])
|
||||||
|
prog = encode(prog, vocab["prog_token_to_idx"], allow_unk=True)
|
||||||
|
|
||||||
|
while len(prog) < maxP:
|
||||||
|
prog.append(vocab["prog_token_to_idx"]["<NULL>"])
|
||||||
|
|
||||||
|
historyProgPadded = deepcopy(histoyProg)
|
||||||
|
while len(historyProgPadded) < maxHistProg:
|
||||||
|
historyProgPadded.append(
|
||||||
|
vocab["prog_token_to_idx"]["<NULL>"])
|
||||||
|
|
||||||
|
answer = tokenize("_".join(str(_round["answer"]).split(" ")), punct_to_keep=[
|
||||||
|
';', ','], punct_to_remove=['?', '.'])
|
||||||
|
answer = encode(
|
||||||
|
answer, vocab["text_token_to_idx"], allow_unk=True)
|
||||||
|
assert len(answer) == 3 # answer = <START> ans <END>
|
||||||
|
answer = answer[1]
|
||||||
|
|
||||||
|
if split == "train":
|
||||||
|
questionTypeIdx = _round["template"]
|
||||||
|
if questionTypeIdx not in questionBins:
|
||||||
|
questionBins[questionTypeIdx] = {
|
||||||
|
"questions": [],
|
||||||
|
"questionProgs": [],
|
||||||
|
"questionImgIdx": [],
|
||||||
|
"questionRounds": [],
|
||||||
|
|
||||||
|
"histories": [],
|
||||||
|
"historiesProg": [],
|
||||||
|
"answers": [],
|
||||||
|
|
||||||
|
}
|
||||||
|
questionBins[questionTypeIdx]["questions"].append(question)
|
||||||
|
questionBins[questionTypeIdx]["questionProgs"].append(prog)
|
||||||
|
questionBins[questionTypeIdx]["questionImgIdx"].append(
|
||||||
|
imgIdx)
|
||||||
|
questionBins[questionTypeIdx]["questionRounds"].append(i+1)
|
||||||
|
|
||||||
|
questionBins[questionTypeIdx]["histories"].append(
|
||||||
|
deepcopy(history))
|
||||||
|
questionBins[questionTypeIdx]["historiesProg"].append(
|
||||||
|
historyProgPadded)
|
||||||
|
questionBins[questionTypeIdx]["answers"].append(answer)
|
||||||
|
else:
|
||||||
|
questions.append(question)
|
||||||
|
questionProgs.append(prog)
|
||||||
|
histories.append(deepcopy(history))
|
||||||
|
historiesProg.append(historyProgPadded)
|
||||||
|
answers.append(answer)
|
||||||
|
questionImgIdx.append(imgIdx)
|
||||||
|
questionRounds.append(i+1)
|
||||||
|
|
||||||
|
while len(questionH) < maxQ-1:
|
||||||
|
questionH.append(vocab["text_token_to_idx"]["<NULL>"])
|
||||||
|
qaPair = questionH + [answer] + \
|
||||||
|
[vocab["text_token_to_idx"]["<END>"]]
|
||||||
|
if i < num_rounds - 1:
|
||||||
|
history[i+1, :] = qaPair
|
||||||
|
histoyProg.extend(prog)
|
||||||
|
# questionRounds.append(qRounds)
|
||||||
|
|
||||||
|
if split == "train":
|
||||||
|
captions = []
|
||||||
|
captionProgs = []
|
||||||
|
|
||||||
|
questions = []
|
||||||
|
questionProgs = []
|
||||||
|
questionImgIdx = []
|
||||||
|
questionRounds = []
|
||||||
|
|
||||||
|
histories = []
|
||||||
|
historiesProg = []
|
||||||
|
answers = []
|
||||||
|
|
||||||
|
for ctype in captionBins:
|
||||||
|
numTrSamples = int(
|
||||||
|
percentage * len(captionBins[ctype]["captions"]))
|
||||||
|
|
||||||
|
captions.extend(captionBins[ctype]["captions"][:numTrSamples])
|
||||||
|
captionProgs.extend(
|
||||||
|
captionBins[ctype]["captionProgs"][:numTrSamples])
|
||||||
|
|
||||||
|
for qtype in questionBins:
|
||||||
|
numTrSamples = int(
|
||||||
|
percentage * len(questionBins[qtype]["questions"]))
|
||||||
|
|
||||||
|
questions.extend(questionBins[qtype]["questions"][:numTrSamples])
|
||||||
|
questionProgs.extend(
|
||||||
|
questionBins[qtype]["questionProgs"][:numTrSamples])
|
||||||
|
questionImgIdx.extend(
|
||||||
|
questionBins[qtype]["questionImgIdx"][:numTrSamples])
|
||||||
|
questionRounds.extend(
|
||||||
|
questionBins[qtype]["questionRounds"][:numTrSamples])
|
||||||
|
|
||||||
|
histories.extend(questionBins[qtype]["histories"][:numTrSamples])
|
||||||
|
historiesProg.extend(
|
||||||
|
questionBins[qtype]["historiesProg"][:numTrSamples])
|
||||||
|
|
||||||
|
answers.extend(questionBins[qtype]["answers"][:numTrSamples])
|
||||||
|
|
||||||
|
result = {
|
||||||
|
split: {
|
||||||
|
"captions": captions,
|
||||||
|
"captionProgs": captionProgs,
|
||||||
|
|
||||||
|
"questions": questions,
|
||||||
|
"questionProgs": questionProgs,
|
||||||
|
"questionImgIdx": questionImgIdx,
|
||||||
|
"questionRounds": questionRounds,
|
||||||
|
|
||||||
|
"histories": histories,
|
||||||
|
"historiesProg": historiesProg,
|
||||||
|
"answers": answers,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
assert not((args.input_vocab_json == "")
|
||||||
|
and (args.output_vocab_json == ""))
|
||||||
|
|
||||||
|
print("[INFO] Loading data ...")
|
||||||
|
with open(args.input_dialogs_json, "r") as f:
|
||||||
|
allDialogs = json.load(f)
|
||||||
|
|
||||||
|
# Either create the vocab or load it from disk
|
||||||
|
if args.input_vocab_json == "":
|
||||||
|
maxQ = 0
|
||||||
|
maxP = 0
|
||||||
|
text = []
|
||||||
|
programs = []
|
||||||
|
answers = []
|
||||||
|
pbar = tqdm(allDialogs)
|
||||||
|
pbar.set_description("[INFO] Building vocab ...")
|
||||||
|
for imgDialogs in pbar:
|
||||||
|
for dialog in imgDialogs["dialogs"]:
|
||||||
|
text.append(dialog["caption"])
|
||||||
|
tokenized_cap = tokenize(
|
||||||
|
dialog["caption"], punct_to_keep=[
|
||||||
|
';', ','], punct_to_remove=['?', '.'])
|
||||||
|
if len(tokenized_cap) > maxQ:
|
||||||
|
maxQ = len(tokenized_cap)
|
||||||
|
|
||||||
|
prog = [dialog["template"]] + \
|
||||||
|
list(map(lambda a: "_".join(a.split(" ")), dialog["args"]))
|
||||||
|
prog = " ".join(prog)
|
||||||
|
programs.append(prog)
|
||||||
|
for _round in dialog["dialog"]:
|
||||||
|
text.append(_round["question"])
|
||||||
|
tokenized_quest = tokenize(
|
||||||
|
_round["question"], punct_to_keep=[
|
||||||
|
';', ','], punct_to_remove=['?', '.'])
|
||||||
|
if len(tokenized_quest) > maxQ:
|
||||||
|
maxQ = len(tokenized_quest)
|
||||||
|
|
||||||
|
prog = [_round["template"]] + \
|
||||||
|
list(map(lambda a: "_".join(
|
||||||
|
a.split(" ")), _round["args"]))
|
||||||
|
prog = " ".join(prog)
|
||||||
|
|
||||||
|
programs.append(prog)
|
||||||
|
answers.append("_".join(str(_round["answer"]).split(" ")))
|
||||||
|
|
||||||
|
# print("longest question has {} tokens".format(maxQ))
|
||||||
|
answers = list(set(answers))
|
||||||
|
text.extend(answers)
|
||||||
|
answer_token_to_idx = build_vocab(
|
||||||
|
answers, punct_to_keep=[';', ','], punct_to_remove=['?', '.'])
|
||||||
|
text_token_to_idx = build_vocab(
|
||||||
|
text, punct_to_keep=[';', ','], punct_to_remove=['?', '.'])
|
||||||
|
prog_token_to_idx = build_vocab(programs, punct_to_keep=[
|
||||||
|
';', ','], punct_to_remove=['?', '.'])
|
||||||
|
|
||||||
|
idx_answer_to_token = {v: k for k, v in answer_token_to_idx.items()}
|
||||||
|
idx_text_to_token = {v: k for k, v in text_token_to_idx.items()}
|
||||||
|
idx_prog_to_token = {v: k for k, v in prog_token_to_idx.items()}
|
||||||
|
|
||||||
|
vocab = {
|
||||||
|
"text_token_to_idx": text_token_to_idx,
|
||||||
|
"prog_token_to_idx": prog_token_to_idx,
|
||||||
|
"answer_token_to_idx": answer_token_to_idx,
|
||||||
|
"idx_answer_to_token": idx_answer_to_token,
|
||||||
|
"idx_text_to_token": idx_text_to_token,
|
||||||
|
"idx_prog_to_token": idx_prog_to_token,
|
||||||
|
"maxQ": maxQ,
|
||||||
|
"maxP": 6,
|
||||||
|
}
|
||||||
|
|
||||||
|
else:
|
||||||
|
print("[INFO] Loading vocab ...")
|
||||||
|
|
||||||
|
with open(args.input_vocab_json, 'r') as f:
|
||||||
|
vocab = json.load(f)
|
||||||
|
print("[INFO] Vocab loaded from {} ...".format(args.input_vocab_json))
|
||||||
|
|
||||||
|
if args.output_vocab_json != "":
|
||||||
|
if not os.path.isdir(os.path.dirname(args.output_vocab_json)):
|
||||||
|
os.makedirs(os.path.dirname(args.output_vocab_json))
|
||||||
|
with open(args.output_vocab_json, 'w') as f:
|
||||||
|
json.dump(vocab, f)
|
||||||
|
print("[INFO] Vocab saved to {} ...".format(args.output_vocab_json))
|
||||||
|
|
||||||
|
# Encode all questions and programs
|
||||||
|
if args.split == "train":
|
||||||
|
if args.mode == "stack":
|
||||||
|
result = stack(allDialogs[args.val_size:], vocab, args.percentage,
|
||||||
|
split=args.split, num_rounds=args.num_rounds)
|
||||||
|
elif args.mode == "concat":
|
||||||
|
result = concat(allDialogs[args.val_size:], vocab, args.percentage,
|
||||||
|
split=args.split, num_rounds=args.num_rounds)
|
||||||
|
else:
|
||||||
|
print("[ERROR] {} is not supported. Choose between 'concat' and 'stack'".format(
|
||||||
|
args.mode))
|
||||||
|
raise ValueError
|
||||||
|
elif args.split == "val":
|
||||||
|
if args.mode == "stack":
|
||||||
|
result = stack(allDialogs[:args.val_size], vocab, 1.0,
|
||||||
|
split=args.split, num_rounds=args.num_rounds)
|
||||||
|
elif args.mode == "concat":
|
||||||
|
result = concat(allDialogs[:args.val_size], vocab, 1.0,
|
||||||
|
split=args.split, num_rounds=args.num_rounds)
|
||||||
|
else:
|
||||||
|
print("[ERROR] {} is not supported. Choose between 'concat' and 'stack'".format(
|
||||||
|
args.mode))
|
||||||
|
raise ValueError
|
||||||
|
elif args.split == "test":
|
||||||
|
if args.mode == "stack":
|
||||||
|
result = stack(allDialogs, vocab, args.percentage,
|
||||||
|
split=args.split, num_rounds=args.num_rounds)
|
||||||
|
elif args.mode == "concat":
|
||||||
|
result = concat(allDialogs, vocab, args.percentage,
|
||||||
|
split=args.split, num_rounds=args.num_rounds)
|
||||||
|
else:
|
||||||
|
print("[ERROR] {} is not supported. Choose between 'concat' and 'stack'".format(
|
||||||
|
args.mode))
|
||||||
|
raise ValueError
|
||||||
|
elif args.split == "finetune":
|
||||||
|
if args.mode == "stack":
|
||||||
|
result = stack(allDialogs, vocab, args.percentage,
|
||||||
|
split=args.split, num_rounds=args.num_rounds)
|
||||||
|
elif args.mode == "concat":
|
||||||
|
result = concat(allDialogs, vocab, args.percentage,
|
||||||
|
split=args.split, num_rounds=args.num_rounds)
|
||||||
|
else:
|
||||||
|
print("[ERROR] {} is not supported. Choose between 'concat' and 'stack'".format(
|
||||||
|
args.mode))
|
||||||
|
raise ValueError
|
||||||
|
else:
|
||||||
|
print("[ERROR] {} is not supported. Choose between 'train', 'val', and 'test'".format(
|
||||||
|
args.mode))
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
print("[INFO] Writing output ...")
|
||||||
|
|
||||||
|
if not os.path.isdir(os.path.dirname(args.output_h5_file)):
|
||||||
|
os.makedirs(os.path.dirname(args.output_h5_file))
|
||||||
|
|
||||||
|
for split in result:
|
||||||
|
if split != "train":
|
||||||
|
args.percentage = 1.0
|
||||||
|
with h5py.File(args.output_h5_file.format(split, args.num_rounds, args.percentage), 'w') as f:
|
||||||
|
for dataName in result[split]:
|
||||||
|
try:
|
||||||
|
data = np.asarray(result[split][dataName], dtype=np.int32)
|
||||||
|
f.create_dataset(dataName, data=data)
|
||||||
|
except ValueError as e:
|
||||||
|
print("[INFO] Error raise by {} ...".format(dataName))
|
||||||
|
raise e
|
||||||
|
|
||||||
|
print("[INFO] Done ...")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
94
prog_generator/clevrDialog_dataset.py
Normal file
94
prog_generator/clevrDialog_dataset.py
Normal file
|
@ -0,0 +1,94 @@
|
||||||
|
"""
|
||||||
|
author: Adnen Abdessaied
|
||||||
|
maintainer: "Adnen Abdessaied"
|
||||||
|
website: adnenabdessaied.de
|
||||||
|
version: 1.0.1
|
||||||
|
"""
|
||||||
|
|
||||||
|
import h5py
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
|
||||||
|
def invertDict(_dict):
|
||||||
|
return {v: k for k, v in _dict.items()}
|
||||||
|
|
||||||
|
|
||||||
|
class ClevrDialogDataset(Dataset):
|
||||||
|
def __init__(self, dataPath, vocabPath, split, indStart=0, indEnd=-1):
|
||||||
|
super(ClevrDialogDataset, self).__init__()
|
||||||
|
self.data = h5py.File(dataPath, "r")
|
||||||
|
with open(vocabPath, "r") as f:
|
||||||
|
self.vocab = json.load(f)
|
||||||
|
self.vocab["idx_text_to_token"] = invertDict(self.vocab["text_token_to_idx"])
|
||||||
|
self.vocab["idx_prog_to_token"] = invertDict(self.vocab["prog_token_to_idx"])
|
||||||
|
self.vocab["idx_prog_to_token"] = invertDict(self.vocab["prog_token_to_idx"])
|
||||||
|
self.lenVocabText = len(self.vocab["text_token_to_idx"])
|
||||||
|
self.lenVocabProg = len(self.vocab["prog_token_to_idx"])
|
||||||
|
|
||||||
|
self.split = split
|
||||||
|
self.indStart = indStart
|
||||||
|
self.indEnd = indEnd
|
||||||
|
self.maxSamples = indEnd - indStart
|
||||||
|
self.maxLenProg = 6
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class ClevrDialogCaptionDataset(ClevrDialogDataset):
|
||||||
|
def __init__(self, dataPath, vocabPath, split, name, indStart=0, indEnd=-1):
|
||||||
|
super(ClevrDialogCaptionDataset, self).__init__(dataPath, vocabPath, split, indStart=indStart, indEnd=indEnd)
|
||||||
|
self.captions = torch.LongTensor(np.asarray(self.data["captions"], dtype=np.int64)[indStart: indEnd])
|
||||||
|
self.captionsPrgs = torch.LongTensor(np.asarray(self.data["captionProgs"], dtype=np.int64)[indStart: indEnd])
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.captions)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
assert idx < len(self)
|
||||||
|
caption = self.captions[idx][:16]
|
||||||
|
captionPrg = self.captionsPrgs[idx]
|
||||||
|
return caption, captionPrg
|
||||||
|
|
||||||
|
|
||||||
|
class ClevrDialogQuestionDataset(ClevrDialogDataset):
|
||||||
|
def __init__(self, dataPath, vocabPath, split, name, train=True, indStart=0, indEnd=-1):
|
||||||
|
super(ClevrDialogQuestionDataset, self).__init__(dataPath, vocabPath, split, indStart=indStart, indEnd=indEnd)
|
||||||
|
self.questions = torch.LongTensor(np.asarray(self.data["questions"], dtype=np.int64)[indStart: indEnd])
|
||||||
|
self.quesProgs = torch.LongTensor(np.asarray(self.data["questionProgs"], dtype=np.int64)[indStart: indEnd])
|
||||||
|
self.questionRounds = torch.LongTensor(np.asarray(self.data["questionRounds"], dtype=np.int64)[indStart: indEnd])
|
||||||
|
self.questionImgIdx = torch.LongTensor(np.asarray(self.data["questionImgIdx"], dtype=np.int64)[indStart: indEnd])
|
||||||
|
self.histories = torch.LongTensor(np.asarray(self.data["histories"], dtype=np.int64)[indStart: indEnd])
|
||||||
|
self.historiesProgs = torch.LongTensor(np.asarray(self.data["historiesProg"], dtype=np.int64)[indStart: indEnd])
|
||||||
|
|
||||||
|
self.answers = torch.LongTensor(np.asarray(self.data["answers"], dtype=np.int64)[indStart: indEnd])
|
||||||
|
self.name = name
|
||||||
|
self.train = train
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.questions)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
assert idx < len(self)
|
||||||
|
question = self.questions[idx]
|
||||||
|
questionPrg = self.quesProgs[idx]
|
||||||
|
questionImgIdx = self.questionImgIdx[idx]
|
||||||
|
questionRound = self.questionRounds[idx]
|
||||||
|
|
||||||
|
history = self.histories[idx]
|
||||||
|
historiesProg = self.historiesProgs[idx]
|
||||||
|
|
||||||
|
answer = self.answers[idx]
|
||||||
|
if self.train:
|
||||||
|
return question, history, questionPrg, questionRound, answer
|
||||||
|
else:
|
||||||
|
return question, questionPrg, questionImgIdx, questionRound, history, historiesProg, answer
|
476
prog_generator/models.py
Normal file
476
prog_generator/models.py
Normal file
|
@ -0,0 +1,476 @@
|
||||||
|
"""
|
||||||
|
author: Adnen Abdessaied
|
||||||
|
maintainer: "Adnen Abdessaied"
|
||||||
|
website: adnenabdessaied.de
|
||||||
|
version: 1.0.1
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class FC(nn.Module):
|
||||||
|
def __init__(self, in_size, out_size, dropout_r=0., use_relu=True):
|
||||||
|
super(FC, self).__init__()
|
||||||
|
self.dropout_r = dropout_r
|
||||||
|
self.use_relu = use_relu
|
||||||
|
|
||||||
|
self.linear = nn.Linear(in_size, out_size)
|
||||||
|
|
||||||
|
if use_relu:
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
if dropout_r > 0:
|
||||||
|
self.dropout = nn.Dropout(dropout_r)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.linear(x)
|
||||||
|
|
||||||
|
if self.use_relu:
|
||||||
|
x = self.relu(x)
|
||||||
|
|
||||||
|
if self.dropout_r > 0:
|
||||||
|
x = self.dropout(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, in_size, mid_size, out_size, dropout_r=0., use_relu=True):
|
||||||
|
super(MLP, self).__init__()
|
||||||
|
|
||||||
|
self.fc = FC(in_size, mid_size, dropout_r=dropout_r, use_relu=use_relu)
|
||||||
|
self.linear = nn.Linear(mid_size, out_size)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.linear(self.fc(x))
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm(nn.Module):
|
||||||
|
def __init__(self, size, eps=1e-6):
|
||||||
|
super(LayerNorm, self).__init__()
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
self.a_2 = nn.Parameter(torch.ones(size))
|
||||||
|
self.b_2 = nn.Parameter(torch.zeros(size))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
mean = x.mean(-1, keepdim=True)
|
||||||
|
std = x.std(-1, keepdim=True)
|
||||||
|
|
||||||
|
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
|
||||||
|
|
||||||
|
|
||||||
|
class MHAtt(nn.Module):
|
||||||
|
def __init__(self, opts):
|
||||||
|
super(MHAtt, self).__init__()
|
||||||
|
self.opts = opts
|
||||||
|
|
||||||
|
self.linear_v = nn.Linear(opts.hiddenDim, opts.hiddenDim)
|
||||||
|
self.linear_k = nn.Linear(opts.hiddenDim, opts.hiddenDim)
|
||||||
|
self.linear_q = nn.Linear(opts.hiddenDim, opts.hiddenDim)
|
||||||
|
self.linear_merge = nn.Linear(opts.hiddenDim, opts.hiddenDim)
|
||||||
|
|
||||||
|
self.dropout = nn.Dropout(opts.dropout)
|
||||||
|
|
||||||
|
def forward(self, v, k, q, mask):
|
||||||
|
n_batches = q.size(0)
|
||||||
|
|
||||||
|
v = self.linear_v(v).view(
|
||||||
|
n_batches,
|
||||||
|
-1,
|
||||||
|
self.opts.multiHead,
|
||||||
|
self.opts.hiddenSizeHead
|
||||||
|
).transpose(1, 2)
|
||||||
|
|
||||||
|
k = self.linear_k(k).view(
|
||||||
|
n_batches,
|
||||||
|
-1,
|
||||||
|
self.opts.multiHead,
|
||||||
|
self.opts.hiddenSizeHead
|
||||||
|
).transpose(1, 2)
|
||||||
|
|
||||||
|
q = self.linear_q(q).view(
|
||||||
|
n_batches,
|
||||||
|
-1,
|
||||||
|
self.opts.multiHead,
|
||||||
|
self.opts.hiddenSizeHead
|
||||||
|
).transpose(1, 2)
|
||||||
|
|
||||||
|
atted = self.att(v, k, q, mask)
|
||||||
|
atted = atted.transpose(1, 2).contiguous().view(
|
||||||
|
n_batches,
|
||||||
|
-1,
|
||||||
|
self.opts.hiddenDim
|
||||||
|
)
|
||||||
|
|
||||||
|
atted = self.linear_merge(atted)
|
||||||
|
|
||||||
|
return atted
|
||||||
|
|
||||||
|
def att(self, value, key, query, mask):
|
||||||
|
d_k = query.size(-1)
|
||||||
|
|
||||||
|
scores = torch.matmul(
|
||||||
|
query, key.transpose(-2, -1)
|
||||||
|
) / math.sqrt(d_k)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
scores = scores.masked_fill(mask, -1e9)
|
||||||
|
|
||||||
|
att_map = F.softmax(scores, dim=-1)
|
||||||
|
att_map = self.dropout(att_map)
|
||||||
|
|
||||||
|
return torch.matmul(att_map, value)
|
||||||
|
|
||||||
|
class FFN(nn.Module):
|
||||||
|
def __init__(self, opts):
|
||||||
|
super(FFN, self).__init__()
|
||||||
|
|
||||||
|
self.mlp = MLP(
|
||||||
|
in_size=opts.hiddenDim,
|
||||||
|
mid_size=opts.FeedForwardSize,
|
||||||
|
out_size=opts.hiddenDim,
|
||||||
|
dropout_r=opts.dropout,
|
||||||
|
use_relu=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.mlp(x)
|
||||||
|
|
||||||
|
|
||||||
|
class SA(nn.Module):
|
||||||
|
def __init__(self, opts):
|
||||||
|
super(SA, self).__init__()
|
||||||
|
self.mhatt = MHAtt(opts)
|
||||||
|
self.ffn = FFN(opts)
|
||||||
|
|
||||||
|
self.dropout1 = nn.Dropout(opts.dropout)
|
||||||
|
self.norm1 = LayerNorm(opts.hiddenDim)
|
||||||
|
|
||||||
|
self.dropout2 = nn.Dropout(opts.dropout)
|
||||||
|
self.norm2 = LayerNorm(opts.hiddenDim)
|
||||||
|
|
||||||
|
def forward(self, x, x_mask):
|
||||||
|
x = self.norm1(x + self.dropout1(
|
||||||
|
self.mhatt(x, x, x, x_mask)
|
||||||
|
))
|
||||||
|
|
||||||
|
x = self.norm2(x + self.dropout2(
|
||||||
|
self.ffn(x)
|
||||||
|
))
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AttFlat(nn.Module):
|
||||||
|
def __init__(self, opts):
|
||||||
|
super(AttFlat, self).__init__()
|
||||||
|
self.opts = opts
|
||||||
|
|
||||||
|
self.mlp = MLP(
|
||||||
|
in_size=opts.hiddenDim,
|
||||||
|
mid_size=opts.FlatMLPSize,
|
||||||
|
out_size=opts.FlatGlimpses,
|
||||||
|
dropout_r=opts.dropout,
|
||||||
|
use_relu=True
|
||||||
|
)
|
||||||
|
# FLAT_GLIMPSES = 1
|
||||||
|
self.linear_merge = nn.Linear(
|
||||||
|
opts.hiddenDim * opts.FlatGlimpses,
|
||||||
|
opts.FlatOutSize
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, x_mask):
|
||||||
|
att = self.mlp(x)
|
||||||
|
att = att.masked_fill(
|
||||||
|
x_mask.squeeze(1).squeeze(1).unsqueeze(2),
|
||||||
|
-1e9
|
||||||
|
)
|
||||||
|
att = F.softmax(att, dim=1)
|
||||||
|
|
||||||
|
att_list = []
|
||||||
|
for i in range(self.opts.FlatGlimpses):
|
||||||
|
att_list.append(
|
||||||
|
torch.sum(att[:, :, i: i + 1] * x, dim=1)
|
||||||
|
)
|
||||||
|
|
||||||
|
x_atted = torch.cat(att_list, dim=1)
|
||||||
|
x_atted = self.linear_merge(x_atted)
|
||||||
|
|
||||||
|
return x_atted
|
||||||
|
|
||||||
|
class CaptionEncoder(nn.Module):
|
||||||
|
def __init__(self, opts, textVocabSize):
|
||||||
|
super(CaptionEncoder, self).__init__()
|
||||||
|
self.embedding = nn.Embedding(textVocabSize, opts.embedDim)
|
||||||
|
bidirectional = opts.bidirectional > 0
|
||||||
|
self.lstmC = nn.LSTM(
|
||||||
|
input_size=opts.embedDim,
|
||||||
|
hidden_size=opts.hiddenDim,
|
||||||
|
num_layers=opts.numLayers,
|
||||||
|
batch_first=True,
|
||||||
|
bidirectional=bidirectional
|
||||||
|
)
|
||||||
|
if bidirectional:
|
||||||
|
opts.hiddenDim *= 2
|
||||||
|
opts.hiddenSizeHead *= 2
|
||||||
|
opts.FlatOutSize *= 2
|
||||||
|
|
||||||
|
self.attCap = nn.ModuleList([SA(opts) for _ in range(opts.layers)])
|
||||||
|
self.attFlatCap = AttFlat(opts)
|
||||||
|
self.fc = nn.Linear(opts.hiddenDim, opts.hiddenDim)
|
||||||
|
|
||||||
|
def forward(self, cap, hist=None):
|
||||||
|
capMask = self.make_mask(cap.unsqueeze(2))
|
||||||
|
cap = self.embedding(cap)
|
||||||
|
cap, (_, _) = self.lstmC(cap)
|
||||||
|
capO = cap.detach().clone()
|
||||||
|
|
||||||
|
for attC in self.attCap:
|
||||||
|
cap = attC(cap, capMask)
|
||||||
|
# (batchSize, 512)
|
||||||
|
cap = self.attFlatCap(cap, capMask)
|
||||||
|
encOut = self.fc(cap)
|
||||||
|
return encOut, capO
|
||||||
|
|
||||||
|
class QuestEncoder_1(nn.Module):
|
||||||
|
"""
|
||||||
|
Concat encoder
|
||||||
|
"""
|
||||||
|
def __init__(self, opts, textVocabSize):
|
||||||
|
super(QuestEncoder_1, self).__init__()
|
||||||
|
bidirectional = opts.bidirectional > 0
|
||||||
|
|
||||||
|
self.embedding = nn.Embedding(textVocabSize, opts.embedDim)
|
||||||
|
self.lstmQ = nn.LSTM(
|
||||||
|
input_size=opts.embedDim,
|
||||||
|
hidden_size=opts.hiddenDim,
|
||||||
|
num_layers=opts.numLayers,
|
||||||
|
bidirectional=bidirectional,
|
||||||
|
batch_first=True
|
||||||
|
)
|
||||||
|
|
||||||
|
self.lstmH = nn.LSTM(
|
||||||
|
input_size=opts.embedDim,
|
||||||
|
hidden_size=opts.hiddenDim,
|
||||||
|
num_layers=opts.numLayers,
|
||||||
|
bidirectional=bidirectional,
|
||||||
|
batch_first=True)
|
||||||
|
|
||||||
|
if bidirectional:
|
||||||
|
opts.hiddenDim *= 2
|
||||||
|
opts.hiddenSizeHead *= 2
|
||||||
|
opts.FlatOutSize *= 2
|
||||||
|
self.attQues = nn.ModuleList([SA(opts) for _ in range(opts.layers)])
|
||||||
|
self.attHist = nn.ModuleList([SA(opts) for _ in range(opts.layers)])
|
||||||
|
|
||||||
|
self.attFlatQuest = AttFlat(opts)
|
||||||
|
self.fc = nn.Linear(2 * opts.hiddenDim, opts.hiddenDim)
|
||||||
|
|
||||||
|
def forward(self, quest, hist):
|
||||||
|
questMask = self.make_mask(quest.unsqueeze(2))
|
||||||
|
histMask = self.make_mask(hist.unsqueeze(2))
|
||||||
|
|
||||||
|
# quest = F.tanh(self.embedding(quest))
|
||||||
|
quest = self.embedding(quest)
|
||||||
|
|
||||||
|
quest, (_, _) = self.lstmQ(quest)
|
||||||
|
questO = quest.detach().clone()
|
||||||
|
|
||||||
|
hist = self.embedding(hist)
|
||||||
|
hist, (_, _) = self.lstmH(hist)
|
||||||
|
|
||||||
|
for attQ, attH in zip(self.attQues, self.attHist):
|
||||||
|
quest = attQ(quest, questMask)
|
||||||
|
hist = attH(hist, histMask)
|
||||||
|
# (batchSize, 512)
|
||||||
|
quest = self.attFlatQuest(quest, questMask)
|
||||||
|
|
||||||
|
# hist: (batchSize, length, 512)
|
||||||
|
attWeights = torch.sum(torch.mul(hist, quest.unsqueeze(1)), -1)
|
||||||
|
attWeights = torch.softmax(attWeights, -1)
|
||||||
|
hist = torch.sum(torch.mul(hist, attWeights.unsqueeze(2)), 1)
|
||||||
|
encOut = self.fc(torch.cat([quest, hist], -1))
|
||||||
|
|
||||||
|
return encOut, questO
|
||||||
|
|
||||||
|
# Masking
|
||||||
|
def make_mask(self, feature):
|
||||||
|
return (torch.sum(
|
||||||
|
torch.abs(feature),
|
||||||
|
dim=-1
|
||||||
|
) == 0).unsqueeze(1).unsqueeze(2)
|
||||||
|
|
||||||
|
|
||||||
|
class QuestEncoder_2(nn.Module):
|
||||||
|
"""
|
||||||
|
Stack encoder
|
||||||
|
"""
|
||||||
|
def __init__(self, opts, textVocabSize):
|
||||||
|
super(QuestEncoder_2, self).__init__()
|
||||||
|
bidirectional = opts.bidirectional > 0
|
||||||
|
self.embedding = nn.Embedding(textVocabSize, opts.embedDim)
|
||||||
|
self.lstmQ = nn.LSTM(
|
||||||
|
input_size=opts.embedDim,
|
||||||
|
hidden_size=opts.hiddenDim,
|
||||||
|
num_layers=opts.numLayers,
|
||||||
|
batch_first=True,
|
||||||
|
bidirectional=bidirectional,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.lstmH = nn.LSTM(
|
||||||
|
input_size=opts.embedDim,
|
||||||
|
hidden_size=opts.hiddenDim,
|
||||||
|
num_layers=opts.numLayers,
|
||||||
|
batch_first=True,
|
||||||
|
bidirectional=bidirectional,
|
||||||
|
)
|
||||||
|
if bidirectional:
|
||||||
|
opts.hiddenDim *= 2
|
||||||
|
|
||||||
|
self.fc = nn.Linear(2 * opts.hiddenDim, opts.hiddenDim)
|
||||||
|
|
||||||
|
def forward(self, quest, hist):
|
||||||
|
|
||||||
|
quest = F.tanh(self.embedding(quest))
|
||||||
|
quest, (questH, _) = self.lstmQ(quest)
|
||||||
|
|
||||||
|
# concatenate the last hidden states from the forward and backward pass
|
||||||
|
# of the bidirectional lstm
|
||||||
|
lastHiddenForward = questH[1:2, :, :].squeeze(0)
|
||||||
|
lastHiddenBackward = questH[3:4, :, :].squeeze(0)
|
||||||
|
|
||||||
|
# questH: (batchSize, 512)
|
||||||
|
questH = torch.cat([lastHiddenForward, lastHiddenBackward], -1)
|
||||||
|
|
||||||
|
questO = quest.detach().clone()
|
||||||
|
|
||||||
|
hist = F.tanh(self.embedding(hist))
|
||||||
|
numRounds = hist.size(1)
|
||||||
|
histFeat = []
|
||||||
|
for i in range(numRounds):
|
||||||
|
round_i = hist[:, i, :, :]
|
||||||
|
_, (round_i_h, _) = self.lstmH(round_i)
|
||||||
|
|
||||||
|
#Same as before
|
||||||
|
lastHiddenForward = round_i_h[1:2, :, :].squeeze(0)
|
||||||
|
lastHiddenBackward = round_i_h[3:4, :, :].squeeze(0)
|
||||||
|
histFeat.append(torch.cat([lastHiddenForward, lastHiddenBackward], -1))
|
||||||
|
|
||||||
|
# hist: (batchSize, rounds, 512)
|
||||||
|
histFeat = torch.stack(histFeat, 1)
|
||||||
|
attWeights = torch.sum(torch.mul(histFeat, questH.unsqueeze(1)), -1)
|
||||||
|
attWeights = torch.softmax(attWeights, -1)
|
||||||
|
histFeat = torch.sum(torch.mul(histFeat, attWeights.unsqueeze(2)), 1)
|
||||||
|
encOut = self.fc(torch.cat([questH, histFeat], -1))
|
||||||
|
return encOut, questO
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder(nn.Module):
|
||||||
|
def __init__(self, opts, progVocabSize, maxLen, startID=1, endID=2):
|
||||||
|
super(Decoder, self).__init__()
|
||||||
|
self.numLayers = opts.numLayers
|
||||||
|
self.bidirectional = opts.bidirectional > 0
|
||||||
|
self.maxLen = maxLen
|
||||||
|
self.startID = startID
|
||||||
|
self.endID = endID
|
||||||
|
|
||||||
|
self.embedding = nn.Embedding(progVocabSize, opts.embedDim)
|
||||||
|
self.lstmProg = nn.LSTM(
|
||||||
|
input_size=opts.embedDim,
|
||||||
|
hidden_size=2*opts.hiddenDim if self.bidirectional else opts.hiddenDim,
|
||||||
|
num_layers=opts.numLayers,
|
||||||
|
batch_first=True,
|
||||||
|
# bidirectional=self.bidirectional,
|
||||||
|
)
|
||||||
|
hiddenDim = opts.hiddenDim
|
||||||
|
if self.bidirectional:
|
||||||
|
hiddenDim *= 2
|
||||||
|
|
||||||
|
self.fcAtt = nn.Linear(2*hiddenDim, hiddenDim)
|
||||||
|
self.fcOut = nn.Linear(hiddenDim, progVocabSize)
|
||||||
|
|
||||||
|
def initPrgHidden(self, encOut):
|
||||||
|
hidden = [encOut for _ in range(self.numLayers)]
|
||||||
|
hidden = torch.stack(hidden, 0).contiguous()
|
||||||
|
return hidden, hidden
|
||||||
|
|
||||||
|
def forwardStep(self, prog, progH, questO):
|
||||||
|
batchSize = prog.size(0)
|
||||||
|
inputDim = questO.size(1)
|
||||||
|
prog = self.embedding(prog)
|
||||||
|
outProg, progH = self.lstmProg(prog, progH)
|
||||||
|
|
||||||
|
att = torch.bmm(outProg, questO.transpose(1, 2))
|
||||||
|
att = F.softmax(att.view(-1, inputDim), 1).view(batchSize, -1, inputDim)
|
||||||
|
context = torch.bmm(att, questO)
|
||||||
|
# (batchSize, progLength, hiddenDim)
|
||||||
|
out = F.tanh(self.fcAtt(torch.cat([outProg, context], dim=-1)))
|
||||||
|
|
||||||
|
# (batchSize, progLength, progVocabSize)
|
||||||
|
out = self.fcOut(out)
|
||||||
|
predSoftmax = F.log_softmax(out, 2)
|
||||||
|
return predSoftmax, progH
|
||||||
|
|
||||||
|
def forward(self, prog, encOut, questO):
|
||||||
|
progH = self.initPrgHidden(encOut)
|
||||||
|
predSoftmax, progH = self.forwardStep(prog, progH, questO)
|
||||||
|
|
||||||
|
return predSoftmax, progH
|
||||||
|
|
||||||
|
def sample(self, encOut, questO):
|
||||||
|
batchSize = encOut.size(0)
|
||||||
|
cudaFlag = encOut.is_cuda
|
||||||
|
progH = self.initPrgHidden(encOut)
|
||||||
|
# prog = progCopy[:, 0:3]
|
||||||
|
prog = torch.LongTensor(batchSize, 1).fill_(self.startID)
|
||||||
|
# prog = torch.cat((progStart, progEnd), -1)
|
||||||
|
if cudaFlag:
|
||||||
|
prog = prog.cuda()
|
||||||
|
outputLogProbs = []
|
||||||
|
outputTokens = []
|
||||||
|
|
||||||
|
def decode(i, output):
|
||||||
|
tokens = output.topk(1, dim=-1)[1].view(batchSize, -1)
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
for i in range(self.maxLen):
|
||||||
|
predSoftmax, progH = self.forwardStep(prog, progH, questO)
|
||||||
|
prog = decode(i, predSoftmax)
|
||||||
|
|
||||||
|
return outputTokens, outputLogProbs
|
||||||
|
|
||||||
|
|
||||||
|
class SeqToSeqC(nn.Module):
|
||||||
|
def __init__(self, encoder, decoder):
|
||||||
|
super(SeqToSeqC, self).__init__()
|
||||||
|
self.encoder = encoder
|
||||||
|
self.decoder = decoder
|
||||||
|
|
||||||
|
def forward(self, cap, imgFeat, prog):
|
||||||
|
encOut, capO = self.encoder(cap, imgFeat)
|
||||||
|
predSoftmax, progHC = self.decoder(prog, encOut, capO)
|
||||||
|
return predSoftmax, progHC
|
||||||
|
|
||||||
|
|
||||||
|
class SeqToSeqQ(nn.Module):
|
||||||
|
def __init__(self, encoder, decoder):
|
||||||
|
super(SeqToSeqQ, self).__init__()
|
||||||
|
self.encoder = encoder
|
||||||
|
self.decoder = decoder
|
||||||
|
|
||||||
|
def forward(self, quest, hist, prog):
|
||||||
|
encOut, questO = self.encoder(quest, hist)
|
||||||
|
predSoftmax, progHC = self.decoder(prog, encOut, questO)
|
||||||
|
return predSoftmax, progHC
|
||||||
|
|
||||||
|
def sample(self, quest, hist):
|
||||||
|
with torch.no_grad():
|
||||||
|
encOut, questO = self.encoder(quest, hist)
|
||||||
|
outputTokens, outputLogProbs = self.decoder.sample(encOut, questO)
|
||||||
|
outputTokens = torch.stack(outputTokens, 0).transpose(0, 1)
|
||||||
|
return outputTokens
|
79
prog_generator/optim.py
Normal file
79
prog_generator/optim.py
Normal file
|
@ -0,0 +1,79 @@
|
||||||
|
"""
|
||||||
|
author: Adnen Abdessaied
|
||||||
|
maintainer: "Adnen Abdessaied"
|
||||||
|
website: adnenabdessaied.de
|
||||||
|
version: 1.0.1
|
||||||
|
"""
|
||||||
|
|
||||||
|
# --------------------------------------------------------
|
||||||
|
# adapted from https://github.com/MILVLG/mcan-vqa/blob/master/core/model/optim.py
|
||||||
|
# --------------------------------------------------------
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.optim as Optim
|
||||||
|
|
||||||
|
|
||||||
|
class WarmupOptimizer(object):
|
||||||
|
def __init__(self, lr_base, optimizer, data_size, batch_size):
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self._step = 0
|
||||||
|
self.lr_base = lr_base
|
||||||
|
self._rate = 0
|
||||||
|
self.data_size = data_size
|
||||||
|
self.batch_size = batch_size
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
self._step += 1
|
||||||
|
|
||||||
|
rate = self.rate()
|
||||||
|
for p in self.optimizer.param_groups:
|
||||||
|
p['lr'] = rate
|
||||||
|
self._rate = rate
|
||||||
|
|
||||||
|
self.optimizer.step()
|
||||||
|
|
||||||
|
def zero_grad(self):
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
|
def rate(self, step=None):
|
||||||
|
if step is None:
|
||||||
|
step = self._step
|
||||||
|
|
||||||
|
if step <= int(self.data_size / self.batch_size * 1):
|
||||||
|
r = self.lr_base * 1/2.
|
||||||
|
else:
|
||||||
|
r = self.lr_base
|
||||||
|
|
||||||
|
return r
|
||||||
|
|
||||||
|
|
||||||
|
def get_optim(opts, model, data_size, lr_base=None):
|
||||||
|
if lr_base is None:
|
||||||
|
lr_base = opts.lr
|
||||||
|
|
||||||
|
if opts.optim == 'adam':
|
||||||
|
optim = Optim.Adam(
|
||||||
|
filter(lambda p: p.requires_grad, model.parameters()),
|
||||||
|
lr=0,
|
||||||
|
betas=opts.betas,
|
||||||
|
eps=opts.eps,
|
||||||
|
|
||||||
|
)
|
||||||
|
elif opts.optim == 'rmsprop':
|
||||||
|
optim = Optim.RMSprop(
|
||||||
|
filter(lambda p: p.requires_grad, model.parameters()),
|
||||||
|
lr=0,
|
||||||
|
eps=opts.eps,
|
||||||
|
weight_decay=opts.weight_decay
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError('{} optimizer is not supported'.fromat(opts.optim))
|
||||||
|
return WarmupOptimizer(
|
||||||
|
lr_base,
|
||||||
|
optim,
|
||||||
|
data_size,
|
||||||
|
opts.batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
def adjust_lr(optim, decay_r):
|
||||||
|
optim.lr_base *= decay_r
|
283
prog_generator/options_caption_parser.py
Normal file
283
prog_generator/options_caption_parser.py
Normal file
|
@ -0,0 +1,283 @@
|
||||||
|
|
||||||
|
"""
|
||||||
|
author: Adnen Abdessaied
|
||||||
|
maintainer: "Adnen Abdessaied"
|
||||||
|
website: adnenabdessaied.de
|
||||||
|
version: 1.0.1
|
||||||
|
"""
|
||||||
|
# --------------------------------------------------------
|
||||||
|
# adapted from https://github.com/kexinyi/ns-vqa/blob/master/scene_parse/attr_net/options.py
|
||||||
|
# --------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import utils
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class Options():
|
||||||
|
def __init__(self):
|
||||||
|
self.parser = argparse.ArgumentParser()
|
||||||
|
self.initialized = False
|
||||||
|
|
||||||
|
def initialize(self):
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--mode',
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
choices=['train', 'test'],
|
||||||
|
help='The mode of the experiment')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--run_dir',
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help='The experiment directory')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--load_checkpoint_path',
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
help='The path the the pretrained CaptionNet')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--res_path',
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help='Path where to log the predicted caption programs')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--gpu_ids',
|
||||||
|
default='0',
|
||||||
|
type=str,
|
||||||
|
help='Id of the gpu to be used')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--seed',
|
||||||
|
default=42,
|
||||||
|
type=int,
|
||||||
|
help='The seed used in training')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--dataPathTr',
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help='Path to the h5 file of the Clevr-Dialog preprocessed training data')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--dataPathVal',
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help='Path to the h5 file of the Clevr-Dialog preprocessed validation data')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--dataPathTest',
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help='Path to the h5 file of the Clevr-Dialog preprocessed test data')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--vocabPath',
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help='Path to the generated vocabulary')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--batch_size',
|
||||||
|
default=64,
|
||||||
|
type=int,
|
||||||
|
help='Batch size')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--num_workers',
|
||||||
|
default=0,
|
||||||
|
type=int,
|
||||||
|
help='Number of workers for loading')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--num_iters',
|
||||||
|
default=5000,
|
||||||
|
type=int,
|
||||||
|
help='Total number of iterations')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--display_every',
|
||||||
|
default=5,
|
||||||
|
type=int,
|
||||||
|
help='Display training information every N iterations')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--debug_every',
|
||||||
|
default=100,
|
||||||
|
type=int,
|
||||||
|
help='Display debug message every N iterations')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--validate_every',
|
||||||
|
default=1000,
|
||||||
|
type=int,
|
||||||
|
help='Validate every N iterations')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--shuffle_data',
|
||||||
|
default=1,
|
||||||
|
type=int,
|
||||||
|
help='Activate to shuffle the training data')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--optim',
|
||||||
|
default='adam',
|
||||||
|
type=str,
|
||||||
|
help='The name of the optimizer to be used')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--lr',
|
||||||
|
default=1e-3,
|
||||||
|
type=float,
|
||||||
|
help='Base learning rate')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--betas',
|
||||||
|
default='0.9, 0.98',
|
||||||
|
type=str,
|
||||||
|
help='Adam optimizer\'s betas')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--eps',
|
||||||
|
default='1e-9',
|
||||||
|
type=float,
|
||||||
|
help='Adam optimizer\'s epsilon')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--lr_decay_marks',
|
||||||
|
default='50000, 55000',
|
||||||
|
type=str,
|
||||||
|
help='Learing rate decay marks')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--lr_decay_factor',
|
||||||
|
default=0.5,
|
||||||
|
type=float,
|
||||||
|
help='Learning rate decay factor')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--weight_decay',
|
||||||
|
default=1e-6,
|
||||||
|
type=float,
|
||||||
|
help='Weight decay')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--embedDim',
|
||||||
|
default=300,
|
||||||
|
type=int,
|
||||||
|
help='Embedding dimension')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--hiddenDim',
|
||||||
|
default=512,
|
||||||
|
type=int,
|
||||||
|
help='LSTM hidden dimension')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--numLayers',
|
||||||
|
default=2,
|
||||||
|
type=int,
|
||||||
|
help='Number of hidden LSTM layers')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--dropout',
|
||||||
|
default=0.1,
|
||||||
|
type=float,
|
||||||
|
help='Dropout value')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--multiHead',
|
||||||
|
default=8,
|
||||||
|
type=int,
|
||||||
|
help='Number of attention heads')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--hiddenSizeHead',
|
||||||
|
default=64,
|
||||||
|
type=int,
|
||||||
|
help='Dimension of each attention head')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--FeedForwardSize',
|
||||||
|
default=2048,
|
||||||
|
type=int,
|
||||||
|
help='Dimension of the feed forward layer')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--FlatMLPSize',
|
||||||
|
default=512,
|
||||||
|
type=int,
|
||||||
|
help='MLP flatten size')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--FlatGlimpses',
|
||||||
|
default=1,
|
||||||
|
type=int,
|
||||||
|
help='Number of flatten glimpses')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--FlatOutSize',
|
||||||
|
default=512,
|
||||||
|
type=int,
|
||||||
|
help='Final attention reduction dimension')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--layers',
|
||||||
|
default=6,
|
||||||
|
type=int,
|
||||||
|
help='Number of self attention layers')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--bidirectional',
|
||||||
|
default=1,
|
||||||
|
type=int,
|
||||||
|
help='Activate to use bidirectional LSTMs')
|
||||||
|
|
||||||
|
self.initialized = True
|
||||||
|
|
||||||
|
def parse(self):
|
||||||
|
# initialize parser
|
||||||
|
if not self.initialized:
|
||||||
|
self.initialize()
|
||||||
|
self.opts = self.parser.parse_args()
|
||||||
|
|
||||||
|
# parse gpu id list
|
||||||
|
str_gpu_ids = self.opts.gpu_ids.split(',')
|
||||||
|
self.opts.gpu_ids = []
|
||||||
|
for str_id in str_gpu_ids:
|
||||||
|
if str_id.isdigit() and int(str_id) >= 0:
|
||||||
|
self.opts.gpu_ids.append(int(str_id))
|
||||||
|
if len(self.opts.gpu_ids) > 0 and torch.cuda.is_available():
|
||||||
|
print('\n[INFO] Using {} CUDA device(s) ...'.format(len(self.opts.gpu_ids)))
|
||||||
|
else:
|
||||||
|
print('\n[INFO] Using cpu ...')
|
||||||
|
self.opts.gpu_ids = []
|
||||||
|
|
||||||
|
# parse the optimizer's betas and lr decay marks
|
||||||
|
self.opts.betas = [float(beta) for beta in self.opts.betas.split(',')]
|
||||||
|
lr_decay_marks = [int(m) for m in self.opts.lr_decay_marks.split(',')]
|
||||||
|
for i in range(1, len(lr_decay_marks)):
|
||||||
|
assert lr_decay_marks[i] > lr_decay_marks[i-1]
|
||||||
|
self.opts.lr_decay_marks = lr_decay_marks
|
||||||
|
|
||||||
|
# print and save options
|
||||||
|
args = vars(self.opts)
|
||||||
|
print('\n ' + 30*'-' + 'Opts' + 30*'-')
|
||||||
|
for k, v in args.items():
|
||||||
|
print('%s: %s' % (str(k), str(v)))
|
||||||
|
|
||||||
|
if not os.path.isdir(self.opts.run_dir):
|
||||||
|
os.makedirs(self.opts.run_dir)
|
||||||
|
filename = 'opts.txt'
|
||||||
|
file_path = os.path.join(self.opts.run_dir, filename)
|
||||||
|
with open(file_path, 'wt') as fout:
|
||||||
|
fout.write('| options\n')
|
||||||
|
for k, v in sorted(args.items()):
|
||||||
|
fout.write('%s: %s\n' % (str(k), str(v)))
|
||||||
|
return self.opts
|
326
prog_generator/options_question_parser.py
Normal file
326
prog_generator/options_question_parser.py
Normal file
|
@ -0,0 +1,326 @@
|
||||||
|
"""
|
||||||
|
author: Adnen Abdessaied
|
||||||
|
maintainer: "Adnen Abdessaied"
|
||||||
|
website: adnenabdessaied.de
|
||||||
|
version: 1.0.1
|
||||||
|
"""
|
||||||
|
# --------------------------------------------------------
|
||||||
|
# adapted from https://github.com/kexinyi/ns-vqa/blob/master/scene_parse/attr_net/options.py
|
||||||
|
# --------------------------------------------------------
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import utils
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class Options():
|
||||||
|
def __init__(self):
|
||||||
|
self.parser = argparse.ArgumentParser()
|
||||||
|
self.initialized = False
|
||||||
|
|
||||||
|
def initialize(self):
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--mode',
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
choices=['train', 'test_with_gt', 'test_with_pred'],
|
||||||
|
help='The mode of the experiment')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--run_dir',
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help='The experiment directory')
|
||||||
|
|
||||||
|
# self.parser.add_argument('--dataset', default='clevr', type=str, help='dataset')
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--text_log_dir',
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help='File to save the logged text')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--questionNetPath',
|
||||||
|
default='',
|
||||||
|
type=str,
|
||||||
|
help='Path to the pretrained QuestionNet that will be used for testing.')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--captionNetPath',
|
||||||
|
default='',
|
||||||
|
type=str,
|
||||||
|
help='Path to the pretrained CaptionNet that will be used for testing.')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--dialogLen',
|
||||||
|
default=10,
|
||||||
|
type=int,
|
||||||
|
help='Length of the dialogs to be used for testing. We used 10, 15, and 20 in our experiments.')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--last_n_rounds',
|
||||||
|
default=10,
|
||||||
|
type=int,
|
||||||
|
help='Number of the last rounds to consider in the history. We used 1, 2, 3, 4, and 10 in our experiments. ')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--encoderType',
|
||||||
|
required=True,
|
||||||
|
type=int,
|
||||||
|
choices=[1, 2],
|
||||||
|
help='Type of the encoder: 1 --> Concat, 2 --> Stack')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--load_checkpoint_path',
|
||||||
|
default='None',
|
||||||
|
type=str,
|
||||||
|
help='Path to a QestionNet checkpoint path to resume training')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--gpu_ids',
|
||||||
|
default='0',
|
||||||
|
type=str,
|
||||||
|
help='Id of the gpu to be used')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--seed',
|
||||||
|
default=42,
|
||||||
|
type=int,
|
||||||
|
help='The seed used in training')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--dataPathTr',
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help='Path to the h5 file of the Clevr-Dialog preprocessed training data')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--dataPathVal',
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help='Path to the h5 file of the Clevr-Dialog preprocessed validation data')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--dataPathTest',
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help='Path to the h5 file of the Clevr-Dialog preprocessed test data')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--scenesPath',
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help='Path to the derendered clevr-dialog scenes')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--vocabPath',
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help='Path to the generated vocabulary')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--batch_size',
|
||||||
|
default=64,
|
||||||
|
type=int,
|
||||||
|
help='Batch size')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--countFirstFailueRound',
|
||||||
|
default=0,
|
||||||
|
type=int,
|
||||||
|
help='If activated, we count the first failure round')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--maxSamples',
|
||||||
|
default=-1,
|
||||||
|
type=int,
|
||||||
|
help='Maximum number of training samples')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--num_workers',
|
||||||
|
default=0,
|
||||||
|
type=int,
|
||||||
|
help='Number of workers for loading')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--num_iters',
|
||||||
|
default=5000,
|
||||||
|
type=int,
|
||||||
|
help='Total number of iterations')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--display_every',
|
||||||
|
default=5,
|
||||||
|
type=int,
|
||||||
|
help='Display training information every N iterations')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--validate_every',
|
||||||
|
default=1000,
|
||||||
|
type=int,
|
||||||
|
help='Validate every N iterations')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--shuffle_data',
|
||||||
|
default=1,
|
||||||
|
type=int,
|
||||||
|
help='Activate to shuffle the training data')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--optim',
|
||||||
|
default='adam',
|
||||||
|
type=str,
|
||||||
|
help='The name of the optimizer to be used')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--lr',
|
||||||
|
default=1e-3,
|
||||||
|
type=float,
|
||||||
|
help='Base learning rate')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--betas',
|
||||||
|
default='0.9, 0.98',
|
||||||
|
type=str,
|
||||||
|
help='Adam optimizer\'s betas')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--eps',
|
||||||
|
default='1e-9',
|
||||||
|
type=float,
|
||||||
|
help='Adam optimizer\'s epsilon')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--lr_decay_marks',
|
||||||
|
default='50000, 55000',
|
||||||
|
type=str,
|
||||||
|
help='Learing rate decay marks')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--lr_decay_factor',
|
||||||
|
default=0.5,
|
||||||
|
type=float,
|
||||||
|
help='Learning rate decay factor')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--weight_decay',
|
||||||
|
default=1e-6,
|
||||||
|
type=float,
|
||||||
|
help='Weight decay')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--embedDim',
|
||||||
|
default=300,
|
||||||
|
type=int,
|
||||||
|
help='Embedding dimension')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--hiddenDim',
|
||||||
|
default=512,
|
||||||
|
type=int,
|
||||||
|
help='LSTM hidden dimension')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--numLayers',
|
||||||
|
default=2,
|
||||||
|
type=int,
|
||||||
|
help='Number of hidden LSTM layers')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--dropout',
|
||||||
|
default=0.1,
|
||||||
|
type=float,
|
||||||
|
help='Dropout value')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--multiHead',
|
||||||
|
default=8,
|
||||||
|
type=int,
|
||||||
|
help='Number of attention heads')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--hiddenSizeHead',
|
||||||
|
default=64,
|
||||||
|
type=int,
|
||||||
|
help='Dimension of each attention head')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--FeedForwardSize',
|
||||||
|
default=2048,
|
||||||
|
type=int,
|
||||||
|
help='Dimension of the feed forward layer')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--FlatMLPSize',
|
||||||
|
default=512,
|
||||||
|
type=int,
|
||||||
|
help='MLP flatten size')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--FlatGlimpses',
|
||||||
|
default=1,
|
||||||
|
type=int,
|
||||||
|
help='Number of flatten glimpses')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--FlatOutSize',
|
||||||
|
default=512,
|
||||||
|
type=int,
|
||||||
|
help='Final attention reduction dimension')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--layers',
|
||||||
|
default=6,
|
||||||
|
type=int,
|
||||||
|
help='Number of self attention layers')
|
||||||
|
|
||||||
|
self.parser.add_argument(
|
||||||
|
'--bidirectional',
|
||||||
|
default=1,
|
||||||
|
type=int,
|
||||||
|
help='Activate to use bidirectional LSTMs')
|
||||||
|
|
||||||
|
self.initialized = True
|
||||||
|
|
||||||
|
def parse(self):
|
||||||
|
# initialize parser
|
||||||
|
if not self.initialized:
|
||||||
|
self.initialize()
|
||||||
|
self.opts = self.parser.parse_args()
|
||||||
|
|
||||||
|
# parse gpu id list
|
||||||
|
str_gpu_ids = self.opts.gpu_ids.split(',')
|
||||||
|
self.opts.gpu_ids = []
|
||||||
|
for str_id in str_gpu_ids:
|
||||||
|
if str_id.isdigit() and int(str_id) >= 0:
|
||||||
|
self.opts.gpu_ids.append(int(str_id))
|
||||||
|
if len(self.opts.gpu_ids) > 0 and torch.cuda.is_available():
|
||||||
|
print('\n[INFO] Using {} CUDA device(s) ...'.format(
|
||||||
|
len(self.opts.gpu_ids)))
|
||||||
|
else:
|
||||||
|
print('\n[INFO] Using cpu ...')
|
||||||
|
self.opts.gpu_ids = []
|
||||||
|
|
||||||
|
# parse the optimizer's betas and lr decay marks
|
||||||
|
self.opts.betas = [float(beta) for beta in self.opts.betas.split(',')]
|
||||||
|
lr_decay_marks = [int(m) for m in self.opts.lr_decay_marks.split(',')]
|
||||||
|
for i in range(1, len(lr_decay_marks)):
|
||||||
|
assert lr_decay_marks[i] > lr_decay_marks[i-1]
|
||||||
|
self.opts.lr_decay_marks = lr_decay_marks
|
||||||
|
|
||||||
|
# print and save options
|
||||||
|
args = vars(self.opts)
|
||||||
|
print('\n ' + 30*'-' + 'Opts' + 30*'-')
|
||||||
|
for k, v in args.items():
|
||||||
|
print('%s: %s' % (str(k), str(v)))
|
||||||
|
|
||||||
|
if not os.path.isdir(self.opts.run_dir):
|
||||||
|
os.makedirs(self.opts.run_dir)
|
||||||
|
filename = 'opts.txt'
|
||||||
|
file_path = os.path.join(self.opts.run_dir, filename)
|
||||||
|
with open(file_path, 'wt') as fout:
|
||||||
|
fout.write('| options\n')
|
||||||
|
for k, v in sorted(args.items()):
|
||||||
|
fout.write('%s: %s\n' % (str(k), str(v)))
|
||||||
|
return self.opts
|
280
prog_generator/train_caption_parser.py
Normal file
280
prog_generator/train_caption_parser.py
Normal file
|
@ -0,0 +1,280 @@
|
||||||
|
"""
|
||||||
|
author: Adnen Abdessaied
|
||||||
|
maintainer: "Adnen Abdessaied"
|
||||||
|
website: adnenabdessaied.de
|
||||||
|
version: 1.0.1
|
||||||
|
"""
|
||||||
|
|
||||||
|
from clevrDialog_dataset import ClevrDialogCaptionDataset
|
||||||
|
from models import SeqToSeqC, CaptionEncoder, Decoder
|
||||||
|
from optim import get_optim, adjust_lr
|
||||||
|
from options_caption_parser import Options
|
||||||
|
import os, json, torch, pickle, copy, time
|
||||||
|
import numpy as np
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.utils.data as Data
|
||||||
|
from tensorboardX import SummaryWriter
|
||||||
|
|
||||||
|
|
||||||
|
class Execution:
|
||||||
|
def __init__(self, opts):
|
||||||
|
self.opts = opts
|
||||||
|
|
||||||
|
self.loss_fn = torch.nn.NLLLoss().cuda()
|
||||||
|
|
||||||
|
print("[INFO] Loading dataset ...")
|
||||||
|
|
||||||
|
self.dataset_tr = ClevrDialogCaptionDataset(
|
||||||
|
opts.dataPathTr, opts.vocabPath, "train", "Captions Tr")
|
||||||
|
|
||||||
|
self.dataset_val = ClevrDialogCaptionDataset(
|
||||||
|
opts.dataPathVal, opts.vocabPath, "val", "Captions Val")
|
||||||
|
|
||||||
|
self.dataset_test = ClevrDialogCaptionDataset(
|
||||||
|
opts.dataPathTest, opts.vocabPath, "test", "Captions Test")
|
||||||
|
|
||||||
|
tb_path = os.path.join(opts.run_dir, "tb_logdir")
|
||||||
|
if not os.path.isdir(tb_path):
|
||||||
|
os.makedirs(tb_path)
|
||||||
|
|
||||||
|
self.ckpt_path = os.path.join(opts.run_dir, "ckpt_dir")
|
||||||
|
if not os.path.isdir(self.ckpt_path):
|
||||||
|
os.makedirs(self.ckpt_path)
|
||||||
|
|
||||||
|
self.writer = SummaryWriter(tb_path)
|
||||||
|
self.iter_val = 0
|
||||||
|
self.bestValAcc = float("-inf")
|
||||||
|
self.bestValIter = -1
|
||||||
|
|
||||||
|
def constructNet(self, lenVocabText, lenVocabProg, maxLenProg, ):
|
||||||
|
decoder = Decoder(self.opts, lenVocabProg, maxLenProg)
|
||||||
|
encoder = CaptionEncoder(self.opts, lenVocabText)
|
||||||
|
net = SeqToSeqC(encoder, decoder)
|
||||||
|
return net
|
||||||
|
|
||||||
|
def train(self, dataset, dataset_val=None):
|
||||||
|
# Obtain needed information
|
||||||
|
lenVocabText = dataset.lenVocabText
|
||||||
|
lenVocabProg = dataset.lenVocabProg
|
||||||
|
maxLenProg = dataset.maxLenProg
|
||||||
|
net = self.constructNet(lenVocabText, lenVocabProg, maxLenProg)
|
||||||
|
|
||||||
|
net.cuda()
|
||||||
|
net.train()
|
||||||
|
|
||||||
|
# Define the multi-gpu training if needed
|
||||||
|
if len(self.opts.gpu_ids) > 1:
|
||||||
|
net = nn.DataParallel(net, device_ids=self.opts.gpu_ids)
|
||||||
|
|
||||||
|
# Load checkpoint if resume training
|
||||||
|
if self.opts.load_checkpoint_path is not None:
|
||||||
|
print("[INFO] Resume trainig from ckpt {} ...".format(
|
||||||
|
self.opts.load_checkpoint_path
|
||||||
|
))
|
||||||
|
|
||||||
|
# Load the network parameters
|
||||||
|
ckpt = torch.load(self.opts.load_checkpoint_path)
|
||||||
|
print("[INFO] Checkpoint successfully loaded ...")
|
||||||
|
net.load_state_dict(ckpt['state_dict'])
|
||||||
|
|
||||||
|
# Load the optimizer paramters
|
||||||
|
optim = get_optim(self.opts, net, len(dataset), lr_base=ckpt['lr_base'])
|
||||||
|
optim.optimizer.load_state_dict(ckpt['optimizer'])
|
||||||
|
|
||||||
|
else:
|
||||||
|
optim = get_optim(self.opts, net, len(dataset))
|
||||||
|
_iter = 0
|
||||||
|
epoch = 0
|
||||||
|
|
||||||
|
# Define dataloader
|
||||||
|
dataloader = Data.DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=self.opts.batch_size,
|
||||||
|
shuffle=self.opts.shuffle_data,
|
||||||
|
num_workers=self.opts.num_workers,
|
||||||
|
)
|
||||||
|
_iterCur = 0
|
||||||
|
_totalCur = len(dataloader)
|
||||||
|
# Training loop
|
||||||
|
while _iter < self.opts.num_iters:
|
||||||
|
# Learning Rate Decay
|
||||||
|
if _iter in self.opts.lr_decay_marks:
|
||||||
|
adjust_lr(optim, self.opts.lr_decay_factor)
|
||||||
|
|
||||||
|
time_start = time.time()
|
||||||
|
# Iteration
|
||||||
|
for caption, captionPrg in dataloader:
|
||||||
|
if _iter >= self.opts.num_iters:
|
||||||
|
break
|
||||||
|
caption = caption.cuda()
|
||||||
|
captionPrg = captionPrg.cuda()
|
||||||
|
captionPrgTarget = captionPrg.clone()
|
||||||
|
optim.zero_grad()
|
||||||
|
|
||||||
|
predSoftmax, _ = net(caption, captionPrg)
|
||||||
|
|
||||||
|
loss = self.loss_fn(
|
||||||
|
predSoftmax[:, :-1, :].contiguous().view(-1, predSoftmax.size(2)),
|
||||||
|
captionPrgTarget[:, 1:].contiguous().view(-1))
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
# logging
|
||||||
|
self.writer.add_scalar(
|
||||||
|
'train/loss',
|
||||||
|
loss.cpu().data.numpy(),
|
||||||
|
global_step=_iter)
|
||||||
|
|
||||||
|
self.writer.add_scalar(
|
||||||
|
'train/lr',
|
||||||
|
optim._rate,
|
||||||
|
global_step=_iter)
|
||||||
|
if _iter % self.opts.display_every == 0:
|
||||||
|
print("\r[CLEVR-Dialog - %s (%d/%4d)][epoch %2d][iter %4d/%4d] loss: %.4f, lr: %.2e" % (
|
||||||
|
dataset.name,
|
||||||
|
_iterCur,
|
||||||
|
_totalCur,
|
||||||
|
epoch,
|
||||||
|
_iter,
|
||||||
|
self.opts.num_iters,
|
||||||
|
loss.cpu().data.numpy(),
|
||||||
|
optim._rate,
|
||||||
|
), end=' ')
|
||||||
|
optim.step()
|
||||||
|
_iter += 1
|
||||||
|
_iterCur += 1
|
||||||
|
|
||||||
|
if _iter % self.opts.validate_every == 0:
|
||||||
|
if dataset_val is not None:
|
||||||
|
valAcc = self.eval(
|
||||||
|
net,
|
||||||
|
dataset_val,
|
||||||
|
valid=True,
|
||||||
|
)
|
||||||
|
if valAcc > self.bestValAcc:
|
||||||
|
self.bestValAcc = valAcc
|
||||||
|
self.bestValIter = _iter
|
||||||
|
|
||||||
|
print("[INFO] Checkpointing model @ iter {}".format(_iter))
|
||||||
|
state = {
|
||||||
|
'state_dict': net.state_dict(),
|
||||||
|
'optimizer': optim.optimizer.state_dict(),
|
||||||
|
'lr_base': optim.lr_base,
|
||||||
|
'optim': optim.lr_base,
|
||||||
|
'last_iter': _iter,
|
||||||
|
'last_epoch': epoch,
|
||||||
|
}
|
||||||
|
# checkpointing
|
||||||
|
torch.save(
|
||||||
|
state,
|
||||||
|
os.path.join(self.ckpt_path, 'ckpt_iter' + str(_iter) + '.pkl')
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print("[INFO] No validation dataset available")
|
||||||
|
|
||||||
|
time_end = time.time()
|
||||||
|
print('Finished epoch in {}s'.format(int(time_end-time_start)))
|
||||||
|
epoch += 1
|
||||||
|
|
||||||
|
print("[INFO] Training done. Best model had val acc. {} @ iter {}...".format(self.bestValAcc, self.bestValIter))
|
||||||
|
|
||||||
|
# Evaluation
|
||||||
|
def eval(self, net, dataset, valid=False):
|
||||||
|
net = net.eval()
|
||||||
|
data_size = len(dataset)
|
||||||
|
dataloader = Data.DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=self.opts.batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=self.opts.num_workers,
|
||||||
|
pin_memory=False
|
||||||
|
)
|
||||||
|
allPredictedProgs = []
|
||||||
|
numAllProg = 0
|
||||||
|
falsePred = 0
|
||||||
|
for step, (caption, captionPrg) in enumerate(dataloader):
|
||||||
|
print("\rEvaluation: [step %4d/%4d]" % (
|
||||||
|
step,
|
||||||
|
int(data_size / self.opts.batch_size),
|
||||||
|
), end=' ')
|
||||||
|
caption = caption.cuda()
|
||||||
|
captionPrg = captionPrg.cuda()
|
||||||
|
tokens = net.sample(caption)
|
||||||
|
targetProgs = decodeProg(captionPrg, dataset.vocab["idx_prog_to_token"], target=True)
|
||||||
|
predProgs = decodeProg(tokens, dataset.vocab["idx_prog_to_token"])
|
||||||
|
allPredictedProgs.extend(list(map(lambda s: "( {} ( {} ) ) \n".format(s[0], ", ".join(s[1:])), predProgs)))
|
||||||
|
numAllProg += len(targetProgs)
|
||||||
|
for targetProg, predProg in zip(targetProgs, predProgs):
|
||||||
|
mainMod = targetProg[0] == predProg[0]
|
||||||
|
sameLength = len(targetProg) == len(predProg)
|
||||||
|
sameArgs = False
|
||||||
|
if sameLength:
|
||||||
|
sameArgs = True
|
||||||
|
for argTarget in targetProg[1:]:
|
||||||
|
if argTarget not in predProg[1:]:
|
||||||
|
sameArgs = False
|
||||||
|
break
|
||||||
|
|
||||||
|
if not (mainMod and sameArgs):
|
||||||
|
falsePred += 1
|
||||||
|
val_acc = (1 - (falsePred / numAllProg)) * 100.0
|
||||||
|
print("Acc: {}".format(val_acc))
|
||||||
|
net = net.train()
|
||||||
|
if not valid:
|
||||||
|
with open(self.opts.res_path, "w") as f:
|
||||||
|
f.writelines(allPredictedProgs)
|
||||||
|
print("[INFO] Predicted caption programs logged into {}".format(self.opts.res_path))
|
||||||
|
return val_acc
|
||||||
|
|
||||||
|
def run(self, run_mode):
|
||||||
|
self.set_seed(self.opts.seed)
|
||||||
|
if run_mode == 'train':
|
||||||
|
self.train(self.dataset_tr, self.dataset_val)
|
||||||
|
|
||||||
|
elif run_mode == 'test':
|
||||||
|
lenVocabText = self.dataset_test.lenVocabText
|
||||||
|
lenVocabProg = self.dataset_test.lenVocabProg
|
||||||
|
maxLenProg = self.dataset_test.maxLenProg
|
||||||
|
net = self.constructNet(lenVocabText, lenVocabProg, maxLenProg)
|
||||||
|
|
||||||
|
print('Loading ckpt {}'.format(self.opts.load_checkpoint_path))
|
||||||
|
state_dict = torch.load(self.opts.load_checkpoint_path)['state_dict']
|
||||||
|
net.load_state_dict(state_dict)
|
||||||
|
net.cuda()
|
||||||
|
self.eval(net, self.dataset_test)
|
||||||
|
|
||||||
|
else:
|
||||||
|
exit(-1)
|
||||||
|
|
||||||
|
def set_seed(self, seed):
|
||||||
|
"""Sets the seed for reproducibility.
|
||||||
|
Args:
|
||||||
|
seed (int): The seed used
|
||||||
|
"""
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
np.random.seed(seed)
|
||||||
|
print('[INFO] Seed set to {}...'.format(seed))
|
||||||
|
|
||||||
|
|
||||||
|
def decodeProg(tokens, prgIdxToToken, target=False):
|
||||||
|
tokensBatch = tokens.tolist()
|
||||||
|
progsBatch = []
|
||||||
|
for tokens in tokensBatch:
|
||||||
|
prog = []
|
||||||
|
for tok in tokens:
|
||||||
|
if tok == 2: # <END> has index 2
|
||||||
|
break
|
||||||
|
prog.append(prgIdxToToken.get(tok))
|
||||||
|
if target:
|
||||||
|
prog = prog[1:]
|
||||||
|
progsBatch.append(prog)
|
||||||
|
return progsBatch
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
opts = Options().parse()
|
||||||
|
exe = Execution(opts)
|
||||||
|
exe.run(opts.mode)
|
||||||
|
print("[INFO] Done ...")
|
912
prog_generator/train_question_parser.py
Normal file
912
prog_generator/train_question_parser.py
Normal file
|
@ -0,0 +1,912 @@
|
||||||
|
"""
|
||||||
|
author: Adnen Abdessaied
|
||||||
|
maintainer: "Adnen Abdessaied"
|
||||||
|
website: adnenabdessaied.de
|
||||||
|
version: 1.0.1
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import json, torch, pickle, copy, time
|
||||||
|
import numpy as np
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.utils.data as Data
|
||||||
|
from tensorboardX import SummaryWriter
|
||||||
|
from copy import deepcopy
|
||||||
|
from clevrDialog_dataset import ClevrDialogQuestionDataset
|
||||||
|
import pickle
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
|
from executor.symbolic_executor import SymbolicExecutorClevr, SymbolicExecutorMinecraft
|
||||||
|
from models import SeqToSeqQ, QuestEncoder_1, QuestEncoder_2, Decoder, CaptionEncoder, SeqToSeqC
|
||||||
|
from optim import get_optim, adjust_lr
|
||||||
|
from options_caption_parser import Options as OptionsC
|
||||||
|
from options_question_parser import Options as OptionsQ
|
||||||
|
|
||||||
|
|
||||||
|
class Execution:
|
||||||
|
def __init__(self, optsQ, optsC):
|
||||||
|
self.opts = deepcopy(optsQ)
|
||||||
|
if self.opts.useCuda > 0 and torch.cuda.is_available():
|
||||||
|
self.device = torch.device("cuda:0")
|
||||||
|
print("[INFO] Using GPU {} ...".format(torch.cuda.get_device_name(0)))
|
||||||
|
else:
|
||||||
|
print("[INFO] Using CPU ...")
|
||||||
|
self.device = torch.device("cpu")
|
||||||
|
|
||||||
|
self.loss_fn = torch.nn.NLLLoss().to(self.device)
|
||||||
|
|
||||||
|
print("[INFO] Loading dataset ...")
|
||||||
|
|
||||||
|
self.datasetTr = ClevrDialogQuestionDataset(
|
||||||
|
self.opts.dataPathTr, self.opts.vocabPath, "train", "All tr data")
|
||||||
|
|
||||||
|
self.datasetVal = ClevrDialogQuestionDataset(
|
||||||
|
self.opts.dataPathVal, self.opts.vocabPath, "val", "All val data", train=False)
|
||||||
|
|
||||||
|
self.datasetTest = ClevrDialogQuestionDataset(
|
||||||
|
self.opts.dataPathTest, self.opts.vocabPath, "test", "All val data", train=False)
|
||||||
|
|
||||||
|
self.QuestionNet = constructQuestionNet(
|
||||||
|
self.opts,
|
||||||
|
self.datasetTr.lenVocabText,
|
||||||
|
self.datasetTr.lenVocabProg,
|
||||||
|
self.datasetTr.maxLenProg,
|
||||||
|
)
|
||||||
|
|
||||||
|
if os.path.isfile(self.opts.captionNetPath):
|
||||||
|
self.CaptionNet = constructCaptionNet(
|
||||||
|
optsC,
|
||||||
|
self.datasetTr.lenVocabText,
|
||||||
|
self.datasetTr.lenVocabProg,
|
||||||
|
self.datasetTr.maxLenProg
|
||||||
|
)
|
||||||
|
print('Loading CaptionNet from {}'.format(self.opts.captionNetPath))
|
||||||
|
state_dict = torch.load(self.opts.captionNetPath)['state_dict']
|
||||||
|
self.CaptionNet.load_state_dict(state_dict)
|
||||||
|
self.CaptionNet.to(self.device)
|
||||||
|
total_params_cap = sum(p.numel() for p in self.CaptionNet.parameters() if p.requires_grad)
|
||||||
|
print("The caption encoder has {} trainable parameters".format(total_params_cap))
|
||||||
|
|
||||||
|
self.QuestionNet.to(self.device)
|
||||||
|
# if os.path.isfile(self.opts.load_checkpoint_path):
|
||||||
|
# print('Loading QuestionNet from {}'.format(optsQ.load_checkpoint_path))
|
||||||
|
# state_dict = torch.load(self.opts.load_checkpoint_path)['state_dict']
|
||||||
|
# self.QuestionNet.load_state_dict(state_dict)
|
||||||
|
total_params_quest = sum(p.numel() for p in self.QuestionNet.parameters() if p.requires_grad)
|
||||||
|
print("The question encoder has {} trainable parameters".format(total_params_quest))
|
||||||
|
|
||||||
|
if "minecraft" in self.opts.scenesPath:
|
||||||
|
self.symbolicExecutor = SymbolicExecutorMinecraft(self.opts.scenesPath)
|
||||||
|
else:
|
||||||
|
self.symbolicExecutor = SymbolicExecutorClevr(self.opts.scenesPath)
|
||||||
|
|
||||||
|
tb_path = os.path.join(self.opts.run_dir, "tb_logdir")
|
||||||
|
if not os.path.isdir(tb_path):
|
||||||
|
os.makedirs(tb_path)
|
||||||
|
|
||||||
|
self.ckpt_path = os.path.join(self.opts.run_dir, "ckpt_dir")
|
||||||
|
if not os.path.isdir(self.ckpt_path):
|
||||||
|
os.makedirs(self.ckpt_path)
|
||||||
|
if not os.path.isdir(self.opts.text_log_dir):
|
||||||
|
os.makedirs(self.opts.text_log_dir)
|
||||||
|
|
||||||
|
self.writer = SummaryWriter(tb_path)
|
||||||
|
self.iter_val = 0
|
||||||
|
|
||||||
|
if os.path.isfile(self.opts.dependenciesPath):
|
||||||
|
with open(self.opts.dependenciesPath, "rb") as f:
|
||||||
|
self.dependencies = pickle.load(f)
|
||||||
|
|
||||||
|
def train(self):
|
||||||
|
self.QuestionNet.train()
|
||||||
|
|
||||||
|
# Define the multi-gpu training if needed
|
||||||
|
if len(self.opts.gpu_ids) > 1:
|
||||||
|
self.QuestionNet = nn.DataParallel(self.QuestionNet, device_ids=self.opts.gpu_ids)
|
||||||
|
|
||||||
|
# Load checkpoint if resume training
|
||||||
|
if os.path.isfile(self.opts.load_checkpoint_path):
|
||||||
|
print("[INFO] Resume trainig from ckpt {} ...".format(
|
||||||
|
self.opts.load_checkpoint_path
|
||||||
|
))
|
||||||
|
|
||||||
|
# Load the network parameters
|
||||||
|
ckpt = torch.load(self.opts.load_checkpoint_path)
|
||||||
|
print("[INFO] Checkpoint successfully loaded ...")
|
||||||
|
self.QuestionNet.load_state_dict(ckpt['state_dict'])
|
||||||
|
|
||||||
|
# Load the optimizer paramters
|
||||||
|
optim = get_optim(self.opts, self.QuestionNet, len(self.datasetTr)) # , ckpt['optim'], lr_base=ckpt['lr_base'])
|
||||||
|
# optim._step = int(data_size / self.__C.BATCH_SIZE * self.__C.CKPT_EPOCH)
|
||||||
|
optim.optimizer.load_state_dict(ckpt['optimizer'])
|
||||||
|
_iter = 0 # ckpt['last_iter']
|
||||||
|
epoch = 0 # ckpt['last_epoch']
|
||||||
|
|
||||||
|
else:
|
||||||
|
optim = get_optim(self.opts, self.QuestionNet, len(self.datasetTr))
|
||||||
|
_iter = 0
|
||||||
|
epoch = 0
|
||||||
|
|
||||||
|
trainTime = 0
|
||||||
|
bestValAcc = float("-inf")
|
||||||
|
bestCkp = 0
|
||||||
|
# Training loop
|
||||||
|
while _iter < self.opts.num_iters:
|
||||||
|
|
||||||
|
# Learning Rate Decay
|
||||||
|
if _iter in self.opts.lr_decay_marks:
|
||||||
|
adjust_lr(optim, self.opts.lr_decay_factor)
|
||||||
|
|
||||||
|
# Define multi-thread dataloader
|
||||||
|
dataloader = Data.DataLoader(
|
||||||
|
self.datasetTr,
|
||||||
|
batch_size=self.opts.batch_size,
|
||||||
|
shuffle=self.opts.shuffle_data,
|
||||||
|
num_workers=self.opts.num_workers,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Iteration
|
||||||
|
time_start = 0
|
||||||
|
time_end = 0
|
||||||
|
for batch_iter, (quest, hist, prog, questionRound, _) in enumerate(dataloader):
|
||||||
|
time_start = time.time()
|
||||||
|
if _iter >= self.opts.num_iters:
|
||||||
|
break
|
||||||
|
quest = quest.to(self.device)
|
||||||
|
if self.opts.last_n_rounds < 10:
|
||||||
|
last_n_rounds_batch = []
|
||||||
|
for i, r in enumerate(questionRound.tolist()):
|
||||||
|
startIdx = max(r - self.opts.last_n_rounds, 0)
|
||||||
|
endIdx = max(r, self.opts.last_n_rounds)
|
||||||
|
if hist.dim() == 3:
|
||||||
|
assert endIdx - startIdx == self.opts.last_n_rounds
|
||||||
|
histBatch = hist[i, :, :]
|
||||||
|
last_n_rounds_batch.append(histBatch[startIdx:endIdx, :])
|
||||||
|
elif hist.dim() == 2:
|
||||||
|
startIdx *= 20
|
||||||
|
endIdx *= 20
|
||||||
|
histBatch = hist[i, :]
|
||||||
|
temp = histBatch[startIdx:endIdx].cpu()
|
||||||
|
if r > self.opts.last_n_rounds:
|
||||||
|
last_n_rounds_batch.append(torch.cat([torch.tensor([1]), temp, torch.tensor([2])], 0))
|
||||||
|
else:
|
||||||
|
last_n_rounds_batch.append(torch.cat([temp, torch.tensor([2, 0])], 0))
|
||||||
|
hist = torch.stack(last_n_rounds_batch, dim=0)
|
||||||
|
hist = hist.to(self.device)
|
||||||
|
prog = prog.to(self.device)
|
||||||
|
progTarget = prog.clone()
|
||||||
|
optim.zero_grad()
|
||||||
|
|
||||||
|
predSoftmax, _ = self.QuestionNet(quest, hist, prog[:, :-1])
|
||||||
|
loss = self.loss_fn(
|
||||||
|
# predSoftmax[:, :-1, :].contiguous().view(-1, predSoftmax.size(2)),
|
||||||
|
predSoftmax.contiguous().view(-1, predSoftmax.size(2)),
|
||||||
|
progTarget[:, 1:].contiguous().view(-1))
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
if _iter % self.opts.validate_every == 0 and _iter > 0:
|
||||||
|
valAcc = self.val()
|
||||||
|
if valAcc > bestValAcc:
|
||||||
|
bestValAcc = valAcc
|
||||||
|
bestCkp = _iter
|
||||||
|
print("\n[INFO] Checkpointing model @ iter {} with val accuracy {}\n".format(_iter, valAcc))
|
||||||
|
state = {
|
||||||
|
'state_dict': self.QuestionNet.state_dict(),
|
||||||
|
'optimizer': optim.optimizer.state_dict(),
|
||||||
|
'lr_base': optim.lr_base,
|
||||||
|
'optim': optim.lr_base,
|
||||||
|
'last_iter': _iter,
|
||||||
|
'last_epoch': epoch,
|
||||||
|
}
|
||||||
|
# checkpointing
|
||||||
|
torch.save(
|
||||||
|
state,
|
||||||
|
os.path.join(self.ckpt_path, 'ckpt_iter' + str(_iter) + '.pkl')
|
||||||
|
)
|
||||||
|
|
||||||
|
# logging
|
||||||
|
self.writer.add_scalar(
|
||||||
|
'train/loss',
|
||||||
|
loss.cpu().data.numpy(),
|
||||||
|
global_step=_iter)
|
||||||
|
|
||||||
|
self.writer.add_scalar(
|
||||||
|
'train/lr',
|
||||||
|
optim._rate,
|
||||||
|
global_step=_iter)
|
||||||
|
if _iter % self.opts.display_every == 0:
|
||||||
|
time_end = time.time()
|
||||||
|
trainTime += time_end-time_start
|
||||||
|
|
||||||
|
print("\r[CLEVR-Dialog - %s (%d | %d)][epoch %2d][iter %4d/%4d][runtime %4f] loss: %.4f, lr: %.2e" % (
|
||||||
|
self.datasetTr.name,
|
||||||
|
batch_iter,
|
||||||
|
len(dataloader),
|
||||||
|
epoch,
|
||||||
|
_iter,
|
||||||
|
self.opts.num_iters,
|
||||||
|
trainTime,
|
||||||
|
loss.cpu().data.numpy(),
|
||||||
|
optim._rate,
|
||||||
|
), end=' ')
|
||||||
|
|
||||||
|
optim.step()
|
||||||
|
_iter += 1
|
||||||
|
|
||||||
|
epoch += 1
|
||||||
|
print("[INFO] Avg. epoch time: {} s".format(trainTime / epoch))
|
||||||
|
print("[INFO] Best model achieved val acc. {} @ iter {}".format(bestValAcc, bestCkp))
|
||||||
|
|
||||||
|
def val(self):
|
||||||
|
self.QuestionNet.eval()
|
||||||
|
|
||||||
|
total_correct = 0
|
||||||
|
total = 0
|
||||||
|
|
||||||
|
if len(self.opts.gpu_ids) > 1:
|
||||||
|
self.QuestionNet = nn.DataParallel(self.QuestionNet, device_ids=self.opts.gpu_ids)
|
||||||
|
self.QuestionNet = self.QuestionNet.eval()
|
||||||
|
dataloader = Data.DataLoader(
|
||||||
|
self.datasetVal,
|
||||||
|
batch_size=self.opts.batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=self.opts.num_workers,
|
||||||
|
pin_memory=False
|
||||||
|
)
|
||||||
|
_iterCur = 0
|
||||||
|
_totalCur = len(dataloader)
|
||||||
|
|
||||||
|
for step, (question, questionPrg, questionImgIdx, questionRounds, history, historiesProg, answer) in enumerate(dataloader):
|
||||||
|
# print("\rEvaluation: [step %4d/%4d]" % (
|
||||||
|
print("\rEvaluation: [step %4d/%4d]" % (
|
||||||
|
step,
|
||||||
|
int(len(dataloader)),
|
||||||
|
), end=' ')
|
||||||
|
|
||||||
|
question = question.to(self.device)
|
||||||
|
|
||||||
|
if history.dim() == 3:
|
||||||
|
caption = history.detach()
|
||||||
|
caption = caption[:, 0, :]
|
||||||
|
caption = caption[:, :16].to(self.device)
|
||||||
|
elif history.dim() == 2:
|
||||||
|
caption = history.detach()
|
||||||
|
caption = caption[:, :16].to(self.device)
|
||||||
|
if self.opts.last_n_rounds is not None:
|
||||||
|
last_n_rounds_batch = []
|
||||||
|
for i, r in enumerate(questionRounds.tolist()):
|
||||||
|
startIdx = max(r - self.opts.last_n_rounds, 0)
|
||||||
|
endIdx = max(r, self.opts.last_n_rounds)
|
||||||
|
if history.dim() == 3:
|
||||||
|
assert endIdx - startIdx == self.opts.last_n_rounds
|
||||||
|
histBatch = history[i, :, :]
|
||||||
|
last_n_rounds_batch.append(histBatch[startIdx:endIdx, :])
|
||||||
|
elif history.dim() == 2:
|
||||||
|
startIdx *= 20
|
||||||
|
endIdx *= 20
|
||||||
|
histBatch = history[i, :]
|
||||||
|
temp = histBatch[startIdx:endIdx]
|
||||||
|
if r > self.opts.last_n_rounds:
|
||||||
|
last_n_rounds_batch.append(torch.cat([torch.tensor([1]), temp, torch.tensor([2])], 0))
|
||||||
|
else:
|
||||||
|
last_n_rounds_batch.append(torch.cat([temp, torch.tensor([2, 0])], 0))
|
||||||
|
history = torch.stack(last_n_rounds_batch, dim=0)
|
||||||
|
history = history.to(self.device)
|
||||||
|
questionPrg = questionPrg.to(self.device)
|
||||||
|
|
||||||
|
questProgsToksPred = self.QuestionNet.sample(question, history)
|
||||||
|
questProgsPred = decodeProg(questProgsToksPred, self.datasetVal.vocab["idx_prog_to_token"])
|
||||||
|
targetProgs = decodeProg(questionPrg, self.datasetVal.vocab["idx_prog_to_token"], target=True)
|
||||||
|
|
||||||
|
correct = [1 if pred == gt else 0 for (pred, gt) in zip(questProgsPred, targetProgs)]
|
||||||
|
|
||||||
|
correct = sum(correct)
|
||||||
|
total_correct += correct
|
||||||
|
total += len(targetProgs)
|
||||||
|
self.QuestionNet.train()
|
||||||
|
|
||||||
|
return 100.0 * (total_correct / total)
|
||||||
|
|
||||||
|
# Evaluation
|
||||||
|
def eval_with_gt(self):
|
||||||
|
# Define the multi-gpu training if needed
|
||||||
|
all_pred_answers = []
|
||||||
|
all_gt_answers = []
|
||||||
|
all_question_types = []
|
||||||
|
all_penalties = []
|
||||||
|
all_pred_programs = []
|
||||||
|
all_gt_programs = []
|
||||||
|
|
||||||
|
first_failure_round = 0
|
||||||
|
total_correct = 0
|
||||||
|
total_acc_pen = 0
|
||||||
|
total = 0
|
||||||
|
total_quest_prog_correct = 0
|
||||||
|
|
||||||
|
if len(self.opts.gpu_ids) > 1:
|
||||||
|
self.QuestionNet = nn.DataParallel(self.QuestionNet, device_ids=self.opts.gpu_ids)
|
||||||
|
self.QuestionNet = self.QuestionNet.eval()
|
||||||
|
self.CaptionNet = self.CaptionNet.eval()
|
||||||
|
if self.opts.batch_size != self.opts.dialogLen:
|
||||||
|
print("[INFO] Changed batch size from {} to {}".format(self.opts.batch_size, self.opts.dialogLen))
|
||||||
|
self.opts.batch_size = self.opts.dialogLen
|
||||||
|
dataloader = Data.DataLoader(
|
||||||
|
self.datasetTest,
|
||||||
|
batch_size=self.opts.batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=self.opts.num_workers,
|
||||||
|
pin_memory=False
|
||||||
|
)
|
||||||
|
_iterCur = 0
|
||||||
|
_totalCur = len(dataloader)
|
||||||
|
|
||||||
|
for step, (question, questionPrg, questionImgIdx, questionRounds, history, historiesProg, answer) in enumerate(dataloader):
|
||||||
|
# print("\rEvaluation: [step %4d/%4d]" % (
|
||||||
|
# step + 1,
|
||||||
|
# int(data_size / self.opts.batch_size),
|
||||||
|
# ), end=' ')
|
||||||
|
# if step >= 5000:
|
||||||
|
# break
|
||||||
|
batchSize = question.size(0)
|
||||||
|
question = question.to(self.device)
|
||||||
|
# dependecy = self.dependencies[step*batchSize:(step+1)*batchSize]
|
||||||
|
|
||||||
|
if history.dim() == 3:
|
||||||
|
caption = history.detach()
|
||||||
|
caption = caption[:, 0, :]
|
||||||
|
caption = caption[:, :16].to(self.device)
|
||||||
|
elif history.dim() == 2:
|
||||||
|
caption = history.detach()
|
||||||
|
caption = caption[:, :16].to(self.device)
|
||||||
|
if self.opts.last_n_rounds < 10:
|
||||||
|
last_n_rounds_batch = []
|
||||||
|
for i, r in enumerate(questionRounds.tolist()):
|
||||||
|
startIdx = max(r - self.opts.last_n_rounds, 0)
|
||||||
|
endIdx = max(r, self.opts.last_n_rounds)
|
||||||
|
if history.dim() == 3:
|
||||||
|
assert endIdx - startIdx == self.opts.last_n_rounds
|
||||||
|
histBatch = history[i, :, :]
|
||||||
|
last_n_rounds_batch.append(histBatch[startIdx:endIdx, :])
|
||||||
|
elif history.dim() == 2:
|
||||||
|
startIdx *= 20
|
||||||
|
endIdx *= 20
|
||||||
|
histBatch = history[i, :]
|
||||||
|
temp = histBatch[startIdx:endIdx]
|
||||||
|
if r > self.opts.last_n_rounds:
|
||||||
|
last_n_rounds_batch.append(torch.cat([torch.tensor([1]), temp, torch.tensor([2])], 0))
|
||||||
|
else:
|
||||||
|
last_n_rounds_batch.append(torch.cat([temp, torch.tensor([2, 0])], 0))
|
||||||
|
history = torch.stack(last_n_rounds_batch, dim=0)
|
||||||
|
|
||||||
|
history = history.to(self.device)
|
||||||
|
questionPrg = questionPrg.to(self.device)
|
||||||
|
historiesProg = historiesProg.tolist()
|
||||||
|
questionRounds = questionRounds.tolist()
|
||||||
|
answer = answer.tolist()
|
||||||
|
answers = list(map(lambda a: self.datasetTest.vocab["idx_text_to_token"][a], answer))
|
||||||
|
questionImgIdx = questionImgIdx.tolist()
|
||||||
|
# if "minecraft" in self.opts.scenesPath:
|
||||||
|
# questionImgIdx = [idx - 1 for idx in questionImgIdx]
|
||||||
|
questProgsToksPred = self.QuestionNet.sample(question, history)
|
||||||
|
capProgsToksPred = self.CaptionNet.sample(caption)
|
||||||
|
|
||||||
|
questProgsPred = decodeProg(questProgsToksPred, self.datasetTest.vocab["idx_prog_to_token"])
|
||||||
|
capProgsPred = decodeProg(capProgsToksPred, self.datasetTest.vocab["idx_prog_to_token"])
|
||||||
|
|
||||||
|
targetProgs = decodeProg(questionPrg, self.datasetTest.vocab["idx_prog_to_token"], target=True)
|
||||||
|
questionTypes = [targetProg[0] for targetProg in targetProgs]
|
||||||
|
# progHistories = getProgHistories(historiesProg[0], dataset.vocab["idx_prog_to_token"])
|
||||||
|
progHistories = [getProgHistories(progHistToks, self.datasetTest.vocab["idx_prog_to_token"]) for progHistToks in historiesProg]
|
||||||
|
pred_answers = []
|
||||||
|
all_pred_programs.append([capProgsPred[0]] + questProgsPred)
|
||||||
|
all_gt_programs.append([progHistories[0]] + (targetProgs))
|
||||||
|
|
||||||
|
for i in range(batchSize):
|
||||||
|
# if capProgsPred[i][0] == "extreme-center":
|
||||||
|
# print("bla")
|
||||||
|
# print("idx = {}".format(questionImgIdx[i]))
|
||||||
|
ans = self.getPrediction(
|
||||||
|
questProgsPred[i],
|
||||||
|
capProgsPred[i],
|
||||||
|
progHistories[i],
|
||||||
|
questionImgIdx[i]
|
||||||
|
)
|
||||||
|
# if ans == "Error":
|
||||||
|
# print(capProgsPred[i])
|
||||||
|
pred_answers.append(ans)
|
||||||
|
# print(pred_answers)
|
||||||
|
correct = [1 if pred == ans else 0 for (pred, ans) in zip(pred_answers, answers)]
|
||||||
|
correct_prog = [1 if pred == ans else 0 for (pred, ans) in zip(questProgsPred, targetProgs)]
|
||||||
|
idx_false = np.argwhere(np.array(correct) == 0).squeeze(-1)
|
||||||
|
if idx_false.shape[-1] > 0:
|
||||||
|
first_failure_round += idx_false[0] + 1
|
||||||
|
else:
|
||||||
|
first_failure_round += self.opts.dialogLen + 1
|
||||||
|
|
||||||
|
correct = sum(correct)
|
||||||
|
correct_prog = sum(correct_prog)
|
||||||
|
total_correct += correct
|
||||||
|
total_quest_prog_correct += correct_prog
|
||||||
|
total += len(answers)
|
||||||
|
all_pred_answers.append(pred_answers)
|
||||||
|
all_gt_answers.append(answers)
|
||||||
|
all_question_types.append(questionTypes)
|
||||||
|
penalty = np.zeros_like(penalty)
|
||||||
|
all_penalties.append(penalty)
|
||||||
|
_iterCur += 1
|
||||||
|
if _iterCur % self.opts.display_every == 0:
|
||||||
|
print("[Evaluation] step {0} / {1} | acc. = {2:.2f}".format(
|
||||||
|
_iterCur, _totalCur, 100.0 * (total_correct / total)))
|
||||||
|
|
||||||
|
ffr = 1.0 * (first_failure_round/_totalCur)/(self.opts.dialogLen + 1)
|
||||||
|
|
||||||
|
textOut = "\n --------------- Average First Failure Round --------------- \n"
|
||||||
|
textOut += "{} / {}".format(ffr, self.opts.dialogLen)
|
||||||
|
|
||||||
|
# print(total_correct, total)
|
||||||
|
accuracy = total_correct / total
|
||||||
|
vd_acc = total_acc_pen / total
|
||||||
|
quest_prog_acc = total_quest_prog_correct / total
|
||||||
|
textOut += "\n --------------- Overall acc. --------------- \n"
|
||||||
|
textOut += "{}".format(100.0 * accuracy)
|
||||||
|
textOut += "\n --------------- Overall VD acc. --------------- \n"
|
||||||
|
textOut += "{}".format(100.0 * vd_acc)
|
||||||
|
textOut += "\n --------------- Question Prog. Acc --------------- \n"
|
||||||
|
textOut += "{}".format(100.0 * quest_prog_acc)
|
||||||
|
textOut += get_per_round_acc(
|
||||||
|
all_pred_answers, all_gt_answers, all_penalties)
|
||||||
|
|
||||||
|
textOut += get_per_question_type_acc(
|
||||||
|
all_pred_answers, all_gt_answers, all_question_types, all_penalties)
|
||||||
|
|
||||||
|
# textOut += get_per_dependency_type_acc(
|
||||||
|
# all_pred_answers, all_gt_answers, all_penalties)
|
||||||
|
|
||||||
|
textOut += "\n --------------- Done --------------- \n"
|
||||||
|
print(textOut)
|
||||||
|
fname = self.opts.questionNetPath.split("/")[-3] + "results_{}_{}.txt".format(self.opts.last_n_rounds, self.opts.dialogLen)
|
||||||
|
pred_answers_fname = self.opts.questionNetPath.split("/")[-3] + "_pred_answers_{}_{}.pkl".format(self.opts.last_n_rounds, self.opts.dialogLen)
|
||||||
|
pred_answers_fname = os.path.join("/projects/abdessaied/clevr-dialog/output/pred_answers", pred_answers_fname)
|
||||||
|
model_name = "NSVD_stack" if "stack" in self.opts.questionNetPath else "NSVD_concat"
|
||||||
|
experiment_name = "minecraft"
|
||||||
|
# experiment_name += "_{}".format(self.opts.dialogLen)
|
||||||
|
prog_output_fname = os.path.join("/projects/abdessaied/clevr-dialog/output/prog_output/{}_{}.pkl".format(model_name, experiment_name))
|
||||||
|
|
||||||
|
fpath = os.path.join(self.opts.text_log_dir, fname)
|
||||||
|
with open(fpath, "w") as f:
|
||||||
|
f.writelines(textOut)
|
||||||
|
with open(pred_answers_fname, "wb") as f:
|
||||||
|
pickle.dump(all_pred_answers, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
|
with open(prog_output_fname, "wb") as f:
|
||||||
|
pickle.dump((all_gt_programs, all_pred_programs, all_pred_answers), f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
|
|
||||||
|
# Evaluation
|
||||||
|
def eval_with_pred(self):
|
||||||
|
# Define the multi-gpu training if needed
|
||||||
|
all_pred_answers = []
|
||||||
|
all_gt_answers = []
|
||||||
|
all_question_types = []
|
||||||
|
all_penalties = []
|
||||||
|
|
||||||
|
first_failure_round = 0
|
||||||
|
total_correct = 0
|
||||||
|
total_acc_pen = 0
|
||||||
|
total = 0
|
||||||
|
|
||||||
|
samples = {}
|
||||||
|
|
||||||
|
if len(self.opts.gpu_ids) > 1:
|
||||||
|
self.QuestionNet = nn.DataParallel(self.QuestionNet, device_ids=self.opts.gpu_ids)
|
||||||
|
self.QuestionNet = self.QuestionNet.eval()
|
||||||
|
self.CaptionNet = self.CaptionNet.eval()
|
||||||
|
if self.opts.batch_size != self.opts.dialogLen:
|
||||||
|
print("[INFO] Changed batch size from {} to {}".format(self.opts.batch_size, self.opts.dialogLen))
|
||||||
|
self.opts.batch_size = self.opts.dialogLen
|
||||||
|
dataloader = Data.DataLoader(
|
||||||
|
self.datasetTest,
|
||||||
|
batch_size=self.opts.batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=self.opts.num_workers,
|
||||||
|
pin_memory=False
|
||||||
|
)
|
||||||
|
_iterCur = 0
|
||||||
|
_totalCur = len(dataloader)
|
||||||
|
step = 0
|
||||||
|
for step, (question, questionPrg, questionImgIdx, questionRounds, history, historiesProg, answer) in enumerate(dataloader):
|
||||||
|
question = question.tolist()
|
||||||
|
questions = decode(question, self.datasetTest.vocab["idx_text_to_token"], target=True)
|
||||||
|
questions = list(map(lambda q: " ".join(q), questions))
|
||||||
|
targetProgs = decode(questionPrg, self.datasetTest.vocab["idx_prog_to_token"], target=True)
|
||||||
|
|
||||||
|
questionTypes = [targetProg[0] for targetProg in targetProgs]
|
||||||
|
targetProgs = list(map(lambda q: " ".join(q), targetProgs))
|
||||||
|
|
||||||
|
historiesProg = historiesProg.tolist()
|
||||||
|
progHistories = [getProgHistories(progHistToks, self.datasetTest.vocab["idx_prog_to_token"]) for progHistToks in historiesProg]
|
||||||
|
|
||||||
|
answer = answer.tolist()
|
||||||
|
answers = list(map(lambda a: self.datasetTest.vocab["idx_text_to_token"][a], answer))
|
||||||
|
questionImgIdx = questionImgIdx.tolist()
|
||||||
|
|
||||||
|
if self.opts.encoderType == 2:
|
||||||
|
histories_eval = [history[0, 0, :].tolist()]
|
||||||
|
caption = history.detach()
|
||||||
|
caption = caption[0, 0, :].unsqueeze(0)
|
||||||
|
caption = caption[:, :16].to(self.device)
|
||||||
|
elif self.opts.encoderType == 1:
|
||||||
|
caption = history.detach()
|
||||||
|
histories_eval = [history[0, :20].tolist()]
|
||||||
|
caption = caption[0, :16].unsqueeze(0).to(self.device)
|
||||||
|
cap = decode(caption, self.datasetTest.vocab["idx_text_to_token"], target=False)
|
||||||
|
capProgToksPred = self.CaptionNet.sample(caption)
|
||||||
|
capProgPred = decode(capProgToksPred, self.datasetTest.vocab["idx_prog_to_token"])[0]
|
||||||
|
|
||||||
|
pred_answers = []
|
||||||
|
pred_quest_prog = []
|
||||||
|
for i, (q, prog_hist, img_idx) in enumerate(zip(question, progHistories, questionImgIdx)):
|
||||||
|
_round = i + 1
|
||||||
|
if _round <= self.opts.last_n_rounds:
|
||||||
|
start = 0
|
||||||
|
else:
|
||||||
|
start = _round - self.opts.last_n_rounds
|
||||||
|
end = len(histories_eval)
|
||||||
|
|
||||||
|
quest = torch.tensor(q).unsqueeze(0).to(self.device)
|
||||||
|
if self.opts.encoderType == 3:
|
||||||
|
hist = torch.stack([torch.tensor(h) for h in histories_eval[start:end]], dim=0).unsqueeze(0).to(self.device)
|
||||||
|
elif self.opts.encoderType == 1:
|
||||||
|
histories_eval_copy = deepcopy(histories_eval)
|
||||||
|
histories_eval_copy[-1].append(self.datasetTest.vocab["text_token_to_idx"]["<END>"])
|
||||||
|
hist = torch.cat([torch.tensor(h) for h in histories_eval_copy[start:end]], dim=-1).unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
|
questProgsToksPred = self.QuestionNet.sample(quest, hist)
|
||||||
|
questProgsPred = decode(questProgsToksPred, self.datasetTest.vocab["idx_prog_to_token"])[0]
|
||||||
|
pred_quest_prog.append(" ".join(questProgsPred))
|
||||||
|
ans = self.getPrediction(
|
||||||
|
questProgsPred,
|
||||||
|
capProgPred,
|
||||||
|
prog_hist,
|
||||||
|
img_idx
|
||||||
|
)
|
||||||
|
ans_idx = self.datasetTest.vocab["text_token_to_idx"].get(
|
||||||
|
ans, self.datasetTest.vocab["text_token_to_idx"]["<UNK>"])
|
||||||
|
q[q.index(self.datasetTest.vocab["text_token_to_idx"]["<END>"])] = self.datasetTest.vocab["text_token_to_idx"]["<NULL>"]
|
||||||
|
q[-1] = self.datasetTest.vocab["text_token_to_idx"]["<END>"]
|
||||||
|
q.insert(-1, ans_idx)
|
||||||
|
if self.opts.encoderType == 3:
|
||||||
|
histories_eval.append(copy.deepcopy(q))
|
||||||
|
elif self.opts.encoderType == 0:
|
||||||
|
del q[0]
|
||||||
|
del q[-1]
|
||||||
|
histories_eval.append(copy.deepcopy(q))
|
||||||
|
|
||||||
|
pred_answers.append(ans)
|
||||||
|
|
||||||
|
correct = [1 if pred == ans else 0 for (pred, ans) in zip(pred_answers, answers)]
|
||||||
|
idx_false = np.argwhere(np.array(correct) == 0).squeeze(-1)
|
||||||
|
if idx_false.shape[-1] > 0:
|
||||||
|
first_failure_round += idx_false[0] + 1
|
||||||
|
else:
|
||||||
|
first_failure_round += self.opts.dialogLen + 1
|
||||||
|
|
||||||
|
correct = sum(correct)
|
||||||
|
total_correct += correct
|
||||||
|
total += len(answers)
|
||||||
|
all_pred_answers.append(pred_answers)
|
||||||
|
all_gt_answers.append(answers)
|
||||||
|
all_question_types.append(questionTypes)
|
||||||
|
_iterCur += 1
|
||||||
|
if _iterCur % self.opts.display_every == 0:
|
||||||
|
print("[Evaluation] step {0} / {1} | acc. = {2:.2f}".format(
|
||||||
|
_iterCur, _totalCur, 100.0 * (total_correct / total)
|
||||||
|
))
|
||||||
|
samples["{}_{}".format(questionImgIdx[0], (step % 5) + 1)] = {
|
||||||
|
"caption": " ".join(cap[0]),
|
||||||
|
"cap_prog_gt": " ".join(progHistories[0][0]),
|
||||||
|
"cap_prog_pred": " ".join(capProgPred),
|
||||||
|
|
||||||
|
"questions": questions,
|
||||||
|
"quest_progs_gt": targetProgs,
|
||||||
|
"quest_progs_pred": pred_quest_prog,
|
||||||
|
|
||||||
|
|
||||||
|
"answers": answers,
|
||||||
|
"preds": pred_answers,
|
||||||
|
"acc": correct,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
ffr = 1.0 * self.opts.dialogLen * (first_failure_round/total)
|
||||||
|
|
||||||
|
textOut = "\n --------------- Average First Failure Round --------------- \n"
|
||||||
|
textOut += "{} / {}".format(ffr, self.opts.dialogLen)
|
||||||
|
|
||||||
|
# print(total_correct, total)
|
||||||
|
accuracy = total_correct / total
|
||||||
|
vd_acc = total_acc_pen / total
|
||||||
|
textOut += "\n --------------- Overall acc. --------------- \n"
|
||||||
|
textOut += "{}".format(100.0 * accuracy)
|
||||||
|
textOut += "\n --------------- Overall VD acc. --------------- \n"
|
||||||
|
textOut += "{}".format(100.0 * vd_acc)
|
||||||
|
|
||||||
|
textOut += get_per_round_acc(
|
||||||
|
all_pred_answers, all_gt_answers, all_penalties)
|
||||||
|
|
||||||
|
textOut += get_per_question_type_acc(
|
||||||
|
all_pred_answers, all_gt_answers, all_question_types, all_penalties)
|
||||||
|
|
||||||
|
textOut += "\n --------------- Done --------------- \n"
|
||||||
|
print(textOut)
|
||||||
|
if step >= len(dataloader):
|
||||||
|
fname = self.opts.questionNetPath.split("/")[-3] + "_results_{}_{}_{}.txt".format(self.opts.last_n_rounds, self.opts.dialogLen, self.acc_type)
|
||||||
|
pred_answers_fname = self.opts.questionNetPath.split("/")[-3] + "_pred_answers_{}_{}.pkl".format(self.opts.last_n_rounds, self.opts.dialogLen)
|
||||||
|
pred_answers_fname = os.path.join("/projects/abdessaied/clevr-dialog/output/pred_answers", pred_answers_fname)
|
||||||
|
|
||||||
|
fpath = os.path.join(self.opts.text_log_dir, fname)
|
||||||
|
with open(fpath, "w") as f:
|
||||||
|
f.writelines(textOut)
|
||||||
|
with open(pred_answers_fname, "wb") as f:
|
||||||
|
pickle.dump(all_pred_answers, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
|
|
||||||
|
def getPrediction(self, questProgPred, capProgPred, historyProg, imgIndex):
|
||||||
|
self.symbolicExecutor.reset(imgIndex)
|
||||||
|
# if round one, execute the predicted caption program first then answer the question
|
||||||
|
if len(historyProg) == 1:
|
||||||
|
captionFuncLabel = capProgPred[0]
|
||||||
|
captionFuncArgs = capProgPred[1:]
|
||||||
|
|
||||||
|
questionFuncLabel = questProgPred[0]
|
||||||
|
questionFuncArgs = questProgPred[1:]
|
||||||
|
|
||||||
|
try:
|
||||||
|
_ = self.symbolicExecutor.execute(captionFuncLabel, captionFuncArgs)
|
||||||
|
except:
|
||||||
|
return "Error"
|
||||||
|
|
||||||
|
try:
|
||||||
|
predAnswer = self.symbolicExecutor.execute(questionFuncLabel, questionFuncArgs)
|
||||||
|
except:
|
||||||
|
return "Error"
|
||||||
|
|
||||||
|
# If it is not the first round, we have to execute the program history and
|
||||||
|
# then answer the question.
|
||||||
|
else:
|
||||||
|
questionFuncLabel = questProgPred[0]
|
||||||
|
questionFuncArgs = questProgPred[1:]
|
||||||
|
for prg in historyProg:
|
||||||
|
# prg = prg.split(" ")
|
||||||
|
FuncLabel = prg[0]
|
||||||
|
FuncArgs = prg[1:]
|
||||||
|
try:
|
||||||
|
_ = self.symbolicExecutor.execute(FuncLabel, FuncArgs)
|
||||||
|
except:
|
||||||
|
return "Error"
|
||||||
|
|
||||||
|
try:
|
||||||
|
predAnswer = self.symbolicExecutor.execute(questionFuncLabel, questionFuncArgs)
|
||||||
|
except:
|
||||||
|
return "Error"
|
||||||
|
return str(predAnswer)
|
||||||
|
|
||||||
|
def run(self, run_mode, epoch=None):
|
||||||
|
self.set_seed(self.opts.seed)
|
||||||
|
if run_mode == 'train':
|
||||||
|
self.train()
|
||||||
|
|
||||||
|
elif run_mode == 'test_with_gt':
|
||||||
|
print('Testing with gt answers in history')
|
||||||
|
print('Loading ckpt {}'.format(self.opts.questionNetPath))
|
||||||
|
state_dict = torch.load(self.opts.questionNetPath)['state_dict']
|
||||||
|
self.QuestionNet.load_state_dict(state_dict)
|
||||||
|
self.eval_with_gt()
|
||||||
|
|
||||||
|
elif run_mode == 'test_with_pred':
|
||||||
|
print('Testing with predicted answers in history')
|
||||||
|
print('Loading ckpt {}'.format(self.opts.questionNetPath))
|
||||||
|
state_dict = torch.load(self.opts.questionNetPath)['state_dict']
|
||||||
|
self.QuestionNet.load_state_dict(state_dict)
|
||||||
|
self.eval_with_pred()
|
||||||
|
else:
|
||||||
|
exit(-1)
|
||||||
|
|
||||||
|
def set_seed(self, seed):
|
||||||
|
"""Sets the seed for reproducibility.
|
||||||
|
Args:
|
||||||
|
seed (int): The seed used
|
||||||
|
"""
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
np.random.seed(seed)
|
||||||
|
print('[INFO] Seed set to {}...'.format(seed))
|
||||||
|
|
||||||
|
|
||||||
|
def constructQuestionNet(opts, lenVocabText, lenVocabProg, maxLenProg):
|
||||||
|
decoder = Decoder(opts, lenVocabProg, maxLenProg)
|
||||||
|
if opts.encoderType == 1:
|
||||||
|
encoder = QuestEncoder_1(opts, lenVocabText)
|
||||||
|
elif opts.encoderType == 2:
|
||||||
|
encoder = QuestEncoder_2(opts, lenVocabText)
|
||||||
|
|
||||||
|
net = SeqToSeqQ(encoder, decoder)
|
||||||
|
return net
|
||||||
|
|
||||||
|
|
||||||
|
def constructCaptionNet(opts, lenVocabText, lenVocabProg, maxLenProg):
|
||||||
|
decoder = Decoder(opts, lenVocabProg, maxLenProg)
|
||||||
|
encoder = CaptionEncoder(opts, lenVocabText)
|
||||||
|
net = SeqToSeqC(encoder, decoder)
|
||||||
|
return net
|
||||||
|
|
||||||
|
|
||||||
|
def getProgHistories(progHistToks, prgIdxToToken):
|
||||||
|
progHist = []
|
||||||
|
temp = []
|
||||||
|
for tok in progHistToks:
|
||||||
|
if tok not in [0, 1, 2]:
|
||||||
|
temp.append(prgIdxToToken[tok])
|
||||||
|
# del progHistToks[i]
|
||||||
|
if tok == 2:
|
||||||
|
# del progHistToks[i]
|
||||||
|
# progHist.append(" ".join(temp))
|
||||||
|
progHist.append(temp)
|
||||||
|
temp = []
|
||||||
|
return progHist
|
||||||
|
|
||||||
|
|
||||||
|
def getHistoriesFromStack(histToks, textIdxToToken):
|
||||||
|
histories = "\n"
|
||||||
|
temp = []
|
||||||
|
for i, roundToks in enumerate(histToks):
|
||||||
|
for tok in roundToks:
|
||||||
|
if tok not in [0, 1, 2]:
|
||||||
|
temp.append(textIdxToToken[tok])
|
||||||
|
# del progHistToks[i]
|
||||||
|
if tok == 2:
|
||||||
|
# del progHistToks[i]
|
||||||
|
if i == 0:
|
||||||
|
histories += " ".join(temp) + ".\n"
|
||||||
|
else:
|
||||||
|
histories += " ".join(temp[:-1]) + "? | {}.\n".format(temp[-1])
|
||||||
|
# histories.append(temp)
|
||||||
|
temp = []
|
||||||
|
break
|
||||||
|
return histories
|
||||||
|
|
||||||
|
|
||||||
|
def getHistoriesFromConcat(histToks, textIdxToToken):
|
||||||
|
histories = []
|
||||||
|
temp = []
|
||||||
|
for tok in histToks:
|
||||||
|
if tok not in [0, 1, 2]:
|
||||||
|
temp.append(textIdxToToken[tok])
|
||||||
|
# del progHistToks[i]
|
||||||
|
if tok == 2:
|
||||||
|
# del progHistToks[i]
|
||||||
|
histories.append(" ".join(temp[:-1]) + "? | {}".format(temp[-1]))
|
||||||
|
# histories.append(temp)
|
||||||
|
temp = []
|
||||||
|
return histories
|
||||||
|
|
||||||
|
|
||||||
|
def decodeProg(tokens, prgIdxToToken, target=False):
|
||||||
|
tokensBatch = tokens.tolist()
|
||||||
|
progsBatch = []
|
||||||
|
for tokens in tokensBatch:
|
||||||
|
prog = []
|
||||||
|
for tok in tokens:
|
||||||
|
if tok == 2: # <END> has index 2
|
||||||
|
break
|
||||||
|
prog.append(prgIdxToToken.get(tok))
|
||||||
|
if target:
|
||||||
|
prog = prog[1:]
|
||||||
|
# progsBatch.append(" ".join(prog))
|
||||||
|
progsBatch.append(prog)
|
||||||
|
return progsBatch
|
||||||
|
|
||||||
|
|
||||||
|
def printPred(predSoftmax, gts, prgIdxToToken):
|
||||||
|
assert predSoftmax.size(0) == gts.size(0)
|
||||||
|
tokens = predSoftmax.topk(1)[1].squeeze(-1)
|
||||||
|
tokens = tokens.tolist()
|
||||||
|
gts = gts.tolist()
|
||||||
|
message = "\n ------------------------ \n"
|
||||||
|
for token, gt in zip(tokens, gts):
|
||||||
|
message += "Prediction: "
|
||||||
|
for tok in token:
|
||||||
|
message += prgIdxToToken.get(tok) + " "
|
||||||
|
message += "\n Target : "
|
||||||
|
for tok in gt:
|
||||||
|
message += prgIdxToToken.get(tok) + " "
|
||||||
|
message += "\n ------------------------ \n"
|
||||||
|
return message
|
||||||
|
|
||||||
|
|
||||||
|
def get_per_round_acc(preds, gts, penalties):
|
||||||
|
res = {}
|
||||||
|
for img_preds, img_gt, img_pen in zip(preds, gts, penalties):
|
||||||
|
img_preds = list(img_preds)
|
||||||
|
img_gt = list(img_gt)
|
||||||
|
img_pen = list(img_pen)
|
||||||
|
for i, (pred, gt, pen) in enumerate(zip(img_preds, img_gt, img_pen)):
|
||||||
|
_round = str(i + 1)
|
||||||
|
if _round not in res:
|
||||||
|
res[_round] = {
|
||||||
|
"correct": 0,
|
||||||
|
"all": 0
|
||||||
|
}
|
||||||
|
res[_round]["all"] += 1
|
||||||
|
if pred == gt:
|
||||||
|
res[_round]["correct"] += 0.5**pen
|
||||||
|
|
||||||
|
textOut = "\n --------------- Per round Acc --------------- \n"
|
||||||
|
for k in res:
|
||||||
|
textOut += "{}: {} %\n".format(k, 100.0 * (res[k]["correct"]/res[k]["all"]))
|
||||||
|
return textOut
|
||||||
|
|
||||||
|
|
||||||
|
def get_per_question_type_acc(preds, gts, qtypes, penalties):
|
||||||
|
res1 = {}
|
||||||
|
res2 = {}
|
||||||
|
|
||||||
|
for img_preds, img_gt, img_qtypes, img_pen in zip(preds, gts, qtypes, penalties):
|
||||||
|
# img_preds = list(img_preds)
|
||||||
|
# img_gt = list(img_gt)
|
||||||
|
img_pen = list(img_pen)
|
||||||
|
for pred, gt, temp, pen in zip(img_preds, img_gt, img_qtypes, img_pen):
|
||||||
|
if temp not in res1:
|
||||||
|
res1[temp] = {
|
||||||
|
"correct": 0,
|
||||||
|
"all": 0
|
||||||
|
}
|
||||||
|
temp_cat = temp.split("-")[0]
|
||||||
|
if temp_cat not in res2:
|
||||||
|
res2[temp_cat] = {
|
||||||
|
"correct": 0,
|
||||||
|
"all": 0
|
||||||
|
}
|
||||||
|
res1[temp]["all"] += 1
|
||||||
|
res2[temp_cat]["all"] += 1
|
||||||
|
|
||||||
|
if pred == gt:
|
||||||
|
res1[temp]["correct"] += 0.5**pen
|
||||||
|
res2[temp_cat]["correct"] += 0.5**pen
|
||||||
|
|
||||||
|
textOut = "\n --------------- Per question Type Acc --------------- \n"
|
||||||
|
for k in res1:
|
||||||
|
textOut += "{}: {} %\n".format(k, 100.0 * (res1[k]["correct"]/res1[k]["all"]))
|
||||||
|
|
||||||
|
textOut += "\n --------------- Per question Category Acc --------------- \n"
|
||||||
|
for k in res2:
|
||||||
|
textOut += "{}: {} %\n".format(k, 100.0 * (res2[k]["correct"]/res2[k]["all"]))
|
||||||
|
return textOut
|
||||||
|
|
||||||
|
|
||||||
|
def decode(tokens, prgIdxToToken, target=False):
|
||||||
|
if type(tokens) != list:
|
||||||
|
tokens = tokens.tolist()
|
||||||
|
|
||||||
|
progsBatch = []
|
||||||
|
for token in tokens:
|
||||||
|
prog = []
|
||||||
|
for tok in token:
|
||||||
|
if tok == 2: # <END> has index 2
|
||||||
|
break
|
||||||
|
prog.append(prgIdxToToken.get(tok))
|
||||||
|
if target:
|
||||||
|
prog = prog[1:]
|
||||||
|
# progsBatch.append(" ".join(prog))
|
||||||
|
progsBatch.append(prog)
|
||||||
|
return progsBatch
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
optsC = OptionsC().parse()
|
||||||
|
optsQ = OptionsQ().parse()
|
||||||
|
|
||||||
|
exe = Execution(optsQ, optsC)
|
||||||
|
exe.run("test")
|
||||||
|
print("[INFO] Done ...")
|
80
utils.py
Normal file
80
utils.py
Normal file
|
@ -0,0 +1,80 @@
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def merge_captions_question_programs(path_cap, path_ques, caption_first=True):
|
||||||
|
with open(path_cap, "r"):
|
||||||
|
c_progs = path_cap.readlines()
|
||||||
|
with open(path_ques, "r"):
|
||||||
|
q_progs = path_ques.readlines()
|
||||||
|
|
||||||
|
all_merged_progs = []
|
||||||
|
i = 0
|
||||||
|
while i < len(q_progs):
|
||||||
|
cap_idx = i % 11 if caption_first else i % 10
|
||||||
|
start_idx_p = i + 1 if caption_first else i
|
||||||
|
end_idx_p = start_idx_p + 12 if caption_first else start_idx_p + 11
|
||||||
|
temp = c_progs[cap_idx] + q_progs[start_idx_p, end_idx_p]
|
||||||
|
all_merged_progs.append(temp)
|
||||||
|
i = end_idx_p
|
||||||
|
|
||||||
|
|
||||||
|
def load_clevr_scenes(scenes_json):
|
||||||
|
with open(scenes_json) as f:
|
||||||
|
scenes_raw = json.load(f)
|
||||||
|
if type(scenes_raw) == dict:
|
||||||
|
scenes_raw = scenes_raw["scenes"]
|
||||||
|
|
||||||
|
scenes = []
|
||||||
|
for s in scenes_raw:
|
||||||
|
table = []
|
||||||
|
for i, o in enumerate(s['objects']):
|
||||||
|
item = {}
|
||||||
|
item['id'] = '%d-%d' % (s['image_index'], i)
|
||||||
|
if '3d_coords' in o:
|
||||||
|
item['position'] = [np.dot(o['3d_coords'], s['directions']['right']),
|
||||||
|
np.dot(o['3d_coords'], s['directions']['front']),
|
||||||
|
o['3d_coords'][2]]
|
||||||
|
else:
|
||||||
|
item['position'] = o['position']
|
||||||
|
item['color'] = o['color']
|
||||||
|
item['material'] = o['material']
|
||||||
|
item['shape'] = o['shape']
|
||||||
|
item['size'] = o['size']
|
||||||
|
table.append(item)
|
||||||
|
scenes.append(table)
|
||||||
|
return scenes
|
||||||
|
|
||||||
|
|
||||||
|
def load_minecraft_scenes(scenes_json):
|
||||||
|
with open(scenes_json) as f:
|
||||||
|
scenes_raw = json.load(f)
|
||||||
|
if type(scenes_raw) == dict:
|
||||||
|
scenes_raw = scenes_raw["scenes"]
|
||||||
|
|
||||||
|
scenes = []
|
||||||
|
for s in scenes_raw:
|
||||||
|
table = []
|
||||||
|
for i, o in enumerate(s['objects']):
|
||||||
|
item = {}
|
||||||
|
item['id'] = '%d-%d' % (s['image_index'], i)
|
||||||
|
if '3d_coords' in o:
|
||||||
|
item['position'] = [np.dot(o['3d_coords'], s['directions']['right']),
|
||||||
|
np.dot(o['3d_coords'], s['directions']['front']),
|
||||||
|
o['3d_coords'][2]]
|
||||||
|
else:
|
||||||
|
item['position'] = o['position']
|
||||||
|
item['nature'] = o['nature']
|
||||||
|
item['class'] = o['class']
|
||||||
|
item['direction'] = "facing_"
|
||||||
|
if o['direction'] == "front":
|
||||||
|
item['direction'] += "forward"
|
||||||
|
elif o['direction'] == "back":
|
||||||
|
item['direction'] += "backward"
|
||||||
|
elif o['direction'] == "right":
|
||||||
|
item['direction'] += "right"
|
||||||
|
elif o['direction'] == "left":
|
||||||
|
item['direction'] += "left"
|
||||||
|
table.append(item)
|
||||||
|
scenes.append(table)
|
||||||
|
return scenes
|
62
utils_preprocess.py
Normal file
62
utils_preprocess.py
Normal file
|
@ -0,0 +1,62 @@
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def mkdirs(paths):
|
||||||
|
if isinstance(paths, list):
|
||||||
|
for path in paths:
|
||||||
|
if not os.path.exists(path):
|
||||||
|
os.makedirs(path)
|
||||||
|
else:
|
||||||
|
if not os.path.exists(paths):
|
||||||
|
os.makedirs(paths)
|
||||||
|
|
||||||
|
|
||||||
|
def invert_dict(d):
|
||||||
|
return {v: k for k, v in d.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def load_vocab(path):
|
||||||
|
with open(path, 'r') as f:
|
||||||
|
vocab = json.load(f)
|
||||||
|
vocab['question_idx_to_token'] = invert_dict(vocab['question_token_to_idx'])
|
||||||
|
vocab['program_idx_to_token'] = invert_dict(vocab['program_token_to_idx'])
|
||||||
|
vocab['answer_idx_to_token'] = invert_dict(vocab['answer_token_to_idx'])
|
||||||
|
# Sanity check: make sure <NULL>, <START>, and <END> are consistent
|
||||||
|
assert vocab['question_token_to_idx']['<NULL>'] == 0
|
||||||
|
assert vocab['question_token_to_idx']['<START>'] == 1
|
||||||
|
assert vocab['question_token_to_idx']['<END>'] == 2
|
||||||
|
assert vocab['program_token_to_idx']['<NULL>'] == 0
|
||||||
|
assert vocab['program_token_to_idx']['<START>'] == 1
|
||||||
|
assert vocab['program_token_to_idx']['<END>'] == 2
|
||||||
|
return vocab
|
||||||
|
|
||||||
|
|
||||||
|
def load_scenes(scenes_json):
|
||||||
|
with open(scenes_json) as f:
|
||||||
|
scenes_dict = json.load(f)['scenes']
|
||||||
|
scenes = []
|
||||||
|
for s in scenes_dict:
|
||||||
|
table = []
|
||||||
|
for i, o in enumerate(s['objects']):
|
||||||
|
item = {}
|
||||||
|
item['id'] = '%d-%d' % (s['image_index'], i)
|
||||||
|
if '3d_coords' in o:
|
||||||
|
item['position'] = [np.dot(o['3d_coords'], s['directions']['right']),
|
||||||
|
np.dot(o['3d_coords'], s['directions']['front']),
|
||||||
|
o['3d_coords'][2]]
|
||||||
|
else:
|
||||||
|
item['position'] = o['position']
|
||||||
|
item['color'] = o['color']
|
||||||
|
item['material'] = o['material']
|
||||||
|
item['shape'] = o['shape']
|
||||||
|
item['size'] = o['size']
|
||||||
|
table.append(item)
|
||||||
|
scenes.append(table)
|
||||||
|
return scenes
|
||||||
|
|
||||||
|
|
||||||
|
def load_embedding(path):
|
||||||
|
return torch.Tensor(np.load(path))
|
Loading…
Reference in a new issue