InferringIntention/keyboard_and_mouse/networks.py

168 lines
5.1 KiB
Python
Raw Normal View History

2024-03-24 23:42:27 +01:00
import torch
import torch.nn as nn
class fc_block(nn.Module):
def __init__(self, in_channels, out_channels, norm, activation_fn):
super(fc_block, self).__init__()
block = nn.Sequential()
block.add_module('linear', nn.Linear(in_channels, out_channels))
if norm:
block.add_module('batchnorm', nn.BatchNorm1d(out_channels))
if activation_fn is not None:
block.add_module('activation', activation_fn())
self.block = block
def forward(self, x):
return self.block(x)
class ActionDemoEncoder(nn.Module):
def __init__(self, args, pooling):
super(ActionDemoEncoder, self).__init__()
hidden_size = args.demo_hidden
self.hidden_size = hidden_size
self.bs = args.batch_size
len_action_predicates = 35 # max_action_len
self.action_embed = nn.Embedding(len_action_predicates, hidden_size)
feat2hidden = nn.Sequential()
feat2hidden.add_module(
'fc_block1', fc_block(hidden_size, hidden_size, False, nn.ReLU))
self.feat2hidden = feat2hidden
self.pooling = pooling
if 'lstm' in self.pooling:
self.lstm = nn.LSTM(hidden_size, hidden_size)
def forward(self, batch_data):
batch_data = batch_data.view(-1,1)
stacked_demo_feat = self.action_embed(batch_data)
stacked_demo_feat = self.feat2hidden(stacked_demo_feat)
batch_demo_feat = []
start = 0
for length in range(0,batch_data.shape[0]):
if length == 0:
feat = stacked_demo_feat[0:1, :]
else:
feat = stacked_demo_feat[(length-1):length, :]
if len(feat.size()) == 3:
feat = feat.unsqueeze(0)
if self.pooling == 'max':
feat = torch.max(feat, 0)[0]
elif self.pooling == 'avg':
feat = torch.mean(feat, 0)
elif self.pooling == 'lstmavg':
lstm_out, hidden = self.lstm(feat.view(len(feat), 1, -1))
lstm_out = lstm_out.view(len(feat), -1)
feat = torch.mean(lstm_out, 0)
elif self.pooling == 'lstmlast':
lstm_out, hidden = self.lstm(feat.view(len(feat), 1, -1))
lstm_out = lstm_out.view(len(feat), -1)
feat = lstm_out[-1]
else:
raise ValueError
batch_demo_feat.append(feat)
demo_emb = torch.stack(batch_demo_feat, 0)
demo_emb = demo_emb.view(self.bs, 35, -1)
return demo_emb
class PredicateClassifier(nn.Module):
def __init__(self, args,):
super(PredicateClassifier, self).__init__()
hidden_size = args.demo_hidden
self.hidden_size = hidden_size
classifier = nn.Sequential()
classifier.add_module('fc_block1', fc_block(hidden_size*35, hidden_size, False, nn.Tanh))
classifier.add_module('dropout', nn.Dropout(args.dropout))
classifier.add_module('fc_block2', fc_block(hidden_size, 7, False, None)) # 7 is all possible actions
self.classifier = classifier
def forward(self, input_emb):
input_emb = input_emb.view(-1, self.hidden_size*35)
return self.classifier(input_emb)
class ActionDemo2Predicate(nn.Module):
def __init__(self, args, **kwargs):
super(ActionDemo2Predicate, self).__init__()
print('------------------------------------------------------------------------------------------')
print('ActionDemo2Predicate')
print('------------------------------------------------------------------------------------------')
model_type = args.model_type
print('model_type', model_type)
if model_type.lower() == 'max':
demo_encoder = ActionDemoEncoder(args, 'max')
elif model_type.lower() == 'avg':
demo_encoder = ActionDemoEncoder(args, 'avg')
elif model_type.lower() == 'lstmavg':
demo_encoder = ActionDemoEncoder(args, 'lstmavg')
elif model_type.lower() == 'bilstmavg':
demo_encoder = ActionDemoEncoder(args, 'bilstmavg')
elif model_type.lower() == 'lstmlast':
demo_encoder = ActionDemoEncoder(args, 'lstmlast')
elif model_type.lower() == 'bilstmlast':
demo_encoder = ActionDemoEncoder(args, 'bilstmlast')
else:
raise ValueError
demo_encoder = torch.nn.DataParallel(demo_encoder)
predicate_decoder = PredicateClassifier(args)
# for quick save and load
all_modules = nn.Sequential()
all_modules.add_module('demo_encoder', demo_encoder)
all_modules.add_module('predicate_decoder', predicate_decoder)
self.demo_encoder = demo_encoder
self.predicate_decoder = predicate_decoder
self.all_modules = all_modules
self.to_cuda_fn = None
def set_to_cuda_fn(self, to_cuda_fn):
self.to_cuda_fn = to_cuda_fn
def forward(self, data, **kwargs):
'''
Note: The order of the `data` won't change in this function
'''
if self.to_cuda_fn:
data = self.to_cuda_fn(data)
batch_demo_emb = self.demo_encoder(data)
pred = self.predicate_decoder(batch_demo_emb)
return pred
def write_summary(self, writer, info, postfix):
model_name = 'Demo2Predicate-{}/'.format(postfix)
for k in self.summary_keys:
if k in info.keys():
writer.scalar_summary(model_name + k, info[k])
def save(self, path, verbose=False):
if verbose:
print(colored('[*] Save model at {}'.format(path), 'magenta'))
torch.save(self.all_modules.state_dict(), path)
def load(self, path, verbose=False):
if verbose:
print(colored('[*] Load model at {}'.format(path), 'magenta'))
self.all_modules.load_state_dict(
torch.load(
path,
map_location=lambda storage,
loc: storage))