up
This commit is contained in:
parent
a333481e05
commit
de0bea7508
18 changed files with 3150 additions and 2 deletions
0
tom/__init__.py
Normal file
0
tom/__init__.py
Normal file
310
tom/dataset.py
Normal file
310
tom/dataset.py
Normal file
|
@ -0,0 +1,310 @@
|
|||
import pickle as pkl
|
||||
import os
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import torch.nn.functional as F
|
||||
import dgl
|
||||
import random
|
||||
from dgl.data import DGLDataset
|
||||
|
||||
|
||||
def collate_function_seq(batch):
|
||||
#dem_frames = torch.stack([item[0] for item in batch])
|
||||
dem_frames = dgl.batch([item[0] for item in batch])
|
||||
dem_actions = torch.stack([item[1] for item in batch])
|
||||
dem_lens = [item[2] for item in batch]
|
||||
#query_frames = torch.stack([item[3] for item in batch])
|
||||
query_frames = dgl.batch([item[3] for item in batch])
|
||||
target_actions = torch.stack([item[4] for item in batch])
|
||||
return [dem_frames, dem_actions, dem_lens, query_frames, target_actions]
|
||||
|
||||
def collate_function_seq_test(batch):
|
||||
dem_expected_states = dgl.batch([item[0] for item in batch][0])
|
||||
dem_expected_actions = torch.stack([item[1] for item in batch][0]).unsqueeze(dim=0)
|
||||
dem_expected_lens = [item[2] for item in batch]
|
||||
#print(dem_expected_actions.size())
|
||||
dem_unexpected_states = dgl.batch([item[3] for item in batch][0])
|
||||
dem_unexpected_actions = torch.stack([item[4] for item in batch][0]).unsqueeze(dim=0)
|
||||
dem_unexpected_lens = [item[5] for item in batch]
|
||||
query_expected_frames = dgl.batch([item[6] for item in batch])
|
||||
target_expected_actions = torch.stack([item[7] for item in batch])
|
||||
#print(target_expected_actions.size())
|
||||
query_unexpected_frames = dgl.batch([item[8] for item in batch])
|
||||
target_unexpected_actions = torch.stack([item[9] for item in batch])
|
||||
return [
|
||||
dem_expected_states, dem_expected_actions, dem_expected_lens, \
|
||||
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, \
|
||||
query_expected_frames, target_expected_actions, \
|
||||
query_unexpected_frames, target_unexpected_actions
|
||||
]
|
||||
|
||||
def collate_function_mental(batch):
|
||||
dem_frames = dgl.batch([item[0] for item in batch])
|
||||
dem_actions = torch.stack([item[1] for item in batch])
|
||||
dem_lens = [item[2] for item in batch]
|
||||
past_test_frames = dgl.batch([item[3] for item in batch])
|
||||
past_test_actions = torch.stack([item[4] for item in batch])
|
||||
past_test_len = [item[5] for item in batch]
|
||||
query_frames = dgl.batch([item[6] for item in batch])
|
||||
target_actions = torch.stack([item[7] for item in batch])
|
||||
return [dem_frames, dem_actions, dem_lens, past_test_frames, past_test_actions, past_test_len, query_frames, target_actions]
|
||||
|
||||
|
||||
class ToMnetDGLDataset(DGLDataset):
|
||||
"""
|
||||
Training dataset class.
|
||||
"""
|
||||
def __init__(self, path, types=None, mode="train"):
|
||||
self.path = path
|
||||
self.types = types
|
||||
self.mode = mode
|
||||
print('Mode:', self.mode)
|
||||
|
||||
if self.mode == 'train':
|
||||
if len(self.types) == 4:
|
||||
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/'
|
||||
#self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_global/'
|
||||
#self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_local/'
|
||||
elif len(self.types) == 3:
|
||||
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper() + '/'
|
||||
print(self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper())
|
||||
elif len(self.types) == 2:
|
||||
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + '/'
|
||||
print(self.types[0][0].upper() + self.types[1][0].upper())
|
||||
elif len(self.types) == 1:
|
||||
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + '/'
|
||||
else: raise ValueError('Number of types different from 1 or 4.')
|
||||
elif self.mode == 'val':
|
||||
assert len(self.types) == 1
|
||||
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/' + self.types[0] + '/'
|
||||
#self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_global/' + self.types[0] + '/'
|
||||
#self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_local/' + self.types[0] + '/'
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
def get_test(self, states, actions):
|
||||
# now states is a batched graph -> unbatch it, take the len, pick one sub-graph
|
||||
# randomly and select the corresponding action
|
||||
frame_graphs = dgl.unbatch(states)
|
||||
trial_len = len(frame_graphs)
|
||||
query_idx = random.randint(0, trial_len - 1)
|
||||
query_graph = frame_graphs[query_idx]
|
||||
target_action = actions[query_idx]
|
||||
return query_graph, target_action
|
||||
|
||||
def __getitem__(self, idx):
|
||||
with open(self.path+str(idx)+'.pkl', 'rb') as f:
|
||||
states, actions, lens, _ = pkl.load(f)
|
||||
# shuffle
|
||||
ziplist = list(zip(states, actions, lens))
|
||||
random.shuffle(ziplist)
|
||||
states, actions, lens = zip(*ziplist)
|
||||
# convert tuples to lists
|
||||
states, actions, lens = [*states], [*actions], [*lens]
|
||||
# pick last element in the list as test and pick random frame
|
||||
test_s, test_a = self.get_test(states[-1], actions[-1])
|
||||
dem_s = states[:-1]
|
||||
dem_a = actions[:-1]
|
||||
dem_lens = lens[:-1]
|
||||
dem_s = dgl.batch(dem_s)
|
||||
dem_a = torch.stack(dem_a)
|
||||
return dem_s, dem_a, dem_lens, test_s, test_a
|
||||
|
||||
def __len__(self):
|
||||
return len(os.listdir(self.path))
|
||||
|
||||
|
||||
class TestToMnetDGLDataset(DGLDataset):
|
||||
"""
|
||||
Testing dataset class.
|
||||
"""
|
||||
def __init__(self, path, task_type=None, mode="test"):
|
||||
self.path = path
|
||||
self.type = task_type
|
||||
self.mode = mode
|
||||
print('Mode:', self.mode)
|
||||
|
||||
if self.mode == 'test':
|
||||
self.path = self.path + '_dgl_hetero_nobound_4feats/' + self.type + '/'
|
||||
#self.path = self.path + '_dgl_hetero_nobound_4feats_global/' + self.type + '/'
|
||||
#self.path = self.path + '_dgl_hetero_nobound_4feats_local/' + self.type + '/'
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
def __getitem__(self, idx):
|
||||
with open(self.path+str(idx)+'.pkl', 'rb') as f:
|
||||
dem_expected_states, dem_expected_actions, dem_expected_lens, _, \
|
||||
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, _, \
|
||||
query_expected_frames, target_expected_actions, \
|
||||
query_unexpected_frames, target_unexpected_actions = pkl.load(f)
|
||||
assert len(dem_expected_states) == 8
|
||||
assert len(dem_expected_actions) == 8
|
||||
assert len(dem_expected_lens) == 8
|
||||
assert len(dem_unexpected_states) == 8
|
||||
assert len(dem_unexpected_actions) == 8
|
||||
assert len(dem_unexpected_lens) == 8
|
||||
assert len(dgl.unbatch(query_expected_frames)) == target_expected_actions.size()[0]
|
||||
assert len(dgl.unbatch(query_unexpected_frames)) == target_unexpected_actions.size()[0]
|
||||
# ignore n_nodes
|
||||
return dem_expected_states, dem_expected_actions, dem_expected_lens, \
|
||||
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, \
|
||||
query_expected_frames, target_expected_actions, \
|
||||
query_unexpected_frames, target_unexpected_actions
|
||||
|
||||
def __len__(self):
|
||||
return len(os.listdir(self.path))
|
||||
|
||||
|
||||
class ToMnetDGLDatasetUndersample(DGLDataset):
|
||||
"""
|
||||
Training dataset class for the behavior cloning mlp model.
|
||||
"""
|
||||
def __init__(self, path, types=None, mode="train"):
|
||||
self.path = path
|
||||
self.types = types
|
||||
self.mode = mode
|
||||
print('Mode:', self.mode)
|
||||
|
||||
if self.mode == 'train':
|
||||
if len(self.types) == 4:
|
||||
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/'
|
||||
elif len(self.types) == 3:
|
||||
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper() + '/'
|
||||
print(self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper())
|
||||
elif len(self.types) == 2:
|
||||
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + '/'
|
||||
print(self.types[0][0].upper() + self.types[1][0].upper())
|
||||
elif len(self.types) == 1:
|
||||
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + '/'
|
||||
else: raise ValueError('Number of types different from 1 or 4.')
|
||||
elif self.mode == 'val':
|
||||
assert len(self.types) == 1
|
||||
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/' + self.types[0] + '/'
|
||||
else:
|
||||
raise ValueError
|
||||
print('Undersampled dataset!')
|
||||
|
||||
def get_test(self, states, actions):
|
||||
# now states is a batched graph -> unbatch it, take the len, pick one sub-graph
|
||||
# randomly and select the corresponding action
|
||||
frame_graphs = dgl.unbatch(states)
|
||||
trial_len = len(frame_graphs)
|
||||
query_idx = random.randint(0, trial_len - 1)
|
||||
query_graph = frame_graphs[query_idx]
|
||||
target_action = actions[query_idx]
|
||||
return query_graph, target_action
|
||||
|
||||
def __getitem__(self, idx):
|
||||
idx = idx + 3175
|
||||
with open(self.path+str(idx)+'.pkl', 'rb') as f:
|
||||
states, actions, lens, _ = pkl.load(f)
|
||||
# shuffle
|
||||
ziplist = list(zip(states, actions, lens))
|
||||
random.shuffle(ziplist)
|
||||
states, actions, lens = zip(*ziplist)
|
||||
# convert tuples to lists
|
||||
states, actions, lens = [*states], [*actions], [*lens]
|
||||
# pick last element in the list as test and pick random frame
|
||||
test_s, test_a = self.get_test(states[-1], actions[-1])
|
||||
dem_s = states[:-1]
|
||||
dem_a = actions[:-1]
|
||||
dem_lens = lens[:-1]
|
||||
dem_s = dgl.batch(dem_s)
|
||||
dem_a = torch.stack(dem_a)
|
||||
return dem_s, dem_a, dem_lens, test_s, test_a
|
||||
|
||||
def __len__(self):
|
||||
return len(os.listdir(self.path)) - 3175
|
||||
|
||||
|
||||
class ToMnetDGLDatasetMental(DGLDataset):
|
||||
"""
|
||||
Training dataset class.
|
||||
"""
|
||||
def __init__(self, path, types=None, mode="train"):
|
||||
self.path = path
|
||||
self.types = types
|
||||
self.mode = mode
|
||||
print('Mode:', self.mode)
|
||||
|
||||
if self.mode == 'train':
|
||||
if len(self.types) == 4:
|
||||
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/'
|
||||
elif len(self.types) == 3:
|
||||
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper() + '/'
|
||||
print(self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper())
|
||||
elif len(self.types) == 2:
|
||||
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + '/'
|
||||
print(self.types[0][0].upper() + self.types[1][0].upper())
|
||||
elif len(self.types) == 1:
|
||||
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + '/'
|
||||
else: raise ValueError('Number of types different from 1 or 4.')
|
||||
elif self.mode == 'val':
|
||||
assert len(self.types) == 1
|
||||
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/' + self.types[0] + '/'
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
def get_test(self, states, actions):
|
||||
"""
|
||||
return: past_test_graphs, past_test_actions, test_graph, test_action
|
||||
"""
|
||||
frame_graphs = dgl.unbatch(states)
|
||||
trial_len = len(frame_graphs)
|
||||
query_idx = random.randint(0, trial_len - 1)
|
||||
test_graph = frame_graphs[query_idx]
|
||||
test_action = actions[query_idx]
|
||||
if query_idx > 0:
|
||||
#past_test_actions = F.pad(past_test_actions, (0,0,0,41-query_idx), 'constant', 0)
|
||||
if query_idx == 1:
|
||||
past_test_graphs = frame_graphs[0]
|
||||
past_test_actions = actions[:query_idx]
|
||||
past_test_actions = F.pad(past_test_actions, (0,0,0,41-query_idx), 'constant', 0)
|
||||
return past_test_graphs, past_test_actions, query_idx, test_graph, test_action
|
||||
else:
|
||||
past_test_graphs = frame_graphs[:query_idx]
|
||||
past_test_actions = actions[:query_idx]
|
||||
past_test_graphs = dgl.batch(past_test_graphs)
|
||||
past_test_actions = F.pad(past_test_actions, (0,0,0,41-query_idx), 'constant', 0)
|
||||
return past_test_graphs, past_test_actions, query_idx, test_graph, test_action
|
||||
else:
|
||||
test_action_ = F.pad(test_action.unsqueeze(0), (0,0,0,41-1), 'constant', 0)
|
||||
# NOTE: since there are no past frames, return the test frame and action and query_idx=1
|
||||
# not sure what would be a better alternative
|
||||
return test_graph, test_action_, 1, test_graph, test_action
|
||||
|
||||
def __getitem__(self, idx):
|
||||
with open(self.path+str(idx)+'.pkl', 'rb') as f:
|
||||
states, actions, lens, _ = pkl.load(f)
|
||||
ziplist = list(zip(states, actions, lens))
|
||||
random.shuffle(ziplist)
|
||||
states, actions, lens = zip(*ziplist)
|
||||
states, actions, lens = [*states], [*actions], [*lens]
|
||||
past_test_s, past_test_a, past_test_len, test_s, test_a = self.get_test(states[-1], actions[-1])
|
||||
dem_s = states[:-1]
|
||||
dem_a = actions[:-1]
|
||||
dem_lens = lens[:-1]
|
||||
dem_s = dgl.batch(dem_s)
|
||||
dem_a = torch.stack(dem_a)
|
||||
return dem_s, dem_a, dem_lens, past_test_s, past_test_a, past_test_len, test_s, test_a
|
||||
|
||||
def __len__(self):
|
||||
return len(os.listdir(self.path))
|
||||
|
||||
# --------------------------------------------------------------------------------------------------------
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
types = [
|
||||
'preference', 'multi_agent', 'inaccessible_goal',
|
||||
'efficiency_irrational', 'efficiency_time','efficiency_path',
|
||||
'instrumental_no_barrier', 'instrumental_blocking_barrier', 'instrumental_inconsequential_barrier'
|
||||
]
|
||||
|
||||
mental_dataset = ToMnetDGLDatasetMental(
|
||||
path='/datasets/external/bib_train/graphs/all_tasks/',
|
||||
types=['instrumental_action'],
|
||||
mode='train'
|
||||
)
|
||||
dem_frames, dem_actions, dem_lens, past_test_frames, past_test_actions, len, test_frame, test_action = mental_dataset.__getitem__(99)
|
||||
breakpoint()
|
877
tom/gnn.py
Normal file
877
tom/gnn.py
Normal file
|
@ -0,0 +1,877 @@
|
|||
import dgl.nn.pytorch as dglnn
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch
|
||||
import dgl
|
||||
import sys
|
||||
import copy
|
||||
|
||||
from wandb import agent
|
||||
sys.path.append('/projects/bortoletto/irene/')
|
||||
from tom.norm import Norm
|
||||
|
||||
|
||||
class RSAGEv4(nn.Module):
|
||||
# multi-layer GNN for one single feature
|
||||
def __init__(
|
||||
self,
|
||||
hidden_channels,
|
||||
out_channels,
|
||||
rel_names,
|
||||
dropout,
|
||||
n_layers,
|
||||
activation=nn.ELU(),
|
||||
):
|
||||
super().__init__()
|
||||
self.embedding_type = nn.Linear(9, int(hidden_channels))
|
||||
self.embedding_pos = nn.Linear(2, int(hidden_channels))
|
||||
self.embedding_color = nn.Linear(3, int(hidden_channels))
|
||||
self.embedding_shape = nn.Linear(18, int(hidden_channels))
|
||||
self.combine_attr = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_channels*3, hidden_channels)
|
||||
)
|
||||
self.combine_pos = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_channels*2, hidden_channels)
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
dglnn.HeteroGraphConv({
|
||||
rel: dglnn.SAGEConv(
|
||||
in_feats=hidden_channels,
|
||||
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
|
||||
aggregator_type='lstm',
|
||||
feat_drop=dropout,
|
||||
activation=activation if l < n_layers - 1 else None
|
||||
)
|
||||
for rel in rel_names}, aggregate='sum')
|
||||
for l in range(n_layers)
|
||||
])
|
||||
|
||||
def forward(self, g):
|
||||
attr = []
|
||||
attr.append(self.embedding_type(g.ndata['type'].float()))
|
||||
pos = self.embedding_pos(g.ndata['pos']/170.)
|
||||
attr.append(self.embedding_color(g.ndata['color']/255.))
|
||||
attr.append(self.embedding_shape(g.ndata['shape'].float()))
|
||||
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
|
||||
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
|
||||
for l, conv in enumerate(self.layers):
|
||||
h = conv(g, h)
|
||||
with g.local_scope():
|
||||
g.ndata['h'] = h['obj']
|
||||
out = dgl.mean_nodes(g, 'h')
|
||||
return out
|
||||
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
class RAGNNv4(nn.Module):
|
||||
# multi-layer GNN for one single feature
|
||||
def __init__(
|
||||
self,
|
||||
hidden_channels,
|
||||
out_channels,
|
||||
rel_names,
|
||||
dropout,
|
||||
n_layers,
|
||||
activation=nn.ELU(),
|
||||
):
|
||||
super().__init__()
|
||||
self.embedding_type = nn.Linear(9, int(hidden_channels))
|
||||
self.embedding_pos = nn.Linear(2, int(hidden_channels))
|
||||
self.embedding_color = nn.Linear(3, int(hidden_channels))
|
||||
self.embedding_shape = nn.Linear(18, int(hidden_channels))
|
||||
self.combine_attr = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_channels*3, hidden_channels)
|
||||
)
|
||||
self.combine_pos = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_channels*2, hidden_channels)
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
dglnn.HeteroGraphConv({
|
||||
rel: dglnn.AGNNConv()
|
||||
for rel in rel_names}, aggregate='sum')
|
||||
for l in range(n_layers)
|
||||
])
|
||||
|
||||
def forward(self, g):
|
||||
attr = []
|
||||
attr.append(self.embedding_type(g.ndata['type'].float()))
|
||||
pos = self.embedding_pos(g.ndata['pos']/170.)
|
||||
attr.append(self.embedding_color(g.ndata['color']/255.))
|
||||
attr.append(self.embedding_shape(g.ndata['shape'].float()))
|
||||
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
|
||||
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
|
||||
for l, conv in enumerate(self.layers):
|
||||
h = conv(g, h)
|
||||
with g.local_scope():
|
||||
g.ndata['h'] = h['obj']
|
||||
out = dgl.mean_nodes(g, 'h')
|
||||
return out
|
||||
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
class RGATv2(nn.Module):
|
||||
# multi-layer GNN for one single feature
|
||||
def __init__(
|
||||
self,
|
||||
hidden_channels,
|
||||
out_channels,
|
||||
num_heads,
|
||||
rel_names,
|
||||
dropout,
|
||||
n_layers,
|
||||
activation=nn.ELU(),
|
||||
residual=False
|
||||
):
|
||||
super().__init__()
|
||||
self.embedding_type = nn.Embedding(9, int(hidden_channels*num_heads/4))
|
||||
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads/4))
|
||||
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads/4))
|
||||
self.embedding_shape = nn.Embedding(18, int(hidden_channels*num_heads/4))
|
||||
self.combine = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_channels*num_heads, hidden_channels*num_heads)
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
dglnn.HeteroGraphConv({
|
||||
rel: dglnn.GATv2Conv(
|
||||
in_feats=hidden_channels*num_heads,
|
||||
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
|
||||
num_heads=num_heads if l < n_layers - 1 else 1,
|
||||
feat_drop=dropout,
|
||||
attn_drop=dropout,
|
||||
residual=residual,
|
||||
activation=activation if l < n_layers - 1 else None
|
||||
)
|
||||
for rel in rel_names}, aggregate='sum')
|
||||
for l in range(n_layers)
|
||||
])
|
||||
|
||||
def forward(self, g):
|
||||
feats = []
|
||||
feats.append(self.embedding_type(torch.argmax(g.ndata['type'], dim=1)))
|
||||
feats.append(self.embedding_pos(g.ndata['pos']/170.))
|
||||
feats.append(self.embedding_color(g.ndata['color']/255.))
|
||||
feats.append(self.embedding_shape(torch.argmax(g.ndata['shape'], dim=1)))
|
||||
h = {'obj': self.combine(torch.cat(feats, dim=1))}
|
||||
for l, conv in enumerate(self.layers):
|
||||
h = conv(g, h)
|
||||
if l != len(self.layers) - 1:
|
||||
h = {k: v.flatten(1) for k, v in h.items()}
|
||||
else:
|
||||
h = {k: v.mean(1) for k, v in h.items()}
|
||||
with g.local_scope():
|
||||
g.ndata['h'] = h['obj']
|
||||
out = dgl.mean_nodes(g, 'h')
|
||||
return out
|
||||
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
class RGATv3(nn.Module):
|
||||
# multi-layer GNN for one single feature
|
||||
def __init__(
|
||||
self,
|
||||
hidden_channels,
|
||||
out_channels,
|
||||
num_heads,
|
||||
rel_names,
|
||||
dropout,
|
||||
n_layers,
|
||||
activation=nn.ELU(),
|
||||
residual=False
|
||||
):
|
||||
super().__init__()
|
||||
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
|
||||
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
|
||||
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
|
||||
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
|
||||
self.combine = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_channels*num_heads*4, hidden_channels*num_heads)
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
dglnn.HeteroGraphConv({
|
||||
rel: dglnn.GATv2Conv(
|
||||
in_feats=hidden_channels*num_heads,
|
||||
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
|
||||
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
|
||||
feat_drop=dropout,
|
||||
attn_drop=dropout,
|
||||
residual=residual,
|
||||
activation=activation if l < n_layers - 1 else None
|
||||
)
|
||||
for rel in rel_names}, aggregate='sum')
|
||||
for l in range(n_layers)
|
||||
])
|
||||
|
||||
def forward(self, g):
|
||||
feats = []
|
||||
feats.append(self.embedding_type(g.ndata['type'].float()))
|
||||
feats.append(self.embedding_pos(g.ndata['pos']/170.)) # NOTE: this should be 180 because I remove the boundary walls!
|
||||
feats.append(self.embedding_color(g.ndata['color']/255.))
|
||||
feats.append(self.embedding_shape(g.ndata['shape'].float()))
|
||||
h = {'obj': self.combine(torch.cat(feats, dim=1))}
|
||||
for l, conv in enumerate(self.layers):
|
||||
h = conv(g, h)
|
||||
if l != len(self.layers) - 1:
|
||||
h = {k: v.flatten(1) for k, v in h.items()}
|
||||
else:
|
||||
h = {k: v.mean(1) for k, v in h.items()}
|
||||
with g.local_scope():
|
||||
g.ndata['h'] = h['obj']
|
||||
out = dgl.mean_nodes(g, 'h')
|
||||
return out
|
||||
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
class RGCNv2(nn.Module):
|
||||
# multi-layer GNN for one single feature
|
||||
def __init__(
|
||||
self,
|
||||
hidden_channels,
|
||||
out_channels,
|
||||
rel_names,
|
||||
dropout,
|
||||
n_layers,
|
||||
activation=nn.ReLU()
|
||||
):
|
||||
super().__init__()
|
||||
self.embedding_type = nn.Linear(9, int(hidden_channels))
|
||||
self.embedding_pos = nn.Linear(2, int(hidden_channels))
|
||||
self.embedding_color = nn.Linear(3, int(hidden_channels))
|
||||
self.embedding_shape = nn.Linear(18, int(hidden_channels))
|
||||
self.combine = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_channels*4, hidden_channels)
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
dglnn.RelGraphConv(
|
||||
in_feat=hidden_channels,
|
||||
out_feat=hidden_channels,
|
||||
num_rels=len(rel_names),
|
||||
regularizer=None,
|
||||
num_bases=None,
|
||||
bias=True,
|
||||
activation=activation,
|
||||
self_loop=True,
|
||||
dropout=dropout,
|
||||
layer_norm=False
|
||||
)
|
||||
for _ in range(n_layers-1)])
|
||||
self.layers.append(
|
||||
dglnn.RelGraphConv(
|
||||
in_feat=hidden_channels,
|
||||
out_feat=out_channels,
|
||||
num_rels=len(rel_names),
|
||||
regularizer=None,
|
||||
num_bases=None,
|
||||
bias=True,
|
||||
activation=activation,
|
||||
self_loop=True,
|
||||
dropout=dropout,
|
||||
layer_norm=False
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, g):
|
||||
g = g.to_homogeneous()
|
||||
feats = []
|
||||
feats.append(self.embedding_type(g.ndata['type'].float()))
|
||||
feats.append(self.embedding_pos(g.ndata['pos']/170.))
|
||||
feats.append(self.embedding_color(g.ndata['color']/255.))
|
||||
feats.append(self.embedding_shape(g.ndata['shape'].float()))
|
||||
h = self.combine(torch.cat(feats, dim=1))
|
||||
for conv in self.layers:
|
||||
h = conv(g, h, g.etypes)
|
||||
with g.local_scope():
|
||||
g.ndata['h'] = h
|
||||
out = dgl.mean_nodes(g, 'h')
|
||||
return out
|
||||
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
class RGATv3Agent(nn.Module):
|
||||
# multi-layer GNN for one single feature
|
||||
def __init__(
|
||||
self,
|
||||
hidden_channels,
|
||||
out_channels,
|
||||
num_heads,
|
||||
rel_names,
|
||||
dropout,
|
||||
n_layers,
|
||||
activation=nn.ELU(),
|
||||
residual=False
|
||||
):
|
||||
super().__init__()
|
||||
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
|
||||
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
|
||||
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
|
||||
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
|
||||
self.combine = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_channels*num_heads*4, hidden_channels*num_heads)
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
dglnn.HeteroGraphConv({
|
||||
rel: dglnn.GATv2Conv(
|
||||
in_feats=hidden_channels*num_heads,
|
||||
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
|
||||
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
|
||||
feat_drop=dropout,
|
||||
attn_drop=dropout,
|
||||
residual=residual,
|
||||
activation=activation if l < n_layers - 1 else None
|
||||
)
|
||||
for rel in rel_names}, aggregate='sum')
|
||||
for l in range(n_layers)
|
||||
])
|
||||
|
||||
def forward(self, g):
|
||||
agent_mask = g.ndata['type'][:, 0] == 1
|
||||
feats = []
|
||||
feats.append(self.embedding_type(g.ndata['type'].float()))
|
||||
feats.append(self.embedding_pos(g.ndata['pos']/200.))
|
||||
feats.append(self.embedding_color(g.ndata['color']/255.))
|
||||
feats.append(self.embedding_shape(g.ndata['shape'].float()))
|
||||
h = {'obj': self.combine(torch.cat(feats, dim=1))}
|
||||
for l, conv in enumerate(self.layers):
|
||||
h = conv(g, h)
|
||||
if l != len(self.layers) - 1:
|
||||
h = {k: v.flatten(1) for k, v in h.items()}
|
||||
else:
|
||||
h = {k: v.mean(1) for k, v in h.items()}
|
||||
with g.local_scope():
|
||||
g.ndata['h'] = h['obj']
|
||||
out = g.ndata['h'][agent_mask, :]
|
||||
ctx = dgl.mean_nodes(g, 'h')
|
||||
return out + ctx
|
||||
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
class RGATv4(nn.Module):
|
||||
# multi-layer GNN for one single feature
|
||||
def __init__(
|
||||
self,
|
||||
hidden_channels,
|
||||
out_channels,
|
||||
num_heads,
|
||||
rel_names,
|
||||
dropout,
|
||||
n_layers,
|
||||
activation=nn.ELU(),
|
||||
residual=False
|
||||
):
|
||||
super().__init__()
|
||||
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
|
||||
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
|
||||
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
|
||||
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
|
||||
self.combine_attr = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads)
|
||||
)
|
||||
self.combine_pos = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads)
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
dglnn.HeteroGraphConv({
|
||||
rel: dglnn.GATv2Conv(
|
||||
in_feats=hidden_channels*num_heads,
|
||||
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
|
||||
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
|
||||
feat_drop=dropout,
|
||||
attn_drop=dropout,
|
||||
residual=residual,
|
||||
share_weights=False,
|
||||
activation=activation if l < n_layers - 1 else None
|
||||
)
|
||||
for rel in rel_names}, aggregate='sum')
|
||||
for l in range(n_layers)
|
||||
])
|
||||
|
||||
def forward(self, g):
|
||||
attr = []
|
||||
attr.append(self.embedding_type(g.ndata['type'].float()))
|
||||
pos = self.embedding_pos(g.ndata['pos']/170.)
|
||||
attr.append(self.embedding_color(g.ndata['color']/255.))
|
||||
attr.append(self.embedding_shape(g.ndata['shape'].float()))
|
||||
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
|
||||
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
|
||||
for l, conv in enumerate(self.layers):
|
||||
h = conv(g, h)
|
||||
if l != len(self.layers) - 1:
|
||||
h = {k: v.flatten(1) for k, v in h.items()}
|
||||
else:
|
||||
h = {k: v.mean(1) for k, v in h.items()}
|
||||
with g.local_scope():
|
||||
g.ndata['h'] = h['obj']
|
||||
out = dgl.mean_nodes(g, 'h')
|
||||
return out
|
||||
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
class RGATv4Norm(nn.Module):
|
||||
# multi-layer GNN for one single feature
|
||||
def __init__(
|
||||
self,
|
||||
hidden_channels,
|
||||
out_channels,
|
||||
num_heads,
|
||||
rel_names,
|
||||
dropout,
|
||||
n_layers,
|
||||
activation=nn.ELU(),
|
||||
residual=False
|
||||
):
|
||||
super().__init__()
|
||||
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
|
||||
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
|
||||
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
|
||||
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
|
||||
self.combine_attr = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads)
|
||||
)
|
||||
self.combine_pos = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads)
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
dglnn.HeteroGraphConv({
|
||||
rel: dglnn.GATv2Conv(
|
||||
in_feats=hidden_channels*num_heads,
|
||||
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
|
||||
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
|
||||
feat_drop=dropout,
|
||||
attn_drop=dropout,
|
||||
residual=residual,
|
||||
activation=activation if l < n_layers - 1 else None
|
||||
)
|
||||
for rel in rel_names}, aggregate='sum')
|
||||
for l in range(n_layers)
|
||||
])
|
||||
self.norms = nn.ModuleList([
|
||||
Norm(
|
||||
norm_type='gn',
|
||||
hidden_dim=hidden_channels*num_heads if l < n_layers - 1 else out_channels
|
||||
)
|
||||
for l in range(n_layers)
|
||||
])
|
||||
|
||||
def forward(self, g):
|
||||
attr = []
|
||||
attr.append(self.embedding_type(g.ndata['type'].float()))
|
||||
pos = self.embedding_pos(g.ndata['pos']/170.)
|
||||
attr.append(self.embedding_color(g.ndata['color']/255.))
|
||||
attr.append(self.embedding_shape(g.ndata['shape'].float()))
|
||||
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
|
||||
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
|
||||
for l, conv in enumerate(self.layers):
|
||||
h = conv(g, h)
|
||||
if l != len(self.layers) - 1:
|
||||
h = {k: v.flatten(1) for k, v in h.items()}
|
||||
else:
|
||||
h = {k: v.mean(1) for k, v in h.items()}
|
||||
h = {k: self.norms[l](g, v) for k, v in h.items()}
|
||||
with g.local_scope():
|
||||
g.ndata['h'] = h['obj']
|
||||
out = dgl.mean_nodes(g, 'h')
|
||||
return out
|
||||
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
class RGATv3Norm(nn.Module):
|
||||
# multi-layer GNN for one single feature
|
||||
def __init__(
|
||||
self,
|
||||
hidden_channels,
|
||||
out_channels,
|
||||
num_heads,
|
||||
rel_names,
|
||||
dropout,
|
||||
n_layers,
|
||||
activation=nn.ELU(),
|
||||
residual=False
|
||||
):
|
||||
super().__init__()
|
||||
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
|
||||
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
|
||||
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
|
||||
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
|
||||
self.combine = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_channels*num_heads*4, hidden_channels*num_heads)
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
dglnn.HeteroGraphConv({
|
||||
rel: dglnn.GATv2Conv(
|
||||
in_feats=hidden_channels*num_heads,
|
||||
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
|
||||
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
|
||||
feat_drop=dropout,
|
||||
attn_drop=dropout,
|
||||
residual=residual,
|
||||
activation=activation if l < n_layers - 1 else None
|
||||
)
|
||||
for rel in rel_names}, aggregate='sum')
|
||||
for l in range(n_layers)
|
||||
])
|
||||
self.norms = nn.ModuleList([
|
||||
Norm(
|
||||
norm_type='gn',
|
||||
hidden_dim=hidden_channels*num_heads if l < n_layers - 1 else out_channels
|
||||
)
|
||||
for l in range(n_layers)
|
||||
])
|
||||
|
||||
def forward(self, g):
|
||||
feats = []
|
||||
feats.append(self.embedding_type(g.ndata['type'].float()))
|
||||
feats.append(self.embedding_pos(g.ndata['pos']/170.))
|
||||
feats.append(self.embedding_color(g.ndata['color']/255.))
|
||||
feats.append(self.embedding_shape(g.ndata['shape'].float()))
|
||||
h = {'obj': self.combine(torch.cat(feats, dim=1))}
|
||||
for l, conv in enumerate(self.layers):
|
||||
h = conv(g, h)
|
||||
if l != len(self.layers) - 1:
|
||||
h = {k: v.flatten(1) for k, v in h.items()}
|
||||
else:
|
||||
h = {k: v.mean(1) for k, v in h.items()}
|
||||
h = {k: self.norms[l](g, v) for k, v in h.items()}
|
||||
with g.local_scope():
|
||||
g.ndata['h'] = h['obj']
|
||||
out = dgl.mean_nodes(g, 'h')
|
||||
return out
|
||||
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
class RGATv4Agent(nn.Module):
|
||||
# multi-layer GNN for one single feature
|
||||
def __init__(
|
||||
self,
|
||||
hidden_channels,
|
||||
out_channels,
|
||||
num_heads,
|
||||
rel_names,
|
||||
dropout,
|
||||
n_layers,
|
||||
activation=nn.ELU(),
|
||||
residual=False
|
||||
):
|
||||
super().__init__()
|
||||
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
|
||||
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
|
||||
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
|
||||
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
|
||||
self.combine_attr = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads)
|
||||
)
|
||||
self.combine_pos = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads)
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
dglnn.HeteroGraphConv({
|
||||
rel: dglnn.GATv2Conv(
|
||||
in_feats=hidden_channels*num_heads,
|
||||
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
|
||||
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
|
||||
feat_drop=dropout,
|
||||
attn_drop=dropout,
|
||||
residual=residual,
|
||||
activation=activation if l < n_layers - 1 else None
|
||||
)
|
||||
for rel in rel_names}, aggregate='sum')
|
||||
for l in range(n_layers)
|
||||
])
|
||||
self.combine_agent_context = nn.Linear(out_channels*2, out_channels)
|
||||
|
||||
def forward(self, g):
|
||||
agent_mask = g.ndata['type'][:, 0] == 1
|
||||
attr = []
|
||||
attr.append(self.embedding_type(g.ndata['type'].float()))
|
||||
pos = self.embedding_pos(g.ndata['pos']/170.)
|
||||
attr.append(self.embedding_color(g.ndata['color']/255.))
|
||||
attr.append(self.embedding_shape(g.ndata['shape'].float()))
|
||||
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
|
||||
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
|
||||
for l, conv in enumerate(self.layers):
|
||||
h = conv(g, h)
|
||||
if l != len(self.layers) - 1:
|
||||
h = {k: v.flatten(1) for k, v in h.items()}
|
||||
else:
|
||||
h = {k: v.mean(1) for k, v in h.items()}
|
||||
with g.local_scope():
|
||||
g.ndata['h'] = h['obj']
|
||||
h_a = g.ndata['h'][agent_mask, :]
|
||||
g_no_agent = copy.deepcopy(g)
|
||||
g_no_agent.remove_nodes([i for i, x in enumerate(agent_mask) if x])
|
||||
h_g = dgl.mean_nodes(g_no_agent, 'h')
|
||||
out = self.combine_agent_context(torch.cat((h_a, h_g), dim=1))
|
||||
return out
|
||||
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
class RGCNv4(nn.Module):
|
||||
# multi-layer GNN for one single feature
|
||||
def __init__(
|
||||
self,
|
||||
hidden_channels,
|
||||
out_channels,
|
||||
rel_names,
|
||||
n_layers,
|
||||
activation=nn.ELU(),
|
||||
):
|
||||
super().__init__()
|
||||
self.embedding_type = nn.Linear(9, int(hidden_channels))
|
||||
self.embedding_pos = nn.Linear(2, int(hidden_channels))
|
||||
self.embedding_color = nn.Linear(3, int(hidden_channels))
|
||||
self.embedding_shape = nn.Linear(18, int(hidden_channels))
|
||||
self.combine_attr = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_channels*3, hidden_channels)
|
||||
)
|
||||
self.combine_pos = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_channels*2, hidden_channels)
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
dglnn.HeteroGraphConv({
|
||||
rel: dglnn.GraphConv(
|
||||
in_feats=hidden_channels,
|
||||
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
|
||||
activation=activation if l < n_layers - 1 else None
|
||||
)
|
||||
for rel in rel_names}, aggregate='sum')
|
||||
for l in range(n_layers)
|
||||
])
|
||||
|
||||
def forward(self, g):
|
||||
attr = []
|
||||
attr.append(self.embedding_type(g.ndata['type'].float()))
|
||||
pos = self.embedding_pos(g.ndata['pos']/170.)
|
||||
attr.append(self.embedding_color(g.ndata['color']/255.))
|
||||
attr.append(self.embedding_shape(g.ndata['shape'].float()))
|
||||
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
|
||||
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
|
||||
for l, conv in enumerate(self.layers):
|
||||
h = conv(g, h)
|
||||
with g.local_scope():
|
||||
g.ndata['h'] = h['obj']
|
||||
out = dgl.mean_nodes(g, 'h')
|
||||
return out
|
||||
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
class RGATv5(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_channels,
|
||||
out_channels,
|
||||
num_heads,
|
||||
rel_names,
|
||||
dropout,
|
||||
n_layers,
|
||||
activation=nn.ELU(),
|
||||
residual=False
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_channels = hidden_channels
|
||||
self.num_heads = num_heads
|
||||
self.embedding_type = nn.Linear(9, hidden_channels*num_heads)
|
||||
self.embedding_pos = nn.Linear(2, hidden_channels*num_heads)
|
||||
self.embedding_color = nn.Linear(3, hidden_channels*num_heads)
|
||||
self.embedding_shape = nn.Linear(18, hidden_channels*num_heads)
|
||||
self.combine = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_channels*num_heads*4, hidden_channels*num_heads)
|
||||
)
|
||||
self.attention = nn.Linear(hidden_channels*num_heads*4, 4)
|
||||
self.layers = nn.ModuleList([
|
||||
dglnn.HeteroGraphConv({
|
||||
rel: dglnn.GATv2Conv(
|
||||
in_feats=hidden_channels*num_heads,
|
||||
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
|
||||
num_heads=num_heads,
|
||||
feat_drop=dropout,
|
||||
attn_drop=dropout,
|
||||
residual=residual,
|
||||
activation=activation if l < n_layers - 1 else None
|
||||
)
|
||||
for rel in rel_names}, aggregate='sum')
|
||||
for l in range(n_layers)
|
||||
])
|
||||
|
||||
def forward(self, g):
|
||||
feats = []
|
||||
feats.append(self.embedding_type(g.ndata['type'].float()))
|
||||
feats.append(self.embedding_pos(g.ndata['pos']/170.))
|
||||
feats.append(self.embedding_color(g.ndata['color']/255.))
|
||||
feats.append(self.embedding_shape(g.ndata['shape'].float()))
|
||||
h = torch.cat(feats, dim=1)
|
||||
feat_attn = F.softmax(self.attention(h), dim=1)
|
||||
h = h * feat_attn.repeat_interleave(self.hidden_channels*self.num_heads, dim=1)
|
||||
h_in = self.combine(h)
|
||||
h = {'obj': h_in}
|
||||
for conv in self.layers:
|
||||
h = conv(g, h)
|
||||
h = {k: v.flatten(1) for k, v in h.items()}
|
||||
#if l != len(self.layers) - 1:
|
||||
# h = {k: v.flatten(1) for k, v in h.items()}
|
||||
#else:
|
||||
# h = {k: v.mean(1) for k, v in h.items()}
|
||||
h = {k: v + h_in for k, v in h.items()}
|
||||
with g.local_scope():
|
||||
g.ndata['h'] = h['obj']
|
||||
out = dgl.mean_nodes(g, 'h')
|
||||
return out
|
||||
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
class RGATv6(nn.Module):
|
||||
|
||||
# RGATv6 = RGATv4 + Global Attention Pooling
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_channels,
|
||||
out_channels,
|
||||
num_heads,
|
||||
rel_names,
|
||||
dropout,
|
||||
n_layers,
|
||||
activation=nn.ELU(),
|
||||
residual=False
|
||||
):
|
||||
super().__init__()
|
||||
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
|
||||
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
|
||||
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
|
||||
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
|
||||
self.combine_attr = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads)
|
||||
)
|
||||
self.combine_pos = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads)
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
dglnn.HeteroGraphConv({
|
||||
rel: dglnn.GATv2Conv(
|
||||
in_feats=hidden_channels*num_heads,
|
||||
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
|
||||
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
|
||||
feat_drop=dropout,
|
||||
attn_drop=dropout,
|
||||
residual=residual,
|
||||
activation=activation if l < n_layers - 1 else None
|
||||
)
|
||||
for rel in rel_names}, aggregate='sum')
|
||||
for l in range(n_layers)
|
||||
])
|
||||
gate_nn = nn.Linear(out_channels, 1)
|
||||
self.gap = dglnn.GlobalAttentionPooling(gate_nn)
|
||||
|
||||
def forward(self, g):
|
||||
attr = []
|
||||
attr.append(self.embedding_type(g.ndata['type'].float()))
|
||||
pos = self.embedding_pos(g.ndata['pos']/170.)
|
||||
attr.append(self.embedding_color(g.ndata['color']/255.))
|
||||
attr.append(self.embedding_shape(g.ndata['shape'].float()))
|
||||
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
|
||||
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
|
||||
for l, conv in enumerate(self.layers):
|
||||
h = conv(g, h)
|
||||
if l != len(self.layers) - 1:
|
||||
h = {k: v.flatten(1) for k, v in h.items()}
|
||||
else:
|
||||
h = {k: v.mean(1) for k, v in h.items()}
|
||||
#with g.local_scope():
|
||||
#g.ndata['h'] = h['obj']
|
||||
#out = dgl.mean_nodes(g, 'h')
|
||||
out = self.gap(g, h['obj'])
|
||||
return out
|
||||
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
class RGATv6Agent(nn.Module):
|
||||
|
||||
# RGATv6 = RGATv4 + Global Attention Pooling
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_channels,
|
||||
out_channels,
|
||||
num_heads,
|
||||
rel_names,
|
||||
dropout,
|
||||
n_layers,
|
||||
activation=nn.ELU(),
|
||||
residual=False
|
||||
):
|
||||
super().__init__()
|
||||
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
|
||||
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
|
||||
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
|
||||
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
|
||||
self.combine_attr = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads)
|
||||
)
|
||||
self.combine_pos = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads)
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
dglnn.HeteroGraphConv({
|
||||
rel: dglnn.GATv2Conv(
|
||||
in_feats=hidden_channels*num_heads,
|
||||
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
|
||||
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
|
||||
feat_drop=dropout,
|
||||
attn_drop=dropout,
|
||||
residual=residual,
|
||||
activation=activation if l < n_layers - 1 else None
|
||||
)
|
||||
for rel in rel_names}, aggregate='sum')
|
||||
for l in range(n_layers)
|
||||
])
|
||||
gate_nn = nn.Linear(out_channels, 1)
|
||||
self.gap = dglnn.GlobalAttentionPooling(gate_nn)
|
||||
self.combine_agent_context = nn.Linear(out_channels*2, out_channels)
|
||||
|
||||
def forward(self, g):
|
||||
agent_mask = g.ndata['type'][:, 0] == 1
|
||||
attr = []
|
||||
attr.append(self.embedding_type(g.ndata['type'].float()))
|
||||
pos = self.embedding_pos(g.ndata['pos']/170.)
|
||||
attr.append(self.embedding_color(g.ndata['color']/255.))
|
||||
attr.append(self.embedding_shape(g.ndata['shape'].float()))
|
||||
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
|
||||
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
|
||||
for l, conv in enumerate(self.layers):
|
||||
h = conv(g, h)
|
||||
if l != len(self.layers) - 1:
|
||||
h = {k: v.flatten(1) for k, v in h.items()}
|
||||
else:
|
||||
h = {k: v.mean(1) for k, v in h.items()}
|
||||
with g.local_scope():
|
||||
g.ndata['h'] = h['obj']
|
||||
h_a = g.ndata['h'][agent_mask, :]
|
||||
h_g = g.ndata['h'][~agent_mask, :]
|
||||
g_no_agent = copy.deepcopy(g)
|
||||
g_no_agent.remove_nodes([i for i, x in enumerate(agent_mask) if x])
|
||||
h_g = self.gap(g_no_agent, h_g) # dgl.mean_nodes(g_no_agent, 'h')
|
||||
out = self.combine_agent_context(torch.cat((h_a, h_g), dim=1))
|
||||
return out
|
||||
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
513
tom/model.py
Normal file
513
tom/model.py
Normal file
|
@ -0,0 +1,513 @@
|
|||
from argparse import ArgumentParser
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from tom.dataset import *
|
||||
from tom.transformer import TransformerEncoder
|
||||
from tom.gnn import RGATv2, RGATv3, RGATv3Agent, RGATv4, RGATv4Norm, RSAGEv4, RAGNNv4
|
||||
|
||||
|
||||
class MlpModel(nn.Module):
|
||||
"""Multilayer Perceptron with last layer linear.
|
||||
Args:
|
||||
input_size (int): number of inputs
|
||||
hidden_sizes (list): can be empty list for none (linear model).
|
||||
output_size: linear layer at output, or if ``None``, the last hidden size
|
||||
will be the output size and will have nonlinearity applied
|
||||
nonlinearity: torch nonlinearity Module (not Functional).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size,
|
||||
hidden_sizes, # Can be empty list or None for none.
|
||||
output_size=None, # if None, last layer has nonlinearity applied.
|
||||
nonlinearity=nn.ReLU, # Module, not Functional.
|
||||
dropout=None # Dropout value
|
||||
):
|
||||
super().__init__()
|
||||
if isinstance(hidden_sizes, int):
|
||||
hidden_sizes = [hidden_sizes]
|
||||
elif hidden_sizes is None:
|
||||
hidden_sizes = []
|
||||
hidden_layers = [nn.Linear(n_in, n_out) for n_in, n_out in
|
||||
zip([input_size] + hidden_sizes[:-1], hidden_sizes)]
|
||||
sequence = list()
|
||||
for i, layer in enumerate(hidden_layers):
|
||||
if dropout is not None:
|
||||
sequence.extend([layer, nonlinearity(), nn.Dropout(dropout)])
|
||||
else:
|
||||
sequence.extend([layer, nonlinearity()])
|
||||
|
||||
if output_size is not None:
|
||||
last_size = hidden_sizes[-1] if hidden_sizes else input_size
|
||||
sequence.append(torch.nn.Linear(last_size, output_size))
|
||||
self.model = nn.Sequential(*sequence)
|
||||
self._output_size = (hidden_sizes[-1] if output_size is None
|
||||
else output_size)
|
||||
|
||||
def forward(self, input):
|
||||
"""Compute the model on the input, assuming input shape [B,input_size]."""
|
||||
return self.model(input)
|
||||
|
||||
@property
|
||||
def output_size(self):
|
||||
"""Retuns the output size of the model."""
|
||||
return self._output_size
|
||||
|
||||
# ---------------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
class GraphBCRNN(pl.LightningModule):
|
||||
"""
|
||||
Implementation of the baseline model for the BC-RNN algorithm.
|
||||
R-GCN + LSTM are used to encode the familiarization trials
|
||||
"""
|
||||
@staticmethod
|
||||
def add_model_specific_args(parent_parser):
|
||||
parser = ArgumentParser(parents=[parent_parser], add_help=False)
|
||||
parser.add_argument('--lr', type=float, default=1e-4)
|
||||
parser.add_argument('--action_dim', type=int, default=2)
|
||||
parser.add_argument('--context_dim', type=int, default=32) # lstm hidden size
|
||||
parser.add_argument('--beta', type=float, default=0.01)
|
||||
parser.add_argument('--dropout', type=float, default=0.2)
|
||||
parser.add_argument('--process_data', type=int, default=0)
|
||||
parser.add_argument('--max_len', type=int, default=30)
|
||||
# arguments for gnn
|
||||
parser.add_argument('--gnn_type', type=str, default='RGATv4')
|
||||
parser.add_argument('--state_dim', type=int, default=128) # gnn out_feats
|
||||
parser.add_argument('--feats_dims', type=list, default=[9, 2, 3, 18])
|
||||
parser.add_argument('--aggregation', type=str, default='sum')
|
||||
# arguments for mpl
|
||||
#parser.add_argument('--mpl_hid_feats', type=list, default=[256, 64, 16])
|
||||
return parser
|
||||
|
||||
def __init__(self, hparams):
|
||||
super().__init__()
|
||||
|
||||
self.hparams = hparams
|
||||
self.lr = self.hparams.lr
|
||||
self.state_dim = self.hparams.state_dim
|
||||
self.action_dim = self.hparams.action_dim
|
||||
self.context_dim = self.hparams.context_dim
|
||||
self.beta = self.hparams.beta
|
||||
self.dropout = self.hparams.dropout
|
||||
self.max_len = self.hparams.max_len
|
||||
self.feats_dims = self.hparams.feats_dims # type, position, color, shape
|
||||
self.rel_names = [
|
||||
'is_aligned', 'is_back', 'is_close', 'is_down_adj', 'is_down_left_adj',
|
||||
'is_down_right_adj', 'is_front', 'is_left', 'is_left_adj', 'is_right',
|
||||
'is_right_adj', 'is_top_adj', 'is_top_left_adj', 'is_top_right_adj'
|
||||
]
|
||||
self.gnn_aggregation = self.hparams.aggregation
|
||||
|
||||
if self.hparams.gnn_type == 'RGATv2':
|
||||
self.gnn_encoder = RGATv2(
|
||||
hidden_channels=self.state_dim,
|
||||
out_channels=self.state_dim,
|
||||
num_heads=4,
|
||||
rel_names=self.rel_names,
|
||||
dropout=0.0,
|
||||
n_layers=2,
|
||||
activation=nn.ELU(),
|
||||
residual=False
|
||||
)
|
||||
if self.hparams.gnn_type == 'RGATv3':
|
||||
self.gnn_encoder = RGATv3(
|
||||
hidden_channels=self.state_dim,
|
||||
out_channels=self.state_dim,
|
||||
num_heads=4,
|
||||
rel_names=self.rel_names,
|
||||
dropout=0.0,
|
||||
n_layers=2,
|
||||
activation=nn.ELU(),
|
||||
residual=False
|
||||
)
|
||||
if self.hparams.gnn_type == 'RGATv4':
|
||||
self.gnn_encoder = RGATv4(
|
||||
hidden_channels=self.state_dim,
|
||||
out_channels=self.state_dim,
|
||||
num_heads=4,
|
||||
rel_names=self.rel_names,
|
||||
dropout=0.0,
|
||||
n_layers=2,
|
||||
activation=nn.ELU(),
|
||||
residual=False
|
||||
)
|
||||
if self.hparams.gnn_type == 'RGATv3Agent':
|
||||
self.gnn_encoder = RGATv3Agent(
|
||||
hidden_channels=self.state_dim,
|
||||
out_channels=self.state_dim,
|
||||
num_heads=4,
|
||||
rel_names=self.rel_names,
|
||||
dropout=0.0,
|
||||
n_layers=2,
|
||||
activation=nn.ELU(),
|
||||
residual=False
|
||||
)
|
||||
if self.hparams.gnn_type == 'RSAGEv4':
|
||||
self.gnn_encoder = RSAGEv4(
|
||||
hidden_channels=self.state_dim,
|
||||
out_channels=self.state_dim,
|
||||
rel_names=self.rel_names,
|
||||
dropout=0.0,
|
||||
n_layers=2,
|
||||
activation=nn.ELU()
|
||||
)
|
||||
if self.gnn_aggregation == 'cat_axis_1':
|
||||
self.lstm_input_size = self.state_dim * len(self.feats_dims) + self.action_dim
|
||||
self.mlp_input_size = self.state_dim * len(self.feats_dims) + self.context_dim * 2
|
||||
elif self.gnn_aggregation == 'sum':
|
||||
self.lstm_input_size = self.state_dim + self.action_dim
|
||||
self.mlp_input_size = self.state_dim + self.context_dim * 2
|
||||
else:
|
||||
raise ValueError('Fix this')
|
||||
|
||||
self.context_enc = nn.LSTM(self.lstm_input_size, self.context_dim, 2,
|
||||
batch_first=True, bidirectional=True)
|
||||
|
||||
self.policy = MlpModel(input_size=self.mlp_input_size, hidden_sizes=[256, 128, 256],
|
||||
output_size=self.action_dim, dropout=self.dropout)
|
||||
|
||||
self.past_samples = []
|
||||
|
||||
def forward(self, batch):
|
||||
dem_frames, dem_actions, dem_lens, query_frame, target_action = batch
|
||||
dem_actions = dem_actions.float()
|
||||
target_action = target_action.float()
|
||||
dem_states = self.gnn_encoder(dem_frames) # torch.Size([number of frames, 128 * number of features if cat_axis_1])
|
||||
# segment according the number of frames in each episode and pad with zeros
|
||||
# to obtain tensors of shape [batch size, num of trials (8), max num of frames (30), hidden dim]
|
||||
b, l, s, _ = dem_actions.size()
|
||||
dem_states_new = []
|
||||
for batch in range(b):
|
||||
dem_states_new.append(self._sequence_to_padding(dem_states, dem_lens[batch], self.max_len))
|
||||
dem_states_new = torch.stack(dem_states_new).to(self.device) # torch.Size([batchsize, 8, 30, 128 * number of features if cat_axis_1])
|
||||
# concatenate states and actions to get expert trajectory
|
||||
dem_states_new = dem_states_new.view(b * l, s, -1) # torch.Size([batchsize*8, 30, 128 * number of features if cat_axis_1])
|
||||
dem_actions = dem_actions.view(b * l, s, -1) # torch.Size([batchsize*8, 30, 128])
|
||||
dem_traj = torch.cat([dem_states_new, dem_actions], dim=2) # torch.Size([batchsize*8, 30, 2 + 128 * number of features if cat_axis_1])
|
||||
# embed expert trajectory to get a context embedding batch x samples x dim
|
||||
dem_lens = torch.tensor([t for dl in dem_lens for t in dl]).to(torch.int64).cpu()
|
||||
dem_lens = dem_lens.view(-1)
|
||||
x_lstm = nn.utils.rnn.pack_padded_sequence(dem_traj, dem_lens, batch_first=True, enforce_sorted=False)
|
||||
x_lstm, _ = self.context_enc(x_lstm)
|
||||
x_lstm, _ = nn.utils.rnn.pad_packed_sequence(x_lstm, batch_first=True)
|
||||
x_out = x_lstm[:, -1, :]
|
||||
x_out = x_out.view(b, l, -1)
|
||||
context = torch.mean(x_out, dim=1) # torch.Size([batchsize, 64]) (64=32*2)
|
||||
# concat context embedding to the state embedding of test trajectory
|
||||
test_states = self.gnn_encoder(query_frame) # torch.Size([2, 128])
|
||||
test_context_states = torch.cat([context, test_states], dim=1) # torch.Size([batchsize, 192]) 192=64+128
|
||||
# for each state in the test states calculate action
|
||||
test_actions_pred = torch.tanh(self.policy(test_context_states)) # torch.Size([batchsize, 2])
|
||||
return target_action, test_actions_pred
|
||||
|
||||
def _sequence_to_padding(self, x, lengths, max_length):
|
||||
# declare the shape, it can work for x of any shape.
|
||||
ret_tensor = torch.zeros((len(lengths), max_length) + tuple(x.shape[1:]))
|
||||
cum_len = 0
|
||||
for i, l in enumerate(lengths):
|
||||
ret_tensor[i, :l] = x[cum_len: cum_len+l]
|
||||
cum_len += l
|
||||
return ret_tensor
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
test_actions, test_actions_pred = self.forward(batch)
|
||||
loss = F.mse_loss(test_actions, test_actions_pred)
|
||||
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx, dataloader_idx=0):
|
||||
test_actions, test_actions_pred = self.forward(batch)
|
||||
loss = F.mse_loss(test_actions, test_actions_pred)
|
||||
self.log('val_loss', loss, on_epoch=True, logger=True)
|
||||
|
||||
def configure_optimizers(self):
|
||||
optim = torch.optim.Adam(self.parameters(), lr=self.lr)
|
||||
return optim
|
||||
#optim = torch.optim.AdamW(self.parameters(), lr=self.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01)
|
||||
#scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optim, gamma=0.8)
|
||||
#return [optim], [scheduler]
|
||||
|
||||
def train_dataloader(self):
|
||||
train_dataset = ToMnetDGLDataset(path=self.hparams.data_path,
|
||||
types=self.hparams.types,
|
||||
mode='train')
|
||||
train_loader = DataLoader(dataset=train_dataset,
|
||||
batch_size=self.hparams.batch_size,
|
||||
collate_fn=collate_function_seq,
|
||||
num_workers=self.hparams.num_workers,
|
||||
#pin_memory=True,
|
||||
shuffle=True)
|
||||
return train_loader
|
||||
|
||||
def val_dataloader(self):
|
||||
val_datasets = []
|
||||
val_loaders = []
|
||||
for t in self.hparams.types:
|
||||
val_datasets.append(ToMnetDGLDataset(path=self.hparams.data_path,
|
||||
types=[t],
|
||||
mode='val'))
|
||||
val_loaders.append(DataLoader(dataset=val_datasets[-1],
|
||||
batch_size=self.hparams.batch_size,
|
||||
collate_fn=collate_function_seq,
|
||||
num_workers=self.hparams.num_workers,
|
||||
#pin_memory=True,
|
||||
shuffle=False))
|
||||
return val_loaders
|
||||
|
||||
def configure_callbacks(self):
|
||||
checkpoint = ModelCheckpoint(
|
||||
dirpath=None, # automatically set
|
||||
#filename=self.params['bc_model']+'-'+self.params['gnn_type']+'-'+self.gnn_params['feats_aggr']+'-{epoch:02d}',
|
||||
save_top_k=-1,
|
||||
period=1
|
||||
)
|
||||
return [checkpoint]
|
||||
|
||||
# ---------------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
class GraphBC_T(pl.LightningModule):
|
||||
"""
|
||||
BC model with GraphTrans encoder, LSTM and MLP.
|
||||
"""
|
||||
@staticmethod
|
||||
def add_model_specific_args(parent_parser):
|
||||
parser = ArgumentParser(parents=[parent_parser], add_help=False)
|
||||
parser.add_argument('--lr', type=float, default=1e-4)
|
||||
parser.add_argument('--action_dim', type=int, default=2)
|
||||
parser.add_argument('--context_dim', type=int, default=32) # lstm hidden size
|
||||
parser.add_argument('--beta', type=float, default=0.01)
|
||||
parser.add_argument('--dropout', type=float, default=0.2)
|
||||
parser.add_argument('--process_data', type=int, default=0)
|
||||
parser.add_argument('--max_len', type=int, default=30)
|
||||
# arguments for gnn
|
||||
parser.add_argument('--state_dim', type=int, default=128) # gnn out_feats
|
||||
parser.add_argument('--feats_dims', type=list, default=[9, 2, 3, 18])
|
||||
parser.add_argument('--aggregation', type=str, default='cat_axis_1')
|
||||
parser.add_argument('--gnn_type', type=str, default='RGATv3')
|
||||
# arguments for mpl
|
||||
#parser.add_argument('--mpl_hid_feats', type=list, default=[256, 64, 16])
|
||||
# arguments for transformer
|
||||
parser.add_argument('--d_model', type=int, default=128)
|
||||
parser.add_argument('--nhead', type=int, default=4)
|
||||
parser.add_argument('--dim_feedforward', type=int, default=512)
|
||||
parser.add_argument('--transformer_dropout', type=float, default=0.3)
|
||||
parser.add_argument('--transformer_activation', type=str, default='gelu')
|
||||
parser.add_argument('--num_encoder_layers', type=int, default=6)
|
||||
parser.add_argument('--transformer_norm_input', type=int, default=0)
|
||||
return parser
|
||||
|
||||
def __init__(self, hparams):
|
||||
super().__init__()
|
||||
|
||||
self.hparams = hparams
|
||||
self.lr = self.hparams.lr
|
||||
self.state_dim = self.hparams.state_dim
|
||||
self.action_dim = self.hparams.action_dim
|
||||
self.context_dim = self.hparams.context_dim
|
||||
self.beta = self.hparams.beta
|
||||
self.dropout = self.hparams.dropout
|
||||
self.max_len = self.hparams.max_len
|
||||
self.feats_dims = self.hparams.feats_dims # type, position, color, shape
|
||||
|
||||
self.rel_names = [
|
||||
'is_aligned', 'is_back', 'is_close', 'is_down_adj', 'is_down_left_adj',
|
||||
'is_down_right_adj', 'is_front', 'is_left', 'is_left_adj', 'is_right',
|
||||
'is_right_adj', 'is_top_adj', 'is_top_left_adj', 'is_top_right_adj'
|
||||
]
|
||||
self.gnn_aggregation = self.hparams.aggregation
|
||||
if self.hparams.gnn_type == 'RGATv3':
|
||||
self.gnn_encoder = RGATv3(
|
||||
hidden_channels=self.state_dim,
|
||||
out_channels=self.state_dim,
|
||||
num_heads=4,
|
||||
rel_names=self.rel_names,
|
||||
dropout=0.0,
|
||||
n_layers=2,
|
||||
activation=nn.ELU(),
|
||||
residual=False
|
||||
)
|
||||
if self.hparams.gnn_type == 'RGATv4':
|
||||
self.gnn_encoder = RGATv4(
|
||||
hidden_channels=self.state_dim,
|
||||
out_channels=self.state_dim,
|
||||
num_heads=4,
|
||||
rel_names=self.rel_names,
|
||||
dropout=0.0,
|
||||
n_layers=2,
|
||||
activation=nn.ELU(),
|
||||
residual=False
|
||||
)
|
||||
if self.hparams.gnn_type == 'RSAGEv4':
|
||||
self.gnn_encoder = RSAGEv4(
|
||||
hidden_channels=self.state_dim,
|
||||
out_channels=self.state_dim,
|
||||
rel_names=self.rel_names,
|
||||
dropout=0.0,
|
||||
n_layers=2,
|
||||
activation=nn.ELU()
|
||||
)
|
||||
if self.hparams.gnn_type == 'RAGNNv4':
|
||||
self.gnn_encoder = RAGNNv4(
|
||||
hidden_channels=self.state_dim,
|
||||
out_channels=self.state_dim,
|
||||
rel_names=self.rel_names,
|
||||
dropout=0.0,
|
||||
n_layers=2,
|
||||
activation=nn.ELU()
|
||||
)
|
||||
if self.hparams.gnn_type == 'RGATv4Norm':
|
||||
self.gnn_encoder = RGATv4Norm(
|
||||
hidden_channels=self.state_dim,
|
||||
out_channels=self.state_dim,
|
||||
num_heads=4,
|
||||
rel_names=self.rel_names,
|
||||
dropout=0.0,
|
||||
n_layers=2,
|
||||
activation=nn.ELU(),
|
||||
residual=True
|
||||
)
|
||||
self.d_model = self.hparams.d_model
|
||||
self.nhead = self.hparams.nhead
|
||||
self.dim_feedforward = self.hparams.dim_feedforward
|
||||
self.transformer_dropout = self.hparams.transformer_dropout
|
||||
self.transformer_activation = self.hparams.transformer_activation
|
||||
self.num_encoder_layers = self.hparams.num_encoder_layers
|
||||
self.transformer_norm_input = self.hparams.transformer_norm_input
|
||||
self.context_enc = TransformerEncoder(
|
||||
self.d_model,
|
||||
self.nhead,
|
||||
self.dim_feedforward,
|
||||
self.transformer_dropout,
|
||||
self.transformer_activation,
|
||||
self.num_encoder_layers,
|
||||
self.max_len,
|
||||
self.transformer_norm_input
|
||||
)
|
||||
|
||||
if self.gnn_aggregation == 'cat_axis_1':
|
||||
self.gnn2transformer = nn.Linear(self.state_dim * len(self.feats_dims) + self.action_dim, self.d_model)
|
||||
self.mlp_input_size = len(self.feats_dims) * self.state_dim + self.d_model
|
||||
elif self.gnn_aggregation == 'sum':
|
||||
self.gnn2transformer = nn.Linear(self.state_dim + self.action_dim, self.d_model)
|
||||
self.mlp_input_size = self.state_dim + self.d_model
|
||||
else:
|
||||
raise ValueError('Only sum and cat1 aggregations are available.')
|
||||
|
||||
self.policy = MlpModel(input_size=self.mlp_input_size, hidden_sizes=[256, 128, 256],
|
||||
output_size=self.action_dim, dropout=self.dropout)
|
||||
|
||||
# CLS Embedding parameters, requires_grad=True
|
||||
self.embedding = nn.Embedding(self.max_len + 1, self.d_model) # + 1 cause of cls token
|
||||
self.emb_layer_norm = nn.LayerNorm(self.d_model)
|
||||
self.emb_dropout = nn.Dropout(p=self.transformer_dropout)
|
||||
|
||||
def forward(self, batch):
|
||||
dem_frames, dem_actions, dem_lens, query_frame, target_action = batch
|
||||
dem_actions = dem_actions.float()
|
||||
target_action = target_action.float()
|
||||
dem_states = self.gnn_encoder(dem_frames)
|
||||
b, l, s, _ = dem_actions.size()
|
||||
dem_lens = torch.tensor([t for dl in dem_lens for t in dl]).to(torch.int64).cpu()
|
||||
dem_lens = dem_lens.view(-1)
|
||||
dem_actions_packed = torch.nn.utils.rnn.pack_padded_sequence(dem_actions.view(b*l, s, -1), dem_lens, batch_first=True, enforce_sorted=False)[0]
|
||||
dem_traj = torch.cat([dem_states, dem_actions_packed], dim=1)
|
||||
h_node = self.gnn2transformer(dem_traj)
|
||||
hidden_dim = h_node.size()[1]
|
||||
padded_trajectory = torch.zeros(b*l, self.max_len, hidden_dim).to(self.device)
|
||||
j = 0
|
||||
for idx, i in enumerate(dem_lens):
|
||||
padded_trajectory[idx][:i] = h_node[j:j+i]
|
||||
j += i
|
||||
mask = self.make_mask(padded_trajectory).to(self.device)
|
||||
transformer_input = padded_trajectory.transpose(0, 1) # [30, 16, 128]
|
||||
# add cls:
|
||||
cls_embedding = nn.Parameter(torch.randn([1, 1, self.d_model], requires_grad=True)).expand(1, b*l, -1).to(self.device)
|
||||
transformer_input = torch.cat([transformer_input, cls_embedding], dim=0) # [31, 16, 128]
|
||||
zeros = mask.data.new(mask.size(0), 1).fill_(0)
|
||||
mask = torch.cat([mask, zeros], dim=1)
|
||||
# Embed
|
||||
indices = torch.arange(self.max_len + 1, dtype=torch.int).to(self.device) # + 1 cause of cls [0, 1, ..., 30]
|
||||
positional_embeddings = self.embedding(indices).unsqueeze(1) # torch.Size([31, 1, 128])
|
||||
#generate transformer input
|
||||
pe_input = positional_embeddings + transformer_input # torch.Size([31, 16, 128])
|
||||
# Layernorm and dropout
|
||||
transformer_in = self.emb_dropout(self.emb_layer_norm(pe_input)) # torch.Size([31, 16, 128])
|
||||
# transformer encoding and output parsing
|
||||
out, _ = self.context_enc(transformer_in, mask) # [31, 16, 128]
|
||||
cls = out[-1]
|
||||
cls = cls.view(b, l, -1) # 2, 8, 128
|
||||
context = torch.mean(cls, dim=1)
|
||||
# CLASSIFICATION
|
||||
test_states = self.gnn_encoder(query_frame) # torch.Size([2, 512])
|
||||
test_context_states = torch.cat([context, test_states], dim=1) # torch.Size([batchsize, hidden_dim + lstm_hidden_dim]) 192=512+128
|
||||
# for each state in the test states calculate action
|
||||
x = self.policy(test_context_states)
|
||||
test_actions_pred = torch.tanh(x) # torch.Size([batchsize, 2])
|
||||
#test_actions_pred = torch.tanh(self.policy(test_context_states)) # torch.Size([batchsize, 2])
|
||||
return target_action, test_actions_pred
|
||||
|
||||
def make_mask(self, feature):
|
||||
return (torch.sum(
|
||||
torch.abs(feature),
|
||||
dim=-1
|
||||
) == 0)#.unsqueeze(1).unsqueeze(2)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
test_actions, test_actions_pred = self.forward(batch)
|
||||
loss = F.mse_loss(test_actions, test_actions_pred)
|
||||
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx, dataloader_idx=0):
|
||||
test_actions, test_actions_pred = self.forward(batch)
|
||||
loss = F.mse_loss(test_actions, test_actions_pred)
|
||||
self.log('val_loss', loss, on_epoch=True, logger=True)
|
||||
|
||||
def configure_optimizers(self):
|
||||
optim = torch.optim.Adam(self.parameters(), lr=self.lr)
|
||||
return optim
|
||||
#optim = torch.optim.AdamW(self.parameters(), lr=self.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01)
|
||||
#scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optim, gamma=0.96)
|
||||
#return [optim], [scheduler]
|
||||
|
||||
def train_dataloader(self):
|
||||
train_dataset = ToMnetDGLDataset(path=self.hparams.data_path,
|
||||
types=self.hparams.types,
|
||||
mode='train')
|
||||
train_loader = DataLoader(dataset=train_dataset,
|
||||
batch_size=self.hparams.batch_size,
|
||||
collate_fn=collate_function_seq,
|
||||
num_workers=self.hparams.num_workers,
|
||||
#pin_memory=True,
|
||||
shuffle=True)
|
||||
return train_loader
|
||||
|
||||
def val_dataloader(self):
|
||||
val_datasets = []
|
||||
val_loaders = []
|
||||
for t in self.hparams.types:
|
||||
val_datasets.append(ToMnetDGLDataset(path=self.hparams.data_path,
|
||||
types=[t],
|
||||
mode='val'))
|
||||
val_loaders.append(DataLoader(dataset=val_datasets[-1],
|
||||
batch_size=self.hparams.batch_size,
|
||||
collate_fn=collate_function_seq,
|
||||
num_workers=self.hparams.num_workers,
|
||||
#pin_memory=True,
|
||||
shuffle=False))
|
||||
return val_loaders
|
||||
|
||||
def configure_callbacks(self):
|
||||
checkpoint = ModelCheckpoint(
|
||||
dirpath=None, # automatically set
|
||||
#filename=self.params['bc_model']+'-'+self.params['gnn_type']+'-'+self.gnn_params['feats_aggr']+'-{epoch:02d}',
|
||||
save_top_k=-1,
|
||||
period=1
|
||||
)
|
||||
return [checkpoint]
|
46
tom/norm.py
Normal file
46
tom/norm.py
Normal file
|
@ -0,0 +1,46 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Norm(nn.Module):
|
||||
|
||||
def __init__(self, norm_type, hidden_dim=64, print_info=None):
|
||||
super(Norm, self).__init__()
|
||||
|
||||
# assert norm_type in ['bn', 'ln', 'gn', None]
|
||||
self.norm = None
|
||||
self.print_info = print_info
|
||||
if norm_type == 'bn':
|
||||
self.norm = nn.BatchNorm1d(hidden_dim)
|
||||
elif norm_type == 'gn':
|
||||
self.norm = norm_type
|
||||
self.weight = nn.Parameter(torch.ones(hidden_dim))
|
||||
self.bias = nn.Parameter(torch.zeros(hidden_dim))
|
||||
|
||||
self.mean_scale = nn.Parameter(torch.ones(hidden_dim))
|
||||
|
||||
def forward(self, graph, tensor, print_=False):
|
||||
|
||||
if self.norm is not None and type(self.norm) != str:
|
||||
return self.norm(tensor)
|
||||
elif self.norm is None:
|
||||
return tensor
|
||||
|
||||
batch_list = graph.batch_num_nodes('obj')
|
||||
batch_size = len(batch_list)
|
||||
#batch_list = torch.tensor(batch_list).long().to(tensor.device)
|
||||
batch_list = batch_list.long().to(tensor.device)
|
||||
batch_index = torch.arange(batch_size).to(tensor.device).repeat_interleave(batch_list)
|
||||
batch_index = batch_index.view((-1,) + (1,) * (tensor.dim() - 1)).expand_as(tensor)
|
||||
mean = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device)
|
||||
mean = mean.scatter_add_(0, batch_index, tensor)
|
||||
mean = (mean.T / batch_list).T
|
||||
mean = mean.repeat_interleave(batch_list, dim=0)
|
||||
|
||||
sub = tensor - mean * self.mean_scale
|
||||
|
||||
std = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device)
|
||||
std = std.scatter_add_(0, batch_index, sub.pow(2))
|
||||
std = ((std.T / batch_list).T + 1e-6).sqrt()
|
||||
std = std.repeat_interleave(batch_list, dim=0)
|
||||
return self.weight * sub / std + self.bias
|
89
tom/transformer.py
Normal file
89
tom/transformer.py
Normal file
|
@ -0,0 +1,89 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
|
||||
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
|
||||
super().__init__()
|
||||
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
position = torch.arange(max_len).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
|
||||
pe = torch.zeros(max_len, 1, d_model)
|
||||
pe[:, 0, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 0, 1::2] = torch.cos(position * div_term)
|
||||
self.register_buffer('pe', pe)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: Tensor, shape [seq_len, batch_size, embedding_dim]
|
||||
"""
|
||||
x = x + self.pe[:x.size(0)]
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model,
|
||||
nhead,
|
||||
dim_feedforward,
|
||||
transformer_dropout,
|
||||
transformer_activation,
|
||||
num_encoder_layers,
|
||||
max_input_len,
|
||||
transformer_norm_input
|
||||
):
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.num_layer = num_encoder_layers
|
||||
self.max_input_len = max_input_len
|
||||
|
||||
# Creating Transformer Encoder Model
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=d_model,
|
||||
nhead=nhead,
|
||||
dim_feedforward=dim_feedforward,
|
||||
dropout=transformer_dropout,
|
||||
activation=transformer_activation
|
||||
)
|
||||
encoder_norm = nn.LayerNorm(d_model)
|
||||
self.transformer = nn.TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
||||
|
||||
|
||||
self.norm_input = None
|
||||
if transformer_norm_input:
|
||||
self.norm_input = nn.LayerNorm(d_model)
|
||||
|
||||
def forward(self, padded_h_node, src_padding_mask):
|
||||
"""
|
||||
padded_h_node: n_b x B x h_d # 63, 257, 128
|
||||
src_key_padding_mask: B x n_b # 257, 63
|
||||
"""
|
||||
# (S, B, h_d), (B, S)
|
||||
if self.norm_input is not None:
|
||||
padded_h_node = self.norm_input(padded_h_node)
|
||||
|
||||
transformer_out = self.transformer(padded_h_node, src_key_padding_mask=src_padding_mask) # (S, B, h_d)
|
||||
|
||||
return transformer_out, src_padding_mask
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = TransformerEncoder(
|
||||
d_model=12,
|
||||
nhead=4,
|
||||
dim_feedforward=32,
|
||||
transformer_dropout=0.0,
|
||||
transformer_activation='gelu',
|
||||
num_encoder_layers=4,
|
||||
max_input_len=34,
|
||||
transformer_norm_input=0
|
||||
)
|
||||
print(model.norm_input)
|
Loading…
Add table
Add a link
Reference in a new issue