115 lines
4.4 KiB
Raw Normal View History

2023-10-25 15:38:09 +02:00
import os
import os.path as osp
import numpy as np
import json
import argparse
import pyhocon
import glog as log
import torch
from tqdm import tqdm
from utils.data_utils import load_pickle_lines
from utils.visdial_metrics import scores_to_ranks
parser = argparse.ArgumentParser(description='Ensemble for VisDial')
parser.add_argument('--exp', type=str, default='test',
help='experiment name from .conf')
parser.add_argument('--mode', type=str, default='predict', choices=['eval', 'predict'],
help='eval or predict')
parser.add_argument('--ssh', action='store_true',
help='whether or not we are executing command via ssh. '
'If set to True, we will not log.info anything to screen and only redirect them to log file')
if __name__ == '__main__':
args = parser.parse_args()
# initialization
config = pyhocon.ConfigFactory.parse_file(f"config/ensemble.conf")[args.exp]
config["log_dir"] = os.path.join(config["log_dir"], args.exp)
if not os.path.exists(config["log_dir"]):
# set logs
log_file = os.path.join(config["log_dir"], f'{args.mode}.log')
set_log_file(log_file, file_only=args.ssh)
# print environment info
log.info(f"Running experiment: {args.exp}")
log.info(f"Results saved to {config['log_dir']}")
log.info(pyhocon.HOCONConverter.convert(config, "hocon"))
if isinstance(config['processed'], list):
assert len(config['models']) == len(config['processed'])
processed = {model:pcd for model, pcd in zip(config['models'], config['processed'])}
processed = {model: config['processed'] for model in config['models']}
if config['split'] == 'test' and np.any(config['processed']):
test_data = json.load(open(config['visdial_test_data']))['data']['dialogs']
imid2rndid = {t['image_id']: len(t['dialog']) for t in test_data}
del test_data
# load predictions files
visdial_outputs = dict()
if args.mode == 'eval':
metrics = {}
for model in config['models']:
pred_filename = osp.join(config['pred_dir'], model, 'visdial_prediction.pkl')
pred_dict = {p['image_id']: p for p in load_pickle_lines(pred_filename)}
log.info(f'Loading {len(pred_dict)} predictions from {pred_filename}')
visdial_outputs[model] = pred_dict
if args.mode == 'eval':
assert len(visdial_outputs[model]) >= num_dialogs
metric = json.load(open(osp.join(config['pred_dir'], model, "metrics_epoch_best.json")))
metrics[model] = metric['val']
image_ids = visdial_outputs[model].keys()
predictions = []
# for each dialog
for image_id in tqdm(image_ids):
scores = []
round_id = None
for model in config['models']:
pred = visdial_outputs[model][image_id]
if config['split'] == 'test' and processed[model]:
# if predict on processed data, the first few rounds are deleted from some dialogs
# so the original round ids can only be found in the original test data
round_id_in_pred = imid2rndid[image_id]
round_id_in_pred = pred['gt_relevance_round_id']
if not isinstance(round_id_in_pred, int):
round_id_in_pred = int(round_id_in_pred)
if round_id is None:
round_id = round_id_in_pred
# make sure all models have the same round_id
assert round_id == round_id_in_pred
# ensemble scores
scores = torch.cat(scores, 0) # [n_model, num_rounds, num_options]
scores = torch.sum(scores, dim=0, keepdim=True) # [1, num_rounds, num_options]
if scores.size(0) > 1:
scores = scores[round_id - 1].unsqueeze(0)
ranks = scores_to_ranks(scores) # [eval_batch_size, num_rounds, num_options]
ranks = ranks.squeeze(1)
prediction = {
"image_id": image_id,
"round_id": round_id,
"ranks": ranks[0].tolist()
filename = osp.join(config['log_dir'], f'{config["split"]}_ensemble_preds.json')
with open(filename, 'w') as f:
json.dump(predictions, f)
log.info(f'{len(predictions)} predictions saved to {filename}')