Code release

This commit is contained in:
Adnen Abdessaied 2023-10-25 15:38:09 +02:00
commit 09fb25e339
29 changed files with 7162 additions and 0 deletions

277
README.md Normal file
View File

@ -0,0 +1,277 @@
<div align="center">
<h1> VD-GR: Boosting Visual Dialog with Cascaded Spatial-Temporal Multi-Modal GRaphs </h1>
**[Adnen Abdessaied][5], &nbsp; [Lei Shi][6], &nbsp; [Andreas Bulling][7]** <br> <br>
**WACV'24, Hawaii, USA** <img src="misc/usa.png" width="3%" align="center"> <br>
**[[Paper][8]]**
-------------------
<img src="misc/teaser_1.png" width="100%" align="middle"><br><br>
</div>
# Table of Contents
* [Setup and Dependencies](#Setup-and-Dependencies)
* [Download Data](#Download-Data)
* [Pre-trained Checkpoints](#Pre-trained-Checkpoints)
* [Training](#Training)
* [Results](#Results)
# Setup and Dependencies
We implemented our model using Python 3.7 and PyTorch 1.11.0 (CUDA 11.3, CuDNN 8.2.0). We recommend to setup a virtual environment using Anaconda. <br>
1. Install [git lfs][1] on your system
2. Clone our repository to download the data, checkpoints, and code
```shell
git lfs install
git clone this_repo.git
```
3. Create a conda environment and install dependencies
```shell
conda create -n vdgr python=3.7
conda activate vdgr
conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch
conda install pyg -c pyg # 2.1.0
pip install pytorch-transformers
pip install pytorch_pretrained_bert
pip install pyhocon glog wandb lmdb
```
4. If you wish to speed-up training, we recommend installing [apex][2]
```shell
git clone https://github.com/NVIDIA/apex
cd apex
# if pip >= 23.1 (ref: https://pip.pypa.io/en/stable/news/#v23-1) which supports multiple `--config-settings` with the same key...
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
# otherwise
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --global-option="--cpp_ext" --global-option="--cuda_ext" ./
cd ..
```
# Download Data
1. Download the extacted visual features of [VisDial][3] and setup all files we used in our work. We provide a shell script for convenience:
```shell
./setup_data.sh # Please make sure you have enough disk space
```
If everything was correctly setup, the ```data/``` folder should look like this
```
├── history_adj_matrices
│ ├── test
│ ├── *.pkl
│ ├── train
│ ├── *.pkl
│ ├── val
│ ├── *.pkl
├── question_adj_matrices
│ ├── test
│ ├── *.pkl
│ ├── train
│ ├── *.pkl
│ ├── val
│ ├── *.pkl
├── img_adj_matrices
│ ├── *.pkl
├── parse_vocab.pkl
├── test_dense_mapping.json
├── tr_dense_mapping.json
├── val_dense_mapping.json
├── visdial_0.9_test.json
├── visdial_0.9_train.json
├── visdial_0.9_val.json
├── visdial_1.0_test.json
├── visdial_1.0_train_dense_annotations.json
├── visdial_1.0_train_dense.json
├── visdial_1.0_train.json
├── visdial_1.0_val_dense_annotations.json
├── visdial_1.0_val.json
├── visdialconv_dense_annotations.json
├── visdialconv.json
├── vispro_dense_annotations.json
└── vispro.json
```
# Pre-trained Checkpoints
For convenience, we provide checkpoints of our model after the warm-up training stage in ```ckpt/``` for both VisDial v1.0 and VisDial v0.9. <br>
These checkpoints will be downloaded with the code if you use ```git lfs```.
# Training
We trained our model on 8 Nvidia Tesla V100-32GB GPUs. The default hyperparameters in ```config/vdgr.conf``` and ```config/bert_base_6layer_6conect.json``` need to be adjusted if your setup differs from ours.
## Phase 1
### Training
1. In this phase, we train our model on VisDial v1.0 via
```shell
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
--model vdgr/P1 \
--mode train \
--tag K2_v1.0 \
--wandb_mode online \
--wandb_project your_wandb_project_name
```
⚠️ On a similar setup to ours, this will take roughly 20h to complete using apex for training.
2. To train on VisDial v0.9:
* Set ```visdial_version = 0.9``` in ```config/vdgr.conf```
* Set ```start_path = ckpt/vdgr_visdial_v0.9_after_warmup_K2.ckpt``` in ```config/vdgr.conf```
* Run
```shell
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
--model vdgr/P1 \
--mode train \
--tag K2_v0.9 \
--wandb_mode online \
--wandb_project your_wandb_project_name
```
### Inference
1. For inference on VisDial v1.0 val, VisDialConv, or VisPro:
* Set ```eval_dataset = {visdial, visdial_conv, visdial_vispro}``` in ```logs/vdgr/P1_K2_v1.0/code/config/vdgr.conf```
* Run
```shell
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
--model vdgr/P1 \
--mode eval \
--eval_dir logs/vdgr/P1_K2_v1.0 \
--wandb_mode offline \
```
2. For inference on VisDial v0.9:
* Set ```eval_dataset = visdial``` in ```logs/vdgr/P1_K2_v0.9/code/config/vdgr.conf```
* Run
```shell
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
--model vdgr/P1 \
--mode eval \
--eval_dir logs/vdgr/P1_K2_v0.9 \
--wandb_mode offline \
```
⚠️ This might take some time to finish as the testing data of VisDial v0.9 is large.
3. For inference on the ```visdial_v1.0 test```:
* Run
```shell
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
--model vdgr/P1 \
--mode predict \
--eval_dir logs/vdgr/P1_K2_v1.0 \
--wandb_mode offline \
```
* The output file will be saved in ```output/```
## Phase 2
In this phase, we finetune on dense annotations to improve the NDCG score (Only supported for VisDial v1.0.)
1. Run
```shell
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
--model vdgr/P2_CE \
--mode train \
--tag K2_v1.0_CE \
--wandb_mode online \
--wandb_project your_wandb_project_name
```
This will take roughly 3-4 hours to complete using the same setup as before and [DP][4] for training.
2. For inference on VisDial v1.0:
* Run:
```shell
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
--model vdgr/P2_CE \
--mode predict \
--eval_dir logs/vdgr/P1_K2_v1.0_CE \
--wandb_mode offline \
```
* The output file will be saved in ```output/```
## Phase 3
### Training
In the final phase, we train an ensemble method comprising of 8 models using ```K={1,2,3,4}``` and ```dense_loss={ce, listnet}```.
For ```K=k```:
1. Set the value of ```num_v_gnn_layers, num_q_gnn_layers, num_h_gnn_layers``` to ```k```
2. Set ```start_path = ckpt/vdgr_visdial_v1.0_after_warmup_K[k].ckpt``` in ```config/vdgr.conf``` (P1)
3. Phase 1 training:
```shell
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
--model vdgr/P1 \
--mode train \
--tag K[k]_v1.0 \
--wandb_mode online \
--wandb_project your_wandb_project_name
```
3. Set ```start_path = logs/vdgr/P1_K[k]_v1.0/epoch_best.ckpt``` in ```config/vdgr.conf``` (P2)
4. Phase 2 training:
* Fine-tune with CE:
```shell
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
--model vdgr/P2_CE \
--mode train \
--tag K[k]_v1.0_CE \
--wandb_mode online \
--wandb_project your_wandb_project_name
```
* Fine-tune with LISTNET:
```shell
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
--model vdgr/P2_LISTNET \
--mode train \
--tag K[k]_v1.0_LISTNET \
--wandb_mode online \
--wandb_project your_wandb_project_name
```
### Inference
1. For inference on VisDial v1.0 test:
```shell
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python main.py \
--model vdgr/P2_[CE,LISTNET] \
--mode predict \
--eval_dir logs/vdgr/P2_K[1,2,3,4]_v1.0_[CE,LISTNET] \
--wandb_mode offline \
```
2. Finally, merge the outputs of all models
```shell
python ensemble.py \
--exp test \
--mode predict \
```
The output file will be saved in ```output/```
# Results
## VisDial v0.9
| Model | MRR | R@1 | R@5 | R@10 | Mean |
|:--------:|:---:|:---:|:---:|:----:|:----:|
| Prev. SOTA | 71.99 | 59.41 | 87.92 | 94.59 | 2.87 |
| VD-GR | **74.50** | **62.10** | **90.49** | **96.37** | **2.45** |
## VisDialConv
| Model | NDCG | MRR | R@1 | R@5 | R@10 | Mean |
|:--------:|:----:|:---:|:---:|:---:|:----:|:----:|
| Prev. SOTA | 61.72 | 61.79 | 48.95 | 77.50 | 86.71 | 4.72 |
| VD-GR | **67.09** | **66.82** | **54.47** | **81.71** | **91.44** | **3.54** |
## VisPro
| Model | NDCG | MRR | R@1 | R@5 | R@10 | Mean |
|:--------:|:----:|:---:|:---:|:---:|:----:|:----:|
| Prev. SOTA | 59.30 | 62.29 | 48.35 | 80.10 | 88.87 | 4.37 |
| VD-GR | **60.35** | **69.89** | **57.21** | **85.97** | **92.68** | **3.15** |
## VisDial V1.0 Val
| Model | NDCG | MRR | R@1 | R@5 | R@10 | Mean |
|:--------:|:----:|:---:|:---:|:---:|:----:|:----:|
| Prev. SOTA | 65.47 | 69.71 | 56.79 | 85.82 | 93.64 | 3.15 |
| VD-GR | 64.32 | **69.91** | **57.01** | **86.14** | **93.74** | **3.13** |
## VisDial V1.0 Test
| Model | NDCG | MRR | R@1 | R@5 | R@10 | Mean |
|:--------:|:----:|:---:|:---:|:---:|:----:|:----:|
| Prev. SOTA | 64.91 | 68.73 | 55.73 | 85.38 | 93.53 | 3.21 |
| VD-GR | 63.49 | 68.65 | 55.33 | **85.58** | **93.85** | **3.20** |
| ♣️ Prev. SOTA | 75.92 | 56.18 | 45.32 | 68.05 | 80.98 | 5.42 |
| ♣️ VD-GR | **75.95** | **58.30** | **46.55** | **71.45** | 84.52 | **5.32** |
| ♣️♦️ Prev. SOTA | 76.17 | 56.42 | 44.75 | 70.23 | 84.52 | 5.47 |
| ♣️♦️ VD-GR | **76.43** | 56.35 | **45.18** | 68.13 | 82.18 | 5.79 |
♣️ = Finetuning on dense annotations, ♦️ = Ensemble model
[1]: https://git-lfs.com/
[2]: https://github.com/NVIDIA/apex
[3]: https://visualdialog.org/
[4]: https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html
[5]: https://adnenabdessaied.de
[6]: https://www.perceptualui.org/people/shi/
[7]: https://www.perceptualui.org/people/bulling/
[8]: https://drive.google.com/file/d/1GT0WDinA_z5FdwVc_bWtyB-cwQkGIf7C/view?usp=sharing

0
ckpt/.gitkeep Normal file
View File

View File

@ -0,0 +1,40 @@
{
"attention_probs_dropout_prob": 0.1,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"max_position_embeddings": 512,
"num_attention_heads": 12,
"num_hidden_layers": 12,
"type_vocab_size": 2,
"vocab_size": 30522,
"v_feature_size": 2048,
"v_target_size": 1601,
"v_hidden_size": 1024,
"v_num_hidden_layers": 6,
"v_num_attention_heads": 8,
"v_intermediate_size": 1024,
"bi_hidden_size": 1024,
"bi_num_attention_heads": 8,
"bi_intermediate_size": 1024,
"bi_attention_type": 1,
"v_attention_probs_dropout_prob": 0.1,
"v_hidden_act": "gelu",
"v_hidden_dropout_prob": 0.1,
"v_initializer_range": 0.02,
"pooling_method": "mul",
"v_biattention_id": [0, 1, 2, 3, 4, 5],
"t_biattention_id": [6, 7, 8, 9, 10, 11],
"gnn_act": "gelu",
"num_v_gnn_layers": 2,
"num_q_gnn_layers": 2,
"num_h_gnn_layers": 2,
"num_gnn_attention_heads": 4,
"gnn_dropout_prob": 0.1,
"v_gnn_edge_dim": 12,
"q_gnn_edge_dim": 48,
"v_gnn_ids": [0, 1, 2, 3, 4, 5],
"t_gnn_ids": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
}

33
config/ensemble.conf Normal file
View File

@ -0,0 +1,33 @@
test = {
split = test
skip_mrr_eval = true
# data
visdial_test_data = data/visdial_1.0_test.json
# directory
log_dir = logs/vdgr_ensemble
pred_dir = logs/vdgr
visdial_output_dir = visdial_output
processed = [
false,
false,
false,
false,
false,
false,
false,
false
]
models = [
"P2_K1_v1.0_CE",
"P2_K2_v1.0_CE",
"P2_K3_v1.0_CE",
"P2_K4_v1.0_CE",
"P2_K1_v1.0_LISTNET",
"P2_K2_v1.0_LISTNET",
"P2_K3_v1.0_LISTNET",
"P2_K4_v1.0_LISTNET",
]
}

188
config/vdgr.conf Normal file
View File

@ -0,0 +1,188 @@
# Phase 1
P1 {
use_cpu = false
visdial_version = 1.0
train_on_dense = false
metrics_to_maximize = mrr
# visdial data
visdial_image_feats = data/visdial_img_feat.lmdb
visdial_image_adj_matrices = data/img_adj_matrices
visdial_question_adj_matrices = data/question_adj_matrices
visdial_history_adj_matrices = data/history_adj_matrices
visdial_train = data/visdial_1.0_train.json
visdial_val = data/visdial_1.0_val.json
visdial_test = data/visdial_1.0_test.json
visdial_val_dense_annotations = data/visdial_1.0_val_dense_annotations.json
visdial_train_09 = data/visdial_0.9_train.json
visdial_val_09 = data/visdial_0.9_val.json
visdial_test_09 = data/visdial_0.9_test.json
visdialconv_val = data/visdial_conv.json
visdialconv_val_dense_annotations = data/visdialconv_dense_annotations.json
visdialvispro_val = data/vispro.json
visdialvispro_val_dense_annotations = data/vispro_dense_annotations.json
visdial_question_parse_vocab = data/parse_vocab.pkl
# init
start_path = ckpt/vdgr_visdial_v1.0_after_warmup_K2.ckpt
model_config = config/bert_base_6layer_6conect.json
# visdial training
freeze_vilbert = false
visdial_tot_rounds = 11
num_negative_samples = 1
sequences_per_image = 2
batch_size = 8
lm_loss_coeff = 1
nsp_loss_coeff = 1
img_loss_coeff = 1
batch_multiply = 1
use_trainval = false
dense_loss = ce
dense_loss_coeff = 0
dataloader_text_only = false
rlv_hst_only = false
rlv_hst_dense_round = false
# visdial model
mask_prob = 0.1
image_mask_prob = 0.1
max_seq_len = 256
num_options = 100
num_options_dense = 100
use_embedding = joint
# visdial evaluation
eval_visdial_on_test = true
eval_batch_size = 1
eval_line_batch_size = 200
skip_mrr_eval = false
skip_ndcg_eval = false
skip_visdial_eval = false
eval_visdial_every = 1
eval_dataset = visdial # visdial_vispro # choices = [visdial, visdial_conv, visdial_vispro ]
continue_evaluation = false
eval_at_start = false
eval_before_training = false
initializer = normal
bert_cased = false
# restore ckpt
loads_best_ckpt = false
loads_ckpt = false
restarts = false
resets_max_metric = false
uses_new_optimizer = false
sets_new_lr = false
loads_start_path = false
# logging
random_seed = 42
next_logging_pct = 1.0
next_evaluating_pct = 50.0
max_ckpt_to_keep = 1
num_epochs = 20
early_stop_epoch = 5
skip_saving_ckpt = false
dp_type = apex
stack_gr_data = false
master_port = 5122
stop_epochs = -1
train_each_round = false
drop_last_answer = false
num_samples = -1
# predicting
predict_split = test
predict_each_round = false
predict_dense_round = false
num_test_dialogs = 8000
num_val_dialogs = 2064
save_score = false
# optimizer
reset_optim = none
learning_rate_bert = 5e-6
learning_rate_gnn = 2e-4
gnn_weight_decay = 0.01
use_diff_lr_gnn = true
min_lr = 0
decay_method_bert = linear
decay_method_gnn = linear
decay_exp = 2
max_grad_norm = 1.0
task_optimizer = adam
warmup_ratio = 0.1
# directory
log_dir = logs/vdgr
data_dir = data
visdial_output_dir = visdial_output
bert_cache_dir = transformers
# keep track of other hparams in bert json
v_gnn_edge_dim = 12 # 11 classes + hub_node
q_gnn_edge_dim = 48 # 47 classes + hub_node
num_v_gnn_layers = 2
num_q_gnn_layers = 2
num_h_gnn_layers = 2
num_gnn_attention_heads = 4
v_gnn_ids = [0, 1, 2, 3, 4, 5]
t_gnn_ids = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
}
# Phase 2
P2_CE = ${P1} {
# basic
train_on_dense = true
use_trainval = true
metrics_to_maximize = ndcg
visdial_train_dense = data/visdial_1.0_train_dense.json
visdial_train_dense_annotations = data/visdial_1.0_train_dense_annotations.json
visdial_val_dense = data/visdial_1.0_val.json
tr_graph_idx_mapping = data/tr_dense_mapping.json
val_graph_idx_mapping = data/val_dense_mapping.json
test_graph_idx_mapping = data/test_dense_mapping.json
visdial_val = data/visdial_1.0_val.json
visdial_val_dense_annotations = data/visdial_1.0_val_dense_annotations.json
# data
start_path = logs/vdgr/P1_K2_v1.0/epoch_best.ckpt
rlv_hst_only = false
# visdial training
nsp_loss_coeff = 0
dense_loss_coeff = 1
batch_multiply = 10
batch_size = 1
# visdial model
num_options_dense = 100
# visdial evaluation
eval_batch_size = 1
eval_line_batch_size = 100
skip_mrr_eval = true
# training
stop_epochs = 3
dp_type = dp
dense_loss = ce
# optimizer
learning_rate_bert = 1e-4
}
P2_LISTNET = ${P2_CE} {
dense_loss = listnet
}

0
data/.gitkeep Normal file
View File

0
dataloader/__init__.py Normal file
View File

View File

@ -0,0 +1,269 @@
import torch
from torch.utils import data
import json
import os
import glog as log
import pickle
import torch.utils.data as tud
from pytorch_transformers.tokenization_bert import BertTokenizer
from utils.image_features_reader import ImageFeaturesH5Reader
class DatasetBase(data.Dataset):
def __init__(self, config):
if config['display']:
log.info('Initializing dataset')
# Fetch the correct dataset for evaluation
if config['validating']:
assert config.eval_dataset in ['visdial', 'visdial_conv', 'visdial_vispro', 'visdial_v09']
if config.eval_dataset == 'visdial_conv':
config['visdial_val'] = config.visdialconv_val
config['visdial_val_dense_annotations'] = config.visdialconv_val_dense_annotations
elif config.eval_dataset == 'visdial_vispro':
config['visdial_val'] = config.visdialvispro_val
config['visdial_val_dense_annotations'] = config.visdialvispro_val_dense_annotations
elif config.eval_dataset == 'visdial_v09':
config['visdial_val_09'] = config.visdial_test_09
config['visdial_val_dense_annotations'] = None
self.config = config
self.numDataPoints = {}
if not config['dataloader_text_only']:
self._image_features_reader = ImageFeaturesH5Reader(
config['visdial_image_feats'],
config['visdial_image_adj_matrices']
)
if self.config['training'] or self.config['validating'] or self.config['predicting']:
split2data = {'train': 'train', 'val': 'val', 'test': 'test'}
elif self.config['debugging']:
split2data = {'train': 'val', 'val': 'val', 'test': 'test'}
elif self.config['visualizing']:
split2data = {'train': 'train', 'val': 'train', 'test': 'test'}
filename = f'visdial_{split2data["train"]}'
if config['train_on_dense']:
filename += '_dense'
if self.config['visdial_version'] == 0.9:
filename += '_09'
with open(config[filename]) as f:
self.visdial_data_train = json.load(f)
if self.config.num_samples > 0:
self.visdial_data_train['data']['dialogs'] = self.visdial_data_train['data']['dialogs'][:self.config.num_samples]
self.numDataPoints['train'] = len(self.visdial_data_train['data']['dialogs'])
filename = f'visdial_{split2data["val"]}'
if config['train_on_dense'] and config['training']:
filename += '_dense'
if self.config['visdial_version'] == 0.9:
filename += '_09'
with open(config[filename]) as f:
self.visdial_data_val = json.load(f)
if self.config.num_samples > 0:
self.visdial_data_val['data']['dialogs'] = self.visdial_data_val['data']['dialogs'][:self.config.num_samples]
self.numDataPoints['val'] = len(self.visdial_data_val['data']['dialogs'])
if config['train_on_dense']:
self.numDataPoints['trainval'] = self.numDataPoints['train'] + self.numDataPoints['val']
with open(config[f'visdial_{split2data["test"]}']) as f:
self.visdial_data_test = json.load(f)
self.numDataPoints['test'] = len(self.visdial_data_test['data']['dialogs'])
self.rlv_hst_train = None
self.rlv_hst_val = None
self.rlv_hst_test = None
if config['train_on_dense'] or config['predict_dense_round']:
with open(config[f'visdial_{split2data["train"]}_dense_annotations']) as f:
self.visdial_data_train_dense = json.load(f)
if config['train_on_dense']:
self.subsets = ['train', 'val', 'trainval', 'test']
else:
self.subsets = ['train','val','test']
self.num_options = config["num_options"]
self.num_options_dense = config["num_options_dense"]
if config['visdial_version'] != 0.9:
with open(config[f'visdial_{split2data["val"]}_dense_annotations']) as f:
self.visdial_data_val_dense = json.load(f)
else:
self.visdial_data_val_dense = None
self._split = 'train'
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', cache_dir=config['bert_cache_dir'])
# fetching token indicecs of [CLS] and [SEP]
tokens = ['[CLS]','[MASK]','[SEP]']
indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokens)
self.CLS = indexed_tokens[0]
self.MASK = indexed_tokens[1]
self.SEP = indexed_tokens[2]
self._max_region_num = 37
self.predict_each_round = self.config['predicting'] and self.config['predict_each_round']
self.keys_to_expand = ['image_feat', 'image_loc', 'image_mask', 'image_target', 'image_label']
self.keys_to_flatten_1d = ['hist_len', 'next_sentence_labels', 'round_id', 'image_id']
self.keys_to_flatten_2d = ['tokens', 'segments', 'sep_indices', 'mask', 'image_mask', 'image_label', 'input_mask', 'question_limits']
self.keys_to_flatten_3d = ['image_feat', 'image_loc', 'image_target', ]
self.keys_other = ['gt_relevance', 'gt_option_inds']
self.keys_lists_to_flatten = ['image_edge_indices', 'image_edge_attributes', 'question_edge_indices', 'question_edge_attributes', 'history_edge_indices', 'history_sep_indices']
if config['stack_gr_data']:
self.keys_to_flatten_3d.extend(self.keys_lists_to_flatten[:-1])
self.keys_to_flatten_2d.append(self.keys_lists_to_flatten[-1])
self.keys_to_flatten_1d.extend(['len_image_gr', 'len_question_gr', 'len_history_gr', 'len_history_sep'])
self.keys_lists_to_flatten = []
self.keys_to_list = ['tot_len']
# Load the parse vocab for question graph relationship mapping
if os.path.isfile(config['visdial_question_parse_vocab']):
with open(config['visdial_question_parse_vocab'], 'rb') as f:
self.parse_vocab = pickle.load(f)
def __len__(self):
return self.numDataPoints[self._split]
@property
def split(self):
return self._split
@split.setter
def split(self, split):
assert split in self.subsets
self._split = split
def tokens2str(self, seq):
dialog_sequence = ''
for sentence in seq:
for word in sentence:
dialog_sequence += self.tokenizer._convert_id_to_token(word) + " "
dialog_sequence += ' </end> '
dialog_sequence = dialog_sequence.encode('utf8')
return dialog_sequence
def pruneRounds(self, context, num_rounds):
start_segment = 1
len_context = len(context)
cur_rounds = (len(context) // 2) + 1
l_index = 0
if cur_rounds > num_rounds:
# caption is not part of the final input
l_index = len_context - (2 * num_rounds)
start_segment = 0
return context[l_index:], start_segment
def tokenize_utterance(self, sent, sentences, tot_len, sentence_count, sentence_map, speakers):
sentences.extend(sent + ['[SEP]'])
tokenized_sent = self.tokenizer.convert_tokens_to_ids(sent)
assert len(sent) == len(tokenized_sent), 'sub-word tokens are not allowed!'
sent_len = len(tokenized_sent)
tot_len += sent_len + 1 # the additional 1 is for the sep token
sentence_count += 1
sentence_map.extend([sentence_count * 2 - 1] * sent_len)
sentence_map.append(sentence_count * 2) # for [SEP]
speakers.extend([2] * (sent_len + 1))
return tokenized_sent, sentences, tot_len, sentence_count, sentence_map, speakers
def __getitem__(self, index):
return NotImplementedError
def collate_fn(self, batch):
tokens_size = batch[0]['tokens'].size()
num_rounds, num_samples = tokens_size[0], tokens_size[1]
merged_batch = {key: [d[key] for d in batch] for key in batch[0]}
if self.config['stack_gr_data']:
if (len(batch)) > 1:
max_question_gr_len = max([length.max().item() for length in merged_batch['len_question_gr']])
max_history_gr_len = max([length.max().item() for length in merged_batch['len_history_gr']])
max_history_sep_len = max([length.max().item() for length in merged_batch['len_history_sep']])
max_image_gr_len = max([length.max().item() for length in merged_batch['len_image_gr']])
question_edge_indices_padded = []
question_edge_attributes_padded = []
for q_e_idx, q_e_attr in zip(merged_batch['question_edge_indices'], merged_batch['question_edge_attributes']):
b_size, edge_dim, orig_len = q_e_idx.size()
q_e_idx_padded = torch.zeros((b_size, edge_dim, max_question_gr_len))
q_e_idx_padded[:, :, :orig_len] = q_e_idx
question_edge_indices_padded.append(q_e_idx_padded)
edge_attr_dim = q_e_attr.size(-1)
q_e_attr_padded = torch.zeros((b_size, max_question_gr_len, edge_attr_dim))
q_e_attr_padded[:, :orig_len, :] = q_e_attr
question_edge_attributes_padded.append(q_e_attr_padded)
merged_batch['question_edge_indices'] = question_edge_indices_padded
merged_batch['question_edge_attributes'] = question_edge_attributes_padded
history_edge_indices_padded = []
for h_e_idx in merged_batch['history_edge_indices']:
b_size, _, orig_len = h_e_idx.size()
h_edge_idx_padded = torch.zeros((b_size, 2, max_history_gr_len))
h_edge_idx_padded[:, :, :orig_len] = h_e_idx
history_edge_indices_padded.append(h_edge_idx_padded)
merged_batch['history_edge_indices'] = history_edge_indices_padded
history_sep_indices_padded = []
for hist_sep_idx in merged_batch['history_sep_indices']:
b_size, orig_len = hist_sep_idx.size()
hist_sep_idx_padded = torch.zeros((b_size, max_history_sep_len))
hist_sep_idx_padded[:, :orig_len] = hist_sep_idx
history_sep_indices_padded.append(hist_sep_idx_padded)
merged_batch['history_sep_indices'] = history_sep_indices_padded
image_edge_indices_padded = []
image_edge_attributes_padded = []
for img_e_idx, img_e_attr in zip(merged_batch['image_edge_indices'], merged_batch['image_edge_attributes']):
b_size, edge_dim, orig_len = img_e_idx.size()
img_e_idx_padded = torch.zeros((b_size, edge_dim, max_image_gr_len))
img_e_idx_padded[:, :, :orig_len] = img_e_idx
image_edge_indices_padded.append(img_e_idx_padded)
edge_attr_dim = img_e_attr.size(-1)
img_e_attr_padded = torch.zeros((b_size, max_image_gr_len, edge_attr_dim))
img_e_attr_padded[:, :orig_len, :] = img_e_attr
image_edge_attributes_padded.append(img_e_attr_padded)
merged_batch['image_edge_indices'] = image_edge_indices_padded
merged_batch['image_edge_attributes'] = image_edge_attributes_padded
out = {}
for key in merged_batch:
if key in self.keys_lists_to_flatten:
temp = []
for b in merged_batch[key]:
for x in b:
temp.append(x)
merged_batch[key] = temp
elif key in self.keys_to_list:
pass
else:
merged_batch[key] = torch.stack(merged_batch[key], 0)
if key in self.keys_to_expand:
if len(merged_batch[key].size()) == 3:
size0, size1, size2 = merged_batch[key].size()
expand_size = (size0, num_rounds, num_samples, size1, size2)
elif len(merged_batch[key].size()) == 2:
size0, size1 = merged_batch[key].size()
expand_size = (size0, num_rounds, num_samples, size1)
merged_batch[key] = merged_batch[key].unsqueeze(1).unsqueeze(1).expand(expand_size).contiguous()
if key in self.keys_to_flatten_1d:
merged_batch[key] = merged_batch[key].reshape(-1)
elif key in self.keys_to_flatten_2d:
merged_batch[key] = merged_batch[key].reshape(-1, merged_batch[key].shape[-1])
elif key in self.keys_to_flatten_3d:
merged_batch[key] = merged_batch[key].reshape(-1, merged_batch[key].shape[-2], merged_batch[key].shape[-1])
else:
assert key in self.keys_other, f'unrecognized key in collate_fn: {key}'
out[key] = merged_batch[key]
return out

View File

@ -0,0 +1,615 @@
import torch
import os
import numpy as np
import random
import pickle
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from utils.data_utils import encode_input, encode_input_with_mask, encode_image_input
from dataloader.dataloader_base import DatasetBase
class VisdialDataset(DatasetBase):
def __init__(self, config):
super(VisdialDataset, self).__init__(config)
def __getitem__(self, index):
MAX_SEQ_LEN = self.config['max_seq_len']
cur_data = None
if self._split == 'train':
cur_data = self.visdial_data_train['data']
ques_adj_matrices_dir = os.path.join(self.config['visdial_question_adj_matrices'], 'train')
hist_adj_matrices_dir = os.path.join(self.config['visdial_history_adj_matrices'], 'train')
elif self._split == 'val':
cur_data = self.visdial_data_val['data']
ques_adj_matrices_dir = os.path.join(self.config['visdial_question_adj_matrices'], 'val')
hist_adj_matrices_dir = os.path.join(self.config['visdial_history_adj_matrices'], 'val')
else:
cur_data = self.visdial_data_test['data']
ques_adj_matrices_dir = os.path.join(self.config['visdial_question_adj_matrices'], 'test')
hist_adj_matrices_dir = os.path.join(self.config['visdial_history_adj_matrices'], 'test')
if self.config['visdial_version'] == 0.9:
ques_adj_matrices_dir = os.path.join(self.config['visdial_question_adj_matrices'], 'train')
hist_adj_matrices_dir = os.path.join(self.config['visdial_history_adj_matrices'], 'train')
self.num_bad_samples = 0
# number of options to score on
num_options = self.num_options
assert num_options > 1 and num_options <= 100
num_dialog_rounds = 10
dialog = cur_data['dialogs'][index]
cur_questions = cur_data['questions']
cur_answers = cur_data['answers']
img_id = dialog['image_id']
graph_idx = dialog.get('dialog_idx', index)
if self._split == 'train':
# caption
sent = dialog['caption'].split(' ')
sentences = ['[CLS]']
tot_len = 1 # for the CLS token
sentence_map = [0] # for the CLS token
sentence_count = 0
speakers = [0]
tokenized_sent, sentences, tot_len, sentence_count, sentence_map, speakers = \
self.tokenize_utterance(sent, sentences, tot_len, sentence_count, sentence_map, speakers)
utterances = [[tokenized_sent]]
utterances_random = [[tokenized_sent]]
for rnd, utterance in enumerate(dialog['dialog']):
cur_rnd_utterance = utterances[-1].copy()
cur_rnd_utterance_random = utterances[-1].copy()
# question
sent = cur_questions[utterance['question']].split(' ')
tokenized_sent, sentences, tot_len, sentence_count, sentence_map, speakers = \
self.tokenize_utterance(sent, sentences, tot_len, sentence_count, sentence_map, speakers)
cur_rnd_utterance.append(tokenized_sent)
cur_rnd_utterance_random.append(tokenized_sent)
# answer
sent = cur_answers[utterance['answer']].split(' ')
tokenized_sent, sentences, tot_len, sentence_count, sentence_map, speakers = \
self.tokenize_utterance(sent, sentences, tot_len, sentence_count, sentence_map, speakers)
cur_rnd_utterance.append(tokenized_sent)
utterances.append(cur_rnd_utterance)
# randomly select one random utterance in that round
num_inds = len(utterance['answer_options'])
gt_option_ind = utterance['gt_index']
negative_samples = []
for _ in range(self.config["num_negative_samples"]):
all_inds = list(range(100))
all_inds.remove(gt_option_ind)
all_inds = all_inds[:(num_options-1)]
tokenized_random_utterance = None
option_ind = None
while len(all_inds):
option_ind = random.choice(all_inds)
tokenized_random_utterance = self.tokenizer.convert_tokens_to_ids(cur_answers[utterance['answer_options'][option_ind]].split(' '))
# the 1 here is for the sep token at the end of each utterance
if(MAX_SEQ_LEN >= (tot_len + len(tokenized_random_utterance) + 1)):
break
else:
all_inds.remove(option_ind)
if len(all_inds) == 0:
# all the options exceed the max len. Truncate the last utterance in this case.
tokenized_random_utterance = tokenized_random_utterance[:len(tokenized_sent)]
t = cur_rnd_utterance_random.copy()
t.append(tokenized_random_utterance)
negative_samples.append(t)
utterances_random.append(negative_samples)
# removing the caption in the beginning
utterances = utterances[1:]
utterances_random = utterances_random[1:]
assert len(utterances) == len(utterances_random) == num_dialog_rounds
assert tot_len <= MAX_SEQ_LEN, '{} {} tot_len = {} > max_seq_len'.format(
self._split, index, tot_len
)
tokens_all = []
question_limits_all = []
question_edge_indices_all = []
question_edge_attributes_all = []
history_edge_indices_all = []
history_sep_indices_all = []
mask_all = []
segments_all = []
sep_indices_all = []
next_labels_all = []
hist_len_all = []
# randomly pick several rounds to train
pos_rounds = sorted(random.sample(range(num_dialog_rounds), self.config['sequences_per_image'] // 2), reverse=True)
neg_rounds = sorted(random.sample(range(num_dialog_rounds), self.config['sequences_per_image'] // 2), reverse=True)
tokens_all_rnd = []
question_limits_all_rnd = []
mask_all_rnd = []
segments_all_rnd = []
sep_indices_all_rnd = []
next_labels_all_rnd = []
hist_len_all_rnd = []
for j in pos_rounds:
context = utterances[j]
context, start_segment = self.pruneRounds(context, self.config['visdial_tot_rounds'])
if j == pos_rounds[0]: # dialog with positive label and max rounds
tokens, segments, sep_indices, mask, input_mask, start_question, end_question = encode_input_with_mask(context, start_segment, self.CLS,
self.SEP, self.MASK, max_seq_len=MAX_SEQ_LEN, mask_prob=self.config["mask_prob"])
else:
tokens, segments, sep_indices, mask, start_question, end_question = encode_input(context, start_segment, self.CLS,
self.SEP, self.MASK, max_seq_len=MAX_SEQ_LEN, mask_prob=self.config["mask_prob"])
tokens_all_rnd.append(tokens)
question_limits_all_rnd.append(torch.tensor([start_question, end_question]))
mask_all_rnd.append(mask)
sep_indices_all_rnd.append(sep_indices)
next_labels_all_rnd.append(torch.LongTensor([0]))
segments_all_rnd.append(segments)
hist_len_all_rnd.append(torch.LongTensor([len(context)-1]))
tokens_all.append(torch.cat(tokens_all_rnd,0).unsqueeze(0))
mask_all.append(torch.cat(mask_all_rnd,0).unsqueeze(0))
question_limits_all.extend(question_limits_all_rnd)
segments_all.append(torch.cat(segments_all_rnd, 0).unsqueeze(0))
sep_indices_all.append(torch.cat(sep_indices_all_rnd, 0).unsqueeze(0))
next_labels_all.append(torch.cat(next_labels_all_rnd, 0).unsqueeze(0))
hist_len_all.append(torch.cat(hist_len_all_rnd,0).unsqueeze(0))
assert len(pos_rounds) == 1
question_graphs = pickle.load(
open(os.path.join(ques_adj_matrices_dir, f'{graph_idx}.pkl'), 'rb')
)
question_graph_pos = question_graphs[pos_rounds[0]]
question_edge_index_pos = []
question_edge_attribute_pos = []
for edge_idx, edge_attr in question_graph_pos:
question_edge_index_pos.append(edge_idx)
edge_attr_one_hot = np.zeros((len(self.parse_vocab) + 1,), dtype=np.float32)
edge_attr_one_hot[self.parse_vocab.get(edge_attr, len(self.parse_vocab))] = 1.0
question_edge_attribute_pos.append(edge_attr_one_hot)
question_edge_index_pos = np.array(question_edge_index_pos, dtype=np.float64)
question_edge_attribute_pos = np.stack(question_edge_attribute_pos, axis=0)
question_edge_indices_all.append(
torch.from_numpy(question_edge_index_pos).t().long().contiguous()
)
question_edge_attributes_all.append(
torch.from_numpy(question_edge_attribute_pos)
)
history_edge_indices = pickle.load(
open(os.path.join(hist_adj_matrices_dir, f'{graph_idx}.pkl'), 'rb')
)
history_edge_indices_all.append(
torch.tensor(history_edge_indices[pos_rounds[0]]).t().long().contiguous()
)
# Get the [SEP] tokens that will represent the history graph node features
hist_idx_pos = [i * 2 for i in range(pos_rounds[0] + 1)]
sep_indices = sep_indices.squeeze(0).numpy()
history_sep_indices_all.append(torch.from_numpy(sep_indices[hist_idx_pos]))
if len(neg_rounds) > 0:
tokens_all_rnd = []
question_limits_all_rnd = []
mask_all_rnd = []
segments_all_rnd = []
sep_indices_all_rnd = []
next_labels_all_rnd = []
hist_len_all_rnd = []
for j in neg_rounds:
negative_samples = utterances_random[j]
for context_random in negative_samples:
context_random, start_segment = self.pruneRounds(context_random, self.config['visdial_tot_rounds'])
tokens_random, segments_random, sep_indices_random, mask_random, start_question, end_question = encode_input(context_random, start_segment, self.CLS,
self.SEP, self.MASK, max_seq_len=MAX_SEQ_LEN, mask_prob=self.config["mask_prob"])
tokens_all_rnd.append(tokens_random)
question_limits_all_rnd.append(torch.tensor([start_question, end_question]))
mask_all_rnd.append(mask_random)
sep_indices_all_rnd.append(sep_indices_random)
next_labels_all_rnd.append(torch.LongTensor([1]))
segments_all_rnd.append(segments_random)
hist_len_all_rnd.append(torch.LongTensor([len(context_random)-1]))
tokens_all.append(torch.cat(tokens_all_rnd,0).unsqueeze(0))
mask_all.append(torch.cat(mask_all_rnd,0).unsqueeze(0))
question_limits_all.extend(question_limits_all_rnd)
segments_all.append(torch.cat(segments_all_rnd, 0).unsqueeze(0))
sep_indices_all.append(torch.cat(sep_indices_all_rnd, 0).unsqueeze(0))
next_labels_all.append(torch.cat(next_labels_all_rnd, 0).unsqueeze(0))
hist_len_all.append(torch.cat(hist_len_all_rnd,0).unsqueeze(0))
assert len(neg_rounds) == 1
question_graph_neg = question_graphs[neg_rounds[0]]
question_edge_index_neg = []
question_edge_attribute_neg = []
for edge_idx, edge_attr in question_graph_neg:
question_edge_index_neg.append(edge_idx)
edge_attr_one_hot = np.zeros((len(self.parse_vocab) + 1,), dtype=np.float32)
edge_attr_one_hot[self.parse_vocab.get(edge_attr, len(self.parse_vocab))] = 1.0
question_edge_attribute_neg.append(edge_attr_one_hot)
question_edge_index_neg = np.array(question_edge_index_neg, dtype=np.float64)
question_edge_attribute_neg = np.stack(question_edge_attribute_neg, axis=0)
question_edge_indices_all.append(
torch.from_numpy(question_edge_index_neg).t().long().contiguous()
)
question_edge_attributes_all.append(
torch.from_numpy(question_edge_attribute_neg)
)
history_edge_indices_all.append(
torch.tensor(history_edge_indices[neg_rounds[0]]).t().long().contiguous()
)
# Get the [SEP] tokens that will represent the history graph node features
hist_idx_neg = [i * 2 for i in range(neg_rounds[0] + 1)]
sep_indices_random = sep_indices_random.squeeze(0).numpy()
history_sep_indices_all.append(torch.from_numpy(sep_indices_random[hist_idx_neg]))
tokens_all = torch.cat(tokens_all, 0) # [2, num_pos, max_len]
question_limits_all = torch.stack(question_limits_all, 0) # [2, 2]
mask_all = torch.cat(mask_all,0)
segments_all = torch.cat(segments_all, 0)
sep_indices_all = torch.cat(sep_indices_all, 0)
next_labels_all = torch.cat(next_labels_all, 0)
hist_len_all = torch.cat(hist_len_all, 0)
input_mask_all = torch.LongTensor(input_mask) # [max_len]
item = {}
item['tokens'] = tokens_all
item['question_limits'] = question_limits_all
item['question_edge_indices'] = question_edge_indices_all
item['question_edge_attributes'] = question_edge_attributes_all
item['history_edge_indices'] = history_edge_indices_all
item['history_sep_indices'] = history_sep_indices_all
item['segments'] = segments_all
item['sep_indices'] = sep_indices_all
item['mask'] = mask_all
item['next_sentence_labels'] = next_labels_all
item['hist_len'] = hist_len_all
item['input_mask'] = input_mask_all
# get image features
if not self.config['dataloader_text_only']:
features, num_boxes, boxes, _ , image_target, image_edge_indexes, image_edge_attributes = self._image_features_reader[img_id]
features, spatials, image_mask, image_target, image_label = encode_image_input(features, num_boxes, boxes, image_target, max_regions=self._max_region_num)
else:
features = spatials = image_mask = image_target = image_label = torch.tensor([0])
elif self._split == 'val':
gt_relevance = None
gt_option_inds = []
options_all = []
# caption
sent = dialog['caption'].split(' ')
sentences = ['[CLS]']
tot_len = 1 # for the CLS token
sentence_map = [0] # for the CLS token
sentence_count = 0
speakers = [0]
tokenized_sent, sentences, tot_len, sentence_count, sentence_map, speakers = \
self.tokenize_utterance(sent, sentences, tot_len, sentence_count, sentence_map, speakers)
utterances = [[tokenized_sent]]
for rnd, utterance in enumerate(dialog['dialog']):
cur_rnd_utterance = utterances[-1].copy()
# question
sent = cur_questions[utterance['question']].split(' ')
tokenized_sent, sentences, tot_len, sentence_count, sentence_map, speakers = \
self.tokenize_utterance(sent, sentences, tot_len, sentence_count, sentence_map, speakers)
cur_rnd_utterance.append(tokenized_sent)
# current round
gt_option_ind = utterance['gt_index']
# first select gt option id, then choose the first num_options inds
option_inds = []
option_inds.append(gt_option_ind)
all_inds = list(range(100))
all_inds.remove(gt_option_ind)
all_inds = all_inds[:(num_options-1)]
option_inds.extend(all_inds)
gt_option_inds.append(0)
cur_rnd_options = []
answer_options = [utterance['answer_options'][k] for k in option_inds]
assert len(answer_options) == len(option_inds) == num_options
assert answer_options[0] == utterance['answer']
# for evaluation of all options and dense relevance
if self.visdial_data_val_dense:
if rnd == self.visdial_data_val_dense[index]['round_id'] - 1:
# only 1 round has gt_relevance for each example
if 'relevance' in self.visdial_data_val_dense[index]:
gt_relevance = torch.Tensor(self.visdial_data_val_dense[index]['relevance'])
else:
gt_relevance = torch.Tensor(self.visdial_data_val_dense[index]['gt_relevance'])
# shuffle based on new indices
gt_relevance = gt_relevance[torch.LongTensor(option_inds)]
else:
gt_relevance = -1
for answer_option in answer_options:
cur_rnd_cur_option = cur_rnd_utterance.copy()
cur_rnd_cur_option.append(self.tokenizer.convert_tokens_to_ids(cur_answers[answer_option].split(' ')))
cur_rnd_options.append(cur_rnd_cur_option)
# answer
sent = cur_answers[utterance['answer']].split(' ')
tokenized_sent, sentences, tot_len, sentence_count, sentence_map, speakers = \
self.tokenize_utterance(sent, sentences, tot_len, sentence_count, sentence_map, speakers)
cur_rnd_utterance.append(tokenized_sent)
utterances.append(cur_rnd_utterance)
options_all.append(cur_rnd_options)
# encode the input and create batch x 10 x 100 * max_len arrays (batch x num_rounds x num_options)
tokens_all = []
question_limits_all = []
mask_all = []
segments_all = []
sep_indices_all = []
hist_len_all = []
history_sep_indices_all = []
for rnd, cur_rnd_options in enumerate(options_all):
tokens_all_rnd = []
mask_all_rnd = []
segments_all_rnd = []
sep_indices_all_rnd = []
hist_len_all_rnd = []
for j, cur_rnd_option in enumerate(cur_rnd_options):
cur_rnd_option, start_segment = self.pruneRounds(cur_rnd_option, self.config['visdial_tot_rounds'])
if rnd == len(options_all) - 1 and j == 0: # gt dialog
tokens, segments, sep_indices, mask, input_mask, start_question, end_question = encode_input_with_mask(cur_rnd_option, start_segment, self.CLS,
self.SEP, self.MASK, max_seq_len=MAX_SEQ_LEN, mask_prob=0)
else:
tokens, segments, sep_indices, mask, start_question, end_question = encode_input(cur_rnd_option, start_segment,self.CLS,
self.SEP, self.MASK ,max_seq_len=MAX_SEQ_LEN, mask_prob=0)
tokens_all_rnd.append(tokens)
mask_all_rnd.append(mask)
segments_all_rnd.append(segments)
sep_indices_all_rnd.append(sep_indices)
hist_len_all_rnd.append(torch.LongTensor([len(cur_rnd_option)-1]))
question_limits_all.append(torch.tensor([start_question, end_question]).unsqueeze(0).repeat(100, 1))
tokens_all.append(torch.cat(tokens_all_rnd,0).unsqueeze(0))
mask_all.append(torch.cat(mask_all_rnd,0).unsqueeze(0))
segments_all.append(torch.cat(segments_all_rnd,0).unsqueeze(0))
sep_indices_all.append(torch.cat(sep_indices_all_rnd,0).unsqueeze(0))
hist_len_all.append(torch.cat(hist_len_all_rnd,0).unsqueeze(0))
# Get the [SEP] tokens that will represent the history graph node features
# It will be the same for all answer candidates as the history does not change
# for each answer
hist_idx = [i * 2 for i in range(rnd + 1)]
history_sep_indices_all.extend(sep_indices.squeeze(0)[hist_idx].contiguous() for _ in range(100))
tokens_all = torch.cat(tokens_all, 0) # [10, 100, max_len]
mask_all = torch.cat(mask_all, 0)
segments_all = torch.cat(segments_all, 0)
sep_indices_all = torch.cat(sep_indices_all, 0)
hist_len_all = torch.cat(hist_len_all, 0)
input_mask_all = torch.LongTensor(input_mask) # [max_len]
# load graph data
question_limits_all = torch.stack(question_limits_all, 0) # [10, 100, 2]
question_graphs = pickle.load(
open(os.path.join(ques_adj_matrices_dir, f'{graph_idx}.pkl'), 'rb')
)
question_edge_indices_all = [] # [10, N] we do not repeat it 100 times here
question_edge_attributes_all = [] # [10, N] we do not repeat it 100 times here
for q_graph_round in question_graphs:
question_edge_index = []
question_edge_attribute = []
for edge_index, edge_attr in q_graph_round:
question_edge_index.append(edge_index)
edge_attr_one_hot = np.zeros((len(self.parse_vocab) + 1,), dtype=np.float32)
edge_attr_one_hot[self.parse_vocab.get(edge_attr, len(self.parse_vocab))] = 1.0
question_edge_attribute.append(edge_attr_one_hot)
question_edge_index = np.array(question_edge_index, dtype=np.float64)
question_edge_attribute = np.stack(question_edge_attribute, axis=0)
question_edge_indices_all.extend(
[torch.from_numpy(question_edge_index).t().long().contiguous() for _ in range(100)])
question_edge_attributes_all.extend(
[torch.from_numpy(question_edge_attribute).contiguous() for _ in range(100)])
_history_edge_incides_all = pickle.load(
open(os.path.join(hist_adj_matrices_dir, f'{graph_idx}.pkl'), 'rb')
)
history_edge_incides_all = []
for hist_edge_indices_rnd in _history_edge_incides_all:
history_edge_incides_all.extend(
[torch.tensor(hist_edge_indices_rnd).t().long().contiguous() for _ in range(100)]
)
item = {}
item['tokens'] = tokens_all
item['segments'] = segments_all
item['sep_indices'] = sep_indices_all
item['mask'] = mask_all
item['hist_len'] = hist_len_all
item['input_mask'] = input_mask_all
item['gt_option_inds'] = torch.LongTensor(gt_option_inds)
# return dense annotation data as well
if self.visdial_data_val_dense:
item['round_id'] = torch.LongTensor([self.visdial_data_val_dense[index]['round_id']])
item['gt_relevance'] = gt_relevance
item['question_limits'] = question_limits_all
item['question_edge_indices'] = question_edge_indices_all
item['question_edge_attributes'] = question_edge_attributes_all
item['history_edge_indices'] = history_edge_incides_all
item['history_sep_indices'] = history_sep_indices_all
# get image features
if not self.config['dataloader_text_only']:
features, num_boxes, boxes, _ , image_target, image_edge_indexes, image_edge_attributes = self._image_features_reader[img_id]
features, spatials, image_mask, image_target, image_label = encode_image_input(
features, num_boxes, boxes, image_target, max_regions=self._max_region_num, mask_prob=0)
else:
features = spatials = image_mask = image_target = image_label = torch.tensor([0])
elif self.split == 'test':
assert num_options == 100
cur_rnd_utterance = [self.tokenizer.convert_tokens_to_ids(dialog['caption'].split(' '))]
options_all = []
for rnd,utterance in enumerate(dialog['dialog']):
cur_rnd_utterance.append(self.tokenizer.convert_tokens_to_ids(cur_questions[utterance['question']].split(' ')))
if rnd != len(dialog['dialog'])-1:
cur_rnd_utterance.append(self.tokenizer.convert_tokens_to_ids(cur_answers[utterance['answer']].split(' ')))
for answer_option in dialog['dialog'][-1]['answer_options']:
cur_option = cur_rnd_utterance.copy()
cur_option.append(self.tokenizer.convert_tokens_to_ids(cur_answers[answer_option].split(' ')))
options_all.append(cur_option)
tokens_all = []
mask_all = []
segments_all = []
sep_indices_all = []
hist_len_all = []
for j, option in enumerate(options_all):
option, start_segment = self.pruneRounds(option, self.config['visdial_tot_rounds'])
tokens, segments, sep_indices, mask = encode_input(option, start_segment ,self.CLS,
self.SEP, self.MASK ,max_seq_len=MAX_SEQ_LEN, mask_prob=0)
tokens_all.append(tokens)
mask_all.append(mask)
segments_all.append(segments)
sep_indices_all.append(sep_indices)
hist_len_all.append(torch.LongTensor([len(option)-1]))
tokens_all = torch.cat(tokens_all,0)
mask_all = torch.cat(mask_all,0)
segments_all = torch.cat(segments_all, 0)
sep_indices_all = torch.cat(sep_indices_all, 0)
hist_len_all = torch.cat(hist_len_all,0)
hist_idx = [i*2 for i in range(len(dialog['dialog']))]
history_sep_indices_all = [sep_indices.squeeze(0)[hist_idx].contiguous() for _ in range(num_options)]
with open(os.path.join(ques_adj_matrices_dir, f'{graph_idx}.pkl'), 'rb') as f:
question_graphs = pickle.load(f)
q_graph_last = question_graphs[-1]
question_edge_index = []
question_edge_attribute = []
for edge_index, edge_attr in q_graph_last:
question_edge_index.append(edge_index)
edge_attr_one_hot = np.zeros((len(self.parse_vocab) + 1,), dtype=np.float32)
edge_attr_one_hot[self.parse_vocab.get(edge_attr, len(self.parse_vocab))] = 1.0
question_edge_attribute.append(edge_attr_one_hot)
question_edge_index = np.array(question_edge_index, dtype=np.float64)
question_edge_attribute = np.stack(question_edge_attribute, axis=0)
question_edge_indices_all = [torch.from_numpy(question_edge_index).t().long().contiguous() for _ in range(num_options)]
question_edge_attributes_all = [torch.from_numpy(question_edge_attribute).contiguous() for _ in range(num_options)]
with open(os.path.join(hist_adj_matrices_dir, f'{graph_idx}.pkl'), 'rb') as f:
_history_edge_incides_all = pickle.load(f)
_history_edge_incides_last = _history_edge_incides_all[-1]
history_edge_index_all = [torch.tensor(_history_edge_incides_last).t().long().contiguous() for _ in range(num_options)]
if self.config['stack_gr_data']:
question_edge_indices_all = torch.stack(question_edge_indices_all, dim=0)
question_edge_attributes_all = torch.stack(question_edge_attributes_all, dim=0)
history_edge_index_all = torch.stack(history_edge_index_all, dim=0)
history_sep_indices_all = torch.stack(history_sep_indices_all, dim=0)
len_question_gr = torch.tensor(question_edge_indices_all.size(-1)).unsqueeze(0).repeat(num_options, 1)
len_history_gr = torch.tensor(history_edge_index_all.size(-1)).repeat(num_options, 1)
len_history_sep = torch.tensor(history_sep_indices_all.size(-1)).repeat(num_options, 1)
item = {}
item['tokens'] = tokens_all.unsqueeze(0)
item['segments'] = segments_all.unsqueeze(0)
item['sep_indices'] = sep_indices_all.unsqueeze(0)
item['mask'] = mask_all.unsqueeze(0)
item['hist_len'] = hist_len_all.unsqueeze(0)
item['question_limits'] = question_limits_all
item['question_edge_indices'] = question_edge_indices_all
item['question_edge_attributes'] = question_edge_attributes_all
item['history_edge_indices'] = history_edge_index_all
item['history_sep_indices'] = history_sep_indices_all
if self.config['stack_gr_data']:
item['len_question_gr'] = len_question_gr
item['len_history_gr'] = len_history_gr
item['len_history_sep'] = len_history_sep
item['round_id'] = torch.LongTensor([dialog['round_id']])
# get image features
if not self.config['dataloader_text_only']:
features, num_boxes, boxes, _ , image_target, image_edge_indexes, image_edge_attributes = self._image_features_reader[img_id]
features, spatials, image_mask, image_target, image_label = encode_image_input(features, num_boxes, boxes, image_target, max_regions=self._max_region_num, mask_prob=0)
else:
features = spatials = image_mask = image_target = image_label = torch.tensor([0])
item['image_feat'] = features
item['image_loc'] = spatials
item['image_mask'] = image_mask
item['image_target'] = image_target
item['image_label'] = image_label
item['image_id'] = torch.LongTensor([img_id])
if self._split == 'train':
# cheap hack to account for the graph data for the postitive and negatice examples
item['image_edge_indices'] = [torch.from_numpy(image_edge_indexes).long(), torch.from_numpy(image_edge_indexes).long()]
item['image_edge_attributes'] = [torch.from_numpy(image_edge_attributes), torch.from_numpy(image_edge_attributes)]
elif self._split == 'val':
# cheap hack to account for the graph data for the postitive and negatice examples
item['image_edge_indices'] = [torch.from_numpy(image_edge_indexes).contiguous().long() for _ in range(1000)]
item['image_edge_attributes'] = [torch.from_numpy(image_edge_attributes).contiguous() for _ in range(1000)]
else:
# cheap hack to account for the graph data for the postitive and negatice examples
item['image_edge_indices'] = [torch.from_numpy(image_edge_indexes).contiguous().long() for _ in range(100)]
item['image_edge_attributes'] = [torch.from_numpy(image_edge_attributes).contiguous() for _ in range(100)]
if self.config['stack_gr_data']:
item['image_edge_indices'] = torch.stack(item['image_edge_indices'], dim=0)
item['image_edge_attributes'] = torch.stack(item['image_edge_attributes'], dim=0)
len_image_gr = torch.tensor(item['image_edge_indices'].size(-1)).unsqueeze(0).repeat(num_options)
item['len_image_gr'] = len_image_gr
return item

View File

@ -0,0 +1,313 @@
import torch
import json
import os
import time
import numpy as np
import random
from tqdm import tqdm
import copy
import pyhocon
import glog as log
from collections import OrderedDict
import argparse
import pickle
import torch.utils.data as tud
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from utils.data_utils import encode_input, encode_image_input
from dataloader.dataloader_base import DatasetBase