168 lines
5.1 KiB
Python
168 lines
5.1 KiB
Python
|
import torch
|
||
|
import torch.nn as nn
|
||
|
|
||
|
class fc_block(nn.Module):
|
||
|
def __init__(self, in_channels, out_channels, norm, activation_fn):
|
||
|
super(fc_block, self).__init__()
|
||
|
block = nn.Sequential()
|
||
|
block.add_module('linear', nn.Linear(in_channels, out_channels))
|
||
|
if norm:
|
||
|
block.add_module('batchnorm', nn.BatchNorm1d(out_channels))
|
||
|
if activation_fn is not None:
|
||
|
block.add_module('activation', activation_fn())
|
||
|
|
||
|
self.block = block
|
||
|
|
||
|
def forward(self, x):
|
||
|
return self.block(x)
|
||
|
|
||
|
class ActionDemoEncoder(nn.Module):
|
||
|
def __init__(self, args, pooling):
|
||
|
super(ActionDemoEncoder, self).__init__()
|
||
|
hidden_size = args.demo_hidden
|
||
|
self.hidden_size = hidden_size
|
||
|
self.bs = args.batch_size
|
||
|
|
||
|
len_action_predicates = 35 # max_action_len
|
||
|
self.action_embed = nn.Embedding(len_action_predicates, hidden_size)
|
||
|
|
||
|
feat2hidden = nn.Sequential()
|
||
|
feat2hidden.add_module(
|
||
|
'fc_block1', fc_block(hidden_size, hidden_size, False, nn.ReLU))
|
||
|
self.feat2hidden = feat2hidden
|
||
|
|
||
|
self.pooling = pooling
|
||
|
|
||
|
if 'lstm' in self.pooling:
|
||
|
self.lstm = nn.LSTM(hidden_size, hidden_size)
|
||
|
|
||
|
def forward(self, batch_data):
|
||
|
batch_data = batch_data.view(-1,1)
|
||
|
stacked_demo_feat = self.action_embed(batch_data)
|
||
|
stacked_demo_feat = self.feat2hidden(stacked_demo_feat)
|
||
|
batch_demo_feat = []
|
||
|
start = 0
|
||
|
|
||
|
for length in range(0,batch_data.shape[0]):
|
||
|
if length == 0:
|
||
|
feat = stacked_demo_feat[0:1, :]
|
||
|
else:
|
||
|
feat = stacked_demo_feat[(length-1):length, :]
|
||
|
if len(feat.size()) == 3:
|
||
|
feat = feat.unsqueeze(0)
|
||
|
|
||
|
if self.pooling == 'max':
|
||
|
feat = torch.max(feat, 0)[0]
|
||
|
elif self.pooling == 'avg':
|
||
|
feat = torch.mean(feat, 0)
|
||
|
elif self.pooling == 'lstmavg':
|
||
|
lstm_out, hidden = self.lstm(feat.view(len(feat), 1, -1))
|
||
|
lstm_out = lstm_out.view(len(feat), -1)
|
||
|
feat = torch.mean(lstm_out, 0)
|
||
|
elif self.pooling == 'lstmlast':
|
||
|
lstm_out, hidden = self.lstm(feat.view(len(feat), 1, -1))
|
||
|
lstm_out = lstm_out.view(len(feat), -1)
|
||
|
feat = lstm_out[-1]
|
||
|
else:
|
||
|
raise ValueError
|
||
|
|
||
|
|
||
|
batch_demo_feat.append(feat)
|
||
|
|
||
|
demo_emb = torch.stack(batch_demo_feat, 0)
|
||
|
demo_emb = demo_emb.view(self.bs, 35, -1)
|
||
|
return demo_emb
|
||
|
|
||
|
class PredicateClassifier(nn.Module):
|
||
|
|
||
|
def __init__(self, args,):
|
||
|
super(PredicateClassifier, self).__init__()
|
||
|
hidden_size = args.demo_hidden
|
||
|
self.hidden_size = hidden_size
|
||
|
|
||
|
classifier = nn.Sequential()
|
||
|
classifier.add_module('fc_block1', fc_block(hidden_size*35, hidden_size, False, nn.Tanh))
|
||
|
classifier.add_module('dropout', nn.Dropout(args.dropout))
|
||
|
classifier.add_module('fc_block2', fc_block(hidden_size, 7, False, None)) # 7 is all possible actions
|
||
|
|
||
|
self.classifier = classifier
|
||
|
|
||
|
def forward(self, input_emb):
|
||
|
input_emb = input_emb.view(-1, self.hidden_size*35)
|
||
|
return self.classifier(input_emb)
|
||
|
|
||
|
|
||
|
class ActionDemo2Predicate(nn.Module):
|
||
|
def __init__(self, args, **kwargs):
|
||
|
super(ActionDemo2Predicate, self).__init__()
|
||
|
|
||
|
print('------------------------------------------------------------------------------------------')
|
||
|
print('ActionDemo2Predicate')
|
||
|
print('------------------------------------------------------------------------------------------')
|
||
|
|
||
|
model_type = args.model_type
|
||
|
print('model_type', model_type)
|
||
|
|
||
|
if model_type.lower() == 'max':
|
||
|
demo_encoder = ActionDemoEncoder(args, 'max')
|
||
|
elif model_type.lower() == 'avg':
|
||
|
demo_encoder = ActionDemoEncoder(args, 'avg')
|
||
|
elif model_type.lower() == 'lstmavg':
|
||
|
demo_encoder = ActionDemoEncoder(args, 'lstmavg')
|
||
|
elif model_type.lower() == 'bilstmavg':
|
||
|
demo_encoder = ActionDemoEncoder(args, 'bilstmavg')
|
||
|
elif model_type.lower() == 'lstmlast':
|
||
|
demo_encoder = ActionDemoEncoder(args, 'lstmlast')
|
||
|
elif model_type.lower() == 'bilstmlast':
|
||
|
demo_encoder = ActionDemoEncoder(args, 'bilstmlast')
|
||
|
else:
|
||
|
raise ValueError
|
||
|
demo_encoder = torch.nn.DataParallel(demo_encoder)
|
||
|
|
||
|
predicate_decoder = PredicateClassifier(args)
|
||
|
|
||
|
# for quick save and load
|
||
|
all_modules = nn.Sequential()
|
||
|
all_modules.add_module('demo_encoder', demo_encoder)
|
||
|
all_modules.add_module('predicate_decoder', predicate_decoder)
|
||
|
|
||
|
self.demo_encoder = demo_encoder
|
||
|
self.predicate_decoder = predicate_decoder
|
||
|
self.all_modules = all_modules
|
||
|
self.to_cuda_fn = None
|
||
|
|
||
|
def set_to_cuda_fn(self, to_cuda_fn):
|
||
|
self.to_cuda_fn = to_cuda_fn
|
||
|
|
||
|
def forward(self, data, **kwargs):
|
||
|
'''
|
||
|
Note: The order of the `data` won't change in this function
|
||
|
'''
|
||
|
if self.to_cuda_fn:
|
||
|
data = self.to_cuda_fn(data)
|
||
|
|
||
|
batch_demo_emb = self.demo_encoder(data)
|
||
|
pred = self.predicate_decoder(batch_demo_emb)
|
||
|
return pred
|
||
|
|
||
|
def write_summary(self, writer, info, postfix):
|
||
|
model_name = 'Demo2Predicate-{}/'.format(postfix)
|
||
|
for k in self.summary_keys:
|
||
|
if k in info.keys():
|
||
|
writer.scalar_summary(model_name + k, info[k])
|
||
|
|
||
|
def save(self, path, verbose=False):
|
||
|
if verbose:
|
||
|
print(colored('[*] Save model at {}'.format(path), 'magenta'))
|
||
|
torch.save(self.all_modules.state_dict(), path)
|
||
|
|
||
|
def load(self, path, verbose=False):
|
||
|
if verbose:
|
||
|
print(colored('[*] Load model at {}'.format(path), 'magenta'))
|
||
|
self.all_modules.load_state_dict(
|
||
|
torch.load(
|
||
|
path,
|
||
|
map_location=lambda storage,
|
||
|
loc: storage))
|
||
|
|