This commit is contained in:
Matteo Bortoletto 2024-02-01 15:40:47 +01:00
parent a333481e05
commit de0bea7508
18 changed files with 3150 additions and 2 deletions

0
tom/__init__.py Normal file
View file

310
tom/dataset.py Normal file
View file

@ -0,0 +1,310 @@
import pickle as pkl
import os
import torch
import torch.utils.data
import torch.nn.functional as F
import dgl
import random
from dgl.data import DGLDataset
def collate_function_seq(batch):
#dem_frames = torch.stack([item[0] for item in batch])
dem_frames = dgl.batch([item[0] for item in batch])
dem_actions = torch.stack([item[1] for item in batch])
dem_lens = [item[2] for item in batch]
#query_frames = torch.stack([item[3] for item in batch])
query_frames = dgl.batch([item[3] for item in batch])
target_actions = torch.stack([item[4] for item in batch])
return [dem_frames, dem_actions, dem_lens, query_frames, target_actions]
def collate_function_seq_test(batch):
dem_expected_states = dgl.batch([item[0] for item in batch][0])
dem_expected_actions = torch.stack([item[1] for item in batch][0]).unsqueeze(dim=0)
dem_expected_lens = [item[2] for item in batch]
#print(dem_expected_actions.size())
dem_unexpected_states = dgl.batch([item[3] for item in batch][0])
dem_unexpected_actions = torch.stack([item[4] for item in batch][0]).unsqueeze(dim=0)
dem_unexpected_lens = [item[5] for item in batch]
query_expected_frames = dgl.batch([item[6] for item in batch])
target_expected_actions = torch.stack([item[7] for item in batch])
#print(target_expected_actions.size())
query_unexpected_frames = dgl.batch([item[8] for item in batch])
target_unexpected_actions = torch.stack([item[9] for item in batch])
return [
dem_expected_states, dem_expected_actions, dem_expected_lens, \
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, \
query_expected_frames, target_expected_actions, \
query_unexpected_frames, target_unexpected_actions
]
def collate_function_mental(batch):
dem_frames = dgl.batch([item[0] for item in batch])
dem_actions = torch.stack([item[1] for item in batch])
dem_lens = [item[2] for item in batch]
past_test_frames = dgl.batch([item[3] for item in batch])
past_test_actions = torch.stack([item[4] for item in batch])
past_test_len = [item[5] for item in batch]
query_frames = dgl.batch([item[6] for item in batch])
target_actions = torch.stack([item[7] for item in batch])
return [dem_frames, dem_actions, dem_lens, past_test_frames, past_test_actions, past_test_len, query_frames, target_actions]
class ToMnetDGLDataset(DGLDataset):
"""
Training dataset class.
"""
def __init__(self, path, types=None, mode="train"):
self.path = path
self.types = types
self.mode = mode
print('Mode:', self.mode)
if self.mode == 'train':
if len(self.types) == 4:
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/'
#self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_global/'
#self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_local/'
elif len(self.types) == 3:
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper() + '/'
print(self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper())
elif len(self.types) == 2:
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + '/'
print(self.types[0][0].upper() + self.types[1][0].upper())
elif len(self.types) == 1:
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + '/'
else: raise ValueError('Number of types different from 1 or 4.')
elif self.mode == 'val':
assert len(self.types) == 1
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/' + self.types[0] + '/'
#self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_global/' + self.types[0] + '/'
#self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_local/' + self.types[0] + '/'
else:
raise ValueError
def get_test(self, states, actions):
# now states is a batched graph -> unbatch it, take the len, pick one sub-graph
# randomly and select the corresponding action
frame_graphs = dgl.unbatch(states)
trial_len = len(frame_graphs)
query_idx = random.randint(0, trial_len - 1)
query_graph = frame_graphs[query_idx]
target_action = actions[query_idx]
return query_graph, target_action
def __getitem__(self, idx):
with open(self.path+str(idx)+'.pkl', 'rb') as f:
states, actions, lens, _ = pkl.load(f)
# shuffle
ziplist = list(zip(states, actions, lens))
random.shuffle(ziplist)
states, actions, lens = zip(*ziplist)
# convert tuples to lists
states, actions, lens = [*states], [*actions], [*lens]
# pick last element in the list as test and pick random frame
test_s, test_a = self.get_test(states[-1], actions[-1])
dem_s = states[:-1]
dem_a = actions[:-1]
dem_lens = lens[:-1]
dem_s = dgl.batch(dem_s)
dem_a = torch.stack(dem_a)
return dem_s, dem_a, dem_lens, test_s, test_a
def __len__(self):
return len(os.listdir(self.path))
class TestToMnetDGLDataset(DGLDataset):
"""
Testing dataset class.
"""
def __init__(self, path, task_type=None, mode="test"):
self.path = path
self.type = task_type
self.mode = mode
print('Mode:', self.mode)
if self.mode == 'test':
self.path = self.path + '_dgl_hetero_nobound_4feats/' + self.type + '/'
#self.path = self.path + '_dgl_hetero_nobound_4feats_global/' + self.type + '/'
#self.path = self.path + '_dgl_hetero_nobound_4feats_local/' + self.type + '/'
else:
raise ValueError
def __getitem__(self, idx):
with open(self.path+str(idx)+'.pkl', 'rb') as f:
dem_expected_states, dem_expected_actions, dem_expected_lens, _, \
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, _, \
query_expected_frames, target_expected_actions, \
query_unexpected_frames, target_unexpected_actions = pkl.load(f)
assert len(dem_expected_states) == 8
assert len(dem_expected_actions) == 8
assert len(dem_expected_lens) == 8
assert len(dem_unexpected_states) == 8
assert len(dem_unexpected_actions) == 8
assert len(dem_unexpected_lens) == 8
assert len(dgl.unbatch(query_expected_frames)) == target_expected_actions.size()[0]
assert len(dgl.unbatch(query_unexpected_frames)) == target_unexpected_actions.size()[0]
# ignore n_nodes
return dem_expected_states, dem_expected_actions, dem_expected_lens, \
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, \
query_expected_frames, target_expected_actions, \
query_unexpected_frames, target_unexpected_actions
def __len__(self):
return len(os.listdir(self.path))
class ToMnetDGLDatasetUndersample(DGLDataset):
"""
Training dataset class for the behavior cloning mlp model.
"""
def __init__(self, path, types=None, mode="train"):
self.path = path
self.types = types
self.mode = mode
print('Mode:', self.mode)
if self.mode == 'train':
if len(self.types) == 4:
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/'
elif len(self.types) == 3:
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper() + '/'
print(self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper())
elif len(self.types) == 2:
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + '/'
print(self.types[0][0].upper() + self.types[1][0].upper())
elif len(self.types) == 1:
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + '/'
else: raise ValueError('Number of types different from 1 or 4.')
elif self.mode == 'val':
assert len(self.types) == 1
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/' + self.types[0] + '/'
else:
raise ValueError
print('Undersampled dataset!')
def get_test(self, states, actions):
# now states is a batched graph -> unbatch it, take the len, pick one sub-graph
# randomly and select the corresponding action
frame_graphs = dgl.unbatch(states)
trial_len = len(frame_graphs)
query_idx = random.randint(0, trial_len - 1)
query_graph = frame_graphs[query_idx]
target_action = actions[query_idx]
return query_graph, target_action
def __getitem__(self, idx):
idx = idx + 3175
with open(self.path+str(idx)+'.pkl', 'rb') as f:
states, actions, lens, _ = pkl.load(f)
# shuffle
ziplist = list(zip(states, actions, lens))
random.shuffle(ziplist)
states, actions, lens = zip(*ziplist)
# convert tuples to lists
states, actions, lens = [*states], [*actions], [*lens]
# pick last element in the list as test and pick random frame
test_s, test_a = self.get_test(states[-1], actions[-1])
dem_s = states[:-1]
dem_a = actions[:-1]
dem_lens = lens[:-1]
dem_s = dgl.batch(dem_s)
dem_a = torch.stack(dem_a)
return dem_s, dem_a, dem_lens, test_s, test_a
def __len__(self):
return len(os.listdir(self.path)) - 3175
class ToMnetDGLDatasetMental(DGLDataset):
"""
Training dataset class.
"""
def __init__(self, path, types=None, mode="train"):
self.path = path
self.types = types
self.mode = mode
print('Mode:', self.mode)
if self.mode == 'train':
if len(self.types) == 4:
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/'
elif len(self.types) == 3:
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper() + '/'
print(self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper())
elif len(self.types) == 2:
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + '/'
print(self.types[0][0].upper() + self.types[1][0].upper())
elif len(self.types) == 1:
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + '/'
else: raise ValueError('Number of types different from 1 or 4.')
elif self.mode == 'val':
assert len(self.types) == 1
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/' + self.types[0] + '/'
else:
raise ValueError
def get_test(self, states, actions):
"""
return: past_test_graphs, past_test_actions, test_graph, test_action
"""
frame_graphs = dgl.unbatch(states)
trial_len = len(frame_graphs)
query_idx = random.randint(0, trial_len - 1)
test_graph = frame_graphs[query_idx]
test_action = actions[query_idx]
if query_idx > 0:
#past_test_actions = F.pad(past_test_actions, (0,0,0,41-query_idx), 'constant', 0)
if query_idx == 1:
past_test_graphs = frame_graphs[0]
past_test_actions = actions[:query_idx]
past_test_actions = F.pad(past_test_actions, (0,0,0,41-query_idx), 'constant', 0)
return past_test_graphs, past_test_actions, query_idx, test_graph, test_action
else:
past_test_graphs = frame_graphs[:query_idx]
past_test_actions = actions[:query_idx]
past_test_graphs = dgl.batch(past_test_graphs)
past_test_actions = F.pad(past_test_actions, (0,0,0,41-query_idx), 'constant', 0)
return past_test_graphs, past_test_actions, query_idx, test_graph, test_action
else:
test_action_ = F.pad(test_action.unsqueeze(0), (0,0,0,41-1), 'constant', 0)
# NOTE: since there are no past frames, return the test frame and action and query_idx=1
# not sure what would be a better alternative
return test_graph, test_action_, 1, test_graph, test_action
def __getitem__(self, idx):
with open(self.path+str(idx)+'.pkl', 'rb') as f:
states, actions, lens, _ = pkl.load(f)
ziplist = list(zip(states, actions, lens))
random.shuffle(ziplist)
states, actions, lens = zip(*ziplist)
states, actions, lens = [*states], [*actions], [*lens]
past_test_s, past_test_a, past_test_len, test_s, test_a = self.get_test(states[-1], actions[-1])
dem_s = states[:-1]
dem_a = actions[:-1]
dem_lens = lens[:-1]
dem_s = dgl.batch(dem_s)
dem_a = torch.stack(dem_a)
return dem_s, dem_a, dem_lens, past_test_s, past_test_a, past_test_len, test_s, test_a
def __len__(self):
return len(os.listdir(self.path))
# --------------------------------------------------------------------------------------------------------
if __name__ == "__main__":
types = [
'preference', 'multi_agent', 'inaccessible_goal',
'efficiency_irrational', 'efficiency_time','efficiency_path',
'instrumental_no_barrier', 'instrumental_blocking_barrier', 'instrumental_inconsequential_barrier'
]
mental_dataset = ToMnetDGLDatasetMental(
path='/datasets/external/bib_train/graphs/all_tasks/',
types=['instrumental_action'],
mode='train'
)
dem_frames, dem_actions, dem_lens, past_test_frames, past_test_actions, len, test_frame, test_action = mental_dataset.__getitem__(99)
breakpoint()

877
tom/gnn.py Normal file
View file

@ -0,0 +1,877 @@
import dgl.nn.pytorch as dglnn
import torch.nn as nn
import torch.nn.functional as F
import torch
import dgl
import sys
import copy
from wandb import agent
sys.path.append('/projects/bortoletto/irene/')
from tom.norm import Norm
class RSAGEv4(nn.Module):
# multi-layer GNN for one single feature
def __init__(
self,
hidden_channels,
out_channels,
rel_names,
dropout,
n_layers,
activation=nn.ELU(),
):
super().__init__()
self.embedding_type = nn.Linear(9, int(hidden_channels))
self.embedding_pos = nn.Linear(2, int(hidden_channels))
self.embedding_color = nn.Linear(3, int(hidden_channels))
self.embedding_shape = nn.Linear(18, int(hidden_channels))
self.combine_attr = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*3, hidden_channels)
)
self.combine_pos = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*2, hidden_channels)
)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.SAGEConv(
in_feats=hidden_channels,
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
aggregator_type='lstm',
feat_drop=dropout,
activation=activation if l < n_layers - 1 else None
)
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
def forward(self, g):
attr = []
attr.append(self.embedding_type(g.ndata['type'].float()))
pos = self.embedding_pos(g.ndata['pos']/170.)
attr.append(self.embedding_color(g.ndata['color']/255.))
attr.append(self.embedding_shape(g.ndata['shape'].float()))
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
for l, conv in enumerate(self.layers):
h = conv(g, h)
with g.local_scope():
g.ndata['h'] = h['obj']
out = dgl.mean_nodes(g, 'h')
return out
# -------------------------------------------------------------------------------------------
class RAGNNv4(nn.Module):
# multi-layer GNN for one single feature
def __init__(
self,
hidden_channels,
out_channels,
rel_names,
dropout,
n_layers,
activation=nn.ELU(),
):
super().__init__()
self.embedding_type = nn.Linear(9, int(hidden_channels))
self.embedding_pos = nn.Linear(2, int(hidden_channels))
self.embedding_color = nn.Linear(3, int(hidden_channels))
self.embedding_shape = nn.Linear(18, int(hidden_channels))
self.combine_attr = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*3, hidden_channels)
)
self.combine_pos = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*2, hidden_channels)
)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.AGNNConv()
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
def forward(self, g):
attr = []
attr.append(self.embedding_type(g.ndata['type'].float()))
pos = self.embedding_pos(g.ndata['pos']/170.)
attr.append(self.embedding_color(g.ndata['color']/255.))
attr.append(self.embedding_shape(g.ndata['shape'].float()))
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
for l, conv in enumerate(self.layers):
h = conv(g, h)
with g.local_scope():
g.ndata['h'] = h['obj']
out = dgl.mean_nodes(g, 'h')
return out
# -------------------------------------------------------------------------------------------
class RGATv2(nn.Module):
# multi-layer GNN for one single feature
def __init__(
self,
hidden_channels,
out_channels,
num_heads,
rel_names,
dropout,
n_layers,
activation=nn.ELU(),
residual=False
):
super().__init__()
self.embedding_type = nn.Embedding(9, int(hidden_channels*num_heads/4))
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads/4))
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads/4))
self.embedding_shape = nn.Embedding(18, int(hidden_channels*num_heads/4))
self.combine = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads, hidden_channels*num_heads)
)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.GATv2Conv(
in_feats=hidden_channels*num_heads,
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
num_heads=num_heads if l < n_layers - 1 else 1,
feat_drop=dropout,
attn_drop=dropout,
residual=residual,
activation=activation if l < n_layers - 1 else None
)
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
def forward(self, g):
feats = []
feats.append(self.embedding_type(torch.argmax(g.ndata['type'], dim=1)))
feats.append(self.embedding_pos(g.ndata['pos']/170.))
feats.append(self.embedding_color(g.ndata['color']/255.))
feats.append(self.embedding_shape(torch.argmax(g.ndata['shape'], dim=1)))
h = {'obj': self.combine(torch.cat(feats, dim=1))}
for l, conv in enumerate(self.layers):
h = conv(g, h)
if l != len(self.layers) - 1:
h = {k: v.flatten(1) for k, v in h.items()}
else:
h = {k: v.mean(1) for k, v in h.items()}
with g.local_scope():
g.ndata['h'] = h['obj']
out = dgl.mean_nodes(g, 'h')
return out
# -------------------------------------------------------------------------------------------
class RGATv3(nn.Module):
# multi-layer GNN for one single feature
def __init__(
self,
hidden_channels,
out_channels,
num_heads,
rel_names,
dropout,
n_layers,
activation=nn.ELU(),
residual=False
):
super().__init__()
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
self.combine = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*4, hidden_channels*num_heads)
)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.GATv2Conv(
in_feats=hidden_channels*num_heads,
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
feat_drop=dropout,
attn_drop=dropout,
residual=residual,
activation=activation if l < n_layers - 1 else None
)
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
def forward(self, g):
feats = []
feats.append(self.embedding_type(g.ndata['type'].float()))
feats.append(self.embedding_pos(g.ndata['pos']/170.)) # NOTE: this should be 180 because I remove the boundary walls!
feats.append(self.embedding_color(g.ndata['color']/255.))
feats.append(self.embedding_shape(g.ndata['shape'].float()))
h = {'obj': self.combine(torch.cat(feats, dim=1))}
for l, conv in enumerate(self.layers):
h = conv(g, h)
if l != len(self.layers) - 1:
h = {k: v.flatten(1) for k, v in h.items()}
else:
h = {k: v.mean(1) for k, v in h.items()}
with g.local_scope():
g.ndata['h'] = h['obj']
out = dgl.mean_nodes(g, 'h')
return out
# -------------------------------------------------------------------------------------------
class RGCNv2(nn.Module):
# multi-layer GNN for one single feature
def __init__(
self,
hidden_channels,
out_channels,
rel_names,
dropout,
n_layers,
activation=nn.ReLU()
):
super().__init__()
self.embedding_type = nn.Linear(9, int(hidden_channels))
self.embedding_pos = nn.Linear(2, int(hidden_channels))
self.embedding_color = nn.Linear(3, int(hidden_channels))
self.embedding_shape = nn.Linear(18, int(hidden_channels))
self.combine = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*4, hidden_channels)
)
self.layers = nn.ModuleList([
dglnn.RelGraphConv(
in_feat=hidden_channels,
out_feat=hidden_channels,
num_rels=len(rel_names),
regularizer=None,
num_bases=None,
bias=True,
activation=activation,
self_loop=True,
dropout=dropout,
layer_norm=False
)
for _ in range(n_layers-1)])
self.layers.append(
dglnn.RelGraphConv(
in_feat=hidden_channels,
out_feat=out_channels,
num_rels=len(rel_names),
regularizer=None,
num_bases=None,
bias=True,
activation=activation,
self_loop=True,
dropout=dropout,
layer_norm=False
)
)
def forward(self, g):
g = g.to_homogeneous()
feats = []
feats.append(self.embedding_type(g.ndata['type'].float()))
feats.append(self.embedding_pos(g.ndata['pos']/170.))
feats.append(self.embedding_color(g.ndata['color']/255.))
feats.append(self.embedding_shape(g.ndata['shape'].float()))
h = self.combine(torch.cat(feats, dim=1))
for conv in self.layers:
h = conv(g, h, g.etypes)
with g.local_scope():
g.ndata['h'] = h
out = dgl.mean_nodes(g, 'h')
return out
# -------------------------------------------------------------------------------------------
class RGATv3Agent(nn.Module):
# multi-layer GNN for one single feature
def __init__(
self,
hidden_channels,
out_channels,
num_heads,
rel_names,
dropout,
n_layers,
activation=nn.ELU(),
residual=False
):
super().__init__()
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
self.combine = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*4, hidden_channels*num_heads)
)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.GATv2Conv(
in_feats=hidden_channels*num_heads,
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
feat_drop=dropout,
attn_drop=dropout,
residual=residual,
activation=activation if l < n_layers - 1 else None
)
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
def forward(self, g):
agent_mask = g.ndata['type'][:, 0] == 1
feats = []
feats.append(self.embedding_type(g.ndata['type'].float()))
feats.append(self.embedding_pos(g.ndata['pos']/200.))
feats.append(self.embedding_color(g.ndata['color']/255.))
feats.append(self.embedding_shape(g.ndata['shape'].float()))
h = {'obj': self.combine(torch.cat(feats, dim=1))}
for l, conv in enumerate(self.layers):
h = conv(g, h)
if l != len(self.layers) - 1:
h = {k: v.flatten(1) for k, v in h.items()}
else:
h = {k: v.mean(1) for k, v in h.items()}
with g.local_scope():
g.ndata['h'] = h['obj']
out = g.ndata['h'][agent_mask, :]
ctx = dgl.mean_nodes(g, 'h')
return out + ctx
# -------------------------------------------------------------------------------------------
class RGATv4(nn.Module):
# multi-layer GNN for one single feature
def __init__(
self,
hidden_channels,
out_channels,
num_heads,
rel_names,
dropout,
n_layers,
activation=nn.ELU(),
residual=False
):
super().__init__()
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
self.combine_attr = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads)
)
self.combine_pos = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads)
)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.GATv2Conv(
in_feats=hidden_channels*num_heads,
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
feat_drop=dropout,
attn_drop=dropout,
residual=residual,
share_weights=False,
activation=activation if l < n_layers - 1 else None
)
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
def forward(self, g):
attr = []
attr.append(self.embedding_type(g.ndata['type'].float()))
pos = self.embedding_pos(g.ndata['pos']/170.)
attr.append(self.embedding_color(g.ndata['color']/255.))
attr.append(self.embedding_shape(g.ndata['shape'].float()))
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
for l, conv in enumerate(self.layers):
h = conv(g, h)
if l != len(self.layers) - 1:
h = {k: v.flatten(1) for k, v in h.items()}
else:
h = {k: v.mean(1) for k, v in h.items()}
with g.local_scope():
g.ndata['h'] = h['obj']
out = dgl.mean_nodes(g, 'h')
return out
# -------------------------------------------------------------------------------------------
class RGATv4Norm(nn.Module):
# multi-layer GNN for one single feature
def __init__(
self,
hidden_channels,
out_channels,
num_heads,
rel_names,
dropout,
n_layers,
activation=nn.ELU(),
residual=False
):
super().__init__()
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
self.combine_attr = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads)
)
self.combine_pos = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads)
)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.GATv2Conv(
in_feats=hidden_channels*num_heads,
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
feat_drop=dropout,
attn_drop=dropout,
residual=residual,
activation=activation if l < n_layers - 1 else None
)
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
self.norms = nn.ModuleList([
Norm(
norm_type='gn',
hidden_dim=hidden_channels*num_heads if l < n_layers - 1 else out_channels
)
for l in range(n_layers)
])
def forward(self, g):
attr = []
attr.append(self.embedding_type(g.ndata['type'].float()))
pos = self.embedding_pos(g.ndata['pos']/170.)
attr.append(self.embedding_color(g.ndata['color']/255.))
attr.append(self.embedding_shape(g.ndata['shape'].float()))
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
for l, conv in enumerate(self.layers):
h = conv(g, h)
if l != len(self.layers) - 1:
h = {k: v.flatten(1) for k, v in h.items()}
else:
h = {k: v.mean(1) for k, v in h.items()}
h = {k: self.norms[l](g, v) for k, v in h.items()}
with g.local_scope():
g.ndata['h'] = h['obj']
out = dgl.mean_nodes(g, 'h')
return out
# -------------------------------------------------------------------------------------------
class RGATv3Norm(nn.Module):
# multi-layer GNN for one single feature
def __init__(
self,
hidden_channels,
out_channels,
num_heads,
rel_names,
dropout,
n_layers,
activation=nn.ELU(),
residual=False
):
super().__init__()
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
self.combine = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*4, hidden_channels*num_heads)
)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.GATv2Conv(
in_feats=hidden_channels*num_heads,
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
feat_drop=dropout,
attn_drop=dropout,
residual=residual,
activation=activation if l < n_layers - 1 else None
)
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
self.norms = nn.ModuleList([
Norm(
norm_type='gn',
hidden_dim=hidden_channels*num_heads if l < n_layers - 1 else out_channels
)
for l in range(n_layers)
])
def forward(self, g):
feats = []
feats.append(self.embedding_type(g.ndata['type'].float()))
feats.append(self.embedding_pos(g.ndata['pos']/170.))
feats.append(self.embedding_color(g.ndata['color']/255.))
feats.append(self.embedding_shape(g.ndata['shape'].float()))
h = {'obj': self.combine(torch.cat(feats, dim=1))}
for l, conv in enumerate(self.layers):
h = conv(g, h)
if l != len(self.layers) - 1:
h = {k: v.flatten(1) for k, v in h.items()}
else:
h = {k: v.mean(1) for k, v in h.items()}
h = {k: self.norms[l](g, v) for k, v in h.items()}
with g.local_scope():
g.ndata['h'] = h['obj']
out = dgl.mean_nodes(g, 'h')
return out
# -------------------------------------------------------------------------------------------
class RGATv4Agent(nn.Module):
# multi-layer GNN for one single feature
def __init__(
self,
hidden_channels,
out_channels,
num_heads,
rel_names,
dropout,
n_layers,
activation=nn.ELU(),
residual=False
):
super().__init__()
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
self.combine_attr = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads)
)
self.combine_pos = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads)
)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.GATv2Conv(
in_feats=hidden_channels*num_heads,
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
feat_drop=dropout,
attn_drop=dropout,
residual=residual,
activation=activation if l < n_layers - 1 else None
)
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
self.combine_agent_context = nn.Linear(out_channels*2, out_channels)
def forward(self, g):
agent_mask = g.ndata['type'][:, 0] == 1
attr = []
attr.append(self.embedding_type(g.ndata['type'].float()))
pos = self.embedding_pos(g.ndata['pos']/170.)
attr.append(self.embedding_color(g.ndata['color']/255.))
attr.append(self.embedding_shape(g.ndata['shape'].float()))
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
for l, conv in enumerate(self.layers):
h = conv(g, h)
if l != len(self.layers) - 1:
h = {k: v.flatten(1) for k, v in h.items()}
else:
h = {k: v.mean(1) for k, v in h.items()}
with g.local_scope():
g.ndata['h'] = h['obj']
h_a = g.ndata['h'][agent_mask, :]
g_no_agent = copy.deepcopy(g)
g_no_agent.remove_nodes([i for i, x in enumerate(agent_mask) if x])
h_g = dgl.mean_nodes(g_no_agent, 'h')
out = self.combine_agent_context(torch.cat((h_a, h_g), dim=1))
return out
# -------------------------------------------------------------------------------------------
class RGCNv4(nn.Module):
# multi-layer GNN for one single feature
def __init__(
self,
hidden_channels,
out_channels,
rel_names,
n_layers,
activation=nn.ELU(),
):
super().__init__()
self.embedding_type = nn.Linear(9, int(hidden_channels))
self.embedding_pos = nn.Linear(2, int(hidden_channels))
self.embedding_color = nn.Linear(3, int(hidden_channels))
self.embedding_shape = nn.Linear(18, int(hidden_channels))
self.combine_attr = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*3, hidden_channels)
)
self.combine_pos = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*2, hidden_channels)
)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.GraphConv(
in_feats=hidden_channels,
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
activation=activation if l < n_layers - 1 else None
)
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
def forward(self, g):
attr = []
attr.append(self.embedding_type(g.ndata['type'].float()))
pos = self.embedding_pos(g.ndata['pos']/170.)
attr.append(self.embedding_color(g.ndata['color']/255.))
attr.append(self.embedding_shape(g.ndata['shape'].float()))
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
for l, conv in enumerate(self.layers):
h = conv(g, h)
with g.local_scope():
g.ndata['h'] = h['obj']
out = dgl.mean_nodes(g, 'h')
return out
# -------------------------------------------------------------------------------------------
class RGATv5(nn.Module):
def __init__(
self,
hidden_channels,
out_channels,
num_heads,
rel_names,
dropout,
n_layers,
activation=nn.ELU(),
residual=False
):
super().__init__()
self.hidden_channels = hidden_channels
self.num_heads = num_heads
self.embedding_type = nn.Linear(9, hidden_channels*num_heads)
self.embedding_pos = nn.Linear(2, hidden_channels*num_heads)
self.embedding_color = nn.Linear(3, hidden_channels*num_heads)
self.embedding_shape = nn.Linear(18, hidden_channels*num_heads)
self.combine = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*4, hidden_channels*num_heads)
)
self.attention = nn.Linear(hidden_channels*num_heads*4, 4)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.GATv2Conv(
in_feats=hidden_channels*num_heads,
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
num_heads=num_heads,
feat_drop=dropout,
attn_drop=dropout,
residual=residual,
activation=activation if l < n_layers - 1 else None
)
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
def forward(self, g):
feats = []
feats.append(self.embedding_type(g.ndata['type'].float()))
feats.append(self.embedding_pos(g.ndata['pos']/170.))
feats.append(self.embedding_color(g.ndata['color']/255.))
feats.append(self.embedding_shape(g.ndata['shape'].float()))
h = torch.cat(feats, dim=1)
feat_attn = F.softmax(self.attention(h), dim=1)
h = h * feat_attn.repeat_interleave(self.hidden_channels*self.num_heads, dim=1)
h_in = self.combine(h)
h = {'obj': h_in}
for conv in self.layers:
h = conv(g, h)
h = {k: v.flatten(1) for k, v in h.items()}
#if l != len(self.layers) - 1:
# h = {k: v.flatten(1) for k, v in h.items()}
#else:
# h = {k: v.mean(1) for k, v in h.items()}
h = {k: v + h_in for k, v in h.items()}
with g.local_scope():
g.ndata['h'] = h['obj']
out = dgl.mean_nodes(g, 'h')
return out
# -------------------------------------------------------------------------------------------
class RGATv6(nn.Module):
# RGATv6 = RGATv4 + Global Attention Pooling
def __init__(
self,
hidden_channels,
out_channels,
num_heads,
rel_names,
dropout,
n_layers,
activation=nn.ELU(),
residual=False
):
super().__init__()
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
self.combine_attr = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads)
)
self.combine_pos = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads)
)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.GATv2Conv(
in_feats=hidden_channels*num_heads,
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
feat_drop=dropout,
attn_drop=dropout,
residual=residual,
activation=activation if l < n_layers - 1 else None
)
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
gate_nn = nn.Linear(out_channels, 1)
self.gap = dglnn.GlobalAttentionPooling(gate_nn)
def forward(self, g):
attr = []
attr.append(self.embedding_type(g.ndata['type'].float()))
pos = self.embedding_pos(g.ndata['pos']/170.)
attr.append(self.embedding_color(g.ndata['color']/255.))
attr.append(self.embedding_shape(g.ndata['shape'].float()))
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
for l, conv in enumerate(self.layers):
h = conv(g, h)
if l != len(self.layers) - 1:
h = {k: v.flatten(1) for k, v in h.items()}
else:
h = {k: v.mean(1) for k, v in h.items()}
#with g.local_scope():
#g.ndata['h'] = h['obj']
#out = dgl.mean_nodes(g, 'h')
out = self.gap(g, h['obj'])
return out
# -------------------------------------------------------------------------------------------
class RGATv6Agent(nn.Module):
# RGATv6 = RGATv4 + Global Attention Pooling
def __init__(
self,
hidden_channels,
out_channels,
num_heads,
rel_names,
dropout,
n_layers,
activation=nn.ELU(),
residual=False
):
super().__init__()
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
self.combine_attr = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads)
)
self.combine_pos = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads)
)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.GATv2Conv(
in_feats=hidden_channels*num_heads,
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
feat_drop=dropout,
attn_drop=dropout,
residual=residual,
activation=activation if l < n_layers - 1 else None
)
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
gate_nn = nn.Linear(out_channels, 1)
self.gap = dglnn.GlobalAttentionPooling(gate_nn)
self.combine_agent_context = nn.Linear(out_channels*2, out_channels)
def forward(self, g):
agent_mask = g.ndata['type'][:, 0] == 1
attr = []
attr.append(self.embedding_type(g.ndata['type'].float()))
pos = self.embedding_pos(g.ndata['pos']/170.)
attr.append(self.embedding_color(g.ndata['color']/255.))
attr.append(self.embedding_shape(g.ndata['shape'].float()))
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
for l, conv in enumerate(self.layers):
h = conv(g, h)
if l != len(self.layers) - 1:
h = {k: v.flatten(1) for k, v in h.items()}
else:
h = {k: v.mean(1) for k, v in h.items()}
with g.local_scope():
g.ndata['h'] = h['obj']
h_a = g.ndata['h'][agent_mask, :]
h_g = g.ndata['h'][~agent_mask, :]
g_no_agent = copy.deepcopy(g)
g_no_agent.remove_nodes([i for i, x in enumerate(agent_mask) if x])
h_g = self.gap(g_no_agent, h_g) # dgl.mean_nodes(g_no_agent, 'h')
out = self.combine_agent_context(torch.cat((h_a, h_g), dim=1))
return out
# -------------------------------------------------------------------------------------------

513
tom/model.py Normal file
View file

@ -0,0 +1,513 @@
from argparse import ArgumentParser
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader
from tom.dataset import *
from tom.transformer import TransformerEncoder
from tom.gnn import RGATv2, RGATv3, RGATv3Agent, RGATv4, RGATv4Norm, RSAGEv4, RAGNNv4
class MlpModel(nn.Module):
"""Multilayer Perceptron with last layer linear.
Args:
input_size (int): number of inputs
hidden_sizes (list): can be empty list for none (linear model).
output_size: linear layer at output, or if ``None``, the last hidden size
will be the output size and will have nonlinearity applied
nonlinearity: torch nonlinearity Module (not Functional).
"""
def __init__(
self,
input_size,
hidden_sizes, # Can be empty list or None for none.
output_size=None, # if None, last layer has nonlinearity applied.
nonlinearity=nn.ReLU, # Module, not Functional.
dropout=None # Dropout value
):
super().__init__()
if isinstance(hidden_sizes, int):
hidden_sizes = [hidden_sizes]
elif hidden_sizes is None:
hidden_sizes = []
hidden_layers = [nn.Linear(n_in, n_out) for n_in, n_out in
zip([input_size] + hidden_sizes[:-1], hidden_sizes)]
sequence = list()
for i, layer in enumerate(hidden_layers):
if dropout is not None:
sequence.extend([layer, nonlinearity(), nn.Dropout(dropout)])
else:
sequence.extend([layer, nonlinearity()])
if output_size is not None:
last_size = hidden_sizes[-1] if hidden_sizes else input_size
sequence.append(torch.nn.Linear(last_size, output_size))
self.model = nn.Sequential(*sequence)
self._output_size = (hidden_sizes[-1] if output_size is None
else output_size)
def forward(self, input):
"""Compute the model on the input, assuming input shape [B,input_size]."""
return self.model(input)
@property
def output_size(self):
"""Retuns the output size of the model."""
return self._output_size
# ---------------------------------------------------------------------------------------------------------------------------------
class GraphBCRNN(pl.LightningModule):
"""
Implementation of the baseline model for the BC-RNN algorithm.
R-GCN + LSTM are used to encode the familiarization trials
"""
@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--action_dim', type=int, default=2)
parser.add_argument('--context_dim', type=int, default=32) # lstm hidden size
parser.add_argument('--beta', type=float, default=0.01)
parser.add_argument('--dropout', type=float, default=0.2)
parser.add_argument('--process_data', type=int, default=0)
parser.add_argument('--max_len', type=int, default=30)
# arguments for gnn
parser.add_argument('--gnn_type', type=str, default='RGATv4')
parser.add_argument('--state_dim', type=int, default=128) # gnn out_feats
parser.add_argument('--feats_dims', type=list, default=[9, 2, 3, 18])
parser.add_argument('--aggregation', type=str, default='sum')
# arguments for mpl
#parser.add_argument('--mpl_hid_feats', type=list, default=[256, 64, 16])
return parser
def __init__(self, hparams):
super().__init__()
self.hparams = hparams
self.lr = self.hparams.lr
self.state_dim = self.hparams.state_dim
self.action_dim = self.hparams.action_dim
self.context_dim = self.hparams.context_dim
self.beta = self.hparams.beta
self.dropout = self.hparams.dropout
self.max_len = self.hparams.max_len
self.feats_dims = self.hparams.feats_dims # type, position, color, shape
self.rel_names = [
'is_aligned', 'is_back', 'is_close', 'is_down_adj', 'is_down_left_adj',
'is_down_right_adj', 'is_front', 'is_left', 'is_left_adj', 'is_right',
'is_right_adj', 'is_top_adj', 'is_top_left_adj', 'is_top_right_adj'
]
self.gnn_aggregation = self.hparams.aggregation
if self.hparams.gnn_type == 'RGATv2':
self.gnn_encoder = RGATv2(
hidden_channels=self.state_dim,
out_channels=self.state_dim,
num_heads=4,
rel_names=self.rel_names,
dropout=0.0,
n_layers=2,
activation=nn.ELU(),
residual=False
)
if self.hparams.gnn_type == 'RGATv3':
self.gnn_encoder = RGATv3(
hidden_channels=self.state_dim,
out_channels=self.state_dim,
num_heads=4,
rel_names=self.rel_names,
dropout=0.0,
n_layers=2,
activation=nn.ELU(),
residual=False
)
if self.hparams.gnn_type == 'RGATv4':
self.gnn_encoder = RGATv4(
hidden_channels=self.state_dim,
out_channels=self.state_dim,
num_heads=4,
rel_names=self.rel_names,
dropout=0.0,
n_layers=2,
activation=nn.ELU(),
residual=False
)
if self.hparams.gnn_type == 'RGATv3Agent':
self.gnn_encoder = RGATv3Agent(
hidden_channels=self.state_dim,
out_channels=self.state_dim,
num_heads=4,
rel_names=self.rel_names,
dropout=0.0,
n_layers=2,
activation=nn.ELU(),
residual=False
)
if self.hparams.gnn_type == 'RSAGEv4':
self.gnn_encoder = RSAGEv4(
hidden_channels=self.state_dim,
out_channels=self.state_dim,
rel_names=self.rel_names,
dropout=0.0,
n_layers=2,
activation=nn.ELU()
)
if self.gnn_aggregation == 'cat_axis_1':
self.lstm_input_size = self.state_dim * len(self.feats_dims) + self.action_dim
self.mlp_input_size = self.state_dim * len(self.feats_dims) + self.context_dim * 2
elif self.gnn_aggregation == 'sum':
self.lstm_input_size = self.state_dim + self.action_dim
self.mlp_input_size = self.state_dim + self.context_dim * 2
else:
raise ValueError('Fix this')
self.context_enc = nn.LSTM(self.lstm_input_size, self.context_dim, 2,
batch_first=True, bidirectional=True)
self.policy = MlpModel(input_size=self.mlp_input_size, hidden_sizes=[256, 128, 256],
output_size=self.action_dim, dropout=self.dropout)
self.past_samples = []
def forward(self, batch):
dem_frames, dem_actions, dem_lens, query_frame, target_action = batch
dem_actions = dem_actions.float()
target_action = target_action.float()
dem_states = self.gnn_encoder(dem_frames) # torch.Size([number of frames, 128 * number of features if cat_axis_1])
# segment according the number of frames in each episode and pad with zeros
# to obtain tensors of shape [batch size, num of trials (8), max num of frames (30), hidden dim]
b, l, s, _ = dem_actions.size()
dem_states_new = []
for batch in range(b):
dem_states_new.append(self._sequence_to_padding(dem_states, dem_lens[batch], self.max_len))
dem_states_new = torch.stack(dem_states_new).to(self.device) # torch.Size([batchsize, 8, 30, 128 * number of features if cat_axis_1])
# concatenate states and actions to get expert trajectory
dem_states_new = dem_states_new.view(b * l, s, -1) # torch.Size([batchsize*8, 30, 128 * number of features if cat_axis_1])
dem_actions = dem_actions.view(b * l, s, -1) # torch.Size([batchsize*8, 30, 128])
dem_traj = torch.cat([dem_states_new, dem_actions], dim=2) # torch.Size([batchsize*8, 30, 2 + 128 * number of features if cat_axis_1])
# embed expert trajectory to get a context embedding batch x samples x dim
dem_lens = torch.tensor([t for dl in dem_lens for t in dl]).to(torch.int64).cpu()
dem_lens = dem_lens.view(-1)
x_lstm = nn.utils.rnn.pack_padded_sequence(dem_traj, dem_lens, batch_first=True, enforce_sorted=False)
x_lstm, _ = self.context_enc(x_lstm)
x_lstm, _ = nn.utils.rnn.pad_packed_sequence(x_lstm, batch_first=True)
x_out = x_lstm[:, -1, :]
x_out = x_out.view(b, l, -1)
context = torch.mean(x_out, dim=1) # torch.Size([batchsize, 64]) (64=32*2)
# concat context embedding to the state embedding of test trajectory
test_states = self.gnn_encoder(query_frame) # torch.Size([2, 128])
test_context_states = torch.cat([context, test_states], dim=1) # torch.Size([batchsize, 192]) 192=64+128
# for each state in the test states calculate action
test_actions_pred = torch.tanh(self.policy(test_context_states)) # torch.Size([batchsize, 2])
return target_action, test_actions_pred
def _sequence_to_padding(self, x, lengths, max_length):
# declare the shape, it can work for x of any shape.
ret_tensor = torch.zeros((len(lengths), max_length) + tuple(x.shape[1:]))
cum_len = 0
for i, l in enumerate(lengths):
ret_tensor[i, :l] = x[cum_len: cum_len+l]
cum_len += l
return ret_tensor
def training_step(self, batch, batch_idx):
test_actions, test_actions_pred = self.forward(batch)
loss = F.mse_loss(test_actions, test_actions_pred)
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
return loss
def validation_step(self, batch, batch_idx, dataloader_idx=0):
test_actions, test_actions_pred = self.forward(batch)
loss = F.mse_loss(test_actions, test_actions_pred)
self.log('val_loss', loss, on_epoch=True, logger=True)
def configure_optimizers(self):
optim = torch.optim.Adam(self.parameters(), lr=self.lr)
return optim
#optim = torch.optim.AdamW(self.parameters(), lr=self.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01)
#scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optim, gamma=0.8)
#return [optim], [scheduler]
def train_dataloader(self):
train_dataset = ToMnetDGLDataset(path=self.hparams.data_path,
types=self.hparams.types,
mode='train')
train_loader = DataLoader(dataset=train_dataset,
batch_size=self.hparams.batch_size,
collate_fn=collate_function_seq,
num_workers=self.hparams.num_workers,
#pin_memory=True,
shuffle=True)
return train_loader
def val_dataloader(self):
val_datasets = []
val_loaders = []
for t in self.hparams.types:
val_datasets.append(ToMnetDGLDataset(path=self.hparams.data_path,
types=[t],
mode='val'))
val_loaders.append(DataLoader(dataset=val_datasets[-1],
batch_size=self.hparams.batch_size,
collate_fn=collate_function_seq,
num_workers=self.hparams.num_workers,
#pin_memory=True,
shuffle=False))
return val_loaders
def configure_callbacks(self):
checkpoint = ModelCheckpoint(
dirpath=None, # automatically set
#filename=self.params['bc_model']+'-'+self.params['gnn_type']+'-'+self.gnn_params['feats_aggr']+'-{epoch:02d}',
save_top_k=-1,
period=1
)
return [checkpoint]
# ---------------------------------------------------------------------------------------------------------------------------------
class GraphBC_T(pl.LightningModule):
"""
BC model with GraphTrans encoder, LSTM and MLP.
"""
@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--action_dim', type=int, default=2)
parser.add_argument('--context_dim', type=int, default=32) # lstm hidden size
parser.add_argument('--beta', type=float, default=0.01)
parser.add_argument('--dropout', type=float, default=0.2)
parser.add_argument('--process_data', type=int, default=0)
parser.add_argument('--max_len', type=int, default=30)
# arguments for gnn
parser.add_argument('--state_dim', type=int, default=128) # gnn out_feats
parser.add_argument('--feats_dims', type=list, default=[9, 2, 3, 18])
parser.add_argument('--aggregation', type=str, default='cat_axis_1')
parser.add_argument('--gnn_type', type=str, default='RGATv3')
# arguments for mpl
#parser.add_argument('--mpl_hid_feats', type=list, default=[256, 64, 16])
# arguments for transformer
parser.add_argument('--d_model', type=int, default=128)
parser.add_argument('--nhead', type=int, default=4)
parser.add_argument('--dim_feedforward', type=int, default=512)
parser.add_argument('--transformer_dropout', type=float, default=0.3)
parser.add_argument('--transformer_activation', type=str, default='gelu')
parser.add_argument('--num_encoder_layers', type=int, default=6)
parser.add_argument('--transformer_norm_input', type=int, default=0)
return parser
def __init__(self, hparams):
super().__init__()
self.hparams = hparams
self.lr = self.hparams.lr
self.state_dim = self.hparams.state_dim
self.action_dim = self.hparams.action_dim
self.context_dim = self.hparams.context_dim
self.beta = self.hparams.beta
self.dropout = self.hparams.dropout
self.max_len = self.hparams.max_len
self.feats_dims = self.hparams.feats_dims # type, position, color, shape
self.rel_names = [
'is_aligned', 'is_back', 'is_close', 'is_down_adj', 'is_down_left_adj',
'is_down_right_adj', 'is_front', 'is_left', 'is_left_adj', 'is_right',
'is_right_adj', 'is_top_adj', 'is_top_left_adj', 'is_top_right_adj'
]
self.gnn_aggregation = self.hparams.aggregation
if self.hparams.gnn_type == 'RGATv3':
self.gnn_encoder = RGATv3(
hidden_channels=self.state_dim,
out_channels=self.state_dim,
num_heads=4,
rel_names=self.rel_names,
dropout=0.0,
n_layers=2,
activation=nn.ELU(),
residual=False
)
if self.hparams.gnn_type == 'RGATv4':
self.gnn_encoder = RGATv4(
hidden_channels=self.state_dim,
out_channels=self.state_dim,
num_heads=4,
rel_names=self.rel_names,
dropout=0.0,
n_layers=2,
activation=nn.ELU(),
residual=False
)
if self.hparams.gnn_type == 'RSAGEv4':
self.gnn_encoder = RSAGEv4(
hidden_channels=self.state_dim,
out_channels=self.state_dim,
rel_names=self.rel_names,
dropout=0.0,
n_layers=2,
activation=nn.ELU()
)
if self.hparams.gnn_type == 'RAGNNv4':
self.gnn_encoder = RAGNNv4(
hidden_channels=self.state_dim,
out_channels=self.state_dim,
rel_names=self.rel_names,
dropout=0.0,
n_layers=2,
activation=nn.ELU()
)
if self.hparams.gnn_type == 'RGATv4Norm':
self.gnn_encoder = RGATv4Norm(
hidden_channels=self.state_dim,
out_channels=self.state_dim,
num_heads=4,
rel_names=self.rel_names,
dropout=0.0,
n_layers=2,
activation=nn.ELU(),
residual=True
)
self.d_model = self.hparams.d_model
self.nhead = self.hparams.nhead
self.dim_feedforward = self.hparams.dim_feedforward
self.transformer_dropout = self.hparams.transformer_dropout
self.transformer_activation = self.hparams.transformer_activation
self.num_encoder_layers = self.hparams.num_encoder_layers
self.transformer_norm_input = self.hparams.transformer_norm_input
self.context_enc = TransformerEncoder(
self.d_model,
self.nhead,
self.dim_feedforward,
self.transformer_dropout,
self.transformer_activation,
self.num_encoder_layers,
self.max_len,
self.transformer_norm_input
)
if self.gnn_aggregation == 'cat_axis_1':
self.gnn2transformer = nn.Linear(self.state_dim * len(self.feats_dims) + self.action_dim, self.d_model)
self.mlp_input_size = len(self.feats_dims) * self.state_dim + self.d_model
elif self.gnn_aggregation == 'sum':
self.gnn2transformer = nn.Linear(self.state_dim + self.action_dim, self.d_model)
self.mlp_input_size = self.state_dim + self.d_model
else:
raise ValueError('Only sum and cat1 aggregations are available.')
self.policy = MlpModel(input_size=self.mlp_input_size, hidden_sizes=[256, 128, 256],
output_size=self.action_dim, dropout=self.dropout)
# CLS Embedding parameters, requires_grad=True
self.embedding = nn.Embedding(self.max_len + 1, self.d_model) # + 1 cause of cls token
self.emb_layer_norm = nn.LayerNorm(self.d_model)
self.emb_dropout = nn.Dropout(p=self.transformer_dropout)
def forward(self, batch):
dem_frames, dem_actions, dem_lens, query_frame, target_action = batch
dem_actions = dem_actions.float()
target_action = target_action.float()
dem_states = self.gnn_encoder(dem_frames)
b, l, s, _ = dem_actions.size()
dem_lens = torch.tensor([t for dl in dem_lens for t in dl]).to(torch.int64).cpu()
dem_lens = dem_lens.view(-1)
dem_actions_packed = torch.nn.utils.rnn.pack_padded_sequence(dem_actions.view(b*l, s, -1), dem_lens, batch_first=True, enforce_sorted=False)[0]
dem_traj = torch.cat([dem_states, dem_actions_packed], dim=1)
h_node = self.gnn2transformer(dem_traj)
hidden_dim = h_node.size()[1]
padded_trajectory = torch.zeros(b*l, self.max_len, hidden_dim).to(self.device)
j = 0
for idx, i in enumerate(dem_lens):
padded_trajectory[idx][:i] = h_node[j:j+i]
j += i
mask = self.make_mask(padded_trajectory).to(self.device)
transformer_input = padded_trajectory.transpose(0, 1) # [30, 16, 128]
# add cls:
cls_embedding = nn.Parameter(torch.randn([1, 1, self.d_model], requires_grad=True)).expand(1, b*l, -1).to(self.device)
transformer_input = torch.cat([transformer_input, cls_embedding], dim=0) # [31, 16, 128]
zeros = mask.data.new(mask.size(0), 1).fill_(0)
mask = torch.cat([mask, zeros], dim=1)
# Embed
indices = torch.arange(self.max_len + 1, dtype=torch.int).to(self.device) # + 1 cause of cls [0, 1, ..., 30]
positional_embeddings = self.embedding(indices).unsqueeze(1) # torch.Size([31, 1, 128])
#generate transformer input
pe_input = positional_embeddings + transformer_input # torch.Size([31, 16, 128])
# Layernorm and dropout
transformer_in = self.emb_dropout(self.emb_layer_norm(pe_input)) # torch.Size([31, 16, 128])
# transformer encoding and output parsing
out, _ = self.context_enc(transformer_in, mask) # [31, 16, 128]
cls = out[-1]
cls = cls.view(b, l, -1) # 2, 8, 128
context = torch.mean(cls, dim=1)
# CLASSIFICATION
test_states = self.gnn_encoder(query_frame) # torch.Size([2, 512])
test_context_states = torch.cat([context, test_states], dim=1) # torch.Size([batchsize, hidden_dim + lstm_hidden_dim]) 192=512+128
# for each state in the test states calculate action
x = self.policy(test_context_states)
test_actions_pred = torch.tanh(x) # torch.Size([batchsize, 2])
#test_actions_pred = torch.tanh(self.policy(test_context_states)) # torch.Size([batchsize, 2])
return target_action, test_actions_pred
def make_mask(self, feature):
return (torch.sum(
torch.abs(feature),
dim=-1
) == 0)#.unsqueeze(1).unsqueeze(2)
def training_step(self, batch, batch_idx):
test_actions, test_actions_pred = self.forward(batch)
loss = F.mse_loss(test_actions, test_actions_pred)
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
return loss
def validation_step(self, batch, batch_idx, dataloader_idx=0):
test_actions, test_actions_pred = self.forward(batch)
loss = F.mse_loss(test_actions, test_actions_pred)
self.log('val_loss', loss, on_epoch=True, logger=True)
def configure_optimizers(self):
optim = torch.optim.Adam(self.parameters(), lr=self.lr)
return optim
#optim = torch.optim.AdamW(self.parameters(), lr=self.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01)
#scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optim, gamma=0.96)
#return [optim], [scheduler]
def train_dataloader(self):
train_dataset = ToMnetDGLDataset(path=self.hparams.data_path,
types=self.hparams.types,
mode='train')
train_loader = DataLoader(dataset=train_dataset,
batch_size=self.hparams.batch_size,
collate_fn=collate_function_seq,
num_workers=self.hparams.num_workers,
#pin_memory=True,
shuffle=True)
return train_loader
def val_dataloader(self):
val_datasets = []
val_loaders = []
for t in self.hparams.types:
val_datasets.append(ToMnetDGLDataset(path=self.hparams.data_path,
types=[t],
mode='val'))
val_loaders.append(DataLoader(dataset=val_datasets[-1],
batch_size=self.hparams.batch_size,
collate_fn=collate_function_seq,
num_workers=self.hparams.num_workers,
#pin_memory=True,
shuffle=False))
return val_loaders
def configure_callbacks(self):
checkpoint = ModelCheckpoint(
dirpath=None, # automatically set
#filename=self.params['bc_model']+'-'+self.params['gnn_type']+'-'+self.gnn_params['feats_aggr']+'-{epoch:02d}',
save_top_k=-1,
period=1
)
return [checkpoint]

46
tom/norm.py Normal file
View file

@ -0,0 +1,46 @@
import torch
import torch.nn as nn
class Norm(nn.Module):
def __init__(self, norm_type, hidden_dim=64, print_info=None):
super(Norm, self).__init__()
# assert norm_type in ['bn', 'ln', 'gn', None]
self.norm = None
self.print_info = print_info
if norm_type == 'bn':
self.norm = nn.BatchNorm1d(hidden_dim)
elif norm_type == 'gn':
self.norm = norm_type
self.weight = nn.Parameter(torch.ones(hidden_dim))
self.bias = nn.Parameter(torch.zeros(hidden_dim))
self.mean_scale = nn.Parameter(torch.ones(hidden_dim))
def forward(self, graph, tensor, print_=False):
if self.norm is not None and type(self.norm) != str:
return self.norm(tensor)
elif self.norm is None:
return tensor
batch_list = graph.batch_num_nodes('obj')
batch_size = len(batch_list)
#batch_list = torch.tensor(batch_list).long().to(tensor.device)
batch_list = batch_list.long().to(tensor.device)
batch_index = torch.arange(batch_size).to(tensor.device).repeat_interleave(batch_list)
batch_index = batch_index.view((-1,) + (1,) * (tensor.dim() - 1)).expand_as(tensor)
mean = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device)
mean = mean.scatter_add_(0, batch_index, tensor)
mean = (mean.T / batch_list).T
mean = mean.repeat_interleave(batch_list, dim=0)
sub = tensor - mean * self.mean_scale
std = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device)
std = std.scatter_add_(0, batch_index, sub.pow(2))
std = ((std.T / batch_list).T + 1e-6).sqrt()
std = std.repeat_interleave(batch_list, dim=0)
return self.weight * sub / std + self.bias

89
tom/transformer.py Normal file
View file

@ -0,0 +1,89 @@
import torch
import torch.nn as nn
import math
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(max_len, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Tensor, shape [seq_len, batch_size, embedding_dim]
"""
x = x + self.pe[:x.size(0)]
return self.dropout(x)
class TransformerEncoder(nn.Module):
def __init__(
self,
d_model,
nhead,
dim_feedforward,
transformer_dropout,
transformer_activation,
num_encoder_layers,
max_input_len,
transformer_norm_input
):
super().__init__()
self.d_model = d_model
self.num_layer = num_encoder_layers
self.max_input_len = max_input_len
# Creating Transformer Encoder Model
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=transformer_dropout,
activation=transformer_activation
)
encoder_norm = nn.LayerNorm(d_model)
self.transformer = nn.TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
self.norm_input = None
if transformer_norm_input:
self.norm_input = nn.LayerNorm(d_model)
def forward(self, padded_h_node, src_padding_mask):
"""
padded_h_node: n_b x B x h_d # 63, 257, 128
src_key_padding_mask: B x n_b # 257, 63
"""
# (S, B, h_d), (B, S)
if self.norm_input is not None:
padded_h_node = self.norm_input(padded_h_node)
transformer_out = self.transformer(padded_h_node, src_key_padding_mask=src_padding_mask) # (S, B, h_d)
return transformer_out, src_padding_mask
if __name__ == '__main__':
model = TransformerEncoder(
d_model=12,
nhead=4,
dim_feedforward=32,
transformer_dropout=0.0,
transformer_activation='gelu',
num_encoder_layers=4,
max_input_len=34,
transformer_norm_input=0
)
print(model.norm_input)