This commit is contained in:
Matteo Bortoletto 2024-02-01 15:40:47 +01:00
parent a333481e05
commit de0bea7508
18 changed files with 3150 additions and 2 deletions

View file

@ -1,3 +1,75 @@
# IRENE
<div align="center">
<h1> Neural Reasoning about Agents' Goals, Preferences, and Actions </h1>
Official code of "Neural Reasoning About Agents' Goals, Preferences, and Actions"
**[Matteo Bortoletto][1], &nbsp; [Lei Shi][2], &nbsp; [Andreas Bulling][3]** <br> <br>
**AAAI'24, Vancouver, CA** <br>
**[[Paper][4]]**
</div>
# Citation
If you find our code useful or use it in your own projects, please cite our paper:
```bibtex
@inproceedings{bortoletto2024neural,
author = {Bortoletto, Matteo and Lei, Shi and Bulling, Andreas},
title = {{Neural Reasoning about Agents' Goals, Preferences, and Actions}},
booktitle = {Proc. 38th AAAI Conference on Artificial Intelligence (AAAI)},
year = {2024},
}
```
# Setup
This code is based on the [original implementation][5] of the BIB benchmark.
## Using `virtualenv`
```
python -m virtualenv /path/to/env
source /path/to/env/bin/activate
pip install -r requirements.txt
```
## Using `conda`
```
conda create --name <env_name> python=3.8.10 pip=20.0.2 cudatoolkit=10.2.89
conda activate <env_name>
pip install -r requirements_conda.txt
pip install dgl-cu102 dglgo -f https://data.dgl.ai/wheels/repo.html
```
# Running the code
## Activate the environment
Run `source bibdgl/bin/activate`.
## Index data
This will create the json files with all the indexed frames for each episode in each video.
```
python utils/index_data.py
```
You need to manually set `mode` in the dataset class (in main).
## Generate graphs
This will generate the graphs from the videos:
```
python /utils/build_graphs.py --mode MODE --cpus NUM_CPUS
```
`MODE` can be `train`, `val` or `test`. NOTE: check `utils/build_graphs.py` to make sure you're loading the correct dataset to generate the graphs you want.
## Training
You can use the `gtbc.sh`.
## Testing
Use `run_test_tom.sh`.
# Hardware setup
All models are trained on an NVIDIA Tesla V100-SXM2-32GB GPU.
[1]: https://mattbortoletto.github.io/
[2]: https://perceptualui.org/people/shi/
[3]: https://perceptualui.org/people/bulling/
[4]: https://perceptualui.org/publications/bortoletto24_aaai.pdf
[5]: https://github.com/kanishkg/bib-baselines

9
run_test.sh Normal file
View file

@ -0,0 +1,9 @@
echo 314 e31
CUDA_VISIBLE_DEVICES=1 python test_tom.py \
--model_type graphbcrnn \
--types efficiency_irrational \
--ckpt /projects/bortoletto/icml2023_matteo/wandb/run-20221224_135525-8i1r2aqy/files/bib/8i1r2aqy/checkpoints/epoch\=31-step\=22399.ckpt \
--data_path /datasets/external/bib_evaluation_1_1/graphs/all_tasks \
--process_data 0 \
--surprise_type max

18
run_train.sh Normal file
View file

@ -0,0 +1,18 @@
CUDA_VISIBLE_DEVICES=0 python train_tom.py \
--model_type graphbcrnn \
--types single_object preference instrumental_action \
--data_path /datasets/external/bib_train/graphs/all_tasks/ \
--seed 7 \
--batch_size 32 \
--max_epochs 35 \
--gpus 1 \
--auto_select_gpus True \
--num_workers 2 \
--stochastic_weight_avg True \
--lr 5e-4 \
--check_val_every_n_epoch 1 \
--track_grad_norm 2 \
--gradient_clip_val 10 \
--gnn_type RSAGEv4 \
--state_dim 96 \
--aggregation sum

122
test_tom.py Normal file
View file

@ -0,0 +1,122 @@
from argparse import ArgumentParser
import numpy as np
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
import dgl
from tom.dataset import TestToMnetDGLDataset, collate_function_seq_test
from tom.model import GraphBC_T, GraphBCRNN
def get_z_scores(total, total_expected, total_unexpected):
mean = np.mean(total)
std = np.std(total)
print("Z-Score expected: ",
(np.mean(total_expected) - mean) / std)
print("Z-Score unexpected: ",
(np.mean(total_unexpected) - mean) / std)
parser = ArgumentParser()
parser.add_argument('--model_type', type=str, default='graphbcrnn')
parser.add_argument('--ckpt', type=str, default=None, help='path to checkpoint')
parser.add_argument('--data_path', type=str, default=None, help='path to the data')
parser.add_argument('--process_data', type=int, default=0)
parser.add_argument('--surprise_type', type=str, default='max',
help='surprise type: mean, max. This is used for comparing the plausibility scores of the two test episodes')
parser.add_argument('--types', nargs='+', type=str,
default=[
'preference', 'multi_agent', 'inaccessible_goal',
'efficiency_irrational', 'efficiency_time','efficiency_path',
'instrumental_no_barrier', 'instrumental_blocking_barrier', 'instrumental_inconsequential_barrier'
],
help='types of tasks used for training / testing')
parser.add_argument('--filename', type=str, default='')
args = parser.parse_args()
filename = args.filename
if args.model_type == 'graphbct':
model = GraphBC_T.load_from_checkpoint(args.ckpt)
elif args.model_type == 'graphbcrnn':
model = GraphBCRNN.load_from_checkpoint(args.ckpt)
else:
raise ValueError('Unknown model type.')
device = 'cuda'
model.to(device)
model.eval()
with torch.no_grad():
for t in args.types:
if args.model_type == 'graphbcrnn':
test_dataset = TestToMnetDGLDataset(
path=args.data_path,
task_type=t,
mode='test'
)
test_dataloader = DataLoader(
test_dataset,
batch_size=1,
num_workers=1,
pin_memory=True,
collate_fn=collate_function_seq_test,
shuffle=False
)
count = 0
total, total_expected, total_unexpected = [], [], []
pbar = tqdm(test_dataloader)
for j, batch in enumerate(pbar):
if args.model_type == 'graphbcrnn':
dem_expected_states, dem_expected_actions, dem_expected_lens, \
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, \
query_expected_frames, target_expected_actions, \
query_unexpected_frames, target_unexpected_actions = batch
dem_expected_states = dem_expected_states.to(device)
dem_expected_actions = dem_expected_actions.to(device)
dem_unexpected_states = dem_unexpected_states.to(device)
dem_unexpected_actions = dem_unexpected_actions.to(device)
target_expected_actions = target_expected_actions.to(device)
target_unexpected_actions = target_unexpected_actions.to(device)
surprise_expected = []
query_expected_frames = dgl.unbatch(query_expected_frames)
for i in range(len(query_expected_frames)):
if args.model_type == 'graphbcrnn':
test_actions, test_actions_pred = model(
[dem_expected_states, dem_expected_actions, dem_expected_lens, query_expected_frames[i].to(device), target_expected_actions[:, i, :]]
)
loss = F.mse_loss(test_actions, test_actions_pred)
surprise_expected.append(loss.cpu().detach().numpy())
mean_expected_surprise = np.mean(surprise_expected)
max_expected_surprise = np.max(surprise_expected)
# calculate the plausibility scores for the unexpected episode
surprise_unexpected = []
query_unexpected_frames = dgl.unbatch(query_unexpected_frames)
for i in range(len(query_unexpected_frames)):
if args.model_type == 'graphbcrnn':
test_actions, test_actions_pred = model(
[dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, query_unexpected_frames[i].to(device), target_unexpected_actions[:, i, :]]
)
loss = F.mse_loss(test_actions, test_actions_pred)
surprise_unexpected.append(loss.cpu().detach().numpy())
mean_unexpected_surprise = np.mean(surprise_unexpected)
max_unexpected_surprise = np.max(surprise_unexpected)
correct_mean = mean_expected_surprise < mean_unexpected_surprise + 0.5 * (mean_expected_surprise == mean_unexpected_surprise)
correct_max = max_expected_surprise < max_unexpected_surprise + 0.5 * (max_expected_surprise == max_unexpected_surprise)
if args.surprise_type == 'max':
count += correct_max
elif args.surprise_type == 'mean':
count += correct_mean
pbar.set_postfix({'accuracy': count/(j+1.), 'type': t})
total_expected.append(max_expected_surprise)
total_unexpected.append(max_unexpected_surprise)
total.append(max_expected_surprise)
total.append(max_unexpected_surprise)
get_z_scores(total, total_expected, total_unexpected)

0
tom/__init__.py Normal file
View file

310
tom/dataset.py Normal file
View file

@ -0,0 +1,310 @@
import pickle as pkl
import os
import torch
import torch.utils.data
import torch.nn.functional as F
import dgl
import random
from dgl.data import DGLDataset
def collate_function_seq(batch):
#dem_frames = torch.stack([item[0] for item in batch])
dem_frames = dgl.batch([item[0] for item in batch])
dem_actions = torch.stack([item[1] for item in batch])
dem_lens = [item[2] for item in batch]
#query_frames = torch.stack([item[3] for item in batch])
query_frames = dgl.batch([item[3] for item in batch])
target_actions = torch.stack([item[4] for item in batch])
return [dem_frames, dem_actions, dem_lens, query_frames, target_actions]
def collate_function_seq_test(batch):
dem_expected_states = dgl.batch([item[0] for item in batch][0])
dem_expected_actions = torch.stack([item[1] for item in batch][0]).unsqueeze(dim=0)
dem_expected_lens = [item[2] for item in batch]
#print(dem_expected_actions.size())
dem_unexpected_states = dgl.batch([item[3] for item in batch][0])
dem_unexpected_actions = torch.stack([item[4] for item in batch][0]).unsqueeze(dim=0)
dem_unexpected_lens = [item[5] for item in batch]
query_expected_frames = dgl.batch([item[6] for item in batch])
target_expected_actions = torch.stack([item[7] for item in batch])
#print(target_expected_actions.size())
query_unexpected_frames = dgl.batch([item[8] for item in batch])
target_unexpected_actions = torch.stack([item[9] for item in batch])
return [
dem_expected_states, dem_expected_actions, dem_expected_lens, \
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, \
query_expected_frames, target_expected_actions, \
query_unexpected_frames, target_unexpected_actions
]
def collate_function_mental(batch):
dem_frames = dgl.batch([item[0] for item in batch])
dem_actions = torch.stack([item[1] for item in batch])
dem_lens = [item[2] for item in batch]
past_test_frames = dgl.batch([item[3] for item in batch])
past_test_actions = torch.stack([item[4] for item in batch])
past_test_len = [item[5] for item in batch]
query_frames = dgl.batch([item[6] for item in batch])
target_actions = torch.stack([item[7] for item in batch])
return [dem_frames, dem_actions, dem_lens, past_test_frames, past_test_actions, past_test_len, query_frames, target_actions]
class ToMnetDGLDataset(DGLDataset):
"""
Training dataset class.
"""
def __init__(self, path, types=None, mode="train"):
self.path = path
self.types = types
self.mode = mode
print('Mode:', self.mode)
if self.mode == 'train':
if len(self.types) == 4:
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/'
#self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_global/'
#self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_local/'
elif len(self.types) == 3:
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper() + '/'
print(self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper())
elif len(self.types) == 2:
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + '/'
print(self.types[0][0].upper() + self.types[1][0].upper())
elif len(self.types) == 1:
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + '/'
else: raise ValueError('Number of types different from 1 or 4.')
elif self.mode == 'val':
assert len(self.types) == 1
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/' + self.types[0] + '/'
#self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_global/' + self.types[0] + '/'
#self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_local/' + self.types[0] + '/'
else:
raise ValueError
def get_test(self, states, actions):
# now states is a batched graph -> unbatch it, take the len, pick one sub-graph
# randomly and select the corresponding action
frame_graphs = dgl.unbatch(states)
trial_len = len(frame_graphs)
query_idx = random.randint(0, trial_len - 1)
query_graph = frame_graphs[query_idx]
target_action = actions[query_idx]
return query_graph, target_action
def __getitem__(self, idx):
with open(self.path+str(idx)+'.pkl', 'rb') as f:
states, actions, lens, _ = pkl.load(f)
# shuffle
ziplist = list(zip(states, actions, lens))
random.shuffle(ziplist)
states, actions, lens = zip(*ziplist)
# convert tuples to lists
states, actions, lens = [*states], [*actions], [*lens]
# pick last element in the list as test and pick random frame
test_s, test_a = self.get_test(states[-1], actions[-1])
dem_s = states[:-1]
dem_a = actions[:-1]
dem_lens = lens[:-1]
dem_s = dgl.batch(dem_s)
dem_a = torch.stack(dem_a)
return dem_s, dem_a, dem_lens, test_s, test_a
def __len__(self):
return len(os.listdir(self.path))
class TestToMnetDGLDataset(DGLDataset):
"""
Testing dataset class.
"""
def __init__(self, path, task_type=None, mode="test"):
self.path = path
self.type = task_type
self.mode = mode
print('Mode:', self.mode)
if self.mode == 'test':
self.path = self.path + '_dgl_hetero_nobound_4feats/' + self.type + '/'
#self.path = self.path + '_dgl_hetero_nobound_4feats_global/' + self.type + '/'
#self.path = self.path + '_dgl_hetero_nobound_4feats_local/' + self.type + '/'
else:
raise ValueError
def __getitem__(self, idx):
with open(self.path+str(idx)+'.pkl', 'rb') as f:
dem_expected_states, dem_expected_actions, dem_expected_lens, _, \
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, _, \
query_expected_frames, target_expected_actions, \
query_unexpected_frames, target_unexpected_actions = pkl.load(f)
assert len(dem_expected_states) == 8
assert len(dem_expected_actions) == 8
assert len(dem_expected_lens) == 8
assert len(dem_unexpected_states) == 8
assert len(dem_unexpected_actions) == 8
assert len(dem_unexpected_lens) == 8
assert len(dgl.unbatch(query_expected_frames)) == target_expected_actions.size()[0]
assert len(dgl.unbatch(query_unexpected_frames)) == target_unexpected_actions.size()[0]
# ignore n_nodes
return dem_expected_states, dem_expected_actions, dem_expected_lens, \
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, \
query_expected_frames, target_expected_actions, \
query_unexpected_frames, target_unexpected_actions
def __len__(self):
return len(os.listdir(self.path))
class ToMnetDGLDatasetUndersample(DGLDataset):
"""
Training dataset class for the behavior cloning mlp model.
"""
def __init__(self, path, types=None, mode="train"):
self.path = path
self.types = types
self.mode = mode
print('Mode:', self.mode)
if self.mode == 'train':
if len(self.types) == 4:
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/'
elif len(self.types) == 3:
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper() + '/'
print(self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper())
elif len(self.types) == 2:
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + '/'
print(self.types[0][0].upper() + self.types[1][0].upper())
elif len(self.types) == 1:
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + '/'
else: raise ValueError('Number of types different from 1 or 4.')
elif self.mode == 'val':
assert len(self.types) == 1
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/' + self.types[0] + '/'
else:
raise ValueError
print('Undersampled dataset!')
def get_test(self, states, actions):
# now states is a batched graph -> unbatch it, take the len, pick one sub-graph
# randomly and select the corresponding action
frame_graphs = dgl.unbatch(states)
trial_len = len(frame_graphs)
query_idx = random.randint(0, trial_len - 1)
query_graph = frame_graphs[query_idx]
target_action = actions[query_idx]
return query_graph, target_action
def __getitem__(self, idx):
idx = idx + 3175
with open(self.path+str(idx)+'.pkl', 'rb') as f:
states, actions, lens, _ = pkl.load(f)
# shuffle
ziplist = list(zip(states, actions, lens))
random.shuffle(ziplist)
states, actions, lens = zip(*ziplist)
# convert tuples to lists
states, actions, lens = [*states], [*actions], [*lens]
# pick last element in the list as test and pick random frame
test_s, test_a = self.get_test(states[-1], actions[-1])
dem_s = states[:-1]
dem_a = actions[:-1]
dem_lens = lens[:-1]
dem_s = dgl.batch(dem_s)
dem_a = torch.stack(dem_a)
return dem_s, dem_a, dem_lens, test_s, test_a
def __len__(self):
return len(os.listdir(self.path)) - 3175
class ToMnetDGLDatasetMental(DGLDataset):
"""
Training dataset class.
"""
def __init__(self, path, types=None, mode="train"):
self.path = path
self.types = types
self.mode = mode
print('Mode:', self.mode)
if self.mode == 'train':
if len(self.types) == 4:
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/'
elif len(self.types) == 3:
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper() + '/'
print(self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper())
elif len(self.types) == 2:
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + '/'
print(self.types[0][0].upper() + self.types[1][0].upper())
elif len(self.types) == 1:
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + '/'
else: raise ValueError('Number of types different from 1 or 4.')
elif self.mode == 'val':
assert len(self.types) == 1
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/' + self.types[0] + '/'
else:
raise ValueError
def get_test(self, states, actions):
"""
return: past_test_graphs, past_test_actions, test_graph, test_action
"""
frame_graphs = dgl.unbatch(states)
trial_len = len(frame_graphs)
query_idx = random.randint(0, trial_len - 1)
test_graph = frame_graphs[query_idx]
test_action = actions[query_idx]
if query_idx > 0:
#past_test_actions = F.pad(past_test_actions, (0,0,0,41-query_idx), 'constant', 0)
if query_idx == 1:
past_test_graphs = frame_graphs[0]
past_test_actions = actions[:query_idx]
past_test_actions = F.pad(past_test_actions, (0,0,0,41-query_idx), 'constant', 0)
return past_test_graphs, past_test_actions, query_idx, test_graph, test_action
else:
past_test_graphs = frame_graphs[:query_idx]
past_test_actions = actions[:query_idx]
past_test_graphs = dgl.batch(past_test_graphs)
past_test_actions = F.pad(past_test_actions, (0,0,0,41-query_idx), 'constant', 0)
return past_test_graphs, past_test_actions, query_idx, test_graph, test_action
else:
test_action_ = F.pad(test_action.unsqueeze(0), (0,0,0,41-1), 'constant', 0)
# NOTE: since there are no past frames, return the test frame and action and query_idx=1
# not sure what would be a better alternative
return test_graph, test_action_, 1, test_graph, test_action
def __getitem__(self, idx):
with open(self.path+str(idx)+'.pkl', 'rb') as f:
states, actions, lens, _ = pkl.load(f)
ziplist = list(zip(states, actions, lens))
random.shuffle(ziplist)
states, actions, lens = zip(*ziplist)
states, actions, lens = [*states], [*actions], [*lens]
past_test_s, past_test_a, past_test_len, test_s, test_a = self.get_test(states[-1], actions[-1])
dem_s = states[:-1]
dem_a = actions[:-1]
dem_lens = lens[:-1]
dem_s = dgl.batch(dem_s)
dem_a = torch.stack(dem_a)
return dem_s, dem_a, dem_lens, past_test_s, past_test_a, past_test_len, test_s, test_a
def __len__(self):
return len(os.listdir(self.path))
# --------------------------------------------------------------------------------------------------------
if __name__ == "__main__":
types = [
'preference', 'multi_agent', 'inaccessible_goal',
'efficiency_irrational', 'efficiency_time','efficiency_path',
'instrumental_no_barrier', 'instrumental_blocking_barrier', 'instrumental_inconsequential_barrier'
]
mental_dataset = ToMnetDGLDatasetMental(
path='/datasets/external/bib_train/graphs/all_tasks/',
types=['instrumental_action'],
mode='train'
)
dem_frames, dem_actions, dem_lens, past_test_frames, past_test_actions, len, test_frame, test_action = mental_dataset.__getitem__(99)
breakpoint()

877
tom/gnn.py Normal file
View file

@ -0,0 +1,877 @@
import dgl.nn.pytorch as dglnn
import torch.nn as nn
import torch.nn.functional as F
import torch
import dgl
import sys
import copy
from wandb import agent
sys.path.append('/projects/bortoletto/irene/')
from tom.norm import Norm
class RSAGEv4(nn.Module):
# multi-layer GNN for one single feature
def __init__(
self,
hidden_channels,
out_channels,
rel_names,
dropout,
n_layers,
activation=nn.ELU(),
):
super().__init__()
self.embedding_type = nn.Linear(9, int(hidden_channels))
self.embedding_pos = nn.Linear(2, int(hidden_channels))
self.embedding_color = nn.Linear(3, int(hidden_channels))
self.embedding_shape = nn.Linear(18, int(hidden_channels))
self.combine_attr = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*3, hidden_channels)
)
self.combine_pos = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*2, hidden_channels)
)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.SAGEConv(
in_feats=hidden_channels,
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
aggregator_type='lstm',
feat_drop=dropout,
activation=activation if l < n_layers - 1 else None
)
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
def forward(self, g):
attr = []
attr.append(self.embedding_type(g.ndata['type'].float()))
pos = self.embedding_pos(g.ndata['pos']/170.)
attr.append(self.embedding_color(g.ndata['color']/255.))
attr.append(self.embedding_shape(g.ndata['shape'].float()))
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
for l, conv in enumerate(self.layers):
h = conv(g, h)
with g.local_scope():
g.ndata['h'] = h['obj']
out = dgl.mean_nodes(g, 'h')
return out
# -------------------------------------------------------------------------------------------
class RAGNNv4(nn.Module):
# multi-layer GNN for one single feature
def __init__(
self,
hidden_channels,
out_channels,
rel_names,
dropout,
n_layers,
activation=nn.ELU(),
):
super().__init__()
self.embedding_type = nn.Linear(9, int(hidden_channels))
self.embedding_pos = nn.Linear(2, int(hidden_channels))
self.embedding_color = nn.Linear(3, int(hidden_channels))
self.embedding_shape = nn.Linear(18, int(hidden_channels))
self.combine_attr = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*3, hidden_channels)
)
self.combine_pos = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*2, hidden_channels)
)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.AGNNConv()
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
def forward(self, g):
attr = []
attr.append(self.embedding_type(g.ndata['type'].float()))
pos = self.embedding_pos(g.ndata['pos']/170.)
attr.append(self.embedding_color(g.ndata['color']/255.))
attr.append(self.embedding_shape(g.ndata['shape'].float()))
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
for l, conv in enumerate(self.layers):
h = conv(g, h)
with g.local_scope():
g.ndata['h'] = h['obj']
out = dgl.mean_nodes(g, 'h')
return out
# -------------------------------------------------------------------------------------------
class RGATv2(nn.Module):
# multi-layer GNN for one single feature
def __init__(
self,
hidden_channels,
out_channels,
num_heads,
rel_names,
dropout,
n_layers,
activation=nn.ELU(),
residual=False
):
super().__init__()
self.embedding_type = nn.Embedding(9, int(hidden_channels*num_heads/4))
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads/4))
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads/4))
self.embedding_shape = nn.Embedding(18, int(hidden_channels*num_heads/4))
self.combine = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads, hidden_channels*num_heads)
)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.GATv2Conv(
in_feats=hidden_channels*num_heads,
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
num_heads=num_heads if l < n_layers - 1 else 1,
feat_drop=dropout,
attn_drop=dropout,
residual=residual,
activation=activation if l < n_layers - 1 else None
)
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
def forward(self, g):
feats = []
feats.append(self.embedding_type(torch.argmax(g.ndata['type'], dim=1)))
feats.append(self.embedding_pos(g.ndata['pos']/170.))
feats.append(self.embedding_color(g.ndata['color']/255.))
feats.append(self.embedding_shape(torch.argmax(g.ndata['shape'], dim=1)))
h = {'obj': self.combine(torch.cat(feats, dim=1))}
for l, conv in enumerate(self.layers):
h = conv(g, h)
if l != len(self.layers) - 1:
h = {k: v.flatten(1) for k, v in h.items()}
else:
h = {k: v.mean(1) for k, v in h.items()}
with g.local_scope():
g.ndata['h'] = h['obj']
out = dgl.mean_nodes(g, 'h')
return out
# -------------------------------------------------------------------------------------------
class RGATv3(nn.Module):
# multi-layer GNN for one single feature
def __init__(
self,
hidden_channels,
out_channels,
num_heads,
rel_names,
dropout,
n_layers,
activation=nn.ELU(),
residual=False
):
super().__init__()
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
self.combine = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*4, hidden_channels*num_heads)
)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.GATv2Conv(
in_feats=hidden_channels*num_heads,
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
feat_drop=dropout,
attn_drop=dropout,
residual=residual,
activation=activation if l < n_layers - 1 else None
)
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
def forward(self, g):
feats = []
feats.append(self.embedding_type(g.ndata['type'].float()))
feats.append(self.embedding_pos(g.ndata['pos']/170.)) # NOTE: this should be 180 because I remove the boundary walls!
feats.append(self.embedding_color(g.ndata['color']/255.))
feats.append(self.embedding_shape(g.ndata['shape'].float()))
h = {'obj': self.combine(torch.cat(feats, dim=1))}
for l, conv in enumerate(self.layers):
h = conv(g, h)
if l != len(self.layers) - 1:
h = {k: v.flatten(1) for k, v in h.items()}
else:
h = {k: v.mean(1) for k, v in h.items()}
with g.local_scope():
g.ndata['h'] = h['obj']
out = dgl.mean_nodes(g, 'h')
return out
# -------------------------------------------------------------------------------------------
class RGCNv2(nn.Module):
# multi-layer GNN for one single feature
def __init__(
self,
hidden_channels,
out_channels,
rel_names,
dropout,
n_layers,
activation=nn.ReLU()
):
super().__init__()
self.embedding_type = nn.Linear(9, int(hidden_channels))
self.embedding_pos = nn.Linear(2, int(hidden_channels))
self.embedding_color = nn.Linear(3, int(hidden_channels))
self.embedding_shape = nn.Linear(18, int(hidden_channels))
self.combine = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*4, hidden_channels)
)
self.layers = nn.ModuleList([
dglnn.RelGraphConv(
in_feat=hidden_channels,
out_feat=hidden_channels,
num_rels=len(rel_names),
regularizer=None,
num_bases=None,
bias=True,
activation=activation,
self_loop=True,
dropout=dropout,
layer_norm=False
)
for _ in range(n_layers-1)])
self.layers.append(
dglnn.RelGraphConv(
in_feat=hidden_channels,
out_feat=out_channels,
num_rels=len(rel_names),
regularizer=None,
num_bases=None,
bias=True,
activation=activation,
self_loop=True,
dropout=dropout,
layer_norm=False
)
)
def forward(self, g):
g = g.to_homogeneous()
feats = []
feats.append(self.embedding_type(g.ndata['type'].float()))
feats.append(self.embedding_pos(g.ndata['pos']/170.))
feats.append(self.embedding_color(g.ndata['color']/255.))
feats.append(self.embedding_shape(g.ndata['shape'].float()))
h = self.combine(torch.cat(feats, dim=1))
for conv in self.layers:
h = conv(g, h, g.etypes)
with g.local_scope():
g.ndata['h'] = h
out = dgl.mean_nodes(g, 'h')
return out
# -------------------------------------------------------------------------------------------
class RGATv3Agent(nn.Module):
# multi-layer GNN for one single feature
def __init__(
self,
hidden_channels,
out_channels,
num_heads,
rel_names,
dropout,
n_layers,
activation=nn.ELU(),
residual=False
):
super().__init__()
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
self.combine = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*4, hidden_channels*num_heads)
)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.GATv2Conv(
in_feats=hidden_channels*num_heads,
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
feat_drop=dropout,
attn_drop=dropout,
residual=residual,
activation=activation if l < n_layers - 1 else None
)
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
def forward(self, g):
agent_mask = g.ndata['type'][:, 0] == 1
feats = []
feats.append(self.embedding_type(g.ndata['type'].float()))
feats.append(self.embedding_pos(g.ndata['pos']/200.))
feats.append(self.embedding_color(g.ndata['color']/255.))
feats.append(self.embedding_shape(g.ndata['shape'].float()))
h = {'obj': self.combine(torch.cat(feats, dim=1))}
for l, conv in enumerate(self.layers):
h = conv(g, h)
if l != len(self.layers) - 1:
h = {k: v.flatten(1) for k, v in h.items()}
else:
h = {k: v.mean(1) for k, v in h.items()}
with g.local_scope():
g.ndata['h'] = h['obj']
out = g.ndata['h'][agent_mask, :]
ctx = dgl.mean_nodes(g, 'h')
return out + ctx
# -------------------------------------------------------------------------------------------
class RGATv4(nn.Module):
# multi-layer GNN for one single feature
def __init__(
self,
hidden_channels,
out_channels,
num_heads,
rel_names,
dropout,
n_layers,
activation=nn.ELU(),
residual=False
):
super().__init__()
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
self.combine_attr = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads)
)
self.combine_pos = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads)
)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.GATv2Conv(
in_feats=hidden_channels*num_heads,
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
feat_drop=dropout,
attn_drop=dropout,
residual=residual,
share_weights=False,
activation=activation if l < n_layers - 1 else None
)
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
def forward(self, g):
attr = []
attr.append(self.embedding_type(g.ndata['type'].float()))
pos = self.embedding_pos(g.ndata['pos']/170.)
attr.append(self.embedding_color(g.ndata['color']/255.))
attr.append(self.embedding_shape(g.ndata['shape'].float()))
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
for l, conv in enumerate(self.layers):
h = conv(g, h)
if l != len(self.layers) - 1:
h = {k: v.flatten(1) for k, v in h.items()}
else:
h = {k: v.mean(1) for k, v in h.items()}
with g.local_scope():
g.ndata['h'] = h['obj']
out = dgl.mean_nodes(g, 'h')
return out
# -------------------------------------------------------------------------------------------
class RGATv4Norm(nn.Module):
# multi-layer GNN for one single feature
def __init__(
self,
hidden_channels,
out_channels,
num_heads,
rel_names,
dropout,
n_layers,
activation=nn.ELU(),
residual=False
):
super().__init__()
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
self.combine_attr = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads)
)
self.combine_pos = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads)
)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.GATv2Conv(
in_feats=hidden_channels*num_heads,
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
feat_drop=dropout,
attn_drop=dropout,
residual=residual,
activation=activation if l < n_layers - 1 else None
)
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
self.norms = nn.ModuleList([
Norm(
norm_type='gn',
hidden_dim=hidden_channels*num_heads if l < n_layers - 1 else out_channels
)
for l in range(n_layers)
])
def forward(self, g):
attr = []
attr.append(self.embedding_type(g.ndata['type'].float()))
pos = self.embedding_pos(g.ndata['pos']/170.)
attr.append(self.embedding_color(g.ndata['color']/255.))
attr.append(self.embedding_shape(g.ndata['shape'].float()))
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
for l, conv in enumerate(self.layers):
h = conv(g, h)
if l != len(self.layers) - 1:
h = {k: v.flatten(1) for k, v in h.items()}
else:
h = {k: v.mean(1) for k, v in h.items()}
h = {k: self.norms[l](g, v) for k, v in h.items()}
with g.local_scope():
g.ndata['h'] = h['obj']
out = dgl.mean_nodes(g, 'h')
return out
# -------------------------------------------------------------------------------------------
class RGATv3Norm(nn.Module):
# multi-layer GNN for one single feature
def __init__(
self,
hidden_channels,
out_channels,
num_heads,
rel_names,
dropout,
n_layers,
activation=nn.ELU(),
residual=False
):
super().__init__()
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
self.combine = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*4, hidden_channels*num_heads)
)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.GATv2Conv(
in_feats=hidden_channels*num_heads,
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
feat_drop=dropout,
attn_drop=dropout,
residual=residual,
activation=activation if l < n_layers - 1 else None
)
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
self.norms = nn.ModuleList([
Norm(
norm_type='gn',
hidden_dim=hidden_channels*num_heads if l < n_layers - 1 else out_channels
)
for l in range(n_layers)
])
def forward(self, g):
feats = []
feats.append(self.embedding_type(g.ndata['type'].float()))
feats.append(self.embedding_pos(g.ndata['pos']/170.))
feats.append(self.embedding_color(g.ndata['color']/255.))
feats.append(self.embedding_shape(g.ndata['shape'].float()))
h = {'obj': self.combine(torch.cat(feats, dim=1))}
for l, conv in enumerate(self.layers):
h = conv(g, h)
if l != len(self.layers) - 1:
h = {k: v.flatten(1) for k, v in h.items()}
else:
h = {k: v.mean(1) for k, v in h.items()}
h = {k: self.norms[l](g, v) for k, v in h.items()}
with g.local_scope():
g.ndata['h'] = h['obj']
out = dgl.mean_nodes(g, 'h')
return out
# -------------------------------------------------------------------------------------------
class RGATv4Agent(nn.Module):
# multi-layer GNN for one single feature
def __init__(
self,
hidden_channels,
out_channels,
num_heads,
rel_names,
dropout,
n_layers,
activation=nn.ELU(),
residual=False
):
super().__init__()
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
self.combine_attr = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads)
)
self.combine_pos = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads)
)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.GATv2Conv(
in_feats=hidden_channels*num_heads,
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
feat_drop=dropout,
attn_drop=dropout,
residual=residual,
activation=activation if l < n_layers - 1 else None
)
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
self.combine_agent_context = nn.Linear(out_channels*2, out_channels)
def forward(self, g):
agent_mask = g.ndata['type'][:, 0] == 1
attr = []
attr.append(self.embedding_type(g.ndata['type'].float()))
pos = self.embedding_pos(g.ndata['pos']/170.)
attr.append(self.embedding_color(g.ndata['color']/255.))
attr.append(self.embedding_shape(g.ndata['shape'].float()))
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
for l, conv in enumerate(self.layers):
h = conv(g, h)
if l != len(self.layers) - 1:
h = {k: v.flatten(1) for k, v in h.items()}
else:
h = {k: v.mean(1) for k, v in h.items()}
with g.local_scope():
g.ndata['h'] = h['obj']
h_a = g.ndata['h'][agent_mask, :]
g_no_agent = copy.deepcopy(g)
g_no_agent.remove_nodes([i for i, x in enumerate(agent_mask) if x])
h_g = dgl.mean_nodes(g_no_agent, 'h')
out = self.combine_agent_context(torch.cat((h_a, h_g), dim=1))
return out
# -------------------------------------------------------------------------------------------
class RGCNv4(nn.Module):
# multi-layer GNN for one single feature
def __init__(
self,
hidden_channels,
out_channels,
rel_names,
n_layers,
activation=nn.ELU(),
):
super().__init__()
self.embedding_type = nn.Linear(9, int(hidden_channels))
self.embedding_pos = nn.Linear(2, int(hidden_channels))
self.embedding_color = nn.Linear(3, int(hidden_channels))
self.embedding_shape = nn.Linear(18, int(hidden_channels))
self.combine_attr = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*3, hidden_channels)
)
self.combine_pos = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*2, hidden_channels)
)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.GraphConv(
in_feats=hidden_channels,
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
activation=activation if l < n_layers - 1 else None
)
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
def forward(self, g):
attr = []
attr.append(self.embedding_type(g.ndata['type'].float()))
pos = self.embedding_pos(g.ndata['pos']/170.)
attr.append(self.embedding_color(g.ndata['color']/255.))
attr.append(self.embedding_shape(g.ndata['shape'].float()))
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
for l, conv in enumerate(self.layers):
h = conv(g, h)
with g.local_scope():
g.ndata['h'] = h['obj']
out = dgl.mean_nodes(g, 'h')
return out
# -------------------------------------------------------------------------------------------
class RGATv5(nn.Module):
def __init__(
self,
hidden_channels,
out_channels,
num_heads,
rel_names,
dropout,
n_layers,
activation=nn.ELU(),
residual=False
):
super().__init__()
self.hidden_channels = hidden_channels
self.num_heads = num_heads
self.embedding_type = nn.Linear(9, hidden_channels*num_heads)
self.embedding_pos = nn.Linear(2, hidden_channels*num_heads)
self.embedding_color = nn.Linear(3, hidden_channels*num_heads)
self.embedding_shape = nn.Linear(18, hidden_channels*num_heads)
self.combine = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*4, hidden_channels*num_heads)
)
self.attention = nn.Linear(hidden_channels*num_heads*4, 4)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.GATv2Conv(
in_feats=hidden_channels*num_heads,
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
num_heads=num_heads,
feat_drop=dropout,
attn_drop=dropout,
residual=residual,
activation=activation if l < n_layers - 1 else None
)
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
def forward(self, g):
feats = []
feats.append(self.embedding_type(g.ndata['type'].float()))
feats.append(self.embedding_pos(g.ndata['pos']/170.))
feats.append(self.embedding_color(g.ndata['color']/255.))
feats.append(self.embedding_shape(g.ndata['shape'].float()))
h = torch.cat(feats, dim=1)
feat_attn = F.softmax(self.attention(h), dim=1)
h = h * feat_attn.repeat_interleave(self.hidden_channels*self.num_heads, dim=1)
h_in = self.combine(h)
h = {'obj': h_in}
for conv in self.layers:
h = conv(g, h)
h = {k: v.flatten(1) for k, v in h.items()}
#if l != len(self.layers) - 1:
# h = {k: v.flatten(1) for k, v in h.items()}
#else:
# h = {k: v.mean(1) for k, v in h.items()}
h = {k: v + h_in for k, v in h.items()}
with g.local_scope():
g.ndata['h'] = h['obj']
out = dgl.mean_nodes(g, 'h')
return out
# -------------------------------------------------------------------------------------------
class RGATv6(nn.Module):
# RGATv6 = RGATv4 + Global Attention Pooling
def __init__(
self,
hidden_channels,
out_channels,
num_heads,
rel_names,
dropout,
n_layers,
activation=nn.ELU(),
residual=False
):
super().__init__()
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
self.combine_attr = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads)
)
self.combine_pos = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads)
)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.GATv2Conv(
in_feats=hidden_channels*num_heads,
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
feat_drop=dropout,
attn_drop=dropout,
residual=residual,
activation=activation if l < n_layers - 1 else None
)
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
gate_nn = nn.Linear(out_channels, 1)
self.gap = dglnn.GlobalAttentionPooling(gate_nn)
def forward(self, g):
attr = []
attr.append(self.embedding_type(g.ndata['type'].float()))
pos = self.embedding_pos(g.ndata['pos']/170.)
attr.append(self.embedding_color(g.ndata['color']/255.))
attr.append(self.embedding_shape(g.ndata['shape'].float()))
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
for l, conv in enumerate(self.layers):
h = conv(g, h)
if l != len(self.layers) - 1:
h = {k: v.flatten(1) for k, v in h.items()}
else:
h = {k: v.mean(1) for k, v in h.items()}
#with g.local_scope():
#g.ndata['h'] = h['obj']
#out = dgl.mean_nodes(g, 'h')
out = self.gap(g, h['obj'])
return out
# -------------------------------------------------------------------------------------------
class RGATv6Agent(nn.Module):
# RGATv6 = RGATv4 + Global Attention Pooling
def __init__(
self,
hidden_channels,
out_channels,
num_heads,
rel_names,
dropout,
n_layers,
activation=nn.ELU(),
residual=False
):
super().__init__()
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))