81 lines
2.5 KiB
Python
81 lines
2.5 KiB
Python
import os
|
|
import torch
|
|
import json
|
|
from utils.metrcis import SparseGTMetrics, NDCG
|
|
from Levenshtein import ratio
|
|
|
|
|
|
output_dir = '/pfss/mlde/workspaces/mlde_wsp_Rohrbach/users/ma35vahy/V2Dial_new_v2/output/visdial'
|
|
|
|
file_paths = os.listdir(output_dir)
|
|
file_paths = list(filter(lambda f: 'part' in f , file_paths))
|
|
name = file_paths[0]
|
|
file_paths = list(map(lambda f: os.path.join(output_dir, f), file_paths))
|
|
|
|
results = {}
|
|
count = 0
|
|
for pth in file_paths:
|
|
with open(pth, 'r') as f:
|
|
partial_res = json.load(f)
|
|
count += len(partial_res)
|
|
results.update(partial_res)
|
|
# dialogs.extend(data['dialogs'])
|
|
os.remove(pth)
|
|
|
|
|
|
name = "".join(name.split('-')[:-1]) + '.json'
|
|
output_path = os.path.join(output_dir, name)
|
|
with open(output_path, 'w') as f:
|
|
json.dump(results, f, indent=4)
|
|
|
|
|
|
# result_path = '/pfss/mlde/workspaces/mlde_wsp_Rohrbach/users/ma35vahy/V2Dial_new_v2/output/visdial/zeroshot_visdial_after_champagne_googleflant5large_results_dstc8_beam_depth_8_lenPen_0.3.json'
|
|
|
|
# with open(result_path, 'r') as f:
|
|
# results = json.load(f)
|
|
|
|
annos_path = '/pfss/mlde/workspaces/mlde_wsp_Rohrbach/data/annotations/visdial_v1.0/visdial_1.0_val.json'
|
|
dense_annos_path = '/pfss/mlde/workspaces/mlde_wsp_Rohrbach/data/annotations/visdial_v1.0/visdial_1.0_val_dense_annotations.json'
|
|
|
|
with open(annos_path, 'r') as f:
|
|
data = json.load(f)['data']
|
|
|
|
all_answers = data['answers']
|
|
all_questions = data['questions']
|
|
|
|
|
|
dialogs = data['dialogs']
|
|
|
|
dialogs_dict = {}
|
|
|
|
for dialog in dialogs:
|
|
image_id = dialog['image_id']
|
|
for i, turn in enumerate(dialog['dialog']):
|
|
answer_opts = [all_answers[a] for a in turn['answer_options']]
|
|
dialogs_dict[str(image_id) + '_' + str(i+1)] = {
|
|
'answer_opts': answer_opts,
|
|
'gt_index': turn['gt_index']
|
|
}
|
|
# print('bla')
|
|
|
|
with open(dense_annos_path, 'r') as f:
|
|
dense_data = json.load(f)
|
|
|
|
dense_data = {str(d['image_id']) + '_' + str(d['round_id']): d['gt_relevance'] for d in dense_data}
|
|
|
|
sparse_metrics = SparseGTMetrics()
|
|
ndcg = NDCG()
|
|
|
|
for res_key, res in results.items():
|
|
answer_opts = dialogs_dict[res_key]['answer_opts']
|
|
gt_index = torch.tensor(dialogs_dict[res_key]['gt_index'])
|
|
|
|
scores = torch.tensor([ratio(res, answer_opt) for answer_opt in answer_opts]).unsqueeze(0).unsqueeze(0)
|
|
sparse_metrics.observe(scores, gt_index)
|
|
if res_key in dense_data:
|
|
gt_relevance = torch.tensor(dense_data[res_key]).unsqueeze(0)
|
|
ndcg.observe(scores.squeeze(0), gt_relevance)
|
|
# print('bla')
|
|
|
|
print(sparse_metrics.retrieve())
|
|
print(ndcg.retrieve())
|