298 lines
9.7 KiB
Python
298 lines
9.7 KiB
Python
|
import argparse
|
||
|
import random
|
||
|
import time
|
||
|
import os
|
||
|
import json
|
||
|
import numpy as np
|
||
|
|
||
|
|
||
|
import torch
|
||
|
from torch.utils.tensorboard import SummaryWriter
|
||
|
from helper import to_cpu, average_over_list, writer_helper
|
||
|
import csv
|
||
|
import pathlib
|
||
|
|
||
|
def grab_args():
|
||
|
|
||
|
def str2bool(v):
|
||
|
return v.lower() == 'true'
|
||
|
|
||
|
parser = argparse.ArgumentParser(description='')
|
||
|
parser.add_argument('--seed', type=int, default=123, help='random seed')
|
||
|
parser.add_argument('--verbose', type=str2bool, default=False)
|
||
|
parser.add_argument('--debug', type=str2bool, default=False)
|
||
|
parser.add_argument('--prefix', type=str, default='test')
|
||
|
parser.add_argument('--checkpoint', type=str, default=None)
|
||
|
parser.add_argument('--n_workers', type=int, default=0)
|
||
|
parser.add_argument('--train_iters', type=int, default=2e4)
|
||
|
parser.add_argument('--inputtype', type=str, default='actioninput')
|
||
|
parser.add_argument('--resume', type=str, default='')
|
||
|
parser.add_argument('--dropout', type=float, default=0)
|
||
|
parser.add_argument('--inference', type=int, default=0)
|
||
|
parser.add_argument('--single', type=int, default=0)
|
||
|
parser.add_argument('--loss_type', type=str, default='regu') #regu or ce
|
||
|
parser.add_argument('--testset', type=str, default='test') # test: test set 1, new_test: test set 2
|
||
|
|
||
|
# model config
|
||
|
parser.add_argument(
|
||
|
'--model_type',
|
||
|
type=str,
|
||
|
default='max')
|
||
|
parser.add_argument('--embedding_dim', type=int, default=100)
|
||
|
parser.add_argument('--predicate_hidden', type=int, default=128)
|
||
|
parser.add_argument('--demo_hidden', type=int, default=128)
|
||
|
parser.add_argument('--multi_classifier', type=int, default=0)
|
||
|
parser.add_argument('--transformer_nhead', type=int, default=2)
|
||
|
|
||
|
# train config
|
||
|
parser.add_argument(
|
||
|
'--gpu_id',
|
||
|
metavar='N',
|
||
|
type=str,
|
||
|
nargs='+',
|
||
|
help='specify the gpu id')
|
||
|
parser.add_argument('--batch_size', type=int, default=2)
|
||
|
parser.add_argument('--model_lr_rate', type=float, default=3e-4)
|
||
|
|
||
|
args = parser.parse_args()
|
||
|
return args
|
||
|
|
||
|
|
||
|
def setup(train):
|
||
|
def _basic_setting(args):
|
||
|
|
||
|
# set seed
|
||
|
torch.manual_seed(args.seed)
|
||
|
random.seed(args.seed)
|
||
|
np.random.seed(args.seed)
|
||
|
|
||
|
if args.gpu_id is None:
|
||
|
os.environ['CUDA_VISIBLE_DEVICES'] = ''
|
||
|
args.__dict__.update({'cuda': False})
|
||
|
else:
|
||
|
os.environ['CUDA_VISIBLE_DEVICES'] = ', '.join(args.gpu_id)
|
||
|
args.__dict__.update({'cuda': True})
|
||
|
torch.cuda.manual_seed_all(args.seed)
|
||
|
|
||
|
if args.debug:
|
||
|
args.verbose = True
|
||
|
|
||
|
def _basic_checking(args):
|
||
|
pass
|
||
|
|
||
|
def _create_checkpoint_dir(args):
|
||
|
# setup checkpoint_dir
|
||
|
if args.debug:
|
||
|
checkpoint_dir = 'debug'
|
||
|
elif train:
|
||
|
checkpoint_dir = 'checkpoint_dir'
|
||
|
else:
|
||
|
checkpoint_dir = 'testing_dir'
|
||
|
|
||
|
checkpoint_dir = os.path.join(checkpoint_dir, 'demo2predicate')
|
||
|
|
||
|
args_dict = args.__dict__
|
||
|
keys = sorted(args_dict)
|
||
|
prefix = ['{}-{}'.format(k, args_dict[k]) for k in keys]
|
||
|
prefix.remove('debug-{}'.format(args.debug))
|
||
|
prefix.remove('checkpoint-{}'.format(args.checkpoint))
|
||
|
prefix.remove('gpu_id-{}'.format(args.gpu_id))
|
||
|
|
||
|
checkpoint_dir = os.path.join(checkpoint_dir, *prefix)
|
||
|
|
||
|
checkpoint_dir += '/{}'.format(time.strftime("%Y%m%d-%H%M%S"))
|
||
|
|
||
|
return checkpoint_dir
|
||
|
|
||
|
def _make_dirs(checkpoint_dir, tfboard_dir):
|
||
|
if not os.path.exists(checkpoint_dir):
|
||
|
os.makedirs(checkpoint_dir)
|
||
|
if not os.path.exists(tfboard_dir):
|
||
|
os.makedirs(tfboard_dir)
|
||
|
|
||
|
def _print_args(args):
|
||
|
args_str = ''
|
||
|
with open(os.path.join(checkpoint_dir, 'args.txt'), 'w') as f:
|
||
|
for k, v in args.__dict__.items():
|
||
|
s = '{}: {}'.format(k, v)
|
||
|
args_str += '{}\n'.format(s)
|
||
|
print(s)
|
||
|
f.write(s + '\n')
|
||
|
print("All the data will be saved in", checkpoint_dir)
|
||
|
return args_str
|
||
|
|
||
|
args = grab_args()
|
||
|
_basic_setting(args)
|
||
|
_basic_checking(args)
|
||
|
|
||
|
checkpoint_dir = args.checkpoint
|
||
|
tfboard_dir = os.path.join(checkpoint_dir, 'tfboard')
|
||
|
|
||
|
_make_dirs(checkpoint_dir, tfboard_dir)
|
||
|
args_str = _print_args(args)
|
||
|
|
||
|
writer = SummaryWriter(tfboard_dir)
|
||
|
writer.add_text('args', args_str, 0)
|
||
|
writer = writer_helper(writer)
|
||
|
|
||
|
model_config = {
|
||
|
"model_type": args.model_type,
|
||
|
"embedding_dim": args.embedding_dim,
|
||
|
"predicate_hidden": args.predicate_hidden,
|
||
|
}
|
||
|
model_config.update({"demo_hidden": args.demo_hidden})
|
||
|
|
||
|
return args, checkpoint_dir, writer, model_config
|
||
|
|
||
|
|
||
|
def summary(
|
||
|
args,
|
||
|
writer,
|
||
|
info,
|
||
|
train_args,
|
||
|
model,
|
||
|
test_loader,
|
||
|
postfix,
|
||
|
fps=None):
|
||
|
|
||
|
if postfix == 'train':
|
||
|
model.write_summary(writer, info, postfix=postfix)
|
||
|
elif postfix == 'val':
|
||
|
info = summary_eval(
|
||
|
model,
|
||
|
test_loader,
|
||
|
test_loader.dataset)
|
||
|
model.write_summary(writer, info, postfix=postfix)
|
||
|
elif postfix == 'test':
|
||
|
info = summary_eval(
|
||
|
model,
|
||
|
test_loader,
|
||
|
test_loader.dataset)
|
||
|
|
||
|
model.write_summary(writer, info, postfix=postfix)
|
||
|
else:
|
||
|
raise ValueError
|
||
|
|
||
|
if fps:
|
||
|
writer.scalar_summary('General/fps', fps)
|
||
|
|
||
|
return info
|
||
|
|
||
|
def summary_eval(
|
||
|
model,
|
||
|
loader,
|
||
|
dset):
|
||
|
|
||
|
model.eval()
|
||
|
print(len(loader))
|
||
|
with torch.no_grad():
|
||
|
|
||
|
loss_list = []
|
||
|
top1_list = []
|
||
|
iter = 0
|
||
|
action_id_list = []
|
||
|
|
||
|
prob = []
|
||
|
target = []
|
||
|
file_name = []
|
||
|
for batch_data in loader:
|
||
|
loss, info = model(batch_data)
|
||
|
loss_list.append(loss.cpu().item())
|
||
|
top1_list.append(info['top1'])
|
||
|
|
||
|
prob.append(info['prob'])
|
||
|
target.append(info['target'])
|
||
|
file_name.append(info['task_name'])
|
||
|
action_id_list.append(info['action_id'])
|
||
|
|
||
|
if iter%10==0:
|
||
|
print('testing %d / %d: loss %.4f: acc %.4f' % (iter, len(loader), loss, info['top1']))
|
||
|
|
||
|
iter += 1
|
||
|
|
||
|
info = {"loss": sum(loss_list)/ len(loss_list), "top1": sum(top1_list)/ len(top1_list), "prob": prob, "target": target, "task_name": file_name, "action_id": action_id_list}
|
||
|
return info
|
||
|
|
||
|
def write_prob(info, args):
|
||
|
temp_prob_list = []
|
||
|
temp_action_id_list = []
|
||
|
temp_task_name_list = []
|
||
|
temp_target_list = []
|
||
|
for i in range(len(info['prob'])):
|
||
|
for j in range(len(info['prob'][i])):
|
||
|
temp_prob_list.append(info['prob'][i][j])
|
||
|
|
||
|
for i in range(len(info['action_id'])):
|
||
|
for j in range(len(info['action_id'][i])):
|
||
|
temp_action_id_list.append(info['action_id'][i][j])
|
||
|
|
||
|
for i in range(len(info['task_name'])):
|
||
|
for j in range(len(info['task_name'][i])):
|
||
|
temp_task_name_list.append(info['task_name'][i][j])
|
||
|
|
||
|
for i in range(len(info['target'])):
|
||
|
for j in range(len(info['target'][i])):
|
||
|
temp_target_list.append(info['target'][i][j])
|
||
|
|
||
|
prob = np.array(temp_prob_list)
|
||
|
action_id = np.array(temp_action_id_list)
|
||
|
task_name = np.array(temp_task_name_list)
|
||
|
target = np.array(temp_target_list)
|
||
|
|
||
|
write_data = np.concatenate((np.reshape(action_id, (-1, 1)), prob, np.reshape(task_name, (-1, 1)), np.reshape(target, (-1, 1))), axis=1)
|
||
|
import pandas as pd
|
||
|
head = []
|
||
|
for j in range(79):
|
||
|
head.append('act'+str(j+1))
|
||
|
head.append('task_name')
|
||
|
head.append('gt')
|
||
|
head.insert(0,'action_id')
|
||
|
pd.DataFrame(write_data).to_csv("prediction/" + args.model_type + "/" + task_name[0] + "_full.csv", header=head)
|
||
|
|
||
|
def write_prob_strategy(info, model_name, args):
|
||
|
temp_prob_list = []
|
||
|
temp_action_id_list = []
|
||
|
temp_task_name_list = []
|
||
|
temp_target_list = []
|
||
|
for i in range(len(info['prob'])):
|
||
|
for j in range(len(info['prob'][i])):
|
||
|
temp_prob_list.append(info['prob'][i][j])
|
||
|
|
||
|
for i in range(len(info['action_id'])):
|
||
|
for j in range(len(info['action_id'][i])):
|
||
|
temp_action_id_list.append(info['action_id'][i][j])
|
||
|
|
||
|
for i in range(len(info['task_name'])):
|
||
|
for j in range(len(info['task_name'][i])):
|
||
|
temp_task_name_list.append(info['task_name'][i][j])
|
||
|
|
||
|
for i in range(len(info['target'])):
|
||
|
for j in range(len(info['target'][i])):
|
||
|
temp_target_list.append(info['target'][i][j])
|
||
|
|
||
|
prob = np.array(temp_prob_list)
|
||
|
action_id = np.array(temp_action_id_list)
|
||
|
task_name = np.array(temp_task_name_list)
|
||
|
target = np.array(temp_target_list)
|
||
|
|
||
|
|
||
|
write_data = np.concatenate((np.reshape(action_id, (-1, 1)), prob, np.reshape(task_name, (-1, 1)), np.reshape(target, (-1, 1))), axis=1)
|
||
|
import pandas as pd
|
||
|
head = []
|
||
|
for j in range(79):
|
||
|
head.append('act'+str(j+1))
|
||
|
head.append('task_name')
|
||
|
head.append('gt')
|
||
|
head.insert(0,'action_id')
|
||
|
path = pathlib.Path("stan/prediction/" + args.testset + "/" + args.model_type)
|
||
|
path.mkdir(parents=True, exist_ok=True)
|
||
|
pd.DataFrame(write_data).to_csv("stan/prediction/" + args.testset + "/" + args.model_type + "/model_" + model_name + '_strategy_' + task_name[0] + ".csv", header=head)
|
||
|
|
||
|
def save(args, i, checkpoint_dir, model, task):
|
||
|
save_path = '{}/demo2predicate-{}{}.ckpt'.format(checkpoint_dir, 'best_model_', task)
|
||
|
model.save(save_path, True)
|
||
|
|
||
|
def save_checkpoint(args, i, checkpoint_dir, model, task):
|
||
|
save_path = '{}/demo2predicate-{}{}.ckpt'.format(checkpoint_dir, 'checkpoint_model_', task)
|
||
|
model.save(save_path, True)
|