OLViT/src/utils/simmc2_dataset/format_data.py

151 lines
5.1 KiB
Python

#! /usr/bin/env python
"""
Copyright (c) Facebook, Inc. and its affiliates.
All rights reserved.
This source code is licensed under the license found in the LICENSE file in the
root directory of this source tree.
Reads SIMMC 2.1 dataset, creates train, devtest, dev formats for ambiguous candidates.
Author(s): Satwik Kottur
"""
from __future__ import absolute_import, division, print_function, unicode_literals
import argparse
import copy
import json
import os
SPLITS = ["train", "dev", "devtest", "teststd"]
def get_image_name(scene_ids, turn_ind):
"""Given scene ids and turn index, get the image name.
"""
sorted_scene_ids = sorted(
((int(key), val) for key, val in scene_ids.items()),
key=lambda x: x[0],
reverse=True
)
# NOTE: Hardcoded to only two scenes.
if turn_ind >= sorted_scene_ids[0][0]:
scene_label = sorted_scene_ids[0][1]
else:
scene_label = sorted_scene_ids[1][1]
image_label = scene_label
if "m_" in scene_label:
image_label = image_label.replace("m_", "")
return f"{image_label}.png", scene_label
def get_object_mapping(scene_label, args):
"""Get the object mapping for a given scene.
"""
scene_json_path = os.path.join(
args["scene_json_folder"], f"{scene_label}_scene.json"
)
with open(scene_json_path, "r") as file_id:
scene_objects = json.load(file_id)["scenes"][0]["objects"]
object_map = [ii["index"] for ii in scene_objects]
return object_map
def main(args):
for split in SPLITS:
read_path = args[f"simmc_{split}_json"]
print(f"Reading: {read_path}")
with open(read_path, "r") as file_id:
dialogs = json.load(file_id)
# Reformat into simple strings with positive and negative labels.
# (dialog string, label)
ambiguous_candidates_data = []
for dialog_id, dialog_datum in enumerate(dialogs["dialogue_data"]):
turns = []
q_turns = []
a_turns = []
for turn_ind, turn_datum in enumerate(dialog_datum["dialogue"]):
query = [turn_datum["transcript"]]
answer = [turn_datum["system_transcript"]]
#annotations = turn_datum["transcript_annotated"]
#if annotations.get("disambiguation_label", False):
#label = annotations["disambiguation_candidates"]
image_name, scene_label = get_image_name(
dialog_datum["scene_ids"], turn_ind
)
# If dialog contains multiple scenes, map it accordingly.
object_map = get_object_mapping(scene_label, args)
new_datum = {
"query": query,
"answer": answer,
"q_turns": copy.deepcopy(q_turns),
"a_turns": copy.deepcopy(a_turns),
"turns": copy.deepcopy(turns),
"dialog_id": dialog_datum["dialogue_idx"],
"turn_id": turn_ind,
#"input_text": copy.deepcopy(history),
#"ambiguous_candidates": label,
"image_name": image_name,
"object_map": object_map,
}
ambiguous_candidates_data.append(new_datum)
turns.append([turn_datum["transcript"] + turn_datum["system_transcript"]])
q_turns.append(query)
a_turns.append(answer)
# Ignore if system_transcript is not found (last round teststd).
# if turn_datum.get("system_transcript", None):
# history.append(turn_datum["system_transcript"])
print(f"# instances [{split}]: {len(ambiguous_candidates_data)}")
save_path = os.path.join(
args["ambiguous_candidates_save_path"],
f"simmc2.1_ambiguous_candidates_dstc11_{split}.json"
)
print(f"Saving: {save_path}")
with open(save_path, "w") as file_id:
json.dump(
{
"source_path": read_path,
"split": split,
"data": ambiguous_candidates_data,
},
file_id
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--simmc_train_json", default=None, help="Path to SIMMC 2.1 train"
)
parser.add_argument(
"--simmc_dev_json", default=None, help="Path to SIMMC 2.1 dev"
)
parser.add_argument(
"--simmc_devtest_json", default=None, help="Path to SIMMC 2.1 devtest"
)
parser.add_argument(
"--simmc_teststd_json", default=None, help="Path to SIMMC 2.1 teststd (public)"
)
parser.add_argument(
"--scene_json_folder", default=None, help="Path to SIMMC scene jsons"
)
parser.add_argument(
"--ambiguous_candidates_save_path",
required=True,
help="Path to save SIMMC disambiguate JSONs",
)
try:
parsed_args = vars(parser.parse_args())
except (IOError) as msg:
parser.error(str(msg))
main(parsed_args)