up
This commit is contained in:
parent
a333481e05
commit
de0bea7508
18 changed files with 3150 additions and 2 deletions
76
README.md
76
README.md
|
@ -1,3 +1,75 @@
|
|||
# IRENE
|
||||
<div align="center">
|
||||
<h1> Neural Reasoning about Agents' Goals, Preferences, and Actions </h1>
|
||||
|
||||
Official code of "Neural Reasoning About Agents' Goals, Preferences, and Actions"
|
||||
**[Matteo Bortoletto][1], [Lei Shi][2], [Andreas Bulling][3]** <br> <br>
|
||||
**AAAI'24, Vancouver, CA** <br>
|
||||
**[[Paper][4]]**
|
||||
|
||||
</div>
|
||||
|
||||
# Citation
|
||||
If you find our code useful or use it in your own projects, please cite our paper:
|
||||
|
||||
```bibtex
|
||||
@inproceedings{bortoletto2024neural,
|
||||
author = {Bortoletto, Matteo and Lei, Shi and Bulling, Andreas},
|
||||
title = {{Neural Reasoning about Agents' Goals, Preferences, and Actions}},
|
||||
booktitle = {Proc. 38th AAAI Conference on Artificial Intelligence (AAAI)},
|
||||
year = {2024},
|
||||
}
|
||||
```
|
||||
|
||||
# Setup
|
||||
|
||||
This code is based on the [original implementation][5] of the BIB benchmark.
|
||||
|
||||
## Using `virtualenv`
|
||||
```
|
||||
python -m virtualenv /path/to/env
|
||||
source /path/to/env/bin/activate
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Using `conda`
|
||||
```
|
||||
conda create --name <env_name> python=3.8.10 pip=20.0.2 cudatoolkit=10.2.89
|
||||
conda activate <env_name>
|
||||
pip install -r requirements_conda.txt
|
||||
pip install dgl-cu102 dglgo -f https://data.dgl.ai/wheels/repo.html
|
||||
```
|
||||
|
||||
|
||||
# Running the code
|
||||
|
||||
## Activate the environment
|
||||
Run `source bibdgl/bin/activate`.
|
||||
|
||||
## Index data
|
||||
This will create the json files with all the indexed frames for each episode in each video.
|
||||
```
|
||||
python utils/index_data.py
|
||||
```
|
||||
You need to manually set `mode` in the dataset class (in main).
|
||||
|
||||
## Generate graphs
|
||||
This will generate the graphs from the videos:
|
||||
```
|
||||
python /utils/build_graphs.py --mode MODE --cpus NUM_CPUS
|
||||
```
|
||||
`MODE` can be `train`, `val` or `test`. NOTE: check `utils/build_graphs.py` to make sure you're loading the correct dataset to generate the graphs you want.
|
||||
|
||||
## Training
|
||||
You can use the `gtbc.sh`.
|
||||
|
||||
## Testing
|
||||
Use `run_test_tom.sh`.
|
||||
|
||||
# Hardware setup
|
||||
All models are trained on an NVIDIA Tesla V100-SXM2-32GB GPU.
|
||||
|
||||
|
||||
[1]: https://mattbortoletto.github.io/
|
||||
[2]: https://perceptualui.org/people/shi/
|
||||
[3]: https://perceptualui.org/people/bulling/
|
||||
[4]: https://perceptualui.org/publications/bortoletto24_aaai.pdf
|
||||
[5]: https://github.com/kanishkg/bib-baselines
|
||||
|
|
9
run_test.sh
Normal file
9
run_test.sh
Normal file
|
@ -0,0 +1,9 @@
|
|||
echo 314 e31
|
||||
|
||||
CUDA_VISIBLE_DEVICES=1 python test_tom.py \
|
||||
--model_type graphbcrnn \
|
||||
--types efficiency_irrational \
|
||||
--ckpt /projects/bortoletto/icml2023_matteo/wandb/run-20221224_135525-8i1r2aqy/files/bib/8i1r2aqy/checkpoints/epoch\=31-step\=22399.ckpt \
|
||||
--data_path /datasets/external/bib_evaluation_1_1/graphs/all_tasks \
|
||||
--process_data 0 \
|
||||
--surprise_type max
|
18
run_train.sh
Normal file
18
run_train.sh
Normal file
|
@ -0,0 +1,18 @@
|
|||
CUDA_VISIBLE_DEVICES=0 python train_tom.py \
|
||||
--model_type graphbcrnn \
|
||||
--types single_object preference instrumental_action \
|
||||
--data_path /datasets/external/bib_train/graphs/all_tasks/ \
|
||||
--seed 7 \
|
||||
--batch_size 32 \
|
||||
--max_epochs 35 \
|
||||
--gpus 1 \
|
||||
--auto_select_gpus True \
|
||||
--num_workers 2 \
|
||||
--stochastic_weight_avg True \
|
||||
--lr 5e-4 \
|
||||
--check_val_every_n_epoch 1 \
|
||||
--track_grad_norm 2 \
|
||||
--gradient_clip_val 10 \
|
||||
--gnn_type RSAGEv4 \
|
||||
--state_dim 96 \
|
||||
--aggregation sum
|
122
test_tom.py
Normal file
122
test_tom.py
Normal file
|
@ -0,0 +1,122 @@
|
|||
from argparse import ArgumentParser
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
import torch.nn.functional as F
|
||||
import dgl
|
||||
|
||||
from tom.dataset import TestToMnetDGLDataset, collate_function_seq_test
|
||||
from tom.model import GraphBC_T, GraphBCRNN
|
||||
|
||||
|
||||
def get_z_scores(total, total_expected, total_unexpected):
|
||||
mean = np.mean(total)
|
||||
std = np.std(total)
|
||||
print("Z-Score expected: ",
|
||||
(np.mean(total_expected) - mean) / std)
|
||||
print("Z-Score unexpected: ",
|
||||
(np.mean(total_unexpected) - mean) / std)
|
||||
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
parser.add_argument('--model_type', type=str, default='graphbcrnn')
|
||||
parser.add_argument('--ckpt', type=str, default=None, help='path to checkpoint')
|
||||
parser.add_argument('--data_path', type=str, default=None, help='path to the data')
|
||||
parser.add_argument('--process_data', type=int, default=0)
|
||||
parser.add_argument('--surprise_type', type=str, default='max',
|
||||
help='surprise type: mean, max. This is used for comparing the plausibility scores of the two test episodes')
|
||||
parser.add_argument('--types', nargs='+', type=str,
|
||||
default=[
|
||||
'preference', 'multi_agent', 'inaccessible_goal',
|
||||
'efficiency_irrational', 'efficiency_time','efficiency_path',
|
||||
'instrumental_no_barrier', 'instrumental_blocking_barrier', 'instrumental_inconsequential_barrier'
|
||||
],
|
||||
help='types of tasks used for training / testing')
|
||||
parser.add_argument('--filename', type=str, default='')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
filename = args.filename
|
||||
|
||||
if args.model_type == 'graphbct':
|
||||
model = GraphBC_T.load_from_checkpoint(args.ckpt)
|
||||
elif args.model_type == 'graphbcrnn':
|
||||
model = GraphBCRNN.load_from_checkpoint(args.ckpt)
|
||||
else:
|
||||
raise ValueError('Unknown model type.')
|
||||
|
||||
device = 'cuda'
|
||||
model.to(device)
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for t in args.types:
|
||||
if args.model_type == 'graphbcrnn':
|
||||
test_dataset = TestToMnetDGLDataset(
|
||||
path=args.data_path,
|
||||
task_type=t,
|
||||
mode='test'
|
||||
)
|
||||
test_dataloader = DataLoader(
|
||||
test_dataset,
|
||||
batch_size=1,
|
||||
num_workers=1,
|
||||
pin_memory=True,
|
||||
collate_fn=collate_function_seq_test,
|
||||
shuffle=False
|
||||
)
|
||||
count = 0
|
||||
total, total_expected, total_unexpected = [], [], []
|
||||
pbar = tqdm(test_dataloader)
|
||||
for j, batch in enumerate(pbar):
|
||||
if args.model_type == 'graphbcrnn':
|
||||
dem_expected_states, dem_expected_actions, dem_expected_lens, \
|
||||
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, \
|
||||
query_expected_frames, target_expected_actions, \
|
||||
query_unexpected_frames, target_unexpected_actions = batch
|
||||
dem_expected_states = dem_expected_states.to(device)
|
||||
dem_expected_actions = dem_expected_actions.to(device)
|
||||
dem_unexpected_states = dem_unexpected_states.to(device)
|
||||
dem_unexpected_actions = dem_unexpected_actions.to(device)
|
||||
target_expected_actions = target_expected_actions.to(device)
|
||||
target_unexpected_actions = target_unexpected_actions.to(device)
|
||||
surprise_expected = []
|
||||
query_expected_frames = dgl.unbatch(query_expected_frames)
|
||||
for i in range(len(query_expected_frames)):
|
||||
if args.model_type == 'graphbcrnn':
|
||||
test_actions, test_actions_pred = model(
|
||||
[dem_expected_states, dem_expected_actions, dem_expected_lens, query_expected_frames[i].to(device), target_expected_actions[:, i, :]]
|
||||
)
|
||||
loss = F.mse_loss(test_actions, test_actions_pred)
|
||||
surprise_expected.append(loss.cpu().detach().numpy())
|
||||
mean_expected_surprise = np.mean(surprise_expected)
|
||||
max_expected_surprise = np.max(surprise_expected)
|
||||
|
||||
# calculate the plausibility scores for the unexpected episode
|
||||
surprise_unexpected = []
|
||||
query_unexpected_frames = dgl.unbatch(query_unexpected_frames)
|
||||
for i in range(len(query_unexpected_frames)):
|
||||
if args.model_type == 'graphbcrnn':
|
||||
test_actions, test_actions_pred = model(
|
||||
[dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, query_unexpected_frames[i].to(device), target_unexpected_actions[:, i, :]]
|
||||
)
|
||||
loss = F.mse_loss(test_actions, test_actions_pred)
|
||||
surprise_unexpected.append(loss.cpu().detach().numpy())
|
||||
mean_unexpected_surprise = np.mean(surprise_unexpected)
|
||||
max_unexpected_surprise = np.max(surprise_unexpected)
|
||||
|
||||
correct_mean = mean_expected_surprise < mean_unexpected_surprise + 0.5 * (mean_expected_surprise == mean_unexpected_surprise)
|
||||
correct_max = max_expected_surprise < max_unexpected_surprise + 0.5 * (max_expected_surprise == max_unexpected_surprise)
|
||||
if args.surprise_type == 'max':
|
||||
count += correct_max
|
||||
elif args.surprise_type == 'mean':
|
||||
count += correct_mean
|
||||
pbar.set_postfix({'accuracy': count/(j+1.), 'type': t})
|
||||
|
||||
total_expected.append(max_expected_surprise)
|
||||
total_unexpected.append(max_unexpected_surprise)
|
||||
total.append(max_expected_surprise)
|
||||
total.append(max_unexpected_surprise)
|
||||
get_z_scores(total, total_expected, total_unexpected)
|
0
tom/__init__.py
Normal file
0
tom/__init__.py
Normal file
310
tom/dataset.py
Normal file
310
tom/dataset.py
Normal file
|
@ -0,0 +1,310 @@
|
|||
import pickle as pkl
|
||||
import os
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import torch.nn.functional as F
|
||||
import dgl
|
||||
import random
|
||||
from dgl.data import DGLDataset
|
||||
|
||||
|
||||
def collate_function_seq(batch):
|
||||
#dem_frames = torch.stack([item[0] for item in batch])
|
||||
dem_frames = dgl.batch([item[0] for item in batch])
|
||||
dem_actions = torch.stack([item[1] for item in batch])
|
||||
dem_lens = [item[2] for item in batch]
|
||||
#query_frames = torch.stack([item[3] for item in batch])
|
||||
query_frames = dgl.batch([item[3] for item in batch])
|
||||
target_actions = torch.stack([item[4] for item in batch])
|
||||
return [dem_frames, dem_actions, dem_lens, query_frames, target_actions]
|
||||
|
||||
def collate_function_seq_test(batch):
|
||||
dem_expected_states = dgl.batch([item[0] for item in batch][0])
|
||||
dem_expected_actions = torch.stack([item[1] for item in batch][0]).unsqueeze(dim=0)
|
||||
dem_expected_lens = [item[2] for item in batch]
|
||||
#print(dem_expected_actions.size())
|
||||
dem_unexpected_states = dgl.batch([item[3] for item in batch][0])
|
||||
dem_unexpected_actions = torch.stack([item[4] for item in batch][0]).unsqueeze(dim=0)
|
||||
dem_unexpected_lens = [item[5] for item in batch]
|
||||
query_expected_frames = dgl.batch([item[6] for item in batch])
|
||||
target_expected_actions = torch.stack([item[7] for item in batch])
|
||||
#print(target_expected_actions.size())
|
||||
query_unexpected_frames = dgl.batch([item[8] for item in batch])
|
||||
target_unexpected_actions = torch.stack([item[9] for item in batch])
|
||||
return [
|
||||
dem_expected_states, dem_expected_actions, dem_expected_lens, \
|
||||
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, \
|
||||
query_expected_frames, target_expected_actions, \
|
||||
query_unexpected_frames, target_unexpected_actions
|
||||
]
|
||||
|
||||
def collate_function_mental(batch):
|
||||
dem_frames = dgl.batch([item[0] for item in batch])
|
||||
dem_actions = torch.stack([item[1] for item in batch])
|
||||
dem_lens = [item[2] for item in batch]
|
||||
past_test_frames = dgl.batch([item[3] for item in batch])
|
||||
past_test_actions = torch.stack([item[4] for item in batch])
|
||||
past_test_len = [item[5] for item in batch]
|
||||
query_frames = dgl.batch([item[6] for item in batch])
|
||||
target_actions = torch.stack([item[7] for item in batch])
|
||||
return [dem_frames, dem_actions, dem_lens, past_test_frames, past_test_actions, past_test_len, query_frames, target_actions]
|
||||
|
||||
|
||||
class ToMnetDGLDataset(DGLDataset):
|
||||
"""
|
||||
Training dataset class.
|
||||
"""
|
||||
def __init__(self, path, types=None, mode="train"):
|
||||
self.path = path
|
||||
self.types = types
|
||||
self.mode = mode
|
||||
print('Mode:', self.mode)
|
||||
|
||||
if self.mode == 'train':
|
||||
if len(self.types) == 4:
|
||||
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/'
|
||||
#self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_global/'
|
||||
#self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_local/'
|
||||
elif len(self.types) == 3:
|
||||
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper() + '/'
|
||||
print(self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper())
|
||||
elif len(self.types) == 2:
|
||||
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + '/'
|
||||
print(self.types[0][0].upper() + self.types[1][0].upper())
|
||||
elif len(self.types) == 1:
|
||||
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + '/'
|
||||
else: raise ValueError('Number of types different from 1 or 4.')
|
||||
elif self.mode == 'val':
|
||||
assert len(self.types) == 1
|
||||
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/' + self.types[0] + '/'
|
||||
#self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_global/' + self.types[0] + '/'
|
||||
#self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_local/' + self.types[0] + '/'
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
def get_test(self, states, actions):
|
||||
# now states is a batched graph -> unbatch it, take the len, pick one sub-graph
|
||||
# randomly and select the corresponding action
|
||||
frame_graphs = dgl.unbatch(states)
|
||||
trial_len = len(frame_graphs)
|
||||
query_idx = random.randint(0, trial_len - 1)
|
||||
query_graph = frame_graphs[query_idx]
|
||||
target_action = actions[query_idx]
|
||||
return query_graph, target_action
|
||||
|
||||
def __getitem__(self, idx):
|
||||
with open(self.path+str(idx)+'.pkl', 'rb') as f:
|
||||
states, actions, lens, _ = pkl.load(f)
|
||||
# shuffle
|
||||
ziplist = list(zip(states, actions, lens))
|
||||
random.shuffle(ziplist)
|
||||
states, actions, lens = zip(*ziplist)
|
||||
# convert tuples to lists
|
||||
states, actions, lens = [*states], [*actions], [*lens]
|
||||
# pick last element in the list as test and pick random frame
|
||||
test_s, test_a = self.get_test(states[-1], actions[-1])
|
||||
dem_s = states[:-1]
|
||||
dem_a = actions[:-1]
|
||||
dem_lens = lens[:-1]
|
||||
dem_s = dgl.batch(dem_s)
|
||||
dem_a = torch.stack(dem_a)
|
||||
return dem_s, dem_a, dem_lens, test_s, test_a
|
||||
|
||||
def __len__(self):
|
||||
return len(os.listdir(self.path))
|
||||
|
||||
|
||||
class TestToMnetDGLDataset(DGLDataset):
|
||||
"""
|
||||
Testing dataset class.
|
||||
"""
|
||||
def __init__(self, path, task_type=None, mode="test"):
|
||||
self.path = path
|
||||
self.type = task_type
|
||||
self.mode = mode
|
||||
print('Mode:', self.mode)
|
||||
|
||||
if self.mode == 'test':
|
||||
self.path = self.path + '_dgl_hetero_nobound_4feats/' + self.type + '/'
|
||||
#self.path = self.path + '_dgl_hetero_nobound_4feats_global/' + self.type + '/'
|
||||
#self.path = self.path + '_dgl_hetero_nobound_4feats_local/' + self.type + '/'
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
def __getitem__(self, idx):
|
||||
with open(self.path+str(idx)+'.pkl', 'rb') as f:
|
||||
dem_expected_states, dem_expected_actions, dem_expected_lens, _, \
|
||||
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, _, \
|
||||
query_expected_frames, target_expected_actions, \
|
||||
query_unexpected_frames, target_unexpected_actions = pkl.load(f)
|
||||
assert len(dem_expected_states) == 8
|
||||
assert len(dem_expected_actions) == 8
|
||||
assert len(dem_expected_lens) == 8
|
||||
assert len(dem_unexpected_states) == 8
|
||||
assert len(dem_unexpected_actions) == 8
|
||||
assert len(dem_unexpected_lens) == 8
|
||||
assert len(dgl.unbatch(query_expected_frames)) == target_expected_actions.size()[0]
|
||||
assert len(dgl.unbatch(query_unexpected_frames)) == target_unexpected_actions.size()[0]
|
||||
# ignore n_nodes
|
||||
return dem_expected_states, dem_expected_actions, dem_expected_lens, \
|
||||
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, \
|
||||
query_expected_frames, target_expected_actions, \
|
||||
query_unexpected_frames, target_unexpected_actions
|
||||
|
||||
def __len__(self):
|
||||
return len(os.listdir(self.path))
|
||||
|
||||
|
||||
class ToMnetDGLDatasetUndersample(DGLDataset):
|
||||
"""
|
||||
Training dataset class for the behavior cloning mlp model.
|
||||
"""
|
||||
def __init__(self, path, types=None, mode="train"):
|
||||
self.path = path
|
||||
self.types = types
|
||||
self.mode = mode
|
||||
print('Mode:', self.mode)
|
||||
|
||||
if self.mode == 'train':
|
||||
if len(self.types) == 4:
|
||||
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/'
|
||||
elif len(self.types) == 3:
|
||||
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper() + '/'
|
||||
print(self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper())
|
||||
elif len(self.types) == 2:
|
||||
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + '/'
|
||||
print(self.types[0][0].upper() + self.types[1][0].upper())
|
||||
elif len(self.types) == 1:
|
||||
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + '/'
|
||||
else: raise ValueError('Number of types different from 1 or 4.')
|
||||
elif self.mode == 'val':
|
||||
assert len(self.types) == 1
|
||||
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/' + self.types[0] + '/'
|
||||
else:
|
||||
raise ValueError
|
||||
print('Undersampled dataset!')
|
||||
|
||||
def get_test(self, states, actions):
|
||||
# now states is a batched graph -> unbatch it, take the len, pick one sub-graph
|
||||
# randomly and select the corresponding action
|
||||
frame_graphs = dgl.unbatch(states)
|
||||
trial_len = len(frame_graphs)
|
||||
query_idx = random.randint(0, trial_len - 1)
|
||||
query_graph = frame_graphs[query_idx]
|
||||
target_action = actions[query_idx]
|
||||
return query_graph, target_action
|
||||
|
||||
def __getitem__(self, idx):
|
||||
idx = idx + 3175
|
||||
with open(self.path+str(idx)+'.pkl', 'rb') as f:
|
||||
states, actions, lens, _ = pkl.load(f)
|
||||
# shuffle
|
||||
ziplist = list(zip(states, actions, lens))
|
||||
random.shuffle(ziplist)
|
||||
states, actions, lens = zip(*ziplist)
|
||||
# convert tuples to lists
|
||||
states, actions, lens = [*states], [*actions], [*lens]
|
||||
# pick last element in the list as test and pick random frame
|
||||
test_s, test_a = self.get_test(states[-1], actions[-1])
|
||||
dem_s = states[:-1]
|
||||
dem_a = actions[:-1]
|
||||
dem_lens = lens[:-1]
|
||||
dem_s = dgl.batch(dem_s)
|
||||
dem_a = torch.stack(dem_a)
|
||||
return dem_s, dem_a, dem_lens, test_s, test_a
|
||||
|
||||
def __len__(self):
|
||||
return len(os.listdir(self.path)) - 3175
|
||||
|
||||
|
||||
class ToMnetDGLDatasetMental(DGLDataset):
|
||||
"""
|
||||
Training dataset class.
|
||||
"""
|
||||
def __init__(self, path, types=None, mode="train"):
|
||||
self.path = path
|
||||
self.types = types
|
||||
self.mode = mode
|
||||
print('Mode:', self.mode)
|
||||
|
||||
if self.mode == 'train':
|
||||
if len(self.types) == 4:
|
||||
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/'
|
||||
elif len(self.types) == 3:
|
||||
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper() + '/'
|
||||
print(self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper())
|
||||
elif len(self.types) == 2:
|
||||
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + '/'
|
||||
print(self.types[0][0].upper() + self.types[1][0].upper())
|
||||
elif len(self.types) == 1:
|
||||
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + '/'
|
||||
else: raise ValueError('Number of types different from 1 or 4.')
|
||||
elif self.mode == 'val':
|
||||
assert len(self.types) == 1
|
||||
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/' + self.types[0] + '/'
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
def get_test(self, states, actions):
|
||||
"""
|
||||
return: past_test_graphs, past_test_actions, test_graph, test_action
|
||||
"""
|
||||
frame_graphs = dgl.unbatch(states)
|
||||
trial_len = len(frame_graphs)
|
||||
query_idx = random.randint(0, trial_len - 1)
|
||||
test_graph = frame_graphs[query_idx]
|
||||
test_action = actions[query_idx]
|
||||
if query_idx > 0:
|
||||
#past_test_actions = F.pad(past_test_actions, (0,0,0,41-query_idx), 'constant', 0)
|
||||
if query_idx == 1:
|
||||
past_test_graphs = frame_graphs[0]
|
||||
past_test_actions = actions[:query_idx]
|
||||
past_test_actions = F.pad(past_test_actions, (0,0,0,41-query_idx), 'constant', 0)
|
||||
return past_test_graphs, past_test_actions, query_idx, test_graph, test_action
|
||||
else:
|
||||
past_test_graphs = frame_graphs[:query_idx]
|
||||
past_test_actions = actions[:query_idx]
|
||||
past_test_graphs = dgl.batch(past_test_graphs)
|
||||
past_test_actions = F.pad(past_test_actions, (0,0,0,41-query_idx), 'constant', 0)
|
||||
return past_test_graphs, past_test_actions, query_idx, test_graph, test_action
|
||||
else:
|
||||
test_action_ = F.pad(test_action.unsqueeze(0), (0,0,0,41-1), 'constant', 0)
|
||||
# NOTE: since there are no past frames, return the test frame and action and query_idx=1
|
||||
# not sure what would be a better alternative
|
||||
return test_graph, test_action_, 1, test_graph, test_action
|
||||
|
||||
def __getitem__(self, idx):
|
||||
with open(self.path+str(idx)+'.pkl', 'rb') as f:
|
||||
states, actions, lens, _ = pkl.load(f)
|
||||
ziplist = list(zip(states, actions, lens))
|
||||
random.shuffle(ziplist)
|
||||
states, actions, lens = zip(*ziplist)
|
||||
states, actions, lens = [*states], [*actions], [*lens]
|
||||
past_test_s, past_test_a, past_test_len, test_s, test_a = self.get_test(states[-1], actions[-1])
|
||||
dem_s = states[:-1]
|
||||
dem_a = actions[:-1]
|
||||
dem_lens = lens[:-1]
|
||||
dem_s = dgl.batch(dem_s)
|
||||
dem_a = torch.stack(dem_a)
|
||||
return dem_s, dem_a, dem_lens, past_test_s, past_test_a, past_test_len, test_s, test_a
|
||||
|
||||
def __len__(self):
|
||||
return len(os.listdir(self.path))
|
||||
|
||||
# --------------------------------------------------------------------------------------------------------
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
types = [
|
||||
'preference', 'multi_agent', 'inaccessible_goal',
|
||||
'efficiency_irrational', 'efficiency_time','efficiency_path',
|
||||
'instrumental_no_barrier', 'instrumental_blocking_barrier', 'instrumental_inconsequential_barrier'
|
||||
]
|
||||
|
||||
mental_dataset = ToMnetDGLDatasetMental(
|
||||
path='/datasets/external/bib_train/graphs/all_tasks/',
|
||||
types=['instrumental_action'],
|
||||
mode='train'
|
||||
)
|
||||
dem_frames, dem_actions, dem_lens, past_test_frames, past_test_actions, len, test_frame, test_action = mental_dataset.__getitem__(99)
|
||||
breakpoint()
|
877
tom/gnn.py
Normal file
877
tom/gnn.py
Normal file
|
@ -0,0 +1,877 @@
|
|||
import dgl.nn.pytorch as dglnn
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch
|
||||
import dgl
|
||||
import sys
|
||||
import copy
|
||||
|
||||
from wandb import agent
|
||||
sys.path.append('/projects/bortoletto/irene/')
|
||||
from tom.norm import Norm
|
||||
|
||||
|
||||
class RSAGEv4(nn.Module):
|
||||
# multi-layer GNN for one single feature
|
||||
def __init__(
|
||||
self,
|
||||
hidden_channels,
|
||||
out_channels,
|
||||
rel_names,
|
||||
dropout,
|
||||
n_layers,
|
||||
activation=nn.ELU(),
|
||||
):
|
||||
super().__init__()
|
||||
self.embedding_type = nn.Linear(9, int(hidden_channels))
|
||||
self.embedding_pos = nn.Linear(2, int(hidden_channels))
|
||||
self.embedding_color = nn.Linear(3, int(hidden_channels))
|
||||
self.embedding_shape = nn.Linear(18, int(hidden_channels))
|
||||
self.combine_attr = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_channels*3, hidden_channels)
|
||||
)
|
||||
self.combine_pos = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_channels*2, hidden_channels)
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
dglnn.HeteroGraphConv({
|
||||
rel: dglnn.SAGEConv(
|
||||
in_feats=hidden_channels,
|
||||
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
|
||||
aggregator_type='lstm',
|
||||
feat_drop=dropout,
|
||||
activation=activation if l < n_layers - 1 else None
|
||||
)
|
||||
for rel in rel_names}, aggregate='sum')
|
||||
for l in range(n_layers)
|
||||
])
|
||||
|
||||
def forward(self, g):
|
||||
attr = []
|
||||
attr.append(self.embedding_type(g.ndata['type'].float()))
|
||||
pos = self.embedding_pos(g.ndata['pos']/170.)
|
||||
attr.append(self.embedding_color(g.ndata['color']/255.))
|
||||
attr.append(self.embedding_shape(g.ndata['shape'].float()))
|
||||
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
|
||||
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
|
||||
for l, conv in enumerate(self.layers):
|
||||
h = conv(g, h)
|
||||
with g.local_scope():
|
||||
g.ndata['h'] = h['obj']
|
||||
out = dgl.mean_nodes(g, 'h')
|
||||
return out
|
||||
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
class RAGNNv4(nn.Module):
|
||||
# multi-layer GNN for one single feature
|
||||
def __init__(
|
||||
self,
|
||||
hidden_channels,
|
||||
out_channels,
|
||||
rel_names,
|
||||
dropout,
|
||||
n_layers,
|
||||
activation=nn.ELU(),
|
||||
):
|
||||
super().__init__()
|
||||
self.embedding_type = nn.Linear(9, int(hidden_channels))
|
||||
self.embedding_pos = nn.Linear(2, int(hidden_channels))
|
||||
self.embedding_color = nn.Linear(3, int(hidden_channels))
|
||||
self.embedding_shape = nn.Linear(18, int(hidden_channels))
|
||||
self.combine_attr = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_channels*3, hidden_channels)
|
||||
)
|
||||
self.combine_pos = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_channels*2, hidden_channels)
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
dglnn.HeteroGraphConv({
|
||||
rel: dglnn.AGNNConv()
|
||||
for rel in rel_names}, aggregate='sum')
|
||||
for l in range(n_layers)
|
||||
])
|
||||
|
||||
def forward(self, g):
|
||||
attr = []
|
||||
attr.append(self.embedding_type(g.ndata['type'].float()))
|
||||
pos = self.embedding_pos(g.ndata['pos']/170.)
|
||||
attr.append(self.embedding_color(g.ndata['color']/255.))
|
||||
attr.append(self.embedding_shape(g.ndata['shape'].float()))
|
||||
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
|
||||
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
|
||||
for l, conv in enumerate(self.layers):
|
||||
h = conv(g, h)
|
||||
with g.local_scope():
|
||||
g.ndata['h'] = h['obj']
|
||||
out = dgl.mean_nodes(g, 'h')
|
||||
return out
|
||||
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
class RGATv2(nn.Module):
|
||||
# multi-layer GNN for one single feature
|
||||
def __init__(
|
||||
self,
|
||||
hidden_channels,
|
||||
out_channels,
|
||||
num_heads,
|
||||
rel_names,
|
||||
dropout,
|
||||
n_layers,
|
||||
activation=nn.ELU(),
|
||||
residual=False
|
||||
):
|
||||
super().__init__()
|
||||
self.embedding_type = nn.Embedding(9, int(hidden_channels*num_heads/4))
|
||||
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads/4))
|
||||
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads/4))
|
||||
self.embedding_shape = nn.Embedding(18, int(hidden_channels*num_heads/4))
|
||||
self.combine = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_channels*num_heads, hidden_channels*num_heads)
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
dglnn.HeteroGraphConv({
|
||||
rel: dglnn.GATv2Conv(
|
||||
in_feats=hidden_channels*num_heads,
|
||||
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
|
||||
num_heads=num_heads if l < n_layers - 1 else 1,
|
||||
feat_drop=dropout,
|
||||
attn_drop=dropout,
|
||||
residual=residual,
|
||||
activation=activation if l < n_layers - 1 else None
|
||||
)
|
||||
for rel in rel_names}, aggregate='sum')
|
||||
for l in range(n_layers)
|
||||
])
|
||||
|
||||
def forward(self, g):
|
||||
feats = []
|
||||
feats.append(self.embedding_type(torch.argmax(g.ndata['type'], dim=1)))
|
||||
feats.append(self.embedding_pos(g.ndata['pos']/170.))
|
||||
feats.append(self.embedding_color(g.ndata['color']/255.))
|
||||
feats.append(self.embedding_shape(torch.argmax(g.ndata['shape'], dim=1)))
|
||||
h = {'obj': self.combine(torch.cat(feats, dim=1))}
|
||||
for l, conv in enumerate(self.layers):
|
||||
h = conv(g, h)
|
||||
if l != len(self.layers) - 1:
|
||||
h = {k: v.flatten(1) for k, v in h.items()}
|
||||
else:
|
||||
h = {k: v.mean(1) for k, v in h.items()}
|
||||
with g.local_scope():
|
||||
g.ndata['h'] = h['obj']
|
||||
out = dgl.mean_nodes(g, 'h')
|
||||
return out
|
||||
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
class RGATv3(nn.Module):
|
||||
# multi-layer GNN for one single feature
|
||||
def __init__(
|
||||
self,
|
||||
hidden_channels,
|
||||
out_channels,
|
||||
num_heads,
|
||||
rel_names,
|
||||
dropout,
|
||||
n_layers,
|
||||
activation=nn.ELU(),
|
||||
residual=False
|
||||
):
|
||||
super().__init__()
|
||||
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
|
||||
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
|
||||
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
|
||||
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
|
||||
self.combine = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_channels*num_heads*4, hidden_channels*num_heads)
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
dglnn.HeteroGraphConv({
|
||||
rel: dglnn.GATv2Conv(
|
||||
in_feats=hidden_channels*num_heads,
|
||||
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
|
||||
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
|
||||
feat_drop=dropout,
|
||||
attn_drop=dropout,
|
||||
residual=residual,
|
||||
activation=activation if l < n_layers - 1 else None
|
||||
)
|
||||
for rel in rel_names}, aggregate='sum')
|
||||
for l in range(n_layers)
|
||||
])
|
||||
|
||||
def forward(self, g):
|
||||
feats = []
|
||||
feats.append(self.embedding_type(g.ndata['type'].float()))
|
||||
feats.append(self.embedding_pos(g.ndata['pos']/170.)) # NOTE: this should be 180 because I remove the boundary walls!
|
||||
feats.append(self.embedding_color(g.ndata['color']/255.))
|
||||
feats.append(self.embedding_shape(g.ndata['shape'].float()))
|
||||
h = {'obj': self.combine(torch.cat(feats, dim=1))}
|
||||
for l, conv in enumerate(self.layers):
|
||||
h = conv(g, h)
|
||||
if l != len(self.layers) - 1:
|
||||
h = {k: v.flatten(1) for k, v in h.items()}
|
||||
else:
|
||||
h = {k: v.mean(1) for k, v in h.items()}
|
||||
with g.local_scope():
|
||||
g.ndata['h'] = h['obj']
|
||||
out = dgl.mean_nodes(g, 'h')
|
||||
return out
|
||||
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
class RGCNv2(nn.Module):
|
||||
# multi-layer GNN for one single feature
|
||||
def __init__(
|
||||
self,
|
||||
hidden_channels,
|
||||
out_channels,
|
||||
rel_names,
|
||||
dropout,
|
||||
n_layers,
|
||||
activation=nn.ReLU()
|
||||
):
|
||||
super().__init__()
|
||||
self.embedding_type = nn.Linear(9, int(hidden_channels))
|
||||
self.embedding_pos = nn.Linear(2, int(hidden_channels))
|
||||
self.embedding_color = nn.Linear(3, int(hidden_channels))
|
||||
self.embedding_shape = nn.Linear(18, int(hidden_channels))
|
||||
self.combine = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_channels*4, hidden_channels)
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
dglnn.RelGraphConv(
|
||||
in_feat=hidden_channels,
|
||||
out_feat=hidden_channels,
|
||||
num_rels=len(rel_names),
|
||||
regularizer=None,
|
||||
num_bases=None,
|
||||
bias=True,
|
||||
activation=activation,
|
||||
self_loop=True,
|
||||
dropout=dropout,
|
||||
layer_norm=False
|
||||
)
|
||||
for _ in range(n_layers-1)])
|
||||
self.layers.append(
|
||||
dglnn.RelGraphConv(
|
||||
in_feat=hidden_channels,
|
||||
out_feat=out_channels,
|
||||
num_rels=len(rel_names),
|
||||
regularizer=None,
|
||||
num_bases=None,
|
||||
bias=True,
|
||||
activation=activation,
|
||||
self_loop=True,
|
||||
dropout=dropout,
|
||||
layer_norm=False
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, g):
|
||||
g = g.to_homogeneous()
|
||||
feats = []
|
||||
feats.append(self.embedding_type(g.ndata['type'].float()))
|
||||
feats.append(self.embedding_pos(g.ndata['pos']/170.))
|
||||
feats.append(self.embedding_color(g.ndata['color']/255.))
|
||||
feats.append(self.embedding_shape(g.ndata['shape'].float()))
|
||||
h = self.combine(torch.cat(feats, dim=1))
|
||||
for conv in self.layers:
|
||||
h = conv(g, h, g.etypes)
|
||||
with g.local_scope():
|
||||
g.ndata['h'] = h
|
||||
out = dgl.mean_nodes(g, 'h')
|
||||
return out
|
||||
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
class RGATv3Agent(nn.Module):
|
||||
# multi-layer GNN for one single feature
|
||||
def __init__(
|
||||
self,
|
||||
hidden_channels,
|
||||
out_channels,
|
||||
num_heads,
|
||||
rel_names,
|
||||
dropout,
|
||||
n_layers,
|
||||
activation=nn.ELU(),
|
||||
residual=False
|
||||
):
|
||||
super().__init__()
|
||||
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
|
||||
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
|
||||
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
|
||||
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
|
||||
self.combine = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_channels*num_heads*4, hidden_channels*num_heads)
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
dglnn.HeteroGraphConv({
|
||||
rel: dglnn.GATv2Conv(
|
||||
in_feats=hidden_channels*num_heads,
|
||||
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
|
||||
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
|
||||
feat_drop=dropout,
|
||||
attn_drop=dropout,
|
||||
residual=residual,
|
||||
activation=activation if l < n_layers - 1 else None
|
||||
)
|
||||
for rel in rel_names}, aggregate='sum')
|
||||
for l in range(n_layers)
|
||||
])
|
||||
|
||||
def forward(self, g):
|
||||
agent_mask = g.ndata['type'][:, 0] == 1
|
||||
feats = []
|
||||
feats.append(self.embedding_type(g.ndata['type'].float()))
|
||||
feats.append(self.embedding_pos(g.ndata['pos']/200.))
|
||||
feats.append(self.embedding_color(g.ndata['color']/255.))
|
||||
feats.append(self.embedding_shape(g.ndata['shape'].float()))
|
||||
h = {'obj': self.combine(torch.cat(feats, dim=1))}
|
||||
for l, conv in enumerate(self.layers):
|
||||
h = conv(g, h)
|
||||
if l != len(self.layers) - 1:
|
||||
h = {k: v.flatten(1) for k, v in h.items()}
|
||||
else:
|
||||
h = {k: v.mean(1) for k, v in h.items()}
|
||||
with g.local_scope():
|
||||
g.ndata['h'] = h['obj']
|
||||
out = g.ndata['h'][agent_mask, :]
|
||||
ctx = dgl.mean_nodes(g, 'h')
|
||||
return out + ctx
|
||||
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
class RGATv4(nn.Module):
|
||||
# multi-layer GNN for one single feature
|
||||
def __init__(
|
||||
self,
|
||||
hidden_channels,
|
||||
out_channels,
|
||||
num_heads,
|
||||
rel_names,
|
||||
dropout,
|
||||
n_layers,
|
||||
activation=nn.ELU(),
|
||||
residual=False
|
||||
):
|
||||
super().__init__()
|
||||
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
|
||||
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
|
||||
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
|
||||
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
|
||||
self.combine_attr = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads)
|
||||
)
|
||||
self.combine_pos = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads)
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
dglnn.HeteroGraphConv({
|
||||
rel: dglnn.GATv2Conv(
|
||||
in_feats=hidden_channels*num_heads,
|
||||
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
|
||||
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
|
||||
feat_drop=dropout,
|
||||
attn_drop=dropout,
|
||||
residual=residual,
|
||||
share_weights=False,
|
||||
activation=activation if l < n_layers - 1 else None
|
||||
)
|
||||
for rel in rel_names}, aggregate='sum')
|
||||
for l in range(n_layers)
|
||||
])
|
||||
|
||||
def forward(self, g):
|
||||
attr = []
|
||||
attr.append(self.embedding_type(g.ndata['type'].float()))
|
||||
pos = self.embedding_pos(g.ndata['pos']/170.)
|
||||
attr.append(self.embedding_color(g.ndata['color']/255.))
|
||||
attr.append(self.embedding_shape(g.ndata['shape'].float()))
|
||||
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
|
||||
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
|
||||
for l, conv in enumerate(self.layers):
|
||||
h = conv(g, h)
|
||||
if l != len(self.layers) - 1:
|
||||
h = {k: v.flatten(1) for k, v in h.items()}
|
||||
else:
|
||||
h = {k: v.mean(1) for k, v in h.items()}
|
||||
with g.local_scope():
|
||||
g.ndata['h'] = h['obj']
|
||||
out = dgl.mean_nodes(g, 'h')
|
||||
return out
|
||||
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
class RGATv4Norm(nn.Module):
|
||||
# multi-layer GNN for one single feature
|
||||
def __init__(
|
||||
self,
|
||||
hidden_channels,
|
||||
out_channels,
|
||||
num_heads,
|
||||
rel_names,
|
||||
dropout,
|
||||
n_layers,
|
||||
activation=nn.ELU(),
|
||||
residual=False
|
||||
):
|
||||
super().__init__()
|
||||
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
|
||||
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
|
||||
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
|
||||
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
|
||||
self.combine_attr = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads)
|
||||
)
|
||||
self.combine_pos = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads)
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
dglnn.HeteroGraphConv({
|
||||
rel: dglnn.GATv2Conv(
|
||||
in_feats=hidden_channels*num_heads,
|
||||
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
|
||||
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
|
||||
feat_drop=dropout,
|
||||
attn_drop=dropout,
|
||||
residual=residual,
|
||||
activation=activation if l < n_layers - 1 else None
|
||||
)
|
||||
for rel in rel_names}, aggregate='sum')
|
||||
for l in range(n_layers)
|
||||
])
|
||||
self.norms = nn.ModuleList([
|
||||
Norm(
|
||||
norm_type='gn',
|
||||
hidden_dim=hidden_channels*num_heads if l < n_layers - 1 else out_channels
|
||||
)
|
||||
for l in range(n_layers)
|
||||
])
|
||||
|
||||
def forward(self, g):
|
||||
attr = []
|
||||
attr.append(self.embedding_type(g.ndata['type'].float()))
|
||||
pos = self.embedding_pos(g.ndata['pos']/170.)
|
||||
attr.append(self.embedding_color(g.ndata['color']/255.))
|
||||
attr.append(self.embedding_shape(g.ndata['shape'].float()))
|
||||
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
|
||||
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
|
||||
for l, conv in enumerate(self.layers):
|
||||
h = conv(g, h)
|
||||
if l != len(self.layers) - 1:
|
||||
h = {k: v.flatten(1) for k, v in h.items()}
|
||||
else:
|
||||
h = {k: v.mean(1) for k, v in h.items()}
|
||||
h = {k: self.norms[l](g, v) for k, v in h.items()}
|
||||
with g.local_scope():
|
||||
g.ndata['h'] = h['obj']
|
||||
out = dgl.mean_nodes(g, 'h')
|
||||
return out
|
||||
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
class RGATv3Norm(nn.Module):
|
||||
# multi-layer GNN for one single feature
|
||||
def __init__(
|
||||
self,
|
||||
hidden_channels,
|
||||
out_channels,
|
||||
num_heads,
|
||||
rel_names,
|
||||
dropout,
|
||||
n_layers,
|
||||
activation=nn.ELU(),
|
||||
residual=False
|
||||
):
|
||||
super().__init__()
|
||||
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
|
||||
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
|
||||
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
|
||||
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
|
||||
self.combine = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_channels*num_heads*4, hidden_channels*num_heads)
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
dglnn.HeteroGraphConv({
|
||||
rel: dglnn.GATv2Conv(
|
||||
in_feats=hidden_channels*num_heads,
|
||||
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
|
||||
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
|
||||
feat_drop=dropout,
|
||||
attn_drop=dropout,
|
||||
residual=residual,
|
||||
activation=activation if l < n_layers - 1 else None
|
||||
)
|
||||
for rel in rel_names}, aggregate='sum')
|
||||
for l in range(n_layers)
|
||||
])
|
||||
self.norms = nn.ModuleList([
|
||||
Norm(
|
||||
norm_type='gn',
|
||||
hidden_dim=hidden_channels*num_heads if l < n_layers - 1 else out_channels
|
||||
)
|
||||
for l in range(n_layers)
|
||||
])
|
||||
|
||||
def forward(self, g):
|
||||
feats = []
|
||||
feats.append(self.embedding_type(g.ndata['type'].float()))
|
||||
feats.append(self.embedding_pos(g.ndata['pos']/170.))
|
||||
feats.append(self.embedding_color(g.ndata['color']/255.))
|
||||
feats.append(self.embedding_shape(g.ndata['shape'].float()))
|
||||
h = {'obj': self.combine(torch.cat(feats, dim=1))}
|
||||
for l, conv in enumerate(self.layers):
|
||||
h = conv(g, h)
|
||||
if l != len(self.layers) - 1:
|
||||
h = {k: v.flatten(1) for k, v in h.items()}
|
||||
else:
|
||||
h = {k: v.mean(1) for k, v in h.items()}
|
||||
h = {k: self.norms[l](g, v) for k, v in h.items()}
|
||||
with g.local_scope():
|
||||
g.ndata['h'] = h['obj']
|
||||
out = dgl.mean_nodes(g, 'h')
|
||||
return out
|
||||
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
class RGATv4Agent(nn.Module):
|
||||
# multi-layer GNN for one single feature
|
||||
def __init__(
|
||||
self,
|
||||
hidden_channels,
|
||||
out_channels,
|
||||
num_heads,
|
||||
rel_names,
|
||||
dropout,
|
||||
n_layers,
|
||||
activation=nn.ELU(),
|
||||
residual=False
|
||||
):
|
||||
super().__init__()
|
||||
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
|
||||
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
|
||||
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
|
||||
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
|
||||
self.combine_attr = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads)
|
||||
)
|
||||
self.combine_pos = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads)
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
dglnn.HeteroGraphConv({
|
||||
rel: dglnn.GATv2Conv(
|
||||
in_feats=hidden_channels*num_heads,
|
||||
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
|
||||
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
|
||||
feat_drop=dropout,
|
||||
attn_drop=dropout,
|
||||
residual=residual,
|
||||
activation=activation if l < n_layers - 1 else None
|
||||
)
|
||||
for rel in rel_names}, aggregate='sum')
|
||||
for l in range(n_layers)
|
||||
])
|
||||
self.combine_agent_context = nn.Linear(out_channels*2, out_channels)
|
||||
|
||||
def forward(self, g):
|
||||
agent_mask = g.ndata['type'][:, 0] == 1
|
||||
attr = []
|
||||
attr.append(self.embedding_type(g.ndata['type'].float()))
|
||||
pos = self.embedding_pos(g.ndata['pos']/170.)
|
||||
attr.append(self.embedding_color(g.ndata['color']/255.))
|
||||
attr.append(self.embedding_shape(g.ndata['shape'].float()))
|
||||
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
|
||||
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
|
||||
for l, conv in enumerate(self.layers):
|
||||
h = conv(g, h)
|
||||
if l != len(self.layers) - 1:
|
||||
h = {k: v.flatten(1) for k, v in h.items()}
|
||||
else:
|
||||
h = {k: v.mean(1) for k, v in h.items()}
|
||||
with g.local_scope():
|
||||
g.ndata['h'] = h['obj']
|
||||
h_a = g.ndata['h'][agent_mask, :]
|
||||
g_no_agent = copy.deepcopy(g)
|
||||
g_no_agent.remove_nodes([i for i, x in enumerate(agent_mask) if x])
|
||||
h_g = dgl.mean_nodes(g_no_agent, 'h')
|
||||
out = self.combine_agent_context(torch.cat((h_a, h_g), dim=1))
|
||||
return out
|
||||
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
class RGCNv4(nn.Module):
|
||||
# multi-layer GNN for one single feature
|
||||
def __init__(
|
||||
self,
|
||||
hidden_channels,
|
||||
out_channels,
|
||||
rel_names,
|
||||
n_layers,
|
||||
activation=nn.ELU(),
|
||||
):
|
||||
super().__init__()
|
||||
self.embedding_type = nn.Linear(9, int(hidden_channels))
|
||||
self.embedding_pos = nn.Linear(2, int(hidden_channels))
|
||||
self.embedding_color = nn.Linear(3, int(hidden_channels))
|
||||
self.embedding_shape = nn.Linear(18, int(hidden_channels))
|
||||
self.combine_attr = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_channels*3, hidden_channels)
|
||||
)
|
||||
self.combine_pos = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_channels*2, hidden_channels)
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
dglnn.HeteroGraphConv({
|
||||
rel: dglnn.GraphConv(
|
||||
in_feats=hidden_channels,
|
||||
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
|
||||
activation=activation if l < n_layers - 1 else None
|
||||
)
|
||||
for rel in rel_names}, aggregate='sum')
|
||||
for l in range(n_layers)
|
||||
])
|
||||
|
||||
def forward(self, g):
|
||||
attr = []
|
||||
attr.append(self.embedding_type(g.ndata['type'].float()))
|
||||
pos = self.embedding_pos(g.ndata['pos']/170.)
|
||||
attr.append(self.embedding_color(g.ndata['color']/255.))
|
||||
attr.append(self.embedding_shape(g.ndata['shape'].float()))
|
||||
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
|
||||
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
|
||||
for l, conv in enumerate(self.layers):
|
||||
h = conv(g, h)
|
||||
with g.local_scope():
|
||||
g.ndata['h'] = h['obj']
|
||||
out = dgl.mean_nodes(g, 'h')
|
||||
return out
|
||||
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
class RGATv5(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_channels,
|
||||
out_channels,
|
||||
num_heads,
|
||||
rel_names,
|
||||
dropout,
|
||||
n_layers,
|
||||
activation=nn.ELU(),
|
||||
residual=False
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_channels = hidden_channels
|
||||
self.num_heads = num_heads
|
||||
self.embedding_type = nn.Linear(9, hidden_channels*num_heads)
|
||||
self.embedding_pos = nn.Linear(2, hidden_channels*num_heads)
|
||||
self.embedding_color = nn.Linear(3, hidden_channels*num_heads)
|
||||
self.embedding_shape = nn.Linear(18, hidden_channels*num_heads)
|
||||
self.combine = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_channels*num_heads*4, hidden_channels*num_heads)
|
||||
)
|
||||
self.attention = nn.Linear(hidden_channels*num_heads*4, 4)
|
||||
self.layers = nn.ModuleList([
|
||||
dglnn.HeteroGraphConv({
|
||||
rel: dglnn.GATv2Conv(
|
||||
in_feats=hidden_channels*num_heads,
|
||||
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
|
||||
num_heads=num_heads,
|
||||
feat_drop=dropout,
|
||||
attn_drop=dropout,
|
||||
residual=residual,
|
||||
activation=activation if l < n_layers - 1 else None
|
||||
)
|
||||
for rel in rel_names}, aggregate='sum')
|
||||
for l in range(n_layers)
|
||||
])
|
||||
|
||||
def forward(self, g):
|
||||
feats = []
|
||||
feats.append(self.embedding_type(g.ndata['type'].float()))
|
||||
feats.append(self.embedding_pos(g.ndata['pos']/170.))
|
||||
feats.append(self.embedding_color(g.ndata['color']/255.))
|
||||
feats.append(self.embedding_shape(g.ndata['shape'].float()))
|
||||
h = torch.cat(feats, dim=1)
|
||||
feat_attn = F.softmax(self.attention(h), dim=1)
|
||||
h = h * feat_attn.repeat_interleave(self.hidden_channels*self.num_heads, dim=1)
|
||||
h_in = self.combine(h)
|
||||
h = {'obj': h_in}
|
||||
for conv in self.layers:
|
||||
h = conv(g, h)
|
||||
h = {k: v.flatten(1) for k, v in h.items()}
|
||||
#if l != len(self.layers) - 1:
|
||||
# h = {k: v.flatten(1) for k, v in h.items()}
|
||||
#else:
|
||||
# h = {k: v.mean(1) for k, v in h.items()}
|
||||
h = {k: v + h_in for k, v in h.items()}
|
||||
with g.local_scope():
|
||||
g.ndata['h'] = h['obj']
|
||||
out = dgl.mean_nodes(g, 'h')
|
||||
return out
|
||||
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
class RGATv6(nn.Module):
|
||||
|
||||
# RGATv6 = RGATv4 + Global Attention Pooling
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_channels,
|
||||
out_channels,
|
||||
num_heads,
|
||||
rel_names,
|
||||
dropout,
|
||||
n_layers,
|
||||
activation=nn.ELU(),
|
||||
residual=False
|
||||
):
|
||||
super().__init__()
|
||||
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
|
||||
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
|
||||
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
|
||||
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
|
||||
self.combine_attr = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads)
|
||||
)
|
||||
self.combine_pos = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads)
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
dglnn.HeteroGraphConv({
|
||||
rel: dglnn.GATv2Conv(
|
||||
in_feats=hidden_channels*num_heads,
|
||||
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
|
||||
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
|
||||
feat_drop=dropout,
|
||||
attn_drop=dropout,
|
||||
residual=residual,
|
||||
activation=activation if l < n_layers - 1 else None
|
||||
)
|
||||
for rel in rel_names}, aggregate='sum')
|
||||
for l in range(n_layers)
|
||||
])
|
||||
gate_nn = nn.Linear(out_channels, 1)
|
||||
self.gap = dglnn.GlobalAttentionPooling(gate_nn)
|
||||
|
||||
def forward(self, g):
|
||||
attr = []
|
||||
attr.append(self.embedding_type(g.ndata['type'].float()))
|
||||
pos = self.embedding_pos(g.ndata['pos']/170.)
|
||||
attr.append(self.embedding_color(g.ndata['color']/255.))
|
||||
attr.append(self.embedding_shape(g.ndata['shape'].float()))
|
||||
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
|
||||
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
|
||||
for l, conv in enumerate(self.layers):
|
||||
h = conv(g, h)
|
||||
if l != len(self.layers) - 1:
|
||||
h = {k: v.flatten(1) for k, v in h.items()}
|
||||
else:
|
||||
h = {k: v.mean(1) for k, v in h.items()}
|
||||
#with g.local_scope():
|
||||
#g.ndata['h'] = h['obj']
|
||||
#out = dgl.mean_nodes(g, 'h')
|
||||
out = self.gap(g, h['obj'])
|
||||
return out
|
||||
|
||||
# -------------------------------------------------------------------------------------------
|
||||
|
||||
class RGATv6Agent(nn.Module):
|
||||
|
||||
# RGATv6 = RGATv4 + Global Attention Pooling
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_channels,
|
||||
out_channels,
|
||||
num_heads,
|
||||
rel_names,
|
||||
dropout,
|
||||
n_layers,
|
||||
activation=nn.ELU(),
|
||||
residual=False
|
||||
):
|
||||
super().__init__()
|
||||
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
|