Make code public
This commit is contained in:
commit
8e03ef1c38
49 changed files with 545354 additions and 0 deletions
3
.gitattributes
vendored
Normal file
3
.gitattributes
vendored
Normal file
|
@ -0,0 +1,3 @@
|
|||
*.pkl filter=lfs diff=lfs merge=lfs -text
|
||||
*.h5 filter=lfs diff=lfs merge=lfs -text
|
||||
*.csv filter=lfs diff=lfs merge=lfs -text
|
21
LICENSE
Normal file
21
LICENSE
Normal file
|
@ -0,0 +1,21 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2023 Anonymous
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
174
README.md
Normal file
174
README.md
Normal file
|
@ -0,0 +1,174 @@
|
|||
<div align="center">
|
||||
<h1> MST-MIXER <img src="misc/mixer.png" width="3%" align="bottom">: Multi-Modal Video Dialog State Tracking in the Wild </h1>
|
||||
|
||||
**[Adnen Abdessaied][16], [Lei Shi][17], [Andreas Bulling][18]** <br> <br>
|
||||
**ECCV 2024, Milan, Italy <img src="misc/italy.png" width="3%" align="center">** <br>
|
||||
**[[Paper][19]]**
|
||||
|
||||
---------------------------
|
||||
<img src="misc/teaser.png" width="70%" align="middle"><br><br>
|
||||
|
||||
</div>
|
||||
|
||||
# Citation
|
||||
If you find our code useful or use it in your own projects, please cite our paper:
|
||||
|
||||
```bibtex
|
||||
@InProceedings{Abdessaied_2024_eccv,
|
||||
author = {Abdessaied, Adnen and Shi, Lei and Bulling, Andreas},
|
||||
title = {{Multi-Modal Video Dialog State Tracking in the Wild}},
|
||||
booktitle = {Proceedings of the European Conference on Computer Vision (ECCV)},
|
||||
year = {2024}
|
||||
}
|
||||
```
|
||||
|
||||
# Table of Contents
|
||||
* [Setup and Dependencies](#Setup-and-Dependencies)
|
||||
* [Download Data](#Download-Data)
|
||||
* [Training](#Training)
|
||||
* [Response Generation](#Response-Generation)
|
||||
* [Results](#Results)
|
||||
* [Acknowledgements](#Acknowledgements)
|
||||
|
||||
# Setup and Dependencies
|
||||
We implemented our model using Python 3.7 and PyTorch 1.12.0 (CUDA 11.3, CuDNN 8.3.2). We recommend to setup a virtual environment using Anaconda. <br>
|
||||
1. Install [git lfs][1] on your system
|
||||
2. Clone our repository to download a checpint of our best model and our code
|
||||
```shell
|
||||
git lfs install
|
||||
git clone this_repo.git
|
||||
```
|
||||
3. Create a conda environment and install dependencies
|
||||
```shell
|
||||
conda create -n mst_mixer python=3.7
|
||||
conda activate mst_mixer
|
||||
conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit=11.3 -c pytorch
|
||||
conda install pyg -c pyg
|
||||
conda install pytorch-scatter -c pyg # pytorch >= 1.8.0
|
||||
conda install pytorch-sparse -c pyg # pytorch >= 1.8.0
|
||||
conda install -c huggingface transformers
|
||||
pip install evaluate wandb glog pyhocon attrs
|
||||
```
|
||||
# Download Data
|
||||
## AVSD
|
||||
1. Download the [AVSD-DSTC7][2], [AVSD-DSTC8][3] and [AVSD-DSTC10][10] data
|
||||
2. Place the raw json files in ```raw_data/``` and the features in ```features/```
|
||||
3. Prepeocess and save the input features for faster training as indicated in ```custom_datasets/```
|
||||
## NExT-QA
|
||||
1. For convenience, we included the features/data in this git repo.
|
||||
|
||||
# Training
|
||||
We trained our model on 8 Nvidia Tesla V100-32GB GPUs. The default hyperparameters in ```config/mst_mixer.conf``` need to be adjusted if your setup differs from ours.
|
||||
## AVSD
|
||||
1. Set ```task=avsd``` in ```config/mst_mixer.conf```
|
||||
2. ```shell
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
|
||||
--mode train \
|
||||
--tag mst_mixer_avsd \
|
||||
--wandb_mode online \
|
||||
--wandb_project mst_mixer_avsd
|
||||
```
|
||||
To deactivate [wandb][4] logging, use ```--wandb_mode disabled```.
|
||||
On a similar setup to ours, this will take roughly 20h to complete.
|
||||
|
||||
## NExT-QA
|
||||
1. Set ```task=nextqa``` in ```config/mst_mixer.conf```
|
||||
2. ```shell
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
|
||||
--mode train \
|
||||
--tag mst_mixer_nextqa \
|
||||
--wandb_mode online \
|
||||
--wandb_project mst_mixer_nextqa
|
||||
```
|
||||
|
||||
# Response Generation
|
||||
## AVSD-DSTC7
|
||||
1. Set ```dstc=7``` in the ```.conf``` file of your trained networks. in The default setting, can find this under ```logs/unique_training_tag/code/config/mst_mixer.conf```
|
||||
2. Generate the responses
|
||||
```shell
|
||||
./generate_parallel_avsd.sh mst_mixer/mixer results_avsd_dstc7 generate logs/mst_mixer_avsd 7
|
||||
```
|
||||
3. All responses will be saved in ```output/dstc7/```
|
||||
## AVSD-DSTC8
|
||||
1. Set ```dstc=8``` in the ```.conf``` file of your trained networks. in The default setting, can find this under ```logs/unique_training_tag/code/config/mst_mixer.conf```
|
||||
2. Generate the responses
|
||||
```shell
|
||||
./generate_parallel_avsd.sh mst_mixer/mixer results_avsd_dstc8 generate logs/mst_mixer_avsd 8
|
||||
```
|
||||
3. All responses will be saved in ```output/dstc8/```
|
||||
|
||||
## AVSD-DSTC10
|
||||
1. Set ```dstc=10``` in the ```.conf``` file of your trained networks. in The default setting, can find this under ```logs/unique_training_tag/code/config/mst_mixer.conf```
|
||||
2. Generate the responses
|
||||
```shell
|
||||
./generate_parallel_avsd.sh mst_mixer/mixer results_avsd_dstc10 generate logs/mst_mixer_avsd 10
|
||||
```
|
||||
3. All responses will be saved in ```output/dstc10/```
|
||||
|
||||
## NExT-QA
|
||||
1. Generate the responses
|
||||
```shell
|
||||
./generate_parallel_nextqa.sh mst_mixer/mixer results_nextqa generate logs/mst_mixer_nextqa
|
||||
```
|
||||
2. All responses will be saved in ```output/nextqa/```
|
||||
3. Evalute using this [script][15]
|
||||
|
||||
|
||||
# Results
|
||||
To evaluate our best model on
|
||||
## AVSD-DSTC7
|
||||
Executing the [eval_tool][7] of AVSD-DSTC7 using the generated repsonses will output the following metrics
|
||||
| Model | BLUE-1 | BLUE-2 | BLUE-3 | BLUE-4 | METEOR | ROUGE-L | CIDEr |
|
||||
|:--------:|:------:|:------:|:------:|:------:|:------:|:-------:|:-----:|
|
||||
| Prev. SOTA | 78.2 | 65.5 | 55.2 | 46.9 | 30.8 | 61.9 | 135.2 |
|
||||
| MST_MIXER | **78.7** | **66.5** | **56.3** | **47.6** | **31.3** | **62.5** | **138.8**|
|
||||
|
||||
## AVSD-DSTC8
|
||||
1. Set ```dstc=8``` in the ```ckpt/code/mst_mixer.conf```
|
||||
2. run
|
||||
```shell
|
||||
./generate_parallel_avsd.sh mst_mixer/mixer results_avsd_dstc8_best_model generate ckpt/avsd 8
|
||||
```
|
||||
3. The responses will be saved in ```output/dstc8/```
|
||||
4. Executing the [eval_tool][7] of AVSD-DSTC8 using the generated repsonses will output the following metrics
|
||||
|
||||
| Model | BLUE-1 | BLUE-2 | BLUE-3 | BLUE-4 | METEOR | ROUGE-L | CIDEr |
|
||||
|:--------:|:------:|:------:|:------:|:------:|:------:|:-------:|:-----:|
|
||||
| Prev. SOTA | 76.4 | 64.1 | 54.3 | 46.0 | 30.1 | 61.0 | 130.4 |
|
||||
| MST_MIXER | **77.5** | **66.0** | **56.1** | **47.7** | **30.6** | **62.4** | **135.4**|
|
||||
|
||||
## AVSD-DSTC10
|
||||
Executing the [eval_tool][11] of AVSD-DSTC10 using the generated repsonses will output the following metrics
|
||||
|
||||
| Model | BLUE-1 | BLUE-2 | BLUE-3 | BLUE-4 | METEOR | ROUGE-L | CIDEr |
|
||||
|:--------:|:------:|:------:|:------:|:------:|:------:|:-------:|:-----:|
|
||||
| Prev. SOTA | 69.3 | 55.6 | 45.0 | 37.2 | 24.9 | 53.6 | 91.2 |
|
||||
| MST_MIXER | **70.0** | **57.4** | **47.6** | **40.0** | **25.7** | **54.5** | **99.8**|
|
||||
|
||||
## NExT-QA
|
||||
Executing the [eval script][15] of NExT-QA using the generated repsonses will output the following metrics
|
||||
|
||||
| Model | WUPS_C | WUPS_T | WUPS_D | WUPS |
|
||||
|:--------:|:------:|:------:|:------:|:------:|
|
||||
| Prev. SOTA | 17.98| 17.95 | 50.84 | 28.40 |
|
||||
| MST_MIXER | **22.12** | **22.20** | **55.64** | **29.50** |
|
||||
|
||||
|
||||
# Acknowledgements
|
||||
We thank the authors of [RLM][8] for providing their [code][9] that greatly influenced this work.
|
||||
|
||||
[1]: https://git-lfs.com/
|
||||
[2]: https://github.com/hudaAlamri/DSTC7-Audio-Visual-Scene-Aware-Dialog-AVSD-Challenge
|
||||
[3]: https://github.com/dialogtekgeek/DSTC8-AVSD_official
|
||||
[4]: https://wandb.ai/site
|
||||
[5]: https://drive.google.com/drive/folders/1SlZTySJAk_2tiMG5F8ivxCfOl_OWwd_Q
|
||||
[7]: https://drive.google.com/file/d/1EKfPtrNBQ5ciKRl6XggImweGRP84XuPi/view?usp=sharing
|
||||
[8]: https://arxiv.org/abs/2002.00163
|
||||
[9]: https://github.com/ictnlp/DSTC8-AVSD
|
||||
[10]: https://drive.google.com/file/d/1zvC6FuPRVRiLQCXZcYpzYUI9r1tiWls6/view
|
||||
[11]: https://github.com/ankitshah009/AVSD-DSTC10_baseline
|
||||
[15]: https://github.com/doc-doc/NExT-OE/blob/main/eval_oe.py
|
||||
[16]: https://adnenabdessaied.de/
|
||||
[17]: https://perceptualui.org/people/shi/
|
||||
[18]: https://perceptualui.org/people/bulling/
|
||||
[19]: https://arxiv.org/abs/2407.02218
|
118
config/avsd_bart_base.json
Normal file
118
config/avsd_bart_base.json
Normal file
|
@ -0,0 +1,118 @@
|
|||
{
|
||||
"_name_or_path": "bart-base",
|
||||
"activation_dropout": 0.1,
|
||||
"activation_function": "gelu",
|
||||
"add_bias_logits": false,
|
||||
"add_final_layer_norm": false,
|
||||
"architectures": [
|
||||
"BartModel"
|
||||
],
|
||||
"attention_dropout": 0.1,
|
||||
"bos_token_id": 0,
|
||||
"classif_dropout": 0.1,
|
||||
"classifier_dropout": 0.0,
|
||||
"d_model": 768,
|
||||
"decoder_attention_heads": 12,
|
||||
"decoder_ffn_dim": 3072,
|
||||
"decoder_layerdrop": 0.0,
|
||||
"decoder_layers": 6,
|
||||
"decoder_start_token_id": 2,
|
||||
"dropout": 0.1,
|
||||
"early_stopping": true,
|
||||
"encoder_attention_heads": 12,
|
||||
"encoder_ffn_dim": 3072,
|
||||
"encoder_layerdrop": 0.0,
|
||||
"encoder_layers": 6,
|
||||
"eos_token_id": 2,
|
||||
"forced_eos_token_id": 2,
|
||||
"forced_bos_token_id": 0,
|
||||
"gradient_checkpointing": false,
|
||||
"id2label": {
|
||||
"0": "LABEL_0",
|
||||
"1": "LABEL_1",
|
||||
"2": "LABEL_2"
|
||||
},
|
||||
"init_std": 0.02,
|
||||
"is_encoder_decoder": true,
|
||||
"label2id": {
|
||||
"LABEL_0": 0,
|
||||
"LABEL_1": 1,
|
||||
"LABEL_2": 2
|
||||
},
|
||||
"max_position_embeddings": 1024,
|
||||
"model_type": "bart",
|
||||
"no_repeat_ngram_size": 3,
|
||||
"normalize_before": false,
|
||||
"normalize_embedding": true,
|
||||
"num_beams": 4,
|
||||
"num_hidden_layers": 6,
|
||||
"pad_token_id": 1,
|
||||
"scale_embedding": false,
|
||||
"task_specific_params": {
|
||||
"summarization": {
|
||||
"length_penalty": 1.0,
|
||||
"max_length": 128,
|
||||
"min_length": 12,
|
||||
"num_beams": 4
|
||||
},
|
||||
"summarization_cnn": {
|
||||
"length_penalty": 2.0,
|
||||
"max_length": 142,
|
||||
"min_length": 56,
|
||||
"num_beams": 4
|
||||
},
|
||||
"summarization_xsum": {
|
||||
"length_penalty": 1.0,
|
||||
"max_length": 62,
|
||||
"min_length": 11,
|
||||
"num_beams": 6
|
||||
}
|
||||
},
|
||||
"torch_dtype": "float32",
|
||||
"transformers_version": "4.12.0.dev0",
|
||||
"use_cache": true,
|
||||
"vocab_size": 50265,
|
||||
|
||||
"d_i3d_flow": 2048,
|
||||
"d_i3d_rgb": 2048,
|
||||
"d_sam": 512,
|
||||
"d_audio": 128,
|
||||
"top_k": 10,
|
||||
"num_nn": 4,
|
||||
"gnn_type": "appnp",
|
||||
"use_random_graphs": false,
|
||||
"integrate_all_gnn_features": true,
|
||||
"use_elbo_local": true,
|
||||
"use_elbo_global": true,
|
||||
"use_non_linear": true,
|
||||
"gnn_K": 2,
|
||||
"gnn_alpha": 0.1,
|
||||
"num_modalities": 6,
|
||||
"local_gnn_d_hidden": 768,
|
||||
"global_gnn_d_hidden": 768,
|
||||
"num_local_gnn_heads": 2,
|
||||
"num_global_gnn_heads": 4,
|
||||
"local_gnn_dropout": 0.1,
|
||||
"global_gnn_dropout": 0.1,
|
||||
"local_fc_dropout": 0.1,
|
||||
"global_fc_dropout": 0.1,
|
||||
"num_local_gnn_layers": 2,
|
||||
"num_global_gnn_layers": 2,
|
||||
"num_local_fc_layers": 2,
|
||||
"num_global_fc_layers": 2,
|
||||
"use_local_gnn_bn": true,
|
||||
"use_global_gnn_bn": true,
|
||||
"use_local_fc_bn": true,
|
||||
"use_global_fc_bn": true,
|
||||
"local_gnn_concat": true,
|
||||
"global_gnn_concat": true,
|
||||
"num_local_gr_learner_heads": 8,
|
||||
"num_global_gr_learner_heads":8,
|
||||
"init_adj_ratio": 0.5,
|
||||
"adj_ratio": 0.5,
|
||||
"alpha": 0.9,
|
||||
"gnns_every": 2,
|
||||
"num_layers_state_fc_decoder": 2,
|
||||
"dropout_state_fc_decoder": 0.3
|
||||
}
|
||||
|
116
config/avsd_bart_large.json
Normal file
116
config/avsd_bart_large.json
Normal file
|
@ -0,0 +1,116 @@
|
|||
{
|
||||
"activation_dropout": 0.1,
|
||||
"activation_function": "gelu",
|
||||
"add_bias_logits": false,
|
||||
"add_final_layer_norm": false,
|
||||
"architectures": [
|
||||
"BartModel"
|
||||
],
|
||||
"attention_dropout": 0.1,
|
||||
"bos_token_id": 0,
|
||||
"classif_dropout": 0.1,
|
||||
"classifier_dropout": 0.0,
|
||||
"d_model": 1024,
|
||||
"decoder_attention_heads": 16,
|
||||
"decoder_ffn_dim": 4096,
|
||||
"decoder_layerdrop": 0.0,
|
||||
"decoder_layers": 12,
|
||||
"decoder_start_token_id": 2,
|
||||
"dropout": 0.1,
|
||||
"early_stopping": true,
|
||||
"encoder_attention_heads": 16,
|
||||
"encoder_ffn_dim": 4096,
|
||||
"encoder_layerdrop": 0.0,
|
||||
"encoder_layers": 12,
|
||||
"eos_token_id": 2,
|
||||
"forced_eos_token_id": 2,
|
||||
"forced_bos_token_id": 0,
|
||||
"gradient_checkpointing": false,
|
||||
"id2label": {
|
||||
"0": "LABEL_0",
|
||||
"1": "LABEL_1",
|
||||
"2": "LABEL_2"
|
||||
},
|
||||
"init_std": 0.02,
|
||||
"is_encoder_decoder": true,
|
||||
"label2id": {
|
||||
"LABEL_0": 0,
|
||||
"LABEL_1": 1,
|
||||
"LABEL_2": 2
|
||||
},
|
||||
"max_position_embeddings": 1024,
|
||||
"model_type": "bart",
|
||||
"no_repeat_ngram_size": 3,
|
||||
"normalize_before": false,
|
||||
"num_beams": 4,
|
||||
"num_hidden_layers": 12,
|
||||
"pad_token_id": 1,
|
||||
"scale_embedding": false,
|
||||
"task_specific_params": {
|
||||
"summarization": {
|
||||
"length_penalty": 1.0,
|
||||
"max_length": 128,
|
||||
"min_length": 12,
|
||||
"num_beams": 4
|
||||
},
|
||||
"summarization_cnn": {
|
||||
"length_penalty": 2.0,
|
||||
"max_length": 142,
|
||||
"min_length": 56,
|
||||
"num_beams": 4
|
||||
},
|
||||
"summarization_xsum": {
|
||||
"length_penalty": 1.0,
|
||||
"max_length": 62,
|
||||
"min_length": 11,
|
||||
"num_beams": 6
|
||||
}
|
||||
},
|
||||
"transformers_version": "4.7.0.dev0",
|
||||
"use_cache": true,
|
||||
"vocab_size": 50265,
|
||||
|
||||
"d_i3d_flow": 2048,
|
||||
"d_i3d_rgb": 2048,
|
||||
"d_sam": 512,
|
||||
"d_audio": 128,
|
||||
"top_k": 10,
|
||||
"num_nn": 4,
|
||||
"gnn_type": "appnp",
|
||||
"use_random_graphs": false,
|
||||
"integrate_all_gnn_features": true,
|
||||
"use_elbo_local": true,
|
||||
"use_elbo_global": true,
|
||||
"use_non_linear": true,
|
||||
"gnn_K": 2,
|
||||
"gnn_alpha": 0.1,
|
||||
"num_modalities": 6,
|
||||
"local_gnn_d_hidden": 1024,
|
||||
"global_gnn_d_hidden": 1024,
|
||||
"num_local_gnn_heads": 2,
|
||||
"num_global_gnn_heads": 4,
|
||||
"local_gnn_dropout": 0.1,
|
||||
"global_gnn_dropout": 0.1,
|
||||
"local_fc_dropout": 0.1,
|
||||
"global_fc_dropout": 0.1,
|
||||
"num_local_gnn_layers": 1,
|
||||
"num_global_gnn_layers": 1,
|
||||
"num_local_fc_layers": 1,
|
||||
"num_global_fc_layers": 1,
|
||||
"use_local_gnn_bn": true,
|
||||
"use_global_gnn_bn": true,
|
||||
"use_local_fc_bn": true,
|
||||
"use_global_fc_bn": true,
|
||||
"local_gnn_concat": true,
|
||||
"global_gnn_concat": true,
|
||||
"num_local_gr_learner_heads": 8,
|
||||
"num_global_gr_learner_heads": 8,
|
||||
"init_adj_ratio": 0.5,
|
||||
"adj_ratio": 0.5,
|
||||
"alpha": 0.9,
|
||||
"gnns_every": 4,
|
||||
"num_layers_state_fc_decoder": 2,
|
||||
"dropout_state_fc_decoder": 0.3
|
||||
|
||||
}
|
||||
|
94
config/mst_mixer.conf
Normal file
94
config/mst_mixer.conf
Normal file
|
@ -0,0 +1,94 @@
|
|||
mixer {
|
||||
task = avsd
|
||||
#################################################################################
|
||||
# datasets
|
||||
# avsd
|
||||
|
||||
avsd_processed = features/
|
||||
avsd_train = raw_data/train_set4DSTC7-AVSD.json
|
||||
avsd_val = raw_data/valid_set4DSTC7-AVSD.json
|
||||
avsd_test_dstc7 = raw_data/test_set4DSTC7-AVSD.json
|
||||
avsd_test_dstc8 = raw_data/test_set4DSTC8-AVSD.json
|
||||
avsd_test_dstc10 = raw_data/test_set4DSTC10-AVSD.json
|
||||
avsd_feature_path = features/
|
||||
avsd_i3d_rgb = features/i3d_rgb
|
||||
avsd_i3d_rgb_test = features/i3d_rgb_testset
|
||||
avsd_i3d_flow = features/i3d_flow_all
|
||||
avsd_i3d_flow_test = features/i3d_flow_testset
|
||||
avsd_audio = features/vggish_all
|
||||
avsd_audio_test = features/vggish_testset
|
||||
avsd_objects = features/sam
|
||||
avsd_objects_test = features/sam_testset
|
||||
|
||||
dstc = 7
|
||||
|
||||
# NextQA
|
||||
nextqa_root = processed/next_qa/annotations
|
||||
nextqa_vid_feat = processed/next_qa/vid_feat
|
||||
#################################################################################
|
||||
# Model
|
||||
bart_size = large # base, large
|
||||
avsd_bart_base_config = config/avsd_bart_base.json
|
||||
avsd_bart_large_config = config/avsd_bart_large.json
|
||||
nextqa_bart_large_config = config/nextqa_bart_large.json
|
||||
|
||||
#################################################################################
|
||||
# Logging & Checkpointing
|
||||
log_dir = logs
|
||||
output_dir_dstc7 = output/dstc7
|
||||
output_dir_dstc8 = output/dstc8
|
||||
output_dir_dstc10 = output/dstc10
|
||||
output_dir_nextqa = output/nextqa
|
||||
max_ckpt_to_keep = 5
|
||||
start_ckpt_for_generating = none
|
||||
loads_start_path = false
|
||||
next_logging_pct = 0.1
|
||||
save_ckpt=true
|
||||
skip_saving_ckpt = false
|
||||
stop_epochs = -1
|
||||
resets_min_val_loss = false
|
||||
restarts = false
|
||||
uses_new_optimizer = true
|
||||
sets_new_lr = false
|
||||
################################################################################
|
||||
# Data processing
|
||||
expand_rnd = false
|
||||
cap_sum = cap_sum
|
||||
add_state_tokens = true
|
||||
bart_max_input_len = 1024
|
||||
num_workers = 0
|
||||
n_history = 3
|
||||
caption_drop_rate = 0.0
|
||||
vis_feat_length = 36
|
||||
#################################################################################
|
||||
# Training
|
||||
dp_type = ddp
|
||||
batch_size = 16
|
||||
num_epochs = 12
|
||||
warmup_ratio = 0.1
|
||||
batch_multiply = 1
|
||||
skip_eval = false
|
||||
stop_epoch = -1
|
||||
random_seed = 54
|
||||
learning_rate_bart = 1e-5
|
||||
learning_rate_other = 1e-4
|
||||
min_lr = 0
|
||||
clip_grad_value = 1.0
|
||||
print_output = false
|
||||
eval_first = false
|
||||
overfit_size = -1
|
||||
elbo_global_coeff = 100
|
||||
elbo_local_coeff = 100
|
||||
gen_coeff = 1
|
||||
#################################################################################
|
||||
# Generation
|
||||
gen_batch_size = 1
|
||||
beam_depth = 5
|
||||
max_generation_length = 20
|
||||
min_generation_length = 1
|
||||
length_penalty = 0.3
|
||||
#################################################################################
|
||||
# Misc.
|
||||
master_port = 5101
|
||||
use_cpu = false
|
||||
}
|
116
config/nextqa_bart_large.json
Normal file
116
config/nextqa_bart_large.json
Normal file
|
@ -0,0 +1,116 @@
|
|||
{
|
||||
"activation_dropout": 0.1,
|
||||
"activation_function": "gelu",
|
||||
"add_bias_logits": false,
|
||||
"add_final_layer_norm": false,
|
||||
"architectures": [
|
||||
"BartModel"
|
||||
],
|
||||
"attention_dropout": 0.1,
|
||||
"bos_token_id": 0,
|
||||
"classif_dropout": 0.1,
|
||||
"classifier_dropout": 0.0,
|
||||
"d_model": 1024,
|
||||
"decoder_attention_heads": 16,
|
||||
"decoder_ffn_dim": 4096,
|
||||
"decoder_layerdrop": 0.0,
|
||||
"decoder_layers": 12,
|
||||
"decoder_start_token_id": 2,
|
||||
"dropout": 0.1,
|
||||
"early_stopping": true,
|
||||
"encoder_attention_heads": 16,
|
||||
"encoder_ffn_dim": 4096,
|
||||
"encoder_layerdrop": 0.0,
|
||||
"encoder_layers": 12,
|
||||
"eos_token_id": 2,
|
||||
"forced_eos_token_id": 2,
|
||||
"forced_bos_token_id": 0,
|
||||
"gradient_checkpointing": false,
|
||||
"id2label": {
|
||||
"0": "LABEL_0",
|
||||
"1": "LABEL_1",
|
||||
"2": "LABEL_2"
|
||||
},
|
||||
"init_std": 0.02,
|
||||
"is_encoder_decoder": true,
|
||||
"label2id": {
|
||||
"LABEL_0": 0,
|
||||
"LABEL_1": 1,
|
||||
"LABEL_2": 2
|
||||
},
|
||||
"max_position_embeddings": 1024,
|
||||
"model_type": "bart",
|
||||
"no_repeat_ngram_size": 3,
|
||||
"normalize_before": false,
|
||||
"num_beams": 4,
|
||||
"num_hidden_layers": 12,
|
||||
"pad_token_id": 1,
|
||||
"scale_embedding": false,
|
||||
"task_specific_params": {
|
||||
"summarization": {
|
||||
"length_penalty": 1.0,
|
||||
"max_length": 128,
|
||||
"min_length": 12,
|
||||
"num_beams": 4
|
||||
},
|
||||
"summarization_cnn": {
|
||||
"length_penalty": 2.0,
|
||||
"max_length": 142,
|
||||
"min_length": 56,
|
||||
"num_beams": 4
|
||||
},
|
||||
"summarization_xsum": {
|
||||
"length_penalty": 1.0,
|
||||
"max_length": 62,
|
||||
"min_length": 11,
|
||||
"num_beams": 6
|
||||
}
|
||||
},
|
||||
"transformers_version": "4.7.0.dev0",
|
||||
"use_cache": true,
|
||||
"vocab_size": 50265,
|
||||
|
||||
"d_i3d_flow": 2048,
|
||||
"d_i3d_rgb": 2048,
|
||||
"d_sam": 512,
|
||||
"d_audio": 128,
|
||||
"top_k": 10,
|
||||
"num_nn": 4,
|
||||
"gnn_type": "appnp",
|
||||
"use_random_graphs": false,
|
||||
"integrate_all_gnn_features": true,
|
||||
"use_elbo_local": true,
|
||||
"use_elbo_global": true,
|
||||
"use_non_linear": true,
|
||||
"gnn_K": 2,
|
||||
"gnn_alpha": 0.1,
|
||||
"num_modalities": 3,
|
||||
"local_gnn_d_hidden": 1024,
|
||||
"global_gnn_d_hidden": 1024,
|
||||
"num_local_gnn_heads": 2,
|
||||
"num_global_gnn_heads": 4,
|
||||
"local_gnn_dropout": 0.1,
|
||||
"global_gnn_dropout": 0.1,
|
||||
"local_fc_dropout": 0.1,
|
||||
"global_fc_dropout": 0.1,
|
||||
"num_local_gnn_layers": 1,
|
||||
"num_global_gnn_layers": 1,
|
||||
"num_local_fc_layers": 1,
|
||||
"num_global_fc_layers": 1,
|
||||
"use_local_gnn_bn": true,
|
||||
"use_global_gnn_bn": true,
|
||||
"use_local_fc_bn": true,
|
||||
"use_global_fc_bn": true,
|
||||
"local_gnn_concat": true,
|
||||
"global_gnn_concat": true,
|
||||
"num_local_gr_learner_heads": 8,
|
||||
"num_global_gr_learner_heads": 8,
|
||||
"init_adj_ratio": 0.5,
|
||||
"adj_ratio": 0.5,
|
||||
"alpha": 0.9,
|
||||
"gnns_every": 4,
|
||||
"num_layers_state_fc_decoder": 2,
|
||||
"dropout_state_fc_decoder": 0.3
|
||||
|
||||
}
|
||||
|
20
custom_datasets/README.md
Normal file
20
custom_datasets/README.md
Normal file
|
@ -0,0 +1,20 @@
|
|||
1. Download the raw [Charades train/val](https://prior.allenai.org/projects/charades) data
|
||||
2. Download the raw [Charades test](https://ai2-public-datasets.s3-us-west-2.amazonaws.com/charades/Charades_vu17_test_480.tar) data
|
||||
3. Install [SAM](https://github.com/facebookresearch/segment-anything.git)
|
||||
4. Segment the frames
|
||||
```shell
|
||||
python segement.py --sam_ckpt path_to_sam_ckpt --avsd_root path_to_charades_trval_frames --crop_root path_to_save_the_trval_crops --mode segment --start start_idx --end end_idx
|
||||
python segement.py --sam_ckpt path_to_sam_ckpt --avsd_root path_to_charades_test_frames --crop_root path_to_save_the_test_crops --mode segment --start start_idx --end end_id
|
||||
```
|
||||
5. Embed the crops
|
||||
```shell
|
||||
python segement.py --sam_ckpt path_to_sam_ckpt --crop_root path_to_save_the_trval_crops --mode emebed --embed_root ../features/sam --start start_idx --end end_idx
|
||||
python segement.py --sam_ckpt path_to_sam_ckpt --crop_root path_to_save_the_test_crops --mode emebed --embed_root ../features/sam_testset --start start_idx --end end_idx
|
||||
|
||||
```
|
||||
6. Preprocess and log the data
|
||||
```shell
|
||||
python dataset.py --split train
|
||||
python dataset.py --split val
|
||||
|
||||
```
|
0
custom_datasets/__init__.py
Normal file
0
custom_datasets/__init__.py
Normal file
401
custom_datasets/avsd.py
Normal file
401
custom_datasets/avsd.py
Normal file
|
@ -0,0 +1,401 @@
|
|||
import os
|
||||
import pickle
|
||||
import pyhocon
|
||||
from copy import deepcopy
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
import torch
|
||||
from argparse import ArgumentParser
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from transformers import BartTokenizer
|
||||
from itertools import chain
|
||||
|
||||
|
||||
ADDITIONAL_SPECIAL_TOKENS = [
|
||||
'<place_holder>', '<s0>', '<s1>', '<s2>', '<s3>', '<s4>', '<s5>']
|
||||
|
||||
SPECIAL_TOKENS_DICT = {
|
||||
'bos_token': '<s>',
|
||||
'eos_token': '</s>',
|
||||
'pad_token': '<pad>',
|
||||
'additional_special_tokens': ADDITIONAL_SPECIAL_TOKENS
|
||||
}
|
||||
|
||||
S0_TOK = '<s0>' # I3D_flow
|
||||
S1_TOK = '<s1>' # I3D_rgb
|
||||
S2_TOK = '<s2>' # sam obj
|
||||
S3_TOK = '<s3>' # audio
|
||||
S4_TOK = '<s4>' # history
|
||||
S5_TOK = '<s5>' # question
|
||||
|
||||
|
||||
|
||||
def tokenize(obj, tokenizer):
|
||||
if isinstance(obj, str):
|
||||
return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj))
|
||||
if isinstance(obj, dict):
|
||||
return dict((n, tokenize(o)) for n, o in obj.items())
|
||||
return list(tokenize(o) for o in obj)
|
||||
|
||||
|
||||
class AVSDDataset(Dataset):
|
||||
def __init__(self, config, split):
|
||||
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.split = split
|
||||
self.bart_max_input_len = config['bart_max_input_len']
|
||||
self.bart_size = config['bart_size']
|
||||
self.cap_sum = config['cap_sum']
|
||||
self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-{}'.format(self.bart_size))
|
||||
self.vocab_size = self.tokenizer.vocab_size
|
||||
|
||||
self.tokenizer.add_special_tokens({'additional_special_tokens': ADDITIONAL_SPECIAL_TOKENS})
|
||||
self.vocab_size += len(ADDITIONAL_SPECIAL_TOKENS)
|
||||
self.tokenizer.save_pretrained(os.path.join(self.config['log_dir'], 'bart_tokenizer'))
|
||||
self.processed_dir = os.path.join(self.config['avsd_processed'], 'hist_with_{}_rounds'.format(self.config['n_history']), split)
|
||||
self.paths = list(map(lambda p: os.path.join(self.processed_dir, p), os.listdir(self.processed_dir)))
|
||||
|
||||
if self.config['overfit'] > 0:
|
||||
self.paths = self.paths[:self.config['overfit_size']]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths)
|
||||
|
||||
def __getitem__(self, index):
|
||||
pth = self.paths[index]
|
||||
with open(pth, 'rb') as f:
|
||||
item = pickle.load(f)
|
||||
|
||||
question_sep = self.tokenizer.convert_tokens_to_ids('<s5>')
|
||||
|
||||
input_ids = item['input_ids']
|
||||
history_end = (input_ids == question_sep).nonzero(as_tuple=True)[0]
|
||||
|
||||
history_interval = [0, history_end.item()] # The last token is the question state token (not part of the history)
|
||||
question_interval = [history_end.item(), input_ids.size(0)]
|
||||
|
||||
lm_labels = item['lm_labels']
|
||||
i3d_rgb = item['i3d_rgb']
|
||||
i3d_flow = item['i3d_flow']
|
||||
sam = item['sam']
|
||||
vgg = item['vgg']
|
||||
vid = item['vid']
|
||||
|
||||
return input_ids, lm_labels, history_interval, question_interval, i3d_rgb, i3d_flow, sam, vgg, vid
|
||||
|
||||
def padding(self, seq, pad_token, max_len=None):
|
||||
if max_len is None:
|
||||
max_len = max([i.size(0) for i in seq])
|
||||
if len(seq[0].size()) == 1:
|
||||
result = torch.ones((len(seq), max_len)).long() * pad_token
|
||||
else:
|
||||
result = torch.ones((len(seq), max_len, seq[0].size(-1))).float()
|
||||
for i in range(len(seq)):
|
||||
result[i, :seq[i].size(0)] = seq[i]
|
||||
return result
|
||||
|
||||
def collate_fn(self, batch):
|
||||
input_ids_list, lm_labels_list, history_interval_list, question_interval_list, i3d_rgb_list, i3d_flow_list, sam_list, vggish_list, vid_ids_list = [], [], [], [], [], [], [], [], []
|
||||
for i in batch:
|
||||
input_ids_list.append(i[0])
|
||||
lm_labels_list.append(i[1])
|
||||
history_interval_list.append(i[2])
|
||||
question_interval_list.append(i[3])
|
||||
i3d_rgb_list.append(i[4])
|
||||
i3d_flow_list.append(i[5])
|
||||
sam_list.append(i[6])
|
||||
vggish_list.append(i[7])
|
||||
vid_ids_list.append(i[8])
|
||||
|
||||
history_intervals = np.array(history_interval_list)
|
||||
question_intervals = np.array(question_interval_list)
|
||||
|
||||
|
||||
min_len_i3d_flow = min([feat.shape[0] for feat in i3d_flow_list])
|
||||
min_len_i3d_rgb = min([feat.shape[0] for feat in i3d_rgb_list])
|
||||
min_len_sam = min([feat.shape[0] for feat in sam_list])
|
||||
min_len_vggish = min([feat.shape[0] for feat in vggish_list])
|
||||
|
||||
min_length = min([self.config['vis_feat_length'], min_len_i3d_flow, min_len_i3d_rgb, min_len_sam, min_len_vggish])
|
||||
|
||||
# Sample equally-distant features from the visual features for each sample within the batch
|
||||
for i in range(len(i3d_rgb_list)):
|
||||
sample_idx_i3d_rgb = np.round(np.linspace(0, i3d_rgb_list[i].shape[0] - 1, min_length)).astype(int)
|
||||
i3d_rgb_list[i] = i3d_rgb_list[i][sample_idx_i3d_rgb, :]
|
||||
i3d_rgb = torch.from_numpy(np.array(i3d_rgb_list)).float()
|
||||
|
||||
for i in range(len(i3d_flow_list)):
|
||||
sample_idx_i3d_flow = np.round(np.linspace(0, i3d_flow_list[i].shape[0] - 1, min_length)).astype(int)
|
||||
i3d_flow_list[i] = i3d_flow_list[i][sample_idx_i3d_flow, :]
|
||||
i3d_flow = torch.from_numpy(np.array(i3d_flow_list)).float()
|
||||
|
||||
for i in range(len(sam_list)):
|
||||
sample_idx_sam = np.round(np.linspace(0, sam_list[i].shape[0] - 1, min_length)).astype(int)
|
||||
sam_list[i] = sam_list[i][sample_idx_sam, :]
|
||||
sam = torch.from_numpy(np.array(sam_list)).float()
|
||||
|
||||
for i in range(len(vggish_list)):
|
||||
sample_idx_vggish = np.round(np.linspace(0, vggish_list[i].shape[0] - 1, min_length)).astype(int)
|
||||
vggish_list[i] = vggish_list[i][sample_idx_vggish, :]
|
||||
vggish = torch.from_numpy(np.array(vggish_list)).float()
|
||||
|
||||
pad_token, i3d_flow_sep, i3d_rgb_sep, sam_sep, audio_sep, ph_token = self.tokenizer.convert_tokens_to_ids(
|
||||
['<pad>', '<s0>', '<s1>', '<s2>', '<s3>', '<place_holder>'])
|
||||
|
||||
# All the visual features will not be masked because we do not perform any padding on them
|
||||
video_mask = torch.ones((len(batch), min_length*4 + 4)) == 1 # NOTE *4: 4 modalities | +4: the state tokens
|
||||
# Now we create a dummy input for the video tokens (sole purpose is to reserve the spot of the seperators)
|
||||
dummy = torch.ones((len(batch), min_length)) * ph_token
|
||||
video_place_holder_ids = torch.cat(
|
||||
[torch.ones((len(batch), 1)) * i3d_rgb_sep, dummy,
|
||||
torch.ones((len(batch), 1)) * i3d_flow_sep, dummy,
|
||||
torch.ones((len(batch), 1)) * sam_sep, dummy,
|
||||
torch.ones((len(batch), 1)) * audio_sep, dummy,
|
||||
], dim=-1).long()
|
||||
|
||||
input_ids = self.padding(input_ids_list, pad_token)
|
||||
lm_labels = self.padding(lm_labels_list, -100)
|
||||
text_mask = input_ids != pad_token
|
||||
input_mask = torch.cat([video_mask, text_mask], dim=1)
|
||||
|
||||
# Now we get the intervals of the visual input tokens
|
||||
# Here the interval do not change across the batch dimension
|
||||
i3d_rgb_interval = [0, min_length + 1] # the last token is not part of this modality
|
||||
i3d_flow_interval = [min_length + 1, 2 * min_length + 2]
|
||||
sam_interval = [2 * min_length + 2, 3 * min_length + 3]
|
||||
audio_interval = [3 * min_length + 3, 4 * min_length + 4]
|
||||
|
||||
vis_state_vector_idx = [i3d_rgb_interval[0], i3d_flow_interval[0], sam_interval[0], audio_interval[0]]
|
||||
|
||||
# adapt the question and history interval -- shifted to the right by the visual input length
|
||||
history_intervals += 4 * min_length + 4
|
||||
question_intervals += 4 * min_length + 4
|
||||
history_intervals = history_intervals.tolist()
|
||||
question_intervals = question_intervals.tolist()
|
||||
|
||||
history_state_vector_idx = [x[0] + 1 for x in history_intervals] # +1 because the history starts with <s><s4> .....
|
||||
question_state_vector_idx = [x[0] for x in question_intervals] # +1 because the history starts with <s><s4> .....
|
||||
|
||||
batch = {
|
||||
'input_ids': input_ids,
|
||||
'video_place_holder_ids': video_place_holder_ids,
|
||||
'i3d_rgb': i3d_rgb,
|
||||
'i3d_flow': i3d_flow,
|
||||
'sam': sam,
|
||||
'vggish': vggish,
|
||||
'lm_labels': lm_labels,
|
||||
'input_mask': input_mask,
|
||||
'i3d_rgb_interval': i3d_rgb_interval,
|
||||
'i3d_flow_interval': i3d_flow_interval,
|
||||
'sam_interval': sam_interval,
|
||||
'audio_interval': audio_interval,
|
||||
'history_intervals': history_intervals,
|
||||
'question_intervals': question_intervals,
|
||||
'vis_state_vector_idx': vis_state_vector_idx,
|
||||
'history_state_vector_idx': history_state_vector_idx,
|
||||
'question_state_vector_idx': question_state_vector_idx
|
||||
}
|
||||
return batch
|
||||
|
||||
|
||||
def get_dataset(config, split, tokenizer):
|
||||
if split != 'test':
|
||||
dialog_pth = config[f'avsd_{split}']
|
||||
else:
|
||||
dialog_pth = config['avsd_test_dstc{}'.format(config['dstc'])]
|
||||
n_history = config['n_history']
|
||||
dialog_data = json.load(open(dialog_pth, 'r'))
|
||||
dialog_list = []
|
||||
vid_set = set()
|
||||
undisclosed_only = split == 'test'
|
||||
pbar = tqdm(dialog_data['dialogs'])
|
||||
|
||||
pbar.set_description('[INFO] Generating {} items | DSTC {}'.format(split, config['dstc']))
|
||||
for dialog in pbar:
|
||||
if config['dstc'] != 10:
|
||||
caption = [tokenize(dialog['caption'], tokenizer)] + [tokenize(dialog['summary'], tokenizer)]
|
||||
else:
|
||||
caption = [tokenize('no', tokenizer)]
|
||||
|
||||
questions = [tokenize(d['question'], tokenizer) for d in dialog['dialog']]
|
||||
answers = [tokenize(d['answer'], tokenizer) for d in dialog['dialog']]
|
||||
vid = dialog["image_id"]
|
||||
vid_set.add(vid)
|
||||
if undisclosed_only:
|
||||
it = range(len(questions) - 1, len(questions))
|
||||
else:
|
||||
it = range(len(questions))
|
||||
qalist=[]
|
||||
history = []
|
||||
if undisclosed_only:
|
||||
for n in range(len(questions)-1):
|
||||
qalist.append(questions[n])
|
||||
qalist.append(answers[n])
|
||||
history=qalist[max(-len(qalist),-n_history*2):]
|
||||
for n in it:
|
||||
if undisclosed_only:
|
||||
assert dialog['dialog'][n]['answer'] == '__UNDISCLOSED__'
|
||||
question = questions[n]
|
||||
answer = answers[n]
|
||||
history.append(question)
|
||||
if n_history == 0:
|
||||
item = {'vid': vid, 'history': [question], 'answer': answer, 'caption': caption}
|
||||
else:
|
||||
item = {'vid': vid, 'history': history, 'answer': answer, 'caption': caption}
|
||||
dialog_list.append(item)
|
||||
qalist.append(question)
|
||||
qalist.append(answer)
|
||||
history=qalist[max(-len(qalist),-n_history*2):]
|
||||
|
||||
all_features = {}
|
||||
fea_types = ['vggish', 'i3d_flow', 'i3d_rgb', 'sam']
|
||||
|
||||
dataname = '<FeaType>/<ImageID>.npy'
|
||||
for ftype in fea_types:
|
||||
if undisclosed_only:
|
||||
basename = dataname.replace('<FeaType>', ftype+'_testset')
|
||||
else:
|
||||
basename = dataname.replace('<FeaType>', ftype)
|
||||
features = {}
|
||||
for vid in vid_set:
|
||||
filename = basename.replace('<ImageID>', vid)
|
||||
filepath = config['avsd_feature_path'] + filename
|
||||
features[vid] = filepath
|
||||
all_features[ftype] = features
|
||||
return dialog_list, all_features
|
||||
|
||||
|
||||
def build_input_from_segments(caption, history_orig, reply, tokenizer, add_state_tokens=True, drop_caption=False):
|
||||
""" Build a sequence of input from 3 segments: caption(caption+summary) history and last reply """
|
||||
|
||||
bos, eos, hist_state, ques_state = tokenizer.convert_tokens_to_ids(['<s>', '</s>', '<s4>', '<s5>'])
|
||||
sep = eos
|
||||
|
||||
instance = {}
|
||||
instance["lm_labels"] = reply + [eos]
|
||||
caption = list(chain(*caption))
|
||||
|
||||
# Add state tokens if applicable
|
||||
if add_state_tokens:
|
||||
caption.insert(0, hist_state)
|
||||
history = deepcopy(history_orig)
|
||||
history[-1].insert(0, ques_state)
|
||||
else:
|
||||
history = history_orig
|
||||
|
||||
if not drop_caption:
|
||||
# sequence = [[bos] + list(chain(*caption))] + history + [reply + ([eos] if with_eos else [])]
|
||||
|
||||
# NOTE It is important not to include the reply in the input of the encoder -- > the decoder will just
|
||||
# learn to copy it --> low train/val loss but no learning is happening
|
||||
sequence = [[bos] + caption + [eos]] + [[sep] + s for s in history] + [[eos]]
|
||||
else:
|
||||
sequence = [[bos]] + [[hist_state]] + [[sep] + s for s in history] + [[eos]]
|
||||
|
||||
instance["input_ids"] = list(chain(*sequence))
|
||||
return instance
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = ArgumentParser(description='debug dataloader')
|
||||
parser.add_argument(
|
||||
'--split',
|
||||
type=str,
|
||||
default='train',
|
||||
help='train or val')
|
||||
|
||||
parser.add_argument(
|
||||
'--model',
|
||||
type=str,
|
||||
default='mixer',
|
||||
help='model name to train or test')
|
||||
|
||||
parser.add_argument(
|
||||
'--log_dataset',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='Whether or not to log the processed data')
|
||||
|
||||
parser.add_argument(
|
||||
'--add_state_tokens',
|
||||
action='store_true',
|
||||
default=True,
|
||||
help='Whether or not to add state tokens')
|
||||
|
||||
parser.add_argument(
|
||||
'--log_dir',
|
||||
type=str,
|
||||
default='processed/avsd',
|
||||
help='Output directory')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
split = args.split
|
||||
|
||||
config = pyhocon.ConfigFactory.parse_file(
|
||||
'config/mst_mixer.conf')[args.model]
|
||||
config['expand_rnd'] = False
|
||||
config['debugging'] = False
|
||||
config['overfit'] = False
|
||||
args.log_dir = os.path.join(args.log_dir, 'hist_with_{}_rounds'.format(config['n_history']) )
|
||||
if args.log_dataset:
|
||||
log_dir = os.path.join(args.log_dir, split)
|
||||
if not os.path.isdir(log_dir):
|
||||
os.makedirs(log_dir)
|
||||
|
||||
tokenizer = BartTokenizer.from_pretrained('facebook/bart-{}'.format(config['bart_size']))
|
||||
tokenizer.add_special_tokens({'additional_special_tokens': ADDITIONAL_SPECIAL_TOKENS})
|
||||
dialogs, features = get_dataset(config, split, tokenizer)
|
||||
pbar = tqdm(dialogs)
|
||||
pbar.set_description('[{}] Logging processed data'.format(split))
|
||||
counter = 0
|
||||
for dialog in pbar:
|
||||
vid = dialog['vid']
|
||||
his = dialog['history']
|
||||
cap = dialog['caption']
|
||||
ans = dialog['answer']
|
||||
|
||||
if np.random.rand() < config['caption_drop_rate']:
|
||||
instance = build_input_from_segments(
|
||||
cap, his, ans, tokenizer, add_state_tokens=args.add_state_tokens, drop_caption=True)
|
||||
else:
|
||||
instance = build_input_from_segments(
|
||||
cap, his, ans, tokenizer, add_state_tokens=args.add_state_tokens, drop_caption=False)
|
||||
|
||||
input_ids = torch.Tensor(instance["input_ids"]).long()
|
||||
lm_labels = torch.Tensor(instance["lm_labels"]).long()
|
||||
|
||||
vgg = np.load(features["vggish"][vid])
|
||||
i3d_flow = np.load(features["i3d_flow"][vid])
|
||||
i3d_rgb = np.load(features["i3d_rgb"][vid])
|
||||
sam = np.load(features["sam"][vid])
|
||||
|
||||
item = {
|
||||
'input_ids': input_ids,
|
||||
'lm_labels': lm_labels,
|
||||
'i3d_rgb': i3d_rgb,
|
||||
'i3d_flow': i3d_flow,
|
||||
'sam': sam,
|
||||
'vgg': vgg,
|
||||
'vid': vid
|
||||
}
|
||||
counter += 1
|
||||
pth = os.path.join(log_dir, str(counter) + '.pkl')
|
||||
with open(pth, 'wb') as f:
|
||||
pickle.dump(item, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
else:
|
||||
avsd_dataset = AVSDDataset(config, 'val')
|
||||
avsd_dataloader = DataLoader(avsd_dataset, batch_size=4, shuffle=False, collate_fn=avsd_dataset.collate_fn)
|
||||
|
||||
for i, data in enumerate(avsd_dataloader):
|
||||
print('{}/{}'.format(i, len(avsd_dataloader)))
|
||||
print(avsd_dataset.max_len)
|
||||
|
||||
print('[INFO] Done...')
|
211
custom_datasets/nextqa.py
Normal file
211
custom_datasets/nextqa.py
Normal file
|
@ -0,0 +1,211 @@
|
|||
import os
|
||||
import pandas as pd
|
||||
import h5py
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import BartTokenizer
|
||||
from itertools import chain
|
||||
|
||||
|
||||
ADDITIONAL_SPECIAL_TOKENS = [
|
||||
'<place_holder>', '<s0>', '<s1>', '<s2>', '<s3>', '<s4>', '<s5>']
|
||||
|
||||
SPECIAL_TOKENS_DICT = {
|
||||
'bos_token': '<s>',
|
||||
'eos_token': '</s>',
|
||||
'pad_token': '<pad>',
|
||||
'additional_special_tokens': ADDITIONAL_SPECIAL_TOKENS
|
||||
}
|
||||
|
||||
S0_TOK = '<s0>' # frame
|
||||
S1_TOK = '<s1>' # mot
|
||||
S2_TOK = '<s2>' # question
|
||||
|
||||
def load_file(file_name):
|
||||
annos = None
|
||||
if os.path.splitext(file_name)[-1] == '.csv':
|
||||
return pd.read_csv(file_name)
|
||||
with open(file_name, 'r') as fp:
|
||||
if os.path.splitext(file_name)[1]== '.txt':
|
||||
annos = fp.readlines()
|
||||
annos = [line.rstrip() for line in annos]
|
||||
if os.path.splitext(file_name)[1] == '.json':
|
||||
annos = json.load(fp)
|
||||
|
||||
return annos
|
||||
|
||||
|
||||
def tokenize(obj, tokenizer):
|
||||
if isinstance(obj, str):
|
||||
return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj))
|
||||
if isinstance(obj, dict):
|
||||
return dict((n, tokenize(o)) for n, o in obj.items())
|
||||
return list(tokenize(o) for o in obj)
|
||||
|
||||
|
||||
class NextQADataset(Dataset):
|
||||
def __init__(self, config, split):
|
||||
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.split = split
|
||||
self.bart_max_input_len = config['bart_max_input_len']
|
||||
self.bart_size = config['bart_size']
|
||||
self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-{}'.format(self.bart_size))
|
||||
self.vocab_size = self.tokenizer.vocab_size
|
||||
|
||||
self.tokenizer.add_special_tokens({'additional_special_tokens': ADDITIONAL_SPECIAL_TOKENS})
|
||||
self.vocab_size += len(ADDITIONAL_SPECIAL_TOKENS)
|
||||
self.tokenizer.save_pretrained(os.path.join(self.config['log_dir'], 'bart_tokenizer'))
|
||||
|
||||
sample_list_file = os.path.join(self.config['nextqa_root'], '{}.csv'.format(split))
|
||||
self.sample_list = load_file(sample_list_file)
|
||||
|
||||
vid_feat_file = os.path.join(self.config['nextqa_vid_feat'], 'app_mot_{}.h5'.format(split))
|
||||
print('Load {}...'.format(vid_feat_file))
|
||||
self.frame_feats = {}
|
||||
self.mot_feats = {}
|
||||
with h5py.File(vid_feat_file, 'r') as fp:
|
||||
vids = fp['ids']
|
||||
feats = fp['feat']
|
||||
for vid, feat in zip(vids, feats):
|
||||
self.frame_feats[str(vid)] = feat[:, :2048] # (16, 2048)
|
||||
self.mot_feats[str(vid)] = feat[:, 2048:] # (16, 2048)
|
||||
|
||||
if self.config['overfit_size'] > 0:
|
||||
self.sample_list = self.sample_list[:self.config['overfit_size']]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sample_list)
|
||||
|
||||
def get_video_feature(self, video_name):
|
||||
"""
|
||||
:param video_name:
|
||||
:return:
|
||||
"""
|
||||
|
||||
app_feat = self.frame_feats[video_name]
|
||||
app_feat = torch.from_numpy(app_feat).type(torch.float32)
|
||||
|
||||
mot_feat = self.mot_feats[video_name]
|
||||
mot_feat = torch.from_numpy(mot_feat).type(torch.float32)
|
||||
|
||||
return app_feat, mot_feat
|
||||
|
||||
|
||||
def __getitem__(self, idx):
|
||||
cur_sample = self.sample_list.loc[idx]
|
||||
video_name, ques, ans, qid = str(cur_sample['video']), str(cur_sample['question']),\
|
||||
str(cur_sample['answer']), str(cur_sample['qid'])
|
||||
|
||||
input_ids = tokenize(ques, self.tokenizer)
|
||||
lm_labels = tokenize(ans, self.tokenizer)
|
||||
|
||||
app_feat, mot_feat = self.get_video_feature(video_name)
|
||||
|
||||
bos, eos, ques_state = self.tokenizer.convert_tokens_to_ids(['<s>', '</s>', '<s2>'])
|
||||
|
||||
# Add state tokens
|
||||
input_ids.insert(0, ques_state)
|
||||
lm_labels.append(eos)
|
||||
question_interval = [0, len(input_ids)]
|
||||
|
||||
input_ids = torch.Tensor(input_ids).long()
|
||||
lm_labels = torch.Tensor(lm_labels).long()
|
||||
|
||||
return input_ids, lm_labels, app_feat, mot_feat, question_interval, video_name
|
||||
|
||||
|
||||
def padding(self, seq, pad_token, max_len=None):
|
||||
if max_len is None:
|
||||
max_len = max([i.size(0) for i in seq])
|
||||
if len(seq[0].size()) == 1:
|
||||
result = torch.ones((len(seq), max_len)).long() * pad_token
|
||||
else:
|
||||
result = torch.ones((len(seq), max_len, seq[0].size(-1))).float()
|
||||
for i in range(len(seq)):
|
||||
result[i, :seq[i].size(0)] = seq[i]
|
||||
return result
|
||||
|
||||
def collate_fn(self, batch):
|
||||
input_ids_list, lm_labels_list, app_feat_list, mot_feat_list, question_interval_list, vid_ids_list = [], [], [], [], [], []
|
||||
for i in batch:
|
||||
input_ids_list.append(i[0])
|
||||
lm_labels_list.append(i[1])
|
||||
app_feat_list.append(i[2])
|
||||
mot_feat_list.append(i[3])
|
||||
question_interval_list.append(i[4])
|
||||
vid_ids_list.append(i[5])
|
||||
|
||||
app_feats = torch.stack(app_feat_list, dim=0).float()
|
||||
mot_feats = torch.stack(mot_feat_list, dim=0).float()
|
||||
|
||||
question_intervals = np.array(question_interval_list)
|
||||
|
||||
pad_token, app_sep, mot_sep, ph_token = self.tokenizer.convert_tokens_to_ids(
|
||||
['<pad>', '<s0>', '<s1>', '<place_holder>'])
|
||||
|
||||
# All the visual features will not be masked because we do not perform any padding on them
|
||||
video_mask = torch.ones((len(batch), 16*2 + 2)) == 1 # NOTE *2: 2 modalities | +2: the state tokens | each modality has length 16
|
||||
# Now we create a dummy input for the video tokens (sole purpose is to reserve the spot of the seperators)
|
||||
dummy = torch.ones((len(batch), 16)) * ph_token
|
||||
video_place_holder_ids = torch.cat(
|
||||
[torch.ones((len(batch), 1)) * app_sep, dummy,
|
||||
torch.ones((len(batch), 1)) * mot_sep, dummy,
|
||||
], dim=-1).long()
|
||||
|
||||
input_ids = self.padding(input_ids_list, pad_token)
|
||||
lm_labels = self.padding(lm_labels_list, -100)
|
||||
text_mask = input_ids != pad_token
|
||||
input_mask = torch.cat([video_mask, text_mask], dim=1)
|
||||
|
||||
# Now we get the intervals of the visual input tokens
|
||||
# Here the interval do not change across the batch dimension
|
||||
app_interval = [0, 16 + 1] # the last token is not part of this modality
|
||||
mot_interval = [16 + 1, 2 * 16 + 2]
|
||||
vis_state_vector_idx = [app_interval[0], mot_interval[0]]
|
||||
|
||||
# adapt the question and history interval -- shifted to the right by the visual input length
|
||||
question_intervals += 2 * 16 + 2
|
||||
question_intervals = question_intervals.tolist()
|
||||
|
||||
question_state_vector_idx = [x[0] for x in question_intervals]
|
||||
|
||||
batch = {
|
||||
'input_ids': input_ids,
|
||||
'video_place_holder_ids': video_place_holder_ids,
|
||||
'app_feats': app_feats,
|
||||
'mot_feats': mot_feats,
|
||||
'lm_labels': lm_labels,
|
||||
'input_mask': input_mask,
|
||||
'app_interval': app_interval,
|
||||
'mot_interval': mot_interval,
|
||||
'question_intervals': question_intervals,
|
||||
'vis_state_vector_idx': vis_state_vector_idx,
|
||||
'question_state_vector_idx': question_state_vector_idx
|
||||
}
|
||||
return batch
|
||||
|
||||
def get_dataset(config, split):
|
||||
|
||||
bart_max_input_len = config['bart_max_input_len']
|
||||
bart_size = config['bart_size']
|
||||
|
||||
sample_list_file = os.path.join(config['nextqa_root'], '{}.csv'.format(split))
|
||||
sample_list = load_file(sample_list_file)
|
||||
|
||||
vid_feat_file = os.path.join(config['nextqa_vid_feat'], 'app_mot_{}.h5'.format(split))
|
||||
print('Load {}...'.format(vid_feat_file))
|
||||
app_feats = {}
|
||||
mot_feats = {}
|
||||
with h5py.File(vid_feat_file, 'r') as fp:
|
||||
vids = fp['ids']
|
||||
feats = fp['feat']
|
||||
for vid, feat in zip(vids, feats):
|
||||
app_feats[str(vid)] = feat[:, :2048] # (16, 2048)
|
||||
mot_feats[str(vid)] = feat[:, 2048:] # (16, 2048)
|
||||
|
||||
return sample_list, app_feats, mot_feats
|
||||
|
179
custom_datasets/segment.py
Normal file
179
custom_datasets/segment.py
Normal file
|
@ -0,0 +1,179 @@
|
|||
from segment_anything import SamPredictor, SamAutomaticMaskGenerator, sam_model_registry
|
||||
from tqdm import tqdm
|
||||
from argparse import ArgumentParser
|
||||
import pickle
|
||||
import cv2
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
'--sam_ckpt',
|
||||
type=str,
|
||||
help='SAM checkpoint to be used'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--avsd_root',
|
||||
type=str,
|
||||
help='Directory where the individual AVSD frames are located'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--crop_root',
|
||||
type=str,
|
||||
help='Directory where the individual crops (objects) will be saved'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--embed_root',
|
||||
type=str,
|
||||
help='Directory where the individual embeddings will be saved'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--mode',
|
||||
type=str,
|
||||
choices=['segment', 'embed'],
|
||||
help='segment: segment the image into regions | embed: embed the image crops detected during segmentation'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--start',
|
||||
type=int,
|
||||
default=0,
|
||||
help='Start index of the partition'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--end',
|
||||
type=int,
|
||||
default=1968,
|
||||
help='End index of the partition'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def partition_ids(avsd_ids, start, end):
|
||||
avsd_ids.sort()
|
||||
assert start < end
|
||||
assert start >= 0 and end <= len(avsd_ids)
|
||||
avsd_ids_partition = avsd_ids[start:end]
|
||||
return avsd_ids_partition
|
||||
|
||||
|
||||
def get_middle_frames(avsd_ids_partition, avsd_root):
|
||||
pbar = tqdm(avsd_ids_partition)
|
||||
pbar.set_description('[INFO] Preparing frames of {} videos'.format(len(avsd_ids_partition)))
|
||||
path_list = []
|
||||
for avsd_id in pbar:
|
||||
frames = os.listdir(os.path.join(avsd_root, avsd_id))
|
||||
if 'test' in avsd_root:
|
||||
frames.sort(key=lambda f: int(f.split('_')[-1].split('.')[0]))
|
||||
else:
|
||||
frames.sort(key=lambda f: int(f.split('-')[-1].split('.')[0]))
|
||||
middle_frame = frames[int(len(frames)/2)]
|
||||
middle_frame = os.path.join(avsd_root, avsd_id, middle_frame)
|
||||
path_list.append(middle_frame)
|
||||
return path_list
|
||||
|
||||
|
||||
def segment_images(sam, path_list, crop_root):
|
||||
mask_generator = SamAutomaticMaskGenerator(sam)
|
||||
pbar = tqdm(path_list)
|
||||
pbar.set_description('Detecting Objects')
|
||||
for pth in pbar:
|
||||
vid_id = pth.split('/')[-2]
|
||||
crop_dir = os.path.join(crop_root, vid_id)
|
||||
if not os.path.isdir(crop_dir):
|
||||
os.makedirs(crop_dir)
|
||||
|
||||
image = cv2.imread(pth)
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
masks = mask_generator.generate(image)
|
||||
masks.sort(key=lambda e: e['stability_score'], reverse=True)
|
||||
if len(masks) > 36:
|
||||
masks = masks[:36]
|
||||
for i, mask in enumerate(masks):
|
||||
crop = image[
|
||||
int(mask['bbox'][1]):int(mask['bbox'][1] + mask['bbox'][3] + 1),
|
||||
int(mask['bbox'][0]):int(mask['bbox'][0] + mask['bbox'][2] + 1),
|
||||
:
|
||||
]
|
||||
crop_flipped = cv2.flip(crop, 1) # Horizontal flip
|
||||
cv2.imwrite(os.path.join(crop_dir, f'obj_{i}.jpg'), crop)
|
||||
cv2.imwrite(os.path.join(crop_dir, f'obj_{i}_flipped.jpg'), crop_flipped)
|
||||
|
||||
print('[INFO] Done...')
|
||||
|
||||
|
||||
def embed_objects(sam, crop_ids, crop_root, embed_root):
|
||||
predictor = SamPredictor(sam)
|
||||
pbar = tqdm(crop_ids)
|
||||
pbar.set_description('Embedding Objects')
|
||||
for vid_id in pbar:
|
||||
embeds = []
|
||||
crop_dir = os.path.join(crop_root, vid_id)
|
||||
crop_paths = list(map(lambda p: os.path.join(crop_dir, p), os.listdir(crop_dir)))
|
||||
crop_paths = list(filter(lambda p: 'flipped' not in p, crop_paths))
|
||||
crop_paths.sort(key=lambda p: int(p.split('_')[-1].split('.')[0]))
|
||||
for cp in crop_paths:
|
||||
crop = cv2.imread(cp)
|
||||
crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
|
||||
predictor.set_image(crop)
|
||||
embed_crop = predictor.get_image_embedding()
|
||||
embed_crop = embed_crop.mean(-1).mean(-1)
|
||||
|
||||
crop_flipped = cv2.flip(crop, 1)
|
||||
predictor.set_image(crop_flipped)
|
||||
embed_crop_flipped = predictor.get_image_embedding()
|
||||
embed_crop_flipped = embed_crop_flipped.mean(-1).mean(-1)
|
||||
|
||||
embed = torch.cat((embed_crop, embed_crop_flipped), dim=-1)
|
||||
# embed = embed.copy().cpu()
|
||||
embeds.append(embed)
|
||||
|
||||
embeds = torch.cat(embeds, 0).cpu().numpy()
|
||||
np.save(os.path.join(embed_root, f'{vid_id}.npy'), embeds)
|
||||
|
||||
print('[INFO] Done...')
|
||||
|
||||
|
||||
def segment(args, sam):
|
||||
avsd_ids = os.listdir(args.avsd_root)
|
||||
avsd_ids.sort()
|
||||
avsd_ids_partition = partition_ids(avsd_ids, args.start, args.end)
|
||||
path_list = get_middle_frames(avsd_ids_partition, args.avsd_root)
|
||||
|
||||
if not os.path.isdir(args.crop_root):
|
||||
os.makedirs(args.crop_root)
|
||||
segment_images(sam, path_list, args.crop_root)
|
||||
|
||||
|
||||
def embed(args, sam):
|
||||
crop_ids = os.listdir(args.crop_root)
|
||||
crop_ids.sort()
|
||||
crop_ids_partition = partition_ids(crop_ids, args.start, args.end)
|
||||
if not os.path.isdir(args.embed_root):
|
||||
os.makedirs(args.embed_root)
|
||||
embed_objects(sam, crop_ids_partition, args.crop_root, args.embed_root)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
sam = sam_model_registry['vit_h |