InferringIntention/watch_and_help/watch_strategy_full/predicate-train-strategy.py
2024-03-24 23:42:27 +01:00

368 lines
14 KiB
Python

import resource
import time
from termcolor import colored
import torch
from torch.utils.data import DataLoader
from helper import Constant, LinearStep
from predicate.utils import save, setup, save_checkpoint
from predicate.utils import summary, write_prob, summary_eval, write_prob_strategy
import random
import json
import pickle
import numpy as np
topk = 1
p_th = 0.5
def print_output(args, outputs, targets, file_names, test_dset):
goal_predicates = test_dset.goal_predicates
goal_predicates = {v:k for k,v in goal_predicates.items()}
json_output = {}
for i, target in enumerate(targets):
if args.inference == 0:
p = random.uniform(0, 1)
if p>p_th:
continue
file_name = file_names[i]
output = outputs[i]
if args.multi_classifier:
output = torch.Tensor(output).view(-1, len(test_dset.goal_predicates), test_dset.max_subgoal_length+1)
target = torch.Tensor(target).view(-1, len(test_dset.goal_predicates))
else:
output = torch.Tensor(output).view(-1, test_dset.max_goal_length, len(test_dset.goal_predicates))
target = torch.Tensor(target).view(-1, test_dset.max_goal_length)
output = output.numpy()
target = target.numpy()
if args.inference == 0:
target_inference = [target[0]]
output_inference = [output[0]]
file_name_inference = [file_name[0]]
else:
target_inference = target
output_inference = output
file_name_inference = file_name
for (target_j, output_j, file_name_j) in zip(target_inference, output_inference, file_name_inference):
## only show the fist sample in each minibatch
assert file_name_j not in json_output
json_output[file_name_j] = {}
json_output[file_name_j]['ground_truth'] = []
json_output[file_name_j]['prediction'] = []
json_output[file_name_j]['ground_truth_id'] = []
json_output[file_name_j]['prediction_id'] = []
print('----------------------------------------------------------------------------------')
if args.multi_classifier:
assert len(target_j) == len(goal_predicates) == len(output_j)
for k, target_k in enumerate(target_j):
output_k = output_j[k]
strtar = ('tar: %s %d' % (goal_predicates[k], target_k)).ljust(50, ' ')
strpre = '| gen: %s %d' % (goal_predicates[k], output_k.argmax())
print(strtar+strpre)
json_output[file_name_j]['ground_truth_id'].append(int(target_k))
json_output[file_name_j]['prediction_id'].append(output_k.argmax())
json_output[file_name_j]['ground_truth'].append(goal_predicates[k])
json_output[file_name_j]['prediction'].append(goal_predicates[k])
else:
for k, target_k in enumerate(target_j):
output_k = output_j[k]
strtar = ('tar: %s' % goal_predicates[int(target_k)]).ljust(50, ' ')
strpre = '| gen: %s' % goal_predicates[output_k.argmax()]
print(strtar+strpre)
json_output[file_name_j]['ground_truth_id'].append(int(target_k))
json_output[file_name_j]['prediction_id'].append(output_k.argmax())
json_output[file_name_j]['ground_truth'].append(goal_predicates[int(target_k)])
json_output[file_name_j]['prediction'].append(goal_predicates[output_k.argmax()])
print('----------------------------------------------------------------------------------')
if args.inference == 1:
if args.single:
pickle.dump( json_output, open( "dataset/test_output_"+args.resume.split('/')[-2]+"_single_task.p", "wb" ) )
else:
pickle.dump( json_output, open( "dataset/test_output_"+args.resume.split('/')[-2]+"_multiple_task.p", "wb" ) )
def run_one_iteration(model, optim, batch_data, train_args, args):
model.train()
optim.zero_grad()
loss, info = model(batch_data, **train_args)
loss.backward()
optim.step()
return batch_data, info, loss
def train(
args,
model,
optim,
train_loader,
test_loader,
val_loader,
checkpoint_dir,
writer,
train_dset,
test_dset,
task):
# Train
print(colored('Start training...', 'red'))
# loader for the testing set
def _loader():
while True:
for batch_data in test_loader:
yield batch_data
get_next_data_fn = _loader().__iter__().__next__
train_args = {}
if args.inference == 1:
info = summary(
args,
writer,
None,
None,
model,
test_loader,
'test')
print('test top1', info['top1'])
write_prob(info, args)
def _train_loop(task):
iter = 0
summary_t1 = time.time()
test_best_top1 = 0
print('start while')
print('train iterations: ',args.train_iters)
while iter <= args.train_iters:
for batch_data in train_loader:
results = run_one_iteration(model, optim, batch_data, train_args, args)
batch_data, info, loss = results
if iter % 10 == 0:
print('%s: training %d / %d: loss %.4f: acc %.4f' % (args.checkpoint, iter, len(train_loader), loss, info['top1']))
fps = 10. / (time.time() - summary_t1)
info = summary(
args,
writer,
info,
train_args,
model,
None,
'train',
fps=fps)
if iter > 0:
summary_t1 = time.time()
if iter % (len(train_loader)*1) == 0 and iter>0:
info = summary(
args,
writer,
None,
None,
model,
test_loader,
'test')
if info['top1']>test_best_top1:
test_best_top1 = info['top1']
save(args, iter, checkpoint_dir, model, task)
save_checkpoint(args, iter, checkpoint_dir, model, task)
iter += 1
print('start train loop')
_train_loop(task)
print('train loop done')
def main():
args, checkpoint_dir, writer, model_config = setup(train=True)
print(args)
from predicate.demo_dataset_graph_strategy_test import get_dataset
from predicate.demo_dataset_graph_strategy_test import collate_fn
from predicate.demo_dataset_graph_strategy_test import to_cuda_fn
#strategy inference
if args.inference == 2: # 0: not infer, 1: infer, 2: strategy infer
from network.encoder_decoder import ActionDemo2Predicate
test_tasks = ['put_fridge', 'put_dishwasher', 'read_book', 'prepare_food', 'setup_table']
new_test_tasks = ['put_fridge', 'put_dishwasher', 'read_book']
train_dsets = []
test_dsets = []
new_test_dsets = []
models = []
train_loaders = []
test_loaders = []
val_loaders = []
for i in range(len(new_test_tasks)):
loss_weights = np.load('dataset/watch_data/loss_weight_'+test_tasks[i]+'_new_test_task'+'.npy')
train_dset, test_dset, new_test_dset = get_dataset(args, new_test_tasks[i], train=True )
train_dsets.append(train_dset)
test_dsets.append(test_dset)
new_test_dsets.append(new_test_dset)
model = ActionDemo2Predicate(args, train_dset, loss_weights, **model_config)
model.load(args.checkpoint+'/demo2predicate-checkpoint_model_'+new_test_tasks[i]+'.ckpt', True)
model.cuda()
model.eval()
models.append(model)
train_loader = DataLoader(
dataset=train_dset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.n_workers,
drop_last=True)
if args.testset == 'test_task':
test_loader = DataLoader(
dataset=test_dset,
batch_size=args.batch_size,
shuffle=False,
num_workers=0,
drop_last=True)
val_loader = DataLoader(
dataset=test_dset,
batch_size=args.batch_size,
shuffle=False,
num_workers=0,
drop_last=True)
if args.testset == 'new_test_task':
test_loader = DataLoader(
dataset=new_test_dset,
batch_size=args.batch_size,
shuffle=False,
num_workers=0,
drop_last=True)
val_loader = DataLoader(
dataset=new_test_dset,
batch_size=args.batch_size,
shuffle=False,
num_workers=0,
drop_last=True)
train_loaders.append(train_loader)
test_loaders.append(test_loader)
val_loaders.append(val_loader)
for i in range(len(models)):
infos = []
for j in range(len(test_loaders)):
info = summary_eval(
models[i],
test_loaders[j],
test_loaders[j].dataset)
print('test top1', info['top1'])
infos.append(info)
total_info = {
"prob": np.concatenate((infos[0]["prob"], infos[1]["prob"], infos[2]["prob"]), axis=0),
"target": np.concatenate((infos[0]["target"], infos[1]["target"], infos[2]["target"]), axis=0), #batch_target.cpu().numpy(),
"task_name": np.concatenate((infos[0]["task_name"], infos[1]["task_name"], infos[2]["task_name"]), axis=0), #batch_task_name,
"action_id": np.concatenate((infos[0]["action_id"], infos[1]["action_id"], infos[2]["action_id"]), axis=0) #batch_action_id.cpu().numpy()
}
write_prob_strategy(total_info, test_tasks[i], args)
else:
print('get dataset')
test_tasks = ['put_dishwasher', 'read_book', 'put_fridge', 'prepare_food', 'setup_table']
new_test_tasks = ['put_dishwasher', 'read_book', 'put_fridge']
for i in range(len(new_test_tasks)):
train_dset, test_dset, new_test_dset = get_dataset(args, test_tasks[i], train=True )
print('train set len:',len(train_dset))
train_loader = DataLoader(
dataset=train_dset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.n_workers,
drop_last=True)
if args.single:
test_loader = DataLoader(
dataset=new_test_dset,
batch_size=args.batch_size,
shuffle=True,
num_workers=0,
drop_last=True)
val_loader = DataLoader(
dataset=test_dset,
batch_size=args.batch_size,
shuffle=True,
num_workers=0,
drop_last=True)
else:
test_loader = DataLoader(
dataset=new_test_dset,
batch_size=args.batch_size,
shuffle=True,
num_workers=0,
drop_last=True)
val_loader = DataLoader(
dataset=test_dset,
batch_size=args.batch_size,
shuffle=True,
num_workers=0,
drop_last=True)
# initialize model
loss_weights = np.load('dataset/watch_data/loss_weight_'+test_tasks[i]+'_train_task'+'.npy')
if args.inputtype=='graphinput':
from network.encoder_decoder import GraphDemo2Predicate
model = GraphDemo2Predicate(args, train_dset, **model_config)
elif args.inputtype=='actioninput':
from network.encoder_decoder import ActionDemo2Predicate
model = ActionDemo2Predicate(args, train_dset, loss_weights, **model_config)
if args.resume!='':
model.load(args.resume, True)
optim = torch.optim.Adam(
filter(
lambda p: p.requires_grad,
model.parameters()),
args.model_lr_rate)
if args.gpu_id is not None:
model.cuda()
# main loop
train(
args,
model,
optim,
train_loader,
test_loader,
val_loader,
checkpoint_dir,
writer,
train_dset,
test_dset,
test_tasks[i])
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (1024 * 4, rlimit[1]))
if __name__ == '__main__':
from multiprocessing import set_start_method
try:
set_start_method('spawn')
except RuntimeError:
pass
main()