122 lines
5.6 KiB
Python
122 lines
5.6 KiB
Python
from argparse import ArgumentParser
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
|
|
import torch
|
|
from torch.utils.data import DataLoader
|
|
import torch.nn.functional as F
|
|
import dgl
|
|
|
|
from tom.dataset import TestToMnetDGLDataset, collate_function_seq_test
|
|
from tom.model import GraphBC_T, GraphBCRNN
|
|
|
|
|
|
def get_z_scores(total, total_expected, total_unexpected):
|
|
mean = np.mean(total)
|
|
std = np.std(total)
|
|
print("Z-Score expected: ",
|
|
(np.mean(total_expected) - mean) / std)
|
|
print("Z-Score unexpected: ",
|
|
(np.mean(total_unexpected) - mean) / std)
|
|
|
|
|
|
parser = ArgumentParser()
|
|
|
|
parser.add_argument('--model_type', type=str, default='graphbcrnn')
|
|
parser.add_argument('--ckpt', type=str, default=None, help='path to checkpoint')
|
|
parser.add_argument('--data_path', type=str, default=None, help='path to the data')
|
|
parser.add_argument('--process_data', type=int, default=0)
|
|
parser.add_argument('--surprise_type', type=str, default='max',
|
|
help='surprise type: mean, max. This is used for comparing the plausibility scores of the two test episodes')
|
|
parser.add_argument('--types', nargs='+', type=str,
|
|
default=[
|
|
'preference', 'multi_agent', 'inaccessible_goal',
|
|
'efficiency_irrational', 'efficiency_time','efficiency_path',
|
|
'instrumental_no_barrier', 'instrumental_blocking_barrier', 'instrumental_inconsequential_barrier'
|
|
],
|
|
help='types of tasks used for training / testing')
|
|
parser.add_argument('--filename', type=str, default='')
|
|
|
|
args = parser.parse_args()
|
|
|
|
filename = args.filename
|
|
|
|
if args.model_type == 'graphbct':
|
|
model = GraphBC_T.load_from_checkpoint(args.ckpt)
|
|
elif args.model_type == 'graphbcrnn':
|
|
model = GraphBCRNN.load_from_checkpoint(args.ckpt)
|
|
else:
|
|
raise ValueError('Unknown model type.')
|
|
|
|
device = 'cuda'
|
|
model.to(device)
|
|
model.eval()
|
|
with torch.no_grad():
|
|
for t in args.types:
|
|
if args.model_type == 'graphbcrnn':
|
|
test_dataset = TestToMnetDGLDataset(
|
|
path=args.data_path,
|
|
task_type=t,
|
|
mode='test'
|
|
)
|
|
test_dataloader = DataLoader(
|
|
test_dataset,
|
|
batch_size=1,
|
|
num_workers=1,
|
|
pin_memory=True,
|
|
collate_fn=collate_function_seq_test,
|
|
shuffle=False
|
|
)
|
|
count = 0
|
|
total, total_expected, total_unexpected = [], [], []
|
|
pbar = tqdm(test_dataloader)
|
|
for j, batch in enumerate(pbar):
|
|
if args.model_type == 'graphbcrnn':
|
|
dem_expected_states, dem_expected_actions, dem_expected_lens, \
|
|
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, \
|
|
query_expected_frames, target_expected_actions, \
|
|
query_unexpected_frames, target_unexpected_actions = batch
|
|
dem_expected_states = dem_expected_states.to(device)
|
|
dem_expected_actions = dem_expected_actions.to(device)
|
|
dem_unexpected_states = dem_unexpected_states.to(device)
|
|
dem_unexpected_actions = dem_unexpected_actions.to(device)
|
|
target_expected_actions = target_expected_actions.to(device)
|
|
target_unexpected_actions = target_unexpected_actions.to(device)
|
|
surprise_expected = []
|
|
query_expected_frames = dgl.unbatch(query_expected_frames)
|
|
for i in range(len(query_expected_frames)):
|
|
if args.model_type == 'graphbcrnn':
|
|
test_actions, test_actions_pred = model(
|
|
[dem_expected_states, dem_expected_actions, dem_expected_lens, query_expected_frames[i].to(device), target_expected_actions[:, i, :]]
|
|
)
|
|
loss = F.mse_loss(test_actions, test_actions_pred)
|
|
surprise_expected.append(loss.cpu().detach().numpy())
|
|
mean_expected_surprise = np.mean(surprise_expected)
|
|
max_expected_surprise = np.max(surprise_expected)
|
|
|
|
# calculate the plausibility scores for the unexpected episode
|
|
surprise_unexpected = []
|
|
query_unexpected_frames = dgl.unbatch(query_unexpected_frames)
|
|
for i in range(len(query_unexpected_frames)):
|
|
if args.model_type == 'graphbcrnn':
|
|
test_actions, test_actions_pred = model(
|
|
[dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, query_unexpected_frames[i].to(device), target_unexpected_actions[:, i, :]]
|
|
)
|
|
loss = F.mse_loss(test_actions, test_actions_pred)
|
|
surprise_unexpected.append(loss.cpu().detach().numpy())
|
|
mean_unexpected_surprise = np.mean(surprise_unexpected)
|
|
max_unexpected_surprise = np.max(surprise_unexpected)
|
|
|
|
correct_mean = mean_expected_surprise < mean_unexpected_surprise + 0.5 * (mean_expected_surprise == mean_unexpected_surprise)
|
|
correct_max = max_expected_surprise < max_unexpected_surprise + 0.5 * (max_expected_surprise == max_unexpected_surprise)
|
|
if args.surprise_type == 'max':
|
|
count += correct_max
|
|
elif args.surprise_type == 'mean':
|
|
count += correct_mean
|
|
pbar.set_postfix({'accuracy': count/(j+1.), 'type': t})
|
|
|
|
total_expected.append(max_expected_surprise)
|
|
total_unexpected.append(max_unexpected_surprise)
|
|
total.append(max_expected_surprise)
|
|
total.append(max_unexpected_surprise)
|
|
get_z_scores(total, total_expected, total_unexpected)
|