IRENE/tom/gnn.py
2024-02-01 15:40:47 +01:00

877 lines
34 KiB
Python

import dgl.nn.pytorch as dglnn
import torch.nn as nn
import torch.nn.functional as F
import torch
import dgl
import sys
import copy
from wandb import agent
sys.path.append('/projects/bortoletto/irene/')
from tom.norm import Norm
class RSAGEv4(nn.Module):
# multi-layer GNN for one single feature
def __init__(
self,
hidden_channels,
out_channels,
rel_names,
dropout,
n_layers,
activation=nn.ELU(),
):
super().__init__()
self.embedding_type = nn.Linear(9, int(hidden_channels))
self.embedding_pos = nn.Linear(2, int(hidden_channels))
self.embedding_color = nn.Linear(3, int(hidden_channels))
self.embedding_shape = nn.Linear(18, int(hidden_channels))
self.combine_attr = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*3, hidden_channels)
)
self.combine_pos = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*2, hidden_channels)
)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.SAGEConv(
in_feats=hidden_channels,
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
aggregator_type='lstm',
feat_drop=dropout,
activation=activation if l < n_layers - 1 else None
)
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
def forward(self, g):
attr = []
attr.append(self.embedding_type(g.ndata['type'].float()))
pos = self.embedding_pos(g.ndata['pos']/170.)
attr.append(self.embedding_color(g.ndata['color']/255.))
attr.append(self.embedding_shape(g.ndata['shape'].float()))
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
for l, conv in enumerate(self.layers):
h = conv(g, h)
with g.local_scope():
g.ndata['h'] = h['obj']
out = dgl.mean_nodes(g, 'h')
return out
# -------------------------------------------------------------------------------------------
class RAGNNv4(nn.Module):
# multi-layer GNN for one single feature
def __init__(
self,
hidden_channels,
out_channels,
rel_names,
dropout,
n_layers,
activation=nn.ELU(),
):
super().__init__()
self.embedding_type = nn.Linear(9, int(hidden_channels))
self.embedding_pos = nn.Linear(2, int(hidden_channels))
self.embedding_color = nn.Linear(3, int(hidden_channels))
self.embedding_shape = nn.Linear(18, int(hidden_channels))
self.combine_attr = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*3, hidden_channels)
)
self.combine_pos = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*2, hidden_channels)
)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.AGNNConv()
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
def forward(self, g):
attr = []
attr.append(self.embedding_type(g.ndata['type'].float()))
pos = self.embedding_pos(g.ndata['pos']/170.)
attr.append(self.embedding_color(g.ndata['color']/255.))
attr.append(self.embedding_shape(g.ndata['shape'].float()))
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
for l, conv in enumerate(self.layers):
h = conv(g, h)
with g.local_scope():
g.ndata['h'] = h['obj']
out = dgl.mean_nodes(g, 'h')
return out
# -------------------------------------------------------------------------------------------
class RGATv2(nn.Module):
# multi-layer GNN for one single feature
def __init__(
self,
hidden_channels,
out_channels,
num_heads,
rel_names,
dropout,
n_layers,
activation=nn.ELU(),
residual=False
):
super().__init__()
self.embedding_type = nn.Embedding(9, int(hidden_channels*num_heads/4))
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads/4))
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads/4))
self.embedding_shape = nn.Embedding(18, int(hidden_channels*num_heads/4))
self.combine = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads, hidden_channels*num_heads)
)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.GATv2Conv(
in_feats=hidden_channels*num_heads,
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
num_heads=num_heads if l < n_layers - 1 else 1,
feat_drop=dropout,
attn_drop=dropout,
residual=residual,
activation=activation if l < n_layers - 1 else None
)
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
def forward(self, g):
feats = []
feats.append(self.embedding_type(torch.argmax(g.ndata['type'], dim=1)))
feats.append(self.embedding_pos(g.ndata['pos']/170.))
feats.append(self.embedding_color(g.ndata['color']/255.))
feats.append(self.embedding_shape(torch.argmax(g.ndata['shape'], dim=1)))
h = {'obj': self.combine(torch.cat(feats, dim=1))}
for l, conv in enumerate(self.layers):
h = conv(g, h)
if l != len(self.layers) - 1:
h = {k: v.flatten(1) for k, v in h.items()}
else:
h = {k: v.mean(1) for k, v in h.items()}
with g.local_scope():
g.ndata['h'] = h['obj']
out = dgl.mean_nodes(g, 'h')
return out
# -------------------------------------------------------------------------------------------
class RGATv3(nn.Module):
# multi-layer GNN for one single feature
def __init__(
self,
hidden_channels,
out_channels,
num_heads,
rel_names,
dropout,
n_layers,
activation=nn.ELU(),
residual=False
):
super().__init__()
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
self.combine = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*4, hidden_channels*num_heads)
)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.GATv2Conv(
in_feats=hidden_channels*num_heads,
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
feat_drop=dropout,
attn_drop=dropout,
residual=residual,
activation=activation if l < n_layers - 1 else None
)
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
def forward(self, g):
feats = []
feats.append(self.embedding_type(g.ndata['type'].float()))
feats.append(self.embedding_pos(g.ndata['pos']/170.)) # NOTE: this should be 180 because I remove the boundary walls!
feats.append(self.embedding_color(g.ndata['color']/255.))
feats.append(self.embedding_shape(g.ndata['shape'].float()))
h = {'obj': self.combine(torch.cat(feats, dim=1))}
for l, conv in enumerate(self.layers):
h = conv(g, h)
if l != len(self.layers) - 1:
h = {k: v.flatten(1) for k, v in h.items()}
else:
h = {k: v.mean(1) for k, v in h.items()}
with g.local_scope():
g.ndata['h'] = h['obj']
out = dgl.mean_nodes(g, 'h')
return out
# -------------------------------------------------------------------------------------------
class RGCNv2(nn.Module):
# multi-layer GNN for one single feature
def __init__(
self,
hidden_channels,
out_channels,
rel_names,
dropout,
n_layers,
activation=nn.ReLU()
):
super().__init__()
self.embedding_type = nn.Linear(9, int(hidden_channels))
self.embedding_pos = nn.Linear(2, int(hidden_channels))
self.embedding_color = nn.Linear(3, int(hidden_channels))
self.embedding_shape = nn.Linear(18, int(hidden_channels))
self.combine = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*4, hidden_channels)
)
self.layers = nn.ModuleList([
dglnn.RelGraphConv(
in_feat=hidden_channels,
out_feat=hidden_channels,
num_rels=len(rel_names),
regularizer=None,
num_bases=None,
bias=True,
activation=activation,
self_loop=True,
dropout=dropout,
layer_norm=False
)
for _ in range(n_layers-1)])
self.layers.append(
dglnn.RelGraphConv(
in_feat=hidden_channels,
out_feat=out_channels,
num_rels=len(rel_names),
regularizer=None,
num_bases=None,
bias=True,
activation=activation,
self_loop=True,
dropout=dropout,
layer_norm=False
)
)
def forward(self, g):
g = g.to_homogeneous()
feats = []
feats.append(self.embedding_type(g.ndata['type'].float()))
feats.append(self.embedding_pos(g.ndata['pos']/170.))
feats.append(self.embedding_color(g.ndata['color']/255.))
feats.append(self.embedding_shape(g.ndata['shape'].float()))
h = self.combine(torch.cat(feats, dim=1))
for conv in self.layers:
h = conv(g, h, g.etypes)
with g.local_scope():
g.ndata['h'] = h
out = dgl.mean_nodes(g, 'h')
return out
# -------------------------------------------------------------------------------------------
class RGATv3Agent(nn.Module):
# multi-layer GNN for one single feature
def __init__(
self,
hidden_channels,
out_channels,
num_heads,
rel_names,
dropout,
n_layers,
activation=nn.ELU(),
residual=False
):
super().__init__()
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
self.combine = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*4, hidden_channels*num_heads)
)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.GATv2Conv(
in_feats=hidden_channels*num_heads,
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
feat_drop=dropout,
attn_drop=dropout,
residual=residual,
activation=activation if l < n_layers - 1 else None
)
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
def forward(self, g):
agent_mask = g.ndata['type'][:, 0] == 1
feats = []
feats.append(self.embedding_type(g.ndata['type'].float()))
feats.append(self.embedding_pos(g.ndata['pos']/200.))
feats.append(self.embedding_color(g.ndata['color']/255.))
feats.append(self.embedding_shape(g.ndata['shape'].float()))
h = {'obj': self.combine(torch.cat(feats, dim=1))}
for l, conv in enumerate(self.layers):
h = conv(g, h)
if l != len(self.layers) - 1:
h = {k: v.flatten(1) for k, v in h.items()}
else:
h = {k: v.mean(1) for k, v in h.items()}
with g.local_scope():
g.ndata['h'] = h['obj']
out = g.ndata['h'][agent_mask, :]
ctx = dgl.mean_nodes(g, 'h')
return out + ctx
# -------------------------------------------------------------------------------------------
class RGATv4(nn.Module):
# multi-layer GNN for one single feature
def __init__(
self,
hidden_channels,
out_channels,
num_heads,
rel_names,
dropout,
n_layers,
activation=nn.ELU(),
residual=False
):
super().__init__()
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
self.combine_attr = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads)
)
self.combine_pos = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads)
)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.GATv2Conv(
in_feats=hidden_channels*num_heads,
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
feat_drop=dropout,
attn_drop=dropout,
residual=residual,
share_weights=False,
activation=activation if l < n_layers - 1 else None
)
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
def forward(self, g):
attr = []
attr.append(self.embedding_type(g.ndata['type'].float()))
pos = self.embedding_pos(g.ndata['pos']/170.)
attr.append(self.embedding_color(g.ndata['color']/255.))
attr.append(self.embedding_shape(g.ndata['shape'].float()))
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
for l, conv in enumerate(self.layers):
h = conv(g, h)
if l != len(self.layers) - 1:
h = {k: v.flatten(1) for k, v in h.items()}
else:
h = {k: v.mean(1) for k, v in h.items()}
with g.local_scope():
g.ndata['h'] = h['obj']
out = dgl.mean_nodes(g, 'h')
return out
# -------------------------------------------------------------------------------------------
class RGATv4Norm(nn.Module):
# multi-layer GNN for one single feature
def __init__(
self,
hidden_channels,
out_channels,
num_heads,
rel_names,
dropout,
n_layers,
activation=nn.ELU(),
residual=False
):
super().__init__()
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
self.combine_attr = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads)
)
self.combine_pos = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads)
)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.GATv2Conv(
in_feats=hidden_channels*num_heads,
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
feat_drop=dropout,
attn_drop=dropout,
residual=residual,
activation=activation if l < n_layers - 1 else None
)
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
self.norms = nn.ModuleList([
Norm(
norm_type='gn',
hidden_dim=hidden_channels*num_heads if l < n_layers - 1 else out_channels
)
for l in range(n_layers)
])
def forward(self, g):
attr = []
attr.append(self.embedding_type(g.ndata['type'].float()))
pos = self.embedding_pos(g.ndata['pos']/170.)
attr.append(self.embedding_color(g.ndata['color']/255.))
attr.append(self.embedding_shape(g.ndata['shape'].float()))
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
for l, conv in enumerate(self.layers):
h = conv(g, h)
if l != len(self.layers) - 1:
h = {k: v.flatten(1) for k, v in h.items()}
else:
h = {k: v.mean(1) for k, v in h.items()}
h = {k: self.norms[l](g, v) for k, v in h.items()}
with g.local_scope():
g.ndata['h'] = h['obj']
out = dgl.mean_nodes(g, 'h')
return out
# -------------------------------------------------------------------------------------------
class RGATv3Norm(nn.Module):
# multi-layer GNN for one single feature
def __init__(
self,
hidden_channels,
out_channels,
num_heads,
rel_names,
dropout,
n_layers,
activation=nn.ELU(),
residual=False
):
super().__init__()
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
self.combine = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*4, hidden_channels*num_heads)
)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.GATv2Conv(
in_feats=hidden_channels*num_heads,
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
feat_drop=dropout,
attn_drop=dropout,
residual=residual,
activation=activation if l < n_layers - 1 else None
)
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
self.norms = nn.ModuleList([
Norm(
norm_type='gn',
hidden_dim=hidden_channels*num_heads if l < n_layers - 1 else out_channels
)
for l in range(n_layers)
])
def forward(self, g):
feats = []
feats.append(self.embedding_type(g.ndata['type'].float()))
feats.append(self.embedding_pos(g.ndata['pos']/170.))
feats.append(self.embedding_color(g.ndata['color']/255.))
feats.append(self.embedding_shape(g.ndata['shape'].float()))
h = {'obj': self.combine(torch.cat(feats, dim=1))}
for l, conv in enumerate(self.layers):
h = conv(g, h)
if l != len(self.layers) - 1:
h = {k: v.flatten(1) for k, v in h.items()}
else:
h = {k: v.mean(1) for k, v in h.items()}
h = {k: self.norms[l](g, v) for k, v in h.items()}
with g.local_scope():
g.ndata['h'] = h['obj']
out = dgl.mean_nodes(g, 'h')
return out
# -------------------------------------------------------------------------------------------
class RGATv4Agent(nn.Module):
# multi-layer GNN for one single feature
def __init__(
self,
hidden_channels,
out_channels,
num_heads,
rel_names,
dropout,
n_layers,
activation=nn.ELU(),
residual=False
):
super().__init__()
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
self.combine_attr = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads)
)
self.combine_pos = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads)
)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.GATv2Conv(
in_feats=hidden_channels*num_heads,
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
feat_drop=dropout,
attn_drop=dropout,
residual=residual,
activation=activation if l < n_layers - 1 else None
)
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
self.combine_agent_context = nn.Linear(out_channels*2, out_channels)
def forward(self, g):
agent_mask = g.ndata['type'][:, 0] == 1
attr = []
attr.append(self.embedding_type(g.ndata['type'].float()))
pos = self.embedding_pos(g.ndata['pos']/170.)
attr.append(self.embedding_color(g.ndata['color']/255.))
attr.append(self.embedding_shape(g.ndata['shape'].float()))
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
for l, conv in enumerate(self.layers):
h = conv(g, h)
if l != len(self.layers) - 1:
h = {k: v.flatten(1) for k, v in h.items()}
else:
h = {k: v.mean(1) for k, v in h.items()}
with g.local_scope():
g.ndata['h'] = h['obj']
h_a = g.ndata['h'][agent_mask, :]
g_no_agent = copy.deepcopy(g)
g_no_agent.remove_nodes([i for i, x in enumerate(agent_mask) if x])
h_g = dgl.mean_nodes(g_no_agent, 'h')
out = self.combine_agent_context(torch.cat((h_a, h_g), dim=1))
return out
# -------------------------------------------------------------------------------------------
class RGCNv4(nn.Module):
# multi-layer GNN for one single feature
def __init__(
self,
hidden_channels,
out_channels,
rel_names,
n_layers,
activation=nn.ELU(),
):
super().__init__()
self.embedding_type = nn.Linear(9, int(hidden_channels))
self.embedding_pos = nn.Linear(2, int(hidden_channels))
self.embedding_color = nn.Linear(3, int(hidden_channels))
self.embedding_shape = nn.Linear(18, int(hidden_channels))
self.combine_attr = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*3, hidden_channels)
)
self.combine_pos = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*2, hidden_channels)
)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.GraphConv(
in_feats=hidden_channels,
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
activation=activation if l < n_layers - 1 else None
)
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
def forward(self, g):
attr = []
attr.append(self.embedding_type(g.ndata['type'].float()))
pos = self.embedding_pos(g.ndata['pos']/170.)
attr.append(self.embedding_color(g.ndata['color']/255.))
attr.append(self.embedding_shape(g.ndata['shape'].float()))
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
for l, conv in enumerate(self.layers):
h = conv(g, h)
with g.local_scope():
g.ndata['h'] = h['obj']
out = dgl.mean_nodes(g, 'h')
return out
# -------------------------------------------------------------------------------------------
class RGATv5(nn.Module):
def __init__(
self,
hidden_channels,
out_channels,
num_heads,
rel_names,
dropout,
n_layers,
activation=nn.ELU(),
residual=False
):
super().__init__()
self.hidden_channels = hidden_channels
self.num_heads = num_heads
self.embedding_type = nn.Linear(9, hidden_channels*num_heads)
self.embedding_pos = nn.Linear(2, hidden_channels*num_heads)
self.embedding_color = nn.Linear(3, hidden_channels*num_heads)
self.embedding_shape = nn.Linear(18, hidden_channels*num_heads)
self.combine = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*4, hidden_channels*num_heads)
)
self.attention = nn.Linear(hidden_channels*num_heads*4, 4)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.GATv2Conv(
in_feats=hidden_channels*num_heads,
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
num_heads=num_heads,
feat_drop=dropout,
attn_drop=dropout,
residual=residual,
activation=activation if l < n_layers - 1 else None
)
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
def forward(self, g):
feats = []
feats.append(self.embedding_type(g.ndata['type'].float()))
feats.append(self.embedding_pos(g.ndata['pos']/170.))
feats.append(self.embedding_color(g.ndata['color']/255.))
feats.append(self.embedding_shape(g.ndata['shape'].float()))
h = torch.cat(feats, dim=1)
feat_attn = F.softmax(self.attention(h), dim=1)
h = h * feat_attn.repeat_interleave(self.hidden_channels*self.num_heads, dim=1)
h_in = self.combine(h)
h = {'obj': h_in}
for conv in self.layers:
h = conv(g, h)
h = {k: v.flatten(1) for k, v in h.items()}
#if l != len(self.layers) - 1:
# h = {k: v.flatten(1) for k, v in h.items()}
#else:
# h = {k: v.mean(1) for k, v in h.items()}
h = {k: v + h_in for k, v in h.items()}
with g.local_scope():
g.ndata['h'] = h['obj']
out = dgl.mean_nodes(g, 'h')
return out
# -------------------------------------------------------------------------------------------
class RGATv6(nn.Module):
# RGATv6 = RGATv4 + Global Attention Pooling
def __init__(
self,
hidden_channels,
out_channels,
num_heads,
rel_names,
dropout,
n_layers,
activation=nn.ELU(),
residual=False
):
super().__init__()
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
self.combine_attr = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads)
)
self.combine_pos = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads)
)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.GATv2Conv(
in_feats=hidden_channels*num_heads,
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
feat_drop=dropout,
attn_drop=dropout,
residual=residual,
activation=activation if l < n_layers - 1 else None
)
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
gate_nn = nn.Linear(out_channels, 1)
self.gap = dglnn.GlobalAttentionPooling(gate_nn)
def forward(self, g):
attr = []
attr.append(self.embedding_type(g.ndata['type'].float()))
pos = self.embedding_pos(g.ndata['pos']/170.)
attr.append(self.embedding_color(g.ndata['color']/255.))
attr.append(self.embedding_shape(g.ndata['shape'].float()))
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
for l, conv in enumerate(self.layers):
h = conv(g, h)
if l != len(self.layers) - 1:
h = {k: v.flatten(1) for k, v in h.items()}
else:
h = {k: v.mean(1) for k, v in h.items()}
#with g.local_scope():
#g.ndata['h'] = h['obj']
#out = dgl.mean_nodes(g, 'h')
out = self.gap(g, h['obj'])
return out
# -------------------------------------------------------------------------------------------
class RGATv6Agent(nn.Module):
# RGATv6 = RGATv4 + Global Attention Pooling
def __init__(
self,
hidden_channels,
out_channels,
num_heads,
rel_names,
dropout,
n_layers,
activation=nn.ELU(),
residual=False
):
super().__init__()
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
self.combine_attr = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads)
)
self.combine_pos = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads)
)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.GATv2Conv(
in_feats=hidden_channels*num_heads,
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
feat_drop=dropout,
attn_drop=dropout,
residual=residual,
activation=activation if l < n_layers - 1 else None
)
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
gate_nn = nn.Linear(out_channels, 1)
self.gap = dglnn.GlobalAttentionPooling(gate_nn)
self.combine_agent_context = nn.Linear(out_channels*2, out_channels)
def forward(self, g):
agent_mask = g.ndata['type'][:, 0] == 1
attr = []
attr.append(self.embedding_type(g.ndata['type'].float()))
pos = self.embedding_pos(g.ndata['pos']/170.)
attr.append(self.embedding_color(g.ndata['color']/255.))
attr.append(self.embedding_shape(g.ndata['shape'].float()))
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
for l, conv in enumerate(self.layers):
h = conv(g, h)
if l != len(self.layers) - 1:
h = {k: v.flatten(1) for k, v in h.items()}
else:
h = {k: v.mean(1) for k, v in h.items()}
with g.local_scope():
g.ndata['h'] = h['obj']
h_a = g.ndata['h'][agent_mask, :]
h_g = g.ndata['h'][~agent_mask, :]
g_no_agent = copy.deepcopy(g)
g_no_agent.remove_nodes([i for i, x in enumerate(agent_mask) if x])
h_g = self.gap(g_no_agent, h_g) # dgl.mean_nodes(g_no_agent, 'h')
out = self.combine_agent_context(torch.cat((h_a, h_g), dim=1))
return out
# -------------------------------------------------------------------------------------------