from argparse import ArgumentParser import numpy as np from tqdm import tqdm import torch from torch.utils.data import DataLoader import torch.nn.functional as F import dgl from tom.dataset import TestToMnetDGLDataset, collate_function_seq_test from tom.model import GraphBC_T, GraphBCRNN def get_z_scores(total, total_expected, total_unexpected): mean = np.mean(total) std = np.std(total) print("Z-Score expected: ", (np.mean(total_expected) - mean) / std) print("Z-Score unexpected: ", (np.mean(total_unexpected) - mean) / std) parser = ArgumentParser() parser.add_argument('--model_type', type=str, default='graphbcrnn') parser.add_argument('--ckpt', type=str, default=None, help='path to checkpoint') parser.add_argument('--data_path', type=str, default=None, help='path to the data') parser.add_argument('--process_data', type=int, default=0) parser.add_argument('--surprise_type', type=str, default='max', help='surprise type: mean, max. This is used for comparing the plausibility scores of the two test episodes') parser.add_argument('--types', nargs='+', type=str, default=[ 'preference', 'multi_agent', 'inaccessible_goal', 'efficiency_irrational', 'efficiency_time','efficiency_path', 'instrumental_no_barrier', 'instrumental_blocking_barrier', 'instrumental_inconsequential_barrier' ], help='types of tasks used for training / testing') parser.add_argument('--filename', type=str, default='') args = parser.parse_args() filename = args.filename if args.model_type == 'graphbct': model = GraphBC_T.load_from_checkpoint(args.ckpt) elif args.model_type == 'graphbcrnn': model = GraphBCRNN.load_from_checkpoint(args.ckpt) else: raise ValueError('Unknown model type.') device = 'cuda' model.to(device) model.eval() with torch.no_grad(): for t in args.types: if args.model_type == 'graphbcrnn': test_dataset = TestToMnetDGLDataset( path=args.data_path, task_type=t, mode='test' ) test_dataloader = DataLoader( test_dataset, batch_size=1, num_workers=1, pin_memory=True, collate_fn=collate_function_seq_test, shuffle=False ) count = 0 total, total_expected, total_unexpected = [], [], [] pbar = tqdm(test_dataloader) for j, batch in enumerate(pbar): if args.model_type == 'graphbcrnn': dem_expected_states, dem_expected_actions, dem_expected_lens, \ dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, \ query_expected_frames, target_expected_actions, \ query_unexpected_frames, target_unexpected_actions = batch dem_expected_states = dem_expected_states.to(device) dem_expected_actions = dem_expected_actions.to(device) dem_unexpected_states = dem_unexpected_states.to(device) dem_unexpected_actions = dem_unexpected_actions.to(device) target_expected_actions = target_expected_actions.to(device) target_unexpected_actions = target_unexpected_actions.to(device) surprise_expected = [] query_expected_frames = dgl.unbatch(query_expected_frames) for i in range(len(query_expected_frames)): if args.model_type == 'graphbcrnn': test_actions, test_actions_pred = model( [dem_expected_states, dem_expected_actions, dem_expected_lens, query_expected_frames[i].to(device), target_expected_actions[:, i, :]] ) loss = F.mse_loss(test_actions, test_actions_pred) surprise_expected.append(loss.cpu().detach().numpy()) mean_expected_surprise = np.mean(surprise_expected) max_expected_surprise = np.max(surprise_expected) # calculate the plausibility scores for the unexpected episode surprise_unexpected = [] query_unexpected_frames = dgl.unbatch(query_unexpected_frames) for i in range(len(query_unexpected_frames)): if args.model_type == 'graphbcrnn': test_actions, test_actions_pred = model( [dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, query_unexpected_frames[i].to(device), target_unexpected_actions[:, i, :]] ) loss = F.mse_loss(test_actions, test_actions_pred) surprise_unexpected.append(loss.cpu().detach().numpy()) mean_unexpected_surprise = np.mean(surprise_unexpected) max_unexpected_surprise = np.max(surprise_unexpected) correct_mean = mean_expected_surprise < mean_unexpected_surprise + 0.5 * (mean_expected_surprise == mean_unexpected_surprise) correct_max = max_expected_surprise < max_unexpected_surprise + 0.5 * (max_expected_surprise == max_unexpected_surprise) if args.surprise_type == 'max': count += correct_max elif args.surprise_type == 'mean': count += correct_mean pbar.set_postfix({'accuracy': count/(j+1.), 'type': t}) total_expected.append(max_expected_surprise) total_unexpected.append(max_unexpected_surprise) total.append(max_expected_surprise) total.append(max_unexpected_surprise) get_z_scores(total, total_expected, total_unexpected)