Code release

This commit is contained in:
Adnen Abdessaied 2023-10-25 15:38:09 +02:00
commit 09fb25e339
29 changed files with 7162 additions and 0 deletions

277
README.md Normal file
View file

@ -0,0 +1,277 @@
<div align="center">
<h1> VD-GR: Boosting Visual Dialog with Cascaded Spatial-Temporal Multi-Modal GRaphs </h1>
**[Adnen Abdessaied][5], &nbsp; [Lei Shi][6], &nbsp; [Andreas Bulling][7]** <br> <br>
**WACV'24, Hawaii, USA** <img src="misc/usa.png" width="3%" align="center"> <br>
**[[Paper][8]]**
-------------------
<img src="misc/teaser_1.png" width="100%" align="middle"><br><br>
</div>
# Table of Contents
* [Setup and Dependencies](#Setup-and-Dependencies)
* [Download Data](#Download-Data)
* [Pre-trained Checkpoints](#Pre-trained-Checkpoints)
* [Training](#Training)
* [Results](#Results)
# Setup and Dependencies
We implemented our model using Python 3.7 and PyTorch 1.11.0 (CUDA 11.3, CuDNN 8.2.0). We recommend to setup a virtual environment using Anaconda. <br>
1. Install [git lfs][1] on your system
2. Clone our repository to download the data, checkpoints, and code
```shell
git lfs install
git clone this_repo.git
```
3. Create a conda environment and install dependencies
```shell
conda create -n vdgr python=3.7
conda activate vdgr
conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch
conda install pyg -c pyg # 2.1.0
pip install pytorch-transformers
pip install pytorch_pretrained_bert
pip install pyhocon glog wandb lmdb
```
4. If you wish to speed-up training, we recommend installing [apex][2]
```shell
git clone https://github.com/NVIDIA/apex
cd apex
# if pip >= 23.1 (ref: https://pip.pypa.io/en/stable/news/#v23-1) which supports multiple `--config-settings` with the same key...
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
# otherwise
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --global-option="--cpp_ext" --global-option="--cuda_ext" ./
cd ..
```
# Download Data
1. Download the extacted visual features of [VisDial][3] and setup all files we used in our work. We provide a shell script for convenience:
```shell
./setup_data.sh # Please make sure you have enough disk space
```
If everything was correctly setup, the ```data/``` folder should look like this
```
├── history_adj_matrices
│ ├── test
│ ├── *.pkl
│ ├── train
│ ├── *.pkl
│ ├── val
│ ├── *.pkl
├── question_adj_matrices
│ ├── test
│ ├── *.pkl
│ ├── train
│ ├── *.pkl
│ ├── val
│ ├── *.pkl
├── img_adj_matrices
│ ├── *.pkl
├── parse_vocab.pkl
├── test_dense_mapping.json
├── tr_dense_mapping.json
├── val_dense_mapping.json
├── visdial_0.9_test.json
├── visdial_0.9_train.json
├── visdial_0.9_val.json
├── visdial_1.0_test.json
├── visdial_1.0_train_dense_annotations.json
├── visdial_1.0_train_dense.json
├── visdial_1.0_train.json
├── visdial_1.0_val_dense_annotations.json
├── visdial_1.0_val.json
├── visdialconv_dense_annotations.json
├── visdialconv.json
├── vispro_dense_annotations.json
└── vispro.json
```
# Pre-trained Checkpoints
For convenience, we provide checkpoints of our model after the warm-up training stage in ```ckpt/``` for both VisDial v1.0 and VisDial v0.9. <br>
These checkpoints will be downloaded with the code if you use ```git lfs```.
# Training
We trained our model on 8 Nvidia Tesla V100-32GB GPUs. The default hyperparameters in ```config/vdgr.conf``` and ```config/bert_base_6layer_6conect.json``` need to be adjusted if your setup differs from ours.
## Phase 1
### Training
1. In this phase, we train our model on VisDial v1.0 via
```shell
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
--model vdgr/P1 \
--mode train \
--tag K2_v1.0 \
--wandb_mode online \
--wandb_project your_wandb_project_name
```
⚠️ On a similar setup to ours, this will take roughly 20h to complete using apex for training.
2. To train on VisDial v0.9:
* Set ```visdial_version = 0.9``` in ```config/vdgr.conf```
* Set ```start_path = ckpt/vdgr_visdial_v0.9_after_warmup_K2.ckpt``` in ```config/vdgr.conf```
* Run
```shell
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
--model vdgr/P1 \
--mode train \
--tag K2_v0.9 \
--wandb_mode online \
--wandb_project your_wandb_project_name
```
### Inference
1. For inference on VisDial v1.0 val, VisDialConv, or VisPro:
* Set ```eval_dataset = {visdial, visdial_conv, visdial_vispro}``` in ```logs/vdgr/P1_K2_v1.0/code/config/vdgr.conf```
* Run
```shell
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
--model vdgr/P1 \
--mode eval \
--eval_dir logs/vdgr/P1_K2_v1.0 \
--wandb_mode offline \
```
2. For inference on VisDial v0.9:
* Set ```eval_dataset = visdial``` in ```logs/vdgr/P1_K2_v0.9/code/config/vdgr.conf```
* Run
```shell
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
--model vdgr/P1 \
--mode eval \
--eval_dir logs/vdgr/P1_K2_v0.9 \
--wandb_mode offline \
```
⚠️ This might take some time to finish as the testing data of VisDial v0.9 is large.
3. For inference on the ```visdial_v1.0 test```:
* Run
```shell
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
--model vdgr/P1 \
--mode predict \
--eval_dir logs/vdgr/P1_K2_v1.0 \
--wandb_mode offline \
```
* The output file will be saved in ```output/```
## Phase 2
In this phase, we finetune on dense annotations to improve the NDCG score (Only supported for VisDial v1.0.)
1. Run
```shell
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
--model vdgr/P2_CE \
--mode train \
--tag K2_v1.0_CE \
--wandb_mode online \
--wandb_project your_wandb_project_name
```
This will take roughly 3-4 hours to complete using the same setup as before and [DP][4] for training.
2. For inference on VisDial v1.0:
* Run:
```shell
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
--model vdgr/P2_CE \
--mode predict \
--eval_dir logs/vdgr/P1_K2_v1.0_CE \
--wandb_mode offline \
```
* The output file will be saved in ```output/```
## Phase 3
### Training
In the final phase, we train an ensemble method comprising of 8 models using ```K={1,2,3,4}``` and ```dense_loss={ce, listnet}```.
For ```K=k```:
1. Set the value of ```num_v_gnn_layers, num_q_gnn_layers, num_h_gnn_layers``` to ```k```
2. Set ```start_path = ckpt/vdgr_visdial_v1.0_after_warmup_K[k].ckpt``` in ```config/vdgr.conf``` (P1)
3. Phase 1 training:
```shell
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
--model vdgr/P1 \
--mode train \
--tag K[k]_v1.0 \
--wandb_mode online \
--wandb_project your_wandb_project_name
```
3. Set ```start_path = logs/vdgr/P1_K[k]_v1.0/epoch_best.ckpt``` in ```config/vdgr.conf``` (P2)
4. Phase 2 training:
* Fine-tune with CE:
```shell
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
--model vdgr/P2_CE \
--mode train \
--tag K[k]_v1.0_CE \
--wandb_mode online \
--wandb_project your_wandb_project_name
```
* Fine-tune with LISTNET:
```shell
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
--model vdgr/P2_LISTNET \
--mode train \
--tag K[k]_v1.0_LISTNET \
--wandb_mode online \
--wandb_project your_wandb_project_name
```
### Inference
1. For inference on VisDial v1.0 test:
```shell
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
--model vdgr/P2_[CE,LISTNET] \
--mode predict \
--eval_dir logs/vdgr/P2_K[1,2,3,4]_v1.0_[CE,LISTNET] \
--wandb_mode offline \
```
2. Finally, merge the outputs of all models
```shell
python ensemble.py \
--exp test \
--mode predict \
```
The output file will be saved in ```output/```
# Results
## VisDial v0.9
| Model | MRR | R@1 | R@5 | R@10 | Mean |
|:--------:|:---:|:---:|:---:|:----:|:----:|
| Prev. SOTA | 71.99 | 59.41 | 87.92 | 94.59 | 2.87 |
| VD-GR | **74.50** | **62.10** | **90.49** | **96.37** | **2.45** |
## VisDialConv
| Model | NDCG | MRR | R@1 | R@5 | R@10 | Mean |
|:--------:|:----:|:---:|:---:|:---:|:----:|:----:|
| Prev. SOTA | 61.72 | 61.79 | 48.95 | 77.50 | 86.71 | 4.72 |
| VD-GR | **67.09** | **66.82** | **54.47** | **81.71** | **91.44** | **3.54** |
## VisPro
| Model | NDCG | MRR | R@1 | R@5 | R@10 | Mean |
|:--------:|:----:|:---:|:---:|:---:|:----:|:----:|
| Prev. SOTA | 59.30 | 62.29 | 48.35 | 80.10 | 88.87 | 4.37 |
| VD-GR | **60.35** | **69.89** | **57.21** | **85.97** | **92.68** | **3.15** |
## VisDial V1.0 Val
| Model | NDCG | MRR | R@1 | R@5 | R@10 | Mean |
|:--------:|:----:|:---:|:---:|:---:|:----:|:----:|
| Prev. SOTA | 65.47 | 69.71 | 56.79 | 85.82 | 93.64 | 3.15 |
| VD-GR | 64.32 | **69.91** | **57.01** | **86.14** | **93.74** | **3.13** |
## VisDial V1.0 Test
| Model | NDCG | MRR | R@1 | R@5 | R@10 | Mean |
|:--------:|:----:|:---:|:---:|:---:|:----:|:----:|
| Prev. SOTA | 64.91 | 68.73 | 55.73 | 85.38 | 93.53 | 3.21 |
| VD-GR | 63.49 | 68.65 | 55.33 | **85.58** | **93.85** | **3.20** |
| ♣️ Prev. SOTA | 75.92 | 56.18 | 45.32 | 68.05 | 80.98 | 5.42 |
| ♣️ VD-GR | **75.95** | **58.30** | **46.55** | **71.45** | 84.52 | **5.32** |
| ♣️♦️ Prev. SOTA | 76.17 | 56.42 | 44.75 | 70.23 | 84.52 | 5.47 |
| ♣️♦️ VD-GR | **76.43** | 56.35 | **45.18** | 68.13 | 82.18 | 5.79 |
♣️ = Finetuning on dense annotations, ♦️ = Ensemble model
[1]: https://git-lfs.com/
[2]: https://github.com/NVIDIA/apex
[3]: https://visualdialog.org/
[4]: https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html
[5]: https://adnenabdessaied.de
[6]: https://www.perceptualui.org/people/shi/
[7]: https://www.perceptualui.org/people/bulling/
[8]: https://drive.google.com/file/d/1GT0WDinA_z5FdwVc_bWtyB-cwQkGIf7C/view?usp=sharing

0
ckpt/.gitkeep Normal file
View file

View file

@ -0,0 +1,40 @@
{
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"max_position_embeddings": 512,
"num_attention_heads": 12,
"num_hidden_layers": 12,
"type_vocab_size": 2,
"vocab_size": 30522,
"v_feature_size": 2048,
"v_target_size": 1601,
"v_hidden_size": 1024,
"v_num_hidden_layers": 6,
"v_num_attention_heads": 8,
"v_intermediate_size": 1024,
"bi_hidden_size": 1024,
"bi_num_attention_heads": 8,
"bi_intermediate_size": 1024,
"bi_attention_type": 1,
"v_attention_probs_dropout_prob": 0.1,
"v_hidden_act": "gelu",
"v_hidden_dropout_prob": 0.1,
"v_initializer_range": 0.02,
"pooling_method": "mul",
"v_biattention_id": [0, 1, 2, 3, 4, 5],
"t_biattention_id": [6, 7, 8, 9, 10, 11],
"gnn_act": "gelu",
"num_v_gnn_layers": 2,
"num_q_gnn_layers": 2,
"num_h_gnn_layers": 2,
"num_gnn_attention_heads": 4,
"gnn_dropout_prob": 0.1,
"v_gnn_edge_dim": 12,
"q_gnn_edge_dim": 48,
"v_gnn_ids": [0, 1, 2, 3, 4, 5],
"t_gnn_ids": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
}

33
config/ensemble.conf Normal file
View file

@ -0,0 +1,33 @@
test = {
split = test
skip_mrr_eval = true
# data
visdial_test_data = data/visdial_1.0_test.json
# directory
log_dir = logs/vdgr_ensemble
pred_dir = logs/vdgr
visdial_output_dir = visdial_output
processed = [
false,
false,
false,
false,
false,
false,
false,
false
]
models = [
"P2_K1_v1.0_CE",
"P2_K2_v1.0_CE",
"P2_K3_v1.0_CE",
"P2_K4_v1.0_CE",
"P2_K1_v1.0_LISTNET",
"P2_K2_v1.0_LISTNET",
"P2_K3_v1.0_LISTNET",
"P2_K4_v1.0_LISTNET",
]
}

188
config/vdgr.conf Normal file
View file

@ -0,0 +1,188 @@
# Phase 1
P1 {
use_cpu = false
visdial_version = 1.0
train_on_dense = false
metrics_to_maximize = mrr
# visdial data
visdial_image_feats = data/visdial_img_feat.lmdb
visdial_image_adj_matrices = data/img_adj_matrices
visdial_question_adj_matrices = data/question_adj_matrices
visdial_history_adj_matrices = data/history_adj_matrices
visdial_train = data/visdial_1.0_train.json
visdial_val = data/visdial_1.0_val.json
visdial_test = data/visdial_1.0_test.json
visdial_val_dense_annotations = data/visdial_1.0_val_dense_annotations.json
visdial_train_09 = data/visdial_0.9_train.json
visdial_val_09 = data/visdial_0.9_val.json
visdial_test_09 = data/visdial_0.9_test.json
visdialconv_val = data/visdial_conv.json
visdialconv_val_dense_annotations = data/visdialconv_dense_annotations.json
visdialvispro_val = data/vispro.json
visdialvispro_val_dense_annotations = data/vispro_dense_annotations.json
visdial_question_parse_vocab = data/parse_vocab.pkl
# init
start_path = ckpt/vdgr_visdial_v1.0_after_warmup_K2.ckpt
model_config = config/bert_base_6layer_6conect.json
# visdial training
freeze_vilbert = false
visdial_tot_rounds = 11
num_negative_samples = 1
sequences_per_image = 2
batch_size = 8
lm_loss_coeff = 1
nsp_loss_coeff = 1
img_loss_coeff = 1
batch_multiply = 1
use_trainval = false
dense_loss = ce
dense_loss_coeff = 0
dataloader_text_only = false
rlv_hst_only = false
rlv_hst_dense_round = false
# visdial model
mask_prob = 0.1
image_mask_prob = 0.1
max_seq_len = 256
num_options = 100
num_options_dense = 100
use_embedding = joint
# visdial evaluation
eval_visdial_on_test = true
eval_batch_size = 1
eval_line_batch_size = 200
skip_mrr_eval = false
skip_ndcg_eval = false
skip_visdial_eval = false
eval_visdial_every = 1
eval_dataset = visdial # visdial_vispro # choices = [visdial, visdial_conv, visdial_vispro ]
continue_evaluation = false
eval_at_start = false
eval_before_training = false
initializer = normal
bert_cased = false
# restore ckpt
loads_best_ckpt = false
loads_ckpt = false
restarts = false
resets_max_metric = false
uses_new_optimizer = false
sets_new_lr = false
loads_start_path = false
# logging
random_seed = 42
next_logging_pct = 1.0
next_evaluating_pct = 50.0
max_ckpt_to_keep = 1
num_epochs = 20
early_stop_epoch = 5
skip_saving_ckpt = false
dp_type = apex
stack_gr_data = false
master_port = 5122
stop_epochs = -1
train_each_round = false
drop_last_answer = false
num_samples = -1
# predicting
predict_split = test
predict_each_round = false
predict_dense_round = false
num_test_dialogs = 8000
num_val_dialogs = 2064
save_score = false
# optimizer
reset_optim = none
learning_rate_bert = 5e-6
learning_rate_gnn = 2e-4
gnn_weight_decay = 0.01
use_diff_lr_gnn = true
min_lr = 0
decay_method_bert = linear
decay_method_gnn = linear
decay_exp = 2
max_grad_norm = 1.0
task_optimizer = adam
warmup_ratio = 0.1
# directory
log_dir = logs/vdgr
data_dir = data
visdial_output_dir = visdial_output
bert_cache_dir = transformers
# keep track of other hparams in bert json
v_gnn_edge_dim = 12 # 11 classes + hub_node
q_gnn_edge_dim = 48 # 47 classes + hub_node
num_v_gnn_layers = 2
num_q_gnn_layers = 2
num_h_gnn_layers = 2
num_gnn_attention_heads = 4
v_gnn_ids = [0, 1, 2, 3, 4, 5]
t_gnn_ids = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
}
# Phase 2
P2_CE = ${P1} {
# basic
train_on_dense = true
use_trainval = true
metrics_to_maximize = ndcg
visdial_train_dense = data/visdial_1.0_train_dense.json
visdial_train_dense_annotations = data/visdial_1.0_train_dense_annotations.json
visdial_val_dense = data/visdial_1.0_val.json
tr_graph_idx_mapping = data/tr_dense_mapping.json
val_graph_idx_mapping = data/val_dense_mapping.json
test_graph_idx_mapping = data/test_dense_mapping.json
visdial_val = data/visdial_1.0_val.json
visdial_val_dense_annotations = data/visdial_1.0_val_dense_annotations.json
# data
start_path = logs/vdgr/P1_K2_v1.0/epoch_best.ckpt
rlv_hst_only = false
# visdial training
nsp_loss_coeff = 0
dense_loss_coeff = 1
batch_multiply = 10
batch_size = 1
# visdial model
num_options_dense = 100
# visdial evaluation
eval_batch_size = 1
eval_line_batch_size = 100
skip_mrr_eval = true
# training
stop_epochs = 3
dp_type = dp
dense_loss = ce
# optimizer
learning_rate_bert = 1e-4
}
P2_LISTNET = ${P2_CE} {
dense_loss = listnet
}

0
data/.gitkeep Normal file
View file

0
dataloader/__init__.py Normal file
View file

View file

@ -0,0 +1,269 @@
import torch
from torch.utils import data
import json
import os
import glog as log
import pickle
import torch.utils.data as tud
from pytorch_transformers.tokenization_bert import BertTokenizer
from utils.image_features_reader import ImageFeaturesH5Reader
class DatasetBase(data.Dataset):
def __init__(self, config):
if config['display']:
log.info('Initializing dataset')
# Fetch the correct dataset for evaluation
if config['validating']:
assert config.eval_dataset in ['visdial', 'visdial_conv', 'visdial_vispro', 'visdial_v09']
if config.eval_dataset == 'visdial_conv':
config['visdial_val'] = config.visdialconv_val
config['visdial_val_dense_annotations'] = config.visdialconv_val_dense_annotations
elif config.eval_dataset == 'visdial_vispro':
config['visdial_val'] = config.visdialvispro_val
config['visdial_val_dense_annotations'] = config.visdialvispro_val_dense_annotations
elif config.eval_dataset == 'visdial_v09':
config['visdial_val_09'] = config.visdial_test_09
config['visdial_val_dense_annotations'] = None
self.config = config
self.numDataPoints = {}
if not config['dataloader_text_only']:
self._image_features_reader = ImageFeaturesH5Reader(
config['visdial_image_feats'],
config['visdial_image_adj_matrices']
)
if self.config['training'] or self.config['validating'] or self.config['predicting']:
split2data = {'train': 'train', 'val': 'val', 'test': 'test'}
elif self.config['debugging']:
split2data = {'train': 'val', 'val': 'val', 'test': 'test'}
elif self.config['visualizing']:
split2data = {'train': 'train', 'val': 'train', 'test': 'test'}
filename = f'visdial_{split2data["train"]}'
if config['train_on_dense']:
filename += '_dense'
if self.config['visdial_version'] == 0.9:
filename += '_09'
with open(config[filename]) as f:
self.visdial_data_train = json.load(f)
if self.config.num_samples > 0:
self.visdial_data_train['data']['dialogs'] = self.visdial_data_train['data']['dialogs'][:self.config.num_samples]
self.numDataPoints['train'] = len(self.visdial_data_train['data']['dialogs'])
filename = f'visdial_{split2data["val"]}'
if config['train_on_dense'] and config['training']:
filename += '_dense'
if self.config['visdial_version'] == 0.9:
filename += '_09'
with open(config[filename]) as f:
self.visdial_data_val = json.load(f)
if self.config.num_samples > 0:
self.visdial_data_val['data']['dialogs'] = self.visdial_data_val['data']['dialogs'][:self.config.num_samples]
self.numDataPoints['val'] = len(self.visdial_data_val['data']['dialogs'])
if config['train_on_dense']:
self.numDataPoints['trainval'] = self.numDataPoints['train'] + self.numDataPoints['val']
with open(config[f'visdial_{split2data["test"]}']) as f:
self.visdial_data_test = json.load(f)
self.numDataPoints['test'] = len(self.visdial_data_test['data']['dialogs'])
self.rlv_hst_train = None
self.rlv_hst_val = None
self.rlv_hst_test = None
if config['train_on_dense'] or config['predict_dense_round']:
with open(config[f'visdial_{split2data["train"]}_dense_annotations']) as f:
self.visdial_data_train_dense = json.load(f)
if config['train_on_dense']:
self.subsets = ['train', 'val', 'trainval', 'test']
else:
self.subsets = ['train','val','test']
self.num_options = config["num_options"]
self.num_options_dense = config["num_options_dense"]
if config['visdial_version'] != 0.9:
with open(config[f'visdial_{split2data["val"]}_dense_annotations']) as f:
self.visdial_data_val_dense = json.load(f)
else:
self.visdial_data_val_dense = None
self._split = 'train'
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', cache_dir=config['bert_cache_dir'])
# fetching token indicecs of [CLS] and [SEP]
tokens = ['[CLS]','[MASK]','[SEP]']
indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokens)
self.CLS = indexed_tokens[0]
self.MASK = indexed_tokens[1]
self.SEP = indexed_tokens[2]
self._max_region_num = 37
self.predict_each_round = self.config['predicting'] and self.config['predict_each_round']
self.keys_to_expand = ['image_feat', 'image_loc', 'image_mask', 'image_target', 'image_label']
self.keys_to_flatten_1d = ['hist_len', 'next_sentence_labels', 'round_id', 'image_id']
self.keys_to_flatten_2d = ['tokens', 'segments', 'sep_indices', 'mask', 'image_mask', 'image_label', 'input_mask', 'question_limits']
self.keys_to_flatten_3d = ['image_feat', 'image_loc', 'image_target', ]
self.keys_other = ['gt_relevance', 'gt_option_inds']
self.keys_lists_to_flatten = ['image_edge_indices', 'image_edge_attributes', 'question_edge_indices', 'question_edge_attributes', 'history_edge_indices', 'history_sep_indices']
if config['stack_gr_data']:
self.keys_to_flatten_3d.extend(self.keys_lists_to_flatten[:-1])
self.keys_to_flatten_2d.append(self.keys_lists_to_flatten[-1])
self.keys_to_flatten_1d.extend(['len_image_gr', 'len_question_gr', 'len_history_gr', 'len_history_sep'])
self.keys_lists_to_flatten = []
self.keys_to_list = ['tot_len']
# Load the parse vocab for question graph relationship mapping
if os.path.isfile(config['visdial_question_parse_vocab']):
with open(config['visdial_question_parse_vocab'], 'rb') as f:
self.parse_vocab = pickle.load(f)
def __len__(self):
return self.numDataPoints[self._split]
@property
def split(self):
return self._split
@split.setter
def split(self, split):
assert split in self.subsets
self._split = split
def tokens2str(self, seq):
dialog_sequence = ''
for sentence in seq:
for word in sentence:
dialog_sequence += self.tokenizer._convert_id_to_token(word) + " "
dialog_sequence += ' </end> '
dialog_sequence = dialog_sequence.encode('utf8')
return dialog_sequence
def pruneRounds(self, context, num_rounds):
start_segment = 1
len_context = len(context)
cur_rounds = (len(context) // 2) + 1
l_index = 0
if cur_rounds > num_rounds:
# caption is not part of the final input
l_index = len_context - (2 * num_rounds)
start_segment = 0
return context[l_index:], start_segment
def tokenize_utterance(self, sent, sentences, tot_len, sentence_count, sentence_map, speakers):
sentences.extend(sent + ['[SEP]'])
tokenized_sent = self.tokenizer.convert_tokens_to_ids(sent)
assert len(sent) == len(tokenized_sent), 'sub-word tokens are not allowed!'
sent_len = len(tokenized_sent)
tot_len += sent_len + 1 # the additional 1 is for the sep token
sentence_count += 1
sentence_map.extend([sentence_count * 2 - 1] * sent_len)
sentence_map.append(sentence_count * 2) # for [SEP]
speakers.extend([2] * (sent_len + 1))
return tokenized_sent, sentences, tot_len, sentence_count, sentence_map, speakers
def __getitem__(self, index):
return NotImplementedError
def collate_fn(self, batch):
tokens_size = batch[0]['tokens'].size()
num_rounds, num_samples = tokens_size[0], tokens_size[1]
merged_batch = {key: [d[key] for d in batch] for key in batch[0]}
if self.config['stack_gr_data']:
if (len(batch)) > 1:
max_question_gr_len = max([length.max().item() for length in merged_batch['len_question_gr']])
max_history_gr_len = max([length.max().item() for length in merged_batch['len_history_gr']])
max_history_sep_len = max([length.max().item() for length in merged_batch['len_history_sep']])
max_image_gr_len = max([length.max().item() for length in merged_batch['len_image_gr']])
question_edge_indices_padded = []
question_edge_attributes_padded = []
for q_e_idx, q_e_attr in zip(merged_batch['question_edge_indices'], merged_batch['question_edge_attributes']):
b_size, edge_dim, orig_len = q_e_idx.size()
q_e_idx_padded = torch.zeros((b_size, edge_dim, max_question_gr_len))
q_e_idx_padded[:, :, :orig_len] = q_e_idx
question_edge_indices_padded.append(q_e_idx_padded)
edge_attr_dim = q_e_attr.size(-1)
q_e_attr_padded = torch.zeros((b_size, max_question_gr_len, edge_attr_dim))
q_e_attr_padded[:, :orig_len, :] = q_e_attr
question_edge_attributes_padded.append(q_e_attr_padded)
merged_batch['question_edge_indices'] = question_edge_indices_padded
merged_batch['question_edge_attributes'] = question_edge_attributes_padded
history_edge_indices_padded = []
for h_e_idx in merged_batch['history_edge_indices']:
b_size, _, orig_len = h_e_idx.size()
h_edge_idx_padded = torch.zeros((b_size, 2, max_history_gr_len))
h_edge_idx_padded[:, :, :orig_len] = h_e_idx
history_edge_indices_padded.append(h_edge_idx_padded)
merged_batch['history_edge_indices'] = history_edge_indices_padded
history_sep_indices_padded = []
for hist_sep_idx in merged_batch['history_sep_indices']:
b_size, orig_len = hist_sep_idx.size()
hist_sep_idx_padded = torch.zeros((b_size, max_history_sep_len))
hist_sep_idx_padded[:, :orig_len] = hist_sep_idx
history_sep_indices_padded.append(hist_sep_idx_padded)
merged_batch['history_sep_indices'] = history_sep_indices_padded
image_edge_indices_padded = []
image_edge_attributes_padded = []
for img_e_idx, img_e_attr in zip(merged_batch['image_edge_indices'], merged_batch['image_edge_attributes']):
b_size, edge_dim, orig_len = img_e_idx.size()
img_e_idx_padded = torch.zeros((b_size, edge_dim, max_image_gr_len))
img_e_idx_padded[:, :, :orig_len] = img_e_idx
image_edge_indices_padded.append(img_e_idx_padded)
edge_attr_dim = img_e_attr.size(-1)
img_e_attr_padded = torch.zeros((b_size, max_image_gr_len, edge_attr_dim))
img_e_attr_padded[:, :orig_len, :] = img_e_attr
image_edge_attributes_padded.append(img_e_attr_padded)
merged_batch['image_edge_indices'] = image_edge_indices_padded
merged_batch['image_edge_attributes'] = image_edge_attributes_padded
out = {}
for key in merged_batch:
if key in self.keys_lists_to_flatten:
temp = []
for b in merged_batch[key]:
for x in b:
temp.append(x)
merged_batch[key] = temp
elif key in self.keys_to_list:
pass
else:
merged_batch[key] = torch.stack(merged_batch[key], 0)
if key in self.keys_to_expand:
if len(merged_batch[key].size()) == 3:
size0, size1, size2 = merged_batch[key].size()
expand_size = (size0, num_rounds, num_samples, size1, size2)
elif len(merged_batch[key].size()) == 2:
size0, size1 = merged_batch[key].size()
expand_size = (size0, num_rounds, num_samples, size1)
merged_batch[key] = merged_batch[key].unsqueeze(1).unsqueeze(1).expand(expand_size).contiguous()
if key in self.keys_to_flatten_1d:
merged_batch[key] = merged_batch[key].reshape(-1)
elif key in self.keys_to_flatten_2d:
merged_batch[key] = merged_batch[key].reshape(-1, merged_batch[key].shape[-1])
elif key in self.keys_to_flatten_3d:
merged_batch[key] = merged_batch[key].reshape(-1, merged_batch[key].shape[-2], merged_batch[key].shape[-1])
else:
assert key in self.keys_other, f'unrecognized key in collate_fn: {key}'
out[key] = merged_batch[key]
return out

View file

@ -0,0 +1,615 @@
import torch
import os
import numpy as np
import random
import pickle
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from utils.data_utils import encode_input, encode_input_with_mask, encode_image_input
from dataloader.dataloader_base import DatasetBase
class VisdialDataset(DatasetBase):
def __init__(self, config):
super(VisdialDataset, self).__init__(config)
def __getitem__(self, index):
MAX_SEQ_LEN = self.config['max_seq_len']
cur_data = None
if self._split == 'train':
cur_data = self.visdial_data_train['data']
ques_adj_matrices_dir = os.path.join(self.config['visdial_question_adj_matrices'], 'train')
hist_adj_matrices_dir = os.path.join(self.config['visdial_history_adj_matrices'], 'train')
elif self._split == 'val':
cur_data = self.visdial_data_val['data']
ques_adj_matrices_dir = os.path.join(self.config['visdial_question_adj_matrices'], 'val')
hist_adj_matrices_dir = os.path.join(self.config['visdial_history_adj_matrices'], 'val')
else:
cur_data = self.visdial_data_test['data']
ques_adj_matrices_dir = os.path.join(self.config['visdial_question_adj_matrices'], 'test')
hist_adj_matrices_dir = os.path.join(self.config['visdial_history_adj_matrices'], 'test')
if self.config['visdial_version'] == 0.9:
ques_adj_matrices_dir = os.path.join(self.config['visdial_question_adj_matrices'], 'train')
hist_adj_matrices_dir = os.path.join(self.config['visdial_history_adj_matrices'], 'train')
self.num_bad_samples = 0
# number of options to score on
num_options = self.num_options
assert num_options > 1 and num_options <= 100
num_dialog_rounds = 10
dialog = cur_data['dialogs'][index]
cur_questions = cur_data['questions']
cur_answers = cur_data['answers']
img_id = dialog['image_id']
graph_idx = dialog.get('dialog_idx', index)
if self._split == 'train':
# caption
sent = dialog['caption'].split(' ')
sentences = ['[CLS]']
tot_len = 1 # for the CLS token
sentence_map = [0] # for the CLS token
sentence_count = 0
speakers = [0]
tokenized_sent, sentences, tot_len, sentence_count, sentence_map, speakers = \
self.tokenize_utterance(sent, sentences, tot_len, sentence_count, sentence_map, speakers)
utterances = [[tokenized_sent]]
utterances_random = [[tokenized_sent]]
for rnd, utterance in enumerate(dialog['dialog']):
cur_rnd_utterance = utterances[-1].copy()
cur_rnd_utterance_random = utterances[-1].copy()
# question
sent = cur_questions[utterance['question']].split(' ')
tokenized_sent, sentences, tot_len, sentence_count, sentence_map, speakers = \
self.tokenize_utterance(sent, sentences, tot_len, sentence_count, sentence_map, speakers)
cur_rnd_utterance.append(tokenized_sent)
cur_rnd_utterance_random.append(tokenized_sent)
# answer
sent = cur_answers[utterance['answer']].split(' ')
tokenized_sent, sentences, tot_len, sentence_count, sentence_map, speakers = \
self.tokenize_utterance(sent, sentences, tot_len, sentence_count, sentence_map, speakers)
cur_rnd_utterance.append(tokenized_sent)
utterances.append(cur_rnd_utterance)
# randomly select one random utterance in that round
num_inds = len(utterance['answer_options'])
gt_option_ind = utterance['gt_index']
negative_samples = []
for _ in range(self.config["num_negative_samples"]):
all_inds = list(range(100))
all_inds.remove(gt_option_ind)
all_inds = all_inds[:(num_options-1)]
tokenized_random_utterance = None
option_ind = None
while len(all_inds):
option_ind = random.choice(all_inds)
tokenized_random_utterance = self.tokenizer.convert_tokens_to_ids(cur_answers[utterance['answer_options'][option_ind]].split(' '))
# the 1 here is for the sep token at the end of each utterance
if(MAX_SEQ_LEN >= (tot_len + len(tokenized_random_utterance) + 1)):
break
else:
all_inds.remove(option_ind)
if len(all_inds) == 0:
# all the options exceed the max len. Truncate the last utterance in this case.
tokenized_random_utterance = tokenized_random_utterance[:len(tokenized_sent)]
t = cur_rnd_utterance_random.copy()
t.append(tokenized_random_utterance)
negative_samples.append(t)
utterances_random.append(negative_samples)
# removing the caption in the beginning
utterances = utterances[1:]
utterances_random = utterances_random[1:]
assert len(utterances) == len(utterances_random) == num_dialog_rounds
assert tot_len <= MAX_SEQ_LEN, '{} {} tot_len = {} > max_seq_len'.format(
self._split, index, tot_len
)
tokens_all = []
question_limits_all = []
question_edge_indices_all = []
question_edge_attributes_all = []
history_edge_indices_all = []
history_sep_indices_all = []
mask_all = []
segments_all = []
sep_indices_all = []
next_labels_all = []
hist_len_all = []
# randomly pick several rounds to train
pos_rounds = sorted(random.sample(range(num_dialog_rounds), self.config['sequences_per_image'] // 2), reverse=True)
neg_rounds = sorted(random.sample(range(num_dialog_rounds), self.config['sequences_per_image'] // 2), reverse=True)
tokens_all_rnd = []
question_limits_all_rnd = []
mask_all_rnd = []
segments_all_rnd = []
sep_indices_all_rnd = []
next_labels_all_rnd = []
hist_len_all_rnd = []
for j in pos_rounds:
context = utterances[j]
context, start_segment = self.pruneRounds(context, self.config['visdial_tot_rounds'])
if j == pos_rounds[0]: # dialog with positive label and max rounds
tokens, segments, sep_indices, mask, input_mask, start_question, end_question = encode_input_with_mask(context, start_segment, self.CLS,
self.SEP, self.MASK, max_seq_len=MAX_SEQ_LEN, mask_prob=self.config["mask_prob"])
else:
tokens, segments, sep_indices, mask, start_question, end_question = encode_input(context, start_segment, self.CLS,
self.SEP, self.MASK, max_seq_len=MAX_SEQ_LEN, mask_prob=self.config["mask_prob"])
tokens_all_rnd.append(tokens)
question_limits_all_rnd.append(torch.tensor([start_question, end_question]))
mask_all_rnd.append(mask)
sep_indices_all_rnd.append(sep_indices)
next_labels_all_rnd.append(torch.LongTensor([0]))
segments_all_rnd.append(segments)
hist_len_all_rnd.append(torch.LongTensor([len(context)-1]))
tokens_all.append(torch.cat(tokens_all_rnd,0).unsqueeze(0))
mask_all.append(torch.cat(mask_all_rnd,0).unsqueeze(0))
question_limits_all.extend(question_limits_all_rnd)
segments_all.append(torch.cat(segments_all_rnd, 0).unsqueeze(0))
sep_indices_all.append(torch.cat(sep_indices_all_rnd, 0).unsqueeze(0))
next_labels_all.append(torch.cat(next_labels_all_rnd, 0).unsqueeze(0))
hist_len_all.append(torch.cat(hist_len_all_rnd,0).unsqueeze(0))
assert len(pos_rounds) == 1
question_graphs = pickle.load(
open(os.path.join(ques_adj_matrices_dir, f'{graph_idx}.pkl'), 'rb')
)
question_graph_pos = question_graphs[pos_rounds[0]]
question_edge_index_pos = []
question_edge_attribute_pos = []
for edge_idx, edge_attr in question_graph_pos:
question_edge_index_pos.append(edge_idx)
edge_attr_one_hot = np.zeros((len(self.parse_vocab) + 1,), dtype=np.float32)
edge_attr_one_hot[self.parse_vocab.get(edge_attr, len(self.parse_vocab))] = 1.0
question_edge_attribute_pos.append(edge_attr_one_hot)
question_edge_index_pos = np.array(question_edge_index_pos, dtype=np.float64)
question_edge_attribute_pos = np.stack(question_edge_attribute_pos, axis=0)
question_edge_indices_all.append(
torch.from_numpy(question_edge_index_pos).t().long().contiguous()
)
question_edge_attributes_all.append(
torch.from_numpy(question_edge_attribute_pos)
)
history_edge_indices = pickle.load(
open(os.path.join(hist_adj_matrices_dir, f'{graph_idx}.pkl'), 'rb')
)
history_edge_indices_all.append(
torch.tensor(history_edge_indices[pos_rounds[0]]).t().long().contiguous()
)
# Get the [SEP] tokens that will represent the history graph node features
hist_idx_pos = [i * 2 for i in range(pos_rounds[0] + 1)]
sep_indices = sep_indices.squeeze(0).numpy()
history_sep_indices_all.append(torch.from_numpy(sep_indices[hist_idx_pos]))
if len(neg_rounds) > 0:
tokens_all_rnd = []
question_limits_all_rnd = []
mask_all_rnd = []
segments_all_rnd = []
sep_indices_all_rnd = []
next_labels_all_rnd = []
hist_len_all_rnd = []
for j in neg_rounds:
negative_samples = utterances_random[j]
for context_random in negative_samples:
context_random, start_segment = self.pruneRounds(context_random, self.config['visdial_tot_rounds'])
tokens_random, segments_random, sep_indices_random, mask_random, start_question, end_question = encode_input(context_random, start_segment, self.CLS,
self.SEP, self.MASK, max_seq_len=MAX_SEQ_LEN, mask_prob=self.config["mask_prob"])
tokens_all_rnd.append(tokens_random)
question_limits_all_rnd.append(torch.tensor([start_question, end_question]))
mask_all_rnd.append(mask_random)
sep_indices_all_rnd.append(sep_indices_random)
next_labels_all_rnd.append(torch.LongTensor([1]))
segments_all_rnd.append(segments_random)
hist_len_all_rnd.append(torch.LongTensor([len(context_random)-1]))
tokens_all.append(torch.cat(tokens_all_rnd,0).unsqueeze(0))
mask_all.append(torch.cat(mask_all_rnd,0).unsqueeze(0))
question_limits_all.extend(question_limits_all_rnd)
segments_all.append(torch.cat(segments_all_rnd, 0).unsqueeze(0))
sep_indices_all.append(torch.cat(sep_indices_all_rnd, 0).unsqueeze(0))
next_labels_all.append(torch.cat(next_labels_all_rnd, 0).unsqueeze(0))
hist_len_all.append(torch.cat(hist_len_all_rnd,0).unsqueeze(0))
assert len(neg_rounds) == 1
question_graph_neg = question_graphs[neg_rounds[0]]
question_edge_index_neg = []
question_edge_attribute_neg = []
for edge_idx, edge_attr in question_graph_neg:
question_edge_index_neg.append(edge_idx)
edge_attr_one_hot = np.zeros((len(self.parse_vocab) + 1,), dtype=np.float32)
edge_attr_one_hot[self.parse_vocab.get(edge_attr, len(self.parse_vocab))] = 1.0
question_edge_attribute_neg.append(edge_attr_one_hot)
question_edge_index_neg = np.array(question_edge_index_neg, dtype=np.float64)
question_edge_attribute_neg = np.stack(question_edge_attribute_neg, axis=0)
question_edge_indices_all.append(
torch.from_numpy(question_edge_index_neg).t().long().contiguous()
)
question_edge_attributes_all.append(
torch.from_numpy(question_edge_attribute_neg)
)
history_edge_indices_all.append(
torch.tensor(history_edge_indices[neg_rounds[0]]).t().long().contiguous()
)
# Get the [SEP] tokens that will represent the history graph node features
hist_idx_neg = [i * 2 for i in range(neg_rounds[0] + 1)]
sep_indices_random = sep_indices_random.squeeze(0).numpy()
history_sep_indices_all.append(torch.from_numpy(sep_indices_random[hist_idx_neg]))
tokens_all = torch.cat(tokens_all, 0) # [2, num_pos, max_len]
question_limits_all = torch.stack(question_limits_all, 0) # [2, 2]
mask_all = torch.cat(mask_all,0)
segments_all = torch.cat(segments_all, 0)
sep_indices_all = torch.cat(sep_indices_all, 0)
next_labels_all = torch.cat(next_labels_all, 0)
hist_len_all = torch.cat(hist_len_all, 0)
input_mask_all = torch.LongTensor(input_mask) # [max_len]
item = {}
item['tokens'] = tokens_all
item['question_limits'] = question_limits_all
item['question_edge_indices'] = question_edge_indices_all
item['question_edge_attributes'] = question_edge_attributes_all
item['history_edge_indices'] = history_edge_indices_all
item['history_sep_indices'] = history_sep_indices_all
item['segments'] = segments_all
item['sep_indices'] = sep_indices_all
item['mask'] = mask_all
item['next_sentence_labels'] = next_labels_all
item['hist_len'] = hist_len_all
item['input_mask'] = input_mask_all
# get image features
if not self.config['dataloader_text_only']:
features, num_boxes, boxes, _ , image_target, image_edge_indexes, image_edge_attributes = self._image_features_reader[img_id]
features, spatials, image_mask, image_target, image_label = encode_image_input(features, num_boxes, boxes, image_target, max_regions=self._max_region_num)
else:
features = spatials = image_mask = image_target = image_label = torch.tensor([0])
elif self._split == 'val':
gt_relevance = None
gt_option_inds = []
options_all = []
# caption
sent = dialog['caption'].split(' ')
sentences = ['[CLS]']
tot_len = 1 # for the CLS token
sentence_map = [0] # for the CLS token
sentence_count = 0
speakers = [0]
tokenized_sent, sentences, tot_len, sentence_count, sentence_map, speakers = \
self.tokenize_utterance(sent, sentences, tot_len, sentence_count, sentence_map, speakers)
utterances = [[tokenized_sent]]
for rnd, utterance in enumerate(dialog['dialog']):
cur_rnd_utterance = utterances[-1].copy()
# question
sent = cur_questions[utterance['question']].split(' ')
tokenized_sent, sentences, tot_len, sentence_count, sentence_map, speakers = \
self.tokenize_utterance(sent, sentences, tot_len, sentence_count, sentence_map, speakers)
cur_rnd_utterance.append(tokenized_sent)
# current round
gt_option_ind = utterance['gt_index']
# first select gt option id, then choose the first num_options inds
option_inds = []
option_inds.append(gt_option_ind)
all_inds = list(range(100))
all_inds.remove(gt_option_ind)
all_inds = all_inds[:(num_options-1)]
option_inds.extend(all_inds)
gt_option_inds.append(0)
cur_rnd_options = []
answer_options = [utterance['answer_options'][k] for k in option_inds]
assert len(answer_options) == len(option_inds) == num_options
assert answer_options[0] == utterance['answer']
# for evaluation of all options and dense relevance
if self.visdial_data_val_dense:
if rnd == self.visdial_data_val_dense[index]['round_id'] - 1:
# only 1 round has gt_relevance for each example
if 'relevance' in self.visdial_data_val_dense[index]:
gt_relevance = torch.Tensor(self.visdial_data_val_dense[index]['relevance'])
else:
gt_relevance = torch.Tensor(self.visdial_data_val_dense[index]['gt_relevance'])
# shuffle based on new indices
gt_relevance = gt_relevance[torch.LongTensor(option_inds)]
else:
gt_relevance = -1
for answer_option in answer_options:
cur_rnd_cur_option = cur_rnd_utterance.copy()
cur_rnd_cur_option.append(self.tokenizer.convert_tokens_to_ids(cur_answers[answer_option].split(' ')))
cur_rnd_options.append(cur_rnd_cur_option)
# answer
sent = cur_answers[utterance['answer']].split(' ')
tokenized_sent, sentences, tot_len, sentence_count, sentence_map, speakers = \
self.tokenize_utterance(sent, sentences, tot_len, sentence_count, sentence_map, speakers)
cur_rnd_utterance.append(tokenized_sent)
utterances.append(cur_rnd_utterance)
options_all.append(cur_rnd_options)
# encode the input and create batch x 10 x 100 * max_len arrays (batch x num_rounds x num_options)
tokens_all = []
question_limits_all = []
mask_all = []
segments_all = []
sep_indices_all = []
hist_len_all = []
history_sep_indices_all = []
for rnd, cur_rnd_options in enumerate(options_all):
tokens_all_rnd = []
mask_all_rnd = []
segments_all_rnd = []
sep_indices_all_rnd = []
hist_len_all_rnd = []
for j, cur_rnd_option in enumerate(cur_rnd_options):
cur_rnd_option, start_segment = self.pruneRounds(cur_rnd_option, self.config['visdial_tot_rounds'])
if rnd == len(options_all) - 1 and j == 0: # gt dialog
tokens, segments, sep_indices, mask, input_mask, start_question, end_question = encode_input_with_mask(cur_rnd_option, start_segment, self.CLS,
self.SEP, self.MASK, max_seq_len=MAX_SEQ_LEN, mask_prob=0)
else:
tokens, segments, sep_indices, mask, start_question, end_question = encode_input(cur_rnd_option, start_segment,self.CLS,
self.SEP, self.MASK ,max_seq_len=MAX_SEQ_LEN, mask_prob=0)
tokens_all_rnd.append(tokens)
mask_all_rnd.append(mask)
segments_all_rnd.append(segments)
sep_indices_all_rnd.append(sep_indices)
hist_len_all_rnd.append(torch.LongTensor([len(cur_rnd_option)-1]))
question_limits_all.append(torch.tensor([start_question, end_question]).unsqueeze(0).repeat(100, 1))
tokens_all.append(torch.cat(tokens_all_rnd,0).unsqueeze(0))
mask_all.append(torch.cat(mask_all_rnd,0).unsqueeze(0))
segments_all.append(torch.cat(segments_all_rnd,0).unsqueeze(0))
sep_indices_all.append(torch.cat(sep_indices_all_rnd,0).unsqueeze(0))
hist_len_all.append(torch.cat(hist_len_all_rnd,0).unsqueeze(0))
# Get the [SEP] tokens that will represent the history graph node features
# It will be the same for all answer candidates as the history does not change
# for each answer
hist_idx = [i * 2 for i in range(rnd + 1)]
history_sep_indices_all.extend(sep_indices.squeeze(0)[hist_idx].contiguous() for _ in range(100))
tokens_all = torch.cat(tokens_all, 0) # [10, 100, max_len]
mask_all = torch.cat(mask_all, 0)
segments_all = torch.cat(segments_all, 0)
sep_indices_all = torch.cat(sep_indices_all, 0)
hist_len_all = torch.cat(hist_len_all, 0)
input_mask_all = torch.LongTensor(input_mask) # [max_len]
# load graph data
question_limits_all = torch.stack(question_limits_all, 0) # [10, 100, 2]
question_graphs = pickle.load(
open(os.path.join(ques_adj_matrices_dir, f'{graph_idx}.pkl'), 'rb')
)
question_edge_indices_all = [] # [10, N] we do not repeat it 100 times here
question_edge_attributes_all = [] # [10, N] we do not repeat it 100 times here
for q_graph_round in question_graphs:
question_edge_index = []
question_edge_attribute = []
for edge_index, edge_attr in q_graph_round:
question_edge_index.append(edge_index)
edge_attr_one_hot = np.zeros((len(self.parse_vocab) + 1,), dtype=np.float32)
edge_attr_one_hot[self.parse_vocab.get(edge_attr, len(self.parse_vocab))] = 1.0
question_edge_attribute.append(edge_attr_one_hot)
question_edge_index = np.array(question_edge_index, dtype=np.float64)
question_edge_attribute = np.stack(question_edge_attribute, axis=0)
question_edge_indices_all.extend(
[torch.from_numpy(question_edge_index).t().long().contiguous() for _ in range(100)])
question_edge_attributes_all.extend(
[torch.from_numpy(question_edge_attribute).contiguous() for _ in range(100)])
_history_edge_incides_all = pickle.load(
open(os.path.join(hist_adj_matrices_dir, f'{graph_idx}.pkl'), 'rb')
)
history_edge_incides_all = []
for hist_edge_indices_rnd in _history_edge_incides_all:
history_edge_incides_all.extend(
[torch.tensor(hist_edge_indices_rnd).t().long().contiguous() for _ in range(100)]
)
item = {}
item['tokens'] = tokens_all
item['segments'] = segments_all
item['sep_indices'] = sep_indices_all
item['mask'] = mask_all
item['hist_len'] = hist_len_all
item['input_mask'] = input_mask_all
item['gt_option_inds'] = torch.LongTensor(gt_option_inds)
# return dense annotation data as well
if self.visdial_data_val_dense:
item['round_id'] = torch.LongTensor([self.visdial_data_val_dense[index]['round_id']])
item['gt_relevance'] = gt_relevance
item['question_limits'] = question_limits_all
item['question_edge_indices'] = question_edge_indices_all
item['question_edge_attributes'] = question_edge_attributes_all
item['history_edge_indices'] = history_edge_incides_all
item['history_sep_indices'] = history_sep_indices_all
# get image features
if not self.config['dataloader_text_only']:
features, num_boxes, boxes, _ , image_target, image_edge_indexes, image_edge_attributes = self._image_features_reader[img_id]
features, spatials, image_mask, image_target, image_label = encode_image_input(
features, num_boxes, boxes, image_target, max_regions=self._max_region_num, mask_prob=0)
else:
features = spatials = image_mask = image_target = image_label = torch.tensor([0])
elif self.split == 'test':
assert num_options == 100
cur_rnd_utterance = [self.tokenizer.convert_tokens_to_ids(dialog['caption'].split(' '))]
options_all = []
for rnd,utterance in enumerate(dialog['dialog']):
cur_rnd_utterance.append(self.tokenizer.convert_tokens_to_ids(cur_questions[utterance['question']].split(' ')))
if rnd != len(dialog['dialog'])-1:
cur_rnd_utterance.append(self.tokenizer.convert_tokens_to_ids(cur_answers[utterance['answer']].split(' ')))
for answer_option in dialog['dialog'][-1]['answer_options']:
cur_option = cur_rnd_utterance.copy()
cur_option.append(self.tokenizer.convert_tokens_to_ids(cur_answers[answer_option].split(' ')))
options_all.append(cur_option)
tokens_all = []
mask_all = []
segments_all = []
sep_indices_all = []
hist_len_all = []
for j, option in enumerate(options_all):
option, start_segment = self.pruneRounds(option, self.config['visdial_tot_rounds'])
tokens, segments, sep_indices, mask = encode_input(option, start_segment ,self.CLS,
self.SEP, self.MASK ,max_seq_len=MAX_SEQ_LEN, mask_prob=0)
tokens_all.append(tokens)
mask_all.append(mask)
segments_all.append(segments)
sep_indices_all.append(sep_indices)
hist_len_all.append(torch.LongTensor([len(option)-1]))
tokens_all = torch.cat(tokens_all,0)
mask_all = torch.cat(mask_all,0)
segments_all = torch.cat(segments_all, 0)
sep_indices_all = torch.cat(sep_indices_all, 0)
hist_len_all = torch.cat(hist_len_all,0)
hist_idx = [i*2 for i in range(len(dialog['dialog']))]
history_sep_indices_all = [sep_indices.squeeze(0)[hist_idx].contiguous() for _ in range(num_options)]
with open(os.path.join(ques_adj_matrices_dir, f'{graph_idx}.pkl'), 'rb') as f:
question_graphs = pickle.load(f)
q_graph_last = question_graphs[-1]
question_edge_index = []
question_edge_attribute = []
for edge_index, edge_attr in q_graph_last:
question_edge_index.append(edge_index)
edge_attr_one_hot = np.zeros((len(self.parse_vocab) + 1,), dtype=np.float32)
edge_attr_one_hot[self.parse_vocab.get(edge_attr, len(self.parse_vocab))] = 1.0
question_edge_attribute.append(edge_attr_one_hot)
question_edge_index = np.array(question_edge_index, dtype=np.float64)
question_edge_attribute = np.stack(question_edge_attribute, axis=0)
question_edge_indices_all = [torch.from_numpy(question_edge_index).t().long().contiguous() for _ in range(num_options)]
question_edge_attributes_all = [torch.from_numpy(question_edge_attribute).contiguous() for _ in range(num_options)]
with open(os.path.join(hist_adj_matrices_dir, f'{graph_idx}.pkl'), 'rb') as f:
_history_edge_incides_all = pickle.load(f)
_history_edge_incides_last = _history_edge_incides_all[-1]
history_edge_index_all = [torch.tensor(_history_edge_incides_last).t().long().contiguous() for _ in range(num_options)]
if self.config['stack_gr_data']:
question_edge_indices_all = torch.stack(question_edge_indices_all, dim=0)
question_edge_attributes_all = torch.stack(question_edge_attributes_all, dim=0)
history_edge_index_all = torch.stack(history_edge_index_all, dim=0)
history_sep_indices_all = torch.stack(history_sep_indices_all, dim=0)
len_question_gr = torch.tensor(question_edge_indices_all.size(-1)).unsqueeze(0).repeat(num_options, 1)
len_history_gr = torch.tensor(history_edge_index_all.size(-1)).repeat(num_options, 1)
len_history_sep = torch.tensor(history_sep_indices_all.size(-1)).repeat(num_options, 1)
item = {}
item['tokens'] = tokens_all.unsqueeze(0)
item['segments'] = segments_all.unsqueeze(0)
item['sep_indices'] = sep_indices_all.unsqueeze(0)
item['mask'] = mask_all.unsqueeze(0)
item['hist_len'] = hist_len_all.unsqueeze(0)
item['question_limits'] = question_limits_all
item['question_edge_indices'] = question_edge_indices_all
item['question_edge_attributes'] = question_edge_attributes_all
item['history_edge_indices'] = history_edge_index_all
item['history_sep_indices'] = history_sep_indices_all
if self.config['stack_gr_data']:
item['len_question_gr'] = len_question_gr
item['len_history_gr'] = len_history_gr
item['len_history_sep'] = len_history_sep
item['round_id'] = torch.LongTensor([dialog['round_id']])
# get image features
if not self.config['dataloader_text_only']:
features, num_boxes, boxes, _ , image_target, image_edge_indexes, image_edge_attributes = self._image_features_reader[img_id]
features, spatials, image_mask, image_target, image_label = encode_image_input(features, num_boxes, boxes, image_target, max_regions=self._max_region_num, mask_prob=0)
else:
features = spatials = image_mask = image_target = image_label = torch.tensor([0])
item['image_feat'] = features
item['image_loc'] = spatials
item['image_mask'] = image_mask
item['image_target'] = image_target
item['image_label'] = image_label
item['image_id'] = torch.LongTensor([img_id])
if self._split == 'train':
# cheap hack to account for the graph data for the postitive and negatice examples
item['image_edge_indices'] = [torch.from_numpy(image_edge_indexes).long(), torch.from_numpy(image_edge_indexes).long()]
item['image_edge_attributes'] = [torch.from_numpy(image_edge_attributes), torch.from_numpy(image_edge_attributes)]
elif self._split == 'val':
# cheap hack to account for the graph data for the postitive and negatice examples
item['image_edge_indices'] = [torch.from_numpy(image_edge_indexes).contiguous().long() for _ in range(1000)]
item['image_edge_attributes'] = [torch.from_numpy(image_edge_attributes).contiguous() for _ in range(1000)]
else:
# cheap hack to account for the graph data for the postitive and negatice examples
item['image_edge_indices'] = [torch.from_numpy(image_edge_indexes).contiguous().long() for _ in range(100)]
item['image_edge_attributes'] = [torch.from_numpy(image_edge_attributes).contiguous() for _ in range(100)]
if self.config['stack_gr_data']:
item['image_edge_indices'] = torch.stack(item['image_edge_indices'], dim=0)
item['image_edge_attributes'] = torch.stack(item['image_edge_attributes'], dim=0)
len_image_gr = torch.tensor(item['image_edge_indices'].size(-1)).unsqueeze(0).repeat(num_options)
item['len_image_gr'] = len_image_gr
return item

View file

@ -0,0 +1,313 @@
import torch
import json
import os
import time
import numpy as np
import random
from tqdm import tqdm
import copy
import pyhocon
import glog as log
from collections import OrderedDict
import argparse
import pickle
import torch.utils.data as tud
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from utils.data_utils import encode_input, encode_image_input
from dataloader.dataloader_base import DatasetBase
class VisdialDenseDataset(DatasetBase):
def __init__(self, config):
super(VisdialDenseDataset, self).__init__(config)
with open(config.tr_graph_idx_mapping, 'r') as f:
self.tr_graph_idx_mapping = json.load(f)
with open(config.val_graph_idx_mapping, 'r') as f:
self.val_graph_idx_mapping = json.load(f)
with open(config.test_graph_idx_mapping, 'r') as f:
self.test_graph_idx_mapping = json.load(f)
self.question_gr_paths = {
'train': os.path.join(self.config['visdial_question_adj_matrices'], 'train'),
'val': os.path.join(self.config['visdial_question_adj_matrices'], 'val'),
'test': os.path.join(self.config['visdial_question_adj_matrices'], 'test')
}
self.history_gr_paths = {
'train': os.path.join(self.config['visdial_history_adj_matrices'], 'train'),
'val': os.path.join(self.config['visdial_history_adj_matrices'], 'val'),
'test': os.path.join(self.config['visdial_history_adj_matrices'], 'test')
}
def __getitem__(self, index):
MAX_SEQ_LEN = self.config['max_seq_len']
cur_data = None
cur_dense_annotations = None
if self._split == 'train':
cur_data = self.visdial_data_train['data']
cur_dense_annotations = self.visdial_data_train_dense
cur_question_gr_path = self.question_gr_paths['train']
cur_history_gr_path = self.history_gr_paths['train']
cur_gr_mapping = self.tr_graph_idx_mapping
if self.config['rlv_hst_only']:
cur_rlv_hst = self.rlv_hst_train
elif self._split == 'val':
cur_data = self.visdial_data_val['data']
cur_dense_annotations = self.visdial_data_val_dense
cur_question_gr_path = self.question_gr_paths['val']
cur_history_gr_path = self.history_gr_paths['val']
cur_gr_mapping = self.val_graph_idx_mapping
if self.config['rlv_hst_only']:
cur_rlv_hst = self.rlv_hst_val
elif self._split == 'trainval':
if index >= self.numDataPoints['train']:
cur_data = self.visdial_data_val['data']
cur_dense_annotations = self.visdial_data_val_dense
cur_gr_mapping = self.val_graph_idx_mapping
index -= self.numDataPoints['train']
cur_question_gr_path = self.question_gr_paths['val']
cur_history_gr_path = self.history_gr_paths['val']
if self.config['rlv_hst_only']:
cur_rlv_hst = self.rlv_hst_val
else:
cur_data = self.visdial_data_train['data']
cur_dense_annotations = self.visdial_data_train_dense
cur_question_gr_path = self.question_gr_paths['train']
cur_gr_mapping = self.tr_graph_idx_mapping
cur_history_gr_path = self.history_gr_paths['train']
if self.config['rlv_hst_only']:
cur_rlv_hst = self.rlv_hst_train
elif self._split == 'test':
cur_data = self.visdial_data_test['data']
cur_question_gr_path = self.question_gr_paths['test']
cur_history_gr_path = self.history_gr_paths['test']
if self.config['rlv_hst_only']:
cur_rlv_hst = self.rlv_hst_test
# number of options to score on
num_options = self.num_options_dense
if self._split == 'test' or self.config['validating'] or self.config['predicting']:
assert num_options == 100
else:
assert num_options >=1 and num_options <= 100
dialog = cur_data['dialogs'][index]
cur_questions = cur_data['questions']
cur_answers = cur_data['answers']
img_id = dialog['image_id']
if self._split != 'test':
graph_idx = cur_gr_mapping[str(img_id)]
else:
graph_idx = index
if self._split != 'test':
assert img_id == cur_dense_annotations[index]['image_id']
if self.config['rlv_hst_only']:
rlv_hst = cur_rlv_hst[str(img_id)] # [10 for each round, 10 for cap + first 9 round ]
if self._split == 'test':
cur_rounds = len(dialog['dialog']) # 1, 2, ..., 10
else:
cur_rounds = cur_dense_annotations[index]['round_id'] # 1, 2, ..., 10
# caption
cur_rnd_utterance = []
include_caption = True
if self.config['rlv_hst_only']:
if self.config['rlv_hst_dense_round']:
if rlv_hst[0] == 0:
include_caption = False
elif rlv_hst[cur_rounds - 1][0] == 0:
include_caption = False
if include_caption:
sent = dialog['caption'].split(' ')
tokenized_sent = self.tokenizer.convert_tokens_to_ids(sent)
cur_rnd_utterance.append(tokenized_sent)
# tot_len += len(sent) + 1
for rnd, utterance in enumerate(dialog['dialog'][:cur_rounds]):
if self.config['rlv_hst_only'] and rnd < cur_rounds - 1:
if self.config['rlv_hst_dense_round']:
if rlv_hst[rnd + 1] == 0:
continue
elif rlv_hst[cur_rounds - 1][rnd + 1] == 0:
continue
# question
sent = cur_questions[utterance['question']].split(' ')
tokenized_sent = self.tokenizer.convert_tokens_to_ids(sent)
cur_rnd_utterance.append(tokenized_sent)
# answer
if rnd != cur_rounds - 1:
sent = cur_answers[utterance['answer']].split(' ')
tokenized_sent = self.tokenizer.convert_tokens_to_ids(sent)
cur_rnd_utterance.append(tokenized_sent)
if self.config['rlv_hst_only']:
num_rlv_rnds = len(cur_rnd_utterance) - 1
else:
num_rlv_rnds = None
if self._split != 'test':
gt_option = dialog['dialog'][cur_rounds - 1]['gt_index']
if self.config['training'] or self.config['debugging']:
# first select gt option id, then choose the first num_options inds
option_inds = []
option_inds.append(gt_option)
all_inds = list(range(100))
all_inds.remove(gt_option)
# debug
if num_options < 100:
random.shuffle(all_inds)
all_inds = all_inds[:(num_options-1)]
option_inds.extend(all_inds)
gt_option = 0
else:
option_inds = range(num_options)
answer_options = [dialog['dialog'][cur_rounds - 1]['answer_options'][k] for k in option_inds]
if 'relevance' in cur_dense_annotations[index]:
key = 'relevance'
else:
key = 'gt_relevance'
gt_relevance = torch.Tensor(cur_dense_annotations[index][key])
gt_relevance = gt_relevance[option_inds]
assert len(answer_options) == len(option_inds) == num_options
else:
answer_options = dialog['dialog'][-1]['answer_options']
assert len(answer_options) == num_options
options_all = []
for answer_option in answer_options:
cur_option = cur_rnd_utterance.copy()
cur_option.append(self.tokenizer.convert_tokens_to_ids(cur_answers[answer_option].split(' ')))
options_all.append(cur_option)
if not self.config['rlv_hst_only']:
assert len(cur_option) == 2 * cur_rounds + 1
tokens_all = []
mask_all = []
segments_all = []
sep_indices_all = []
hist_len_all = []
tot_len_debug = []
for opt_id, option in enumerate(options_all):
option, start_segment = self.pruneRounds(option, self.config['visdial_tot_rounds'])
tokens, segments, sep_indices, mask, start_question, end_question = encode_input(option, start_segment ,self.CLS,
self.SEP, self.MASK ,max_seq_len=MAX_SEQ_LEN, mask_prob=0)
tokens_all.append(tokens)
mask_all.append(mask)
segments_all.append(segments)
sep_indices_all.append(sep_indices)
hist_len_all.append(torch.LongTensor([len(option)-1]))
len_tokens = sum(len(s) for s in option)
tot_len_debug.append(len_tokens + len(option) + 1)
tokens_all = torch.cat(tokens_all,0)
mask_all = torch.cat(mask_all,0)
segments_all = torch.cat(segments_all, 0)
sep_indices_all = torch.cat(sep_indices_all, 0)
hist_len_all = torch.cat(hist_len_all,0)
question_limits_all = torch.tensor([start_question, end_question]).unsqueeze(0).repeat(num_options, 1)
if self.config['rlv_hst_only']:
assert num_rlv_rnds > 0
hist_idx = [i * 2 for i in range(num_rlv_rnds)]
else:
hist_idx = [i*2 for i in range(cur_rounds)]
history_sep_indices_all = sep_indices.squeeze(0)[hist_idx].contiguous().unsqueeze(0).repeat(num_options, 1)
with open(os.path.join(cur_question_gr_path, f'{graph_idx}.pkl'), 'rb') as f:
question_graphs = pickle.load(f)
question_graph_round = question_graphs[cur_rounds - 1]
question_edge_index = []
question_edge_attribute = []
for edge_index, edge_attr in question_graph_round:
question_edge_index.append(edge_index)
edge_attr_one_hot = np.zeros((len(self.parse_vocab) + 1,), dtype=np.float32)
edge_attr_one_hot[self.parse_vocab.get(edge_attr, len(self.parse_vocab))] = 1.0
question_edge_attribute.append(edge_attr_one_hot)
question_edge_index = np.array(question_edge_index, dtype=np.float64)
question_edge_attribute = np.stack(question_edge_attribute, axis=0)
question_edge_indices_all = [torch.from_numpy(question_edge_index).t().long().contiguous() for _ in range(num_options)]
question_edge_attributes_all = [torch.from_numpy(question_edge_attribute).contiguous() for _ in range(num_options)]
if self.config['rlv_hst_only']:
with open(os.path.join(cur_history_gr_path, f'{graph_idx}.pkl'), 'rb') as f:
_history_edge_incides_round = pickle.load(f)
else:
with open(os.path.join(cur_history_gr_path, f'{graph_idx}.pkl'), 'rb') as f:
_history_edge_incides_all = pickle.load(f)
_history_edge_incides_round = _history_edge_incides_all[cur_rounds - 1]
history_edge_index_all = [torch.tensor(_history_edge_incides_round).t().long().contiguous() for _ in range(num_options)]
if self.config['stack_gr_data']:
question_edge_indices_all = torch.stack(question_edge_indices_all, dim=0)
question_edge_attributes_all = torch.stack(question_edge_attributes_all, dim=0)
history_edge_index_all = torch.stack(history_edge_index_all, dim=0)
item = {}
item['tokens'] = tokens_all.unsqueeze(0) # [1, num_options, max_len]
item['segments'] = segments_all.unsqueeze(0)
item['sep_indices'] = sep_indices_all.unsqueeze(0)
item['mask'] = mask_all.unsqueeze(0)
item['hist_len'] = hist_len_all.unsqueeze(0)
item['question_limits'] = question_limits_all
item['question_edge_indices'] = question_edge_indices_all
item['question_edge_attributes'] = question_edge_attributes_all
item['history_edge_indices'] = history_edge_index_all
item['history_sep_indices'] = history_sep_indices_all
# add dense annotation fields
if self._split != 'test':
item['gt_relevance'] = gt_relevance # [num_options]
item['gt_option_inds'] = torch.LongTensor([gt_option])
# add next sentence labels for training with the nsp loss as well
nsp_labels = torch.ones(*tokens_all.unsqueeze(0).shape[:-1]).long()
nsp_labels[:,gt_option] = 0
item['next_sentence_labels'] = nsp_labels
item['round_id'] = torch.LongTensor([cur_rounds])
else:
if 'round_id' in dialog:
item['round_id'] = torch.LongTensor([dialog['round_id']])
else:
item['round_id'] = torch.LongTensor([cur_rounds])
# get image features
if not self.config['dataloader_text_only']:
features, num_boxes, boxes, _ , image_target, image_edge_indexes, image_edge_attributes = self._image_features_reader[img_id]
features, spatials, image_mask, image_target, image_label = encode_image_input(features, num_boxes, boxes, image_target, max_regions=self._max_region_num, mask_prob=0)
else:
features = spatials = image_mask = image_target = image_label = torch.tensor([0])
item['image_feat'] = features
item['image_loc'] = spatials
item['image_mask'] = image_mask
item['image_id'] = torch.LongTensor([img_id])
item['tot_len'] = torch.LongTensor(tot_len_debug)
item['image_edge_indices'] = [torch.from_numpy(image_edge_indexes).contiguous().long() for _ in range(num_options)]
item['image_edge_attributes'] = [torch.from_numpy(image_edge_attributes).contiguous() for _ in range(num_options)]
if self.config['stack_gr_data']:
item['image_edge_indices'] = torch.stack(item['image_edge_indices'], dim=0)
item['image_edge_attributes'] = torch.stack(item['image_edge_attributes'], dim=0)
return item

114
ensemble.py Normal file
View file

@ -0,0 +1,114 @@
import os
import os.path as osp
import numpy as np
import json
import argparse
import pyhocon
import glog as log
import torch
from tqdm import tqdm
from utils.data_utils import load_pickle_lines
from utils.visdial_metrics import scores_to_ranks
parser = argparse.ArgumentParser(description='Ensemble for VisDial')
parser.add_argument('--exp', type=str, default='test',
help='experiment name from .conf')
parser.add_argument('--mode', type=str, default='predict', choices=['eval', 'predict'],
help='eval or predict')
parser.add_argument('--ssh', action='store_true',
help='whether or not we are executing command via ssh. '
'If set to True, we will not log.info anything to screen and only redirect them to log file')
if __name__ == '__main__':
args = parser.parse_args()
# initialization
config = pyhocon.ConfigFactory.parse_file(f"config/ensemble.conf")[args.exp]
config["log_dir"] = os.path.join(config["log_dir"], args.exp)
if not os.path.exists(config["log_dir"]):
os.makedirs(config["log_dir"])
# set logs
log_file = os.path.join(config["log_dir"], f'{args.mode}.log')
set_log_file(log_file, file_only=args.ssh)
# print environment info
log.info(f"Running experiment: {args.exp}")
log.info(f"Results saved to {config['log_dir']}")
log.info(pyhocon.HOCONConverter.convert(config, "hocon"))
if isinstance(config['processed'], list):
assert len(config['models']) == len(config['processed'])
processed = {model:pcd for model, pcd in zip(config['models'], config['processed'])}
else:
processed = {model: config['processed'] for model in config['models']}
if config['split'] == 'test' and np.any(config['processed']):
test_data = json.load(open(config['visdial_test_data']))['data']['dialogs']
imid2rndid = {t['image_id']: len(t['dialog']) for t in test_data}
del test_data
# load predictions files
visdial_outputs = dict()
if args.mode == 'eval':
metrics = {}
for model in config['models']:
pred_filename = osp.join(config['pred_dir'], model, 'visdial_prediction.pkl')
pred_dict = {p['image_id']: p for p in load_pickle_lines(pred_filename)}
log.info(f'Loading {len(pred_dict)} predictions from {pred_filename}')
visdial_outputs[model] = pred_dict
if args.mode == 'eval':
assert len(visdial_outputs[model]) >= num_dialogs
metric = json.load(open(osp.join(config['pred_dir'], model, "metrics_epoch_best.json")))
metrics[model] = metric['val']
image_ids = visdial_outputs[model].keys()
predictions = []
# for each dialog
for image_id in tqdm(image_ids):
scores = []
round_id = None
for model in config['models']:
pred = visdial_outputs[model][image_id]
if config['split'] == 'test' and processed[model]:
# if predict on processed data, the first few rounds are deleted from some dialogs
# so the original round ids can only be found in the original test data
round_id_in_pred = imid2rndid[image_id]
else:
round_id_in_pred = pred['gt_relevance_round_id']
if not isinstance(round_id_in_pred, int):
round_id_in_pred = int(round_id_in_pred)
if round_id is None:
round_id = round_id_in_pred
else:
# make sure all models have the same round_id
assert round_id == round_id_in_pred
scores.append(torch.from_numpy(pred['nsp_probs']).unsqueeze(0))
# ensemble scores
scores = torch.cat(scores, 0) # [n_model, num_rounds, num_options]
scores = torch.sum(scores, dim=0, keepdim=True) # [1, num_rounds, num_options]
if scores.size(0) > 1:
scores = scores[round_id - 1].unsqueeze(0)
ranks = scores_to_ranks(scores) # [eval_batch_size, num_rounds, num_options]
ranks = ranks.squeeze(1)
prediction = {
"image_id": image_id,
"round_id": round_id,
"ranks": ranks[0].tolist()
}
predictions.append(prediction)
filename = osp.join(config['log_dir'], f'{config["split"]}_ensemble_preds.json')
with open(filename, 'w') as f:
json.dump(predictions, f)
log.info(f'{len(predictions)} predictions saved to {filename}')

199
main.py Normal file
View file

@ -0,0 +1,199 @@
from utils.init_utils import load_runner, load_dataset, set_random_seed, set_training_steps, initialize_from_env, set_log_file, copy_file_to_log
import torch.distributed as dist
import torch.nn as nn
import torch.multiprocessing as mp
import torch
import os
import sys
import argparse
import pyhocon
import glog as log
import socket
import getpass
try:
from apex.parallel import DistributedDataParallel as DDP
from apex import amp
except ModuleNotFoundError:
print('apex not found')
parser = argparse.ArgumentParser(description='Main script for VD-GR')
parser.add_argument(
'--model',
type=str,
default='vdgr/P1',
help='model name to train or test')
parser.add_argument(
'--mode',
type=str,
default='train',
help='train, eval, predict or debug')
parser.add_argument(
'--wandb_project',
type=str,
default='VD-GR'
)
parser.add_argument(
'--wandb_mode',
type=str,
default='online',
choices=['online', 'offline', 'disabled', 'run', 'dryrun']
)
parser.add_argument(
'--tag',
type=str,
default='K2',
help="Tag to differentiate the different runs"
)
parser.add_argument(
'--eval_dir',
type=str,
default='',
help="Directory of a trained model to evaluate"
)
parser.add_argument('--ssh', action='store_true',
help='whether or not we are executing command via ssh. '
'If set to True, we will not log.info anything to screen and only redirect them to log file')
def main(gpu, config, args):
config['training'] = args.mode == 'train'
config['validating'] = args.mode == 'eval'
config['debugging'] = args.mode == 'debug'
config['predicting'] = args.mode == 'predict'
config['wandb_project'] = args.wandb_project
config['wandb_mode'] = args.wandb_mode
if config['parallel'] and config['dp_type'] != 'dp':
config['rank'] = gpu
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(config['master_port'])
dist.init_process_group(
backend='nccl',
world_size=config['num_gpus'],
rank=gpu
)
config['display'] = gpu == 0
if config['dp_type'] == 'apex':
torch.cuda.set_device(gpu)
else:
config['display'] = True
if config['debugging'] or (config['parallel'] and config['dp_type'] != 'dp'):
config['num_workers'] = 0
else:
config['num_workers'] = 0
# set logs
log_file = os.path.join(config["log_dir"], f'{args.mode}.log')
set_log_file(log_file, file_only=args.ssh)
# print environment info
if config['display']:
log.info('Host: {}, user: {}, CUDA_VISIBLE_DEVICES: {}, cwd: {}'.format(
socket.gethostname(), getpass.getuser(), os.environ.get('CUDA_VISIBLE_DEVICES', ''), os.getcwd()))
log.info('Command line is: {}'.format(' '.join(sys.argv)))
if config['parallel'] and config['dp_type'] != 'dp':
log.info(
f'World_size: {config["num_gpus"]}, cur rank: {config["rank"]}')
log.info(f"Running experiment: {args.model}")
log.info(f"Results saved to {config['log_dir']}")
# initialization
if config['display'] and config['training']:
copy_file_to_log(config['log_dir'])
set_random_seed(config['random_seed'])
device = torch.device(f"cuda:{gpu}")
if config["use_cpu"]:
device = torch.device("cpu")
config['device'] = device
# prepare dataset
dataset, dataset_eval = load_dataset(config)
# set training steps
if not config['validating'] or config['parallel']:
config = set_training_steps(config, len(dataset))
if config['display']:
log.info(pyhocon.HOCONConverter.convert(config, "hocon"))
# load runner
runner = load_runner(config)
# apex
if config['dp_type'] == 'apex':
runner.model, runner.optimizer = amp.initialize(runner.model,
runner.optimizer,
opt_level="O1")
# parallel
if config['parallel']:
if config['dp_type'] == 'dp':
runner.model = nn.DataParallel(runner.model)
runner.model.to(config['device'])
elif config['dp_type'] == 'apex':
runner.model = DDP(runner.model)
elif config['dp_type'] == 'ddp':
torch.cuda.set_device(gpu)
runner.model = runner.model.to(gpu)
runner.model = nn.parallel.DistributedDataParallel(
runner.model,
device_ids=[gpu],
output_device=gpu,
find_unused_parameters=True)
else:
raise ValueError(f'Unrecognized dp_type: {config["dp_type"]}')
if config['training'] or config['debugging']:
runner.load_pretrained_vilbert()
runner.train(dataset, dataset_eval)
else:
if config['loads_start_path']:
runner.load_pretrained_vilbert()
else:
runner.load_ckpt_best()
metrics_results = {}
if config['predicting']:
eval_splits = [config['predict_split']]
else:
eval_splits = ['val']
if config['model_type'] == 'conly' and not config['train_each_round']:
eval_splits.append('test')
for split in eval_splits:
if config['display']:
log.info(f'Results on {split} split of the best epoch')
if dataset_eval is None:
dataset_to_eval = dataset
else:
dataset_to_eval = dataset_eval
dataset_to_eval.split = split
_, metrics_results[split] = runner.evaluate(
dataset_to_eval, eval_visdial=True)
if not config['predicting'] and config['display']:
runner.save_eval_results(split, 'best', metrics_results)
if config['parallel'] and config['dp_type'] != 'dp':
dist.destroy_process_group()
if __name__ == '__main__':
args = parser.parse_args()
# initialization
model_type, model_name = args.model.split('/')
config = initialize_from_env(
model_name, args.mode, args.eval_dir, model_type, tag=args.tag)
if config['num_gpus'] > 1:
config['parallel'] = True
if config['dp_type'] == 'dp':
main(0, config, args)
else:
mp.spawn(main, nprocs=config['num_gpus'], args=(config, args))
else:
config['parallel'] = False
main(0, config, args)

0
misc/.gitkeep Normal file
View file

BIN
misc/teaser_1.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.2 MiB

BIN
misc/teaser_2.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.6 MiB

BIN
misc/usa.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 11 KiB

0
models/__init__.py Normal file
View file

830
models/runner.py Normal file
View file

@ -0,0 +1,830 @@
import os
import os.path as osp
import json
from collections import deque
import time
import re
import shutil
import glob
import pickle
import gc
import numpy as np
import glog as log
try:
from apex import amp
except ModuleNotFoundError:
print('apex not found')
import torch
import torch.utils.data as tud
import torch.nn.functional as F
import torch.distributed as dist
from utils.data_utils import load_pickle_lines
from utils.visdial_metrics import SparseGTMetrics, NDCG, scores_to_ranks
import wandb
class Runner:
def __init__(self, config):
self.config = config
if 'rank' in config:
self.gpu_rank = config['rank']
else:
self.gpu_rank = 0
self.epoch_idx = 0
self.max_metric = 0.
self.max_metric_epoch_idx = 0
self.na_str = 'N/A'
if self.config["max_ckpt_to_keep"] > 0:
self.checkpoint_queue = deque(
[], maxlen=config["max_ckpt_to_keep"])
self.metrics_queue = deque([], maxlen=config["max_ckpt_to_keep"])
self.setup_wandb()
def setup_wandb(self):
if self.gpu_rank == 0:
print("[INFO] Set wandb logging on rank {}".format(0))
run = wandb.init(
project=self.config['wandb_project'], config=self.config, mode=self.config['wandb_mode'])
else:
run = None
self.run = run
def forward(self, batch, eval_visdial=False):
return NotImplementedError
def train(self, dataset, dataset_eval=None):
# wandb.login()
if os.path.exists(self.config['log_dir']) or self.config['loads_ckpt'] or self.config['loads_best_ckpt']:
self.load_ckpt()
if self.config['use_trainval']:
dataset.split = 'trainval'
else:
dataset.split = 'train'
batch_size = self.config['batch_size']
if self.config['parallel'] and self.config['dp_type'] != 'dp':
sampler_tr = tud.distributed.DistributedSampler(
dataset,
num_replicas=self.config['num_gpus'],
rank=self.gpu_rank
)
else:
sampler_tr = None
data_loader_tr = tud.DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=self.config['training'] and not self.config['parallel'],
collate_fn=dataset.collate_fn,
num_workers=self.config['num_workers'],
sampler=sampler_tr
)
start_epoch_idx = self.epoch_idx
num_iter_epoch = self.config['num_iter_per_epoch']
if self.config['display']:
log.info(f'{num_iter_epoch} iter per epoch.')
# eval before training
eval_dense_at_first = self.config['train_on_dense'] and self.config['skip_mrr_eval'] and start_epoch_idx == 0
# eval before training under 2 circumstances:
# for dense finetuning, eval ndcg before the first epoch
# for mrr training, continue training and the last epoch is not evaluated
if (eval_dense_at_first or (self.config['eval_at_start'] and len(self.metrics_queue) == 0 and start_epoch_idx > 0)):
if eval_dense_at_first:
iter_now = 0
else:
iter_now = max(num_iter_epoch * start_epoch_idx, 0)
if dataset_eval is None:
dataset.split = 'val'
dataset_to_eval = dataset
else:
dataset_to_eval = dataset_eval
metrics_results = {}
metrics_to_maximize, metrics_results['val'] = self.evaluate(
dataset_to_eval, iter_now)
if eval_dense_at_first:
self.max_metric = metrics_to_maximize
self.max_metric_epoch_idx = -1
else:
if self.config['display']:
self.save_eval_results(
'val', start_epoch_idx - 1, metrics_results)
if metrics_to_maximize > self.max_metric:
self.max_metric = metrics_to_maximize
self.max_metric_epoch_idx = start_epoch_idx - 1
self.copy_best_results('val', start_epoch_idx - 1)
self.copy_best_predictions('val')
if dataset_eval is None:
if self.config['use_trainval']:
dataset.split = 'trainval'
else:
dataset.split = 'train'
num_epochs = self.config['num_epochs']
for epoch_idx in range(start_epoch_idx, num_epochs):
if self.config['parallel'] and self.config['dp_type'] != 'dp':
sampler_tr.set_epoch(epoch_idx)
self.epoch_idx = epoch_idx
if self.config['display']:
log.info(f'starting epoch {epoch_idx}')
log.info('training')
self.model.train()
num_batch = 0
next_logging_pct = .1
next_evaluating_pct = self.config["next_evaluating_pct"] + .1
start_time = time.time()
self.optimizer.zero_grad()
for batch in data_loader_tr:
if self.config['eval_before_training']:
log.info('Skipping stright to evaluation...')
break
num_batch += 1
pct = num_batch / num_iter_epoch * 100
iter_now = num_iter_epoch * epoch_idx + num_batch
output = self.forward(batch)
losses = output['losses']
# optimizer step
losses['tot_loss'] /= self.config['batch_multiply']
# debug
if self.config['debugging']:
log.info('try backward')
if self.config['dp_type'] == 'apex':
with amp.scale_loss(losses['tot_loss'], self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
losses['tot_loss'].backward()
if self.config['debugging']:
log.info('backward done')
if iter_now % self.config['batch_multiply'] == 0:
self.optimizer.step()
self.optimizer.zero_grad()
self.scheduler.step()
# display and eval
if pct >= next_logging_pct:
if self.config['display']:
loss_to_print = ''
for key in losses:
if losses[key] is not None and isinstance(losses[key], torch.Tensor):
loss_to_print += f'[{key}: {losses[key].item():.4f}]'
print(
f'[{int(pct)}%][Epoch: {epoch_idx + 1}/{num_epochs}][Iter : {num_batch}/{len(data_loader_tr)}] [time: {time.time() - start_time:.2f}] {loss_to_print}'
)
next_logging_pct += self.config["next_logging_pct"]
if self.config['debugging']:
break
if pct >= next_evaluating_pct:
next_evaluating_pct += self.config["next_evaluating_pct"]
if self.run:
if self.config['train_on_dense']:
self.run.log(
{
"Train/dense_loss": losses['dense_loss'],
"Train/total_loss": losses['tot_loss'],
},
step=iter_now
)
else:
self.run.log(
{
"Train/lm_loss": losses['lm_loss'],
"Train/img_loss": losses['img_loss'],
"Train/nsp_loss": losses['nsp_loss'],
"Train/total_loss": losses['tot_loss'],
},
step=iter_now
)
lr_gnn, lr_bert = self.scheduler.get_lr()[0], self.scheduler.get_lr()[1]
self.run.log(
{
"Train/lr_gnn": lr_gnn,
"Train/lr_bert": lr_bert,
},
step=iter_now
)
del losses
# debug
torch.cuda.empty_cache()
if self.config['display']:
log.info(
f'100%,\ttime:\t{time.time() - start_time:.2f}'
)
ckpt_path = self.save_ckpt()
if not self.config['skip_visdial_eval'] and self.epoch_idx % self.config['eval_visdial_every'] == 0:
iter_now = num_iter_epoch * (epoch_idx + 1)
if dataset_eval is None:
dataset.split = 'val'
dataset_to_eval = dataset
else:
dataset_to_eval = dataset_eval
metrics_results = {}
metrics_to_maximize, metrics_results['val'] = self.evaluate(
dataset_to_eval, iter_now)
if dataset_eval is None:
if self.config['use_trainval']:
dataset.split = 'trainval'
else:
dataset.split = 'train'
if self.config['display']:
self.save_eval_results('val', epoch_idx, metrics_results)
if self.config['display']:
if metrics_to_maximize > self.max_metric:
self.max_metric = metrics_to_maximize
self.max_metric_epoch_idx = epoch_idx
self.copy_best_results('val', epoch_idx)
self.copy_best_predictions('val')
elif not self.config['parallel'] and epoch_idx - self.max_metric_epoch_idx > self.config["early_stop_epoch"]:
log.info('Early stop.')
break
if self.run:
self.run.log(
{"Val/metric_best": self.max_metric}, step=iter_now)
if self.config['parallel']:
if self.config['dp_type'] == 'dp':
gc.collect()
torch.cuda.empty_cache()
else:
dist.barrier()
log.info('Rank {} passed barrier...'.format(self.gpu_rank))
if self.config['stop_epochs'] >= 0 and epoch_idx + 1 >= self.config['stop_epochs']:
if self.config['display']:
log.info('Stop for reaching stop_epochs.')
break
def evaluate(self, dataset, training_iter=None, eval_visdial=True):
# create files to save output
if self.config['predicting']:
visdial_file_name = None
if self.config['save_score']:
visdial_file_name = osp.join(
self.config['log_dir'], f'visdial_prediction.pkl')
if osp.exists(visdial_file_name):
dialogs_predicted = load_pickle_lines(
visdial_file_name)
dialogs_predicted = [d['image_id']
for d in dialogs_predicted]
else:
dialogs_predicted = []
f_visdial = open(visdial_file_name, 'ab')
else:
visdial_file_name = osp.join(
self.config['log_dir'], f'visdial_prediction.jsonlines')
if self.config['parallel'] and self.config['dp_type'] != 'dp':
visdial_file_name = visdial_file_name.replace(
'.jsonlines', f'_{self.config["rank"]}of{self.config["num_gpus"]}.jsonlines')
if osp.exists(visdial_file_name):
dialogs_predicted_visdial = [json.loads(
line)['image_id'] for line in open(visdial_file_name)]
f_visdial = open(visdial_file_name, 'a')
else:
dialogs_predicted_visdial = []
f_visdial = open(visdial_file_name, 'w')
dialogs_predicted = dialogs_predicted_visdial
if len(dialogs_predicted) > 0:
log.info(f'Found {len(dialogs_predicted)} predicted results.')
if self.config['display']:
if visdial_file_name is not None:
log.info(
f'VisDial predictions saved to {visdial_file_name}')
elif self.config['display']:
if self.config['continue_evaluation']:
predicted_files = os.listdir(
osp.join(self.config['visdial_output_dir'], dataset.split))
dialogs_predicted = [
int(re.match(r'(\d+).npz', p).group(1)) for p in predicted_files]
else:
if osp.exists(osp.join(self.config['visdial_output_dir'], dataset.split)):
shutil.rmtree(
osp.join(self.config['visdial_output_dir'], dataset.split))
os.makedirs(
osp.join(self.config['visdial_output_dir'], dataset.split))
dialogs_predicted = []
log.info(f'Found {len(dialogs_predicted)} predicted results.')
if self.config['parallel'] and self.config['dp_type'] != 'dp':
sampler_val = tud.distributed.DistributedSampler(
dataset,
num_replicas=self.config['num_gpus'],
rank=self.gpu_rank
)
sampler_val.set_epoch(self.epoch_idx)
else:
sampler_val = None
data_loader_val = tud.DataLoader(
dataset=dataset,
batch_size=self.config['eval_batch_size'],
shuffle=False,
collate_fn=dataset.collate_fn,
num_workers=self.config['num_workers'],
sampler=sampler_val
)
self.model.eval()
with torch.no_grad():
if self.config['display']:
log.info(f'Evaluating {len(dataset)} samples')
next_logging_pct = self.config["next_logging_pct"] + .1
if self.config['parallel'] and self.config['dp_type'] == 'dp':
num_batch_tot = int(
np.ceil(len(dataset) / self.config['eval_batch_size']))
else:
num_batch_tot = int(np.ceil(
len(dataset) / (self.config['eval_batch_size'] * self.config['num_gpus'])))
num_batch = 0
if dataset.split == 'val':
num_options = self.config["num_options"]
if self.config['skip_mrr_eval']:
num_rounds = 1
else:
num_rounds = 10
elif dataset.split == 'test':
num_options = 100
num_rounds = 1
if self.gpu_rank == 0:
start_time = time.time()
for batch in data_loader_val:
num_batch += 1
# skip dialogs that have been predicted
if self.config['predicting']:
image_ids = batch['image_id'].tolist()
skip_batch = True
for image_id in image_ids:
if image_id not in dialogs_predicted:
skip_batch = False
if skip_batch:
continue
output = self.forward(
batch, eval_visdial=eval_visdial)
# visdial evaluation
if eval_visdial:
img_ids = batch['image_id'].tolist()
batch_size = len(img_ids)
if not self.config['skip_ndcg_eval']:
gt_relevance_round_id = batch['round_id'].tolist()
# [batch_size * num_rounds * num_options, 2]
nsp_scores = output['nsp_scores']
nsp_probs = F.softmax(nsp_scores, dim=1)
assert nsp_probs.shape[-1] == 2
# num_dim=2, 0 for postive, 1 for negative
nsp_probs = nsp_probs[:, 0]
nsp_probs = nsp_probs.view(
batch_size, num_rounds, num_options)
# could be predicting or evaluating
if dataset.split == 'val':
if self.config['skip_ndcg_eval']:
gt_option_inds = batch['gt_option_inds']
for b in range(batch_size):
filename = osp.join(
self.config['visdial_output_dir'], dataset.split, f'{img_ids[b]}.npz')
if not osp.exists(filename):
np.savez(
filename,
nsp_probs=nsp_probs[b].cpu().numpy(),
gt_option_inds=gt_option_inds[b].cpu().numpy()
)
else:
# [batch_size, num_rounds]
gt_option_inds = batch['gt_option_inds']
# [batch_size, num_options]
gt_relevance = batch['gt_relevance']
for b in range(batch_size):
filename = osp.join(
self.config['visdial_output_dir'], dataset.split, f'{img_ids[b]}.npz')
if not osp.exists(filename):
np.savez(filename,
nsp_probs=nsp_probs[b].cpu().numpy(),
gt_option_inds=gt_option_inds[b].cpu(
).numpy(),
gt_relevance=gt_relevance[b].cpu(
).numpy(),
gt_relevance_round_id=gt_relevance_round_id[b])
# must be predicting
if dataset.split == 'test':
if self.config['save_score']:
for b in range(batch_size):
prediction = {
"image_id": img_ids[b],
"nsp_probs": nsp_probs[b].cpu().numpy(),
"gt_relevance_round_id": gt_relevance_round_id[b]
}
pickle.dump(prediction, f_visdial)
else:
# [eval_batch_size, num_rounds, num_options]
ranks = scores_to_ranks(nsp_probs)
ranks = ranks.squeeze(1)
for b in range(batch_size):
prediction = {
"image_id": img_ids[b],
"round_id": gt_relevance_round_id[b],
"ranks": ranks[b].tolist()
}
f_visdial.write(json.dumps(prediction) + '\n')
# debug
if self.config['debugging']:
break
pct = num_batch / num_batch_tot * 100
if pct >= next_logging_pct:
if self.config['display'] and self.gpu_rank == 0:
log.info(
f'{int(pct)}%,\ttime:\t{time.time() - start_time:.2f}'
)
next_logging_pct += self.config["next_logging_pct"]
# debug
if self.config['debugging']:
break
if self.config['display'] and self.gpu_rank == 0:
pct = num_batch / num_batch_tot * 100
log.info(
f'{int(pct)}%,\ttime:\t{time.time() - start_time:.2f}'
)
if not self.config['validating']:
self.model.train()
if self.config['parallel'] and self.config['dp_type'] != 'dp':
dist.barrier()
print(f'{self.gpu_rank} passed barrier')
if self.config['predicting']:
f_visdial.close()
if not self.config['save_score']:
all_visdial_predictions = [json.loads(
line) for line in open(visdial_file_name)]
if self.config['predict_split'] == 'test' and len(all_visdial_predictions) == self.config['num_test_dialogs']:
visdial_file_name = visdial_file_name.replace(
'jsonlines', 'json')
with open(visdial_file_name, 'w') as f_visdial:
json.dump(all_visdial_predictions, f_visdial)
log.info(
f'Prediction for submisson save to {visdial_file_name}.')
return None, None
if self.config['display']:
if dataset.split == 'val' and eval_visdial:
if not self.config['skip_mrr_eval']:
sparse_metrics = SparseGTMetrics()
if not self.config['skip_ndcg_eval']:
ndcg = NDCG()
if dataset.split == 'val' and eval_visdial:
visdial_output_filenames = glob.glob(
osp.join(self.config['visdial_output_dir'], dataset.split, '*.npz'))
log.info(
f'Calculating visdial metrics for {len(visdial_output_filenames)} dialogs')
for visdial_output_filename in visdial_output_filenames:
output = np.load(visdial_output_filename)
nsp_probs = torch.from_numpy(
output['nsp_probs']).unsqueeze(0)
if not self.config['skip_ndcg_eval']:
gt_relevance = torch.from_numpy(output['gt_relevance']).unsqueeze(0)
if not self.config['skip_mrr_eval']:
gt_option_inds = torch.from_numpy(
output['gt_option_inds']).unsqueeze(0)
sparse_metrics.observe(nsp_probs, gt_option_inds)
if not self.config['skip_ndcg_eval']:
gt_relevance_round_id = output['gt_relevance_round_id']
nsp_probs_dense = nsp_probs[0, gt_relevance_round_id - 1, :].unsqueeze(0)
else:
nsp_probs_dense = nsp_probs.squeeze(0) # [1, 100]
if not self.config['skip_ndcg_eval']:
ndcg.observe(nsp_probs_dense, gt_relevance)
# visdial eval output
visdial_metrics = {}
if dataset.split == 'val' and eval_visdial:
if not self.config['skip_mrr_eval']:
visdial_metrics.update(sparse_metrics.retrieve(reset=True))
if not self.config['skip_ndcg_eval']:
visdial_metrics.update(ndcg.retrieve(reset=True))
if self.config['display']:
to_print = ''
for metric_name, metric_value in visdial_metrics.items():
if 'round' not in metric_name:
to_print += f"\n{metric_name}: {metric_value}"
if training_iter is not None:
if self.run:
self.run.log(
{'Val/' + metric_name: metric_value}, step=training_iter)
log.info(to_print)
if self.config['metrics_to_maximize'] in visdial_metrics:
metrics_to_maximize = visdial_metrics[self.config['metrics_to_maximize']]
else:
metrics_to_maximize = None
torch.cuda.empty_cache()
return metrics_to_maximize, visdial_metrics
else:
torch.cuda.empty_cache()
return None, None
def save_eval_results(self, split, epoch_idx, metrics_results):
metrics_filename = osp.join(
self.config['log_dir'], f'metrics_epoch_{epoch_idx}.json')
with open(metrics_filename, 'w') as f:
json.dump(metrics_results, f)
log.info(f'Results of metrics saved to {metrics_filename}')
if self.config["max_ckpt_to_keep"] > 0:
if len(self.metrics_queue) == self.metrics_queue.maxlen:
todel = self.metrics_queue.popleft()
os.remove(todel)
self.metrics_queue.append(metrics_filename)
if epoch_idx == 'best':
self.copy_best_predictions(split)
def copy_best_results(self, split, epoch_idx):
to_print = 'Copy '
if not self.config['skip_saving_ckpt']:
ckpt_path = osp.join(
self.config['log_dir'], f'epoch_{epoch_idx}.ckpt')
best_ckpt_path = ckpt_path.replace(
f'{epoch_idx}.ckpt', 'best.ckpt')
shutil.copyfile(ckpt_path, best_ckpt_path)
to_print += best_ckpt_path + ' '
metrics_filename = osp.join(
self.config['log_dir'], f'metrics_epoch_{epoch_idx}.json')
best_metric_filename = metrics_filename.replace(
f'{epoch_idx}.json', 'best.json')
shutil.copyfile(metrics_filename, best_metric_filename)
to_print += best_metric_filename + ' '
log.info(to_print)
def copy_best_predictions(self, split):
to_print = 'Copy '
visdial_output_dir = osp.join(self.config['visdial_output_dir'], split)
if osp.exists(visdial_output_dir):
dir_best = visdial_output_dir.replace('output', 'output_best')
if osp.exists(dir_best):
shutil.rmtree(dir_best)
shutil.copytree(visdial_output_dir, dir_best)
to_print += dir_best + ' '
log.info(to_print)
def get_ckpt(self):
ckpt = {
'epoch_idx': self.epoch_idx,
'max_metric': self.max_metric,
'seed': self.config['random_seed'],
'optimizer': self.optimizer.state_dict(),
'scheduler': self.scheduler.state_dict()
}
if self.config['parallel']:
ckpt['model_state_dict'] = self.model.module.state_dict()
else:
ckpt['model_state_dict'] = self.model.state_dict()
if self.config['dp_type'] == 'apex':
ckpt['amp'] = amp.state_dict()
return ckpt
def set_ckpt(self, ckpt_dict):
if not self.config['restarts']:
self.epoch_idx = ckpt_dict.get('epoch_idx', -1) + 1
if not self.config['resets_max_metric']:
self.max_metric = ckpt_dict.get('max_metric', -1)
if self.config['parallel']:
model = self.model.module
else:
model = self.model
model_state_dict = model.state_dict()
former_dict = {
k: v for k, v in ckpt_dict['model_state_dict'].items() if k in model_state_dict}
if self.config['display']:
log.info("number of keys transferred: %d" % len(former_dict))
assert len(former_dict.keys()) > 0
model_state_dict.update(former_dict)
model.load_state_dict(model_state_dict)
if self.config['display']:
log.info('loaded model')
del model_state_dict, former_dict
if not self.config['validating'] and not (self.config['uses_new_optimizer'] or self.config['sets_new_lr']):
if 'optimizer' in ckpt_dict:
self.optimizer.load_state_dict(ckpt_dict['optimizer'])
if self.config['display']:
log.info('loaded optimizer')
if 'scheduler' in ckpt_dict:
self.scheduler.last_epcoh = ckpt_dict['epoch_idx'] * \
self.config['num_iter_per_epoch']
self.scheduler.load_state_dict(ckpt_dict['scheduler'])
if 'amp' in ckpt_dict and self.config['dp_type'] == 'apex':
amp.load_state_dict(ckpt_dict['amp'])
del ckpt_dict
torch.cuda.empty_cache()
def save_ckpt(self):
ckpt_path = f'{self.config["log_dir"]}/epoch_{self.epoch_idx}.ckpt'
log.info(f'saving checkpoint {ckpt_path}')
ckpt = self.get_ckpt()
if self.config['skip_saving_ckpt']:
return ckpt_path
torch_version = float(torch.__version__[:3])
if torch_version - 1.4 > 1e-3:
torch.save(ckpt, f=ckpt_path, _use_new_zipfile_serialization=False)
else:
torch.save(ckpt, f=ckpt_path)
del ckpt
if not (self.config['parallel'] and self.config['dp_type'] in ['ddp', 'apex']):
torch.cuda.empty_cache()
if self.config["max_ckpt_to_keep"] > 0:
if len(self.checkpoint_queue) == self.checkpoint_queue.maxlen:
todel = self.checkpoint_queue.popleft()
os.remove(todel)
self.checkpoint_queue.append(ckpt_path)
return ckpt_path
def save_ckpt_best(self):
ckpt_path = f'{self.config["log_dir"]}/epoch_best.ckpt'
log.info(f'saving checkpoint {ckpt_path}')
ckpt = self.get_ckpt()
torch.save(ckpt, f=ckpt_path)
del ckpt
return ckpt_path
def load_ckpt_best(self):
ckpt_path = f'{osp.dirname(self.config["log_dir"])}/epoch_best.ckpt'
if not osp.exists(ckpt_path):
ckpt_paths = [path for path in os.listdir(
f'{self.config["log_dir"]}/') if path.endswith('.ckpt') and 'best' not in path]
if len(ckpt_paths) == 0:
if self.config['display']:
log.info(f'No .ckpt found in {self.config["log_dir"]}')
return
def sort_func(x): return int(re.search(r"(\d+)", x).groups()[0])
ckpt_path = f'{self.config["log_dir"]}/{sorted(ckpt_paths, key=sort_func, reverse=True)[0]}'
if self.config['display']:
log.info(f'loading checkpoint {ckpt_path}')
map_location = {'cuda:0': f'cuda:{self.gpu_rank}'}
self.set_ckpt(torch.load(ckpt_path, map_location=map_location))
def load_ckpt(self, ckpt_path=None):
if not ckpt_path:
if self.config['validating'] or self.config['loads_best_ckpt']:
ckpt_path = f'{self.config["log_dir"]}/epoch_best.ckpt'
else:
ckpt_paths = [path for path in os.listdir(
f'{self.config["log_dir"]}/') if path.endswith('.ckpt') and 'best' not in path]
if len(ckpt_paths) == 0:
if self.config['display']:
log.info(f'No .ckpt found in {self.config["log_dir"]}')
return
def sort_func(x): return int(
re.search(r"(\d+)", x).groups()[0])
ckpt_path = f'{self.config["log_dir"]}/{sorted(ckpt_paths, key=sort_func, reverse=True)[0]}'
if self.config['display']:
log.info(f'loading checkpoint {ckpt_path}')
epoch_name = osp.split(ckpt_path)[1].split('.')[0]
if re.search(r"(\d+)", epoch_name):
self.checkpoint_queue.append(ckpt_path)
metrics_filename = osp.join(
self.config['log_dir'], f'metrics_{epoch_name}.json')
if osp.exists(metrics_filename):
self.metrics_queue.append(metrics_filename)
map_location = {'cuda:0': f'cuda:{self.gpu_rank}'}
self.set_ckpt(torch.load(ckpt_path, map_location=map_location))
def match_model_key(self, pretrained_dict, model_dict):
matched_dict = dict()
for key in pretrained_dict:
if key in model_dict:
matched_key = key
elif key.startswith('encoder.') and key[8:] in model_dict:
matched_key = key[8:]
elif key.startswith('module.') and key[7:] in model_dict:
matched_key = key[7:]
elif 'encoder.' + key in model_dict:
matched_key = 'encoder.' + key
elif 'module.' + key in model_dict:
matched_key = 'module.' + key
else:
# not_found.append(key)
continue
matched_dict[matched_key] = pretrained_dict[key]
not_found = ""
for k in model_dict:
if k not in matched_dict:
not_found += k + '\n'
log.info("Keys from model_dict that were not found in pretrained_dict:")
log.info(not_found)
return matched_dict
def load_pretrained_vilbert(self, start_from=None):
if start_from is not None:
self.config["start_path"] = start_from
if self.config['training'] or self.config['debugging']:
ckpt_paths = [path for path in os.listdir(
f'{self.config["log_dir"]}/') if path.endswith('.ckpt') and 'best' not in path]
if len(ckpt_paths) > 0:
if self.config['display']:
log.info('Continue training')
return
if self.config['display']:
log.info(
f'Loading pretrained VilBERT from {self.config["start_path"]}')
map_location = {'cuda:0': f'cuda:{self.gpu_rank}'}
pretrained_dict = torch.load(
self.config['start_path'], map_location=map_location)
if 'model_state_dict' in pretrained_dict:
pretrained_dict = pretrained_dict['model_state_dict']
if self.config['parallel']:
model = self.model.module
else:
model = self.model
model_dict = model.state_dict()
matched_dict = self.match_model_key(pretrained_dict, model_dict)
if self.config['display']:
log.info("number of keys transferred: %d" % len(matched_dict))
assert len(matched_dict.keys()) > 0
model_dict.update(matched_dict)
model.load_state_dict(model_dict)
del pretrained_dict, model_dict, matched_dict
if not self.config['parallel'] or self.config['dp_type'] == 'dp':
torch.cuda.empty_cache()
if self.config['display']:
log.info(f'Pretrained VilBERT loaded')

379
models/vdgr.py Normal file
View file

@ -0,0 +1,379 @@
import sys
from collections import OrderedDict
import torch
from torch import nn
import torch.nn.functional as F
sys.path.append('../')
from utils.model_utils import listMLE, approxNDCGLoss, listNet, neuralNDCG, neuralNDCG_transposed
from utils.data_utils import sequence_mask
from utils.optim_utils import init_optim
from models.runner import Runner
from models.vilbert_dialog import BertForMultiModalPreTraining, BertConfig
class VDGR(nn.Module):
def __init__(self, config_path, device, use_apex=False, cache_dir=None):
super(VDGR, self).__init__()
config = BertConfig.from_json_file(config_path)
self.bert_pretrained = BertForMultiModalPreTraining.from_pretrained('bert-base-uncased', config, device, use_apex=use_apex, cache_dir=cache_dir)
self.bert_pretrained.train()
def forward(self, input_ids, image_feat, image_loc, image_edge_indices, image_edge_attributes,
question_edge_indices, question_edge_attributes, question_limits,
history_edge_indices, history_sep_indices,
sep_indices=None, sep_len=None, token_type_ids=None,
attention_mask=None, masked_lm_labels=None, next_sentence_label=None,
image_attention_mask=None, image_label=None, image_target=None):
masked_lm_loss = None
masked_img_loss = None
nsp_loss = None
seq_relationship_score = None
if next_sentence_label is not None and masked_lm_labels \
is not None and image_target is not None:
# train mode, output losses
masked_lm_loss, masked_img_loss, nsp_loss, _, _, seq_relationship_score, _ = \
self.bert_pretrained(
input_ids, image_feat, image_loc, image_edge_indices, image_edge_attributes,
question_edge_indices, question_edge_attributes, question_limits,
history_edge_indices, history_sep_indices, sep_indices=sep_indices, sep_len=sep_len, \
token_type_ids=token_type_ids, attention_mask=attention_mask, masked_lm_labels=masked_lm_labels, \
next_sentence_label=next_sentence_label, image_attention_mask=image_attention_mask,\
image_label=image_label, image_target=image_target)
else:
#inference, output scores
_, _, seq_relationship_score, _, _, _ = \
self.bert_pretrained(
input_ids, image_feat, image_loc, image_edge_indices, image_edge_attributes,
question_edge_indices, question_edge_attributes, question_limits,
history_edge_indices, history_sep_indices,
sep_indices=sep_indices, sep_len=sep_len, \
token_type_ids=token_type_ids, attention_mask=attention_mask, masked_lm_labels=masked_lm_labels, \
next_sentence_label=next_sentence_label, image_attention_mask=image_attention_mask,\
image_label=image_label, image_target=image_target)
out = (masked_lm_loss, masked_img_loss, nsp_loss, seq_relationship_score)
return out
class SparseRunner(Runner):
def __init__(self, config):
super(SparseRunner, self).__init__(config)
self.model = VDGR(
self.config['model_config'], self.config['device'],
use_apex=self.config['dp_type'] == 'apex',
cache_dir=self.config['bert_cache_dir'])
self.model.to(self.config['device'])
if not self.config['validating'] or self.config['dp_type'] == 'apex':
self.optimizer, self.scheduler = init_optim(self.model, self.config)
def forward(self, batch, eval_visdial=False):
# load data
for key in batch:
if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].to(self.config['device'])
elif isinstance(batch[key], list):
if key != 'dialog_info': # Do not send the dialog_info item to the gpu
batch[key] = [x.to(self.config['device']) for x in batch[key]]
tokens = batch['tokens']
segments = batch['segments']
sep_indices = batch['sep_indices']
mask = batch['mask']
hist_len = batch['hist_len']
image_feat = batch['image_feat']
image_loc = batch['image_loc']
image_mask = batch['image_mask']
next_sentence_labels = batch.get('next_sentence_labels', None)
image_target = batch.get('image_target', None)
image_label = batch.get('image_label', None)
# load the graph data
image_edge_indices = batch['image_edge_indices']
image_edge_attributes = batch['image_edge_attributes']
question_edge_indices = batch['question_edge_indices']
question_edge_attributes = batch['question_edge_attributes']
question_limits = batch['question_limits']
history_edge_indices = batch['history_edge_indices']
history_sep_indices = batch['history_sep_indices']
sequence_lengths = torch.gather(sep_indices, 1, hist_len.view(-1, 1)) + 1
sequence_lengths = sequence_lengths.squeeze(1)
attention_mask_lm_nsp = sequence_mask(sequence_lengths, max_len=tokens.shape[1])
sep_len = hist_len + 1
losses = OrderedDict()
if eval_visdial:
num_lines = tokens.size(0)
line_batch_size = self.config['eval_line_batch_size']
num_line_batches = num_lines // line_batch_size
if num_lines % line_batch_size > 0:
num_line_batches += 1
nsp_scores = []
for j in range(num_line_batches):
# create chunks of the original batch
chunk_range = range(j*line_batch_size, min((j+1)*line_batch_size, num_lines))
tokens_chunk = tokens[chunk_range]
segments_chunk = segments[chunk_range]
sep_indices_chunk = sep_indices[chunk_range]
mask_chunk = mask[chunk_range]
sep_len_chunk = sep_len[chunk_range]
attention_mask_lm_nsp_chunk = attention_mask_lm_nsp[chunk_range]
image_feat_chunk = image_feat[chunk_range]
image_loc_chunk = image_loc[chunk_range]
image_mask_chunk = image_mask[chunk_range]
image_edge_indices_chunk = image_edge_indices[chunk_range[0]: chunk_range[-1]+1]
image_edge_attributes_chunk = image_edge_attributes[chunk_range[0]: chunk_range[-1]+1]
question_edge_indices_chunk = question_edge_indices[chunk_range[0]: chunk_range[-1]+1]
question_edge_attributes_chunk = question_edge_attributes[chunk_range[0]: chunk_range[-1]+1]
question_limits_chunk = question_limits[chunk_range[0]: chunk_range[-1]+1]
history_edge_indices_chunk = history_edge_indices[chunk_range[0]: chunk_range[-1]+1]
history_sep_indices_chunk = history_sep_indices[chunk_range[0]: chunk_range[-1]+1]
_ , _ , _, nsp_scores_chunk = \
self.model(
tokens_chunk,
image_feat_chunk,
image_loc_chunk,
image_edge_indices_chunk,
image_edge_attributes_chunk,
question_edge_indices_chunk,
question_edge_attributes_chunk,
question_limits_chunk,
history_edge_indices_chunk,
history_sep_indices_chunk,
sep_indices=sep_indices_chunk,
sep_len=sep_len_chunk,
token_type_ids=segments_chunk,
masked_lm_labels=mask_chunk,
attention_mask=attention_mask_lm_nsp_chunk,
image_attention_mask=image_mask_chunk
)
nsp_scores.append(nsp_scores_chunk)
nsp_scores = torch.cat(nsp_scores, 0)
else:
losses['lm_loss'], losses['img_loss'], losses['nsp_loss'], nsp_scores = \
self.model(
tokens,
image_feat,
image_loc,
image_edge_indices,
image_edge_attributes,
question_edge_indices,
question_edge_attributes,
question_limits,
history_edge_indices,
history_sep_indices,
next_sentence_label=next_sentence_labels,
image_target=image_target,
image_label=image_label,
sep_indices=sep_indices,
sep_len=sep_len,
token_type_ids=segments,
masked_lm_labels=mask,
attention_mask=attention_mask_lm_nsp,
image_attention_mask=image_mask
)
losses['tot_loss'] = 0
for key in ['lm_loss', 'img_loss', 'nsp_loss']:
if key in losses and losses[key] is not None:
losses[key] = losses[key].mean()
losses['tot_loss'] += self.config[f'{key}_coeff'] * losses[key]
output = {
'losses': losses,
'nsp_scores': nsp_scores
}
return output
class DenseRunner(Runner):
def __init__(self, config):
super(DenseRunner, self).__init__(config)
self.model = VDGR(
self.config['model_config'], self.config['device'],
use_apex=self.config['dp_type'] == 'apex',
cache_dir=self.config['bert_cache_dir'])
if not(self.config['parallel'] and self.config['dp_type'] == 'dp'):
self.model.to(self.config['device'])
if self.config['dense_loss'] == 'ce':
self.dense_loss = nn.KLDivLoss(reduction='batchmean')
elif self.config['dense_loss'] == 'listmle':
self.dense_loss = listMLE
elif self.config['dense_loss'] == 'listnet':
self.dense_loss = listNet
elif self.config['dense_loss'] == 'approxndcg':
self.dense_loss = approxNDCGLoss
elif self.config['dense_loss'] == 'neural_ndcg':
self.dense_loss = neuralNDCG
elif self.config['dense_loss'] == 'neural_ndcg_transposed':
self.dense_loss = neuralNDCG_transposed
else:
raise ValueError('dense_loss must be one of ce, listmle, listnet, approxndcg, neural_ndcg, neural_ndcg_transposed')
if not self.config['validating'] or self.config['dp_type'] == 'apex':
self.optimizer, self.scheduler = init_optim(self.model, self.config)
def forward(self, batch, eval_visdial=False):
# load data
for key in batch:
if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].to(self.config['device'])
elif isinstance(batch[key], list):
if key != 'dialog_info': # Do not send the dialog_info item to the gpu
batch[key] = [x.to(self.config['device']) for x in batch[key]]
# get embedding and forward visdial
tokens = batch['tokens']
segments = batch['segments']
sep_indices = batch['sep_indices']
mask = batch['mask']
hist_len = batch['hist_len']
image_feat = batch['image_feat']
image_loc = batch['image_loc']
image_mask = batch['image_mask']
next_sentence_labels = batch.get('next_sentence_labels', None)
image_target = batch.get('image_target', None)
image_label = batch.get('image_label', None)
# load the graph data
image_edge_indices = batch['image_edge_indices']
image_edge_attributes = batch['image_edge_attributes']
question_edge_indices = batch['question_edge_indices']
question_edge_attributes = batch['question_edge_attributes']
question_limits = batch['question_limits']
history_edge_indices = batch['history_edge_indices']
assert history_edge_indices[0].size(0) == 2
history_sep_indices = batch['history_sep_indices']
sequence_lengths = torch.gather(sep_indices, 1, hist_len.view(-1, 1)) + 1
sequence_lengths = sequence_lengths.squeeze(1)
attention_mask_lm_nsp = sequence_mask(sequence_lengths, max_len=tokens.shape[1])
sep_len = hist_len + 1
losses = OrderedDict()
if eval_visdial:
num_lines = tokens.size(0)
line_batch_size = self.config['eval_line_batch_size']
num_line_batches = num_lines // line_batch_size
if num_lines % line_batch_size > 0:
num_line_batches += 1
nsp_scores = []
for j in range(num_line_batches):
# create chunks of the original batch
chunk_range = range(j*line_batch_size, min((j+1)*line_batch_size, num_lines))
tokens_chunk = tokens[chunk_range]
segments_chunk = segments[chunk_range]
sep_indices_chunk = sep_indices[chunk_range]
mask_chunk = mask[chunk_range]
sep_len_chunk = sep_len[chunk_range]
attention_mask_lm_nsp_chunk = attention_mask_lm_nsp[chunk_range]
image_feat_chunk = image_feat[chunk_range]
image_loc_chunk = image_loc[chunk_range]
image_mask_chunk = image_mask[chunk_range]
image_edge_indices_chunk = image_edge_indices[chunk_range[0]: chunk_range[-1]+1]
image_edge_attributes_chunk = image_edge_attributes[chunk_range[0]: chunk_range[-1]+1]
question_edge_indices_chunk = question_edge_indices[chunk_range[0]: chunk_range[-1]+1]
question_edge_attributes_chunk = question_edge_attributes[chunk_range[0]: chunk_range[-1]+1]
question_limits_chunk = question_limits[chunk_range[0]: chunk_range[-1]+1]
history_edge_indices_chunk = history_edge_indices[chunk_range[0]: chunk_range[-1]+1]
history_sep_indices_chunk = history_sep_indices[chunk_range[0]: chunk_range[-1]+1]
_, _, _, nsp_scores_chunk = \
self.model(
tokens_chunk,
image_feat_chunk,
image_loc_chunk,
image_edge_indices_chunk,
image_edge_attributes_chunk,
question_edge_indices_chunk,
question_edge_attributes_chunk,
question_limits_chunk,
history_edge_indices_chunk,
history_sep_indices_chunk,
sep_indices=sep_indices_chunk,
sep_len=sep_len_chunk,
token_type_ids=segments_chunk,
masked_lm_labels=mask_chunk,
attention_mask=attention_mask_lm_nsp_chunk,
image_attention_mask=image_mask_chunk
)
nsp_scores.append(nsp_scores_chunk)
nsp_scores = torch.cat(nsp_scores, 0)
else:
_, _, _, nsp_scores = \
self.model(
tokens,
image_feat,
image_loc,
image_edge_indices,
image_edge_attributes,
question_edge_indices,
question_edge_attributes,
question_limits,
history_edge_indices,
history_sep_indices,
next_sentence_label=next_sentence_labels,
image_target=image_target,
image_label=image_label,
sep_indices=sep_indices,
sep_len=sep_len,
token_type_ids=segments,
masked_lm_labels=mask,
attention_mask=attention_mask_lm_nsp,
image_attention_mask=image_mask
)
if nsp_scores is not None:
nsp_scores_output = nsp_scores.detach().clone()
if not eval_visdial:
nsp_scores = nsp_scores.view(-1, self.config['num_options_dense'], 2)
if 'next_sentence_labels' in batch and self.config['nsp_loss_coeff'] > 0:
next_sentence_labels = batch['next_sentence_labels'].to(self.config['device'])
losses['nsp_loss'] = F.cross_entropy(nsp_scores.view(-1,2), next_sentence_labels.view(-1))
else:
losses['nsp_loss'] = None
if not eval_visdial:
gt_relevance = batch['gt_relevance'].to(self.config['device'])
nsp_scores = nsp_scores[:, :, 0]
if self.config['dense_loss'] == 'ce':
losses['dense_loss'] = self.dense_loss(F.log_softmax(nsp_scores, dim=1), F.softmax(gt_relevance, dim=1))
else:
losses['dense_loss'] = self.dense_loss(nsp_scores, gt_relevance)
else:
losses['dense_loss'] = None
else:
nsp_scores_output = None
losses['nsp_loss'] = None
losses['dense_loss'] = None
losses['tot_loss'] = 0
for key in ['nsp_loss', 'dense_loss']:
if key in losses and losses[key] is not None:
losses[key] = losses[key].mean()
losses['tot_loss'] += self.config[f'{key}_coeff'] * losses[key]
output = {
'losses': losses,
'nsp_scores': nsp_scores_output
}
return output

2021
models/vilbert_dialog.py Normal file

File diff suppressed because it is too large Load diff

18
setup_data.sh Normal file
View file

@ -0,0 +1,18 @@
cd data
# Exract the graphs
tar xvfz history_adj_matrices.tar.gz
tar xvfz question_adj_matrices.tar.gz
tar xvfz img_adj_matrices.tar.gz
# Remove the .tar files
rm *.tar.gz
# Download the preprocessed image features
mkdir visdial_img_feat.lmdb
wget https://s3.amazonaws.com/visdial-bert/data/visdial_image_feats.lmdb/data.mdb -O visdial_img_feat.lmdb/data.mdb
wget https://s3.amazonaws.com/visdial-bert/data/visdial_image_feats.lmdb/lock.mdb -O visdial_img_feat.lmdb/lock.mdb
echo Data setup successfully...
cd ..

0
utils/__init__.py Normal file
View file

290
utils/data_utils.py Normal file
View file

@ -0,0 +1,290 @@
import torch
from torch.autograd import Variable
import random
import pickle
import numpy as np
from copy import deepcopy
def load_pickle_lines(filename):
data = []
with open(filename, 'rb') as f:
while True:
try:
data.append(pickle.load(f))
except EOFError:
break
return data
def flatten(l):
return [item for sublist in l for item in sublist]
def build_len_mask_batch(
# [batch_size], []
len_batch, max_len=None
):
if max_len is None:
max_len = len_batch.max().item()
# try:
batch_size, = len_batch.shape
# [batch_size, max_len]
idxes_batch = torch.arange(max_len, device=len_batch.device).view(1, -1).repeat(batch_size, 1)
# [batch_size, max_len] = [batch_size, max_len] < [batch_size, 1]
return idxes_batch < len_batch.view(-1, 1)
def sequence_mask(sequence_length, max_len=None):
if max_len is None:
max_len = sequence_length.data.max()
batch_size = sequence_length.size(0)
seq_range = torch.arange(0, max_len).long()
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
seq_range_expand = Variable(seq_range_expand)
if sequence_length.is_cuda:
seq_range_expand = seq_range_expand.to(sequence_length.device)
seq_length_expand = (sequence_length.unsqueeze(1)
.expand_as(seq_range_expand))
return seq_range_expand < seq_length_expand
def batch_iter(dataloader, params):
for epochId in range(params['num_epochs']):
for idx, batch in enumerate(dataloader):
yield epochId, idx, batch
def list2tensorpad(inp_list, max_seq_len):
inp_tensor = torch.LongTensor([inp_list])
inp_tensor_zeros = torch.zeros(1, max_seq_len, dtype=torch.long)
inp_tensor_zeros[0,:inp_tensor.shape[1]] = inp_tensor # after preprocess, inp_tensor.shape[1] must < max_seq_len
inp_tensor = inp_tensor_zeros
return inp_tensor
def encode_input(utterances, start_segment, CLS, SEP, MASK, max_seq_len=256,max_sep_len=25,mask_prob=0.2):
cur_segment = start_segment
token_id_list = []
segment_id_list = []
sep_token_indices = []
masked_token_list = []
token_id_list.append(CLS)
segment_id_list.append(cur_segment)
masked_token_list.append(0)
cur_sep_token_index = 0
for cur_utterance in utterances:
# add the masked token and keep track
cur_masked_index = [1 if random.random() < mask_prob else 0 for _ in range(len(cur_utterance))]
masked_token_list.extend(cur_masked_index)
token_id_list.extend(cur_utterance)
segment_id_list.extend([cur_segment]*len(cur_utterance))
token_id_list.append(SEP)
segment_id_list.append(cur_segment)
masked_token_list.append(0)
cur_sep_token_index = cur_sep_token_index + len(cur_utterance) + 1
sep_token_indices.append(cur_sep_token_index)
cur_segment = cur_segment ^ 1 # cur segment osciallates between 0 and 1
start_question, end_question = sep_token_indices[-3] + 1, sep_token_indices[-2]
assert end_question - start_question == len(utterances[-2])
assert len(segment_id_list) == len(token_id_list) == len(masked_token_list) == sep_token_indices[-1] + 1
# convert to tensors and pad to maximum seq length
tokens = list2tensorpad(token_id_list,max_seq_len) # [1, max_len]
masked_tokens = list2tensorpad(masked_token_list,max_seq_len)
masked_tokens[0,masked_tokens[0,:]==0] = -1
mask = masked_tokens[0,:]==1
masked_tokens[0,mask] = tokens[0,mask]
tokens[0,mask] = MASK
segment_id_list = list2tensorpad(segment_id_list,max_seq_len)
return tokens, segment_id_list, list2tensorpad(sep_token_indices,max_sep_len), masked_tokens, start_question, end_question
def encode_input_with_mask(utterances, start_segment, CLS, SEP, MASK, max_seq_len=256,max_sep_len=25,mask_prob=0.2, get_q_limits=True):
cur_segment = start_segment
token_id_list = []
segment_id_list = []
sep_token_indices = []
masked_token_list = []
input_mask_list = []
token_id_list.append(CLS)
segment_id_list.append(cur_segment)
masked_token_list.append(0)
input_mask_list.append(1)
cur_sep_token_index = 0
for cur_utterance in utterances:
# add the masked token and keep track
cur_masked_index = [1 if random.random() < mask_prob else 0 for _ in range(len(cur_utterance))]
masked_token_list.extend(cur_masked_index)
token_id_list.extend(cur_utterance)
segment_id_list.extend([cur_segment]*len(cur_utterance))
input_mask_list.extend([1]*len(cur_utterance))
token_id_list.append(SEP)
segment_id_list.append(cur_segment)
masked_token_list.append(0)
input_mask_list.append(1)
cur_sep_token_index = cur_sep_token_index + len(cur_utterance) + 1
sep_token_indices.append(cur_sep_token_index)
cur_segment = cur_segment ^ 1 # cur segment osciallates between 0 and 1
if get_q_limits:
start_question, end_question = sep_token_indices[-3] + 1, sep_token_indices[-2]
assert end_question - start_question == len(utterances[-2])
else:
start_question, end_question = -1, -1
assert len(segment_id_list) == len(token_id_list) == len(masked_token_list) ==len(input_mask_list) == sep_token_indices[-1] + 1
# convert to tensors and pad to maximum seq length
tokens = list2tensorpad(token_id_list, max_seq_len)
masked_tokens = list2tensorpad(masked_token_list, max_seq_len)
input_mask = list2tensorpad(input_mask_list,max_seq_len)
masked_tokens[0,masked_tokens[0,:]==0] = -1
mask = masked_tokens[0,:]==1
masked_tokens[0,mask] = tokens[0,mask]
tokens[0,mask] = MASK
segment_id_list = list2tensorpad(segment_id_list,max_seq_len)
return tokens, segment_id_list, list2tensorpad(sep_token_indices,max_sep_len),masked_tokens, input_mask, start_question, end_question
def encode_image_input(features, num_boxes, boxes, image_target, max_regions=37, mask_prob=0.15):
output_label = []
num_boxes = min(int(num_boxes), max_regions)
mix_boxes_pad = np.zeros((max_regions, boxes.shape[-1]))
mix_features_pad = np.zeros((max_regions, features.shape[-1]))
mix_image_target = np.zeros((max_regions, image_target.shape[-1]))
mix_boxes_pad[:num_boxes] = boxes[:num_boxes]
mix_features_pad[:num_boxes] = features[:num_boxes]
mix_image_target[:num_boxes] = image_target[:num_boxes]
boxes = mix_boxes_pad
features = mix_features_pad
image_target = mix_image_target
mask_indexes = []
for i in range(num_boxes):
prob = random.random()
# mask token with 15% probability
if prob < mask_prob:
prob /= mask_prob
# 80% randomly change token to mask token
if prob < 0.9:
features[i] = 0
output_label.append(1)
mask_indexes.append(i)
else:
# no masking token (will be ignored by loss function later)
output_label.append(-1)
image_mask = [1] * (int(num_boxes))
while len(image_mask) < max_regions:
image_mask.append(0)
output_label.append(-1)
# ensure we have atleast one region being predicted
output_label[random.randint(1,len(output_label)-1)] = 1
image_label = torch.LongTensor(output_label)
image_label[0] = 0 # make sure the <IMG> token doesn't contribute to the masked loss
image_mask = torch.tensor(image_mask).float()
features = torch.tensor(features).float()
spatials = torch.tensor(boxes).float()
image_target = torch.tensor(image_target).float()
return features, spatials, image_mask, image_target, image_label
def question_edge_masking(question_edge_indices, question_edge_attributes, mask, question_limits, mask_prob=0.4, max_len=10):
mask = mask.squeeze().tolist()
question_limits = question_limits.tolist()
question_start, question_end = question_limits
# Get the masking of the question
mask_question = mask[question_start:question_end]
masked_idx = np.argwhere(np.array(mask_question) > -1).squeeze().tolist()
if isinstance(masked_idx, (int)): # only one question token is masked
masked_idx = [masked_idx]
# get rid of all edge indices and attributes that corresond to masked tokens
edge_attr_gt = []
edge_idx_gt_gnn = []
edge_idx_gt_bert = []
for i, (question_edge_idx, question_edge_attr) in enumerate(zip(question_edge_indices, question_edge_attributes)):
if not(question_edge_idx[0] in masked_idx or question_edge_idx[1] in masked_idx):
# Masking
if random.random() < mask_prob:
edge_attr_gt.append(np.argwhere(question_edge_attr).item())
edge_idx_gt_gnn.append(question_edge_idx)
edge_idx_gt_bert.append([question_edge_idx[0] + question_start, question_edge_idx[1] + question_start])
question_edge_attr = np.zeros_like(question_edge_attr)
question_edge_attr[-1] = 1.0 # The [EDGE_MASK] special token is the last one hot vector encoding
question_edge_attributes[i] = question_edge_attr
else:
continue
# Force masking if the necessary:
if len(edge_attr_gt) == 0:
for i, (question_edge_idx, question_edge_attr) in enumerate(zip(question_edge_indices, question_edge_attributes)):
if not(question_edge_idx[0] in masked_idx or question_edge_idx[1] in masked_idx):
# Masking
edge_attr_gt.append(np.argwhere(question_edge_attr).item())
edge_idx_gt_gnn.append(question_edge_idx)
edge_idx_gt_bert.append([question_edge_idx[0] + question_start, question_edge_idx[1] + question_start])
question_edge_attr = np.zeros_like(question_edge_attr)
question_edge_attr[-1] = 1.0 # The [EDGE_MASK] special token is the last one hot vector encoding
question_edge_attributes[i] = question_edge_attr
break
# For the rare case, where the conditions for masking were not met
if len(edge_attr_gt) == 0:
edge_attr_gt.append(-1)
edge_idx_gt_gnn.append([0, question_end - question_start])
edge_idx_gt_bert.append(question_limits)
# Pad to max_len
while len(edge_attr_gt) < max_len:
edge_attr_gt.append(-1)
edge_idx_gt_gnn.append(edge_idx_gt_gnn[-1])
edge_idx_gt_bert.append(edge_idx_gt_bert[-1])
# Truncate if longer than max_len
if len(edge_attr_gt) > max_len:
edge_idx_gt_gnn = edge_idx_gt_gnn[:max_len]
edge_idx_gt_bert = edge_idx_gt_bert[:max_len]
edge_attr_gt = edge_attr_gt[:max_len]
edge_idx_gt_gnn = np.array(edge_idx_gt_gnn)
edge_idx_gt_bert = np.array(edge_idx_gt_bert)
first_edge_node_gt_gnn = list(edge_idx_gt_gnn[:, 0])
second_edge_node_gt_gnn = list(edge_idx_gt_gnn[:, 1])
first_edge_node_gt_bert = list(edge_idx_gt_bert[:, 0])
second_edge_node_gt_bert = list(edge_idx_gt_bert[:, 1])
return question_edge_attributes, edge_attr_gt, first_edge_node_gt_gnn, second_edge_node_gt_gnn, first_edge_node_gt_bert, second_edge_node_gt_bert
def to_data_list(feats, batch_idx):
feat_list = []
device = feats.device
left = 0
right = 0
batch_size = batch_idx.max().item() + 1
for batch in range(batch_size):
if batch == batch_size - 1:
right = batch_idx.size(0)
else:
right = torch.argwhere(batch_idx == batch + 1)[0].item()
idx = torch.arange(left, right).unsqueeze(-1).repeat(1, feats.size(1)).to(device)
feat_list.append(torch.gather(feats, 0, idx))
left = right
return feat_list

View file

@ -0,0 +1,192 @@
from typing import List
import csv
import h5py
import numpy as np
import copy
import pickle
import lmdb # install lmdb by "pip install lmdb"
import base64
import pdb
import os
class ImageFeaturesH5Reader(object):
"""
A reader for H5 files containing pre-extracted image features. A typical
H5 file is expected to have a column named "image_id", and another column
named "features".
Example of an H5 file:
```
faster_rcnn_bottomup_features.h5
|--- "image_id" [shape: (num_images, )]
|--- "features" [shape: (num_images, num_proposals, feature_size)]
+--- .attrs ("split", "train")
```
Parameters
----------
features_h5path : str
Path to an H5 file containing COCO train / val image features.
in_memory : bool
Whether to load the whole H5 file in memory. Beware, these files are
sometimes tens of GBs in size. Set this to true if you have sufficient
RAM - trade-off between speed and memory.
"""
def __init__(self, features_path: str, scene_graph_path: str, in_memory: bool = False):
self.features_path = features_path
self.scene_graph_path = scene_graph_path
self._in_memory = in_memory
self.env = lmdb.open(self.features_path, max_readers=1, readonly=True,
lock=False, readahead=False, meminit=False)
with self.env.begin(write=False) as txn:
self._image_ids = pickle.loads(txn.get('keys'.encode()))
self.features = [None] * len(self._image_ids)
self.num_boxes = [None] * len(self._image_ids)
self.boxes = [None] * len(self._image_ids)
self.boxes_ori = [None] * len(self._image_ids)
self.cls_prob = [None] * len(self._image_ids)
self.edge_indexes = [None] * len(self._image_ids)
self.edge_attributes = [None] * len(self._image_ids)
def __len__(self):
return len(self._image_ids)
def __getitem__(self, image_id):
image_id = str(image_id).encode()
index = self._image_ids.index(image_id)
if self._in_memory:
# Load features during first epoch, all not loaded together as it
# has a slow start.
if self.features[index] is not None:
features = self.features[index]
num_boxes = self.num_boxes[index]
image_location = self.boxes[index]
image_location_ori = self.boxes_ori[index]
cls_prob = self.cls_prob[index]
edge_indexes = self.edge_indexes[index]
edge_attributes = self.edge_attributes[index]
else:
with self.env.begin(write=False) as txn:
item = pickle.loads(txn.get(image_id))
image_id = item['image_id']
image_h = int(item['image_h'])
image_w = int(item['image_w'])
num_boxes = int(item['num_boxes'])
features = np.frombuffer(base64.b64decode(item["features"]), dtype=np.float32).reshape(num_boxes, 2048)
boxes = np.frombuffer(base64.b64decode(item['boxes']), dtype=np.float32).reshape(num_boxes, 4)
cls_prob = np.frombuffer(base64.b64decode(item['cls_prob']), dtype=np.float32).reshape(num_boxes, 1601)
# add an extra row at the top for the <IMG> tokens
g_cls_prob = np.zeros(1601, dtype=np.float32)
g_cls_prob[0] = 1
cls_prob = np.concatenate([np.expand_dims(g_cls_prob,axis=0), cls_prob], axis=0)
self.cls_prob[index] = cls_prob
g_feat = np.sum(features, axis=0) / num_boxes
num_boxes = num_boxes + 1
features = np.concatenate([np.expand_dims(g_feat, axis=0), features], axis=0)
self.features[index] = features
image_location = np.zeros((boxes.shape[0], 5), dtype=np.float32)
image_location[:,:4] = boxes
image_location[:,4] = (image_location[:,3] - image_location[:,1]) * (image_location[:,2] - image_location[:,0]) / (float(image_w) * float(image_h))
image_location_ori = copy.deepcopy(image_location)
image_location[:,0] = image_location[:,0] / float(image_w)
image_location[:,1] = image_location[:,1] / float(image_h)
image_location[:,2] = image_location[:,2] / float(image_w)
image_location[:,3] = image_location[:,3] / float(image_h)
g_location = np.array([0,0,1,1,1])
image_location = np.concatenate([np.expand_dims(g_location, axis=0), image_location], axis=0)
self.boxes[index] = image_location
g_location_ori = np.array([0, 0, image_w, image_h, image_w*image_h])
image_location_ori = np.concatenate([np.expand_dims(g_location_ori, axis=0), image_location_ori], axis=0)
self.boxes_ori[index] = image_location_ori
self.num_boxes[index] = num_boxes
# load the scene graph data
pth = os.path.join(self.scene_graph_path, f'{image_id}.pkl')
with open(pth, 'rb') as f:
graph_data = pickle.load(f)
edge_indexes = []
edge_attributes = []
for e_idx, e_attr in graph_data:
edge_indexes.append(e_idx)
# get one-hot-encoding of the edges
e_attr_one_hot = np.zeros((12,), dtype=np.float32) # 12 = 11 rels + hub-node rel
e_attr_one_hot[e_attr] = 1.0
edge_attributes.append(e_attr_one_hot)
edge_indexes = np.array(edge_indexes, dtype=np.float64).transpose(1, 0)
edge_attributes = np.stack(edge_attributes, axis=0)
self.edge_indexes[index] = edge_indexes
self.edge_attributes[index] = edge_attributes
else:
# Read chunk from file everytime if not loaded in memory.
with self.env.begin(write=False) as txn:
item = pickle.loads(txn.get(image_id))
image_id = item['image_id']
image_h = int(item['image_h'])
image_w = int(item['image_w'])
num_boxes = int(item['num_boxes'])
cls_prob = np.frombuffer(base64.b64decode(item['cls_prob']), dtype=np.float32).reshape(num_boxes, 1601)
# add an extra row at the top for the <IMG> tokens
g_cls_prob = np.zeros(1601, dtype=np.float32)
g_cls_prob[0] = 1
cls_prob = np.concatenate([np.expand_dims(g_cls_prob,axis=0), cls_prob], axis=0)
features = np.frombuffer(base64.b64decode(item["features"]), dtype=np.float32).reshape(num_boxes, 2048)
boxes = np.frombuffer(base64.b64decode(item['boxes']), dtype=np.float32).reshape(num_boxes, 4)
g_feat = np.sum(features, axis=0) / num_boxes
num_boxes = num_boxes + 1
features = np.concatenate([np.expand_dims(g_feat, axis=0), features], axis=0)
image_location = np.zeros((boxes.shape[0], 5), dtype=np.float32)
image_location[:,:4] = boxes
image_location[:,4] = (image_location[:,3] - image_location[:,1]) * (image_location[:,2] - image_location[:,0]) / (float(image_w) * float(image_h))
image_location_ori = copy.deepcopy(image_location)
image_location[:,0] = image_location[:,0] / float(image_w)
image_location[:,1] = image_location[:,1] / float(image_h)
image_location[:,2] = image_location[:,2] / float(image_w)
image_location[:,3] = image_location[:,3] / float(image_h)
g_location = np.array([0,0,1,1,1])
image_location = np.concatenate([np.expand_dims(g_location, axis=0), image_location], axis=0)
g_location_ori = np.array([0,0,image_w,image_h,image_w*image_h])
image_location_ori = np.concatenate([np.expand_dims(g_location_ori, axis=0), image_location_ori], axis=0)
# load the scene graph data
pth = os.path.join(self.scene_graph_path, f'{image_id}.pkl')
with open(pth, 'rb') as f:
graph_data = pickle.load(f)
edge_indexes = []
edge_attributes = []
for e_idx, e_attr in graph_data:
edge_indexes.append(e_idx)
# get one-hot-encoding of the edges
e_attr_one_hot = np.zeros((12,), dtype=np.float32) # 12 = 11 rels + hub-node rel
e_attr_one_hot[e_attr] = 1.0
edge_attributes.append(e_attr_one_hot)
edge_indexes = np.array(edge_indexes, dtype=np.float64).transpose(1, 0)
edge_attributes = np.stack(edge_attributes, axis=0)
return features, num_boxes, image_location, image_location_ori, cls_prob, edge_indexes, edge_attributes
def keys(self) -> List[int]:
return self._image_ids
def set_keys(self, new_ids: List[str]):
self._image_ids = list(map(lambda _id: _id.encode('ascii') ,new_ids))

176
utils/init_utils.py Normal file
View file

@ -0,0 +1,176 @@
import os
import os.path as osp
import random
import datetime
import itertools
import glob
import subprocess
import pyhocon
import glob
import re
import numpy as np
import glog as log
import json
import torch
import sys
sys.path.append('../')
from models import vdgr
from dataloader.dataloader_visdial import VisdialDataset
from dataloader.dataloader_visdial_dense import VisdialDenseDataset
def load_runner(config):
if config['train_on_dense']:
return vdgr.DenseRunner(config)
else:
return vdgr.SparseRunner(config)
def load_dataset(config):
dataset_eval = None
if config['train_on_dense']:
dataset = VisdialDenseDataset(config)
if config['skip_mrr_eval']:
temp = config['num_options_dense']
config['num_options_dense'] = config['num_options']
dataset_eval = VisdialDenseDataset(config)
config['num_options_dense'] = temp
else:
dataset_eval = VisdialDataset(config)
else:
dataset = VisdialDataset(config)
if config['skip_mrr_eval']:
dataset_eval = VisdialDenseDataset(config)
if config['use_trainval']:
dataset.split = 'trainval'
else:
dataset.split = 'train'
if dataset_eval is not None:
dataset_eval.split = 'val'
return dataset, dataset_eval
def initialize_from_env(model, mode, eval_dir, model_type, tag=''):
if "GPU" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ['GPU']
if mode in ['train', 'debug']:
config = pyhocon.ConfigFactory.parse_file(f"config/{model_type}.conf")[model]
else:
path_config = osp.join(eval_dir, 'code', f"config/{model_type}.conf")
config = pyhocon.ConfigFactory.parse_file(path_config)[model]
config['log_dir'] = eval_dir
config['model_config'] = osp.join(eval_dir, 'code/config/bert_base_6layer_6conect.json')
if config['dp_type'] == 'apex':
config['dp_type'] = 'ddp'
if config['dp_type'] == 'dp':
config['stack_gr_data'] = True
config['model_type'] = model_type
if "CUDA_VISIBLE_DEVICES" in os.environ:
config['num_gpus'] = len(os.environ["CUDA_VISIBLE_DEVICES"].split(','))
# multi-gpu setting
if config['num_gpus'] > 1:
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '5678'
if mode == 'debug':
model += '_debug'
if tag:
model += '-' + tag
if mode in ['train', 'debug']:
config['log_dir'] = os.path.join(config["log_dir"], model)
if not os.path.exists(config["log_dir"]):
os.makedirs(config["log_dir"])
config['visdial_output_dir'] = osp.join(config['log_dir'], config['visdial_output_dir'])
config['timestamp'] = datetime.datetime.now().strftime('%m%d-%H%M%S')
# add the bert config
config['bert_config'] = json.load(open(config['model_config'], 'r'))
if mode in ['predict', 'eval']:
if (not config['loads_start_path']) and (not config['loads_best_ckpt']):
config['loads_best_ckpt'] = True
print(f'Setting loads_best_ckpt=True under predict or eval mode')
if config['num_options_dense'] < 100:
config['num_options_dense'] = 100
print('Setting num_options_dense=100 under predict or eval mode')
if config['visdial_version'] == 0.9:
config['skip_ndcg_eval'] = True
return config
def set_log_file(fname, file_only=False):
# if fname already exists, find all log file under log dir,
# and name the current log file with a new number
if osp.exists(fname):
prefix, suffix = osp.splitext(fname)
log_files = glob.glob(prefix + '*' + suffix)
count = 0
for log_file in log_files:
num = re.search(r'(\d+)', log_file)
if num is not None:
num = int(num.group(0))
count = max(num, count)
fname = fname.replace(suffix, str(count + 1) + suffix)
# set log file
# simple tricks for duplicating logging destination in the logging module such as:
# logging.getLogger().addHandler(logging.FileHandler(filename))
# does NOT work well here, because python Traceback message (not via logging module) is not sent to the file,
# the following solution (copied from : https://stackoverflow.com/questions/616645) is a little bit
# complicated but simulates exactly the "tee" command in linux shell, and it redirects everything
if file_only:
# we only output messages to file, and stdout/stderr receives nothing.
# this feature is designed for executing the script via ssh:
# since ssh has a windowing kind of flow control, i.e., if the controller does not read data from a
# ssh channel and its buffer fills up, the execution machine will not be able to write anything into the
# channel and the process will be set to sleeping (S) status until someone reads all data from the channel.
# this is not desired since we do not want to read stdout/stderr from the controller machine.
# so, here we use a simple solution: disable output to stdout/stderr and only output messages to log file.
log.logger.handlers[0].stream = log.handler.stream = sys.stdout = sys.stderr = f = open(fname, 'w', buffering=1)
else:
# we output messages to both file and stdout/stderr
tee = subprocess.Popen(['tee', fname], stdin=subprocess.PIPE)
os.dup2(tee.stdin.fileno(), sys.stdout.fileno())
os.dup2(tee.stdin.fileno(), sys.stderr.fileno())
def copy_file_to_log(log_dir):
dirs_to_cp = ['.', 'config', 'dataloader', 'models', 'utils']
files_to_cp = ['*.py', '*.json', '*.sh', '*.conf']
for dir_name in dirs_to_cp:
dir_name = osp.join(log_dir, 'code', dir_name)
if not osp.exists(dir_name):
os.makedirs(dir_name)
for dir_name, file_name in itertools.product(dirs_to_cp, files_to_cp):
filename = osp.join(dir_name, file_name)
if len(glob.glob(filename)) > 0:
os.system(f'cp {filename} {osp.join(log_dir, "code", dir_name)}')
log.info(f'Files copied to {osp.join(log_dir, "code")}')
def set_random_seed(random_seed):
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
random.seed(random_seed)
np.random.seed(random_seed)
def set_training_steps(config, num_samples):
if config['parallel'] and config['dp_type'] == 'dp':
config['num_iter_per_epoch'] = int(np.ceil(num_samples / config['batch_size']))
else:
config['num_iter_per_epoch'] = int(np.ceil(num_samples / (config['batch_size'] * config['num_gpus'])))
if 'train_steps' not in config:
config['train_steps'] = config['num_iter_per_epoch'] * config['num_epochs']
if 'warmup_steps' not in config:
config['warmup_steps'] = int(config['train_steps'] * config['warmup_ratio'])
return config

456
utils/model_utils.py Normal file
View file

@ -0,0 +1,456 @@
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
def truncated_normal_(tensor, mean=0, std=1):
size = tensor.shape
tmp = tensor.new_empty(size + (4,)).normal_()
valid = (tmp < 2) & (tmp > -2)
ind = valid.max(-1, keepdim=True)[1]
tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
tensor.data.mul_(std).add_(mean)
def init_params(module, initializer='normal'):
if isinstance(module, nn.Linear):
if initializer == 'kaiming_normal':
nn.init.kaiming_normal_(module.weight.data)
elif initializer == 'normal':
nn.init.normal_(module.weight.data, std=0.02)
elif initializer == 'truncated_normal':
truncated_normal_(module.weight.data, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias.data)
# log.info('initialized Linear')
elif isinstance(module, nn.Embedding):
if initializer == 'kaiming_normal':
nn.init.kaiming_normal_(module.weight.data)
elif initializer == 'normal':
nn.init.normal_(module.weight.data, std=0.02)
elif initializer == 'truncated_normal':
truncated_normal_(module.weight.data, std=0.02)
elif isinstance(module, nn.Conv2d) or isinstance(module, nn.Conv1d):
nn.init.kaiming_normal_(module.weight, mode='fan_out')
# log.info('initialized Conv')
elif isinstance(module, nn.RNNBase) or isinstance(module, nn.LSTMCell) or isinstance(module, nn.GRUCell):
for name, param in module.named_parameters():
if 'weight' in name:
nn.init.orthogonal_(param.data)
elif 'bias' in name:
nn.init.normal_(param.data)
# log.info('initialized LSTM')
elif isinstance(module, nn.BatchNorm1d) or isinstance(module, nn.BatchNorm2d):
module.weight.data.normal_(1.0, 0.02)
# log.info('initialized BatchNorm')
def TensorboardWriter(save_path):
from torch.utils.tensorboard import SummaryWriter
return SummaryWriter(save_path, comment="Unmt")
DEFAULT_EPS = 1e-8
PADDED_Y_VALUE = -1
def listMLE(y_pred, y_true, eps=DEFAULT_EPS, padded_value_indicator=PADDED_Y_VALUE):
"""
ListMLE loss introduced in "Listwise Approach to Learning to Rank - Theory and Algorithm".
:param y_pred: predictions from the model, shape [batch_size, slate_length]
:param y_true: ground truth labels, shape [batch_size, slate_length]
:param eps: epsilon value, used for numerical stability
:param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
:return: loss value, a torch.Tensor
"""
# shuffle for randomised tie resolution
random_indices = torch.randperm(y_pred.shape[-1])
y_pred_shuffled = y_pred[:, random_indices]
y_true_shuffled = y_true[:, random_indices]
y_true_sorted, indices = y_true_shuffled.sort(descending=True, dim=-1)
mask = y_true_sorted == padded_value_indicator
preds_sorted_by_true = torch.gather(y_pred_shuffled, dim=1, index=indices)
preds_sorted_by_true[mask] = float("-inf")
max_pred_values, _ = preds_sorted_by_true.max(dim=1, keepdim=True)
preds_sorted_by_true_minus_max = preds_sorted_by_true - max_pred_values
cumsums = torch.cumsum(preds_sorted_by_true_minus_max.exp().flip(dims=[1]), dim=1).flip(dims=[1])
observation_loss = torch.log(cumsums + eps) - preds_sorted_by_true_minus_max
observation_loss[mask] = 0.0
return torch.mean(torch.sum(observation_loss, dim=1))
def approxNDCGLoss(y_pred, y_true, eps=DEFAULT_EPS, padded_value_indicator=PADDED_Y_VALUE, alpha=1.):
"""
Loss based on approximate NDCG introduced in "A General Approximation Framework for Direct Optimization of
Information Retrieval Measures". Please note that this method does not implement any kind of truncation.
:param y_pred: predictions from the model, shape [batch_size, slate_length]
:param y_true: ground truth labels, shape [batch_size, slate_length]
:param eps: epsilon value, used for numerical stability
:param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
:param alpha: score difference weight used in the sigmoid function
:return: loss value, a torch.Tensor
"""
device = y_pred.device
y_pred = y_pred.clone()
y_true = y_true.clone()
padded_mask = y_true == padded_value_indicator
y_pred[padded_mask] = float("-inf")
y_true[padded_mask] = float("-inf")
# Here we sort the true and predicted relevancy scores.
y_pred_sorted, indices_pred = y_pred.sort(descending=True, dim=-1)
y_true_sorted, _ = y_true.sort(descending=True, dim=-1)
# After sorting, we can mask out the pairs of indices (i, j) containing index of a padded element.
true_sorted_by_preds = torch.gather(y_true, dim=1, index=indices_pred)
true_diffs = true_sorted_by_preds[:, :, None] - true_sorted_by_preds[:, None, :]
padded_pairs_mask = torch.isfinite(true_diffs)
padded_pairs_mask.diagonal(dim1=-2, dim2=-1).zero_()
# Here we clamp the -infs to get correct gains and ideal DCGs (maxDCGs)
true_sorted_by_preds.clamp_(min=0.)
y_true_sorted.clamp_(min=0.)
# Here we find the gains, discounts and ideal DCGs per slate.
pos_idxs = torch.arange(1, y_pred.shape[1] + 1).to(device)
D = torch.log2(1. + pos_idxs.float())[None, :]
maxDCGs = torch.sum((torch.pow(2, y_true_sorted) - 1) / D, dim=-1).clamp(min=eps)
G = (torch.pow(2, true_sorted_by_preds) - 1) / maxDCGs[:, None]
# Here we approximate the ranking positions according to Eqs 19-20 and later approximate NDCG (Eq 21)
scores_diffs = (y_pred_sorted[:, :, None] - y_pred_sorted[:, None, :])
scores_diffs[~padded_pairs_mask] = 0.
approx_pos = 1. + torch.sum(padded_pairs_mask.float() * (torch.sigmoid(-alpha * scores_diffs).clamp(min=eps)),
dim=-1)
approx_D = torch.log2(1. + approx_pos)
approx_NDCG = torch.sum((G / approx_D), dim=-1)
return -torch.mean(approx_NDCG)
# return -torch.mean(approx_NDCG)
def listNet(y_pred, y_true, eps=DEFAULT_EPS, padded_value_indicator=PADDED_Y_VALUE):
"""
ListNet loss introduced in "Learning to Rank: From Pairwise Approach to Listwise Approach".
:param y_pred: predictions from the model, shape [batch_size, slate_length]
:param y_true: ground truth labels, shape [batch_size, slate_length]
:param eps: epsilon value, used for numerical stability
:param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
:return: loss value, a torch.Tensor
"""
y_pred = y_pred.clone()
y_true = y_true.clone()
mask = y_true == padded_value_indicator
y_pred[mask] = float('-inf')
y_true[mask] = float('-inf')
preds_smax = F.softmax(y_pred, dim=1)
true_smax = F.softmax(y_true, dim=1)
preds_smax = preds_smax + eps
preds_log = torch.log(preds_smax)
return torch.mean(-torch.sum(true_smax * preds_log, dim=1))
def deterministic_neural_sort(s, tau, mask):
"""
Deterministic neural sort.
Code taken from "Stochastic Optimization of Sorting Networks via Continuous Relaxations", ICLR 2019.
Minor modifications applied to the original code (masking).
:param s: values to sort, shape [batch_size, slate_length]
:param tau: temperature for the final softmax function
:param mask: mask indicating padded elements
:return: approximate permutation matrices of shape [batch_size, slate_length, slate_length]
"""
dev = s.device
n = s.size()[1]
one = torch.ones((n, 1), dtype=torch.float32, device=dev)
s = s.masked_fill(mask[:, :, None], -1e8)
A_s = torch.abs(s - s.permute(0, 2, 1))
A_s = A_s.masked_fill(mask[:, :, None] | mask[:, None, :], 0.0)
B = torch.matmul(A_s, torch.matmul(one, torch.transpose(one, 0, 1)))
temp = [n - m + 1 - 2 * (torch.arange(n - m, device=dev) + 1) for m in mask.squeeze(-1).sum(dim=1)]
temp = [t.type(torch.float32) for t in temp]
temp = [torch.cat((t, torch.zeros(n - len(t), device=dev))) for t in temp]
scaling = torch.stack(temp).type(torch.float32).to(dev) # type: ignore
s = s.masked_fill(mask[:, :, None], 0.0)
C = torch.matmul(s, scaling.unsqueeze(-2))
P_max = (C - B).permute(0, 2, 1)
P_max = P_max.masked_fill(mask[:, :, None] | mask[:, None, :], -np.inf)
P_max = P_max.masked_fill(mask[:, :, None] & mask[:, None, :], 1.0)
sm = torch.nn.Softmax(-1)
P_hat = sm(P_max / tau)
return P_hat
def sample_gumbel(samples_shape, device, eps=1e-10) -> torch.Tensor:
"""
Sampling from Gumbel distribution.
Code taken from "Stochastic Optimization of Sorting Networks via Continuous Relaxations", ICLR 2019.
Minor modifications applied to the original code (masking).
:param samples_shape: shape of the output samples tensor
:param device: device of the output samples tensor
:param eps: epsilon for the logarithm function
:return: Gumbel samples tensor of shape samples_shape
"""
U = torch.rand(samples_shape, device=device)
return -torch.log(-torch.log(U + eps) + eps)
def apply_mask_and_get_true_sorted_by_preds(y_pred, y_true, padding_indicator=PADDED_Y_VALUE):
mask = y_true == padding_indicator
y_pred[mask] = float('-inf')
y_true[mask] = 0.0
_, indices = y_pred.sort(descending=True, dim=-1)
return torch.gather(y_true, dim=1, index=indices)
def dcg(y_pred, y_true, ats=None, gain_function=lambda x: torch.pow(2, x) - 1, padding_indicator=PADDED_Y_VALUE):
"""
Discounted Cumulative Gain at k.
Compute DCG at ranks given by ats or at the maximum rank if ats is None.
:param y_pred: predictions from the model, shape [batch_size, slate_length]
:param y_true: ground truth labels, shape [batch_size, slate_length]
:param ats: optional list of ranks for DCG evaluation, if None, maximum rank is used
:param gain_function: callable, gain function for the ground truth labels, e.g. torch.pow(2, x) - 1
:param padding_indicator: an indicator of the y_true index containing a padded item, e.g. -1
:return: DCG values for each slate and evaluation position, shape [batch_size, len(ats)]
"""
y_true = y_true.clone()
y_pred = y_pred.clone()
actual_length = y_true.shape[1]
if ats is None:
ats = [actual_length]
ats = [min(at, actual_length) for at in ats]
true_sorted_by_preds = apply_mask_and_get_true_sorted_by_preds(y_pred, y_true, padding_indicator)
discounts = (torch.tensor(1) / torch.log2(torch.arange(true_sorted_by_preds.shape[1], dtype=torch.float) + 2.0)).to(
device=true_sorted_by_preds.device)
gains = gain_function(true_sorted_by_preds)
discounted_gains = (gains * discounts)[:, :np.max(ats)]
cum_dcg = torch.cumsum(discounted_gains, dim=1)
ats_tensor = torch.tensor(ats, dtype=torch.long) - torch.tensor(1)
dcg = cum_dcg[:, ats_tensor]
return dcg
def sinkhorn_scaling(mat, mask=None, tol=1e-6, max_iter=50):
"""
Sinkhorn scaling procedure.
:param mat: a tensor of square matrices of shape N x M x M, where N is batch size
:param mask: a tensor of masks of shape N x M
:param tol: Sinkhorn scaling tolerance
:param max_iter: maximum number of iterations of the Sinkhorn scaling
:return: a tensor of (approximately) doubly stochastic matrices
"""
if mask is not None:
mat = mat.masked_fill(mask[:, None, :] | mask[:, :, None], 0.0)
mat = mat.masked_fill(mask[:, None, :] & mask[:, :, None], 1.0)
for _ in range(max_iter):
mat = mat / mat.sum(dim=1, keepdim=True).clamp(min=DEFAULT_EPS)
mat = mat / mat.sum(dim=2, keepdim=True).clamp(min=DEFAULT_EPS)
if torch.max(torch.abs(mat.sum(dim=2) - 1.)) < tol and torch.max(torch.abs(mat.sum(dim=1) - 1.)) < tol:
break
if mask is not None:
mat = mat.masked_fill(mask[:, None, :] | mask[:, :, None], 0.0)
return mat
def stochastic_neural_sort(s, n_samples, tau, mask, beta=1.0, log_scores=True, eps=1e-10):
"""
Stochastic neural sort. Please note that memory complexity grows by factor n_samples.
Code taken from "Stochastic Optimization of Sorting Networks via Continuous Relaxations", ICLR 2019.
Minor modifications applied to the original code (masking).
:param s: values to sort, shape [batch_size, slate_length]
:param n_samples: number of samples (approximations) for each permutation matrix
:param tau: temperature for the final softmax function
:param mask: mask indicating padded elements
:param beta: scale parameter for the Gumbel distribution
:param log_scores: whether to apply the logarithm function to scores prior to Gumbel perturbation
:param eps: epsilon for the logarithm function
:return: approximate permutation matrices of shape [n_samples, batch_size, slate_length, slate_length]
"""
dev = s.device
batch_size = s.size()[0]
n = s.size()[1]
s_positive = s + torch.abs(s.min())
samples = beta * sample_gumbel([n_samples, batch_size, n, 1], device=dev)
if log_scores:
s_positive = torch.log(s_positive + eps)
s_perturb = (s_positive + samples).view(n_samples * batch_size, n, 1)
mask_repeated = mask.repeat_interleave(n_samples, dim=0)
P_hat = deterministic_neural_sort(s_perturb, tau, mask_repeated)
P_hat = P_hat.view(n_samples, batch_size, n, n)
return P_hat
def neuralNDCG(y_pred, y_true, padded_value_indicator=PADDED_Y_VALUE, temperature=1., powered_relevancies=True, k=None,
stochastic=False, n_samples=32, beta=0.1, log_scores=True):
"""
NeuralNDCG loss introduced in "NeuralNDCG: Direct Optimisation of a Ranking Metric via Differentiable
Relaxation of Sorting" - https://arxiv.org/abs/2102.07831. Based on the NeuralSort algorithm.
:param y_pred: predictions from the model, shape [batch_size, slate_length]
:param y_true: ground truth labels, shape [batch_size, slate_length]
:param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
:param temperature: temperature for the NeuralSort algorithm
:param powered_relevancies: whether to apply 2^x - 1 gain function, x otherwise
:param k: rank at which the loss is truncated
:param stochastic: whether to calculate the stochastic variant
:param n_samples: how many stochastic samples are taken, used if stochastic == True
:param beta: beta parameter for NeuralSort algorithm, used if stochastic == True
:param log_scores: log_scores parameter for NeuralSort algorithm, used if stochastic == True
:return: loss value, a torch.Tensor
"""
dev = y_pred.device
if k is None:
k = y_true.shape[1]
mask = (y_true == padded_value_indicator)
# Choose the deterministic/stochastic variant
if stochastic:
P_hat = stochastic_neural_sort(y_pred.unsqueeze(-1), n_samples=n_samples, tau=temperature, mask=mask,
beta=beta, log_scores=log_scores)
else:
P_hat = deterministic_neural_sort(y_pred.unsqueeze(-1), tau=temperature, mask=mask).unsqueeze(0)
# Perform sinkhorn scaling to obtain doubly stochastic permutation matrices
P_hat = sinkhorn_scaling(P_hat.view(P_hat.shape[0] * P_hat.shape[1], P_hat.shape[2], P_hat.shape[3]),
mask.repeat_interleave(P_hat.shape[0], dim=0), tol=1e-6, max_iter=50)
P_hat = P_hat.view(int(P_hat.shape[0] / y_pred.shape[0]), y_pred.shape[0], P_hat.shape[1], P_hat.shape[2])
# Mask P_hat and apply to true labels, ie approximately sort them
P_hat = P_hat.masked_fill(mask[None, :, :, None] | mask[None, :, None, :], 0.)
y_true_masked = y_true.masked_fill(mask, 0.).unsqueeze(-1).unsqueeze(0)
if powered_relevancies:
y_true_masked = torch.pow(2., y_true_masked) - 1.
ground_truth = torch.matmul(P_hat, y_true_masked).squeeze(-1)
discounts = (torch.tensor(1.) / torch.log2(torch.arange(y_true.shape[-1], dtype=torch.float) + 2.)).to(dev)
discounted_gains = ground_truth * discounts
if powered_relevancies:
idcg = dcg(y_true, y_true, ats=[k]).permute(1, 0)
else:
idcg = dcg(y_true, y_true, ats=[k], gain_function=lambda x: x).permute(1, 0)
discounted_gains = discounted_gains[:, :, :k]
ndcg = discounted_gains.sum(dim=-1) / (idcg + DEFAULT_EPS)
idcg_mask = idcg == 0.
ndcg = ndcg.masked_fill(idcg_mask.repeat(ndcg.shape[0], 1), 0.)
assert (ndcg < 0.).sum() >= 0, "every ndcg should be non-negative"
if idcg_mask.all():
return torch.tensor(0.)
mean_ndcg = ndcg.sum() / ((~idcg_mask).sum() * ndcg.shape[0]) # type: ignore
return -1. * mean_ndcg # -1 cause we want to maximize NDCG
def neuralNDCG_transposed(y_pred, y_true, padded_value_indicator=PADDED_Y_VALUE, temperature=1.,
powered_relevancies=True, k=None, stochastic=False, n_samples=32, beta=0.1, log_scores=True,
max_iter=50, tol=1e-6):
"""
NeuralNDCG Transposed loss introduced in "NeuralNDCG: Direct Optimisation of a Ranking Metric via Differentiable
Relaxation of Sorting" - https://arxiv.org/abs/2102.07831. Based on the NeuralSort algorithm.
:param y_pred: predictions from the model, shape [batch_size, slate_length]
:param y_true: ground truth labels, shape [batch_size, slate_length]
:param padded_value_indicator: an indicator of the y_true index containing a padded item, e.g. -1
:param temperature: temperature for the NeuralSort algorithm
:param powered_relevancies: whether to apply 2^x - 1 gain function, x otherwise
:param k: rank at which the loss is truncated
:param stochastic: whether to calculate the stochastic variant
:param n_samples: how many stochastic samples are taken, used if stochastic == True
:param beta: beta parameter for NeuralSort algorithm, used if stochastic == True
:param log_scores: log_scores parameter for NeuralSort algorithm, used if stochastic == True
:param max_iter: maximum iteration count for Sinkhorn scaling
:param tol: tolerance for Sinkhorn scaling
:return: loss value, a torch.Tensor
"""
dev = y_pred.device
if k is None:
k = y_true.shape[1]
mask = (y_true == padded_value_indicator)
if stochastic:
P_hat = stochastic_neural_sort(y_pred.unsqueeze(-1), n_samples=n_samples, tau=temperature, mask=mask,
beta=beta, log_scores=log_scores)
else:
P_hat = deterministic_neural_sort(y_pred.unsqueeze(-1), tau=temperature, mask=mask).unsqueeze(0)
# Perform sinkhorn scaling to obtain doubly stochastic permutation matrices
P_hat_masked = sinkhorn_scaling(P_hat.view(P_hat.shape[0] * y_pred.shape[0], y_pred.shape[1], y_pred.shape[1]),
mask.repeat_interleave(P_hat.shape[0], dim=0), tol=tol, max_iter=max_iter)
P_hat_masked = P_hat_masked.view(P_hat.shape[0], y_pred.shape[0], y_pred.shape[1], y_pred.shape[1])
discounts = (torch.tensor(1) / torch.log2(torch.arange(y_true.shape[-1], dtype=torch.float) + 2.)).to(dev)
# This takes care of the @k metric truncation - if something is @>k, it is useless and gets 0.0 discount
discounts[k:] = 0.
discounts = discounts[None, None, :, None]
# Here the discounts become expected discounts
discounts = torch.matmul(P_hat_masked.permute(0, 1, 3, 2), discounts).squeeze(-1)
if powered_relevancies:
gains = torch.pow(2., y_true) - 1
discounted_gains = gains.unsqueeze(0) * discounts
idcg = dcg(y_true, y_true, ats=[k]).squeeze()
else:
gains = y_true
discounted_gains = gains.unsqueeze(0) * discounts
idcg = dcg(y_true, y_true, ats=[k]).squeeze()
ndcg = discounted_gains.sum(dim=2) / (idcg + DEFAULT_EPS)
idcg_mask = idcg == 0.
ndcg = ndcg.masked_fill(idcg_mask, 0.)
assert (ndcg < 0.).sum() >= 0, "every ndcg should be non-negative"
if idcg_mask.all():
return torch.tensor(0.)
mean_ndcg = ndcg.sum() / ((~idcg_mask).sum() * ndcg.shape[0]) # type: ignore
return -1. * mean_ndcg # -1 cause we want to maximize NDCG

41
utils/modules.py Normal file
View file

@ -0,0 +1,41 @@
from collections import Counter, defaultdict
import logging
from typing import Union, List, Dict, Any
import torch
from torch import nn
class Identity(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x
class Reshaper(nn.Module):
def __init__(self, *output_shape):
super().__init__()
self.output_shape = output_shape
def forward(self, input: torch.Tensor):
return input.view(*self.output_shape)
class Normalizer(nn.Module):
def __init__(self, target_norm=1.):
super().__init__()
self.target_norm = target_norm
def forward(self, input: torch.Tensor):
return input * self.target_norm / input.norm(p=2, dim=1, keepdim=True)
class Squeezer(nn.Module):
def __init__(self, dim=-1):
super().__init__()
self.dim = dim
def forward(self, input):
return torch.squeeze(input, dim=self.dim)

389
utils/optim_utils.py Normal file
View file

@ -0,0 +1,389 @@
import logging
import math
import numpy as np
import random
import functools
import glog as log
import torch
from torch import nn, optim
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler, ConstantLR
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_
from pytorch_transformers.optimization import AdamW
class WarmupLinearScheduleNonZero(_LRScheduler):
""" Linear warmup and then linear decay.
Linearly increases learning rate from 0 to max_lr over `warmup_steps` training steps.
Linearly decreases learning rate linearly to min_lr over remaining `t_total - warmup_steps` steps.
"""
def __init__(self, optimizer, warmup_steps, t_total, min_lr=1e-5, last_epoch=-1):
self.warmup_steps = warmup_steps
self.t_total = t_total
self.min_lr = min_lr
super(WarmupLinearScheduleNonZero, self).__init__(optimizer, last_epoch=last_epoch)
def get_lr(self):
step = self.last_epoch
if step < self.warmup_steps:
lr_factor = float(step) / float(max(1, self.warmup_steps))
else:
lr_factor = max(0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps)))
return [base_lr * lr_factor if (base_lr * lr_factor) > self.min_lr else self.min_lr for base_lr in self.base_lrs]
def init_optim(model, config):
optimizer_grouped_parameters = []
gnn_params = []
encoder_params_with_decay = []
encoder_params_without_decay = []
exclude_from_weight_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
for module_name, module in model.named_children():
for param_name, param in module.named_parameters():
if param.requires_grad:
if "gnn" in param_name:
gnn_params.append(param)
elif module_name == 'encoder':
if any(ex in param_name for ex in exclude_from_weight_decay):
encoder_params_without_decay.append(param)
else:
encoder_params_with_decay.append(param)
optimizer_grouped_parameters = [
{
'params': gnn_params,
'weight_decay': config.gnn_weight_decay,
'lr': config['learning_rate_gnn'] if config.use_diff_lr_gnn else config['learning_rate_bert']
}
]
optimizer_grouped_parameters.extend(
[
{
'params': encoder_params_without_decay,
'weight_decay': 0,
'lr': config['learning_rate_bert']
},
{
'params': encoder_params_with_decay,
'weight_decay': 0.01,
'lr': config['learning_rate_bert']
}
]
)
optimizer = AdamW(optimizer_grouped_parameters, lr=config['learning_rate_gnn'])
scheduler = WarmupLinearScheduleNonZero(
optimizer,
warmup_steps=config['warmup_steps'],
t_total=config['train_steps'],
min_lr=config['min_lr']
)
return optimizer, scheduler
def build_torch_optimizer(model, config):
"""Builds the PyTorch optimizer.
We use the default parameters for Adam that are suggested by
the original paper https://arxiv.org/pdf/1412.6980.pdf
These values are also used by other established implementations,
e.g. https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer
https://keras.io/optimizers/
Recently there are slightly different values used in the paper
"Attention is all you need"
https://arxiv.org/pdf/1706.03762.pdf, particularly the value beta2=0.98
was used there however, beta2=0.999 is still arguably the more
established value, so we use that here as well
Args:
model: The model to optimize.
config: The dictionary of options.
Returns:
A ``torch.optim.Optimizer`` instance.
"""
params = [p for p in model.parameters() if p.requires_grad]
betas = [0.9, 0.999]
exclude_from_weight_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
params = {'bert': [], 'task': []}
for module_name, module in model.named_children():
if module_name == 'encoder':
param_type = 'bert'
else:
param_type = 'task'
for param_name, param in module.named_parameters():
if param.requires_grad:
if any(ex in param_name for ex in exclude_from_weight_decay):
params[param_type] += [
{
"params": [param],
"weight_decay": 0
}
]
else:
params[param_type] += [
{
"params": [param],
"weight_decay": 0.01
}
]
if config['task_optimizer'] == 'adamw':
log.info('Using AdamW as task optimizer')
task_optimizer = AdamWeightDecay(params['task'],
lr=config["learning_rate_task"],
betas=betas,
eps=1e-6)
elif config['task_optimizer'] == 'adam':
log.info('Using Adam as task optimizer')
task_optimizer = optim.Adam(params['task'],
lr=config["learning_rate_task"],
betas=betas,
eps=1e-6)
if len(params['bert']) > 0:
bert_optimizer = AdamWeightDecay(params['bert'],
lr=config["learning_rate_bert"],
betas=betas,
eps=1e-6)
optimizer = MultipleOptimizer([bert_optimizer, task_optimizer])
else:
optimizer = task_optimizer
return optimizer
def make_learning_rate_decay_fn(decay_method, train_steps, **kwargs):
"""Returns the learning decay function from options."""
if decay_method == "linear":
return functools.partial(
linear_decay,
global_steps=train_steps,
**kwargs)
elif decay_method == "exp":
return functools.partial(
exp_decay,
global_steps=train_steps,
**kwargs)
else:
raise ValueError(f'{decay_method} not found')
def linear_decay(step, global_steps, warmup_steps=100, initial_learning_rate=1, end_learning_rate=0, **kargs):
if step < warmup_steps:
return initial_learning_rate * step / warmup_steps
else:
return (initial_learning_rate - end_learning_rate) * \
(1 - (step - warmup_steps) / (global_steps - warmup_steps)) + \
end_learning_rate
def exp_decay(step, global_steps, decay_exp=1, warmup_steps=100, initial_learning_rate=1, end_learning_rate=0, **kargs):
if step < warmup_steps:
return initial_learning_rate * step / warmup_steps
else:
return (initial_learning_rate - end_learning_rate) * \
((1 - (step - warmup_steps) / (global_steps - warmup_steps)) ** decay_exp) + \
end_learning_rate
class MultipleOptimizer(object):
""" Implement multiple optimizers needed for sparse adam """
def __init__(self, op):
""" ? """
self.optimizers = op
@property
def param_groups(self):
param_groups = []
for optimizer in self.optimizers:
param_groups.extend(optimizer.param_groups)
return param_groups
def zero_grad(self):
""" ? """
for op in self.optimizers:
op.zero_grad()
def step(self):
""" ? """
for op in self.optimizers:
op.step()
@property
def state(self):
""" ? """
return {k: v for op in self.optimizers for k, v in op.state.items()}
def state_dict(self):
""" ? """
return [op.state_dict() for op in self.optimizers]
def load_state_dict(self, state_dicts):
""" ? """
assert len(state_dicts) == len(self.optimizers)
for i in range(len(state_dicts)):
self.optimizers[i].load_state_dict(state_dicts[i])
class OptimizerBase(object):
"""
Controller class for optimization. Mostly a thin
wrapper for `optim`, but also useful for implementing
rate scheduling beyond what is currently available.
Also implements necessary methods for training RNNs such
as grad manipulations.
"""
def __init__(self,
optimizer,
learning_rate,
learning_rate_decay_fn=None,
max_grad_norm=None):
"""Initializes the controller.
Args:
optimizer: A ``torch.optim.Optimizer`` instance.
learning_rate: The initial learning rate.
learning_rate_decay_fn: An optional callable taking the current step
as argument and return a learning rate scaling factor.
max_grad_norm: Clip gradients to this global norm.
"""
self._optimizer = optimizer
self._learning_rate = learning_rate
self._learning_rate_decay_fn = learning_rate_decay_fn
self._max_grad_norm = max_grad_norm or 0
self._training_step = 1
self._decay_step = 1
@classmethod
def from_opt(cls, model, config, checkpoint=None):
"""Builds the optimizer from options.
Args:
cls: The ``Optimizer`` class to instantiate.
model: The model to optimize.
opt: The dict of user options.
checkpoint: An optional checkpoint to load states from.
Returns:
An ``Optimizer`` instance.
"""
optim_opt = config
optim_state_dict = None
if config["loads_ckpt"] and checkpoint is not None:
optim = checkpoint['optim']
ckpt_opt = checkpoint['opt']
ckpt_state_dict = {}
if isinstance(optim, Optimizer): # Backward compatibility.
ckpt_state_dict['training_step'] = optim._step + 1
ckpt_state_dict['decay_step'] = optim._step + 1
ckpt_state_dict['optimizer'] = optim.optimizer.state_dict()
else:
ckpt_state_dict = optim
if config["reset_optim"] == 'none':
# Load everything from the checkpoint.
optim_opt = ckpt_opt
optim_state_dict = ckpt_state_dict
elif config["reset_optim"] == 'all':
# Build everything from scratch.
pass
elif config["reset_optim"] == 'states':
# Reset optimizer, keep options.
optim_opt = ckpt_opt
optim_state_dict = ckpt_state_dict
del optim_state_dict['optimizer']
elif config["reset_optim"] == 'keep_states':
# Reset options, keep optimizer.
optim_state_dict = ckpt_state_dict
learning_rates = [
optim_opt["learning_rate_bert"],
optim_opt["learning_rate_gnn"]
]
decay_fn = [
make_learning_rate_decay_fn(optim_opt['decay_method_bert'],
optim_opt['train_steps'],
warmup_steps=optim_opt['warmup_steps'],
decay_exp=optim_opt['decay_exp']),
make_learning_rate_decay_fn(optim_opt['decay_method_gnn'],
optim_opt['train_steps'],
warmup_steps=optim_opt['warmup_steps'],
decay_exp=optim_opt['decay_exp']),
]
optimizer = cls(
build_torch_optimizer(model, optim_opt),
learning_rates,
learning_rate_decay_fn=decay_fn,
max_grad_norm=optim_opt["max_grad_norm"])
if optim_state_dict:
optimizer.load_state_dict(optim_state_dict)
return optimizer
@property
def training_step(self):
"""The current training step."""
return self._training_step
def learning_rate(self):
"""Returns the current learning rate."""
if self._learning_rate_decay_fn is None:
return self._learning_rate
return [decay_fn(self._decay_step) * learning_rate \
for decay_fn, learning_rate in \
zip(self._learning_rate_decay_fn, self._learning_rate)]
def state_dict(self):
return {
'training_step': self._training_step,
'decay_step': self._decay_step,
'optimizer': self._optimizer.state_dict()
}
def load_state_dict(self, state_dict):
self._training_step = state_dict['training_step']
# State can be partially restored.
if 'decay_step' in state_dict:
self._decay_step = state_dict['decay_step']
if 'optimizer' in state_dict:
self._optimizer.load_state_dict(state_dict['optimizer'])
def zero_grad(self):
"""Zero the gradients of optimized parameters."""
self._optimizer.zero_grad()
def backward(self, loss):
"""Wrapper for backward pass. Some optimizer requires ownership of the
backward pass."""
loss.backward()
def step(self):
"""Update the model parameters based on current gradients.
Optionally, will employ gradient modification or update learning
rate.
"""
learning_rate = self.learning_rate()
if isinstance(self._optimizer, MultipleOptimizer):
optimizers = self._optimizer.optimizers
else:
optimizers = [self._optimizer]
for lr, op in zip(learning_rate, optimizers):
for group in op.param_groups:
group['lr'] = lr
if self._max_grad_norm > 0:
clip_grad_norm_(group['params'], self._max_grad_norm)
self._optimizer.step()
self._decay_step += 1
self._training_step += 1

322
utils/visdial_metrics.py Normal file
View file

@ -0,0 +1,322 @@
"""
A Metric observes output of certain model, for example, in form of logits or
scores, and accumulates a particular metric with reference to some provided
targets. In context of VisDial, we use Recall (@ 1, 5, 10), Mean Rank, Mean
Reciprocal Rank (MRR) and Normalized Discounted Cumulative Gain (NDCG).
Each ``Metric`` must atleast implement three methods:
- ``observe``, update accumulated metric with currently observed outputs
and targets.
- ``retrieve`` to return the accumulated metric., an optionally reset
internally accumulated metric (this is commonly done between two epochs
after validation).
- ``reset`` to explicitly reset the internally accumulated metric.
Caveat, if you wish to implement your own class of Metric, make sure you call
``detach`` on output tensors (like logits), else it will cause memory leaks.
"""
import torch
import torch.distributed as dist
import numpy as np
def scores_to_ranks(scores: torch.Tensor):
"""Convert model output scores into ranks."""
batch_size, num_rounds, num_options = scores.size()
scores = scores.view(-1, num_options)
# sort in descending order - largest score gets highest rank
sorted_ranks, ranked_idx = scores.sort(1, descending=True)
# i-th position in ranked_idx specifies which score shall take this
# position but we want i-th position to have rank of score at that
# position, do this conversion
ranks = ranked_idx.clone().fill_(0)
for i in range(ranked_idx.size(0)):
for j in range(num_options):
ranks[i][ranked_idx[i][j]] = j
# convert from 0-99 ranks to 1-100 ranks
ranks += 1
ranks = ranks.view(batch_size, num_rounds, num_options)
return ranks
class SparseGTMetrics(object):
"""
A class to accumulate all metrics with sparse ground truth annotations.
These include Recall (@ 1, 5, 10), Mean Rank and Mean Reciprocal Rank.
"""
def __init__(self):
self._rank_list = []
self._rank_list_rnd = []
self.num_rounds = None
def observe(
self, predicted_scores: torch.Tensor, target_ranks: torch.Tensor
):
predicted_scores = predicted_scores.detach()
# shape: (batch_size, num_rounds, num_options)
predicted_ranks = scores_to_ranks(predicted_scores)
batch_size, num_rounds, num_options = predicted_ranks.size()
self.num_rounds = num_rounds
# collapse batch dimension
predicted_ranks = predicted_ranks.view(
batch_size * num_rounds, num_options
)
# shape: (batch_size * num_rounds, )
target_ranks = target_ranks.view(batch_size * num_rounds).long()
# shape: (batch_size * num_rounds, )
predicted_gt_ranks = predicted_ranks[
torch.arange(batch_size * num_rounds), target_ranks
]
self._rank_list.extend(list(predicted_gt_ranks.cpu().numpy()))
predicted_gt_ranks_rnd = predicted_gt_ranks.view(batch_size, num_rounds)
# predicted gt ranks
self._rank_list_rnd.append(predicted_gt_ranks_rnd.cpu().numpy())
def retrieve(self, reset: bool = True):
num_examples = len(self._rank_list)
if num_examples > 0:
# convert to numpy array for easy calculation.
__rank_list = torch.tensor(self._rank_list).float()
metrics = {
"r@1": torch.mean((__rank_list <= 1).float()).item(),
"r@5": torch.mean((__rank_list <= 5).float()).item(),
"r@10": torch.mean((__rank_list <= 10).float()).item(),
"mean": torch.mean(__rank_list).item(),
"mrr": torch.mean(__rank_list.reciprocal()).item()
}
# add round metrics
_rank_list_rnd = np.concatenate(self._rank_list_rnd)
_rank_list_rnd = _rank_list_rnd.astype(float)
r_1_rnd = np.mean(_rank_list_rnd <= 1, axis=0)
r_5_rnd = np.mean(_rank_list_rnd <= 5, axis=0)
r_10_rnd = np.mean(_rank_list_rnd <= 10, axis=0)
mean_rnd = np.mean(_rank_list_rnd, axis=0)
mrr_rnd = np.mean(np.reciprocal(_rank_list_rnd), axis=0)
for rnd in range(1, self.num_rounds + 1):
metrics["r_1" + "_round_" + str(rnd)] = r_1_rnd[rnd-1]
metrics["r_5" + "_round_" + str(rnd)] = r_5_rnd[rnd-1]
metrics["r_10" + "_round_" + str(rnd)] = r_10_rnd[rnd-1]
metrics["mean" + "_round_" + str(rnd)] = mean_rnd[rnd-1]
metrics["mrr" + "_round_" + str(rnd)] = mrr_rnd[rnd-1]
else:
metrics = {}
if reset:
self.reset()
return metrics
def reset(self):
self._rank_list = []
self._rank_list_rnd = []
class NDCG(object):
def __init__(self):
self._ndcg_numerator = 0.0
self._ndcg_denominator = 0.0
def observe(
self, predicted_scores: torch.Tensor, target_relevance: torch.Tensor
):
"""
Observe model output scores and target ground truth relevance and
accumulate NDCG metric.
Parameters
----------
predicted_scores: torch.Tensor
A tensor of shape (batch_size, num_options), because dense
annotations are available for 1 randomly picked round out of 10.
target_relevance: torch.Tensor
A tensor of shape same as predicted scores, indicating ground truth
relevance of each answer option for a particular round.
"""
predicted_scores = predicted_scores.detach()
# shape: (batch_size, 1, num_options)
predicted_scores = predicted_scores.unsqueeze(1)
predicted_ranks = scores_to_ranks(predicted_scores)
# shape: (batch_size, num_options)
predicted_ranks = predicted_ranks.squeeze(1)
batch_size, num_options = predicted_ranks.size()
k = torch.sum(target_relevance != 0, dim=-1)
# shape: (batch_size, num_options)
_, rankings = torch.sort(predicted_ranks, dim=-1)
# Sort relevance in descending order so highest relevance gets top rnk.
_, best_rankings = torch.sort(
target_relevance, dim=-1, descending=True
)
# shape: (batch_size, )
batch_ndcg = []
for batch_index in range(batch_size):
num_relevant = k[batch_index]
dcg = self._dcg(
rankings[batch_index][:num_relevant],
target_relevance[batch_index],
)
best_dcg = self._dcg(
best_rankings[batch_index][:num_relevant],
target_relevance[batch_index],
)
batch_ndcg.append(dcg / best_dcg)
self._ndcg_denominator += batch_size
self._ndcg_numerator += sum(batch_ndcg)
def _dcg(self, rankings: torch.Tensor, relevance: torch.Tensor):
sorted_relevance = relevance[rankings].cpu().float()
discounts = torch.log2(torch.arange(len(rankings)).float() + 2)
return torch.sum(sorted_relevance / discounts, dim=-1)
def retrieve(self, reset: bool = True):
if self._ndcg_denominator > 0:
metrics = {
"ndcg": float(self._ndcg_numerator / self._ndcg_denominator)
}
else:
metrics = {}
if reset:
self.reset()
return metrics
def reset(self):
self._ndcg_numerator = 0.0
self._ndcg_denominator = 0.0
class SparseGTMetricsParallel(object):
"""
A class to accumulate all metrics with sparse ground truth annotations.
These include Recall (@ 1, 5, 10), Mean Rank and Mean Reciprocal Rank.
"""
def __init__(self, gpu_rank):
self.rank_1 = 0
self.rank_5 = 0
self.rank_10 = 0
self.ranks = 0
self.reciprocal = 0
self.count = 0
self.gpu_rank = gpu_rank
self.img_ids = []
def observe(
self, img_id: list, predicted_scores: torch.Tensor, target_ranks: torch.Tensor
):
if img_id in self.img_ids:
return
else:
self.img_ids.append(img_id)
predicted_scores = predicted_scores.detach()
# shape: (batch_size, num_rounds, num_options)
predicted_ranks = scores_to_ranks(predicted_scores)
batch_size, num_rounds, num_options = predicted_ranks.size()
self.num_rounds = num_rounds
# collapse batch dimension
predicted_ranks = predicted_ranks.view(
batch_size * num_rounds, num_options
)
# shape: (batch_size * num_rounds, )
target_ranks = target_ranks.view(batch_size * num_rounds).long()
# shape: (batch_size * num_rounds, )
predicted_gt_ranks = predicted_ranks[
torch.arange(batch_size * num_rounds), target_ranks
]
self.rank_1 += (predicted_gt_ranks <= 1).sum().item()
self.rank_5 += (predicted_gt_ranks <= 5).sum().item()
self.rank_10 += (predicted_gt_ranks <= 10).sum().item()
self.ranks += predicted_gt_ranks.sum().item()
self.reciprocal += predicted_gt_ranks.float().reciprocal().sum().item()
self.count += batch_size * num_rounds
def retrieve(self):
if self.count > 0:
# retrieve data from all gpu
# define tensor on GPU, count and total is the result at each GPU
t = torch.tensor([self.rank_1, self.rank_5, self.rank_10, self.ranks, self.reciprocal, self.count], dtype=torch.float32, device=f'cuda:{self.gpu_rank}')
dist.barrier() # synchronizes all processes
dist.all_reduce(t, op=torch.distributed.ReduceOp.SUM,) # Reduces the tensor data across all machines in such a way that all get the final result.
t = t.tolist()
self.rank_1, self.rank_5, self.rank_10, self.ranks, self.reciprocal, self.count = t
# convert to numpy array for easy calculation.
metrics = {
"r@1": self.rank_1 / self.count,
"r@5": self.rank_5 / self.count,
"r@10": self.rank_10 / self.count,
"mean": self.ranks / self.count,
"mrr": self.reciprocal / self.count,
"tot_rnds": self.count,
}
else:
metrics = {}
return metrics
def get_count(self):
return int(self.count)
class NDCGParallel(NDCG):
def __init__(self, gpu_rank):
super(NDCGParallel, self).__init__()
self.gpu_rank = gpu_rank
self.img_ids = []
self.count = 0
def observe(
self, img_id: int, predicted_scores: torch.Tensor, target_relevance: torch.Tensor
):
"""
Observe model output scores and target ground truth relevance and
accumulate NDCG metric.
Parameters
----------
predicted_scores: torch.Tensor
A tensor of shape (batch_size, num_options), because dense
annotations are available for 1 randomly picked round out of 10.
target_relevance: torch.Tensor
A tensor of shape same as predicted scores, indicating ground truth
relevance of each answer option for a particular round.
"""
if img_id in self.img_ids:
return
else:
self.img_ids.append(img_id)
self.count += 1
super(NDCGParallel, self).observe(predicted_scores, target_relevance)
def retrieve(self):
if self._ndcg_denominator > 0:
# define tensor on GPU, count and total is the result at each GPU
t = torch.tensor([self._ndcg_numerator, self._ndcg_denominator, self.count], dtype=torch.float32, device=f'cuda:{self.gpu_rank}')
dist.barrier() # synchronizes all processes
dist.all_reduce(t, op=torch.distributed.ReduceOp.SUM,) # Reduces the tensor data across all machines in such a way that all get the final result.
t = t.tolist()
self._ndcg_numerator, self._ndcg_denominator, self.count = t
metrics = {
"ndcg": float(self._ndcg_numerator / self._ndcg_denominator)
}
else:
metrics = {}
return metrics
def get_count(self):
return int(self.count)