import dgl.nn.pytorch as dglnn import torch.nn as nn import torch.nn.functional as F import torch import dgl import sys import copy from wandb import agent sys.path.append('/projects/bortoletto/irene/') from tom.norm import Norm class RSAGEv4(nn.Module): # multi-layer GNN for one single feature def __init__( self, hidden_channels, out_channels, rel_names, dropout, n_layers, activation=nn.ELU(), ): super().__init__() self.embedding_type = nn.Linear(9, int(hidden_channels)) self.embedding_pos = nn.Linear(2, int(hidden_channels)) self.embedding_color = nn.Linear(3, int(hidden_channels)) self.embedding_shape = nn.Linear(18, int(hidden_channels)) self.combine_attr = nn.Sequential( nn.ReLU(), nn.Linear(hidden_channels*3, hidden_channels) ) self.combine_pos = nn.Sequential( nn.ReLU(), nn.Linear(hidden_channels*2, hidden_channels) ) self.layers = nn.ModuleList([ dglnn.HeteroGraphConv({ rel: dglnn.SAGEConv( in_feats=hidden_channels, out_feats=hidden_channels if l < n_layers - 1 else out_channels, aggregator_type='lstm', feat_drop=dropout, activation=activation if l < n_layers - 1 else None ) for rel in rel_names}, aggregate='sum') for l in range(n_layers) ]) def forward(self, g): attr = [] attr.append(self.embedding_type(g.ndata['type'].float())) pos = self.embedding_pos(g.ndata['pos']/170.) attr.append(self.embedding_color(g.ndata['color']/255.)) attr.append(self.embedding_shape(g.ndata['shape'].float())) combined_attr = self.combine_attr(torch.cat(attr, dim=1)) h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))} for l, conv in enumerate(self.layers): h = conv(g, h) with g.local_scope(): g.ndata['h'] = h['obj'] out = dgl.mean_nodes(g, 'h') return out # ------------------------------------------------------------------------------------------- class RAGNNv4(nn.Module): # multi-layer GNN for one single feature def __init__( self, hidden_channels, out_channels, rel_names, dropout, n_layers, activation=nn.ELU(), ): super().__init__() self.embedding_type = nn.Linear(9, int(hidden_channels)) self.embedding_pos = nn.Linear(2, int(hidden_channels)) self.embedding_color = nn.Linear(3, int(hidden_channels)) self.embedding_shape = nn.Linear(18, int(hidden_channels)) self.combine_attr = nn.Sequential( nn.ReLU(), nn.Linear(hidden_channels*3, hidden_channels) ) self.combine_pos = nn.Sequential( nn.ReLU(), nn.Linear(hidden_channels*2, hidden_channels) ) self.layers = nn.ModuleList([ dglnn.HeteroGraphConv({ rel: dglnn.AGNNConv() for rel in rel_names}, aggregate='sum') for l in range(n_layers) ]) def forward(self, g): attr = [] attr.append(self.embedding_type(g.ndata['type'].float())) pos = self.embedding_pos(g.ndata['pos']/170.) attr.append(self.embedding_color(g.ndata['color']/255.)) attr.append(self.embedding_shape(g.ndata['shape'].float())) combined_attr = self.combine_attr(torch.cat(attr, dim=1)) h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))} for l, conv in enumerate(self.layers): h = conv(g, h) with g.local_scope(): g.ndata['h'] = h['obj'] out = dgl.mean_nodes(g, 'h') return out # ------------------------------------------------------------------------------------------- class RGATv2(nn.Module): # multi-layer GNN for one single feature def __init__( self, hidden_channels, out_channels, num_heads, rel_names, dropout, n_layers, activation=nn.ELU(), residual=False ): super().__init__() self.embedding_type = nn.Embedding(9, int(hidden_channels*num_heads/4)) self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads/4)) self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads/4)) self.embedding_shape = nn.Embedding(18, int(hidden_channels*num_heads/4)) self.combine = nn.Sequential( nn.ReLU(), nn.Linear(hidden_channels*num_heads, hidden_channels*num_heads) ) self.layers = nn.ModuleList([ dglnn.HeteroGraphConv({ rel: dglnn.GATv2Conv( in_feats=hidden_channels*num_heads, out_feats=hidden_channels if l < n_layers - 1 else out_channels, num_heads=num_heads if l < n_layers - 1 else 1, feat_drop=dropout, attn_drop=dropout, residual=residual, activation=activation if l < n_layers - 1 else None ) for rel in rel_names}, aggregate='sum') for l in range(n_layers) ]) def forward(self, g): feats = [] feats.append(self.embedding_type(torch.argmax(g.ndata['type'], dim=1))) feats.append(self.embedding_pos(g.ndata['pos']/170.)) feats.append(self.embedding_color(g.ndata['color']/255.)) feats.append(self.embedding_shape(torch.argmax(g.ndata['shape'], dim=1))) h = {'obj': self.combine(torch.cat(feats, dim=1))} for l, conv in enumerate(self.layers): h = conv(g, h) if l != len(self.layers) - 1: h = {k: v.flatten(1) for k, v in h.items()} else: h = {k: v.mean(1) for k, v in h.items()} with g.local_scope(): g.ndata['h'] = h['obj'] out = dgl.mean_nodes(g, 'h') return out # ------------------------------------------------------------------------------------------- class RGATv3(nn.Module): # multi-layer GNN for one single feature def __init__( self, hidden_channels, out_channels, num_heads, rel_names, dropout, n_layers, activation=nn.ELU(), residual=False ): super().__init__() self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads)) self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads)) self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads)) self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads)) self.combine = nn.Sequential( nn.ReLU(), nn.Linear(hidden_channels*num_heads*4, hidden_channels*num_heads) ) self.layers = nn.ModuleList([ dglnn.HeteroGraphConv({ rel: dglnn.GATv2Conv( in_feats=hidden_channels*num_heads, out_feats=hidden_channels if l < n_layers - 1 else out_channels, num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always feat_drop=dropout, attn_drop=dropout, residual=residual, activation=activation if l < n_layers - 1 else None ) for rel in rel_names}, aggregate='sum') for l in range(n_layers) ]) def forward(self, g): feats = [] feats.append(self.embedding_type(g.ndata['type'].float())) feats.append(self.embedding_pos(g.ndata['pos']/170.)) # NOTE: this should be 180 because I remove the boundary walls! feats.append(self.embedding_color(g.ndata['color']/255.)) feats.append(self.embedding_shape(g.ndata['shape'].float())) h = {'obj': self.combine(torch.cat(feats, dim=1))} for l, conv in enumerate(self.layers): h = conv(g, h) if l != len(self.layers) - 1: h = {k: v.flatten(1) for k, v in h.items()} else: h = {k: v.mean(1) for k, v in h.items()} with g.local_scope(): g.ndata['h'] = h['obj'] out = dgl.mean_nodes(g, 'h') return out # ------------------------------------------------------------------------------------------- class RGCNv2(nn.Module): # multi-layer GNN for one single feature def __init__( self, hidden_channels, out_channels, rel_names, dropout, n_layers, activation=nn.ReLU() ): super().__init__() self.embedding_type = nn.Linear(9, int(hidden_channels)) self.embedding_pos = nn.Linear(2, int(hidden_channels)) self.embedding_color = nn.Linear(3, int(hidden_channels)) self.embedding_shape = nn.Linear(18, int(hidden_channels)) self.combine = nn.Sequential( nn.ReLU(), nn.Linear(hidden_channels*4, hidden_channels) ) self.layers = nn.ModuleList([ dglnn.RelGraphConv( in_feat=hidden_channels, out_feat=hidden_channels, num_rels=len(rel_names), regularizer=None, num_bases=None, bias=True, activation=activation, self_loop=True, dropout=dropout, layer_norm=False ) for _ in range(n_layers-1)]) self.layers.append( dglnn.RelGraphConv( in_feat=hidden_channels, out_feat=out_channels, num_rels=len(rel_names), regularizer=None, num_bases=None, bias=True, activation=activation, self_loop=True, dropout=dropout, layer_norm=False ) ) def forward(self, g): g = g.to_homogeneous() feats = [] feats.append(self.embedding_type(g.ndata['type'].float())) feats.append(self.embedding_pos(g.ndata['pos']/170.)) feats.append(self.embedding_color(g.ndata['color']/255.)) feats.append(self.embedding_shape(g.ndata['shape'].float())) h = self.combine(torch.cat(feats, dim=1)) for conv in self.layers: h = conv(g, h, g.etypes) with g.local_scope(): g.ndata['h'] = h out = dgl.mean_nodes(g, 'h') return out # ------------------------------------------------------------------------------------------- class RGATv3Agent(nn.Module): # multi-layer GNN for one single feature def __init__( self, hidden_channels, out_channels, num_heads, rel_names, dropout, n_layers, activation=nn.ELU(), residual=False ): super().__init__() self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads)) self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads)) self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads)) self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads)) self.combine = nn.Sequential( nn.ReLU(), nn.Linear(hidden_channels*num_heads*4, hidden_channels*num_heads) ) self.layers = nn.ModuleList([ dglnn.HeteroGraphConv({ rel: dglnn.GATv2Conv( in_feats=hidden_channels*num_heads, out_feats=hidden_channels if l < n_layers - 1 else out_channels, num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always feat_drop=dropout, attn_drop=dropout, residual=residual, activation=activation if l < n_layers - 1 else None ) for rel in rel_names}, aggregate='sum') for l in range(n_layers) ]) def forward(self, g): agent_mask = g.ndata['type'][:, 0] == 1 feats = [] feats.append(self.embedding_type(g.ndata['type'].float())) feats.append(self.embedding_pos(g.ndata['pos']/200.)) feats.append(self.embedding_color(g.ndata['color']/255.)) feats.append(self.embedding_shape(g.ndata['shape'].float())) h = {'obj': self.combine(torch.cat(feats, dim=1))} for l, conv in enumerate(self.layers): h = conv(g, h) if l != len(self.layers) - 1: h = {k: v.flatten(1) for k, v in h.items()} else: h = {k: v.mean(1) for k, v in h.items()} with g.local_scope(): g.ndata['h'] = h['obj'] out = g.ndata['h'][agent_mask, :] ctx = dgl.mean_nodes(g, 'h') return out + ctx # ------------------------------------------------------------------------------------------- class RGATv4(nn.Module): # multi-layer GNN for one single feature def __init__( self, hidden_channels, out_channels, num_heads, rel_names, dropout, n_layers, activation=nn.ELU(), residual=False ): super().__init__() self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads)) self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads)) self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads)) self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads)) self.combine_attr = nn.Sequential( nn.ReLU(), nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads) ) self.combine_pos = nn.Sequential( nn.ReLU(), nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads) ) self.layers = nn.ModuleList([ dglnn.HeteroGraphConv({ rel: dglnn.GATv2Conv( in_feats=hidden_channels*num_heads, out_feats=hidden_channels if l < n_layers - 1 else out_channels, num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always feat_drop=dropout, attn_drop=dropout, residual=residual, share_weights=False, activation=activation if l < n_layers - 1 else None ) for rel in rel_names}, aggregate='sum') for l in range(n_layers) ]) def forward(self, g): attr = [] attr.append(self.embedding_type(g.ndata['type'].float())) pos = self.embedding_pos(g.ndata['pos']/170.) attr.append(self.embedding_color(g.ndata['color']/255.)) attr.append(self.embedding_shape(g.ndata['shape'].float())) combined_attr = self.combine_attr(torch.cat(attr, dim=1)) h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))} for l, conv in enumerate(self.layers): h = conv(g, h) if l != len(self.layers) - 1: h = {k: v.flatten(1) for k, v in h.items()} else: h = {k: v.mean(1) for k, v in h.items()} with g.local_scope(): g.ndata['h'] = h['obj'] out = dgl.mean_nodes(g, 'h') return out # ------------------------------------------------------------------------------------------- class RGATv4Norm(nn.Module): # multi-layer GNN for one single feature def __init__( self, hidden_channels, out_channels, num_heads, rel_names, dropout, n_layers, activation=nn.ELU(), residual=False ): super().__init__() self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads)) self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads)) self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads)) self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads)) self.combine_attr = nn.Sequential( nn.ReLU(), nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads) ) self.combine_pos = nn.Sequential( nn.ReLU(), nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads) ) self.layers = nn.ModuleList([ dglnn.HeteroGraphConv({ rel: dglnn.GATv2Conv( in_feats=hidden_channels*num_heads, out_feats=hidden_channels if l < n_layers - 1 else out_channels, num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always feat_drop=dropout, attn_drop=dropout, residual=residual, activation=activation if l < n_layers - 1 else None ) for rel in rel_names}, aggregate='sum') for l in range(n_layers) ]) self.norms = nn.ModuleList([ Norm( norm_type='gn', hidden_dim=hidden_channels*num_heads if l < n_layers - 1 else out_channels ) for l in range(n_layers) ]) def forward(self, g): attr = [] attr.append(self.embedding_type(g.ndata['type'].float())) pos = self.embedding_pos(g.ndata['pos']/170.) attr.append(self.embedding_color(g.ndata['color']/255.)) attr.append(self.embedding_shape(g.ndata['shape'].float())) combined_attr = self.combine_attr(torch.cat(attr, dim=1)) h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))} for l, conv in enumerate(self.layers): h = conv(g, h) if l != len(self.layers) - 1: h = {k: v.flatten(1) for k, v in h.items()} else: h = {k: v.mean(1) for k, v in h.items()} h = {k: self.norms[l](g, v) for k, v in h.items()} with g.local_scope(): g.ndata['h'] = h['obj'] out = dgl.mean_nodes(g, 'h') return out # ------------------------------------------------------------------------------------------- class RGATv3Norm(nn.Module): # multi-layer GNN for one single feature def __init__( self, hidden_channels, out_channels, num_heads, rel_names, dropout, n_layers, activation=nn.ELU(), residual=False ): super().__init__() self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads)) self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads)) self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads)) self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads)) self.combine = nn.Sequential( nn.ReLU(), nn.Linear(hidden_channels*num_heads*4, hidden_channels*num_heads) ) self.layers = nn.ModuleList([ dglnn.HeteroGraphConv({ rel: dglnn.GATv2Conv( in_feats=hidden_channels*num_heads, out_feats=hidden_channels if l < n_layers - 1 else out_channels, num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always feat_drop=dropout, attn_drop=dropout, residual=residual, activation=activation if l < n_layers - 1 else None ) for rel in rel_names}, aggregate='sum') for l in range(n_layers) ]) self.norms = nn.ModuleList([ Norm( norm_type='gn', hidden_dim=hidden_channels*num_heads if l < n_layers - 1 else out_channels ) for l in range(n_layers) ]) def forward(self, g): feats = [] feats.append(self.embedding_type(g.ndata['type'].float())) feats.append(self.embedding_pos(g.ndata['pos']/170.)) feats.append(self.embedding_color(g.ndata['color']/255.)) feats.append(self.embedding_shape(g.ndata['shape'].float())) h = {'obj': self.combine(torch.cat(feats, dim=1))} for l, conv in enumerate(self.layers): h = conv(g, h) if l != len(self.layers) - 1: h = {k: v.flatten(1) for k, v in h.items()} else: h = {k: v.mean(1) for k, v in h.items()} h = {k: self.norms[l](g, v) for k, v in h.items()} with g.local_scope(): g.ndata['h'] = h['obj'] out = dgl.mean_nodes(g, 'h') return out # ------------------------------------------------------------------------------------------- class RGATv4Agent(nn.Module): # multi-layer GNN for one single feature def __init__( self, hidden_channels, out_channels, num_heads, rel_names, dropout, n_layers, activation=nn.ELU(), residual=False ): super().__init__() self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads)) self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads)) self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads)) self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads)) self.combine_attr = nn.Sequential( nn.ReLU(), nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads) ) self.combine_pos = nn.Sequential( nn.ReLU(), nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads) ) self.layers = nn.ModuleList([ dglnn.HeteroGraphConv({ rel: dglnn.GATv2Conv( in_feats=hidden_channels*num_heads, out_feats=hidden_channels if l < n_layers - 1 else out_channels, num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always feat_drop=dropout, attn_drop=dropout, residual=residual, activation=activation if l < n_layers - 1 else None ) for rel in rel_names}, aggregate='sum') for l in range(n_layers) ]) self.combine_agent_context = nn.Linear(out_channels*2, out_channels) def forward(self, g): agent_mask = g.ndata['type'][:, 0] == 1 attr = [] attr.append(self.embedding_type(g.ndata['type'].float())) pos = self.embedding_pos(g.ndata['pos']/170.) attr.append(self.embedding_color(g.ndata['color']/255.)) attr.append(self.embedding_shape(g.ndata['shape'].float())) combined_attr = self.combine_attr(torch.cat(attr, dim=1)) h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))} for l, conv in enumerate(self.layers): h = conv(g, h) if l != len(self.layers) - 1: h = {k: v.flatten(1) for k, v in h.items()} else: h = {k: v.mean(1) for k, v in h.items()} with g.local_scope(): g.ndata['h'] = h['obj'] h_a = g.ndata['h'][agent_mask, :] g_no_agent = copy.deepcopy(g) g_no_agent.remove_nodes([i for i, x in enumerate(agent_mask) if x]) h_g = dgl.mean_nodes(g_no_agent, 'h') out = self.combine_agent_context(torch.cat((h_a, h_g), dim=1)) return out # ------------------------------------------------------------------------------------------- class RGCNv4(nn.Module): # multi-layer GNN for one single feature def __init__( self, hidden_channels, out_channels, rel_names, n_layers, activation=nn.ELU(), ): super().__init__() self.embedding_type = nn.Linear(9, int(hidden_channels)) self.embedding_pos = nn.Linear(2, int(hidden_channels)) self.embedding_color = nn.Linear(3, int(hidden_channels)) self.embedding_shape = nn.Linear(18, int(hidden_channels)) self.combine_attr = nn.Sequential( nn.ReLU(), nn.Linear(hidden_channels*3, hidden_channels) ) self.combine_pos = nn.Sequential( nn.ReLU(), nn.Linear(hidden_channels*2, hidden_channels) ) self.layers = nn.ModuleList([ dglnn.HeteroGraphConv({ rel: dglnn.GraphConv( in_feats=hidden_channels, out_feats=hidden_channels if l < n_layers - 1 else out_channels, activation=activation if l < n_layers - 1 else None ) for rel in rel_names}, aggregate='sum') for l in range(n_layers) ]) def forward(self, g): attr = [] attr.append(self.embedding_type(g.ndata['type'].float())) pos = self.embedding_pos(g.ndata['pos']/170.) attr.append(self.embedding_color(g.ndata['color']/255.)) attr.append(self.embedding_shape(g.ndata['shape'].float())) combined_attr = self.combine_attr(torch.cat(attr, dim=1)) h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))} for l, conv in enumerate(self.layers): h = conv(g, h) with g.local_scope(): g.ndata['h'] = h['obj'] out = dgl.mean_nodes(g, 'h') return out # ------------------------------------------------------------------------------------------- class RGATv5(nn.Module): def __init__( self, hidden_channels, out_channels, num_heads, rel_names, dropout, n_layers, activation=nn.ELU(), residual=False ): super().__init__() self.hidden_channels = hidden_channels self.num_heads = num_heads self.embedding_type = nn.Linear(9, hidden_channels*num_heads) self.embedding_pos = nn.Linear(2, hidden_channels*num_heads) self.embedding_color = nn.Linear(3, hidden_channels*num_heads) self.embedding_shape = nn.Linear(18, hidden_channels*num_heads) self.combine = nn.Sequential( nn.ReLU(), nn.Linear(hidden_channels*num_heads*4, hidden_channels*num_heads) ) self.attention = nn.Linear(hidden_channels*num_heads*4, 4) self.layers = nn.ModuleList([ dglnn.HeteroGraphConv({ rel: dglnn.GATv2Conv( in_feats=hidden_channels*num_heads, out_feats=hidden_channels if l < n_layers - 1 else out_channels, num_heads=num_heads, feat_drop=dropout, attn_drop=dropout, residual=residual, activation=activation if l < n_layers - 1 else None ) for rel in rel_names}, aggregate='sum') for l in range(n_layers) ]) def forward(self, g): feats = [] feats.append(self.embedding_type(g.ndata['type'].float())) feats.append(self.embedding_pos(g.ndata['pos']/170.)) feats.append(self.embedding_color(g.ndata['color']/255.)) feats.append(self.embedding_shape(g.ndata['shape'].float())) h = torch.cat(feats, dim=1) feat_attn = F.softmax(self.attention(h), dim=1) h = h * feat_attn.repeat_interleave(self.hidden_channels*self.num_heads, dim=1) h_in = self.combine(h) h = {'obj': h_in} for conv in self.layers: h = conv(g, h) h = {k: v.flatten(1) for k, v in h.items()} #if l != len(self.layers) - 1: # h = {k: v.flatten(1) for k, v in h.items()} #else: # h = {k: v.mean(1) for k, v in h.items()} h = {k: v + h_in for k, v in h.items()} with g.local_scope(): g.ndata['h'] = h['obj'] out = dgl.mean_nodes(g, 'h') return out # ------------------------------------------------------------------------------------------- class RGATv6(nn.Module): # RGATv6 = RGATv4 + Global Attention Pooling def __init__( self, hidden_channels, out_channels, num_heads, rel_names, dropout, n_layers, activation=nn.ELU(), residual=False ): super().__init__() self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads)) self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads)) self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads)) self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads)) self.combine_attr = nn.Sequential( nn.ReLU(), nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads) ) self.combine_pos = nn.Sequential( nn.ReLU(), nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads) ) self.layers = nn.ModuleList([ dglnn.HeteroGraphConv({ rel: dglnn.GATv2Conv( in_feats=hidden_channels*num_heads, out_feats=hidden_channels if l < n_layers - 1 else out_channels, num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always feat_drop=dropout, attn_drop=dropout, residual=residual, activation=activation if l < n_layers - 1 else None ) for rel in rel_names}, aggregate='sum') for l in range(n_layers) ]) gate_nn = nn.Linear(out_channels, 1) self.gap = dglnn.GlobalAttentionPooling(gate_nn) def forward(self, g): attr = [] attr.append(self.embedding_type(g.ndata['type'].float())) pos = self.embedding_pos(g.ndata['pos']/170.) attr.append(self.embedding_color(g.ndata['color']/255.)) attr.append(self.embedding_shape(g.ndata['shape'].float())) combined_attr = self.combine_attr(torch.cat(attr, dim=1)) h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))} for l, conv in enumerate(self.layers): h = conv(g, h) if l != len(self.layers) - 1: h = {k: v.flatten(1) for k, v in h.items()} else: h = {k: v.mean(1) for k, v in h.items()} #with g.local_scope(): #g.ndata['h'] = h['obj'] #out = dgl.mean_nodes(g, 'h') out = self.gap(g, h['obj']) return out # ------------------------------------------------------------------------------------------- class RGATv6Agent(nn.Module): # RGATv6 = RGATv4 + Global Attention Pooling def __init__( self, hidden_channels, out_channels, num_heads, rel_names, dropout, n_layers, activation=nn.ELU(), residual=False ): super().__init__() self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads)) self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads)) self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads)) self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads)) self.combine_attr = nn.Sequential( nn.ReLU(), nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads) ) self.combine_pos = nn.Sequential( nn.ReLU(), nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads) ) self.layers = nn.ModuleList([ dglnn.HeteroGraphConv({ rel: dglnn.GATv2Conv( in_feats=hidden_channels*num_heads, out_feats=hidden_channels if l < n_layers - 1 else out_channels, num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always feat_drop=dropout, attn_drop=dropout, residual=residual, activation=activation if l < n_layers - 1 else None ) for rel in rel_names}, aggregate='sum') for l in range(n_layers) ]) gate_nn = nn.Linear(out_channels, 1) self.gap = dglnn.GlobalAttentionPooling(gate_nn) self.combine_agent_context = nn.Linear(out_channels*2, out_channels) def forward(self, g): agent_mask = g.ndata['type'][:, 0] == 1 attr = [] attr.append(self.embedding_type(g.ndata['type'].float())) pos = self.embedding_pos(g.ndata['pos']/170.) attr.append(self.embedding_color(g.ndata['color']/255.)) attr.append(self.embedding_shape(g.ndata['shape'].float())) combined_attr = self.combine_attr(torch.cat(attr, dim=1)) h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))} for l, conv in enumerate(self.layers): h = conv(g, h) if l != len(self.layers) - 1: h = {k: v.flatten(1) for k, v in h.items()} else: h = {k: v.mean(1) for k, v in h.items()} with g.local_scope(): g.ndata['h'] = h['obj'] h_a = g.ndata['h'][agent_mask, :] h_g = g.ndata['h'][~agent_mask, :] g_no_agent = copy.deepcopy(g) g_no_agent.remove_nodes([i for i, x in enumerate(agent_mask) if x]) h_g = self.gap(g_no_agent, h_g) # dgl.mean_nodes(g_no_agent, 'h') out = self.combine_agent_context(torch.cat((h_a, h_g), dim=1)) return out # -------------------------------------------------------------------------------------------