369 lines
14 KiB
Python
369 lines
14 KiB
Python
|
import resource
|
||
|
import time
|
||
|
from termcolor import colored
|
||
|
import torch
|
||
|
from torch.utils.data import DataLoader
|
||
|
from helper import Constant, LinearStep
|
||
|
from predicate.utils import save, setup, save_checkpoint
|
||
|
from predicate.utils import summary, write_prob, summary_eval, write_prob_strategy
|
||
|
import random
|
||
|
import json
|
||
|
import pickle
|
||
|
import numpy as np
|
||
|
|
||
|
|
||
|
topk = 1
|
||
|
p_th = 0.5
|
||
|
|
||
|
|
||
|
def print_output(args, outputs, targets, file_names, test_dset):
|
||
|
goal_predicates = test_dset.goal_predicates
|
||
|
goal_predicates = {v:k for k,v in goal_predicates.items()}
|
||
|
json_output = {}
|
||
|
|
||
|
for i, target in enumerate(targets):
|
||
|
if args.inference == 0:
|
||
|
p = random.uniform(0, 1)
|
||
|
if p>p_th:
|
||
|
continue
|
||
|
|
||
|
file_name = file_names[i]
|
||
|
output = outputs[i]
|
||
|
|
||
|
if args.multi_classifier:
|
||
|
output = torch.Tensor(output).view(-1, len(test_dset.goal_predicates), test_dset.max_subgoal_length+1)
|
||
|
target = torch.Tensor(target).view(-1, len(test_dset.goal_predicates))
|
||
|
else:
|
||
|
output = torch.Tensor(output).view(-1, test_dset.max_goal_length, len(test_dset.goal_predicates))
|
||
|
target = torch.Tensor(target).view(-1, test_dset.max_goal_length)
|
||
|
|
||
|
output = output.numpy()
|
||
|
target = target.numpy()
|
||
|
|
||
|
if args.inference == 0:
|
||
|
target_inference = [target[0]]
|
||
|
output_inference = [output[0]]
|
||
|
file_name_inference = [file_name[0]]
|
||
|
else:
|
||
|
target_inference = target
|
||
|
output_inference = output
|
||
|
file_name_inference = file_name
|
||
|
|
||
|
|
||
|
for (target_j, output_j, file_name_j) in zip(target_inference, output_inference, file_name_inference):
|
||
|
## only show the fist sample in each minibatch
|
||
|
assert file_name_j not in json_output
|
||
|
json_output[file_name_j] = {}
|
||
|
json_output[file_name_j]['ground_truth'] = []
|
||
|
json_output[file_name_j]['prediction'] = []
|
||
|
json_output[file_name_j]['ground_truth_id'] = []
|
||
|
json_output[file_name_j]['prediction_id'] = []
|
||
|
|
||
|
print('----------------------------------------------------------------------------------')
|
||
|
if args.multi_classifier:
|
||
|
assert len(target_j) == len(goal_predicates) == len(output_j)
|
||
|
for k, target_k in enumerate(target_j):
|
||
|
output_k = output_j[k]
|
||
|
strtar = ('tar: %s %d' % (goal_predicates[k], target_k)).ljust(50, ' ')
|
||
|
strpre = '| gen: %s %d' % (goal_predicates[k], output_k.argmax())
|
||
|
print(strtar+strpre)
|
||
|
|
||
|
json_output[file_name_j]['ground_truth_id'].append(int(target_k))
|
||
|
json_output[file_name_j]['prediction_id'].append(output_k.argmax())
|
||
|
json_output[file_name_j]['ground_truth'].append(goal_predicates[k])
|
||
|
json_output[file_name_j]['prediction'].append(goal_predicates[k])
|
||
|
else:
|
||
|
for k, target_k in enumerate(target_j):
|
||
|
output_k = output_j[k]
|
||
|
|
||
|
strtar = ('tar: %s' % goal_predicates[int(target_k)]).ljust(50, ' ')
|
||
|
strpre = '| gen: %s' % goal_predicates[output_k.argmax()]
|
||
|
print(strtar+strpre)
|
||
|
|
||
|
json_output[file_name_j]['ground_truth_id'].append(int(target_k))
|
||
|
json_output[file_name_j]['prediction_id'].append(output_k.argmax())
|
||
|
json_output[file_name_j]['ground_truth'].append(goal_predicates[int(target_k)])
|
||
|
json_output[file_name_j]['prediction'].append(goal_predicates[output_k.argmax()])
|
||
|
|
||
|
print('----------------------------------------------------------------------------------')
|
||
|
|
||
|
if args.inference == 1:
|
||
|
if args.single:
|
||
|
pickle.dump( json_output, open( "dataset/test_output_"+args.resume.split('/')[-2]+"_single_task.p", "wb" ) )
|
||
|
else:
|
||
|
pickle.dump( json_output, open( "dataset/test_output_"+args.resume.split('/')[-2]+"_multiple_task.p", "wb" ) )
|
||
|
|
||
|
def run_one_iteration(model, optim, batch_data, train_args, args):
|
||
|
model.train()
|
||
|
optim.zero_grad()
|
||
|
loss, info = model(batch_data, **train_args)
|
||
|
loss.backward()
|
||
|
optim.step()
|
||
|
return batch_data, info, loss
|
||
|
|
||
|
|
||
|
def train(
|
||
|
args,
|
||
|
model,
|
||
|
optim,
|
||
|
train_loader,
|
||
|
test_loader,
|
||
|
val_loader,
|
||
|
checkpoint_dir,
|
||
|
writer,
|
||
|
train_dset,
|
||
|
test_dset,
|
||
|
task):
|
||
|
# Train
|
||
|
print(colored('Start training...', 'red'))
|
||
|
# loader for the testing set
|
||
|
def _loader():
|
||
|
while True:
|
||
|
for batch_data in test_loader:
|
||
|
yield batch_data
|
||
|
|
||
|
get_next_data_fn = _loader().__iter__().__next__
|
||
|
train_args = {}
|
||
|
|
||
|
if args.inference == 1:
|
||
|
info = summary(
|
||
|
args,
|
||
|
writer,
|
||
|
None,
|
||
|
None,
|
||
|
model,
|
||
|
test_loader,
|
||
|
'test')
|
||
|
print('test top1', info['top1'])
|
||
|
write_prob(info, args)
|
||
|
|
||
|
def _train_loop(task):
|
||
|
iter = 0
|
||
|
summary_t1 = time.time()
|
||
|
|
||
|
test_best_top1 = 0
|
||
|
print('start while')
|
||
|
print('train iterations: ',args.train_iters)
|
||
|
while iter <= args.train_iters:
|
||
|
for batch_data in train_loader:
|
||
|
results = run_one_iteration(model, optim, batch_data, train_args, args)
|
||
|
batch_data, info, loss = results
|
||
|
|
||
|
if iter % 10 == 0:
|
||
|
print('%s: training %d / %d: loss %.4f: acc %.4f' % (args.checkpoint, iter, len(train_loader), loss, info['top1']))
|
||
|
|
||
|
fps = 10. / (time.time() - summary_t1)
|
||
|
info = summary(
|
||
|
args,
|
||
|
writer,
|
||
|
info,
|
||
|
train_args,
|
||
|
model,
|
||
|
None,
|
||
|
'train',
|
||
|
fps=fps)
|
||
|
if iter > 0:
|
||
|
summary_t1 = time.time()
|
||
|
|
||
|
if iter % (len(train_loader)*1) == 0 and iter>0:
|
||
|
info = summary(
|
||
|
args,
|
||
|
writer,
|
||
|
None,
|
||
|
None,
|
||
|
model,
|
||
|
test_loader,
|
||
|
'test')
|
||
|
|
||
|
if info['top1']>test_best_top1:
|
||
|
test_best_top1 = info['top1']
|
||
|
save(args, iter, checkpoint_dir, model, task)
|
||
|
save_checkpoint(args, iter, checkpoint_dir, model, task)
|
||
|
|
||
|
iter += 1
|
||
|
print('start train loop')
|
||
|
_train_loop(task)
|
||
|
print('train loop done')
|
||
|
|
||
|
|
||
|
def main():
|
||
|
args, checkpoint_dir, writer, model_config = setup(train=True)
|
||
|
print(args)
|
||
|
|
||
|
from predicate.demo_dataset_graph_strategy_test import get_dataset
|
||
|
from predicate.demo_dataset_graph_strategy_test import collate_fn
|
||
|
from predicate.demo_dataset_graph_strategy_test import to_cuda_fn
|
||
|
|
||
|
#strategy inference
|
||
|
if args.inference == 2: # 0: not infer, 1: infer, 2: strategy infer
|
||
|
from network.encoder_decoder import ActionDemo2Predicate
|
||
|
test_tasks = ['put_fridge', 'put_dishwasher', 'read_book', 'prepare_food', 'setup_table']
|
||
|
new_test_tasks = ['put_fridge', 'put_dishwasher', 'read_book']
|
||
|
|
||
|
train_dsets = []
|
||
|
test_dsets = []
|
||
|
new_test_dsets = []
|
||
|
models = []
|
||
|
train_loaders = []
|
||
|
test_loaders = []
|
||
|
val_loaders = []
|
||
|
for i in range(len(new_test_tasks)):
|
||
|
loss_weights = np.load('dataset/watch_data/loss_weight_'+test_tasks[i]+'_new_test_task'+'.npy')
|
||
|
train_dset, test_dset, new_test_dset = get_dataset(args, new_test_tasks[i], train=True )
|
||
|
train_dsets.append(train_dset)
|
||
|
test_dsets.append(test_dset)
|
||
|
new_test_dsets.append(new_test_dset)
|
||
|
model = ActionDemo2Predicate(args, train_dset, loss_weights, **model_config)
|
||
|
model.load(args.checkpoint+'/demo2predicate-checkpoint_model_'+new_test_tasks[i]+'.ckpt', True)
|
||
|
model.cuda()
|
||
|
model.eval()
|
||
|
models.append(model)
|
||
|
|
||
|
train_loader = DataLoader(
|
||
|
dataset=train_dset,
|
||
|
batch_size=args.batch_size,
|
||
|
shuffle=True,
|
||
|
num_workers=args.n_workers,
|
||
|
drop_last=True)
|
||
|
if args.testset == 'test_task':
|
||
|
test_loader = DataLoader(
|
||
|
dataset=test_dset,
|
||
|
batch_size=args.batch_size,
|
||
|
shuffle=False,
|
||
|
num_workers=0,
|
||
|
drop_last=True)
|
||
|
|
||
|
val_loader = DataLoader(
|
||
|
dataset=test_dset,
|
||
|
batch_size=args.batch_size,
|
||
|
shuffle=False,
|
||
|
num_workers=0,
|
||
|
drop_last=True)
|
||
|
if args.testset == 'new_test_task':
|
||
|
test_loader = DataLoader(
|
||
|
dataset=new_test_dset,
|
||
|
batch_size=args.batch_size,
|
||
|
shuffle=False,
|
||
|
num_workers=0,
|
||
|
drop_last=True)
|
||
|
|
||
|
val_loader = DataLoader(
|
||
|
dataset=new_test_dset,
|
||
|
batch_size=args.batch_size,
|
||
|
shuffle=False,
|
||
|
num_workers=0,
|
||
|
drop_last=True)
|
||
|
|
||
|
train_loaders.append(train_loader)
|
||
|
test_loaders.append(test_loader)
|
||
|
val_loaders.append(val_loader)
|
||
|
|
||
|
|
||
|
for i in range(len(models)):
|
||
|
infos = []
|
||
|
for j in range(len(test_loaders)):
|
||
|
info = summary_eval(
|
||
|
models[i],
|
||
|
test_loaders[j],
|
||
|
test_loaders[j].dataset)
|
||
|
print('test top1', info['top1'])
|
||
|
infos.append(info)
|
||
|
total_info = {
|
||
|
"prob": np.concatenate((infos[0]["prob"], infos[1]["prob"], infos[2]["prob"]), axis=0),
|
||
|
"target": np.concatenate((infos[0]["target"], infos[1]["target"], infos[2]["target"]), axis=0), #batch_target.cpu().numpy(),
|
||
|
"task_name": np.concatenate((infos[0]["task_name"], infos[1]["task_name"], infos[2]["task_name"]), axis=0), #batch_task_name,
|
||
|
"action_id": np.concatenate((infos[0]["action_id"], infos[1]["action_id"], infos[2]["action_id"]), axis=0) #batch_action_id.cpu().numpy()
|
||
|
}
|
||
|
write_prob_strategy(total_info, test_tasks[i], args)
|
||
|
|
||
|
else:
|
||
|
print('get dataset')
|
||
|
test_tasks = ['put_dishwasher', 'read_book', 'put_fridge', 'prepare_food', 'setup_table']
|
||
|
new_test_tasks = ['put_dishwasher', 'read_book', 'put_fridge']
|
||
|
for i in range(len(new_test_tasks)):
|
||
|
train_dset, test_dset, new_test_dset = get_dataset(args, test_tasks[i], train=True )
|
||
|
print('train set len:',len(train_dset))
|
||
|
|
||
|
train_loader = DataLoader(
|
||
|
dataset=train_dset,
|
||
|
batch_size=args.batch_size,
|
||
|
shuffle=True,
|
||
|
num_workers=args.n_workers,
|
||
|
drop_last=True)
|
||
|
|
||
|
if args.single:
|
||
|
test_loader = DataLoader(
|
||
|
dataset=new_test_dset,
|
||
|
batch_size=args.batch_size,
|
||
|
shuffle=True,
|
||
|
num_workers=0,
|
||
|
drop_last=True)
|
||
|
|
||
|
val_loader = DataLoader(
|
||
|
dataset=test_dset,
|
||
|
batch_size=args.batch_size,
|
||
|
shuffle=True,
|
||
|
num_workers=0,
|
||
|
drop_last=True)
|
||
|
else:
|
||
|
test_loader = DataLoader(
|
||
|
dataset=new_test_dset,
|
||
|
batch_size=args.batch_size,
|
||
|
shuffle=True,
|
||
|
num_workers=0,
|
||
|
drop_last=True)
|
||
|
|
||
|
val_loader = DataLoader(
|
||
|
dataset=test_dset,
|
||
|
batch_size=args.batch_size,
|
||
|
shuffle=True,
|
||
|
num_workers=0,
|
||
|
drop_last=True)
|
||
|
|
||
|
# initialize model
|
||
|
loss_weights = np.load('dataset/watch_data/loss_weight_'+test_tasks[i]+'_train_task'+'.npy')
|
||
|
if args.inputtype=='graphinput':
|
||
|
from network.encoder_decoder import GraphDemo2Predicate
|
||
|
model = GraphDemo2Predicate(args, train_dset, **model_config)
|
||
|
elif args.inputtype=='actioninput':
|
||
|
from network.encoder_decoder import ActionDemo2Predicate
|
||
|
model = ActionDemo2Predicate(args, train_dset, loss_weights, **model_config)
|
||
|
|
||
|
if args.resume!='':
|
||
|
model.load(args.resume, True)
|
||
|
|
||
|
|
||
|
optim = torch.optim.Adam(
|
||
|
filter(
|
||
|
lambda p: p.requires_grad,
|
||
|
model.parameters()),
|
||
|
args.model_lr_rate)
|
||
|
if args.gpu_id is not None:
|
||
|
model.cuda()
|
||
|
|
||
|
# main loop
|
||
|
train(
|
||
|
args,
|
||
|
model,
|
||
|
optim,
|
||
|
train_loader,
|
||
|
test_loader,
|
||
|
val_loader,
|
||
|
checkpoint_dir,
|
||
|
writer,
|
||
|
train_dset,
|
||
|
test_dset,
|
||
|
test_tasks[i])
|
||
|
|
||
|
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||
|
resource.setrlimit(resource.RLIMIT_NOFILE, (1024 * 4, rlimit[1]))
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
from multiprocessing import set_start_method
|
||
|
try:
|
||
|
set_start_method('spawn')
|
||
|
except RuntimeError:
|
||
|
pass
|
||
|
main()
|