IRENE/test_tom.py

123 lines
5.6 KiB
Python
Raw Normal View History

2024-02-01 15:40:47 +01:00
from argparse import ArgumentParser
import numpy as np
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
import dgl
from tom.dataset import TestToMnetDGLDataset, collate_function_seq_test
from tom.model import GraphBC_T, GraphBCRNN
def get_z_scores(total, total_expected, total_unexpected):
mean = np.mean(total)
std = np.std(total)
print("Z-Score expected: ",
(np.mean(total_expected) - mean) / std)
print("Z-Score unexpected: ",
(np.mean(total_unexpected) - mean) / std)
parser = ArgumentParser()
parser.add_argument('--model_type', type=str, default='graphbcrnn')
parser.add_argument('--ckpt', type=str, default=None, help='path to checkpoint')
parser.add_argument('--data_path', type=str, default=None, help='path to the data')
parser.add_argument('--process_data', type=int, default=0)
parser.add_argument('--surprise_type', type=str, default='max',
help='surprise type: mean, max. This is used for comparing the plausibility scores of the two test episodes')
parser.add_argument('--types', nargs='+', type=str,
default=[
'preference', 'multi_agent', 'inaccessible_goal',
'efficiency_irrational', 'efficiency_time','efficiency_path',
'instrumental_no_barrier', 'instrumental_blocking_barrier', 'instrumental_inconsequential_barrier'
],
help='types of tasks used for training / testing')
parser.add_argument('--filename', type=str, default='')
args = parser.parse_args()
filename = args.filename
if args.model_type == 'graphbct':
model = GraphBC_T.load_from_checkpoint(args.ckpt)
elif args.model_type == 'graphbcrnn':
model = GraphBCRNN.load_from_checkpoint(args.ckpt)
else:
raise ValueError('Unknown model type.')
device = 'cuda'
model.to(device)
model.eval()
with torch.no_grad():
for t in args.types:
if args.model_type == 'graphbcrnn':
test_dataset = TestToMnetDGLDataset(
path=args.data_path,
task_type=t,
mode='test'
)
test_dataloader = DataLoader(
test_dataset,
batch_size=1,
num_workers=1,
pin_memory=True,
collate_fn=collate_function_seq_test,
shuffle=False
)
count = 0
total, total_expected, total_unexpected = [], [], []
pbar = tqdm(test_dataloader)
for j, batch in enumerate(pbar):
if args.model_type == 'graphbcrnn':
dem_expected_states, dem_expected_actions, dem_expected_lens, \
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, \
query_expected_frames, target_expected_actions, \
query_unexpected_frames, target_unexpected_actions = batch
dem_expected_states = dem_expected_states.to(device)
dem_expected_actions = dem_expected_actions.to(device)
dem_unexpected_states = dem_unexpected_states.to(device)
dem_unexpected_actions = dem_unexpected_actions.to(device)
target_expected_actions = target_expected_actions.to(device)
target_unexpected_actions = target_unexpected_actions.to(device)
surprise_expected = []
query_expected_frames = dgl.unbatch(query_expected_frames)
for i in range(len(query_expected_frames)):
if args.model_type == 'graphbcrnn':
test_actions, test_actions_pred = model(
[dem_expected_states, dem_expected_actions, dem_expected_lens, query_expected_frames[i].to(device), target_expected_actions[:, i, :]]
)
loss = F.mse_loss(test_actions, test_actions_pred)
surprise_expected.append(loss.cpu().detach().numpy())
mean_expected_surprise = np.mean(surprise_expected)
max_expected_surprise = np.max(surprise_expected)
# calculate the plausibility scores for the unexpected episode
surprise_unexpected = []
query_unexpected_frames = dgl.unbatch(query_unexpected_frames)
for i in range(len(query_unexpected_frames)):
if args.model_type == 'graphbcrnn':
test_actions, test_actions_pred = model(
[dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, query_unexpected_frames[i].to(device), target_unexpected_actions[:, i, :]]
)
loss = F.mse_loss(test_actions, test_actions_pred)
surprise_unexpected.append(loss.cpu().detach().numpy())
mean_unexpected_surprise = np.mean(surprise_unexpected)
max_unexpected_surprise = np.max(surprise_unexpected)
correct_mean = mean_expected_surprise < mean_unexpected_surprise + 0.5 * (mean_expected_surprise == mean_unexpected_surprise)
correct_max = max_expected_surprise < max_unexpected_surprise + 0.5 * (max_expected_surprise == max_unexpected_surprise)
if args.surprise_type == 'max':
count += correct_max
elif args.surprise_type == 'mean':
count += correct_mean
pbar.set_postfix({'accuracy': count/(j+1.), 'type': t})
total_expected.append(max_expected_surprise)
total_unexpected.append(max_unexpected_surprise)
total.append(max_expected_surprise)
total.append(max_unexpected_surprise)
get_z_scores(total, total_expected, total_unexpected)