Code release

This commit is contained in:
Adnen Abdessaied 2023-10-25 15:38:09 +02:00
commit 09fb25e339
29 changed files with 7162 additions and 0 deletions

277
README.md Normal file
View file

@ -0,0 +1,277 @@
<div align="center">
<h1> VD-GR: Boosting Visual Dialog with Cascaded Spatial-Temporal Multi-Modal GRaphs </h1>
**[Adnen Abdessaied][5], &nbsp; [Lei Shi][6], &nbsp; [Andreas Bulling][7]** <br> <br>
**WACV'24, Hawaii, USA** <img src="misc/usa.png" width="3%" align="center"> <br>
**[[Paper][8]]**
-------------------
<img src="misc/teaser_1.png" width="100%" align="middle"><br><br>
</div>
# Table of Contents
* [Setup and Dependencies](#Setup-and-Dependencies)
* [Download Data](#Download-Data)
* [Pre-trained Checkpoints](#Pre-trained-Checkpoints)
* [Training](#Training)
* [Results](#Results)
# Setup and Dependencies
We implemented our model using Python 3.7 and PyTorch 1.11.0 (CUDA 11.3, CuDNN 8.2.0). We recommend to setup a virtual environment using Anaconda. <br>
1. Install [git lfs][1] on your system
2. Clone our repository to download the data, checkpoints, and code
```shell
git lfs install
git clone this_repo.git
```
3. Create a conda environment and install dependencies
```shell
conda create -n vdgr python=3.7
conda activate vdgr
conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch
conda install pyg -c pyg # 2.1.0
pip install pytorch-transformers
pip install pytorch_pretrained_bert
pip install pyhocon glog wandb lmdb
```
4. If you wish to speed-up training, we recommend installing [apex][2]
```shell
git clone https://github.com/NVIDIA/apex
cd apex
# if pip >= 23.1 (ref: https://pip.pypa.io/en/stable/news/#v23-1) which supports multiple `--config-settings` with the same key...
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
# otherwise
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --global-option="--cpp_ext" --global-option="--cuda_ext" ./
cd ..
```
# Download Data
1. Download the extacted visual features of [VisDial][3] and setup all files we used in our work. We provide a shell script for convenience:
```shell
./setup_data.sh # Please make sure you have enough disk space
```
If everything was correctly setup, the ```data/``` folder should look like this
```
├── history_adj_matrices
│ ├── test
│ ├── *.pkl
│ ├── train
│ ├── *.pkl
│ ├── val
│ ├── *.pkl
├── question_adj_matrices
│ ├── test
│ ├── *.pkl
│ ├── train
│ ├── *.pkl
│ ├── val
│ ├── *.pkl
├── img_adj_matrices
│ ├── *.pkl
├── parse_vocab.pkl
├── test_dense_mapping.json
├── tr_dense_mapping.json
├── val_dense_mapping.json
├── visdial_0.9_test.json
├── visdial_0.9_train.json
├── visdial_0.9_val.json
├── visdial_1.0_test.json
├── visdial_1.0_train_dense_annotations.json
├── visdial_1.0_train_dense.json
├── visdial_1.0_train.json
├── visdial_1.0_val_dense_annotations.json
├── visdial_1.0_val.json
├── visdialconv_dense_annotations.json
├── visdialconv.json
├── vispro_dense_annotations.json
└── vispro.json
```
# Pre-trained Checkpoints
For convenience, we provide checkpoints of our model after the warm-up training stage in ```ckpt/``` for both VisDial v1.0 and VisDial v0.9. <br>
These checkpoints will be downloaded with the code if you use ```git lfs```.
# Training
We trained our model on 8 Nvidia Tesla V100-32GB GPUs. The default hyperparameters in ```config/vdgr.conf``` and ```config/bert_base_6layer_6conect.json``` need to be adjusted if your setup differs from ours.
## Phase 1
### Training
1. In this phase, we train our model on VisDial v1.0 via
```shell
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
--model vdgr/P1 \
--mode train \
--tag K2_v1.0 \
--wandb_mode online \
--wandb_project your_wandb_project_name
```
⚠️ On a similar setup to ours, this will take roughly 20h to complete using apex for training.
2. To train on VisDial v0.9:
* Set ```visdial_version = 0.9``` in ```config/vdgr.conf```
* Set ```start_path = ckpt/vdgr_visdial_v0.9_after_warmup_K2.ckpt``` in ```config/vdgr.conf```
* Run
```shell
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
--model vdgr/P1 \
--mode train \
--tag K2_v0.9 \
--wandb_mode online \
--wandb_project your_wandb_project_name
```
### Inference
1. For inference on VisDial v1.0 val, VisDialConv, or VisPro:
* Set ```eval_dataset = {visdial, visdial_conv, visdial_vispro}``` in ```logs/vdgr/P1_K2_v1.0/code/config/vdgr.conf```
* Run
```shell
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
--model vdgr/P1 \
--mode eval \
--eval_dir logs/vdgr/P1_K2_v1.0 \
--wandb_mode offline \
```
2. For inference on VisDial v0.9:
* Set ```eval_dataset = visdial``` in ```logs/vdgr/P1_K2_v0.9/code/config/vdgr.conf```
* Run
```shell
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
--model vdgr/P1 \
--mode eval \
--eval_dir logs/vdgr/P1_K2_v0.9 \
--wandb_mode offline \
```
⚠️ This might take some time to finish as the testing data of VisDial v0.9 is large.
3. For inference on the ```visdial_v1.0 test```:
* Run
```shell
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
--model vdgr/P1 \
--mode predict \
--eval_dir logs/vdgr/P1_K2_v1.0 \
--wandb_mode offline \
```
* The output file will be saved in ```output/```
## Phase 2
In this phase, we finetune on dense annotations to improve the NDCG score (Only supported for VisDial v1.0.)
1. Run
```shell
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
--model vdgr/P2_CE \
--mode train \
--tag K2_v1.0_CE \
--wandb_mode online \
--wandb_project your_wandb_project_name
```
This will take roughly 3-4 hours to complete using the same setup as before and [DP][4] for training.
2. For inference on VisDial v1.0:
* Run:
```shell
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
--model vdgr/P2_CE \
--mode predict \
--eval_dir logs/vdgr/P1_K2_v1.0_CE \
--wandb_mode offline \
```
* The output file will be saved in ```output/```
## Phase 3
### Training
In the final phase, we train an ensemble method comprising of 8 models using ```K={1,2,3,4}``` and ```dense_loss={ce, listnet}```.
For ```K=k```:
1. Set the value of ```num_v_gnn_layers, num_q_gnn_layers, num_h_gnn_layers``` to ```k```
2. Set ```start_path = ckpt/vdgr_visdial_v1.0_after_warmup_K[k].ckpt``` in ```config/vdgr.conf``` (P1)
3. Phase 1 training:
```shell
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
--model vdgr/P1 \
--mode train \
--tag K[k]_v1.0 \
--wandb_mode online \
--wandb_project your_wandb_project_name
```
3. Set ```start_path = logs/vdgr/P1_K[k]_v1.0/epoch_best.ckpt``` in ```config/vdgr.conf``` (P2)
4. Phase 2 training:
* Fine-tune with CE:
```shell
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
--model vdgr/P2_CE \
--mode train \
--tag K[k]_v1.0_CE \
--wandb_mode online \
--wandb_project your_wandb_project_name
```
* Fine-tune with LISTNET:
```shell
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
--model vdgr/P2_LISTNET \
--mode train \
--tag K[k]_v1.0_LISTNET \
--wandb_mode online \
--wandb_project your_wandb_project_name
```
### Inference
1. For inference on VisDial v1.0 test:
```shell
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
--model vdgr/P2_[CE,LISTNET] \
--mode predict \
--eval_dir logs/vdgr/P2_K[1,2,3,4]_v1.0_[CE,LISTNET] \
--wandb_mode offline \
```
2. Finally, merge the outputs of all models
```shell
python ensemble.py \
--exp test \
--mode predict \
```
The output file will be saved in ```output/```
# Results
## VisDial v0.9
| Model | MRR | R@1 | R@5 | R@10 | Mean |
|:--------:|:---:|:---:|:---:|:----:|:----:|
| Prev. SOTA | 71.99 | 59.41 | 87.92 | 94.59 | 2.87 |
| VD-GR | **74.50** | **62.10** | **90.49** | **96.37** | **2.45** |
## VisDialConv
| Model | NDCG | MRR | R@1 | R@5 | R@10 | Mean |
|:--------:|:----:|:---:|:---:|:---:|:----:|:----:|
| Prev. SOTA | 61.72 | 61.79 | 48.95 | 77.50 | 86.71 | 4.72 |
| VD-GR | **67.09** | **66.82** | **54.47** | **81.71** | **91.44** | **3.54** |
## VisPro
| Model | NDCG | MRR | R@1 | R@5 | R@10 | Mean |
|:--------:|:----:|:---:|:---:|:---:|:----:|:----:|
| Prev. SOTA | 59.30 | 62.29 | 48.35 | 80.10 | 88.87 | 4.37 |
| VD-GR | **60.35** | **69.89** | **57.21** | **85.97** | **92.68** | **3.15** |
## VisDial V1.0 Val
| Model | NDCG | MRR | R@1 | R@5 | R@10 | Mean |
|:--------:|:----:|:---:|:---:|:---:|:----:|:----:|
| Prev. SOTA | 65.47 | 69.71 | 56.79 | 85.82 | 93.64 | 3.15 |
| VD-GR | 64.32 | **69.91** | **57.01** | **86.14** | **93.74** | **3.13** |
## VisDial V1.0 Test
| Model | NDCG | MRR | R@1 | R@5 | R@10 | Mean |
|:--------:|:----:|:---:|:---:|:---:|:----:|:----:|
| Prev. SOTA | 64.91 | 68.73 | 55.73 | 85.38 | 93.53 | 3.21 |
| VD-GR | 63.49 | 68.65 | 55.33 | **85.58** | **93.85** | **3.20** |
| ♣️ Prev. SOTA | 75.92 | 56.18 | 45.32 | 68.05 | 80.98 | 5.42 |
| ♣️ VD-GR | **75.95** | **58.30** | **46.55** | **71.45** | 84.52 | **5.32** |
| ♣️♦️ Prev. SOTA | 76.17 | 56.42 | 44.75 | 70.23 | 84.52 | 5.47 |
| ♣️♦️ VD-GR | **76.43** | 56.35 | **45.18** | 68.13 | 82.18 | 5.79 |
♣️ = Finetuning on dense annotations, ♦️ = Ensemble model
[1]: https://git-lfs.com/
[2]: https://github.com/NVIDIA/apex
[3]: https://visualdialog.org/
[4]: https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html
[5]: https://adnenabdessaied.de
[6]: https://www.perceptualui.org/people/shi/
[7]: https://www.perceptualui.org/people/bulling/
[8]: https://drive.google.com/file/d/1GT0WDinA_z5FdwVc_bWtyB-cwQkGIf7C/view?usp=sharing

0
ckpt/.gitkeep Normal file
View file

View file

@ -0,0 +1,40 @@
{
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"max_position_embeddings": 512,
"num_attention_heads": 12,
"num_hidden_layers": 12,
"type_vocab_size": 2,
"vocab_size": 30522,
"v_feature_size": 2048,
"v_target_size": 1601,
"v_hidden_size": 1024,
"v_num_hidden_layers": 6,
"v_num_attention_heads": 8,
"v_intermediate_size": 1024,
"bi_hidden_size": 1024,
"bi_num_attention_heads": 8,
"bi_intermediate_size": 1024,
"bi_attention_type": 1,
"v_attention_probs_dropout_prob": 0.1,
"v_hidden_act": "gelu",
"v_hidden_dropout_prob": 0.1,
"v_initializer_range": 0.02,
"pooling_method": "mul",
"v_biattention_id": [0, 1, 2, 3, 4, 5],
"t_biattention_id": [6, 7, 8, 9, 10, 11],
"gnn_act": "gelu",
"num_v_gnn_layers": 2,
"num_q_gnn_layers": 2,
"num_h_gnn_layers": 2,
"num_gnn_attention_heads": 4,
"gnn_dropout_prob": 0.1,
"v_gnn_edge_dim": 12,
"q_gnn_edge_dim": 48,
"v_gnn_ids": [0, 1, 2, 3, 4, 5],
"t_gnn_ids": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
}

33
config/ensemble.conf Normal file
View file

@ -0,0 +1,33 @@
test = {
split = test
skip_mrr_eval = true
# data
visdial_test_data = data/visdial_1.0_test.json
# directory
log_dir = logs/vdgr_ensemble
pred_dir = logs/vdgr
visdial_output_dir = visdial_output
processed = [
false,
false,
false,
false,
false,
false,
false,
false
]
models = [
"P2_K1_v1.0_CE",
"P2_K2_v1.0_CE",
"P2_K3_v1.0_CE",
"P2_K4_v1.0_CE",
"P2_K1_v1.0_LISTNET",
"P2_K2_v1.0_LISTNET",
"P2_K3_v1.0_LISTNET",
"P2_K4_v1.0_LISTNET",
]
}

188
config/vdgr.conf Normal file
View file

@ -0,0 +1,188 @@
# Phase 1
P1 {
use_cpu = false
visdial_version = 1.0
train_on_dense = false
metrics_to_maximize = mrr
# visdial data
visdial_image_feats = data/visdial_img_feat.lmdb
visdial_image_adj_matrices = data/img_adj_matrices
visdial_question_adj_matrices = data/question_adj_matrices
visdial_history_adj_matrices = data/history_adj_matrices
visdial_train = data/visdial_1.0_train.json
visdial_val = data/visdial_1.0_val.json
visdial_test = data/visdial_1.0_test.json
visdial_val_dense_annotations = data/visdial_1.0_val_dense_annotations.json
visdial_train_09 = data/visdial_0.9_train.json
visdial_val_09 = data/visdial_0.9_val.json
visdial_test_09 = data/visdial_0.9_test.json
visdialconv_val = data/visdial_conv.json
visdialconv_val_dense_annotations = data/visdialconv_dense_annotations.json
visdialvispro_val = data/vispro.json
visdialvispro_val_dense_annotations = data/vispro_dense_annotations.json
visdial_question_parse_vocab = data/parse_vocab.pkl
# init
start_path = ckpt/vdgr_visdial_v1.0_after_warmup_K2.ckpt
model_config = config/bert_base_6layer_6conect.json
# visdial training
freeze_vilbert = false
visdial_tot_rounds = 11
num_negative_samples = 1
sequences_per_image = 2
batch_size = 8
lm_loss_coeff = 1
nsp_loss_coeff = 1
img_loss_coeff = 1
batch_multiply = 1
use_trainval = false
dense_loss = ce
dense_loss_coeff = 0
dataloader_text_only = false
rlv_hst_only = false
rlv_hst_dense_round = false
# visdial model
mask_prob = 0.1
image_mask_prob = 0.1
max_seq_len = 256
num_options = 100
num_options_dense = 100
use_embedding = joint
# visdial evaluation
eval_visdial_on_test = true
eval_batch_size = 1
eval_line_batch_size = 200
skip_mrr_eval = false
skip_ndcg_eval = false
skip_visdial_eval = false
eval_visdial_every = 1
eval_dataset = visdial # visdial_vispro # choices = [visdial, visdial_conv, visdial_vispro ]
continue_evaluation = false
eval_at_start = false
eval_before_training = false
initializer = normal
bert_cased = false
# restore ckpt
loads_best_ckpt = false
loads_ckpt = false
restarts = false
resets_max_metric = false
uses_new_optimizer = false
sets_new_lr = false
loads_start_path = false
# logging
random_seed = 42
next_logging_pct = 1.0
next_evaluating_pct = 50.0
max_ckpt_to_keep = 1
num_epochs = 20
early_stop_epoch = 5
skip_saving_ckpt = false
dp_type = apex
stack_gr_data = false
master_port = 5122
stop_epochs = -1
train_each_round = false
drop_last_answer = false
num_samples = -1
# predicting
predict_split = test
predict_each_round = false
predict_dense_round = false
num_test_dialogs = 8000
num_val_dialogs = 2064
save_score = false
# optimizer
reset_optim = none
learning_rate_bert = 5e-6
learning_rate_gnn = 2e-4
gnn_weight_decay = 0.01
use_diff_lr_gnn = true
min_lr = 0
decay_method_bert = linear
decay_method_gnn = linear
decay_exp = 2
max_grad_norm = 1.0
task_optimizer = adam
warmup_ratio = 0.1
# directory
log_dir = logs/vdgr
data_dir = data
visdial_output_dir = visdial_output
bert_cache_dir = transformers
# keep track of other hparams in bert json
v_gnn_edge_dim = 12 # 11 classes + hub_node
q_gnn_edge_dim = 48 # 47 classes + hub_node
num_v_gnn_layers = 2
num_q_gnn_layers = 2
num_h_gnn_layers = 2
num_gnn_attention_heads = 4
v_gnn_ids = [0, 1, 2, 3, 4, 5]
t_gnn_ids = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
}
# Phase 2
P2_CE = ${P1} {
# basic
train_on_dense = true
use_trainval = true
metrics_to_maximize = ndcg
visdial_train_dense = data/visdial_1.0_train_dense.json
visdial_train_dense_annotations = data/visdial_1.0_train_dense_annotations.json
visdial_val_dense = data/visdial_1.0_val.json
tr_graph_idx_mapping = data/tr_dense_mapping.json
val_graph_idx_mapping = data/val_dense_mapping.json
test_graph_idx_mapping = data/test_dense_mapping.json
visdial_val = data/visdial_1.0_val.json
visdial_val_dense_annotations = data/visdial_1.0_val_dense_annotations.json
# data
start_path = logs/vdgr/P1_K2_v1.0/epoch_best.ckpt
rlv_hst_only = false
# visdial training
nsp_loss_coeff = 0
dense_loss_coeff = 1
batch_multiply = 10
batch_size = 1
# visdial model
num_options_dense = 100
# visdial evaluation
eval_batch_size = 1
eval_line_batch_size = 100
skip_mrr_eval = true
# training
stop_epochs = 3
dp_type = dp
dense_loss = ce
# optimizer
learning_rate_bert = 1e-4
}
P2_LISTNET = ${P2_CE} {
dense_loss = listnet
}

0
data/.gitkeep Normal file
View file

0
dataloader/__init__.py Normal file
View file

View file

@ -0,0 +1,269 @@
import torch
from torch.utils import data
import json
import os
import glog as log
import pickle
import torch.utils.data as tud
from pytorch_transformers.tokenization_bert import BertTokenizer
from utils.image_features_reader import ImageFeaturesH5Reader
class DatasetBase(data.Dataset):
def __init__(self, config):
if config['display']:
log.info('Initializing dataset')
# Fetch the correct dataset for evaluation
if config['validating']:
assert config.eval_dataset in ['visdial', 'visdial_conv', 'visdial_vispro', 'visdial_v09']
if config.eval_dataset == 'visdial_conv':
config['visdial_val'] = config.visdialconv_val
config['visdial_val_dense_annotations'] = config.visdialconv_val_dense_annotations
elif config.eval_dataset == 'visdial_vispro':
config['visdial_val'] = config.visdialvispro_val
config['visdial_val_dense_annotations'] = config.visdialvispro_val_dense_annotations
elif config.eval_dataset == 'visdial_v09':
config['visdial_val_09'] = config.visdial_test_09
config['visdial_val_dense_annotations'] = None
self.config = config
self.numDataPoints = {}
if not config['dataloader_text_only']:
self._image_features_reader = ImageFeaturesH5Reader(
config['visdial_image_feats'],
config['visdial_image_adj_matrices']
)
if self.config['training'] or self.config['validating'] or self.config['predicting']:
split2data = {'train': 'train', 'val': 'val', 'test': 'test'}
elif self.config['debugging']:
split2data = {'train': 'val', 'val': 'val', 'test': 'test'}
elif self.config['visualizing']:
split2data = {'train': 'train', 'val': 'train', 'test': 'test'}
filename = f'visdial_{split2data["train"]}'
if config['train_on_dense']:
filename += '_dense'
if self.config['visdial_version'] == 0.9:
filename += '_09'
with open(config[filename]) as f:
self.visdial_data_train = json.load(f)
if self.config.num_samples > 0:
self.visdial_data_train['data']['dialogs'] = self.visdial_data_train['data']['dialogs'][:self.config.num_samples]
self.numDataPoints['train'] = len(self.visdial_data_train['data']['dialogs'])
filename = f'visdial_{split2data["val"]}'
if config['train_on_dense'] and config['training']:
filename += '_dense'
if self.config['visdial_version'] == 0.9:
filename += '_09'
with open(config[filename]) as f:
self.visdial_data_val = json.load(f)
if self.config.num_samples > 0:
self.visdial_data_val['data']['dialogs'] = self.visdial_data_val['data']['dialogs'][:self.config.num_samples]
self.numDataPoints['val'] = len(self.visdial_data_val['data']['dialogs'])
if config['train_on_dense']:
self.numDataPoints['trainval'] = self.numDataPoints['train'] + self.numDataPoints['val']
with open(config[f'visdial_{split2data["test"]}']) as f:
self.visdial_data_test = json.load(f)
self.numDataPoints['test'] = len(self.visdial_data_test['data']['dialogs'])
self.rlv_hst_train = None
self.rlv_hst_val = None
self.rlv_hst_test = None
if config['train_on_dense'] or config['predict_dense_round']:
with open(config[f'visdial_{split2data["train"]}_dense_annotations']) as f:
self.visdial_data_train_dense = json.load(f)
if config['train_on_dense']:
self.subsets = ['train', 'val', 'trainval', 'test']
else:
self.subsets = ['train','val','test']
self.num_options = config["num_options"]
self.num_options_dense = config["num_options_dense"]
if config['visdial_version'] != 0.9:
with open(config[f'visdial_{split2data["val"]}_dense_annotations']) as f:
self.visdial_data_val_dense = json.load(f)
else:
self.visdial_data_val_dense = None
self._split = 'train'
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', cache_dir=config['bert_cache_dir'])
# fetching token indicecs of [CLS] and [SEP]
tokens = ['[CLS]','[MASK]','[SEP]']
indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokens)
self.CLS = indexed_tokens[0]
self.MASK = indexed_tokens[1]
self.SEP = indexed_tokens[2]
self._max_region_num = 37
self.predict_each_round = self.config['predicting'] and self.config['predict_each_round']
self.keys_to_expand = ['image_feat', 'image_loc', 'image_mask', 'image_target', 'image_label']
self.keys_to_flatten_1d = ['hist_len', 'next_sentence_labels', 'round_id', 'image_id']
self.keys_to_flatten_2d = ['tokens', 'segments', 'sep_indices', 'mask', 'image_mask', 'image_label', 'input_mask', 'question_limits']
self.keys_to_flatten_3d = ['image_feat', 'image_loc', 'image_target', ]
self.keys_other = ['gt_relevance', 'gt_option_inds']
self.keys_lists_to_flatten = ['image_edge_indices', 'image_edge_attributes', 'question_edge_indices', 'question_edge_attributes', 'history_edge_indices', 'history_sep_indices']
if config['stack_gr_data']:
self.keys_to_flatten_3d.extend(self.keys_lists_to_flatten[:-1])
self.keys_to_flatten_2d.append(self.keys_lists_to_flatten[-1])
self.keys_to_flatten_1d.extend(['len_image_gr', 'len_question_gr', 'len_history_gr', 'len_history_sep'])
self.keys_lists_to_flatten = []
self.keys_to_list = ['tot_len']
# Load the parse vocab for question graph relationship mapping
if os.path.isfile(config['visdial_question_parse_vocab']):
with open(config['visdial_question_parse_vocab'], 'rb') as f:
self.parse_vocab = pickle.load(f)
def __len__(self):
return self.numDataPoints[self._split]
@property
def split(self):
return self._split
@split.setter
def split(self, split):
assert split in self.subsets
self._split = split
def tokens2str(self, seq):
dialog_sequence = ''
for sentence in seq:
for word in sentence:
dialog_sequence += self.tokenizer._convert_id_to_token(word) + " "
dialog_sequence += ' </end> '
dialog_sequence = dialog_sequence.encode('utf8')
return dialog_sequence
def pruneRounds(self, context, num_rounds):
start_segment = 1
len_context = len(context)
cur_rounds = (len(context) // 2) + 1
l_index = 0
if cur_rounds > num_rounds:
# caption is not part of the final input
l_index = len_context - (2 * num_rounds)
start_segment = 0
return context[l_index:], start_segment
def tokenize_utterance(self, sent, sentences, tot_len, sentence_count, sentence_map, speakers):
sentences.extend(sent + ['[SEP]'])
tokenized_sent = self.tokenizer.convert_tokens_to_ids(sent)
assert len(sent) == len(tokenized_sent), 'sub-word tokens are not allowed!'
sent_len = len(tokenized_sent)
tot_len += sent_len + 1 # the additional 1 is for the sep token
sentence_count += 1
sentence_map.extend([sentence_count * 2 - 1] * sent_len)
sentence_map.append(sentence_count * 2) # for [SEP]
speakers.extend([2] * (sent_len + 1))
return tokenized_sent, sentences, tot_len, sentence_count, sentence_map, speakers
def __getitem__(self, index):
return NotImplementedError
def collate_fn(self, batch):
tokens_size = batch[0]['tokens'].size()
num_rounds, num_samples = tokens_size[0], tokens_size[1]
merged_batch = {key: [d[key] for d in batch] for key in batch[0]}
if self.config['stack_gr_data']:
if (len(batch)) > 1:
max_question_gr_len = max([length.max().item() for length in merged_batch['len_question_gr']])
max_history_gr_len = max([length.max().item() for length in merged_batch['len_history_gr']])
max_history_sep_len = max([length.max().item() for length in merged_batch['len_history_sep']])
max_image_gr_len = max([length.max().item() for length in merged_batch['len_image_gr']])
question_edge_indices_padded = []
question_edge_attributes_padded = []
for q_e_idx, q_e_attr in zip(merged_batch['question_edge_indices'], merged_batch['question_edge_attributes']):
b_size, edge_dim, orig_len = q_e_idx.size()
q_e_idx_padded = torch.zeros((b_size, edge_dim, max_question_gr_len))
q_e_idx_padded[:, :, :orig_len] = q_e_idx
question_edge_indices_padded.append(q_e_idx_padded)
edge_attr_dim = q_e_attr.size(-1)
q_e_attr_padded = torch.zeros((b_size, max_question_gr_len, edge_attr_dim))
q_e_attr_padded[:, :orig_len, :] = q_e_attr
question_edge_attributes_padded.append(q_e_attr_padded)
merged_batch['question_edge_indices'] = question_edge_indices_padded
merged_batch['question_edge_attributes'] = question_edge_attributes_padded
history_edge_indices_padded = []
for h_e_idx in merged_batch['history_edge_indices']:
b_size, _, orig_len = h_e_idx.size()
h_edge_idx_padded = torch.zeros((b_size, 2, max_history_gr_len))
h_edge_idx_padded[:, :, :orig_len] = h_e_idx
history_edge_indices_padded.append(h_edge_idx_padded)
merged_batch['history_edge_indices'] = history_edge_indices_padded
history_sep_indices_padded = []
for hist_sep_idx in merged_batch['history_sep_indices']:
b_size, orig_len = hist_sep_idx.size()
hist_sep_idx_padded = torch.zeros((b_size, max_history_sep_len))
hist_sep_idx_padded[:, :orig_len] = hist_sep_idx
history_sep_indices_padded.append(hist_sep_idx_padded)
merged_batch['history_sep_indices'] = history_sep_indices_padded
image_edge_indices_padded = []
image_edge_attributes_padded = []
for img_e_idx, img_e_attr in zip(merged_batch['image_edge_indices'], merged_batch['image_edge_attributes']):
b_size, edge_dim, orig_len = img_e_idx.size()
img_e_idx_padded = torch.zeros((b_size, edge_dim, max_image_gr_len))
img_e_idx_padded[:, :, :orig_len] = img_e_idx
image_edge_indices_padded.append(img_e_idx_padded)
edge_attr_dim = img_e_attr.size(-1)
img_e_attr_padded = torch.zeros((b_size, max_image_gr_len, edge_attr_dim))
img_e_attr_padded[:, :orig_len, :] = img_e_attr
image_edge_attributes_padded.append(img_e_attr_padded)
merged_batch['image_edge_indices'] = image_edge_indices_padded
merged_batch['image_edge_attributes'] = image_edge_attributes_padded
out = {}
for key in merged_batch:
if key in self.keys_lists_to_flatten:
temp = []
for b in merged_batch[key]:
for x in b:
temp.append(x)
merged_batch[key] = temp
elif key in self.keys_to_list:
pass
else:
merged_batch[key] = torch.stack(merged_batch[key], 0)
if key in self.keys_to_expand:
if len(merged_batch[key].size()) == 3:
size0, size1, size2 = merged_batch[key].size()
expand_size = (size0, num_rounds, num_samples, size1, size2)
elif len(merged_batch[key].size()) == 2:
size0, size1 = merged_batch[key].size()
expand_size = (size0, num_rounds, num_samples, size1)
merged_batch[key] = merged_batch[key].unsqueeze(1).unsqueeze(1).expand(expand_size).contiguous()
if key in self.keys_to_flatten_1d:
merged_batch[key] = merged_batch[key].reshape(-1)
elif key in self.keys_to_flatten_2d:
merged_batch[key] = merged_batch[key].reshape(-1, merged_batch[key].shape[-1])
elif key in self.keys_to_flatten_3d:
merged_batch[key] = merged_batch[key].reshape(-1, merged_batch[key].shape[-2], merged_batch[key].shape[-1])
else:
assert key in self.keys_other, f'unrecognized key in collate_fn: {key}'
out[key] = merged_batch[key]
return out

View file

@ -0,0 +1,615 @@
import torch
import os
import numpy as np
import random
import pickle
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from utils.data_utils import encode_input, encode_input_with_mask, encode_image_input
from dataloader.dataloader_base import DatasetBase
class VisdialDataset(DatasetBase):
def __init__(self, config):
super(VisdialDataset, self).__init__(config)
def __getitem__(self, index):
MAX_SEQ_LEN = self.config['max_seq_len']
cur_data = None
if self._split == 'train':
cur_data = self.visdial_data_train['data']
ques_adj_matrices_dir = os.path.join(self.config['visdial_question_adj_matrices'], 'train')
hist_adj_matrices_dir = os.path.join(self.config['visdial_history_adj_matrices'], 'train')
elif self._split == 'val':
cur_data = self.visdial_data_val['data']
ques_adj_matrices_dir = os.path.join(self.config['visdial_question_adj_matrices'], 'val')
hist_adj_matrices_dir = os.path.join(self.config['visdial_history_adj_matrices'], 'val')
else:
cur_data = self.visdial_data_test['data']
ques_adj_matrices_dir = os.path.join(self.config['visdial_question_adj_matrices'], 'test')
hist_adj_matrices_dir = os.path.join(self.config['visdial_history_adj_matrices'], 'test')
if self.config['visdial_version'] == 0.9:
ques_adj_matrices_dir = os.path.join(self.config['visdial_question_adj_matrices'], 'train')
hist_adj_matrices_dir = os.path.join(self.config['visdial_history_adj_matrices'], 'train')
self.num_bad_samples = 0
# number of options to score on
num_options = self.num_options
assert num_options > 1 and num_options <= 100
num_dialog_rounds = 10
dialog = cur_data['dialogs'][index]
cur_questions = cur_data['questions']
cur_answers = cur_data['answers']
img_id = dialog['image_id']
graph_idx = dialog.get('dialog_idx', index)
if self._split == 'train':
# caption
sent = dialog['caption'].split(' ')
sentences = ['[CLS]']
tot_len = 1 # for the CLS token
sentence_map = [0] # for the CLS token
sentence_count = 0
speakers = [0]
tokenized_sent, sentences, tot_len, sentence_count, sentence_map, speakers = \
self.tokenize_utterance(sent, sentences, tot_len, sentence_count, sentence_map, speakers)
utterances = [[tokenized_sent]]
utterances_random = [[tokenized_sent]]
for rnd, utterance in enumerate(dialog['dialog']):
cur_rnd_utterance = utterances[-1].copy()
cur_rnd_utterance_random = utterances[-1].copy()
# question
sent = cur_questions[utterance['question']].split(' ')
tokenized_sent, sentences, tot_len, sentence_count, sentence_map, speakers = \
self.tokenize_utterance(sent, sentences, tot_len, sentence_count, sentence_map, speakers)
cur_rnd_utterance.append(tokenized_sent)
cur_rnd_utterance_random.append(tokenized_sent)
# answer
sent = cur_answers[utterance['answer']].split(' ')
tokenized_sent, sentences, tot_len, sentence_count, sentence_map, speakers = \
self.tokenize_utterance(sent, sentences, tot_len, sentence_count, sentence_map, speakers)
cur_rnd_utterance.append(tokenized_sent)
utterances.append(cur_rnd_utterance)
# randomly select one random utterance in that round
num_inds = len(utterance['answer_options'])
gt_option_ind = utterance['gt_index']
negative_samples = []
for _ in range(self.config["num_negative_samples"]):
all_inds = list(range(100))
all_inds.remove(gt_option_ind)
all_inds = all_inds[:(num_options-1)]
tokenized_random_utterance = None
option_ind = None
while len(all_inds):
option_ind = random.choice(all_inds)
tokenized_random_utterance = self.tokenizer.convert_tokens_to_ids(cur_answers[utterance['answer_options'][option_ind]].split(' '))
# the 1 here is for the sep token at the end of each utterance
if(MAX_SEQ_LEN >= (tot_len + len(tokenized_random_utterance) + 1)):
break
else:
all_inds.remove(option_ind)
if len(all_inds) == 0:
# all the options exceed the max len. Truncate the last utterance in this case.
tokenized_random_utterance = tokenized_random_utterance[:len(tokenized_sent)]
t = cur_rnd_utterance_random.copy()
t.append(tokenized_random_utterance)
negative_samples.append(t)
utterances_random.append(negative_samples)
# removing the caption in the beginning
utterances = utterances[1:]
utterances_random = utterances_random[1:]
assert len(utterances) == len(utterances_random) == num_dialog_rounds
assert tot_len <= MAX_SEQ_LEN, '{} {} tot_len = {} > max_seq_len'.format(
self._split, index, tot_len
)
tokens_all = []
question_limits_all = []
question_edge_indices_all = []
question_edge_attributes_all = []
history_edge_indices_all = []
history_sep_indices_all = []
mask_all = []
segments_all = []
sep_indices_all = []
next_labels_all = []
hist_len_all = []
# randomly pick several rounds to train
pos_rounds = sorted(random.sample(range(num_dialog_rounds), self.config['sequences_per_image'] // 2), reverse=True)
neg_rounds = sorted(random.sample(range(num_dialog_rounds), self.config['sequences_per_image'] // 2), reverse=True)
tokens_all_rnd = []
question_limits_all_rnd = []
mask_all_rnd = []
segments_all_rnd = []
sep_indices_all_rnd = []
next_labels_all_rnd = []
hist_len_all_rnd = []
for j in pos_rounds:
context = utterances[j]
context, start_segment = self.pruneRounds(context, self.config['visdial_tot_rounds'])
if j == pos_rounds[0]: # dialog with positive label and max rounds
tokens, segments, sep_indices, mask, input_mask, start_question, end_question = encode_input_with_mask(context, start_segment, self.CLS,
self.SEP, self.MASK, max_seq_len=MAX_SEQ_LEN, mask_prob=self.config["mask_prob"])
else:
tokens, segments, sep_indices, mask, start_question, end_question = encode_input(context, start_segment, self.CLS,
self.SEP, self.MASK, max_seq_len=MAX_SEQ_LEN, mask_prob=self.config["mask_prob"])
tokens_all_rnd.append(tokens)
question_limits_all_rnd.append(torch.tensor([start_question, end_question]))
mask_all_rnd.append(mask)
sep_indices_all_rnd.append(sep_indices)
next_labels_all_rnd.append(torch.LongTensor([0]))
segments_all_rnd.append(segments)
hist_len_all_rnd.append(torch.LongTensor([len(context)-1]))
tokens_all.append(torch.cat(tokens_all_rnd,0).unsqueeze(0))
mask_all.append(torch.cat(mask_all_rnd,0).unsqueeze(0))
question_limits_all.extend(question_limits_all_rnd)
segments_all.append(torch.cat(segments_all_rnd, 0).unsqueeze(0))
sep_indices_all.append(torch.cat(sep_indices_all_rnd, 0).unsqueeze(0))
next_labels_all.append(torch.cat(next_labels_all_rnd, 0).unsqueeze(0))
hist_len_all.append(torch.cat(hist_len_all_rnd,0).unsqueeze(0))
assert len(pos_rounds) == 1
question_graphs = pickle.load(
open(os.path.join(ques_adj_matrices_dir, f'{graph_idx}.pkl'), 'rb')
)
question_graph_pos = question_graphs[pos_rounds[0]]
question_edge_index_pos = []
question_edge_attribute_pos = []
for edge_idx, edge_attr in question_graph_pos:
question_edge_index_pos.append(edge_idx)
edge_attr_one_hot = np.zeros((len(self.parse_vocab) + 1,), dtype=np.float32)
edge_attr_one_hot[self.parse_vocab.get(edge_attr, len(self.parse_vocab))] = 1.0
question_edge_attribute_pos.append(edge_attr_one_hot)
question_edge_index_pos = np.array(question_edge_index_pos, dtype=np.float64)
question_edge_attribute_pos = np.stack(question_edge_attribute_pos, axis=0)
question_edge_indices_all.append(
torch.from_numpy(question_edge_index_pos).t().long().contiguous()
)
question_edge_attributes_all.append(
torch.from_numpy(question_edge_attribute_pos)
)
history_edge_indices = pickle.load(
open(os.path.join(hist_adj_matrices_dir, f'{graph_idx}.pkl'), 'rb')
)
history_edge_indices_all.append(
torch.tensor(history_edge_indices[pos_rounds[0]]).t().long().contiguous()
)
# Get the [SEP] tokens that will represent the history graph node features
hist_idx_pos = [i * 2 for i in range(pos_rounds[0] + 1)]
sep_indices = sep_indices.squeeze(0).numpy()
history_sep_indices_all.append(torch.from_numpy(sep_indices[hist_idx_pos]))
if len(neg_rounds) > 0:
tokens_all_rnd = []
question_limits_all_rnd = []
mask_all_rnd = []
segments_all_rnd = []
sep_indices_all_rnd = []
next_labels_all_rnd = []
hist_len_all_rnd = []
for j in neg_rounds:
negative_samples = utterances_random[j]
for context_random in negative_samples:
context_random, start_segment = self.pruneRounds(context_random, self.config['visdial_tot_rounds'])
tokens_random, segments_random, sep_indices_random, mask_random, start_question, end_question = encode_input(context_random, start_segment, self.CLS,
self.SEP, self.MASK, max_seq_len=MAX_SEQ_LEN, mask_prob=self.config["mask_prob"])
tokens_all_rnd.append(tokens_random)
question_limits_all_rnd.append(torch.tensor([start_question, end_question]))
mask_all_rnd.append(mask_random)
sep_indices_all_rnd.append(sep_indices_random)
next_labels_all_rnd.append(torch.LongTensor([1]))
segments_all_rnd.append(segments_random)
hist_len_all_rnd.append(torch.LongTensor([len(context_random)-1]))
tokens_all.append(torch.cat(tokens_all_rnd,0).unsqueeze(0))
mask_all.append(torch.cat(mask_all_rnd,0).unsqueeze(0))
question_limits_all.extend(question_limits_all_rnd)
segments_all.append(torch.cat(segments_all_rnd, 0).unsqueeze(0))
sep_indices_all.append(torch.cat(sep_indices_all_rnd, 0).unsqueeze(0))
next_labels_all.append(torch.cat(next_labels_all_rnd, 0).unsqueeze(0))
hist_len_all.append(torch.cat(hist_len_all_rnd,0).unsqueeze(0))
assert len(neg_rounds) == 1
question_graph_neg = question_graphs[neg_rounds[0]]
question_edge_index_neg = []
question_edge_attribute_neg = []
for edge_idx, edge_attr in question_graph_neg:
question_edge_index_neg.append(edge_idx)
edge_attr_one_hot = np.zeros((len(self.parse_vocab) + 1,), dtype=np.float32)
edge_attr_one_hot[self.parse_vocab.get(edge_attr, len(self.parse_vocab))] = 1.0
question_edge_attribute_neg.append(edge_attr_one_hot)
question_edge_index_neg = np.array(question_edge_index_neg, dtype=np.float64)
question_edge_attribute_neg = np.stack(question_edge_attribute_neg, axis=0)
question_edge_indices_all.append(
torch.from_numpy(question_edge_index_neg).t().long().contiguous()
)
question_edge_attributes_all.append(
torch.from_numpy(question_edge_attribute_neg)
)
history_edge_indices_all.append(
torch.tensor(history_edge_indices[neg_rounds[0]]).t().long().contiguous()
)
# Get the [SEP] tokens that will represent the history graph node features
hist_idx_neg = [i * 2 for i in range(neg_rounds[0] + 1)]
sep_indices_random = sep_indices_random.squeeze(0).numpy()
history_sep_indices_all.append(torch.from_numpy(sep_indices_random[hist_idx_neg]))
tokens_all = torch.cat(tokens_all, 0) # [2, num_pos, max_len]
question_limits_all = torch.stack(question_limits_all, 0) # [2, 2]
mask_all = torch.cat(mask_all,0)
segments_all = torch.cat(segments_all, 0)
sep_indices_all = torch.cat(sep_indices_all, 0)
next_labels_all = torch.cat(next_labels_all, 0)
hist_len_all = torch.cat(hist_len_all, 0)
input_mask_all = torch.LongTensor(input_mask) # [max_len]
item = {}
item['tokens'] = tokens_all
item['question_limits'] = question_limits_all
item['question_edge_indices'] = question_edge_indices_all
item['question_edge_attributes'] = question_edge_attributes_all
item['history_edge_indices'] = history_edge_indices_all
item['history_sep_indices'] = history_sep_indices_all
item['segments'] = segments_all
item['sep_indices'] = sep_indices_all
item['mask'] = mask_all
item['next_sentence_labels'] = next_labels_all
item['hist_len'] = hist_len_all
item['input_mask'] = input_mask_all
# get image features
if not self.config['dataloader_text_only']:
features, num_boxes, boxes, _ , image_target, image_edge_indexes, image_edge_attributes = self._image_features_reader[img_id]
features, spatials, image_mask, image_target, image_label = encode_image_input(features, num_boxes, boxes, image_target, max_regions=self._max_region_num)
else:
features = spatials = image_mask = image_target = image_label = torch.tensor([0])
elif self._split == 'val':
gt_relevance = None
gt_option_inds = []
options_all = []
# caption
sent = dialog['caption'].split(' ')
sentences = ['[CLS]']
tot_len = 1 # for the CLS token
sentence_map = [0] # for the CLS token
sentence_count = 0
speakers = [0]
tokenized_sent, sentences, tot_len, sentence_count, sentence_map, speakers = \
self.tokenize_utterance(sent, sentences, tot_len, sentence_count, sentence_map, speakers)
utterances = [[tokenized_sent]]
for rnd, utterance in enumerate(dialog['dialog']):
cur_rnd_utterance = utterances[-1].copy()
# question
sent = cur_questions[utterance['question']].split(' ')
tokenized_sent, sentences, tot_len, sentence_count, sentence_map, speakers = \
self.tokenize_utterance(sent, sentences, tot_len, sentence_count, sentence_map, speakers)
cur_rnd_utterance.append(tokenized_sent)
# current round
gt_option_ind = utterance['gt_index']
# first select gt option id, then choose the first num_options inds
option_inds = []
option_inds.append(gt_option_ind)
all_inds = list(range(100))
all_inds.remove(gt_option_ind)
all_inds = all_inds[:(num_options-1)]
option_inds.extend(all_inds)
gt_option_inds.append(0)
cur_rnd_options = []
answer_options = [utterance['answer_options'][k] for k in option_inds]
assert len(answer_options) == len(option_inds) == num_options
assert answer_options[0] == utterance['answer']
# for evaluation of all options and dense relevance
if self.visdial_data_val_dense:
if rnd == self.visdial_data_val_dense[index]['round_id'] - 1:
# only 1 round has gt_relevance for each example
if 'relevance' in self.visdial_data_val_dense[index]:
gt_relevance = torch.Tensor(self.visdial_data_val_dense[index]['relevance'])
else:
gt_relevance = torch.Tensor(self.visdial_data_val_dense[index]['gt_relevance'])
# shuffle based on new indices
gt_relevance = gt_relevance[torch.LongTensor(option_inds)]
else:
gt_relevance = -1
for answer_option in answer_options:
cur_rnd_cur_option = cur_rnd_utterance.copy()
cur_rnd_cur_option.append(self.tokenizer.convert_tokens_to_ids(cur_answers[answer_option].split(' ')))
cur_rnd_options.append(cur_rnd_cur_option)
# answer
sent = cur_answers[utterance['answer']].split(' ')
tokenized_sent, sentences, tot_len, sentence_count, sentence_map, speakers = \
self.tokenize_utterance(sent, sentences, tot_len, sentence_count, sentence_map, speakers)
cur_rnd_utterance.append(tokenized_sent)
utterances.append(cur_rnd_utterance)
options_all.append(cur_rnd_options)
# encode the input and create batch x 10 x 100 * max_len arrays (batch x num_rounds x num_options)
tokens_all = []
question_limits_all = []
mask_all = []
segments_all = []
sep_indices_all = []
hist_len_all = []
history_sep_indices_all = []
for rnd, cur_rnd_options in enumerate(options_all):
tokens_all_rnd = []
mask_all_rnd = []
segments_all_rnd = []
sep_indices_all_rnd = []
hist_len_all_rnd = []
for j, cur_rnd_option in enumerate(cur_rnd_options):
cur_rnd_option, start_segment = self.pruneRounds(cur_rnd_option, self.config['visdial_tot_rounds'])
if rnd == len(options_all) - 1 and j == 0: # gt dialog
tokens, segments, sep_indices, mask, input_mask, start_question, end_question = encode_input_with_mask(cur_rnd_option, start_segment, self.CLS,
self.SEP, self.MASK, max_seq_len=MAX_SEQ_LEN, mask_prob=0)
else:
tokens, segments, sep_indices, mask, start_question, end_question = encode_input(cur_rnd_option, start_segment,self.CLS,
self.SEP, self.MASK ,max_seq_len=MAX_SEQ_LEN, mask_prob=0)
tokens_all_rnd.append(tokens)
mask_all_rnd.append(mask)
segments_all_rnd.append(segments)
sep_indices_all_rnd.append(sep_indices)
hist_len_all_rnd.append(torch.LongTensor([len(cur_rnd_option)-1]))
question_limits_all.append(torch.tensor([start_question, end_question]).unsqueeze(0).repeat(100, 1))
tokens_all.append(torch.cat(tokens_all_rnd,0).unsqueeze(0))
mask_all.append(torch.cat(mask_all_rnd,0).unsqueeze(0))
segments_all.append(torch.cat(segments_all_rnd,0).unsqueeze(0))
sep_indices_all.append(torch.cat(sep_indices_all_rnd,0).unsqueeze(0))
hist_len_all.append(torch.cat(hist_len_all_rnd,0).unsqueeze(0))
# Get the [SEP] tokens that will represent the history graph node features
# It will be the same for all answer candidates as the history does not change
# for each answer
hist_idx = [i * 2 for i in range(rnd + 1)]
history_sep_indices_all.extend(sep_indices.squeeze(0)[hist_idx].contiguous() for _ in range(100))
tokens_all = torch.cat(tokens_all, 0) # [10, 100, max_len]
mask_all = torch.cat(mask_all, 0)
segments_all = torch.cat(segments_all, 0)
sep_indices_all = torch.cat(sep_indices_all, 0)
hist_len_all = torch.cat(hist_len_all, 0)
input_mask_all = torch.LongTensor(input_mask) # [max_len]
# load graph data
question_limits_all = torch.stack(question_limits_all, 0) # [10, 100, 2]
question_graphs = pickle.load(
open(os.path.join(ques_adj_matrices_dir, f'{graph_idx}.pkl'), 'rb')
)
question_edge_indices_all = [] # [10, N] we do not repeat it 100 times here
question_edge_attributes_all = [] # [10, N] we do not repeat it 100 times here
for q_graph_round in question_graphs:
question_edge_index = []
question_edge_attribute = []
for edge_index, edge_attr in q_graph_round:
question_edge_index.append(edge_index)
edge_attr_one_hot = np.zeros((len(self.parse_vocab) + 1,), dtype=np.float32)
edge_attr_one_hot[self.parse_vocab.get(edge_attr, len(self.parse_vocab))] = 1.0
question_edge_attribute.append(edge_attr_one_hot)
question_edge_index = np.array(question_edge_index, dtype=np.float64)
question_edge_attribute = np.stack(question_edge_attribute, axis=0)
question_edge_indices_all.extend(
[torch.from_numpy(question_edge_index).t().long().contiguous() for _ in range(100)])
question_edge_attributes_all.extend(
[torch.from_numpy(question_edge_attribute).contiguous() for _ in range(100)])
_history_edge_incides_all = pickle.load(
open(os.path.join(hist_adj_matrices_dir, f'{graph_idx}.pkl'), 'rb')
)
history_edge_incides_all = []
for hist_edge_indices_rnd in _history_edge_incides_all:
history_edge_incides_all.extend(
[torch.tensor(hist_edge_indices_rnd).t().long().contiguous() for _ in range(100)]
)
item = {}
item['tokens'] = tokens_all
item['segments'] = segments_all
item['sep_indices'] = sep_indices_all
item['mask'] = mask_all
item['hist_len'] = hist_len_all
item['input_mask'] = input_mask_all
item['gt_option_inds'] = torch.LongTensor(gt_option_inds)
# return dense annotation data as well
if self.visdial_data_val_dense:
item['round_id'] = torch.LongTensor([self.visdial_data_val_dense[index]['round_id']])
item['gt_relevance'] = gt_relevance
item['question_limits'] = question_limits_all
item['question_edge_indices'] = question_edge_indices_all
item['question_edge_attributes'] = question_edge_attributes_all
item['history_edge_indices'] = history_edge_incides_all
item['history_sep_indices'] = history_sep_indices_all
# get image features
if not self.config['dataloader_text_only']:
features, num_boxes, boxes, _ , image_target, image_edge_indexes, image_edge_attributes = self._image_features_reader[img_id]
features, spatials, image_mask, image_target, image_label = encode_image_input(
features, num_boxes, boxes, image_target, max_regions=self._max_region_num, mask_prob=0)
else:
features = spatials = image_mask = image_target = image_label = torch.tensor([0])
elif self.split == 'test':
assert num_options == 100
cur_rnd_utterance = [self.tokenizer.convert_tokens_to_ids(dialog['caption'].split(' '))]
options_all = []
for rnd,utterance in enumerate(dialog['dialog']):
cur_rnd_utterance.append(self.tokenizer.convert_tokens_to_ids(cur_questions[utterance['question']].split(' ')))
if rnd != len(dialog['dialog'])-1:
cur_rnd_utterance.append(self.tokenizer.convert_tokens_to_ids(cur_answers[utterance['answer']].split(' ')))
for answer_option in dialog['dialog'][-1]['answer_options']:
cur_option = cur_rnd_utterance.copy()
cur_option.append(self.tokenizer.convert_tokens_to_ids(cur_answers[answer_option].split(' ')))
options_all.append(cur_option)
tokens_all = []
mask_all = []
segments_all = []
sep_indices_all = []
hist_len_all = []
for j, option in enumerate(options_all):
option, start_segment = self.pruneRounds(option, self.config['visdial_tot_rounds'])
tokens, segments, sep_indices, mask = encode_input(option, start_segment ,self.CLS,
self.SEP, self.MASK ,max_seq_len=MAX_SEQ_LEN, mask_prob=0)
tokens_all.append(tokens)
mask_all.append(mask)
segments_all.append(segments)
sep_indices_all.append(sep_indices)
hist_len_all.append(torch.LongTensor([len(option)-1]))
tokens_all = torch.cat(tokens_all,0)
mask_all = torch.cat(mask_all,0)
segments_all = torch.cat(segments_all, 0)
sep_indices_all = torch.cat(sep_indices_all, 0)
hist_len_all = torch.cat(hist_len_all,0)
hist_idx = [i*2 for i in range(len(dialog['dialog']))]
history_sep_indices_all = [sep_indices.squeeze(0)[hist_idx].contiguous() for _ in range(num_options)]
with open(os.path.join(ques_adj_matrices_dir, f'{graph_idx}.pkl'), 'rb') as f:
question_graphs = pickle.load(f)
q_graph_last = question_graphs[-1]
question_edge_index = []
question_edge_attribute = []
for edge_index, edge_attr in q_graph_last:
question_edge_index.append(edge_index)
edge_attr_one_hot = np.zeros((len(self.parse_vocab) + 1,), dtype=np.float32)
edge_attr_one_hot[self.parse_vocab.get(edge_attr, len(self.parse_vocab))] = 1.0
question_edge_attribute.append(edge_attr_one_hot)
question_edge_index = np.array(question_edge_index, dtype=np.float64)
question_edge_attribute = np.stack(question_edge_attribute, axis=0)
question_edge_indices_all = [torch.from_numpy(question_edge_index).t().long().contiguous() for _ in range(num_options)]
question_edge_attributes_all = [torch.from_numpy(question_edge_attribute).contiguous() for _ in range(num_options)]
with open(os.path.join(hist_adj_matrices_dir, f'{graph_idx}.pkl'), 'rb') as f:
_history_edge_incides_all = pickle.load(f)
_history_edge_incides_last = _history_edge_incides_all[-1]
history_edge_index_all = [torch.tensor(_history_edge_incides_last).t().long().contiguous() for _ in range(num_options)]
if self.config['stack_gr_data']:
question_edge_indices_all = torch.stack(question_edge_indices_all, dim=0)
question_edge_attributes_all = torch.stack(question_edge_attributes_all, dim=0)
history_edge_index_all = torch.stack(history_edge_index_all, dim=0)
history_sep_indices_all = torch.stack(history_sep_indices_all, dim=0)
len_question_gr = torch.tensor(question_edge_indices_all.size(-1)).unsqueeze(0).repeat(num_options, 1)
len_history_gr = torch.tensor(history_edge_index_all.size(-1)).repeat(num_options, 1)
len_history_sep = torch.tensor(history_sep_indices_all.size(-1)).repeat(num_options, 1)
item = {}
item['tokens'] = tokens_all.unsqueeze(0)
item['segments'] = segments_all.unsqueeze(0)
item['sep_indices'] = sep_indices_all.unsqueeze(0)
item['mask'] = mask_all.unsqueeze(0)
item['hist_len'] = hist_len_all.unsqueeze(0)
item['question_limits'] = question_limits_all
item['question_edge_indices'] = question_edge_indices_all
item['question_edge_attributes'] = question_edge_attributes_all
item['history_edge_indices'] = history_edge_index_all
item['history_sep_indices'] = history_sep_indices_all
if self.config['stack_gr_data']:
item['len_question_gr'] = len_question_gr
item['len_history_gr'] = len_history_gr
item['len_history_sep'] = len_history_sep
item['round_id'] = torch.LongTensor([dialog['round_id']])
# get image features
if not self.config['dataloader_text_only']:
features, num_boxes, boxes, _ , image_target, image_edge_indexes, image_edge_attributes = self._image_features_reader[img_id]
features, spatials, image_mask, image_target, image_label = encode_image_input(features, num_boxes, boxes, image_target, max_regions=self._max_region_num, mask_prob=0)
else:
features = spatials = image_mask = image_target = image_label = torch.tensor([0])
item['image_feat'] = features
item['image_loc'] = spatials
item['image_mask'] = image_mask
item['image_target'] = image_target
item['image_label'] = image_label
item['image_id'] = torch.LongTensor([img_id])
if self._split == 'train':
# cheap hack to account for the graph data for the postitive and negatice examples
item['image_edge_indices'] = [torch.from_numpy(image_edge_indexes).long(), torch.from_numpy(image_edge_indexes).long()]
item['image_edge_attributes'] = [torch.from_numpy(image_edge_attributes), torch.from_numpy(image_edge_attributes)]
elif self._split == 'val':
# cheap hack to account for the graph data for the postitive and negatice examples
item['image_edge_indices'] = [torch.from_numpy(image_edge_indexes).contiguous().long() for _ in range(1000)]
item['image_edge_attributes'] = [torch.from_numpy(image_edge_attributes).contiguous() for _ in range(1000)]
else:
# cheap hack to account for the graph data for the postitive and negatice examples
item['image_edge_indices'] = [torch.from_numpy(image_edge_indexes).contiguous().long() for _ in range(100)]
item['image_edge_attributes'] = [torch.from_numpy(image_edge_attributes).contiguous() for _ in range(100)]
if self.config['stack_gr_data']:
item['image_edge_indices'] = torch.stack(item['image_edge_indices'], dim=0)
item['image_edge_attributes'] = torch.stack(item['image_edge_attributes'], dim=0)
len_image_gr = torch.tensor(item['image_edge_indices'].size(-1)).unsqueeze(0).repeat(num_options)
item['len_image_gr'] = len_image_gr
return item

View file

@ -0,0 +1,313 @@
import torch
import json
import os
import time