add code
This commit is contained in:
parent
58fa893d4d
commit
c8f2babd76
8 changed files with 300 additions and 7 deletions
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
.DS_STORE
|
||||||
|
.pyc
|
18
Code/dataset_new.py
Normal file
18
Code/dataset_new.py
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class ImagesWithSaliency(Dataset):
|
||||||
|
def __init__(self, npy_path, dtype=None):
|
||||||
|
self.dtype = dtype
|
||||||
|
self.datas = np.load(npy_path, allow_pickle = True)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.datas)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
if self.dtype:
|
||||||
|
self.datas[idx][0] = self.datas[idx][0].type(self.dtype)
|
||||||
|
self.datas[idx][3] = self.datas[idx][3].type(self.dtype)
|
||||||
|
|
||||||
|
return self.datas[idx]
|
6
Code/env.py
Normal file
6
Code/env.py
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
|
||||||
|
import os
|
||||||
|
os.environ['TORCH_HOME'] = '/projects/wang/.cache/torch'
|
||||||
|
os.environ['TRANSFORMERS_CACHE'] = '/projects/wang/.cache'
|
||||||
|
|
||||||
|
my_variable = os.environ.get('TORCH_HOME')
|
110
Code/evaluation.py
Normal file
110
Code/evaluation.py
Normal file
|
@ -0,0 +1,110 @@
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from env import *
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from dataset_new import ImagesWithSaliency
|
||||||
|
from torchvision.utils import save_image
|
||||||
|
from transformers import SwinModel
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
def evaluation(Model:str, ckpt: str, device, batch_size:int):
|
||||||
|
eps=1e-10
|
||||||
|
|
||||||
|
if Model == 'llama':
|
||||||
|
from model_llama import SalFormer
|
||||||
|
from transformers import LlamaModel
|
||||||
|
from tokenizer_llama import padding_fn
|
||||||
|
# llm = LlamaModel.from_pretrained("Enoch/llama-7b-hf", low_cpu_mem_usage=True)
|
||||||
|
llm = LlamaModel.from_pretrained("daryl149/Llama-2-7b-chat-hf", low_cpu_mem_usage=True)
|
||||||
|
neuron_n = 4096
|
||||||
|
print("llama loaded")
|
||||||
|
elif Model == 'bloom':
|
||||||
|
from model_llama import SalFormer
|
||||||
|
from transformers import BloomModel
|
||||||
|
from tokenizer_bloom import padding_fn
|
||||||
|
llm = BloomModel.from_pretrained("bigscience/bloom-3b")
|
||||||
|
neuron_n = 2560
|
||||||
|
print('BloomModel loaded')
|
||||||
|
elif Model == 'bert':
|
||||||
|
from model_swin import SalFormer
|
||||||
|
from transformers import BertModel
|
||||||
|
from tokenizer_bert import padding_fn
|
||||||
|
llm = BertModel.from_pretrained("bert-base-uncased")
|
||||||
|
print('BertModel loaded')
|
||||||
|
else:
|
||||||
|
print('model not available, possiblilities: llama, bloom, bert')
|
||||||
|
return
|
||||||
|
|
||||||
|
test_set = ImagesWithSaliency("data/test.npy")
|
||||||
|
|
||||||
|
Path('./eval_results').mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
|
||||||
|
vit = SwinModel.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
|
||||||
|
# vit = timm.create_model('xception41p.ra3_in1k', pretrained=True)
|
||||||
|
|
||||||
|
if Model == 'bert':
|
||||||
|
model = SalFormer(vit, llm).to(device)
|
||||||
|
else:
|
||||||
|
model = SalFormer(vit, llm, neuron_n = neuron_n).to(device)
|
||||||
|
|
||||||
|
checkpoint = torch.load(ckpt)
|
||||||
|
model.load_state_dict(checkpoint['model_state_dict'])
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
|
||||||
|
test_dataloader = DataLoader(test_set, batch_size=batch_size, shuffle=False, collate_fn=padding_fn, num_workers=8)
|
||||||
|
kl_loss = torch.nn.KLDivLoss(reduction="batchmean", log_target=True)
|
||||||
|
|
||||||
|
test_kl, test_cc, test_nss = 0,0,0
|
||||||
|
for batch, (img, input_ids, fix, hm, name) in enumerate(test_dataloader):
|
||||||
|
img = img.to(device)
|
||||||
|
input_ids = input_ids.to(device)
|
||||||
|
fix = fix.to(device)
|
||||||
|
hm = hm.to(device)
|
||||||
|
|
||||||
|
y = model(img, input_ids)
|
||||||
|
|
||||||
|
y_sum = y.view(y.shape[0], -1).sum(1, keepdim=True)
|
||||||
|
y_distribution = y / (y_sum[:, :, None, None] + eps)
|
||||||
|
|
||||||
|
hm_sum = hm.view(y.shape[0], -1).sum(1, keepdim=True)
|
||||||
|
hm_distribution = hm / (hm_sum[:, :, None, None] + eps)
|
||||||
|
hm_distribution = hm_distribution + eps
|
||||||
|
hm_distribution = hm_distribution / (1+eps)
|
||||||
|
|
||||||
|
if fix.sum() != 0:
|
||||||
|
normal_y = (y-y.mean())/y.std()
|
||||||
|
nss = torch.sum(normal_y*fix)/fix.sum()
|
||||||
|
else:
|
||||||
|
nss = torch.Tensor([0.0]).to(device)
|
||||||
|
kl = kl_loss(torch.log(y_distribution), torch.log(hm_distribution))
|
||||||
|
|
||||||
|
vy = y - torch.mean(y)
|
||||||
|
vhm = hm - torch.mean(hm)
|
||||||
|
|
||||||
|
if (torch.sqrt(torch.sum(vy ** 2)) * torch.sqrt(torch.sum(vhm ** 2))) != 0:
|
||||||
|
cc = torch.sum(vy * vhm) / (torch.sqrt(torch.sum(vy ** 2)) * torch.sqrt(torch.sum(vhm ** 2)))
|
||||||
|
else:
|
||||||
|
cc = torch.Tensor([0.0]).to(device)
|
||||||
|
|
||||||
|
test_kl += kl.item()/len(test_dataloader)
|
||||||
|
test_cc += cc.item()/len(test_dataloader)
|
||||||
|
test_nss += nss.item()/len(test_dataloader)
|
||||||
|
|
||||||
|
for i in range(0, y.shape[0]):
|
||||||
|
save_image(y[i], f"./eval_results/{name[i]}")
|
||||||
|
|
||||||
|
print("kl:", test_kl, "cc", test_cc, "nss", test_nss)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--model", type=str, default='bert')
|
||||||
|
parser.add_argument("--device", type=str, default='cuda')
|
||||||
|
parser.add_argument("--batch_size", type=int, default=16)
|
||||||
|
parser.add_argument("--ckpt", type=str, default='./ckpt/model_bert_freeze_10kl_5cc_2nss.tar')
|
||||||
|
args = vars(parser.parse_args())
|
||||||
|
|
||||||
|
evaluation(Model = args['model'], device = args['device'], ckpt = args['ckpt'], batch_size = args['batch_size'])
|
1
Code/evaluation.sh
Executable file
1
Code/evaluation.sh
Executable file
|
@ -0,0 +1 @@
|
||||||
|
python evaluation.py --model 'bert' --ckpt './VisSalFormer_weights.tar' --device 'cuda'
|
116
Code/model_swin.py
Normal file
116
Code/model_swin.py
Normal file
|
@ -0,0 +1,116 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
class SalFormer(torch.nn.Module):
|
||||||
|
def __init__(self, vision_encoder, bert):
|
||||||
|
"""
|
||||||
|
In the constructor we instantiate four parameters and assign them as
|
||||||
|
member parameters.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.vit = vision_encoder
|
||||||
|
self.feature_dim = 768
|
||||||
|
self.bert = bert
|
||||||
|
|
||||||
|
self.vision_head = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(self.feature_dim, self.feature_dim),
|
||||||
|
torch.nn.GELU(),
|
||||||
|
torch.nn.Linear(self.feature_dim, self.feature_dim),
|
||||||
|
torch.nn.GELU()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.text_head = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(self.feature_dim, self.feature_dim),
|
||||||
|
torch.nn.GELU(),
|
||||||
|
torch.nn.Linear(self.feature_dim, self.feature_dim),
|
||||||
|
torch.nn.GELU()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.cross_attention = torch.nn.MultiheadAttention(self.feature_dim, 16, kdim=self.feature_dim, vdim=self.feature_dim, batch_first=True)
|
||||||
|
self.cross_attention1 = torch.nn.MultiheadAttention(self.feature_dim, 16, kdim=self.feature_dim, vdim=self.feature_dim, batch_first=True)
|
||||||
|
|
||||||
|
self.ln1 = torch.nn.LayerNorm(self.feature_dim)
|
||||||
|
self.ln2 = torch.nn.LayerNorm(self.feature_dim)
|
||||||
|
|
||||||
|
self.self_attetion = torch.nn.MultiheadAttention(self.feature_dim, 16, batch_first=True)
|
||||||
|
|
||||||
|
self.text_feature_query = torch.nn.Parameter(torch.randn(10, self.feature_dim).unsqueeze(0)/2)
|
||||||
|
self.img_positional_embedding = torch.nn.Parameter(torch.zeros(49, self.feature_dim))
|
||||||
|
self.text_positional_embedding = torch.nn.Parameter(torch.zeros(10, self.feature_dim))
|
||||||
|
|
||||||
|
|
||||||
|
self.dense1 = torch.nn.Linear(self.feature_dim, self.feature_dim)
|
||||||
|
self.relu1 = torch.nn.ReLU()
|
||||||
|
|
||||||
|
self.decoder = torch.nn.Sequential(
|
||||||
|
torch.nn.Conv2d(self.feature_dim, 512, 3),
|
||||||
|
torch.nn.BatchNorm2d(512),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.Dropout(p=0.1),
|
||||||
|
torch.nn.Conv2d(512, 512, 3),
|
||||||
|
torch.nn.BatchNorm2d(512),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.Dropout(p=0.1),
|
||||||
|
torch.nn.Upsample((16,16), mode='bilinear'),
|
||||||
|
torch.nn.Conv2d(512, 256, 3),
|
||||||
|
torch.nn.BatchNorm2d(256),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.Dropout(p=0.1),
|
||||||
|
torch.nn.Conv2d(256, 256, 3),
|
||||||
|
torch.nn.BatchNorm2d(256),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.Dropout(p=0.1),
|
||||||
|
torch.nn.Upsample((32,32), mode='bilinear'),
|
||||||
|
torch.nn.Conv2d(256, 128, 3),
|
||||||
|
torch.nn.BatchNorm2d(128),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.Dropout(p=0.1),
|
||||||
|
torch.nn.Conv2d(128, 128, 3),
|
||||||
|
torch.nn.BatchNorm2d(128),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.Dropout(p=0.1),
|
||||||
|
torch.nn.Upsample((130,130), mode='bilinear'),
|
||||||
|
torch.nn.Conv2d(128, 1, 3),
|
||||||
|
torch.nn.BatchNorm2d(1),
|
||||||
|
torch.nn.Sigmoid(),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.vit.eval()
|
||||||
|
self.bert.eval()
|
||||||
|
self.train(True)
|
||||||
|
|
||||||
|
# def eval(self):
|
||||||
|
# super().eval()
|
||||||
|
# self.vit.eval()
|
||||||
|
# self.bert.eval()
|
||||||
|
|
||||||
|
# def train(self, mode=True):
|
||||||
|
# super().train(mode)
|
||||||
|
# self.vit.train(mode)
|
||||||
|
# self.bert.train(mode)
|
||||||
|
|
||||||
|
def forward(self, img, q_inputs):
|
||||||
|
|
||||||
|
img_features = self.vit.forward(img, return_dict =True)["last_hidden_state"]
|
||||||
|
with torch.no_grad():
|
||||||
|
text_features = self.bert(**q_inputs)["last_hidden_state"]
|
||||||
|
# text_features = torch.unsqueeze(bert_output["last_hidden_state"][:,0,:], 1)
|
||||||
|
text_features = self.cross_attention.forward(self.text_feature_query.repeat([text_features.shape[0], 1, 1]), text_features, text_features, need_weights=False)[0]
|
||||||
|
|
||||||
|
fused_features = torch.concat((self.vision_head(img_features)+self.img_positional_embedding, self.text_head(text_features)+self.text_positional_embedding), 1)
|
||||||
|
att_fused_features = self.self_attetion.forward(fused_features, fused_features, fused_features, need_weights=False)[0]
|
||||||
|
fused_features = fused_features + att_fused_features
|
||||||
|
fused_features = self.ln1(fused_features)
|
||||||
|
|
||||||
|
features = self.cross_attention1.forward(img_features, fused_features, fused_features, need_weights=False)[0]
|
||||||
|
features = img_features + features
|
||||||
|
features = self.ln2(features)
|
||||||
|
|
||||||
|
features = self.dense1(features)
|
||||||
|
latent_features = self.relu1(features)
|
||||||
|
|
||||||
|
latent_features = latent_features.permute(0,2,1)
|
||||||
|
out = torch.reshape(latent_features, (features.shape[0], self.feature_dim, 7, 7))
|
||||||
|
out = self.decoder(out)
|
||||||
|
|
||||||
|
return out
|
15
Code/tokenizer_bert.py
Normal file
15
Code/tokenizer_bert.py
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
import torch
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
# tokenizer = AutoTokenizer.from_pretrained("roberta-base")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||||
|
print('bert-base-uncased tokenizer loaded')
|
||||||
|
|
||||||
|
def padding_fn(data):
|
||||||
|
img, q, fix, hm, name = zip(*data)
|
||||||
|
|
||||||
|
input_ids = tokenizer(q, return_tensors="pt", padding=True)
|
||||||
|
|
||||||
|
return torch.stack(img), input_ids, torch.stack(fix), torch.stack(hm), name
|
||||||
|
|
||||||
|
|
39
README.md
39
README.md
|
@ -12,13 +12,38 @@ $Root Directory
|
||||||
│
|
│
|
||||||
│─ README.md —— this file
|
│─ README.md —— this file
|
||||||
│
|
│
|
||||||
|─ VisSalFormer —— Source code of the network to predict question-driven saliency
|
|─ Code —— Source code of the VisSalFormer model to predict question-driven saliency
|
||||||
│ │
|
│ │
|
||||||
│ │─ coming soon
|
│ |─ environment.yml —— conda environment
|
||||||
│
|
│ │
|
||||||
└─ SalChartQA —— The dataset
|
│ |─ env.py —— python envorinment $TORCH_HOME and $TRANSFORMERS_CACHE
|
||||||
|
│ │
|
||||||
|
│ │─ dataset_new.py —— dataloader for SalChartQA
|
||||||
|
│ │
|
||||||
|
│ │─ evaluation.py —— evaluation script to load VisSalFormer weights and make predictions
|
||||||
|
│ │
|
||||||
|
│ │─ evaluation.sh —— bash script to run evaluation.py
|
||||||
|
│ │
|
||||||
|
│ │─ model_swin.py —— definition of the VisSalFormer model
|
||||||
|
│ │
|
||||||
|
│ │─ tokenizer_bert.py —— tokenizer of Bert
|
||||||
|
│ │
|
||||||
|
│ └─ VisSalFormer_weights.tar —— weights of VisSalFormer
|
||||||
|
│
|
||||||
|
└─ SalChartQA.zip —— The SalChartQA dataset
|
||||||
│
|
│
|
||||||
│─ coming soon
|
│─ fixationByVis —— BubbleView data (mouse clicks) of AMT workers
|
||||||
|
│
|
||||||
|
│─ image_questions.json —— visualisation-question pairs
|
||||||
|
│
|
||||||
|
│─ raw_img —— original visualisations from the ChartQA dataset
|
||||||
|
│
|
||||||
|
│─ saliency_all —— saliency maps from all AMT workers
|
||||||
|
│
|
||||||
|
│─ saliency_ans —— saliency maps aggretated by all AMT workers who either answered a question correctly or wrongly
|
||||||
|
│
|
||||||
|
└─ unified_approved.csv —— responses from AMT workers
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
If you think our work is useful to you, please consider citing our paper as:
|
If you think our work is useful to you, please consider citing our paper as:
|
||||||
|
@ -28,9 +53,9 @@ If you think our work is useful to you, please consider citing our paper as:
|
||||||
title = {SalChartQA: Question-driven Saliency on Information Visualisations},
|
title = {SalChartQA: Question-driven Saliency on Information Visualisations},
|
||||||
author = {Wang, Yao and Wang, Weitian and Abdelhafez, Abdullah and Elfares, Mayar and Hu, Zhiming and B{\^a}ce, Mihai and Bulling, Andreas},
|
author = {Wang, Yao and Wang, Weitian and Abdelhafez, Abdullah and Elfares, Mayar and Hu, Zhiming and B{\^a}ce, Mihai and Bulling, Andreas},
|
||||||
year = {2024},
|
year = {2024},
|
||||||
pages = {1--20},
|
pages = {1--14},
|
||||||
booktitle = {Proc. ACM SIGCHI Conference on Human Factors in Computing Systems (CHI)},
|
booktitle = {Proc. ACM SIGCHI Conference on Human Factors in Computing Systems (CHI)},
|
||||||
doi = {}
|
doi = {10.1145/3613904.3642942}
|
||||||
}
|
}
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
Loading…
Reference in a new issue