111 lines
4.1 KiB
Python
111 lines
4.1 KiB
Python
|
import torch
|
||
|
from torch.utils.data import DataLoader
|
||
|
from env import *
|
||
|
|
||
|
import argparse
|
||
|
from dataset_new import ImagesWithSaliency
|
||
|
from torchvision.utils import save_image
|
||
|
from transformers import SwinModel
|
||
|
from pathlib import Path
|
||
|
|
||
|
def evaluation(Model:str, ckpt: str, device, batch_size:int):
|
||
|
eps=1e-10
|
||
|
|
||
|
if Model == 'llama':
|
||
|
from model_llama import SalFormer
|
||
|
from transformers import LlamaModel
|
||
|
from tokenizer_llama import padding_fn
|
||
|
# llm = LlamaModel.from_pretrained("Enoch/llama-7b-hf", low_cpu_mem_usage=True)
|
||
|
llm = LlamaModel.from_pretrained("daryl149/Llama-2-7b-chat-hf", low_cpu_mem_usage=True)
|
||
|
neuron_n = 4096
|
||
|
print("llama loaded")
|
||
|
elif Model == 'bloom':
|
||
|
from model_llama import SalFormer
|
||
|
from transformers import BloomModel
|
||
|
from tokenizer_bloom import padding_fn
|
||
|
llm = BloomModel.from_pretrained("bigscience/bloom-3b")
|
||
|
neuron_n = 2560
|
||
|
print('BloomModel loaded')
|
||
|
elif Model == 'bert':
|
||
|
from model_swin import SalFormer
|
||
|
from transformers import BertModel
|
||
|
from tokenizer_bert import padding_fn
|
||
|
llm = BertModel.from_pretrained("bert-base-uncased")
|
||
|
print('BertModel loaded')
|
||
|
else:
|
||
|
print('model not available, possiblilities: llama, bloom, bert')
|
||
|
return
|
||
|
|
||
|
test_set = ImagesWithSaliency("data/test.npy")
|
||
|
|
||
|
Path('./eval_results').mkdir(parents=True, exist_ok=True)
|
||
|
|
||
|
# vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
|
||
|
vit = SwinModel.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
|
||
|
# vit = timm.create_model('xception41p.ra3_in1k', pretrained=True)
|
||
|
|
||
|
if Model == 'bert':
|
||
|
model = SalFormer(vit, llm).to(device)
|
||
|
else:
|
||
|
model = SalFormer(vit, llm, neuron_n = neuron_n).to(device)
|
||
|
|
||
|
checkpoint = torch.load(ckpt)
|
||
|
model.load_state_dict(checkpoint['model_state_dict'])
|
||
|
model.eval()
|
||
|
|
||
|
|
||
|
test_dataloader = DataLoader(test_set, batch_size=batch_size, shuffle=False, collate_fn=padding_fn, num_workers=8)
|
||
|
kl_loss = torch.nn.KLDivLoss(reduction="batchmean", log_target=True)
|
||
|
|
||
|
test_kl, test_cc, test_nss = 0,0,0
|
||
|
for batch, (img, input_ids, fix, hm, name) in enumerate(test_dataloader):
|
||
|
img = img.to(device)
|
||
|
input_ids = input_ids.to(device)
|
||
|
fix = fix.to(device)
|
||
|
hm = hm.to(device)
|
||
|
|
||
|
y = model(img, input_ids)
|
||
|
|
||
|
y_sum = y.view(y.shape[0], -1).sum(1, keepdim=True)
|
||
|
y_distribution = y / (y_sum[:, :, None, None] + eps)
|
||
|
|
||
|
hm_sum = hm.view(y.shape[0], -1).sum(1, keepdim=True)
|
||
|
hm_distribution = hm / (hm_sum[:, :, None, None] + eps)
|
||
|
hm_distribution = hm_distribution + eps
|
||
|
hm_distribution = hm_distribution / (1+eps)
|
||
|
|
||
|
if fix.sum() != 0:
|
||
|
normal_y = (y-y.mean())/y.std()
|
||
|
nss = torch.sum(normal_y*fix)/fix.sum()
|
||
|
else:
|
||
|
nss = torch.Tensor([0.0]).to(device)
|
||
|
kl = kl_loss(torch.log(y_distribution), torch.log(hm_distribution))
|
||
|
|
||
|
vy = y - torch.mean(y)
|
||
|
vhm = hm - torch.mean(hm)
|
||
|
|
||
|
if (torch.sqrt(torch.sum(vy ** 2)) * torch.sqrt(torch.sum(vhm ** 2))) != 0:
|
||
|
cc = torch.sum(vy * vhm) / (torch.sqrt(torch.sum(vy ** 2)) * torch.sqrt(torch.sum(vhm ** 2)))
|
||
|
else:
|
||
|
cc = torch.Tensor([0.0]).to(device)
|
||
|
|
||
|
test_kl += kl.item()/len(test_dataloader)
|
||
|
test_cc += cc.item()/len(test_dataloader)
|
||
|
test_nss += nss.item()/len(test_dataloader)
|
||
|
|
||
|
for i in range(0, y.shape[0]):
|
||
|
save_image(y[i], f"./eval_results/{name[i]}")
|
||
|
|
||
|
print("kl:", test_kl, "cc", test_cc, "nss", test_nss)
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
parser = argparse.ArgumentParser()
|
||
|
parser.add_argument("--model", type=str, default='bert')
|
||
|
parser.add_argument("--device", type=str, default='cuda')
|
||
|
parser.add_argument("--batch_size", type=int, default=16)
|
||
|
parser.add_argument("--ckpt", type=str, default='./ckpt/model_bert_freeze_10kl_5cc_2nss.tar')
|
||
|
args = vars(parser.parse_args())
|
||
|
|
||
|
evaluation(Model = args['model'], device = args['device'], ckpt = args['ckpt'], batch_size = args['batch_size'])
|