first commit

This commit is contained in:
Lei Shi 2024-03-24 23:42:27 +01:00
commit 83b04e2133
109 changed files with 12081 additions and 0 deletions

View file

@ -0,0 +1,131 @@
import copy
import numpy as np
from termcolor import colored
import torch
import torch.nn as nn
import torch.nn.functional as F
#import ipdb
#import pdb
def _align_tensor_index(reference_index, tensor_index):
where_in_tensor = []
for i in reference_index:
where = np.where(i == tensor_index)[0][0]
where_in_tensor.append(where)
return np.array(where_in_tensor)
def _sort_by_length(list_of_tensor, batch_length, return_idx=False):
idx = np.argsort(np.array(copy.copy(batch_length)))[::-1]
for i, tensor in enumerate(list_of_tensor):
if isinstance(tensor, dict):
list_of_tensor[i]['class_objects'] = [tensor['class_objects'][j] for j in idx]
list_of_tensor[i]['object_coords'] = [tensor['object_coords'][j] for j in idx]
list_of_tensor[i]['states_objects'] = [tensor['states_objects'][j] for j in idx]
list_of_tensor[i]['mask_object'] = [tensor['mask_object'][j] for j in idx]
else:
list_of_tensor[i] = [tensor[j] for j in idx]
if return_idx:
return list_of_tensor, idx
else:
return list_of_tensor
def _sort_by_index(list_of_tensor, idx):
for i, tensor in enumerate(list_of_tensor):
list_of_tensor[i] = [tensor[j] for j in idx]
return list_of_tensor
class ActionDemo2Predicate(nn.Module):
summary_keys = ['loss', 'top1']
def __init__(self, args, dset, loss_weight, **kwargs):
from network.module_graph import PredicateClassifier
super(ActionDemo2Predicate, self).__init__()
print('------------------------------------------------------------------------------------------')
print('ActionDemo2Predicate')
print('------------------------------------------------------------------------------------------')
model_type = kwargs["model_type"]
print('model_type', model_type)
if model_type.lower() == 'max':
from network.module_graph import ActionDemoEncoder
demo_encoder = ActionDemoEncoder(args, dset, 'max')
elif model_type.lower() == 'avg':
from network.module_graph import ActionDemoEncoder
demo_encoder = ActionDemoEncoder(args, dset, 'avg')
elif model_type.lower() == 'lstmavg':
from network.module_graph import ActionDemoEncoder
demo_encoder = ActionDemoEncoder(args, dset, 'lstmavg')
elif model_type.lower() == 'bilstmavg':
from network.module_graph import ActionDemoEncoder
demo_encoder = ActionDemoEncoder(args, dset, 'bilstmavg')
elif model_type.lower() == 'lstmlast':
from network.module_graph import ActionDemoEncoder
demo_encoder = ActionDemoEncoder(args, dset, 'lstmlast')
elif model_type.lower() == 'bilstmlast':
from network.module_graph import ActionDemoEncoder
demo_encoder = ActionDemoEncoder(args, dset, 'bilstmlast')
else:
raise ValueError
demo_encoder = torch.nn.DataParallel(demo_encoder)
predicate_decoder = PredicateClassifier(args, dset, loss_weight)
# for quick save and load
all_modules = nn.Sequential()
all_modules.add_module('demo_encoder', demo_encoder)
all_modules.add_module('predicate_decoder', predicate_decoder)
self.demo_encoder = demo_encoder
self.predicate_decoder = predicate_decoder
self.all_modules = all_modules
self.to_cuda_fn = None
def set_to_cuda_fn(self, to_cuda_fn):
self.to_cuda_fn = to_cuda_fn
def forward(self, data, **kwargs):
if self.to_cuda_fn:
data = self.to_cuda_fn(data)
# demonstration
batch_data = data[0].cuda()
batch_gt = data[1].cuda()
batch_task_name = data[2]
batch_action_id = data[3]
# demonstration encoder
batch_demo_emb, _ = self.demo_encoder(batch_data, batch_gt, batch_task_name)
loss, info = self.predicate_decoder(batch_demo_emb, batch_gt, batch_action_id, batch_task_name)
return loss, info
def write_summary(self, writer, info, postfix):
model_name = 'Demo2Predicate-{}/'.format(postfix)
for k in self.summary_keys:
if k in info.keys():
writer.scalar_summary(model_name + k, info[k])
def save(self, path, verbose=False):
if verbose:
print(colored('[*] Save model at {}'.format(path), 'magenta'))
torch.save(self.all_modules.state_dict(), path)
def load(self, path, verbose=False):
if verbose:
print(colored('[*] Load model at {}'.format(path), 'magenta'))
self.all_modules.load_state_dict(
torch.load(
path,
map_location=lambda storage,
loc: storage))

View file

@ -0,0 +1,249 @@
import random
import itertools
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from helper import fc_block, Constant
def _calculate_accuracy_predicate(logits, batch_target, max_possible_count=None, topk=1, multi_classifier=False):
batch_size = batch_target.size(0) / max_possible_count
_, pred = logits.topk(topk, 1, True, True)
pred = pred.t()
correct = pred.eq(batch_target.view(1, -1).expand_as(pred))
k = 1
accuray = correct[:k].view(-1).float()
accuray = accuray.view(-1, max_possible_count)
correct_k = (accuray.sum(1)==max_possible_count).sum(0)
correct_k = correct_k * (100.0 / batch_size)
return correct_k
def _calculate_accuracy(
action_correct,
object_correct,
rel_correct,
target_correct,
batch_length,
info):
action_valid_correct = [sum(action_correct[i, :(l - 1)])
for i, l in enumerate(batch_length)]
object_valid_correct = [sum(object_correct[i, :(l - 1)])
for i, l in enumerate(batch_length)]
rel_valid_correct = [sum(rel_correct[i, :(l - 1)])
for i, l in enumerate(batch_length)]
target_valid_correct = [sum(target_correct[i, :(l - 1)])
for i, l in enumerate(batch_length)]
action_accuracy = sum(action_valid_correct).float() / (sum(batch_length) - 1. * len(batch_length))
object_accuracy = sum(object_valid_correct).float() / (sum(batch_length) - 1. * len(batch_length))
rel_accuracy = sum(rel_valid_correct).float() / (sum(batch_length) - 1. * len(batch_length))
target_accuracy = sum(target_valid_correct).float() / (sum(batch_length) - 1. * len(batch_length))
info.update({'action_accuracy': action_accuracy.cpu().item()})
info.update({'object_accuracy': object_accuracy.cpu().item()})
info.update({'rel_accuracy': rel_accuracy.cpu().item()})
info.update({'target_accuracy': target_accuracy.cpu().item()})
class PredicateClassifier(nn.Module):
def __init__(
self,
args,
dset,
loss_weight
):
super(PredicateClassifier, self).__init__()
self.num_goal_predicates = dset.num_goal_predicates
self.max_possible_count = dset.max_goal_length
self.loss = custom_loss(loss_weight)
hidden_size = args.demo_hidden
self.hidden_size = hidden_size
self.loss_type = args.loss_type
if args.dropout==0:
print('dropout', args.dropout)
classifier = nn.Sequential()
classifier.add_module('fc_block1', fc_block(hidden_size*82, hidden_size, False, nn.Tanh))
classifier.add_module('fc_block2', fc_block(hidden_size, 79, False, None)) # 79 is all possible actions
else:
print('dropout not 0', args.dropout)
classifier = nn.Sequential()
classifier.add_module('fc_block1', fc_block(hidden_size*82, hidden_size, False, nn.Tanh))
classifier.add_module('dropout', nn.Dropout(args.dropout))
classifier.add_module('fc_block2', fc_block(hidden_size, 79, False, None)) # 79 is all possible actions
self.classifier = classifier
def forward(self, input_emb, batch_target, batch_action_id, batch_task_name, **kwargs):
input_emb = input_emb.view(-1, self.hidden_size*82)
logits = self.classifier(input_emb)
prob = F.softmax(logits, 1)
#cross entropy loss
if self.loss_type == 'ce':
loss = F.cross_entropy(logits, batch_target)
#custom loss
if self.loss_type == 'regu':
loss = self.loss(logits, batch_target)
argmax_Y = torch.max(logits, 1)[1].view(-1, 1)
top1 = (batch_target.float().view(-1, 1) == argmax_Y.float()).sum().item() / len(batch_target.float().view(-1, 1)) * 100
with torch.no_grad():
info = {
"prob": prob.cpu().numpy(),
"argmax": argmax_Y.cpu().numpy(),
"loss": loss.cpu().numpy(),
"top1": top1,
"target": batch_target.cpu().numpy(),
"task_name": batch_task_name,
"action_id": batch_action_id.cpu().numpy()
}
return loss, info
class custom_loss(nn.Module):
def __init__(self, loss_weight) -> None:
super(custom_loss, self).__init__( )
self.loss = nn.CrossEntropyLoss().cuda()
self.weight_loss = nn.MSELoss().cuda()
self.loss_weight = loss_weight
self.counts = torch.FloatTensor(self.loss_weight).cuda()
def forward(self, pred, target):
# weight loss + cross entropy loss
batch_counts = torch.bincount(target)
batch_counts = batch_counts/torch.sum(batch_counts)
if len(batch_counts) < 79:
batch_counts = F.pad(input=batch_counts, pad=(0, 79 - len(batch_counts)%79), mode='constant', value=0)
celoss = self.loss(pred, target)
customloss = self.weight_loss(batch_counts, self.counts)
print('celoss: ', celoss, 'customloss: ', customloss)
loss = celoss + 1000*customloss
return loss
class PredicateClassifierMultiClassifier(nn.Module):
def __init__(
self,
args,
dset):
super(PredicateClassifierMultiClassifier, self).__init__()
self.num_goal_predicates = dset.num_goal_predicates
self.max_possible_count = dset.max_goal_length
self.max_subgoal_length = dset.max_subgoal_length
hidden_size = args.demo_hidden
print('hidden_size', hidden_size)
print('PredicateClassifierMultiClassifier')
if args.dropout==0:
print('dropout', args.dropout)
classifier = nn.Sequential()
classifier.add_module('fc_block1', fc_block(hidden_size, hidden_size, False, nn.Tanh))
classifier.add_module('fc_block2', fc_block(hidden_size, self.num_goal_predicates*(self.max_subgoal_length+1), False, None))
else:
print('dropout not 0', args.dropout)
classifier = nn.Sequential()
classifier.add_module('fc_block1', fc_block(hidden_size, hidden_size, False, nn.Tanh))
classifier.add_module('dropout', nn.Dropout(args.dropout))
classifier.add_module('fc_block2', fc_block(hidden_size, self.num_goal_predicates*(self.max_subgoal_length+1), False, None))
self.classifier = classifier
def forward(self, bs, input_emb, batch_target, batch_file_name, **kwargs):
logits = self.classifier(input_emb)
logits = logits.reshape([-1, (self.max_subgoal_length+1)])
prob = F.softmax(logits, 1)
batch_target = torch.cat(batch_target)
loss = F.cross_entropy(logits, batch_target)
top1 = _calculate_accuracy_predicate(logits, batch_target, self.num_goal_predicates, multi_classifier=True)
with torch.no_grad():
info = {
"prob": prob.cpu().numpy(),
"loss": loss.cpu().numpy(),
"top1": top1.cpu().numpy(),
"target": batch_target.cpu().numpy(),
"file_name": batch_file_name
}
return loss, info
class ActionDemoEncoder(nn.Module):
def __init__(self, args, dset, pooling):
super(ActionDemoEncoder, self).__init__()
hidden_size = args.demo_hidden
self.hidden_size = hidden_size
len_action_predicates = dset.max_action_len
self.action_embed = nn.Embedding(len_action_predicates, hidden_size)
feat2hidden = nn.Sequential()
feat2hidden.add_module(
'fc_block1', fc_block(hidden_size, hidden_size, False, nn.ReLU))
self.feat2hidden = feat2hidden
self.pooling = pooling
if 'lstm' in self.pooling:
self.lstm = nn.LSTM(hidden_size, hidden_size)
def forward(self, batch_data, batch_gt, batch_task_name):
batch_data = batch_data.view(-1,1)
stacked_demo_feat = self.action_embed(batch_data)
stacked_demo_feat = self.feat2hidden(stacked_demo_feat)
batch_demo_feat = []
start = 0
for length in range(0,batch_data.shape[0]):
if length == 0:
feat = stacked_demo_feat[0:1, :]
else:
feat = stacked_demo_feat[(length-1):length, :]
if len(feat.size()) == 3:
feat = feat.unsqueeze(0)
if self.pooling == 'max':
feat = torch.max(feat, 0)[0]
elif self.pooling == 'avg':
feat = torch.mean(feat, 0)
elif self.pooling == 'lstmavg':
lstm_out, hidden = self.lstm(feat.view(len(feat), 1, -1))
lstm_out = lstm_out.view(len(feat), -1)
feat = torch.mean(lstm_out, 0)
elif self.pooling == 'lstmlast':
lstm_out, hidden = self.lstm(feat.view(len(feat), 1, -1))
lstm_out = lstm_out.view(len(feat), -1)
feat = lstm_out[-1]
else:
raise ValueError
batch_demo_feat.append(feat)
demo_emb = torch.stack(batch_demo_feat, 0)
demo_emb = demo_emb.view(8,82, -1)
return demo_emb, batch_demo_feat