commit 09fb25e339b3babf282a35d4a3bd67fafe6955c3 Author: abdessaied Date: Wed Oct 25 15:38:09 2023 +0200 Code release diff --git a/README.md b/README.md new file mode 100644 index 0000000..6ccab22 --- /dev/null +++ b/README.md @@ -0,0 +1,277 @@ +
+

VD-GR: Boosting Visual Dialog with Cascaded Spatial-Temporal Multi-Modal GRaphs

+ +**[Adnen Abdessaied][5],   [Lei Shi][6],   [Andreas Bulling][7]**

+**WACV'24, Hawaii, USA**
+**[[Paper][8]]** + +------------------- +

+ +
+ +# Table of Contents +* [Setup and Dependencies](#Setup-and-Dependencies) +* [Download Data](#Download-Data) +* [Pre-trained Checkpoints](#Pre-trained-Checkpoints) +* [Training](#Training) +* [Results](#Results) + +# Setup and Dependencies +We implemented our model using Python 3.7 and PyTorch 1.11.0 (CUDA 11.3, CuDNN 8.2.0). We recommend to setup a virtual environment using Anaconda.
+1. Install [git lfs][1] on your system +2. Clone our repository to download the data, checkpoints, and code + ```shell + git lfs install + git clone this_repo.git + ``` +3. Create a conda environment and install dependencies + ```shell + conda create -n vdgr python=3.7 + conda activate vdgr + conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch + conda install pyg -c pyg # 2.1.0 + pip install pytorch-transformers + pip install pytorch_pretrained_bert + pip install pyhocon glog wandb lmdb + ``` +4. If you wish to speed-up training, we recommend installing [apex][2] + ```shell + git clone https://github.com/NVIDIA/apex + cd apex + # if pip >= 23.1 (ref: https://pip.pypa.io/en/stable/news/#v23-1) which supports multiple `--config-settings` with the same key... + pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ + # otherwise + pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --global-option="--cpp_ext" --global-option="--cuda_ext" ./ + cd .. + ``` + +# Download Data +1. Download the extacted visual features of [VisDial][3] and setup all files we used in our work. We provide a shell script for convenience: +```shell +./setup_data.sh # Please make sure you have enough disk space +``` +If everything was correctly setup, the ```data/``` folder should look like this +``` +├── history_adj_matrices +│ ├── test +│ ├── *.pkl +│ ├── train +│ ├── *.pkl +│ ├── val +│ ├── *.pkl +├── question_adj_matrices +│ ├── test +│ ├── *.pkl +│ ├── train +│ ├── *.pkl +│ ├── val +│ ├── *.pkl +├── img_adj_matrices +│ ├── *.pkl +├── parse_vocab.pkl +├── test_dense_mapping.json +├── tr_dense_mapping.json +├── val_dense_mapping.json +├── visdial_0.9_test.json +├── visdial_0.9_train.json +├── visdial_0.9_val.json +├── visdial_1.0_test.json +├── visdial_1.0_train_dense_annotations.json +├── visdial_1.0_train_dense.json +├── visdial_1.0_train.json +├── visdial_1.0_val_dense_annotations.json +├── visdial_1.0_val.json +├── visdialconv_dense_annotations.json +├── visdialconv.json +├── vispro_dense_annotations.json +└── vispro.json +``` +# Pre-trained Checkpoints +For convenience, we provide checkpoints of our model after the warm-up training stage in ```ckpt/``` for both VisDial v1.0 and VisDial v0.9.
+These checkpoints will be downloaded with the code if you use ```git lfs```. + +# Training +We trained our model on 8 Nvidia Tesla V100-32GB GPUs. The default hyperparameters in ```config/vdgr.conf``` and ```config/bert_base_6layer_6conect.json``` need to be adjusted if your setup differs from ours. + +## Phase 1 +### Training +1. In this phase, we train our model on VisDial v1.0 via +```shell +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \ +--model vdgr/P1 \ +--mode train \ +--tag K2_v1.0 \ +--wandb_mode online \ +--wandb_project your_wandb_project_name +``` +⚠️ On a similar setup to ours, this will take roughly 20h to complete using apex for training. + +2. To train on VisDial v0.9: + * Set ```visdial_version = 0.9``` in ```config/vdgr.conf``` + * Set ```start_path = ckpt/vdgr_visdial_v0.9_after_warmup_K2.ckpt``` in ```config/vdgr.conf``` + * Run + ```shell + CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \ + --model vdgr/P1 \ + --mode train \ + --tag K2_v0.9 \ + --wandb_mode online \ + --wandb_project your_wandb_project_name + ``` +### Inference +1. For inference on VisDial v1.0 val, VisDialConv, or VisPro: + * Set ```eval_dataset = {visdial, visdial_conv, visdial_vispro}``` in ```logs/vdgr/P1_K2_v1.0/code/config/vdgr.conf``` + * Run + ```shell + CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \ + --model vdgr/P1 \ + --mode eval \ + --eval_dir logs/vdgr/P1_K2_v1.0 \ + --wandb_mode offline \ + ``` +2. For inference on VisDial v0.9: + * Set ```eval_dataset = visdial``` in ```logs/vdgr/P1_K2_v0.9/code/config/vdgr.conf``` + * Run + ```shell + CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \ + --model vdgr/P1 \ + --mode eval \ + --eval_dir logs/vdgr/P1_K2_v0.9 \ + --wandb_mode offline \ + ``` +⚠️ This might take some time to finish as the testing data of VisDial v0.9 is large. + +3. For inference on the ```visdial_v1.0 test```: + * Run + ```shell + CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \ + --model vdgr/P1 \ + --mode predict \ + --eval_dir logs/vdgr/P1_K2_v1.0 \ + --wandb_mode offline \ + ``` + * The output file will be saved in ```output/``` + +## Phase 2 +In this phase, we finetune on dense annotations to improve the NDCG score (Only supported for VisDial v1.0.) +1. Run +```shell +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \ +--model vdgr/P2_CE \ +--mode train \ +--tag K2_v1.0_CE \ +--wandb_mode online \ +--wandb_project your_wandb_project_name +``` +⚠️This will take roughly 3-4 hours to complete using the same setup as before and [DP][4] for training. + +2. For inference on VisDial v1.0: + * Run: + ```shell + CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \ + --model vdgr/P2_CE \ + --mode predict \ + --eval_dir logs/vdgr/P1_K2_v1.0_CE \ + --wandb_mode offline \ + ``` + * The output file will be saved in ```output/``` + +## Phase 3 +### Training +In the final phase, we train an ensemble method comprising of 8 models using ```K={1,2,3,4}``` and ```dense_loss={ce, listnet}```. +For ```K=k```: +1. Set the value of ```num_v_gnn_layers, num_q_gnn_layers, num_h_gnn_layers``` to ```k``` +2. Set ```start_path = ckpt/vdgr_visdial_v1.0_after_warmup_K[k].ckpt``` in ```config/vdgr.conf``` (P1) +3. Phase 1 training: +```shell +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \ +--model vdgr/P1 \ +--mode train \ +--tag K[k]_v1.0 \ +--wandb_mode online \ +--wandb_project your_wandb_project_name +``` +3. Set ```start_path = logs/vdgr/P1_K[k]_v1.0/epoch_best.ckpt``` in ```config/vdgr.conf``` (P2) +4. Phase 2 training: +* Fine-tune with CE: + ```shell + CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \ + --model vdgr/P2_CE \ + --mode train \ + --tag K[k]_v1.0_CE \ + --wandb_mode online \ + --wandb_project your_wandb_project_name +``` +* Fine-tune with LISTNET: + ```shell + CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \ + --model vdgr/P2_LISTNET \ + --mode train \ + --tag K[k]_v1.0_LISTNET \ + --wandb_mode online \ + --wandb_project your_wandb_project_name +``` +### Inference +1. For inference on VisDial v1.0 test: +```shell +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \ +--model vdgr/P2_[CE,LISTNET] \ +--mode predict \ +--eval_dir logs/vdgr/P2_K[1,2,3,4]_v1.0_[CE,LISTNET] \ +--wandb_mode offline \ +``` +2. Finally, merge the outputs of all models +```shell + python ensemble.py \ +--exp test \ +--mode predict \ +``` +The output file will be saved in ```output/``` + +# Results +## VisDial v0.9 +| Model | MRR | R@1 | R@5 | R@10 | Mean | +|:--------:|:---:|:---:|:---:|:----:|:----:| +| Prev. SOTA | 71.99 | 59.41 | 87.92 | 94.59 | 2.87 | +| VD-GR | **74.50** | **62.10** | **90.49** | **96.37** | **2.45** | + +## VisDialConv +| Model | NDCG | MRR | R@1 | R@5 | R@10 | Mean | +|:--------:|:----:|:---:|:---:|:---:|:----:|:----:| +| Prev. SOTA | 61.72 | 61.79 | 48.95 | 77.50 | 86.71 | 4.72 | +| VD-GR | **67.09** | **66.82** | **54.47** | **81.71** | **91.44** | **3.54** | + +## VisPro +| Model | NDCG | MRR | R@1 | R@5 | R@10 | Mean | +|:--------:|:----:|:---:|:---:|:---:|:----:|:----:| +| Prev. SOTA | 59.30 | 62.29 | 48.35 | 80.10 | 88.87 | 4.37 | +| VD-GR | **60.35** | **69.89** | **57.21** | **85.97** | **92.68** | **3.15** | + +## VisDial V1.0 Val +| Model | NDCG | MRR | R@1 | R@5 | R@10 | Mean | +|:--------:|:----:|:---:|:---:|:---:|:----:|:----:| +| Prev. SOTA | 65.47 | 69.71 | 56.79 | 85.82 | 93.64 | 3.15 | +| VD-GR | 64.32 | **69.91** | **57.01** | **86.14** | **93.74** | **3.13** | + +## VisDial V1.0 Test +| Model | NDCG | MRR | R@1 | R@5 | R@10 | Mean | +|:--------:|:----:|:---:|:---:|:---:|:----:|:----:| +| Prev. SOTA | 64.91 | 68.73 | 55.73 | 85.38 | 93.53 | 3.21 | +| VD-GR | 63.49 | 68.65 | 55.33 | **85.58** | **93.85** | **3.20** | +| ♣️ Prev. SOTA | 75.92 | 56.18 | 45.32 | 68.05 | 80.98 | 5.42 | +| ♣️ VD-GR | **75.95** | **58.30** | **46.55** | **71.45** | 84.52 | **5.32** | +| ♣️♦️ Prev. SOTA | 76.17 | 56.42 | 44.75 | 70.23 | 84.52 | 5.47 | +| ♣️♦️ VD-GR | **76.43** | 56.35 | **45.18** | 68.13 | 82.18 | 5.79 | + +♣️ = Finetuning on dense annotations, ♦️ = Ensemble model + + +[1]: https://git-lfs.com/ +[2]: https://github.com/NVIDIA/apex +[3]: https://visualdialog.org/ +[4]: https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html +[5]: https://adnenabdessaied.de +[6]: https://www.perceptualui.org/people/shi/ +[7]: https://www.perceptualui.org/people/bulling/ +[8]: https://drive.google.com/file/d/1GT0WDinA_z5FdwVc_bWtyB-cwQkGIf7C/view?usp=sharing diff --git a/ckpt/.gitkeep b/ckpt/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/config/bert_base_6layer_6conect.json b/config/bert_base_6layer_6conect.json new file mode 100644 index 0000000..f8e3563 --- /dev/null +++ b/config/bert_base_6layer_6conect.json @@ -0,0 +1,40 @@ +{ + "attention_probs_dropout_prob": 0.1, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "max_position_embeddings": 512, + "num_attention_heads": 12, + "num_hidden_layers": 12, + "type_vocab_size": 2, + "vocab_size": 30522, + "v_feature_size": 2048, + "v_target_size": 1601, + "v_hidden_size": 1024, + "v_num_hidden_layers": 6, + "v_num_attention_heads": 8, + "v_intermediate_size": 1024, + "bi_hidden_size": 1024, + "bi_num_attention_heads": 8, + "bi_intermediate_size": 1024, + "bi_attention_type": 1, + "v_attention_probs_dropout_prob": 0.1, + "v_hidden_act": "gelu", + "v_hidden_dropout_prob": 0.1, + "v_initializer_range": 0.02, + "pooling_method": "mul", + "v_biattention_id": [0, 1, 2, 3, 4, 5], + "t_biattention_id": [6, 7, 8, 9, 10, 11], + "gnn_act": "gelu", + "num_v_gnn_layers": 2, + "num_q_gnn_layers": 2, + "num_h_gnn_layers": 2, + "num_gnn_attention_heads": 4, + "gnn_dropout_prob": 0.1, + "v_gnn_edge_dim": 12, + "q_gnn_edge_dim": 48, + "v_gnn_ids": [0, 1, 2, 3, 4, 5], + "t_gnn_ids": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] +} \ No newline at end of file diff --git a/config/ensemble.conf b/config/ensemble.conf new file mode 100644 index 0000000..f3844c2 --- /dev/null +++ b/config/ensemble.conf @@ -0,0 +1,33 @@ +test = { + split = test + skip_mrr_eval = true + # data + visdial_test_data = data/visdial_1.0_test.json + + # directory + log_dir = logs/vdgr_ensemble + pred_dir = logs/vdgr + visdial_output_dir = visdial_output + + processed = [ + false, + false, + false, + false, + false, + false, + false, + false + ] + + models = [ + "P2_K1_v1.0_CE", + "P2_K2_v1.0_CE", + "P2_K3_v1.0_CE", + "P2_K4_v1.0_CE", + "P2_K1_v1.0_LISTNET", + "P2_K2_v1.0_LISTNET", + "P2_K3_v1.0_LISTNET", + "P2_K4_v1.0_LISTNET", + ] +} diff --git a/config/vdgr.conf b/config/vdgr.conf new file mode 100644 index 0000000..6ef305c --- /dev/null +++ b/config/vdgr.conf @@ -0,0 +1,188 @@ +# Phase 1 +P1 { + use_cpu = false + visdial_version = 1.0 + train_on_dense = false + metrics_to_maximize = mrr + + # visdial data + visdial_image_feats = data/visdial_img_feat.lmdb + + visdial_image_adj_matrices = data/img_adj_matrices + visdial_question_adj_matrices = data/question_adj_matrices + visdial_history_adj_matrices = data/history_adj_matrices + + visdial_train = data/visdial_1.0_train.json + visdial_val = data/visdial_1.0_val.json + visdial_test = data/visdial_1.0_test.json + visdial_val_dense_annotations = data/visdial_1.0_val_dense_annotations.json + + visdial_train_09 = data/visdial_0.9_train.json + visdial_val_09 = data/visdial_0.9_val.json + visdial_test_09 = data/visdial_0.9_test.json + + visdialconv_val = data/visdial_conv.json + visdialconv_val_dense_annotations = data/visdialconv_dense_annotations.json + + visdialvispro_val = data/vispro.json + visdialvispro_val_dense_annotations = data/vispro_dense_annotations.json + + visdial_question_parse_vocab = data/parse_vocab.pkl + + # init + start_path = ckpt/vdgr_visdial_v1.0_after_warmup_K2.ckpt + model_config = config/bert_base_6layer_6conect.json + + # visdial training + freeze_vilbert = false + visdial_tot_rounds = 11 + num_negative_samples = 1 + sequences_per_image = 2 + batch_size = 8 + lm_loss_coeff = 1 + nsp_loss_coeff = 1 + img_loss_coeff = 1 + batch_multiply = 1 + use_trainval = false + dense_loss = ce + dense_loss_coeff = 0 + dataloader_text_only = false + rlv_hst_only = false + rlv_hst_dense_round = false + + # visdial model + mask_prob = 0.1 + image_mask_prob = 0.1 + max_seq_len = 256 + num_options = 100 + num_options_dense = 100 + use_embedding = joint + + # visdial evaluation + eval_visdial_on_test = true + eval_batch_size = 1 + eval_line_batch_size = 200 + skip_mrr_eval = false + skip_ndcg_eval = false + skip_visdial_eval = false + eval_visdial_every = 1 + eval_dataset = visdial # visdial_vispro # choices = [visdial, visdial_conv, visdial_vispro ] + + continue_evaluation = false + eval_at_start = false + eval_before_training = false + initializer = normal + bert_cased = false + + # restore ckpt + loads_best_ckpt = false + loads_ckpt = false + restarts = false + resets_max_metric = false + uses_new_optimizer = false + sets_new_lr = false + loads_start_path = false + + # logging + random_seed = 42 + next_logging_pct = 1.0 + next_evaluating_pct = 50.0 + max_ckpt_to_keep = 1 + num_epochs = 20 + early_stop_epoch = 5 + skip_saving_ckpt = false + dp_type = apex + stack_gr_data = false + master_port = 5122 + stop_epochs = -1 + train_each_round = false + drop_last_answer = false + num_samples = -1 + + # predicting + predict_split = test + predict_each_round = false + predict_dense_round = false + num_test_dialogs = 8000 + num_val_dialogs = 2064 + save_score = false + + # optimizer + reset_optim = none + learning_rate_bert = 5e-6 + learning_rate_gnn = 2e-4 + gnn_weight_decay = 0.01 + use_diff_lr_gnn = true + min_lr = 0 + decay_method_bert = linear + decay_method_gnn = linear + decay_exp = 2 + max_grad_norm = 1.0 + task_optimizer = adam + warmup_ratio = 0.1 + + # directory + log_dir = logs/vdgr + data_dir = data + visdial_output_dir = visdial_output + bert_cache_dir = transformers + + # keep track of other hparams in bert json + v_gnn_edge_dim = 12 # 11 classes + hub_node + q_gnn_edge_dim = 48 # 47 classes + hub_node + num_v_gnn_layers = 2 + num_q_gnn_layers = 2 + num_h_gnn_layers = 2 + num_gnn_attention_heads = 4 + v_gnn_ids = [0, 1, 2, 3, 4, 5] + t_gnn_ids = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] +} + +# Phase 2 +P2_CE = ${P1} { + # basic + train_on_dense = true + use_trainval = true + metrics_to_maximize = ndcg + + visdial_train_dense = data/visdial_1.0_train_dense.json + visdial_train_dense_annotations = data/visdial_1.0_train_dense_annotations.json + visdial_val_dense = data/visdial_1.0_val.json + + tr_graph_idx_mapping = data/tr_dense_mapping.json + val_graph_idx_mapping = data/val_dense_mapping.json + test_graph_idx_mapping = data/test_dense_mapping.json + + visdial_val = data/visdial_1.0_val.json + visdial_val_dense_annotations = data/visdial_1.0_val_dense_annotations.json + + # data + start_path = logs/vdgr/P1_K2_v1.0/epoch_best.ckpt + rlv_hst_only = false + + # visdial training + nsp_loss_coeff = 0 + dense_loss_coeff = 1 + batch_multiply = 10 + batch_size = 1 + + # visdial model + num_options_dense = 100 + + # visdial evaluation + eval_batch_size = 1 + eval_line_batch_size = 100 + skip_mrr_eval = true + + # training + stop_epochs = 3 + dp_type = dp + dense_loss = ce + + # optimizer + learning_rate_bert = 1e-4 +} + +P2_LISTNET = ${P2_CE} { + dense_loss = listnet +} diff --git a/data/.gitkeep b/data/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/dataloader/__init__.py b/dataloader/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dataloader/dataloader_base.py b/dataloader/dataloader_base.py new file mode 100644 index 0000000..41c78da --- /dev/null +++ b/dataloader/dataloader_base.py @@ -0,0 +1,269 @@ +import torch +from torch.utils import data +import json +import os +import glog as log +import pickle + +import torch.utils.data as tud +from pytorch_transformers.tokenization_bert import BertTokenizer + +from utils.image_features_reader import ImageFeaturesH5Reader + + +class DatasetBase(data.Dataset): + + def __init__(self, config): + + if config['display']: + log.info('Initializing dataset') + + # Fetch the correct dataset for evaluation + if config['validating']: + assert config.eval_dataset in ['visdial', 'visdial_conv', 'visdial_vispro', 'visdial_v09'] + if config.eval_dataset == 'visdial_conv': + config['visdial_val'] = config.visdialconv_val + config['visdial_val_dense_annotations'] = config.visdialconv_val_dense_annotations + elif config.eval_dataset == 'visdial_vispro': + config['visdial_val'] = config.visdialvispro_val + config['visdial_val_dense_annotations'] = config.visdialvispro_val_dense_annotations + elif config.eval_dataset == 'visdial_v09': + config['visdial_val_09'] = config.visdial_test_09 + config['visdial_val_dense_annotations'] = None + + self.config = config + self.numDataPoints = {} + + if not config['dataloader_text_only']: + self._image_features_reader = ImageFeaturesH5Reader( + config['visdial_image_feats'], + config['visdial_image_adj_matrices'] + ) + + if self.config['training'] or self.config['validating'] or self.config['predicting']: + split2data = {'train': 'train', 'val': 'val', 'test': 'test'} + elif self.config['debugging']: + split2data = {'train': 'val', 'val': 'val', 'test': 'test'} + elif self.config['visualizing']: + split2data = {'train': 'train', 'val': 'train', 'test': 'test'} + + filename = f'visdial_{split2data["train"]}' + if config['train_on_dense']: + filename += '_dense' + if self.config['visdial_version'] == 0.9: + filename += '_09' + + with open(config[filename]) as f: + self.visdial_data_train = json.load(f) + if self.config.num_samples > 0: + self.visdial_data_train['data']['dialogs'] = self.visdial_data_train['data']['dialogs'][:self.config.num_samples] + self.numDataPoints['train'] = len(self.visdial_data_train['data']['dialogs']) + + filename = f'visdial_{split2data["val"]}' + if config['train_on_dense'] and config['training']: + filename += '_dense' + if self.config['visdial_version'] == 0.9: + filename += '_09' + + with open(config[filename]) as f: + self.visdial_data_val = json.load(f) + if self.config.num_samples > 0: + self.visdial_data_val['data']['dialogs'] = self.visdial_data_val['data']['dialogs'][:self.config.num_samples] + self.numDataPoints['val'] = len(self.visdial_data_val['data']['dialogs']) + + if config['train_on_dense']: + self.numDataPoints['trainval'] = self.numDataPoints['train'] + self.numDataPoints['val'] + with open(config[f'visdial_{split2data["test"]}']) as f: + self.visdial_data_test = json.load(f) + self.numDataPoints['test'] = len(self.visdial_data_test['data']['dialogs']) + + self.rlv_hst_train = None + self.rlv_hst_val = None + self.rlv_hst_test = None + + if config['train_on_dense'] or config['predict_dense_round']: + with open(config[f'visdial_{split2data["train"]}_dense_annotations']) as f: + self.visdial_data_train_dense = json.load(f) + if config['train_on_dense']: + self.subsets = ['train', 'val', 'trainval', 'test'] + else: + self.subsets = ['train','val','test'] + self.num_options = config["num_options"] + self.num_options_dense = config["num_options_dense"] + if config['visdial_version'] != 0.9: + with open(config[f'visdial_{split2data["val"]}_dense_annotations']) as f: + self.visdial_data_val_dense = json.load(f) + else: + self.visdial_data_val_dense = None + self._split = 'train' + self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', cache_dir=config['bert_cache_dir']) + # fetching token indicecs of [CLS] and [SEP] + tokens = ['[CLS]','[MASK]','[SEP]'] + indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokens) + self.CLS = indexed_tokens[0] + self.MASK = indexed_tokens[1] + self.SEP = indexed_tokens[2] + self._max_region_num = 37 + self.predict_each_round = self.config['predicting'] and self.config['predict_each_round'] + + self.keys_to_expand = ['image_feat', 'image_loc', 'image_mask', 'image_target', 'image_label'] + self.keys_to_flatten_1d = ['hist_len', 'next_sentence_labels', 'round_id', 'image_id'] + self.keys_to_flatten_2d = ['tokens', 'segments', 'sep_indices', 'mask', 'image_mask', 'image_label', 'input_mask', 'question_limits'] + self.keys_to_flatten_3d = ['image_feat', 'image_loc', 'image_target', ] + self.keys_other = ['gt_relevance', 'gt_option_inds'] + self.keys_lists_to_flatten = ['image_edge_indices', 'image_edge_attributes', 'question_edge_indices', 'question_edge_attributes', 'history_edge_indices', 'history_sep_indices'] + if config['stack_gr_data']: + self.keys_to_flatten_3d.extend(self.keys_lists_to_flatten[:-1]) + self.keys_to_flatten_2d.append(self.keys_lists_to_flatten[-1]) + self.keys_to_flatten_1d.extend(['len_image_gr', 'len_question_gr', 'len_history_gr', 'len_history_sep']) + self.keys_lists_to_flatten = [] + + self.keys_to_list = ['tot_len'] + + # Load the parse vocab for question graph relationship mapping + if os.path.isfile(config['visdial_question_parse_vocab']): + with open(config['visdial_question_parse_vocab'], 'rb') as f: + self.parse_vocab = pickle.load(f) + + def __len__(self): + return self.numDataPoints[self._split] + + @property + def split(self): + return self._split + + @split.setter + def split(self, split): + assert split in self.subsets + self._split = split + + def tokens2str(self, seq): + dialog_sequence = '' + for sentence in seq: + for word in sentence: + dialog_sequence += self.tokenizer._convert_id_to_token(word) + " " + dialog_sequence += ' ' + dialog_sequence = dialog_sequence.encode('utf8') + return dialog_sequence + + def pruneRounds(self, context, num_rounds): + start_segment = 1 + len_context = len(context) + cur_rounds = (len(context) // 2) + 1 + l_index = 0 + if cur_rounds > num_rounds: + # caption is not part of the final input + l_index = len_context - (2 * num_rounds) + start_segment = 0 + return context[l_index:], start_segment + + def tokenize_utterance(self, sent, sentences, tot_len, sentence_count, sentence_map, speakers): + sentences.extend(sent + ['[SEP]']) + tokenized_sent = self.tokenizer.convert_tokens_to_ids(sent) + assert len(sent) == len(tokenized_sent), 'sub-word tokens are not allowed!' + + sent_len = len(tokenized_sent) + tot_len += sent_len + 1 # the additional 1 is for the sep token + sentence_count += 1 + sentence_map.extend([sentence_count * 2 - 1] * sent_len) + sentence_map.append(sentence_count * 2) # for [SEP] + speakers.extend([2] * (sent_len + 1)) + + return tokenized_sent, sentences, tot_len, sentence_count, sentence_map, speakers + + def __getitem__(self, index): + return NotImplementedError + + def collate_fn(self, batch): + tokens_size = batch[0]['tokens'].size() + num_rounds, num_samples = tokens_size[0], tokens_size[1] + merged_batch = {key: [d[key] for d in batch] for key in batch[0]} + + if self.config['stack_gr_data']: + if (len(batch)) > 1: + max_question_gr_len = max([length.max().item() for length in merged_batch['len_question_gr']]) + max_history_gr_len = max([length.max().item() for length in merged_batch['len_history_gr']]) + max_history_sep_len = max([length.max().item() for length in merged_batch['len_history_sep']]) + max_image_gr_len = max([length.max().item() for length in merged_batch['len_image_gr']]) + + question_edge_indices_padded = [] + question_edge_attributes_padded = [] + + for q_e_idx, q_e_attr in zip(merged_batch['question_edge_indices'], merged_batch['question_edge_attributes']): + b_size, edge_dim, orig_len = q_e_idx.size() + q_e_idx_padded = torch.zeros((b_size, edge_dim, max_question_gr_len)) + q_e_idx_padded[:, :, :orig_len] = q_e_idx + question_edge_indices_padded.append(q_e_idx_padded) + + edge_attr_dim = q_e_attr.size(-1) + q_e_attr_padded = torch.zeros((b_size, max_question_gr_len, edge_attr_dim)) + q_e_attr_padded[:, :orig_len, :] = q_e_attr + question_edge_attributes_padded.append(q_e_attr_padded) + + merged_batch['question_edge_indices'] = question_edge_indices_padded + merged_batch['question_edge_attributes'] = question_edge_attributes_padded + + history_edge_indices_padded = [] + for h_e_idx in merged_batch['history_edge_indices']: + b_size, _, orig_len = h_e_idx.size() + h_edge_idx_padded = torch.zeros((b_size, 2, max_history_gr_len)) + h_edge_idx_padded[:, :, :orig_len] = h_e_idx + history_edge_indices_padded.append(h_edge_idx_padded) + merged_batch['history_edge_indices'] = history_edge_indices_padded + + history_sep_indices_padded = [] + for hist_sep_idx in merged_batch['history_sep_indices']: + b_size, orig_len = hist_sep_idx.size() + hist_sep_idx_padded = torch.zeros((b_size, max_history_sep_len)) + hist_sep_idx_padded[:, :orig_len] = hist_sep_idx + history_sep_indices_padded.append(hist_sep_idx_padded) + merged_batch['history_sep_indices'] = history_sep_indices_padded + + image_edge_indices_padded = [] + image_edge_attributes_padded = [] + for img_e_idx, img_e_attr in zip(merged_batch['image_edge_indices'], merged_batch['image_edge_attributes']): + b_size, edge_dim, orig_len = img_e_idx.size() + img_e_idx_padded = torch.zeros((b_size, edge_dim, max_image_gr_len)) + img_e_idx_padded[:, :, :orig_len] = img_e_idx + image_edge_indices_padded.append(img_e_idx_padded) + + edge_attr_dim = img_e_attr.size(-1) + img_e_attr_padded = torch.zeros((b_size, max_image_gr_len, edge_attr_dim)) + img_e_attr_padded[:, :orig_len, :] = img_e_attr + image_edge_attributes_padded.append(img_e_attr_padded) + + merged_batch['image_edge_indices'] = image_edge_indices_padded + merged_batch['image_edge_attributes'] = image_edge_attributes_padded + + out = {} + for key in merged_batch: + if key in self.keys_lists_to_flatten: + temp = [] + for b in merged_batch[key]: + for x in b: + temp.append(x) + merged_batch[key] = temp + + elif key in self.keys_to_list: + pass + else: + merged_batch[key] = torch.stack(merged_batch[key], 0) + if key in self.keys_to_expand: + if len(merged_batch[key].size()) == 3: + size0, size1, size2 = merged_batch[key].size() + expand_size = (size0, num_rounds, num_samples, size1, size2) + elif len(merged_batch[key].size()) == 2: + size0, size1 = merged_batch[key].size() + expand_size = (size0, num_rounds, num_samples, size1) + merged_batch[key] = merged_batch[key].unsqueeze(1).unsqueeze(1).expand(expand_size).contiguous() + if key in self.keys_to_flatten_1d: + merged_batch[key] = merged_batch[key].reshape(-1) + elif key in self.keys_to_flatten_2d: + merged_batch[key] = merged_batch[key].reshape(-1, merged_batch[key].shape[-1]) + elif key in self.keys_to_flatten_3d: + merged_batch[key] = merged_batch[key].reshape(-1, merged_batch[key].shape[-2], merged_batch[key].shape[-1]) + else: + assert key in self.keys_other, f'unrecognized key in collate_fn: {key}' + + out[key] = merged_batch[key] + return out diff --git a/dataloader/dataloader_visdial.py b/dataloader/dataloader_visdial.py new file mode 100644 index 0000000..4c13035 --- /dev/null +++ b/dataloader/dataloader_visdial.py @@ -0,0 +1,615 @@ +import torch +import os +import numpy as np +import random +import pickle + +import sys +sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) + +from utils.data_utils import encode_input, encode_input_with_mask, encode_image_input +from dataloader.dataloader_base import DatasetBase + + +class VisdialDataset(DatasetBase): + + def __init__(self, config): + super(VisdialDataset, self).__init__(config) + + def __getitem__(self, index): + MAX_SEQ_LEN = self.config['max_seq_len'] + cur_data = None + if self._split == 'train': + cur_data = self.visdial_data_train['data'] + ques_adj_matrices_dir = os.path.join(self.config['visdial_question_adj_matrices'], 'train') + hist_adj_matrices_dir = os.path.join(self.config['visdial_history_adj_matrices'], 'train') + + elif self._split == 'val': + cur_data = self.visdial_data_val['data'] + ques_adj_matrices_dir = os.path.join(self.config['visdial_question_adj_matrices'], 'val') + hist_adj_matrices_dir = os.path.join(self.config['visdial_history_adj_matrices'], 'val') + + else: + cur_data = self.visdial_data_test['data'] + ques_adj_matrices_dir = os.path.join(self.config['visdial_question_adj_matrices'], 'test') + hist_adj_matrices_dir = os.path.join(self.config['visdial_history_adj_matrices'], 'test') + + if self.config['visdial_version'] == 0.9: + ques_adj_matrices_dir = os.path.join(self.config['visdial_question_adj_matrices'], 'train') + hist_adj_matrices_dir = os.path.join(self.config['visdial_history_adj_matrices'], 'train') + + self.num_bad_samples = 0 + # number of options to score on + num_options = self.num_options + assert num_options > 1 and num_options <= 100 + num_dialog_rounds = 10 + + dialog = cur_data['dialogs'][index] + cur_questions = cur_data['questions'] + cur_answers = cur_data['answers'] + img_id = dialog['image_id'] + graph_idx = dialog.get('dialog_idx', index) + + if self._split == 'train': + # caption + sent = dialog['caption'].split(' ') + sentences = ['[CLS]'] + tot_len = 1 # for the CLS token + sentence_map = [0] # for the CLS token + sentence_count = 0 + speakers = [0] + + tokenized_sent, sentences, tot_len, sentence_count, sentence_map, speakers = \ + self.tokenize_utterance(sent, sentences, tot_len, sentence_count, sentence_map, speakers) + + utterances = [[tokenized_sent]] + utterances_random = [[tokenized_sent]] + + for rnd, utterance in enumerate(dialog['dialog']): + cur_rnd_utterance = utterances[-1].copy() + cur_rnd_utterance_random = utterances[-1].copy() + + # question + sent = cur_questions[utterance['question']].split(' ') + tokenized_sent, sentences, tot_len, sentence_count, sentence_map, speakers = \ + self.tokenize_utterance(sent, sentences, tot_len, sentence_count, sentence_map, speakers) + + cur_rnd_utterance.append(tokenized_sent) + cur_rnd_utterance_random.append(tokenized_sent) + + # answer + sent = cur_answers[utterance['answer']].split(' ') + tokenized_sent, sentences, tot_len, sentence_count, sentence_map, speakers = \ + self.tokenize_utterance(sent, sentences, tot_len, sentence_count, sentence_map, speakers) + cur_rnd_utterance.append(tokenized_sent) + + utterances.append(cur_rnd_utterance) + + # randomly select one random utterance in that round + num_inds = len(utterance['answer_options']) + gt_option_ind = utterance['gt_index'] + + negative_samples = [] + + for _ in range(self.config["num_negative_samples"]): + + all_inds = list(range(100)) + all_inds.remove(gt_option_ind) + all_inds = all_inds[:(num_options-1)] + tokenized_random_utterance = None + option_ind = None + + while len(all_inds): + option_ind = random.choice(all_inds) + tokenized_random_utterance = self.tokenizer.convert_tokens_to_ids(cur_answers[utterance['answer_options'][option_ind]].split(' ')) + # the 1 here is for the sep token at the end of each utterance + if(MAX_SEQ_LEN >= (tot_len + len(tokenized_random_utterance) + 1)): + break + else: + all_inds.remove(option_ind) + if len(all_inds) == 0: + # all the options exceed the max len. Truncate the last utterance in this case. + tokenized_random_utterance = tokenized_random_utterance[:len(tokenized_sent)] + t = cur_rnd_utterance_random.copy() + t.append(tokenized_random_utterance) + negative_samples.append(t) + + utterances_random.append(negative_samples) + + # removing the caption in the beginning + utterances = utterances[1:] + utterances_random = utterances_random[1:] + assert len(utterances) == len(utterances_random) == num_dialog_rounds + assert tot_len <= MAX_SEQ_LEN, '{} {} tot_len = {} > max_seq_len'.format( + self._split, index, tot_len + ) + + tokens_all = [] + question_limits_all = [] + question_edge_indices_all = [] + question_edge_attributes_all = [] + history_edge_indices_all = [] + history_sep_indices_all = [] + mask_all = [] + segments_all = [] + sep_indices_all = [] + next_labels_all = [] + hist_len_all = [] + + # randomly pick several rounds to train + pos_rounds = sorted(random.sample(range(num_dialog_rounds), self.config['sequences_per_image'] // 2), reverse=True) + neg_rounds = sorted(random.sample(range(num_dialog_rounds), self.config['sequences_per_image'] // 2), reverse=True) + + tokens_all_rnd = [] + question_limits_all_rnd = [] + mask_all_rnd = [] + segments_all_rnd = [] + sep_indices_all_rnd = [] + next_labels_all_rnd = [] + hist_len_all_rnd = [] + + for j in pos_rounds: + context = utterances[j] + context, start_segment = self.pruneRounds(context, self.config['visdial_tot_rounds']) + if j == pos_rounds[0]: # dialog with positive label and max rounds + tokens, segments, sep_indices, mask, input_mask, start_question, end_question = encode_input_with_mask(context, start_segment, self.CLS, + self.SEP, self.MASK, max_seq_len=MAX_SEQ_LEN, mask_prob=self.config["mask_prob"]) + else: + tokens, segments, sep_indices, mask, start_question, end_question = encode_input(context, start_segment, self.CLS, + self.SEP, self.MASK, max_seq_len=MAX_SEQ_LEN, mask_prob=self.config["mask_prob"]) + tokens_all_rnd.append(tokens) + question_limits_all_rnd.append(torch.tensor([start_question, end_question])) + mask_all_rnd.append(mask) + sep_indices_all_rnd.append(sep_indices) + next_labels_all_rnd.append(torch.LongTensor([0])) + segments_all_rnd.append(segments) + hist_len_all_rnd.append(torch.LongTensor([len(context)-1])) + + tokens_all.append(torch.cat(tokens_all_rnd,0).unsqueeze(0)) + mask_all.append(torch.cat(mask_all_rnd,0).unsqueeze(0)) + question_limits_all.extend(question_limits_all_rnd) + segments_all.append(torch.cat(segments_all_rnd, 0).unsqueeze(0)) + sep_indices_all.append(torch.cat(sep_indices_all_rnd, 0).unsqueeze(0)) + next_labels_all.append(torch.cat(next_labels_all_rnd, 0).unsqueeze(0)) + hist_len_all.append(torch.cat(hist_len_all_rnd,0).unsqueeze(0)) + + assert len(pos_rounds) == 1 + question_graphs = pickle.load( + open(os.path.join(ques_adj_matrices_dir, f'{graph_idx}.pkl'), 'rb') + ) + + question_graph_pos = question_graphs[pos_rounds[0]] + question_edge_index_pos = [] + question_edge_attribute_pos = [] + for edge_idx, edge_attr in question_graph_pos: + question_edge_index_pos.append(edge_idx) + edge_attr_one_hot = np.zeros((len(self.parse_vocab) + 1,), dtype=np.float32) + edge_attr_one_hot[self.parse_vocab.get(edge_attr, len(self.parse_vocab))] = 1.0 + question_edge_attribute_pos.append(edge_attr_one_hot) + + question_edge_index_pos = np.array(question_edge_index_pos, dtype=np.float64) + question_edge_attribute_pos = np.stack(question_edge_attribute_pos, axis=0) + + question_edge_indices_all.append( + torch.from_numpy(question_edge_index_pos).t().long().contiguous() + ) + + question_edge_attributes_all.append( + torch.from_numpy(question_edge_attribute_pos) + ) + + history_edge_indices = pickle.load( + open(os.path.join(hist_adj_matrices_dir, f'{graph_idx}.pkl'), 'rb') + ) + + history_edge_indices_all.append( + torch.tensor(history_edge_indices[pos_rounds[0]]).t().long().contiguous() + ) + # Get the [SEP] tokens that will represent the history graph node features + hist_idx_pos = [i * 2 for i in range(pos_rounds[0] + 1)] + sep_indices = sep_indices.squeeze(0).numpy() + history_sep_indices_all.append(torch.from_numpy(sep_indices[hist_idx_pos])) + + if len(neg_rounds) > 0: + tokens_all_rnd = [] + question_limits_all_rnd = [] + mask_all_rnd = [] + segments_all_rnd = [] + sep_indices_all_rnd = [] + next_labels_all_rnd = [] + hist_len_all_rnd = [] + + for j in neg_rounds: + + negative_samples = utterances_random[j] + for context_random in negative_samples: + context_random, start_segment = self.pruneRounds(context_random, self.config['visdial_tot_rounds']) + tokens_random, segments_random, sep_indices_random, mask_random, start_question, end_question = encode_input(context_random, start_segment, self.CLS, + self.SEP, self.MASK, max_seq_len=MAX_SEQ_LEN, mask_prob=self.config["mask_prob"]) + tokens_all_rnd.append(tokens_random) + question_limits_all_rnd.append(torch.tensor([start_question, end_question])) + mask_all_rnd.append(mask_random) + sep_indices_all_rnd.append(sep_indices_random) + next_labels_all_rnd.append(torch.LongTensor([1])) + segments_all_rnd.append(segments_random) + hist_len_all_rnd.append(torch.LongTensor([len(context_random)-1])) + + tokens_all.append(torch.cat(tokens_all_rnd,0).unsqueeze(0)) + mask_all.append(torch.cat(mask_all_rnd,0).unsqueeze(0)) + question_limits_all.extend(question_limits_all_rnd) + segments_all.append(torch.cat(segments_all_rnd, 0).unsqueeze(0)) + sep_indices_all.append(torch.cat(sep_indices_all_rnd, 0).unsqueeze(0)) + next_labels_all.append(torch.cat(next_labels_all_rnd, 0).unsqueeze(0)) + hist_len_all.append(torch.cat(hist_len_all_rnd,0).unsqueeze(0)) + + assert len(neg_rounds) == 1 + + question_graph_neg = question_graphs[neg_rounds[0]] + question_edge_index_neg = [] + question_edge_attribute_neg = [] + for edge_idx, edge_attr in question_graph_neg: + question_edge_index_neg.append(edge_idx) + edge_attr_one_hot = np.zeros((len(self.parse_vocab) + 1,), dtype=np.float32) + edge_attr_one_hot[self.parse_vocab.get(edge_attr, len(self.parse_vocab))] = 1.0 + question_edge_attribute_neg.append(edge_attr_one_hot) + + question_edge_index_neg = np.array(question_edge_index_neg, dtype=np.float64) + question_edge_attribute_neg = np.stack(question_edge_attribute_neg, axis=0) + + question_edge_indices_all.append( + torch.from_numpy(question_edge_index_neg).t().long().contiguous() + ) + + question_edge_attributes_all.append( + torch.from_numpy(question_edge_attribute_neg) + ) + + history_edge_indices_all.append( + torch.tensor(history_edge_indices[neg_rounds[0]]).t().long().contiguous() + ) + + # Get the [SEP] tokens that will represent the history graph node features + hist_idx_neg = [i * 2 for i in range(neg_rounds[0] + 1)] + sep_indices_random = sep_indices_random.squeeze(0).numpy() + history_sep_indices_all.append(torch.from_numpy(sep_indices_random[hist_idx_neg])) + + tokens_all = torch.cat(tokens_all, 0) # [2, num_pos, max_len] + question_limits_all = torch.stack(question_limits_all, 0) # [2, 2] + mask_all = torch.cat(mask_all,0) + segments_all = torch.cat(segments_all, 0) + sep_indices_all = torch.cat(sep_indices_all, 0) + next_labels_all = torch.cat(next_labels_all, 0) + hist_len_all = torch.cat(hist_len_all, 0) + input_mask_all = torch.LongTensor(input_mask) # [max_len] + + item = {} + + item['tokens'] = tokens_all + item['question_limits'] = question_limits_all + item['question_edge_indices'] = question_edge_indices_all + item['question_edge_attributes'] = question_edge_attributes_all + + item['history_edge_indices'] = history_edge_indices_all + item['history_sep_indices'] = history_sep_indices_all + item['segments'] = segments_all + item['sep_indices'] = sep_indices_all + item['mask'] = mask_all + item['next_sentence_labels'] = next_labels_all + item['hist_len'] = hist_len_all + item['input_mask'] = input_mask_all + + # get image features + if not self.config['dataloader_text_only']: + features, num_boxes, boxes, _ , image_target, image_edge_indexes, image_edge_attributes = self._image_features_reader[img_id] + features, spatials, image_mask, image_target, image_label = encode_image_input(features, num_boxes, boxes, image_target, max_regions=self._max_region_num) + else: + features = spatials = image_mask = image_target = image_label = torch.tensor([0]) + + elif self._split == 'val': + gt_relevance = None + gt_option_inds = [] + options_all = [] + + # caption + sent = dialog['caption'].split(' ') + sentences = ['[CLS]'] + tot_len = 1 # for the CLS token + sentence_map = [0] # for the CLS token + sentence_count = 0 + speakers = [0] + + tokenized_sent, sentences, tot_len, sentence_count, sentence_map, speakers = \ + self.tokenize_utterance(sent, sentences, tot_len, sentence_count, sentence_map, speakers) + utterances = [[tokenized_sent]] + + for rnd, utterance in enumerate(dialog['dialog']): + cur_rnd_utterance = utterances[-1].copy() + + # question + sent = cur_questions[utterance['question']].split(' ') + tokenized_sent, sentences, tot_len, sentence_count, sentence_map, speakers = \ + self.tokenize_utterance(sent, sentences, tot_len, sentence_count, sentence_map, speakers) + + cur_rnd_utterance.append(tokenized_sent) + + # current round + gt_option_ind = utterance['gt_index'] + # first select gt option id, then choose the first num_options inds + option_inds = [] + option_inds.append(gt_option_ind) + all_inds = list(range(100)) + all_inds.remove(gt_option_ind) + all_inds = all_inds[:(num_options-1)] + option_inds.extend(all_inds) + gt_option_inds.append(0) + cur_rnd_options = [] + answer_options = [utterance['answer_options'][k] for k in option_inds] + assert len(answer_options) == len(option_inds) == num_options + assert answer_options[0] == utterance['answer'] + + # for evaluation of all options and dense relevance + if self.visdial_data_val_dense: + if rnd == self.visdial_data_val_dense[index]['round_id'] - 1: + # only 1 round has gt_relevance for each example + if 'relevance' in self.visdial_data_val_dense[index]: + gt_relevance = torch.Tensor(self.visdial_data_val_dense[index]['relevance']) + else: + gt_relevance = torch.Tensor(self.visdial_data_val_dense[index]['gt_relevance']) + # shuffle based on new indices + gt_relevance = gt_relevance[torch.LongTensor(option_inds)] + else: + gt_relevance = -1 + + for answer_option in answer_options: + cur_rnd_cur_option = cur_rnd_utterance.copy() + cur_rnd_cur_option.append(self.tokenizer.convert_tokens_to_ids(cur_answers[answer_option].split(' '))) + cur_rnd_options.append(cur_rnd_cur_option) + + # answer + sent = cur_answers[utterance['answer']].split(' ') + tokenized_sent, sentences, tot_len, sentence_count, sentence_map, speakers = \ + self.tokenize_utterance(sent, sentences, tot_len, sentence_count, sentence_map, speakers) + cur_rnd_utterance.append(tokenized_sent) + + utterances.append(cur_rnd_utterance) + options_all.append(cur_rnd_options) + + # encode the input and create batch x 10 x 100 * max_len arrays (batch x num_rounds x num_options) + tokens_all = [] + question_limits_all = [] + mask_all = [] + segments_all = [] + sep_indices_all = [] + hist_len_all = [] + history_sep_indices_all = [] + + for rnd, cur_rnd_options in enumerate(options_all): + + tokens_all_rnd = [] + mask_all_rnd = [] + segments_all_rnd = [] + sep_indices_all_rnd = [] + hist_len_all_rnd = [] + + for j, cur_rnd_option in enumerate(cur_rnd_options): + + cur_rnd_option, start_segment = self.pruneRounds(cur_rnd_option, self.config['visdial_tot_rounds']) + if rnd == len(options_all) - 1 and j == 0: # gt dialog + tokens, segments, sep_indices, mask, input_mask, start_question, end_question = encode_input_with_mask(cur_rnd_option, start_segment, self.CLS, + self.SEP, self.MASK, max_seq_len=MAX_SEQ_LEN, mask_prob=0) + else: + tokens, segments, sep_indices, mask, start_question, end_question = encode_input(cur_rnd_option, start_segment,self.CLS, + self.SEP, self.MASK ,max_seq_len=MAX_SEQ_LEN, mask_prob=0) + + tokens_all_rnd.append(tokens) + mask_all_rnd.append(mask) + segments_all_rnd.append(segments) + sep_indices_all_rnd.append(sep_indices) + hist_len_all_rnd.append(torch.LongTensor([len(cur_rnd_option)-1])) + + question_limits_all.append(torch.tensor([start_question, end_question]).unsqueeze(0).repeat(100, 1)) + tokens_all.append(torch.cat(tokens_all_rnd,0).unsqueeze(0)) + mask_all.append(torch.cat(mask_all_rnd,0).unsqueeze(0)) + segments_all.append(torch.cat(segments_all_rnd,0).unsqueeze(0)) + sep_indices_all.append(torch.cat(sep_indices_all_rnd,0).unsqueeze(0)) + hist_len_all.append(torch.cat(hist_len_all_rnd,0).unsqueeze(0)) + # Get the [SEP] tokens that will represent the history graph node features + # It will be the same for all answer candidates as the history does not change + # for each answer + hist_idx = [i * 2 for i in range(rnd + 1)] + history_sep_indices_all.extend(sep_indices.squeeze(0)[hist_idx].contiguous() for _ in range(100)) + + tokens_all = torch.cat(tokens_all, 0) # [10, 100, max_len] + mask_all = torch.cat(mask_all, 0) + segments_all = torch.cat(segments_all, 0) + sep_indices_all = torch.cat(sep_indices_all, 0) + hist_len_all = torch.cat(hist_len_all, 0) + input_mask_all = torch.LongTensor(input_mask) # [max_len] + + # load graph data + question_limits_all = torch.stack(question_limits_all, 0) # [10, 100, 2] + + question_graphs = pickle.load( + open(os.path.join(ques_adj_matrices_dir, f'{graph_idx}.pkl'), 'rb') + ) + question_edge_indices_all = [] # [10, N] we do not repeat it 100 times here + question_edge_attributes_all = [] # [10, N] we do not repeat it 100 times here + + for q_graph_round in question_graphs: + question_edge_index = [] + question_edge_attribute = [] + for edge_index, edge_attr in q_graph_round: + question_edge_index.append(edge_index) + edge_attr_one_hot = np.zeros((len(self.parse_vocab) + 1,), dtype=np.float32) + edge_attr_one_hot[self.parse_vocab.get(edge_attr, len(self.parse_vocab))] = 1.0 + question_edge_attribute.append(edge_attr_one_hot) + question_edge_index = np.array(question_edge_index, dtype=np.float64) + question_edge_attribute = np.stack(question_edge_attribute, axis=0) + + question_edge_indices_all.extend( + [torch.from_numpy(question_edge_index).t().long().contiguous() for _ in range(100)]) + question_edge_attributes_all.extend( + [torch.from_numpy(question_edge_attribute).contiguous() for _ in range(100)]) + + _history_edge_incides_all = pickle.load( + open(os.path.join(hist_adj_matrices_dir, f'{graph_idx}.pkl'), 'rb') + ) + history_edge_incides_all = [] + for hist_edge_indices_rnd in _history_edge_incides_all: + history_edge_incides_all.extend( + [torch.tensor(hist_edge_indices_rnd).t().long().contiguous() for _ in range(100)] + ) + + item = {} + item['tokens'] = tokens_all + item['segments'] = segments_all + item['sep_indices'] = sep_indices_all + item['mask'] = mask_all + item['hist_len'] = hist_len_all + item['input_mask'] = input_mask_all + + item['gt_option_inds'] = torch.LongTensor(gt_option_inds) + + # return dense annotation data as well + if self.visdial_data_val_dense: + item['round_id'] = torch.LongTensor([self.visdial_data_val_dense[index]['round_id']]) + item['gt_relevance'] = gt_relevance + + item['question_limits'] = question_limits_all + + item['question_edge_indices'] = question_edge_indices_all + item['question_edge_attributes'] = question_edge_attributes_all + + item['history_edge_indices'] = history_edge_incides_all + item['history_sep_indices'] = history_sep_indices_all + + # get image features + if not self.config['dataloader_text_only']: + features, num_boxes, boxes, _ , image_target, image_edge_indexes, image_edge_attributes = self._image_features_reader[img_id] + features, spatials, image_mask, image_target, image_label = encode_image_input( + features, num_boxes, boxes, image_target, max_regions=self._max_region_num, mask_prob=0) + else: + features = spatials = image_mask = image_target = image_label = torch.tensor([0]) + + elif self.split == 'test': + assert num_options == 100 + cur_rnd_utterance = [self.tokenizer.convert_tokens_to_ids(dialog['caption'].split(' '))] + options_all = [] + for rnd,utterance in enumerate(dialog['dialog']): + cur_rnd_utterance.append(self.tokenizer.convert_tokens_to_ids(cur_questions[utterance['question']].split(' '))) + if rnd != len(dialog['dialog'])-1: + cur_rnd_utterance.append(self.tokenizer.convert_tokens_to_ids(cur_answers[utterance['answer']].split(' '))) + for answer_option in dialog['dialog'][-1]['answer_options']: + cur_option = cur_rnd_utterance.copy() + cur_option.append(self.tokenizer.convert_tokens_to_ids(cur_answers[answer_option].split(' '))) + options_all.append(cur_option) + + tokens_all = [] + mask_all = [] + segments_all = [] + sep_indices_all = [] + hist_len_all = [] + + for j, option in enumerate(options_all): + option, start_segment = self.pruneRounds(option, self.config['visdial_tot_rounds']) + tokens, segments, sep_indices, mask = encode_input(option, start_segment ,self.CLS, + self.SEP, self.MASK ,max_seq_len=MAX_SEQ_LEN, mask_prob=0) + + tokens_all.append(tokens) + mask_all.append(mask) + segments_all.append(segments) + sep_indices_all.append(sep_indices) + hist_len_all.append(torch.LongTensor([len(option)-1])) + + tokens_all = torch.cat(tokens_all,0) + mask_all = torch.cat(mask_all,0) + segments_all = torch.cat(segments_all, 0) + sep_indices_all = torch.cat(sep_indices_all, 0) + hist_len_all = torch.cat(hist_len_all,0) + hist_idx = [i*2 for i in range(len(dialog['dialog']))] + history_sep_indices_all = [sep_indices.squeeze(0)[hist_idx].contiguous() for _ in range(num_options)] + + with open(os.path.join(ques_adj_matrices_dir, f'{graph_idx}.pkl'), 'rb') as f: + question_graphs = pickle.load(f) + q_graph_last = question_graphs[-1] + question_edge_index = [] + question_edge_attribute = [] + for edge_index, edge_attr in q_graph_last: + question_edge_index.append(edge_index) + edge_attr_one_hot = np.zeros((len(self.parse_vocab) + 1,), dtype=np.float32) + edge_attr_one_hot[self.parse_vocab.get(edge_attr, len(self.parse_vocab))] = 1.0 + question_edge_attribute.append(edge_attr_one_hot) + question_edge_index = np.array(question_edge_index, dtype=np.float64) + question_edge_attribute = np.stack(question_edge_attribute, axis=0) + + question_edge_indices_all = [torch.from_numpy(question_edge_index).t().long().contiguous() for _ in range(num_options)] + question_edge_attributes_all = [torch.from_numpy(question_edge_attribute).contiguous() for _ in range(num_options)] + + with open(os.path.join(hist_adj_matrices_dir, f'{graph_idx}.pkl'), 'rb') as f: + _history_edge_incides_all = pickle.load(f) + _history_edge_incides_last = _history_edge_incides_all[-1] + history_edge_index_all = [torch.tensor(_history_edge_incides_last).t().long().contiguous() for _ in range(num_options)] + + if self.config['stack_gr_data']: + question_edge_indices_all = torch.stack(question_edge_indices_all, dim=0) + question_edge_attributes_all = torch.stack(question_edge_attributes_all, dim=0) + history_edge_index_all = torch.stack(history_edge_index_all, dim=0) + history_sep_indices_all = torch.stack(history_sep_indices_all, dim=0) + len_question_gr = torch.tensor(question_edge_indices_all.size(-1)).unsqueeze(0).repeat(num_options, 1) + len_history_gr = torch.tensor(history_edge_index_all.size(-1)).repeat(num_options, 1) + len_history_sep = torch.tensor(history_sep_indices_all.size(-1)).repeat(num_options, 1) + + item = {} + item['tokens'] = tokens_all.unsqueeze(0) + item['segments'] = segments_all.unsqueeze(0) + item['sep_indices'] = sep_indices_all.unsqueeze(0) + item['mask'] = mask_all.unsqueeze(0) + item['hist_len'] = hist_len_all.unsqueeze(0) + item['question_limits'] = question_limits_all + item['question_edge_indices'] = question_edge_indices_all + item['question_edge_attributes'] = question_edge_attributes_all + + item['history_edge_indices'] = history_edge_index_all + item['history_sep_indices'] = history_sep_indices_all + + if self.config['stack_gr_data']: + item['len_question_gr'] = len_question_gr + item['len_history_gr'] = len_history_gr + item['len_history_sep'] = len_history_sep + + item['round_id'] = torch.LongTensor([dialog['round_id']]) + + # get image features + if not self.config['dataloader_text_only']: + features, num_boxes, boxes, _ , image_target, image_edge_indexes, image_edge_attributes = self._image_features_reader[img_id] + features, spatials, image_mask, image_target, image_label = encode_image_input(features, num_boxes, boxes, image_target, max_regions=self._max_region_num, mask_prob=0) + else: + features = spatials = image_mask = image_target = image_label = torch.tensor([0]) + + item['image_feat'] = features + item['image_loc'] = spatials + item['image_mask'] = image_mask + item['image_target'] = image_target + item['image_label'] = image_label + item['image_id'] = torch.LongTensor([img_id]) + if self._split == 'train': + # cheap hack to account for the graph data for the postitive and negatice examples + item['image_edge_indices'] = [torch.from_numpy(image_edge_indexes).long(), torch.from_numpy(image_edge_indexes).long()] + item['image_edge_attributes'] = [torch.from_numpy(image_edge_attributes), torch.from_numpy(image_edge_attributes)] + elif self._split == 'val': + # cheap hack to account for the graph data for the postitive and negatice examples + item['image_edge_indices'] = [torch.from_numpy(image_edge_indexes).contiguous().long() for _ in range(1000)] + item['image_edge_attributes'] = [torch.from_numpy(image_edge_attributes).contiguous() for _ in range(1000)] + + else: + # cheap hack to account for the graph data for the postitive and negatice examples + item['image_edge_indices'] = [torch.from_numpy(image_edge_indexes).contiguous().long() for _ in range(100)] + item['image_edge_attributes'] = [torch.from_numpy(image_edge_attributes).contiguous() for _ in range(100)] + + if self.config['stack_gr_data']: + item['image_edge_indices'] = torch.stack(item['image_edge_indices'], dim=0) + item['image_edge_attributes'] = torch.stack(item['image_edge_attributes'], dim=0) + len_image_gr = torch.tensor(item['image_edge_indices'].size(-1)).unsqueeze(0).repeat(num_options) + item['len_image_gr'] = len_image_gr + + return item diff --git a/dataloader/dataloader_visdial_dense.py b/dataloader/dataloader_visdial_dense.py new file mode 100644 index 0000000..72cdbbf --- /dev/null +++ b/dataloader/dataloader_visdial_dense.py @@ -0,0 +1,313 @@ +import torch +import json +import os +import time +import numpy as np +import random +from tqdm import tqdm +import copy +import pyhocon +import glog as log +from collections import OrderedDict +import argparse +import pickle +import torch.utils.data as tud +import sys +sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) + +from utils.data_utils import encode_input, encode_image_input +from dataloader.dataloader_base import DatasetBase + + +class VisdialDenseDataset(DatasetBase): + + def __init__(self, config): + super(VisdialDenseDataset, self).__init__(config) + with open(config.tr_graph_idx_mapping, 'r') as f: + self.tr_graph_idx_mapping = json.load(f) + + with open(config.val_graph_idx_mapping, 'r') as f: + self.val_graph_idx_mapping = json.load(f) + + with open(config.test_graph_idx_mapping, 'r') as f: + self.test_graph_idx_mapping = json.load(f) + + + self.question_gr_paths = { + 'train': os.path.join(self.config['visdial_question_adj_matrices'], 'train'), + 'val': os.path.join(self.config['visdial_question_adj_matrices'], 'val'), + 'test': os.path.join(self.config['visdial_question_adj_matrices'], 'test') + } + + self.history_gr_paths = { + 'train': os.path.join(self.config['visdial_history_adj_matrices'], 'train'), + 'val': os.path.join(self.config['visdial_history_adj_matrices'], 'val'), + 'test': os.path.join(self.config['visdial_history_adj_matrices'], 'test') + } + + + def __getitem__(self, index): + MAX_SEQ_LEN = self.config['max_seq_len'] + cur_data = None + cur_dense_annotations = None + + if self._split == 'train': + cur_data = self.visdial_data_train['data'] + cur_dense_annotations = self.visdial_data_train_dense + cur_question_gr_path = self.question_gr_paths['train'] + cur_history_gr_path = self.history_gr_paths['train'] + cur_gr_mapping = self.tr_graph_idx_mapping + + if self.config['rlv_hst_only']: + cur_rlv_hst = self.rlv_hst_train + elif self._split == 'val': + cur_data = self.visdial_data_val['data'] + cur_dense_annotations = self.visdial_data_val_dense + cur_question_gr_path = self.question_gr_paths['val'] + cur_history_gr_path = self.history_gr_paths['val'] + cur_gr_mapping = self.val_graph_idx_mapping + + if self.config['rlv_hst_only']: + cur_rlv_hst = self.rlv_hst_val + elif self._split == 'trainval': + if index >= self.numDataPoints['train']: + cur_data = self.visdial_data_val['data'] + cur_dense_annotations = self.visdial_data_val_dense + cur_gr_mapping = self.val_graph_idx_mapping + index -= self.numDataPoints['train'] + cur_question_gr_path = self.question_gr_paths['val'] + cur_history_gr_path = self.history_gr_paths['val'] + if self.config['rlv_hst_only']: + cur_rlv_hst = self.rlv_hst_val + else: + cur_data = self.visdial_data_train['data'] + cur_dense_annotations = self.visdial_data_train_dense + cur_question_gr_path = self.question_gr_paths['train'] + cur_gr_mapping = self.tr_graph_idx_mapping + cur_history_gr_path = self.history_gr_paths['train'] + if self.config['rlv_hst_only']: + cur_rlv_hst = self.rlv_hst_train + elif self._split == 'test': + cur_data = self.visdial_data_test['data'] + cur_question_gr_path = self.question_gr_paths['test'] + cur_history_gr_path = self.history_gr_paths['test'] + if self.config['rlv_hst_only']: + cur_rlv_hst = self.rlv_hst_test + + # number of options to score on + num_options = self.num_options_dense + if self._split == 'test' or self.config['validating'] or self.config['predicting']: + assert num_options == 100 + else: + assert num_options >=1 and num_options <= 100 + + dialog = cur_data['dialogs'][index] + cur_questions = cur_data['questions'] + cur_answers = cur_data['answers'] + img_id = dialog['image_id'] + if self._split != 'test': + graph_idx = cur_gr_mapping[str(img_id)] + else: + graph_idx = index + + if self._split != 'test': + assert img_id == cur_dense_annotations[index]['image_id'] + if self.config['rlv_hst_only']: + rlv_hst = cur_rlv_hst[str(img_id)] # [10 for each round, 10 for cap + first 9 round ] + + if self._split == 'test': + cur_rounds = len(dialog['dialog']) # 1, 2, ..., 10 + else: + cur_rounds = cur_dense_annotations[index]['round_id'] # 1, 2, ..., 10 + + # caption + cur_rnd_utterance = [] + include_caption = True + if self.config['rlv_hst_only']: + if self.config['rlv_hst_dense_round']: + if rlv_hst[0] == 0: + include_caption = False + elif rlv_hst[cur_rounds - 1][0] == 0: + include_caption = False + if include_caption: + sent = dialog['caption'].split(' ') + tokenized_sent = self.tokenizer.convert_tokens_to_ids(sent) + cur_rnd_utterance.append(tokenized_sent) + # tot_len += len(sent) + 1 + + for rnd, utterance in enumerate(dialog['dialog'][:cur_rounds]): + if self.config['rlv_hst_only'] and rnd < cur_rounds - 1: + if self.config['rlv_hst_dense_round']: + if rlv_hst[rnd + 1] == 0: + continue + elif rlv_hst[cur_rounds - 1][rnd + 1] == 0: + continue + # question + sent = cur_questions[utterance['question']].split(' ') + tokenized_sent = self.tokenizer.convert_tokens_to_ids(sent) + cur_rnd_utterance.append(tokenized_sent) + + # answer + if rnd != cur_rounds - 1: + sent = cur_answers[utterance['answer']].split(' ') + tokenized_sent = self.tokenizer.convert_tokens_to_ids(sent) + cur_rnd_utterance.append(tokenized_sent) + + if self.config['rlv_hst_only']: + num_rlv_rnds = len(cur_rnd_utterance) - 1 + else: + num_rlv_rnds = None + + if self._split != 'test': + gt_option = dialog['dialog'][cur_rounds - 1]['gt_index'] + if self.config['training'] or self.config['debugging']: + # first select gt option id, then choose the first num_options inds + option_inds = [] + option_inds.append(gt_option) + all_inds = list(range(100)) + all_inds.remove(gt_option) + # debug + if num_options < 100: + random.shuffle(all_inds) + all_inds = all_inds[:(num_options-1)] + option_inds.extend(all_inds) + gt_option = 0 + else: + option_inds = range(num_options) + answer_options = [dialog['dialog'][cur_rounds - 1]['answer_options'][k] for k in option_inds] + if 'relevance' in cur_dense_annotations[index]: + key = 'relevance' + else: + key = 'gt_relevance' + gt_relevance = torch.Tensor(cur_dense_annotations[index][key]) + gt_relevance = gt_relevance[option_inds] + assert len(answer_options) == len(option_inds) == num_options + else: + answer_options = dialog['dialog'][-1]['answer_options'] + assert len(answer_options) == num_options + + options_all = [] + for answer_option in answer_options: + cur_option = cur_rnd_utterance.copy() + cur_option.append(self.tokenizer.convert_tokens_to_ids(cur_answers[answer_option].split(' '))) + options_all.append(cur_option) + if not self.config['rlv_hst_only']: + assert len(cur_option) == 2 * cur_rounds + 1 + + tokens_all = [] + mask_all = [] + segments_all = [] + sep_indices_all = [] + hist_len_all = [] + tot_len_debug = [] + + for opt_id, option in enumerate(options_all): + option, start_segment = self.pruneRounds(option, self.config['visdial_tot_rounds']) + tokens, segments, sep_indices, mask, start_question, end_question = encode_input(option, start_segment ,self.CLS, + self.SEP, self.MASK ,max_seq_len=MAX_SEQ_LEN, mask_prob=0) + + tokens_all.append(tokens) + mask_all.append(mask) + segments_all.append(segments) + sep_indices_all.append(sep_indices) + hist_len_all.append(torch.LongTensor([len(option)-1])) + + len_tokens = sum(len(s) for s in option) + tot_len_debug.append(len_tokens + len(option) + 1) + + tokens_all = torch.cat(tokens_all,0) + mask_all = torch.cat(mask_all,0) + segments_all = torch.cat(segments_all, 0) + sep_indices_all = torch.cat(sep_indices_all, 0) + hist_len_all = torch.cat(hist_len_all,0) + question_limits_all = torch.tensor([start_question, end_question]).unsqueeze(0).repeat(num_options, 1) + if self.config['rlv_hst_only']: + assert num_rlv_rnds > 0 + hist_idx = [i * 2 for i in range(num_rlv_rnds)] + else: + hist_idx = [i*2 for i in range(cur_rounds)] + history_sep_indices_all = sep_indices.squeeze(0)[hist_idx].contiguous().unsqueeze(0).repeat(num_options, 1) + + with open(os.path.join(cur_question_gr_path, f'{graph_idx}.pkl'), 'rb') as f: + question_graphs = pickle.load(f) + question_graph_round = question_graphs[cur_rounds - 1] + question_edge_index = [] + question_edge_attribute = [] + for edge_index, edge_attr in question_graph_round: + question_edge_index.append(edge_index) + edge_attr_one_hot = np.zeros((len(self.parse_vocab) + 1,), dtype=np.float32) + edge_attr_one_hot[self.parse_vocab.get(edge_attr, len(self.parse_vocab))] = 1.0 + question_edge_attribute.append(edge_attr_one_hot) + question_edge_index = np.array(question_edge_index, dtype=np.float64) + question_edge_attribute = np.stack(question_edge_attribute, axis=0) + + question_edge_indices_all = [torch.from_numpy(question_edge_index).t().long().contiguous() for _ in range(num_options)] + question_edge_attributes_all = [torch.from_numpy(question_edge_attribute).contiguous() for _ in range(num_options)] + + if self.config['rlv_hst_only']: + with open(os.path.join(cur_history_gr_path, f'{graph_idx}.pkl'), 'rb') as f: + _history_edge_incides_round = pickle.load(f) + else: + with open(os.path.join(cur_history_gr_path, f'{graph_idx}.pkl'), 'rb') as f: + _history_edge_incides_all = pickle.load(f) + _history_edge_incides_round = _history_edge_incides_all[cur_rounds - 1] + + history_edge_index_all = [torch.tensor(_history_edge_incides_round).t().long().contiguous() for _ in range(num_options)] + + if self.config['stack_gr_data']: + question_edge_indices_all = torch.stack(question_edge_indices_all, dim=0) + question_edge_attributes_all = torch.stack(question_edge_attributes_all, dim=0) + history_edge_index_all = torch.stack(history_edge_index_all, dim=0) + + item = {} + + item['tokens'] = tokens_all.unsqueeze(0) # [1, num_options, max_len] + item['segments'] = segments_all.unsqueeze(0) + item['sep_indices'] = sep_indices_all.unsqueeze(0) + item['mask'] = mask_all.unsqueeze(0) + item['hist_len'] = hist_len_all.unsqueeze(0) + item['question_limits'] = question_limits_all + item['question_edge_indices'] = question_edge_indices_all + item['question_edge_attributes'] = question_edge_attributes_all + item['history_edge_indices'] = history_edge_index_all + item['history_sep_indices'] = history_sep_indices_all + + # add dense annotation fields + if self._split != 'test': + item['gt_relevance'] = gt_relevance # [num_options] + item['gt_option_inds'] = torch.LongTensor([gt_option]) + + # add next sentence labels for training with the nsp loss as well + nsp_labels = torch.ones(*tokens_all.unsqueeze(0).shape[:-1]).long() + nsp_labels[:,gt_option] = 0 + item['next_sentence_labels'] = nsp_labels + + item['round_id'] = torch.LongTensor([cur_rounds]) + else: + if 'round_id' in dialog: + item['round_id'] = torch.LongTensor([dialog['round_id']]) + else: + item['round_id'] = torch.LongTensor([cur_rounds]) + + # get image features + if not self.config['dataloader_text_only']: + features, num_boxes, boxes, _ , image_target, image_edge_indexes, image_edge_attributes = self._image_features_reader[img_id] + features, spatials, image_mask, image_target, image_label = encode_image_input(features, num_boxes, boxes, image_target, max_regions=self._max_region_num, mask_prob=0) + else: + features = spatials = image_mask = image_target = image_label = torch.tensor([0]) + item['image_feat'] = features + item['image_loc'] = spatials + item['image_mask'] = image_mask + item['image_id'] = torch.LongTensor([img_id]) + item['tot_len'] = torch.LongTensor(tot_len_debug) + + + + item['image_edge_indices'] = [torch.from_numpy(image_edge_indexes).contiguous().long() for _ in range(num_options)] + item['image_edge_attributes'] = [torch.from_numpy(image_edge_attributes).contiguous() for _ in range(num_options)] + + if self.config['stack_gr_data']: + item['image_edge_indices'] = torch.stack(item['image_edge_indices'], dim=0) + item['image_edge_attributes'] = torch.stack(item['image_edge_attributes'], dim=0) + + return item diff --git a/ensemble.py b/ensemble.py new file mode 100644 index 0000000..850917f --- /dev/null +++ b/ensemble.py @@ -0,0 +1,114 @@ +import os +import os.path as osp +import numpy as np +import json +import argparse +import pyhocon +import glog as log +import torch +from tqdm import tqdm + +from utils.data_utils import load_pickle_lines +from utils.visdial_metrics import scores_to_ranks + + +parser = argparse.ArgumentParser(description='Ensemble for VisDial') +parser.add_argument('--exp', type=str, default='test', + help='experiment name from .conf') +parser.add_argument('--mode', type=str, default='predict', choices=['eval', 'predict'], + help='eval or predict') +parser.add_argument('--ssh', action='store_true', + help='whether or not we are executing command via ssh. ' + 'If set to True, we will not log.info anything to screen and only redirect them to log file') + + +if __name__ == '__main__': + args = parser.parse_args() + + # initialization + config = pyhocon.ConfigFactory.parse_file(f"config/ensemble.conf")[args.exp] + config["log_dir"] = os.path.join(config["log_dir"], args.exp) + if not os.path.exists(config["log_dir"]): + os.makedirs(config["log_dir"]) + + # set logs + log_file = os.path.join(config["log_dir"], f'{args.mode}.log') + set_log_file(log_file, file_only=args.ssh) + + # print environment info + log.info(f"Running experiment: {args.exp}") + log.info(f"Results saved to {config['log_dir']}") + log.info(pyhocon.HOCONConverter.convert(config, "hocon")) + + if isinstance(config['processed'], list): + assert len(config['models']) == len(config['processed']) + processed = {model:pcd for model, pcd in zip(config['models'], config['processed'])} + else: + processed = {model: config['processed'] for model in config['models']} + + if config['split'] == 'test' and np.any(config['processed']): + test_data = json.load(open(config['visdial_test_data']))['data']['dialogs'] + imid2rndid = {t['image_id']: len(t['dialog']) for t in test_data} + del test_data + + # load predictions files + visdial_outputs = dict() + if args.mode == 'eval': + metrics = {} + for model in config['models']: + pred_filename = osp.join(config['pred_dir'], model, 'visdial_prediction.pkl') + pred_dict = {p['image_id']: p for p in load_pickle_lines(pred_filename)} + log.info(f'Loading {len(pred_dict)} predictions from {pred_filename}') + visdial_outputs[model] = pred_dict + if args.mode == 'eval': + assert len(visdial_outputs[model]) >= num_dialogs + metric = json.load(open(osp.join(config['pred_dir'], model, "metrics_epoch_best.json"))) + metrics[model] = metric['val'] + + image_ids = visdial_outputs[model].keys() + predictions = [] + + # for each dialog + for image_id in tqdm(image_ids): + scores = [] + round_id = None + + for model in config['models']: + pred = visdial_outputs[model][image_id] + + if config['split'] == 'test' and processed[model]: + # if predict on processed data, the first few rounds are deleted from some dialogs + # so the original round ids can only be found in the original test data + round_id_in_pred = imid2rndid[image_id] + else: + round_id_in_pred = pred['gt_relevance_round_id'] + + if not isinstance(round_id_in_pred, int): + round_id_in_pred = int(round_id_in_pred) + if round_id is None: + round_id = round_id_in_pred + else: + # make sure all models have the same round_id + assert round_id == round_id_in_pred + scores.append(torch.from_numpy(pred['nsp_probs']).unsqueeze(0)) + + # ensemble scores + scores = torch.cat(scores, 0) # [n_model, num_rounds, num_options] + scores = torch.sum(scores, dim=0, keepdim=True) # [1, num_rounds, num_options] + + + if scores.size(0) > 1: + scores = scores[round_id - 1].unsqueeze(0) + ranks = scores_to_ranks(scores) # [eval_batch_size, num_rounds, num_options] + ranks = ranks.squeeze(1) + prediction = { + "image_id": image_id, + "round_id": round_id, + "ranks": ranks[0].tolist() + } + predictions.append(prediction) + + filename = osp.join(config['log_dir'], f'{config["split"]}_ensemble_preds.json') + with open(filename, 'w') as f: + json.dump(predictions, f) + log.info(f'{len(predictions)} predictions saved to {filename}') diff --git a/main.py b/main.py new file mode 100644 index 0000000..d6ef13c --- /dev/null +++ b/main.py @@ -0,0 +1,199 @@ +from utils.init_utils import load_runner, load_dataset, set_random_seed, set_training_steps, initialize_from_env, set_log_file, copy_file_to_log +import torch.distributed as dist +import torch.nn as nn +import torch.multiprocessing as mp +import torch +import os +import sys +import argparse +import pyhocon +import glog as log +import socket +import getpass + +try: + from apex.parallel import DistributedDataParallel as DDP + from apex import amp +except ModuleNotFoundError: + print('apex not found') + +parser = argparse.ArgumentParser(description='Main script for VD-GR') +parser.add_argument( + '--model', + type=str, + default='vdgr/P1', + help='model name to train or test') + +parser.add_argument( + '--mode', + type=str, + default='train', + help='train, eval, predict or debug') + +parser.add_argument( + '--wandb_project', + type=str, + default='VD-GR' +) + +parser.add_argument( + '--wandb_mode', + type=str, + default='online', + choices=['online', 'offline', 'disabled', 'run', 'dryrun'] +) + +parser.add_argument( + '--tag', + type=str, + default='K2', + help="Tag to differentiate the different runs" +) + +parser.add_argument( + '--eval_dir', + type=str, + default='', + help="Directory of a trained model to evaluate" +) + +parser.add_argument('--ssh', action='store_true', + help='whether or not we are executing command via ssh. ' + 'If set to True, we will not log.info anything to screen and only redirect them to log file') + + +def main(gpu, config, args): + config['training'] = args.mode == 'train' + config['validating'] = args.mode == 'eval' + config['debugging'] = args.mode == 'debug' + config['predicting'] = args.mode == 'predict' + config['wandb_project'] = args.wandb_project + config['wandb_mode'] = args.wandb_mode + + if config['parallel'] and config['dp_type'] != 'dp': + config['rank'] = gpu + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(config['master_port']) + dist.init_process_group( + backend='nccl', + world_size=config['num_gpus'], + rank=gpu + ) + config['display'] = gpu == 0 + if config['dp_type'] == 'apex': + torch.cuda.set_device(gpu) + else: + config['display'] = True + if config['debugging'] or (config['parallel'] and config['dp_type'] != 'dp'): + config['num_workers'] = 0 + else: + config['num_workers'] = 0 + # set logs + log_file = os.path.join(config["log_dir"], f'{args.mode}.log') + set_log_file(log_file, file_only=args.ssh) + + # print environment info + if config['display']: + log.info('Host: {}, user: {}, CUDA_VISIBLE_DEVICES: {}, cwd: {}'.format( + socket.gethostname(), getpass.getuser(), os.environ.get('CUDA_VISIBLE_DEVICES', ''), os.getcwd())) + log.info('Command line is: {}'.format(' '.join(sys.argv))) + + if config['parallel'] and config['dp_type'] != 'dp': + log.info( + f'World_size: {config["num_gpus"]}, cur rank: {config["rank"]}') + log.info(f"Running experiment: {args.model}") + log.info(f"Results saved to {config['log_dir']}") + + # initialization + if config['display'] and config['training']: + copy_file_to_log(config['log_dir']) + set_random_seed(config['random_seed']) + + device = torch.device(f"cuda:{gpu}") + if config["use_cpu"]: + device = torch.device("cpu") + config['device'] = device + + # prepare dataset + dataset, dataset_eval = load_dataset(config) + + # set training steps + if not config['validating'] or config['parallel']: + config = set_training_steps(config, len(dataset)) + + if config['display']: + log.info(pyhocon.HOCONConverter.convert(config, "hocon")) + + # load runner + runner = load_runner(config) + # apex + if config['dp_type'] == 'apex': + runner.model, runner.optimizer = amp.initialize(runner.model, + runner.optimizer, + opt_level="O1") + # parallel + if config['parallel']: + if config['dp_type'] == 'dp': + runner.model = nn.DataParallel(runner.model) + runner.model.to(config['device']) + elif config['dp_type'] == 'apex': + runner.model = DDP(runner.model) + elif config['dp_type'] == 'ddp': + torch.cuda.set_device(gpu) + runner.model = runner.model.to(gpu) + runner.model = nn.parallel.DistributedDataParallel( + runner.model, + device_ids=[gpu], + output_device=gpu, + find_unused_parameters=True) + else: + raise ValueError(f'Unrecognized dp_type: {config["dp_type"]}') + + if config['training'] or config['debugging']: + runner.load_pretrained_vilbert() + runner.train(dataset, dataset_eval) + else: + if config['loads_start_path']: + runner.load_pretrained_vilbert() + else: + runner.load_ckpt_best() + + metrics_results = {} + if config['predicting']: + eval_splits = [config['predict_split']] + else: + eval_splits = ['val'] + if config['model_type'] == 'conly' and not config['train_each_round']: + eval_splits.append('test') + for split in eval_splits: + if config['display']: + log.info(f'Results on {split} split of the best epoch') + if dataset_eval is None: + dataset_to_eval = dataset + else: + dataset_to_eval = dataset_eval + dataset_to_eval.split = split + _, metrics_results[split] = runner.evaluate( + dataset_to_eval, eval_visdial=True) + if not config['predicting'] and config['display']: + runner.save_eval_results(split, 'best', metrics_results) + + if config['parallel'] and config['dp_type'] != 'dp': + dist.destroy_process_group() + + +if __name__ == '__main__': + args = parser.parse_args() + # initialization + model_type, model_name = args.model.split('/') + config = initialize_from_env( + model_name, args.mode, args.eval_dir, model_type, tag=args.tag) + if config['num_gpus'] > 1: + config['parallel'] = True + if config['dp_type'] == 'dp': + main(0, config, args) + else: + mp.spawn(main, nprocs=config['num_gpus'], args=(config, args)) + else: + config['parallel'] = False + main(0, config, args) diff --git a/misc/.gitkeep b/misc/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/misc/teaser_1.png b/misc/teaser_1.png new file mode 100644 index 0000000..e493fa2 Binary files /dev/null and b/misc/teaser_1.png differ diff --git a/misc/teaser_2.png b/misc/teaser_2.png new file mode 100644 index 0000000..d279406 Binary files /dev/null and b/misc/teaser_2.png differ diff --git a/misc/usa.png b/misc/usa.png new file mode 100644 index 0000000..eebfcf4 Binary files /dev/null and b/misc/usa.png differ diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/runner.py b/models/runner.py new file mode 100644 index 0000000..173a572 --- /dev/null +++ b/models/runner.py @@ -0,0 +1,830 @@ +import os +import os.path as osp +import json +from collections import deque +import time +import re +import shutil +import glob +import pickle +import gc +import numpy as np +import glog as log +try: + from apex import amp +except ModuleNotFoundError: + print('apex not found') + +import torch +import torch.utils.data as tud +import torch.nn.functional as F +import torch.distributed as dist + +from utils.data_utils import load_pickle_lines +from utils.visdial_metrics import SparseGTMetrics, NDCG, scores_to_ranks +import wandb + + +class Runner: + def __init__(self, config): + self.config = config + if 'rank' in config: + self.gpu_rank = config['rank'] + else: + self.gpu_rank = 0 + + self.epoch_idx = 0 + self.max_metric = 0. + self.max_metric_epoch_idx = 0 + self.na_str = 'N/A' + + if self.config["max_ckpt_to_keep"] > 0: + self.checkpoint_queue = deque( + [], maxlen=config["max_ckpt_to_keep"]) + self.metrics_queue = deque([], maxlen=config["max_ckpt_to_keep"]) + + self.setup_wandb() + + def setup_wandb(self): + if self.gpu_rank == 0: + print("[INFO] Set wandb logging on rank {}".format(0)) + run = wandb.init( + project=self.config['wandb_project'], config=self.config, mode=self.config['wandb_mode']) + else: + run = None + self.run = run + + def forward(self, batch, eval_visdial=False): + return NotImplementedError + + def train(self, dataset, dataset_eval=None): + # wandb.login() + if os.path.exists(self.config['log_dir']) or self.config['loads_ckpt'] or self.config['loads_best_ckpt']: + self.load_ckpt() + + if self.config['use_trainval']: + dataset.split = 'trainval' + else: + dataset.split = 'train' + batch_size = self.config['batch_size'] + if self.config['parallel'] and self.config['dp_type'] != 'dp': + sampler_tr = tud.distributed.DistributedSampler( + dataset, + num_replicas=self.config['num_gpus'], + rank=self.gpu_rank + ) + else: + sampler_tr = None + + data_loader_tr = tud.DataLoader( + dataset=dataset, + batch_size=batch_size, + shuffle=self.config['training'] and not self.config['parallel'], + collate_fn=dataset.collate_fn, + num_workers=self.config['num_workers'], + sampler=sampler_tr + ) + + + start_epoch_idx = self.epoch_idx + num_iter_epoch = self.config['num_iter_per_epoch'] + if self.config['display']: + log.info(f'{num_iter_epoch} iter per epoch.') + + # eval before training + eval_dense_at_first = self.config['train_on_dense'] and self.config['skip_mrr_eval'] and start_epoch_idx == 0 + # eval before training under 2 circumstances: + # for dense finetuning, eval ndcg before the first epoch + # for mrr training, continue training and the last epoch is not evaluated + + if (eval_dense_at_first or (self.config['eval_at_start'] and len(self.metrics_queue) == 0 and start_epoch_idx > 0)): + if eval_dense_at_first: + iter_now = 0 + else: + iter_now = max(num_iter_epoch * start_epoch_idx, 0) + + if dataset_eval is None: + dataset.split = 'val' + dataset_to_eval = dataset + else: + dataset_to_eval = dataset_eval + + metrics_results = {} + metrics_to_maximize, metrics_results['val'] = self.evaluate( + dataset_to_eval, iter_now) + if eval_dense_at_first: + self.max_metric = metrics_to_maximize + self.max_metric_epoch_idx = -1 + else: + if self.config['display']: + self.save_eval_results( + 'val', start_epoch_idx - 1, metrics_results) + if metrics_to_maximize > self.max_metric: + self.max_metric = metrics_to_maximize + self.max_metric_epoch_idx = start_epoch_idx - 1 + self.copy_best_results('val', start_epoch_idx - 1) + self.copy_best_predictions('val') + if dataset_eval is None: + if self.config['use_trainval']: + dataset.split = 'trainval' + else: + dataset.split = 'train' + + num_epochs = self.config['num_epochs'] + + for epoch_idx in range(start_epoch_idx, num_epochs): + if self.config['parallel'] and self.config['dp_type'] != 'dp': + sampler_tr.set_epoch(epoch_idx) + + self.epoch_idx = epoch_idx + + if self.config['display']: + log.info(f'starting epoch {epoch_idx}') + log.info('training') + + self.model.train() + + num_batch = 0 + next_logging_pct = .1 + next_evaluating_pct = self.config["next_evaluating_pct"] + .1 + start_time = time.time() + self.optimizer.zero_grad() + + for batch in data_loader_tr: + if self.config['eval_before_training']: + log.info('Skipping stright to evaluation...') + break + num_batch += 1 + pct = num_batch / num_iter_epoch * 100 + iter_now = num_iter_epoch * epoch_idx + num_batch + + output = self.forward(batch) + losses = output['losses'] + + # optimizer step + losses['tot_loss'] /= self.config['batch_multiply'] + # debug + if self.config['debugging']: + log.info('try backward') + if self.config['dp_type'] == 'apex': + with amp.scale_loss(losses['tot_loss'], self.optimizer) as scaled_loss: + scaled_loss.backward() + else: + losses['tot_loss'].backward() + if self.config['debugging']: + log.info('backward done') + + if iter_now % self.config['batch_multiply'] == 0: + self.optimizer.step() + self.optimizer.zero_grad() + self.scheduler.step() + + # display and eval + if pct >= next_logging_pct: + if self.config['display']: + loss_to_print = '' + for key in losses: + if losses[key] is not None and isinstance(losses[key], torch.Tensor): + loss_to_print += f'[{key}: {losses[key].item():.4f}]' + print( + f'[{int(pct)}%][Epoch: {epoch_idx + 1}/{num_epochs}][Iter : {num_batch}/{len(data_loader_tr)}] [time: {time.time() - start_time:.2f}] {loss_to_print}' + ) + + + next_logging_pct += self.config["next_logging_pct"] + + if self.config['debugging']: + break + + if pct >= next_evaluating_pct: + next_evaluating_pct += self.config["next_evaluating_pct"] + + if self.run: + if self.config['train_on_dense']: + self.run.log( + { + "Train/dense_loss": losses['dense_loss'], + "Train/total_loss": losses['tot_loss'], + }, + step=iter_now + ) + + else: + self.run.log( + { + "Train/lm_loss": losses['lm_loss'], + "Train/img_loss": losses['img_loss'], + "Train/nsp_loss": losses['nsp_loss'], + "Train/total_loss": losses['tot_loss'], + }, + step=iter_now + ) + + lr_gnn, lr_bert = self.scheduler.get_lr()[0], self.scheduler.get_lr()[1] + self.run.log( + { + "Train/lr_gnn": lr_gnn, + "Train/lr_bert": lr_bert, + }, + step=iter_now + ) + del losses + # debug + torch.cuda.empty_cache() + + if self.config['display']: + log.info( + f'100%,\ttime:\t{time.time() - start_time:.2f}' + ) + ckpt_path = self.save_ckpt() + + if not self.config['skip_visdial_eval'] and self.epoch_idx % self.config['eval_visdial_every'] == 0: + + iter_now = num_iter_epoch * (epoch_idx + 1) + + if dataset_eval is None: + dataset.split = 'val' + dataset_to_eval = dataset + else: + dataset_to_eval = dataset_eval + metrics_results = {} + metrics_to_maximize, metrics_results['val'] = self.evaluate( + dataset_to_eval, iter_now) + if dataset_eval is None: + if self.config['use_trainval']: + dataset.split = 'trainval' + else: + dataset.split = 'train' + if self.config['display']: + self.save_eval_results('val', epoch_idx, metrics_results) + + if self.config['display']: + + if metrics_to_maximize > self.max_metric: + self.max_metric = metrics_to_maximize + self.max_metric_epoch_idx = epoch_idx + self.copy_best_results('val', epoch_idx) + self.copy_best_predictions('val') + + elif not self.config['parallel'] and epoch_idx - self.max_metric_epoch_idx > self.config["early_stop_epoch"]: + log.info('Early stop.') + break + + if self.run: + self.run.log( + {"Val/metric_best": self.max_metric}, step=iter_now) + + if self.config['parallel']: + if self.config['dp_type'] == 'dp': + gc.collect() + torch.cuda.empty_cache() + else: + dist.barrier() + log.info('Rank {} passed barrier...'.format(self.gpu_rank)) + + if self.config['stop_epochs'] >= 0 and epoch_idx + 1 >= self.config['stop_epochs']: + if self.config['display']: + log.info('Stop for reaching stop_epochs.') + break + + def evaluate(self, dataset, training_iter=None, eval_visdial=True): + # create files to save output + if self.config['predicting']: + visdial_file_name = None + if self.config['save_score']: + visdial_file_name = osp.join( + self.config['log_dir'], f'visdial_prediction.pkl') + if osp.exists(visdial_file_name): + dialogs_predicted = load_pickle_lines( + visdial_file_name) + dialogs_predicted = [d['image_id'] + for d in dialogs_predicted] + else: + dialogs_predicted = [] + f_visdial = open(visdial_file_name, 'ab') + + else: + visdial_file_name = osp.join( + self.config['log_dir'], f'visdial_prediction.jsonlines') + if self.config['parallel'] and self.config['dp_type'] != 'dp': + visdial_file_name = visdial_file_name.replace( + '.jsonlines', f'_{self.config["rank"]}of{self.config["num_gpus"]}.jsonlines') + if osp.exists(visdial_file_name): + dialogs_predicted_visdial = [json.loads( + line)['image_id'] for line in open(visdial_file_name)] + f_visdial = open(visdial_file_name, 'a') + else: + dialogs_predicted_visdial = [] + f_visdial = open(visdial_file_name, 'w') + + dialogs_predicted = dialogs_predicted_visdial + + if len(dialogs_predicted) > 0: + log.info(f'Found {len(dialogs_predicted)} predicted results.') + + if self.config['display']: + if visdial_file_name is not None: + log.info( + f'VisDial predictions saved to {visdial_file_name}') + + elif self.config['display']: + if self.config['continue_evaluation']: + predicted_files = os.listdir( + osp.join(self.config['visdial_output_dir'], dataset.split)) + dialogs_predicted = [ + int(re.match(r'(\d+).npz', p).group(1)) for p in predicted_files] + else: + if osp.exists(osp.join(self.config['visdial_output_dir'], dataset.split)): + shutil.rmtree( + osp.join(self.config['visdial_output_dir'], dataset.split)) + os.makedirs( + osp.join(self.config['visdial_output_dir'], dataset.split)) + + dialogs_predicted = [] + log.info(f'Found {len(dialogs_predicted)} predicted results.') + + if self.config['parallel'] and self.config['dp_type'] != 'dp': + sampler_val = tud.distributed.DistributedSampler( + dataset, + num_replicas=self.config['num_gpus'], + rank=self.gpu_rank + ) + + sampler_val.set_epoch(self.epoch_idx) + else: + sampler_val = None + + data_loader_val = tud.DataLoader( + dataset=dataset, + batch_size=self.config['eval_batch_size'], + shuffle=False, + collate_fn=dataset.collate_fn, + num_workers=self.config['num_workers'], + sampler=sampler_val + ) + self.model.eval() + + with torch.no_grad(): + if self.config['display']: + log.info(f'Evaluating {len(dataset)} samples') + + next_logging_pct = self.config["next_logging_pct"] + .1 + if self.config['parallel'] and self.config['dp_type'] == 'dp': + num_batch_tot = int( + np.ceil(len(dataset) / self.config['eval_batch_size'])) + else: + num_batch_tot = int(np.ceil( + len(dataset) / (self.config['eval_batch_size'] * self.config['num_gpus']))) + num_batch = 0 + if dataset.split == 'val': + num_options = self.config["num_options"] + if self.config['skip_mrr_eval']: + num_rounds = 1 + else: + num_rounds = 10 + elif dataset.split == 'test': + num_options = 100 + num_rounds = 1 + if self.gpu_rank == 0: + start_time = time.time() + + for batch in data_loader_val: + num_batch += 1 + # skip dialogs that have been predicted + if self.config['predicting']: + image_ids = batch['image_id'].tolist() + skip_batch = True + for image_id in image_ids: + if image_id not in dialogs_predicted: + skip_batch = False + if skip_batch: + continue + output = self.forward( + batch, eval_visdial=eval_visdial) + + # visdial evaluation + if eval_visdial: + img_ids = batch['image_id'].tolist() + batch_size = len(img_ids) + if not self.config['skip_ndcg_eval']: + gt_relevance_round_id = batch['round_id'].tolist() + + # [batch_size * num_rounds * num_options, 2] + nsp_scores = output['nsp_scores'] + nsp_probs = F.softmax(nsp_scores, dim=1) + assert nsp_probs.shape[-1] == 2 + # num_dim=2, 0 for postive, 1 for negative + nsp_probs = nsp_probs[:, 0] + nsp_probs = nsp_probs.view( + batch_size, num_rounds, num_options) + + # could be predicting or evaluating + if dataset.split == 'val': + if self.config['skip_ndcg_eval']: + gt_option_inds = batch['gt_option_inds'] + + for b in range(batch_size): + filename = osp.join( + self.config['visdial_output_dir'], dataset.split, f'{img_ids[b]}.npz') + if not osp.exists(filename): + np.savez( + filename, + nsp_probs=nsp_probs[b].cpu().numpy(), + gt_option_inds=gt_option_inds[b].cpu().numpy() + ) + else: + # [batch_size, num_rounds] + gt_option_inds = batch['gt_option_inds'] + # [batch_size, num_options] + gt_relevance = batch['gt_relevance'] + + for b in range(batch_size): + filename = osp.join( + self.config['visdial_output_dir'], dataset.split, f'{img_ids[b]}.npz') + if not osp.exists(filename): + np.savez(filename, + nsp_probs=nsp_probs[b].cpu().numpy(), + gt_option_inds=gt_option_inds[b].cpu( + ).numpy(), + gt_relevance=gt_relevance[b].cpu( + ).numpy(), + gt_relevance_round_id=gt_relevance_round_id[b]) + + # must be predicting + if dataset.split == 'test': + if self.config['save_score']: + for b in range(batch_size): + prediction = { + "image_id": img_ids[b], + "nsp_probs": nsp_probs[b].cpu().numpy(), + "gt_relevance_round_id": gt_relevance_round_id[b] + } + pickle.dump(prediction, f_visdial) + else: + # [eval_batch_size, num_rounds, num_options] + ranks = scores_to_ranks(nsp_probs) + ranks = ranks.squeeze(1) + for b in range(batch_size): + prediction = { + "image_id": img_ids[b], + "round_id": gt_relevance_round_id[b], + "ranks": ranks[b].tolist() + } + f_visdial.write(json.dumps(prediction) + '\n') + + # debug + if self.config['debugging']: + break + + pct = num_batch / num_batch_tot * 100 + if pct >= next_logging_pct: + if self.config['display'] and self.gpu_rank == 0: + log.info( + f'{int(pct)}%,\ttime:\t{time.time() - start_time:.2f}' + ) + next_logging_pct += self.config["next_logging_pct"] + # debug + if self.config['debugging']: + break + + if self.config['display'] and self.gpu_rank == 0: + pct = num_batch / num_batch_tot * 100 + log.info( + f'{int(pct)}%,\ttime:\t{time.time() - start_time:.2f}' + ) + + if not self.config['validating']: + self.model.train() + + if self.config['parallel'] and self.config['dp_type'] != 'dp': + dist.barrier() + + print(f'{self.gpu_rank} passed barrier') + + if self.config['predicting']: + f_visdial.close() + if not self.config['save_score']: + all_visdial_predictions = [json.loads( + line) for line in open(visdial_file_name)] + if self.config['predict_split'] == 'test' and len(all_visdial_predictions) == self.config['num_test_dialogs']: + visdial_file_name = visdial_file_name.replace( + 'jsonlines', 'json') + with open(visdial_file_name, 'w') as f_visdial: + json.dump(all_visdial_predictions, f_visdial) + log.info( + f'Prediction for submisson save to {visdial_file_name}.') + return None, None + + if self.config['display']: + if dataset.split == 'val' and eval_visdial: + if not self.config['skip_mrr_eval']: + sparse_metrics = SparseGTMetrics() + if not self.config['skip_ndcg_eval']: + ndcg = NDCG() + + if dataset.split == 'val' and eval_visdial: + visdial_output_filenames = glob.glob( + osp.join(self.config['visdial_output_dir'], dataset.split, '*.npz')) + log.info( + f'Calculating visdial metrics for {len(visdial_output_filenames)} dialogs') + for visdial_output_filename in visdial_output_filenames: + output = np.load(visdial_output_filename) + nsp_probs = torch.from_numpy( + output['nsp_probs']).unsqueeze(0) + if not self.config['skip_ndcg_eval']: + gt_relevance = torch.from_numpy(output['gt_relevance']).unsqueeze(0) + if not self.config['skip_mrr_eval']: + gt_option_inds = torch.from_numpy( + output['gt_option_inds']).unsqueeze(0) + sparse_metrics.observe(nsp_probs, gt_option_inds) + if not self.config['skip_ndcg_eval']: + gt_relevance_round_id = output['gt_relevance_round_id'] + nsp_probs_dense = nsp_probs[0, gt_relevance_round_id - 1, :].unsqueeze(0) + else: + nsp_probs_dense = nsp_probs.squeeze(0) # [1, 100] + if not self.config['skip_ndcg_eval']: + ndcg.observe(nsp_probs_dense, gt_relevance) + + # visdial eval output + visdial_metrics = {} + if dataset.split == 'val' and eval_visdial: + if not self.config['skip_mrr_eval']: + visdial_metrics.update(sparse_metrics.retrieve(reset=True)) + if not self.config['skip_ndcg_eval']: + visdial_metrics.update(ndcg.retrieve(reset=True)) + + if self.config['display']: + to_print = '' + for metric_name, metric_value in visdial_metrics.items(): + if 'round' not in metric_name: + to_print += f"\n{metric_name}: {metric_value}" + if training_iter is not None: + if self.run: + self.run.log( + {'Val/' + metric_name: metric_value}, step=training_iter) + log.info(to_print) + + if self.config['metrics_to_maximize'] in visdial_metrics: + metrics_to_maximize = visdial_metrics[self.config['metrics_to_maximize']] + else: + metrics_to_maximize = None + torch.cuda.empty_cache() + return metrics_to_maximize, visdial_metrics + else: + torch.cuda.empty_cache() + return None, None + + def save_eval_results(self, split, epoch_idx, metrics_results): + + metrics_filename = osp.join( + self.config['log_dir'], f'metrics_epoch_{epoch_idx}.json') + with open(metrics_filename, 'w') as f: + json.dump(metrics_results, f) + log.info(f'Results of metrics saved to {metrics_filename}') + + if self.config["max_ckpt_to_keep"] > 0: + if len(self.metrics_queue) == self.metrics_queue.maxlen: + todel = self.metrics_queue.popleft() + os.remove(todel) + self.metrics_queue.append(metrics_filename) + + if epoch_idx == 'best': + self.copy_best_predictions(split) + + def copy_best_results(self, split, epoch_idx): + to_print = 'Copy ' + + if not self.config['skip_saving_ckpt']: + ckpt_path = osp.join( + self.config['log_dir'], f'epoch_{epoch_idx}.ckpt') + best_ckpt_path = ckpt_path.replace( + f'{epoch_idx}.ckpt', 'best.ckpt') + shutil.copyfile(ckpt_path, best_ckpt_path) + to_print += best_ckpt_path + ' ' + + metrics_filename = osp.join( + self.config['log_dir'], f'metrics_epoch_{epoch_idx}.json') + best_metric_filename = metrics_filename.replace( + f'{epoch_idx}.json', 'best.json') + shutil.copyfile(metrics_filename, best_metric_filename) + to_print += best_metric_filename + ' ' + + log.info(to_print) + + def copy_best_predictions(self, split): + to_print = 'Copy ' + + visdial_output_dir = osp.join(self.config['visdial_output_dir'], split) + if osp.exists(visdial_output_dir): + dir_best = visdial_output_dir.replace('output', 'output_best') + if osp.exists(dir_best): + shutil.rmtree(dir_best) + shutil.copytree(visdial_output_dir, dir_best) + to_print += dir_best + ' ' + + log.info(to_print) + + def get_ckpt(self): + ckpt = { + 'epoch_idx': self.epoch_idx, + 'max_metric': self.max_metric, + 'seed': self.config['random_seed'], + 'optimizer': self.optimizer.state_dict(), + 'scheduler': self.scheduler.state_dict() + } + if self.config['parallel']: + ckpt['model_state_dict'] = self.model.module.state_dict() + else: + ckpt['model_state_dict'] = self.model.state_dict() + if self.config['dp_type'] == 'apex': + ckpt['amp'] = amp.state_dict() + return ckpt + + def set_ckpt(self, ckpt_dict): + if not self.config['restarts']: + self.epoch_idx = ckpt_dict.get('epoch_idx', -1) + 1 + + if not self.config['resets_max_metric']: + self.max_metric = ckpt_dict.get('max_metric', -1) + + if self.config['parallel']: + model = self.model.module + else: + model = self.model + + model_state_dict = model.state_dict() + former_dict = { + k: v for k, v in ckpt_dict['model_state_dict'].items() if k in model_state_dict} + + if self.config['display']: + log.info("number of keys transferred: %d" % len(former_dict)) + assert len(former_dict.keys()) > 0 + + model_state_dict.update(former_dict) + + model.load_state_dict(model_state_dict) + if self.config['display']: + log.info('loaded model') + del model_state_dict, former_dict + + if not self.config['validating'] and not (self.config['uses_new_optimizer'] or self.config['sets_new_lr']): + if 'optimizer' in ckpt_dict: + self.optimizer.load_state_dict(ckpt_dict['optimizer']) + if self.config['display']: + log.info('loaded optimizer') + if 'scheduler' in ckpt_dict: + self.scheduler.last_epcoh = ckpt_dict['epoch_idx'] * \ + self.config['num_iter_per_epoch'] + self.scheduler.load_state_dict(ckpt_dict['scheduler']) + + if 'amp' in ckpt_dict and self.config['dp_type'] == 'apex': + amp.load_state_dict(ckpt_dict['amp']) + + del ckpt_dict + + torch.cuda.empty_cache() + + def save_ckpt(self): + ckpt_path = f'{self.config["log_dir"]}/epoch_{self.epoch_idx}.ckpt' + log.info(f'saving checkpoint {ckpt_path}') + ckpt = self.get_ckpt() + if self.config['skip_saving_ckpt']: + return ckpt_path + torch_version = float(torch.__version__[:3]) + if torch_version - 1.4 > 1e-3: + torch.save(ckpt, f=ckpt_path, _use_new_zipfile_serialization=False) + else: + torch.save(ckpt, f=ckpt_path) + del ckpt + + if not (self.config['parallel'] and self.config['dp_type'] in ['ddp', 'apex']): + torch.cuda.empty_cache() + + if self.config["max_ckpt_to_keep"] > 0: + if len(self.checkpoint_queue) == self.checkpoint_queue.maxlen: + todel = self.checkpoint_queue.popleft() + os.remove(todel) + self.checkpoint_queue.append(ckpt_path) + + return ckpt_path + + def save_ckpt_best(self): + ckpt_path = f'{self.config["log_dir"]}/epoch_best.ckpt' + log.info(f'saving checkpoint {ckpt_path}') + ckpt = self.get_ckpt() + torch.save(ckpt, f=ckpt_path) + del ckpt + return ckpt_path + + def load_ckpt_best(self): + ckpt_path = f'{osp.dirname(self.config["log_dir"])}/epoch_best.ckpt' + if not osp.exists(ckpt_path): + ckpt_paths = [path for path in os.listdir( + f'{self.config["log_dir"]}/') if path.endswith('.ckpt') and 'best' not in path] + if len(ckpt_paths) == 0: + if self.config['display']: + log.info(f'No .ckpt found in {self.config["log_dir"]}') + return + + def sort_func(x): return int(re.search(r"(\d+)", x).groups()[0]) + ckpt_path = f'{self.config["log_dir"]}/{sorted(ckpt_paths, key=sort_func, reverse=True)[0]}' + if self.config['display']: + log.info(f'loading checkpoint {ckpt_path}') + map_location = {'cuda:0': f'cuda:{self.gpu_rank}'} + self.set_ckpt(torch.load(ckpt_path, map_location=map_location)) + + def load_ckpt(self, ckpt_path=None): + if not ckpt_path: + if self.config['validating'] or self.config['loads_best_ckpt']: + ckpt_path = f'{self.config["log_dir"]}/epoch_best.ckpt' + else: + ckpt_paths = [path for path in os.listdir( + f'{self.config["log_dir"]}/') if path.endswith('.ckpt') and 'best' not in path] + if len(ckpt_paths) == 0: + if self.config['display']: + log.info(f'No .ckpt found in {self.config["log_dir"]}') + return + + def sort_func(x): return int( + re.search(r"(\d+)", x).groups()[0]) + ckpt_path = f'{self.config["log_dir"]}/{sorted(ckpt_paths, key=sort_func, reverse=True)[0]}' + + if self.config['display']: + log.info(f'loading checkpoint {ckpt_path}') + epoch_name = osp.split(ckpt_path)[1].split('.')[0] + if re.search(r"(\d+)", epoch_name): + self.checkpoint_queue.append(ckpt_path) + metrics_filename = osp.join( + self.config['log_dir'], f'metrics_{epoch_name}.json') + if osp.exists(metrics_filename): + self.metrics_queue.append(metrics_filename) + + map_location = {'cuda:0': f'cuda:{self.gpu_rank}'} + self.set_ckpt(torch.load(ckpt_path, map_location=map_location)) + + def match_model_key(self, pretrained_dict, model_dict): + matched_dict = dict() + for key in pretrained_dict: + if key in model_dict: + matched_key = key + elif key.startswith('encoder.') and key[8:] in model_dict: + matched_key = key[8:] + elif key.startswith('module.') and key[7:] in model_dict: + matched_key = key[7:] + elif 'encoder.' + key in model_dict: + matched_key = 'encoder.' + key + elif 'module.' + key in model_dict: + matched_key = 'module.' + key + else: + # not_found.append(key) + continue + matched_dict[matched_key] = pretrained_dict[key] + + not_found = "" + for k in model_dict: + if k not in matched_dict: + not_found += k + '\n' + + log.info("Keys from model_dict that were not found in pretrained_dict:") + log.info(not_found) + return matched_dict + + def load_pretrained_vilbert(self, start_from=None): + if start_from is not None: + self.config["start_path"] = start_from + if self.config['training'] or self.config['debugging']: + ckpt_paths = [path for path in os.listdir( + f'{self.config["log_dir"]}/') if path.endswith('.ckpt') and 'best' not in path] + if len(ckpt_paths) > 0: + if self.config['display']: + log.info('Continue training') + return + + if self.config['display']: + log.info( + f'Loading pretrained VilBERT from {self.config["start_path"]}') + map_location = {'cuda:0': f'cuda:{self.gpu_rank}'} + pretrained_dict = torch.load( + self.config['start_path'], map_location=map_location) + if 'model_state_dict' in pretrained_dict: + pretrained_dict = pretrained_dict['model_state_dict'] + if self.config['parallel']: + model = self.model.module + else: + model = self.model + model_dict = model.state_dict() + + matched_dict = self.match_model_key(pretrained_dict, model_dict) + + if self.config['display']: + log.info("number of keys transferred: %d" % len(matched_dict)) + assert len(matched_dict.keys()) > 0 + model_dict.update(matched_dict) + model.load_state_dict(model_dict) + + del pretrained_dict, model_dict, matched_dict + if not self.config['parallel'] or self.config['dp_type'] == 'dp': + torch.cuda.empty_cache() + + if self.config['display']: + log.info(f'Pretrained VilBERT loaded') diff --git a/models/vdgr.py b/models/vdgr.py new file mode 100644 index 0000000..6aa7f32 --- /dev/null +++ b/models/vdgr.py @@ -0,0 +1,379 @@ +import sys +from collections import OrderedDict + +import torch +from torch import nn +import torch.nn.functional as F + +sys.path.append('../') +from utils.model_utils import listMLE, approxNDCGLoss, listNet, neuralNDCG, neuralNDCG_transposed + +from utils.data_utils import sequence_mask +from utils.optim_utils import init_optim +from models.runner import Runner + +from models.vilbert_dialog import BertForMultiModalPreTraining, BertConfig + + +class VDGR(nn.Module): + + def __init__(self, config_path, device, use_apex=False, cache_dir=None): + super(VDGR, self).__init__() + config = BertConfig.from_json_file(config_path) + + self.bert_pretrained = BertForMultiModalPreTraining.from_pretrained('bert-base-uncased', config, device, use_apex=use_apex, cache_dir=cache_dir) + self.bert_pretrained.train() + + def forward(self, input_ids, image_feat, image_loc, image_edge_indices, image_edge_attributes, + question_edge_indices, question_edge_attributes, question_limits, + history_edge_indices, history_sep_indices, + sep_indices=None, sep_len=None, token_type_ids=None, + attention_mask=None, masked_lm_labels=None, next_sentence_label=None, + image_attention_mask=None, image_label=None, image_target=None): + + masked_lm_loss = None + masked_img_loss = None + nsp_loss = None + seq_relationship_score = None + + if next_sentence_label is not None and masked_lm_labels \ + is not None and image_target is not None: + # train mode, output losses + masked_lm_loss, masked_img_loss, nsp_loss, _, _, seq_relationship_score, _ = \ + self.bert_pretrained( + input_ids, image_feat, image_loc, image_edge_indices, image_edge_attributes, + question_edge_indices, question_edge_attributes, question_limits, + history_edge_indices, history_sep_indices, sep_indices=sep_indices, sep_len=sep_len, \ + token_type_ids=token_type_ids, attention_mask=attention_mask, masked_lm_labels=masked_lm_labels, \ + next_sentence_label=next_sentence_label, image_attention_mask=image_attention_mask,\ + image_label=image_label, image_target=image_target) + else: + #inference, output scores + _, _, seq_relationship_score, _, _, _ = \ + self.bert_pretrained( + input_ids, image_feat, image_loc, image_edge_indices, image_edge_attributes, + question_edge_indices, question_edge_attributes, question_limits, + history_edge_indices, history_sep_indices, + sep_indices=sep_indices, sep_len=sep_len, \ + token_type_ids=token_type_ids, attention_mask=attention_mask, masked_lm_labels=masked_lm_labels, \ + next_sentence_label=next_sentence_label, image_attention_mask=image_attention_mask,\ + image_label=image_label, image_target=image_target) + + out = (masked_lm_loss, masked_img_loss, nsp_loss, seq_relationship_score) + + return out + + +class SparseRunner(Runner): + def __init__(self, config): + super(SparseRunner, self).__init__(config) + self.model = VDGR( + self.config['model_config'], self.config['device'], + use_apex=self.config['dp_type'] == 'apex', + cache_dir=self.config['bert_cache_dir']) + + self.model.to(self.config['device']) + + if not self.config['validating'] or self.config['dp_type'] == 'apex': + self.optimizer, self.scheduler = init_optim(self.model, self.config) + + def forward(self, batch, eval_visdial=False): + # load data + for key in batch: + if isinstance(batch[key], torch.Tensor): + batch[key] = batch[key].to(self.config['device']) + elif isinstance(batch[key], list): + if key != 'dialog_info': # Do not send the dialog_info item to the gpu + batch[key] = [x.to(self.config['device']) for x in batch[key]] + + tokens = batch['tokens'] + segments = batch['segments'] + sep_indices = batch['sep_indices'] + mask = batch['mask'] + hist_len = batch['hist_len'] + image_feat = batch['image_feat'] + image_loc = batch['image_loc'] + image_mask = batch['image_mask'] + next_sentence_labels = batch.get('next_sentence_labels', None) + image_target = batch.get('image_target', None) + image_label = batch.get('image_label', None) + # load the graph data + image_edge_indices = batch['image_edge_indices'] + image_edge_attributes = batch['image_edge_attributes'] + question_edge_indices = batch['question_edge_indices'] + question_edge_attributes = batch['question_edge_attributes'] + question_limits = batch['question_limits'] + history_edge_indices = batch['history_edge_indices'] + history_sep_indices = batch['history_sep_indices'] + + sequence_lengths = torch.gather(sep_indices, 1, hist_len.view(-1, 1)) + 1 + sequence_lengths = sequence_lengths.squeeze(1) + attention_mask_lm_nsp = sequence_mask(sequence_lengths, max_len=tokens.shape[1]) + sep_len = hist_len + 1 + + losses = OrderedDict() + + if eval_visdial: + num_lines = tokens.size(0) + line_batch_size = self.config['eval_line_batch_size'] + num_line_batches = num_lines // line_batch_size + if num_lines % line_batch_size > 0: + num_line_batches += 1 + nsp_scores = [] + for j in range(num_line_batches): + # create chunks of the original batch + chunk_range = range(j*line_batch_size, min((j+1)*line_batch_size, num_lines)) + tokens_chunk = tokens[chunk_range] + segments_chunk = segments[chunk_range] + sep_indices_chunk = sep_indices[chunk_range] + mask_chunk = mask[chunk_range] + sep_len_chunk = sep_len[chunk_range] + attention_mask_lm_nsp_chunk = attention_mask_lm_nsp[chunk_range] + image_feat_chunk = image_feat[chunk_range] + image_loc_chunk = image_loc[chunk_range] + image_mask_chunk = image_mask[chunk_range] + image_edge_indices_chunk = image_edge_indices[chunk_range[0]: chunk_range[-1]+1] + image_edge_attributes_chunk = image_edge_attributes[chunk_range[0]: chunk_range[-1]+1] + question_edge_indices_chunk = question_edge_indices[chunk_range[0]: chunk_range[-1]+1] + question_edge_attributes_chunk = question_edge_attributes[chunk_range[0]: chunk_range[-1]+1] + question_limits_chunk = question_limits[chunk_range[0]: chunk_range[-1]+1] + history_edge_indices_chunk = history_edge_indices[chunk_range[0]: chunk_range[-1]+1] + history_sep_indices_chunk = history_sep_indices[chunk_range[0]: chunk_range[-1]+1] + + _ , _ , _, nsp_scores_chunk = \ + self.model( + tokens_chunk, + image_feat_chunk, + image_loc_chunk, + image_edge_indices_chunk, + image_edge_attributes_chunk, + question_edge_indices_chunk, + question_edge_attributes_chunk, + question_limits_chunk, + history_edge_indices_chunk, + history_sep_indices_chunk, + sep_indices=sep_indices_chunk, + sep_len=sep_len_chunk, + token_type_ids=segments_chunk, + masked_lm_labels=mask_chunk, + attention_mask=attention_mask_lm_nsp_chunk, + image_attention_mask=image_mask_chunk + ) + nsp_scores.append(nsp_scores_chunk) + nsp_scores = torch.cat(nsp_scores, 0) + + else: + losses['lm_loss'], losses['img_loss'], losses['nsp_loss'], nsp_scores = \ + self.model( + tokens, + image_feat, + image_loc, + image_edge_indices, + image_edge_attributes, + question_edge_indices, + question_edge_attributes, + question_limits, + history_edge_indices, + history_sep_indices, + next_sentence_label=next_sentence_labels, + image_target=image_target, + image_label=image_label, + sep_indices=sep_indices, + sep_len=sep_len, + token_type_ids=segments, + masked_lm_labels=mask, + attention_mask=attention_mask_lm_nsp, + image_attention_mask=image_mask + ) + + losses['tot_loss'] = 0 + for key in ['lm_loss', 'img_loss', 'nsp_loss']: + if key in losses and losses[key] is not None: + losses[key] = losses[key].mean() + losses['tot_loss'] += self.config[f'{key}_coeff'] * losses[key] + + output = { + 'losses': losses, + 'nsp_scores': nsp_scores + } + return output + + +class DenseRunner(Runner): + def __init__(self, config): + super(DenseRunner, self).__init__(config) + self.model = VDGR( + self.config['model_config'], self.config['device'], + use_apex=self.config['dp_type'] == 'apex', + cache_dir=self.config['bert_cache_dir']) + + if not(self.config['parallel'] and self.config['dp_type'] == 'dp'): + self.model.to(self.config['device']) + + if self.config['dense_loss'] == 'ce': + self.dense_loss = nn.KLDivLoss(reduction='batchmean') + elif self.config['dense_loss'] == 'listmle': + self.dense_loss = listMLE + elif self.config['dense_loss'] == 'listnet': + self.dense_loss = listNet + elif self.config['dense_loss'] == 'approxndcg': + self.dense_loss = approxNDCGLoss + elif self.config['dense_loss'] == 'neural_ndcg': + self.dense_loss = neuralNDCG + elif self.config['dense_loss'] == 'neural_ndcg_transposed': + self.dense_loss = neuralNDCG_transposed + else: + raise ValueError('dense_loss must be one of ce, listmle, listnet, approxndcg, neural_ndcg, neural_ndcg_transposed') + + if not self.config['validating'] or self.config['dp_type'] == 'apex': + self.optimizer, self.scheduler = init_optim(self.model, self.config) + + def forward(self, batch, eval_visdial=False): + # load data + for key in batch: + if isinstance(batch[key], torch.Tensor): + batch[key] = batch[key].to(self.config['device']) + elif isinstance(batch[key], list): + if key != 'dialog_info': # Do not send the dialog_info item to the gpu + batch[key] = [x.to(self.config['device']) for x in batch[key]] + + # get embedding and forward visdial + tokens = batch['tokens'] + segments = batch['segments'] + sep_indices = batch['sep_indices'] + mask = batch['mask'] + hist_len = batch['hist_len'] + image_feat = batch['image_feat'] + image_loc = batch['image_loc'] + image_mask = batch['image_mask'] + next_sentence_labels = batch.get('next_sentence_labels', None) + image_target = batch.get('image_target', None) + image_label = batch.get('image_label', None) + + # load the graph data + image_edge_indices = batch['image_edge_indices'] + image_edge_attributes = batch['image_edge_attributes'] + question_edge_indices = batch['question_edge_indices'] + question_edge_attributes = batch['question_edge_attributes'] + question_limits = batch['question_limits'] + history_edge_indices = batch['history_edge_indices'] + assert history_edge_indices[0].size(0) == 2 + history_sep_indices = batch['history_sep_indices'] + + sequence_lengths = torch.gather(sep_indices, 1, hist_len.view(-1, 1)) + 1 + sequence_lengths = sequence_lengths.squeeze(1) + attention_mask_lm_nsp = sequence_mask(sequence_lengths, max_len=tokens.shape[1]) + sep_len = hist_len + 1 + + losses = OrderedDict() + + if eval_visdial: + num_lines = tokens.size(0) + line_batch_size = self.config['eval_line_batch_size'] + num_line_batches = num_lines // line_batch_size + if num_lines % line_batch_size > 0: + num_line_batches += 1 + nsp_scores = [] + for j in range(num_line_batches): + # create chunks of the original batch + chunk_range = range(j*line_batch_size, min((j+1)*line_batch_size, num_lines)) + tokens_chunk = tokens[chunk_range] + segments_chunk = segments[chunk_range] + sep_indices_chunk = sep_indices[chunk_range] + mask_chunk = mask[chunk_range] + sep_len_chunk = sep_len[chunk_range] + attention_mask_lm_nsp_chunk = attention_mask_lm_nsp[chunk_range] + image_feat_chunk = image_feat[chunk_range] + image_loc_chunk = image_loc[chunk_range] + image_mask_chunk = image_mask[chunk_range] + image_edge_indices_chunk = image_edge_indices[chunk_range[0]: chunk_range[-1]+1] + image_edge_attributes_chunk = image_edge_attributes[chunk_range[0]: chunk_range[-1]+1] + question_edge_indices_chunk = question_edge_indices[chunk_range[0]: chunk_range[-1]+1] + question_edge_attributes_chunk = question_edge_attributes[chunk_range[0]: chunk_range[-1]+1] + question_limits_chunk = question_limits[chunk_range[0]: chunk_range[-1]+1] + history_edge_indices_chunk = history_edge_indices[chunk_range[0]: chunk_range[-1]+1] + history_sep_indices_chunk = history_sep_indices[chunk_range[0]: chunk_range[-1]+1] + + _, _, _, nsp_scores_chunk = \ + self.model( + tokens_chunk, + image_feat_chunk, + image_loc_chunk, + image_edge_indices_chunk, + image_edge_attributes_chunk, + question_edge_indices_chunk, + question_edge_attributes_chunk, + question_limits_chunk, + history_edge_indices_chunk, + history_sep_indices_chunk, + sep_indices=sep_indices_chunk, + sep_len=sep_len_chunk, + token_type_ids=segments_chunk, + masked_lm_labels=mask_chunk, + attention_mask=attention_mask_lm_nsp_chunk, + image_attention_mask=image_mask_chunk + ) + nsp_scores.append(nsp_scores_chunk) + nsp_scores = torch.cat(nsp_scores, 0) + + else: + _, _, _, nsp_scores = \ + self.model( + tokens, + image_feat, + image_loc, + image_edge_indices, + image_edge_attributes, + question_edge_indices, + question_edge_attributes, + question_limits, + history_edge_indices, + history_sep_indices, + next_sentence_label=next_sentence_labels, + image_target=image_target, + image_label=image_label, + sep_indices=sep_indices, + sep_len=sep_len, + token_type_ids=segments, + masked_lm_labels=mask, + attention_mask=attention_mask_lm_nsp, + image_attention_mask=image_mask + ) + + + if nsp_scores is not None: + nsp_scores_output = nsp_scores.detach().clone() + if not eval_visdial: + nsp_scores = nsp_scores.view(-1, self.config['num_options_dense'], 2) + if 'next_sentence_labels' in batch and self.config['nsp_loss_coeff'] > 0: + next_sentence_labels = batch['next_sentence_labels'].to(self.config['device']) + losses['nsp_loss'] = F.cross_entropy(nsp_scores.view(-1,2), next_sentence_labels.view(-1)) + else: + losses['nsp_loss'] = None + + if not eval_visdial: + gt_relevance = batch['gt_relevance'].to(self.config['device']) + nsp_scores = nsp_scores[:, :, 0] + if self.config['dense_loss'] == 'ce': + losses['dense_loss'] = self.dense_loss(F.log_softmax(nsp_scores, dim=1), F.softmax(gt_relevance, dim=1)) + else: + losses['dense_loss'] = self.dense_loss(nsp_scores, gt_relevance) + else: + losses['dense_loss'] = None + else: + nsp_scores_output = None + losses['nsp_loss'] = None + losses['dense_loss'] = None + + losses['tot_loss'] = 0 + for key in ['nsp_loss', 'dense_loss']: + if key in losses and losses[key] is not None: + losses[key] = losses[key].mean() + losses['tot_loss'] += self.config[f'{key}_coeff'] * losses[key] + + output = { + 'losses': losses, + 'nsp_scores': nsp_scores_output + } + + return output diff --git a/models/vilbert_dialog.py b/models/vilbert_dialog.py new file mode 100644 index 0000000..36b3837 --- /dev/null +++ b/models/vilbert_dialog.py @@ -0,0 +1,2021 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BERT model.""" + +import copy +import json +import logging +import math +import os +import shutil +import tarfile +import tempfile +import sys +from io import open + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss +import torch.nn.functional as F +from torch.nn.utils.weight_norm import weight_norm +from pytorch_transformers.modeling_bert import BertEmbeddings +from utils.data_utils import sequence_mask, to_data_list +import torch_geometric.nn as pyg_nn +from torch_geometric.data import Data +from torch_geometric.loader import DataLoader +from pytorch_pretrained_bert.file_utils import cached_path +import pdb + +logger = logging.getLogger(__name__) + +PRETRAINED_MODEL_ARCHIVE_MAP = { + "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz", + "bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz", + "bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz", + "bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz", + "bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz", + "bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", + "bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", +} + +def load_tf_weights_in_bert(model, tf_checkpoint_path): + """ Load tf checkpoints in a pytorch model + """ + try: + import re + import numpy as np + import tensorflow as tf + except ImportError: + print( + "Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + print("Converting TensorFlow checkpoint from {}".format(tf_path)) + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + print("Loading TF weight {} with shape {}".format(name, shape)) + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any(n in ["adam_v", "adam_m"] for n in name): + print("Skipping {}".format("/".join(name))) + continue + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + l = re.split(r"_(\d+)", m_name) + else: + l = [m_name] + if l[0] == "kernel" or l[0] == "gamma": + pointer = getattr(pointer, "weight") + elif l[0] == "output_bias" or l[0] == "beta": + pointer = getattr(pointer, "bias") + elif l[0] == "output_weights": + pointer = getattr(pointer, "weight") + else: + pointer = getattr(pointer, l[0]) + if len(l) >= 2: + num = int(l[1]) + pointer = pointer[num] + if m_name[-11:] == "_embeddings": + pointer = getattr(pointer, "weight") + elif m_name == "kernel": + array = np.transpose(array) + try: + assert pointer.shape == array.shape + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + print("Initialize PyTorch weight {}".format(name)) + pointer.data = torch.from_numpy(array) + return model + +class GeLU(nn.Module): + """Implementation of the gelu activation function. + For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): + 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + Also see https://arxiv.org/abs/1606.08415 + """ + + def __init__(self): + super(GeLU, self).__init__() + + def forward(self, x): + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + + +def gelu(x): + """Implementation of the gelu activation function. + For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): + 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + Also see https://arxiv.org/abs/1606.08415 + """ + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + + +def swish(x): + return x * torch.sigmoid(x) + + +ACT2FN = {"GeLU": GeLU(), "gelu": gelu, + "relu": torch.nn.functional.relu, "swish": swish} + +class BertConfig(object): + """Configuration class to store the configuration of a `BertModel`. + """ + + def __init__( + self, + vocab_size_or_config_json_file, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + v_feature_size=2048, + v_target_size=1601, + v_hidden_size=768, + v_num_hidden_layers=3, + v_num_attention_heads=12, + v_intermediate_size=3072, + bi_hidden_size=1024, + bi_num_attention_heads=16, + v_attention_probs_dropout_prob=0.1, + v_hidden_act="gelu", + v_hidden_dropout_prob=0.1, + v_initializer_range=0.2, + v_biattention_id=[0, 1], + t_biattention_id=[10, 11], + predict_feature=False, + fast_mode=False, + fixed_v_layer=0, + fixed_t_layer=0, + in_batch_pairs=False, + fusion_method="mul", + intra_gate=False, + with_coattention=True + ): + + """Constructs BertConfig. + + Args: + vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`. + hidden_size: Size of the encoder layers and the pooler layer. + num_hidden_layers: Number of hidden layers in the Transformer encoder. + num_attention_heads: Number of attention heads for each attention layer in + the Transformer encoder. + intermediate_size: The size of the "intermediate" (i.e., feed-forward) + layer in the Transformer encoder. + hidden_act: The non-linear activation function (function or string) in the + encoder and pooler. If string, "gelu", "relu" and "swish" are supported. + hidden_dropout_prob: The dropout probabilitiy for all fully connected + layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob: The dropout ratio for the attention + probabilities. + max_position_embeddings: The maximum sequence length that this model might + ever be used with. Typically set this to something large just in case + (e.g., 512 or 1024 or 2048). + type_vocab_size: The vocabulary size of the `token_type_ids` passed into + `BertModel`. + initializer_range: The sttdev of the truncated_normal_initializer for + initializing all weight matrices. + """ + assert len(v_biattention_id) == len(t_biattention_id) + assert max(v_biattention_id) < v_num_hidden_layers + assert max(t_biattention_id) < num_hidden_layers + + if isinstance(vocab_size_or_config_json_file, str) or ( + sys.version_info[0] == 2 + and isinstance(vocab_size_or_config_json_file, unicode) + ): + with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader: + json_config = json.loads(reader.read()) + for key, value in json_config.items(): + self.__dict__[key] = value + elif isinstance(vocab_size_or_config_json_file, int): + self.vocab_size = vocab_size_or_config_json_file + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.v_feature_size = v_feature_size + self.v_hidden_size = v_hidden_size + self.v_num_hidden_layers = v_num_hidden_layers + self.v_num_attention_heads = v_num_attention_heads + self.v_intermediate_size = v_intermediate_size + self.v_attention_probs_dropout_prob = v_attention_probs_dropout_prob + self.v_hidden_act = v_hidden_act + self.v_hidden_dropout_prob = v_hidden_dropout_prob + self.v_initializer_range = v_initializer_range + self.v_biattention_id = v_biattention_id + self.t_biattention_id = t_biattention_id + self.v_target_size = v_target_size + self.bi_hidden_size = bi_hidden_size + self.bi_num_attention_heads = bi_num_attention_heads + self.predict_feature = predict_feature + self.fast_mode = fast_mode + self.fixed_v_layer = fixed_v_layer + self.fixed_t_layer = fixed_t_layer + + self.in_batch_pairs = in_batch_pairs + self.fusion_method = fusion_method + self.intra_gate = intra_gate + self.with_coattention=with_coattention + else: + raise ValueError( + "First argument must be either a vocabulary size (int)" + "or the path to a pretrained model config file (str)" + ) + + @classmethod + def from_dict(cls, json_object): + """Constructs a `BertConfig` from a Python dictionary of parameters.""" + config = BertConfig(vocab_size_or_config_json_file=-1) + for key, value in json_object.items(): + config.__dict__[key] = value + return config + + @classmethod + def from_json_file(cls, json_file): + """Constructs a `BertConfig` from a json file of parameters.""" + with open(json_file, "r", encoding="utf-8") as reader: + text = reader.read() + return cls.from_dict(json.loads(text)) + + def __repr__(self): + return str(self.to_json_string()) + + def to_dict(self): + """Serializes this instance to a Python dictionary.""" + output = copy.deepcopy(self.__dict__) + return output + + def to_json_string(self): + """Serializes this instance to a JSON string.""" + return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" + +try: + # from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm + import torch.nn.LayerNorm as BertLayerNorm +except ImportError: + # logger.info( + # "Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex ." + # ) + pass + + class BertLayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-12): + """Construct a layernorm module in the TF style (epsilon inside the square root). + """ + super(BertLayerNorm, self).__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.bias = nn.Parameter(torch.zeros(hidden_size)) + self.variance_epsilon = eps + + def forward(self, x): + u = x.mean(-1, keepdim=True) + s = (x - u).pow(2).mean(-1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.variance_epsilon) + return self.weight * x + self.bias + +class BertEmbeddingsDialog(nn.Module): + def __init__(self, config, device): + super(BertEmbeddingsDialog, self).__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + max_seq_len = 256 + d_model = config.hidden_size + pe = torch.zeros(max_seq_len, d_model) + for pos in range(max_seq_len): + for i in range(0, d_model, 2): + pe[pos, i] = \ + math.sin(pos / (10000 ** ((2 * i)/d_model))) + pe[pos, i + 1] = \ + math.cos(pos / (10000 ** ((2 * (i + 1))/d_model))) + self.pe = pe.to(device) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + # add support for additional segment embeddings. Supporting 10 additional embedding as of now + self.token_type_embeddings_extension = nn.Embedding(10,config.hidden_size) + # adding specialized embeddings for sep tokens + self.sep_embeddings = nn.Embedding(50,config.hidden_size) + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.config = config + + def forward(self, input_ids, sep_indices=None, sep_len=None, token_type_ids=None): + seq_length = input_ids.size(1) + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + + token_type_ids_extension = token_type_ids - self.config.type_vocab_size + token_type_ids_extension_mask = (token_type_ids_extension >= 0).float() + token_type_ids_extension = (token_type_ids_extension.float() * token_type_ids_extension_mask).long() + + token_type_ids_mask = (token_type_ids < self.config.type_vocab_size).float() + assert torch.sum(token_type_ids_extension_mask + token_type_ids_mask) == \ + torch.numel(token_type_ids) == torch.numel(token_type_ids_mask) + token_type_ids = (token_type_ids.float() * token_type_ids_mask).long() + + token_type_embeddings = self.token_type_embeddings(token_type_ids) + token_type_embeddings_extension = self.token_type_embeddings_extension(token_type_ids_extension) + + token_type_embeddings = (token_type_embeddings * token_type_ids_mask.unsqueeze(-1)) + \ + (token_type_embeddings_extension * token_type_ids_extension_mask.unsqueeze(-1)) + + embeddings = words_embeddings + position_embeddings + token_type_embeddings + + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + +class BertSelfAttention(nn.Module): + def __init__(self, config): + super(BertSelfAttention, self).__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states, attention_mask): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + return context_layer, attention_probs + +class BertSelfOutput(nn.Module): + def __init__(self, config): + super(BertSelfOutput, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + +class BertAttention(nn.Module): + def __init__(self, config): + super(BertAttention, self).__init__() + self.self = BertSelfAttention(config) + self.output = BertSelfOutput(config) + + def forward(self, input_tensor, attention_mask): + self_output, attention_probs = self.self(input_tensor, attention_mask) + attention_output = self.output(self_output, input_tensor) + return attention_output, attention_probs + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super(BertIntermediate, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str) or ( + sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode) + ): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, config): + super(BertOutput, self).__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config): + super(BertLayer, self).__init__() + self.attention = BertAttention(config) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward(self, hidden_states, attention_mask): + attention_output, attention_probs = self.attention(hidden_states, attention_mask) + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output, attention_probs + + +class TextGraphLayer(nn.Module): + def __init__(self, config): + super(TextGraphLayer, self).__init__() + self.config = config + self.gnn_act = ACT2FN[config.gnn_act] + + self.num_q_gnn_layers = config.num_q_gnn_layers + self.num_h_gnn_layers = config.num_h_gnn_layers + + self.q_gnn_layers = [] + self.q_gnn_norm_layers = [] + + for _ in range(self.num_q_gnn_layers): + # Graph layers + self.q_gnn_layers.append( + pyg_nn.GATv2Conv( + config.hidden_size, config.hidden_size//config.num_gnn_attention_heads, + config.num_gnn_attention_heads, + dropout=config.gnn_dropout_prob, + edge_dim=config.q_gnn_edge_dim, + concat=True + ) + ) + # After each graph layer, a normalization layer is added + self.q_gnn_norm_layers.append(pyg_nn.PairNorm()) + + self.q_gnn_layers = nn.ModuleList(self.q_gnn_layers) + self.q_gnn_norm_layers = nn.ModuleList(self.q_gnn_norm_layers) + + self.h_gnn_layers = [] + self.h_gnn_norm_layers = [] + + for _ in range(self.num_h_gnn_layers): + self.h_gnn_layers.append( + pyg_nn.GATv2Conv( + config.hidden_size, config.hidden_size//config.num_gnn_attention_heads, + config.num_gnn_attention_heads, + dropout=config.gnn_dropout_prob, + concat=True + ) + ) + # After each graph layer, a normalization layer is added + self.h_gnn_norm_layers.append(pyg_nn.PairNorm()) + + self.h_gnn_layers = nn.ModuleList(self.h_gnn_layers) + self.h_gnn_norm_layers = nn.ModuleList(self.h_gnn_norm_layers) + + self.h_gnn_dense_hub = nn.Linear(config.v_hidden_size, config.hidden_size) + self.h_gnn_layer_norm_hub = BertLayerNorm(config.hidden_size, eps=1e-12) + self.h_gnn_dropout_hub = nn.Dropout(config.gnn_dropout_prob) + + q_dense_pooling = nn.Sequential( + nn.Linear(config.hidden_size, 1), + ACT2FN['GeLU'], + nn.Dropout(config.gnn_dropout_prob) + ) + self.q_gnn_pooling = pyg_nn.GlobalAttention(q_dense_pooling) + h_dense_pooling = nn.Sequential( + nn.Linear(config.hidden_size, 1), + ACT2FN['GeLU'], + nn.Dropout(config.gnn_dropout_prob) + ) + self.h_gnn_pooling = pyg_nn.GlobalAttention(h_dense_pooling) + + + def forward( + self, hidden_states, q_edge_indices, q_edge_attributes, + q_limits, h_edge_indices, h_sep_indices, v_hub, + len_q_gr=None, len_h_gr=None, len_h_sep=None): + device = hidden_states.device + batch_size, _, hidden_size = hidden_states.size() + if isinstance(q_edge_indices, list): + assert len(q_edge_indices) == len(q_edge_attributes) == q_limits.size(0) \ + == len(h_edge_indices) == len(h_sep_indices) == batch_size + else: + assert q_edge_indices.size(0) == q_edge_attributes.size(0) == q_limits.size(0) \ + == h_edge_indices.size(0) == h_sep_indices.size(0) == batch_size + if len_q_gr is not None: + q_edge_indices = [t.squeeze(0)[:, :l].long() for t, l in zip(torch.split(q_edge_indices, 1, dim=0), len_q_gr)] + q_edge_attributes = [t.squeeze(0)[:l, :] for t, l in zip(torch.split(q_edge_attributes, 1, dim=0), len_q_gr)] + h_edge_indices = [t.squeeze(0)[:, :l].long() for t, l in zip(torch.split(h_edge_indices, 1, dim=0), len_h_gr)] + h_sep_indices = [t.squeeze(0)[:l].long() for t, l in zip(torch.split(h_sep_indices, 1, dim=0), len_h_sep)] + else: + q_edge_indices = [t.squeeze(0) for t in torch.split(q_edge_indices, 1, dim=0)] + q_edge_attributes = [t.squeeze(0) for t in torch.split(q_edge_attributes, 1, dim=0)] + h_edge_indices = [t.squeeze(0) for t in torch.split(h_edge_indices, 1, dim=0)] + h_sep_indices = [t.squeeze(0).long() for t in torch.split(h_sep_indices, 1, dim=0)] + + gnn_hidden_states = hidden_states.clone().detach() + # Extract the history and question node features (without the hub node) + h_node_feats = [] + q_node_feats = [] + q_limits = q_limits.tolist() + q_tok_indices_extended = [] + h_sep_indices_extended = [] + for i, (h_sep_idx, q_limit) in enumerate(zip(h_sep_indices, q_limits)): + batch_data = gnn_hidden_states[i, :, :].clone().detach() + h_sep_idx = h_sep_idx.unsqueeze(-1).repeat(1, hidden_size) + h_sep_indices_extended.append(h_sep_idx) + h_node_feats.append(torch.gather(batch_data, 0, h_sep_idx)) + q_tok_idx = torch.arange(q_limit[0], q_limit[1]).unsqueeze(-1).repeat(1, hidden_size).to(device) + q_tok_indices_extended.append(q_tok_idx) + q_node_feats.append(torch.gather(batch_data, 0, q_tok_idx)) + + # if self.use_hub_nodes: + # Map v_hub to the correct vector space + v_hub = self.h_gnn_dense_hub(v_hub) + v_hub = self.h_gnn_layer_norm_hub(v_hub) + v_hub = self.h_gnn_dropout_hub(v_hub) + # Add the hub node to the history nodes + v_hub = torch.split(v_hub, 1, dim=0) + h_node_feats = [torch.cat((h, x), dim=0) for h, x in zip(h_node_feats, v_hub)] + + # Create the history graph data and pass them through the GNNs + pg_hist_data = [Data(x=x, edge_index=idx) for x, idx in zip(h_node_feats, h_edge_indices)] + pg_hist_loader = DataLoader(pg_hist_data, batch_size=batch_size, shuffle=False) + for data in pg_hist_loader: + x_h, edge_index_h, h_gnn_batch_idx = data.x, data.edge_index, data.batch + for i in range(self.num_h_gnn_layers): + # Normalization + x_h = self.h_gnn_norm_layers[i](x_h, h_gnn_batch_idx) + # Graph propagation + x_h = self.h_gnn_layers[i](x_h, edge_index_h, edge_attr=None) + # Activation + x_h = self.gnn_act(x_h) + x_h + x_h = self.gnn_act(x_h) + + + h_hub = self.h_gnn_pooling(x_h, h_gnn_batch_idx) + + # Add the hub nodes + h_hub_split = torch.split(h_hub, 1, dim=0) + q_node_feats = [torch.cat((q, x), dim=0) for q, x in zip(q_node_feats, h_hub_split)] + + + # Create the question graph data and pass them through the GNNs + pg_ques_data = [Data(x=x, edge_index=idx, edge_attr=attr) for x, idx, attr in zip(q_node_feats, q_edge_indices, q_edge_attributes)] + pg_ques_loader = DataLoader(pg_ques_data, batch_size=batch_size, shuffle=False) + for data in pg_ques_loader: + x_q, edge_index_q, edge_attr_q, q_gnn_batch_idx = data.x, data.edge_index, data.edge_attr, data.batch + for i in range(self.num_q_gnn_layers): + # Normalization + x_q = self.q_gnn_norm_layers[i](x_q, q_gnn_batch_idx) + # GNN propagation + x_q = self.q_gnn_layers[i](x_q, edge_index_q, edge_attr=edge_attr_q) + # Activation + x_q = self.gnn_act(x_q) + x_q + x_q = self.gnn_act(x_q) + + + q_hub = self.q_gnn_pooling(x_q, q_gnn_batch_idx) + # Reshape the node features + h_node_feats = to_data_list(x_h, h_gnn_batch_idx) + q_node_feats = to_data_list(x_q, q_gnn_batch_idx) + + # Update the text tokens with the graph feats + zipped_data = zip(h_node_feats, h_sep_indices_extended, q_node_feats, q_tok_indices_extended) + for i, (h_node_feat, h_sep_idx, q_node_feat, q_tok_idx) in enumerate(zipped_data): + gnn_hidden_states[i].scatter(0, h_sep_idx, h_node_feat[:-1]) + gnn_hidden_states[i].scatter(0, q_tok_idx, q_node_feat[:-1]) + + final_hidden_states = 0.5 * (hidden_states + gnn_hidden_states) + return final_hidden_states, h_hub, q_hub + + +class BertImageSelfAttention(nn.Module): + def __init__(self, config): + super(BertImageSelfAttention, self).__init__() + if config.v_hidden_size % config.v_num_attention_heads != 0: + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.v_hidden_size, config.v_num_attention_heads) + ) + self.num_attention_heads = config.v_num_attention_heads + self.attention_head_size = int( + config.v_hidden_size / config.v_num_attention_heads + ) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.v_hidden_size, self.all_head_size) + self.key = nn.Linear(config.v_hidden_size, self.all_head_size) + self.value = nn.Linear(config.v_hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.v_attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states, attention_mask): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + return context_layer, attention_probs + +class BertImageSelfOutput(nn.Module): + def __init__(self, config): + super(BertImageSelfOutput, self).__init__() + self.dense = nn.Linear(config.v_hidden_size, config.v_hidden_size) + self.LayerNorm = BertLayerNorm(config.v_hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.v_hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + +class BertImageAttention(nn.Module): + def __init__(self, config): + super(BertImageAttention, self).__init__() + self.self = BertImageSelfAttention(config) + self.output = BertImageSelfOutput(config) + + def forward(self, input_tensor, attention_mask): + self_output, attention_probs = self.self(input_tensor, attention_mask) + attention_output = self.output(self_output, input_tensor) + return attention_output, attention_probs + + +class BertImageIntermediate(nn.Module): + def __init__(self, config): + super(BertImageIntermediate, self).__init__() + self.dense = nn.Linear(config.v_hidden_size, config.v_intermediate_size) + if isinstance(config.v_hidden_act, str) or ( + sys.version_info[0] == 2 and isinstance(config.v_hidden_act, unicode) + ): + self.intermediate_act_fn = ACT2FN[config.v_hidden_act] + else: + self.intermediate_act_fn = config.v_hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertImageOutput(nn.Module): + def __init__(self, config): + super(BertImageOutput, self).__init__() + self.dense = nn.Linear(config.v_intermediate_size, config.v_hidden_size) + self.LayerNorm = BertLayerNorm(config.v_hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.v_hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertImageLayer(nn.Module): + def __init__(self, config): + super(BertImageLayer, self).__init__() + self.attention = BertImageAttention(config) + self.intermediate = BertImageIntermediate(config) + self.output = BertImageOutput(config) + + def forward(self, hidden_states, attention_mask): + attention_output, attention_probs = self.attention(hidden_states, attention_mask) + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output, attention_probs + +class ImageGraphLayer(nn.Module): + def __init__(self, config): + super(ImageGraphLayer, self).__init__() + self.config = config + self.gnn_act = ACT2FN[config.gnn_act] + + self.num_gnn_layers = config.num_v_gnn_layers + self.gnn_layers = [] + self.gnn_norm_layers = [] + + for _ in range(self.num_gnn_layers): + self.gnn_layers.append( + pyg_nn.GATv2Conv( + config.v_hidden_size, config.v_hidden_size//config.num_gnn_attention_heads, + config.num_gnn_attention_heads, + dropout=config.gnn_dropout_prob, + edge_dim=config.v_gnn_edge_dim, + concat=True + ) + ) + # After each graph layer, a normalization layer is added + self.gnn_norm_layers.append(pyg_nn.PairNorm()) + + self.gnn_layers = nn.ModuleList(self.gnn_layers) + self.gnn_norm_layers = nn.ModuleList(self.gnn_norm_layers) + + self.gnn_dense_hub = nn.Linear(config.hidden_size, config.v_hidden_size) + self.gnn_layer_norm_hub = BertLayerNorm(config.v_hidden_size, eps=1e-12) + self.gnn_dropout_hub = nn.Dropout(config.gnn_dropout_prob) + + dense_pooling = nn.Sequential( + nn.Linear(config.v_hidden_size, 1), + ACT2FN['GeLU'], + nn.Dropout(config.gnn_dropout_prob) + ) + self.gnn_pooling = pyg_nn.GlobalAttention(dense_pooling) + + def forward( + self, hidden_states, edge_indices, edge_attributes, hub_states, + len_img_gr=None): + # assert hub_states is not None + gnn_hidden_states = hidden_states.clone().detach() + batch_size, num_img_reg, v_hidden_size = hidden_states.size() + node_feats = hidden_states.clone().detach() + # Remave the [IMG] feats + node_feats = node_feats[:, 1:] + node_feats = torch.split(node_feats, 1, dim=0) + + if len_img_gr is not None: + edge_indices = [t.squeeze(0)[:, :l].long() for t, l in zip(torch.split(edge_indices, 1, dim=0), len_img_gr)] + edge_attributes = [t.squeeze(0)[:l, :] for t, l in zip(torch.split(edge_attributes, 1, dim=0), len_img_gr)] + + # Concat the hub states + hub_states = self.gnn_dense_hub(hub_states) + hub_states = self.gnn_dropout_hub(hub_states) + hub_states = self.gnn_layer_norm_hub(hub_states) + + hub_states = torch.split(hub_states, 1, dim=0) + node_feats = [torch.cat((x.squeeze(0), h), dim=0) + for x, h in zip(node_feats, hub_states)] + + pg_data = [Data(x, idx, attr) for x, idx, attr in zip( + node_feats, edge_indices, edge_attributes)] + pg_dataloader = DataLoader( + pg_data, batch_size=batch_size, shuffle=False) + # Gnn forward pass + for data in pg_dataloader: + x, edge_index, edge_attr, gnn_batch_idx = data.x, data.edge_index, data.edge_attr, data.batch + for i in range(self.num_gnn_layers): + # Normalization + x = self.gnn_norm_layers[i](x, gnn_batch_idx) + # GNN propagation + x = self.gnn_layers[i](x, edge_index, edge_attr=edge_attr) + # Activation + x = self.gnn_act(x) + x + x = self.gnn_act(x) + + # Reshape the output of the GNN to batch_size x num_img_reg x hidden_dim + v_hub = self.gnn_pooling(x, gnn_batch_idx) + + x = x.view(batch_size, num_img_reg, v_hidden_size) + gnn_hidden_states[:, 1:, :] = x[:, :-1, :] + + final_hidden_states = 0.5 * (hidden_states + gnn_hidden_states) + + return final_hidden_states, v_hub + + +class BertBiAttention(nn.Module): + def __init__(self, config): + super(BertBiAttention, self).__init__() + if config.bi_hidden_size % config.bi_num_attention_heads != 0: + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.bi_hidden_size, config.bi_num_attention_heads) + ) + + self.num_attention_heads = config.bi_num_attention_heads + self.attention_head_size = int( + config.bi_hidden_size / config.bi_num_attention_heads + ) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + # self.scale = nn.Linear(1, self.num_attention_heads, bias=False) + # self.scale_act_fn = ACT2FN['relu'] + + self.query1 = nn.Linear(config.v_hidden_size, self.all_head_size) + self.key1 = nn.Linear(config.v_hidden_size, self.all_head_size) + self.value1 = nn.Linear(config.v_hidden_size, self.all_head_size) + # self.logit1 = nn.Linear(config.hidden_size, self.num_attention_heads) + + self.dropout1 = nn.Dropout(config.v_attention_probs_dropout_prob) + + self.query2 = nn.Linear(config.hidden_size, self.all_head_size) + self.key2 = nn.Linear(config.hidden_size, self.all_head_size) + self.value2 = nn.Linear(config.hidden_size, self.all_head_size) + # self.logit2 = nn.Linear(config.hidden_size, self.num_attention_heads) + + self.dropout2 = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, input_tensor1, attention_mask1, input_tensor2, attention_mask2, co_attention_mask=None, use_co_attention_mask=False): + + # for vision input. + mixed_query_layer1 = self.query1(input_tensor1) + mixed_key_layer1 = self.key1(input_tensor1) + mixed_value_layer1 = self.value1(input_tensor1) + # mixed_logit_layer1 = self.logit1(input_tensor1) + + query_layer1 = self.transpose_for_scores(mixed_query_layer1) + key_layer1 = self.transpose_for_scores(mixed_key_layer1) + value_layer1 = self.transpose_for_scores(mixed_value_layer1) + # logit_layer1 = self.transpose_for_logits(mixed_logit_layer1) + + # for text input: + mixed_query_layer2 = self.query2(input_tensor2) + mixed_key_layer2 = self.key2(input_tensor2) + mixed_value_layer2 = self.value2(input_tensor2) + # mixed_logit_layer2 = self.logit2(input_tensor2) + + query_layer2 = self.transpose_for_scores(mixed_query_layer2) + key_layer2 = self.transpose_for_scores(mixed_key_layer2) + value_layer2 = self.transpose_for_scores(mixed_value_layer2) + # logit_layer2 = self.transpose_for_logits(mixed_logit_layer2) + + # Take the dot product between "query2" and "key1" to get the raw attention scores for value 1. + attention_scores1 = torch.matmul(query_layer2, key_layer1.transpose(-1, -2)) + attention_scores1 = attention_scores1 / math.sqrt(self.attention_head_size) + attention_scores1 = attention_scores1 + attention_mask1 + + if use_co_attention_mask: + attention_scores1 = attention_scores1 + co_attention_mask.permute(0,1,3,2) + + # Normalize the attention scores to probabilities. + attention_probs1 = nn.Softmax(dim=-1)(attention_scores1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs1 = self.dropout1(attention_probs1) + + context_layer1 = torch.matmul(attention_probs1, value_layer1) + context_layer1 = context_layer1.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape1 = context_layer1.size()[:-2] + (self.all_head_size,) + context_layer1 = context_layer1.view(*new_context_layer_shape1) + + # Take the dot product between "query1" and "key2" to get the raw attention scores for value 2. + attention_scores2 = torch.matmul(query_layer1, key_layer2.transpose(-1, -2)) + attention_scores2 = attention_scores2 / math.sqrt(self.attention_head_size) + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + + # we can comment this line for single flow. + attention_scores2 = attention_scores2 + attention_mask2 + if use_co_attention_mask: + attention_scores2 = attention_scores2 + co_attention_mask + + # Normalize the attention scores to probabilities. + attention_probs2 = nn.Softmax(dim=-1)(attention_scores2) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs2 = self.dropout2(attention_probs2) + + context_layer2 = torch.matmul(attention_probs2, value_layer2) + context_layer2 = context_layer2.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape2 = context_layer2.size()[:-2] + (self.all_head_size,) + context_layer2 = context_layer2.view(*new_context_layer_shape2) + + return context_layer1, context_layer2, (attention_probs1, attention_probs2) + +class BertBiOutput(nn.Module): + def __init__(self, config): + super(BertBiOutput, self).__init__() + + self.dense1 = nn.Linear(config.bi_hidden_size, config.v_hidden_size) + self.LayerNorm1 = BertLayerNorm(config.v_hidden_size, eps=1e-12) + self.dropout1 = nn.Dropout(config.v_hidden_dropout_prob) + + self.q_dense1 = nn.Linear(config.bi_hidden_size, config.v_hidden_size) + self.q_dropout1 = nn.Dropout(config.v_hidden_dropout_prob) + + self.dense2 = nn.Linear(config.bi_hidden_size, config.hidden_size) + self.LayerNorm2 = BertLayerNorm(config.hidden_size, eps=1e-12) + self.dropout2 = nn.Dropout(config.hidden_dropout_prob) + + self.q_dense2 = nn.Linear(config.bi_hidden_size, config.hidden_size) + self.q_dropout2 = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states1, input_tensor1, hidden_states2, input_tensor2): + + + context_state1 = self.dense1(hidden_states1) + context_state1 = self.dropout1(context_state1) + + context_state2 = self.dense2(hidden_states2) + context_state2 = self.dropout2(context_state2) + + hidden_states1 = self.LayerNorm1(context_state1 + input_tensor1) + hidden_states2 = self.LayerNorm2(context_state2 + input_tensor2) + + return hidden_states1, hidden_states2 + +class BertConnectionLayer(nn.Module): + def __init__(self, config): + super(BertConnectionLayer, self).__init__() + self.biattention = BertBiAttention(config) + + self.biOutput = BertBiOutput(config) + + self.v_intermediate = BertImageIntermediate(config) + self.v_output = BertImageOutput(config) + + self.t_intermediate = BertIntermediate(config) + self.t_output = BertOutput(config) + + def forward(self, input_tensor1, attention_mask1, input_tensor2, attention_mask2, co_attention_mask=None, use_co_attention_mask=False): + + bi_output1, bi_output2, co_attention_probs = self.biattention( + input_tensor1, attention_mask1, input_tensor2, attention_mask2, co_attention_mask, use_co_attention_mask + ) + + attention_output1, attention_output2 = self.biOutput(bi_output2, input_tensor1, bi_output1, input_tensor2) + + intermediate_output1 = self.v_intermediate(attention_output1) + layer_output1 = self.v_output(intermediate_output1, attention_output1) + + intermediate_output2 = self.t_intermediate(attention_output2) + layer_output2 = self.t_output(intermediate_output2, attention_output2) + + return layer_output1, layer_output2, co_attention_probs + +class BertEncoder(nn.Module): + def __init__(self, config): + super(BertEncoder, self).__init__() + + # in the bert encoder, we need to extract three things here. + # text bert layer: BertLayer + # vision bert layer: BertImageLayer + # Bi-Attention: Given the output of two bertlayer, perform bi-directional + # attention and add on two layers. + + self.FAST_MODE = config.fast_mode + self.with_coattention = config.with_coattention + self.v_biattention_id = config.v_biattention_id + self.t_biattention_id = config.t_biattention_id + self.in_batch_pairs = config.in_batch_pairs + self.fixed_t_layer = config.fixed_t_layer + self.fixed_v_layer = config.fixed_v_layer + self.t_gnn_ids = config.t_gnn_ids + self.v_gnn_ids = config.v_gnn_ids + + v_layer = BertImageLayer(config) + connect_layer = BertConnectionLayer(config) + + self.layer = [] + for _ in range(config.num_hidden_layers): + self.layer.append(BertLayer(config)) + + self.layer = nn.ModuleList(self.layer) + + txt_graph_layer = TextGraphLayer(config) + self.t_gnns = nn.ModuleList([txt_graph_layer for _ in range(len(self.t_gnn_ids))]) + + + self.v_layer = nn.ModuleList( + [copy.deepcopy(v_layer) for _ in range(config.v_num_hidden_layers)] + ) + + img_graph_layer = ImageGraphLayer(config) + self.v_gnns = nn.ModuleList([img_graph_layer for _ in range(len(self.v_gnn_ids))]) + self.c_layer = nn.ModuleList( + [copy.deepcopy(connect_layer) + for _ in range(len(config.v_biattention_id))] + ) + + + def forward( + self, + txt_embedding, + image_embedding, + txt_attention_mask, + image_attention_mask, + image_edge_indices, + image_edge_attributes, + question_edge_indices, + question_edge_attributes, + question_limits, + history_edge_indices, + history_sep_indices, + co_attention_mask=None, + output_all_encoded_layers=True, + output_all_attention_masks=False, + len_img_gr=None, len_q_gr=None, len_h_gr=None, len_h_sep=None + ): + + v_start = 0 + t_start = 0 + count = 0 + all_encoder_layers_t = [] + all_encoder_layers_v = [] + + all_attention_mask_t = [] + all_attnetion_mask_v = [] + all_attention_mask_c = [] + + batch_size, num_words, t_hidden_size = txt_embedding.size() + _, num_regions, v_hidden_size = image_embedding.size() + + # self.pool_feats(txt_embedding) + use_co_attention_mask = False + + # Init the v_hub with the [IMG]-token embedding + v_hub = image_embedding[:, 0, :].clone().detach() + + + q_hub = None + for v_layer_id, t_layer_id in zip(self.v_biattention_id, self.t_biattention_id): + + v_end = v_layer_id + t_end = t_layer_id + + assert self.fixed_t_layer <= t_end + assert self.fixed_v_layer <= v_end + + for idx in range(v_start, self.fixed_v_layer): + with torch.no_grad(): + image_embedding, image_attention_probs = self.v_layer[idx]( + image_embedding, image_attention_mask) + v_start = self.fixed_v_layer + + if output_all_attention_masks: + all_attnetion_mask_v.append(image_attention_probs) + + for idx in range(v_start, v_end): + # Perfrom graph message passing and aggr. if applicable + if idx in self.v_gnn_ids: + assert q_hub is not None + v_gnn_layer_idx = self.v_gnn_ids.index(idx) + image_embedding, v_hub = self.v_gnns[v_gnn_layer_idx]( + image_embedding, + image_edge_indices, + image_edge_attributes, + q_hub, + len_img_gr=len_img_gr, + ) + + # Perform standard bert self-attention + image_embedding, image_attention_probs = self.v_layer[idx]( + image_embedding, image_attention_mask) + if output_all_attention_masks: + all_attnetion_mask_v.append(image_attention_probs) + + for idx in range(t_start, self.fixed_t_layer): + with torch.no_grad(): + txt_embedding, txt_attention_probs = self.layer[idx](txt_embedding, txt_attention_mask) + t_start = self.fixed_t_layer + if output_all_attention_masks: + all_attention_mask_t.append(txt_attention_probs) + + for idx in range(t_start, t_end): + # Perfrom graph message passing and aggr. if applicable + if idx in self.t_gnn_ids: + t_gnn_layer_idx = self.t_gnn_ids.index(idx) + txt_embedding, h_hub, q_hub = self.t_gnns[t_gnn_layer_idx]( + txt_embedding, + question_edge_indices, + question_edge_attributes, + question_limits, + history_edge_indices, + history_sep_indices, + v_hub, + len_q_gr=len_q_gr, + len_h_gr=len_h_gr, + len_h_sep=len_h_sep + ) + # Perform standard bert self-attention + txt_embedding, txt_attention_probs = self.layer[idx](txt_embedding, txt_attention_mask) + if output_all_attention_masks: + all_attention_mask_t.append(txt_attention_probs) + + if count == 0 and self.in_batch_pairs: + # new batch size is the batch_size ^2 + image_embedding = image_embedding.unsqueeze(0).expand(batch_size, batch_size, num_regions, v_hidden_size).contiguous( + ).view(batch_size*batch_size, num_regions, v_hidden_size) + image_attention_mask = image_attention_mask.unsqueeze(0).expand( + batch_size, batch_size, 1, 1, num_regions).contiguous().view(batch_size*batch_size, 1, 1, num_regions) + + txt_embedding = txt_embedding.unsqueeze(1).expand(batch_size, batch_size, num_words, t_hidden_size).contiguous( + ).view(batch_size*batch_size, num_words, t_hidden_size) + txt_attention_mask = txt_attention_mask.unsqueeze(1).expand( + batch_size, batch_size, 1, 1, num_words).contiguous().view(batch_size*batch_size, 1, 1, num_words) + co_attention_mask = co_attention_mask.unsqueeze(1).expand( + batch_size, batch_size, 1, num_regions, num_words).contiguous().view(batch_size*batch_size, 1, num_regions, num_words) + + if self.with_coattention: + # do the bi attention. + image_embedding, txt_embedding, co_attention_probs = self.c_layer[count]( + image_embedding, image_attention_mask, txt_embedding, txt_attention_mask, co_attention_mask, use_co_attention_mask) + + # use_co_attention_mask = False + if output_all_attention_masks: + all_attention_mask_c.append(co_attention_probs) + + v_start = v_end + t_start = t_end + count += 1 + + if output_all_encoded_layers: + all_encoder_layers_t.append(txt_embedding) + all_encoder_layers_v.append(image_embedding) + + for idx in range(v_start, len(self.v_layer)): + # Perfrom graph message passing and aggr. if applicable + if idx in self.v_gnn_ids: + v_gnn_layer_idx = self.v_gnn_ids.index(idx) + image_embedding, v_hub = self.v_gnns[v_gnn_layer_idx]( + image_embedding, + image_edge_indices, + image_edge_attributes, + q_hub, + len_img_gr=len_img_gr + ) + + # Perform standard bert self-attention + image_embedding, image_attention_probs = self.v_layer[idx]( + image_embedding, image_attention_mask) + + if output_all_attention_masks: + all_attnetion_mask_v.append(image_attention_probs) + + + for idx in range(t_start, len(self.layer)): + # Perfrom graph message passing and aggr. if applicable + if idx in self.t_gnn_ids: + t_gnn_layer_idx = self.t_gnn_ids.index(idx) + txt_embedding, h_hub, q_hub = self.t_gnns[t_gnn_layer_idx]( + txt_embedding, + question_edge_indices, + question_edge_attributes, + question_limits, + history_edge_indices, + history_sep_indices, + v_hub, + len_q_gr=len_q_gr, + len_h_gr=len_h_gr, + len_h_sep=len_h_sep + ) + + # Perform standard bert self-attention + txt_embedding, txt_attention_probs = self.layer[idx](txt_embedding, txt_attention_mask) + + if output_all_attention_masks: + all_attention_mask_t.append(txt_attention_probs) + + # add the end part to finish. + if not output_all_encoded_layers: + all_encoder_layers_t.append(txt_embedding) + all_encoder_layers_v.append(image_embedding) + + + return all_encoder_layers_t, all_encoder_layers_v, (all_attention_mask_t, all_attnetion_mask_v, all_attention_mask_c), (h_hub, q_hub, v_hub) + + +class BertTextPooler(nn.Module): + def __init__(self, config): + super(BertTextPooler, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.bi_hidden_size) + self.activation = nn.ReLU() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + +class BertImagePooler(nn.Module): + def __init__(self, config): + super(BertImagePooler, self).__init__() + self.dense = nn.Linear(config.v_hidden_size, config.bi_hidden_size) + self.activation = nn.ReLU() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super(BertPredictionHeadTransform, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str) or ( + sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode) + ): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + +class BertImgPredictionHeadTransform(nn.Module): + def __init__(self, config): + super(BertImgPredictionHeadTransform, self).__init__() + self.dense = nn.Linear(config.v_hidden_size, config.v_hidden_size) + if isinstance(config.hidden_act, str) or ( + sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode) + ): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.v_hidden_act + self.LayerNorm = BertLayerNorm(config.v_hidden_size, eps=1e-12) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + +class BertLMPredictionHead(nn.Module): + def __init__(self, config, bert_model_embedding_weights): + super(BertLMPredictionHead, self).__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear( + bert_model_embedding_weights.size(1), + bert_model_embedding_weights.size(0), + bias=False, + ) + self.decoder.weight = bert_model_embedding_weights + self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0))) + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + self.bias + return hidden_states + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config, bert_model_embedding_weights): + super(BertOnlyMLMHead, self).__init__() + self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + +class BertOnlyNSPHead(nn.Module): + def __init__(self, config): + super(BertOnlyNSPHead, self).__init__() + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, pooled_output): + seq_relationship_score = self.seq_relationship(pooled_output) + return seq_relationship_score + +class BertPreTrainingHeads(nn.Module): + def __init__(self, config, bert_model_embedding_weights): + super(BertPreTrainingHeads, self).__init__() + self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights) + self.bi_seq_relationship = nn.Linear(config.bi_hidden_size, 2) + self.imagePredictions = BertImagePredictionHead(config) + self.fusion_method = config.fusion_method + self.dropout = nn.Dropout(0.1) + + def forward( + self, sequence_output_t, sequence_output_v, pooled_output_t, pooled_output_v + ): + + if self.fusion_method == 'sum': + pooled_output = self.dropout(pooled_output_t + pooled_output_v) + elif self.fusion_method == 'mul': + pooled_output = self.dropout(pooled_output_t * pooled_output_v) + else: + assert False + + prediction_scores_t = self.predictions(sequence_output_t) + seq_relationship_score = self.bi_seq_relationship(pooled_output) + prediction_scores_v = self.imagePredictions(sequence_output_v) + + return prediction_scores_t, prediction_scores_v, seq_relationship_score + +class BertImagePredictionHead(nn.Module): + def __init__(self, config): + super(BertImagePredictionHead, self).__init__() + self.transform = BertImgPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.v_hidden_size, config.v_target_size) + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + +class BertPreTrainedModel(nn.Module): + """ An abstract class to handle weights initialization and + a simple interface for dowloading and loading pretrained models. + """ + + def __init__(self, config, device='cuda:0', default_gpu=True, *inputs, **kwargs): + super(BertPreTrainedModel, self).__init__() + + if not isinstance(config, BertConfig): + raise ValueError( + "Parameter config in `{}(config)` should be an instance of class `BertConfig`. " + "To create a model from a Google pretrained model use " + "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( + self.__class__.__name__, self.__class__.__name__ + ) + ) + + self.config = config + + def init_bert_weights(self, module): + """ Initialize the weights. + """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, BertLayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path, + config, + device, + use_apex=False, + default_gpu=True, + state_dict=None, + cache_dir=None, + from_tf=False, + *inputs, + **kwargs + ): + """ + Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict. + Download and cache the pre-trained model file if needed. + + Params: + pretrained_model_name_or_path: either: + - a str with the name of a pre-trained model to load selected in the list of: + . `bert-base-uncased` + . `bert-large-uncased` + . `bert-base-cased` + . `bert-large-cased` + . `bert-base-multilingual-uncased` + . `bert-base-multilingual-cased` + . `bert-base-chinese` + - a path or url to a pretrained model archive containing: + . `bert_config.json` a configuration file for the model + . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance + - a path or url to a pretrained model archive containing: + . `bert_config.json` a configuration file for the model + . `model.chkpt` a TensorFlow checkpoint + from_tf: should we load the weights from a locally saved TensorFlow checkpoint + cache_dir: an optional path to a folder in which the pre-trained models will be cached. + state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models + *inputs, **kwargs: additional input for the specific Bert class + (ex: num_labels for BertForSequenceClassification) + """ + CONFIG_NAME = "bert_config.json" + WEIGHTS_NAME = "pytorch_model.bin" + TF_WEIGHTS_NAME = "model.ckpt" + + if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: + archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path] + else: + archive_file = pretrained_model_name_or_path + # redirect to the cache, if necessary + try: + resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) + except EnvironmentError: + logger.error( + "Model name '{}' was not found in model name list ({}). " + "We assumed '{}' was a path or url but couldn't find any file " + "associated to this path or url.".format( + pretrained_model_name_or_path, + ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), + archive_file, + ) + ) + return None + + if default_gpu: + if resolved_archive_file == archive_file: + logger.info("loading archive file {}".format(archive_file)) + else: + logger.info( + "loading archive file {} from cache at {}".format( + archive_file, resolved_archive_file + ) + ) + tempdir = None + if os.path.isdir(resolved_archive_file) or from_tf: + serialization_dir = resolved_archive_file + elif resolved_archive_file[-3:] == 'bin': + serialization_dir = '/'.join(resolved_archive_file.split('/')[:-1]) + WEIGHTS_NAME = resolved_archive_file.split('/')[-1] + else: + # Extract archive to temp dir + tempdir = tempfile.mkdtemp() + logger.info( + "extracting archive file {} to temp dir {}".format( + resolved_archive_file, tempdir + ) + ) + with tarfile.open(resolved_archive_file, "r:gz") as archive: + archive.extractall(tempdir) + serialization_dir = tempdir + # Load config + # config_file = os.path.join(serialization_dir, CONFIG_NAME) + # config = BertConfig.from_json_file(config_file) + if default_gpu: + # cancel output + # logger.info("Model config {}".format(config)) + pass + # Instantiate model. + model = cls(config, device, use_apex, *inputs, **kwargs) + if state_dict is None and not from_tf: + weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) + map_location = {'cuda:0': device} + state_dict = torch.load( + weights_path, + map_location=map_location + ) + if 'state_dict' in dir(state_dict): + state_dict = state_dict.state_dict() + + if tempdir: + # Clean up temp dir + shutil.rmtree(tempdir) + if from_tf: + # Directly load from a TensorFlow checkpoint + weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME) + return load_tf_weights_in_bert(model, weights_path) + # Load from a PyTorch state_dict + old_keys = [] + new_keys = [] + for key in state_dict.keys(): + new_key = None + if "gamma" in key: + new_key = key.replace("gamma", "weight") + if "beta" in key: + new_key = key.replace("beta", "bias") + if new_key: + old_keys.append(key) + new_keys.append(new_key) + for old_key, new_key in zip(old_keys, new_keys): + state_dict[new_key] = state_dict.pop(old_key) + + missing_keys = [] + unexpected_keys = [] + error_msgs = [] + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, "_metadata", None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + def load(module, prefix=""): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + module._load_from_state_dict( + state_dict, + prefix, + local_metadata, + True, + missing_keys, + unexpected_keys, + error_msgs, + ) + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + ".") + + start_prefix = "" + if not hasattr(model, "bert") and any( + s.startswith("bert.") for s in state_dict.keys() + ): + start_prefix = "bert." + load(model, prefix=start_prefix) + if len(missing_keys) > 0 and default_gpu: + logger.info( + "Weights of {} not initialized from pretrained model: {}".format( + model.__class__.__name__, missing_keys + ) + ) + if len(unexpected_keys) > 0 and default_gpu: + logger.info( + "Weights from pretrained model not used in {}: {}".format( + model.__class__.__name__, unexpected_keys + ) + ) + if len(error_msgs) > 0 and default_gpu: + raise RuntimeError( + "Error(s) in loading state_dict for {}:\n\t{}".format( + model.__class__.__name__, "\n\t".join(error_msgs) + ) + ) + return model + + +class BertModel(BertPreTrainedModel): + """BERT model ("Bidirectional Embedding Representations from a Transformer"). + + Params: + config: a BertConfig class instance with the configuration to build a new model + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token + types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to + a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`. + + Outputs: Tuple of (encoded_layers, pooled_output) + `encoded_layers`: controled by `output_all_encoded_layers` argument: + - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end + of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each + encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size], + - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding + to the last attention block of shape [batch_size, sequence_length, hidden_size], + `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a + classifier pretrained on top of the hidden state associated to the first character of the + input (`CLS`) to train on the Next-Sentence task (see BERT's paper). + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + + config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + + model = modeling.BertModel(config=config) + all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) + ``` + """ + + def __init__(self, config, device, use_apex=False): + super(BertModel, self).__init__(config, device) + + # initilize word embedding + self.embeddings = BertEmbeddingsDialog(config, device) + + # initlize the vision embedding + self.v_embeddings = BertImageEmbeddings(config) + + self.encoder = BertEncoder(config) + self.t_pooler = BertTextPooler(config) + self.v_pooler = BertImagePooler(config) + + self.use_apex = use_apex + + self.apply(self.init_bert_weights) + + def forward( + self, + input_txt, + input_imgs, + image_loc, + image_edge_indices, + image_edge_attributes, + question_edge_indices, + question_edge_attributes, + question_limits, + history_edge_indices, + history_sep_indices, + sep_indices=None, + sep_len=None, + token_type_ids=None, + attention_mask=None, + image_attention_mask=None, + co_attention_mask=None, + output_all_encoded_layers=False, + output_all_attention_masks=False, + ): + if attention_mask is None: + attention_mask = torch.ones_like(input_txt) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_txt) + if image_attention_mask is None: + image_attention_mask = torch.ones( + input_imgs.size(0), input_imgs.size(1) + ).type_as(input_txt) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + extended_image_attention_mask = image_attention_mask.unsqueeze(1).unsqueeze(2) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + if self.use_apex: + dtype = dtype=next(self.parameters()).dtype + else: + dtype = torch.float32 + extended_attention_mask = extended_attention_mask.to(dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + extended_image_attention_mask = extended_image_attention_mask.to(dtype) # fp16 compatibility + extended_image_attention_mask = (1.0 - extended_image_attention_mask) * -10000.0 + + if co_attention_mask is None: + co_attention_mask = torch.zeros(input_txt.size(0), input_imgs.size(1), input_txt.size(1)).type_as(extended_image_attention_mask) + + extended_co_attention_mask = co_attention_mask.unsqueeze(1) + + # extended_co_attention_mask = co_attention_mask.unsqueeze(-1) + extended_co_attention_mask = extended_co_attention_mask * 5.0 + extended_co_attention_mask = extended_co_attention_mask.to(dtype) # fp16 compatibility + + embedding_output = self.embeddings(input_txt, token_type_ids=token_type_ids, sep_indices=sep_indices, sep_len=sep_len) + v_embedding_output = self.v_embeddings(input_imgs, image_loc) + + encoded_layers_t, encoded_layers_v, all_attention_mask, hub_feats = self.encoder( + embedding_output, + v_embedding_output, + extended_attention_mask, + extended_image_attention_mask, + image_edge_indices, + image_edge_attributes, + question_edge_indices, + question_edge_attributes, + question_limits, + history_edge_indices, + history_sep_indices, + co_attention_mask=extended_co_attention_mask, + output_all_encoded_layers=output_all_encoded_layers, + output_all_attention_masks=output_all_attention_masks, + ) + + sequence_output_t = encoded_layers_t[-1] + sequence_output_v = encoded_layers_v[-1] + + pooled_output_t = self.t_pooler(sequence_output_t) + pooled_output_v = self.v_pooler(sequence_output_v) + + if not output_all_encoded_layers: + encoded_layers_t = encoded_layers_t[-1] + encoded_layers_v = encoded_layers_v[-1] + + return encoded_layers_t, encoded_layers_v, pooled_output_t, pooled_output_v, all_attention_mask, hub_feats + +class BertImageEmbeddings(nn.Module): + """Construct the embeddings from image, spatial location (omit now) and token_type embeddings. + """ + def __init__(self, config): + super(BertImageEmbeddings, self).__init__() + + self.image_embeddings = nn.Linear(config.v_feature_size, config.v_hidden_size) + self.image_location_embeddings = nn.Linear(5, config.v_hidden_size) + self.LayerNorm = BertLayerNorm(config.v_hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, input_ids, input_loc): + + img_embeddings = self.image_embeddings(input_ids) + loc_embeddings = self.image_location_embeddings(input_loc) + embeddings = self.LayerNorm(img_embeddings+loc_embeddings) + embeddings = self.dropout(embeddings) + + return embeddings + +class BertForMultiModalPreTraining(BertPreTrainedModel): + """BERT model with multi modal pre-training heads. + """ + + def __init__(self, config, device, use_apex=False): + super(BertForMultiModalPreTraining, self).__init__(config, device) + + self.bert = BertModel(config, device, use_apex) + self.cls = BertPreTrainingHeads( + config, self.bert.embeddings.word_embeddings.weight + ) + + self.apply(self.init_bert_weights) + self.predict_feature = config.predict_feature + self.loss_fct = CrossEntropyLoss(ignore_index=-1) + + print("model's option for predict_feature is ", config.predict_feature) + + if self.predict_feature: + self.vis_criterion = nn.MSELoss(reduction="none") + else: + self.vis_criterion = nn.KLDivLoss(reduction="none") + + def forward( + self, + input_ids, + image_feat, + image_loc, + image_edge_indices, + image_edge_attributes, + question_edge_indices, + question_edge_attributes, + question_limits, + history_edge_indices, + history_sep_indices, + sep_indices=None, + sep_len=None, + token_type_ids=None, + attention_mask=None, + image_attention_mask=None, + masked_lm_labels=None, + image_label=None, + image_target = None, + next_sentence_label=None, + output_all_attention_masks=False + ): + + # in this model, we first embed the images. + sequence_output_t, sequence_output_v, pooled_output_t, pooled_output_v, all_attention_mask, hub_nodes = self.bert( + input_ids, + image_feat, + image_loc, + image_edge_indices, + image_edge_attributes, + question_edge_indices, + question_edge_attributes, + question_limits, + history_edge_indices, + history_sep_indices, + sep_indices=sep_indices, + sep_len=sep_len, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + image_attention_mask=image_attention_mask, + output_all_encoded_layers=False, + output_all_attention_masks=output_all_attention_masks + ) + + prediction_scores_t, prediction_scores_v, seq_relationship_score = self.cls( + sequence_output_t, sequence_output_v, pooled_output_t, pooled_output_v + ) + if masked_lm_labels is not None and next_sentence_label is not None and image_target is not None: + + # prediction_scores_v = prediction_scores_v[:, 1:] + if self.predict_feature: + img_loss = self.vis_criterion(prediction_scores_v, image_target) + masked_img_loss = torch.sum( + img_loss * (image_label == 1).unsqueeze(2).float() + ) / max(torch.sum((image_label == 1).unsqueeze(2).expand_as(img_loss)),1) + + else: + img_loss = self.vis_criterion( + F.log_softmax(prediction_scores_v, dim=2), image_target + ) + masked_img_loss = torch.sum( + img_loss * (image_label == 1).unsqueeze(2).float() + ) / max(torch.sum((image_label == 1)), 0) + + # masked_img_loss = torch.sum(img_loss) / (img_loss.shape[0] * img_loss.shape[1]) + masked_lm_loss = self.loss_fct( + prediction_scores_t.view(-1, self.config.vocab_size), + masked_lm_labels.view(-1), + ) + next_sentence_loss = self.loss_fct( + seq_relationship_score.view(-1, 2), next_sentence_label.view(-1) + ) + # total_loss = masked_lm_loss + next_sentence_loss + masked_img_loss + return masked_lm_loss.unsqueeze(0), masked_img_loss.unsqueeze(0), next_sentence_loss.unsqueeze(0), sequence_output_t, prediction_scores_t, seq_relationship_score, hub_nodes + else: + return prediction_scores_t, prediction_scores_v, seq_relationship_score, sequence_output_t, all_attention_mask, hub_nodes + + def get_text_embedding( + self, + input_ids, + image_feat, + image_loc, + sep_indices=None, + sep_len=None, + token_type_ids=None, + attention_mask=None, + image_attention_mask=None, + output_all_attention_masks=False + ): + + # in this model, we first embed the images. + sequence_output_t, sequence_output_v, pooled_output_t, pooled_output_v, all_attention_mask = self.bert( + input_ids, + image_feat, + image_loc, + sep_indices=sep_indices, + sep_len=sep_len, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + image_attention_mask=image_attention_mask, + output_all_encoded_layers=False, + output_all_attention_masks=output_all_attention_masks + ) + + return sequence_output_t # [batch_size, num_words, 768] + + +class VILBertForVLTasks(BertPreTrainedModel): + def __init__(self, config, device, num_labels, use_apex=False, dropout_prob=0.1, default_gpu=True): + super(VILBertForVLTasks, self).__init__(config) + self.num_labels = num_labels + self.bert = BertModel(config, device, use_apex) + self.use_apex = use_apex + self.dropout = nn.Dropout(dropout_prob) + self.cls = BertPreTrainingHeads( + config, self.bert.embeddings.word_embeddings.weight + ) + self.vil_prediction = SimpleClassifier(config.bi_hidden_size, config.bi_hidden_size*2, num_labels, 0.5) + # self.vil_prediction = nn.Linear(config.bi_hidden_size, num_labels) + self.vil_logit = nn.Linear(config.bi_hidden_size, 1) + self.vision_logit = nn.Linear(config.v_hidden_size, 1) + self.linguisic_logit = nn.Linear(config.hidden_size, 1) + self.fusion_method = config.fusion_method + self.apply(self.init_bert_weights) + + def forward( + self, + input_txt, + input_imgs, + image_loc, + token_type_ids=None, + attention_mask=None, + image_attention_mask=None, + co_attention_mask=None, + output_all_encoded_layers=False, + ): + sequence_output_t, sequence_output_v, pooled_output_t, pooled_output_v, _ = self.bert( + input_txt, + input_imgs, + image_loc, + token_type_ids, + attention_mask, + image_attention_mask, + co_attention_mask, + output_all_encoded_layers=False, + ) + + vil_prediction = 0 + vil_logit = 0 + vil_binary_prediction = 0 + vision_prediction = 0 + vision_logit = 0 + linguisic_prediction = 0 + linguisic_logit = 0 + + linguisic_prediction, vision_prediction, vil_binary_prediction = self.cls( + sequence_output_t, sequence_output_v, pooled_output_t, pooled_output_v + ) + + if self.fusion_method == 'sum': + pooled_output = self.dropout(pooled_output_t + pooled_output_v) + elif self.fusion_method == 'mul': + pooled_output = self.dropout(pooled_output_t * pooled_output_v) + else: + assert False + + vil_prediction = self.vil_prediction(pooled_output) + vil_logit = self.vil_logit(pooled_output) + if self.use_apex: + dtype = next(self.parameters()).dtype + else: + dtype = torch.float32 + vision_logit = self.vision_logit(self.dropout(sequence_output_v)) + ((1.0 - image_attention_mask)* -10000.0).unsqueeze(2).to(dtype) + linguisic_logit = self.linguisic_logit(self.dropout(sequence_output_t)) + + return vil_prediction, vil_logit, vil_binary_prediction, vision_prediction, vision_logit, linguisic_prediction, linguisic_logit + +class SimpleClassifier(nn.Module): + def __init__(self, in_dim, hid_dim, out_dim, dropout): + super(SimpleClassifier, self).__init__() + layers = [ + weight_norm(nn.Linear(in_dim, hid_dim), dim=None), + nn.ReLU(), + nn.Dropout(dropout, inplace=True), + weight_norm(nn.Linear(hid_dim, out_dim), dim=None) + ] + self.main = nn.Sequential(*layers) + + def forward(self, x): + logits = self.main(x) diff --git a/setup_data.sh b/setup_data.sh new file mode 100644 index 0000000..b749fd0 --- /dev/null +++ b/setup_data.sh @@ -0,0 +1,18 @@ +cd data +# Exract the graphs +tar xvfz history_adj_matrices.tar.gz +tar xvfz question_adj_matrices.tar.gz +tar xvfz img_adj_matrices.tar.gz + +# Remove the .tar files +rm *.tar.gz + +# Download the preprocessed image features +mkdir visdial_img_feat.lmdb +wget https://s3.amazonaws.com/visdial-bert/data/visdial_image_feats.lmdb/data.mdb -O visdial_img_feat.lmdb/data.mdb +wget https://s3.amazonaws.com/visdial-bert/data/visdial_image_feats.lmdb/lock.mdb -O visdial_img_feat.lmdb/lock.mdb + +echo Data setup successfully... + +cd .. + diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/data_utils.py b/utils/data_utils.py new file mode 100644 index 0000000..ac8503e --- /dev/null +++ b/utils/data_utils.py @@ -0,0 +1,290 @@ +import torch +from torch.autograd import Variable +import random +import pickle +import numpy as np +from copy import deepcopy + + +def load_pickle_lines(filename): + data = [] + with open(filename, 'rb') as f: + while True: + try: + data.append(pickle.load(f)) + except EOFError: + break + return data + + +def flatten(l): + return [item for sublist in l for item in sublist] + + +def build_len_mask_batch( + # [batch_size], [] + len_batch, max_len=None +): + if max_len is None: + max_len = len_batch.max().item() + # try: + batch_size, = len_batch.shape + # [batch_size, max_len] + idxes_batch = torch.arange(max_len, device=len_batch.device).view(1, -1).repeat(batch_size, 1) + # [batch_size, max_len] = [batch_size, max_len] < [batch_size, 1] + return idxes_batch < len_batch.view(-1, 1) + + +def sequence_mask(sequence_length, max_len=None): + if max_len is None: + max_len = sequence_length.data.max() + batch_size = sequence_length.size(0) + seq_range = torch.arange(0, max_len).long() + seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) + seq_range_expand = Variable(seq_range_expand) + if sequence_length.is_cuda: + seq_range_expand = seq_range_expand.to(sequence_length.device) + seq_length_expand = (sequence_length.unsqueeze(1) + .expand_as(seq_range_expand)) + return seq_range_expand < seq_length_expand + +def batch_iter(dataloader, params): + for epochId in range(params['num_epochs']): + for idx, batch in enumerate(dataloader): + yield epochId, idx, batch + +def list2tensorpad(inp_list, max_seq_len): + inp_tensor = torch.LongTensor([inp_list]) + inp_tensor_zeros = torch.zeros(1, max_seq_len, dtype=torch.long) + inp_tensor_zeros[0,:inp_tensor.shape[1]] = inp_tensor # after preprocess, inp_tensor.shape[1] must < max_seq_len + inp_tensor = inp_tensor_zeros + return inp_tensor + + +def encode_input(utterances, start_segment, CLS, SEP, MASK, max_seq_len=256,max_sep_len=25,mask_prob=0.2): + + cur_segment = start_segment + token_id_list = [] + segment_id_list = [] + sep_token_indices = [] + masked_token_list = [] + + token_id_list.append(CLS) + segment_id_list.append(cur_segment) + masked_token_list.append(0) + + cur_sep_token_index = 0 + + for cur_utterance in utterances: + # add the masked token and keep track + cur_masked_index = [1 if random.random() < mask_prob else 0 for _ in range(len(cur_utterance))] + masked_token_list.extend(cur_masked_index) + token_id_list.extend(cur_utterance) + segment_id_list.extend([cur_segment]*len(cur_utterance)) + + token_id_list.append(SEP) + segment_id_list.append(cur_segment) + masked_token_list.append(0) + cur_sep_token_index = cur_sep_token_index + len(cur_utterance) + 1 + sep_token_indices.append(cur_sep_token_index) + cur_segment = cur_segment ^ 1 # cur segment osciallates between 0 and 1 + start_question, end_question = sep_token_indices[-3] + 1, sep_token_indices[-2] + assert end_question - start_question == len(utterances[-2]) + + assert len(segment_id_list) == len(token_id_list) == len(masked_token_list) == sep_token_indices[-1] + 1 + # convert to tensors and pad to maximum seq length + tokens = list2tensorpad(token_id_list,max_seq_len) # [1, max_len] + masked_tokens = list2tensorpad(masked_token_list,max_seq_len) + masked_tokens[0,masked_tokens[0,:]==0] = -1 + mask = masked_tokens[0,:]==1 + masked_tokens[0,mask] = tokens[0,mask] + tokens[0,mask] = MASK + + segment_id_list = list2tensorpad(segment_id_list,max_seq_len) + return tokens, segment_id_list, list2tensorpad(sep_token_indices,max_sep_len), masked_tokens, start_question, end_question + +def encode_input_with_mask(utterances, start_segment, CLS, SEP, MASK, max_seq_len=256,max_sep_len=25,mask_prob=0.2, get_q_limits=True): + + cur_segment = start_segment + token_id_list = [] + segment_id_list = [] + sep_token_indices = [] + masked_token_list = [] + input_mask_list = [] + + token_id_list.append(CLS) + segment_id_list.append(cur_segment) + masked_token_list.append(0) + input_mask_list.append(1) + + cur_sep_token_index = 0 + + for cur_utterance in utterances: + # add the masked token and keep track + cur_masked_index = [1 if random.random() < mask_prob else 0 for _ in range(len(cur_utterance))] + masked_token_list.extend(cur_masked_index) + token_id_list.extend(cur_utterance) + segment_id_list.extend([cur_segment]*len(cur_utterance)) + input_mask_list.extend([1]*len(cur_utterance)) + + token_id_list.append(SEP) + segment_id_list.append(cur_segment) + masked_token_list.append(0) + input_mask_list.append(1) + cur_sep_token_index = cur_sep_token_index + len(cur_utterance) + 1 + sep_token_indices.append(cur_sep_token_index) + cur_segment = cur_segment ^ 1 # cur segment osciallates between 0 and 1 + + if get_q_limits: + start_question, end_question = sep_token_indices[-3] + 1, sep_token_indices[-2] + assert end_question - start_question == len(utterances[-2]) + else: + start_question, end_question = -1, -1 + assert len(segment_id_list) == len(token_id_list) == len(masked_token_list) ==len(input_mask_list) == sep_token_indices[-1] + 1 + # convert to tensors and pad to maximum seq length + tokens = list2tensorpad(token_id_list, max_seq_len) + masked_tokens = list2tensorpad(masked_token_list, max_seq_len) + input_mask = list2tensorpad(input_mask_list,max_seq_len) + masked_tokens[0,masked_tokens[0,:]==0] = -1 + mask = masked_tokens[0,:]==1 + masked_tokens[0,mask] = tokens[0,mask] + tokens[0,mask] = MASK + + segment_id_list = list2tensorpad(segment_id_list,max_seq_len) + return tokens, segment_id_list, list2tensorpad(sep_token_indices,max_sep_len),masked_tokens, input_mask, start_question, end_question + + +def encode_image_input(features, num_boxes, boxes, image_target, max_regions=37, mask_prob=0.15): + output_label = [] + num_boxes = min(int(num_boxes), max_regions) + + mix_boxes_pad = np.zeros((max_regions, boxes.shape[-1])) + mix_features_pad = np.zeros((max_regions, features.shape[-1])) + mix_image_target = np.zeros((max_regions, image_target.shape[-1])) + + mix_boxes_pad[:num_boxes] = boxes[:num_boxes] + mix_features_pad[:num_boxes] = features[:num_boxes] + mix_image_target[:num_boxes] = image_target[:num_boxes] + + boxes = mix_boxes_pad + features = mix_features_pad + image_target = mix_image_target + mask_indexes = [] + for i in range(num_boxes): + prob = random.random() + # mask token with 15% probability + if prob < mask_prob: + prob /= mask_prob + + # 80% randomly change token to mask token + if prob < 0.9: + features[i] = 0 + output_label.append(1) + mask_indexes.append(i) + else: + # no masking token (will be ignored by loss function later) + output_label.append(-1) + + image_mask = [1] * (int(num_boxes)) + while len(image_mask) < max_regions: + image_mask.append(0) + output_label.append(-1) + + # ensure we have atleast one region being predicted + output_label[random.randint(1,len(output_label)-1)] = 1 + image_label = torch.LongTensor(output_label) + image_label[0] = 0 # make sure the token doesn't contribute to the masked loss + image_mask = torch.tensor(image_mask).float() + + features = torch.tensor(features).float() + spatials = torch.tensor(boxes).float() + image_target = torch.tensor(image_target).float() + + return features, spatials, image_mask, image_target, image_label + + +def question_edge_masking(question_edge_indices, question_edge_attributes, mask, question_limits, mask_prob=0.4, max_len=10): + mask = mask.squeeze().tolist() + question_limits = question_limits.tolist() + question_start, question_end = question_limits + # Get the masking of the question + mask_question = mask[question_start:question_end] + masked_idx = np.argwhere(np.array(mask_question) > -1).squeeze().tolist() + if isinstance(masked_idx, (int)): # only one question token is masked + masked_idx = [masked_idx] + + # get rid of all edge indices and attributes that corresond to masked tokens + edge_attr_gt = [] + edge_idx_gt_gnn = [] + edge_idx_gt_bert = [] + for i, (question_edge_idx, question_edge_attr) in enumerate(zip(question_edge_indices, question_edge_attributes)): + if not(question_edge_idx[0] in masked_idx or question_edge_idx[1] in masked_idx): + # Masking + if random.random() < mask_prob: + edge_attr_gt.append(np.argwhere(question_edge_attr).item()) + edge_idx_gt_gnn.append(question_edge_idx) + edge_idx_gt_bert.append([question_edge_idx[0] + question_start, question_edge_idx[1] + question_start]) + question_edge_attr = np.zeros_like(question_edge_attr) + question_edge_attr[-1] = 1.0 # The [EDGE_MASK] special token is the last one hot vector encoding + question_edge_attributes[i] = question_edge_attr + else: + continue + # Force masking if the necessary: + if len(edge_attr_gt) == 0: + for i, (question_edge_idx, question_edge_attr) in enumerate(zip(question_edge_indices, question_edge_attributes)): + if not(question_edge_idx[0] in masked_idx or question_edge_idx[1] in masked_idx): + # Masking + edge_attr_gt.append(np.argwhere(question_edge_attr).item()) + edge_idx_gt_gnn.append(question_edge_idx) + edge_idx_gt_bert.append([question_edge_idx[0] + question_start, question_edge_idx[1] + question_start]) + question_edge_attr = np.zeros_like(question_edge_attr) + question_edge_attr[-1] = 1.0 # The [EDGE_MASK] special token is the last one hot vector encoding + question_edge_attributes[i] = question_edge_attr + break + + # For the rare case, where the conditions for masking were not met + if len(edge_attr_gt) == 0: + edge_attr_gt.append(-1) + edge_idx_gt_gnn.append([0, question_end - question_start]) + edge_idx_gt_bert.append(question_limits) + + # Pad to max_len + while len(edge_attr_gt) < max_len: + edge_attr_gt.append(-1) + edge_idx_gt_gnn.append(edge_idx_gt_gnn[-1]) + edge_idx_gt_bert.append(edge_idx_gt_bert[-1]) + + # Truncate if longer than max_len + if len(edge_attr_gt) > max_len: + edge_idx_gt_gnn = edge_idx_gt_gnn[:max_len] + edge_idx_gt_bert = edge_idx_gt_bert[:max_len] + edge_attr_gt = edge_attr_gt[:max_len] + edge_idx_gt_gnn = np.array(edge_idx_gt_gnn) + edge_idx_gt_bert = np.array(edge_idx_gt_bert) + + first_edge_node_gt_gnn = list(edge_idx_gt_gnn[:, 0]) + second_edge_node_gt_gnn = list(edge_idx_gt_gnn[:, 1]) + + first_edge_node_gt_bert = list(edge_idx_gt_bert[:, 0]) + second_edge_node_gt_bert = list(edge_idx_gt_bert[:, 1]) + + return question_edge_attributes, edge_attr_gt, first_edge_node_gt_gnn, second_edge_node_gt_gnn, first_edge_node_gt_bert, second_edge_node_gt_bert + + +def to_data_list(feats, batch_idx): + feat_list = [] + device = feats.device + left = 0 + right = 0 + batch_size = batch_idx.max().item() + 1 + for batch in range(batch_size): + if batch == batch_size - 1: + right = batch_idx.size(0) + else: + right = torch.argwhere(batch_idx == batch + 1)[0].item() + idx = torch.arange(left, right).unsqueeze(-1).repeat(1, feats.size(1)).to(device) + feat_list.append(torch.gather(feats, 0, idx)) + left = right + + return feat_list + diff --git a/utils/image_features_reader.py b/utils/image_features_reader.py new file mode 100644 index 0000000..894457d --- /dev/null +++ b/utils/image_features_reader.py @@ -0,0 +1,192 @@ +from typing import List +import csv +import h5py +import numpy as np +import copy +import pickle +import lmdb # install lmdb by "pip install lmdb" +import base64 +import pdb +import os + + +class ImageFeaturesH5Reader(object): + """ + A reader for H5 files containing pre-extracted image features. A typical + H5 file is expected to have a column named "image_id", and another column + named "features". + + Example of an H5 file: + ``` + faster_rcnn_bottomup_features.h5 + |--- "image_id" [shape: (num_images, )] + |--- "features" [shape: (num_images, num_proposals, feature_size)] + +--- .attrs ("split", "train") + ``` + Parameters + ---------- + features_h5path : str + Path to an H5 file containing COCO train / val image features. + in_memory : bool + Whether to load the whole H5 file in memory. Beware, these files are + sometimes tens of GBs in size. Set this to true if you have sufficient + RAM - trade-off between speed and memory. + """ + def __init__(self, features_path: str, scene_graph_path: str, in_memory: bool = False): + self.features_path = features_path + self.scene_graph_path = scene_graph_path + self._in_memory = in_memory + + self.env = lmdb.open(self.features_path, max_readers=1, readonly=True, + lock=False, readahead=False, meminit=False) + + with self.env.begin(write=False) as txn: + self._image_ids = pickle.loads(txn.get('keys'.encode())) + + self.features = [None] * len(self._image_ids) + self.num_boxes = [None] * len(self._image_ids) + self.boxes = [None] * len(self._image_ids) + self.boxes_ori = [None] * len(self._image_ids) + self.cls_prob = [None] * len(self._image_ids) + self.edge_indexes = [None] * len(self._image_ids) + self.edge_attributes = [None] * len(self._image_ids) + + def __len__(self): + return len(self._image_ids) + + def __getitem__(self, image_id): + + image_id = str(image_id).encode() + index = self._image_ids.index(image_id) + if self._in_memory: + # Load features during first epoch, all not loaded together as it + # has a slow start. + if self.features[index] is not None: + features = self.features[index] + num_boxes = self.num_boxes[index] + image_location = self.boxes[index] + image_location_ori = self.boxes_ori[index] + cls_prob = self.cls_prob[index] + edge_indexes = self.edge_indexes[index] + edge_attributes = self.edge_attributes[index] + else: + with self.env.begin(write=False) as txn: + item = pickle.loads(txn.get(image_id)) + image_id = item['image_id'] + image_h = int(item['image_h']) + image_w = int(item['image_w']) + num_boxes = int(item['num_boxes']) + features = np.frombuffer(base64.b64decode(item["features"]), dtype=np.float32).reshape(num_boxes, 2048) + boxes = np.frombuffer(base64.b64decode(item['boxes']), dtype=np.float32).reshape(num_boxes, 4) + + cls_prob = np.frombuffer(base64.b64decode(item['cls_prob']), dtype=np.float32).reshape(num_boxes, 1601) + # add an extra row at the top for the tokens + g_cls_prob = np.zeros(1601, dtype=np.float32) + g_cls_prob[0] = 1 + cls_prob = np.concatenate([np.expand_dims(g_cls_prob,axis=0), cls_prob], axis=0) + + self.cls_prob[index] = cls_prob + + g_feat = np.sum(features, axis=0) / num_boxes + num_boxes = num_boxes + 1 + + features = np.concatenate([np.expand_dims(g_feat, axis=0), features], axis=0) + self.features[index] = features + + image_location = np.zeros((boxes.shape[0], 5), dtype=np.float32) + image_location[:,:4] = boxes + image_location[:,4] = (image_location[:,3] - image_location[:,1]) * (image_location[:,2] - image_location[:,0]) / (float(image_w) * float(image_h)) + + image_location_ori = copy.deepcopy(image_location) + + image_location[:,0] = image_location[:,0] / float(image_w) + image_location[:,1] = image_location[:,1] / float(image_h) + image_location[:,2] = image_location[:,2] / float(image_w) + image_location[:,3] = image_location[:,3] / float(image_h) + + g_location = np.array([0,0,1,1,1]) + image_location = np.concatenate([np.expand_dims(g_location, axis=0), image_location], axis=0) + self.boxes[index] = image_location + + g_location_ori = np.array([0, 0, image_w, image_h, image_w*image_h]) + image_location_ori = np.concatenate([np.expand_dims(g_location_ori, axis=0), image_location_ori], axis=0) + self.boxes_ori[index] = image_location_ori + self.num_boxes[index] = num_boxes + + # load the scene graph data + pth = os.path.join(self.scene_graph_path, f'{image_id}.pkl') + with open(pth, 'rb') as f: + graph_data = pickle.load(f) + edge_indexes = [] + edge_attributes = [] + for e_idx, e_attr in graph_data: + edge_indexes.append(e_idx) + # get one-hot-encoding of the edges + e_attr_one_hot = np.zeros((12,), dtype=np.float32) # 12 = 11 rels + hub-node rel + e_attr_one_hot[e_attr] = 1.0 + edge_attributes.append(e_attr_one_hot) + edge_indexes = np.array(edge_indexes, dtype=np.float64).transpose(1, 0) + edge_attributes = np.stack(edge_attributes, axis=0) + + self.edge_indexes[index] = edge_indexes + self.edge_attributes[index] = edge_attributes + + else: + # Read chunk from file everytime if not loaded in memory. + with self.env.begin(write=False) as txn: + item = pickle.loads(txn.get(image_id)) + image_id = item['image_id'] + image_h = int(item['image_h']) + image_w = int(item['image_w']) + num_boxes = int(item['num_boxes']) + cls_prob = np.frombuffer(base64.b64decode(item['cls_prob']), dtype=np.float32).reshape(num_boxes, 1601) + # add an extra row at the top for the tokens + g_cls_prob = np.zeros(1601, dtype=np.float32) + g_cls_prob[0] = 1 + cls_prob = np.concatenate([np.expand_dims(g_cls_prob,axis=0), cls_prob], axis=0) + + features = np.frombuffer(base64.b64decode(item["features"]), dtype=np.float32).reshape(num_boxes, 2048) + boxes = np.frombuffer(base64.b64decode(item['boxes']), dtype=np.float32).reshape(num_boxes, 4) + g_feat = np.sum(features, axis=0) / num_boxes + num_boxes = num_boxes + 1 + features = np.concatenate([np.expand_dims(g_feat, axis=0), features], axis=0) + + image_location = np.zeros((boxes.shape[0], 5), dtype=np.float32) + image_location[:,:4] = boxes + image_location[:,4] = (image_location[:,3] - image_location[:,1]) * (image_location[:,2] - image_location[:,0]) / (float(image_w) * float(image_h)) + + image_location_ori = copy.deepcopy(image_location) + image_location[:,0] = image_location[:,0] / float(image_w) + image_location[:,1] = image_location[:,1] / float(image_h) + image_location[:,2] = image_location[:,2] / float(image_w) + image_location[:,3] = image_location[:,3] / float(image_h) + + g_location = np.array([0,0,1,1,1]) + image_location = np.concatenate([np.expand_dims(g_location, axis=0), image_location], axis=0) + + g_location_ori = np.array([0,0,image_w,image_h,image_w*image_h]) + image_location_ori = np.concatenate([np.expand_dims(g_location_ori, axis=0), image_location_ori], axis=0) + + # load the scene graph data + pth = os.path.join(self.scene_graph_path, f'{image_id}.pkl') + with open(pth, 'rb') as f: + graph_data = pickle.load(f) + edge_indexes = [] + edge_attributes = [] + for e_idx, e_attr in graph_data: + edge_indexes.append(e_idx) + # get one-hot-encoding of the edges + e_attr_one_hot = np.zeros((12,), dtype=np.float32) # 12 = 11 rels + hub-node rel + e_attr_one_hot[e_attr] = 1.0 + edge_attributes.append(e_attr_one_hot) + edge_indexes = np.array(edge_indexes, dtype=np.float64).transpose(1, 0) + edge_attributes = np.stack(edge_attributes, axis=0) + + return features, num_boxes, image_location, image_location_ori, cls_prob, edge_indexes, edge_attributes + + + def keys(self) -> List[int]: + return self._image_ids + + def set_keys(self, new_ids: List[str]): + self._image_ids = list(map(lambda _id: _id.encode('ascii') ,new_ids)) diff --git a/utils/init_utils.py b/utils/init_utils.py new file mode 100644 index 0000000..b5665bc --- /dev/null +++ b/utils/init_utils.py @@ -0,0 +1,176 @@ +import os +import os.path as osp +import random +import datetime +import itertools +import glob +import subprocess +import pyhocon +import glob +import re +import numpy as np +import glog as log +import json +import torch + +import sys +sys.path.append('../') + +from models import vdgr +from dataloader.dataloader_visdial import VisdialDataset + +from dataloader.dataloader_visdial_dense import VisdialDenseDataset + + +def load_runner(config): + if config['train_on_dense']: + return vdgr.DenseRunner(config) + else: + return vdgr.SparseRunner(config) + +def load_dataset(config): + dataset_eval = None + + if config['train_on_dense']: + dataset = VisdialDenseDataset(config) + if config['skip_mrr_eval']: + temp = config['num_options_dense'] + config['num_options_dense'] = config['num_options'] + dataset_eval = VisdialDenseDataset(config) + config['num_options_dense'] = temp + else: + dataset_eval = VisdialDataset(config) + else: + dataset = VisdialDataset(config) + if config['skip_mrr_eval']: + dataset_eval = VisdialDenseDataset(config) + + if config['use_trainval']: + dataset.split = 'trainval' + else: + dataset.split = 'train' + + if dataset_eval is not None: + dataset_eval.split = 'val' + + return dataset, dataset_eval + + +def initialize_from_env(model, mode, eval_dir, model_type, tag=''): + if "GPU" in os.environ: + os.environ["CUDA_VISIBLE_DEVICES"] = os.environ['GPU'] + if mode in ['train', 'debug']: + config = pyhocon.ConfigFactory.parse_file(f"config/{model_type}.conf")[model] + else: + path_config = osp.join(eval_dir, 'code', f"config/{model_type}.conf") + config = pyhocon.ConfigFactory.parse_file(path_config)[model] + config['log_dir'] = eval_dir + config['model_config'] = osp.join(eval_dir, 'code/config/bert_base_6layer_6conect.json') + if config['dp_type'] == 'apex': + config['dp_type'] = 'ddp' + + if config['dp_type'] == 'dp': + config['stack_gr_data'] = True + + config['model_type'] = model_type + if "CUDA_VISIBLE_DEVICES" in os.environ: + config['num_gpus'] = len(os.environ["CUDA_VISIBLE_DEVICES"].split(',')) + # multi-gpu setting + if config['num_gpus'] > 1: + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = '5678' + + if mode == 'debug': + model += '_debug' + + if tag: + model += '-' + tag + if mode in ['train', 'debug']: + config['log_dir'] = os.path.join(config["log_dir"], model) + if not os.path.exists(config["log_dir"]): + os.makedirs(config["log_dir"]) + config['visdial_output_dir'] = osp.join(config['log_dir'], config['visdial_output_dir']) + + config['timestamp'] = datetime.datetime.now().strftime('%m%d-%H%M%S') + + # add the bert config + config['bert_config'] = json.load(open(config['model_config'], 'r')) + if mode in ['predict', 'eval']: + if (not config['loads_start_path']) and (not config['loads_best_ckpt']): + config['loads_best_ckpt'] = True + print(f'Setting loads_best_ckpt=True under predict or eval mode') + if config['num_options_dense'] < 100: + config['num_options_dense'] = 100 + print('Setting num_options_dense=100 under predict or eval mode') + if config['visdial_version'] == 0.9: + config['skip_ndcg_eval'] = True + + return config + + +def set_log_file(fname, file_only=False): + # if fname already exists, find all log file under log dir, + # and name the current log file with a new number + if osp.exists(fname): + prefix, suffix = osp.splitext(fname) + log_files = glob.glob(prefix + '*' + suffix) + count = 0 + for log_file in log_files: + num = re.search(r'(\d+)', log_file) + if num is not None: + num = int(num.group(0)) + count = max(num, count) + fname = fname.replace(suffix, str(count + 1) + suffix) + # set log file + # simple tricks for duplicating logging destination in the logging module such as: + # logging.getLogger().addHandler(logging.FileHandler(filename)) + # does NOT work well here, because python Traceback message (not via logging module) is not sent to the file, + # the following solution (copied from : https://stackoverflow.com/questions/616645) is a little bit + # complicated but simulates exactly the "tee" command in linux shell, and it redirects everything + if file_only: + # we only output messages to file, and stdout/stderr receives nothing. + # this feature is designed for executing the script via ssh: + # since ssh has a windowing kind of flow control, i.e., if the controller does not read data from a + # ssh channel and its buffer fills up, the execution machine will not be able to write anything into the + # channel and the process will be set to sleeping (S) status until someone reads all data from the channel. + # this is not desired since we do not want to read stdout/stderr from the controller machine. + # so, here we use a simple solution: disable output to stdout/stderr and only output messages to log file. + log.logger.handlers[0].stream = log.handler.stream = sys.stdout = sys.stderr = f = open(fname, 'w', buffering=1) + else: + # we output messages to both file and stdout/stderr + tee = subprocess.Popen(['tee', fname], stdin=subprocess.PIPE) + os.dup2(tee.stdin.fileno(), sys.stdout.fileno()) + os.dup2(tee.stdin.fileno(), sys.stderr.fileno()) + + +def copy_file_to_log(log_dir): + dirs_to_cp = ['.', 'config', 'dataloader', 'models', 'utils'] + files_to_cp = ['*.py', '*.json', '*.sh', '*.conf'] + for dir_name in dirs_to_cp: + dir_name = osp.join(log_dir, 'code', dir_name) + if not osp.exists(dir_name): + os.makedirs(dir_name) + for dir_name, file_name in itertools.product(dirs_to_cp, files_to_cp): + filename = osp.join(dir_name, file_name) + if len(glob.glob(filename)) > 0: + os.system(f'cp {filename} {osp.join(log_dir, "code", dir_name)}') + log.info(f'Files copied to {osp.join(log_dir, "code")}') + + +def set_random_seed(random_seed): + torch.manual_seed(random_seed) + torch.cuda.manual_seed(random_seed) + random.seed(random_seed) + np.random.seed(random_seed) + + +def set_training_steps(config, num_samples): + if config['parallel'] and config['dp_type'] == 'dp': + config['num_iter_per_epoch'] = int(np.ceil(num_samples / config['batch_size'])) + else: + config['num_iter_per_epoch'] = int(np.ceil(num_samples / (config['batch_size'] * config['num_gpus']))) + if 'train_steps' not in config: + config['train_steps'] = config['num_iter_per_epoch'] * config['num_epochs'] + if 'warmup_steps' not in config: + config['warmup_steps'] = int(config['train_steps'] * config['warmup_ratio']) + return config diff --git a/utils/model_utils.py b/utils/model_utils.py new file mode 100644 index 0000000..734a550 --- /dev/null +++ b/utils/model_utils.py @@ -0,0 +1,456 @@ +import torch +from torch import nn +import torch.nn.functional as F +import numpy as np + +def truncated_normal_(tensor, mean=0, std=1): + size = tensor.shape + tmp = tensor.new_empty(size + (4,)).normal_() + valid = (tmp < 2) & (tmp > -2) + ind = valid.max(-1, keepdim=True)[1] + tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1)) + tensor.data.mul_(std).add_(mean) + + +def init_params(module, initializer='normal'): + + if isinstance(module, nn.Linear): + if initializer == 'kaiming_normal': + nn.init.kaiming_normal_(module.weight.data) + elif initializer == 'normal': + nn.init.normal_(module.weight.data, std=0.02) + elif initializer == 'truncated_normal': + truncated_normal_(module.weight.data, std=0.02) + + if module.bias is not None: + nn.init.zeros_(module.bias.data) + + # log.info('initialized Linear') + + elif isinstance(module, nn.Embedding): + if initializer == 'kaiming_normal': + nn.init.kaiming_normal_(module.weight.data) + elif initializer == 'normal': + nn.init.normal_(module.weight.data, std=0.02) + elif initializer == 'truncated_normal': + truncated_normal_(module.weight.data, std=0.02) + + elif isinstance(module, nn.Conv2d) or isinstance(module, nn.Conv1d): + nn.init.kaiming_normal_(module.weight, mode='fan_out') + # log.info('initialized Conv') + + elif isinstance(module, nn.RNNBase) or isinstance(module, nn.LSTMCell) or isinstance(module, nn.GRUCell): + for name, param in module.named_parameters(): + if 'weight' in name: + nn.init.orthogonal_(param.data) + elif 'bias' in name: + nn.init.normal_(param.data) + + # log.info('initialized LSTM') + + elif isinstance(module, nn.BatchNorm1d) or isinstance(module, nn.BatchNorm2d): + module.weight.data.normal_(1.0, 0.02) + # log.info('initialized BatchNorm') + + +def TensorboardWriter(save_path): + from torch.utils.tensorboard import SummaryWriter + return SummaryWriter(save_path, comment="Unmt") + + +DEFAULT_EPS = 1e-8 +PADDED_Y_VALUE = -1 + + +def listMLE(y_pred, y_true, eps=DEFAULT_EPS, padded_value_indicator=PADDED_Y_VALUE): + """ + ListMLE loss introduced in "Listwise Approach to Learning to Rank - Theory and Algorithm". + :param y_pred: predictions from the model, shape [batch_size, slate_length] + :param y_true: ground truth labels, shape [batch_size, slate_length] + :param eps: epsilon value, used for numerical stability + :param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1 + :return: loss value, a torch.Tensor + """ + # shuffle for randomised tie resolution + random_indices = torch.randperm(y_pred.shape[-1]) + y_pred_shuffled = y_pred[:, random_indices] + y_true_shuffled = y_true[:, random_indices] + + y_true_sorted, indices = y_true_shuffled.sort(descending=True, dim=-1) + + mask = y_true_sorted == padded_value_indicator + + preds_sorted_by_true = torch.gather(y_pred_shuffled, dim=1, index=indices) + preds_sorted_by_true[mask] = float("-inf") + + max_pred_values, _ = preds_sorted_by_true.max(dim=1, keepdim=True) + + preds_sorted_by_true_minus_max = preds_sorted_by_true - max_pred_values + + cumsums = torch.cumsum(preds_sorted_by_true_minus_max.exp().flip(dims=[1]), dim=1).flip(dims=[1]) + + observation_loss = torch.log(cumsums + eps) - preds_sorted_by_true_minus_max + + observation_loss[mask] = 0.0 + + return torch.mean(torch.sum(observation_loss, dim=1)) + + +def approxNDCGLoss(y_pred, y_true, eps=DEFAULT_EPS, padded_value_indicator=PADDED_Y_VALUE, alpha=1.): + """ + Loss based on approximate NDCG introduced in "A General Approximation Framework for Direct Optimization of + Information Retrieval Measures". Please note that this method does not implement any kind of truncation. + :param y_pred: predictions from the model, shape [batch_size, slate_length] + :param y_true: ground truth labels, shape [batch_size, slate_length] + :param eps: epsilon value, used for numerical stability + :param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1 + :param alpha: score difference weight used in the sigmoid function + :return: loss value, a torch.Tensor + """ + device = y_pred.device + y_pred = y_pred.clone() + y_true = y_true.clone() + + padded_mask = y_true == padded_value_indicator + y_pred[padded_mask] = float("-inf") + y_true[padded_mask] = float("-inf") + + # Here we sort the true and predicted relevancy scores. + y_pred_sorted, indices_pred = y_pred.sort(descending=True, dim=-1) + y_true_sorted, _ = y_true.sort(descending=True, dim=-1) + + # After sorting, we can mask out the pairs of indices (i, j) containing index of a padded element. + true_sorted_by_preds = torch.gather(y_true, dim=1, index=indices_pred) + true_diffs = true_sorted_by_preds[:, :, None] - true_sorted_by_preds[:, None, :] + padded_pairs_mask = torch.isfinite(true_diffs) + padded_pairs_mask.diagonal(dim1=-2, dim2=-1).zero_() + + # Here we clamp the -infs to get correct gains and ideal DCGs (maxDCGs) + true_sorted_by_preds.clamp_(min=0.) + y_true_sorted.clamp_(min=0.) + + # Here we find the gains, discounts and ideal DCGs per slate. + pos_idxs = torch.arange(1, y_pred.shape[1] + 1).to(device) + D = torch.log2(1. + pos_idxs.float())[None, :] + maxDCGs = torch.sum((torch.pow(2, y_true_sorted) - 1) / D, dim=-1).clamp(min=eps) + G = (torch.pow(2, true_sorted_by_preds) - 1) / maxDCGs[:, None] + + # Here we approximate the ranking positions according to Eqs 19-20 and later approximate NDCG (Eq 21) + scores_diffs = (y_pred_sorted[:, :, None] - y_pred_sorted[:, None, :]) + scores_diffs[~padded_pairs_mask] = 0. + approx_pos = 1. + torch.sum(padded_pairs_mask.float() * (torch.sigmoid(-alpha * scores_diffs).clamp(min=eps)), + dim=-1) + approx_D = torch.log2(1. + approx_pos) + approx_NDCG = torch.sum((G / approx_D), dim=-1) + + return -torch.mean(approx_NDCG) + # return -torch.mean(approx_NDCG) + + +def listNet(y_pred, y_true, eps=DEFAULT_EPS, padded_value_indicator=PADDED_Y_VALUE): + """ + ListNet loss introduced in "Learning to Rank: From Pairwise Approach to Listwise Approach". + :param y_pred: predictions from the model, shape [batch_size, slate_length] + :param y_true: ground truth labels, shape [batch_size, slate_length] + :param eps: epsilon value, used for numerical stability + :param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1 + :return: loss value, a torch.Tensor + """ + y_pred = y_pred.clone() + y_true = y_true.clone() + + mask = y_true == padded_value_indicator + y_pred[mask] = float('-inf') + y_true[mask] = float('-inf') + + preds_smax = F.softmax(y_pred, dim=1) + true_smax = F.softmax(y_true, dim=1) + + preds_smax = preds_smax + eps + preds_log = torch.log(preds_smax) + + return torch.mean(-torch.sum(true_smax * preds_log, dim=1)) + + +def deterministic_neural_sort(s, tau, mask): + """ + Deterministic neural sort. + Code taken from "Stochastic Optimization of Sorting Networks via Continuous Relaxations", ICLR 2019. + Minor modifications applied to the original code (masking). + :param s: values to sort, shape [batch_size, slate_length] + :param tau: temperature for the final softmax function + :param mask: mask indicating padded elements + :return: approximate permutation matrices of shape [batch_size, slate_length, slate_length] + """ + dev = s.device + + n = s.size()[1] + one = torch.ones((n, 1), dtype=torch.float32, device=dev) + s = s.masked_fill(mask[:, :, None], -1e8) + A_s = torch.abs(s - s.permute(0, 2, 1)) + A_s = A_s.masked_fill(mask[:, :, None] | mask[:, None, :], 0.0) + + B = torch.matmul(A_s, torch.matmul(one, torch.transpose(one, 0, 1))) + + temp = [n - m + 1 - 2 * (torch.arange(n - m, device=dev) + 1) for m in mask.squeeze(-1).sum(dim=1)] + temp = [t.type(torch.float32) for t in temp] + temp = [torch.cat((t, torch.zeros(n - len(t), device=dev))) for t in temp] + scaling = torch.stack(temp).type(torch.float32).to(dev) # type: ignore + + s = s.masked_fill(mask[:, :, None], 0.0) + C = torch.matmul(s, scaling.unsqueeze(-2)) + + P_max = (C - B).permute(0, 2, 1) + P_max = P_max.masked_fill(mask[:, :, None] | mask[:, None, :], -np.inf) + P_max = P_max.masked_fill(mask[:, :, None] & mask[:, None, :], 1.0) + sm = torch.nn.Softmax(-1) + P_hat = sm(P_max / tau) + return P_hat + +def sample_gumbel(samples_shape, device, eps=1e-10) -> torch.Tensor: + """ + Sampling from Gumbel distribution. + Code taken from "Stochastic Optimization of Sorting Networks via Continuous Relaxations", ICLR 2019. + Minor modifications applied to the original code (masking). + :param samples_shape: shape of the output samples tensor + :param device: device of the output samples tensor + :param eps: epsilon for the logarithm function + :return: Gumbel samples tensor of shape samples_shape + """ + U = torch.rand(samples_shape, device=device) + return -torch.log(-torch.log(U + eps) + eps) + + +def apply_mask_and_get_true_sorted_by_preds(y_pred, y_true, padding_indicator=PADDED_Y_VALUE): + mask = y_true == padding_indicator + + y_pred[mask] = float('-inf') + y_true[mask] = 0.0 + + _, indices = y_pred.sort(descending=True, dim=-1) + return torch.gather(y_true, dim=1, index=indices) + + +def dcg(y_pred, y_true, ats=None, gain_function=lambda x: torch.pow(2, x) - 1, padding_indicator=PADDED_Y_VALUE): + """ + Discounted Cumulative Gain at k. + Compute DCG at ranks given by ats or at the maximum rank if ats is None. + :param y_pred: predictions from the model, shape [batch_size, slate_length] + :param y_true: ground truth labels, shape [batch_size, slate_length] + :param ats: optional list of ranks for DCG evaluation, if None, maximum rank is used + :param gain_function: callable, gain function for the ground truth labels, e.g. torch.pow(2, x) - 1 + :param padding_indicator: an indicator of the y_true index containing a padded item, e.g. -1 + :return: DCG values for each slate and evaluation position, shape [batch_size, len(ats)] + """ + y_true = y_true.clone() + y_pred = y_pred.clone() + + actual_length = y_true.shape[1] + + if ats is None: + ats = [actual_length] + ats = [min(at, actual_length) for at in ats] + + true_sorted_by_preds = apply_mask_and_get_true_sorted_by_preds(y_pred, y_true, padding_indicator) + + discounts = (torch.tensor(1) / torch.log2(torch.arange(true_sorted_by_preds.shape[1], dtype=torch.float) + 2.0)).to( + device=true_sorted_by_preds.device) + + gains = gain_function(true_sorted_by_preds) + + discounted_gains = (gains * discounts)[:, :np.max(ats)] + + cum_dcg = torch.cumsum(discounted_gains, dim=1) + + ats_tensor = torch.tensor(ats, dtype=torch.long) - torch.tensor(1) + + dcg = cum_dcg[:, ats_tensor] + + return dcg + + +def sinkhorn_scaling(mat, mask=None, tol=1e-6, max_iter=50): + """ + Sinkhorn scaling procedure. + :param mat: a tensor of square matrices of shape N x M x M, where N is batch size + :param mask: a tensor of masks of shape N x M + :param tol: Sinkhorn scaling tolerance + :param max_iter: maximum number of iterations of the Sinkhorn scaling + :return: a tensor of (approximately) doubly stochastic matrices + """ + if mask is not None: + mat = mat.masked_fill(mask[:, None, :] | mask[:, :, None], 0.0) + mat = mat.masked_fill(mask[:, None, :] & mask[:, :, None], 1.0) + + for _ in range(max_iter): + mat = mat / mat.sum(dim=1, keepdim=True).clamp(min=DEFAULT_EPS) + mat = mat / mat.sum(dim=2, keepdim=True).clamp(min=DEFAULT_EPS) + + if torch.max(torch.abs(mat.sum(dim=2) - 1.)) < tol and torch.max(torch.abs(mat.sum(dim=1) - 1.)) < tol: + break + + if mask is not None: + mat = mat.masked_fill(mask[:, None, :] | mask[:, :, None], 0.0) + + return mat + + + +def stochastic_neural_sort(s, n_samples, tau, mask, beta=1.0, log_scores=True, eps=1e-10): + """ + Stochastic neural sort. Please note that memory complexity grows by factor n_samples. + Code taken from "Stochastic Optimization of Sorting Networks via Continuous Relaxations", ICLR 2019. + Minor modifications applied to the original code (masking). + :param s: values to sort, shape [batch_size, slate_length] + :param n_samples: number of samples (approximations) for each permutation matrix + :param tau: temperature for the final softmax function + :param mask: mask indicating padded elements + :param beta: scale parameter for the Gumbel distribution + :param log_scores: whether to apply the logarithm function to scores prior to Gumbel perturbation + :param eps: epsilon for the logarithm function + :return: approximate permutation matrices of shape [n_samples, batch_size, slate_length, slate_length] + """ + dev = s.device + + batch_size = s.size()[0] + n = s.size()[1] + s_positive = s + torch.abs(s.min()) + samples = beta * sample_gumbel([n_samples, batch_size, n, 1], device=dev) + if log_scores: + s_positive = torch.log(s_positive + eps) + + s_perturb = (s_positive + samples).view(n_samples * batch_size, n, 1) + mask_repeated = mask.repeat_interleave(n_samples, dim=0) + + P_hat = deterministic_neural_sort(s_perturb, tau, mask_repeated) + P_hat = P_hat.view(n_samples, batch_size, n, n) + return P_hat + + +def neuralNDCG(y_pred, y_true, padded_value_indicator=PADDED_Y_VALUE, temperature=1., powered_relevancies=True, k=None, + stochastic=False, n_samples=32, beta=0.1, log_scores=True): + """ + NeuralNDCG loss introduced in "NeuralNDCG: Direct Optimisation of a Ranking Metric via Differentiable + Relaxation of Sorting" - https://arxiv.org/abs/2102.07831. Based on the NeuralSort algorithm. + :param y_pred: predictions from the model, shape [batch_size, slate_length] + :param y_true: ground truth labels, shape [batch_size, slate_length] + :param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1 + :param temperature: temperature for the NeuralSort algorithm + :param powered_relevancies: whether to apply 2^x - 1 gain function, x otherwise + :param k: rank at which the loss is truncated + :param stochastic: whether to calculate the stochastic variant + :param n_samples: how many stochastic samples are taken, used if stochastic == True + :param beta: beta parameter for NeuralSort algorithm, used if stochastic == True + :param log_scores: log_scores parameter for NeuralSort algorithm, used if stochastic == True + :return: loss value, a torch.Tensor + """ + dev = y_pred.device + + if k is None: + k = y_true.shape[1] + + mask = (y_true == padded_value_indicator) + # Choose the deterministic/stochastic variant + if stochastic: + P_hat = stochastic_neural_sort(y_pred.unsqueeze(-1), n_samples=n_samples, tau=temperature, mask=mask, + beta=beta, log_scores=log_scores) + else: + P_hat = deterministic_neural_sort(y_pred.unsqueeze(-1), tau=temperature, mask=mask).unsqueeze(0) + + # Perform sinkhorn scaling to obtain doubly stochastic permutation matrices + P_hat = sinkhorn_scaling(P_hat.view(P_hat.shape[0] * P_hat.shape[1], P_hat.shape[2], P_hat.shape[3]), + mask.repeat_interleave(P_hat.shape[0], dim=0), tol=1e-6, max_iter=50) + P_hat = P_hat.view(int(P_hat.shape[0] / y_pred.shape[0]), y_pred.shape[0], P_hat.shape[1], P_hat.shape[2]) + + # Mask P_hat and apply to true labels, ie approximately sort them + P_hat = P_hat.masked_fill(mask[None, :, :, None] | mask[None, :, None, :], 0.) + y_true_masked = y_true.masked_fill(mask, 0.).unsqueeze(-1).unsqueeze(0) + if powered_relevancies: + y_true_masked = torch.pow(2., y_true_masked) - 1. + + ground_truth = torch.matmul(P_hat, y_true_masked).squeeze(-1) + discounts = (torch.tensor(1.) / torch.log2(torch.arange(y_true.shape[-1], dtype=torch.float) + 2.)).to(dev) + discounted_gains = ground_truth * discounts + + if powered_relevancies: + idcg = dcg(y_true, y_true, ats=[k]).permute(1, 0) + else: + idcg = dcg(y_true, y_true, ats=[k], gain_function=lambda x: x).permute(1, 0) + + discounted_gains = discounted_gains[:, :, :k] + ndcg = discounted_gains.sum(dim=-1) / (idcg + DEFAULT_EPS) + idcg_mask = idcg == 0. + ndcg = ndcg.masked_fill(idcg_mask.repeat(ndcg.shape[0], 1), 0.) + + assert (ndcg < 0.).sum() >= 0, "every ndcg should be non-negative" + if idcg_mask.all(): + return torch.tensor(0.) + + mean_ndcg = ndcg.sum() / ((~idcg_mask).sum() * ndcg.shape[0]) # type: ignore + return -1. * mean_ndcg # -1 cause we want to maximize NDCG + + +def neuralNDCG_transposed(y_pred, y_true, padded_value_indicator=PADDED_Y_VALUE, temperature=1., + powered_relevancies=True, k=None, stochastic=False, n_samples=32, beta=0.1, log_scores=True, + max_iter=50, tol=1e-6): + """ + NeuralNDCG Transposed loss introduced in "NeuralNDCG: Direct Optimisation of a Ranking Metric via Differentiable + Relaxation of Sorting" - https://arxiv.org/abs/2102.07831. Based on the NeuralSort algorithm. + :param y_pred: predictions from the model, shape [batch_size, slate_length] + :param y_true: ground truth labels, shape [batch_size, slate_length] + :param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1 + :param temperature: temperature for the NeuralSort algorithm + :param powered_relevancies: whether to apply 2^x - 1 gain function, x otherwise + :param k: rank at which the loss is truncated + :param stochastic: whether to calculate the stochastic variant + :param n_samples: how many stochastic samples are taken, used if stochastic == True + :param beta: beta parameter for NeuralSort algorithm, used if stochastic == True + :param log_scores: log_scores parameter for NeuralSort algorithm, used if stochastic == True + :param max_iter: maximum iteration count for Sinkhorn scaling + :param tol: tolerance for Sinkhorn scaling + :return: loss value, a torch.Tensor + """ + dev = y_pred.device + + if k is None: + k = y_true.shape[1] + + mask = (y_true == padded_value_indicator) + + if stochastic: + P_hat = stochastic_neural_sort(y_pred.unsqueeze(-1), n_samples=n_samples, tau=temperature, mask=mask, + beta=beta, log_scores=log_scores) + else: + P_hat = deterministic_neural_sort(y_pred.unsqueeze(-1), tau=temperature, mask=mask).unsqueeze(0) + + # Perform sinkhorn scaling to obtain doubly stochastic permutation matrices + P_hat_masked = sinkhorn_scaling(P_hat.view(P_hat.shape[0] * y_pred.shape[0], y_pred.shape[1], y_pred.shape[1]), + mask.repeat_interleave(P_hat.shape[0], dim=0), tol=tol, max_iter=max_iter) + P_hat_masked = P_hat_masked.view(P_hat.shape[0], y_pred.shape[0], y_pred.shape[1], y_pred.shape[1]) + discounts = (torch.tensor(1) / torch.log2(torch.arange(y_true.shape[-1], dtype=torch.float) + 2.)).to(dev) + + # This takes care of the @k metric truncation - if something is @>k, it is useless and gets 0.0 discount + discounts[k:] = 0. + discounts = discounts[None, None, :, None] + + # Here the discounts become expected discounts + discounts = torch.matmul(P_hat_masked.permute(0, 1, 3, 2), discounts).squeeze(-1) + if powered_relevancies: + gains = torch.pow(2., y_true) - 1 + discounted_gains = gains.unsqueeze(0) * discounts + idcg = dcg(y_true, y_true, ats=[k]).squeeze() + else: + gains = y_true + discounted_gains = gains.unsqueeze(0) * discounts + idcg = dcg(y_true, y_true, ats=[k]).squeeze() + + ndcg = discounted_gains.sum(dim=2) / (idcg + DEFAULT_EPS) + idcg_mask = idcg == 0. + ndcg = ndcg.masked_fill(idcg_mask, 0.) + + assert (ndcg < 0.).sum() >= 0, "every ndcg should be non-negative" + if idcg_mask.all(): + return torch.tensor(0.) + + mean_ndcg = ndcg.sum() / ((~idcg_mask).sum() * ndcg.shape[0]) # type: ignore + return -1. * mean_ndcg # -1 cause we want to maximize NDCG diff --git a/utils/modules.py b/utils/modules.py new file mode 100644 index 0000000..cccb657 --- /dev/null +++ b/utils/modules.py @@ -0,0 +1,41 @@ +from collections import Counter, defaultdict +import logging +from typing import Union, List, Dict, Any +import torch +from torch import nn + + +class Identity(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x + + +class Reshaper(nn.Module): + def __init__(self, *output_shape): + super().__init__() + + self.output_shape = output_shape + + def forward(self, input: torch.Tensor): + return input.view(*self.output_shape) + + +class Normalizer(nn.Module): + def __init__(self, target_norm=1.): + super().__init__() + self.target_norm = target_norm + + def forward(self, input: torch.Tensor): + return input * self.target_norm / input.norm(p=2, dim=1, keepdim=True) + + +class Squeezer(nn.Module): + def __init__(self, dim=-1): + super().__init__() + self.dim = dim + + def forward(self, input): + return torch.squeeze(input, dim=self.dim) diff --git a/utils/optim_utils.py b/utils/optim_utils.py new file mode 100644 index 0000000..4d98597 --- /dev/null +++ b/utils/optim_utils.py @@ -0,0 +1,389 @@ +import logging +import math +import numpy as np +import random +import functools +import glog as log + +import torch +from torch import nn, optim +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler, ConstantLR +import torch.nn.functional as F +from torch.nn.utils import clip_grad_norm_ +from pytorch_transformers.optimization import AdamW + + +class WarmupLinearScheduleNonZero(_LRScheduler): + """ Linear warmup and then linear decay. + Linearly increases learning rate from 0 to max_lr over `warmup_steps` training steps. + Linearly decreases learning rate linearly to min_lr over remaining `t_total - warmup_steps` steps. + """ + def __init__(self, optimizer, warmup_steps, t_total, min_lr=1e-5, last_epoch=-1): + self.warmup_steps = warmup_steps + self.t_total = t_total + self.min_lr = min_lr + super(WarmupLinearScheduleNonZero, self).__init__(optimizer, last_epoch=last_epoch) + + def get_lr(self): + step = self.last_epoch + if step < self.warmup_steps: + lr_factor = float(step) / float(max(1, self.warmup_steps)) + else: + lr_factor = max(0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps))) + + return [base_lr * lr_factor if (base_lr * lr_factor) > self.min_lr else self.min_lr for base_lr in self.base_lrs] + + +def init_optim(model, config): + optimizer_grouped_parameters = [] + + gnn_params = [] + + encoder_params_with_decay = [] + encoder_params_without_decay = [] + + exclude_from_weight_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] + + for module_name, module in model.named_children(): + for param_name, param in module.named_parameters(): + if param.requires_grad: + if "gnn" in param_name: + gnn_params.append(param) + elif module_name == 'encoder': + if any(ex in param_name for ex in exclude_from_weight_decay): + encoder_params_without_decay.append(param) + else: + encoder_params_with_decay.append(param) + + optimizer_grouped_parameters = [ + { + 'params': gnn_params, + 'weight_decay': config.gnn_weight_decay, + 'lr': config['learning_rate_gnn'] if config.use_diff_lr_gnn else config['learning_rate_bert'] + } + ] + + optimizer_grouped_parameters.extend( + [ + { + 'params': encoder_params_without_decay, + 'weight_decay': 0, + 'lr': config['learning_rate_bert'] + }, + { + 'params': encoder_params_with_decay, + 'weight_decay': 0.01, + 'lr': config['learning_rate_bert'] + } + ] + ) + optimizer = AdamW(optimizer_grouped_parameters, lr=config['learning_rate_gnn']) + scheduler = WarmupLinearScheduleNonZero( + optimizer, + warmup_steps=config['warmup_steps'], + t_total=config['train_steps'], + min_lr=config['min_lr'] + ) + + return optimizer, scheduler + + +def build_torch_optimizer(model, config): + """Builds the PyTorch optimizer. + + We use the default parameters for Adam that are suggested by + the original paper https://arxiv.org/pdf/1412.6980.pdf + These values are also used by other established implementations, + e.g. https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer + https://keras.io/optimizers/ + Recently there are slightly different values used in the paper + "Attention is all you need" + https://arxiv.org/pdf/1706.03762.pdf, particularly the value beta2=0.98 + was used there however, beta2=0.999 is still arguably the more + established value, so we use that here as well + + Args: + model: The model to optimize. + config: The dictionary of options. + + Returns: + A ``torch.optim.Optimizer`` instance. + """ + params = [p for p in model.parameters() if p.requires_grad] + betas = [0.9, 0.999] + exclude_from_weight_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] + + params = {'bert': [], 'task': []} + for module_name, module in model.named_children(): + if module_name == 'encoder': + param_type = 'bert' + else: + param_type = 'task' + for param_name, param in module.named_parameters(): + if param.requires_grad: + if any(ex in param_name for ex in exclude_from_weight_decay): + params[param_type] += [ + { + "params": [param], + "weight_decay": 0 + } + ] + else: + params[param_type] += [ + { + "params": [param], + "weight_decay": 0.01 + } + ] + if config['task_optimizer'] == 'adamw': + log.info('Using AdamW as task optimizer') + task_optimizer = AdamWeightDecay(params['task'], + lr=config["learning_rate_task"], + betas=betas, + eps=1e-6) + elif config['task_optimizer'] == 'adam': + log.info('Using Adam as task optimizer') + task_optimizer = optim.Adam(params['task'], + lr=config["learning_rate_task"], + betas=betas, + eps=1e-6) + if len(params['bert']) > 0: + bert_optimizer = AdamWeightDecay(params['bert'], + lr=config["learning_rate_bert"], + betas=betas, + eps=1e-6) + optimizer = MultipleOptimizer([bert_optimizer, task_optimizer]) + else: + optimizer = task_optimizer + + return optimizer + + +def make_learning_rate_decay_fn(decay_method, train_steps, **kwargs): + """Returns the learning decay function from options.""" + if decay_method == "linear": + return functools.partial( + linear_decay, + global_steps=train_steps, + **kwargs) + elif decay_method == "exp": + return functools.partial( + exp_decay, + global_steps=train_steps, + **kwargs) + else: + raise ValueError(f'{decay_method} not found') + + +def linear_decay(step, global_steps, warmup_steps=100, initial_learning_rate=1, end_learning_rate=0, **kargs): + if step < warmup_steps: + return initial_learning_rate * step / warmup_steps + else: + return (initial_learning_rate - end_learning_rate) * \ + (1 - (step - warmup_steps) / (global_steps - warmup_steps)) + \ + end_learning_rate + +def exp_decay(step, global_steps, decay_exp=1, warmup_steps=100, initial_learning_rate=1, end_learning_rate=0, **kargs): + if step < warmup_steps: + return initial_learning_rate * step / warmup_steps + else: + return (initial_learning_rate - end_learning_rate) * \ + ((1 - (step - warmup_steps) / (global_steps - warmup_steps)) ** decay_exp) + \ + end_learning_rate + + +class MultipleOptimizer(object): + """ Implement multiple optimizers needed for sparse adam """ + + def __init__(self, op): + """ ? """ + self.optimizers = op + + @property + def param_groups(self): + param_groups = [] + for optimizer in self.optimizers: + param_groups.extend(optimizer.param_groups) + return param_groups + + def zero_grad(self): + """ ? """ + for op in self.optimizers: + op.zero_grad() + + def step(self): + """ ? """ + for op in self.optimizers: + op.step() + + @property + def state(self): + """ ? """ + return {k: v for op in self.optimizers for k, v in op.state.items()} + + def state_dict(self): + """ ? """ + return [op.state_dict() for op in self.optimizers] + + def load_state_dict(self, state_dicts): + """ ? """ + assert len(state_dicts) == len(self.optimizers) + for i in range(len(state_dicts)): + self.optimizers[i].load_state_dict(state_dicts[i]) + + +class OptimizerBase(object): + """ + Controller class for optimization. Mostly a thin + wrapper for `optim`, but also useful for implementing + rate scheduling beyond what is currently available. + Also implements necessary methods for training RNNs such + as grad manipulations. + """ + + def __init__(self, + optimizer, + learning_rate, + learning_rate_decay_fn=None, + max_grad_norm=None): + """Initializes the controller. + + Args: + optimizer: A ``torch.optim.Optimizer`` instance. + learning_rate: The initial learning rate. + learning_rate_decay_fn: An optional callable taking the current step + as argument and return a learning rate scaling factor. + max_grad_norm: Clip gradients to this global norm. + """ + self._optimizer = optimizer + self._learning_rate = learning_rate + self._learning_rate_decay_fn = learning_rate_decay_fn + self._max_grad_norm = max_grad_norm or 0 + self._training_step = 1 + self._decay_step = 1 + + @classmethod + def from_opt(cls, model, config, checkpoint=None): + """Builds the optimizer from options. + + Args: + cls: The ``Optimizer`` class to instantiate. + model: The model to optimize. + opt: The dict of user options. + checkpoint: An optional checkpoint to load states from. + + Returns: + An ``Optimizer`` instance. + """ + optim_opt = config + optim_state_dict = None + + if config["loads_ckpt"] and checkpoint is not None: + optim = checkpoint['optim'] + ckpt_opt = checkpoint['opt'] + ckpt_state_dict = {} + if isinstance(optim, Optimizer): # Backward compatibility. + ckpt_state_dict['training_step'] = optim._step + 1 + ckpt_state_dict['decay_step'] = optim._step + 1 + ckpt_state_dict['optimizer'] = optim.optimizer.state_dict() + else: + ckpt_state_dict = optim + + if config["reset_optim"] == 'none': + # Load everything from the checkpoint. + optim_opt = ckpt_opt + optim_state_dict = ckpt_state_dict + elif config["reset_optim"] == 'all': + # Build everything from scratch. + pass + elif config["reset_optim"] == 'states': + # Reset optimizer, keep options. + optim_opt = ckpt_opt + optim_state_dict = ckpt_state_dict + del optim_state_dict['optimizer'] + elif config["reset_optim"] == 'keep_states': + # Reset options, keep optimizer. + optim_state_dict = ckpt_state_dict + + learning_rates = [ + optim_opt["learning_rate_bert"], + optim_opt["learning_rate_gnn"] + ] + decay_fn = [ + make_learning_rate_decay_fn(optim_opt['decay_method_bert'], + optim_opt['train_steps'], + warmup_steps=optim_opt['warmup_steps'], + decay_exp=optim_opt['decay_exp']), + make_learning_rate_decay_fn(optim_opt['decay_method_gnn'], + optim_opt['train_steps'], + warmup_steps=optim_opt['warmup_steps'], + decay_exp=optim_opt['decay_exp']), + ] + optimizer = cls( + build_torch_optimizer(model, optim_opt), + learning_rates, + learning_rate_decay_fn=decay_fn, + max_grad_norm=optim_opt["max_grad_norm"]) + if optim_state_dict: + optimizer.load_state_dict(optim_state_dict) + return optimizer + + @property + def training_step(self): + """The current training step.""" + return self._training_step + + def learning_rate(self): + """Returns the current learning rate.""" + if self._learning_rate_decay_fn is None: + return self._learning_rate + return [decay_fn(self._decay_step) * learning_rate \ + for decay_fn, learning_rate in \ + zip(self._learning_rate_decay_fn, self._learning_rate)] + + def state_dict(self): + return { + 'training_step': self._training_step, + 'decay_step': self._decay_step, + 'optimizer': self._optimizer.state_dict() + } + + def load_state_dict(self, state_dict): + self._training_step = state_dict['training_step'] + # State can be partially restored. + if 'decay_step' in state_dict: + self._decay_step = state_dict['decay_step'] + if 'optimizer' in state_dict: + self._optimizer.load_state_dict(state_dict['optimizer']) + + def zero_grad(self): + """Zero the gradients of optimized parameters.""" + self._optimizer.zero_grad() + + def backward(self, loss): + """Wrapper for backward pass. Some optimizer requires ownership of the + backward pass.""" + loss.backward() + + def step(self): + """Update the model parameters based on current gradients. + + Optionally, will employ gradient modification or update learning + rate. + """ + learning_rate = self.learning_rate() + + if isinstance(self._optimizer, MultipleOptimizer): + optimizers = self._optimizer.optimizers + else: + optimizers = [self._optimizer] + for lr, op in zip(learning_rate, optimizers): + for group in op.param_groups: + group['lr'] = lr + if self._max_grad_norm > 0: + clip_grad_norm_(group['params'], self._max_grad_norm) + self._optimizer.step() + self._decay_step += 1 + self._training_step += 1 + diff --git a/utils/visdial_metrics.py b/utils/visdial_metrics.py new file mode 100644 index 0000000..fc9a383 --- /dev/null +++ b/utils/visdial_metrics.py @@ -0,0 +1,322 @@ +""" +A Metric observes output of certain model, for example, in form of logits or +scores, and accumulates a particular metric with reference to some provided +targets. In context of VisDial, we use Recall (@ 1, 5, 10), Mean Rank, Mean +Reciprocal Rank (MRR) and Normalized Discounted Cumulative Gain (NDCG). + +Each ``Metric`` must atleast implement three methods: + - ``observe``, update accumulated metric with currently observed outputs + and targets. + - ``retrieve`` to return the accumulated metric., an optionally reset + internally accumulated metric (this is commonly done between two epochs + after validation). + - ``reset`` to explicitly reset the internally accumulated metric. + +Caveat, if you wish to implement your own class of Metric, make sure you call +``detach`` on output tensors (like logits), else it will cause memory leaks. +""" +import torch +import torch.distributed as dist +import numpy as np + +def scores_to_ranks(scores: torch.Tensor): + """Convert model output scores into ranks.""" + batch_size, num_rounds, num_options = scores.size() + scores = scores.view(-1, num_options) + + # sort in descending order - largest score gets highest rank + sorted_ranks, ranked_idx = scores.sort(1, descending=True) + + # i-th position in ranked_idx specifies which score shall take this + # position but we want i-th position to have rank of score at that + # position, do this conversion + ranks = ranked_idx.clone().fill_(0) + for i in range(ranked_idx.size(0)): + for j in range(num_options): + ranks[i][ranked_idx[i][j]] = j + # convert from 0-99 ranks to 1-100 ranks + ranks += 1 + ranks = ranks.view(batch_size, num_rounds, num_options) + return ranks + +class SparseGTMetrics(object): + """ + A class to accumulate all metrics with sparse ground truth annotations. + These include Recall (@ 1, 5, 10), Mean Rank and Mean Reciprocal Rank. + """ + + def __init__(self): + self._rank_list = [] + self._rank_list_rnd = [] + self.num_rounds = None + + def observe( + self, predicted_scores: torch.Tensor, target_ranks: torch.Tensor + ): + predicted_scores = predicted_scores.detach() + + # shape: (batch_size, num_rounds, num_options) + predicted_ranks = scores_to_ranks(predicted_scores) + batch_size, num_rounds, num_options = predicted_ranks.size() + self.num_rounds = num_rounds + # collapse batch dimension + predicted_ranks = predicted_ranks.view( + batch_size * num_rounds, num_options + ) + + # shape: (batch_size * num_rounds, ) + target_ranks = target_ranks.view(batch_size * num_rounds).long() + + # shape: (batch_size * num_rounds, ) + predicted_gt_ranks = predicted_ranks[ + torch.arange(batch_size * num_rounds), target_ranks + ] + self._rank_list.extend(list(predicted_gt_ranks.cpu().numpy())) + + predicted_gt_ranks_rnd = predicted_gt_ranks.view(batch_size, num_rounds) + # predicted gt ranks + self._rank_list_rnd.append(predicted_gt_ranks_rnd.cpu().numpy()) + + def retrieve(self, reset: bool = True): + num_examples = len(self._rank_list) + if num_examples > 0: + # convert to numpy array for easy calculation. + __rank_list = torch.tensor(self._rank_list).float() + metrics = { + "r@1": torch.mean((__rank_list <= 1).float()).item(), + "r@5": torch.mean((__rank_list <= 5).float()).item(), + "r@10": torch.mean((__rank_list <= 10).float()).item(), + "mean": torch.mean(__rank_list).item(), + "mrr": torch.mean(__rank_list.reciprocal()).item() + } + # add round metrics + _rank_list_rnd = np.concatenate(self._rank_list_rnd) + _rank_list_rnd = _rank_list_rnd.astype(float) + r_1_rnd = np.mean(_rank_list_rnd <= 1, axis=0) + r_5_rnd = np.mean(_rank_list_rnd <= 5, axis=0) + r_10_rnd = np.mean(_rank_list_rnd <= 10, axis=0) + mean_rnd = np.mean(_rank_list_rnd, axis=0) + mrr_rnd = np.mean(np.reciprocal(_rank_list_rnd), axis=0) + + for rnd in range(1, self.num_rounds + 1): + metrics["r_1" + "_round_" + str(rnd)] = r_1_rnd[rnd-1] + metrics["r_5" + "_round_" + str(rnd)] = r_5_rnd[rnd-1] + metrics["r_10" + "_round_" + str(rnd)] = r_10_rnd[rnd-1] + metrics["mean" + "_round_" + str(rnd)] = mean_rnd[rnd-1] + metrics["mrr" + "_round_" + str(rnd)] = mrr_rnd[rnd-1] + else: + metrics = {} + + if reset: + self.reset() + return metrics + + def reset(self): + self._rank_list = [] + self._rank_list_rnd = [] + +class NDCG(object): + def __init__(self): + self._ndcg_numerator = 0.0 + self._ndcg_denominator = 0.0 + + def observe( + self, predicted_scores: torch.Tensor, target_relevance: torch.Tensor + ): + """ + Observe model output scores and target ground truth relevance and + accumulate NDCG metric. + + Parameters + ---------- + predicted_scores: torch.Tensor + A tensor of shape (batch_size, num_options), because dense + annotations are available for 1 randomly picked round out of 10. + target_relevance: torch.Tensor + A tensor of shape same as predicted scores, indicating ground truth + relevance of each answer option for a particular round. + """ + predicted_scores = predicted_scores.detach() + + # shape: (batch_size, 1, num_options) + predicted_scores = predicted_scores.unsqueeze(1) + predicted_ranks = scores_to_ranks(predicted_scores) + + # shape: (batch_size, num_options) + predicted_ranks = predicted_ranks.squeeze(1) + batch_size, num_options = predicted_ranks.size() + + k = torch.sum(target_relevance != 0, dim=-1) + + # shape: (batch_size, num_options) + _, rankings = torch.sort(predicted_ranks, dim=-1) + # Sort relevance in descending order so highest relevance gets top rnk. + _, best_rankings = torch.sort( + target_relevance, dim=-1, descending=True + ) + + # shape: (batch_size, ) + batch_ndcg = [] + for batch_index in range(batch_size): + num_relevant = k[batch_index] + dcg = self._dcg( + rankings[batch_index][:num_relevant], + target_relevance[batch_index], + ) + best_dcg = self._dcg( + best_rankings[batch_index][:num_relevant], + target_relevance[batch_index], + ) + batch_ndcg.append(dcg / best_dcg) + + self._ndcg_denominator += batch_size + self._ndcg_numerator += sum(batch_ndcg) + + def _dcg(self, rankings: torch.Tensor, relevance: torch.Tensor): + sorted_relevance = relevance[rankings].cpu().float() + discounts = torch.log2(torch.arange(len(rankings)).float() + 2) + return torch.sum(sorted_relevance / discounts, dim=-1) + + def retrieve(self, reset: bool = True): + if self._ndcg_denominator > 0: + metrics = { + "ndcg": float(self._ndcg_numerator / self._ndcg_denominator) + } + else: + metrics = {} + + if reset: + self.reset() + return metrics + + def reset(self): + self._ndcg_numerator = 0.0 + self._ndcg_denominator = 0.0 + +class SparseGTMetricsParallel(object): + """ + A class to accumulate all metrics with sparse ground truth annotations. + These include Recall (@ 1, 5, 10), Mean Rank and Mean Reciprocal Rank. + """ + + def __init__(self, gpu_rank): + self.rank_1 = 0 + self.rank_5 = 0 + self.rank_10 = 0 + self.ranks = 0 + self.reciprocal = 0 + self.count = 0 + self.gpu_rank = gpu_rank + self.img_ids = [] + + def observe( + self, img_id: list, predicted_scores: torch.Tensor, target_ranks: torch.Tensor + ): + if img_id in self.img_ids: + return + else: + self.img_ids.append(img_id) + + predicted_scores = predicted_scores.detach() + + # shape: (batch_size, num_rounds, num_options) + predicted_ranks = scores_to_ranks(predicted_scores) + batch_size, num_rounds, num_options = predicted_ranks.size() + self.num_rounds = num_rounds + # collapse batch dimension + predicted_ranks = predicted_ranks.view( + batch_size * num_rounds, num_options + ) + + # shape: (batch_size * num_rounds, ) + target_ranks = target_ranks.view(batch_size * num_rounds).long() + + # shape: (batch_size * num_rounds, ) + predicted_gt_ranks = predicted_ranks[ + torch.arange(batch_size * num_rounds), target_ranks + ] + + self.rank_1 += (predicted_gt_ranks <= 1).sum().item() + self.rank_5 += (predicted_gt_ranks <= 5).sum().item() + self.rank_10 += (predicted_gt_ranks <= 10).sum().item() + self.ranks += predicted_gt_ranks.sum().item() + self.reciprocal += predicted_gt_ranks.float().reciprocal().sum().item() + self.count += batch_size * num_rounds + + def retrieve(self): + if self.count > 0: + # retrieve data from all gpu + # define tensor on GPU, count and total is the result at each GPU + t = torch.tensor([self.rank_1, self.rank_5, self.rank_10, self.ranks, self.reciprocal, self.count], dtype=torch.float32, device=f'cuda:{self.gpu_rank}') + dist.barrier() # synchronizes all processes + dist.all_reduce(t, op=torch.distributed.ReduceOp.SUM,) # Reduces the tensor data across all machines in such a way that all get the final result. + t = t.tolist() + self.rank_1, self.rank_5, self.rank_10, self.ranks, self.reciprocal, self.count = t + + # convert to numpy array for easy calculation. + metrics = { + "r@1": self.rank_1 / self.count, + "r@5": self.rank_5 / self.count, + "r@10": self.rank_10 / self.count, + "mean": self.ranks / self.count, + "mrr": self.reciprocal / self.count, + "tot_rnds": self.count, + } + + else: + metrics = {} + + return metrics + + def get_count(self): + return int(self.count) + +class NDCGParallel(NDCG): + def __init__(self, gpu_rank): + super(NDCGParallel, self).__init__() + self.gpu_rank = gpu_rank + self.img_ids = [] + self.count = 0 + + def observe( + self, img_id: int, predicted_scores: torch.Tensor, target_relevance: torch.Tensor + ): + """ + Observe model output scores and target ground truth relevance and + accumulate NDCG metric. + + Parameters + ---------- + predicted_scores: torch.Tensor + A tensor of shape (batch_size, num_options), because dense + annotations are available for 1 randomly picked round out of 10. + target_relevance: torch.Tensor + A tensor of shape same as predicted scores, indicating ground truth + relevance of each answer option for a particular round. + """ + if img_id in self.img_ids: + return + else: + self.img_ids.append(img_id) + self.count += 1 + + super(NDCGParallel, self).observe(predicted_scores, target_relevance) + + + def retrieve(self): + if self._ndcg_denominator > 0: + # define tensor on GPU, count and total is the result at each GPU + t = torch.tensor([self._ndcg_numerator, self._ndcg_denominator, self.count], dtype=torch.float32, device=f'cuda:{self.gpu_rank}') + dist.barrier() # synchronizes all processes + dist.all_reduce(t, op=torch.distributed.ReduceOp.SUM,) # Reduces the tensor data across all machines in such a way that all get the final result. + t = t.tolist() + self._ndcg_numerator, self._ndcg_denominator, self.count = t + metrics = { + "ndcg": float(self._ndcg_numerator / self._ndcg_denominator) + } + else: + metrics = {} + return metrics + + def get_count(self): + return int(self.count)