IRENE/tom/model.py

513 lines
24 KiB
Python
Raw Normal View History

2024-02-01 15:40:47 +01:00
from argparse import ArgumentParser
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader
from tom.dataset import *
from tom.transformer import TransformerEncoder
from tom.gnn import RGATv2, RGATv3, RGATv3Agent, RGATv4, RGATv4Norm, RSAGEv4, RAGNNv4
class MlpModel(nn.Module):
"""Multilayer Perceptron with last layer linear.
Args:
input_size (int): number of inputs
hidden_sizes (list): can be empty list for none (linear model).
output_size: linear layer at output, or if ``None``, the last hidden size
will be the output size and will have nonlinearity applied
nonlinearity: torch nonlinearity Module (not Functional).
"""
def __init__(
self,
input_size,
hidden_sizes, # Can be empty list or None for none.
output_size=None, # if None, last layer has nonlinearity applied.
nonlinearity=nn.ReLU, # Module, not Functional.
dropout=None # Dropout value
):
super().__init__()
if isinstance(hidden_sizes, int):
hidden_sizes = [hidden_sizes]
elif hidden_sizes is None:
hidden_sizes = []
hidden_layers = [nn.Linear(n_in, n_out) for n_in, n_out in
zip([input_size] + hidden_sizes[:-1], hidden_sizes)]
sequence = list()
for i, layer in enumerate(hidden_layers):
if dropout is not None:
sequence.extend([layer, nonlinearity(), nn.Dropout(dropout)])
else:
sequence.extend([layer, nonlinearity()])
if output_size is not None:
last_size = hidden_sizes[-1] if hidden_sizes else input_size
sequence.append(torch.nn.Linear(last_size, output_size))
self.model = nn.Sequential(*sequence)
self._output_size = (hidden_sizes[-1] if output_size is None
else output_size)
def forward(self, input):
"""Compute the model on the input, assuming input shape [B,input_size]."""
return self.model(input)
@property
def output_size(self):
"""Retuns the output size of the model."""
return self._output_size
# ---------------------------------------------------------------------------------------------------------------------------------
class GraphBCRNN(pl.LightningModule):
"""
Implementation of the baseline model for the BC-RNN algorithm.
R-GCN + LSTM are used to encode the familiarization trials
"""
@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--action_dim', type=int, default=2)
parser.add_argument('--context_dim', type=int, default=32) # lstm hidden size
parser.add_argument('--beta', type=float, default=0.01)
parser.add_argument('--dropout', type=float, default=0.2)
parser.add_argument('--process_data', type=int, default=0)
parser.add_argument('--max_len', type=int, default=30)
# arguments for gnn
parser.add_argument('--gnn_type', type=str, default='RGATv4')
parser.add_argument('--state_dim', type=int, default=128) # gnn out_feats
parser.add_argument('--feats_dims', type=list, default=[9, 2, 3, 18])
parser.add_argument('--aggregation', type=str, default='sum')
# arguments for mpl
#parser.add_argument('--mpl_hid_feats', type=list, default=[256, 64, 16])
return parser
def __init__(self, hparams):
super().__init__()
self.hparams = hparams
self.lr = self.hparams.lr
self.state_dim = self.hparams.state_dim
self.action_dim = self.hparams.action_dim
self.context_dim = self.hparams.context_dim
self.beta = self.hparams.beta
self.dropout = self.hparams.dropout
self.max_len = self.hparams.max_len
self.feats_dims = self.hparams.feats_dims # type, position, color, shape
self.rel_names = [
'is_aligned', 'is_back', 'is_close', 'is_down_adj', 'is_down_left_adj',
'is_down_right_adj', 'is_front', 'is_left', 'is_left_adj', 'is_right',
'is_right_adj', 'is_top_adj', 'is_top_left_adj', 'is_top_right_adj'
]
self.gnn_aggregation = self.hparams.aggregation
if self.hparams.gnn_type == 'RGATv2':
self.gnn_encoder = RGATv2(
hidden_channels=self.state_dim,
out_channels=self.state_dim,
num_heads=4,
rel_names=self.rel_names,
dropout=0.0,
n_layers=2,
activation=nn.ELU(),
residual=False
)
if self.hparams.gnn_type == 'RGATv3':
self.gnn_encoder = RGATv3(
hidden_channels=self.state_dim,
out_channels=self.state_dim,
num_heads=4,
rel_names=self.rel_names,
dropout=0.0,
n_layers=2,
activation=nn.ELU(),
residual=False
)
if self.hparams.gnn_type == 'RGATv4':
self.gnn_encoder = RGATv4(
hidden_channels=self.state_dim,
out_channels=self.state_dim,
num_heads=4,
rel_names=self.rel_names,
dropout=0.0,
n_layers=2,
activation=nn.ELU(),
residual=False
)
if self.hparams.gnn_type == 'RGATv3Agent':
self.gnn_encoder = RGATv3Agent(
hidden_channels=self.state_dim,
out_channels=self.state_dim,
num_heads=4,
rel_names=self.rel_names,
dropout=0.0,
n_layers=2,
activation=nn.ELU(),
residual=False
)
if self.hparams.gnn_type == 'RSAGEv4':
self.gnn_encoder = RSAGEv4(
hidden_channels=self.state_dim,
out_channels=self.state_dim,
rel_names=self.rel_names,
dropout=0.0,
n_layers=2,
activation=nn.ELU()
)
if self.gnn_aggregation == 'cat_axis_1':
self.lstm_input_size = self.state_dim * len(self.feats_dims) + self.action_dim
self.mlp_input_size = self.state_dim * len(self.feats_dims) + self.context_dim * 2
elif self.gnn_aggregation == 'sum':
self.lstm_input_size = self.state_dim + self.action_dim
self.mlp_input_size = self.state_dim + self.context_dim * 2
else:
raise ValueError('Fix this')
self.context_enc = nn.LSTM(self.lstm_input_size, self.context_dim, 2,
batch_first=True, bidirectional=True)
self.policy = MlpModel(input_size=self.mlp_input_size, hidden_sizes=[256, 128, 256],
output_size=self.action_dim, dropout=self.dropout)
self.past_samples = []
def forward(self, batch):
dem_frames, dem_actions, dem_lens, query_frame, target_action = batch
dem_actions = dem_actions.float()
target_action = target_action.float()
dem_states = self.gnn_encoder(dem_frames) # torch.Size([number of frames, 128 * number of features if cat_axis_1])
# segment according the number of frames in each episode and pad with zeros
# to obtain tensors of shape [batch size, num of trials (8), max num of frames (30), hidden dim]
b, l, s, _ = dem_actions.size()
dem_states_new = []
for batch in range(b):
dem_states_new.append(self._sequence_to_padding(dem_states, dem_lens[batch], self.max_len))
dem_states_new = torch.stack(dem_states_new).to(self.device) # torch.Size([batchsize, 8, 30, 128 * number of features if cat_axis_1])
# concatenate states and actions to get expert trajectory
dem_states_new = dem_states_new.view(b * l, s, -1) # torch.Size([batchsize*8, 30, 128 * number of features if cat_axis_1])
dem_actions = dem_actions.view(b * l, s, -1) # torch.Size([batchsize*8, 30, 128])
dem_traj = torch.cat([dem_states_new, dem_actions], dim=2) # torch.Size([batchsize*8, 30, 2 + 128 * number of features if cat_axis_1])
# embed expert trajectory to get a context embedding batch x samples x dim
dem_lens = torch.tensor([t for dl in dem_lens for t in dl]).to(torch.int64).cpu()
dem_lens = dem_lens.view(-1)
x_lstm = nn.utils.rnn.pack_padded_sequence(dem_traj, dem_lens, batch_first=True, enforce_sorted=False)
x_lstm, _ = self.context_enc(x_lstm)
x_lstm, _ = nn.utils.rnn.pad_packed_sequence(x_lstm, batch_first=True)
x_out = x_lstm[:, -1, :]
x_out = x_out.view(b, l, -1)
context = torch.mean(x_out, dim=1) # torch.Size([batchsize, 64]) (64=32*2)
# concat context embedding to the state embedding of test trajectory
test_states = self.gnn_encoder(query_frame) # torch.Size([2, 128])
test_context_states = torch.cat([context, test_states], dim=1) # torch.Size([batchsize, 192]) 192=64+128
# for each state in the test states calculate action
test_actions_pred = torch.tanh(self.policy(test_context_states)) # torch.Size([batchsize, 2])
return target_action, test_actions_pred
def _sequence_to_padding(self, x, lengths, max_length):
# declare the shape, it can work for x of any shape.
ret_tensor = torch.zeros((len(lengths), max_length) + tuple(x.shape[1:]))
cum_len = 0
for i, l in enumerate(lengths):
ret_tensor[i, :l] = x[cum_len: cum_len+l]
cum_len += l
return ret_tensor
def training_step(self, batch, batch_idx):
test_actions, test_actions_pred = self.forward(batch)
loss = F.mse_loss(test_actions, test_actions_pred)
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
return loss
def validation_step(self, batch, batch_idx, dataloader_idx=0):
test_actions, test_actions_pred = self.forward(batch)
loss = F.mse_loss(test_actions, test_actions_pred)
self.log('val_loss', loss, on_epoch=True, logger=True)
def configure_optimizers(self):
optim = torch.optim.Adam(self.parameters(), lr=self.lr)
return optim
#optim = torch.optim.AdamW(self.parameters(), lr=self.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01)
#scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optim, gamma=0.8)
#return [optim], [scheduler]
def train_dataloader(self):
train_dataset = ToMnetDGLDataset(path=self.hparams.data_path,
types=self.hparams.types,
mode='train')
train_loader = DataLoader(dataset=train_dataset,
batch_size=self.hparams.batch_size,
collate_fn=collate_function_seq,
num_workers=self.hparams.num_workers,
#pin_memory=True,
shuffle=True)
return train_loader
def val_dataloader(self):
val_datasets = []
val_loaders = []
for t in self.hparams.types:
val_datasets.append(ToMnetDGLDataset(path=self.hparams.data_path,
types=[t],
mode='val'))
val_loaders.append(DataLoader(dataset=val_datasets[-1],
batch_size=self.hparams.batch_size,
collate_fn=collate_function_seq,
num_workers=self.hparams.num_workers,
#pin_memory=True,
shuffle=False))
return val_loaders
def configure_callbacks(self):
checkpoint = ModelCheckpoint(
dirpath=None, # automatically set
#filename=self.params['bc_model']+'-'+self.params['gnn_type']+'-'+self.gnn_params['feats_aggr']+'-{epoch:02d}',
save_top_k=-1,
period=1
)
return [checkpoint]
# ---------------------------------------------------------------------------------------------------------------------------------
class GraphBC_T(pl.LightningModule):
"""
BC model with GraphTrans encoder, LSTM and MLP.
"""
@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--action_dim', type=int, default=2)
parser.add_argument('--context_dim', type=int, default=32) # lstm hidden size
parser.add_argument('--beta', type=float, default=0.01)
parser.add_argument('--dropout', type=float, default=0.2)
parser.add_argument('--process_data', type=int, default=0)
parser.add_argument('--max_len', type=int, default=30)
# arguments for gnn
parser.add_argument('--state_dim', type=int, default=128) # gnn out_feats
parser.add_argument('--feats_dims', type=list, default=[9, 2, 3, 18])
parser.add_argument('--aggregation', type=str, default='cat_axis_1')
parser.add_argument('--gnn_type', type=str, default='RGATv3')
# arguments for mpl
#parser.add_argument('--mpl_hid_feats', type=list, default=[256, 64, 16])
# arguments for transformer
parser.add_argument('--d_model', type=int, default=128)
parser.add_argument('--nhead', type=int, default=4)
parser.add_argument('--dim_feedforward', type=int, default=512)
parser.add_argument('--transformer_dropout', type=float, default=0.3)
parser.add_argument('--transformer_activation', type=str, default='gelu')
parser.add_argument('--num_encoder_layers', type=int, default=6)
parser.add_argument('--transformer_norm_input', type=int, default=0)
return parser
def __init__(self, hparams):
super().__init__()
self.hparams = hparams
self.lr = self.hparams.lr
self.state_dim = self.hparams.state_dim
self.action_dim = self.hparams.action_dim
self.context_dim = self.hparams.context_dim
self.beta = self.hparams.beta
self.dropout = self.hparams.dropout
self.max_len = self.hparams.max_len
self.feats_dims = self.hparams.feats_dims # type, position, color, shape
self.rel_names = [
'is_aligned', 'is_back', 'is_close', 'is_down_adj', 'is_down_left_adj',
'is_down_right_adj', 'is_front', 'is_left', 'is_left_adj', 'is_right',
'is_right_adj', 'is_top_adj', 'is_top_left_adj', 'is_top_right_adj'
]
self.gnn_aggregation = self.hparams.aggregation
if self.hparams.gnn_type == 'RGATv3':
self.gnn_encoder = RGATv3(
hidden_channels=self.state_dim,
out_channels=self.state_dim,
num_heads=4,
rel_names=self.rel_names,
dropout=0.0,
n_layers=2,
activation=nn.ELU(),
residual=False
)
if self.hparams.gnn_type == 'RGATv4':
self.gnn_encoder = RGATv4(
hidden_channels=self.state_dim,
out_channels=self.state_dim,
num_heads=4,
rel_names=self.rel_names,
dropout=0.0,
n_layers=2,
activation=nn.ELU(),
residual=False
)
if self.hparams.gnn_type == 'RSAGEv4':
self.gnn_encoder = RSAGEv4(
hidden_channels=self.state_dim,
out_channels=self.state_dim,
rel_names=self.rel_names,
dropout=0.0,
n_layers=2,
activation=nn.ELU()
)
if self.hparams.gnn_type == 'RAGNNv4':
self.gnn_encoder = RAGNNv4(
hidden_channels=self.state_dim,
out_channels=self.state_dim,
rel_names=self.rel_names,
dropout=0.0,
n_layers=2,
activation=nn.ELU()
)
if self.hparams.gnn_type == 'RGATv4Norm':
self.gnn_encoder = RGATv4Norm(
hidden_channels=self.state_dim,
out_channels=self.state_dim,
num_heads=4,
rel_names=self.rel_names,
dropout=0.0,
n_layers=2,
activation=nn.ELU(),
residual=True
)
self.d_model = self.hparams.d_model
self.nhead = self.hparams.nhead
self.dim_feedforward = self.hparams.dim_feedforward
self.transformer_dropout = self.hparams.transformer_dropout
self.transformer_activation = self.hparams.transformer_activation
self.num_encoder_layers = self.hparams.num_encoder_layers
self.transformer_norm_input = self.hparams.transformer_norm_input
self.context_enc = TransformerEncoder(
self.d_model,
self.nhead,
self.dim_feedforward,
self.transformer_dropout,
self.transformer_activation,
self.num_encoder_layers,
self.max_len,
self.transformer_norm_input
)
if self.gnn_aggregation == 'cat_axis_1':
self.gnn2transformer = nn.Linear(self.state_dim * len(self.feats_dims) + self.action_dim, self.d_model)
self.mlp_input_size = len(self.feats_dims) * self.state_dim + self.d_model
elif self.gnn_aggregation == 'sum':
self.gnn2transformer = nn.Linear(self.state_dim + self.action_dim, self.d_model)
self.mlp_input_size = self.state_dim + self.d_model
else:
raise ValueError('Only sum and cat1 aggregations are available.')
self.policy = MlpModel(input_size=self.mlp_input_size, hidden_sizes=[256, 128, 256],
output_size=self.action_dim, dropout=self.dropout)
# CLS Embedding parameters, requires_grad=True
self.embedding = nn.Embedding(self.max_len + 1, self.d_model) # + 1 cause of cls token
self.emb_layer_norm = nn.LayerNorm(self.d_model)
self.emb_dropout = nn.Dropout(p=self.transformer_dropout)
def forward(self, batch):
dem_frames, dem_actions, dem_lens, query_frame, target_action = batch
dem_actions = dem_actions.float()
target_action = target_action.float()
dem_states = self.gnn_encoder(dem_frames)
b, l, s, _ = dem_actions.size()
dem_lens = torch.tensor([t for dl in dem_lens for t in dl]).to(torch.int64).cpu()
dem_lens = dem_lens.view(-1)
dem_actions_packed = torch.nn.utils.rnn.pack_padded_sequence(dem_actions.view(b*l, s, -1), dem_lens, batch_first=True, enforce_sorted=False)[0]
dem_traj = torch.cat([dem_states, dem_actions_packed], dim=1)
h_node = self.gnn2transformer(dem_traj)
hidden_dim = h_node.size()[1]
padded_trajectory = torch.zeros(b*l, self.max_len, hidden_dim).to(self.device)
j = 0
for idx, i in enumerate(dem_lens):
padded_trajectory[idx][:i] = h_node[j:j+i]
j += i
mask = self.make_mask(padded_trajectory).to(self.device)
transformer_input = padded_trajectory.transpose(0, 1) # [30, 16, 128]
# add cls:
cls_embedding = nn.Parameter(torch.randn([1, 1, self.d_model], requires_grad=True)).expand(1, b*l, -1).to(self.device)
transformer_input = torch.cat([transformer_input, cls_embedding], dim=0) # [31, 16, 128]
zeros = mask.data.new(mask.size(0), 1).fill_(0)
mask = torch.cat([mask, zeros], dim=1)
# Embed
indices = torch.arange(self.max_len + 1, dtype=torch.int).to(self.device) # + 1 cause of cls [0, 1, ..., 30]
positional_embeddings = self.embedding(indices).unsqueeze(1) # torch.Size([31, 1, 128])
#generate transformer input
pe_input = positional_embeddings + transformer_input # torch.Size([31, 16, 128])
# Layernorm and dropout
transformer_in = self.emb_dropout(self.emb_layer_norm(pe_input)) # torch.Size([31, 16, 128])
# transformer encoding and output parsing
out, _ = self.context_enc(transformer_in, mask) # [31, 16, 128]
cls = out[-1]
cls = cls.view(b, l, -1) # 2, 8, 128
context = torch.mean(cls, dim=1)
# CLASSIFICATION
test_states = self.gnn_encoder(query_frame) # torch.Size([2, 512])
test_context_states = torch.cat([context, test_states], dim=1) # torch.Size([batchsize, hidden_dim + lstm_hidden_dim]) 192=512+128
# for each state in the test states calculate action
x = self.policy(test_context_states)
test_actions_pred = torch.tanh(x) # torch.Size([batchsize, 2])
#test_actions_pred = torch.tanh(self.policy(test_context_states)) # torch.Size([batchsize, 2])
return target_action, test_actions_pred
def make_mask(self, feature):
return (torch.sum(
torch.abs(feature),
dim=-1
) == 0)#.unsqueeze(1).unsqueeze(2)
def training_step(self, batch, batch_idx):
test_actions, test_actions_pred = self.forward(batch)
loss = F.mse_loss(test_actions, test_actions_pred)
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
return loss
def validation_step(self, batch, batch_idx, dataloader_idx=0):
test_actions, test_actions_pred = self.forward(batch)
loss = F.mse_loss(test_actions, test_actions_pred)
self.log('val_loss', loss, on_epoch=True, logger=True)
def configure_optimizers(self):
optim = torch.optim.Adam(self.parameters(), lr=self.lr)
return optim
#optim = torch.optim.AdamW(self.parameters(), lr=self.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01)
#scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optim, gamma=0.96)
#return [optim], [scheduler]
def train_dataloader(self):
train_dataset = ToMnetDGLDataset(path=self.hparams.data_path,
types=self.hparams.types,
mode='train')
train_loader = DataLoader(dataset=train_dataset,
batch_size=self.hparams.batch_size,
collate_fn=collate_function_seq,
num_workers=self.hparams.num_workers,
#pin_memory=True,
shuffle=True)
return train_loader
def val_dataloader(self):
val_datasets = []
val_loaders = []
for t in self.hparams.types:
val_datasets.append(ToMnetDGLDataset(path=self.hparams.data_path,
types=[t],
mode='val'))
val_loaders.append(DataLoader(dataset=val_datasets[-1],
batch_size=self.hparams.batch_size,
collate_fn=collate_function_seq,
num_workers=self.hparams.num_workers,
#pin_memory=True,
shuffle=False))
return val_loaders
def configure_callbacks(self):
checkpoint = ModelCheckpoint(
dirpath=None, # automatically set
#filename=self.params['bc_model']+'-'+self.params['gnn_type']+'-'+self.gnn_params['feats_aggr']+'-{epoch:02d}',
save_top_k=-1,
period=1
)
return [checkpoint]