Make code public

This commit is contained in:
Adnen Abdessaied 2024-07-08 11:41:28 +02:00
commit 8e03ef1c38
49 changed files with 545354 additions and 0 deletions

3
.gitattributes vendored Normal file
View file

@ -0,0 +1,3 @@
*.pkl filter=lfs diff=lfs merge=lfs -text
*.h5 filter=lfs diff=lfs merge=lfs -text
*.csv filter=lfs diff=lfs merge=lfs -text

21
LICENSE Normal file
View file

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2023 Anonymous
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

174
README.md Normal file
View file

@ -0,0 +1,174 @@
<div align="center">
<h1> MST-MIXER <img src="misc/mixer.png" width="3%" align="bottom">: Multi-Modal Video Dialog State Tracking in the Wild </h1>
**[Adnen Abdessaied][16], &nbsp; [Lei Shi][17], &nbsp; [Andreas Bulling][18]** <br> <br>
**ECCV 2024, Milan, Italy <img src="misc/italy.png" width="3%" align="center">** <br>
**[[Paper][19]]**
---------------------------
<img src="misc/teaser.png" width="70%" align="middle"><br><br>
</div>
# Citation
If you find our code useful or use it in your own projects, please cite our paper:
```bibtex
@InProceedings{Abdessaied_2024_eccv,
author = {Abdessaied, Adnen and Shi, Lei and Bulling, Andreas},
title = {{Multi-Modal Video Dialog State Tracking in the Wild}},
booktitle = {Proceedings of the European Conference on Computer Vision (ECCV)},
year = {2024}
}
```
# Table of Contents
* [Setup and Dependencies](#Setup-and-Dependencies)
* [Download Data](#Download-Data)
* [Training](#Training)
* [Response Generation](#Response-Generation)
* [Results](#Results)
* [Acknowledgements](#Acknowledgements)
# Setup and Dependencies
We implemented our model using Python 3.7 and PyTorch 1.12.0 (CUDA 11.3, CuDNN 8.3.2). We recommend to setup a virtual environment using Anaconda. <br>
1. Install [git lfs][1] on your system
2. Clone our repository to download a checpint of our best model and our code
```shell
git lfs install
git clone this_repo.git
```
3. Create a conda environment and install dependencies
```shell
conda create -n mst_mixer python=3.7
conda activate mst_mixer
conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit=11.3 -c pytorch
conda install pyg -c pyg
conda install pytorch-scatter -c pyg # pytorch >= 1.8.0
conda install pytorch-sparse -c pyg # pytorch >= 1.8.0
conda install -c huggingface transformers
pip install evaluate wandb glog pyhocon attrs
```
# Download Data
## AVSD
1. Download the [AVSD-DSTC7][2], [AVSD-DSTC8][3] and [AVSD-DSTC10][10] data
2. Place the raw json files in ```raw_data/``` and the features in ```features/```
3. Prepeocess and save the input features for faster training as indicated in ```custom_datasets/```
## NExT-QA
1. For convenience, we included the features/data in this git repo.
# Training
We trained our model on 8 Nvidia Tesla V100-32GB GPUs. The default hyperparameters in ```config/mst_mixer.conf``` need to be adjusted if your setup differs from ours.
## AVSD
1. Set ```task=avsd``` in ```config/mst_mixer.conf```
2. ```shell
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
--mode train \
--tag mst_mixer_avsd \
--wandb_mode online \
--wandb_project mst_mixer_avsd
```
To deactivate [wandb][4] logging, use ```--wandb_mode disabled```.
On a similar setup to ours, this will take roughly 20h to complete.
## NExT-QA
1. Set ```task=nextqa``` in ```config/mst_mixer.conf```
2. ```shell
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
--mode train \
--tag mst_mixer_nextqa \
--wandb_mode online \
--wandb_project mst_mixer_nextqa
```
# Response Generation
## AVSD-DSTC7
1. Set ```dstc=7``` in the ```.conf``` file of your trained networks. in The default setting, can find this under ```logs/unique_training_tag/code/config/mst_mixer.conf```
2. Generate the responses
```shell
./generate_parallel_avsd.sh mst_mixer/mixer results_avsd_dstc7 generate logs/mst_mixer_avsd 7
```
3. All responses will be saved in ```output/dstc7/```
## AVSD-DSTC8
1. Set ```dstc=8``` in the ```.conf``` file of your trained networks. in The default setting, can find this under ```logs/unique_training_tag/code/config/mst_mixer.conf```
2. Generate the responses
```shell
./generate_parallel_avsd.sh mst_mixer/mixer results_avsd_dstc8 generate logs/mst_mixer_avsd 8
```
3. All responses will be saved in ```output/dstc8/```
## AVSD-DSTC10
1. Set ```dstc=10``` in the ```.conf``` file of your trained networks. in The default setting, can find this under ```logs/unique_training_tag/code/config/mst_mixer.conf```
2. Generate the responses
```shell
./generate_parallel_avsd.sh mst_mixer/mixer results_avsd_dstc10 generate logs/mst_mixer_avsd 10
```
3. All responses will be saved in ```output/dstc10/```
## NExT-QA
1. Generate the responses
```shell
./generate_parallel_nextqa.sh mst_mixer/mixer results_nextqa generate logs/mst_mixer_nextqa
```
2. All responses will be saved in ```output/nextqa/```
3. Evalute using this [script][15]
# Results
To evaluate our best model on
## AVSD-DSTC7
Executing the [eval_tool][7] of AVSD-DSTC7 using the generated repsonses will output the following metrics
| Model | BLUE-1 | BLUE-2 | BLUE-3 | BLUE-4 | METEOR | ROUGE-L | CIDEr |
|:--------:|:------:|:------:|:------:|:------:|:------:|:-------:|:-----:|
| Prev. SOTA | 78.2 | 65.5 | 55.2 | 46.9 | 30.8 | 61.9 | 135.2 |
| MST_MIXER | **78.7** | **66.5** | **56.3** | **47.6** | **31.3** | **62.5** | **138.8**|
## AVSD-DSTC8
1. Set ```dstc=8``` in the ```ckpt/code/mst_mixer.conf```
2. run
```shell
./generate_parallel_avsd.sh mst_mixer/mixer results_avsd_dstc8_best_model generate ckpt/avsd 8
```
3. The responses will be saved in ```output/dstc8/```
4. Executing the [eval_tool][7] of AVSD-DSTC8 using the generated repsonses will output the following metrics
| Model | BLUE-1 | BLUE-2 | BLUE-3 | BLUE-4 | METEOR | ROUGE-L | CIDEr |
|:--------:|:------:|:------:|:------:|:------:|:------:|:-------:|:-----:|
| Prev. SOTA | 76.4 | 64.1 | 54.3 | 46.0 | 30.1 | 61.0 | 130.4 |
| MST_MIXER | **77.5** | **66.0** | **56.1** | **47.7** | **30.6** | **62.4** | **135.4**|
## AVSD-DSTC10
Executing the [eval_tool][11] of AVSD-DSTC10 using the generated repsonses will output the following metrics
| Model | BLUE-1 | BLUE-2 | BLUE-3 | BLUE-4 | METEOR | ROUGE-L | CIDEr |
|:--------:|:------:|:------:|:------:|:------:|:------:|:-------:|:-----:|
| Prev. SOTA | 69.3 | 55.6 | 45.0 | 37.2 | 24.9 | 53.6 | 91.2 |
| MST_MIXER | **70.0** | **57.4** | **47.6** | **40.0** | **25.7** | **54.5** | **99.8**|
## NExT-QA
Executing the [eval script][15] of NExT-QA using the generated repsonses will output the following metrics
| Model | WUPS_C | WUPS_T | WUPS_D | WUPS |
|:--------:|:------:|:------:|:------:|:------:|
| Prev. SOTA | 17.98| 17.95 | 50.84 | 28.40 |
| MST_MIXER | **22.12** | **22.20** | **55.64** | **29.50** |
# Acknowledgements
We thank the authors of [RLM][8] for providing their [code][9] that greatly influenced this work.
[1]: https://git-lfs.com/
[2]: https://github.com/hudaAlamri/DSTC7-Audio-Visual-Scene-Aware-Dialog-AVSD-Challenge
[3]: https://github.com/dialogtekgeek/DSTC8-AVSD_official
[4]: https://wandb.ai/site
[5]: https://drive.google.com/drive/folders/1SlZTySJAk_2tiMG5F8ivxCfOl_OWwd_Q
[7]: https://drive.google.com/file/d/1EKfPtrNBQ5ciKRl6XggImweGRP84XuPi/view?usp=sharing
[8]: https://arxiv.org/abs/2002.00163
[9]: https://github.com/ictnlp/DSTC8-AVSD
[10]: https://drive.google.com/file/d/1zvC6FuPRVRiLQCXZcYpzYUI9r1tiWls6/view
[11]: https://github.com/ankitshah009/AVSD-DSTC10_baseline
[15]: https://github.com/doc-doc/NExT-OE/blob/main/eval_oe.py
[16]: https://adnenabdessaied.de/
[17]: https://perceptualui.org/people/shi/
[18]: https://perceptualui.org/people/bulling/
[19]: https://arxiv.org/abs/2407.02218

118
config/avsd_bart_base.json Normal file
View file

@ -0,0 +1,118 @@
{
"_name_or_path": "bart-base",
"activation_dropout": 0.1,
"activation_function": "gelu",
"add_bias_logits": false,
"add_final_layer_norm": false,
"architectures": [
"BartModel"
],
"attention_dropout": 0.1,
"bos_token_id": 0,
"classif_dropout": 0.1,
"classifier_dropout": 0.0,
"d_model": 768,
"decoder_attention_heads": 12,
"decoder_ffn_dim": 3072,
"decoder_layerdrop": 0.0,
"decoder_layers": 6,
"decoder_start_token_id": 2,
"dropout": 0.1,
"early_stopping": true,
"encoder_attention_heads": 12,
"encoder_ffn_dim": 3072,
"encoder_layerdrop": 0.0,
"encoder_layers": 6,
"eos_token_id": 2,
"forced_eos_token_id": 2,
"forced_bos_token_id": 0,
"gradient_checkpointing": false,
"id2label": {
"0": "LABEL_0",
"1": "LABEL_1",
"2": "LABEL_2"
},
"init_std": 0.02,
"is_encoder_decoder": true,
"label2id": {
"LABEL_0": 0,
"LABEL_1": 1,
"LABEL_2": 2
},
"max_position_embeddings": 1024,
"model_type": "bart",
"no_repeat_ngram_size": 3,
"normalize_before": false,
"normalize_embedding": true,
"num_beams": 4,
"num_hidden_layers": 6,
"pad_token_id": 1,
"scale_embedding": false,
"task_specific_params": {
"summarization": {
"length_penalty": 1.0,
"max_length": 128,
"min_length": 12,
"num_beams": 4
},
"summarization_cnn": {
"length_penalty": 2.0,
"max_length": 142,
"min_length": 56,
"num_beams": 4
},
"summarization_xsum": {
"length_penalty": 1.0,
"max_length": 62,
"min_length": 11,
"num_beams": 6
}
},
"torch_dtype": "float32",
"transformers_version": "4.12.0.dev0",
"use_cache": true,
"vocab_size": 50265,
"d_i3d_flow": 2048,
"d_i3d_rgb": 2048,
"d_sam": 512,
"d_audio": 128,
"top_k": 10,
"num_nn": 4,
"gnn_type": "appnp",
"use_random_graphs": false,
"integrate_all_gnn_features": true,
"use_elbo_local": true,
"use_elbo_global": true,
"use_non_linear": true,
"gnn_K": 2,
"gnn_alpha": 0.1,
"num_modalities": 6,
"local_gnn_d_hidden": 768,
"global_gnn_d_hidden": 768,
"num_local_gnn_heads": 2,
"num_global_gnn_heads": 4,
"local_gnn_dropout": 0.1,
"global_gnn_dropout": 0.1,
"local_fc_dropout": 0.1,
"global_fc_dropout": 0.1,
"num_local_gnn_layers": 2,
"num_global_gnn_layers": 2,
"num_local_fc_layers": 2,
"num_global_fc_layers": 2,
"use_local_gnn_bn": true,
"use_global_gnn_bn": true,
"use_local_fc_bn": true,
"use_global_fc_bn": true,
"local_gnn_concat": true,
"global_gnn_concat": true,
"num_local_gr_learner_heads": 8,
"num_global_gr_learner_heads":8,
"init_adj_ratio": 0.5,
"adj_ratio": 0.5,
"alpha": 0.9,
"gnns_every": 2,
"num_layers_state_fc_decoder": 2,
"dropout_state_fc_decoder": 0.3
}

116
config/avsd_bart_large.json Normal file
View file

@ -0,0 +1,116 @@
{
"activation_dropout": 0.1,
"activation_function": "gelu",
"add_bias_logits": false,
"add_final_layer_norm": false,
"architectures": [
"BartModel"
],
"attention_dropout": 0.1,
"bos_token_id": 0,
"classif_dropout": 0.1,
"classifier_dropout": 0.0,
"d_model": 1024,
"decoder_attention_heads": 16,
"decoder_ffn_dim": 4096,
"decoder_layerdrop": 0.0,
"decoder_layers": 12,
"decoder_start_token_id": 2,
"dropout": 0.1,
"early_stopping": true,
"encoder_attention_heads": 16,
"encoder_ffn_dim": 4096,
"encoder_layerdrop": 0.0,
"encoder_layers": 12,
"eos_token_id": 2,
"forced_eos_token_id": 2,
"forced_bos_token_id": 0,
"gradient_checkpointing": false,
"id2label": {
"0": "LABEL_0",
"1": "LABEL_1",
"2": "LABEL_2"
},
"init_std": 0.02,
"is_encoder_decoder": true,
"label2id": {
"LABEL_0": 0,
"LABEL_1": 1,
"LABEL_2": 2
},
"max_position_embeddings": 1024,
"model_type": "bart",
"no_repeat_ngram_size": 3,
"normalize_before": false,
"num_beams": 4,
"num_hidden_layers": 12,
"pad_token_id": 1,
"scale_embedding": false,
"task_specific_params": {
"summarization": {
"length_penalty": 1.0,
"max_length": 128,
"min_length": 12,
"num_beams": 4
},
"summarization_cnn": {
"length_penalty": 2.0,
"max_length": 142,
"min_length": 56,
"num_beams": 4
},
"summarization_xsum": {
"length_penalty": 1.0,
"max_length": 62,
"min_length": 11,
"num_beams": 6
}
},
"transformers_version": "4.7.0.dev0",
"use_cache": true,
"vocab_size": 50265,
"d_i3d_flow": 2048,
"d_i3d_rgb": 2048,
"d_sam": 512,
"d_audio": 128,
"top_k": 10,
"num_nn": 4,
"gnn_type": "appnp",
"use_random_graphs": false,
"integrate_all_gnn_features": true,
"use_elbo_local": true,
"use_elbo_global": true,
"use_non_linear": true,
"gnn_K": 2,
"gnn_alpha": 0.1,
"num_modalities": 6,
"local_gnn_d_hidden": 1024,
"global_gnn_d_hidden": 1024,
"num_local_gnn_heads": 2,
"num_global_gnn_heads": 4,
"local_gnn_dropout": 0.1,
"global_gnn_dropout": 0.1,
"local_fc_dropout": 0.1,
"global_fc_dropout": 0.1,
"num_local_gnn_layers": 1,
"num_global_gnn_layers": 1,
"num_local_fc_layers": 1,
"num_global_fc_layers": 1,
"use_local_gnn_bn": true,
"use_global_gnn_bn": true,
"use_local_fc_bn": true,
"use_global_fc_bn": true,
"local_gnn_concat": true,
"global_gnn_concat": true,
"num_local_gr_learner_heads": 8,
"num_global_gr_learner_heads": 8,
"init_adj_ratio": 0.5,
"adj_ratio": 0.5,
"alpha": 0.9,
"gnns_every": 4,
"num_layers_state_fc_decoder": 2,
"dropout_state_fc_decoder": 0.3
}

94
config/mst_mixer.conf Normal file
View file

@ -0,0 +1,94 @@
mixer {
task = avsd
#################################################################################
# datasets
# avsd
avsd_processed = features/
avsd_train = raw_data/train_set4DSTC7-AVSD.json
avsd_val = raw_data/valid_set4DSTC7-AVSD.json
avsd_test_dstc7 = raw_data/test_set4DSTC7-AVSD.json
avsd_test_dstc8 = raw_data/test_set4DSTC8-AVSD.json
avsd_test_dstc10 = raw_data/test_set4DSTC10-AVSD.json
avsd_feature_path = features/
avsd_i3d_rgb = features/i3d_rgb
avsd_i3d_rgb_test = features/i3d_rgb_testset
avsd_i3d_flow = features/i3d_flow_all
avsd_i3d_flow_test = features/i3d_flow_testset
avsd_audio = features/vggish_all
avsd_audio_test = features/vggish_testset
avsd_objects = features/sam
avsd_objects_test = features/sam_testset
dstc = 7
# NextQA
nextqa_root = processed/next_qa/annotations
nextqa_vid_feat = processed/next_qa/vid_feat
#################################################################################
# Model
bart_size = large # base, large
avsd_bart_base_config = config/avsd_bart_base.json
avsd_bart_large_config = config/avsd_bart_large.json
nextqa_bart_large_config = config/nextqa_bart_large.json
#################################################################################
# Logging & Checkpointing
log_dir = logs
output_dir_dstc7 = output/dstc7
output_dir_dstc8 = output/dstc8
output_dir_dstc10 = output/dstc10
output_dir_nextqa = output/nextqa
max_ckpt_to_keep = 5
start_ckpt_for_generating = none
loads_start_path = false
next_logging_pct = 0.1
save_ckpt=true
skip_saving_ckpt = false
stop_epochs = -1
resets_min_val_loss = false
restarts = false
uses_new_optimizer = true
sets_new_lr = false
################################################################################
# Data processing
expand_rnd = false
cap_sum = cap_sum
add_state_tokens = true
bart_max_input_len = 1024
num_workers = 0
n_history = 3
caption_drop_rate = 0.0
vis_feat_length = 36
#################################################################################
# Training
dp_type = ddp
batch_size = 16
num_epochs = 12
warmup_ratio = 0.1
batch_multiply = 1
skip_eval = false
stop_epoch = -1
random_seed = 54
learning_rate_bart = 1e-5
learning_rate_other = 1e-4
min_lr = 0
clip_grad_value = 1.0
print_output = false
eval_first = false
overfit_size = -1
elbo_global_coeff = 100
elbo_local_coeff = 100
gen_coeff = 1
#################################################################################
# Generation
gen_batch_size = 1
beam_depth = 5
max_generation_length = 20
min_generation_length = 1
length_penalty = 0.3
#################################################################################
# Misc.
master_port = 5101
use_cpu = false
}

View file

@ -0,0 +1,116 @@
{
"activation_dropout": 0.1,
"activation_function": "gelu",
"add_bias_logits": false,
"add_final_layer_norm": false,
"architectures": [
"BartModel"
],
"attention_dropout": 0.1,
"bos_token_id": 0,
"classif_dropout": 0.1,
"classifier_dropout": 0.0,
"d_model": 1024,
"decoder_attention_heads": 16,
"decoder_ffn_dim": 4096,
"decoder_layerdrop": 0.0,
"decoder_layers": 12,
"decoder_start_token_id": 2,
"dropout": 0.1,
"early_stopping": true,
"encoder_attention_heads": 16,
"encoder_ffn_dim": 4096,
"encoder_layerdrop": 0.0,
"encoder_layers": 12,
"eos_token_id": 2,
"forced_eos_token_id": 2,
"forced_bos_token_id": 0,
"gradient_checkpointing": false,
"id2label": {
"0": "LABEL_0",
"1": "LABEL_1",
"2": "LABEL_2"
},
"init_std": 0.02,
"is_encoder_decoder": true,
"label2id": {
"LABEL_0": 0,
"LABEL_1": 1,
"LABEL_2": 2
},
"max_position_embeddings": 1024,
"model_type": "bart",
"no_repeat_ngram_size": 3,
"normalize_before": false,
"num_beams": 4,
"num_hidden_layers": 12,
"pad_token_id": 1,
"scale_embedding": false,
"task_specific_params": {
"summarization": {
"length_penalty": 1.0,
"max_length": 128,
"min_length": 12,
"num_beams": 4
},
"summarization_cnn": {
"length_penalty": 2.0,
"max_length": 142,
"min_length": 56,
"num_beams": 4
},
"summarization_xsum": {
"length_penalty": 1.0,
"max_length": 62,
"min_length": 11,
"num_beams": 6
}
},
"transformers_version": "4.7.0.dev0",
"use_cache": true,
"vocab_size": 50265,
"d_i3d_flow": 2048,
"d_i3d_rgb": 2048,
"d_sam": 512,
"d_audio": 128,
"top_k": 10,
"num_nn": 4,
"gnn_type": "appnp",
"use_random_graphs": false,
"integrate_all_gnn_features": true,
"use_elbo_local": true,
"use_elbo_global": true,
"use_non_linear": true,
"gnn_K": 2,
"gnn_alpha": 0.1,
"num_modalities": 3,
"local_gnn_d_hidden": 1024,
"global_gnn_d_hidden": 1024,
"num_local_gnn_heads": 2,
"num_global_gnn_heads": 4,
"local_gnn_dropout": 0.1,
"global_gnn_dropout": 0.1,
"local_fc_dropout": 0.1,
"global_fc_dropout": 0.1,
"num_local_gnn_layers": 1,
"num_global_gnn_layers": 1,
"num_local_fc_layers": 1,
"num_global_fc_layers": 1,
"use_local_gnn_bn": true,
"use_global_gnn_bn": true,
"use_local_fc_bn": true,
"use_global_fc_bn": true,
"local_gnn_concat": true,
"global_gnn_concat": true,
"num_local_gr_learner_heads": 8,
"num_global_gr_learner_heads": 8,
"init_adj_ratio": 0.5,
"adj_ratio": 0.5,
"alpha": 0.9,
"gnns_every": 4,
"num_layers_state_fc_decoder": 2,
"dropout_state_fc_decoder": 0.3
}

20
custom_datasets/README.md Normal file
View file

@ -0,0 +1,20 @@
1. Download the raw [Charades train/val](https://prior.allenai.org/projects/charades) data
2. Download the raw [Charades test](https://ai2-public-datasets.s3-us-west-2.amazonaws.com/charades/Charades_vu17_test_480.tar) data
3. Install [SAM](https://github.com/facebookresearch/segment-anything.git)
4. Segment the frames
```shell
python segement.py --sam_ckpt path_to_sam_ckpt --avsd_root path_to_charades_trval_frames --crop_root path_to_save_the_trval_crops --mode segment --start start_idx --end end_idx
python segement.py --sam_ckpt path_to_sam_ckpt --avsd_root path_to_charades_test_frames --crop_root path_to_save_the_test_crops --mode segment --start start_idx --end end_id
```
5. Embed the crops
```shell
python segement.py --sam_ckpt path_to_sam_ckpt --crop_root path_to_save_the_trval_crops --mode emebed --embed_root ../features/sam --start start_idx --end end_idx
python segement.py --sam_ckpt path_to_sam_ckpt --crop_root path_to_save_the_test_crops --mode emebed --embed_root ../features/sam_testset --start start_idx --end end_idx
```
6. Preprocess and log the data
```shell
python dataset.py --split train
python dataset.py --split val
```

View file

401
custom_datasets/avsd.py Normal file
View file

@ -0,0 +1,401 @@
import os
import pickle
import pyhocon
from copy import deepcopy
import json
from tqdm import tqdm
import numpy as np
import torch
from argparse import ArgumentParser
from torch.utils.data import Dataset, DataLoader
from transformers import BartTokenizer
from itertools import chain
ADDITIONAL_SPECIAL_TOKENS = [
'<place_holder>', '<s0>', '<s1>', '<s2>', '<s3>', '<s4>', '<s5>']
SPECIAL_TOKENS_DICT = {
'bos_token': '<s>',
'eos_token': '</s>',
'pad_token': '<pad>',
'additional_special_tokens': ADDITIONAL_SPECIAL_TOKENS
}
S0_TOK = '<s0>' # I3D_flow
S1_TOK = '<s1>' # I3D_rgb
S2_TOK = '<s2>' # sam obj
S3_TOK = '<s3>' # audio
S4_TOK = '<s4>' # history
S5_TOK = '<s5>' # question
def tokenize(obj, tokenizer):
if isinstance(obj, str):
return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj))
if isinstance(obj, dict):
return dict((n, tokenize(o)) for n, o in obj.items())
return list(tokenize(o) for o in obj)
class AVSDDataset(Dataset):
def __init__(self, config, split):
super().__init__()
self.config = config
self.split = split
self.bart_max_input_len = config['bart_max_input_len']
self.bart_size = config['bart_size']
self.cap_sum = config['cap_sum']
self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-{}'.format(self.bart_size))
self.vocab_size = self.tokenizer.vocab_size
self.tokenizer.add_special_tokens({'additional_special_tokens': ADDITIONAL_SPECIAL_TOKENS})
self.vocab_size += len(ADDITIONAL_SPECIAL_TOKENS)
self.tokenizer.save_pretrained(os.path.join(self.config['log_dir'], 'bart_tokenizer'))
self.processed_dir = os.path.join(self.config['avsd_processed'], 'hist_with_{}_rounds'.format(self.config['n_history']), split)
self.paths = list(map(lambda p: os.path.join(self.processed_dir, p), os.listdir(self.processed_dir)))
if self.config['overfit'] > 0:
self.paths = self.paths[:self.config['overfit_size']]
def __len__(self):
return len(self.paths)
def __getitem__(self, index):
pth = self.paths[index]
with open(pth, 'rb') as f:
item = pickle.load(f)
question_sep = self.tokenizer.convert_tokens_to_ids('<s5>')
input_ids = item['input_ids']
history_end = (input_ids == question_sep).nonzero(as_tuple=True)[0]
history_interval = [0, history_end.item()] # The last token is the question state token (not part of the history)
question_interval = [history_end.item(), input_ids.size(0)]
lm_labels = item['lm_labels']
i3d_rgb = item['i3d_rgb']
i3d_flow = item['i3d_flow']
sam = item['sam']
vgg = item['vgg']
vid = item['vid']
return input_ids, lm_labels, history_interval, question_interval, i3d_rgb, i3d_flow, sam, vgg, vid
def padding(self, seq, pad_token, max_len=None):
if max_len is None:
max_len = max([i.size(0) for i in seq])
if len(seq[0].size()) == 1:
result = torch.ones((len(seq), max_len)).long() * pad_token
else:
result = torch.ones((len(seq), max_len, seq[0].size(-1))).float()
for i in range(len(seq)):
result[i, :seq[i].size(0)] = seq[i]
return result
def collate_fn(self, batch):
input_ids_list, lm_labels_list, history_interval_list, question_interval_list, i3d_rgb_list, i3d_flow_list, sam_list, vggish_list, vid_ids_list = [], [], [], [], [], [], [], [], []
for i in batch:
input_ids_list.append(i[0])
lm_labels_list.append(i[1])
history_interval_list.append(i[2])
question_interval_list.append(i[3])
i3d_rgb_list.append(i[4])
i3d_flow_list.append(i[5])
sam_list.append(i[6])
vggish_list.append(i[7])
vid_ids_list.append(i[8])
history_intervals = np.array(history_interval_list)
question_intervals = np.array(question_interval_list)
min_len_i3d_flow = min([feat.shape[0] for feat in i3d_flow_list])
min_len_i3d_rgb = min([feat.shape[0] for feat in i3d_rgb_list])
min_len_sam = min([feat.shape[0] for feat in sam_list])
min_len_vggish = min([feat.shape[0] for feat in vggish_list])
min_length = min([self.config['vis_feat_length'], min_len_i3d_flow, min_len_i3d_rgb, min_len_sam, min_len_vggish])
# Sample equally-distant features from the visual features for each sample within the batch
for i in range(len(i3d_rgb_list)):
sample_idx_i3d_rgb = np.round(np.linspace(0, i3d_rgb_list[i].shape[0] - 1, min_length)).astype(int)
i3d_rgb_list[i] = i3d_rgb_list[i][sample_idx_i3d_rgb, :]
i3d_rgb = torch.from_numpy(np.array(i3d_rgb_list)).float()
for i in range(len(i3d_flow_list)):
sample_idx_i3d_flow = np.round(np.linspace(0, i3d_flow_list[i].shape[0] - 1, min_length)).astype(int)
i3d_flow_list[i] = i3d_flow_list[i][sample_idx_i3d_flow, :]
i3d_flow = torch.from_numpy(np.array(i3d_flow_list)).float()
for i in range(len(sam_list)):
sample_idx_sam = np.round(np.linspace(0, sam_list[i].shape[0] - 1, min_length)).astype(int)
sam_list[i] = sam_list[i][sample_idx_sam, :]
sam = torch.from_numpy(np.array(sam_list)).float()
for i in range(len(vggish_list)):
sample_idx_vggish = np.round(np.linspace(0, vggish_list[i].shape[0] - 1, min_length)).astype(int)
vggish_list[i] = vggish_list[i][sample_idx_vggish, :]
vggish = torch.from_numpy(np.array(vggish_list)).float()
pad_token, i3d_flow_sep, i3d_rgb_sep, sam_sep, audio_sep, ph_token = self.tokenizer.convert_tokens_to_ids(
['<pad>', '<s0>', '<s1>', '<s2>', '<s3>', '<place_holder>'])
# All the visual features will not be masked because we do not perform any padding on them
video_mask = torch.ones((len(batch), min_length*4 + 4)) == 1 # NOTE *4: 4 modalities | +4: the state tokens
# Now we create a dummy input for the video tokens (sole purpose is to reserve the spot of the seperators)
dummy = torch.ones((len(batch), min_length)) * ph_token
video_place_holder_ids = torch.cat(
[torch.ones((len(batch), 1)) * i3d_rgb_sep, dummy,
torch.ones((len(batch), 1)) * i3d_flow_sep, dummy,
torch.ones((len(batch), 1)) * sam_sep, dummy,
torch.ones((len(batch), 1)) * audio_sep, dummy,
], dim=-1).long()
input_ids = self.padding(input_ids_list, pad_token)
lm_labels = self.padding(lm_labels_list, -100)
text_mask = input_ids != pad_token
input_mask = torch.cat([video_mask, text_mask], dim=1)
# Now we get the intervals of the visual input tokens
# Here the interval do not change across the batch dimension
i3d_rgb_interval = [0, min_length + 1] # the last token is not part of this modality
i3d_flow_interval = [min_length + 1, 2 * min_length + 2]
sam_interval = [2 * min_length + 2, 3 * min_length + 3]
audio_interval = [3 * min_length + 3, 4 * min_length + 4]
vis_state_vector_idx = [i3d_rgb_interval[0], i3d_flow_interval[0], sam_interval[0], audio_interval[0]]
# adapt the question and history interval -- shifted to the right by the visual input length
history_intervals += 4 * min_length + 4
question_intervals += 4 * min_length + 4
history_intervals = history_intervals.tolist()
question_intervals = question_intervals.tolist()
history_state_vector_idx = [x[0] + 1 for x in history_intervals] # +1 because the history starts with <s><s4> .....
question_state_vector_idx = [x[0] for x in question_intervals] # +1 because the history starts with <s><s4> .....
batch = {
'input_ids': input_ids,
'video_place_holder_ids': video_place_holder_ids,
'i3d_rgb': i3d_rgb,
'i3d_flow': i3d_flow,
'sam': sam,
'vggish': vggish,
'lm_labels': lm_labels,
'input_mask': input_mask,
'i3d_rgb_interval': i3d_rgb_interval,
'i3d_flow_interval': i3d_flow_interval,
'sam_interval': sam_interval,
'audio_interval': audio_interval,
'history_intervals': history_intervals,
'question_intervals': question_intervals,
'vis_state_vector_idx': vis_state_vector_idx,
'history_state_vector_idx': history_state_vector_idx,
'question_state_vector_idx': question_state_vector_idx
}
return batch
def get_dataset(config, split, tokenizer):
if split != 'test':
dialog_pth = config[f'avsd_{split}']
else:
dialog_pth = config['avsd_test_dstc{}'.format(config['dstc'])]
n_history = config['n_history']
dialog_data = json.load(open(dialog_pth, 'r'))
dialog_list = []
vid_set = set()
undisclosed_only = split == 'test'
pbar = tqdm(dialog_data['dialogs'])
pbar.set_description('[INFO] Generating {} items | DSTC {}'.format(split, config['dstc']))
for dialog in pbar:
if config['dstc'] != 10:
caption = [tokenize(dialog['caption'], tokenizer)] + [tokenize(dialog['summary'], tokenizer)]
else:
caption = [tokenize('no', tokenizer)]
questions = [tokenize(d['question'], tokenizer) for d in dialog['dialog']]
answers = [tokenize(d['answer'], tokenizer) for d in dialog['dialog']]
vid = dialog["image_id"]
vid_set.add(vid)
if undisclosed_only:
it = range(len(questions) - 1, len(questions))
else:
it = range(len(questions))
qalist=[]
history = []
if undisclosed_only:
for n in range(len(questions)-1):
qalist.append(questions[n])
qalist.append(answers[n])
history=qalist[max(-len(qalist),-n_history*2):]
for n in it:
if undisclosed_only:
assert dialog['dialog'][n]['answer'] == '__UNDISCLOSED__'
question = questions[n]
answer = answers[n]
history.append(question)
if n_history == 0:
item = {'vid': vid, 'history': [question], 'answer': answer, 'caption': caption}
else:
item = {'vid': vid, 'history': history, 'answer': answer, 'caption': caption}
dialog_list.append(item)
qalist.append(question)
qalist.append(answer)
history=qalist[max(-len(qalist),-n_history*2):]
all_features = {}
fea_types = ['vggish', 'i3d_flow', 'i3d_rgb', 'sam']
dataname = '<FeaType>/<ImageID>.npy'
for ftype in fea_types:
if undisclosed_only:
basename = dataname.replace('<FeaType>', ftype+'_testset')
else:
basename = dataname.replace('<FeaType>', ftype)
features = {}
for vid in vid_set:
filename = basename.replace('<ImageID>', vid)
filepath = config['avsd_feature_path'] + filename
features[vid] = filepath
all_features[ftype] = features
return dialog_list, all_features
def build_input_from_segments(caption, history_orig, reply, tokenizer, add_state_tokens=True, drop_caption=False):
""" Build a sequence of input from 3 segments: caption(caption+summary) history and last reply """
bos, eos, hist_state, ques_state = tokenizer.convert_tokens_to_ids(['<s>', '</s>', '<s4>', '<s5>'])
sep = eos
instance = {}
instance["lm_labels"] = reply + [eos]
caption = list(chain(*caption))
# Add state tokens if applicable
if add_state_tokens:
caption.insert(0, hist_state)
history = deepcopy(history_orig)
history[-1].insert(0, ques_state)
else:
history = history_orig
if not drop_caption:
# sequence = [[bos] + list(chain(*caption))] + history + [reply + ([eos] if with_eos else [])]
# NOTE It is important not to include the reply in the input of the encoder -- > the decoder will just
# learn to copy it --> low train/val loss but no learning is happening
sequence = [[bos] + caption + [eos]] + [[sep] + s for s in history] + [[eos]]
else:
sequence = [[bos]] + [[hist_state]] + [[sep] + s for s in history] + [[eos]]
instance["input_ids"] = list(chain(*sequence))
return instance
def parse_args():
parser = ArgumentParser(description='debug dataloader')
parser.add_argument(
'--split',
type=str,
default='train',
help='train or val')
parser.add_argument(
'--model',
type=str,
default='mixer',
help='model name to train or test')
parser.add_argument(
'--log_dataset',
action='store_true',
default=False,
help='Whether or not to log the processed data')
parser.add_argument(
'--add_state_tokens',
action='store_true',
default=True,
help='Whether or not to add state tokens')
parser.add_argument(
'--log_dir',
type=str,
default='processed/avsd',
help='Output directory')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
split = args.split
config = pyhocon.ConfigFactory.parse_file(
'config/mst_mixer.conf')[args.model]
config['expand_rnd'] = False
config['debugging'] = False
config['overfit'] = False
args.log_dir = os.path.join(args.log_dir, 'hist_with_{}_rounds'.format(config['n_history']) )
if args.log_dataset:
log_dir = os.path.join(args.log_dir, split)
if not os.path.isdir(log_dir):
os.makedirs(log_dir)
tokenizer = BartTokenizer.from_pretrained('facebook/bart-{}'.format(config['bart_size']))
tokenizer.add_special_tokens({'additional_special_tokens': ADDITIONAL_SPECIAL_TOKENS})
dialogs, features = get_dataset(config, split, tokenizer)
pbar = tqdm(dialogs)
pbar.set_description('[{}] Logging processed data'.format(split))
counter = 0
for dialog in pbar:
vid = dialog['vid']
his = dialog['history']
cap = dialog['caption']
ans = dialog['answer']
if np.random.rand() < config['caption_drop_rate']:
instance = build_input_from_segments(
cap, his, ans, tokenizer, add_state_tokens=args.add_state_tokens, drop_caption=True)
else:
instance = build_input_from_segments(
cap, his, ans, tokenizer, add_state_tokens=args.add_state_tokens, drop_caption=False)
input_ids = torch.Tensor(instance["input_ids"]).long()
lm_labels = torch.Tensor(instance["lm_labels"]).long()
vgg = np.load(features["vggish"][vid])
i3d_flow = np.load(features["i3d_flow"][vid])
i3d_rgb = np.load(features["i3d_rgb"][vid])
sam = np.load(features["sam"][vid])
item = {
'input_ids': input_ids,
'lm_labels': lm_labels,
'i3d_rgb': i3d_rgb,
'i3d_flow': i3d_flow,
'sam': sam,
'vgg': vgg,
'vid': vid
}
counter += 1
pth = os.path.join(log_dir, str(counter) + '.pkl')
with open(pth, 'wb') as f:
pickle.dump(item, f, protocol=pickle.HIGHEST_PROTOCOL)
else:
avsd_dataset = AVSDDataset(config, 'val')
avsd_dataloader = DataLoader(avsd_dataset, batch_size=4, shuffle=False, collate_fn=avsd_dataset.collate_fn)
for i, data in enumerate(avsd_dataloader):
print('{}/{}'.format(i, len(avsd_dataloader)))
print(avsd_dataset.max_len)
print('[INFO] Done...')

211
custom_datasets/nextqa.py Normal file
View file

@ -0,0 +1,211 @@
import os
import pandas as pd
import h5py
import json
import numpy as np
import torch
from torch.utils.data import Dataset
from transformers import BartTokenizer
from itertools import chain
ADDITIONAL_SPECIAL_TOKENS = [
'<place_holder>', '<s0>', '<s1>', '<s2>', '<s3>', '<s4>', '<s5>']
SPECIAL_TOKENS_DICT = {
'bos_token': '<s>',
'eos_token': '</s>',
'pad_token': '<pad>',
'additional_special_tokens': ADDITIONAL_SPECIAL_TOKENS
}
S0_TOK = '<s0>' # frame
S1_TOK = '<s1>' # mot
S2_TOK = '<s2>' # question
def load_file(file_name):
annos = None
if os.path.splitext(file_name)[-1] == '.csv':
return pd.read_csv(file_name)
with open(file_name, 'r') as fp:
if os.path.splitext(file_name)[1]== '.txt':
annos = fp.readlines()
annos = [line.rstrip() for line in annos]
if os.path.splitext(file_name)[1] == '.json':
annos = json.load(fp)
return annos
def tokenize(obj, tokenizer):
if isinstance(obj, str):
return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj))
if isinstance(obj, dict):
return dict((n, tokenize(o)) for n, o in obj.items())
return list(tokenize(o) for o in obj)
class NextQADataset(Dataset):
def __init__(self, config, split):
super().__init__()
self.config = config
self.split = split
self.bart_max_input_len = config['bart_max_input_len']
self.bart_size = config['bart_size']
self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-{}'.format(self.bart_size))
self.vocab_size = self.tokenizer.vocab_size
self.tokenizer.add_special_tokens({'additional_special_tokens': ADDITIONAL_SPECIAL_TOKENS})
self.vocab_size += len(ADDITIONAL_SPECIAL_TOKENS)
self.tokenizer.save_pretrained(os.path.join(self.config['log_dir'], 'bart_tokenizer'))
sample_list_file = os.path.join(self.config['nextqa_root'], '{}.csv'.format(split))
self.sample_list = load_file(sample_list_file)
vid_feat_file = os.path.join(self.config['nextqa_vid_feat'], 'app_mot_{}.h5'.format(split))
print('Load {}...'.format(vid_feat_file))
self.frame_feats = {}
self.mot_feats = {}
with h5py.File(vid_feat_file, 'r') as fp:
vids = fp['ids']
feats = fp['feat']
for vid, feat in zip(vids, feats):
self.frame_feats[str(vid)] = feat[:, :2048] # (16, 2048)
self.mot_feats[str(vid)] = feat[:, 2048:] # (16, 2048)
if self.config['overfit_size'] > 0:
self.sample_list = self.sample_list[:self.config['overfit_size']]
def __len__(self):
return len(self.sample_list)
def get_video_feature(self, video_name):
"""
:param video_name:
:return:
"""
app_feat = self.frame_feats[video_name]
app_feat = torch.from_numpy(app_feat).type(torch.float32)
mot_feat = self.mot_feats[video_name]
mot_feat = torch.from_numpy(mot_feat).type(torch.float32)
return app_feat, mot_feat
def __getitem__(self, idx):
cur_sample = self.sample_list.loc[idx]
video_name, ques, ans, qid = str(cur_sample['video']), str(cur_sample['question']),\
str(cur_sample['answer']), str(cur_sample['qid'])
input_ids = tokenize(ques, self.tokenizer)
lm_labels = tokenize(ans, self.tokenizer)
app_feat, mot_feat = self.get_video_feature(video_name)
bos, eos, ques_state = self.tokenizer.convert_tokens_to_ids(['<s>', '</s>', '<s2>'])
# Add state tokens
input_ids.insert(0, ques_state)
lm_labels.append(eos)
question_interval = [0, len(input_ids)]
input_ids = torch.Tensor(input_ids).long()
lm_labels = torch.Tensor(lm_labels).long()
return input_ids, lm_labels, app_feat, mot_feat, question_interval, video_name
def padding(self, seq, pad_token, max_len=None):
if max_len is None:
max_len = max([i.size(0) for i in seq])
if len(seq[0].size()) == 1:
result = torch.ones((len(seq), max_len)).long() * pad_token
else:
result = torch.ones((len(seq), max_len, seq[0].size(-1))).float()
for i in range(len(seq)):
result[i, :seq[i].size(0)] = seq[i]
return result
def collate_fn(self, batch):
input_ids_list, lm_labels_list, app_feat_list, mot_feat_list, question_interval_list, vid_ids_list = [], [], [], [], [], []
for i in batch:
input_ids_list.append(i[0])
lm_labels_list.append(i[1])
app_feat_list.append(i[2])
mot_feat_list.append(i[3])
question_interval_list.append(i[4])
vid_ids_list.append(i[5])
app_feats = torch.stack(app_feat_list, dim=0).float()
mot_feats = torch.stack(mot_feat_list, dim=0).float()
question_intervals = np.array(question_interval_list)
pad_token, app_sep, mot_sep, ph_token = self.tokenizer.convert_tokens_to_ids(
['<pad>', '<s0>', '<s1>', '<place_holder>'])
# All the visual features will not be masked because we do not perform any padding on them
video_mask = torch.ones((len(batch), 16*2 + 2)) == 1 # NOTE *2: 2 modalities | +2: the state tokens | each modality has length 16
# Now we create a dummy input for the video tokens (sole purpose is to reserve the spot of the seperators)
dummy = torch.ones((len(batch), 16)) * ph_token
video_place_holder_ids = torch.cat(
[torch.ones((len(batch), 1)) * app_sep, dummy,
torch.ones((len(batch), 1)) * mot_sep, dummy,
], dim=-1).long()
input_ids = self.padding(input_ids_list, pad_token)
lm_labels = self.padding(lm_labels_list, -100)
text_mask = input_ids != pad_token
input_mask = torch.cat([video_mask, text_mask], dim=1)
# Now we get the intervals of the visual input tokens
# Here the interval do not change across the batch dimension
app_interval = [0, 16 + 1] # the last token is not part of this modality
mot_interval = [16 + 1, 2 * 16 + 2]
vis_state_vector_idx = [app_interval[0], mot_interval[0]]
# adapt the question and history interval -- shifted to the right by the visual input length
question_intervals += 2 * 16 + 2
question_intervals = question_intervals.tolist()
question_state_vector_idx = [x[0] for x in question_intervals]
batch = {
'input_ids': input_ids,
'video_place_holder_ids': video_place_holder_ids,
'app_feats': app_feats,
'mot_feats': mot_feats,
'lm_labels': lm_labels,
'input_mask': input_mask,
'app_interval': app_interval,
'mot_interval': mot_interval,
'question_intervals': question_intervals,
'vis_state_vector_idx': vis_state_vector_idx,
'question_state_vector_idx': question_state_vector_idx
}
return batch
def get_dataset(config, split):
bart_max_input_len = config['bart_max_input_len']
bart_size = config['bart_size']
sample_list_file = os.path.join(config['nextqa_root'], '{}.csv'.format(split))
sample_list = load_file(sample_list_file)
vid_feat_file = os.path.join(config['nextqa_vid_feat'], 'app_mot_{}.h5'.format(split))
print('Load {}...'.format(vid_feat_file))
app_feats = {}
mot_feats = {}
with h5py.File(vid_feat_file, 'r') as fp:
vids = fp['ids']
feats = fp['feat']
for vid, feat in zip(vids, feats):
app_feats[str(vid)] = feat[:, :2048] # (16, 2048)
mot_feats[str(vid)] = feat[:, 2048:] # (16, 2048)
return sample_list, app_feats, mot_feats

179
custom_datasets/segment.py Normal file
View file

@ -0,0 +1,179 @@
from segment_anything import SamPredictor, SamAutomaticMaskGenerator, sam_model_registry
from tqdm import tqdm
from argparse import ArgumentParser
import pickle
import cv2
import os
import torch
import numpy as np
def parse_args():
parser = ArgumentParser()
parser.add_argument(
'--sam_ckpt',
type=str,
help='SAM checkpoint to be used'
)
parser.add_argument(
'--avsd_root',
type=str,
help='Directory where the individual AVSD frames are located'
)
parser.add_argument(
'--crop_root',
type=str,
help='Directory where the individual crops (objects) will be saved'
)
parser.add_argument(
'--embed_root',
type=str,
help='Directory where the individual embeddings will be saved'
)
parser.add_argument(
'--mode',
type=str,
choices=['segment', 'embed'],
help='segment: segment the image into regions | embed: embed the image crops detected during segmentation'
)
parser.add_argument(
'--start',
type=int,
default=0,
help='Start index of the partition'
)
parser.add_argument(
'--end',
type=int,
default=1968,
help='End index of the partition'
)
args = parser.parse_args()
return args
def partition_ids(avsd_ids, start, end):
avsd_ids.sort()
assert start < end
assert start >= 0 and end <= len(avsd_ids)
avsd_ids_partition = avsd_ids[start:end]
return avsd_ids_partition
def get_middle_frames(avsd_ids_partition, avsd_root):
pbar = tqdm(avsd_ids_partition)
pbar.set_description('[INFO] Preparing frames of {} videos'.format(len(avsd_ids_partition)))
path_list = []
for avsd_id in pbar:
frames = os.listdir(os.path.join(avsd_root, avsd_id))
if 'test' in avsd_root:
frames.sort(key=lambda f: int(f.split('_')[-1].split('.')[0]))
else:
frames.sort(key=lambda f: int(f.split('-')[-1].split('.')[0]))
middle_frame = frames[int(len(frames)/2)]
middle_frame = os.path.join(avsd_root, avsd_id, middle_frame)
path_list.append(middle_frame)
return path_list
def segment_images(sam, path_list, crop_root):
mask_generator = SamAutomaticMaskGenerator(sam)
pbar = tqdm(path_list)
pbar.set_description('Detecting Objects')
for pth in pbar:
vid_id = pth.split('/')[-2]
crop_dir = os.path.join(crop_root, vid_id)
if not os.path.isdir(crop_dir):
os.makedirs(crop_dir)
image = cv2.imread(pth)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
masks = mask_generator.generate(image)
masks.sort(key=lambda e: e['stability_score'], reverse=True)
if len(masks) > 36:
masks = masks[:36]
for i, mask in enumerate(masks):
crop = image[
int(mask['bbox'][1]):int(mask['bbox'][1] + mask['bbox'][3] + 1),
int(mask['bbox'][0]):int(mask['bbox'][0] + mask['bbox'][2] + 1),
:
]
crop_flipped = cv2.flip(crop, 1) # Horizontal flip
cv2.imwrite(os.path.join(crop_dir, f'obj_{i}.jpg'), crop)
cv2.imwrite(os.path.join(crop_dir, f'obj_{i}_flipped.jpg'), crop_flipped)
print('[INFO] Done...')
def embed_objects(sam, crop_ids, crop_root, embed_root):
predictor = SamPredictor(sam)
pbar = tqdm(crop_ids)
pbar.set_description('Embedding Objects')
for vid_id in pbar:
embeds = []
crop_dir = os.path.join(crop_root, vid_id)
crop_paths = list(map(lambda p: os.path.join(crop_dir, p), os.listdir(crop_dir)))
crop_paths = list(filter(lambda p: 'flipped' not in p, crop_paths))
crop_paths.sort(key=lambda p: int(p.split('_')[-1].split('.')[0]))
for cp in crop_paths:
crop = cv2.imread(cp)
crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
predictor.set_image(crop)
embed_crop = predictor.get_image_embedding()
embed_crop = embed_crop.mean(-1).mean(-1)
crop_flipped = cv2.flip(crop, 1)
predictor.set_image(crop_flipped)
embed_crop_flipped = predictor.get_image_embedding()
embed_crop_flipped = embed_crop_flipped.mean(-1).mean(-1)
embed = torch.cat((embed_crop, embed_crop_flipped), dim=-1)
# embed = embed.copy().cpu()
embeds.append(embed)
embeds = torch.cat(embeds, 0).cpu().numpy()
np.save(os.path.join(embed_root, f'{vid_id}.npy'), embeds)
print('[INFO] Done...')
def segment(args, sam):
avsd_ids = os.listdir(args.avsd_root)
avsd_ids.sort()
avsd_ids_partition = partition_ids(avsd_ids, args.start, args.end)
path_list = get_middle_frames(avsd_ids_partition, args.avsd_root)
if not os.path.isdir(args.crop_root):
os.makedirs(args.crop_root)
segment_images(sam, path_list, args.crop_root)
def embed(args, sam):
crop_ids = os.listdir(args.crop_root)
crop_ids.sort()
crop_ids_partition = partition_ids(crop_ids, args.start, args.end)
if not os.path.isdir(args.embed_root):
os.makedirs(args.embed_root)
embed_objects(sam, crop_ids_partition, args.crop_root, args.embed_root)
if __name__ == '__main__':
args = parse_args()
sam = sam_model_registry['vit_h