Make code public

This commit is contained in:
Adnen Abdessaied 2024-07-08 11:41:28 +02:00
commit 8e03ef1c38
49 changed files with 545354 additions and 0 deletions

3
.gitattributes vendored Normal file
View file

@ -0,0 +1,3 @@
*.pkl filter=lfs diff=lfs merge=lfs -text
*.h5 filter=lfs diff=lfs merge=lfs -text
*.csv filter=lfs diff=lfs merge=lfs -text

21
LICENSE Normal file
View file

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2023 Anonymous
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

174
README.md Normal file
View file

@ -0,0 +1,174 @@
<div align="center">
<h1> MST-MIXER <img src="misc/mixer.png" width="3%" align="bottom">: Multi-Modal Video Dialog State Tracking in the Wild </h1>
**[Adnen Abdessaied][16], &nbsp; [Lei Shi][17], &nbsp; [Andreas Bulling][18]** <br> <br>
**ECCV 2024, Milan, Italy <img src="misc/italy.png" width="3%" align="center">** <br>
**[[Paper][19]]**
---------------------------
<img src="misc/teaser.png" width="70%" align="middle"><br><br>
</div>
# Citation
If you find our code useful or use it in your own projects, please cite our paper:
```bibtex
@InProceedings{Abdessaied_2024_eccv,
author = {Abdessaied, Adnen and Shi, Lei and Bulling, Andreas},
title = {{Multi-Modal Video Dialog State Tracking in the Wild}},
booktitle = {Proceedings of the European Conference on Computer Vision (ECCV)},
year = {2024}
}
```
# Table of Contents
* [Setup and Dependencies](#Setup-and-Dependencies)
* [Download Data](#Download-Data)
* [Training](#Training)
* [Response Generation](#Response-Generation)
* [Results](#Results)
* [Acknowledgements](#Acknowledgements)
# Setup and Dependencies
We implemented our model using Python 3.7 and PyTorch 1.12.0 (CUDA 11.3, CuDNN 8.3.2). We recommend to setup a virtual environment using Anaconda. <br>
1. Install [git lfs][1] on your system
2. Clone our repository to download a checpint of our best model and our code
```shell
git lfs install
git clone this_repo.git
```
3. Create a conda environment and install dependencies
```shell
conda create -n mst_mixer python=3.7
conda activate mst_mixer
conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit=11.3 -c pytorch
conda install pyg -c pyg
conda install pytorch-scatter -c pyg # pytorch >= 1.8.0
conda install pytorch-sparse -c pyg # pytorch >= 1.8.0
conda install -c huggingface transformers
pip install evaluate wandb glog pyhocon attrs
```
# Download Data
## AVSD
1. Download the [AVSD-DSTC7][2], [AVSD-DSTC8][3] and [AVSD-DSTC10][10] data
2. Place the raw json files in ```raw_data/``` and the features in ```features/```
3. Prepeocess and save the input features for faster training as indicated in ```custom_datasets/```
## NExT-QA
1. For convenience, we included the features/data in this git repo.
# Training
We trained our model on 8 Nvidia Tesla V100-32GB GPUs. The default hyperparameters in ```config/mst_mixer.conf``` need to be adjusted if your setup differs from ours.
## AVSD
1. Set ```task=avsd``` in ```config/mst_mixer.conf```
2. ```shell
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
--mode train \
--tag mst_mixer_avsd \
--wandb_mode online \
--wandb_project mst_mixer_avsd
```
To deactivate [wandb][4] logging, use ```--wandb_mode disabled```.
On a similar setup to ours, this will take roughly 20h to complete.
## NExT-QA
1. Set ```task=nextqa``` in ```config/mst_mixer.conf```
2. ```shell
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
--mode train \
--tag mst_mixer_nextqa \
--wandb_mode online \
--wandb_project mst_mixer_nextqa
```
# Response Generation
## AVSD-DSTC7
1. Set ```dstc=7``` in the ```.conf``` file of your trained networks. in The default setting, can find this under ```logs/unique_training_tag/code/config/mst_mixer.conf```
2. Generate the responses
```shell
./generate_parallel_avsd.sh mst_mixer/mixer results_avsd_dstc7 generate logs/mst_mixer_avsd 7
```
3. All responses will be saved in ```output/dstc7/```
## AVSD-DSTC8
1. Set ```dstc=8``` in the ```.conf``` file of your trained networks. in The default setting, can find this under ```logs/unique_training_tag/code/config/mst_mixer.conf```
2. Generate the responses
```shell
./generate_parallel_avsd.sh mst_mixer/mixer results_avsd_dstc8 generate logs/mst_mixer_avsd 8
```
3. All responses will be saved in ```output/dstc8/```
## AVSD-DSTC10
1. Set ```dstc=10``` in the ```.conf``` file of your trained networks. in The default setting, can find this under ```logs/unique_training_tag/code/config/mst_mixer.conf```
2. Generate the responses
```shell
./generate_parallel_avsd.sh mst_mixer/mixer results_avsd_dstc10 generate logs/mst_mixer_avsd 10
```
3. All responses will be saved in ```output/dstc10/```
## NExT-QA
1. Generate the responses
```shell
./generate_parallel_nextqa.sh mst_mixer/mixer results_nextqa generate logs/mst_mixer_nextqa
```
2. All responses will be saved in ```output/nextqa/```
3. Evalute using this [script][15]
# Results
To evaluate our best model on
## AVSD-DSTC7
Executing the [eval_tool][7] of AVSD-DSTC7 using the generated repsonses will output the following metrics
| Model | BLUE-1 | BLUE-2 | BLUE-3 | BLUE-4 | METEOR | ROUGE-L | CIDEr |
|:--------:|:------:|:------:|:------:|:------:|:------:|:-------:|:-----:|
| Prev. SOTA | 78.2 | 65.5 | 55.2 | 46.9 | 30.8 | 61.9 | 135.2 |
| MST_MIXER | **78.7** | **66.5** | **56.3** | **47.6** | **31.3** | **62.5** | **138.8**|
## AVSD-DSTC8
1. Set ```dstc=8``` in the ```ckpt/code/mst_mixer.conf```
2. run
```shell
./generate_parallel_avsd.sh mst_mixer/mixer results_avsd_dstc8_best_model generate ckpt/avsd 8
```
3. The responses will be saved in ```output/dstc8/```
4. Executing the [eval_tool][7] of AVSD-DSTC8 using the generated repsonses will output the following metrics
| Model | BLUE-1 | BLUE-2 | BLUE-3 | BLUE-4 | METEOR | ROUGE-L | CIDEr |
|:--------:|:------:|:------:|:------:|:------:|:------:|:-------:|:-----:|
| Prev. SOTA | 76.4 | 64.1 | 54.3 | 46.0 | 30.1 | 61.0 | 130.4 |
| MST_MIXER | **77.5** | **66.0** | **56.1** | **47.7** | **30.6** | **62.4** | **135.4**|
## AVSD-DSTC10
Executing the [eval_tool][11] of AVSD-DSTC10 using the generated repsonses will output the following metrics
| Model | BLUE-1 | BLUE-2 | BLUE-3 | BLUE-4 | METEOR | ROUGE-L | CIDEr |
|:--------:|:------:|:------:|:------:|:------:|:------:|:-------:|:-----:|
| Prev. SOTA | 69.3 | 55.6 | 45.0 | 37.2 | 24.9 | 53.6 | 91.2 |
| MST_MIXER | **70.0** | **57.4** | **47.6** | **40.0** | **25.7** | **54.5** | **99.8**|
## NExT-QA
Executing the [eval script][15] of NExT-QA using the generated repsonses will output the following metrics
| Model | WUPS_C | WUPS_T | WUPS_D | WUPS |
|:--------:|:------:|:------:|:------:|:------:|
| Prev. SOTA | 17.98| 17.95 | 50.84 | 28.40 |
| MST_MIXER | **22.12** | **22.20** | **55.64** | **29.50** |
# Acknowledgements
We thank the authors of [RLM][8] for providing their [code][9] that greatly influenced this work.
[1]: https://git-lfs.com/
[2]: https://github.com/hudaAlamri/DSTC7-Audio-Visual-Scene-Aware-Dialog-AVSD-Challenge
[3]: https://github.com/dialogtekgeek/DSTC8-AVSD_official
[4]: https://wandb.ai/site
[5]: https://drive.google.com/drive/folders/1SlZTySJAk_2tiMG5F8ivxCfOl_OWwd_Q
[7]: https://drive.google.com/file/d/1EKfPtrNBQ5ciKRl6XggImweGRP84XuPi/view?usp=sharing
[8]: https://arxiv.org/abs/2002.00163
[9]: https://github.com/ictnlp/DSTC8-AVSD
[10]: https://drive.google.com/file/d/1zvC6FuPRVRiLQCXZcYpzYUI9r1tiWls6/view
[11]: https://github.com/ankitshah009/AVSD-DSTC10_baseline
[15]: https://github.com/doc-doc/NExT-OE/blob/main/eval_oe.py
[16]: https://adnenabdessaied.de/
[17]: https://perceptualui.org/people/shi/
[18]: https://perceptualui.org/people/bulling/
[19]: https://arxiv.org/abs/2407.02218

118
config/avsd_bart_base.json Normal file
View file

@ -0,0 +1,118 @@
{
"_name_or_path": "bart-base",
"activation_dropout": 0.1,
"activation_function": "gelu",
"add_bias_logits": false,
"add_final_layer_norm": false,
"architectures": [
"BartModel"
],
"attention_dropout": 0.1,
"bos_token_id": 0,
"classif_dropout": 0.1,
"classifier_dropout": 0.0,
"d_model": 768,
"decoder_attention_heads": 12,
"decoder_ffn_dim": 3072,
"decoder_layerdrop": 0.0,
"decoder_layers": 6,
"decoder_start_token_id": 2,
"dropout": 0.1,
"early_stopping": true,
"encoder_attention_heads": 12,
"encoder_ffn_dim": 3072,
"encoder_layerdrop": 0.0,
"encoder_layers": 6,
"eos_token_id": 2,
"forced_eos_token_id": 2,
"forced_bos_token_id": 0,
"gradient_checkpointing": false,
"id2label": {
"0": "LABEL_0",
"1": "LABEL_1",
"2": "LABEL_2"
},
"init_std": 0.02,
"is_encoder_decoder": true,
"label2id": {
"LABEL_0": 0,
"LABEL_1": 1,
"LABEL_2": 2
},
"max_position_embeddings": 1024,
"model_type": "bart",
"no_repeat_ngram_size": 3,
"normalize_before": false,
"normalize_embedding": true,
"num_beams": 4,
"num_hidden_layers": 6,
"pad_token_id": 1,
"scale_embedding": false,
"task_specific_params": {
"summarization": {
"length_penalty": 1.0,
"max_length": 128,
"min_length": 12,
"num_beams": 4
},
"summarization_cnn": {
"length_penalty": 2.0,
"max_length": 142,
"min_length": 56,
"num_beams": 4
},
"summarization_xsum": {
"length_penalty": 1.0,
"max_length": 62,
"min_length": 11,
"num_beams": 6
}
},
"torch_dtype": "float32",
"transformers_version": "4.12.0.dev0",
"use_cache": true,
"vocab_size": 50265,
"d_i3d_flow": 2048,
"d_i3d_rgb": 2048,
"d_sam": 512,
"d_audio": 128,
"top_k": 10,
"num_nn": 4,
"gnn_type": "appnp",
"use_random_graphs": false,
"integrate_all_gnn_features": true,
"use_elbo_local": true,
"use_elbo_global": true,
"use_non_linear": true,
"gnn_K": 2,
"gnn_alpha": 0.1,
"num_modalities": 6,
"local_gnn_d_hidden": 768,
"global_gnn_d_hidden": 768,
"num_local_gnn_heads": 2,
"num_global_gnn_heads": 4,
"local_gnn_dropout": 0.1,
"global_gnn_dropout": 0.1,
"local_fc_dropout": 0.1,
"global_fc_dropout": 0.1,
"num_local_gnn_layers": 2,
"num_global_gnn_layers": 2,
"num_local_fc_layers": 2,
"num_global_fc_layers": 2,
"use_local_gnn_bn": true,
"use_global_gnn_bn": true,
"use_local_fc_bn": true,
"use_global_fc_bn": true,
"local_gnn_concat": true,
"global_gnn_concat": true,
"num_local_gr_learner_heads": 8,
"num_global_gr_learner_heads":8,
"init_adj_ratio": 0.5,
"adj_ratio": 0.5,
"alpha": 0.9,
"gnns_every": 2,
"num_layers_state_fc_decoder": 2,
"dropout_state_fc_decoder": 0.3
}

116
config/avsd_bart_large.json Normal file
View file

@ -0,0 +1,116 @@
{
"activation_dropout": 0.1,
"activation_function": "gelu",
"add_bias_logits": false,
"add_final_layer_norm": false,
"architectures": [
"BartModel"
],
"attention_dropout": 0.1,
"bos_token_id": 0,
"classif_dropout": 0.1,
"classifier_dropout": 0.0,
"d_model": 1024,
"decoder_attention_heads": 16,
"decoder_ffn_dim": 4096,
"decoder_layerdrop": 0.0,
"decoder_layers": 12,
"decoder_start_token_id": 2,
"dropout": 0.1,
"early_stopping": true,
"encoder_attention_heads": 16,
"encoder_ffn_dim": 4096,
"encoder_layerdrop": 0.0,
"encoder_layers": 12,
"eos_token_id": 2,
"forced_eos_token_id": 2,
"forced_bos_token_id": 0,
"gradient_checkpointing": false,
"id2label": {
"0": "LABEL_0",
"1": "LABEL_1",
"2": "LABEL_2"
},
"init_std": 0.02,
"is_encoder_decoder": true,
"label2id": {
"LABEL_0": 0,
"LABEL_1": 1,
"LABEL_2": 2
},
"max_position_embeddings": 1024,
"model_type": "bart",
"no_repeat_ngram_size": 3,
"normalize_before": false,
"num_beams": 4,
"num_hidden_layers": 12,
"pad_token_id": 1,
"scale_embedding": false,
"task_specific_params": {
"summarization": {
"length_penalty": 1.0,
"max_length": 128,
"min_length": 12,
"num_beams": 4
},
"summarization_cnn": {
"length_penalty": 2.0,
"max_length": 142,
"min_length": 56,
"num_beams": 4
},
"summarization_xsum": {
"length_penalty": 1.0,
"max_length": 62,
"min_length": 11,
"num_beams": 6
}
},
"transformers_version": "4.7.0.dev0",
"use_cache": true,
"vocab_size": 50265,
"d_i3d_flow": 2048,
"d_i3d_rgb": 2048,
"d_sam": 512,
"d_audio": 128,
"top_k": 10,
"num_nn": 4,
"gnn_type": "appnp",
"use_random_graphs": false,
"integrate_all_gnn_features": true,
"use_elbo_local": true,
"use_elbo_global": true,
"use_non_linear": true,
"gnn_K": 2,
"gnn_alpha": 0.1,
"num_modalities": 6,
"local_gnn_d_hidden": 1024,
"global_gnn_d_hidden": 1024,
"num_local_gnn_heads": 2,
"num_global_gnn_heads": 4,
"local_gnn_dropout": 0.1,
"global_gnn_dropout": 0.1,
"local_fc_dropout": 0.1,
"global_fc_dropout": 0.1,
"num_local_gnn_layers": 1,
"num_global_gnn_layers": 1,
"num_local_fc_layers": 1,
"num_global_fc_layers": 1,
"use_local_gnn_bn": true,
"use_global_gnn_bn": true,
"use_local_fc_bn": true,
"use_global_fc_bn": true,
"local_gnn_concat": true,
"global_gnn_concat": true,
"num_local_gr_learner_heads": 8,
"num_global_gr_learner_heads": 8,
"init_adj_ratio": 0.5,
"adj_ratio": 0.5,
"alpha": 0.9,
"gnns_every": 4,
"num_layers_state_fc_decoder": 2,
"dropout_state_fc_decoder": 0.3
}

94
config/mst_mixer.conf Normal file
View file

@ -0,0 +1,94 @@
mixer {
task = avsd
#################################################################################
# datasets
# avsd
avsd_processed = features/
avsd_train = raw_data/train_set4DSTC7-AVSD.json
avsd_val = raw_data/valid_set4DSTC7-AVSD.json
avsd_test_dstc7 = raw_data/test_set4DSTC7-AVSD.json
avsd_test_dstc8 = raw_data/test_set4DSTC8-AVSD.json
avsd_test_dstc10 = raw_data/test_set4DSTC10-AVSD.json
avsd_feature_path = features/
avsd_i3d_rgb = features/i3d_rgb
avsd_i3d_rgb_test = features/i3d_rgb_testset
avsd_i3d_flow = features/i3d_flow_all
avsd_i3d_flow_test = features/i3d_flow_testset
avsd_audio = features/vggish_all
avsd_audio_test = features/vggish_testset
avsd_objects = features/sam
avsd_objects_test = features/sam_testset
dstc = 7
# NextQA
nextqa_root = processed/next_qa/annotations
nextqa_vid_feat = processed/next_qa/vid_feat
#################################################################################
# Model
bart_size = large # base, large
avsd_bart_base_config = config/avsd_bart_base.json
avsd_bart_large_config = config/avsd_bart_large.json
nextqa_bart_large_config = config/nextqa_bart_large.json
#################################################################################
# Logging & Checkpointing
log_dir = logs
output_dir_dstc7 = output/dstc7
output_dir_dstc8 = output/dstc8
output_dir_dstc10 = output/dstc10
output_dir_nextqa = output/nextqa
max_ckpt_to_keep = 5
start_ckpt_for_generating = none
loads_start_path = false
next_logging_pct = 0.1
save_ckpt=true
skip_saving_ckpt = false
stop_epochs = -1
resets_min_val_loss = false
restarts = false
uses_new_optimizer = true
sets_new_lr = false
################################################################################
# Data processing
expand_rnd = false
cap_sum = cap_sum
add_state_tokens = true
bart_max_input_len = 1024
num_workers = 0
n_history = 3
caption_drop_rate = 0.0
vis_feat_length = 36
#################################################################################
# Training
dp_type = ddp
batch_size = 16
num_epochs = 12
warmup_ratio = 0.1
batch_multiply = 1
skip_eval = false
stop_epoch = -1
random_seed = 54
learning_rate_bart = 1e-5
learning_rate_other = 1e-4
min_lr = 0
clip_grad_value = 1.0
print_output = false
eval_first = false
overfit_size = -1
elbo_global_coeff = 100
elbo_local_coeff = 100
gen_coeff = 1
#################################################################################
# Generation
gen_batch_size = 1
beam_depth = 5
max_generation_length = 20
min_generation_length = 1
length_penalty = 0.3
#################################################################################
# Misc.
master_port = 5101
use_cpu = false
}

View file

@ -0,0 +1,116 @@
{
"activation_dropout": 0.1,
"activation_function": "gelu",
"add_bias_logits": false,
"add_final_layer_norm": false,
"architectures": [
"BartModel"
],
"attention_dropout": 0.1,
"bos_token_id": 0,
"classif_dropout": 0.1,
"classifier_dropout": 0.0,
"d_model": 1024,
"decoder_attention_heads": 16,
"decoder_ffn_dim": 4096,
"decoder_layerdrop": 0.0,
"decoder_layers": 12,
"decoder_start_token_id": 2,
"dropout": 0.1,
"early_stopping": true,
"encoder_attention_heads": 16,
"encoder_ffn_dim": 4096,
"encoder_layerdrop": 0.0,
"encoder_layers": 12,
"eos_token_id": 2,
"forced_eos_token_id": 2,
"forced_bos_token_id": 0,
"gradient_checkpointing": false,
"id2label": {
"0": "LABEL_0",
"1": "LABEL_1",
"2": "LABEL_2"
},
"init_std": 0.02,
"is_encoder_decoder": true,
"label2id": {
"LABEL_0": 0,
"LABEL_1": 1,
"LABEL_2": 2
},
"max_position_embeddings": 1024,
"model_type": "bart",
"no_repeat_ngram_size": 3,
"normalize_before": false,
"num_beams": 4,
"num_hidden_layers": 12,
"pad_token_id": 1,
"scale_embedding": false,
"task_specific_params": {
"summarization": {
"length_penalty": 1.0,
"max_length": 128,
"min_length": 12,
"num_beams": 4
},
"summarization_cnn": {
"length_penalty": 2.0,
"max_length": 142,
"min_length": 56,
"num_beams": 4
},
"summarization_xsum": {
"length_penalty": 1.0,
"max_length": 62,
"min_length": 11,
"num_beams": 6
}
},
"transformers_version": "4.7.0.dev0",
"use_cache": true,
"vocab_size": 50265,
"d_i3d_flow": 2048,
"d_i3d_rgb": 2048,
"d_sam": 512,
"d_audio": 128,
"top_k": 10,
"num_nn": 4,
"gnn_type": "appnp",
"use_random_graphs": false,
"integrate_all_gnn_features": true,
"use_elbo_local": true,
"use_elbo_global": true,
"use_non_linear": true,
"gnn_K": 2,
"gnn_alpha": 0.1,
"num_modalities": 3,
"local_gnn_d_hidden": 1024,
"global_gnn_d_hidden": 1024,
"num_local_gnn_heads": 2,
"num_global_gnn_heads": 4,
"local_gnn_dropout": 0.1,
"global_gnn_dropout": 0.1,
"local_fc_dropout": 0.1,
"global_fc_dropout": 0.1,
"num_local_gnn_layers": 1,
"num_global_gnn_layers": 1,
"num_local_fc_layers": 1,
"num_global_fc_layers": 1,
"use_local_gnn_bn": true,
"use_global_gnn_bn": true,
"use_local_fc_bn": true,
"use_global_fc_bn": true,
"local_gnn_concat": true,
"global_gnn_concat": true,
"num_local_gr_learner_heads": 8,
"num_global_gr_learner_heads": 8,
"init_adj_ratio": 0.5,
"adj_ratio": 0.5,
"alpha": 0.9,
"gnns_every": 4,
"num_layers_state_fc_decoder": 2,
"dropout_state_fc_decoder": 0.3
}

20
custom_datasets/README.md Normal file
View file

@ -0,0 +1,20 @@
1. Download the raw [Charades train/val](https://prior.allenai.org/projects/charades) data
2. Download the raw [Charades test](https://ai2-public-datasets.s3-us-west-2.amazonaws.com/charades/Charades_vu17_test_480.tar) data
3. Install [SAM](https://github.com/facebookresearch/segment-anything.git)
4. Segment the frames
```shell
python segement.py --sam_ckpt path_to_sam_ckpt --avsd_root path_to_charades_trval_frames --crop_root path_to_save_the_trval_crops --mode segment --start start_idx --end end_idx
python segement.py --sam_ckpt path_to_sam_ckpt --avsd_root path_to_charades_test_frames --crop_root path_to_save_the_test_crops --mode segment --start start_idx --end end_id
```
5. Embed the crops
```shell
python segement.py --sam_ckpt path_to_sam_ckpt --crop_root path_to_save_the_trval_crops --mode emebed --embed_root ../features/sam --start start_idx --end end_idx
python segement.py --sam_ckpt path_to_sam_ckpt --crop_root path_to_save_the_test_crops --mode emebed --embed_root ../features/sam_testset --start start_idx --end end_idx
```
6. Preprocess and log the data
```shell
python dataset.py --split train
python dataset.py --split val
```

View file

401
custom_datasets/avsd.py Normal file
View file

@ -0,0 +1,401 @@
import os
import pickle
import pyhocon
from copy import deepcopy
import json
from tqdm import tqdm
import numpy as np
import torch
from argparse import ArgumentParser
from torch.utils.data import Dataset, DataLoader
from transformers import BartTokenizer
from itertools import chain
ADDITIONAL_SPECIAL_TOKENS = [
'<place_holder>', '<s0>', '<s1>', '<s2>', '<s3>', '<s4>', '<s5>']
SPECIAL_TOKENS_DICT = {
'bos_token': '<s>',
'eos_token': '</s>',
'pad_token': '<pad>',
'additional_special_tokens': ADDITIONAL_SPECIAL_TOKENS
}
S0_TOK = '<s0>' # I3D_flow
S1_TOK = '<s1>' # I3D_rgb
S2_TOK = '<s2>' # sam obj
S3_TOK = '<s3>' # audio
S4_TOK = '<s4>' # history
S5_TOK = '<s5>' # question
def tokenize(obj, tokenizer):
if isinstance(obj, str):
return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj))
if isinstance(obj, dict):
return dict((n, tokenize(o)) for n, o in obj.items())
return list(tokenize(o) for o in obj)
class AVSDDataset(Dataset):
def __init__(self, config, split):
super().__init__()
self.config = config
self.split = split
self.bart_max_input_len = config['bart_max_input_len']
self.bart_size = config['bart_size']
self.cap_sum = config['cap_sum']
self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-{}'.format(self.bart_size))
self.vocab_size = self.tokenizer.vocab_size
self.tokenizer.add_special_tokens({'additional_special_tokens': ADDITIONAL_SPECIAL_TOKENS})
self.vocab_size += len(ADDITIONAL_SPECIAL_TOKENS)
self.tokenizer.save_pretrained(os.path.join(self.config['log_dir'], 'bart_tokenizer'))
self.processed_dir = os.path.join(self.config['avsd_processed'], 'hist_with_{}_rounds'.format(self.config['n_history']), split)
self.paths = list(map(lambda p: os.path.join(self.processed_dir, p), os.listdir(self.processed_dir)))
if self.config['overfit'] > 0:
self.paths = self.paths[:self.config['overfit_size']]
def __len__(self):
return len(self.paths)
def __getitem__(self, index):
pth = self.paths[index]
with open(pth, 'rb') as f:
item = pickle.load(f)
question_sep = self.tokenizer.convert_tokens_to_ids('<s5>')
input_ids = item['input_ids']
history_end = (input_ids == question_sep).nonzero(as_tuple=True)[0]
history_interval = [0, history_end.item()] # The last token is the question state token (not part of the history)
question_interval = [history_end.item(), input_ids.size(0)]
lm_labels = item['lm_labels']
i3d_rgb = item['i3d_rgb']
i3d_flow = item['i3d_flow']
sam = item['sam']
vgg = item['vgg']
vid = item['vid']
return input_ids, lm_labels, history_interval, question_interval, i3d_rgb, i3d_flow, sam, vgg, vid
def padding(self, seq, pad_token, max_len=None):
if max_len is None:
max_len = max([i.size(0) for i in seq])
if len(seq[0].size()) == 1:
result = torch.ones((len(seq), max_len)).long() * pad_token
else:
result = torch.ones((len(seq), max_len, seq[0].size(-1))).float()
for i in range(len(seq)):
result[i, :seq[i].size(0)] = seq[i]
return result
def collate_fn(self, batch):
input_ids_list, lm_labels_list, history_interval_list, question_interval_list, i3d_rgb_list, i3d_flow_list, sam_list, vggish_list, vid_ids_list = [], [], [], [], [], [], [], [], []
for i in batch:
input_ids_list.append(i[0])
lm_labels_list.append(i[1])
history_interval_list.append(i[2])
question_interval_list.append(i[3])
i3d_rgb_list.append(i[4])
i3d_flow_list.append(i[5])
sam_list.append(i[6])
vggish_list.append(i[7])
vid_ids_list.append(i[8])
history_intervals = np.array(history_interval_list)
question_intervals = np.array(question_interval_list)
min_len_i3d_flow = min([feat.shape[0] for feat in i3d_flow_list])
min_len_i3d_rgb = min([feat.shape[0] for feat in i3d_rgb_list])
min_len_sam = min([feat.shape[0] for feat in sam_list])
min_len_vggish = min([feat.shape[0] for feat in vggish_list])
min_length = min([self.config['vis_feat_length'], min_len_i3d_flow, min_len_i3d_rgb, min_len_sam, min_len_vggish])
# Sample equally-distant features from the visual features for each sample within the batch
for i in range(len(i3d_rgb_list)):
sample_idx_i3d_rgb = np.round(np.linspace(0, i3d_rgb_list[i].shape[0] - 1, min_length)).astype(int)
i3d_rgb_list[i] = i3d_rgb_list[i][sample_idx_i3d_rgb, :]
i3d_rgb = torch.from_numpy(np.array(i3d_rgb_list)).float()
for i in range(len(i3d_flow_list)):
sample_idx_i3d_flow = np.round(np.linspace(0, i3d_flow_list[i].shape[0] - 1, min_length)).astype(int)
i3d_flow_list[i] = i3d_flow_list[i][sample_idx_i3d_flow, :]
i3d_flow = torch.from_numpy(np.array(i3d_flow_list)).float()
for i in range(len(sam_list)):
sample_idx_sam = np.round(np.linspace(0, sam_list[i].shape[0] - 1, min_length)).astype(int)
sam_list[i] = sam_list[i][sample_idx_sam, :]
sam = torch.from_numpy(np.array(sam_list)).float()
for i in range(len(vggish_list)):
sample_idx_vggish = np.round(np.linspace(0, vggish_list[i].shape[0] - 1, min_length)).astype(int)
vggish_list[i] = vggish_list[i][sample_idx_vggish, :]
vggish = torch.from_numpy(np.array(vggish_list)).float()
pad_token, i3d_flow_sep, i3d_rgb_sep, sam_sep, audio_sep, ph_token = self.tokenizer.convert_tokens_to_ids(
['<pad>', '<s0>', '<s1>', '<s2>', '<s3>', '<place_holder>'])
# All the visual features will not be masked because we do not perform any padding on them
video_mask = torch.ones((len(batch), min_length*4 + 4)) == 1 # NOTE *4: 4 modalities | +4: the state tokens
# Now we create a dummy input for the video tokens (sole purpose is to reserve the spot of the seperators)
dummy = torch.ones((len(batch), min_length)) * ph_token
video_place_holder_ids = torch.cat(
[torch.ones((len(batch), 1)) * i3d_rgb_sep, dummy,
torch.ones((len(batch), 1)) * i3d_flow_sep, dummy,
torch.ones((len(batch), 1)) * sam_sep, dummy,
torch.ones((len(batch), 1)) * audio_sep, dummy,
], dim=-1).long()
input_ids = self.padding(input_ids_list, pad_token)
lm_labels = self.padding(lm_labels_list, -100)
text_mask = input_ids != pad_token
input_mask = torch.cat([video_mask, text_mask], dim=1)
# Now we get the intervals of the visual input tokens
# Here the interval do not change across the batch dimension
i3d_rgb_interval = [0, min_length + 1] # the last token is not part of this modality
i3d_flow_interval = [min_length + 1, 2 * min_length + 2]
sam_interval = [2 * min_length + 2, 3 * min_length + 3]
audio_interval = [3 * min_length + 3, 4 * min_length + 4]
vis_state_vector_idx = [i3d_rgb_interval[0], i3d_flow_interval[0], sam_interval[0], audio_interval[0]]
# adapt the question and history interval -- shifted to the right by the visual input length
history_intervals += 4 * min_length + 4
question_intervals += 4 * min_length + 4
history_intervals = history_intervals.tolist()
question_intervals = question_intervals.tolist()
history_state_vector_idx = [x[0] + 1 for x in history_intervals] # +1 because the history starts with <s><s4> .....
question_state_vector_idx = [x[0] for x in question_intervals] # +1 because the history starts with <s><s4> .....
batch = {
'input_ids': input_ids,
'video_place_holder_ids': video_place_holder_ids,
'i3d_rgb': i3d_rgb,
'i3d_flow': i3d_flow,
'sam': sam,
'vggish': vggish,
'lm_labels': lm_labels,
'input_mask': input_mask,
'i3d_rgb_interval': i3d_rgb_interval,
'i3d_flow_interval': i3d_flow_interval,
'sam_interval': sam_interval,
'audio_interval': audio_interval,
'history_intervals': history_intervals,
'question_intervals': question_intervals,
'vis_state_vector_idx': vis_state_vector_idx,
'history_state_vector_idx': history_state_vector_idx,
'question_state_vector_idx': question_state_vector_idx
}
return batch
def get_dataset(config, split, tokenizer):
if split != 'test':
dialog_pth = config[f'avsd_{split}']
else:
dialog_pth = config['avsd_test_dstc{}'.format(config['dstc'])]
n_history = config['n_history']
dialog_data = json.load(open(dialog_pth, 'r'))
dialog_list = []
vid_set = set()
undisclosed_only = split == 'test'
pbar = tqdm(dialog_data['dialogs'])
pbar.set_description('[INFO] Generating {} items | DSTC {}'.format(split, config['dstc']))
for dialog in pbar:
if config['dstc'] != 10:
caption = [tokenize(dialog['caption'], tokenizer)] + [tokenize(dialog['summary'], tokenizer)]
else:
caption = [tokenize('no', tokenizer)]
questions = [tokenize(d['question'], tokenizer) for d in dialog['dialog']]
answers = [tokenize(d['answer'], tokenizer) for d in dialog['dialog']]
vid = dialog["image_id"]
vid_set.add(vid)
if undisclosed_only:
it = range(len(questions) - 1, len(questions))
else:
it = range(len(questions))
qalist=[]
history = []
if undisclosed_only:
for n in range(len(questions)-1):
qalist.append(questions[n])
qalist.append(answers[n])
history=qalist[max(-len(qalist),-n_history*2):]
for n in it:
if undisclosed_only:
assert dialog['dialog'][n]['answer'] == '__UNDISCLOSED__'
question = questions[n]
answer = answers[n]
history.append(question)
if n_history == 0:
item = {'vid': vid, 'history': [question], 'answer': answer, 'caption': caption}
else:
item = {'vid': vid, 'history': history, 'answer': answer, 'caption': caption}
dialog_list.append(item)
qalist.append(question)
qalist.append(answer)
history=qalist[max(-len(qalist),-n_history*2):]
all_features = {}
fea_types = ['vggish', 'i3d_flow', 'i3d_rgb', 'sam']
dataname = '<FeaType>/<ImageID>.npy'
for ftype in fea_types:
if undisclosed_only:
basename = dataname.replace('<FeaType>', ftype+'_testset')
else:
basename = dataname.replace('<FeaType>', ftype)
features = {}
for vid in vid_set:
filename = basename.replace('<ImageID>', vid)
filepath = config['avsd_feature_path'] + filename
features[vid] = filepath
all_features[ftype] = features
return dialog_list, all_features
def build_input_from_segments(caption, history_orig, reply, tokenizer, add_state_tokens=True, drop_caption=False):
""" Build a sequence of input from 3 segments: caption(caption+summary) history and last reply """
bos, eos, hist_state, ques_state = tokenizer.convert_tokens_to_ids(['<s>', '</s>', '<s4>', '<s5>'])
sep = eos
instance = {}
instance["lm_labels"] = reply + [eos]
caption = list(chain(*caption))
# Add state tokens if applicable
if add_state_tokens:
caption.insert(0, hist_state)
history = deepcopy(history_orig)
history[-1].insert(0, ques_state)
else:
history = history_orig
if not drop_caption:
# sequence = [[bos] + list(chain(*caption))] + history + [reply + ([eos] if with_eos else [])]
# NOTE It is important not to include the reply in the input of the encoder -- > the decoder will just
# learn to copy it --> low train/val loss but no learning is happening
sequence = [[bos] + caption + [eos]] + [[sep] + s for s in history] + [[eos]]
else:
sequence = [[bos]] + [[hist_state]] + [[sep] + s for s in history] + [[eos]]
instance["input_ids"] = list(chain(*sequence))
return instance
def parse_args():
parser = ArgumentParser(description='debug dataloader')
parser.add_argument(
'--split',
type=str,
default='train',
help='train or val')
parser.add_argument(
'--model',
type=str,
default='mixer',
help='model name to train or test')
parser.add_argument(
'--log_dataset',
action='store_true',
default=False,
help='Whether or not to log the processed data')
parser.add_argument(
'--add_state_tokens',
action='store_true',
default=True,
help='Whether or not to add state tokens')
parser.add_argument(
'--log_dir',
type=str,
default='processed/avsd',
help='Output directory')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
split = args.split
config = pyhocon.ConfigFactory.parse_file(
'config/mst_mixer.conf')[args.model]
config['expand_rnd'] = False
config['debugging'] = False
config['overfit'] = False
args.log_dir = os.path.join(args.log_dir, 'hist_with_{}_rounds'.format(config['n_history']) )
if args.log_dataset:
log_dir = os.path.join(args.log_dir, split)
if not os.path.isdir(log_dir):
os.makedirs(log_dir)
tokenizer = BartTokenizer.from_pretrained('facebook/bart-{}'.format(config['bart_size']))
tokenizer.add_special_tokens({'additional_special_tokens': ADDITIONAL_SPECIAL_TOKENS})
dialogs, features = get_dataset(config, split, tokenizer)
pbar = tqdm(dialogs)
pbar.set_description('[{}] Logging processed data'.format(split))
counter = 0
for dialog in pbar:
vid = dialog['vid']
his = dialog['history']
cap = dialog['caption']
ans = dialog['answer']
if np.random.rand() < config['caption_drop_rate']:
instance = build_input_from_segments(
cap, his, ans, tokenizer, add_state_tokens=args.add_state_tokens, drop_caption=True)
else:
instance = build_input_from_segments(
cap, his, ans, tokenizer, add_state_tokens=args.add_state_tokens, drop_caption=False)
input_ids = torch.Tensor(instance["input_ids"]).long()
lm_labels = torch.Tensor(instance["lm_labels"]).long()
vgg = np.load(features["vggish"][vid])
i3d_flow = np.load(features["i3d_flow"][vid])
i3d_rgb = np.load(features["i3d_rgb"][vid])
sam = np.load(features["sam"][vid])
item = {
'input_ids': input_ids,
'lm_labels': lm_labels,
'i3d_rgb': i3d_rgb,
'i3d_flow': i3d_flow,
'sam': sam,
'vgg': vgg,
'vid': vid
}
counter += 1
pth = os.path.join(log_dir, str(counter) + '.pkl')
with open(pth, 'wb') as f:
pickle.dump(item, f, protocol=pickle.HIGHEST_PROTOCOL)
else:
avsd_dataset = AVSDDataset(config, 'val')
avsd_dataloader = DataLoader(avsd_dataset, batch_size=4, shuffle=False, collate_fn=avsd_dataset.collate_fn)
for i, data in enumerate(avsd_dataloader):
print('{}/{}'.format(i, len(avsd_dataloader)))
print(avsd_dataset.max_len)
print('[INFO] Done...')

211
custom_datasets/nextqa.py Normal file
View file

@ -0,0 +1,211 @@
import os
import pandas as pd
import h5py
import json
import numpy as np
import torch
from torch.utils.data import Dataset
from transformers import BartTokenizer
from itertools import chain
ADDITIONAL_SPECIAL_TOKENS = [
'<place_holder>', '<s0>', '<s1>', '<s2>', '<s3>', '<s4>', '<s5>']
SPECIAL_TOKENS_DICT = {
'bos_token': '<s>',
'eos_token': '</s>',
'pad_token': '<pad>',
'additional_special_tokens': ADDITIONAL_SPECIAL_TOKENS
}
S0_TOK = '<s0>' # frame
S1_TOK = '<s1>' # mot
S2_TOK = '<s2>' # question
def load_file(file_name):
annos = None
if os.path.splitext(file_name)[-1] == '.csv':
return pd.read_csv(file_name)
with open(file_name, 'r') as fp:
if os.path.splitext(file_name)[1]== '.txt':
annos = fp.readlines()
annos = [line.rstrip() for line in annos]
if os.path.splitext(file_name)[1] == '.json':
annos = json.load(fp)
return annos
def tokenize(obj, tokenizer):
if isinstance(obj, str):
return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj))
if isinstance(obj, dict):
return dict((n, tokenize(o)) for n, o in obj.items())
return list(tokenize(o) for o in obj)
class NextQADataset(Dataset):
def __init__(self, config, split):
super().__init__()
self.config = config
self.split = split
self.bart_max_input_len = config['bart_max_input_len']
self.bart_size = config['bart_size']
self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-{}'.format(self.bart_size))
self.vocab_size = self.tokenizer.vocab_size
self.tokenizer.add_special_tokens({'additional_special_tokens': ADDITIONAL_SPECIAL_TOKENS})
self.vocab_size += len(ADDITIONAL_SPECIAL_TOKENS)
self.tokenizer.save_pretrained(os.path.join(self.config['log_dir'], 'bart_tokenizer'))
sample_list_file = os.path.join(self.config['nextqa_root'], '{}.csv'.format(split))
self.sample_list = load_file(sample_list_file)
vid_feat_file = os.path.join(self.config['nextqa_vid_feat'], 'app_mot_{}.h5'.format(split))
print('Load {}...'.format(vid_feat_file))
self.frame_feats = {}
self.mot_feats = {}
with h5py.File(vid_feat_file, 'r') as fp:
vids = fp['ids']
feats = fp['feat']
for vid, feat in zip(vids, feats):
self.frame_feats[str(vid)] = feat[:, :2048] # (16, 2048)
self.mot_feats[str(vid)] = feat[:, 2048:] # (16, 2048)
if self.config['overfit_size'] > 0:
self.sample_list = self.sample_list[:self.config['overfit_size']]
def __len__(self):
return len(self.sample_list)
def get_video_feature(self, video_name):
"""
:param video_name:
:return:
"""
app_feat = self.frame_feats[video_name]
app_feat = torch.from_numpy(app_feat).type(torch.float32)
mot_feat = self.mot_feats[video_name]
mot_feat = torch.from_numpy(mot_feat).type(torch.float32)
return app_feat, mot_feat
def __getitem__(self, idx):
cur_sample = self.sample_list.loc[idx]
video_name, ques, ans, qid = str(cur_sample['video']), str(cur_sample['question']),\
str(cur_sample['answer']), str(cur_sample['qid'])
input_ids = tokenize(ques, self.tokenizer)
lm_labels = tokenize(ans, self.tokenizer)
app_feat, mot_feat = self.get_video_feature(video_name)
bos, eos, ques_state = self.tokenizer.convert_tokens_to_ids(['<s>', '</s>', '<s2>'])
# Add state tokens
input_ids.insert(0, ques_state)
lm_labels.append(eos)
question_interval = [0, len(input_ids)]
input_ids = torch.Tensor(input_ids).long()
lm_labels = torch.Tensor(lm_labels).long()
return input_ids, lm_labels, app_feat, mot_feat, question_interval, video_name
def padding(self, seq, pad_token, max_len=None):
if max_len is None:
max_len = max([i.size(0) for i in seq])
if len(seq[0].size()) == 1:
result = torch.ones((len(seq), max_len)).long() * pad_token
else:
result = torch.ones((len(seq), max_len, seq[0].size(-1))).float()
for i in range(len(seq)):
result[i, :seq[i].size(0)] = seq[i]
return result
def collate_fn(self, batch):
input_ids_list, lm_labels_list, app_feat_list, mot_feat_list, question_interval_list, vid_ids_list = [], [], [], [], [], []
for i in batch:
input_ids_list.append(i[0])
lm_labels_list.append(i[1])
app_feat_list.append(i[2])
mot_feat_list.append(i[3])
question_interval_list.append(i[4])
vid_ids_list.append(i[5])
app_feats = torch.stack(app_feat_list, dim=0).float()
mot_feats = torch.stack(mot_feat_list, dim=0).float()
question_intervals = np.array(question_interval_list)
pad_token, app_sep, mot_sep, ph_token = self.tokenizer.convert_tokens_to_ids(
['<pad>', '<s0>', '<s1>', '<place_holder>'])
# All the visual features will not be masked because we do not perform any padding on them
video_mask = torch.ones((len(batch), 16*2 + 2)) == 1 # NOTE *2: 2 modalities | +2: the state tokens | each modality has length 16
# Now we create a dummy input for the video tokens (sole purpose is to reserve the spot of the seperators)
dummy = torch.ones((len(batch), 16)) * ph_token
video_place_holder_ids = torch.cat(
[torch.ones((len(batch), 1)) * app_sep, dummy,
torch.ones((len(batch), 1)) * mot_sep, dummy,
], dim=-1).long()
input_ids = self.padding(input_ids_list, pad_token)
lm_labels = self.padding(lm_labels_list, -100)
text_mask = input_ids != pad_token
input_mask = torch.cat([video_mask, text_mask], dim=1)
# Now we get the intervals of the visual input tokens
# Here the interval do not change across the batch dimension
app_interval = [0, 16 + 1] # the last token is not part of this modality
mot_interval = [16 + 1, 2 * 16 + 2]
vis_state_vector_idx = [app_interval[0], mot_interval[0]]
# adapt the question and history interval -- shifted to the right by the visual input length
question_intervals += 2 * 16 + 2
question_intervals = question_intervals.tolist()
question_state_vector_idx = [x[0] for x in question_intervals]
batch = {
'input_ids': input_ids,
'video_place_holder_ids': video_place_holder_ids,
'app_feats': app_feats,
'mot_feats': mot_feats,
'lm_labels': lm_labels,
'input_mask': input_mask,
'app_interval': app_interval,
'mot_interval': mot_interval,
'question_intervals': question_intervals,
'vis_state_vector_idx': vis_state_vector_idx,
'question_state_vector_idx': question_state_vector_idx
}
return batch
def get_dataset(config, split):
bart_max_input_len = config['bart_max_input_len']
bart_size = config['bart_size']
sample_list_file = os.path.join(config['nextqa_root'], '{}.csv'.format(split))
sample_list = load_file(sample_list_file)
vid_feat_file = os.path.join(config['nextqa_vid_feat'], 'app_mot_{}.h5'.format(split))
print('Load {}...'.format(vid_feat_file))
app_feats = {}
mot_feats = {}
with h5py.File(vid_feat_file, 'r') as fp:
vids = fp['ids']
feats = fp['feat']
for vid, feat in zip(vids, feats):
app_feats[str(vid)] = feat[:, :2048] # (16, 2048)
mot_feats[str(vid)] = feat[:, 2048:] # (16, 2048)
return sample_list, app_feats, mot_feats

179
custom_datasets/segment.py Normal file
View file

@ -0,0 +1,179 @@
from segment_anything import SamPredictor, SamAutomaticMaskGenerator, sam_model_registry
from tqdm import tqdm
from argparse import ArgumentParser
import pickle
import cv2
import os
import torch
import numpy as np
def parse_args():
parser = ArgumentParser()
parser.add_argument(
'--sam_ckpt',
type=str,
help='SAM checkpoint to be used'
)
parser.add_argument(
'--avsd_root',
type=str,
help='Directory where the individual AVSD frames are located'
)
parser.add_argument(
'--crop_root',
type=str,
help='Directory where the individual crops (objects) will be saved'
)
parser.add_argument(
'--embed_root',
type=str,
help='Directory where the individual embeddings will be saved'
)
parser.add_argument(
'--mode',
type=str,
choices=['segment', 'embed'],
help='segment: segment the image into regions | embed: embed the image crops detected during segmentation'
)
parser.add_argument(
'--start',
type=int,
default=0,
help='Start index of the partition'
)
parser.add_argument(
'--end',
type=int,
default=1968,
help='End index of the partition'
)
args = parser.parse_args()
return args
def partition_ids(avsd_ids, start, end):
avsd_ids.sort()
assert start < end
assert start >= 0 and end <= len(avsd_ids)
avsd_ids_partition = avsd_ids[start:end]
return avsd_ids_partition
def get_middle_frames(avsd_ids_partition, avsd_root):
pbar = tqdm(avsd_ids_partition)
pbar.set_description('[INFO] Preparing frames of {} videos'.format(len(avsd_ids_partition)))
path_list = []
for avsd_id in pbar:
frames = os.listdir(os.path.join(avsd_root, avsd_id))
if 'test' in avsd_root:
frames.sort(key=lambda f: int(f.split('_')[-1].split('.')[0]))
else:
frames.sort(key=lambda f: int(f.split('-')[-1].split('.')[0]))
middle_frame = frames[int(len(frames)/2)]
middle_frame = os.path.join(avsd_root, avsd_id, middle_frame)
path_list.append(middle_frame)
return path_list
def segment_images(sam, path_list, crop_root):
mask_generator = SamAutomaticMaskGenerator(sam)
pbar = tqdm(path_list)
pbar.set_description('Detecting Objects')
for pth in pbar:
vid_id = pth.split('/')[-2]
crop_dir = os.path.join(crop_root, vid_id)
if not os.path.isdir(crop_dir):
os.makedirs(crop_dir)
image = cv2.imread(pth)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
masks = mask_generator.generate(image)
masks.sort(key=lambda e: e['stability_score'], reverse=True)
if len(masks) > 36:
masks = masks[:36]
for i, mask in enumerate(masks):
crop = image[
int(mask['bbox'][1]):int(mask['bbox'][1] + mask['bbox'][3] + 1),
int(mask['bbox'][0]):int(mask['bbox'][0] + mask['bbox'][2] + 1),
:
]
crop_flipped = cv2.flip(crop, 1) # Horizontal flip
cv2.imwrite(os.path.join(crop_dir, f'obj_{i}.jpg'), crop)
cv2.imwrite(os.path.join(crop_dir, f'obj_{i}_flipped.jpg'), crop_flipped)
print('[INFO] Done...')
def embed_objects(sam, crop_ids, crop_root, embed_root):
predictor = SamPredictor(sam)
pbar = tqdm(crop_ids)
pbar.set_description('Embedding Objects')
for vid_id in pbar:
embeds = []
crop_dir = os.path.join(crop_root, vid_id)
crop_paths = list(map(lambda p: os.path.join(crop_dir, p), os.listdir(crop_dir)))
crop_paths = list(filter(lambda p: 'flipped' not in p, crop_paths))
crop_paths.sort(key=lambda p: int(p.split('_')[-1].split('.')[0]))
for cp in crop_paths:
crop = cv2.imread(cp)
crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
predictor.set_image(crop)
embed_crop = predictor.get_image_embedding()
embed_crop = embed_crop.mean(-1).mean(-1)
crop_flipped = cv2.flip(crop, 1)
predictor.set_image(crop_flipped)
embed_crop_flipped = predictor.get_image_embedding()
embed_crop_flipped = embed_crop_flipped.mean(-1).mean(-1)
embed = torch.cat((embed_crop, embed_crop_flipped), dim=-1)
# embed = embed.copy().cpu()
embeds.append(embed)
embeds = torch.cat(embeds, 0).cpu().numpy()
np.save(os.path.join(embed_root, f'{vid_id}.npy'), embeds)
print('[INFO] Done...')
def segment(args, sam):
avsd_ids = os.listdir(args.avsd_root)
avsd_ids.sort()
avsd_ids_partition = partition_ids(avsd_ids, args.start, args.end)
path_list = get_middle_frames(avsd_ids_partition, args.avsd_root)
if not os.path.isdir(args.crop_root):
os.makedirs(args.crop_root)
segment_images(sam, path_list, args.crop_root)
def embed(args, sam):
crop_ids = os.listdir(args.crop_root)
crop_ids.sort()
crop_ids_partition = partition_ids(crop_ids, args.start, args.end)
if not os.path.isdir(args.embed_root):
os.makedirs(args.embed_root)
embed_objects(sam, crop_ids_partition, args.crop_root, args.embed_root)
if __name__ == '__main__':
args = parse_args()
sam = sam_model_registry['vit_h'](
checkpoint=args.sam_ckpt)
device = 'cuda'
sam.to(device=device)
assert args.mode in ['segment', 'embed']
if args.mode == 'segment':
segment(args, sam)
else:
embed(args, sam)

0
features/.gitkeep Normal file
View file

63
generate_parallel_avsd.sh Executable file
View file

@ -0,0 +1,63 @@
export MODEL=$1
export TAG=$2
export MODE=$3
export EVAL_DIR=$4
export DSTC=$5
# >>> conda initialize >>>
# !! Contents within this block are managed by 'conda init' !!
__conda_setup="$('/opt/anaconda3/bin/conda' 'shell.bash' 'hook' 2> /dev/null)"
if [ $? -eq 0 ]; then
eval "$__conda_setup"
else
if [ -f "/opt/anaconda3/etc/profile.d/conda.sh" ]; then
. "/opt/anaconda3/etc/profile.d/conda.sh"
else
export PATH="/opt/anaconda3/bin:$PATH"
fi
fi
unset __conda_setup
# <<< conda initialize <<<
conda activate mst_mixer
if [ $DSTC -eq 10 ]; then
export CUDA_VISIBLE_DEVICES=0; python main.py --start_idx_gen 0000 --end_idx_gen 0112 --gen_subset_num 01 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=1; python main.py --start_idx_gen 0112 --end_idx_gen 0224 --gen_subset_num 02 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=2; python main.py --start_idx_gen 0224 --end_idx_gen 0336 --gen_subset_num 03 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=3; python main.py --start_idx_gen 0336 --end_idx_gen 0448 --gen_subset_num 04 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=4; python main.py --start_idx_gen 0448 --end_idx_gen 0560 --gen_subset_num 05 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=5; python main.py --start_idx_gen 0560 --end_idx_gen 0672 --gen_subset_num 06 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=6; python main.py --start_idx_gen 0672 --end_idx_gen 0784 --gen_subset_num 07 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=7; python main.py --start_idx_gen 0784 --end_idx_gen 0896 --gen_subset_num 08 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=0; python main.py --start_idx_gen 0896 --end_idx_gen 1008 --gen_subset_num 09 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=1; python main.py --start_idx_gen 1008 --end_idx_gen 1120 --gen_subset_num 10 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=2; python main.py --start_idx_gen 1120 --end_idx_gen 1232 --gen_subset_num 11 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=3; python main.py --start_idx_gen 1232 --end_idx_gen 1344 --gen_subset_num 12 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=4; python main.py --start_idx_gen 1344 --end_idx_gen 1456 --gen_subset_num 13 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=5; python main.py --start_idx_gen 1456 --end_idx_gen 1568 --gen_subset_num 14 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=6; python main.py --start_idx_gen 1568 --end_idx_gen 1680 --gen_subset_num 15 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=7; python main.py --start_idx_gen 1680 --end_idx_gen 1804 --gen_subset_num 16 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
else
export CUDA_VISIBLE_DEVICES=0; python main.py --start_idx_gen 0000 --end_idx_gen 0107 --gen_subset_num 01 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=1; python main.py --start_idx_gen 0107 --end_idx_gen 0214 --gen_subset_num 02 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=2; python main.py --start_idx_gen 0214 --end_idx_gen 0321 --gen_subset_num 03 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=3; python main.py --start_idx_gen 0321 --end_idx_gen 0428 --gen_subset_num 04 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=4; python main.py --start_idx_gen 0428 --end_idx_gen 0535 --gen_subset_num 05 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=5; python main.py --start_idx_gen 0535 --end_idx_gen 0642 --gen_subset_num 06 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=6; python main.py --start_idx_gen 0642 --end_idx_gen 0749 --gen_subset_num 07 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=7; python main.py --start_idx_gen 0749 --end_idx_gen 0856 --gen_subset_num 08 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=0; python main.py --start_idx_gen 0856 --end_idx_gen 0963 --gen_subset_num 09 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=1; python main.py --start_idx_gen 0963 --end_idx_gen 1070 --gen_subset_num 10 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=2; python main.py --start_idx_gen 1070 --end_idx_gen 1177 --gen_subset_num 11 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=3; python main.py --start_idx_gen 1177 --end_idx_gen 1284 --gen_subset_num 12 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=4; python main.py --start_idx_gen 1284 --end_idx_gen 1391 --gen_subset_num 13 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=5; python main.py --start_idx_gen 1391 --end_idx_gen 1498 --gen_subset_num 14 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=6; python main.py --start_idx_gen 1498 --end_idx_gen 1605 --gen_subset_num 15 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=7; python main.py --start_idx_gen 1605 --end_idx_gen 1710 --gen_subset_num 16 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
fi
wait
python merge_pred_avsd.py --dstc $DSTC

67
generate_parallel_nextqa.sh Executable file
View file

@ -0,0 +1,67 @@
export MODEL=$1
export TAG=$2
export MODE=$3
export EVAL_DIR=$4
export DSTC=$5
# >>> conda initialize >>>
# !! Contents within this block are managed by 'conda init' !!
__conda_setup="$('/opt/anaconda3/bin/conda' 'shell.bash' 'hook' 2> /dev/null)"
if [ $? -eq 0 ]; then
eval "$__conda_setup"
else
if [ -f "/opt/anaconda3/etc/profile.d/conda.sh" ]; then
. "/opt/anaconda3/etc/profile.d/conda.sh"
else
export PATH="/opt/anaconda3/bin:$PATH"
fi
fi
unset __conda_setup
# <<< conda initialize <<<
conda activate mst_mixer
export CUDA_VISIBLE_DEVICES=0; python main.py --start_idx_gen 0000 --end_idx_gen 0285 --gen_subset_num 01 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=0; python main.py --start_idx_gen 0285 --end_idx_gen 0570 --gen_subset_num 02 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=0; python main.py --start_idx_gen 0570 --end_idx_gen 0855 --gen_subset_num 03 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=0; python main.py --start_idx_gen 0855 --end_idx_gen 1140 --gen_subset_num 04 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=1; python main.py --start_idx_gen 1140 --end_idx_gen 1425 --gen_subset_num 05 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=1; python main.py --start_idx_gen 1425 --end_idx_gen 1710 --gen_subset_num 06 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=1; python main.py --start_idx_gen 1710 --end_idx_gen 1995 --gen_subset_num 07 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=1; python main.py --start_idx_gen 1995 --end_idx_gen 2280 --gen_subset_num 08 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=2; python main.py --start_idx_gen 2280 --end_idx_gen 2565 --gen_subset_num 09 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=2; python main.py --start_idx_gen 2565 --end_idx_gen 2850 --gen_subset_num 10 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=2; python main.py --start_idx_gen 2850 --end_idx_gen 3135 --gen_subset_num 11 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=2; python main.py --start_idx_gen 3135 --end_idx_gen 3420 --gen_subset_num 12 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=3; python main.py --start_idx_gen 3420 --end_idx_gen 3705 --gen_subset_num 13 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=3; python main.py --start_idx_gen 3705 --end_idx_gen 3990 --gen_subset_num 14 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=3; python main.py --start_idx_gen 3990 --end_idx_gen 4275 --gen_subset_num 15 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=3; python main.py --start_idx_gen 4275 --end_idx_gen 4560 --gen_subset_num 16 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=4; python main.py --start_idx_gen 4560 --end_idx_gen 4845 --gen_subset_num 17 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=4; python main.py --start_idx_gen 4845 --end_idx_gen 5130 --gen_subset_num 18 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=4; python main.py --start_idx_gen 5130 --end_idx_gen 5415 --gen_subset_num 19 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=4; python main.py --start_idx_gen 5415 --end_idx_gen 5700 --gen_subset_num 20 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=5; python main.py --start_idx_gen 5700 --end_idx_gen 5985 --gen_subset_num 21 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=5; python main.py --start_idx_gen 5985 --end_idx_gen 6270 --gen_subset_num 22 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=5; python main.py --start_idx_gen 6270 --end_idx_gen 6555 --gen_subset_num 23 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=5; python main.py --start_idx_gen 6555 --end_idx_gen 6840 --gen_subset_num 24 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=6; python main.py --start_idx_gen 6840 --end_idx_gen 7125 --gen_subset_num 25 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=6; python main.py --start_idx_gen 7125 --end_idx_gen 7410 --gen_subset_num 26 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=6; python main.py --start_idx_gen 7410 --end_idx_gen 7695 --gen_subset_num 27 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=6; python main.py --start_idx_gen 7695 --end_idx_gen 7980 --gen_subset_num 28 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=7; python main.py --start_idx_gen 7980 --end_idx_gen 8265 --gen_subset_num 29 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=7; python main.py --start_idx_gen 8265 --end_idx_gen 8550 --gen_subset_num 30 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=7; python main.py --start_idx_gen 8550 --end_idx_gen 8835 --gen_subset_num 31 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
export CUDA_VISIBLE_DEVICES=7; python main.py --start_idx_gen 8835 --end_idx_gen 9178 --gen_subset_num 32 --model $MODEL --mode $MODE --eval_dir $EVAL_DIR --tag $TAG & \
wait
python merge_pred_nextqa.py

153
init_utils.py Normal file
View file

@ -0,0 +1,153 @@
import os
import torch
import random
import pyhocon
import datetime
import json
import subprocess
import itertools
import glob
import glog as log
import sys
import re
from os import path as osp
import numpy as np
from custom_datasets.avsd import AVSDDataset
from custom_datasets.nextqa import NextQADataset
from runners.runner_avsd import AVSDRunner
from runners.runner_nextqa import NEXTQARunner
def load_runner(config, tokenizer, vocab_size):
if config['task'] == 'avsd':
return AVSDRunner(config, tokenizer, vocab_size)
elif config['task'] == 'nextqa':
return NEXTQARunner(config, tokenizer, vocab_size)
else:
raise ValueError
def load_dataset(config):
if config['task'] == 'avsd':
dataset = AVSDDataset(config, 'train')
dataset_eval = AVSDDataset(config, 'val')
elif config['task'] == 'nextqa':
dataset = NextQADataset(config, 'train')
dataset_eval = NextQADataset(config, 'val')
else:
raise ValueError
return dataset, dataset_eval
def set_random_seed(random_seed):
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
random.seed(random_seed)
np.random.seed(random_seed)
def copy_file_to_log(log_dir):
dirs_to_cp = ['.', 'config', 'datasets', 'runners', 'models']
files_to_cp = ['*.py', '*.json', '*.sh', '*.conf']
for dir_name in dirs_to_cp:
dir_name = osp.join(log_dir, 'code', dir_name)
if not osp.exists(dir_name):
os.makedirs(dir_name)
for dir_name, file_name in itertools.product(dirs_to_cp, files_to_cp):
filename = osp.join(dir_name, file_name)
if len(glob.glob(filename)) > 0:
os.system(f'cp {filename} {osp.join(log_dir, "code", dir_name)}')
log.info(f'Files copied to {osp.join(log_dir, "code")}')
def set_log_file(fname, file_only=False):
# if fname already exists, find all log file under log dir,
# and name the current log file with a new number
if osp.exists(fname):
prefix, suffix = osp.splitext(fname)
log_files = glob.glob(prefix + '*' + suffix)
count = 0
for log_file in log_files:
num = re.search(r'(\d+)', log_file)
if num is not None:
num = int(num.group(0))
count = max(num, count)
fname = fname.replace(suffix, str(count + 1) + suffix)
# set log file
# simple tricks for duplicating logging destination in the logging module such as:
# logging.getLogger().addHandler(logging.FileHandler(filename))
# does NOT work well here, because python Traceback message (not via logging module) is not sent to the file,
# the following solution (copied from : https://stackoverflow.com/questions/616645) is a little bit
# complicated but simulates exactly the "tee" command in linux shell, and it redirects everything
if file_only:
# we only output messages to file, and stdout/stderr receives nothing.
# this feature is designed for executing the script via ssh:
# since ssh has a windowing kind of flow control, i.e., if the controller does not read data from a
# ssh channel and its buffer fills up, the execution machine will not be able to write anything into the
# channel and the process will be set to sleeping (S) status until someone reads all data from the channel.
# this is not desired since we do not want to read stdout/stderr from the controller machine.
# so, here we use a simple solution: disable output to stdout/stderr and only output messages to log file.
log.logger.handlers[0].stream = log.handler.stream = sys.stdout = sys.stderr = f = open(fname, 'w', buffering=1)
else:
# we output messages to both file and stdout/stderr
tee = subprocess.Popen(['tee', fname], stdin=subprocess.PIPE)
os.dup2(tee.stdin.fileno(), sys.stdout.fileno())
os.dup2(tee.stdin.fileno(), sys.stderr.fileno())
def initialize_from_env(model, mode, model_type, eval_dir, tag=''):
if "GPU" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ['GPU']
if mode in ['train', 'debug']:
config = pyhocon.ConfigFactory.parse_file(f"config/{model_type}.conf")[model]
else:
path_config = os.path.join(eval_dir, 'code', f"config/{model_type}.conf")
config = pyhocon.ConfigFactory.parse_file(path_config)[model]
config['log_dir'] = eval_dir
config['model_type'] = model_type
if "CUDA_VISIBLE_DEVICES" in os.environ:
config['num_gpus'] = len(os.environ["CUDA_VISIBLE_DEVICES"].split(','))
# multi-gpu setting
if config['num_gpus'] > 1:
os.environ['MASTER_ADDR'] = 'localhost'
os.environ["MASTER_PORT"] = str(config['master_port'])
if mode == 'debug':
model += '_debug'
if tag:
model += '-' + tag
if mode != 'generate':
config["log_dir"] = os.path.join(config["log_dir"], model)
if not os.path.exists(config["log_dir"]):
os.makedirs(config["log_dir"])
config['timestamp'] = datetime.datetime.now().strftime('%m%d-%H%M%S')
# Choose the correct config file and add the BART json file to config
if mode in ['train', 'debug']:
config['bart_config'] = config['{}_bart_{}_config'.format(
config['task'], config['bart_size'])]
else:
config['bart_config'] = os.path.join(eval_dir, 'code', 'config/{}_bart_{}.json'.format(
config['task'], config['bart_size']))
config['bart_config_json'] = json.load(open(config['bart_config'], 'r'))
config['overfit'] = config['overfit_size'] > 0
return config
def set_training_steps(config, num_samples):
if config['parallel'] and config['dp_type'] == 'dp':
config['num_iter_per_epoch'] = int(np.ceil(num_samples / config['batch_size']))
else:
config['num_iter_per_epoch'] = int(np.ceil(num_samples / (config['batch_size'] * config['num_gpus'])))
if 'train_steps' not in config:
config['train_steps'] = config['num_iter_per_epoch'] * config['num_epochs']
if 'warmup_steps' not in config:
config['warmup_steps'] = int(config['train_steps'] * config['warmup_ratio'])
return config

225
main.py Normal file
View file

@ -0,0 +1,225 @@
from init_utils import (
load_runner,
load_dataset,
set_random_seed,
set_training_steps,
initialize_from_env,
set_log_file,
copy_file_to_log
)
import os
import sys
import argparse
import pyhocon
import glog as log
import socket
import getpass
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import torch.distributed as dist
from transformers import BartTokenizer
from custom_datasets.avsd import get_dataset as get_avsd_dataset
from custom_datasets.nextqa import get_dataset as get_nextqa_dataset
parser = argparse.ArgumentParser(description='Main script for MST-MIXER')
parser.add_argument(
'--model',
type=str,
default='mst_mixer/mixer',
help='model name to train or test')
parser.add_argument(
'--mode',
type=str,
default='train',
help='train, generate or debug'
)
parser.add_argument(
'--eval_dir',
type=str,
default='ckpt/avsd'
)
parser.add_argument(
'--wandb_project',
type=str,
default='mst_mixer'
)
parser.add_argument(
'--wandb_mode',
type=str,
default='offline',
choices=['online', 'offline', 'disabled', 'run', 'dryrun']
)
parser.add_argument(
'--tag',
type=str,
default='full_model',
help="Tag to differentiate the models"
)
parser.add_argument(
'--start_idx_gen',
type=int,
default=0,
help="The start index for generation"
)
parser.add_argument(
'--end_idx_gen',
type=int,
default=10,
help="The end index for generation"
)
parser.add_argument(
'--gen_subset_num',
type=int,
default=1,
help="The index of the test split for generation"
)
parser.add_argument('--ssh', action='store_true',
help='whether or not we are executing command via ssh. '
'If set to True, we will not log.info anything to screen and only redirect them to log file')
def main(gpu, config, args):
config['training'] = args.mode == 'train'
config['debugging'] = args.mode == 'debug'
config['generating'] = args.mode == 'generate'
config['wandb_project'] = args.wandb_project
config['wandb_mode'] = 'disabled'
if config['training']:
config['wandb_mode'] = args.wandb_mode
# When generating, only use 1 GPU
if config['generating']:
assert config['num_gpus'] == 1, 'When generating, only use 1 GPU!'
if config['parallel'] and config['dp_type'] != 'dp':
config['rank'] = gpu
dist.init_process_group(
backend='nccl',
# init_method='env://',
world_size=config['num_gpus'],
rank=gpu
)
config['display'] = gpu == 0
torch.cuda.set_device(gpu)
else:
config['display'] = True
if config['debugging'] or (config['parallel'] and config['dp_type'] != 'dp'):
config['num_workers'] = 0
# set logs
if config['training']:
log_file = os.path.join(config["log_dir"], f'{args.mode}.log')
set_log_file(log_file, file_only=args.ssh)
# print environment info
if config['display']:
log.info('Host: {}, user: {}, CUDA_VISIBLE_DEVICES: {}, cwd: {}'.format(
socket.gethostname(), getpass.getuser(), os.environ.get('CUDA_VISIBLE_DEVICES', ''), os.getcwd()))
log.info('Command line is: {}'.format(' '.join(sys.argv)))
if config['parallel'] and config['dp_type'] != 'dp':
log.info(f'World_size: {config["num_gpus"]}, cur rank: {config["rank"]}')
log.info(f"Running experiment: {args.model}")
log.info(f"Results saved to {config['log_dir']}")
# initialization
if config['display'] and config['training']:
copy_file_to_log(config['log_dir'])
set_random_seed(config['random_seed'])
device = torch.device(f"cuda:{gpu}")
if config["use_cpu"]:
device = torch.device("cpu")
config['device'] = device
# prepare datasets (train and validation)
dataset, dataset_eval = load_dataset(config)
# set training steps
if not config['generating'] or config['parallel']:
config = set_training_steps(config, len(dataset))
if config['display']:
log.info(pyhocon.HOCONConverter.convert(config, "hocon"))
# load runner
runner = load_runner(config, dataset.tokenizer, dataset.vocab_size)
# parallel
if config['parallel']:
if config['dp_type'] == 'ddp':
torch.cuda.set_device(gpu)
runner.model = runner.model.to(gpu)
runner.model = nn.parallel.DistributedDataParallel(
runner.model,
device_ids=[gpu],
output_device=gpu,
find_unused_parameters=True
)
else:
raise ValueError(f'Unrecognized dp_type: {config["dp_type"]}')
if config['training'] or config['debugging']:
ckpt_path = config.get('start_path', None)
runner.load_ckpt(ckpt_path=ckpt_path)
runner.train(dataset, dataset_eval)
elif config['generating']:
if config['loads_start_path']:
runner.load_ckpt(config['start_ckpt_for_generating'])
else:
runner.load_ckpt_best()
assert args.gen_subset_num > 0
# Load the data
if config['task'] == 'avsd':
# load the saved tokenizer
tokenizer = BartTokenizer.from_pretrained(os.path.join(config['log_dir'], 'bart_tokenizer'))
test_dataset, _ = get_avsd_dataset(config, 'test', tokenizer)
assert args.start_idx_gen >= 0 and args.end_idx_gen <= len(test_dataset) and args.start_idx_gen < args.end_idx_gen
test_dataset = test_dataset[args.start_idx_gen:args.end_idx_gen]
runner.generate(
test_dataset, args.tag, tokenizer, gen_subset_num=args.gen_subset_num
)
elif config['task'] == 'nextqa':
# load the saved tokenizer
tokenizer = BartTokenizer.from_pretrained(os.path.join(config['log_dir'], 'bart_tokenizer'))
test_dataset, app_feats, mot_feats = get_nextqa_dataset(config, 'test')
assert args.start_idx_gen >= 0 and args.end_idx_gen <= len(test_dataset) and args.start_idx_gen < args.end_idx_gen
test_dataset = test_dataset[args.start_idx_gen:args.end_idx_gen]
runner.generate(
test_dataset, app_feats, mot_feats, args.tag, tokenizer, args.start_idx_gen, args.end_idx_gen, gen_subset_num=args.gen_subset_num
)
else:
raise ValueError
if config['parallel']:
dist.destroy_process_group()
if __name__ == '__main__':
args = parser.parse_args()
# initialization
model_type, model_name = args.model.split('/')
config = initialize_from_env(model_name, args.mode, model_type, args.eval_dir, tag=args.tag)
if config['num_gpus'] > 1:
config['parallel'] = True
mp.spawn(main, nprocs=config['num_gpus'], args=(config, args))
else:
config['parallel'] = False
main(0, config, args)

61
merge_pred_avsd.py Normal file
View file

@ -0,0 +1,61 @@
import os
import json
import argparse
parser = argparse.ArgumentParser(description='Main script for MST-MIXER')
parser.add_argument(
'--dstc',
type=int,
default=8,
choices=[7, 8, 10],
help='DSTC challenge identifier')
args = parser.parse_args()
assert args.dstc in [7, 8, 10]
if args.dstc == 7:
output_dir = 'output/dstc7'
raw_data_path = 'raw_data/test_set4DSTC7-AVSD.json'
elif args.dstc == 8:
output_dir = 'output/dstc8'
raw_data_path = 'raw_data/test_set4DSTC8-AVSD.json'
else:
output_dir = 'output/dstc10'
raw_data_path = 'raw_data/test_set4DSTC10-AVSD.json'
with open(raw_data_path, 'r') as f:
raw_dialogs = json.load(f)['dialogs']
file_paths = os.listdir(output_dir)
file_paths = list(filter(lambda f: 'part' in f , file_paths))
name = file_paths[0]
file_paths = list(map(lambda f: os.path.join(output_dir, f), file_paths))
dialogs = {}
for pth in file_paths:
with open(pth, 'r') as f:
data = json.load(f)
for dialog in data['dialogs']:
vid_id = dialog['image_id']
dialogs[vid_id] = dialog
# dialogs.extend(data['dialogs'])
os.remove(pth)
# Now, re-establish the original order of the dialogs
res = []
for dialog in raw_dialogs:
vid_id = dialog['image_id']
res.append(dialogs[vid_id])
res = {
'dialogs': res
}
name = "".join(name.split('-')[:-1]) + '.json'
output_path = os.path.join(output_dir, name)
with open(output_path, 'w') as f:
json.dump(res, f, indent=4)
print('[INFO] Files merged and saved in {}'.format(output_path))

34
merge_pred_nextqa.py Normal file
View file

@ -0,0 +1,34 @@
import os
import json
import argparse
parser = argparse.ArgumentParser(description='Main script for MST-MIXER')
args = parser.parse_args()
output_dir = 'output/nextqa'
file_paths = os.listdir(output_dir)
file_paths = list(filter(lambda f: 'part' in f , file_paths))
name = file_paths[0]
file_paths = list(map(lambda f: os.path.join(output_dir, f), file_paths))
results = {}
for pth in file_paths:
with open(pth, 'r') as f:
data = json.load(f)
for video_id in data:
if video_id not in results:
results[video_id] = data[video_id]
else:
for qid in data[video_id]:
if qid not in results[video_id]:
results[video_id][qid] = data[video_id][qid]
os.remove(pth)
name = "".join(name.split('-')[:-1]) + '.json'
output_path = os.path.join(output_dir, name)
with open(output_path, 'w') as f:
json.dump(results, f, indent=4)
print('[INFO] Files merged and saved in {}'.format(output_path))

BIN
misc/italy.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.9 KiB

BIN
misc/mixer.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.9 MiB

BIN
misc/teaser.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 197 KiB

0
models/__init__.py Normal file
View file

1438
models/avsd_bart.py Normal file

File diff suppressed because it is too large Load diff

801
models/gnns.py Normal file
View file

@ -0,0 +1,801 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn.dense import DenseGATConv, DenseGCNConv, DenseSAGEConv
from torch.nn.parameter import Parameter
from typing import Optional, Tuple
from .utils import get_knn_graph
import torch_sparse
class BartAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.reshape(*proj_shape)
value_states = value_states.reshape(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if output_attentions:
# this operation is a bit awkward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to be reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
else:
attn_weights_reshaped = None
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_probs, value_states)
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned across GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights_reshaped, past_key_value
class MLPModule(nn.Module):
def __init__(self, d_in, d_hidden, d_out, num_layers=3, dropout=0.3, use_non_linear=False, use_batch_norm=False):
super(MLPModule, self).__init__()
self.use_batch_norm = use_batch_norm
self.dropout = dropout
self.fcs = nn.ModuleList()
self.batch_norms = nn.ModuleList()
if num_layers == 1:
self.fcs.append(nn.Linear(d_in, d_out))
else:
self.fcs.append(nn.Linear(d_in, d_hidden))
self.batch_norms.append(nn.BatchNorm1d(d_hidden))
for _ in range(num_layers - 2):
self.fcs.append(nn.Linear(d_hidden, d_hidden))
self.batch_norms.append(nn.BatchNorm1d(d_hidden))
self.fcs.append(nn.Linear(d_hidden, d_out))
self.act_fn = nn.GELU()
self.dropout = nn.Dropout(dropout)
self.use_non_linear=use_non_linear
def reset_parameters(self):
for fc in self.fcs:
fc.reset_parameters()
for bn in self.batch_norms:
bn.reset_parameters()
def forward(self, X):
for fc, bn in zip(self.fcs[:-1], self.batch_norms):
X = fc(X)
X = self.act_fn(X)
if self.use_batch_norm:
if X.dim() > 2:
X = X.transpose(1, 2)
X = bn(X)
if X.dim() > 2:
X = X.transpose(1, 2)
X = self.dropout(X)
X = self.fcs[-1](X)
return X
class GATModule(nn.Module):
def __init__(self, d_in, d_hidden, d_out, num_layers=3, dropout=0.3, concat=True, heads=2, use_non_linear=False, use_batch_norm=False):
super(GATModule, self).__init__()
self.gnns = nn.ModuleList()
if concat:
d_hidden = d_hidden // heads
d_out = d_out // heads
self.gnns.append(DenseGATConv(d_in, d_hidden, heads=heads, concat=concat, dropout=dropout))
self.batch_norms = nn.ModuleList()
self.batch_norms.append(nn.BatchNorm1d(d_hidden * heads if concat else d_hidden))
for _ in range(num_layers - 2):
self.gnns.append(DenseGATConv(
d_hidden * heads if concat else d_hidden, d_hidden,
heads=heads,
concat=concat,
dropout=dropout)
)
self.batch_norms.append(nn.BatchNorm1d(d_hidden * heads if concat else d_hidden))
self.gnns.append(DenseGATConv(
d_hidden * heads if concat else d_hidden, d_out,
heads=heads,
concat=concat,
dropout=dropout)
)
self.dropout = nn.Dropout(dropout)
self.non_linear = nn.GELU()
self.use_batch_norm = use_batch_norm
self.use_non_linear = use_non_linear
def reset_parameters(self):
for gnn in self.gnns:
gnn.reset_parameters()
for batch_norm in self.batch_norms:
batch_norm.reset_parameters()
def forward(self, X, A):
Z = self.dropout(X)
for i in range(len(self.gnns) - 1):
Z = self.gnns[i](Z, A)
if self.use_batch_norm:
Z = Z.transpose(1, 2)
Z = self.batch_norms[i](Z)
Z = Z.transpose(1, 2)
if self.use_non_linear:
Z = self.non_linear(Z)
Z = self.dropout(Z)
Z = self.gnns[-1](Z, A)
return Z
class GCNModule(nn.Module):
def __init__(self, d_in, d_hidden, d_out, num_layers=3, dropout=0.3, use_non_linear=False, use_batch_norm=False):
super(GCNModule, self).__init__()
self.gnns = nn.ModuleList()
self.gnns.append(DenseGCNConv(d_in, d_hidden))
self.batch_norms = nn.ModuleList()
self.batch_norms.append(nn.BatchNorm1d(d_hidden))
for _ in range(num_layers - 2):
self.gnns.append(DenseGCNConv(
d_hidden, d_hidden)
)
self.batch_norms.append(nn.BatchNorm1d(d_hidden))
self.gnns.append(DenseGCNConv(
d_hidden, d_out)
)
self.dropout = nn.Dropout(dropout)
self.non_linear = nn.GELU()
self.use_batch_norm = use_batch_norm
self.use_non_linear = use_non_linear
def reset_parameters(self):
for gnn in self.gnns:
gnn.reset_parameters()
for batch_norm in self.batch_norms:
batch_norm.reset_parameters()
def forward(self, X, A):
Z = self.dropout(X)
for i in range(len(self.gnns) - 1):
Z = self.gnns[i](Z, A)
if self.use_batch_norm:
Z = Z.transpose(1, 2)
Z = self.batch_norms[i](Z)
Z = Z.transpose(1, 2)
if self.use_non_linear:
Z = self.non_linear(Z)
Z = self.dropout(Z)
Z = self.gnns[-1](Z, A)
return Z
class SAGEModule(nn.Module):
def __init__(self, d_in, d_hidden, d_out, num_layers=3, dropout=0.3, use_non_linear=False, use_batch_norm=False):
super(SAGEModule, self).__init__()
self.gnns = nn.ModuleList()
self.gnns.append(DenseSAGEConv(d_in, d_hidden))
self.batch_norms = nn.ModuleList()
self.batch_norms.append(nn.BatchNorm1d(d_hidden))
for _ in range(num_layers - 2):
self.gnns.append(DenseSAGEConv(
d_hidden, d_hidden)
)
self.batch_norms.append(nn.BatchNorm1d(d_hidden))
self.gnns.append(DenseSAGEConv(
d_hidden, d_out)
)
self.dropout = nn.Dropout(dropout)
self.non_linear = nn.GELU()
self.use_batch_norm = use_batch_norm
self.use_non_linear = use_non_linear
def reset_parameters(self):
for gnn in self.gnns:
gnn.reset_parameters()
for batch_norm in self.batch_norms:
batch_norm.reset_parameters()
def forward(self, X, A):
Z = self.dropout(X)
for i in range(len(self.gnns) - 1):
Z = self.gnns[i](Z, A)
if self.use_batch_norm:
Z = Z.transpose(1, 2)
Z = self.batch_norms[i](Z)
Z = Z.transpose(1, 2)
if self.use_non_linear:
Z = self.non_linear(Z)
Z = self.dropout(Z)
Z = self.gnns[-1](Z, A)
return Z
class GlobalGraphLearner(nn.Module):
def __init__(self, d_in, num_heads, random=False):
super(GlobalGraphLearner, self).__init__()
self.random = random
if not self.random:
w = torch.Tensor(num_heads, d_in)
self.w = Parameter(nn.init.xavier_uniform_(w), requires_grad=True)
def reset_parameters(self):
if not self.random:
self.w = Parameter(nn.init.xavier_uniform_(self.w))
def forward(self, Z):
if self.random:
att_global = torch.randn((Z.size(0), Z.size(1), Z.size(1))).to(Z.device)
else:
w_expanded = self.w.unsqueeze(1).unsqueeze(1)
Z = Z.unsqueeze(0) * w_expanded
Z = F.normalize(Z, p=2, dim=-1)
att_global = torch.matmul(Z, Z.transpose(-1, -2)).mean(0)
mask_global = (att_global > 0).detach().float()
att_global = att_global * mask_global
return att_global
class DenseAPPNP(nn.Module):
def __init__(self, K, alpha):
super().__init__()
self.K = K
self.alpha = alpha
def forward(self, x, adj_t):
h = x
for _ in range(self.K):
if adj_t.is_sparse:
x = torch_sparse.spmm(adj_t, x)
else:
x = torch.matmul(adj_t, x)
x = x * (1 - self.alpha)
x += self.alpha * h
x /= self.K
return x
class Dense_APPNP_Net(nn.Module):
def __init__(self, d_in, d_hidden, d_out, dropout=.5, K=10, alpha=.1):
super(Dense_APPNP_Net, self).__init__()
self.lin1 = nn.Linear(d_in, d_hidden)
self.lin2 = nn.Linear(d_hidden, d_out)
self.prop1 = DenseAPPNP(K, alpha)
self.dropout = dropout
def reset_parameters(self):
self.lin1.reset_parameters()
self.lin2.reset_parameters()
def forward(self, x, adj_t):
x = F.dropout(x, p=self.dropout, training=self.training)
x = F.relu(self.lin1(x))
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.lin2(x)
x = self.prop1(x, adj_t)
return x
class MMGraphLearner(nn.Module):
def __init__(self, d_in, num_heads, random=False):
super(MMGraphLearner, self).__init__()
self.random = random
if not self.random:
w = torch.Tensor(num_heads, d_in)
self.w = Parameter(nn.init.xavier_uniform_(w), requires_grad=True)
self.fc = nn.Linear(d_in, d_in)
def reset_parameters(self):
if not self.random:
self.fc.reset_parameters()
self.w = Parameter(nn.init.xavier_uniform_(self.w), requires_grad=True)
def forward(self, features):
if self.random:
att = torch.randn((features.size(0), features.size(1), features.size(1))).to(features.device)
else:
features = self.fc(features)
w_expanded = self.w.unsqueeze(1).unsqueeze(1)
features = features.unsqueeze(0) * w_expanded
features = F.normalize(features, p=2, dim=-1)
att = torch.matmul(features, features.transpose(-1, -2)).mean(0)
mask = (att > 0).detach().float()
att = att * mask
return att
class QNetLocal(nn.Module):
def __init__(self, config):
super(QNetLocal, self).__init__()
self.config=config
self.mm_gnn_modules = nn.ModuleList()
self.mm_graph_learners_1 = nn.ModuleList()
self.mm_graph_learners_2 = nn.ModuleList()
for _ in range(self.config.num_modalities):
if self.config.gnn_type == 'gat':
self.mm_gnn_modules.append(GATModule(
self.config.d_model, self.config.d_model, self.config.d_model,
num_layers=self.config.num_local_gnn_layers,
heads=self.config.num_local_gnn_heads,
dropout=self.config.local_gnn_dropout,
concat=self.config.local_gnn_concat,
use_batch_norm=self.config.use_local_gnn_bn,
use_non_linear=self.config.use_non_linear
)
)
elif self.config.gnn_type == 'appnp':
self.mm_gnn_modules.append(Dense_APPNP_Net(
self.config.d_model, self.config.d_model, self.config.d_model, dropout=self.config.local_gnn_dropout,
K=self.config.gnn_K, alpha=self.config.gnn_alpha
)
)
elif self.config.gnn_type == 'gcn':
self.mm_gnn_modules.append(GCNModule(
self.config.d_model, self.config.d_model, self.config.d_model,
num_layers=self.config.num_local_gnn_layers,
dropout=self.config.local_gnn_dropout,
use_batch_norm=self.config.use_local_gnn_bn,
use_non_linear=self.config.use_non_linear
)
)
elif self.config.gnn_type == 'sage':
self.mm_gnn_modules.append(SAGEModule(
self.config.d_model, self.config.d_model, self.config.d_model,
num_layers=self.config.num_local_gnn_layers,
dropout=self.config.local_gnn_dropout,
use_batch_norm=self.config.use_local_gnn_bn,
use_non_linear=self.config.use_non_linear
)
)
else:
raise ValueError
self.mm_graph_learners_1.append(MMGraphLearner(self.config.d_model, self.config.num_local_gr_learner_heads, random=self.config.use_random_graphs))
self.mm_graph_learners_2.append(MMGraphLearner(self.config.d_model * 2, self.config.num_local_gr_learner_heads, random=self.config.use_random_graphs))
def reset_parameters(self):
for i in range(self.config.num_modalities):
self.mm_gnn_modules[i].reset_parameters()
self.mm_graph_learners_1[i].reset_parameters()
self.mm_graph_learners_2[i].reset_parameters()
def forward(self, features, A_tildes=None):
mm_Xs = features# []
device = features[0].device
if A_tildes is None:
A_tildes = []
for mm_X in mm_Xs:
A_tildes.append(get_knn_graph(mm_X, self.config.num_nn, device))
################# Multi-modal graph learner (upper branch) #################
A_primes = []
for i, mm_X in enumerate(mm_Xs): # iterate over the modalities
A_primes.append(self.mm_graph_learners_1[i](mm_X))
# Linear combination of A_primes with A_tildes
A_primes = [(1 - self.config.init_adj_ratio) * A_prime + self.config.init_adj_ratio * A_tilde for A_prime, A_tilde in zip(A_primes, A_tildes)]
################# Multi-modal gnn (upper branch) #################
Z_primes = []
for i, (mm_X, A_prime) in enumerate(zip(mm_Xs, A_primes)):
Z_primes.append(self.mm_gnn_modules[i](mm_X, A_prime))
################# Multi-modal gnn (lower branch) #################
Z_double_primes = []
for i, (mm_X, A_tilde) in enumerate(zip(mm_Xs, A_tildes)):
Z_double_primes.append(self.mm_gnn_modules[i](mm_X, A_tilde))
Z_concats = [torch.cat([Z_1, Z_2], dim=-1) for Z_1, Z_2 in zip(Z_primes, Z_double_primes)]
################# Multi-modal graph learner (lower branch) #################
A_double_primes = []
for i, Z_concat in enumerate(Z_concats):
A_double_primes.append(self.mm_graph_learners_2[i](Z_concat))
A_double_primes = [(1 - self.config.init_adj_ratio) * A_double_prime + self.config.init_adj_ratio * A_tilde for A_double_prime, A_tilde in zip(A_double_primes, A_tildes)]
As = [(1 - self.config.adj_ratio) * A_prime + self.config.adj_ratio * A_double_prime for A_prime, A_double_prime in zip(A_primes, A_double_primes)]
################## Average across all multimodal inputs ##################
Zs = [0.5 * Z1 + 0.5 * Z2 for Z1, Z2 in zip(Z_primes, Z_double_primes)]
return As, Zs
class PNetLocal(nn.Module):
def __init__(self, config):
super(PNetLocal, self).__init__()
self.config = config
self.mm_gnn_modules = nn.ModuleList()
self.mm_mlp_modules = nn.ModuleList()
self.mm_graph_learners_1 = nn.ModuleList()
self.mm_graph_learners_2 = nn.ModuleList()
for _ in range(self.config.num_modalities):
if self.config.gnn_type == 'gat':
self.mm_gnn_modules.append(GATModule(
self.config.d_model, self.config.d_model, self.config.d_model,
num_layers=self.config.num_local_gnn_layers,
heads=self.config.num_local_gnn_heads,
dropout=self.config.local_gnn_dropout,
concat=self.config.local_gnn_concat,
use_batch_norm=self.config.use_local_gnn_bn,
use_non_linear=self.config.use_non_linear
)
)
elif self.config.gnn_type == 'appnp':
self.mm_gnn_modules.append(Dense_APPNP_Net(
self.config.d_model, self.config.d_model, self.config.d_model, dropout=self.config.local_gnn_dropout,
K=self.config.gnn_K, alpha=self.config.gnn_alpha
)
)
elif self.config.gnn_type == 'gcn':
self.mm_gnn_modules.append(GCNModule(
self.config.d_model, self.config.d_model, self.config.d_model,
num_layers=self.config.num_local_gnn_layers,
dropout=self.config.local_gnn_dropout,
use_batch_norm=self.config.use_local_gnn_bn,
use_non_linear=self.config.use_non_linear
)
)
elif self.config.gnn_type == 'sage':
self.mm_gnn_modules.append(SAGEModule(
self.config.d_model, self.config.d_model, self.config.d_model,
num_layers=self.config.num_local_gnn_layers,
dropout=self.config.local_gnn_dropout,
use_batch_norm=self.config.use_local_gnn_bn,
use_non_linear=self.config.use_non_linear
)
)
else:
raise ValueError
self.mm_mlp_modules.append(MLPModule(
self.config.d_model, self.config.d_model, self.config.d_model,
num_layers=self.config.num_local_fc_layers,
dropout=self.config.local_fc_dropout,
use_batch_norm=self.config.use_local_fc_bn,
use_non_linear=self.config.use_non_linear
))
self.mm_graph_learners_1.append(MMGraphLearner(self.config.d_model, self.config.num_local_gr_learner_heads, random=self.config.use_random_graphs))
self.mm_graph_learners_2.append(MMGraphLearner(self.config.d_model * 2, self.config.num_local_gr_learner_heads, random=self.config.use_random_graphs))
def reset_parameters(self):
for i in range(self.config.num_modalities):
self.mm_gnn_modules[i].reset_parameters()
self.mm_mlp_modules[i].reset_parameters()
self.mm_graph_learners_1[i].reset_parameters()
self.mm_graph_learners_2[i].reset_parameters()
def forward(self, features):
mm_Xs = features
################# Multi-modal graph learner (upper branch) #################
A_primes = []
for i, mm_X in enumerate(mm_Xs): # iterate over the modalities
A_primes.append(self.mm_graph_learners_1[i](mm_X))
################# Multi-modal gnn (upper branch) #################
Z_primes = []
for i, (mm_X, A_prime) in enumerate(zip(mm_Xs, A_primes)):
Z_primes.append(self.mm_gnn_modules[i](mm_X, A_prime))
################# Multi-modal gnn (lower branch) #################
Z_double_primes = []
for i, mm_X, in enumerate(mm_Xs):
Z_double_primes.append(self.mm_mlp_modules[i](mm_X))
Z_concats = [torch.cat([Z_1, Z_2], dim=-1) for Z_1, Z_2 in zip(Z_primes, Z_double_primes)]
################# Multi-modal graph learner (lower branch) #################
A_double_primes = []
for i, Z_concat in enumerate(Z_concats):
A_double_primes.append(self.mm_graph_learners_2[i](Z_concat))
As = [(1 - self.config.adj_ratio) * A_prime + self.config.adj_ratio * A_double_prime for A_prime, A_double_prime in zip(A_primes, A_double_primes)]
################## Average across all multimodal inputs ##################
Zs = [0.5 * Z1 + 0.5 * Z2 for Z1, Z2 in zip(Z_primes, Z_double_primes)]
return As, Zs
class QNetGlobal(nn.Module):
def __init__(self, config):
super(QNetGlobal, self).__init__()
self.config = config
if self.config.gnn_type == 'gat':
self.gnn = GATModule(
self.config.d_model, self.config.d_model, self.config.d_model,
num_layers=self.config.num_global_gnn_layers,
heads=self.config.num_global_gnn_heads,
dropout=self.config.global_gnn_dropout,
concat=self.config.global_gnn_concat,
use_batch_norm=self.config.use_global_gnn_bn,
use_non_linear=self.config.use_non_linear
)
elif self.config.gnn_type == 'appnp':
self.gnn = Dense_APPNP_Net(
self.config.d_model, self.config.d_model, self.config.d_model, dropout=self.config.local_gnn_dropout,
K=self.config.gnn_K, alpha=self.config.gnn_alpha
)
elif self.config.gnn_type == 'gcn':
self.gnn = GCNModule(
self.config.d_model, self.config.d_model, self.config.d_model,
num_layers=self.config.num_global_gnn_layers,
dropout=self.config.global_gnn_dropout,
use_batch_norm=self.config.use_global_gnn_bn,
use_non_linear=self.config.use_non_linear
)
elif self.config.gnn_type == 'sage':
self.gnn = SAGEModule(
self.config.d_model, self.config.d_model, self.config.d_model,
num_layers=self.config.num_global_gnn_layers,
dropout=self.config.global_gnn_dropout,
use_batch_norm=self.config.use_global_gnn_bn,
use_non_linear=self.config.use_non_linear
)
else:
raise ValueError
self.graph_learner_1 = GlobalGraphLearner(self.config.d_model, self.config.num_global_gr_learner_heads, self.config.use_random_graphs)
self.graph_learner_2 = GlobalGraphLearner(self.config.d_model * 2, self.config.num_global_gr_learner_heads, self.config.use_random_graphs)
def reset_parameters(self):
self.gnn.reset_parameters()
self.graph_learner_1.reset_parameters()
self.graph_learner_2.reset_parameters()
def forward(self, Z, A):
################# Graph learner (upper branch) #################
A_prime = self.graph_learner_1(Z)
A_prime = (1-self.config.init_adj_ratio) * A_prime + self.config.init_adj_ratio * A
################# Gnn (upper branch) #################
Z_prime = self.gnn(Z, A_prime)
################# Gnn (lower branch) #################
Z_double_prime = self.gnn(Z, A)
Z_concat = torch.cat([Z_prime, Z_double_prime], dim=-1)
################# Graph learner (lower branch) #################
A_double_prime = self.graph_learner_2(Z_concat)
A_double_prime = (1-self.config.init_adj_ratio) * A_double_prime + self.config.init_adj_ratio * A
################## Average across branches ##################
A_global = (1 - self.config.adj_ratio) * A_prime + self.config.adj_ratio * A_double_prime
Z_global = 0.5 * Z_prime + 0.5 * Z_double_prime
return A_global, Z_global
class PNetGlobal(nn.Module):
def __init__(self, config):
super(PNetGlobal, self).__init__()
self.config = config
if self.config.gnn_type == 'gat':
self.gnn = GATModule(
self.config.d_model, self.config.d_model, self.config.d_model,
num_layers=self.config.num_global_gnn_layers,
heads=self.config.num_global_gnn_heads,
dropout=self.config.global_gnn_dropout,
concat=self.config.global_gnn_concat,
use_batch_norm=self.config.use_global_gnn_bn,
use_non_linear=self.config.use_non_linear
)
elif self.config.gnn_type == 'appnp':
self.gnn = Dense_APPNP_Net(
self.config.d_model, self.config.d_model, self.config.d_model, dropout=self.config.local_gnn_dropout,
K=self.config.gnn_K, alpha=self.config.gnn_alpha
)
elif self.config.gnn_type == 'gcn':
self.gnn = GCNModule(
self.config.d_model, self.config.d_model, self.config.d_model,
num_layers=self.config.num_global_gnn_layers,
dropout=self.config.global_gnn_dropout,
use_batch_norm=self.config.use_global_gnn_bn,
use_non_linear=self.config.use_non_linear
)
elif self.config.gnn_type == 'sage':
self.gnn = SAGEModule(
self.config.d_model, self.config.d_model, self.config.d_model,
num_layers=self.config.num_global_gnn_layers,
dropout=self.config.global_gnn_dropout,
use_batch_norm=self.config.use_global_gnn_bn,
use_non_linear=self.config.use_non_linear
)
else:
raise ValueError
self.mlp = MLPModule(
self.config.d_model, self.config.d_model, self.config.d_model,
num_layers=self.config.num_global_fc_layers,
dropout=self.config.global_fc_dropout,
use_batch_norm=self.config.use_global_fc_bn,
use_non_linear=self.config.use_non_linear
)
self.graph_learner_1 = GlobalGraphLearner(self.config.d_model, self.config.num_global_gr_learner_heads, random=self.config.use_random_graphs)
self.graph_learner_2 = GlobalGraphLearner(self.config.d_model * 2, self.config.num_global_gr_learner_heads, random=self.config.use_random_graphs)
def reset_parameters(self):
self.gnn.reset_parameters()
self.mlp.reset_parameters()
self.graph_learner_1.reset_parameters()
self.graph_learner_2.reset_parameters()
def forward(self, Z, A):
################# Graph learner (upper branch) #################
A_prime = self.graph_learner_1(Z)
################# Gnn (upper branch) #################
Z_prime = self.gnn(Z, A_prime)
################# mlp (lower branch) #################
Z_double_prime = self.mlp(Z)
Z_concat = torch.cat([Z_prime, Z_double_prime], dim=-1)
################# Graph learner (lower branch) #################
A_double_prime = self.graph_learner_2(Z_concat)
# A_double_prime = (1-self.config.init_adj_ratio) * A_double_prime + self.config.init_adj_ratio * A
################## Average across braches ##################
A_global = (1 - self.config.adj_ratio) * A_prime + self.config.adj_ratio * A_double_prime
Z_global = 0.5 * Z_prime + 0.5 * Z_double_prime
return A_global, Z_global

1397
models/nextqa_bart.py Normal file

File diff suppressed because it is too large Load diff

249
models/utils.py Normal file
View file

@ -0,0 +1,249 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.utils import ModelOutput
from typing import Optional, Tuple
class ELBO(nn.Module):
def __init__(self):
super(ELBO, self).__init__()
def forward(self, QA, PA):
QA_flattened = QA.view(-1).unsqueeze(-1)
PA_flattened = PA.view(-1).unsqueeze(-1)
QA_flattened = torch.cat([torch.zeros_like(QA_flattened), QA_flattened], dim=-1)
PA_flattened = torch.cat([torch.zeros_like(PA_flattened), PA_flattened], dim=-1)
log_QA = F.log_softmax(QA_flattened, dim=1)
log_PA = F.log_softmax(PA_flattened, dim=1)
QA_dist = torch.exp(log_QA)
loss_QA = torch.mean(log_QA * QA_dist)
loss_PA = torch.mean(log_PA * QA_dist)
loss = loss_QA - loss_PA
return loss
def seperate_nextqa_input_modalities(
features, i3d_rgb_interval, i3d_flow_interval, question_intervals,
vis_state_vector_idx, question_state_vector_idx,
attention_values=None):
""" We separate the multimodal input hidden states. The state token embeddings are left out (+1 while indexing)
Args:
features (_type_): _description_
i3d_rgb_interval (_type_): _description_
i3d_flow_interval (_type_): _description_
sam_interval (_type_): _description_
audio_interval (_type_): _description_
history_intervals (_type_): _description_
question_intervals (_type_): _description_
Returns:
_type_: _description_
"""
features_copy = features.clone() # .detach()
i3d_rgb_hidden = features_copy[:, i3d_rgb_interval[0]+1:i3d_rgb_interval[1], :]
i3d_flow_hidden = features_copy[:, i3d_flow_interval[0]+1:i3d_flow_interval[1], :]
question_hidden = []
features_split = torch.split(features_copy, 1, dim=0)
for ques_inter, feat in zip(question_intervals, features_split):
ques_idx = torch.arange(ques_inter[0]+1, ques_inter[1]).unsqueeze(0).unsqueeze(-1).repeat(1, 1, feat.size(-1)).to(feat.device)
question_hidden.append(torch.gather(feat, 1, ques_idx))
if attention_values is None:
i3d_rgb_att = None
i3d_flow_att = None
question_att = None
else:
attention_values = attention_values.mean(1)
i3d_rgb_att = attention_values[:, vis_state_vector_idx[0], vis_state_vector_idx[0]+1:vis_state_vector_idx[1]]
i3d_flow_att = attention_values[:, vis_state_vector_idx[1], vis_state_vector_idx[1]+1:question_state_vector_idx[0]]
question_att = [attention_values[i, question_state_vector_idx[i], question_intervals[i][0] + 1: question_intervals[i][1]] for i in range(len(question_state_vector_idx))]
features_list = [i3d_rgb_hidden, i3d_flow_hidden, question_hidden]
att = [i3d_rgb_att, i3d_flow_att, question_att]
return features_list, att
def seperate_input_modalities(
features, i3d_rgb_interval, i3d_flow_interval, sam_interval, audio_interval, history_intervals, question_intervals,
vis_state_vector_idx, history_state_vector_idx, question_state_vector_idx,
attention_values=None):
""" We separate the multimodal input hidden states. The state token embeddings are left out (+1 while indexing)
Args:
features (_type_): _description_
i3d_rgb_interval (_type_): _description_
i3d_flow_interval (_type_): _description_
sam_interval (_type_): _description_
audio_interval (_type_): _description_
history_intervals (_type_): _description_
question_intervals (_type_): _description_
Returns:
_type_: _description_
"""
features_copy = features.clone() # .detach()
i3d_rgb_hidden = features_copy[:, i3d_rgb_interval[0]+1:i3d_rgb_interval[1], :]
i3d_flow_hidden = features_copy[:, i3d_flow_interval[0]+1:i3d_flow_interval[1], :]
sam_hidden = features_copy[:, sam_interval[0]+1:sam_interval[1], :]
audio_hidden = features_copy[:, audio_interval[0]+1:audio_interval[1], :]
history_hidden = []
question_hidden = []
features_split = torch.split(features_copy, 1, dim=0)
for hist_inter, ques_inter, feat in zip(history_intervals, question_intervals, features_split):
hist_idx = torch.arange(hist_inter[0]+1, hist_inter[1]).unsqueeze(0).unsqueeze(-1).repeat(1, 1, feat.size(-1)).to(feat.device)
history_hidden.append(torch.gather(feat, 1, hist_idx))
ques_idx = torch.arange(ques_inter[0]+1, ques_inter[1]).unsqueeze(0).unsqueeze(-1).repeat(1, 1, feat.size(-1)).to(feat.device)
question_hidden.append(torch.gather(feat, 1, ques_idx))
if attention_values is None:
i3d_rgb_att = None
i3d_flow_att = None
sam_att = None
audio_att = None
history_att = None
question_att = None
else:
attention_values = attention_values.mean(1)
i3d_rgb_att = attention_values[:, vis_state_vector_idx[0], vis_state_vector_idx[0]+1:vis_state_vector_idx[1]]
i3d_flow_att = attention_values[:, vis_state_vector_idx[1], vis_state_vector_idx[1]+1:vis_state_vector_idx[2]]
sam_att = attention_values[:, vis_state_vector_idx[2], vis_state_vector_idx[2]+1:vis_state_vector_idx[3]]
audio_att = attention_values[:, vis_state_vector_idx[3], vis_state_vector_idx[3]+1:history_state_vector_idx[0] - 1]
history_att = [attention_values[i, history_state_vector_idx[i], history_intervals[i][0] + 1 : history_intervals[i][1]] for i in range(len(history_state_vector_idx))]
question_att = [attention_values[i, question_state_vector_idx[i], question_intervals[i][0] + 1: question_intervals[i][1]] for i in range(len(question_state_vector_idx))]
features_list = [i3d_rgb_hidden, i3d_flow_hidden, sam_hidden, audio_hidden, history_hidden, question_hidden]
att = [i3d_rgb_att, i3d_flow_att, sam_att, audio_att, history_att, question_att]
return features_list, att
def get_knn_graph(features, num_nn, device):
features = features.permute((1, 2, 0))
cosine_sim_pairwise = F.cosine_similarity(features, features.unsqueeze(1), dim=-2)
cosine_sim_pairwise = cosine_sim_pairwise.permute((2, 0, 1))
num_nn = min(num_nn, cosine_sim_pairwise.size(-1))
adj_mat = torch.zeros_like(cosine_sim_pairwise).to(device)
_, to_keep = torch.topk(cosine_sim_pairwise, num_nn, dim=-1, sorted=False)
adj_mat = adj_mat.scatter(-1, to_keep, torch.ones_like(adj_mat).to(device))
return adj_mat
def track_features_vis(features, att, top_k, device, node_idx=None):
"""Computes an adjacency matrix based on the nearset neighbor similiarity for
the i3d, audio, and sam input modalities. The tracked constituents of each modality
are randomly chosen (A_tilde in the paper).
"""
features = features.clone().detach()
top_k = min(features.size(1), top_k)
if att is None:
node_idx = torch.randint(low=0, high=features.size(1), size=(features.size(0), top_k))
else:
_, node_idx = torch.topk(att, top_k, dim=-1, sorted=False)
node_idx = node_idx.unsqueeze(-1).repeat(1, 1, features.size(-1)).to(device)
selected_features = torch.gather(features, 1, node_idx)
return selected_features, node_idx
def track_features_text(features, att, top_k, device, node_idx=None):
"""Computes an adjacency matrix based on the nearset neighbor similiarity for
the history and question inputs. The tracked constituents of each modality
are randomly chosen (A_tilde in the paper).
"""
hidden_dim = features[0].size(-1)
min_len = min([feat.size(1) for feat in features])
top_k = min(min_len, top_k)
if att is None:
node_idx = [torch.randint(low=0, high=feat.size(1), size=(feat.size(0), top_k)) for feat in features]
else:
node_idx = [torch.topk(a, top_k, dim=-1, sorted=False)[-1] for a in att]
node_idx = [idx.unsqueeze(-1).repeat(1, 1, hidden_dim).to(device) for idx in node_idx]
selected_features = [torch.gather(feat, 1, idx) for feat, idx in zip(features, node_idx)]
selected_features = torch.cat(selected_features, dim=0)
return selected_features, node_idx
def diag_tensor(tensors):
device = tensors[0].device
n = sum([t.size(-1) for t in tensors])
bsz = tensors[0].size(0)
diag_tensor = torch.zeros((bsz, n, n)).float().to(device)
delimiter = 0
delimiters = [0]
for t in tensors:
diag_tensor[:, delimiter:delimiter+t.size(-1), delimiter:delimiter+t.size(-1)] = t
delimiter += t.size(-1)
delimiters.append(delimiter)
return diag_tensor, delimiters
def embed_graphs(features, delimiters):
state_vectors = []
for i in range(len(delimiters) - 1):
state_vectors.append(features[:, delimiters[i]:delimiters[i+1], :].mean(dim=1))
return state_vectors
class AVSDEncoderOutput(ModelOutput):
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
QAs_local = None
PAs_local = None
QA_global = None
PA_global = None
state_vectors = None
class AVSDSeq2SeqModelOutput(ModelOutput):
last_hidden_state: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
QAs_local = None
PAs_local = None
QA_global = None
PA_global = None
state_vectors = None
class AVSDSeq2SeqLMOutput(ModelOutput):
gen_loss: Optional[torch.FloatTensor] = None
elbo_loss_global: Optional[torch.FloatTensor] = None
elbo_loss_local: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_QAs_local = None
encoder_PAs_local = None
encoder_QA_global = None
encoder_PA_global = None
encoder_state_vectors = None

106
optim_utils.py Normal file
View file

@ -0,0 +1,106 @@
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim import AdamW
class WarmupLinearScheduleNonZero(_LRScheduler):
""" Linear warmup and then linear decay.
Linearly increases learning rate from 0 to max_lr over `warmup_steps` training steps.
Linearly decreases learning rate linearly to min_lr over remaining `t_total - warmup_steps` steps.
"""
def __init__(self, optimizer, warmup_steps, t_total, min_lr=1e-5, last_epoch=-1):
self.warmup_steps = warmup_steps
self.t_total = t_total
self.min_lr = min_lr
super(WarmupLinearScheduleNonZero, self).__init__(optimizer, last_epoch=last_epoch)
def get_lr(self):
step = self.last_epoch
if step < self.warmup_steps:
lr_factor = float(step) / float(max(1, self.warmup_steps))
else:
lr_factor = max(0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps)))
return [base_lr * lr_factor if (base_lr * lr_factor) > self.min_lr else self.min_lr for base_lr in self.base_lrs]
def init_optim(model, config):
encoder_params_with_weight_decay = []
encoder_params_without_weight_decay = []
decoder_params_with_weight_decay = []
decoder_params_without_weight_decay = []
other_params_with_weight_decay = []
other_params_without_weight_decay = []
exclude_from_weight_decay=['bias', 'LayerNorm.bias', 'LayerNorm.weight']
# Our model shares (embedding) parameters between the encoder and decoder.
# We want to include such parameters only in one parameter group.
# So we keep track of the unique ids of each parameter.
params_ids = []
for module_name, module in model.named_children():
for param_name, param in module.named_parameters():
if id(param) not in params_ids:
params_ids.append(id(param))
else:
continue
if param.requires_grad:
if 'encoder' in param_name:
if any(ex in param_name for ex in exclude_from_weight_decay):
encoder_params_without_weight_decay.append(param)
else:
encoder_params_with_weight_decay.append(param)
elif 'decoder' in param_name:
if any(ex in param_name for ex in exclude_from_weight_decay):
decoder_params_without_weight_decay.append(param)
else:
decoder_params_with_weight_decay.append(param)
else:
if any(ex in param_name for ex in exclude_from_weight_decay):
other_params_without_weight_decay.append(param)
else:
other_params_with_weight_decay.append(param)
optimizer_grouped_parameters = [
{
'params': encoder_params_with_weight_decay,
'weight_decay': 0.01,
'lr': config['learning_rate_bart']
},
{
'params': encoder_params_without_weight_decay,
'weight_decay': 0.0,
'lr': config['learning_rate_bart']
},
{
'params': decoder_params_with_weight_decay,
'weight_decay': 0.01,
'lr': config['learning_rate_bart']
},
{
'params': decoder_params_without_weight_decay,
'weight_decay': 0.0,
'lr': config['learning_rate_bart']
},
{
'params': other_params_with_weight_decay,
'weight_decay': 0.01,
'lr': config['learning_rate_other']
},
{
'params': other_params_without_weight_decay,
'weight_decay': 0.0,
'lr': config['learning_rate_other']
},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=config['learning_rate_bart'])
scheduler = WarmupLinearScheduleNonZero(
optimizer,
warmup_steps=config['warmup_steps'],
t_total=config['train_steps'],
min_lr=config['min_lr']
)
return optimizer, scheduler

0
processed/.gitkeep Normal file
View file

0
processed/avsd/.gitkeep Normal file
View file

View file

File diff suppressed because it is too large Load diff

Binary file not shown.

BIN
processed/nextqa/annotations/test.csv (Stored with Git LFS) Normal file

Binary file not shown.
1 version https://git-lfs.github.com/spec/v1
2 oid sha256:c73c7db32ed1c2addd7d81e8fd92849c7468d7778d4218af191739a18e09dfec
3 size 944376

BIN
processed/nextqa/annotations/train.csv (Stored with Git LFS) Normal file

Binary file not shown.
1 version https://git-lfs.github.com/spec/v1
2 oid sha256:b787d5b954022727e9d6851dc7e2e15c97f68070d2630101de4b797722413f38
3 size 3943605

BIN
processed/nextqa/annotations/val.csv (Stored with Git LFS) Normal file

Binary file not shown.
1 version https://git-lfs.github.com/spec/v1
2 oid sha256:7123332ec67319b61dcb3b05cad640538094751c1990bac0d9ce8bc6d103e70d
3 size 554521

BIN
processed/nextqa/annotations/vocab.pkl (Stored with Git LFS) Normal file

Binary file not shown.

BIN
processed/nextqa/vid_feat/app_mot_test.h5 (Stored with Git LFS) Normal file

Binary file not shown.

BIN
processed/nextqa/vid_feat/app_mot_train.h5 (Stored with Git LFS) Normal file

Binary file not shown.

BIN
processed/nextqa/vid_feat/app_mot_val.h5 (Stored with Git LFS) Normal file

Binary file not shown.

0
raw_data/.gitkeep Normal file
View file

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

488
runners/runner.py Normal file
View file

@ -0,0 +1,488 @@
import wandb
import os
import os.path as osp
import json
from collections import deque, OrderedDict
import time
import re
import shutil
import glob
import pickle
import gc
import numpy as np
import glog as log
import torch
import torch.utils.data as tud
import torch.nn.functional as F
import torch.distributed as dist
from torch.nn.utils import clip_grad_value_
class Runner:
def __init__(self, config):
self.config = config
if 'rank' in config:
self.gpu_rank = config['rank']
else:
self.gpu_rank = 0
self.epoch_idx = 0
self.min_gen_val_loss = float('inf')
self.best_epoch_idx = 0
if self.config["max_ckpt_to_keep"] > 0:
self.checkpoint_queue = deque([], maxlen=config["max_ckpt_to_keep"])
self.metrics_queue = deque([], maxlen=config["max_ckpt_to_keep"])
self.setup_wandb()
def setup_wandb(self):
if self.gpu_rank == 0:
print("[INFO] Set wandb logging on rank {}".format(0))
run = wandb.init(
project=self.config['wandb_project'], config=self.config, mode=self.config['wandb_mode'])
else:
run = None
self.run = run
def forward(self, batch, eval=False):
return NotImplementedError
def train(self, dataset, dataset_eval):
batch_size = self.config['batch_size']
if self.config['parallel'] and self.config['dp_type'] != 'dp':
sampler = tud.distributed.DistributedSampler(
dataset,
num_replicas=self.config['num_gpus'],
rank=self.gpu_rank
)
else:
sampler = None
data_loader = tud.DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=self.config['training'] and not self.config['parallel'],
collate_fn=dataset.collate_fn,
num_workers=self.config['num_workers'],
sampler=sampler
)
start_epoch_idx = self.epoch_idx
num_iter_epoch = self.config['num_iter_per_epoch']
if self.config['display']:
log.info(f'{num_iter_epoch} iter per epoch.')
num_epochs = self.config['num_epochs']
# Perform validation before training
if self.config['eval_first']:
_ = self.val(dataset_eval)
for epoch_idx in range(start_epoch_idx, num_epochs):
if self.config['parallel'] and self.config['dp_type'] != 'dp':
sampler.set_epoch(epoch_idx)
self.epoch_idx = epoch_idx
if self.config['display']:
log.info(f'starting epoch {epoch_idx}')
log.info('training')
self.model.train()
num_batch = 0
next_logging_pct = .1
start_time = time.time()
self.optimizer.zero_grad()
for batch in data_loader:
num_batch += 1
pct = num_batch / num_iter_epoch * 100
iter_now = num_iter_epoch * epoch_idx + num_batch
output = self.forward(batch)
losses = output['losses']
# optimizer step
losses['tot_loss'] /= self.config['batch_multiply']
# debug
if self.config['debugging']:
log.info('try backward')
losses['tot_loss'].backward()
if self.config['clip_grad_value'] > 0:
clip_grad_value_(self.model.parameters(), self.config['clip_grad_value'])
if self.config['debugging']:
log.info('backward done')
if iter_now % self.config['batch_multiply'] == 0:
self.optimizer.step()
self.optimizer.zero_grad()
self.scheduler.step()
# display and eval
if pct >= next_logging_pct:
if self.config['display']:
loss_to_print = ''
for key in losses:
if losses[key] is not None and isinstance(losses[key], torch.Tensor):
loss_to_print += f'[{key}: {losses[key].item():.4f}]'
print(
f'[{int(pct)}%][Epoch: {epoch_idx + 1}/{num_epochs}][Iter : {num_batch}/{len(data_loader)}] [time: {time.time() - start_time:.2f}] {loss_to_print}'
)
if self.config['print_output']:
print(10 * '-' + 'responses' + 10 * '-')
print(output['reponses'])
print(10 * '-' + 'gt' + 10 * '-')
print(output['gt'])
next_logging_pct += self.config["next_logging_pct"]
if self.config['debugging']:
break
lr_bart, lr_other = self.scheduler.get_lr()[0], self.scheduler.get_lr()[-1]
elbo_global_key = 'elbo_loss_global (x{})'.format(self.config['elbo_global_coeff'])
elbo_local_key = 'elbo_loss_local (x{})'.format(self.config['elbo_local_coeff'])
gen_key = 'gen_loss (x{})'.format(self.config['gen_coeff'])
if self.run:
self.run.log(
{
f"Train/{gen_key}": losses[gen_key].item(),
f"Train/{elbo_global_key}": losses[elbo_global_key].item(),
f"Train/{elbo_local_key}": losses[elbo_local_key].item(),
"Train/total_loss": losses['tot_loss'].item(),
},
step=iter_now
)
self.run.log(
{"Train/lr_bart": lr_bart, "Train/lr_other": lr_other},
step=iter_now
)
del losses
del output
if self.config['display']:
log.info(
f'100%,\ttime:\t{time.time() - start_time:.2f}'
)
if not self.config['overfit'] and self.run:
self.save_ckpt()
if not self.config['skip_eval']:
iter_now = num_iter_epoch * (epoch_idx + 1)
val_losses = self.val(dataset_eval)
if self.config['display']:
log.info('#'*100)
for k in val_losses:
log.info('Average val {} (epoch {}) = {}'.format(k, self.epoch_idx, val_losses[k]))
log.info('#'*100)
gen_val_loss = val_losses[gen_key]
if gen_val_loss < self.min_gen_val_loss:
self.min_gen_val_loss = gen_val_loss
self.best_epoch_idx = epoch_idx
# Log the best model w.r.t. the validation data
if self.run and self.config['save_ckpt']:
self.save_ckpt_best()
if self.run:
self.run.log(
{
f"Val/{gen_key}": val_losses[gen_key],
f"Val/{elbo_global_key}": val_losses[elbo_global_key],
f"Val/{elbo_local_key}": val_losses[elbo_local_key],
"Val/total_loss": val_losses['tot_loss'],
"Val/min_gen_loss": self.min_gen_val_loss
},
step=iter_now
)
if self.config['parallel']:
if self.config['dp_type'] == 'dp':
gc.collect()
torch.cuda.empty_cache()
else:
dist.barrier()
torch.cuda.empty_cache()
if self.config['stop_epochs'] >= 0 and epoch_idx + 1 >= self.config['stop_epochs']:
if self.config['display']:
log.info('Stop for reaching stop_epochs.')
break
if self.config['display']:
log.info(f'Best validation loss was reached at epoch {self.best_epoch_idx}.')
def val(self, dataset):
total_loss_val = 0.0
total_gen_loss_val = 0.0
total_elbo_global_val = 0.0
total_elbo_local_val = 0.0
num_batch_val = 0
next_logging_pct_val = 0.05
elbo_global_key = 'elbo_loss_global (x{})'.format(self.config['elbo_global_coeff'])
elbo_local_key = 'elbo_loss_local (x{})'.format(self.config['elbo_local_coeff'])
gen_key = 'gen_loss (x{})'.format(self.config['gen_coeff'])
# Prepare the dataloader
if self.config['parallel'] and self.config['dp_type'] != 'dp':
sampler_val = tud.distributed.DistributedSampler(
dataset,
num_replicas=self.config['num_gpus'],
rank=self.gpu_rank
)
sampler_val.set_epoch(self.epoch_idx)
else:
sampler_val = None
data_loader_val = tud.DataLoader(
dataset=dataset,
batch_size=self.config['batch_size'],
shuffle=False,
collate_fn=dataset.collate_fn,
num_workers=self.config['num_workers'],
sampler=sampler_val
)
if self.config['parallel'] and self.config['dp_type'] == 'dp':
num_iter_per_epoch_val = int(np.ceil(len(dataset) / self.config['batch_size']))
else:
num_iter_per_epoch_val = int(np.ceil(len(dataset) / (self.config['batch_size'] * self.config['num_gpus'])))
self.model.eval()
if self.gpu_rank == 0:
start_time = time.time()
for batch in data_loader_val:
num_batch_val += 1
pct = num_batch_val / num_iter_per_epoch_val * 100
with torch.no_grad():
output = self.forward(batch)
losses = output['losses']
losses['tot_loss'] /= self.config['batch_multiply']
losses[elbo_global_key] /= self.config['batch_multiply']
losses[elbo_local_key] /= self.config['batch_multiply']
losses[gen_key] /= self.config['batch_multiply']
total_loss_val += losses['tot_loss'].item()
total_gen_loss_val += losses[gen_key].item()
total_elbo_global_val += losses[elbo_global_key].item()
total_elbo_local_val += losses[elbo_local_key].item()
# display and eval
if pct >= next_logging_pct_val:
if self.config['display']:
loss_to_print = ''
for key in losses:
if losses[key] is not None and isinstance(losses[key], torch.Tensor):
loss_to_print += f'[{key}: {losses[key].item():.4f}]'
print(
f'[{int(pct)}%][Validating][Iter : {num_batch_val}/{num_iter_per_epoch_val}] [time: {time.time() - start_time:.2f}] {loss_to_print}'
)
next_logging_pct_val += self.config["next_logging_pct"]
loss_val = total_loss_val / num_batch_val
gen_loss_val = total_gen_loss_val / num_batch_val
elbo_global_val = total_elbo_global_val / num_batch_val
elbo_local_val = total_elbo_local_val / num_batch_val
losses_val = {
'tot_loss': loss_val,
elbo_global_key: elbo_global_val,
elbo_local_key: elbo_local_val,
gen_key: gen_loss_val
}
self.model.train()
return losses_val
def save_eval_results(self, split, epoch_idx, metrics_results):
metrics_filename = osp.join(self.config['log_dir'], f'metrics_epoch_{epoch_idx}.json')
with open(metrics_filename, 'w') as f:
json.dump(metrics_results, f)
log.info(f'Results of metrics saved to {metrics_filename}')
if self.config["max_ckpt_to_keep"] > 0:
if len(self.metrics_queue) == self.metrics_queue.maxlen:
todel = self.metrics_queue.popleft()
os.remove(todel)
self.metrics_queue.append(metrics_filename)
if epoch_idx == 'best':
self.copy_best_predictions(split)
def copy_best_results(self, split, epoch_idx):
to_print = 'Copy '
if not self.config['skip_saving_ckpt']:
ckpt_path = osp.join(self.config['log_dir'], f'epoch_{epoch_idx}.ckpt')
best_ckpt_path = ckpt_path.replace(f'{epoch_idx}.ckpt', 'best.ckpt')
shutil.copyfile(ckpt_path, best_ckpt_path)
to_print += best_ckpt_path + ' '
metrics_filename = osp.join(self.config['log_dir'], f'metrics_epoch_{epoch_idx}.json')
best_metric_filename = metrics_filename.replace(f'{epoch_idx}.json', 'best.json')
shutil.copyfile(metrics_filename, best_metric_filename)
to_print += best_metric_filename + ' '
log.info(to_print)
def set_ckpt(self, ckpt_dict):
if self.config['parallel']:
model = self.model.module
else:
model = self.model
model_state_dict = model.state_dict()
former_dict = {k: v for k, v in ckpt_dict['model_state_dict'].items() if k in model_state_dict}
if self.config['display']:
log.info("number of keys transferred: %d" % len(former_dict))
assert len(former_dict.keys()) > 0
model_state_dict.update(former_dict)
model.load_state_dict(model_state_dict)
if self.config['display']:
log.info('loaded model')
del model_state_dict, former_dict
# if not self.config['uses_new_optimizer']:
if not self.config['generating'] and not (self.config['uses_new_optimizer'] or self.config['sets_new_lr']):
if not self.config['restarts']:
self.epoch_idx = ckpt_dict['epoch_idx'] + 1
if not self.config['resets_min_val_loss']:
self.min_gen_val_loss = ckpt_dict['min_gen_val_loss']
self.optimizer.load_state_dict(ckpt_dict['optimizer'])
if self.config['display']:
log.info('loaded optimizer')
if 'scheduler' in ckpt_dict:
self.scheduler.last_epcoh = ckpt_dict['epoch_idx'] * self.config['num_iter_per_epoch']
self.scheduler.load_state_dict(ckpt_dict['scheduler'])
del ckpt_dict
torch.cuda.empty_cache()
def save_ckpt(self):
ckpt_path = f'{self.config["log_dir"]}/epoch_{self.epoch_idx}.ckpt'
log.info(f'saving checkpoint {ckpt_path}')
ckpt = self.get_ckpt()
if self.config['skip_saving_ckpt']:
return ckpt_path
torch_version = float(torch.__version__[:3])
if torch_version - 1.4 > 1e-3:
torch.save(ckpt, f=ckpt_path, _use_new_zipfile_serialization=False)
else:
torch.save(ckpt, f=ckpt_path)
del ckpt
if not self.config['parallel']:
torch.cuda.empty_cache()
if self.config["max_ckpt_to_keep"] > 0:
if len(self.checkpoint_queue) == self.checkpoint_queue.maxlen:
todel = self.checkpoint_queue.popleft()
os.remove(todel)
self.checkpoint_queue.append(ckpt_path)
def save_ckpt_best(self):
ckpt_path = f'{self.config["log_dir"]}/epoch_best.ckpt'
log.info(f'saving checkpoint {ckpt_path}')
ckpt = self.get_ckpt()
torch.save(ckpt, f=ckpt_path)
del ckpt
return ckpt_path
def load_ckpt_best(self):
ckpt_path = f'{self.config["log_dir"]}/epoch_best.ckpt'
if not osp.exists(ckpt_path):
ckpt_paths = [path for path in os.listdir(f'{self.config["log_dir"]}/') if path.endswith('.ckpt') and 'best' not in path]
if len(ckpt_paths) == 0:
if self.config['display']:
log.info(f'No .ckpt found in {self.config["log_dir"]}')
return
sort_func = lambda x:int(re.search(r"(\d+)", x).groups()[0])
ckpt_path = f'{self.config["log_dir"]}/{sorted(ckpt_paths, key=sort_func, reverse=True)[0]}'
if self.config['display']:
log.info(f'loading checkpoint {ckpt_path}')
map_location = {'cuda:0': f'cuda:{self.gpu_rank}'}
self.set_ckpt(torch.load(ckpt_path, map_location=map_location))
def get_ckpt(self):
ckpt = {
'epoch_idx': self.epoch_idx,
'min_gen_val_loss': self.min_gen_val_loss,
'seed': self.config['random_seed'],
'optimizer': self.optimizer.state_dict(),
'scheduler': self.scheduler.state_dict()
}
ckpt['model_state_dict'] = self.model.module.state_dict()
return ckpt
def load_ckpt(self, ckpt_path=None):
if not ckpt_path:
if self.config['generating']: # or self.config['start_ckpt_for_generating']:
ckpt_path = f'{self.config["log_dir"]}/epoch_best.ckpt'
else:
ckpt_paths = [path for path in os.listdir(f'{self.config["log_dir"]}/') if path.endswith('.ckpt') and 'best' not in path]
if len(ckpt_paths) == 0:
if self.config['display']:
log.info(f'No .ckpt found in {self.config["log_dir"]}')
return
sort_func = lambda x:int(re.search(r"(\d+)", x).groups()[0])
ckpt_path = f'{self.config["log_dir"]}/{sorted(ckpt_paths, key=sort_func, reverse=True)[0]}'
if self.config['display']:
log.info(f'loading checkpoint {ckpt_path}')
epoch_name = osp.split(ckpt_path)[1].split('.')[0]
if re.search(r"(\d+)", epoch_name):
self.checkpoint_queue.append(ckpt_path)
metrics_filename = osp.join(self.config['log_dir'], f'metrics_{epoch_name}.json')
if osp.exists(metrics_filename):
self.metrics_queue.append(metrics_filename)
map_location = {'cuda:0': f'cuda:{self.gpu_rank}'}
self.set_ckpt(torch.load(ckpt_path, map_location=map_location))
def match_model_key(self, pretrained_dict, model_dict):
matched_dict = dict()
not_found = []
for key in pretrained_dict:
if key in model_dict:
matched_key = key
elif key.startswith('encoder.') and key[8:] in model_dict:
matched_key = key[8:]
elif key.startswith('module.') and key[7:] in model_dict:
matched_key = key[7:]
elif 'encoder.' + key in model_dict:
matched_key = 'encoder.' + key
elif 'module.' + key in model_dict:
matched_key = 'module.' + key
else:
not_found.append(key)
continue
matched_dict[matched_key] = pretrained_dict[key]
print("Keys from pretrained_dict that were not found in model_dict:\n", not_found)
return matched_dict

337
runners/runner_avsd.py Normal file
View file

@ -0,0 +1,337 @@
import time
import os
import glog as log
import numpy as np
import json
import torch
import torch.nn.functional as F
from runners.runner import Runner
from copy import deepcopy
from optim_utils import init_optim
from transformers.models.bart.configuration_bart import BartConfig
from models.avsd_bart import AVSDBart
from custom_datasets.avsd import build_input_from_segments
from time import time
class AVSDRunner(Runner):
def __init__(self, config, tokenizer, vocab_size):
super(AVSDRunner, self).__init__(config)
bart_config = BartConfig.from_json_file(self.config['bart_config'])
self.model = AVSDBart.from_pretrained(
'facebook/bart-{}'.format(self.config['bart_size']), config=bart_config)
# Resize the embedding to match the vocab with additional special toks
# This takes care of resizing weights of related parts of the network
# pytorch_total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
# print(pytorch_total_params)
if vocab_size != bart_config.vocab_size:
self.model.resize_token_embeddings(vocab_size)
self.model.to(self.config['device'])
if not self.config['generating']:
self.optimizer, self.scheduler = init_optim(self.model, self.config)
self.tokenizer = tokenizer
def forward(self, batch):
for key in batch:
if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].cuda()
########################################################
input_ids = batch['input_ids']
video_place_holder_ids = batch['video_place_holder_ids']
i3d_rgb = batch['i3d_rgb']
i3d_flow = batch['i3d_flow']
sam = batch['sam']
vggish = batch['vggish']
lm_labels = batch['lm_labels']
input_mask = batch['input_mask']
i3d_rgb_interval = batch['i3d_rgb_interval']
i3d_flow_interval = batch['i3d_flow_interval']
sam_interval = batch['sam_interval']
audio_interval = batch['audio_interval']
history_intervals = batch['history_intervals']
question_intervals = batch['question_intervals']
vis_state_vector_idx = batch['vis_state_vector_idx']
history_state_vector_idx = batch['history_state_vector_idx']
question_state_vector_idx = batch['question_state_vector_idx']
########################################################
bart_output = self.model(
input_ids=input_ids,
video_place_holder_ids=video_place_holder_ids,
i3d_rgb=i3d_rgb,
i3d_flow=i3d_flow,
sam=sam,
vggish=vggish,
attention_mask=input_mask,
labels=lm_labels,
i3d_rgb_interval=i3d_rgb_interval,
i3d_flow_interval=i3d_flow_interval,
sam_interval=sam_interval,
audio_interval=audio_interval,
history_intervals=history_intervals,
question_intervals=question_intervals,
vis_state_vector_idx=vis_state_vector_idx,
history_state_vector_idx=history_state_vector_idx,
question_state_vector_idx=question_state_vector_idx,
output_attentions=True,
return_dict=True
)
output = {}
if self.config['print_output']:
logits = bart_output['logits']
probs = F.softmax(logits, dim=-1)
preds = torch.topk(probs, 1)[1].squeeze(-1)
preds = preds.tolist()
lm_labels_list = lm_labels[:, 1:].tolist()
lm_labels_list = [[s for s in label if s != -1] for label in lm_labels_list]
reponses = ''
labels = ''
for pred, label in zip(preds, lm_labels_list):
reponses += self.tokenizer.decode(pred) + '\n'
labels += self.tokenizer.decode(label) + '\n'
output['reponses'] = reponses
output['gt'] = labels
gen_key = 'gen_loss (x{})'.format(self.config['gen_coeff'])
gen_loss = bart_output['gen_loss']
gen_loss = self.config['gen_coeff'] * gen_loss
elbo_global_key = 'elbo_loss_global (x{})'.format(self.config['elbo_global_coeff'])
if bart_output['elbo_loss_global'] is not None:
elbo_global_loss = bart_output['elbo_loss_global']
elbo_global_loss = self.config['elbo_global_coeff'] * elbo_global_loss
else:
elbo_global_loss = torch.tensor(0.0)
elbo_local_key = 'elbo_loss_local (x{})'.format(self.config['elbo_local_coeff'])
if bart_output['elbo_loss_local'] is not None:
elbo_local_loss = bart_output['elbo_loss_local']
elbo_local_loss = self.config['elbo_local_coeff'] * elbo_local_loss
else:
elbo_local_loss = torch.tensor(0.0)
total_loss = gen_loss + elbo_global_loss + elbo_local_loss
output['losses'] = {
gen_key: gen_loss,
elbo_local_key: elbo_local_loss,
elbo_global_key: elbo_global_loss,
'tot_loss': total_loss
}
del bart_output
return output
def generate(self, dataset, tag, tokenizer, gen_subset_num=None):
self.model.eval()
responses = {}
i3d_flow_sep, i3d_rgb_sep, sam_sep, audio_sep, ph_token = tokenizer.convert_tokens_to_ids(
['<s0>', '<s1>', '<s2>', '<s3>', '<place_holder>'])
# Generate the repsonse for each round
log.info('[INFO] Generating responses for {} samples'.format(len(dataset)))
with torch.no_grad():
for counter, dialog in enumerate(dataset):
start_time = time()
vid = dialog['vid']
i3d_rgb = np.load(os.path.join(self.config['avsd_i3d_rgb_test'], vid + '.npy'))
i3d_flow = np.load(os.path.join(self.config['avsd_i3d_flow_test'], vid + '.npy'))
sam = np.load(os.path.join(self.config['avsd_objects_test'], vid + '.npy'))
vggish = np.load(os.path.join(self.config['avsd_audio_test'], vid + '.npy'))
min_length = min([self.config['vis_feat_length'], i3d_rgb.shape[0], i3d_flow.shape[0], sam.shape[0], vggish.shape[0]])
sample_idx_i3d_rgb = np.round(np.linspace(0, i3d_rgb.shape[0] - 1, min_length)).astype(int)
sample_idx_i3d_flow = np.round(np.linspace(0, i3d_flow.shape[0] - 1, min_length)).astype(int)
sample_idx_sam = np.round(np.linspace(0, sam.shape[0] - 1, min_length)).astype(int)
sample_idx_vggish = np.round(np.linspace(0, vggish.shape[0] - 1, min_length)).astype(int)
i3d_rgb = torch.from_numpy(i3d_rgb[sample_idx_i3d_rgb, :]).float()
i3d_flow = torch.from_numpy(i3d_flow[sample_idx_i3d_flow, :]).float()
sam = torch.from_numpy(sam[sample_idx_sam, :]).float()
vggish = torch.from_numpy(vggish[sample_idx_vggish, :]).float()
dummy = torch.ones((1, min_length)) * ph_token
video_place_holder_ids = torch.cat(
[torch.ones((1, 1)) * i3d_rgb_sep, dummy,
torch.ones((1, 1)) * i3d_flow_sep, dummy,
torch.ones((1, 1)) * sam_sep, dummy,
torch.ones((1, 1)) * audio_sep, dummy,
], dim=-1).long()
# Now we get the intervals of the visual input tokens
# Here the interval do not change across the batch dimension
i3d_rgb_interval = [0, min_length + 1] # the last token is not part of this modality
i3d_flow_interval = [min_length + 1, 2 * min_length + 2]
sam_interval = [2 * min_length + 2, 3 * min_length + 3]
audio_interval = [3 * min_length + 3, 4 * min_length + 4]
vis_state_vector_idx = [i3d_rgb_interval[0], i3d_flow_interval[0], sam_interval[0], audio_interval[0]]
response = self.beam_search_generation(
dialog['caption'], dialog['history'],
i3d_rgb, i3d_flow, sam, vggish,
i3d_rgb_interval, i3d_flow_interval, sam_interval, audio_interval,
vis_state_vector_idx, video_place_holder_ids, tokenizer)
# Decode the response
response = self.tokenizer.decode(response)
responses[vid] = response
# all_graphs[vid] = graphs
time_elapsed = int(time() - start_time)
print('Generating resonse {} / {} -- took {}s'.format(counter + 1, len(dataset), time_elapsed))
# Create a file with all responses
with open(self.config['avsd_test_dstc{}'.format(self.config['dstc'])], 'r') as f:
test_data = json.load(f)
test_dialogs = deepcopy(test_data['dialogs'])
# Filter the predicted dialogs
test_dialogs = list(filter(lambda diag: diag['image_id'] in responses, test_dialogs))
for i, dialog in enumerate(test_dialogs):
vid_id = dialog['image_id']
gen_response = responses[vid_id]
round_num_to_answer = len(dialog['dialog'])-1
assert dialog['dialog'][round_num_to_answer]['answer'] == '__UNDISCLOSED__'
dialog['dialog'][round_num_to_answer]['answer'] = gen_response
test_dialogs[i] = dialog
# Log the file
file_name = 'results_dstc{}_beam_depth_{}'.format(self.config['dstc'], self.config['beam_depth'])
if gen_subset_num is not None:
file_name += f'-part_{gen_subset_num}'
file_name = f'{tag}_' + file_name
output_path = os.path.join(self.config['output_dir_dstc{}'.format(self.config['dstc'])], file_name + '.json')
with open(output_path, 'w') as f:
json.dump({'dialogs': test_dialogs}, f, indent=4)
log.info('Results logged to {}'.format(output_path))
print(os.getcwd())
# Switch back to training mode
self.model.train()
def beam_search_generation(
self, caption, history,
i3d_rgb, i3d_flow, sam, vggish,
i3d_rgb_interval, i3d_flow_interval, sam_interval, audio_interval,
vis_state_vector_idx, video_place_holder_ids, tokenizer):
eos_token = tokenizer.eos_token_id
unk_token = tokenizer.unk_token_id
question_sep = tokenizer.convert_tokens_to_ids('<s5>')
gen_ans = [eos_token]
hyplist = [([], 0.0, [eos_token])]
best_state = None
comp_hyplist = []
i3d_rgb = i3d_rgb.unsqueeze(0).cuda()
i3d_flow = i3d_flow.unsqueeze(0).cuda()
sam = sam.unsqueeze(0).cuda()
vggish = vggish.unsqueeze(0).cuda()
video_place_holder_ids = video_place_holder_ids.cuda()
text_shift_len = video_place_holder_ids.size(-1)
drop_caption = self.config['dstc'] == 10
instance = build_input_from_segments(caption, history, gen_ans, tokenizer, drop_caption=drop_caption)
input_ids = torch.tensor(instance['input_ids'])
history_end = (input_ids == question_sep).nonzero(as_tuple=True)[0]
history_intervals = [[0 + text_shift_len, history_end.item() + text_shift_len]] # The last token is the question state token (not part of the history)
question_intervals = [[history_end.item() + text_shift_len, input_ids.size(0) + text_shift_len]]
history_state_vector_idx = [x[0] + 1 for x in history_intervals] # +1 because the history starts with <s><s4> .....
question_state_vector_idx = [x[0] for x in question_intervals] # +1 because the history starts with <s><s4> .....
input_ids = input_ids.long().cuda().unsqueeze(0)
encoder_outputs = None
for i in range(self.config['max_generation_length']):
new_hyplist = []
argmin = 0
for out, lp, st in hyplist:
decoder_input_ids = torch.tensor(st).long().cuda().unsqueeze(0)
bart_output = self.model(
input_ids=input_ids,
video_place_holder_ids=video_place_holder_ids,
i3d_rgb=i3d_rgb,
i3d_flow=i3d_flow,
sam=sam,
vggish=vggish,
encoder_outputs=encoder_outputs,
decoder_input_ids=decoder_input_ids,
i3d_rgb_interval=i3d_rgb_interval,
i3d_flow_interval=i3d_flow_interval,
sam_interval=sam_interval,
audio_interval=audio_interval,
history_intervals=history_intervals,
question_intervals=question_intervals,
vis_state_vector_idx=vis_state_vector_idx,
history_state_vector_idx=history_state_vector_idx,
question_state_vector_idx=question_state_vector_idx,
output_attentions=True,
generate=True,
return_dict=True
)
if encoder_outputs is None:
encoder_outputs = [
bart_output['encoder_last_hidden_state'],
bart_output['encoder_hidden_states'],
bart_output['encoder_attentions'],
bart_output['encoder_QAs_local'],
bart_output['encoder_PAs_local'],
bart_output['encoder_QA_global'],
bart_output['encoder_PA_global'],
bart_output['encoder_state_vectors']
]
logits = bart_output['logits'][:,-1,:].squeeze() # get the logits of the last token
logp = F.log_softmax(logits, dim=0)
lp_vec = logp.cpu().data.numpy() + lp
if i >= self.config['min_generation_length']:
new_lp = lp_vec[eos_token] + self.config['length_penalty'] * (len(out) + 1)
comp_hyplist.append((out, new_lp))
if best_state is None or best_state < new_lp:
best_state = new_lp
count = 1
for o in np.argsort(lp_vec)[::-1]: # reverse the order
if o in [eos_token, unk_token]:
continue
new_lp = lp_vec[o]
if len(new_hyplist) == self.config['beam_depth']:
if new_hyplist[argmin][1] < new_lp:
new_st = deepcopy(st)
new_st.append(int(o))
new_hyplist[argmin] = (out + [o], new_lp, new_st)
argmin = min(enumerate(new_hyplist), key=lambda h: h[1][1])[0]
else:
break
else:
new_st = deepcopy(st)
new_st.append(int(o))
new_hyplist.append((out + [o], new_lp, new_st))
if len(new_hyplist) == self.config['beam_depth']:
argmin = min(enumerate(new_hyplist), key=lambda h: h[1][1])[0]
count += 1
hyplist = new_hyplist
if len(comp_hyplist) > 0:
maxhyps = sorted(comp_hyplist, key=lambda h: -h[1])[:1]
return maxhyps[0][0]
else:
return []

300
runners/runner_nextqa.py Normal file
View file

@ -0,0 +1,300 @@
import time
import os
import glog as log
import numpy as np
import json
import torch
import torch.nn.functional as F
from runners.runner import Runner
from copy import deepcopy
from optim_utils import init_optim
from transformers.models.bart.configuration_bart import BartConfig
from models.nextqa_bart import AVSDBart
from time import time
def tokenize(obj, tokenizer):
if isinstance(obj, str):
return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj))
if isinstance(obj, dict):
return dict((n, tokenize(o)) for n, o in obj.items())
return list(tokenize(o) for o in obj)
class NEXTQARunner(Runner):
def __init__(self, config, tokenizer, vocab_size):
super(NEXTQARunner, self).__init__(config)
bart_config = BartConfig.from_json_file(self.config['bart_config'])
self.model = AVSDBart.from_pretrained(
'facebook/bart-{}'.format(self.config['bart_size']), config=bart_config)
# Resize the embedding to match the vocab with additional special toks
# This takes care of resizing weights of related parts of the network
if vocab_size != bart_config.vocab_size:
self.model.resize_token_embeddings(vocab_size)
self.model.to(self.config['device'])
if not self.config['generating']:
self.optimizer, self.scheduler = init_optim(self.model, self.config)
self.tokenizer = tokenizer
def forward(self, batch):
for key in batch:
if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].cuda()
########################################################
input_ids = batch['input_ids']
video_place_holder_ids = batch['video_place_holder_ids']
app_feats = batch['app_feats']
mot_feats = batch['mot_feats']
lm_labels = batch['lm_labels']
input_mask = batch['input_mask']
app_interval = batch['app_interval']
mot_interval = batch['mot_interval']
question_intervals = batch['question_intervals']
vis_state_vector_idx = batch['vis_state_vector_idx']
question_state_vector_idx = batch['question_state_vector_idx']
########################################################
bart_output = self.model(
input_ids=input_ids,
video_place_holder_ids=video_place_holder_ids,
i3d_rgb=app_feats,
i3d_flow=mot_feats,
attention_mask=input_mask,
labels=lm_labels,
i3d_rgb_interval=app_interval,
i3d_flow_interval=mot_interval,
question_intervals=question_intervals,
vis_state_vector_idx=vis_state_vector_idx,
question_state_vector_idx=question_state_vector_idx,
output_attentions=True,
return_dict=True
)
output = {}
if self.config['print_output']:
logits = bart_output['logits']
probs = F.softmax(logits, dim=-1)
preds = torch.topk(probs, 1)[1].squeeze(-1)
preds = preds.tolist()
lm_labels_list = lm_labels[:, 1:].tolist()
lm_labels_list = [[s for s in label if s != -1] for label in lm_labels_list]
reponses = ''
labels = ''
for pred, label in zip(preds, lm_labels_list):
reponses += self.tokenizer.decode(pred) + '\n'
labels += self.tokenizer.decode(label) + '\n'
output['reponses'] = reponses
output['gt'] = labels
gen_key = 'gen_loss (x{})'.format(self.config['gen_coeff'])
gen_loss = bart_output['gen_loss']
gen_loss = self.config['gen_coeff'] * gen_loss
elbo_global_key = 'elbo_loss_global (x{})'.format(self.config['elbo_global_coeff'])
if bart_output['elbo_loss_global'] is not None:
elbo_global_loss = bart_output['elbo_loss_global']
elbo_global_loss = self.config['elbo_global_coeff'] * elbo_global_loss
else:
elbo_global_loss = torch.tensor(0.0)
elbo_local_key = 'elbo_loss_local (x{})'.format(self.config['elbo_local_coeff'])
if bart_output['elbo_loss_local'] is not None:
elbo_local_loss = bart_output['elbo_loss_local']
elbo_local_loss = self.config['elbo_local_coeff'] * elbo_local_loss
else:
elbo_local_loss = torch.tensor(0.0)
total_loss = gen_loss + elbo_global_loss + elbo_local_loss
output['losses'] = {
gen_key: gen_loss,
elbo_local_key: elbo_local_loss,
elbo_global_key: elbo_global_loss,
'tot_loss': total_loss
}
del bart_output
return output
def generate(self, dataset, app_feats, mot_feats, tag, tokenizer, start_idx_gen, end_idx_gen, gen_subset_num=None):
self.model.eval()
results = {}
app_sep, mot_sep, ph_token = tokenizer.convert_tokens_to_ids(
['<s0>', '<s1>', '<place_holder>'])
# Generate the repsonse for each round
log.info('[INFO] Generating responses for {} samples'.format(len(dataset)))
with torch.no_grad():
counter = 0
for idx in range(start_idx_gen, end_idx_gen):
start_time = time()
cur_sample = dataset.loc[idx]
video_name, ques, ans, qid = str(cur_sample['video']), str(cur_sample['question']),\
str(cur_sample['answer']), str(cur_sample['qid'])
if video_name not in results:
results[video_name] = {}
input_ids = tokenize(ques, tokenizer)
app_feat = app_feats[video_name]
app_feat = torch.from_numpy(app_feat).type(torch.float32)
mot_feat = mot_feats[video_name]
mot_feat = torch.from_numpy(mot_feat).type(torch.float32)
bos, eos, ques_state = self.tokenizer.convert_tokens_to_ids(['<s>', '</s>', '<s2>'])
# Add state tokens
input_ids.insert(0, ques_state)
input_ids = torch.Tensor(input_ids).long()
dummy = torch.ones((1, 16)) * ph_token
video_place_holder_ids = torch.cat(
[torch.ones((1, 1)) * app_sep, dummy,
torch.ones((1, 1)) * mot_sep, dummy,
], dim=-1).long()
# Now we get the intervals of the visual input tokens
# Here the interval do not change across the batch dimension
app_interval = [0, 16 + 1] # the last token is not part of this modality
mot_interval = [16 + 1, 2 * 16 + 2]
vis_state_vector_idx = [app_interval[0], mot_interval[0]]
response = self.beam_search_generation(
input_ids,
app_feat, mot_feat,
app_interval, mot_interval,
vis_state_vector_idx, video_place_holder_ids, tokenizer)
# Decode the response
response = self.tokenizer.decode(response)
results[video_name][qid] = response
time_elapsed = int(time() - start_time)
print('Generating resonse {} / {} -- took {}s'.format(counter + 1, len(dataset), time_elapsed))
counter += 1
# Create a file with all responses
file_name = 'results_nextqa_beam_depth_{}'.format(self.config['beam_depth'])
if gen_subset_num is not None:
file_name += f'-part_{gen_subset_num}'
file_name = f'{tag}_' + file_name
output_path = os.path.join(self.config['output_dir_nextqa'], file_name + '.json')
with open(output_path, 'w') as f:
json.dump(results, f, indent=4)
log.info('Results logged to {}'.format(output_path))
print(os.getcwd())
# Switch back to training mode
self.model.train()
def beam_search_generation(
self, input_ids,
app_feat, mot_feat,
app_interval, mot_interval,
vis_state_vector_idx, video_place_holder_ids, tokenizer):
eos_token = tokenizer.eos_token_id
unk_token = tokenizer.unk_token_id
question_sep = tokenizer.convert_tokens_to_ids('<s2>')
gen_ans = [eos_token]
hyplist = [([], 0.0, [eos_token])]
best_state = None
comp_hyplist = []
app_feat = app_feat.unsqueeze(0).cuda()
mot_feat = mot_feat.unsqueeze(0).cuda()
video_place_holder_ids = video_place_holder_ids.cuda()
text_shift_len = video_place_holder_ids.size(-1)
question_intervals = [[0 + text_shift_len, input_ids.size(0) + text_shift_len]] # The last token is the question state token (not part of the history)
question_state_vector_idx = [x[0] for x in question_intervals]
input_ids = input_ids.long().cuda().unsqueeze(0)
encoder_outputs = None
for i in range(self.config['max_generation_length']):
new_hyplist = []
argmin = 0
for out, lp, st in hyplist:
decoder_input_ids = torch.tensor(st).long().cuda().unsqueeze(0)
bart_output = self.model(
input_ids=input_ids,
video_place_holder_ids=video_place_holder_ids,
i3d_rgb=app_feat,
i3d_flow=mot_feat,
encoder_outputs=encoder_outputs,
decoder_input_ids=decoder_input_ids,
i3d_rgb_interval=app_interval,
i3d_flow_interval=mot_interval,
question_intervals=question_intervals,
vis_state_vector_idx=vis_state_vector_idx,
question_state_vector_idx=question_state_vector_idx,
output_attentions=True,
generate=True,
return_dict=True
)
if encoder_outputs is None:
encoder_outputs = [
bart_output['encoder_last_hidden_state'],
bart_output['encoder_hidden_states'],
bart_output['encoder_attentions'],
bart_output['encoder_QAs_local'],
bart_output['encoder_PAs_local'],
bart_output['encoder_QA_global'],
bart_output['encoder_PA_global'],
bart_output['encoder_state_vectors']
]
logits = bart_output['logits'][:,-1,:].squeeze() # get the logits of the last token
logp = F.log_softmax(logits, dim=0)
lp_vec = logp.cpu().data.numpy() + lp
if i >= self.config['min_generation_length']:
new_lp = lp_vec[eos_token] + self.config['length_penalty'] * (len(out) + 1)
comp_hyplist.append((out, new_lp))
if best_state is None or best_state < new_lp:
best_state = new_lp
count = 1
for o in np.argsort(lp_vec)[::-1]: # reverse the order
if o in [eos_token, unk_token]:
continue
new_lp = lp_vec[o]
if len(new_hyplist) == self.config['beam_depth']:
if new_hyplist[argmin][1] < new_lp:
new_st = deepcopy(st)
new_st.append(int(o))
new_hyplist[argmin] = (out + [o], new_lp, new_st)
argmin = min(enumerate(new_hyplist), key=lambda h: h[1][1])[0]
else:
break
else:
new_st = deepcopy(st)
new_st.append(int(o))
new_hyplist.append((out + [o], new_lp, new_st))
if len(new_hyplist) == self.config['beam_depth']:
argmin = min(enumerate(new_hyplist), key=lambda h: h[1][1])[0]
count += 1
hyplist = new_hyplist
if len(comp_hyplist) > 0:
maxhyps = sorted(comp_hyplist, key=lambda h: -h[1])[:1]
return maxhyps[0][0]
else:
return []