initial commit
This commit is contained in:
commit
a82bbc593e
129 changed files with 33981 additions and 0 deletions
81
eval_visdial.py
Normal file
81
eval_visdial.py
Normal file
|
@ -0,0 +1,81 @@
|
|||
import os
|
||||
import torch
|
||||
import json
|
||||
from utils.metrcis import SparseGTMetrics, NDCG
|
||||
from Levenshtein import ratio
|
||||
|
||||
|
||||
output_dir = '/pfss/mlde/workspaces/mlde_wsp_Rohrbach/users/ma35vahy/V2Dial_new_v2/output/visdial'
|
||||
|
||||
file_paths = os.listdir(output_dir)
|
||||
file_paths = list(filter(lambda f: 'part' in f , file_paths))
|
||||
name = file_paths[0]
|
||||
file_paths = list(map(lambda f: os.path.join(output_dir, f), file_paths))
|
||||
|
||||
results = {}
|
||||
count = 0
|
||||
for pth in file_paths:
|
||||
with open(pth, 'r') as f:
|
||||
partial_res = json.load(f)
|
||||
count += len(partial_res)
|
||||
results.update(partial_res)
|
||||
# dialogs.extend(data['dialogs'])
|
||||
os.remove(pth)
|
||||
|
||||
|
||||
name = "".join(name.split('-')[:-1]) + '.json'
|
||||
output_path = os.path.join(output_dir, name)
|
||||
with open(output_path, 'w') as f:
|
||||
json.dump(results, f, indent=4)
|
||||
|
||||
|
||||
# result_path = '/pfss/mlde/workspaces/mlde_wsp_Rohrbach/users/ma35vahy/V2Dial_new_v2/output/visdial/zeroshot_visdial_after_champagne_googleflant5large_results_dstc8_beam_depth_8_lenPen_0.3.json'
|
||||
|
||||
# with open(result_path, 'r') as f:
|
||||
# results = json.load(f)
|
||||
|
||||
annos_path = '/pfss/mlde/workspaces/mlde_wsp_Rohrbach/data/annotations/visdial_v1.0/visdial_1.0_val.json'
|
||||
dense_annos_path = '/pfss/mlde/workspaces/mlde_wsp_Rohrbach/data/annotations/visdial_v1.0/visdial_1.0_val_dense_annotations.json'
|
||||
|
||||
with open(annos_path, 'r') as f:
|
||||
data = json.load(f)['data']
|
||||
|
||||
all_answers = data['answers']
|
||||
all_questions = data['questions']
|
||||
|
||||
|
||||
dialogs = data['dialogs']
|
||||
|
||||
dialogs_dict = {}
|
||||
|
||||
for dialog in dialogs:
|
||||
image_id = dialog['image_id']
|
||||
for i, turn in enumerate(dialog['dialog']):
|
||||
answer_opts = [all_answers[a] for a in turn['answer_options']]
|
||||
dialogs_dict[str(image_id) + '_' + str(i+1)] = {
|
||||
'answer_opts': answer_opts,
|
||||
'gt_index': turn['gt_index']
|
||||
}
|
||||
# print('bla')
|
||||
|
||||
with open(dense_annos_path, 'r') as f:
|
||||
dense_data = json.load(f)
|
||||
|
||||
dense_data = {str(d['image_id']) + '_' + str(d['round_id']): d['gt_relevance'] for d in dense_data}
|
||||
|
||||
sparse_metrics = SparseGTMetrics()
|
||||
ndcg = NDCG()
|
||||
|
||||
for res_key, res in results.items():
|
||||
answer_opts = dialogs_dict[res_key]['answer_opts']
|
||||
gt_index = torch.tensor(dialogs_dict[res_key]['gt_index'])
|
||||
|
||||
scores = torch.tensor([ratio(res, answer_opt) for answer_opt in answer_opts]).unsqueeze(0).unsqueeze(0)
|
||||
sparse_metrics.observe(scores, gt_index)
|
||||
if res_key in dense_data:
|
||||
gt_relevance = torch.tensor(dense_data[res_key]).unsqueeze(0)
|
||||
ndcg.observe(scores.squeeze(0), gt_relevance)
|
||||
# print('bla')
|
||||
|
||||
print(sparse_metrics.retrieve())
|
||||
print(ndcg.retrieve())
|
Loading…
Add table
Add a link
Reference in a new issue