Make code public
This commit is contained in:
commit
8e03ef1c38
49 changed files with 545354 additions and 0 deletions
3
.gitattributes
vendored
Normal file
3
.gitattributes
vendored
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
*.pkl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.h5 filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.csv filter=lfs diff=lfs merge=lfs -text
|
21
LICENSE
Normal file
21
LICENSE
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2023 Anonymous
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
174
README.md
Normal file
174
README.md
Normal file
|
@ -0,0 +1,174 @@
|
||||||
|
<div align="center">
|
||||||
|
<h1> MST-MIXER <img src="misc/mixer.png" width="3%" align="bottom">: Multi-Modal Video Dialog State Tracking in the Wild </h1>
|
||||||
|
|
||||||
|
**[Adnen Abdessaied][16], [Lei Shi][17], [Andreas Bulling][18]** <br> <br>
|
||||||
|
**ECCV 2024, Milan, Italy <img src="misc/italy.png" width="3%" align="center">** <br>
|
||||||
|
**[[Paper][19]]**
|
||||||
|
|
||||||
|
---------------------------
|
||||||
|
<img src="misc/teaser.png" width="70%" align="middle"><br><br>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
|
||||||
|
# Citation
|
||||||
|
If you find our code useful or use it in your own projects, please cite our paper:
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@InProceedings{Abdessaied_2024_eccv,
|
||||||
|
author = {Abdessaied, Adnen and Shi, Lei and Bulling, Andreas},
|
||||||
|
title = {{Multi-Modal Video Dialog State Tracking in the Wild}},
|
||||||
|
booktitle = {Proceedings of the European Conference on Computer Vision (ECCV)},
|
||||||
|
year = {2024}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
# Table of Contents
|
||||||
|
* [Setup and Dependencies](#Setup-and-Dependencies)
|
||||||
|
* [Download Data](#Download-Data)
|
||||||
|
* [Training](#Training)
|
||||||
|
* [Response Generation](#Response-Generation)
|
||||||
|
* [Results](#Results)
|
||||||
|
* [Acknowledgements](#Acknowledgements)
|
||||||
|
|
||||||
|
# Setup and Dependencies
|
||||||
|
We implemented our model using Python 3.7 and PyTorch 1.12.0 (CUDA 11.3, CuDNN 8.3.2). We recommend to setup a virtual environment using Anaconda. <br>
|
||||||
|
1. Install [git lfs][1] on your system
|
||||||
|
2. Clone our repository to download a checpint of our best model and our code
|
||||||
|
```shell
|
||||||
|
git lfs install
|
||||||
|
git clone this_repo.git
|
||||||
|
```
|
||||||
|
3. Create a conda environment and install dependencies
|
||||||
|
```shell
|
||||||
|
conda create -n mst_mixer python=3.7
|
||||||
|
conda activate mst_mixer
|
||||||
|
conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit=11.3 -c pytorch
|
||||||
|
conda install pyg -c pyg
|
||||||
|
conda install pytorch-scatter -c pyg # pytorch >= 1.8.0
|
||||||
|
conda install pytorch-sparse -c pyg # pytorch >= 1.8.0
|
||||||
|
conda install -c huggingface transformers
|
||||||
|
pip install evaluate wandb glog pyhocon attrs
|
||||||
|
```
|
||||||
|
# Download Data
|
||||||
|
## AVSD
|
||||||
|
1. Download the [AVSD-DSTC7][2], [AVSD-DSTC8][3] and [AVSD-DSTC10][10] data
|
||||||
|
2. Place the raw json files in ```raw_data/``` and the features in ```features/```
|
||||||
|
3. Prepeocess and save the input features for faster training as indicated in ```custom_datasets/```
|
||||||
|
## NExT-QA
|
||||||
|
1. For convenience, we included the features/data in this git repo.
|
||||||
|
|
||||||
|
# Training
|
||||||
|
We trained our model on 8 Nvidia Tesla V100-32GB GPUs. The default hyperparameters in ```config/mst_mixer.conf``` need to be adjusted if your setup differs from ours.
|
||||||
|
## AVSD
|
||||||
|
1. Set ```task=avsd``` in ```config/mst_mixer.conf```
|
||||||
|
2. ```shell
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
|
||||||
|
--mode train \
|
||||||
|
--tag mst_mixer_avsd \
|
||||||
|
--wandb_mode online \
|
||||||
|
--wandb_project mst_mixer_avsd
|
||||||
|
```
|
||||||
|
To deactivate [wandb][4] logging, use ```--wandb_mode disabled```.
|
||||||
|
On a similar setup to ours, this will take roughly 20h to complete.
|
||||||
|
|
||||||
|
## NExT-QA
|
||||||
|
1. Set ```task=nextqa``` in ```config/mst_mixer.conf```
|
||||||
|
2. ```shell
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
|
||||||
|
--mode train \
|
||||||
|
--tag mst_mixer_nextqa \
|
||||||
|
--wandb_mode online \
|
||||||
|
--wandb_project mst_mixer_nextqa
|
||||||
|
```
|
||||||
|
|
||||||
|
# Response Generation
|
||||||
|
## AVSD-DSTC7
|
||||||
|
1. Set ```dstc=7``` in the ```.conf``` file of your trained networks. in The default setting, can find this under ```logs/unique_training_tag/code/config/mst_mixer.conf```
|
||||||
|
2. Generate the responses
|
||||||
|
```shell
|
||||||
|
./generate_parallel_avsd.sh mst_mixer/mixer results_avsd_dstc7 generate logs/mst_mixer_avsd 7
|
||||||
|
```
|
||||||
|
3. All responses will be saved in ```output/dstc7/```
|
||||||
|
## AVSD-DSTC8
|
||||||
|
1. Set ```dstc=8``` in the ```.conf``` file of your trained networks. in The default setting, can find this under ```logs/unique_training_tag/code/config/mst_mixer.conf```
|
||||||
|
2. Generate the responses
|
||||||
|
```shell
|
||||||
|
./generate_parallel_avsd.sh mst_mixer/mixer results_avsd_dstc8 generate logs/mst_mixer_avsd 8
|
||||||
|
```
|
||||||
|
3. All responses will be saved in ```output/dstc8/```
|
||||||
|
|
||||||
|
## AVSD-DSTC10
|
||||||
|
1. Set ```dstc=10``` in the ```.conf``` file of your trained networks. in The default setting, can find this under ```logs/unique_training_tag/code/config/mst_mixer.conf```
|
||||||
|
2. Generate the responses
|
||||||
|
```shell
|
||||||
|
./generate_parallel_avsd.sh mst_mixer/mixer results_avsd_dstc10 generate logs/mst_mixer_avsd 10
|
||||||
|
```
|
||||||
|
3. All responses will be saved in ```output/dstc10/```
|
||||||
|
|
||||||
|
## NExT-QA
|
||||||
|
1. Generate the responses
|
||||||
|
```shell
|
||||||
|
./generate_parallel_nextqa.sh mst_mixer/mixer results_nextqa generate logs/mst_mixer_nextqa
|
||||||
|
```
|
||||||
|
2. All responses will be saved in ```output/nextqa/```
|
||||||
|
3. Evalute using this [script][15]
|
||||||
|
|
||||||
|
|
||||||
|
# Results
|
||||||
|
To evaluate our best model on
|
||||||
|
## AVSD-DSTC7
|
||||||
|
Executing the [eval_tool][7] of AVSD-DSTC7 using the generated repsonses will output the following metrics
|
||||||
|
| Model | BLUE-1 | BLUE-2 | BLUE-3 | BLUE-4 | METEOR | ROUGE-L | CIDEr |
|
||||||
|
|:--------:|:------:|:------:|:------:|:------:|:------:|:-------:|:-----:|
|
||||||
|
| Prev. SOTA | 78.2 | 65.5 | 55.2 | 46.9 | 30.8 | 61.9 | 135.2 |
|
||||||
|
| MST_MIXER | **78.7** | **66.5** | **56.3** | **47.6** | **31.3** | **62.5** | **138.8**|
|
||||||
|
|
||||||
|
## AVSD-DSTC8
|
||||||
|
1. Set ```dstc=8``` in the ```ckpt/code/mst_mixer.conf```
|
||||||
|
2. run
|
||||||
|
```shell
|
||||||
|
./generate_parallel_avsd.sh mst_mixer/mixer results_avsd_dstc8_best_model generate ckpt/avsd 8
|
||||||
|
```
|
||||||
|
3. The responses will be saved in ```output/dstc8/```
|
||||||
|
4. Executing the [eval_tool][7] of AVSD-DSTC8 using the generated repsonses will output the following metrics
|
||||||
|
|
||||||
|
| Model | BLUE-1 | BLUE-2 | BLUE-3 | BLUE-4 | METEOR | ROUGE-L | CIDEr |
|
||||||
|
|:--------:|:------:|:------:|:------:|:------:|:------:|:-------:|:-----:|
|
||||||
|
| Prev. SOTA | 76.4 | 64.1 | 54.3 | 46.0 | 30.1 | 61.0 | 130.4 |
|
||||||
|
| MST_MIXER | **77.5** | **66.0** | **56.1** | **47.7** | **30.6** | **62.4** | **135.4**|
|
||||||
|
|
||||||
|
## AVSD-DSTC10
|
||||||
|
Executing the [eval_tool][11] of AVSD-DSTC10 using the generated repsonses will output the following metrics
|
||||||
|
|
||||||
|
| Model | BLUE-1 | BLUE-2 | BLUE-3 | BLUE-4 | METEOR | ROUGE-L | CIDEr |
|
||||||
|
|:--------:|:------:|:------:|:------:|:------:|:------:|:-------:|:-----:|
|
||||||
|
| Prev. SOTA | 69.3 | 55.6 | 45.0 | 37.2 | 24.9 | 53.6 | 91.2 |
|
||||||
|
| MST_MIXER | **70.0** | **57.4** | **47.6** | **40.0** | **25.7** | **54.5** | **99.8**|
|
||||||
|
|
||||||
|
## NExT-QA
|
||||||
|
Executing the [eval script][15] of NExT-QA using the generated repsonses will output the following metrics
|
||||||
|
|
||||||
|
| Model | WUPS_C | WUPS_T | WUPS_D | WUPS |
|
||||||
|
|:--------:|:------:|:------:|:------:|:------:|
|
||||||
|
| Prev. SOTA | 17.98| 17.95 | 50.84 | 28.40 |
|
||||||
|
| MST_MIXER | **22.12** | **22.20** | **55.64** | **29.50** |
|
||||||
|
|
||||||
|
|
||||||
|
# Acknowledgements
|
||||||
|
We thank the authors of [RLM][8] for providing their [code][9] that greatly influenced this work.
|
||||||
|
|
||||||
|
[1]: https://git-lfs.com/
|
||||||
|
[2]: https://github.com/hudaAlamri/DSTC7-Audio-Visual-Scene-Aware-Dialog-AVSD-Challenge
|
||||||
|
[3]: https://github.com/dialogtekgeek/DSTC8-AVSD_official
|
||||||
|
[4]: https://wandb.ai/site
|
||||||
|
[5]: https://drive.google.com/drive/folders/1SlZTySJAk_2tiMG5F8ivxCfOl_OWwd_Q
|
||||||
|
[7]: https://drive.google.com/file/d/1EKfPtrNBQ5ciKRl6XggImweGRP84XuPi/view?usp=sharing
|
||||||
|
[8]: https://arxiv.org/abs/2002.00163
|
||||||
|
[9]: https://github.com/ictnlp/DSTC8-AVSD
|
||||||
|
[10]: https://drive.google.com/file/d/1zvC6FuPRVRiLQCXZcYpzYUI9r1tiWls6/view
|
||||||
|
[11]: https://github.com/ankitshah009/AVSD-DSTC10_baseline
|
||||||
|
[15]: https://github.com/doc-doc/NExT-OE/blob/main/eval_oe.py
|
||||||
|
[16]: https://adnenabdessaied.de/
|
||||||
|
[17]: https://perceptualui.org/people/shi/
|
||||||
|
[18]: https://perceptualui.org/people/bulling/
|
||||||
|
[19]: https://arxiv.org/abs/2407.02218
|
118
config/avsd_bart_base.json
Normal file
118
config/avsd_bart_base.json
Normal file
|
@ -0,0 +1,118 @@
|
||||||
|
{
|
||||||
|
"_name_or_path": "bart-base",
|
||||||
|
"activation_dropout": 0.1,
|
||||||
|
"activation_function": "gelu",
|
||||||
|
"add_bias_logits": false,
|
||||||
|
"add_final_layer_norm": false,
|
||||||
|
"architectures": [
|
||||||
|
"BartModel"
|
||||||
|
],
|
||||||
|
"attention_dropout": 0.1,
|
||||||
|
"bos_token_id": 0,
|
||||||
|
"classif_dropout": 0.1,
|
||||||
|
"classifier_dropout": 0.0,
|
||||||
|
"d_model": 768,
|
||||||
|
"decoder_attention_heads": 12,
|
||||||
|
"decoder_ffn_dim": 3072,
|
||||||
|
"decoder_layerdrop": 0.0,
|
||||||
|
"decoder_layers": 6,
|
||||||
|
"decoder_start_token_id": 2,
|
||||||
|
"dropout": 0.1,
|
||||||
|
"early_stopping": true,
|
||||||
|
"encoder_attention_heads": 12,
|
||||||
|
"encoder_ffn_dim": 3072,
|
||||||
|
"encoder_layerdrop": 0.0,
|
||||||
|
"encoder_layers": 6,
|
||||||
|
"eos_token_id": 2,
|
||||||
|
"forced_eos_token_id": 2,
|
||||||
|
"forced_bos_token_id": 0,
|
||||||
|
"gradient_checkpointing": false,
|
||||||
|
"id2label": {
|
||||||
|
"0": "LABEL_0",
|
||||||
|
"1": "LABEL_1",
|
||||||
|
"2": "LABEL_2"
|
||||||
|
},
|
||||||
|
"init_std": 0.02,
|
||||||
|
"is_encoder_decoder": true,
|
||||||
|
"label2id": {
|
||||||
|
"LABEL_0": 0,
|
||||||
|
"LABEL_1": 1,
|
||||||
|
"LABEL_2": 2
|
||||||
|
},
|
||||||
|
"max_position_embeddings": 1024,
|
||||||
|
"model_type": "bart",
|
||||||
|
"no_repeat_ngram_size": 3,
|
||||||
|
"normalize_before": false,
|
||||||
|
"normalize_embedding": true,
|
||||||
|
"num_beams": 4,
|
||||||
|
"num_hidden_layers": 6,
|
||||||
|
"pad_token_id": 1,
|
||||||
|
"scale_embedding": false,
|
||||||
|
"task_specific_params": {
|
||||||
|
"summarization": {
|
||||||
|
"length_penalty": 1.0,
|
||||||
|
"max_length": 128,
|
||||||
|
"min_length": 12,
|
||||||
|
"num_beams": 4
|
||||||
|
},
|
||||||
|
"summarization_cnn": {
|
||||||
|
"length_penalty": 2.0,
|
||||||
|
"max_length": 142,
|
||||||
|
"min_length": 56,
|
||||||
|
"num_beams": 4
|
||||||
|
},
|
||||||
|
"summarization_xsum": {
|
||||||
|
"length_penalty": 1.0,
|
||||||
|
"max_length": 62,
|
||||||
|
"min_length": 11,
|
||||||
|
"num_beams": 6
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"torch_dtype": "float32",
|
||||||
|
"transformers_version": "4.12.0.dev0",
|
||||||
|
"use_cache": true,
|
||||||
|
"vocab_size": 50265,
|
||||||
|
|
||||||
|
"d_i3d_flow": 2048,
|
||||||
|
"d_i3d_rgb": 2048,
|
||||||
|
"d_sam": 512,
|
||||||
|
"d_audio": 128,
|
||||||
|
"top_k": 10,
|
||||||
|
"num_nn": 4,
|
||||||
|
"gnn_type": "appnp",
|
||||||
|
"use_random_graphs": false,
|
||||||
|
"integrate_all_gnn_features": true,
|
||||||
|
"use_elbo_local": true,
|
||||||
|
"use_elbo_global": true,
|
||||||
|
"use_non_linear": true,
|
||||||
|
"gnn_K": 2,
|
||||||
|
"gnn_alpha": 0.1,
|
||||||
|
"num_modalities": 6,
|
||||||
|
"local_gnn_d_hidden": 768,
|
||||||
|
"global_gnn_d_hidden": 768,
|
||||||
|
"num_local_gnn_heads": 2,
|
||||||
|
"num_global_gnn_heads": 4,
|
||||||
|
"local_gnn_dropout": 0.1,
|
||||||
|
"global_gnn_dropout": 0.1,
|
||||||
|
"local_fc_dropout": 0.1,
|
||||||
|
"global_fc_dropout": 0.1,
|
||||||
|
"num_local_gnn_layers": 2,
|
||||||
|
"num_global_gnn_layers": 2,
|
||||||
|
"num_local_fc_layers": 2,
|
||||||
|
"num_global_fc_layers": 2,
|
||||||
|
"use_local_gnn_bn": true,
|
||||||
|
"use_global_gnn_bn": true,
|
||||||
|
"use_local_fc_bn": true,
|
||||||
|
"use_global_fc_bn": true,
|
||||||
|
"local_gnn_concat": true,
|
||||||
|
"global_gnn_concat": true,
|
||||||
|
"num_local_gr_learner_heads": 8,
|
||||||
|
"num_global_gr_learner_heads":8,
|
||||||
|
"init_adj_ratio": 0.5,
|
||||||
|
"adj_ratio": 0.5,
|
||||||
|
"alpha": 0.9,
|
||||||
|
"gnns_every": 2,
|
||||||
|
"num_layers_state_fc_decoder": 2,
|
||||||
|
"dropout_state_fc_decoder": 0.3
|
||||||
|
}
|
||||||
|
|
116
config/avsd_bart_large.json
Normal file
116
config/avsd_bart_large.json
Normal file
|
@ -0,0 +1,116 @@
|
||||||
|
{
|
||||||
|
"activation_dropout": 0.1,
|
||||||
|
"activation_function": "gelu",
|
||||||
|
"add_bias_logits": false,
|
||||||
|
"add_final_layer_norm": false,
|
||||||
|
"architectures": [
|
||||||
|
"BartModel"
|
||||||
|
],
|
||||||
|
"attention_dropout": 0.1,
|
||||||
|
"bos_token_id": 0,
|
||||||
|
"classif_dropout": 0.1,
|
||||||
|
"classifier_dropout": 0.0,
|
||||||
|
"d_model": 1024,
|
||||||
|
"decoder_attention_heads": 16,
|
||||||
|
"decoder_ffn_dim": 4096,
|
||||||
|
"decoder_layerdrop": 0.0,
|
||||||
|
"decoder_layers": 12,
|
||||||
|
"decoder_start_token_id": 2,
|
||||||
|
"dropout": 0.1,
|
||||||
|
"early_stopping": true,
|
||||||
|
"encoder_attention_heads": 16,
|
||||||
|
"encoder_ffn_dim": 4096,
|
||||||
|
"encoder_layerdrop": 0.0,
|
||||||
|
"encoder_layers": 12,
|
||||||
|
"eos_token_id": 2,
|
||||||
|
"forced_eos_token_id": 2,
|
||||||
|
"forced_bos_token_id": 0,
|
||||||
|
"gradient_checkpointing": false,
|
||||||
|
"id2label": {
|
||||||
|
"0": "LABEL_0",
|
||||||
|
"1": "LABEL_1",
|
||||||
|
"2": "LABEL_2"
|
||||||
|
},
|
||||||
|
"init_std": 0.02,
|
||||||
|
"is_encoder_decoder": true,
|
||||||
|
"label2id": {
|
||||||
|
"LABEL_0": 0,
|
||||||
|
"LABEL_1": 1,
|
||||||
|
"LABEL_2": 2
|
||||||
|
},
|
||||||
|
"max_position_embeddings": 1024,
|
||||||
|
"model_type": "bart",
|
||||||
|
"no_repeat_ngram_size": 3,
|
||||||
|
"normalize_before": false,
|
||||||
|
"num_beams": 4,
|
||||||
|
"num_hidden_layers": 12,
|
||||||
|
"pad_token_id": 1,
|
||||||
|
"scale_embedding": false,
|
||||||
|
"task_specific_params": {
|
||||||
|
"summarization": {
|
||||||
|
"length_penalty": 1.0,
|
||||||
|
"max_length": 128,
|
||||||
|
"min_length": 12,
|
||||||
|
"num_beams": 4
|
||||||
|
},
|
||||||
|
"summarization_cnn": {
|
||||||
|
"length_penalty": 2.0,
|
||||||
|
"max_length": 142,
|
||||||
|
"min_length": 56,
|
||||||
|
"num_beams": 4
|
||||||
|
},
|
||||||
|
"summarization_xsum": {
|
||||||
|
"length_penalty": 1.0,
|
||||||
|
"max_length": 62,
|
||||||
|
"min_length": 11,
|
||||||
|
"num_beams": 6
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"transformers_version": "4.7.0.dev0",
|
||||||
|
"use_cache": true,
|
||||||
|
"vocab_size": 50265,
|
||||||
|
|
||||||
|
"d_i3d_flow": 2048,
|
||||||
|
"d_i3d_rgb": 2048,
|
||||||
|
"d_sam": 512,
|
||||||
|
"d_audio": 128,
|
||||||
|
"top_k": 10,
|
||||||
|
"num_nn": 4,
|
||||||
|
"gnn_type": "appnp",
|
||||||
|
"use_random_graphs": false,
|
||||||
|
"integrate_all_gnn_features": true,
|
||||||
|
"use_elbo_local": true,
|
||||||
|
"use_elbo_global": true,
|
||||||
|
"use_non_linear": true,
|
||||||
|
"gnn_K": 2,
|
||||||
|
"gnn_alpha": 0.1,
|
||||||
|
"num_modalities": 6,
|
||||||
|
"local_gnn_d_hidden": 1024,
|
||||||
|
"global_gnn_d_hidden": 1024,
|
||||||
|
"num_local_gnn_heads": 2,
|
||||||
|
"num_global_gnn_heads": 4,
|
||||||
|
"local_gnn_dropout": 0.1,
|
||||||
|
"global_gnn_dropout": 0.1,
|
||||||
|
"local_fc_dropout": 0.1,
|
||||||
|
"global_fc_dropout": 0.1,
|
||||||
|
"num_local_gnn_layers": 1,
|
||||||
|
"num_global_gnn_layers": 1,
|
||||||
|
"num_local_fc_layers": 1,
|
||||||
|
"num_global_fc_layers": 1,
|
||||||
|
"use_local_gnn_bn": true,
|
||||||
|
"use_global_gnn_bn": true,
|
||||||
|
"use_local_fc_bn": true,
|
||||||
|
"use_global_fc_bn": true,
|
||||||
|
"local_gnn_concat": true,
|
||||||
|
"global_gnn_concat": true,
|
||||||
|
"num_local_gr_learner_heads": 8,
|
||||||
|
"num_global_gr_learner_heads": 8,
|
||||||
|
"init_adj_ratio": 0.5,
|
||||||
|
"adj_ratio": 0.5,
|
||||||
|
"alpha": 0.9,
|
||||||
|
"gnns_every": 4,
|
||||||
|
"num_layers_state_fc_decoder": 2,
|
||||||
|
"dropout_state_fc_decoder": 0.3
|
||||||
|
|
||||||
|
}
|
||||||
|
|
94
config/mst_mixer.conf
Normal file
94
config/mst_mixer.conf
Normal file
|
@ -0,0 +1,94 @@
|
||||||
|
mixer {
|
||||||
|
task = avsd
|
||||||
|
#################################################################################
|
||||||
|
# datasets
|
||||||
|
# avsd
|
||||||
|
|
||||||
|
avsd_processed = features/
|
||||||
|
avsd_train = raw_data/train_set4DSTC7-AVSD.json
|
||||||
|
avsd_val = raw_data/valid_set4DSTC7-AVSD.json
|
||||||
|
avsd_test_dstc7 = raw_data/test_set4DSTC7-AVSD.json
|
||||||
|
avsd_test_dstc8 = raw_data/test_set4DSTC8-AVSD.json
|
||||||
|
avsd_test_dstc10 = raw_data/test_set4DSTC10-AVSD.json
|
||||||
|
avsd_feature_path = features/
|
||||||
|
avsd_i3d_rgb = features/i3d_rgb
|
||||||
|
avsd_i3d_rgb_test = features/i3d_rgb_testset
|
||||||
|
avsd_i3d_flow = features/i3d_flow_all
|
||||||
|
avsd_i3d_flow_test = features/i3d_flow_testset
|
||||||
|
avsd_audio = features/vggish_all
|
||||||
|
avsd_audio_test = features/vggish_testset
|
||||||
|
avsd_objects = features/sam
|
||||||
|
avsd_objects_test = features/sam_testset
|
||||||
|
|
||||||
|
dstc = 7
|
||||||
|
|
||||||
|
# NextQA
|
||||||
|
nextqa_root = processed/next_qa/annotations
|
||||||
|
nextqa_vid_feat = processed/next_qa/vid_feat
|
||||||
|
#################################################################################
|
||||||
|
# Model
|
||||||
|
bart_size = large # base, large
|
||||||
|
avsd_bart_base_config = config/avsd_bart_base.json
|
||||||
|
avsd_bart_large_config = config/avsd_bart_large.json
|
||||||
|
nextqa_bart_large_config = config/nextqa_bart_large.json
|
||||||
|
|
||||||
|
#################################################################################
|
||||||
|
# Logging & Checkpointing
|
||||||
|
log_dir = logs
|
||||||
|
output_dir_dstc7 = output/dstc7
|
||||||
|
output_dir_dstc8 = output/dstc8
|
||||||
|
output_dir_dstc10 = output/dstc10
|
||||||
|
output_dir_nextqa = output/nextqa
|
||||||
|
max_ckpt_to_keep = 5
|
||||||
|
start_ckpt_for_generating = none
|
||||||
|
loads_start_path = false
|
||||||
|
next_logging_pct = 0.1
|
||||||
|
save_ckpt=true
|
||||||
|
skip_saving_ckpt = false
|
||||||
|
stop_epochs = -1
|
||||||
|
resets_min_val_loss = false
|
||||||
|
restarts = false
|
||||||
|
uses_new_optimizer = true
|
||||||
|
sets_new_lr = false
|
||||||
|
################################################################################
|
||||||
|
# Data processing
|
||||||
|
expand_rnd = false
|
||||||
|
cap_sum = cap_sum
|
||||||
|
add_state_tokens = true
|
||||||
|
bart_max_input_len = 1024
|
||||||
|
num_workers = 0
|
||||||
|
n_history = 3
|
||||||
|
caption_drop_rate = 0.0
|
||||||
|
vis_feat_length = 36
|
||||||
|
#################################################################################
|
||||||
|
# Training
|
||||||
|
dp_type = ddp
|
||||||
|
batch_size = 16
|
||||||
|
num_epochs = 12
|
||||||
|
warmup_ratio = 0.1
|
||||||
|
batch_multiply = 1
|
||||||
|
skip_eval = false
|
||||||
|
stop_epoch = -1
|
||||||
|
random_seed = 54
|
||||||
|
learning_rate_bart = 1e-5
|
||||||
|
learning_rate_other = 1e-4
|
||||||
|
min_lr = 0
|
||||||
|
clip_grad_value = 1.0
|
||||||
|
print_output = false
|
||||||
|
eval_first = false
|
||||||
|
overfit_size = -1
|
||||||
|
elbo_global_coeff = 100
|
||||||
|
elbo_local_coeff = 100
|
||||||
|
gen_coeff = 1
|
||||||
|
#################################################################################
|
||||||
|
# Generation
|
||||||
|
gen_batch_size = 1
|
||||||
|
beam_depth = 5
|
||||||
|
max_generation_length = 20
|
||||||
|
min_generation_length = 1
|
||||||
|
length_penalty = 0.3
|
||||||
|
#################################################################################
|
||||||
|
# Misc.
|
||||||
|
master_port = 5101
|
||||||
|
use_cpu = false
|
||||||
|
}
|
116
config/nextqa_bart_large.json
Normal file
116
config/nextqa_bart_large.json
Normal file
|
@ -0,0 +1,116 @@
|
||||||
|
{
|
||||||
|
"activation_dropout": 0.1,
|
||||||
|
"activation_function": "gelu",
|
||||||
|
"add_bias_logits": false,
|
||||||
|
"add_final_layer_norm": false,
|
||||||
|
"architectures": [
|
||||||
|
"BartModel"
|
||||||
|
],
|
||||||
|
"attention_dropout": 0.1,
|
||||||
|
"bos_token_id": 0,
|
||||||
|
"classif_dropout": 0.1,
|
||||||
|
"classifier_dropout": 0.0,
|
||||||
|
"d_model": 1024,
|
||||||
|
"decoder_attention_heads": 16,
|
||||||
|
"decoder_ffn_dim": 4096,
|
||||||
|
"decoder_layerdrop": 0.0,
|
||||||
|
"decoder_layers": 12,
|
||||||
|
"decoder_start_token_id": 2,
|
||||||
|
"dropout": 0.1,
|
||||||
|
"early_stopping": true,
|
||||||
|
"encoder_attention_heads": 16,
|
||||||
|
"encoder_ffn_dim": 4096,
|
||||||
|
"encoder_layerdrop": 0.0,
|
||||||
|
"encoder_layers": 12,
|
||||||
|
"eos_token_id": 2,
|
||||||
|
"forced_eos_token_id": 2,
|
||||||
|
"forced_bos_token_id": 0,
|
||||||
|
"gradient_checkpointing": false,
|
||||||
|
"id2label": {
|
||||||
|
"0": "LABEL_0",
|
||||||
|
"1": "LABEL_1",
|
||||||
|
"2": "LABEL_2"
|
||||||
|
},
|
||||||
|
"init_std": 0.02,
|
||||||
|
"is_encoder_decoder": true,
|
||||||
|
"label2id": {
|
||||||
|
"LABEL_0": 0,
|
||||||
|
"LABEL_1": 1,
|
||||||
|
"LABEL_2": 2
|
||||||
|
},
|
||||||
|
"max_position_embeddings": 1024,
|
||||||
|
"model_type": "bart",
|
||||||
|
"no_repeat_ngram_size": 3,
|
||||||
|
"normalize_before": false,
|
||||||
|
"num_beams": 4,
|
||||||
|
"num_hidden_layers": 12,
|
||||||
|
"pad_token_id": 1,
|
||||||
|
"scale_embedding": false,
|
||||||
|
"task_specific_params": {
|
||||||
|
"summarization": {
|
||||||
|
"length_penalty": 1.0,
|
||||||
|
"max_length": 128,
|
||||||
|
"min_length": 12,
|
||||||
|
"num_beams": 4
|
||||||
|
},
|
||||||
|
"summarization_cnn": {
|
||||||
|
"length_penalty": 2.0,
|
||||||
|
"max_length": 142,
|
||||||
|
"min_length": 56,
|
||||||
|
"num_beams": 4
|
||||||
|
},
|
||||||
|
"summarization_xsum": {
|
||||||
|
"length_penalty": 1.0,
|
||||||
|
"max_length": 62,
|
||||||
|
"min_length": 11,
|
||||||
|
"num_beams": 6
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"transformers_version": "4.7.0.dev0",
|
||||||
|
"use_cache": true,
|
||||||
|
"vocab_size": 50265,
|
||||||
|
|
||||||
|
"d_i3d_flow": 2048,
|
||||||
|
"d_i3d_rgb": 2048,
|
||||||
|
"d_sam": 512,
|
||||||
|
"d_audio": 128,
|
||||||
|
"top_k": 10,
|
||||||
|
"num_nn": 4,
|
||||||
|
"gnn_type": "appnp",
|
||||||
|
"use_random_graphs": false,
|
||||||
|
"integrate_all_gnn_features": true,
|
||||||
|
"use_elbo_local": true,
|
||||||
|
"use_elbo_global": true,
|
||||||
|
"use_non_linear": true,
|
||||||
|
"gnn_K": 2,
|
||||||
|
"gnn_alpha": 0.1,
|
||||||
|
"num_modalities": 3,
|
||||||
|
"local_gnn_d_hidden": 1024,
|
||||||
|
"global_gnn_d_hidden": 1024,
|
||||||
|
"num_local_gnn_heads": 2,
|
||||||
|
"num_global_gnn_heads": 4,
|
||||||
|
"local_gnn_dropout": 0.1,
|
||||||
|
"global_gnn_dropout": 0.1,
|
||||||
|
"local_fc_dropout": 0.1,
|
||||||
|
"global_fc_dropout": 0.1,
|
||||||
|
"num_local_gnn_layers": 1,
|
||||||
|
"num_global_gnn_layers": 1,
|
||||||
|
"num_local_fc_layers": 1,
|
||||||
|
"num_global_fc_layers": 1,
|
||||||
|
"use_local_gnn_bn": true,
|
||||||
|
"use_global_gnn_bn": true,
|
||||||
|
"use_local_fc_bn": true,
|
||||||
|
"use_global_fc_bn": true,
|
||||||
|
"local_gnn_concat": true,
|
||||||
|
"global_gnn_concat": true,
|
||||||
|
"num_local_gr_learner_heads": 8,
|
||||||
|
"num_global_gr_learner_heads": 8,
|
||||||
|
"init_adj_ratio": 0.5,
|
||||||
|
"adj_ratio": 0.5,
|
||||||
|
"alpha": 0.9,
|
||||||
|
"gnns_every": 4,
|
||||||
|
"num_layers_state_fc_decoder": 2,
|
||||||
|
"dropout_state_fc_decoder": 0.3
|
||||||
|
|
||||||
|
}
|
||||||
|
|
20
custom_datasets/README.md
Normal file
20
custom_datasets/README.md
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
1. Download the raw [Charades train/val](https://prior.allenai.org/projects/charades) data
|
||||||
|
2. Download the raw [Charades test](https://ai2-public-datasets.s3-us-west-2.amazonaws.com/charades/Charades_vu17_test_480.tar) data
|
||||||
|
3. Install [SAM](https://github.com/facebookresearch/segment-anything.git)
|
||||||
|
4. Segment the frames
|
||||||
|
```shell
|
||||||
|
python segement.py --sam_ckpt path_to_sam_ckpt --avsd_root path_to_charades_trval_frames --crop_root path_to_save_the_trval_crops --mode segment --start start_idx --end end_idx
|
||||||
|
python segement.py --sam_ckpt path_to_sam_ckpt --avsd_root path_to_charades_test_frames --crop_root path_to_save_the_test_crops --mode segment --start start_idx --end end_id
|
||||||
|
```
|
||||||
|
5. Embed the crops
|
||||||
|
```shell
|
||||||
|
python segement.py --sam_ckpt path_to_sam_ckpt --crop_root path_to_save_the_trval_crops --mode emebed --embed_root ../features/sam --start start_idx --end end_idx
|
||||||
|
python segement.py --sam_ckpt path_to_sam_ckpt --crop_root path_to_save_the_test_crops --mode emebed --embed_root ../features/sam_testset --start start_idx --end end_idx
|
||||||
|
|
||||||
|
```
|
||||||
|
6. Preprocess and log the data
|
||||||
|
```shell
|
||||||
|
python dataset.py --split train
|
||||||
|
python dataset.py --split val
|
||||||
|
|
||||||
|
```
|
0
custom_datasets/__init__.py
Normal file
0
custom_datasets/__init__.py
Normal file
401
custom_datasets/avsd.py
Normal file
401
custom_datasets/avsd.py
Normal file
|
@ -0,0 +1,401 @@
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import pyhocon
|
||||||
|
from copy import deepcopy
|
||||||
|
import json
|
||||||
|
from tqdm import tqdm
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
from transformers import BartTokenizer
|
||||||
|
from itertools import chain
|
||||||
|
|
||||||
|
|
||||||
|
ADDITIONAL_SPECIAL_TOKENS = [
|
||||||
|
'<place_holder>', '<s0>', '<s1>', '<s2>', '<s3>', '<s4>', '<s5>']
|
||||||
|
|
||||||
|
SPECIAL_TOKENS_DICT = {
|
||||||
|
'bos_token': '<s>',
|
||||||
|
'eos_token': '</s>',
|
||||||
|
'pad_token': '<pad>',
|
||||||
|
'additional_special_tokens': ADDITIONAL_SPECIAL_TOKENS
|
||||||
|
}
|
||||||
|
|
||||||
|
S0_TOK = '<s0>' # I3D_flow
|
||||||
|
S1_TOK = '<s1>' # I3D_rgb
|
||||||
|
S2_TOK = '<s2>' # sam obj
|
||||||
|
S3_TOK = '<s3>' # audio
|
||||||
|
S4_TOK = '<s4>' # history
|
||||||
|
S5_TOK = '<s5>' # question
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize(obj, tokenizer):
|
||||||
|
if isinstance(obj, str):
|
||||||
|
return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj))
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
return dict((n, tokenize(o)) for n, o in obj.items())
|
||||||
|
return list(tokenize(o) for o in obj)
|
||||||
|
|
||||||
|
|
||||||
|
class AVSDDataset(Dataset):
|
||||||
|
def __init__(self, config, split):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.split = split
|
||||||
|
self.bart_max_input_len = config['bart_max_input_len']
|
||||||
|
self.bart_size = config['bart_size']
|
||||||
|
self.cap_sum = config['cap_sum']
|
||||||
|
self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-{}'.format(self.bart_size))
|
||||||
|
self.vocab_size = self.tokenizer.vocab_size
|
||||||
|
|
||||||
|
self.tokenizer.add_special_tokens({'additional_special_tokens': ADDITIONAL_SPECIAL_TOKENS})
|
||||||
|
self.vocab_size += len(ADDITIONAL_SPECIAL_TOKENS)
|
||||||
|
self.tokenizer.save_pretrained(os.path.join(self.config['log_dir'], 'bart_tokenizer'))
|
||||||
|
self.processed_dir = os.path.join(self.config['avsd_processed'], 'hist_with_{}_rounds'.format(self.config['n_history']), split)
|
||||||
|
self.paths = list(map(lambda p: os.path.join(self.processed_dir, p), os.listdir(self.processed_dir)))
|
||||||
|
|
||||||
|
if self.config['overfit'] > 0:
|
||||||
|
self.paths = self.paths[:self.config['overfit_size']]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.paths)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
pth = self.paths[index]
|
||||||
|
with open(pth, 'rb') as f:
|
||||||
|
item = pickle.load(f)
|
||||||
|
|
||||||
|
question_sep = self.tokenizer.convert_tokens_to_ids('<s5>')
|
||||||
|
|
||||||
|
input_ids = item['input_ids']
|
||||||
|
history_end = (input_ids == question_sep).nonzero(as_tuple=True)[0]
|
||||||
|
|
||||||
|
history_interval = [0, history_end.item()] # The last token is the question state token (not part of the history)
|
||||||
|
question_interval = [history_end.item(), input_ids.size(0)]
|
||||||
|
|
||||||
|
lm_labels = item['lm_labels']
|
||||||
|
i3d_rgb = item['i3d_rgb']
|
||||||
|
i3d_flow = item['i3d_flow']
|
||||||
|
sam = item['sam']
|
||||||
|
vgg = item['vgg']
|
||||||
|
vid = item['vid']
|
||||||
|
|
||||||
|
return input_ids, lm_labels, history_interval, question_interval, i3d_rgb, i3d_flow, sam, vgg, vid
|
||||||
|
|
||||||
|
def padding(self, seq, pad_token, max_len=None):
|
||||||
|
if max_len is None:
|
||||||
|
max_len = max([i.size(0) for i in seq])
|
||||||
|
if len(seq[0].size()) == 1:
|
||||||
|
result = torch.ones((len(seq), max_len)).long() * pad_token
|
||||||
|
else:
|
||||||
|
result = torch.ones((len(seq), max_len, seq[0].size(-1))).float()
|
||||||
|
for i in range(len(seq)):
|
||||||
|
result[i, :seq[i].size(0)] = seq[i]
|
||||||
|
return result
|
||||||
|
|
||||||
|
def collate_fn(self, batch):
|
||||||
|
input_ids_list, lm_labels_list, history_interval_list, question_interval_list, i3d_rgb_list, i3d_flow_list, sam_list, vggish_list, vid_ids_list = [], [], [], [], [], [], [], [], []
|
||||||
|
for i in batch:
|
||||||
|
input_ids_list.append(i[0])
|
||||||
|
lm_labels_list.append(i[1])
|
||||||
|
history_interval_list.append(i[2])
|
||||||
|
question_interval_list.append(i[3])
|
||||||
|
i3d_rgb_list.append(i[4])
|
||||||
|
i3d_flow_list.append(i[5])
|
||||||
|
sam_list.append(i[6])
|
||||||
|
vggish_list.append(i[7])
|
||||||
|
vid_ids_list.append(i[8])
|
||||||
|
|
||||||
|
history_intervals = np.array(history_interval_list)
|
||||||
|
question_intervals = np.array(question_interval_list)
|
||||||
|
|
||||||
|
|
||||||
|
min_len_i3d_flow = min([feat.shape[0] for feat in i3d_flow_list])
|
||||||
|
min_len_i3d_rgb = min([feat.shape[0] for feat in i3d_rgb_list])
|
||||||
|
min_len_sam = min([feat.shape[0] for feat in sam_list])
|
||||||
|
min_len_vggish = min([feat.shape[0] for feat in vggish_list])
|
||||||
|
|
||||||
|
min_length = min([self.config['vis_feat_length'], min_len_i3d_flow, min_len_i3d_rgb, min_len_sam, min_len_vggish])
|
||||||
|
|
||||||
|
# Sample equally-distant features from the visual features for each sample within the batch
|
||||||
|
for i in range(len(i3d_rgb_list)):
|
||||||
|
sample_idx_i3d_rgb = np.round(np.linspace(0, i3d_rgb_list[i].shape[0] - 1, min_length)).astype(int)
|
||||||
|
i3d_rgb_list[i] = i3d_rgb_list[i][sample_idx_i3d_rgb, :]
|
||||||
|
i3d_rgb = torch.from_numpy(np.array(i3d_rgb_list)).float()
|
||||||
|
|
||||||
|
for i in range(len(i3d_flow_list)):
|
||||||
|
sample_idx_i3d_flow = np.round(np.linspace(0, i3d_flow_list[i].shape[0] - 1, min_length)).astype(int)
|
||||||
|
i3d_flow_list[i] = i3d_flow_list[i][sample_idx_i3d_flow, :]
|
||||||
|
i3d_flow = torch.from_numpy(np.array(i3d_flow_list)).float()
|
||||||
|
|
||||||
|
for i in range(len(sam_list)):
|
||||||
|
sample_idx_sam = np.round(np.linspace(0, sam_list[i].shape[0] - 1, min_length)).astype(int)
|
||||||
|
sam_list[i] = sam_list[i][sample_idx_sam, :]
|
||||||
|
sam = torch.from_numpy(np.array(sam_list)).float()
|
||||||
|
|
||||||
|
for i in range(len(vggish_list)):
|
||||||
|
sample_idx_vggish = np.round(np.linspace(0, vggish_list[i].shape[0] - 1, min_length)).astype(int)
|
||||||
|
vggish_list[i] = vggish_list[i][sample_idx_vggish, :]
|
||||||
|
vggish = torch.from_numpy(np.array(vggish_list)).float()
|
||||||
|
|
||||||
|
pad_token, i3d_flow_sep, i3d_rgb_sep, sam_sep, audio_sep, ph_token = self.tokenizer.convert_tokens_to_ids(
|
||||||
|
['<pad>', '<s0>', '<s1>', '<s2>', '<s3>', '<place_holder>'])
|
||||||
|
|
||||||
|
# All the visual features will not be masked because we do not perform any padding on them
|
||||||
|
video_mask = torch.ones((len(batch), min_length*4 + 4)) == 1 # NOTE *4: 4 modalities | +4: the state tokens
|
||||||
|
# Now we create a dummy input for the video tokens (sole purpose is to reserve the spot of the seperators)
|
||||||
|
dummy = torch.ones((len(batch), min_length)) * ph_token
|
||||||
|
video_place_holder_ids = torch.cat(
|
||||||
|
[torch.ones((len(batch), 1)) * i3d_rgb_sep, dummy,
|
||||||
|
torch.ones((len(batch), 1)) * i3d_flow_sep, dummy,
|
||||||
|
torch.ones((len(batch), 1)) * sam_sep, dummy,
|
||||||
|
torch.ones((len(batch), 1)) * audio_sep, dummy,
|
||||||
|
], dim=-1).long()
|
||||||
|
|
||||||
|
input_ids = self.padding(input_ids_list, pad_token)
|
||||||
|
lm_labels = self.padding(lm_labels_list, -100)
|
||||||
|
text_mask = input_ids != pad_token
|
||||||
|
input_mask = torch.cat([video_mask, text_mask], dim=1)
|
||||||
|
|
||||||
|
# Now we get the intervals of the visual input tokens
|
||||||
|
# Here the interval do not change across the batch dimension
|
||||||
|
i3d_rgb_interval = [0, min_length + 1] # the last token is not part of this modality
|
||||||
|
i3d_flow_interval = [min_length + 1, 2 * min_length + 2]
|
||||||
|
sam_interval = [2 * min_length + 2, 3 * min_length + 3]
|
||||||
|
audio_interval = [3 * min_length + 3, 4 * min_length + 4]
|
||||||
|
|
||||||
|
vis_state_vector_idx = [i3d_rgb_interval[0], i3d_flow_interval[0], sam_interval[0], audio_interval[0]]
|
||||||
|
|
||||||
|
# adapt the question and history interval -- shifted to the right by the visual input length
|
||||||
|
history_intervals += 4 * min_length + 4
|
||||||
|
question_intervals += 4 * min_length + 4
|
||||||
|
history_intervals = history_intervals.tolist()
|
||||||
|
question_intervals = question_intervals.tolist()
|
||||||
|
|
||||||
|
history_state_vector_idx = [x[0] + 1 for x in history_intervals] # +1 because the history starts with <s><s4> .....
|
||||||
|
question_state_vector_idx = [x[0] for x in question_intervals] # +1 because the history starts with <s><s4> .....
|
||||||
|
|
||||||
|
batch = {
|
||||||
|
'input_ids': input_ids,
|
||||||
|
'video_place_holder_ids': video_place_holder_ids,
|
||||||
|
'i3d_rgb': i3d_rgb,
|
||||||
|
'i3d_flow': i3d_flow,
|
||||||
|
'sam': sam,
|
||||||
|
'vggish': vggish,
|
||||||
|
'lm_labels': lm_labels,
|
||||||
|
'input_mask': input_mask,
|
||||||
|
'i3d_rgb_interval': i3d_rgb_interval,
|
||||||
|
'i3d_flow_interval': i3d_flow_interval,
|
||||||
|
'sam_interval': sam_interval,
|
||||||
|
'audio_interval': audio_interval,
|
||||||
|
'history_intervals': history_intervals,
|
||||||
|
'question_intervals': question_intervals,
|
||||||
|
'vis_state_vector_idx': vis_state_vector_idx,
|
||||||
|
'history_state_vector_idx': history_state_vector_idx,
|
||||||
|
'question_state_vector_idx': question_state_vector_idx
|
||||||
|
}
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataset(config, split, tokenizer):
|
||||||
|
if split != 'test':
|
||||||
|
dialog_pth = config[f'avsd_{split}']
|
||||||
|
else:
|
||||||
|
dialog_pth = config['avsd_test_dstc{}'.format(config['dstc'])]
|
||||||
|
n_history = config['n_history']
|
||||||
|
dialog_data = json.load(open(dialog_pth, 'r'))
|
||||||
|
dialog_list = []
|
||||||
|
vid_set = set()
|
||||||
|
undisclosed_only = split == 'test'
|
||||||
|
pbar = tqdm(dialog_data['dialogs'])
|
||||||
|
|
||||||
|
pbar.set_description('[INFO] Generating {} items | DSTC {}'.format(split, config['dstc']))
|
||||||
|
for dialog in pbar:
|
||||||
|
if config['dstc'] != 10:
|
||||||
|
caption = [tokenize(dialog['caption'], tokenizer)] + [tokenize(dialog['summary'], tokenizer)]
|
||||||
|
else:
|
||||||
|
caption = [tokenize('no', tokenizer)]
|
||||||
|
|
||||||
|
questions = [tokenize(d['question'], tokenizer) for d in dialog['dialog']]
|
||||||
|
answers = [tokenize(d['answer'], tokenizer) for d in dialog['dialog']]
|
||||||
|
vid = dialog["image_id"]
|
||||||
|
vid_set.add(vid)
|
||||||
|
if undisclosed_only:
|
||||||
|
it = range(len(questions) - 1, len(questions))
|
||||||
|
else:
|
||||||
|
it = range(len(questions))
|
||||||
|
qalist=[]
|
||||||
|
history = []
|
||||||
|
if undisclosed_only:
|
||||||
|
for n in range(len(questions)-1):
|
||||||
|
qalist.append(questions[n])
|
||||||
|
qalist.append(answers[n])
|
||||||
|
history=qalist[max(-len(qalist),-n_history*2):]
|
||||||
|
for n in it:
|
||||||
|
if undisclosed_only:
|
||||||
|
assert dialog['dialog'][n]['answer'] == '__UNDISCLOSED__'
|
||||||
|
question = questions[n]
|
||||||
|
answer = answers[n]
|
||||||
|
history.append(question)
|
||||||
|
if n_history == 0:
|
||||||
|
item = {'vid': vid, 'history': [question], 'answer': answer, 'caption': caption}
|
||||||
|
else:
|
||||||
|
item = {'vid': vid, 'history': history, 'answer': answer, 'caption': caption}
|
||||||
|
dialog_list.append(item)
|
||||||
|
qalist.append(question)
|
||||||
|
qalist.append(answer)
|
||||||
|
history=qalist[max(-len(qalist),-n_history*2):]
|
||||||
|
|
||||||
|
all_features = {}
|
||||||
|
fea_types = ['vggish', 'i3d_flow', 'i3d_rgb', 'sam']
|
||||||
|
|
||||||
|
dataname = '<FeaType>/<ImageID>.npy'
|
||||||
|
for ftype in fea_types:
|
||||||
|
if undisclosed_only:
|
||||||
|
basename = dataname.replace('<FeaType>', ftype+'_testset')
|
||||||
|
else:
|
||||||
|
basename = dataname.replace('<FeaType>', ftype)
|
||||||
|
features = {}
|
||||||
|
for vid in vid_set:
|
||||||
|
filename = basename.replace('<ImageID>', vid)
|
||||||
|
filepath = config['avsd_feature_path'] + filename
|
||||||
|
features[vid] = filepath
|
||||||
|
all_features[ftype] = features
|
||||||
|
return dialog_list, all_features
|
||||||
|
|
||||||
|
|
||||||
|
def build_input_from_segments(caption, history_orig, reply, tokenizer, add_state_tokens=True, drop_caption=False):
|
||||||
|
""" Build a sequence of input from 3 segments: caption(caption+summary) history and last reply """
|
||||||
|
|
||||||
|
bos, eos, hist_state, ques_state = tokenizer.convert_tokens_to_ids(['<s>', '</s>', '<s4>', '<s5>'])
|
||||||
|
sep = eos
|
||||||
|
|
||||||
|
instance = {}
|
||||||
|
instance["lm_labels"] = reply + [eos]
|
||||||
|
caption = list(chain(*caption))
|
||||||
|
|
||||||
|
# Add state tokens if applicable
|
||||||
|
if add_state_tokens:
|
||||||
|
caption.insert(0, hist_state)
|
||||||
|
history = deepcopy(history_orig)
|
||||||
|
history[-1].insert(0, ques_state)
|
||||||
|
else:
|
||||||
|
history = history_orig
|
||||||
|
|
||||||
|
if not drop_caption:
|
||||||
|
# sequence = [[bos] + list(chain(*caption))] + history + [reply + ([eos] if with_eos else [])]
|
||||||
|
|
||||||
|
# NOTE It is important not to include the reply in the input of the encoder -- > the decoder will just
|
||||||
|
# learn to copy it --> low train/val loss but no learning is happening
|
||||||
|
sequence = [[bos] + caption + [eos]] + [[sep] + s for s in history] + [[eos]]
|
||||||
|
else:
|
||||||
|
sequence = [[bos]] + [[hist_state]] + [[sep] + s for s in history] + [[eos]]
|
||||||
|
|
||||||
|
instance["input_ids"] = list(chain(*sequence))
|
||||||
|
return instance
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = ArgumentParser(description='debug dataloader')
|
||||||
|
parser.add_argument(
|
||||||
|
'--split',
|
||||||
|
type=str,
|
||||||
|
default='train',
|
||||||
|
help='train or val')
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--model',
|
||||||
|
type=str,
|
||||||
|
default='mixer',
|
||||||
|
help='model name to train or test')
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--log_dataset',
|
||||||
|
action='store_true',
|
||||||
|
default=False,
|
||||||
|
help='Whether or not to log the processed data')
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--add_state_tokens',
|
||||||
|
action='store_true',
|
||||||
|
default=True,
|
||||||
|
help='Whether or not to add state tokens')
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--log_dir',
|
||||||
|
type=str,
|
||||||
|
default='processed/avsd',
|
||||||
|
help='Output directory')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = parse_args()
|
||||||
|
split = args.split
|
||||||
|
|
||||||
|
config = pyhocon.ConfigFactory.parse_file(
|
||||||
|
'config/mst_mixer.conf')[args.model]
|
||||||
|
config['expand_rnd'] = False
|
||||||
|
config['debugging'] = False
|
||||||
|
config['overfit'] = False
|
||||||
|
args.log_dir = os.path.join(args.log_dir, 'hist_with_{}_rounds'.format(config['n_history']) )
|
||||||
|
if args.log_dataset:
|
||||||
|
log_dir = os.path.join(args.log_dir, split)
|
||||||
|
if not os.path.isdir(log_dir):
|
||||||
|
os.makedirs(log_dir)
|
||||||
|
|
||||||
|
tokenizer = BartTokenizer.from_pretrained('facebook/bart-{}'.format(config['bart_size']))
|
||||||
|
tokenizer.add_special_tokens({'additional_special_tokens': ADDITIONAL_SPECIAL_TOKENS})
|
||||||
|
dialogs, features = get_dataset(config, split, tokenizer)
|
||||||
|
pbar = tqdm(dialogs)
|
||||||
|
pbar.set_description('[{}] Logging processed data'.format(split))
|
||||||
|
counter = 0
|
||||||
|
for dialog in pbar:
|
||||||
|
vid = dialog['vid']
|
||||||
|
his = dialog['history']
|
||||||
|
cap = dialog['caption']
|
||||||
|
ans = dialog['answer']
|
||||||
|
|
||||||
|
if np.random.rand() < config['caption_drop_rate']:
|
||||||
|
instance = build_input_from_segments(
|
||||||
|
cap, his, ans, tokenizer, add_state_tokens=args.add_state_tokens, drop_caption=True)
|
||||||
|
else:
|
||||||
|
instance = build_input_from_segments(
|
||||||
|
cap, his, ans, tokenizer, add_state_tokens=args.add_state_tokens, drop_caption=False)
|
||||||
|
|
||||||
|
input_ids = torch.Tensor(instance["input_ids"]).long()
|
||||||
|
lm_labels = torch.Tensor(instance["lm_labels"]).long()
|
||||||
|
|
||||||
|
vgg = np.load(features["vggish"][vid])
|
||||||
|
i3d_flow = np.load(features["i3d_flow"][vid])
|
||||||
|
i3d_rgb = np.load(features["i3d_rgb"][vid])
|
||||||
|
sam = np.load(features["sam"][vid])
|
||||||
|
|
||||||
|
item = {
|
||||||
|
'input_ids': input_ids,
|
||||||
|
'lm_labels': lm_labels,
|
||||||
|
'i3d_rgb': i3d_rgb,
|
||||||
|
'i3d_flow': i3d_flow,
|
||||||
|
'sam': sam,
|
||||||
|
'vgg': vgg,
|
||||||
|
'vid': vid
|
||||||
|
}
|
||||||
|
counter += 1
|
||||||
|
pth = os.path.join(log_dir, str(counter) + '.pkl')
|
||||||
|
with open(pth, 'wb') as f:
|
||||||
|
pickle.dump(item, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
|
else:
|
||||||
|
avsd_dataset = AVSDDataset(config, 'val')
|
||||||
|
avsd_dataloader = DataLoader(avsd_dataset, batch_size=4, shuffle=False, collate_fn=avsd_dataset.collate_fn)
|
||||||
|
|
||||||
|
for i, data in enumerate(avsd_dataloader):
|
||||||
|
print('{}/{}'.format(i, len(avsd_dataloader)))
|
||||||
|
print(avsd_dataset.max_len)
|
||||||
|
|
||||||
|
print('[INFO] Done...')
|
211
custom_datasets/nextqa.py
Normal file
211
custom_datasets/nextqa.py
Normal file
|
@ -0,0 +1,211 @@
|
||||||
|
import os
|
||||||
|
import pandas as pd
|
||||||
|
import h5py
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from transformers import BartTokenizer
|
||||||
|
from itertools import chain
|
||||||
|
|
||||||
|
|
||||||
|
ADDITIONAL_SPECIAL_TOKENS = [
|
||||||
|
'<place_holder>', '<s0>', '<s1>', '<s2>', '<s3>', '<s4>', '<s5>']
|
||||||
|
|
||||||
|
SPECIAL_TOKENS_DICT = {
|
||||||
|
'bos_token': '<s>',
|
||||||
|
'eos_token': '</s>',
|
||||||
|
'pad_token': '<pad>',
|
||||||
|
'additional_special_tokens': ADDITIONAL_SPECIAL_TOKENS
|
||||||
|
}
|
||||||
|
|
||||||
|
S0_TOK = '<s0>' # frame
|
||||||
|
S1_TOK = '<s1>' # mot
|
||||||
|
S2_TOK = '<s2>' # question
|
||||||
|
|
||||||
|
def load_file(file_name):
|
||||||
|
annos = None
|
||||||
|
if os.path.splitext(file_name)[-1] == '.csv':
|
||||||
|
return pd.read_csv(file_name)
|
||||||
|
with open(file_name, 'r') as fp:
|
||||||
|
if os.path.splitext(file_name)[1]== '.txt':
|
||||||
|
annos = fp.readlines()
|
||||||
|
annos = [line.rstrip() for line in annos]
|
||||||
|
if os.path.splitext(file_name)[1] == '.json':
|
||||||
|
annos = json.load(fp)
|
||||||
|
|
||||||
|
return annos
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize(obj, tokenizer):
|
||||||
|
if isinstance(obj, str):
|
||||||
|
return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj))
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
return dict((n, tokenize(o)) for n, o in obj.items())
|
||||||
|
return list(tokenize(o) for o in obj)
|
||||||
|
|
||||||
|
|
||||||
|
class NextQADataset(Dataset):
|
||||||
|
def __init__(self, config, split):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.split = split
|
||||||
|
self.bart_max_input_len = config['bart_max_input_len']
|
||||||
|
self.bart_size = config['bart_size']
|
||||||
|
self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-{}'.format(self.bart_size))
|
||||||
|
self.vocab_size = self.tokenizer.vocab_size
|
||||||
|
|
||||||
|
self.tokenizer.add_special_tokens({'additional_special_tokens': ADDITIONAL_SPECIAL_TOKENS})
|
||||||
|
self.vocab_size += len(ADDITIONAL_SPECIAL_TOKENS)
|
||||||
|
self.tokenizer.save_pretrained(os.path.join(self.config['log_dir'], 'bart_tokenizer'))
|
||||||
|
|
||||||
|
sample_list_file = os.path.join(self.config['nextqa_root'], '{}.csv'.format(split))
|
||||||
|
self.sample_list = load_file(sample_list_file)
|
||||||
|
|
||||||
|
vid_feat_file = os.path.join(self.config['nextqa_vid_feat'], 'app_mot_{}.h5'.format(split))
|
||||||
|
print('Load {}...'.format(vid_feat_file))
|
||||||
|
self.frame_feats = {}
|
||||||
|
self.mot_feats = {}
|
||||||
|
with h5py.File(vid_feat_file, 'r') as fp:
|
||||||
|
vids = fp['ids']
|
||||||
|
feats = fp['feat']
|
||||||
|
for vid, feat in zip(vids, feats):
|
||||||
|
self.frame_feats[str(vid)] = feat[:, :2048] # (16, 2048)
|
||||||
|
self.mot_feats[str(vid)] = feat[:, 2048:] # (16, 2048)
|
||||||
|
|
||||||
|
if self.config['overfit_size'] > 0:
|
||||||
|
self.sample_list = self.sample_list[:self.config['overfit_size']]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.sample_list)
|
||||||
|
|
||||||
|
def get_video_feature(self, video_name):
|
||||||
|
"""
|
||||||
|
:param video_name:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
app_feat = self.frame_feats[video_name]
|
||||||
|
app_feat = torch.from_numpy(app_feat).type(torch.float32)
|
||||||
|
|
||||||
|
mot_feat = self.mot_feats[video_name]
|
||||||
|
mot_feat = torch.from_numpy(mot_feat).type(torch.float32)
|
||||||
|
|
||||||
|
return app_feat, mot_feat
|
||||||
|
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
cur_sample = self.sample_list.loc[idx]
|
||||||
|
video_name, ques, ans, qid = str(cur_sample['video']), str(cur_sample['question']),\
|
||||||
|
str(cur_sample['answer']), str(cur_sample['qid'])
|
||||||
|
|
||||||
|
input_ids = tokenize(ques, self.tokenizer)
|
||||||
|
lm_labels = tokenize(ans, self.tokenizer)
|
||||||
|
|
||||||
|
app_feat, mot_feat = self.get_video_feature(video_name)
|
||||||
|
|
||||||
|
bos, eos, ques_state = self.tokenizer.convert_tokens_to_ids(['<s>', '</s>', '<s2>'])
|
||||||
|
|
||||||
|
# Add state tokens
|
||||||
|
input_ids.insert(0, ques_state)
|
||||||
|
lm_labels.append(eos)
|
||||||
|
question_interval = [0, len(input_ids)]
|
||||||
|
|
||||||
|
input_ids = torch.Tensor(input_ids).long()
|
||||||
|
lm_labels = torch.Tensor(lm_labels).long()
|
||||||
|
|
||||||
|
return input_ids, lm_labels, app_feat, mot_feat, question_interval, video_name
|
||||||
|
|
||||||
|
|
||||||
|
def padding(self, seq, pad_token, max_len=None):
|
||||||
|
if max_len is None:
|
||||||
|
max_len = max([i.size(0) for i in seq])
|
||||||
|
if len(seq[0].size()) == 1:
|
||||||
|
result = torch.ones((len(seq), max_len)).long() * pad_token
|
||||||
|
else:
|
||||||
|
result = torch.ones((len(seq), max_len, seq[0].size(-1))).float()
|
||||||
|
for i in range(len(seq)):
|
||||||
|
result[i, :seq[i].size(0)] = seq[i]
|
||||||
|
return result
|
||||||
|
|
||||||
|
def collate_fn(self, batch):
|
||||||
|
input_ids_list, lm_labels_list, app_feat_list, mot_feat_list, question_interval_list, vid_ids_list = [], [], [], [], [], []
|
||||||
|
for i in batch:
|
||||||
|
input_ids_list.append(i[0])
|
||||||
|
lm_labels_list.append(i[1])
|
||||||
|
app_feat_list.append(i[2])
|
||||||
|
mot_feat_list.append(i[3])
|
||||||
|
question_interval_list.append(i[4])
|
||||||
|
vid_ids_list.append(i[5])
|
||||||
|
|
||||||
|
app_feats = torch.stack(app_feat_list, dim=0).float()
|
||||||
|
mot_feats = torch.stack(mot_feat_list, dim=0).float()
|
||||||
|
|
||||||
|
question_intervals = np.array(question_interval_list)
|
||||||
|
|
||||||
|
pad_token, app_sep, mot_sep, ph_token = self.tokenizer.convert_tokens_to_ids(
|
||||||
|
['<pad>', '<s0>', '<s1>', '<place_holder>'])
|
||||||
|
|
||||||
|
# All the visual features will not be masked because we do not perform any padding on them
|
||||||
|
video_mask = torch.ones((len(batch), 16*2 + 2)) == 1 # NOTE *2: 2 modalities | +2: the state tokens | each modality has length 16
|
||||||
|
# Now we create a dummy input for the video tokens (sole purpose is to reserve the spot of the seperators)
|
||||||
|
dummy = torch.ones((len(batch), 16)) * ph_token
|
||||||
|
video_place_holder_ids = torch.cat(
|
||||||
|
[torch.ones((len(batch), 1)) * app_sep, dummy,
|
||||||
|
torch.ones((len(batch), 1)) * mot_sep, dummy,
|
||||||
|
], dim=-1).long()
|
||||||
|
|
||||||
|
input_ids = self.padding(input_ids_list, pad_token)
|
||||||
|
lm_labels = self.padding(lm_labels_list, -100)
|
||||||
|
text_mask = input_ids != pad_token
|
||||||
|
input_mask = torch.cat([video_mask, text_mask], dim=1)
|
||||||
|
|
||||||
|
# Now we get the intervals of the visual input tokens
|
||||||
|
# Here the interval do not change across the batch dimension
|
||||||
|
app_interval = [0, 16 + 1] # the last token is not part of this modality
|
||||||
|
mot_interval = [16 + 1, 2 * 16 + 2]
|
||||||
|
vis_state_vector_idx = [app_interval[0], mot_interval[0]]
|
||||||
|
|
||||||
|
# adapt the question and history interval -- shifted to the right by the visual input length
|
||||||
|
question_intervals += 2 * 16 + 2
|
||||||
|
question_intervals = question_intervals.tolist()
|
||||||
|
|
||||||
|
question_state_vector_idx = [x[0] for x in question_intervals]
|
||||||
|
|
||||||
|
batch = {
|
||||||
|
'input_ids': input_ids,
|
||||||
|
'video_place_holder_ids': video_place_holder_ids,
|
||||||
|
'app_feats': app_feats,
|
||||||
|
'mot_feats': mot_feats,
|
||||||
|
'lm_labels': lm_labels,
|
||||||
|
'input_mask': input_mask,
|
||||||
|
'app_interval': app_interval,
|
||||||
|
'mot_interval': mot_interval,
|
||||||
|
'question_intervals': question_intervals,
|
||||||
|
'vis_state_vector_idx': vis_state_vector_idx,
|
||||||
|
'question_state_vector_idx': question_state_vector_idx
|
||||||
|
}
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def get_dataset(config, split):
|
||||||
|
|
||||||
|
bart_max_input_len = config['bart_max_input_len']
|
||||||
|
bart_size = config['bart_size']
|
||||||
|
|
||||||
|
sample_list_file = os.path.join(config['nextqa_root'], '{}.csv'.format(split))
|
||||||
|
sample_list = load_file(sample_list_file)
|
||||||
|
|
||||||
|
vid_feat_file = os.path.join(config['nextqa_vid_feat'], 'app_mot_{}.h5'.format(split))
|
||||||
|
print('Load {}...'.format(vid_feat_file))
|
||||||
|
app_feats = {}
|
||||||
|
mot_feats = {}
|
||||||
|
with h5py.File(vid_feat_file, 'r') as fp:
|
||||||
|
vids = fp['ids']
|
||||||
|
feats = fp['feat']
|
||||||
|
for vid, feat in zip(vids, feats):
|
||||||
|
app_feats[str(vid)] = feat[:, :2048] # (16, 2048)
|
||||||
|
mot_feats[str(vid)] = feat[:, 2048:] # (16, 2048)
|
||||||
|
|
||||||
|
return sample_list, app_feats, mot_feats
|
||||||
|
|
179
custom_datasets/segment.py
Normal file
179
custom_datasets/segment.py
Normal file
|
@ -0,0 +1,179 @@
|
||||||
|
from segment_anything import SamPredictor, SamAutomaticMaskGenerator, sam_model_registry
|
||||||
|
from tqdm import tqdm
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
import pickle
|
||||||
|
import cv2
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--sam_ckpt',
|
||||||
|
type=str,
|
||||||
|
help='SAM checkpoint to be used'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--avsd_root',
|
||||||
|
type=str,
|
||||||
|
help='Directory where the individual AVSD frames are located'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--crop_root',
|
||||||
|
type=str,
|
||||||
|
help='Directory where the individual crops (objects) will be saved'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--embed_root',
|
||||||
|
type=str,
|
||||||
|
help='Directory where the individual embeddings will be saved'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--mode',
|
||||||
|
type=str,
|
||||||
|
choices=['segment', 'embed'],
|
||||||
|
help='segment: segment the image into regions | embed: embed the image crops detected during segmentation'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--start',
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help='Start index of the partition'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--end',
|
||||||
|
type=int,
|
||||||
|
default=1968,
|
||||||
|
help='End index of the partition'
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def partition_ids(avsd_ids, start, end):
|
||||||
|
avsd_ids.sort()
|
||||||
|
assert start < end
|
||||||
|
assert start >= 0 and end <= len(avsd_ids)
|
||||||
|
avsd_ids_partition = avsd_ids[start:end]
|
||||||
|
return avsd_ids_partition
|
||||||
|
|
||||||
|
|
||||||
|
def get_middle_frames(avsd_ids_partition, avsd_root):
|
||||||
|
pbar = tqdm(avsd_ids_partition)
|
||||||
|
pbar.set_description('[INFO] Preparing frames of {} videos'.format(len(avsd_ids_partition)))
|
||||||
|
path_list = []
|
||||||
|
for avsd_id in pbar:
|
||||||
|
frames = os.listdir(os.path.join(avsd_root, avsd_id))
|
||||||
|
if 'test' in avsd_root:
|
||||||
|
frames.sort(key=lambda f: int(f.split('_')[-1].split('.')[0]))
|
||||||
|
else:
|
||||||
|
frames.sort(key=lambda f: int(f.split('-')[-1].split('.')[0]))
|
||||||
|
middle_frame = frames[int(len(frames)/2)]
|
||||||
|
middle_frame = os.path.join(avsd_root, avsd_id, middle_frame)
|
||||||
|
path_list.append(middle_frame)
|
||||||
|
return path_list
|
||||||
|
|
||||||
|
|
||||||
|
def segment_images(sam, path_list, crop_root):
|
||||||
|
mask_generator = SamAutomaticMaskGenerator(sam)
|
||||||
|
pbar = tqdm(path_list)
|
||||||
|
pbar.set_description('Detecting Objects')
|
||||||
|
for pth in pbar:
|
||||||
|
vid_id = pth.split('/')[-2]
|
||||||
|
crop_dir = os.path.join(crop_root, vid_id)
|
||||||
|
if not os.path.isdir(crop_dir):
|
||||||
|
os.makedirs(crop_dir)
|
||||||
|
|
||||||
|
image = cv2.imread(pth)
|
||||||
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||||
|
masks = mask_generator.generate(image)
|
||||||
|
masks.sort(key=lambda e: e['stability_score'], reverse=True)
|
||||||
|
if len(masks) > 36:
|
||||||
|
masks = masks[:36]
|
||||||
|
for i, mask in enumerate(masks):
|
||||||
|
crop = image[
|
||||||
|
int(mask['bbox'][1]):int(mask['bbox'][1] + mask['bbox'][3] + 1),
|
||||||
|
int(mask['bbox'][0]):int(mask['bbox'][0] + mask['bbox'][2] + 1),
|
||||||
|
:
|
||||||
|
]
|
||||||
|
crop_flipped = cv2.flip(crop, 1) # Horizontal flip
|
||||||
|
cv2.imwrite(os.path.join(crop_dir, f'obj_{i}.jpg'), crop)
|
||||||
|
cv2.imwrite(os.path.join(crop_dir, f'obj_{i}_flipped.jpg'), crop_flipped)
|
||||||
|
|
||||||
|
print('[INFO] Done...')
|
||||||
|
|
||||||
|
|
||||||
|
def embed_objects(sam, crop_ids, crop_root, embed_root):
|
||||||
|
predictor = SamPredictor(sam)
|
||||||
|
pbar = tqdm(crop_ids)
|
||||||
|
pbar.set_description('Embedding Objects')
|
||||||
|
for vid_id in pbar:
|
||||||
|
embeds = []
|
||||||
|
crop_dir = os.path.join(crop_root, vid_id)
|
||||||
|
crop_paths = list(map(lambda p: os.path.join(crop_dir, p), os.listdir(crop_dir)))
|
||||||
|
crop_paths = list(filter(lambda p: 'flipped' not in p, crop_paths))
|
||||||
|
crop_paths.sort(key=lambda p: int(p.split('_')[-1].split('.')[0]))
|
||||||
|
for cp in crop_paths:
|
||||||
|
crop = cv2.imread(cp)
|
||||||
|
crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
|
||||||
|
predictor.set_image(crop)
|
||||||
|
embed_crop = predictor.get_image_embedding()
|
||||||
|
embed_crop = embed_crop.mean(-1).mean(-1)
|
||||||
|
|
||||||
|
crop_flipped = cv2.flip(crop, 1)
|
||||||
|
predictor.set_image(crop_flipped)
|
||||||
|
embed_crop_flipped = predictor.get_image_embedding()
|
||||||
|
embed_crop_flipped = embed_crop_flipped.mean(-1).mean(-1)
|
||||||
|
|
||||||
|
embed = torch.cat((embed_crop, embed_crop_flipped), dim=-1)
|
||||||
|
# embed = embed.copy().cpu()
|
||||||
|
embeds.append(embed)
|
||||||
|
|
||||||
|
embeds = torch.cat(embeds, 0).cpu().numpy()
|
||||||
|
np.save(os.path.join(embed_root, f'{vid_id}.npy'), embeds)
|
||||||
|
|
||||||
|
print('[INFO] Done...')
|
||||||
|
|
||||||
|
|
||||||
|
def segment(args, sam):
|
||||||
|
avsd_ids = os.listdir(args.avsd_root)
|
||||||
|
avsd_ids.sort()
|
||||||
|
avsd_ids_partition = partition_ids(avsd_ids, args.start, args.end)
|
||||||
|
path_list = get_middle_frames(avsd_ids_partition, args.avsd_root)
|
||||||
|
|
||||||
|
if not os.path.isdir(args.crop_root):
|
||||||
|
os.makedirs(args.crop_root)
|
||||||
|
segment_images(sam, path_list, args.crop_root)
|
||||||
|
|
||||||
|
|
||||||
|
def embed(args, sam):
|
||||||
|
crop_ids = os.listdir(args.crop_root)
|
||||||
|
crop_ids.sort()
|
||||||
|
crop_ids_partition = partition_ids(crop_ids, args.start, args.end)
|
||||||
|
if not os.path.isdir(args.embed_root):
|
||||||
|
os.makedirs(args.embed_root)
|
||||||
|
embed_objects(sam, crop_ids_partition, args.crop_root, args.embed_root)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = parse_args()
|
||||||
|
sam = sam_model_registry['vit_h'](
|
||||||
|
checkpoint=args.sam_ckpt)
|
||||||
|
device = 'cuda'
|
||||||
|
sam.to(device=device)
|
||||||
|
|
||||||
|
assert args.mode in ['segment', 'embed']
|
||||||
|
if args.mode == 'segment':
|
||||||
|
segment(args, sam)
|
||||||
|
else:
|
||||||
|
embed(args, sam)
|
0
features/.gitkeep
Normal file
0
features/.gitkeep
Normal file
63
generate_parallel_avsd.sh
Executable file
63
generate_parallel_avsd.sh
Executable file
|
@ -0,0 +1,63 @@
|
||||||
|
export MODEL=$1
|
||||||
|
export TAG=$2
|
||||||
|
export MODE=$3
|
||||||
|
export EVAL_DIR=$4
|
||||||
|
export DSTC=$5
|
||||||
|
|
||||||
|
|
||||||
|
# >>> conda initialize >>>
|
||||||
|
# !! Contents within this block are managed by 'conda init' !!
|
||||||
|
__conda_setup="$('/opt/anaconda3/bin/conda' 'shell.bash' 'hook' 2> /dev/null)"
|
||||||
|
if [ $? -eq 0 ]; then
|
||||||
|
eval "$__conda_setup"
|
||||||
|
else
|
||||||
|
if [ -f "/opt/anaconda3/etc/profile.d/conda.sh" ]; then
|
||||||
|
. "/opt/anaconda3/etc/profile.d/conda.sh"
|
||||||
|
else
|
||||||
|
export PATH="/opt/anaconda3/bin:$PATH"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
unset __conda_setup
|
||||||
|
# <<< conda initialize <<<
|
||||||
|
|
||||||
|
conda activate mst_mixer
|
||||||
|
|
||||||
|
if [ $DSTC -eq 10 ]; then
|
||||||
|
export CUDA_VISIBLE_DEVICES=0; python main.py --start_idx_gen 0000 --end_idx_gen 0112 --gen_subset_num 01 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=1; python main.py --start_idx_gen 0112 --end_idx_gen 0224 --gen_subset_num 02 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=2; python main.py --start_idx_gen 0224 --end_idx_gen 0336 --gen_subset_num 03 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=3; python main.py --start_idx_gen 0336 --end_idx_gen 0448 --gen_subset_num 04 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=4; python main.py --start_idx_gen 0448 --end_idx_gen 0560 --gen_subset_num 05 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=5; python main.py --start_idx_gen 0560 --end_idx_gen 0672 --gen_subset_num 06 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=6; python main.py --start_idx_gen 0672 --end_idx_gen 0784 --gen_subset_num 07 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=7; python main.py --start_idx_gen 0784 --end_idx_gen 0896 --gen_subset_num 08 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=0; python main.py --start_idx_gen 0896 --end_idx_gen 1008 --gen_subset_num 09 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=1; python main.py --start_idx_gen 1008 --end_idx_gen 1120 --gen_subset_num 10 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=2; python main.py --start_idx_gen 1120 --end_idx_gen 1232 --gen_subset_num 11 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=3; python main.py --start_idx_gen 1232 --end_idx_gen 1344 --gen_subset_num 12 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=4; python main.py --start_idx_gen 1344 --end_idx_gen 1456 --gen_subset_num 13 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=5; python main.py --start_idx_gen 1456 --end_idx_gen 1568 --gen_subset_num 14 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=6; python main.py --start_idx_gen 1568 --end_idx_gen 1680 --gen_subset_num 15 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=7; python main.py --start_idx_gen 1680 --end_idx_gen 1804 --gen_subset_num 16 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
else
|
||||||
|
export CUDA_VISIBLE_DEVICES=0; python main.py --start_idx_gen 0000 --end_idx_gen 0107 --gen_subset_num 01 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=1; python main.py --start_idx_gen 0107 --end_idx_gen 0214 --gen_subset_num 02 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=2; python main.py --start_idx_gen 0214 --end_idx_gen 0321 --gen_subset_num 03 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=3; python main.py --start_idx_gen 0321 --end_idx_gen 0428 --gen_subset_num 04 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=4; python main.py --start_idx_gen 0428 --end_idx_gen 0535 --gen_subset_num 05 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=5; python main.py --start_idx_gen 0535 --end_idx_gen 0642 --gen_subset_num 06 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=6; python main.py --start_idx_gen 0642 --end_idx_gen 0749 --gen_subset_num 07 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=7; python main.py --start_idx_gen 0749 --end_idx_gen 0856 --gen_subset_num 08 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=0; python main.py --start_idx_gen 0856 --end_idx_gen 0963 --gen_subset_num 09 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=1; python main.py --start_idx_gen 0963 --end_idx_gen 1070 --gen_subset_num 10 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=2; python main.py --start_idx_gen 1070 --end_idx_gen 1177 --gen_subset_num 11 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=3; python main.py --start_idx_gen 1177 --end_idx_gen 1284 --gen_subset_num 12 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=4; python main.py --start_idx_gen 1284 --end_idx_gen 1391 --gen_subset_num 13 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=5; python main.py --start_idx_gen 1391 --end_idx_gen 1498 --gen_subset_num 14 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=6; python main.py --start_idx_gen 1498 --end_idx_gen 1605 --gen_subset_num 15 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=7; python main.py --start_idx_gen 1605 --end_idx_gen 1710 --gen_subset_num 16 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
fi
|
||||||
|
|
||||||
|
wait
|
||||||
|
|
||||||
|
python merge_pred_avsd.py --dstc $DSTC
|
67
generate_parallel_nextqa.sh
Executable file
67
generate_parallel_nextqa.sh
Executable file
|
@ -0,0 +1,67 @@
|
||||||
|
export MODEL=$1
|
||||||
|
export TAG=$2
|
||||||
|
export MODE=$3
|
||||||
|
export EVAL_DIR=$4
|
||||||
|
export DSTC=$5
|
||||||
|
|
||||||
|
|
||||||
|
# >>> conda initialize >>>
|
||||||
|
# !! Contents within this block are managed by 'conda init' !!
|
||||||
|
__conda_setup="$('/opt/anaconda3/bin/conda' 'shell.bash' 'hook' 2> /dev/null)"
|
||||||
|
if [ $? -eq 0 ]; then
|
||||||
|
eval "$__conda_setup"
|
||||||
|
else
|
||||||
|
if [ -f "/opt/anaconda3/etc/profile.d/conda.sh" ]; then
|
||||||
|
. "/opt/anaconda3/etc/profile.d/conda.sh"
|
||||||
|
else
|
||||||
|
export PATH="/opt/anaconda3/bin:$PATH"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
unset __conda_setup
|
||||||
|
# <<< conda initialize <<<
|
||||||
|
|
||||||
|
conda activate mst_mixer
|
||||||
|
|
||||||
|
export CUDA_VISIBLE_DEVICES=0; python main.py --start_idx_gen 0000 --end_idx_gen 0285 --gen_subset_num 01 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=0; python main.py --start_idx_gen 0285 --end_idx_gen 0570 --gen_subset_num 02 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=0; python main.py --start_idx_gen 0570 --end_idx_gen 0855 --gen_subset_num 03 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=0; python main.py --start_idx_gen 0855 --end_idx_gen 1140 --gen_subset_num 04 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
|
||||||
|
export CUDA_VISIBLE_DEVICES=1; python main.py --start_idx_gen 1140 --end_idx_gen 1425 --gen_subset_num 05 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=1; python main.py --start_idx_gen 1425 --end_idx_gen 1710 --gen_subset_num 06 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=1; python main.py --start_idx_gen 1710 --end_idx_gen 1995 --gen_subset_num 07 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=1; python main.py --start_idx_gen 1995 --end_idx_gen 2280 --gen_subset_num 08 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
|
||||||
|
export CUDA_VISIBLE_DEVICES=2; python main.py --start_idx_gen 2280 --end_idx_gen 2565 --gen_subset_num 09 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=2; python main.py --start_idx_gen 2565 --end_idx_gen 2850 --gen_subset_num 10 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=2; python main.py --start_idx_gen 2850 --end_idx_gen 3135 --gen_subset_num 11 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=2; python main.py --start_idx_gen 3135 --end_idx_gen 3420 --gen_subset_num 12 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
|
||||||
|
export CUDA_VISIBLE_DEVICES=3; python main.py --start_idx_gen 3420 --end_idx_gen 3705 --gen_subset_num 13 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=3; python main.py --start_idx_gen 3705 --end_idx_gen 3990 --gen_subset_num 14 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=3; python main.py --start_idx_gen 3990 --end_idx_gen 4275 --gen_subset_num 15 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=3; python main.py --start_idx_gen 4275 --end_idx_gen 4560 --gen_subset_num 16 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
|
||||||
|
export CUDA_VISIBLE_DEVICES=4; python main.py --start_idx_gen 4560 --end_idx_gen 4845 --gen_subset_num 17 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=4; python main.py --start_idx_gen 4845 --end_idx_gen 5130 --gen_subset_num 18 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=4; python main.py --start_idx_gen 5130 --end_idx_gen 5415 --gen_subset_num 19 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=4; python main.py --start_idx_gen 5415 --end_idx_gen 5700 --gen_subset_num 20 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
|
||||||
|
export CUDA_VISIBLE_DEVICES=5; python main.py --start_idx_gen 5700 --end_idx_gen 5985 --gen_subset_num 21 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=5; python main.py --start_idx_gen 5985 --end_idx_gen 6270 --gen_subset_num 22 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=5; python main.py --start_idx_gen 6270 --end_idx_gen 6555 --gen_subset_num 23 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=5; python main.py --start_idx_gen 6555 --end_idx_gen 6840 --gen_subset_num 24 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
|
||||||
|
export CUDA_VISIBLE_DEVICES=6; python main.py --start_idx_gen 6840 --end_idx_gen 7125 --gen_subset_num 25 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=6; python main.py --start_idx_gen 7125 --end_idx_gen 7410 --gen_subset_num 26 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=6; python main.py --start_idx_gen 7410 --end_idx_gen 7695 --gen_subset_num 27 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=6; python main.py --start_idx_gen 7695 --end_idx_gen 7980 --gen_subset_num 28 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
|
||||||
|
export CUDA_VISIBLE_DEVICES=7; python main.py --start_idx_gen 7980 --end_idx_gen 8265 --gen_subset_num 29 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=7; python main.py --start_idx_gen 8265 --end_idx_gen 8550 --gen_subset_num 30 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=7; python main.py --start_idx_gen 8550 --end_idx_gen 8835 --gen_subset_num 31 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
export CUDA_VISIBLE_DEVICES=7; python main.py --start_idx_gen 8835 --end_idx_gen 9178 --gen_subset_num 32 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
|
||||||
|
|
||||||
|
wait
|
||||||
|
|
||||||
|
python merge_pred_nextqa.py
|
153
init_utils.py
Normal file
153
init_utils.py
Normal file
|
@ -0,0 +1,153 @@
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import random
|
||||||
|
import pyhocon
|
||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
import subprocess
|
||||||
|
import itertools
|
||||||
|
import glob
|
||||||
|
import glog as log
|
||||||
|
import sys
|
||||||
|
import re
|
||||||
|
from os import path as osp
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from custom_datasets.avsd import AVSDDataset
|
||||||
|
from custom_datasets.nextqa import NextQADataset
|
||||||
|
|
||||||
|
from runners.runner_avsd import AVSDRunner
|
||||||
|
from runners.runner_nextqa import NEXTQARunner
|
||||||
|
|
||||||
|
|
||||||
|
def load_runner(config, tokenizer, vocab_size):
|
||||||
|
if config['task'] == 'avsd':
|
||||||
|
return AVSDRunner(config, tokenizer, vocab_size)
|
||||||
|
elif config['task'] == 'nextqa':
|
||||||
|
return NEXTQARunner(config, tokenizer, vocab_size)
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
|
||||||
|
def load_dataset(config):
|
||||||
|
if config['task'] == 'avsd':
|
||||||
|
dataset = AVSDDataset(config, 'train')
|
||||||
|
dataset_eval = AVSDDataset(config, 'val')
|
||||||
|
elif config['task'] == 'nextqa':
|
||||||
|
dataset = NextQADataset(config, 'train')
|
||||||
|
dataset_eval = NextQADataset(config, 'val')
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
return dataset, dataset_eval
|
||||||
|
|
||||||
|
|
||||||
|
def set_random_seed(random_seed):
|
||||||
|
torch.manual_seed(random_seed)
|
||||||
|
torch.cuda.manual_seed(random_seed)
|
||||||
|
random.seed(random_seed)
|
||||||
|
np.random.seed(random_seed)
|
||||||
|
|
||||||
|
|
||||||
|
def copy_file_to_log(log_dir):
|
||||||
|
dirs_to_cp = ['.', 'config', 'datasets', 'runners', 'models']
|
||||||
|
files_to_cp = ['*.py', '*.json', '*.sh', '*.conf']
|
||||||
|
for dir_name in dirs_to_cp:
|
||||||
|
dir_name = osp.join(log_dir, 'code', dir_name)
|
||||||
|
if not osp.exists(dir_name):
|
||||||
|
os.makedirs(dir_name)
|
||||||
|
for dir_name, file_name in itertools.product(dirs_to_cp, files_to_cp):
|
||||||
|
filename = osp.join(dir_name, file_name)
|
||||||
|
if len(glob.glob(filename)) > 0:
|
||||||
|
os.system(f'cp {filename} {osp.join(log_dir, "code", dir_name)}')
|
||||||
|
log.info(f'Files copied to {osp.join(log_dir, "code")}')
|
||||||
|
|
||||||
|
|
||||||
|
def set_log_file(fname, file_only=False):
|
||||||
|
# if fname already exists, find all log file under log dir,
|
||||||
|
# and name the current log file with a new number
|
||||||
|
if osp.exists(fname):
|
||||||
|
prefix, suffix = osp.splitext(fname)
|
||||||
|
log_files = glob.glob(prefix + '*' + suffix)
|
||||||
|
count = 0
|
||||||
|
for log_file in log_files:
|
||||||
|
num = re.search(r'(\d+)', log_file)
|
||||||
|
if num is not None:
|
||||||
|
num = int(num.group(0))
|
||||||
|
count = max(num, count)
|
||||||
|
fname = fname.replace(suffix, str(count + 1) + suffix)
|
||||||
|
# set log file
|
||||||
|
# simple tricks for duplicating logging destination in the logging module such as:
|
||||||
|
# logging.getLogger().addHandler(logging.FileHandler(filename))
|
||||||
|
# does NOT work well here, because python Traceback message (not via logging module) is not sent to the file,
|
||||||
|
# the following solution (copied from : https://stackoverflow.com/questions/616645) is a little bit
|
||||||
|
# complicated but simulates exactly the "tee" command in linux shell, and it redirects everything
|
||||||
|
if file_only:
|
||||||
|
# we only output messages to file, and stdout/stderr receives nothing.
|
||||||
|
# this feature is designed for executing the script via ssh:
|
||||||
|
# since ssh has a windowing kind of flow control, i.e., if the controller does not read data from a
|
||||||
|
# ssh channel and its buffer fills up, the execution machine will not be able to write anything into the
|
||||||
|
# channel and the process will be set to sleeping (S) status until someone reads all data from the channel.
|
||||||
|
# this is not desired since we do not want to read stdout/stderr from the controller machine.
|
||||||
|
# so, here we use a simple solution: disable output to stdout/stderr and only output messages to log file.
|
||||||
|
log.logger.handlers[0].stream = log.handler.stream = sys.stdout = sys.stderr = f = open(fname, 'w', buffering=1)
|
||||||
|
else:
|
||||||
|
# we output messages to both file and stdout/stderr
|
||||||
|
tee = subprocess.Popen(['tee', fname], stdin=subprocess.PIPE)
|
||||||
|
os.dup2(tee.stdin.fileno(), sys.stdout.fileno())
|
||||||
|
os.dup2(tee.stdin.fileno(), sys.stderr.fileno())
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_from_env(model, mode, model_type, eval_dir, tag=''):
|
||||||
|
if "GPU" in os.environ:
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ['GPU']
|
||||||
|
if mode in ['train', 'debug']:
|
||||||
|
config = pyhocon.ConfigFactory.parse_file(f"config/{model_type}.conf")[model]
|
||||||
|
else:
|
||||||
|
path_config = os.path.join(eval_dir, 'code', f"config/{model_type}.conf")
|
||||||
|
config = pyhocon.ConfigFactory.parse_file(path_config)[model]
|
||||||
|
config['log_dir'] = eval_dir
|
||||||
|
|
||||||
|
config['model_type'] = model_type
|
||||||
|
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
||||||
|
config['num_gpus'] = len(os.environ["CUDA_VISIBLE_DEVICES"].split(','))
|
||||||
|
# multi-gpu setting
|
||||||
|
if config['num_gpus'] > 1:
|
||||||
|
os.environ['MASTER_ADDR'] = 'localhost'
|
||||||
|
os.environ["MASTER_PORT"] = str(config['master_port'])
|
||||||
|
|
||||||
|
if mode == 'debug':
|
||||||
|
model += '_debug'
|
||||||
|
|
||||||
|
if tag:
|
||||||
|
model += '-' + tag
|
||||||
|
if mode != 'generate':
|
||||||
|
config["log_dir"] = os.path.join(config["log_dir"], model)
|
||||||
|
if not os.path.exists(config["log_dir"]):
|
||||||
|
os.makedirs(config["log_dir"])
|
||||||
|
|
||||||
|
config['timestamp'] = datetime.datetime.now().strftime('%m%d-%H%M%S')
|
||||||
|
|
||||||
|
# Choose the correct config file and add the BART json file to config
|
||||||
|
if mode in ['train', 'debug']:
|
||||||
|
config['bart_config'] = config['{}_bart_{}_config'.format(
|
||||||
|
config['task'], config['bart_size'])]
|
||||||
|
else:
|
||||||
|
config['bart_config'] = os.path.join(eval_dir, 'code', 'config/{}_bart_{}.json'.format(
|
||||||
|
config['task'], config['bart_size']))
|
||||||
|
|
||||||
|
config['bart_config_json'] = json.load(open(config['bart_config'], 'r'))
|
||||||
|
|
||||||
|
config['overfit'] = config['overfit_size'] > 0
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def set_training_steps(config, num_samples):
|
||||||
|
if config['parallel'] and config['dp_type'] == 'dp':
|
||||||
|
config['num_iter_per_epoch'] = int(np.ceil(num_samples / config['batch_size']))
|
||||||
|
else:
|
||||||
|
config['num_iter_per_epoch'] = int(np.ceil(num_samples / (config['batch_size'] * config['num_gpus'])))
|
||||||
|
if 'train_steps' not in config:
|
||||||
|
config['train_steps'] = config['num_iter_per_epoch'] * config['num_epochs']
|
||||||
|
if 'warmup_steps' not in config:
|
||||||
|
config['warmup_steps'] = int(config['train_steps'] * config['warmup_ratio'])
|
||||||
|
return config
|
225
main.py
Normal file
225
main.py
Normal file
|
@ -0,0 +1,225 @@
|
||||||
|
from init_utils import (
|
||||||
|
load_runner,
|
||||||
|
load_dataset,
|
||||||
|
set_random_seed,
|
||||||
|
set_training_steps,
|
||||||
|
initialize_from_env,
|
||||||
|
set_log_file,
|
||||||
|
copy_file_to_log
|
||||||
|
)
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
import pyhocon
|
||||||
|
import glog as log
|
||||||
|
import socket
|
||||||
|
import getpass
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.distributed as dist
|
||||||
|
from transformers import BartTokenizer
|
||||||
|
|
||||||
|
from custom_datasets.avsd import get_dataset as get_avsd_dataset
|
||||||
|
from custom_datasets.nextqa import get_dataset as get_nextqa_dataset
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='Main script for MST-MIXER')
|
||||||
|
parser.add_argument(
|
||||||
|
'--model',
|
||||||
|
type=str,
|
||||||
|
default='mst_mixer/mixer',
|
||||||
|
help='model name to train or test')
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--mode',
|
||||||
|
type=str,
|
||||||
|
default='train',
|
||||||
|
help='train, generate or debug'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--eval_dir',
|
||||||
|
type=str,
|
||||||
|
default='ckpt/avsd'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--wandb_project',
|
||||||
|
type=str,
|
||||||
|
default='mst_mixer'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--wandb_mode',
|
||||||
|
type=str,
|
||||||
|
default='offline',
|
||||||
|
choices=['online', 'offline', 'disabled', 'run', 'dryrun']
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--tag',
|
||||||
|
type=str,
|
||||||
|
default='full_model',
|
||||||
|
help="Tag to differentiate the models"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--start_idx_gen',
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="The start index for generation"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--end_idx_gen',
|
||||||
|
type=int,
|
||||||
|
default=10,
|
||||||
|
help="The end index for generation"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--gen_subset_num',
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="The index of the test split for generation"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
parser.add_argument('--ssh', action='store_true',
|
||||||
|
help='whether or not we are executing command via ssh. '
|
||||||
|
'If set to True, we will not log.info anything to screen and only redirect them to log file')
|
||||||
|
|
||||||
|
|
||||||
|
def main(gpu, config, args):
|
||||||
|
config['training'] = args.mode == 'train'
|
||||||
|
config['debugging'] = args.mode == 'debug'
|
||||||
|
config['generating'] = args.mode == 'generate'
|
||||||
|
config['wandb_project'] = args.wandb_project
|
||||||
|
config['wandb_mode'] = 'disabled'
|
||||||
|
if config['training']:
|
||||||
|
config['wandb_mode'] = args.wandb_mode
|
||||||
|
|
||||||
|
# When generating, only use 1 GPU
|
||||||
|
if config['generating']:
|
||||||
|
assert config['num_gpus'] == 1, 'When generating, only use 1 GPU!'
|
||||||
|
|
||||||
|
if config['parallel'] and config['dp_type'] != 'dp':
|
||||||
|
config['rank'] = gpu
|
||||||
|
dist.init_process_group(
|
||||||
|
backend='nccl',
|
||||||
|
# init_method='env://',
|
||||||
|
world_size=config['num_gpus'],
|
||||||
|
rank=gpu
|
||||||
|
)
|
||||||
|
config['display'] = gpu == 0
|
||||||
|
torch.cuda.set_device(gpu)
|
||||||
|
else:
|
||||||
|
config['display'] = True
|
||||||
|
if config['debugging'] or (config['parallel'] and config['dp_type'] != 'dp'):
|
||||||
|
config['num_workers'] = 0
|
||||||
|
|
||||||
|
# set logs
|
||||||
|
if config['training']:
|
||||||
|
log_file = os.path.join(config["log_dir"], f'{args.mode}.log')
|
||||||
|
set_log_file(log_file, file_only=args.ssh)
|
||||||
|
|
||||||
|
# print environment info
|
||||||
|
if config['display']:
|
||||||
|
log.info('Host: {}, user: {}, CUDA_VISIBLE_DEVICES: {}, cwd: {}'.format(
|
||||||
|
socket.gethostname(), getpass.getuser(), os.environ.get('CUDA_VISIBLE_DEVICES', ''), os.getcwd()))
|
||||||
|
log.info('Command line is: {}'.format(' '.join(sys.argv)))
|
||||||
|
|
||||||
|
if config['parallel'] and config['dp_type'] != 'dp':
|
||||||
|
log.info(f'World_size: {config["num_gpus"]}, cur rank: {config["rank"]}')
|
||||||
|
log.info(f"Running experiment: {args.model}")
|
||||||
|
log.info(f"Results saved to {config['log_dir']}")
|
||||||
|
|
||||||
|
# initialization
|
||||||
|
if config['display'] and config['training']:
|
||||||
|
copy_file_to_log(config['log_dir'])
|
||||||
|
set_random_seed(config['random_seed'])
|
||||||
|
|
||||||
|
device = torch.device(f"cuda:{gpu}")
|
||||||
|
if config["use_cpu"]:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
config['device'] = device
|
||||||
|
|
||||||
|
# prepare datasets (train and validation)
|
||||||
|
dataset, dataset_eval = load_dataset(config)
|
||||||
|
|
||||||
|
# set training steps
|
||||||
|
if not config['generating'] or config['parallel']:
|
||||||
|
config = set_training_steps(config, len(dataset))
|
||||||
|
|
||||||
|
if config['display']:
|
||||||
|
log.info(pyhocon.HOCONConverter.convert(config, "hocon"))
|
||||||
|
|
||||||
|
# load runner
|
||||||
|
runner = load_runner(config, dataset.tokenizer, dataset.vocab_size)
|
||||||
|
|
||||||
|
# parallel
|
||||||
|
if config['parallel']:
|
||||||
|
if config['dp_type'] == 'ddp':
|
||||||
|
torch.cuda.set_device(gpu)
|
||||||
|
runner.model = runner.model.to(gpu)
|
||||||
|
runner.model = nn.parallel.DistributedDataParallel(
|
||||||
|
runner.model,
|
||||||
|
device_ids=[gpu],
|
||||||
|
output_device=gpu,
|
||||||
|
find_unused_parameters=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unrecognized dp_type: {config["dp_type"]}')
|
||||||
|
|
||||||
|
if config['training'] or config['debugging']:
|
||||||
|
ckpt_path = config.get('start_path', None)
|
||||||
|
runner.load_ckpt(ckpt_path=ckpt_path)
|
||||||
|
runner.train(dataset, dataset_eval)
|
||||||
|
|
||||||
|
elif config['generating']:
|
||||||
|
if config['loads_start_path']:
|
||||||
|
runner.load_ckpt(config['start_ckpt_for_generating'])
|
||||||
|
else:
|
||||||
|
runner.load_ckpt_best()
|
||||||
|
assert args.gen_subset_num > 0
|
||||||
|
# Load the data
|
||||||
|
if config['task'] == 'avsd':
|
||||||
|
# load the saved tokenizer
|
||||||
|
tokenizer = BartTokenizer.from_pretrained(os.path.join(config['log_dir'], 'bart_tokenizer'))
|
||||||
|
test_dataset, _ = get_avsd_dataset(config, 'test', tokenizer)
|
||||||
|
assert args.start_idx_gen >= 0 and args.end_idx_gen <= len(test_dataset) and args.start_idx_gen < args.end_idx_gen
|
||||||
|
test_dataset = test_dataset[args.start_idx_gen:args.end_idx_gen]
|
||||||
|
runner.generate(
|
||||||
|
test_dataset, args.tag, tokenizer, gen_subset_num=args.gen_subset_num
|
||||||
|
)
|
||||||
|
|
||||||
|
elif config['task'] == 'nextqa':
|
||||||
|
# load the saved tokenizer
|
||||||
|
tokenizer = BartTokenizer.from_pretrained(os.path.join(config['log_dir'], 'bart_tokenizer'))
|
||||||
|
test_dataset, app_feats, mot_feats = get_nextqa_dataset(config, 'test')
|
||||||
|
assert args.start_idx_gen >= 0 and args.end_idx_gen <= len(test_dataset) and args.start_idx_gen < args.end_idx_gen
|
||||||
|
test_dataset = test_dataset[args.start_idx_gen:args.end_idx_gen]
|
||||||
|
runner.generate(
|
||||||
|
test_dataset, app_feats, mot_feats, args.tag, tokenizer, args.start_idx_gen, args.end_idx_gen, gen_subset_num=args.gen_subset_num
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
if config['parallel']:
|
||||||
|
dist.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# initialization
|
||||||
|
model_type, model_name = args.model.split('/')
|
||||||
|
config = initialize_from_env(model_name, args.mode, model_type, args.eval_dir, tag=args.tag)
|
||||||
|
if config['num_gpus'] > 1:
|
||||||
|
config['parallel'] = True
|
||||||
|
mp.spawn(main, nprocs=config['num_gpus'], args=(config, args))
|
||||||
|
else:
|
||||||
|
config['parallel'] = False
|
||||||
|
main(0, config, args)
|
61
merge_pred_avsd.py
Normal file
61
merge_pred_avsd.py
Normal file
|
@ -0,0 +1,61 @@
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='Main script for MST-MIXER')
|
||||||
|
parser.add_argument(
|
||||||
|
'--dstc',
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
choices=[7, 8, 10],
|
||||||
|
help='DSTC challenge identifier')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
assert args.dstc in [7, 8, 10]
|
||||||
|
if args.dstc == 7:
|
||||||
|
output_dir = 'output/dstc7'
|
||||||
|
raw_data_path = 'raw_data/test_set4DSTC7-AVSD.json'
|
||||||
|
|
||||||
|
elif args.dstc == 8:
|
||||||
|
output_dir = 'output/dstc8'
|
||||||
|
raw_data_path = 'raw_data/test_set4DSTC8-AVSD.json'
|
||||||
|
else:
|
||||||
|
output_dir = 'output/dstc10'
|
||||||
|
raw_data_path = 'raw_data/test_set4DSTC10-AVSD.json'
|
||||||
|
|
||||||
|
with open(raw_data_path, 'r') as f:
|
||||||
|
raw_dialogs = json.load(f)['dialogs']
|
||||||
|
|
||||||
|
file_paths = os.listdir(output_dir)
|
||||||
|
file_paths = list(filter(lambda f: 'part' in f , file_paths))
|
||||||
|
name = file_paths[0]
|
||||||
|
file_paths = list(map(lambda f: os.path.join(output_dir, f), file_paths))
|
||||||
|
|
||||||
|
dialogs = {}
|
||||||
|
for pth in file_paths:
|
||||||
|
with open(pth, 'r') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
for dialog in data['dialogs']:
|
||||||
|
vid_id = dialog['image_id']
|
||||||
|
dialogs[vid_id] = dialog
|
||||||
|
# dialogs.extend(data['dialogs'])
|
||||||
|
os.remove(pth)
|
||||||
|
|
||||||
|
# Now, re-establish the original order of the dialogs
|
||||||
|
res = []
|
||||||
|
for dialog in raw_dialogs:
|
||||||
|
vid_id = dialog['image_id']
|
||||||
|
res.append(dialogs[vid_id])
|
||||||
|
|
||||||
|
res = {
|
||||||
|
'dialogs': res
|
||||||
|
}
|
||||||
|
|
||||||
|
name = "".join(name.split('-')[:-1]) + '.json'
|
||||||
|
output_path = os.path.join(output_dir, name)
|
||||||
|
with open(output_path, 'w') as f:
|
||||||
|
json.dump(res, f, indent=4)
|
||||||
|
|
||||||
|
print('[INFO] Files merged and saved in {}'.format(output_path))
|
34
merge_pred_nextqa.py
Normal file
34
merge_pred_nextqa.py
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='Main script for MST-MIXER')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
output_dir = 'output/nextqa'
|
||||||
|
|
||||||
|
file_paths = os.listdir(output_dir)
|
||||||
|
file_paths = list(filter(lambda f: 'part' in f , file_paths))
|
||||||
|
name = file_paths[0]
|
||||||
|
file_paths = list(map(lambda f: os.path.join(output_dir, f), file_paths))
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
for pth in file_paths:
|
||||||
|
with open(pth, 'r') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
for video_id in data:
|
||||||
|
if video_id not in results:
|
||||||
|
results[video_id] = data[video_id]
|
||||||
|
else:
|
||||||
|
for qid in data[video_id]:
|
||||||
|
if qid not in results[video_id]:
|
||||||
|
results[video_id][qid] = data[video_id][qid]
|
||||||
|
os.remove(pth)
|
||||||
|
|
||||||
|
name = "".join(name.split('-')[:-1]) + '.json'
|
||||||
|
output_path = os.path.join(output_dir, name)
|
||||||
|
with open(output_path, 'w') as f:
|
||||||
|
json.dump(results, f, indent=4)
|
||||||
|
|
||||||
|
print('[INFO] Files merged and saved in {}'.format(output_path))
|
BIN
misc/italy.png
Normal file
BIN
misc/italy.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 3.9 KiB |
BIN
misc/mixer.png
Normal file
BIN
misc/mixer.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 4.9 MiB |
BIN
misc/teaser.png
Normal file
BIN
misc/teaser.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 197 KiB |
0
models/__init__.py
Normal file
0
models/__init__.py
Normal file
1438
models/avsd_bart.py
Normal file
1438
models/avsd_bart.py
Normal file
File diff suppressed because it is too large
Load diff
801
models/gnns.py
Normal file
801
models/gnns.py
Normal file
|
@ -0,0 +1,801 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch_geometric.nn.dense import DenseGATConv, DenseGCNConv, DenseSAGEConv
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
from .utils import get_knn_graph
|
||||||
|
import torch_sparse
|
||||||
|
|
||||||
|
|
||||||
|
class BartAttention(nn.Module):
|
||||||
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
is_decoder: bool = False,
|
||||||
|
bias: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.dropout = dropout
|
||||||
|
self.head_dim = embed_dim // num_heads
|
||||||
|
|
||||||
|
if (self.head_dim * num_heads) != self.embed_dim:
|
||||||
|
raise ValueError(
|
||||||
|
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
||||||
|
f" and `num_heads`: {num_heads})."
|
||||||
|
)
|
||||||
|
self.scaling = self.head_dim**-0.5
|
||||||
|
self.is_decoder = is_decoder
|
||||||
|
|
||||||
|
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||||
|
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||||
|
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||||
|
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||||
|
|
||||||
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||||
|
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
key_value_states: Optional[torch.Tensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
layer_head_mask: Optional[torch.Tensor] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
"""Input shape: Batch x Time x Channel"""
|
||||||
|
|
||||||
|
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||||
|
# for the decoder
|
||||||
|
is_cross_attention = key_value_states is not None
|
||||||
|
|
||||||
|
bsz, tgt_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
# get query proj
|
||||||
|
query_states = self.q_proj(hidden_states) * self.scaling
|
||||||
|
# get key, value proj
|
||||||
|
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
||||||
|
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
||||||
|
# the provided `key_value_states` to support prefix tuning
|
||||||
|
if (
|
||||||
|
is_cross_attention
|
||||||
|
and past_key_value is not None
|
||||||
|
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
||||||
|
):
|
||||||
|
# reuse k,v, cross_attentions
|
||||||
|
key_states = past_key_value[0]
|
||||||
|
value_states = past_key_value[1]
|
||||||
|
elif is_cross_attention:
|
||||||
|
# cross_attentions
|
||||||
|
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
||||||
|
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
||||||
|
elif past_key_value is not None:
|
||||||
|
# reuse k, v, self_attention
|
||||||
|
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||||
|
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||||
|
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||||
|
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||||
|
else:
|
||||||
|
# self_attention
|
||||||
|
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||||
|
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||||
|
|
||||||
|
if self.is_decoder:
|
||||||
|
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||||
|
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||||
|
# key/value_states (first "if" case)
|
||||||
|
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||||
|
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||||
|
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||||
|
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||||
|
past_key_value = (key_states, value_states)
|
||||||
|
|
||||||
|
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
||||||
|
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
||||||
|
key_states = key_states.reshape(*proj_shape)
|
||||||
|
value_states = value_states.reshape(*proj_shape)
|
||||||
|
|
||||||
|
src_len = key_states.size(1)
|
||||||
|
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||||
|
|
||||||
|
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||||
|
f" {attn_weights.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||||
|
)
|
||||||
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||||||
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||||
|
|
||||||
|
if layer_head_mask is not None:
|
||||||
|
if layer_head_mask.size() != (self.num_heads,):
|
||||||
|
raise ValueError(
|
||||||
|
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
|
||||||
|
f" {layer_head_mask.size()}"
|
||||||
|
)
|
||||||
|
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||||
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
# this operation is a bit awkward, but it's required to
|
||||||
|
# make sure that attn_weights keeps its gradient.
|
||||||
|
# In order to do so, attn_weights have to be reshaped
|
||||||
|
# twice and have to be reused in the following
|
||||||
|
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||||
|
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
else:
|
||||||
|
attn_weights_reshaped = None
|
||||||
|
|
||||||
|
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||||
|
|
||||||
|
attn_output = torch.bmm(attn_probs, value_states)
|
||||||
|
|
||||||
|
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||||
|
attn_output = attn_output.transpose(1, 2)
|
||||||
|
|
||||||
|
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
||||||
|
# partitioned across GPUs when using tensor-parallelism.
|
||||||
|
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
||||||
|
|
||||||
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output, attn_weights_reshaped, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
class MLPModule(nn.Module):
|
||||||
|
def __init__(self, d_in, d_hidden, d_out, num_layers=3, dropout=0.3, use_non_linear=False, use_batch_norm=False):
|
||||||
|
super(MLPModule, self).__init__()
|
||||||
|
|
||||||
|
self.use_batch_norm = use_batch_norm
|
||||||
|
self.dropout = dropout
|
||||||
|
|
||||||
|
self.fcs = nn.ModuleList()
|
||||||
|
self.batch_norms = nn.ModuleList()
|
||||||
|
|
||||||
|
if num_layers == 1:
|
||||||
|
self.fcs.append(nn.Linear(d_in, d_out))
|
||||||
|
else:
|
||||||
|
self.fcs.append(nn.Linear(d_in, d_hidden))
|
||||||
|
self.batch_norms.append(nn.BatchNorm1d(d_hidden))
|
||||||
|
for _ in range(num_layers - 2):
|
||||||
|
self.fcs.append(nn.Linear(d_hidden, d_hidden))
|
||||||
|
self.batch_norms.append(nn.BatchNorm1d(d_hidden))
|
||||||
|
|
||||||
|
self.fcs.append(nn.Linear(d_hidden, d_out))
|
||||||
|
|
||||||
|
self.act_fn = nn.GELU()
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
self.use_non_linear=use_non_linear
|
||||||
|
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
for fc in self.fcs:
|
||||||
|
fc.reset_parameters()
|
||||||
|
for bn in self.batch_norms:
|
||||||
|
bn.reset_parameters()
|
||||||
|
|
||||||
|
def forward(self, X):
|
||||||
|
for fc, bn in zip(self.fcs[:-1], self.batch_norms):
|
||||||
|
X = fc(X)
|
||||||
|
X = self.act_fn(X)
|
||||||
|
if self.use_batch_norm:
|
||||||
|
if X.dim() > 2:
|
||||||
|
X = X.transpose(1, 2)
|
||||||
|
X = bn(X)
|
||||||
|
if X.dim() > 2:
|
||||||
|
X = X.transpose(1, 2)
|
||||||
|
X = self.dropout(X)
|
||||||
|
X = self.fcs[-1](X)
|
||||||
|
return X
|
||||||
|
|
||||||
|
|
||||||
|
class GATModule(nn.Module):
|
||||||
|
def __init__(self, d_in, d_hidden, d_out, num_layers=3, dropout=0.3, concat=True, heads=2, use_non_linear=False, use_batch_norm=False):
|
||||||
|
super(GATModule, self).__init__()
|
||||||
|
self.gnns = nn.ModuleList()
|
||||||
|
if concat:
|
||||||
|
d_hidden = d_hidden // heads
|
||||||
|
d_out = d_out // heads
|
||||||
|
|
||||||
|
self.gnns.append(DenseGATConv(d_in, d_hidden, heads=heads, concat=concat, dropout=dropout))
|
||||||
|
|
||||||
|
self.batch_norms = nn.ModuleList()
|
||||||
|
self.batch_norms.append(nn.BatchNorm1d(d_hidden * heads if concat else d_hidden))
|
||||||
|
|
||||||
|
for _ in range(num_layers - 2):
|
||||||
|
self.gnns.append(DenseGATConv(
|
||||||
|
d_hidden * heads if concat else d_hidden, d_hidden,
|
||||||
|
heads=heads,
|
||||||
|
concat=concat,
|
||||||
|
dropout=dropout)
|
||||||
|
)
|
||||||
|
self.batch_norms.append(nn.BatchNorm1d(d_hidden * heads if concat else d_hidden))
|
||||||
|
|
||||||
|
self.gnns.append(DenseGATConv(
|
||||||
|
d_hidden * heads if concat else d_hidden, d_out,
|
||||||
|
heads=heads,
|
||||||
|
concat=concat,
|
||||||
|
dropout=dropout)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
self.non_linear = nn.GELU()
|
||||||
|
self.use_batch_norm = use_batch_norm
|
||||||
|
self.use_non_linear = use_non_linear
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
for gnn in self.gnns:
|
||||||
|
gnn.reset_parameters()
|
||||||
|
for batch_norm in self.batch_norms:
|
||||||
|
batch_norm.reset_parameters()
|
||||||
|
|
||||||
|
def forward(self, X, A):
|
||||||
|
Z = self.dropout(X)
|
||||||
|
for i in range(len(self.gnns) - 1):
|
||||||
|
Z = self.gnns[i](Z, A)
|
||||||
|
if self.use_batch_norm:
|
||||||
|
Z = Z.transpose(1, 2)
|
||||||
|
Z = self.batch_norms[i](Z)
|
||||||
|
Z = Z.transpose(1, 2)
|
||||||
|
if self.use_non_linear:
|
||||||
|
Z = self.non_linear(Z)
|
||||||
|
Z = self.dropout(Z)
|
||||||
|
Z = self.gnns[-1](Z, A)
|
||||||
|
return Z
|
||||||
|
|
||||||
|
|
||||||
|
class GCNModule(nn.Module):
|
||||||
|
def __init__(self, d_in, d_hidden, d_out, num_layers=3, dropout=0.3, use_non_linear=False, use_batch_norm=False):
|
||||||
|
super(GCNModule, self).__init__()
|
||||||
|
self.gnns = nn.ModuleList()
|
||||||
|
|
||||||
|
self.gnns.append(DenseGCNConv(d_in, d_hidden))
|
||||||
|
|
||||||
|
self.batch_norms = nn.ModuleList()
|
||||||
|
self.batch_norms.append(nn.BatchNorm1d(d_hidden))
|
||||||
|
|
||||||
|
for _ in range(num_layers - 2):
|
||||||
|
self.gnns.append(DenseGCNConv(
|
||||||
|
d_hidden, d_hidden)
|
||||||
|
)
|
||||||
|
self.batch_norms.append(nn.BatchNorm1d(d_hidden))
|
||||||
|
|
||||||
|
self.gnns.append(DenseGCNConv(
|
||||||
|
d_hidden, d_out)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
self.non_linear = nn.GELU()
|
||||||
|
self.use_batch_norm = use_batch_norm
|
||||||
|
self.use_non_linear = use_non_linear
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
for gnn in self.gnns:
|
||||||
|
gnn.reset_parameters()
|
||||||
|
for batch_norm in self.batch_norms:
|
||||||
|
batch_norm.reset_parameters()
|
||||||
|
|
||||||
|
def forward(self, X, A):
|
||||||
|
Z = self.dropout(X)
|
||||||
|
for i in range(len(self.gnns) - 1):
|
||||||
|
Z = self.gnns[i](Z, A)
|
||||||
|
if self.use_batch_norm:
|
||||||
|
Z = Z.transpose(1, 2)
|
||||||
|
Z = self.batch_norms[i](Z)
|
||||||
|
Z = Z.transpose(1, 2)
|
||||||
|
if self.use_non_linear:
|
||||||
|
Z = self.non_linear(Z)
|
||||||
|
Z = self.dropout(Z)
|
||||||
|
Z = self.gnns[-1](Z, A)
|
||||||
|
return Z
|
||||||
|
|
||||||
|
|
||||||
|
class SAGEModule(nn.Module):
|
||||||
|
def __init__(self, d_in, d_hidden, d_out, num_layers=3, dropout=0.3, use_non_linear=False, use_batch_norm=False):
|
||||||
|
super(SAGEModule, self).__init__()
|
||||||
|
self.gnns = nn.ModuleList()
|
||||||
|
|
||||||
|
self.gnns.append(DenseSAGEConv(d_in, d_hidden))
|
||||||
|
|
||||||
|
self.batch_norms = nn.ModuleList()
|
||||||
|
self.batch_norms.append(nn.BatchNorm1d(d_hidden))
|
||||||
|
|
||||||
|
for _ in range(num_layers - 2):
|
||||||
|
self.gnns.append(DenseSAGEConv(
|
||||||
|
d_hidden, d_hidden)
|
||||||
|
)
|
||||||
|
self.batch_norms.append(nn.BatchNorm1d(d_hidden))
|
||||||
|
|
||||||
|
self.gnns.append(DenseSAGEConv(
|
||||||
|
d_hidden, d_out)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
self.non_linear = nn.GELU()
|
||||||
|
self.use_batch_norm = use_batch_norm
|
||||||
|
self.use_non_linear = use_non_linear
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
for gnn in self.gnns:
|
||||||
|
gnn.reset_parameters()
|
||||||
|
for batch_norm in self.batch_norms:
|
||||||
|
batch_norm.reset_parameters()
|
||||||
|
|
||||||
|
def forward(self, X, A):
|
||||||
|
Z = self.dropout(X)
|
||||||
|
for i in range(len(self.gnns) - 1):
|
||||||
|
Z = self.gnns[i](Z, A)
|
||||||
|
if self.use_batch_norm:
|
||||||
|
Z = Z.transpose(1, 2)
|
||||||
|
Z = self.batch_norms[i](Z)
|
||||||
|
Z = Z.transpose(1, 2)
|
||||||
|
if self.use_non_linear:
|
||||||
|
Z = self.non_linear(Z)
|
||||||
|
Z = self.dropout(Z)
|
||||||
|
Z = self.gnns[-1](Z, A)
|
||||||
|
return Z
|
||||||
|
|
||||||
|
|
||||||
|
class GlobalGraphLearner(nn.Module):
|
||||||
|
def __init__(self, d_in, num_heads, random=False):
|
||||||
|
super(GlobalGraphLearner, self).__init__()
|
||||||
|
self.random = random
|
||||||
|
if not self.random:
|
||||||
|
w = torch.Tensor(num_heads, d_in)
|
||||||
|
self.w = Parameter(nn.init.xavier_uniform_(w), requires_grad=True)
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
if not self.random:
|
||||||
|
self.w = Parameter(nn.init.xavier_uniform_(self.w))
|
||||||
|
|
||||||
|
def forward(self, Z):
|
||||||
|
if self.random:
|
||||||
|
att_global = torch.randn((Z.size(0), Z.size(1), Z.size(1))).to(Z.device)
|
||||||
|
else:
|
||||||
|
w_expanded = self.w.unsqueeze(1).unsqueeze(1)
|
||||||
|
Z = Z.unsqueeze(0) * w_expanded
|
||||||
|
Z = F.normalize(Z, p=2, dim=-1)
|
||||||
|
att_global = torch.matmul(Z, Z.transpose(-1, -2)).mean(0)
|
||||||
|
mask_global = (att_global > 0).detach().float()
|
||||||
|
att_global = att_global * mask_global
|
||||||
|
|
||||||
|
return att_global
|
||||||
|
|
||||||
|
|
||||||
|
class DenseAPPNP(nn.Module):
|
||||||
|
def __init__(self, K, alpha):
|
||||||
|
super().__init__()
|
||||||
|
self.K = K
|
||||||
|
self.alpha = alpha
|
||||||
|
|
||||||
|
def forward(self, x, adj_t):
|
||||||
|
h = x
|
||||||
|
for _ in range(self.K):
|
||||||
|
if adj_t.is_sparse:
|
||||||
|
x = torch_sparse.spmm(adj_t, x)
|
||||||
|
else:
|
||||||
|
x = torch.matmul(adj_t, x)
|
||||||
|
x = x * (1 - self.alpha)
|
||||||
|
x += self.alpha * h
|
||||||
|
x /= self.K
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Dense_APPNP_Net(nn.Module):
|
||||||
|
def __init__(self, d_in, d_hidden, d_out, dropout=.5, K=10, alpha=.1):
|
||||||
|
super(Dense_APPNP_Net, self).__init__()
|
||||||
|
self.lin1 = nn.Linear(d_in, d_hidden)
|
||||||
|
self.lin2 = nn.Linear(d_hidden, d_out)
|
||||||
|
self.prop1 = DenseAPPNP(K, alpha)
|
||||||
|
self.dropout = dropout
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
self.lin1.reset_parameters()
|
||||||
|
self.lin2.reset_parameters()
|
||||||
|
|
||||||
|
def forward(self, x, adj_t):
|
||||||
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||||
|
x = F.relu(self.lin1(x))
|
||||||
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||||
|
x = self.lin2(x)
|
||||||
|
x = self.prop1(x, adj_t)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MMGraphLearner(nn.Module):
|
||||||
|
def __init__(self, d_in, num_heads, random=False):
|
||||||
|
super(MMGraphLearner, self).__init__()
|
||||||
|
self.random = random
|
||||||
|
if not self.random:
|
||||||
|
w = torch.Tensor(num_heads, d_in)
|
||||||
|
self.w = Parameter(nn.init.xavier_uniform_(w), requires_grad=True)
|
||||||
|
|
||||||
|
self.fc = nn.Linear(d_in, d_in)
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
if not self.random:
|
||||||
|
self.fc.reset_parameters()
|
||||||
|
self.w = Parameter(nn.init.xavier_uniform_(self.w), requires_grad=True)
|
||||||
|
|
||||||
|
def forward(self, features):
|
||||||
|
if self.random:
|
||||||
|
att = torch.randn((features.size(0), features.size(1), features.size(1))).to(features.device)
|
||||||
|
else:
|
||||||
|
features = self.fc(features)
|
||||||
|
w_expanded = self.w.unsqueeze(1).unsqueeze(1)
|
||||||
|
features = features.unsqueeze(0) * w_expanded
|
||||||
|
features = F.normalize(features, p=2, dim=-1)
|
||||||
|
att = torch.matmul(features, features.transpose(-1, -2)).mean(0)
|
||||||
|
mask = (att > 0).detach().float()
|
||||||
|
att = att * mask
|
||||||
|
return att
|
||||||
|
|
||||||
|
|
||||||
|
class QNetLocal(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super(QNetLocal, self).__init__()
|
||||||
|
self.config=config
|
||||||
|
|
||||||
|
self.mm_gnn_modules = nn.ModuleList()
|
||||||
|
self.mm_graph_learners_1 = nn.ModuleList()
|
||||||
|
self.mm_graph_learners_2 = nn.ModuleList()
|
||||||
|
|
||||||
|
for _ in range(self.config.num_modalities):
|
||||||
|
if self.config.gnn_type == 'gat':
|
||||||
|
self.mm_gnn_modules.append(GATModule(
|
||||||
|
self.config.d_model, self.config.d_model, self.config.d_model,
|
||||||
|
num_layers=self.config.num_local_gnn_layers,
|
||||||
|
heads=self.config.num_local_gnn_heads,
|
||||||
|
dropout=self.config.local_gnn_dropout,
|
||||||
|
concat=self.config.local_gnn_concat,
|
||||||
|
use_batch_norm=self.config.use_local_gnn_bn,
|
||||||
|
use_non_linear=self.config.use_non_linear
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif self.config.gnn_type == 'appnp':
|
||||||
|
self.mm_gnn_modules.append(Dense_APPNP_Net(
|
||||||
|
self.config.d_model, self.config.d_model, self.config.d_model, dropout=self.config.local_gnn_dropout,
|
||||||
|
K=self.config.gnn_K, alpha=self.config.gnn_alpha
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif self.config.gnn_type == 'gcn':
|
||||||
|
self.mm_gnn_modules.append(GCNModule(
|
||||||
|
self.config.d_model, self.config.d_model, self.config.d_model,
|
||||||
|
num_layers=self.config.num_local_gnn_layers,
|
||||||
|
dropout=self.config.local_gnn_dropout,
|
||||||
|
use_batch_norm=self.config.use_local_gnn_bn,
|
||||||
|
use_non_linear=self.config.use_non_linear
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif self.config.gnn_type == 'sage':
|
||||||
|
self.mm_gnn_modules.append(SAGEModule(
|
||||||
|
self.config.d_model, self.config.d_model, self.config.d_model,
|
||||||
|
num_layers=self.config.num_local_gnn_layers,
|
||||||
|
dropout=self.config.local_gnn_dropout,
|
||||||
|
use_batch_norm=self.config.use_local_gnn_bn,
|
||||||
|
use_non_linear=self.config.use_non_linear
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
self.mm_graph_learners_1.append(MMGraphLearner(self.config.d_model, self.config.num_local_gr_learner_heads, random=self.config.use_random_graphs))
|
||||||
|
self.mm_graph_learners_2.append(MMGraphLearner(self.config.d_model * 2, self.config.num_local_gr_learner_heads, random=self.config.use_random_graphs))
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
for i in range(self.config.num_modalities):
|
||||||
|
self.mm_gnn_modules[i].reset_parameters()
|
||||||
|
self.mm_graph_learners_1[i].reset_parameters()
|
||||||
|
self.mm_graph_learners_2[i].reset_parameters()
|
||||||
|
|
||||||
|
def forward(self, features, A_tildes=None):
|
||||||
|
mm_Xs = features# []
|
||||||
|
device = features[0].device
|
||||||
|
|
||||||
|
if A_tildes is None:
|
||||||
|
A_tildes = []
|
||||||
|
for mm_X in mm_Xs:
|
||||||
|
A_tildes.append(get_knn_graph(mm_X, self.config.num_nn, device))
|
||||||
|
|
||||||
|
################# Multi-modal graph learner (upper branch) #################
|
||||||
|
A_primes = []
|
||||||
|
for i, mm_X in enumerate(mm_Xs): # iterate over the modalities
|
||||||
|
A_primes.append(self.mm_graph_learners_1[i](mm_X))
|
||||||
|
|
||||||
|
# Linear combination of A_primes with A_tildes
|
||||||
|
A_primes = [(1 - self.config.init_adj_ratio) * A_prime + self.config.init_adj_ratio * A_tilde for A_prime, A_tilde in zip(A_primes, A_tildes)]
|
||||||
|
|
||||||
|
################# Multi-modal gnn (upper branch) #################
|
||||||
|
Z_primes = []
|
||||||
|
for i, (mm_X, A_prime) in enumerate(zip(mm_Xs, A_primes)):
|
||||||
|
Z_primes.append(self.mm_gnn_modules[i](mm_X, A_prime))
|
||||||
|
|
||||||
|
################# Multi-modal gnn (lower branch) #################
|
||||||
|
Z_double_primes = []
|
||||||
|
for i, (mm_X, A_tilde) in enumerate(zip(mm_Xs, A_tildes)):
|
||||||
|
Z_double_primes.append(self.mm_gnn_modules[i](mm_X, A_tilde))
|
||||||
|
|
||||||
|
Z_concats = [torch.cat([Z_1, Z_2], dim=-1) for Z_1, Z_2 in zip(Z_primes, Z_double_primes)]
|
||||||
|
|
||||||
|
################# Multi-modal graph learner (lower branch) #################
|
||||||
|
A_double_primes = []
|
||||||
|
for i, Z_concat in enumerate(Z_concats):
|
||||||
|
A_double_primes.append(self.mm_graph_learners_2[i](Z_concat))
|
||||||
|
|
||||||
|
A_double_primes = [(1 - self.config.init_adj_ratio) * A_double_prime + self.config.init_adj_ratio * A_tilde for A_double_prime, A_tilde in zip(A_double_primes, A_tildes)]
|
||||||
|
|
||||||
|
As = [(1 - self.config.adj_ratio) * A_prime + self.config.adj_ratio * A_double_prime for A_prime, A_double_prime in zip(A_primes, A_double_primes)]
|
||||||
|
|
||||||
|
################## Average across all multimodal inputs ##################
|
||||||
|
|
||||||
|
Zs = [0.5 * Z1 + 0.5 * Z2 for Z1, Z2 in zip(Z_primes, Z_double_primes)]
|
||||||
|
return As, Zs
|
||||||
|
|
||||||
|
|
||||||
|
class PNetLocal(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super(PNetLocal, self).__init__()
|
||||||
|
self.config = config
|
||||||
|
self.mm_gnn_modules = nn.ModuleList()
|
||||||
|
self.mm_mlp_modules = nn.ModuleList()
|
||||||
|
|
||||||
|
self.mm_graph_learners_1 = nn.ModuleList()
|
||||||
|
self.mm_graph_learners_2 = nn.ModuleList()
|
||||||
|
|
||||||
|
for _ in range(self.config.num_modalities):
|
||||||
|
if self.config.gnn_type == 'gat':
|
||||||
|
self.mm_gnn_modules.append(GATModule(
|
||||||
|
self.config.d_model, self.config.d_model, self.config.d_model,
|
||||||
|
num_layers=self.config.num_local_gnn_layers,
|
||||||
|
heads=self.config.num_local_gnn_heads,
|
||||||
|
dropout=self.config.local_gnn_dropout,
|
||||||
|
concat=self.config.local_gnn_concat,
|
||||||
|
use_batch_norm=self.config.use_local_gnn_bn,
|
||||||
|
use_non_linear=self.config.use_non_linear
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif self.config.gnn_type == 'appnp':
|
||||||
|
self.mm_gnn_modules.append(Dense_APPNP_Net(
|
||||||
|
self.config.d_model, self.config.d_model, self.config.d_model, dropout=self.config.local_gnn_dropout,
|
||||||
|
K=self.config.gnn_K, alpha=self.config.gnn_alpha
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif self.config.gnn_type == 'gcn':
|
||||||
|
self.mm_gnn_modules.append(GCNModule(
|
||||||
|
self.config.d_model, self.config.d_model, self.config.d_model,
|
||||||
|
num_layers=self.config.num_local_gnn_layers,
|
||||||
|
dropout=self.config.local_gnn_dropout,
|
||||||
|
use_batch_norm=self.config.use_local_gnn_bn,
|
||||||
|
use_non_linear=self.config.use_non_linear
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif self.config.gnn_type == 'sage':
|
||||||
|
self.mm_gnn_modules.append(SAGEModule(
|
||||||
|
self.config.d_model, self.config.d_model, self.config.d_model,
|
||||||
|
num_layers=self.config.num_local_gnn_layers,
|
||||||
|
dropout=self.config.local_gnn_dropout,
|
||||||
|
use_batch_norm=self.config.use_local_gnn_bn,
|
||||||
|
use_non_linear=self.config.use_non_linear
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
self.mm_mlp_modules.append(MLPModule(
|
||||||
|
self.config.d_model, self.config.d_model, self.config.d_model,
|
||||||
|
num_layers=self.config.num_local_fc_layers,
|
||||||
|
dropout=self.config.local_fc_dropout,
|
||||||
|
use_batch_norm=self.config.use_local_fc_bn,
|
||||||
|
use_non_linear=self.config.use_non_linear
|
||||||
|
))
|
||||||
|
|
||||||
|
self.mm_graph_learners_1.append(MMGraphLearner(self.config.d_model, self.config.num_local_gr_learner_heads, random=self.config.use_random_graphs))
|
||||||
|
self.mm_graph_learners_2.append(MMGraphLearner(self.config.d_model * 2, self.config.num_local_gr_learner_heads, random=self.config.use_random_graphs))
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
for i in range(self.config.num_modalities):
|
||||||
|
self.mm_gnn_modules[i].reset_parameters()
|
||||||
|
self.mm_mlp_modules[i].reset_parameters()
|
||||||
|
self.mm_graph_learners_1[i].reset_parameters()
|
||||||
|
self.mm_graph_learners_2[i].reset_parameters()
|
||||||
|
|
||||||
|
def forward(self, features):
|
||||||
|
mm_Xs = features
|
||||||
|
|
||||||
|
################# Multi-modal graph learner (upper branch) #################
|
||||||
|
A_primes = []
|
||||||
|
for i, mm_X in enumerate(mm_Xs): # iterate over the modalities
|
||||||
|
A_primes.append(self.mm_graph_learners_1[i](mm_X))
|
||||||
|
|
||||||
|
################# Multi-modal gnn (upper branch) #################
|
||||||
|
Z_primes = []
|
||||||
|
for i, (mm_X, A_prime) in enumerate(zip(mm_Xs, A_primes)):
|
||||||
|
Z_primes.append(self.mm_gnn_modules[i](mm_X, A_prime))
|
||||||
|
|
||||||
|
################# Multi-modal gnn (lower branch) #################
|
||||||
|
Z_double_primes = []
|
||||||
|
for i, mm_X, in enumerate(mm_Xs):
|
||||||
|
Z_double_primes.append(self.mm_mlp_modules[i](mm_X))
|
||||||
|
|
||||||
|
Z_concats = [torch.cat([Z_1, Z_2], dim=-1) for Z_1, Z_2 in zip(Z_primes, Z_double_primes)]
|
||||||
|
|
||||||
|
################# Multi-modal graph learner (lower branch) #################
|
||||||
|
A_double_primes = []
|
||||||
|
for i, Z_concat in enumerate(Z_concats):
|
||||||
|
A_double_primes.append(self.mm_graph_learners_2[i](Z_concat))
|
||||||
|
|
||||||
|
As = [(1 - self.config.adj_ratio) * A_prime + self.config.adj_ratio * A_double_prime for A_prime, A_double_prime in zip(A_primes, A_double_primes)]
|
||||||
|
|
||||||
|
################## Average across all multimodal inputs ##################
|
||||||
|
|
||||||
|
Zs = [0.5 * Z1 + 0.5 * Z2 for Z1, Z2 in zip(Z_primes, Z_double_primes)]
|
||||||
|
|
||||||
|
return As, Zs
|
||||||
|
|
||||||
|
|
||||||
|
class QNetGlobal(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super(QNetGlobal, self).__init__()
|
||||||
|
self.config = config
|
||||||
|
if self.config.gnn_type == 'gat':
|
||||||
|
self.gnn = GATModule(
|
||||||
|
self.config.d_model, self.config.d_model, self.config.d_model,
|
||||||
|
num_layers=self.config.num_global_gnn_layers,
|
||||||
|
heads=self.config.num_global_gnn_heads,
|
||||||
|
dropout=self.config.global_gnn_dropout,
|
||||||
|
concat=self.config.global_gnn_concat,
|
||||||
|
use_batch_norm=self.config.use_global_gnn_bn,
|
||||||
|
use_non_linear=self.config.use_non_linear
|
||||||
|
)
|
||||||
|
elif self.config.gnn_type == 'appnp':
|
||||||
|
self.gnn = Dense_APPNP_Net(
|
||||||
|
self.config.d_model, self.config.d_model, self.config.d_model, dropout=self.config.local_gnn_dropout,
|
||||||
|
K=self.config.gnn_K, alpha=self.config.gnn_alpha
|
||||||
|
)
|
||||||
|
elif self.config.gnn_type == 'gcn':
|
||||||
|
self.gnn = GCNModule(
|
||||||
|
self.config.d_model, self.config.d_model, self.config.d_model,
|
||||||
|
num_layers=self.config.num_global_gnn_layers,
|
||||||
|
dropout=self.config.global_gnn_dropout,
|
||||||
|
use_batch_norm=self.config.use_global_gnn_bn,
|
||||||
|
use_non_linear=self.config.use_non_linear
|
||||||
|
)
|
||||||
|
|
||||||
|
elif self.config.gnn_type == 'sage':
|
||||||
|
self.gnn = SAGEModule(
|
||||||
|
self.config.d_model, self.config.d_model, self.config.d_model,
|
||||||
|
num_layers=self.config.num_global_gnn_layers,
|
||||||
|
dropout=self.config.global_gnn_dropout,
|
||||||
|
use_batch_norm=self.config.use_global_gnn_bn,
|
||||||
|
use_non_linear=self.config.use_non_linear
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
self.graph_learner_1 = GlobalGraphLearner(self.config.d_model, self.config.num_global_gr_learner_heads, self.config.use_random_graphs)
|
||||||
|
self.graph_learner_2 = GlobalGraphLearner(self.config.d_model * 2, self.config.num_global_gr_learner_heads, self.config.use_random_graphs)
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
self.gnn.reset_parameters()
|
||||||
|
self.graph_learner_1.reset_parameters()
|
||||||
|
self.graph_learner_2.reset_parameters()
|
||||||
|
|
||||||
|
def forward(self, Z, A):
|
||||||
|
|
||||||
|
################# Graph learner (upper branch) #################
|
||||||
|
A_prime = self.graph_learner_1(Z)
|
||||||
|
A_prime = (1-self.config.init_adj_ratio) * A_prime + self.config.init_adj_ratio * A
|
||||||
|
|
||||||
|
################# Gnn (upper branch) #################
|
||||||
|
Z_prime = self.gnn(Z, A_prime)
|
||||||
|
|
||||||
|
################# Gnn (lower branch) #################
|
||||||
|
Z_double_prime = self.gnn(Z, A)
|
||||||
|
Z_concat = torch.cat([Z_prime, Z_double_prime], dim=-1)
|
||||||
|
|
||||||
|
################# Graph learner (lower branch) #################
|
||||||
|
A_double_prime = self.graph_learner_2(Z_concat)
|
||||||
|
A_double_prime = (1-self.config.init_adj_ratio) * A_double_prime + self.config.init_adj_ratio * A
|
||||||
|
|
||||||
|
################## Average across branches ##################
|
||||||
|
A_global = (1 - self.config.adj_ratio) * A_prime + self.config.adj_ratio * A_double_prime
|
||||||
|
Z_global = 0.5 * Z_prime + 0.5 * Z_double_prime
|
||||||
|
return A_global, Z_global
|
||||||
|
|
||||||
|
|
||||||
|
class PNetGlobal(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super(PNetGlobal, self).__init__()
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
if self.config.gnn_type == 'gat':
|
||||||
|
self.gnn = GATModule(
|
||||||
|
self.config.d_model, self.config.d_model, self.config.d_model,
|
||||||
|
num_layers=self.config.num_global_gnn_layers,
|
||||||
|
heads=self.config.num_global_gnn_heads,
|
||||||
|
dropout=self.config.global_gnn_dropout,
|
||||||
|
concat=self.config.global_gnn_concat,
|
||||||
|
use_batch_norm=self.config.use_global_gnn_bn,
|
||||||
|
use_non_linear=self.config.use_non_linear
|
||||||
|
)
|
||||||
|
elif self.config.gnn_type == 'appnp':
|
||||||
|
self.gnn = Dense_APPNP_Net(
|
||||||
|
self.config.d_model, self.config.d_model, self.config.d_model, dropout=self.config.local_gnn_dropout,
|
||||||
|
K=self.config.gnn_K, alpha=self.config.gnn_alpha
|
||||||
|
)
|
||||||
|
elif self.config.gnn_type == 'gcn':
|
||||||
|
self.gnn = GCNModule(
|
||||||
|
self.config.d_model, self.config.d_model, self.config.d_model,
|
||||||
|
num_layers=self.config.num_global_gnn_layers,
|
||||||
|
dropout=self.config.global_gnn_dropout,
|
||||||
|
use_batch_norm=self.config.use_global_gnn_bn,
|
||||||
|
use_non_linear=self.config.use_non_linear
|
||||||
|
)
|
||||||
|
elif self.config.gnn_type == 'sage':
|
||||||
|
self.gnn = SAGEModule(
|
||||||
|
self.config.d_model, self.config.d_model, self.config.d_model,
|
||||||
|
num_layers=self.config.num_global_gnn_layers,
|
||||||
|
dropout=self.config.global_gnn_dropout,
|
||||||
|
use_batch_norm=self.config.use_global_gnn_bn,
|
||||||
|
use_non_linear=self.config.use_non_linear
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
self.mlp = MLPModule(
|
||||||
|
self.config.d_model, self.config.d_model, self.config.d_model,
|
||||||
|
num_layers=self.config.num_global_fc_layers,
|
||||||
|
dropout=self.config.global_fc_dropout,
|
||||||
|
use_batch_norm=self.config.use_global_fc_bn,
|
||||||
|
use_non_linear=self.config.use_non_linear
|
||||||
|
|
||||||
|
)
|
||||||
|
|
||||||
|
self.graph_learner_1 = GlobalGraphLearner(self.config.d_model, self.config.num_global_gr_learner_heads, random=self.config.use_random_graphs)
|
||||||
|
self.graph_learner_2 = GlobalGraphLearner(self.config.d_model * 2, self.config.num_global_gr_learner_heads, random=self.config.use_random_graphs)
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
self.gnn.reset_parameters()
|
||||||
|
self.mlp.reset_parameters()
|
||||||
|
self.graph_learner_1.reset_parameters()
|
||||||
|
self.graph_learner_2.reset_parameters()
|
||||||
|
|
||||||
|
def forward(self, Z, A):
|
||||||
|
|
||||||
|
################# Graph learner (upper branch) #################
|
||||||
|
A_prime = self.graph_learner_1(Z)
|
||||||
|
|
||||||
|
################# Gnn (upper branch) #################
|
||||||
|
Z_prime = self.gnn(Z, A_prime)
|
||||||
|
|
||||||
|
################# mlp (lower branch) #################
|
||||||
|
Z_double_prime = self.mlp(Z)
|
||||||
|
Z_concat = torch.cat([Z_prime, Z_double_prime], dim=-1)
|
||||||
|
|
||||||
|
################# Graph learner (lower branch) #################
|
||||||
|
A_double_prime = self.graph_learner_2(Z_concat)
|
||||||
|
# A_double_prime = (1-self.config.init_adj_ratio) * A_double_prime + self.config.init_adj_ratio * A
|
||||||
|
|
||||||
|
################## Average across braches ##################
|
||||||
|
A_global = (1 - self.config.adj_ratio) * A_prime + self.config.adj_ratio * A_double_prime
|
||||||
|
Z_global = 0.5 * Z_prime + 0.5 * Z_double_prime
|
||||||
|
return A_global, Z_global
|
||||||
|
|
1397
models/nextqa_bart.py
Normal file
1397
models/nextqa_bart.py
Normal file
File diff suppressed because it is too large
Load diff
249
models/utils.py
Normal file
249
models/utils.py
Normal file
|
@ -0,0 +1,249 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from transformers.utils import ModelOutput
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
class ELBO(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(ELBO, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, QA, PA):
|
||||||
|
QA_flattened = QA.view(-1).unsqueeze(-1)
|
||||||
|
PA_flattened = PA.view(-1).unsqueeze(-1)
|
||||||
|
|
||||||
|
QA_flattened = torch.cat([torch.zeros_like(QA_flattened), QA_flattened], dim=-1)
|
||||||
|
PA_flattened = torch.cat([torch.zeros_like(PA_flattened), PA_flattened], dim=-1)
|
||||||
|
|
||||||
|
log_QA = F.log_softmax(QA_flattened, dim=1)
|
||||||
|
log_PA = F.log_softmax(PA_flattened, dim=1)
|
||||||
|
|
||||||
|
QA_dist = torch.exp(log_QA)
|
||||||
|
|
||||||
|
loss_QA = torch.mean(log_QA * QA_dist)
|
||||||
|
loss_PA = torch.mean(log_PA * QA_dist)
|
||||||
|
|
||||||
|
loss = loss_QA - loss_PA
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def seperate_nextqa_input_modalities(
|
||||||
|
features, i3d_rgb_interval, i3d_flow_interval, question_intervals,
|
||||||
|
vis_state_vector_idx, question_state_vector_idx,
|
||||||
|
attention_values=None):
|
||||||
|
""" We separate the multimodal input hidden states. The state token embeddings are left out (+1 while indexing)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
features (_type_): _description_
|
||||||
|
i3d_rgb_interval (_type_): _description_
|
||||||
|
i3d_flow_interval (_type_): _description_
|
||||||
|
sam_interval (_type_): _description_
|
||||||
|
audio_interval (_type_): _description_
|
||||||
|
history_intervals (_type_): _description_
|
||||||
|
question_intervals (_type_): _description_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
_type_: _description_
|
||||||
|
"""
|
||||||
|
features_copy = features.clone() # .detach()
|
||||||
|
i3d_rgb_hidden = features_copy[:, i3d_rgb_interval[0]+1:i3d_rgb_interval[1], :]
|
||||||
|
i3d_flow_hidden = features_copy[:, i3d_flow_interval[0]+1:i3d_flow_interval[1], :]
|
||||||
|
|
||||||
|
question_hidden = []
|
||||||
|
features_split = torch.split(features_copy, 1, dim=0)
|
||||||
|
for ques_inter, feat in zip(question_intervals, features_split):
|
||||||
|
ques_idx = torch.arange(ques_inter[0]+1, ques_inter[1]).unsqueeze(0).unsqueeze(-1).repeat(1, 1, feat.size(-1)).to(feat.device)
|
||||||
|
question_hidden.append(torch.gather(feat, 1, ques_idx))
|
||||||
|
|
||||||
|
if attention_values is None:
|
||||||
|
i3d_rgb_att = None
|
||||||
|
i3d_flow_att = None
|
||||||
|
question_att = None
|
||||||
|
else:
|
||||||
|
attention_values = attention_values.mean(1)
|
||||||
|
i3d_rgb_att = attention_values[:, vis_state_vector_idx[0], vis_state_vector_idx[0]+1:vis_state_vector_idx[1]]
|
||||||
|
i3d_flow_att = attention_values[:, vis_state_vector_idx[1], vis_state_vector_idx[1]+1:question_state_vector_idx[0]]
|
||||||
|
question_att = [attention_values[i, question_state_vector_idx[i], question_intervals[i][0] + 1: question_intervals[i][1]] for i in range(len(question_state_vector_idx))]
|
||||||
|
|
||||||
|
features_list = [i3d_rgb_hidden, i3d_flow_hidden, question_hidden]
|
||||||
|
att = [i3d_rgb_att, i3d_flow_att, question_att]
|
||||||
|
|
||||||
|
return features_list, att
|
||||||
|
|
||||||
|
|
||||||
|
def seperate_input_modalities(
|
||||||
|
features, i3d_rgb_interval, i3d_flow_interval, sam_interval, audio_interval, history_intervals, question_intervals,
|
||||||
|
vis_state_vector_idx, history_state_vector_idx, question_state_vector_idx,
|
||||||
|
attention_values=None):
|
||||||
|
""" We separate the multimodal input hidden states. The state token embeddings are left out (+1 while indexing)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
features (_type_): _description_
|
||||||
|
i3d_rgb_interval (_type_): _description_
|
||||||
|
i3d_flow_interval (_type_): _description_
|
||||||
|
sam_interval (_type_): _description_
|
||||||
|
audio_interval (_type_): _description_
|
||||||
|
history_intervals (_type_): _description_
|
||||||
|
question_intervals (_type_): _description_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
_type_: _description_
|
||||||
|
"""
|
||||||
|
features_copy = features.clone() # .detach()
|
||||||
|
i3d_rgb_hidden = features_copy[:, i3d_rgb_interval[0]+1:i3d_rgb_interval[1], :]
|
||||||
|
i3d_flow_hidden = features_copy[:, i3d_flow_interval[0]+1:i3d_flow_interval[1], :]
|
||||||
|
sam_hidden = features_copy[:, sam_interval[0]+1:sam_interval[1], :]
|
||||||
|
audio_hidden = features_copy[:, audio_interval[0]+1:audio_interval[1], :]
|
||||||
|
|
||||||
|
history_hidden = []
|
||||||
|
question_hidden = []
|
||||||
|
features_split = torch.split(features_copy, 1, dim=0)
|
||||||
|
for hist_inter, ques_inter, feat in zip(history_intervals, question_intervals, features_split):
|
||||||
|
hist_idx = torch.arange(hist_inter[0]+1, hist_inter[1]).unsqueeze(0).unsqueeze(-1).repeat(1, 1, feat.size(-1)).to(feat.device)
|
||||||
|
history_hidden.append(torch.gather(feat, 1, hist_idx))
|
||||||
|
|
||||||
|
ques_idx = torch.arange(ques_inter[0]+1, ques_inter[1]).unsqueeze(0).unsqueeze(-1).repeat(1, 1, feat.size(-1)).to(feat.device)
|
||||||
|
question_hidden.append(torch.gather(feat, 1, ques_idx))
|
||||||
|
|
||||||
|
if attention_values is None:
|
||||||
|
i3d_rgb_att = None
|
||||||
|
i3d_flow_att = None
|
||||||
|
sam_att = None
|
||||||
|
audio_att = None
|
||||||
|
history_att = None
|
||||||
|
question_att = None
|
||||||
|
else:
|
||||||
|
attention_values = attention_values.mean(1)
|
||||||
|
i3d_rgb_att = attention_values[:, vis_state_vector_idx[0], vis_state_vector_idx[0]+1:vis_state_vector_idx[1]]
|
||||||
|
i3d_flow_att = attention_values[:, vis_state_vector_idx[1], vis_state_vector_idx[1]+1:vis_state_vector_idx[2]]
|
||||||
|
sam_att = attention_values[:, vis_state_vector_idx[2], vis_state_vector_idx[2]+1:vis_state_vector_idx[3]]
|
||||||
|
audio_att = attention_values[:, vis_state_vector_idx[3], vis_state_vector_idx[3]+1:history_state_vector_idx[0] - 1]
|
||||||
|
history_att = [attention_values[i, history_state_vector_idx[i], history_intervals[i][0] + 1 : history_intervals[i][1]] for i in range(len(history_state_vector_idx))]
|
||||||
|
question_att = [attention_values[i, question_state_vector_idx[i], question_intervals[i][0] + 1: question_intervals[i][1]] for i in range(len(question_state_vector_idx))]
|
||||||
|
|
||||||
|
features_list = [i3d_rgb_hidden, i3d_flow_hidden, sam_hidden, audio_hidden, history_hidden, question_hidden]
|
||||||
|
att = [i3d_rgb_att, i3d_flow_att, sam_att, audio_att, history_att, question_att]
|
||||||
|
|
||||||
|
return features_list, att
|
||||||
|
|
||||||
|
|
||||||
|
def get_knn_graph(features, num_nn, device):
|
||||||
|
features = features.permute((1, 2, 0))
|
||||||
|
cosine_sim_pairwise = F.cosine_similarity(features, features.unsqueeze(1), dim=-2)
|
||||||
|
cosine_sim_pairwise = cosine_sim_pairwise.permute((2, 0, 1))
|
||||||
|
num_nn = min(num_nn, cosine_sim_pairwise.size(-1))
|
||||||
|
adj_mat = torch.zeros_like(cosine_sim_pairwise).to(device)
|
||||||
|
_, to_keep = torch.topk(cosine_sim_pairwise, num_nn, dim=-1, sorted=False)
|
||||||
|
adj_mat = adj_mat.scatter(-1, to_keep, torch.ones_like(adj_mat).to(device))
|
||||||
|
return adj_mat
|
||||||
|
|
||||||
|
|
||||||
|
def track_features_vis(features, att, top_k, device, node_idx=None):
|
||||||
|
"""Computes an adjacency matrix based on the nearset neighbor similiarity for
|
||||||
|
the i3d, audio, and sam input modalities. The tracked constituents of each modality
|
||||||
|
are randomly chosen (A_tilde in the paper).
|
||||||
|
"""
|
||||||
|
features = features.clone().detach()
|
||||||
|
top_k = min(features.size(1), top_k)
|
||||||
|
if att is None:
|
||||||
|
node_idx = torch.randint(low=0, high=features.size(1), size=(features.size(0), top_k))
|
||||||
|
else:
|
||||||
|
_, node_idx = torch.topk(att, top_k, dim=-1, sorted=False)
|
||||||
|
|
||||||
|
node_idx = node_idx.unsqueeze(-1).repeat(1, 1, features.size(-1)).to(device)
|
||||||
|
|
||||||
|
selected_features = torch.gather(features, 1, node_idx)
|
||||||
|
|
||||||
|
return selected_features, node_idx
|
||||||
|
|
||||||
|
|
||||||
|
def track_features_text(features, att, top_k, device, node_idx=None):
|
||||||
|
"""Computes an adjacency matrix based on the nearset neighbor similiarity for
|
||||||
|
the history and question inputs. The tracked constituents of each modality
|
||||||
|
are randomly chosen (A_tilde in the paper).
|
||||||
|
"""
|
||||||
|
hidden_dim = features[0].size(-1)
|
||||||
|
min_len = min([feat.size(1) for feat in features])
|
||||||
|
top_k = min(min_len, top_k)
|
||||||
|
if att is None:
|
||||||
|
node_idx = [torch.randint(low=0, high=feat.size(1), size=(feat.size(0), top_k)) for feat in features]
|
||||||
|
else:
|
||||||
|
node_idx = [torch.topk(a, top_k, dim=-1, sorted=False)[-1] for a in att]
|
||||||
|
|
||||||
|
node_idx = [idx.unsqueeze(-1).repeat(1, 1, hidden_dim).to(device) for idx in node_idx]
|
||||||
|
|
||||||
|
selected_features = [torch.gather(feat, 1, idx) for feat, idx in zip(features, node_idx)]
|
||||||
|
selected_features = torch.cat(selected_features, dim=0)
|
||||||
|
|
||||||
|
return selected_features, node_idx
|
||||||
|
|
||||||
|
|
||||||
|
def diag_tensor(tensors):
|
||||||
|
device = tensors[0].device
|
||||||
|
n = sum([t.size(-1) for t in tensors])
|
||||||
|
bsz = tensors[0].size(0)
|
||||||
|
diag_tensor = torch.zeros((bsz, n, n)).float().to(device)
|
||||||
|
delimiter = 0
|
||||||
|
delimiters = [0]
|
||||||
|
for t in tensors:
|
||||||
|
diag_tensor[:, delimiter:delimiter+t.size(-1), delimiter:delimiter+t.size(-1)] = t
|
||||||
|
delimiter += t.size(-1)
|
||||||
|
delimiters.append(delimiter)
|
||||||
|
|
||||||
|
return diag_tensor, delimiters
|
||||||
|
|
||||||
|
|
||||||
|
def embed_graphs(features, delimiters):
|
||||||
|
state_vectors = []
|
||||||
|
for i in range(len(delimiters) - 1):
|
||||||
|
state_vectors.append(features[:, delimiters[i]:delimiters[i+1], :].mean(dim=1))
|
||||||
|
return state_vectors
|
||||||
|
|
||||||
|
|
||||||
|
class AVSDEncoderOutput(ModelOutput):
|
||||||
|
last_hidden_state: torch.FloatTensor = None
|
||||||
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
QAs_local = None
|
||||||
|
PAs_local = None
|
||||||
|
QA_global = None
|
||||||
|
PA_global = None
|
||||||
|
state_vectors = None
|
||||||
|
|
||||||
|
|
||||||
|
class AVSDSeq2SeqModelOutput(ModelOutput):
|
||||||
|
|
||||||
|
last_hidden_state: torch.FloatTensor = None
|
||||||
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
|
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
|
||||||
|
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
QAs_local = None
|
||||||
|
PAs_local = None
|
||||||
|
QA_global = None
|
||||||
|
PA_global = None
|
||||||
|
state_vectors = None
|
||||||
|
|
||||||
|
|
||||||
|
class AVSDSeq2SeqLMOutput(ModelOutput):
|
||||||
|
|
||||||
|
gen_loss: Optional[torch.FloatTensor] = None
|
||||||
|
elbo_loss_global: Optional[torch.FloatTensor] = None
|
||||||
|
elbo_loss_local: Optional[torch.FloatTensor] = None
|
||||||
|
logits: torch.FloatTensor = None
|
||||||
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
|
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
|
||||||
|
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
encoder_QAs_local = None
|
||||||
|
encoder_PAs_local = None
|
||||||
|
encoder_QA_global = None
|
||||||
|
encoder_PA_global = None
|
||||||
|
encoder_state_vectors = None
|
106
optim_utils.py
Normal file
106
optim_utils.py
Normal file
|
@ -0,0 +1,106 @@
|
||||||
|
from torch.optim.lr_scheduler import _LRScheduler
|
||||||
|
from torch.optim import AdamW
|
||||||
|
|
||||||
|
|
||||||
|
class WarmupLinearScheduleNonZero(_LRScheduler):
|
||||||
|
""" Linear warmup and then linear decay.
|
||||||
|
Linearly increases learning rate from 0 to max_lr over `warmup_steps` training steps.
|
||||||
|
Linearly decreases learning rate linearly to min_lr over remaining `t_total - warmup_steps` steps.
|
||||||
|
"""
|
||||||
|
def __init__(self, optimizer, warmup_steps, t_total, min_lr=1e-5, last_epoch=-1):
|
||||||
|
self.warmup_steps = warmup_steps
|
||||||
|
self.t_total = t_total
|
||||||
|
self.min_lr = min_lr
|
||||||
|
super(WarmupLinearScheduleNonZero, self).__init__(optimizer, last_epoch=last_epoch)
|
||||||
|
|
||||||
|
def get_lr(self):
|
||||||
|
step = self.last_epoch
|
||||||
|
if step < self.warmup_steps:
|
||||||
|
lr_factor = float(step) / float(max(1, self.warmup_steps))
|
||||||
|
else:
|
||||||
|
lr_factor = max(0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps)))
|
||||||
|
|
||||||
|
return [base_lr * lr_factor if (base_lr * lr_factor) > self.min_lr else self.min_lr for base_lr in self.base_lrs]
|
||||||
|
|
||||||
|
|
||||||
|
def init_optim(model, config):
|
||||||
|
encoder_params_with_weight_decay = []
|
||||||
|
encoder_params_without_weight_decay = []
|
||||||
|
decoder_params_with_weight_decay = []
|
||||||
|
decoder_params_without_weight_decay = []
|
||||||
|
other_params_with_weight_decay = []
|
||||||
|
other_params_without_weight_decay = []
|
||||||
|
|
||||||
|
exclude_from_weight_decay=['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
||||||
|
|
||||||
|
# Our model shares (embedding) parameters between the encoder and decoder.
|
||||||
|
# We want to include such parameters only in one parameter group.
|
||||||
|
# So we keep track of the unique ids of each parameter.
|
||||||
|
params_ids = []
|
||||||
|
|
||||||
|
for module_name, module in model.named_children():
|
||||||
|
for param_name, param in module.named_parameters():
|
||||||
|
if id(param) not in params_ids:
|
||||||
|
params_ids.append(id(param))
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
if param.requires_grad:
|
||||||
|
if 'encoder' in param_name:
|
||||||
|
if any(ex in param_name for ex in exclude_from_weight_decay):
|
||||||
|
encoder_params_without_weight_decay.append(param)
|
||||||
|
else:
|
||||||
|
encoder_params_with_weight_decay.append(param)
|
||||||
|
|
||||||
|
elif 'decoder' in param_name:
|
||||||
|
if any(ex in param_name for ex in exclude_from_weight_decay):
|
||||||
|
decoder_params_without_weight_decay.append(param)
|
||||||
|
else:
|
||||||
|
decoder_params_with_weight_decay.append(param)
|
||||||
|
else:
|
||||||
|
if any(ex in param_name for ex in exclude_from_weight_decay):
|
||||||
|
other_params_without_weight_decay.append(param)
|
||||||
|
else:
|
||||||
|
other_params_with_weight_decay.append(param)
|
||||||
|
|
||||||
|
optimizer_grouped_parameters = [
|
||||||
|
{
|
||||||
|
'params': encoder_params_with_weight_decay,
|
||||||
|
'weight_decay': 0.01,
|
||||||
|
'lr': config['learning_rate_bart']
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'params': encoder_params_without_weight_decay,
|
||||||
|
'weight_decay': 0.0,
|
||||||
|
'lr': config['learning_rate_bart']
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'params': decoder_params_with_weight_decay,
|
||||||
|
'weight_decay': 0.01,
|
||||||
|
'lr': config['learning_rate_bart']
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'params': decoder_params_without_weight_decay,
|
||||||
|
'weight_decay': 0.0,
|
||||||
|
'lr': config['learning_rate_bart']
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'params': other_params_with_weight_decay,
|
||||||
|
'weight_decay': 0.01,
|
||||||
|
'lr': config['learning_rate_other']
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'params': other_params_without_weight_decay,
|
||||||
|
'weight_decay': 0.0,
|
||||||
|
'lr': config['learning_rate_other']
|
||||||
|
},
|
||||||
|
]
|
||||||
|
optimizer = AdamW(optimizer_grouped_parameters, lr=config['learning_rate_bart'])
|
||||||
|
|
||||||
|
scheduler = WarmupLinearScheduleNonZero(
|
||||||
|
optimizer,
|
||||||
|
warmup_steps=config['warmup_steps'],
|
||||||
|
t_total=config['train_steps'],
|
||||||
|
min_lr=config['min_lr']
|
||||||
|
)
|
||||||
|
|
||||||
|
return optimizer, scheduler
|
0
processed/.gitkeep
Normal file
0
processed/.gitkeep
Normal file
0
processed/avsd/.gitkeep
Normal file
0
processed/avsd/.gitkeep
Normal file
0
processed/nextqa/.gitkeep
Normal file
0
processed/nextqa/.gitkeep
Normal file
5634
processed/nextqa/annotations/add_reference_answer_test.json
Normal file
5634
processed/nextqa/annotations/add_reference_answer_test.json
Normal file
File diff suppressed because it is too large
Load diff
BIN
processed/nextqa/annotations/glove_embed.npy
Normal file
BIN
processed/nextqa/annotations/glove_embed.npy
Normal file
Binary file not shown.
BIN
processed/nextqa/annotations/test.csv
(Stored with Git LFS)
Normal file
BIN
processed/nextqa/annotations/test.csv
(Stored with Git LFS)
Normal file
Binary file not shown.
|
BIN
processed/nextqa/annotations/train.csv
(Stored with Git LFS)
Normal file
BIN
processed/nextqa/annotations/train.csv
(Stored with Git LFS)
Normal file
Binary file not shown.
|
BIN
processed/nextqa/annotations/val.csv
(Stored with Git LFS)
Normal file
BIN
processed/nextqa/annotations/val.csv
(Stored with Git LFS)
Normal file
Binary file not shown.
|
BIN
processed/nextqa/annotations/vocab.pkl
(Stored with Git LFS)
Normal file
BIN
processed/nextqa/annotations/vocab.pkl
(Stored with Git LFS)
Normal file
Binary file not shown.
BIN
processed/nextqa/vid_feat/app_mot_test.h5
(Stored with Git LFS)
Normal file
BIN
processed/nextqa/vid_feat/app_mot_test.h5
(Stored with Git LFS)
Normal file
Binary file not shown.
BIN
processed/nextqa/vid_feat/app_mot_train.h5
(Stored with Git LFS)
Normal file
BIN
processed/nextqa/vid_feat/app_mot_train.h5
(Stored with Git LFS)
Normal file
Binary file not shown.
BIN
processed/nextqa/vid_feat/app_mot_val.h5
(Stored with Git LFS)
Normal file
BIN
processed/nextqa/vid_feat/app_mot_val.h5
(Stored with Git LFS)
Normal file
Binary file not shown.
0
raw_data/.gitkeep
Normal file
0
raw_data/.gitkeep
Normal file
1
raw_data/test_set4DSTC10-AVSD.json
Normal file
1
raw_data/test_set4DSTC10-AVSD.json
Normal file
File diff suppressed because one or more lines are too long
38956
raw_data/test_set4DSTC7-AVSD.json
Normal file
38956
raw_data/test_set4DSTC7-AVSD.json
Normal file
File diff suppressed because it is too large
Load diff
49596
raw_data/test_set4DSTC8-AVSD.json
Normal file
49596
raw_data/test_set4DSTC8-AVSD.json
Normal file
File diff suppressed because it is too large
Load diff
359979
raw_data/train_set4DSTC7-AVSD.json
Normal file
359979
raw_data/train_set4DSTC7-AVSD.json
Normal file
File diff suppressed because it is too large
Load diff
83995
raw_data/valid_set4DSTC7-AVSD.json
Normal file
83995
raw_data/valid_set4DSTC7-AVSD.json
Normal file
File diff suppressed because it is too large
Load diff
488
runners/runner.py
Normal file
488
runners/runner.py
Normal file
|
@ -0,0 +1,488 @@
|
||||||
|
import wandb
|
||||||
|
import os
|
||||||
|
import os.path as osp
|
||||||
|
import json
|
||||||
|
from collections import deque, OrderedDict
|
||||||
|
import time
|
||||||
|
import re
|
||||||
|
import shutil
|
||||||
|
import glob
|
||||||
|
import pickle
|
||||||
|
import gc
|
||||||
|
import numpy as np
|
||||||
|
import glog as log
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.data as tud
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch.nn.utils import clip_grad_value_
|
||||||
|
|
||||||
|
|
||||||
|
class Runner:
|
||||||
|
def __init__(self, config):
|
||||||
|
self.config = config
|
||||||
|
if 'rank' in config:
|
||||||
|
self.gpu_rank = config['rank']
|
||||||
|
else:
|
||||||
|
self.gpu_rank = 0
|
||||||
|
|
||||||
|
self.epoch_idx = 0
|
||||||
|
self.min_gen_val_loss = float('inf')
|
||||||
|
self.best_epoch_idx = 0
|
||||||
|
|
||||||
|
if self.config["max_ckpt_to_keep"] > 0:
|
||||||
|
self.checkpoint_queue = deque([], maxlen=config["max_ckpt_to_keep"])
|
||||||
|
self.metrics_queue = deque([], maxlen=config["max_ckpt_to_keep"])
|
||||||
|
|
||||||
|
self.setup_wandb()
|
||||||
|
|
||||||
|
def setup_wandb(self):
|
||||||
|
if self.gpu_rank == 0:
|
||||||
|
print("[INFO] Set wandb logging on rank {}".format(0))
|
||||||
|
run = wandb.init(
|
||||||
|
project=self.config['wandb_project'], config=self.config, mode=self.config['wandb_mode'])
|
||||||
|
else:
|
||||||
|
run = None
|
||||||
|
self.run = run
|
||||||
|
|
||||||
|
def forward(self, batch, eval=False):
|
||||||
|
return NotImplementedError
|
||||||
|
|
||||||
|
def train(self, dataset, dataset_eval):
|
||||||
|
batch_size = self.config['batch_size']
|
||||||
|
if self.config['parallel'] and self.config['dp_type'] != 'dp':
|
||||||
|
sampler = tud.distributed.DistributedSampler(
|
||||||
|
dataset,
|
||||||
|
num_replicas=self.config['num_gpus'],
|
||||||
|
rank=self.gpu_rank
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sampler = None
|
||||||
|
|
||||||
|
data_loader = tud.DataLoader(
|
||||||
|
dataset=dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=self.config['training'] and not self.config['parallel'],
|
||||||
|
collate_fn=dataset.collate_fn,
|
||||||
|
num_workers=self.config['num_workers'],
|
||||||
|
sampler=sampler
|
||||||
|
)
|
||||||
|
|
||||||
|
start_epoch_idx = self.epoch_idx
|
||||||
|
num_iter_epoch = self.config['num_iter_per_epoch']
|
||||||
|
if self.config['display']:
|
||||||
|
log.info(f'{num_iter_epoch} iter per epoch.')
|
||||||
|
|
||||||
|
num_epochs = self.config['num_epochs']
|
||||||
|
|
||||||
|
# Perform validation before training
|
||||||
|
if self.config['eval_first']:
|
||||||
|
_ = self.val(dataset_eval)
|
||||||
|
|
||||||
|
for epoch_idx in range(start_epoch_idx, num_epochs):
|
||||||
|
if self.config['parallel'] and self.config['dp_type'] != 'dp':
|
||||||
|
sampler.set_epoch(epoch_idx)
|
||||||
|
self.epoch_idx = epoch_idx
|
||||||
|
|
||||||
|
if self.config['display']:
|
||||||
|
log.info(f'starting epoch {epoch_idx}')
|
||||||
|
log.info('training')
|
||||||
|
|
||||||
|
self.model.train()
|
||||||
|
|
||||||
|
num_batch = 0
|
||||||
|
next_logging_pct = .1
|
||||||
|
start_time = time.time()
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
|
for batch in data_loader:
|
||||||
|
num_batch += 1
|
||||||
|
pct = num_batch / num_iter_epoch * 100
|
||||||
|
iter_now = num_iter_epoch * epoch_idx + num_batch
|
||||||
|
|
||||||
|
output = self.forward(batch)
|
||||||
|
|
||||||
|
losses = output['losses']
|
||||||
|
|
||||||
|
# optimizer step
|
||||||
|
losses['tot_loss'] /= self.config['batch_multiply']
|
||||||
|
# debug
|
||||||
|
if self.config['debugging']:
|
||||||
|
log.info('try backward')
|
||||||
|
|
||||||
|
losses['tot_loss'].backward()
|
||||||
|
if self.config['clip_grad_value'] > 0:
|
||||||
|
clip_grad_value_(self.model.parameters(), self.config['clip_grad_value'])
|
||||||
|
if self.config['debugging']:
|
||||||
|
log.info('backward done')
|
||||||
|
|
||||||
|
if iter_now % self.config['batch_multiply'] == 0:
|
||||||
|
self.optimizer.step()
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
|
self.scheduler.step()
|
||||||
|
|
||||||
|
# display and eval
|
||||||
|
if pct >= next_logging_pct:
|
||||||
|
if self.config['display']:
|
||||||
|
loss_to_print = ''
|
||||||
|
for key in losses:
|
||||||
|
if losses[key] is not None and isinstance(losses[key], torch.Tensor):
|
||||||
|
loss_to_print += f'[{key}: {losses[key].item():.4f}]'
|
||||||
|
print(
|
||||||
|
f'[{int(pct)}%][Epoch: {epoch_idx + 1}/{num_epochs}][Iter : {num_batch}/{len(data_loader)}] [time: {time.time() - start_time:.2f}] {loss_to_print}'
|
||||||
|
)
|
||||||
|
if self.config['print_output']:
|
||||||
|
print(10 * '-' + 'responses' + 10 * '-')
|
||||||
|
print(output['reponses'])
|
||||||
|
print(10 * '-' + 'gt' + 10 * '-')
|
||||||
|
print(output['gt'])
|
||||||
|
|
||||||
|
next_logging_pct += self.config["next_logging_pct"]
|
||||||
|
|
||||||
|
if self.config['debugging']:
|
||||||
|
break
|
||||||
|
|
||||||
|
lr_bart, lr_other = self.scheduler.get_lr()[0], self.scheduler.get_lr()[-1]
|
||||||
|
|
||||||
|
elbo_global_key = 'elbo_loss_global (x{})'.format(self.config['elbo_global_coeff'])
|
||||||
|
elbo_local_key = 'elbo_loss_local (x{})'.format(self.config['elbo_local_coeff'])
|
||||||
|
gen_key = 'gen_loss (x{})'.format(self.config['gen_coeff'])
|
||||||
|
if self.run:
|
||||||
|
self.run.log(
|
||||||
|
{
|
||||||
|
f"Train/{gen_key}": losses[gen_key].item(),
|
||||||
|
f"Train/{elbo_global_key}": losses[elbo_global_key].item(),
|
||||||
|
f"Train/{elbo_local_key}": losses[elbo_local_key].item(),
|
||||||
|
"Train/total_loss": losses['tot_loss'].item(),
|
||||||
|
},
|
||||||
|
step=iter_now
|
||||||
|
)
|
||||||
|
|
||||||
|
self.run.log(
|
||||||
|
{"Train/lr_bart": lr_bart, "Train/lr_other": lr_other},
|
||||||
|
step=iter_now
|
||||||
|
)
|
||||||
|
del losses
|
||||||
|
del output
|
||||||
|
|
||||||
|
if self.config['display']:
|
||||||
|
log.info(
|
||||||
|
f'100%,\ttime:\t{time.time() - start_time:.2f}'
|
||||||
|
)
|
||||||
|
if not self.config['overfit'] and self.run:
|
||||||
|
self.save_ckpt()
|
||||||
|
|
||||||
|
if not self.config['skip_eval']:
|
||||||
|
|
||||||
|
iter_now = num_iter_epoch * (epoch_idx + 1)
|
||||||
|
val_losses = self.val(dataset_eval)
|
||||||
|
|
||||||
|
if self.config['display']:
|
||||||
|
log.info('#'*100)
|
||||||
|
for k in val_losses:
|
||||||
|
log.info('Average val {} (epoch {}) = {}'.format(k, self.epoch_idx, val_losses[k]))
|
||||||
|
log.info('#'*100)
|
||||||
|
|
||||||
|
gen_val_loss = val_losses[gen_key]
|
||||||
|
|
||||||
|
if gen_val_loss < self.min_gen_val_loss:
|
||||||
|
self.min_gen_val_loss = gen_val_loss
|
||||||
|
self.best_epoch_idx = epoch_idx
|
||||||
|
# Log the best model w.r.t. the validation data
|
||||||
|
if self.run and self.config['save_ckpt']:
|
||||||
|
self.save_ckpt_best()
|
||||||
|
|
||||||
|
if self.run:
|
||||||
|
|
||||||
|
self.run.log(
|
||||||
|
{
|
||||||
|
f"Val/{gen_key}": val_losses[gen_key],
|
||||||
|
f"Val/{elbo_global_key}": val_losses[elbo_global_key],
|
||||||
|
f"Val/{elbo_local_key}": val_losses[elbo_local_key],
|
||||||
|
"Val/total_loss": val_losses['tot_loss'],
|
||||||
|
"Val/min_gen_loss": self.min_gen_val_loss
|
||||||
|
},
|
||||||
|
step=iter_now
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.config['parallel']:
|
||||||
|
if self.config['dp_type'] == 'dp':
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
else:
|
||||||
|
dist.barrier()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
if self.config['stop_epochs'] >= 0 and epoch_idx + 1 >= self.config['stop_epochs']:
|
||||||
|
if self.config['display']:
|
||||||
|
log.info('Stop for reaching stop_epochs.')
|
||||||
|
break
|
||||||
|
if self.config['display']:
|
||||||
|
log.info(f'Best validation loss was reached at epoch {self.best_epoch_idx}.')
|
||||||
|
|
||||||
|
def val(self, dataset):
|
||||||
|
total_loss_val = 0.0
|
||||||
|
total_gen_loss_val = 0.0
|
||||||
|
total_elbo_global_val = 0.0
|
||||||
|
total_elbo_local_val = 0.0
|
||||||
|
num_batch_val = 0
|
||||||
|
next_logging_pct_val = 0.05
|
||||||
|
|
||||||
|
elbo_global_key = 'elbo_loss_global (x{})'.format(self.config['elbo_global_coeff'])
|
||||||
|
elbo_local_key = 'elbo_loss_local (x{})'.format(self.config['elbo_local_coeff'])
|
||||||
|
gen_key = 'gen_loss (x{})'.format(self.config['gen_coeff'])
|
||||||
|
|
||||||
|
# Prepare the dataloader
|
||||||
|
if self.config['parallel'] and self.config['dp_type'] != 'dp':
|
||||||
|
sampler_val = tud.distributed.DistributedSampler(
|
||||||
|
dataset,
|
||||||
|
num_replicas=self.config['num_gpus'],
|
||||||
|
rank=self.gpu_rank
|
||||||
|
)
|
||||||
|
sampler_val.set_epoch(self.epoch_idx)
|
||||||
|
else:
|
||||||
|
sampler_val = None
|
||||||
|
|
||||||
|
data_loader_val = tud.DataLoader(
|
||||||
|
dataset=dataset,
|
||||||
|
batch_size=self.config['batch_size'],
|
||||||
|
shuffle=False,
|
||||||
|
collate_fn=dataset.collate_fn,
|
||||||
|
num_workers=self.config['num_workers'],
|
||||||
|
sampler=sampler_val
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.config['parallel'] and self.config['dp_type'] == 'dp':
|
||||||
|
num_iter_per_epoch_val = int(np.ceil(len(dataset) / self.config['batch_size']))
|
||||||
|
else:
|
||||||
|
num_iter_per_epoch_val = int(np.ceil(len(dataset) / (self.config['batch_size'] * self.config['num_gpus'])))
|
||||||
|
|
||||||
|
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
if self.gpu_rank == 0:
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
for batch in data_loader_val:
|
||||||
|
num_batch_val += 1
|
||||||
|
|
||||||
|
pct = num_batch_val / num_iter_per_epoch_val * 100
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output = self.forward(batch)
|
||||||
|
|
||||||
|
losses = output['losses']
|
||||||
|
|
||||||
|
losses['tot_loss'] /= self.config['batch_multiply']
|
||||||
|
losses[elbo_global_key] /= self.config['batch_multiply']
|
||||||
|
losses[elbo_local_key] /= self.config['batch_multiply']
|
||||||
|
losses[gen_key] /= self.config['batch_multiply']
|
||||||
|
|
||||||
|
total_loss_val += losses['tot_loss'].item()
|
||||||
|
total_gen_loss_val += losses[gen_key].item()
|
||||||
|
total_elbo_global_val += losses[elbo_global_key].item()
|
||||||
|
total_elbo_local_val += losses[elbo_local_key].item()
|
||||||
|
|
||||||
|
# display and eval
|
||||||
|
if pct >= next_logging_pct_val:
|
||||||
|
if self.config['display']:
|
||||||
|
loss_to_print = ''
|
||||||
|
for key in losses:
|
||||||
|
if losses[key] is not None and isinstance(losses[key], torch.Tensor):
|
||||||
|
loss_to_print += f'[{key}: {losses[key].item():.4f}]'
|
||||||
|
print(
|
||||||
|
f'[{int(pct)}%][Validating][Iter : {num_batch_val}/{num_iter_per_epoch_val}] [time: {time.time() - start_time:.2f}] {loss_to_print}'
|
||||||
|
)
|
||||||
|
|
||||||
|
next_logging_pct_val += self.config["next_logging_pct"]
|
||||||
|
loss_val = total_loss_val / num_batch_val
|
||||||
|
gen_loss_val = total_gen_loss_val / num_batch_val
|
||||||
|
elbo_global_val = total_elbo_global_val / num_batch_val
|
||||||
|
elbo_local_val = total_elbo_local_val / num_batch_val
|
||||||
|
|
||||||
|
losses_val = {
|
||||||
|
'tot_loss': loss_val,
|
||||||
|
elbo_global_key: elbo_global_val,
|
||||||
|
elbo_local_key: elbo_local_val,
|
||||||
|
gen_key: gen_loss_val
|
||||||
|
}
|
||||||
|
self.model.train()
|
||||||
|
return losses_val
|
||||||
|
|
||||||
|
|
||||||
|
def save_eval_results(self, split, epoch_idx, metrics_results):
|
||||||
|
|
||||||
|
metrics_filename = osp.join(self.config['log_dir'], f'metrics_epoch_{epoch_idx}.json')
|
||||||
|
with open(metrics_filename, 'w') as f:
|
||||||
|
json.dump(metrics_results, f)
|
||||||
|
log.info(f'Results of metrics saved to {metrics_filename}')
|
||||||
|
|
||||||
|
if self.config["max_ckpt_to_keep"] > 0:
|
||||||
|
if len(self.metrics_queue) == self.metrics_queue.maxlen:
|
||||||
|
todel = self.metrics_queue.popleft()
|
||||||
|
os.remove(todel)
|
||||||
|
self.metrics_queue.append(metrics_filename)
|
||||||
|
|
||||||
|
if epoch_idx == 'best':
|
||||||
|
self.copy_best_predictions(split)
|
||||||
|
|
||||||
|
def copy_best_results(self, split, epoch_idx):
|
||||||
|
to_print = 'Copy '
|
||||||
|
|
||||||
|
if not self.config['skip_saving_ckpt']:
|
||||||
|
ckpt_path = osp.join(self.config['log_dir'], f'epoch_{epoch_idx}.ckpt')
|
||||||
|
best_ckpt_path = ckpt_path.replace(f'{epoch_idx}.ckpt', 'best.ckpt')
|
||||||
|
shutil.copyfile(ckpt_path, best_ckpt_path)
|
||||||
|
to_print += best_ckpt_path + ' '
|
||||||
|
|
||||||
|
metrics_filename = osp.join(self.config['log_dir'], f'metrics_epoch_{epoch_idx}.json')
|
||||||
|
best_metric_filename = metrics_filename.replace(f'{epoch_idx}.json', 'best.json')
|
||||||
|
shutil.copyfile(metrics_filename, best_metric_filename)
|
||||||
|
to_print += best_metric_filename + ' '
|
||||||
|
|
||||||
|
log.info(to_print)
|
||||||
|
|
||||||
|
|
||||||
|
def set_ckpt(self, ckpt_dict):
|
||||||
|
if self.config['parallel']:
|
||||||
|
model = self.model.module
|
||||||
|
else:
|
||||||
|
model = self.model
|
||||||
|
|
||||||
|
model_state_dict = model.state_dict()
|
||||||
|
|
||||||
|
former_dict = {k: v for k, v in ckpt_dict['model_state_dict'].items() if k in model_state_dict}
|
||||||
|
|
||||||
|
if self.config['display']:
|
||||||
|
log.info("number of keys transferred: %d" % len(former_dict))
|
||||||
|
assert len(former_dict.keys()) > 0
|
||||||
|
|
||||||
|
model_state_dict.update(former_dict)
|
||||||
|
|
||||||
|
model.load_state_dict(model_state_dict)
|
||||||
|
if self.config['display']:
|
||||||
|
log.info('loaded model')
|
||||||
|
del model_state_dict, former_dict
|
||||||
|
|
||||||
|
# if not self.config['uses_new_optimizer']:
|
||||||
|
if not self.config['generating'] and not (self.config['uses_new_optimizer'] or self.config['sets_new_lr']):
|
||||||
|
if not self.config['restarts']:
|
||||||
|
self.epoch_idx = ckpt_dict['epoch_idx'] + 1
|
||||||
|
|
||||||
|
if not self.config['resets_min_val_loss']:
|
||||||
|
self.min_gen_val_loss = ckpt_dict['min_gen_val_loss']
|
||||||
|
|
||||||
|
self.optimizer.load_state_dict(ckpt_dict['optimizer'])
|
||||||
|
if self.config['display']:
|
||||||
|
log.info('loaded optimizer')
|
||||||
|
if 'scheduler' in ckpt_dict:
|
||||||
|
self.scheduler.last_epcoh = ckpt_dict['epoch_idx'] * self.config['num_iter_per_epoch']
|
||||||
|
self.scheduler.load_state_dict(ckpt_dict['scheduler'])
|
||||||
|
|
||||||
|
del ckpt_dict
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
def save_ckpt(self):
|
||||||
|
ckpt_path = f'{self.config["log_dir"]}/epoch_{self.epoch_idx}.ckpt'
|
||||||
|
log.info(f'saving checkpoint {ckpt_path}')
|
||||||
|
ckpt = self.get_ckpt()
|
||||||
|
if self.config['skip_saving_ckpt']:
|
||||||
|
return ckpt_path
|
||||||
|
torch_version = float(torch.__version__[:3])
|
||||||
|
if torch_version - 1.4 > 1e-3:
|
||||||
|
torch.save(ckpt, f=ckpt_path, _use_new_zipfile_serialization=False)
|
||||||
|
else:
|
||||||
|
torch.save(ckpt, f=ckpt_path)
|
||||||
|
del ckpt
|
||||||
|
if not self.config['parallel']:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
if self.config["max_ckpt_to_keep"] > 0:
|
||||||
|
if len(self.checkpoint_queue) == self.checkpoint_queue.maxlen:
|
||||||
|
todel = self.checkpoint_queue.popleft()
|
||||||
|
os.remove(todel)
|
||||||
|
self.checkpoint_queue.append(ckpt_path)
|
||||||
|
|
||||||
|
def save_ckpt_best(self):
|
||||||
|
ckpt_path = f'{self.config["log_dir"]}/epoch_best.ckpt'
|
||||||
|
log.info(f'saving checkpoint {ckpt_path}')
|
||||||
|
ckpt = self.get_ckpt()
|
||||||
|
torch.save(ckpt, f=ckpt_path)
|
||||||
|
del ckpt
|
||||||
|
return ckpt_path
|
||||||
|
|
||||||
|
def load_ckpt_best(self):
|
||||||
|
ckpt_path = f'{self.config["log_dir"]}/epoch_best.ckpt'
|
||||||
|
if not osp.exists(ckpt_path):
|
||||||
|
ckpt_paths = [path for path in os.listdir(f'{self.config["log_dir"]}/') if path.endswith('.ckpt') and 'best' not in path]
|
||||||
|
if len(ckpt_paths) == 0:
|
||||||
|
if self.config['display']:
|
||||||
|
log.info(f'No .ckpt found in {self.config["log_dir"]}')
|
||||||
|
return
|
||||||
|
sort_func = lambda x:int(re.search(r"(\d+)", x).groups()[0])
|
||||||
|
ckpt_path = f'{self.config["log_dir"]}/{sorted(ckpt_paths, key=sort_func, reverse=True)[0]}'
|
||||||
|
if self.config['display']:
|
||||||
|
log.info(f'loading checkpoint {ckpt_path}')
|
||||||
|
map_location = {'cuda:0': f'cuda:{self.gpu_rank}'}
|
||||||
|
self.set_ckpt(torch.load(ckpt_path, map_location=map_location))
|
||||||
|
|
||||||
|
def get_ckpt(self):
|
||||||
|
ckpt = {
|
||||||
|
'epoch_idx': self.epoch_idx,
|
||||||
|
'min_gen_val_loss': self.min_gen_val_loss,
|
||||||
|
'seed': self.config['random_seed'],
|
||||||
|
'optimizer': self.optimizer.state_dict(),
|
||||||
|
'scheduler': self.scheduler.state_dict()
|
||||||
|
}
|
||||||
|
ckpt['model_state_dict'] = self.model.module.state_dict()
|
||||||
|
return ckpt
|
||||||
|
|
||||||
|
def load_ckpt(self, ckpt_path=None):
|
||||||
|
if not ckpt_path:
|
||||||
|
if self.config['generating']: # or self.config['start_ckpt_for_generating']:
|
||||||
|
ckpt_path = f'{self.config["log_dir"]}/epoch_best.ckpt'
|
||||||
|
else:
|
||||||
|
ckpt_paths = [path for path in os.listdir(f'{self.config["log_dir"]}/') if path.endswith('.ckpt') and 'best' not in path]
|
||||||
|
if len(ckpt_paths) == 0:
|
||||||
|
if self.config['display']:
|
||||||
|
log.info(f'No .ckpt found in {self.config["log_dir"]}')
|
||||||
|
return
|
||||||
|
sort_func = lambda x:int(re.search(r"(\d+)", x).groups()[0])
|
||||||
|
ckpt_path = f'{self.config["log_dir"]}/{sorted(ckpt_paths, key=sort_func, reverse=True)[0]}'
|
||||||
|
|
||||||
|
if self.config['display']:
|
||||||
|
log.info(f'loading checkpoint {ckpt_path}')
|
||||||
|
epoch_name = osp.split(ckpt_path)[1].split('.')[0]
|
||||||
|
if re.search(r"(\d+)", epoch_name):
|
||||||
|
self.checkpoint_queue.append(ckpt_path)
|
||||||
|
metrics_filename = osp.join(self.config['log_dir'], f'metrics_{epoch_name}.json')
|
||||||
|
if osp.exists(metrics_filename):
|
||||||
|
self.metrics_queue.append(metrics_filename)
|
||||||
|
|
||||||
|
map_location = {'cuda:0': f'cuda:{self.gpu_rank}'}
|
||||||
|
self.set_ckpt(torch.load(ckpt_path, map_location=map_location))
|
||||||
|
|
||||||
|
def match_model_key(self, pretrained_dict, model_dict):
|
||||||
|
matched_dict = dict()
|
||||||
|
not_found = []
|
||||||
|
for key in pretrained_dict:
|
||||||
|
if key in model_dict:
|
||||||
|
matched_key = key
|
||||||
|
elif key.startswith('encoder.') and key[8:] in model_dict:
|
||||||
|
matched_key = key[8:]
|
||||||
|
elif key.startswith('module.') and key[7:] in model_dict:
|
||||||
|
matched_key = key[7:]
|
||||||
|
elif 'encoder.' + key in model_dict:
|
||||||
|
matched_key = 'encoder.' + key
|
||||||
|
elif 'module.' + key in model_dict:
|
||||||
|
matched_key = 'module.' + key
|
||||||
|
else:
|
||||||
|
not_found.append(key)
|
||||||
|
continue
|
||||||
|
matched_dict[matched_key] = pretrained_dict[key]
|
||||||
|
print("Keys from pretrained_dict that were not found in model_dict:\n", not_found)
|
||||||
|
return matched_dict
|
337
runners/runner_avsd.py
Normal file
337
runners/runner_avsd.py
Normal file
|
@ -0,0 +1,337 @@
|
||||||
|
import time
|
||||||
|
import os
|
||||||
|
import glog as log
|
||||||
|
import numpy as np
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from runners.runner import Runner
|
||||||
|
from copy import deepcopy
|
||||||
|
from optim_utils import init_optim
|
||||||
|
from transformers.models.bart.configuration_bart import BartConfig
|
||||||
|
from models.avsd_bart import AVSDBart
|
||||||
|
|
||||||
|
from custom_datasets.avsd import build_input_from_segments
|
||||||
|
from time import time
|
||||||
|
|
||||||
|
|
||||||
|
class AVSDRunner(Runner):
|
||||||
|
def __init__(self, config, tokenizer, vocab_size):
|
||||||
|
super(AVSDRunner, self).__init__(config)
|
||||||
|
bart_config = BartConfig.from_json_file(self.config['bart_config'])
|
||||||
|
|
||||||
|
self.model = AVSDBart.from_pretrained(
|
||||||
|
'facebook/bart-{}'.format(self.config['bart_size']), config=bart_config)
|
||||||
|
|
||||||
|
# Resize the embedding to match the vocab with additional special toks
|
||||||
|
# This takes care of resizing weights of related parts of the network
|
||||||
|
# pytorch_total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
||||||
|
# print(pytorch_total_params)
|
||||||
|
|
||||||
|
if vocab_size != bart_config.vocab_size:
|
||||||
|
self.model.resize_token_embeddings(vocab_size)
|
||||||
|
|
||||||
|
self.model.to(self.config['device'])
|
||||||
|
if not self.config['generating']:
|
||||||
|
self.optimizer, self.scheduler = init_optim(self.model, self.config)
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
|
def forward(self, batch):
|
||||||
|
|
||||||
|
for key in batch:
|
||||||
|
if isinstance(batch[key], torch.Tensor):
|
||||||
|
batch[key] = batch[key].cuda()
|
||||||
|
|
||||||
|
########################################################
|
||||||
|
input_ids = batch['input_ids']
|
||||||
|
video_place_holder_ids = batch['video_place_holder_ids']
|
||||||
|
i3d_rgb = batch['i3d_rgb']
|
||||||
|
i3d_flow = batch['i3d_flow']
|
||||||
|
sam = batch['sam']
|
||||||
|
vggish = batch['vggish']
|
||||||
|
lm_labels = batch['lm_labels']
|
||||||
|
input_mask = batch['input_mask']
|
||||||
|
|
||||||
|
i3d_rgb_interval = batch['i3d_rgb_interval']
|
||||||
|
i3d_flow_interval = batch['i3d_flow_interval']
|
||||||
|
sam_interval = batch['sam_interval']
|
||||||
|
audio_interval = batch['audio_interval']
|
||||||
|
history_intervals = batch['history_intervals']
|
||||||
|
question_intervals = batch['question_intervals']
|
||||||
|
vis_state_vector_idx = batch['vis_state_vector_idx']
|
||||||
|
history_state_vector_idx = batch['history_state_vector_idx']
|
||||||
|
question_state_vector_idx = batch['question_state_vector_idx']
|
||||||
|
|
||||||
|
########################################################
|
||||||
|
bart_output = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
video_place_holder_ids=video_place_holder_ids,
|
||||||
|
i3d_rgb=i3d_rgb,
|
||||||
|
i3d_flow=i3d_flow,
|
||||||
|
sam=sam,
|
||||||
|
vggish=vggish,
|
||||||
|
attention_mask=input_mask,
|
||||||
|
labels=lm_labels,
|
||||||
|
i3d_rgb_interval=i3d_rgb_interval,
|
||||||
|
i3d_flow_interval=i3d_flow_interval,
|
||||||
|
sam_interval=sam_interval,
|
||||||
|
audio_interval=audio_interval,
|
||||||
|
history_intervals=history_intervals,
|
||||||
|
question_intervals=question_intervals,
|
||||||
|
vis_state_vector_idx=vis_state_vector_idx,
|
||||||
|
history_state_vector_idx=history_state_vector_idx,
|
||||||
|
question_state_vector_idx=question_state_vector_idx,
|
||||||
|
output_attentions=True,
|
||||||
|
return_dict=True
|
||||||
|
)
|
||||||
|
|
||||||
|
output = {}
|
||||||
|
|
||||||
|
if self.config['print_output']:
|
||||||
|
logits = bart_output['logits']
|
||||||
|
probs = F.softmax(logits, dim=-1)
|
||||||
|
preds = torch.topk(probs, 1)[1].squeeze(-1)
|
||||||
|
preds = preds.tolist()
|
||||||
|
lm_labels_list = lm_labels[:, 1:].tolist()
|
||||||
|
lm_labels_list = [[s for s in label if s != -1] for label in lm_labels_list]
|
||||||
|
reponses = ''
|
||||||
|
labels = ''
|
||||||
|
for pred, label in zip(preds, lm_labels_list):
|
||||||
|
reponses += self.tokenizer.decode(pred) + '\n'
|
||||||
|
labels += self.tokenizer.decode(label) + '\n'
|
||||||
|
|
||||||
|
output['reponses'] = reponses
|
||||||
|
output['gt'] = labels
|
||||||
|
|
||||||
|
|
||||||
|
gen_key = 'gen_loss (x{})'.format(self.config['gen_coeff'])
|
||||||
|
gen_loss = bart_output['gen_loss']
|
||||||
|
gen_loss = self.config['gen_coeff'] * gen_loss
|
||||||
|
|
||||||
|
|
||||||
|
elbo_global_key = 'elbo_loss_global (x{})'.format(self.config['elbo_global_coeff'])
|
||||||
|
if bart_output['elbo_loss_global'] is not None:
|
||||||
|
elbo_global_loss = bart_output['elbo_loss_global']
|
||||||
|
elbo_global_loss = self.config['elbo_global_coeff'] * elbo_global_loss
|
||||||
|
else:
|
||||||
|
elbo_global_loss = torch.tensor(0.0)
|
||||||
|
|
||||||
|
elbo_local_key = 'elbo_loss_local (x{})'.format(self.config['elbo_local_coeff'])
|
||||||
|
if bart_output['elbo_loss_local'] is not None:
|
||||||
|
elbo_local_loss = bart_output['elbo_loss_local']
|
||||||
|
elbo_local_loss = self.config['elbo_local_coeff'] * elbo_local_loss
|
||||||
|
else:
|
||||||
|
elbo_local_loss = torch.tensor(0.0)
|
||||||
|
|
||||||
|
total_loss = gen_loss + elbo_global_loss + elbo_local_loss
|
||||||
|
|
||||||
|
output['losses'] = {
|
||||||
|
gen_key: gen_loss,
|
||||||
|
elbo_local_key: elbo_local_loss,
|
||||||
|
elbo_global_key: elbo_global_loss,
|
||||||
|
'tot_loss': total_loss
|
||||||
|
}
|
||||||
|
del bart_output
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def generate(self, dataset, tag, tokenizer, gen_subset_num=None):
|
||||||
|
|
||||||
|
self.model.eval()
|
||||||
|
responses = {}
|
||||||
|
i3d_flow_sep, i3d_rgb_sep, sam_sep, audio_sep, ph_token = tokenizer.convert_tokens_to_ids(
|
||||||
|
['<s0>', '<s1>', '<s2>', '<s3>', '<place_holder>'])
|
||||||
|
|
||||||
|
# Generate the repsonse for each round
|
||||||
|
log.info('[INFO] Generating responses for {} samples'.format(len(dataset)))
|
||||||
|
with torch.no_grad():
|
||||||
|
for counter, dialog in enumerate(dataset):
|
||||||
|
start_time = time()
|
||||||
|
vid = dialog['vid']
|
||||||
|
|
||||||
|
i3d_rgb = np.load(os.path.join(self.config['avsd_i3d_rgb_test'], vid + '.npy'))
|
||||||
|
i3d_flow = np.load(os.path.join(self.config['avsd_i3d_flow_test'], vid + '.npy'))
|
||||||
|
sam = np.load(os.path.join(self.config['avsd_objects_test'], vid + '.npy'))
|
||||||
|
vggish = np.load(os.path.join(self.config['avsd_audio_test'], vid + '.npy'))
|
||||||
|
|
||||||
|
min_length = min([self.config['vis_feat_length'], i3d_rgb.shape[0], i3d_flow.shape[0], sam.shape[0], vggish.shape[0]])
|
||||||
|
sample_idx_i3d_rgb = np.round(np.linspace(0, i3d_rgb.shape[0] - 1, min_length)).astype(int)
|
||||||
|
sample_idx_i3d_flow = np.round(np.linspace(0, i3d_flow.shape[0] - 1, min_length)).astype(int)
|
||||||
|
sample_idx_sam = np.round(np.linspace(0, sam.shape[0] - 1, min_length)).astype(int)
|
||||||
|
sample_idx_vggish = np.round(np.linspace(0, vggish.shape[0] - 1, min_length)).astype(int)
|
||||||
|
|
||||||
|
i3d_rgb = torch.from_numpy(i3d_rgb[sample_idx_i3d_rgb, :]).float()
|
||||||
|
i3d_flow = torch.from_numpy(i3d_flow[sample_idx_i3d_flow, :]).float()
|
||||||
|
sam = torch.from_numpy(sam[sample_idx_sam, :]).float()
|
||||||
|
vggish = torch.from_numpy(vggish[sample_idx_vggish, :]).float()
|
||||||
|
|
||||||
|
dummy = torch.ones((1, min_length)) * ph_token
|
||||||
|
video_place_holder_ids = torch.cat(
|
||||||
|
[torch.ones((1, 1)) * i3d_rgb_sep, dummy,
|
||||||
|
torch.ones((1, 1)) * i3d_flow_sep, dummy,
|
||||||
|
torch.ones((1, 1)) * sam_sep, dummy,
|
||||||
|
torch.ones((1, 1)) * audio_sep, dummy,
|
||||||
|
], dim=-1).long()
|
||||||
|
# Now we get the intervals of the visual input tokens
|
||||||
|
# Here the interval do not change across the batch dimension
|
||||||
|
i3d_rgb_interval = [0, min_length + 1] # the last token is not part of this modality
|
||||||
|
i3d_flow_interval = [min_length + 1, 2 * min_length + 2]
|
||||||
|
sam_interval = [2 * min_length + 2, 3 * min_length + 3]
|
||||||
|
audio_interval = [3 * min_length + 3, 4 * min_length + 4]
|
||||||
|
vis_state_vector_idx = [i3d_rgb_interval[0], i3d_flow_interval[0], sam_interval[0], audio_interval[0]]
|
||||||
|
|
||||||
|
|
||||||
|
response = self.beam_search_generation(
|
||||||
|
dialog['caption'], dialog['history'],
|
||||||
|
i3d_rgb, i3d_flow, sam, vggish,
|
||||||
|
i3d_rgb_interval, i3d_flow_interval, sam_interval, audio_interval,
|
||||||
|
vis_state_vector_idx, video_place_holder_ids, tokenizer)
|
||||||
|
|
||||||
|
# Decode the response
|
||||||
|
response = self.tokenizer.decode(response)
|
||||||
|
responses[vid] = response
|
||||||
|
# all_graphs[vid] = graphs
|
||||||
|
time_elapsed = int(time() - start_time)
|
||||||
|
print('Generating resonse {} / {} -- took {}s'.format(counter + 1, len(dataset), time_elapsed))
|
||||||
|
|
||||||
|
# Create a file with all responses
|
||||||
|
with open(self.config['avsd_test_dstc{}'.format(self.config['dstc'])], 'r') as f:
|
||||||
|
test_data = json.load(f)
|
||||||
|
test_dialogs = deepcopy(test_data['dialogs'])
|
||||||
|
# Filter the predicted dialogs
|
||||||
|
test_dialogs = list(filter(lambda diag: diag['image_id'] in responses, test_dialogs))
|
||||||
|
|
||||||
|
for i, dialog in enumerate(test_dialogs):
|
||||||
|
vid_id = dialog['image_id']
|
||||||
|
gen_response = responses[vid_id]
|
||||||
|
round_num_to_answer = len(dialog['dialog'])-1
|
||||||
|
assert dialog['dialog'][round_num_to_answer]['answer'] == '__UNDISCLOSED__'
|
||||||
|
dialog['dialog'][round_num_to_answer]['answer'] = gen_response
|
||||||
|
test_dialogs[i] = dialog
|
||||||
|
|
||||||
|
# Log the file
|
||||||
|
file_name = 'results_dstc{}_beam_depth_{}'.format(self.config['dstc'], self.config['beam_depth'])
|
||||||
|
if gen_subset_num is not None:
|
||||||
|
file_name += f'-part_{gen_subset_num}'
|
||||||
|
file_name = f'{tag}_' + file_name
|
||||||
|
output_path = os.path.join(self.config['output_dir_dstc{}'.format(self.config['dstc'])], file_name + '.json')
|
||||||
|
with open(output_path, 'w') as f:
|
||||||
|
json.dump({'dialogs': test_dialogs}, f, indent=4)
|
||||||
|
log.info('Results logged to {}'.format(output_path))
|
||||||
|
print(os.getcwd())
|
||||||
|
# Switch back to training mode
|
||||||
|
self.model.train()
|
||||||
|
|
||||||
|
|
||||||
|
def beam_search_generation(
|
||||||
|
self, caption, history,
|
||||||
|
i3d_rgb, i3d_flow, sam, vggish,
|
||||||
|
i3d_rgb_interval, i3d_flow_interval, sam_interval, audio_interval,
|
||||||
|
vis_state_vector_idx, video_place_holder_ids, tokenizer):
|
||||||
|
|
||||||
|
eos_token = tokenizer.eos_token_id
|
||||||
|
unk_token = tokenizer.unk_token_id
|
||||||
|
question_sep = tokenizer.convert_tokens_to_ids('<s5>')
|
||||||
|
|
||||||
|
gen_ans = [eos_token]
|
||||||
|
hyplist = [([], 0.0, [eos_token])]
|
||||||
|
best_state = None
|
||||||
|
comp_hyplist = []
|
||||||
|
|
||||||
|
i3d_rgb = i3d_rgb.unsqueeze(0).cuda()
|
||||||
|
i3d_flow = i3d_flow.unsqueeze(0).cuda()
|
||||||
|
sam = sam.unsqueeze(0).cuda()
|
||||||
|
vggish = vggish.unsqueeze(0).cuda()
|
||||||
|
video_place_holder_ids = video_place_holder_ids.cuda()
|
||||||
|
text_shift_len = video_place_holder_ids.size(-1)
|
||||||
|
|
||||||
|
drop_caption = self.config['dstc'] == 10
|
||||||
|
instance = build_input_from_segments(caption, history, gen_ans, tokenizer, drop_caption=drop_caption)
|
||||||
|
|
||||||
|
input_ids = torch.tensor(instance['input_ids'])
|
||||||
|
history_end = (input_ids == question_sep).nonzero(as_tuple=True)[0]
|
||||||
|
history_intervals = [[0 + text_shift_len, history_end.item() + text_shift_len]] # The last token is the question state token (not part of the history)
|
||||||
|
question_intervals = [[history_end.item() + text_shift_len, input_ids.size(0) + text_shift_len]]
|
||||||
|
|
||||||
|
history_state_vector_idx = [x[0] + 1 for x in history_intervals] # +1 because the history starts with <s><s4> .....
|
||||||
|
question_state_vector_idx = [x[0] for x in question_intervals] # +1 because the history starts with <s><s4> .....
|
||||||
|
|
||||||
|
input_ids = input_ids.long().cuda().unsqueeze(0)
|
||||||
|
encoder_outputs = None
|
||||||
|
|
||||||
|
for i in range(self.config['max_generation_length']):
|
||||||
|
new_hyplist = []
|
||||||
|
argmin = 0
|
||||||
|
for out, lp, st in hyplist:
|
||||||
|
decoder_input_ids = torch.tensor(st).long().cuda().unsqueeze(0)
|
||||||
|
|
||||||
|
bart_output = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
video_place_holder_ids=video_place_holder_ids,
|
||||||
|
i3d_rgb=i3d_rgb,
|
||||||
|
i3d_flow=i3d_flow,
|
||||||
|
sam=sam,
|
||||||
|
vggish=vggish,
|
||||||
|
encoder_outputs=encoder_outputs,
|
||||||
|
decoder_input_ids=decoder_input_ids,
|
||||||
|
i3d_rgb_interval=i3d_rgb_interval,
|
||||||
|
i3d_flow_interval=i3d_flow_interval,
|
||||||
|
sam_interval=sam_interval,
|
||||||
|
audio_interval=audio_interval,
|
||||||
|
history_intervals=history_intervals,
|
||||||
|
question_intervals=question_intervals,
|
||||||
|
vis_state_vector_idx=vis_state_vector_idx,
|
||||||
|
history_state_vector_idx=history_state_vector_idx,
|
||||||
|
question_state_vector_idx=question_state_vector_idx,
|
||||||
|
output_attentions=True,
|
||||||
|
generate=True,
|
||||||
|
return_dict=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if encoder_outputs is None:
|
||||||
|
encoder_outputs = [
|
||||||
|
bart_output['encoder_last_hidden_state'],
|
||||||
|
bart_output['encoder_hidden_states'],
|
||||||
|
bart_output['encoder_attentions'],
|
||||||
|
bart_output['encoder_QAs_local'],
|
||||||
|
bart_output['encoder_PAs_local'],
|
||||||
|
bart_output['encoder_QA_global'],
|
||||||
|
bart_output['encoder_PA_global'],
|
||||||
|
bart_output['encoder_state_vectors']
|
||||||
|
]
|
||||||
|
|
||||||
|
logits = bart_output['logits'][:,-1,:].squeeze() # get the logits of the last token
|
||||||
|
logp = F.log_softmax(logits, dim=0)
|
||||||
|
lp_vec = logp.cpu().data.numpy() + lp
|
||||||
|
if i >= self.config['min_generation_length']:
|
||||||
|
new_lp = lp_vec[eos_token] + self.config['length_penalty'] * (len(out) + 1)
|
||||||
|
comp_hyplist.append((out, new_lp))
|
||||||
|
if best_state is None or best_state < new_lp:
|
||||||
|
best_state = new_lp
|
||||||
|
count = 1
|
||||||
|
for o in np.argsort(lp_vec)[::-1]: # reverse the order
|
||||||
|
if o in [eos_token, unk_token]:
|
||||||
|
continue
|
||||||
|
new_lp = lp_vec[o]
|
||||||
|
if len(new_hyplist) == self.config['beam_depth']:
|
||||||
|
if new_hyplist[argmin][1] < new_lp:
|
||||||
|
new_st = deepcopy(st)
|
||||||
|
new_st.append(int(o))
|
||||||
|
new_hyplist[argmin] = (out + [o], new_lp, new_st)
|
||||||
|
argmin = min(enumerate(new_hyplist), key=lambda h: h[1][1])[0]
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
new_st = deepcopy(st)
|
||||||
|
new_st.append(int(o))
|
||||||
|
new_hyplist.append((out + [o], new_lp, new_st))
|
||||||
|
if len(new_hyplist) == self.config['beam_depth']:
|
||||||
|
argmin = min(enumerate(new_hyplist), key=lambda h: h[1][1])[0]
|
||||||
|
count += 1
|
||||||
|
hyplist = new_hyplist
|
||||||
|
|
||||||
|
if len(comp_hyplist) > 0:
|
||||||
|
maxhyps = sorted(comp_hyplist, key=lambda h: -h[1])[:1]
|
||||||
|
return maxhyps[0][0]
|
||||||
|
else:
|
||||||
|
return []
|
300
runners/runner_nextqa.py
Normal file
300
runners/runner_nextqa.py
Normal file
|
@ -0,0 +1,300 @@
|
||||||
|
import time
|
||||||
|
import os
|
||||||
|
import glog as log
|
||||||
|
import numpy as np
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from runners.runner import Runner
|
||||||
|
from copy import deepcopy
|
||||||
|
from optim_utils import init_optim
|
||||||
|
from transformers.models.bart.configuration_bart import BartConfig
|
||||||
|
from models.nextqa_bart import AVSDBart
|
||||||
|
from time import time
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize(obj, tokenizer):
|
||||||
|
if isinstance(obj, str):
|
||||||
|
return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj))
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
return dict((n, tokenize(o)) for n, o in obj.items())
|
||||||
|
return list(tokenize(o) for o in obj)
|
||||||
|
|
||||||
|
|
||||||
|
class NEXTQARunner(Runner):
|
||||||
|
def __init__(self, config, tokenizer, vocab_size):
|
||||||
|
super(NEXTQARunner, self).__init__(config)
|
||||||
|
bart_config = BartConfig.from_json_file(self.config['bart_config'])
|
||||||
|
|
||||||
|
self.model = AVSDBart.from_pretrained(
|
||||||
|
'facebook/bart-{}'.format(self.config['bart_size']), config=bart_config)
|
||||||
|
|
||||||
|
# Resize the embedding to match the vocab with additional special toks
|
||||||
|
# This takes care of resizing weights of related parts of the network
|
||||||
|
|
||||||
|
if vocab_size != bart_config.vocab_size:
|
||||||
|
self.model.resize_token_embeddings(vocab_size)
|
||||||
|
|
||||||
|
self.model.to(self.config['device'])
|
||||||
|
if not self.config['generating']:
|
||||||
|
self.optimizer, self.scheduler = init_optim(self.model, self.config)
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
|
def forward(self, batch):
|
||||||
|
|
||||||
|
for key in batch:
|
||||||
|
if isinstance(batch[key], torch.Tensor):
|
||||||
|
batch[key] = batch[key].cuda()
|
||||||
|
|
||||||
|
########################################################
|
||||||
|
input_ids = batch['input_ids']
|
||||||
|
video_place_holder_ids = batch['video_place_holder_ids']
|
||||||
|
app_feats = batch['app_feats']
|
||||||
|
mot_feats = batch['mot_feats']
|
||||||
|
lm_labels = batch['lm_labels']
|
||||||
|
input_mask = batch['input_mask']
|
||||||
|
|
||||||
|
app_interval = batch['app_interval']
|
||||||
|
mot_interval = batch['mot_interval']
|
||||||
|
question_intervals = batch['question_intervals']
|
||||||
|
vis_state_vector_idx = batch['vis_state_vector_idx']
|
||||||
|
question_state_vector_idx = batch['question_state_vector_idx']
|
||||||
|
########################################################
|
||||||
|
|
||||||
|
bart_output = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
video_place_holder_ids=video_place_holder_ids,
|
||||||
|
i3d_rgb=app_feats,
|
||||||
|
i3d_flow=mot_feats,
|
||||||
|
attention_mask=input_mask,
|
||||||
|
labels=lm_labels,
|
||||||
|
i3d_rgb_interval=app_interval,
|
||||||
|
i3d_flow_interval=mot_interval,
|
||||||
|
question_intervals=question_intervals,
|
||||||
|
vis_state_vector_idx=vis_state_vector_idx,
|
||||||
|
question_state_vector_idx=question_state_vector_idx,
|
||||||
|
output_attentions=True,
|
||||||
|
return_dict=True
|
||||||
|
)
|
||||||
|
|
||||||
|
output = {}
|
||||||
|
|
||||||
|
if self.config['print_output']:
|
||||||
|
logits = bart_output['logits']
|
||||||
|
probs = F.softmax(logits, dim=-1)
|
||||||
|
preds = torch.topk(probs, 1)[1].squeeze(-1)
|
||||||
|
preds = preds.tolist()
|
||||||
|
lm_labels_list = lm_labels[:, 1:].tolist()
|
||||||
|
lm_labels_list = [[s for s in label if s != -1] for label in lm_labels_list]
|
||||||
|
reponses = ''
|
||||||
|
labels = ''
|
||||||
|
for pred, label in zip(preds, lm_labels_list):
|
||||||
|
reponses += self.tokenizer.decode(pred) + '\n'
|
||||||
|
labels += self.tokenizer.decode(label) + '\n'
|
||||||
|
|
||||||
|
output['reponses'] = reponses
|
||||||
|
output['gt'] = labels
|
||||||
|
|
||||||
|
|
||||||
|
gen_key = 'gen_loss (x{})'.format(self.config['gen_coeff'])
|
||||||
|
gen_loss = bart_output['gen_loss']
|
||||||
|
gen_loss = self.config['gen_coeff'] * gen_loss
|
||||||
|
|
||||||
|
|
||||||
|
elbo_global_key = 'elbo_loss_global (x{})'.format(self.config['elbo_global_coeff'])
|
||||||
|
if bart_output['elbo_loss_global'] is not None:
|
||||||
|
elbo_global_loss = bart_output['elbo_loss_global']
|
||||||
|
elbo_global_loss = self.config['elbo_global_coeff'] * elbo_global_loss
|
||||||
|
else:
|
||||||
|
elbo_global_loss = torch.tensor(0.0)
|
||||||
|
|
||||||
|
elbo_local_key = 'elbo_loss_local (x{})'.format(self.config['elbo_local_coeff'])
|
||||||
|
if bart_output['elbo_loss_local'] is not None:
|
||||||
|
elbo_local_loss = bart_output['elbo_loss_local']
|
||||||
|
elbo_local_loss = self.config['elbo_local_coeff'] * elbo_local_loss
|
||||||
|
else:
|
||||||
|
elbo_local_loss = torch.tensor(0.0)
|
||||||
|
|
||||||
|
total_loss = gen_loss + elbo_global_loss + elbo_local_loss
|
||||||
|
|
||||||
|
output['losses'] = {
|
||||||
|
gen_key: gen_loss,
|
||||||
|
elbo_local_key: elbo_local_loss,
|
||||||
|
elbo_global_key: elbo_global_loss,
|
||||||
|
'tot_loss': total_loss
|
||||||
|
}
|
||||||
|
del bart_output
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def generate(self, dataset, app_feats, mot_feats, tag, tokenizer, start_idx_gen, end_idx_gen, gen_subset_num=None):
|
||||||
|
|
||||||
|
self.model.eval()
|
||||||
|
results = {}
|
||||||
|
app_sep, mot_sep, ph_token = tokenizer.convert_tokens_to_ids(
|
||||||
|
['<s0>', '<s1>', '<place_holder>'])
|
||||||
|
|
||||||
|
# Generate the repsonse for each round
|
||||||
|
log.info('[INFO] Generating responses for {} samples'.format(len(dataset)))
|
||||||
|
with torch.no_grad():
|
||||||
|
counter = 0
|
||||||
|
for idx in range(start_idx_gen, end_idx_gen):
|
||||||
|
start_time = time()
|
||||||
|
cur_sample = dataset.loc[idx]
|
||||||
|
video_name, ques, ans, qid = str(cur_sample['video']), str(cur_sample['question']),\
|
||||||
|
str(cur_sample['answer']), str(cur_sample['qid'])
|
||||||
|
if video_name not in results:
|
||||||
|
results[video_name] = {}
|
||||||
|
|
||||||
|
input_ids = tokenize(ques, tokenizer)
|
||||||
|
|
||||||
|
app_feat = app_feats[video_name]
|
||||||
|
app_feat = torch.from_numpy(app_feat).type(torch.float32)
|
||||||
|
|
||||||
|
mot_feat = mot_feats[video_name]
|
||||||
|
mot_feat = torch.from_numpy(mot_feat).type(torch.float32)
|
||||||
|
|
||||||
|
bos, eos, ques_state = self.tokenizer.convert_tokens_to_ids(['<s>', '</s>', '<s2>'])
|
||||||
|
|
||||||
|
# Add state tokens
|
||||||
|
input_ids.insert(0, ques_state)
|
||||||
|
|
||||||
|
input_ids = torch.Tensor(input_ids).long()
|
||||||
|
|
||||||
|
dummy = torch.ones((1, 16)) * ph_token
|
||||||
|
video_place_holder_ids = torch.cat(
|
||||||
|
[torch.ones((1, 1)) * app_sep, dummy,
|
||||||
|
torch.ones((1, 1)) * mot_sep, dummy,
|
||||||
|
], dim=-1).long()
|
||||||
|
|
||||||
|
# Now we get the intervals of the visual input tokens
|
||||||
|
# Here the interval do not change across the batch dimension
|
||||||
|
app_interval = [0, 16 + 1] # the last token is not part of this modality
|
||||||
|
mot_interval = [16 + 1, 2 * 16 + 2]
|
||||||
|
vis_state_vector_idx = [app_interval[0], mot_interval[0]]
|
||||||
|
|
||||||
|
response = self.beam_search_generation(
|
||||||
|
input_ids,
|
||||||
|
app_feat, mot_feat,
|
||||||
|
app_interval, mot_interval,
|
||||||
|
vis_state_vector_idx, video_place_holder_ids, tokenizer)
|
||||||
|
|
||||||
|
# Decode the response
|
||||||
|
response = self.tokenizer.decode(response)
|
||||||
|
|
||||||
|
results[video_name][qid] = response
|
||||||
|
time_elapsed = int(time() - start_time)
|
||||||
|
print('Generating resonse {} / {} -- took {}s'.format(counter + 1, len(dataset), time_elapsed))
|
||||||
|
counter += 1
|
||||||
|
|
||||||
|
# Create a file with all responses
|
||||||
|
file_name = 'results_nextqa_beam_depth_{}'.format(self.config['beam_depth'])
|
||||||
|
if gen_subset_num is not None:
|
||||||
|
file_name += f'-part_{gen_subset_num}'
|
||||||
|
file_name = f'{tag}_' + file_name
|
||||||
|
output_path = os.path.join(self.config['output_dir_nextqa'], file_name + '.json')
|
||||||
|
with open(output_path, 'w') as f:
|
||||||
|
json.dump(results, f, indent=4)
|
||||||
|
log.info('Results logged to {}'.format(output_path))
|
||||||
|
print(os.getcwd())
|
||||||
|
# Switch back to training mode
|
||||||
|
self.model.train()
|
||||||
|
|
||||||
|
|
||||||
|
def beam_search_generation(
|
||||||
|
self, input_ids,
|
||||||
|
app_feat, mot_feat,
|
||||||
|
app_interval, mot_interval,
|
||||||
|
vis_state_vector_idx, video_place_holder_ids, tokenizer):
|
||||||
|
|
||||||
|
eos_token = tokenizer.eos_token_id
|
||||||
|
unk_token = tokenizer.unk_token_id
|
||||||
|
question_sep = tokenizer.convert_tokens_to_ids('<s2>')
|
||||||
|
|
||||||
|
gen_ans = [eos_token]
|
||||||
|
hyplist = [([], 0.0, [eos_token])]
|
||||||
|
best_state = None
|
||||||
|
comp_hyplist = []
|
||||||
|
|
||||||
|
app_feat = app_feat.unsqueeze(0).cuda()
|
||||||
|
mot_feat = mot_feat.unsqueeze(0).cuda()
|
||||||
|
video_place_holder_ids = video_place_holder_ids.cuda()
|
||||||
|
text_shift_len = video_place_holder_ids.size(-1)
|
||||||
|
|
||||||
|
question_intervals = [[0 + text_shift_len, input_ids.size(0) + text_shift_len]] # The last token is the question state token (not part of the history)
|
||||||
|
|
||||||
|
question_state_vector_idx = [x[0] for x in question_intervals]
|
||||||
|
|
||||||
|
input_ids = input_ids.long().cuda().unsqueeze(0)
|
||||||
|
encoder_outputs = None
|
||||||
|
|
||||||
|
for i in range(self.config['max_generation_length']):
|
||||||
|
new_hyplist = []
|
||||||
|
argmin = 0
|
||||||
|
for out, lp, st in hyplist:
|
||||||
|
decoder_input_ids = torch.tensor(st).long().cuda().unsqueeze(0)
|
||||||
|
|
||||||
|
bart_output = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
video_place_holder_ids=video_place_holder_ids,
|
||||||
|
i3d_rgb=app_feat,
|
||||||
|
i3d_flow=mot_feat,
|
||||||
|
encoder_outputs=encoder_outputs,
|
||||||
|
decoder_input_ids=decoder_input_ids,
|
||||||
|
i3d_rgb_interval=app_interval,
|
||||||
|
i3d_flow_interval=mot_interval,
|
||||||
|
question_intervals=question_intervals,
|
||||||
|
vis_state_vector_idx=vis_state_vector_idx,
|
||||||
|
question_state_vector_idx=question_state_vector_idx,
|
||||||
|
output_attentions=True,
|
||||||
|
generate=True,
|
||||||
|
return_dict=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if encoder_outputs is None:
|
||||||
|
encoder_outputs = [
|
||||||
|
bart_output['encoder_last_hidden_state'],
|
||||||
|
bart_output['encoder_hidden_states'],
|
||||||
|
bart_output['encoder_attentions'],
|
||||||
|
bart_output['encoder_QAs_local'],
|
||||||
|
bart_output['encoder_PAs_local'],
|
||||||
|
bart_output['encoder_QA_global'],
|
||||||
|
bart_output['encoder_PA_global'],
|
||||||
|
bart_output['encoder_state_vectors']
|
||||||
|
]
|
||||||
|
|
||||||
|
logits = bart_output['logits'][:,-1,:].squeeze() # get the logits of the last token
|
||||||
|
logp = F.log_softmax(logits, dim=0)
|
||||||
|
lp_vec = logp.cpu().data.numpy() + lp
|
||||||
|
if i >= self.config['min_generation_length']:
|
||||||
|
new_lp = lp_vec[eos_token] + self.config['length_penalty'] * (len(out) + 1)
|
||||||
|
comp_hyplist.append((out, new_lp))
|
||||||
|
if best_state is None or best_state < new_lp:
|
||||||
|
best_state = new_lp
|
||||||
|
count = 1
|
||||||
|
for o in np.argsort(lp_vec)[::-1]: # reverse the order
|
||||||
|
if o in [eos_token, unk_token]:
|
||||||
|
continue
|
||||||
|
new_lp = lp_vec[o]
|
||||||
|
if len(new_hyplist) == self.config['beam_depth']:
|
||||||
|
if new_hyplist[argmin][1] < new_lp:
|
||||||
|
new_st = deepcopy(st)
|
||||||
|
new_st.append(int(o))
|
||||||
|
new_hyplist[argmin] = (out + [o], new_lp, new_st)
|
||||||
|
argmin = min(enumerate(new_hyplist), key=lambda h: h[1][1])[0]
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
new_st = deepcopy(st)
|
||||||
|
new_st.append(int(o))
|
||||||
|
new_hyplist.append((out + [o], new_lp, new_st))
|
||||||
|
if len(new_hyplist) == self.config['beam_depth']:
|
||||||
|
argmin = min(enumerate(new_hyplist), key=lambda h: h[1][1])[0]
|
||||||
|
count += 1
|
||||||
|
hyplist = new_hyplist
|
||||||
|
|
||||||
|
if len(comp_hyplist) > 0:
|
||||||
|
maxhyps = sorted(comp_hyplist, key=lambda h: -h[1])[:1]
|
||||||
|
return maxhyps[0][0]
|
||||||
|
else:
|
||||||
|
return []
|
Loading…
Reference in a new issue