From c8f2babd763515f99843082d97b6a60428f15eec Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Mon, 22 Jan 2024 21:02:16 +0800 Subject: [PATCH] add code --- .gitignore | 2 + Code/dataset_new.py | 18 +++++++ Code/env.py | 6 +++ Code/evaluation.py | 110 ++++++++++++++++++++++++++++++++++++++ Code/evaluation.sh | 1 + Code/model_swin.py | 116 +++++++++++++++++++++++++++++++++++++++++ Code/tokenizer_bert.py | 15 ++++++ README.md | 39 +++++++++++--- 8 files changed, 300 insertions(+), 7 deletions(-) create mode 100644 .gitignore create mode 100644 Code/dataset_new.py create mode 100644 Code/env.py create mode 100644 Code/evaluation.py create mode 100755 Code/evaluation.sh create mode 100644 Code/model_swin.py create mode 100644 Code/tokenizer_bert.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ee8ad4a --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +.DS_STORE +.pyc diff --git a/Code/dataset_new.py b/Code/dataset_new.py new file mode 100644 index 0000000..fa32b2e --- /dev/null +++ b/Code/dataset_new.py @@ -0,0 +1,18 @@ +import torch +from torch.utils.data import Dataset +import numpy as np + +class ImagesWithSaliency(Dataset): + def __init__(self, npy_path, dtype=None): + self.dtype = dtype + self.datas = np.load(npy_path, allow_pickle = True) + + def __len__(self): + return len(self.datas) + + def __getitem__(self, idx): + if self.dtype: + self.datas[idx][0] = self.datas[idx][0].type(self.dtype) + self.datas[idx][3] = self.datas[idx][3].type(self.dtype) + + return self.datas[idx] diff --git a/Code/env.py b/Code/env.py new file mode 100644 index 0000000..8962cd8 --- /dev/null +++ b/Code/env.py @@ -0,0 +1,6 @@ + +import os +os.environ['TORCH_HOME'] = '/projects/wang/.cache/torch' +os.environ['TRANSFORMERS_CACHE'] = '/projects/wang/.cache' + +my_variable = os.environ.get('TORCH_HOME') \ No newline at end of file diff --git a/Code/evaluation.py b/Code/evaluation.py new file mode 100644 index 0000000..091f2e8 --- /dev/null +++ b/Code/evaluation.py @@ -0,0 +1,110 @@ +import torch +from torch.utils.data import DataLoader +from env import * + +import argparse +from dataset_new import ImagesWithSaliency +from torchvision.utils import save_image +from transformers import SwinModel +from pathlib import Path + +def evaluation(Model:str, ckpt: str, device, batch_size:int): + eps=1e-10 + + if Model == 'llama': + from model_llama import SalFormer + from transformers import LlamaModel + from tokenizer_llama import padding_fn + # llm = LlamaModel.from_pretrained("Enoch/llama-7b-hf", low_cpu_mem_usage=True) + llm = LlamaModel.from_pretrained("daryl149/Llama-2-7b-chat-hf", low_cpu_mem_usage=True) + neuron_n = 4096 + print("llama loaded") + elif Model == 'bloom': + from model_llama import SalFormer + from transformers import BloomModel + from tokenizer_bloom import padding_fn + llm = BloomModel.from_pretrained("bigscience/bloom-3b") + neuron_n = 2560 + print('BloomModel loaded') + elif Model == 'bert': + from model_swin import SalFormer + from transformers import BertModel + from tokenizer_bert import padding_fn + llm = BertModel.from_pretrained("bert-base-uncased") + print('BertModel loaded') + else: + print('model not available, possiblilities: llama, bloom, bert') + return + + test_set = ImagesWithSaliency("data/test.npy") + + Path('./eval_results').mkdir(parents=True, exist_ok=True) + + # vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k") + vit = SwinModel.from_pretrained("microsoft/swin-tiny-patch4-window7-224") + # vit = timm.create_model('xception41p.ra3_in1k', pretrained=True) + + if Model == 'bert': + model = SalFormer(vit, llm).to(device) + else: + model = SalFormer(vit, llm, neuron_n = neuron_n).to(device) + + checkpoint = torch.load(ckpt) + model.load_state_dict(checkpoint['model_state_dict']) + model.eval() + + + test_dataloader = DataLoader(test_set, batch_size=batch_size, shuffle=False, collate_fn=padding_fn, num_workers=8) + kl_loss = torch.nn.KLDivLoss(reduction="batchmean", log_target=True) + + test_kl, test_cc, test_nss = 0,0,0 + for batch, (img, input_ids, fix, hm, name) in enumerate(test_dataloader): + img = img.to(device) + input_ids = input_ids.to(device) + fix = fix.to(device) + hm = hm.to(device) + + y = model(img, input_ids) + + y_sum = y.view(y.shape[0], -1).sum(1, keepdim=True) + y_distribution = y / (y_sum[:, :, None, None] + eps) + + hm_sum = hm.view(y.shape[0], -1).sum(1, keepdim=True) + hm_distribution = hm / (hm_sum[:, :, None, None] + eps) + hm_distribution = hm_distribution + eps + hm_distribution = hm_distribution / (1+eps) + + if fix.sum() != 0: + normal_y = (y-y.mean())/y.std() + nss = torch.sum(normal_y*fix)/fix.sum() + else: + nss = torch.Tensor([0.0]).to(device) + kl = kl_loss(torch.log(y_distribution), torch.log(hm_distribution)) + + vy = y - torch.mean(y) + vhm = hm - torch.mean(hm) + + if (torch.sqrt(torch.sum(vy ** 2)) * torch.sqrt(torch.sum(vhm ** 2))) != 0: + cc = torch.sum(vy * vhm) / (torch.sqrt(torch.sum(vy ** 2)) * torch.sqrt(torch.sum(vhm ** 2))) + else: + cc = torch.Tensor([0.0]).to(device) + + test_kl += kl.item()/len(test_dataloader) + test_cc += cc.item()/len(test_dataloader) + test_nss += nss.item()/len(test_dataloader) + + for i in range(0, y.shape[0]): + save_image(y[i], f"./eval_results/{name[i]}") + + print("kl:", test_kl, "cc", test_cc, "nss", test_nss) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, default='bert') + parser.add_argument("--device", type=str, default='cuda') + parser.add_argument("--batch_size", type=int, default=16) + parser.add_argument("--ckpt", type=str, default='./ckpt/model_bert_freeze_10kl_5cc_2nss.tar') + args = vars(parser.parse_args()) + + evaluation(Model = args['model'], device = args['device'], ckpt = args['ckpt'], batch_size = args['batch_size']) diff --git a/Code/evaluation.sh b/Code/evaluation.sh new file mode 100755 index 0000000..5691af2 --- /dev/null +++ b/Code/evaluation.sh @@ -0,0 +1 @@ +python evaluation.py --model 'bert' --ckpt './VisSalFormer_weights.tar' --device 'cuda' \ No newline at end of file diff --git a/Code/model_swin.py b/Code/model_swin.py new file mode 100644 index 0000000..5d97147 --- /dev/null +++ b/Code/model_swin.py @@ -0,0 +1,116 @@ +import torch + +class SalFormer(torch.nn.Module): + def __init__(self, vision_encoder, bert): + """ + In the constructor we instantiate four parameters and assign them as + member parameters. + """ + super().__init__() + + self.vit = vision_encoder + self.feature_dim = 768 + self.bert = bert + + self.vision_head = torch.nn.Sequential( + torch.nn.Linear(self.feature_dim, self.feature_dim), + torch.nn.GELU(), + torch.nn.Linear(self.feature_dim, self.feature_dim), + torch.nn.GELU() + ) + + self.text_head = torch.nn.Sequential( + torch.nn.Linear(self.feature_dim, self.feature_dim), + torch.nn.GELU(), + torch.nn.Linear(self.feature_dim, self.feature_dim), + torch.nn.GELU() + ) + + self.cross_attention = torch.nn.MultiheadAttention(self.feature_dim, 16, kdim=self.feature_dim, vdim=self.feature_dim, batch_first=True) + self.cross_attention1 = torch.nn.MultiheadAttention(self.feature_dim, 16, kdim=self.feature_dim, vdim=self.feature_dim, batch_first=True) + + self.ln1 = torch.nn.LayerNorm(self.feature_dim) + self.ln2 = torch.nn.LayerNorm(self.feature_dim) + + self.self_attetion = torch.nn.MultiheadAttention(self.feature_dim, 16, batch_first=True) + + self.text_feature_query = torch.nn.Parameter(torch.randn(10, self.feature_dim).unsqueeze(0)/2) + self.img_positional_embedding = torch.nn.Parameter(torch.zeros(49, self.feature_dim)) + self.text_positional_embedding = torch.nn.Parameter(torch.zeros(10, self.feature_dim)) + + + self.dense1 = torch.nn.Linear(self.feature_dim, self.feature_dim) + self.relu1 = torch.nn.ReLU() + + self.decoder = torch.nn.Sequential( + torch.nn.Conv2d(self.feature_dim, 512, 3), + torch.nn.BatchNorm2d(512), + torch.nn.ReLU(), + torch.nn.Dropout(p=0.1), + torch.nn.Conv2d(512, 512, 3), + torch.nn.BatchNorm2d(512), + torch.nn.ReLU(), + torch.nn.Dropout(p=0.1), + torch.nn.Upsample((16,16), mode='bilinear'), + torch.nn.Conv2d(512, 256, 3), + torch.nn.BatchNorm2d(256), + torch.nn.ReLU(), + torch.nn.Dropout(p=0.1), + torch.nn.Conv2d(256, 256, 3), + torch.nn.BatchNorm2d(256), + torch.nn.ReLU(), + torch.nn.Dropout(p=0.1), + torch.nn.Upsample((32,32), mode='bilinear'), + torch.nn.Conv2d(256, 128, 3), + torch.nn.BatchNorm2d(128), + torch.nn.ReLU(), + torch.nn.Dropout(p=0.1), + torch.nn.Conv2d(128, 128, 3), + torch.nn.BatchNorm2d(128), + torch.nn.ReLU(), + torch.nn.Dropout(p=0.1), + torch.nn.Upsample((130,130), mode='bilinear'), + torch.nn.Conv2d(128, 1, 3), + torch.nn.BatchNorm2d(1), + torch.nn.Sigmoid(), + ) + + self.vit.eval() + self.bert.eval() + self.train(True) + + # def eval(self): + # super().eval() + # self.vit.eval() + # self.bert.eval() + + # def train(self, mode=True): + # super().train(mode) + # self.vit.train(mode) + # self.bert.train(mode) + + def forward(self, img, q_inputs): + + img_features = self.vit.forward(img, return_dict =True)["last_hidden_state"] + with torch.no_grad(): + text_features = self.bert(**q_inputs)["last_hidden_state"] + # text_features = torch.unsqueeze(bert_output["last_hidden_state"][:,0,:], 1) + text_features = self.cross_attention.forward(self.text_feature_query.repeat([text_features.shape[0], 1, 1]), text_features, text_features, need_weights=False)[0] + + fused_features = torch.concat((self.vision_head(img_features)+self.img_positional_embedding, self.text_head(text_features)+self.text_positional_embedding), 1) + att_fused_features = self.self_attetion.forward(fused_features, fused_features, fused_features, need_weights=False)[0] + fused_features = fused_features + att_fused_features + fused_features = self.ln1(fused_features) + + features = self.cross_attention1.forward(img_features, fused_features, fused_features, need_weights=False)[0] + features = img_features + features + features = self.ln2(features) + + features = self.dense1(features) + latent_features = self.relu1(features) + + latent_features = latent_features.permute(0,2,1) + out = torch.reshape(latent_features, (features.shape[0], self.feature_dim, 7, 7)) + out = self.decoder(out) + + return out diff --git a/Code/tokenizer_bert.py b/Code/tokenizer_bert.py new file mode 100644 index 0000000..c504176 --- /dev/null +++ b/Code/tokenizer_bert.py @@ -0,0 +1,15 @@ +import torch +from transformers import AutoTokenizer + +# tokenizer = AutoTokenizer.from_pretrained("roberta-base") +tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") +print('bert-base-uncased tokenizer loaded') + +def padding_fn(data): + img, q, fix, hm, name = zip(*data) + + input_ids = tokenizer(q, return_tensors="pt", padding=True) + + return torch.stack(img), input_ids, torch.stack(fix), torch.stack(hm), name + + diff --git a/README.md b/README.md index eebff5e..d16eaef 100644 --- a/README.md +++ b/README.md @@ -12,13 +12,38 @@ $Root Directory │ │─ README.md —— this file │ -|─ VisSalFormer —— Source code of the network to predict question-driven saliency +|─ Code —— Source code of the VisSalFormer model to predict question-driven saliency │ │ -│ │─ coming soon -│ -└─ SalChartQA —— The dataset +│ |─ environment.yml —— conda environment +│ │ +│ |─ env.py —— python envorinment $TORCH_HOME and $TRANSFORMERS_CACHE +│ │ +│ │─ dataset_new.py —— dataloader for SalChartQA +│ │ +│ │─ evaluation.py —— evaluation script to load VisSalFormer weights and make predictions +│ │ +│ │─ evaluation.sh —— bash script to run evaluation.py +│ │ +│ │─ model_swin.py —— definition of the VisSalFormer model +│ │ +│ │─ tokenizer_bert.py —— tokenizer of Bert +│ │ +│ └─ VisSalFormer_weights.tar —— weights of VisSalFormer +│ +└─ SalChartQA.zip —— The SalChartQA dataset │ - │─ coming soon + │─ fixationByVis —— BubbleView data (mouse clicks) of AMT workers + │ + │─ image_questions.json —— visualisation-question pairs + │ + │─ raw_img —— original visualisations from the ChartQA dataset + │ + │─ saliency_all —— saliency maps from all AMT workers + │ + │─ saliency_ans —— saliency maps aggretated by all AMT workers who either answered a question correctly or wrongly + │ + └─ unified_approved.csv —— responses from AMT workers + ``` If you think our work is useful to you, please consider citing our paper as: @@ -28,9 +53,9 @@ If you think our work is useful to you, please consider citing our paper as: title = {SalChartQA: Question-driven Saliency on Information Visualisations}, author = {Wang, Yao and Wang, Weitian and Abdelhafez, Abdullah and Elfares, Mayar and Hu, Zhiming and B{\^a}ce, Mihai and Bulling, Andreas}, year = {2024}, - pages = {1--20}, + pages = {1--14}, booktitle = {Proc. ACM SIGCHI Conference on Human Factors in Computing Systems (CHI)}, - doi = {} + doi = {10.1145/3613904.3642942} } ```