initial commit

This commit is contained in:
Andreas Bulling 2025-06-24 08:38:09 +02:00
commit a82bbc593e
129 changed files with 33981 additions and 0 deletions

81
eval_visdial.py Normal file
View file

@ -0,0 +1,81 @@
import os
import torch
import json
from utils.metrcis import SparseGTMetrics, NDCG
from Levenshtein import ratio
output_dir = '/pfss/mlde/workspaces/mlde_wsp_Rohrbach/users/ma35vahy/V2Dial_new_v2/output/visdial'
file_paths = os.listdir(output_dir)
file_paths = list(filter(lambda f: 'part' in f , file_paths))
name = file_paths[0]
file_paths = list(map(lambda f: os.path.join(output_dir, f), file_paths))
results = {}
count = 0
for pth in file_paths:
with open(pth, 'r') as f:
partial_res = json.load(f)
count += len(partial_res)
results.update(partial_res)
# dialogs.extend(data['dialogs'])
os.remove(pth)
name = "".join(name.split('-')[:-1]) + '.json'
output_path = os.path.join(output_dir, name)
with open(output_path, 'w') as f:
json.dump(results, f, indent=4)
# result_path = '/pfss/mlde/workspaces/mlde_wsp_Rohrbach/users/ma35vahy/V2Dial_new_v2/output/visdial/zeroshot_visdial_after_champagne_googleflant5large_results_dstc8_beam_depth_8_lenPen_0.3.json'
# with open(result_path, 'r') as f:
# results = json.load(f)
annos_path = '/pfss/mlde/workspaces/mlde_wsp_Rohrbach/data/annotations/visdial_v1.0/visdial_1.0_val.json'
dense_annos_path = '/pfss/mlde/workspaces/mlde_wsp_Rohrbach/data/annotations/visdial_v1.0/visdial_1.0_val_dense_annotations.json'
with open(annos_path, 'r') as f:
data = json.load(f)['data']
all_answers = data['answers']
all_questions = data['questions']
dialogs = data['dialogs']
dialogs_dict = {}
for dialog in dialogs:
image_id = dialog['image_id']
for i, turn in enumerate(dialog['dialog']):
answer_opts = [all_answers[a] for a in turn['answer_options']]
dialogs_dict[str(image_id) + '_' + str(i+1)] = {
'answer_opts': answer_opts,
'gt_index': turn['gt_index']
}
# print('bla')
with open(dense_annos_path, 'r') as f:
dense_data = json.load(f)
dense_data = {str(d['image_id']) + '_' + str(d['round_id']): d['gt_relevance'] for d in dense_data}
sparse_metrics = SparseGTMetrics()
ndcg = NDCG()
for res_key, res in results.items():
answer_opts = dialogs_dict[res_key]['answer_opts']
gt_index = torch.tensor(dialogs_dict[res_key]['gt_index'])
scores = torch.tensor([ratio(res, answer_opt) for answer_opt in answer_opts]).unsqueeze(0).unsqueeze(0)
sparse_metrics.observe(scores, gt_index)
if res_key in dense_data:
gt_relevance = torch.tensor(dense_data[res_key]).unsqueeze(0)
ndcg.observe(scores.squeeze(0), gt_relevance)
# print('bla')
print(sparse_metrics.retrieve())
print(ndcg.retrieve())