first commit
This commit is contained in:
commit
83b04e2133
109 changed files with 12081 additions and 0 deletions
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,306 @@
|
|||
import os
|
||||
import random
|
||||
import copy
|
||||
import json
|
||||
|
||||
|
||||
import numpy as np
|
||||
from termcolor import colored
|
||||
from glob import glob
|
||||
import pickle
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data.dataset import Dataset
|
||||
|
||||
|
||||
|
||||
################################
|
||||
# Demonstration
|
||||
################################
|
||||
|
||||
def get_dataset(args, task_name, train):
|
||||
train_data, test_data, new_test_data, train_action_gt, test_action_gt, new_test_action_gt, train_task_name, test_task_name, new_test_task_name, train_action_id, test_action_id, new_test_action_id, action_predicates, all_action, all_object, goal_objects, goal_targets, goal_predicates, max_goal_length, max_action_length, max_node_length, max_subgoal_length = gather_data(args, task_name)
|
||||
train_dset = demo_dset(args, train_data, train_action_gt, train_task_name, train_action_id, action_predicates, all_action, all_object, goal_objects, goal_targets, goal_predicates, max_goal_length, max_action_length, max_node_length, max_subgoal_length)
|
||||
test_dset = demo_dset(args, test_data, test_action_gt, test_task_name, test_action_id, action_predicates, all_action, all_object, goal_objects, goal_targets, goal_predicates, max_goal_length, max_action_length, max_node_length, max_subgoal_length)
|
||||
new_test_dset = demo_dset(args, new_test_data, new_test_action_gt, new_test_task_name, new_test_action_id, action_predicates, all_action, all_object, goal_objects, goal_targets, goal_predicates, max_goal_length, max_action_length, max_node_length, max_subgoal_length)
|
||||
return train_dset, test_dset, new_test_dset
|
||||
|
||||
|
||||
def collate_fn(data_list):
|
||||
graph_data = [data[0] for data in data_list]
|
||||
batch_goal_index = [data[1] for data in data_list]
|
||||
batch_valid_action_with_walk_index = [data[2] for data in data_list]
|
||||
|
||||
if len(graph_data[0])==3:
|
||||
batch_graph_length = [d[0] for d in graph_data]
|
||||
batch_graph_input = [d[1] for d in graph_data]
|
||||
batch_file_name = [d[2] for d in graph_data]
|
||||
else:
|
||||
batch_graph_length = [d[0] for d in graph_data]
|
||||
batch_file_name = [d[1] for d in graph_data]
|
||||
|
||||
if len(graph_data[0])==3:
|
||||
batch_demo_data = (
|
||||
np.arange(len(batch_graph_length)),
|
||||
batch_graph_length,
|
||||
batch_graph_input,
|
||||
batch_file_name
|
||||
)
|
||||
else:
|
||||
batch_demo_data = (
|
||||
np.arange(len(batch_graph_length)),
|
||||
batch_graph_length,
|
||||
batch_file_name
|
||||
)
|
||||
|
||||
return batch_demo_data, batch_goal_index, batch_valid_action_with_walk_index
|
||||
|
||||
|
||||
def to_cuda_fn(data):
|
||||
batch_demo_data, batch_goal_index, batch_valid_action_with_walk_index = data
|
||||
|
||||
if len(batch_demo_data)==4:
|
||||
batch_demo_index, batch_graph_length, batch_graph_input, batch_file_name = batch_demo_data
|
||||
|
||||
batch_graph_input_class_objects = [[torch.tensor(j['class_objects']).cuda() for j in i] for i in batch_graph_input]
|
||||
batch_graph_input_object_coords = [[torch.tensor(j['object_coords']).cuda() for j in i] for i in batch_graph_input]
|
||||
batch_graph_input_states_objects = [[torch.tensor(j['states_objects']).cuda() for j in i] for i in batch_graph_input]
|
||||
batch_graph_input_mask_object = [[torch.tensor(j['mask_object']).cuda() for j in i] for i in batch_graph_input]
|
||||
|
||||
batch_graph_input = { 'class_objects': batch_graph_input_class_objects,
|
||||
'object_coords': batch_graph_input_object_coords,
|
||||
'states_objects': batch_graph_input_states_objects,
|
||||
'mask_object': batch_graph_input_mask_object}
|
||||
|
||||
else:
|
||||
batch_demo_index, batch_graph_length, batch_file_name = batch_demo_data
|
||||
|
||||
|
||||
batch_goal_index = [torch.tensor(i).cuda().long() for i in batch_goal_index]
|
||||
batch_valid_action_with_walk_index = [torch.tensor(i).cuda().long() for i in batch_valid_action_with_walk_index]
|
||||
|
||||
if len(batch_demo_data)==4:
|
||||
batch_demo_data = (
|
||||
batch_demo_index,
|
||||
batch_graph_length,
|
||||
batch_graph_input,
|
||||
batch_file_name
|
||||
)
|
||||
else:
|
||||
batch_demo_data = (
|
||||
batch_demo_index,
|
||||
batch_graph_length,
|
||||
batch_file_name
|
||||
)
|
||||
|
||||
return batch_demo_data, batch_goal_index, batch_valid_action_with_walk_index
|
||||
|
||||
def one_hot(states, graph_node_states):
|
||||
one_hot = np.zeros(len(graph_node_states))
|
||||
for state in states:
|
||||
one_hot[graph_node_states[state]] = 1
|
||||
return one_hot
|
||||
|
||||
def gather_data(args, task):
|
||||
meta_data_path = 'dataset/watch_data/metadata.json'
|
||||
data_path_new_test = 'dataset/watch_data/action/new_test_task_' + task + '_strategy.json'
|
||||
data_path_test = 'dataset/watch_data/action/test_task_' + task + '_strategy.json'
|
||||
data_path_train = 'dataset/watch_data/action/train_task_' + task + '_strategy.json'
|
||||
|
||||
|
||||
with open(data_path_new_test, 'r') as f:
|
||||
new_test_data = json.load(f)
|
||||
|
||||
with open(data_path_test, 'r') as f:
|
||||
test_data = json.load(f)
|
||||
|
||||
#if os.path.exists(data_path):
|
||||
#print('load gather_data, this may take a while...', data_path)
|
||||
with open(data_path_train, 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
# temporarily set data_path to test data to test training
|
||||
train_data = data
|
||||
|
||||
with open(meta_data_path, 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
action_predicates = data['action_predicates']
|
||||
all_action = data['all_action']
|
||||
all_object = data['all_object']
|
||||
|
||||
goal_objects = data['goal_objects']
|
||||
goal_targets = data['goal_targets']
|
||||
goal_predicates = data['goal_predicates']
|
||||
|
||||
graph_class_names = data['graph_class_names']
|
||||
graph_node_states = data['graph_node_states']
|
||||
|
||||
max_goal_length = data['max_goal_length']
|
||||
max_action_length = data['max_action_length']
|
||||
max_node_length = data['max_node_length']
|
||||
|
||||
|
||||
## -----------------------------------------------------------------------------
|
||||
## add action, goal, and graph node index
|
||||
## -----------------------------------------------------------------------------
|
||||
max_subgoal_length = 1
|
||||
|
||||
train_task_name = np.array(train_data['task_name'])
|
||||
|
||||
test_task_name = np.array(test_data['task_name'])
|
||||
|
||||
new_test_task_name = np.array(new_test_data['task_name'])
|
||||
|
||||
|
||||
for traintest in [train_data, test_data, new_test_data]:
|
||||
for data in traintest['goal']:
|
||||
## goal
|
||||
goal_index = []
|
||||
subgoal_dict = {}
|
||||
for subgoal in data:
|
||||
goal_index.append(goal_predicates[subgoal])
|
||||
|
||||
if goal_predicates[subgoal] not in subgoal_dict:
|
||||
subgoal_dict[goal_predicates[subgoal]] = 1
|
||||
else:
|
||||
subgoal_dict[goal_predicates[subgoal]] += 1
|
||||
|
||||
this_max_subgoal_length = np.max(list(subgoal_dict.values()))
|
||||
if this_max_subgoal_length>max_subgoal_length:
|
||||
max_subgoal_length = this_max_subgoal_length
|
||||
|
||||
|
||||
goal_index.sort()
|
||||
if len(goal_index) < max_goal_length:
|
||||
for i in range(max_goal_length-len(goal_index)):
|
||||
goal_index.append(0)
|
||||
|
||||
## action gt
|
||||
for i in range(len(traintest['action_gt'])): # len(traintest['action_gt'])
|
||||
action_name = traintest['action_gt'][i][0].split(' ')[0]
|
||||
object_name = traintest['action_gt'][i][0].split(' ')[1]
|
||||
predicate_name = ' '.join([action_name, object_name])
|
||||
traintest['action_gt'][i] = action_predicates[predicate_name]
|
||||
|
||||
## action
|
||||
valid_action_with_walk_index = []
|
||||
for i in range(len(traintest['valid_action_with_walks'])):
|
||||
actions_index = []
|
||||
for actions in traintest['valid_action_with_walks'][i]:
|
||||
if actions!='None':
|
||||
action_name = actions[0].split(' ')[0]
|
||||
object_name = actions[0].split(' ')[1]
|
||||
predicate_name = ' '.join([action_name, object_name])
|
||||
else:
|
||||
predicate_name = actions
|
||||
actions_index.append(action_predicates[predicate_name])
|
||||
|
||||
traintest['valid_action_with_walks'][i] = actions_index
|
||||
|
||||
print(len(train_data['action_gt']),np.array(train_data['action_gt']), type(train_data['action_gt']))
|
||||
|
||||
|
||||
train_action_gt = np.array(train_data['action_gt'])
|
||||
test_action_gt = np.array(test_data['action_gt'])
|
||||
new_test_action_gt = np.array(new_test_data['action_gt'])
|
||||
|
||||
train_action_id = np.array(train_data['action_id'])
|
||||
test_action_id = np.array(test_data['action_id'])
|
||||
new_test_action_id = np.array(new_test_data['action_id'])
|
||||
|
||||
train_data = np.array(train_data['valid_action_with_walks'])
|
||||
test_data = np.array(test_data['valid_action_with_walks'])
|
||||
new_test_data = np.array(new_test_data['valid_action_with_walks'])
|
||||
|
||||
|
||||
print('--------------------------------------------------------------------------------')
|
||||
print('train_data', len(train_data), train_data.shape)
|
||||
print('test_data', len(test_data), train_data.shape)
|
||||
print('new_test_data', len(new_test_data), train_data.shape)
|
||||
print('--------------------------------------------------------------------------------')
|
||||
print('train_gt', len(train_action_gt), train_action_gt.shape)
|
||||
print('test_gt', len(test_action_gt), test_action_gt.shape)
|
||||
print('new_test_gt', len(new_test_action_gt), new_test_action_gt.shape)
|
||||
print('--------------------------------------------------------------------------------')
|
||||
print('train_task_name', len(train_task_name), train_task_name.shape)
|
||||
print('test_task_name', len(test_task_name), test_task_name.shape)
|
||||
print('new_test_task_name', len(new_test_task_name), new_test_task_name.shape)
|
||||
print('--------------------------------------------------------------------------------')
|
||||
|
||||
return train_data, test_data, new_test_data, train_action_gt, test_action_gt, new_test_action_gt, train_task_name, test_task_name, new_test_task_name, train_action_id, test_action_id, new_test_action_id, action_predicates, all_action, all_object, goal_objects, goal_targets, goal_predicates, max_goal_length, max_action_length, max_node_length, max_subgoal_length
|
||||
|
||||
class demo_dset(Dataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args,
|
||||
data,
|
||||
gt,
|
||||
task_name,
|
||||
action_id,
|
||||
#action_predicates, all_action, all_object, goal_objects, goal_targets, goal_predicates, graph_class_names, graph_node_states, max_goal_length, max_action_length, max_node_length, max_subgoal_length):
|
||||
action_predicates, all_action, all_object, goal_objects, goal_targets, goal_predicates, max_goal_length, max_action_length, max_node_length, max_subgoal_length):
|
||||
|
||||
|
||||
self.inputtype = args.inputtype
|
||||
self.multi_classifier = args.multi_classifier
|
||||
self.data = data
|
||||
self.gt = gt
|
||||
self.task_name = task_name
|
||||
self.action_id = action_id
|
||||
|
||||
self.max_action_len = 82
|
||||
self.action_predicates = action_predicates
|
||||
self.all_action = all_action
|
||||
self.all_object = all_object
|
||||
|
||||
self.goal_objects = goal_objects
|
||||
self.goal_targets = goal_targets
|
||||
self.goal_predicates = goal_predicates
|
||||
self.num_goal_predicates = len(goal_predicates)
|
||||
|
||||
self.max_goal_length = max_goal_length
|
||||
self.max_action_length = max_action_length
|
||||
self.max_subgoal_length = max_subgoal_length
|
||||
|
||||
if self.inputtype=='graphinput':
|
||||
self.graph_class_names = graph_class_names
|
||||
self.graph_node_states = graph_node_states
|
||||
self.num_node_states = len(graph_node_states)
|
||||
self.max_node_length = max_node_length
|
||||
|
||||
|
||||
print('-----------------------------------------------------------------------------')
|
||||
print('num_goal_predicates', self.num_goal_predicates)
|
||||
print('max_goal_length', self.max_goal_length)
|
||||
print('max_action_length', max_action_length)
|
||||
|
||||
if self.inputtype=='graphinput':
|
||||
print('num_node_states', self.num_node_states)
|
||||
print('max_node_length', max_node_length)
|
||||
print('-----------------------------------------------------------------------------')
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
data = self.data[index]
|
||||
gt = self.gt[index]
|
||||
task_name = self.task_name[index]
|
||||
action_id = self.action_id[index]
|
||||
|
||||
return data, gt, task_name, action_id
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def _preprocess_one_data(self, data):
|
||||
action_gt = data['action_gt']
|
||||
valid_action_with_walk_index = data['valid_action_with_walks']
|
||||
action_length = len(valid_action_with_walk_index)
|
||||
inputdata = (action_length, 'actions')
|
||||
data = [inputdata, action_gt, valid_action_with_walk_index]
|
||||
return data
|
||||
|
||||
|
297
watch_and_help/watch_strategy_full/predicate/utils.py
Normal file
297
watch_and_help/watch_strategy_full/predicate/utils.py
Normal file
|
@ -0,0 +1,297 @@
|
|||
import argparse
|
||||
import random
|
||||
import time
|
||||
import os
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
|
||||
import torch
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from helper import to_cpu, average_over_list, writer_helper
|
||||
import csv
|
||||
import pathlib
|
||||
|
||||
def grab_args():
|
||||
|
||||
def str2bool(v):
|
||||
return v.lower() == 'true'
|
||||
|
||||
parser = argparse.ArgumentParser(description='')
|
||||
parser.add_argument('--seed', type=int, default=123, help='random seed')
|
||||
parser.add_argument('--verbose', type=str2bool, default=False)
|
||||
parser.add_argument('--debug', type=str2bool, default=False)
|
||||
parser.add_argument('--prefix', type=str, default='test')
|
||||
parser.add_argument('--checkpoint', type=str, default=None)
|
||||
parser.add_argument('--n_workers', type=int, default=0)
|
||||
parser.add_argument('--train_iters', type=int, default=2e4)
|
||||
parser.add_argument('--inputtype', type=str, default='actioninput')
|
||||
parser.add_argument('--resume', type=str, default='')
|
||||
parser.add_argument('--dropout', type=float, default=0)
|
||||
parser.add_argument('--inference', type=int, default=0)
|
||||
parser.add_argument('--single', type=int, default=0)
|
||||
parser.add_argument('--loss_type', type=str, default='regu') #regu or ce
|
||||
parser.add_argument('--testset', type=str, default='test') # test: test set 1, new_test: test set 2
|
||||
|
||||
# model config
|
||||
parser.add_argument(
|
||||
'--model_type',
|
||||
type=str,
|
||||
default='max')
|
||||
parser.add_argument('--embedding_dim', type=int, default=100)
|
||||
parser.add_argument('--predicate_hidden', type=int, default=128)
|
||||
parser.add_argument('--demo_hidden', type=int, default=128)
|
||||
parser.add_argument('--multi_classifier', type=int, default=0)
|
||||
parser.add_argument('--transformer_nhead', type=int, default=2)
|
||||
|
||||
# train config
|
||||
parser.add_argument(
|
||||
'--gpu_id',
|
||||
metavar='N',
|
||||
type=str,
|
||||
nargs='+',
|
||||
help='specify the gpu id')
|
||||
parser.add_argument('--batch_size', type=int, default=2)
|
||||
parser.add_argument('--model_lr_rate', type=float, default=3e-4)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def setup(train):
|
||||
def _basic_setting(args):
|
||||
|
||||
# set seed
|
||||
torch.manual_seed(args.seed)
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
|
||||
if args.gpu_id is None:
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = ''
|
||||
args.__dict__.update({'cuda': False})
|
||||
else:
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = ', '.join(args.gpu_id)
|
||||
args.__dict__.update({'cuda': True})
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
if args.debug:
|
||||
args.verbose = True
|
||||
|
||||
def _basic_checking(args):
|
||||
pass
|
||||
|
||||
def _create_checkpoint_dir(args):
|
||||
# setup checkpoint_dir
|
||||
if args.debug:
|
||||
checkpoint_dir = 'debug'
|
||||
elif train:
|
||||
checkpoint_dir = 'checkpoint_dir'
|
||||
else:
|
||||
checkpoint_dir = 'testing_dir'
|
||||
|
||||
checkpoint_dir = os.path.join(checkpoint_dir, 'demo2predicate')
|
||||
|
||||
args_dict = args.__dict__
|
||||
keys = sorted(args_dict)
|
||||
prefix = ['{}-{}'.format(k, args_dict[k]) for k in keys]
|
||||
prefix.remove('debug-{}'.format(args.debug))
|
||||
prefix.remove('checkpoint-{}'.format(args.checkpoint))
|
||||
prefix.remove('gpu_id-{}'.format(args.gpu_id))
|
||||
|
||||
checkpoint_dir = os.path.join(checkpoint_dir, *prefix)
|
||||
|
||||
checkpoint_dir += '/{}'.format(time.strftime("%Y%m%d-%H%M%S"))
|
||||
|
||||
return checkpoint_dir
|
||||
|
||||
def _make_dirs(checkpoint_dir, tfboard_dir):
|
||||
if not os.path.exists(checkpoint_dir):
|
||||
os.makedirs(checkpoint_dir)
|
||||
if not os.path.exists(tfboard_dir):
|
||||
os.makedirs(tfboard_dir)
|
||||
|
||||
def _print_args(args):
|
||||
args_str = ''
|
||||
with open(os.path.join(checkpoint_dir, 'args.txt'), 'w') as f:
|
||||
for k, v in args.__dict__.items():
|
||||
s = '{}: {}'.format(k, v)
|
||||
args_str += '{}\n'.format(s)
|
||||
print(s)
|
||||
f.write(s + '\n')
|
||||
print("All the data will be saved in", checkpoint_dir)
|
||||
return args_str
|
||||
|
||||
args = grab_args()
|
||||
_basic_setting(args)
|
||||
_basic_checking(args)
|
||||
|
||||
checkpoint_dir = args.checkpoint
|
||||
tfboard_dir = os.path.join(checkpoint_dir, 'tfboard')
|
||||
|
||||
_make_dirs(checkpoint_dir, tfboard_dir)
|
||||
args_str = _print_args(args)
|
||||
|
||||
writer = SummaryWriter(tfboard_dir)
|
||||
writer.add_text('args', args_str, 0)
|
||||
writer = writer_helper(writer)
|
||||
|
||||
model_config = {
|
||||
"model_type": args.model_type,
|
||||
"embedding_dim": args.embedding_dim,
|
||||
"predicate_hidden": args.predicate_hidden,
|
||||
}
|
||||
model_config.update({"demo_hidden": args.demo_hidden})
|
||||
|
||||
return args, checkpoint_dir, writer, model_config
|
||||
|
||||
|
||||
def summary(
|
||||
args,
|
||||
writer,
|
||||
info,
|
||||
train_args,
|
||||
model,
|
||||
test_loader,
|
||||
postfix,
|
||||
fps=None):
|
||||
|
||||
if postfix == 'train':
|
||||
model.write_summary(writer, info, postfix=postfix)
|
||||
elif postfix == 'val':
|
||||
info = summary_eval(
|
||||
model,
|
||||
test_loader,
|
||||
test_loader.dataset)
|
||||
model.write_summary(writer, info, postfix=postfix)
|
||||
elif postfix == 'test':
|
||||
info = summary_eval(
|
||||
model,
|
||||
test_loader,
|
||||
test_loader.dataset)
|
||||
|
||||
model.write_summary(writer, info, postfix=postfix)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
if fps:
|
||||
writer.scalar_summary('General/fps', fps)
|
||||
|
||||
return info
|
||||
|
||||
def summary_eval(
|
||||
model,
|
||||
loader,
|
||||
dset):
|
||||
|
||||
model.eval()
|
||||
print(len(loader))
|
||||
with torch.no_grad():
|
||||
|
||||
loss_list = []
|
||||
top1_list = []
|
||||
iter = 0
|
||||
action_id_list = []
|
||||
|
||||
prob = []
|
||||
target = []
|
||||
file_name = []
|
||||
for batch_data in loader:
|
||||
loss, info = model(batch_data)
|
||||
loss_list.append(loss.cpu().item())
|
||||
top1_list.append(info['top1'])
|
||||
|
||||
prob.append(info['prob'])
|
||||
target.append(info['target'])
|
||||
file_name.append(info['task_name'])
|
||||
action_id_list.append(info['action_id'])
|
||||
|
||||
if iter%10==0:
|
||||
print('testing %d / %d: loss %.4f: acc %.4f' % (iter, len(loader), loss, info['top1']))
|
||||
|
||||
iter += 1
|
||||
|
||||
info = {"loss": sum(loss_list)/ len(loss_list), "top1": sum(top1_list)/ len(top1_list), "prob": prob, "target": target, "task_name": file_name, "action_id": action_id_list}
|
||||
return info
|
||||
|
||||
def write_prob(info, args):
|
||||
temp_prob_list = []
|
||||
temp_action_id_list = []
|
||||
temp_task_name_list = []
|
||||
temp_target_list = []
|
||||
for i in range(len(info['prob'])):
|
||||
for j in range(len(info['prob'][i])):
|
||||
temp_prob_list.append(info['prob'][i][j])
|
||||
|
||||
for i in range(len(info['action_id'])):
|
||||
for j in range(len(info['action_id'][i])):
|
||||
temp_action_id_list.append(info['action_id'][i][j])
|
||||
|
||||
for i in range(len(info['task_name'])):
|
||||
for j in range(len(info['task_name'][i])):
|
||||
temp_task_name_list.append(info['task_name'][i][j])
|
||||
|
||||
for i in range(len(info['target'])):
|
||||
for j in range(len(info['target'][i])):
|
||||
temp_target_list.append(info['target'][i][j])
|
||||
|
||||
prob = np.array(temp_prob_list)
|
||||
action_id = np.array(temp_action_id_list)
|
||||
task_name = np.array(temp_task_name_list)
|
||||
target = np.array(temp_target_list)
|
||||
|
||||
write_data = np.concatenate((np.reshape(action_id, (-1, 1)), prob, np.reshape(task_name, (-1, 1)), np.reshape(target, (-1, 1))), axis=1)
|
||||
import pandas as pd
|
||||
head = []
|
||||
for j in range(79):
|
||||
head.append('act'+str(j+1))
|
||||
head.append('task_name')
|
||||
head.append('gt')
|
||||
head.insert(0,'action_id')
|
||||
pd.DataFrame(write_data).to_csv("prediction/" + args.model_type + "/" + task_name[0] + "_full.csv", header=head)
|
||||
|
||||
def write_prob_strategy(info, model_name, args):
|
||||
temp_prob_list = []
|
||||
temp_action_id_list = []
|
||||
temp_task_name_list = []
|
||||
temp_target_list = []
|
||||
for i in range(len(info['prob'])):
|
||||
for j in range(len(info['prob'][i])):
|
||||
temp_prob_list.append(info['prob'][i][j])
|
||||
|
||||
for i in range(len(info['action_id'])):
|
||||
for j in range(len(info['action_id'][i])):
|
||||
temp_action_id_list.append(info['action_id'][i][j])
|
||||
|
||||
for i in range(len(info['task_name'])):
|
||||
for j in range(len(info['task_name'][i])):
|
||||
temp_task_name_list.append(info['task_name'][i][j])
|
||||
|
||||
for i in range(len(info['target'])):
|
||||
for j in range(len(info['target'][i])):
|
||||
temp_target_list.append(info['target'][i][j])
|
||||
|
||||
prob = np.array(temp_prob_list)
|
||||
action_id = np.array(temp_action_id_list)
|
||||
task_name = np.array(temp_task_name_list)
|
||||
target = np.array(temp_target_list)
|
||||
|
||||
|
||||
write_data = np.concatenate((np.reshape(action_id, (-1, 1)), prob, np.reshape(task_name, (-1, 1)), np.reshape(target, (-1, 1))), axis=1)
|
||||
import pandas as pd
|
||||
head = []
|
||||
for j in range(79):
|
||||
head.append('act'+str(j+1))
|
||||
head.append('task_name')
|
||||
head.append('gt')
|
||||
head.insert(0,'action_id')
|
||||
path = pathlib.Path("stan/prediction/" + args.testset + "/" + args.model_type)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
pd.DataFrame(write_data).to_csv("stan/prediction/" + args.testset + "/" + args.model_type + "/model_" + model_name + '_strategy_' + task_name[0] + ".csv", header=head)
|
||||
|
||||
def save(args, i, checkpoint_dir, model, task):
|
||||
save_path = '{}/demo2predicate-{}{}.ckpt'.format(checkpoint_dir, 'best_model_', task)
|
||||
model.save(save_path, True)
|
||||
|
||||
def save_checkpoint(args, i, checkpoint_dir, model, task):
|
||||
save_path = '{}/demo2predicate-{}{}.ckpt'.format(checkpoint_dir, 'checkpoint_model_', task)
|
||||
model.save(save_path, True)
|
Loading…
Add table
Add a link
Reference in a new issue