up
This commit is contained in:
parent
a333481e05
commit
de0bea7508
18 changed files with 3150 additions and 2 deletions
122
test_tom.py
Normal file
122
test_tom.py
Normal file
|
@ -0,0 +1,122 @@
|
|||
from argparse import ArgumentParser
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
import torch.nn.functional as F
|
||||
import dgl
|
||||
|
||||
from tom.dataset import TestToMnetDGLDataset, collate_function_seq_test
|
||||
from tom.model import GraphBC_T, GraphBCRNN
|
||||
|
||||
|
||||
def get_z_scores(total, total_expected, total_unexpected):
|
||||
mean = np.mean(total)
|
||||
std = np.std(total)
|
||||
print("Z-Score expected: ",
|
||||
(np.mean(total_expected) - mean) / std)
|
||||
print("Z-Score unexpected: ",
|
||||
(np.mean(total_unexpected) - mean) / std)
|
||||
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
parser.add_argument('--model_type', type=str, default='graphbcrnn')
|
||||
parser.add_argument('--ckpt', type=str, default=None, help='path to checkpoint')
|
||||
parser.add_argument('--data_path', type=str, default=None, help='path to the data')
|
||||
parser.add_argument('--process_data', type=int, default=0)
|
||||
parser.add_argument('--surprise_type', type=str, default='max',
|
||||
help='surprise type: mean, max. This is used for comparing the plausibility scores of the two test episodes')
|
||||
parser.add_argument('--types', nargs='+', type=str,
|
||||
default=[
|
||||
'preference', 'multi_agent', 'inaccessible_goal',
|
||||
'efficiency_irrational', 'efficiency_time','efficiency_path',
|
||||
'instrumental_no_barrier', 'instrumental_blocking_barrier', 'instrumental_inconsequential_barrier'
|
||||
],
|
||||
help='types of tasks used for training / testing')
|
||||
parser.add_argument('--filename', type=str, default='')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
filename = args.filename
|
||||
|
||||
if args.model_type == 'graphbct':
|
||||
model = GraphBC_T.load_from_checkpoint(args.ckpt)
|
||||
elif args.model_type == 'graphbcrnn':
|
||||
model = GraphBCRNN.load_from_checkpoint(args.ckpt)
|
||||
else:
|
||||
raise ValueError('Unknown model type.')
|
||||
|
||||
device = 'cuda'
|
||||
model.to(device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for t in args.types:
|
||||
if args.model_type == 'graphbcrnn':
|
||||
test_dataset = TestToMnetDGLDataset(
|
||||
path=args.data_path,
|
||||
task_type=t,
|
||||
mode='test'
|
||||
)
|
||||
test_dataloader = DataLoader(
|
||||
test_dataset,
|
||||
batch_size=1,
|
||||
num_workers=1,
|
||||
pin_memory=True,
|
||||
collate_fn=collate_function_seq_test,
|
||||
shuffle=False
|
||||
)
|
||||
count = 0
|
||||
total, total_expected, total_unexpected = [], [], []
|
||||
pbar = tqdm(test_dataloader)
|
||||
for j, batch in enumerate(pbar):
|
||||
if args.model_type == 'graphbcrnn':
|
||||
dem_expected_states, dem_expected_actions, dem_expected_lens, \
|
||||
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, \
|
||||
query_expected_frames, target_expected_actions, \
|
||||
query_unexpected_frames, target_unexpected_actions = batch
|
||||
dem_expected_states = dem_expected_states.to(device)
|
||||
dem_expected_actions = dem_expected_actions.to(device)
|
||||
dem_unexpected_states = dem_unexpected_states.to(device)
|
||||
dem_unexpected_actions = dem_unexpected_actions.to(device)
|
||||
target_expected_actions = target_expected_actions.to(device)
|
||||
target_unexpected_actions = target_unexpected_actions.to(device)
|
||||
surprise_expected = []
|
||||
query_expected_frames = dgl.unbatch(query_expected_frames)
|
||||
for i in range(len(query_expected_frames)):
|
||||
if args.model_type == 'graphbcrnn':
|
||||
test_actions, test_actions_pred = model(
|
||||
[dem_expected_states, dem_expected_actions, dem_expected_lens, query_expected_frames[i].to(device), target_expected_actions[:, i, :]]
|
||||
)
|
||||
loss = F.mse_loss(test_actions, test_actions_pred)
|
||||
surprise_expected.append(loss.cpu().detach().numpy())
|
||||
mean_expected_surprise = np.mean(surprise_expected)
|
||||
max_expected_surprise = np.max(surprise_expected)
|
||||
|
||||
# calculate the plausibility scores for the unexpected episode
|
||||
surprise_unexpected = []
|
||||
query_unexpected_frames = dgl.unbatch(query_unexpected_frames)
|
||||
for i in range(len(query_unexpected_frames)):
|
||||
if args.model_type == 'graphbcrnn':
|
||||
test_actions, test_actions_pred = model(
|
||||
[dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, query_unexpected_frames[i].to(device), target_unexpected_actions[:, i, :]]
|
||||
)
|
||||
loss = F.mse_loss(test_actions, test_actions_pred)
|
||||
surprise_unexpected.append(loss.cpu().detach().numpy())
|
||||
mean_unexpected_surprise = np.mean(surprise_unexpected)
|
||||
max_unexpected_surprise = np.max(surprise_unexpected)
|
||||
|
||||
correct_mean = mean_expected_surprise < mean_unexpected_surprise + 0.5 * (mean_expected_surprise == mean_unexpected_surprise)
|
||||
correct_max = max_expected_surprise < max_unexpected_surprise + 0.5 * (max_expected_surprise == max_unexpected_surprise)
|
||||
if args.surprise_type == 'max':
|
||||
count += correct_max
|
||||
elif args.surprise_type == 'mean':
|
||||
count += correct_mean
|
||||
pbar.set_postfix({'accuracy': count/(j+1.), 'type': t})
|
||||
|
||||
total_expected.append(max_expected_surprise)
|
||||
total_unexpected.append(max_unexpected_surprise)
|
||||
total.append(max_expected_surprise)
|
||||
total.append(max_unexpected_surprise)
|
||||
get_z_scores(total, total_expected, total_unexpected)
|
Loading…
Add table
Add a link
Reference in a new issue