Code release
This commit is contained in:
commit
09fb25e339
29 changed files with 7162 additions and 0 deletions
277
README.md
Normal file
277
README.md
Normal file
|
@ -0,0 +1,277 @@
|
||||||
|
<div align="center">
|
||||||
|
<h1> VD-GR: Boosting Visual Dialog with Cascaded Spatial-Temporal Multi-Modal GRaphs </h1>
|
||||||
|
|
||||||
|
**[Adnen Abdessaied][5], [Lei Shi][6], [Andreas Bulling][7]** <br> <br>
|
||||||
|
**WACV'24, Hawaii, USA** <img src="misc/usa.png" width="3%" align="center"> <br>
|
||||||
|
**[[Paper][8]]**
|
||||||
|
|
||||||
|
-------------------
|
||||||
|
<img src="misc/teaser_1.png" width="100%" align="middle"><br><br>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
|
||||||
|
# Table of Contents
|
||||||
|
* [Setup and Dependencies](#Setup-and-Dependencies)
|
||||||
|
* [Download Data](#Download-Data)
|
||||||
|
* [Pre-trained Checkpoints](#Pre-trained-Checkpoints)
|
||||||
|
* [Training](#Training)
|
||||||
|
* [Results](#Results)
|
||||||
|
|
||||||
|
# Setup and Dependencies
|
||||||
|
We implemented our model using Python 3.7 and PyTorch 1.11.0 (CUDA 11.3, CuDNN 8.2.0). We recommend to setup a virtual environment using Anaconda. <br>
|
||||||
|
1. Install [git lfs][1] on your system
|
||||||
|
2. Clone our repository to download the data, checkpoints, and code
|
||||||
|
```shell
|
||||||
|
git lfs install
|
||||||
|
git clone this_repo.git
|
||||||
|
```
|
||||||
|
3. Create a conda environment and install dependencies
|
||||||
|
```shell
|
||||||
|
conda create -n vdgr python=3.7
|
||||||
|
conda activate vdgr
|
||||||
|
conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch
|
||||||
|
conda install pyg -c pyg # 2.1.0
|
||||||
|
pip install pytorch-transformers
|
||||||
|
pip install pytorch_pretrained_bert
|
||||||
|
pip install pyhocon glog wandb lmdb
|
||||||
|
```
|
||||||
|
4. If you wish to speed-up training, we recommend installing [apex][2]
|
||||||
|
```shell
|
||||||
|
git clone https://github.com/NVIDIA/apex
|
||||||
|
cd apex
|
||||||
|
# if pip >= 23.1 (ref: https://pip.pypa.io/en/stable/news/#v23-1) which supports multiple `--config-settings` with the same key...
|
||||||
|
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
|
||||||
|
# otherwise
|
||||||
|
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --global-option="--cpp_ext" --global-option="--cuda_ext" ./
|
||||||
|
cd ..
|
||||||
|
```
|
||||||
|
|
||||||
|
# Download Data
|
||||||
|
1. Download the extacted visual features of [VisDial][3] and setup all files we used in our work. We provide a shell script for convenience:
|
||||||
|
```shell
|
||||||
|
./setup_data.sh # Please make sure you have enough disk space
|
||||||
|
```
|
||||||
|
If everything was correctly setup, the ```data/``` folder should look like this
|
||||||
|
```
|
||||||
|
├── history_adj_matrices
|
||||||
|
│ ├── test
|
||||||
|
│ ├── *.pkl
|
||||||
|
│ ├── train
|
||||||
|
│ ├── *.pkl
|
||||||
|
│ ├── val
|
||||||
|
│ ├── *.pkl
|
||||||
|
├── question_adj_matrices
|
||||||
|
│ ├── test
|
||||||
|
│ ├── *.pkl
|
||||||
|
│ ├── train
|
||||||
|
│ ├── *.pkl
|
||||||
|
│ ├── val
|
||||||
|
│ ├── *.pkl
|
||||||
|
├── img_adj_matrices
|
||||||
|
│ ├── *.pkl
|
||||||
|
├── parse_vocab.pkl
|
||||||
|
├── test_dense_mapping.json
|
||||||
|
├── tr_dense_mapping.json
|
||||||
|
├── val_dense_mapping.json
|
||||||
|
├── visdial_0.9_test.json
|
||||||
|
├── visdial_0.9_train.json
|
||||||
|
├── visdial_0.9_val.json
|
||||||
|
├── visdial_1.0_test.json
|
||||||
|
├── visdial_1.0_train_dense_annotations.json
|
||||||
|
├── visdial_1.0_train_dense.json
|
||||||
|
├── visdial_1.0_train.json
|
||||||
|
├── visdial_1.0_val_dense_annotations.json
|
||||||
|
├── visdial_1.0_val.json
|
||||||
|
├── visdialconv_dense_annotations.json
|
||||||
|
├── visdialconv.json
|
||||||
|
├── vispro_dense_annotations.json
|
||||||
|
└── vispro.json
|
||||||
|
```
|
||||||
|
# Pre-trained Checkpoints
|
||||||
|
For convenience, we provide checkpoints of our model after the warm-up training stage in ```ckpt/``` for both VisDial v1.0 and VisDial v0.9. <br>
|
||||||
|
These checkpoints will be downloaded with the code if you use ```git lfs```.
|
||||||
|
|
||||||
|
# Training
|
||||||
|
We trained our model on 8 Nvidia Tesla V100-32GB GPUs. The default hyperparameters in ```config/vdgr.conf``` and ```config/bert_base_6layer_6conect.json``` need to be adjusted if your setup differs from ours.
|
||||||
|
|
||||||
|
## Phase 1
|
||||||
|
### Training
|
||||||
|
1. In this phase, we train our model on VisDial v1.0 via
|
||||||
|
```shell
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
|
||||||
|
--model vdgr/P1 \
|
||||||
|
--mode train \
|
||||||
|
--tag K2_v1.0 \
|
||||||
|
--wandb_mode online \
|
||||||
|
--wandb_project your_wandb_project_name
|
||||||
|
```
|
||||||
|
⚠️ On a similar setup to ours, this will take roughly 20h to complete using apex for training.
|
||||||
|
|
||||||
|
2. To train on VisDial v0.9:
|
||||||
|
* Set ```visdial_version = 0.9``` in ```config/vdgr.conf```
|
||||||
|
* Set ```start_path = ckpt/vdgr_visdial_v0.9_after_warmup_K2.ckpt``` in ```config/vdgr.conf```
|
||||||
|
* Run
|
||||||
|
```shell
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
|
||||||
|
--model vdgr/P1 \
|
||||||
|
--mode train \
|
||||||
|
--tag K2_v0.9 \
|
||||||
|
--wandb_mode online \
|
||||||
|
--wandb_project your_wandb_project_name
|
||||||
|
```
|
||||||
|
### Inference
|
||||||
|
1. For inference on VisDial v1.0 val, VisDialConv, or VisPro:
|
||||||
|
* Set ```eval_dataset = {visdial, visdial_conv, visdial_vispro}``` in ```logs/vdgr/P1_K2_v1.0/code/config/vdgr.conf```
|
||||||
|
* Run
|
||||||
|
```shell
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
|
||||||
|
--model vdgr/P1 \
|
||||||
|
--mode eval \
|
||||||
|
--eval_dir logs/vdgr/P1_K2_v1.0 \
|
||||||
|
--wandb_mode offline \
|
||||||
|
```
|
||||||
|
2. For inference on VisDial v0.9:
|
||||||
|
* Set ```eval_dataset = visdial``` in ```logs/vdgr/P1_K2_v0.9/code/config/vdgr.conf```
|
||||||
|
* Run
|
||||||
|
```shell
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
|
||||||
|
--model vdgr/P1 \
|
||||||
|
--mode eval \
|
||||||
|
--eval_dir logs/vdgr/P1_K2_v0.9 \
|
||||||
|
--wandb_mode offline \
|
||||||
|
```
|
||||||
|
⚠️ This might take some time to finish as the testing data of VisDial v0.9 is large.
|
||||||
|
|
||||||
|
3. For inference on the ```visdial_v1.0 test```:
|
||||||
|
* Run
|
||||||
|
```shell
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
|
||||||
|
--model vdgr/P1 \
|
||||||
|
--mode predict \
|
||||||
|
--eval_dir logs/vdgr/P1_K2_v1.0 \
|
||||||
|
--wandb_mode offline \
|
||||||
|
```
|
||||||
|
* The output file will be saved in ```output/```
|
||||||
|
|
||||||
|
## Phase 2
|
||||||
|
In this phase, we finetune on dense annotations to improve the NDCG score (Only supported for VisDial v1.0.)
|
||||||
|
1. Run
|
||||||
|
```shell
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
|
||||||
|
--model vdgr/P2_CE \
|
||||||
|
--mode train \
|
||||||
|
--tag K2_v1.0_CE \
|
||||||
|
--wandb_mode online \
|
||||||
|
--wandb_project your_wandb_project_name
|
||||||
|
```
|
||||||
|
⚠️This will take roughly 3-4 hours to complete using the same setup as before and [DP][4] for training.
|
||||||
|
|
||||||
|
2. For inference on VisDial v1.0:
|
||||||
|
* Run:
|
||||||
|
```shell
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
|
||||||
|
--model vdgr/P2_CE \
|
||||||
|
--mode predict \
|
||||||
|
--eval_dir logs/vdgr/P1_K2_v1.0_CE \
|
||||||
|
--wandb_mode offline \
|
||||||
|
```
|
||||||
|
* The output file will be saved in ```output/```
|
||||||
|
|
||||||
|
## Phase 3
|
||||||
|
### Training
|
||||||
|
In the final phase, we train an ensemble method comprising of 8 models using ```K={1,2,3,4}``` and ```dense_loss={ce, listnet}```.
|
||||||
|
For ```K=k```:
|
||||||
|
1. Set the value of ```num_v_gnn_layers, num_q_gnn_layers, num_h_gnn_layers``` to ```k```
|
||||||
|
2. Set ```start_path = ckpt/vdgr_visdial_v1.0_after_warmup_K[k].ckpt``` in ```config/vdgr.conf``` (P1)
|
||||||
|
3. Phase 1 training:
|
||||||
|
```shell
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
|
||||||
|
--model vdgr/P1 \
|
||||||
|
--mode train \
|
||||||
|
--tag K[k]_v1.0 \
|
||||||
|
--wandb_mode online \
|
||||||
|
--wandb_project your_wandb_project_name
|
||||||
|
```
|
||||||
|
3. Set ```start_path = logs/vdgr/P1_K[k]_v1.0/epoch_best.ckpt``` in ```config/vdgr.conf``` (P2)
|
||||||
|
4. Phase 2 training:
|
||||||
|
* Fine-tune with CE:
|
||||||
|
```shell
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
|
||||||
|
--model vdgr/P2_CE \
|
||||||
|
--mode train \
|
||||||
|
--tag K[k]_v1.0_CE \
|
||||||
|
--wandb_mode online \
|
||||||
|
--wandb_project your_wandb_project_name
|
||||||
|
```
|
||||||
|
* Fine-tune with LISTNET:
|
||||||
|
```shell
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
|
||||||
|
--model vdgr/P2_LISTNET \
|
||||||
|
--mode train \
|
||||||
|
--tag K[k]_v1.0_LISTNET \
|
||||||
|
--wandb_mode online \
|
||||||
|
--wandb_project your_wandb_project_name
|
||||||
|
```
|
||||||
|
### Inference
|
||||||
|
1. For inference on VisDial v1.0 test:
|
||||||
|
```shell
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
|
||||||
|
--model vdgr/P2_[CE,LISTNET] \
|
||||||
|
--mode predict \
|
||||||
|
--eval_dir logs/vdgr/P2_K[1,2,3,4]_v1.0_[CE,LISTNET] \
|
||||||
|
--wandb_mode offline \
|
||||||
|
```
|
||||||
|
2. Finally, merge the outputs of all models
|
||||||
|
```shell
|
||||||
|
python ensemble.py \
|
||||||
|
--exp test \
|
||||||
|
--mode predict \
|
||||||
|
```
|
||||||
|
The output file will be saved in ```output/```
|
||||||
|
|
||||||
|
# Results
|
||||||
|
## VisDial v0.9
|
||||||
|
| Model | MRR | R@1 | R@5 | R@10 | Mean |
|
||||||
|
|:--------:|:---:|:---:|:---:|:----:|:----:|
|
||||||
|
| Prev. SOTA | 71.99 | 59.41 | 87.92 | 94.59 | 2.87 |
|
||||||
|
| VD-GR | **74.50** | **62.10** | **90.49** | **96.37** | **2.45** |
|
||||||
|
|
||||||
|
## VisDialConv
|
||||||
|
| Model | NDCG | MRR | R@1 | R@5 | R@10 | Mean |
|
||||||
|
|:--------:|:----:|:---:|:---:|:---:|:----:|:----:|
|
||||||
|
| Prev. SOTA | 61.72 | 61.79 | 48.95 | 77.50 | 86.71 | 4.72 |
|
||||||
|
| VD-GR | **67.09** | **66.82** | **54.47** | **81.71** | **91.44** | **3.54** |
|
||||||
|
|
||||||
|
## VisPro
|
||||||
|
| Model | NDCG | MRR | R@1 | R@5 | R@10 | Mean |
|
||||||
|
|:--------:|:----:|:---:|:---:|:---:|:----:|:----:|
|
||||||
|
| Prev. SOTA | 59.30 | 62.29 | 48.35 | 80.10 | 88.87 | 4.37 |
|
||||||
|
| VD-GR | **60.35** | **69.89** | **57.21** | **85.97** | **92.68** | **3.15** |
|
||||||
|
|
||||||
|
## VisDial V1.0 Val
|
||||||
|
| Model | NDCG | MRR | R@1 | R@5 | R@10 | Mean |
|
||||||
|
|:--------:|:----:|:---:|:---:|:---:|:----:|:----:|
|
||||||
|
| Prev. SOTA | 65.47 | 69.71 | 56.79 | 85.82 | 93.64 | 3.15 |
|
||||||
|
| VD-GR | 64.32 | **69.91** | **57.01** | **86.14** | **93.74** | **3.13** |
|
||||||
|
|
||||||
|
## VisDial V1.0 Test
|
||||||
|
| Model | NDCG | MRR | R@1 | R@5 | R@10 | Mean |
|
||||||
|
|:--------:|:----:|:---:|:---:|:---:|:----:|:----:|
|
||||||
|
| Prev. SOTA | 64.91 | 68.73 | 55.73 | 85.38 | 93.53 | 3.21 |
|
||||||
|
| VD-GR | 63.49 | 68.65 | 55.33 | **85.58** | **93.85** | **3.20** |
|
||||||
|
| ♣️ Prev. SOTA | 75.92 | 56.18 | 45.32 | 68.05 | 80.98 | 5.42 |
|
||||||
|
| ♣️ VD-GR | **75.95** | **58.30** | **46.55** | **71.45** | 84.52 | **5.32** |
|
||||||
|
| ♣️♦️ Prev. SOTA | 76.17 | 56.42 | 44.75 | 70.23 | 84.52 | 5.47 |
|
||||||
|
| ♣️♦️ VD-GR | **76.43** | 56.35 | **45.18** | 68.13 | 82.18 | 5.79 |
|
||||||
|
|
||||||
|
♣️ = Finetuning on dense annotations, ♦️ = Ensemble model
|
||||||
|
|
||||||
|
|
||||||
|
[1]: https://git-lfs.com/
|
||||||
|
[2]: https://github.com/NVIDIA/apex
|
||||||
|
[3]: https://visualdialog.org/
|
||||||
|
[4]: https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html
|
||||||
|
[5]: https://adnenabdessaied.de
|
||||||
|
[6]: https://www.perceptualui.org/people/shi/
|
||||||
|
[7]: https://www.perceptualui.org/people/bulling/
|
||||||
|
[8]: https://drive.google.com/file/d/1GT0WDinA_z5FdwVc_bWtyB-cwQkGIf7C/view?usp=sharing
|
0
ckpt/.gitkeep
Normal file
0
ckpt/.gitkeep
Normal file
40
config/bert_base_6layer_6conect.json
Normal file
40
config/bert_base_6layer_6conect.json
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
{
|
||||||
|
"attention_probs_dropout_prob": 0.1,
|
||||||
|
"hidden_act": "gelu",
|
||||||
|
"hidden_dropout_prob": 0.1,
|
||||||
|
"hidden_size": 768,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 3072,
|
||||||
|
"max_position_embeddings": 512,
|
||||||
|
"num_attention_heads": 12,
|
||||||
|
"num_hidden_layers": 12,
|
||||||
|
"type_vocab_size": 2,
|
||||||
|
"vocab_size": 30522,
|
||||||
|
"v_feature_size": 2048,
|
||||||
|
"v_target_size": 1601,
|
||||||
|
"v_hidden_size": 1024,
|
||||||
|
"v_num_hidden_layers": 6,
|
||||||
|
"v_num_attention_heads": 8,
|
||||||
|
"v_intermediate_size": 1024,
|
||||||
|
"bi_hidden_size": 1024,
|
||||||
|
"bi_num_attention_heads": 8,
|
||||||
|
"bi_intermediate_size": 1024,
|
||||||
|
"bi_attention_type": 1,
|
||||||
|
"v_attention_probs_dropout_prob": 0.1,
|
||||||
|
"v_hidden_act": "gelu",
|
||||||
|
"v_hidden_dropout_prob": 0.1,
|
||||||
|
"v_initializer_range": 0.02,
|
||||||
|
"pooling_method": "mul",
|
||||||
|
"v_biattention_id": [0, 1, 2, 3, 4, 5],
|
||||||
|
"t_biattention_id": [6, 7, 8, 9, 10, 11],
|
||||||
|
"gnn_act": "gelu",
|
||||||
|
"num_v_gnn_layers": 2,
|
||||||
|
"num_q_gnn_layers": 2,
|
||||||
|
"num_h_gnn_layers": 2,
|
||||||
|
"num_gnn_attention_heads": 4,
|
||||||
|
"gnn_dropout_prob": 0.1,
|
||||||
|
"v_gnn_edge_dim": 12,
|
||||||
|
"q_gnn_edge_dim": 48,
|
||||||
|
"v_gnn_ids": [0, 1, 2, 3, 4, 5],
|
||||||
|
"t_gnn_ids": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
|
||||||
|
}
|
33
config/ensemble.conf
Normal file
33
config/ensemble.conf
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
test = {
|
||||||
|
split = test
|
||||||
|
skip_mrr_eval = true
|
||||||
|
# data
|
||||||
|
visdial_test_data = data/visdial_1.0_test.json
|
||||||
|
|
||||||
|
# directory
|
||||||
|
log_dir = logs/vdgr_ensemble
|
||||||
|
pred_dir = logs/vdgr
|
||||||
|
visdial_output_dir = visdial_output
|
||||||
|
|
||||||
|
processed = [
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
false
|
||||||
|
]
|
||||||
|
|
||||||
|
models = [
|
||||||
|
"P2_K1_v1.0_CE",
|
||||||
|
"P2_K2_v1.0_CE",
|
||||||
|
"P2_K3_v1.0_CE",
|
||||||
|
"P2_K4_v1.0_CE",
|
||||||
|
"P2_K1_v1.0_LISTNET",
|
||||||
|
"P2_K2_v1.0_LISTNET",
|
||||||
|
"P2_K3_v1.0_LISTNET",
|
||||||
|
"P2_K4_v1.0_LISTNET",
|
||||||
|
]
|
||||||
|
}
|
188
config/vdgr.conf
Normal file
188
config/vdgr.conf
Normal file
|
@ -0,0 +1,188 @@
|
||||||
|
# Phase 1
|
||||||
|
P1 {
|
||||||
|
use_cpu = false
|
||||||
|
visdial_version = 1.0
|
||||||
|
train_on_dense = false
|
||||||
|
metrics_to_maximize = mrr
|
||||||
|
|
||||||
|
# visdial data
|
||||||
|
visdial_image_feats = data/visdial_img_feat.lmdb
|
||||||
|
|
||||||
|
visdial_image_adj_matrices = data/img_adj_matrices
|
||||||
|
visdial_question_adj_matrices = data/question_adj_matrices
|
||||||
|
visdial_history_adj_matrices = data/history_adj_matrices
|
||||||
|
|
||||||
|
visdial_train = data/visdial_1.0_train.json
|
||||||
|
visdial_val = data/visdial_1.0_val.json
|
||||||
|
visdial_test = data/visdial_1.0_test.json
|
||||||
|
visdial_val_dense_annotations = data/visdial_1.0_val_dense_annotations.json
|
||||||
|
|
||||||
|
visdial_train_09 = data/visdial_0.9_train.json
|
||||||
|
visdial_val_09 = data/visdial_0.9_val.json
|
||||||
|
visdial_test_09 = data/visdial_0.9_test.json
|
||||||
|
|
||||||
|
visdialconv_val = data/visdial_conv.json
|
||||||
|
visdialconv_val_dense_annotations = data/visdialconv_dense_annotations.json
|
||||||
|
|
||||||
|
visdialvispro_val = data/vispro.json
|
||||||
|
visdialvispro_val_dense_annotations = data/vispro_dense_annotations.json
|
||||||
|
|
||||||
|
visdial_question_parse_vocab = data/parse_vocab.pkl
|
||||||
|
|
||||||
|
# init
|
||||||
|
start_path = ckpt/vdgr_visdial_v1.0_after_warmup_K2.ckpt
|
||||||
|
model_config = config/bert_base_6layer_6conect.json
|
||||||
|
|
||||||
|
# visdial training
|
||||||
|
freeze_vilbert = false
|
||||||
|
visdial_tot_rounds = 11
|
||||||
|
num_negative_samples = 1
|
||||||
|
sequences_per_image = 2
|
||||||
|
batch_size = 8
|
||||||
|
lm_loss_coeff = 1
|
||||||
|
nsp_loss_coeff = 1
|
||||||
|
img_loss_coeff = 1
|
||||||
|
batch_multiply = 1
|
||||||
|
use_trainval = false
|
||||||
|
dense_loss = ce
|
||||||
|
dense_loss_coeff = 0
|
||||||
|
dataloader_text_only = false
|
||||||
|
rlv_hst_only = false
|
||||||
|
rlv_hst_dense_round = false
|
||||||
|
|
||||||
|
# visdial model
|
||||||
|
mask_prob = 0.1
|
||||||
|
image_mask_prob = 0.1
|
||||||
|
max_seq_len = 256
|
||||||
|
num_options = 100
|
||||||
|
num_options_dense = 100
|
||||||
|
use_embedding = joint
|
||||||
|
|
||||||
|
# visdial evaluation
|
||||||
|
eval_visdial_on_test = true
|
||||||
|
eval_batch_size = 1
|
||||||
|
eval_line_batch_size = 200
|
||||||
|
skip_mrr_eval = false
|
||||||
|
skip_ndcg_eval = false
|
||||||
|
skip_visdial_eval = false
|
||||||
|
eval_visdial_every = 1
|
||||||
|
eval_dataset = visdial # visdial_vispro # choices = [visdial, visdial_conv, visdial_vispro ]
|
||||||
|
|
||||||
|
continue_evaluation = false
|
||||||
|
eval_at_start = false
|
||||||
|
eval_before_training = false
|
||||||
|
initializer = normal
|
||||||
|
bert_cased = false
|
||||||
|
|
||||||
|
# restore ckpt
|
||||||
|
loads_best_ckpt = false
|
||||||
|
loads_ckpt = false
|
||||||
|
restarts = false
|
||||||
|
resets_max_metric = false
|
||||||
|
uses_new_optimizer = false
|
||||||
|
sets_new_lr = false
|
||||||
|
loads_start_path = false
|
||||||
|
|
||||||
|
# logging
|
||||||
|
random_seed = 42
|
||||||
|
next_logging_pct = 1.0
|
||||||
|
next_evaluating_pct = 50.0
|
||||||
|
max_ckpt_to_keep = 1
|
||||||
|
num_epochs = 20
|
||||||
|
early_stop_epoch = 5
|
||||||
|
skip_saving_ckpt = false
|
||||||
|
dp_type = apex
|
||||||
|
stack_gr_data = false
|
||||||
|
master_port = 5122
|
||||||
|
stop_epochs = -1
|
||||||
|
train_each_round = false
|
||||||
|
drop_last_answer = false
|
||||||
|
num_samples = -1
|
||||||
|
|
||||||
|
# predicting
|
||||||
|
predict_split = test
|
||||||
|
predict_each_round = false
|
||||||
|
predict_dense_round = false
|
||||||
|
num_test_dialogs = 8000
|
||||||
|
num_val_dialogs = 2064
|
||||||
|
save_score = false
|
||||||
|
|
||||||
|
# optimizer
|
||||||
|
reset_optim = none
|
||||||
|
learning_rate_bert = 5e-6
|
||||||
|
learning_rate_gnn = 2e-4
|
||||||
|
gnn_weight_decay = 0.01
|
||||||
|
use_diff_lr_gnn = true
|
||||||
|
min_lr = 0
|
||||||
|
decay_method_bert = linear
|
||||||
|
decay_method_gnn = linear
|
||||||
|
decay_exp = 2
|
||||||
|
max_grad_norm = 1.0
|
||||||
|
task_optimizer = adam
|
||||||
|
warmup_ratio = 0.1
|
||||||
|
|
||||||
|
# directory
|
||||||
|
log_dir = logs/vdgr
|
||||||
|
data_dir = data
|
||||||
|
visdial_output_dir = visdial_output
|
||||||
|
bert_cache_dir = transformers
|
||||||
|
|
||||||
|
# keep track of other hparams in bert json
|
||||||
|
v_gnn_edge_dim = 12 # 11 classes + hub_node
|
||||||
|
q_gnn_edge_dim = 48 # 47 classes + hub_node
|
||||||
|
num_v_gnn_layers = 2
|
||||||
|
num_q_gnn_layers = 2
|
||||||
|
num_h_gnn_layers = 2
|
||||||
|
num_gnn_attention_heads = 4
|
||||||
|
v_gnn_ids = [0, 1, 2, 3, 4, 5]
|
||||||
|
t_gnn_ids = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
|
||||||
|
}
|
||||||
|
|
||||||
|
# Phase 2
|
||||||
|
P2_CE = ${P1} {
|
||||||
|
# basic
|
||||||
|
train_on_dense = true
|
||||||
|
use_trainval = true
|
||||||
|
metrics_to_maximize = ndcg
|
||||||
|
|
||||||
|
visdial_train_dense = data/visdial_1.0_train_dense.json
|
||||||
|
visdial_train_dense_annotations = data/visdial_1.0_train_dense_annotations.json
|
||||||
|
visdial_val_dense = data/visdial_1.0_val.json
|
||||||
|
|
||||||
|
tr_graph_idx_mapping = data/tr_dense_mapping.json
|
||||||
|
val_graph_idx_mapping = data/val_dense_mapping.json
|
||||||
|
test_graph_idx_mapping = data/test_dense_mapping.json
|
||||||
|
|
||||||
|
visdial_val = data/visdial_1.0_val.json
|
||||||
|
visdial_val_dense_annotations = data/visdial_1.0_val_dense_annotations.json
|
||||||
|
|
||||||
|
# data
|
||||||
|
start_path = logs/vdgr/P1_K2_v1.0/epoch_best.ckpt
|
||||||
|
rlv_hst_only = false
|
||||||
|
|
||||||
|
# visdial training
|
||||||
|
nsp_loss_coeff = 0
|
||||||
|
dense_loss_coeff = 1
|
||||||
|
batch_multiply = 10
|
||||||
|
batch_size = 1
|
||||||
|
|
||||||
|
# visdial model
|
||||||
|
num_options_dense = 100
|
||||||
|
|
||||||
|
# visdial evaluation
|
||||||
|
eval_batch_size = 1
|
||||||
|
eval_line_batch_size = 100
|
||||||
|
skip_mrr_eval = true
|
||||||
|
|
||||||
|
# training
|
||||||
|
stop_epochs = 3
|
||||||
|
dp_type = dp
|
||||||
|
dense_loss = ce
|
||||||
|
|
||||||
|
# optimizer
|
||||||
|
learning_rate_bert = 1e-4
|
||||||
|
}
|
||||||
|
|
||||||
|
P2_LISTNET = ${P2_CE} {
|
||||||
|
dense_loss = listnet
|
||||||
|
}
|
0
data/.gitkeep
Normal file
0
data/.gitkeep
Normal file
0
dataloader/__init__.py
Normal file
0
dataloader/__init__.py
Normal file
269
dataloader/dataloader_base.py
Normal file
269
dataloader/dataloader_base.py
Normal file
|
@ -0,0 +1,269 @@
|
||||||
|
import torch
|
||||||
|
from torch.utils import data
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import glog as log
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
import torch.utils.data as tud
|
||||||
|
from pytorch_transformers.tokenization_bert import BertTokenizer
|
||||||
|
|
||||||
|
from utils.image_features_reader import ImageFeaturesH5Reader
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetBase(data.Dataset):
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
|
||||||
|
if config['display']:
|
||||||
|
log.info('Initializing dataset')
|
||||||
|
|
||||||
|
# Fetch the correct dataset for evaluation
|
||||||
|
if config['validating']:
|
||||||
|
assert config.eval_dataset in ['visdial', 'visdial_conv', 'visdial_vispro', 'visdial_v09']
|
||||||
|
if config.eval_dataset == 'visdial_conv':
|
||||||
|
config['visdial_val'] = config.visdialconv_val
|
||||||
|
config['visdial_val_dense_annotations'] = config.visdialconv_val_dense_annotations
|
||||||
|
elif config.eval_dataset == 'visdial_vispro':
|
||||||
|
config['visdial_val'] = config.visdialvispro_val
|
||||||
|
config['visdial_val_dense_annotations'] = config.visdialvispro_val_dense_annotations
|
||||||
|
elif config.eval_dataset == 'visdial_v09':
|
||||||
|
config['visdial_val_09'] = config.visdial_test_09
|
||||||
|
config['visdial_val_dense_annotations'] = None
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.numDataPoints = {}
|
||||||
|
|
||||||
|
if not config['dataloader_text_only']:
|
||||||
|
self._image_features_reader = ImageFeaturesH5Reader(
|
||||||
|
config['visdial_image_feats'],
|
||||||
|
config['visdial_image_adj_matrices']
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.config['training'] or self.config['validating'] or self.config['predicting']:
|
||||||
|
split2data = {'train': 'train', 'val': 'val', 'test': 'test'}
|
||||||
|
elif self.config['debugging']:
|
||||||
|
split2data = {'train': 'val', 'val': 'val', 'test': 'test'}
|
||||||
|
elif self.config['visualizing']:
|
||||||
|
split2data = {'train': 'train', 'val': 'train', 'test': 'test'}
|
||||||
|
|
||||||
|
filename = f'visdial_{split2data["train"]}'
|
||||||
|
if config['train_on_dense']:
|
||||||
|
filename += '_dense'
|
||||||
|
if self.config['visdial_version'] == 0.9:
|
||||||
|
filename += '_09'
|
||||||
|
|
||||||
|
with open(config[filename]) as f:
|
||||||
|
self.visdial_data_train = json.load(f)
|
||||||
|
if self.config.num_samples > 0:
|
||||||
|
self.visdial_data_train['data']['dialogs'] = self.visdial_data_train['data']['dialogs'][:self.config.num_samples]
|
||||||
|
self.numDataPoints['train'] = len(self.visdial_data_train['data']['dialogs'])
|
||||||
|
|
||||||
|
filename = f'visdial_{split2data["val"]}'
|
||||||
|
if config['train_on_dense'] and config['training']:
|
||||||
|
filename += '_dense'
|
||||||
|
if self.config['visdial_version'] == 0.9:
|
||||||
|
filename += '_09'
|
||||||
|
|
||||||
|
with open(config[filename]) as f:
|
||||||
|
self.visdial_data_val = json.load(f)
|
||||||
|
if self.config.num_samples > 0:
|
||||||
|
self.visdial_data_val['data']['dialogs'] = self.visdial_data_val['data']['dialogs'][:self.config.num_samples]
|
||||||
|
self.numDataPoints['val'] = len(self.visdial_data_val['data']['dialogs'])
|
||||||
|
|
||||||
|
if config['train_on_dense']:
|
||||||
|
self.numDataPoints['trainval'] = self.numDataPoints['train'] + self.numDataPoints['val']
|
||||||
|
with open(config[f'visdial_{split2data["test"]}']) as f:
|
||||||
|
self.visdial_data_test = json.load(f)
|
||||||
|
self.numDataPoints['test'] = len(self.visdial_data_test['data']['dialogs'])
|
||||||
|
|
||||||
|
self.rlv_hst_train = None
|
||||||
|
self.rlv_hst_val = None
|
||||||
|
self.rlv_hst_test = None
|
||||||
|
|
||||||
|
if config['train_on_dense'] or config['predict_dense_round']:
|
||||||
|
with open(config[f'visdial_{split2data["train"]}_dense_annotations']) as f:
|
||||||
|
self.visdial_data_train_dense = json.load(f)
|
||||||
|
if config['train_on_dense']:
|
||||||
|
self.subsets = ['train', 'val', 'trainval', 'test']
|
||||||
|
else:
|
||||||
|
self.subsets = ['train','val','test']
|
||||||
|
self.num_options = config["num_options"]
|
||||||
|
self.num_options_dense = config["num_options_dense"]
|
||||||
|
if config['visdial_version'] != 0.9:
|
||||||
|
with open(config[f'visdial_{split2data["val"]}_dense_annotations']) as f:
|
||||||
|
self.visdial_data_val_dense = json.load(f)
|
||||||
|
else:
|
||||||
|
self.visdial_data_val_dense = None
|
||||||
|
self._split = 'train'
|
||||||
|
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', cache_dir=config['bert_cache_dir'])
|
||||||
|
# fetching token indicecs of [CLS] and [SEP]
|
||||||
|
tokens = ['[CLS]','[MASK]','[SEP]']
|
||||||
|
indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokens)
|
||||||
|
self.CLS = indexed_tokens[0]
|
||||||
|
self.MASK = indexed_tokens[1]
|
||||||
|
self.SEP = indexed_tokens[2]
|
||||||
|
self._max_region_num = 37
|
||||||
|
self.predict_each_round = self.config['predicting'] and self.config['predict_each_round']
|
||||||
|
|
||||||
|
self.keys_to_expand = ['image_feat', 'image_loc', 'image_mask', 'image_target', 'image_label']
|
||||||
|
self.keys_to_flatten_1d = ['hist_len', 'next_sentence_labels', 'round_id', 'image_id']
|
||||||
|
self.keys_to_flatten_2d = ['tokens', 'segments', 'sep_indices', 'mask', 'image_mask', 'image_label', 'input_mask', 'question_limits']
|
||||||
|
self.keys_to_flatten_3d = ['image_feat', 'image_loc', 'image_target', ]
|
||||||
|
self.keys_other = ['gt_relevance', 'gt_option_inds']
|
||||||
|
self.keys_lists_to_flatten = ['image_edge_indices', 'image_edge_attributes', 'question_edge_indices', 'question_edge_attributes', 'history_edge_indices', 'history_sep_indices']
|
||||||
|
if config['stack_gr_data']:
|
||||||
|
self.keys_to_flatten_3d.extend(self.keys_lists_to_flatten[:-1])
|
||||||
|
self.keys_to_flatten_2d.append(self.keys_lists_to_flatten[-1])
|
||||||
|
self.keys_to_flatten_1d.extend(['len_image_gr', 'len_question_gr', 'len_history_gr', 'len_history_sep'])
|
||||||
|
self.keys_lists_to_flatten = []
|
||||||
|
|
||||||
|
self.keys_to_list = ['tot_len']
|
||||||
|
|
||||||
|
# Load the parse vocab for question graph relationship mapping
|
||||||
|
if os.path.isfile(config['visdial_question_parse_vocab']):
|
||||||
|
with open(config['visdial_question_parse_vocab'], 'rb') as f:
|
||||||
|
self.parse_vocab = pickle.load(f)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.numDataPoints[self._split]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def split(self):
|
||||||
|
return self._split
|
||||||
|
|
||||||
|
@split.setter
|
||||||
|
def split(self, split):
|
||||||
|
assert split in self.subsets
|
||||||
|
self._split = split
|
||||||
|
|
||||||
|
def tokens2str(self, seq):
|
||||||
|
dialog_sequence = ''
|
||||||
|
for sentence in seq:
|
||||||
|
for word in sentence:
|
||||||
|
dialog_sequence += self.tokenizer._convert_id_to_token(word) + " "
|
||||||
|
dialog_sequence += ' </end> '
|
||||||
|
dialog_sequence = dialog_sequence.encode('utf8')
|
||||||
|
return dialog_sequence
|
||||||
|
|
||||||
|
def pruneRounds(self, context, num_rounds):
|
||||||
|
start_segment = 1
|
||||||
|
len_context = len(context)
|
||||||
|
cur_rounds = (len(context) // 2) + 1
|
||||||
|
l_index = 0
|
||||||
|
if cur_rounds > num_rounds:
|
||||||
|
# caption is not part of the final input
|
||||||
|
l_index = len_context - (2 * num_rounds)
|
||||||
|
start_segment = 0
|
||||||
|
return context[l_index:], start_segment
|
||||||
|
|
||||||
|
def tokenize_utterance(self, sent, sentences, tot_len, sentence_count, sentence_map, speakers):
|
||||||
|
sentences.extend(sent + ['[SEP]'])
|
||||||
|
tokenized_sent = self.tokenizer.convert_tokens_to_ids(sent)
|
||||||
|
assert len(sent) == len(tokenized_sent), 'sub-word tokens are not allowed!'
|
||||||
|
|
||||||
|
sent_len = len(tokenized_sent)
|
||||||
|
tot_len += sent_len + 1 # the additional 1 is for the sep token
|
||||||
|
sentence_count += 1
|
||||||
|
sentence_map.extend([sentence_count * 2 - 1] * sent_len)
|
||||||
|
sentence_map.append(sentence_count * 2) # for [SEP]
|
||||||
|
speakers.extend([2] * (sent_len + 1))
|
||||||
|
|
||||||
|
return tokenized_sent, sentences, tot_len, sentence_count, sentence_map, speakers
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
return NotImplementedError
|
||||||
|
|
||||||
|
def collate_fn(self, batch):
|
||||||
|
tokens_size = batch[0]['tokens'].size()
|
||||||
|
num_rounds, num_samples = tokens_size[0], tokens_size[1]
|
||||||
|
merged_batch = {key: [d[key] for d in batch] for key in batch[0]}
|
||||||
|
|
||||||
|
if self.config['stack_gr_data']:
|
||||||
|
if (len(batch)) > 1:
|
||||||
|
max_question_gr_len = max([length.max().item() for length in merged_batch['len_question_gr']])
|
||||||
|
max_history_gr_len = max([length.max().item() for length in merged_batch['len_history_gr']])
|
||||||
|
max_history_sep_len = max([length.max().item() for length in merged_batch['len_history_sep']])
|
||||||
|
max_image_gr_len = max([length.max().item() for length in merged_batch['len_image_gr']])
|
||||||
|
|
||||||
|
question_edge_indices_padded = []
|
||||||
|
question_edge_attributes_padded = []
|
||||||
|
|
||||||
|
for q_e_idx, q_e_attr in zip(merged_batch['question_edge_indices'], merged_batch['question_edge_attributes']):
|
||||||
|
b_size, edge_dim, orig_len = q_e_idx.size()
|
||||||
|
q_e_idx_padded = torch.zeros((b_size, edge_dim, max_question_gr_len))
|
||||||
|
q_e_idx_padded[:, :, :orig_len] = q_e_idx
|
||||||
|
question_edge_indices_padded.append(q_e_idx_padded)
|
||||||
|
|
||||||
|
edge_attr_dim = q_e_attr.size(-1)
|
||||||
|
q_e_attr_padded = torch.zeros((b_size, max_question_gr_len, edge_attr_dim))
|
||||||
|
q_e_attr_padded[:, :orig_len, :] = q_e_attr
|
||||||
|
question_edge_attributes_padded.append(q_e_attr_padded)
|
||||||
|
|
||||||
|
merged_batch['question_edge_indices'] = question_edge_indices_padded
|
||||||
|
merged_batch['question_edge_attributes'] = question_edge_attributes_padded
|
||||||
|
|
||||||
|
history_edge_indices_padded = []
|
||||||
|
for h_e_idx in merged_batch['history_edge_indices']:
|
||||||
|
b_size, _, orig_len = h_e_idx.size()
|
||||||
|
h_edge_idx_padded = torch.zeros((b_size, 2, max_history_gr_len))
|
||||||
|
h_edge_idx_padded[:, :, :orig_len] = h_e_idx
|
||||||
|
history_edge_indices_padded.append(h_edge_idx_padded)
|
||||||
|
merged_batch['history_edge_indices'] = history_edge_indices_padded
|
||||||
|
|
||||||
|
history_sep_indices_padded = []
|
||||||
|
for hist_sep_idx in merged_batch['history_sep_indices']:
|
||||||
|
b_size, orig_len = hist_sep_idx.size()
|
||||||
|
hist_sep_idx_padded = torch.zeros((b_size, max_history_sep_len))
|
||||||
|
hist_sep_idx_padded[:, :orig_len] = hist_sep_idx
|
||||||
|
history_sep_indices_padded.append(hist_sep_idx_padded)
|
||||||
|
merged_batch['history_sep_indices'] = history_sep_indices_padded
|
||||||
|
|
||||||
|
image_edge_indices_padded = []
|
||||||
|
image_edge_attributes_padded = []
|
||||||
|
for img_e_idx, img_e_attr in zip(merged_batch['image_edge_indices'], merged_batch['image_edge_attributes']):
|
||||||
|
b_size, edge_dim, orig_len = img_e_idx.size()
|
||||||
|
img_e_idx_padded = torch.zeros((b_size, edge_dim, max_image_gr_len))
|
||||||
|
img_e_idx_padded[:, :, :orig_len] = img_e_idx
|
||||||
|
image_edge_indices_padded.append(img_e_idx_padded)
|
||||||
|
|
||||||
|
edge_attr_dim = img_e_attr.size(-1)
|
||||||
|
img_e_attr_padded = torch.zeros((b_size, max_image_gr_len, edge_attr_dim))
|
||||||
|
img_e_attr_padded[:, :orig_len, :] = img_e_attr
|
||||||
|
image_edge_attributes_padded.append(img_e_attr_padded)
|
||||||
|
|
||||||
|
merged_batch['image_edge_indices'] = image_edge_indices_padded
|
||||||
|
merged_batch['image_edge_attributes'] = image_edge_attributes_padded
|
||||||
|
|
||||||
|
out = {}
|
||||||
|
for key in merged_batch:
|
||||||
|
if key in self.keys_lists_to_flatten:
|
||||||
|
temp = []
|
||||||
|
for b in merged_batch[key]:
|
||||||
|
for x in b:
|
||||||
|
temp.append(x)
|
||||||
|
merged_batch[key] = temp
|
||||||
|
|
||||||
|
elif key in self.keys_to_list:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
merged_batch[key] = torch.stack(merged_batch[key], 0)
|
||||||
|
if key in self.keys_to_expand:
|
||||||
|
if len(merged_batch[key].size()) == 3:
|
||||||
|
size0, size1, size2 = merged_batch[key].size()
|
||||||
|
expand_size = (size0, num_rounds, num_samples, size1, size2)
|
||||||
|
elif len(merged_batch[key].size()) == 2:
|
||||||
|
size0, size1 = merged_batch[key].size()
|
||||||
|
expand_size = (size0, num_rounds, num_samples, size1)
|
||||||
|
merged_batch[key] = merged_batch[key].unsqueeze(1).unsqueeze(1).expand(expand_size).contiguous()
|
||||||
|
if key in self.keys_to_flatten_1d:
|
||||||
|
merged_batch[key] = merged_batch[key].reshape(-1)
|
||||||
|
elif key in self.keys_to_flatten_2d:
|
||||||
|
merged_batch[key] = merged_batch[key].reshape(-1, merged_batch[key].shape[-1])
|
||||||
|
elif key in self.keys_to_flatten_3d:
|
||||||
|
merged_batch[key] = merged_batch[key].reshape(-1, merged_batch[key].shape[-2], merged_batch[key].shape[-1])
|
||||||
|
else:
|
||||||
|
assert key in self.keys_other, f'unrecognized key in collate_fn: {key}'
|
||||||
|
|
||||||
|
out[key] = merged_batch[key]
|
||||||
|
return out
|
615
dataloader/dataloader_visdial.py
Normal file
615
dataloader/dataloader_visdial.py
Normal file
|
@ -0,0 +1,615 @@
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
import sys
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
||||||
|
|
||||||
|
from utils.data_utils import encode_input, encode_input_with_mask, encode_image_input
|
||||||
|
from dataloader.dataloader_base import DatasetBase
|
||||||
|
|
||||||
|
|
||||||
|
class VisdialDataset(DatasetBase):
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super(VisdialDataset, self).__init__(config)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
MAX_SEQ_LEN = self.config['max_seq_len']
|
||||||
|
cur_data = None
|
||||||
|
if self._split == 'train':
|
||||||
|
cur_data = self.visdial_data_train['data']
|
||||||
|
ques_adj_matrices_dir = os.path.join(self.config['visdial_question_adj_matrices'], 'train')
|
||||||
|
hist_adj_matrices_dir = os.path.join(self.config['visdial_history_adj_matrices'], 'train')
|
||||||
|
|
||||||
|
elif self._split == 'val':
|
||||||
|
cur_data = self.visdial_data_val['data']
|
||||||
|
ques_adj_matrices_dir = os.path.join(self.config['visdial_question_adj_matrices'], 'val')
|
||||||
|
hist_adj_matrices_dir = os.path.join(self.config['visdial_history_adj_matrices'], 'val')
|
||||||
|
|
||||||
|
else:
|
||||||
|
cur_data = self.visdial_data_test['data']
|
||||||
|
ques_adj_matrices_dir = os.path.join(self.config['visdial_question_adj_matrices'], 'test')
|
||||||
|
hist_adj_matrices_dir = os.path.join(self.config['visdial_history_adj_matrices'], 'test')
|
||||||
|
|
||||||
|
if self.config['visdial_version'] == 0.9:
|
||||||
|
ques_adj_matrices_dir = os.path.join(self.config['visdial_question_adj_matrices'], 'train')
|
||||||
|
hist_adj_matrices_dir = os.path.join(self.config['visdial_history_adj_matrices'], 'train')
|
||||||
|
|
||||||
|
self.num_bad_samples = 0
|
||||||
|
# number of options to score on
|
||||||
|
num_options = self.num_options
|
||||||
|
assert num_options > 1 and num_options <= 100
|
||||||
|
num_dialog_rounds = 10
|
||||||
|
|
||||||
|
dialog = cur_data['dialogs'][index]
|
||||||
|
cur_questions = cur_data['questions']
|
||||||
|
cur_answers = cur_data['answers']
|
||||||
|
img_id = dialog['image_id']
|
||||||
|
graph_idx = dialog.get('dialog_idx', index)
|
||||||
|
|
||||||
|
if self._split == 'train':
|
||||||
|
# caption
|
||||||
|
sent = dialog['caption'].split(' ')
|
||||||
|
sentences = ['[CLS]']
|
||||||
|
tot_len = 1 # for the CLS token
|
||||||
|
sentence_map = [0] # for the CLS token
|
||||||
|
sentence_count = 0
|
||||||
|
speakers = [0]
|
||||||
|
|
||||||
|
tokenized_sent, sentences, tot_len, sentence_count, sentence_map, speakers = \
|
||||||
|
self.tokenize_utterance(sent, sentences, tot_len, sentence_count, sentence_map, speakers)
|
||||||
|
|
||||||
|
utterances = [[tokenized_sent]]
|
||||||
|
utterances_random = [[tokenized_sent]]
|
||||||
|
|
||||||
|
for rnd, utterance in enumerate(dialog['dialog']):
|
||||||
|
cur_rnd_utterance = utterances[-1].copy()
|
||||||
|
cur_rnd_utterance_random = utterances[-1].copy()
|
||||||
|
|
||||||
|
# question
|
||||||
|
sent = cur_questions[utterance['question']].split(' ')
|
||||||
|
tokenized_sent, sentences, tot_len, sentence_count, sentence_map, speakers = \
|
||||||
|
self.tokenize_utterance(sent, sentences, tot_len, sentence_count, sentence_map, speakers)
|
||||||
|
|
||||||
|
cur_rnd_utterance.append(tokenized_sent)
|
||||||
|
cur_rnd_utterance_random.append(tokenized_sent)
|
||||||
|
|
||||||
|
# answer
|
||||||
|
sent = cur_answers[utterance['answer']].split(' ')
|
||||||
|
tokenized_sent, sentences, tot_len, sentence_count, sentence_map, speakers = \
|
||||||
|
self.tokenize_utterance(sent, sentences, tot_len, sentence_count, sentence_map, speakers)
|
||||||
|
cur_rnd_utterance.append(tokenized_sent)
|
||||||
|
|
||||||
|
utterances.append(cur_rnd_utterance)
|
||||||
|
|
||||||
|
# randomly select one random utterance in that round
|
||||||
|
num_inds = len(utterance['answer_options'])
|
||||||
|
gt_option_ind = utterance['gt_index']
|
||||||
|
|
||||||
|
negative_samples = []
|
||||||
|
|
||||||
|
for _ in range(self.config["num_negative_samples"]):
|
||||||
|
|
||||||
|
all_inds = list(range(100))
|
||||||
|
all_inds.remove(gt_option_ind)
|
||||||
|
all_inds = all_inds[:(num_options-1)]
|
||||||
|
tokenized_random_utterance = None
|
||||||
|
option_ind = None
|
||||||
|
|
||||||
|
while len(all_inds):
|
||||||
|
option_ind = random.choice(all_inds)
|
||||||
|
tokenized_random_utterance = self.tokenizer.convert_tokens_to_ids(cur_answers[utterance['answer_options'][option_ind]].split(' '))
|
||||||
|
# the 1 here is for the sep token at the end of each utterance
|
||||||
|
if(MAX_SEQ_LEN >= (tot_len + len(tokenized_random_utterance) + 1)):
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
all_inds.remove(option_ind)
|
||||||
|
if len(all_inds) == 0:
|
||||||
|
# all the options exceed the max len. Truncate the last utterance in this case.
|
||||||
|
tokenized_random_utterance = tokenized_random_utterance[:len(tokenized_sent)]
|
||||||
|
t = cur_rnd_utterance_random.copy()
|
||||||
|
t.append(tokenized_random_utterance)
|
||||||
|
negative_samples.append(t)
|
||||||
|
|
||||||
|
utterances_random.append(negative_samples)
|
||||||
|
|
||||||
|
# removing the caption in the beginning
|
||||||
|
utterances = utterances[1:]
|
||||||
|
utterances_random = utterances_random[1:]
|
||||||
|
assert len(utterances) == len(utterances_random) == num_dialog_rounds
|
||||||
|
assert tot_len <= MAX_SEQ_LEN, '{} {} tot_len = {} > max_seq_len'.format(
|
||||||
|
self._split, index, tot_len
|
||||||
|
)
|
||||||
|
|
||||||
|
tokens_all = []
|
||||||
|
question_limits_all = []
|
||||||
|
question_edge_indices_all = []
|
||||||
|
question_edge_attributes_all = []
|
||||||
|
history_edge_indices_all = []
|
||||||
|
history_sep_indices_all = []
|
||||||
|
mask_all = []
|
||||||
|
segments_all = []
|
||||||
|
sep_indices_all = []
|
||||||
|
next_labels_all = []
|
||||||
|
hist_len_all = []
|
||||||
|
|
||||||
|
# randomly pick several rounds to train
|
||||||
|
pos_rounds = sorted(random.sample(range(num_dialog_rounds), self.config['sequences_per_image'] // 2), reverse=True)
|
||||||
|
neg_rounds = sorted(random.sample(range(num_dialog_rounds), self.config['sequences_per_image'] // 2), reverse=True)
|
||||||
|
|
||||||
|
tokens_all_rnd = []
|
||||||
|
question_limits_all_rnd = []
|
||||||
|
mask_all_rnd = []
|
||||||
|
segments_all_rnd = []
|
||||||
|
sep_indices_all_rnd = []
|
||||||
|
next_labels_all_rnd = []
|
||||||
|
hist_len_all_rnd = []
|
||||||
|
|
||||||
|
for j in pos_rounds:
|
||||||
|
context = utterances[j]
|
||||||
|
context, start_segment = self.pruneRounds(context, self.config['visdial_tot_rounds'])
|
||||||
|
if j == pos_rounds[0]: # dialog with positive label and max rounds
|
||||||
|
tokens, segments, sep_indices, mask, input_mask, start_question, end_question = encode_input_with_mask(context, start_segment, self.CLS,
|
||||||
|
self.SEP, self.MASK, max_seq_len=MAX_SEQ_LEN, mask_prob=self.config["mask_prob"])
|
||||||
|
else:
|
||||||
|
tokens, segments, sep_indices, mask, start_question, end_question = encode_input(context, start_segment, self.CLS,
|
||||||
|
self.SEP, self.MASK, max_seq_len=MAX_SEQ_LEN, mask_prob=self.config["mask_prob"])
|
||||||
|
tokens_all_rnd.append(tokens)
|
||||||
|
question_limits_all_rnd.append(torch.tensor([start_question, end_question]))
|
||||||
|
mask_all_rnd.append(mask)
|
||||||
|
sep_indices_all_rnd.append(sep_indices)
|
||||||
|
next_labels_all_rnd.append(torch.LongTensor([0]))
|
||||||
|
segments_all_rnd.append(segments)
|
||||||
|
hist_len_all_rnd.append(torch.LongTensor([len(context)-1]))
|
||||||
|
|
||||||
|
tokens_all.append(torch.cat(tokens_all_rnd,0).unsqueeze(0))
|
||||||
|
mask_all.append(torch.cat(mask_all_rnd,0).unsqueeze(0))
|
||||||
|
question_limits_all.extend(question_limits_all_rnd)
|
||||||
|
segments_all.append(torch.cat(segments_all_rnd, 0).unsqueeze(0))
|
||||||
|
sep_indices_all.append(torch.cat(sep_indices_all_rnd, 0).unsqueeze(0))
|
||||||
|
next_labels_all.append(torch.cat(next_labels_all_rnd, 0).unsqueeze(0))
|
||||||
|
hist_len_all.append(torch.cat(hist_len_all_rnd,0).unsqueeze(0))
|
||||||
|
|
||||||
|
assert len(pos_rounds) == 1
|
||||||
|
question_graphs = pickle.load(
|
||||||
|
open(os.path.join(ques_adj_matrices_dir, f'{graph_idx}.pkl'), 'rb')
|
||||||
|
)
|
||||||
|
|
||||||
|
question_graph_pos = question_graphs[pos_rounds[0]]
|
||||||
|
question_edge_index_pos = []
|
||||||
|
question_edge_attribute_pos = []
|
||||||
|
for edge_idx, edge_attr in question_graph_pos:
|
||||||
|
question_edge_index_pos.append(edge_idx)
|
||||||
|
edge_attr_one_hot = np.zeros((len(self.parse_vocab) + 1,), dtype=np.float32)
|
||||||
|
edge_attr_one_hot[self.parse_vocab.get(edge_attr, len(self.parse_vocab))] = 1.0
|
||||||
|
question_edge_attribute_pos.append(edge_attr_one_hot)
|
||||||
|
|
||||||
|
question_edge_index_pos = np.array(question_edge_index_pos, dtype=np.float64)
|
||||||
|
question_edge_attribute_pos = np.stack(question_edge_attribute_pos, axis=0)
|
||||||
|
|
||||||
|
question_edge_indices_all.append(
|
||||||
|
torch.from_numpy(question_edge_index_pos).t().long().contiguous()
|
||||||
|
)
|
||||||
|
|
||||||
|
question_edge_attributes_all.append(
|
||||||
|
torch.from_numpy(question_edge_attribute_pos)
|
||||||
|
)
|
||||||
|
|
||||||
|
history_edge_indices = pickle.load(
|
||||||
|
open(os.path.join(hist_adj_matrices_dir, f'{graph_idx}.pkl'), 'rb')
|
||||||
|
)
|
||||||
|
|
||||||
|
history_edge_indices_all.append(
|
||||||
|
torch.tensor(history_edge_indices[pos_rounds[0]]).t().long().contiguous()
|
||||||
|
)
|
||||||
|
# Get the [SEP] tokens that will represent the history graph node features
|
||||||
|
hist_idx_pos = [i * 2 for i in range(pos_rounds[0] + 1)]
|
||||||
|
sep_indices = sep_indices.squeeze(0).numpy()
|
||||||
|
history_sep_indices_all.append(torch.from_numpy(sep_indices[hist_idx_pos]))
|
||||||
|
|
||||||
|
if len(neg_rounds) > 0:
|
||||||
|
tokens_all_rnd = []
|
||||||
|
question_limits_all_rnd = []
|
||||||
|
mask_all_rnd = []
|
||||||
|
segments_all_rnd = []
|
||||||
|
sep_indices_all_rnd = []
|
||||||
|
next_labels_all_rnd = []
|
||||||
|
hist_len_all_rnd = []
|
||||||
|
|
||||||
|
for j in neg_rounds:
|
||||||
|
|
||||||
|
negative_samples = utterances_random[j]
|
||||||
|
for context_random in negative_samples:
|
||||||
|
context_random, start_segment = self.pruneRounds(context_random, self.config['visdial_tot_rounds'])
|
||||||
|
tokens_random, segments_random, sep_indices_random, mask_random, start_question, end_question = encode_input(context_random, start_segment, self.CLS,
|
||||||
|
self.SEP, self.MASK, max_seq_len=MAX_SEQ_LEN, mask_prob=self.config["mask_prob"])
|
||||||
|
tokens_all_rnd.append(tokens_random)
|
||||||
|
question_limits_all_rnd.append(torch.tensor([start_question, end_question]))
|
||||||
|
mask_all_rnd.append(mask_random)
|
||||||
|
sep_indices_all_rnd.append(sep_indices_random)
|
||||||
|
next_labels_all_rnd.append(torch.LongTensor([1]))
|
||||||
|
segments_all_rnd.append(segments_random)
|
||||||
|
hist_len_all_rnd.append(torch.LongTensor([len(context_random)-1]))
|
||||||
|
|
||||||
|
tokens_all.append(torch.cat(tokens_all_rnd,0).unsqueeze(0))
|
||||||
|
mask_all.append(torch.cat(mask_all_rnd,0).unsqueeze(0))
|
||||||
|
question_limits_all.extend(question_limits_all_rnd)
|
||||||
|
segments_all.append(torch.cat(segments_all_rnd, 0).unsqueeze(0))
|
||||||
|
sep_indices_all.append(torch.cat(sep_indices_all_rnd, 0).unsqueeze(0))
|
||||||
|
next_labels_all.append(torch.cat(next_labels_all_rnd, 0).unsqueeze(0))
|
||||||
|
hist_len_all.append(torch.cat(hist_len_all_rnd,0).unsqueeze(0))
|
||||||
|
|
||||||
|
assert len(neg_rounds) == 1
|
||||||
|
|
||||||
|
question_graph_neg = question_graphs[neg_rounds[0]]
|
||||||
|
question_edge_index_neg = []
|
||||||
|
question_edge_attribute_neg = []
|
||||||
|
for edge_idx, edge_attr in question_graph_neg:
|
||||||
|
question_edge_index_neg.append(edge_idx)
|
||||||
|
edge_attr_one_hot = np.zeros((len(self.parse_vocab) + 1,), dtype=np.float32)
|
||||||
|
edge_attr_one_hot[self.parse_vocab.get(edge_attr, len(self.parse_vocab))] = 1.0
|
||||||
|
question_edge_attribute_neg.append(edge_attr_one_hot)
|
||||||
|
|
||||||
|
question_edge_index_neg = np.array(question_edge_index_neg, dtype=np.float64)
|
||||||
|
question_edge_attribute_neg = np.stack(question_edge_attribute_neg, axis=0)
|
||||||
|
|
||||||
|
question_edge_indices_all.append(
|
||||||
|
torch.from_numpy(question_edge_index_neg).t().long().contiguous()
|
||||||
|
)
|
||||||
|
|
||||||
|
question_edge_attributes_all.append(
|
||||||
|
torch.from_numpy(question_edge_attribute_neg)
|
||||||
|
)
|
||||||
|
|
||||||
|
history_edge_indices_all.append(
|
||||||
|
torch.tensor(history_edge_indices[neg_rounds[0]]).t().long().contiguous()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the [SEP] tokens that will represent the history graph node features
|
||||||
|
hist_idx_neg = [i * 2 for i in range(neg_rounds[0] + 1)]
|
||||||
|
sep_indices_random = sep_indices_random.squeeze(0).numpy()
|
||||||
|
history_sep_indices_all.append(torch.from_numpy(sep_indices_random[hist_idx_neg]))
|
||||||
|
|
||||||
|
tokens_all = torch.cat(tokens_all, 0) # [2, num_pos, max_len]
|
||||||
|
question_limits_all = torch.stack(question_limits_all, 0) # [2, 2]
|
||||||
|
mask_all = torch.cat(mask_all,0)
|
||||||
|
segments_all = torch.cat(segments_all, 0)
|
||||||
|
sep_indices_all = torch.cat(sep_indices_all, 0)
|
||||||
|
next_labels_all = torch.cat(next_labels_all, 0)
|
||||||
|
hist_len_all = torch.cat(hist_len_all, 0)
|
||||||
|
input_mask_all = torch.LongTensor(input_mask) # [max_len]
|
||||||
|
|
||||||
|
item = {}
|
||||||
|
|
||||||
|
item['tokens'] = tokens_all
|
||||||
|
item['question_limits'] = question_limits_all
|
||||||
|
item['question_edge_indices'] = question_edge_indices_all
|
||||||
|
item['question_edge_attributes'] = question_edge_attributes_all
|
||||||
|
|
||||||
|
item['history_edge_indices'] = history_edge_indices_all
|
||||||
|
item['history_sep_indices'] = history_sep_indices_all
|
||||||
|
item['segments'] = segments_all
|
||||||
|
item['sep_indices'] = sep_indices_all
|
||||||
|
item['mask'] = mask_all
|
||||||
|
item['next_sentence_labels'] = next_labels_all
|
||||||
|
item['hist_len'] = hist_len_all
|
||||||
|
item['input_mask'] = input_mask_all
|
||||||
|
|
||||||
|
# get image features
|
||||||
|
if not self.config['dataloader_text_only']:
|
||||||
|
features, num_boxes, boxes, _ , image_target, image_edge_indexes, image_edge_attributes = self._image_features_reader[img_id]
|
||||||
|
features, spatials, image_mask, image_target, image_label = encode_image_input(features, num_boxes, boxes, image_target, max_regions=self._max_region_num)
|
||||||
|
else:
|
||||||
|
features = spatials = image_mask = image_target = image_label = torch.tensor([0])
|
||||||
|
|
||||||
|
elif self._split == 'val':
|
||||||
|
gt_relevance = None
|
||||||
|
gt_option_inds = []
|
||||||
|
options_all = []
|
||||||
|
|
||||||
|
# caption
|
||||||
|
sent = dialog['caption'].split(' ')
|
||||||
|
sentences = ['[CLS]']
|
||||||
|
tot_len = 1 # for the CLS token
|
||||||
|
sentence_map = [0] # for the CLS token
|
||||||
|
sentence_count = 0
|
||||||
|
speakers = [0]
|
||||||
|
|
||||||
|
tokenized_sent, sentences, tot_len, sentence_count, sentence_map, speakers = \
|
||||||
|
self.tokenize_utterance(sent, sentences, tot_len, sentence_count, sentence_map, speakers)
|
||||||
|
utterances = [[tokenized_sent]]
|
||||||
|
|
||||||
|
for rnd, utterance in enumerate(dialog['dialog']):
|
||||||
|
cur_rnd_utterance = utterances[-1].copy()
|
||||||
|
|
||||||
|
# question
|
||||||
|
sent = cur_questions[utterance['question']].split(' ')
|
||||||
|
tokenized_sent, sentences, tot_len, sentence_count, sentence_map, speakers = \
|
||||||
|
self.tokenize_utterance(sent, sentences, tot_len, sentence_count, sentence_map, speakers)
|
||||||
|
|
||||||
|
cur_rnd_utterance.append(tokenized_sent)
|
||||||
|
|
||||||
|
# current round
|
||||||
|
gt_option_ind = utterance['gt_index']
|
||||||
|
# first select gt option id, then choose the first num_options inds
|
||||||
|
option_inds = []
|
||||||
|
option_inds.append(gt_option_ind)
|
||||||
|
all_inds = list(range(100))
|
||||||
|
all_inds.remove(gt_option_ind)
|
||||||
|
all_inds = all_inds[:(num_options-1)]
|
||||||
|
option_inds.extend(all_inds)
|
||||||
|
gt_option_inds.append(0)
|
||||||
|
cur_rnd_options = []
|
||||||
|
answer_options = [utterance['answer_options'][k] for k in option_inds]
|
||||||
|
assert len(answer_options) == len(option_inds) == num_options
|
||||||
|
assert answer_options[0] == utterance['answer']
|
||||||
|
|
||||||
|
# for evaluation of all options and dense relevance
|
||||||
|
if self.visdial_data_val_dense:
|
||||||
|
if rnd == self.visdial_data_val_dense[index]['round_id'] - 1:
|
||||||
|
# only 1 round has gt_relevance for each example
|
||||||
|
if 'relevance' in self.visdial_data_val_dense[index]:
|
||||||
|
gt_relevance = torch.Tensor(self.visdial_data_val_dense[index]['relevance'])
|
||||||
|
else:
|
||||||
|
gt_relevance = torch.Tensor(self.visdial_data_val_dense[index]['gt_relevance'])
|
||||||
|
# shuffle based on new indices
|
||||||
|
gt_relevance = gt_relevance[torch.LongTensor(option_inds)]
|
||||||
|
else:
|
||||||
|
gt_relevance = -1
|
||||||
|
|
||||||
|
for answer_option in answer_options:
|
||||||
|
cur_rnd_cur_option = cur_rnd_utterance.copy()
|
||||||
|
cur_rnd_cur_option.append(self.tokenizer.convert_tokens_to_ids(cur_answers[answer_option].split(' ')))
|
||||||
|
cur_rnd_options.append(cur_rnd_cur_option)
|
||||||
|
|
||||||
|
# answer
|
||||||
|
sent = cur_answers[utterance['answer']].split(' ')
|
||||||
|
tokenized_sent, sentences, tot_len, sentence_count, sentence_map, speakers = \
|
||||||
|
self.tokenize_utterance(sent, sentences, tot_len, sentence_count, sentence_map, speakers)
|
||||||
|
cur_rnd_utterance.append(tokenized_sent)
|
||||||
|
|
||||||
|
utterances.append(cur_rnd_utterance)
|
||||||
|
options_all.append(cur_rnd_options)
|
||||||
|
|
||||||
|
# encode the input and create batch x 10 x 100 * max_len arrays (batch x num_rounds x num_options)
|
||||||
|
tokens_all = []
|
||||||
|
question_limits_all = []
|
||||||
|
mask_all = []
|
||||||
|
segments_all = []
|
||||||
|
sep_indices_all = []
|
||||||
|
hist_len_all = []
|
||||||
|
history_sep_indices_all = []
|
||||||
|
|
||||||
|
for rnd, cur_rnd_options in enumerate(options_all):
|
||||||
|
|
||||||
|
tokens_all_rnd = []
|
||||||
|
mask_all_rnd = []
|
||||||
|
segments_all_rnd = []
|
||||||
|
sep_indices_all_rnd = []
|
||||||
|
hist_len_all_rnd = []
|
||||||
|
|
||||||
|
for j, cur_rnd_option in enumerate(cur_rnd_options):
|
||||||
|
|
||||||
|
cur_rnd_option, start_segment = self.pruneRounds(cur_rnd_option, self.config['visdial_tot_rounds'])
|
||||||
|
if rnd == len(options_all) - 1 and j == 0: # gt dialog
|
||||||
|
tokens, segments, sep_indices, mask, input_mask, start_question, end_question = encode_input_with_mask(cur_rnd_option, start_segment, self.CLS,
|
||||||
|
self.SEP, self.MASK, max_seq_len=MAX_SEQ_LEN, mask_prob=0)
|
||||||
|
else:
|
||||||
|
tokens, segments, sep_indices, mask, start_question, end_question = encode_input(cur_rnd_option, start_segment,self.CLS,
|
||||||
|
self.SEP, self.MASK ,max_seq_len=MAX_SEQ_LEN, mask_prob=0)
|
||||||
|
|
||||||
|
tokens_all_rnd.append(tokens)
|
||||||
|
mask_all_rnd.append(mask)
|
||||||
|
segments_all_rnd.append(segments)
|
||||||
|
sep_indices_all_rnd.append(sep_indices)
|
||||||
|
hist_len_all_rnd.append(torch.LongTensor([len(cur_rnd_option)-1]))
|
||||||
|
|
||||||
|
question_limits_all.append(torch.tensor([start_question, end_question]).unsqueeze(0).repeat(100, 1))
|
||||||
|
tokens_all.append(torch.cat(tokens_all_rnd,0).unsqueeze(0))
|
||||||
|
mask_all.append(torch.cat(mask_all_rnd,0).unsqueeze(0))
|
||||||
|
segments_all.append(torch.cat(segments_all_rnd,0).unsqueeze(0))
|
||||||
|
sep_indices_all.append(torch.cat(sep_indices_all_rnd,0).unsqueeze(0))
|
||||||
|
hist_len_all.append(torch.cat(hist_len_all_rnd,0).unsqueeze(0))
|
||||||
|
# Get the [SEP] tokens that will represent the history graph node features
|
||||||
|
# It will be the same for all answer candidates as the history does not change
|
||||||
|
# for each answer
|
||||||
|
hist_idx = [i * 2 for i in range(rnd + 1)]
|
||||||
|
history_sep_indices_all.extend(sep_indices.squeeze(0)[hist_idx].contiguous() for _ in range(100))
|
||||||
|
|
||||||
|
tokens_all = torch.cat(tokens_all, 0) # [10, 100, max_len]
|
||||||
|
mask_all = torch.cat(mask_all, 0)
|
||||||
|
segments_all = torch.cat(segments_all, 0)
|
||||||
|
sep_indices_all = torch.cat(sep_indices_all, 0)
|
||||||
|
hist_len_all = torch.cat(hist_len_all, 0)
|
||||||
|
input_mask_all = torch.LongTensor(input_mask) # [max_len]
|
||||||
|
|
||||||
|
# load graph data
|
||||||
|
question_limits_all = torch.stack(question_limits_all, 0) # [10, 100, 2]
|
||||||
|
|
||||||
|
question_graphs = pickle.load(
|
||||||
|
open(os.path.join(ques_adj_matrices_dir, f'{graph_idx}.pkl'), 'rb')
|
||||||
|
)
|
||||||
|
question_edge_indices_all = [] # [10, N] we do not repeat it 100 times here
|
||||||
|
question_edge_attributes_all = [] # [10, N] we do not repeat it 100 times here
|
||||||
|
|
||||||
|
for q_graph_round in question_graphs:
|
||||||
|
question_edge_index = []
|
||||||
|
question_edge_attribute = []
|
||||||
|
for edge_index, edge_attr in q_graph_round:
|
||||||
|
question_edge_index.append(edge_index)
|
||||||
|
edge_attr_one_hot = np.zeros((len(self.parse_vocab) + 1,), dtype=np.float32)
|
||||||
|
edge_attr_one_hot[self.parse_vocab.get(edge_attr, len(self.parse_vocab))] = 1.0
|
||||||
|
question_edge_attribute.append(edge_attr_one_hot)
|
||||||
|
question_edge_index = np.array(question_edge_index, dtype=np.float64)
|
||||||
|
question_edge_attribute = np.stack(question_edge_attribute, axis=0)
|
||||||
|
|
||||||
|
question_edge_indices_all.extend(
|
||||||
|
[torch.from_numpy(question_edge_index).t().long().contiguous() for _ in range(100)])
|
||||||
|
question_edge_attributes_all.extend(
|
||||||
|
[torch.from_numpy(question_edge_attribute).contiguous() for _ in range(100)])
|
||||||
|
|
||||||
|
_history_edge_incides_all = pickle.load(
|
||||||
|
open(os.path.join(hist_adj_matrices_dir, f'{graph_idx}.pkl'), 'rb')
|
||||||
|
)
|
||||||
|
history_edge_incides_all = []
|
||||||
|
for hist_edge_indices_rnd in _history_edge_incides_all:
|
||||||
|
history_edge_incides_all.extend(
|
||||||
|
[torch.tensor(hist_edge_indices_rnd).t().long().contiguous() for _ in range(100)]
|
||||||
|
)
|
||||||
|
|
||||||
|
item = {}
|
||||||
|
item['tokens'] = tokens_all
|
||||||
|
item['segments'] = segments_all
|
||||||
|
item['sep_indices'] = sep_indices_all
|
||||||
|
item['mask'] = mask_all
|
||||||
|
item['hist_len'] = hist_len_all
|
||||||
|
item['input_mask'] = input_mask_all
|
||||||
|
|
||||||
|
item['gt_option_inds'] = torch.LongTensor(gt_option_inds)
|
||||||
|
|
||||||
|
# return dense annotation data as well
|
||||||
|
if self.visdial_data_val_dense:
|
||||||
|
item['round_id'] = torch.LongTensor([self.visdial_data_val_dense[index]['round_id']])
|
||||||
|
item['gt_relevance'] = gt_relevance
|
||||||
|
|
||||||
|
item['question_limits'] = question_limits_all
|
||||||
|
|
||||||
|
item['question_edge_indices'] = question_edge_indices_all
|
||||||
|
item['question_edge_attributes'] = question_edge_attributes_all
|
||||||
|
|
||||||
|
item['history_edge_indices'] = history_edge_incides_all
|
||||||
|
item['history_sep_indices'] = history_sep_indices_all
|
||||||
|
|
||||||
|
# get image features
|
||||||
|
if not self.config['dataloader_text_only']:
|
||||||
|
features, num_boxes, boxes, _ , image_target, image_edge_indexes, image_edge_attributes = self._image_features_reader[img_id]
|
||||||
|
features, spatials, image_mask, image_target, image_label = encode_image_input(
|
||||||
|
features, num_boxes, boxes, image_target, max_regions=self._max_region_num, mask_prob=0)
|
||||||
|
else:
|
||||||
|
features = spatials = image_mask = image_target = image_label = torch.tensor([0])
|
||||||
|
|
||||||
|
elif self.split == 'test':
|
||||||
|
assert num_options == 100
|
||||||
|
cur_rnd_utterance = [self.tokenizer.convert_tokens_to_ids(dialog['caption'].split(' '))]
|
||||||
|
options_all = []
|
||||||
|
for rnd,utterance in enumerate(dialog['dialog']):
|
||||||
|
cur_rnd_utterance.append(self.tokenizer.convert_tokens_to_ids(cur_questions[utterance['question']].split(' ')))
|
||||||
|
if rnd != len(dialog['dialog'])-1:
|
||||||
|
cur_rnd_utterance.append(self.tokenizer.convert_tokens_to_ids(cur_answers[utterance['answer']].split(' ')))
|
||||||
|
for answer_option in dialog['dialog'][-1]['answer_options']:
|
||||||
|
cur_option = cur_rnd_utterance.copy()
|
||||||
|
cur_option.append(self.tokenizer.convert_tokens_to_ids(cur_answers[answer_option].split(' ')))
|
||||||
|
options_all.append(cur_option)
|
||||||
|
|
||||||
|
tokens_all = []
|
||||||
|
mask_all = []
|
||||||
|
segments_all = []
|
||||||
|
sep_indices_all = []
|
||||||
|
hist_len_all = []
|
||||||
|
|
||||||
|
for j, option in enumerate(options_all):
|
||||||
|
option, start_segment = self.pruneRounds(option, self.config['visdial_tot_rounds'])
|
||||||
|
tokens, segments, sep_indices, mask = encode_input(option, start_segment ,self.CLS,
|
||||||
|
self.SEP, self.MASK ,max_seq_len=MAX_SEQ_LEN, mask_prob=0)
|
||||||
|
|
||||||
|
tokens_all.append(tokens)
|
||||||
|
mask_all.append(mask)
|
||||||
|
segments_all.append(segments)
|
||||||
|
sep_indices_all.append(sep_indices)
|
||||||
|
hist_len_all.append(torch.LongTensor([len(option)-1]))
|
||||||
|
|
||||||
|
tokens_all = torch.cat(tokens_all,0)
|
||||||
|
mask_all = torch.cat(mask_all,0)
|
||||||
|
segments_all = torch.cat(segments_all, 0)
|
||||||
|
sep_indices_all = torch.cat(sep_indices_all, 0)
|
||||||
|
hist_len_all = torch.cat(hist_len_all,0)
|
||||||
|
hist_idx = [i*2 for i in range(len(dialog['dialog']))]
|
||||||
|
history_sep_indices_all = [sep_indices.squeeze(0)[hist_idx].contiguous() for _ in range(num_options)]
|
||||||
|
|
||||||
|
with open(os.path.join(ques_adj_matrices_dir, f'{graph_idx}.pkl'), 'rb') as f:
|
||||||
|
question_graphs = pickle.load(f)
|
||||||
|
q_graph_last = question_graphs[-1]
|
||||||
|
question_edge_index = []
|
||||||
|
question_edge_attribute = []
|
||||||
|
for edge_index, edge_attr in q_graph_last:
|
||||||
|
question_edge_index.append(edge_index)
|
||||||
|
edge_attr_one_hot = np.zeros((len(self.parse_vocab) + 1,), dtype=np.float32)
|
||||||
|
edge_attr_one_hot[self.parse_vocab.get(edge_attr, len(self.parse_vocab))] = 1.0
|
||||||
|
question_edge_attribute.append(edge_attr_one_hot)
|
||||||
|
question_edge_index = np.array(question_edge_index, dtype=np.float64)
|
||||||
|
question_edge_attribute = np.stack(question_edge_attribute, axis=0)
|
||||||
|
|
||||||
|
question_edge_indices_all = [torch.from_numpy(question_edge_index).t().long().contiguous() for _ in range(num_options)]
|
||||||
|
question_edge_attributes_all = [torch.from_numpy(question_edge_attribute).contiguous() for _ in range(num_options)]
|
||||||
|
|
||||||
|
with open(os.path.join(hist_adj_matrices_dir, f'{graph_idx}.pkl'), 'rb') as f:
|
||||||
|
_history_edge_incides_all = pickle.load(f)
|
||||||
|
_history_edge_incides_last = _history_edge_incides_all[-1]
|
||||||
|
history_edge_index_all = [torch.tensor(_history_edge_incides_last).t().long().contiguous() for _ in range(num_options)]
|
||||||
|
|
||||||
|
if self.config['stack_gr_data']:
|
||||||
|
question_edge_indices_all = torch.stack(question_edge_indices_all, dim=0)
|
||||||
|
question_edge_attributes_all = torch.stack(question_edge_attributes_all, dim=0)
|
||||||
|
history_edge_index_all = torch.stack(history_edge_index_all, dim=0)
|
||||||
|
history_sep_indices_all = torch.stack(history_sep_indices_all, dim=0)
|
||||||
|
len_question_gr = torch.tensor(question_edge_indices_all.size(-1)).unsqueeze(0).repeat(num_options, 1)
|
||||||
|
len_history_gr = torch.tensor(history_edge_index_all.size(-1)).repeat(num_options, 1)
|
||||||
|
len_history_sep = torch.tensor(history_sep_indices_all.size(-1)).repeat(num_options, 1)
|
||||||
|
|
||||||
|
item = {}
|
||||||
|
item['tokens'] = tokens_all.unsqueeze(0)
|
||||||
|
item['segments'] = segments_all.unsqueeze(0)
|
||||||
|
item['sep_indices'] = sep_indices_all.unsqueeze(0)
|
||||||
|
item['mask'] = mask_all.unsqueeze(0)
|
||||||
|
item['hist_len'] = hist_len_all.unsqueeze(0)
|
||||||
|
item['question_limits'] = question_limits_all
|
||||||
|
item['question_edge_indices'] = question_edge_indices_all
|
||||||
|
item['question_edge_attributes'] = question_edge_attributes_all
|
||||||
|
|
||||||
|
item['history_edge_indices'] = history_edge_index_all
|
||||||
|
item['history_sep_indices'] = history_sep_indices_all
|
||||||
|
|
||||||
|
if self.config['stack_gr_data']:
|
||||||
|
item['len_question_gr'] = len_question_gr
|
||||||
|
item['len_history_gr'] = len_history_gr
|
||||||
|
item['len_history_sep'] = len_history_sep
|
||||||
|
|
||||||
|
item['round_id'] = torch.LongTensor([dialog['round_id']])
|
||||||
|
|
||||||
|
# get image features
|
||||||
|
if not self.config['dataloader_text_only']:
|
||||||
|
features, num_boxes, boxes, _ , image_target, image_edge_indexes, image_edge_attributes = self._image_features_reader[img_id]
|
||||||
|
features, spatials, image_mask, image_target, image_label = encode_image_input(features, num_boxes, boxes, image_target, max_regions=self._max_region_num, mask_prob=0)
|
||||||
|
else:
|
||||||
|
features = spatials = image_mask = image_target = image_label = torch.tensor([0])
|
||||||
|
|
||||||
|
item['image_feat'] = features
|
||||||
|
item['image_loc'] = spatials
|
||||||
|
item['image_mask'] = image_mask
|
||||||
|
item['image_target'] = image_target
|
||||||
|
item['image_label'] = image_label
|
||||||
|
item['image_id'] = torch.LongTensor([img_id])
|
||||||
|
if self._split == 'train':
|
||||||
|
# cheap hack to account for the graph data for the postitive and negatice examples
|
||||||
|
item['image_edge_indices'] = [torch.from_numpy(image_edge_indexes).long(), torch.from_numpy(image_edge_indexes).long()]
|
||||||
|
item['image_edge_attributes'] = [torch.from_numpy(image_edge_attributes), torch.from_numpy(image_edge_attributes)]
|
||||||
|
elif self._split == 'val':
|
||||||
|
# cheap hack to account for the graph data for the postitive and negatice examples
|
||||||
|
item['image_edge_indices'] = [torch.from_numpy(image_edge_indexes).contiguous().long() for _ in range(1000)]
|
||||||
|
item['image_edge_attributes'] = [torch.from_numpy(image_edge_attributes).contiguous() for _ in range(1000)]
|
||||||
|
|
||||||
|
else:
|
||||||
|
# cheap hack to account for the graph data for the postitive and negatice examples
|
||||||
|
item['image_edge_indices'] = [torch.from_numpy(image_edge_indexes).contiguous().long() for _ in range(100)]
|
||||||
|
item['image_edge_attributes'] = [torch.from_numpy(image_edge_attributes).contiguous() for _ in range(100)]
|
||||||
|
|
||||||
|
if self.config['stack_gr_data']:
|
||||||
|
item['image_edge_indices'] = torch.stack(item['image_edge_indices'], dim=0)
|
||||||
|
item['image_edge_attributes'] = torch.stack(item['image_edge_attributes'], dim=0)
|
||||||
|
len_image_gr = torch.tensor(item['image_edge_indices'].size(-1)).unsqueeze(0).repeat(num_options)
|
||||||
|
item['len_image_gr'] = len_image_gr
|
||||||
|
|
||||||
|
return item
|
313
dataloader/dataloader_visdial_dense.py
Normal file
313
dataloader/dataloader_visdial_dense.py
Normal file
|
@ -0,0 +1,313 @@
|
||||||
|
import torch
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
from tqdm import tqdm
|
||||||
|
import copy
|
||||||
|
import pyhocon
|
||||||
|
import glog as log
|
||||||
|
from collections import OrderedDict
|
||||||
|
import argparse
|
||||||
|
import pickle
|
||||||
|
import torch.utils.data as tud
|
||||||
|
import sys
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
||||||
|
|
||||||
|
from utils.data_utils import encode_input, encode_image_input
|
||||||
|
from dataloader.dataloader_base import DatasetBase
|
||||||
|
|
||||||
|
|
||||||
|
class VisdialDenseDataset(DatasetBase):
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super(VisdialDenseDataset, self).__init__(config)
|
||||||
|
with open(config.tr_graph_idx_mapping, 'r') as f:
|
||||||
|
self.tr_graph_idx_mapping = json.load(f)
|
||||||
|
|
||||||
|
with open(config.val_graph_idx_mapping, 'r') as f:
|
||||||
|
self.val_graph_idx_mapping = json.load(f)
|
||||||
|
|
||||||
|
with open(config.test_graph_idx_mapping, 'r') as f:
|
||||||
|
self.test_graph_idx_mapping = json.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
self.question_gr_paths = {
|
||||||
|
'train': os.path.join(self.config['visdial_question_adj_matrices'], 'train'),
|
||||||
|
'val': os.path.join(self.config['visdial_question_adj_matrices'], 'val'),
|
||||||
|
'test': os.path.join(self.config['visdial_question_adj_matrices'], 'test')
|
||||||
|
}
|
||||||
|
|
||||||
|
self.history_gr_paths = {
|
||||||
|
'train': os.path.join(self.config['visdial_history_adj_matrices'], 'train'),
|
||||||
|
'val': os.path.join(self.config['visdial_history_adj_matrices'], 'val'),
|
||||||
|
'test': os.path.join(self.config['visdial_history_adj_matrices'], 'test')
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
MAX_SEQ_LEN = self.config['max_seq_len']
|
||||||
|
cur_data = None
|
||||||
|
cur_dense_annotations = None
|
||||||
|
|
||||||
|
if self._split == 'train':
|
||||||
|
cur_data = self.visdial_data_train['data']
|
||||||
|
cur_dense_annotations = self.visdial_data_train_dense
|
||||||
|
cur_question_gr_path = self.question_gr_paths['train']
|
||||||
|
cur_history_gr_path = self.history_gr_paths['train']
|
||||||
|
cur_gr_mapping = self.tr_graph_idx_mapping
|
||||||
|
|
||||||
|
if self.config['rlv_hst_only']:
|
||||||
|
cur_rlv_hst = self.rlv_hst_train
|
||||||
|
elif self._split == 'val':
|
||||||
|
cur_data = self.visdial_data_val['data']
|
||||||
|
cur_dense_annotations = self.visdial_data_val_dense
|
||||||
|
cur_question_gr_path = self.question_gr_paths['val']
|
||||||
|
cur_history_gr_path = self.history_gr_paths['val']
|
||||||
|
cur_gr_mapping = self.val_graph_idx_mapping
|
||||||
|
|
||||||
|
if self.config['rlv_hst_only']:
|
||||||
|
cur_rlv_hst = self.rlv_hst_val
|
||||||
|
elif self._split == 'trainval':
|
||||||
|
if index >= self.numDataPoints['train']:
|
||||||
|
cur_data = self.visdial_data_val['data']
|
||||||
|
cur_dense_annotations = self.visdial_data_val_dense
|
||||||
|
cur_gr_mapping = self.val_graph_idx_mapping
|
||||||
|
index -= self.numDataPoints['train']
|
||||||
|
cur_question_gr_path = self.question_gr_paths['val']
|
||||||
|
cur_history_gr_path = self.history_gr_paths['val']
|
||||||
|
if self.config['rlv_hst_only']:
|
||||||
|
cur_rlv_hst = self.rlv_hst_val
|
||||||
|
else:
|
||||||
|
cur_data = self.visdial_data_train['data']
|
||||||
|
cur_dense_annotations = self.visdial_data_train_dense
|
||||||
|
cur_question_gr_path = self.question_gr_paths['train']
|
||||||
|
cur_gr_mapping = self.tr_graph_idx_mapping
|
||||||
|
cur_history_gr_path = self.history_gr_paths['train']
|
||||||
|
if self.config['rlv_hst_only']:
|
||||||
|
cur_rlv_hst = self.rlv_hst_train
|
||||||
|
elif self._split == 'test':
|
||||||
|
cur_data = self.visdial_data_test['data']
|
||||||
|
cur_question_gr_path = self.question_gr_paths['test']
|
||||||
|
cur_history_gr_path = self.history_gr_paths['test']
|
||||||
|
if self.config['rlv_hst_only']:
|
||||||
|
cur_rlv_hst = self.rlv_hst_test
|
||||||
|
|
||||||
|
# number of options to score on
|
||||||
|
num_options = self.num_options_dense
|
||||||
|
if self._split == 'test' or self.config['validating'] or self.config['predicting']:
|
||||||
|
assert num_options == 100
|
||||||
|
else:
|
||||||
|
assert num_options >=1 and num_options <= 100
|
||||||
|
|
||||||
|
dialog = cur_data['dialogs'][index]
|
||||||
|
cur_questions = cur_data['questions']
|
||||||
|
cur_answers = cur_data['answers']
|
||||||
|
img_id = dialog['image_id']
|
||||||
|
if self._split != 'test':
|
||||||
|
graph_idx = cur_gr_mapping[str(img_id)]
|
||||||
|
else:
|
||||||
|
graph_idx = index
|
||||||
|
|
||||||
|
if self._split != 'test':
|
||||||
|
assert img_id == cur_dense_annotations[index]['image_id']
|
||||||
|
if self.config['rlv_hst_only']:
|
||||||
|
rlv_hst = cur_rlv_hst[str(img_id)] # [10 for each round, 10 for cap + first 9 round ]
|
||||||
|
|
||||||
|
if self._split == 'test':
|
||||||
|
cur_rounds = len(dialog['dialog']) # 1, 2, ..., 10
|
||||||
|
else:
|
||||||
|
cur_rounds = cur_dense_annotations[index]['round_id'] # 1, 2, ..., 10
|
||||||
|
|
||||||
|
# caption
|
||||||
|
cur_rnd_utterance = []
|
||||||
|
include_caption = True
|
||||||
|
if self.config['rlv_hst_only']:
|
||||||
|
if self.config['rlv_hst_dense_round']:
|
||||||
|
if rlv_hst[0] == 0:
|
||||||
|
include_caption = False
|
||||||
|
elif rlv_hst[cur_rounds - 1][0] == 0:
|
||||||
|
include_caption = False
|
||||||
|
if include_caption:
|
||||||
|
sent = dialog['caption'].split(' ')
|
||||||
|
tokenized_sent = self.tokenizer.convert_tokens_to_ids(sent)
|
||||||
|
cur_rnd_utterance.append(tokenized_sent)
|
||||||
|
# tot_len += len(sent) + 1
|
||||||
|
|
||||||
|
for rnd, utterance in enumerate(dialog['dialog'][:cur_rounds]):
|
||||||
|
if self.config['rlv_hst_only'] and rnd < cur_rounds - 1:
|
||||||
|
if self.config['rlv_hst_dense_round']:
|
||||||
|
if rlv_hst[rnd + 1] == 0:
|
||||||
|
continue
|
||||||
|
elif rlv_hst[cur_rounds - 1][rnd + 1] == 0:
|
||||||
|
continue
|
||||||
|
# question
|
||||||
|
sent = cur_questions[utterance['question']].split(' ')
|
||||||
|
tokenized_sent = self.tokenizer.convert_tokens_to_ids(sent)
|
||||||
|
cur_rnd_utterance.append(tokenized_sent)
|
||||||
|
|
||||||
|
# answer
|
||||||
|
if rnd != cur_rounds - 1:
|
||||||
|
sent = cur_answers[utterance['answer']].split(' ')
|
||||||
|
tokenized_sent = self.tokenizer.convert_tokens_to_ids(sent)
|
||||||
|
cur_rnd_utterance.append(tokenized_sent)
|
||||||
|
|
||||||
|
if self.config['rlv_hst_only']:
|
||||||
|
num_rlv_rnds = len(cur_rnd_utterance) - 1
|
||||||
|
else:
|
||||||
|
num_rlv_rnds = None
|
||||||
|
|
||||||
|
if self._split != 'test':
|
||||||
|
gt_option = dialog['dialog'][cur_rounds - 1]['gt_index']
|
||||||
|
if self.config['training'] or self.config['debugging']:
|
||||||
|
# first select gt option id, then choose the first num_options inds
|
||||||
|
option_inds = []
|
||||||
|
option_inds.append(gt_option)
|
||||||
|
all_inds = list(range(100))
|
||||||
|
all_inds.remove(gt_option)
|
||||||
|
# debug
|
||||||
|
if num_options < 100:
|
||||||
|
random.shuffle(all_inds)
|
||||||
|
all_inds = all_inds[:(num_options-1)]
|
||||||
|
option_inds.extend(all_inds)
|
||||||
|
gt_option = 0
|
||||||
|
else:
|
||||||
|
option_inds = range(num_options)
|
||||||
|
answer_options = [dialog['dialog'][cur_rounds - 1]['answer_options'][k] for k in option_inds]
|
||||||
|
if 'relevance' in cur_dense_annotations[index]:
|
||||||
|
key = 'relevance'
|
||||||
|
else:
|
||||||
|
key = 'gt_relevance'
|
||||||
|
gt_relevance = torch.Tensor(cur_dense_annotations[index][key])
|
||||||
|
gt_relevance = gt_relevance[option_inds]
|
||||||
|
assert len(answer_options) == len(option_inds) == num_options
|
||||||
|
else:
|
||||||
|
answer_options = dialog['dialog'][-1]['answer_options']
|
||||||
|
assert len(answer_options) == num_options
|
||||||
|
|
||||||
|
options_all = []
|
||||||
|
for answer_option in answer_options:
|
||||||
|
cur_option = cur_rnd_utterance.copy()
|
||||||
|
cur_option.append(self.tokenizer.convert_tokens_to_ids(cur_answers[answer_option].split(' ')))
|
||||||
|
options_all.append(cur_option)
|
||||||
|
if not self.config['rlv_hst_only']:
|
||||||
|
assert len(cur_option) == 2 * cur_rounds + 1
|
||||||
|
|
||||||
|
tokens_all = []
|
||||||
|
mask_all = []
|
||||||
|
segments_all = []
|
||||||
|
sep_indices_all = []
|
||||||
|
hist_len_all = []
|
||||||
|
tot_len_debug = []
|
||||||
|
|
||||||
|
for opt_id, option in enumerate(options_all):
|
||||||
|
option, start_segment = self.pruneRounds(option, self.config['visdial_tot_rounds'])
|
||||||
|
tokens, segments, sep_indices, mask, start_question, end_question = encode_input(option, start_segment ,self.CLS,
|
||||||
|
self.SEP, self.MASK ,max_seq_len=MAX_SEQ_LEN, mask_prob=0)
|
||||||
|
|
||||||
|
tokens_all.append(tokens)
|
||||||
|
mask_all.append(mask)
|
||||||
|
segments_all.append(segments)
|
||||||
|
sep_indices_all.append(sep_indices)
|
||||||
|
hist_len_all.append(torch.LongTensor([len(option)-1]))
|
||||||
|
|
||||||
|
len_tokens = sum(len(s) for s in option)
|
||||||
|
tot_len_debug.append(len_tokens + len(option) + 1)
|
||||||
|
|
||||||
|
tokens_all = torch.cat(tokens_all,0)
|
||||||
|
mask_all = torch.cat(mask_all,0)
|
||||||
|
segments_all = torch.cat(segments_all, 0)
|
||||||
|
sep_indices_all = torch.cat(sep_indices_all, 0)
|
||||||
|
hist_len_all = torch.cat(hist_len_all,0)
|
||||||
|
question_limits_all = torch.tensor([start_question, end_question]).unsqueeze(0).repeat(num_options, 1)
|
||||||
|
if self.config['rlv_hst_only']:
|
||||||
|
assert num_rlv_rnds > 0
|
||||||
|
hist_idx = [i * 2 for i in range(num_rlv_rnds)]
|
||||||
|
else:
|
||||||
|
hist_idx = [i*2 for i in range(cur_rounds)]
|
||||||
|
history_sep_indices_all = sep_indices.squeeze(0)[hist_idx].contiguous().unsqueeze(0).repeat(num_options, 1)
|
||||||
|
|
||||||
|
with open(os.path.join(cur_question_gr_path, f'{graph_idx}.pkl'), 'rb') as f:
|
||||||
|
question_graphs = pickle.load(f)
|
||||||
|
question_graph_round = question_graphs[cur_rounds - 1]
|
||||||
|
question_edge_index = []
|
||||||
|
question_edge_attribute = []
|
||||||
|
for edge_index, edge_attr in question_graph_round:
|
||||||
|
question_edge_index.append(edge_index)
|
||||||
|
edge_attr_one_hot = np.zeros((len(self.parse_vocab) + 1,), dtype=np.float32)
|
||||||
|
edge_attr_one_hot[self.parse_vocab.get(edge_attr, len(self.parse_vocab))] = 1.0
|
||||||
|
question_edge_attribute.append(edge_attr_one_hot)
|
||||||
|
question_edge_index = np.array(question_edge_index, dtype=np.float64)
|
||||||
|
question_edge_attribute = np.stack(question_edge_attribute, axis=0)
|
||||||
|
|
||||||
|
question_edge_indices_all = [torch.from_numpy(question_edge_index).t().long().contiguous() for _ in range(num_options)]
|
||||||
|
question_edge_attributes_all = [torch.from_numpy(question_edge_attribute).contiguous() for _ in range(num_options)]
|
||||||
|
|
||||||
|
if self.config['rlv_hst_only']:
|
||||||
|
with open(os.path.join(cur_history_gr_path, f'{graph_idx}.pkl'), 'rb') as f:
|
||||||
|
_history_edge_incides_round = pickle.load(f)
|
||||||
|
else:
|
||||||
|
with open(os.path.join(cur_history_gr_path, f'{graph_idx}.pkl'), 'rb') as f:
|
||||||
|
_history_edge_incides_all = pickle.load(f)
|
||||||
|
_history_edge_incides_round = _history_edge_incides_all[cur_rounds - 1]
|
||||||
|
|
||||||
|
history_edge_index_all = [torch.tensor(_history_edge_incides_round).t().long().contiguous() for _ in range(num_options)]
|
||||||
|
|
||||||
|
if self.config['stack_gr_data']:
|
||||||
|
question_edge_indices_all = torch.stack(question_edge_indices_all, dim=0)
|
||||||
|
question_edge_attributes_all = torch.stack(question_edge_attributes_all, dim=0)
|
||||||
|
history_edge_index_all = torch.stack(history_edge_index_all, dim=0)
|
||||||
|
|
||||||
|
item = {}
|
||||||
|
|
||||||
|
item['tokens'] = tokens_all.unsqueeze(0) # [1, num_options, max_len]
|
||||||
|
item['segments'] = segments_all.unsqueeze(0)
|
||||||
|
item['sep_indices'] = sep_indices_all.unsqueeze(0)
|
||||||
|
item['mask'] = mask_all.unsqueeze(0)
|
||||||
|
item['hist_len'] = hist_len_all.unsqueeze(0)
|
||||||
|
item['question_limits'] = question_limits_all
|
||||||
|
item['question_edge_indices'] = question_edge_indices_all
|
||||||
|
item['question_edge_attributes'] = question_edge_attributes_all
|
||||||
|
item['history_edge_indices'] = history_edge_index_all
|
||||||
|
item['history_sep_indices'] = history_sep_indices_all
|
||||||
|
|
||||||
|
# add dense annotation fields
|
||||||
|
if self._split != 'test':
|
||||||
|
item['gt_relevance'] = gt_relevance # [num_options]
|
||||||
|
item['gt_option_inds'] = torch.LongTensor([gt_option])
|
||||||
|
|
||||||
|
# add next sentence labels for training with the nsp loss as well
|
||||||
|
nsp_labels = torch.ones(*tokens_all.unsqueeze(0).shape[:-1]).long()
|
||||||
|
nsp_labels[:,gt_option] = 0
|
||||||
|
item['next_sentence_labels'] = nsp_labels
|
||||||
|
|
||||||
|
item['round_id'] = torch.LongTensor([cur_rounds])
|
||||||
|
else:
|
||||||
|
if 'round_id' in dialog:
|
||||||
|
item['round_id'] = torch.LongTensor([dialog['round_id']])
|
||||||
|
else:
|
||||||
|
item['round_id'] = torch.LongTensor([cur_rounds])
|
||||||
|
|
||||||
|
# get image features
|
||||||
|
if not self.config['dataloader_text_only']:
|
||||||
|
features, num_boxes, boxes, _ , image_target, image_edge_indexes, image_edge_attributes = self._image_features_reader[img_id]
|
||||||
|
features, spatials, image_mask, image_target, image_label = encode_image_input(features, num_boxes, boxes, image_target, max_regions=self._max_region_num, mask_prob=0)
|
||||||
|
else:
|
||||||
|
features = spatials = image_mask = image_target = image_label = torch.tensor([0])
|
||||||
|
item['image_feat'] = features
|
||||||
|
item['image_loc'] = spatials
|
||||||
|
item['image_mask'] = image_mask
|
||||||
|
item['image_id'] = torch.LongTensor([img_id])
|
||||||
|
item['tot_len'] = torch.LongTensor(tot_len_debug)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
item['image_edge_indices'] = [torch.from_numpy(image_edge_indexes).contiguous().long() for _ in range(num_options)]
|
||||||
|
item['image_edge_attributes'] = [torch.from_numpy(image_edge_attributes).contiguous() for _ in range(num_options)]
|
||||||
|
|
||||||
|
if self.config['stack_gr_data']:
|
||||||
|
item['image_edge_indices'] = torch.stack(item['image_edge_indices'], dim=0)
|
||||||
|
item['image_edge_attributes'] = torch.stack(item['image_edge_attributes'], dim=0)
|
||||||
|
|
||||||
|
return item
|
114
ensemble.py
Normal file
114
ensemble.py
Normal file
|
@ -0,0 +1,114 @@
|
||||||
|
import os
|
||||||
|
import os.path as osp
|
||||||
|
import numpy as np
|
||||||
|
import json
|
||||||
|
import argparse
|
||||||
|
import pyhocon
|
||||||
|
import glog as log
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from utils.data_utils import load_pickle_lines
|
||||||
|
from utils.visdial_metrics import scores_to_ranks
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='Ensemble for VisDial')
|
||||||
|
parser.add_argument('--exp', type=str, default='test',
|
||||||
|
help='experiment name from .conf')
|
||||||
|
parser.add_argument('--mode', type=str, default='predict', choices=['eval', 'predict'],
|
||||||
|
help='eval or predict')
|
||||||
|
parser.add_argument('--ssh', action='store_true',
|
||||||
|
help='whether or not we are executing command via ssh. '
|
||||||
|
'If set to True, we will not log.info anything to screen and only redirect them to log file')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# initialization
|
||||||
|
config = pyhocon.ConfigFactory.parse_file(f"config/ensemble.conf")[args.exp]
|
||||||
|
config["log_dir"] = os.path.join(config["log_dir"], args.exp)
|
||||||
|
if not os.path.exists(config["log_dir"]):
|
||||||
|
os.makedirs(config["log_dir"])
|
||||||
|
|
||||||
|
# set logs
|
||||||
|
log_file = os.path.join(config["log_dir"], f'{args.mode}.log')
|
||||||
|
set_log_file(log_file, file_only=args.ssh)
|
||||||
|
|
||||||
|
# print environment info
|
||||||
|
log.info(f"Running experiment: {args.exp}")
|
||||||
|
log.info(f"Results saved to {config['log_dir']}")
|
||||||
|
log.info(pyhocon.HOCONConverter.convert(config, "hocon"))
|
||||||
|
|
||||||
|
if isinstance(config['processed'], list):
|
||||||
|
assert len(config['models']) == len(config['processed'])
|
||||||
|
processed = {model:pcd for model, pcd in zip(config['models'], config['processed'])}
|
||||||
|
else:
|
||||||
|
processed = {model: config['processed'] for model in config['models']}
|
||||||
|
|
||||||
|
if config['split'] == 'test' and np.any(config['processed']):
|
||||||
|
test_data = json.load(open(config['visdial_test_data']))['data']['dialogs']
|
||||||
|
imid2rndid = {t['image_id']: len(t['dialog']) for t in test_data}
|
||||||
|
del test_data
|
||||||
|
|
||||||
|
# load predictions files
|
||||||
|
visdial_outputs = dict()
|
||||||
|
if args.mode == 'eval':
|
||||||
|
metrics = {}
|
||||||
|
for model in config['models']:
|
||||||
|
pred_filename = osp.join(config['pred_dir'], model, 'visdial_prediction.pkl')
|
||||||
|
pred_dict = {p['image_id']: p for p in load_pickle_lines(pred_filename)}
|
||||||
|
log.info(f'Loading {len(pred_dict)} predictions from {pred_filename}')
|
||||||
|
visdial_outputs[model] = pred_dict
|
||||||
|
if args.mode == 'eval':
|
||||||
|
assert len(visdial_outputs[model]) >= num_dialogs
|
||||||
|
metric = json.load(open(osp.join(config['pred_dir'], model, "metrics_epoch_best.json")))
|
||||||
|
metrics[model] = metric['val']
|
||||||
|
|
||||||
|
image_ids = visdial_outputs[model].keys()
|
||||||
|
predictions = []
|
||||||
|
|
||||||
|
# for each dialog
|
||||||
|
for image_id in tqdm(image_ids):
|
||||||
|
scores = []
|
||||||
|
round_id = None
|
||||||
|
|
||||||
|
for model in config['models']:
|
||||||
|
pred = visdial_outputs[model][image_id]
|
||||||
|
|
||||||
|
if config['split'] == 'test' and processed[model]:
|
||||||
|
# if predict on processed data, the first few rounds are deleted from some dialogs
|
||||||
|
# so the original round ids can only be found in the original test data
|
||||||
|
round_id_in_pred = imid2rndid[image_id]
|
||||||
|
else:
|
||||||
|
round_id_in_pred = pred['gt_relevance_round_id']
|
||||||
|
|
||||||
|
if not isinstance(round_id_in_pred, int):
|
||||||
|
round_id_in_pred = int(round_id_in_pred)
|
||||||
|
if round_id is None:
|
||||||
|
round_id = round_id_in_pred
|
||||||
|
else:
|
||||||
|
# make sure all models have the same round_id
|
||||||
|
assert round_id == round_id_in_pred
|
||||||
|
scores.append(torch.from_numpy(pred['nsp_probs']).unsqueeze(0))
|
||||||
|
|
||||||
|
# ensemble scores
|
||||||
|
scores = torch.cat(scores, 0) # [n_model, num_rounds, num_options]
|
||||||
|
scores = torch.sum(scores, dim=0, keepdim=True) # [1, num_rounds, num_options]
|
||||||
|
|
||||||
|
|
||||||
|
if scores.size(0) > 1:
|
||||||
|
scores = scores[round_id - 1].unsqueeze(0)
|
||||||
|
ranks = scores_to_ranks(scores) # [eval_batch_size, num_rounds, num_options]
|
||||||
|
ranks = ranks.squeeze(1)
|
||||||
|
prediction = {
|
||||||
|
"image_id": image_id,
|
||||||
|
"round_id": round_id,
|
||||||
|
"ranks": ranks[0].tolist()
|
||||||
|
}
|
||||||
|
predictions.append(prediction)
|
||||||
|
|
||||||
|
filename = osp.join(config['log_dir'], f'{config["split"]}_ensemble_preds.json')
|
||||||
|
with open(filename, 'w') as f:
|
||||||
|
json.dump(predictions, f)
|
||||||
|
log.info(f'{len(predictions)} predictions saved to {filename}')
|
199
main.py
Normal file
199
main.py
Normal file
|
@ -0,0 +1,199 @@
|
||||||
|
from utils.init_utils import load_runner, load_dataset, set_random_seed, set_training_steps, initialize_from_env, set_log_file, copy_file_to_log
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
import pyhocon
|
||||||
|
import glog as log
|
||||||
|
import socket
|
||||||
|
import getpass
|
||||||
|
|
||||||
|
try:
|
||||||
|
from apex.parallel import DistributedDataParallel as DDP
|
||||||
|
from apex import amp
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
print('apex not found')
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='Main script for VD-GR')
|
||||||
|
parser.add_argument(
|
||||||
|
'--model',
|
||||||
|
type=str,
|
||||||
|
default='vdgr/P1',
|
||||||
|
help='model name to train or test')
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--mode',
|
||||||
|
type=str,
|
||||||
|
default='train',
|
||||||
|
help='train, eval, predict or debug')
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--wandb_project',
|
||||||
|
type=str,
|
||||||
|
default='VD-GR'
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--wandb_mode',
|
||||||
|
type=str,
|
||||||
|
default='online',
|
||||||
|
choices=['online', 'offline', 'disabled', 'run', 'dryrun']
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--tag',
|
||||||
|
type=str,
|
||||||
|
default='K2',
|
||||||
|
help="Tag to differentiate the different runs"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--eval_dir',
|
||||||
|
type=str,
|
||||||
|
default='',
|
||||||
|
help="Directory of a trained model to evaluate"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument('--ssh', action='store_true',
|
||||||
|
help='whether or not we are executing command via ssh. '
|
||||||
|
'If set to True, we will not log.info anything to screen and only redirect them to log file')
|
||||||
|
|
||||||
|
|
||||||
|
def main(gpu, config, args):
|
||||||
|
config['training'] = args.mode == 'train'
|
||||||
|
config['validating'] = args.mode == 'eval'
|
||||||
|
config['debugging'] = args.mode == 'debug'
|
||||||
|
config['predicting'] = args.mode == 'predict'
|
||||||
|
config['wandb_project'] = args.wandb_project
|
||||||
|
config['wandb_mode'] = args.wandb_mode
|
||||||
|
|
||||||
|
if config['parallel'] and config['dp_type'] != 'dp':
|
||||||
|
config['rank'] = gpu
|
||||||
|
os.environ["MASTER_ADDR"] = "localhost"
|
||||||
|
os.environ["MASTER_PORT"] = str(config['master_port'])
|
||||||
|
dist.init_process_group(
|
||||||
|
backend='nccl',
|
||||||
|
world_size=config['num_gpus'],
|
||||||
|
rank=gpu
|
||||||
|
)
|
||||||
|
config['display'] = gpu == 0
|
||||||
|
if config['dp_type'] == 'apex':
|
||||||
|
torch.cuda.set_device(gpu)
|
||||||
|
else:
|
||||||
|
config['display'] = True
|
||||||
|
if config['debugging'] or (config['parallel'] and config['dp_type'] != 'dp'):
|
||||||
|
config['num_workers'] = 0
|
||||||
|
else:
|
||||||
|
config['num_workers'] = 0
|
||||||
|
# set logs
|
||||||
|
log_file = os.path.join(config["log_dir"], f'{args.mode}.log')
|
||||||
|
set_log_file(log_file, file_only=args.ssh)
|
||||||
|
|
||||||
|
# print environment info
|
||||||
|
if config['display']:
|
||||||
|
log.info('Host: {}, user: {}, CUDA_VISIBLE_DEVICES: {}, cwd: {}'.format(
|
||||||
|
socket.gethostname(), getpass.getuser(), os.environ.get('CUDA_VISIBLE_DEVICES', ''), os.getcwd()))
|
||||||
|
log.info('Command line is: {}'.format(' '.join(sys.argv)))
|
||||||
|
|
||||||
|
if config['parallel'] and config['dp_type'] != 'dp':
|
||||||
|
log.info(
|
||||||
|
f'World_size: {config["num_gpus"]}, cur rank: {config["rank"]}')
|
||||||
|
log.info(f"Running experiment: {args.model}")
|
||||||
|
log.info(f"Results saved to {config['log_dir']}")
|
||||||
|
|
||||||
|
# initialization
|
||||||
|
if config['display'] and config['training']:
|
||||||
|
copy_file_to_log(config['log_dir'])
|
||||||
|
set_random_seed(config['random_seed'])
|
||||||
|
|
||||||
|
device = torch.device(f"cuda:{gpu}")
|
||||||
|
if config["use_cpu"]:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
config['device'] = device
|
||||||
|
|
||||||
|
# prepare dataset
|
||||||
|
dataset, dataset_eval = load_dataset(config)
|
||||||
|
|
||||||
|
# set training steps
|
||||||
|
if not config['validating'] or config['parallel']:
|
||||||
|
config = set_training_steps(config, len(dataset))
|
||||||
|
|
||||||
|
if config['display']:
|
||||||
|
log.info(pyhocon.HOCONConverter.convert(config, "hocon"))
|
||||||
|
|
||||||
|
# load runner
|
||||||
|
runner = load_runner(config)
|
||||||
|
# apex
|
||||||
|
if config['dp_type'] == 'apex':
|
||||||
|
runner.model, runner.optimizer = amp.initialize(runner.model,
|
||||||
|
runner.optimizer,
|
||||||
|
opt_level="O1")
|
||||||
|
# parallel
|
||||||
|
if config['parallel']:
|
||||||
|
if config['dp_type'] == 'dp':
|
||||||
|
runner.model = nn.DataParallel(runner.model)
|
||||||
|
runner.model.to(config['device'])
|
||||||
|
elif config['dp_type'] == 'apex':
|
||||||
|
runner.model = DDP(runner.model)
|
||||||
|
elif config['dp_type'] == 'ddp':
|
||||||
|
torch.cuda.set_device(gpu)
|
||||||
|
runner.model = runner.model.to(gpu)
|
||||||
|
runner.model = nn.parallel.DistributedDataParallel(
|
||||||
|
runner.model,
|
||||||
|
device_ids=[gpu],
|
||||||
|
output_device=gpu,
|
||||||
|
find_unused_parameters=True)
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unrecognized dp_type: {config["dp_type"]}')
|
||||||
|
|
||||||
|
if config['training'] or config['debugging']:
|
||||||
|
runner.load_pretrained_vilbert()
|
||||||
|
runner.train(dataset, dataset_eval)
|
||||||
|
else:
|
||||||
|
if config['loads_start_path']:
|
||||||
|
runner.load_pretrained_vilbert()
|
||||||
|
else:
|
||||||
|
runner.load_ckpt_best()
|
||||||
|
|
||||||
|
metrics_results = {}
|
||||||
|
if config['predicting']:
|
||||||
|
eval_splits = [config['predict_split']]
|
||||||
|
else:
|
||||||
|
eval_splits = ['val']
|
||||||
|
if config['model_type'] == 'conly' and not config['train_each_round']:
|
||||||
|
eval_splits.append('test')
|
||||||
|
for split in eval_splits:
|
||||||
|
if config['display']:
|
||||||
|
log.info(f'Results on {split} split of the best epoch')
|
||||||
|
if dataset_eval is None:
|
||||||
|
dataset_to_eval = dataset
|
||||||
|
else:
|
||||||
|
dataset_to_eval = dataset_eval
|
||||||
|
dataset_to_eval.split = split
|
||||||
|
_, metrics_results[split] = runner.evaluate(
|
||||||
|
dataset_to_eval, eval_visdial=True)
|
||||||
|
if not config['predicting'] and config['display']:
|
||||||
|
runner.save_eval_results(split, 'best', metrics_results)
|
||||||
|
|
||||||
|
if config['parallel'] and config['dp_type'] != 'dp':
|
||||||
|
dist.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = parser.parse_args()
|
||||||
|
# initialization
|
||||||
|
model_type, model_name = args.model.split('/')
|
||||||
|
config = initialize_from_env(
|
||||||
|
model_name, args.mode, args.eval_dir, model_type, tag=args.tag)
|
||||||
|
if config['num_gpus'] > 1:
|
||||||
|
config['parallel'] = True
|
||||||
|
if config['dp_type'] == 'dp':
|
||||||
|
main(0, config, args)
|
||||||
|
else:
|
||||||
|
mp.spawn(main, nprocs=config['num_gpus'], args=(config, args))
|
||||||
|
else:
|
||||||
|
config['parallel'] = False
|
||||||
|
main(0, config, args)
|
0
misc/.gitkeep
Normal file
0
misc/.gitkeep
Normal file
BIN
misc/teaser_1.png
Normal file
BIN
misc/teaser_1.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 2.2 MiB |
BIN
misc/teaser_2.png
Normal file
BIN
misc/teaser_2.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 3.6 MiB |
BIN
misc/usa.png
Normal file
BIN
misc/usa.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 11 KiB |
0
models/__init__.py
Normal file
0
models/__init__.py
Normal file
830
models/runner.py
Normal file
830
models/runner.py
Normal file
|
@ -0,0 +1,830 @@
|
||||||
|
import os
|
||||||
|
import os.path as osp
|
||||||
|
import json
|
||||||
|
from collections import deque
|
||||||
|
import time
|
||||||
|
import re
|
||||||
|
import shutil
|
||||||
|
import glob
|
||||||
|
import pickle
|
||||||
|
import gc
|
||||||
|
import numpy as np
|
||||||
|
import glog as log
|
||||||
|
try:
|
||||||
|
from apex import amp
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
print('apex not found')
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.data as tud
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
from utils.data_utils import load_pickle_lines
|
||||||
|
from utils.visdial_metrics import SparseGTMetrics, NDCG, scores_to_ranks
|
||||||
|
import wandb
|
||||||
|
|
||||||
|
|
||||||
|
class Runner:
|
||||||
|
def __init__(self, config):
|
||||||
|
self.config = config
|
||||||
|
if 'rank' in config:
|
||||||
|
self.gpu_rank = config['rank']
|
||||||
|
else:
|
||||||
|
self.gpu_rank = 0
|
||||||
|
|
||||||
|
self.epoch_idx = 0
|
||||||
|
self.max_metric = 0.
|
||||||
|
self.max_metric_epoch_idx = 0
|
||||||
|
self.na_str = 'N/A'
|
||||||
|
|
||||||
|
if self.config["max_ckpt_to_keep"] > 0:
|
||||||
|
self.checkpoint_queue = deque(
|
||||||
|
[], maxlen=config["max_ckpt_to_keep"])
|
||||||
|
self.metrics_queue = deque([], maxlen=config["max_ckpt_to_keep"])
|
||||||
|
|
||||||
|
self.setup_wandb()
|
||||||
|
|
||||||
|
def setup_wandb(self):
|
||||||
|
if self.gpu_rank == 0:
|
||||||
|
print("[INFO] Set wandb logging on rank {}".format(0))
|
||||||
|
run = wandb.init(
|
||||||
|
project=self.config['wandb_project'], config=self.config, mode=self.config['wandb_mode'])
|
||||||
|
else:
|
||||||
|
run = None
|
||||||
|
self.run = run
|
||||||
|
|
||||||
|
def forward(self, batch, eval_visdial=False):
|
||||||
|
return NotImplementedError
|
||||||
|
|
||||||
|
def train(self, dataset, dataset_eval=None):
|
||||||
|
# wandb.login()
|
||||||
|
if os.path.exists(self.config['log_dir']) or self.config['loads_ckpt'] or self.config['loads_best_ckpt']:
|
||||||
|
self.load_ckpt()
|
||||||
|
|
||||||
|
if self.config['use_trainval']:
|
||||||
|
dataset.split = 'trainval'
|
||||||
|
else:
|
||||||
|
dataset.split = 'train'
|
||||||
|
batch_size = self.config['batch_size']
|
||||||
|
if self.config['parallel'] and self.config['dp_type'] != 'dp':
|
||||||
|
sampler_tr = tud.distributed.DistributedSampler(
|
||||||
|
dataset,
|
||||||
|
num_replicas=self.config['num_gpus'],
|
||||||
|
rank=self.gpu_rank
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sampler_tr = None
|
||||||
|
|
||||||
|
data_loader_tr = tud.DataLoader(
|
||||||
|
dataset=dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=self.config['training'] and not self.config['parallel'],
|
||||||
|
collate_fn=dataset.collate_fn,
|
||||||
|
num_workers=self.config['num_workers'],
|
||||||
|
sampler=sampler_tr
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
start_epoch_idx = self.epoch_idx
|
||||||
|
num_iter_epoch = self.config['num_iter_per_epoch']
|
||||||
|
if self.config['display']:
|
||||||
|
log.info(f'{num_iter_epoch} iter per epoch.')
|
||||||
|
|
||||||
|
# eval before training
|
||||||
|
eval_dense_at_first = self.config['train_on_dense'] and self.config['skip_mrr_eval'] and start_epoch_idx == 0
|
||||||
|
# eval before training under 2 circumstances:
|
||||||
|
# for dense finetuning, eval ndcg before the first epoch
|
||||||
|
# for mrr training, continue training and the last epoch is not evaluated
|
||||||
|
|
||||||
|
if (eval_dense_at_first or (self.config['eval_at_start'] and len(self.metrics_queue) == 0 and start_epoch_idx > 0)):
|
||||||
|
if eval_dense_at_first:
|
||||||
|
iter_now = 0
|
||||||
|
else:
|
||||||
|
iter_now = max(num_iter_epoch * start_epoch_idx, 0)
|
||||||
|
|
||||||
|
if dataset_eval is None:
|
||||||
|
dataset.split = 'val'
|
||||||
|
dataset_to_eval = dataset
|
||||||
|
else:
|
||||||
|
dataset_to_eval = dataset_eval
|
||||||
|
|
||||||
|
metrics_results = {}
|
||||||
|
metrics_to_maximize, metrics_results['val'] = self.evaluate(
|
||||||
|
dataset_to_eval, iter_now)
|
||||||
|
if eval_dense_at_first:
|
||||||
|
self.max_metric = metrics_to_maximize
|
||||||
|
self.max_metric_epoch_idx = -1
|
||||||
|
else:
|
||||||
|
if self.config['display']:
|
||||||
|
self.save_eval_results(
|
||||||
|
'val', start_epoch_idx - 1, metrics_results)
|
||||||
|
if metrics_to_maximize > self.max_metric:
|
||||||
|
self.max_metric = metrics_to_maximize
|
||||||
|
self.max_metric_epoch_idx = start_epoch_idx - 1
|
||||||
|
self.copy_best_results('val', start_epoch_idx - 1)
|
||||||
|
self.copy_best_predictions('val')
|
||||||
|
if dataset_eval is None:
|
||||||
|
if self.config['use_trainval']:
|
||||||
|
dataset.split = 'trainval'
|
||||||
|
else:
|
||||||
|
dataset.split = 'train'
|
||||||
|
|
||||||
|
num_epochs = self.config['num_epochs']
|
||||||
|
|
||||||
|
for epoch_idx in range(start_epoch_idx, num_epochs):
|
||||||
|
if self.config['parallel'] and self.config['dp_type'] != 'dp':
|
||||||
|
sampler_tr.set_epoch(epoch_idx)
|
||||||
|
|
||||||
|
self.epoch_idx = epoch_idx
|
||||||
|
|
||||||
|
if self.config['display']:
|
||||||
|
log.info(f'starting epoch {epoch_idx}')
|
||||||
|
log.info('training')
|
||||||
|
|
||||||
|
self.model.train()
|
||||||
|
|
||||||
|
num_batch = 0
|
||||||
|
next_logging_pct = .1
|
||||||
|
next_evaluating_pct = self.config["next_evaluating_pct"] + .1
|
||||||
|
start_time = time.time()
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
|
for batch in data_loader_tr:
|
||||||
|
if self.config['eval_before_training']:
|
||||||
|
log.info('Skipping stright to evaluation...')
|
||||||
|
break
|
||||||
|
num_batch += 1
|
||||||
|
pct = num_batch / num_iter_epoch * 100
|
||||||
|
iter_now = num_iter_epoch * epoch_idx + num_batch
|
||||||
|
|
||||||
|
output = self.forward(batch)
|
||||||
|
losses = output['losses']
|
||||||
|
|
||||||
|
# optimizer step
|
||||||
|
losses['tot_loss'] /= self.config['batch_multiply']
|
||||||
|
# debug
|
||||||
|
if self.config['debugging']:
|
||||||
|
log.info('try backward')
|
||||||
|
if self.config['dp_type'] == 'apex':
|
||||||
|
with amp.scale_loss(losses['tot_loss'], self.optimizer) as scaled_loss:
|
||||||
|
scaled_loss.backward()
|
||||||
|
else:
|
||||||
|
losses['tot_loss'].backward()
|
||||||
|
if self.config['debugging']:
|
||||||
|
log.info('backward done')
|
||||||
|
|
||||||
|
if iter_now % self.config['batch_multiply'] == 0:
|
||||||
|
self.optimizer.step()
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
self.scheduler.step()
|
||||||
|
|
||||||
|
# display and eval
|
||||||
|
if pct >= next_logging_pct:
|
||||||
|
if self.config['display']:
|
||||||
|
loss_to_print = ''
|
||||||
|
for key in losses:
|
||||||
|
if losses[key] is not None and isinstance(losses[key], torch.Tensor):
|
||||||
|
loss_to_print += f'[{key}: {losses[key].item():.4f}]'
|
||||||
|
print(
|
||||||
|
f'[{int(pct)}%][Epoch: {epoch_idx + 1}/{num_epochs}][Iter : {num_batch}/{len(data_loader_tr)}] [time: {time.time() - start_time:.2f}] {loss_to_print}'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
next_logging_pct += self.config["next_logging_pct"]
|
||||||
|
|
||||||
|
if self.config['debugging']:
|
||||||
|
break
|
||||||
|
|
||||||
|
if pct >= next_evaluating_pct:
|
||||||
|
next_evaluating_pct += self.config["next_evaluating_pct"]
|
||||||
|
|
||||||
|
if self.run:
|
||||||
|
if self.config['train_on_dense']:
|
||||||
|
self.run.log(
|
||||||
|
{
|
||||||
|
"Train/dense_loss": losses['dense_loss'],
|
||||||
|
"Train/total_loss": losses['tot_loss'],
|
||||||
|
},
|
||||||
|
step=iter_now
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.run.log(
|
||||||
|
{
|
||||||
|
"Train/lm_loss": losses['lm_loss'],
|
||||||
|
"Train/img_loss": losses['img_loss'],
|
||||||
|
"Train/nsp_loss": losses['nsp_loss'],
|
||||||
|
"Train/total_loss": losses['tot_loss'],
|
||||||
|
},
|
||||||
|
step=iter_now
|
||||||
|
)
|
||||||
|
|
||||||
|
lr_gnn, lr_bert = self.scheduler.get_lr()[0], self.scheduler.get_lr()[1]
|
||||||
|
self.run.log(
|
||||||
|
{
|
||||||
|
"Train/lr_gnn": lr_gnn,
|
||||||
|
"Train/lr_bert": lr_bert,
|
||||||
|
},
|
||||||
|
step=iter_now
|
||||||
|
)
|
||||||
|
del losses
|
||||||
|
# debug
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
if self.config['display']:
|
||||||
|
log.info(
|
||||||
|
f'100%,\ttime:\t{time.time() - start_time:.2f}'
|
||||||
|
)
|
||||||
|
ckpt_path = self.save_ckpt()
|
||||||
|
|
||||||
|
if not self.config['skip_visdial_eval'] and self.epoch_idx % self.config['eval_visdial_every'] == 0:
|
||||||
|
|
||||||
|
iter_now = num_iter_epoch * (epoch_idx + 1)
|
||||||
|
|
||||||
|
if dataset_eval is None:
|
||||||
|
dataset.split = 'val'
|
||||||
|
dataset_to_eval = dataset
|
||||||
|
else:
|
||||||
|
dataset_to_eval = dataset_eval
|
||||||
|
metrics_results = {}
|
||||||
|
metrics_to_maximize, metrics_results['val'] = self.evaluate(
|
||||||
|
dataset_to_eval, iter_now)
|
||||||
|
if dataset_eval is None:
|
||||||
|
if self.config['use_trainval']:
|
||||||
|
dataset.split = 'trainval'
|
||||||
|
else:
|
||||||
|
dataset.split = 'train'
|
||||||
|
if self.config['display']:
|
||||||
|
self.save_eval_results('val', epoch_idx, metrics_results)
|
||||||
|
|
||||||
|
if self.config['display']:
|
||||||
|
|
||||||
|
if metrics_to_maximize > self.max_metric:
|
||||||
|
self.max_metric = metrics_to_maximize
|
||||||
|
self.max_metric_epoch_idx = epoch_idx
|
||||||
|
self.copy_best_results('val', epoch_idx)
|
||||||
|
self.copy_best_predictions('val')
|
||||||
|
|
||||||
|
elif not self.config['parallel'] and epoch_idx - self.max_metric_epoch_idx > self.config["early_stop_epoch"]:
|
||||||
|
log.info('Early stop.')
|
||||||
|
break
|
||||||
|
|
||||||
|
if self.run:
|
||||||
|
self.run.log(
|
||||||
|
{"Val/metric_best": self.max_metric}, step=iter_now)
|
||||||
|
|
||||||
|
if self.config['parallel']:
|
||||||
|
if self.config['dp_type'] == 'dp':
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
else:
|
||||||
|
dist.barrier()
|
||||||
|
log.info('Rank {} passed barrier...'.format(self.gpu_rank))
|
||||||
|
|
||||||
|
if self.config['stop_epochs'] >= 0 and epoch_idx + 1 >= self.config['stop_epochs']:
|
||||||
|
if self.config['display']:
|
||||||
|
log.info('Stop for reaching stop_epochs.')
|
||||||
|
break
|
||||||
|
|
||||||
|
def evaluate(self, dataset, training_iter=None, eval_visdial=True):
|
||||||
|
# create files to save output
|
||||||
|
if self.config['predicting']:
|
||||||
|
visdial_file_name = None
|
||||||
|
if self.config['save_score']:
|
||||||
|
visdial_file_name = osp.join(
|
||||||
|
self.config['log_dir'], f'visdial_prediction.pkl')
|
||||||
|
if osp.exists(visdial_file_name):
|
||||||
|
dialogs_predicted = load_pickle_lines(
|
||||||
|
visdial_file_name)
|
||||||
|
dialogs_predicted = [d['image_id']
|
||||||
|
for d in dialogs_predicted]
|
||||||
|
else:
|
||||||
|
dialogs_predicted = []
|
||||||
|
f_visdial = open(visdial_file_name, 'ab')
|
||||||
|
|
||||||
|
else:
|
||||||
|
visdial_file_name = osp.join(
|
||||||
|
self.config['log_dir'], f'visdial_prediction.jsonlines')
|
||||||
|
if self.config['parallel'] and self.config['dp_type'] != 'dp':
|
||||||
|
visdial_file_name = visdial_file_name.replace(
|
||||||
|
'.jsonlines', f'_{self.config["rank"]}of{self.config["num_gpus"]}.jsonlines')
|
||||||
|
if osp.exists(visdial_file_name):
|
||||||
|
dialogs_predicted_visdial = [json.loads(
|
||||||
|
line)['image_id'] for line in open(visdial_file_name)]
|
||||||
|
f_visdial = open(visdial_file_name, 'a')
|
||||||
|
else:
|
||||||
|
dialogs_predicted_visdial = []
|
||||||
|
f_visdial = open(visdial_file_name, 'w')
|
||||||
|
|
||||||
|
dialogs_predicted = dialogs_predicted_visdial
|
||||||
|
|
||||||
|
if len(dialogs_predicted) > 0:
|
||||||
|
log.info(f'Found {len(dialogs_predicted)} predicted results.')
|
||||||
|
|
||||||
|
if self.config['display']:
|
||||||
|
if visdial_file_name is not None:
|
||||||
|
log.info(
|
||||||
|
f'VisDial predictions saved to {visdial_file_name}')
|
||||||
|
|
||||||
|
elif self.config['display']:
|
||||||
|
if self.config['continue_evaluation']:
|
||||||
|
predicted_files = os.listdir(
|
||||||
|
osp.join(self.config['visdial_output_dir'], dataset.split))
|
||||||
|
dialogs_predicted = [
|
||||||
|
int(re.match(r'(\d+).npz', p).group(1)) for p in predicted_files]
|
||||||
|
else:
|
||||||
|
if osp.exists(osp.join(self.config['visdial_output_dir'], dataset.split)):
|
||||||
|
shutil.rmtree(
|
||||||
|
osp.join(self.config['visdial_output_dir'], dataset.split))
|
||||||
|
os.makedirs(
|
||||||
|
osp.join(self.config['visdial_output_dir'], dataset.split))
|
||||||
|
|
||||||
|
dialogs_predicted = []
|
||||||
|
log.info(f'Found {len(dialogs_predicted)} predicted results.')
|
||||||
|
|
||||||
|
if self.config['parallel'] and self.config['dp_type'] != 'dp':
|
||||||
|
sampler_val = tud.distributed.DistributedSampler(
|
||||||
|
dataset,
|
||||||
|
num_replicas=self.config['num_gpus'],
|
||||||
|
rank=self.gpu_rank
|
||||||
|
)
|
||||||
|
|
||||||
|
sampler_val.set_epoch(self.epoch_idx)
|
||||||
|
else:
|
||||||
|
sampler_val = None
|
||||||
|
|
||||||
|
data_loader_val = tud.DataLoader(
|
||||||
|
dataset=dataset,
|
||||||
|
batch_size=self.config['eval_batch_size'],
|
||||||
|
shuffle=False,
|
||||||
|
collate_fn=dataset.collate_fn,
|
||||||
|
num_workers=self.config['num_workers'],
|
||||||
|
sampler=sampler_val
|
||||||
|
)
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
if self.config['display']:
|
||||||
|
log.info(f'Evaluating {len(dataset)} samples')
|
||||||
|
|
||||||
|
next_logging_pct = self.config["next_logging_pct"] + .1
|
||||||
|
if self.config['parallel'] and self.config['dp_type'] == 'dp':
|
||||||
|
num_batch_tot = int(
|
||||||
|
np.ceil(len(dataset) / self.config['eval_batch_size']))
|
||||||
|
else:
|
||||||
|
num_batch_tot = int(np.ceil(
|
||||||
|
len(dataset) / (self.config['eval_batch_size'] * self.config['num_gpus'])))
|
||||||
|
num_batch = 0
|
||||||
|
if dataset.split == 'val':
|
||||||
|
num_options = self.config["num_options"]
|
||||||
|
if self.config['skip_mrr_eval']:
|
||||||
|
num_rounds = 1
|
||||||
|
else:
|
||||||
|
num_rounds = 10
|
||||||
|
elif dataset.split == 'test':
|
||||||
|
num_options = 100
|
||||||
|
num_rounds = 1
|
||||||
|
if self.gpu_rank == 0:
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
for batch in data_loader_val:
|
||||||
|
num_batch += 1
|
||||||
|
# skip dialogs that have been predicted
|
||||||
|
if self.config['predicting']:
|
||||||
|
image_ids = batch['image_id'].tolist()
|
||||||
|
skip_batch = True
|
||||||
|
for image_id in image_ids:
|
||||||
|
if image_id not in dialogs_predicted:
|
||||||
|
skip_batch = False
|
||||||
|
if skip_batch:
|
||||||
|
continue
|
||||||
|
output = self.forward(
|
||||||
|
batch, eval_visdial=eval_visdial)
|
||||||
|
|
||||||
|
# visdial evaluation
|
||||||
|
if eval_visdial:
|
||||||
|
img_ids = batch['image_id'].tolist()
|
||||||
|
batch_size = len(img_ids)
|
||||||
|
if not self.config['skip_ndcg_eval']:
|
||||||
|
gt_relevance_round_id = batch['round_id'].tolist()
|
||||||
|
|
||||||
|
# [batch_size * num_rounds * num_options, 2]
|
||||||
|
nsp_scores = output['nsp_scores']
|
||||||
|
nsp_probs = F.softmax(nsp_scores, dim=1)
|
||||||
|
assert nsp_probs.shape[-1] == 2
|
||||||
|
# num_dim=2, 0 for postive, 1 for negative
|
||||||
|
nsp_probs = nsp_probs[:, 0]
|
||||||
|
nsp_probs = nsp_probs.view(
|
||||||
|
batch_size, num_rounds, num_options)
|
||||||
|
|
||||||
|
# could be predicting or evaluating
|
||||||
|
if dataset.split == 'val':
|
||||||
|
if self.config['skip_ndcg_eval']:
|
||||||
|
gt_option_inds = batch['gt_option_inds']
|
||||||
|
|
||||||
|
for b in range(batch_size):
|
||||||
|
filename = osp.join(
|
||||||
|
self.config['visdial_output_dir'], dataset.split, f'{img_ids[b]}.npz')
|
||||||
|
if not osp.exists(filename):
|
||||||
|
np.savez(
|
||||||
|
filename,
|
||||||
|
nsp_probs=nsp_probs[b].cpu().numpy(),
|
||||||
|
gt_option_inds=gt_option_inds[b].cpu().numpy()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# [batch_size, num_rounds]
|
||||||
|
gt_option_inds = batch['gt_option_inds']
|
||||||
|
# [batch_size, num_options]
|
||||||
|
gt_relevance = batch['gt_relevance']
|
||||||
|
|
||||||
|
for b in range(batch_size):
|
||||||
|
filename = osp.join(
|
||||||
|
self.config['visdial_output_dir'], dataset.split, f'{img_ids[b]}.npz')
|
||||||
|
if not osp.exists(filename):
|
||||||
|
np.savez(filename,
|
||||||
|
nsp_probs=nsp_probs[b].cpu().numpy(),
|
||||||
|
gt_option_inds=gt_option_inds[b].cpu(
|
||||||
|
).numpy(),
|
||||||
|
gt_relevance=gt_relevance[b].cpu(
|
||||||
|
).numpy(),
|
||||||
|
gt_relevance_round_id=gt_relevance_round_id[b])
|
||||||
|
|
||||||
|
# must be predicting
|
||||||
|
if dataset.split == 'test':
|
||||||
|
if self.config['save_score']:
|
||||||
|
for b in range(batch_size):
|
||||||
|
prediction = {
|
||||||
|
"image_id": img_ids[b],
|
||||||
|
"nsp_probs": nsp_probs[b].cpu().numpy(),
|
||||||
|
"gt_relevance_round_id": gt_relevance_round_id[b]
|
||||||
|
}
|
||||||
|
pickle.dump(prediction, f_visdial)
|
||||||
|
else:
|
||||||
|
# [eval_batch_size, num_rounds, num_options]
|
||||||
|
ranks = scores_to_ranks(nsp_probs)
|
||||||
|
ranks = ranks.squeeze(1)
|
||||||
|
for b in range(batch_size):
|
||||||
|
prediction = {
|
||||||
|
"image_id": img_ids[b],
|
||||||
|
"round_id": gt_relevance_round_id[b],
|
||||||
|
"ranks": ranks[b].tolist()
|
||||||
|
}
|
||||||
|
f_visdial.write(json.dumps(prediction) + '\n')
|
||||||
|
|
||||||
|
# debug
|
||||||
|
if self.config['debugging']:
|
||||||
|
break
|
||||||
|
|
||||||
|
pct = num_batch / num_batch_tot * 100
|
||||||
|
if pct >= next_logging_pct:
|
||||||
|
if self.config['display'] and self.gpu_rank == 0:
|
||||||
|
log.info(
|
||||||
|
f'{int(pct)}%,\ttime:\t{time.time() - start_time:.2f}'
|
||||||
|
)
|
||||||
|
next_logging_pct += self.config["next_logging_pct"]
|
||||||
|
# debug
|
||||||
|
if self.config['debugging']:
|
||||||
|
break
|
||||||
|
|
||||||
|
if self.config['display'] and self.gpu_rank == 0:
|
||||||
|
pct = num_batch / num_batch_tot * 100
|
||||||
|
log.info(
|
||||||
|
f'{int(pct)}%,\ttime:\t{time.time() - start_time:.2f}'
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.config['validating']:
|
||||||
|
self.model.train()
|
||||||
|
|
||||||
|
if self.config['parallel'] and self.config['dp_type'] != 'dp':
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
print(f'{self.gpu_rank} passed barrier')
|
||||||
|
|
||||||
|
if self.config['predicting']:
|
||||||
|
f_visdial.close()
|
||||||
|
if not self.config['save_score']:
|
||||||
|
all_visdial_predictions = [json.loads(
|
||||||
|
line) for line in open(visdial_file_name)]
|
||||||
|
if self.config['predict_split'] == 'test' and len(all_visdial_predictions) == self.config['num_test_dialogs']:
|
||||||
|
visdial_file_name = visdial_file_name.replace(
|
||||||
|
'jsonlines', 'json')
|
||||||
|
with open(visdial_file_name, 'w') as f_visdial:
|
||||||
|
json.dump(all_visdial_predictions, f_visdial)
|
||||||
|
log.info(
|
||||||
|
f'Prediction for submisson save to {visdial_file_name}.')
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
if self.config['display']:
|
||||||
|
if dataset.split == 'val' and eval_visdial:
|
||||||
|
if not self.config['skip_mrr_eval']:
|
||||||
|
sparse_metrics = SparseGTMetrics()
|
||||||
|
if not self.config['skip_ndcg_eval']:
|
||||||
|
ndcg = NDCG()
|
||||||
|
|
||||||
|
if dataset.split == 'val' and eval_visdial:
|
||||||
|
visdial_output_filenames = glob.glob(
|
||||||
|
osp.join(self.config['visdial_output_dir'], dataset.split, '*.npz'))
|
||||||
|
log.info(
|
||||||
|
f'Calculating visdial metrics for {len(visdial_output_filenames)} dialogs')
|
||||||
|
for visdial_output_filename in visdial_output_filenames:
|
||||||
|
output = np.load(visdial_output_filename)
|
||||||
|
nsp_probs = torch.from_numpy(
|
||||||
|
output['nsp_probs']).unsqueeze(0)
|
||||||
|
if not self.config['skip_ndcg_eval']:
|
||||||
|
gt_relevance = torch.from_numpy(output['gt_relevance']).unsqueeze(0)
|
||||||
|
if not self.config['skip_mrr_eval']:
|
||||||
|
gt_option_inds = torch.from_numpy(
|
||||||
|
output['gt_option_inds']).unsqueeze(0)
|
||||||
|
sparse_metrics.observe(nsp_probs, gt_option_inds)
|
||||||
|
if not self.config['skip_ndcg_eval']:
|
||||||
|
gt_relevance_round_id = output['gt_relevance_round_id']
|
||||||
|
nsp_probs_dense = nsp_probs[0, gt_relevance_round_id - 1, :].unsqueeze(0)
|
||||||
|
else:
|
||||||
|
nsp_probs_dense = nsp_probs.squeeze(0) # [1, 100]
|
||||||
|
if not self.config['skip_ndcg_eval']:
|
||||||
|
ndcg.observe(nsp_probs_dense, gt_relevance)
|
||||||
|
|
||||||
|
# visdial eval output
|
||||||
|
visdial_metrics = {}
|
||||||
|
if dataset.split == 'val' and eval_visdial:
|
||||||
|
if not self.config['skip_mrr_eval']:
|
||||||
|
visdial_metrics.update(sparse_metrics.retrieve(reset=True))
|
||||||
|
if not self.config['skip_ndcg_eval']:
|
||||||
|
visdial_metrics.update(ndcg.retrieve(reset=True))
|
||||||
|
|
||||||
|
if self.config['display']:
|
||||||
|
to_print = ''
|
||||||
|
for metric_name, metric_value in visdial_metrics.items():
|
||||||
|
if 'round' not in metric_name:
|
||||||
|
to_print += f"\n{metric_name}: {metric_value}"
|
||||||
|
if training_iter is not None:
|
||||||
|
if self.run:
|
||||||
|
self.run.log(
|
||||||
|
{'Val/' + metric_name: metric_value}, step=training_iter)
|
||||||
|
log.info(to_print)
|
||||||
|
|
||||||
|
if self.config['metrics_to_maximize'] in visdial_metrics:
|
||||||
|
metrics_to_maximize = visdial_metrics[self.config['metrics_to_maximize']]
|
||||||
|
else:
|
||||||
|
metrics_to_maximize = None
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
return metrics_to_maximize, visdial_metrics
|
||||||
|
else:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
def save_eval_results(self, split, epoch_idx, metrics_results):
|
||||||
|
|
||||||
|
metrics_filename = osp.join(
|
||||||
|
self.config['log_dir'], f'metrics_epoch_{epoch_idx}.json')
|
||||||
|
with open(metrics_filename, 'w') as f:
|
||||||
|
json.dump(metrics_results, f)
|
||||||
|
log.info(f'Results of metrics saved to {metrics_filename}')
|
||||||
|
|
||||||
|
if self.config["max_ckpt_to_keep"] > 0:
|
||||||
|
if len(self.metrics_queue) == self.metrics_queue.maxlen:
|
||||||
|
todel = self.metrics_queue.popleft()
|
||||||
|
os.remove(todel)
|
||||||
|
self.metrics_queue.append(metrics_filename)
|
||||||
|
|
||||||
|
if epoch_idx == 'best':
|
||||||
|
self.copy_best_predictions(split)
|
||||||
|
|
||||||
|
def copy_best_results(self, split, epoch_idx):
|
||||||
|
to_print = 'Copy '
|
||||||
|
|
||||||
|
if not self.config['skip_saving_ckpt']:
|
||||||
|
ckpt_path = osp.join(
|
||||||
|
self.config['log_dir'], f'epoch_{epoch_idx}.ckpt')
|
||||||
|
best_ckpt_path = ckpt_path.replace(
|
||||||
|
f'{epoch_idx}.ckpt', 'best.ckpt')
|
||||||
|
shutil.copyfile(ckpt_path, best_ckpt_path)
|
||||||
|
to_print += best_ckpt_path + ' '
|
||||||
|
|
||||||
|
metrics_filename = osp.join(
|
||||||
|
self.config['log_dir'], f'metrics_epoch_{epoch_idx}.json')
|
||||||
|
best_metric_filename = metrics_filename.replace(
|
||||||
|
f'{epoch_idx}.json', 'best.json')
|
||||||
|
shutil.copyfile(metrics_filename, best_metric_filename)
|
||||||
|
to_print += best_metric_filename + ' '
|
||||||
|
|
||||||
|
log.info(to_print)
|
||||||
|
|
||||||
|
def copy_best_predictions(self, split):
|
||||||
|
to_print = 'Copy '
|
||||||
|
|
||||||
|
visdial_output_dir = osp.join(self.config['visdial_output_dir'], split)
|
||||||
|
if osp.exists(visdial_output_dir):
|
||||||
|
dir_best = visdial_output_dir.replace('output', 'output_best')
|
||||||
|
if osp.exists(dir_best):
|
||||||
|
shutil.rmtree(dir_best)
|
||||||
|
shutil.copytree(visdial_output_dir, dir_best)
|
||||||
|
to_print += dir_best + ' '
|
||||||
|
|
||||||
|
log.info(to_print)
|
||||||
|
|
||||||
|
def get_ckpt(self):
|
||||||
|
ckpt = {
|
||||||
|
'epoch_idx': self.epoch_idx,
|
||||||
|
'max_metric': self.max_metric,
|
||||||
|
'seed': self.config['random_seed'],
|
||||||
|
'optimizer': self.optimizer.state_dict(),
|
||||||
|
'scheduler': self.scheduler.state_dict()
|
||||||
|
}
|
||||||
|
if self.config['parallel']:
|
||||||
|
ckpt['model_state_dict'] = self.model.module.state_dict()
|
||||||
|
else:
|
||||||
|
ckpt['model_state_dict'] = self.model.state_dict()
|
||||||
|
if self.config['dp_type'] == 'apex':
|
||||||
|
ckpt['amp'] = amp.state_dict()
|
||||||
|
return ckpt
|
||||||
|
|
||||||
|
def set_ckpt(self, ckpt_dict):
|
||||||
|
if not self.config['restarts']:
|
||||||
|
self.epoch_idx = ckpt_dict.get('epoch_idx', -1) + 1
|
||||||
|
|
||||||
|
if not self.config['resets_max_metric']:
|
||||||
|
self.max_metric = ckpt_dict.get('max_metric', -1)
|
||||||
|
|
||||||
|
if self.config['parallel']:
|
||||||
|
model = self.model.module
|
||||||
|
else:
|
||||||
|
model = self.model
|
||||||
|
|
||||||
|
model_state_dict = model.state_dict()
|
||||||
|
former_dict = {
|
||||||
|
k: v for k, v in ckpt_dict['model_state_dict'].items() if k in model_state_dict}
|
||||||
|
|
||||||
|
if self.config['display']:
|
||||||
|
log.info("number of keys transferred: %d" % len(former_dict))
|
||||||
|
assert len(former_dict.keys()) > 0
|
||||||
|
|
||||||
|
model_state_dict.update(former_dict)
|
||||||
|
|
||||||
|
model.load_state_dict(model_state_dict)
|
||||||
|
if self.config['display']:
|
||||||
|
log.info('loaded model')
|
||||||
|
del model_state_dict, former_dict
|
||||||
|
|
||||||
|
if not self.config['validating'] and not (self.config['uses_new_optimizer'] or self.config['sets_new_lr']):
|
||||||
|
if 'optimizer' in ckpt_dict:
|
||||||
|
self.optimizer.load_state_dict(ckpt_dict['optimizer'])
|
||||||
|
if self.config['display']:
|
||||||
|
log.info('loaded optimizer')
|
||||||
|
if 'scheduler' in ckpt_dict:
|
||||||
|
self.scheduler.last_epcoh = ckpt_dict['epoch_idx'] * \
|
||||||
|
self.config['num_iter_per_epoch']
|
||||||
|
self.scheduler.load_state_dict(ckpt_dict['scheduler'])
|
||||||
|
|
||||||
|
if 'amp' in ckpt_dict and self.config['dp_type'] == 'apex':
|
||||||
|
amp.load_state_dict(ckpt_dict['amp'])
|
||||||
|
|
||||||
|
del ckpt_dict
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def save_ckpt(self):
|
||||||
|
ckpt_path = f'{self.config["log_dir"]}/epoch_{self.epoch_idx}.ckpt'
|
||||||
|
log.info(f'saving checkpoint {ckpt_path}')
|
||||||
|
ckpt = self.get_ckpt()
|
||||||
|
if self.config['skip_saving_ckpt']:
|
||||||
|
return ckpt_path
|
||||||
|
torch_version = float(torch.__version__[:3])
|
||||||
|
if torch_version - 1.4 > 1e-3:
|
||||||
|
torch.save(ckpt, f=ckpt_path, _use_new_zipfile_serialization=False)
|
||||||
|
else:
|
||||||
|
torch.save(ckpt, f=ckpt_path)
|
||||||
|
del ckpt
|
||||||
|
|
||||||
|
if not (self.config['parallel'] and self.config['dp_type'] in ['ddp', 'apex']):
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
if self.config["max_ckpt_to_keep"] > 0:
|
||||||
|
if len(self.checkpoint_queue) == self.checkpoint_queue.maxlen:
|
||||||
|
todel = self.checkpoint_queue.popleft()
|
||||||
|
os.remove(todel)
|
||||||
|
self.checkpoint_queue.append(ckpt_path)
|
||||||
|
|
||||||
|
return ckpt_path
|
||||||
|
|
||||||
|
def save_ckpt_best(self):
|
||||||
|
ckpt_path = f'{self.config["log_dir"]}/epoch_best.ckpt'
|
||||||
|
log.info(f'saving checkpoint {ckpt_path}')
|
||||||
|
ckpt = self.get_ckpt()
|
||||||
|
torch.save(ckpt, f=ckpt_path)
|
||||||
|
del ckpt
|
||||||
|
return ckpt_path
|
||||||
|
|
||||||
|
def load_ckpt_best(self):
|
||||||
|
ckpt_path = f'{osp.dirname(self.config["log_dir"])}/epoch_best.ckpt'
|
||||||
|
if not osp.exists(ckpt_path):
|
||||||
|
ckpt_paths = [path for path in os.listdir(
|
||||||
|
f'{self.config["log_dir"]}/') if path.endswith('.ckpt') and 'best' not in path]
|
||||||
|
if len(ckpt_paths) == 0:
|
||||||
|
if self.config['display']:
|
||||||
|
log.info(f'No .ckpt found in {self.config["log_dir"]}')
|
||||||
|
return
|
||||||
|
|
||||||
|
def sort_func(x): return int(re.search(r"(\d+)", x).groups()[0])
|
||||||
|
ckpt_path = f'{self.config["log_dir"]}/{sorted(ckpt_paths, key=sort_func, reverse=True)[0]}'
|
||||||
|
if self.config['display']:
|
||||||
|
log.info(f'loading checkpoint {ckpt_path}')
|
||||||
|
map_location = {'cuda:0': f'cuda:{self.gpu_rank}'}
|
||||||
|
self.set_ckpt(torch.load(ckpt_path, map_location=map_location))
|
||||||
|
|
||||||
|
def load_ckpt(self, ckpt_path=None):
|
||||||
|
if not ckpt_path:
|
||||||
|
if self.config['validating'] or self.config['loads_best_ckpt']:
|
||||||
|
ckpt_path = f'{self.config["log_dir"]}/epoch_best.ckpt'
|
||||||
|
else:
|
||||||
|
ckpt_paths = [path for path in os.listdir(
|
||||||
|
f'{self.config["log_dir"]}/') if path.endswith('.ckpt') and 'best' not in path]
|
||||||
|
if len(ckpt_paths) == 0:
|
||||||
|
if self.config['display']:
|
||||||
|
log.info(f'No .ckpt found in {self.config["log_dir"]}')
|
||||||
|
return
|
||||||
|
|
||||||
|
def sort_func(x): return int(
|
||||||
|
re.search(r"(\d+)", x).groups()[0])
|
||||||
|
ckpt_path = f'{self.config["log_dir"]}/{sorted(ckpt_paths, key=sort_func, reverse=True)[0]}'
|
||||||
|
|
||||||
|
if self.config['display']:
|
||||||
|
log.info(f'loading checkpoint {ckpt_path}')
|
||||||
|
epoch_name = osp.split(ckpt_path)[1].split('.')[0]
|
||||||
|
if re.search(r"(\d+)", epoch_name):
|
||||||
|
self.checkpoint_queue.append(ckpt_path)
|
||||||
|
metrics_filename = osp.join(
|
||||||
|
self.config['log_dir'], f'metrics_{epoch_name}.json')
|
||||||
|
if osp.exists(metrics_filename):
|
||||||
|
self.metrics_queue.append(metrics_filename)
|
||||||
|
|
||||||
|
map_location = {'cuda:0': f'cuda:{self.gpu_rank}'}
|
||||||
|
self.set_ckpt(torch.load(ckpt_path, map_location=map_location))
|
||||||
|
|
||||||
|
def match_model_key(self, pretrained_dict, model_dict):
|
||||||
|
matched_dict = dict()
|
||||||
|
for key in pretrained_dict:
|
||||||
|
if key in model_dict:
|
||||||
|
matched_key = key
|
||||||
|
elif key.startswith('encoder.') and key[8:] in model_dict:
|
||||||
|
matched_key = key[8:]
|
||||||
|
elif key.startswith('module.') and key[7:] in model_dict:
|
||||||
|
matched_key = key[7:]
|
||||||
|
elif 'encoder.' + key in model_dict:
|
||||||
|
matched_key = 'encoder.' + key
|
||||||
|
elif 'module.' + key in model_dict:
|
||||||
|
matched_key = 'module.' + key
|
||||||
|
else:
|
||||||
|
# not_found.append(key)
|
||||||
|
continue
|
||||||
|
matched_dict[matched_key] = pretrained_dict[key]
|
||||||
|
|
||||||
|
not_found = ""
|
||||||
|
for k in model_dict:
|
||||||
|
if k not in matched_dict:
|
||||||
|
not_found += k + '\n'
|
||||||
|
|
||||||
|
log.info("Keys from model_dict that were not found in pretrained_dict:")
|
||||||
|
log.info(not_found)
|
||||||
|
return matched_dict
|
||||||
|
|
||||||
|
def load_pretrained_vilbert(self, start_from=None):
|
||||||
|
if start_from is not None:
|
||||||
|
self.config["start_path"] = start_from
|
||||||
|
if self.config['training'] or self.config['debugging']:
|
||||||
|
ckpt_paths = [path for path in os.listdir(
|
||||||
|
f'{self.config["log_dir"]}/') if path.endswith('.ckpt') and 'best' not in path]
|
||||||
|
if len(ckpt_paths) > 0:
|
||||||
|
if self.config['display']:
|
||||||
|
log.info('Continue training')
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.config['display']:
|
||||||
|
log.info(
|
||||||
|
f'Loading pretrained VilBERT from {self.config["start_path"]}')
|
||||||
|
map_location = {'cuda:0': f'cuda:{self.gpu_rank}'}
|
||||||
|
pretrained_dict = torch.load(
|
||||||
|
self.config['start_path'], map_location=map_location)
|
||||||
|
if 'model_state_dict' in pretrained_dict:
|
||||||
|
pretrained_dict = pretrained_dict['model_state_dict']
|
||||||
|
if self.config['parallel']:
|
||||||
|
model = self.model.module
|
||||||
|
else:
|
||||||
|
model = self.model
|
||||||
|
model_dict = model.state_dict()
|
||||||
|
|
||||||
|
matched_dict = self.match_model_key(pretrained_dict, model_dict)
|
||||||
|
|
||||||
|
if self.config['display']:
|
||||||
|
log.info("number of keys transferred: %d" % len(matched_dict))
|
||||||
|
assert len(matched_dict.keys()) > 0
|
||||||
|
model_dict.update(matched_dict)
|
||||||
|
model.load_state_dict(model_dict)
|
||||||
|
|
||||||
|
del pretrained_dict, model_dict, matched_dict
|
||||||
|
if not self.config['parallel'] or self.config['dp_type'] == 'dp':
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
if self.config['display']:
|
||||||
|
log.info(f'Pretrained VilBERT loaded')
|
379
models/vdgr.py
Normal file
379
models/vdgr.py
Normal file
|
@ -0,0 +1,379 @@
|
||||||
|
import sys
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
sys.path.append('../')
|
||||||
|
from utils.model_utils import listMLE, approxNDCGLoss, listNet, neuralNDCG, neuralNDCG_transposed
|
||||||
|
|
||||||
|
from utils.data_utils import sequence_mask
|
||||||
|
from utils.optim_utils import init_optim
|
||||||
|
from models.runner import Runner
|
||||||
|
|
||||||
|
from models.vilbert_dialog import BertForMultiModalPreTraining, BertConfig
|
||||||
|
|
||||||
|
|
||||||
|
class VDGR(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config_path, device, use_apex=False, cache_dir=None):
|
||||||
|
super(VDGR, self).__init__()
|
||||||
|
config = BertConfig.from_json_file(config_path)
|
||||||
|
|
||||||
|
self.bert_pretrained = BertForMultiModalPreTraining.from_pretrained('bert-base-uncased', config, device, use_apex=use_apex, cache_dir=cache_dir)
|
||||||
|
self.bert_pretrained.train()
|
||||||
|
|
||||||
|
def forward(self, input_ids, image_feat, image_loc, image_edge_indices, image_edge_attributes,
|
||||||
|
question_edge_indices, question_edge_attributes, question_limits,
|
||||||
|
history_edge_indices, history_sep_indices,
|
||||||
|
sep_indices=None, sep_len=None, token_type_ids=None,
|
||||||
|
attention_mask=None, masked_lm_labels=None, next_sentence_label=None,
|
||||||
|
image_attention_mask=None, image_label=None, image_target=None):
|
||||||
|
|
||||||
|
masked_lm_loss = None
|
||||||
|
masked_img_loss = None
|
||||||
|
nsp_loss = None
|
||||||
|
seq_relationship_score = None
|
||||||
|
|
||||||
|
if next_sentence_label is not None and masked_lm_labels \
|
||||||
|
is not None and image_target is not None:
|
||||||
|
# train mode, output losses
|
||||||
|
masked_lm_loss, masked_img_loss, nsp_loss, _, _, seq_relationship_score, _ = \
|
||||||
|
self.bert_pretrained(
|
||||||
|
input_ids, image_feat, image_loc, image_edge_indices, image_edge_attributes,
|
||||||
|
question_edge_indices, question_edge_attributes, question_limits,
|
||||||
|
history_edge_indices, history_sep_indices, sep_indices=sep_indices, sep_len=sep_len, \
|
||||||
|
token_type_ids=token_type_ids, attention_mask=attention_mask, masked_lm_labels=masked_lm_labels, \
|
||||||
|
next_sentence_label=next_sentence_label, image_attention_mask=image_attention_mask,\
|
||||||
|
image_label=image_label, image_target=image_target)
|
||||||
|
else:
|
||||||
|
#inference, output scores
|
||||||
|
_, _, seq_relationship_score, _, _, _ = \
|
||||||
|
self.bert_pretrained(
|
||||||
|
input_ids, image_feat, image_loc, image_edge_indices, image_edge_attributes,
|
||||||
|
question_edge_indices, question_edge_attributes, question_limits,
|
||||||
|
history_edge_indices, history_sep_indices,
|
||||||
|
sep_indices=sep_indices, sep_len=sep_len, \
|
||||||
|
token_type_ids=token_type_ids, attention_mask=attention_mask, masked_lm_labels=masked_lm_labels, \
|
||||||
|
next_sentence_label=next_sentence_label, image_attention_mask=image_attention_mask,\
|
||||||
|
image_label=image_label, image_target=image_target)
|
||||||
|
|
||||||
|
out = (masked_lm_loss, masked_img_loss, nsp_loss, seq_relationship_score)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class SparseRunner(Runner):
|
||||||
|
def __init__(self, config):
|
||||||
|
super(SparseRunner, self).__init__(config)
|
||||||
|
self.model = VDGR(
|
||||||
|
self.config['model_config'], self.config['device'],
|
||||||
|
use_apex=self.config['dp_type'] == 'apex',
|
||||||
|
cache_dir=self.config['bert_cache_dir'])
|
||||||
|
|
||||||
|
self.model.to(self.config['device'])
|
||||||
|
|
||||||
|
if not self.config['validating'] or self.config['dp_type'] == 'apex':
|
||||||
|
self.optimizer, self.scheduler = init_optim(self.model, self.config)
|
||||||
|
|
||||||
|
def forward(self, batch, eval_visdial=False):
|
||||||
|
# load data
|
||||||
|
for key in batch:
|
||||||
|
if isinstance(batch[key], torch.Tensor):
|
||||||
|
batch[key] = batch[key].to(self.config['device'])
|
||||||
|
elif isinstance(batch[key], list):
|
||||||
|
if key != 'dialog_info': # Do not send the dialog_info item to the gpu
|
||||||
|
batch[key] = [x.to(self.config['device']) for x in batch[key]]
|
||||||
|
|
||||||
|
tokens = batch['tokens']
|
||||||
|
segments = batch['segments']
|
||||||
|
sep_indices = batch['sep_indices']
|
||||||
|
mask = batch['mask']
|
||||||
|
hist_len = batch['hist_len']
|
||||||
|
image_feat = batch['image_feat']
|
||||||
|
image_loc = batch['image_loc']
|
||||||
|
image_mask = batch['image_mask']
|
||||||
|
next_sentence_labels = batch.get('next_sentence_labels', None)
|
||||||
|
image_target = batch.get('image_target', None)
|
||||||
|
image_label = batch.get('image_label', None)
|
||||||
|
# load the graph data
|
||||||
|
image_edge_indices = batch['image_edge_indices']
|
||||||
|
image_edge_attributes = batch['image_edge_attributes']
|
||||||
|
question_edge_indices = batch['question_edge_indices']
|
||||||
|
question_edge_attributes = batch['question_edge_attributes']
|
||||||
|
question_limits = batch['question_limits']
|
||||||
|
history_edge_indices = batch['history_edge_indices']
|
||||||
|
history_sep_indices = batch['history_sep_indices']
|
||||||
|
|
||||||
|
sequence_lengths = torch.gather(sep_indices, 1, hist_len.view(-1, 1)) + 1
|
||||||
|
sequence_lengths = sequence_lengths.squeeze(1)
|
||||||
|
attention_mask_lm_nsp = sequence_mask(sequence_lengths, max_len=tokens.shape[1])
|
||||||
|
sep_len = hist_len + 1
|
||||||
|
|
||||||
|
losses = OrderedDict()
|
||||||
|
|
||||||
|
if eval_visdial:
|
||||||
|
num_lines = tokens.size(0)
|
||||||
|
line_batch_size = self.config['eval_line_batch_size']
|
||||||
|
num_line_batches = num_lines // line_batch_size
|
||||||
|
if num_lines % line_batch_size > 0:
|
||||||
|
num_line_batches += 1
|
||||||
|
nsp_scores = []
|
||||||
|
for j in range(num_line_batches):
|
||||||
|
# create chunks of the original batch
|
||||||
|
chunk_range = range(j*line_batch_size, min((j+1)*line_batch_size, num_lines))
|
||||||
|
tokens_chunk = tokens[chunk_range]
|
||||||
|
segments_chunk = segments[chunk_range]
|
||||||
|
sep_indices_chunk = sep_indices[chunk_range]
|
||||||
|
mask_chunk = mask[chunk_range]
|
||||||
|
sep_len_chunk = sep_len[chunk_range]
|
||||||
|
attention_mask_lm_nsp_chunk = attention_mask_lm_nsp[chunk_range]
|
||||||
|
image_feat_chunk = image_feat[chunk_range]
|
||||||
|
image_loc_chunk = image_loc[chunk_range]
|
||||||
|
image_mask_chunk = image_mask[chunk_range]
|
||||||
|
image_edge_indices_chunk = image_edge_indices[chunk_range[0]: chunk_range[-1]+1]
|
||||||
|
image_edge_attributes_chunk = image_edge_attributes[chunk_range[0]: chunk_range[-1]+1]
|
||||||
|
question_edge_indices_chunk = question_edge_indices[chunk_range[0]: chunk_range[-1]+1]
|
||||||
|
question_edge_attributes_chunk = question_edge_attributes[chunk_range[0]: chunk_range[-1]+1]
|
||||||
|
question_limits_chunk = question_limits[chunk_range[0]: chunk_range[-1]+1]
|
||||||
|
history_edge_indices_chunk = history_edge_indices[chunk_range[0]: chunk_range[-1]+1]
|
||||||
|
history_sep_indices_chunk = history_sep_indices[chunk_range[0]: chunk_range[-1]+1]
|
||||||
|
|
||||||
|
_ , _ , _, nsp_scores_chunk = \
|
||||||
|
self.model(
|
||||||
|
tokens_chunk,
|
||||||
|
image_feat_chunk,
|
||||||
|
image_loc_chunk,
|
||||||
|
image_edge_indices_chunk,
|
||||||
|
image_edge_attributes_chunk,
|
||||||
|
question_edge_indices_chunk,
|
||||||
|
question_edge_attributes_chunk,
|
||||||
|
question_limits_chunk,
|
||||||
|
history_edge_indices_chunk,
|
||||||
|
history_sep_indices_chunk,
|
||||||
|
sep_indices=sep_indices_chunk,
|
||||||
|
sep_len=sep_len_chunk,
|
||||||
|
token_type_ids=segments_chunk,
|
||||||
|
masked_lm_labels=mask_chunk,
|
||||||
|
attention_mask=attention_mask_lm_nsp_chunk,
|
||||||
|
image_attention_mask=image_mask_chunk
|
||||||
|
)
|
||||||
|
nsp_scores.append(nsp_scores_chunk)
|
||||||
|
nsp_scores = torch.cat(nsp_scores, 0)
|
||||||
|
|
||||||
|
else:
|
||||||
|
losses['lm_loss'], losses['img_loss'], losses['nsp_loss'], nsp_scores = \
|
||||||
|
self.model(
|
||||||
|
tokens,
|
||||||
|
image_feat,
|
||||||
|
image_loc,
|
||||||
|
image_edge_indices,
|
||||||
|
image_edge_attributes,
|
||||||
|
question_edge_indices,
|
||||||
|
question_edge_attributes,
|
||||||
|
question_limits,
|
||||||
|
history_edge_indices,
|
||||||
|
history_sep_indices,
|
||||||
|
next_sentence_label=next_sentence_labels,
|
||||||
|
image_target=image_target,
|
||||||
|
image_label=image_label,
|
||||||
|
sep_indices=sep_indices,
|
||||||
|
sep_len=sep_len,
|
||||||
|
token_type_ids=segments,
|
||||||
|
masked_lm_labels=mask,
|
||||||
|
attention_mask=attention_mask_lm_nsp,
|
||||||
|
image_attention_mask=image_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
losses['tot_loss'] = 0
|
||||||
|
for key in ['lm_loss', 'img_loss', 'nsp_loss']:
|
||||||
|
if key in losses and losses[key] is not None:
|
||||||
|
losses[key] = losses[key].mean()
|
||||||
|
losses['tot_loss'] += self.config[f'{key}_coeff'] * losses[key]
|
||||||
|
|
||||||
|
output = {
|
||||||
|
'losses': losses,
|
||||||
|
'nsp_scores': nsp_scores
|
||||||
|
}
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class DenseRunner(Runner):
|
||||||
|
def __init__(self, config):
|
||||||
|
super(DenseRunner, self).__init__(config)
|
||||||
|
self.model = VDGR(
|
||||||
|
self.config['model_config'], self.config['device'],
|
||||||
|
use_apex=self.config['dp_type'] == 'apex',
|
||||||
|
cache_dir=self.config['bert_cache_dir'])
|
||||||
|
|
||||||
|
if not(self.config['parallel'] and self.config['dp_type'] == 'dp'):
|
||||||
|
self.model.to(self.config['device'])
|
||||||
|
|
||||||
|
if self.config['dense_loss'] == 'ce':
|
||||||
|
self.dense_loss = nn.KLDivLoss(reduction='batchmean')
|
||||||
|
elif self.config['dense_loss'] == 'listmle':
|
||||||
|
self.dense_loss = listMLE
|
||||||
|
elif self.config['dense_loss'] == 'listnet':
|
||||||
|
self.dense_loss = listNet
|
||||||
|
elif self.config['dense_loss'] == 'approxndcg':
|
||||||
|
self.dense_loss = approxNDCGLoss
|
||||||
|
elif self.config['dense_loss'] == 'neural_ndcg':
|
||||||
|
self.dense_loss = neuralNDCG
|
||||||
|
elif self.config['dense_loss'] == 'neural_ndcg_transposed':
|
||||||
|
self.dense_loss = neuralNDCG_transposed
|
||||||
|
else:
|
||||||
|
raise ValueError('dense_loss must be one of ce, listmle, listnet, approxndcg, neural_ndcg, neural_ndcg_transposed')
|
||||||
|
|
||||||
|
if not self.config['validating'] or self.config['dp_type'] == 'apex':
|
||||||
|
self.optimizer, self.scheduler = init_optim(self.model, self.config)
|
||||||
|
|
||||||
|
def forward(self, batch, eval_visdial=False):
|
||||||
|
# load data
|
||||||
|
for key in batch:
|
||||||
|
if isinstance(batch[key], torch.Tensor):
|
||||||
|
batch[key] = batch[key].to(self.config['device'])
|
||||||
|
elif isinstance(batch[key], list):
|
||||||
|
if key != 'dialog_info': # Do not send the dialog_info item to the gpu
|
||||||
|
batch[key] = [x.to(self.config['device']) for x in batch[key]]
|
||||||
|
|
||||||
|
# get embedding and forward visdial
|
||||||
|
tokens = batch['tokens']
|
||||||
|
segments = batch['segments']
|
||||||
|
sep_indices = batch['sep_indices']
|
||||||
|
mask = batch['mask']
|
||||||
|
hist_len = batch['hist_len']
|
||||||
|
image_feat = batch['image_feat']
|
||||||
|
image_loc = batch['image_loc']
|
||||||
|
image_mask = batch['image_mask']
|
||||||
|
next_sentence_labels = batch.get('next_sentence_labels', None)
|
||||||
|
image_target = batch.get('image_target', None)
|
||||||
|
image_label = batch.get('image_label', None)
|
||||||
|
|
||||||
|
# load the graph data
|
||||||
|
image_edge_indices = batch['image_edge_indices']
|
||||||
|
image_edge_attributes = batch['image_edge_attributes']
|
||||||
|
question_edge_indices = batch['question_edge_indices']
|
||||||
|
question_edge_attributes = batch['question_edge_attributes']
|
||||||
|
question_limits = batch['question_limits']
|
||||||
|
history_edge_indices = batch['history_edge_indices']
|
||||||
|
assert history_edge_indices[0].size(0) == 2
|
||||||
|
history_sep_indices = batch['history_sep_indices']
|
||||||
|
|
||||||
|
sequence_lengths = torch.gather(sep_indices, 1, hist_len.view(-1, 1)) + 1
|
||||||
|
sequence_lengths = sequence_lengths.squeeze(1)
|
||||||
|
attention_mask_lm_nsp = sequence_mask(sequence_lengths, max_len=tokens.shape[1])
|
||||||
|
sep_len = hist_len + 1
|
||||||
|
|
||||||
|
losses = OrderedDict()
|
||||||
|
|
||||||
|
if eval_visdial:
|
||||||
|
num_lines = tokens.size(0)
|
||||||
|
line_batch_size = self.config['eval_line_batch_size']
|
||||||
|
num_line_batches = num_lines // line_batch_size
|
||||||
|
if num_lines % line_batch_size > 0:
|
||||||
|
num_line_batches += 1
|
||||||
|
nsp_scores = []
|
||||||
|
for j in range(num_line_batches):
|
||||||
|
# create chunks of the original batch
|
||||||
|
chunk_range = range(j*line_batch_size, min((j+1)*line_batch_size, num_lines))
|
||||||
|
tokens_chunk = tokens[chunk_range]
|
||||||
|
segments_chunk = segments[chunk_range]
|
||||||
|
sep_indices_chunk = sep_indices[chunk_range]
|
||||||
|
mask_chunk = mask[chunk_range]
|
||||||
|
sep_len_chunk = sep_len[chunk_range]
|
||||||
|
attention_mask_lm_nsp_chunk = attention_mask_lm_nsp[chunk_range]
|
||||||
|
image_feat_chunk = image_feat[chunk_range]
|
||||||
|
image_loc_chunk = image_loc[chunk_range]
|
||||||
|
image_mask_chunk = image_mask[chunk_range]
|
||||||
|
image_edge_indices_chunk = image_edge_indices[chunk_range[0]: chunk_range[-1]+1]
|
||||||
|
image_edge_attributes_chunk = image_edge_attributes[chunk_range[0]: chunk_range[-1]+1]
|
||||||
|
question_edge_indices_chunk = question_edge_indices[chunk_range[0]: chunk_range[-1]+1]
|
||||||
|
question_edge_attributes_chunk = question_edge_attributes[chunk_range[0]: chunk_range[-1]+1]
|
||||||
|
question_limits_chunk = question_limits[chunk_range[0]: chunk_range[-1]+1]
|
||||||
|
history_edge_indices_chunk = history_edge_indices[chunk_range[0]: chunk_range[-1]+1]
|
||||||
|
history_sep_indices_chunk = history_sep_indices[chunk_range[0]: chunk_range[-1]+1]
|
||||||
|
|
||||||
|
_, _, _, nsp_scores_chunk = \
|
||||||
|
self.model(
|
||||||
|
tokens_chunk,
|
||||||
|
image_feat_chunk,
|
||||||
|
image_loc_chunk,
|
||||||
|
image_edge_indices_chunk,
|
||||||
|
image_edge_attributes_chunk,
|
||||||
|
question_edge_indices_chunk,
|
||||||
|
question_edge_attributes_chunk,
|
||||||
|
question_limits_chunk,
|
||||||
|
history_edge_indices_chunk,
|
||||||
|
history_sep_indices_chunk,
|
||||||
|
sep_indices=sep_indices_chunk,
|
||||||
|
sep_len=sep_len_chunk,
|
||||||
|
token_type_ids=segments_chunk,
|
||||||
|
masked_lm_labels=mask_chunk,
|
||||||
|
attention_mask=attention_mask_lm_nsp_chunk,
|
||||||
|
image_attention_mask=image_mask_chunk
|
||||||
|
)
|
||||||
|
nsp_scores.append(nsp_scores_chunk)
|
||||||
|
nsp_scores = torch.cat(nsp_scores, 0)
|
||||||
|
|
||||||
|
else:
|
||||||
|
_, _, _, nsp_scores = \
|
||||||
|
self.model(
|
||||||
|
tokens,
|
||||||
|
image_feat,
|
||||||
|
image_loc,
|
||||||
|
image_edge_indices,
|
||||||
|
image_edge_attributes,
|
||||||
|
question_edge_indices,
|
||||||
|
question_edge_attributes,
|
||||||
|
question_limits,
|
||||||
|
history_edge_indices,
|
||||||
|
history_sep_indices,
|
||||||
|
next_sentence_label=next_sentence_labels,
|
||||||
|
image_target=image_target,
|
||||||
|
image_label=image_label,
|
||||||
|
sep_indices=sep_indices,
|
||||||
|
sep_len=sep_len,
|
||||||
|
token_type_ids=segments,
|
||||||
|
masked_lm_labels=mask,
|
||||||
|
attention_mask=attention_mask_lm_nsp,
|
||||||
|
image_attention_mask=image_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if nsp_scores is not None:
|
||||||
|
nsp_scores_output = nsp_scores.detach().clone()
|
||||||
|
if not eval_visdial:
|
||||||
|
nsp_scores = nsp_scores.view(-1, self.config['num_options_dense'], 2)
|
||||||
|
if 'next_sentence_labels' in batch and self.config['nsp_loss_coeff'] > 0:
|
||||||
|
next_sentence_labels = batch['next_sentence_labels'].to(self.config['device'])
|
||||||
|
losses['nsp_loss'] = F.cross_entropy(nsp_scores.view(-1,2), next_sentence_labels.view(-1))
|
||||||
|
else:
|
||||||
|
losses['nsp_loss'] = None
|
||||||
|
|
||||||
|
if not eval_visdial:
|
||||||
|
gt_relevance = batch['gt_relevance'].to(self.config['device'])
|
||||||
|
nsp_scores = nsp_scores[:, :, 0]
|
||||||
|
if self.config['dense_loss'] == 'ce':
|
||||||
|
losses['dense_loss'] = self.dense_loss(F.log_softmax(nsp_scores, dim=1), F.softmax(gt_relevance, dim=1))
|
||||||
|
else:
|
||||||
|
losses['dense_loss'] = self.dense_loss(nsp_scores, gt_relevance)
|
||||||
|
else:
|
||||||
|
losses['dense_loss'] = None
|
||||||
|
else:
|
||||||
|
nsp_scores_output = None
|
||||||
|
losses['nsp_loss'] = None
|
||||||
|
losses['dense_loss'] = None
|
||||||
|
|
||||||
|
losses['tot_loss'] = 0
|
||||||
|
for key in ['nsp_loss', 'dense_loss']:
|
||||||
|
if key in losses and losses[key] is not None:
|
||||||
|
losses[key] = losses[key].mean()
|
||||||
|
losses['tot_loss'] += self.config[f'{key}_coeff'] * losses[key]
|
||||||
|
|
||||||
|
output = {
|
||||||
|
'losses': losses,
|
||||||
|
'nsp_scores': nsp_scores_output
|
||||||
|
}
|
||||||
|
|
||||||
|
return output
|
2021
models/vilbert_dialog.py
Normal file
2021
models/vilbert_dialog.py
Normal file
File diff suppressed because it is too large
Load diff
18
setup_data.sh
Normal file
18
setup_data.sh
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
cd data
|
||||||
|
# Exract the graphs
|
||||||
|
tar xvfz history_adj_matrices.tar.gz
|
||||||
|
tar xvfz question_adj_matrices.tar.gz
|
||||||
|
tar xvfz img_adj_matrices.tar.gz
|
||||||
|
|
||||||
|
# Remove the .tar files
|
||||||
|
rm *.tar.gz
|
||||||
|
|
||||||
|
# Download the preprocessed image features
|
||||||
|
mkdir visdial_img_feat.lmdb
|
||||||
|
wget https://s3.amazonaws.com/visdial-bert/data/visdial_image_feats.lmdb/data.mdb -O visdial_img_feat.lmdb/data.mdb
|
||||||
|
wget https://s3.amazonaws.com/visdial-bert/data/visdial_image_feats.lmdb/lock.mdb -O visdial_img_feat.lmdb/lock.mdb
|
||||||
|
|
||||||
|
echo Data setup successfully...
|
||||||
|
|
||||||
|
cd ..
|
||||||
|
|
0
utils/__init__.py
Normal file
0
utils/__init__.py
Normal file
290
utils/data_utils.py
Normal file
290
utils/data_utils.py
Normal file
|
@ -0,0 +1,290 @@
|
||||||
|
import torch
|
||||||
|
from torch.autograd import Variable
|
||||||
|
import random
|
||||||
|
import pickle
|
||||||
|
import numpy as np
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
|
||||||
|
def load_pickle_lines(filename):
|
||||||
|
data = []
|
||||||
|
with open(filename, 'rb') as f:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
data.append(pickle.load(f))
|
||||||
|
except EOFError:
|
||||||
|
break
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def flatten(l):
|
||||||
|
return [item for sublist in l for item in sublist]
|
||||||
|
|
||||||
|
|
||||||
|
def build_len_mask_batch(
|
||||||
|
# [batch_size], []
|
||||||
|
len_batch, max_len=None
|
||||||
|
):
|
||||||
|
if max_len is None:
|
||||||
|
max_len = len_batch.max().item()
|
||||||
|
# try:
|
||||||
|
batch_size, = len_batch.shape
|
||||||
|
# [batch_size, max_len]
|
||||||
|
idxes_batch = torch.arange(max_len, device=len_batch.device).view(1, -1).repeat(batch_size, 1)
|
||||||
|
# [batch_size, max_len] = [batch_size, max_len] < [batch_size, 1]
|
||||||
|
return idxes_batch < len_batch.view(-1, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def sequence_mask(sequence_length, max_len=None):
|
||||||
|
if max_len is None:
|
||||||
|
max_len = sequence_length.data.max()
|
||||||
|
batch_size = sequence_length.size(0)
|
||||||
|
seq_range = torch.arange(0, max_len).long()
|
||||||
|
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
||||||
|
seq_range_expand = Variable(seq_range_expand)
|
||||||
|
if sequence_length.is_cuda:
|
||||||
|
seq_range_expand = seq_range_expand.to(sequence_length.device)
|
||||||
|
seq_length_expand = (sequence_length.unsqueeze(1)
|
||||||
|
.expand_as(seq_range_expand))
|
||||||
|
return seq_range_expand < seq_length_expand
|
||||||
|
|
||||||
|
def batch_iter(dataloader, params):
|
||||||
|
for epochId in range(params['num_epochs']):
|
||||||
|
for idx, batch in enumerate(dataloader):
|
||||||
|
yield epochId, idx, batch
|
||||||
|
|
||||||
|
def list2tensorpad(inp_list, max_seq_len):
|
||||||
|
inp_tensor = torch.LongTensor([inp_list])
|
||||||
|
inp_tensor_zeros = torch.zeros(1, max_seq_len, dtype=torch.long)
|
||||||
|
inp_tensor_zeros[0,:inp_tensor.shape[1]] = inp_tensor # after preprocess, inp_tensor.shape[1] must < max_seq_len
|
||||||
|
inp_tensor = inp_tensor_zeros
|
||||||
|
return inp_tensor
|
||||||
|
|
||||||
|
|
||||||
|
def encode_input(utterances, start_segment, CLS, SEP, MASK, max_seq_len=256,max_sep_len=25,mask_prob=0.2):
|
||||||
|
|
||||||
|
cur_segment = start_segment
|
||||||
|
token_id_list = []
|
||||||
|
segment_id_list = []
|
||||||
|
sep_token_indices = []
|
||||||
|
masked_token_list = []
|
||||||
|
|
||||||
|
token_id_list.append(CLS)
|
||||||
|
segment_id_list.append(cur_segment)
|
||||||
|
masked_token_list.append(0)
|
||||||
|
|
||||||
|
cur_sep_token_index = 0
|
||||||
|
|
||||||
|
for cur_utterance in utterances:
|
||||||
|
# add the masked token and keep track
|
||||||
|
cur_masked_index = [1 if random.random() < mask_prob else 0 for _ in range(len(cur_utterance))]
|
||||||
|
masked_token_list.extend(cur_masked_index)
|
||||||
|
token_id_list.extend(cur_utterance)
|
||||||
|
segment_id_list.extend([cur_segment]*len(cur_utterance))
|
||||||
|
|
||||||
|
token_id_list.append(SEP)
|
||||||
|
segment_id_list.append(cur_segment)
|
||||||
|
masked_token_list.append(0)
|
||||||
|
cur_sep_token_index = cur_sep_token_index + len(cur_utterance) + 1
|
||||||
|
sep_token_indices.append(cur_sep_token_index)
|
||||||
|
cur_segment = cur_segment ^ 1 # cur segment osciallates between 0 and 1
|
||||||
|
start_question, end_question = sep_token_indices[-3] + 1, sep_token_indices[-2]
|
||||||
|
assert end_question - start_question == len(utterances[-2])
|
||||||
|
|
||||||
|
assert len(segment_id_list) == len(token_id_list) == len(masked_token_list) == sep_token_indices[-1] + 1
|
||||||
|
# convert to tensors and pad to maximum seq length
|
||||||
|
tokens = list2tensorpad(token_id_list,max_seq_len) # [1, max_len]
|
||||||
|
masked_tokens = list2tensorpad(masked_token_list,max_seq_len)
|
||||||
|
masked_tokens[0,masked_tokens[0,:]==0] = -1
|
||||||
|
mask = masked_tokens[0,:]==1
|
||||||
|
masked_tokens[0,mask] = tokens[0,mask]
|
||||||
|
tokens[0,mask] = MASK
|
||||||
|
|
||||||
|
segment_id_list = list2tensorpad(segment_id_list,max_seq_len)
|
||||||
|
return tokens, segment_id_list, list2tensorpad(sep_token_indices,max_sep_len), masked_tokens, start_question, end_question
|
||||||
|
|
||||||
|
def encode_input_with_mask(utterances, start_segment, CLS, SEP, MASK, max_seq_len=256,max_sep_len=25,mask_prob=0.2, get_q_limits=True):
|
||||||
|
|
||||||
|
cur_segment = start_segment
|
||||||
|
token_id_list = []
|
||||||
|
segment_id_list = []
|
||||||
|
sep_token_indices = []
|
||||||
|
masked_token_list = []
|
||||||
|
input_mask_list = []
|
||||||
|
|
||||||
|
token_id_list.append(CLS)
|
||||||
|
segment_id_list.append(cur_segment)
|
||||||
|
masked_token_list.append(0)
|
||||||
|
input_mask_list.append(1)
|
||||||
|
|
||||||
|
cur_sep_token_index = 0
|
||||||
|
|
||||||
|
for cur_utterance in utterances:
|
||||||
|
# add the masked token and keep track
|
||||||
|
cur_masked_index = [1 if random.random() < mask_prob else 0 for _ in range(len(cur_utterance))]
|
||||||
|
masked_token_list.extend(cur_masked_index)
|
||||||
|
token_id_list.extend(cur_utterance)
|
||||||
|
segment_id_list.extend([cur_segment]*len(cur_utterance))
|
||||||
|
input_mask_list.extend([1]*len(cur_utterance))
|
||||||
|
|
||||||
|
token_id_list.append(SEP)
|
||||||
|
segment_id_list.append(cur_segment)
|
||||||
|
masked_token_list.append(0)
|
||||||
|
input_mask_list.append(1)
|
||||||
|
cur_sep_token_index = cur_sep_token_index + len(cur_utterance) + 1
|
||||||
|
sep_token_indices.append(cur_sep_token_index)
|
||||||
|
cur_segment = cur_segment ^ 1 # cur segment osciallates between 0 and 1
|
||||||
|
|
||||||
|
if get_q_limits:
|
||||||
|
start_question, end_question = sep_token_indices[-3] + 1, sep_token_indices[-2]
|
||||||
|
assert end_question - start_question == len(utterances[-2])
|
||||||
|
else:
|
||||||
|
start_question, end_question = -1, -1
|
||||||
|
assert len(segment_id_list) == len(token_id_list) == len(masked_token_list) ==len(input_mask_list) == sep_token_indices[-1] + 1
|
||||||
|
# convert to tensors and pad to maximum seq length
|
||||||
|
tokens = list2tensorpad(token_id_list, max_seq_len)
|
||||||
|
masked_tokens = list2tensorpad(masked_token_list, max_seq_len)
|
||||||
|
input_mask = list2tensorpad(input_mask_list,max_seq_len)
|
||||||
|
masked_tokens[0,masked_tokens[0,:]==0] = -1
|
||||||
|
mask = masked_tokens[0,:]==1
|
||||||
|
masked_tokens[0,mask] = tokens[0,mask]
|
||||||
|
tokens[0,mask] = MASK
|
||||||
|
|
||||||
|
segment_id_list = list2tensorpad(segment_id_list,max_seq_len)
|
||||||
|
return tokens, segment_id_list, list2tensorpad(sep_token_indices,max_sep_len),masked_tokens, input_mask, start_question, end_question
|
||||||
|
|
||||||
|
|
||||||
|
def encode_image_input(features, num_boxes, boxes, image_target, max_regions=37, mask_prob=0.15):
|
||||||
|
output_label = []
|
||||||
|
num_boxes = min(int(num_boxes), max_regions)
|
||||||
|
|
||||||
|
mix_boxes_pad = np.zeros((max_regions, boxes.shape[-1]))
|
||||||
|
mix_features_pad = np.zeros((max_regions, features.shape[-1]))
|
||||||
|
mix_image_target = np.zeros((max_regions, image_target.shape[-1]))
|
||||||
|
|
||||||
|
mix_boxes_pad[:num_boxes] = boxes[:num_boxes]
|
||||||
|
mix_features_pad[:num_boxes] = features[:num_boxes]
|
||||||
|
mix_image_target[:num_boxes] = image_target[:num_boxes]
|
||||||
|
|
||||||
|
boxes = mix_boxes_pad
|
||||||
|
features = mix_features_pad
|
||||||
|
image_target = mix_image_target
|
||||||
|
mask_indexes = []
|
||||||
|
for i in range(num_boxes):
|
||||||
|
prob = random.random()
|
||||||
|
# mask token with 15% probability
|
||||||
|
if prob < mask_prob:
|
||||||
|
prob /= mask_prob
|
||||||
|
|
||||||
|
# 80% randomly change token to mask token
|
||||||
|
if prob < 0.9:
|
||||||
|
features[i] = 0
|
||||||
|
output_label.append(1)
|
||||||
|
mask_indexes.append(i)
|
||||||
|
else:
|
||||||
|
# no masking token (will be ignored by loss function later)
|
||||||
|
output_label.append(-1)
|
||||||
|
|
||||||
|
image_mask = [1] * (int(num_boxes))
|
||||||
|
while len(image_mask) < max_regions:
|
||||||
|
image_mask.append(0)
|
||||||
|
output_label.append(-1)
|
||||||
|
|
||||||
|
# ensure we have atleast one region being predicted
|
||||||
|
output_label[random.randint(1,len(output_label)-1)] = 1
|
||||||
|
image_label = torch.LongTensor(output_label)
|
||||||
|
image_label[0] = 0 # make sure the <IMG> token doesn't contribute to the masked loss
|
||||||
|
image_mask = torch.tensor(image_mask).float()
|
||||||
|
|
||||||
|
features = torch.tensor(features).float()
|
||||||
|
spatials = torch.tensor(boxes).float()
|
||||||
|
image_target = torch.tensor(image_target).float()
|
||||||
|
|
||||||
|
return features, spatials, image_mask, image_target, image_label
|
||||||
|
|
||||||
|
|
||||||
|
def question_edge_masking(question_edge_indices, question_edge_attributes, mask, question_limits, mask_prob=0.4, max_len=10):
|
||||||
|
mask = mask.squeeze().tolist()
|
||||||
|
question_limits = question_limits.tolist()
|
||||||
|
question_start, question_end = question_limits
|
||||||
|
# Get the masking of the question
|
||||||
|
mask_question = mask[question_start:question_end]
|
||||||
|
masked_idx = np.argwhere(np.array(mask_question) > -1).squeeze().tolist()
|
||||||
|
if isinstance(masked_idx, (int)): # only one question token is masked
|
||||||
|
masked_idx = [masked_idx]
|
||||||
|
|
||||||
|
# get rid of all edge indices and attributes that corresond to masked tokens
|
||||||
|
edge_attr_gt = []
|
||||||
|
edge_idx_gt_gnn = []
|
||||||
|
edge_idx_gt_bert = []
|
||||||
|
for i, (question_edge_idx, question_edge_attr) in enumerate(zip(question_edge_indices, question_edge_attributes)):
|
||||||
|
if not(question_edge_idx[0] in masked_idx or question_edge_idx[1] in masked_idx):
|
||||||
|
# Masking
|
||||||
|
if random.random() < mask_prob:
|
||||||
|
edge_attr_gt.append(np.argwhere(question_edge_attr).item())
|
||||||
|
edge_idx_gt_gnn.append(question_edge_idx)
|
||||||
|
edge_idx_gt_bert.append([question_edge_idx[0] + question_start, question_edge_idx[1] + question_start])
|
||||||
|
question_edge_attr = np.zeros_like(question_edge_attr)
|
||||||
|
question_edge_attr[-1] = 1.0 # The [EDGE_MASK] special token is the last one hot vector encoding
|
||||||
|
question_edge_attributes[i] = question_edge_attr
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
# Force masking if the necessary:
|
||||||
|
if len(edge_attr_gt) == 0:
|
||||||
|
for i, (question_edge_idx, question_edge_attr) in enumerate(zip(question_edge_indices, question_edge_attributes)):
|
||||||
|
if not(question_edge_idx[0] in masked_idx or question_edge_idx[1] in masked_idx):
|
||||||
|
# Masking
|
||||||
|
edge_attr_gt.append(np.argwhere(question_edge_attr).item())
|
||||||
|
edge_idx_gt_gnn.append(question_edge_idx)
|
||||||
|
edge_idx_gt_bert.append([question_edge_idx[0] + question_start, question_edge_idx[1] + question_start])
|
||||||
|
question_edge_attr = np.zeros_like(question_edge_attr)
|
||||||
|
question_edge_attr[-1] = 1.0 # The [EDGE_MASK] special token is the last one hot vector encoding
|
||||||
|
question_edge_attributes[i] = question_edge_attr
|
||||||
|
break
|
||||||
|
|
||||||
|
# For the rare case, where the conditions for masking were not met
|
||||||
|
if len(edge_attr_gt) == 0:
|
||||||
|
edge_attr_gt.append(-1)
|
||||||
|
edge_idx_gt_gnn.append([0, question_end - question_start])
|
||||||
|
edge_idx_gt_bert.append(question_limits)
|
||||||
|
|
||||||
|
# Pad to max_len
|
||||||
|
while len(edge_attr_gt) < max_len:
|
||||||
|
edge_attr_gt.append(-1)
|
||||||
|
edge_idx_gt_gnn.append(edge_idx_gt_gnn[-1])
|
||||||
|
edge_idx_gt_bert.append(edge_idx_gt_bert[-1])
|
||||||
|
|
||||||
|
# Truncate if longer than max_len
|
||||||
|
if len(edge_attr_gt) > max_len:
|
||||||
|
edge_idx_gt_gnn = edge_idx_gt_gnn[:max_len]
|
||||||
|
edge_idx_gt_bert = edge_idx_gt_bert[:max_len]
|
||||||
|
edge_attr_gt = edge_attr_gt[:max_len]
|
||||||
|
edge_idx_gt_gnn = np.array(edge_idx_gt_gnn)
|
||||||
|
edge_idx_gt_bert = np.array(edge_idx_gt_bert)
|
||||||
|
|
||||||
|
first_edge_node_gt_gnn = list(edge_idx_gt_gnn[:, 0])
|
||||||
|
second_edge_node_gt_gnn = list(edge_idx_gt_gnn[:, 1])
|
||||||
|
|
||||||
|
first_edge_node_gt_bert = list(edge_idx_gt_bert[:, 0])
|
||||||
|
second_edge_node_gt_bert = list(edge_idx_gt_bert[:, 1])
|
||||||
|
|
||||||
|
return question_edge_attributes, edge_attr_gt, first_edge_node_gt_gnn, second_edge_node_gt_gnn, first_edge_node_gt_bert, second_edge_node_gt_bert
|
||||||
|
|
||||||
|
|
||||||
|
def to_data_list(feats, batch_idx):
|
||||||
|
feat_list = []
|
||||||
|
device = feats.device
|
||||||
|
left = 0
|
||||||
|
right = 0
|
||||||
|
batch_size = batch_idx.max().item() + 1
|
||||||
|
for batch in range(batch_size):
|
||||||
|
if batch == batch_size - 1:
|
||||||
|
right = batch_idx.size(0)
|
||||||
|
else:
|
||||||
|
right = torch.argwhere(batch_idx == batch + 1)[0].item()
|
||||||
|
idx = torch.arange(left, right).unsqueeze(-1).repeat(1, feats.size(1)).to(device)
|
||||||
|
feat_list.append(torch.gather(feats, 0, idx))
|
||||||
|
left = right
|
||||||
|
|
||||||
|
return feat_list
|
||||||
|
|
192
utils/image_features_reader.py
Normal file
192
utils/image_features_reader.py
Normal file
|
@ -0,0 +1,192 @@
|
||||||
|
from typing import List
|
||||||
|
import csv
|
||||||
|
import h5py
|
||||||
|
import numpy as np
|
||||||
|
import copy
|
||||||
|
import pickle
|
||||||
|
import lmdb # install lmdb by "pip install lmdb"
|
||||||
|
import base64
|
||||||
|
import pdb
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
class ImageFeaturesH5Reader(object):
|
||||||
|
"""
|
||||||
|
A reader for H5 files containing pre-extracted image features. A typical
|
||||||
|
H5 file is expected to have a column named "image_id", and another column
|
||||||
|
named "features".
|
||||||
|
|
||||||
|
Example of an H5 file:
|
||||||
|
```
|
||||||
|
faster_rcnn_bottomup_features.h5
|
||||||
|
|--- "image_id" [shape: (num_images, )]
|
||||||
|
|--- "features" [shape: (num_images, num_proposals, feature_size)]
|
||||||
|
+--- .attrs ("split", "train")
|
||||||
|
```
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
features_h5path : str
|
||||||
|
Path to an H5 file containing COCO train / val image features.
|
||||||
|
in_memory : bool
|
||||||
|
Whether to load the whole H5 file in memory. Beware, these files are
|
||||||
|
sometimes tens of GBs in size. Set this to true if you have sufficient
|
||||||
|
RAM - trade-off between speed and memory.
|
||||||
|
"""
|
||||||
|
def __init__(self, features_path: str, scene_graph_path: str, in_memory: bool = False):
|
||||||
|
self.features_path = features_path
|
||||||
|
self.scene_graph_path = scene_graph_path
|
||||||
|
self._in_memory = in_memory
|
||||||
|
|
||||||
|
self.env = lmdb.open(self.features_path, max_readers=1, readonly=True,
|
||||||
|
lock=False, readahead=False, meminit=False)
|
||||||
|
|
||||||
|
with self.env.begin(write=False) as txn:
|
||||||
|
self._image_ids = pickle.loads(txn.get('keys'.encode()))
|
||||||
|
|
||||||
|
self.features = [None] * len(self._image_ids)
|
||||||
|
self.num_boxes = [None] * len(self._image_ids)
|
||||||
|
self.boxes = [None] * len(self._image_ids)
|
||||||
|
self.boxes_ori = [None] * len(self._image_ids)
|
||||||
|
self.cls_prob = [None] * len(self._image_ids)
|
||||||
|
self.edge_indexes = [None] * len(self._image_ids)
|
||||||
|
self.edge_attributes = [None] * len(self._image_ids)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._image_ids)
|
||||||
|
|
||||||
|
def __getitem__(self, image_id):
|
||||||
|
|
||||||
|
image_id = str(image_id).encode()
|
||||||
|
index = self._image_ids.index(image_id)
|
||||||
|
if self._in_memory:
|
||||||
|
# Load features during first epoch, all not loaded together as it
|
||||||
|
# has a slow start.
|
||||||
|
if self.features[index] is not None:
|
||||||
|
features = self.features[index]
|
||||||
|
num_boxes = self.num_boxes[index]
|
||||||
|
image_location = self.boxes[index]
|
||||||
|
image_location_ori = self.boxes_ori[index]
|
||||||
|
cls_prob = self.cls_prob[index]
|
||||||
|
edge_indexes = self.edge_indexes[index]
|
||||||
|
edge_attributes = self.edge_attributes[index]
|
||||||
|
else:
|
||||||
|
with self.env.begin(write=False) as txn:
|
||||||
|
item = pickle.loads(txn.get(image_id))
|
||||||
|
image_id = item['image_id']
|
||||||
|
image_h = int(item['image_h'])
|
||||||
|
image_w = int(item['image_w'])
|
||||||
|
num_boxes = int(item['num_boxes'])
|
||||||
|
features = np.frombuffer(base64.b64decode(item["features"]), dtype=np.float32).reshape(num_boxes, 2048)
|
||||||
|
boxes = np.frombuffer(base64.b64decode(item['boxes']), dtype=np.float32).reshape(num_boxes, 4)
|
||||||
|
|
||||||
|
cls_prob = np.frombuffer(base64.b64decode(item['cls_prob']), dtype=np.float32).reshape(num_boxes, 1601)
|
||||||
|
# add an extra row at the top for the <IMG> tokens
|
||||||
|
g_cls_prob = np.zeros(1601, dtype=np.float32)
|
||||||
|
g_cls_prob[0] = 1
|
||||||
|
cls_prob = np.concatenate([np.expand_dims(g_cls_prob,axis=0), cls_prob], axis=0)
|
||||||
|
|
||||||
|
self.cls_prob[index] = cls_prob
|
||||||
|
|
||||||
|
g_feat = np.sum(features, axis=0) / num_boxes
|
||||||
|
num_boxes = num_boxes + 1
|
||||||
|
|
||||||
|
features = np.concatenate([np.expand_dims(g_feat, axis=0), features], axis=0)
|
||||||
|
self.features[index] = features
|
||||||
|
|
||||||
|
image_location = np.zeros((boxes.shape[0], 5), dtype=np.float32)
|
||||||
|
image_location[:,:4] = boxes
|
||||||
|
image_location[:,4] = (image_location[:,3] - image_location[:,1]) * (image_location[:,2] - image_location[:,0]) / (float(image_w) * float(image_h))
|
||||||
|
|
||||||
|
image_location_ori = copy.deepcopy(image_location)
|
||||||
|
|
||||||
|
image_location[:,0] = image_location[:,0] / float(image_w)
|
||||||
|
image_location[:,1] = image_location[:,1] / float(image_h)
|
||||||
|
image_location[:,2] = image_location[:,2] / float(image_w)
|
||||||
|
image_location[:,3] = image_location[:,3] / float(image_h)
|
||||||
|
|
||||||
|
g_location = np.array([0,0,1,1,1])
|
||||||
|
image_location = np.concatenate([np.expand_dims(g_location, axis=0), image_location], axis=0)
|
||||||
|
self.boxes[index] = image_location
|
||||||
|
|
||||||
|
g_location_ori = np.array([0, 0, image_w, image_h, image_w*image_h])
|
||||||
|
image_location_ori = np.concatenate([np.expand_dims(g_location_ori, axis=0), image_location_ori], axis=0)
|
||||||
|
self.boxes_ori[index] = image_location_ori
|
||||||
|
self.num_boxes[index] = num_boxes
|
||||||
|
|
||||||
|
# load the scene graph data
|
||||||
|
pth = os.path.join(self.scene_graph_path, f'{image_id}.pkl')
|
||||||
|
with open(pth, 'rb') as f:
|
||||||
|
graph_data = pickle.load(f)
|
||||||
|
edge_indexes = []
|
||||||
|
edge_attributes = []
|
||||||
|
for e_idx, e_attr in graph_data:
|
||||||
|
edge_indexes.append(e_idx)
|
||||||
|
# get one-hot-encoding of the edges
|
||||||
|
e_attr_one_hot = np.zeros((12,), dtype=np.float32) # 12 = 11 rels + hub-node rel
|
||||||
|
e_attr_one_hot[e_attr] = 1.0
|
||||||
|
edge_attributes.append(e_attr_one_hot)
|
||||||
|
edge_indexes = np.array(edge_indexes, dtype=np.float64).transpose(1, 0)
|
||||||
|
edge_attributes = np.stack(edge_attributes, axis=0)
|
||||||
|
|
||||||
|
self.edge_indexes[index] = edge_indexes
|
||||||
|
self.edge_attributes[index] = edge_attributes
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Read chunk from file everytime if not loaded in memory.
|
||||||
|
with self.env.begin(write=False) as txn:
|
||||||
|
item = pickle.loads(txn.get(image_id))
|
||||||
|
image_id = item['image_id']
|
||||||
|
image_h = int(item['image_h'])
|
||||||
|
image_w = int(item['image_w'])
|
||||||
|
num_boxes = int(item['num_boxes'])
|
||||||
|
cls_prob = np.frombuffer(base64.b64decode(item['cls_prob']), dtype=np.float32).reshape(num_boxes, 1601)
|
||||||
|
# add an extra row at the top for the <IMG> tokens
|
||||||
|
g_cls_prob = np.zeros(1601, dtype=np.float32)
|
||||||
|
g_cls_prob[0] = 1
|
||||||
|
cls_prob = np.concatenate([np.expand_dims(g_cls_prob,axis=0), cls_prob], axis=0)
|
||||||
|
|
||||||
|
features = np.frombuffer(base64.b64decode(item["features"]), dtype=np.float32).reshape(num_boxes, 2048)
|
||||||
|
boxes = np.frombuffer(base64.b64decode(item['boxes']), dtype=np.float32).reshape(num_boxes, 4)
|
||||||
|
g_feat = np.sum(features, axis=0) / num_boxes
|
||||||
|
num_boxes = num_boxes + 1
|
||||||
|
features = np.concatenate([np.expand_dims(g_feat, axis=0), features], axis=0)
|
||||||
|
|
||||||
|
image_location = np.zeros((boxes.shape[0], 5), dtype=np.float32)
|
||||||
|
image_location[:,:4] = boxes
|
||||||
|
image_location[:,4] = (image_location[:,3] - image_location[:,1]) * (image_location[:,2] - image_location[:,0]) / (float(image_w) * float(image_h))
|
||||||
|
|
||||||
|
image_location_ori = copy.deepcopy(image_location)
|
||||||
|
image_location[:,0] = image_location[:,0] / float(image_w)
|
||||||
|
image_location[:,1] = image_location[:,1] / float(image_h)
|
||||||
|
image_location[:,2] = image_location[:,2] / float(image_w)
|
||||||
|
image_location[:,3] = image_location[:,3] / float(image_h)
|
||||||
|
|
||||||
|
g_location = np.array([0,0,1,1,1])
|
||||||
|
image_location = np.concatenate([np.expand_dims(g_location, axis=0), image_location], axis=0)
|
||||||
|
|
||||||
|
g_location_ori = np.array([0,0,image_w,image_h,image_w*image_h])
|
||||||
|
image_location_ori = np.concatenate([np.expand_dims(g_location_ori, axis=0), image_location_ori], axis=0)
|
||||||
|
|
||||||
|
# load the scene graph data
|
||||||
|
pth = os.path.join(self.scene_graph_path, f'{image_id}.pkl')
|
||||||
|
with open(pth, 'rb') as f:
|
||||||
|
graph_data = pickle.load(f)
|
||||||
|
edge_indexes = []
|
||||||
|
edge_attributes = []
|
||||||
|
for e_idx, e_attr in graph_data:
|
||||||
|
edge_indexes.append(e_idx)
|
||||||
|
# get one-hot-encoding of the edges
|
||||||
|
e_attr_one_hot = np.zeros((12,), dtype=np.float32) # 12 = 11 rels + hub-node rel
|
||||||
|
e_attr_one_hot[e_attr] = 1.0
|
||||||
|
edge_attributes.append(e_attr_one_hot)
|
||||||
|
edge_indexes = np.array(edge_indexes, dtype=np.float64).transpose(1, 0)
|
||||||
|
edge_attributes = np.stack(edge_attributes, axis=0)
|
||||||
|
|
||||||
|
return features, num_boxes, image_location, image_location_ori, cls_prob, edge_indexes, edge_attributes
|
||||||
|
|
||||||
|
|
||||||
|
def keys(self) -> List[int]:
|
||||||
|
return self._image_ids
|
||||||
|
|
||||||
|
def set_keys(self, new_ids: List[str]):
|
||||||
|
self._image_ids = list(map(lambda _id: _id.encode('ascii') ,new_ids))
|
176
utils/init_utils.py
Normal file
176
utils/init_utils.py
Normal file
|
@ -0,0 +1,176 @@
|
||||||
|
import os
|
||||||
|
import os.path as osp
|
||||||
|
import random
|
||||||
|
import datetime
|
||||||
|
import itertools
|
||||||
|
import glob
|
||||||
|
import subprocess
|
||||||
|
import pyhocon
|
||||||
|
import glob
|
||||||
|
import re
|
||||||
|
import numpy as np
|
||||||
|
import glog as log
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import sys
|
||||||
|
sys.path.append('../')
|
||||||
|
|
||||||
|
from models import vdgr
|
||||||
|
from dataloader.dataloader_visdial import VisdialDataset
|
||||||
|
|
||||||
|
from dataloader.dataloader_visdial_dense import VisdialDenseDataset
|
||||||
|
|
||||||
|
|
||||||
|
def load_runner(config):
|
||||||
|
if config['train_on_dense']:
|
||||||
|
return vdgr.DenseRunner(config)
|
||||||
|
else:
|
||||||
|
return vdgr.SparseRunner(config)
|
||||||
|
|
||||||
|
def load_dataset(config):
|
||||||
|
dataset_eval = None
|
||||||
|
|
||||||
|
if config['train_on_dense']:
|
||||||
|
dataset = VisdialDenseDataset(config)
|
||||||
|
if config['skip_mrr_eval']:
|
||||||
|
temp = config['num_options_dense']
|
||||||
|
config['num_options_dense'] = config['num_options']
|
||||||
|
dataset_eval = VisdialDenseDataset(config)
|
||||||
|
config['num_options_dense'] = temp
|
||||||
|
else:
|
||||||
|
dataset_eval = VisdialDataset(config)
|
||||||
|
else:
|
||||||
|
dataset = VisdialDataset(config)
|
||||||
|
if config['skip_mrr_eval']:
|
||||||
|
dataset_eval = VisdialDenseDataset(config)
|
||||||
|
|
||||||
|
if config['use_trainval']:
|
||||||
|
dataset.split = 'trainval'
|
||||||
|
else:
|
||||||
|
dataset.split = 'train'
|
||||||
|
|
||||||
|
if dataset_eval is not None:
|
||||||
|
dataset_eval.split = 'val'
|
||||||
|
|
||||||
|
return dataset, dataset_eval
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_from_env(model, mode, eval_dir, model_type, tag=''):
|
||||||
|
if "GPU" in os.environ:
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ['GPU']
|
||||||
|
if mode in ['train', 'debug']:
|
||||||
|
config = pyhocon.ConfigFactory.parse_file(f"config/{model_type}.conf")[model]
|
||||||
|
else:
|
||||||
|
path_config = osp.join(eval_dir, 'code', f"config/{model_type}.conf")
|
||||||
|
config = pyhocon.ConfigFactory.parse_file(path_config)[model]
|
||||||
|
config['log_dir'] = eval_dir
|
||||||
|
config['model_config'] = osp.join(eval_dir, 'code/config/bert_base_6layer_6conect.json')
|
||||||
|
if config['dp_type'] == 'apex':
|
||||||
|
config['dp_type'] = 'ddp'
|
||||||
|
|
||||||
|
if config['dp_type'] == 'dp':
|
||||||
|
config['stack_gr_data'] = True
|
||||||
|
|
||||||
|
config['model_type'] = model_type
|
||||||
|
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
||||||
|
config['num_gpus'] = len(os.environ["CUDA_VISIBLE_DEVICES"].split(','))
|
||||||
|
# multi-gpu setting
|
||||||
|
if config['num_gpus'] > 1:
|
||||||
|
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||||
|
os.environ['MASTER_PORT'] = '5678'
|
||||||
|
|
||||||
|
if mode == 'debug':
|
||||||
|
model += '_debug'
|
||||||
|
|
||||||
|
if tag:
|
||||||
|
model += '-' + tag
|
||||||
|
if mode in ['train', 'debug']:
|
||||||
|
config['log_dir'] = os.path.join(config["log_dir"], model)
|
||||||
|
if not os.path.exists(config["log_dir"]):
|
||||||
|
os.makedirs(config["log_dir"])
|
||||||
|
config['visdial_output_dir'] = osp.join(config['log_dir'], config['visdial_output_dir'])
|
||||||
|
|
||||||
|
config['timestamp'] = datetime.datetime.now().strftime('%m%d-%H%M%S')
|
||||||
|
|
||||||
|
# add the bert config
|
||||||
|
config['bert_config'] = json.load(open(config['model_config'], 'r'))
|
||||||
|
if mode in ['predict', 'eval']:
|
||||||
|
if (not config['loads_start_path']) and (not config['loads_best_ckpt']):
|
||||||
|
config['loads_best_ckpt'] = True
|
||||||
|
print(f'Setting loads_best_ckpt=True under predict or eval mode')
|
||||||
|
if config['num_options_dense'] < 100:
|
||||||
|
config['num_options_dense'] = 100
|
||||||
|
print('Setting num_options_dense=100 under predict or eval mode')
|
||||||
|
if config['visdial_version'] == 0.9:
|
||||||
|
config['skip_ndcg_eval'] = True
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def set_log_file(fname, file_only=False):
|
||||||
|
# if fname already exists, find all log file under log dir,
|
||||||
|
# and name the current log file with a new number
|
||||||
|
if osp.exists(fname):
|
||||||
|
prefix, suffix = osp.splitext(fname)
|
||||||
|
log_files = glob.glob(prefix + '*' + suffix)
|
||||||
|
count = 0
|
||||||
|
for log_file in log_files:
|
||||||
|
num = re.search(r'(\d+)', log_file)
|
||||||
|
if num is not None:
|
||||||
|
num = int(num.group(0))
|
||||||
|
count = max(num, count)
|
||||||
|
fname = fname.replace(suffix, str(count + 1) + suffix)
|
||||||
|
# set log file
|
||||||
|
# simple tricks for duplicating logging destination in the logging module such as:
|
||||||
|
# logging.getLogger().addHandler(logging.FileHandler(filename))
|
||||||
|
# does NOT work well here, because python Traceback message (not via logging module) is not sent to the file,
|
||||||
|
# the following solution (copied from : https://stackoverflow.com/questions/616645) is a little bit
|
||||||
|
# complicated but simulates exactly the "tee" command in linux shell, and it redirects everything
|
||||||
|
if file_only:
|
||||||
|
# we only output messages to file, and stdout/stderr receives nothing.
|
||||||
|
# this feature is designed for executing the script via ssh:
|
||||||
|
# since ssh has a windowing kind of flow control, i.e., if the controller does not read data from a
|
||||||
|
# ssh channel and its buffer fills up, the execution machine will not be able to write anything into the
|
||||||
|
# channel and the process will be set to sleeping (S) status until someone reads all data from the channel.
|
||||||
|
# this is not desired since we do not want to read stdout/stderr from the controller machine.
|
||||||
|
# so, here we use a simple solution: disable output to stdout/stderr and only output messages to log file.
|
||||||
|
log.logger.handlers[0].stream = log.handler.stream = sys.stdout = sys.stderr = f = open(fname, 'w', buffering=1)
|
||||||
|
else:
|
||||||
|
# we output messages to both file and stdout/stderr
|
||||||
|
tee = subprocess.Popen(['tee', fname], stdin=subprocess.PIPE)
|
||||||
|
os.dup2(tee.stdin.fileno(), sys.stdout.fileno())
|
||||||
|
os.dup2(tee.stdin.fileno(), sys.stderr.fileno())
|
||||||
|
|
||||||
|
|
||||||
|
def copy_file_to_log(log_dir):
|
||||||
|
dirs_to_cp = ['.', 'config', 'dataloader', 'models', 'utils']
|
||||||
|
files_to_cp = ['*.py', '*.json', '*.sh', '*.conf']
|
||||||
|
for dir_name in dirs_to_cp:
|
||||||
|
dir_name = osp.join(log_dir, 'code', dir_name)
|
||||||
|
if not osp.exists(dir_name):
|
||||||
|
os.makedirs(dir_name)
|
||||||
|
for dir_name, file_name in itertools.product(dirs_to_cp, files_to_cp):
|
||||||
|
filename = osp.join(dir_name, file_name)
|
||||||
|
if len(glob.glob(filename)) > 0:
|
||||||
|
os.system(f'cp {filename} {osp.join(log_dir, "code", dir_name)}')
|
||||||
|
log.info(f'Files copied to {osp.join(log_dir, "code")}')
|
||||||
|
|
||||||
|
|
||||||
|
def set_random_seed(random_seed):
|
||||||
|
torch.manual_seed(random_seed)
|
||||||
|
torch.cuda.manual_seed(random_seed)
|
||||||
|
random.seed(random_seed)
|
||||||
|
np.random.seed(random_seed)
|
||||||
|
|
||||||
|
|
||||||
|
def set_training_steps(config, num_samples):
|
||||||
|
if config['parallel'] and config['dp_type'] == 'dp':
|
||||||
|
config['num_iter_per_epoch'] = int(np.ceil(num_samples / config['batch_size']))
|
||||||
|
else:
|
||||||
|
config['num_iter_per_epoch'] = int(np.ceil(num_samples / (config['batch_size'] * config['num_gpus'])))
|
||||||
|
if 'train_steps' not in config:
|
||||||
|
config['train_steps'] = config['num_iter_per_epoch'] * config['num_epochs']
|
||||||
|
if 'warmup_steps' not in config:
|
||||||
|
config['warmup_steps'] = int(config['train_steps'] * config['warmup_ratio'])
|
||||||
|
return config
|
456
utils/model_utils.py
Normal file
456
utils/model_utils.py
Normal file
|
@ -0,0 +1,456 @@
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def truncated_normal_(tensor, mean=0, std=1):
|
||||||
|
size = tensor.shape
|
||||||
|
tmp = tensor.new_empty(size + (4,)).normal_()
|
||||||
|
valid = (tmp < 2) & (tmp > -2)
|
||||||
|
ind = valid.max(-1, keepdim=True)[1]
|
||||||
|
tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
|
||||||
|
tensor.data.mul_(std).add_(mean)
|
||||||
|
|
||||||
|
|
||||||
|
def init_params(module, initializer='normal'):
|
||||||
|
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
if initializer == 'kaiming_normal':
|
||||||
|
nn.init.kaiming_normal_(module.weight.data)
|
||||||
|
elif initializer == 'normal':
|
||||||
|
nn.init.normal_(module.weight.data, std=0.02)
|
||||||
|
elif initializer == 'truncated_normal':
|
||||||
|
truncated_normal_(module.weight.data, std=0.02)
|
||||||
|
|
||||||
|
if module.bias is not None:
|
||||||
|
nn.init.zeros_(module.bias.data)
|
||||||
|
|
||||||
|
# log.info('initialized Linear')
|
||||||
|
|
||||||
|
elif isinstance(module, nn.Embedding):
|
||||||
|
if initializer == 'kaiming_normal':
|
||||||
|
nn.init.kaiming_normal_(module.weight.data)
|
||||||
|
elif initializer == 'normal':
|
||||||
|
nn.init.normal_(module.weight.data, std=0.02)
|
||||||
|
elif initializer == 'truncated_normal':
|
||||||
|
truncated_normal_(module.weight.data, std=0.02)
|
||||||
|
|
||||||
|
elif isinstance(module, nn.Conv2d) or isinstance(module, nn.Conv1d):
|
||||||
|
nn.init.kaiming_normal_(module.weight, mode='fan_out')
|
||||||
|
# log.info('initialized Conv')
|
||||||
|
|
||||||
|
elif isinstance(module, nn.RNNBase) or isinstance(module, nn.LSTMCell) or isinstance(module, nn.GRUCell):
|
||||||
|
for name, param in module.named_parameters():
|
||||||
|
if 'weight' in name:
|
||||||
|
nn.init.orthogonal_(param.data)
|
||||||
|
elif 'bias' in name:
|
||||||
|
nn.init.normal_(param.data)
|
||||||
|
|
||||||
|
# log.info('initialized LSTM')
|
||||||
|
|
||||||
|
elif isinstance(module, nn.BatchNorm1d) or isinstance(module, nn.BatchNorm2d):
|
||||||
|
module.weight.data.normal_(1.0, 0.02)
|
||||||
|
# log.info('initialized BatchNorm')
|
||||||
|
|
||||||
|
|
||||||
|
def TensorboardWriter(save_path):
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
return SummaryWriter(save_path, comment="Unmt")
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_EPS = 1e-8
|
||||||
|
PADDED_Y_VALUE = -1
|
||||||
|
|
||||||
|
|
||||||
|
def listMLE(y_pred, y_true, eps=DEFAULT_EPS, padded_value_indicator=PADDED_Y_VALUE):
|
||||||
|
"""
|
||||||
|
ListMLE loss introduced in "Listwise Approach to Learning to Rank - Theory and Algorithm".
|
||||||
|
:param y_pred: predictions from the model, shape [batch_size, slate_length]
|
||||||
|
:param y_true: ground truth labels, shape [batch_size, slate_length]
|
||||||
|
:param eps: epsilon value, used for numerical stability
|
||||||
|
:param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
|
||||||
|
:return: loss value, a torch.Tensor
|
||||||
|
"""
|
||||||
|
# shuffle for randomised tie resolution
|
||||||
|
random_indices = torch.randperm(y_pred.shape[-1])
|
||||||
|
y_pred_shuffled = y_pred[:, random_indices]
|
||||||
|
y_true_shuffled = y_true[:, random_indices]
|
||||||
|
|
||||||
|
y_true_sorted, indices = y_true_shuffled.sort(descending=True, dim=-1)
|
||||||
|
|
||||||
|
mask = y_true_sorted == padded_value_indicator
|
||||||
|
|
||||||
|
preds_sorted_by_true = torch.gather(y_pred_shuffled, dim=1, index=indices)
|
||||||
|
preds_sorted_by_true[mask] = float("-inf")
|
||||||
|
|
||||||
|
max_pred_values, _ = preds_sorted_by_true.max(dim=1, keepdim=True)
|
||||||
|
|
||||||
|
preds_sorted_by_true_minus_max = preds_sorted_by_true - max_pred_values
|
||||||
|
|
||||||
|
cumsums = torch.cumsum(preds_sorted_by_true_minus_max.exp().flip(dims=[1]), dim=1).flip(dims=[1])
|
||||||
|
|
||||||
|
observation_loss = torch.log(cumsums + eps) - preds_sorted_by_true_minus_max
|
||||||
|
|
||||||
|
observation_loss[mask] = 0.0
|
||||||
|
|
||||||
|
return torch.mean(torch.sum(observation_loss, dim=1))
|
||||||
|
|
||||||
|
|
||||||
|
def approxNDCGLoss(y_pred, y_true, eps=DEFAULT_EPS, padded_value_indicator=PADDED_Y_VALUE, alpha=1.):
|
||||||
|
"""
|
||||||
|
Loss based on approximate NDCG introduced in "A General Approximation Framework for Direct Optimization of
|
||||||
|
Information Retrieval Measures". Please note that this method does not implement any kind of truncation.
|
||||||
|
:param y_pred: predictions from the model, shape [batch_size, slate_length]
|
||||||
|
:param y_true: ground truth labels, shape [batch_size, slate_length]
|
||||||
|
:param eps: epsilon value, used for numerical stability
|
||||||
|
:param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
|
||||||
|
:param alpha: score difference weight used in the sigmoid function
|
||||||
|
:return: loss value, a torch.Tensor
|
||||||
|
"""
|
||||||
|
device = y_pred.device
|
||||||
|
y_pred = y_pred.clone()
|
||||||
|
y_true = y_true.clone()
|
||||||
|
|
||||||
|
padded_mask = y_true == padded_value_indicator
|
||||||
|
y_pred[padded_mask] = float("-inf")
|
||||||
|
y_true[padded_mask] = float("-inf")
|
||||||
|
|
||||||
|
# Here we sort the true and predicted relevancy scores.
|
||||||
|
y_pred_sorted, indices_pred = y_pred.sort(descending=True, dim=-1)
|
||||||
|
y_true_sorted, _ = y_true.sort(descending=True, dim=-1)
|
||||||
|
|
||||||
|
# After sorting, we can mask out the pairs of indices (i, j) containing index of a padded element.
|
||||||
|
true_sorted_by_preds = torch.gather(y_true, dim=1, index=indices_pred)
|
||||||
|
true_diffs = true_sorted_by_preds[:, :, None] - true_sorted_by_preds[:, None, :]
|
||||||
|
padded_pairs_mask = torch.isfinite(true_diffs)
|
||||||
|
padded_pairs_mask.diagonal(dim1=-2, dim2=-1).zero_()
|
||||||
|
|
||||||
|
# Here we clamp the -infs to get correct gains and ideal DCGs (maxDCGs)
|
||||||
|
true_sorted_by_preds.clamp_(min=0.)
|
||||||
|
y_true_sorted.clamp_(min=0.)
|
||||||
|
|
||||||
|
# Here we find the gains, discounts and ideal DCGs per slate.
|
||||||
|
pos_idxs = torch.arange(1, y_pred.shape[1] + 1).to(device)
|
||||||
|
D = torch.log2(1. + pos_idxs.float())[None, :]
|
||||||
|
maxDCGs = torch.sum((torch.pow(2, y_true_sorted) - 1) / D, dim=-1).clamp(min=eps)
|
||||||
|
G = (torch.pow(2, true_sorted_by_preds) - 1) / maxDCGs[:, None]
|
||||||
|
|
||||||
|
# Here we approximate the ranking positions according to Eqs 19-20 and later approximate NDCG (Eq 21)
|
||||||
|
scores_diffs = (y_pred_sorted[:, :, None] - y_pred_sorted[:, None, :])
|
||||||
|
scores_diffs[~padded_pairs_mask] = 0.
|
||||||
|
approx_pos = 1. + torch.sum(padded_pairs_mask.float() * (torch.sigmoid(-alpha * scores_diffs).clamp(min=eps)),
|
||||||
|
dim=-1)
|
||||||
|
approx_D = torch.log2(1. + approx_pos)
|
||||||
|
approx_NDCG = torch.sum((G / approx_D), dim=-1)
|
||||||
|
|
||||||
|
return -torch.mean(approx_NDCG)
|
||||||
|
# return -torch.mean(approx_NDCG)
|
||||||
|
|
||||||
|
|
||||||
|
def listNet(y_pred, y_true, eps=DEFAULT_EPS, padded_value_indicator=PADDED_Y_VALUE):
|
||||||
|
"""
|
||||||
|
ListNet loss introduced in "Learning to Rank: From Pairwise Approach to Listwise Approach".
|
||||||
|
:param y_pred: predictions from the model, shape [batch_size, slate_length]
|
||||||
|
:param y_true: ground truth labels, shape [batch_size, slate_length]
|
||||||
|
:param eps: epsilon value, used for numerical stability
|
||||||
|
:param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
|
||||||
|
:return: loss value, a torch.Tensor
|
||||||
|
"""
|
||||||
|
y_pred = y_pred.clone()
|
||||||
|
y_true = y_true.clone()
|
||||||
|
|
||||||
|
mask = y_true == padded_value_indicator
|
||||||
|
y_pred[mask] = float('-inf')
|
||||||
|
y_true[mask] = float('-inf')
|
||||||
|
|
||||||
|
preds_smax = F.softmax(y_pred, dim=1)
|
||||||
|
true_smax = F.softmax(y_true, dim=1)
|
||||||
|
|
||||||
|
preds_smax = preds_smax + eps
|
||||||
|
preds_log = torch.log(preds_smax)
|
||||||
|
|
||||||
|
return torch.mean(-torch.sum(true_smax * preds_log, dim=1))
|
||||||
|
|
||||||
|
|
||||||
|
def deterministic_neural_sort(s, tau, mask):
|
||||||
|
"""
|
||||||
|
Deterministic neural sort.
|
||||||
|
Code taken from "Stochastic Optimization of Sorting Networks via Continuous Relaxations", ICLR 2019.
|
||||||
|
Minor modifications applied to the original code (masking).
|
||||||
|
:param s: values to sort, shape [batch_size, slate_length]
|
||||||
|
:param tau: temperature for the final softmax function
|
||||||
|
:param mask: mask indicating padded elements
|
||||||
|
:return: approximate permutation matrices of shape [batch_size, slate_length, slate_length]
|
||||||
|
"""
|
||||||
|
dev = s.device
|
||||||
|
|
||||||
|
n = s.size()[1]
|
||||||
|
one = torch.ones((n, 1), dtype=torch.float32, device=dev)
|
||||||
|
s = s.masked_fill(mask[:, :, None], -1e8)
|
||||||
|
A_s = torch.abs(s - s.permute(0, 2, 1))
|
||||||
|
A_s = A_s.masked_fill(mask[:, :, None] | mask[:, None, :], 0.0)
|
||||||
|
|
||||||
|
B = torch.matmul(A_s, torch.matmul(one, torch.transpose(one, 0, 1)))
|
||||||
|
|
||||||
|
temp = [n - m + 1 - 2 * (torch.arange(n - m, device=dev) + 1) for m in mask.squeeze(-1).sum(dim=1)]
|
||||||
|
temp = [t.type(torch.float32) for t in temp]
|
||||||
|
temp = [torch.cat((t, torch.zeros(n - len(t), device=dev))) for t in temp]
|
||||||
|
scaling = torch.stack(temp).type(torch.float32).to(dev) # type: ignore
|
||||||
|
|
||||||
|
s = s.masked_fill(mask[:, :, None], 0.0)
|
||||||
|
C = torch.matmul(s, scaling.unsqueeze(-2))
|
||||||
|
|
||||||
|
P_max = (C - B).permute(0, 2, 1)
|
||||||
|
P_max = P_max.masked_fill(mask[:, :, None] | mask[:, None, :], -np.inf)
|
||||||
|
P_max = P_max.masked_fill(mask[:, :, None] & mask[:, None, :], 1.0)
|
||||||
|
sm = torch.nn.Softmax(-1)
|
||||||
|
P_hat = sm(P_max / tau)
|
||||||
|
return P_hat
|
||||||
|
|
||||||
|
def sample_gumbel(samples_shape, device, eps=1e-10) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Sampling from Gumbel distribution.
|
||||||
|
Code taken from "Stochastic Optimization of Sorting Networks via Continuous Relaxations", ICLR 2019.
|
||||||
|
Minor modifications applied to the original code (masking).
|
||||||
|
:param samples_shape: shape of the output samples tensor
|
||||||
|
:param device: device of the output samples tensor
|
||||||
|
:param eps: epsilon for the logarithm function
|
||||||
|
:return: Gumbel samples tensor of shape samples_shape
|
||||||
|
"""
|
||||||
|
U = torch.rand(samples_shape, device=device)
|
||||||
|
return -torch.log(-torch.log(U + eps) + eps)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_mask_and_get_true_sorted_by_preds(y_pred, y_true, padding_indicator=PADDED_Y_VALUE):
|
||||||
|
mask = y_true == padding_indicator
|
||||||
|
|
||||||
|
y_pred[mask] = float('-inf')
|
||||||
|
y_true[mask] = 0.0
|
||||||
|
|
||||||
|
_, indices = y_pred.sort(descending=True, dim=-1)
|
||||||
|
return torch.gather(y_true, dim=1, index=indices)
|
||||||
|
|
||||||
|
|
||||||
|
def dcg(y_pred, y_true, ats=None, gain_function=lambda x: torch.pow(2, x) - 1, padding_indicator=PADDED_Y_VALUE):
|
||||||
|
"""
|
||||||
|
Discounted Cumulative Gain at k.
|
||||||
|
Compute DCG at ranks given by ats or at the maximum rank if ats is None.
|
||||||
|
:param y_pred: predictions from the model, shape [batch_size, slate_length]
|
||||||
|
:param y_true: ground truth labels, shape [batch_size, slate_length]
|
||||||
|
:param ats: optional list of ranks for DCG evaluation, if None, maximum rank is used
|
||||||
|
:param gain_function: callable, gain function for the ground truth labels, e.g. torch.pow(2, x) - 1
|
||||||
|
:param padding_indicator: an indicator of the y_true index containing a padded item, e.g. -1
|
||||||
|
:return: DCG values for each slate and evaluation position, shape [batch_size, len(ats)]
|
||||||
|
"""
|
||||||
|
y_true = y_true.clone()
|
||||||
|
y_pred = y_pred.clone()
|
||||||
|
|
||||||
|
actual_length = y_true.shape[1]
|
||||||
|
|
||||||
|
if ats is None:
|
||||||
|
ats = [actual_length]
|
||||||
|
ats = [min(at, actual_length) for at in ats]
|
||||||
|
|
||||||
|
true_sorted_by_preds = apply_mask_and_get_true_sorted_by_preds(y_pred, y_true, padding_indicator)
|
||||||
|
|
||||||
|
discounts = (torch.tensor(1) / torch.log2(torch.arange(true_sorted_by_preds.shape[1], dtype=torch.float) + 2.0)).to(
|
||||||
|
device=true_sorted_by_preds.device)
|
||||||
|
|
||||||
|
gains = gain_function(true_sorted_by_preds)
|
||||||
|
|
||||||
|
discounted_gains = (gains * discounts)[:, :np.max(ats)]
|
||||||
|
|
||||||
|
cum_dcg = torch.cumsum(discounted_gains, dim=1)
|
||||||
|
|
||||||
|
ats_tensor = torch.tensor(ats, dtype=torch.long) - torch.tensor(1)
|
||||||
|
|
||||||
|
dcg = cum_dcg[:, ats_tensor]
|
||||||
|
|
||||||
|
return dcg
|
||||||
|
|
||||||
|
|
||||||
|
def sinkhorn_scaling(mat, mask=None, tol=1e-6, max_iter=50):
|
||||||
|
"""
|
||||||
|
Sinkhorn scaling procedure.
|
||||||
|
:param mat: a tensor of square matrices of shape N x M x M, where N is batch size
|
||||||
|
:param mask: a tensor of masks of shape N x M
|
||||||
|
:param tol: Sinkhorn scaling tolerance
|
||||||
|
:param max_iter: maximum number of iterations of the Sinkhorn scaling
|
||||||
|
:return: a tensor of (approximately) doubly stochastic matrices
|
||||||
|
"""
|
||||||
|
if mask is not None:
|
||||||
|
mat = mat.masked_fill(mask[:, None, :] | mask[:, :, None], 0.0)
|
||||||
|
mat = mat.masked_fill(mask[:, None, :] & mask[:, :, None], 1.0)
|
||||||
|
|
||||||
|
for _ in range(max_iter):
|
||||||
|
mat = mat / mat.sum(dim=1, keepdim=True).clamp(min=DEFAULT_EPS)
|
||||||
|
mat = mat / mat.sum(dim=2, keepdim=True).clamp(min=DEFAULT_EPS)
|
||||||
|
|
||||||
|
if torch.max(torch.abs(mat.sum(dim=2) - 1.)) < tol and torch.max(torch.abs(mat.sum(dim=1) - 1.)) < tol:
|
||||||
|
break
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
mat = mat.masked_fill(mask[:, None, :] | mask[:, :, None], 0.0)
|
||||||
|
|
||||||
|
return mat
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def stochastic_neural_sort(s, n_samples, tau, mask, beta=1.0, log_scores=True, eps=1e-10):
|
||||||
|
"""
|
||||||
|
Stochastic neural sort. Please note that memory complexity grows by factor n_samples.
|
||||||
|
Code taken from "Stochastic Optimization of Sorting Networks via Continuous Relaxations", ICLR 2019.
|
||||||
|
Minor modifications applied to the original code (masking).
|
||||||
|
:param s: values to sort, shape [batch_size, slate_length]
|
||||||
|
:param n_samples: number of samples (approximations) for each permutation matrix
|
||||||
|
:param tau: temperature for the final softmax function
|
||||||
|
:param mask: mask indicating padded elements
|
||||||
|
:param beta: scale parameter for the Gumbel distribution
|
||||||
|
:param log_scores: whether to apply the logarithm function to scores prior to Gumbel perturbation
|
||||||
|
:param eps: epsilon for the logarithm function
|
||||||
|
:return: approximate permutation matrices of shape [n_samples, batch_size, slate_length, slate_length]
|
||||||
|
"""
|
||||||
|
dev = s.device
|
||||||
|
|
||||||
|
batch_size = s.size()[0]
|
||||||
|
n = s.size()[1]
|
||||||
|
s_positive = s + torch.abs(s.min())
|
||||||
|
samples = beta * sample_gumbel([n_samples, batch_size, n, 1], device=dev)
|
||||||
|
if log_scores:
|
||||||
|
s_positive = torch.log(s_positive + eps)
|
||||||
|
|
||||||
|
s_perturb = (s_positive + samples).view(n_samples * batch_size, n, 1)
|
||||||
|
mask_repeated = mask.repeat_interleave(n_samples, dim=0)
|
||||||
|
|
||||||
|
P_hat = deterministic_neural_sort(s_perturb, tau, mask_repeated)
|
||||||
|
P_hat = P_hat.view(n_samples, batch_size, n, n)
|
||||||
|
return P_hat
|
||||||
|
|
||||||
|
|
||||||
|
def neuralNDCG(y_pred, y_true, padded_value_indicator=PADDED_Y_VALUE, temperature=1., powered_relevancies=True, k=None,
|
||||||
|
stochastic=False, n_samples=32, beta=0.1, log_scores=True):
|
||||||
|
"""
|
||||||
|
NeuralNDCG loss introduced in "NeuralNDCG: Direct Optimisation of a Ranking Metric via Differentiable
|
||||||
|
Relaxation of Sorting" - https://arxiv.org/abs/2102.07831. Based on the NeuralSort algorithm.
|
||||||
|
:param y_pred: predictions from the model, shape [batch_size, slate_length]
|
||||||
|
:param y_true: ground truth labels, shape [batch_size, slate_length]
|
||||||
|
:param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
|
||||||
|
:param temperature: temperature for the NeuralSort algorithm
|
||||||
|
:param powered_relevancies: whether to apply 2^x - 1 gain function, x otherwise
|
||||||
|
:param k: rank at which the loss is truncated
|
||||||
|
:param stochastic: whether to calculate the stochastic variant
|
||||||
|
:param n_samples: how many stochastic samples are taken, used if stochastic == True
|
||||||
|
:param beta: beta parameter for NeuralSort algorithm, used if stochastic == True
|
||||||
|
:param log_scores: log_scores parameter for NeuralSort algorithm, used if stochastic == True
|
||||||
|
:return: loss value, a torch.Tensor
|
||||||
|
"""
|
||||||
|
dev = y_pred.device
|
||||||
|
|
||||||
|
if k is None:
|
||||||
|
k = y_true.shape[1]
|
||||||
|
|
||||||
|
mask = (y_true == padded_value_indicator)
|
||||||
|
# Choose the deterministic/stochastic variant
|
||||||
|
if stochastic:
|
||||||
|
P_hat = stochastic_neural_sort(y_pred.unsqueeze(-1), n_samples=n_samples, tau=temperature, mask=mask,
|
||||||
|
beta=beta, log_scores=log_scores)
|
||||||
|
else:
|
||||||
|
P_hat = deterministic_neural_sort(y_pred.unsqueeze(-1), tau=temperature, mask=mask).unsqueeze(0)
|
||||||
|
|
||||||
|
# Perform sinkhorn scaling to obtain doubly stochastic permutation matrices
|
||||||
|
P_hat = sinkhorn_scaling(P_hat.view(P_hat.shape[0] * P_hat.shape[1], P_hat.shape[2], P_hat.shape[3]),
|
||||||
|
mask.repeat_interleave(P_hat.shape[0], dim=0), tol=1e-6, max_iter=50)
|
||||||
|
P_hat = P_hat.view(int(P_hat.shape[0] / y_pred.shape[0]), y_pred.shape[0], P_hat.shape[1], P_hat.shape[2])
|
||||||
|
|
||||||
|
# Mask P_hat and apply to true labels, ie approximately sort them
|
||||||
|
P_hat = P_hat.masked_fill(mask[None, :, :, None] | mask[None, :, None, :], 0.)
|
||||||
|
y_true_masked = y_true.masked_fill(mask, 0.).unsqueeze(-1).unsqueeze(0)
|
||||||
|
if powered_relevancies:
|
||||||
|
y_true_masked = torch.pow(2., y_true_masked) - 1.
|
||||||
|
|
||||||
|
ground_truth = torch.matmul(P_hat, y_true_masked).squeeze(-1)
|
||||||
|
discounts = (torch.tensor(1.) / torch.log2(torch.arange(y_true.shape[-1], dtype=torch.float) + 2.)).to(dev)
|
||||||
|
discounted_gains = ground_truth * discounts
|
||||||
|
|
||||||
|
if powered_relevancies:
|
||||||
|
idcg = dcg(y_true, y_true, ats=[k]).permute(1, 0)
|
||||||
|
else:
|
||||||
|
idcg = dcg(y_true, y_true, ats=[k], gain_function=lambda x: x).permute(1, 0)
|
||||||
|
|
||||||
|
discounted_gains = discounted_gains[:, :, :k]
|
||||||
|
ndcg = discounted_gains.sum(dim=-1) / (idcg + DEFAULT_EPS)
|
||||||
|
idcg_mask = idcg == 0.
|
||||||
|
ndcg = ndcg.masked_fill(idcg_mask.repeat(ndcg.shape[0], 1), 0.)
|
||||||
|
|
||||||
|
assert (ndcg < 0.).sum() >= 0, "every ndcg should be non-negative"
|
||||||
|
if idcg_mask.all():
|
||||||
|
return torch.tensor(0.)
|
||||||
|
|
||||||
|
mean_ndcg = ndcg.sum() / ((~idcg_mask).sum() * ndcg.shape[0]) # type: ignore
|
||||||
|
return -1. * mean_ndcg # -1 cause we want to maximize NDCG
|
||||||
|
|
||||||
|
|
||||||
|
def neuralNDCG_transposed(y_pred, y_true, padded_value_indicator=PADDED_Y_VALUE, temperature=1.,
|
||||||
|
powered_relevancies=True, k=None, stochastic=False, n_samples=32, beta=0.1, log_scores=True,
|
||||||
|
max_iter=50, tol=1e-6):
|
||||||
|
"""
|
||||||
|
NeuralNDCG Transposed loss introduced in "NeuralNDCG: Direct Optimisation of a Ranking Metric via Differentiable
|
||||||
|
Relaxation of Sorting" - https://arxiv.org/abs/2102.07831. Based on the NeuralSort algorithm.
|
||||||
|
:param y_pred: predictions from the model, shape [batch_size, slate_length]
|
||||||
|
:param y_true: ground truth labels, shape [batch_size, slate_length]
|
||||||
|
:param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
|
||||||
|
:param temperature: temperature for the NeuralSort algorithm
|
||||||
|
:param powered_relevancies: whether to apply 2^x - 1 gain function, x otherwise
|
||||||
|
:param k: rank at which the loss is truncated
|
||||||
|
:param stochastic: whether to calculate the stochastic variant
|
||||||
|
:param n_samples: how many stochastic samples are taken, used if stochastic == True
|
||||||
|
:param beta: beta parameter for NeuralSort algorithm, used if stochastic == True
|
||||||
|
:param log_scores: log_scores parameter for NeuralSort algorithm, used if stochastic == True
|
||||||
|
:param max_iter: maximum iteration count for Sinkhorn scaling
|
||||||
|
:param tol: tolerance for Sinkhorn scaling
|
||||||
|
:return: loss value, a torch.Tensor
|
||||||
|
"""
|
||||||
|
dev = y_pred.device
|
||||||
|
|
||||||
|
if k is None:
|
||||||
|
k = y_true.shape[1]
|
||||||
|
|
||||||
|
mask = (y_true == padded_value_indicator)
|
||||||
|
|
||||||
|
if stochastic:
|
||||||
|
P_hat = stochastic_neural_sort(y_pred.unsqueeze(-1), n_samples=n_samples, tau=temperature, mask=mask,
|
||||||
|
beta=beta, log_scores=log_scores)
|
||||||
|
else:
|
||||||
|
P_hat = deterministic_neural_sort(y_pred.unsqueeze(-1), tau=temperature, mask=mask).unsqueeze(0)
|
||||||
|
|
||||||
|
# Perform sinkhorn scaling to obtain doubly stochastic permutation matrices
|
||||||
|
P_hat_masked = sinkhorn_scaling(P_hat.view(P_hat.shape[0] * y_pred.shape[0], y_pred.shape[1], y_pred.shape[1]),
|
||||||
|
mask.repeat_interleave(P_hat.shape[0], dim=0), tol=tol, max_iter=max_iter)
|
||||||
|
P_hat_masked = P_hat_masked.view(P_hat.shape[0], y_pred.shape[0], y_pred.shape[1], y_pred.shape[1])
|
||||||
|
discounts = (torch.tensor(1) / torch.log2(torch.arange(y_true.shape[-1], dtype=torch.float) + 2.)).to(dev)
|
||||||
|
|
||||||
|
# This takes care of the @k metric truncation - if something is @>k, it is useless and gets 0.0 discount
|
||||||
|
discounts[k:] = 0.
|
||||||
|
discounts = discounts[None, None, :, None]
|
||||||
|
|
||||||
|
# Here the discounts become expected discounts
|
||||||
|
discounts = torch.matmul(P_hat_masked.permute(0, 1, 3, 2), discounts).squeeze(-1)
|
||||||
|
if powered_relevancies:
|
||||||
|
gains = torch.pow(2., y_true) - 1
|
||||||
|
discounted_gains = gains.unsqueeze(0) * discounts
|
||||||
|
idcg = dcg(y_true, y_true, ats=[k]).squeeze()
|
||||||
|
else:
|
||||||
|
gains = y_true
|
||||||
|
discounted_gains = gains.unsqueeze(0) * discounts
|
||||||
|
idcg = dcg(y_true, y_true, ats=[k]).squeeze()
|
||||||
|
|
||||||
|
ndcg = discounted_gains.sum(dim=2) / (idcg + DEFAULT_EPS)
|
||||||
|
idcg_mask = idcg == 0.
|
||||||
|
ndcg = ndcg.masked_fill(idcg_mask, 0.)
|
||||||
|
|
||||||
|
assert (ndcg < 0.).sum() >= 0, "every ndcg should be non-negative"
|
||||||
|
if idcg_mask.all():
|
||||||
|
return torch.tensor(0.)
|
||||||
|
|
||||||
|
mean_ndcg = ndcg.sum() / ((~idcg_mask).sum() * ndcg.shape[0]) # type: ignore
|
||||||
|
return -1. * mean_ndcg # -1 cause we want to maximize NDCG
|
41
utils/modules.py
Normal file
41
utils/modules.py
Normal file
|
@ -0,0 +1,41 @@
|
||||||
|
from collections import Counter, defaultdict
|
||||||
|
import logging
|
||||||
|
from typing import Union, List, Dict, Any
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class Identity(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Reshaper(nn.Module):
|
||||||
|
def __init__(self, *output_shape):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.output_shape = output_shape
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor):
|
||||||
|
return input.view(*self.output_shape)
|
||||||
|
|
||||||
|
|
||||||
|
class Normalizer(nn.Module):
|
||||||
|
def __init__(self, target_norm=1.):
|
||||||
|
super().__init__()
|
||||||
|
self.target_norm = target_norm
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor):
|
||||||
|
return input * self.target_norm / input.norm(p=2, dim=1, keepdim=True)
|
||||||
|
|
||||||
|
|
||||||
|
class Squeezer(nn.Module):
|
||||||
|
def __init__(self, dim=-1):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
return torch.squeeze(input, dim=self.dim)
|
389
utils/optim_utils.py
Normal file
389
utils/optim_utils.py
Normal file
|
@ -0,0 +1,389 @@
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
import functools
|
||||||
|
import glog as log
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn, optim
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
from torch.optim.lr_scheduler import _LRScheduler, ConstantLR
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.nn.utils import clip_grad_norm_
|
||||||
|
from pytorch_transformers.optimization import AdamW
|
||||||
|
|
||||||
|
|
||||||
|
class WarmupLinearScheduleNonZero(_LRScheduler):
|
||||||
|
""" Linear warmup and then linear decay.
|
||||||
|
Linearly increases learning rate from 0 to max_lr over `warmup_steps` training steps.
|
||||||
|
Linearly decreases learning rate linearly to min_lr over remaining `t_total - warmup_steps` steps.
|
||||||
|
"""
|
||||||
|
def __init__(self, optimizer, warmup_steps, t_total, min_lr=1e-5, last_epoch=-1):
|
||||||
|
self.warmup_steps = warmup_steps
|
||||||
|
self.t_total = t_total
|
||||||
|
self.min_lr = min_lr
|
||||||
|
super(WarmupLinearScheduleNonZero, self).__init__(optimizer, last_epoch=last_epoch)
|
||||||
|
|
||||||
|
def get_lr(self):
|
||||||
|
step = self.last_epoch
|
||||||
|
if step < self.warmup_steps:
|
||||||
|
lr_factor = float(step) / float(max(1, self.warmup_steps))
|
||||||
|
else:
|
||||||
|
lr_factor = max(0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps)))
|
||||||
|
|
||||||
|
return [base_lr * lr_factor if (base_lr * lr_factor) > self.min_lr else self.min_lr for base_lr in self.base_lrs]
|
||||||
|
|
||||||
|
|
||||||
|
def init_optim(model, config):
|
||||||
|
optimizer_grouped_parameters = []
|
||||||
|
|
||||||
|
gnn_params = []
|
||||||
|
|
||||||
|
encoder_params_with_decay = []
|
||||||
|
encoder_params_without_decay = []
|
||||||
|
|
||||||
|
exclude_from_weight_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
||||||
|
|
||||||
|
for module_name, module in model.named_children():
|
||||||
|
for param_name, param in module.named_parameters():
|
||||||
|
if param.requires_grad:
|
||||||
|
if "gnn" in param_name:
|
||||||
|
gnn_params.append(param)
|
||||||
|
elif module_name == 'encoder':
|
||||||
|
if any(ex in param_name for ex in exclude_from_weight_decay):
|
||||||
|
encoder_params_without_decay.append(param)
|
||||||
|
else:
|
||||||
|
encoder_params_with_decay.append(param)
|
||||||
|
|
||||||
|
optimizer_grouped_parameters = [
|
||||||
|
{
|
||||||
|
'params': gnn_params,
|
||||||
|
'weight_decay': config.gnn_weight_decay,
|
||||||
|
'lr': config['learning_rate_gnn'] if config.use_diff_lr_gnn else config['learning_rate_bert']
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
optimizer_grouped_parameters.extend(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
'params': encoder_params_without_decay,
|
||||||
|
'weight_decay': 0,
|
||||||
|
'lr': config['learning_rate_bert']
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'params': encoder_params_with_decay,
|
||||||
|
'weight_decay': 0.01,
|
||||||
|
'lr': config['learning_rate_bert']
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
optimizer = AdamW(optimizer_grouped_parameters, lr=config['learning_rate_gnn'])
|
||||||
|
scheduler = WarmupLinearScheduleNonZero(
|
||||||
|
optimizer,
|
||||||
|
warmup_steps=config['warmup_steps'],
|
||||||
|
t_total=config['train_steps'],
|
||||||
|
min_lr=config['min_lr']
|
||||||
|
)
|
||||||
|
|
||||||
|
return optimizer, scheduler
|
||||||
|
|
||||||
|
|
||||||
|
def build_torch_optimizer(model, config):
|
||||||
|
"""Builds the PyTorch optimizer.
|
||||||
|
|
||||||
|
We use the default parameters for Adam that are suggested by
|
||||||
|
the original paper https://arxiv.org/pdf/1412.6980.pdf
|
||||||
|
These values are also used by other established implementations,
|
||||||
|
e.g. https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer
|
||||||
|
https://keras.io/optimizers/
|
||||||
|
Recently there are slightly different values used in the paper
|
||||||
|
"Attention is all you need"
|
||||||
|
https://arxiv.org/pdf/1706.03762.pdf, particularly the value beta2=0.98
|
||||||
|
was used there however, beta2=0.999 is still arguably the more
|
||||||
|
established value, so we use that here as well
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The model to optimize.
|
||||||
|
config: The dictionary of options.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A ``torch.optim.Optimizer`` instance.
|
||||||
|
"""
|
||||||
|
params = [p for p in model.parameters() if p.requires_grad]
|
||||||
|
betas = [0.9, 0.999]
|
||||||
|
exclude_from_weight_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
||||||
|
|
||||||
|
params = {'bert': [], 'task': []}
|
||||||
|
for module_name, module in model.named_children():
|
||||||
|
if module_name == 'encoder':
|
||||||
|
param_type = 'bert'
|
||||||
|
else:
|
||||||
|
param_type = 'task'
|
||||||
|
for param_name, param in module.named_parameters():
|
||||||
|
if param.requires_grad:
|
||||||
|
if any(ex in param_name for ex in exclude_from_weight_decay):
|
||||||
|
params[param_type] += [
|
||||||
|
{
|
||||||
|
"params": [param],
|
||||||
|
"weight_decay": 0
|
||||||
|
}
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
params[param_type] += [
|
||||||
|
{
|
||||||
|
"params": [param],
|
||||||
|
"weight_decay": 0.01
|
||||||
|
}
|
||||||
|
]
|
||||||
|
if config['task_optimizer'] == 'adamw':
|
||||||
|
log.info('Using AdamW as task optimizer')
|
||||||
|
task_optimizer = AdamWeightDecay(params['task'],
|
||||||
|
lr=config["learning_rate_task"],
|
||||||
|
betas=betas,
|
||||||
|
eps=1e-6)
|
||||||
|
elif config['task_optimizer'] == 'adam':
|
||||||
|
log.info('Using Adam as task optimizer')
|
||||||
|
task_optimizer = optim.Adam(params['task'],
|
||||||
|
lr=config["learning_rate_task"],
|
||||||
|
betas=betas,
|
||||||
|
eps=1e-6)
|
||||||
|
if len(params['bert']) > 0:
|
||||||
|
bert_optimizer = AdamWeightDecay(params['bert'],
|
||||||
|
lr=config["learning_rate_bert"],
|
||||||
|
betas=betas,
|
||||||
|
eps=1e-6)
|
||||||
|
optimizer = MultipleOptimizer([bert_optimizer, task_optimizer])
|
||||||
|
else:
|
||||||
|
optimizer = task_optimizer
|
||||||
|
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
|
def make_learning_rate_decay_fn(decay_method, train_steps, **kwargs):
|
||||||
|
"""Returns the learning decay function from options."""
|
||||||
|
if decay_method == "linear":
|
||||||
|
return functools.partial(
|
||||||
|
linear_decay,
|
||||||
|
global_steps=train_steps,
|
||||||
|
**kwargs)
|
||||||
|
elif decay_method == "exp":
|
||||||
|
return functools.partial(
|
||||||
|
exp_decay,
|
||||||
|
global_steps=train_steps,
|
||||||
|
**kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError(f'{decay_method} not found')
|
||||||
|
|
||||||
|
|
||||||
|
def linear_decay(step, global_steps, warmup_steps=100, initial_learning_rate=1, end_learning_rate=0, **kargs):
|
||||||
|
if step < warmup_steps:
|
||||||
|
return initial_learning_rate * step / warmup_steps
|
||||||
|
else:
|
||||||
|
return (initial_learning_rate - end_learning_rate) * \
|
||||||
|
(1 - (step - warmup_steps) / (global_steps - warmup_steps)) + \
|
||||||
|
end_learning_rate
|
||||||
|
|
||||||
|
def exp_decay(step, global_steps, decay_exp=1, warmup_steps=100, initial_learning_rate=1, end_learning_rate=0, **kargs):
|
||||||
|
if step < warmup_steps:
|
||||||
|
return initial_learning_rate * step / warmup_steps
|
||||||
|
else:
|
||||||
|
return (initial_learning_rate - end_learning_rate) * \
|
||||||
|
((1 - (step - warmup_steps) / (global_steps - warmup_steps)) ** decay_exp) + \
|
||||||
|
end_learning_rate
|
||||||
|
|
||||||
|
|
||||||
|
class MultipleOptimizer(object):
|
||||||
|
""" Implement multiple optimizers needed for sparse adam """
|
||||||
|
|
||||||
|
def __init__(self, op):
|
||||||
|
""" ? """
|
||||||
|
self.optimizers = op
|
||||||
|
|
||||||
|
@property
|
||||||
|
def param_groups(self):
|
||||||
|
param_groups = []
|
||||||
|
for optimizer in self.optimizers:
|
||||||
|
param_groups.extend(optimizer.param_groups)
|
||||||
|
return param_groups
|
||||||
|
|
||||||
|
def zero_grad(self):
|
||||||
|
""" ? """
|
||||||
|
for op in self.optimizers:
|
||||||
|
op.zero_grad()
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
""" ? """
|
||||||
|
for op in self.optimizers:
|
||||||
|
op.step()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state(self):
|
||||||
|
""" ? """
|
||||||
|
return {k: v for op in self.optimizers for k, v in op.state.items()}
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
""" ? """
|
||||||
|
return [op.state_dict() for op in self.optimizers]
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dicts):
|
||||||
|
""" ? """
|
||||||
|
assert len(state_dicts) == len(self.optimizers)
|
||||||
|
for i in range(len(state_dicts)):
|
||||||
|
self.optimizers[i].load_state_dict(state_dicts[i])
|
||||||
|
|
||||||
|
|
||||||
|
class OptimizerBase(object):
|
||||||
|
"""
|
||||||
|
Controller class for optimization. Mostly a thin
|
||||||
|
wrapper for `optim`, but also useful for implementing
|
||||||
|
rate scheduling beyond what is currently available.
|
||||||
|
Also implements necessary methods for training RNNs such
|
||||||
|
as grad manipulations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
optimizer,
|
||||||
|
learning_rate,
|
||||||
|
learning_rate_decay_fn=None,
|
||||||
|
max_grad_norm=None):
|
||||||
|
"""Initializes the controller.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer: A ``torch.optim.Optimizer`` instance.
|
||||||
|
learning_rate: The initial learning rate.
|
||||||
|
learning_rate_decay_fn: An optional callable taking the current step
|
||||||
|
as argument and return a learning rate scaling factor.
|
||||||
|
max_grad_norm: Clip gradients to this global norm.
|
||||||
|
"""
|
||||||
|
self._optimizer = optimizer
|
||||||
|
self._learning_rate = learning_rate
|
||||||
|
self._learning_rate_decay_fn = learning_rate_decay_fn
|
||||||
|
self._max_grad_norm = max_grad_norm or 0
|
||||||
|
self._training_step = 1
|
||||||
|
self._decay_step = 1
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_opt(cls, model, config, checkpoint=None):
|
||||||
|
"""Builds the optimizer from options.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cls: The ``Optimizer`` class to instantiate.
|
||||||
|
model: The model to optimize.
|
||||||
|
opt: The dict of user options.
|
||||||
|
checkpoint: An optional checkpoint to load states from.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An ``Optimizer`` instance.
|
||||||
|
"""
|
||||||
|
optim_opt = config
|
||||||
|
optim_state_dict = None
|
||||||
|
|
||||||
|
if config["loads_ckpt"] and checkpoint is not None:
|
||||||
|
optim = checkpoint['optim']
|
||||||
|
ckpt_opt = checkpoint['opt']
|
||||||
|
ckpt_state_dict = {}
|
||||||
|
if isinstance(optim, Optimizer): # Backward compatibility.
|
||||||
|
ckpt_state_dict['training_step'] = optim._step + 1
|
||||||
|
ckpt_state_dict['decay_step'] = optim._step + 1
|
||||||
|
ckpt_state_dict['optimizer'] = optim.optimizer.state_dict()
|
||||||
|
else:
|
||||||
|
ckpt_state_dict = optim
|
||||||
|
|
||||||
|
if config["reset_optim"] == 'none':
|
||||||
|
# Load everything from the checkpoint.
|
||||||
|
optim_opt = ckpt_opt
|
||||||
|
optim_state_dict = ckpt_state_dict
|
||||||
|
elif config["reset_optim"] == 'all':
|
||||||
|
# Build everything from scratch.
|
||||||
|
pass
|
||||||
|
elif config["reset_optim"] == 'states':
|
||||||
|
# Reset optimizer, keep options.
|
||||||
|
optim_opt = ckpt_opt
|
||||||
|
optim_state_dict = ckpt_state_dict
|
||||||
|
del optim_state_dict['optimizer']
|
||||||
|
elif config["reset_optim"] == 'keep_states':
|
||||||
|
# Reset options, keep optimizer.
|
||||||
|
optim_state_dict = ckpt_state_dict
|
||||||
|
|
||||||
|
learning_rates = [
|
||||||
|
optim_opt["learning_rate_bert"],
|
||||||
|
optim_opt["learning_rate_gnn"]
|
||||||
|
]
|
||||||
|
decay_fn = [
|
||||||
|
make_learning_rate_decay_fn(optim_opt['decay_method_bert'],
|
||||||
|
optim_opt['train_steps'],
|
||||||
|
warmup_steps=optim_opt['warmup_steps'],
|
||||||
|
decay_exp=optim_opt['decay_exp']),
|
||||||
|
make_learning_rate_decay_fn(optim_opt['decay_method_gnn'],
|
||||||
|
optim_opt['train_steps'],
|
||||||
|
warmup_steps=optim_opt['warmup_steps'],
|
||||||
|
decay_exp=optim_opt['decay_exp']),
|
||||||
|
]
|
||||||
|
optimizer = cls(
|
||||||
|
build_torch_optimizer(model, optim_opt),
|
||||||
|
learning_rates,
|
||||||
|
learning_rate_decay_fn=decay_fn,
|
||||||
|
max_grad_norm=optim_opt["max_grad_norm"])
|
||||||
|
if optim_state_dict:
|
||||||
|
optimizer.load_state_dict(optim_state_dict)
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
@property
|
||||||
|
def training_step(self):
|
||||||
|
"""The current training step."""
|
||||||
|
return self._training_step
|
||||||
|
|
||||||
|
def learning_rate(self):
|
||||||
|
"""Returns the current learning rate."""
|
||||||
|
if self._learning_rate_decay_fn is None:
|
||||||
|
return self._learning_rate
|
||||||
|
return [decay_fn(self._decay_step) * learning_rate \
|
||||||
|
for decay_fn, learning_rate in \
|
||||||
|
zip(self._learning_rate_decay_fn, self._learning_rate)]
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
return {
|
||||||
|
'training_step': self._training_step,
|
||||||
|
'decay_step': self._decay_step,
|
||||||
|
'optimizer': self._optimizer.state_dict()
|
||||||
|
}
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict):
|
||||||
|
self._training_step = state_dict['training_step']
|
||||||
|
# State can be partially restored.
|
||||||
|
if 'decay_step' in state_dict:
|
||||||
|
self._decay_step = state_dict['decay_step']
|
||||||
|
if 'optimizer' in state_dict:
|
||||||
|
self._optimizer.load_state_dict(state_dict['optimizer'])
|
||||||
|
|
||||||
|
def zero_grad(self):
|
||||||
|
"""Zero the gradients of optimized parameters."""
|
||||||
|
self._optimizer.zero_grad()
|
||||||
|
|
||||||
|
def backward(self, loss):
|
||||||
|
"""Wrapper for backward pass. Some optimizer requires ownership of the
|
||||||
|
backward pass."""
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
"""Update the model parameters based on current gradients.
|
||||||
|
|
||||||
|
Optionally, will employ gradient modification or update learning
|
||||||
|
rate.
|
||||||
|
"""
|
||||||
|
learning_rate = self.learning_rate()
|
||||||
|
|
||||||
|
if isinstance(self._optimizer, MultipleOptimizer):
|
||||||
|
optimizers = self._optimizer.optimizers
|
||||||
|
else:
|
||||||
|
optimizers = [self._optimizer]
|
||||||
|
for lr, op in zip(learning_rate, optimizers):
|
||||||
|
for group in op.param_groups:
|
||||||
|
group['lr'] = lr
|
||||||
|
if self._max_grad_norm > 0:
|
||||||
|
clip_grad_norm_(group['params'], self._max_grad_norm)
|
||||||
|
self._optimizer.step()
|
||||||
|
self._decay_step += 1
|
||||||
|
self._training_step += 1
|
||||||
|
|
322
utils/visdial_metrics.py
Normal file
322
utils/visdial_metrics.py
Normal file
|
@ -0,0 +1,322 @@
|
||||||
|
"""
|
||||||
|
A Metric observes output of certain model, for example, in form of logits or
|
||||||
|
scores, and accumulates a particular metric with reference to some provided
|
||||||
|
targets. In context of VisDial, we use Recall (@ 1, 5, 10), Mean Rank, Mean
|
||||||
|
Reciprocal Rank (MRR) and Normalized Discounted Cumulative Gain (NDCG).
|
||||||
|
|
||||||
|
Each ``Metric`` must atleast implement three methods:
|
||||||
|
- ``observe``, update accumulated metric with currently observed outputs
|
||||||
|
and targets.
|
||||||
|
- ``retrieve`` to return the accumulated metric., an optionally reset
|
||||||
|
internally accumulated metric (this is commonly done between two epochs
|
||||||
|
after validation).
|
||||||
|
- ``reset`` to explicitly reset the internally accumulated metric.
|
||||||
|
|
||||||
|
Caveat, if you wish to implement your own class of Metric, make sure you call
|
||||||
|
``detach`` on output tensors (like logits), else it will cause memory leaks.
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def scores_to_ranks(scores: torch.Tensor):
|
||||||
|
"""Convert model output scores into ranks."""
|
||||||
|
batch_size, num_rounds, num_options = scores.size()
|
||||||
|
scores = scores.view(-1, num_options)
|
||||||
|
|
||||||
|
# sort in descending order - largest score gets highest rank
|
||||||
|
sorted_ranks, ranked_idx = scores.sort(1, descending=True)
|
||||||
|
|
||||||
|
# i-th position in ranked_idx specifies which score shall take this
|
||||||
|
# position but we want i-th position to have rank of score at that
|
||||||
|
# position, do this conversion
|
||||||
|
ranks = ranked_idx.clone().fill_(0)
|
||||||
|
for i in range(ranked_idx.size(0)):
|
||||||
|
for j in range(num_options):
|
||||||
|
ranks[i][ranked_idx[i][j]] = j
|
||||||
|
# convert from 0-99 ranks to 1-100 ranks
|
||||||
|
ranks += 1
|
||||||
|
ranks = ranks.view(batch_size, num_rounds, num_options)
|
||||||
|
return ranks
|
||||||
|
|
||||||
|
class SparseGTMetrics(object):
|
||||||
|
"""
|
||||||
|
A class to accumulate all metrics with sparse ground truth annotations.
|
||||||
|
These include Recall (@ 1, 5, 10), Mean Rank and Mean Reciprocal Rank.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._rank_list = []
|
||||||
|
self._rank_list_rnd = []
|
||||||
|
self.num_rounds = None
|
||||||
|
|
||||||
|
def observe(
|
||||||
|
self, predicted_scores: torch.Tensor, target_ranks: torch.Tensor
|
||||||
|
):
|
||||||
|
predicted_scores = predicted_scores.detach()
|
||||||
|
|
||||||
|
# shape: (batch_size, num_rounds, num_options)
|
||||||
|
predicted_ranks = scores_to_ranks(predicted_scores)
|
||||||
|
batch_size, num_rounds, num_options = predicted_ranks.size()
|
||||||
|
self.num_rounds = num_rounds
|
||||||
|
# collapse batch dimension
|
||||||
|
predicted_ranks = predicted_ranks.view(
|
||||||
|
batch_size * num_rounds, num_options
|
||||||
|
)
|
||||||
|
|
||||||
|
# shape: (batch_size * num_rounds, )
|
||||||
|
target_ranks = target_ranks.view(batch_size * num_rounds).long()
|
||||||
|
|
||||||
|
# shape: (batch_size * num_rounds, )
|
||||||
|
predicted_gt_ranks = predicted_ranks[
|
||||||
|
torch.arange(batch_size * num_rounds), target_ranks
|
||||||
|
]
|
||||||
|
self._rank_list.extend(list(predicted_gt_ranks.cpu().numpy()))
|
||||||
|
|
||||||
|
predicted_gt_ranks_rnd = predicted_gt_ranks.view(batch_size, num_rounds)
|
||||||
|
# predicted gt ranks
|
||||||
|
self._rank_list_rnd.append(predicted_gt_ranks_rnd.cpu().numpy())
|
||||||
|
|
||||||
|
def retrieve(self, reset: bool = True):
|
||||||
|
num_examples = len(self._rank_list)
|
||||||
|
if num_examples > 0:
|
||||||
|
# convert to numpy array for easy calculation.
|
||||||
|
__rank_list = torch.tensor(self._rank_list).float()
|
||||||
|
metrics = {
|
||||||
|
"r@1": torch.mean((__rank_list <= 1).float()).item(),
|
||||||
|
"r@5": torch.mean((__rank_list <= 5).float()).item(),
|
||||||
|
"r@10": torch.mean((__rank_list <= 10).float()).item(),
|
||||||
|
"mean": torch.mean(__rank_list).item(),
|
||||||
|
"mrr": torch.mean(__rank_list.reciprocal()).item()
|
||||||
|
}
|
||||||
|
# add round metrics
|
||||||
|
_rank_list_rnd = np.concatenate(self._rank_list_rnd)
|
||||||
|
_rank_list_rnd = _rank_list_rnd.astype(float)
|
||||||
|
r_1_rnd = np.mean(_rank_list_rnd <= 1, axis=0)
|
||||||
|
r_5_rnd = np.mean(_rank_list_rnd <= 5, axis=0)
|
||||||
|
r_10_rnd = np.mean(_rank_list_rnd <= 10, axis=0)
|
||||||
|
mean_rnd = np.mean(_rank_list_rnd, axis=0)
|
||||||
|
mrr_rnd = np.mean(np.reciprocal(_rank_list_rnd), axis=0)
|
||||||
|
|
||||||
|
for rnd in range(1, self.num_rounds + 1):
|
||||||
|
metrics["r_1" + "_round_" + str(rnd)] = r_1_rnd[rnd-1]
|
||||||
|
metrics["r_5" + "_round_" + str(rnd)] = r_5_rnd[rnd-1]
|
||||||
|
metrics["r_10" + "_round_" + str(rnd)] = r_10_rnd[rnd-1]
|
||||||
|
metrics["mean" + "_round_" + str(rnd)] = mean_rnd[rnd-1]
|
||||||
|
metrics["mrr" + "_round_" + str(rnd)] = mrr_rnd[rnd-1]
|
||||||
|
else:
|
||||||
|
metrics = {}
|
||||||
|
|
||||||
|
if reset:
|
||||||
|
self.reset()
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self._rank_list = []
|
||||||
|
self._rank_list_rnd = []
|
||||||
|
|
||||||
|
class NDCG(object):
|
||||||
|
def __init__(self):
|
||||||
|
self._ndcg_numerator = 0.0
|
||||||
|
self._ndcg_denominator = 0.0
|
||||||
|
|
||||||
|
def observe(
|
||||||
|
self, predicted_scores: torch.Tensor, target_relevance: torch.Tensor
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Observe model output scores and target ground truth relevance and
|
||||||
|
accumulate NDCG metric.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
predicted_scores: torch.Tensor
|
||||||
|
A tensor of shape (batch_size, num_options), because dense
|
||||||
|
annotations are available for 1 randomly picked round out of 10.
|
||||||
|
target_relevance: torch.Tensor
|
||||||
|
A tensor of shape same as predicted scores, indicating ground truth
|
||||||
|
relevance of each answer option for a particular round.
|
||||||
|
"""
|
||||||
|
predicted_scores = predicted_scores.detach()
|
||||||
|
|
||||||
|
# shape: (batch_size, 1, num_options)
|
||||||
|
predicted_scores = predicted_scores.unsqueeze(1)
|
||||||
|
predicted_ranks = scores_to_ranks(predicted_scores)
|
||||||
|
|
||||||
|
# shape: (batch_size, num_options)
|
||||||
|
predicted_ranks = predicted_ranks.squeeze(1)
|
||||||
|
batch_size, num_options = predicted_ranks.size()
|
||||||
|
|
||||||
|
k = torch.sum(target_relevance != 0, dim=-1)
|
||||||
|
|
||||||
|
# shape: (batch_size, num_options)
|
||||||
|
_, rankings = torch.sort(predicted_ranks, dim=-1)
|
||||||
|
# Sort relevance in descending order so highest relevance gets top rnk.
|
||||||
|
_, best_rankings = torch.sort(
|
||||||
|
target_relevance, dim=-1, descending=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# shape: (batch_size, )
|
||||||
|
batch_ndcg = []
|
||||||
|
for batch_index in range(batch_size):
|
||||||
|
num_relevant = k[batch_index]
|
||||||
|
dcg = self._dcg(
|
||||||
|
rankings[batch_index][:num_relevant],
|
||||||
|
target_relevance[batch_index],
|
||||||
|
)
|
||||||
|
best_dcg = self._dcg(
|
||||||
|
best_rankings[batch_index][:num_relevant],
|
||||||
|
target_relevance[batch_index],
|
||||||
|
)
|
||||||
|
batch_ndcg.append(dcg / best_dcg)
|
||||||
|
|
||||||
|
self._ndcg_denominator += batch_size
|
||||||
|
self._ndcg_numerator += sum(batch_ndcg)
|
||||||
|
|
||||||
|
def _dcg(self, rankings: torch.Tensor, relevance: torch.Tensor):
|
||||||
|
sorted_relevance = relevance[rankings].cpu().float()
|
||||||
|
discounts = torch.log2(torch.arange(len(rankings)).float() + 2)
|
||||||
|
return torch.sum(sorted_relevance / discounts, dim=-1)
|
||||||
|
|
||||||
|
def retrieve(self, reset: bool = True):
|
||||||
|
if self._ndcg_denominator > 0:
|
||||||
|
metrics = {
|
||||||
|
"ndcg": float(self._ndcg_numerator / self._ndcg_denominator)
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
metrics = {}
|
||||||
|
|
||||||
|
if reset:
|
||||||
|
self.reset()
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self._ndcg_numerator = 0.0
|
||||||
|
self._ndcg_denominator = 0.0
|
||||||
|
|
||||||
|
class SparseGTMetricsParallel(object):
|
||||||
|
"""
|
||||||
|
A class to accumulate all metrics with sparse ground truth annotations.
|
||||||
|
These include Recall (@ 1, 5, 10), Mean Rank and Mean Reciprocal Rank.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, gpu_rank):
|
||||||
|
self.rank_1 = 0
|
||||||
|
self.rank_5 = 0
|
||||||
|
self.rank_10 = 0
|
||||||
|
self.ranks = 0
|
||||||
|
self.reciprocal = 0
|
||||||
|
self.count = 0
|
||||||
|
self.gpu_rank = gpu_rank
|
||||||
|
self.img_ids = []
|
||||||
|
|
||||||
|
def observe(
|
||||||
|
self, img_id: list, predicted_scores: torch.Tensor, target_ranks: torch.Tensor
|
||||||
|
):
|
||||||
|
if img_id in self.img_ids:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
self.img_ids.append(img_id)
|
||||||
|
|
||||||
|
predicted_scores = predicted_scores.detach()
|
||||||
|
|
||||||
|
# shape: (batch_size, num_rounds, num_options)
|
||||||
|
predicted_ranks = scores_to_ranks(predicted_scores)
|
||||||
|
batch_size, num_rounds, num_options = predicted_ranks.size()
|
||||||
|
self.num_rounds = num_rounds
|
||||||
|
# collapse batch dimension
|
||||||
|
predicted_ranks = predicted_ranks.view(
|
||||||
|
batch_size * num_rounds, num_options
|
||||||
|
)
|
||||||
|
|
||||||
|
# shape: (batch_size * num_rounds, )
|
||||||
|
target_ranks = target_ranks.view(batch_size * num_rounds).long()
|
||||||
|
|
||||||
|
# shape: (batch_size * num_rounds, )
|
||||||
|
predicted_gt_ranks = predicted_ranks[
|
||||||
|
torch.arange(batch_size * num_rounds), target_ranks
|
||||||
|
]
|
||||||
|
|
||||||
|
self.rank_1 += (predicted_gt_ranks <= 1).sum().item()
|
||||||
|
self.rank_5 += (predicted_gt_ranks <= 5).sum().item()
|
||||||
|
self.rank_10 += (predicted_gt_ranks <= 10).sum().item()
|
||||||
|
self.ranks += predicted_gt_ranks.sum().item()
|
||||||
|
self.reciprocal += predicted_gt_ranks.float().reciprocal().sum().item()
|
||||||
|
self.count += batch_size * num_rounds
|
||||||
|
|
||||||
|
def retrieve(self):
|
||||||
|
if self.count > 0:
|
||||||
|
# retrieve data from all gpu
|
||||||
|
# define tensor on GPU, count and total is the result at each GPU
|
||||||
|
t = torch.tensor([self.rank_1, self.rank_5, self.rank_10, self.ranks, self.reciprocal, self.count], dtype=torch.float32, device=f'cuda:{self.gpu_rank}')
|
||||||
|
dist.barrier() # synchronizes all processes
|
||||||
|
dist.all_reduce(t, op=torch.distributed.ReduceOp.SUM,) # Reduces the tensor data across all machines in such a way that all get the final result.
|
||||||
|
t = t.tolist()
|
||||||
|
self.rank_1, self.rank_5, self.rank_10, self.ranks, self.reciprocal, self.count = t
|
||||||
|
|
||||||
|
# convert to numpy array for easy calculation.
|
||||||
|
metrics = {
|
||||||
|
"r@1": self.rank_1 / self.count,
|
||||||
|
"r@5": self.rank_5 / self.count,
|
||||||
|
"r@10": self.rank_10 / self.count,
|
||||||
|
"mean": self.ranks / self.count,
|
||||||
|
"mrr": self.reciprocal / self.count,
|
||||||
|
"tot_rnds": self.count,
|
||||||
|
}
|
||||||
|
|
||||||
|
else:
|
||||||
|
metrics = {}
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
def get_count(self):
|
||||||
|
return int(self.count)
|
||||||
|
|
||||||
|
class NDCGParallel(NDCG):
|
||||||
|
def __init__(self, gpu_rank):
|
||||||
|
super(NDCGParallel, self).__init__()
|
||||||
|
self.gpu_rank = gpu_rank
|
||||||
|
self.img_ids = []
|
||||||
|
self.count = 0
|
||||||
|
|
||||||
|
def observe(
|
||||||
|
self, img_id: int, predicted_scores: torch.Tensor, target_relevance: torch.Tensor
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Observe model output scores and target ground truth relevance and
|
||||||
|
accumulate NDCG metric.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
predicted_scores: torch.Tensor
|
||||||
|
A tensor of shape (batch_size, num_options), because dense
|
||||||
|
annotations are available for 1 randomly picked round out of 10.
|
||||||
|
target_relevance: torch.Tensor
|
||||||
|
A tensor of shape same as predicted scores, indicating ground truth
|
||||||
|
relevance of each answer option for a particular round.
|
||||||
|
"""
|
||||||
|
if img_id in self.img_ids:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
self.img_ids.append(img_id)
|
||||||
|
self.count += 1
|
||||||
|
|
||||||
|
super(NDCGParallel, self).observe(predicted_scores, target_relevance)
|
||||||
|
|
||||||
|
|
||||||
|
def retrieve(self):
|
||||||
|
if self._ndcg_denominator > 0:
|
||||||
|
# define tensor on GPU, count and total is the result at each GPU
|
||||||
|
t = torch.tensor([self._ndcg_numerator, self._ndcg_denominator, self.count], dtype=torch.float32, device=f'cuda:{self.gpu_rank}')
|
||||||
|
dist.barrier() # synchronizes all processes
|
||||||
|
dist.all_reduce(t, op=torch.distributed.ReduceOp.SUM,) # Reduces the tensor data across all machines in such a way that all get the final result.
|
||||||
|
t = t.tolist()
|
||||||
|
self._ndcg_numerator, self._ndcg_denominator, self.count = t
|
||||||
|
metrics = {
|
||||||
|
"ndcg": float(self._ndcg_numerator / self._ndcg_denominator)
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
metrics = {}
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
def get_count(self):
|
||||||
|
return int(self.count)
|
Loading…
Reference in a new issue