From de0bea75087e684e61bad5469b83044db521dd3d Mon Sep 17 00:00:00 2001 From: Matteo Bortoletto Date: Thu, 1 Feb 2024 15:40:47 +0100 Subject: [PATCH] up --- README.md | 76 +++- run_test.sh | 9 + run_train.sh | 18 + test_tom.py | 122 ++++++ tom/__init__.py | 0 tom/dataset.py | 310 ++++++++++++++ tom/gnn.py | 877 ++++++++++++++++++++++++++++++++++++++ tom/model.py | 513 ++++++++++++++++++++++ tom/norm.py | 46 ++ tom/transformer.py | 89 ++++ train_tom.py | 75 ++++ utils/__init__.py | 0 utils/build_graphs.py | 115 +++++ utils/dataset.py | 487 +++++++++++++++++++++ utils/grid_object.py | 174 ++++++++ utils/index_data.py | 124 ++++++ utils/relations.py | 116 +++++ utils/run_build_graphs.sh | 1 + 18 files changed, 3150 insertions(+), 2 deletions(-) create mode 100644 run_test.sh create mode 100644 run_train.sh create mode 100644 test_tom.py create mode 100644 tom/__init__.py create mode 100644 tom/dataset.py create mode 100644 tom/gnn.py create mode 100644 tom/model.py create mode 100644 tom/norm.py create mode 100644 tom/transformer.py create mode 100644 train_tom.py create mode 100644 utils/__init__.py create mode 100644 utils/build_graphs.py create mode 100644 utils/dataset.py create mode 100644 utils/grid_object.py create mode 100644 utils/index_data.py create mode 100644 utils/relations.py create mode 100644 utils/run_build_graphs.sh diff --git a/README.md b/README.md index 8cfac98..3993eda 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,75 @@ -# IRENE +
+

Neural Reasoning about Agents' Goals, Preferences, and Actions

-Official code of "Neural Reasoning About Agents' Goals, Preferences, and Actions" \ No newline at end of file +**[Matteo Bortoletto][1],   [Lei Shi][2],   [Andreas Bulling][3]**

+**AAAI'24, Vancouver, CA**
+**[[Paper][4]]** + +
+ +# Citation +If you find our code useful or use it in your own projects, please cite our paper: + +```bibtex +@inproceedings{bortoletto2024neural, + author = {Bortoletto, Matteo and Lei, Shi and Bulling, Andreas}, + title = {{Neural Reasoning about Agents' Goals, Preferences, and Actions}}, + booktitle = {Proc. 38th AAAI Conference on Artificial Intelligence (AAAI)}, + year = {2024}, +} +``` + +# Setup + +This code is based on the [original implementation][5] of the BIB benchmark. + +## Using `virtualenv` +``` +python -m virtualenv /path/to/env +source /path/to/env/bin/activate +pip install -r requirements.txt +``` + +## Using `conda` +``` +conda create --name python=3.8.10 pip=20.0.2 cudatoolkit=10.2.89 +conda activate +pip install -r requirements_conda.txt +pip install dgl-cu102 dglgo -f https://data.dgl.ai/wheels/repo.html +``` + + +# Running the code + +## Activate the environment +Run `source bibdgl/bin/activate`. + +## Index data +This will create the json files with all the indexed frames for each episode in each video. +``` +python utils/index_data.py +``` +You need to manually set `mode` in the dataset class (in main). + +## Generate graphs +This will generate the graphs from the videos: +``` +python /utils/build_graphs.py --mode MODE --cpus NUM_CPUS +``` +`MODE` can be `train`, `val` or `test`. NOTE: check `utils/build_graphs.py` to make sure you're loading the correct dataset to generate the graphs you want. + +## Training +You can use the `gtbc.sh`. + +## Testing +Use `run_test_tom.sh`. + +# Hardware setup +All models are trained on an NVIDIA Tesla V100-SXM2-32GB GPU. + + +[1]: https://mattbortoletto.github.io/ +[2]: https://perceptualui.org/people/shi/ +[3]: https://perceptualui.org/people/bulling/ +[4]: https://perceptualui.org/publications/bortoletto24_aaai.pdf +[5]: https://github.com/kanishkg/bib-baselines diff --git a/run_test.sh b/run_test.sh new file mode 100644 index 0000000..ee0a135 --- /dev/null +++ b/run_test.sh @@ -0,0 +1,9 @@ +echo 314 e31 + +CUDA_VISIBLE_DEVICES=1 python test_tom.py \ +--model_type graphbcrnn \ +--types efficiency_irrational \ +--ckpt /projects/bortoletto/icml2023_matteo/wandb/run-20221224_135525-8i1r2aqy/files/bib/8i1r2aqy/checkpoints/epoch\=31-step\=22399.ckpt \ +--data_path /datasets/external/bib_evaluation_1_1/graphs/all_tasks \ +--process_data 0 \ +--surprise_type max diff --git a/run_train.sh b/run_train.sh new file mode 100644 index 0000000..7b1c9fb --- /dev/null +++ b/run_train.sh @@ -0,0 +1,18 @@ +CUDA_VISIBLE_DEVICES=0 python train_tom.py \ +--model_type graphbcrnn \ +--types single_object preference instrumental_action \ +--data_path /datasets/external/bib_train/graphs/all_tasks/ \ +--seed 7 \ +--batch_size 32 \ +--max_epochs 35 \ +--gpus 1 \ +--auto_select_gpus True \ +--num_workers 2 \ +--stochastic_weight_avg True \ +--lr 5e-4 \ +--check_val_every_n_epoch 1 \ +--track_grad_norm 2 \ +--gradient_clip_val 10 \ +--gnn_type RSAGEv4 \ +--state_dim 96 \ +--aggregation sum diff --git a/test_tom.py b/test_tom.py new file mode 100644 index 0000000..d72a008 --- /dev/null +++ b/test_tom.py @@ -0,0 +1,122 @@ +from argparse import ArgumentParser +import numpy as np +from tqdm import tqdm + +import torch +from torch.utils.data import DataLoader +import torch.nn.functional as F +import dgl + +from tom.dataset import TestToMnetDGLDataset, collate_function_seq_test +from tom.model import GraphBC_T, GraphBCRNN + + +def get_z_scores(total, total_expected, total_unexpected): + mean = np.mean(total) + std = np.std(total) + print("Z-Score expected: ", + (np.mean(total_expected) - mean) / std) + print("Z-Score unexpected: ", + (np.mean(total_unexpected) - mean) / std) + + +parser = ArgumentParser() + +parser.add_argument('--model_type', type=str, default='graphbcrnn') +parser.add_argument('--ckpt', type=str, default=None, help='path to checkpoint') +parser.add_argument('--data_path', type=str, default=None, help='path to the data') +parser.add_argument('--process_data', type=int, default=0) +parser.add_argument('--surprise_type', type=str, default='max', + help='surprise type: mean, max. This is used for comparing the plausibility scores of the two test episodes') +parser.add_argument('--types', nargs='+', type=str, + default=[ + 'preference', 'multi_agent', 'inaccessible_goal', + 'efficiency_irrational', 'efficiency_time','efficiency_path', + 'instrumental_no_barrier', 'instrumental_blocking_barrier', 'instrumental_inconsequential_barrier' + ], + help='types of tasks used for training / testing') +parser.add_argument('--filename', type=str, default='') + +args = parser.parse_args() + +filename = args.filename + +if args.model_type == 'graphbct': + model = GraphBC_T.load_from_checkpoint(args.ckpt) +elif args.model_type == 'graphbcrnn': + model = GraphBCRNN.load_from_checkpoint(args.ckpt) +else: + raise ValueError('Unknown model type.') + +device = 'cuda' +model.to(device) +model.eval() +with torch.no_grad(): + for t in args.types: + if args.model_type == 'graphbcrnn': + test_dataset = TestToMnetDGLDataset( + path=args.data_path, + task_type=t, + mode='test' + ) + test_dataloader = DataLoader( + test_dataset, + batch_size=1, + num_workers=1, + pin_memory=True, + collate_fn=collate_function_seq_test, + shuffle=False + ) + count = 0 + total, total_expected, total_unexpected = [], [], [] + pbar = tqdm(test_dataloader) + for j, batch in enumerate(pbar): + if args.model_type == 'graphbcrnn': + dem_expected_states, dem_expected_actions, dem_expected_lens, \ + dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, \ + query_expected_frames, target_expected_actions, \ + query_unexpected_frames, target_unexpected_actions = batch + dem_expected_states = dem_expected_states.to(device) + dem_expected_actions = dem_expected_actions.to(device) + dem_unexpected_states = dem_unexpected_states.to(device) + dem_unexpected_actions = dem_unexpected_actions.to(device) + target_expected_actions = target_expected_actions.to(device) + target_unexpected_actions = target_unexpected_actions.to(device) + surprise_expected = [] + query_expected_frames = dgl.unbatch(query_expected_frames) + for i in range(len(query_expected_frames)): + if args.model_type == 'graphbcrnn': + test_actions, test_actions_pred = model( + [dem_expected_states, dem_expected_actions, dem_expected_lens, query_expected_frames[i].to(device), target_expected_actions[:, i, :]] + ) + loss = F.mse_loss(test_actions, test_actions_pred) + surprise_expected.append(loss.cpu().detach().numpy()) + mean_expected_surprise = np.mean(surprise_expected) + max_expected_surprise = np.max(surprise_expected) + + # calculate the plausibility scores for the unexpected episode + surprise_unexpected = [] + query_unexpected_frames = dgl.unbatch(query_unexpected_frames) + for i in range(len(query_unexpected_frames)): + if args.model_type == 'graphbcrnn': + test_actions, test_actions_pred = model( + [dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, query_unexpected_frames[i].to(device), target_unexpected_actions[:, i, :]] + ) + loss = F.mse_loss(test_actions, test_actions_pred) + surprise_unexpected.append(loss.cpu().detach().numpy()) + mean_unexpected_surprise = np.mean(surprise_unexpected) + max_unexpected_surprise = np.max(surprise_unexpected) + + correct_mean = mean_expected_surprise < mean_unexpected_surprise + 0.5 * (mean_expected_surprise == mean_unexpected_surprise) + correct_max = max_expected_surprise < max_unexpected_surprise + 0.5 * (max_expected_surprise == max_unexpected_surprise) + if args.surprise_type == 'max': + count += correct_max + elif args.surprise_type == 'mean': + count += correct_mean + pbar.set_postfix({'accuracy': count/(j+1.), 'type': t}) + + total_expected.append(max_expected_surprise) + total_unexpected.append(max_unexpected_surprise) + total.append(max_expected_surprise) + total.append(max_unexpected_surprise) + get_z_scores(total, total_expected, total_unexpected) diff --git a/tom/__init__.py b/tom/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tom/dataset.py b/tom/dataset.py new file mode 100644 index 0000000..d3bcc75 --- /dev/null +++ b/tom/dataset.py @@ -0,0 +1,310 @@ +import pickle as pkl +import os +import torch +import torch.utils.data +import torch.nn.functional as F +import dgl +import random +from dgl.data import DGLDataset + + +def collate_function_seq(batch): + #dem_frames = torch.stack([item[0] for item in batch]) + dem_frames = dgl.batch([item[0] for item in batch]) + dem_actions = torch.stack([item[1] for item in batch]) + dem_lens = [item[2] for item in batch] + #query_frames = torch.stack([item[3] for item in batch]) + query_frames = dgl.batch([item[3] for item in batch]) + target_actions = torch.stack([item[4] for item in batch]) + return [dem_frames, dem_actions, dem_lens, query_frames, target_actions] + +def collate_function_seq_test(batch): + dem_expected_states = dgl.batch([item[0] for item in batch][0]) + dem_expected_actions = torch.stack([item[1] for item in batch][0]).unsqueeze(dim=0) + dem_expected_lens = [item[2] for item in batch] + #print(dem_expected_actions.size()) + dem_unexpected_states = dgl.batch([item[3] for item in batch][0]) + dem_unexpected_actions = torch.stack([item[4] for item in batch][0]).unsqueeze(dim=0) + dem_unexpected_lens = [item[5] for item in batch] + query_expected_frames = dgl.batch([item[6] for item in batch]) + target_expected_actions = torch.stack([item[7] for item in batch]) + #print(target_expected_actions.size()) + query_unexpected_frames = dgl.batch([item[8] for item in batch]) + target_unexpected_actions = torch.stack([item[9] for item in batch]) + return [ + dem_expected_states, dem_expected_actions, dem_expected_lens, \ + dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, \ + query_expected_frames, target_expected_actions, \ + query_unexpected_frames, target_unexpected_actions + ] + +def collate_function_mental(batch): + dem_frames = dgl.batch([item[0] for item in batch]) + dem_actions = torch.stack([item[1] for item in batch]) + dem_lens = [item[2] for item in batch] + past_test_frames = dgl.batch([item[3] for item in batch]) + past_test_actions = torch.stack([item[4] for item in batch]) + past_test_len = [item[5] for item in batch] + query_frames = dgl.batch([item[6] for item in batch]) + target_actions = torch.stack([item[7] for item in batch]) + return [dem_frames, dem_actions, dem_lens, past_test_frames, past_test_actions, past_test_len, query_frames, target_actions] + + +class ToMnetDGLDataset(DGLDataset): + """ + Training dataset class. + """ + def __init__(self, path, types=None, mode="train"): + self.path = path + self.types = types + self.mode = mode + print('Mode:', self.mode) + + if self.mode == 'train': + if len(self.types) == 4: + self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/' + #self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_global/' + #self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_local/' + elif len(self.types) == 3: + self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper() + '/' + print(self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper()) + elif len(self.types) == 2: + self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + '/' + print(self.types[0][0].upper() + self.types[1][0].upper()) + elif len(self.types) == 1: + self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + '/' + else: raise ValueError('Number of types different from 1 or 4.') + elif self.mode == 'val': + assert len(self.types) == 1 + self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/' + self.types[0] + '/' + #self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_global/' + self.types[0] + '/' + #self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_local/' + self.types[0] + '/' + else: + raise ValueError + + def get_test(self, states, actions): + # now states is a batched graph -> unbatch it, take the len, pick one sub-graph + # randomly and select the corresponding action + frame_graphs = dgl.unbatch(states) + trial_len = len(frame_graphs) + query_idx = random.randint(0, trial_len - 1) + query_graph = frame_graphs[query_idx] + target_action = actions[query_idx] + return query_graph, target_action + + def __getitem__(self, idx): + with open(self.path+str(idx)+'.pkl', 'rb') as f: + states, actions, lens, _ = pkl.load(f) + # shuffle + ziplist = list(zip(states, actions, lens)) + random.shuffle(ziplist) + states, actions, lens = zip(*ziplist) + # convert tuples to lists + states, actions, lens = [*states], [*actions], [*lens] + # pick last element in the list as test and pick random frame + test_s, test_a = self.get_test(states[-1], actions[-1]) + dem_s = states[:-1] + dem_a = actions[:-1] + dem_lens = lens[:-1] + dem_s = dgl.batch(dem_s) + dem_a = torch.stack(dem_a) + return dem_s, dem_a, dem_lens, test_s, test_a + + def __len__(self): + return len(os.listdir(self.path)) + + +class TestToMnetDGLDataset(DGLDataset): + """ + Testing dataset class. + """ + def __init__(self, path, task_type=None, mode="test"): + self.path = path + self.type = task_type + self.mode = mode + print('Mode:', self.mode) + + if self.mode == 'test': + self.path = self.path + '_dgl_hetero_nobound_4feats/' + self.type + '/' + #self.path = self.path + '_dgl_hetero_nobound_4feats_global/' + self.type + '/' + #self.path = self.path + '_dgl_hetero_nobound_4feats_local/' + self.type + '/' + else: + raise ValueError + + def __getitem__(self, idx): + with open(self.path+str(idx)+'.pkl', 'rb') as f: + dem_expected_states, dem_expected_actions, dem_expected_lens, _, \ + dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, _, \ + query_expected_frames, target_expected_actions, \ + query_unexpected_frames, target_unexpected_actions = pkl.load(f) + assert len(dem_expected_states) == 8 + assert len(dem_expected_actions) == 8 + assert len(dem_expected_lens) == 8 + assert len(dem_unexpected_states) == 8 + assert len(dem_unexpected_actions) == 8 + assert len(dem_unexpected_lens) == 8 + assert len(dgl.unbatch(query_expected_frames)) == target_expected_actions.size()[0] + assert len(dgl.unbatch(query_unexpected_frames)) == target_unexpected_actions.size()[0] + # ignore n_nodes + return dem_expected_states, dem_expected_actions, dem_expected_lens, \ + dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, \ + query_expected_frames, target_expected_actions, \ + query_unexpected_frames, target_unexpected_actions + + def __len__(self): + return len(os.listdir(self.path)) + + +class ToMnetDGLDatasetUndersample(DGLDataset): + """ + Training dataset class for the behavior cloning mlp model. + """ + def __init__(self, path, types=None, mode="train"): + self.path = path + self.types = types + self.mode = mode + print('Mode:', self.mode) + + if self.mode == 'train': + if len(self.types) == 4: + self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/' + elif len(self.types) == 3: + self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper() + '/' + print(self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper()) + elif len(self.types) == 2: + self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + '/' + print(self.types[0][0].upper() + self.types[1][0].upper()) + elif len(self.types) == 1: + self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + '/' + else: raise ValueError('Number of types different from 1 or 4.') + elif self.mode == 'val': + assert len(self.types) == 1 + self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/' + self.types[0] + '/' + else: + raise ValueError + print('Undersampled dataset!') + + def get_test(self, states, actions): + # now states is a batched graph -> unbatch it, take the len, pick one sub-graph + # randomly and select the corresponding action + frame_graphs = dgl.unbatch(states) + trial_len = len(frame_graphs) + query_idx = random.randint(0, trial_len - 1) + query_graph = frame_graphs[query_idx] + target_action = actions[query_idx] + return query_graph, target_action + + def __getitem__(self, idx): + idx = idx + 3175 + with open(self.path+str(idx)+'.pkl', 'rb') as f: + states, actions, lens, _ = pkl.load(f) + # shuffle + ziplist = list(zip(states, actions, lens)) + random.shuffle(ziplist) + states, actions, lens = zip(*ziplist) + # convert tuples to lists + states, actions, lens = [*states], [*actions], [*lens] + # pick last element in the list as test and pick random frame + test_s, test_a = self.get_test(states[-1], actions[-1]) + dem_s = states[:-1] + dem_a = actions[:-1] + dem_lens = lens[:-1] + dem_s = dgl.batch(dem_s) + dem_a = torch.stack(dem_a) + return dem_s, dem_a, dem_lens, test_s, test_a + + def __len__(self): + return len(os.listdir(self.path)) - 3175 + + +class ToMnetDGLDatasetMental(DGLDataset): + """ + Training dataset class. + """ + def __init__(self, path, types=None, mode="train"): + self.path = path + self.types = types + self.mode = mode + print('Mode:', self.mode) + + if self.mode == 'train': + if len(self.types) == 4: + self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/' + elif len(self.types) == 3: + self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper() + '/' + print(self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper()) + elif len(self.types) == 2: + self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + '/' + print(self.types[0][0].upper() + self.types[1][0].upper()) + elif len(self.types) == 1: + self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + '/' + else: raise ValueError('Number of types different from 1 or 4.') + elif self.mode == 'val': + assert len(self.types) == 1 + self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/' + self.types[0] + '/' + else: + raise ValueError + + def get_test(self, states, actions): + """ + return: past_test_graphs, past_test_actions, test_graph, test_action + """ + frame_graphs = dgl.unbatch(states) + trial_len = len(frame_graphs) + query_idx = random.randint(0, trial_len - 1) + test_graph = frame_graphs[query_idx] + test_action = actions[query_idx] + if query_idx > 0: + #past_test_actions = F.pad(past_test_actions, (0,0,0,41-query_idx), 'constant', 0) + if query_idx == 1: + past_test_graphs = frame_graphs[0] + past_test_actions = actions[:query_idx] + past_test_actions = F.pad(past_test_actions, (0,0,0,41-query_idx), 'constant', 0) + return past_test_graphs, past_test_actions, query_idx, test_graph, test_action + else: + past_test_graphs = frame_graphs[:query_idx] + past_test_actions = actions[:query_idx] + past_test_graphs = dgl.batch(past_test_graphs) + past_test_actions = F.pad(past_test_actions, (0,0,0,41-query_idx), 'constant', 0) + return past_test_graphs, past_test_actions, query_idx, test_graph, test_action + else: + test_action_ = F.pad(test_action.unsqueeze(0), (0,0,0,41-1), 'constant', 0) + # NOTE: since there are no past frames, return the test frame and action and query_idx=1 + # not sure what would be a better alternative + return test_graph, test_action_, 1, test_graph, test_action + + def __getitem__(self, idx): + with open(self.path+str(idx)+'.pkl', 'rb') as f: + states, actions, lens, _ = pkl.load(f) + ziplist = list(zip(states, actions, lens)) + random.shuffle(ziplist) + states, actions, lens = zip(*ziplist) + states, actions, lens = [*states], [*actions], [*lens] + past_test_s, past_test_a, past_test_len, test_s, test_a = self.get_test(states[-1], actions[-1]) + dem_s = states[:-1] + dem_a = actions[:-1] + dem_lens = lens[:-1] + dem_s = dgl.batch(dem_s) + dem_a = torch.stack(dem_a) + return dem_s, dem_a, dem_lens, past_test_s, past_test_a, past_test_len, test_s, test_a + + def __len__(self): + return len(os.listdir(self.path)) + +# -------------------------------------------------------------------------------------------------------- + +if __name__ == "__main__": + + types = [ + 'preference', 'multi_agent', 'inaccessible_goal', + 'efficiency_irrational', 'efficiency_time','efficiency_path', + 'instrumental_no_barrier', 'instrumental_blocking_barrier', 'instrumental_inconsequential_barrier' + ] + + mental_dataset = ToMnetDGLDatasetMental( + path='/datasets/external/bib_train/graphs/all_tasks/', + types=['instrumental_action'], + mode='train' + ) + dem_frames, dem_actions, dem_lens, past_test_frames, past_test_actions, len, test_frame, test_action = mental_dataset.__getitem__(99) + breakpoint() \ No newline at end of file diff --git a/tom/gnn.py b/tom/gnn.py new file mode 100644 index 0000000..10923d4 --- /dev/null +++ b/tom/gnn.py @@ -0,0 +1,877 @@ +import dgl.nn.pytorch as dglnn +import torch.nn as nn +import torch.nn.functional as F +import torch +import dgl +import sys +import copy + +from wandb import agent +sys.path.append('/projects/bortoletto/irene/') +from tom.norm import Norm + + +class RSAGEv4(nn.Module): + # multi-layer GNN for one single feature + def __init__( + self, + hidden_channels, + out_channels, + rel_names, + dropout, + n_layers, + activation=nn.ELU(), + ): + super().__init__() + self.embedding_type = nn.Linear(9, int(hidden_channels)) + self.embedding_pos = nn.Linear(2, int(hidden_channels)) + self.embedding_color = nn.Linear(3, int(hidden_channels)) + self.embedding_shape = nn.Linear(18, int(hidden_channels)) + self.combine_attr = nn.Sequential( + nn.ReLU(), + nn.Linear(hidden_channels*3, hidden_channels) + ) + self.combine_pos = nn.Sequential( + nn.ReLU(), + nn.Linear(hidden_channels*2, hidden_channels) + ) + self.layers = nn.ModuleList([ + dglnn.HeteroGraphConv({ + rel: dglnn.SAGEConv( + in_feats=hidden_channels, + out_feats=hidden_channels if l < n_layers - 1 else out_channels, + aggregator_type='lstm', + feat_drop=dropout, + activation=activation if l < n_layers - 1 else None + ) + for rel in rel_names}, aggregate='sum') + for l in range(n_layers) + ]) + + def forward(self, g): + attr = [] + attr.append(self.embedding_type(g.ndata['type'].float())) + pos = self.embedding_pos(g.ndata['pos']/170.) + attr.append(self.embedding_color(g.ndata['color']/255.)) + attr.append(self.embedding_shape(g.ndata['shape'].float())) + combined_attr = self.combine_attr(torch.cat(attr, dim=1)) + h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))} + for l, conv in enumerate(self.layers): + h = conv(g, h) + with g.local_scope(): + g.ndata['h'] = h['obj'] + out = dgl.mean_nodes(g, 'h') + return out + +# ------------------------------------------------------------------------------------------- + +class RAGNNv4(nn.Module): + # multi-layer GNN for one single feature + def __init__( + self, + hidden_channels, + out_channels, + rel_names, + dropout, + n_layers, + activation=nn.ELU(), + ): + super().__init__() + self.embedding_type = nn.Linear(9, int(hidden_channels)) + self.embedding_pos = nn.Linear(2, int(hidden_channels)) + self.embedding_color = nn.Linear(3, int(hidden_channels)) + self.embedding_shape = nn.Linear(18, int(hidden_channels)) + self.combine_attr = nn.Sequential( + nn.ReLU(), + nn.Linear(hidden_channels*3, hidden_channels) + ) + self.combine_pos = nn.Sequential( + nn.ReLU(), + nn.Linear(hidden_channels*2, hidden_channels) + ) + self.layers = nn.ModuleList([ + dglnn.HeteroGraphConv({ + rel: dglnn.AGNNConv() + for rel in rel_names}, aggregate='sum') + for l in range(n_layers) + ]) + + def forward(self, g): + attr = [] + attr.append(self.embedding_type(g.ndata['type'].float())) + pos = self.embedding_pos(g.ndata['pos']/170.) + attr.append(self.embedding_color(g.ndata['color']/255.)) + attr.append(self.embedding_shape(g.ndata['shape'].float())) + combined_attr = self.combine_attr(torch.cat(attr, dim=1)) + h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))} + for l, conv in enumerate(self.layers): + h = conv(g, h) + with g.local_scope(): + g.ndata['h'] = h['obj'] + out = dgl.mean_nodes(g, 'h') + return out + +# ------------------------------------------------------------------------------------------- + +class RGATv2(nn.Module): + # multi-layer GNN for one single feature + def __init__( + self, + hidden_channels, + out_channels, + num_heads, + rel_names, + dropout, + n_layers, + activation=nn.ELU(), + residual=False + ): + super().__init__() + self.embedding_type = nn.Embedding(9, int(hidden_channels*num_heads/4)) + self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads/4)) + self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads/4)) + self.embedding_shape = nn.Embedding(18, int(hidden_channels*num_heads/4)) + self.combine = nn.Sequential( + nn.ReLU(), + nn.Linear(hidden_channels*num_heads, hidden_channels*num_heads) + ) + self.layers = nn.ModuleList([ + dglnn.HeteroGraphConv({ + rel: dglnn.GATv2Conv( + in_feats=hidden_channels*num_heads, + out_feats=hidden_channels if l < n_layers - 1 else out_channels, + num_heads=num_heads if l < n_layers - 1 else 1, + feat_drop=dropout, + attn_drop=dropout, + residual=residual, + activation=activation if l < n_layers - 1 else None + ) + for rel in rel_names}, aggregate='sum') + for l in range(n_layers) + ]) + + def forward(self, g): + feats = [] + feats.append(self.embedding_type(torch.argmax(g.ndata['type'], dim=1))) + feats.append(self.embedding_pos(g.ndata['pos']/170.)) + feats.append(self.embedding_color(g.ndata['color']/255.)) + feats.append(self.embedding_shape(torch.argmax(g.ndata['shape'], dim=1))) + h = {'obj': self.combine(torch.cat(feats, dim=1))} + for l, conv in enumerate(self.layers): + h = conv(g, h) + if l != len(self.layers) - 1: + h = {k: v.flatten(1) for k, v in h.items()} + else: + h = {k: v.mean(1) for k, v in h.items()} + with g.local_scope(): + g.ndata['h'] = h['obj'] + out = dgl.mean_nodes(g, 'h') + return out + +# ------------------------------------------------------------------------------------------- + +class RGATv3(nn.Module): + # multi-layer GNN for one single feature + def __init__( + self, + hidden_channels, + out_channels, + num_heads, + rel_names, + dropout, + n_layers, + activation=nn.ELU(), + residual=False + ): + super().__init__() + self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads)) + self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads)) + self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads)) + self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads)) + self.combine = nn.Sequential( + nn.ReLU(), + nn.Linear(hidden_channels*num_heads*4, hidden_channels*num_heads) + ) + self.layers = nn.ModuleList([ + dglnn.HeteroGraphConv({ + rel: dglnn.GATv2Conv( + in_feats=hidden_channels*num_heads, + out_feats=hidden_channels if l < n_layers - 1 else out_channels, + num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always + feat_drop=dropout, + attn_drop=dropout, + residual=residual, + activation=activation if l < n_layers - 1 else None + ) + for rel in rel_names}, aggregate='sum') + for l in range(n_layers) + ]) + + def forward(self, g): + feats = [] + feats.append(self.embedding_type(g.ndata['type'].float())) + feats.append(self.embedding_pos(g.ndata['pos']/170.)) # NOTE: this should be 180 because I remove the boundary walls! + feats.append(self.embedding_color(g.ndata['color']/255.)) + feats.append(self.embedding_shape(g.ndata['shape'].float())) + h = {'obj': self.combine(torch.cat(feats, dim=1))} + for l, conv in enumerate(self.layers): + h = conv(g, h) + if l != len(self.layers) - 1: + h = {k: v.flatten(1) for k, v in h.items()} + else: + h = {k: v.mean(1) for k, v in h.items()} + with g.local_scope(): + g.ndata['h'] = h['obj'] + out = dgl.mean_nodes(g, 'h') + return out + +# ------------------------------------------------------------------------------------------- + +class RGCNv2(nn.Module): + # multi-layer GNN for one single feature + def __init__( + self, + hidden_channels, + out_channels, + rel_names, + dropout, + n_layers, + activation=nn.ReLU() + ): + super().__init__() + self.embedding_type = nn.Linear(9, int(hidden_channels)) + self.embedding_pos = nn.Linear(2, int(hidden_channels)) + self.embedding_color = nn.Linear(3, int(hidden_channels)) + self.embedding_shape = nn.Linear(18, int(hidden_channels)) + self.combine = nn.Sequential( + nn.ReLU(), + nn.Linear(hidden_channels*4, hidden_channels) + ) + self.layers = nn.ModuleList([ + dglnn.RelGraphConv( + in_feat=hidden_channels, + out_feat=hidden_channels, + num_rels=len(rel_names), + regularizer=None, + num_bases=None, + bias=True, + activation=activation, + self_loop=True, + dropout=dropout, + layer_norm=False + ) + for _ in range(n_layers-1)]) + self.layers.append( + dglnn.RelGraphConv( + in_feat=hidden_channels, + out_feat=out_channels, + num_rels=len(rel_names), + regularizer=None, + num_bases=None, + bias=True, + activation=activation, + self_loop=True, + dropout=dropout, + layer_norm=False + ) + ) + + def forward(self, g): + g = g.to_homogeneous() + feats = [] + feats.append(self.embedding_type(g.ndata['type'].float())) + feats.append(self.embedding_pos(g.ndata['pos']/170.)) + feats.append(self.embedding_color(g.ndata['color']/255.)) + feats.append(self.embedding_shape(g.ndata['shape'].float())) + h = self.combine(torch.cat(feats, dim=1)) + for conv in self.layers: + h = conv(g, h, g.etypes) + with g.local_scope(): + g.ndata['h'] = h + out = dgl.mean_nodes(g, 'h') + return out + +# ------------------------------------------------------------------------------------------- + +class RGATv3Agent(nn.Module): + # multi-layer GNN for one single feature + def __init__( + self, + hidden_channels, + out_channels, + num_heads, + rel_names, + dropout, + n_layers, + activation=nn.ELU(), + residual=False + ): + super().__init__() + self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads)) + self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads)) + self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads)) + self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads)) + self.combine = nn.Sequential( + nn.ReLU(), + nn.Linear(hidden_channels*num_heads*4, hidden_channels*num_heads) + ) + self.layers = nn.ModuleList([ + dglnn.HeteroGraphConv({ + rel: dglnn.GATv2Conv( + in_feats=hidden_channels*num_heads, + out_feats=hidden_channels if l < n_layers - 1 else out_channels, + num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always + feat_drop=dropout, + attn_drop=dropout, + residual=residual, + activation=activation if l < n_layers - 1 else None + ) + for rel in rel_names}, aggregate='sum') + for l in range(n_layers) + ]) + + def forward(self, g): + agent_mask = g.ndata['type'][:, 0] == 1 + feats = [] + feats.append(self.embedding_type(g.ndata['type'].float())) + feats.append(self.embedding_pos(g.ndata['pos']/200.)) + feats.append(self.embedding_color(g.ndata['color']/255.)) + feats.append(self.embedding_shape(g.ndata['shape'].float())) + h = {'obj': self.combine(torch.cat(feats, dim=1))} + for l, conv in enumerate(self.layers): + h = conv(g, h) + if l != len(self.layers) - 1: + h = {k: v.flatten(1) for k, v in h.items()} + else: + h = {k: v.mean(1) for k, v in h.items()} + with g.local_scope(): + g.ndata['h'] = h['obj'] + out = g.ndata['h'][agent_mask, :] + ctx = dgl.mean_nodes(g, 'h') + return out + ctx + +# ------------------------------------------------------------------------------------------- + +class RGATv4(nn.Module): + # multi-layer GNN for one single feature + def __init__( + self, + hidden_channels, + out_channels, + num_heads, + rel_names, + dropout, + n_layers, + activation=nn.ELU(), + residual=False + ): + super().__init__() + self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads)) + self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads)) + self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads)) + self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads)) + self.combine_attr = nn.Sequential( + nn.ReLU(), + nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads) + ) + self.combine_pos = nn.Sequential( + nn.ReLU(), + nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads) + ) + self.layers = nn.ModuleList([ + dglnn.HeteroGraphConv({ + rel: dglnn.GATv2Conv( + in_feats=hidden_channels*num_heads, + out_feats=hidden_channels if l < n_layers - 1 else out_channels, + num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always + feat_drop=dropout, + attn_drop=dropout, + residual=residual, + share_weights=False, + activation=activation if l < n_layers - 1 else None + ) + for rel in rel_names}, aggregate='sum') + for l in range(n_layers) + ]) + + def forward(self, g): + attr = [] + attr.append(self.embedding_type(g.ndata['type'].float())) + pos = self.embedding_pos(g.ndata['pos']/170.) + attr.append(self.embedding_color(g.ndata['color']/255.)) + attr.append(self.embedding_shape(g.ndata['shape'].float())) + combined_attr = self.combine_attr(torch.cat(attr, dim=1)) + h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))} + for l, conv in enumerate(self.layers): + h = conv(g, h) + if l != len(self.layers) - 1: + h = {k: v.flatten(1) for k, v in h.items()} + else: + h = {k: v.mean(1) for k, v in h.items()} + with g.local_scope(): + g.ndata['h'] = h['obj'] + out = dgl.mean_nodes(g, 'h') + return out + +# ------------------------------------------------------------------------------------------- + +class RGATv4Norm(nn.Module): + # multi-layer GNN for one single feature + def __init__( + self, + hidden_channels, + out_channels, + num_heads, + rel_names, + dropout, + n_layers, + activation=nn.ELU(), + residual=False + ): + super().__init__() + self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads)) + self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads)) + self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads)) + self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads)) + self.combine_attr = nn.Sequential( + nn.ReLU(), + nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads) + ) + self.combine_pos = nn.Sequential( + nn.ReLU(), + nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads) + ) + self.layers = nn.ModuleList([ + dglnn.HeteroGraphConv({ + rel: dglnn.GATv2Conv( + in_feats=hidden_channels*num_heads, + out_feats=hidden_channels if l < n_layers - 1 else out_channels, + num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always + feat_drop=dropout, + attn_drop=dropout, + residual=residual, + activation=activation if l < n_layers - 1 else None + ) + for rel in rel_names}, aggregate='sum') + for l in range(n_layers) + ]) + self.norms = nn.ModuleList([ + Norm( + norm_type='gn', + hidden_dim=hidden_channels*num_heads if l < n_layers - 1 else out_channels + ) + for l in range(n_layers) + ]) + + def forward(self, g): + attr = [] + attr.append(self.embedding_type(g.ndata['type'].float())) + pos = self.embedding_pos(g.ndata['pos']/170.) + attr.append(self.embedding_color(g.ndata['color']/255.)) + attr.append(self.embedding_shape(g.ndata['shape'].float())) + combined_attr = self.combine_attr(torch.cat(attr, dim=1)) + h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))} + for l, conv in enumerate(self.layers): + h = conv(g, h) + if l != len(self.layers) - 1: + h = {k: v.flatten(1) for k, v in h.items()} + else: + h = {k: v.mean(1) for k, v in h.items()} + h = {k: self.norms[l](g, v) for k, v in h.items()} + with g.local_scope(): + g.ndata['h'] = h['obj'] + out = dgl.mean_nodes(g, 'h') + return out + +# ------------------------------------------------------------------------------------------- + +class RGATv3Norm(nn.Module): + # multi-layer GNN for one single feature + def __init__( + self, + hidden_channels, + out_channels, + num_heads, + rel_names, + dropout, + n_layers, + activation=nn.ELU(), + residual=False + ): + super().__init__() + self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads)) + self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads)) + self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads)) + self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads)) + self.combine = nn.Sequential( + nn.ReLU(), + nn.Linear(hidden_channels*num_heads*4, hidden_channels*num_heads) + ) + self.layers = nn.ModuleList([ + dglnn.HeteroGraphConv({ + rel: dglnn.GATv2Conv( + in_feats=hidden_channels*num_heads, + out_feats=hidden_channels if l < n_layers - 1 else out_channels, + num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always + feat_drop=dropout, + attn_drop=dropout, + residual=residual, + activation=activation if l < n_layers - 1 else None + ) + for rel in rel_names}, aggregate='sum') + for l in range(n_layers) + ]) + self.norms = nn.ModuleList([ + Norm( + norm_type='gn', + hidden_dim=hidden_channels*num_heads if l < n_layers - 1 else out_channels + ) + for l in range(n_layers) + ]) + + def forward(self, g): + feats = [] + feats.append(self.embedding_type(g.ndata['type'].float())) + feats.append(self.embedding_pos(g.ndata['pos']/170.)) + feats.append(self.embedding_color(g.ndata['color']/255.)) + feats.append(self.embedding_shape(g.ndata['shape'].float())) + h = {'obj': self.combine(torch.cat(feats, dim=1))} + for l, conv in enumerate(self.layers): + h = conv(g, h) + if l != len(self.layers) - 1: + h = {k: v.flatten(1) for k, v in h.items()} + else: + h = {k: v.mean(1) for k, v in h.items()} + h = {k: self.norms[l](g, v) for k, v in h.items()} + with g.local_scope(): + g.ndata['h'] = h['obj'] + out = dgl.mean_nodes(g, 'h') + return out + +# ------------------------------------------------------------------------------------------- + +class RGATv4Agent(nn.Module): + # multi-layer GNN for one single feature + def __init__( + self, + hidden_channels, + out_channels, + num_heads, + rel_names, + dropout, + n_layers, + activation=nn.ELU(), + residual=False + ): + super().__init__() + self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads)) + self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads)) + self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads)) + self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads)) + self.combine_attr = nn.Sequential( + nn.ReLU(), + nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads) + ) + self.combine_pos = nn.Sequential( + nn.ReLU(), + nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads) + ) + self.layers = nn.ModuleList([ + dglnn.HeteroGraphConv({ + rel: dglnn.GATv2Conv( + in_feats=hidden_channels*num_heads, + out_feats=hidden_channels if l < n_layers - 1 else out_channels, + num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always + feat_drop=dropout, + attn_drop=dropout, + residual=residual, + activation=activation if l < n_layers - 1 else None + ) + for rel in rel_names}, aggregate='sum') + for l in range(n_layers) + ]) + self.combine_agent_context = nn.Linear(out_channels*2, out_channels) + + def forward(self, g): + agent_mask = g.ndata['type'][:, 0] == 1 + attr = [] + attr.append(self.embedding_type(g.ndata['type'].float())) + pos = self.embedding_pos(g.ndata['pos']/170.) + attr.append(self.embedding_color(g.ndata['color']/255.)) + attr.append(self.embedding_shape(g.ndata['shape'].float())) + combined_attr = self.combine_attr(torch.cat(attr, dim=1)) + h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))} + for l, conv in enumerate(self.layers): + h = conv(g, h) + if l != len(self.layers) - 1: + h = {k: v.flatten(1) for k, v in h.items()} + else: + h = {k: v.mean(1) for k, v in h.items()} + with g.local_scope(): + g.ndata['h'] = h['obj'] + h_a = g.ndata['h'][agent_mask, :] + g_no_agent = copy.deepcopy(g) + g_no_agent.remove_nodes([i for i, x in enumerate(agent_mask) if x]) + h_g = dgl.mean_nodes(g_no_agent, 'h') + out = self.combine_agent_context(torch.cat((h_a, h_g), dim=1)) + return out + +# ------------------------------------------------------------------------------------------- + +class RGCNv4(nn.Module): + # multi-layer GNN for one single feature + def __init__( + self, + hidden_channels, + out_channels, + rel_names, + n_layers, + activation=nn.ELU(), + ): + super().__init__() + self.embedding_type = nn.Linear(9, int(hidden_channels)) + self.embedding_pos = nn.Linear(2, int(hidden_channels)) + self.embedding_color = nn.Linear(3, int(hidden_channels)) + self.embedding_shape = nn.Linear(18, int(hidden_channels)) + self.combine_attr = nn.Sequential( + nn.ReLU(), + nn.Linear(hidden_channels*3, hidden_channels) + ) + self.combine_pos = nn.Sequential( + nn.ReLU(), + nn.Linear(hidden_channels*2, hidden_channels) + ) + self.layers = nn.ModuleList([ + dglnn.HeteroGraphConv({ + rel: dglnn.GraphConv( + in_feats=hidden_channels, + out_feats=hidden_channels if l < n_layers - 1 else out_channels, + activation=activation if l < n_layers - 1 else None + ) + for rel in rel_names}, aggregate='sum') + for l in range(n_layers) + ]) + + def forward(self, g): + attr = [] + attr.append(self.embedding_type(g.ndata['type'].float())) + pos = self.embedding_pos(g.ndata['pos']/170.) + attr.append(self.embedding_color(g.ndata['color']/255.)) + attr.append(self.embedding_shape(g.ndata['shape'].float())) + combined_attr = self.combine_attr(torch.cat(attr, dim=1)) + h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))} + for l, conv in enumerate(self.layers): + h = conv(g, h) + with g.local_scope(): + g.ndata['h'] = h['obj'] + out = dgl.mean_nodes(g, 'h') + return out + +# ------------------------------------------------------------------------------------------- + +class RGATv5(nn.Module): + + def __init__( + self, + hidden_channels, + out_channels, + num_heads, + rel_names, + dropout, + n_layers, + activation=nn.ELU(), + residual=False + ): + super().__init__() + self.hidden_channels = hidden_channels + self.num_heads = num_heads + self.embedding_type = nn.Linear(9, hidden_channels*num_heads) + self.embedding_pos = nn.Linear(2, hidden_channels*num_heads) + self.embedding_color = nn.Linear(3, hidden_channels*num_heads) + self.embedding_shape = nn.Linear(18, hidden_channels*num_heads) + self.combine = nn.Sequential( + nn.ReLU(), + nn.Linear(hidden_channels*num_heads*4, hidden_channels*num_heads) + ) + self.attention = nn.Linear(hidden_channels*num_heads*4, 4) + self.layers = nn.ModuleList([ + dglnn.HeteroGraphConv({ + rel: dglnn.GATv2Conv( + in_feats=hidden_channels*num_heads, + out_feats=hidden_channels if l < n_layers - 1 else out_channels, + num_heads=num_heads, + feat_drop=dropout, + attn_drop=dropout, + residual=residual, + activation=activation if l < n_layers - 1 else None + ) + for rel in rel_names}, aggregate='sum') + for l in range(n_layers) + ]) + + def forward(self, g): + feats = [] + feats.append(self.embedding_type(g.ndata['type'].float())) + feats.append(self.embedding_pos(g.ndata['pos']/170.)) + feats.append(self.embedding_color(g.ndata['color']/255.)) + feats.append(self.embedding_shape(g.ndata['shape'].float())) + h = torch.cat(feats, dim=1) + feat_attn = F.softmax(self.attention(h), dim=1) + h = h * feat_attn.repeat_interleave(self.hidden_channels*self.num_heads, dim=1) + h_in = self.combine(h) + h = {'obj': h_in} + for conv in self.layers: + h = conv(g, h) + h = {k: v.flatten(1) for k, v in h.items()} + #if l != len(self.layers) - 1: + # h = {k: v.flatten(1) for k, v in h.items()} + #else: + # h = {k: v.mean(1) for k, v in h.items()} + h = {k: v + h_in for k, v in h.items()} + with g.local_scope(): + g.ndata['h'] = h['obj'] + out = dgl.mean_nodes(g, 'h') + return out + +# ------------------------------------------------------------------------------------------- + +class RGATv6(nn.Module): + + # RGATv6 = RGATv4 + Global Attention Pooling + + def __init__( + self, + hidden_channels, + out_channels, + num_heads, + rel_names, + dropout, + n_layers, + activation=nn.ELU(), + residual=False + ): + super().__init__() + self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads)) + self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads)) + self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads)) + self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads)) + self.combine_attr = nn.Sequential( + nn.ReLU(), + nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads) + ) + self.combine_pos = nn.Sequential( + nn.ReLU(), + nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads) + ) + self.layers = nn.ModuleList([ + dglnn.HeteroGraphConv({ + rel: dglnn.GATv2Conv( + in_feats=hidden_channels*num_heads, + out_feats=hidden_channels if l < n_layers - 1 else out_channels, + num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always + feat_drop=dropout, + attn_drop=dropout, + residual=residual, + activation=activation if l < n_layers - 1 else None + ) + for rel in rel_names}, aggregate='sum') + for l in range(n_layers) + ]) + gate_nn = nn.Linear(out_channels, 1) + self.gap = dglnn.GlobalAttentionPooling(gate_nn) + + def forward(self, g): + attr = [] + attr.append(self.embedding_type(g.ndata['type'].float())) + pos = self.embedding_pos(g.ndata['pos']/170.) + attr.append(self.embedding_color(g.ndata['color']/255.)) + attr.append(self.embedding_shape(g.ndata['shape'].float())) + combined_attr = self.combine_attr(torch.cat(attr, dim=1)) + h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))} + for l, conv in enumerate(self.layers): + h = conv(g, h) + if l != len(self.layers) - 1: + h = {k: v.flatten(1) for k, v in h.items()} + else: + h = {k: v.mean(1) for k, v in h.items()} + #with g.local_scope(): + #g.ndata['h'] = h['obj'] + #out = dgl.mean_nodes(g, 'h') + out = self.gap(g, h['obj']) + return out + +# ------------------------------------------------------------------------------------------- + +class RGATv6Agent(nn.Module): + + # RGATv6 = RGATv4 + Global Attention Pooling + + def __init__( + self, + hidden_channels, + out_channels, + num_heads, + rel_names, + dropout, + n_layers, + activation=nn.ELU(), + residual=False + ): + super().__init__() + self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads)) + self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads)) + self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads)) + self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads)) + self.combine_attr = nn.Sequential( + nn.ReLU(), + nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads) + ) + self.combine_pos = nn.Sequential( + nn.ReLU(), + nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads) + ) + self.layers = nn.ModuleList([ + dglnn.HeteroGraphConv({ + rel: dglnn.GATv2Conv( + in_feats=hidden_channels*num_heads, + out_feats=hidden_channels if l < n_layers - 1 else out_channels, + num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always + feat_drop=dropout, + attn_drop=dropout, + residual=residual, + activation=activation if l < n_layers - 1 else None + ) + for rel in rel_names}, aggregate='sum') + for l in range(n_layers) + ]) + gate_nn = nn.Linear(out_channels, 1) + self.gap = dglnn.GlobalAttentionPooling(gate_nn) + self.combine_agent_context = nn.Linear(out_channels*2, out_channels) + + def forward(self, g): + agent_mask = g.ndata['type'][:, 0] == 1 + attr = [] + attr.append(self.embedding_type(g.ndata['type'].float())) + pos = self.embedding_pos(g.ndata['pos']/170.) + attr.append(self.embedding_color(g.ndata['color']/255.)) + attr.append(self.embedding_shape(g.ndata['shape'].float())) + combined_attr = self.combine_attr(torch.cat(attr, dim=1)) + h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))} + for l, conv in enumerate(self.layers): + h = conv(g, h) + if l != len(self.layers) - 1: + h = {k: v.flatten(1) for k, v in h.items()} + else: + h = {k: v.mean(1) for k, v in h.items()} + with g.local_scope(): + g.ndata['h'] = h['obj'] + h_a = g.ndata['h'][agent_mask, :] + h_g = g.ndata['h'][~agent_mask, :] + g_no_agent = copy.deepcopy(g) + g_no_agent.remove_nodes([i for i, x in enumerate(agent_mask) if x]) + h_g = self.gap(g_no_agent, h_g) # dgl.mean_nodes(g_no_agent, 'h') + out = self.combine_agent_context(torch.cat((h_a, h_g), dim=1)) + return out + +# ------------------------------------------------------------------------------------------- + diff --git a/tom/model.py b/tom/model.py new file mode 100644 index 0000000..14e0bb5 --- /dev/null +++ b/tom/model.py @@ -0,0 +1,513 @@ +from argparse import ArgumentParser +import torch +import torch.nn as nn +import torch.nn.functional as F +import pytorch_lightning as pl +from pytorch_lightning.callbacks import ModelCheckpoint +from torch.utils.data import DataLoader + +from tom.dataset import * +from tom.transformer import TransformerEncoder +from tom.gnn import RGATv2, RGATv3, RGATv3Agent, RGATv4, RGATv4Norm, RSAGEv4, RAGNNv4 + + +class MlpModel(nn.Module): + """Multilayer Perceptron with last layer linear. + Args: + input_size (int): number of inputs + hidden_sizes (list): can be empty list for none (linear model). + output_size: linear layer at output, or if ``None``, the last hidden size + will be the output size and will have nonlinearity applied + nonlinearity: torch nonlinearity Module (not Functional). + """ + + def __init__( + self, + input_size, + hidden_sizes, # Can be empty list or None for none. + output_size=None, # if None, last layer has nonlinearity applied. + nonlinearity=nn.ReLU, # Module, not Functional. + dropout=None # Dropout value + ): + super().__init__() + if isinstance(hidden_sizes, int): + hidden_sizes = [hidden_sizes] + elif hidden_sizes is None: + hidden_sizes = [] + hidden_layers = [nn.Linear(n_in, n_out) for n_in, n_out in + zip([input_size] + hidden_sizes[:-1], hidden_sizes)] + sequence = list() + for i, layer in enumerate(hidden_layers): + if dropout is not None: + sequence.extend([layer, nonlinearity(), nn.Dropout(dropout)]) + else: + sequence.extend([layer, nonlinearity()]) + + if output_size is not None: + last_size = hidden_sizes[-1] if hidden_sizes else input_size + sequence.append(torch.nn.Linear(last_size, output_size)) + self.model = nn.Sequential(*sequence) + self._output_size = (hidden_sizes[-1] if output_size is None + else output_size) + + def forward(self, input): + """Compute the model on the input, assuming input shape [B,input_size].""" + return self.model(input) + + @property + def output_size(self): + """Retuns the output size of the model.""" + return self._output_size + +# --------------------------------------------------------------------------------------------------------------------------------- + +class GraphBCRNN(pl.LightningModule): + """ + Implementation of the baseline model for the BC-RNN algorithm. + R-GCN + LSTM are used to encode the familiarization trials + """ + @staticmethod + def add_model_specific_args(parent_parser): + parser = ArgumentParser(parents=[parent_parser], add_help=False) + parser.add_argument('--lr', type=float, default=1e-4) + parser.add_argument('--action_dim', type=int, default=2) + parser.add_argument('--context_dim', type=int, default=32) # lstm hidden size + parser.add_argument('--beta', type=float, default=0.01) + parser.add_argument('--dropout', type=float, default=0.2) + parser.add_argument('--process_data', type=int, default=0) + parser.add_argument('--max_len', type=int, default=30) + # arguments for gnn + parser.add_argument('--gnn_type', type=str, default='RGATv4') + parser.add_argument('--state_dim', type=int, default=128) # gnn out_feats + parser.add_argument('--feats_dims', type=list, default=[9, 2, 3, 18]) + parser.add_argument('--aggregation', type=str, default='sum') + # arguments for mpl + #parser.add_argument('--mpl_hid_feats', type=list, default=[256, 64, 16]) + return parser + + def __init__(self, hparams): + super().__init__() + + self.hparams = hparams + self.lr = self.hparams.lr + self.state_dim = self.hparams.state_dim + self.action_dim = self.hparams.action_dim + self.context_dim = self.hparams.context_dim + self.beta = self.hparams.beta + self.dropout = self.hparams.dropout + self.max_len = self.hparams.max_len + self.feats_dims = self.hparams.feats_dims # type, position, color, shape + self.rel_names = [ + 'is_aligned', 'is_back', 'is_close', 'is_down_adj', 'is_down_left_adj', + 'is_down_right_adj', 'is_front', 'is_left', 'is_left_adj', 'is_right', + 'is_right_adj', 'is_top_adj', 'is_top_left_adj', 'is_top_right_adj' + ] + self.gnn_aggregation = self.hparams.aggregation + + if self.hparams.gnn_type == 'RGATv2': + self.gnn_encoder = RGATv2( + hidden_channels=self.state_dim, + out_channels=self.state_dim, + num_heads=4, + rel_names=self.rel_names, + dropout=0.0, + n_layers=2, + activation=nn.ELU(), + residual=False + ) + if self.hparams.gnn_type == 'RGATv3': + self.gnn_encoder = RGATv3( + hidden_channels=self.state_dim, + out_channels=self.state_dim, + num_heads=4, + rel_names=self.rel_names, + dropout=0.0, + n_layers=2, + activation=nn.ELU(), + residual=False + ) + if self.hparams.gnn_type == 'RGATv4': + self.gnn_encoder = RGATv4( + hidden_channels=self.state_dim, + out_channels=self.state_dim, + num_heads=4, + rel_names=self.rel_names, + dropout=0.0, + n_layers=2, + activation=nn.ELU(), + residual=False + ) + if self.hparams.gnn_type == 'RGATv3Agent': + self.gnn_encoder = RGATv3Agent( + hidden_channels=self.state_dim, + out_channels=self.state_dim, + num_heads=4, + rel_names=self.rel_names, + dropout=0.0, + n_layers=2, + activation=nn.ELU(), + residual=False + ) + if self.hparams.gnn_type == 'RSAGEv4': + self.gnn_encoder = RSAGEv4( + hidden_channels=self.state_dim, + out_channels=self.state_dim, + rel_names=self.rel_names, + dropout=0.0, + n_layers=2, + activation=nn.ELU() + ) + if self.gnn_aggregation == 'cat_axis_1': + self.lstm_input_size = self.state_dim * len(self.feats_dims) + self.action_dim + self.mlp_input_size = self.state_dim * len(self.feats_dims) + self.context_dim * 2 + elif self.gnn_aggregation == 'sum': + self.lstm_input_size = self.state_dim + self.action_dim + self.mlp_input_size = self.state_dim + self.context_dim * 2 + else: + raise ValueError('Fix this') + + self.context_enc = nn.LSTM(self.lstm_input_size, self.context_dim, 2, + batch_first=True, bidirectional=True) + + self.policy = MlpModel(input_size=self.mlp_input_size, hidden_sizes=[256, 128, 256], + output_size=self.action_dim, dropout=self.dropout) + + self.past_samples = [] + + def forward(self, batch): + dem_frames, dem_actions, dem_lens, query_frame, target_action = batch + dem_actions = dem_actions.float() + target_action = target_action.float() + dem_states = self.gnn_encoder(dem_frames) # torch.Size([number of frames, 128 * number of features if cat_axis_1]) + # segment according the number of frames in each episode and pad with zeros + # to obtain tensors of shape [batch size, num of trials (8), max num of frames (30), hidden dim] + b, l, s, _ = dem_actions.size() + dem_states_new = [] + for batch in range(b): + dem_states_new.append(self._sequence_to_padding(dem_states, dem_lens[batch], self.max_len)) + dem_states_new = torch.stack(dem_states_new).to(self.device) # torch.Size([batchsize, 8, 30, 128 * number of features if cat_axis_1]) + # concatenate states and actions to get expert trajectory + dem_states_new = dem_states_new.view(b * l, s, -1) # torch.Size([batchsize*8, 30, 128 * number of features if cat_axis_1]) + dem_actions = dem_actions.view(b * l, s, -1) # torch.Size([batchsize*8, 30, 128]) + dem_traj = torch.cat([dem_states_new, dem_actions], dim=2) # torch.Size([batchsize*8, 30, 2 + 128 * number of features if cat_axis_1]) + # embed expert trajectory to get a context embedding batch x samples x dim + dem_lens = torch.tensor([t for dl in dem_lens for t in dl]).to(torch.int64).cpu() + dem_lens = dem_lens.view(-1) + x_lstm = nn.utils.rnn.pack_padded_sequence(dem_traj, dem_lens, batch_first=True, enforce_sorted=False) + x_lstm, _ = self.context_enc(x_lstm) + x_lstm, _ = nn.utils.rnn.pad_packed_sequence(x_lstm, batch_first=True) + x_out = x_lstm[:, -1, :] + x_out = x_out.view(b, l, -1) + context = torch.mean(x_out, dim=1) # torch.Size([batchsize, 64]) (64=32*2) + # concat context embedding to the state embedding of test trajectory + test_states = self.gnn_encoder(query_frame) # torch.Size([2, 128]) + test_context_states = torch.cat([context, test_states], dim=1) # torch.Size([batchsize, 192]) 192=64+128 + # for each state in the test states calculate action + test_actions_pred = torch.tanh(self.policy(test_context_states)) # torch.Size([batchsize, 2]) + return target_action, test_actions_pred + + def _sequence_to_padding(self, x, lengths, max_length): + # declare the shape, it can work for x of any shape. + ret_tensor = torch.zeros((len(lengths), max_length) + tuple(x.shape[1:])) + cum_len = 0 + for i, l in enumerate(lengths): + ret_tensor[i, :l] = x[cum_len: cum_len+l] + cum_len += l + return ret_tensor + + def training_step(self, batch, batch_idx): + test_actions, test_actions_pred = self.forward(batch) + loss = F.mse_loss(test_actions, test_actions_pred) + self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) + return loss + + def validation_step(self, batch, batch_idx, dataloader_idx=0): + test_actions, test_actions_pred = self.forward(batch) + loss = F.mse_loss(test_actions, test_actions_pred) + self.log('val_loss', loss, on_epoch=True, logger=True) + + def configure_optimizers(self): + optim = torch.optim.Adam(self.parameters(), lr=self.lr) + return optim + #optim = torch.optim.AdamW(self.parameters(), lr=self.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01) + #scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optim, gamma=0.8) + #return [optim], [scheduler] + + def train_dataloader(self): + train_dataset = ToMnetDGLDataset(path=self.hparams.data_path, + types=self.hparams.types, + mode='train') + train_loader = DataLoader(dataset=train_dataset, + batch_size=self.hparams.batch_size, + collate_fn=collate_function_seq, + num_workers=self.hparams.num_workers, + #pin_memory=True, + shuffle=True) + return train_loader + + def val_dataloader(self): + val_datasets = [] + val_loaders = [] + for t in self.hparams.types: + val_datasets.append(ToMnetDGLDataset(path=self.hparams.data_path, + types=[t], + mode='val')) + val_loaders.append(DataLoader(dataset=val_datasets[-1], + batch_size=self.hparams.batch_size, + collate_fn=collate_function_seq, + num_workers=self.hparams.num_workers, + #pin_memory=True, + shuffle=False)) + return val_loaders + + def configure_callbacks(self): + checkpoint = ModelCheckpoint( + dirpath=None, # automatically set + #filename=self.params['bc_model']+'-'+self.params['gnn_type']+'-'+self.gnn_params['feats_aggr']+'-{epoch:02d}', + save_top_k=-1, + period=1 + ) + return [checkpoint] + +# --------------------------------------------------------------------------------------------------------------------------------- + +class GraphBC_T(pl.LightningModule): + """ + BC model with GraphTrans encoder, LSTM and MLP. + """ + @staticmethod + def add_model_specific_args(parent_parser): + parser = ArgumentParser(parents=[parent_parser], add_help=False) + parser.add_argument('--lr', type=float, default=1e-4) + parser.add_argument('--action_dim', type=int, default=2) + parser.add_argument('--context_dim', type=int, default=32) # lstm hidden size + parser.add_argument('--beta', type=float, default=0.01) + parser.add_argument('--dropout', type=float, default=0.2) + parser.add_argument('--process_data', type=int, default=0) + parser.add_argument('--max_len', type=int, default=30) + # arguments for gnn + parser.add_argument('--state_dim', type=int, default=128) # gnn out_feats + parser.add_argument('--feats_dims', type=list, default=[9, 2, 3, 18]) + parser.add_argument('--aggregation', type=str, default='cat_axis_1') + parser.add_argument('--gnn_type', type=str, default='RGATv3') + # arguments for mpl + #parser.add_argument('--mpl_hid_feats', type=list, default=[256, 64, 16]) + # arguments for transformer + parser.add_argument('--d_model', type=int, default=128) + parser.add_argument('--nhead', type=int, default=4) + parser.add_argument('--dim_feedforward', type=int, default=512) + parser.add_argument('--transformer_dropout', type=float, default=0.3) + parser.add_argument('--transformer_activation', type=str, default='gelu') + parser.add_argument('--num_encoder_layers', type=int, default=6) + parser.add_argument('--transformer_norm_input', type=int, default=0) + return parser + + def __init__(self, hparams): + super().__init__() + + self.hparams = hparams + self.lr = self.hparams.lr + self.state_dim = self.hparams.state_dim + self.action_dim = self.hparams.action_dim + self.context_dim = self.hparams.context_dim + self.beta = self.hparams.beta + self.dropout = self.hparams.dropout + self.max_len = self.hparams.max_len + self.feats_dims = self.hparams.feats_dims # type, position, color, shape + + self.rel_names = [ + 'is_aligned', 'is_back', 'is_close', 'is_down_adj', 'is_down_left_adj', + 'is_down_right_adj', 'is_front', 'is_left', 'is_left_adj', 'is_right', + 'is_right_adj', 'is_top_adj', 'is_top_left_adj', 'is_top_right_adj' + ] + self.gnn_aggregation = self.hparams.aggregation + if self.hparams.gnn_type == 'RGATv3': + self.gnn_encoder = RGATv3( + hidden_channels=self.state_dim, + out_channels=self.state_dim, + num_heads=4, + rel_names=self.rel_names, + dropout=0.0, + n_layers=2, + activation=nn.ELU(), + residual=False + ) + if self.hparams.gnn_type == 'RGATv4': + self.gnn_encoder = RGATv4( + hidden_channels=self.state_dim, + out_channels=self.state_dim, + num_heads=4, + rel_names=self.rel_names, + dropout=0.0, + n_layers=2, + activation=nn.ELU(), + residual=False + ) + if self.hparams.gnn_type == 'RSAGEv4': + self.gnn_encoder = RSAGEv4( + hidden_channels=self.state_dim, + out_channels=self.state_dim, + rel_names=self.rel_names, + dropout=0.0, + n_layers=2, + activation=nn.ELU() + ) + if self.hparams.gnn_type == 'RAGNNv4': + self.gnn_encoder = RAGNNv4( + hidden_channels=self.state_dim, + out_channels=self.state_dim, + rel_names=self.rel_names, + dropout=0.0, + n_layers=2, + activation=nn.ELU() + ) + if self.hparams.gnn_type == 'RGATv4Norm': + self.gnn_encoder = RGATv4Norm( + hidden_channels=self.state_dim, + out_channels=self.state_dim, + num_heads=4, + rel_names=self.rel_names, + dropout=0.0, + n_layers=2, + activation=nn.ELU(), + residual=True + ) + self.d_model = self.hparams.d_model + self.nhead = self.hparams.nhead + self.dim_feedforward = self.hparams.dim_feedforward + self.transformer_dropout = self.hparams.transformer_dropout + self.transformer_activation = self.hparams.transformer_activation + self.num_encoder_layers = self.hparams.num_encoder_layers + self.transformer_norm_input = self.hparams.transformer_norm_input + self.context_enc = TransformerEncoder( + self.d_model, + self.nhead, + self.dim_feedforward, + self.transformer_dropout, + self.transformer_activation, + self.num_encoder_layers, + self.max_len, + self.transformer_norm_input + ) + + if self.gnn_aggregation == 'cat_axis_1': + self.gnn2transformer = nn.Linear(self.state_dim * len(self.feats_dims) + self.action_dim, self.d_model) + self.mlp_input_size = len(self.feats_dims) * self.state_dim + self.d_model + elif self.gnn_aggregation == 'sum': + self.gnn2transformer = nn.Linear(self.state_dim + self.action_dim, self.d_model) + self.mlp_input_size = self.state_dim + self.d_model + else: + raise ValueError('Only sum and cat1 aggregations are available.') + + self.policy = MlpModel(input_size=self.mlp_input_size, hidden_sizes=[256, 128, 256], + output_size=self.action_dim, dropout=self.dropout) + + # CLS Embedding parameters, requires_grad=True + self.embedding = nn.Embedding(self.max_len + 1, self.d_model) # + 1 cause of cls token + self.emb_layer_norm = nn.LayerNorm(self.d_model) + self.emb_dropout = nn.Dropout(p=self.transformer_dropout) + + def forward(self, batch): + dem_frames, dem_actions, dem_lens, query_frame, target_action = batch + dem_actions = dem_actions.float() + target_action = target_action.float() + dem_states = self.gnn_encoder(dem_frames) + b, l, s, _ = dem_actions.size() + dem_lens = torch.tensor([t for dl in dem_lens for t in dl]).to(torch.int64).cpu() + dem_lens = dem_lens.view(-1) + dem_actions_packed = torch.nn.utils.rnn.pack_padded_sequence(dem_actions.view(b*l, s, -1), dem_lens, batch_first=True, enforce_sorted=False)[0] + dem_traj = torch.cat([dem_states, dem_actions_packed], dim=1) + h_node = self.gnn2transformer(dem_traj) + hidden_dim = h_node.size()[1] + padded_trajectory = torch.zeros(b*l, self.max_len, hidden_dim).to(self.device) + j = 0 + for idx, i in enumerate(dem_lens): + padded_trajectory[idx][:i] = h_node[j:j+i] + j += i + mask = self.make_mask(padded_trajectory).to(self.device) + transformer_input = padded_trajectory.transpose(0, 1) # [30, 16, 128] + # add cls: + cls_embedding = nn.Parameter(torch.randn([1, 1, self.d_model], requires_grad=True)).expand(1, b*l, -1).to(self.device) + transformer_input = torch.cat([transformer_input, cls_embedding], dim=0) # [31, 16, 128] + zeros = mask.data.new(mask.size(0), 1).fill_(0) + mask = torch.cat([mask, zeros], dim=1) + # Embed + indices = torch.arange(self.max_len + 1, dtype=torch.int).to(self.device) # + 1 cause of cls [0, 1, ..., 30] + positional_embeddings = self.embedding(indices).unsqueeze(1) # torch.Size([31, 1, 128]) + #generate transformer input + pe_input = positional_embeddings + transformer_input # torch.Size([31, 16, 128]) + # Layernorm and dropout + transformer_in = self.emb_dropout(self.emb_layer_norm(pe_input)) # torch.Size([31, 16, 128]) + # transformer encoding and output parsing + out, _ = self.context_enc(transformer_in, mask) # [31, 16, 128] + cls = out[-1] + cls = cls.view(b, l, -1) # 2, 8, 128 + context = torch.mean(cls, dim=1) + # CLASSIFICATION + test_states = self.gnn_encoder(query_frame) # torch.Size([2, 512]) + test_context_states = torch.cat([context, test_states], dim=1) # torch.Size([batchsize, hidden_dim + lstm_hidden_dim]) 192=512+128 + # for each state in the test states calculate action + x = self.policy(test_context_states) + test_actions_pred = torch.tanh(x) # torch.Size([batchsize, 2]) + #test_actions_pred = torch.tanh(self.policy(test_context_states)) # torch.Size([batchsize, 2]) + return target_action, test_actions_pred + + def make_mask(self, feature): + return (torch.sum( + torch.abs(feature), + dim=-1 + ) == 0)#.unsqueeze(1).unsqueeze(2) + + def training_step(self, batch, batch_idx): + test_actions, test_actions_pred = self.forward(batch) + loss = F.mse_loss(test_actions, test_actions_pred) + self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) + return loss + + def validation_step(self, batch, batch_idx, dataloader_idx=0): + test_actions, test_actions_pred = self.forward(batch) + loss = F.mse_loss(test_actions, test_actions_pred) + self.log('val_loss', loss, on_epoch=True, logger=True) + + def configure_optimizers(self): + optim = torch.optim.Adam(self.parameters(), lr=self.lr) + return optim + #optim = torch.optim.AdamW(self.parameters(), lr=self.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01) + #scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optim, gamma=0.96) + #return [optim], [scheduler] + + def train_dataloader(self): + train_dataset = ToMnetDGLDataset(path=self.hparams.data_path, + types=self.hparams.types, + mode='train') + train_loader = DataLoader(dataset=train_dataset, + batch_size=self.hparams.batch_size, + collate_fn=collate_function_seq, + num_workers=self.hparams.num_workers, + #pin_memory=True, + shuffle=True) + return train_loader + + def val_dataloader(self): + val_datasets = [] + val_loaders = [] + for t in self.hparams.types: + val_datasets.append(ToMnetDGLDataset(path=self.hparams.data_path, + types=[t], + mode='val')) + val_loaders.append(DataLoader(dataset=val_datasets[-1], + batch_size=self.hparams.batch_size, + collate_fn=collate_function_seq, + num_workers=self.hparams.num_workers, + #pin_memory=True, + shuffle=False)) + return val_loaders + + def configure_callbacks(self): + checkpoint = ModelCheckpoint( + dirpath=None, # automatically set + #filename=self.params['bc_model']+'-'+self.params['gnn_type']+'-'+self.gnn_params['feats_aggr']+'-{epoch:02d}', + save_top_k=-1, + period=1 + ) + return [checkpoint] \ No newline at end of file diff --git a/tom/norm.py b/tom/norm.py new file mode 100644 index 0000000..818ecf1 --- /dev/null +++ b/tom/norm.py @@ -0,0 +1,46 @@ +import torch +import torch.nn as nn + + +class Norm(nn.Module): + + def __init__(self, norm_type, hidden_dim=64, print_info=None): + super(Norm, self).__init__() + + # assert norm_type in ['bn', 'ln', 'gn', None] + self.norm = None + self.print_info = print_info + if norm_type == 'bn': + self.norm = nn.BatchNorm1d(hidden_dim) + elif norm_type == 'gn': + self.norm = norm_type + self.weight = nn.Parameter(torch.ones(hidden_dim)) + self.bias = nn.Parameter(torch.zeros(hidden_dim)) + + self.mean_scale = nn.Parameter(torch.ones(hidden_dim)) + + def forward(self, graph, tensor, print_=False): + + if self.norm is not None and type(self.norm) != str: + return self.norm(tensor) + elif self.norm is None: + return tensor + + batch_list = graph.batch_num_nodes('obj') + batch_size = len(batch_list) + #batch_list = torch.tensor(batch_list).long().to(tensor.device) + batch_list = batch_list.long().to(tensor.device) + batch_index = torch.arange(batch_size).to(tensor.device).repeat_interleave(batch_list) + batch_index = batch_index.view((-1,) + (1,) * (tensor.dim() - 1)).expand_as(tensor) + mean = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device) + mean = mean.scatter_add_(0, batch_index, tensor) + mean = (mean.T / batch_list).T + mean = mean.repeat_interleave(batch_list, dim=0) + + sub = tensor - mean * self.mean_scale + + std = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device) + std = std.scatter_add_(0, batch_index, sub.pow(2)) + std = ((std.T / batch_list).T + 1e-6).sqrt() + std = std.repeat_interleave(batch_list, dim=0) + return self.weight * sub / std + self.bias \ No newline at end of file diff --git a/tom/transformer.py b/tom/transformer.py new file mode 100644 index 0000000..0544c9e --- /dev/null +++ b/tom/transformer.py @@ -0,0 +1,89 @@ +import torch +import torch.nn as nn +import math + + +class PositionalEncoding(nn.Module): + + def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000): + super().__init__() + + self.dropout = nn.Dropout(p=dropout) + + position = torch.arange(max_len).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) + pe = torch.zeros(max_len, 1, d_model) + pe[:, 0, 0::2] = torch.sin(position * div_term) + pe[:, 0, 1::2] = torch.cos(position * div_term) + self.register_buffer('pe', pe) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: Tensor, shape [seq_len, batch_size, embedding_dim] + """ + x = x + self.pe[:x.size(0)] + return self.dropout(x) + + +class TransformerEncoder(nn.Module): + + def __init__( + self, + d_model, + nhead, + dim_feedforward, + transformer_dropout, + transformer_activation, + num_encoder_layers, + max_input_len, + transformer_norm_input + ): + super().__init__() + self.d_model = d_model + self.num_layer = num_encoder_layers + self.max_input_len = max_input_len + + # Creating Transformer Encoder Model + encoder_layer = nn.TransformerEncoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=transformer_dropout, + activation=transformer_activation + ) + encoder_norm = nn.LayerNorm(d_model) + self.transformer = nn.TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) + + + self.norm_input = None + if transformer_norm_input: + self.norm_input = nn.LayerNorm(d_model) + + def forward(self, padded_h_node, src_padding_mask): + """ + padded_h_node: n_b x B x h_d # 63, 257, 128 + src_key_padding_mask: B x n_b # 257, 63 + """ + # (S, B, h_d), (B, S) + if self.norm_input is not None: + padded_h_node = self.norm_input(padded_h_node) + + transformer_out = self.transformer(padded_h_node, src_key_padding_mask=src_padding_mask) # (S, B, h_d) + + return transformer_out, src_padding_mask + + + +if __name__ == '__main__': + model = TransformerEncoder( + d_model=12, + nhead=4, + dim_feedforward=32, + transformer_dropout=0.0, + transformer_activation='gelu', + num_encoder_layers=4, + max_input_len=34, + transformer_norm_input=0 + ) + print(model.norm_input) \ No newline at end of file diff --git a/train_tom.py b/train_tom.py new file mode 100644 index 0000000..9ce2934 --- /dev/null +++ b/train_tom.py @@ -0,0 +1,75 @@ +import random +from argparse import ArgumentParser +import numpy as np +import torch +from pytorch_lightning import Trainer +from pytorch_lightning.loggers import WandbLogger +from tom.model import GraphBC_T, GraphBCRNN + +torch.multiprocessing.set_sharing_strategy('file_system') + +parser = ArgumentParser() + +# program level args +parser.add_argument('--seed', type=int, default=4) +# data specific args +parser.add_argument('--data_path', type=str, default='/datasets/external/bib_train/graphs/all_tasks/') +parser.add_argument('--types', nargs='+', type=str, + default=['preference', 'multi_agent', 'single_object', 'instrumental_action'], + help='types of tasks used for training / validation') +parser.add_argument('--train', type=int, default=1) +parser.add_argument('--num_workers', type=int, default=4) +parser.add_argument('--batch_size', type=int, default=16) +parser.add_argument('--model_type', type=str, default='graphbcrnn') + +# model specific args +parser_model = ArgumentParser() +parser_model = GraphBC_T.add_model_specific_args(parser_model) +# parser_model = GraphBCRNN.add_model_specific_args(parser_model) +# NOTE: here unfortunately you have to select manually the model + +# add all the available trainer options to argparse +parser = Trainer.add_argparse_args(parser) + +# combine parsers +parser_all = ArgumentParser(conflict_handler='resolve', + parents=[parser, parser_model]) + +# parse args +args = parser_all.parse_args() +args.types = sorted(args.types) +print(args) + +random.seed(args.seed) +np.random.seed(args.seed) +torch.manual_seed(args.seed) + +# init model +if args.model_type == 'graphbct': + model = GraphBC_T(args) +elif args.model_type == 'graphbcrnn': + model = GraphBCRNN(args) +else: + raise NotImplementedError + +torch.autograd.set_detect_anomaly(True) + +logger = WandbLogger(project='bib') +trainer = Trainer( + gradient_clip_val=args.gradient_clip_val, + gpus=args.gpus, + auto_select_gpus=args.auto_select_gpus, + track_grad_norm=args.track_grad_norm, + check_val_every_n_epoch=args.check_val_every_n_epoch, + max_epochs=args.max_epochs, + accelerator=args.accelerator, + resume_from_checkpoint=args.resume_from_checkpoint, + stochastic_weight_avg=args.stochastic_weight_avg, + num_sanity_val_steps=args.num_sanity_val_steps, + logger=logger +) + +if args.train: + trainer.fit(model) +else: + raise NotImplementedError diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/build_graphs.py b/utils/build_graphs.py new file mode 100644 index 0000000..5af8fab --- /dev/null +++ b/utils/build_graphs.py @@ -0,0 +1,115 @@ +import sys +sys.path.append('/projects/bortoletto/icml2023_matteo/utils') +from dataset import TransitionDataset, TestTransitionDatasetSequence +import multiprocessing as mp +import argparse +import pickle as pkl +import os + +# Instantiate the parser +parser = argparse.ArgumentParser() +parser.add_argument('--cpus', type=int, + help='Number of processes') +parser.add_argument('--mode', type=str, + help='Train (train) or validation (val)') +args = parser.parse_args() + +NUM_PROCESSES = args.cpus +MODE = args.mode + +def generate_files(idx): + print('Generating idx', idx) + if os.path.exists(PATH+str(idx)+'.pkl'): + print('Index', idx, 'skipped.') + return + if MODE == 'train' or MODE == 'val': + states, actions, lens, n_nodes = dataset.__getitem__(idx) + with open(PATH+str(idx)+'.pkl', 'wb') as f: + pkl.dump([states, actions, lens, n_nodes], f) + elif MODE == 'test': + dem_expected_states, dem_expected_actions, dem_expected_lens, dem_expected_nodes, \ + dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, dem_unexpected_nodes, \ + query_expected_frames, target_expected_actions, \ + query_unexpected_frames, target_unexpected_actions = dataset.__getitem__(idx) + with open(PATH+str(idx)+'.pkl', 'wb') as f: + pkl.dump([ + dem_expected_states, dem_expected_actions, dem_expected_lens, dem_expected_nodes, \ + dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, dem_unexpected_nodes, \ + query_expected_frames, target_expected_actions, \ + query_unexpected_frames, target_unexpected_actions], f + ) + else: + raise ValueError('MODE can be only train, val or test.') + print(PATH+str(idx)+'.pkl saved.') + +if __name__ == "__main__": + if MODE == 'train': + print('TRAIN MODE') + PATH = '/datasets/external/bib_train/graphs/all_tasks/train_dgl_hetero_nobound_4feats/' + if not os.path.exists(PATH): + os.makedirs(PATH) + print(PATH, 'directory created.') + dataset = TransitionDataset( + path='/datasets/external/bib_train/', + types=['instrumental_action', 'multi_agent', 'preference', 'single_object'], + mode="train", + max_len=30, + num_test=1, + num_trials=9, + action_range=10, + process_data=0 + ) + pool = mp.Pool(processes=NUM_PROCESSES) + print('Starting graph generation with', NUM_PROCESSES, 'processes...') + pool.map(generate_files, [i for i in range(dataset.__len__())]) + pool.close() + elif MODE == 'val': + print('VALIDATION MODE') + types = ['multi_agent', 'instrumental_action', 'preference', 'single_object'] + for t in range(len(types)): + PATH = '/datasets/external/bib_train/graphs/all_tasks/val_dgl_hetero_nobound_4feats/'+types[t]+'/' + if not os.path.exists(PATH): + os.makedirs(PATH) + print(PATH, 'directory created.') + dataset = TransitionDataset( + path='/datasets/external/bib_train/', + types=[types[t]], + mode="val", + max_len=30, + num_test=1, + num_trials=9, + action_range=10, + process_data=0 + ) + pool = mp.Pool(processes=NUM_PROCESSES) + print('Starting', types[t], 'graph generation with', NUM_PROCESSES, 'processes...') + pool.map(generate_files, [i for i in range(dataset.__len__())]) + pool.close() + elif MODE == 'test': + print('TEST MODE') + types = [ + 'preference', 'multi_agent', 'inaccessible_goal', + 'efficiency_irrational', 'efficiency_time','efficiency_path', + 'instrumental_no_barrier', 'instrumental_blocking_barrier', 'instrumental_inconsequential_barrier' + ] + for t in range(len(types)): + PATH = '/datasets/external/bib_evaluation_1_1/graphs/all_tasks_dgl_hetero_nobound_4feats/'+types[t]+'/' + if not os.path.exists(PATH): + os.makedirs(PATH) + print(PATH, 'directory created.') + dataset = TestTransitionDatasetSequence( + path='/datasets/external/bib_evaluation_1_1/', + task_type=types[t], + mode="test", + num_test=1, + num_trials=9, + action_range=10, + process_data=0, + max_len=30 + ) + pool = mp.Pool(processes=NUM_PROCESSES) + print('Starting', types[t], 'graph generation with', NUM_PROCESSES, 'processes...') + pool.map(generate_files, [i for i in range(dataset.__len__())]) + pool.close() + else: + raise ValueError diff --git a/utils/dataset.py b/utils/dataset.py new file mode 100644 index 0000000..061879d --- /dev/null +++ b/utils/dataset.py @@ -0,0 +1,487 @@ +import dgl +import torch +import torch.utils.data +import os +import pickle as pkl +import json +import numpy as np +from tqdm import tqdm + +import sys +sys.path.append('/projects/bortoletto/irene/') +from utils.grid_object import * +from utils.relations import * + +# ========================== Helper functions ========================== + +def index_data(json_list, path_list): + print(f'processing files {len(json_list)}') + data_tuples = [] + for j, v in tqdm(zip(json_list, path_list)): + with open(j, 'r') as f: + state = json.load(f) + ep_lens = [len(x) for x in state] + past_len = 0 + for e, l in enumerate(ep_lens): + data_tuples.append([]) + # skip first 30 frames and last 83 frames + for f in range(30, l - 83): + # find action taken; + f0x, f0y = state[e][f]['agent'][0] + f1x, f1y = state[e][f + 1]['agent'][0] + dx = (f1x - f0x) / 2. + dy = (f1y - f0y) / 2. + action = [dx, dy] + #data_tuples[-1].append((v, past_len + f, action)) + data_tuples[-1].append((j, past_len + f, action)) + # data_tuples = [[json file, frame number, action] for each episode in each video] + assert len(data_tuples[-1]) > 0 + past_len += l + return data_tuples + +# ========================== Dataset class ========================== + +class TransitionDataset(torch.utils.data.Dataset): + """ + Training dataset class for the behavior cloning mlp model. + Args: + path: path to the dataset + types: list of video types to include + mode: train, val + num_test: number of test state-action pairs + num_trials: number of trials in an episode + action_range: number of frames to skip; actions are combined over these number of frames (displcement) of the agent + process_data: whether to the videos or not (skip if already processed) + max_len: max number of context state-action pairs + __getitem__: + returns: (states, actions, lens, n_nodes) + dem_frames: batched DGLGraph.heterograph + dem_actions: (max_len, 2) + query_frames: DGLGraph.heterograph + target_actions: (num_test, 2) + """ + def __init__( + self, + path, + types=None, + mode="train", + num_test=1, + num_trials=9, + action_range=10, + process_data=0, + max_len=30 + ): + self.path = path + self.types = types + self.mode = mode + self.num_trials = num_trials + self.num_test = num_test + self.action_range = action_range + self.max_len = max_len + self.ep_combs = self.num_trials * (self.num_trials - 2) # 9p2 - 9 + self.eps = [[x, y] for x in range(self.num_trials) for y in range(self.num_trials) if x != y] + types_str = '_'.join(self.types) + self.rel_deter_func = [ + is_top_adj, is_left_adj, is_top_right_adj, is_top_left_adj, + is_down_adj, is_right_adj, is_down_left_adj, is_down_right_adj, + is_left, is_right, is_front, is_back, is_aligned, is_close + ] + self.path_list = [] + self.json_list = [] + # get video paths and json file paths + for t in types: + print(f'reading files of type {t} in {mode}') + paths = [os.path.join(self.path, t, x) for x in os.listdir(os.path.join(self.path, t)) if + x.endswith(f'.mp4')] + jsons = [os.path.join(self.path, t, x) for x in os.listdir(os.path.join(self.path, t)) if + x.endswith(f'.json') and 'index' not in x] + paths = sorted(paths) + jsons = sorted(jsons) + if mode == 'train': + self.path_list += paths[:int(0.8 * len(jsons))] + self.json_list += jsons[:int(0.8 * len(jsons))] + elif mode == 'val': + self.path_list += paths[int(0.8 * len(jsons)):] + self.json_list += jsons[int(0.8 * len(jsons)):] + else: + self.path_list += paths + self.json_list += jsons + self.data_tuples = [] + if process_data: + # index the videos in the dataset directory. This is done to speed up the retrieval of videos. + # frame index, action tuples are stored + self.data_tuples = index_data(self.json_list, self.path_list) + # tuples of frame index and action (displacement of agent) + index_dict = {'data_tuples': self.data_tuples} + with open(os.path.join(self.path, f'jindex_bib_{mode}_{types_str}.json'), 'w') as fp: + json.dump(index_dict, fp) + else: + # read pre-indexed data + with open(os.path.join(self.path, f'jindex_bib_{mode}_{types_str}.json'), 'r') as fp: + index_dict = json.load(fp) + self.data_tuples = index_dict['data_tuples'] + self.tot_trials = len(self.path_list) * 9 + + def _get_frame_graph(self, jsonfile, frame_idx): + # load json + with open(jsonfile, 'rb') as f: + frame_data = json.load(f) + flat_list = [x for xs in frame_data for x in xs] + # extract entities + grid_objs = parse_objects(flat_list[frame_idx]) + # --- build the graph + adj = self._get_spatial_rel(grid_objs) + # define edges + is_top_adj_src, is_top_adj_dst = np.nonzero(adj[0]) + is_left_adj_src, is_left_adj_dst = np.nonzero(adj[1]) + is_top_right_adj_src, is_top_right_adj_dst = np.nonzero(adj[2]) + is_top_left_adj_src, is_top_left_adj_dst = np.nonzero(adj[3]) + is_down_adj_src, is_down_adj_dst = np.nonzero(adj[4]) + is_right_adj_src, is_right_adj_dst = np.nonzero(adj[5]) + is_down_left_adj_src, is_down_left_adj_dst = np.nonzero(adj[6]) + is_down_right_adj_src, is_down_right_adj_dst = np.nonzero(adj[7]) + is_left_src, is_left_dst = np.nonzero(adj[8]) + is_right_src, is_right_dst = np.nonzero(adj[9]) + is_front_src, is_front_dst = np.nonzero(adj[10]) + is_back_src, is_back_dst = np.nonzero(adj[11]) + is_aligned_src, is_aligned_dst = np.nonzero(adj[12]) + is_close_src, is_close_dst = np.nonzero(adj[13]) + g = dgl.heterograph({ + ('obj', 'is_top_adj', 'obj'): (torch.tensor(is_top_adj_src), torch.tensor(is_top_adj_dst)), + ('obj', 'is_left_adj', 'obj'): (torch.tensor(is_left_adj_src), torch.tensor(is_left_adj_dst)), + ('obj', 'is_top_right_adj', 'obj'): (torch.tensor(is_top_right_adj_src), torch.tensor(is_top_right_adj_dst)), + ('obj', 'is_top_left_adj', 'obj'): (torch.tensor(is_top_left_adj_src), torch.tensor(is_top_left_adj_dst)), + ('obj', 'is_down_adj', 'obj'): (torch.tensor(is_down_adj_src), torch.tensor(is_down_adj_dst)), + ('obj', 'is_right_adj', 'obj'): (torch.tensor(is_right_adj_src), torch.tensor(is_right_adj_dst)), + ('obj', 'is_down_left_adj', 'obj'): (torch.tensor(is_down_left_adj_src), torch.tensor(is_down_left_adj_dst)), + ('obj', 'is_down_right_adj', 'obj'): (torch.tensor(is_down_right_adj_src), torch.tensor(is_down_right_adj_dst)), + ('obj', 'is_left', 'obj'): (torch.tensor(is_left_src), torch.tensor(is_left_dst)), + ('obj', 'is_right', 'obj'): (torch.tensor(is_right_src), torch.tensor(is_right_dst)), + ('obj', 'is_front', 'obj'): (torch.tensor(is_front_src), torch.tensor(is_front_dst)), + ('obj', 'is_back', 'obj'): (torch.tensor(is_back_src), torch.tensor(is_back_dst)), + ('obj', 'is_aligned', 'obj'): (torch.tensor(is_aligned_src), torch.tensor(is_aligned_dst)), + ('obj', 'is_close', 'obj'): (torch.tensor(is_close_src), torch.tensor(is_close_dst)) + }, num_nodes_dict={'obj': len(grid_objs)}) + g = self._add_node_features(grid_objs, g) + breakpoint() + return g + + def _add_node_features(self, objs, graph): + for obj_idx, obj in enumerate(objs): + graph.nodes[obj_idx].data['type'] = torch.tensor(obj.type) + graph.nodes[obj_idx].data['pos'] = torch.tensor([[obj.x, obj.y]], dtype=torch.float32) + assert len(obj.attributes) == 2 and None not in obj.attributes[0] and None not in obj.attributes[1] + graph.nodes[obj_idx].data['color'] = torch.tensor([obj.attributes[0]]) + graph.nodes[obj_idx].data['shape'] = torch.tensor([obj.attributes[1]]) + return graph + + def _get_spatial_rel(self, objs): + spatial_tensors = [np.zeros([len(objs), len(objs)]) for _ in range(len(self.rel_deter_func))] + for obj_idx1, obj1 in enumerate(objs): + for obj_idx2, obj2 in enumerate(objs): + direction_vec = np.array((0, -1)) # Up + for rel_idx, func in enumerate(self.rel_deter_func): + if func(obj1, obj2, direction_vec): + spatial_tensors[rel_idx][obj_idx1, obj_idx2] = 1.0 + return spatial_tensors + + def get_trial(self, trials, step=10): + # retrieve state embeddings and actions from cached file + states = [] + actions = [] + trial_len = [] + lens = [] + n_nodes = [] + # 8 trials + for t in trials: + tl = [(t, n) for n in range(0, len(self.data_tuples[t]), step)] + if len(tl) > self.max_len: # 30 + tl = tl[:self.max_len] + trial_len.append(tl) + for tl in trial_len: + states.append([]) + actions.append([]) + lens.append(len(tl)) + for t, n in tl: + video = self.data_tuples[t][n][0] + states[-1].append(self._get_frame_graph(video, self.data_tuples[t][n][1])) + n_nodes.append(states[-1][-1].number_of_nodes()) + # actions are pooled over frames + if len(self.data_tuples[t]) > n + self.action_range: + actions_xy = [d[2] for d in self.data_tuples[t][n:n + self.action_range]] + else: + actions_xy = [d[2] for d in self.data_tuples[t][n:]] + actions_xy = np.array(actions_xy) + actions_xy = np.mean(actions_xy, axis=0) + actions[-1].append(actions_xy) + states[-1] = dgl.batch(states[-1]) + actions[-1] = torch.tensor(np.array(actions[-1])) + trial_actions_padded = torch.zeros(self.max_len, actions[-1].size(1)) + trial_actions_padded[:actions[-1].size(0), :] = actions[-1] + actions[-1] = trial_actions_padded + return states, actions, lens, n_nodes + + def __getitem__(self, idx): + ep_trials = [idx * self.num_trials + t for t in range(self.num_trials)] # [idx, ..., idx+8] + states, actions, lens, n_nodes = self.get_trial(ep_trials, step=self.action_range) + return states, actions, lens, n_nodes + + def __len__(self): + return self.tot_trials // self.num_trials + + +class TestTransitionDatasetSequence(torch.utils.data.Dataset): + """ + Test dataset class for the behavior cloning rnn model. This dataset is used to test the model on the eval data. + This class is used to compare plausible and implausible episodes. + Args: + path: path to the dataset + types: list of video types to include + size: size of the frames to be returned + mode: test + num_context: number of context state-action pairs + num_test: number of test state-action pairs + num_trials: number of trials in an episode + action_range: number of frames to skip; actions are combined over these number of frames (displcement) of the agent + process_data: whether to the videos or not (skip if already processed) + __getitem__: + returns: (expected_dem_frames, expected_dem_actions, expected_dem_lens expected_query_frames, expected_target_actions, + unexpected_dem_frames, unexpected_dem_actions, unexpected_dem_lens, unexpected_query_frames, unexpected_target_actions) + dem_frames: (num_context, max_len, 3, size, size) + dem_actions: (num_context, max_len, 2) + dem_lens: (num_context) + query_frames: (num_test, 3, size, size) + target_actions: (num_test, 2) + """ + def __init__( + self, + path, + task_type=None, + mode="test", + num_test=1, + num_trials=9, + action_range=10, + process_data=0, + max_len=30 + ): + self.path = path + self.task_type = task_type + self.mode = mode + self.num_trials = num_trials + self.num_test = num_test + self.action_range = action_range + self.max_len = max_len + self.ep_combs = self.num_trials * (self.num_trials - 2) # 9p2 - 9 + self.eps = [[x, y] for x in range(self.num_trials) for y in range(self.num_trials) if x != y] + self.path_list_exp = [] + self.json_list_exp = [] + self.path_list_un = [] + self.json_list_un = [] + + print(f'reading files of type {task_type} in {mode}') + + paths_expected = sorted([os.path.join(self.path, task_type, x) for x in os.listdir(os.path.join(self.path, task_type)) if + x.endswith('e.mp4')]) + jsons_expected = sorted([os.path.join(self.path, task_type, x) for x in os.listdir(os.path.join(self.path, task_type)) if + x.endswith('e.json') and 'index' not in x]) + paths_unexpected = sorted([os.path.join(self.path, task_type, x) for x in os.listdir(os.path.join(self.path, task_type)) if + x.endswith('u.mp4')]) + jsons_unexpected = sorted([os.path.join(self.path, task_type, x) for x in os.listdir(os.path.join(self.path, task_type)) if + x.endswith('u.json') and 'index' not in x]) + self.path_list_exp += paths_expected + self.json_list_exp += jsons_expected + self.path_list_un += paths_unexpected + self.json_list_un += jsons_unexpected + self.data_expected = [] + self.data_unexpected = [] + + if process_data: + # index data. This is done to speed up video retrieval. + # frame index, action tuples are stored + self.data_expected = index_data(self.json_list_exp, self.path_list_exp) + index_dict = {'data_tuples': self.data_expected} + with open(os.path.join(self.path, f'jindex_bib_test_{task_type}e.json'), 'w') as fp: + json.dump(index_dict, fp) + + self.data_unexpected = index_data(self.json_list_un, self.path_list_un) + index_dict = {'data_tuples': self.data_unexpected} + with open(os.path.join(self.path, f'jindex_bib_test_{task_type}u.json'), 'w') as fp: + json.dump(index_dict, fp) + else: + with open(os.path.join(self.path, f'jindex_bib_{mode}_{task_type}e.json'), 'r') as fp: + index_dict = json.load(fp) + self.data_expected = index_dict['data_tuples'] + with open(os.path.join(self.path, f'jindex_bib_{mode}_{task_type}u.json'), 'r') as fp: + index_dict = json.load(fp) + self.data_unexpected = index_dict['data_tuples'] + + self.rel_deter_func = [ + is_top_adj, is_left_adj, is_top_right_adj, is_top_left_adj, + is_down_adj, is_right_adj, is_down_left_adj, is_down_right_adj, + is_left, is_right, is_front, is_back, is_aligned, is_close + ] + + print('Done.') + + def _get_frame_graph(self, jsonfile, frame_idx): + # load json + with open(jsonfile, 'rb') as f: + frame_data = json.load(f) + flat_list = [x for xs in frame_data for x in xs] + # extract entities + grid_objs = parse_objects(flat_list[frame_idx]) + # --- build the graph + adj = self._get_spatial_rel(grid_objs) + # define edges + is_top_adj_src, is_top_adj_dst = np.nonzero(adj[0]) + is_left_adj_src, is_left_adj_dst = np.nonzero(adj[1]) + is_top_right_adj_src, is_top_right_adj_dst = np.nonzero(adj[2]) + is_top_left_adj_src, is_top_left_adj_dst = np.nonzero(adj[3]) + is_down_adj_src, is_down_adj_dst = np.nonzero(adj[4]) + is_right_adj_src, is_right_adj_dst = np.nonzero(adj[5]) + is_down_left_adj_src, is_down_left_adj_dst = np.nonzero(adj[6]) + is_down_right_adj_src, is_down_right_adj_dst = np.nonzero(adj[7]) + is_left_src, is_left_dst = np.nonzero(adj[8]) + is_right_src, is_right_dst = np.nonzero(adj[9]) + is_front_src, is_front_dst = np.nonzero(adj[10]) + is_back_src, is_back_dst = np.nonzero(adj[11]) + is_aligned_src, is_aligned_dst = np.nonzero(adj[12]) + is_close_src, is_close_dst = np.nonzero(adj[13]) + g = dgl.heterograph({ + ('obj', 'is_top_adj', 'obj'): (torch.tensor(is_top_adj_src), torch.tensor(is_top_adj_dst)), + ('obj', 'is_left_adj', 'obj'): (torch.tensor(is_left_adj_src), torch.tensor(is_left_adj_dst)), + ('obj', 'is_top_right_adj', 'obj'): (torch.tensor(is_top_right_adj_src), torch.tensor(is_top_right_adj_dst)), + ('obj', 'is_top_left_adj', 'obj'): (torch.tensor(is_top_left_adj_src), torch.tensor(is_top_left_adj_dst)), + ('obj', 'is_down_adj', 'obj'): (torch.tensor(is_down_adj_src), torch.tensor(is_down_adj_dst)), + ('obj', 'is_right_adj', 'obj'): (torch.tensor(is_right_adj_src), torch.tensor(is_right_adj_dst)), + ('obj', 'is_down_left_adj', 'obj'): (torch.tensor(is_down_left_adj_src), torch.tensor(is_down_left_adj_dst)), + ('obj', 'is_down_right_adj', 'obj'): (torch.tensor(is_down_right_adj_src), torch.tensor(is_down_right_adj_dst)), + ('obj', 'is_left', 'obj'): (torch.tensor(is_left_src), torch.tensor(is_left_dst)), + ('obj', 'is_right', 'obj'): (torch.tensor(is_right_src), torch.tensor(is_right_dst)), + ('obj', 'is_front', 'obj'): (torch.tensor(is_front_src), torch.tensor(is_front_dst)), + ('obj', 'is_back', 'obj'): (torch.tensor(is_back_src), torch.tensor(is_back_dst)), + ('obj', 'is_aligned', 'obj'): (torch.tensor(is_aligned_src), torch.tensor(is_aligned_dst)), + ('obj', 'is_close', 'obj'): (torch.tensor(is_close_src), torch.tensor(is_close_dst)) + }, num_nodes_dict={'obj': len(grid_objs)}) + g = self._add_node_features(grid_objs, g) + return g + + def _add_node_features(self, objs, graph): + for obj_idx, obj in enumerate(objs): + graph.nodes[obj_idx].data['type'] = torch.tensor(obj.type) + graph.nodes[obj_idx].data['pos'] = torch.tensor([[obj.x, obj.y]], dtype=torch.float32) + assert len(obj.attributes) == 2 and None not in obj.attributes[0] and None not in obj.attributes[1] + graph.nodes[obj_idx].data['color'] = torch.tensor([obj.attributes[0]]) + graph.nodes[obj_idx].data['shape'] = torch.tensor([obj.attributes[1]]) + return graph + + def _get_spatial_rel(self, objs): + spatial_tensors = [np.zeros([len(objs), len(objs)]) for _ in range(len(self.rel_deter_func))] + for obj_idx1, obj1 in enumerate(objs): + for obj_idx2, obj2 in enumerate(objs): + direction_vec = np.array((0, -1)) # Up why?????????????? + for rel_idx, func in enumerate(self.rel_deter_func): + if func(obj1, obj2, direction_vec): + spatial_tensors[rel_idx][obj_idx1, obj_idx2] = 1.0 + return spatial_tensors + + def get_trial(self, trials, data, step=10): + # retrieve state embeddings and actions from cached file + states = [] + actions = [] + trial_len = [] + lens = [] + n_nodes = [] + for t in trials: + tl = [(t, n) for n in range(0, len(data[t]), step)] + if len(tl) > self.max_len: + tl = tl[:self.max_len] + trial_len.append(tl) + for tl in trial_len: + states.append([]) + actions.append([]) + lens.append(len(tl)) + for t, n in tl: + video = data[t][n][0] + states[-1].append(self._get_frame_graph(video, data[t][n][1])) + n_nodes.append(states[-1][-1].number_of_nodes()) + if len(data[t]) > n + self.action_range: + actions_xy = [d[2] for d in data[t][n:n + self.action_range]] + else: + actions_xy = [d[2] for d in data[t][n:]] + actions_xy = np.array(actions_xy) + actions_xy = np.mean(actions_xy, axis=0) + actions[-1].append(actions_xy) + states[-1] = dgl.batch(states[-1]) + actions[-1] = torch.tensor(np.array(actions[-1])) + trial_actions_padded = torch.zeros(self.max_len, actions[-1].size(1)) + trial_actions_padded[:actions[-1].size(0), :] = actions[-1] + actions[-1] = trial_actions_padded + return states, actions, lens, n_nodes + + def get_test(self, trial, data, step=10): + # retrieve state embeddings and actions from cached file + states = [] + actions = [] + trial_len = [] + trial_len += [(trial, n) for n in range(0, len(data[trial]), step)] + for t, n in trial_len: + video = data[t][n][0] + state = self._get_frame_graph(video, data[t][n][1]) + if len(data[t]) > n + self.action_range: + actions_xy = [d[2] for d in data[t][n:n + self.action_range]] + else: + actions_xy = [d[2] for d in data[t][n:]] + actions_xy = np.array(actions_xy) + actions_xy = np.mean(actions_xy, axis=0) + actions.append(actions_xy) + states.append(state) + #states = torch.stack(states) + states = dgl.batch(states) + actions = torch.tensor(np.array(actions)) + return states, actions + + def __getitem__(self, idx): + ep_trials = [idx * self.num_trials + t for t in range(self.num_trials)] + dem_expected_states, dem_expected_actions, dem_expected_lens, dem_expected_nodes = self.get_trial( + ep_trials[:-1], self.data_expected, step=self.action_range + ) + dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, dem_unexpected_nodes = self.get_trial( + ep_trials[:-1], self.data_unexpected, step=self.action_range + ) + query_expected_frames, target_expected_actions = self.get_test( + ep_trials[-1], self.data_expected, step=self.action_range + ) + query_unexpected_frames, target_unexpected_actions = self.get_test( + ep_trials[-1], self.data_unexpected, step=self.action_range + ) + return dem_expected_states, dem_expected_actions, dem_expected_lens, dem_expected_nodes, \ + dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, dem_unexpected_nodes, \ + query_expected_frames, target_expected_actions, \ + query_unexpected_frames, target_unexpected_actions + + def __len__(self): + return len(self.path_list_exp) + + +if __name__ == '__main__': + types = ['preference', 'multi_agent', 'inaccessible_goal', + 'efficiency_irrational', 'efficiency_time','efficiency_path', + 'instrumental_no_barrier', 'instrumental_blocking_barrier', 'instrumental_inconsequential_barrier'] + for t in types: + ttd = TestTransitionDatasetSequence(path='/datasets/external/bib_evaluation_1_1/', task_type=t, process_data=0, mode='test') + for i in range(ttd.__len__()): + print(i, end='\r') + dem_expected_states, dem_expected_actions, dem_expected_lens, dem_expected_nodes, \ + dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, dem_unexpected_nodes, \ + query_expected_frames, target_expected_actions, \ + query_unexpected_frames, target_unexpected_actions = ttd.__getitem__(i) + for j in range(8): + if not torch.tensor([1., 0., 0., 0., 0., 0., 0., 0., 0.]) in dem_expected_states[j].ndata['type']: + print(i) + if not torch.tensor([1., 0., 0., 0., 0., 0., 0., 0., 0.]) in dem_unexpected_states[j].ndata['type']: + print(i) + if not torch.tensor([1., 0., 0., 0., 0., 0., 0., 0., 0.]) in query_expected_frames.ndata['type']: + print(i) + if not torch.tensor([1., 0., 0., 0., 0., 0., 0., 0., 0.]) in query_unexpected_frames.ndata['type']: + print(i) diff --git a/utils/grid_object.py b/utils/grid_object.py new file mode 100644 index 0000000..d5cc682 --- /dev/null +++ b/utils/grid_object.py @@ -0,0 +1,174 @@ +import json +import pdb +import numpy as np +from sklearn.preprocessing import OneHotEncoder +import itertools + + +SHAPES = { + # walls + 'square': 0, + # objects + 'heart': 1, 'clove_tree': 2, 'moon': 3, 'wine': 4, 'double_dia': 5, + 'flag': 6, 'capsule': 7, 'vase': 8, 'curved_triangle': 9, 'spoon': 10, + 'medal': 11, # inaccessible goal + # home + 'home': 12, + # agent + 'pentagon': 13, 'clove': 14, 'kite': 15, + # key + 'triangle0': 16, 'triangle180': 16, 'triangle90': 16, 'triangle270': 16, + # lock + 'triangle_slot0': 17, 'triangle_slot180': 17, 'triangle_slot90': 17, 'triangle_slot270': 17 +} + +ENTITIES = { + 'agent': 0 , 'walls': 1, 'fuse_walls': 2, 'key': 3, + 'blocking': 4, 'home': 5, 'objects': 6, 'pin': 7, 'lock': 8 +} + +# =============================== GridPbject class =============================== + +class GridObject(): + "object is specified by its location" + def __init__(self, x, y, object_type, attributes=[]): + self.x = x + self.y = y + self.type = object_type + self.attributes = attributes + + @property + def pos(self): + return np.array([self.x, self.y]) + + @property + def name(self): + return {'type': str(self.type), + 'x': str(self.x), + 'y': str(self.y), + 'color': self.attributes[0], + 'shape': self.attributes[1]} + +# =============================== Helper functions =============================== + +def type2index(key): + for name, idx in ENTITIES.items(): + if name == key: + return idx + +def find_shape(shape_string, print_shape=False): + try: + shape = shape_string.split('/')[-1].split('.')[0] + except: + shape = shape_string.split('.')[0] + if print_shape: print(shape) + for name, idx in SHAPES.items(): + if name == shape: + return idx + +def parse_objects(frame): + """ + x and y are computed differently from walls and objects + for walls x, y = obj[0][0] + obj[1][0]/2, obj[0][1] + obj[1][1]/2 + for objects x, y = obj[0][0] + obj[1], obj[0][1] + obj[1] + :param obj: + :return: GridObject + """ + shape_onehot_encoder = OneHotEncoder(sparse=False) + shape_onehot_encoder.fit([[i] for i in range(len(SHAPES)-6)]) + type_onehot_encoder = OneHotEncoder(sparse=False) + type_onehot_encoder.fit([[i] for i in range(len(ENTITIES))]) + # remove duplicate walls + frame['walls'].sort() + frame['walls'] = list(k for k, _ in itertools.groupby(frame['walls'])) + # remove boundary walls + frame['walls'] = [w for w in frame['walls'] if (w[0][0] != 0 and w[0][0] != 180 and w[0][1] != 0 and w[0][1] != 180)] + # remove duplicate fuse_walls + frame['fuse_walls'].sort() + frame['fuse_walls'] = list(k for k, _ in itertools.groupby(frame['fuse_walls'])) + grid_objs = [] + assert 'agent' in frame.keys() + for key in frame.keys(): + #print(key) + if key == 'size': + continue + obj = frame[key] + if obj == []: + #print(key, 'skipped') + continue + obj_type = type2index(key) + obj_type = type_onehot_encoder.transform([[obj_type]]) + if key == 'walls': + for wall in obj: + x, y = wall[0][0] + wall[1][0]/2, wall[0][1] + wall[1][1]/2 + #x, y = (wall[0][0] + wall[1][0]/2)/200, (wall[0][1] + wall[1][1]/2)/200 if u use this you need to change relations.py!!! + color = [0, 0, 0] if key == 'walls' else [80, 146, 56] + #color = [c / 255 for c in color] + shape = 0 + assert shape in SHAPES.values(), 'Shape not found' + shape = shape_onehot_encoder.transform([[shape]])[0] + grid_obj = GridObject(x=x, y=y, object_type=obj_type, attributes=[color, shape]) + grid_objs.append(grid_obj) + elif key == 'fuse_walls': + # resample green barriers + obj = [obj[i] for i in range(len(obj)) if (obj[i][0][0] % 20 == 0 and obj[i][0][1] % 20 == 0)] + for wall in obj: + x, y = wall[0][0] + wall[1][0]/2, wall[0][1] + wall[1][1]/2 + #x, y = (wall[0][0] + wall[1][0]/2)/200, (wall[0][1] + wall[1][1]/2)/200 if u use this you need to change relations.py!!! + color = [80, 146, 56] + #color = [c / 255 for c in color] + shape = 0 + assert shape in SHAPES.values(), 'Shape not found' + shape = shape_onehot_encoder.transform([[shape]])[0] + grid_obj = GridObject(x=x, y=y, object_type=obj_type, attributes=[color, shape]) + grid_objs.append(grid_obj) + elif key == 'objects': + for ob in obj: + x, y = ob[0][0] + ob[1], ob[0][1] + ob[1] + color = ob[-1] + #color = [c / 255 for c in color] + shape = find_shape(ob[2], print_shape=False) + assert shape in SHAPES.values(), 'Shape not found' + shape = shape_onehot_encoder.transform([[shape]])[0] + grid_obj = GridObject(x=x, y=y, object_type=obj_type, attributes=[color, shape]) + grid_objs.append(grid_obj) + elif key == 'key': + obj = obj[0] + x, y = obj[0][0] + obj[1], obj[0][1] + obj[1] + color = obj[-1] + #color = [c / 255 for c in color] + shape = find_shape(obj[2], print_shape=False) + assert shape in SHAPES.values(), 'Shape not found' + shape = shape_onehot_encoder.transform([[shape]])[0] + grid_obj = GridObject(x=x, y=y, object_type=obj_type, attributes=[color, shape]) + grid_objs.append(grid_obj) + elif key == 'lock': + obj = obj[0] + x, y = obj[0][0] + obj[1], obj[0][1] + obj[1] + color = obj[-1] + #color = [c / 255 for c in color] + shape = find_shape(obj[2], print_shape=False) + assert shape in SHAPES.values(), 'Shape not found' + shape = shape_onehot_encoder.transform([[shape]])[0] + grid_obj = GridObject(x=x, y=y, object_type=obj_type, attributes=[color, shape]) + grid_objs.append(grid_obj) + else: + try: + x, y = obj[0][0] + obj[1], obj[0][1] + obj[1] + color = obj[-1] + #color = [c / 255 for c in color] + shape = find_shape(obj[2], print_shape=False) + assert shape in SHAPES.values(), 'Shape not found' + shape = shape_onehot_encoder.transform([[shape]])[0] + except: + # [[[x, y], extension, shape, color]] in some cases in instrumental_no_barrier (bib_evaluation_1_1) + x, y = obj[0][0][0] + obj[0][1], obj[0][0][1] + obj[0][1] + color = obj[0][-1] + #color = [c / 255 for c in color] + assert len(color) == 3 + shape = find_shape(obj[0][2], print_shape=False) + assert shape in SHAPES.values(), 'Shape not found' + shape = shape_onehot_encoder.transform([[shape]])[0] + grid_obj = GridObject(x=x, y=y, object_type=obj_type, attributes=[color, shape]) + grid_objs.append(grid_obj) + return grid_objs diff --git a/utils/index_data.py b/utils/index_data.py new file mode 100644 index 0000000..8dd9a18 --- /dev/null +++ b/utils/index_data.py @@ -0,0 +1,124 @@ +import json +import os +import torch +import torch.utils.data +from tqdm import tqdm + + +def index_data(json_list, path_list): + print(f'processing files {len(json_list)}') + data_tuples = [] + for j, v in tqdm(zip(json_list, path_list)): + with open(j, 'r') as f: + state = json.load(f) + ep_lens = [len(x) for x in state] + past_len = 0 + for e, l in enumerate(ep_lens): + data_tuples.append([]) + # skip first 30 frames and last 83 frames + for f in range(30, l - 83): + # find action taken; + f0x, f0y = state[e][f]['agent'][0] + f1x, f1y = state[e][f + 1]['agent'][0] + dx = (f1x - f0x) / 2. + dy = (f1y - f0y) / 2. + action = [dx, dy] + #data_tuples[-1].append((v, past_len + f, action)) + data_tuples[-1].append((j, past_len + f, action)) + # data_tuples = (json file, frame number, action) + assert len(data_tuples[-1]) > 0 + past_len += l + return data_tuples + +class TransitionDataset(torch.utils.data.Dataset): + """ + Training dataset class for the behavior cloning mlp model. + Args: + path: path to the dataset + types: list of video types to include + size: size of the frames to be returned + mode: train, val + num_context: number of context state-action pairs + num_test: number of test state-action pairs + num_trials: number of trials in an episode + action_range: number of frames to skip; actions are combined over these number of frames (displcement) of the agent + process_data: whether to the videos or not (skip if already processed) + __getitem__: + returns: (dem_frames, dem_actions, query_frames, target_actions) + dem_frames: (num_context, 3, size, size) + dem_actions: (num_context, 2) + query_frames: (num_test, 3, size, size) + target_actions: (num_test, 2) + """ + def __init__(self, path, types=None, size=None, mode="train", num_context=30, num_test=1, num_trials=9, + action_range=10, process_data=0): + + self.path = path + self.types = types + self.size = size + self.mode = mode + self.num_trials = num_trials + self.num_context = num_context + self.num_test = num_test + self.action_range = action_range + self.ep_combs = self.num_trials * (self.num_trials - 2) # 9p2 - 9 + self.eps = [[x, y] for x in range(self.num_trials) for y in range(self.num_trials) if x != y] + types_str = '_'.join(self.types) + + self.path_list = [] + self.json_list = [] + # get video paths and json file paths + for t in types: + print(f'reading files of type {t} in {mode}') + paths = [os.path.join(self.path, t, x) for x in os.listdir(os.path.join(self.path, t)) if + x.endswith(f'.mp4')] + jsons = [os.path.join(self.path, t, x) for x in os.listdir(os.path.join(self.path, t)) if + x.endswith(f'.json') and 'index' not in x] + + paths = sorted(paths) + jsons = sorted(jsons) + + if mode == 'train': + self.path_list += paths[:int(0.8 * len(jsons))] + self.json_list += jsons[:int(0.8 * len(jsons))] + elif mode == 'val': + self.path_list += paths[int(0.8 * len(jsons)):] + self.json_list += jsons[int(0.8 * len(jsons)):] + else: + self.path_list += paths + self.json_list += jsons + + self.data_tuples = [] + if process_data: + # index the videos in the dataset directory. This is done to speed up the retrieval of videos. + # frame index, action tuples are stored + self.data_tuples = index_data(self.json_list, self.path_list) + # tuples of frame index and action (displacement of agent) + index_dict = {'data_tuples': self.data_tuples} + with open(os.path.join(self.path, f'jindex_bib_{mode}_{types_str}.json'), 'w') as fp: + json.dump(index_dict, fp) + else: + # read pre-indexed data + with open(os.path.join(self.path, f'jindex_bib_{mode}_{types_str}.json'), 'r') as fp: + index_dict = json.load(fp) + self.data_tuples = index_dict['data_tuples'] + + self.tot_trials = len(self.path_list) * 9 + + + def __getitem__(self, idx): + print('Empty') + return + + def __len__(self): + return self.tot_trials // self.num_trials + + +if __name__ == "__main__": + dataset = TransitionDataset(path='/datasets/external/bib_train/', + types=['multi_agent', 'instrumental_action'], #['instrumental_action', 'multi_agent', 'preference', 'single_object'], + size=(84, 84), + mode="train", num_context=30, + num_test=1, num_trials=9, + action_range=10, process_data=1) + print(len(dataset)) diff --git a/utils/relations.py b/utils/relations.py new file mode 100644 index 0000000..0ef7454 --- /dev/null +++ b/utils/relations.py @@ -0,0 +1,116 @@ +import numpy as np + + +# =============================== relationships to build the graph =============================== + +def rotate_vec2d(vec, degrees): + """ + rotate a vector anti-clockwise + :param vec: + :param degrees: + :return: + """ + theta = np.radians(degrees) + c, s = np.cos(theta), np.sin(theta) + R = np.array(((c, -s), (s, c))) + return R@vec + +# ---------- Remote Directional Relations -------------------------------------------------------- + +def is_front(obj1, obj2, direction_vec)->bool: + diff = obj2.pos - obj1.pos + return diff@direction_vec > 0.1 + +def is_back(obj1, obj2, direction_vec)->bool: + diff = obj2.pos - obj1.pos + return diff@direction_vec < -0.1 + +def is_left(obj1, obj2, direction_vec)->bool: + left_vec = rotate_vec2d(direction_vec, -90) + diff = obj2.pos - obj1.pos + return diff@left_vec > 0.1 + +def is_right(obj1, obj2, direction_vec)->bool: + left_vec = rotate_vec2d(direction_vec, 90) + diff = obj2.pos - obj1.pos + return diff@left_vec > 0.1 + +# ---------- Alignment and Adjacency Relations --------------------------------------------------- + +def is_close(obj1, obj2, direction_vec=None)->bool: + # indicate whether two objects are adjacent to each other, + # which, unlike local directional relations, carry no directional information + distance = np.abs(obj1.pos - obj2.pos) + return np.sum(distance)==20 + +def is_aligned(obj1, obj2, direction_vec=None)->bool: + # indicate if two entities are on the same horizontal or vertical line + diff = obj2.pos - obj1.pos + return np.any(diff==0) + +# ---------- Local Directional Relations --------------------------------------------------------- + +def is_top_adj(obj1, obj2, direction_vec=None)->bool: + return obj1.x==obj2.x and obj1.y==obj2.y+20 + +def is_left_adj(obj1, obj2, direction_vec=None)->bool: + return obj1.y==obj2.y and obj1.x==obj2.x-20 + +def is_top_left_adj(obj1, obj2, direction_vec=None)->bool: + return obj1.y==obj2.y+20 and obj1.x==obj2.x-20 + +def is_top_right_adj(obj1, obj2, direction_vec=None)->bool: + return obj1.y==obj2.y+20 and obj1.x==obj2.x+20 + +def is_down_adj(obj1, obj2, direction_vec=None)->bool: + return is_top_adj(obj2, obj1) + +def is_right_adj(obj1, obj2, direction_vec=None)->bool: + return is_left_adj(obj2, obj1) + +def is_down_right_adj(obj1, obj2, direction_vec=None)->bool: + return is_top_left_adj(obj2, obj1) + +def is_down_left_adj(obj1, obj2, direction_vec=None)->bool: + return is_top_right_adj(obj2, obj1) + +# ---------- More Remote Directional Relations (not used) ---------------------------------------- + +def top_left(obj1, obj2, direction_vec)->bool: + return (obj1.x-obj2.x) <= (obj1.y-obj2.y) + +def top_right(obj1, obj2, direction_vec)->bool: + return -(obj1.x-obj2.x) <= (obj1.y-obj2.y) + +def down_left(obj1, obj2, direction_vec)->bool: + return top_right(obj2, obj1, direction_vec) + +def down_right(obj1, obj2, direction_vec)->bool: + return top_left(obj2, obj1, direction_vec) + +def fan_top(obj1, obj2, direction_vec)->bool: + top_left = (obj1.x-obj2.x) <= (obj1.y-obj2.y) + top_right = -(obj1.x-obj2.x) <= (obj1.y-obj2.y) + return top_left and top_right + +def fan_down(obj1, obj2, direction_vec)->bool: + return fan_top(obj2, obj1, direction_vec) + +def fan_right(obj1, obj2, direction_vec)->bool: + down_left = (obj1.x-obj2.x) >= (obj1.y-obj2.y) + top_right = -(obj1.x-obj2.x) <= (obj1.y-obj2.y) + return down_left and top_right + +def fan_left(obj1, obj2, direction_vec)->bool: + return fan_right(obj2, obj1, direction_vec) + +# ---------- Ad-hoc Relations -------------------------------------------------------------------- + +def needs(obj1, obj2, direction_vec=None)->bool: + return np.argmax(obj1.type) == 0 and np.argmax(obj2.type) == 3 + +def opens(obj1, obj2, direction_vec=None)->bool: + return np.argmax(obj1.type) == 3 and np.argmax(obj2.type) == 8 + +def collects(obj1, obj2, direction_vec=None)->bool: + return np.argmax(obj1.type) == 0 and np.argmax(obj2.type) == 6 \ No newline at end of file diff --git a/utils/run_build_graphs.sh b/utils/run_build_graphs.sh new file mode 100644 index 0000000..4f2b03b --- /dev/null +++ b/utils/run_build_graphs.sh @@ -0,0 +1 @@ +python build_graphs.py --mode train --cpus 30 && python build_graphs.py --mode val --cpus 30