878 lines
34 KiB
Python
878 lines
34 KiB
Python
|
import dgl.nn.pytorch as dglnn
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
import torch
|
||
|
import dgl
|
||
|
import sys
|
||
|
import copy
|
||
|
|
||
|
from wandb import agent
|
||
|
sys.path.append('/projects/bortoletto/irene/')
|
||
|
from tom.norm import Norm
|
||
|
|
||
|
|
||
|
class RSAGEv4(nn.Module):
|
||
|
# multi-layer GNN for one single feature
|
||
|
def __init__(
|
||
|
self,
|
||
|
hidden_channels,
|
||
|
out_channels,
|
||
|
rel_names,
|
||
|
dropout,
|
||
|
n_layers,
|
||
|
activation=nn.ELU(),
|
||
|
):
|
||
|
super().__init__()
|
||
|
self.embedding_type = nn.Linear(9, int(hidden_channels))
|
||
|
self.embedding_pos = nn.Linear(2, int(hidden_channels))
|
||
|
self.embedding_color = nn.Linear(3, int(hidden_channels))
|
||
|
self.embedding_shape = nn.Linear(18, int(hidden_channels))
|
||
|
self.combine_attr = nn.Sequential(
|
||
|
nn.ReLU(),
|
||
|
nn.Linear(hidden_channels*3, hidden_channels)
|
||
|
)
|
||
|
self.combine_pos = nn.Sequential(
|
||
|
nn.ReLU(),
|
||
|
nn.Linear(hidden_channels*2, hidden_channels)
|
||
|
)
|
||
|
self.layers = nn.ModuleList([
|
||
|
dglnn.HeteroGraphConv({
|
||
|
rel: dglnn.SAGEConv(
|
||
|
in_feats=hidden_channels,
|
||
|
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
|
||
|
aggregator_type='lstm',
|
||
|
feat_drop=dropout,
|
||
|
activation=activation if l < n_layers - 1 else None
|
||
|
)
|
||
|
for rel in rel_names}, aggregate='sum')
|
||
|
for l in range(n_layers)
|
||
|
])
|
||
|
|
||
|
def forward(self, g):
|
||
|
attr = []
|
||
|
attr.append(self.embedding_type(g.ndata['type'].float()))
|
||
|
pos = self.embedding_pos(g.ndata['pos']/170.)
|
||
|
attr.append(self.embedding_color(g.ndata['color']/255.))
|
||
|
attr.append(self.embedding_shape(g.ndata['shape'].float()))
|
||
|
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
|
||
|
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
|
||
|
for l, conv in enumerate(self.layers):
|
||
|
h = conv(g, h)
|
||
|
with g.local_scope():
|
||
|
g.ndata['h'] = h['obj']
|
||
|
out = dgl.mean_nodes(g, 'h')
|
||
|
return out
|
||
|
|
||
|
# -------------------------------------------------------------------------------------------
|
||
|
|
||
|
class RAGNNv4(nn.Module):
|
||
|
# multi-layer GNN for one single feature
|
||
|
def __init__(
|
||
|
self,
|
||
|
hidden_channels,
|
||
|
out_channels,
|
||
|
rel_names,
|
||
|
dropout,
|
||
|
n_layers,
|
||
|
activation=nn.ELU(),
|
||
|
):
|
||
|
super().__init__()
|
||
|
self.embedding_type = nn.Linear(9, int(hidden_channels))
|
||
|
self.embedding_pos = nn.Linear(2, int(hidden_channels))
|
||
|
self.embedding_color = nn.Linear(3, int(hidden_channels))
|
||
|
self.embedding_shape = nn.Linear(18, int(hidden_channels))
|
||
|
self.combine_attr = nn.Sequential(
|
||
|
nn.ReLU(),
|
||
|
nn.Linear(hidden_channels*3, hidden_channels)
|
||
|
)
|
||
|
self.combine_pos = nn.Sequential(
|
||
|
nn.ReLU(),
|
||
|
nn.Linear(hidden_channels*2, hidden_channels)
|
||
|
)
|
||
|
self.layers = nn.ModuleList([
|
||
|
dglnn.HeteroGraphConv({
|
||
|
rel: dglnn.AGNNConv()
|
||
|
for rel in rel_names}, aggregate='sum')
|
||
|
for l in range(n_layers)
|
||
|
])
|
||
|
|
||
|
def forward(self, g):
|
||
|
attr = []
|
||
|
attr.append(self.embedding_type(g.ndata['type'].float()))
|
||
|
pos = self.embedding_pos(g.ndata['pos']/170.)
|
||
|
attr.append(self.embedding_color(g.ndata['color']/255.))
|
||
|
attr.append(self.embedding_shape(g.ndata['shape'].float()))
|
||
|
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
|
||
|
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
|
||
|
for l, conv in enumerate(self.layers):
|
||
|
h = conv(g, h)
|
||
|
with g.local_scope():
|
||
|
g.ndata['h'] = h['obj']
|
||
|
out = dgl.mean_nodes(g, 'h')
|
||
|
return out
|
||
|
|
||
|
# -------------------------------------------------------------------------------------------
|
||
|
|
||
|
class RGATv2(nn.Module):
|
||
|
# multi-layer GNN for one single feature
|
||
|
def __init__(
|
||
|
self,
|
||
|
hidden_channels,
|
||
|
out_channels,
|
||
|
num_heads,
|
||
|
rel_names,
|
||
|
dropout,
|
||
|
n_layers,
|
||
|
activation=nn.ELU(),
|
||
|
residual=False
|
||
|
):
|
||
|
super().__init__()
|
||
|
self.embedding_type = nn.Embedding(9, int(hidden_channels*num_heads/4))
|
||
|
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads/4))
|
||
|
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads/4))
|
||
|
self.embedding_shape = nn.Embedding(18, int(hidden_channels*num_heads/4))
|
||
|
self.combine = nn.Sequential(
|
||
|
nn.ReLU(),
|
||
|
nn.Linear(hidden_channels*num_heads, hidden_channels*num_heads)
|
||
|
)
|
||
|
self.layers = nn.ModuleList([
|
||
|
dglnn.HeteroGraphConv({
|
||
|
rel: dglnn.GATv2Conv(
|
||
|
in_feats=hidden_channels*num_heads,
|
||
|
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
|
||
|
num_heads=num_heads if l < n_layers - 1 else 1,
|
||
|
feat_drop=dropout,
|
||
|
attn_drop=dropout,
|
||
|
residual=residual,
|
||
|
activation=activation if l < n_layers - 1 else None
|
||
|
)
|
||
|
for rel in rel_names}, aggregate='sum')
|
||
|
for l in range(n_layers)
|
||
|
])
|
||
|
|
||
|
def forward(self, g):
|
||
|
feats = []
|
||
|
feats.append(self.embedding_type(torch.argmax(g.ndata['type'], dim=1)))
|
||
|
feats.append(self.embedding_pos(g.ndata['pos']/170.))
|
||
|
feats.append(self.embedding_color(g.ndata['color']/255.))
|
||
|
feats.append(self.embedding_shape(torch.argmax(g.ndata['shape'], dim=1)))
|
||
|
h = {'obj': self.combine(torch.cat(feats, dim=1))}
|
||
|
for l, conv in enumerate(self.layers):
|
||
|
h = conv(g, h)
|
||
|
if l != len(self.layers) - 1:
|
||
|
h = {k: v.flatten(1) for k, v in h.items()}
|
||
|
else:
|
||
|
h = {k: v.mean(1) for k, v in h.items()}
|
||
|
with g.local_scope():
|
||
|
g.ndata['h'] = h['obj']
|
||
|
out = dgl.mean_nodes(g, 'h')
|
||
|
return out
|
||
|
|
||
|
# -------------------------------------------------------------------------------------------
|
||
|
|
||
|
class RGATv3(nn.Module):
|
||
|
# multi-layer GNN for one single feature
|
||
|
def __init__(
|
||
|
self,
|
||
|
hidden_channels,
|
||
|
out_channels,
|
||
|
num_heads,
|
||
|
rel_names,
|
||
|
dropout,
|
||
|
n_layers,
|
||
|
activation=nn.ELU(),
|
||
|
residual=False
|
||
|
):
|
||
|
super().__init__()
|
||
|
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
|
||
|
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
|
||
|
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
|
||
|
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
|
||
|
self.combine = nn.Sequential(
|
||
|
nn.ReLU(),
|
||
|
nn.Linear(hidden_channels*num_heads*4, hidden_channels*num_heads)
|
||
|
)
|
||
|
self.layers = nn.ModuleList([
|
||
|
dglnn.HeteroGraphConv({
|
||
|
rel: dglnn.GATv2Conv(
|
||
|
in_feats=hidden_channels*num_heads,
|
||
|
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
|
||
|
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
|
||
|
feat_drop=dropout,
|
||
|
attn_drop=dropout,
|
||
|
residual=residual,
|
||
|
activation=activation if l < n_layers - 1 else None
|
||
|
)
|
||
|
for rel in rel_names}, aggregate='sum')
|
||
|
for l in range(n_layers)
|
||
|
])
|
||
|
|
||
|
def forward(self, g):
|
||
|
feats = []
|
||
|
feats.append(self.embedding_type(g.ndata['type'].float()))
|
||
|
feats.append(self.embedding_pos(g.ndata['pos']/170.)) # NOTE: this should be 180 because I remove the boundary walls!
|
||
|
feats.append(self.embedding_color(g.ndata['color']/255.))
|
||
|
feats.append(self.embedding_shape(g.ndata['shape'].float()))
|
||
|
h = {'obj': self.combine(torch.cat(feats, dim=1))}
|
||
|
for l, conv in enumerate(self.layers):
|
||
|
h = conv(g, h)
|
||
|
if l != len(self.layers) - 1:
|
||
|
h = {k: v.flatten(1) for k, v in h.items()}
|
||
|
else:
|
||
|
h = {k: v.mean(1) for k, v in h.items()}
|
||
|
with g.local_scope():
|
||
|
g.ndata['h'] = h['obj']
|
||
|
out = dgl.mean_nodes(g, 'h')
|
||
|
return out
|
||
|
|
||
|
# -------------------------------------------------------------------------------------------
|
||
|
|
||
|
class RGCNv2(nn.Module):
|
||
|
# multi-layer GNN for one single feature
|
||
|
def __init__(
|
||
|
self,
|
||
|
hidden_channels,
|
||
|
out_channels,
|
||
|
rel_names,
|
||
|
dropout,
|
||
|
n_layers,
|
||
|
activation=nn.ReLU()
|
||
|
):
|
||
|
super().__init__()
|
||
|
self.embedding_type = nn.Linear(9, int(hidden_channels))
|
||
|
self.embedding_pos = nn.Linear(2, int(hidden_channels))
|
||
|
self.embedding_color = nn.Linear(3, int(hidden_channels))
|
||
|
self.embedding_shape = nn.Linear(18, int(hidden_channels))
|
||
|
self.combine = nn.Sequential(
|
||
|
nn.ReLU(),
|
||
|
nn.Linear(hidden_channels*4, hidden_channels)
|
||
|
)
|
||
|
self.layers = nn.ModuleList([
|
||
|
dglnn.RelGraphConv(
|
||
|
in_feat=hidden_channels,
|
||
|
out_feat=hidden_channels,
|
||
|
num_rels=len(rel_names),
|
||
|
regularizer=None,
|
||
|
num_bases=None,
|
||
|
bias=True,
|
||
|
activation=activation,
|
||
|
self_loop=True,
|
||
|
dropout=dropout,
|
||
|
layer_norm=False
|
||
|
)
|
||
|
for _ in range(n_layers-1)])
|
||
|
self.layers.append(
|
||
|
dglnn.RelGraphConv(
|
||
|
in_feat=hidden_channels,
|
||
|
out_feat=out_channels,
|
||
|
num_rels=len(rel_names),
|
||
|
regularizer=None,
|
||
|
num_bases=None,
|
||
|
bias=True,
|
||
|
activation=activation,
|
||
|
self_loop=True,
|
||
|
dropout=dropout,
|
||
|
layer_norm=False
|
||
|
)
|
||
|
)
|
||
|
|
||
|
def forward(self, g):
|
||
|
g = g.to_homogeneous()
|
||
|
feats = []
|
||
|
feats.append(self.embedding_type(g.ndata['type'].float()))
|
||
|
feats.append(self.embedding_pos(g.ndata['pos']/170.))
|
||
|
feats.append(self.embedding_color(g.ndata['color']/255.))
|
||
|
feats.append(self.embedding_shape(g.ndata['shape'].float()))
|
||
|
h = self.combine(torch.cat(feats, dim=1))
|
||
|
for conv in self.layers:
|
||
|
h = conv(g, h, g.etypes)
|
||
|
with g.local_scope():
|
||
|
g.ndata['h'] = h
|
||
|
out = dgl.mean_nodes(g, 'h')
|
||
|
return out
|
||
|
|
||
|
# -------------------------------------------------------------------------------------------
|
||
|
|
||
|
class RGATv3Agent(nn.Module):
|
||
|
# multi-layer GNN for one single feature
|
||
|
def __init__(
|
||
|
self,
|
||
|
hidden_channels,
|
||
|
out_channels,
|
||
|
num_heads,
|
||
|
rel_names,
|
||
|
dropout,
|
||
|
n_layers,
|
||
|
activation=nn.ELU(),
|
||
|
residual=False
|
||
|
):
|
||
|
super().__init__()
|
||
|
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
|
||
|
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
|
||
|
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
|
||
|
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
|
||
|
self.combine = nn.Sequential(
|
||
|
nn.ReLU(),
|
||
|
nn.Linear(hidden_channels*num_heads*4, hidden_channels*num_heads)
|
||
|
)
|
||
|
self.layers = nn.ModuleList([
|
||
|
dglnn.HeteroGraphConv({
|
||
|
rel: dglnn.GATv2Conv(
|
||
|
in_feats=hidden_channels*num_heads,
|
||
|
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
|
||
|
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
|
||
|
feat_drop=dropout,
|
||
|
attn_drop=dropout,
|
||
|
residual=residual,
|
||
|
activation=activation if l < n_layers - 1 else None
|
||
|
)
|
||
|
for rel in rel_names}, aggregate='sum')
|
||
|
for l in range(n_layers)
|
||
|
])
|
||
|
|
||
|
def forward(self, g):
|
||
|
agent_mask = g.ndata['type'][:, 0] == 1
|
||
|
feats = []
|
||
|
feats.append(self.embedding_type(g.ndata['type'].float()))
|
||
|
feats.append(self.embedding_pos(g.ndata['pos']/200.))
|
||
|
feats.append(self.embedding_color(g.ndata['color']/255.))
|
||
|
feats.append(self.embedding_shape(g.ndata['shape'].float()))
|
||
|
h = {'obj': self.combine(torch.cat(feats, dim=1))}
|
||
|
for l, conv in enumerate(self.layers):
|
||
|
h = conv(g, h)
|
||
|
if l != len(self.layers) - 1:
|
||
|
h = {k: v.flatten(1) for k, v in h.items()}
|
||
|
else:
|
||
|
h = {k: v.mean(1) for k, v in h.items()}
|
||
|
with g.local_scope():
|
||
|
g.ndata['h'] = h['obj']
|
||
|
out = g.ndata['h'][agent_mask, :]
|
||
|
ctx = dgl.mean_nodes(g, 'h')
|
||
|
return out + ctx
|
||
|
|
||
|
# -------------------------------------------------------------------------------------------
|
||
|
|
||
|
class RGATv4(nn.Module):
|
||
|
# multi-layer GNN for one single feature
|
||
|
def __init__(
|
||
|
self,
|
||
|
hidden_channels,
|
||
|
out_channels,
|
||
|
num_heads,
|
||
|
rel_names,
|
||
|
dropout,
|
||
|
n_layers,
|
||
|
activation=nn.ELU(),
|
||
|
residual=False
|
||
|
):
|
||
|
super().__init__()
|
||
|
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
|
||
|
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
|
||
|
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
|
||
|
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
|
||
|
self.combine_attr = nn.Sequential(
|
||
|
nn.ReLU(),
|
||
|
nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads)
|
||
|
)
|
||
|
self.combine_pos = nn.Sequential(
|
||
|
nn.ReLU(),
|
||
|
nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads)
|
||
|
)
|
||
|
self.layers = nn.ModuleList([
|
||
|
dglnn.HeteroGraphConv({
|
||
|
rel: dglnn.GATv2Conv(
|
||
|
in_feats=hidden_channels*num_heads,
|
||
|
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
|
||
|
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
|
||
|
feat_drop=dropout,
|
||
|
attn_drop=dropout,
|
||
|
residual=residual,
|
||
|
share_weights=False,
|
||
|
activation=activation if l < n_layers - 1 else None
|
||
|
)
|
||
|
for rel in rel_names}, aggregate='sum')
|
||
|
for l in range(n_layers)
|
||
|
])
|
||
|
|
||
|
def forward(self, g):
|
||
|
attr = []
|
||
|
attr.append(self.embedding_type(g.ndata['type'].float()))
|
||
|
pos = self.embedding_pos(g.ndata['pos']/170.)
|
||
|
attr.append(self.embedding_color(g.ndata['color']/255.))
|
||
|
attr.append(self.embedding_shape(g.ndata['shape'].float()))
|
||
|
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
|
||
|
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
|
||
|
for l, conv in enumerate(self.layers):
|
||
|
h = conv(g, h)
|
||
|
if l != len(self.layers) - 1:
|
||
|
h = {k: v.flatten(1) for k, v in h.items()}
|
||
|
else:
|
||
|
h = {k: v.mean(1) for k, v in h.items()}
|
||
|
with g.local_scope():
|
||
|
g.ndata['h'] = h['obj']
|
||
|
out = dgl.mean_nodes(g, 'h')
|
||
|
return out
|
||
|
|
||
|
# -------------------------------------------------------------------------------------------
|
||
|
|
||
|
class RGATv4Norm(nn.Module):
|
||
|
# multi-layer GNN for one single feature
|
||
|
def __init__(
|
||
|
self,
|
||
|
hidden_channels,
|
||
|
out_channels,
|
||
|
num_heads,
|
||
|
rel_names,
|
||
|
dropout,
|
||
|
n_layers,
|
||
|
activation=nn.ELU(),
|
||
|
residual=False
|
||
|
):
|
||
|
super().__init__()
|
||
|
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
|
||
|
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
|
||
|
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
|
||
|
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
|
||
|
self.combine_attr = nn.Sequential(
|
||
|
nn.ReLU(),
|
||
|
nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads)
|
||
|
)
|
||
|
self.combine_pos = nn.Sequential(
|
||
|
nn.ReLU(),
|
||
|
nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads)
|
||
|
)
|
||
|
self.layers = nn.ModuleList([
|
||
|
dglnn.HeteroGraphConv({
|
||
|
rel: dglnn.GATv2Conv(
|
||
|
in_feats=hidden_channels*num_heads,
|
||
|
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
|
||
|
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
|
||
|
feat_drop=dropout,
|
||
|
attn_drop=dropout,
|
||
|
residual=residual,
|
||
|
activation=activation if l < n_layers - 1 else None
|
||
|
)
|
||
|
for rel in rel_names}, aggregate='sum')
|
||
|
for l in range(n_layers)
|
||
|
])
|
||
|
self.norms = nn.ModuleList([
|
||
|
Norm(
|
||
|
norm_type='gn',
|
||
|
hidden_dim=hidden_channels*num_heads if l < n_layers - 1 else out_channels
|
||
|
)
|
||
|
for l in range(n_layers)
|
||
|
])
|
||
|
|
||
|
def forward(self, g):
|
||
|
attr = []
|
||
|
attr.append(self.embedding_type(g.ndata['type'].float()))
|
||
|
pos = self.embedding_pos(g.ndata['pos']/170.)
|
||
|
attr.append(self.embedding_color(g.ndata['color']/255.))
|
||
|
attr.append(self.embedding_shape(g.ndata['shape'].float()))
|
||
|
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
|
||
|
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
|
||
|
for l, conv in enumerate(self.layers):
|
||
|
h = conv(g, h)
|
||
|
if l != len(self.layers) - 1:
|
||
|
h = {k: v.flatten(1) for k, v in h.items()}
|
||
|
else:
|
||
|
h = {k: v.mean(1) for k, v in h.items()}
|
||
|
h = {k: self.norms[l](g, v) for k, v in h.items()}
|
||
|
with g.local_scope():
|
||
|
g.ndata['h'] = h['obj']
|
||
|
out = dgl.mean_nodes(g, 'h')
|
||
|
return out
|
||
|
|
||
|
# -------------------------------------------------------------------------------------------
|
||
|
|
||
|
class RGATv3Norm(nn.Module):
|
||
|
# multi-layer GNN for one single feature
|
||
|
def __init__(
|
||
|
self,
|
||
|
hidden_channels,
|
||
|
out_channels,
|
||
|
num_heads,
|
||
|
rel_names,
|
||
|
dropout,
|
||
|
n_layers,
|
||
|
activation=nn.ELU(),
|
||
|
residual=False
|
||
|
):
|
||
|
super().__init__()
|
||
|
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
|
||
|
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
|
||
|
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
|
||
|
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
|
||
|
self.combine = nn.Sequential(
|
||
|
nn.ReLU(),
|
||
|
nn.Linear(hidden_channels*num_heads*4, hidden_channels*num_heads)
|
||
|
)
|
||
|
self.layers = nn.ModuleList([
|
||
|
dglnn.HeteroGraphConv({
|
||
|
rel: dglnn.GATv2Conv(
|
||
|
in_feats=hidden_channels*num_heads,
|
||
|
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
|
||
|
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
|
||
|
feat_drop=dropout,
|
||
|
attn_drop=dropout,
|
||
|
residual=residual,
|
||
|
activation=activation if l < n_layers - 1 else None
|
||
|
)
|
||
|
for rel in rel_names}, aggregate='sum')
|
||
|
for l in range(n_layers)
|
||
|
])
|
||
|
self.norms = nn.ModuleList([
|
||
|
Norm(
|
||
|
norm_type='gn',
|
||
|
hidden_dim=hidden_channels*num_heads if l < n_layers - 1 else out_channels
|
||
|
)
|
||
|
for l in range(n_layers)
|
||
|
])
|
||
|
|
||
|
def forward(self, g):
|
||
|
feats = []
|
||
|
feats.append(self.embedding_type(g.ndata['type'].float()))
|
||
|
feats.append(self.embedding_pos(g.ndata['pos']/170.))
|
||
|
feats.append(self.embedding_color(g.ndata['color']/255.))
|
||
|
feats.append(self.embedding_shape(g.ndata['shape'].float()))
|
||
|
h = {'obj': self.combine(torch.cat(feats, dim=1))}
|
||
|
for l, conv in enumerate(self.layers):
|
||
|
h = conv(g, h)
|
||
|
if l != len(self.layers) - 1:
|
||
|
h = {k: v.flatten(1) for k, v in h.items()}
|
||
|
else:
|
||
|
h = {k: v.mean(1) for k, v in h.items()}
|
||
|
h = {k: self.norms[l](g, v) for k, v in h.items()}
|
||
|
with g.local_scope():
|
||
|
g.ndata['h'] = h['obj']
|
||
|
out = dgl.mean_nodes(g, 'h')
|
||
|
return out
|
||
|
|
||
|
# -------------------------------------------------------------------------------------------
|
||
|
|
||
|
class RGATv4Agent(nn.Module):
|
||
|
# multi-layer GNN for one single feature
|
||
|
def __init__(
|
||
|
self,
|
||
|
hidden_channels,
|
||
|
out_channels,
|
||
|
num_heads,
|
||
|
rel_names,
|
||
|
dropout,
|
||
|
n_layers,
|
||
|
activation=nn.ELU(),
|
||
|
residual=False
|
||
|
):
|
||
|
super().__init__()
|
||
|
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
|
||
|
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
|
||
|
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
|
||
|
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
|
||
|
self.combine_attr = nn.Sequential(
|
||
|
nn.ReLU(),
|
||
|
nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads)
|
||
|
)
|
||
|
self.combine_pos = nn.Sequential(
|
||
|
nn.ReLU(),
|
||
|
nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads)
|
||
|
)
|
||
|
self.layers = nn.ModuleList([
|
||
|
dglnn.HeteroGraphConv({
|
||
|
rel: dglnn.GATv2Conv(
|
||
|
in_feats=hidden_channels*num_heads,
|
||
|
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
|
||
|
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
|
||
|
feat_drop=dropout,
|
||
|
attn_drop=dropout,
|
||
|
residual=residual,
|
||
|
activation=activation if l < n_layers - 1 else None
|
||
|
)
|
||
|
for rel in rel_names}, aggregate='sum')
|
||
|
for l in range(n_layers)
|
||
|
])
|
||
|
self.combine_agent_context = nn.Linear(out_channels*2, out_channels)
|
||
|
|
||
|
def forward(self, g):
|
||
|
agent_mask = g.ndata['type'][:, 0] == 1
|
||
|
attr = []
|
||
|
attr.append(self.embedding_type(g.ndata['type'].float()))
|
||
|
pos = self.embedding_pos(g.ndata['pos']/170.)
|
||
|
attr.append(self.embedding_color(g.ndata['color']/255.))
|
||
|
attr.append(self.embedding_shape(g.ndata['shape'].float()))
|
||
|
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
|
||
|
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
|
||
|
for l, conv in enumerate(self.layers):
|
||
|
h = conv(g, h)
|
||
|
if l != len(self.layers) - 1:
|
||
|
h = {k: v.flatten(1) for k, v in h.items()}
|
||
|
else:
|
||
|
h = {k: v.mean(1) for k, v in h.items()}
|
||
|
with g.local_scope():
|
||
|
g.ndata['h'] = h['obj']
|
||
|
h_a = g.ndata['h'][agent_mask, :]
|
||
|
g_no_agent = copy.deepcopy(g)
|
||
|
g_no_agent.remove_nodes([i for i, x in enumerate(agent_mask) if x])
|
||
|
h_g = dgl.mean_nodes(g_no_agent, 'h')
|
||
|
out = self.combine_agent_context(torch.cat((h_a, h_g), dim=1))
|
||
|
return out
|
||
|
|
||
|
# -------------------------------------------------------------------------------------------
|
||
|
|
||
|
class RGCNv4(nn.Module):
|
||
|
# multi-layer GNN for one single feature
|
||
|
def __init__(
|
||
|
self,
|
||
|
hidden_channels,
|
||
|
out_channels,
|
||
|
rel_names,
|
||
|
n_layers,
|
||
|
activation=nn.ELU(),
|
||
|
):
|
||
|
super().__init__()
|
||
|
self.embedding_type = nn.Linear(9, int(hidden_channels))
|
||
|
self.embedding_pos = nn.Linear(2, int(hidden_channels))
|
||
|
self.embedding_color = nn.Linear(3, int(hidden_channels))
|
||
|
self.embedding_shape = nn.Linear(18, int(hidden_channels))
|
||
|
self.combine_attr = nn.Sequential(
|
||
|
nn.ReLU(),
|
||
|
nn.Linear(hidden_channels*3, hidden_channels)
|
||
|
)
|
||
|
self.combine_pos = nn.Sequential(
|
||
|
nn.ReLU(),
|
||
|
nn.Linear(hidden_channels*2, hidden_channels)
|
||
|
)
|
||
|
self.layers = nn.ModuleList([
|
||
|
dglnn.HeteroGraphConv({
|
||
|
rel: dglnn.GraphConv(
|
||
|
in_feats=hidden_channels,
|
||
|
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
|
||
|
activation=activation if l < n_layers - 1 else None
|
||
|
)
|
||
|
for rel in rel_names}, aggregate='sum')
|
||
|
for l in range(n_layers)
|
||
|
])
|
||
|
|
||
|
def forward(self, g):
|
||
|
attr = []
|
||
|
attr.append(self.embedding_type(g.ndata['type'].float()))
|
||
|
pos = self.embedding_pos(g.ndata['pos']/170.)
|
||
|
attr.append(self.embedding_color(g.ndata['color']/255.))
|
||
|
attr.append(self.embedding_shape(g.ndata['shape'].float()))
|
||
|
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
|
||
|
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
|
||
|
for l, conv in enumerate(self.layers):
|
||
|
h = conv(g, h)
|
||
|
with g.local_scope():
|
||
|
g.ndata['h'] = h['obj']
|
||
|
out = dgl.mean_nodes(g, 'h')
|
||
|
return out
|
||
|
|
||
|
# -------------------------------------------------------------------------------------------
|
||
|
|
||
|
class RGATv5(nn.Module):
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
hidden_channels,
|
||
|
out_channels,
|
||
|
num_heads,
|
||
|
rel_names,
|
||
|
dropout,
|
||
|
n_layers,
|
||
|
activation=nn.ELU(),
|
||
|
residual=False
|
||
|
):
|
||
|
super().__init__()
|
||
|
self.hidden_channels = hidden_channels
|
||
|
self.num_heads = num_heads
|
||
|
self.embedding_type = nn.Linear(9, hidden_channels*num_heads)
|
||
|
self.embedding_pos = nn.Linear(2, hidden_channels*num_heads)
|
||
|
self.embedding_color = nn.Linear(3, hidden_channels*num_heads)
|
||
|
self.embedding_shape = nn.Linear(18, hidden_channels*num_heads)
|
||
|
self.combine = nn.Sequential(
|
||
|
nn.ReLU(),
|
||
|
nn.Linear(hidden_channels*num_heads*4, hidden_channels*num_heads)
|
||
|
)
|
||
|
self.attention = nn.Linear(hidden_channels*num_heads*4, 4)
|
||
|
self.layers = nn.ModuleList([
|
||
|
dglnn.HeteroGraphConv({
|
||
|
rel: dglnn.GATv2Conv(
|
||
|
in_feats=hidden_channels*num_heads,
|
||
|
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
|
||
|
num_heads=num_heads,
|
||
|
feat_drop=dropout,
|
||
|
attn_drop=dropout,
|
||
|
residual=residual,
|
||
|
activation=activation if l < n_layers - 1 else None
|
||
|
)
|
||
|
for rel in rel_names}, aggregate='sum')
|
||
|
for l in range(n_layers)
|
||
|
])
|
||
|
|
||
|
def forward(self, g):
|
||
|
feats = []
|
||
|
feats.append(self.embedding_type(g.ndata['type'].float()))
|
||
|
feats.append(self.embedding_pos(g.ndata['pos']/170.))
|
||
|
feats.append(self.embedding_color(g.ndata['color']/255.))
|
||
|
feats.append(self.embedding_shape(g.ndata['shape'].float()))
|
||
|
h = torch.cat(feats, dim=1)
|
||
|
feat_attn = F.softmax(self.attention(h), dim=1)
|
||
|
h = h * feat_attn.repeat_interleave(self.hidden_channels*self.num_heads, dim=1)
|
||
|
h_in = self.combine(h)
|
||
|
h = {'obj': h_in}
|
||
|
for conv in self.layers:
|
||
|
h = conv(g, h)
|
||
|
h = {k: v.flatten(1) for k, v in h.items()}
|
||
|
#if l != len(self.layers) - 1:
|
||
|
# h = {k: v.flatten(1) for k, v in h.items()}
|
||
|
#else:
|
||
|
# h = {k: v.mean(1) for k, v in h.items()}
|
||
|
h = {k: v + h_in for k, v in h.items()}
|
||
|
with g.local_scope():
|
||
|
g.ndata['h'] = h['obj']
|
||
|
out = dgl.mean_nodes(g, 'h')
|
||
|
return out
|
||
|
|
||
|
# -------------------------------------------------------------------------------------------
|
||
|
|
||
|
class RGATv6(nn.Module):
|
||
|
|
||
|
# RGATv6 = RGATv4 + Global Attention Pooling
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
hidden_channels,
|
||
|
out_channels,
|
||
|
num_heads,
|
||
|
rel_names,
|
||
|
dropout,
|
||
|
n_layers,
|
||
|
activation=nn.ELU(),
|
||
|
residual=False
|
||
|
):
|
||
|
super().__init__()
|
||
|
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
|
||
|
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
|
||
|
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
|
||
|
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
|
||
|
self.combine_attr = nn.Sequential(
|
||
|
nn.ReLU(),
|
||
|
nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads)
|
||
|
)
|
||
|
self.combine_pos = nn.Sequential(
|
||
|
nn.ReLU(),
|
||
|
nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads)
|
||
|
)
|
||
|
self.layers = nn.ModuleList([
|
||
|
dglnn.HeteroGraphConv({
|
||
|
rel: dglnn.GATv2Conv(
|
||
|
in_feats=hidden_channels*num_heads,
|
||
|
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
|
||
|
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
|
||
|
feat_drop=dropout,
|
||
|
attn_drop=dropout,
|
||
|
residual=residual,
|
||
|
activation=activation if l < n_layers - 1 else None
|
||
|
)
|
||
|
for rel in rel_names}, aggregate='sum')
|
||
|
for l in range(n_layers)
|
||
|
])
|
||
|
gate_nn = nn.Linear(out_channels, 1)
|
||
|
self.gap = dglnn.GlobalAttentionPooling(gate_nn)
|
||
|
|
||
|
def forward(self, g):
|
||
|
attr = []
|
||
|
attr.append(self.embedding_type(g.ndata['type'].float()))
|
||
|
pos = self.embedding_pos(g.ndata['pos']/170.)
|
||
|
attr.append(self.embedding_color(g.ndata['color']/255.))
|
||
|
attr.append(self.embedding_shape(g.ndata['shape'].float()))
|
||
|
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
|
||
|
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
|
||
|
for l, conv in enumerate(self.layers):
|
||
|
h = conv(g, h)
|
||
|
if l != len(self.layers) - 1:
|
||
|
h = {k: v.flatten(1) for k, v in h.items()}
|
||
|
else:
|
||
|
h = {k: v.mean(1) for k, v in h.items()}
|
||
|
#with g.local_scope():
|
||
|
#g.ndata['h'] = h['obj']
|
||
|
#out = dgl.mean_nodes(g, 'h')
|
||
|
out = self.gap(g, h['obj'])
|
||
|
return out
|
||
|
|
||
|
# -------------------------------------------------------------------------------------------
|
||
|
|
||
|
class RGATv6Agent(nn.Module):
|
||
|
|
||
|
# RGATv6 = RGATv4 + Global Attention Pooling
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
hidden_channels,
|
||
|
out_channels,
|
||
|
num_heads,
|
||
|
rel_names,
|
||
|
dropout,
|
||
|
n_layers,
|
||
|
activation=nn.ELU(),
|
||
|
residual=False
|
||
|
):
|
||
|
super().__init__()
|
||
|
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
|
||
|
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
|
||
|
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
|
||
|
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
|
||
|
self.combine_attr = nn.Sequential(
|
||
|
nn.ReLU(),
|
||
|
nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads)
|
||
|
)
|
||
|
self.combine_pos = nn.Sequential(
|
||
|
nn.ReLU(),
|
||
|
nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads)
|
||
|
)
|
||
|
self.layers = nn.ModuleList([
|
||
|
dglnn.HeteroGraphConv({
|
||
|
rel: dglnn.GATv2Conv(
|
||
|
in_feats=hidden_channels*num_heads,
|
||
|
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
|
||
|
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
|
||
|
feat_drop=dropout,
|
||
|
attn_drop=dropout,
|
||
|
residual=residual,
|
||
|
activation=activation if l < n_layers - 1 else None
|
||
|
)
|
||
|
for rel in rel_names}, aggregate='sum')
|
||
|
for l in range(n_layers)
|
||
|
])
|
||
|
gate_nn = nn.Linear(out_channels, 1)
|
||
|
self.gap = dglnn.GlobalAttentionPooling(gate_nn)
|
||
|
self.combine_agent_context = nn.Linear(out_channels*2, out_channels)
|
||
|
|
||
|
def forward(self, g):
|
||
|
agent_mask = g.ndata['type'][:, 0] == 1
|
||
|
attr = []
|
||
|
attr.append(self.embedding_type(g.ndata['type'].float()))
|
||
|
pos = self.embedding_pos(g.ndata['pos']/170.)
|
||
|
attr.append(self.embedding_color(g.ndata['color']/255.))
|
||
|
attr.append(self.embedding_shape(g.ndata['shape'].float()))
|
||
|
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
|
||
|
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
|
||
|
for l, conv in enumerate(self.layers):
|
||
|
h = conv(g, h)
|
||
|
if l != len(self.layers) - 1:
|
||
|
h = {k: v.flatten(1) for k, v in h.items()}
|
||
|
else:
|
||
|
h = {k: v.mean(1) for k, v in h.items()}
|
||
|
with g.local_scope():
|
||
|
g.ndata['h'] = h['obj']
|
||
|
h_a = g.ndata['h'][agent_mask, :]
|
||
|
h_g = g.ndata['h'][~agent_mask, :]
|
||
|
g_no_agent = copy.deepcopy(g)
|
||
|
g_no_agent.remove_nodes([i for i, x in enumerate(agent_mask) if x])
|
||
|
h_g = self.gap(g_no_agent, h_g) # dgl.mean_nodes(g_no_agent, 'h')
|
||
|
out = self.combine_agent_context(torch.cat((h_a, h_g), dim=1))
|
||
|
return out
|
||
|
|
||
|
# -------------------------------------------------------------------------------------------
|
||
|
|