159 lines
5.4 KiB
Python
159 lines
5.4 KiB
Python
|
import pickle
|
||
|
import numpy as np
|
||
|
from torch.utils.data import Dataset, DataLoader
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.optim as optim
|
||
|
import torch.nn.functional as F
|
||
|
import shutil
|
||
|
import matplotlib.pyplot as plt
|
||
|
import argparse
|
||
|
from networks import ActionDemo2Predicate
|
||
|
from pathlib import Path
|
||
|
from termcolor import colored
|
||
|
import pandas as pd
|
||
|
|
||
|
|
||
|
print('torch version: ',torch.__version__)
|
||
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
|
print(DEVICE)
|
||
|
torch.manual_seed(256)
|
||
|
|
||
|
class test_dataset(Dataset):
|
||
|
def __init__(self, x, label, action_id):
|
||
|
self.x = x
|
||
|
self.idx = action_id
|
||
|
self.labels = label
|
||
|
|
||
|
def __getitem__(self, index):
|
||
|
x = self.x[index]
|
||
|
label = self.labels[index]
|
||
|
action_idx = self.idx[index]
|
||
|
return x, label, action_idx
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.labels)
|
||
|
|
||
|
def test_model(model, test_dataloader, DEVICE):
|
||
|
model.to(DEVICE)
|
||
|
model.eval()
|
||
|
test_acc = []
|
||
|
logits = []
|
||
|
labels = []
|
||
|
action_ids = []
|
||
|
for iter, (x, label, action_id) in enumerate(test_dataloader):
|
||
|
with torch.no_grad():
|
||
|
x = torch.tensor(x).to(DEVICE)
|
||
|
label = torch.tensor(label).to(DEVICE)
|
||
|
logps = model(x)
|
||
|
logps = F.softmax(logps, 1)
|
||
|
logits.append(logps.cpu().numpy())
|
||
|
labels.append(label.cpu().numpy())
|
||
|
action_ids.append(action_id)
|
||
|
|
||
|
argmax_Y = torch.max(logps, 1)[1].view(-1, 1)
|
||
|
test_acc.append((label.float().view(-1, 1) == argmax_Y.float()).sum().item() / len(label.float().view(-1, 1)) * 100)
|
||
|
|
||
|
test_acc = np.mean(np.array(test_acc))
|
||
|
print('test acc {:.4f}'.format(test_acc))
|
||
|
logits = np.concatenate(logits, axis=0)
|
||
|
labels = np.concatenate(labels, axis=0)
|
||
|
action_ids = np.concatenate(action_ids, axis=0)
|
||
|
return logits, labels, action_ids
|
||
|
|
||
|
def main():
|
||
|
# parsing parameters
|
||
|
parser = argparse.ArgumentParser(description='')
|
||
|
parser.add_argument('--resume', type=bool, default=False, help='resume training')
|
||
|
parser.add_argument('--batch_size', type=int, default=32, help='batch size')
|
||
|
parser.add_argument('--lr', type=float, default=1e-1, help='learning rate')
|
||
|
parser.add_argument('--model_type', type=str, default='lstmlast', help='model type')
|
||
|
parser.add_argument('--hidden_size', type=int, default=256, help='hidden_size')
|
||
|
parser.add_argument('--epochs', type=int, default=100, help='training epoch')
|
||
|
parser.add_argument('--dataset_path', type=str, default='dataset/strategy_dataset/', help='dataset path')
|
||
|
parser.add_argument('--weight_decay', type=float, default=0.9, help='wight decay for Adam optimizer')
|
||
|
parser.add_argument('--demo_hidden', type=int, default=512, help='demo_hidden')
|
||
|
parser.add_argument('--dropout', type=float, default=0.5, help='dropout rate')
|
||
|
parser.add_argument('--checkpoint', type=str, default='checkpoints/', help='checkpoints path')
|
||
|
|
||
|
args = parser.parse_args()
|
||
|
path = args.checkpoint+args.model_type+'_bs_'+str(args.batch_size)+'_lr_'+str(args.lr)+'_hidden_size_'+str(args.hidden_size)
|
||
|
|
||
|
# read models
|
||
|
models = []
|
||
|
for i in range(7): # 7 tasks
|
||
|
net = ActionDemo2Predicate(args)
|
||
|
model_path = path + '/task' + str(i) + '_checkpoint.ckpt' # _checkpoint
|
||
|
net.load(model_path)
|
||
|
models.append(net)
|
||
|
|
||
|
for u in range(5):
|
||
|
task_pred = []
|
||
|
task_target = []
|
||
|
task_act = []
|
||
|
task_task_name = []
|
||
|
for i in range(7): # 7 tasks
|
||
|
test_loader = []
|
||
|
# # read dataset test data
|
||
|
with open(args.dataset_path + 'test_data_' + str(i) + '.pkl', 'rb') as f:
|
||
|
data_x = pickle.load(f)
|
||
|
with open(args.dataset_path + 'test_label_' + str(i) + '.pkl', 'rb') as f:
|
||
|
data_y = pickle.load(f)
|
||
|
with open(args.dataset_path + 'test_action_id_' + str(i) + '.pkl', 'rb') as f:
|
||
|
act_idx = pickle.load(f)
|
||
|
|
||
|
x = data_x[u]
|
||
|
y = data_y[u]
|
||
|
act = act_idx[u]
|
||
|
test_set = test_dataset(np.array(x), np.array(y)-1, np.array(act))
|
||
|
test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=4, drop_last=True)
|
||
|
|
||
|
preds = []
|
||
|
targets = []
|
||
|
actions = []
|
||
|
task_names = []
|
||
|
for j in range(7): # logits from all models
|
||
|
pred, target, action = test_model(models[j], test_loader, DEVICE)
|
||
|
preds.append(pred)
|
||
|
targets.append(target)
|
||
|
actions.append(action)
|
||
|
task_names.append(np.full(target.shape, i)) #assumed intention
|
||
|
|
||
|
task_pred.append(preds)
|
||
|
task_target.append(targets)
|
||
|
task_act.append(actions)
|
||
|
task_task_name.append(task_names)
|
||
|
|
||
|
for i in range(7):
|
||
|
preds = []
|
||
|
targets = []
|
||
|
actions = []
|
||
|
task_names = []
|
||
|
for j in range(7):
|
||
|
preds.append(task_pred[j][i])
|
||
|
targets.append(task_target[j][i]+1) # gt value add one
|
||
|
actions.append(task_act[j][i])
|
||
|
task_names.append(task_task_name[j][i])
|
||
|
|
||
|
preds = np.concatenate(preds, axis=0)
|
||
|
targets = np.concatenate(targets, axis=0)
|
||
|
actions = np.concatenate(actions, axis=0)
|
||
|
task_names = np.concatenate(task_names, axis=0)
|
||
|
write_data = np.concatenate((np.reshape(actions, (-1, 1)), preds, np.reshape(task_names, (-1, 1)), np.reshape(targets, (-1, 1))), axis=1)
|
||
|
|
||
|
output_path = 'prediction/' + 'task' +str(i) + '/' + args.model_type+'_bs_'+str(args.batch_size)+'_lr_'+str(args.lr)+'_hidden_size_'+str(args.hidden_size)
|
||
|
Path(output_path).mkdir(parents=True, exist_ok=True)
|
||
|
output_path = output_path + '/user' + str(u) + '_pred.csv'
|
||
|
print(write_data.shape)
|
||
|
|
||
|
head = []
|
||
|
for j in range(7):
|
||
|
head.append('act'+str(j+1))
|
||
|
head.append('task_name')
|
||
|
head.append('gt')
|
||
|
head.insert(0,'action_id')
|
||
|
pd.DataFrame(write_data).to_csv(output_path, header=head)
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
main()
|