InferringIntention/keyboard_and_mouse/test.py

159 lines
5.4 KiB
Python
Raw Normal View History

2024-03-24 23:42:27 +01:00
import pickle
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import shutil
import matplotlib.pyplot as plt
import argparse
from networks import ActionDemo2Predicate
from pathlib import Path
from termcolor import colored
import pandas as pd
print('torch version: ',torch.__version__)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)
torch.manual_seed(256)
class test_dataset(Dataset):
def __init__(self, x, label, action_id):
self.x = x
self.idx = action_id
self.labels = label
def __getitem__(self, index):
x = self.x[index]
label = self.labels[index]
action_idx = self.idx[index]
return x, label, action_idx
def __len__(self):
return len(self.labels)
def test_model(model, test_dataloader, DEVICE):
model.to(DEVICE)
model.eval()
test_acc = []
logits = []
labels = []
action_ids = []
for iter, (x, label, action_id) in enumerate(test_dataloader):
with torch.no_grad():
x = torch.tensor(x).to(DEVICE)
label = torch.tensor(label).to(DEVICE)
logps = model(x)
logps = F.softmax(logps, 1)
logits.append(logps.cpu().numpy())
labels.append(label.cpu().numpy())
action_ids.append(action_id)
argmax_Y = torch.max(logps, 1)[1].view(-1, 1)
test_acc.append((label.float().view(-1, 1) == argmax_Y.float()).sum().item() / len(label.float().view(-1, 1)) * 100)
test_acc = np.mean(np.array(test_acc))
print('test acc {:.4f}'.format(test_acc))
logits = np.concatenate(logits, axis=0)
labels = np.concatenate(labels, axis=0)
action_ids = np.concatenate(action_ids, axis=0)
return logits, labels, action_ids
def main():
# parsing parameters
parser = argparse.ArgumentParser(description='')
parser.add_argument('--resume', type=bool, default=False, help='resume training')
parser.add_argument('--batch_size', type=int, default=32, help='batch size')
parser.add_argument('--lr', type=float, default=1e-1, help='learning rate')
parser.add_argument('--model_type', type=str, default='lstmlast', help='model type')
parser.add_argument('--hidden_size', type=int, default=256, help='hidden_size')
parser.add_argument('--epochs', type=int, default=100, help='training epoch')
parser.add_argument('--dataset_path', type=str, default='dataset/strategy_dataset/', help='dataset path')
parser.add_argument('--weight_decay', type=float, default=0.9, help='wight decay for Adam optimizer')
parser.add_argument('--demo_hidden', type=int, default=512, help='demo_hidden')
parser.add_argument('--dropout', type=float, default=0.5, help='dropout rate')
parser.add_argument('--checkpoint', type=str, default='checkpoints/', help='checkpoints path')
args = parser.parse_args()
path = args.checkpoint+args.model_type+'_bs_'+str(args.batch_size)+'_lr_'+str(args.lr)+'_hidden_size_'+str(args.hidden_size)
# read models
models = []
for i in range(7): # 7 tasks
net = ActionDemo2Predicate(args)
model_path = path + '/task' + str(i) + '_checkpoint.ckpt' # _checkpoint
net.load(model_path)
models.append(net)
for u in range(5):
task_pred = []
task_target = []
task_act = []
task_task_name = []
for i in range(7): # 7 tasks
test_loader = []
# # read dataset test data
with open(args.dataset_path + 'test_data_' + str(i) + '.pkl', 'rb') as f:
data_x = pickle.load(f)
with open(args.dataset_path + 'test_label_' + str(i) + '.pkl', 'rb') as f:
data_y = pickle.load(f)
with open(args.dataset_path + 'test_action_id_' + str(i) + '.pkl', 'rb') as f:
act_idx = pickle.load(f)
x = data_x[u]
y = data_y[u]
act = act_idx[u]
test_set = test_dataset(np.array(x), np.array(y)-1, np.array(act))
test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=4, drop_last=True)
preds = []
targets = []
actions = []
task_names = []
for j in range(7): # logits from all models
pred, target, action = test_model(models[j], test_loader, DEVICE)
preds.append(pred)
targets.append(target)
actions.append(action)
task_names.append(np.full(target.shape, i)) #assumed intention
task_pred.append(preds)
task_target.append(targets)
task_act.append(actions)
task_task_name.append(task_names)
for i in range(7):
preds = []
targets = []
actions = []
task_names = []
for j in range(7):
preds.append(task_pred[j][i])
targets.append(task_target[j][i]+1) # gt value add one
actions.append(task_act[j][i])
task_names.append(task_task_name[j][i])
preds = np.concatenate(preds, axis=0)
targets = np.concatenate(targets, axis=0)
actions = np.concatenate(actions, axis=0)
task_names = np.concatenate(task_names, axis=0)
write_data = np.concatenate((np.reshape(actions, (-1, 1)), preds, np.reshape(task_names, (-1, 1)), np.reshape(targets, (-1, 1))), axis=1)
output_path = 'prediction/' + 'task' +str(i) + '/' + args.model_type+'_bs_'+str(args.batch_size)+'_lr_'+str(args.lr)+'_hidden_size_'+str(args.hidden_size)
Path(output_path).mkdir(parents=True, exist_ok=True)
output_path = output_path + '/user' + str(u) + '_pred.csv'
print(write_data.shape)
head = []
for j in range(7):
head.append('act'+str(j+1))
head.append('task_name')
head.append('gt')
head.insert(0,'action_id')
pd.DataFrame(write_data).to_csv(output_path, header=head)
if __name__ == '__main__':
main()