import torch import torch.nn as nn class fc_block(nn.Module): def __init__(self, in_channels, out_channels, norm, activation_fn): super(fc_block, self).__init__() block = nn.Sequential() block.add_module('linear', nn.Linear(in_channels, out_channels)) if norm: block.add_module('batchnorm', nn.BatchNorm1d(out_channels)) if activation_fn is not None: block.add_module('activation', activation_fn()) self.block = block def forward(self, x): return self.block(x) class ActionDemoEncoder(nn.Module): def __init__(self, args, pooling): super(ActionDemoEncoder, self).__init__() hidden_size = args.demo_hidden self.hidden_size = hidden_size self.bs = args.batch_size len_action_predicates = 35 # max_action_len self.action_embed = nn.Embedding(len_action_predicates, hidden_size) feat2hidden = nn.Sequential() feat2hidden.add_module( 'fc_block1', fc_block(hidden_size, hidden_size, False, nn.ReLU)) self.feat2hidden = feat2hidden self.pooling = pooling if 'lstm' in self.pooling: self.lstm = nn.LSTM(hidden_size, hidden_size) def forward(self, batch_data): batch_data = batch_data.view(-1,1) stacked_demo_feat = self.action_embed(batch_data) stacked_demo_feat = self.feat2hidden(stacked_demo_feat) batch_demo_feat = [] start = 0 for length in range(0,batch_data.shape[0]): if length == 0: feat = stacked_demo_feat[0:1, :] else: feat = stacked_demo_feat[(length-1):length, :] if len(feat.size()) == 3: feat = feat.unsqueeze(0) if self.pooling == 'max': feat = torch.max(feat, 0)[0] elif self.pooling == 'avg': feat = torch.mean(feat, 0) elif self.pooling == 'lstmavg': lstm_out, hidden = self.lstm(feat.view(len(feat), 1, -1)) lstm_out = lstm_out.view(len(feat), -1) feat = torch.mean(lstm_out, 0) elif self.pooling == 'lstmlast': lstm_out, hidden = self.lstm(feat.view(len(feat), 1, -1)) lstm_out = lstm_out.view(len(feat), -1) feat = lstm_out[-1] else: raise ValueError batch_demo_feat.append(feat) demo_emb = torch.stack(batch_demo_feat, 0) demo_emb = demo_emb.view(self.bs, 35, -1) return demo_emb class PredicateClassifier(nn.Module): def __init__(self, args,): super(PredicateClassifier, self).__init__() hidden_size = args.demo_hidden self.hidden_size = hidden_size classifier = nn.Sequential() classifier.add_module('fc_block1', fc_block(hidden_size*35, hidden_size, False, nn.Tanh)) classifier.add_module('dropout', nn.Dropout(args.dropout)) classifier.add_module('fc_block2', fc_block(hidden_size, 7, False, None)) # 7 is all possible actions self.classifier = classifier def forward(self, input_emb): input_emb = input_emb.view(-1, self.hidden_size*35) return self.classifier(input_emb) class ActionDemo2Predicate(nn.Module): def __init__(self, args, **kwargs): super(ActionDemo2Predicate, self).__init__() print('------------------------------------------------------------------------------------------') print('ActionDemo2Predicate') print('------------------------------------------------------------------------------------------') model_type = args.model_type print('model_type', model_type) if model_type.lower() == 'max': demo_encoder = ActionDemoEncoder(args, 'max') elif model_type.lower() == 'avg': demo_encoder = ActionDemoEncoder(args, 'avg') elif model_type.lower() == 'lstmavg': demo_encoder = ActionDemoEncoder(args, 'lstmavg') elif model_type.lower() == 'bilstmavg': demo_encoder = ActionDemoEncoder(args, 'bilstmavg') elif model_type.lower() == 'lstmlast': demo_encoder = ActionDemoEncoder(args, 'lstmlast') elif model_type.lower() == 'bilstmlast': demo_encoder = ActionDemoEncoder(args, 'bilstmlast') else: raise ValueError demo_encoder = torch.nn.DataParallel(demo_encoder) predicate_decoder = PredicateClassifier(args) # for quick save and load all_modules = nn.Sequential() all_modules.add_module('demo_encoder', demo_encoder) all_modules.add_module('predicate_decoder', predicate_decoder) self.demo_encoder = demo_encoder self.predicate_decoder = predicate_decoder self.all_modules = all_modules self.to_cuda_fn = None def set_to_cuda_fn(self, to_cuda_fn): self.to_cuda_fn = to_cuda_fn def forward(self, data, **kwargs): ''' Note: The order of the `data` won't change in this function ''' if self.to_cuda_fn: data = self.to_cuda_fn(data) batch_demo_emb = self.demo_encoder(data) pred = self.predicate_decoder(batch_demo_emb) return pred def write_summary(self, writer, info, postfix): model_name = 'Demo2Predicate-{}/'.format(postfix) for k in self.summary_keys: if k in info.keys(): writer.scalar_summary(model_name + k, info[k]) def save(self, path, verbose=False): if verbose: print(colored('[*] Save model at {}'.format(path), 'magenta')) torch.save(self.all_modules.state_dict(), path) def load(self, path, verbose=False): if verbose: print(colored('[*] Load model at {}'.format(path), 'magenta')) self.all_modules.load_state_dict( torch.load( path, map_location=lambda storage, loc: storage))