commit efbd43fed12f18ed32a9526461293708173050db Author: abdessaied Date: Tue Feb 20 16:31:21 2024 +0100 release code base diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..e27162d --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +*.tar.gz filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md new file mode 100644 index 0000000..16447cf --- /dev/null +++ b/README.md @@ -0,0 +1,90 @@ +
+

OLViT: Multi-Modal State Tracking via Attention-Based Embeddings for Video-Grounded Dialog

+ +**[Adnen Abdessaied][4],   [Manuel von Hochmeister][5],   [Andreas Bulling][6]**

+**COLING 2024**, Turin, Italy
+**[[Paper][7]]** +---------------- +

+ +
+ +# Table of Contents +* [Setup and Dependencies](#Setup-and-Dependencies) +* [Download Data](#Download-Data) +* [Training](#Training) +* [Testing](#Testing) +* [Results](#Results) +* [Acknowledgements](#Acknowledgements) + +# Setup and Dependencies +We implemented our model using Python 3.7, PyTorch 1.11.0 (CUDA 11.3, CuDNN 8.3.2) and PyTorch Lightning. We recommend to setup a virtual environment using Anaconda.
+1. Install [git lfs][1] on your system +2. Clone our repository to download a checpint of our best model and our code + ```shell + git lfs install + git clone this_repo.git + ``` +3. Create a conda environment and install dependencies + ```shell + conda create -n olvit python=3.7 + conda activate olvit + conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch + pip install pytorch-lightning==1.6.3 + pip install transformers==4.19.2 + pip install torchtext==0.12.0 + pip install wandb nltk pandas + ``` +# Download Data +1. [DVD][2] and [SIMMC 2.1][3] data are included in this repository and will be downloaded using git lfs +2. Setup the data by executing + ```shell + chmod u+x setup_data.sh + ./setup_data.sh + ``` +3. This will unpack all the data necessary in ```data/dvd/``` and ```data/simmc/``` + +# Training +We trained our model on 3 Nvidia Tesla V100-32GB GPUs. The default hyperparameters need to be adjusted if your setup differs from ours. +## DVD +1. Adjust the config file for DVD according to your hardware specifications in ```config/dvd.json``` +2. Execute +```shell +CUDA_VISIBLE_DEVICES=0,1,2 python train.py --cfg_path config/dvd.json +``` +3. Checkpoints will be saved in ```checkpoints/dvd/``` + +## SIMMC 2.1 +1. Adjust the config file for SIMMC 2.1 according to your hardware specifications in ```config/simmc.json``` +2. Execute +```shell +CUDA_VISIBLE_DEVICES=0,1,2 python train.py --cfg_path config/simmc.json +``` +3. Checkpoints will be saved in ```checkpoints/simmc/``` + +# Testing +1. Execute +```shell +CUDA_VISIBLE_DEVICES=0 python test.py --ckpt_path --cfg_path +``` + +# Results +Training using the default config and a similar hardware setup as ours will result in the following performance + +## DVD +

+ +## SIMMC 2.1 +

+ +# Acknowledgements +Our work relied on the codebases of [DVD][2] and [SIMMC][3]. Thanks to the authors for sharing their code. + + +[1]: https://git-lfs.com/ +[2]: https://github.com/facebookresearch/DVDialogues/ +[3]: https://github.com/facebookresearch/simmc2/ +[4]: https://perceptualui.org/people/abdessaied/ +[5]: https://www.linkedin.com/in/manuel-von-hochmeister-285416202/ +[6]: https://www.perceptualui.org/people/bulling/ +[7]: https://drive.google.com/file/d/1sDFfGpQ9E9NahT5gw8UjknWt3sNdxM7p/view?usp=sharing diff --git a/checkpoints/dvd/.gitkeep b/checkpoints/dvd/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/checkpoints/simmc/.gitkeep b/checkpoints/simmc/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/config/__init__.py b/config/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/config/config.py b/config/config.py new file mode 100644 index 0000000..0d51c41 --- /dev/null +++ b/config/config.py @@ -0,0 +1,26 @@ +import json +import os + +def read_default_config(): + dirpath = os.path.dirname(__file__) + path = os.path.join(dirpath, "default.json") + with open(path) as config_file: + config = json.load(config_file) + return config + +def read_config(path): + with open(path) as config_file: + config = json.load(config_file) + return config + +def update_nested_dicts(old_dict, update_dict): + for key in update_dict: + if key in old_dict: + old_dict[key].update(update_dict[key]) + else: + old_dict[key] = update_dict[key] + return old_dict + + + + \ No newline at end of file diff --git a/config/default.json b/config/default.json new file mode 100644 index 0000000..ad62aff --- /dev/null +++ b/config/default.json @@ -0,0 +1,43 @@ +{ + "wandb": { + "entity": "TO_BE_DEFINED", + "name": "", + "group": "", + "tags": [], + "project": "olvit" + + }, + "model": { + "model_type": "base_model", + "feature_type": "none", + "freeze_roberta": true, + "v_emb_dim": 16, + "dim_feedforward": 400, + "n_heads": 9, + "fc_dim": 128, + "dropout_p": 0.1, + "sample_rate_video": 10, + "n_encoder_layers": 6, + "add_choices_as_context": false, + "use_pretrained_lm": false, + "projection_as_in_aloe": false, + "pretrained_lm_name": "" + }, + "training": { + "lr": 1e-4, + "total_steps": 200000, + "warmup_steps": 4000, + "accumulate_grad_batches": 1, + "batch_size": 128, + "epochs": 40, + "seed": null + }, + "datamodule": { + "fea_dir": "data/dvd/monet_feats/", + "data_dir": "data/dvd/dialogs/" + }, + "checkpoint": { + "checkpoint_folder": "checkpoints/", + "checkpoint_file_name": "olvit" + } +} \ No newline at end of file diff --git a/config/dvd.json b/config/dvd.json new file mode 100644 index 0000000..e9f6e02 --- /dev/null +++ b/config/dvd.json @@ -0,0 +1,49 @@ +{ + "wandb": { + "name": "olvit", + "group": "dvd", + "tags": [], + "project": "olvit" + + }, + "model": { + "model_type": "discriminative", + "n_heads": 6, + "v_emb_dim": 36, + "dim_feedforward": 200, + "dropout_p": 0.1, + "fc_dim": 512, + "sample_rate_video": 20, + "n_transf_layers": 4, + "use_pretrained_lm": true, + "projection_as_in_aloe": true, + "pretrained_lm_name": "distilroberta-base", + "dataset": "dvd" + }, + "extended_model": { + "hist_len_for_state_gen": 7, + "number_of_relevant_emb": 2, + "num_layers_v_state": 2, + "num_layers_d_state": 2, + "combiner_option": "OptionA", + "state_tracker_type": "Transformer", + "use_v_state": true, + "use_d_state": true, + "n_heads_combiner_transformer": 8, + "n_heads_state_tracker": 6, + "dim_feedforward_v_transformer": 140, + "dim_feedforward_d_transformer": 60 + }, + "training": { + "lr": 1e-4, + "warmup_steps": 4000, + "total_steps": 200000, + "batch_size": 128, + "seed": 12345, + "epochs": 1000 + }, + "checkpoint": { + "checkpoint_folder": "checkpoints/dvd", + "checkpoint_file_name": "olvit" + } +} \ No newline at end of file diff --git a/config/simmc.json b/config/simmc.json new file mode 100644 index 0000000..a17feb8 --- /dev/null +++ b/config/simmc.json @@ -0,0 +1,61 @@ +{ + "wandb": { + "name": "olvit", + "group": "simmc2", + "tags": [], + "project": "olvit" + + }, + "model": { + "model_type": "generative", + "dataset": "simmc2", + "feature_type": "object_text_features", + "object_feature_generator_dim": 50, + "n_object_feature_generator_layers": 2, + "n_heads": 6, + "v_emb_dim": 516, + "emb_dim": 216, + "dim_feedforward": 200, + "dropout_p": 0.1, + "fc_dim": 512, + "sample_rate_video": 1, + "n_encoder_layers": 4, + "n_decoder_layers": 4, + "use_pretrained_lm": true, + "vocab_size": 50265, + "projection_as_in_aloe": false, + "pretrained_lm_name": "distilroberta-base" + }, + "extended_model": { + "hist_len_for_state_gen": 3, + "number_of_relevant_emb": 2, + "num_layers_v_state": 2, + "num_layers_d_state": 2, + "combiner_option": "OptionA", + "state_tracker_type": "Transformer", + "use_v_state": true, + "use_d_state": true, + "n_heads_combiner_transformer": 8, + "n_heads_state_tracker": 6, + "dim_feedforward_v_transformer": 140, + "dim_feedforward_d_transformer": 60 + }, + "training": { + "lr": 1e-4, + "warmup_steps": 4000, + "total_steps": 200000, + "batch_size": 8, + "seed": 12345, + "epochs": 1000 + }, + "datamodule": { + "fea_dir": "data/simmc/visual_features_resnet50_simmc2.1.pt", + "data_dir": "data/simmc/dialogs" + }, + "checkpoint": { + "checkpoint_folder": "checkpoints/simmc/", + "checkpoint_file_name": "olvit", + "output_path": "output/simmc/", + "checkpoint_path": "TO_BE_DETERMINED" + } +} \ No newline at end of file diff --git a/data/dvd/dialogs.tar.gz b/data/dvd/dialogs.tar.gz new file mode 100644 index 0000000..b8e7d16 --- /dev/null +++ b/data/dvd/dialogs.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b1b58ee7af90b402eddbde8470dc0333b83ae293a90a93d26af3b8c39c2d9b0e +size 395953476 diff --git a/data/dvd/monet_feats_part00.tar.gz b/data/dvd/monet_feats_part00.tar.gz new file mode 100644 index 0000000..c70d3e7 --- /dev/null +++ b/data/dvd/monet_feats_part00.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:933c88dbf854d11fca34c388b1b566096b4f9733abd2ded0a1d381b4b1c6a379 +size 1582620496 diff --git a/data/dvd/monet_feats_part01.tar.gz b/data/dvd/monet_feats_part01.tar.gz new file mode 100644 index 0000000..56f6e5e --- /dev/null +++ b/data/dvd/monet_feats_part01.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c07f88af54843010899ed1149d16343b9aeb38dbd2cb4e1977bb4c2436d461ec +size 1582620496 diff --git a/data/simmc/dialogs.tar.gz b/data/simmc/dialogs.tar.gz new file mode 100644 index 0000000..6a31bd6 --- /dev/null +++ b/data/simmc/dialogs.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:65ed3852c6bbe9f3135558f1bfd3900e8c37ae9af7b8338b3535987408086ca6 +size 12956266 diff --git a/data/simmc/visual_features_resnet50_simmc2.1.pt b/data/simmc/visual_features_resnet50_simmc2.1.pt new file mode 100644 index 0000000..6d5a900 --- /dev/null +++ b/data/simmc/visual_features_resnet50_simmc2.1.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7f7aa24ce312e0cdbdb69021ce593aa985074e3ec88a737bc7af8060ff61d6a8 +size 81394479 diff --git a/misc/.gitkeep b/misc/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/misc/italy.png b/misc/italy.png new file mode 100644 index 0000000..cab3939 Binary files /dev/null and b/misc/italy.png differ diff --git a/misc/results_dvd.png b/misc/results_dvd.png new file mode 100644 index 0000000..2b645aa Binary files /dev/null and b/misc/results_dvd.png differ diff --git a/misc/results_simmc.png b/misc/results_simmc.png new file mode 100644 index 0000000..51c2733 Binary files /dev/null and b/misc/results_simmc.png differ diff --git a/misc/teaser.pdf b/misc/teaser.pdf new file mode 100644 index 0000000..73d8147 Binary files /dev/null and b/misc/teaser.pdf differ diff --git a/misc/teaser.png b/misc/teaser.png new file mode 100644 index 0000000..fc64b5e Binary files /dev/null and b/misc/teaser.png differ diff --git a/output/.gitkeep b/output/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/setup_data.sh b/setup_data.sh new file mode 100644 index 0000000..1efdb01 --- /dev/null +++ b/setup_data.sh @@ -0,0 +1,16 @@ +cd data/dvd + +tar -xvzf dialogs.tar.gz +cat monet_feats_part* > monet_feats.tar.gz +tar -xvzf monet_feats.tar.gz + +rm dialogs.tar.gz +rm monet_feats.tar.gz +rm monet_feats_part00.tar.gz +rm monet_feats_part01.tar.gz + +cd ../simmc +tar -xvzf dialogs.tar.gz +rm dialogs.tar.gz + +cd ../.. diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/combiner/option_a.py b/src/combiner/option_a.py new file mode 100644 index 0000000..8dfcf1c --- /dev/null +++ b/src/combiner/option_a.py @@ -0,0 +1,25 @@ +import pytorch_lightning as pl +import torch + +class CombinerOptionA(pl.LightningModule): + def __init__(self, config=None, model_input_dim=None, use_v_state=False, use_d_state=False): + super().__init__() + self.use_v_state = use_v_state + self.use_d_state = use_d_state + + def forward(self, vision_emb, language_emb, language_emb_mask, v_state, d_state, dummy_word=None): + if v_state is not None \ + and d_state is not None \ + and self.use_v_state \ + and self.use_d_state: + output = torch.concat([v_state, d_state, vision_emb, language_emb], axis=1) + elif d_state is not None and self.use_d_state: + output = torch.concat([d_state, vision_emb, language_emb], axis=1) + elif v_state is not None and self.use_v_state: + output = torch.concat([v_state, vision_emb, language_emb], axis=1) + else: + output = torch.concat([vision_emb, language_emb], axis=1) + if dummy_word is not None: + output = torch.concat([dummy_word, output], axis=1) + + return output diff --git a/src/combiner/option_b.py b/src/combiner/option_b.py new file mode 100644 index 0000000..86be3ac --- /dev/null +++ b/src/combiner/option_b.py @@ -0,0 +1,38 @@ +import pytorch_lightning as pl +import torch + +class CombinerOptionB(pl.LightningModule): + def __init__(self, config=None, model_input_dim=None, use_v_state=False, use_d_state=False): + super().__init__() + self.use_v_state = use_v_state + self.use_d_state = use_d_state + + + def append_state_to_emb(self, tensor, state): + tiling_vector = [1, tensor.shape[1], 1] + state_tensor_for_concatenation = torch.tile(state, tiling_vector) + result = torch.concat([tensor, state_tensor_for_concatenation], axis=2) + return result + + + def forward(self, dummy_word, video_emb, language_emb, language_emb_mask, v_state, d_state): + # concatenate the video emb with the video state and the language emb with the dialogue state + # if the stat is not used, concatenate itself + if v_state is not None \ + and d_state is not None \ + and self.use_v_state \ + and self.use_d_state: + video_emb = self.append_state_to_emb(video_emb, v_state) + language_emb = self.append_state_to_emb(language_emb, d_state) + elif d_state is not None and self.use_d_state: + video_emb = self.append_state_to_emb(video_emb, video_emb) + language_emb = self.append_state_to_emb(language_emb, d_state) + elif v_state is not None and self.use_v_state: + video_emb = self.append_state_to_emb(video_emb, v_state) + language_emb = self.append_state_to_emb(language_emb, language_emb) + else: + video_emb = self.append_state_to_emb(video_emb, video_emb) + language_emb = self.append_state_to_emb(language_emb, language_emb) + + output = torch.concat([dummy_word, video_emb, language_emb], axis=1) + return output diff --git a/src/combiner/option_c.py b/src/combiner/option_c.py new file mode 100644 index 0000000..b4db9c1 --- /dev/null +++ b/src/combiner/option_c.py @@ -0,0 +1,69 @@ +import pytorch_lightning as pl +import torch +from torch import nn + +class CombinerOptionC(pl.LightningModule): + def __init__(self, config, model_input_dim, use_v_state, use_d_state): + super().__init__() + self.config = config + self.use_v_state = use_v_state + self.use_d_state = use_d_state + + self.encoder_layer_d = nn.TransformerEncoderLayer( + d_model=model_input_dim, + dim_feedforward=self.config['dim_feedforward_d_transformer'], + batch_first=True, + nhead=self.config['n_heads_combiner_transformer'] + ) + self.encoder_layer_v = nn.TransformerEncoderLayer( + d_model=model_input_dim, + dim_feedforward=self.config['dim_feedforward_v_transformer'], + batch_first=True, + nhead=self.config['n_heads_combiner_transformer'] + ) + + + def prepare_inputs_for_transformers(self, video_emb, language_emb, language_emb_mask, v_state, d_state): + # create masks for the language inputs (video seq should all be 301 frames long and dont need padding) + d_input_mask = ~language_emb_mask # emb for pytorch needs to be True for masked tokens (opposite to huggingface mask) + # if the dialogue state is used, add a column of Falses at the beeginngin of the tensor (state should be attended -> no mask) + if d_state is not None and self.use_d_state: + zero_column = torch.zeros((d_input_mask.shape[0], 1), dtype=torch.bool, device=self.device) + d_input_mask = torch.concat([zero_column, d_input_mask],axis=1) + + # prepare the input tensors for the different transformer layers depending on which state vectors should be used + if v_state is not None \ + and d_state is not None \ + and self.use_v_state \ + and self.use_d_state: + v_input = torch.concat([v_state, video_emb], axis=1) + d_input = torch.concat([d_state, language_emb], axis=1) + elif d_state is not None and self.use_d_state: + v_input = video_emb + d_input = torch.concat([d_state, language_emb], axis=1) + elif v_state is not None and self.use_v_state: + v_input = torch.concat([v_state, video_emb], axis=1) + d_input = language_emb + else: + v_input = video_emb + d_input = language_emb + + return v_input, d_input, d_input_mask + + + def forward(self, dummy_word, video_emb, language_emb, language_emb_mask, v_state, d_state): + # prepare the input tensors for the different transformer layers depending on which state vectors should be used + v_input, d_input, d_input_mask = self.prepare_inputs_for_transformers(video_emb, language_emb, language_emb_mask, v_state, d_state) + + # apply the v transformer to the v input and the d transformer to the d input + v_emb = self.encoder_layer_v(v_input) + d_emb = self.encoder_layer_d(d_input, src_key_padding_mask=d_input_mask) + + # combine the output of the first 2 transformers and add the dummy word (cls token) + # put the embedded video and dialog states at the beginning of the combined input + v_state_emb = v_emb[:, 0, :].unsqueeze(1) + d_state_emb = d_emb[:, 0, :].unsqueeze(1) + combined_input = torch.concat([dummy_word, v_state_emb, d_state_emb, v_emb[:, 1:, :], d_emb[:, 1:, :]], axis=1) + + # create combined_input_mask based on the language_emb_mask + return combined_input \ No newline at end of file diff --git a/src/data_modules/__init__.py b/src/data_modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/data_modules/dvd_data.py b/src/data_modules/dvd_data.py new file mode 100644 index 0000000..cc8f423 --- /dev/null +++ b/src/data_modules/dvd_data.py @@ -0,0 +1,55 @@ +import pytorch_lightning as pl +import src.utils.dvd_codebase.data.data_handler as dh +from src.utils.dvd_codebase.configs.configs import * +from transformers import AutoTokenizer +import os + +class DVDData(pl.LightningDataModule): + def __init__(self, config): + super().__init__() + args.batch_size = config['training']['batch_size'] + args.fea_dir = config['datamodule']['fea_dir'] + args.data_dir = config['datamodule']['data_dir'] + pretrained_lm_name = config['model']['pretrained_lm_name'] + + # load dialogues + self.train_dials, self.train_vids = dh.load_dials(args, "train") + self.val_dials, self.val_vids = dh.load_dials(args, "val") + self.test_dials, self.test_vids = dh.load_dials(args, "test") + + # get vocabulary + self.vocab, self.answer_list = dh.get_vocabulary(self.train_dials, args) + # self.answer_list = ['0', '1', '10', '2', '3', '4', '5', '6', '7', '8', '9', 'False', 'True', 'blue', 'brown', 'cone', 'cube', 'cyan', 'cylinder', 'flying', 'flying,rotating', 'flying,rotating,sliding', 'flying,sliding', 'gold', 'gray', 'green', 'large', 'medium', 'metal', 'no action', 'purple', 'red', 'rotating', 'rotating,sliding', 'rubber', 'sliding', 'small', 'sphere', 'spl', 'yellow'] + + train_vft = dh.load_video_features(args, self.train_vids) + val_vft = dh.load_video_features(args, self.val_vids) + test_vft = dh.load_video_features(args, self.test_vids) + + # create tokenizer + if pretrained_lm_name != '': + tokenizer = AutoTokenizer.from_pretrained(pretrained_lm_name) + pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token) + self.vocab[''] = pad_token_id + os.environ["TOKENIZERS_PARALLELISM"] = "false" + else: + tokenizer = None + + # load data + self.train_dials = dh.create_dials(self.train_dials, self.vocab, self.answer_list, train_vft, args, tokenizer=tokenizer) + self.val_dials = dh.create_dials(self.val_dials, self.vocab, self.answer_list, val_vft, args, tokenizer=tokenizer) + self.test_dials = dh.create_dials(self.test_dials, self.vocab, self.answer_list, test_vft, args, tokenizer=tokenizer) + + + def train_dataloader(self): + dl, _ = dh.create_dataset(self.train_dials, self.vocab, "train", args) + return dl + + def val_dataloader(self): + dl, _ = dh.create_dataset(self.val_dials, self.vocab, "val", args) + return dl + + def test_dataloader(self): + dl, _ = dh.create_dataset(self.test_dials, self.vocab, "test", args) + return dl + + diff --git a/src/data_modules/simmc2_data.py b/src/data_modules/simmc2_data.py new file mode 100644 index 0000000..acff6ee --- /dev/null +++ b/src/data_modules/simmc2_data.py @@ -0,0 +1,95 @@ +import pytorch_lightning as pl +from src.utils.simmc2_dataset.dataloader_dvd_model import Simmc2Dataset, VisualFeatureLoader +from transformers import AutoTokenizer +import argparse +import os +from torch.utils.data import DataLoader + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument("--train_file", default='', help="Path to train file") + parser.add_argument("--dev_file", default='', help="Path to dev file") + parser.add_argument("--devtest_file", default='', help="Path to devtest file") + parser.add_argument( + "--visual_feature_path", default=None, help="Path to visual features" + ) + parser.add_argument( + "--visual_feature_size", + type=int, + default=516, + help="Size of the visual features", + ) + parser.add_argument( + "--max_turns", type=int, default=5, help="Number of turns in history" + ) + parser.add_argument( + "--max_length", type=int, default=512, help="Maximum length in utterance" + ) + parser.add_argument("--use_gpu", dest="use_gpu", action="store_true", default=True) + args = parser.parse_args() + return args + + + +class Simmc2Data(pl.LightningDataModule): + def __init__(self, config): + super().__init__() + self.args = parse_arguments() + self.args.train_file = os.path.join(config['datamodule']['data_dir'], 'simmc2.1_ambiguous_candidates_dstc11_train.json') + self.args.dev_file = os.path.join(config['datamodule']['data_dir'], 'simmc2.1_ambiguous_candidates_dstc11_dev.json') + self.args.devtest_file = os.path.join(config['datamodule']['data_dir'], 'simmc2.1_ambiguous_candidates_dstc11_devtest.json') + self.args.teststd_file = os.path.join(config['datamodule']['data_dir'], 'simmc2.1_dials_dstc11_dev.json') + self.args.visual_feature_path = config['datamodule']['fea_dir'] + pretrained_lm_name = config['model']['pretrained_lm_name'] + self.tokenizer = AutoTokenizer.from_pretrained(pretrained_lm_name) + self.feature_loader = VisualFeatureLoader( + feature_path=self.args.visual_feature_path, + feature_size=self.args.visual_feature_size + ) + self.config = config + + def train_dataloader(self): + dataset = Simmc2Dataset( + tokenizer=self.tokenizer, + feature_loader=self.feature_loader, + load_path=self.args.train_file, + args=self.args + ) + dl = DataLoader( + dataset, + batch_size=self.config['training']['batch_size'], + shuffle=True, + collate_fn=dataset.collate_fn, + ) + return dl + + def val_dataloader(self): + dataset = Simmc2Dataset( + tokenizer=self.tokenizer, + feature_loader=self.feature_loader, + load_path=self.args.dev_file, + args=self.args, + ) + dl = DataLoader( + dataset, + batch_size=self.config['training']['batch_size'], + shuffle=False, + collate_fn=dataset.collate_fn, + ) + return dl + + def test_dataloader(self): + dataset = Simmc2Dataset( + tokenizer=self.tokenizer, + feature_loader=self.feature_loader, + load_path=self.args.devtest_file, + args=self.args, + ) + dl = DataLoader( + dataset, + batch_size=self.config['training']['batch_size'], + shuffle=False, + collate_fn=dataset.collate_fn, + ) + return dl diff --git a/src/models/__init__.py b/src/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/models/base_model.py b/src/models/base_model.py new file mode 100644 index 0000000..a6930f7 --- /dev/null +++ b/src/models/base_model.py @@ -0,0 +1,179 @@ +import pytorch_lightning as pl +import torch +from torch import nn +from torch.optim import AdamW +from src.utils.positional_encoding import PositionalEncoding +from src.object_description_encoder.object_description_encoder import ObjectDescriptionEncoder +import torchmetrics as metrics +from transformers import get_cosine_schedule_with_warmup +from transformers import AutoModel +from src.combiner.option_a import CombinerOptionA +from transformers import AutoTokenizer + + +class TransformerModel(pl.LightningModule): + def __init__(self, config, output_path=None): + super().__init__() + self.output_path = output_path + self.config = config['model'] + self.train_config = config['training'] + + self.train_acc = metrics.Accuracy('multiclass', num_classes=40) + self.val_acc = metrics.Accuracy('multiclass', num_classes=40) + self.test_acc = metrics.Accuracy('multiclass', num_classes=40) + + self.best_val_acc = 0 + self.loss_for_best_val_acc = 0 + self.best_train_acc = 0 + + + self.combiner = CombinerOptionA() + self.initialize_text_encoder_and_feature_mapping() + + self.positional_encoder = PositionalEncoding( + d_model=self.model_input_dim, dropout=self.config['dropout_p'], max_len=self.config['dim_feedforward'] + ) + + encoder_layer = nn.TransformerEncoderLayer( + d_model=self.model_input_dim, + batch_first=True, + dropout=self.config['dropout_p'], + dim_feedforward=self.config['dim_feedforward'], + nhead=self.config['n_heads'] + ) + self.encoder = nn.TransformerEncoder( + encoder_layer=encoder_layer, + num_layers=self.config['n_encoder_layers'], + ) + + self.loss = nn.CrossEntropyLoss() + + if self.config['feature_type'] == 'object_text_features': + self.object_description_encoder = ObjectDescriptionEncoder( + d_model=self.config['v_emb_dim'], + config=self.config + ) + # maps the output from the pretrained lm to as smaller size used for the encoding of the object description (reduces transformer size) + self.linear_projection_object_description = nn.Linear(self.pretrained_lm.config.hidden_size, self.config['v_emb_dim']) + + + # tokenizer for translation from ids to text + self.tokenizer = AutoTokenizer.from_pretrained(self.config['pretrained_lm_name']) + + + def initialize_text_encoder_and_feature_mapping(self): + if self.config['use_pretrained_lm']: + self.pretrained_lm = AutoModel.from_pretrained( + self.config['pretrained_lm_name'], + add_pooling_layer=False + ) + self.pretrained_lm.eval() + # don't train the paramteres of the pretrained lm + self.pretrained_lm.config.training = True + # for param in self.pretrained_lm.parameters(): + # param.requires_grad = False + + # initialize the projection layers to map the embeddings to the correct input dim + # either use the emb_dim as done in aloe (v_emb_dim * n_heads) or the emb_dim specified in the config + if self.config['projection_as_in_aloe']: + self.model_input_dim = self.config['n_heads'] * self.config['v_emb_dim'] + self.linear_projection_video = nn.Linear(self.config['v_emb_dim'], self.model_input_dim - 2) + self.linear_projection_text = nn.Linear(self.pretrained_lm.config.hidden_size, self.model_input_dim - 2) + else: + # take embedding size from config and map the video features from their size to the chose emb size + self.linear_projection_video = nn.Linear(self.config['v_emb_dim'], self.config['emb_dim'] - 2) + self.linear_projection_text = nn.Linear(self.pretrained_lm.config.hidden_size, self.config['emb_dim'] - 2) + self.model_input_dim = self.config['emb_dim'] + else: + # either use the emb_dim as done in aloe (v_emb_dim * n_heads) or the video_emb_dim (2 is either added or subtracted because of the input ids) + if self.config['projection_as_in_aloe']: + self.model_input_dim = self.config['n_heads'] * self.config['v_emb_dim'] + else: + self.model_input_dim = self.config['emb_dim'] + self.linear_projection_video = nn.Linear(self.config['v_emb_dim'], self.model_input_dim - 2) + self.embed = nn.Embedding(num_embeddings=self.config['vocab_size'], embedding_dim=self.model_input_dim - 2) + + + def append_ids(self, tensor, id_vector, axis): + id_vector = torch.tensor(id_vector, device=self.device) + for a in range(len(tensor.shape)): + if a != axis: + id_vector = torch.unsqueeze(id_vector, axis=a) + tiling_vector = [s if i != axis else 1 for i, s in enumerate(tensor.shape)] + id_tensor = torch.tile(id_vector, tiling_vector) + return torch.concat([tensor, id_tensor], axis=axis) + + + def downsample_video_emb(self, video_emb): + return video_emb[:, ::self.config['sample_rate_video'], :, :] + + + def unroll_video_emb(self, video_emb): + video_emb = video_emb.permute(0, 1, 3, 2) + return torch.reshape(video_emb, (video_emb.shape[0], -1, video_emb.shape[3])) + + + def apply_pretrained_lm(self, query, query_mask): + output = self.pretrained_lm( + input_ids=query, + attention_mask=query_mask + ) + return output['last_hidden_state'] + + + def prepare_lang_emb(self, query, query_mask): + # set maximum query length TODO ------ set param in config + if query.shape[1] > 100: + query = query[:, :100] + query_mask = query_mask[:, :100] + + # apply pretrained language model to embed the query if specified + if self.config['use_pretrained_lm']: + lang_emb = self.apply_pretrained_lm(query, query_mask) + else: + lang_emb = self.embed(query) + + # Aloe uses an emb_dim of v_emb_dim * n_heads. Or use the emb_dim specified in the config + if self.config['use_pretrained_lm']: + lang_emb = self.linear_projection_text(lang_emb) + + lang_emb = self.append_ids(lang_emb, [1, 0], 2) + lang_emb = self.positional_encoder(lang_emb) + return lang_emb + + + def prepare_video_emb(self, video_emb): + # shape: [batch, frames, v_emb_dim, objects] + video_emb = self.downsample_video_emb(video_emb) + + # unroll time dimension in object dimension (only take every _ frame) - shape: [batch, objects x frames, v_emb_dim + 2] + video_emb = self.unroll_video_emb(video_emb) + + # video_emb need to be projected to either the size of the language emb or the emb_size given by v_emb_dim * n_heads (As done in the Aloe paper) + #if self.config['use_pretrained_lm'] or self.config['projection_as_in_aloe']: + video_emb = self.linear_projection_video(video_emb) + + video_emb = self.append_ids(video_emb, [0, 1], 2) + video_emb = self.positional_encoder(video_emb) + return video_emb + + + def forward(self, batch): + output = self.answer_query(batch.query, batch.query_mask, batch.vft) + return output + + + def configure_optimizers(self): + opt = AdamW(self.parameters(), lr=self.train_config['lr']) + sched = get_cosine_schedule_with_warmup( + opt, + num_warmup_steps=self.train_config['warmup_steps'], + num_training_steps=self.train_config['total_steps'], + ) + return { + 'optimizer': opt, + 'lr_scheduler': { + 'scheduler': sched, + 'interval': 'step' + } + } \ No newline at end of file diff --git a/src/models/discriminative_model.py b/src/models/discriminative_model.py new file mode 100644 index 0000000..12e112a --- /dev/null +++ b/src/models/discriminative_model.py @@ -0,0 +1,137 @@ +from src.models.state_tracker_model import StateTrackerModel +import torch +from torch import nn +from src.utils.text_utils import translate_from_ids_to_text +import pandas as pd + + +class DiscriminativeModel(StateTrackerModel): + def __init__(self, config, output_path=None): + super().__init__(config, output_path=output_path) + + self.fc = nn.Linear(self.model_input_dim, self.config["fc_dim"]) + self.relu = nn.ReLU() + self.output = nn.Linear(self.config["fc_dim"], 40) + + + def apply_model(self, language_emb, language_emb_mask, video_emb, v_state=None, d_state=None, answer_emb=None, answer_mask=None, state_generation_mode=None): + # analogous to the CLS token from BERT models + dummy_word = torch.zeros(self.model_input_dim, requires_grad=True, device=self.device) + dummy_word = torch.tile(dummy_word, (language_emb.shape[0], 1, 1)) + + # combine state and embeddings + input = self.combiner( + video_emb, + language_emb, + language_emb_mask, + v_state, + d_state, + dummy_word + ) + # create input mask based on the language_emb_mask (complete video is unmasked) + input_mask = torch.zeros((input.shape[0], input.shape[1]), device=self.device) + offset = 1 + if v_state is not None: offset += 1 + if d_state is not None: offset += 1 + # offset is caused by cls token and state vectors + if self.config['model_type'] == 'extended_model': + # set offset to 1 if combiner B is used -> no state vectors as input. Instead concatenated with embeddings + if self.ext_config['combiner_option'] == 'OptionB': + offset = 1 + input_mask[:, video_emb.shape[1] + offset:] = ~language_emb_mask + + x = self.encoder(input, src_key_padding_mask=input_mask) + # only pass transformed dummy word to the dense layers + x = self.fc(x[:, 0, :]) + x = self.relu(x) + output = self.output(x) + return output + + + def answer_query(self, query, query_mask, vft, v_state=None, d_state=None, answer=None, answer_mask=None, state_generation_mode=False): + video_emb = self.prepare_video_emb(vft) + lang_emb = self.prepare_lang_emb(query, query_mask) + if answer is not None and answer_mask is not None: + answer_emb = self.prepare_lang_emb(answer, answer_mask) + else: + answer_emb = None + output = self.apply_model(lang_emb, query_mask, video_emb, v_state, d_state, answer_emb, answer_mask, state_generation_mode) + return output + + + def training_step(self, train_batch, batch_idx): + train_batch.move_to_cuda() + label = torch.squeeze(train_batch.answer) + out = self.forward(train_batch) + loss = self.loss(out, label) + tr_acc = self.train_acc(out.softmax(dim=1), label) + if tr_acc > self.best_train_acc: + self.best_train_acc = tr_acc + self.log("train_acc", tr_acc, prog_bar=True, on_step=False, on_epoch=True, batch_size=train_batch.query.shape[0]) + self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True, batch_size=train_batch.query.shape[0]) + print('train_loss: {} | train_acc: {}'.format(loss, tr_acc)) + return loss + + + def validation_step(self, val_batch, batch_idx): + val_batch.move_to_cuda() + label = torch.squeeze(val_batch.answer) + out = self.forward(val_batch) + loss = self.loss(out, label) + self.val_acc(out.softmax(dim=1), label) + self.log("val_acc", self.val_acc, prog_bar=True, on_step=False, on_epoch=True, batch_size=val_batch.query.shape[0]) + self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True, batch_size=val_batch.query.shape[0]) + return {'val_loss': loss, 'val_acc': self.val_acc.compute()} + + + def test_step(self, test_batch, batch_idx): + test_batch.move_to_cuda() + label = torch.squeeze(test_batch.answer) + out = self.forward(test_batch) + loss = self.loss(out, label) + self.test_acc(out.softmax(dim=1), label) + self.log("test_acc", self.test_acc, prog_bar=True, on_step=False, on_epoch=True, batch_size=test_batch.query.shape[0]) + self.log("test_loss", loss, prog_bar=True, on_step=False, on_epoch=True, batch_size=test_batch.query.shape[0]) + + # save the results into a dictionary + out = torch.argmax(out, dim=1) + + question_as_text = [] + for i in range(test_batch.query.shape[0]): + question_ids = test_batch.query[i, :] + question_as_text.append(translate_from_ids_to_text(question_ids, self.tokenizer)) + + self.results['question'].extend(question_as_text) + self.results['video_name'].extend(test_batch.video_name) + + self.results['qa_id'].extend(test_batch.qa_ids) + self.results['q_type'].extend(test_batch.q_type) + self.results['label'].extend(label.tolist()) + self.results['output'].extend(out.tolist()) + self.results['attribute_dependency'].extend(test_batch.attribute_dependency) + self.results['object_dependency'].extend(test_batch.object_dependency) + self.results['temporal_dependency'].extend(test_batch.temporal_dependency) + self.results['spatial_dependency'].extend(test_batch.spatial_dependency) + self.results['q_complexity'].extend(test_batch.q_complexity) + + + def on_test_start(self): + self.results = { + 'qa_id': [], + 'q_type': [], + 'label': [], + 'output': [], + 'attribute_dependency': [], + 'object_dependency': [], + 'temporal_dependency': [], + 'spatial_dependency': [], + 'q_complexity': [], + # only needed for input output analysis + 'question': [], + 'video_name': [] + } + + + def on_test_end(self): + df = pd.DataFrame.from_dict(self.results) + df.to_pickle(self.output_path) diff --git a/src/models/generative_model.py b/src/models/generative_model.py new file mode 100644 index 0000000..f5c1264 --- /dev/null +++ b/src/models/generative_model.py @@ -0,0 +1,350 @@ +# code is partly inspired from https://pytorch.org/tutorials/beginner/translation_transformer.html + +from unittest import result +from src.models.state_tracker_model import StateTrackerModel +from src.utils.batch_interfaces import batch_interface_simmc2_to_dvd, batch_interface_avsd_to_dvd +from dataclasses import dataclass +import torch +from torch import nn +from torchtext.data.metrics import bleu_score +import json +import os +from transformers import AutoTokenizer +import nltk +import numpy as np +from src.utils.text_utils import normalize_sentence, translate_from_ids_to_text + + + + +class GenerativeModel(StateTrackerModel): + def __init__(self, config, output_path=None): + super().__init__(config, output_path=output_path) + + self.transformer = nn.Transformer( + d_model=self.model_input_dim, + batch_first=True, + dropout=self.config['dropout_p'], + dim_feedforward=self.config['dim_feedforward'], + nhead=self.config['n_heads'], + num_encoder_layers=self.config['n_encoder_layers'], + num_decoder_layers=self.config['n_decoder_layers'], + custom_encoder=self.encoder + ) + self.prob_generator = nn.Linear(self.model_input_dim, self.config['vocab_size']) + + self.pad_id = 1 + self.unk_id = 3 + self.loss = nn.CrossEntropyLoss(ignore_index=self.pad_id) + + + # tokenizer for translation from ids to text + self.tokenizer = AutoTokenizer.from_pretrained(self.config['pretrained_lm_name']) + + # ---TODO: Remove ------ + self.results = {} + self.epoch_count = 0 + + + # ----------------------- + self.batch_interface = batch_interface_simmc2_to_dvd + + + def encode_object_descriptions(self, vft): + #embed the object descriptions using bert and then create the object token using transformer layers + if self.config['feature_type'] == "object_text_features": + object_features = [] + for i in range(vft.shape[1]): + object_description = vft[:, i, :] + object_description_mask = (object_description != 1) + embedded_object_description = self.apply_pretrained_lm(object_description, object_description_mask) + + #map embeddings to a smaller size (motivation: reduce transformer sice of object description encoder) + embedded_object_description = self.linear_projection_object_description(embedded_object_description) + + #apply transformer to encode the object description + object_token = self.object_description_encoder(embedded_object_description) + object_features.append(object_token) + object_features = torch.concat(object_features, dim=1) + #add frame dimension (only one frame in this cas) + object_features = object_features.unsqueeze(1) + #bring the data to the format [batch_size x frames x emb_dim (desc_text_len) x obj_number] + vft = object_features.permute(0, 1, 3, 2) + + return vft + + + def create_target_mask(self, size): + mask = torch.triu(torch.ones((size,size), device=self.device), 1) + mask = mask.masked_fill(mask == 1, float('-inf')) + return mask + + + def generate_prob_for_next_tokens(self, input, answer_emb, tgt_mask, input_mask, answer_mask): + x = self.transformer.encoder(input, src_key_padding_mask=input_mask) + dec_out = self.transformer.decoder(answer_emb, x, tgt_mask) + probs = self.prob_generator(dec_out) + + + return probs + + + def generate_complete_answers(self, input, input_mask): + # encode the complete batch of questions + memory = self.transformer.encoder(input, src_key_padding_mask=input_mask) + generated_answers = torch.ones(memory.shape[0], 40, dtype=torch.int) # 20 = max answer length, use unknown token () + + # generate the answers for each individual question from the batch + for i in range(memory.shape[0]): + memory_i = memory[i, :, :] + memory_i = memory_i.unsqueeze(0) + answer_i = torch.zeros((1,1), dtype=torch.int, device=self.device) # Pass start token to decoder as first input. From roberta vocab: ": 0, "": 2 + + for j in range(40): # 20 = max answer length + + answer_i_emb = self.prepare_lang_emb(answer_i, torch.ones((1, answer_i.shape[0]), device=self.device, dtype=torch.int16)) + tgt_mask = self.create_target_mask(answer_i.shape[1]) + decoder_output = self.transformer.decoder(answer_i_emb, memory_i, tgt_mask) + prob = self.prob_generator(decoder_output[:, -1, :]) + next_word = prob.argmax() + + answer_i = torch.concat([answer_i, next_word.unsqueeze(0).unsqueeze(0)], dim=1) + if next_word.item() == 2: # eos token in roberta vocab "": 2 + break + + generated_answers[i, :answer_i.shape[1] - 1] = answer_i[0, 1:] + + return generated_answers + + + def apply_model(self, language_emb, language_emb_mask, video_emb, v_state=None, d_state=None, answer_emb=None, answer_mask=None, state_generation_mode=False): + # combine state and embeddings + input = self.combiner( + video_emb, + language_emb, + language_emb_mask, + v_state, + d_state + ) + # create input mask based on the language_emb_mask (complete video is unmasked) + input_mask = torch.zeros((input.shape[0], input.shape[1]), device=self.device) + offset = 0 + if v_state is not None: offset += 1 + if d_state is not None: offset += 1 + # offset is caused by state vectors + input_mask[:, video_emb.shape[1] + offset:] = ~language_emb_mask + tgt_mask = self.create_target_mask(answer_emb.shape[1]) + + #-------TODO: Mask padded object embeddings when text based object embeddings are used ------------- + + if self.mode == 'train' or state_generation_mode: + probs = self.generate_prob_for_next_tokens(input, answer_emb, tgt_mask, input_mask, answer_mask) + return probs + elif self.mode == 'val': + generated_answers = self.generate_complete_answers(input, input_mask) + return generated_answers + + + def prepare_answer_emb_and_mask(self, answer, answer_mask): + mask = torch.tril(torch.ones((answer.shape[1], answer.shape[1]), device=self.device)) + mask = mask.unsqueeze(0) + mask = mask.expand(answer.shape[0], -1, -1) + answer_emb = self.apply_pretrained_lm(answer, mask) + + answer_emb = self.linear_projection_text(answer_emb) + answer_emb = self.append_ids(answer_emb, [1, 0], 2) + answer_emb = self.positional_encoder(answer_emb) + + # pytorch interprets True in a mask as padding + answer_mask = ~answer_mask + answer_emb_final = answer_emb[:, :-1].detach() + answer_mask_final = answer_mask[:, :-1].detach() + + return answer_emb_final, answer_mask_final + + + def answer_query(self, query, query_mask, vft, v_state=None, d_state=None, answer=None, answer_mask=None, state_generation_mode=False): + video_emb = self.prepare_video_emb(vft) + lang_emb = self.prepare_lang_emb(query, query_mask) + answer_emb, answer_mask = self.prepare_answer_emb_and_mask(answer, answer_mask) + output = self.apply_model(lang_emb, query_mask, video_emb, v_state, d_state, answer_emb, answer_mask, state_generation_mode) + return output + + + def training_step(self, train_batch, batch_idx): + train_batch = self.batch_interface(train_batch, feature_type=self.config['feature_type']) + if self.config['feature_type'] == "object_text_features": + train_batch.vft = self.encode_object_descriptions(train_batch.vft) + + logits = self.forward(train_batch) + logits = logits.permute(0, 2, 1) + + # replace any unknown token (id = 3) with a padding token in order to also ignore them -> avoid model which outputs unk tokens + train_batch.answer[train_batch.answer == 3] = 1 + loss = self.loss(logits, train_batch.answer[:, 1:]) # ignore padding + self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True, batch_size=train_batch.query.shape[0]) + return loss + + + def get_next_token_pred_as_text_and_logits(self, batch): + # set mode to train to get the logits instead of completely generated sentences + self.mode = 'train' + logits = self.forward(batch) + logits = logits.permute(0, 2, 1) + predicted_tokens = [] + for j in range(logits.shape[0]): + l = logits[j, :, :] + ids = [l[:, i].argmax().item() for i in range(l.shape[1])] + text = translate_from_ids_to_text(ids, self.tokenizer) + predicted_tokens.append(text) + # set mode back to val + self.mode = 'val' + + return predicted_tokens, logits + + + def calculate_bleu_score(self, generated_answer_ids, correct_answer): + # calculate bleu score for the generated answers compared to the provided correct answers + bleu4_scores = [] + all_generated_answers = [] + for i in range(generated_answer_ids.shape[0]): + generated_answer = generated_answer_ids[i, :].tolist() + generated_answer_text = translate_from_ids_to_text(generated_answer, self.tokenizer) + all_generated_answers.append(generated_answer_text) + correct_answer_text_i = correct_answer[i] + score4 = nltk.translate.bleu_score.sentence_bleu( + [normalize_sentence(correct_answer_text_i)], + normalize_sentence(generated_answer_text), + smoothing_function=nltk.translate.bleu_score.SmoothingFunction().method7 + ) + bleu4_scores.append(score4) + bleu4_score = np.mean(bleu4_scores) + return bleu4_score, all_generated_answers + + + def translate_answer_ids_to_text(self, answer): + correct_answer_text = [] + for i in range(answer.shape[0]): + correct_answer_i = answer[i, :].tolist() + correct_answer_text_i = translate_from_ids_to_text(correct_answer_i, self.tokenizer) + correct_answer_text.append(correct_answer_text_i) + return correct_answer_text + + + def validation_step(self, val_batch, batch_idx): + val_batch = self.batch_interface(val_batch, feature_type=self.config['feature_type']) + if self.config['feature_type'] == "object_text_features": + val_batch.vft = self.encode_object_descriptions(val_batch.vft) + + correct_answer_text = self.translate_answer_ids_to_text(val_batch.answer) + generated_answer_ids = self.forward(val_batch) + + # calculate and log bleu score for the generated answers compared to the provided correct answers + bleu4_score, generated_answers_text = self.calculate_bleu_score(generated_answer_ids, correct_answer_text) + self.log('bleu4', bleu4_score, prog_bar=True, on_step=False, on_epoch=True, batch_size=generated_answer_ids.shape[0]) + + # calculate and log the validation loss based on the results from next token predicition (train mode needed) + predicted_tokens, logits = self.get_next_token_pred_as_text_and_logits(val_batch) + loss = self.loss(logits, val_batch.answer[:, 1:]) # ignore padding + self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True, batch_size=val_batch.query.shape[0]) + + return {'next_token_predictions': predicted_tokens, 'generated_answers': generated_answers_text, 'correct_answers': correct_answer_text} + + + def test_step(self, test_batch, batch_idx): + dialog_id = test_batch['dialog_id'] + turn_id = test_batch['turn_id'] + test_batch = self.batch_interface(test_batch, feature_type=self.config['feature_type']) + if self.config['feature_type'] == "object_text_features": + test_batch.vft = self.encode_object_descriptions(test_batch.vft) + + correct_answer_text = self.translate_answer_ids_to_text(test_batch.answer) + generated_answer_ids = self.forward(test_batch) + + # calculate and log bleu score for the generated answers compared to the provided correct answers + bleu4_score, generated_answers_text = self.calculate_bleu_score(generated_answer_ids, correct_answer_text) + self.log('bleu4', bleu4_score, prog_bar=True, on_step=False, on_epoch=True, batch_size=generated_answer_ids.shape[0]) + + # calculate and log the validation loss based on the results from next token predicition (train mode needed) + predicted_tokens, logits = self.get_next_token_pred_as_text_and_logits(test_batch) + loss = self.loss(logits, test_batch.answer[:, 1:]) # ignore padding + self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True, batch_size=test_batch.query.shape[0]) + + return {'turn_id': turn_id, 'next_token_predictions': predicted_tokens, 'dialog_id': dialog_id, 'generated_answers': generated_answers_text, 'correct_answers': correct_answer_text} + + + def test_epoch_end(self, outputs): + + if self.config['output_format'] == 'submission': + responses = [] + for output in outputs: + for t_id, d_id, answer in zip(output['turn_id'], output['dialog_id'], output['generated_answers']): + sample = { + 'dialog_id': d_id, + 'predictions': [ + { + 'turn_id': t_id, + 'response': answer + } + ] + } + responses.append(sample) + name = 'dstc11-simmc-devtest-pred-subtask-4-generation.json' + with open(os.path.join(self.output_path, name), 'w') as file: + json.dump(responses, file) + + else: + result_idx = 0 + for output in outputs: + for j in range(len(output['next_token_predictions'])): + pred = " " + corr = " " + gen = " " + self.results[result_idx] = { + 'next_token_pred': pred.join(output['next_token_predictions'][j]), + 'generated_ans': gen.join(output['generated_answers'][j]), + 'correct': corr.join(output['correct_answers'][j]) + } + result_idx += 1 + + name = f'epoch_{self.epoch_count}.json' + with open(os.path.join(self.output_path, name), 'w') as file: + json.dump(self.results, file) + + + def validation_epoch_end(self, outputs): + result_idx = 0 + for output in outputs: + for j in range(len(output['next_token_predictions'])): + pred = " " + corr = " " + gen = " " + self.results[result_idx] = { + 'next_token_pred': pred.join(output['next_token_predictions'][j]), + 'generated_ans': gen.join(output['generated_answers'][j]), + 'correct': corr.join(output['correct_answers'][j]) + } + result_idx += 1 + + name = f'epoch_{self.epoch_count}.json' + with open(os.path.join(self.output_path, name), 'w') as file: + json.dump(self.results, file) + + self.results = {} + self.epoch_count += 1 + + + def on_train_epoch_start(self): + self.mode = 'train' + + + def on_validation_epoch_start(self): + self.mode = 'val' + + + def on_test_epoch_start(self): + self.mode = 'val' + + + + diff --git a/src/models/state_tracker_model.py b/src/models/state_tracker_model.py new file mode 100644 index 0000000..707b9dc --- /dev/null +++ b/src/models/state_tracker_model.py @@ -0,0 +1,167 @@ +import pytorch_lightning as pl +import torch +from torch import nn +from src.models.base_model import TransformerModel +from src.utils.save_attention_weights import SaveOutput +from src.utils.custom_transformer_encoder_layer import CustomTransformerEncoderLayer +from src.state_trackers.video_state_tracker import VstLSTM +from src.state_trackers.dialogue_state_tracker import DstLSTM +from src.state_trackers.vst_transformer_based import VstTransformer +from src.state_trackers.dst_transformer_based import DstTransformer +from src.combiner.option_a import CombinerOptionA +from src.combiner.option_b import CombinerOptionB +from src.combiner.option_c import CombinerOptionC + + +class StateTrackerModel(TransformerModel): + def __init__(self, config, output_path=None): + super().__init__(config, output_path=output_path) + self.config = config['model'] + self.ext_config = config['extended_model'] + + combine_state_and_emb_options = { + 'OptionA': CombinerOptionA, + 'OptionB': CombinerOptionB, + 'OptionC': CombinerOptionC, + } + state_tracker_options = { + 'Transformer': { + 'vst': VstTransformer, + 'dst': DstTransformer + }, + 'LSTM': { + 'vst': VstLSTM, + 'dst': DstLSTM + } + } + + # if option b is used the state vector is appended to each embedding -> input size for the transformers needs to double + if self.ext_config['combiner_option'] == 'OptionB': + self.model_input_dim *= 2 + # replace fc layer with a fitting one for the larger embeddings + self.fc = nn.Linear(self.model_input_dim, self.config["fc_dim"]) + + self.combiner = combine_state_and_emb_options[self.ext_config['combiner_option']]( + config = self.ext_config, + model_input_dim = self.model_input_dim, + use_v_state=self.ext_config['use_v_state'], + use_d_state=self.ext_config['use_d_state'] + ) + encoder_layer = CustomTransformerEncoderLayer( + d_model=self.model_input_dim, + batch_first=True, + dropout=self.config['dropout_p'], + dim_feedforward=self.config['dim_feedforward'], + nhead=self.config['n_heads'] + ) + self.encoder = nn.TransformerEncoder( + encoder_layer=encoder_layer, + num_layers=self.config['n_encoder_layers'], + ) + self.save_output = SaveOutput() + self.hook_handle = self.encoder.layers[-1].self_attn.register_forward_hook(self.save_output) + if self.ext_config['use_v_state']: + self.video_state_tracker = state_tracker_options[self.ext_config['state_tracker_type']]['vst']( + self.model_input_dim, + self.config['dropout_p'], + self.ext_config + ) + if self.ext_config['use_d_state']: + self.dial_state_tracker = state_tracker_options[self.ext_config['state_tracker_type']]['dst']( + self.model_input_dim, + self.config['dropout_p'], + self.ext_config + ) + self.video_emb_start_idx = self.calculate_video_emb_start_idx() + + + def calculate_video_emb_start_idx(self): + video_emb_start_idx = 0 + if self.config['model_type'] == 'discriminative': video_emb_start_idx += 1 + if self.ext_config['use_v_state']: video_emb_start_idx += 1 + if self.ext_config['use_d_state']: video_emb_start_idx += 1 + return video_emb_start_idx + + + def determine_relevant_obj_emb(self, attention_weights, vft): + # determine index of maximum values + obj_emb = self.prepare_video_emb(vft) + _, relevant_emb_indices = attention_weights[:, self.video_emb_start_idx:obj_emb.shape[1] + self.video_emb_start_idx].topk(k=self.ext_config['number_of_relevant_emb'], dim=1) + relevant_emb = torch.zeros((obj_emb.shape[0], self.ext_config['number_of_relevant_emb'], obj_emb.shape[2]), device=self.device) + for i in range(attention_weights.shape[0]): + relevant_emb[i, :, :] = obj_emb[i, relevant_emb_indices[i, :]] + + return relevant_emb + + + def get_attention_weights(self, n_vid_emb): + if self.config['model_type'] in ['generative', 'ranking']: + # get the attention weights from the query tokens and sum all of them + query_start_idx = self.video_emb_start_idx + n_vid_emb + attention_weights = self.save_output.outputs[1][:, query_start_idx:, :] + attention_weights = attention_weights.sum(dim=1) + elif self.config['model_type'] == 'discriminative': + # get only the attention weights of the cls token + attention_weights = self.save_output.outputs[1][:, 0, :] + return attention_weights + + + def forward(self, batch): + # initialize the state vectors - initialize as none if we dont want to use them + if self.ext_config['use_v_state']: + video_state = torch.zeros((batch.query.shape[0], 1, self.model_input_dim), device=self.device) + else: + video_state = None + if self.ext_config['use_d_state']: + dial_state = torch.zeros((batch.query.shape[0], 1, self.model_input_dim), device=self.device) + else: + dial_state = None + + # create the state vectors based on the previous n most recent dialogue turns + hist_start_turn_state_gen = batch.turns.shape[1] - 1 - self.ext_config["hist_len_for_state_gen"] + for dialogue_round in range(max(0, hist_start_turn_state_gen), batch.turns.shape[1]): + question = batch.q_turns[:, dialogue_round, :] + + question_mask = batch.q_turns_mask[:, dialogue_round, :] + qa_pair = batch.turns[:, dialogue_round, :] + qa_pair_mask = batch.turns_mask[:, dialogue_round, :] + + # pass correct answer tokens to the decoder for training a generative model + if self.config['model_type'] in ['generative', 'ranking']: + answer = batch.a_turns[:, dialogue_round, :] + answer_mask = batch.a_turns_mask[:, dialogue_round, :] + # the answer is not used, only the attention weights are relevant for state creation + _ = self.answer_query(question, question_mask, batch.vft, video_state, dial_state, answer, answer_mask, state_generation_mode=True) + else: + _ = self.answer_query(question, question_mask, batch.vft, video_state, dial_state) + + + # update the states + if self.ext_config['use_v_state']: + # get the attention weights from the last "answer_query" call and determine the relevant obj + attention_weights = self.get_attention_weights(n_vid_emb=batch.vft.shape[1]) + relevant_obj_emb = self.determine_relevant_obj_emb(attention_weights, batch.vft) + # add ids to match the input size of the main transformer block + video_state = self.video_state_tracker(relevant_obj_emb) + if self.ext_config['use_d_state']: + qa_pair_emb = self.prepare_lang_emb(qa_pair, qa_pair_mask) + # add ids to match the input size of the main transformer block + dial_state = self.dial_state_tracker(qa_pair_emb) + + # delete state of the state tracker + if self.ext_config['use_v_state']: + self.video_state_tracker.reset() + if self.ext_config['use_d_state']: + self.dial_state_tracker.reset() + + # answer the actual question + # pass correct answer tokens to the decoder for training a generative model + if self.config['model_type'] in ['generative', 'ranking']: + output = self.answer_query(batch.query, batch.query_mask, batch.vft, video_state, dial_state, batch.answer, batch.answer_mask) + else: + output = self.answer_query(batch.query, batch.query_mask, batch.vft, video_state, dial_state) + + return output + + + diff --git a/src/object_description_encoder/object_description_encoder.py b/src/object_description_encoder/object_description_encoder.py new file mode 100644 index 0000000..c05660c --- /dev/null +++ b/src/object_description_encoder/object_description_encoder.py @@ -0,0 +1,29 @@ +import pytorch_lightning as pl +from torch import nn +import torch + + +class ObjectDescriptionEncoder(pl.LightningModule): + def __init__(self, d_model, config): + super().__init__() + self.d_model = d_model + encoder_layer = nn.TransformerEncoderLayer( + d_model=d_model, + batch_first=True, + dropout=config['dropout_p'], + dim_feedforward=config['object_feature_generator_dim'], + nhead=config['n_heads'] + ) + self.encoder = nn.TransformerEncoder( + encoder_layer=encoder_layer, + num_layers=config['n_object_feature_generator_layers'] + ) + + def forward(self, input): + object_description_embedding = torch.zeros((input.shape[0], 1, self.d_model), device=self.device) + input = torch.concat([object_description_embedding, input], dim=1) + output = self.encoder(input) + object_description_embedding = output[:, 0, :] + object_description_embedding = object_description_embedding.unsqueeze(1) + return object_description_embedding + diff --git a/src/state_trackers/dialogue_state_tracker.py b/src/state_trackers/dialogue_state_tracker.py new file mode 100644 index 0000000..6d363f7 --- /dev/null +++ b/src/state_trackers/dialogue_state_tracker.py @@ -0,0 +1,32 @@ +import pytorch_lightning as pl +from torch import nn +import torch + + +class DstLSTM(pl.LightningModule): + def __init__(self, emb_dim, dropout, config): + super().__init__() + self.lstm_layer = nn.LSTM( + input_size=emb_dim, + hidden_size=emb_dim, + num_layers=config['num_layers_d_state'], + batch_first=True, + dropout=dropout + ) + self.h = None + self.c = None + + def forward(self, input): + if self.h is None: + _, (self.h, self.c) = self.lstm_layer(input) + else: + _, (self.h, self.c) = self.lstm_layer(input, (self.h, self.c)) + + output = torch.permute(self.h, (1, 0, 2)) + output = output[:, -1, :] + output = output.unsqueeze(1) + return output + + def reset(self): + self.h = None + self.c = None diff --git a/src/state_trackers/dst_transformer_based.py b/src/state_trackers/dst_transformer_based.py new file mode 100644 index 0000000..a81e27b --- /dev/null +++ b/src/state_trackers/dst_transformer_based.py @@ -0,0 +1,36 @@ +import pytorch_lightning as pl +from torch import nn +import torch + + +class DstTransformer(pl.LightningModule): + def __init__(self, emb_dim, dropout, config): + super().__init__() + self.emb_dim = emb_dim + encoder_layer = nn.TransformerEncoderLayer( + d_model=emb_dim, + batch_first=True, + dropout=dropout, + dim_feedforward=config['dim_feedforward_d_transformer'], + nhead=config['n_heads_state_tracker'] + ) + self.encoder = nn.TransformerEncoder( + encoder_layer=encoder_layer, + num_layers=config['num_layers_d_state'] + ) + self.state_vector = None + + + def forward(self, input): + if self.state_vector is None: + self.state_vector = torch.zeros((input.shape[0], 1, self.emb_dim), device=self.device) + + input = torch.concat([self.state_vector, input], dim=1) + output = self.encoder(input) + self.state_vector = output[:, 0, :] + self.state_vector = self.state_vector.unsqueeze(1) + return self.state_vector + + + def reset(self): + self.state_vector = None diff --git a/src/state_trackers/video_state_tracker.py b/src/state_trackers/video_state_tracker.py new file mode 100644 index 0000000..7778fb7 --- /dev/null +++ b/src/state_trackers/video_state_tracker.py @@ -0,0 +1,36 @@ +import pytorch_lightning as pl +from torch import nn +import torch + + +class VstLSTM(pl.LightningModule): + def __init__(self, emb_dim, dropout, config): + super().__init__() + self.lstm_layer = nn.LSTM( + input_size=emb_dim, + hidden_size=emb_dim, + num_layers=config['num_layers_v_state'], + batch_first=True, + dropout=dropout + ) + self.h = None + self.c = None + + def forward(self, input): + if self.h is None: + _, (self.h, self.c) = self.lstm_layer(input) + else: + _, (self.h, self.c) = self.lstm_layer(input, (self.h, self.c)) + + output = torch.permute(self.h, (1,0,2)) + output = output[:, -1, :] + output = output.unsqueeze(1) + return output + + def reset(self): + self.h = None + self.c = None + + + + \ No newline at end of file diff --git a/src/state_trackers/vst_transformer_based.py b/src/state_trackers/vst_transformer_based.py new file mode 100644 index 0000000..af67a68 --- /dev/null +++ b/src/state_trackers/vst_transformer_based.py @@ -0,0 +1,39 @@ +import pytorch_lightning as pl +from torch import nn +import torch + + +class VstTransformer(pl.LightningModule): + def __init__(self, emb_dim, dropout, config): + super().__init__() + self.emb_dim = emb_dim + encoder_layer = nn.TransformerEncoderLayer( + d_model=emb_dim, + batch_first=True, + dropout=dropout, + dim_feedforward=1 + config['number_of_relevant_emb'], + nhead=config['n_heads_state_tracker'] + ) + self.encoder = nn.TransformerEncoder( + encoder_layer=encoder_layer, + num_layers=config['num_layers_v_state'] + ) + self.state_vector = None + + + def forward(self, input): + if self.state_vector is None: + self.state_vector = torch.zeros((input.shape[0], 1, self.emb_dim), device=self.device) + + input = torch.concat([self.state_vector, input], dim=1) + output = self.encoder(input) + self.state_vector = output[:, 0, :] + self.state_vector = self.state_vector.unsqueeze(1) + return self.state_vector + + + def reset(self): + self.state_vector = None + + + \ No newline at end of file diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/utils/batch_interfaces.py b/src/utils/batch_interfaces.py new file mode 100644 index 0000000..f543b70 --- /dev/null +++ b/src/utils/batch_interfaces.py @@ -0,0 +1,106 @@ +import torch +from dataclasses import dataclass +from typing import Optional + +@dataclass +class Batch: + query: torch.Tensor + query_mask: torch.Tensor + vft: torch.Tensor + turns: torch.Tensor + turns_mask: torch.Tensor + q_turns: torch.Tensor + q_turns_mask: torch.Tensor + a_turns: torch.Tensor + a_turns_mask: torch.Tensor + answer: torch.Tensor + answer_mask: torch.Tensor + answer_candidates: Optional[torch.Tensor] = None + answer_candidates_mask: Optional[torch.Tensor] = None + + +# ---- TODO: Replace with function for the Mask RCNN features ---- +def create_monet_like_vft(vft): + target_dim = 36 + remainder = vft.shape[1] % target_dim + vft = vft[:, :-remainder].reshape((vft.shape[0], -1, target_dim)) + vft = vft.unsqueeze(3) + return vft + + +def batch_interface_simmc2_to_dvd(batch, feature_type): + if feature_type == 'resnet50': + vft = batch['features'] + vft = vft.unsqueeze(3) + elif feature_type == "object_text_features": + vft = batch['object_features'] + # add frame dimension (only one frame in this cas) + #vft = vft.unsqueeze(1) + # bring the data to the format [batch_size x frames x emb_dim (desc_text_len) x obj_number] + #vft = vft.permute(0, 1, 3, 2) + + batch_in_dvd_format = Batch( + query=batch['query'], + query_mask=(batch['query'] != 1), + vft=vft, + turns=batch['turns'], + turns_mask=(batch['turns'] != 1), + q_turns=batch['q_turns'], + q_turns_mask=(batch['q_turns'] != 1), + a_turns=batch['a_turns'], + a_turns_mask=(batch['a_turns'] != 1), + answer=batch['answer'].type(torch.int64), + answer_mask=(batch['answer'] != 1), + answer_candidates=batch['answer_candidates'], + answer_candidates_mask=(batch['answer_candidates'] != 1) + ) + return batch_in_dvd_format + + + +def batch_interface_avsd_to_dvd(batch, feature_type): + # map question to query + query = batch['ques'][:,-1, :] + query_mask = (query != 1) + + # map vid_feat to vft + # TODO: Use other video features ------!!!------- + if feature_type == 'i3d': + vft = create_monet_like_vft(batch['vid_feat']) + else: + vft = batch['vid_feat'] + + + q_turns = batch['ques'][:, :9, :] + q_turns_mask = (q_turns != 1) + + index_tensor = batch['ans_ind'].unsqueeze(2) + index_tensor = index_tensor.repeat(1,1,20) + index_tensor = index_tensor.unsqueeze(2) + a_turns = batch['opt'].gather(2, index_tensor) + a_turns = a_turns.squeeze(2) + + # turns should only contain the previous questions (first 9 turns) + a_turns, answer = a_turns.split([9, 1], dim=1) + answer = answer.squeeze(1) + a_turns_mask = (a_turns != 1) + answer_mask = (answer != 1) + + # concat questions and a_turns to create turns tensor + turns = torch.concat((q_turns, a_turns), 2) + turns_mask = (turns != 1) + + batch_in_dvd_format = Batch( + query, + query_mask, + vft, + turns, + turns_mask, + q_turns, + q_turns_mask, + a_turns, + a_turns_mask, + answer, + answer_mask + ) + return batch_in_dvd_format \ No newline at end of file diff --git a/src/utils/custom_transformer_encoder_layer.py b/src/utils/custom_transformer_encoder_layer.py new file mode 100644 index 0000000..9e88479 --- /dev/null +++ b/src/utils/custom_transformer_encoder_layer.py @@ -0,0 +1,84 @@ +# https://pytorch.org/docs/stable/_modules/torch/nn/modules/transformer.html#TransformerEncoderLayer +from typing import Optional, Any, Union, Callable + +from torch import nn + +import torch +from torch import Tensor +from torch.nn import functional as F +from torch.nn.modules import Module +from torch.nn import MultiheadAttention +#from nn.container import ModuleList +#from ..init import xavier_uniform_ +from torch.nn import Dropout +from torch.nn import Linear +from torch.nn import LayerNorm + + +class CustomTransformerEncoderLayer(Module): + + __constants__ = ['batch_first', 'norm_first'] + + def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False, + device=None, dtype=None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, + **factory_kwargs) + # Implementation of Feedforward model + self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs) + self.dropout = Dropout(dropout) + self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs) + + self.norm_first = norm_first + self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) + self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) + self.dropout1 = Dropout(dropout) + self.dropout2 = Dropout(dropout) + + def __setstate__(self, state): + if 'activation' not in state: + state['activation'] = F.relu + super().__setstate__(state) + + def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor: + r"""Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + see the docs in Transformer class. + """ + + # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf + + x = src + if self.norm_first: + x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask) + x = x + self._ff_block(self.norm2(x)) + else: + x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask)) + x = self.norm2(x + self._ff_block(x)) + + return x + + + # self-attention block + def _sa_block(self, x: Tensor, + attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor: + x = self.self_attn(x, x, x, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=True)[0] + return self.dropout1(x) + + # feed forward block + def _ff_block(self, x: Tensor) -> Tensor: + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + return self.dropout2(x) + diff --git a/src/utils/dvd_codebase/__init__.py b/src/utils/dvd_codebase/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/utils/dvd_codebase/configs/configs.py b/src/utils/dvd_codebase/configs/configs.py new file mode 100644 index 0000000..8ce0d58 --- /dev/null +++ b/src/utils/dvd_codebase/configs/configs.py @@ -0,0 +1,39 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import argparse +import logging +import random +import numpy as np + +parser = argparse.ArgumentParser() +parser.add_argument('--debug', default=0, type=int, help='') + +# Data /projects/hochmeister/CATER-videos/features/per_video' +#parser.add_argument('--fea-dir', default='/scratch/hochmeister/CATER-videos/features/monet_pretrained_on_clevr/per_video', type=str, help='Image feature files (.pkl)') +#parser.add_argument('--data-dir', default='/scratch/hochmeister/DVDData/small_subset/', type=str,help='Path to training feature files') +parser.add_argument('--output-dir', default='/scratch/abdessaied/projects/olvit/msc2022_hochmeister/checkpoints/avsd_code_testing', type=str,help='output path of model and params') +parser.add_argument('--num-workers', default=20, type=int, help='') +parser.add_argument('--device', default='0', type=str, help='') + +# Training +parser.add_argument('--num-epochs', '-e', default=15, type=int,help='Number of epochs') +#parser.add_argument('--batch-size', '-b', default=85, type=int,help='Batch size in training') +# others +parser.add_argument('--verbose', '-v', default=0, type=int,help='verbose level') + +args, unknown = parser.parse_known_args() + +print(args) + +# Presetting +if args.verbose >= 1: + logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s') +else: + logging.basicConfig(level=logging.INFO, + format='%(asctime)s %(levelname)s: %(message)s') diff --git a/src/utils/dvd_codebase/data/__init__.py b/src/utils/dvd_codebase/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/utils/dvd_codebase/data/analysis_utils.py b/src/utils/dvd_codebase/data/analysis_utils.py new file mode 100644 index 0000000..a25ee02 --- /dev/null +++ b/src/utils/dvd_codebase/data/analysis_utils.py @@ -0,0 +1,282 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import glob, json, pdb +from tqdm import tqdm +import pandas as pd +import copy, os + +def get_question_type(template, prior_template): + last_node_type = template['nodes'][-1]['type'] + text = template['text'][0].lower() + if 'same set of activities' in text: + qtype = 'compare action set' + elif 'same sequence of activities' in text: + qtype = 'compare action sequence' + elif 'frequently' in text: + qtype = 'compare int' + elif 'how many times' in text: + qtype = 'action count' + elif 'how many' in text or 'what number' in text: + qtype = 'obj count' + elif 'is there' in text: + qtype = 'obj exist' + elif 'what color' in text or 'what material' in text or 'what shape' in text or 'what size' in text: + qtype = 'attr query' + elif 'what type of action' in text or 'what is the' in text or 'what types of action' in text: + qtype = 'action query' + else: + assert 'what about' in text + qtype = get_question_type(prior_template, None) + return qtype + +def get_question_subtype(template, prior_template): + last_node_type = template['nodes'][-1]['type'] + text = template['text'][0].lower() + if 'same set of activities' in text: + if 'how many' in text: + qtype = 'compare action set (count)' + else: + qtype = 'compare action set (exist)' + elif 'same sequence of activities' in text: + if 'how many' in text: + qtype = 'compare action seq (count)' + else: + qtype = 'compare action seq (exist)' + elif 'frequently' in text: + if 'as frequently' in text: + qtype = 'compare int (equal)' + elif 'less frequently' in text: + qtype = 'compare int (less)' + elif 'more frequently' in text: + qtype = 'compare int (more)' + elif 'how many times' in text: + qtype = 'action count' + elif 'how many' in text or 'what number' in text: + qtype = 'obj count' + elif 'is there' in text: + qtype = 'obj exist' + elif 'what color' in text or 'what about its color' in text: + qtype = 'attr query (color)' + elif 'what material' in text or 'what about its material'in text: + qtype = 'attr query (material)' + elif 'what shape' in text or 'what about its shape' in text: + qtype = 'attr query (shape)' + elif 'what size' in text or 'what about its size' in text: + qtype = 'attr query (size)' + elif 'what type of action' in text or 'what is the' in text or 'what types of action' in text: + if '' in text: + qtype = 'action query (by order)' + elif '' in text: + qtype = 'ation query (by freq)' + else: + qtype = 'action query (all actions)' + else: + assert 'what about' in text + assert 'color' not in text and 'size' not in text and \ + 'shape' not in text and 'material' not in text + qtype = get_question_subtype(prior_template, None) + return qtype + +def get_question_complexity(turn, template_fn): + template = turn['template'] + interval_type = template['interval_type'] + last_node_type = template['nodes'][-1]['type'] + second_last_node_type = template['nodes'][-2]['type'] + + if interval_type == 'none': + return 'none' + elif interval_type == 'atomic': + if 'one_hop' in template_fn: + return 'atomic (spatial)' + else: + return 'atomic (non-spatial)' + #return 'atomic' + elif interval_type == 'compositional': + return 'compositional' + +def get_accuracies_by_type(all_types, models, all_answers, all_results, output_file): + types = sorted(set(all_types)) + accuracies = {} + for t in types: + accuracies[t] = [] + for model in models: + nb_corrects = 0 + count = 0 + results = all_results[model] + for a_idx, a in enumerate(all_answers): + curr_type = all_types[a_idx] + if curr_type != t: continue + pred = results[a_idx] + if str(pred) == str(a): + nb_corrects += 1 + count += 1 + acc = nb_corrects/count + accuracies[t].append(acc) + df = copy.deepcopy(accuracies) + df['model'] = models + df = pd.DataFrame(data=df, columns=['model'] + list(accuracies.keys())) + df.to_csv(output_file) + return types, accuracies, df + +def get_transfer_accuracies(all_types, models, all_answers, all_results, output_file, is_video_update=False, is_all=False): + accuracies = [] + for model in models: + results = all_results[model] + nb_corrects = 0 + count = 0 + for a_idx, a in enumerate(all_answers): + if is_all: + is_single_turn = True + for k,v in all_types.items(): + if v[a_idx] != 'none': + is_single_turn = False + break + if is_single_turn: continue + else: + curr_type = all_types[a_idx] + if is_video_update: + if curr_type != 'video_update': continue + else: + if curr_type != 'yes': continue + prior_pred_a = results[a_idx-1] + prior_gt_a = all_answers[a_idx-1] + if str(prior_pred_a) != str(prior_gt_a): continue + pred_a = results[a_idx] + gt_a = all_answers[a_idx] + if str(pred_a) == str(gt_a): + nb_corrects += 1 + count += 1 + if count == 0: + acc = 0 + else: + #pdb.set_trace() + acc = nb_corrects/count + accuracies.append(acc) + df = {} + df['accuracies'] = accuracies + df['model'] = models + df = pd.DataFrame(data=df, columns=['model', 'accuracies']) + df.to_csv(output_file) + return df + +def get_start_end_time(period): + start, end = period + if start is None: + start = 0 + else: + start = start[-1] + if end is None: + end = 301 + else: + end = end[-1] + return start, end + +def get_period_size(period): + if period is None: + return 0 + start, end = get_start_end_time(period) + return end - start + +def get_overlap_period(curr_period, last_period, ratio=False): + if curr_period is None: + return -1 + if last_period is None: + return 0 + s1, e1 = get_start_end_time(curr_period) + s2, e2 = get_start_end_time(last_period) + if s2n: + bin = n + else: + break + return bin + +def get_obj_turn_dist(used_objects, dependencies, template, turn_idx): + all_dists = [0] + + if dependencies['object'] != 'none': + if dependencies['object'] == 'earlier_unique': + obj_id = str(template['earlier_unique_obj']) + if obj_id not in used_objects: + pdb.set_trace() + turn_dist = turn_idx - used_objects[obj_id]['original_turn'] + 1 + all_dists.append(turn_dist) + + if dependencies['temporal'] != 'none': + if 'earlier_unique' in dependencies['temporal']: + obj_id = str(template['temporal_obj_id']) + if obj_id not in used_objects: + pdb.set_trace() + turn_dist = turn_idx - used_objects[obj_id]['original_turn'] + 1 + all_dists.append(turn_dist) + + return max(all_dists) + +def get_stats(dials): + videos = set() + questions = set() + for dial in dials: + for turn in dial: + question = turn['question'] + video = '{}-{}'.format(turn['split'], turn['image_filename']) + videos.add(video) + questions.add(question) + print('# videos: {}'.format(len(videos))) + print("# dialogues: {}".format(len(dials))) + print("# unique questions: {}".format(len(questions))) + output = { + '#videos': len(videos), + '#dialogues': len(dials), + '#unique questions': len(questions) + } + return output + +def find_video_end_range(end_time): + ranges = [0, 30, 60, 90, 120, 150, 180, 210, 240, 270] + if end_time is None: + return 9 + for idx, r in enumerate(ranges): + if end_time[-1] > r: + curr_r = idx + else: + return curr_r + return 9 + +def find_video_start_range(start_time): + ranges = [400, 270, 240, 210, 180, 150, 120, 90, 60, 30] + if start_time is None: + return 0 + for idx, r in enumerate(ranges): + if start_time[-1] <= r: + curr_r = 9-idx + else: + return curr_r + return 0 diff --git a/src/utils/dvd_codebase/data/data_handler.py b/src/utils/dvd_codebase/data/data_handler.py new file mode 100644 index 0000000..c9574e6 --- /dev/null +++ b/src/utils/dvd_codebase/data/data_handler.py @@ -0,0 +1,264 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import copy, logging, sys, time, os, pdb, random, glob, json +import pickle as pkl +import numpy as np +from tqdm import tqdm +from collections import Counter +from functools import partial +import nltk +import torch +import torch.utils.data as Data +from src.utils.dvd_codebase.data.dataset import * +from src.utils.dvd_codebase.data.analysis_utils import * +from src.utils.dvd_codebase.data.data_utils import * +from src.utils.dvd_codebase.data.analysis_utils import get_question_subtype, get_question_complexity +from transformers import AutoTokenizer + + +def load_dials(args, split): + files = [] + for video_split in ['all_actions', 'max2action']: + files += glob.glob(args.data_dir + '{}_{}_*/*.json'.format(video_split, split)) + files = sorted(files) # [:50] + if args.debug: + files = files[:100] + all_dials = [] + vid_set = {} + for file in tqdm(files, total=len(files)): + dials = json.load(open(file)) + all_dials.extend(dials) + video_split = dials[0][0]['split'] + vid = dials[0][0]['image'].replace('CLEVR', 'CATER') + vid_key = '{}-{}'.format(video_split, vid) + if vid_key not in vid_set: + vid_set[vid_key] = '{}/{}/{}.pkl'.format(args.fea_dir, video_split, vid) + return all_dials, vid_set + +def load_videos(args, vid_set): + vid_fts = {} + ft_dims = None + size, stride = -1, -1 + segment_map = {} + for vid_key, fea_file in tqdm(vid_set.items(), total=len(vid_set)): + #fea_file = '{}/{}.pkl'.format(args.fea_dir, vid) + fea = pkl.load(open(fea_file, 'rb')) + output = [] + for clip_idx, clip in enumerate(fea['clips']): + fea = clip['features'] + if len(fea.shape)==3: + fea = fea.transpose(1, 2, 0) + output.append(fea) + start, end = clip['segment'] + if clip_idx not in segment_map: + segment_map[clip_idx] = (start, end) + if size == -1: + size = end - start + 1 + if clip_idx>0 and stride == -1: + stride = start - prior_start + prior_start, prior_end = start, end + vft = np.asarray(output) + vid_fts[vid_key] = vft + if ft_dims is None: + ft_dims = vft.shape + return vid_fts, ft_dims, size, stride, segment_map + +def load_video_features(args, vid_set): + vid_fts = {} + for vid_key, fea_file in tqdm(vid_set.items(), total=len(vid_set)): + #fea_file = '{}/{}.pkl'.format(args.fea_dir, vid) + fea = pkl.load(open(fea_file, 'rb')) + vid_fts[vid_key] = fea + return vid_fts + +def get_vocabulary(dials, args, vocab=None): + #answer_options = set() + word_freq = {} + for dialog in tqdm(dials, total=len(dials)): + for turn in dialog: + for word in nltk.word_tokenize(turn['question']): + if word not in word_freq: word_freq[word] = 0 + word_freq[word] += 1 + answer = str(turn['answer']) + #answer_options.add(answer) + for word in nltk.word_tokenize(answer): + if word not in word_freq: word_freq[word] = 0 + word_freq[word] += 1 + program = turn['final_all_program'] + for n in program: + if n['type'] == 'identity': continue + if n['type'] not in word_freq: word_freq[n['type']] = 0 + word_freq[n['type']] += 1 + if 'side_inputs' in n: + for side_input in n['side_inputs']: + for word in nltk.word_tokenize(side_input): + if word not in word_freq: word_freq[word] = 0 + word_freq[word] += 1 + if vocab is not None: + unk_words = set() + for word, freq in word_freq.items(): + if word not in vocab: + unk_words.add(word) + return unk_words + vocab = {'':0, '':1, '':2, '':3, '': 4} + for word, freq in word_freq.items(): + vocab[word] = len(vocab) + answer_options = ['0', '1', '10', '2', '3', '4', '5', '6', '7', '8', '9', 'False', 'True', 'blue', 'brown', 'cone', 'cube', 'cyan', 'cylinder', 'flying', 'flying,rotating', 'flying,rotating,sliding', 'flying,sliding', 'gold', 'gray', 'green', 'large', 'medium', 'metal', 'no action', 'purple', 'red', 'rotating', 'rotating,sliding', 'rubber', 'sliding', 'small', 'sphere', 'spl', 'yellow'] + return vocab, answer_options + +def answer_by_question_type(dials): + qa_dist = {} + for dialog in dials: + for turn_idx, turn in enumerate(dialog): + answer = turn['answer'] + template = turn['template'] + if turn_idx > 0: + prior_template = dialog[turn_idx-1]['template'] + else: + prior_template = None + qtype = get_question_subtype(template, prior_template) + if qtype not in qa_dist: + qa_dist[qtype] = {} + if answer not in qa_dist[qtype]: + qa_dist[qtype][answer] = 0 + qa_dist[qtype][answer] += 1 + return qa_dist + + +# Load text data +def create_dials(dials, vocab, answer_list, vft_data, args, tokenizer=None): + dialog_list = [] + qa_id = 0 + for dialog in tqdm(dials, total=len(dials)): + if tokenizer is None: + questions = [words2ids(t['question'], vocab) for t in dialog] + answers = [words2ids(str(t['answer']), vocab) for t in dialog] + else: + questions = [words2ids_pretrained_lm(t['question'], vocab, tokenizer) for t in dialog] + answers = [words2ids_pretrained_lm(str(t['answer']), vocab, tokenizer) for t in dialog] + answer_output = [[answer_list.index(str(t['answer']))] for t in dialog] + qa_pair = [np.concatenate((q,a)).astype(np.int32) for q,a in zip(questions, answers)] + + attribute_dependencies = [] + object_dependencies = [] + temporal_dependencies = [] + spatial_dependencies = [] + q_types = [] + q_complexities = [] + for i, t in enumerate(dialog): + # determine the type of turn relation + attribute_dependencies.append(t['turn_dependencies']['attribute']) + object_dependencies.append(t['turn_dependencies']['object']) + temporal_dependencies.append(t['turn_dependencies']['temporal']) + spatial_dependencies.append(t['turn_dependencies']['spatial']) + + # determine the question type based on the template for analysis reasons + if i == 0: + q_types.append(get_question_type(t['template'], None)) + else: + q_types.append(get_question_type(t['template'], dialog[i-1]['template'])) + + # get question complexity + q_complexities.append(get_question_complexity(t, t['template_filename'] )) + + # get image name + video_name = t['image'] + + vid_cutoffs = [t['template']['cutoff'] for t in dialog] + gt_vid_periods = [t['template']['used_periods'][-1] for t in dialog] + programs = [program2ids(t['final_all_program'], vocab) for t in dialog] + states = [state2ids(t['template']['used_objects'], vocab) for t in dialog] + vid = dialog[0]['image'].replace('CLEVR', 'CATER') + vid_split = dialog[0]['split'] + vid_key = '{}-{}'.format(vid_split, vid) + whole_vft_fea = vft_data[vid_key] + turn_based_vft_fea = [] + + # cutoff the unused vft data based on the vid_cutoffs + for t_idx, t_cutoff in enumerate(vid_cutoffs): + if t_cutoff is not None: + t_vft_fea = whole_vft_fea[:t_cutoff[3], :, :] + else: + t_vft_fea = whole_vft_fea + turn_based_vft_fea.append(t_vft_fea) + + for n in range(len(questions)): + start_turn_idx = 0 + history = np.asarray([]) + turns = [] + q_turns = [] + a_turns = [] + for m in range(start_turn_idx, n): + history = np.append(history, qa_pair[m]) + turns.append(qa_pair[m]) + q_turns.append(questions[m]) + a_turns.append(np.array(answer_output[m])) + + question = questions[n] + answer = answer_output[n] + program = programs[n] + state = states[n] + gt_period = gt_vid_periods[n] + q_type = q_types[n] + attribute_dependency = attribute_dependencies[n] + object_dependency = object_dependencies[n] + temporal_dependency = temporal_dependencies[n] + spatial_dependency = spatial_dependencies[n] + q_complexity = q_complexities[n] + vft_feat = turn_based_vft_fea[n] + + item = [vid_split, vid, qa_id, history, question, answer, turns, + q_turns, a_turns, vft_feat, gt_period, + program, state, q_type, attribute_dependency, object_dependency, + temporal_dependency, spatial_dependency, video_name, q_complexity] + + dialog_list.append(item) + qa_id += 1 + + data = {'dialogs': dialog_list, 'vocab': vocab, 'answer': answer_list, 'features': []} + return data + + +def create_dataset(data, vocab, split, args): + out = {} + keys = ['vid_split', 'vid', 'qa_id', 'history', 'question', 'answer', 'turns', + 'q_turns', 'a_turns', 'vft', 'gt_period', + 'program', 'state', 'q_type', 'attribute_dependency', 'object_dependency', + 'temporal_dependency', 'spatial_dependency', 'video_name', 'q_complexity'] + for key in keys: + out[key] = [] + for dialog in data['dialogs']: + out['vid_split'].append(dialog[0]) + out['vid'].append(dialog[1]) + out['qa_id'].append(dialog[2]) + out['history'].append(dialog[3]) + out['question'].append(dialog[4]) + out['answer'].append(dialog[5]) + out['turns'].append(dialog[6]) + out['q_turns'].append(dialog[7]) + out['a_turns'].append(dialog[8]) + out['vft'].append(dialog[9]) + out['gt_period'].append(dialog[10]) + out['program'].append(dialog[11]) + out['state'].append(dialog[12]) + out['q_type'].append(dialog[13]) + out['attribute_dependency'].append(dialog[14]) + out['object_dependency'].append(dialog[15]) + out['temporal_dependency'].append(dialog[16]) + out['spatial_dependency'].append(dialog[17]) + out['video_name'].append(dialog[18]) + out['q_complexity'].append(dialog[19]) + + dataset = Dataset(out) + data_loader = torch.utils.data.DataLoader(dataset=dataset, + batch_size=args.batch_size, + shuffle=(split=='train'), + collate_fn=partial(collate_fn, vocab=vocab), + num_workers=args.num_workers, + pin_memory=True) + return data_loader, len(out['vid']) diff --git a/src/utils/dvd_codebase/data/data_utils.py b/src/utils/dvd_codebase/data/data_utils.py new file mode 100644 index 0000000..4b902b0 --- /dev/null +++ b/src/utils/dvd_codebase/data/data_utils.py @@ -0,0 +1,169 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import copy +import logging +import sys +import time +import os +import six +import pickle +import json +import numpy as np +import pdb +from tqdm import tqdm +import torch +import nltk + +def subsequent_mask(size): + "Mask out subsequent positions." + attn_shape = (1, size, size) + subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8') + return torch.from_numpy(subsequent_mask) == 0 + +def get_npy_shape(filename): + with open(filename, 'rb') as f: + if filename.endswith('.pkl'): + shape = pickle.load(f).shape + else: + pdb.set_trace() + major, minor = np.lib.format.read_magic(f) + shape, fortran, dtype = np.lib.format.read_array_header_1_0(f) + return shape + +def words2ids(str_in, vocab): + words = nltk.word_tokenize(str_in) + sentence = np.ndarray(len(words)+2, dtype=np.int32) + sentence[0]=vocab[''] + for i,w in enumerate(words): + if w in vocab: + sentence[i+1] = vocab[w] + else: + sentence[i+1] = vocab[''] + sentence[-1]=vocab[''] + return sentence + +def words2ids_pretrained_lm(str_in, vocab, tokenizer): + # based on: https://medium.com/@dhartidhami/understanding-bert-word-embeddings-7dc4d2ea54ca + text = tokenizer.cls_token + str_in + tokenizer.eos_token + tokenized_text = tokenizer.tokenize(text) + indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text) + token_array = np.array([indexed_tokens]) + token_array = np.reshape(token_array, (-1,)) + return token_array + + +def program2ids(program, vocab): + sentence = [] + return np.asarray(sentence, dtype=np.int32) + for n in program: + t = n['type'] + if t == 'identity': continue + if t not in vocab: + print(t) + pdb.set_trace() + #else: + # t = new_nodes[t] + sentence.append(vocab[t]) + if 'side_inputs' in n: + if len(n['side_inputs'])!=1: + assert type(n['side_inputs']) == str + words = n['side_inputs'] + else: + words = n['side_inputs'][0] + words = nltk.word_tokenize(words) + for word in words: + if word in vocab: + sentence.append(vocab[word]) + else: + sentence.append(vocab['']) + #if len(sentence)==0: + # pdb.set_trace() + # sentence=np.asarray([vocab['']]) + return np.asarray(sentence, dtype=np.int32) + +def state2ids_dot(state, dot_vocab, max_dot_size=10): + ordered_attrs = ['', '', '', ''] + ids = {} + for a in ordered_attrs: + ids[a] = [] + for o in range(max_dot_size): + ids[a].append(dot_vocab[a]['']) + if len(state)==0: + return ids + sorted_state = {k: v for k, v in sorted(state.items(), key=lambda item: item[1]['original_turn'])} + state_idx = 0 + for k,v in sorted_state.items(): + for a in ordered_attrs: + if a in v: + ids[a][state_idx] = dot_vocab[a][v[a]] + state_idx += 1 + ids = {k:np.asarray(v, dtype=np.int32) for k,v in ids.items()} + return ids + +def state2ids(state, vocab): + return np.asarray([], dtype=np.int32) + if len(state)==0: + return np.asarray([vocab['']], dtype=np.int32) + sentence = [] + ordered_attrs = ['', '', '', ''] + #print(state) + sorted_state = {k: v for k, v in sorted(state.items(), key=lambda item: item[1]['original_turn'])} + + for k,v in sorted_state.items(): + found_obj = False + for a in ordered_attrs: + if a in v: + sentence.append(vocab[v[a]]) + found_obj = True + if found_obj: + sentence.append(vocab['']) + if len(sentence)==0: + return np.asarray([vocab['']], dtype=np.int32) + return np.asarray(sentence, dtype=np.int32) + +def get_vft_size_by_timestamp(time, segment_map, event_type='end', threshold=5): + if time is None: + if event_type == 'end': + return len(segment_map)-1 + else: + return 0 + + if event_type == 'end': + segment_idx = -1 + for idx in range(len(segment_map)): + segment_range = segment_map[idx] + if segment_range[1]>time[-1]: + segment_idx = idx-1 + break + if segment_idx == -1: + segment_idx = 0 + return segment_idx + + else: + segment_idx = -1 + for idx in range(len(segment_map)): + segment_range = segment_map[idx] + if segment_range[0]>=time[-1]: + segment_idx = idx + break + if segment_idx == -1: + segment_idx = len(segment_map)-1 + return segment_idx + + +def get_vft_range_by_period(period, segment_map, eov): + if period is None: + return (0, eov) + else: + start_time, end_time = period + start_vft = get_vft_size_by_timestamp(start_time, segment_map, 'start') + end_vft = get_vft_size_by_timestamp(end_time, segment_map, 'end') + if start_vft > end_vft: + start_vft, end_vft = end_vft, start_vft + return (start_vft, end_vft) + diff --git a/src/utils/dvd_codebase/data/dataset.py b/src/utils/dvd_codebase/data/dataset.py new file mode 100644 index 0000000..992ee3b --- /dev/null +++ b/src/utils/dvd_codebase/data/dataset.py @@ -0,0 +1,255 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import copy +import logging +import sys +import time +import os +import six +import pickle +import json +import numpy as np +import pdb +from tqdm import tqdm +import torch +import torch.utils.data as Data +from torch.autograd import Variable +from src.utils.dvd_codebase.data.data_utils import * + +class Dataset(Data.Dataset): + def __init__(self, data_info): + self.vid_split = data_info['vid_split'] + self.vid = data_info['vid'] + self.qa_id = data_info['qa_id'] + self.history = data_info['history'] + self.question = data_info['question'] + self.answer = data_info['answer'] + self.turns = data_info['turns'] + self.q_turns = data_info['q_turns'] + self.a_turns = data_info['a_turns'] + self.vft = data_info['vft'] + self.gt_period = data_info['gt_period'] + self.program = data_info['program'] + self.state = data_info['state'] + self.q_type = data_info['q_type'] + self.attribute_dependency = data_info['attribute_dependency'] + self.object_dependency = data_info['object_dependency'] + self.temporal_dependency = data_info['temporal_dependency'] + self.spatial_dependency = data_info['spatial_dependency'] + self.video_name = data_info['video_name'] + self.q_complexity = data_info['q_complexity'] + + def __getitem__(self, index): + item_info = { + 'vid_split': self.vid_split[index], + 'vid':self.vid[index], + 'qa_id': self.qa_id[index], + 'history': self.history[index], + 'turns': self.turns[index], + 'q_turns': self.q_turns[index], + 'a_turns': self.a_turns[index], + 'question': self.question[index], + 'answer': self.answer[index], + 'vft': self.vft[index], + 'gt_period': self.gt_period[index], + 'program': self.program[index], + 'state': self.state[index], + 'q_type': self.q_type[index], + 'attribute_dependency': self.attribute_dependency[index], + 'object_dependency': self.object_dependency[index], + 'temporal_dependency': self.temporal_dependency[index], + 'spatial_dependency': self.spatial_dependency[index], + 'video_name': self.video_name[index], + 'q_complexity': self.q_complexity[index] + } + return item_info + + def __len__(self): + return len(self.vid) + +class Batch: + def __init__(self, vft, his, query, his_query, turns, + q_turns, a_turns, + answer, vid_splits, vids, qa_ids, + query_lens, his_lens, his_query_lens, + dial_lens, turn_lens, + program, program_lens, state, state_lens, + vocab, q_type, attribute_dependency, object_dependency, + temporal_dependency, spatial_dependency, video_name, q_complexity): + self.vid_splits = vid_splits + self.vids = vids + self.qa_ids = qa_ids + self.size = len(self.vids) + + self.query = query + self.query_lens = query_lens + self.his = his + self.his_lens = his_lens + self.his_query = his_query + self.his_query_lens = his_query_lens + self.answer = answer + self.vft = vft + self.turns = turns + self.q_turns = q_turns + self.a_turns = a_turns + self.dial_lens = dial_lens + self.turn_lens = turn_lens + self.q_type = q_type + self.attribute_dependency = attribute_dependency + self.object_dependency = object_dependency + self.temporal_dependency = temporal_dependency + self.spatial_dependency = spatial_dependency + self.video_name = video_name + self.q_complexity = q_complexity + + pad = vocab[''] + self.his_query_mask = (his_query != pad).unsqueeze(-2) + self.query_mask = (query != pad) + self.his_mask = (his != pad).unsqueeze(-2) + self.q_turns_mask = (q_turns != pad) + self.turns_mask = (turns != pad) + + self.program = program + self.program_lens = program_lens + self.state = state + self.state_lens = state_lens + + @staticmethod + def make_std_mask(tgt, pad): + tgt_mask = (tgt != pad).unsqueeze(-2) + tgt_mask = tgt_mask & Variable(subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data)) + return tgt_mask + + def move_to_cuda(self): + self.query = self.query.to('cuda', non_blocking=True) + self.his = self.his.to('cuda', non_blocking=True) + self.his_query = self.his_query.to('cuda', non_blocking=True) + self.query_mask = self.query_mask.to('cuda', non_blocking=True) + self.his_mask = self.his_mask.to('cuda', non_blocking=True) + self.his_query_mask = self.his_query_mask.to('cuda', non_blocking=True) + self.answer = self.answer.to('cuda', non_blocking=True) + self.vft = self.vft.to('cuda', non_blocking=True) + self.turns = self.turns.to('cuda', non_blocking=True) + self.turns_mask = self.turns_mask.to('cuda', non_blocking=True) + self.q_turns = self.q_turns.to('cuda', non_blocking=True) + self.q_turns_mask = self.q_turns_mask.to('cuda', non_blocking=True) + self.a_turns = self.a_turns.to('cuda', non_blocking=True) + self.program = self.program.to('cuda', non_blocking=True) + self.state = self.state.to('cuda', non_blocking=True) + + def to_cuda(self, tensor): + return tensor.cuda() + +def collate_fn(data, vocab): + def pad_monet_videos(seqs, pad_token): + lengths = [s.shape[0] for s in seqs] + max_length = max(lengths) + output = [] + for seq in seqs: + result = torch.ones((max_length, seq.shape[1], seq.shape[2])) * pad_token + result[:seq.shape[0]] = seq + output.append(result) + return output + + def pad_seq(seqs, pad_token, return_lens=False, is_vft=False): + lengths = [s.shape[0] for s in seqs] + max_length = max(lengths) + output = [] + for seq in seqs: + if is_vft: + if len(seq.shape)==4: # spatio-temporal feature + result = np.ones((max_length, seq.shape[1], seq.shape[2], seq.shape[3]), dtype=seq.dtype)*pad_token + else: + result = np.ones((max_length, seq.shape[-1]), dtype=seq.dtype)*pad_token + else: + result = np.ones(max_length, dtype=seq.dtype)*pad_token + result[:seq.shape[0]] = seq + output.append(result) + if return_lens: + return lengths, output + return output + + def pad_2d_seq(seqs, pad_token, return_lens=False, is_vft=False): + lens1 = [len(s) for s in seqs] + max_len1 = max(lens1) + all_seqs = [] + for seq in seqs: + all_seqs.extend(seq) + lens2 = [len(s) for s in all_seqs] + max_len2 = max(lens2) + output = [] + all_lens = [] + for seq in seqs: + if is_vft: + result = np.ones((max_len1, max_len2, seq[0].shape[-1]))*pad_token + else: + result = np.ones((max_len1, max_len2))*pad_token + turn_lens = np.ones(max_len1).astype(int) + offset = max_len1 - len(seq) + for turn_idx, turn in enumerate(seq): + #result[turn_idx,:turn.shape[0]] = turn + # padding should be at the first turn idxs (Reason: result of last n turns is used for state creation) + result[turn_idx + offset,:turn.shape[0]] = turn + turn_lens[turn_idx] = turn.shape[0] + output.append(result) + all_lens.append(turn_lens) + all_lens = np.asarray(all_lens) + if return_lens: + return lens1, all_lens, output + return output + + def prepare_data(seqs, is_float=False): + if is_float: + return torch.from_numpy(np.asarray(seqs)).float() + return torch.from_numpy(np.asarray(seqs)).long() + + item_info = {} + for key in data[0].keys(): + item_info[key] = [d[key] for d in data] + pad_token = vocab[''] + h_lens, h_padded = pad_seq(item_info['history'], pad_token, return_lens=True) + h_batch = prepare_data(h_padded) + q_lens, q_padded = pad_seq(item_info['question'], pad_token, return_lens=True) + q_batch = prepare_data(q_padded) + + hq = [np.concatenate([q,h]) for q,h in zip(item_info['history'], item_info['question'])] + hq_lens, hq_padded = pad_seq(hq, pad_token, return_lens=True) + hq_batch = prepare_data(hq_padded) + + dial_lens, turn_lens, turns_padded = pad_2d_seq(item_info['turns'], pad_token, return_lens=True) + _, _, q_turns_padded = pad_2d_seq(item_info['q_turns'], pad_token, return_lens=True) + turns_batch = prepare_data(turns_padded) + q_turns_batch = prepare_data(q_turns_padded) + + a_turns_padded = pad_2d_seq(item_info['a_turns'], pad_token) + a_turns_batch = prepare_data(a_turns_padded) + + a_batch = prepare_data(item_info['answer']) + + #vft_lens, vft_padded = pad_seq(item_info['vft'], 0, return_lens=True, is_vft=True) + #vft_batch = prepare_data(vft_padded, is_float=True) + vft_batch = item_info['vft'] + vft_batch_padded = pad_monet_videos(vft_batch, 0) + vft_batch_padded = torch.stack(vft_batch_padded) + + p_lens, p_padded = pad_seq(item_info['program'], pad_token, return_lens=True) + p_batch = prepare_data(p_padded) + + s_lens, s_padded = pad_seq(item_info['state'], pad_token, return_lens=True) + s_batch = prepare_data(s_padded) + + batch = Batch(vft_batch_padded, + h_batch, q_batch, hq_batch, turns_batch, q_turns_batch, a_turns_batch, a_batch, + item_info['vid_split'], item_info['vid'], item_info['qa_id'], + q_lens, h_lens, hq_lens, + dial_lens, turn_lens, + p_batch, p_lens, s_batch, s_lens, + vocab, item_info['q_type'], item_info['attribute_dependency'], item_info['object_dependency'], + item_info['temporal_dependency'], item_info['spatial_dependency'], item_info['video_name'], + item_info['q_complexity']) + return batch diff --git a/src/utils/dvd_codebase/exps_test/baseline/dvd.conf b/src/utils/dvd_codebase/exps_test/baseline/dvd.conf new file mode 100644 index 0000000..5112a13 Binary files /dev/null and b/src/utils/dvd_codebase/exps_test/baseline/dvd.conf differ diff --git a/src/utils/dvd_codebase/exps_test/baseline/dvd_params.txt b/src/utils/dvd_codebase/exps_test/baseline/dvd_params.txt new file mode 100644 index 0000000..8473b9a --- /dev/null +++ b/src/utils/dvd_codebase/exps_test/baseline/dvd_params.txt @@ -0,0 +1,9 @@ +debug=1 +fea_dir=/workspace/hungle/data/dvd/video-classification-3d-cnn-pytorch/outputs/resnext_101/ +data_dir=/workspace/hungle/cater-dialog/question_generation/output/ +output_dir=exps_test//baseline/dvd +num_workers=0 +device=0 +num_epochs=3 +batch_size=32 +verbose=0 diff --git a/src/utils/dvd_codebase/main.py b/src/utils/dvd_codebase/main.py new file mode 100755 index 0000000..8dc6558 --- /dev/null +++ b/src/utils/dvd_codebase/main.py @@ -0,0 +1,86 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +#!/usr/bin/env python +import math +import sys +import time +import os +import json +import numpy as np +import pickle as pkl +import threading +import pdb +from tqdm import tqdm +import torch +import torch.nn as nn + +from project.dvd_codebase.configs.configs import * +import project.dvd_codebase.data.data_handler as dh + +def run_epoch(loader, epoch): + it = tqdm(enumerate(loader),total=len(loader), desc="epoch {}/{}".format(epoch+1, args.num_epochs), ncols=0) + for j, batch in it: + batch.move_to_cuda() + pdb.set_trace() + +# load dialogues +logging.info('Loading dialogues from {}'.format(args.data_dir)) +train_dials, train_vids = dh.load_dials(args, 'train') +logging.info('#train dials = {} # train videos = {}'.format(len(train_dials), len(train_vids))) +val_dials, val_vids = dh.load_dials(args, 'val') +logging.info('#val dials = {} # val videos = {}'.format(len(val_dials), len(val_vids))) + +# load video features +logging.info('Loading video features from {}'.format(args.fea_dir)) +train_vft, vft_dims, clip_size, clip_stride, segment_map = dh.load_videos(args, train_vids) +val_vft, _, _, _, _ = dh.load_videos(args, val_vids) +logging.info('#video ft dims = {} clip size {} clip stride {}'.format(vft_dims, clip_size, clip_stride)) + +# get vocabulary +logging.info('Extracting vocabulary') +vocab, answer_list = dh.get_vocabulary(train_dials, args) +logging.info('#vocab = {} #answer candidates = {}'. + format(len(vocab), len(answer_list))) +logging.info('All answer candidates: {}'.format(answer_list)) +unk_words = dh.get_vocabulary(val_dials, args, vocab=vocab) +logging.info('{} unknown words in val split: {}'.format(len(unk_words), unk_words)) + +# question-answer distribution +qa_dist = dh.answer_by_question_type(train_dials) + +# save meta parameters +path = args.output_dir + '.conf' +with open(path, 'wb') as f: + pkl.dump((vocab, answer_list, qa_dist, args), f, -1) +path2 = args.output_dir + '_params.txt' +with open(path2, "w") as f: + for arg in vars(args): + f.write("{}={}\n".format(arg, getattr(args, arg))) + +# load data +logging.info('Creating training instances') +train_dials = dh.create_dials(train_dials, vocab, answer_list, segment_map, train_vft, args) +logging.info('Creating validation instances') +valid_dials = dh.create_dials(val_dials, vocab, answer_list, segment_map, val_vft, args) + +# make dataloaders +train_dataloader, train_samples = dh.create_dataset(train_dials, vocab, 'train', args) +logging.info('#train sample = {} # train batch = {}'.format(train_samples, len(train_dataloader))) +valid_dataloader, valid_samples = dh.create_dataset(valid_dials, vocab, 'val', args) +logging.info('#train sample = {} # train batch = {}'.format(valid_samples, len(valid_dataloader))) + +epoch_counts = 0 +for epoch in range(args.num_epochs): + # train on training data + logging.info('-------training--------') + train_losses = run_epoch(train_dataloader, epoch) + + # test on validation data + logging.info('-------validation--------') + valid_losses = run_epoch(valid_dataloader, epoch) + diff --git a/src/utils/dvd_codebase/run.sh b/src/utils/dvd_codebase/run.sh new file mode 100755 index 0000000..6dad31f --- /dev/null +++ b/src/utils/dvd_codebase/run.sh @@ -0,0 +1,43 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. +All rights reserved. +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +#input choices +device=$1 +debug=$2 # true: test run with small datasets OR false: run with real datasets + +num_epochs=50 +batch_size=32 +nb_workers=16 + +# data setting +data_dir=/workspace/hungle/cater-dialog/question_generation/output/ +fea_dir=/workspace/hungle/data/dvd/video-classification-3d-cnn-pytorch/outputs/resnext_101/ + +# output folder name +expid=baseline + +if [ $debug = 1 ]; then + expdir=exps_test/$task/${expid} + num_epochs=3 + nb_workers=0 + report_interval=10 +else + expdir=exps/$task/${expid} +fi +echo stage: $stage debug? $debug task: $task exp_dir: $expdir + +# training phase +mkdir -p $expdir +CUDA_VISIBLE_DEVICES=$device python main.py \ + --debug $debug \ + --fea-dir $fea_dir \ + --data-dir $data_dir \ + --output-dir $expdir/dvd \ + --num-epochs $num_epochs \ + --batch-size $batch_size \ + --num-workers $nb_workers \ + diff --git a/src/utils/positional_encoding.py b/src/utils/positional_encoding.py new file mode 100644 index 0000000..3bc1ea9 --- /dev/null +++ b/src/utils/positional_encoding.py @@ -0,0 +1,27 @@ +# https://github.com/pytorch/pytorch/issues/68407 +from torch import nn +from torch import Tensor +import torch +import math + + +class PositionalEncoding(nn.Module): + def __init__(self, d_model, dropout=0.1, max_len=5000): + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0).transpose(0, 1) + self.register_parameter('pe', nn.Parameter(pe, requires_grad=False)) + + def forward(self, x): + # positional encoding expects shape (seq_len, batch_size, emb_dim), (batch_size, seq_len, emb_dim) is given + x = x.permute(1,0,2) + x = x + self.pe[:x.size(0), :] + x = x.permute(1,0,2) + return self.dropout(x) + diff --git a/src/utils/save_attention_weights.py b/src/utils/save_attention_weights.py new file mode 100644 index 0000000..5ce8d98 --- /dev/null +++ b/src/utils/save_attention_weights.py @@ -0,0 +1,8 @@ +#https://gist.github.com/airalcorn2/50ec06517ce96ecc143503e21fa6cb91 + +class SaveOutput: + def __init__(self): + self.outputs = None + + def __call__(self, module, module_in, module_out): + self.outputs = module_out diff --git a/src/utils/simmc2_dataset/dataloader_dvd_model.py b/src/utils/simmc2_dataset/dataloader_dvd_model.py new file mode 100644 index 0000000..fd29632 --- /dev/null +++ b/src/utils/simmc2_dataset/dataloader_dvd_model.py @@ -0,0 +1,233 @@ +#! /usr/bin/env python +""" +Copyright (c) Facebook, Inc. and its affiliates. +All rights reserved. +This source code is licensed under the license found in the LICENSE file in the +root directory of this source tree. + +Dataloader for ambiguous candidates identification task on SIMMC 2.1. + +Author(s): Satwik Kottur +""" + +from __future__ import absolute_import, division, print_function, unicode_literals + +import json + +import numpy as np +import torch +from torch.utils.data import Dataset +from torch.nn.utils.rnn import pad_sequence + + +def pad_seq(seqs, pad_token, return_lens=False, is_vft=False): + lengths = [s.shape[1] for s in seqs] + max_length = max(lengths) + output = [] + for seq in seqs: + if is_vft: + if len(seq.shape)==4: # spatio-temporal feature + result = torch.ones(((1, max_length), seq.shape[1], seq.shape[2], seq.shape[3]), dtype=seq.dtype)*pad_token + else: + result = torch.ones(((1, max_length), seq.shape[-1]), dtype=seq.dtype)*pad_token + else: + result = torch.ones((1, max_length), dtype=seq.dtype)*pad_token + result[0, :seq.shape[1]] = seq + output.append(result) + if return_lens: + return lengths, output + return output + + +def pad_2d_seq(seqs, pad_token, return_lens=False, is_vft=False): + lens1 = [len(s) for s in seqs] + max_len1 = max(lens1) + all_seqs = [] + for seq in seqs: + all_seqs.extend(seq) + lens2 = [s.shape[1] for s in all_seqs] + max_len2 = max(lens2) + output = [] + all_lens = [] + for seq in seqs: + if is_vft: + result = torch.ones((max_len1, max_len2, seq[0].shape[-1]))*pad_token + else: + result = torch.ones((1, max_len1, max_len2))*pad_token + #turn_lens = torch.ones(max_len1, dtype=np.int) + offset = max_len1 - len(seq) + for turn_idx, turn in enumerate(seq): + #result[turn_idx,:turn.shape[0]] = turn + # padding should be at the first turn idxs (Reason: result of last n turns is used for state creation) + result[0, turn_idx + offset,:turn.shape[1]] = turn + #turn_lens[turn_idx] = turn.shape[0] + output.append(result) + return output + + +class Simmc2Dataset(Dataset): + def __init__(self, tokenizer, feature_loader, load_path, args, hidden_labels=False): + self._tokenizer = tokenizer + self._features = feature_loader + self._args = args + self._hidden_labels = hidden_labels + print("Loading: {}".format(load_path)) + with open(load_path, "r") as file_id: + self._raw_data = json.load(file_id) + # Also read the source data for evaluation. + with open(self._raw_data["source_path"], "r") as file_id: + self.source_data = json.load(file_id) + self._data = self._raw_data["data"] + + self.num_utterances = 2 * args.max_turns + 1 + self.num_instances = len(self._data) + self.device = torch.cuda if args.use_gpu else torch + + def get_random_batch(self, batch_size): + indices = np.random.randint(0, self.num_instances, batch_size) + return self.get_indexed_data(indices) + + def get_entire_batch(self, batch_size): + all_indices = np.arange(self.num_instances) + for start in all_indices[::batch_size]: + batch_indices = all_indices[start : start + batch_size] + yield self.get_indexed_data(batch_indices) + + + def __len__(self): + return len(self._data) + + + def collate_fn(self, batch): + merged_batch = {key: [d[key] for d in batch] for key in batch[0]} + out = {} + for key in merged_batch: + if key in ['query', 'answer']: + seq = pad_seq(merged_batch[key], pad_token=1) + out[key] = torch.concat(seq, dim=0) + elif key in ['q_turns', 'a_turns', 'turns', 'object_features', 'answer_candidates']: + if merged_batch[key][0] is not None: + seq = pad_2d_seq(merged_batch[key], pad_token=1) + out[key] = torch.concat(seq, dim=0).type(torch.int) + else: + out[key] = None + + elif key in ['features']: + #features = [f.unsqueeze(1) for f in merged_batch[key]] + # pad video featues + features = pad_sequence(merged_batch[key], batch_first=True) + out[key] = features + else: + out[key] = merged_batch[key] + + + return out + + + def encode_turns(self, turns): + encoded_turns = [] + for turn in turns: + encoded_turn = self._tokenizer( + turn, + padding=True, + max_length=self._args.max_length, + return_tensors="pt", + truncation=True, + ) + encoded_turns.append(encoded_turn['input_ids'].type(torch.int)) + return encoded_turns + + + def __getitem__(self, index): + text_labels = [] + text_inputs = [] + dialog_ids = [] + turn_ids = [] + features = [] + object_maps = [] + # Add and tokens. + dialog_datum = self._data[index] + #dialog = self._data[index]["input_text"] + query = self._data[index]["query"] + answer = self._data[index]["answer"] + turns = self._data[index]["turns"] + q_turns = self._data[index]["q_turns"] + a_turns = self._data[index]["a_turns"] + object_features = self._data[index]["object_metadata"] + if "answer_candidates" in self._data[index].keys(): + answer_candidates = self._data[index]["answer_candidates"] + else: + answer_candidates = None + + if self._features: + feature = self._features[dialog_datum["image_name"]] + + encoded_query = self._tokenizer( + query, + padding=True, + max_length=self._args.max_length, + return_tensors="pt", + truncation=True, + )['input_ids'].type(torch.int) + encoded_answer = self._tokenizer( + answer, + padding=True, + max_length=self._args.max_length, + return_tensors="pt", + truncation=True, + )['input_ids'].type(torch.int) + encoded_q_turns = self.encode_turns(q_turns) + encoded_a_turns = self.encode_turns(a_turns) + encoded_turns = self.encode_turns(turns) + encoded_object_features = self.encode_turns(object_features) + if "answer_candidates" in self._data[index].keys(): + encoded_answer_candidates = self.encode_turns(answer_candidates) + else: + encoded_answer_candidates = None + + + # Pack the sample. + sample = { + "query": encoded_query, + "answer": encoded_answer, + "answer_candidates": encoded_answer_candidates, + "turns": encoded_turns, + "q_turns": encoded_q_turns, + "a_turns": encoded_a_turns, + "object_features": encoded_object_features, + "dialog_id": dialog_datum["dialog_id"], + "turn_id": dialog_datum["turn_id"], + "features": feature, + } + return sample + + +class VisualFeatureLoader: + """Loads visual features for SIMMC 2.1 ambiguous candidate identification.""" + + UNAVAILABLE_IMAGES = [ + "cloth_store_1416238_woman_20_6.png", + "cloth_store_1416238_woman_19_0.png", + "cloth_store_1416238_woman_4_8.png", + ] + + def __init__(self, feature_path, feature_size): + """Read the features from the path.""" + self._features = torch.load(feature_path) + self._feature_size = feature_size + self._zero_feature = torch.zeros((1, self._feature_size), dtype=torch.float) + + def __getitem__(self, label): + """Get the feature given image label.""" + assert ( + label in self._features or label in self.UNAVAILABLE_IMAGES + ), f"{label} not found!" + if label in self.UNAVAILABLE_IMAGES: + return self._zero_feature + return self._features[label] + + def cuda(self): + """Move the features to cuda.""" + self._zero_feature = self._zero_feature.cuda() + for key, val in self._features.items(): + self._features[key] = val.cuda() diff --git a/src/utils/simmc2_dataset/dataloader_finetune_mlm.py b/src/utils/simmc2_dataset/dataloader_finetune_mlm.py new file mode 100644 index 0000000..e69de29 diff --git a/src/utils/simmc2_dataset/dataloader_mlm_nsp.py b/src/utils/simmc2_dataset/dataloader_mlm_nsp.py new file mode 100644 index 0000000..64f0e34 --- /dev/null +++ b/src/utils/simmc2_dataset/dataloader_mlm_nsp.py @@ -0,0 +1,277 @@ +#! /usr/bin/env python +""" +Copyright (c) Facebook, Inc. and its affiliates. +All rights reserved. +This source code is licensed under the license found in the LICENSE file in the +root directory of this source tree. + +Dataloader for ambiguous candidates identification task on SIMMC 2.1. + +Author(s): Satwik Kottur +""" + +from __future__ import absolute_import, division, print_function, unicode_literals + +import json + +import numpy as np +import torch +from torch.utils.data import Dataset +from torch.nn.utils.rnn import pad_sequence +from random import shuffle +from random import random as rand +#from src.utils.vd_bert.loader_utils import get_random_word + + +def pad_seq(seqs, pad_token, return_lens=False, is_vft=False): + lengths = [s.shape[1] for s in seqs] + max_length = max(lengths) + output = [] + for seq in seqs: + if is_vft: + if len(seq.shape)==4: # spatio-temporal feature + result = torch.ones(((1, max_length), seq.shape[1], seq.shape[2], seq.shape[3]), dtype=seq.dtype)*pad_token + else: + result = torch.ones(((1, max_length), seq.shape[-1]), dtype=seq.dtype)*pad_token + else: + result = torch.ones((1, max_length), dtype=seq.dtype)*pad_token + result[0, :seq.shape[1]] = seq + output.append(result) + if return_lens: + return lengths, output + return output + + +def pad_2d_seq(seqs, pad_token, return_lens=False, is_vft=False): + lens1 = [len(s) for s in seqs] + max_len1 = max(lens1) + all_seqs = [] + for seq in seqs: + all_seqs.extend(seq) + lens2 = [s.shape[1] for s in all_seqs] + max_len2 = max(lens2) + output = [] + all_lens = [] + for seq in seqs: + if is_vft: + result = torch.ones((max_len1, max_len2, seq[0].shape[-1]))*pad_token + else: + result = torch.ones((1, max_len1, max_len2))*pad_token + #turn_lens = torch.ones(max_len1, dtype=np.int) + offset = max_len1 - len(seq) + for turn_idx, turn in enumerate(seq): + #result[turn_idx,:turn.shape[0]] = turn + # padding should be at the first turn idxs (Reason: result of last n turns is used for state creation) + result[0, turn_idx + offset,:turn.shape[1]] = turn + #turn_lens[turn_idx] = turn.shape[0] + output.append(result) + return output + + +class Simmc2DatasetMlmNsp(Dataset): + def __init__(self, tokenizer, feature_loader, load_path, args, hidden_labels=False): + self._tokenizer = tokenizer + self._features = feature_loader + self._args = args + self._hidden_labels = hidden_labels + print("Loading: {}".format(load_path)) + with open(load_path, "r") as file_id: + self._raw_data = json.load(file_id) + # Also read the source data for evaluation. + with open(self._raw_data["source_path"], "r") as file_id: + self.source_data = json.load(file_id) + self._data = self._raw_data["data"] + + self.num_utterances = 2 * args.max_turns + 1 + self.num_instances = len(self._data) + self.device = torch.cuda if args.use_gpu else torch + + + def conduct_mask(self, tokens, effective_length, start_id, end_id): + # taken from https://github.com/salesforce/VD-BERT + # For masked Language Models + cand_pos = [] + special_pos = set() + + n_pred = min(self._args.max_n_masked, max( + 1, int(round(effective_length * self._args.p_mask)))) + + # candidate positions of masked tokens + for i, tk in enumerate(tokens): + # only mask tokens_b (target sequence) + # we will mask [SEP] as an ending symbol + if (i >= start_id) and (tk != '[CLS]') and (tk != '[PAD]') and (i < end_id): + cand_pos.append(i) + else: + special_pos.add(i) + + shuffle(cand_pos) + masked_pos = cand_pos[:n_pred] + + masked_tokens = [tokens[pos] for pos in masked_pos] + for pos in masked_pos: + if self._args.finetune: + tokens[pos] = '[MASK]' + continue + if rand() < 0.8: # 80% + tokens[pos] = '[MASK]' + #elif rand() < 0.5: # 10% + # tokens[pos] = get_random_word(self.vocab_words) + # when n_pred < max_pred, we only calculate loss within n_pred + masked_weights = [1] * len(masked_tokens) + + # Token Indexing + input_ids = self._tokenizer.convert_tokens_to_ids(tokens) + masked_ids = self._tokenizer.convert_tokens_to_ids(masked_tokens) + + if self._args.max_n_masked > n_pred: + n_pad = self._args.max_n_masked - n_pred + masked_ids.extend([0] * n_pad) + masked_pos.extend([0] * n_pad) + masked_weights.extend([0] * n_pad) + + assert len(masked_ids) == len(masked_pos) == len(masked_weights) == self._args.max_n_masked, \ + "[masked] id: %d, pos: %d, weights: %d" % (len(masked_ids), len(masked_pos), len(masked_weights)) + + return input_ids, masked_ids, masked_pos, masked_weights + + + def get_random_batch(self, batch_size): + indices = np.random.randint(0, self.num_instances, batch_size) + return self.get_indexed_data(indices) + + def get_entire_batch(self, batch_size): + all_indices = np.arange(self.num_instances) + for start in all_indices[::batch_size]: + batch_indices = all_indices[start : start + batch_size] + yield self.get_indexed_data(batch_indices) + + + def __len__(self): + return len(self._data) + + + def collate_fn(self, batch): + merged_batch = {key: [d[key] for d in batch] for key in batch[0]} + out = {} + for key in merged_batch: + if key in ['qa_pair', 'q_len', 'q_turns_len', 'masked_pos', 'mask_labels', 'next_sentence_label', 'masked_weights']: + seq = pad_seq(merged_batch[key], pad_token=1) + out[key] = torch.concat(seq, dim=0) + elif key in ['qa_turns']: + if merged_batch[key][0] is not None: + seq = pad_2d_seq(merged_batch[key], pad_token=1) + out[key] = torch.concat(seq, dim=0).type(torch.int) + else: + out[key] = None + + elif key in ['features']: + #features = [f.unsqueeze(1) for f in merged_batch[key]] + # pad video featues + features = pad_sequence(merged_batch[key], batch_first=True) + out[key] = features + else: + out[key] = merged_batch[key] + + + return out + + + def encode_turns(self, turns): + encoded_turns = [] + for turn in turns: + encoded_turn = self._tokenizer( + turn, + padding=True, + max_length=self._args.max_length, + return_tensors="pt", + truncation=True, + ) + # without cls and sep token + encoded_turns.append(encoded_turn['input_ids'][:, 1:-1].type(torch.int)) + return encoded_turns + + + def __getitem__(self, index): + dialog_datum = self._data[index] + qa_pair = self._data[index]["qa_pair"] + qa_turns = self._data[index]["qa_turns"] + q_turns = self._data[index]["q_turns"] + next_sentence_label = self._data[index]["next_sentence_label"] + + if self._features: + feature = self._features[dialog_datum["image_name"]] + + # mask the qa_pair + qa_pair_as_tokens = self._tokenizer.tokenize(qa_pair[0]) + if next_sentence_label[0] == 0: + end_id = qa_pair_as_tokens.index('[SEP_1]') + effective_length = end_id + 1 + start_id = 0 + else: + end_id = len(qa_pair_as_tokens) - 1 + effective_length = len(qa_pair_as_tokens) + start_id = 0 + + if self._args.only_mask_ans: + effective_length = len(qa_pair_as_tokens) - qa_pair_as_tokens.index('[SEP_1]') + start_id = qa_pair_as_tokens.index('[SEP_1]') + + # get length of current and prv questions + q_len = [qa_pair_as_tokens.index('[SEP_1]')] + q_turns_len = [len(self._tokenizer.tokenize(q[0])) for q in q_turns] + + qa_pair_ids, masked_ids, masked_pos, masked_weights = self.conduct_mask( + tokens=qa_pair_as_tokens, + effective_length=effective_length, + start_id = start_id, + end_id=end_id + ) + + qa_turns_ids = self.encode_turns(qa_turns) + + + # Pack the sample. + sample = { + "qa_pair": torch.tensor(qa_pair_ids).unsqueeze(0), + "qa_turns": qa_turns_ids, + "features": feature, + "q_len": torch.tensor(q_len).unsqueeze(0), + "q_turns_len": torch.tensor(q_turns_len).unsqueeze(0), + "masked_pos": torch.tensor(masked_pos).unsqueeze(0), + "mask_labels": torch.tensor(masked_ids).unsqueeze(0), + "masked_weights": torch.tensor(masked_weights).unsqueeze(0), + "next_sentence_label": torch.tensor(next_sentence_label).unsqueeze(0) + } + return sample + + +class VisualFeatureLoader: + """Loads visual features for SIMMC 2.1 ambiguous candidate identification.""" + + UNAVAILABLE_IMAGES = [ + "cloth_store_1416238_woman_20_6.png", + "cloth_store_1416238_woman_19_0.png", + "cloth_store_1416238_woman_4_8.png", + ] + + def __init__(self, feature_path, feature_size): + """Read the features from the path.""" + self._features = torch.load(feature_path) + self._feature_size = feature_size + self._zero_feature = torch.zeros((1, self._feature_size), dtype=torch.float) + + def __getitem__(self, label): + """Get the feature given image label.""" + assert ( + label in self._features or label in self.UNAVAILABLE_IMAGES + ), f"{label} not found!" + if label in self.UNAVAILABLE_IMAGES: + return self._zero_feature + return self._features[label] + + def cuda(self): + """Move the features to cuda.""" + self._zero_feature = self._zero_feature.cuda() + for key, val in self._features.items(): + self._features[key] = val.cuda() diff --git a/src/utils/simmc2_dataset/dataloader_test_gen.py b/src/utils/simmc2_dataset/dataloader_test_gen.py new file mode 100644 index 0000000..03d2f39 --- /dev/null +++ b/src/utils/simmc2_dataset/dataloader_test_gen.py @@ -0,0 +1,253 @@ +#! /usr/bin/env python +""" +Copyright (c) Facebook, Inc. and its affiliates. +All rights reserved. +This source code is licensed under the license found in the LICENSE file in the +root directory of this source tree. + +Dataloader for ambiguous candidates identification task on SIMMC 2.1. + +Author(s): Satwik Kottur +""" + +from __future__ import absolute_import, division, print_function, unicode_literals + +import json + +import numpy as np +import torch +from torch.utils.data import Dataset +from torch.nn.utils.rnn import pad_sequence +from random import shuffle +from random import random as rand +#from src.utils.vd_bert.loader_utils import get_random_word + + +def pad_seq(seqs, pad_token, return_lens=False, is_vft=False): + lengths = [s.shape[1] for s in seqs] + max_length = max(lengths) + output = [] + for seq in seqs: + if is_vft: + if len(seq.shape)==4: # spatio-temporal feature + result = torch.ones(((1, max_length), seq.shape[1], seq.shape[2], seq.shape[3]), dtype=seq.dtype)*pad_token + else: + result = torch.ones(((1, max_length), seq.shape[-1]), dtype=seq.dtype)*pad_token + else: + result = torch.ones((1, max_length), dtype=seq.dtype)*pad_token + result[0, :seq.shape[1]] = seq + output.append(result) + if return_lens: + return lengths, output + return output + + +def pad_2d_seq(seqs, pad_token, return_lens=False, is_vft=False): + lens1 = [len(s) for s in seqs] + max_len1 = max(lens1) + all_seqs = [] + for seq in seqs: + all_seqs.extend(seq) + lens2 = [s.shape[1] for s in all_seqs] + max_len2 = max(lens2) + output = [] + all_lens = [] + for seq in seqs: + if is_vft: + result = torch.ones((max_len1, max_len2, seq[0].shape[-1]))*pad_token + else: + result = torch.ones((1, max_len1, max_len2))*pad_token + #turn_lens = torch.ones(max_len1, dtype=np.int) + offset = max_len1 - len(seq) + for turn_idx, turn in enumerate(seq): + #result[turn_idx,:turn.shape[0]] = turn + # padding should be at the first turn idxs (Reason: result of last n turns is used for state creation) + result[0, turn_idx + offset,:turn.shape[1]] = turn + #turn_lens[turn_idx] = turn.shape[0] + output.append(result) + return output + + +class Simmc2DatasetTest(Dataset): + def __init__(self, tokenizer, feature_loader, load_path, args, hidden_labels=False): + self._tokenizer = tokenizer + self._features = feature_loader + self._args = args + self._hidden_labels = hidden_labels + print("Loading: {}".format(load_path)) + with open(load_path, "r") as file_id: + self._raw_data = json.load(file_id) + # Also read the source data for evaluation. + with open(self._raw_data["source_path"], "r") as file_id: + self.source_data = json.load(file_id) + self._data = self._raw_data["data"] + + self.num_utterances = 2 * args.max_turns + 1 + self.num_instances = len(self._data) + self.device = torch.cuda if args.use_gpu else torch + + + def conduct_mask(self, tokens, effective_length, start_id, end_id): + # taken from https://github.com/salesforce/VD-BERT + # For masked Language Models + cand_pos = [] + special_pos = set() + + n_pred = min(self._args.max_n_masked, max( + 1, int(round(effective_length * self._args.p_mask)))) + + # candidate positions of masked tokens + for i, tk in enumerate(tokens): + # only mask tokens_b (target sequence) + # we will mask [SEP] as an ending symbol + if (i >= start_id) and (tk != '[CLS]') and (tk != '[PAD]') and (i < end_id): + cand_pos.append(i) + else: + special_pos.add(i) + + shuffle(cand_pos) + masked_pos = cand_pos[:n_pred] + + masked_tokens = [tokens[pos] for pos in masked_pos] + for pos in masked_pos: + if self._args.finetune: + tokens[pos] = '[MASK]' + continue + if rand() < 0.8: # 80% + tokens[pos] = '[MASK]' + #elif rand() < 0.5: # 10% + # tokens[pos] = get_random_word(self.vocab_words) + # when n_pred < max_pred, we only calculate loss within n_pred + masked_weights = [1] * len(masked_tokens) + + # Token Indexing + input_ids = self._tokenizer.convert_tokens_to_ids(tokens) + masked_ids = self._tokenizer.convert_tokens_to_ids(masked_tokens) + + if self._args.max_n_masked > n_pred: + n_pad = self._args.max_n_masked - n_pred + masked_ids.extend([0] * n_pad) + masked_pos.extend([0] * n_pad) + masked_weights.extend([0] * n_pad) + + assert len(masked_ids) == len(masked_pos) == len(masked_weights) == self._args.max_n_masked, \ + "[masked] id: %d, pos: %d, weights: %d" % (len(masked_ids), len(masked_pos), len(masked_weights)) + + return input_ids, masked_ids, masked_pos, masked_weights + + + def get_random_batch(self, batch_size): + indices = np.random.randint(0, self.num_instances, batch_size) + return self.get_indexed_data(indices) + + def get_entire_batch(self, batch_size): + all_indices = np.arange(self.num_instances) + for start in all_indices[::batch_size]: + batch_indices = all_indices[start : start + batch_size] + yield self.get_indexed_data(batch_indices) + + + def __len__(self): + return len(self._data) + + + def collate_fn(self, batch): + merged_batch = {key: [d[key] for d in batch] for key in batch[0]} + out = {} + for key in merged_batch: + if key in ['qa_pair', 'masked_pos', 'mask_labels', 'next_sentence_label', 'masked_weights', 'q_len']: + seq = pad_seq(merged_batch[key], pad_token=1) + out[key] = torch.concat(seq, dim=0) + elif key in ['qa_turns']: + if merged_batch[key][0] is not None: + seq = pad_2d_seq(merged_batch[key], pad_token=1) + out[key] = torch.concat(seq, dim=0).type(torch.int) + else: + out[key] = None + elif key in ['answer']: + out[key] = merged_batch[key] + elif key in ['features']: + #features = [f.unsqueeze(1) for f in merged_batch[key]] + # pad video featues + features = pad_sequence(merged_batch[key], batch_first=True) + out[key] = features + else: + out[key] = merged_batch[key] + + + return out + + + def encode_turns(self, turns): + encoded_turns = [] + for turn in turns: + encoded_turn = self._tokenizer( + turn, + padding=True, + max_length=self._args.max_length, + return_tensors="pt", + truncation=True, + ) + # without cls and sep token + encoded_turns.append(encoded_turn['input_ids'][:, 1:-1].type(torch.int)) + return encoded_turns + + + def __getitem__(self, index): + dialog_datum = self._data[index] + qa_pair = self._data[index]["qa_pair"] + qa_turns = self._data[index]["qa_turns"] + answer = self._data[index]["answer"] + next_sentence_label = self._data[index]["next_sentence_label"] + + if self._features: + feature = self._features[dialog_datum["image_name"]] + + qa_pair_as_tokens = self._tokenizer.tokenize(qa_pair[0]) + q_len = [qa_pair_as_tokens.index('[SEP_1]')] + + + qa_pair_ids = self._tokenizer.convert_tokens_to_ids(qa_pair_as_tokens) + qa_turns_ids = self.encode_turns(qa_turns) + + + # Pack the sample. + sample = { + "answer": answer, + "qa_pair": torch.tensor(qa_pair_ids).unsqueeze(0), + "q_len": torch.tensor(q_len).unsqueeze(0), + "qa_turns": qa_turns_ids, + "features": feature + } + return sample + + +class VisualFeatureLoader: + """Loads visual features for SIMMC 2.1 ambiguous candidate identification.""" + + UNAVAILABLE_IMAGES = [ + "cloth_store_1416238_woman_20_6.png", + "cloth_store_1416238_woman_19_0.png", + "cloth_store_1416238_woman_4_8.png", + ] + + def __init__(self, feature_path, feature_size): + """Read the features from the path.""" + self._features = torch.load(feature_path) + self._feature_size = feature_size + self._zero_feature = torch.zeros((1, self._feature_size), dtype=torch.float) + + def __getitem__(self, label): + """Get the feature given image label.""" + assert ( + label in self._features or label in self.UNAVAILABLE_IMAGES + ), f"{label} not found!" + if label in self.UNAVAILABLE_IMAGES: + return self._zero_feature + return self._features[label] + + def cuda(self): + """Move the features to cuda.""" + self._zero_feature = self._zero_feature.cuda() + for key, val in self._features.items(): + self._features[key] = val.cuda() diff --git a/src/utils/simmc2_dataset/format_data.py b/src/utils/simmc2_dataset/format_data.py new file mode 100644 index 0000000..1c770a2 --- /dev/null +++ b/src/utils/simmc2_dataset/format_data.py @@ -0,0 +1,150 @@ +#! /usr/bin/env python +""" +Copyright (c) Facebook, Inc. and its affiliates. +All rights reserved. +This source code is licensed under the license found in the LICENSE file in the +root directory of this source tree. + +Reads SIMMC 2.1 dataset, creates train, devtest, dev formats for ambiguous candidates. + +Author(s): Satwik Kottur +""" + +from __future__ import absolute_import, division, print_function, unicode_literals + +import argparse +import copy +import json +import os + + +SPLITS = ["train", "dev", "devtest", "teststd"] + + +def get_image_name(scene_ids, turn_ind): + """Given scene ids and turn index, get the image name. + """ + sorted_scene_ids = sorted( + ((int(key), val) for key, val in scene_ids.items()), + key=lambda x: x[0], + reverse=True + ) + # NOTE: Hardcoded to only two scenes. + if turn_ind >= sorted_scene_ids[0][0]: + scene_label = sorted_scene_ids[0][1] + else: + scene_label = sorted_scene_ids[1][1] + image_label = scene_label + if "m_" in scene_label: + image_label = image_label.replace("m_", "") + return f"{image_label}.png", scene_label + + +def get_object_mapping(scene_label, args): + """Get the object mapping for a given scene. + """ + scene_json_path = os.path.join( + args["scene_json_folder"], f"{scene_label}_scene.json" + ) + with open(scene_json_path, "r") as file_id: + scene_objects = json.load(file_id)["scenes"][0]["objects"] + object_map = [ii["index"] for ii in scene_objects] + return object_map + + +def main(args): + for split in SPLITS: + read_path = args[f"simmc_{split}_json"] + print(f"Reading: {read_path}") + with open(read_path, "r") as file_id: + dialogs = json.load(file_id) + + # Reformat into simple strings with positive and negative labels. + # (dialog string, label) + ambiguous_candidates_data = [] + for dialog_id, dialog_datum in enumerate(dialogs["dialogue_data"]): + turns = [] + q_turns = [] + a_turns = [] + for turn_ind, turn_datum in enumerate(dialog_datum["dialogue"]): + query = [turn_datum["transcript"]] + answer = [turn_datum["system_transcript"]] + + #annotations = turn_datum["transcript_annotated"] + #if annotations.get("disambiguation_label", False): + #label = annotations["disambiguation_candidates"] + image_name, scene_label = get_image_name( + dialog_datum["scene_ids"], turn_ind + ) + # If dialog contains multiple scenes, map it accordingly. + object_map = get_object_mapping(scene_label, args) + new_datum = { + "query": query, + "answer": answer, + "q_turns": copy.deepcopy(q_turns), + "a_turns": copy.deepcopy(a_turns), + "turns": copy.deepcopy(turns), + "dialog_id": dialog_datum["dialogue_idx"], + "turn_id": turn_ind, + #"input_text": copy.deepcopy(history), + #"ambiguous_candidates": label, + "image_name": image_name, + "object_map": object_map, + } + + ambiguous_candidates_data.append(new_datum) + + turns.append([turn_datum["transcript"] + turn_datum["system_transcript"]]) + q_turns.append(query) + a_turns.append(answer) + + + # Ignore if system_transcript is not found (last round teststd). + # if turn_datum.get("system_transcript", None): + # history.append(turn_datum["system_transcript"]) + + print(f"# instances [{split}]: {len(ambiguous_candidates_data)}") + save_path = os.path.join( + args["ambiguous_candidates_save_path"], + f"simmc2.1_ambiguous_candidates_dstc11_{split}.json" + ) + print(f"Saving: {save_path}") + with open(save_path, "w") as file_id: + json.dump( + { + "source_path": read_path, + "split": split, + "data": ambiguous_candidates_data, + }, + file_id + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--simmc_train_json", default=None, help="Path to SIMMC 2.1 train" + ) + parser.add_argument( + "--simmc_dev_json", default=None, help="Path to SIMMC 2.1 dev" + ) + parser.add_argument( + "--simmc_devtest_json", default=None, help="Path to SIMMC 2.1 devtest" + ) + parser.add_argument( + "--simmc_teststd_json", default=None, help="Path to SIMMC 2.1 teststd (public)" + ) + parser.add_argument( + "--scene_json_folder", default=None, help="Path to SIMMC scene jsons" + ) + parser.add_argument( + "--ambiguous_candidates_save_path", + required=True, + help="Path to save SIMMC disambiguate JSONs", + ) + + try: + parsed_args = vars(parser.parse_args()) + except (IOError) as msg: + parser.error(str(msg)) + main(parsed_args) diff --git a/src/utils/simmc2_dataset/format_data.sh b/src/utils/simmc2_dataset/format_data.sh new file mode 100755 index 0000000..14881fd --- /dev/null +++ b/src/utils/simmc2_dataset/format_data.sh @@ -0,0 +1,8 @@ +#!/bin/bash +DATA_FOLDER="../../data/" +python format_data.py \ + --simmc_train_json "/scratch/hochmeister/simmc2/data/simmc2.1_dials_dstc11_train.json" \ + --simmc_dev_json "/scratch/hochmeister/simmc2/data/simmc2.1_dials_dstc11_dev.json" \ + --simmc_devtest_json "/scratch/hochmeister/simmc2/data/simmc2.1_dials_dstc11_devtest.json" \ + --scene_json_folder "/scratch/hochmeister/simmc2/data/public/" \ + --ambiguous_candidates_save_path "/scratch/hochmeister/simmc2/data/ambiguous_candidates/" \ No newline at end of file diff --git a/src/utils/simmc2_dataset/format_data_subtask4_b.py b/src/utils/simmc2_dataset/format_data_subtask4_b.py new file mode 100644 index 0000000..df0f205 --- /dev/null +++ b/src/utils/simmc2_dataset/format_data_subtask4_b.py @@ -0,0 +1,224 @@ +#! /usr/bin/env python +""" +Copyright (c) Facebook, Inc. and its affiliates. +All rights reserved. +This source code is licensed under the license found in the LICENSE file in the +root directory of this source tree. + +Reads SIMMC 2.1 dataset, creates train, devtest, dev formats for ambiguous candidates. + +Author(s): Satwik Kottur +""" + +from __future__ import absolute_import, division, print_function, unicode_literals + +import argparse +import copy +import json +import os +import random + + +SPLITS = ["train", "dev", "devtest", "teststd"] + + +def get_image_name(scene_ids, turn_ind): + """Given scene ids and turn index, get the image name. + """ + sorted_scene_ids = sorted( + ((int(key), val) for key, val in scene_ids.items()), + key=lambda x: x[0], + reverse=True + ) + # NOTE: Hardcoded to only two scenes. + if turn_ind >= sorted_scene_ids[0][0]: + scene_label = sorted_scene_ids[0][1] + else: + scene_label = sorted_scene_ids[1][1] + image_label = scene_label + if "m_" in scene_label: + image_label = image_label.replace("m_", "") + return f"{image_label}.png", scene_label + + +def get_object_mapping(scene_label, args): + """Get the object mapping for a given scene. + """ + scene_json_path = os.path.join( + args["scene_json_folder"], f"{scene_label}_scene.json" + ) + with open(scene_json_path, "r") as file_id: + scene_objects = json.load(file_id)["scenes"][0]["objects"] + object_map = [ii["index"] for ii in scene_objects] + return object_map + + +def dictionary_to_string(dictionary): + result = "" + for k, v in dictionary.items(): + result += k + ":" + result += str(v) + " " + return result + + +def get_all_answers(dialogs): + all_answers = [] + for dialog_id, dialog_datum in enumerate(dialogs["dialogue_data"]): + for turn_ind, turn_datum in enumerate(dialog_datum["dialogue"]): + all_answers.append(turn_datum["system_transcript"]) + return all_answers + + +def main(args): + for split in SPLITS: + read_path = args[f"simmc_{split}_json"] + print(f"Reading: {read_path}") + with open(read_path, "r") as file_id: + dialogs = json.load(file_id) + + + # load the metadata files + with open(args["furniture_prefab_metadata"], "r") as file: + furniture_metadata = json.load(file) + + with open(args["fashion_prefab_metadata"], "r") as file: + fashion_metadata = json.load(file) + + # get all answer fromm all dialogues to sample answer candidates from for each dialogue iteration + all_answers = get_all_answers(dialogs) + + + # Reformat into simple strings with positive and negative labels. + # (dialog string, label) + ambiguous_candidates_data = [] + for dialog_id, dialog_datum in enumerate(dialogs["dialogue_data"]): + turns = [] + q_turns = [] + a_turns = [] + + for turn_ind, turn_datum in enumerate(dialog_datum["dialogue"]): + query = [turn_datum["transcript"]] + answer = [turn_datum["system_transcript"]] + answer_candidates = [] + + # sample random answers from the list of all answers as answer candidates + # sample n_answer_candidates - 1 wrong answer candidates from the list of all answers + for _ in range(int(args["n_answer_candidates"]) - 1): + random_idx = random.randint(0, len(all_answers) - 1) + answer_candidates.append([all_answers[random_idx]]) + answer_candidates.insert(0, answer) + #random.shuffle(answer_candidates) + + #annotations = turn_datum["transcript_annotated"] + #if annotations.get("disambiguation_label", False): + #label = annotations["disambiguation_candidates"] + image_name, scene_id = get_image_name( + dialog_datum["scene_ids"], turn_ind + ) + + # load the scene files and get all the prefab pahts to get the object descriptions for each scene + prefab_paths = [] + scene_path = os.path.join(args["scene_json_folder"], f"{scene_id}_scene.json") + with open(scene_path, "r") as scene_file: + scene_data = json.load(scene_file) + for scene in scene_data["scenes"]: + for object in scene["objects"]: + prefab_paths.append(object["prefab_path"]) + + # get the metadata for all objects of the scene (prefab_paths) + object_metadata = [] + for prefab_path in prefab_paths: + if scene_id[:11] in ["cloth_store", "m_cloth_sto"]: + object_dict = fashion_metadata[prefab_path] + elif scene_id[:7] == "wayfair": + object_dict = furniture_metadata[prefab_path] + object_str = dictionary_to_string(object_dict) + object_metadata.append([object_str]) + + + # If dialog contains multiple scenes, map it accordingly. + #object_map = get_object_mapping(scene_label, args) + new_datum = { + "query": query, + "answer": answer, + "answer_candidates": answer_candidates, + "q_turns": copy.deepcopy(q_turns), + "a_turns": copy.deepcopy(a_turns), + "turns": copy.deepcopy(turns), + "object_metadata": object_metadata, + "dialog_id": dialog_datum["dialogue_idx"], + "turn_id": turn_ind, + #"input_text": copy.deepcopy(history), + #"ambiguous_candidates": label, + "image_name": image_name, + #"object_map": object_map, + } + + ambiguous_candidates_data.append(new_datum) + + turns.append([turn_datum["transcript"] + turn_datum["system_transcript"]]) + q_turns.append(query) + a_turns.append(answer) + + + # Ignore if system_transcript is not found (last round teststd). + # if turn_datum.get("system_transcript", None): + # history.append(turn_datum["system_transcript"]) + + print(f"# instances [{split}]: {len(ambiguous_candidates_data)}") + save_path = os.path.join( + args["ambiguous_candidates_save_path"], + f"simmc2.1_ambiguous_candidates_dstc11_{split}.json" + ) + print(f"Saving: {save_path}") + with open(save_path, "w") as file_id: + json.dump( + { + "source_path": read_path, + "split": split, + "data": ambiguous_candidates_data, + }, + file_id + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--simmc_train_json", default=None, help="Path to SIMMC 2.1 train" + ) + parser.add_argument( + "--simmc_dev_json", default=None, help="Path to SIMMC 2.1 dev" + ) + parser.add_argument( + "--simmc_devtest_json", default=None, help="Path to SIMMC 2.1 devtest" + ) + parser.add_argument( + "--simmc_teststd_json", default=None, help="Path to SIMMC 2.1 teststd (public)" + ) + parser.add_argument( + "--scene_json_folder", default=None, help="Path to SIMMC scene jsons" + ) + parser.add_argument( + "--ambiguous_candidates_save_path", + required=True, + help="Path to save SIMMC disambiguate JSONs", + ) + parser.add_argument( + "--fashion_prefab_metadata", required=True, + help="Path to the file with all metadata for fashion objects" + ) + parser.add_argument( + "--furniture_prefab_metadata", required=True, + help="Path to the file with all metadata for fashion objects" + ) + parser.add_argument( + "--n_answer_candidates", required=True, + help="number of answer candidates for the ranking task" + ) + + try: + parsed_args = vars(parser.parse_args()) + except (IOError) as msg: + parser.error(str(msg)) + main(parsed_args) diff --git a/src/utils/simmc2_dataset/format_data_subtask4_b.sh b/src/utils/simmc2_dataset/format_data_subtask4_b.sh new file mode 100755 index 0000000..352ea5e --- /dev/null +++ b/src/utils/simmc2_dataset/format_data_subtask4_b.sh @@ -0,0 +1,11 @@ +#!/bin/bash +DATA_FOLDER="../../data/" +python format_data_with_object_descriptions.py \ + --simmc_train_json "/scratch/hochmeister/simmc2/data/simmc2.1_dials_dstc11_train.json" \ + --simmc_dev_json "/scratch/hochmeister/simmc2/data/simmc2.1_dials_dstc11_dev.json" \ + --simmc_devtest_json "/scratch/hochmeister/simmc2/data/simmc2.1_dials_dstc11_devtest.json" \ + --scene_json_folder "/scratch/hochmeister/simmc2/data/public/" \ + --ambiguous_candidates_save_path "/scratch/hochmeister/simmc2/data/subtask_4_b_data/" + --fashion_prefab_metadata "/scratch/hochmeister/simmc2/data/fashion_prefab_metadata_all.json" + --furniture_prefab_metadata "/scratch/hochmeister/simmc2/data/furniture_prefab_metadata_all.json" + --n_answer_candidates 10 \ No newline at end of file diff --git a/src/utils/simmc2_dataset/format_data_subtask4_mlm_nsp.py b/src/utils/simmc2_dataset/format_data_subtask4_mlm_nsp.py new file mode 100644 index 0000000..f7bac83 --- /dev/null +++ b/src/utils/simmc2_dataset/format_data_subtask4_mlm_nsp.py @@ -0,0 +1,207 @@ +#! /usr/bin/env python +""" +Copyright (c) Facebook, Inc. and its affiliates. +All rights reserved. +This source code is licensed under the license found in the LICENSE file in the +root directory of this source tree. + +Reads SIMMC 2.1 dataset, creates train, devtest, dev formats for ambiguous candidates. + +Author(s): Satwik Kottur +""" + +from __future__ import absolute_import, division, print_function, unicode_literals + +import argparse +import copy +import json +import os +import random + + +SPLITS = ["train", "dev", "devtest", "teststd"] + + +def get_image_name(scene_ids, turn_ind): + """Given scene ids and turn index, get the image name. + """ + sorted_scene_ids = sorted( + ((int(key), val) for key, val in scene_ids.items()), + key=lambda x: x[0], + reverse=True + ) + # NOTE: Hardcoded to only two scenes. + if turn_ind >= sorted_scene_ids[0][0]: + scene_label = sorted_scene_ids[0][1] + else: + scene_label = sorted_scene_ids[1][1] + image_label = scene_label + if "m_" in scene_label: + image_label = image_label.replace("m_", "") + return f"{image_label}.png", scene_label + + +def get_object_mapping(scene_label, args): + """Get the object mapping for a given scene. + """ + scene_json_path = os.path.join( + args["scene_json_folder"], f"{scene_label}_scene.json" + ) + with open(scene_json_path, "r") as file_id: + scene_objects = json.load(file_id)["scenes"][0]["objects"] + object_map = [ii["index"] for ii in scene_objects] + return object_map + + +def dictionary_to_string(dictionary): + result = "" + for k, v in dictionary.items(): + result += k + ":" + result += str(v) + " " + return result + + +def get_all_answers(dialogs): + all_answers = [] + for dialog_id, dialog_datum in enumerate(dialogs["dialogue_data"]): + for turn_ind, turn_datum in enumerate(dialog_datum["dialogue"]): + all_answers.append(turn_datum["system_transcript"]) + return all_answers + + +def main(args): + for split in SPLITS: + read_path = args[f"simmc_{split}_json"] + print(f"Reading: {read_path}") + with open(read_path, "r") as file_id: + dialogs = json.load(file_id) + + # get all answer fromm all dialogues to sample answer candidates from for each dialogue iteration + all_answers = get_all_answers(dialogs) + + ambiguous_candidates_data = [] + for dialog_id, dialog_datum in enumerate(dialogs["dialogue_data"]): + q_turns = [] + a_turns = [] + qa_turns = [] + + for turn_ind, turn_datum in enumerate(dialog_datum["dialogue"]): + query = turn_datum["transcript"] + answer = turn_datum["system_transcript"] + + # wrong answer is used to create false sample for nsp + wrong_answer = random.choice(all_answers) + + qa_pair = query + '[SEP_1]' + answer + '[SEP]' + wrong_qa_pair = query + '[SEP_1]' + wrong_answer + '[SEP]' + + image_name, scene_id = get_image_name( + dialog_datum["scene_ids"], turn_ind + ) + + # load the scene files and get all the prefab pahts to get the object descriptions for each scene + prefab_paths = [] + scene_path = os.path.join(args["scene_json_folder"], f"{scene_id}_scene.json") + with open(scene_path, "r") as scene_file: + scene_data = json.load(scene_file) + for scene in scene_data["scenes"]: + for object in scene["objects"]: + prefab_paths.append(object["prefab_path"]) + + # for each dialogue round add a sample with the correct answer and one with a random answer for nsp + new_datum_correct_answer = { + "query": [query], + "answer": [answer], + "qa_pair": [qa_pair], + "next_sentence_label": [1], + "q_turns": copy.deepcopy(q_turns), + "a_turns": copy.deepcopy(a_turns), + "qa_turns": copy.deepcopy(qa_turns), + "dialog_id": dialog_datum["dialogue_idx"], + "turn_id": turn_ind, + "image_name": image_name, + } + new_datum_wrong_answer = { + "query": [query], + "answer": [wrong_answer], + "qa_pair": [wrong_qa_pair], + "next_sentence_label": [0], + "q_turns": copy.deepcopy(q_turns), + "a_turns": copy.deepcopy(a_turns), + "qa_turns": copy.deepcopy(qa_turns), + "dialog_id": dialog_datum["dialogue_idx"], + "turn_id": turn_ind, + "image_name": image_name, + } + + ambiguous_candidates_data.append(new_datum_correct_answer) + + if args['create_false_samples_for_nsp']: + ambiguous_candidates_data.append(new_datum_wrong_answer) + + q_turns.append([query]) + a_turns.append([answer]) + qa_turns.append([qa_pair]) + + + # Ignore if system_transcript is not found (last round teststd). + # if turn_datum.get("system_transcript", None): + # history.append(turn_datum["system_transcript"]) + + print(f"# instances [{split}]: {len(ambiguous_candidates_data)}") + save_path = os.path.join( + args["ambiguous_candidates_save_path"], + f"simmc2.1_ambiguous_candidates_dstc11_{split}.json" + ) + print(f"Saving: {save_path}") + with open(save_path, "w") as file_id: + json.dump( + { + "source_path": read_path, + "split": split, + "data": ambiguous_candidates_data, + }, + file_id + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--simmc_train_json", default=None, help="Path to SIMMC 2.1 train" + ) + parser.add_argument( + "--simmc_dev_json", default=None, help="Path to SIMMC 2.1 dev" + ) + parser.add_argument( + "--simmc_devtest_json", default=None, help="Path to SIMMC 2.1 devtest" + ) + parser.add_argument( + "--simmc_teststd_json", default=None, help="Path to SIMMC 2.1 teststd (public)" + ) + parser.add_argument( + "--scene_json_folder", default=None, help="Path to SIMMC scene jsons" + ) + parser.add_argument( + "--ambiguous_candidates_save_path", + required=True, + help="Path to save SIMMC disambiguate JSONs", + ) + parser.add_argument( + "--fashion_prefab_metadata", required=True, + help="Path to the file with all metadata for fashion objects" + ) + parser.add_argument( + "--furniture_prefab_metadata", required=True, + help="Path to the file with all metadata for fashion objects" + ) + parser.add_argument( + "--create_false_samples_for_nsp", action='store_true', + help="if set, for each correct sample a wrong one is added" + ) + + try: + parsed_args = vars(parser.parse_args()) + except (IOError) as msg: + parser.error(str(msg)) + main(parsed_args) diff --git a/src/utils/simmc2_dataset/format_data_with_obj_descriptions.py b/src/utils/simmc2_dataset/format_data_with_obj_descriptions.py new file mode 100644 index 0000000..5d10542 --- /dev/null +++ b/src/utils/simmc2_dataset/format_data_with_obj_descriptions.py @@ -0,0 +1,202 @@ +#! /usr/bin/env python +""" +Copyright (c) Facebook, Inc. and its affiliates. +All rights reserved. +This source code is licensed under the license found in the LICENSE file in the +root directory of this source tree. + +Reads SIMMC 2.1 dataset, creates train, devtest, dev formats for ambiguous candidates. + +Author(s): Satwik Kottur +""" + +from __future__ import absolute_import, division, print_function, unicode_literals + +import argparse +import copy +import json +import os + + +SPLITS = ["train", "dev", "devtest", "teststd"] + + +def get_image_name(scene_ids, turn_ind): + """Given scene ids and turn index, get the image name. + """ + sorted_scene_ids = sorted( + ((int(key), val) for key, val in scene_ids.items()), + key=lambda x: x[0], + reverse=True + ) + # NOTE: Hardcoded to only two scenes. + if turn_ind >= sorted_scene_ids[0][0]: + scene_label = sorted_scene_ids[0][1] + else: + scene_label = sorted_scene_ids[1][1] + image_label = scene_label + if "m_" in scene_label: + image_label = image_label.replace("m_", "") + return f"{image_label}.png", scene_label + + +def get_object_mapping(scene_label, args): + """Get the object mapping for a given scene. + """ + scene_json_path = os.path.join( + args["scene_json_folder"], f"{scene_label}_scene.json" + ) + with open(scene_json_path, "r") as file_id: + scene_objects = json.load(file_id)["scenes"][0]["objects"] + object_map = [ii["index"] for ii in scene_objects] + return object_map + + +def dictionary_to_string(dictionary): + result = "" + for k, v in dictionary.items(): + if k in ['assetType', 'color', 'pattern', 'sleeveLength', 'type']: + continue + result += k + ":" + result += str(v) + " " + return result + + +def main(args): + for split in SPLITS: + read_path = args[f"simmc_{split}_json"] + print(f"Reading: {read_path}") + with open(read_path, "r") as file_id: + dialogs = json.load(file_id) + + + # load the metadata files + with open(args["furniture_prefab_metadata"], "r") as file: + furniture_metadata = json.load(file) + + with open(args["fashion_prefab_metadata"], "r") as file: + fashion_metadata = json.load(file) + + + # Reformat into simple strings with positive and negative labels. + # (dialog string, label) + ambiguous_candidates_data = [] + for dialog_id, dialog_datum in enumerate(dialogs["dialogue_data"]): + turns = [] + q_turns = [] + a_turns = [] + + for turn_ind, turn_datum in enumerate(dialog_datum["dialogue"]): + query = [turn_datum["transcript"]] + if "system_transcript" not in turn_datum.keys(): + continue + answer = [turn_datum["system_transcript"]] + + #annotations = turn_datum["transcript_annotated"] + #if annotations.get("disambiguation_label", False): + #label = annotations["disambiguation_candidates"] + image_name, scene_id = get_image_name( + dialog_datum["scene_ids"], turn_ind + ) + + # load the scene files and get all the prefab pahts to get the object descriptions for each scene + prefab_paths = [] + scene_path = os.path.join(args["scene_json_folder"], f"{scene_id}_scene.json") + with open(scene_path, "r") as scene_file: + scene_data = json.load(scene_file) + for scene in scene_data["scenes"]: + for object in scene["objects"]: + prefab_paths.append(object["prefab_path"]) + + # get the metadata for all objects of the scene (prefab_paths) + object_metadata = [] + for prefab_path in prefab_paths: + if scene_id[:11] in ["cloth_store", "m_cloth_sto"]: + object_dict = fashion_metadata[prefab_path] + elif scene_id[:7] == "wayfair": + object_dict = furniture_metadata[prefab_path] + object_str = dictionary_to_string(object_dict) + object_metadata.append([object_str]) + + + # If dialog contains multiple scenes, map it accordingly. + #object_map = get_object_mapping(scene_label, args) + new_datum = { + "query": query, + "answer": answer, + "q_turns": copy.deepcopy(q_turns), + "a_turns": copy.deepcopy(a_turns), + "turns": copy.deepcopy(turns), + "object_metadata": object_metadata, + "dialog_id": dialog_datum["dialogue_idx"], + "turn_id": turn_ind, + #"input_text": copy.deepcopy(history), + #"ambiguous_candidates": label, + "image_name": image_name, + #"object_map": object_map, + } + + ambiguous_candidates_data.append(new_datum) + + turns.append([turn_datum["transcript"] + turn_datum["system_transcript"]]) + q_turns.append(query) + a_turns.append(answer) + + + # Ignore if system_transcript is not found (last round teststd). + # if turn_datum.get("system_transcript", None): + # history.append(turn_datum["system_transcript"]) + + print(f"# instances [{split}]: {len(ambiguous_candidates_data)}") + save_path = os.path.join( + args["ambiguous_candidates_save_path"], + f"simmc2.1_ambiguous_candidates_dstc11_{split}.json" + ) + print(f"Saving: {save_path}") + with open(save_path, "w") as file_id: + json.dump( + { + "source_path": read_path, + "split": split, + "data": ambiguous_candidates_data, + }, + file_id + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--simmc_train_json", default=None, help="Path to SIMMC 2.1 train" + ) + parser.add_argument( + "--simmc_dev_json", default=None, help="Path to SIMMC 2.1 dev" + ) + parser.add_argument( + "--simmc_devtest_json", default=None, help="Path to SIMMC 2.1 devtest" + ) + parser.add_argument( + "--simmc_teststd_json", default=None, help="Path to SIMMC 2.1 teststd (public)" + ) + parser.add_argument( + "--scene_json_folder", default=None, help="Path to SIMMC scene jsons" + ) + parser.add_argument( + "--ambiguous_candidates_save_path", + required=True, + help="Path to save SIMMC disambiguate JSONs", + ) + parser.add_argument( + "--fashion_prefab_metadata", required=True, + help="Path to the file with all metadata for fashion objects" + ) + parser.add_argument( + "--furniture_prefab_metadata", required=True, + help="Path to the file with all metadata for fashion objects" + ) + + try: + parsed_args = vars(parser.parse_args()) + except (IOError) as msg: + parser.error(str(msg)) + main(parsed_args) diff --git a/src/utils/simmc2_dataset/format_data_with_obj_descriptions.sh b/src/utils/simmc2_dataset/format_data_with_obj_descriptions.sh new file mode 100755 index 0000000..f5f0b75 --- /dev/null +++ b/src/utils/simmc2_dataset/format_data_with_obj_descriptions.sh @@ -0,0 +1,10 @@ +#!/bin/bash +DATA_FOLDER="../../data/" +python format_data_with_object_descriptions.py \ + --simmc_train_json "/scratch/hochmeister/simmc2/data/simmc2.1_dials_dstc11_train.json" \ + --simmc_dev_json "/scratch/hochmeister/simmc2/data/simmc2.1_dials_dstc11_dev.json" \ + --simmc_devtest_json "/scratch/hochmeister/simmc2/data/simmc2.1_dials_dstc11_devtest.json" \ + --scene_json_folder "/scratch/hochmeister/simmc2/data/public/" \ + --ambiguous_candidates_save_path "/scratch/hochmeister/simmc2/data/ambiguous_candidates/" + --fashion_prefab_metadata "/scratch/hochmeister/simmc2/data/fashion_prefab_metadata_all.json" + --furniture_prefab_metadata "/scratch/hochmeister/simmc2/data/furniture_prefab_metadata_all.json" \ No newline at end of file diff --git a/src/utils/simmc2_dataset/format_data_with_obj_descriptions_devtest10.py b/src/utils/simmc2_dataset/format_data_with_obj_descriptions_devtest10.py new file mode 100644 index 0000000..e8374f0 --- /dev/null +++ b/src/utils/simmc2_dataset/format_data_with_obj_descriptions_devtest10.py @@ -0,0 +1,206 @@ +#! /usr/bin/env python +""" +Copyright (c) Facebook, Inc. and its affiliates. +All rights reserved. +This source code is licensed under the license found in the LICENSE file in the +root directory of this source tree. + +Reads SIMMC 2.1 dataset, creates train, devtest, dev formats for ambiguous candidates. + +Author(s): Satwik Kottur +""" + +from __future__ import absolute_import, division, print_function, unicode_literals + +import argparse +import copy +import json +import os + + +SPLITS = ["train", "dev", "devtest", "teststd"] + + +def get_image_name(scene_ids, turn_ind): + """Given scene ids and turn index, get the image name. + """ + sorted_scene_ids = sorted( + ((int(key), val) for key, val in scene_ids.items()), + key=lambda x: x[0], + reverse=True + ) + # NOTE: Hardcoded to only two scenes. + if turn_ind >= sorted_scene_ids[0][0]: + scene_label = sorted_scene_ids[0][1] + else: + scene_label = sorted_scene_ids[1][1] + image_label = scene_label + if "m_" in scene_label: + image_label = image_label.replace("m_", "") + return f"{image_label}.png", scene_label + + +def get_object_mapping(scene_label, args): + """Get the object mapping for a given scene. + """ + scene_json_path = os.path.join( + args["scene_json_folder"], f"{scene_label}_scene.json" + ) + with open(scene_json_path, "r") as file_id: + scene_objects = json.load(file_id)["scenes"][0]["objects"] + object_map = [ii["index"] for ii in scene_objects] + return object_map + + +def dictionary_to_string(dictionary): + result = "" + for k, v in dictionary.items(): + if k in ['assetType', 'color', 'pattern', 'sleeveLength', 'type']: + continue + result += k + ":" + result += str(v) + " " + return result + + +def main(args): + for split in SPLITS: + read_path = args[f"simmc_{split}_json"] + print(f"Reading: {read_path}") + with open(read_path, "r") as file_id: + dialogs = json.load(file_id) + + + # load the metadata files + with open(args["furniture_prefab_metadata"], "r") as file: + furniture_metadata = json.load(file) + + with open(args["fashion_prefab_metadata"], "r") as file: + fashion_metadata = json.load(file) + + + # Reformat into simple strings with positive and negative labels. + # (dialog string, label) + ambiguous_candidates_data = [] + for dialog_id, dialog_datum in enumerate(dialogs["dialogue_data"]): + turns = [] + q_turns = [] + a_turns = [] + dial_len = len(dialog_datum['dialogue']) + + for turn_ind, turn_datum in enumerate(dialog_datum["dialogue"]): + query = [turn_datum["transcript"]] + if "system_transcript" not in turn_datum.keys(): + answer = "" + else: + answer = [turn_datum["system_transcript"]] + + #annotations = turn_datum["transcript_annotated"] + #if annotations.get("disambiguation_label", False): + #label = annotations["disambiguation_candidates"] + image_name, scene_id = get_image_name( + dialog_datum["scene_ids"], turn_ind + ) + + # load the scene files and get all the prefab pahts to get the object descriptions for each scene + prefab_paths = [] + scene_path = os.path.join(args["scene_json_folder"], f"{scene_id}_scene.json") + with open(scene_path, "r") as scene_file: + scene_data = json.load(scene_file) + for scene in scene_data["scenes"]: + for object in scene["objects"]: + prefab_paths.append(object["prefab_path"]) + + # get the metadata for all objects of the scene (prefab_paths) + object_metadata = [] + for prefab_path in prefab_paths: + if scene_id[:11] in ["cloth_store", "m_cloth_sto"]: + object_dict = fashion_metadata[prefab_path] + elif scene_id[:7] == "wayfair": + object_dict = furniture_metadata[prefab_path] + object_str = dictionary_to_string(object_dict) + object_metadata.append([object_str]) + + + # If dialog contains multiple scenes, map it accordingly. + #object_map = get_object_mapping(scene_label, args) + new_datum = { + "query": query, + "answer": answer, + "q_turns": copy.deepcopy(q_turns), + "a_turns": copy.deepcopy(a_turns), + "turns": copy.deepcopy(turns), + "object_metadata": object_metadata, + "dialog_id": dialog_datum["dialogue_idx"], + "turn_id": turn_ind, + #"input_text": copy.deepcopy(history), + #"ambiguous_candidates": label, + "image_name": image_name, + #"object_map": object_map, + } + + # only the last dialogue turns are used as samples for the test set + if turn_ind == dial_len - 1: + ambiguous_candidates_data.append(new_datum) + else: + turns.append([turn_datum["transcript"] + turn_datum["system_transcript"]]) + q_turns.append(query) + a_turns.append(answer) + + + # Ignore if system_transcript is not found (last round teststd). + # if turn_datum.get("system_transcript", None): + # history.append(turn_datum["system_transcript"]) + + print(f"# instances [{split}]: {len(ambiguous_candidates_data)}") + save_path = os.path.join( + args["ambiguous_candidates_save_path"], + f"simmc2.1_ambiguous_candidates_dstc11_{split}.json" + ) + print(f"Saving: {save_path}") + with open(save_path, "w") as file_id: + json.dump( + { + "source_path": read_path, + "split": split, + "data": ambiguous_candidates_data, + }, + file_id + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--simmc_train_json", default=None, help="Path to SIMMC 2.1 train" + ) + parser.add_argument( + "--simmc_dev_json", default=None, help="Path to SIMMC 2.1 dev" + ) + parser.add_argument( + "--simmc_devtest_json", default=None, help="Path to SIMMC 2.1 devtest" + ) + parser.add_argument( + "--simmc_teststd_json", default=None, help="Path to SIMMC 2.1 teststd (public)" + ) + parser.add_argument( + "--scene_json_folder", default=None, help="Path to SIMMC scene jsons" + ) + parser.add_argument( + "--ambiguous_candidates_save_path", + required=True, + help="Path to save SIMMC disambiguate JSONs", + ) + parser.add_argument( + "--fashion_prefab_metadata", required=True, + help="Path to the file with all metadata for fashion objects" + ) + parser.add_argument( + "--furniture_prefab_metadata", required=True, + help="Path to the file with all metadata for fashion objects" + ) + + try: + parsed_args = vars(parser.parse_args()) + except (IOError) as msg: + parser.error(str(msg)) + main(parsed_args) diff --git a/src/utils/text_utils.py b/src/utils/text_utils.py new file mode 100644 index 0000000..2379544 --- /dev/null +++ b/src/utils/text_utils.py @@ -0,0 +1,15 @@ +import nltk + +def normalize_sentence(sentence): + return nltk.tokenize.word_tokenize(sentence.lower()) + + +def translate_from_ids_to_text(ids, tokenizer): + text = tokenizer.decode(ids) + if '' in text: + text, pad = text.split('', 1) + if '' in text: + text = text[3:] + + #text_as_list = text.split(' ') + return text \ No newline at end of file diff --git a/test.py b/test.py new file mode 100644 index 0000000..03a9a91 --- /dev/null +++ b/test.py @@ -0,0 +1,71 @@ +from src.models.discriminative_model import DiscriminativeModel +from src.models.generative_model import GenerativeModel +from src.data_modules.dvd_data import DVDData +from src.data_modules.simmc2_data import Simmc2Data +from src.data_modules.avsd_data import AvsdData +from pytorch_lightning import Trainer +import pytorch_lightning as pl +from pytorch_lightning.loggers import WandbLogger +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor +import wandb +from config.config import read_default_config, read_config, update_nested_dicts +import argparse + + +parser = argparse.ArgumentParser(description='Test script for OLViT') +parser.add_argument( + '--ckpt_path', + type=str, + help='Path to the checkpoint to be tested') + +parser.add_argument( + '--cfg_path', + type=str, + help='Path to the config file of the selected checkpoint') + + +if __name__ == '__main__': + wandb.finish() + args = parser.parse_args() + + chkpt_path = args.ckpt_path + + # read the default conifg and update the values with the experiment specific config + config = read_default_config() + experiment_config = read_config(args.cfg_path) + config = update_nested_dicts(old_dict=config, update_dict=experiment_config) + + if 'output_path' not in config['checkpoint'].keys(): + raise Exception('no output path provided in config (full path for disc model only path to output folder for gen. model)') + + available_models = { + 'discriminative': DiscriminativeModel, + 'generative': GenerativeModel + } + data_modules = { + 'dvd': DVDData, + 'simmc2': Simmc2Data, + } + + wandb_logger = WandbLogger( + entity=config['wandb']['entity'], + name=config['wandb']['name'], + group=config['wandb']['group'], + tags=config['wandb']['tags'], + project=config['wandb']['project'], + config=config + ) + + if config['training']['seed'] != None: + pl.seed_everything(config['training']['seed']) + + trainer = Trainer( + logger=wandb_logger, + accelerator='gpu', + devices=[0] + ) + data = data_modules[config['model']['dataset']](config=config) + + model = available_models[config['model']['model_type']](config=config, output_path=config['checkpoint']['output_path']) + trainer.test(model=model, ckpt_path=chkpt_path, dataloaders=data) diff --git a/train.py b/train.py new file mode 100644 index 0000000..2632df2 --- /dev/null +++ b/train.py @@ -0,0 +1,95 @@ +from src.models.discriminative_model import DiscriminativeModel +from src.models.generative_model import GenerativeModel +from src.data_modules.dvd_data import DVDData +from src.data_modules.simmc2_data import Simmc2Data +from pytorch_lightning import Trainer +import pytorch_lightning as pl +from pytorch_lightning.loggers import WandbLogger +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor +import wandb +from config.config import read_default_config, read_config, update_nested_dicts +import argparse + +parser = argparse.ArgumentParser(description='Train script for OLViT') + +parser.add_argument( + '--cfg_path', + default='config/dvd.json', + type=str, + help='Path to the config file of the selected checkpoint') + + +if __name__ == '__main__': + wandb.finish() + args = parser.parse_args() + # read the default conifg and update the values with the experiment specific config + config = read_default_config() + experiment_config = read_config(args.cfg_path) + config = update_nested_dicts(old_dict=config, update_dict=experiment_config) + + available_models = { + 'discriminative': DiscriminativeModel, + 'generative': GenerativeModel + } + data_modules = { + 'dvd': DVDData, + 'simmc2': Simmc2Data, + } + + monitor_score = { + 'discriminative': 'val_acc', + 'generative': 'bleu4' + } + + checkpoint_cb = pl.callbacks.ModelCheckpoint( + monitor=monitor_score[config['model']['model_type']], mode="max", + save_top_k=1, + dirpath=config["checkpoint"]["checkpoint_folder"], + filename=config["checkpoint"]["checkpoint_file_name"], + every_n_epochs=1 + ) + + lr_monitor_cb = LearningRateMonitor( + logging_interval='step' + ) + + callbacks = [] + callbacks.append(checkpoint_cb) + callbacks.append(lr_monitor_cb) + + wandb_logger = WandbLogger( + offline=True, + entity=config['wandb']['entity'], + name=config['wandb']['name'], + group=config['wandb']['group'], + tags=config['wandb']['tags'], + project=config['wandb']['project'], + config=config + ) + + if config['training']['seed'] != None: + pl.seed_everything(config['training']['seed']) + + trainer = Trainer( + logger=wandb_logger, + # detect_anomaly=True, + accelerator='gpu', + devices=[0], + fast_dev_run=False, + max_epochs=config['training']['epochs'], + check_val_every_n_epoch=1, + log_every_n_steps=1, + strategy=pl.strategies.ddp.DDPStrategy(find_unused_parameters=False), + accumulate_grad_batches=config['training']['accumulate_grad_batches'], + precision=32, + callbacks=callbacks + ) + data = data_modules[config['model']['dataset']](config=config) + + if 'output_path' in config['checkpoint'].keys(): + model = available_models[config['model']['model_type']](config=config, output_path=config['checkpoint']['output_path']) + else: + model = available_models[config['model']['model_type']](config=config) + + trainer.fit(model, data)