release code base
This commit is contained in:
commit
efbd43fed1
70 changed files with 4923 additions and 0 deletions
2
.gitattributes
vendored
Normal file
2
.gitattributes
vendored
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
*.tar.gz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pt filter=lfs diff=lfs merge=lfs -text
|
90
README.md
Normal file
90
README.md
Normal file
|
@ -0,0 +1,90 @@
|
||||||
|
<div align="center">
|
||||||
|
<h1> OLViT: Multi-Modal State Tracking via Attention-Based Embeddings for Video-Grounded Dialog </h1>
|
||||||
|
|
||||||
|
**[Adnen Abdessaied][4], [Manuel von Hochmeister][5], [Andreas Bulling][6]** <br> <br>
|
||||||
|
**COLING 2024**, Turin, Italy <img src="misc/italy.png" width="3%" align="center"> <br>
|
||||||
|
**[[Paper][7]]**
|
||||||
|
----------------
|
||||||
|
<img src="misc/teaser.png" width="40%" align="middle"><br><br>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
|
||||||
|
# Table of Contents
|
||||||
|
* [Setup and Dependencies](#Setup-and-Dependencies)
|
||||||
|
* [Download Data](#Download-Data)
|
||||||
|
* [Training](#Training)
|
||||||
|
* [Testing](#Testing)
|
||||||
|
* [Results](#Results)
|
||||||
|
* [Acknowledgements](#Acknowledgements)
|
||||||
|
|
||||||
|
# Setup and Dependencies
|
||||||
|
We implemented our model using Python 3.7, PyTorch 1.11.0 (CUDA 11.3, CuDNN 8.3.2) and PyTorch Lightning. We recommend to setup a virtual environment using Anaconda. <br>
|
||||||
|
1. Install [git lfs][1] on your system
|
||||||
|
2. Clone our repository to download a checpint of our best model and our code
|
||||||
|
```shell
|
||||||
|
git lfs install
|
||||||
|
git clone this_repo.git
|
||||||
|
```
|
||||||
|
3. Create a conda environment and install dependencies
|
||||||
|
```shell
|
||||||
|
conda create -n olvit python=3.7
|
||||||
|
conda activate olvit
|
||||||
|
conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch
|
||||||
|
pip install pytorch-lightning==1.6.3
|
||||||
|
pip install transformers==4.19.2
|
||||||
|
pip install torchtext==0.12.0
|
||||||
|
pip install wandb nltk pandas
|
||||||
|
```
|
||||||
|
# Download Data
|
||||||
|
1. [DVD][2] and [SIMMC 2.1][3] data are included in this repository and will be downloaded using git lfs
|
||||||
|
2. Setup the data by executing
|
||||||
|
```shell
|
||||||
|
chmod u+x setup_data.sh
|
||||||
|
./setup_data.sh
|
||||||
|
```
|
||||||
|
3. This will unpack all the data necessary in ```data/dvd/``` and ```data/simmc/```
|
||||||
|
|
||||||
|
# Training
|
||||||
|
We trained our model on 3 Nvidia Tesla V100-32GB GPUs. The default hyperparameters need to be adjusted if your setup differs from ours.
|
||||||
|
## DVD
|
||||||
|
1. Adjust the config file for DVD according to your hardware specifications in ```config/dvd.json```
|
||||||
|
2. Execute
|
||||||
|
```shell
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1,2 python train.py --cfg_path config/dvd.json
|
||||||
|
```
|
||||||
|
3. Checkpoints will be saved in ```checkpoints/dvd/```
|
||||||
|
|
||||||
|
## SIMMC 2.1
|
||||||
|
1. Adjust the config file for SIMMC 2.1 according to your hardware specifications in ```config/simmc.json```
|
||||||
|
2. Execute
|
||||||
|
```shell
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1,2 python train.py --cfg_path config/simmc.json
|
||||||
|
```
|
||||||
|
3. Checkpoints will be saved in ```checkpoints/simmc/```
|
||||||
|
|
||||||
|
# Testing
|
||||||
|
1. Execute
|
||||||
|
```shell
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python test.py --ckpt_path <PATH_TO_TRAINED_MODEL> --cfg_path <PATH_TO_CONFIG_OF_TRAINED_MODEL>
|
||||||
|
```
|
||||||
|
|
||||||
|
# Results
|
||||||
|
Training using the default config and a similar hardware setup as ours will result in the following performance
|
||||||
|
|
||||||
|
## DVD
|
||||||
|
<img src="misc/results_dvd.png" width="100%" align="middle"><br><br>
|
||||||
|
|
||||||
|
## SIMMC 2.1
|
||||||
|
<img src="misc/results_simmc.png" width="50%" align="middle"><br><br>
|
||||||
|
|
||||||
|
# Acknowledgements
|
||||||
|
Our work relied on the codebases of [DVD][2] and [SIMMC][3]. Thanks to the authors for sharing their code.
|
||||||
|
|
||||||
|
|
||||||
|
[1]: https://git-lfs.com/
|
||||||
|
[2]: https://github.com/facebookresearch/DVDialogues/
|
||||||
|
[3]: https://github.com/facebookresearch/simmc2/
|
||||||
|
[4]: https://perceptualui.org/people/abdessaied/
|
||||||
|
[5]: https://www.linkedin.com/in/manuel-von-hochmeister-285416202/
|
||||||
|
[6]: https://www.perceptualui.org/people/bulling/
|
||||||
|
[7]: https://drive.google.com/file/d/1sDFfGpQ9E9NahT5gw8UjknWt3sNdxM7p/view?usp=sharing
|
0
checkpoints/dvd/.gitkeep
Normal file
0
checkpoints/dvd/.gitkeep
Normal file
0
checkpoints/simmc/.gitkeep
Normal file
0
checkpoints/simmc/.gitkeep
Normal file
0
config/__init__.py
Normal file
0
config/__init__.py
Normal file
26
config/config.py
Normal file
26
config/config.py
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
def read_default_config():
|
||||||
|
dirpath = os.path.dirname(__file__)
|
||||||
|
path = os.path.join(dirpath, "default.json")
|
||||||
|
with open(path) as config_file:
|
||||||
|
config = json.load(config_file)
|
||||||
|
return config
|
||||||
|
|
||||||
|
def read_config(path):
|
||||||
|
with open(path) as config_file:
|
||||||
|
config = json.load(config_file)
|
||||||
|
return config
|
||||||
|
|
||||||
|
def update_nested_dicts(old_dict, update_dict):
|
||||||
|
for key in update_dict:
|
||||||
|
if key in old_dict:
|
||||||
|
old_dict[key].update(update_dict[key])
|
||||||
|
else:
|
||||||
|
old_dict[key] = update_dict[key]
|
||||||
|
return old_dict
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
43
config/default.json
Normal file
43
config/default.json
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
{
|
||||||
|
"wandb": {
|
||||||
|
"entity": "TO_BE_DEFINED",
|
||||||
|
"name": "",
|
||||||
|
"group": "",
|
||||||
|
"tags": [],
|
||||||
|
"project": "olvit"
|
||||||
|
|
||||||
|
},
|
||||||
|
"model": {
|
||||||
|
"model_type": "base_model",
|
||||||
|
"feature_type": "none",
|
||||||
|
"freeze_roberta": true,
|
||||||
|
"v_emb_dim": 16,
|
||||||
|
"dim_feedforward": 400,
|
||||||
|
"n_heads": 9,
|
||||||
|
"fc_dim": 128,
|
||||||
|
"dropout_p": 0.1,
|
||||||
|
"sample_rate_video": 10,
|
||||||
|
"n_encoder_layers": 6,
|
||||||
|
"add_choices_as_context": false,
|
||||||
|
"use_pretrained_lm": false,
|
||||||
|
"projection_as_in_aloe": false,
|
||||||
|
"pretrained_lm_name": ""
|
||||||
|
},
|
||||||
|
"training": {
|
||||||
|
"lr": 1e-4,
|
||||||
|
"total_steps": 200000,
|
||||||
|
"warmup_steps": 4000,
|
||||||
|
"accumulate_grad_batches": 1,
|
||||||
|
"batch_size": 128,
|
||||||
|
"epochs": 40,
|
||||||
|
"seed": null
|
||||||
|
},
|
||||||
|
"datamodule": {
|
||||||
|
"fea_dir": "data/dvd/monet_feats/",
|
||||||
|
"data_dir": "data/dvd/dialogs/"
|
||||||
|
},
|
||||||
|
"checkpoint": {
|
||||||
|
"checkpoint_folder": "checkpoints/",
|
||||||
|
"checkpoint_file_name": "olvit"
|
||||||
|
}
|
||||||
|
}
|
49
config/dvd.json
Normal file
49
config/dvd.json
Normal file
|
@ -0,0 +1,49 @@
|
||||||
|
{
|
||||||
|
"wandb": {
|
||||||
|
"name": "olvit",
|
||||||
|
"group": "dvd",
|
||||||
|
"tags": [],
|
||||||
|
"project": "olvit"
|
||||||
|
|
||||||
|
},
|
||||||
|
"model": {
|
||||||
|
"model_type": "discriminative",
|
||||||
|
"n_heads": 6,
|
||||||
|
"v_emb_dim": 36,
|
||||||
|
"dim_feedforward": 200,
|
||||||
|
"dropout_p": 0.1,
|
||||||
|
"fc_dim": 512,
|
||||||
|
"sample_rate_video": 20,
|
||||||
|
"n_transf_layers": 4,
|
||||||
|
"use_pretrained_lm": true,
|
||||||
|
"projection_as_in_aloe": true,
|
||||||
|
"pretrained_lm_name": "distilroberta-base",
|
||||||
|
"dataset": "dvd"
|
||||||
|
},
|
||||||
|
"extended_model": {
|
||||||
|
"hist_len_for_state_gen": 7,
|
||||||
|
"number_of_relevant_emb": 2,
|
||||||
|
"num_layers_v_state": 2,
|
||||||
|
"num_layers_d_state": 2,
|
||||||
|
"combiner_option": "OptionA",
|
||||||
|
"state_tracker_type": "Transformer",
|
||||||
|
"use_v_state": true,
|
||||||
|
"use_d_state": true,
|
||||||
|
"n_heads_combiner_transformer": 8,
|
||||||
|
"n_heads_state_tracker": 6,
|
||||||
|
"dim_feedforward_v_transformer": 140,
|
||||||
|
"dim_feedforward_d_transformer": 60
|
||||||
|
},
|
||||||
|
"training": {
|
||||||
|
"lr": 1e-4,
|
||||||
|
"warmup_steps": 4000,
|
||||||
|
"total_steps": 200000,
|
||||||
|
"batch_size": 128,
|
||||||
|
"seed": 12345,
|
||||||
|
"epochs": 1000
|
||||||
|
},
|
||||||
|
"checkpoint": {
|
||||||
|
"checkpoint_folder": "checkpoints/dvd",
|
||||||
|
"checkpoint_file_name": "olvit"
|
||||||
|
}
|
||||||
|
}
|
61
config/simmc.json
Normal file
61
config/simmc.json
Normal file
|
@ -0,0 +1,61 @@
|
||||||
|
{
|
||||||
|
"wandb": {
|
||||||
|
"name": "olvit",
|
||||||
|
"group": "simmc2",
|
||||||
|
"tags": [],
|
||||||
|
"project": "olvit"
|
||||||
|
|
||||||
|
},
|
||||||
|
"model": {
|
||||||
|
"model_type": "generative",
|
||||||
|
"dataset": "simmc2",
|
||||||
|
"feature_type": "object_text_features",
|
||||||
|
"object_feature_generator_dim": 50,
|
||||||
|
"n_object_feature_generator_layers": 2,
|
||||||
|
"n_heads": 6,
|
||||||
|
"v_emb_dim": 516,
|
||||||
|
"emb_dim": 216,
|
||||||
|
"dim_feedforward": 200,
|
||||||
|
"dropout_p": 0.1,
|
||||||
|
"fc_dim": 512,
|
||||||
|
"sample_rate_video": 1,
|
||||||
|
"n_encoder_layers": 4,
|
||||||
|
"n_decoder_layers": 4,
|
||||||
|
"use_pretrained_lm": true,
|
||||||
|
"vocab_size": 50265,
|
||||||
|
"projection_as_in_aloe": false,
|
||||||
|
"pretrained_lm_name": "distilroberta-base"
|
||||||
|
},
|
||||||
|
"extended_model": {
|
||||||
|
"hist_len_for_state_gen": 3,
|
||||||
|
"number_of_relevant_emb": 2,
|
||||||
|
"num_layers_v_state": 2,
|
||||||
|
"num_layers_d_state": 2,
|
||||||
|
"combiner_option": "OptionA",
|
||||||
|
"state_tracker_type": "Transformer",
|
||||||
|
"use_v_state": true,
|
||||||
|
"use_d_state": true,
|
||||||
|
"n_heads_combiner_transformer": 8,
|
||||||
|
"n_heads_state_tracker": 6,
|
||||||
|
"dim_feedforward_v_transformer": 140,
|
||||||
|
"dim_feedforward_d_transformer": 60
|
||||||
|
},
|
||||||
|
"training": {
|
||||||
|
"lr": 1e-4,
|
||||||
|
"warmup_steps": 4000,
|
||||||
|
"total_steps": 200000,
|
||||||
|
"batch_size": 8,
|
||||||
|
"seed": 12345,
|
||||||
|
"epochs": 1000
|
||||||
|
},
|
||||||
|
"datamodule": {
|
||||||
|
"fea_dir": "data/simmc/visual_features_resnet50_simmc2.1.pt",
|
||||||
|
"data_dir": "data/simmc/dialogs"
|
||||||
|
},
|
||||||
|
"checkpoint": {
|
||||||
|
"checkpoint_folder": "checkpoints/simmc/",
|
||||||
|
"checkpoint_file_name": "olvit",
|
||||||
|
"output_path": "output/simmc/",
|
||||||
|
"checkpoint_path": "TO_BE_DETERMINED"
|
||||||
|
}
|
||||||
|
}
|
3
data/dvd/dialogs.tar.gz
Normal file
3
data/dvd/dialogs.tar.gz
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:b1b58ee7af90b402eddbde8470dc0333b83ae293a90a93d26af3b8c39c2d9b0e
|
||||||
|
size 395953476
|
3
data/dvd/monet_feats_part00.tar.gz
Normal file
3
data/dvd/monet_feats_part00.tar.gz
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:933c88dbf854d11fca34c388b1b566096b4f9733abd2ded0a1d381b4b1c6a379
|
||||||
|
size 1582620496
|
3
data/dvd/monet_feats_part01.tar.gz
Normal file
3
data/dvd/monet_feats_part01.tar.gz
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:c07f88af54843010899ed1149d16343b9aeb38dbd2cb4e1977bb4c2436d461ec
|
||||||
|
size 1582620496
|
3
data/simmc/dialogs.tar.gz
Normal file
3
data/simmc/dialogs.tar.gz
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:65ed3852c6bbe9f3135558f1bfd3900e8c37ae9af7b8338b3535987408086ca6
|
||||||
|
size 12956266
|
3
data/simmc/visual_features_resnet50_simmc2.1.pt
Normal file
3
data/simmc/visual_features_resnet50_simmc2.1.pt
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:7f7aa24ce312e0cdbdb69021ce593aa985074e3ec88a737bc7af8060ff61d6a8
|
||||||
|
size 81394479
|
0
misc/.gitkeep
Normal file
0
misc/.gitkeep
Normal file
BIN
misc/italy.png
Normal file
BIN
misc/italy.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 3.9 KiB |
BIN
misc/results_dvd.png
Normal file
BIN
misc/results_dvd.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 324 KiB |
BIN
misc/results_simmc.png
Normal file
BIN
misc/results_simmc.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 34 KiB |
BIN
misc/teaser.pdf
Normal file
BIN
misc/teaser.pdf
Normal file
Binary file not shown.
BIN
misc/teaser.png
Normal file
BIN
misc/teaser.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 6.5 MiB |
0
output/.gitkeep
Normal file
0
output/.gitkeep
Normal file
16
setup_data.sh
Normal file
16
setup_data.sh
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
cd data/dvd
|
||||||
|
|
||||||
|
tar -xvzf dialogs.tar.gz
|
||||||
|
cat monet_feats_part* > monet_feats.tar.gz
|
||||||
|
tar -xvzf monet_feats.tar.gz
|
||||||
|
|
||||||
|
rm dialogs.tar.gz
|
||||||
|
rm monet_feats.tar.gz
|
||||||
|
rm monet_feats_part00.tar.gz
|
||||||
|
rm monet_feats_part01.tar.gz
|
||||||
|
|
||||||
|
cd ../simmc
|
||||||
|
tar -xvzf dialogs.tar.gz
|
||||||
|
rm dialogs.tar.gz
|
||||||
|
|
||||||
|
cd ../..
|
0
src/__init__.py
Normal file
0
src/__init__.py
Normal file
25
src/combiner/option_a.py
Normal file
25
src/combiner/option_a.py
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import torch
|
||||||
|
|
||||||
|
class CombinerOptionA(pl.LightningModule):
|
||||||
|
def __init__(self, config=None, model_input_dim=None, use_v_state=False, use_d_state=False):
|
||||||
|
super().__init__()
|
||||||
|
self.use_v_state = use_v_state
|
||||||
|
self.use_d_state = use_d_state
|
||||||
|
|
||||||
|
def forward(self, vision_emb, language_emb, language_emb_mask, v_state, d_state, dummy_word=None):
|
||||||
|
if v_state is not None \
|
||||||
|
and d_state is not None \
|
||||||
|
and self.use_v_state \
|
||||||
|
and self.use_d_state:
|
||||||
|
output = torch.concat([v_state, d_state, vision_emb, language_emb], axis=1)
|
||||||
|
elif d_state is not None and self.use_d_state:
|
||||||
|
output = torch.concat([d_state, vision_emb, language_emb], axis=1)
|
||||||
|
elif v_state is not None and self.use_v_state:
|
||||||
|
output = torch.concat([v_state, vision_emb, language_emb], axis=1)
|
||||||
|
else:
|
||||||
|
output = torch.concat([vision_emb, language_emb], axis=1)
|
||||||
|
if dummy_word is not None:
|
||||||
|
output = torch.concat([dummy_word, output], axis=1)
|
||||||
|
|
||||||
|
return output
|
38
src/combiner/option_b.py
Normal file
38
src/combiner/option_b.py
Normal file
|
@ -0,0 +1,38 @@
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import torch
|
||||||
|
|
||||||
|
class CombinerOptionB(pl.LightningModule):
|
||||||
|
def __init__(self, config=None, model_input_dim=None, use_v_state=False, use_d_state=False):
|
||||||
|
super().__init__()
|
||||||
|
self.use_v_state = use_v_state
|
||||||
|
self.use_d_state = use_d_state
|
||||||
|
|
||||||
|
|
||||||
|
def append_state_to_emb(self, tensor, state):
|
||||||
|
tiling_vector = [1, tensor.shape[1], 1]
|
||||||
|
state_tensor_for_concatenation = torch.tile(state, tiling_vector)
|
||||||
|
result = torch.concat([tensor, state_tensor_for_concatenation], axis=2)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, dummy_word, video_emb, language_emb, language_emb_mask, v_state, d_state):
|
||||||
|
# concatenate the video emb with the video state and the language emb with the dialogue state
|
||||||
|
# if the stat is not used, concatenate itself
|
||||||
|
if v_state is not None \
|
||||||
|
and d_state is not None \
|
||||||
|
and self.use_v_state \
|
||||||
|
and self.use_d_state:
|
||||||
|
video_emb = self.append_state_to_emb(video_emb, v_state)
|
||||||
|
language_emb = self.append_state_to_emb(language_emb, d_state)
|
||||||
|
elif d_state is not None and self.use_d_state:
|
||||||
|
video_emb = self.append_state_to_emb(video_emb, video_emb)
|
||||||
|
language_emb = self.append_state_to_emb(language_emb, d_state)
|
||||||
|
elif v_state is not None and self.use_v_state:
|
||||||
|
video_emb = self.append_state_to_emb(video_emb, v_state)
|
||||||
|
language_emb = self.append_state_to_emb(language_emb, language_emb)
|
||||||
|
else:
|
||||||
|
video_emb = self.append_state_to_emb(video_emb, video_emb)
|
||||||
|
language_emb = self.append_state_to_emb(language_emb, language_emb)
|
||||||
|
|
||||||
|
output = torch.concat([dummy_word, video_emb, language_emb], axis=1)
|
||||||
|
return output
|
69
src/combiner/option_c.py
Normal file
69
src/combiner/option_c.py
Normal file
|
@ -0,0 +1,69 @@
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
class CombinerOptionC(pl.LightningModule):
|
||||||
|
def __init__(self, config, model_input_dim, use_v_state, use_d_state):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.use_v_state = use_v_state
|
||||||
|
self.use_d_state = use_d_state
|
||||||
|
|
||||||
|
self.encoder_layer_d = nn.TransformerEncoderLayer(
|
||||||
|
d_model=model_input_dim,
|
||||||
|
dim_feedforward=self.config['dim_feedforward_d_transformer'],
|
||||||
|
batch_first=True,
|
||||||
|
nhead=self.config['n_heads_combiner_transformer']
|
||||||
|
)
|
||||||
|
self.encoder_layer_v = nn.TransformerEncoderLayer(
|
||||||
|
d_model=model_input_dim,
|
||||||
|
dim_feedforward=self.config['dim_feedforward_v_transformer'],
|
||||||
|
batch_first=True,
|
||||||
|
nhead=self.config['n_heads_combiner_transformer']
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_inputs_for_transformers(self, video_emb, language_emb, language_emb_mask, v_state, d_state):
|
||||||
|
# create masks for the language inputs (video seq should all be 301 frames long and dont need padding)
|
||||||
|
d_input_mask = ~language_emb_mask # emb for pytorch needs to be True for masked tokens (opposite to huggingface mask)
|
||||||
|
# if the dialogue state is used, add a column of Falses at the beeginngin of the tensor (state should be attended -> no mask)
|
||||||
|
if d_state is not None and self.use_d_state:
|
||||||
|
zero_column = torch.zeros((d_input_mask.shape[0], 1), dtype=torch.bool, device=self.device)
|
||||||
|
d_input_mask = torch.concat([zero_column, d_input_mask],axis=1)
|
||||||
|
|
||||||
|
# prepare the input tensors for the different transformer layers depending on which state vectors should be used
|
||||||
|
if v_state is not None \
|
||||||
|
and d_state is not None \
|
||||||
|
and self.use_v_state \
|
||||||
|
and self.use_d_state:
|
||||||
|
v_input = torch.concat([v_state, video_emb], axis=1)
|
||||||
|
d_input = torch.concat([d_state, language_emb], axis=1)
|
||||||
|
elif d_state is not None and self.use_d_state:
|
||||||
|
v_input = video_emb
|
||||||
|
d_input = torch.concat([d_state, language_emb], axis=1)
|
||||||
|
elif v_state is not None and self.use_v_state:
|
||||||
|
v_input = torch.concat([v_state, video_emb], axis=1)
|
||||||
|
d_input = language_emb
|
||||||
|
else:
|
||||||
|
v_input = video_emb
|
||||||
|
d_input = language_emb
|
||||||
|
|
||||||
|
return v_input, d_input, d_input_mask
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, dummy_word, video_emb, language_emb, language_emb_mask, v_state, d_state):
|
||||||
|
# prepare the input tensors for the different transformer layers depending on which state vectors should be used
|
||||||
|
v_input, d_input, d_input_mask = self.prepare_inputs_for_transformers(video_emb, language_emb, language_emb_mask, v_state, d_state)
|
||||||
|
|
||||||
|
# apply the v transformer to the v input and the d transformer to the d input
|
||||||
|
v_emb = self.encoder_layer_v(v_input)
|
||||||
|
d_emb = self.encoder_layer_d(d_input, src_key_padding_mask=d_input_mask)
|
||||||
|
|
||||||
|
# combine the output of the first 2 transformers and add the dummy word (cls token)
|
||||||
|
# put the embedded video and dialog states at the beginning of the combined input
|
||||||
|
v_state_emb = v_emb[:, 0, :].unsqueeze(1)
|
||||||
|
d_state_emb = d_emb[:, 0, :].unsqueeze(1)
|
||||||
|
combined_input = torch.concat([dummy_word, v_state_emb, d_state_emb, v_emb[:, 1:, :], d_emb[:, 1:, :]], axis=1)
|
||||||
|
|
||||||
|
# create combined_input_mask based on the language_emb_mask
|
||||||
|
return combined_input
|
0
src/data_modules/__init__.py
Normal file
0
src/data_modules/__init__.py
Normal file
55
src/data_modules/dvd_data.py
Normal file
55
src/data_modules/dvd_data.py
Normal file
|
@ -0,0 +1,55 @@
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import src.utils.dvd_codebase.data.data_handler as dh
|
||||||
|
from src.utils.dvd_codebase.configs.configs import *
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
import os
|
||||||
|
|
||||||
|
class DVDData(pl.LightningDataModule):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
args.batch_size = config['training']['batch_size']
|
||||||
|
args.fea_dir = config['datamodule']['fea_dir']
|
||||||
|
args.data_dir = config['datamodule']['data_dir']
|
||||||
|
pretrained_lm_name = config['model']['pretrained_lm_name']
|
||||||
|
|
||||||
|
# load dialogues
|
||||||
|
self.train_dials, self.train_vids = dh.load_dials(args, "train")
|
||||||
|
self.val_dials, self.val_vids = dh.load_dials(args, "val")
|
||||||
|
self.test_dials, self.test_vids = dh.load_dials(args, "test")
|
||||||
|
|
||||||
|
# get vocabulary
|
||||||
|
self.vocab, self.answer_list = dh.get_vocabulary(self.train_dials, args)
|
||||||
|
# self.answer_list = ['0', '1', '10', '2', '3', '4', '5', '6', '7', '8', '9', 'False', 'True', 'blue', 'brown', 'cone', 'cube', 'cyan', 'cylinder', 'flying', 'flying,rotating', 'flying,rotating,sliding', 'flying,sliding', 'gold', 'gray', 'green', 'large', 'medium', 'metal', 'no action', 'purple', 'red', 'rotating', 'rotating,sliding', 'rubber', 'sliding', 'small', 'sphere', 'spl', 'yellow']
|
||||||
|
|
||||||
|
train_vft = dh.load_video_features(args, self.train_vids)
|
||||||
|
val_vft = dh.load_video_features(args, self.val_vids)
|
||||||
|
test_vft = dh.load_video_features(args, self.test_vids)
|
||||||
|
|
||||||
|
# create tokenizer
|
||||||
|
if pretrained_lm_name != '':
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(pretrained_lm_name)
|
||||||
|
pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
|
||||||
|
self.vocab['<blank>'] = pad_token_id
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
else:
|
||||||
|
tokenizer = None
|
||||||
|
|
||||||
|
# load data
|
||||||
|
self.train_dials = dh.create_dials(self.train_dials, self.vocab, self.answer_list, train_vft, args, tokenizer=tokenizer)
|
||||||
|
self.val_dials = dh.create_dials(self.val_dials, self.vocab, self.answer_list, val_vft, args, tokenizer=tokenizer)
|
||||||
|
self.test_dials = dh.create_dials(self.test_dials, self.vocab, self.answer_list, test_vft, args, tokenizer=tokenizer)
|
||||||
|
|
||||||
|
|
||||||
|
def train_dataloader(self):
|
||||||
|
dl, _ = dh.create_dataset(self.train_dials, self.vocab, "train", args)
|
||||||
|
return dl
|
||||||
|
|
||||||
|
def val_dataloader(self):
|
||||||
|
dl, _ = dh.create_dataset(self.val_dials, self.vocab, "val", args)
|
||||||
|
return dl
|
||||||
|
|
||||||
|
def test_dataloader(self):
|
||||||
|
dl, _ = dh.create_dataset(self.test_dials, self.vocab, "test", args)
|
||||||
|
return dl
|
||||||
|
|
||||||
|
|
95
src/data_modules/simmc2_data.py
Normal file
95
src/data_modules/simmc2_data.py
Normal file
|
@ -0,0 +1,95 @@
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
from src.utils.simmc2_dataset.dataloader_dvd_model import Simmc2Dataset, VisualFeatureLoader
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
|
||||||
|
def parse_arguments():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--train_file", default='', help="Path to train file")
|
||||||
|
parser.add_argument("--dev_file", default='', help="Path to dev file")
|
||||||
|
parser.add_argument("--devtest_file", default='', help="Path to devtest file")
|
||||||
|
parser.add_argument(
|
||||||
|
"--visual_feature_path", default=None, help="Path to visual features"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--visual_feature_size",
|
||||||
|
type=int,
|
||||||
|
default=516,
|
||||||
|
help="Size of the visual features",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_turns", type=int, default=5, help="Number of turns in history"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_length", type=int, default=512, help="Maximum length in utterance"
|
||||||
|
)
|
||||||
|
parser.add_argument("--use_gpu", dest="use_gpu", action="store_true", default=True)
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Simmc2Data(pl.LightningDataModule):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.args = parse_arguments()
|
||||||
|
self.args.train_file = os.path.join(config['datamodule']['data_dir'], 'simmc2.1_ambiguous_candidates_dstc11_train.json')
|
||||||
|
self.args.dev_file = os.path.join(config['datamodule']['data_dir'], 'simmc2.1_ambiguous_candidates_dstc11_dev.json')
|
||||||
|
self.args.devtest_file = os.path.join(config['datamodule']['data_dir'], 'simmc2.1_ambiguous_candidates_dstc11_devtest.json')
|
||||||
|
self.args.teststd_file = os.path.join(config['datamodule']['data_dir'], 'simmc2.1_dials_dstc11_dev.json')
|
||||||
|
self.args.visual_feature_path = config['datamodule']['fea_dir']
|
||||||
|
pretrained_lm_name = config['model']['pretrained_lm_name']
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_lm_name)
|
||||||
|
self.feature_loader = VisualFeatureLoader(
|
||||||
|
feature_path=self.args.visual_feature_path,
|
||||||
|
feature_size=self.args.visual_feature_size
|
||||||
|
)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def train_dataloader(self):
|
||||||
|
dataset = Simmc2Dataset(
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
feature_loader=self.feature_loader,
|
||||||
|
load_path=self.args.train_file,
|
||||||
|
args=self.args
|
||||||
|
)
|
||||||
|
dl = DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=self.config['training']['batch_size'],
|
||||||
|
shuffle=True,
|
||||||
|
collate_fn=dataset.collate_fn,
|
||||||
|
)
|
||||||
|
return dl
|
||||||
|
|
||||||
|
def val_dataloader(self):
|
||||||
|
dataset = Simmc2Dataset(
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
feature_loader=self.feature_loader,
|
||||||
|
load_path=self.args.dev_file,
|
||||||
|
args=self.args,
|
||||||
|
)
|
||||||
|
dl = DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=self.config['training']['batch_size'],
|
||||||
|
shuffle=False,
|
||||||
|
collate_fn=dataset.collate_fn,
|
||||||
|
)
|
||||||
|
return dl
|
||||||
|
|
||||||
|
def test_dataloader(self):
|
||||||
|
dataset = Simmc2Dataset(
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
feature_loader=self.feature_loader,
|
||||||
|
load_path=self.args.devtest_file,
|
||||||
|
args=self.args,
|
||||||
|
)
|
||||||
|
dl = DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=self.config['training']['batch_size'],
|
||||||
|
shuffle=False,
|
||||||
|
collate_fn=dataset.collate_fn,
|
||||||
|
)
|
||||||
|
return dl
|
0
src/models/__init__.py
Normal file
0
src/models/__init__.py
Normal file
179
src/models/base_model.py
Normal file
179
src/models/base_model.py
Normal file
|
@ -0,0 +1,179 @@
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.optim import AdamW
|
||||||
|
from src.utils.positional_encoding import PositionalEncoding
|
||||||
|
from src.object_description_encoder.object_description_encoder import ObjectDescriptionEncoder
|
||||||
|
import torchmetrics as metrics
|
||||||
|
from transformers import get_cosine_schedule_with_warmup
|
||||||
|
from transformers import AutoModel
|
||||||
|
from src.combiner.option_a import CombinerOptionA
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerModel(pl.LightningModule):
|
||||||
|
def __init__(self, config, output_path=None):
|
||||||
|
super().__init__()
|
||||||
|
self.output_path = output_path
|
||||||
|
self.config = config['model']
|
||||||
|
self.train_config = config['training']
|
||||||
|
|
||||||
|
self.train_acc = metrics.Accuracy('multiclass', num_classes=40)
|
||||||
|
self.val_acc = metrics.Accuracy('multiclass', num_classes=40)
|
||||||
|
self.test_acc = metrics.Accuracy('multiclass', num_classes=40)
|
||||||
|
|
||||||
|
self.best_val_acc = 0
|
||||||
|
self.loss_for_best_val_acc = 0
|
||||||
|
self.best_train_acc = 0
|
||||||
|
|
||||||
|
|
||||||
|
self.combiner = CombinerOptionA()
|
||||||
|
self.initialize_text_encoder_and_feature_mapping()
|
||||||
|
|
||||||
|
self.positional_encoder = PositionalEncoding(
|
||||||
|
d_model=self.model_input_dim, dropout=self.config['dropout_p'], max_len=self.config['dim_feedforward']
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder_layer = nn.TransformerEncoderLayer(
|
||||||
|
d_model=self.model_input_dim,
|
||||||
|
batch_first=True,
|
||||||
|
dropout=self.config['dropout_p'],
|
||||||
|
dim_feedforward=self.config['dim_feedforward'],
|
||||||
|
nhead=self.config['n_heads']
|
||||||
|
)
|
||||||
|
self.encoder = nn.TransformerEncoder(
|
||||||
|
encoder_layer=encoder_layer,
|
||||||
|
num_layers=self.config['n_encoder_layers'],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.loss = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
if self.config['feature_type'] == 'object_text_features':
|
||||||
|
self.object_description_encoder = ObjectDescriptionEncoder(
|
||||||
|
d_model=self.config['v_emb_dim'],
|
||||||
|
config=self.config
|
||||||
|
)
|
||||||
|
# maps the output from the pretrained lm to as smaller size used for the encoding of the object description (reduces transformer size)
|
||||||
|
self.linear_projection_object_description = nn.Linear(self.pretrained_lm.config.hidden_size, self.config['v_emb_dim'])
|
||||||
|
|
||||||
|
|
||||||
|
# tokenizer for translation from ids to text
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(self.config['pretrained_lm_name'])
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_text_encoder_and_feature_mapping(self):
|
||||||
|
if self.config['use_pretrained_lm']:
|
||||||
|
self.pretrained_lm = AutoModel.from_pretrained(
|
||||||
|
self.config['pretrained_lm_name'],
|
||||||
|
add_pooling_layer=False
|
||||||
|
)
|
||||||
|
self.pretrained_lm.eval()
|
||||||
|
# don't train the paramteres of the pretrained lm
|
||||||
|
self.pretrained_lm.config.training = True
|
||||||
|
# for param in self.pretrained_lm.parameters():
|
||||||
|
# param.requires_grad = False
|
||||||
|
|
||||||
|
# initialize the projection layers to map the embeddings to the correct input dim
|
||||||
|
# either use the emb_dim as done in aloe (v_emb_dim * n_heads) or the emb_dim specified in the config
|
||||||
|
if self.config['projection_as_in_aloe']:
|
||||||
|
self.model_input_dim = self.config['n_heads'] * self.config['v_emb_dim']
|
||||||
|
self.linear_projection_video = nn.Linear(self.config['v_emb_dim'], self.model_input_dim - 2)
|
||||||
|
self.linear_projection_text = nn.Linear(self.pretrained_lm.config.hidden_size, self.model_input_dim - 2)
|
||||||
|
else:
|
||||||
|
# take embedding size from config and map the video features from their size to the chose emb size
|
||||||
|
self.linear_projection_video = nn.Linear(self.config['v_emb_dim'], self.config['emb_dim'] - 2)
|
||||||
|
self.linear_projection_text = nn.Linear(self.pretrained_lm.config.hidden_size, self.config['emb_dim'] - 2)
|
||||||
|
self.model_input_dim = self.config['emb_dim']
|
||||||
|
else:
|
||||||
|
# either use the emb_dim as done in aloe (v_emb_dim * n_heads) or the video_emb_dim (2 is either added or subtracted because of the input ids)
|
||||||
|
if self.config['projection_as_in_aloe']:
|
||||||
|
self.model_input_dim = self.config['n_heads'] * self.config['v_emb_dim']
|
||||||
|
else:
|
||||||
|
self.model_input_dim = self.config['emb_dim']
|
||||||
|
self.linear_projection_video = nn.Linear(self.config['v_emb_dim'], self.model_input_dim - 2)
|
||||||
|
self.embed = nn.Embedding(num_embeddings=self.config['vocab_size'], embedding_dim=self.model_input_dim - 2)
|
||||||
|
|
||||||
|
|
||||||
|
def append_ids(self, tensor, id_vector, axis):
|
||||||
|
id_vector = torch.tensor(id_vector, device=self.device)
|
||||||
|
for a in range(len(tensor.shape)):
|
||||||
|
if a != axis:
|
||||||
|
id_vector = torch.unsqueeze(id_vector, axis=a)
|
||||||
|
tiling_vector = [s if i != axis else 1 for i, s in enumerate(tensor.shape)]
|
||||||
|
id_tensor = torch.tile(id_vector, tiling_vector)
|
||||||
|
return torch.concat([tensor, id_tensor], axis=axis)
|
||||||
|
|
||||||
|
|
||||||
|
def downsample_video_emb(self, video_emb):
|
||||||
|
return video_emb[:, ::self.config['sample_rate_video'], :, :]
|
||||||
|
|
||||||
|
|
||||||
|
def unroll_video_emb(self, video_emb):
|
||||||
|
video_emb = video_emb.permute(0, 1, 3, 2)
|
||||||
|
return torch.reshape(video_emb, (video_emb.shape[0], -1, video_emb.shape[3]))
|
||||||
|
|
||||||
|
|
||||||
|
def apply_pretrained_lm(self, query, query_mask):
|
||||||
|
output = self.pretrained_lm(
|
||||||
|
input_ids=query,
|
||||||
|
attention_mask=query_mask
|
||||||
|
)
|
||||||
|
return output['last_hidden_state']
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_lang_emb(self, query, query_mask):
|
||||||
|
# set maximum query length TODO ------ set param in config
|
||||||
|
if query.shape[1] > 100:
|
||||||
|
query = query[:, :100]
|
||||||
|
query_mask = query_mask[:, :100]
|
||||||
|
|
||||||
|
# apply pretrained language model to embed the query if specified
|
||||||
|
if self.config['use_pretrained_lm']:
|
||||||
|
lang_emb = self.apply_pretrained_lm(query, query_mask)
|
||||||
|
else:
|
||||||
|
lang_emb = self.embed(query)
|
||||||
|
|
||||||
|
# Aloe uses an emb_dim of v_emb_dim * n_heads. Or use the emb_dim specified in the config
|
||||||
|
if self.config['use_pretrained_lm']:
|
||||||
|
lang_emb = self.linear_projection_text(lang_emb)
|
||||||
|
|
||||||
|
lang_emb = self.append_ids(lang_emb, [1, 0], 2)
|
||||||
|
lang_emb = self.positional_encoder(lang_emb)
|
||||||
|
return lang_emb
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_video_emb(self, video_emb):
|
||||||
|
# shape: [batch, frames, v_emb_dim, objects]
|
||||||
|
video_emb = self.downsample_video_emb(video_emb)
|
||||||
|
|
||||||
|
# unroll time dimension in object dimension (only take every _ frame) - shape: [batch, objects x frames, v_emb_dim + 2]
|
||||||
|
video_emb = self.unroll_video_emb(video_emb)
|
||||||
|
|
||||||
|
# video_emb need to be projected to either the size of the language emb or the emb_size given by v_emb_dim * n_heads (As done in the Aloe paper)
|
||||||
|
#if self.config['use_pretrained_lm'] or self.config['projection_as_in_aloe']:
|
||||||
|
video_emb = self.linear_projection_video(video_emb)
|
||||||
|
|
||||||
|
video_emb = self.append_ids(video_emb, [0, 1], 2)
|
||||||
|
video_emb = self.positional_encoder(video_emb)
|
||||||
|
return video_emb
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, batch):
|
||||||
|
output = self.answer_query(batch.query, batch.query_mask, batch.vft)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
opt = AdamW(self.parameters(), lr=self.train_config['lr'])
|
||||||
|
sched = get_cosine_schedule_with_warmup(
|
||||||
|
opt,
|
||||||
|
num_warmup_steps=self.train_config['warmup_steps'],
|
||||||
|
num_training_steps=self.train_config['total_steps'],
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
'optimizer': opt,
|
||||||
|
'lr_scheduler': {
|
||||||
|
'scheduler': sched,
|
||||||
|
'interval': 'step'
|
||||||
|
}
|
||||||
|
}
|
137
src/models/discriminative_model.py
Normal file
137
src/models/discriminative_model.py
Normal file
|
@ -0,0 +1,137 @@
|
||||||
|
from src.models.state_tracker_model import StateTrackerModel
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from src.utils.text_utils import translate_from_ids_to_text
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
|
class DiscriminativeModel(StateTrackerModel):
|
||||||
|
def __init__(self, config, output_path=None):
|
||||||
|
super().__init__(config, output_path=output_path)
|
||||||
|
|
||||||
|
self.fc = nn.Linear(self.model_input_dim, self.config["fc_dim"])
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
self.output = nn.Linear(self.config["fc_dim"], 40)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_model(self, language_emb, language_emb_mask, video_emb, v_state=None, d_state=None, answer_emb=None, answer_mask=None, state_generation_mode=None):
|
||||||
|
# analogous to the CLS token from BERT models
|
||||||
|
dummy_word = torch.zeros(self.model_input_dim, requires_grad=True, device=self.device)
|
||||||
|
dummy_word = torch.tile(dummy_word, (language_emb.shape[0], 1, 1))
|
||||||
|
|
||||||
|
# combine state and embeddings
|
||||||
|
input = self.combiner(
|
||||||
|
video_emb,
|
||||||
|
language_emb,
|
||||||
|
language_emb_mask,
|
||||||
|
v_state,
|
||||||
|
d_state,
|
||||||
|
dummy_word
|
||||||
|
)
|
||||||
|
# create input mask based on the language_emb_mask (complete video is unmasked)
|
||||||
|
input_mask = torch.zeros((input.shape[0], input.shape[1]), device=self.device)
|
||||||
|
offset = 1
|
||||||
|
if v_state is not None: offset += 1
|
||||||
|
if d_state is not None: offset += 1
|
||||||
|
# offset is caused by cls token and state vectors
|
||||||
|
if self.config['model_type'] == 'extended_model':
|
||||||
|
# set offset to 1 if combiner B is used -> no state vectors as input. Instead concatenated with embeddings
|
||||||
|
if self.ext_config['combiner_option'] == 'OptionB':
|
||||||
|
offset = 1
|
||||||
|
input_mask[:, video_emb.shape[1] + offset:] = ~language_emb_mask
|
||||||
|
|
||||||
|
x = self.encoder(input, src_key_padding_mask=input_mask)
|
||||||
|
# only pass transformed dummy word to the dense layers
|
||||||
|
x = self.fc(x[:, 0, :])
|
||||||
|
x = self.relu(x)
|
||||||
|
output = self.output(x)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def answer_query(self, query, query_mask, vft, v_state=None, d_state=None, answer=None, answer_mask=None, state_generation_mode=False):
|
||||||
|
video_emb = self.prepare_video_emb(vft)
|
||||||
|
lang_emb = self.prepare_lang_emb(query, query_mask)
|
||||||
|
if answer is not None and answer_mask is not None:
|
||||||
|
answer_emb = self.prepare_lang_emb(answer, answer_mask)
|
||||||
|
else:
|
||||||
|
answer_emb = None
|
||||||
|
output = self.apply_model(lang_emb, query_mask, video_emb, v_state, d_state, answer_emb, answer_mask, state_generation_mode)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def training_step(self, train_batch, batch_idx):
|
||||||
|
train_batch.move_to_cuda()
|
||||||
|
label = torch.squeeze(train_batch.answer)
|
||||||
|
out = self.forward(train_batch)
|
||||||
|
loss = self.loss(out, label)
|
||||||
|
tr_acc = self.train_acc(out.softmax(dim=1), label)
|
||||||
|
if tr_acc > self.best_train_acc:
|
||||||
|
self.best_train_acc = tr_acc
|
||||||
|
self.log("train_acc", tr_acc, prog_bar=True, on_step=False, on_epoch=True, batch_size=train_batch.query.shape[0])
|
||||||
|
self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True, batch_size=train_batch.query.shape[0])
|
||||||
|
print('train_loss: {} | train_acc: {}'.format(loss, tr_acc))
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def validation_step(self, val_batch, batch_idx):
|
||||||
|
val_batch.move_to_cuda()
|
||||||
|
label = torch.squeeze(val_batch.answer)
|
||||||
|
out = self.forward(val_batch)
|
||||||
|
loss = self.loss(out, label)
|
||||||
|
self.val_acc(out.softmax(dim=1), label)
|
||||||
|
self.log("val_acc", self.val_acc, prog_bar=True, on_step=False, on_epoch=True, batch_size=val_batch.query.shape[0])
|
||||||
|
self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True, batch_size=val_batch.query.shape[0])
|
||||||
|
return {'val_loss': loss, 'val_acc': self.val_acc.compute()}
|
||||||
|
|
||||||
|
|
||||||
|
def test_step(self, test_batch, batch_idx):
|
||||||
|
test_batch.move_to_cuda()
|
||||||
|
label = torch.squeeze(test_batch.answer)
|
||||||
|
out = self.forward(test_batch)
|
||||||
|
loss = self.loss(out, label)
|
||||||
|
self.test_acc(out.softmax(dim=1), label)
|
||||||
|
self.log("test_acc", self.test_acc, prog_bar=True, on_step=False, on_epoch=True, batch_size=test_batch.query.shape[0])
|
||||||
|
self.log("test_loss", loss, prog_bar=True, on_step=False, on_epoch=True, batch_size=test_batch.query.shape[0])
|
||||||
|
|
||||||
|
# save the results into a dictionary
|
||||||
|
out = torch.argmax(out, dim=1)
|
||||||
|
|
||||||
|
question_as_text = []
|
||||||
|
for i in range(test_batch.query.shape[0]):
|
||||||
|
question_ids = test_batch.query[i, :]
|
||||||
|
question_as_text.append(translate_from_ids_to_text(question_ids, self.tokenizer))
|
||||||
|
|
||||||
|
self.results['question'].extend(question_as_text)
|
||||||
|
self.results['video_name'].extend(test_batch.video_name)
|
||||||
|
|
||||||
|
self.results['qa_id'].extend(test_batch.qa_ids)
|
||||||
|
self.results['q_type'].extend(test_batch.q_type)
|
||||||
|
self.results['label'].extend(label.tolist())
|
||||||
|
self.results['output'].extend(out.tolist())
|
||||||
|
self.results['attribute_dependency'].extend(test_batch.attribute_dependency)
|
||||||
|
self.results['object_dependency'].extend(test_batch.object_dependency)
|
||||||
|
self.results['temporal_dependency'].extend(test_batch.temporal_dependency)
|
||||||
|
self.results['spatial_dependency'].extend(test_batch.spatial_dependency)
|
||||||
|
self.results['q_complexity'].extend(test_batch.q_complexity)
|
||||||
|
|
||||||
|
|
||||||
|
def on_test_start(self):
|
||||||
|
self.results = {
|
||||||
|
'qa_id': [],
|
||||||
|
'q_type': [],
|
||||||
|
'label': [],
|
||||||
|
'output': [],
|
||||||
|
'attribute_dependency': [],
|
||||||
|
'object_dependency': [],
|
||||||
|
'temporal_dependency': [],
|
||||||
|
'spatial_dependency': [],
|
||||||
|
'q_complexity': [],
|
||||||
|
# only needed for input output analysis
|
||||||
|
'question': [],
|
||||||
|
'video_name': []
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def on_test_end(self):
|
||||||
|
df = pd.DataFrame.from_dict(self.results)
|
||||||
|
df.to_pickle(self.output_path)
|
350
src/models/generative_model.py
Normal file
350
src/models/generative_model.py
Normal file
|
@ -0,0 +1,350 @@
|
||||||
|
# code is partly inspired from https://pytorch.org/tutorials/beginner/translation_transformer.html
|
||||||
|
|
||||||
|
from unittest import result
|
||||||
|
from src.models.state_tracker_model import StateTrackerModel
|
||||||
|
from src.utils.batch_interfaces import batch_interface_simmc2_to_dvd, batch_interface_avsd_to_dvd
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torchtext.data.metrics import bleu_score
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
import nltk
|
||||||
|
import numpy as np
|
||||||
|
from src.utils.text_utils import normalize_sentence, translate_from_ids_to_text
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class GenerativeModel(StateTrackerModel):
|
||||||
|
def __init__(self, config, output_path=None):
|
||||||
|
super().__init__(config, output_path=output_path)
|
||||||
|
|
||||||
|
self.transformer = nn.Transformer(
|
||||||
|
d_model=self.model_input_dim,
|
||||||
|
batch_first=True,
|
||||||
|
dropout=self.config['dropout_p'],
|
||||||
|
dim_feedforward=self.config['dim_feedforward'],
|
||||||
|
nhead=self.config['n_heads'],
|
||||||
|
num_encoder_layers=self.config['n_encoder_layers'],
|
||||||
|
num_decoder_layers=self.config['n_decoder_layers'],
|
||||||
|
custom_encoder=self.encoder
|
||||||
|
)
|
||||||
|
self.prob_generator = nn.Linear(self.model_input_dim, self.config['vocab_size'])
|
||||||
|
|
||||||
|
self.pad_id = 1
|
||||||
|
self.unk_id = 3
|
||||||
|
self.loss = nn.CrossEntropyLoss(ignore_index=self.pad_id)
|
||||||
|
|
||||||
|
|
||||||
|
# tokenizer for translation from ids to text
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(self.config['pretrained_lm_name'])
|
||||||
|
|
||||||
|
# ---TODO: Remove ------
|
||||||
|
self.results = {}
|
||||||
|
self.epoch_count = 0
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------
|
||||||
|
self.batch_interface = batch_interface_simmc2_to_dvd
|
||||||
|
|
||||||
|
|
||||||
|
def encode_object_descriptions(self, vft):
|
||||||
|
#embed the object descriptions using bert and then create the object token using transformer layers
|
||||||
|
if self.config['feature_type'] == "object_text_features":
|
||||||
|
object_features = []
|
||||||
|
for i in range(vft.shape[1]):
|
||||||
|
object_description = vft[:, i, :]
|
||||||
|
object_description_mask = (object_description != 1)
|
||||||
|
embedded_object_description = self.apply_pretrained_lm(object_description, object_description_mask)
|
||||||
|
|
||||||
|
#map embeddings to a smaller size (motivation: reduce transformer sice of object description encoder)
|
||||||
|
embedded_object_description = self.linear_projection_object_description(embedded_object_description)
|
||||||
|
|
||||||
|
#apply transformer to encode the object description
|
||||||
|
object_token = self.object_description_encoder(embedded_object_description)
|
||||||
|
object_features.append(object_token)
|
||||||
|
object_features = torch.concat(object_features, dim=1)
|
||||||
|
#add frame dimension (only one frame in this cas)
|
||||||
|
object_features = object_features.unsqueeze(1)
|
||||||
|
#bring the data to the format [batch_size x frames x emb_dim (desc_text_len) x obj_number]
|
||||||
|
vft = object_features.permute(0, 1, 3, 2)
|
||||||
|
|
||||||
|
return vft
|
||||||
|
|
||||||
|
|
||||||
|
def create_target_mask(self, size):
|
||||||
|
mask = torch.triu(torch.ones((size,size), device=self.device), 1)
|
||||||
|
mask = mask.masked_fill(mask == 1, float('-inf'))
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def generate_prob_for_next_tokens(self, input, answer_emb, tgt_mask, input_mask, answer_mask):
|
||||||
|
x = self.transformer.encoder(input, src_key_padding_mask=input_mask)
|
||||||
|
dec_out = self.transformer.decoder(answer_emb, x, tgt_mask)
|
||||||
|
probs = self.prob_generator(dec_out)
|
||||||
|
|
||||||
|
|
||||||
|
return probs
|
||||||
|
|
||||||
|
|
||||||
|
def generate_complete_answers(self, input, input_mask):
|
||||||
|
# encode the complete batch of questions
|
||||||
|
memory = self.transformer.encoder(input, src_key_padding_mask=input_mask)
|
||||||
|
generated_answers = torch.ones(memory.shape[0], 40, dtype=torch.int) # 20 = max answer length, use unknown token ()
|
||||||
|
|
||||||
|
# generate the answers for each individual question from the batch
|
||||||
|
for i in range(memory.shape[0]):
|
||||||
|
memory_i = memory[i, :, :]
|
||||||
|
memory_i = memory_i.unsqueeze(0)
|
||||||
|
answer_i = torch.zeros((1,1), dtype=torch.int, device=self.device) # Pass start token <s> to decoder as first input. From roberta vocab: <s>": 0, "</s>": 2
|
||||||
|
|
||||||
|
for j in range(40): # 20 = max answer length
|
||||||
|
|
||||||
|
answer_i_emb = self.prepare_lang_emb(answer_i, torch.ones((1, answer_i.shape[0]), device=self.device, dtype=torch.int16))
|
||||||
|
tgt_mask = self.create_target_mask(answer_i.shape[1])
|
||||||
|
decoder_output = self.transformer.decoder(answer_i_emb, memory_i, tgt_mask)
|
||||||
|
prob = self.prob_generator(decoder_output[:, -1, :])
|
||||||
|
next_word = prob.argmax()
|
||||||
|
|
||||||
|
answer_i = torch.concat([answer_i, next_word.unsqueeze(0).unsqueeze(0)], dim=1)
|
||||||
|
if next_word.item() == 2: # eos token in roberta vocab "</s>": 2
|
||||||
|
break
|
||||||
|
|
||||||
|
generated_answers[i, :answer_i.shape[1] - 1] = answer_i[0, 1:]
|
||||||
|
|
||||||
|
return generated_answers
|
||||||
|
|
||||||
|
|
||||||
|
def apply_model(self, language_emb, language_emb_mask, video_emb, v_state=None, d_state=None, answer_emb=None, answer_mask=None, state_generation_mode=False):
|
||||||
|
# combine state and embeddings
|
||||||
|
input = self.combiner(
|
||||||
|
video_emb,
|
||||||
|
language_emb,
|
||||||
|
language_emb_mask,
|
||||||
|
v_state,
|
||||||
|
d_state
|
||||||
|
)
|
||||||
|
# create input mask based on the language_emb_mask (complete video is unmasked)
|
||||||
|
input_mask = torch.zeros((input.shape[0], input.shape[1]), device=self.device)
|
||||||
|
offset = 0
|
||||||
|
if v_state is not None: offset += 1
|
||||||
|
if d_state is not None: offset += 1
|
||||||
|
# offset is caused by state vectors
|
||||||
|
input_mask[:, video_emb.shape[1] + offset:] = ~language_emb_mask
|
||||||
|
tgt_mask = self.create_target_mask(answer_emb.shape[1])
|
||||||
|
|
||||||
|
#-------TODO: Mask padded object embeddings when text based object embeddings are used -------------
|
||||||
|
|
||||||
|
if self.mode == 'train' or state_generation_mode:
|
||||||
|
probs = self.generate_prob_for_next_tokens(input, answer_emb, tgt_mask, input_mask, answer_mask)
|
||||||
|
return probs
|
||||||
|
elif self.mode == 'val':
|
||||||
|
generated_answers = self.generate_complete_answers(input, input_mask)
|
||||||
|
return generated_answers
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_answer_emb_and_mask(self, answer, answer_mask):
|
||||||
|
mask = torch.tril(torch.ones((answer.shape[1], answer.shape[1]), device=self.device))
|
||||||
|
mask = mask.unsqueeze(0)
|
||||||
|
mask = mask.expand(answer.shape[0], -1, -1)
|
||||||
|
answer_emb = self.apply_pretrained_lm(answer, mask)
|
||||||
|
|
||||||
|
answer_emb = self.linear_projection_text(answer_emb)
|
||||||
|
answer_emb = self.append_ids(answer_emb, [1, 0], 2)
|
||||||
|
answer_emb = self.positional_encoder(answer_emb)
|
||||||
|
|
||||||
|
# pytorch interprets True in a mask as padding
|
||||||
|
answer_mask = ~answer_mask
|
||||||
|
answer_emb_final = answer_emb[:, :-1].detach()
|
||||||
|
answer_mask_final = answer_mask[:, :-1].detach()
|
||||||
|
|
||||||
|
return answer_emb_final, answer_mask_final
|
||||||
|
|
||||||
|
|
||||||
|
def answer_query(self, query, query_mask, vft, v_state=None, d_state=None, answer=None, answer_mask=None, state_generation_mode=False):
|
||||||
|
video_emb = self.prepare_video_emb(vft)
|
||||||
|
lang_emb = self.prepare_lang_emb(query, query_mask)
|
||||||
|
answer_emb, answer_mask = self.prepare_answer_emb_and_mask(answer, answer_mask)
|
||||||
|
output = self.apply_model(lang_emb, query_mask, video_emb, v_state, d_state, answer_emb, answer_mask, state_generation_mode)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def training_step(self, train_batch, batch_idx):
|
||||||
|
train_batch = self.batch_interface(train_batch, feature_type=self.config['feature_type'])
|
||||||
|
if self.config['feature_type'] == "object_text_features":
|
||||||
|
train_batch.vft = self.encode_object_descriptions(train_batch.vft)
|
||||||
|
|
||||||
|
logits = self.forward(train_batch)
|
||||||
|
logits = logits.permute(0, 2, 1)
|
||||||
|
|
||||||
|
# replace any unknown token (id = 3) with a padding token in order to also ignore them -> avoid model which outputs unk tokens
|
||||||
|
train_batch.answer[train_batch.answer == 3] = 1
|
||||||
|
loss = self.loss(logits, train_batch.answer[:, 1:]) # ignore padding
|
||||||
|
self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True, batch_size=train_batch.query.shape[0])
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def get_next_token_pred_as_text_and_logits(self, batch):
|
||||||
|
# set mode to train to get the logits instead of completely generated sentences
|
||||||
|
self.mode = 'train'
|
||||||
|
logits = self.forward(batch)
|
||||||
|
logits = logits.permute(0, 2, 1)
|
||||||
|
predicted_tokens = []
|
||||||
|
for j in range(logits.shape[0]):
|
||||||
|
l = logits[j, :, :]
|
||||||
|
ids = [l[:, i].argmax().item() for i in range(l.shape[1])]
|
||||||
|
text = translate_from_ids_to_text(ids, self.tokenizer)
|
||||||
|
predicted_tokens.append(text)
|
||||||
|
# set mode back to val
|
||||||
|
self.mode = 'val'
|
||||||
|
|
||||||
|
return predicted_tokens, logits
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_bleu_score(self, generated_answer_ids, correct_answer):
|
||||||
|
# calculate bleu score for the generated answers compared to the provided correct answers
|
||||||
|
bleu4_scores = []
|
||||||
|
all_generated_answers = []
|
||||||
|
for i in range(generated_answer_ids.shape[0]):
|
||||||
|
generated_answer = generated_answer_ids[i, :].tolist()
|
||||||
|
generated_answer_text = translate_from_ids_to_text(generated_answer, self.tokenizer)
|
||||||
|
all_generated_answers.append(generated_answer_text)
|
||||||
|
correct_answer_text_i = correct_answer[i]
|
||||||
|
score4 = nltk.translate.bleu_score.sentence_bleu(
|
||||||
|
[normalize_sentence(correct_answer_text_i)],
|
||||||
|
normalize_sentence(generated_answer_text),
|
||||||
|
smoothing_function=nltk.translate.bleu_score.SmoothingFunction().method7
|
||||||
|
)
|
||||||
|
bleu4_scores.append(score4)
|
||||||
|
bleu4_score = np.mean(bleu4_scores)
|
||||||
|
return bleu4_score, all_generated_answers
|
||||||
|
|
||||||
|
|
||||||
|
def translate_answer_ids_to_text(self, answer):
|
||||||
|
correct_answer_text = []
|
||||||
|
for i in range(answer.shape[0]):
|
||||||
|
correct_answer_i = answer[i, :].tolist()
|
||||||
|
correct_answer_text_i = translate_from_ids_to_text(correct_answer_i, self.tokenizer)
|
||||||
|
correct_answer_text.append(correct_answer_text_i)
|
||||||
|
return correct_answer_text
|
||||||
|
|
||||||
|
|
||||||
|
def validation_step(self, val_batch, batch_idx):
|
||||||
|
val_batch = self.batch_interface(val_batch, feature_type=self.config['feature_type'])
|
||||||
|
if self.config['feature_type'] == "object_text_features":
|
||||||
|
val_batch.vft = self.encode_object_descriptions(val_batch.vft)
|
||||||
|
|
||||||
|
correct_answer_text = self.translate_answer_ids_to_text(val_batch.answer)
|
||||||
|
generated_answer_ids = self.forward(val_batch)
|
||||||
|
|
||||||
|
# calculate and log bleu score for the generated answers compared to the provided correct answers
|
||||||
|
bleu4_score, generated_answers_text = self.calculate_bleu_score(generated_answer_ids, correct_answer_text)
|
||||||
|
self.log('bleu4', bleu4_score, prog_bar=True, on_step=False, on_epoch=True, batch_size=generated_answer_ids.shape[0])
|
||||||
|
|
||||||
|
# calculate and log the validation loss based on the results from next token predicition (train mode needed)
|
||||||
|
predicted_tokens, logits = self.get_next_token_pred_as_text_and_logits(val_batch)
|
||||||
|
loss = self.loss(logits, val_batch.answer[:, 1:]) # ignore padding
|
||||||
|
self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True, batch_size=val_batch.query.shape[0])
|
||||||
|
|
||||||
|
return {'next_token_predictions': predicted_tokens, 'generated_answers': generated_answers_text, 'correct_answers': correct_answer_text}
|
||||||
|
|
||||||
|
|
||||||
|
def test_step(self, test_batch, batch_idx):
|
||||||
|
dialog_id = test_batch['dialog_id']
|
||||||
|
turn_id = test_batch['turn_id']
|
||||||
|
test_batch = self.batch_interface(test_batch, feature_type=self.config['feature_type'])
|
||||||
|
if self.config['feature_type'] == "object_text_features":
|
||||||
|
test_batch.vft = self.encode_object_descriptions(test_batch.vft)
|
||||||
|
|
||||||
|
correct_answer_text = self.translate_answer_ids_to_text(test_batch.answer)
|
||||||
|
generated_answer_ids = self.forward(test_batch)
|
||||||
|
|
||||||
|
# calculate and log bleu score for the generated answers compared to the provided correct answers
|
||||||
|
bleu4_score, generated_answers_text = self.calculate_bleu_score(generated_answer_ids, correct_answer_text)
|
||||||
|
self.log('bleu4', bleu4_score, prog_bar=True, on_step=False, on_epoch=True, batch_size=generated_answer_ids.shape[0])
|
||||||
|
|
||||||
|
# calculate and log the validation loss based on the results from next token predicition (train mode needed)
|
||||||
|
predicted_tokens, logits = self.get_next_token_pred_as_text_and_logits(test_batch)
|
||||||
|
loss = self.loss(logits, test_batch.answer[:, 1:]) # ignore padding
|
||||||
|
self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True, batch_size=test_batch.query.shape[0])
|
||||||
|
|
||||||
|
return {'turn_id': turn_id, 'next_token_predictions': predicted_tokens, 'dialog_id': dialog_id, 'generated_answers': generated_answers_text, 'correct_answers': correct_answer_text}
|
||||||
|
|
||||||
|
|
||||||
|
def test_epoch_end(self, outputs):
|
||||||
|
|
||||||
|
if self.config['output_format'] == 'submission':
|
||||||
|
responses = []
|
||||||
|
for output in outputs:
|
||||||
|
for t_id, d_id, answer in zip(output['turn_id'], output['dialog_id'], output['generated_answers']):
|
||||||
|
sample = {
|
||||||
|
'dialog_id': d_id,
|
||||||
|
'predictions': [
|
||||||
|
{
|
||||||
|
'turn_id': t_id,
|
||||||
|
'response': answer
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
responses.append(sample)
|
||||||
|
name = 'dstc11-simmc-devtest-pred-subtask-4-generation.json'
|
||||||
|
with open(os.path.join(self.output_path, name), 'w') as file:
|
||||||
|
json.dump(responses, file)
|
||||||
|
|
||||||
|
else:
|
||||||
|
result_idx = 0
|
||||||
|
for output in outputs:
|
||||||
|
for j in range(len(output['next_token_predictions'])):
|
||||||
|
pred = " "
|
||||||
|
corr = " "
|
||||||
|
gen = " "
|
||||||
|
self.results[result_idx] = {
|
||||||
|
'next_token_pred': pred.join(output['next_token_predictions'][j]),
|
||||||
|
'generated_ans': gen.join(output['generated_answers'][j]),
|
||||||
|
'correct': corr.join(output['correct_answers'][j])
|
||||||
|
}
|
||||||
|
result_idx += 1
|
||||||
|
|
||||||
|
name = f'epoch_{self.epoch_count}.json'
|
||||||
|
with open(os.path.join(self.output_path, name), 'w') as file:
|
||||||
|
json.dump(self.results, file)
|
||||||
|
|
||||||
|
|
||||||
|
def validation_epoch_end(self, outputs):
|
||||||
|
result_idx = 0
|
||||||
|
for output in outputs:
|
||||||
|
for j in range(len(output['next_token_predictions'])):
|
||||||
|
pred = " "
|
||||||
|
corr = " "
|
||||||
|
gen = " "
|
||||||
|
self.results[result_idx] = {
|
||||||
|
'next_token_pred': pred.join(output['next_token_predictions'][j]),
|
||||||
|
'generated_ans': gen.join(output['generated_answers'][j]),
|
||||||
|
'correct': corr.join(output['correct_answers'][j])
|
||||||
|
}
|
||||||
|
result_idx += 1
|
||||||
|
|
||||||
|
name = f'epoch_{self.epoch_count}.json'
|
||||||
|
with open(os.path.join(self.output_path, name), 'w') as file:
|
||||||
|
json.dump(self.results, file)
|
||||||
|
|
||||||
|
self.results = {}
|
||||||
|
self.epoch_count += 1
|
||||||
|
|
||||||
|
|
||||||
|
def on_train_epoch_start(self):
|
||||||
|
self.mode = 'train'
|
||||||
|
|
||||||
|
|
||||||
|
def on_validation_epoch_start(self):
|
||||||
|
self.mode = 'val'
|
||||||
|
|
||||||
|
|
||||||
|
def on_test_epoch_start(self):
|
||||||
|
self.mode = 'val'
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
167
src/models/state_tracker_model.py
Normal file
167
src/models/state_tracker_model.py
Normal file
|
@ -0,0 +1,167 @@
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from src.models.base_model import TransformerModel
|
||||||
|
from src.utils.save_attention_weights import SaveOutput
|
||||||
|
from src.utils.custom_transformer_encoder_layer import CustomTransformerEncoderLayer
|
||||||
|
from src.state_trackers.video_state_tracker import VstLSTM
|
||||||
|
from src.state_trackers.dialogue_state_tracker import DstLSTM
|
||||||
|
from src.state_trackers.vst_transformer_based import VstTransformer
|
||||||
|
from src.state_trackers.dst_transformer_based import DstTransformer
|
||||||
|
from src.combiner.option_a import CombinerOptionA
|
||||||
|
from src.combiner.option_b import CombinerOptionB
|
||||||
|
from src.combiner.option_c import CombinerOptionC
|
||||||
|
|
||||||
|
|
||||||
|
class StateTrackerModel(TransformerModel):
|
||||||
|
def __init__(self, config, output_path=None):
|
||||||
|
super().__init__(config, output_path=output_path)
|
||||||
|
self.config = config['model']
|
||||||
|
self.ext_config = config['extended_model']
|
||||||
|
|
||||||
|
combine_state_and_emb_options = {
|
||||||
|
'OptionA': CombinerOptionA,
|
||||||
|
'OptionB': CombinerOptionB,
|
||||||
|
'OptionC': CombinerOptionC,
|
||||||
|
}
|
||||||
|
state_tracker_options = {
|
||||||
|
'Transformer': {
|
||||||
|
'vst': VstTransformer,
|
||||||
|
'dst': DstTransformer
|
||||||
|
},
|
||||||
|
'LSTM': {
|
||||||
|
'vst': VstLSTM,
|
||||||
|
'dst': DstLSTM
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# if option b is used the state vector is appended to each embedding -> input size for the transformers needs to double
|
||||||
|
if self.ext_config['combiner_option'] == 'OptionB':
|
||||||
|
self.model_input_dim *= 2
|
||||||
|
# replace fc layer with a fitting one for the larger embeddings
|
||||||
|
self.fc = nn.Linear(self.model_input_dim, self.config["fc_dim"])
|
||||||
|
|
||||||
|
self.combiner = combine_state_and_emb_options[self.ext_config['combiner_option']](
|
||||||
|
config = self.ext_config,
|
||||||
|
model_input_dim = self.model_input_dim,
|
||||||
|
use_v_state=self.ext_config['use_v_state'],
|
||||||
|
use_d_state=self.ext_config['use_d_state']
|
||||||
|
)
|
||||||
|
encoder_layer = CustomTransformerEncoderLayer(
|
||||||
|
d_model=self.model_input_dim,
|
||||||
|
batch_first=True,
|
||||||
|
dropout=self.config['dropout_p'],
|
||||||
|
dim_feedforward=self.config['dim_feedforward'],
|
||||||
|
nhead=self.config['n_heads']
|
||||||
|
)
|
||||||
|
self.encoder = nn.TransformerEncoder(
|
||||||
|
encoder_layer=encoder_layer,
|
||||||
|
num_layers=self.config['n_encoder_layers'],
|
||||||
|
)
|
||||||
|
self.save_output = SaveOutput()
|
||||||
|
self.hook_handle = self.encoder.layers[-1].self_attn.register_forward_hook(self.save_output)
|
||||||
|
if self.ext_config['use_v_state']:
|
||||||
|
self.video_state_tracker = state_tracker_options[self.ext_config['state_tracker_type']]['vst'](
|
||||||
|
self.model_input_dim,
|
||||||
|
self.config['dropout_p'],
|
||||||
|
self.ext_config
|
||||||
|
)
|
||||||
|
if self.ext_config['use_d_state']:
|
||||||
|
self.dial_state_tracker = state_tracker_options[self.ext_config['state_tracker_type']]['dst'](
|
||||||
|
self.model_input_dim,
|
||||||
|
self.config['dropout_p'],
|
||||||
|
self.ext_config
|
||||||
|
)
|
||||||
|
self.video_emb_start_idx = self.calculate_video_emb_start_idx()
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_video_emb_start_idx(self):
|
||||||
|
video_emb_start_idx = 0
|
||||||
|
if self.config['model_type'] == 'discriminative': video_emb_start_idx += 1
|
||||||
|
if self.ext_config['use_v_state']: video_emb_start_idx += 1
|
||||||
|
if self.ext_config['use_d_state']: video_emb_start_idx += 1
|
||||||
|
return video_emb_start_idx
|
||||||
|
|
||||||
|
|
||||||
|
def determine_relevant_obj_emb(self, attention_weights, vft):
|
||||||
|
# determine index of maximum values
|
||||||
|
obj_emb = self.prepare_video_emb(vft)
|
||||||
|
_, relevant_emb_indices = attention_weights[:, self.video_emb_start_idx:obj_emb.shape[1] + self.video_emb_start_idx].topk(k=self.ext_config['number_of_relevant_emb'], dim=1)
|
||||||
|
relevant_emb = torch.zeros((obj_emb.shape[0], self.ext_config['number_of_relevant_emb'], obj_emb.shape[2]), device=self.device)
|
||||||
|
for i in range(attention_weights.shape[0]):
|
||||||
|
relevant_emb[i, :, :] = obj_emb[i, relevant_emb_indices[i, :]]
|
||||||
|
|
||||||
|
return relevant_emb
|
||||||
|
|
||||||
|
|
||||||
|
def get_attention_weights(self, n_vid_emb):
|
||||||
|
if self.config['model_type'] in ['generative', 'ranking']:
|
||||||
|
# get the attention weights from the query tokens and sum all of them
|
||||||
|
query_start_idx = self.video_emb_start_idx + n_vid_emb
|
||||||
|
attention_weights = self.save_output.outputs[1][:, query_start_idx:, :]
|
||||||
|
attention_weights = attention_weights.sum(dim=1)
|
||||||
|
elif self.config['model_type'] == 'discriminative':
|
||||||
|
# get only the attention weights of the cls token
|
||||||
|
attention_weights = self.save_output.outputs[1][:, 0, :]
|
||||||
|
return attention_weights
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, batch):
|
||||||
|
# initialize the state vectors - initialize as none if we dont want to use them
|
||||||
|
if self.ext_config['use_v_state']:
|
||||||
|
video_state = torch.zeros((batch.query.shape[0], 1, self.model_input_dim), device=self.device)
|
||||||
|
else:
|
||||||
|
video_state = None
|
||||||
|
if self.ext_config['use_d_state']:
|
||||||
|
dial_state = torch.zeros((batch.query.shape[0], 1, self.model_input_dim), device=self.device)
|
||||||
|
else:
|
||||||
|
dial_state = None
|
||||||
|
|
||||||
|
# create the state vectors based on the previous n most recent dialogue turns
|
||||||
|
hist_start_turn_state_gen = batch.turns.shape[1] - 1 - self.ext_config["hist_len_for_state_gen"]
|
||||||
|
for dialogue_round in range(max(0, hist_start_turn_state_gen), batch.turns.shape[1]):
|
||||||
|
question = batch.q_turns[:, dialogue_round, :]
|
||||||
|
|
||||||
|
question_mask = batch.q_turns_mask[:, dialogue_round, :]
|
||||||
|
qa_pair = batch.turns[:, dialogue_round, :]
|
||||||
|
qa_pair_mask = batch.turns_mask[:, dialogue_round, :]
|
||||||
|
|
||||||
|
# pass correct answer tokens to the decoder for training a generative model
|
||||||
|
if self.config['model_type'] in ['generative', 'ranking']:
|
||||||
|
answer = batch.a_turns[:, dialogue_round, :]
|
||||||
|
answer_mask = batch.a_turns_mask[:, dialogue_round, :]
|
||||||
|
# the answer is not used, only the attention weights are relevant for state creation
|
||||||
|
_ = self.answer_query(question, question_mask, batch.vft, video_state, dial_state, answer, answer_mask, state_generation_mode=True)
|
||||||
|
else:
|
||||||
|
_ = self.answer_query(question, question_mask, batch.vft, video_state, dial_state)
|
||||||
|
|
||||||
|
|
||||||
|
# update the states
|
||||||
|
if self.ext_config['use_v_state']:
|
||||||
|
# get the attention weights from the last "answer_query" call and determine the relevant obj
|
||||||
|
attention_weights = self.get_attention_weights(n_vid_emb=batch.vft.shape[1])
|
||||||
|
relevant_obj_emb = self.determine_relevant_obj_emb(attention_weights, batch.vft)
|
||||||
|
# add ids to match the input size of the main transformer block
|
||||||
|
video_state = self.video_state_tracker(relevant_obj_emb)
|
||||||
|
if self.ext_config['use_d_state']:
|
||||||
|
qa_pair_emb = self.prepare_lang_emb(qa_pair, qa_pair_mask)
|
||||||
|
# add ids to match the input size of the main transformer block
|
||||||
|
dial_state = self.dial_state_tracker(qa_pair_emb)
|
||||||
|
|
||||||
|
# delete state of the state tracker
|
||||||
|
if self.ext_config['use_v_state']:
|
||||||
|
self.video_state_tracker.reset()
|
||||||
|
if self.ext_config['use_d_state']:
|
||||||
|
self.dial_state_tracker.reset()
|
||||||
|
|
||||||
|
# answer the actual question
|
||||||
|
# pass correct answer tokens to the decoder for training a generative model
|
||||||
|
if self.config['model_type'] in ['generative', 'ranking']:
|
||||||
|
output = self.answer_query(batch.query, batch.query_mask, batch.vft, video_state, dial_state, batch.answer, batch.answer_mask)
|
||||||
|
else:
|
||||||
|
output = self.answer_query(batch.query, batch.query_mask, batch.vft, video_state, dial_state)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
|
29
src/object_description_encoder/object_description_encoder.py
Normal file
29
src/object_description_encoder/object_description_encoder.py
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
from torch import nn
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class ObjectDescriptionEncoder(pl.LightningModule):
|
||||||
|
def __init__(self, d_model, config):
|
||||||
|
super().__init__()
|
||||||
|
self.d_model = d_model
|
||||||
|
encoder_layer = nn.TransformerEncoderLayer(
|
||||||
|
d_model=d_model,
|
||||||
|
batch_first=True,
|
||||||
|
dropout=config['dropout_p'],
|
||||||
|
dim_feedforward=config['object_feature_generator_dim'],
|
||||||
|
nhead=config['n_heads']
|
||||||
|
)
|
||||||
|
self.encoder = nn.TransformerEncoder(
|
||||||
|
encoder_layer=encoder_layer,
|
||||||
|
num_layers=config['n_object_feature_generator_layers']
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
object_description_embedding = torch.zeros((input.shape[0], 1, self.d_model), device=self.device)
|
||||||
|
input = torch.concat([object_description_embedding, input], dim=1)
|
||||||
|
output = self.encoder(input)
|
||||||
|
object_description_embedding = output[:, 0, :]
|
||||||
|
object_description_embedding = object_description_embedding.unsqueeze(1)
|
||||||
|
return object_description_embedding
|
||||||
|
|
32
src/state_trackers/dialogue_state_tracker.py
Normal file
32
src/state_trackers/dialogue_state_tracker.py
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
from torch import nn
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class DstLSTM(pl.LightningModule):
|
||||||
|
def __init__(self, emb_dim, dropout, config):
|
||||||
|
super().__init__()
|
||||||
|
self.lstm_layer = nn.LSTM(
|
||||||
|
input_size=emb_dim,
|
||||||
|
hidden_size=emb_dim,
|
||||||
|
num_layers=config['num_layers_d_state'],
|
||||||
|
batch_first=True,
|
||||||
|
dropout=dropout
|
||||||
|
)
|
||||||
|
self.h = None
|
||||||
|
self.c = None
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
if self.h is None:
|
||||||
|
_, (self.h, self.c) = self.lstm_layer(input)
|
||||||
|
else:
|
||||||
|
_, (self.h, self.c) = self.lstm_layer(input, (self.h, self.c))
|
||||||
|
|
||||||
|
output = torch.permute(self.h, (1, 0, 2))
|
||||||
|
output = output[:, -1, :]
|
||||||
|
output = output.unsqueeze(1)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.h = None
|
||||||
|
self.c = None
|
36
src/state_trackers/dst_transformer_based.py
Normal file
36
src/state_trackers/dst_transformer_based.py
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
from torch import nn
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class DstTransformer(pl.LightningModule):
|
||||||
|
def __init__(self, emb_dim, dropout, config):
|
||||||
|
super().__init__()
|
||||||
|
self.emb_dim = emb_dim
|
||||||
|
encoder_layer = nn.TransformerEncoderLayer(
|
||||||
|
d_model=emb_dim,
|
||||||
|
batch_first=True,
|
||||||
|
dropout=dropout,
|
||||||
|
dim_feedforward=config['dim_feedforward_d_transformer'],
|
||||||
|
nhead=config['n_heads_state_tracker']
|
||||||
|
)
|
||||||
|
self.encoder = nn.TransformerEncoder(
|
||||||
|
encoder_layer=encoder_layer,
|
||||||
|
num_layers=config['num_layers_d_state']
|
||||||
|
)
|
||||||
|
self.state_vector = None
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
if self.state_vector is None:
|
||||||
|
self.state_vector = torch.zeros((input.shape[0], 1, self.emb_dim), device=self.device)
|
||||||
|
|
||||||
|
input = torch.concat([self.state_vector, input], dim=1)
|
||||||
|
output = self.encoder(input)
|
||||||
|
self.state_vector = output[:, 0, :]
|
||||||
|
self.state_vector = self.state_vector.unsqueeze(1)
|
||||||
|
return self.state_vector
|
||||||
|
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.state_vector = None
|
36
src/state_trackers/video_state_tracker.py
Normal file
36
src/state_trackers/video_state_tracker.py
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
from torch import nn
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class VstLSTM(pl.LightningModule):
|
||||||
|
def __init__(self, emb_dim, dropout, config):
|
||||||
|
super().__init__()
|
||||||
|
self.lstm_layer = nn.LSTM(
|
||||||
|
input_size=emb_dim,
|
||||||
|
hidden_size=emb_dim,
|
||||||
|
num_layers=config['num_layers_v_state'],
|
||||||
|
batch_first=True,
|
||||||
|
dropout=dropout
|
||||||
|
)
|
||||||
|
self.h = None
|
||||||
|
self.c = None
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
if self.h is None:
|
||||||
|
_, (self.h, self.c) = self.lstm_layer(input)
|
||||||
|
else:
|
||||||
|
_, (self.h, self.c) = self.lstm_layer(input, (self.h, self.c))
|
||||||
|
|
||||||
|
output = torch.permute(self.h, (1,0,2))
|
||||||
|
output = output[:, -1, :]
|
||||||
|
output = output.unsqueeze(1)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.h = None
|
||||||
|
self.c = None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
39
src/state_trackers/vst_transformer_based.py
Normal file
39
src/state_trackers/vst_transformer_based.py
Normal file
|
@ -0,0 +1,39 @@
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
from torch import nn
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class VstTransformer(pl.LightningModule):
|
||||||
|
def __init__(self, emb_dim, dropout, config):
|
||||||
|
super().__init__()
|
||||||
|
self.emb_dim = emb_dim
|
||||||
|
encoder_layer = nn.TransformerEncoderLayer(
|
||||||
|
d_model=emb_dim,
|
||||||
|
batch_first=True,
|
||||||
|
dropout=dropout,
|
||||||
|
dim_feedforward=1 + config['number_of_relevant_emb'],
|
||||||
|
nhead=config['n_heads_state_tracker']
|
||||||
|
)
|
||||||
|
self.encoder = nn.TransformerEncoder(
|
||||||
|
encoder_layer=encoder_layer,
|
||||||
|
num_layers=config['num_layers_v_state']
|
||||||
|
)
|
||||||
|
self.state_vector = None
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
if self.state_vector is None:
|
||||||
|
self.state_vector = torch.zeros((input.shape[0], 1, self.emb_dim), device=self.device)
|
||||||
|
|
||||||
|
input = torch.concat([self.state_vector, input], dim=1)
|
||||||
|
output = self.encoder(input)
|
||||||
|
self.state_vector = output[:, 0, :]
|
||||||
|
self.state_vector = self.state_vector.unsqueeze(1)
|
||||||
|
return self.state_vector
|
||||||
|
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.state_vector = None
|
||||||
|
|
||||||
|
|
||||||
|
|
0
src/utils/__init__.py
Normal file
0
src/utils/__init__.py
Normal file
106
src/utils/batch_interfaces.py
Normal file
106
src/utils/batch_interfaces.py
Normal file
|
@ -0,0 +1,106 @@
|
||||||
|
import torch
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Batch:
|
||||||
|
query: torch.Tensor
|
||||||
|
query_mask: torch.Tensor
|
||||||
|
vft: torch.Tensor
|
||||||
|
turns: torch.Tensor
|
||||||
|
turns_mask: torch.Tensor
|
||||||
|
q_turns: torch.Tensor
|
||||||
|
q_turns_mask: torch.Tensor
|
||||||
|
a_turns: torch.Tensor
|
||||||
|
a_turns_mask: torch.Tensor
|
||||||
|
answer: torch.Tensor
|
||||||
|
answer_mask: torch.Tensor
|
||||||
|
answer_candidates: Optional[torch.Tensor] = None
|
||||||
|
answer_candidates_mask: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
|
||||||
|
# ---- TODO: Replace with function for the Mask RCNN features ----
|
||||||
|
def create_monet_like_vft(vft):
|
||||||
|
target_dim = 36
|
||||||
|
remainder = vft.shape[1] % target_dim
|
||||||
|
vft = vft[:, :-remainder].reshape((vft.shape[0], -1, target_dim))
|
||||||
|
vft = vft.unsqueeze(3)
|
||||||
|
return vft
|
||||||
|
|
||||||
|
|
||||||
|
def batch_interface_simmc2_to_dvd(batch, feature_type):
|
||||||
|
if feature_type == 'resnet50':
|
||||||
|
vft = batch['features']
|
||||||
|
vft = vft.unsqueeze(3)
|
||||||
|
elif feature_type == "object_text_features":
|
||||||
|
vft = batch['object_features']
|
||||||
|
# add frame dimension (only one frame in this cas)
|
||||||
|
#vft = vft.unsqueeze(1)
|
||||||
|
# bring the data to the format [batch_size x frames x emb_dim (desc_text_len) x obj_number]
|
||||||
|
#vft = vft.permute(0, 1, 3, 2)
|
||||||
|
|
||||||
|
batch_in_dvd_format = Batch(
|
||||||
|
query=batch['query'],
|
||||||
|
query_mask=(batch['query'] != 1),
|
||||||
|
vft=vft,
|
||||||
|
turns=batch['turns'],
|
||||||
|
turns_mask=(batch['turns'] != 1),
|
||||||
|
q_turns=batch['q_turns'],
|
||||||
|
q_turns_mask=(batch['q_turns'] != 1),
|
||||||
|
a_turns=batch['a_turns'],
|
||||||
|
a_turns_mask=(batch['a_turns'] != 1),
|
||||||
|
answer=batch['answer'].type(torch.int64),
|
||||||
|
answer_mask=(batch['answer'] != 1),
|
||||||
|
answer_candidates=batch['answer_candidates'],
|
||||||
|
answer_candidates_mask=(batch['answer_candidates'] != 1)
|
||||||
|
)
|
||||||
|
return batch_in_dvd_format
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def batch_interface_avsd_to_dvd(batch, feature_type):
|
||||||
|
# map question to query
|
||||||
|
query = batch['ques'][:,-1, :]
|
||||||
|
query_mask = (query != 1)
|
||||||
|
|
||||||
|
# map vid_feat to vft
|
||||||
|
# TODO: Use other video features ------!!!-------
|
||||||
|
if feature_type == 'i3d':
|
||||||
|
vft = create_monet_like_vft(batch['vid_feat'])
|
||||||
|
else:
|
||||||
|
vft = batch['vid_feat']
|
||||||
|
|
||||||
|
|
||||||
|
q_turns = batch['ques'][:, :9, :]
|
||||||
|
q_turns_mask = (q_turns != 1)
|
||||||
|
|
||||||
|
index_tensor = batch['ans_ind'].unsqueeze(2)
|
||||||
|
index_tensor = index_tensor.repeat(1,1,20)
|
||||||
|
index_tensor = index_tensor.unsqueeze(2)
|
||||||
|
a_turns = batch['opt'].gather(2, index_tensor)
|
||||||
|
a_turns = a_turns.squeeze(2)
|
||||||
|
|
||||||
|
# turns should only contain the previous questions (first 9 turns)
|
||||||
|
a_turns, answer = a_turns.split([9, 1], dim=1)
|
||||||
|
answer = answer.squeeze(1)
|
||||||
|
a_turns_mask = (a_turns != 1)
|
||||||
|
answer_mask = (answer != 1)
|
||||||
|
|
||||||
|
# concat questions and a_turns to create turns tensor
|
||||||
|
turns = torch.concat((q_turns, a_turns), 2)
|
||||||
|
turns_mask = (turns != 1)
|
||||||
|
|
||||||
|
batch_in_dvd_format = Batch(
|
||||||
|
query,
|
||||||
|
query_mask,
|
||||||
|
vft,
|
||||||
|
turns,
|
||||||
|
turns_mask,
|
||||||
|
q_turns,
|
||||||
|
q_turns_mask,
|
||||||
|
a_turns,
|
||||||
|
a_turns_mask,
|
||||||
|
answer,
|
||||||
|
answer_mask
|
||||||
|
)
|
||||||
|
return batch_in_dvd_format
|
84
src/utils/custom_transformer_encoder_layer.py
Normal file
84
src/utils/custom_transformer_encoder_layer.py
Normal file
|
@ -0,0 +1,84 @@
|
||||||
|
# https://pytorch.org/docs/stable/_modules/torch/nn/modules/transformer.html#TransformerEncoderLayer
|
||||||
|
from typing import Optional, Any, Union, Callable
|
||||||
|
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from torch.nn.modules import Module
|
||||||
|
from torch.nn import MultiheadAttention
|
||||||
|
#from nn.container import ModuleList
|
||||||
|
#from ..init import xavier_uniform_
|
||||||
|
from torch.nn import Dropout
|
||||||
|
from torch.nn import Linear
|
||||||
|
from torch.nn import LayerNorm
|
||||||
|
|
||||||
|
|
||||||
|
class CustomTransformerEncoderLayer(Module):
|
||||||
|
|
||||||
|
__constants__ = ['batch_first', 'norm_first']
|
||||||
|
|
||||||
|
def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
|
||||||
|
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
|
||||||
|
layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
|
||||||
|
device=None, dtype=None) -> None:
|
||||||
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||||
|
super().__init__()
|
||||||
|
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
|
||||||
|
**factory_kwargs)
|
||||||
|
# Implementation of Feedforward model
|
||||||
|
self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs)
|
||||||
|
self.dropout = Dropout(dropout)
|
||||||
|
self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs)
|
||||||
|
|
||||||
|
self.norm_first = norm_first
|
||||||
|
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
||||||
|
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
||||||
|
self.dropout1 = Dropout(dropout)
|
||||||
|
self.dropout2 = Dropout(dropout)
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
if 'activation' not in state:
|
||||||
|
state['activation'] = F.relu
|
||||||
|
super().__setstate__(state)
|
||||||
|
|
||||||
|
def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
|
||||||
|
r"""Pass the input through the encoder layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
src: the sequence to the encoder layer (required).
|
||||||
|
src_mask: the mask for the src sequence (optional).
|
||||||
|
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||||
|
|
||||||
|
Shape:
|
||||||
|
see the docs in Transformer class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
|
||||||
|
|
||||||
|
x = src
|
||||||
|
if self.norm_first:
|
||||||
|
x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
|
||||||
|
x = x + self._ff_block(self.norm2(x))
|
||||||
|
else:
|
||||||
|
x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask))
|
||||||
|
x = self.norm2(x + self._ff_block(x))
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# self-attention block
|
||||||
|
def _sa_block(self, x: Tensor,
|
||||||
|
attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor:
|
||||||
|
x = self.self_attn(x, x, x,
|
||||||
|
attn_mask=attn_mask,
|
||||||
|
key_padding_mask=key_padding_mask,
|
||||||
|
need_weights=True)[0]
|
||||||
|
return self.dropout1(x)
|
||||||
|
|
||||||
|
# feed forward block
|
||||||
|
def _ff_block(self, x: Tensor) -> Tensor:
|
||||||
|
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
||||||
|
return self.dropout2(x)
|
||||||
|
|
0
src/utils/dvd_codebase/__init__.py
Normal file
0
src/utils/dvd_codebase/__init__.py
Normal file
39
src/utils/dvd_codebase/configs/configs.py
Normal file
39
src/utils/dvd_codebase/configs/configs.py
Normal file
|
@ -0,0 +1,39 @@
|
||||||
|
"""
|
||||||
|
Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
All rights reserved.
|
||||||
|
This source code is licensed under the license found in the
|
||||||
|
LICENSE file in the root directory of this source tree.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--debug', default=0, type=int, help='')
|
||||||
|
|
||||||
|
# Data /projects/hochmeister/CATER-videos/features/per_video'
|
||||||
|
#parser.add_argument('--fea-dir', default='/scratch/hochmeister/CATER-videos/features/monet_pretrained_on_clevr/per_video', type=str, help='Image feature files (.pkl)')
|
||||||
|
#parser.add_argument('--data-dir', default='/scratch/hochmeister/DVDData/small_subset/', type=str,help='Path to training feature files')
|
||||||
|
parser.add_argument('--output-dir', default='/scratch/abdessaied/projects/olvit/msc2022_hochmeister/checkpoints/avsd_code_testing', type=str,help='output path of model and params')
|
||||||
|
parser.add_argument('--num-workers', default=20, type=int, help='')
|
||||||
|
parser.add_argument('--device', default='0', type=str, help='')
|
||||||
|
|
||||||
|
# Training
|
||||||
|
parser.add_argument('--num-epochs', '-e', default=15, type=int,help='Number of epochs')
|
||||||
|
#parser.add_argument('--batch-size', '-b', default=85, type=int,help='Batch size in training')
|
||||||
|
# others
|
||||||
|
parser.add_argument('--verbose', '-v', default=0, type=int,help='verbose level')
|
||||||
|
|
||||||
|
args, unknown = parser.parse_known_args()
|
||||||
|
|
||||||
|
print(args)
|
||||||
|
|
||||||
|
# Presetting
|
||||||
|
if args.verbose >= 1:
|
||||||
|
logging.basicConfig(level=logging.DEBUG,
|
||||||
|
format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s')
|
||||||
|
else:
|
||||||
|
logging.basicConfig(level=logging.INFO,
|
||||||
|
format='%(asctime)s %(levelname)s: %(message)s')
|
0
src/utils/dvd_codebase/data/__init__.py
Normal file
0
src/utils/dvd_codebase/data/__init__.py
Normal file
282
src/utils/dvd_codebase/data/analysis_utils.py
Normal file
282
src/utils/dvd_codebase/data/analysis_utils.py
Normal file
|
@ -0,0 +1,282 @@
|
||||||
|
"""
|
||||||
|
Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
All rights reserved.
|
||||||
|
This source code is licensed under the license found in the
|
||||||
|
LICENSE file in the root directory of this source tree.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import glob, json, pdb
|
||||||
|
from tqdm import tqdm
|
||||||
|
import pandas as pd
|
||||||
|
import copy, os
|
||||||
|
|
||||||
|
def get_question_type(template, prior_template):
|
||||||
|
last_node_type = template['nodes'][-1]['type']
|
||||||
|
text = template['text'][0].lower()
|
||||||
|
if 'same set of activities' in text:
|
||||||
|
qtype = 'compare action set'
|
||||||
|
elif 'same sequence of activities' in text:
|
||||||
|
qtype = 'compare action sequence'
|
||||||
|
elif 'frequently' in text:
|
||||||
|
qtype = 'compare int'
|
||||||
|
elif 'how many times' in text:
|
||||||
|
qtype = 'action count'
|
||||||
|
elif 'how many' in text or 'what number' in text:
|
||||||
|
qtype = 'obj count'
|
||||||
|
elif 'is there' in text:
|
||||||
|
qtype = 'obj exist'
|
||||||
|
elif 'what color' in text or 'what material' in text or 'what shape' in text or 'what size' in text:
|
||||||
|
qtype = 'attr query'
|
||||||
|
elif 'what type of action' in text or 'what is the' in text or 'what types of action' in text:
|
||||||
|
qtype = 'action query'
|
||||||
|
else:
|
||||||
|
assert 'what about' in text
|
||||||
|
qtype = get_question_type(prior_template, None)
|
||||||
|
return qtype
|
||||||
|
|
||||||
|
def get_question_subtype(template, prior_template):
|
||||||
|
last_node_type = template['nodes'][-1]['type']
|
||||||
|
text = template['text'][0].lower()
|
||||||
|
if 'same set of activities' in text:
|
||||||
|
if 'how many' in text:
|
||||||
|
qtype = 'compare action set (count)'
|
||||||
|
else:
|
||||||
|
qtype = 'compare action set (exist)'
|
||||||
|
elif 'same sequence of activities' in text:
|
||||||
|
if 'how many' in text:
|
||||||
|
qtype = 'compare action seq (count)'
|
||||||
|
else:
|
||||||
|
qtype = 'compare action seq (exist)'
|
||||||
|
elif 'frequently' in text:
|
||||||
|
if 'as frequently' in text:
|
||||||
|
qtype = 'compare int (equal)'
|
||||||
|
elif 'less frequently' in text:
|
||||||
|
qtype = 'compare int (less)'
|
||||||
|
elif 'more frequently' in text:
|
||||||
|
qtype = 'compare int (more)'
|
||||||
|
elif 'how many times' in text:
|
||||||
|
qtype = 'action count'
|
||||||
|
elif 'how many' in text or 'what number' in text:
|
||||||
|
qtype = 'obj count'
|
||||||
|
elif 'is there' in text:
|
||||||
|
qtype = 'obj exist'
|
||||||
|
elif 'what color' in text or 'what about its color' in text:
|
||||||
|
qtype = 'attr query (color)'
|
||||||
|
elif 'what material' in text or 'what about its material'in text:
|
||||||
|
qtype = 'attr query (material)'
|
||||||
|
elif 'what shape' in text or 'what about its shape' in text:
|
||||||
|
qtype = 'attr query (shape)'
|
||||||
|
elif 'what size' in text or 'what about its size' in text:
|
||||||
|
qtype = 'attr query (size)'
|
||||||
|
elif 'what type of action' in text or 'what is the' in text or 'what types of action' in text:
|
||||||
|
if '<o>' in text:
|
||||||
|
qtype = 'action query (by order)'
|
||||||
|
elif '<f>' in text:
|
||||||
|
qtype = 'ation query (by freq)'
|
||||||
|
else:
|
||||||
|
qtype = 'action query (all actions)'
|
||||||
|
else:
|
||||||
|
assert 'what about' in text
|
||||||
|
assert 'color' not in text and 'size' not in text and \
|
||||||
|
'shape' not in text and 'material' not in text
|
||||||
|
qtype = get_question_subtype(prior_template, None)
|
||||||
|
return qtype
|
||||||
|
|
||||||
|
def get_question_complexity(turn, template_fn):
|
||||||
|
template = turn['template']
|
||||||
|
interval_type = template['interval_type']
|
||||||
|
last_node_type = template['nodes'][-1]['type']
|
||||||
|
second_last_node_type = template['nodes'][-2]['type']
|
||||||
|
|
||||||
|
if interval_type == 'none':
|
||||||
|
return 'none'
|
||||||
|
elif interval_type == 'atomic':
|
||||||
|
if 'one_hop' in template_fn:
|
||||||
|
return 'atomic (spatial)'
|
||||||
|
else:
|
||||||
|
return 'atomic (non-spatial)'
|
||||||
|
#return 'atomic'
|
||||||
|
elif interval_type == 'compositional':
|
||||||
|
return 'compositional'
|
||||||
|
|
||||||
|
def get_accuracies_by_type(all_types, models, all_answers, all_results, output_file):
|
||||||
|
types = sorted(set(all_types))
|
||||||
|
accuracies = {}
|
||||||
|
for t in types:
|
||||||
|
accuracies[t] = []
|
||||||
|
for model in models:
|
||||||
|
nb_corrects = 0
|
||||||
|
count = 0
|
||||||
|
results = all_results[model]
|
||||||
|
for a_idx, a in enumerate(all_answers):
|
||||||
|
curr_type = all_types[a_idx]
|
||||||
|
if curr_type != t: continue
|
||||||
|
pred = results[a_idx]
|
||||||
|
if str(pred) == str(a):
|
||||||
|
nb_corrects += 1
|
||||||
|
count += 1
|
||||||
|
acc = nb_corrects/count
|
||||||
|
accuracies[t].append(acc)
|
||||||
|
df = copy.deepcopy(accuracies)
|
||||||
|
df['model'] = models
|
||||||
|
df = pd.DataFrame(data=df, columns=['model'] + list(accuracies.keys()))
|
||||||
|
df.to_csv(output_file)
|
||||||
|
return types, accuracies, df
|
||||||
|
|
||||||
|
def get_transfer_accuracies(all_types, models, all_answers, all_results, output_file, is_video_update=False, is_all=False):
|
||||||
|
accuracies = []
|
||||||
|
for model in models:
|
||||||
|
results = all_results[model]
|
||||||
|
nb_corrects = 0
|
||||||
|
count = 0
|
||||||
|
for a_idx, a in enumerate(all_answers):
|
||||||
|
if is_all:
|
||||||
|
is_single_turn = True
|
||||||
|
for k,v in all_types.items():
|
||||||
|
if v[a_idx] != 'none':
|
||||||
|
is_single_turn = False
|
||||||
|
break
|
||||||
|
if is_single_turn: continue
|
||||||
|
else:
|
||||||
|
curr_type = all_types[a_idx]
|
||||||
|
if is_video_update:
|
||||||
|
if curr_type != 'video_update': continue
|
||||||
|
else:
|
||||||
|
if curr_type != 'yes': continue
|
||||||
|
prior_pred_a = results[a_idx-1]
|
||||||
|
prior_gt_a = all_answers[a_idx-1]
|
||||||
|
if str(prior_pred_a) != str(prior_gt_a): continue
|
||||||
|
pred_a = results[a_idx]
|
||||||
|
gt_a = all_answers[a_idx]
|
||||||
|
if str(pred_a) == str(gt_a):
|
||||||
|
nb_corrects += 1
|
||||||
|
count += 1
|
||||||
|
if count == 0:
|
||||||
|
acc = 0
|
||||||
|
else:
|
||||||
|
#pdb.set_trace()
|
||||||
|
acc = nb_corrects/count
|
||||||
|
accuracies.append(acc)
|
||||||
|
df = {}
|
||||||
|
df['accuracies'] = accuracies
|
||||||
|
df['model'] = models
|
||||||
|
df = pd.DataFrame(data=df, columns=['model', 'accuracies'])
|
||||||
|
df.to_csv(output_file)
|
||||||
|
return df
|
||||||
|
|
||||||
|
def get_start_end_time(period):
|
||||||
|
start, end = period
|
||||||
|
if start is None:
|
||||||
|
start = 0
|
||||||
|
else:
|
||||||
|
start = start[-1]
|
||||||
|
if end is None:
|
||||||
|
end = 301
|
||||||
|
else:
|
||||||
|
end = end[-1]
|
||||||
|
return start, end
|
||||||
|
|
||||||
|
def get_period_size(period):
|
||||||
|
if period is None:
|
||||||
|
return 0
|
||||||
|
start, end = get_start_end_time(period)
|
||||||
|
return end - start
|
||||||
|
|
||||||
|
def get_overlap_period(curr_period, last_period, ratio=False):
|
||||||
|
if curr_period is None:
|
||||||
|
return -1
|
||||||
|
if last_period is None:
|
||||||
|
return 0
|
||||||
|
s1, e1 = get_start_end_time(curr_period)
|
||||||
|
s2, e2 = get_start_end_time(last_period)
|
||||||
|
if s2<e1 and s1<e2:
|
||||||
|
if ratio:
|
||||||
|
return get_period_ratio_bin((min(e1,e2)-max(s1,s2))/(e2-s2))
|
||||||
|
else:
|
||||||
|
return (min(e1,e2)-max(s1,s2))
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def get_period_distance(curr_period, last_period, point='start'):
|
||||||
|
if curr_period is None:
|
||||||
|
return -1
|
||||||
|
if last_period is None:
|
||||||
|
return -1
|
||||||
|
s1, e1 = get_start_end_time(curr_period)
|
||||||
|
s2, e2 = get_start_end_time(last_period)
|
||||||
|
if point == 'start':
|
||||||
|
return abs(s1-s2)
|
||||||
|
elif point == 'end':
|
||||||
|
return abs(e1-e2)
|
||||||
|
|
||||||
|
def get_period_ratio_bin(ratio):
|
||||||
|
if ratio == 0:
|
||||||
|
return 0
|
||||||
|
for n in range(0,10):
|
||||||
|
if ratio*10>n:
|
||||||
|
bin = n
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
return bin
|
||||||
|
|
||||||
|
def get_obj_turn_dist(used_objects, dependencies, template, turn_idx):
|
||||||
|
all_dists = [0]
|
||||||
|
|
||||||
|
if dependencies['object'] != 'none':
|
||||||
|
if dependencies['object'] == 'earlier_unique':
|
||||||
|
obj_id = str(template['earlier_unique_obj'])
|
||||||
|
if obj_id not in used_objects:
|
||||||
|
pdb.set_trace()
|
||||||
|
turn_dist = turn_idx - used_objects[obj_id]['original_turn'] + 1
|
||||||
|
all_dists.append(turn_dist)
|
||||||
|
|
||||||
|
if dependencies['temporal'] != 'none':
|
||||||
|
if 'earlier_unique' in dependencies['temporal']:
|
||||||
|
obj_id = str(template['temporal_obj_id'])
|
||||||
|
if obj_id not in used_objects:
|
||||||
|
pdb.set_trace()
|
||||||
|
turn_dist = turn_idx - used_objects[obj_id]['original_turn'] + 1
|
||||||
|
all_dists.append(turn_dist)
|
||||||
|
|
||||||
|
return max(all_dists)
|
||||||
|
|
||||||
|
def get_stats(dials):
|
||||||
|
videos = set()
|
||||||
|
questions = set()
|
||||||
|
for dial in dials:
|
||||||
|
for turn in dial:
|
||||||
|
question = turn['question']
|
||||||
|
video = '{}-{}'.format(turn['split'], turn['image_filename'])
|
||||||
|
videos.add(video)
|
||||||
|
questions.add(question)
|
||||||
|
print('# videos: {}'.format(len(videos)))
|
||||||
|
print("# dialogues: {}".format(len(dials)))
|
||||||
|
print("# unique questions: {}".format(len(questions)))
|
||||||
|
output = {
|
||||||
|
'#videos': len(videos),
|
||||||
|
'#dialogues': len(dials),
|
||||||
|
'#unique questions': len(questions)
|
||||||
|
}
|
||||||
|
return output
|
||||||
|
|
||||||
|
def find_video_end_range(end_time):
|
||||||
|
ranges = [0, 30, 60, 90, 120, 150, 180, 210, 240, 270]
|
||||||
|
if end_time is None:
|
||||||
|
return 9
|
||||||
|
for idx, r in enumerate(ranges):
|
||||||
|
if end_time[-1] > r:
|
||||||
|
curr_r = idx
|
||||||
|
else:
|
||||||
|
return curr_r
|
||||||
|
return 9
|
||||||
|
|
||||||
|
def find_video_start_range(start_time):
|
||||||
|
ranges = [400, 270, 240, 210, 180, 150, 120, 90, 60, 30]
|
||||||
|
if start_time is None:
|
||||||
|
return 0
|
||||||
|
for idx, r in enumerate(ranges):
|
||||||
|
if start_time[-1] <= r:
|
||||||
|
curr_r = 9-idx
|
||||||
|
else:
|
||||||
|
return curr_r
|
||||||
|
return 0
|
264
src/utils/dvd_codebase/data/data_handler.py
Normal file
264
src/utils/dvd_codebase/data/data_handler.py
Normal file
|
@ -0,0 +1,264 @@
|
||||||
|
"""
|
||||||
|
Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
All rights reserved.
|
||||||
|
This source code is licensed under the license found in the
|
||||||
|
LICENSE file in the root directory of this source tree.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import copy, logging, sys, time, os, pdb, random, glob, json
|
||||||
|
import pickle as pkl
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
from collections import Counter
|
||||||
|
from functools import partial
|
||||||
|
import nltk
|
||||||
|
import torch
|
||||||
|
import torch.utils.data as Data
|
||||||
|
from src.utils.dvd_codebase.data.dataset import *
|
||||||
|
from src.utils.dvd_codebase.data.analysis_utils import *
|
||||||
|
from src.utils.dvd_codebase.data.data_utils import *
|
||||||
|
from src.utils.dvd_codebase.data.analysis_utils import get_question_subtype, get_question_complexity
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def load_dials(args, split):
|
||||||
|
files = []
|
||||||
|
for video_split in ['all_actions', 'max2action']:
|
||||||
|
files += glob.glob(args.data_dir + '{}_{}_*/*.json'.format(video_split, split))
|
||||||
|
files = sorted(files) # [:50]
|
||||||
|
if args.debug:
|
||||||
|
files = files[:100]
|
||||||
|
all_dials = []
|
||||||
|
vid_set = {}
|
||||||
|
for file in tqdm(files, total=len(files)):
|
||||||
|
dials = json.load(open(file))
|
||||||
|
all_dials.extend(dials)
|
||||||
|
video_split = dials[0][0]['split']
|
||||||
|
vid = dials[0][0]['image'].replace('CLEVR', 'CATER')
|
||||||
|
vid_key = '{}-{}'.format(video_split, vid)
|
||||||
|
if vid_key not in vid_set:
|
||||||
|
vid_set[vid_key] = '{}/{}/{}.pkl'.format(args.fea_dir, video_split, vid)
|
||||||
|
return all_dials, vid_set
|
||||||
|
|
||||||
|
def load_videos(args, vid_set):
|
||||||
|
vid_fts = {}
|
||||||
|
ft_dims = None
|
||||||
|
size, stride = -1, -1
|
||||||
|
segment_map = {}
|
||||||
|
for vid_key, fea_file in tqdm(vid_set.items(), total=len(vid_set)):
|
||||||
|
#fea_file = '{}/{}.pkl'.format(args.fea_dir, vid)
|
||||||
|
fea = pkl.load(open(fea_file, 'rb'))
|
||||||
|
output = []
|
||||||
|
for clip_idx, clip in enumerate(fea['clips']):
|
||||||
|
fea = clip['features']
|
||||||
|
if len(fea.shape)==3:
|
||||||
|
fea = fea.transpose(1, 2, 0)
|
||||||
|
output.append(fea)
|
||||||
|
start, end = clip['segment']
|
||||||
|
if clip_idx not in segment_map:
|
||||||
|
segment_map[clip_idx] = (start, end)
|
||||||
|
if size == -1:
|
||||||
|
size = end - start + 1
|
||||||
|
if clip_idx>0 and stride == -1:
|
||||||
|
stride = start - prior_start
|
||||||
|
prior_start, prior_end = start, end
|
||||||
|
vft = np.asarray(output)
|
||||||
|
vid_fts[vid_key] = vft
|
||||||
|
if ft_dims is None:
|
||||||
|
ft_dims = vft.shape
|
||||||
|
return vid_fts, ft_dims, size, stride, segment_map
|
||||||
|
|
||||||
|
def load_video_features(args, vid_set):
|
||||||
|
vid_fts = {}
|
||||||
|
for vid_key, fea_file in tqdm(vid_set.items(), total=len(vid_set)):
|
||||||
|
#fea_file = '{}/{}.pkl'.format(args.fea_dir, vid)
|
||||||
|
fea = pkl.load(open(fea_file, 'rb'))
|
||||||
|
vid_fts[vid_key] = fea
|
||||||
|
return vid_fts
|
||||||
|
|
||||||
|
def get_vocabulary(dials, args, vocab=None):
|
||||||
|
#answer_options = set()
|
||||||
|
word_freq = {}
|
||||||
|
for dialog in tqdm(dials, total=len(dials)):
|
||||||
|
for turn in dialog:
|
||||||
|
for word in nltk.word_tokenize(turn['question']):
|
||||||
|
if word not in word_freq: word_freq[word] = 0
|
||||||
|
word_freq[word] += 1
|
||||||
|
answer = str(turn['answer'])
|
||||||
|
#answer_options.add(answer)
|
||||||
|
for word in nltk.word_tokenize(answer):
|
||||||
|
if word not in word_freq: word_freq[word] = 0
|
||||||
|
word_freq[word] += 1
|
||||||
|
program = turn['final_all_program']
|
||||||
|
for n in program:
|
||||||
|
if n['type'] == 'identity': continue
|
||||||
|
if n['type'] not in word_freq: word_freq[n['type']] = 0
|
||||||
|
word_freq[n['type']] += 1
|
||||||
|
if 'side_inputs' in n:
|
||||||
|
for side_input in n['side_inputs']:
|
||||||
|
for word in nltk.word_tokenize(side_input):
|
||||||
|
if word not in word_freq: word_freq[word] = 0
|
||||||
|
word_freq[word] += 1
|
||||||
|
if vocab is not None:
|
||||||
|
unk_words = set()
|
||||||
|
for word, freq in word_freq.items():
|
||||||
|
if word not in vocab:
|
||||||
|
unk_words.add(word)
|
||||||
|
return unk_words
|
||||||
|
vocab = {'<unk>':0, '<blank>':1, '<sos>':2, '<eos>':3, '<eoo>': 4}
|
||||||
|
for word, freq in word_freq.items():
|
||||||
|
vocab[word] = len(vocab)
|
||||||
|
answer_options = ['0', '1', '10', '2', '3', '4', '5', '6', '7', '8', '9', 'False', 'True', 'blue', 'brown', 'cone', 'cube', 'cyan', 'cylinder', 'flying', 'flying,rotating', 'flying,rotating,sliding', 'flying,sliding', 'gold', 'gray', 'green', 'large', 'medium', 'metal', 'no action', 'purple', 'red', 'rotating', 'rotating,sliding', 'rubber', 'sliding', 'small', 'sphere', 'spl', 'yellow']
|
||||||
|
return vocab, answer_options
|
||||||
|
|
||||||
|
def answer_by_question_type(dials):
|
||||||
|
qa_dist = {}
|
||||||
|
for dialog in dials:
|
||||||
|
for turn_idx, turn in enumerate(dialog):
|
||||||
|
answer = turn['answer']
|
||||||
|
template = turn['template']
|
||||||
|
if turn_idx > 0:
|
||||||
|
prior_template = dialog[turn_idx-1]['template']
|
||||||
|
else:
|
||||||
|
prior_template = None
|
||||||
|
qtype = get_question_subtype(template, prior_template)
|
||||||
|
if qtype not in qa_dist:
|
||||||
|
qa_dist[qtype] = {}
|
||||||
|
if answer not in qa_dist[qtype]:
|
||||||
|
qa_dist[qtype][answer] = 0
|
||||||
|
qa_dist[qtype][answer] += 1
|
||||||
|
return qa_dist
|
||||||
|
|
||||||
|
|
||||||
|
# Load text data
|
||||||
|
def create_dials(dials, vocab, answer_list, vft_data, args, tokenizer=None):
|
||||||
|
dialog_list = []
|
||||||
|
qa_id = 0
|
||||||
|
for dialog in tqdm(dials, total=len(dials)):
|
||||||
|
if tokenizer is None:
|
||||||
|
questions = [words2ids(t['question'], vocab) for t in dialog]
|
||||||
|
answers = [words2ids(str(t['answer']), vocab) for t in dialog]
|
||||||
|
else:
|
||||||
|
questions = [words2ids_pretrained_lm(t['question'], vocab, tokenizer) for t in dialog]
|
||||||
|
answers = [words2ids_pretrained_lm(str(t['answer']), vocab, tokenizer) for t in dialog]
|
||||||
|
answer_output = [[answer_list.index(str(t['answer']))] for t in dialog]
|
||||||
|
qa_pair = [np.concatenate((q,a)).astype(np.int32) for q,a in zip(questions, answers)]
|
||||||
|
|
||||||
|
attribute_dependencies = []
|
||||||
|
object_dependencies = []
|
||||||
|
temporal_dependencies = []
|
||||||
|
spatial_dependencies = []
|
||||||
|
q_types = []
|
||||||
|
q_complexities = []
|
||||||
|
for i, t in enumerate(dialog):
|
||||||
|
# determine the type of turn relation
|
||||||
|
attribute_dependencies.append(t['turn_dependencies']['attribute'])
|
||||||
|
object_dependencies.append(t['turn_dependencies']['object'])
|
||||||
|
temporal_dependencies.append(t['turn_dependencies']['temporal'])
|
||||||
|
spatial_dependencies.append(t['turn_dependencies']['spatial'])
|
||||||
|
|
||||||
|
# determine the question type based on the template for analysis reasons
|
||||||
|
if i == 0:
|
||||||
|
q_types.append(get_question_type(t['template'], None))
|
||||||
|
else:
|
||||||
|
q_types.append(get_question_type(t['template'], dialog[i-1]['template']))
|
||||||
|
|
||||||
|
# get question complexity
|
||||||
|
q_complexities.append(get_question_complexity(t, t['template_filename'] ))
|
||||||
|
|
||||||
|
# get image name
|
||||||
|
video_name = t['image']
|
||||||
|
|
||||||
|
vid_cutoffs = [t['template']['cutoff'] for t in dialog]
|
||||||
|
gt_vid_periods = [t['template']['used_periods'][-1] for t in dialog]
|
||||||
|
programs = [program2ids(t['final_all_program'], vocab) for t in dialog]
|
||||||
|
states = [state2ids(t['template']['used_objects'], vocab) for t in dialog]
|
||||||
|
vid = dialog[0]['image'].replace('CLEVR', 'CATER')
|
||||||
|
vid_split = dialog[0]['split']
|
||||||
|
vid_key = '{}-{}'.format(vid_split, vid)
|
||||||
|
whole_vft_fea = vft_data[vid_key]
|
||||||
|
turn_based_vft_fea = []
|
||||||
|
|
||||||
|
# cutoff the unused vft data based on the vid_cutoffs
|
||||||
|
for t_idx, t_cutoff in enumerate(vid_cutoffs):
|
||||||
|
if t_cutoff is not None:
|
||||||
|
t_vft_fea = whole_vft_fea[:t_cutoff[3], :, :]
|
||||||
|
else:
|
||||||
|
t_vft_fea = whole_vft_fea
|
||||||
|
turn_based_vft_fea.append(t_vft_fea)
|
||||||
|
|
||||||
|
for n in range(len(questions)):
|
||||||
|
start_turn_idx = 0
|
||||||
|
history = np.asarray([])
|
||||||
|
turns = []
|
||||||
|
q_turns = []
|
||||||
|
a_turns = []
|
||||||
|
for m in range(start_turn_idx, n):
|
||||||
|
history = np.append(history, qa_pair[m])
|
||||||
|
turns.append(qa_pair[m])
|
||||||
|
q_turns.append(questions[m])
|
||||||
|
a_turns.append(np.array(answer_output[m]))
|
||||||
|
|
||||||
|
question = questions[n]
|
||||||
|
answer = answer_output[n]
|
||||||
|
program = programs[n]
|
||||||
|
state = states[n]
|
||||||
|
gt_period = gt_vid_periods[n]
|
||||||
|
q_type = q_types[n]
|
||||||
|
attribute_dependency = attribute_dependencies[n]
|
||||||
|
object_dependency = object_dependencies[n]
|
||||||
|
temporal_dependency = temporal_dependencies[n]
|
||||||
|
spatial_dependency = spatial_dependencies[n]
|
||||||
|
q_complexity = q_complexities[n]
|
||||||
|
vft_feat = turn_based_vft_fea[n]
|
||||||
|
|
||||||
|
item = [vid_split, vid, qa_id, history, question, answer, turns,
|
||||||
|
q_turns, a_turns, vft_feat, gt_period,
|
||||||
|
program, state, q_type, attribute_dependency, object_dependency,
|
||||||
|
temporal_dependency, spatial_dependency, video_name, q_complexity]
|
||||||
|
|
||||||
|
dialog_list.append(item)
|
||||||
|
qa_id += 1
|
||||||
|
|
||||||
|
data = {'dialogs': dialog_list, 'vocab': vocab, 'answer': answer_list, 'features': []}
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataset(data, vocab, split, args):
|
||||||
|
out = {}
|
||||||
|
keys = ['vid_split', 'vid', 'qa_id', 'history', 'question', 'answer', 'turns',
|
||||||
|
'q_turns', 'a_turns', 'vft', 'gt_period',
|
||||||
|
'program', 'state', 'q_type', 'attribute_dependency', 'object_dependency',
|
||||||
|
'temporal_dependency', 'spatial_dependency', 'video_name', 'q_complexity']
|
||||||
|
for key in keys:
|
||||||
|
out[key] = []
|
||||||
|
for dialog in data['dialogs']:
|
||||||
|
out['vid_split'].append(dialog[0])
|
||||||
|
out['vid'].append(dialog[1])
|
||||||
|
out['qa_id'].append(dialog[2])
|
||||||
|
out['history'].append(dialog[3])
|
||||||
|
out['question'].append(dialog[4])
|
||||||
|
out['answer'].append(dialog[5])
|
||||||
|
out['turns'].append(dialog[6])
|
||||||
|
out['q_turns'].append(dialog[7])
|
||||||
|
out['a_turns'].append(dialog[8])
|
||||||
|
out['vft'].append(dialog[9])
|
||||||
|
out['gt_period'].append(dialog[10])
|
||||||
|
out['program'].append(dialog[11])
|
||||||
|
out['state'].append(dialog[12])
|
||||||
|
out['q_type'].append(dialog[13])
|
||||||
|
out['attribute_dependency'].append(dialog[14])
|
||||||
|
out['object_dependency'].append(dialog[15])
|
||||||
|
out['temporal_dependency'].append(dialog[16])
|
||||||
|
out['spatial_dependency'].append(dialog[17])
|
||||||
|
out['video_name'].append(dialog[18])
|
||||||
|
out['q_complexity'].append(dialog[19])
|
||||||
|
|
||||||
|
dataset = Dataset(out)
|
||||||
|
data_loader = torch.utils.data.DataLoader(dataset=dataset,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
shuffle=(split=='train'),
|
||||||
|
collate_fn=partial(collate_fn, vocab=vocab),
|
||||||
|
num_workers=args.num_workers,
|
||||||
|
pin_memory=True)
|
||||||
|
return data_loader, len(out['vid'])
|
169
src/utils/dvd_codebase/data/data_utils.py
Normal file
169
src/utils/dvd_codebase/data/data_utils.py
Normal file
|
@ -0,0 +1,169 @@
|
||||||
|
"""
|
||||||
|
Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
All rights reserved.
|
||||||
|
This source code is licensed under the license found in the
|
||||||
|
LICENSE file in the root directory of this source tree.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import os
|
||||||
|
import six
|
||||||
|
import pickle
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
import pdb
|
||||||
|
from tqdm import tqdm
|
||||||
|
import torch
|
||||||
|
import nltk
|
||||||
|
|
||||||
|
def subsequent_mask(size):
|
||||||
|
"Mask out subsequent positions."
|
||||||
|
attn_shape = (1, size, size)
|
||||||
|
subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
|
||||||
|
return torch.from_numpy(subsequent_mask) == 0
|
||||||
|
|
||||||
|
def get_npy_shape(filename):
|
||||||
|
with open(filename, 'rb') as f:
|
||||||
|
if filename.endswith('.pkl'):
|
||||||
|
shape = pickle.load(f).shape
|
||||||
|
else:
|
||||||
|
pdb.set_trace()
|
||||||
|
major, minor = np.lib.format.read_magic(f)
|
||||||
|
shape, fortran, dtype = np.lib.format.read_array_header_1_0(f)
|
||||||
|
return shape
|
||||||
|
|
||||||
|
def words2ids(str_in, vocab):
|
||||||
|
words = nltk.word_tokenize(str_in)
|
||||||
|
sentence = np.ndarray(len(words)+2, dtype=np.int32)
|
||||||
|
sentence[0]=vocab['<sos>']
|
||||||
|
for i,w in enumerate(words):
|
||||||
|
if w in vocab:
|
||||||
|
sentence[i+1] = vocab[w]
|
||||||
|
else:
|
||||||
|
sentence[i+1] = vocab['<unk>']
|
||||||
|
sentence[-1]=vocab['<eos>']
|
||||||
|
return sentence
|
||||||
|
|
||||||
|
def words2ids_pretrained_lm(str_in, vocab, tokenizer):
|
||||||
|
# based on: https://medium.com/@dhartidhami/understanding-bert-word-embeddings-7dc4d2ea54ca
|
||||||
|
text = tokenizer.cls_token + str_in + tokenizer.eos_token
|
||||||
|
tokenized_text = tokenizer.tokenize(text)
|
||||||
|
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
|
||||||
|
token_array = np.array([indexed_tokens])
|
||||||
|
token_array = np.reshape(token_array, (-1,))
|
||||||
|
return token_array
|
||||||
|
|
||||||
|
|
||||||
|
def program2ids(program, vocab):
|
||||||
|
sentence = []
|
||||||
|
return np.asarray(sentence, dtype=np.int32)
|
||||||
|
for n in program:
|
||||||
|
t = n['type']
|
||||||
|
if t == 'identity': continue
|
||||||
|
if t not in vocab:
|
||||||
|
print(t)
|
||||||
|
pdb.set_trace()
|
||||||
|
#else:
|
||||||
|
# t = new_nodes[t]
|
||||||
|
sentence.append(vocab[t])
|
||||||
|
if 'side_inputs' in n:
|
||||||
|
if len(n['side_inputs'])!=1:
|
||||||
|
assert type(n['side_inputs']) == str
|
||||||
|
words = n['side_inputs']
|
||||||
|
else:
|
||||||
|
words = n['side_inputs'][0]
|
||||||
|
words = nltk.word_tokenize(words)
|
||||||
|
for word in words:
|
||||||
|
if word in vocab:
|
||||||
|
sentence.append(vocab[word])
|
||||||
|
else:
|
||||||
|
sentence.append(vocab['<unk>'])
|
||||||
|
#if len(sentence)==0:
|
||||||
|
# pdb.set_trace()
|
||||||
|
# sentence=np.asarray([vocab['<eop>']])
|
||||||
|
return np.asarray(sentence, dtype=np.int32)
|
||||||
|
|
||||||
|
def state2ids_dot(state, dot_vocab, max_dot_size=10):
|
||||||
|
ordered_attrs = ['<Z>', '<C>', '<M>', '<S>']
|
||||||
|
ids = {}
|
||||||
|
for a in ordered_attrs:
|
||||||
|
ids[a] = []
|
||||||
|
for o in range(max_dot_size):
|
||||||
|
ids[a].append(dot_vocab[a]['<blank>'])
|
||||||
|
if len(state)==0:
|
||||||
|
return ids
|
||||||
|
sorted_state = {k: v for k, v in sorted(state.items(), key=lambda item: item[1]['original_turn'])}
|
||||||
|
state_idx = 0
|
||||||
|
for k,v in sorted_state.items():
|
||||||
|
for a in ordered_attrs:
|
||||||
|
if a in v:
|
||||||
|
ids[a][state_idx] = dot_vocab[a][v[a]]
|
||||||
|
state_idx += 1
|
||||||
|
ids = {k:np.asarray(v, dtype=np.int32) for k,v in ids.items()}
|
||||||
|
return ids
|
||||||
|
|
||||||
|
def state2ids(state, vocab):
|
||||||
|
return np.asarray([], dtype=np.int32)
|
||||||
|
if len(state)==0:
|
||||||
|
return np.asarray([vocab['<eoo>']], dtype=np.int32)
|
||||||
|
sentence = []
|
||||||
|
ordered_attrs = ['<Z>', '<C>', '<M>', '<S>']
|
||||||
|
#print(state)
|
||||||
|
sorted_state = {k: v for k, v in sorted(state.items(), key=lambda item: item[1]['original_turn'])}
|
||||||
|
|
||||||
|
for k,v in sorted_state.items():
|
||||||
|
found_obj = False
|
||||||
|
for a in ordered_attrs:
|
||||||
|
if a in v:
|
||||||
|
sentence.append(vocab[v[a]])
|
||||||
|
found_obj = True
|
||||||
|
if found_obj:
|
||||||
|
sentence.append(vocab['<eoo>'])
|
||||||
|
if len(sentence)==0:
|
||||||
|
return np.asarray([vocab['<eoo>']], dtype=np.int32)
|
||||||
|
return np.asarray(sentence, dtype=np.int32)
|
||||||
|
|
||||||
|
def get_vft_size_by_timestamp(time, segment_map, event_type='end', threshold=5):
|
||||||
|
if time is None:
|
||||||
|
if event_type == 'end':
|
||||||
|
return len(segment_map)-1
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
if event_type == 'end':
|
||||||
|
segment_idx = -1
|
||||||
|
for idx in range(len(segment_map)):
|
||||||
|
segment_range = segment_map[idx]
|
||||||
|
if segment_range[1]>time[-1]:
|
||||||
|
segment_idx = idx-1
|
||||||
|
break
|
||||||
|
if segment_idx == -1:
|
||||||
|
segment_idx = 0
|
||||||
|
return segment_idx
|
||||||
|
|
||||||
|
else:
|
||||||
|
segment_idx = -1
|
||||||
|
for idx in range(len(segment_map)):
|
||||||
|
segment_range = segment_map[idx]
|
||||||
|
if segment_range[0]>=time[-1]:
|
||||||
|
segment_idx = idx
|
||||||
|
break
|
||||||
|
if segment_idx == -1:
|
||||||
|
segment_idx = len(segment_map)-1
|
||||||
|
return segment_idx
|
||||||
|
|
||||||
|
|
||||||
|
def get_vft_range_by_period(period, segment_map, eov):
|
||||||
|
if period is None:
|
||||||
|
return (0, eov)
|
||||||
|
else:
|
||||||
|
start_time, end_time = period
|
||||||
|
start_vft = get_vft_size_by_timestamp(start_time, segment_map, 'start')
|
||||||
|
end_vft = get_vft_size_by_timestamp(end_time, segment_map, 'end')
|
||||||
|
if start_vft > end_vft:
|
||||||
|
start_vft, end_vft = end_vft, start_vft
|
||||||
|
return (start_vft, end_vft)
|
||||||
|
|
255
src/utils/dvd_codebase/data/dataset.py
Normal file
255
src/utils/dvd_codebase/data/dataset.py
Normal file
|
@ -0,0 +1,255 @@
|
||||||
|
"""
|
||||||
|
Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
All rights reserved.
|
||||||
|
This source code is licensed under the license found in the
|
||||||
|
LICENSE file in the root directory of this source tree.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import os
|
||||||
|
import six
|
||||||
|
import pickle
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
import pdb
|
||||||
|
from tqdm import tqdm
|
||||||
|
import torch
|
||||||
|
import torch.utils.data as Data
|
||||||
|
from torch.autograd import Variable
|
||||||
|
from src.utils.dvd_codebase.data.data_utils import *
|
||||||
|
|
||||||
|
class Dataset(Data.Dataset):
|
||||||
|
def __init__(self, data_info):
|
||||||
|
self.vid_split = data_info['vid_split']
|
||||||
|
self.vid = data_info['vid']
|
||||||
|
self.qa_id = data_info['qa_id']
|
||||||
|
self.history = data_info['history']
|
||||||
|
self.question = data_info['question']
|
||||||
|
self.answer = data_info['answer']
|
||||||
|
self.turns = data_info['turns']
|
||||||
|
self.q_turns = data_info['q_turns']
|
||||||
|
self.a_turns = data_info['a_turns']
|
||||||
|
self.vft = data_info['vft']
|
||||||
|
self.gt_period = data_info['gt_period']
|
||||||
|
self.program = data_info['program']
|
||||||
|
self.state = data_info['state']
|
||||||
|
self.q_type = data_info['q_type']
|
||||||
|
self.attribute_dependency = data_info['attribute_dependency']
|
||||||
|
self.object_dependency = data_info['object_dependency']
|
||||||
|
self.temporal_dependency = data_info['temporal_dependency']
|
||||||
|
self.spatial_dependency = data_info['spatial_dependency']
|
||||||
|
self.video_name = data_info['video_name']
|
||||||
|
self.q_complexity = data_info['q_complexity']
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
item_info = {
|
||||||
|
'vid_split': self.vid_split[index],
|
||||||
|
'vid':self.vid[index],
|
||||||
|
'qa_id': self.qa_id[index],
|
||||||
|
'history': self.history[index],
|
||||||
|
'turns': self.turns[index],
|
||||||
|
'q_turns': self.q_turns[index],
|
||||||
|
'a_turns': self.a_turns[index],
|
||||||
|
'question': self.question[index],
|
||||||
|
'answer': self.answer[index],
|
||||||
|
'vft': self.vft[index],
|
||||||
|
'gt_period': self.gt_period[index],
|
||||||
|
'program': self.program[index],
|
||||||
|
'state': self.state[index],
|
||||||
|
'q_type': self.q_type[index],
|
||||||
|
'attribute_dependency': self.attribute_dependency[index],
|
||||||
|
'object_dependency': self.object_dependency[index],
|
||||||
|
'temporal_dependency': self.temporal_dependency[index],
|
||||||
|
'spatial_dependency': self.spatial_dependency[index],
|
||||||
|
'video_name': self.video_name[index],
|
||||||
|
'q_complexity': self.q_complexity[index]
|
||||||
|
}
|
||||||
|
return item_info
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.vid)
|
||||||
|
|
||||||
|
class Batch:
|
||||||
|
def __init__(self, vft, his, query, his_query, turns,
|
||||||
|
q_turns, a_turns,
|
||||||
|
answer, vid_splits, vids, qa_ids,
|
||||||
|
query_lens, his_lens, his_query_lens,
|
||||||
|
dial_lens, turn_lens,
|
||||||
|
program, program_lens, state, state_lens,
|
||||||
|
vocab, q_type, attribute_dependency, object_dependency,
|
||||||
|
temporal_dependency, spatial_dependency, video_name, q_complexity):
|
||||||
|
self.vid_splits = vid_splits
|
||||||
|
self.vids = vids
|
||||||
|
self.qa_ids = qa_ids
|
||||||
|
self.size = len(self.vids)
|
||||||
|
|
||||||
|
self.query = query
|
||||||
|
self.query_lens = query_lens
|
||||||
|
self.his = his
|
||||||
|
self.his_lens = his_lens
|
||||||
|
self.his_query = his_query
|
||||||
|
self.his_query_lens = his_query_lens
|
||||||
|
self.answer = answer
|
||||||
|
self.vft = vft
|
||||||
|
self.turns = turns
|
||||||
|
self.q_turns = q_turns
|
||||||
|
self.a_turns = a_turns
|
||||||
|
self.dial_lens = dial_lens
|
||||||
|
self.turn_lens = turn_lens
|
||||||
|
self.q_type = q_type
|
||||||
|
self.attribute_dependency = attribute_dependency
|
||||||
|
self.object_dependency = object_dependency
|
||||||
|
self.temporal_dependency = temporal_dependency
|
||||||
|
self.spatial_dependency = spatial_dependency
|
||||||
|
self.video_name = video_name
|
||||||
|
self.q_complexity = q_complexity
|
||||||
|
|
||||||
|
pad = vocab['<blank>']
|
||||||
|
self.his_query_mask = (his_query != pad).unsqueeze(-2)
|
||||||
|
self.query_mask = (query != pad)
|
||||||
|
self.his_mask = (his != pad).unsqueeze(-2)
|
||||||
|
self.q_turns_mask = (q_turns != pad)
|
||||||
|
self.turns_mask = (turns != pad)
|
||||||
|
|
||||||
|
self.program = program
|
||||||
|
self.program_lens = program_lens
|
||||||
|
self.state = state
|
||||||
|
self.state_lens = state_lens
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make_std_mask(tgt, pad):
|
||||||
|
tgt_mask = (tgt != pad).unsqueeze(-2)
|
||||||
|
tgt_mask = tgt_mask & Variable(subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data))
|
||||||
|
return tgt_mask
|
||||||
|
|
||||||
|
def move_to_cuda(self):
|
||||||
|
self.query = self.query.to('cuda', non_blocking=True)
|
||||||
|
self.his = self.his.to('cuda', non_blocking=True)
|
||||||
|
self.his_query = self.his_query.to('cuda', non_blocking=True)
|
||||||
|
self.query_mask = self.query_mask.to('cuda', non_blocking=True)
|
||||||
|
self.his_mask = self.his_mask.to('cuda', non_blocking=True)
|
||||||
|
self.his_query_mask = self.his_query_mask.to('cuda', non_blocking=True)
|
||||||
|
self.answer = self.answer.to('cuda', non_blocking=True)
|
||||||
|
self.vft = self.vft.to('cuda', non_blocking=True)
|
||||||
|
self.turns = self.turns.to('cuda', non_blocking=True)
|
||||||
|
self.turns_mask = self.turns_mask.to('cuda', non_blocking=True)
|
||||||
|
self.q_turns = self.q_turns.to('cuda', non_blocking=True)
|
||||||
|
self.q_turns_mask = self.q_turns_mask.to('cuda', non_blocking=True)
|
||||||
|
self.a_turns = self.a_turns.to('cuda', non_blocking=True)
|
||||||
|
self.program = self.program.to('cuda', non_blocking=True)
|
||||||
|
self.state = self.state.to('cuda', non_blocking=True)
|
||||||
|
|
||||||
|
def to_cuda(self, tensor):
|
||||||
|
return tensor.cuda()
|
||||||
|
|
||||||
|
def collate_fn(data, vocab):
|
||||||
|
def pad_monet_videos(seqs, pad_token):
|
||||||
|
lengths = [s.shape[0] for s in seqs]
|
||||||
|
max_length = max(lengths)
|
||||||
|
output = []
|
||||||
|
for seq in seqs:
|
||||||
|
result = torch.ones((max_length, seq.shape[1], seq.shape[2])) * pad_token
|
||||||
|
result[:seq.shape[0]] = seq
|
||||||
|
output.append(result)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def pad_seq(seqs, pad_token, return_lens=False, is_vft=False):
|
||||||
|
lengths = [s.shape[0] for s in seqs]
|
||||||
|
max_length = max(lengths)
|
||||||
|
output = []
|
||||||
|
for seq in seqs:
|
||||||
|
if is_vft:
|
||||||
|
if len(seq.shape)==4: # spatio-temporal feature
|
||||||
|
result = np.ones((max_length, seq.shape[1], seq.shape[2], seq.shape[3]), dtype=seq.dtype)*pad_token
|
||||||
|
else:
|
||||||
|
result = np.ones((max_length, seq.shape[-1]), dtype=seq.dtype)*pad_token
|
||||||
|
else:
|
||||||
|
result = np.ones(max_length, dtype=seq.dtype)*pad_token
|
||||||
|
result[:seq.shape[0]] = seq
|
||||||
|
output.append(result)
|
||||||
|
if return_lens:
|
||||||
|
return lengths, output
|
||||||
|
return output
|
||||||
|
|
||||||
|
def pad_2d_seq(seqs, pad_token, return_lens=False, is_vft=False):
|
||||||
|
lens1 = [len(s) for s in seqs]
|
||||||
|
max_len1 = max(lens1)
|
||||||
|
all_seqs = []
|
||||||
|
for seq in seqs:
|
||||||
|
all_seqs.extend(seq)
|
||||||
|
lens2 = [len(s) for s in all_seqs]
|
||||||
|
max_len2 = max(lens2)
|
||||||
|
output = []
|
||||||
|
all_lens = []
|
||||||
|
for seq in seqs:
|
||||||
|
if is_vft:
|
||||||
|
result = np.ones((max_len1, max_len2, seq[0].shape[-1]))*pad_token
|
||||||
|
else:
|
||||||
|
result = np.ones((max_len1, max_len2))*pad_token
|
||||||
|
turn_lens = np.ones(max_len1).astype(int)
|
||||||
|
offset = max_len1 - len(seq)
|
||||||
|
for turn_idx, turn in enumerate(seq):
|
||||||
|
#result[turn_idx,:turn.shape[0]] = turn
|
||||||
|
# padding should be at the first turn idxs (Reason: result of last n turns is used for state creation)
|
||||||
|
result[turn_idx + offset,:turn.shape[0]] = turn
|
||||||
|
turn_lens[turn_idx] = turn.shape[0]
|
||||||
|
output.append(result)
|
||||||
|
all_lens.append(turn_lens)
|
||||||
|
all_lens = np.asarray(all_lens)
|
||||||
|
if return_lens:
|
||||||
|
return lens1, all_lens, output
|
||||||
|
return output
|
||||||
|
|
||||||
|
def prepare_data(seqs, is_float=False):
|
||||||
|
if is_float:
|
||||||
|
return torch.from_numpy(np.asarray(seqs)).float()
|
||||||
|
return torch.from_numpy(np.asarray(seqs)).long()
|
||||||
|
|
||||||
|
item_info = {}
|
||||||
|
for key in data[0].keys():
|
||||||
|
item_info[key] = [d[key] for d in data]
|
||||||
|
pad_token = vocab['<blank>']
|
||||||
|
h_lens, h_padded = pad_seq(item_info['history'], pad_token, return_lens=True)
|
||||||
|
h_batch = prepare_data(h_padded)
|
||||||
|
q_lens, q_padded = pad_seq(item_info['question'], pad_token, return_lens=True)
|
||||||
|
q_batch = prepare_data(q_padded)
|
||||||
|
|
||||||
|
hq = [np.concatenate([q,h]) for q,h in zip(item_info['history'], item_info['question'])]
|
||||||
|
hq_lens, hq_padded = pad_seq(hq, pad_token, return_lens=True)
|
||||||
|
hq_batch = prepare_data(hq_padded)
|
||||||
|
|
||||||
|
dial_lens, turn_lens, turns_padded = pad_2d_seq(item_info['turns'], pad_token, return_lens=True)
|
||||||
|
_, _, q_turns_padded = pad_2d_seq(item_info['q_turns'], pad_token, return_lens=True)
|
||||||
|
turns_batch = prepare_data(turns_padded)
|
||||||
|
q_turns_batch = prepare_data(q_turns_padded)
|
||||||
|
|
||||||
|
a_turns_padded = pad_2d_seq(item_info['a_turns'], pad_token)
|
||||||
|
a_turns_batch = prepare_data(a_turns_padded)
|
||||||
|
|
||||||
|
a_batch = prepare_data(item_info['answer'])
|
||||||
|
|
||||||
|
#vft_lens, vft_padded = pad_seq(item_info['vft'], 0, return_lens=True, is_vft=True)
|
||||||
|
#vft_batch = prepare_data(vft_padded, is_float=True)
|
||||||
|
vft_batch = item_info['vft']
|
||||||
|
vft_batch_padded = pad_monet_videos(vft_batch, 0)
|
||||||
|
vft_batch_padded = torch.stack(vft_batch_padded)
|
||||||
|
|
||||||
|
p_lens, p_padded = pad_seq(item_info['program'], pad_token, return_lens=True)
|
||||||
|
p_batch = prepare_data(p_padded)
|
||||||
|
|
||||||
|
s_lens, s_padded = pad_seq(item_info['state'], pad_token, return_lens=True)
|
||||||
|
s_batch = prepare_data(s_padded)
|
||||||
|
|
||||||
|
batch = Batch(vft_batch_padded,
|
||||||
|
h_batch, q_batch, hq_batch, turns_batch, q_turns_batch, a_turns_batch, a_batch,
|
||||||
|
item_info['vid_split'], item_info['vid'], item_info['qa_id'],
|
||||||
|
q_lens, h_lens, hq_lens,
|
||||||
|
dial_lens, turn_lens,
|
||||||
|
p_batch, p_lens, s_batch, s_lens,
|
||||||
|
vocab, item_info['q_type'], item_info['attribute_dependency'], item_info['object_dependency'],
|
||||||
|
item_info['temporal_dependency'], item_info['spatial_dependency'], item_info['video_name'],
|
||||||
|
item_info['q_complexity'])
|
||||||
|
return batch
|
BIN
src/utils/dvd_codebase/exps_test/baseline/dvd.conf
Normal file
BIN
src/utils/dvd_codebase/exps_test/baseline/dvd.conf
Normal file
Binary file not shown.
9
src/utils/dvd_codebase/exps_test/baseline/dvd_params.txt
Normal file
9
src/utils/dvd_codebase/exps_test/baseline/dvd_params.txt
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
debug=1
|
||||||
|
fea_dir=/workspace/hungle/data/dvd/video-classification-3d-cnn-pytorch/outputs/resnext_101/
|
||||||
|
data_dir=/workspace/hungle/cater-dialog/question_generation/output/
|
||||||
|
output_dir=exps_test//baseline/dvd
|
||||||
|
num_workers=0
|
||||||
|
device=0
|
||||||
|
num_epochs=3
|
||||||
|
batch_size=32
|
||||||
|
verbose=0
|
86
src/utils/dvd_codebase/main.py
Executable file
86
src/utils/dvd_codebase/main.py
Executable file
|
@ -0,0 +1,86 @@
|
||||||
|
"""
|
||||||
|
Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
All rights reserved.
|
||||||
|
This source code is licensed under the license found in the
|
||||||
|
LICENSE file in the root directory of this source tree.
|
||||||
|
"""
|
||||||
|
|
||||||
|
#!/usr/bin/env python
|
||||||
|
import math
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
import pickle as pkl
|
||||||
|
import threading
|
||||||
|
import pdb
|
||||||
|
from tqdm import tqdm
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from project.dvd_codebase.configs.configs import *
|
||||||
|
import project.dvd_codebase.data.data_handler as dh
|
||||||
|
|
||||||
|
def run_epoch(loader, epoch):
|
||||||
|
it = tqdm(enumerate(loader),total=len(loader), desc="epoch {}/{}".format(epoch+1, args.num_epochs), ncols=0)
|
||||||
|
for j, batch in it:
|
||||||
|
batch.move_to_cuda()
|
||||||
|
pdb.set_trace()
|
||||||
|
|
||||||
|
# load dialogues
|
||||||
|
logging.info('Loading dialogues from {}'.format(args.data_dir))
|
||||||
|
train_dials, train_vids = dh.load_dials(args, 'train')
|
||||||
|
logging.info('#train dials = {} # train videos = {}'.format(len(train_dials), len(train_vids)))
|
||||||
|
val_dials, val_vids = dh.load_dials(args, 'val')
|
||||||
|
logging.info('#val dials = {} # val videos = {}'.format(len(val_dials), len(val_vids)))
|
||||||
|
|
||||||
|
# load video features
|
||||||
|
logging.info('Loading video features from {}'.format(args.fea_dir))
|
||||||
|
train_vft, vft_dims, clip_size, clip_stride, segment_map = dh.load_videos(args, train_vids)
|
||||||
|
val_vft, _, _, _, _ = dh.load_videos(args, val_vids)
|
||||||
|
logging.info('#video ft dims = {} clip size {} clip stride {}'.format(vft_dims, clip_size, clip_stride))
|
||||||
|
|
||||||
|
# get vocabulary
|
||||||
|
logging.info('Extracting vocabulary')
|
||||||
|
vocab, answer_list = dh.get_vocabulary(train_dials, args)
|
||||||
|
logging.info('#vocab = {} #answer candidates = {}'.
|
||||||
|
format(len(vocab), len(answer_list)))
|
||||||
|
logging.info('All answer candidates: {}'.format(answer_list))
|
||||||
|
unk_words = dh.get_vocabulary(val_dials, args, vocab=vocab)
|
||||||
|
logging.info('{} unknown words in val split: {}'.format(len(unk_words), unk_words))
|
||||||
|
|
||||||
|
# question-answer distribution
|
||||||
|
qa_dist = dh.answer_by_question_type(train_dials)
|
||||||
|
|
||||||
|
# save meta parameters
|
||||||
|
path = args.output_dir + '.conf'
|
||||||
|
with open(path, 'wb') as f:
|
||||||
|
pkl.dump((vocab, answer_list, qa_dist, args), f, -1)
|
||||||
|
path2 = args.output_dir + '_params.txt'
|
||||||
|
with open(path2, "w") as f:
|
||||||
|
for arg in vars(args):
|
||||||
|
f.write("{}={}\n".format(arg, getattr(args, arg)))
|
||||||
|
|
||||||
|
# load data
|
||||||
|
logging.info('Creating training instances')
|
||||||
|
train_dials = dh.create_dials(train_dials, vocab, answer_list, segment_map, train_vft, args)
|
||||||
|
logging.info('Creating validation instances')
|
||||||
|
valid_dials = dh.create_dials(val_dials, vocab, answer_list, segment_map, val_vft, args)
|
||||||
|
|
||||||
|
# make dataloaders
|
||||||
|
train_dataloader, train_samples = dh.create_dataset(train_dials, vocab, 'train', args)
|
||||||
|
logging.info('#train sample = {} # train batch = {}'.format(train_samples, len(train_dataloader)))
|
||||||
|
valid_dataloader, valid_samples = dh.create_dataset(valid_dials, vocab, 'val', args)
|
||||||
|
logging.info('#train sample = {} # train batch = {}'.format(valid_samples, len(valid_dataloader)))
|
||||||
|
|
||||||
|
epoch_counts = 0
|
||||||
|
for epoch in range(args.num_epochs):
|
||||||
|
# train on training data
|
||||||
|
logging.info('-------training--------')
|
||||||
|
train_losses = run_epoch(train_dataloader, epoch)
|
||||||
|
|
||||||
|
# test on validation data
|
||||||
|
logging.info('-------validation--------')
|
||||||
|
valid_losses = run_epoch(valid_dataloader, epoch)
|
||||||
|
|
43
src/utils/dvd_codebase/run.sh
Executable file
43
src/utils/dvd_codebase/run.sh
Executable file
|
@ -0,0 +1,43 @@
|
||||||
|
"""
|
||||||
|
Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
All rights reserved.
|
||||||
|
This source code is licensed under the license found in the
|
||||||
|
LICENSE file in the root directory of this source tree.
|
||||||
|
"""
|
||||||
|
|
||||||
|
#input choices
|
||||||
|
device=$1
|
||||||
|
debug=$2 # true: test run with small datasets OR false: run with real datasets
|
||||||
|
|
||||||
|
num_epochs=50
|
||||||
|
batch_size=32
|
||||||
|
nb_workers=16
|
||||||
|
|
||||||
|
# data setting
|
||||||
|
data_dir=/workspace/hungle/cater-dialog/question_generation/output/
|
||||||
|
fea_dir=/workspace/hungle/data/dvd/video-classification-3d-cnn-pytorch/outputs/resnext_101/
|
||||||
|
|
||||||
|
# output folder name
|
||||||
|
expid=baseline
|
||||||
|
|
||||||
|
if [ $debug = 1 ]; then
|
||||||
|
expdir=exps_test/$task/${expid}
|
||||||
|
num_epochs=3
|
||||||
|
nb_workers=0
|
||||||
|
report_interval=10
|
||||||
|
else
|
||||||
|
expdir=exps/$task/${expid}
|
||||||
|
fi
|
||||||
|
echo stage: $stage debug? $debug task: $task exp_dir: $expdir
|
||||||
|
|
||||||
|
# training phase
|
||||||
|
mkdir -p $expdir
|
||||||
|
CUDA_VISIBLE_DEVICES=$device python main.py \
|
||||||
|
--debug $debug \
|
||||||
|
--fea-dir $fea_dir \
|
||||||
|
--data-dir $data_dir \
|
||||||
|
--output-dir $expdir/dvd \
|
||||||
|
--num-epochs $num_epochs \
|
||||||
|
--batch-size $batch_size \
|
||||||
|
--num-workers $nb_workers \
|
||||||
|
|
27
src/utils/positional_encoding.py
Normal file
27
src/utils/positional_encoding.py
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
# https://github.com/pytorch/pytorch/issues/68407
|
||||||
|
from torch import nn
|
||||||
|
from torch import Tensor
|
||||||
|
import torch
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
class PositionalEncoding(nn.Module):
|
||||||
|
def __init__(self, d_model, dropout=0.1, max_len=5000):
|
||||||
|
super(PositionalEncoding, self).__init__()
|
||||||
|
self.dropout = nn.Dropout(p=dropout)
|
||||||
|
|
||||||
|
pe = torch.zeros(max_len, d_model)
|
||||||
|
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
||||||
|
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
||||||
|
pe[:, 0::2] = torch.sin(position * div_term)
|
||||||
|
pe[:, 1::2] = torch.cos(position * div_term)
|
||||||
|
pe = pe.unsqueeze(0).transpose(0, 1)
|
||||||
|
self.register_parameter('pe', nn.Parameter(pe, requires_grad=False))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# positional encoding expects shape (seq_len, batch_size, emb_dim), (batch_size, seq_len, emb_dim) is given
|
||||||
|
x = x.permute(1,0,2)
|
||||||
|
x = x + self.pe[:x.size(0), :]
|
||||||
|
x = x.permute(1,0,2)
|
||||||
|
return self.dropout(x)
|
||||||
|
|
8
src/utils/save_attention_weights.py
Normal file
8
src/utils/save_attention_weights.py
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
#https://gist.github.com/airalcorn2/50ec06517ce96ecc143503e21fa6cb91
|
||||||
|
|
||||||
|
class SaveOutput:
|
||||||
|
def __init__(self):
|
||||||
|
self.outputs = None
|
||||||
|
|
||||||
|
def __call__(self, module, module_in, module_out):
|
||||||
|
self.outputs = module_out
|
233
src/utils/simmc2_dataset/dataloader_dvd_model.py
Normal file
233
src/utils/simmc2_dataset/dataloader_dvd_model.py
Normal file
|
@ -0,0 +1,233 @@
|
||||||
|
#! /usr/bin/env python
|
||||||
|
"""
|
||||||
|
Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
All rights reserved.
|
||||||
|
This source code is licensed under the license found in the LICENSE file in the
|
||||||
|
root directory of this source tree.
|
||||||
|
|
||||||
|
Dataloader for ambiguous candidates identification task on SIMMC 2.1.
|
||||||
|
|
||||||
|
Author(s): Satwik Kottur
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
|
|
||||||
|
def pad_seq(seqs, pad_token, return_lens=False, is_vft=False):
|
||||||
|
lengths = [s.shape[1] for s in seqs]
|
||||||
|
max_length = max(lengths)
|
||||||
|
output = []
|
||||||
|
for seq in seqs:
|
||||||
|
if is_vft:
|
||||||
|
if len(seq.shape)==4: # spatio-temporal feature
|
||||||
|
result = torch.ones(((1, max_length), seq.shape[1], seq.shape[2], seq.shape[3]), dtype=seq.dtype)*pad_token
|
||||||
|
else:
|
||||||
|
result = torch.ones(((1, max_length), seq.shape[-1]), dtype=seq.dtype)*pad_token
|
||||||
|
else:
|
||||||
|
result = torch.ones((1, max_length), dtype=seq.dtype)*pad_token
|
||||||
|
result[0, :seq.shape[1]] = seq
|
||||||
|
output.append(result)
|
||||||
|
if return_lens:
|
||||||
|
return lengths, output
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def pad_2d_seq(seqs, pad_token, return_lens=False, is_vft=False):
|
||||||
|
lens1 = [len(s) for s in seqs]
|
||||||
|
max_len1 = max(lens1)
|
||||||
|
all_seqs = []
|
||||||
|
for seq in seqs:
|
||||||
|
all_seqs.extend(seq)
|
||||||
|
lens2 = [s.shape[1] for s in all_seqs]
|
||||||
|
max_len2 = max(lens2)
|
||||||
|
output = []
|
||||||
|
all_lens = []
|
||||||
|
for seq in seqs:
|
||||||
|
if is_vft:
|
||||||
|
result = torch.ones((max_len1, max_len2, seq[0].shape[-1]))*pad_token
|
||||||
|
else:
|
||||||
|
result = torch.ones((1, max_len1, max_len2))*pad_token
|
||||||
|
#turn_lens = torch.ones(max_len1, dtype=np.int)
|
||||||
|
offset = max_len1 - len(seq)
|
||||||
|
for turn_idx, turn in enumerate(seq):
|
||||||
|
#result[turn_idx,:turn.shape[0]] = turn
|
||||||
|
# padding should be at the first turn idxs (Reason: result of last n turns is used for state creation)
|
||||||
|
result[0, turn_idx + offset,:turn.shape[1]] = turn
|
||||||
|
#turn_lens[turn_idx] = turn.shape[0]
|
||||||
|
output.append(result)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class Simmc2Dataset(Dataset):
|
||||||
|
def __init__(self, tokenizer, feature_loader, load_path, args, hidden_labels=False):
|
||||||
|
self._tokenizer = tokenizer
|
||||||
|
self._features = feature_loader
|
||||||
|
self._args = args
|
||||||
|
self._hidden_labels = hidden_labels
|
||||||
|
print("Loading: {}".format(load_path))
|
||||||
|
with open(load_path, "r") as file_id:
|
||||||
|
self._raw_data = json.load(file_id)
|
||||||
|
# Also read the source data for evaluation.
|
||||||
|
with open(self._raw_data["source_path"], "r") as file_id:
|
||||||
|
self.source_data = json.load(file_id)
|
||||||
|
self._data = self._raw_data["data"]
|
||||||
|
|
||||||
|
self.num_utterances = 2 * args.max_turns + 1
|
||||||
|
self.num_instances = len(self._data)
|
||||||
|
self.device = torch.cuda if args.use_gpu else torch
|
||||||
|
|
||||||
|
def get_random_batch(self, batch_size):
|
||||||
|
indices = np.random.randint(0, self.num_instances, batch_size)
|
||||||
|
return self.get_indexed_data(indices)
|
||||||
|
|
||||||
|
def get_entire_batch(self, batch_size):
|
||||||
|
all_indices = np.arange(self.num_instances)
|
||||||
|
for start in all_indices[::batch_size]:
|
||||||
|
batch_indices = all_indices[start : start + batch_size]
|
||||||
|
yield self.get_indexed_data(batch_indices)
|
||||||
|
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._data)
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn(self, batch):
|
||||||
|
merged_batch = {key: [d[key] for d in batch] for key in batch[0]}
|
||||||
|
out = {}
|
||||||
|
for key in merged_batch:
|
||||||
|
if key in ['query', 'answer']:
|
||||||
|
seq = pad_seq(merged_batch[key], pad_token=1)
|
||||||
|
out[key] = torch.concat(seq, dim=0)
|
||||||
|
elif key in ['q_turns', 'a_turns', 'turns', 'object_features', 'answer_candidates']:
|
||||||
|
if merged_batch[key][0] is not None:
|
||||||
|
seq = pad_2d_seq(merged_batch[key], pad_token=1)
|
||||||
|
out[key] = torch.concat(seq, dim=0).type(torch.int)
|
||||||
|
else:
|
||||||
|
out[key] = None
|
||||||
|
|
||||||
|
elif key in ['features']:
|
||||||
|
#features = [f.unsqueeze(1) for f in merged_batch[key]]
|
||||||
|
# pad video featues
|
||||||
|
features = pad_sequence(merged_batch[key], batch_first=True)
|
||||||
|
out[key] = features
|
||||||
|
else:
|
||||||
|
out[key] = merged_batch[key]
|
||||||
|
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def encode_turns(self, turns):
|
||||||
|
encoded_turns = []
|
||||||
|
for turn in turns:
|
||||||
|
encoded_turn = self._tokenizer(
|
||||||
|
turn,
|
||||||
|
padding=True,
|
||||||
|
max_length=self._args.max_length,
|
||||||
|
return_tensors="pt",
|
||||||
|
truncation=True,
|
||||||
|
)
|
||||||
|
encoded_turns.append(encoded_turn['input_ids'].type(torch.int))
|
||||||
|
return encoded_turns
|
||||||
|
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
text_labels = []
|
||||||
|
text_inputs = []
|
||||||
|
dialog_ids = []
|
||||||
|
turn_ids = []
|
||||||
|
features = []
|
||||||
|
object_maps = []
|
||||||
|
# Add <USER> and <SYS> tokens.
|
||||||
|
dialog_datum = self._data[index]
|
||||||
|
#dialog = self._data[index]["input_text"]
|
||||||
|
query = self._data[index]["query"]
|
||||||
|
answer = self._data[index]["answer"]
|
||||||
|
turns = self._data[index]["turns"]
|
||||||
|
q_turns = self._data[index]["q_turns"]
|
||||||
|
a_turns = self._data[index]["a_turns"]
|
||||||
|
object_features = self._data[index]["object_metadata"]
|
||||||
|
if "answer_candidates" in self._data[index].keys():
|
||||||
|
answer_candidates = self._data[index]["answer_candidates"]
|
||||||
|
else:
|
||||||
|
answer_candidates = None
|
||||||
|
|
||||||
|
if self._features:
|
||||||
|
feature = self._features[dialog_datum["image_name"]]
|
||||||
|
|
||||||
|
encoded_query = self._tokenizer(
|
||||||
|
query,
|
||||||
|
padding=True,
|
||||||
|
max_length=self._args.max_length,
|
||||||
|
return_tensors="pt",
|
||||||
|
truncation=True,
|
||||||
|
)['input_ids'].type(torch.int)
|
||||||
|
encoded_answer = self._tokenizer(
|
||||||
|
answer,
|
||||||
|
padding=True,
|
||||||
|
max_length=self._args.max_length,
|
||||||
|
return_tensors="pt",
|
||||||
|
truncation=True,
|
||||||
|
)['input_ids'].type(torch.int)
|
||||||
|
encoded_q_turns = self.encode_turns(q_turns)
|
||||||
|
encoded_a_turns = self.encode_turns(a_turns)
|
||||||
|
encoded_turns = self.encode_turns(turns)
|
||||||
|
encoded_object_features = self.encode_turns(object_features)
|
||||||
|
if "answer_candidates" in self._data[index].keys():
|
||||||
|
encoded_answer_candidates = self.encode_turns(answer_candidates)
|
||||||
|
else:
|
||||||
|
encoded_answer_candidates = None
|
||||||
|
|
||||||
|
|
||||||
|
# Pack the sample.
|
||||||
|
sample = {
|
||||||
|
"query": encoded_query,
|
||||||
|
"answer": encoded_answer,
|
||||||
|
"answer_candidates": encoded_answer_candidates,
|
||||||
|
"turns": encoded_turns,
|
||||||
|
"q_turns": encoded_q_turns,
|
||||||
|
"a_turns": encoded_a_turns,
|
||||||
|
"object_features": encoded_object_features,
|
||||||
|
"dialog_id": dialog_datum["dialog_id"],
|
||||||
|
"turn_id": dialog_datum["turn_id"],
|
||||||
|
"features": feature,
|
||||||
|
}
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
class VisualFeatureLoader:
|
||||||
|
"""Loads visual features for SIMMC 2.1 ambiguous candidate identification."""
|
||||||
|
|
||||||
|
UNAVAILABLE_IMAGES = [
|
||||||
|
"cloth_store_1416238_woman_20_6.png",
|
||||||
|
"cloth_store_1416238_woman_19_0.png",
|
||||||
|
"cloth_store_1416238_woman_4_8.png",
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, feature_path, feature_size):
|
||||||
|
"""Read the features from the path."""
|
||||||
|
self._features = torch.load(feature_path)
|
||||||
|
self._feature_size = feature_size
|
||||||
|
self._zero_feature = torch.zeros((1, self._feature_size), dtype=torch.float)
|
||||||
|
|
||||||
|
def __getitem__(self, label):
|
||||||
|
"""Get the feature given image label."""
|
||||||
|
assert (
|
||||||
|
label in self._features or label in self.UNAVAILABLE_IMAGES
|
||||||
|
), f"{label} not found!"
|
||||||
|
if label in self.UNAVAILABLE_IMAGES:
|
||||||
|
return self._zero_feature
|
||||||
|
return self._features[label]
|
||||||
|
|
||||||
|
def cuda(self):
|
||||||
|
"""Move the features to cuda."""
|
||||||
|
self._zero_feature = self._zero_feature.cuda()
|
||||||
|
for key, val in self._features.items():
|
||||||
|
self._features[key] = val.cuda()
|
0
src/utils/simmc2_dataset/dataloader_finetune_mlm.py
Normal file
0
src/utils/simmc2_dataset/dataloader_finetune_mlm.py
Normal file
277
src/utils/simmc2_dataset/dataloader_mlm_nsp.py
Normal file
277
src/utils/simmc2_dataset/dataloader_mlm_nsp.py
Normal file
|
@ -0,0 +1,277 @@
|
||||||
|
#! /usr/bin/env python
|
||||||
|
"""
|
||||||
|
Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
All rights reserved.
|
||||||
|
This source code is licensed under the license found in the LICENSE file in the
|
||||||
|
root directory of this source tree.
|
||||||
|
|
||||||
|
Dataloader for ambiguous candidates identification task on SIMMC 2.1.
|
||||||
|
|
||||||
|
Author(s): Satwik Kottur
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
from random import shuffle
|
||||||
|
from random import random as rand
|
||||||
|
#from src.utils.vd_bert.loader_utils import get_random_word
|
||||||
|
|
||||||
|
|
||||||
|
def pad_seq(seqs, pad_token, return_lens=False, is_vft=False):
|
||||||
|
lengths = [s.shape[1] for s in seqs]
|
||||||
|
max_length = max(lengths)
|
||||||
|
output = []
|
||||||
|
for seq in seqs:
|
||||||
|
if is_vft:
|
||||||
|
if len(seq.shape)==4: # spatio-temporal feature
|
||||||
|
result = torch.ones(((1, max_length), seq.shape[1], seq.shape[2], seq.shape[3]), dtype=seq.dtype)*pad_token
|
||||||
|
else:
|
||||||
|
result = torch.ones(((1, max_length), seq.shape[-1]), dtype=seq.dtype)*pad_token
|
||||||
|
else:
|
||||||
|
result = torch.ones((1, max_length), dtype=seq.dtype)*pad_token
|
||||||
|
result[0, :seq.shape[1]] = seq
|
||||||
|
output.append(result)
|
||||||
|
if return_lens:
|
||||||
|
return lengths, output
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def pad_2d_seq(seqs, pad_token, return_lens=False, is_vft=False):
|
||||||
|
lens1 = [len(s) for s in seqs]
|
||||||
|
max_len1 = max(lens1)
|
||||||
|
all_seqs = []
|
||||||
|
for seq in seqs:
|
||||||
|
all_seqs.extend(seq)
|
||||||
|
lens2 = [s.shape[1] for s in all_seqs]
|
||||||
|
max_len2 = max(lens2)
|
||||||
|
output = []
|
||||||
|
all_lens = []
|
||||||
|
for seq in seqs:
|
||||||
|
if is_vft:
|
||||||
|
result = torch.ones((max_len1, max_len2, seq[0].shape[-1]))*pad_token
|
||||||
|
else:
|
||||||
|
result = torch.ones((1, max_len1, max_len2))*pad_token
|
||||||
|
#turn_lens = torch.ones(max_len1, dtype=np.int)
|
||||||
|
offset = max_len1 - len(seq)
|
||||||
|
for turn_idx, turn in enumerate(seq):
|
||||||
|
#result[turn_idx,:turn.shape[0]] = turn
|
||||||
|
# padding should be at the first turn idxs (Reason: result of last n turns is used for state creation)
|
||||||
|
result[0, turn_idx + offset,:turn.shape[1]] = turn
|
||||||
|
#turn_lens[turn_idx] = turn.shape[0]
|
||||||
|
output.append(result)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class Simmc2DatasetMlmNsp(Dataset):
|
||||||
|
def __init__(self, tokenizer, feature_loader, load_path, args, hidden_labels=False):
|
||||||
|
self._tokenizer = tokenizer
|
||||||
|
self._features = feature_loader
|
||||||
|
self._args = args
|
||||||
|
self._hidden_labels = hidden_labels
|
||||||
|
print("Loading: {}".format(load_path))
|
||||||
|
with open(load_path, "r") as file_id:
|
||||||
|
self._raw_data = json.load(file_id)
|
||||||
|
# Also read the source data for evaluation.
|
||||||
|
with open(self._raw_data["source_path"], "r") as file_id:
|
||||||
|
self.source_data = json.load(file_id)
|
||||||
|
self._data = self._raw_data["data"]
|
||||||
|
|
||||||
|
self.num_utterances = 2 * args.max_turns + 1
|
||||||
|
self.num_instances = len(self._data)
|
||||||
|
self.device = torch.cuda if args.use_gpu else torch
|
||||||
|
|
||||||
|
|
||||||
|
def conduct_mask(self, tokens, effective_length, start_id, end_id):
|
||||||
|
# taken from https://github.com/salesforce/VD-BERT
|
||||||
|
# For masked Language Models
|
||||||
|
cand_pos = []
|
||||||
|
special_pos = set()
|
||||||
|
|
||||||
|
n_pred = min(self._args.max_n_masked, max(
|
||||||
|
1, int(round(effective_length * self._args.p_mask))))
|
||||||
|
|
||||||
|
# candidate positions of masked tokens
|
||||||
|
for i, tk in enumerate(tokens):
|
||||||
|
# only mask tokens_b (target sequence)
|
||||||
|
# we will mask [SEP] as an ending symbol
|
||||||
|
if (i >= start_id) and (tk != '[CLS]') and (tk != '[PAD]') and (i < end_id):
|
||||||
|
cand_pos.append(i)
|
||||||
|
else:
|
||||||
|
special_pos.add(i)
|
||||||
|
|
||||||
|
shuffle(cand_pos)
|
||||||
|
masked_pos = cand_pos[:n_pred]
|
||||||
|
|
||||||
|
masked_tokens = [tokens[pos] for pos in masked_pos]
|
||||||
|
for pos in masked_pos:
|
||||||
|
if self._args.finetune:
|
||||||
|
tokens[pos] = '[MASK]'
|
||||||
|
continue
|
||||||
|
if rand() < 0.8: # 80%
|
||||||
|
tokens[pos] = '[MASK]'
|
||||||
|
#elif rand() < 0.5: # 10%
|
||||||
|
# tokens[pos] = get_random_word(self.vocab_words)
|
||||||
|
# when n_pred < max_pred, we only calculate loss within n_pred
|
||||||
|
masked_weights = [1] * len(masked_tokens)
|
||||||
|
|
||||||
|
# Token Indexing
|
||||||
|
input_ids = self._tokenizer.convert_tokens_to_ids(tokens)
|
||||||
|
masked_ids = self._tokenizer.convert_tokens_to_ids(masked_tokens)
|
||||||
|
|
||||||
|
if self._args.max_n_masked > n_pred:
|
||||||
|
n_pad = self._args.max_n_masked - n_pred
|
||||||
|
masked_ids.extend([0] * n_pad)
|
||||||
|
masked_pos.extend([0] * n_pad)
|
||||||
|
masked_weights.extend([0] * n_pad)
|
||||||
|
|
||||||
|
assert len(masked_ids) == len(masked_pos) == len(masked_weights) == self._args.max_n_masked, \
|
||||||
|
"[masked] id: %d, pos: %d, weights: %d" % (len(masked_ids), len(masked_pos), len(masked_weights))
|
||||||
|
|
||||||
|
return input_ids, masked_ids, masked_pos, masked_weights
|
||||||
|
|
||||||
|
|
||||||
|
def get_random_batch(self, batch_size):
|
||||||
|
indices = np.random.randint(0, self.num_instances, batch_size)
|
||||||
|
return self.get_indexed_data(indices)
|
||||||
|
|
||||||
|
def get_entire_batch(self, batch_size):
|
||||||
|
all_indices = np.arange(self.num_instances)
|
||||||
|
for start in all_indices[::batch_size]:
|
||||||
|
batch_indices = all_indices[start : start + batch_size]
|
||||||
|
yield self.get_indexed_data(batch_indices)
|
||||||
|
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._data)
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn(self, batch):
|
||||||
|
merged_batch = {key: [d[key] for d in batch] for key in batch[0]}
|
||||||
|
out = {}
|
||||||
|
for key in merged_batch:
|
||||||
|
if key in ['qa_pair', 'q_len', 'q_turns_len', 'masked_pos', 'mask_labels', 'next_sentence_label', 'masked_weights']:
|
||||||
|
seq = pad_seq(merged_batch[key], pad_token=1)
|
||||||
|
out[key] = torch.concat(seq, dim=0)
|
||||||
|
elif key in ['qa_turns']:
|
||||||
|
if merged_batch[key][0] is not None:
|
||||||
|
seq = pad_2d_seq(merged_batch[key], pad_token=1)
|
||||||
|
out[key] = torch.concat(seq, dim=0).type(torch.int)
|
||||||
|
else:
|
||||||
|
out[key] = None
|
||||||
|
|
||||||
|
elif key in ['features']:
|
||||||
|
#features = [f.unsqueeze(1) for f in merged_batch[key]]
|
||||||
|
# pad video featues
|
||||||
|
features = pad_sequence(merged_batch[key], batch_first=True)
|
||||||
|
out[key] = features
|
||||||
|
else:
|
||||||
|
out[key] = merged_batch[key]
|
||||||
|
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def encode_turns(self, turns):
|
||||||
|
encoded_turns = []
|
||||||
|
for turn in turns:
|
||||||
|
encoded_turn = self._tokenizer(
|
||||||
|
turn,
|
||||||
|
padding=True,
|
||||||
|
max_length=self._args.max_length,
|
||||||
|
return_tensors="pt",
|
||||||
|
truncation=True,
|
||||||
|
)
|
||||||
|
# without cls and sep token
|
||||||
|
encoded_turns.append(encoded_turn['input_ids'][:, 1:-1].type(torch.int))
|
||||||
|
return encoded_turns
|
||||||
|
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
dialog_datum = self._data[index]
|
||||||
|
qa_pair = self._data[index]["qa_pair"]
|
||||||
|
qa_turns = self._data[index]["qa_turns"]
|
||||||
|
q_turns = self._data[index]["q_turns"]
|
||||||
|
next_sentence_label = self._data[index]["next_sentence_label"]
|
||||||
|
|
||||||
|
if self._features:
|
||||||
|
feature = self._features[dialog_datum["image_name"]]
|
||||||
|
|
||||||
|
# mask the qa_pair
|
||||||
|
qa_pair_as_tokens = self._tokenizer.tokenize(qa_pair[0])
|
||||||
|
if next_sentence_label[0] == 0:
|
||||||
|
end_id = qa_pair_as_tokens.index('[SEP_1]')
|
||||||
|
effective_length = end_id + 1
|
||||||
|
start_id = 0
|
||||||
|
else:
|
||||||
|
end_id = len(qa_pair_as_tokens) - 1
|
||||||
|
effective_length = len(qa_pair_as_tokens)
|
||||||
|
start_id = 0
|
||||||
|
|
||||||
|
if self._args.only_mask_ans:
|
||||||
|
effective_length = len(qa_pair_as_tokens) - qa_pair_as_tokens.index('[SEP_1]')
|
||||||
|
start_id = qa_pair_as_tokens.index('[SEP_1]')
|
||||||
|
|
||||||
|
# get length of current and prv questions
|
||||||
|
q_len = [qa_pair_as_tokens.index('[SEP_1]')]
|
||||||
|
q_turns_len = [len(self._tokenizer.tokenize(q[0])) for q in q_turns]
|
||||||
|
|
||||||
|
qa_pair_ids, masked_ids, masked_pos, masked_weights = self.conduct_mask(
|
||||||
|
tokens=qa_pair_as_tokens,
|
||||||
|
effective_length=effective_length,
|
||||||
|
start_id = start_id,
|
||||||
|
end_id=end_id
|
||||||
|
)
|
||||||
|
|
||||||
|
qa_turns_ids = self.encode_turns(qa_turns)
|
||||||
|
|
||||||
|
|
||||||
|
# Pack the sample.
|
||||||
|
sample = {
|
||||||
|
"qa_pair": torch.tensor(qa_pair_ids).unsqueeze(0),
|
||||||
|
"qa_turns": qa_turns_ids,
|
||||||
|
"features": feature,
|
||||||
|
"q_len": torch.tensor(q_len).unsqueeze(0),
|
||||||
|
"q_turns_len": torch.tensor(q_turns_len).unsqueeze(0),
|
||||||
|
"masked_pos": torch.tensor(masked_pos).unsqueeze(0),
|
||||||
|
"mask_labels": torch.tensor(masked_ids).unsqueeze(0),
|
||||||
|
"masked_weights": torch.tensor(masked_weights).unsqueeze(0),
|
||||||
|
"next_sentence_label": torch.tensor(next_sentence_label).unsqueeze(0)
|
||||||
|
}
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
class VisualFeatureLoader:
|
||||||
|
"""Loads visual features for SIMMC 2.1 ambiguous candidate identification."""
|
||||||
|
|
||||||
|
UNAVAILABLE_IMAGES = [
|
||||||
|
"cloth_store_1416238_woman_20_6.png",
|
||||||
|
"cloth_store_1416238_woman_19_0.png",
|
||||||
|
"cloth_store_1416238_woman_4_8.png",
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, feature_path, feature_size):
|
||||||
|
"""Read the features from the path."""
|
||||||
|
self._features = torch.load(feature_path)
|
||||||
|
self._feature_size = feature_size
|
||||||
|
self._zero_feature = torch.zeros((1, self._feature_size), dtype=torch.float)
|
||||||
|
|
||||||
|
def __getitem__(self, label):
|
||||||
|
"""Get the feature given image label."""
|
||||||
|
assert (
|
||||||
|
label in self._features or label in self.UNAVAILABLE_IMAGES
|
||||||
|
), f"{label} not found!"
|
||||||
|
if label in self.UNAVAILABLE_IMAGES:
|
||||||
|
return self._zero_feature
|
||||||
|
return self._features[label]
|
||||||
|
|
||||||
|
def cuda(self):
|
||||||
|
"""Move the features to cuda."""
|
||||||
|
self._zero_feature = self._zero_feature.cuda()
|
||||||
|
for key, val in self._features.items():
|
||||||
|
self._features[key] = val.cuda()
|
253
src/utils/simmc2_dataset/dataloader_test_gen.py
Normal file
253
src/utils/simmc2_dataset/dataloader_test_gen.py
Normal file
|
@ -0,0 +1,253 @@
|
||||||
|
#! /usr/bin/env python
|
||||||
|
"""
|
||||||
|
Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
All rights reserved.
|
||||||
|
This source code is licensed under the license found in the LICENSE file in the
|
||||||
|
root directory of this source tree.
|
||||||
|
|
||||||
|
Dataloader for ambiguous candidates identification task on SIMMC 2.1.
|
||||||
|
|
||||||
|
Author(s): Satwik Kottur
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
from random import shuffle
|
||||||
|
from random import random as rand
|
||||||
|
#from src.utils.vd_bert.loader_utils import get_random_word
|
||||||
|
|
||||||
|
|
||||||
|
def pad_seq(seqs, pad_token, return_lens=False, is_vft=False):
|
||||||
|
lengths = [s.shape[1] for s in seqs]
|
||||||
|
max_length = max(lengths)
|
||||||
|
output = []
|
||||||
|
for seq in seqs:
|
||||||
|
if is_vft:
|
||||||
|
if len(seq.shape)==4: # spatio-temporal feature
|
||||||
|
result = torch.ones(((1, max_length), seq.shape[1], seq.shape[2], seq.shape[3]), dtype=seq.dtype)*pad_token
|
||||||
|
else:
|
||||||
|
result = torch.ones(((1, max_length), seq.shape[-1]), dtype=seq.dtype)*pad_token
|
||||||
|
else:
|
||||||
|
result = torch.ones((1, max_length), dtype=seq.dtype)*pad_token
|
||||||
|
result[0, :seq.shape[1]] = seq
|
||||||
|
output.append(result)
|
||||||
|
if return_lens:
|
||||||
|
return lengths, output
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def pad_2d_seq(seqs, pad_token, return_lens=False, is_vft=False):
|
||||||
|
lens1 = [len(s) for s in seqs]
|
||||||
|
max_len1 = max(lens1)
|
||||||
|
all_seqs = []
|
||||||
|
for seq in seqs:
|
||||||
|
all_seqs.extend(seq)
|
||||||
|
lens2 = [s.shape[1] for s in all_seqs]
|
||||||
|
max_len2 = max(lens2)
|
||||||
|
output = []
|
||||||
|
all_lens = []
|
||||||
|
for seq in seqs:
|
||||||
|
if is_vft:
|
||||||
|
result = torch.ones((max_len1, max_len2, seq[0].shape[-1]))*pad_token
|
||||||
|
else:
|
||||||
|
result = torch.ones((1, max_len1, max_len2))*pad_token
|
||||||
|
#turn_lens = torch.ones(max_len1, dtype=np.int)
|
||||||
|
offset = max_len1 - len(seq)
|
||||||
|
for turn_idx, turn in enumerate(seq):
|
||||||
|
#result[turn_idx,:turn.shape[0]] = turn
|
||||||
|
# padding should be at the first turn idxs (Reason: result of last n turns is used for state creation)
|
||||||
|
result[0, turn_idx + offset,:turn.shape[1]] = turn
|
||||||
|
#turn_lens[turn_idx] = turn.shape[0]
|
||||||
|
output.append(result)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class Simmc2DatasetTest(Dataset):
|
||||||
|
def __init__(self, tokenizer, feature_loader, load_path, args, hidden_labels=False):
|
||||||
|
self._tokenizer = tokenizer
|
||||||
|
self._features = feature_loader
|
||||||
|
self._args = args
|
||||||
|
self._hidden_labels = hidden_labels
|
||||||
|
print("Loading: {}".format(load_path))
|
||||||
|
with open(load_path, "r") as file_id:
|
||||||
|
self._raw_data = json.load(file_id)
|
||||||
|
# Also read the source data for evaluation.
|
||||||
|
with open(self._raw_data["source_path"], "r") as file_id:
|
||||||
|
self.source_data = json.load(file_id)
|
||||||
|
self._data = self._raw_data["data"]
|
||||||
|
|
||||||
|
self.num_utterances = 2 * args.max_turns + 1
|
||||||
|
self.num_instances = len(self._data)
|
||||||
|
self.device = torch.cuda if args.use_gpu else torch
|
||||||
|
|
||||||
|
|
||||||
|
def conduct_mask(self, tokens, effective_length, start_id, end_id):
|
||||||
|
# taken from https://github.com/salesforce/VD-BERT
|
||||||
|
# For masked Language Models
|
||||||
|
cand_pos = []
|
||||||
|
special_pos = set()
|
||||||
|
|
||||||
|
n_pred = min(self._args.max_n_masked, max(
|
||||||
|
1, int(round(effective_length * self._args.p_mask))))
|
||||||
|
|
||||||
|
# candidate positions of masked tokens
|
||||||
|
for i, tk in enumerate(tokens):
|
||||||
|
# only mask tokens_b (target sequence)
|
||||||
|
# we will mask [SEP] as an ending symbol
|
||||||
|
if (i >= start_id) and (tk != '[CLS]') and (tk != '[PAD]') and (i < end_id):
|
||||||
|
cand_pos.append(i)
|
||||||
|
else:
|
||||||
|
special_pos.add(i)
|
||||||
|
|
||||||
|
shuffle(cand_pos)
|
||||||
|
masked_pos = cand_pos[:n_pred]
|
||||||
|
|
||||||
|
masked_tokens = [tokens[pos] for pos in masked_pos]
|
||||||
|
for pos in masked_pos:
|
||||||
|
if self._args.finetune:
|
||||||
|
tokens[pos] = '[MASK]'
|
||||||
|
continue
|
||||||
|
if rand() < 0.8: # 80%
|
||||||
|
tokens[pos] = '[MASK]'
|
||||||
|
#elif rand() < 0.5: # 10%
|
||||||
|
# tokens[pos] = get_random_word(self.vocab_words)
|
||||||
|
# when n_pred < max_pred, we only calculate loss within n_pred
|
||||||
|
masked_weights = [1] * len(masked_tokens)
|
||||||
|
|
||||||
|
# Token Indexing
|
||||||
|
input_ids = self._tokenizer.convert_tokens_to_ids(tokens)
|
||||||
|
masked_ids = self._tokenizer.convert_tokens_to_ids(masked_tokens)
|
||||||
|
|
||||||
|
if self._args.max_n_masked > n_pred:
|
||||||
|
n_pad = self._args.max_n_masked - n_pred
|
||||||
|
masked_ids.extend([0] * n_pad)
|
||||||
|
masked_pos.extend([0] * n_pad)
|
||||||
|
masked_weights.extend([0] * n_pad)
|
||||||
|
|
||||||
|
assert len(masked_ids) == len(masked_pos) == len(masked_weights) == self._args.max_n_masked, \
|
||||||
|
"[masked] id: %d, pos: %d, weights: %d" % (len(masked_ids), len(masked_pos), len(masked_weights))
|
||||||
|
|
||||||
|
return input_ids, masked_ids, masked_pos, masked_weights
|
||||||
|
|
||||||
|
|
||||||
|
def get_random_batch(self, batch_size):
|
||||||
|
indices = np.random.randint(0, self.num_instances, batch_size)
|
||||||
|
return self.get_indexed_data(indices)
|
||||||
|
|
||||||
|
def get_entire_batch(self, batch_size):
|
||||||
|
all_indices = np.arange(self.num_instances)
|
||||||
|
for start in all_indices[::batch_size]:
|
||||||
|
batch_indices = all_indices[start : start + batch_size]
|
||||||
|
yield self.get_indexed_data(batch_indices)
|
||||||
|
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._data)
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn(self, batch):
|
||||||
|
merged_batch = {key: [d[key] for d in batch] for key in batch[0]}
|
||||||
|
out = {}
|
||||||
|
for key in merged_batch:
|
||||||
|
if key in ['qa_pair', 'masked_pos', 'mask_labels', 'next_sentence_label', 'masked_weights', 'q_len']:
|
||||||
|
seq = pad_seq(merged_batch[key], pad_token=1)
|
||||||
|
out[key] = torch.concat(seq, dim=0)
|
||||||
|
elif key in ['qa_turns']:
|
||||||
|
if merged_batch[key][0] is not None:
|
||||||
|
seq = pad_2d_seq(merged_batch[key], pad_token=1)
|
||||||
|
out[key] = torch.concat(seq, dim=0).type(torch.int)
|
||||||
|
else:
|
||||||
|
out[key] = None
|
||||||
|
elif key in ['answer']:
|
||||||
|
out[key] = merged_batch[key]
|
||||||
|
elif key in ['features']:
|
||||||
|
#features = [f.unsqueeze(1) for f in merged_batch[key]]
|
||||||
|
# pad video featues
|
||||||
|
features = pad_sequence(merged_batch[key], batch_first=True)
|
||||||
|
out[key] = features
|
||||||
|
else:
|
||||||
|
out[key] = merged_batch[key]
|
||||||
|
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def encode_turns(self, turns):
|
||||||
|
encoded_turns = []
|
||||||
|
for turn in turns:
|
||||||
|
encoded_turn = self._tokenizer(
|
||||||
|
turn,
|
||||||
|
padding=True,
|
||||||
|
max_length=self._args.max_length,
|
||||||
|
return_tensors="pt",
|
||||||
|
truncation=True,
|
||||||
|
)
|
||||||
|
# without cls and sep token
|
||||||
|
encoded_turns.append(encoded_turn['input_ids'][:, 1:-1].type(torch.int))
|
||||||
|
return encoded_turns
|
||||||
|
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
dialog_datum = self._data[index]
|
||||||
|
qa_pair = self._data[index]["qa_pair"]
|
||||||
|
qa_turns = self._data[index]["qa_turns"]
|
||||||
|
answer = self._data[index]["answer"]
|
||||||
|
next_sentence_label = self._data[index]["next_sentence_label"]
|
||||||
|
|
||||||
|
if self._features:
|
||||||
|
feature = self._features[dialog_datum["image_name"]]
|
||||||
|
|
||||||
|
qa_pair_as_tokens = self._tokenizer.tokenize(qa_pair[0])
|
||||||
|
q_len = [qa_pair_as_tokens.index('[SEP_1]')]
|
||||||
|
|
||||||
|
|
||||||
|
qa_pair_ids = self._tokenizer.convert_tokens_to_ids(qa_pair_as_tokens)
|
||||||
|
qa_turns_ids = self.encode_turns(qa_turns)
|
||||||
|
|
||||||
|
|
||||||
|
# Pack the sample.
|
||||||
|
sample = {
|
||||||
|
"answer": answer,
|
||||||
|
"qa_pair": torch.tensor(qa_pair_ids).unsqueeze(0),
|
||||||
|
"q_len": torch.tensor(q_len).unsqueeze(0),
|
||||||
|
"qa_turns": qa_turns_ids,
|
||||||
|
"features": feature
|
||||||
|
}
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
class VisualFeatureLoader:
|
||||||
|
"""Loads visual features for SIMMC 2.1 ambiguous candidate identification."""
|
||||||
|
|
||||||
|
UNAVAILABLE_IMAGES = [
|
||||||
|
"cloth_store_1416238_woman_20_6.png",
|
||||||
|
"cloth_store_1416238_woman_19_0.png",
|
||||||
|
"cloth_store_1416238_woman_4_8.png",
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, feature_path, feature_size):
|
||||||
|
"""Read the features from the path."""
|
||||||
|
self._features = torch.load(feature_path)
|
||||||
|
self._feature_size = feature_size
|
||||||
|
self._zero_feature = torch.zeros((1, self._feature_size), dtype=torch.float)
|
||||||
|
|
||||||
|
def __getitem__(self, label):
|
||||||
|
"""Get the feature given image label."""
|
||||||
|
assert (
|
||||||
|
label in self._features or label in self.UNAVAILABLE_IMAGES
|
||||||
|
), f"{label} not found!"
|
||||||
|
if label in self.UNAVAILABLE_IMAGES:
|
||||||
|
return self._zero_feature
|
||||||
|
return self._features[label]
|
||||||
|
|
||||||
|
def cuda(self):
|
||||||
|
"""Move the features to cuda."""
|
||||||
|
self._zero_feature = self._zero_feature.cuda()
|
||||||
|
for key, val in self._features.items():
|
||||||
|
self._features[key] = val.cuda()
|
150
src/utils/simmc2_dataset/format_data.py
Normal file
150
src/utils/simmc2_dataset/format_data.py
Normal file
|
@ -0,0 +1,150 @@
|
||||||
|
#! /usr/bin/env python
|
||||||
|
"""
|
||||||
|
Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
All rights reserved.
|
||||||
|
This source code is licensed under the license found in the LICENSE file in the
|
||||||
|
root directory of this source tree.
|
||||||
|
|
||||||
|
Reads SIMMC 2.1 dataset, creates train, devtest, dev formats for ambiguous candidates.
|
||||||
|
|
||||||
|
Author(s): Satwik Kottur
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import copy
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
SPLITS = ["train", "dev", "devtest", "teststd"]
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_name(scene_ids, turn_ind):
|
||||||
|
"""Given scene ids and turn index, get the image name.
|
||||||
|
"""
|
||||||
|
sorted_scene_ids = sorted(
|
||||||
|
((int(key), val) for key, val in scene_ids.items()),
|
||||||
|
key=lambda x: x[0],
|
||||||
|
reverse=True
|
||||||
|
)
|
||||||
|
# NOTE: Hardcoded to only two scenes.
|
||||||
|
if turn_ind >= sorted_scene_ids[0][0]:
|
||||||
|
scene_label = sorted_scene_ids[0][1]
|
||||||
|
else:
|
||||||
|
scene_label = sorted_scene_ids[1][1]
|
||||||
|
image_label = scene_label
|
||||||
|
if "m_" in scene_label:
|
||||||
|
image_label = image_label.replace("m_", "")
|
||||||
|
return f"{image_label}.png", scene_label
|
||||||
|
|
||||||
|
|
||||||
|
def get_object_mapping(scene_label, args):
|
||||||
|
"""Get the object mapping for a given scene.
|
||||||
|
"""
|
||||||
|
scene_json_path = os.path.join(
|
||||||
|
args["scene_json_folder"], f"{scene_label}_scene.json"
|
||||||
|
)
|
||||||
|
with open(scene_json_path, "r") as file_id:
|
||||||
|
scene_objects = json.load(file_id)["scenes"][0]["objects"]
|
||||||
|
object_map = [ii["index"] for ii in scene_objects]
|
||||||
|
return object_map
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
for split in SPLITS:
|
||||||
|
read_path = args[f"simmc_{split}_json"]
|
||||||
|
print(f"Reading: {read_path}")
|
||||||
|
with open(read_path, "r") as file_id:
|
||||||
|
dialogs = json.load(file_id)
|
||||||
|
|
||||||
|
# Reformat into simple strings with positive and negative labels.
|
||||||
|
# (dialog string, label)
|
||||||
|
ambiguous_candidates_data = []
|
||||||
|
for dialog_id, dialog_datum in enumerate(dialogs["dialogue_data"]):
|
||||||
|
turns = []
|
||||||
|
q_turns = []
|
||||||
|
a_turns = []
|
||||||
|
for turn_ind, turn_datum in enumerate(dialog_datum["dialogue"]):
|
||||||
|
query = [turn_datum["transcript"]]
|
||||||
|
answer = [turn_datum["system_transcript"]]
|
||||||
|
|
||||||
|
#annotations = turn_datum["transcript_annotated"]
|
||||||
|
#if annotations.get("disambiguation_label", False):
|
||||||
|
#label = annotations["disambiguation_candidates"]
|
||||||
|
image_name, scene_label = get_image_name(
|
||||||
|
dialog_datum["scene_ids"], turn_ind
|
||||||
|
)
|
||||||
|
# If dialog contains multiple scenes, map it accordingly.
|
||||||
|
object_map = get_object_mapping(scene_label, args)
|
||||||
|
new_datum = {
|
||||||
|
"query": query,
|
||||||
|
"answer": answer,
|
||||||
|
"q_turns": copy.deepcopy(q_turns),
|
||||||
|
"a_turns": copy.deepcopy(a_turns),
|
||||||
|
"turns": copy.deepcopy(turns),
|
||||||
|
"dialog_id": dialog_datum["dialogue_idx"],
|
||||||
|
"turn_id": turn_ind,
|
||||||
|
#"input_text": copy.deepcopy(history),
|
||||||
|
#"ambiguous_candidates": label,
|
||||||
|
"image_name": image_name,
|
||||||
|
"object_map": object_map,
|
||||||
|
}
|
||||||
|
|
||||||
|
ambiguous_candidates_data.append(new_datum)
|
||||||
|
|
||||||
|
turns.append([turn_datum["transcript"] + turn_datum["system_transcript"]])
|
||||||
|
q_turns.append(query)
|
||||||
|
a_turns.append(answer)
|
||||||
|
|
||||||
|
|
||||||
|
# Ignore if system_transcript is not found (last round teststd).
|
||||||
|
# if turn_datum.get("system_transcript", None):
|
||||||
|
# history.append(turn_datum["system_transcript"])
|
||||||
|
|
||||||
|
print(f"# instances [{split}]: {len(ambiguous_candidates_data)}")
|
||||||
|
save_path = os.path.join(
|
||||||
|
args["ambiguous_candidates_save_path"],
|
||||||
|
f"simmc2.1_ambiguous_candidates_dstc11_{split}.json"
|
||||||
|
)
|
||||||
|
print(f"Saving: {save_path}")
|
||||||
|
with open(save_path, "w") as file_id:
|
||||||
|
json.dump(
|
||||||
|
{
|
||||||
|
"source_path": read_path,
|
||||||
|
"split": split,
|
||||||
|
"data": ambiguous_candidates_data,
|
||||||
|
},
|
||||||
|
file_id
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument(
|
||||||
|
"--simmc_train_json", default=None, help="Path to SIMMC 2.1 train"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--simmc_dev_json", default=None, help="Path to SIMMC 2.1 dev"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--simmc_devtest_json", default=None, help="Path to SIMMC 2.1 devtest"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--simmc_teststd_json", default=None, help="Path to SIMMC 2.1 teststd (public)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--scene_json_folder", default=None, help="Path to SIMMC scene jsons"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ambiguous_candidates_save_path",
|
||||||
|
required=True,
|
||||||
|
help="Path to save SIMMC disambiguate JSONs",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed_args = vars(parser.parse_args())
|
||||||
|
except (IOError) as msg:
|
||||||
|
parser.error(str(msg))
|
||||||
|
main(parsed_args)
|
8
src/utils/simmc2_dataset/format_data.sh
Executable file
8
src/utils/simmc2_dataset/format_data.sh
Executable file
|
@ -0,0 +1,8 @@
|
||||||
|
#!/bin/bash
|
||||||
|
DATA_FOLDER="../../data/"
|
||||||
|
python format_data.py \
|
||||||
|
--simmc_train_json "/scratch/hochmeister/simmc2/data/simmc2.1_dials_dstc11_train.json" \
|
||||||
|
--simmc_dev_json "/scratch/hochmeister/simmc2/data/simmc2.1_dials_dstc11_dev.json" \
|
||||||
|
--simmc_devtest_json "/scratch/hochmeister/simmc2/data/simmc2.1_dials_dstc11_devtest.json" \
|
||||||
|
--scene_json_folder "/scratch/hochmeister/simmc2/data/public/" \
|
||||||
|
--ambiguous_candidates_save_path "/scratch/hochmeister/simmc2/data/ambiguous_candidates/"
|
224
src/utils/simmc2_dataset/format_data_subtask4_b.py
Normal file
224
src/utils/simmc2_dataset/format_data_subtask4_b.py
Normal file
|
@ -0,0 +1,224 @@
|
||||||
|
#! /usr/bin/env python
|
||||||
|
"""
|
||||||
|
Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
All rights reserved.
|
||||||
|
This source code is licensed under the license found in the LICENSE file in the
|
||||||
|
root directory of this source tree.
|
||||||
|
|
||||||
|
Reads SIMMC 2.1 dataset, creates train, devtest, dev formats for ambiguous candidates.
|
||||||
|
|
||||||
|
Author(s): Satwik Kottur
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import copy
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
|
SPLITS = ["train", "dev", "devtest", "teststd"]
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_name(scene_ids, turn_ind):
|
||||||
|
"""Given scene ids and turn index, get the image name.
|
||||||
|
"""
|
||||||
|
sorted_scene_ids = sorted(
|
||||||
|
((int(key), val) for key, val in scene_ids.items()),
|
||||||
|
key=lambda x: x[0],
|
||||||
|
reverse=True
|
||||||
|
)
|
||||||
|
# NOTE: Hardcoded to only two scenes.
|
||||||
|
if turn_ind >= sorted_scene_ids[0][0]:
|
||||||
|
scene_label = sorted_scene_ids[0][1]
|
||||||
|
else:
|
||||||
|
scene_label = sorted_scene_ids[1][1]
|
||||||
|
image_label = scene_label
|
||||||
|
if "m_" in scene_label:
|
||||||
|
image_label = image_label.replace("m_", "")
|
||||||
|
return f"{image_label}.png", scene_label
|
||||||
|
|
||||||
|
|
||||||
|
def get_object_mapping(scene_label, args):
|
||||||
|
"""Get the object mapping for a given scene.
|
||||||
|
"""
|
||||||
|
scene_json_path = os.path.join(
|
||||||
|
args["scene_json_folder"], f"{scene_label}_scene.json"
|
||||||
|
)
|
||||||
|
with open(scene_json_path, "r") as file_id:
|
||||||
|
scene_objects = json.load(file_id)["scenes"][0]["objects"]
|
||||||
|
object_map = [ii["index"] for ii in scene_objects]
|
||||||
|
return object_map
|
||||||
|
|
||||||
|
|
||||||
|
def dictionary_to_string(dictionary):
|
||||||
|
result = ""
|
||||||
|
for k, v in dictionary.items():
|
||||||
|
result += k + ":"
|
||||||
|
result += str(v) + " "
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_answers(dialogs):
|
||||||
|
all_answers = []
|
||||||
|
for dialog_id, dialog_datum in enumerate(dialogs["dialogue_data"]):
|
||||||
|
for turn_ind, turn_datum in enumerate(dialog_datum["dialogue"]):
|
||||||
|
all_answers.append(turn_datum["system_transcript"])
|
||||||
|
return all_answers
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
for split in SPLITS:
|
||||||
|
read_path = args[f"simmc_{split}_json"]
|
||||||
|
print(f"Reading: {read_path}")
|
||||||
|
with open(read_path, "r") as file_id:
|
||||||
|
dialogs = json.load(file_id)
|
||||||
|
|
||||||
|
|
||||||
|
# load the metadata files
|
||||||
|
with open(args["furniture_prefab_metadata"], "r") as file:
|
||||||
|
furniture_metadata = json.load(file)
|
||||||
|
|
||||||
|
with open(args["fashion_prefab_metadata"], "r") as file:
|
||||||
|
fashion_metadata = json.load(file)
|
||||||
|
|
||||||
|
# get all answer fromm all dialogues to sample answer candidates from for each dialogue iteration
|
||||||
|
all_answers = get_all_answers(dialogs)
|
||||||
|
|
||||||
|
|
||||||
|
# Reformat into simple strings with positive and negative labels.
|
||||||
|
# (dialog string, label)
|
||||||
|
ambiguous_candidates_data = []
|
||||||
|
for dialog_id, dialog_datum in enumerate(dialogs["dialogue_data"]):
|
||||||
|
turns = []
|
||||||
|
q_turns = []
|
||||||
|
a_turns = []
|
||||||
|
|
||||||
|
for turn_ind, turn_datum in enumerate(dialog_datum["dialogue"]):
|
||||||
|
query = [turn_datum["transcript"]]
|
||||||
|
answer = [turn_datum["system_transcript"]]
|
||||||
|
answer_candidates = []
|
||||||
|
|
||||||
|
# sample random answers from the list of all answers as answer candidates
|
||||||
|
# sample n_answer_candidates - 1 wrong answer candidates from the list of all answers
|
||||||
|
for _ in range(int(args["n_answer_candidates"]) - 1):
|
||||||
|
random_idx = random.randint(0, len(all_answers) - 1)
|
||||||
|
answer_candidates.append([all_answers[random_idx]])
|
||||||
|
answer_candidates.insert(0, answer)
|
||||||
|
#random.shuffle(answer_candidates)
|
||||||
|
|
||||||
|
#annotations = turn_datum["transcript_annotated"]
|
||||||
|
#if annotations.get("disambiguation_label", False):
|
||||||
|
#label = annotations["disambiguation_candidates"]
|
||||||
|
image_name, scene_id = get_image_name(
|
||||||
|
dialog_datum["scene_ids"], turn_ind
|
||||||
|
)
|
||||||
|
|
||||||
|
# load the scene files and get all the prefab pahts to get the object descriptions for each scene
|
||||||
|
prefab_paths = []
|
||||||
|
scene_path = os.path.join(args["scene_json_folder"], f"{scene_id}_scene.json")
|
||||||
|
with open(scene_path, "r") as scene_file:
|
||||||
|
scene_data = json.load(scene_file)
|
||||||
|
for scene in scene_data["scenes"]:
|
||||||
|
for object in scene["objects"]:
|
||||||
|
prefab_paths.append(object["prefab_path"])
|
||||||
|
|
||||||
|
# get the metadata for all objects of the scene (prefab_paths)
|
||||||
|
object_metadata = []
|
||||||
|
for prefab_path in prefab_paths:
|
||||||
|
if scene_id[:11] in ["cloth_store", "m_cloth_sto"]:
|
||||||
|
object_dict = fashion_metadata[prefab_path]
|
||||||
|
elif scene_id[:7] == "wayfair":
|
||||||
|
object_dict = furniture_metadata[prefab_path]
|
||||||
|
object_str = dictionary_to_string(object_dict)
|
||||||
|
object_metadata.append([object_str])
|
||||||
|
|
||||||
|
|
||||||
|
# If dialog contains multiple scenes, map it accordingly.
|
||||||
|
#object_map = get_object_mapping(scene_label, args)
|
||||||
|
new_datum = {
|
||||||
|
"query": query,
|
||||||
|
"answer": answer,
|
||||||
|
"answer_candidates": answer_candidates,
|
||||||
|
"q_turns": copy.deepcopy(q_turns),
|
||||||
|
"a_turns": copy.deepcopy(a_turns),
|
||||||
|
"turns": copy.deepcopy(turns),
|
||||||
|
"object_metadata": object_metadata,
|
||||||
|
"dialog_id": dialog_datum["dialogue_idx"],
|
||||||
|
"turn_id": turn_ind,
|
||||||
|
#"input_text": copy.deepcopy(history),
|
||||||
|
#"ambiguous_candidates": label,
|
||||||
|
"image_name": image_name,
|
||||||
|
#"object_map": object_map,
|
||||||
|
}
|
||||||
|
|
||||||
|
ambiguous_candidates_data.append(new_datum)
|
||||||
|
|
||||||
|
turns.append([turn_datum["transcript"] + turn_datum["system_transcript"]])
|
||||||
|
q_turns.append(query)
|
||||||
|
a_turns.append(answer)
|
||||||
|
|
||||||
|
|
||||||
|
# Ignore if system_transcript is not found (last round teststd).
|
||||||
|
# if turn_datum.get("system_transcript", None):
|
||||||
|
# history.append(turn_datum["system_transcript"])
|
||||||
|
|
||||||
|
print(f"# instances [{split}]: {len(ambiguous_candidates_data)}")
|
||||||
|
save_path = os.path.join(
|
||||||
|
args["ambiguous_candidates_save_path"],
|
||||||
|
f"simmc2.1_ambiguous_candidates_dstc11_{split}.json"
|
||||||
|
)
|
||||||
|
print(f"Saving: {save_path}")
|
||||||
|
with open(save_path, "w") as file_id:
|
||||||
|
json.dump(
|
||||||
|
{
|
||||||
|
"source_path": read_path,
|
||||||
|
"split": split,
|
||||||
|
"data": ambiguous_candidates_data,
|
||||||
|
},
|
||||||
|
file_id
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument(
|
||||||
|
"--simmc_train_json", default=None, help="Path to SIMMC 2.1 train"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--simmc_dev_json", default=None, help="Path to SIMMC 2.1 dev"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--simmc_devtest_json", default=None, help="Path to SIMMC 2.1 devtest"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--simmc_teststd_json", default=None, help="Path to SIMMC 2.1 teststd (public)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--scene_json_folder", default=None, help="Path to SIMMC scene jsons"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ambiguous_candidates_save_path",
|
||||||
|
required=True,
|
||||||
|
help="Path to save SIMMC disambiguate JSONs",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--fashion_prefab_metadata", required=True,
|
||||||
|
help="Path to the file with all metadata for fashion objects"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--furniture_prefab_metadata", required=True,
|
||||||
|
help="Path to the file with all metadata for fashion objects"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--n_answer_candidates", required=True,
|
||||||
|
help="number of answer candidates for the ranking task"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed_args = vars(parser.parse_args())
|
||||||
|
except (IOError) as msg:
|
||||||
|
parser.error(str(msg))
|
||||||
|
main(parsed_args)
|
11
src/utils/simmc2_dataset/format_data_subtask4_b.sh
Executable file
11
src/utils/simmc2_dataset/format_data_subtask4_b.sh
Executable file
|
@ -0,0 +1,11 @@
|
||||||
|
#!/bin/bash
|
||||||
|
DATA_FOLDER="../../data/"
|
||||||
|
python format_data_with_object_descriptions.py \
|
||||||
|
--simmc_train_json "/scratch/hochmeister/simmc2/data/simmc2.1_dials_dstc11_train.json" \
|
||||||
|
--simmc_dev_json "/scratch/hochmeister/simmc2/data/simmc2.1_dials_dstc11_dev.json" \
|
||||||
|
--simmc_devtest_json "/scratch/hochmeister/simmc2/data/simmc2.1_dials_dstc11_devtest.json" \
|
||||||
|
--scene_json_folder "/scratch/hochmeister/simmc2/data/public/" \
|
||||||
|
--ambiguous_candidates_save_path "/scratch/hochmeister/simmc2/data/subtask_4_b_data/"
|
||||||
|
--fashion_prefab_metadata "/scratch/hochmeister/simmc2/data/fashion_prefab_metadata_all.json"
|
||||||
|
--furniture_prefab_metadata "/scratch/hochmeister/simmc2/data/furniture_prefab_metadata_all.json"
|
||||||
|
--n_answer_candidates 10
|
207
src/utils/simmc2_dataset/format_data_subtask4_mlm_nsp.py
Normal file
207
src/utils/simmc2_dataset/format_data_subtask4_mlm_nsp.py
Normal file
|
@ -0,0 +1,207 @@
|
||||||
|
#! /usr/bin/env python
|
||||||
|
"""
|
||||||
|
Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
All rights reserved.
|
||||||
|
This source code is licensed under the license found in the LICENSE file in the
|
||||||
|
root directory of this source tree.
|
||||||
|
|
||||||
|
Reads SIMMC 2.1 dataset, creates train, devtest, dev formats for ambiguous candidates.
|
||||||
|
|
||||||
|
Author(s): Satwik Kottur
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import copy
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
|
SPLITS = ["train", "dev", "devtest", "teststd"]
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_name(scene_ids, turn_ind):
|
||||||
|
"""Given scene ids and turn index, get the image name.
|
||||||
|
"""
|
||||||
|
sorted_scene_ids = sorted(
|
||||||
|
((int(key), val) for key, val in scene_ids.items()),
|
||||||
|
key=lambda x: x[0],
|
||||||
|
reverse=True
|
||||||
|
)
|
||||||
|
# NOTE: Hardcoded to only two scenes.
|
||||||
|
if turn_ind >= sorted_scene_ids[0][0]:
|
||||||
|
scene_label = sorted_scene_ids[0][1]
|
||||||
|
else:
|
||||||
|
scene_label = sorted_scene_ids[1][1]
|
||||||
|
image_label = scene_label
|
||||||
|
if "m_" in scene_label:
|
||||||
|
image_label = image_label.replace("m_", "")
|
||||||
|
return f"{image_label}.png", scene_label
|
||||||
|
|
||||||
|
|
||||||
|
def get_object_mapping(scene_label, args):
|
||||||
|
"""Get the object mapping for a given scene.
|
||||||
|
"""
|
||||||
|
scene_json_path = os.path.join(
|
||||||
|
args["scene_json_folder"], f"{scene_label}_scene.json"
|
||||||
|
)
|
||||||
|
with open(scene_json_path, "r") as file_id:
|
||||||
|
scene_objects = json.load(file_id)["scenes"][0]["objects"]
|
||||||
|
object_map = [ii["index"] for ii in scene_objects]
|
||||||
|
return object_map
|
||||||
|
|
||||||
|
|
||||||
|
def dictionary_to_string(dictionary):
|
||||||
|
result = ""
|
||||||
|
for k, v in dictionary.items():
|
||||||
|
result += k + ":"
|
||||||
|
result += str(v) + " "
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_answers(dialogs):
|
||||||
|
all_answers = []
|
||||||
|
for dialog_id, dialog_datum in enumerate(dialogs["dialogue_data"]):
|
||||||
|
for turn_ind, turn_datum in enumerate(dialog_datum["dialogue"]):
|
||||||
|
all_answers.append(turn_datum["system_transcript"])
|
||||||
|
return all_answers
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
for split in SPLITS:
|
||||||
|
read_path = args[f"simmc_{split}_json"]
|
||||||
|
print(f"Reading: {read_path}")
|
||||||
|
with open(read_path, "r") as file_id:
|
||||||
|
dialogs = json.load(file_id)
|
||||||
|
|
||||||
|
# get all answer fromm all dialogues to sample answer candidates from for each dialogue iteration
|
||||||
|
all_answers = get_all_answers(dialogs)
|
||||||
|
|
||||||
|
ambiguous_candidates_data = []
|
||||||
|
for dialog_id, dialog_datum in enumerate(dialogs["dialogue_data"]):
|
||||||
|
q_turns = []
|
||||||
|
a_turns = []
|
||||||
|
qa_turns = []
|
||||||
|
|
||||||
|
for turn_ind, turn_datum in enumerate(dialog_datum["dialogue"]):
|
||||||
|
query = turn_datum["transcript"]
|
||||||
|
answer = turn_datum["system_transcript"]
|
||||||
|
|
||||||
|
# wrong answer is used to create false sample for nsp
|
||||||
|
wrong_answer = random.choice(all_answers)
|
||||||
|
|
||||||
|
qa_pair = query + '[SEP_1]' + answer + '[SEP]'
|
||||||
|
wrong_qa_pair = query + '[SEP_1]' + wrong_answer + '[SEP]'
|
||||||
|
|
||||||
|
image_name, scene_id = get_image_name(
|
||||||
|
dialog_datum["scene_ids"], turn_ind
|
||||||
|
)
|
||||||
|
|
||||||
|
# load the scene files and get all the prefab pahts to get the object descriptions for each scene
|
||||||
|
prefab_paths = []
|
||||||
|
scene_path = os.path.join(args["scene_json_folder"], f"{scene_id}_scene.json")
|
||||||
|
with open(scene_path, "r") as scene_file:
|
||||||
|
scene_data = json.load(scene_file)
|
||||||
|
for scene in scene_data["scenes"]:
|
||||||
|
for object in scene["objects"]:
|
||||||
|
prefab_paths.append(object["prefab_path"])
|
||||||
|
|
||||||
|
# for each dialogue round add a sample with the correct answer and one with a random answer for nsp
|
||||||
|
new_datum_correct_answer = {
|
||||||
|
"query": [query],
|
||||||
|
"answer": [answer],
|
||||||
|
"qa_pair": [qa_pair],
|
||||||
|
"next_sentence_label": [1],
|
||||||
|
"q_turns": copy.deepcopy(q_turns),
|
||||||
|
"a_turns": copy.deepcopy(a_turns),
|
||||||
|
"qa_turns": copy.deepcopy(qa_turns),
|
||||||
|
"dialog_id": dialog_datum["dialogue_idx"],
|
||||||
|
"turn_id": turn_ind,
|
||||||
|
"image_name": image_name,
|
||||||
|
}
|
||||||
|
new_datum_wrong_answer = {
|
||||||
|
"query": [query],
|
||||||
|
"answer": [wrong_answer],
|
||||||
|
"qa_pair": [wrong_qa_pair],
|
||||||
|
"next_sentence_label": [0],
|
||||||
|
"q_turns": copy.deepcopy(q_turns),
|
||||||
|
"a_turns": copy.deepcopy(a_turns),
|
||||||
|
"qa_turns": copy.deepcopy(qa_turns),
|
||||||
|
"dialog_id": dialog_datum["dialogue_idx"],
|
||||||
|
"turn_id": turn_ind,
|
||||||
|
"image_name": image_name,
|
||||||
|
}
|
||||||
|
|
||||||
|
ambiguous_candidates_data.append(new_datum_correct_answer)
|
||||||
|
|
||||||
|
if args['create_false_samples_for_nsp']:
|
||||||
|
ambiguous_candidates_data.append(new_datum_wrong_answer)
|
||||||
|
|
||||||
|
q_turns.append([query])
|
||||||
|
a_turns.append([answer])
|
||||||
|
qa_turns.append([qa_pair])
|
||||||
|
|
||||||
|
|
||||||
|
# Ignore if system_transcript is not found (last round teststd).
|
||||||
|
# if turn_datum.get("system_transcript", None):
|
||||||
|
# history.append(turn_datum["system_transcript"])
|
||||||
|
|
||||||
|
print(f"# instances [{split}]: {len(ambiguous_candidates_data)}")
|
||||||
|
save_path = os.path.join(
|
||||||
|
args["ambiguous_candidates_save_path"],
|
||||||
|
f"simmc2.1_ambiguous_candidates_dstc11_{split}.json"
|
||||||
|
)
|
||||||
|
print(f"Saving: {save_path}")
|
||||||
|
with open(save_path, "w") as file_id:
|
||||||
|
json.dump(
|
||||||
|
{
|
||||||
|
"source_path": read_path,
|
||||||
|
"split": split,
|
||||||
|
"data": ambiguous_candidates_data,
|
||||||
|
},
|
||||||
|
file_id
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument(
|
||||||
|
"--simmc_train_json", default=None, help="Path to SIMMC 2.1 train"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--simmc_dev_json", default=None, help="Path to SIMMC 2.1 dev"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--simmc_devtest_json", default=None, help="Path to SIMMC 2.1 devtest"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--simmc_teststd_json", default=None, help="Path to SIMMC 2.1 teststd (public)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--scene_json_folder", default=None, help="Path to SIMMC scene jsons"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ambiguous_candidates_save_path",
|
||||||
|
required=True,
|
||||||
|
help="Path to save SIMMC disambiguate JSONs",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--fashion_prefab_metadata", required=True,
|
||||||
|
help="Path to the file with all metadata for fashion objects"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--furniture_prefab_metadata", required=True,
|
||||||
|
help="Path to the file with all metadata for fashion objects"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--create_false_samples_for_nsp", action='store_true',
|
||||||
|
help="if set, for each correct sample a wrong one is added"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed_args = vars(parser.parse_args())
|
||||||
|
except (IOError) as msg:
|
||||||
|
parser.error(str(msg))
|
||||||
|
main(parsed_args)
|
202
src/utils/simmc2_dataset/format_data_with_obj_descriptions.py
Normal file
202
src/utils/simmc2_dataset/format_data_with_obj_descriptions.py
Normal file
|
@ -0,0 +1,202 @@
|
||||||
|
#! /usr/bin/env python
|
||||||
|
"""
|
||||||
|
Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
All rights reserved.
|
||||||
|
This source code is licensed under the license found in the LICENSE file in the
|
||||||
|
root directory of this source tree.
|
||||||
|
|
||||||
|
Reads SIMMC 2.1 dataset, creates train, devtest, dev formats for ambiguous candidates.
|
||||||
|
|
||||||
|
Author(s): Satwik Kottur
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import copy
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
SPLITS = ["train", "dev", "devtest", "teststd"]
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_name(scene_ids, turn_ind):
|
||||||
|
"""Given scene ids and turn index, get the image name.
|
||||||
|
"""
|
||||||
|
sorted_scene_ids = sorted(
|
||||||
|
((int(key), val) for key, val in scene_ids.items()),
|
||||||
|
key=lambda x: x[0],
|
||||||
|
reverse=True
|
||||||
|
)
|
||||||
|
# NOTE: Hardcoded to only two scenes.
|
||||||
|
if turn_ind >= sorted_scene_ids[0][0]:
|
||||||
|
scene_label = sorted_scene_ids[0][1]
|
||||||
|
else:
|
||||||
|
scene_label = sorted_scene_ids[1][1]
|
||||||
|
image_label = scene_label
|
||||||
|
if "m_" in scene_label:
|
||||||
|
image_label = image_label.replace("m_", "")
|
||||||
|
return f"{image_label}.png", scene_label
|
||||||
|
|
||||||
|
|
||||||
|
def get_object_mapping(scene_label, args):
|
||||||
|
"""Get the object mapping for a given scene.
|
||||||
|
"""
|
||||||
|
scene_json_path = os.path.join(
|
||||||
|
args["scene_json_folder"], f"{scene_label}_scene.json"
|
||||||
|
)
|
||||||
|
with open(scene_json_path, "r") as file_id:
|
||||||
|
scene_objects = json.load(file_id)["scenes"][0]["objects"]
|
||||||
|
object_map = [ii["index"] for ii in scene_objects]
|
||||||
|
return object_map
|
||||||
|
|
||||||
|
|
||||||
|
def dictionary_to_string(dictionary):
|
||||||
|
result = ""
|
||||||
|
for k, v in dictionary.items():
|
||||||
|
if k in ['assetType', 'color', 'pattern', 'sleeveLength', 'type']:
|
||||||
|
continue
|
||||||
|
result += k + ":"
|
||||||
|
result += str(v) + " "
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
for split in SPLITS:
|
||||||
|
read_path = args[f"simmc_{split}_json"]
|
||||||
|
print(f"Reading: {read_path}")
|
||||||
|
with open(read_path, "r") as file_id:
|
||||||
|
dialogs = json.load(file_id)
|
||||||
|
|
||||||
|
|
||||||
|
# load the metadata files
|
||||||
|
with open(args["furniture_prefab_metadata"], "r") as file:
|
||||||
|
furniture_metadata = json.load(file)
|
||||||
|
|
||||||
|
with open(args["fashion_prefab_metadata"], "r") as file:
|
||||||
|
fashion_metadata = json.load(file)
|
||||||
|
|
||||||
|
|
||||||
|
# Reformat into simple strings with positive and negative labels.
|
||||||
|
# (dialog string, label)
|
||||||
|
ambiguous_candidates_data = []
|
||||||
|
for dialog_id, dialog_datum in enumerate(dialogs["dialogue_data"]):
|
||||||
|
turns = []
|
||||||
|
q_turns = []
|
||||||
|
a_turns = []
|
||||||
|
|
||||||
|
for turn_ind, turn_datum in enumerate(dialog_datum["dialogue"]):
|
||||||
|
query = [turn_datum["transcript"]]
|
||||||
|
if "system_transcript" not in turn_datum.keys():
|
||||||
|
continue
|
||||||
|
answer = [turn_datum["system_transcript"]]
|
||||||
|
|
||||||
|
#annotations = turn_datum["transcript_annotated"]
|
||||||
|
#if annotations.get("disambiguation_label", False):
|
||||||
|
#label = annotations["disambiguation_candidates"]
|
||||||
|
image_name, scene_id = get_image_name(
|
||||||
|
dialog_datum["scene_ids"], turn_ind
|
||||||
|
)
|
||||||
|
|
||||||
|
# load the scene files and get all the prefab pahts to get the object descriptions for each scene
|
||||||
|
prefab_paths = []
|
||||||
|
scene_path = os.path.join(args["scene_json_folder"], f"{scene_id}_scene.json")
|
||||||
|
with open(scene_path, "r") as scene_file:
|
||||||
|
scene_data = json.load(scene_file)
|
||||||
|
for scene in scene_data["scenes"]:
|
||||||
|
for object in scene["objects"]:
|
||||||
|
prefab_paths.append(object["prefab_path"])
|
||||||
|
|
||||||
|
# get the metadata for all objects of the scene (prefab_paths)
|
||||||
|
object_metadata = []
|
||||||
|
for prefab_path in prefab_paths:
|
||||||
|
if scene_id[:11] in ["cloth_store", "m_cloth_sto"]:
|
||||||
|
object_dict = fashion_metadata[prefab_path]
|
||||||
|
elif scene_id[:7] == "wayfair":
|
||||||
|
object_dict = furniture_metadata[prefab_path]
|
||||||
|
object_str = dictionary_to_string(object_dict)
|
||||||
|
object_metadata.append([object_str])
|
||||||
|
|
||||||
|
|
||||||
|
# If dialog contains multiple scenes, map it accordingly.
|
||||||
|
#object_map = get_object_mapping(scene_label, args)
|
||||||
|
new_datum = {
|
||||||
|
"query": query,
|
||||||
|
"answer": answer,
|
||||||
|
"q_turns": copy.deepcopy(q_turns),
|
||||||
|
"a_turns": copy.deepcopy(a_turns),
|
||||||
|
"turns": copy.deepcopy(turns),
|
||||||
|
"object_metadata": object_metadata,
|
||||||
|
"dialog_id": dialog_datum["dialogue_idx"],
|
||||||
|
"turn_id": turn_ind,
|
||||||
|
#"input_text": copy.deepcopy(history),
|
||||||
|
#"ambiguous_candidates": label,
|
||||||
|
"image_name": image_name,
|
||||||
|
#"object_map": object_map,
|
||||||
|
}
|
||||||
|
|
||||||
|
ambiguous_candidates_data.append(new_datum)
|
||||||
|
|
||||||
|
turns.append([turn_datum["transcript"] + turn_datum["system_transcript"]])
|
||||||
|
q_turns.append(query)
|
||||||
|
a_turns.append(answer)
|
||||||
|
|
||||||
|
|
||||||
|
# Ignore if system_transcript is not found (last round teststd).
|
||||||
|
# if turn_datum.get("system_transcript", None):
|
||||||
|
# history.append(turn_datum["system_transcript"])
|
||||||
|
|
||||||
|
print(f"# instances [{split}]: {len(ambiguous_candidates_data)}")
|
||||||
|
save_path = os.path.join(
|
||||||
|
args["ambiguous_candidates_save_path"],
|
||||||
|
f"simmc2.1_ambiguous_candidates_dstc11_{split}.json"
|
||||||
|
)
|
||||||
|
print(f"Saving: {save_path}")
|
||||||
|
with open(save_path, "w") as file_id:
|
||||||
|
json.dump(
|
||||||
|
{
|
||||||
|
"source_path": read_path,
|
||||||
|
"split": split,
|
||||||
|
"data": ambiguous_candidates_data,
|
||||||
|
},
|
||||||
|
file_id
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument(
|
||||||
|
"--simmc_train_json", default=None, help="Path to SIMMC 2.1 train"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--simmc_dev_json", default=None, help="Path to SIMMC 2.1 dev"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--simmc_devtest_json", default=None, help="Path to SIMMC 2.1 devtest"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--simmc_teststd_json", default=None, help="Path to SIMMC 2.1 teststd (public)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--scene_json_folder", default=None, help="Path to SIMMC scene jsons"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ambiguous_candidates_save_path",
|
||||||
|
required=True,
|
||||||
|
help="Path to save SIMMC disambiguate JSONs",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--fashion_prefab_metadata", required=True,
|
||||||
|
help="Path to the file with all metadata for fashion objects"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--furniture_prefab_metadata", required=True,
|
||||||
|
help="Path to the file with all metadata for fashion objects"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed_args = vars(parser.parse_args())
|
||||||
|
except (IOError) as msg:
|
||||||
|
parser.error(str(msg))
|
||||||
|
main(parsed_args)
|
10
src/utils/simmc2_dataset/format_data_with_obj_descriptions.sh
Executable file
10
src/utils/simmc2_dataset/format_data_with_obj_descriptions.sh
Executable file
|
@ -0,0 +1,10 @@
|
||||||
|
#!/bin/bash
|
||||||
|
DATA_FOLDER="../../data/"
|
||||||
|
python format_data_with_object_descriptions.py \
|
||||||
|
--simmc_train_json "/scratch/hochmeister/simmc2/data/simmc2.1_dials_dstc11_train.json" \
|
||||||
|
--simmc_dev_json "/scratch/hochmeister/simmc2/data/simmc2.1_dials_dstc11_dev.json" \
|
||||||
|
--simmc_devtest_json "/scratch/hochmeister/simmc2/data/simmc2.1_dials_dstc11_devtest.json" \
|
||||||
|
--scene_json_folder "/scratch/hochmeister/simmc2/data/public/" \
|
||||||
|
--ambiguous_candidates_save_path "/scratch/hochmeister/simmc2/data/ambiguous_candidates/"
|
||||||
|
--fashion_prefab_metadata "/scratch/hochmeister/simmc2/data/fashion_prefab_metadata_all.json"
|
||||||
|
--furniture_prefab_metadata "/scratch/hochmeister/simmc2/data/furniture_prefab_metadata_all.json"
|
|
@ -0,0 +1,206 @@
|
||||||
|
#! /usr/bin/env python
|
||||||
|
"""
|
||||||
|
Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
All rights reserved.
|
||||||
|
This source code is licensed under the license found in the LICENSE file in the
|
||||||
|
root directory of this source tree.
|
||||||
|
|
||||||
|
Reads SIMMC 2.1 dataset, creates train, devtest, dev formats for ambiguous candidates.
|
||||||
|
|
||||||
|
Author(s): Satwik Kottur
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import copy
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
SPLITS = ["train", "dev", "devtest", "teststd"]
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_name(scene_ids, turn_ind):
|
||||||
|
"""Given scene ids and turn index, get the image name.
|
||||||
|
"""
|
||||||
|
sorted_scene_ids = sorted(
|
||||||
|
((int(key), val) for key, val in scene_ids.items()),
|
||||||
|
key=lambda x: x[0],
|
||||||
|
reverse=True
|
||||||
|
)
|
||||||
|
# NOTE: Hardcoded to only two scenes.
|
||||||
|
if turn_ind >= sorted_scene_ids[0][0]:
|
||||||
|
scene_label = sorted_scene_ids[0][1]
|
||||||
|
else:
|
||||||
|
scene_label = sorted_scene_ids[1][1]
|
||||||
|
image_label = scene_label
|
||||||
|
if "m_" in scene_label:
|
||||||
|
image_label = image_label.replace("m_", "")
|
||||||
|
return f"{image_label}.png", scene_label
|
||||||
|
|
||||||
|
|
||||||
|
def get_object_mapping(scene_label, args):
|
||||||
|
"""Get the object mapping for a given scene.
|
||||||
|
"""
|
||||||
|
scene_json_path = os.path.join(
|
||||||
|
args["scene_json_folder"], f"{scene_label}_scene.json"
|
||||||
|
)
|
||||||
|
with open(scene_json_path, "r") as file_id:
|
||||||
|
scene_objects = json.load(file_id)["scenes"][0]["objects"]
|
||||||
|
object_map = [ii["index"] for ii in scene_objects]
|
||||||
|
return object_map
|
||||||
|
|
||||||
|
|
||||||
|
def dictionary_to_string(dictionary):
|
||||||
|
result = ""
|
||||||
|
for k, v in dictionary.items():
|
||||||
|
if k in ['assetType', 'color', 'pattern', 'sleeveLength', 'type']:
|
||||||
|
continue
|
||||||
|
result += k + ":"
|
||||||
|
result += str(v) + " "
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
for split in SPLITS:
|
||||||
|
read_path = args[f"simmc_{split}_json"]
|
||||||
|
print(f"Reading: {read_path}")
|
||||||
|
with open(read_path, "r") as file_id:
|
||||||
|
dialogs = json.load(file_id)
|
||||||
|
|
||||||
|
|
||||||
|
# load the metadata files
|
||||||
|
with open(args["furniture_prefab_metadata"], "r") as file:
|
||||||
|
furniture_metadata = json.load(file)
|
||||||
|
|
||||||
|
with open(args["fashion_prefab_metadata"], "r") as file:
|
||||||
|
fashion_metadata = json.load(file)
|
||||||
|
|
||||||
|
|
||||||
|
# Reformat into simple strings with positive and negative labels.
|
||||||
|
# (dialog string, label)
|
||||||
|
ambiguous_candidates_data = []
|
||||||
|
for dialog_id, dialog_datum in enumerate(dialogs["dialogue_data"]):
|
||||||
|
turns = []
|
||||||
|
q_turns = []
|
||||||
|
a_turns = []
|
||||||
|
dial_len = len(dialog_datum['dialogue'])
|
||||||
|
|
||||||
|
for turn_ind, turn_datum in enumerate(dialog_datum["dialogue"]):
|
||||||
|
query = [turn_datum["transcript"]]
|
||||||
|
if "system_transcript" not in turn_datum.keys():
|
||||||
|
answer = ""
|
||||||
|
else:
|
||||||
|
answer = [turn_datum["system_transcript"]]
|
||||||
|
|
||||||
|
#annotations = turn_datum["transcript_annotated"]
|
||||||
|
#if annotations.get("disambiguation_label", False):
|
||||||
|
#label = annotations["disambiguation_candidates"]
|
||||||
|
image_name, scene_id = get_image_name(
|
||||||
|
dialog_datum["scene_ids"], turn_ind
|
||||||
|
)
|
||||||
|
|
||||||
|
# load the scene files and get all the prefab pahts to get the object descriptions for each scene
|
||||||
|
prefab_paths = []
|
||||||
|
scene_path = os.path.join(args["scene_json_folder"], f"{scene_id}_scene.json")
|
||||||
|
with open(scene_path, "r") as scene_file:
|
||||||
|
scene_data = json.load(scene_file)
|
||||||
|
for scene in scene_data["scenes"]:
|
||||||
|
for object in scene["objects"]:
|
||||||
|
prefab_paths.append(object["prefab_path"])
|
||||||
|
|
||||||
|
# get the metadata for all objects of the scene (prefab_paths)
|
||||||
|
object_metadata = []
|
||||||
|
for prefab_path in prefab_paths:
|
||||||
|
if scene_id[:11] in ["cloth_store", "m_cloth_sto"]:
|
||||||
|
object_dict = fashion_metadata[prefab_path]
|
||||||
|
elif scene_id[:7] == "wayfair":
|
||||||
|
object_dict = furniture_metadata[prefab_path]
|
||||||
|
object_str = dictionary_to_string(object_dict)
|
||||||
|
object_metadata.append([object_str])
|
||||||
|
|
||||||
|
|
||||||
|
# If dialog contains multiple scenes, map it accordingly.
|
||||||
|
#object_map = get_object_mapping(scene_label, args)
|
||||||
|
new_datum = {
|
||||||
|
"query": query,
|
||||||
|
"answer": answer,
|
||||||
|
"q_turns": copy.deepcopy(q_turns),
|
||||||
|
"a_turns": copy.deepcopy(a_turns),
|
||||||
|
"turns": copy.deepcopy(turns),
|
||||||
|
"object_metadata": object_metadata,
|
||||||
|
"dialog_id": dialog_datum["dialogue_idx"],
|
||||||
|
"turn_id": turn_ind,
|
||||||
|
#"input_text": copy.deepcopy(history),
|
||||||
|
#"ambiguous_candidates": label,
|
||||||
|
"image_name": image_name,
|
||||||
|
#"object_map": object_map,
|
||||||
|
}
|
||||||
|
|
||||||
|
# only the last dialogue turns are used as samples for the test set
|
||||||
|
if turn_ind == dial_len - 1:
|
||||||
|
ambiguous_candidates_data.append(new_datum)
|
||||||
|
else:
|
||||||
|
turns.append([turn_datum["transcript"] + turn_datum["system_transcript"]])
|
||||||
|
q_turns.append(query)
|
||||||
|
a_turns.append(answer)
|
||||||
|
|
||||||
|
|
||||||
|
# Ignore if system_transcript is not found (last round teststd).
|
||||||
|
# if turn_datum.get("system_transcript", None):
|
||||||
|
# history.append(turn_datum["system_transcript"])
|
||||||
|
|
||||||
|
print(f"# instances [{split}]: {len(ambiguous_candidates_data)}")
|
||||||
|
save_path = os.path.join(
|
||||||
|
args["ambiguous_candidates_save_path"],
|
||||||
|
f"simmc2.1_ambiguous_candidates_dstc11_{split}.json"
|
||||||
|
)
|
||||||
|
print(f"Saving: {save_path}")
|
||||||
|
with open(save_path, "w") as file_id:
|
||||||
|
json.dump(
|
||||||
|
{
|
||||||
|
"source_path": read_path,
|
||||||
|
"split": split,
|
||||||
|
"data": ambiguous_candidates_data,
|
||||||
|
},
|
||||||
|
file_id
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument(
|
||||||
|
"--simmc_train_json", default=None, help="Path to SIMMC 2.1 train"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--simmc_dev_json", default=None, help="Path to SIMMC 2.1 dev"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--simmc_devtest_json", default=None, help="Path to SIMMC 2.1 devtest"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--simmc_teststd_json", default=None, help="Path to SIMMC 2.1 teststd (public)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--scene_json_folder", default=None, help="Path to SIMMC scene jsons"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ambiguous_candidates_save_path",
|
||||||
|
required=True,
|
||||||
|
help="Path to save SIMMC disambiguate JSONs",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--fashion_prefab_metadata", required=True,
|
||||||
|
help="Path to the file with all metadata for fashion objects"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--furniture_prefab_metadata", required=True,
|
||||||
|
help="Path to the file with all metadata for fashion objects"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed_args = vars(parser.parse_args())
|
||||||
|
except (IOError) as msg:
|
||||||
|
parser.error(str(msg))
|
||||||
|
main(parsed_args)
|
15
src/utils/text_utils.py
Normal file
15
src/utils/text_utils.py
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
import nltk
|
||||||
|
|
||||||
|
def normalize_sentence(sentence):
|
||||||
|
return nltk.tokenize.word_tokenize(sentence.lower())
|
||||||
|
|
||||||
|
|
||||||
|
def translate_from_ids_to_text(ids, tokenizer):
|
||||||
|
text = tokenizer.decode(ids)
|
||||||
|
if '</s>' in text:
|
||||||
|
text, pad = text.split('</s>', 1)
|
||||||
|
if '<s>' in text:
|
||||||
|
text = text[3:]
|
||||||
|
|
||||||
|
#text_as_list = text.split(' ')
|
||||||
|
return text
|
71
test.py
Normal file
71
test.py
Normal file
|
@ -0,0 +1,71 @@
|
||||||
|
from src.models.discriminative_model import DiscriminativeModel
|
||||||
|
from src.models.generative_model import GenerativeModel
|
||||||
|
from src.data_modules.dvd_data import DVDData
|
||||||
|
from src.data_modules.simmc2_data import Simmc2Data
|
||||||
|
from src.data_modules.avsd_data import AvsdData
|
||||||
|
from pytorch_lightning import Trainer
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
from pytorch_lightning.loggers import WandbLogger
|
||||||
|
from pytorch_lightning import Trainer
|
||||||
|
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
|
||||||
|
import wandb
|
||||||
|
from config.config import read_default_config, read_config, update_nested_dicts
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='Test script for OLViT')
|
||||||
|
parser.add_argument(
|
||||||
|
'--ckpt_path',
|
||||||
|
type=str,
|
||||||
|
help='Path to the checkpoint to be tested')
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--cfg_path',
|
||||||
|
type=str,
|
||||||
|
help='Path to the config file of the selected checkpoint')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
wandb.finish()
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
chkpt_path = args.ckpt_path
|
||||||
|
|
||||||
|
# read the default conifg and update the values with the experiment specific config
|
||||||
|
config = read_default_config()
|
||||||
|
experiment_config = read_config(args.cfg_path)
|
||||||
|
config = update_nested_dicts(old_dict=config, update_dict=experiment_config)
|
||||||
|
|
||||||
|
if 'output_path' not in config['checkpoint'].keys():
|
||||||
|
raise Exception('no output path provided in config (full path for disc model only path to output folder for gen. model)')
|
||||||
|
|
||||||
|
available_models = {
|
||||||
|
'discriminative': DiscriminativeModel,
|
||||||
|
'generative': GenerativeModel
|
||||||
|
}
|
||||||
|
data_modules = {
|
||||||
|
'dvd': DVDData,
|
||||||
|
'simmc2': Simmc2Data,
|
||||||
|
}
|
||||||
|
|
||||||
|
wandb_logger = WandbLogger(
|
||||||
|
entity=config['wandb']['entity'],
|
||||||
|
name=config['wandb']['name'],
|
||||||
|
group=config['wandb']['group'],
|
||||||
|
tags=config['wandb']['tags'],
|
||||||
|
project=config['wandb']['project'],
|
||||||
|
config=config
|
||||||
|
)
|
||||||
|
|
||||||
|
if config['training']['seed'] != None:
|
||||||
|
pl.seed_everything(config['training']['seed'])
|
||||||
|
|
||||||
|
trainer = Trainer(
|
||||||
|
logger=wandb_logger,
|
||||||
|
accelerator='gpu',
|
||||||
|
devices=[0]
|
||||||
|
)
|
||||||
|
data = data_modules[config['model']['dataset']](config=config)
|
||||||
|
|
||||||
|
model = available_models[config['model']['model_type']](config=config, output_path=config['checkpoint']['output_path'])
|
||||||
|
trainer.test(model=model, ckpt_path=chkpt_path, dataloaders=data)
|
95
train.py
Normal file
95
train.py
Normal file
|
@ -0,0 +1,95 @@
|
||||||
|
from src.models.discriminative_model import DiscriminativeModel
|
||||||
|
from src.models.generative_model import GenerativeModel
|
||||||
|
from src.data_modules.dvd_data import DVDData
|
||||||
|
from src.data_modules.simmc2_data import Simmc2Data
|
||||||
|
from pytorch_lightning import Trainer
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
from pytorch_lightning.loggers import WandbLogger
|
||||||
|
from pytorch_lightning import Trainer
|
||||||
|
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
|
||||||
|
import wandb
|
||||||
|
from config.config import read_default_config, read_config, update_nested_dicts
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='Train script for OLViT')
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--cfg_path',
|
||||||
|
default='config/dvd.json',
|
||||||
|
type=str,
|
||||||
|
help='Path to the config file of the selected checkpoint')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
wandb.finish()
|
||||||
|
args = parser.parse_args()
|
||||||
|
# read the default conifg and update the values with the experiment specific config
|
||||||
|
config = read_default_config()
|
||||||
|
experiment_config = read_config(args.cfg_path)
|
||||||
|
config = update_nested_dicts(old_dict=config, update_dict=experiment_config)
|
||||||
|
|
||||||
|
available_models = {
|
||||||
|
'discriminative': DiscriminativeModel,
|
||||||
|
'generative': GenerativeModel
|
||||||
|
}
|
||||||
|
data_modules = {
|
||||||
|
'dvd': DVDData,
|
||||||
|
'simmc2': Simmc2Data,
|
||||||
|
}
|
||||||
|
|
||||||
|
monitor_score = {
|
||||||
|
'discriminative': 'val_acc',
|
||||||
|
'generative': 'bleu4'
|
||||||
|
}
|
||||||
|
|
||||||
|
checkpoint_cb = pl.callbacks.ModelCheckpoint(
|
||||||
|
monitor=monitor_score[config['model']['model_type']], mode="max",
|
||||||
|
save_top_k=1,
|
||||||
|
dirpath=config["checkpoint"]["checkpoint_folder"],
|
||||||
|
filename=config["checkpoint"]["checkpoint_file_name"],
|
||||||
|
every_n_epochs=1
|
||||||
|
)
|
||||||
|
|
||||||
|
lr_monitor_cb = LearningRateMonitor(
|
||||||
|
logging_interval='step'
|
||||||
|
)
|
||||||
|
|
||||||
|
callbacks = []
|
||||||
|
callbacks.append(checkpoint_cb)
|
||||||
|
callbacks.append(lr_monitor_cb)
|
||||||
|
|
||||||
|
wandb_logger = WandbLogger(
|
||||||
|
offline=True,
|
||||||
|
entity=config['wandb']['entity'],
|
||||||
|
name=config['wandb']['name'],
|
||||||
|
group=config['wandb']['group'],
|
||||||
|
tags=config['wandb']['tags'],
|
||||||
|
project=config['wandb']['project'],
|
||||||
|
config=config
|
||||||
|
)
|
||||||
|
|
||||||
|
if config['training']['seed'] != None:
|
||||||
|
pl.seed_everything(config['training']['seed'])
|
||||||
|
|
||||||
|
trainer = Trainer(
|
||||||
|
logger=wandb_logger,
|
||||||
|
# detect_anomaly=True,
|
||||||
|
accelerator='gpu',
|
||||||
|
devices=[0],
|
||||||
|
fast_dev_run=False,
|
||||||
|
max_epochs=config['training']['epochs'],
|
||||||
|
check_val_every_n_epoch=1,
|
||||||
|
log_every_n_steps=1,
|
||||||
|
strategy=pl.strategies.ddp.DDPStrategy(find_unused_parameters=False),
|
||||||
|
accumulate_grad_batches=config['training']['accumulate_grad_batches'],
|
||||||
|
precision=32,
|
||||||
|
callbacks=callbacks
|
||||||
|
)
|
||||||
|
data = data_modules[config['model']['dataset']](config=config)
|
||||||
|
|
||||||
|
if 'output_path' in config['checkpoint'].keys():
|
||||||
|
model = available_models[config['model']['model_type']](config=config, output_path=config['checkpoint']['output_path'])
|
||||||
|
else:
|
||||||
|
model = available_models[config['model']['model_type']](config=config)
|
||||||
|
|
||||||
|
trainer.fit(model, data)
|
Loading…
Reference in a new issue