from argparse import ArgumentParser import torch import torch.nn as nn import torch.nn.functional as F import pytorch_lightning as pl from pytorch_lightning.callbacks import ModelCheckpoint from torch.utils.data import DataLoader from tom.dataset import * from tom.transformer import TransformerEncoder from tom.gnn import RGATv2, RGATv3, RGATv3Agent, RGATv4, RGATv4Norm, RSAGEv4, RAGNNv4 class MlpModel(nn.Module): """Multilayer Perceptron with last layer linear. Args: input_size (int): number of inputs hidden_sizes (list): can be empty list for none (linear model). output_size: linear layer at output, or if ``None``, the last hidden size will be the output size and will have nonlinearity applied nonlinearity: torch nonlinearity Module (not Functional). """ def __init__( self, input_size, hidden_sizes, # Can be empty list or None for none. output_size=None, # if None, last layer has nonlinearity applied. nonlinearity=nn.ReLU, # Module, not Functional. dropout=None # Dropout value ): super().__init__() if isinstance(hidden_sizes, int): hidden_sizes = [hidden_sizes] elif hidden_sizes is None: hidden_sizes = [] hidden_layers = [nn.Linear(n_in, n_out) for n_in, n_out in zip([input_size] + hidden_sizes[:-1], hidden_sizes)] sequence = list() for i, layer in enumerate(hidden_layers): if dropout is not None: sequence.extend([layer, nonlinearity(), nn.Dropout(dropout)]) else: sequence.extend([layer, nonlinearity()]) if output_size is not None: last_size = hidden_sizes[-1] if hidden_sizes else input_size sequence.append(torch.nn.Linear(last_size, output_size)) self.model = nn.Sequential(*sequence) self._output_size = (hidden_sizes[-1] if output_size is None else output_size) def forward(self, input): """Compute the model on the input, assuming input shape [B,input_size].""" return self.model(input) @property def output_size(self): """Retuns the output size of the model.""" return self._output_size # --------------------------------------------------------------------------------------------------------------------------------- class GraphBCRNN(pl.LightningModule): """ Implementation of the baseline model for the BC-RNN algorithm. R-GCN + LSTM are used to encode the familiarization trials """ @staticmethod def add_model_specific_args(parent_parser): parser = ArgumentParser(parents=[parent_parser], add_help=False) parser.add_argument('--lr', type=float, default=1e-4) parser.add_argument('--action_dim', type=int, default=2) parser.add_argument('--context_dim', type=int, default=32) # lstm hidden size parser.add_argument('--beta', type=float, default=0.01) parser.add_argument('--dropout', type=float, default=0.2) parser.add_argument('--process_data', type=int, default=0) parser.add_argument('--max_len', type=int, default=30) # arguments for gnn parser.add_argument('--gnn_type', type=str, default='RGATv4') parser.add_argument('--state_dim', type=int, default=128) # gnn out_feats parser.add_argument('--feats_dims', type=list, default=[9, 2, 3, 18]) parser.add_argument('--aggregation', type=str, default='sum') # arguments for mpl #parser.add_argument('--mpl_hid_feats', type=list, default=[256, 64, 16]) return parser def __init__(self, hparams): super().__init__() self.hparams = hparams self.lr = self.hparams.lr self.state_dim = self.hparams.state_dim self.action_dim = self.hparams.action_dim self.context_dim = self.hparams.context_dim self.beta = self.hparams.beta self.dropout = self.hparams.dropout self.max_len = self.hparams.max_len self.feats_dims = self.hparams.feats_dims # type, position, color, shape self.rel_names = [ 'is_aligned', 'is_back', 'is_close', 'is_down_adj', 'is_down_left_adj', 'is_down_right_adj', 'is_front', 'is_left', 'is_left_adj', 'is_right', 'is_right_adj', 'is_top_adj', 'is_top_left_adj', 'is_top_right_adj' ] self.gnn_aggregation = self.hparams.aggregation if self.hparams.gnn_type == 'RGATv2': self.gnn_encoder = RGATv2( hidden_channels=self.state_dim, out_channels=self.state_dim, num_heads=4, rel_names=self.rel_names, dropout=0.0, n_layers=2, activation=nn.ELU(), residual=False ) if self.hparams.gnn_type == 'RGATv3': self.gnn_encoder = RGATv3( hidden_channels=self.state_dim, out_channels=self.state_dim, num_heads=4, rel_names=self.rel_names, dropout=0.0, n_layers=2, activation=nn.ELU(), residual=False ) if self.hparams.gnn_type == 'RGATv4': self.gnn_encoder = RGATv4( hidden_channels=self.state_dim, out_channels=self.state_dim, num_heads=4, rel_names=self.rel_names, dropout=0.0, n_layers=2, activation=nn.ELU(), residual=False ) if self.hparams.gnn_type == 'RGATv3Agent': self.gnn_encoder = RGATv3Agent( hidden_channels=self.state_dim, out_channels=self.state_dim, num_heads=4, rel_names=self.rel_names, dropout=0.0, n_layers=2, activation=nn.ELU(), residual=False ) if self.hparams.gnn_type == 'RSAGEv4': self.gnn_encoder = RSAGEv4( hidden_channels=self.state_dim, out_channels=self.state_dim, rel_names=self.rel_names, dropout=0.0, n_layers=2, activation=nn.ELU() ) if self.gnn_aggregation == 'cat_axis_1': self.lstm_input_size = self.state_dim * len(self.feats_dims) + self.action_dim self.mlp_input_size = self.state_dim * len(self.feats_dims) + self.context_dim * 2 elif self.gnn_aggregation == 'sum': self.lstm_input_size = self.state_dim + self.action_dim self.mlp_input_size = self.state_dim + self.context_dim * 2 else: raise ValueError('Fix this') self.context_enc = nn.LSTM(self.lstm_input_size, self.context_dim, 2, batch_first=True, bidirectional=True) self.policy = MlpModel(input_size=self.mlp_input_size, hidden_sizes=[256, 128, 256], output_size=self.action_dim, dropout=self.dropout) self.past_samples = [] def forward(self, batch): dem_frames, dem_actions, dem_lens, query_frame, target_action = batch dem_actions = dem_actions.float() target_action = target_action.float() dem_states = self.gnn_encoder(dem_frames) # torch.Size([number of frames, 128 * number of features if cat_axis_1]) # segment according the number of frames in each episode and pad with zeros # to obtain tensors of shape [batch size, num of trials (8), max num of frames (30), hidden dim] b, l, s, _ = dem_actions.size() dem_states_new = [] for batch in range(b): dem_states_new.append(self._sequence_to_padding(dem_states, dem_lens[batch], self.max_len)) dem_states_new = torch.stack(dem_states_new).to(self.device) # torch.Size([batchsize, 8, 30, 128 * number of features if cat_axis_1]) # concatenate states and actions to get expert trajectory dem_states_new = dem_states_new.view(b * l, s, -1) # torch.Size([batchsize*8, 30, 128 * number of features if cat_axis_1]) dem_actions = dem_actions.view(b * l, s, -1) # torch.Size([batchsize*8, 30, 128]) dem_traj = torch.cat([dem_states_new, dem_actions], dim=2) # torch.Size([batchsize*8, 30, 2 + 128 * number of features if cat_axis_1]) # embed expert trajectory to get a context embedding batch x samples x dim dem_lens = torch.tensor([t for dl in dem_lens for t in dl]).to(torch.int64).cpu() dem_lens = dem_lens.view(-1) x_lstm = nn.utils.rnn.pack_padded_sequence(dem_traj, dem_lens, batch_first=True, enforce_sorted=False) x_lstm, _ = self.context_enc(x_lstm) x_lstm, _ = nn.utils.rnn.pad_packed_sequence(x_lstm, batch_first=True) x_out = x_lstm[:, -1, :] x_out = x_out.view(b, l, -1) context = torch.mean(x_out, dim=1) # torch.Size([batchsize, 64]) (64=32*2) # concat context embedding to the state embedding of test trajectory test_states = self.gnn_encoder(query_frame) # torch.Size([2, 128]) test_context_states = torch.cat([context, test_states], dim=1) # torch.Size([batchsize, 192]) 192=64+128 # for each state in the test states calculate action test_actions_pred = torch.tanh(self.policy(test_context_states)) # torch.Size([batchsize, 2]) return target_action, test_actions_pred def _sequence_to_padding(self, x, lengths, max_length): # declare the shape, it can work for x of any shape. ret_tensor = torch.zeros((len(lengths), max_length) + tuple(x.shape[1:])) cum_len = 0 for i, l in enumerate(lengths): ret_tensor[i, :l] = x[cum_len: cum_len+l] cum_len += l return ret_tensor def training_step(self, batch, batch_idx): test_actions, test_actions_pred = self.forward(batch) loss = F.mse_loss(test_actions, test_actions_pred) self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) return loss def validation_step(self, batch, batch_idx, dataloader_idx=0): test_actions, test_actions_pred = self.forward(batch) loss = F.mse_loss(test_actions, test_actions_pred) self.log('val_loss', loss, on_epoch=True, logger=True) def configure_optimizers(self): optim = torch.optim.Adam(self.parameters(), lr=self.lr) return optim #optim = torch.optim.AdamW(self.parameters(), lr=self.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01) #scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optim, gamma=0.8) #return [optim], [scheduler] def train_dataloader(self): train_dataset = ToMnetDGLDataset(path=self.hparams.data_path, types=self.hparams.types, mode='train') train_loader = DataLoader(dataset=train_dataset, batch_size=self.hparams.batch_size, collate_fn=collate_function_seq, num_workers=self.hparams.num_workers, #pin_memory=True, shuffle=True) return train_loader def val_dataloader(self): val_datasets = [] val_loaders = [] for t in self.hparams.types: val_datasets.append(ToMnetDGLDataset(path=self.hparams.data_path, types=[t], mode='val')) val_loaders.append(DataLoader(dataset=val_datasets[-1], batch_size=self.hparams.batch_size, collate_fn=collate_function_seq, num_workers=self.hparams.num_workers, #pin_memory=True, shuffle=False)) return val_loaders def configure_callbacks(self): checkpoint = ModelCheckpoint( dirpath=None, # automatically set #filename=self.params['bc_model']+'-'+self.params['gnn_type']+'-'+self.gnn_params['feats_aggr']+'-{epoch:02d}', save_top_k=-1, period=1 ) return [checkpoint] # --------------------------------------------------------------------------------------------------------------------------------- class GraphBC_T(pl.LightningModule): """ BC model with GraphTrans encoder, LSTM and MLP. """ @staticmethod def add_model_specific_args(parent_parser): parser = ArgumentParser(parents=[parent_parser], add_help=False) parser.add_argument('--lr', type=float, default=1e-4) parser.add_argument('--action_dim', type=int, default=2) parser.add_argument('--context_dim', type=int, default=32) # lstm hidden size parser.add_argument('--beta', type=float, default=0.01) parser.add_argument('--dropout', type=float, default=0.2) parser.add_argument('--process_data', type=int, default=0) parser.add_argument('--max_len', type=int, default=30) # arguments for gnn parser.add_argument('--state_dim', type=int, default=128) # gnn out_feats parser.add_argument('--feats_dims', type=list, default=[9, 2, 3, 18]) parser.add_argument('--aggregation', type=str, default='cat_axis_1') parser.add_argument('--gnn_type', type=str, default='RGATv3') # arguments for mpl #parser.add_argument('--mpl_hid_feats', type=list, default=[256, 64, 16]) # arguments for transformer parser.add_argument('--d_model', type=int, default=128) parser.add_argument('--nhead', type=int, default=4) parser.add_argument('--dim_feedforward', type=int, default=512) parser.add_argument('--transformer_dropout', type=float, default=0.3) parser.add_argument('--transformer_activation', type=str, default='gelu') parser.add_argument('--num_encoder_layers', type=int, default=6) parser.add_argument('--transformer_norm_input', type=int, default=0) return parser def __init__(self, hparams): super().__init__() self.hparams = hparams self.lr = self.hparams.lr self.state_dim = self.hparams.state_dim self.action_dim = self.hparams.action_dim self.context_dim = self.hparams.context_dim self.beta = self.hparams.beta self.dropout = self.hparams.dropout self.max_len = self.hparams.max_len self.feats_dims = self.hparams.feats_dims # type, position, color, shape self.rel_names = [ 'is_aligned', 'is_back', 'is_close', 'is_down_adj', 'is_down_left_adj', 'is_down_right_adj', 'is_front', 'is_left', 'is_left_adj', 'is_right', 'is_right_adj', 'is_top_adj', 'is_top_left_adj', 'is_top_right_adj' ] self.gnn_aggregation = self.hparams.aggregation if self.hparams.gnn_type == 'RGATv3': self.gnn_encoder = RGATv3( hidden_channels=self.state_dim, out_channels=self.state_dim, num_heads=4, rel_names=self.rel_names, dropout=0.0, n_layers=2, activation=nn.ELU(), residual=False ) if self.hparams.gnn_type == 'RGATv4': self.gnn_encoder = RGATv4( hidden_channels=self.state_dim, out_channels=self.state_dim, num_heads=4, rel_names=self.rel_names, dropout=0.0, n_layers=2, activation=nn.ELU(), residual=False ) if self.hparams.gnn_type == 'RSAGEv4': self.gnn_encoder = RSAGEv4( hidden_channels=self.state_dim, out_channels=self.state_dim, rel_names=self.rel_names, dropout=0.0, n_layers=2, activation=nn.ELU() ) if self.hparams.gnn_type == 'RAGNNv4': self.gnn_encoder = RAGNNv4( hidden_channels=self.state_dim, out_channels=self.state_dim, rel_names=self.rel_names, dropout=0.0, n_layers=2, activation=nn.ELU() ) if self.hparams.gnn_type == 'RGATv4Norm': self.gnn_encoder = RGATv4Norm( hidden_channels=self.state_dim, out_channels=self.state_dim, num_heads=4, rel_names=self.rel_names, dropout=0.0, n_layers=2, activation=nn.ELU(), residual=True ) self.d_model = self.hparams.d_model self.nhead = self.hparams.nhead self.dim_feedforward = self.hparams.dim_feedforward self.transformer_dropout = self.hparams.transformer_dropout self.transformer_activation = self.hparams.transformer_activation self.num_encoder_layers = self.hparams.num_encoder_layers self.transformer_norm_input = self.hparams.transformer_norm_input self.context_enc = TransformerEncoder( self.d_model, self.nhead, self.dim_feedforward, self.transformer_dropout, self.transformer_activation, self.num_encoder_layers, self.max_len, self.transformer_norm_input ) if self.gnn_aggregation == 'cat_axis_1': self.gnn2transformer = nn.Linear(self.state_dim * len(self.feats_dims) + self.action_dim, self.d_model) self.mlp_input_size = len(self.feats_dims) * self.state_dim + self.d_model elif self.gnn_aggregation == 'sum': self.gnn2transformer = nn.Linear(self.state_dim + self.action_dim, self.d_model) self.mlp_input_size = self.state_dim + self.d_model else: raise ValueError('Only sum and cat1 aggregations are available.') self.policy = MlpModel(input_size=self.mlp_input_size, hidden_sizes=[256, 128, 256], output_size=self.action_dim, dropout=self.dropout) # CLS Embedding parameters, requires_grad=True self.embedding = nn.Embedding(self.max_len + 1, self.d_model) # + 1 cause of cls token self.emb_layer_norm = nn.LayerNorm(self.d_model) self.emb_dropout = nn.Dropout(p=self.transformer_dropout) def forward(self, batch): dem_frames, dem_actions, dem_lens, query_frame, target_action = batch dem_actions = dem_actions.float() target_action = target_action.float() dem_states = self.gnn_encoder(dem_frames) b, l, s, _ = dem_actions.size() dem_lens = torch.tensor([t for dl in dem_lens for t in dl]).to(torch.int64).cpu() dem_lens = dem_lens.view(-1) dem_actions_packed = torch.nn.utils.rnn.pack_padded_sequence(dem_actions.view(b*l, s, -1), dem_lens, batch_first=True, enforce_sorted=False)[0] dem_traj = torch.cat([dem_states, dem_actions_packed], dim=1) h_node = self.gnn2transformer(dem_traj) hidden_dim = h_node.size()[1] padded_trajectory = torch.zeros(b*l, self.max_len, hidden_dim).to(self.device) j = 0 for idx, i in enumerate(dem_lens): padded_trajectory[idx][:i] = h_node[j:j+i] j += i mask = self.make_mask(padded_trajectory).to(self.device) transformer_input = padded_trajectory.transpose(0, 1) # [30, 16, 128] # add cls: cls_embedding = nn.Parameter(torch.randn([1, 1, self.d_model], requires_grad=True)).expand(1, b*l, -1).to(self.device) transformer_input = torch.cat([transformer_input, cls_embedding], dim=0) # [31, 16, 128] zeros = mask.data.new(mask.size(0), 1).fill_(0) mask = torch.cat([mask, zeros], dim=1) # Embed indices = torch.arange(self.max_len + 1, dtype=torch.int).to(self.device) # + 1 cause of cls [0, 1, ..., 30] positional_embeddings = self.embedding(indices).unsqueeze(1) # torch.Size([31, 1, 128]) #generate transformer input pe_input = positional_embeddings + transformer_input # torch.Size([31, 16, 128]) # Layernorm and dropout transformer_in = self.emb_dropout(self.emb_layer_norm(pe_input)) # torch.Size([31, 16, 128]) # transformer encoding and output parsing out, _ = self.context_enc(transformer_in, mask) # [31, 16, 128] cls = out[-1] cls = cls.view(b, l, -1) # 2, 8, 128 context = torch.mean(cls, dim=1) # CLASSIFICATION test_states = self.gnn_encoder(query_frame) # torch.Size([2, 512]) test_context_states = torch.cat([context, test_states], dim=1) # torch.Size([batchsize, hidden_dim + lstm_hidden_dim]) 192=512+128 # for each state in the test states calculate action x = self.policy(test_context_states) test_actions_pred = torch.tanh(x) # torch.Size([batchsize, 2]) #test_actions_pred = torch.tanh(self.policy(test_context_states)) # torch.Size([batchsize, 2]) return target_action, test_actions_pred def make_mask(self, feature): return (torch.sum( torch.abs(feature), dim=-1 ) == 0)#.unsqueeze(1).unsqueeze(2) def training_step(self, batch, batch_idx): test_actions, test_actions_pred = self.forward(batch) loss = F.mse_loss(test_actions, test_actions_pred) self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) return loss def validation_step(self, batch, batch_idx, dataloader_idx=0): test_actions, test_actions_pred = self.forward(batch) loss = F.mse_loss(test_actions, test_actions_pred) self.log('val_loss', loss, on_epoch=True, logger=True) def configure_optimizers(self): optim = torch.optim.Adam(self.parameters(), lr=self.lr) return optim #optim = torch.optim.AdamW(self.parameters(), lr=self.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01) #scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optim, gamma=0.96) #return [optim], [scheduler] def train_dataloader(self): train_dataset = ToMnetDGLDataset(path=self.hparams.data_path, types=self.hparams.types, mode='train') train_loader = DataLoader(dataset=train_dataset, batch_size=self.hparams.batch_size, collate_fn=collate_function_seq, num_workers=self.hparams.num_workers, #pin_memory=True, shuffle=True) return train_loader def val_dataloader(self): val_datasets = [] val_loaders = [] for t in self.hparams.types: val_datasets.append(ToMnetDGLDataset(path=self.hparams.data_path, types=[t], mode='val')) val_loaders.append(DataLoader(dataset=val_datasets[-1], batch_size=self.hparams.batch_size, collate_fn=collate_function_seq, num_workers=self.hparams.num_workers, #pin_memory=True, shuffle=False)) return val_loaders def configure_callbacks(self): checkpoint = ModelCheckpoint( dirpath=None, # automatically set #filename=self.params['bc_model']+'-'+self.params['gnn_type']+'-'+self.gnn_params['feats_aggr']+'-{epoch:02d}', save_top_k=-1, period=1 ) return [checkpoint]