This commit is contained in:
Matteo Bortoletto 2024-02-01 15:40:47 +01:00
parent a333481e05
commit de0bea7508
18 changed files with 3150 additions and 2 deletions

View file

@ -1,3 +1,75 @@
# IRENE
<div align="center">
<h1> Neural Reasoning about Agents' Goals, Preferences, and Actions </h1>
Official code of "Neural Reasoning About Agents' Goals, Preferences, and Actions"
**[Matteo Bortoletto][1], &nbsp; [Lei Shi][2], &nbsp; [Andreas Bulling][3]** <br> <br>
**AAAI'24, Vancouver, CA** <br>
**[[Paper][4]]**
</div>
# Citation
If you find our code useful or use it in your own projects, please cite our paper:
```bibtex
@inproceedings{bortoletto2024neural,
author = {Bortoletto, Matteo and Lei, Shi and Bulling, Andreas},
title = {{Neural Reasoning about Agents' Goals, Preferences, and Actions}},
booktitle = {Proc. 38th AAAI Conference on Artificial Intelligence (AAAI)},
year = {2024},
}
```
# Setup
This code is based on the [original implementation][5] of the BIB benchmark.
## Using `virtualenv`
```
python -m virtualenv /path/to/env
source /path/to/env/bin/activate
pip install -r requirements.txt
```
## Using `conda`
```
conda create --name <env_name> python=3.8.10 pip=20.0.2 cudatoolkit=10.2.89
conda activate <env_name>
pip install -r requirements_conda.txt
pip install dgl-cu102 dglgo -f https://data.dgl.ai/wheels/repo.html
```
# Running the code
## Activate the environment
Run `source bibdgl/bin/activate`.
## Index data
This will create the json files with all the indexed frames for each episode in each video.
```
python utils/index_data.py
```
You need to manually set `mode` in the dataset class (in main).
## Generate graphs
This will generate the graphs from the videos:
```
python /utils/build_graphs.py --mode MODE --cpus NUM_CPUS
```
`MODE` can be `train`, `val` or `test`. NOTE: check `utils/build_graphs.py` to make sure you're loading the correct dataset to generate the graphs you want.
## Training
You can use the `gtbc.sh`.
## Testing
Use `run_test_tom.sh`.
# Hardware setup
All models are trained on an NVIDIA Tesla V100-SXM2-32GB GPU.
[1]: https://mattbortoletto.github.io/
[2]: https://perceptualui.org/people/shi/
[3]: https://perceptualui.org/people/bulling/
[4]: https://perceptualui.org/publications/bortoletto24_aaai.pdf
[5]: https://github.com/kanishkg/bib-baselines

9
run_test.sh Normal file
View file

@ -0,0 +1,9 @@
echo 314 e31
CUDA_VISIBLE_DEVICES=1 python test_tom.py \
--model_type graphbcrnn \
--types efficiency_irrational \
--ckpt /projects/bortoletto/icml2023_matteo/wandb/run-20221224_135525-8i1r2aqy/files/bib/8i1r2aqy/checkpoints/epoch\=31-step\=22399.ckpt \
--data_path /datasets/external/bib_evaluation_1_1/graphs/all_tasks \
--process_data 0 \
--surprise_type max

18
run_train.sh Normal file
View file

@ -0,0 +1,18 @@
CUDA_VISIBLE_DEVICES=0 python train_tom.py \
--model_type graphbcrnn \
--types single_object preference instrumental_action \
--data_path /datasets/external/bib_train/graphs/all_tasks/ \
--seed 7 \
--batch_size 32 \
--max_epochs 35 \
--gpus 1 \
--auto_select_gpus True \
--num_workers 2 \
--stochastic_weight_avg True \
--lr 5e-4 \
--check_val_every_n_epoch 1 \
--track_grad_norm 2 \
--gradient_clip_val 10 \
--gnn_type RSAGEv4 \
--state_dim 96 \
--aggregation sum

122
test_tom.py Normal file
View file

@ -0,0 +1,122 @@
from argparse import ArgumentParser
import numpy as np
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
import dgl
from tom.dataset import TestToMnetDGLDataset, collate_function_seq_test
from tom.model import GraphBC_T, GraphBCRNN
def get_z_scores(total, total_expected, total_unexpected):
mean = np.mean(total)
std = np.std(total)
print("Z-Score expected: ",
(np.mean(total_expected) - mean) / std)
print("Z-Score unexpected: ",
(np.mean(total_unexpected) - mean) / std)
parser = ArgumentParser()
parser.add_argument('--model_type', type=str, default='graphbcrnn')
parser.add_argument('--ckpt', type=str, default=None, help='path to checkpoint')
parser.add_argument('--data_path', type=str, default=None, help='path to the data')
parser.add_argument('--process_data', type=int, default=0)
parser.add_argument('--surprise_type', type=str, default='max',
help='surprise type: mean, max. This is used for comparing the plausibility scores of the two test episodes')
parser.add_argument('--types', nargs='+', type=str,
default=[
'preference', 'multi_agent', 'inaccessible_goal',
'efficiency_irrational', 'efficiency_time','efficiency_path',
'instrumental_no_barrier', 'instrumental_blocking_barrier', 'instrumental_inconsequential_barrier'
],
help='types of tasks used for training / testing')
parser.add_argument('--filename', type=str, default='')
args = parser.parse_args()
filename = args.filename
if args.model_type == 'graphbct':
model = GraphBC_T.load_from_checkpoint(args.ckpt)
elif args.model_type == 'graphbcrnn':
model = GraphBCRNN.load_from_checkpoint(args.ckpt)
else:
raise ValueError('Unknown model type.')
device = 'cuda'
model.to(device)
model.eval()
with torch.no_grad():
for t in args.types:
if args.model_type == 'graphbcrnn':
test_dataset = TestToMnetDGLDataset(
path=args.data_path,
task_type=t,
mode='test'
)
test_dataloader = DataLoader(
test_dataset,
batch_size=1,
num_workers=1,
pin_memory=True,
collate_fn=collate_function_seq_test,
shuffle=False
)
count = 0
total, total_expected, total_unexpected = [], [], []
pbar = tqdm(test_dataloader)
for j, batch in enumerate(pbar):
if args.model_type == 'graphbcrnn':
dem_expected_states, dem_expected_actions, dem_expected_lens, \
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, \
query_expected_frames, target_expected_actions, \
query_unexpected_frames, target_unexpected_actions = batch
dem_expected_states = dem_expected_states.to(device)
dem_expected_actions = dem_expected_actions.to(device)
dem_unexpected_states = dem_unexpected_states.to(device)
dem_unexpected_actions = dem_unexpected_actions.to(device)
target_expected_actions = target_expected_actions.to(device)
target_unexpected_actions = target_unexpected_actions.to(device)
surprise_expected = []
query_expected_frames = dgl.unbatch(query_expected_frames)
for i in range(len(query_expected_frames)):
if args.model_type == 'graphbcrnn':
test_actions, test_actions_pred = model(
[dem_expected_states, dem_expected_actions, dem_expected_lens, query_expected_frames[i].to(device), target_expected_actions[:, i, :]]
)
loss = F.mse_loss(test_actions, test_actions_pred)
surprise_expected.append(loss.cpu().detach().numpy())
mean_expected_surprise = np.mean(surprise_expected)
max_expected_surprise = np.max(surprise_expected)
# calculate the plausibility scores for the unexpected episode
surprise_unexpected = []
query_unexpected_frames = dgl.unbatch(query_unexpected_frames)
for i in range(len(query_unexpected_frames)):
if args.model_type == 'graphbcrnn':
test_actions, test_actions_pred = model(
[dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, query_unexpected_frames[i].to(device), target_unexpected_actions[:, i, :]]
)
loss = F.mse_loss(test_actions, test_actions_pred)
surprise_unexpected.append(loss.cpu().detach().numpy())
mean_unexpected_surprise = np.mean(surprise_unexpected)
max_unexpected_surprise = np.max(surprise_unexpected)
correct_mean = mean_expected_surprise < mean_unexpected_surprise + 0.5 * (mean_expected_surprise == mean_unexpected_surprise)
correct_max = max_expected_surprise < max_unexpected_surprise + 0.5 * (max_expected_surprise == max_unexpected_surprise)
if args.surprise_type == 'max':
count += correct_max
elif args.surprise_type == 'mean':
count += correct_mean
pbar.set_postfix({'accuracy': count/(j+1.), 'type': t})
total_expected.append(max_expected_surprise)
total_unexpected.append(max_unexpected_surprise)
total.append(max_expected_surprise)
total.append(max_unexpected_surprise)
get_z_scores(total, total_expected, total_unexpected)

0
tom/__init__.py Normal file
View file

310
tom/dataset.py Normal file
View file

@ -0,0 +1,310 @@
import pickle as pkl
import os
import torch
import torch.utils.data
import torch.nn.functional as F
import dgl
import random
from dgl.data import DGLDataset
def collate_function_seq(batch):
#dem_frames = torch.stack([item[0] for item in batch])
dem_frames = dgl.batch([item[0] for item in batch])
dem_actions = torch.stack([item[1] for item in batch])
dem_lens = [item[2] for item in batch]
#query_frames = torch.stack([item[3] for item in batch])
query_frames = dgl.batch([item[3] for item in batch])
target_actions = torch.stack([item[4] for item in batch])
return [dem_frames, dem_actions, dem_lens, query_frames, target_actions]
def collate_function_seq_test(batch):
dem_expected_states = dgl.batch([item[0] for item in batch][0])
dem_expected_actions = torch.stack([item[1] for item in batch][0]).unsqueeze(dim=0)
dem_expected_lens = [item[2] for item in batch]
#print(dem_expected_actions.size())
dem_unexpected_states = dgl.batch([item[3] for item in batch][0])
dem_unexpected_actions = torch.stack([item[4] for item in batch][0]).unsqueeze(dim=0)
dem_unexpected_lens = [item[5] for item in batch]
query_expected_frames = dgl.batch([item[6] for item in batch])
target_expected_actions = torch.stack([item[7] for item in batch])
#print(target_expected_actions.size())
query_unexpected_frames = dgl.batch([item[8] for item in batch])
target_unexpected_actions = torch.stack([item[9] for item in batch])
return [
dem_expected_states, dem_expected_actions, dem_expected_lens, \
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, \
query_expected_frames, target_expected_actions, \
query_unexpected_frames, target_unexpected_actions
]
def collate_function_mental(batch):
dem_frames = dgl.batch([item[0] for item in batch])
dem_actions = torch.stack([item[1] for item in batch])
dem_lens = [item[2] for item in batch]
past_test_frames = dgl.batch([item[3] for item in batch])
past_test_actions = torch.stack([item[4] for item in batch])
past_test_len = [item[5] for item in batch]
query_frames = dgl.batch([item[6] for item in batch])
target_actions = torch.stack([item[7] for item in batch])
return [dem_frames, dem_actions, dem_lens, past_test_frames, past_test_actions, past_test_len, query_frames, target_actions]
class ToMnetDGLDataset(DGLDataset):
"""
Training dataset class.
"""
def __init__(self, path, types=None, mode="train"):
self.path = path
self.types = types
self.mode = mode
print('Mode:', self.mode)
if self.mode == 'train':
if len(self.types) == 4:
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/'
#self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_global/'
#self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_local/'
elif len(self.types) == 3:
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper() + '/'
print(self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper())
elif len(self.types) == 2:
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + '/'
print(self.types[0][0].upper() + self.types[1][0].upper())
elif len(self.types) == 1:
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + '/'
else: raise ValueError('Number of types different from 1 or 4.')
elif self.mode == 'val':
assert len(self.types) == 1
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/' + self.types[0] + '/'
#self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_global/' + self.types[0] + '/'
#self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_local/' + self.types[0] + '/'
else:
raise ValueError
def get_test(self, states, actions):
# now states is a batched graph -> unbatch it, take the len, pick one sub-graph
# randomly and select the corresponding action
frame_graphs = dgl.unbatch(states)
trial_len = len(frame_graphs)
query_idx = random.randint(0, trial_len - 1)
query_graph = frame_graphs[query_idx]
target_action = actions[query_idx]
return query_graph, target_action
def __getitem__(self, idx):
with open(self.path+str(idx)+'.pkl', 'rb') as f:
states, actions, lens, _ = pkl.load(f)
# shuffle
ziplist = list(zip(states, actions, lens))
random.shuffle(ziplist)
states, actions, lens = zip(*ziplist)
# convert tuples to lists
states, actions, lens = [*states], [*actions], [*lens]
# pick last element in the list as test and pick random frame
test_s, test_a = self.get_test(states[-1], actions[-1])
dem_s = states[:-1]
dem_a = actions[:-1]
dem_lens = lens[:-1]
dem_s = dgl.batch(dem_s)
dem_a = torch.stack(dem_a)
return dem_s, dem_a, dem_lens, test_s, test_a
def __len__(self):
return len(os.listdir(self.path))
class TestToMnetDGLDataset(DGLDataset):
"""
Testing dataset class.
"""
def __init__(self, path, task_type=None, mode="test"):
self.path = path
self.type = task_type
self.mode = mode
print('Mode:', self.mode)
if self.mode == 'test':
self.path = self.path + '_dgl_hetero_nobound_4feats/' + self.type + '/'
#self.path = self.path + '_dgl_hetero_nobound_4feats_global/' + self.type + '/'
#self.path = self.path + '_dgl_hetero_nobound_4feats_local/' + self.type + '/'
else:
raise ValueError
def __getitem__(self, idx):
with open(self.path+str(idx)+'.pkl', 'rb') as f:
dem_expected_states, dem_expected_actions, dem_expected_lens, _, \
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, _, \
query_expected_frames, target_expected_actions, \
query_unexpected_frames, target_unexpected_actions = pkl.load(f)
assert len(dem_expected_states) == 8
assert len(dem_expected_actions) == 8
assert len(dem_expected_lens) == 8
assert len(dem_unexpected_states) == 8
assert len(dem_unexpected_actions) == 8
assert len(dem_unexpected_lens) == 8
assert len(dgl.unbatch(query_expected_frames)) == target_expected_actions.size()[0]
assert len(dgl.unbatch(query_unexpected_frames)) == target_unexpected_actions.size()[0]
# ignore n_nodes
return dem_expected_states, dem_expected_actions, dem_expected_lens, \
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, \
query_expected_frames, target_expected_actions, \
query_unexpected_frames, target_unexpected_actions
def __len__(self):
return len(os.listdir(self.path))
class ToMnetDGLDatasetUndersample(DGLDataset):
"""
Training dataset class for the behavior cloning mlp model.
"""
def __init__(self, path, types=None, mode="train"):
self.path = path
self.types = types
self.mode = mode
print('Mode:', self.mode)
if self.mode == 'train':
if len(self.types) == 4:
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/'
elif len(self.types) == 3:
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper() + '/'
print(self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper())
elif len(self.types) == 2:
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + '/'
print(self.types[0][0].upper() + self.types[1][0].upper())
elif len(self.types) == 1:
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + '/'
else: raise ValueError('Number of types different from 1 or 4.')
elif self.mode == 'val':
assert len(self.types) == 1
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/' + self.types[0] + '/'
else:
raise ValueError
print('Undersampled dataset!')
def get_test(self, states, actions):
# now states is a batched graph -> unbatch it, take the len, pick one sub-graph
# randomly and select the corresponding action
frame_graphs = dgl.unbatch(states)
trial_len = len(frame_graphs)
query_idx = random.randint(0, trial_len - 1)
query_graph = frame_graphs[query_idx]
target_action = actions[query_idx]
return query_graph, target_action
def __getitem__(self, idx):
idx = idx + 3175
with open(self.path+str(idx)+'.pkl', 'rb') as f:
states, actions, lens, _ = pkl.load(f)
# shuffle
ziplist = list(zip(states, actions, lens))
random.shuffle(ziplist)
states, actions, lens = zip(*ziplist)
# convert tuples to lists
states, actions, lens = [*states], [*actions], [*lens]
# pick last element in the list as test and pick random frame
test_s, test_a = self.get_test(states[-1], actions[-1])
dem_s = states[:-1]
dem_a = actions[:-1]
dem_lens = lens[:-1]
dem_s = dgl.batch(dem_s)
dem_a = torch.stack(dem_a)
return dem_s, dem_a, dem_lens, test_s, test_a
def __len__(self):
return len(os.listdir(self.path)) - 3175
class ToMnetDGLDatasetMental(DGLDataset):
"""
Training dataset class.
"""
def __init__(self, path, types=None, mode="train"):
self.path = path
self.types = types
self.mode = mode
print('Mode:', self.mode)
if self.mode == 'train':
if len(self.types) == 4:
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/'
elif len(self.types) == 3:
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper() + '/'
print(self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper())
elif len(self.types) == 2:
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + '/'
print(self.types[0][0].upper() + self.types[1][0].upper())
elif len(self.types) == 1:
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + '/'
else: raise ValueError('Number of types different from 1 or 4.')
elif self.mode == 'val':
assert len(self.types) == 1
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/' + self.types[0] + '/'
else:
raise ValueError
def get_test(self, states, actions):
"""
return: past_test_graphs, past_test_actions, test_graph, test_action
"""
frame_graphs = dgl.unbatch(states)
trial_len = len(frame_graphs)
query_idx = random.randint(0, trial_len - 1)
test_graph = frame_graphs[query_idx]
test_action = actions[query_idx]
if query_idx > 0:
#past_test_actions = F.pad(past_test_actions, (0,0,0,41-query_idx), 'constant', 0)
if query_idx == 1:
past_test_graphs = frame_graphs[0]
past_test_actions = actions[:query_idx]
past_test_actions = F.pad(past_test_actions, (0,0,0,41-query_idx), 'constant', 0)
return past_test_graphs, past_test_actions, query_idx, test_graph, test_action
else:
past_test_graphs = frame_graphs[:query_idx]
past_test_actions = actions[:query_idx]
past_test_graphs = dgl.batch(past_test_graphs)
past_test_actions = F.pad(past_test_actions, (0,0,0,41-query_idx), 'constant', 0)
return past_test_graphs, past_test_actions, query_idx, test_graph, test_action
else:
test_action_ = F.pad(test_action.unsqueeze(0), (0,0,0,41-1), 'constant', 0)
# NOTE: since there are no past frames, return the test frame and action and query_idx=1
# not sure what would be a better alternative
return test_graph, test_action_, 1, test_graph, test_action
def __getitem__(self, idx):
with open(self.path+str(idx)+'.pkl', 'rb') as f:
states, actions, lens, _ = pkl.load(f)
ziplist = list(zip(states, actions, lens))
random.shuffle(ziplist)
states, actions, lens = zip(*ziplist)
states, actions, lens = [*states], [*actions], [*lens]
past_test_s, past_test_a, past_test_len, test_s, test_a = self.get_test(states[-1], actions[-1])
dem_s = states[:-1]
dem_a = actions[:-1]
dem_lens = lens[:-1]
dem_s = dgl.batch(dem_s)
dem_a = torch.stack(dem_a)
return dem_s, dem_a, dem_lens, past_test_s, past_test_a, past_test_len, test_s, test_a
def __len__(self):
return len(os.listdir(self.path))
# --------------------------------------------------------------------------------------------------------
if __name__ == "__main__":
types = [
'preference', 'multi_agent', 'inaccessible_goal',
'efficiency_irrational', 'efficiency_time','efficiency_path',
'instrumental_no_barrier', 'instrumental_blocking_barrier', 'instrumental_inconsequential_barrier'
]
mental_dataset = ToMnetDGLDatasetMental(
path='/datasets/external/bib_train/graphs/all_tasks/',
types=['instrumental_action'],
mode='train'
)
dem_frames, dem_actions, dem_lens, past_test_frames, past_test_actions, len, test_frame, test_action = mental_dataset.__getitem__(99)
breakpoint()

877
tom/gnn.py Normal file
View file

@ -0,0 +1,877 @@
import dgl.nn.pytorch as dglnn
import torch.nn as nn
import torch.nn.functional as F
import torch
import dgl
import sys
import copy
from wandb import agent
sys.path.append('/projects/bortoletto/irene/')
from tom.norm import Norm
class RSAGEv4(nn.Module):
# multi-layer GNN for one single feature
def __init__(
self,
hidden_channels,
out_channels,
rel_names,
dropout,
n_layers,
activation=nn.ELU(),
):
super().__init__()
self.embedding_type = nn.Linear(9, int(hidden_channels))
self.embedding_pos = nn.Linear(2, int(hidden_channels))
self.embedding_color = nn.Linear(3, int(hidden_channels))
self.embedding_shape = nn.Linear(18, int(hidden_channels))
self.combine_attr = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*3, hidden_channels)
)
self.combine_pos = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*2, hidden_channels)
)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.SAGEConv(
in_feats=hidden_channels,
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
aggregator_type='lstm',
feat_drop=dropout,
activation=activation if l < n_layers - 1 else None
)
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
def forward(self, g):
attr = []
attr.append(self.embedding_type(g.ndata['type'].float()))
pos = self.embedding_pos(g.ndata['pos']/170.)
attr.append(self.embedding_color(g.ndata['color']/255.))
attr.append(self.embedding_shape(g.ndata['shape'].float()))
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
for l, conv in enumerate(self.layers):
h = conv(g, h)
with g.local_scope():
g.ndata['h'] = h['obj']
out = dgl.mean_nodes(g, 'h')
return out
# -------------------------------------------------------------------------------------------
class RAGNNv4(nn.Module):
# multi-layer GNN for one single feature
def __init__(
self,
hidden_channels,
out_channels,
rel_names,
dropout,
n_layers,
activation=nn.ELU(),
):
super().__init__()
self.embedding_type = nn.Linear(9, int(hidden_channels))
self.embedding_pos = nn.Linear(2, int(hidden_channels))
self.embedding_color = nn.Linear(3, int(hidden_channels))
self.embedding_shape = nn.Linear(18, int(hidden_channels))
self.combine_attr = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*3, hidden_channels)
)
self.combine_pos = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*2, hidden_channels)
)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.AGNNConv()
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
def forward(self, g):
attr = []
attr.append(self.embedding_type(g.ndata['type'].float()))
pos = self.embedding_pos(g.ndata['pos']/170.)
attr.append(self.embedding_color(g.ndata['color']/255.))
attr.append(self.embedding_shape(g.ndata['shape'].float()))
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
for l, conv in enumerate(self.layers):
h = conv(g, h)
with g.local_scope():
g.ndata['h'] = h['obj']
out = dgl.mean_nodes(g, 'h')
return out
# -------------------------------------------------------------------------------------------
class RGATv2(nn.Module):
# multi-layer GNN for one single feature
def __init__(
self,
hidden_channels,
out_channels,
num_heads,
rel_names,
dropout,
n_layers,
activation=nn.ELU(),
residual=False
):
super().__init__()
self.embedding_type = nn.Embedding(9, int(hidden_channels*num_heads/4))
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads/4))
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads/4))
self.embedding_shape = nn.Embedding(18, int(hidden_channels*num_heads/4))
self.combine = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads, hidden_channels*num_heads)
)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.GATv2Conv(
in_feats=hidden_channels*num_heads,
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
num_heads=num_heads if l < n_layers - 1 else 1,
feat_drop=dropout,
attn_drop=dropout,
residual=residual,
activation=activation if l < n_layers - 1 else None
)
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
def forward(self, g):
feats = []
feats.append(self.embedding_type(torch.argmax(g.ndata['type'], dim=1)))
feats.append(self.embedding_pos(g.ndata['pos']/170.))
feats.append(self.embedding_color(g.ndata['color']/255.))
feats.append(self.embedding_shape(torch.argmax(g.ndata['shape'], dim=1)))
h = {'obj': self.combine(torch.cat(feats, dim=1))}
for l, conv in enumerate(self.layers):
h = conv(g, h)
if l != len(self.layers) - 1:
h = {k: v.flatten(1) for k, v in h.items()}
else:
h = {k: v.mean(1) for k, v in h.items()}
with g.local_scope():
g.ndata['h'] = h['obj']
out = dgl.mean_nodes(g, 'h')
return out
# -------------------------------------------------------------------------------------------
class RGATv3(nn.Module):
# multi-layer GNN for one single feature
def __init__(
self,
hidden_channels,
out_channels,
num_heads,
rel_names,
dropout,
n_layers,
activation=nn.ELU(),
residual=False
):
super().__init__()
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
self.combine = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*4, hidden_channels*num_heads)
)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.GATv2Conv(
in_feats=hidden_channels*num_heads,
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
feat_drop=dropout,
attn_drop=dropout,
residual=residual,
activation=activation if l < n_layers - 1 else None
)
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
def forward(self, g):
feats = []
feats.append(self.embedding_type(g.ndata['type'].float()))
feats.append(self.embedding_pos(g.ndata['pos']/170.)) # NOTE: this should be 180 because I remove the boundary walls!
feats.append(self.embedding_color(g.ndata['color']/255.))
feats.append(self.embedding_shape(g.ndata['shape'].float()))
h = {'obj': self.combine(torch.cat(feats, dim=1))}
for l, conv in enumerate(self.layers):
h = conv(g, h)
if l != len(self.layers) - 1:
h = {k: v.flatten(1) for k, v in h.items()}
else:
h = {k: v.mean(1) for k, v in h.items()}
with g.local_scope():
g.ndata['h'] = h['obj']
out = dgl.mean_nodes(g, 'h')
return out
# -------------------------------------------------------------------------------------------
class RGCNv2(nn.Module):
# multi-layer GNN for one single feature
def __init__(
self,
hidden_channels,
out_channels,
rel_names,
dropout,
n_layers,
activation=nn.ReLU()
):
super().__init__()
self.embedding_type = nn.Linear(9, int(hidden_channels))
self.embedding_pos = nn.Linear(2, int(hidden_channels))
self.embedding_color = nn.Linear(3, int(hidden_channels))
self.embedding_shape = nn.Linear(18, int(hidden_channels))
self.combine = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*4, hidden_channels)
)
self.layers = nn.ModuleList([
dglnn.RelGraphConv(
in_feat=hidden_channels,
out_feat=hidden_channels,
num_rels=len(rel_names),
regularizer=None,
num_bases=None,
bias=True,
activation=activation,
self_loop=True,
dropout=dropout,
layer_norm=False
)
for _ in range(n_layers-1)])
self.layers.append(
dglnn.RelGraphConv(
in_feat=hidden_channels,
out_feat=out_channels,
num_rels=len(rel_names),
regularizer=None,
num_bases=None,
bias=True,
activation=activation,
self_loop=True,
dropout=dropout,
layer_norm=False
)
)
def forward(self, g):
g = g.to_homogeneous()
feats = []
feats.append(self.embedding_type(g.ndata['type'].float()))
feats.append(self.embedding_pos(g.ndata['pos']/170.))
feats.append(self.embedding_color(g.ndata['color']/255.))
feats.append(self.embedding_shape(g.ndata['shape'].float()))
h = self.combine(torch.cat(feats, dim=1))
for conv in self.layers:
h = conv(g, h, g.etypes)
with g.local_scope():
g.ndata['h'] = h
out = dgl.mean_nodes(g, 'h')
return out
# -------------------------------------------------------------------------------------------
class RGATv3Agent(nn.Module):
# multi-layer GNN for one single feature
def __init__(
self,
hidden_channels,
out_channels,
num_heads,
rel_names,
dropout,
n_layers,
activation=nn.ELU(),
residual=False
):
super().__init__()
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
self.combine = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*4, hidden_channels*num_heads)
)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.GATv2Conv(
in_feats=hidden_channels*num_heads,
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
feat_drop=dropout,
attn_drop=dropout,
residual=residual,
activation=activation if l < n_layers - 1 else None
)
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
def forward(self, g):
agent_mask = g.ndata['type'][:, 0] == 1
feats = []
feats.append(self.embedding_type(g.ndata['type'].float()))
feats.append(self.embedding_pos(g.ndata['pos']/200.))
feats.append(self.embedding_color(g.ndata['color']/255.))
feats.append(self.embedding_shape(g.ndata['shape'].float()))
h = {'obj': self.combine(torch.cat(feats, dim=1))}
for l, conv in enumerate(self.layers):
h = conv(g, h)
if l != len(self.layers) - 1:
h = {k: v.flatten(1) for k, v in h.items()}
else:
h = {k: v.mean(1) for k, v in h.items()}
with g.local_scope():
g.ndata['h'] = h['obj']
out = g.ndata['h'][agent_mask, :]
ctx = dgl.mean_nodes(g, 'h')
return out + ctx
# -------------------------------------------------------------------------------------------
class RGATv4(nn.Module):
# multi-layer GNN for one single feature
def __init__(
self,
hidden_channels,
out_channels,
num_heads,
rel_names,
dropout,
n_layers,
activation=nn.ELU(),
residual=False
):
super().__init__()
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
self.combine_attr = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads)
)
self.combine_pos = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads)
)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.GATv2Conv(
in_feats=hidden_channels*num_heads,
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
feat_drop=dropout,
attn_drop=dropout,
residual=residual,
share_weights=False,
activation=activation if l < n_layers - 1 else None
)
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
def forward(self, g):
attr = []
attr.append(self.embedding_type(g.ndata['type'].float()))
pos = self.embedding_pos(g.ndata['pos']/170.)
attr.append(self.embedding_color(g.ndata['color']/255.))
attr.append(self.embedding_shape(g.ndata['shape'].float()))
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
for l, conv in enumerate(self.layers):
h = conv(g, h)
if l != len(self.layers) - 1:
h = {k: v.flatten(1) for k, v in h.items()}
else:
h = {k: v.mean(1) for k, v in h.items()}
with g.local_scope():
g.ndata['h'] = h['obj']
out = dgl.mean_nodes(g, 'h')
return out
# -------------------------------------------------------------------------------------------
class RGATv4Norm(nn.Module):
# multi-layer GNN for one single feature
def __init__(
self,
hidden_channels,
out_channels,
num_heads,
rel_names,
dropout,
n_layers,
activation=nn.ELU(),
residual=False
):
super().__init__()
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
self.combine_attr = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads)
)
self.combine_pos = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads)
)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.GATv2Conv(
in_feats=hidden_channels*num_heads,
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
feat_drop=dropout,
attn_drop=dropout,
residual=residual,
activation=activation if l < n_layers - 1 else None
)
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
self.norms = nn.ModuleList([
Norm(
norm_type='gn',
hidden_dim=hidden_channels*num_heads if l < n_layers - 1 else out_channels
)
for l in range(n_layers)
])
def forward(self, g):
attr = []
attr.append(self.embedding_type(g.ndata['type'].float()))
pos = self.embedding_pos(g.ndata['pos']/170.)
attr.append(self.embedding_color(g.ndata['color']/255.))
attr.append(self.embedding_shape(g.ndata['shape'].float()))
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
for l, conv in enumerate(self.layers):
h = conv(g, h)
if l != len(self.layers) - 1:
h = {k: v.flatten(1) for k, v in h.items()}
else:
h = {k: v.mean(1) for k, v in h.items()}
h = {k: self.norms[l](g, v) for k, v in h.items()}
with g.local_scope():
g.ndata['h'] = h['obj']
out = dgl.mean_nodes(g, 'h')
return out
# -------------------------------------------------------------------------------------------
class RGATv3Norm(nn.Module):
# multi-layer GNN for one single feature
def __init__(
self,
hidden_channels,
out_channels,
num_heads,
rel_names,
dropout,
n_layers,
activation=nn.ELU(),
residual=False
):
super().__init__()
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
self.combine = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*4, hidden_channels*num_heads)
)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.GATv2Conv(
in_feats=hidden_channels*num_heads,
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
feat_drop=dropout,
attn_drop=dropout,
residual=residual,
activation=activation if l < n_layers - 1 else None
)
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
self.norms = nn.ModuleList([
Norm(
norm_type='gn',
hidden_dim=hidden_channels*num_heads if l < n_layers - 1 else out_channels
)
for l in range(n_layers)
])
def forward(self, g):
feats = []
feats.append(self.embedding_type(g.ndata['type'].float()))
feats.append(self.embedding_pos(g.ndata['pos']/170.))
feats.append(self.embedding_color(g.ndata['color']/255.))
feats.append(self.embedding_shape(g.ndata['shape'].float()))
h = {'obj': self.combine(torch.cat(feats, dim=1))}
for l, conv in enumerate(self.layers):
h = conv(g, h)
if l != len(self.layers) - 1:
h = {k: v.flatten(1) for k, v in h.items()}
else:
h = {k: v.mean(1) for k, v in h.items()}
h = {k: self.norms[l](g, v) for k, v in h.items()}
with g.local_scope():
g.ndata['h'] = h['obj']
out = dgl.mean_nodes(g, 'h')
return out
# -------------------------------------------------------------------------------------------
class RGATv4Agent(nn.Module):
# multi-layer GNN for one single feature
def __init__(
self,
hidden_channels,
out_channels,
num_heads,
rel_names,
dropout,
n_layers,
activation=nn.ELU(),
residual=False
):
super().__init__()
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
self.combine_attr = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads)
)
self.combine_pos = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads)
)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.GATv2Conv(
in_feats=hidden_channels*num_heads,
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
feat_drop=dropout,
attn_drop=dropout,
residual=residual,
activation=activation if l < n_layers - 1 else None
)
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
self.combine_agent_context = nn.Linear(out_channels*2, out_channels)
def forward(self, g):
agent_mask = g.ndata['type'][:, 0] == 1
attr = []
attr.append(self.embedding_type(g.ndata['type'].float()))
pos = self.embedding_pos(g.ndata['pos']/170.)
attr.append(self.embedding_color(g.ndata['color']/255.))
attr.append(self.embedding_shape(g.ndata['shape'].float()))
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
for l, conv in enumerate(self.layers):
h = conv(g, h)
if l != len(self.layers) - 1:
h = {k: v.flatten(1) for k, v in h.items()}
else:
h = {k: v.mean(1) for k, v in h.items()}
with g.local_scope():
g.ndata['h'] = h['obj']
h_a = g.ndata['h'][agent_mask, :]
g_no_agent = copy.deepcopy(g)
g_no_agent.remove_nodes([i for i, x in enumerate(agent_mask) if x])
h_g = dgl.mean_nodes(g_no_agent, 'h')
out = self.combine_agent_context(torch.cat((h_a, h_g), dim=1))
return out
# -------------------------------------------------------------------------------------------
class RGCNv4(nn.Module):
# multi-layer GNN for one single feature
def __init__(
self,
hidden_channels,
out_channels,
rel_names,
n_layers,
activation=nn.ELU(),
):
super().__init__()
self.embedding_type = nn.Linear(9, int(hidden_channels))
self.embedding_pos = nn.Linear(2, int(hidden_channels))
self.embedding_color = nn.Linear(3, int(hidden_channels))
self.embedding_shape = nn.Linear(18, int(hidden_channels))
self.combine_attr = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*3, hidden_channels)
)
self.combine_pos = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*2, hidden_channels)
)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.GraphConv(
in_feats=hidden_channels,
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
activation=activation if l < n_layers - 1 else None
)
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
def forward(self, g):
attr = []
attr.append(self.embedding_type(g.ndata['type'].float()))
pos = self.embedding_pos(g.ndata['pos']/170.)
attr.append(self.embedding_color(g.ndata['color']/255.))
attr.append(self.embedding_shape(g.ndata['shape'].float()))
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
for l, conv in enumerate(self.layers):
h = conv(g, h)
with g.local_scope():
g.ndata['h'] = h['obj']
out = dgl.mean_nodes(g, 'h')
return out
# -------------------------------------------------------------------------------------------
class RGATv5(nn.Module):
def __init__(
self,
hidden_channels,
out_channels,
num_heads,
rel_names,
dropout,
n_layers,
activation=nn.ELU(),
residual=False
):
super().__init__()
self.hidden_channels = hidden_channels
self.num_heads = num_heads
self.embedding_type = nn.Linear(9, hidden_channels*num_heads)
self.embedding_pos = nn.Linear(2, hidden_channels*num_heads)
self.embedding_color = nn.Linear(3, hidden_channels*num_heads)
self.embedding_shape = nn.Linear(18, hidden_channels*num_heads)
self.combine = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*4, hidden_channels*num_heads)
)
self.attention = nn.Linear(hidden_channels*num_heads*4, 4)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.GATv2Conv(
in_feats=hidden_channels*num_heads,
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
num_heads=num_heads,
feat_drop=dropout,
attn_drop=dropout,
residual=residual,
activation=activation if l < n_layers - 1 else None
)
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
def forward(self, g):
feats = []
feats.append(self.embedding_type(g.ndata['type'].float()))
feats.append(self.embedding_pos(g.ndata['pos']/170.))
feats.append(self.embedding_color(g.ndata['color']/255.))
feats.append(self.embedding_shape(g.ndata['shape'].float()))
h = torch.cat(feats, dim=1)
feat_attn = F.softmax(self.attention(h), dim=1)
h = h * feat_attn.repeat_interleave(self.hidden_channels*self.num_heads, dim=1)
h_in = self.combine(h)
h = {'obj': h_in}
for conv in self.layers:
h = conv(g, h)
h = {k: v.flatten(1) for k, v in h.items()}
#if l != len(self.layers) - 1:
# h = {k: v.flatten(1) for k, v in h.items()}
#else:
# h = {k: v.mean(1) for k, v in h.items()}
h = {k: v + h_in for k, v in h.items()}
with g.local_scope():
g.ndata['h'] = h['obj']
out = dgl.mean_nodes(g, 'h')
return out
# -------------------------------------------------------------------------------------------
class RGATv6(nn.Module):
# RGATv6 = RGATv4 + Global Attention Pooling
def __init__(
self,
hidden_channels,
out_channels,
num_heads,
rel_names,
dropout,
n_layers,
activation=nn.ELU(),
residual=False
):
super().__init__()
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
self.combine_attr = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads)
)
self.combine_pos = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads)
)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.GATv2Conv(
in_feats=hidden_channels*num_heads,
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
feat_drop=dropout,
attn_drop=dropout,
residual=residual,
activation=activation if l < n_layers - 1 else None
)
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
gate_nn = nn.Linear(out_channels, 1)
self.gap = dglnn.GlobalAttentionPooling(gate_nn)
def forward(self, g):
attr = []
attr.append(self.embedding_type(g.ndata['type'].float()))
pos = self.embedding_pos(g.ndata['pos']/170.)
attr.append(self.embedding_color(g.ndata['color']/255.))
attr.append(self.embedding_shape(g.ndata['shape'].float()))
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
for l, conv in enumerate(self.layers):
h = conv(g, h)
if l != len(self.layers) - 1:
h = {k: v.flatten(1) for k, v in h.items()}
else:
h = {k: v.mean(1) for k, v in h.items()}
#with g.local_scope():
#g.ndata['h'] = h['obj']
#out = dgl.mean_nodes(g, 'h')
out = self.gap(g, h['obj'])
return out
# -------------------------------------------------------------------------------------------
class RGATv6Agent(nn.Module):
# RGATv6 = RGATv4 + Global Attention Pooling
def __init__(
self,
hidden_channels,
out_channels,
num_heads,
rel_names,
dropout,
n_layers,
activation=nn.ELU(),
residual=False
):
super().__init__()
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
self.combine_attr = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads)
)
self.combine_pos = nn.Sequential(
nn.ReLU(),
nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads)
)
self.layers = nn.ModuleList([
dglnn.HeteroGraphConv({
rel: dglnn.GATv2Conv(
in_feats=hidden_channels*num_heads,
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
feat_drop=dropout,
attn_drop=dropout,
residual=residual,
activation=activation if l < n_layers - 1 else None
)
for rel in rel_names}, aggregate='sum')
for l in range(n_layers)
])
gate_nn = nn.Linear(out_channels, 1)
self.gap = dglnn.GlobalAttentionPooling(gate_nn)
self.combine_agent_context = nn.Linear(out_channels*2, out_channels)
def forward(self, g):
agent_mask = g.ndata['type'][:, 0] == 1
attr = []
attr.append(self.embedding_type(g.ndata['type'].float()))
pos = self.embedding_pos(g.ndata['pos']/170.)
attr.append(self.embedding_color(g.ndata['color']/255.))
attr.append(self.embedding_shape(g.ndata['shape'].float()))
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
for l, conv in enumerate(self.layers):
h = conv(g, h)
if l != len(self.layers) - 1:
h = {k: v.flatten(1) for k, v in h.items()}
else:
h = {k: v.mean(1) for k, v in h.items()}
with g.local_scope():
g.ndata['h'] = h['obj']
h_a = g.ndata['h'][agent_mask, :]
h_g = g.ndata['h'][~agent_mask, :]
g_no_agent = copy.deepcopy(g)
g_no_agent.remove_nodes([i for i, x in enumerate(agent_mask) if x])
h_g = self.gap(g_no_agent, h_g) # dgl.mean_nodes(g_no_agent, 'h')
out = self.combine_agent_context(torch.cat((h_a, h_g), dim=1))
return out
# -------------------------------------------------------------------------------------------

513
tom/model.py Normal file
View file

@ -0,0 +1,513 @@
from argparse import ArgumentParser
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader
from tom.dataset import *
from tom.transformer import TransformerEncoder
from tom.gnn import RGATv2, RGATv3, RGATv3Agent, RGATv4, RGATv4Norm, RSAGEv4, RAGNNv4
class MlpModel(nn.Module):
"""Multilayer Perceptron with last layer linear.
Args:
input_size (int): number of inputs
hidden_sizes (list): can be empty list for none (linear model).
output_size: linear layer at output, or if ``None``, the last hidden size
will be the output size and will have nonlinearity applied
nonlinearity: torch nonlinearity Module (not Functional).
"""
def __init__(
self,
input_size,
hidden_sizes, # Can be empty list or None for none.
output_size=None, # if None, last layer has nonlinearity applied.
nonlinearity=nn.ReLU, # Module, not Functional.
dropout=None # Dropout value
):
super().__init__()
if isinstance(hidden_sizes, int):
hidden_sizes = [hidden_sizes]
elif hidden_sizes is None:
hidden_sizes = []
hidden_layers = [nn.Linear(n_in, n_out) for n_in, n_out in
zip([input_size] + hidden_sizes[:-1], hidden_sizes)]
sequence = list()
for i, layer in enumerate(hidden_layers):
if dropout is not None:
sequence.extend([layer, nonlinearity(), nn.Dropout(dropout)])
else:
sequence.extend([layer, nonlinearity()])
if output_size is not None:
last_size = hidden_sizes[-1] if hidden_sizes else input_size
sequence.append(torch.nn.Linear(last_size, output_size))
self.model = nn.Sequential(*sequence)
self._output_size = (hidden_sizes[-1] if output_size is None
else output_size)
def forward(self, input):
"""Compute the model on the input, assuming input shape [B,input_size]."""
return self.model(input)
@property
def output_size(self):
"""Retuns the output size of the model."""
return self._output_size
# ---------------------------------------------------------------------------------------------------------------------------------
class GraphBCRNN(pl.LightningModule):
"""
Implementation of the baseline model for the BC-RNN algorithm.
R-GCN + LSTM are used to encode the familiarization trials
"""
@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--action_dim', type=int, default=2)
parser.add_argument('--context_dim', type=int, default=32) # lstm hidden size
parser.add_argument('--beta', type=float, default=0.01)
parser.add_argument('--dropout', type=float, default=0.2)
parser.add_argument('--process_data', type=int, default=0)
parser.add_argument('--max_len', type=int, default=30)
# arguments for gnn
parser.add_argument('--gnn_type', type=str, default='RGATv4')
parser.add_argument('--state_dim', type=int, default=128) # gnn out_feats
parser.add_argument('--feats_dims', type=list, default=[9, 2, 3, 18])
parser.add_argument('--aggregation', type=str, default='sum')
# arguments for mpl
#parser.add_argument('--mpl_hid_feats', type=list, default=[256, 64, 16])
return parser
def __init__(self, hparams):
super().__init__()
self.hparams = hparams
self.lr = self.hparams.lr
self.state_dim = self.hparams.state_dim
self.action_dim = self.hparams.action_dim
self.context_dim = self.hparams.context_dim
self.beta = self.hparams.beta
self.dropout = self.hparams.dropout
self.max_len = self.hparams.max_len
self.feats_dims = self.hparams.feats_dims # type, position, color, shape
self.rel_names = [
'is_aligned', 'is_back', 'is_close', 'is_down_adj', 'is_down_left_adj',
'is_down_right_adj', 'is_front', 'is_left', 'is_left_adj', 'is_right',
'is_right_adj', 'is_top_adj', 'is_top_left_adj', 'is_top_right_adj'
]
self.gnn_aggregation = self.hparams.aggregation
if self.hparams.gnn_type == 'RGATv2':
self.gnn_encoder = RGATv2(
hidden_channels=self.state_dim,
out_channels=self.state_dim,
num_heads=4,
rel_names=self.rel_names,
dropout=0.0,
n_layers=2,
activation=nn.ELU(),
residual=False
)
if self.hparams.gnn_type == 'RGATv3':
self.gnn_encoder = RGATv3(
hidden_channels=self.state_dim,
out_channels=self.state_dim,
num_heads=4,
rel_names=self.rel_names,
dropout=0.0,
n_layers=2,
activation=nn.ELU(),
residual=False
)
if self.hparams.gnn_type == 'RGATv4':
self.gnn_encoder = RGATv4(
hidden_channels=self.state_dim,
out_channels=self.state_dim,
num_heads=4,
rel_names=self.rel_names,
dropout=0.0,
n_layers=2,
activation=nn.ELU(),
residual=False
)
if self.hparams.gnn_type == 'RGATv3Agent':
self.gnn_encoder = RGATv3Agent(
hidden_channels=self.state_dim,
out_channels=self.state_dim,
num_heads=4,
rel_names=self.rel_names,
dropout=0.0,
n_layers=2,
activation=nn.ELU(),
residual=False
)
if self.hparams.gnn_type == 'RSAGEv4':
self.gnn_encoder = RSAGEv4(
hidden_channels=self.state_dim,
out_channels=self.state_dim,
rel_names=self.rel_names,
dropout=0.0,
n_layers=2,
activation=nn.ELU()
)
if self.gnn_aggregation == 'cat_axis_1':
self.lstm_input_size = self.state_dim * len(self.feats_dims) + self.action_dim
self.mlp_input_size = self.state_dim * len(self.feats_dims) + self.context_dim * 2
elif self.gnn_aggregation == 'sum':
self.lstm_input_size = self.state_dim + self.action_dim
self.mlp_input_size = self.state_dim + self.context_dim * 2
else:
raise ValueError('Fix this')
self.context_enc = nn.LSTM(self.lstm_input_size, self.context_dim, 2,
batch_first=True, bidirectional=True)
self.policy = MlpModel(input_size=self.mlp_input_size, hidden_sizes=[256, 128, 256],
output_size=self.action_dim, dropout=self.dropout)
self.past_samples = []
def forward(self, batch):
dem_frames, dem_actions, dem_lens, query_frame, target_action = batch
dem_actions = dem_actions.float()
target_action = target_action.float()
dem_states = self.gnn_encoder(dem_frames) # torch.Size([number of frames, 128 * number of features if cat_axis_1])
# segment according the number of frames in each episode and pad with zeros
# to obtain tensors of shape [batch size, num of trials (8), max num of frames (30), hidden dim]
b, l, s, _ = dem_actions.size()
dem_states_new = []
for batch in range(b):
dem_states_new.append(self._sequence_to_padding(dem_states, dem_lens[batch], self.max_len))
dem_states_new = torch.stack(dem_states_new).to(self.device) # torch.Size([batchsize, 8, 30, 128 * number of features if cat_axis_1])
# concatenate states and actions to get expert trajectory
dem_states_new = dem_states_new.view(b * l, s, -1) # torch.Size([batchsize*8, 30, 128 * number of features if cat_axis_1])
dem_actions = dem_actions.view(b * l, s, -1) # torch.Size([batchsize*8, 30, 128])
dem_traj = torch.cat([dem_states_new, dem_actions], dim=2) # torch.Size([batchsize*8, 30, 2 + 128 * number of features if cat_axis_1])
# embed expert trajectory to get a context embedding batch x samples x dim
dem_lens = torch.tensor([t for dl in dem_lens for t in dl]).to(torch.int64).cpu()
dem_lens = dem_lens.view(-1)
x_lstm = nn.utils.rnn.pack_padded_sequence(dem_traj, dem_lens, batch_first=True, enforce_sorted=False)
x_lstm, _ = self.context_enc(x_lstm)
x_lstm, _ = nn.utils.rnn.pad_packed_sequence(x_lstm, batch_first=True)
x_out = x_lstm[:, -1, :]
x_out = x_out.view(b, l, -1)
context = torch.mean(x_out, dim=1) # torch.Size([batchsize, 64]) (64=32*2)
# concat context embedding to the state embedding of test trajectory
test_states = self.gnn_encoder(query_frame) # torch.Size([2, 128])
test_context_states = torch.cat([context, test_states], dim=1) # torch.Size([batchsize, 192]) 192=64+128
# for each state in the test states calculate action
test_actions_pred = torch.tanh(self.policy(test_context_states)) # torch.Size([batchsize, 2])
return target_action, test_actions_pred
def _sequence_to_padding(self, x, lengths, max_length):
# declare the shape, it can work for x of any shape.
ret_tensor = torch.zeros((len(lengths), max_length) + tuple(x.shape[1:]))
cum_len = 0
for i, l in enumerate(lengths):
ret_tensor[i, :l] = x[cum_len: cum_len+l]
cum_len += l
return ret_tensor
def training_step(self, batch, batch_idx):
test_actions, test_actions_pred = self.forward(batch)
loss = F.mse_loss(test_actions, test_actions_pred)
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
return loss
def validation_step(self, batch, batch_idx, dataloader_idx=0):
test_actions, test_actions_pred = self.forward(batch)
loss = F.mse_loss(test_actions, test_actions_pred)
self.log('val_loss', loss, on_epoch=True, logger=True)
def configure_optimizers(self):
optim = torch.optim.Adam(self.parameters(), lr=self.lr)
return optim
#optim = torch.optim.AdamW(self.parameters(), lr=self.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01)
#scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optim, gamma=0.8)
#return [optim], [scheduler]
def train_dataloader(self):
train_dataset = ToMnetDGLDataset(path=self.hparams.data_path,
types=self.hparams.types,
mode='train')
train_loader = DataLoader(dataset=train_dataset,
batch_size=self.hparams.batch_size,
collate_fn=collate_function_seq,
num_workers=self.hparams.num_workers,
#pin_memory=True,
shuffle=True)
return train_loader
def val_dataloader(self):
val_datasets = []
val_loaders = []
for t in self.hparams.types:
val_datasets.append(ToMnetDGLDataset(path=self.hparams.data_path,
types=[t],
mode='val'))
val_loaders.append(DataLoader(dataset=val_datasets[-1],
batch_size=self.hparams.batch_size,
collate_fn=collate_function_seq,
num_workers=self.hparams.num_workers,
#pin_memory=True,
shuffle=False))
return val_loaders
def configure_callbacks(self):
checkpoint = ModelCheckpoint(
dirpath=None, # automatically set
#filename=self.params['bc_model']+'-'+self.params['gnn_type']+'-'+self.gnn_params['feats_aggr']+'-{epoch:02d}',
save_top_k=-1,
period=1
)
return [checkpoint]
# ---------------------------------------------------------------------------------------------------------------------------------
class GraphBC_T(pl.LightningModule):
"""
BC model with GraphTrans encoder, LSTM and MLP.
"""
@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--action_dim', type=int, default=2)
parser.add_argument('--context_dim', type=int, default=32) # lstm hidden size
parser.add_argument('--beta', type=float, default=0.01)
parser.add_argument('--dropout', type=float, default=0.2)
parser.add_argument('--process_data', type=int, default=0)
parser.add_argument('--max_len', type=int, default=30)
# arguments for gnn
parser.add_argument('--state_dim', type=int, default=128) # gnn out_feats
parser.add_argument('--feats_dims', type=list, default=[9, 2, 3, 18])
parser.add_argument('--aggregation', type=str, default='cat_axis_1')
parser.add_argument('--gnn_type', type=str, default='RGATv3')
# arguments for mpl
#parser.add_argument('--mpl_hid_feats', type=list, default=[256, 64, 16])
# arguments for transformer
parser.add_argument('--d_model', type=int, default=128)
parser.add_argument('--nhead', type=int, default=4)
parser.add_argument('--dim_feedforward', type=int, default=512)
parser.add_argument('--transformer_dropout', type=float, default=0.3)
parser.add_argument('--transformer_activation', type=str, default='gelu')
parser.add_argument('--num_encoder_layers', type=int, default=6)
parser.add_argument('--transformer_norm_input', type=int, default=0)
return parser
def __init__(self, hparams):
super().__init__()
self.hparams = hparams
self.lr = self.hparams.lr
self.state_dim = self.hparams.state_dim
self.action_dim = self.hparams.action_dim
self.context_dim = self.hparams.context_dim
self.beta = self.hparams.beta
self.dropout = self.hparams.dropout
self.max_len = self.hparams.max_len
self.feats_dims = self.hparams.feats_dims # type, position, color, shape
self.rel_names = [
'is_aligned', 'is_back', 'is_close', 'is_down_adj', 'is_down_left_adj',
'is_down_right_adj', 'is_front', 'is_left', 'is_left_adj', 'is_right',
'is_right_adj', 'is_top_adj', 'is_top_left_adj', 'is_top_right_adj'
]
self.gnn_aggregation = self.hparams.aggregation
if self.hparams.gnn_type == 'RGATv3':
self.gnn_encoder = RGATv3(
hidden_channels=self.state_dim,
out_channels=self.state_dim,
num_heads=4,
rel_names=self.rel_names,
dropout=0.0,
n_layers=2,
activation=nn.ELU(),
residual=False
)
if self.hparams.gnn_type == 'RGATv4':
self.gnn_encoder = RGATv4(
hidden_channels=self.state_dim,
out_channels=self.state_dim,
num_heads=4,
rel_names=self.rel_names,
dropout=0.0,
n_layers=2,
activation=nn.ELU(),
residual=False
)
if self.hparams.gnn_type == 'RSAGEv4':
self.gnn_encoder = RSAGEv4(
hidden_channels=self.state_dim,
out_channels=self.state_dim,
rel_names=self.rel_names,
dropout=0.0,
n_layers=2,
activation=nn.ELU()
)
if self.hparams.gnn_type == 'RAGNNv4':
self.gnn_encoder = RAGNNv4(
hidden_channels=self.state_dim,
out_channels=self.state_dim,
rel_names=self.rel_names,
dropout=0.0,
n_layers=2,
activation=nn.ELU()
)
if self.hparams.gnn_type == 'RGATv4Norm':
self.gnn_encoder = RGATv4Norm(
hidden_channels=self.state_dim,
out_channels=self.state_dim,
num_heads=4,
rel_names=self.rel_names,
dropout=0.0,
n_layers=2,
activation=nn.ELU(),
residual=True
)
self.d_model = self.hparams.d_model
self.nhead = self.hparams.nhead
self.dim_feedforward = self.hparams.dim_feedforward
self.transformer_dropout = self.hparams.transformer_dropout
self.transformer_activation = self.hparams.transformer_activation
self.num_encoder_layers = self.hparams.num_encoder_layers
self.transformer_norm_input = self.hparams.transformer_norm_input
self.context_enc = TransformerEncoder(
self.d_model,
self.nhead,
self.dim_feedforward,
self.transformer_dropout,
self.transformer_activation,
self.num_encoder_layers,
self.max_len,
self.transformer_norm_input
)
if self.gnn_aggregation == 'cat_axis_1':
self.gnn2transformer = nn.Linear(self.state_dim * len(self.feats_dims) + self.action_dim, self.d_model)
self.mlp_input_size = len(self.feats_dims) * self.state_dim + self.d_model
elif self.gnn_aggregation == 'sum':
self.gnn2transformer = nn.Linear(self.state_dim + self.action_dim, self.d_model)
self.mlp_input_size = self.state_dim + self.d_model
else:
raise ValueError('Only sum and cat1 aggregations are available.')
self.policy = MlpModel(input_size=self.mlp_input_size, hidden_sizes=[256, 128, 256],
output_size=self.action_dim, dropout=self.dropout)
# CLS Embedding parameters, requires_grad=True
self.embedding = nn.Embedding(self.max_len + 1, self.d_model) # + 1 cause of cls token
self.emb_layer_norm = nn.LayerNorm(self.d_model)
self.emb_dropout = nn.Dropout(p=self.transformer_dropout)
def forward(self, batch):
dem_frames, dem_actions, dem_lens, query_frame, target_action = batch
dem_actions = dem_actions.float()
target_action = target_action.float()
dem_states = self.gnn_encoder(dem_frames)
b, l, s, _ = dem_actions.size()
dem_lens = torch.tensor([t for dl in dem_lens for t in dl]).to(torch.int64).cpu()
dem_lens = dem_lens.view(-1)
dem_actions_packed = torch.nn.utils.rnn.pack_padded_sequence(dem_actions.view(b*l, s, -1), dem_lens, batch_first=True, enforce_sorted=False)[0]
dem_traj = torch.cat([dem_states, dem_actions_packed], dim=1)
h_node = self.gnn2transformer(dem_traj)
hidden_dim = h_node.size()[1]
padded_trajectory = torch.zeros(b*l, self.max_len, hidden_dim).to(self.device)
j = 0
for idx, i in enumerate(dem_lens):
padded_trajectory[idx][:i] = h_node[j:j+i]
j += i
mask = self.make_mask(padded_trajectory).to(self.device)
transformer_input = padded_trajectory.transpose(0, 1) # [30, 16, 128]
# add cls:
cls_embedding = nn.Parameter(torch.randn([1, 1, self.d_model], requires_grad=True)).expand(1, b*l, -1).to(self.device)
transformer_input = torch.cat([transformer_input, cls_embedding], dim=0) # [31, 16, 128]
zeros = mask.data.new(mask.size(0), 1).fill_(0)
mask = torch.cat([mask, zeros], dim=1)
# Embed
indices = torch.arange(self.max_len + 1, dtype=torch.int).to(self.device) # + 1 cause of cls [0, 1, ..., 30]
positional_embeddings = self.embedding(indices).unsqueeze(1) # torch.Size([31, 1, 128])
#generate transformer input
pe_input = positional_embeddings + transformer_input # torch.Size([31, 16, 128])
# Layernorm and dropout
transformer_in = self.emb_dropout(self.emb_layer_norm(pe_input)) # torch.Size([31, 16, 128])
# transformer encoding and output parsing
out, _ = self.context_enc(transformer_in, mask) # [31, 16, 128]
cls = out[-1]
cls = cls.view(b, l, -1) # 2, 8, 128
context = torch.mean(cls, dim=1)
# CLASSIFICATION
test_states = self.gnn_encoder(query_frame) # torch.Size([2, 512])
test_context_states = torch.cat([context, test_states], dim=1) # torch.Size([batchsize, hidden_dim + lstm_hidden_dim]) 192=512+128
# for each state in the test states calculate action
x = self.policy(test_context_states)
test_actions_pred = torch.tanh(x) # torch.Size([batchsize, 2])
#test_actions_pred = torch.tanh(self.policy(test_context_states)) # torch.Size([batchsize, 2])
return target_action, test_actions_pred
def make_mask(self, feature):
return (torch.sum(
torch.abs(feature),
dim=-1
) == 0)#.unsqueeze(1).unsqueeze(2)
def training_step(self, batch, batch_idx):
test_actions, test_actions_pred = self.forward(batch)
loss = F.mse_loss(test_actions, test_actions_pred)
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
return loss
def validation_step(self, batch, batch_idx, dataloader_idx=0):
test_actions, test_actions_pred = self.forward(batch)
loss = F.mse_loss(test_actions, test_actions_pred)
self.log('val_loss', loss, on_epoch=True, logger=True)
def configure_optimizers(self):
optim = torch.optim.Adam(self.parameters(), lr=self.lr)
return optim
#optim = torch.optim.AdamW(self.parameters(), lr=self.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01)
#scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optim, gamma=0.96)
#return [optim], [scheduler]
def train_dataloader(self):
train_dataset = ToMnetDGLDataset(path=self.hparams.data_path,
types=self.hparams.types,
mode='train')
train_loader = DataLoader(dataset=train_dataset,
batch_size=self.hparams.batch_size,
collate_fn=collate_function_seq,
num_workers=self.hparams.num_workers,
#pin_memory=True,
shuffle=True)
return train_loader
def val_dataloader(self):
val_datasets = []
val_loaders = []
for t in self.hparams.types:
val_datasets.append(ToMnetDGLDataset(path=self.hparams.data_path,
types=[t],
mode='val'))
val_loaders.append(DataLoader(dataset=val_datasets[-1],
batch_size=self.hparams.batch_size,
collate_fn=collate_function_seq,
num_workers=self.hparams.num_workers,
#pin_memory=True,
shuffle=False))
return val_loaders
def configure_callbacks(self):
checkpoint = ModelCheckpoint(
dirpath=None, # automatically set
#filename=self.params['bc_model']+'-'+self.params['gnn_type']+'-'+self.gnn_params['feats_aggr']+'-{epoch:02d}',
save_top_k=-1,
period=1
)
return [checkpoint]

46
tom/norm.py Normal file
View file

@ -0,0 +1,46 @@
import torch
import torch.nn as nn
class Norm(nn.Module):
def __init__(self, norm_type, hidden_dim=64, print_info=None):
super(Norm, self).__init__()
# assert norm_type in ['bn', 'ln', 'gn', None]
self.norm = None
self.print_info = print_info
if norm_type == 'bn':
self.norm = nn.BatchNorm1d(hidden_dim)
elif norm_type == 'gn':
self.norm = norm_type
self.weight = nn.Parameter(torch.ones(hidden_dim))
self.bias = nn.Parameter(torch.zeros(hidden_dim))
self.mean_scale = nn.Parameter(torch.ones(hidden_dim))
def forward(self, graph, tensor, print_=False):
if self.norm is not None and type(self.norm) != str:
return self.norm(tensor)
elif self.norm is None:
return tensor
batch_list = graph.batch_num_nodes('obj')
batch_size = len(batch_list)
#batch_list = torch.tensor(batch_list).long().to(tensor.device)
batch_list = batch_list.long().to(tensor.device)
batch_index = torch.arange(batch_size).to(tensor.device).repeat_interleave(batch_list)
batch_index = batch_index.view((-1,) + (1,) * (tensor.dim() - 1)).expand_as(tensor)
mean = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device)
mean = mean.scatter_add_(0, batch_index, tensor)
mean = (mean.T / batch_list).T
mean = mean.repeat_interleave(batch_list, dim=0)
sub = tensor - mean * self.mean_scale
std = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device)
std = std.scatter_add_(0, batch_index, sub.pow(2))
std = ((std.T / batch_list).T + 1e-6).sqrt()
std = std.repeat_interleave(batch_list, dim=0)
return self.weight * sub / std + self.bias

89
tom/transformer.py Normal file
View file

@ -0,0 +1,89 @@
import torch
import torch.nn as nn
import math
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(max_len, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Tensor, shape [seq_len, batch_size, embedding_dim]
"""
x = x + self.pe[:x.size(0)]
return self.dropout(x)
class TransformerEncoder(nn.Module):
def __init__(
self,
d_model,
nhead,
dim_feedforward,
transformer_dropout,
transformer_activation,
num_encoder_layers,
max_input_len,
transformer_norm_input
):
super().__init__()
self.d_model = d_model
self.num_layer = num_encoder_layers
self.max_input_len = max_input_len
# Creating Transformer Encoder Model
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=transformer_dropout,
activation=transformer_activation
)
encoder_norm = nn.LayerNorm(d_model)
self.transformer = nn.TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
self.norm_input = None
if transformer_norm_input:
self.norm_input = nn.LayerNorm(d_model)
def forward(self, padded_h_node, src_padding_mask):
"""
padded_h_node: n_b x B x h_d # 63, 257, 128
src_key_padding_mask: B x n_b # 257, 63
"""
# (S, B, h_d), (B, S)
if self.norm_input is not None:
padded_h_node = self.norm_input(padded_h_node)
transformer_out = self.transformer(padded_h_node, src_key_padding_mask=src_padding_mask) # (S, B, h_d)
return transformer_out, src_padding_mask
if __name__ == '__main__':
model = TransformerEncoder(
d_model=12,
nhead=4,
dim_feedforward=32,
transformer_dropout=0.0,
transformer_activation='gelu',
num_encoder_layers=4,
max_input_len=34,
transformer_norm_input=0
)
print(model.norm_input)

75
train_tom.py Normal file
View file

@ -0,0 +1,75 @@
import random
from argparse import ArgumentParser
import numpy as np
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from tom.model import GraphBC_T, GraphBCRNN
torch.multiprocessing.set_sharing_strategy('file_system')
parser = ArgumentParser()
# program level args
parser.add_argument('--seed', type=int, default=4)
# data specific args
parser.add_argument('--data_path', type=str, default='/datasets/external/bib_train/graphs/all_tasks/')
parser.add_argument('--types', nargs='+', type=str,
default=['preference', 'multi_agent', 'single_object', 'instrumental_action'],
help='types of tasks used for training / validation')
parser.add_argument('--train', type=int, default=1)
parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--model_type', type=str, default='graphbcrnn')
# model specific args
parser_model = ArgumentParser()
parser_model = GraphBC_T.add_model_specific_args(parser_model)
# parser_model = GraphBCRNN.add_model_specific_args(parser_model)
# NOTE: here unfortunately you have to select manually the model
# add all the available trainer options to argparse
parser = Trainer.add_argparse_args(parser)
# combine parsers
parser_all = ArgumentParser(conflict_handler='resolve',
parents=[parser, parser_model])
# parse args
args = parser_all.parse_args()
args.types = sorted(args.types)
print(args)
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
# init model
if args.model_type == 'graphbct':
model = GraphBC_T(args)
elif args.model_type == 'graphbcrnn':
model = GraphBCRNN(args)
else:
raise NotImplementedError
torch.autograd.set_detect_anomaly(True)
logger = WandbLogger(project='bib')
trainer = Trainer(
gradient_clip_val=args.gradient_clip_val,
gpus=args.gpus,
auto_select_gpus=args.auto_select_gpus,
track_grad_norm=args.track_grad_norm,
check_val_every_n_epoch=args.check_val_every_n_epoch,
max_epochs=args.max_epochs,
accelerator=args.accelerator,
resume_from_checkpoint=args.resume_from_checkpoint,
stochastic_weight_avg=args.stochastic_weight_avg,
num_sanity_val_steps=args.num_sanity_val_steps,
logger=logger
)
if args.train:
trainer.fit(model)
else:
raise NotImplementedError

0
utils/__init__.py Normal file
View file

115
utils/build_graphs.py Normal file
View file

@ -0,0 +1,115 @@
import sys
sys.path.append('/projects/bortoletto/icml2023_matteo/utils')
from dataset import TransitionDataset, TestTransitionDatasetSequence
import multiprocessing as mp
import argparse
import pickle as pkl
import os
# Instantiate the parser
parser = argparse.ArgumentParser()
parser.add_argument('--cpus', type=int,
help='Number of processes')
parser.add_argument('--mode', type=str,
help='Train (train) or validation (val)')
args = parser.parse_args()
NUM_PROCESSES = args.cpus
MODE = args.mode
def generate_files(idx):
print('Generating idx', idx)
if os.path.exists(PATH+str(idx)+'.pkl'):
print('Index', idx, 'skipped.')
return
if MODE == 'train' or MODE == 'val':
states, actions, lens, n_nodes = dataset.__getitem__(idx)
with open(PATH+str(idx)+'.pkl', 'wb') as f:
pkl.dump([states, actions, lens, n_nodes], f)
elif MODE == 'test':
dem_expected_states, dem_expected_actions, dem_expected_lens, dem_expected_nodes, \
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, dem_unexpected_nodes, \
query_expected_frames, target_expected_actions, \
query_unexpected_frames, target_unexpected_actions = dataset.__getitem__(idx)
with open(PATH+str(idx)+'.pkl', 'wb') as f:
pkl.dump([
dem_expected_states, dem_expected_actions, dem_expected_lens, dem_expected_nodes, \
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, dem_unexpected_nodes, \
query_expected_frames, target_expected_actions, \
query_unexpected_frames, target_unexpected_actions], f
)
else:
raise ValueError('MODE can be only train, val or test.')
print(PATH+str(idx)+'.pkl saved.')
if __name__ == "__main__":
if MODE == 'train':
print('TRAIN MODE')
PATH = '/datasets/external/bib_train/graphs/all_tasks/train_dgl_hetero_nobound_4feats/'
if not os.path.exists(PATH):
os.makedirs(PATH)
print(PATH, 'directory created.')
dataset = TransitionDataset(
path='/datasets/external/bib_train/',
types=['instrumental_action', 'multi_agent', 'preference', 'single_object'],
mode="train",
max_len=30,
num_test=1,
num_trials=9,
action_range=10,
process_data=0
)
pool = mp.Pool(processes=NUM_PROCESSES)
print('Starting graph generation with', NUM_PROCESSES, 'processes...')
pool.map(generate_files, [i for i in range(dataset.__len__())])
pool.close()
elif MODE == 'val':
print('VALIDATION MODE')
types = ['multi_agent', 'instrumental_action', 'preference', 'single_object']
for t in range(len(types)):
PATH = '/datasets/external/bib_train/graphs/all_tasks/val_dgl_hetero_nobound_4feats/'+types[t]+'/'
if not os.path.exists(PATH):
os.makedirs(PATH)
print(PATH, 'directory created.')
dataset = TransitionDataset(
path='/datasets/external/bib_train/',
types=[types[t]],
mode="val",
max_len=30,
num_test=1,
num_trials=9,
action_range=10,
process_data=0
)
pool = mp.Pool(processes=NUM_PROCESSES)
print('Starting', types[t], 'graph generation with', NUM_PROCESSES, 'processes...')
pool.map(generate_files, [i for i in range(dataset.__len__())])
pool.close()
elif MODE == 'test':
print('TEST MODE')
types = [
'preference', 'multi_agent', 'inaccessible_goal',
'efficiency_irrational', 'efficiency_time','efficiency_path',
'instrumental_no_barrier', 'instrumental_blocking_barrier', 'instrumental_inconsequential_barrier'
]
for t in range(len(types)):
PATH = '/datasets/external/bib_evaluation_1_1/graphs/all_tasks_dgl_hetero_nobound_4feats/'+types[t]+'/'
if not os.path.exists(PATH):
os.makedirs(PATH)
print(PATH, 'directory created.')
dataset = TestTransitionDatasetSequence(
path='/datasets/external/bib_evaluation_1_1/',
task_type=types[t],
mode="test",
num_test=1,
num_trials=9,
action_range=10,
process_data=0,
max_len=30
)
pool = mp.Pool(processes=NUM_PROCESSES)
print('Starting', types[t], 'graph generation with', NUM_PROCESSES, 'processes...')
pool.map(generate_files, [i for i in range(dataset.__len__())])
pool.close()
else:
raise ValueError

487
utils/dataset.py Normal file
View file

@ -0,0 +1,487 @@
import dgl
import torch
import torch.utils.data
import os
import pickle as pkl
import json
import numpy as np
from tqdm import tqdm
import sys
sys.path.append('/projects/bortoletto/irene/')
from utils.grid_object import *
from utils.relations import *
# ========================== Helper functions ==========================
def index_data(json_list, path_list):
print(f'processing files {len(json_list)}')
data_tuples = []
for j, v in tqdm(zip(json_list, path_list)):
with open(j, 'r') as f:
state = json.load(f)
ep_lens = [len(x) for x in state]
past_len = 0
for e, l in enumerate(ep_lens):
data_tuples.append([])
# skip first 30 frames and last 83 frames
for f in range(30, l - 83):
# find action taken;
f0x, f0y = state[e][f]['agent'][0]
f1x, f1y = state[e][f + 1]['agent'][0]
dx = (f1x - f0x) / 2.
dy = (f1y - f0y) / 2.
action = [dx, dy]
#data_tuples[-1].append((v, past_len + f, action))
data_tuples[-1].append((j, past_len + f, action))
# data_tuples = [[json file, frame number, action] for each episode in each video]
assert len(data_tuples[-1]) > 0
past_len += l
return data_tuples
# ========================== Dataset class ==========================
class TransitionDataset(torch.utils.data.Dataset):
"""
Training dataset class for the behavior cloning mlp model.
Args:
path: path to the dataset
types: list of video types to include
mode: train, val
num_test: number of test state-action pairs
num_trials: number of trials in an episode
action_range: number of frames to skip; actions are combined over these number of frames (displcement) of the agent
process_data: whether to the videos or not (skip if already processed)
max_len: max number of context state-action pairs
__getitem__:
returns: (states, actions, lens, n_nodes)
dem_frames: batched DGLGraph.heterograph
dem_actions: (max_len, 2)
query_frames: DGLGraph.heterograph
target_actions: (num_test, 2)
"""
def __init__(
self,
path,
types=None,
mode="train",
num_test=1,
num_trials=9,
action_range=10,
process_data=0,
max_len=30
):
self.path = path
self.types = types
self.mode = mode
self.num_trials = num_trials
self.num_test = num_test
self.action_range = action_range
self.max_len = max_len
self.ep_combs = self.num_trials * (self.num_trials - 2) # 9p2 - 9
self.eps = [[x, y] for x in range(self.num_trials) for y in range(self.num_trials) if x != y]
types_str = '_'.join(self.types)
self.rel_deter_func = [
is_top_adj, is_left_adj, is_top_right_adj, is_top_left_adj,
is_down_adj, is_right_adj, is_down_left_adj, is_down_right_adj,
is_left, is_right, is_front, is_back, is_aligned, is_close
]
self.path_list = []
self.json_list = []
# get video paths and json file paths
for t in types:
print(f'reading files of type {t} in {mode}')
paths = [os.path.join(self.path, t, x) for x in os.listdir(os.path.join(self.path, t)) if
x.endswith(f'.mp4')]
jsons = [os.path.join(self.path, t, x) for x in os.listdir(os.path.join(self.path, t)) if
x.endswith(f'.json') and 'index' not in x]
paths = sorted(paths)
jsons = sorted(jsons)
if mode == 'train':
self.path_list += paths[:int(0.8 * len(jsons))]
self.json_list += jsons[:int(0.8 * len(jsons))]
elif mode == 'val':
self.path_list += paths[int(0.8 * len(jsons)):]
self.json_list += jsons[int(0.8 * len(jsons)):]
else:
self.path_list += paths
self.json_list += jsons
self.data_tuples = []
if process_data:
# index the videos in the dataset directory. This is done to speed up the retrieval of videos.
# frame index, action tuples are stored
self.data_tuples = index_data(self.json_list, self.path_list)
# tuples of frame index and action (displacement of agent)
index_dict = {'data_tuples': self.data_tuples}
with open(os.path.join(self.path, f'jindex_bib_{mode}_{types_str}.json'), 'w') as fp:
json.dump(index_dict, fp)
else:
# read pre-indexed data
with open(os.path.join(self.path, f'jindex_bib_{mode}_{types_str}.json'), 'r') as fp:
index_dict = json.load(fp)
self.data_tuples = index_dict['data_tuples']
self.tot_trials = len(self.path_list) * 9
def _get_frame_graph(self, jsonfile, frame_idx):
# load json
with open(jsonfile, 'rb') as f:
frame_data = json.load(f)
flat_list = [x for xs in frame_data for x in xs]
# extract entities
grid_objs = parse_objects(flat_list[frame_idx])
# --- build the graph
adj = self._get_spatial_rel(grid_objs)
# define edges
is_top_adj_src, is_top_adj_dst = np.nonzero(adj[0])
is_left_adj_src, is_left_adj_dst = np.nonzero(adj[1])
is_top_right_adj_src, is_top_right_adj_dst = np.nonzero(adj[2])
is_top_left_adj_src, is_top_left_adj_dst = np.nonzero(adj[3])
is_down_adj_src, is_down_adj_dst = np.nonzero(adj[4])
is_right_adj_src, is_right_adj_dst = np.nonzero(adj[5])
is_down_left_adj_src, is_down_left_adj_dst = np.nonzero(adj[6])
is_down_right_adj_src, is_down_right_adj_dst = np.nonzero(adj[7])
is_left_src, is_left_dst = np.nonzero(adj[8])
is_right_src, is_right_dst = np.nonzero(adj[9])
is_front_src, is_front_dst = np.nonzero(adj[10])
is_back_src, is_back_dst = np.nonzero(adj[11])
is_aligned_src, is_aligned_dst = np.nonzero(adj[12])
is_close_src, is_close_dst = np.nonzero(adj[13])
g = dgl.heterograph({
('obj', 'is_top_adj', 'obj'): (torch.tensor(is_top_adj_src), torch.tensor(is_top_adj_dst)),
('obj', 'is_left_adj', 'obj'): (torch.tensor(is_left_adj_src), torch.tensor(is_left_adj_dst)),
('obj', 'is_top_right_adj', 'obj'): (torch.tensor(is_top_right_adj_src), torch.tensor(is_top_right_adj_dst)),
('obj', 'is_top_left_adj', 'obj'): (torch.tensor(is_top_left_adj_src), torch.tensor(is_top_left_adj_dst)),
('obj', 'is_down_adj', 'obj'): (torch.tensor(is_down_adj_src), torch.tensor(is_down_adj_dst)),
('obj', 'is_right_adj', 'obj'): (torch.tensor(is_right_adj_src), torch.tensor(is_right_adj_dst)),
('obj', 'is_down_left_adj', 'obj'): (torch.tensor(is_down_left_adj_src), torch.tensor(is_down_left_adj_dst)),
('obj', 'is_down_right_adj', 'obj'): (torch.tensor(is_down_right_adj_src), torch.tensor(is_down_right_adj_dst)),
('obj', 'is_left', 'obj'): (torch.tensor(is_left_src), torch.tensor(is_left_dst)),
('obj', 'is_right', 'obj'): (torch.tensor(is_right_src), torch.tensor(is_right_dst)),
('obj', 'is_front', 'obj'): (torch.tensor(is_front_src), torch.tensor(is_front_dst)),
('obj', 'is_back', 'obj'): (torch.tensor(is_back_src), torch.tensor(is_back_dst)),
('obj', 'is_aligned', 'obj'): (torch.tensor(is_aligned_src), torch.tensor(is_aligned_dst)),
('obj', 'is_close', 'obj'): (torch.tensor(is_close_src), torch.tensor(is_close_dst))
}, num_nodes_dict={'obj': len(grid_objs)})
g = self._add_node_features(grid_objs, g)
breakpoint()
return g
def _add_node_features(self, objs, graph):
for obj_idx, obj in enumerate(objs):
graph.nodes[obj_idx].data['type'] = torch.tensor(obj.type)
graph.nodes[obj_idx].data['pos'] = torch.tensor([[obj.x, obj.y]], dtype=torch.float32)
assert len(obj.attributes) == 2 and None not in obj.attributes[0] and None not in obj.attributes[1]
graph.nodes[obj_idx].data['color'] = torch.tensor([obj.attributes[0]])
graph.nodes[obj_idx].data['shape'] = torch.tensor([obj.attributes[1]])
return graph
def _get_spatial_rel(self, objs):
spatial_tensors = [np.zeros([len(objs), len(objs)]) for _ in range(len(self.rel_deter_func))]
for obj_idx1, obj1 in enumerate(objs):
for obj_idx2, obj2 in enumerate(objs):
direction_vec = np.array((0, -1)) # Up
for rel_idx, func in enumerate(self.rel_deter_func):
if func(obj1, obj2, direction_vec):
spatial_tensors[rel_idx][obj_idx1, obj_idx2] = 1.0
return spatial_tensors
def get_trial(self, trials, step=10):
# retrieve state embeddings and actions from cached file
states = []
actions = []
trial_len = []
lens = []
n_nodes = []
# 8 trials
for t in trials:
tl = [(t, n) for n in range(0, len(self.data_tuples[t]), step)]
if len(tl) > self.max_len: # 30
tl = tl[:self.max_len]
trial_len.append(tl)
for tl in trial_len:
states.append([])
actions.append([])
lens.append(len(tl))
for t, n in tl:
video = self.data_tuples[t][n][0]
states[-1].append(self._get_frame_graph(video, self.data_tuples[t][n][1]))
n_nodes.append(states[-1][-1].number_of_nodes())
# actions are pooled over frames
if len(self.data_tuples[t]) > n + self.action_range:
actions_xy = [d[2] for d in self.data_tuples[t][n:n + self.action_range]]
else:
actions_xy = [d[2] for d in self.data_tuples[t][n:]]
actions_xy = np.array(actions_xy)
actions_xy = np.mean(actions_xy, axis=0)
actions[-1].append(actions_xy)
states[-1] = dgl.batch(states[-1])
actions[-1] = torch.tensor(np.array(actions[-1]))
trial_actions_padded = torch.zeros(self.max_len, actions[-1].size(1))
trial_actions_padded[:actions[-1].size(0), :] = actions[-1]
actions[-1] = trial_actions_padded
return states, actions, lens, n_nodes
def __getitem__(self, idx):
ep_trials = [idx * self.num_trials + t for t in range(self.num_trials)] # [idx, ..., idx+8]
states, actions, lens, n_nodes = self.get_trial(ep_trials, step=self.action_range)
return states, actions, lens, n_nodes
def __len__(self):
return self.tot_trials // self.num_trials
class TestTransitionDatasetSequence(torch.utils.data.Dataset):
"""
Test dataset class for the behavior cloning rnn model. This dataset is used to test the model on the eval data.
This class is used to compare plausible and implausible episodes.
Args:
path: path to the dataset
types: list of video types to include
size: size of the frames to be returned
mode: test
num_context: number of context state-action pairs
num_test: number of test state-action pairs
num_trials: number of trials in an episode
action_range: number of frames to skip; actions are combined over these number of frames (displcement) of the agent
process_data: whether to the videos or not (skip if already processed)
__getitem__:
returns: (expected_dem_frames, expected_dem_actions, expected_dem_lens expected_query_frames, expected_target_actions,
unexpected_dem_frames, unexpected_dem_actions, unexpected_dem_lens, unexpected_query_frames, unexpected_target_actions)
dem_frames: (num_context, max_len, 3, size, size)
dem_actions: (num_context, max_len, 2)
dem_lens: (num_context)
query_frames: (num_test, 3, size, size)
target_actions: (num_test, 2)
"""
def __init__(
self,
path,
task_type=None,
mode="test",
num_test=1,
num_trials=9,
action_range=10,
process_data=0,
max_len=30
):
self.path = path
self.task_type = task_type
self.mode = mode
self.num_trials = num_trials
self.num_test = num_test
self.action_range = action_range
self.max_len = max_len
self.ep_combs = self.num_trials * (self.num_trials - 2) # 9p2 - 9
self.eps = [[x, y] for x in range(self.num_trials) for y in range(self.num_trials) if x != y]
self.path_list_exp = []
self.json_list_exp = []
self.path_list_un = []
self.json_list_un = []
print(f'reading files of type {task_type} in {mode}')
paths_expected = sorted([os.path.join(self.path, task_type, x) for x in os.listdir(os.path.join(self.path, task_type)) if
x.endswith('e.mp4')])
jsons_expected = sorted([os.path.join(self.path, task_type, x) for x in os.listdir(os.path.join(self.path, task_type)) if
x.endswith('e.json') and 'index' not in x])
paths_unexpected = sorted([os.path.join(self.path, task_type, x) for x in os.listdir(os.path.join(self.path, task_type)) if
x.endswith('u.mp4')])
jsons_unexpected = sorted([os.path.join(self.path, task_type, x) for x in os.listdir(os.path.join(self.path, task_type)) if
x.endswith('u.json') and 'index' not in x])
self.path_list_exp += paths_expected
self.json_list_exp += jsons_expected
self.path_list_un += paths_unexpected
self.json_list_un += jsons_unexpected
self.data_expected = []
self.data_unexpected = []
if process_data:
# index data. This is done to speed up video retrieval.
# frame index, action tuples are stored
self.data_expected = index_data(self.json_list_exp, self.path_list_exp)
index_dict = {'data_tuples': self.data_expected}
with open(os.path.join(self.path, f'jindex_bib_test_{task_type}e.json'), 'w') as fp:
json.dump(index_dict, fp)
self.data_unexpected = index_data(self.json_list_un, self.path_list_un)
index_dict = {'data_tuples': self.data_unexpected}
with open(os.path.join(self.path, f'jindex_bib_test_{task_type}u.json'), 'w') as fp:
json.dump(index_dict, fp)
else:
with open(os.path.join(self.path, f'jindex_bib_{mode}_{task_type}e.json'), 'r') as fp:
index_dict = json.load(fp)
self.data_expected = index_dict['data_tuples']
with open(os.path.join(self.path, f'jindex_bib_{mode}_{task_type}u.json'), 'r') as fp:
index_dict = json.load(fp)
self.data_unexpected = index_dict['data_tuples']
self.rel_deter_func = [
is_top_adj, is_left_adj, is_top_right_adj, is_top_left_adj,
is_down_adj, is_right_adj, is_down_left_adj, is_down_right_adj,
is_left, is_right, is_front, is_back, is_aligned, is_close
]
print('Done.')
def _get_frame_graph(self, jsonfile, frame_idx):
# load json
with open(jsonfile, 'rb') as f:
frame_data = json.load(f)
flat_list = [x for xs in frame_data for x in xs]
# extract entities
grid_objs = parse_objects(flat_list[frame_idx])
# --- build the graph
adj = self._get_spatial_rel(grid_objs)
# define edges
is_top_adj_src, is_top_adj_dst = np.nonzero(adj[0])
is_left_adj_src, is_left_adj_dst = np.nonzero(adj[1])
is_top_right_adj_src, is_top_right_adj_dst = np.nonzero(adj[2])
is_top_left_adj_src, is_top_left_adj_dst = np.nonzero(adj[3])
is_down_adj_src, is_down_adj_dst = np.nonzero(adj[4])
is_right_adj_src, is_right_adj_dst = np.nonzero(adj[5])
is_down_left_adj_src, is_down_left_adj_dst = np.nonzero(adj[6])
is_down_right_adj_src, is_down_right_adj_dst = np.nonzero(adj[7])
is_left_src, is_left_dst = np.nonzero(adj[8])
is_right_src, is_right_dst = np.nonzero(adj[9])
is_front_src, is_front_dst = np.nonzero(adj[10])
is_back_src, is_back_dst = np.nonzero(adj[11])
is_aligned_src, is_aligned_dst = np.nonzero(adj[12])
is_close_src, is_close_dst = np.nonzero(adj[13])
g = dgl.heterograph({
('obj', 'is_top_adj', 'obj'): (torch.tensor(is_top_adj_src), torch.tensor(is_top_adj_dst)),
('obj', 'is_left_adj', 'obj'): (torch.tensor(is_left_adj_src), torch.tensor(is_left_adj_dst)),
('obj', 'is_top_right_adj', 'obj'): (torch.tensor(is_top_right_adj_src), torch.tensor(is_top_right_adj_dst)),
('obj', 'is_top_left_adj', 'obj'): (torch.tensor(is_top_left_adj_src), torch.tensor(is_top_left_adj_dst)),
('obj', 'is_down_adj', 'obj'): (torch.tensor(is_down_adj_src), torch.tensor(is_down_adj_dst)),
('obj', 'is_right_adj', 'obj'): (torch.tensor(is_right_adj_src), torch.tensor(is_right_adj_dst)),
('obj', 'is_down_left_adj', 'obj'): (torch.tensor(is_down_left_adj_src), torch.tensor(is_down_left_adj_dst)),
('obj', 'is_down_right_adj', 'obj'): (torch.tensor(is_down_right_adj_src), torch.tensor(is_down_right_adj_dst)),
('obj', 'is_left', 'obj'): (torch.tensor(is_left_src), torch.tensor(is_left_dst)),
('obj', 'is_right', 'obj'): (torch.tensor(is_right_src), torch.tensor(is_right_dst)),
('obj', 'is_front', 'obj'): (torch.tensor(is_front_src), torch.tensor(is_front_dst)),
('obj', 'is_back', 'obj'): (torch.tensor(is_back_src), torch.tensor(is_back_dst)),
('obj', 'is_aligned', 'obj'): (torch.tensor(is_aligned_src), torch.tensor(is_aligned_dst)),
('obj', 'is_close', 'obj'): (torch.tensor(is_close_src), torch.tensor(is_close_dst))
}, num_nodes_dict={'obj': len(grid_objs)})
g = self._add_node_features(grid_objs, g)
return g
def _add_node_features(self, objs, graph):
for obj_idx, obj in enumerate(objs):
graph.nodes[obj_idx].data['type'] = torch.tensor(obj.type)
graph.nodes[obj_idx].data['pos'] = torch.tensor([[obj.x, obj.y]], dtype=torch.float32)
assert len(obj.attributes) == 2 and None not in obj.attributes[0] and None not in obj.attributes[1]
graph.nodes[obj_idx].data['color'] = torch.tensor([obj.attributes[0]])
graph.nodes[obj_idx].data['shape'] = torch.tensor([obj.attributes[1]])
return graph
def _get_spatial_rel(self, objs):
spatial_tensors = [np.zeros([len(objs), len(objs)]) for _ in range(len(self.rel_deter_func))]
for obj_idx1, obj1 in enumerate(objs):
for obj_idx2, obj2 in enumerate(objs):
direction_vec = np.array((0, -1)) # Up why??????????????
for rel_idx, func in enumerate(self.rel_deter_func):
if func(obj1, obj2, direction_vec):
spatial_tensors[rel_idx][obj_idx1, obj_idx2] = 1.0
return spatial_tensors
def get_trial(self, trials, data, step=10):
# retrieve state embeddings and actions from cached file
states = []
actions = []
trial_len = []
lens = []
n_nodes = []
for t in trials:
tl = [(t, n) for n in range(0, len(data[t]), step)]
if len(tl) > self.max_len:
tl = tl[:self.max_len]
trial_len.append(tl)
for tl in trial_len:
states.append([])
actions.append([])
lens.append(len(tl))
for t, n in tl:
video = data[t][n][0]
states[-1].append(self._get_frame_graph(video, data[t][n][1]))
n_nodes.append(states[-1][-1].number_of_nodes())
if len(data[t]) > n + self.action_range:
actions_xy = [d[2] for d in data[t][n:n + self.action_range]]
else:
actions_xy = [d[2] for d in data[t][n:]]
actions_xy = np.array(actions_xy)
actions_xy = np.mean(actions_xy, axis=0)
actions[-1].append(actions_xy)
states[-1] = dgl.batch(states[-1])
actions[-1] = torch.tensor(np.array(actions[-1]))
trial_actions_padded = torch.zeros(self.max_len, actions[-1].size(1))
trial_actions_padded[:actions[-1].size(0), :] = actions[-1]
actions[-1] = trial_actions_padded
return states, actions, lens, n_nodes
def get_test(self, trial, data, step=10):
# retrieve state embeddings and actions from cached file
states = []
actions = []
trial_len = []
trial_len += [(trial, n) for n in range(0, len(data[trial]), step)]
for t, n in trial_len:
video = data[t][n][0]
state = self._get_frame_graph(video, data[t][n][1])
if len(data[t]) > n + self.action_range:
actions_xy = [d[2] for d in data[t][n:n + self.action_range]]
else:
actions_xy = [d[2] for d in data[t][n:]]
actions_xy = np.array(actions_xy)
actions_xy = np.mean(actions_xy, axis=0)
actions.append(actions_xy)
states.append(state)
#states = torch.stack(states)
states = dgl.batch(states)
actions = torch.tensor(np.array(actions))
return states, actions
def __getitem__(self, idx):
ep_trials = [idx * self.num_trials + t for t in range(self.num_trials)]
dem_expected_states, dem_expected_actions, dem_expected_lens, dem_expected_nodes = self.get_trial(
ep_trials[:-1], self.data_expected, step=self.action_range
)
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, dem_unexpected_nodes = self.get_trial(
ep_trials[:-1], self.data_unexpected, step=self.action_range
)
query_expected_frames, target_expected_actions = self.get_test(
ep_trials[-1], self.data_expected, step=self.action_range
)
query_unexpected_frames, target_unexpected_actions = self.get_test(
ep_trials[-1], self.data_unexpected, step=self.action_range
)
return dem_expected_states, dem_expected_actions, dem_expected_lens, dem_expected_nodes, \
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, dem_unexpected_nodes, \
query_expected_frames, target_expected_actions, \
query_unexpected_frames, target_unexpected_actions
def __len__(self):
return len(self.path_list_exp)
if __name__ == '__main__':
types = ['preference', 'multi_agent', 'inaccessible_goal',
'efficiency_irrational', 'efficiency_time','efficiency_path',
'instrumental_no_barrier', 'instrumental_blocking_barrier', 'instrumental_inconsequential_barrier']
for t in types:
ttd = TestTransitionDatasetSequence(path='/datasets/external/bib_evaluation_1_1/', task_type=t, process_data=0, mode='test')
for i in range(ttd.__len__()):
print(i, end='\r')
dem_expected_states, dem_expected_actions, dem_expected_lens, dem_expected_nodes, \
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, dem_unexpected_nodes, \
query_expected_frames, target_expected_actions, \
query_unexpected_frames, target_unexpected_actions = ttd.__getitem__(i)
for j in range(8):
if not torch.tensor([1., 0., 0., 0., 0., 0., 0., 0., 0.]) in dem_expected_states[j].ndata['type']:
print(i)
if not torch.tensor([1., 0., 0., 0., 0., 0., 0., 0., 0.]) in dem_unexpected_states[j].ndata['type']:
print(i)
if not torch.tensor([1., 0., 0., 0., 0., 0., 0., 0., 0.]) in query_expected_frames.ndata['type']:
print(i)
if not torch.tensor([1., 0., 0., 0., 0., 0., 0., 0., 0.]) in query_unexpected_frames.ndata['type']:
print(i)

174
utils/grid_object.py Normal file
View file

@ -0,0 +1,174 @@
import json
import pdb
import numpy as np
from sklearn.preprocessing import OneHotEncoder
import itertools
SHAPES = {
# walls
'square': 0,
# objects
'heart': 1, 'clove_tree': 2, 'moon': 3, 'wine': 4, 'double_dia': 5,
'flag': 6, 'capsule': 7, 'vase': 8, 'curved_triangle': 9, 'spoon': 10,
'medal': 11, # inaccessible goal
# home
'home': 12,
# agent
'pentagon': 13, 'clove': 14, 'kite': 15,
# key
'triangle0': 16, 'triangle180': 16, 'triangle90': 16, 'triangle270': 16,
# lock
'triangle_slot0': 17, 'triangle_slot180': 17, 'triangle_slot90': 17, 'triangle_slot270': 17
}
ENTITIES = {
'agent': 0 , 'walls': 1, 'fuse_walls': 2, 'key': 3,
'blocking': 4, 'home': 5, 'objects': 6, 'pin': 7, 'lock': 8
}
# =============================== GridPbject class ===============================
class GridObject():
"object is specified by its location"
def __init__(self, x, y, object_type, attributes=[]):
self.x = x
self.y = y
self.type = object_type
self.attributes = attributes
@property
def pos(self):
return np.array([self.x, self.y])
@property
def name(self):
return {'type': str(self.type),
'x': str(self.x),
'y': str(self.y),
'color': self.attributes[0],
'shape': self.attributes[1]}
# =============================== Helper functions ===============================
def type2index(key):
for name, idx in ENTITIES.items():
if name == key:
return idx
def find_shape(shape_string, print_shape=False):
try:
shape = shape_string.split('/')[-1].split('.')[0]
except:
shape = shape_string.split('.')[0]
if print_shape: print(shape)
for name, idx in SHAPES.items():
if name == shape:
return idx
def parse_objects(frame):
"""
x and y are computed differently from walls and objects
for walls x, y = obj[0][0] + obj[1][0]/2, obj[0][1] + obj[1][1]/2
for objects x, y = obj[0][0] + obj[1], obj[0][1] + obj[1]
:param obj:
:return: GridObject
"""
shape_onehot_encoder = OneHotEncoder(sparse=False)
shape_onehot_encoder.fit([[i] for i in range(len(SHAPES)-6)])
type_onehot_encoder = OneHotEncoder(sparse=False)
type_onehot_encoder.fit([[i] for i in range(len(ENTITIES))])
# remove duplicate walls
frame['walls'].sort()
frame['walls'] = list(k for k, _ in itertools.groupby(frame['walls']))
# remove boundary walls
frame['walls'] = [w for w in frame['walls'] if (w[0][0] != 0 and w[0][0] != 180 and w[0][1] != 0 and w[0][1] != 180)]
# remove duplicate fuse_walls
frame['fuse_walls'].sort()
frame['fuse_walls'] = list(k for k, _ in itertools.groupby(frame['fuse_walls']))
grid_objs = []
assert 'agent' in frame.keys()
for key in frame.keys():
#print(key)
if key == 'size':
continue
obj = frame[key]
if obj == []:
#print(key, 'skipped')
continue
obj_type = type2index(key)
obj_type = type_onehot_encoder.transform([[obj_type]])
if key == 'walls':
for wall in obj:
x, y = wall[0][0] + wall[1][0]/2, wall[0][1] + wall[1][1]/2
#x, y = (wall[0][0] + wall[1][0]/2)/200, (wall[0][1] + wall[1][1]/2)/200 if u use this you need to change relations.py!!!
color = [0, 0, 0] if key == 'walls' else [80, 146, 56]
#color = [c / 255 for c in color]
shape = 0
assert shape in SHAPES.values(), 'Shape not found'
shape = shape_onehot_encoder.transform([[shape]])[0]
grid_obj = GridObject(x=x, y=y, object_type=obj_type, attributes=[color, shape])
grid_objs.append(grid_obj)
elif key == 'fuse_walls':
# resample green barriers
obj = [obj[i] for i in range(len(obj)) if (obj[i][0][0] % 20 == 0 and obj[i][0][1] % 20 == 0)]
for wall in obj:
x, y = wall[0][0] + wall[1][0]/2, wall[0][1] + wall[1][1]/2
#x, y = (wall[0][0] + wall[1][0]/2)/200, (wall[0][1] + wall[1][1]/2)/200 if u use this you need to change relations.py!!!
color = [80, 146, 56]
#color = [c / 255 for c in color]
shape = 0
assert shape in SHAPES.values(), 'Shape not found'
shape = shape_onehot_encoder.transform([[shape]])[0]
grid_obj = GridObject(x=x, y=y, object_type=obj_type, attributes=[color, shape])
grid_objs.append(grid_obj)
elif key == 'objects':
for ob in obj:
x, y = ob[0][0] + ob[1], ob[0][1] + ob[1]
color = ob[-1]
#color = [c / 255 for c in color]
shape = find_shape(ob[2], print_shape=False)
assert shape in SHAPES.values(), 'Shape not found'
shape = shape_onehot_encoder.transform([[shape]])[0]
grid_obj = GridObject(x=x, y=y, object_type=obj_type, attributes=[color, shape])
grid_objs.append(grid_obj)
elif key == 'key':
obj = obj[0]
x, y = obj[0][0] + obj[1], obj[0][1] + obj[1]
color = obj[-1]
#color = [c / 255 for c in color]
shape = find_shape(obj[2], print_shape=False)
assert shape in SHAPES.values(), 'Shape not found'
shape = shape_onehot_encoder.transform([[shape]])[0]
grid_obj = GridObject(x=x, y=y, object_type=obj_type, attributes=[color, shape])
grid_objs.append(grid_obj)
elif key == 'lock':
obj = obj[0]
x, y = obj[0][0] + obj[1], obj[0][1] + obj[1]
color = obj[-1]
#color = [c / 255 for c in color]
shape = find_shape(obj[2], print_shape=False)
assert shape in SHAPES.values(), 'Shape not found'
shape = shape_onehot_encoder.transform([[shape]])[0]
grid_obj = GridObject(x=x, y=y, object_type=obj_type, attributes=[color, shape])
grid_objs.append(grid_obj)
else:
try:
x, y = obj[0][0] + obj[1], obj[0][1] + obj[1]
color = obj[-1]
#color = [c / 255 for c in color]
shape = find_shape(obj[2], print_shape=False)
assert shape in SHAPES.values(), 'Shape not found'
shape = shape_onehot_encoder.transform([[shape]])[0]
except:
# [[[x, y], extension, shape, color]] in some cases in instrumental_no_barrier (bib_evaluation_1_1)
x, y = obj[0][0][0] + obj[0][1], obj[0][0][1] + obj[0][1]
color = obj[0][-1]
#color = [c / 255 for c in color]
assert len(color) == 3
shape = find_shape(obj[0][2], print_shape=False)
assert shape in SHAPES.values(), 'Shape not found'
shape = shape_onehot_encoder.transform([[shape]])[0]
grid_obj = GridObject(x=x, y=y, object_type=obj_type, attributes=[color, shape])
grid_objs.append(grid_obj)
return grid_objs

124
utils/index_data.py Normal file
View file

@ -0,0 +1,124 @@
import json
import os
import torch
import torch.utils.data
from tqdm import tqdm
def index_data(json_list, path_list):
print(f'processing files {len(json_list)}')
data_tuples = []
for j, v in tqdm(zip(json_list, path_list)):
with open(j, 'r') as f:
state = json.load(f)
ep_lens = [len(x) for x in state]
past_len = 0
for e, l in enumerate(ep_lens):
data_tuples.append([])
# skip first 30 frames and last 83 frames
for f in range(30, l - 83):
# find action taken;
f0x, f0y = state[e][f]['agent'][0]
f1x, f1y = state[e][f + 1]['agent'][0]
dx = (f1x - f0x) / 2.
dy = (f1y - f0y) / 2.
action = [dx, dy]
#data_tuples[-1].append((v, past_len + f, action))
data_tuples[-1].append((j, past_len + f, action))
# data_tuples = (json file, frame number, action)
assert len(data_tuples[-1]) > 0
past_len += l
return data_tuples
class TransitionDataset(torch.utils.data.Dataset):
"""
Training dataset class for the behavior cloning mlp model.
Args:
path: path to the dataset
types: list of video types to include
size: size of the frames to be returned
mode: train, val
num_context: number of context state-action pairs
num_test: number of test state-action pairs
num_trials: number of trials in an episode
action_range: number of frames to skip; actions are combined over these number of frames (displcement) of the agent
process_data: whether to the videos or not (skip if already processed)
__getitem__:
returns: (dem_frames, dem_actions, query_frames, target_actions)
dem_frames: (num_context, 3, size, size)
dem_actions: (num_context, 2)
query_frames: (num_test, 3, size, size)
target_actions: (num_test, 2)
"""
def __init__(self, path, types=None, size=None, mode="train", num_context=30, num_test=1, num_trials=9,
action_range=10, process_data=0):
self.path = path
self.types = types
self.size = size
self.mode = mode
self.num_trials = num_trials
self.num_context = num_context
self.num_test = num_test
self.action_range = action_range
self.ep_combs = self.num_trials * (self.num_trials - 2) # 9p2 - 9
self.eps = [[x, y] for x in range(self.num_trials) for y in range(self.num_trials) if x != y]
types_str = '_'.join(self.types)
self.path_list = []
self.json_list = []
# get video paths and json file paths
for t in types:
print(f'reading files of type {t} in {mode}')
paths = [os.path.join(self.path, t, x) for x in os.listdir(os.path.join(self.path, t)) if
x.endswith(f'.mp4')]
jsons = [os.path.join(self.path, t, x) for x in os.listdir(os.path.join(self.path, t)) if
x.endswith(f'.json') and 'index' not in x]
paths = sorted(paths)
jsons = sorted(jsons)
if mode == 'train':
self.path_list += paths[:int(0.8 * len(jsons))]
self.json_list += jsons[:int(0.8 * len(jsons))]
elif mode == 'val':
self.path_list += paths[int(0.8 * len(jsons)):]
self.json_list += jsons[int(0.8 * len(jsons)):]
else:
self.path_list += paths
self.json_list += jsons
self.data_tuples = []
if process_data:
# index the videos in the dataset directory. This is done to speed up the retrieval of videos.
# frame index, action tuples are stored
self.data_tuples = index_data(self.json_list, self.path_list)
# tuples of frame index and action (displacement of agent)
index_dict = {'data_tuples': self.data_tuples}
with open(os.path.join(self.path, f'jindex_bib_{mode}_{types_str}.json'), 'w') as fp:
json.dump(index_dict, fp)
else:
# read pre-indexed data
with open(os.path.join(self.path, f'jindex_bib_{mode}_{types_str}.json'), 'r') as fp:
index_dict = json.load(fp)
self.data_tuples = index_dict['data_tuples']
self.tot_trials = len(self.path_list) * 9
def __getitem__(self, idx):
print('Empty')
return
def __len__(self):
return self.tot_trials // self.num_trials
if __name__ == "__main__":
dataset = TransitionDataset(path='/datasets/external/bib_train/',
types=['multi_agent', 'instrumental_action'], #['instrumental_action', 'multi_agent', 'preference', 'single_object'],
size=(84, 84),
mode="train", num_context=30,
num_test=1, num_trials=9,
action_range=10, process_data=1)
print(len(dataset))

116
utils/relations.py Normal file
View file

@ -0,0 +1,116 @@
import numpy as np
# =============================== relationships to build the graph ===============================
def rotate_vec2d(vec, degrees):
"""
rotate a vector anti-clockwise
:param vec:
:param degrees:
:return:
"""
theta = np.radians(degrees)
c, s = np.cos(theta), np.sin(theta)
R = np.array(((c, -s), (s, c)))
return R@vec
# ---------- Remote Directional Relations --------------------------------------------------------
def is_front(obj1, obj2, direction_vec)->bool:
diff = obj2.pos - obj1.pos
return diff@direction_vec > 0.1
def is_back(obj1, obj2, direction_vec)->bool:
diff = obj2.pos - obj1.pos
return diff@direction_vec < -0.1
def is_left(obj1, obj2, direction_vec)->bool:
left_vec = rotate_vec2d(direction_vec, -90)
diff = obj2.pos - obj1.pos
return diff@left_vec > 0.1
def is_right(obj1, obj2, direction_vec)->bool:
left_vec = rotate_vec2d(direction_vec, 90)
diff = obj2.pos - obj1.pos
return diff@left_vec > 0.1
# ---------- Alignment and Adjacency Relations ---------------------------------------------------
def is_close(obj1, obj2, direction_vec=None)->bool:
# indicate whether two objects are adjacent to each other,
# which, unlike local directional relations, carry no directional information
distance = np.abs(obj1.pos - obj2.pos)
return np.sum(distance)==20
def is_aligned(obj1, obj2, direction_vec=None)->bool:
# indicate if two entities are on the same horizontal or vertical line
diff = obj2.pos - obj1.pos
return np.any(diff==0)
# ---------- Local Directional Relations ---------------------------------------------------------
def is_top_adj(obj1, obj2, direction_vec=None)->bool:
return obj1.x==obj2.x and obj1.y==obj2.y+20
def is_left_adj(obj1, obj2, direction_vec=None)->bool:
return obj1.y==obj2.y and obj1.x==obj2.x-20
def is_top_left_adj(obj1, obj2, direction_vec=None)->bool:
return obj1.y==obj2.y+20 and obj1.x==obj2.x-20
def is_top_right_adj(obj1, obj2, direction_vec=None)->bool:
return obj1.y==obj2.y+20 and obj1.x==obj2.x+20
def is_down_adj(obj1, obj2, direction_vec=None)->bool:
return is_top_adj(obj2, obj1)
def is_right_adj(obj1, obj2, direction_vec=None)->bool:
return is_left_adj(obj2, obj1)
def is_down_right_adj(obj1, obj2, direction_vec=None)->bool:
return is_top_left_adj(obj2, obj1)
def is_down_left_adj(obj1, obj2, direction_vec=None)->bool:
return is_top_right_adj(obj2, obj1)
# ---------- More Remote Directional Relations (not used) ----------------------------------------
def top_left(obj1, obj2, direction_vec)->bool:
return (obj1.x-obj2.x) <= (obj1.y-obj2.y)
def top_right(obj1, obj2, direction_vec)->bool:
return -(obj1.x-obj2.x) <= (obj1.y-obj2.y)
def down_left(obj1, obj2, direction_vec)->bool:
return top_right(obj2, obj1, direction_vec)
def down_right(obj1, obj2, direction_vec)->bool:
return top_left(obj2, obj1, direction_vec)
def fan_top(obj1, obj2, direction_vec)->bool:
top_left = (obj1.x-obj2.x) <= (obj1.y-obj2.y)
top_right = -(obj1.x-obj2.x) <= (obj1.y-obj2.y)
return top_left and top_right
def fan_down(obj1, obj2, direction_vec)->bool:
return fan_top(obj2, obj1, direction_vec)
def fan_right(obj1, obj2, direction_vec)->bool:
down_left = (obj1.x-obj2.x) >= (obj1.y-obj2.y)
top_right = -(obj1.x-obj2.x) <= (obj1.y-obj2.y)
return down_left and top_right
def fan_left(obj1, obj2, direction_vec)->bool:
return fan_right(obj2, obj1, direction_vec)
# ---------- Ad-hoc Relations --------------------------------------------------------------------
def needs(obj1, obj2, direction_vec=None)->bool:
return np.argmax(obj1.type) == 0 and np.argmax(obj2.type) == 3
def opens(obj1, obj2, direction_vec=None)->bool:
return np.argmax(obj1.type) == 3 and np.argmax(obj2.type) == 8
def collects(obj1, obj2, direction_vec=None)->bool:
return np.argmax(obj1.type) == 0 and np.argmax(obj2.type) == 6

View file

@ -0,0 +1 @@
python build_graphs.py --mode train --cpus 30 && python build_graphs.py --mode val --cpus 30