diff --git a/README.md b/README.md
index 8cfac98..3993eda 100644
--- a/README.md
+++ b/README.md
@@ -1,3 +1,75 @@
-# IRENE
+
+
Neural Reasoning about Agents' Goals, Preferences, and Actions
-Official code of "Neural Reasoning About Agents' Goals, Preferences, and Actions"
\ No newline at end of file
+**[Matteo Bortoletto][1], [Lei Shi][2], [Andreas Bulling][3]**
+**AAAI'24, Vancouver, CA**
+**[[Paper][4]]**
+
+
+
+# Citation
+If you find our code useful or use it in your own projects, please cite our paper:
+
+```bibtex
+@inproceedings{bortoletto2024neural,
+ author = {Bortoletto, Matteo and Lei, Shi and Bulling, Andreas},
+ title = {{Neural Reasoning about Agents' Goals, Preferences, and Actions}},
+ booktitle = {Proc. 38th AAAI Conference on Artificial Intelligence (AAAI)},
+ year = {2024},
+}
+```
+
+# Setup
+
+This code is based on the [original implementation][5] of the BIB benchmark.
+
+## Using `virtualenv`
+```
+python -m virtualenv /path/to/env
+source /path/to/env/bin/activate
+pip install -r requirements.txt
+```
+
+## Using `conda`
+```
+conda create --name python=3.8.10 pip=20.0.2 cudatoolkit=10.2.89
+conda activate
+pip install -r requirements_conda.txt
+pip install dgl-cu102 dglgo -f https://data.dgl.ai/wheels/repo.html
+```
+
+
+# Running the code
+
+## Activate the environment
+Run `source bibdgl/bin/activate`.
+
+## Index data
+This will create the json files with all the indexed frames for each episode in each video.
+```
+python utils/index_data.py
+```
+You need to manually set `mode` in the dataset class (in main).
+
+## Generate graphs
+This will generate the graphs from the videos:
+```
+python /utils/build_graphs.py --mode MODE --cpus NUM_CPUS
+```
+`MODE` can be `train`, `val` or `test`. NOTE: check `utils/build_graphs.py` to make sure you're loading the correct dataset to generate the graphs you want.
+
+## Training
+You can use the `gtbc.sh`.
+
+## Testing
+Use `run_test_tom.sh`.
+
+# Hardware setup
+All models are trained on an NVIDIA Tesla V100-SXM2-32GB GPU.
+
+
+[1]: https://mattbortoletto.github.io/
+[2]: https://perceptualui.org/people/shi/
+[3]: https://perceptualui.org/people/bulling/
+[4]: https://perceptualui.org/publications/bortoletto24_aaai.pdf
+[5]: https://github.com/kanishkg/bib-baselines
diff --git a/run_test.sh b/run_test.sh
new file mode 100644
index 0000000..ee0a135
--- /dev/null
+++ b/run_test.sh
@@ -0,0 +1,9 @@
+echo 314 e31
+
+CUDA_VISIBLE_DEVICES=1 python test_tom.py \
+--model_type graphbcrnn \
+--types efficiency_irrational \
+--ckpt /projects/bortoletto/icml2023_matteo/wandb/run-20221224_135525-8i1r2aqy/files/bib/8i1r2aqy/checkpoints/epoch\=31-step\=22399.ckpt \
+--data_path /datasets/external/bib_evaluation_1_1/graphs/all_tasks \
+--process_data 0 \
+--surprise_type max
diff --git a/run_train.sh b/run_train.sh
new file mode 100644
index 0000000..7b1c9fb
--- /dev/null
+++ b/run_train.sh
@@ -0,0 +1,18 @@
+CUDA_VISIBLE_DEVICES=0 python train_tom.py \
+--model_type graphbcrnn \
+--types single_object preference instrumental_action \
+--data_path /datasets/external/bib_train/graphs/all_tasks/ \
+--seed 7 \
+--batch_size 32 \
+--max_epochs 35 \
+--gpus 1 \
+--auto_select_gpus True \
+--num_workers 2 \
+--stochastic_weight_avg True \
+--lr 5e-4 \
+--check_val_every_n_epoch 1 \
+--track_grad_norm 2 \
+--gradient_clip_val 10 \
+--gnn_type RSAGEv4 \
+--state_dim 96 \
+--aggregation sum
diff --git a/test_tom.py b/test_tom.py
new file mode 100644
index 0000000..d72a008
--- /dev/null
+++ b/test_tom.py
@@ -0,0 +1,122 @@
+from argparse import ArgumentParser
+import numpy as np
+from tqdm import tqdm
+
+import torch
+from torch.utils.data import DataLoader
+import torch.nn.functional as F
+import dgl
+
+from tom.dataset import TestToMnetDGLDataset, collate_function_seq_test
+from tom.model import GraphBC_T, GraphBCRNN
+
+
+def get_z_scores(total, total_expected, total_unexpected):
+ mean = np.mean(total)
+ std = np.std(total)
+ print("Z-Score expected: ",
+ (np.mean(total_expected) - mean) / std)
+ print("Z-Score unexpected: ",
+ (np.mean(total_unexpected) - mean) / std)
+
+
+parser = ArgumentParser()
+
+parser.add_argument('--model_type', type=str, default='graphbcrnn')
+parser.add_argument('--ckpt', type=str, default=None, help='path to checkpoint')
+parser.add_argument('--data_path', type=str, default=None, help='path to the data')
+parser.add_argument('--process_data', type=int, default=0)
+parser.add_argument('--surprise_type', type=str, default='max',
+ help='surprise type: mean, max. This is used for comparing the plausibility scores of the two test episodes')
+parser.add_argument('--types', nargs='+', type=str,
+ default=[
+ 'preference', 'multi_agent', 'inaccessible_goal',
+ 'efficiency_irrational', 'efficiency_time','efficiency_path',
+ 'instrumental_no_barrier', 'instrumental_blocking_barrier', 'instrumental_inconsequential_barrier'
+ ],
+ help='types of tasks used for training / testing')
+parser.add_argument('--filename', type=str, default='')
+
+args = parser.parse_args()
+
+filename = args.filename
+
+if args.model_type == 'graphbct':
+ model = GraphBC_T.load_from_checkpoint(args.ckpt)
+elif args.model_type == 'graphbcrnn':
+ model = GraphBCRNN.load_from_checkpoint(args.ckpt)
+else:
+ raise ValueError('Unknown model type.')
+
+device = 'cuda'
+model.to(device)
+model.eval()
+with torch.no_grad():
+ for t in args.types:
+ if args.model_type == 'graphbcrnn':
+ test_dataset = TestToMnetDGLDataset(
+ path=args.data_path,
+ task_type=t,
+ mode='test'
+ )
+ test_dataloader = DataLoader(
+ test_dataset,
+ batch_size=1,
+ num_workers=1,
+ pin_memory=True,
+ collate_fn=collate_function_seq_test,
+ shuffle=False
+ )
+ count = 0
+ total, total_expected, total_unexpected = [], [], []
+ pbar = tqdm(test_dataloader)
+ for j, batch in enumerate(pbar):
+ if args.model_type == 'graphbcrnn':
+ dem_expected_states, dem_expected_actions, dem_expected_lens, \
+ dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, \
+ query_expected_frames, target_expected_actions, \
+ query_unexpected_frames, target_unexpected_actions = batch
+ dem_expected_states = dem_expected_states.to(device)
+ dem_expected_actions = dem_expected_actions.to(device)
+ dem_unexpected_states = dem_unexpected_states.to(device)
+ dem_unexpected_actions = dem_unexpected_actions.to(device)
+ target_expected_actions = target_expected_actions.to(device)
+ target_unexpected_actions = target_unexpected_actions.to(device)
+ surprise_expected = []
+ query_expected_frames = dgl.unbatch(query_expected_frames)
+ for i in range(len(query_expected_frames)):
+ if args.model_type == 'graphbcrnn':
+ test_actions, test_actions_pred = model(
+ [dem_expected_states, dem_expected_actions, dem_expected_lens, query_expected_frames[i].to(device), target_expected_actions[:, i, :]]
+ )
+ loss = F.mse_loss(test_actions, test_actions_pred)
+ surprise_expected.append(loss.cpu().detach().numpy())
+ mean_expected_surprise = np.mean(surprise_expected)
+ max_expected_surprise = np.max(surprise_expected)
+
+ # calculate the plausibility scores for the unexpected episode
+ surprise_unexpected = []
+ query_unexpected_frames = dgl.unbatch(query_unexpected_frames)
+ for i in range(len(query_unexpected_frames)):
+ if args.model_type == 'graphbcrnn':
+ test_actions, test_actions_pred = model(
+ [dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, query_unexpected_frames[i].to(device), target_unexpected_actions[:, i, :]]
+ )
+ loss = F.mse_loss(test_actions, test_actions_pred)
+ surprise_unexpected.append(loss.cpu().detach().numpy())
+ mean_unexpected_surprise = np.mean(surprise_unexpected)
+ max_unexpected_surprise = np.max(surprise_unexpected)
+
+ correct_mean = mean_expected_surprise < mean_unexpected_surprise + 0.5 * (mean_expected_surprise == mean_unexpected_surprise)
+ correct_max = max_expected_surprise < max_unexpected_surprise + 0.5 * (max_expected_surprise == max_unexpected_surprise)
+ if args.surprise_type == 'max':
+ count += correct_max
+ elif args.surprise_type == 'mean':
+ count += correct_mean
+ pbar.set_postfix({'accuracy': count/(j+1.), 'type': t})
+
+ total_expected.append(max_expected_surprise)
+ total_unexpected.append(max_unexpected_surprise)
+ total.append(max_expected_surprise)
+ total.append(max_unexpected_surprise)
+ get_z_scores(total, total_expected, total_unexpected)
diff --git a/tom/__init__.py b/tom/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tom/dataset.py b/tom/dataset.py
new file mode 100644
index 0000000..d3bcc75
--- /dev/null
+++ b/tom/dataset.py
@@ -0,0 +1,310 @@
+import pickle as pkl
+import os
+import torch
+import torch.utils.data
+import torch.nn.functional as F
+import dgl
+import random
+from dgl.data import DGLDataset
+
+
+def collate_function_seq(batch):
+ #dem_frames = torch.stack([item[0] for item in batch])
+ dem_frames = dgl.batch([item[0] for item in batch])
+ dem_actions = torch.stack([item[1] for item in batch])
+ dem_lens = [item[2] for item in batch]
+ #query_frames = torch.stack([item[3] for item in batch])
+ query_frames = dgl.batch([item[3] for item in batch])
+ target_actions = torch.stack([item[4] for item in batch])
+ return [dem_frames, dem_actions, dem_lens, query_frames, target_actions]
+
+def collate_function_seq_test(batch):
+ dem_expected_states = dgl.batch([item[0] for item in batch][0])
+ dem_expected_actions = torch.stack([item[1] for item in batch][0]).unsqueeze(dim=0)
+ dem_expected_lens = [item[2] for item in batch]
+ #print(dem_expected_actions.size())
+ dem_unexpected_states = dgl.batch([item[3] for item in batch][0])
+ dem_unexpected_actions = torch.stack([item[4] for item in batch][0]).unsqueeze(dim=0)
+ dem_unexpected_lens = [item[5] for item in batch]
+ query_expected_frames = dgl.batch([item[6] for item in batch])
+ target_expected_actions = torch.stack([item[7] for item in batch])
+ #print(target_expected_actions.size())
+ query_unexpected_frames = dgl.batch([item[8] for item in batch])
+ target_unexpected_actions = torch.stack([item[9] for item in batch])
+ return [
+ dem_expected_states, dem_expected_actions, dem_expected_lens, \
+ dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, \
+ query_expected_frames, target_expected_actions, \
+ query_unexpected_frames, target_unexpected_actions
+ ]
+
+def collate_function_mental(batch):
+ dem_frames = dgl.batch([item[0] for item in batch])
+ dem_actions = torch.stack([item[1] for item in batch])
+ dem_lens = [item[2] for item in batch]
+ past_test_frames = dgl.batch([item[3] for item in batch])
+ past_test_actions = torch.stack([item[4] for item in batch])
+ past_test_len = [item[5] for item in batch]
+ query_frames = dgl.batch([item[6] for item in batch])
+ target_actions = torch.stack([item[7] for item in batch])
+ return [dem_frames, dem_actions, dem_lens, past_test_frames, past_test_actions, past_test_len, query_frames, target_actions]
+
+
+class ToMnetDGLDataset(DGLDataset):
+ """
+ Training dataset class.
+ """
+ def __init__(self, path, types=None, mode="train"):
+ self.path = path
+ self.types = types
+ self.mode = mode
+ print('Mode:', self.mode)
+
+ if self.mode == 'train':
+ if len(self.types) == 4:
+ self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/'
+ #self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_global/'
+ #self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_local/'
+ elif len(self.types) == 3:
+ self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper() + '/'
+ print(self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper())
+ elif len(self.types) == 2:
+ self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + '/'
+ print(self.types[0][0].upper() + self.types[1][0].upper())
+ elif len(self.types) == 1:
+ self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + '/'
+ else: raise ValueError('Number of types different from 1 or 4.')
+ elif self.mode == 'val':
+ assert len(self.types) == 1
+ self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/' + self.types[0] + '/'
+ #self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_global/' + self.types[0] + '/'
+ #self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_local/' + self.types[0] + '/'
+ else:
+ raise ValueError
+
+ def get_test(self, states, actions):
+ # now states is a batched graph -> unbatch it, take the len, pick one sub-graph
+ # randomly and select the corresponding action
+ frame_graphs = dgl.unbatch(states)
+ trial_len = len(frame_graphs)
+ query_idx = random.randint(0, trial_len - 1)
+ query_graph = frame_graphs[query_idx]
+ target_action = actions[query_idx]
+ return query_graph, target_action
+
+ def __getitem__(self, idx):
+ with open(self.path+str(idx)+'.pkl', 'rb') as f:
+ states, actions, lens, _ = pkl.load(f)
+ # shuffle
+ ziplist = list(zip(states, actions, lens))
+ random.shuffle(ziplist)
+ states, actions, lens = zip(*ziplist)
+ # convert tuples to lists
+ states, actions, lens = [*states], [*actions], [*lens]
+ # pick last element in the list as test and pick random frame
+ test_s, test_a = self.get_test(states[-1], actions[-1])
+ dem_s = states[:-1]
+ dem_a = actions[:-1]
+ dem_lens = lens[:-1]
+ dem_s = dgl.batch(dem_s)
+ dem_a = torch.stack(dem_a)
+ return dem_s, dem_a, dem_lens, test_s, test_a
+
+ def __len__(self):
+ return len(os.listdir(self.path))
+
+
+class TestToMnetDGLDataset(DGLDataset):
+ """
+ Testing dataset class.
+ """
+ def __init__(self, path, task_type=None, mode="test"):
+ self.path = path
+ self.type = task_type
+ self.mode = mode
+ print('Mode:', self.mode)
+
+ if self.mode == 'test':
+ self.path = self.path + '_dgl_hetero_nobound_4feats/' + self.type + '/'
+ #self.path = self.path + '_dgl_hetero_nobound_4feats_global/' + self.type + '/'
+ #self.path = self.path + '_dgl_hetero_nobound_4feats_local/' + self.type + '/'
+ else:
+ raise ValueError
+
+ def __getitem__(self, idx):
+ with open(self.path+str(idx)+'.pkl', 'rb') as f:
+ dem_expected_states, dem_expected_actions, dem_expected_lens, _, \
+ dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, _, \
+ query_expected_frames, target_expected_actions, \
+ query_unexpected_frames, target_unexpected_actions = pkl.load(f)
+ assert len(dem_expected_states) == 8
+ assert len(dem_expected_actions) == 8
+ assert len(dem_expected_lens) == 8
+ assert len(dem_unexpected_states) == 8
+ assert len(dem_unexpected_actions) == 8
+ assert len(dem_unexpected_lens) == 8
+ assert len(dgl.unbatch(query_expected_frames)) == target_expected_actions.size()[0]
+ assert len(dgl.unbatch(query_unexpected_frames)) == target_unexpected_actions.size()[0]
+ # ignore n_nodes
+ return dem_expected_states, dem_expected_actions, dem_expected_lens, \
+ dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, \
+ query_expected_frames, target_expected_actions, \
+ query_unexpected_frames, target_unexpected_actions
+
+ def __len__(self):
+ return len(os.listdir(self.path))
+
+
+class ToMnetDGLDatasetUndersample(DGLDataset):
+ """
+ Training dataset class for the behavior cloning mlp model.
+ """
+ def __init__(self, path, types=None, mode="train"):
+ self.path = path
+ self.types = types
+ self.mode = mode
+ print('Mode:', self.mode)
+
+ if self.mode == 'train':
+ if len(self.types) == 4:
+ self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/'
+ elif len(self.types) == 3:
+ self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper() + '/'
+ print(self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper())
+ elif len(self.types) == 2:
+ self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + '/'
+ print(self.types[0][0].upper() + self.types[1][0].upper())
+ elif len(self.types) == 1:
+ self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + '/'
+ else: raise ValueError('Number of types different from 1 or 4.')
+ elif self.mode == 'val':
+ assert len(self.types) == 1
+ self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/' + self.types[0] + '/'
+ else:
+ raise ValueError
+ print('Undersampled dataset!')
+
+ def get_test(self, states, actions):
+ # now states is a batched graph -> unbatch it, take the len, pick one sub-graph
+ # randomly and select the corresponding action
+ frame_graphs = dgl.unbatch(states)
+ trial_len = len(frame_graphs)
+ query_idx = random.randint(0, trial_len - 1)
+ query_graph = frame_graphs[query_idx]
+ target_action = actions[query_idx]
+ return query_graph, target_action
+
+ def __getitem__(self, idx):
+ idx = idx + 3175
+ with open(self.path+str(idx)+'.pkl', 'rb') as f:
+ states, actions, lens, _ = pkl.load(f)
+ # shuffle
+ ziplist = list(zip(states, actions, lens))
+ random.shuffle(ziplist)
+ states, actions, lens = zip(*ziplist)
+ # convert tuples to lists
+ states, actions, lens = [*states], [*actions], [*lens]
+ # pick last element in the list as test and pick random frame
+ test_s, test_a = self.get_test(states[-1], actions[-1])
+ dem_s = states[:-1]
+ dem_a = actions[:-1]
+ dem_lens = lens[:-1]
+ dem_s = dgl.batch(dem_s)
+ dem_a = torch.stack(dem_a)
+ return dem_s, dem_a, dem_lens, test_s, test_a
+
+ def __len__(self):
+ return len(os.listdir(self.path)) - 3175
+
+
+class ToMnetDGLDatasetMental(DGLDataset):
+ """
+ Training dataset class.
+ """
+ def __init__(self, path, types=None, mode="train"):
+ self.path = path
+ self.types = types
+ self.mode = mode
+ print('Mode:', self.mode)
+
+ if self.mode == 'train':
+ if len(self.types) == 4:
+ self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/'
+ elif len(self.types) == 3:
+ self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper() + '/'
+ print(self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper())
+ elif len(self.types) == 2:
+ self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + '/'
+ print(self.types[0][0].upper() + self.types[1][0].upper())
+ elif len(self.types) == 1:
+ self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + '/'
+ else: raise ValueError('Number of types different from 1 or 4.')
+ elif self.mode == 'val':
+ assert len(self.types) == 1
+ self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/' + self.types[0] + '/'
+ else:
+ raise ValueError
+
+ def get_test(self, states, actions):
+ """
+ return: past_test_graphs, past_test_actions, test_graph, test_action
+ """
+ frame_graphs = dgl.unbatch(states)
+ trial_len = len(frame_graphs)
+ query_idx = random.randint(0, trial_len - 1)
+ test_graph = frame_graphs[query_idx]
+ test_action = actions[query_idx]
+ if query_idx > 0:
+ #past_test_actions = F.pad(past_test_actions, (0,0,0,41-query_idx), 'constant', 0)
+ if query_idx == 1:
+ past_test_graphs = frame_graphs[0]
+ past_test_actions = actions[:query_idx]
+ past_test_actions = F.pad(past_test_actions, (0,0,0,41-query_idx), 'constant', 0)
+ return past_test_graphs, past_test_actions, query_idx, test_graph, test_action
+ else:
+ past_test_graphs = frame_graphs[:query_idx]
+ past_test_actions = actions[:query_idx]
+ past_test_graphs = dgl.batch(past_test_graphs)
+ past_test_actions = F.pad(past_test_actions, (0,0,0,41-query_idx), 'constant', 0)
+ return past_test_graphs, past_test_actions, query_idx, test_graph, test_action
+ else:
+ test_action_ = F.pad(test_action.unsqueeze(0), (0,0,0,41-1), 'constant', 0)
+ # NOTE: since there are no past frames, return the test frame and action and query_idx=1
+ # not sure what would be a better alternative
+ return test_graph, test_action_, 1, test_graph, test_action
+
+ def __getitem__(self, idx):
+ with open(self.path+str(idx)+'.pkl', 'rb') as f:
+ states, actions, lens, _ = pkl.load(f)
+ ziplist = list(zip(states, actions, lens))
+ random.shuffle(ziplist)
+ states, actions, lens = zip(*ziplist)
+ states, actions, lens = [*states], [*actions], [*lens]
+ past_test_s, past_test_a, past_test_len, test_s, test_a = self.get_test(states[-1], actions[-1])
+ dem_s = states[:-1]
+ dem_a = actions[:-1]
+ dem_lens = lens[:-1]
+ dem_s = dgl.batch(dem_s)
+ dem_a = torch.stack(dem_a)
+ return dem_s, dem_a, dem_lens, past_test_s, past_test_a, past_test_len, test_s, test_a
+
+ def __len__(self):
+ return len(os.listdir(self.path))
+
+# --------------------------------------------------------------------------------------------------------
+
+if __name__ == "__main__":
+
+ types = [
+ 'preference', 'multi_agent', 'inaccessible_goal',
+ 'efficiency_irrational', 'efficiency_time','efficiency_path',
+ 'instrumental_no_barrier', 'instrumental_blocking_barrier', 'instrumental_inconsequential_barrier'
+ ]
+
+ mental_dataset = ToMnetDGLDatasetMental(
+ path='/datasets/external/bib_train/graphs/all_tasks/',
+ types=['instrumental_action'],
+ mode='train'
+ )
+ dem_frames, dem_actions, dem_lens, past_test_frames, past_test_actions, len, test_frame, test_action = mental_dataset.__getitem__(99)
+ breakpoint()
\ No newline at end of file
diff --git a/tom/gnn.py b/tom/gnn.py
new file mode 100644
index 0000000..10923d4
--- /dev/null
+++ b/tom/gnn.py
@@ -0,0 +1,877 @@
+import dgl.nn.pytorch as dglnn
+import torch.nn as nn
+import torch.nn.functional as F
+import torch
+import dgl
+import sys
+import copy
+
+from wandb import agent
+sys.path.append('/projects/bortoletto/irene/')
+from tom.norm import Norm
+
+
+class RSAGEv4(nn.Module):
+ # multi-layer GNN for one single feature
+ def __init__(
+ self,
+ hidden_channels,
+ out_channels,
+ rel_names,
+ dropout,
+ n_layers,
+ activation=nn.ELU(),
+ ):
+ super().__init__()
+ self.embedding_type = nn.Linear(9, int(hidden_channels))
+ self.embedding_pos = nn.Linear(2, int(hidden_channels))
+ self.embedding_color = nn.Linear(3, int(hidden_channels))
+ self.embedding_shape = nn.Linear(18, int(hidden_channels))
+ self.combine_attr = nn.Sequential(
+ nn.ReLU(),
+ nn.Linear(hidden_channels*3, hidden_channels)
+ )
+ self.combine_pos = nn.Sequential(
+ nn.ReLU(),
+ nn.Linear(hidden_channels*2, hidden_channels)
+ )
+ self.layers = nn.ModuleList([
+ dglnn.HeteroGraphConv({
+ rel: dglnn.SAGEConv(
+ in_feats=hidden_channels,
+ out_feats=hidden_channels if l < n_layers - 1 else out_channels,
+ aggregator_type='lstm',
+ feat_drop=dropout,
+ activation=activation if l < n_layers - 1 else None
+ )
+ for rel in rel_names}, aggregate='sum')
+ for l in range(n_layers)
+ ])
+
+ def forward(self, g):
+ attr = []
+ attr.append(self.embedding_type(g.ndata['type'].float()))
+ pos = self.embedding_pos(g.ndata['pos']/170.)
+ attr.append(self.embedding_color(g.ndata['color']/255.))
+ attr.append(self.embedding_shape(g.ndata['shape'].float()))
+ combined_attr = self.combine_attr(torch.cat(attr, dim=1))
+ h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
+ for l, conv in enumerate(self.layers):
+ h = conv(g, h)
+ with g.local_scope():
+ g.ndata['h'] = h['obj']
+ out = dgl.mean_nodes(g, 'h')
+ return out
+
+# -------------------------------------------------------------------------------------------
+
+class RAGNNv4(nn.Module):
+ # multi-layer GNN for one single feature
+ def __init__(
+ self,
+ hidden_channels,
+ out_channels,
+ rel_names,
+ dropout,
+ n_layers,
+ activation=nn.ELU(),
+ ):
+ super().__init__()
+ self.embedding_type = nn.Linear(9, int(hidden_channels))
+ self.embedding_pos = nn.Linear(2, int(hidden_channels))
+ self.embedding_color = nn.Linear(3, int(hidden_channels))
+ self.embedding_shape = nn.Linear(18, int(hidden_channels))
+ self.combine_attr = nn.Sequential(
+ nn.ReLU(),
+ nn.Linear(hidden_channels*3, hidden_channels)
+ )
+ self.combine_pos = nn.Sequential(
+ nn.ReLU(),
+ nn.Linear(hidden_channels*2, hidden_channels)
+ )
+ self.layers = nn.ModuleList([
+ dglnn.HeteroGraphConv({
+ rel: dglnn.AGNNConv()
+ for rel in rel_names}, aggregate='sum')
+ for l in range(n_layers)
+ ])
+
+ def forward(self, g):
+ attr = []
+ attr.append(self.embedding_type(g.ndata['type'].float()))
+ pos = self.embedding_pos(g.ndata['pos']/170.)
+ attr.append(self.embedding_color(g.ndata['color']/255.))
+ attr.append(self.embedding_shape(g.ndata['shape'].float()))
+ combined_attr = self.combine_attr(torch.cat(attr, dim=1))
+ h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
+ for l, conv in enumerate(self.layers):
+ h = conv(g, h)
+ with g.local_scope():
+ g.ndata['h'] = h['obj']
+ out = dgl.mean_nodes(g, 'h')
+ return out
+
+# -------------------------------------------------------------------------------------------
+
+class RGATv2(nn.Module):
+ # multi-layer GNN for one single feature
+ def __init__(
+ self,
+ hidden_channels,
+ out_channels,
+ num_heads,
+ rel_names,
+ dropout,
+ n_layers,
+ activation=nn.ELU(),
+ residual=False
+ ):
+ super().__init__()
+ self.embedding_type = nn.Embedding(9, int(hidden_channels*num_heads/4))
+ self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads/4))
+ self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads/4))
+ self.embedding_shape = nn.Embedding(18, int(hidden_channels*num_heads/4))
+ self.combine = nn.Sequential(
+ nn.ReLU(),
+ nn.Linear(hidden_channels*num_heads, hidden_channels*num_heads)
+ )
+ self.layers = nn.ModuleList([
+ dglnn.HeteroGraphConv({
+ rel: dglnn.GATv2Conv(
+ in_feats=hidden_channels*num_heads,
+ out_feats=hidden_channels if l < n_layers - 1 else out_channels,
+ num_heads=num_heads if l < n_layers - 1 else 1,
+ feat_drop=dropout,
+ attn_drop=dropout,
+ residual=residual,
+ activation=activation if l < n_layers - 1 else None
+ )
+ for rel in rel_names}, aggregate='sum')
+ for l in range(n_layers)
+ ])
+
+ def forward(self, g):
+ feats = []
+ feats.append(self.embedding_type(torch.argmax(g.ndata['type'], dim=1)))
+ feats.append(self.embedding_pos(g.ndata['pos']/170.))
+ feats.append(self.embedding_color(g.ndata['color']/255.))
+ feats.append(self.embedding_shape(torch.argmax(g.ndata['shape'], dim=1)))
+ h = {'obj': self.combine(torch.cat(feats, dim=1))}
+ for l, conv in enumerate(self.layers):
+ h = conv(g, h)
+ if l != len(self.layers) - 1:
+ h = {k: v.flatten(1) for k, v in h.items()}
+ else:
+ h = {k: v.mean(1) for k, v in h.items()}
+ with g.local_scope():
+ g.ndata['h'] = h['obj']
+ out = dgl.mean_nodes(g, 'h')
+ return out
+
+# -------------------------------------------------------------------------------------------
+
+class RGATv3(nn.Module):
+ # multi-layer GNN for one single feature
+ def __init__(
+ self,
+ hidden_channels,
+ out_channels,
+ num_heads,
+ rel_names,
+ dropout,
+ n_layers,
+ activation=nn.ELU(),
+ residual=False
+ ):
+ super().__init__()
+ self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
+ self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
+ self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
+ self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
+ self.combine = nn.Sequential(
+ nn.ReLU(),
+ nn.Linear(hidden_channels*num_heads*4, hidden_channels*num_heads)
+ )
+ self.layers = nn.ModuleList([
+ dglnn.HeteroGraphConv({
+ rel: dglnn.GATv2Conv(
+ in_feats=hidden_channels*num_heads,
+ out_feats=hidden_channels if l < n_layers - 1 else out_channels,
+ num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
+ feat_drop=dropout,
+ attn_drop=dropout,
+ residual=residual,
+ activation=activation if l < n_layers - 1 else None
+ )
+ for rel in rel_names}, aggregate='sum')
+ for l in range(n_layers)
+ ])
+
+ def forward(self, g):
+ feats = []
+ feats.append(self.embedding_type(g.ndata['type'].float()))
+ feats.append(self.embedding_pos(g.ndata['pos']/170.)) # NOTE: this should be 180 because I remove the boundary walls!
+ feats.append(self.embedding_color(g.ndata['color']/255.))
+ feats.append(self.embedding_shape(g.ndata['shape'].float()))
+ h = {'obj': self.combine(torch.cat(feats, dim=1))}
+ for l, conv in enumerate(self.layers):
+ h = conv(g, h)
+ if l != len(self.layers) - 1:
+ h = {k: v.flatten(1) for k, v in h.items()}
+ else:
+ h = {k: v.mean(1) for k, v in h.items()}
+ with g.local_scope():
+ g.ndata['h'] = h['obj']
+ out = dgl.mean_nodes(g, 'h')
+ return out
+
+# -------------------------------------------------------------------------------------------
+
+class RGCNv2(nn.Module):
+ # multi-layer GNN for one single feature
+ def __init__(
+ self,
+ hidden_channels,
+ out_channels,
+ rel_names,
+ dropout,
+ n_layers,
+ activation=nn.ReLU()
+ ):
+ super().__init__()
+ self.embedding_type = nn.Linear(9, int(hidden_channels))
+ self.embedding_pos = nn.Linear(2, int(hidden_channels))
+ self.embedding_color = nn.Linear(3, int(hidden_channels))
+ self.embedding_shape = nn.Linear(18, int(hidden_channels))
+ self.combine = nn.Sequential(
+ nn.ReLU(),
+ nn.Linear(hidden_channels*4, hidden_channels)
+ )
+ self.layers = nn.ModuleList([
+ dglnn.RelGraphConv(
+ in_feat=hidden_channels,
+ out_feat=hidden_channels,
+ num_rels=len(rel_names),
+ regularizer=None,
+ num_bases=None,
+ bias=True,
+ activation=activation,
+ self_loop=True,
+ dropout=dropout,
+ layer_norm=False
+ )
+ for _ in range(n_layers-1)])
+ self.layers.append(
+ dglnn.RelGraphConv(
+ in_feat=hidden_channels,
+ out_feat=out_channels,
+ num_rels=len(rel_names),
+ regularizer=None,
+ num_bases=None,
+ bias=True,
+ activation=activation,
+ self_loop=True,
+ dropout=dropout,
+ layer_norm=False
+ )
+ )
+
+ def forward(self, g):
+ g = g.to_homogeneous()
+ feats = []
+ feats.append(self.embedding_type(g.ndata['type'].float()))
+ feats.append(self.embedding_pos(g.ndata['pos']/170.))
+ feats.append(self.embedding_color(g.ndata['color']/255.))
+ feats.append(self.embedding_shape(g.ndata['shape'].float()))
+ h = self.combine(torch.cat(feats, dim=1))
+ for conv in self.layers:
+ h = conv(g, h, g.etypes)
+ with g.local_scope():
+ g.ndata['h'] = h
+ out = dgl.mean_nodes(g, 'h')
+ return out
+
+# -------------------------------------------------------------------------------------------
+
+class RGATv3Agent(nn.Module):
+ # multi-layer GNN for one single feature
+ def __init__(
+ self,
+ hidden_channels,
+ out_channels,
+ num_heads,
+ rel_names,
+ dropout,
+ n_layers,
+ activation=nn.ELU(),
+ residual=False
+ ):
+ super().__init__()
+ self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
+ self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
+ self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
+ self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
+ self.combine = nn.Sequential(
+ nn.ReLU(),
+ nn.Linear(hidden_channels*num_heads*4, hidden_channels*num_heads)
+ )
+ self.layers = nn.ModuleList([
+ dglnn.HeteroGraphConv({
+ rel: dglnn.GATv2Conv(
+ in_feats=hidden_channels*num_heads,
+ out_feats=hidden_channels if l < n_layers - 1 else out_channels,
+ num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
+ feat_drop=dropout,
+ attn_drop=dropout,
+ residual=residual,
+ activation=activation if l < n_layers - 1 else None
+ )
+ for rel in rel_names}, aggregate='sum')
+ for l in range(n_layers)
+ ])
+
+ def forward(self, g):
+ agent_mask = g.ndata['type'][:, 0] == 1
+ feats = []
+ feats.append(self.embedding_type(g.ndata['type'].float()))
+ feats.append(self.embedding_pos(g.ndata['pos']/200.))
+ feats.append(self.embedding_color(g.ndata['color']/255.))
+ feats.append(self.embedding_shape(g.ndata['shape'].float()))
+ h = {'obj': self.combine(torch.cat(feats, dim=1))}
+ for l, conv in enumerate(self.layers):
+ h = conv(g, h)
+ if l != len(self.layers) - 1:
+ h = {k: v.flatten(1) for k, v in h.items()}
+ else:
+ h = {k: v.mean(1) for k, v in h.items()}
+ with g.local_scope():
+ g.ndata['h'] = h['obj']
+ out = g.ndata['h'][agent_mask, :]
+ ctx = dgl.mean_nodes(g, 'h')
+ return out + ctx
+
+# -------------------------------------------------------------------------------------------
+
+class RGATv4(nn.Module):
+ # multi-layer GNN for one single feature
+ def __init__(
+ self,
+ hidden_channels,
+ out_channels,
+ num_heads,
+ rel_names,
+ dropout,
+ n_layers,
+ activation=nn.ELU(),
+ residual=False
+ ):
+ super().__init__()
+ self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
+ self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
+ self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
+ self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
+ self.combine_attr = nn.Sequential(
+ nn.ReLU(),
+ nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads)
+ )
+ self.combine_pos = nn.Sequential(
+ nn.ReLU(),
+ nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads)
+ )
+ self.layers = nn.ModuleList([
+ dglnn.HeteroGraphConv({
+ rel: dglnn.GATv2Conv(
+ in_feats=hidden_channels*num_heads,
+ out_feats=hidden_channels if l < n_layers - 1 else out_channels,
+ num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
+ feat_drop=dropout,
+ attn_drop=dropout,
+ residual=residual,
+ share_weights=False,
+ activation=activation if l < n_layers - 1 else None
+ )
+ for rel in rel_names}, aggregate='sum')
+ for l in range(n_layers)
+ ])
+
+ def forward(self, g):
+ attr = []
+ attr.append(self.embedding_type(g.ndata['type'].float()))
+ pos = self.embedding_pos(g.ndata['pos']/170.)
+ attr.append(self.embedding_color(g.ndata['color']/255.))
+ attr.append(self.embedding_shape(g.ndata['shape'].float()))
+ combined_attr = self.combine_attr(torch.cat(attr, dim=1))
+ h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
+ for l, conv in enumerate(self.layers):
+ h = conv(g, h)
+ if l != len(self.layers) - 1:
+ h = {k: v.flatten(1) for k, v in h.items()}
+ else:
+ h = {k: v.mean(1) for k, v in h.items()}
+ with g.local_scope():
+ g.ndata['h'] = h['obj']
+ out = dgl.mean_nodes(g, 'h')
+ return out
+
+# -------------------------------------------------------------------------------------------
+
+class RGATv4Norm(nn.Module):
+ # multi-layer GNN for one single feature
+ def __init__(
+ self,
+ hidden_channels,
+ out_channels,
+ num_heads,
+ rel_names,
+ dropout,
+ n_layers,
+ activation=nn.ELU(),
+ residual=False
+ ):
+ super().__init__()
+ self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
+ self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
+ self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
+ self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
+ self.combine_attr = nn.Sequential(
+ nn.ReLU(),
+ nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads)
+ )
+ self.combine_pos = nn.Sequential(
+ nn.ReLU(),
+ nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads)
+ )
+ self.layers = nn.ModuleList([
+ dglnn.HeteroGraphConv({
+ rel: dglnn.GATv2Conv(
+ in_feats=hidden_channels*num_heads,
+ out_feats=hidden_channels if l < n_layers - 1 else out_channels,
+ num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
+ feat_drop=dropout,
+ attn_drop=dropout,
+ residual=residual,
+ activation=activation if l < n_layers - 1 else None
+ )
+ for rel in rel_names}, aggregate='sum')
+ for l in range(n_layers)
+ ])
+ self.norms = nn.ModuleList([
+ Norm(
+ norm_type='gn',
+ hidden_dim=hidden_channels*num_heads if l < n_layers - 1 else out_channels
+ )
+ for l in range(n_layers)
+ ])
+
+ def forward(self, g):
+ attr = []
+ attr.append(self.embedding_type(g.ndata['type'].float()))
+ pos = self.embedding_pos(g.ndata['pos']/170.)
+ attr.append(self.embedding_color(g.ndata['color']/255.))
+ attr.append(self.embedding_shape(g.ndata['shape'].float()))
+ combined_attr = self.combine_attr(torch.cat(attr, dim=1))
+ h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
+ for l, conv in enumerate(self.layers):
+ h = conv(g, h)
+ if l != len(self.layers) - 1:
+ h = {k: v.flatten(1) for k, v in h.items()}
+ else:
+ h = {k: v.mean(1) for k, v in h.items()}
+ h = {k: self.norms[l](g, v) for k, v in h.items()}
+ with g.local_scope():
+ g.ndata['h'] = h['obj']
+ out = dgl.mean_nodes(g, 'h')
+ return out
+
+# -------------------------------------------------------------------------------------------
+
+class RGATv3Norm(nn.Module):
+ # multi-layer GNN for one single feature
+ def __init__(
+ self,
+ hidden_channels,
+ out_channels,
+ num_heads,
+ rel_names,
+ dropout,
+ n_layers,
+ activation=nn.ELU(),
+ residual=False
+ ):
+ super().__init__()
+ self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
+ self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
+ self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
+ self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
+ self.combine = nn.Sequential(
+ nn.ReLU(),
+ nn.Linear(hidden_channels*num_heads*4, hidden_channels*num_heads)
+ )
+ self.layers = nn.ModuleList([
+ dglnn.HeteroGraphConv({
+ rel: dglnn.GATv2Conv(
+ in_feats=hidden_channels*num_heads,
+ out_feats=hidden_channels if l < n_layers - 1 else out_channels,
+ num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
+ feat_drop=dropout,
+ attn_drop=dropout,
+ residual=residual,
+ activation=activation if l < n_layers - 1 else None
+ )
+ for rel in rel_names}, aggregate='sum')
+ for l in range(n_layers)
+ ])
+ self.norms = nn.ModuleList([
+ Norm(
+ norm_type='gn',
+ hidden_dim=hidden_channels*num_heads if l < n_layers - 1 else out_channels
+ )
+ for l in range(n_layers)
+ ])
+
+ def forward(self, g):
+ feats = []
+ feats.append(self.embedding_type(g.ndata['type'].float()))
+ feats.append(self.embedding_pos(g.ndata['pos']/170.))
+ feats.append(self.embedding_color(g.ndata['color']/255.))
+ feats.append(self.embedding_shape(g.ndata['shape'].float()))
+ h = {'obj': self.combine(torch.cat(feats, dim=1))}
+ for l, conv in enumerate(self.layers):
+ h = conv(g, h)
+ if l != len(self.layers) - 1:
+ h = {k: v.flatten(1) for k, v in h.items()}
+ else:
+ h = {k: v.mean(1) for k, v in h.items()}
+ h = {k: self.norms[l](g, v) for k, v in h.items()}
+ with g.local_scope():
+ g.ndata['h'] = h['obj']
+ out = dgl.mean_nodes(g, 'h')
+ return out
+
+# -------------------------------------------------------------------------------------------
+
+class RGATv4Agent(nn.Module):
+ # multi-layer GNN for one single feature
+ def __init__(
+ self,
+ hidden_channels,
+ out_channels,
+ num_heads,
+ rel_names,
+ dropout,
+ n_layers,
+ activation=nn.ELU(),
+ residual=False
+ ):
+ super().__init__()
+ self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
+ self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
+ self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
+ self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
+ self.combine_attr = nn.Sequential(
+ nn.ReLU(),
+ nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads)
+ )
+ self.combine_pos = nn.Sequential(
+ nn.ReLU(),
+ nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads)
+ )
+ self.layers = nn.ModuleList([
+ dglnn.HeteroGraphConv({
+ rel: dglnn.GATv2Conv(
+ in_feats=hidden_channels*num_heads,
+ out_feats=hidden_channels if l < n_layers - 1 else out_channels,
+ num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
+ feat_drop=dropout,
+ attn_drop=dropout,
+ residual=residual,
+ activation=activation if l < n_layers - 1 else None
+ )
+ for rel in rel_names}, aggregate='sum')
+ for l in range(n_layers)
+ ])
+ self.combine_agent_context = nn.Linear(out_channels*2, out_channels)
+
+ def forward(self, g):
+ agent_mask = g.ndata['type'][:, 0] == 1
+ attr = []
+ attr.append(self.embedding_type(g.ndata['type'].float()))
+ pos = self.embedding_pos(g.ndata['pos']/170.)
+ attr.append(self.embedding_color(g.ndata['color']/255.))
+ attr.append(self.embedding_shape(g.ndata['shape'].float()))
+ combined_attr = self.combine_attr(torch.cat(attr, dim=1))
+ h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
+ for l, conv in enumerate(self.layers):
+ h = conv(g, h)
+ if l != len(self.layers) - 1:
+ h = {k: v.flatten(1) for k, v in h.items()}
+ else:
+ h = {k: v.mean(1) for k, v in h.items()}
+ with g.local_scope():
+ g.ndata['h'] = h['obj']
+ h_a = g.ndata['h'][agent_mask, :]
+ g_no_agent = copy.deepcopy(g)
+ g_no_agent.remove_nodes([i for i, x in enumerate(agent_mask) if x])
+ h_g = dgl.mean_nodes(g_no_agent, 'h')
+ out = self.combine_agent_context(torch.cat((h_a, h_g), dim=1))
+ return out
+
+# -------------------------------------------------------------------------------------------
+
+class RGCNv4(nn.Module):
+ # multi-layer GNN for one single feature
+ def __init__(
+ self,
+ hidden_channels,
+ out_channels,
+ rel_names,
+ n_layers,
+ activation=nn.ELU(),
+ ):
+ super().__init__()
+ self.embedding_type = nn.Linear(9, int(hidden_channels))
+ self.embedding_pos = nn.Linear(2, int(hidden_channels))
+ self.embedding_color = nn.Linear(3, int(hidden_channels))
+ self.embedding_shape = nn.Linear(18, int(hidden_channels))
+ self.combine_attr = nn.Sequential(
+ nn.ReLU(),
+ nn.Linear(hidden_channels*3, hidden_channels)
+ )
+ self.combine_pos = nn.Sequential(
+ nn.ReLU(),
+ nn.Linear(hidden_channels*2, hidden_channels)
+ )
+ self.layers = nn.ModuleList([
+ dglnn.HeteroGraphConv({
+ rel: dglnn.GraphConv(
+ in_feats=hidden_channels,
+ out_feats=hidden_channels if l < n_layers - 1 else out_channels,
+ activation=activation if l < n_layers - 1 else None
+ )
+ for rel in rel_names}, aggregate='sum')
+ for l in range(n_layers)
+ ])
+
+ def forward(self, g):
+ attr = []
+ attr.append(self.embedding_type(g.ndata['type'].float()))
+ pos = self.embedding_pos(g.ndata['pos']/170.)
+ attr.append(self.embedding_color(g.ndata['color']/255.))
+ attr.append(self.embedding_shape(g.ndata['shape'].float()))
+ combined_attr = self.combine_attr(torch.cat(attr, dim=1))
+ h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
+ for l, conv in enumerate(self.layers):
+ h = conv(g, h)
+ with g.local_scope():
+ g.ndata['h'] = h['obj']
+ out = dgl.mean_nodes(g, 'h')
+ return out
+
+# -------------------------------------------------------------------------------------------
+
+class RGATv5(nn.Module):
+
+ def __init__(
+ self,
+ hidden_channels,
+ out_channels,
+ num_heads,
+ rel_names,
+ dropout,
+ n_layers,
+ activation=nn.ELU(),
+ residual=False
+ ):
+ super().__init__()
+ self.hidden_channels = hidden_channels
+ self.num_heads = num_heads
+ self.embedding_type = nn.Linear(9, hidden_channels*num_heads)
+ self.embedding_pos = nn.Linear(2, hidden_channels*num_heads)
+ self.embedding_color = nn.Linear(3, hidden_channels*num_heads)
+ self.embedding_shape = nn.Linear(18, hidden_channels*num_heads)
+ self.combine = nn.Sequential(
+ nn.ReLU(),
+ nn.Linear(hidden_channels*num_heads*4, hidden_channels*num_heads)
+ )
+ self.attention = nn.Linear(hidden_channels*num_heads*4, 4)
+ self.layers = nn.ModuleList([
+ dglnn.HeteroGraphConv({
+ rel: dglnn.GATv2Conv(
+ in_feats=hidden_channels*num_heads,
+ out_feats=hidden_channels if l < n_layers - 1 else out_channels,
+ num_heads=num_heads,
+ feat_drop=dropout,
+ attn_drop=dropout,
+ residual=residual,
+ activation=activation if l < n_layers - 1 else None
+ )
+ for rel in rel_names}, aggregate='sum')
+ for l in range(n_layers)
+ ])
+
+ def forward(self, g):
+ feats = []
+ feats.append(self.embedding_type(g.ndata['type'].float()))
+ feats.append(self.embedding_pos(g.ndata['pos']/170.))
+ feats.append(self.embedding_color(g.ndata['color']/255.))
+ feats.append(self.embedding_shape(g.ndata['shape'].float()))
+ h = torch.cat(feats, dim=1)
+ feat_attn = F.softmax(self.attention(h), dim=1)
+ h = h * feat_attn.repeat_interleave(self.hidden_channels*self.num_heads, dim=1)
+ h_in = self.combine(h)
+ h = {'obj': h_in}
+ for conv in self.layers:
+ h = conv(g, h)
+ h = {k: v.flatten(1) for k, v in h.items()}
+ #if l != len(self.layers) - 1:
+ # h = {k: v.flatten(1) for k, v in h.items()}
+ #else:
+ # h = {k: v.mean(1) for k, v in h.items()}
+ h = {k: v + h_in for k, v in h.items()}
+ with g.local_scope():
+ g.ndata['h'] = h['obj']
+ out = dgl.mean_nodes(g, 'h')
+ return out
+
+# -------------------------------------------------------------------------------------------
+
+class RGATv6(nn.Module):
+
+ # RGATv6 = RGATv4 + Global Attention Pooling
+
+ def __init__(
+ self,
+ hidden_channels,
+ out_channels,
+ num_heads,
+ rel_names,
+ dropout,
+ n_layers,
+ activation=nn.ELU(),
+ residual=False
+ ):
+ super().__init__()
+ self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
+ self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
+ self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
+ self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
+ self.combine_attr = nn.Sequential(
+ nn.ReLU(),
+ nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads)
+ )
+ self.combine_pos = nn.Sequential(
+ nn.ReLU(),
+ nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads)
+ )
+ self.layers = nn.ModuleList([
+ dglnn.HeteroGraphConv({
+ rel: dglnn.GATv2Conv(
+ in_feats=hidden_channels*num_heads,
+ out_feats=hidden_channels if l < n_layers - 1 else out_channels,
+ num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
+ feat_drop=dropout,
+ attn_drop=dropout,
+ residual=residual,
+ activation=activation if l < n_layers - 1 else None
+ )
+ for rel in rel_names}, aggregate='sum')
+ for l in range(n_layers)
+ ])
+ gate_nn = nn.Linear(out_channels, 1)
+ self.gap = dglnn.GlobalAttentionPooling(gate_nn)
+
+ def forward(self, g):
+ attr = []
+ attr.append(self.embedding_type(g.ndata['type'].float()))
+ pos = self.embedding_pos(g.ndata['pos']/170.)
+ attr.append(self.embedding_color(g.ndata['color']/255.))
+ attr.append(self.embedding_shape(g.ndata['shape'].float()))
+ combined_attr = self.combine_attr(torch.cat(attr, dim=1))
+ h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
+ for l, conv in enumerate(self.layers):
+ h = conv(g, h)
+ if l != len(self.layers) - 1:
+ h = {k: v.flatten(1) for k, v in h.items()}
+ else:
+ h = {k: v.mean(1) for k, v in h.items()}
+ #with g.local_scope():
+ #g.ndata['h'] = h['obj']
+ #out = dgl.mean_nodes(g, 'h')
+ out = self.gap(g, h['obj'])
+ return out
+
+# -------------------------------------------------------------------------------------------
+
+class RGATv6Agent(nn.Module):
+
+ # RGATv6 = RGATv4 + Global Attention Pooling
+
+ def __init__(
+ self,
+ hidden_channels,
+ out_channels,
+ num_heads,
+ rel_names,
+ dropout,
+ n_layers,
+ activation=nn.ELU(),
+ residual=False
+ ):
+ super().__init__()
+ self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
+ self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
+ self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
+ self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
+ self.combine_attr = nn.Sequential(
+ nn.ReLU(),
+ nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads)
+ )
+ self.combine_pos = nn.Sequential(
+ nn.ReLU(),
+ nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads)
+ )
+ self.layers = nn.ModuleList([
+ dglnn.HeteroGraphConv({
+ rel: dglnn.GATv2Conv(
+ in_feats=hidden_channels*num_heads,
+ out_feats=hidden_channels if l < n_layers - 1 else out_channels,
+ num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
+ feat_drop=dropout,
+ attn_drop=dropout,
+ residual=residual,
+ activation=activation if l < n_layers - 1 else None
+ )
+ for rel in rel_names}, aggregate='sum')
+ for l in range(n_layers)
+ ])
+ gate_nn = nn.Linear(out_channels, 1)
+ self.gap = dglnn.GlobalAttentionPooling(gate_nn)
+ self.combine_agent_context = nn.Linear(out_channels*2, out_channels)
+
+ def forward(self, g):
+ agent_mask = g.ndata['type'][:, 0] == 1
+ attr = []
+ attr.append(self.embedding_type(g.ndata['type'].float()))
+ pos = self.embedding_pos(g.ndata['pos']/170.)
+ attr.append(self.embedding_color(g.ndata['color']/255.))
+ attr.append(self.embedding_shape(g.ndata['shape'].float()))
+ combined_attr = self.combine_attr(torch.cat(attr, dim=1))
+ h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
+ for l, conv in enumerate(self.layers):
+ h = conv(g, h)
+ if l != len(self.layers) - 1:
+ h = {k: v.flatten(1) for k, v in h.items()}
+ else:
+ h = {k: v.mean(1) for k, v in h.items()}
+ with g.local_scope():
+ g.ndata['h'] = h['obj']
+ h_a = g.ndata['h'][agent_mask, :]
+ h_g = g.ndata['h'][~agent_mask, :]
+ g_no_agent = copy.deepcopy(g)
+ g_no_agent.remove_nodes([i for i, x in enumerate(agent_mask) if x])
+ h_g = self.gap(g_no_agent, h_g) # dgl.mean_nodes(g_no_agent, 'h')
+ out = self.combine_agent_context(torch.cat((h_a, h_g), dim=1))
+ return out
+
+# -------------------------------------------------------------------------------------------
+
diff --git a/tom/model.py b/tom/model.py
new file mode 100644
index 0000000..14e0bb5
--- /dev/null
+++ b/tom/model.py
@@ -0,0 +1,513 @@
+from argparse import ArgumentParser
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import pytorch_lightning as pl
+from pytorch_lightning.callbacks import ModelCheckpoint
+from torch.utils.data import DataLoader
+
+from tom.dataset import *
+from tom.transformer import TransformerEncoder
+from tom.gnn import RGATv2, RGATv3, RGATv3Agent, RGATv4, RGATv4Norm, RSAGEv4, RAGNNv4
+
+
+class MlpModel(nn.Module):
+ """Multilayer Perceptron with last layer linear.
+ Args:
+ input_size (int): number of inputs
+ hidden_sizes (list): can be empty list for none (linear model).
+ output_size: linear layer at output, or if ``None``, the last hidden size
+ will be the output size and will have nonlinearity applied
+ nonlinearity: torch nonlinearity Module (not Functional).
+ """
+
+ def __init__(
+ self,
+ input_size,
+ hidden_sizes, # Can be empty list or None for none.
+ output_size=None, # if None, last layer has nonlinearity applied.
+ nonlinearity=nn.ReLU, # Module, not Functional.
+ dropout=None # Dropout value
+ ):
+ super().__init__()
+ if isinstance(hidden_sizes, int):
+ hidden_sizes = [hidden_sizes]
+ elif hidden_sizes is None:
+ hidden_sizes = []
+ hidden_layers = [nn.Linear(n_in, n_out) for n_in, n_out in
+ zip([input_size] + hidden_sizes[:-1], hidden_sizes)]
+ sequence = list()
+ for i, layer in enumerate(hidden_layers):
+ if dropout is not None:
+ sequence.extend([layer, nonlinearity(), nn.Dropout(dropout)])
+ else:
+ sequence.extend([layer, nonlinearity()])
+
+ if output_size is not None:
+ last_size = hidden_sizes[-1] if hidden_sizes else input_size
+ sequence.append(torch.nn.Linear(last_size, output_size))
+ self.model = nn.Sequential(*sequence)
+ self._output_size = (hidden_sizes[-1] if output_size is None
+ else output_size)
+
+ def forward(self, input):
+ """Compute the model on the input, assuming input shape [B,input_size]."""
+ return self.model(input)
+
+ @property
+ def output_size(self):
+ """Retuns the output size of the model."""
+ return self._output_size
+
+# ---------------------------------------------------------------------------------------------------------------------------------
+
+class GraphBCRNN(pl.LightningModule):
+ """
+ Implementation of the baseline model for the BC-RNN algorithm.
+ R-GCN + LSTM are used to encode the familiarization trials
+ """
+ @staticmethod
+ def add_model_specific_args(parent_parser):
+ parser = ArgumentParser(parents=[parent_parser], add_help=False)
+ parser.add_argument('--lr', type=float, default=1e-4)
+ parser.add_argument('--action_dim', type=int, default=2)
+ parser.add_argument('--context_dim', type=int, default=32) # lstm hidden size
+ parser.add_argument('--beta', type=float, default=0.01)
+ parser.add_argument('--dropout', type=float, default=0.2)
+ parser.add_argument('--process_data', type=int, default=0)
+ parser.add_argument('--max_len', type=int, default=30)
+ # arguments for gnn
+ parser.add_argument('--gnn_type', type=str, default='RGATv4')
+ parser.add_argument('--state_dim', type=int, default=128) # gnn out_feats
+ parser.add_argument('--feats_dims', type=list, default=[9, 2, 3, 18])
+ parser.add_argument('--aggregation', type=str, default='sum')
+ # arguments for mpl
+ #parser.add_argument('--mpl_hid_feats', type=list, default=[256, 64, 16])
+ return parser
+
+ def __init__(self, hparams):
+ super().__init__()
+
+ self.hparams = hparams
+ self.lr = self.hparams.lr
+ self.state_dim = self.hparams.state_dim
+ self.action_dim = self.hparams.action_dim
+ self.context_dim = self.hparams.context_dim
+ self.beta = self.hparams.beta
+ self.dropout = self.hparams.dropout
+ self.max_len = self.hparams.max_len
+ self.feats_dims = self.hparams.feats_dims # type, position, color, shape
+ self.rel_names = [
+ 'is_aligned', 'is_back', 'is_close', 'is_down_adj', 'is_down_left_adj',
+ 'is_down_right_adj', 'is_front', 'is_left', 'is_left_adj', 'is_right',
+ 'is_right_adj', 'is_top_adj', 'is_top_left_adj', 'is_top_right_adj'
+ ]
+ self.gnn_aggregation = self.hparams.aggregation
+
+ if self.hparams.gnn_type == 'RGATv2':
+ self.gnn_encoder = RGATv2(
+ hidden_channels=self.state_dim,
+ out_channels=self.state_dim,
+ num_heads=4,
+ rel_names=self.rel_names,
+ dropout=0.0,
+ n_layers=2,
+ activation=nn.ELU(),
+ residual=False
+ )
+ if self.hparams.gnn_type == 'RGATv3':
+ self.gnn_encoder = RGATv3(
+ hidden_channels=self.state_dim,
+ out_channels=self.state_dim,
+ num_heads=4,
+ rel_names=self.rel_names,
+ dropout=0.0,
+ n_layers=2,
+ activation=nn.ELU(),
+ residual=False
+ )
+ if self.hparams.gnn_type == 'RGATv4':
+ self.gnn_encoder = RGATv4(
+ hidden_channels=self.state_dim,
+ out_channels=self.state_dim,
+ num_heads=4,
+ rel_names=self.rel_names,
+ dropout=0.0,
+ n_layers=2,
+ activation=nn.ELU(),
+ residual=False
+ )
+ if self.hparams.gnn_type == 'RGATv3Agent':
+ self.gnn_encoder = RGATv3Agent(
+ hidden_channels=self.state_dim,
+ out_channels=self.state_dim,
+ num_heads=4,
+ rel_names=self.rel_names,
+ dropout=0.0,
+ n_layers=2,
+ activation=nn.ELU(),
+ residual=False
+ )
+ if self.hparams.gnn_type == 'RSAGEv4':
+ self.gnn_encoder = RSAGEv4(
+ hidden_channels=self.state_dim,
+ out_channels=self.state_dim,
+ rel_names=self.rel_names,
+ dropout=0.0,
+ n_layers=2,
+ activation=nn.ELU()
+ )
+ if self.gnn_aggregation == 'cat_axis_1':
+ self.lstm_input_size = self.state_dim * len(self.feats_dims) + self.action_dim
+ self.mlp_input_size = self.state_dim * len(self.feats_dims) + self.context_dim * 2
+ elif self.gnn_aggregation == 'sum':
+ self.lstm_input_size = self.state_dim + self.action_dim
+ self.mlp_input_size = self.state_dim + self.context_dim * 2
+ else:
+ raise ValueError('Fix this')
+
+ self.context_enc = nn.LSTM(self.lstm_input_size, self.context_dim, 2,
+ batch_first=True, bidirectional=True)
+
+ self.policy = MlpModel(input_size=self.mlp_input_size, hidden_sizes=[256, 128, 256],
+ output_size=self.action_dim, dropout=self.dropout)
+
+ self.past_samples = []
+
+ def forward(self, batch):
+ dem_frames, dem_actions, dem_lens, query_frame, target_action = batch
+ dem_actions = dem_actions.float()
+ target_action = target_action.float()
+ dem_states = self.gnn_encoder(dem_frames) # torch.Size([number of frames, 128 * number of features if cat_axis_1])
+ # segment according the number of frames in each episode and pad with zeros
+ # to obtain tensors of shape [batch size, num of trials (8), max num of frames (30), hidden dim]
+ b, l, s, _ = dem_actions.size()
+ dem_states_new = []
+ for batch in range(b):
+ dem_states_new.append(self._sequence_to_padding(dem_states, dem_lens[batch], self.max_len))
+ dem_states_new = torch.stack(dem_states_new).to(self.device) # torch.Size([batchsize, 8, 30, 128 * number of features if cat_axis_1])
+ # concatenate states and actions to get expert trajectory
+ dem_states_new = dem_states_new.view(b * l, s, -1) # torch.Size([batchsize*8, 30, 128 * number of features if cat_axis_1])
+ dem_actions = dem_actions.view(b * l, s, -1) # torch.Size([batchsize*8, 30, 128])
+ dem_traj = torch.cat([dem_states_new, dem_actions], dim=2) # torch.Size([batchsize*8, 30, 2 + 128 * number of features if cat_axis_1])
+ # embed expert trajectory to get a context embedding batch x samples x dim
+ dem_lens = torch.tensor([t for dl in dem_lens for t in dl]).to(torch.int64).cpu()
+ dem_lens = dem_lens.view(-1)
+ x_lstm = nn.utils.rnn.pack_padded_sequence(dem_traj, dem_lens, batch_first=True, enforce_sorted=False)
+ x_lstm, _ = self.context_enc(x_lstm)
+ x_lstm, _ = nn.utils.rnn.pad_packed_sequence(x_lstm, batch_first=True)
+ x_out = x_lstm[:, -1, :]
+ x_out = x_out.view(b, l, -1)
+ context = torch.mean(x_out, dim=1) # torch.Size([batchsize, 64]) (64=32*2)
+ # concat context embedding to the state embedding of test trajectory
+ test_states = self.gnn_encoder(query_frame) # torch.Size([2, 128])
+ test_context_states = torch.cat([context, test_states], dim=1) # torch.Size([batchsize, 192]) 192=64+128
+ # for each state in the test states calculate action
+ test_actions_pred = torch.tanh(self.policy(test_context_states)) # torch.Size([batchsize, 2])
+ return target_action, test_actions_pred
+
+ def _sequence_to_padding(self, x, lengths, max_length):
+ # declare the shape, it can work for x of any shape.
+ ret_tensor = torch.zeros((len(lengths), max_length) + tuple(x.shape[1:]))
+ cum_len = 0
+ for i, l in enumerate(lengths):
+ ret_tensor[i, :l] = x[cum_len: cum_len+l]
+ cum_len += l
+ return ret_tensor
+
+ def training_step(self, batch, batch_idx):
+ test_actions, test_actions_pred = self.forward(batch)
+ loss = F.mse_loss(test_actions, test_actions_pred)
+ self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
+ return loss
+
+ def validation_step(self, batch, batch_idx, dataloader_idx=0):
+ test_actions, test_actions_pred = self.forward(batch)
+ loss = F.mse_loss(test_actions, test_actions_pred)
+ self.log('val_loss', loss, on_epoch=True, logger=True)
+
+ def configure_optimizers(self):
+ optim = torch.optim.Adam(self.parameters(), lr=self.lr)
+ return optim
+ #optim = torch.optim.AdamW(self.parameters(), lr=self.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01)
+ #scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optim, gamma=0.8)
+ #return [optim], [scheduler]
+
+ def train_dataloader(self):
+ train_dataset = ToMnetDGLDataset(path=self.hparams.data_path,
+ types=self.hparams.types,
+ mode='train')
+ train_loader = DataLoader(dataset=train_dataset,
+ batch_size=self.hparams.batch_size,
+ collate_fn=collate_function_seq,
+ num_workers=self.hparams.num_workers,
+ #pin_memory=True,
+ shuffle=True)
+ return train_loader
+
+ def val_dataloader(self):
+ val_datasets = []
+ val_loaders = []
+ for t in self.hparams.types:
+ val_datasets.append(ToMnetDGLDataset(path=self.hparams.data_path,
+ types=[t],
+ mode='val'))
+ val_loaders.append(DataLoader(dataset=val_datasets[-1],
+ batch_size=self.hparams.batch_size,
+ collate_fn=collate_function_seq,
+ num_workers=self.hparams.num_workers,
+ #pin_memory=True,
+ shuffle=False))
+ return val_loaders
+
+ def configure_callbacks(self):
+ checkpoint = ModelCheckpoint(
+ dirpath=None, # automatically set
+ #filename=self.params['bc_model']+'-'+self.params['gnn_type']+'-'+self.gnn_params['feats_aggr']+'-{epoch:02d}',
+ save_top_k=-1,
+ period=1
+ )
+ return [checkpoint]
+
+# ---------------------------------------------------------------------------------------------------------------------------------
+
+class GraphBC_T(pl.LightningModule):
+ """
+ BC model with GraphTrans encoder, LSTM and MLP.
+ """
+ @staticmethod
+ def add_model_specific_args(parent_parser):
+ parser = ArgumentParser(parents=[parent_parser], add_help=False)
+ parser.add_argument('--lr', type=float, default=1e-4)
+ parser.add_argument('--action_dim', type=int, default=2)
+ parser.add_argument('--context_dim', type=int, default=32) # lstm hidden size
+ parser.add_argument('--beta', type=float, default=0.01)
+ parser.add_argument('--dropout', type=float, default=0.2)
+ parser.add_argument('--process_data', type=int, default=0)
+ parser.add_argument('--max_len', type=int, default=30)
+ # arguments for gnn
+ parser.add_argument('--state_dim', type=int, default=128) # gnn out_feats
+ parser.add_argument('--feats_dims', type=list, default=[9, 2, 3, 18])
+ parser.add_argument('--aggregation', type=str, default='cat_axis_1')
+ parser.add_argument('--gnn_type', type=str, default='RGATv3')
+ # arguments for mpl
+ #parser.add_argument('--mpl_hid_feats', type=list, default=[256, 64, 16])
+ # arguments for transformer
+ parser.add_argument('--d_model', type=int, default=128)
+ parser.add_argument('--nhead', type=int, default=4)
+ parser.add_argument('--dim_feedforward', type=int, default=512)
+ parser.add_argument('--transformer_dropout', type=float, default=0.3)
+ parser.add_argument('--transformer_activation', type=str, default='gelu')
+ parser.add_argument('--num_encoder_layers', type=int, default=6)
+ parser.add_argument('--transformer_norm_input', type=int, default=0)
+ return parser
+
+ def __init__(self, hparams):
+ super().__init__()
+
+ self.hparams = hparams
+ self.lr = self.hparams.lr
+ self.state_dim = self.hparams.state_dim
+ self.action_dim = self.hparams.action_dim
+ self.context_dim = self.hparams.context_dim
+ self.beta = self.hparams.beta
+ self.dropout = self.hparams.dropout
+ self.max_len = self.hparams.max_len
+ self.feats_dims = self.hparams.feats_dims # type, position, color, shape
+
+ self.rel_names = [
+ 'is_aligned', 'is_back', 'is_close', 'is_down_adj', 'is_down_left_adj',
+ 'is_down_right_adj', 'is_front', 'is_left', 'is_left_adj', 'is_right',
+ 'is_right_adj', 'is_top_adj', 'is_top_left_adj', 'is_top_right_adj'
+ ]
+ self.gnn_aggregation = self.hparams.aggregation
+ if self.hparams.gnn_type == 'RGATv3':
+ self.gnn_encoder = RGATv3(
+ hidden_channels=self.state_dim,
+ out_channels=self.state_dim,
+ num_heads=4,
+ rel_names=self.rel_names,
+ dropout=0.0,
+ n_layers=2,
+ activation=nn.ELU(),
+ residual=False
+ )
+ if self.hparams.gnn_type == 'RGATv4':
+ self.gnn_encoder = RGATv4(
+ hidden_channels=self.state_dim,
+ out_channels=self.state_dim,
+ num_heads=4,
+ rel_names=self.rel_names,
+ dropout=0.0,
+ n_layers=2,
+ activation=nn.ELU(),
+ residual=False
+ )
+ if self.hparams.gnn_type == 'RSAGEv4':
+ self.gnn_encoder = RSAGEv4(
+ hidden_channels=self.state_dim,
+ out_channels=self.state_dim,
+ rel_names=self.rel_names,
+ dropout=0.0,
+ n_layers=2,
+ activation=nn.ELU()
+ )
+ if self.hparams.gnn_type == 'RAGNNv4':
+ self.gnn_encoder = RAGNNv4(
+ hidden_channels=self.state_dim,
+ out_channels=self.state_dim,
+ rel_names=self.rel_names,
+ dropout=0.0,
+ n_layers=2,
+ activation=nn.ELU()
+ )
+ if self.hparams.gnn_type == 'RGATv4Norm':
+ self.gnn_encoder = RGATv4Norm(
+ hidden_channels=self.state_dim,
+ out_channels=self.state_dim,
+ num_heads=4,
+ rel_names=self.rel_names,
+ dropout=0.0,
+ n_layers=2,
+ activation=nn.ELU(),
+ residual=True
+ )
+ self.d_model = self.hparams.d_model
+ self.nhead = self.hparams.nhead
+ self.dim_feedforward = self.hparams.dim_feedforward
+ self.transformer_dropout = self.hparams.transformer_dropout
+ self.transformer_activation = self.hparams.transformer_activation
+ self.num_encoder_layers = self.hparams.num_encoder_layers
+ self.transformer_norm_input = self.hparams.transformer_norm_input
+ self.context_enc = TransformerEncoder(
+ self.d_model,
+ self.nhead,
+ self.dim_feedforward,
+ self.transformer_dropout,
+ self.transformer_activation,
+ self.num_encoder_layers,
+ self.max_len,
+ self.transformer_norm_input
+ )
+
+ if self.gnn_aggregation == 'cat_axis_1':
+ self.gnn2transformer = nn.Linear(self.state_dim * len(self.feats_dims) + self.action_dim, self.d_model)
+ self.mlp_input_size = len(self.feats_dims) * self.state_dim + self.d_model
+ elif self.gnn_aggregation == 'sum':
+ self.gnn2transformer = nn.Linear(self.state_dim + self.action_dim, self.d_model)
+ self.mlp_input_size = self.state_dim + self.d_model
+ else:
+ raise ValueError('Only sum and cat1 aggregations are available.')
+
+ self.policy = MlpModel(input_size=self.mlp_input_size, hidden_sizes=[256, 128, 256],
+ output_size=self.action_dim, dropout=self.dropout)
+
+ # CLS Embedding parameters, requires_grad=True
+ self.embedding = nn.Embedding(self.max_len + 1, self.d_model) # + 1 cause of cls token
+ self.emb_layer_norm = nn.LayerNorm(self.d_model)
+ self.emb_dropout = nn.Dropout(p=self.transformer_dropout)
+
+ def forward(self, batch):
+ dem_frames, dem_actions, dem_lens, query_frame, target_action = batch
+ dem_actions = dem_actions.float()
+ target_action = target_action.float()
+ dem_states = self.gnn_encoder(dem_frames)
+ b, l, s, _ = dem_actions.size()
+ dem_lens = torch.tensor([t for dl in dem_lens for t in dl]).to(torch.int64).cpu()
+ dem_lens = dem_lens.view(-1)
+ dem_actions_packed = torch.nn.utils.rnn.pack_padded_sequence(dem_actions.view(b*l, s, -1), dem_lens, batch_first=True, enforce_sorted=False)[0]
+ dem_traj = torch.cat([dem_states, dem_actions_packed], dim=1)
+ h_node = self.gnn2transformer(dem_traj)
+ hidden_dim = h_node.size()[1]
+ padded_trajectory = torch.zeros(b*l, self.max_len, hidden_dim).to(self.device)
+ j = 0
+ for idx, i in enumerate(dem_lens):
+ padded_trajectory[idx][:i] = h_node[j:j+i]
+ j += i
+ mask = self.make_mask(padded_trajectory).to(self.device)
+ transformer_input = padded_trajectory.transpose(0, 1) # [30, 16, 128]
+ # add cls:
+ cls_embedding = nn.Parameter(torch.randn([1, 1, self.d_model], requires_grad=True)).expand(1, b*l, -1).to(self.device)
+ transformer_input = torch.cat([transformer_input, cls_embedding], dim=0) # [31, 16, 128]
+ zeros = mask.data.new(mask.size(0), 1).fill_(0)
+ mask = torch.cat([mask, zeros], dim=1)
+ # Embed
+ indices = torch.arange(self.max_len + 1, dtype=torch.int).to(self.device) # + 1 cause of cls [0, 1, ..., 30]
+ positional_embeddings = self.embedding(indices).unsqueeze(1) # torch.Size([31, 1, 128])
+ #generate transformer input
+ pe_input = positional_embeddings + transformer_input # torch.Size([31, 16, 128])
+ # Layernorm and dropout
+ transformer_in = self.emb_dropout(self.emb_layer_norm(pe_input)) # torch.Size([31, 16, 128])
+ # transformer encoding and output parsing
+ out, _ = self.context_enc(transformer_in, mask) # [31, 16, 128]
+ cls = out[-1]
+ cls = cls.view(b, l, -1) # 2, 8, 128
+ context = torch.mean(cls, dim=1)
+ # CLASSIFICATION
+ test_states = self.gnn_encoder(query_frame) # torch.Size([2, 512])
+ test_context_states = torch.cat([context, test_states], dim=1) # torch.Size([batchsize, hidden_dim + lstm_hidden_dim]) 192=512+128
+ # for each state in the test states calculate action
+ x = self.policy(test_context_states)
+ test_actions_pred = torch.tanh(x) # torch.Size([batchsize, 2])
+ #test_actions_pred = torch.tanh(self.policy(test_context_states)) # torch.Size([batchsize, 2])
+ return target_action, test_actions_pred
+
+ def make_mask(self, feature):
+ return (torch.sum(
+ torch.abs(feature),
+ dim=-1
+ ) == 0)#.unsqueeze(1).unsqueeze(2)
+
+ def training_step(self, batch, batch_idx):
+ test_actions, test_actions_pred = self.forward(batch)
+ loss = F.mse_loss(test_actions, test_actions_pred)
+ self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
+ return loss
+
+ def validation_step(self, batch, batch_idx, dataloader_idx=0):
+ test_actions, test_actions_pred = self.forward(batch)
+ loss = F.mse_loss(test_actions, test_actions_pred)
+ self.log('val_loss', loss, on_epoch=True, logger=True)
+
+ def configure_optimizers(self):
+ optim = torch.optim.Adam(self.parameters(), lr=self.lr)
+ return optim
+ #optim = torch.optim.AdamW(self.parameters(), lr=self.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01)
+ #scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optim, gamma=0.96)
+ #return [optim], [scheduler]
+
+ def train_dataloader(self):
+ train_dataset = ToMnetDGLDataset(path=self.hparams.data_path,
+ types=self.hparams.types,
+ mode='train')
+ train_loader = DataLoader(dataset=train_dataset,
+ batch_size=self.hparams.batch_size,
+ collate_fn=collate_function_seq,
+ num_workers=self.hparams.num_workers,
+ #pin_memory=True,
+ shuffle=True)
+ return train_loader
+
+ def val_dataloader(self):
+ val_datasets = []
+ val_loaders = []
+ for t in self.hparams.types:
+ val_datasets.append(ToMnetDGLDataset(path=self.hparams.data_path,
+ types=[t],
+ mode='val'))
+ val_loaders.append(DataLoader(dataset=val_datasets[-1],
+ batch_size=self.hparams.batch_size,
+ collate_fn=collate_function_seq,
+ num_workers=self.hparams.num_workers,
+ #pin_memory=True,
+ shuffle=False))
+ return val_loaders
+
+ def configure_callbacks(self):
+ checkpoint = ModelCheckpoint(
+ dirpath=None, # automatically set
+ #filename=self.params['bc_model']+'-'+self.params['gnn_type']+'-'+self.gnn_params['feats_aggr']+'-{epoch:02d}',
+ save_top_k=-1,
+ period=1
+ )
+ return [checkpoint]
\ No newline at end of file
diff --git a/tom/norm.py b/tom/norm.py
new file mode 100644
index 0000000..818ecf1
--- /dev/null
+++ b/tom/norm.py
@@ -0,0 +1,46 @@
+import torch
+import torch.nn as nn
+
+
+class Norm(nn.Module):
+
+ def __init__(self, norm_type, hidden_dim=64, print_info=None):
+ super(Norm, self).__init__()
+
+ # assert norm_type in ['bn', 'ln', 'gn', None]
+ self.norm = None
+ self.print_info = print_info
+ if norm_type == 'bn':
+ self.norm = nn.BatchNorm1d(hidden_dim)
+ elif norm_type == 'gn':
+ self.norm = norm_type
+ self.weight = nn.Parameter(torch.ones(hidden_dim))
+ self.bias = nn.Parameter(torch.zeros(hidden_dim))
+
+ self.mean_scale = nn.Parameter(torch.ones(hidden_dim))
+
+ def forward(self, graph, tensor, print_=False):
+
+ if self.norm is not None and type(self.norm) != str:
+ return self.norm(tensor)
+ elif self.norm is None:
+ return tensor
+
+ batch_list = graph.batch_num_nodes('obj')
+ batch_size = len(batch_list)
+ #batch_list = torch.tensor(batch_list).long().to(tensor.device)
+ batch_list = batch_list.long().to(tensor.device)
+ batch_index = torch.arange(batch_size).to(tensor.device).repeat_interleave(batch_list)
+ batch_index = batch_index.view((-1,) + (1,) * (tensor.dim() - 1)).expand_as(tensor)
+ mean = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device)
+ mean = mean.scatter_add_(0, batch_index, tensor)
+ mean = (mean.T / batch_list).T
+ mean = mean.repeat_interleave(batch_list, dim=0)
+
+ sub = tensor - mean * self.mean_scale
+
+ std = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device)
+ std = std.scatter_add_(0, batch_index, sub.pow(2))
+ std = ((std.T / batch_list).T + 1e-6).sqrt()
+ std = std.repeat_interleave(batch_list, dim=0)
+ return self.weight * sub / std + self.bias
\ No newline at end of file
diff --git a/tom/transformer.py b/tom/transformer.py
new file mode 100644
index 0000000..0544c9e
--- /dev/null
+++ b/tom/transformer.py
@@ -0,0 +1,89 @@
+import torch
+import torch.nn as nn
+import math
+
+
+class PositionalEncoding(nn.Module):
+
+ def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
+ super().__init__()
+
+ self.dropout = nn.Dropout(p=dropout)
+
+ position = torch.arange(max_len).unsqueeze(1)
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
+ pe = torch.zeros(max_len, 1, d_model)
+ pe[:, 0, 0::2] = torch.sin(position * div_term)
+ pe[:, 0, 1::2] = torch.cos(position * div_term)
+ self.register_buffer('pe', pe)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x: Tensor, shape [seq_len, batch_size, embedding_dim]
+ """
+ x = x + self.pe[:x.size(0)]
+ return self.dropout(x)
+
+
+class TransformerEncoder(nn.Module):
+
+ def __init__(
+ self,
+ d_model,
+ nhead,
+ dim_feedforward,
+ transformer_dropout,
+ transformer_activation,
+ num_encoder_layers,
+ max_input_len,
+ transformer_norm_input
+ ):
+ super().__init__()
+ self.d_model = d_model
+ self.num_layer = num_encoder_layers
+ self.max_input_len = max_input_len
+
+ # Creating Transformer Encoder Model
+ encoder_layer = nn.TransformerEncoderLayer(
+ d_model=d_model,
+ nhead=nhead,
+ dim_feedforward=dim_feedforward,
+ dropout=transformer_dropout,
+ activation=transformer_activation
+ )
+ encoder_norm = nn.LayerNorm(d_model)
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
+
+
+ self.norm_input = None
+ if transformer_norm_input:
+ self.norm_input = nn.LayerNorm(d_model)
+
+ def forward(self, padded_h_node, src_padding_mask):
+ """
+ padded_h_node: n_b x B x h_d # 63, 257, 128
+ src_key_padding_mask: B x n_b # 257, 63
+ """
+ # (S, B, h_d), (B, S)
+ if self.norm_input is not None:
+ padded_h_node = self.norm_input(padded_h_node)
+
+ transformer_out = self.transformer(padded_h_node, src_key_padding_mask=src_padding_mask) # (S, B, h_d)
+
+ return transformer_out, src_padding_mask
+
+
+
+if __name__ == '__main__':
+ model = TransformerEncoder(
+ d_model=12,
+ nhead=4,
+ dim_feedforward=32,
+ transformer_dropout=0.0,
+ transformer_activation='gelu',
+ num_encoder_layers=4,
+ max_input_len=34,
+ transformer_norm_input=0
+ )
+ print(model.norm_input)
\ No newline at end of file
diff --git a/train_tom.py b/train_tom.py
new file mode 100644
index 0000000..9ce2934
--- /dev/null
+++ b/train_tom.py
@@ -0,0 +1,75 @@
+import random
+from argparse import ArgumentParser
+import numpy as np
+import torch
+from pytorch_lightning import Trainer
+from pytorch_lightning.loggers import WandbLogger
+from tom.model import GraphBC_T, GraphBCRNN
+
+torch.multiprocessing.set_sharing_strategy('file_system')
+
+parser = ArgumentParser()
+
+# program level args
+parser.add_argument('--seed', type=int, default=4)
+# data specific args
+parser.add_argument('--data_path', type=str, default='/datasets/external/bib_train/graphs/all_tasks/')
+parser.add_argument('--types', nargs='+', type=str,
+ default=['preference', 'multi_agent', 'single_object', 'instrumental_action'],
+ help='types of tasks used for training / validation')
+parser.add_argument('--train', type=int, default=1)
+parser.add_argument('--num_workers', type=int, default=4)
+parser.add_argument('--batch_size', type=int, default=16)
+parser.add_argument('--model_type', type=str, default='graphbcrnn')
+
+# model specific args
+parser_model = ArgumentParser()
+parser_model = GraphBC_T.add_model_specific_args(parser_model)
+# parser_model = GraphBCRNN.add_model_specific_args(parser_model)
+# NOTE: here unfortunately you have to select manually the model
+
+# add all the available trainer options to argparse
+parser = Trainer.add_argparse_args(parser)
+
+# combine parsers
+parser_all = ArgumentParser(conflict_handler='resolve',
+ parents=[parser, parser_model])
+
+# parse args
+args = parser_all.parse_args()
+args.types = sorted(args.types)
+print(args)
+
+random.seed(args.seed)
+np.random.seed(args.seed)
+torch.manual_seed(args.seed)
+
+# init model
+if args.model_type == 'graphbct':
+ model = GraphBC_T(args)
+elif args.model_type == 'graphbcrnn':
+ model = GraphBCRNN(args)
+else:
+ raise NotImplementedError
+
+torch.autograd.set_detect_anomaly(True)
+
+logger = WandbLogger(project='bib')
+trainer = Trainer(
+ gradient_clip_val=args.gradient_clip_val,
+ gpus=args.gpus,
+ auto_select_gpus=args.auto_select_gpus,
+ track_grad_norm=args.track_grad_norm,
+ check_val_every_n_epoch=args.check_val_every_n_epoch,
+ max_epochs=args.max_epochs,
+ accelerator=args.accelerator,
+ resume_from_checkpoint=args.resume_from_checkpoint,
+ stochastic_weight_avg=args.stochastic_weight_avg,
+ num_sanity_val_steps=args.num_sanity_val_steps,
+ logger=logger
+)
+
+if args.train:
+ trainer.fit(model)
+else:
+ raise NotImplementedError
diff --git a/utils/__init__.py b/utils/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/utils/build_graphs.py b/utils/build_graphs.py
new file mode 100644
index 0000000..5af8fab
--- /dev/null
+++ b/utils/build_graphs.py
@@ -0,0 +1,115 @@
+import sys
+sys.path.append('/projects/bortoletto/icml2023_matteo/utils')
+from dataset import TransitionDataset, TestTransitionDatasetSequence
+import multiprocessing as mp
+import argparse
+import pickle as pkl
+import os
+
+# Instantiate the parser
+parser = argparse.ArgumentParser()
+parser.add_argument('--cpus', type=int,
+ help='Number of processes')
+parser.add_argument('--mode', type=str,
+ help='Train (train) or validation (val)')
+args = parser.parse_args()
+
+NUM_PROCESSES = args.cpus
+MODE = args.mode
+
+def generate_files(idx):
+ print('Generating idx', idx)
+ if os.path.exists(PATH+str(idx)+'.pkl'):
+ print('Index', idx, 'skipped.')
+ return
+ if MODE == 'train' or MODE == 'val':
+ states, actions, lens, n_nodes = dataset.__getitem__(idx)
+ with open(PATH+str(idx)+'.pkl', 'wb') as f:
+ pkl.dump([states, actions, lens, n_nodes], f)
+ elif MODE == 'test':
+ dem_expected_states, dem_expected_actions, dem_expected_lens, dem_expected_nodes, \
+ dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, dem_unexpected_nodes, \
+ query_expected_frames, target_expected_actions, \
+ query_unexpected_frames, target_unexpected_actions = dataset.__getitem__(idx)
+ with open(PATH+str(idx)+'.pkl', 'wb') as f:
+ pkl.dump([
+ dem_expected_states, dem_expected_actions, dem_expected_lens, dem_expected_nodes, \
+ dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, dem_unexpected_nodes, \
+ query_expected_frames, target_expected_actions, \
+ query_unexpected_frames, target_unexpected_actions], f
+ )
+ else:
+ raise ValueError('MODE can be only train, val or test.')
+ print(PATH+str(idx)+'.pkl saved.')
+
+if __name__ == "__main__":
+ if MODE == 'train':
+ print('TRAIN MODE')
+ PATH = '/datasets/external/bib_train/graphs/all_tasks/train_dgl_hetero_nobound_4feats/'
+ if not os.path.exists(PATH):
+ os.makedirs(PATH)
+ print(PATH, 'directory created.')
+ dataset = TransitionDataset(
+ path='/datasets/external/bib_train/',
+ types=['instrumental_action', 'multi_agent', 'preference', 'single_object'],
+ mode="train",
+ max_len=30,
+ num_test=1,
+ num_trials=9,
+ action_range=10,
+ process_data=0
+ )
+ pool = mp.Pool(processes=NUM_PROCESSES)
+ print('Starting graph generation with', NUM_PROCESSES, 'processes...')
+ pool.map(generate_files, [i for i in range(dataset.__len__())])
+ pool.close()
+ elif MODE == 'val':
+ print('VALIDATION MODE')
+ types = ['multi_agent', 'instrumental_action', 'preference', 'single_object']
+ for t in range(len(types)):
+ PATH = '/datasets/external/bib_train/graphs/all_tasks/val_dgl_hetero_nobound_4feats/'+types[t]+'/'
+ if not os.path.exists(PATH):
+ os.makedirs(PATH)
+ print(PATH, 'directory created.')
+ dataset = TransitionDataset(
+ path='/datasets/external/bib_train/',
+ types=[types[t]],
+ mode="val",
+ max_len=30,
+ num_test=1,
+ num_trials=9,
+ action_range=10,
+ process_data=0
+ )
+ pool = mp.Pool(processes=NUM_PROCESSES)
+ print('Starting', types[t], 'graph generation with', NUM_PROCESSES, 'processes...')
+ pool.map(generate_files, [i for i in range(dataset.__len__())])
+ pool.close()
+ elif MODE == 'test':
+ print('TEST MODE')
+ types = [
+ 'preference', 'multi_agent', 'inaccessible_goal',
+ 'efficiency_irrational', 'efficiency_time','efficiency_path',
+ 'instrumental_no_barrier', 'instrumental_blocking_barrier', 'instrumental_inconsequential_barrier'
+ ]
+ for t in range(len(types)):
+ PATH = '/datasets/external/bib_evaluation_1_1/graphs/all_tasks_dgl_hetero_nobound_4feats/'+types[t]+'/'
+ if not os.path.exists(PATH):
+ os.makedirs(PATH)
+ print(PATH, 'directory created.')
+ dataset = TestTransitionDatasetSequence(
+ path='/datasets/external/bib_evaluation_1_1/',
+ task_type=types[t],
+ mode="test",
+ num_test=1,
+ num_trials=9,
+ action_range=10,
+ process_data=0,
+ max_len=30
+ )
+ pool = mp.Pool(processes=NUM_PROCESSES)
+ print('Starting', types[t], 'graph generation with', NUM_PROCESSES, 'processes...')
+ pool.map(generate_files, [i for i in range(dataset.__len__())])
+ pool.close()
+ else:
+ raise ValueError
diff --git a/utils/dataset.py b/utils/dataset.py
new file mode 100644
index 0000000..061879d
--- /dev/null
+++ b/utils/dataset.py
@@ -0,0 +1,487 @@
+import dgl
+import torch
+import torch.utils.data
+import os
+import pickle as pkl
+import json
+import numpy as np
+from tqdm import tqdm
+
+import sys
+sys.path.append('/projects/bortoletto/irene/')
+from utils.grid_object import *
+from utils.relations import *
+
+# ========================== Helper functions ==========================
+
+def index_data(json_list, path_list):
+ print(f'processing files {len(json_list)}')
+ data_tuples = []
+ for j, v in tqdm(zip(json_list, path_list)):
+ with open(j, 'r') as f:
+ state = json.load(f)
+ ep_lens = [len(x) for x in state]
+ past_len = 0
+ for e, l in enumerate(ep_lens):
+ data_tuples.append([])
+ # skip first 30 frames and last 83 frames
+ for f in range(30, l - 83):
+ # find action taken;
+ f0x, f0y = state[e][f]['agent'][0]
+ f1x, f1y = state[e][f + 1]['agent'][0]
+ dx = (f1x - f0x) / 2.
+ dy = (f1y - f0y) / 2.
+ action = [dx, dy]
+ #data_tuples[-1].append((v, past_len + f, action))
+ data_tuples[-1].append((j, past_len + f, action))
+ # data_tuples = [[json file, frame number, action] for each episode in each video]
+ assert len(data_tuples[-1]) > 0
+ past_len += l
+ return data_tuples
+
+# ========================== Dataset class ==========================
+
+class TransitionDataset(torch.utils.data.Dataset):
+ """
+ Training dataset class for the behavior cloning mlp model.
+ Args:
+ path: path to the dataset
+ types: list of video types to include
+ mode: train, val
+ num_test: number of test state-action pairs
+ num_trials: number of trials in an episode
+ action_range: number of frames to skip; actions are combined over these number of frames (displcement) of the agent
+ process_data: whether to the videos or not (skip if already processed)
+ max_len: max number of context state-action pairs
+ __getitem__:
+ returns: (states, actions, lens, n_nodes)
+ dem_frames: batched DGLGraph.heterograph
+ dem_actions: (max_len, 2)
+ query_frames: DGLGraph.heterograph
+ target_actions: (num_test, 2)
+ """
+ def __init__(
+ self,
+ path,
+ types=None,
+ mode="train",
+ num_test=1,
+ num_trials=9,
+ action_range=10,
+ process_data=0,
+ max_len=30
+ ):
+ self.path = path
+ self.types = types
+ self.mode = mode
+ self.num_trials = num_trials
+ self.num_test = num_test
+ self.action_range = action_range
+ self.max_len = max_len
+ self.ep_combs = self.num_trials * (self.num_trials - 2) # 9p2 - 9
+ self.eps = [[x, y] for x in range(self.num_trials) for y in range(self.num_trials) if x != y]
+ types_str = '_'.join(self.types)
+ self.rel_deter_func = [
+ is_top_adj, is_left_adj, is_top_right_adj, is_top_left_adj,
+ is_down_adj, is_right_adj, is_down_left_adj, is_down_right_adj,
+ is_left, is_right, is_front, is_back, is_aligned, is_close
+ ]
+ self.path_list = []
+ self.json_list = []
+ # get video paths and json file paths
+ for t in types:
+ print(f'reading files of type {t} in {mode}')
+ paths = [os.path.join(self.path, t, x) for x in os.listdir(os.path.join(self.path, t)) if
+ x.endswith(f'.mp4')]
+ jsons = [os.path.join(self.path, t, x) for x in os.listdir(os.path.join(self.path, t)) if
+ x.endswith(f'.json') and 'index' not in x]
+ paths = sorted(paths)
+ jsons = sorted(jsons)
+ if mode == 'train':
+ self.path_list += paths[:int(0.8 * len(jsons))]
+ self.json_list += jsons[:int(0.8 * len(jsons))]
+ elif mode == 'val':
+ self.path_list += paths[int(0.8 * len(jsons)):]
+ self.json_list += jsons[int(0.8 * len(jsons)):]
+ else:
+ self.path_list += paths
+ self.json_list += jsons
+ self.data_tuples = []
+ if process_data:
+ # index the videos in the dataset directory. This is done to speed up the retrieval of videos.
+ # frame index, action tuples are stored
+ self.data_tuples = index_data(self.json_list, self.path_list)
+ # tuples of frame index and action (displacement of agent)
+ index_dict = {'data_tuples': self.data_tuples}
+ with open(os.path.join(self.path, f'jindex_bib_{mode}_{types_str}.json'), 'w') as fp:
+ json.dump(index_dict, fp)
+ else:
+ # read pre-indexed data
+ with open(os.path.join(self.path, f'jindex_bib_{mode}_{types_str}.json'), 'r') as fp:
+ index_dict = json.load(fp)
+ self.data_tuples = index_dict['data_tuples']
+ self.tot_trials = len(self.path_list) * 9
+
+ def _get_frame_graph(self, jsonfile, frame_idx):
+ # load json
+ with open(jsonfile, 'rb') as f:
+ frame_data = json.load(f)
+ flat_list = [x for xs in frame_data for x in xs]
+ # extract entities
+ grid_objs = parse_objects(flat_list[frame_idx])
+ # --- build the graph
+ adj = self._get_spatial_rel(grid_objs)
+ # define edges
+ is_top_adj_src, is_top_adj_dst = np.nonzero(adj[0])
+ is_left_adj_src, is_left_adj_dst = np.nonzero(adj[1])
+ is_top_right_adj_src, is_top_right_adj_dst = np.nonzero(adj[2])
+ is_top_left_adj_src, is_top_left_adj_dst = np.nonzero(adj[3])
+ is_down_adj_src, is_down_adj_dst = np.nonzero(adj[4])
+ is_right_adj_src, is_right_adj_dst = np.nonzero(adj[5])
+ is_down_left_adj_src, is_down_left_adj_dst = np.nonzero(adj[6])
+ is_down_right_adj_src, is_down_right_adj_dst = np.nonzero(adj[7])
+ is_left_src, is_left_dst = np.nonzero(adj[8])
+ is_right_src, is_right_dst = np.nonzero(adj[9])
+ is_front_src, is_front_dst = np.nonzero(adj[10])
+ is_back_src, is_back_dst = np.nonzero(adj[11])
+ is_aligned_src, is_aligned_dst = np.nonzero(adj[12])
+ is_close_src, is_close_dst = np.nonzero(adj[13])
+ g = dgl.heterograph({
+ ('obj', 'is_top_adj', 'obj'): (torch.tensor(is_top_adj_src), torch.tensor(is_top_adj_dst)),
+ ('obj', 'is_left_adj', 'obj'): (torch.tensor(is_left_adj_src), torch.tensor(is_left_adj_dst)),
+ ('obj', 'is_top_right_adj', 'obj'): (torch.tensor(is_top_right_adj_src), torch.tensor(is_top_right_adj_dst)),
+ ('obj', 'is_top_left_adj', 'obj'): (torch.tensor(is_top_left_adj_src), torch.tensor(is_top_left_adj_dst)),
+ ('obj', 'is_down_adj', 'obj'): (torch.tensor(is_down_adj_src), torch.tensor(is_down_adj_dst)),
+ ('obj', 'is_right_adj', 'obj'): (torch.tensor(is_right_adj_src), torch.tensor(is_right_adj_dst)),
+ ('obj', 'is_down_left_adj', 'obj'): (torch.tensor(is_down_left_adj_src), torch.tensor(is_down_left_adj_dst)),
+ ('obj', 'is_down_right_adj', 'obj'): (torch.tensor(is_down_right_adj_src), torch.tensor(is_down_right_adj_dst)),
+ ('obj', 'is_left', 'obj'): (torch.tensor(is_left_src), torch.tensor(is_left_dst)),
+ ('obj', 'is_right', 'obj'): (torch.tensor(is_right_src), torch.tensor(is_right_dst)),
+ ('obj', 'is_front', 'obj'): (torch.tensor(is_front_src), torch.tensor(is_front_dst)),
+ ('obj', 'is_back', 'obj'): (torch.tensor(is_back_src), torch.tensor(is_back_dst)),
+ ('obj', 'is_aligned', 'obj'): (torch.tensor(is_aligned_src), torch.tensor(is_aligned_dst)),
+ ('obj', 'is_close', 'obj'): (torch.tensor(is_close_src), torch.tensor(is_close_dst))
+ }, num_nodes_dict={'obj': len(grid_objs)})
+ g = self._add_node_features(grid_objs, g)
+ breakpoint()
+ return g
+
+ def _add_node_features(self, objs, graph):
+ for obj_idx, obj in enumerate(objs):
+ graph.nodes[obj_idx].data['type'] = torch.tensor(obj.type)
+ graph.nodes[obj_idx].data['pos'] = torch.tensor([[obj.x, obj.y]], dtype=torch.float32)
+ assert len(obj.attributes) == 2 and None not in obj.attributes[0] and None not in obj.attributes[1]
+ graph.nodes[obj_idx].data['color'] = torch.tensor([obj.attributes[0]])
+ graph.nodes[obj_idx].data['shape'] = torch.tensor([obj.attributes[1]])
+ return graph
+
+ def _get_spatial_rel(self, objs):
+ spatial_tensors = [np.zeros([len(objs), len(objs)]) for _ in range(len(self.rel_deter_func))]
+ for obj_idx1, obj1 in enumerate(objs):
+ for obj_idx2, obj2 in enumerate(objs):
+ direction_vec = np.array((0, -1)) # Up
+ for rel_idx, func in enumerate(self.rel_deter_func):
+ if func(obj1, obj2, direction_vec):
+ spatial_tensors[rel_idx][obj_idx1, obj_idx2] = 1.0
+ return spatial_tensors
+
+ def get_trial(self, trials, step=10):
+ # retrieve state embeddings and actions from cached file
+ states = []
+ actions = []
+ trial_len = []
+ lens = []
+ n_nodes = []
+ # 8 trials
+ for t in trials:
+ tl = [(t, n) for n in range(0, len(self.data_tuples[t]), step)]
+ if len(tl) > self.max_len: # 30
+ tl = tl[:self.max_len]
+ trial_len.append(tl)
+ for tl in trial_len:
+ states.append([])
+ actions.append([])
+ lens.append(len(tl))
+ for t, n in tl:
+ video = self.data_tuples[t][n][0]
+ states[-1].append(self._get_frame_graph(video, self.data_tuples[t][n][1]))
+ n_nodes.append(states[-1][-1].number_of_nodes())
+ # actions are pooled over frames
+ if len(self.data_tuples[t]) > n + self.action_range:
+ actions_xy = [d[2] for d in self.data_tuples[t][n:n + self.action_range]]
+ else:
+ actions_xy = [d[2] for d in self.data_tuples[t][n:]]
+ actions_xy = np.array(actions_xy)
+ actions_xy = np.mean(actions_xy, axis=0)
+ actions[-1].append(actions_xy)
+ states[-1] = dgl.batch(states[-1])
+ actions[-1] = torch.tensor(np.array(actions[-1]))
+ trial_actions_padded = torch.zeros(self.max_len, actions[-1].size(1))
+ trial_actions_padded[:actions[-1].size(0), :] = actions[-1]
+ actions[-1] = trial_actions_padded
+ return states, actions, lens, n_nodes
+
+ def __getitem__(self, idx):
+ ep_trials = [idx * self.num_trials + t for t in range(self.num_trials)] # [idx, ..., idx+8]
+ states, actions, lens, n_nodes = self.get_trial(ep_trials, step=self.action_range)
+ return states, actions, lens, n_nodes
+
+ def __len__(self):
+ return self.tot_trials // self.num_trials
+
+
+class TestTransitionDatasetSequence(torch.utils.data.Dataset):
+ """
+ Test dataset class for the behavior cloning rnn model. This dataset is used to test the model on the eval data.
+ This class is used to compare plausible and implausible episodes.
+ Args:
+ path: path to the dataset
+ types: list of video types to include
+ size: size of the frames to be returned
+ mode: test
+ num_context: number of context state-action pairs
+ num_test: number of test state-action pairs
+ num_trials: number of trials in an episode
+ action_range: number of frames to skip; actions are combined over these number of frames (displcement) of the agent
+ process_data: whether to the videos or not (skip if already processed)
+ __getitem__:
+ returns: (expected_dem_frames, expected_dem_actions, expected_dem_lens expected_query_frames, expected_target_actions,
+ unexpected_dem_frames, unexpected_dem_actions, unexpected_dem_lens, unexpected_query_frames, unexpected_target_actions)
+ dem_frames: (num_context, max_len, 3, size, size)
+ dem_actions: (num_context, max_len, 2)
+ dem_lens: (num_context)
+ query_frames: (num_test, 3, size, size)
+ target_actions: (num_test, 2)
+ """
+ def __init__(
+ self,
+ path,
+ task_type=None,
+ mode="test",
+ num_test=1,
+ num_trials=9,
+ action_range=10,
+ process_data=0,
+ max_len=30
+ ):
+ self.path = path
+ self.task_type = task_type
+ self.mode = mode
+ self.num_trials = num_trials
+ self.num_test = num_test
+ self.action_range = action_range
+ self.max_len = max_len
+ self.ep_combs = self.num_trials * (self.num_trials - 2) # 9p2 - 9
+ self.eps = [[x, y] for x in range(self.num_trials) for y in range(self.num_trials) if x != y]
+ self.path_list_exp = []
+ self.json_list_exp = []
+ self.path_list_un = []
+ self.json_list_un = []
+
+ print(f'reading files of type {task_type} in {mode}')
+
+ paths_expected = sorted([os.path.join(self.path, task_type, x) for x in os.listdir(os.path.join(self.path, task_type)) if
+ x.endswith('e.mp4')])
+ jsons_expected = sorted([os.path.join(self.path, task_type, x) for x in os.listdir(os.path.join(self.path, task_type)) if
+ x.endswith('e.json') and 'index' not in x])
+ paths_unexpected = sorted([os.path.join(self.path, task_type, x) for x in os.listdir(os.path.join(self.path, task_type)) if
+ x.endswith('u.mp4')])
+ jsons_unexpected = sorted([os.path.join(self.path, task_type, x) for x in os.listdir(os.path.join(self.path, task_type)) if
+ x.endswith('u.json') and 'index' not in x])
+ self.path_list_exp += paths_expected
+ self.json_list_exp += jsons_expected
+ self.path_list_un += paths_unexpected
+ self.json_list_un += jsons_unexpected
+ self.data_expected = []
+ self.data_unexpected = []
+
+ if process_data:
+ # index data. This is done to speed up video retrieval.
+ # frame index, action tuples are stored
+ self.data_expected = index_data(self.json_list_exp, self.path_list_exp)
+ index_dict = {'data_tuples': self.data_expected}
+ with open(os.path.join(self.path, f'jindex_bib_test_{task_type}e.json'), 'w') as fp:
+ json.dump(index_dict, fp)
+
+ self.data_unexpected = index_data(self.json_list_un, self.path_list_un)
+ index_dict = {'data_tuples': self.data_unexpected}
+ with open(os.path.join(self.path, f'jindex_bib_test_{task_type}u.json'), 'w') as fp:
+ json.dump(index_dict, fp)
+ else:
+ with open(os.path.join(self.path, f'jindex_bib_{mode}_{task_type}e.json'), 'r') as fp:
+ index_dict = json.load(fp)
+ self.data_expected = index_dict['data_tuples']
+ with open(os.path.join(self.path, f'jindex_bib_{mode}_{task_type}u.json'), 'r') as fp:
+ index_dict = json.load(fp)
+ self.data_unexpected = index_dict['data_tuples']
+
+ self.rel_deter_func = [
+ is_top_adj, is_left_adj, is_top_right_adj, is_top_left_adj,
+ is_down_adj, is_right_adj, is_down_left_adj, is_down_right_adj,
+ is_left, is_right, is_front, is_back, is_aligned, is_close
+ ]
+
+ print('Done.')
+
+ def _get_frame_graph(self, jsonfile, frame_idx):
+ # load json
+ with open(jsonfile, 'rb') as f:
+ frame_data = json.load(f)
+ flat_list = [x for xs in frame_data for x in xs]
+ # extract entities
+ grid_objs = parse_objects(flat_list[frame_idx])
+ # --- build the graph
+ adj = self._get_spatial_rel(grid_objs)
+ # define edges
+ is_top_adj_src, is_top_adj_dst = np.nonzero(adj[0])
+ is_left_adj_src, is_left_adj_dst = np.nonzero(adj[1])
+ is_top_right_adj_src, is_top_right_adj_dst = np.nonzero(adj[2])
+ is_top_left_adj_src, is_top_left_adj_dst = np.nonzero(adj[3])
+ is_down_adj_src, is_down_adj_dst = np.nonzero(adj[4])
+ is_right_adj_src, is_right_adj_dst = np.nonzero(adj[5])
+ is_down_left_adj_src, is_down_left_adj_dst = np.nonzero(adj[6])
+ is_down_right_adj_src, is_down_right_adj_dst = np.nonzero(adj[7])
+ is_left_src, is_left_dst = np.nonzero(adj[8])
+ is_right_src, is_right_dst = np.nonzero(adj[9])
+ is_front_src, is_front_dst = np.nonzero(adj[10])
+ is_back_src, is_back_dst = np.nonzero(adj[11])
+ is_aligned_src, is_aligned_dst = np.nonzero(adj[12])
+ is_close_src, is_close_dst = np.nonzero(adj[13])
+ g = dgl.heterograph({
+ ('obj', 'is_top_adj', 'obj'): (torch.tensor(is_top_adj_src), torch.tensor(is_top_adj_dst)),
+ ('obj', 'is_left_adj', 'obj'): (torch.tensor(is_left_adj_src), torch.tensor(is_left_adj_dst)),
+ ('obj', 'is_top_right_adj', 'obj'): (torch.tensor(is_top_right_adj_src), torch.tensor(is_top_right_adj_dst)),
+ ('obj', 'is_top_left_adj', 'obj'): (torch.tensor(is_top_left_adj_src), torch.tensor(is_top_left_adj_dst)),
+ ('obj', 'is_down_adj', 'obj'): (torch.tensor(is_down_adj_src), torch.tensor(is_down_adj_dst)),
+ ('obj', 'is_right_adj', 'obj'): (torch.tensor(is_right_adj_src), torch.tensor(is_right_adj_dst)),
+ ('obj', 'is_down_left_adj', 'obj'): (torch.tensor(is_down_left_adj_src), torch.tensor(is_down_left_adj_dst)),
+ ('obj', 'is_down_right_adj', 'obj'): (torch.tensor(is_down_right_adj_src), torch.tensor(is_down_right_adj_dst)),
+ ('obj', 'is_left', 'obj'): (torch.tensor(is_left_src), torch.tensor(is_left_dst)),
+ ('obj', 'is_right', 'obj'): (torch.tensor(is_right_src), torch.tensor(is_right_dst)),
+ ('obj', 'is_front', 'obj'): (torch.tensor(is_front_src), torch.tensor(is_front_dst)),
+ ('obj', 'is_back', 'obj'): (torch.tensor(is_back_src), torch.tensor(is_back_dst)),
+ ('obj', 'is_aligned', 'obj'): (torch.tensor(is_aligned_src), torch.tensor(is_aligned_dst)),
+ ('obj', 'is_close', 'obj'): (torch.tensor(is_close_src), torch.tensor(is_close_dst))
+ }, num_nodes_dict={'obj': len(grid_objs)})
+ g = self._add_node_features(grid_objs, g)
+ return g
+
+ def _add_node_features(self, objs, graph):
+ for obj_idx, obj in enumerate(objs):
+ graph.nodes[obj_idx].data['type'] = torch.tensor(obj.type)
+ graph.nodes[obj_idx].data['pos'] = torch.tensor([[obj.x, obj.y]], dtype=torch.float32)
+ assert len(obj.attributes) == 2 and None not in obj.attributes[0] and None not in obj.attributes[1]
+ graph.nodes[obj_idx].data['color'] = torch.tensor([obj.attributes[0]])
+ graph.nodes[obj_idx].data['shape'] = torch.tensor([obj.attributes[1]])
+ return graph
+
+ def _get_spatial_rel(self, objs):
+ spatial_tensors = [np.zeros([len(objs), len(objs)]) for _ in range(len(self.rel_deter_func))]
+ for obj_idx1, obj1 in enumerate(objs):
+ for obj_idx2, obj2 in enumerate(objs):
+ direction_vec = np.array((0, -1)) # Up why??????????????
+ for rel_idx, func in enumerate(self.rel_deter_func):
+ if func(obj1, obj2, direction_vec):
+ spatial_tensors[rel_idx][obj_idx1, obj_idx2] = 1.0
+ return spatial_tensors
+
+ def get_trial(self, trials, data, step=10):
+ # retrieve state embeddings and actions from cached file
+ states = []
+ actions = []
+ trial_len = []
+ lens = []
+ n_nodes = []
+ for t in trials:
+ tl = [(t, n) for n in range(0, len(data[t]), step)]
+ if len(tl) > self.max_len:
+ tl = tl[:self.max_len]
+ trial_len.append(tl)
+ for tl in trial_len:
+ states.append([])
+ actions.append([])
+ lens.append(len(tl))
+ for t, n in tl:
+ video = data[t][n][0]
+ states[-1].append(self._get_frame_graph(video, data[t][n][1]))
+ n_nodes.append(states[-1][-1].number_of_nodes())
+ if len(data[t]) > n + self.action_range:
+ actions_xy = [d[2] for d in data[t][n:n + self.action_range]]
+ else:
+ actions_xy = [d[2] for d in data[t][n:]]
+ actions_xy = np.array(actions_xy)
+ actions_xy = np.mean(actions_xy, axis=0)
+ actions[-1].append(actions_xy)
+ states[-1] = dgl.batch(states[-1])
+ actions[-1] = torch.tensor(np.array(actions[-1]))
+ trial_actions_padded = torch.zeros(self.max_len, actions[-1].size(1))
+ trial_actions_padded[:actions[-1].size(0), :] = actions[-1]
+ actions[-1] = trial_actions_padded
+ return states, actions, lens, n_nodes
+
+ def get_test(self, trial, data, step=10):
+ # retrieve state embeddings and actions from cached file
+ states = []
+ actions = []
+ trial_len = []
+ trial_len += [(trial, n) for n in range(0, len(data[trial]), step)]
+ for t, n in trial_len:
+ video = data[t][n][0]
+ state = self._get_frame_graph(video, data[t][n][1])
+ if len(data[t]) > n + self.action_range:
+ actions_xy = [d[2] for d in data[t][n:n + self.action_range]]
+ else:
+ actions_xy = [d[2] for d in data[t][n:]]
+ actions_xy = np.array(actions_xy)
+ actions_xy = np.mean(actions_xy, axis=0)
+ actions.append(actions_xy)
+ states.append(state)
+ #states = torch.stack(states)
+ states = dgl.batch(states)
+ actions = torch.tensor(np.array(actions))
+ return states, actions
+
+ def __getitem__(self, idx):
+ ep_trials = [idx * self.num_trials + t for t in range(self.num_trials)]
+ dem_expected_states, dem_expected_actions, dem_expected_lens, dem_expected_nodes = self.get_trial(
+ ep_trials[:-1], self.data_expected, step=self.action_range
+ )
+ dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, dem_unexpected_nodes = self.get_trial(
+ ep_trials[:-1], self.data_unexpected, step=self.action_range
+ )
+ query_expected_frames, target_expected_actions = self.get_test(
+ ep_trials[-1], self.data_expected, step=self.action_range
+ )
+ query_unexpected_frames, target_unexpected_actions = self.get_test(
+ ep_trials[-1], self.data_unexpected, step=self.action_range
+ )
+ return dem_expected_states, dem_expected_actions, dem_expected_lens, dem_expected_nodes, \
+ dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, dem_unexpected_nodes, \
+ query_expected_frames, target_expected_actions, \
+ query_unexpected_frames, target_unexpected_actions
+
+ def __len__(self):
+ return len(self.path_list_exp)
+
+
+if __name__ == '__main__':
+ types = ['preference', 'multi_agent', 'inaccessible_goal',
+ 'efficiency_irrational', 'efficiency_time','efficiency_path',
+ 'instrumental_no_barrier', 'instrumental_blocking_barrier', 'instrumental_inconsequential_barrier']
+ for t in types:
+ ttd = TestTransitionDatasetSequence(path='/datasets/external/bib_evaluation_1_1/', task_type=t, process_data=0, mode='test')
+ for i in range(ttd.__len__()):
+ print(i, end='\r')
+ dem_expected_states, dem_expected_actions, dem_expected_lens, dem_expected_nodes, \
+ dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, dem_unexpected_nodes, \
+ query_expected_frames, target_expected_actions, \
+ query_unexpected_frames, target_unexpected_actions = ttd.__getitem__(i)
+ for j in range(8):
+ if not torch.tensor([1., 0., 0., 0., 0., 0., 0., 0., 0.]) in dem_expected_states[j].ndata['type']:
+ print(i)
+ if not torch.tensor([1., 0., 0., 0., 0., 0., 0., 0., 0.]) in dem_unexpected_states[j].ndata['type']:
+ print(i)
+ if not torch.tensor([1., 0., 0., 0., 0., 0., 0., 0., 0.]) in query_expected_frames.ndata['type']:
+ print(i)
+ if not torch.tensor([1., 0., 0., 0., 0., 0., 0., 0., 0.]) in query_unexpected_frames.ndata['type']:
+ print(i)
diff --git a/utils/grid_object.py b/utils/grid_object.py
new file mode 100644
index 0000000..d5cc682
--- /dev/null
+++ b/utils/grid_object.py
@@ -0,0 +1,174 @@
+import json
+import pdb
+import numpy as np
+from sklearn.preprocessing import OneHotEncoder
+import itertools
+
+
+SHAPES = {
+ # walls
+ 'square': 0,
+ # objects
+ 'heart': 1, 'clove_tree': 2, 'moon': 3, 'wine': 4, 'double_dia': 5,
+ 'flag': 6, 'capsule': 7, 'vase': 8, 'curved_triangle': 9, 'spoon': 10,
+ 'medal': 11, # inaccessible goal
+ # home
+ 'home': 12,
+ # agent
+ 'pentagon': 13, 'clove': 14, 'kite': 15,
+ # key
+ 'triangle0': 16, 'triangle180': 16, 'triangle90': 16, 'triangle270': 16,
+ # lock
+ 'triangle_slot0': 17, 'triangle_slot180': 17, 'triangle_slot90': 17, 'triangle_slot270': 17
+}
+
+ENTITIES = {
+ 'agent': 0 , 'walls': 1, 'fuse_walls': 2, 'key': 3,
+ 'blocking': 4, 'home': 5, 'objects': 6, 'pin': 7, 'lock': 8
+}
+
+# =============================== GridPbject class ===============================
+
+class GridObject():
+ "object is specified by its location"
+ def __init__(self, x, y, object_type, attributes=[]):
+ self.x = x
+ self.y = y
+ self.type = object_type
+ self.attributes = attributes
+
+ @property
+ def pos(self):
+ return np.array([self.x, self.y])
+
+ @property
+ def name(self):
+ return {'type': str(self.type),
+ 'x': str(self.x),
+ 'y': str(self.y),
+ 'color': self.attributes[0],
+ 'shape': self.attributes[1]}
+
+# =============================== Helper functions ===============================
+
+def type2index(key):
+ for name, idx in ENTITIES.items():
+ if name == key:
+ return idx
+
+def find_shape(shape_string, print_shape=False):
+ try:
+ shape = shape_string.split('/')[-1].split('.')[0]
+ except:
+ shape = shape_string.split('.')[0]
+ if print_shape: print(shape)
+ for name, idx in SHAPES.items():
+ if name == shape:
+ return idx
+
+def parse_objects(frame):
+ """
+ x and y are computed differently from walls and objects
+ for walls x, y = obj[0][0] + obj[1][0]/2, obj[0][1] + obj[1][1]/2
+ for objects x, y = obj[0][0] + obj[1], obj[0][1] + obj[1]
+ :param obj:
+ :return: GridObject
+ """
+ shape_onehot_encoder = OneHotEncoder(sparse=False)
+ shape_onehot_encoder.fit([[i] for i in range(len(SHAPES)-6)])
+ type_onehot_encoder = OneHotEncoder(sparse=False)
+ type_onehot_encoder.fit([[i] for i in range(len(ENTITIES))])
+ # remove duplicate walls
+ frame['walls'].sort()
+ frame['walls'] = list(k for k, _ in itertools.groupby(frame['walls']))
+ # remove boundary walls
+ frame['walls'] = [w for w in frame['walls'] if (w[0][0] != 0 and w[0][0] != 180 and w[0][1] != 0 and w[0][1] != 180)]
+ # remove duplicate fuse_walls
+ frame['fuse_walls'].sort()
+ frame['fuse_walls'] = list(k for k, _ in itertools.groupby(frame['fuse_walls']))
+ grid_objs = []
+ assert 'agent' in frame.keys()
+ for key in frame.keys():
+ #print(key)
+ if key == 'size':
+ continue
+ obj = frame[key]
+ if obj == []:
+ #print(key, 'skipped')
+ continue
+ obj_type = type2index(key)
+ obj_type = type_onehot_encoder.transform([[obj_type]])
+ if key == 'walls':
+ for wall in obj:
+ x, y = wall[0][0] + wall[1][0]/2, wall[0][1] + wall[1][1]/2
+ #x, y = (wall[0][0] + wall[1][0]/2)/200, (wall[0][1] + wall[1][1]/2)/200 if u use this you need to change relations.py!!!
+ color = [0, 0, 0] if key == 'walls' else [80, 146, 56]
+ #color = [c / 255 for c in color]
+ shape = 0
+ assert shape in SHAPES.values(), 'Shape not found'
+ shape = shape_onehot_encoder.transform([[shape]])[0]
+ grid_obj = GridObject(x=x, y=y, object_type=obj_type, attributes=[color, shape])
+ grid_objs.append(grid_obj)
+ elif key == 'fuse_walls':
+ # resample green barriers
+ obj = [obj[i] for i in range(len(obj)) if (obj[i][0][0] % 20 == 0 and obj[i][0][1] % 20 == 0)]
+ for wall in obj:
+ x, y = wall[0][0] + wall[1][0]/2, wall[0][1] + wall[1][1]/2
+ #x, y = (wall[0][0] + wall[1][0]/2)/200, (wall[0][1] + wall[1][1]/2)/200 if u use this you need to change relations.py!!!
+ color = [80, 146, 56]
+ #color = [c / 255 for c in color]
+ shape = 0
+ assert shape in SHAPES.values(), 'Shape not found'
+ shape = shape_onehot_encoder.transform([[shape]])[0]
+ grid_obj = GridObject(x=x, y=y, object_type=obj_type, attributes=[color, shape])
+ grid_objs.append(grid_obj)
+ elif key == 'objects':
+ for ob in obj:
+ x, y = ob[0][0] + ob[1], ob[0][1] + ob[1]
+ color = ob[-1]
+ #color = [c / 255 for c in color]
+ shape = find_shape(ob[2], print_shape=False)
+ assert shape in SHAPES.values(), 'Shape not found'
+ shape = shape_onehot_encoder.transform([[shape]])[0]
+ grid_obj = GridObject(x=x, y=y, object_type=obj_type, attributes=[color, shape])
+ grid_objs.append(grid_obj)
+ elif key == 'key':
+ obj = obj[0]
+ x, y = obj[0][0] + obj[1], obj[0][1] + obj[1]
+ color = obj[-1]
+ #color = [c / 255 for c in color]
+ shape = find_shape(obj[2], print_shape=False)
+ assert shape in SHAPES.values(), 'Shape not found'
+ shape = shape_onehot_encoder.transform([[shape]])[0]
+ grid_obj = GridObject(x=x, y=y, object_type=obj_type, attributes=[color, shape])
+ grid_objs.append(grid_obj)
+ elif key == 'lock':
+ obj = obj[0]
+ x, y = obj[0][0] + obj[1], obj[0][1] + obj[1]
+ color = obj[-1]
+ #color = [c / 255 for c in color]
+ shape = find_shape(obj[2], print_shape=False)
+ assert shape in SHAPES.values(), 'Shape not found'
+ shape = shape_onehot_encoder.transform([[shape]])[0]
+ grid_obj = GridObject(x=x, y=y, object_type=obj_type, attributes=[color, shape])
+ grid_objs.append(grid_obj)
+ else:
+ try:
+ x, y = obj[0][0] + obj[1], obj[0][1] + obj[1]
+ color = obj[-1]
+ #color = [c / 255 for c in color]
+ shape = find_shape(obj[2], print_shape=False)
+ assert shape in SHAPES.values(), 'Shape not found'
+ shape = shape_onehot_encoder.transform([[shape]])[0]
+ except:
+ # [[[x, y], extension, shape, color]] in some cases in instrumental_no_barrier (bib_evaluation_1_1)
+ x, y = obj[0][0][0] + obj[0][1], obj[0][0][1] + obj[0][1]
+ color = obj[0][-1]
+ #color = [c / 255 for c in color]
+ assert len(color) == 3
+ shape = find_shape(obj[0][2], print_shape=False)
+ assert shape in SHAPES.values(), 'Shape not found'
+ shape = shape_onehot_encoder.transform([[shape]])[0]
+ grid_obj = GridObject(x=x, y=y, object_type=obj_type, attributes=[color, shape])
+ grid_objs.append(grid_obj)
+ return grid_objs
diff --git a/utils/index_data.py b/utils/index_data.py
new file mode 100644
index 0000000..8dd9a18
--- /dev/null
+++ b/utils/index_data.py
@@ -0,0 +1,124 @@
+import json
+import os
+import torch
+import torch.utils.data
+from tqdm import tqdm
+
+
+def index_data(json_list, path_list):
+ print(f'processing files {len(json_list)}')
+ data_tuples = []
+ for j, v in tqdm(zip(json_list, path_list)):
+ with open(j, 'r') as f:
+ state = json.load(f)
+ ep_lens = [len(x) for x in state]
+ past_len = 0
+ for e, l in enumerate(ep_lens):
+ data_tuples.append([])
+ # skip first 30 frames and last 83 frames
+ for f in range(30, l - 83):
+ # find action taken;
+ f0x, f0y = state[e][f]['agent'][0]
+ f1x, f1y = state[e][f + 1]['agent'][0]
+ dx = (f1x - f0x) / 2.
+ dy = (f1y - f0y) / 2.
+ action = [dx, dy]
+ #data_tuples[-1].append((v, past_len + f, action))
+ data_tuples[-1].append((j, past_len + f, action))
+ # data_tuples = (json file, frame number, action)
+ assert len(data_tuples[-1]) > 0
+ past_len += l
+ return data_tuples
+
+class TransitionDataset(torch.utils.data.Dataset):
+ """
+ Training dataset class for the behavior cloning mlp model.
+ Args:
+ path: path to the dataset
+ types: list of video types to include
+ size: size of the frames to be returned
+ mode: train, val
+ num_context: number of context state-action pairs
+ num_test: number of test state-action pairs
+ num_trials: number of trials in an episode
+ action_range: number of frames to skip; actions are combined over these number of frames (displcement) of the agent
+ process_data: whether to the videos or not (skip if already processed)
+ __getitem__:
+ returns: (dem_frames, dem_actions, query_frames, target_actions)
+ dem_frames: (num_context, 3, size, size)
+ dem_actions: (num_context, 2)
+ query_frames: (num_test, 3, size, size)
+ target_actions: (num_test, 2)
+ """
+ def __init__(self, path, types=None, size=None, mode="train", num_context=30, num_test=1, num_trials=9,
+ action_range=10, process_data=0):
+
+ self.path = path
+ self.types = types
+ self.size = size
+ self.mode = mode
+ self.num_trials = num_trials
+ self.num_context = num_context
+ self.num_test = num_test
+ self.action_range = action_range
+ self.ep_combs = self.num_trials * (self.num_trials - 2) # 9p2 - 9
+ self.eps = [[x, y] for x in range(self.num_trials) for y in range(self.num_trials) if x != y]
+ types_str = '_'.join(self.types)
+
+ self.path_list = []
+ self.json_list = []
+ # get video paths and json file paths
+ for t in types:
+ print(f'reading files of type {t} in {mode}')
+ paths = [os.path.join(self.path, t, x) for x in os.listdir(os.path.join(self.path, t)) if
+ x.endswith(f'.mp4')]
+ jsons = [os.path.join(self.path, t, x) for x in os.listdir(os.path.join(self.path, t)) if
+ x.endswith(f'.json') and 'index' not in x]
+
+ paths = sorted(paths)
+ jsons = sorted(jsons)
+
+ if mode == 'train':
+ self.path_list += paths[:int(0.8 * len(jsons))]
+ self.json_list += jsons[:int(0.8 * len(jsons))]
+ elif mode == 'val':
+ self.path_list += paths[int(0.8 * len(jsons)):]
+ self.json_list += jsons[int(0.8 * len(jsons)):]
+ else:
+ self.path_list += paths
+ self.json_list += jsons
+
+ self.data_tuples = []
+ if process_data:
+ # index the videos in the dataset directory. This is done to speed up the retrieval of videos.
+ # frame index, action tuples are stored
+ self.data_tuples = index_data(self.json_list, self.path_list)
+ # tuples of frame index and action (displacement of agent)
+ index_dict = {'data_tuples': self.data_tuples}
+ with open(os.path.join(self.path, f'jindex_bib_{mode}_{types_str}.json'), 'w') as fp:
+ json.dump(index_dict, fp)
+ else:
+ # read pre-indexed data
+ with open(os.path.join(self.path, f'jindex_bib_{mode}_{types_str}.json'), 'r') as fp:
+ index_dict = json.load(fp)
+ self.data_tuples = index_dict['data_tuples']
+
+ self.tot_trials = len(self.path_list) * 9
+
+
+ def __getitem__(self, idx):
+ print('Empty')
+ return
+
+ def __len__(self):
+ return self.tot_trials // self.num_trials
+
+
+if __name__ == "__main__":
+ dataset = TransitionDataset(path='/datasets/external/bib_train/',
+ types=['multi_agent', 'instrumental_action'], #['instrumental_action', 'multi_agent', 'preference', 'single_object'],
+ size=(84, 84),
+ mode="train", num_context=30,
+ num_test=1, num_trials=9,
+ action_range=10, process_data=1)
+ print(len(dataset))
diff --git a/utils/relations.py b/utils/relations.py
new file mode 100644
index 0000000..0ef7454
--- /dev/null
+++ b/utils/relations.py
@@ -0,0 +1,116 @@
+import numpy as np
+
+
+# =============================== relationships to build the graph ===============================
+
+def rotate_vec2d(vec, degrees):
+ """
+ rotate a vector anti-clockwise
+ :param vec:
+ :param degrees:
+ :return:
+ """
+ theta = np.radians(degrees)
+ c, s = np.cos(theta), np.sin(theta)
+ R = np.array(((c, -s), (s, c)))
+ return R@vec
+
+# ---------- Remote Directional Relations --------------------------------------------------------
+
+def is_front(obj1, obj2, direction_vec)->bool:
+ diff = obj2.pos - obj1.pos
+ return diff@direction_vec > 0.1
+
+def is_back(obj1, obj2, direction_vec)->bool:
+ diff = obj2.pos - obj1.pos
+ return diff@direction_vec < -0.1
+
+def is_left(obj1, obj2, direction_vec)->bool:
+ left_vec = rotate_vec2d(direction_vec, -90)
+ diff = obj2.pos - obj1.pos
+ return diff@left_vec > 0.1
+
+def is_right(obj1, obj2, direction_vec)->bool:
+ left_vec = rotate_vec2d(direction_vec, 90)
+ diff = obj2.pos - obj1.pos
+ return diff@left_vec > 0.1
+
+# ---------- Alignment and Adjacency Relations ---------------------------------------------------
+
+def is_close(obj1, obj2, direction_vec=None)->bool:
+ # indicate whether two objects are adjacent to each other,
+ # which, unlike local directional relations, carry no directional information
+ distance = np.abs(obj1.pos - obj2.pos)
+ return np.sum(distance)==20
+
+def is_aligned(obj1, obj2, direction_vec=None)->bool:
+ # indicate if two entities are on the same horizontal or vertical line
+ diff = obj2.pos - obj1.pos
+ return np.any(diff==0)
+
+# ---------- Local Directional Relations ---------------------------------------------------------
+
+def is_top_adj(obj1, obj2, direction_vec=None)->bool:
+ return obj1.x==obj2.x and obj1.y==obj2.y+20
+
+def is_left_adj(obj1, obj2, direction_vec=None)->bool:
+ return obj1.y==obj2.y and obj1.x==obj2.x-20
+
+def is_top_left_adj(obj1, obj2, direction_vec=None)->bool:
+ return obj1.y==obj2.y+20 and obj1.x==obj2.x-20
+
+def is_top_right_adj(obj1, obj2, direction_vec=None)->bool:
+ return obj1.y==obj2.y+20 and obj1.x==obj2.x+20
+
+def is_down_adj(obj1, obj2, direction_vec=None)->bool:
+ return is_top_adj(obj2, obj1)
+
+def is_right_adj(obj1, obj2, direction_vec=None)->bool:
+ return is_left_adj(obj2, obj1)
+
+def is_down_right_adj(obj1, obj2, direction_vec=None)->bool:
+ return is_top_left_adj(obj2, obj1)
+
+def is_down_left_adj(obj1, obj2, direction_vec=None)->bool:
+ return is_top_right_adj(obj2, obj1)
+
+# ---------- More Remote Directional Relations (not used) ----------------------------------------
+
+def top_left(obj1, obj2, direction_vec)->bool:
+ return (obj1.x-obj2.x) <= (obj1.y-obj2.y)
+
+def top_right(obj1, obj2, direction_vec)->bool:
+ return -(obj1.x-obj2.x) <= (obj1.y-obj2.y)
+
+def down_left(obj1, obj2, direction_vec)->bool:
+ return top_right(obj2, obj1, direction_vec)
+
+def down_right(obj1, obj2, direction_vec)->bool:
+ return top_left(obj2, obj1, direction_vec)
+
+def fan_top(obj1, obj2, direction_vec)->bool:
+ top_left = (obj1.x-obj2.x) <= (obj1.y-obj2.y)
+ top_right = -(obj1.x-obj2.x) <= (obj1.y-obj2.y)
+ return top_left and top_right
+
+def fan_down(obj1, obj2, direction_vec)->bool:
+ return fan_top(obj2, obj1, direction_vec)
+
+def fan_right(obj1, obj2, direction_vec)->bool:
+ down_left = (obj1.x-obj2.x) >= (obj1.y-obj2.y)
+ top_right = -(obj1.x-obj2.x) <= (obj1.y-obj2.y)
+ return down_left and top_right
+
+def fan_left(obj1, obj2, direction_vec)->bool:
+ return fan_right(obj2, obj1, direction_vec)
+
+# ---------- Ad-hoc Relations --------------------------------------------------------------------
+
+def needs(obj1, obj2, direction_vec=None)->bool:
+ return np.argmax(obj1.type) == 0 and np.argmax(obj2.type) == 3
+
+def opens(obj1, obj2, direction_vec=None)->bool:
+ return np.argmax(obj1.type) == 3 and np.argmax(obj2.type) == 8
+
+def collects(obj1, obj2, direction_vec=None)->bool:
+ return np.argmax(obj1.type) == 0 and np.argmax(obj2.type) == 6
\ No newline at end of file
diff --git a/utils/run_build_graphs.sh b/utils/run_build_graphs.sh
new file mode 100644
index 0000000..4f2b03b
--- /dev/null
+++ b/utils/run_build_graphs.sh
@@ -0,0 +1 @@
+python build_graphs.py --mode train --cpus 30 && python build_graphs.py --mode val --cpus 30