first commit
This commit is contained in:
commit
83b04e2133
109 changed files with 12081 additions and 0 deletions
Binary file not shown.
Binary file not shown.
131
watch_and_help/watch_strategy_full/network/encoder_decoder.py
Normal file
131
watch_and_help/watch_strategy_full/network/encoder_decoder.py
Normal file
|
@ -0,0 +1,131 @@
|
|||
import copy
|
||||
import numpy as np
|
||||
from termcolor import colored
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
#import ipdb
|
||||
#import pdb
|
||||
|
||||
def _align_tensor_index(reference_index, tensor_index):
|
||||
|
||||
where_in_tensor = []
|
||||
for i in reference_index:
|
||||
where = np.where(i == tensor_index)[0][0]
|
||||
where_in_tensor.append(where)
|
||||
return np.array(where_in_tensor)
|
||||
|
||||
|
||||
def _sort_by_length(list_of_tensor, batch_length, return_idx=False):
|
||||
idx = np.argsort(np.array(copy.copy(batch_length)))[::-1]
|
||||
for i, tensor in enumerate(list_of_tensor):
|
||||
if isinstance(tensor, dict):
|
||||
list_of_tensor[i]['class_objects'] = [tensor['class_objects'][j] for j in idx]
|
||||
list_of_tensor[i]['object_coords'] = [tensor['object_coords'][j] for j in idx]
|
||||
list_of_tensor[i]['states_objects'] = [tensor['states_objects'][j] for j in idx]
|
||||
list_of_tensor[i]['mask_object'] = [tensor['mask_object'][j] for j in idx]
|
||||
else:
|
||||
list_of_tensor[i] = [tensor[j] for j in idx]
|
||||
if return_idx:
|
||||
return list_of_tensor, idx
|
||||
else:
|
||||
return list_of_tensor
|
||||
|
||||
def _sort_by_index(list_of_tensor, idx):
|
||||
for i, tensor in enumerate(list_of_tensor):
|
||||
list_of_tensor[i] = [tensor[j] for j in idx]
|
||||
return list_of_tensor
|
||||
|
||||
class ActionDemo2Predicate(nn.Module):
|
||||
summary_keys = ['loss', 'top1']
|
||||
def __init__(self, args, dset, loss_weight, **kwargs):
|
||||
from network.module_graph import PredicateClassifier
|
||||
super(ActionDemo2Predicate, self).__init__()
|
||||
|
||||
print('------------------------------------------------------------------------------------------')
|
||||
print('ActionDemo2Predicate')
|
||||
print('------------------------------------------------------------------------------------------')
|
||||
|
||||
model_type = kwargs["model_type"]
|
||||
print('model_type', model_type)
|
||||
|
||||
if model_type.lower() == 'max':
|
||||
from network.module_graph import ActionDemoEncoder
|
||||
demo_encoder = ActionDemoEncoder(args, dset, 'max')
|
||||
elif model_type.lower() == 'avg':
|
||||
from network.module_graph import ActionDemoEncoder
|
||||
demo_encoder = ActionDemoEncoder(args, dset, 'avg')
|
||||
elif model_type.lower() == 'lstmavg':
|
||||
from network.module_graph import ActionDemoEncoder
|
||||
demo_encoder = ActionDemoEncoder(args, dset, 'lstmavg')
|
||||
elif model_type.lower() == 'bilstmavg':
|
||||
from network.module_graph import ActionDemoEncoder
|
||||
demo_encoder = ActionDemoEncoder(args, dset, 'bilstmavg')
|
||||
elif model_type.lower() == 'lstmlast':
|
||||
from network.module_graph import ActionDemoEncoder
|
||||
demo_encoder = ActionDemoEncoder(args, dset, 'lstmlast')
|
||||
elif model_type.lower() == 'bilstmlast':
|
||||
from network.module_graph import ActionDemoEncoder
|
||||
demo_encoder = ActionDemoEncoder(args, dset, 'bilstmlast')
|
||||
else:
|
||||
raise ValueError
|
||||
demo_encoder = torch.nn.DataParallel(demo_encoder)
|
||||
|
||||
predicate_decoder = PredicateClassifier(args, dset, loss_weight)
|
||||
|
||||
# for quick save and load
|
||||
all_modules = nn.Sequential()
|
||||
all_modules.add_module('demo_encoder', demo_encoder)
|
||||
all_modules.add_module('predicate_decoder', predicate_decoder)
|
||||
|
||||
self.demo_encoder = demo_encoder
|
||||
self.predicate_decoder = predicate_decoder
|
||||
self.all_modules = all_modules
|
||||
self.to_cuda_fn = None
|
||||
|
||||
def set_to_cuda_fn(self, to_cuda_fn):
|
||||
self.to_cuda_fn = to_cuda_fn
|
||||
|
||||
def forward(self, data, **kwargs):
|
||||
if self.to_cuda_fn:
|
||||
data = self.to_cuda_fn(data)
|
||||
|
||||
# demonstration
|
||||
batch_data = data[0].cuda()
|
||||
batch_gt = data[1].cuda()
|
||||
batch_task_name = data[2]
|
||||
batch_action_id = data[3]
|
||||
|
||||
# demonstration encoder
|
||||
batch_demo_emb, _ = self.demo_encoder(batch_data, batch_gt, batch_task_name)
|
||||
|
||||
loss, info = self.predicate_decoder(batch_demo_emb, batch_gt, batch_action_id, batch_task_name)
|
||||
|
||||
return loss, info
|
||||
|
||||
def write_summary(self, writer, info, postfix):
|
||||
|
||||
model_name = 'Demo2Predicate-{}/'.format(postfix)
|
||||
for k in self.summary_keys:
|
||||
if k in info.keys():
|
||||
writer.scalar_summary(model_name + k, info[k])
|
||||
|
||||
def save(self, path, verbose=False):
|
||||
|
||||
if verbose:
|
||||
print(colored('[*] Save model at {}'.format(path), 'magenta'))
|
||||
torch.save(self.all_modules.state_dict(), path)
|
||||
|
||||
def load(self, path, verbose=False):
|
||||
|
||||
if verbose:
|
||||
print(colored('[*] Load model at {}'.format(path), 'magenta'))
|
||||
self.all_modules.load_state_dict(
|
||||
torch.load(
|
||||
path,
|
||||
map_location=lambda storage,
|
||||
loc: storage))
|
||||
|
||||
|
||||
|
249
watch_and_help/watch_strategy_full/network/module_graph.py
Normal file
249
watch_and_help/watch_strategy_full/network/module_graph.py
Normal file
|
@ -0,0 +1,249 @@
|
|||
import random
|
||||
import itertools
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from helper import fc_block, Constant
|
||||
|
||||
def _calculate_accuracy_predicate(logits, batch_target, max_possible_count=None, topk=1, multi_classifier=False):
|
||||
batch_size = batch_target.size(0) / max_possible_count
|
||||
|
||||
_, pred = logits.topk(topk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(batch_target.view(1, -1).expand_as(pred))
|
||||
|
||||
k = 1
|
||||
accuray = correct[:k].view(-1).float()
|
||||
accuray = accuray.view(-1, max_possible_count)
|
||||
|
||||
correct_k = (accuray.sum(1)==max_possible_count).sum(0)
|
||||
correct_k = correct_k * (100.0 / batch_size)
|
||||
|
||||
return correct_k
|
||||
|
||||
def _calculate_accuracy(
|
||||
action_correct,
|
||||
object_correct,
|
||||
rel_correct,
|
||||
target_correct,
|
||||
batch_length,
|
||||
info):
|
||||
|
||||
action_valid_correct = [sum(action_correct[i, :(l - 1)])
|
||||
for i, l in enumerate(batch_length)]
|
||||
object_valid_correct = [sum(object_correct[i, :(l - 1)])
|
||||
for i, l in enumerate(batch_length)]
|
||||
rel_valid_correct = [sum(rel_correct[i, :(l - 1)])
|
||||
for i, l in enumerate(batch_length)]
|
||||
target_valid_correct = [sum(target_correct[i, :(l - 1)])
|
||||
for i, l in enumerate(batch_length)]
|
||||
|
||||
action_accuracy = sum(action_valid_correct).float() / (sum(batch_length) - 1. * len(batch_length))
|
||||
object_accuracy = sum(object_valid_correct).float() / (sum(batch_length) - 1. * len(batch_length))
|
||||
rel_accuracy = sum(rel_valid_correct).float() / (sum(batch_length) - 1. * len(batch_length))
|
||||
target_accuracy = sum(target_valid_correct).float() / (sum(batch_length) - 1. * len(batch_length))
|
||||
|
||||
info.update({'action_accuracy': action_accuracy.cpu().item()})
|
||||
info.update({'object_accuracy': object_accuracy.cpu().item()})
|
||||
info.update({'rel_accuracy': rel_accuracy.cpu().item()})
|
||||
info.update({'target_accuracy': target_accuracy.cpu().item()})
|
||||
|
||||
class PredicateClassifier(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args,
|
||||
dset,
|
||||
loss_weight
|
||||
):
|
||||
|
||||
super(PredicateClassifier, self).__init__()
|
||||
|
||||
self.num_goal_predicates = dset.num_goal_predicates
|
||||
self.max_possible_count = dset.max_goal_length
|
||||
self.loss = custom_loss(loss_weight)
|
||||
|
||||
hidden_size = args.demo_hidden
|
||||
self.hidden_size = hidden_size
|
||||
self.loss_type = args.loss_type
|
||||
|
||||
if args.dropout==0:
|
||||
print('dropout', args.dropout)
|
||||
classifier = nn.Sequential()
|
||||
classifier.add_module('fc_block1', fc_block(hidden_size*82, hidden_size, False, nn.Tanh))
|
||||
classifier.add_module('fc_block2', fc_block(hidden_size, 79, False, None)) # 79 is all possible actions
|
||||
else:
|
||||
print('dropout not 0', args.dropout)
|
||||
classifier = nn.Sequential()
|
||||
classifier.add_module('fc_block1', fc_block(hidden_size*82, hidden_size, False, nn.Tanh))
|
||||
classifier.add_module('dropout', nn.Dropout(args.dropout))
|
||||
classifier.add_module('fc_block2', fc_block(hidden_size, 79, False, None)) # 79 is all possible actions
|
||||
|
||||
self.classifier = classifier
|
||||
|
||||
def forward(self, input_emb, batch_target, batch_action_id, batch_task_name, **kwargs):
|
||||
input_emb = input_emb.view(-1, self.hidden_size*82)
|
||||
logits = self.classifier(input_emb)
|
||||
|
||||
prob = F.softmax(logits, 1)
|
||||
|
||||
#cross entropy loss
|
||||
if self.loss_type == 'ce':
|
||||
loss = F.cross_entropy(logits, batch_target)
|
||||
#custom loss
|
||||
if self.loss_type == 'regu':
|
||||
loss = self.loss(logits, batch_target)
|
||||
|
||||
argmax_Y = torch.max(logits, 1)[1].view(-1, 1)
|
||||
top1 = (batch_target.float().view(-1, 1) == argmax_Y.float()).sum().item() / len(batch_target.float().view(-1, 1)) * 100
|
||||
|
||||
with torch.no_grad():
|
||||
info = {
|
||||
"prob": prob.cpu().numpy(),
|
||||
"argmax": argmax_Y.cpu().numpy(),
|
||||
"loss": loss.cpu().numpy(),
|
||||
"top1": top1,
|
||||
"target": batch_target.cpu().numpy(),
|
||||
"task_name": batch_task_name,
|
||||
"action_id": batch_action_id.cpu().numpy()
|
||||
}
|
||||
|
||||
return loss, info
|
||||
|
||||
class custom_loss(nn.Module):
|
||||
def __init__(self, loss_weight) -> None:
|
||||
super(custom_loss, self).__init__( )
|
||||
self.loss = nn.CrossEntropyLoss().cuda()
|
||||
self.weight_loss = nn.MSELoss().cuda()
|
||||
self.loss_weight = loss_weight
|
||||
self.counts = torch.FloatTensor(self.loss_weight).cuda()
|
||||
|
||||
def forward(self, pred, target):
|
||||
# weight loss + cross entropy loss
|
||||
batch_counts = torch.bincount(target)
|
||||
batch_counts = batch_counts/torch.sum(batch_counts)
|
||||
if len(batch_counts) < 79:
|
||||
batch_counts = F.pad(input=batch_counts, pad=(0, 79 - len(batch_counts)%79), mode='constant', value=0)
|
||||
|
||||
celoss = self.loss(pred, target)
|
||||
customloss = self.weight_loss(batch_counts, self.counts)
|
||||
print('celoss: ', celoss, 'customloss: ', customloss)
|
||||
loss = celoss + 1000*customloss
|
||||
return loss
|
||||
|
||||
class PredicateClassifierMultiClassifier(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args,
|
||||
dset):
|
||||
|
||||
super(PredicateClassifierMultiClassifier, self).__init__()
|
||||
|
||||
self.num_goal_predicates = dset.num_goal_predicates
|
||||
self.max_possible_count = dset.max_goal_length
|
||||
self.max_subgoal_length = dset.max_subgoal_length
|
||||
|
||||
hidden_size = args.demo_hidden
|
||||
print('hidden_size', hidden_size)
|
||||
print('PredicateClassifierMultiClassifier')
|
||||
|
||||
if args.dropout==0:
|
||||
print('dropout', args.dropout)
|
||||
classifier = nn.Sequential()
|
||||
classifier.add_module('fc_block1', fc_block(hidden_size, hidden_size, False, nn.Tanh))
|
||||
classifier.add_module('fc_block2', fc_block(hidden_size, self.num_goal_predicates*(self.max_subgoal_length+1), False, None))
|
||||
else:
|
||||
print('dropout not 0', args.dropout)
|
||||
classifier = nn.Sequential()
|
||||
classifier.add_module('fc_block1', fc_block(hidden_size, hidden_size, False, nn.Tanh))
|
||||
classifier.add_module('dropout', nn.Dropout(args.dropout))
|
||||
classifier.add_module('fc_block2', fc_block(hidden_size, self.num_goal_predicates*(self.max_subgoal_length+1), False, None))
|
||||
|
||||
|
||||
self.classifier = classifier
|
||||
|
||||
def forward(self, bs, input_emb, batch_target, batch_file_name, **kwargs):
|
||||
|
||||
logits = self.classifier(input_emb)
|
||||
logits = logits.reshape([-1, (self.max_subgoal_length+1)])
|
||||
prob = F.softmax(logits, 1)
|
||||
|
||||
batch_target = torch.cat(batch_target)
|
||||
loss = F.cross_entropy(logits, batch_target)
|
||||
|
||||
top1 = _calculate_accuracy_predicate(logits, batch_target, self.num_goal_predicates, multi_classifier=True)
|
||||
|
||||
with torch.no_grad():
|
||||
info = {
|
||||
"prob": prob.cpu().numpy(),
|
||||
"loss": loss.cpu().numpy(),
|
||||
"top1": top1.cpu().numpy(),
|
||||
"target": batch_target.cpu().numpy(),
|
||||
"file_name": batch_file_name
|
||||
}
|
||||
|
||||
return loss, info
|
||||
|
||||
class ActionDemoEncoder(nn.Module):
|
||||
def __init__(self, args, dset, pooling):
|
||||
super(ActionDemoEncoder, self).__init__()
|
||||
hidden_size = args.demo_hidden
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
len_action_predicates = dset.max_action_len
|
||||
self.action_embed = nn.Embedding(len_action_predicates, hidden_size)
|
||||
|
||||
feat2hidden = nn.Sequential()
|
||||
feat2hidden.add_module(
|
||||
'fc_block1', fc_block(hidden_size, hidden_size, False, nn.ReLU))
|
||||
self.feat2hidden = feat2hidden
|
||||
|
||||
self.pooling = pooling
|
||||
|
||||
if 'lstm' in self.pooling:
|
||||
self.lstm = nn.LSTM(hidden_size, hidden_size)
|
||||
|
||||
|
||||
def forward(self, batch_data, batch_gt, batch_task_name):
|
||||
batch_data = batch_data.view(-1,1)
|
||||
|
||||
stacked_demo_feat = self.action_embed(batch_data)
|
||||
|
||||
stacked_demo_feat = self.feat2hidden(stacked_demo_feat)
|
||||
batch_demo_feat = []
|
||||
start = 0
|
||||
|
||||
for length in range(0,batch_data.shape[0]):
|
||||
if length == 0:
|
||||
feat = stacked_demo_feat[0:1, :]
|
||||
else:
|
||||
feat = stacked_demo_feat[(length-1):length, :]
|
||||
if len(feat.size()) == 3:
|
||||
feat = feat.unsqueeze(0)
|
||||
|
||||
if self.pooling == 'max':
|
||||
feat = torch.max(feat, 0)[0]
|
||||
elif self.pooling == 'avg':
|
||||
feat = torch.mean(feat, 0)
|
||||
elif self.pooling == 'lstmavg':
|
||||
lstm_out, hidden = self.lstm(feat.view(len(feat), 1, -1))
|
||||
lstm_out = lstm_out.view(len(feat), -1)
|
||||
feat = torch.mean(lstm_out, 0)
|
||||
elif self.pooling == 'lstmlast':
|
||||
lstm_out, hidden = self.lstm(feat.view(len(feat), 1, -1))
|
||||
lstm_out = lstm_out.view(len(feat), -1)
|
||||
feat = lstm_out[-1]
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
|
||||
batch_demo_feat.append(feat)
|
||||
|
||||
demo_emb = torch.stack(batch_demo_feat, 0)
|
||||
demo_emb = demo_emb.view(8,82, -1)
|
||||
return demo_emb, batch_demo_feat
|
||||
|
Loading…
Add table
Add a link
Reference in a new issue