513 lines
24 KiB
Python
513 lines
24 KiB
Python
|
from argparse import ArgumentParser
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
import pytorch_lightning as pl
|
||
|
from pytorch_lightning.callbacks import ModelCheckpoint
|
||
|
from torch.utils.data import DataLoader
|
||
|
|
||
|
from tom.dataset import *
|
||
|
from tom.transformer import TransformerEncoder
|
||
|
from tom.gnn import RGATv2, RGATv3, RGATv3Agent, RGATv4, RGATv4Norm, RSAGEv4, RAGNNv4
|
||
|
|
||
|
|
||
|
class MlpModel(nn.Module):
|
||
|
"""Multilayer Perceptron with last layer linear.
|
||
|
Args:
|
||
|
input_size (int): number of inputs
|
||
|
hidden_sizes (list): can be empty list for none (linear model).
|
||
|
output_size: linear layer at output, or if ``None``, the last hidden size
|
||
|
will be the output size and will have nonlinearity applied
|
||
|
nonlinearity: torch nonlinearity Module (not Functional).
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
input_size,
|
||
|
hidden_sizes, # Can be empty list or None for none.
|
||
|
output_size=None, # if None, last layer has nonlinearity applied.
|
||
|
nonlinearity=nn.ReLU, # Module, not Functional.
|
||
|
dropout=None # Dropout value
|
||
|
):
|
||
|
super().__init__()
|
||
|
if isinstance(hidden_sizes, int):
|
||
|
hidden_sizes = [hidden_sizes]
|
||
|
elif hidden_sizes is None:
|
||
|
hidden_sizes = []
|
||
|
hidden_layers = [nn.Linear(n_in, n_out) for n_in, n_out in
|
||
|
zip([input_size] + hidden_sizes[:-1], hidden_sizes)]
|
||
|
sequence = list()
|
||
|
for i, layer in enumerate(hidden_layers):
|
||
|
if dropout is not None:
|
||
|
sequence.extend([layer, nonlinearity(), nn.Dropout(dropout)])
|
||
|
else:
|
||
|
sequence.extend([layer, nonlinearity()])
|
||
|
|
||
|
if output_size is not None:
|
||
|
last_size = hidden_sizes[-1] if hidden_sizes else input_size
|
||
|
sequence.append(torch.nn.Linear(last_size, output_size))
|
||
|
self.model = nn.Sequential(*sequence)
|
||
|
self._output_size = (hidden_sizes[-1] if output_size is None
|
||
|
else output_size)
|
||
|
|
||
|
def forward(self, input):
|
||
|
"""Compute the model on the input, assuming input shape [B,input_size]."""
|
||
|
return self.model(input)
|
||
|
|
||
|
@property
|
||
|
def output_size(self):
|
||
|
"""Retuns the output size of the model."""
|
||
|
return self._output_size
|
||
|
|
||
|
# ---------------------------------------------------------------------------------------------------------------------------------
|
||
|
|
||
|
class GraphBCRNN(pl.LightningModule):
|
||
|
"""
|
||
|
Implementation of the baseline model for the BC-RNN algorithm.
|
||
|
R-GCN + LSTM are used to encode the familiarization trials
|
||
|
"""
|
||
|
@staticmethod
|
||
|
def add_model_specific_args(parent_parser):
|
||
|
parser = ArgumentParser(parents=[parent_parser], add_help=False)
|
||
|
parser.add_argument('--lr', type=float, default=1e-4)
|
||
|
parser.add_argument('--action_dim', type=int, default=2)
|
||
|
parser.add_argument('--context_dim', type=int, default=32) # lstm hidden size
|
||
|
parser.add_argument('--beta', type=float, default=0.01)
|
||
|
parser.add_argument('--dropout', type=float, default=0.2)
|
||
|
parser.add_argument('--process_data', type=int, default=0)
|
||
|
parser.add_argument('--max_len', type=int, default=30)
|
||
|
# arguments for gnn
|
||
|
parser.add_argument('--gnn_type', type=str, default='RGATv4')
|
||
|
parser.add_argument('--state_dim', type=int, default=128) # gnn out_feats
|
||
|
parser.add_argument('--feats_dims', type=list, default=[9, 2, 3, 18])
|
||
|
parser.add_argument('--aggregation', type=str, default='sum')
|
||
|
# arguments for mpl
|
||
|
#parser.add_argument('--mpl_hid_feats', type=list, default=[256, 64, 16])
|
||
|
return parser
|
||
|
|
||
|
def __init__(self, hparams):
|
||
|
super().__init__()
|
||
|
|
||
|
self.hparams = hparams
|
||
|
self.lr = self.hparams.lr
|
||
|
self.state_dim = self.hparams.state_dim
|
||
|
self.action_dim = self.hparams.action_dim
|
||
|
self.context_dim = self.hparams.context_dim
|
||
|
self.beta = self.hparams.beta
|
||
|
self.dropout = self.hparams.dropout
|
||
|
self.max_len = self.hparams.max_len
|
||
|
self.feats_dims = self.hparams.feats_dims # type, position, color, shape
|
||
|
self.rel_names = [
|
||
|
'is_aligned', 'is_back', 'is_close', 'is_down_adj', 'is_down_left_adj',
|
||
|
'is_down_right_adj', 'is_front', 'is_left', 'is_left_adj', 'is_right',
|
||
|
'is_right_adj', 'is_top_adj', 'is_top_left_adj', 'is_top_right_adj'
|
||
|
]
|
||
|
self.gnn_aggregation = self.hparams.aggregation
|
||
|
|
||
|
if self.hparams.gnn_type == 'RGATv2':
|
||
|
self.gnn_encoder = RGATv2(
|
||
|
hidden_channels=self.state_dim,
|
||
|
out_channels=self.state_dim,
|
||
|
num_heads=4,
|
||
|
rel_names=self.rel_names,
|
||
|
dropout=0.0,
|
||
|
n_layers=2,
|
||
|
activation=nn.ELU(),
|
||
|
residual=False
|
||
|
)
|
||
|
if self.hparams.gnn_type == 'RGATv3':
|
||
|
self.gnn_encoder = RGATv3(
|
||
|
hidden_channels=self.state_dim,
|
||
|
out_channels=self.state_dim,
|
||
|
num_heads=4,
|
||
|
rel_names=self.rel_names,
|
||
|
dropout=0.0,
|
||
|
n_layers=2,
|
||
|
activation=nn.ELU(),
|
||
|
residual=False
|
||
|
)
|
||
|
if self.hparams.gnn_type == 'RGATv4':
|
||
|
self.gnn_encoder = RGATv4(
|
||
|
hidden_channels=self.state_dim,
|
||
|
out_channels=self.state_dim,
|
||
|
num_heads=4,
|
||
|
rel_names=self.rel_names,
|
||
|
dropout=0.0,
|
||
|
n_layers=2,
|
||
|
activation=nn.ELU(),
|
||
|
residual=False
|
||
|
)
|
||
|
if self.hparams.gnn_type == 'RGATv3Agent':
|
||
|
self.gnn_encoder = RGATv3Agent(
|
||
|
hidden_channels=self.state_dim,
|
||
|
out_channels=self.state_dim,
|
||
|
num_heads=4,
|
||
|
rel_names=self.rel_names,
|
||
|
dropout=0.0,
|
||
|
n_layers=2,
|
||
|
activation=nn.ELU(),
|
||
|
residual=False
|
||
|
)
|
||
|
if self.hparams.gnn_type == 'RSAGEv4':
|
||
|
self.gnn_encoder = RSAGEv4(
|
||
|
hidden_channels=self.state_dim,
|
||
|
out_channels=self.state_dim,
|
||
|
rel_names=self.rel_names,
|
||
|
dropout=0.0,
|
||
|
n_layers=2,
|
||
|
activation=nn.ELU()
|
||
|
)
|
||
|
if self.gnn_aggregation == 'cat_axis_1':
|
||
|
self.lstm_input_size = self.state_dim * len(self.feats_dims) + self.action_dim
|
||
|
self.mlp_input_size = self.state_dim * len(self.feats_dims) + self.context_dim * 2
|
||
|
elif self.gnn_aggregation == 'sum':
|
||
|
self.lstm_input_size = self.state_dim + self.action_dim
|
||
|
self.mlp_input_size = self.state_dim + self.context_dim * 2
|
||
|
else:
|
||
|
raise ValueError('Fix this')
|
||
|
|
||
|
self.context_enc = nn.LSTM(self.lstm_input_size, self.context_dim, 2,
|
||
|
batch_first=True, bidirectional=True)
|
||
|
|
||
|
self.policy = MlpModel(input_size=self.mlp_input_size, hidden_sizes=[256, 128, 256],
|
||
|
output_size=self.action_dim, dropout=self.dropout)
|
||
|
|
||
|
self.past_samples = []
|
||
|
|
||
|
def forward(self, batch):
|
||
|
dem_frames, dem_actions, dem_lens, query_frame, target_action = batch
|
||
|
dem_actions = dem_actions.float()
|
||
|
target_action = target_action.float()
|
||
|
dem_states = self.gnn_encoder(dem_frames) # torch.Size([number of frames, 128 * number of features if cat_axis_1])
|
||
|
# segment according the number of frames in each episode and pad with zeros
|
||
|
# to obtain tensors of shape [batch size, num of trials (8), max num of frames (30), hidden dim]
|
||
|
b, l, s, _ = dem_actions.size()
|
||
|
dem_states_new = []
|
||
|
for batch in range(b):
|
||
|
dem_states_new.append(self._sequence_to_padding(dem_states, dem_lens[batch], self.max_len))
|
||
|
dem_states_new = torch.stack(dem_states_new).to(self.device) # torch.Size([batchsize, 8, 30, 128 * number of features if cat_axis_1])
|
||
|
# concatenate states and actions to get expert trajectory
|
||
|
dem_states_new = dem_states_new.view(b * l, s, -1) # torch.Size([batchsize*8, 30, 128 * number of features if cat_axis_1])
|
||
|
dem_actions = dem_actions.view(b * l, s, -1) # torch.Size([batchsize*8, 30, 128])
|
||
|
dem_traj = torch.cat([dem_states_new, dem_actions], dim=2) # torch.Size([batchsize*8, 30, 2 + 128 * number of features if cat_axis_1])
|
||
|
# embed expert trajectory to get a context embedding batch x samples x dim
|
||
|
dem_lens = torch.tensor([t for dl in dem_lens for t in dl]).to(torch.int64).cpu()
|
||
|
dem_lens = dem_lens.view(-1)
|
||
|
x_lstm = nn.utils.rnn.pack_padded_sequence(dem_traj, dem_lens, batch_first=True, enforce_sorted=False)
|
||
|
x_lstm, _ = self.context_enc(x_lstm)
|
||
|
x_lstm, _ = nn.utils.rnn.pad_packed_sequence(x_lstm, batch_first=True)
|
||
|
x_out = x_lstm[:, -1, :]
|
||
|
x_out = x_out.view(b, l, -1)
|
||
|
context = torch.mean(x_out, dim=1) # torch.Size([batchsize, 64]) (64=32*2)
|
||
|
# concat context embedding to the state embedding of test trajectory
|
||
|
test_states = self.gnn_encoder(query_frame) # torch.Size([2, 128])
|
||
|
test_context_states = torch.cat([context, test_states], dim=1) # torch.Size([batchsize, 192]) 192=64+128
|
||
|
# for each state in the test states calculate action
|
||
|
test_actions_pred = torch.tanh(self.policy(test_context_states)) # torch.Size([batchsize, 2])
|
||
|
return target_action, test_actions_pred
|
||
|
|
||
|
def _sequence_to_padding(self, x, lengths, max_length):
|
||
|
# declare the shape, it can work for x of any shape.
|
||
|
ret_tensor = torch.zeros((len(lengths), max_length) + tuple(x.shape[1:]))
|
||
|
cum_len = 0
|
||
|
for i, l in enumerate(lengths):
|
||
|
ret_tensor[i, :l] = x[cum_len: cum_len+l]
|
||
|
cum_len += l
|
||
|
return ret_tensor
|
||
|
|
||
|
def training_step(self, batch, batch_idx):
|
||
|
test_actions, test_actions_pred = self.forward(batch)
|
||
|
loss = F.mse_loss(test_actions, test_actions_pred)
|
||
|
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
|
||
|
return loss
|
||
|
|
||
|
def validation_step(self, batch, batch_idx, dataloader_idx=0):
|
||
|
test_actions, test_actions_pred = self.forward(batch)
|
||
|
loss = F.mse_loss(test_actions, test_actions_pred)
|
||
|
self.log('val_loss', loss, on_epoch=True, logger=True)
|
||
|
|
||
|
def configure_optimizers(self):
|
||
|
optim = torch.optim.Adam(self.parameters(), lr=self.lr)
|
||
|
return optim
|
||
|
#optim = torch.optim.AdamW(self.parameters(), lr=self.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01)
|
||
|
#scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optim, gamma=0.8)
|
||
|
#return [optim], [scheduler]
|
||
|
|
||
|
def train_dataloader(self):
|
||
|
train_dataset = ToMnetDGLDataset(path=self.hparams.data_path,
|
||
|
types=self.hparams.types,
|
||
|
mode='train')
|
||
|
train_loader = DataLoader(dataset=train_dataset,
|
||
|
batch_size=self.hparams.batch_size,
|
||
|
collate_fn=collate_function_seq,
|
||
|
num_workers=self.hparams.num_workers,
|
||
|
#pin_memory=True,
|
||
|
shuffle=True)
|
||
|
return train_loader
|
||
|
|
||
|
def val_dataloader(self):
|
||
|
val_datasets = []
|
||
|
val_loaders = []
|
||
|
for t in self.hparams.types:
|
||
|
val_datasets.append(ToMnetDGLDataset(path=self.hparams.data_path,
|
||
|
types=[t],
|
||
|
mode='val'))
|
||
|
val_loaders.append(DataLoader(dataset=val_datasets[-1],
|
||
|
batch_size=self.hparams.batch_size,
|
||
|
collate_fn=collate_function_seq,
|
||
|
num_workers=self.hparams.num_workers,
|
||
|
#pin_memory=True,
|
||
|
shuffle=False))
|
||
|
return val_loaders
|
||
|
|
||
|
def configure_callbacks(self):
|
||
|
checkpoint = ModelCheckpoint(
|
||
|
dirpath=None, # automatically set
|
||
|
#filename=self.params['bc_model']+'-'+self.params['gnn_type']+'-'+self.gnn_params['feats_aggr']+'-{epoch:02d}',
|
||
|
save_top_k=-1,
|
||
|
period=1
|
||
|
)
|
||
|
return [checkpoint]
|
||
|
|
||
|
# ---------------------------------------------------------------------------------------------------------------------------------
|
||
|
|
||
|
class GraphBC_T(pl.LightningModule):
|
||
|
"""
|
||
|
BC model with GraphTrans encoder, LSTM and MLP.
|
||
|
"""
|
||
|
@staticmethod
|
||
|
def add_model_specific_args(parent_parser):
|
||
|
parser = ArgumentParser(parents=[parent_parser], add_help=False)
|
||
|
parser.add_argument('--lr', type=float, default=1e-4)
|
||
|
parser.add_argument('--action_dim', type=int, default=2)
|
||
|
parser.add_argument('--context_dim', type=int, default=32) # lstm hidden size
|
||
|
parser.add_argument('--beta', type=float, default=0.01)
|
||
|
parser.add_argument('--dropout', type=float, default=0.2)
|
||
|
parser.add_argument('--process_data', type=int, default=0)
|
||
|
parser.add_argument('--max_len', type=int, default=30)
|
||
|
# arguments for gnn
|
||
|
parser.add_argument('--state_dim', type=int, default=128) # gnn out_feats
|
||
|
parser.add_argument('--feats_dims', type=list, default=[9, 2, 3, 18])
|
||
|
parser.add_argument('--aggregation', type=str, default='cat_axis_1')
|
||
|
parser.add_argument('--gnn_type', type=str, default='RGATv3')
|
||
|
# arguments for mpl
|
||
|
#parser.add_argument('--mpl_hid_feats', type=list, default=[256, 64, 16])
|
||
|
# arguments for transformer
|
||
|
parser.add_argument('--d_model', type=int, default=128)
|
||
|
parser.add_argument('--nhead', type=int, default=4)
|
||
|
parser.add_argument('--dim_feedforward', type=int, default=512)
|
||
|
parser.add_argument('--transformer_dropout', type=float, default=0.3)
|
||
|
parser.add_argument('--transformer_activation', type=str, default='gelu')
|
||
|
parser.add_argument('--num_encoder_layers', type=int, default=6)
|
||
|
parser.add_argument('--transformer_norm_input', type=int, default=0)
|
||
|
return parser
|
||
|
|
||
|
def __init__(self, hparams):
|
||
|
super().__init__()
|
||
|
|
||
|
self.hparams = hparams
|
||
|
self.lr = self.hparams.lr
|
||
|
self.state_dim = self.hparams.state_dim
|
||
|
self.action_dim = self.hparams.action_dim
|
||
|
self.context_dim = self.hparams.context_dim
|
||
|
self.beta = self.hparams.beta
|
||
|
self.dropout = self.hparams.dropout
|
||
|
self.max_len = self.hparams.max_len
|
||
|
self.feats_dims = self.hparams.feats_dims # type, position, color, shape
|
||
|
|
||
|
self.rel_names = [
|
||
|
'is_aligned', 'is_back', 'is_close', 'is_down_adj', 'is_down_left_adj',
|
||
|
'is_down_right_adj', 'is_front', 'is_left', 'is_left_adj', 'is_right',
|
||
|
'is_right_adj', 'is_top_adj', 'is_top_left_adj', 'is_top_right_adj'
|
||
|
]
|
||
|
self.gnn_aggregation = self.hparams.aggregation
|
||
|
if self.hparams.gnn_type == 'RGATv3':
|
||
|
self.gnn_encoder = RGATv3(
|
||
|
hidden_channels=self.state_dim,
|
||
|
out_channels=self.state_dim,
|
||
|
num_heads=4,
|
||
|
rel_names=self.rel_names,
|
||
|
dropout=0.0,
|
||
|
n_layers=2,
|
||
|
activation=nn.ELU(),
|
||
|
residual=False
|
||
|
)
|
||
|
if self.hparams.gnn_type == 'RGATv4':
|
||
|
self.gnn_encoder = RGATv4(
|
||
|
hidden_channels=self.state_dim,
|
||
|
out_channels=self.state_dim,
|
||
|
num_heads=4,
|
||
|
rel_names=self.rel_names,
|
||
|
dropout=0.0,
|
||
|
n_layers=2,
|
||
|
activation=nn.ELU(),
|
||
|
residual=False
|
||
|
)
|
||
|
if self.hparams.gnn_type == 'RSAGEv4':
|
||
|
self.gnn_encoder = RSAGEv4(
|
||
|
hidden_channels=self.state_dim,
|
||
|
out_channels=self.state_dim,
|
||
|
rel_names=self.rel_names,
|
||
|
dropout=0.0,
|
||
|
n_layers=2,
|
||
|
activation=nn.ELU()
|
||
|
)
|
||
|
if self.hparams.gnn_type == 'RAGNNv4':
|
||
|
self.gnn_encoder = RAGNNv4(
|
||
|
hidden_channels=self.state_dim,
|
||
|
out_channels=self.state_dim,
|
||
|
rel_names=self.rel_names,
|
||
|
dropout=0.0,
|
||
|
n_layers=2,
|
||
|
activation=nn.ELU()
|
||
|
)
|
||
|
if self.hparams.gnn_type == 'RGATv4Norm':
|
||
|
self.gnn_encoder = RGATv4Norm(
|
||
|
hidden_channels=self.state_dim,
|
||
|
out_channels=self.state_dim,
|
||
|
num_heads=4,
|
||
|
rel_names=self.rel_names,
|
||
|
dropout=0.0,
|
||
|
n_layers=2,
|
||
|
activation=nn.ELU(),
|
||
|
residual=True
|
||
|
)
|
||
|
self.d_model = self.hparams.d_model
|
||
|
self.nhead = self.hparams.nhead
|
||
|
self.dim_feedforward = self.hparams.dim_feedforward
|
||
|
self.transformer_dropout = self.hparams.transformer_dropout
|
||
|
self.transformer_activation = self.hparams.transformer_activation
|
||
|
self.num_encoder_layers = self.hparams.num_encoder_layers
|
||
|
self.transformer_norm_input = self.hparams.transformer_norm_input
|
||
|
self.context_enc = TransformerEncoder(
|
||
|
self.d_model,
|
||
|
self.nhead,
|
||
|
self.dim_feedforward,
|
||
|
self.transformer_dropout,
|
||
|
self.transformer_activation,
|
||
|
self.num_encoder_layers,
|
||
|
self.max_len,
|
||
|
self.transformer_norm_input
|
||
|
)
|
||
|
|
||
|
if self.gnn_aggregation == 'cat_axis_1':
|
||
|
self.gnn2transformer = nn.Linear(self.state_dim * len(self.feats_dims) + self.action_dim, self.d_model)
|
||
|
self.mlp_input_size = len(self.feats_dims) * self.state_dim + self.d_model
|
||
|
elif self.gnn_aggregation == 'sum':
|
||
|
self.gnn2transformer = nn.Linear(self.state_dim + self.action_dim, self.d_model)
|
||
|
self.mlp_input_size = self.state_dim + self.d_model
|
||
|
else:
|
||
|
raise ValueError('Only sum and cat1 aggregations are available.')
|
||
|
|
||
|
self.policy = MlpModel(input_size=self.mlp_input_size, hidden_sizes=[256, 128, 256],
|
||
|
output_size=self.action_dim, dropout=self.dropout)
|
||
|
|
||
|
# CLS Embedding parameters, requires_grad=True
|
||
|
self.embedding = nn.Embedding(self.max_len + 1, self.d_model) # + 1 cause of cls token
|
||
|
self.emb_layer_norm = nn.LayerNorm(self.d_model)
|
||
|
self.emb_dropout = nn.Dropout(p=self.transformer_dropout)
|
||
|
|
||
|
def forward(self, batch):
|
||
|
dem_frames, dem_actions, dem_lens, query_frame, target_action = batch
|
||
|
dem_actions = dem_actions.float()
|
||
|
target_action = target_action.float()
|
||
|
dem_states = self.gnn_encoder(dem_frames)
|
||
|
b, l, s, _ = dem_actions.size()
|
||
|
dem_lens = torch.tensor([t for dl in dem_lens for t in dl]).to(torch.int64).cpu()
|
||
|
dem_lens = dem_lens.view(-1)
|
||
|
dem_actions_packed = torch.nn.utils.rnn.pack_padded_sequence(dem_actions.view(b*l, s, -1), dem_lens, batch_first=True, enforce_sorted=False)[0]
|
||
|
dem_traj = torch.cat([dem_states, dem_actions_packed], dim=1)
|
||
|
h_node = self.gnn2transformer(dem_traj)
|
||
|
hidden_dim = h_node.size()[1]
|
||
|
padded_trajectory = torch.zeros(b*l, self.max_len, hidden_dim).to(self.device)
|
||
|
j = 0
|
||
|
for idx, i in enumerate(dem_lens):
|
||
|
padded_trajectory[idx][:i] = h_node[j:j+i]
|
||
|
j += i
|
||
|
mask = self.make_mask(padded_trajectory).to(self.device)
|
||
|
transformer_input = padded_trajectory.transpose(0, 1) # [30, 16, 128]
|
||
|
# add cls:
|
||
|
cls_embedding = nn.Parameter(torch.randn([1, 1, self.d_model], requires_grad=True)).expand(1, b*l, -1).to(self.device)
|
||
|
transformer_input = torch.cat([transformer_input, cls_embedding], dim=0) # [31, 16, 128]
|
||
|
zeros = mask.data.new(mask.size(0), 1).fill_(0)
|
||
|
mask = torch.cat([mask, zeros], dim=1)
|
||
|
# Embed
|
||
|
indices = torch.arange(self.max_len + 1, dtype=torch.int).to(self.device) # + 1 cause of cls [0, 1, ..., 30]
|
||
|
positional_embeddings = self.embedding(indices).unsqueeze(1) # torch.Size([31, 1, 128])
|
||
|
#generate transformer input
|
||
|
pe_input = positional_embeddings + transformer_input # torch.Size([31, 16, 128])
|
||
|
# Layernorm and dropout
|
||
|
transformer_in = self.emb_dropout(self.emb_layer_norm(pe_input)) # torch.Size([31, 16, 128])
|
||
|
# transformer encoding and output parsing
|
||
|
out, _ = self.context_enc(transformer_in, mask) # [31, 16, 128]
|
||
|
cls = out[-1]
|
||
|
cls = cls.view(b, l, -1) # 2, 8, 128
|
||
|
context = torch.mean(cls, dim=1)
|
||
|
# CLASSIFICATION
|
||
|
test_states = self.gnn_encoder(query_frame) # torch.Size([2, 512])
|
||
|
test_context_states = torch.cat([context, test_states], dim=1) # torch.Size([batchsize, hidden_dim + lstm_hidden_dim]) 192=512+128
|
||
|
# for each state in the test states calculate action
|
||
|
x = self.policy(test_context_states)
|
||
|
test_actions_pred = torch.tanh(x) # torch.Size([batchsize, 2])
|
||
|
#test_actions_pred = torch.tanh(self.policy(test_context_states)) # torch.Size([batchsize, 2])
|
||
|
return target_action, test_actions_pred
|
||
|
|
||
|
def make_mask(self, feature):
|
||
|
return (torch.sum(
|
||
|
torch.abs(feature),
|
||
|
dim=-1
|
||
|
) == 0)#.unsqueeze(1).unsqueeze(2)
|
||
|
|
||
|
def training_step(self, batch, batch_idx):
|
||
|
test_actions, test_actions_pred = self.forward(batch)
|
||
|
loss = F.mse_loss(test_actions, test_actions_pred)
|
||
|
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
|
||
|
return loss
|
||
|
|
||
|
def validation_step(self, batch, batch_idx, dataloader_idx=0):
|
||
|
test_actions, test_actions_pred = self.forward(batch)
|
||
|
loss = F.mse_loss(test_actions, test_actions_pred)
|
||
|
self.log('val_loss', loss, on_epoch=True, logger=True)
|
||
|
|
||
|
def configure_optimizers(self):
|
||
|
optim = torch.optim.Adam(self.parameters(), lr=self.lr)
|
||
|
return optim
|
||
|
#optim = torch.optim.AdamW(self.parameters(), lr=self.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01)
|
||
|
#scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optim, gamma=0.96)
|
||
|
#return [optim], [scheduler]
|
||
|
|
||
|
def train_dataloader(self):
|
||
|
train_dataset = ToMnetDGLDataset(path=self.hparams.data_path,
|
||
|
types=self.hparams.types,
|
||
|
mode='train')
|
||
|
train_loader = DataLoader(dataset=train_dataset,
|
||
|
batch_size=self.hparams.batch_size,
|
||
|
collate_fn=collate_function_seq,
|
||
|
num_workers=self.hparams.num_workers,
|
||
|
#pin_memory=True,
|
||
|
shuffle=True)
|
||
|
return train_loader
|
||
|
|
||
|
def val_dataloader(self):
|
||
|
val_datasets = []
|
||
|
val_loaders = []
|
||
|
for t in self.hparams.types:
|
||
|
val_datasets.append(ToMnetDGLDataset(path=self.hparams.data_path,
|
||
|
types=[t],
|
||
|
mode='val'))
|
||
|
val_loaders.append(DataLoader(dataset=val_datasets[-1],
|
||
|
batch_size=self.hparams.batch_size,
|
||
|
collate_fn=collate_function_seq,
|
||
|
num_workers=self.hparams.num_workers,
|
||
|
#pin_memory=True,
|
||
|
shuffle=False))
|
||
|
return val_loaders
|
||
|
|
||
|
def configure_callbacks(self):
|
||
|
checkpoint = ModelCheckpoint(
|
||
|
dirpath=None, # automatically set
|
||
|
#filename=self.params['bc_model']+'-'+self.params['gnn_type']+'-'+self.gnn_params['feats_aggr']+'-{epoch:02d}',
|
||
|
save_top_k=-1,
|
||
|
period=1
|
||
|
)
|
||
|
return [checkpoint]
|