up
This commit is contained in:
parent
a333481e05
commit
de0bea7508
18 changed files with 3150 additions and 2 deletions
76
README.md
76
README.md
|
@ -1,3 +1,75 @@
|
||||||
# IRENE
|
<div align="center">
|
||||||
|
<h1> Neural Reasoning about Agents' Goals, Preferences, and Actions </h1>
|
||||||
|
|
||||||
Official code of "Neural Reasoning About Agents' Goals, Preferences, and Actions"
|
**[Matteo Bortoletto][1], [Lei Shi][2], [Andreas Bulling][3]** <br> <br>
|
||||||
|
**AAAI'24, Vancouver, CA** <br>
|
||||||
|
**[[Paper][4]]**
|
||||||
|
|
||||||
|
</div>
|
||||||
|
|
||||||
|
# Citation
|
||||||
|
If you find our code useful or use it in your own projects, please cite our paper:
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@inproceedings{bortoletto2024neural,
|
||||||
|
author = {Bortoletto, Matteo and Lei, Shi and Bulling, Andreas},
|
||||||
|
title = {{Neural Reasoning about Agents' Goals, Preferences, and Actions}},
|
||||||
|
booktitle = {Proc. 38th AAAI Conference on Artificial Intelligence (AAAI)},
|
||||||
|
year = {2024},
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
# Setup
|
||||||
|
|
||||||
|
This code is based on the [original implementation][5] of the BIB benchmark.
|
||||||
|
|
||||||
|
## Using `virtualenv`
|
||||||
|
```
|
||||||
|
python -m virtualenv /path/to/env
|
||||||
|
source /path/to/env/bin/activate
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
## Using `conda`
|
||||||
|
```
|
||||||
|
conda create --name <env_name> python=3.8.10 pip=20.0.2 cudatoolkit=10.2.89
|
||||||
|
conda activate <env_name>
|
||||||
|
pip install -r requirements_conda.txt
|
||||||
|
pip install dgl-cu102 dglgo -f https://data.dgl.ai/wheels/repo.html
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
# Running the code
|
||||||
|
|
||||||
|
## Activate the environment
|
||||||
|
Run `source bibdgl/bin/activate`.
|
||||||
|
|
||||||
|
## Index data
|
||||||
|
This will create the json files with all the indexed frames for each episode in each video.
|
||||||
|
```
|
||||||
|
python utils/index_data.py
|
||||||
|
```
|
||||||
|
You need to manually set `mode` in the dataset class (in main).
|
||||||
|
|
||||||
|
## Generate graphs
|
||||||
|
This will generate the graphs from the videos:
|
||||||
|
```
|
||||||
|
python /utils/build_graphs.py --mode MODE --cpus NUM_CPUS
|
||||||
|
```
|
||||||
|
`MODE` can be `train`, `val` or `test`. NOTE: check `utils/build_graphs.py` to make sure you're loading the correct dataset to generate the graphs you want.
|
||||||
|
|
||||||
|
## Training
|
||||||
|
You can use the `gtbc.sh`.
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
Use `run_test_tom.sh`.
|
||||||
|
|
||||||
|
# Hardware setup
|
||||||
|
All models are trained on an NVIDIA Tesla V100-SXM2-32GB GPU.
|
||||||
|
|
||||||
|
|
||||||
|
[1]: https://mattbortoletto.github.io/
|
||||||
|
[2]: https://perceptualui.org/people/shi/
|
||||||
|
[3]: https://perceptualui.org/people/bulling/
|
||||||
|
[4]: https://perceptualui.org/publications/bortoletto24_aaai.pdf
|
||||||
|
[5]: https://github.com/kanishkg/bib-baselines
|
||||||
|
|
9
run_test.sh
Normal file
9
run_test.sh
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
echo 314 e31
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=1 python test_tom.py \
|
||||||
|
--model_type graphbcrnn \
|
||||||
|
--types efficiency_irrational \
|
||||||
|
--ckpt /projects/bortoletto/icml2023_matteo/wandb/run-20221224_135525-8i1r2aqy/files/bib/8i1r2aqy/checkpoints/epoch\=31-step\=22399.ckpt \
|
||||||
|
--data_path /datasets/external/bib_evaluation_1_1/graphs/all_tasks \
|
||||||
|
--process_data 0 \
|
||||||
|
--surprise_type max
|
18
run_train.sh
Normal file
18
run_train.sh
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python train_tom.py \
|
||||||
|
--model_type graphbcrnn \
|
||||||
|
--types single_object preference instrumental_action \
|
||||||
|
--data_path /datasets/external/bib_train/graphs/all_tasks/ \
|
||||||
|
--seed 7 \
|
||||||
|
--batch_size 32 \
|
||||||
|
--max_epochs 35 \
|
||||||
|
--gpus 1 \
|
||||||
|
--auto_select_gpus True \
|
||||||
|
--num_workers 2 \
|
||||||
|
--stochastic_weight_avg True \
|
||||||
|
--lr 5e-4 \
|
||||||
|
--check_val_every_n_epoch 1 \
|
||||||
|
--track_grad_norm 2 \
|
||||||
|
--gradient_clip_val 10 \
|
||||||
|
--gnn_type RSAGEv4 \
|
||||||
|
--state_dim 96 \
|
||||||
|
--aggregation sum
|
122
test_tom.py
Normal file
122
test_tom.py
Normal file
|
@ -0,0 +1,122 @@
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import dgl
|
||||||
|
|
||||||
|
from tom.dataset import TestToMnetDGLDataset, collate_function_seq_test
|
||||||
|
from tom.model import GraphBC_T, GraphBCRNN
|
||||||
|
|
||||||
|
|
||||||
|
def get_z_scores(total, total_expected, total_unexpected):
|
||||||
|
mean = np.mean(total)
|
||||||
|
std = np.std(total)
|
||||||
|
print("Z-Score expected: ",
|
||||||
|
(np.mean(total_expected) - mean) / std)
|
||||||
|
print("Z-Score unexpected: ",
|
||||||
|
(np.mean(total_unexpected) - mean) / std)
|
||||||
|
|
||||||
|
|
||||||
|
parser = ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument('--model_type', type=str, default='graphbcrnn')
|
||||||
|
parser.add_argument('--ckpt', type=str, default=None, help='path to checkpoint')
|
||||||
|
parser.add_argument('--data_path', type=str, default=None, help='path to the data')
|
||||||
|
parser.add_argument('--process_data', type=int, default=0)
|
||||||
|
parser.add_argument('--surprise_type', type=str, default='max',
|
||||||
|
help='surprise type: mean, max. This is used for comparing the plausibility scores of the two test episodes')
|
||||||
|
parser.add_argument('--types', nargs='+', type=str,
|
||||||
|
default=[
|
||||||
|
'preference', 'multi_agent', 'inaccessible_goal',
|
||||||
|
'efficiency_irrational', 'efficiency_time','efficiency_path',
|
||||||
|
'instrumental_no_barrier', 'instrumental_blocking_barrier', 'instrumental_inconsequential_barrier'
|
||||||
|
],
|
||||||
|
help='types of tasks used for training / testing')
|
||||||
|
parser.add_argument('--filename', type=str, default='')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
filename = args.filename
|
||||||
|
|
||||||
|
if args.model_type == 'graphbct':
|
||||||
|
model = GraphBC_T.load_from_checkpoint(args.ckpt)
|
||||||
|
elif args.model_type == 'graphbcrnn':
|
||||||
|
model = GraphBCRNN.load_from_checkpoint(args.ckpt)
|
||||||
|
else:
|
||||||
|
raise ValueError('Unknown model type.')
|
||||||
|
|
||||||
|
device = 'cuda'
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
for t in args.types:
|
||||||
|
if args.model_type == 'graphbcrnn':
|
||||||
|
test_dataset = TestToMnetDGLDataset(
|
||||||
|
path=args.data_path,
|
||||||
|
task_type=t,
|
||||||
|
mode='test'
|
||||||
|
)
|
||||||
|
test_dataloader = DataLoader(
|
||||||
|
test_dataset,
|
||||||
|
batch_size=1,
|
||||||
|
num_workers=1,
|
||||||
|
pin_memory=True,
|
||||||
|
collate_fn=collate_function_seq_test,
|
||||||
|
shuffle=False
|
||||||
|
)
|
||||||
|
count = 0
|
||||||
|
total, total_expected, total_unexpected = [], [], []
|
||||||
|
pbar = tqdm(test_dataloader)
|
||||||
|
for j, batch in enumerate(pbar):
|
||||||
|
if args.model_type == 'graphbcrnn':
|
||||||
|
dem_expected_states, dem_expected_actions, dem_expected_lens, \
|
||||||
|
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, \
|
||||||
|
query_expected_frames, target_expected_actions, \
|
||||||
|
query_unexpected_frames, target_unexpected_actions = batch
|
||||||
|
dem_expected_states = dem_expected_states.to(device)
|
||||||
|
dem_expected_actions = dem_expected_actions.to(device)
|
||||||
|
dem_unexpected_states = dem_unexpected_states.to(device)
|
||||||
|
dem_unexpected_actions = dem_unexpected_actions.to(device)
|
||||||
|
target_expected_actions = target_expected_actions.to(device)
|
||||||
|
target_unexpected_actions = target_unexpected_actions.to(device)
|
||||||
|
surprise_expected = []
|
||||||
|
query_expected_frames = dgl.unbatch(query_expected_frames)
|
||||||
|
for i in range(len(query_expected_frames)):
|
||||||
|
if args.model_type == 'graphbcrnn':
|
||||||
|
test_actions, test_actions_pred = model(
|
||||||
|
[dem_expected_states, dem_expected_actions, dem_expected_lens, query_expected_frames[i].to(device), target_expected_actions[:, i, :]]
|
||||||
|
)
|
||||||
|
loss = F.mse_loss(test_actions, test_actions_pred)
|
||||||
|
surprise_expected.append(loss.cpu().detach().numpy())
|
||||||
|
mean_expected_surprise = np.mean(surprise_expected)
|
||||||
|
max_expected_surprise = np.max(surprise_expected)
|
||||||
|
|
||||||
|
# calculate the plausibility scores for the unexpected episode
|
||||||
|
surprise_unexpected = []
|
||||||
|
query_unexpected_frames = dgl.unbatch(query_unexpected_frames)
|
||||||
|
for i in range(len(query_unexpected_frames)):
|
||||||
|
if args.model_type == 'graphbcrnn':
|
||||||
|
test_actions, test_actions_pred = model(
|
||||||
|
[dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, query_unexpected_frames[i].to(device), target_unexpected_actions[:, i, :]]
|
||||||
|
)
|
||||||
|
loss = F.mse_loss(test_actions, test_actions_pred)
|
||||||
|
surprise_unexpected.append(loss.cpu().detach().numpy())
|
||||||
|
mean_unexpected_surprise = np.mean(surprise_unexpected)
|
||||||
|
max_unexpected_surprise = np.max(surprise_unexpected)
|
||||||
|
|
||||||
|
correct_mean = mean_expected_surprise < mean_unexpected_surprise + 0.5 * (mean_expected_surprise == mean_unexpected_surprise)
|
||||||
|
correct_max = max_expected_surprise < max_unexpected_surprise + 0.5 * (max_expected_surprise == max_unexpected_surprise)
|
||||||
|
if args.surprise_type == 'max':
|
||||||
|
count += correct_max
|
||||||
|
elif args.surprise_type == 'mean':
|
||||||
|
count += correct_mean
|
||||||
|
pbar.set_postfix({'accuracy': count/(j+1.), 'type': t})
|
||||||
|
|
||||||
|
total_expected.append(max_expected_surprise)
|
||||||
|
total_unexpected.append(max_unexpected_surprise)
|
||||||
|
total.append(max_expected_surprise)
|
||||||
|
total.append(max_unexpected_surprise)
|
||||||
|
get_z_scores(total, total_expected, total_unexpected)
|
0
tom/__init__.py
Normal file
0
tom/__init__.py
Normal file
310
tom/dataset.py
Normal file
310
tom/dataset.py
Normal file
|
@ -0,0 +1,310 @@
|
||||||
|
import pickle as pkl
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import torch.utils.data
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import dgl
|
||||||
|
import random
|
||||||
|
from dgl.data import DGLDataset
|
||||||
|
|
||||||
|
|
||||||
|
def collate_function_seq(batch):
|
||||||
|
#dem_frames = torch.stack([item[0] for item in batch])
|
||||||
|
dem_frames = dgl.batch([item[0] for item in batch])
|
||||||
|
dem_actions = torch.stack([item[1] for item in batch])
|
||||||
|
dem_lens = [item[2] for item in batch]
|
||||||
|
#query_frames = torch.stack([item[3] for item in batch])
|
||||||
|
query_frames = dgl.batch([item[3] for item in batch])
|
||||||
|
target_actions = torch.stack([item[4] for item in batch])
|
||||||
|
return [dem_frames, dem_actions, dem_lens, query_frames, target_actions]
|
||||||
|
|
||||||
|
def collate_function_seq_test(batch):
|
||||||
|
dem_expected_states = dgl.batch([item[0] for item in batch][0])
|
||||||
|
dem_expected_actions = torch.stack([item[1] for item in batch][0]).unsqueeze(dim=0)
|
||||||
|
dem_expected_lens = [item[2] for item in batch]
|
||||||
|
#print(dem_expected_actions.size())
|
||||||
|
dem_unexpected_states = dgl.batch([item[3] for item in batch][0])
|
||||||
|
dem_unexpected_actions = torch.stack([item[4] for item in batch][0]).unsqueeze(dim=0)
|
||||||
|
dem_unexpected_lens = [item[5] for item in batch]
|
||||||
|
query_expected_frames = dgl.batch([item[6] for item in batch])
|
||||||
|
target_expected_actions = torch.stack([item[7] for item in batch])
|
||||||
|
#print(target_expected_actions.size())
|
||||||
|
query_unexpected_frames = dgl.batch([item[8] for item in batch])
|
||||||
|
target_unexpected_actions = torch.stack([item[9] for item in batch])
|
||||||
|
return [
|
||||||
|
dem_expected_states, dem_expected_actions, dem_expected_lens, \
|
||||||
|
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, \
|
||||||
|
query_expected_frames, target_expected_actions, \
|
||||||
|
query_unexpected_frames, target_unexpected_actions
|
||||||
|
]
|
||||||
|
|
||||||
|
def collate_function_mental(batch):
|
||||||
|
dem_frames = dgl.batch([item[0] for item in batch])
|
||||||
|
dem_actions = torch.stack([item[1] for item in batch])
|
||||||
|
dem_lens = [item[2] for item in batch]
|
||||||
|
past_test_frames = dgl.batch([item[3] for item in batch])
|
||||||
|
past_test_actions = torch.stack([item[4] for item in batch])
|
||||||
|
past_test_len = [item[5] for item in batch]
|
||||||
|
query_frames = dgl.batch([item[6] for item in batch])
|
||||||
|
target_actions = torch.stack([item[7] for item in batch])
|
||||||
|
return [dem_frames, dem_actions, dem_lens, past_test_frames, past_test_actions, past_test_len, query_frames, target_actions]
|
||||||
|
|
||||||
|
|
||||||
|
class ToMnetDGLDataset(DGLDataset):
|
||||||
|
"""
|
||||||
|
Training dataset class.
|
||||||
|
"""
|
||||||
|
def __init__(self, path, types=None, mode="train"):
|
||||||
|
self.path = path
|
||||||
|
self.types = types
|
||||||
|
self.mode = mode
|
||||||
|
print('Mode:', self.mode)
|
||||||
|
|
||||||
|
if self.mode == 'train':
|
||||||
|
if len(self.types) == 4:
|
||||||
|
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/'
|
||||||
|
#self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_global/'
|
||||||
|
#self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_local/'
|
||||||
|
elif len(self.types) == 3:
|
||||||
|
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper() + '/'
|
||||||
|
print(self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper())
|
||||||
|
elif len(self.types) == 2:
|
||||||
|
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + '/'
|
||||||
|
print(self.types[0][0].upper() + self.types[1][0].upper())
|
||||||
|
elif len(self.types) == 1:
|
||||||
|
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + '/'
|
||||||
|
else: raise ValueError('Number of types different from 1 or 4.')
|
||||||
|
elif self.mode == 'val':
|
||||||
|
assert len(self.types) == 1
|
||||||
|
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/' + self.types[0] + '/'
|
||||||
|
#self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_global/' + self.types[0] + '/'
|
||||||
|
#self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_local/' + self.types[0] + '/'
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
def get_test(self, states, actions):
|
||||||
|
# now states is a batched graph -> unbatch it, take the len, pick one sub-graph
|
||||||
|
# randomly and select the corresponding action
|
||||||
|
frame_graphs = dgl.unbatch(states)
|
||||||
|
trial_len = len(frame_graphs)
|
||||||
|
query_idx = random.randint(0, trial_len - 1)
|
||||||
|
query_graph = frame_graphs[query_idx]
|
||||||
|
target_action = actions[query_idx]
|
||||||
|
return query_graph, target_action
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
with open(self.path+str(idx)+'.pkl', 'rb') as f:
|
||||||
|
states, actions, lens, _ = pkl.load(f)
|
||||||
|
# shuffle
|
||||||
|
ziplist = list(zip(states, actions, lens))
|
||||||
|
random.shuffle(ziplist)
|
||||||
|
states, actions, lens = zip(*ziplist)
|
||||||
|
# convert tuples to lists
|
||||||
|
states, actions, lens = [*states], [*actions], [*lens]
|
||||||
|
# pick last element in the list as test and pick random frame
|
||||||
|
test_s, test_a = self.get_test(states[-1], actions[-1])
|
||||||
|
dem_s = states[:-1]
|
||||||
|
dem_a = actions[:-1]
|
||||||
|
dem_lens = lens[:-1]
|
||||||
|
dem_s = dgl.batch(dem_s)
|
||||||
|
dem_a = torch.stack(dem_a)
|
||||||
|
return dem_s, dem_a, dem_lens, test_s, test_a
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(os.listdir(self.path))
|
||||||
|
|
||||||
|
|
||||||
|
class TestToMnetDGLDataset(DGLDataset):
|
||||||
|
"""
|
||||||
|
Testing dataset class.
|
||||||
|
"""
|
||||||
|
def __init__(self, path, task_type=None, mode="test"):
|
||||||
|
self.path = path
|
||||||
|
self.type = task_type
|
||||||
|
self.mode = mode
|
||||||
|
print('Mode:', self.mode)
|
||||||
|
|
||||||
|
if self.mode == 'test':
|
||||||
|
self.path = self.path + '_dgl_hetero_nobound_4feats/' + self.type + '/'
|
||||||
|
#self.path = self.path + '_dgl_hetero_nobound_4feats_global/' + self.type + '/'
|
||||||
|
#self.path = self.path + '_dgl_hetero_nobound_4feats_local/' + self.type + '/'
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
with open(self.path+str(idx)+'.pkl', 'rb') as f:
|
||||||
|
dem_expected_states, dem_expected_actions, dem_expected_lens, _, \
|
||||||
|
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, _, \
|
||||||
|
query_expected_frames, target_expected_actions, \
|
||||||
|
query_unexpected_frames, target_unexpected_actions = pkl.load(f)
|
||||||
|
assert len(dem_expected_states) == 8
|
||||||
|
assert len(dem_expected_actions) == 8
|
||||||
|
assert len(dem_expected_lens) == 8
|
||||||
|
assert len(dem_unexpected_states) == 8
|
||||||
|
assert len(dem_unexpected_actions) == 8
|
||||||
|
assert len(dem_unexpected_lens) == 8
|
||||||
|
assert len(dgl.unbatch(query_expected_frames)) == target_expected_actions.size()[0]
|
||||||
|
assert len(dgl.unbatch(query_unexpected_frames)) == target_unexpected_actions.size()[0]
|
||||||
|
# ignore n_nodes
|
||||||
|
return dem_expected_states, dem_expected_actions, dem_expected_lens, \
|
||||||
|
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, \
|
||||||
|
query_expected_frames, target_expected_actions, \
|
||||||
|
query_unexpected_frames, target_unexpected_actions
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(os.listdir(self.path))
|
||||||
|
|
||||||
|
|
||||||
|
class ToMnetDGLDatasetUndersample(DGLDataset):
|
||||||
|
"""
|
||||||
|
Training dataset class for the behavior cloning mlp model.
|
||||||
|
"""
|
||||||
|
def __init__(self, path, types=None, mode="train"):
|
||||||
|
self.path = path
|
||||||
|
self.types = types
|
||||||
|
self.mode = mode
|
||||||
|
print('Mode:', self.mode)
|
||||||
|
|
||||||
|
if self.mode == 'train':
|
||||||
|
if len(self.types) == 4:
|
||||||
|
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/'
|
||||||
|
elif len(self.types) == 3:
|
||||||
|
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper() + '/'
|
||||||
|
print(self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper())
|
||||||
|
elif len(self.types) == 2:
|
||||||
|
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + '/'
|
||||||
|
print(self.types[0][0].upper() + self.types[1][0].upper())
|
||||||
|
elif len(self.types) == 1:
|
||||||
|
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + '/'
|
||||||
|
else: raise ValueError('Number of types different from 1 or 4.')
|
||||||
|
elif self.mode == 'val':
|
||||||
|
assert len(self.types) == 1
|
||||||
|
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/' + self.types[0] + '/'
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
print('Undersampled dataset!')
|
||||||
|
|
||||||
|
def get_test(self, states, actions):
|
||||||
|
# now states is a batched graph -> unbatch it, take the len, pick one sub-graph
|
||||||
|
# randomly and select the corresponding action
|
||||||
|
frame_graphs = dgl.unbatch(states)
|
||||||
|
trial_len = len(frame_graphs)
|
||||||
|
query_idx = random.randint(0, trial_len - 1)
|
||||||
|
query_graph = frame_graphs[query_idx]
|
||||||
|
target_action = actions[query_idx]
|
||||||
|
return query_graph, target_action
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
idx = idx + 3175
|
||||||
|
with open(self.path+str(idx)+'.pkl', 'rb') as f:
|
||||||
|
states, actions, lens, _ = pkl.load(f)
|
||||||
|
# shuffle
|
||||||
|
ziplist = list(zip(states, actions, lens))
|
||||||
|
random.shuffle(ziplist)
|
||||||
|
states, actions, lens = zip(*ziplist)
|
||||||
|
# convert tuples to lists
|
||||||
|
states, actions, lens = [*states], [*actions], [*lens]
|
||||||
|
# pick last element in the list as test and pick random frame
|
||||||
|
test_s, test_a = self.get_test(states[-1], actions[-1])
|
||||||
|
dem_s = states[:-1]
|
||||||
|
dem_a = actions[:-1]
|
||||||
|
dem_lens = lens[:-1]
|
||||||
|
dem_s = dgl.batch(dem_s)
|
||||||
|
dem_a = torch.stack(dem_a)
|
||||||
|
return dem_s, dem_a, dem_lens, test_s, test_a
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(os.listdir(self.path)) - 3175
|
||||||
|
|
||||||
|
|
||||||
|
class ToMnetDGLDatasetMental(DGLDataset):
|
||||||
|
"""
|
||||||
|
Training dataset class.
|
||||||
|
"""
|
||||||
|
def __init__(self, path, types=None, mode="train"):
|
||||||
|
self.path = path
|
||||||
|
self.types = types
|
||||||
|
self.mode = mode
|
||||||
|
print('Mode:', self.mode)
|
||||||
|
|
||||||
|
if self.mode == 'train':
|
||||||
|
if len(self.types) == 4:
|
||||||
|
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/'
|
||||||
|
elif len(self.types) == 3:
|
||||||
|
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper() + '/'
|
||||||
|
print(self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper())
|
||||||
|
elif len(self.types) == 2:
|
||||||
|
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + '/'
|
||||||
|
print(self.types[0][0].upper() + self.types[1][0].upper())
|
||||||
|
elif len(self.types) == 1:
|
||||||
|
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + '/'
|
||||||
|
else: raise ValueError('Number of types different from 1 or 4.')
|
||||||
|
elif self.mode == 'val':
|
||||||
|
assert len(self.types) == 1
|
||||||
|
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/' + self.types[0] + '/'
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
def get_test(self, states, actions):
|
||||||
|
"""
|
||||||
|
return: past_test_graphs, past_test_actions, test_graph, test_action
|
||||||
|
"""
|
||||||
|
frame_graphs = dgl.unbatch(states)
|
||||||
|
trial_len = len(frame_graphs)
|
||||||
|
query_idx = random.randint(0, trial_len - 1)
|
||||||
|
test_graph = frame_graphs[query_idx]
|
||||||
|
test_action = actions[query_idx]
|
||||||
|
if query_idx > 0:
|
||||||
|
#past_test_actions = F.pad(past_test_actions, (0,0,0,41-query_idx), 'constant', 0)
|
||||||
|
if query_idx == 1:
|
||||||
|
past_test_graphs = frame_graphs[0]
|
||||||
|
past_test_actions = actions[:query_idx]
|
||||||
|
past_test_actions = F.pad(past_test_actions, (0,0,0,41-query_idx), 'constant', 0)
|
||||||
|
return past_test_graphs, past_test_actions, query_idx, test_graph, test_action
|
||||||
|
else:
|
||||||
|
past_test_graphs = frame_graphs[:query_idx]
|
||||||
|
past_test_actions = actions[:query_idx]
|
||||||
|
past_test_graphs = dgl.batch(past_test_graphs)
|
||||||
|
past_test_actions = F.pad(past_test_actions, (0,0,0,41-query_idx), 'constant', 0)
|
||||||
|
return past_test_graphs, past_test_actions, query_idx, test_graph, test_action
|
||||||
|
else:
|
||||||
|
test_action_ = F.pad(test_action.unsqueeze(0), (0,0,0,41-1), 'constant', 0)
|
||||||
|
# NOTE: since there are no past frames, return the test frame and action and query_idx=1
|
||||||
|
# not sure what would be a better alternative
|
||||||
|
return test_graph, test_action_, 1, test_graph, test_action
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
with open(self.path+str(idx)+'.pkl', 'rb') as f:
|
||||||
|
states, actions, lens, _ = pkl.load(f)
|
||||||
|
ziplist = list(zip(states, actions, lens))
|
||||||
|
random.shuffle(ziplist)
|
||||||
|
states, actions, lens = zip(*ziplist)
|
||||||
|
states, actions, lens = [*states], [*actions], [*lens]
|
||||||
|
past_test_s, past_test_a, past_test_len, test_s, test_a = self.get_test(states[-1], actions[-1])
|
||||||
|
dem_s = states[:-1]
|
||||||
|
dem_a = actions[:-1]
|
||||||
|
dem_lens = lens[:-1]
|
||||||
|
dem_s = dgl.batch(dem_s)
|
||||||
|
dem_a = torch.stack(dem_a)
|
||||||
|
return dem_s, dem_a, dem_lens, past_test_s, past_test_a, past_test_len, test_s, test_a
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(os.listdir(self.path))
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
types = [
|
||||||
|
'preference', 'multi_agent', 'inaccessible_goal',
|
||||||
|
'efficiency_irrational', 'efficiency_time','efficiency_path',
|
||||||
|
'instrumental_no_barrier', 'instrumental_blocking_barrier', 'instrumental_inconsequential_barrier'
|
||||||
|
]
|
||||||
|
|
||||||
|
mental_dataset = ToMnetDGLDatasetMental(
|
||||||
|
path='/datasets/external/bib_train/graphs/all_tasks/',
|
||||||
|
types=['instrumental_action'],
|
||||||
|
mode='train'
|
||||||
|
)
|
||||||
|
dem_frames, dem_actions, dem_lens, past_test_frames, past_test_actions, len, test_frame, test_action = mental_dataset.__getitem__(99)
|
||||||
|
breakpoint()
|
877
tom/gnn.py
Normal file
877
tom/gnn.py
Normal file
|
@ -0,0 +1,877 @@
|
||||||
|
import dgl.nn.pytorch as dglnn
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch
|
||||||
|
import dgl
|
||||||
|
import sys
|
||||||
|
import copy
|
||||||
|
|
||||||
|
from wandb import agent
|
||||||
|
sys.path.append('/projects/bortoletto/irene/')
|
||||||
|
from tom.norm import Norm
|
||||||
|
|
||||||
|
|
||||||
|
class RSAGEv4(nn.Module):
|
||||||
|
# multi-layer GNN for one single feature
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_channels,
|
||||||
|
out_channels,
|
||||||
|
rel_names,
|
||||||
|
dropout,
|
||||||
|
n_layers,
|
||||||
|
activation=nn.ELU(),
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.embedding_type = nn.Linear(9, int(hidden_channels))
|
||||||
|
self.embedding_pos = nn.Linear(2, int(hidden_channels))
|
||||||
|
self.embedding_color = nn.Linear(3, int(hidden_channels))
|
||||||
|
self.embedding_shape = nn.Linear(18, int(hidden_channels))
|
||||||
|
self.combine_attr = nn.Sequential(
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(hidden_channels*3, hidden_channels)
|
||||||
|
)
|
||||||
|
self.combine_pos = nn.Sequential(
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(hidden_channels*2, hidden_channels)
|
||||||
|
)
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
dglnn.HeteroGraphConv({
|
||||||
|
rel: dglnn.SAGEConv(
|
||||||
|
in_feats=hidden_channels,
|
||||||
|
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
|
||||||
|
aggregator_type='lstm',
|
||||||
|
feat_drop=dropout,
|
||||||
|
activation=activation if l < n_layers - 1 else None
|
||||||
|
)
|
||||||
|
for rel in rel_names}, aggregate='sum')
|
||||||
|
for l in range(n_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, g):
|
||||||
|
attr = []
|
||||||
|
attr.append(self.embedding_type(g.ndata['type'].float()))
|
||||||
|
pos = self.embedding_pos(g.ndata['pos']/170.)
|
||||||
|
attr.append(self.embedding_color(g.ndata['color']/255.))
|
||||||
|
attr.append(self.embedding_shape(g.ndata['shape'].float()))
|
||||||
|
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
|
||||||
|
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
|
||||||
|
for l, conv in enumerate(self.layers):
|
||||||
|
h = conv(g, h)
|
||||||
|
with g.local_scope():
|
||||||
|
g.ndata['h'] = h['obj']
|
||||||
|
out = dgl.mean_nodes(g, 'h')
|
||||||
|
return out
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class RAGNNv4(nn.Module):
|
||||||
|
# multi-layer GNN for one single feature
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_channels,
|
||||||
|
out_channels,
|
||||||
|
rel_names,
|
||||||
|
dropout,
|
||||||
|
n_layers,
|
||||||
|
activation=nn.ELU(),
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.embedding_type = nn.Linear(9, int(hidden_channels))
|
||||||
|
self.embedding_pos = nn.Linear(2, int(hidden_channels))
|
||||||
|
self.embedding_color = nn.Linear(3, int(hidden_channels))
|
||||||
|
self.embedding_shape = nn.Linear(18, int(hidden_channels))
|
||||||
|
self.combine_attr = nn.Sequential(
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(hidden_channels*3, hidden_channels)
|
||||||
|
)
|
||||||
|
self.combine_pos = nn.Sequential(
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(hidden_channels*2, hidden_channels)
|
||||||
|
)
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
dglnn.HeteroGraphConv({
|
||||||
|
rel: dglnn.AGNNConv()
|
||||||
|
for rel in rel_names}, aggregate='sum')
|
||||||
|
for l in range(n_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, g):
|
||||||
|
attr = []
|
||||||
|
attr.append(self.embedding_type(g.ndata['type'].float()))
|
||||||
|
pos = self.embedding_pos(g.ndata['pos']/170.)
|
||||||
|
attr.append(self.embedding_color(g.ndata['color']/255.))
|
||||||
|
attr.append(self.embedding_shape(g.ndata['shape'].float()))
|
||||||
|
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
|
||||||
|
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
|
||||||
|
for l, conv in enumerate(self.layers):
|
||||||
|
h = conv(g, h)
|
||||||
|
with g.local_scope():
|
||||||
|
g.ndata['h'] = h['obj']
|
||||||
|
out = dgl.mean_nodes(g, 'h')
|
||||||
|
return out
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class RGATv2(nn.Module):
|
||||||
|
# multi-layer GNN for one single feature
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_channels,
|
||||||
|
out_channels,
|
||||||
|
num_heads,
|
||||||
|
rel_names,
|
||||||
|
dropout,
|
||||||
|
n_layers,
|
||||||
|
activation=nn.ELU(),
|
||||||
|
residual=False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.embedding_type = nn.Embedding(9, int(hidden_channels*num_heads/4))
|
||||||
|
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads/4))
|
||||||
|
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads/4))
|
||||||
|
self.embedding_shape = nn.Embedding(18, int(hidden_channels*num_heads/4))
|
||||||
|
self.combine = nn.Sequential(
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(hidden_channels*num_heads, hidden_channels*num_heads)
|
||||||
|
)
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
dglnn.HeteroGraphConv({
|
||||||
|
rel: dglnn.GATv2Conv(
|
||||||
|
in_feats=hidden_channels*num_heads,
|
||||||
|
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
|
||||||
|
num_heads=num_heads if l < n_layers - 1 else 1,
|
||||||
|
feat_drop=dropout,
|
||||||
|
attn_drop=dropout,
|
||||||
|
residual=residual,
|
||||||
|
activation=activation if l < n_layers - 1 else None
|
||||||
|
)
|
||||||
|
for rel in rel_names}, aggregate='sum')
|
||||||
|
for l in range(n_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, g):
|
||||||
|
feats = []
|
||||||
|
feats.append(self.embedding_type(torch.argmax(g.ndata['type'], dim=1)))
|
||||||
|
feats.append(self.embedding_pos(g.ndata['pos']/170.))
|
||||||
|
feats.append(self.embedding_color(g.ndata['color']/255.))
|
||||||
|
feats.append(self.embedding_shape(torch.argmax(g.ndata['shape'], dim=1)))
|
||||||
|
h = {'obj': self.combine(torch.cat(feats, dim=1))}
|
||||||
|
for l, conv in enumerate(self.layers):
|
||||||
|
h = conv(g, h)
|
||||||
|
if l != len(self.layers) - 1:
|
||||||
|
h = {k: v.flatten(1) for k, v in h.items()}
|
||||||
|
else:
|
||||||
|
h = {k: v.mean(1) for k, v in h.items()}
|
||||||
|
with g.local_scope():
|
||||||
|
g.ndata['h'] = h['obj']
|
||||||
|
out = dgl.mean_nodes(g, 'h')
|
||||||
|
return out
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class RGATv3(nn.Module):
|
||||||
|
# multi-layer GNN for one single feature
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_channels,
|
||||||
|
out_channels,
|
||||||
|
num_heads,
|
||||||
|
rel_names,
|
||||||
|
dropout,
|
||||||
|
n_layers,
|
||||||
|
activation=nn.ELU(),
|
||||||
|
residual=False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
|
||||||
|
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
|
||||||
|
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
|
||||||
|
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
|
||||||
|
self.combine = nn.Sequential(
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(hidden_channels*num_heads*4, hidden_channels*num_heads)
|
||||||
|
)
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
dglnn.HeteroGraphConv({
|
||||||
|
rel: dglnn.GATv2Conv(
|
||||||
|
in_feats=hidden_channels*num_heads,
|
||||||
|
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
|
||||||
|
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
|
||||||
|
feat_drop=dropout,
|
||||||
|
attn_drop=dropout,
|
||||||
|
residual=residual,
|
||||||
|
activation=activation if l < n_layers - 1 else None
|
||||||
|
)
|
||||||
|
for rel in rel_names}, aggregate='sum')
|
||||||
|
for l in range(n_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, g):
|
||||||
|
feats = []
|
||||||
|
feats.append(self.embedding_type(g.ndata['type'].float()))
|
||||||
|
feats.append(self.embedding_pos(g.ndata['pos']/170.)) # NOTE: this should be 180 because I remove the boundary walls!
|
||||||
|
feats.append(self.embedding_color(g.ndata['color']/255.))
|
||||||
|
feats.append(self.embedding_shape(g.ndata['shape'].float()))
|
||||||
|
h = {'obj': self.combine(torch.cat(feats, dim=1))}
|
||||||
|
for l, conv in enumerate(self.layers):
|
||||||
|
h = conv(g, h)
|
||||||
|
if l != len(self.layers) - 1:
|
||||||
|
h = {k: v.flatten(1) for k, v in h.items()}
|
||||||
|
else:
|
||||||
|
h = {k: v.mean(1) for k, v in h.items()}
|
||||||
|
with g.local_scope():
|
||||||
|
g.ndata['h'] = h['obj']
|
||||||
|
out = dgl.mean_nodes(g, 'h')
|
||||||
|
return out
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class RGCNv2(nn.Module):
|
||||||
|
# multi-layer GNN for one single feature
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_channels,
|
||||||
|
out_channels,
|
||||||
|
rel_names,
|
||||||
|
dropout,
|
||||||
|
n_layers,
|
||||||
|
activation=nn.ReLU()
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.embedding_type = nn.Linear(9, int(hidden_channels))
|
||||||
|
self.embedding_pos = nn.Linear(2, int(hidden_channels))
|
||||||
|
self.embedding_color = nn.Linear(3, int(hidden_channels))
|
||||||
|
self.embedding_shape = nn.Linear(18, int(hidden_channels))
|
||||||
|
self.combine = nn.Sequential(
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(hidden_channels*4, hidden_channels)
|
||||||
|
)
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
dglnn.RelGraphConv(
|
||||||
|
in_feat=hidden_channels,
|
||||||
|
out_feat=hidden_channels,
|
||||||
|
num_rels=len(rel_names),
|
||||||
|
regularizer=None,
|
||||||
|
num_bases=None,
|
||||||
|
bias=True,
|
||||||
|
activation=activation,
|
||||||
|
self_loop=True,
|
||||||
|
dropout=dropout,
|
||||||
|
layer_norm=False
|
||||||
|
)
|
||||||
|
for _ in range(n_layers-1)])
|
||||||
|
self.layers.append(
|
||||||
|
dglnn.RelGraphConv(
|
||||||
|
in_feat=hidden_channels,
|
||||||
|
out_feat=out_channels,
|
||||||
|
num_rels=len(rel_names),
|
||||||
|
regularizer=None,
|
||||||
|
num_bases=None,
|
||||||
|
bias=True,
|
||||||
|
activation=activation,
|
||||||
|
self_loop=True,
|
||||||
|
dropout=dropout,
|
||||||
|
layer_norm=False
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, g):
|
||||||
|
g = g.to_homogeneous()
|
||||||
|
feats = []
|
||||||
|
feats.append(self.embedding_type(g.ndata['type'].float()))
|
||||||
|
feats.append(self.embedding_pos(g.ndata['pos']/170.))
|
||||||
|
feats.append(self.embedding_color(g.ndata['color']/255.))
|
||||||
|
feats.append(self.embedding_shape(g.ndata['shape'].float()))
|
||||||
|
h = self.combine(torch.cat(feats, dim=1))
|
||||||
|
for conv in self.layers:
|
||||||
|
h = conv(g, h, g.etypes)
|
||||||
|
with g.local_scope():
|
||||||
|
g.ndata['h'] = h
|
||||||
|
out = dgl.mean_nodes(g, 'h')
|
||||||
|
return out
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class RGATv3Agent(nn.Module):
|
||||||
|
# multi-layer GNN for one single feature
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_channels,
|
||||||
|
out_channels,
|
||||||
|
num_heads,
|
||||||
|
rel_names,
|
||||||
|
dropout,
|
||||||
|
n_layers,
|
||||||
|
activation=nn.ELU(),
|
||||||
|
residual=False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
|
||||||
|
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
|
||||||
|
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
|
||||||
|
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
|
||||||
|
self.combine = nn.Sequential(
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(hidden_channels*num_heads*4, hidden_channels*num_heads)
|
||||||
|
)
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
dglnn.HeteroGraphConv({
|
||||||
|
rel: dglnn.GATv2Conv(
|
||||||
|
in_feats=hidden_channels*num_heads,
|
||||||
|
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
|
||||||
|
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
|
||||||
|
feat_drop=dropout,
|
||||||
|
attn_drop=dropout,
|
||||||
|
residual=residual,
|
||||||
|
activation=activation if l < n_layers - 1 else None
|
||||||
|
)
|
||||||
|
for rel in rel_names}, aggregate='sum')
|
||||||
|
for l in range(n_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, g):
|
||||||
|
agent_mask = g.ndata['type'][:, 0] == 1
|
||||||
|
feats = []
|
||||||
|
feats.append(self.embedding_type(g.ndata['type'].float()))
|
||||||
|
feats.append(self.embedding_pos(g.ndata['pos']/200.))
|
||||||
|
feats.append(self.embedding_color(g.ndata['color']/255.))
|
||||||
|
feats.append(self.embedding_shape(g.ndata['shape'].float()))
|
||||||
|
h = {'obj': self.combine(torch.cat(feats, dim=1))}
|
||||||
|
for l, conv in enumerate(self.layers):
|
||||||
|
h = conv(g, h)
|
||||||
|
if l != len(self.layers) - 1:
|
||||||
|
h = {k: v.flatten(1) for k, v in h.items()}
|
||||||
|
else:
|
||||||
|
h = {k: v.mean(1) for k, v in h.items()}
|
||||||
|
with g.local_scope():
|
||||||
|
g.ndata['h'] = h['obj']
|
||||||
|
out = g.ndata['h'][agent_mask, :]
|
||||||
|
ctx = dgl.mean_nodes(g, 'h')
|
||||||
|
return out + ctx
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class RGATv4(nn.Module):
|
||||||
|
# multi-layer GNN for one single feature
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_channels,
|
||||||
|
out_channels,
|
||||||
|
num_heads,
|
||||||
|
rel_names,
|
||||||
|
dropout,
|
||||||
|
n_layers,
|
||||||
|
activation=nn.ELU(),
|
||||||
|
residual=False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
|
||||||
|
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
|
||||||
|
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
|
||||||
|
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
|
||||||
|
self.combine_attr = nn.Sequential(
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads)
|
||||||
|
)
|
||||||
|
self.combine_pos = nn.Sequential(
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads)
|
||||||
|
)
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
dglnn.HeteroGraphConv({
|
||||||
|
rel: dglnn.GATv2Conv(
|
||||||
|
in_feats=hidden_channels*num_heads,
|
||||||
|
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
|
||||||
|
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
|
||||||
|
feat_drop=dropout,
|
||||||
|
attn_drop=dropout,
|
||||||
|
residual=residual,
|
||||||
|
share_weights=False,
|
||||||
|
activation=activation if l < n_layers - 1 else None
|
||||||
|
)
|
||||||
|
for rel in rel_names}, aggregate='sum')
|
||||||
|
for l in range(n_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, g):
|
||||||
|
attr = []
|
||||||
|
attr.append(self.embedding_type(g.ndata['type'].float()))
|
||||||
|
pos = self.embedding_pos(g.ndata['pos']/170.)
|
||||||
|
attr.append(self.embedding_color(g.ndata['color']/255.))
|
||||||
|
attr.append(self.embedding_shape(g.ndata['shape'].float()))
|
||||||
|
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
|
||||||
|
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
|
||||||
|
for l, conv in enumerate(self.layers):
|
||||||
|
h = conv(g, h)
|
||||||
|
if l != len(self.layers) - 1:
|
||||||
|
h = {k: v.flatten(1) for k, v in h.items()}
|
||||||
|
else:
|
||||||
|
h = {k: v.mean(1) for k, v in h.items()}
|
||||||
|
with g.local_scope():
|
||||||
|
g.ndata['h'] = h['obj']
|
||||||
|
out = dgl.mean_nodes(g, 'h')
|
||||||
|
return out
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class RGATv4Norm(nn.Module):
|
||||||
|
# multi-layer GNN for one single feature
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_channels,
|
||||||
|
out_channels,
|
||||||
|
num_heads,
|
||||||
|
rel_names,
|
||||||
|
dropout,
|
||||||
|
n_layers,
|
||||||
|
activation=nn.ELU(),
|
||||||
|
residual=False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
|
||||||
|
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
|
||||||
|
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
|
||||||
|
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
|
||||||
|
self.combine_attr = nn.Sequential(
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads)
|
||||||
|
)
|
||||||
|
self.combine_pos = nn.Sequential(
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads)
|
||||||
|
)
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
dglnn.HeteroGraphConv({
|
||||||
|
rel: dglnn.GATv2Conv(
|
||||||
|
in_feats=hidden_channels*num_heads,
|
||||||
|
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
|
||||||
|
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
|
||||||
|
feat_drop=dropout,
|
||||||
|
attn_drop=dropout,
|
||||||
|
residual=residual,
|
||||||
|
activation=activation if l < n_layers - 1 else None
|
||||||
|
)
|
||||||
|
for rel in rel_names}, aggregate='sum')
|
||||||
|
for l in range(n_layers)
|
||||||
|
])
|
||||||
|
self.norms = nn.ModuleList([
|
||||||
|
Norm(
|
||||||
|
norm_type='gn',
|
||||||
|
hidden_dim=hidden_channels*num_heads if l < n_layers - 1 else out_channels
|
||||||
|
)
|
||||||
|
for l in range(n_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, g):
|
||||||
|
attr = []
|
||||||
|
attr.append(self.embedding_type(g.ndata['type'].float()))
|
||||||
|
pos = self.embedding_pos(g.ndata['pos']/170.)
|
||||||
|
attr.append(self.embedding_color(g.ndata['color']/255.))
|
||||||
|
attr.append(self.embedding_shape(g.ndata['shape'].float()))
|
||||||
|
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
|
||||||
|
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
|
||||||
|
for l, conv in enumerate(self.layers):
|
||||||
|
h = conv(g, h)
|
||||||
|
if l != len(self.layers) - 1:
|
||||||
|
h = {k: v.flatten(1) for k, v in h.items()}
|
||||||
|
else:
|
||||||
|
h = {k: v.mean(1) for k, v in h.items()}
|
||||||
|
h = {k: self.norms[l](g, v) for k, v in h.items()}
|
||||||
|
with g.local_scope():
|
||||||
|
g.ndata['h'] = h['obj']
|
||||||
|
out = dgl.mean_nodes(g, 'h')
|
||||||
|
return out
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class RGATv3Norm(nn.Module):
|
||||||
|
# multi-layer GNN for one single feature
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_channels,
|
||||||
|
out_channels,
|
||||||
|
num_heads,
|
||||||
|
rel_names,
|
||||||
|
dropout,
|
||||||
|
n_layers,
|
||||||
|
activation=nn.ELU(),
|
||||||
|
residual=False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
|
||||||
|
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
|
||||||
|
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
|
||||||
|
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
|
||||||
|
self.combine = nn.Sequential(
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(hidden_channels*num_heads*4, hidden_channels*num_heads)
|
||||||
|
)
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
dglnn.HeteroGraphConv({
|
||||||
|
rel: dglnn.GATv2Conv(
|
||||||
|
in_feats=hidden_channels*num_heads,
|
||||||
|
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
|
||||||
|
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
|
||||||
|
feat_drop=dropout,
|
||||||
|
attn_drop=dropout,
|
||||||
|
residual=residual,
|
||||||
|
activation=activation if l < n_layers - 1 else None
|
||||||
|
)
|
||||||
|
for rel in rel_names}, aggregate='sum')
|
||||||
|
for l in range(n_layers)
|
||||||
|
])
|
||||||
|
self.norms = nn.ModuleList([
|
||||||
|
Norm(
|
||||||
|
norm_type='gn',
|
||||||
|
hidden_dim=hidden_channels*num_heads if l < n_layers - 1 else out_channels
|
||||||
|
)
|
||||||
|
for l in range(n_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, g):
|
||||||
|
feats = []
|
||||||
|
feats.append(self.embedding_type(g.ndata['type'].float()))
|
||||||
|
feats.append(self.embedding_pos(g.ndata['pos']/170.))
|
||||||
|
feats.append(self.embedding_color(g.ndata['color']/255.))
|
||||||
|
feats.append(self.embedding_shape(g.ndata['shape'].float()))
|
||||||
|
h = {'obj': self.combine(torch.cat(feats, dim=1))}
|
||||||
|
for l, conv in enumerate(self.layers):
|
||||||
|
h = conv(g, h)
|
||||||
|
if l != len(self.layers) - 1:
|
||||||
|
h = {k: v.flatten(1) for k, v in h.items()}
|
||||||
|
else:
|
||||||
|
h = {k: v.mean(1) for k, v in h.items()}
|
||||||
|
h = {k: self.norms[l](g, v) for k, v in h.items()}
|
||||||
|
with g.local_scope():
|
||||||
|
g.ndata['h'] = h['obj']
|
||||||
|
out = dgl.mean_nodes(g, 'h')
|
||||||
|
return out
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class RGATv4Agent(nn.Module):
|
||||||
|
# multi-layer GNN for one single feature
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_channels,
|
||||||
|
out_channels,
|
||||||
|
num_heads,
|
||||||
|
rel_names,
|
||||||
|
dropout,
|
||||||
|
n_layers,
|
||||||
|
activation=nn.ELU(),
|
||||||
|
residual=False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
|
||||||
|
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
|
||||||
|
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
|
||||||
|
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
|
||||||
|
self.combine_attr = nn.Sequential(
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads)
|
||||||
|
)
|
||||||
|
self.combine_pos = nn.Sequential(
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads)
|
||||||
|
)
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
dglnn.HeteroGraphConv({
|
||||||
|
rel: dglnn.GATv2Conv(
|
||||||
|
in_feats=hidden_channels*num_heads,
|
||||||
|
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
|
||||||
|
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
|
||||||
|
feat_drop=dropout,
|
||||||
|
attn_drop=dropout,
|
||||||
|
residual=residual,
|
||||||
|
activation=activation if l < n_layers - 1 else None
|
||||||
|
)
|
||||||
|
for rel in rel_names}, aggregate='sum')
|
||||||
|
for l in range(n_layers)
|
||||||
|
])
|
||||||
|
self.combine_agent_context = nn.Linear(out_channels*2, out_channels)
|
||||||
|
|
||||||
|
def forward(self, g):
|
||||||
|
agent_mask = g.ndata['type'][:, 0] == 1
|
||||||
|
attr = []
|
||||||
|
attr.append(self.embedding_type(g.ndata['type'].float()))
|
||||||
|
pos = self.embedding_pos(g.ndata['pos']/170.)
|
||||||
|
attr.append(self.embedding_color(g.ndata['color']/255.))
|
||||||
|
attr.append(self.embedding_shape(g.ndata['shape'].float()))
|
||||||
|
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
|
||||||
|
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
|
||||||
|
for l, conv in enumerate(self.layers):
|
||||||
|
h = conv(g, h)
|
||||||
|
if l != len(self.layers) - 1:
|
||||||
|
h = {k: v.flatten(1) for k, v in h.items()}
|
||||||
|
else:
|
||||||
|
h = {k: v.mean(1) for k, v in h.items()}
|
||||||
|
with g.local_scope():
|
||||||
|
g.ndata['h'] = h['obj']
|
||||||
|
h_a = g.ndata['h'][agent_mask, :]
|
||||||
|
g_no_agent = copy.deepcopy(g)
|
||||||
|
g_no_agent.remove_nodes([i for i, x in enumerate(agent_mask) if x])
|
||||||
|
h_g = dgl.mean_nodes(g_no_agent, 'h')
|
||||||
|
out = self.combine_agent_context(torch.cat((h_a, h_g), dim=1))
|
||||||
|
return out
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class RGCNv4(nn.Module):
|
||||||
|
# multi-layer GNN for one single feature
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_channels,
|
||||||
|
out_channels,
|
||||||
|
rel_names,
|
||||||
|
n_layers,
|
||||||
|
activation=nn.ELU(),
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.embedding_type = nn.Linear(9, int(hidden_channels))
|
||||||
|
self.embedding_pos = nn.Linear(2, int(hidden_channels))
|
||||||
|
self.embedding_color = nn.Linear(3, int(hidden_channels))
|
||||||
|
self.embedding_shape = nn.Linear(18, int(hidden_channels))
|
||||||
|
self.combine_attr = nn.Sequential(
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(hidden_channels*3, hidden_channels)
|
||||||
|
)
|
||||||
|
self.combine_pos = nn.Sequential(
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(hidden_channels*2, hidden_channels)
|
||||||
|
)
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
dglnn.HeteroGraphConv({
|
||||||
|
rel: dglnn.GraphConv(
|
||||||
|
in_feats=hidden_channels,
|
||||||
|
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
|
||||||
|
activation=activation if l < n_layers - 1 else None
|
||||||
|
)
|
||||||
|
for rel in rel_names}, aggregate='sum')
|
||||||
|
for l in range(n_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, g):
|
||||||
|
attr = []
|
||||||
|
attr.append(self.embedding_type(g.ndata['type'].float()))
|
||||||
|
pos = self.embedding_pos(g.ndata['pos']/170.)
|
||||||
|
attr.append(self.embedding_color(g.ndata['color']/255.))
|
||||||
|
attr.append(self.embedding_shape(g.ndata['shape'].float()))
|
||||||
|
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
|
||||||
|
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
|
||||||
|
for l, conv in enumerate(self.layers):
|
||||||
|
h = conv(g, h)
|
||||||
|
with g.local_scope():
|
||||||
|
g.ndata['h'] = h['obj']
|
||||||
|
out = dgl.mean_nodes(g, 'h')
|
||||||
|
return out
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class RGATv5(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_channels,
|
||||||
|
out_channels,
|
||||||
|
num_heads,
|
||||||
|
rel_names,
|
||||||
|
dropout,
|
||||||
|
n_layers,
|
||||||
|
activation=nn.ELU(),
|
||||||
|
residual=False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_channels = hidden_channels
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.embedding_type = nn.Linear(9, hidden_channels*num_heads)
|
||||||
|
self.embedding_pos = nn.Linear(2, hidden_channels*num_heads)
|
||||||
|
self.embedding_color = nn.Linear(3, hidden_channels*num_heads)
|
||||||
|
self.embedding_shape = nn.Linear(18, hidden_channels*num_heads)
|
||||||
|
self.combine = nn.Sequential(
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(hidden_channels*num_heads*4, hidden_channels*num_heads)
|
||||||
|
)
|
||||||
|
self.attention = nn.Linear(hidden_channels*num_heads*4, 4)
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
dglnn.HeteroGraphConv({
|
||||||
|
rel: dglnn.GATv2Conv(
|
||||||
|
in_feats=hidden_channels*num_heads,
|
||||||
|
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
|
||||||
|
num_heads=num_heads,
|
||||||
|
feat_drop=dropout,
|
||||||
|
attn_drop=dropout,
|
||||||
|
residual=residual,
|
||||||
|
activation=activation if l < n_layers - 1 else None
|
||||||
|
)
|
||||||
|
for rel in rel_names}, aggregate='sum')
|
||||||
|
for l in range(n_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, g):
|
||||||
|
feats = []
|
||||||
|
feats.append(self.embedding_type(g.ndata['type'].float()))
|
||||||
|
feats.append(self.embedding_pos(g.ndata['pos']/170.))
|
||||||
|
feats.append(self.embedding_color(g.ndata['color']/255.))
|
||||||
|
feats.append(self.embedding_shape(g.ndata['shape'].float()))
|
||||||
|
h = torch.cat(feats, dim=1)
|
||||||
|
feat_attn = F.softmax(self.attention(h), dim=1)
|
||||||
|
h = h * feat_attn.repeat_interleave(self.hidden_channels*self.num_heads, dim=1)
|
||||||
|
h_in = self.combine(h)
|
||||||
|
h = {'obj': h_in}
|
||||||
|
for conv in self.layers:
|
||||||
|
h = conv(g, h)
|
||||||
|
h = {k: v.flatten(1) for k, v in h.items()}
|
||||||
|
#if l != len(self.layers) - 1:
|
||||||
|
# h = {k: v.flatten(1) for k, v in h.items()}
|
||||||
|
#else:
|
||||||
|
# h = {k: v.mean(1) for k, v in h.items()}
|
||||||
|
h = {k: v + h_in for k, v in h.items()}
|
||||||
|
with g.local_scope():
|
||||||
|
g.ndata['h'] = h['obj']
|
||||||
|
out = dgl.mean_nodes(g, 'h')
|
||||||
|
return out
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class RGATv6(nn.Module):
|
||||||
|
|
||||||
|
# RGATv6 = RGATv4 + Global Attention Pooling
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_channels,
|
||||||
|
out_channels,
|
||||||
|
num_heads,
|
||||||
|
rel_names,
|
||||||
|
dropout,
|
||||||
|
n_layers,
|
||||||
|
activation=nn.ELU(),
|
||||||
|
residual=False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
|
||||||
|
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
|
||||||
|
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
|
||||||
|
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
|
||||||
|
self.combine_attr = nn.Sequential(
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads)
|
||||||
|
)
|
||||||
|
self.combine_pos = nn.Sequential(
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads)
|
||||||
|
)
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
dglnn.HeteroGraphConv({
|
||||||
|
rel: dglnn.GATv2Conv(
|
||||||
|
in_feats=hidden_channels*num_heads,
|
||||||
|
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
|
||||||
|
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
|
||||||
|
feat_drop=dropout,
|
||||||
|
attn_drop=dropout,
|
||||||
|
residual=residual,
|
||||||
|
activation=activation if l < n_layers - 1 else None
|
||||||
|
)
|
||||||
|
for rel in rel_names}, aggregate='sum')
|
||||||
|
for l in range(n_layers)
|
||||||
|
])
|
||||||
|
gate_nn = nn.Linear(out_channels, 1)
|
||||||
|
self.gap = dglnn.GlobalAttentionPooling(gate_nn)
|
||||||
|
|
||||||
|
def forward(self, g):
|
||||||
|
attr = []
|
||||||
|
attr.append(self.embedding_type(g.ndata['type'].float()))
|
||||||
|
pos = self.embedding_pos(g.ndata['pos']/170.)
|
||||||
|
attr.append(self.embedding_color(g.ndata['color']/255.))
|
||||||
|
attr.append(self.embedding_shape(g.ndata['shape'].float()))
|
||||||
|
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
|
||||||
|
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
|
||||||
|
for l, conv in enumerate(self.layers):
|
||||||
|
h = conv(g, h)
|
||||||
|
if l != len(self.layers) - 1:
|
||||||
|
h = {k: v.flatten(1) for k, v in h.items()}
|
||||||
|
else:
|
||||||
|
h = {k: v.mean(1) for k, v in h.items()}
|
||||||
|
#with g.local_scope():
|
||||||
|
#g.ndata['h'] = h['obj']
|
||||||
|
#out = dgl.mean_nodes(g, 'h')
|
||||||
|
out = self.gap(g, h['obj'])
|
||||||
|
return out
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class RGATv6Agent(nn.Module):
|
||||||
|
|
||||||
|
# RGATv6 = RGATv4 + Global Attention Pooling
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_channels,
|
||||||
|
out_channels,
|
||||||
|
num_heads,
|
||||||
|
rel_names,
|
||||||
|
dropout,
|
||||||
|
n_layers,
|
||||||
|
activation=nn.ELU(),
|
||||||
|
residual=False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.embedding_type = nn.Linear(9, int(hidden_channels*num_heads))
|
||||||
|
self.embedding_pos = nn.Linear(2, int(hidden_channels*num_heads))
|
||||||
|
self.embedding_color = nn.Linear(3, int(hidden_channels*num_heads))
|
||||||
|
self.embedding_shape = nn.Linear(18, int(hidden_channels*num_heads))
|
||||||
|
self.combine_attr = nn.Sequential(
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(hidden_channels*num_heads*3, hidden_channels*num_heads)
|
||||||
|
)
|
||||||
|
self.combine_pos = nn.Sequential(
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(hidden_channels*num_heads*2, hidden_channels*num_heads)
|
||||||
|
)
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
dglnn.HeteroGraphConv({
|
||||||
|
rel: dglnn.GATv2Conv(
|
||||||
|
in_feats=hidden_channels*num_heads,
|
||||||
|
out_feats=hidden_channels if l < n_layers - 1 else out_channels,
|
||||||
|
num_heads=num_heads if l < n_layers - 1 else 1, # TODO: change to num_heads always
|
||||||
|
feat_drop=dropout,
|
||||||
|
attn_drop=dropout,
|
||||||
|
residual=residual,
|
||||||
|
activation=activation if l < n_layers - 1 else None
|
||||||
|
)
|
||||||
|
for rel in rel_names}, aggregate='sum')
|
||||||
|
for l in range(n_layers)
|
||||||
|
])
|
||||||
|
gate_nn = nn.Linear(out_channels, 1)
|
||||||
|
self.gap = dglnn.GlobalAttentionPooling(gate_nn)
|
||||||
|
self.combine_agent_context = nn.Linear(out_channels*2, out_channels)
|
||||||
|
|
||||||
|
def forward(self, g):
|
||||||
|
agent_mask = g.ndata['type'][:, 0] == 1
|
||||||
|
attr = []
|
||||||
|
attr.append(self.embedding_type(g.ndata['type'].float()))
|
||||||
|
pos = self.embedding_pos(g.ndata['pos']/170.)
|
||||||
|
attr.append(self.embedding_color(g.ndata['color']/255.))
|
||||||
|
attr.append(self.embedding_shape(g.ndata['shape'].float()))
|
||||||
|
combined_attr = self.combine_attr(torch.cat(attr, dim=1))
|
||||||
|
h = {'obj': self.combine_pos(torch.cat((pos, combined_attr), dim=1))}
|
||||||
|
for l, conv in enumerate(self.layers):
|
||||||
|
h = conv(g, h)
|
||||||
|
if l != len(self.layers) - 1:
|
||||||
|
h = {k: v.flatten(1) for k, v in h.items()}
|
||||||
|
else:
|
||||||
|
h = {k: v.mean(1) for k, v in h.items()}
|
||||||
|
with g.local_scope():
|
||||||
|
g.ndata['h'] = h['obj']
|
||||||
|
h_a = g.ndata['h'][agent_mask, :]
|
||||||
|
h_g = g.ndata['h'][~agent_mask, :]
|
||||||
|
g_no_agent = copy.deepcopy(g)
|
||||||
|
g_no_agent.remove_nodes([i for i, x in enumerate(agent_mask) if x])
|
||||||
|
h_g = self.gap(g_no_agent, h_g) # dgl.mean_nodes(g_no_agent, 'h')
|
||||||
|
out = self.combine_agent_context(torch.cat((h_a, h_g), dim=1))
|
||||||
|
return out
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------------------------
|
||||||
|
|
513
tom/model.py
Normal file
513
tom/model.py
Normal file
|
@ -0,0 +1,513 @@
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from tom.dataset import *
|
||||||
|
from tom.transformer import TransformerEncoder
|
||||||
|
from tom.gnn import RGATv2, RGATv3, RGATv3Agent, RGATv4, RGATv4Norm, RSAGEv4, RAGNNv4
|
||||||
|
|
||||||
|
|
||||||
|
class MlpModel(nn.Module):
|
||||||
|
"""Multilayer Perceptron with last layer linear.
|
||||||
|
Args:
|
||||||
|
input_size (int): number of inputs
|
||||||
|
hidden_sizes (list): can be empty list for none (linear model).
|
||||||
|
output_size: linear layer at output, or if ``None``, the last hidden size
|
||||||
|
will be the output size and will have nonlinearity applied
|
||||||
|
nonlinearity: torch nonlinearity Module (not Functional).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_size,
|
||||||
|
hidden_sizes, # Can be empty list or None for none.
|
||||||
|
output_size=None, # if None, last layer has nonlinearity applied.
|
||||||
|
nonlinearity=nn.ReLU, # Module, not Functional.
|
||||||
|
dropout=None # Dropout value
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if isinstance(hidden_sizes, int):
|
||||||
|
hidden_sizes = [hidden_sizes]
|
||||||
|
elif hidden_sizes is None:
|
||||||
|
hidden_sizes = []
|
||||||
|
hidden_layers = [nn.Linear(n_in, n_out) for n_in, n_out in
|
||||||
|
zip([input_size] + hidden_sizes[:-1], hidden_sizes)]
|
||||||
|
sequence = list()
|
||||||
|
for i, layer in enumerate(hidden_layers):
|
||||||
|
if dropout is not None:
|
||||||
|
sequence.extend([layer, nonlinearity(), nn.Dropout(dropout)])
|
||||||
|
else:
|
||||||
|
sequence.extend([layer, nonlinearity()])
|
||||||
|
|
||||||
|
if output_size is not None:
|
||||||
|
last_size = hidden_sizes[-1] if hidden_sizes else input_size
|
||||||
|
sequence.append(torch.nn.Linear(last_size, output_size))
|
||||||
|
self.model = nn.Sequential(*sequence)
|
||||||
|
self._output_size = (hidden_sizes[-1] if output_size is None
|
||||||
|
else output_size)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
"""Compute the model on the input, assuming input shape [B,input_size]."""
|
||||||
|
return self.model(input)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_size(self):
|
||||||
|
"""Retuns the output size of the model."""
|
||||||
|
return self._output_size
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class GraphBCRNN(pl.LightningModule):
|
||||||
|
"""
|
||||||
|
Implementation of the baseline model for the BC-RNN algorithm.
|
||||||
|
R-GCN + LSTM are used to encode the familiarization trials
|
||||||
|
"""
|
||||||
|
@staticmethod
|
||||||
|
def add_model_specific_args(parent_parser):
|
||||||
|
parser = ArgumentParser(parents=[parent_parser], add_help=False)
|
||||||
|
parser.add_argument('--lr', type=float, default=1e-4)
|
||||||
|
parser.add_argument('--action_dim', type=int, default=2)
|
||||||
|
parser.add_argument('--context_dim', type=int, default=32) # lstm hidden size
|
||||||
|
parser.add_argument('--beta', type=float, default=0.01)
|
||||||
|
parser.add_argument('--dropout', type=float, default=0.2)
|
||||||
|
parser.add_argument('--process_data', type=int, default=0)
|
||||||
|
parser.add_argument('--max_len', type=int, default=30)
|
||||||
|
# arguments for gnn
|
||||||
|
parser.add_argument('--gnn_type', type=str, default='RGATv4')
|
||||||
|
parser.add_argument('--state_dim', type=int, default=128) # gnn out_feats
|
||||||
|
parser.add_argument('--feats_dims', type=list, default=[9, 2, 3, 18])
|
||||||
|
parser.add_argument('--aggregation', type=str, default='sum')
|
||||||
|
# arguments for mpl
|
||||||
|
#parser.add_argument('--mpl_hid_feats', type=list, default=[256, 64, 16])
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def __init__(self, hparams):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.hparams = hparams
|
||||||
|
self.lr = self.hparams.lr
|
||||||
|
self.state_dim = self.hparams.state_dim
|
||||||
|
self.action_dim = self.hparams.action_dim
|
||||||
|
self.context_dim = self.hparams.context_dim
|
||||||
|
self.beta = self.hparams.beta
|
||||||
|
self.dropout = self.hparams.dropout
|
||||||
|
self.max_len = self.hparams.max_len
|
||||||
|
self.feats_dims = self.hparams.feats_dims # type, position, color, shape
|
||||||
|
self.rel_names = [
|
||||||
|
'is_aligned', 'is_back', 'is_close', 'is_down_adj', 'is_down_left_adj',
|
||||||
|
'is_down_right_adj', 'is_front', 'is_left', 'is_left_adj', 'is_right',
|
||||||
|
'is_right_adj', 'is_top_adj', 'is_top_left_adj', 'is_top_right_adj'
|
||||||
|
]
|
||||||
|
self.gnn_aggregation = self.hparams.aggregation
|
||||||
|
|
||||||
|
if self.hparams.gnn_type == 'RGATv2':
|
||||||
|
self.gnn_encoder = RGATv2(
|
||||||
|
hidden_channels=self.state_dim,
|
||||||
|
out_channels=self.state_dim,
|
||||||
|
num_heads=4,
|
||||||
|
rel_names=self.rel_names,
|
||||||
|
dropout=0.0,
|
||||||
|
n_layers=2,
|
||||||
|
activation=nn.ELU(),
|
||||||
|
residual=False
|
||||||
|
)
|
||||||
|
if self.hparams.gnn_type == 'RGATv3':
|
||||||
|
self.gnn_encoder = RGATv3(
|
||||||
|
hidden_channels=self.state_dim,
|
||||||
|
out_channels=self.state_dim,
|
||||||
|
num_heads=4,
|
||||||
|
rel_names=self.rel_names,
|
||||||
|
dropout=0.0,
|
||||||
|
n_layers=2,
|
||||||
|
activation=nn.ELU(),
|
||||||
|
residual=False
|
||||||
|
)
|
||||||
|
if self.hparams.gnn_type == 'RGATv4':
|
||||||
|
self.gnn_encoder = RGATv4(
|
||||||
|
hidden_channels=self.state_dim,
|
||||||
|
out_channels=self.state_dim,
|
||||||
|
num_heads=4,
|
||||||
|
rel_names=self.rel_names,
|
||||||
|
dropout=0.0,
|
||||||
|
n_layers=2,
|
||||||
|
activation=nn.ELU(),
|
||||||
|
residual=False
|
||||||
|
)
|
||||||
|
if self.hparams.gnn_type == 'RGATv3Agent':
|
||||||
|
self.gnn_encoder = RGATv3Agent(
|
||||||
|
hidden_channels=self.state_dim,
|
||||||
|
out_channels=self.state_dim,
|
||||||
|
num_heads=4,
|
||||||
|
rel_names=self.rel_names,
|
||||||
|
dropout=0.0,
|
||||||
|
n_layers=2,
|
||||||
|
activation=nn.ELU(),
|
||||||
|
residual=False
|
||||||
|
)
|
||||||
|
if self.hparams.gnn_type == 'RSAGEv4':
|
||||||
|
self.gnn_encoder = RSAGEv4(
|
||||||
|
hidden_channels=self.state_dim,
|
||||||
|
out_channels=self.state_dim,
|
||||||
|
rel_names=self.rel_names,
|
||||||
|
dropout=0.0,
|
||||||
|
n_layers=2,
|
||||||
|
activation=nn.ELU()
|
||||||
|
)
|
||||||
|
if self.gnn_aggregation == 'cat_axis_1':
|
||||||
|
self.lstm_input_size = self.state_dim * len(self.feats_dims) + self.action_dim
|
||||||
|
self.mlp_input_size = self.state_dim * len(self.feats_dims) + self.context_dim * 2
|
||||||
|
elif self.gnn_aggregation == 'sum':
|
||||||
|
self.lstm_input_size = self.state_dim + self.action_dim
|
||||||
|
self.mlp_input_size = self.state_dim + self.context_dim * 2
|
||||||
|
else:
|
||||||
|
raise ValueError('Fix this')
|
||||||
|
|
||||||
|
self.context_enc = nn.LSTM(self.lstm_input_size, self.context_dim, 2,
|
||||||
|
batch_first=True, bidirectional=True)
|
||||||
|
|
||||||
|
self.policy = MlpModel(input_size=self.mlp_input_size, hidden_sizes=[256, 128, 256],
|
||||||
|
output_size=self.action_dim, dropout=self.dropout)
|
||||||
|
|
||||||
|
self.past_samples = []
|
||||||
|
|
||||||
|
def forward(self, batch):
|
||||||
|
dem_frames, dem_actions, dem_lens, query_frame, target_action = batch
|
||||||
|
dem_actions = dem_actions.float()
|
||||||
|
target_action = target_action.float()
|
||||||
|
dem_states = self.gnn_encoder(dem_frames) # torch.Size([number of frames, 128 * number of features if cat_axis_1])
|
||||||
|
# segment according the number of frames in each episode and pad with zeros
|
||||||
|
# to obtain tensors of shape [batch size, num of trials (8), max num of frames (30), hidden dim]
|
||||||
|
b, l, s, _ = dem_actions.size()
|
||||||
|
dem_states_new = []
|
||||||
|
for batch in range(b):
|
||||||
|
dem_states_new.append(self._sequence_to_padding(dem_states, dem_lens[batch], self.max_len))
|
||||||
|
dem_states_new = torch.stack(dem_states_new).to(self.device) # torch.Size([batchsize, 8, 30, 128 * number of features if cat_axis_1])
|
||||||
|
# concatenate states and actions to get expert trajectory
|
||||||
|
dem_states_new = dem_states_new.view(b * l, s, -1) # torch.Size([batchsize*8, 30, 128 * number of features if cat_axis_1])
|
||||||
|
dem_actions = dem_actions.view(b * l, s, -1) # torch.Size([batchsize*8, 30, 128])
|
||||||
|
dem_traj = torch.cat([dem_states_new, dem_actions], dim=2) # torch.Size([batchsize*8, 30, 2 + 128 * number of features if cat_axis_1])
|
||||||
|
# embed expert trajectory to get a context embedding batch x samples x dim
|
||||||
|
dem_lens = torch.tensor([t for dl in dem_lens for t in dl]).to(torch.int64).cpu()
|
||||||
|
dem_lens = dem_lens.view(-1)
|
||||||
|
x_lstm = nn.utils.rnn.pack_padded_sequence(dem_traj, dem_lens, batch_first=True, enforce_sorted=False)
|
||||||
|
x_lstm, _ = self.context_enc(x_lstm)
|
||||||
|
x_lstm, _ = nn.utils.rnn.pad_packed_sequence(x_lstm, batch_first=True)
|
||||||
|
x_out = x_lstm[:, -1, :]
|
||||||
|
x_out = x_out.view(b, l, -1)
|
||||||
|
context = torch.mean(x_out, dim=1) # torch.Size([batchsize, 64]) (64=32*2)
|
||||||
|
# concat context embedding to the state embedding of test trajectory
|
||||||
|
test_states = self.gnn_encoder(query_frame) # torch.Size([2, 128])
|
||||||
|
test_context_states = torch.cat([context, test_states], dim=1) # torch.Size([batchsize, 192]) 192=64+128
|
||||||
|
# for each state in the test states calculate action
|
||||||
|
test_actions_pred = torch.tanh(self.policy(test_context_states)) # torch.Size([batchsize, 2])
|
||||||
|
return target_action, test_actions_pred
|
||||||
|
|
||||||
|
def _sequence_to_padding(self, x, lengths, max_length):
|
||||||
|
# declare the shape, it can work for x of any shape.
|
||||||
|
ret_tensor = torch.zeros((len(lengths), max_length) + tuple(x.shape[1:]))
|
||||||
|
cum_len = 0
|
||||||
|
for i, l in enumerate(lengths):
|
||||||
|
ret_tensor[i, :l] = x[cum_len: cum_len+l]
|
||||||
|
cum_len += l
|
||||||
|
return ret_tensor
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx):
|
||||||
|
test_actions, test_actions_pred = self.forward(batch)
|
||||||
|
loss = F.mse_loss(test_actions, test_actions_pred)
|
||||||
|
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def validation_step(self, batch, batch_idx, dataloader_idx=0):
|
||||||
|
test_actions, test_actions_pred = self.forward(batch)
|
||||||
|
loss = F.mse_loss(test_actions, test_actions_pred)
|
||||||
|
self.log('val_loss', loss, on_epoch=True, logger=True)
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
optim = torch.optim.Adam(self.parameters(), lr=self.lr)
|
||||||
|
return optim
|
||||||
|
#optim = torch.optim.AdamW(self.parameters(), lr=self.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01)
|
||||||
|
#scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optim, gamma=0.8)
|
||||||
|
#return [optim], [scheduler]
|
||||||
|
|
||||||
|
def train_dataloader(self):
|
||||||
|
train_dataset = ToMnetDGLDataset(path=self.hparams.data_path,
|
||||||
|
types=self.hparams.types,
|
||||||
|
mode='train')
|
||||||
|
train_loader = DataLoader(dataset=train_dataset,
|
||||||
|
batch_size=self.hparams.batch_size,
|
||||||
|
collate_fn=collate_function_seq,
|
||||||
|
num_workers=self.hparams.num_workers,
|
||||||
|
#pin_memory=True,
|
||||||
|
shuffle=True)
|
||||||
|
return train_loader
|
||||||
|
|
||||||
|
def val_dataloader(self):
|
||||||
|
val_datasets = []
|
||||||
|
val_loaders = []
|
||||||
|
for t in self.hparams.types:
|
||||||
|
val_datasets.append(ToMnetDGLDataset(path=self.hparams.data_path,
|
||||||
|
types=[t],
|
||||||
|
mode='val'))
|
||||||
|
val_loaders.append(DataLoader(dataset=val_datasets[-1],
|
||||||
|
batch_size=self.hparams.batch_size,
|
||||||
|
collate_fn=collate_function_seq,
|
||||||
|
num_workers=self.hparams.num_workers,
|
||||||
|
#pin_memory=True,
|
||||||
|
shuffle=False))
|
||||||
|
return val_loaders
|
||||||
|
|
||||||
|
def configure_callbacks(self):
|
||||||
|
checkpoint = ModelCheckpoint(
|
||||||
|
dirpath=None, # automatically set
|
||||||
|
#filename=self.params['bc_model']+'-'+self.params['gnn_type']+'-'+self.gnn_params['feats_aggr']+'-{epoch:02d}',
|
||||||
|
save_top_k=-1,
|
||||||
|
period=1
|
||||||
|
)
|
||||||
|
return [checkpoint]
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class GraphBC_T(pl.LightningModule):
|
||||||
|
"""
|
||||||
|
BC model with GraphTrans encoder, LSTM and MLP.
|
||||||
|
"""
|
||||||
|
@staticmethod
|
||||||
|
def add_model_specific_args(parent_parser):
|
||||||
|
parser = ArgumentParser(parents=[parent_parser], add_help=False)
|
||||||
|
parser.add_argument('--lr', type=float, default=1e-4)
|
||||||
|
parser.add_argument('--action_dim', type=int, default=2)
|
||||||
|
parser.add_argument('--context_dim', type=int, default=32) # lstm hidden size
|
||||||
|
parser.add_argument('--beta', type=float, default=0.01)
|
||||||
|
parser.add_argument('--dropout', type=float, default=0.2)
|
||||||
|
parser.add_argument('--process_data', type=int, default=0)
|
||||||
|
parser.add_argument('--max_len', type=int, default=30)
|
||||||
|
# arguments for gnn
|
||||||
|
parser.add_argument('--state_dim', type=int, default=128) # gnn out_feats
|
||||||
|
parser.add_argument('--feats_dims', type=list, default=[9, 2, 3, 18])
|
||||||
|
parser.add_argument('--aggregation', type=str, default='cat_axis_1')
|
||||||
|
parser.add_argument('--gnn_type', type=str, default='RGATv3')
|
||||||
|
# arguments for mpl
|
||||||
|
#parser.add_argument('--mpl_hid_feats', type=list, default=[256, 64, 16])
|
||||||
|
# arguments for transformer
|
||||||
|
parser.add_argument('--d_model', type=int, default=128)
|
||||||
|
parser.add_argument('--nhead', type=int, default=4)
|
||||||
|
parser.add_argument('--dim_feedforward', type=int, default=512)
|
||||||
|
parser.add_argument('--transformer_dropout', type=float, default=0.3)
|
||||||
|
parser.add_argument('--transformer_activation', type=str, default='gelu')
|
||||||
|
parser.add_argument('--num_encoder_layers', type=int, default=6)
|
||||||
|
parser.add_argument('--transformer_norm_input', type=int, default=0)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def __init__(self, hparams):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.hparams = hparams
|
||||||
|
self.lr = self.hparams.lr
|
||||||
|
self.state_dim = self.hparams.state_dim
|
||||||
|
self.action_dim = self.hparams.action_dim
|
||||||
|
self.context_dim = self.hparams.context_dim
|
||||||
|
self.beta = self.hparams.beta
|
||||||
|
self.dropout = self.hparams.dropout
|
||||||
|
self.max_len = self.hparams.max_len
|
||||||
|
self.feats_dims = self.hparams.feats_dims # type, position, color, shape
|
||||||
|
|
||||||
|
self.rel_names = [
|
||||||
|
'is_aligned', 'is_back', 'is_close', 'is_down_adj', 'is_down_left_adj',
|
||||||
|
'is_down_right_adj', 'is_front', 'is_left', 'is_left_adj', 'is_right',
|
||||||
|
'is_right_adj', 'is_top_adj', 'is_top_left_adj', 'is_top_right_adj'
|
||||||
|
]
|
||||||
|
self.gnn_aggregation = self.hparams.aggregation
|
||||||
|
if self.hparams.gnn_type == 'RGATv3':
|
||||||
|
self.gnn_encoder = RGATv3(
|
||||||
|
hidden_channels=self.state_dim,
|
||||||
|
out_channels=self.state_dim,
|
||||||
|
num_heads=4,
|
||||||
|
rel_names=self.rel_names,
|
||||||
|
dropout=0.0,
|
||||||
|
n_layers=2,
|
||||||
|
activation=nn.ELU(),
|
||||||
|
residual=False
|
||||||
|
)
|
||||||
|
if self.hparams.gnn_type == 'RGATv4':
|
||||||
|
self.gnn_encoder = RGATv4(
|
||||||
|
hidden_channels=self.state_dim,
|
||||||
|
out_channels=self.state_dim,
|
||||||
|
num_heads=4,
|
||||||
|
rel_names=self.rel_names,
|
||||||
|
dropout=0.0,
|
||||||
|
n_layers=2,
|
||||||
|
activation=nn.ELU(),
|
||||||
|
residual=False
|
||||||
|
)
|
||||||
|
if self.hparams.gnn_type == 'RSAGEv4':
|
||||||
|
self.gnn_encoder = RSAGEv4(
|
||||||
|
hidden_channels=self.state_dim,
|
||||||
|
out_channels=self.state_dim,
|
||||||
|
rel_names=self.rel_names,
|
||||||
|
dropout=0.0,
|
||||||
|
n_layers=2,
|
||||||
|
activation=nn.ELU()
|
||||||
|
)
|
||||||
|
if self.hparams.gnn_type == 'RAGNNv4':
|
||||||
|
self.gnn_encoder = RAGNNv4(
|
||||||
|
hidden_channels=self.state_dim,
|
||||||
|
out_channels=self.state_dim,
|
||||||
|
rel_names=self.rel_names,
|
||||||
|
dropout=0.0,
|
||||||
|
n_layers=2,
|
||||||
|
activation=nn.ELU()
|
||||||
|
)
|
||||||
|
if self.hparams.gnn_type == 'RGATv4Norm':
|
||||||
|
self.gnn_encoder = RGATv4Norm(
|
||||||
|
hidden_channels=self.state_dim,
|
||||||
|
out_channels=self.state_dim,
|
||||||
|
num_heads=4,
|
||||||
|
rel_names=self.rel_names,
|
||||||
|
dropout=0.0,
|
||||||
|
n_layers=2,
|
||||||
|
activation=nn.ELU(),
|
||||||
|
residual=True
|
||||||
|
)
|
||||||
|
self.d_model = self.hparams.d_model
|
||||||
|
self.nhead = self.hparams.nhead
|
||||||
|
self.dim_feedforward = self.hparams.dim_feedforward
|
||||||
|
self.transformer_dropout = self.hparams.transformer_dropout
|
||||||
|
self.transformer_activation = self.hparams.transformer_activation
|
||||||
|
self.num_encoder_layers = self.hparams.num_encoder_layers
|
||||||
|
self.transformer_norm_input = self.hparams.transformer_norm_input
|
||||||
|
self.context_enc = TransformerEncoder(
|
||||||
|
self.d_model,
|
||||||
|
self.nhead,
|
||||||
|
self.dim_feedforward,
|
||||||
|
self.transformer_dropout,
|
||||||
|
self.transformer_activation,
|
||||||
|
self.num_encoder_layers,
|
||||||
|
self.max_len,
|
||||||
|
self.transformer_norm_input
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.gnn_aggregation == 'cat_axis_1':
|
||||||
|
self.gnn2transformer = nn.Linear(self.state_dim * len(self.feats_dims) + self.action_dim, self.d_model)
|
||||||
|
self.mlp_input_size = len(self.feats_dims) * self.state_dim + self.d_model
|
||||||
|
elif self.gnn_aggregation == 'sum':
|
||||||
|
self.gnn2transformer = nn.Linear(self.state_dim + self.action_dim, self.d_model)
|
||||||
|
self.mlp_input_size = self.state_dim + self.d_model
|
||||||
|
else:
|
||||||
|
raise ValueError('Only sum and cat1 aggregations are available.')
|
||||||
|
|
||||||
|
self.policy = MlpModel(input_size=self.mlp_input_size, hidden_sizes=[256, 128, 256],
|
||||||
|
output_size=self.action_dim, dropout=self.dropout)
|
||||||
|
|
||||||
|
# CLS Embedding parameters, requires_grad=True
|
||||||
|
self.embedding = nn.Embedding(self.max_len + 1, self.d_model) # + 1 cause of cls token
|
||||||
|
self.emb_layer_norm = nn.LayerNorm(self.d_model)
|
||||||
|
self.emb_dropout = nn.Dropout(p=self.transformer_dropout)
|
||||||
|
|
||||||
|
def forward(self, batch):
|
||||||
|
dem_frames, dem_actions, dem_lens, query_frame, target_action = batch
|
||||||
|
dem_actions = dem_actions.float()
|
||||||
|
target_action = target_action.float()
|
||||||
|
dem_states = self.gnn_encoder(dem_frames)
|
||||||
|
b, l, s, _ = dem_actions.size()
|
||||||
|
dem_lens = torch.tensor([t for dl in dem_lens for t in dl]).to(torch.int64).cpu()
|
||||||
|
dem_lens = dem_lens.view(-1)
|
||||||
|
dem_actions_packed = torch.nn.utils.rnn.pack_padded_sequence(dem_actions.view(b*l, s, -1), dem_lens, batch_first=True, enforce_sorted=False)[0]
|
||||||
|
dem_traj = torch.cat([dem_states, dem_actions_packed], dim=1)
|
||||||
|
h_node = self.gnn2transformer(dem_traj)
|
||||||
|
hidden_dim = h_node.size()[1]
|
||||||
|
padded_trajectory = torch.zeros(b*l, self.max_len, hidden_dim).to(self.device)
|
||||||
|
j = 0
|
||||||
|
for idx, i in enumerate(dem_lens):
|
||||||
|
padded_trajectory[idx][:i] = h_node[j:j+i]
|
||||||
|
j += i
|
||||||
|
mask = self.make_mask(padded_trajectory).to(self.device)
|
||||||
|
transformer_input = padded_trajectory.transpose(0, 1) # [30, 16, 128]
|
||||||
|
# add cls:
|
||||||
|
cls_embedding = nn.Parameter(torch.randn([1, 1, self.d_model], requires_grad=True)).expand(1, b*l, -1).to(self.device)
|
||||||
|
transformer_input = torch.cat([transformer_input, cls_embedding], dim=0) # [31, 16, 128]
|
||||||
|
zeros = mask.data.new(mask.size(0), 1).fill_(0)
|
||||||
|
mask = torch.cat([mask, zeros], dim=1)
|
||||||
|
# Embed
|
||||||
|
indices = torch.arange(self.max_len + 1, dtype=torch.int).to(self.device) # + 1 cause of cls [0, 1, ..., 30]
|
||||||
|
positional_embeddings = self.embedding(indices).unsqueeze(1) # torch.Size([31, 1, 128])
|
||||||
|
#generate transformer input
|
||||||
|
pe_input = positional_embeddings + transformer_input # torch.Size([31, 16, 128])
|
||||||
|
# Layernorm and dropout
|
||||||
|
transformer_in = self.emb_dropout(self.emb_layer_norm(pe_input)) # torch.Size([31, 16, 128])
|
||||||
|
# transformer encoding and output parsing
|
||||||
|
out, _ = self.context_enc(transformer_in, mask) # [31, 16, 128]
|
||||||
|
cls = out[-1]
|
||||||
|
cls = cls.view(b, l, -1) # 2, 8, 128
|
||||||
|
context = torch.mean(cls, dim=1)
|
||||||
|
# CLASSIFICATION
|
||||||
|
test_states = self.gnn_encoder(query_frame) # torch.Size([2, 512])
|
||||||
|
test_context_states = torch.cat([context, test_states], dim=1) # torch.Size([batchsize, hidden_dim + lstm_hidden_dim]) 192=512+128
|
||||||
|
# for each state in the test states calculate action
|
||||||
|
x = self.policy(test_context_states)
|
||||||
|
test_actions_pred = torch.tanh(x) # torch.Size([batchsize, 2])
|
||||||
|
#test_actions_pred = torch.tanh(self.policy(test_context_states)) # torch.Size([batchsize, 2])
|
||||||
|
return target_action, test_actions_pred
|
||||||
|
|
||||||
|
def make_mask(self, feature):
|
||||||
|
return (torch.sum(
|
||||||
|
torch.abs(feature),
|
||||||
|
dim=-1
|
||||||
|
) == 0)#.unsqueeze(1).unsqueeze(2)
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx):
|
||||||
|
test_actions, test_actions_pred = self.forward(batch)
|
||||||
|
loss = F.mse_loss(test_actions, test_actions_pred)
|
||||||
|
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def validation_step(self, batch, batch_idx, dataloader_idx=0):
|
||||||
|
test_actions, test_actions_pred = self.forward(batch)
|
||||||
|
loss = F.mse_loss(test_actions, test_actions_pred)
|
||||||
|
self.log('val_loss', loss, on_epoch=True, logger=True)
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
optim = torch.optim.Adam(self.parameters(), lr=self.lr)
|
||||||
|
return optim
|
||||||
|
#optim = torch.optim.AdamW(self.parameters(), lr=self.lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01)
|
||||||
|
#scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optim, gamma=0.96)
|
||||||
|
#return [optim], [scheduler]
|
||||||
|
|
||||||
|
def train_dataloader(self):
|
||||||
|
train_dataset = ToMnetDGLDataset(path=self.hparams.data_path,
|
||||||
|
types=self.hparams.types,
|
||||||
|
mode='train')
|
||||||
|
train_loader = DataLoader(dataset=train_dataset,
|
||||||
|
batch_size=self.hparams.batch_size,
|
||||||
|
collate_fn=collate_function_seq,
|
||||||
|
num_workers=self.hparams.num_workers,
|
||||||
|
#pin_memory=True,
|
||||||
|
shuffle=True)
|
||||||
|
return train_loader
|
||||||
|
|
||||||
|
def val_dataloader(self):
|
||||||
|
val_datasets = []
|
||||||
|
val_loaders = []
|
||||||
|
for t in self.hparams.types:
|
||||||
|
val_datasets.append(ToMnetDGLDataset(path=self.hparams.data_path,
|
||||||
|
types=[t],
|
||||||
|
mode='val'))
|
||||||
|
val_loaders.append(DataLoader(dataset=val_datasets[-1],
|
||||||
|
batch_size=self.hparams.batch_size,
|
||||||
|
collate_fn=collate_function_seq,
|
||||||
|
num_workers=self.hparams.num_workers,
|
||||||
|
#pin_memory=True,
|
||||||
|
shuffle=False))
|
||||||
|
return val_loaders
|
||||||
|
|
||||||
|
def configure_callbacks(self):
|
||||||
|
checkpoint = ModelCheckpoint(
|
||||||
|
dirpath=None, # automatically set
|
||||||
|
#filename=self.params['bc_model']+'-'+self.params['gnn_type']+'-'+self.gnn_params['feats_aggr']+'-{epoch:02d}',
|
||||||
|
save_top_k=-1,
|
||||||
|
period=1
|
||||||
|
)
|
||||||
|
return [checkpoint]
|
46
tom/norm.py
Normal file
46
tom/norm.py
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class Norm(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, norm_type, hidden_dim=64, print_info=None):
|
||||||
|
super(Norm, self).__init__()
|
||||||
|
|
||||||
|
# assert norm_type in ['bn', 'ln', 'gn', None]
|
||||||
|
self.norm = None
|
||||||
|
self.print_info = print_info
|
||||||
|
if norm_type == 'bn':
|
||||||
|
self.norm = nn.BatchNorm1d(hidden_dim)
|
||||||
|
elif norm_type == 'gn':
|
||||||
|
self.norm = norm_type
|
||||||
|
self.weight = nn.Parameter(torch.ones(hidden_dim))
|
||||||
|
self.bias = nn.Parameter(torch.zeros(hidden_dim))
|
||||||
|
|
||||||
|
self.mean_scale = nn.Parameter(torch.ones(hidden_dim))
|
||||||
|
|
||||||
|
def forward(self, graph, tensor, print_=False):
|
||||||
|
|
||||||
|
if self.norm is not None and type(self.norm) != str:
|
||||||
|
return self.norm(tensor)
|
||||||
|
elif self.norm is None:
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
batch_list = graph.batch_num_nodes('obj')
|
||||||
|
batch_size = len(batch_list)
|
||||||
|
#batch_list = torch.tensor(batch_list).long().to(tensor.device)
|
||||||
|
batch_list = batch_list.long().to(tensor.device)
|
||||||
|
batch_index = torch.arange(batch_size).to(tensor.device).repeat_interleave(batch_list)
|
||||||
|
batch_index = batch_index.view((-1,) + (1,) * (tensor.dim() - 1)).expand_as(tensor)
|
||||||
|
mean = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device)
|
||||||
|
mean = mean.scatter_add_(0, batch_index, tensor)
|
||||||
|
mean = (mean.T / batch_list).T
|
||||||
|
mean = mean.repeat_interleave(batch_list, dim=0)
|
||||||
|
|
||||||
|
sub = tensor - mean * self.mean_scale
|
||||||
|
|
||||||
|
std = torch.zeros(batch_size, *tensor.shape[1:]).to(tensor.device)
|
||||||
|
std = std.scatter_add_(0, batch_index, sub.pow(2))
|
||||||
|
std = ((std.T / batch_list).T + 1e-6).sqrt()
|
||||||
|
std = std.repeat_interleave(batch_list, dim=0)
|
||||||
|
return self.weight * sub / std + self.bias
|
89
tom/transformer.py
Normal file
89
tom/transformer.py
Normal file
|
@ -0,0 +1,89 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
class PositionalEncoding(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dropout = nn.Dropout(p=dropout)
|
||||||
|
|
||||||
|
position = torch.arange(max_len).unsqueeze(1)
|
||||||
|
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
|
||||||
|
pe = torch.zeros(max_len, 1, d_model)
|
||||||
|
pe[:, 0, 0::2] = torch.sin(position * div_term)
|
||||||
|
pe[:, 0, 1::2] = torch.cos(position * div_term)
|
||||||
|
self.register_buffer('pe', pe)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: Tensor, shape [seq_len, batch_size, embedding_dim]
|
||||||
|
"""
|
||||||
|
x = x + self.pe[:x.size(0)]
|
||||||
|
return self.dropout(x)
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoder(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model,
|
||||||
|
nhead,
|
||||||
|
dim_feedforward,
|
||||||
|
transformer_dropout,
|
||||||
|
transformer_activation,
|
||||||
|
num_encoder_layers,
|
||||||
|
max_input_len,
|
||||||
|
transformer_norm_input
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.d_model = d_model
|
||||||
|
self.num_layer = num_encoder_layers
|
||||||
|
self.max_input_len = max_input_len
|
||||||
|
|
||||||
|
# Creating Transformer Encoder Model
|
||||||
|
encoder_layer = nn.TransformerEncoderLayer(
|
||||||
|
d_model=d_model,
|
||||||
|
nhead=nhead,
|
||||||
|
dim_feedforward=dim_feedforward,
|
||||||
|
dropout=transformer_dropout,
|
||||||
|
activation=transformer_activation
|
||||||
|
)
|
||||||
|
encoder_norm = nn.LayerNorm(d_model)
|
||||||
|
self.transformer = nn.TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
||||||
|
|
||||||
|
|
||||||
|
self.norm_input = None
|
||||||
|
if transformer_norm_input:
|
||||||
|
self.norm_input = nn.LayerNorm(d_model)
|
||||||
|
|
||||||
|
def forward(self, padded_h_node, src_padding_mask):
|
||||||
|
"""
|
||||||
|
padded_h_node: n_b x B x h_d # 63, 257, 128
|
||||||
|
src_key_padding_mask: B x n_b # 257, 63
|
||||||
|
"""
|
||||||
|
# (S, B, h_d), (B, S)
|
||||||
|
if self.norm_input is not None:
|
||||||
|
padded_h_node = self.norm_input(padded_h_node)
|
||||||
|
|
||||||
|
transformer_out = self.transformer(padded_h_node, src_key_padding_mask=src_padding_mask) # (S, B, h_d)
|
||||||
|
|
||||||
|
return transformer_out, src_padding_mask
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
model = TransformerEncoder(
|
||||||
|
d_model=12,
|
||||||
|
nhead=4,
|
||||||
|
dim_feedforward=32,
|
||||||
|
transformer_dropout=0.0,
|
||||||
|
transformer_activation='gelu',
|
||||||
|
num_encoder_layers=4,
|
||||||
|
max_input_len=34,
|
||||||
|
transformer_norm_input=0
|
||||||
|
)
|
||||||
|
print(model.norm_input)
|
75
train_tom.py
Normal file
75
train_tom.py
Normal file
|
@ -0,0 +1,75 @@
|
||||||
|
import random
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from pytorch_lightning import Trainer
|
||||||
|
from pytorch_lightning.loggers import WandbLogger
|
||||||
|
from tom.model import GraphBC_T, GraphBCRNN
|
||||||
|
|
||||||
|
torch.multiprocessing.set_sharing_strategy('file_system')
|
||||||
|
|
||||||
|
parser = ArgumentParser()
|
||||||
|
|
||||||
|
# program level args
|
||||||
|
parser.add_argument('--seed', type=int, default=4)
|
||||||
|
# data specific args
|
||||||
|
parser.add_argument('--data_path', type=str, default='/datasets/external/bib_train/graphs/all_tasks/')
|
||||||
|
parser.add_argument('--types', nargs='+', type=str,
|
||||||
|
default=['preference', 'multi_agent', 'single_object', 'instrumental_action'],
|
||||||
|
help='types of tasks used for training / validation')
|
||||||
|
parser.add_argument('--train', type=int, default=1)
|
||||||
|
parser.add_argument('--num_workers', type=int, default=4)
|
||||||
|
parser.add_argument('--batch_size', type=int, default=16)
|
||||||
|
parser.add_argument('--model_type', type=str, default='graphbcrnn')
|
||||||
|
|
||||||
|
# model specific args
|
||||||
|
parser_model = ArgumentParser()
|
||||||
|
parser_model = GraphBC_T.add_model_specific_args(parser_model)
|
||||||
|
# parser_model = GraphBCRNN.add_model_specific_args(parser_model)
|
||||||
|
# NOTE: here unfortunately you have to select manually the model
|
||||||
|
|
||||||
|
# add all the available trainer options to argparse
|
||||||
|
parser = Trainer.add_argparse_args(parser)
|
||||||
|
|
||||||
|
# combine parsers
|
||||||
|
parser_all = ArgumentParser(conflict_handler='resolve',
|
||||||
|
parents=[parser, parser_model])
|
||||||
|
|
||||||
|
# parse args
|
||||||
|
args = parser_all.parse_args()
|
||||||
|
args.types = sorted(args.types)
|
||||||
|
print(args)
|
||||||
|
|
||||||
|
random.seed(args.seed)
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
|
||||||
|
# init model
|
||||||
|
if args.model_type == 'graphbct':
|
||||||
|
model = GraphBC_T(args)
|
||||||
|
elif args.model_type == 'graphbcrnn':
|
||||||
|
model = GraphBCRNN(args)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
torch.autograd.set_detect_anomaly(True)
|
||||||
|
|
||||||
|
logger = WandbLogger(project='bib')
|
||||||
|
trainer = Trainer(
|
||||||
|
gradient_clip_val=args.gradient_clip_val,
|
||||||
|
gpus=args.gpus,
|
||||||
|
auto_select_gpus=args.auto_select_gpus,
|
||||||
|
track_grad_norm=args.track_grad_norm,
|
||||||
|
check_val_every_n_epoch=args.check_val_every_n_epoch,
|
||||||
|
max_epochs=args.max_epochs,
|
||||||
|
accelerator=args.accelerator,
|
||||||
|
resume_from_checkpoint=args.resume_from_checkpoint,
|
||||||
|
stochastic_weight_avg=args.stochastic_weight_avg,
|
||||||
|
num_sanity_val_steps=args.num_sanity_val_steps,
|
||||||
|
logger=logger
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.train:
|
||||||
|
trainer.fit(model)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
0
utils/__init__.py
Normal file
0
utils/__init__.py
Normal file
115
utils/build_graphs.py
Normal file
115
utils/build_graphs.py
Normal file
|
@ -0,0 +1,115 @@
|
||||||
|
import sys
|
||||||
|
sys.path.append('/projects/bortoletto/icml2023_matteo/utils')
|
||||||
|
from dataset import TransitionDataset, TestTransitionDatasetSequence
|
||||||
|
import multiprocessing as mp
|
||||||
|
import argparse
|
||||||
|
import pickle as pkl
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Instantiate the parser
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--cpus', type=int,
|
||||||
|
help='Number of processes')
|
||||||
|
parser.add_argument('--mode', type=str,
|
||||||
|
help='Train (train) or validation (val)')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
NUM_PROCESSES = args.cpus
|
||||||
|
MODE = args.mode
|
||||||
|
|
||||||
|
def generate_files(idx):
|
||||||
|
print('Generating idx', idx)
|
||||||
|
if os.path.exists(PATH+str(idx)+'.pkl'):
|
||||||
|
print('Index', idx, 'skipped.')
|
||||||
|
return
|
||||||
|
if MODE == 'train' or MODE == 'val':
|
||||||
|
states, actions, lens, n_nodes = dataset.__getitem__(idx)
|
||||||
|
with open(PATH+str(idx)+'.pkl', 'wb') as f:
|
||||||
|
pkl.dump([states, actions, lens, n_nodes], f)
|
||||||
|
elif MODE == 'test':
|
||||||
|
dem_expected_states, dem_expected_actions, dem_expected_lens, dem_expected_nodes, \
|
||||||
|
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, dem_unexpected_nodes, \
|
||||||
|
query_expected_frames, target_expected_actions, \
|
||||||
|
query_unexpected_frames, target_unexpected_actions = dataset.__getitem__(idx)
|
||||||
|
with open(PATH+str(idx)+'.pkl', 'wb') as f:
|
||||||
|
pkl.dump([
|
||||||
|
dem_expected_states, dem_expected_actions, dem_expected_lens, dem_expected_nodes, \
|
||||||
|
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, dem_unexpected_nodes, \
|
||||||
|
query_expected_frames, target_expected_actions, \
|
||||||
|
query_unexpected_frames, target_unexpected_actions], f
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError('MODE can be only train, val or test.')
|
||||||
|
print(PATH+str(idx)+'.pkl saved.')
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
if MODE == 'train':
|
||||||
|
print('TRAIN MODE')
|
||||||
|
PATH = '/datasets/external/bib_train/graphs/all_tasks/train_dgl_hetero_nobound_4feats/'
|
||||||
|
if not os.path.exists(PATH):
|
||||||
|
os.makedirs(PATH)
|
||||||
|
print(PATH, 'directory created.')
|
||||||
|
dataset = TransitionDataset(
|
||||||
|
path='/datasets/external/bib_train/',
|
||||||
|
types=['instrumental_action', 'multi_agent', 'preference', 'single_object'],
|
||||||
|
mode="train",
|
||||||
|
max_len=30,
|
||||||
|
num_test=1,
|
||||||
|
num_trials=9,
|
||||||
|
action_range=10,
|
||||||
|
process_data=0
|
||||||
|
)
|
||||||
|
pool = mp.Pool(processes=NUM_PROCESSES)
|
||||||
|
print('Starting graph generation with', NUM_PROCESSES, 'processes...')
|
||||||
|
pool.map(generate_files, [i for i in range(dataset.__len__())])
|
||||||
|
pool.close()
|
||||||
|
elif MODE == 'val':
|
||||||
|
print('VALIDATION MODE')
|
||||||
|
types = ['multi_agent', 'instrumental_action', 'preference', 'single_object']
|
||||||
|
for t in range(len(types)):
|
||||||
|
PATH = '/datasets/external/bib_train/graphs/all_tasks/val_dgl_hetero_nobound_4feats/'+types[t]+'/'
|
||||||
|
if not os.path.exists(PATH):
|
||||||
|
os.makedirs(PATH)
|
||||||
|
print(PATH, 'directory created.')
|
||||||
|
dataset = TransitionDataset(
|
||||||
|
path='/datasets/external/bib_train/',
|
||||||
|
types=[types[t]],
|
||||||
|
mode="val",
|
||||||
|
max_len=30,
|
||||||
|
num_test=1,
|
||||||
|
num_trials=9,
|
||||||
|
action_range=10,
|
||||||
|
process_data=0
|
||||||
|
)
|
||||||
|
pool = mp.Pool(processes=NUM_PROCESSES)
|
||||||
|
print('Starting', types[t], 'graph generation with', NUM_PROCESSES, 'processes...')
|
||||||
|
pool.map(generate_files, [i for i in range(dataset.__len__())])
|
||||||
|
pool.close()
|
||||||
|
elif MODE == 'test':
|
||||||
|
print('TEST MODE')
|
||||||
|
types = [
|
||||||
|
'preference', 'multi_agent', 'inaccessible_goal',
|
||||||
|
'efficiency_irrational', 'efficiency_time','efficiency_path',
|
||||||
|
'instrumental_no_barrier', 'instrumental_blocking_barrier', 'instrumental_inconsequential_barrier'
|
||||||
|
]
|
||||||
|
for t in range(len(types)):
|
||||||
|
PATH = '/datasets/external/bib_evaluation_1_1/graphs/all_tasks_dgl_hetero_nobound_4feats/'+types[t]+'/'
|
||||||
|
if not os.path.exists(PATH):
|
||||||
|
os.makedirs(PATH)
|
||||||
|
print(PATH, 'directory created.')
|
||||||
|
dataset = TestTransitionDatasetSequence(
|
||||||
|
path='/datasets/external/bib_evaluation_1_1/',
|
||||||
|
task_type=types[t],
|
||||||
|
mode="test",
|
||||||
|
num_test=1,
|
||||||
|
num_trials=9,
|
||||||
|
action_range=10,
|
||||||
|
process_data=0,
|
||||||
|
max_len=30
|
||||||
|
)
|
||||||
|
pool = mp.Pool(processes=NUM_PROCESSES)
|
||||||
|
print('Starting', types[t], 'graph generation with', NUM_PROCESSES, 'processes...')
|
||||||
|
pool.map(generate_files, [i for i in range(dataset.__len__())])
|
||||||
|
pool.close()
|
||||||
|
else:
|
||||||
|
raise ValueError
|
487
utils/dataset.py
Normal file
487
utils/dataset.py
Normal file
|
@ -0,0 +1,487 @@
|
||||||
|
import dgl
|
||||||
|
import torch
|
||||||
|
import torch.utils.data
|
||||||
|
import os
|
||||||
|
import pickle as pkl
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import sys
|
||||||
|
sys.path.append('/projects/bortoletto/irene/')
|
||||||
|
from utils.grid_object import *
|
||||||
|
from utils.relations import *
|
||||||
|
|
||||||
|
# ========================== Helper functions ==========================
|
||||||
|
|
||||||
|
def index_data(json_list, path_list):
|
||||||
|
print(f'processing files {len(json_list)}')
|
||||||
|
data_tuples = []
|
||||||
|
for j, v in tqdm(zip(json_list, path_list)):
|
||||||
|
with open(j, 'r') as f:
|
||||||
|
state = json.load(f)
|
||||||
|
ep_lens = [len(x) for x in state]
|
||||||
|
past_len = 0
|
||||||
|
for e, l in enumerate(ep_lens):
|
||||||
|
data_tuples.append([])
|
||||||
|
# skip first 30 frames and last 83 frames
|
||||||
|
for f in range(30, l - 83):
|
||||||
|
# find action taken;
|
||||||
|
f0x, f0y = state[e][f]['agent'][0]
|
||||||
|
f1x, f1y = state[e][f + 1]['agent'][0]
|
||||||
|
dx = (f1x - f0x) / 2.
|
||||||
|
dy = (f1y - f0y) / 2.
|
||||||
|
action = [dx, dy]
|
||||||
|
#data_tuples[-1].append((v, past_len + f, action))
|
||||||
|
data_tuples[-1].append((j, past_len + f, action))
|
||||||
|
# data_tuples = [[json file, frame number, action] for each episode in each video]
|
||||||
|
assert len(data_tuples[-1]) > 0
|
||||||
|
past_len += l
|
||||||
|
return data_tuples
|
||||||
|
|
||||||
|
# ========================== Dataset class ==========================
|
||||||
|
|
||||||
|
class TransitionDataset(torch.utils.data.Dataset):
|
||||||
|
"""
|
||||||
|
Training dataset class for the behavior cloning mlp model.
|
||||||
|
Args:
|
||||||
|
path: path to the dataset
|
||||||
|
types: list of video types to include
|
||||||
|
mode: train, val
|
||||||
|
num_test: number of test state-action pairs
|
||||||
|
num_trials: number of trials in an episode
|
||||||
|
action_range: number of frames to skip; actions are combined over these number of frames (displcement) of the agent
|
||||||
|
process_data: whether to the videos or not (skip if already processed)
|
||||||
|
max_len: max number of context state-action pairs
|
||||||
|
__getitem__:
|
||||||
|
returns: (states, actions, lens, n_nodes)
|
||||||
|
dem_frames: batched DGLGraph.heterograph
|
||||||
|
dem_actions: (max_len, 2)
|
||||||
|
query_frames: DGLGraph.heterograph
|
||||||
|
target_actions: (num_test, 2)
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
path,
|
||||||
|
types=None,
|
||||||
|
mode="train",
|
||||||
|
num_test=1,
|
||||||
|
num_trials=9,
|
||||||
|
action_range=10,
|
||||||
|
process_data=0,
|
||||||
|
max_len=30
|
||||||
|
):
|
||||||
|
self.path = path
|
||||||
|
self.types = types
|
||||||
|
self.mode = mode
|
||||||
|
self.num_trials = num_trials
|
||||||
|
self.num_test = num_test
|
||||||
|
self.action_range = action_range
|
||||||
|
self.max_len = max_len
|
||||||
|
self.ep_combs = self.num_trials * (self.num_trials - 2) # 9p2 - 9
|
||||||
|
self.eps = [[x, y] for x in range(self.num_trials) for y in range(self.num_trials) if x != y]
|
||||||
|
types_str = '_'.join(self.types)
|
||||||
|
self.rel_deter_func = [
|
||||||
|
is_top_adj, is_left_adj, is_top_right_adj, is_top_left_adj,
|
||||||
|
is_down_adj, is_right_adj, is_down_left_adj, is_down_right_adj,
|
||||||
|
is_left, is_right, is_front, is_back, is_aligned, is_close
|
||||||
|
]
|
||||||
|
self.path_list = []
|
||||||
|
self.json_list = []
|
||||||
|
# get video paths and json file paths
|
||||||
|
for t in types:
|
||||||
|
print(f'reading files of type {t} in {mode}')
|
||||||
|
paths = [os.path.join(self.path, t, x) for x in os.listdir(os.path.join(self.path, t)) if
|
||||||
|
x.endswith(f'.mp4')]
|
||||||
|
jsons = [os.path.join(self.path, t, x) for x in os.listdir(os.path.join(self.path, t)) if
|
||||||
|
x.endswith(f'.json') and 'index' not in x]
|
||||||
|
paths = sorted(paths)
|
||||||
|
jsons = sorted(jsons)
|
||||||
|
if mode == 'train':
|
||||||
|
self.path_list += paths[:int(0.8 * len(jsons))]
|
||||||
|
self.json_list += jsons[:int(0.8 * len(jsons))]
|
||||||
|
elif mode == 'val':
|
||||||
|
self.path_list += paths[int(0.8 * len(jsons)):]
|
||||||
|
self.json_list += jsons[int(0.8 * len(jsons)):]
|
||||||
|
else:
|
||||||
|
self.path_list += paths
|
||||||
|
self.json_list += jsons
|
||||||
|
self.data_tuples = []
|
||||||
|
if process_data:
|
||||||
|
# index the videos in the dataset directory. This is done to speed up the retrieval of videos.
|
||||||
|
# frame index, action tuples are stored
|
||||||
|
self.data_tuples = index_data(self.json_list, self.path_list)
|
||||||
|
# tuples of frame index and action (displacement of agent)
|
||||||
|
index_dict = {'data_tuples': self.data_tuples}
|
||||||
|
with open(os.path.join(self.path, f'jindex_bib_{mode}_{types_str}.json'), 'w') as fp:
|
||||||
|
json.dump(index_dict, fp)
|
||||||
|
else:
|
||||||
|
# read pre-indexed data
|
||||||
|
with open(os.path.join(self.path, f'jindex_bib_{mode}_{types_str}.json'), 'r') as fp:
|
||||||
|
index_dict = json.load(fp)
|
||||||
|
self.data_tuples = index_dict['data_tuples']
|
||||||
|
self.tot_trials = len(self.path_list) * 9
|
||||||
|
|
||||||
|
def _get_frame_graph(self, jsonfile, frame_idx):
|
||||||
|
# load json
|
||||||
|
with open(jsonfile, 'rb') as f:
|
||||||
|
frame_data = json.load(f)
|
||||||
|
flat_list = [x for xs in frame_data for x in xs]
|
||||||
|
# extract entities
|
||||||
|
grid_objs = parse_objects(flat_list[frame_idx])
|
||||||
|
# --- build the graph
|
||||||
|
adj = self._get_spatial_rel(grid_objs)
|
||||||
|
# define edges
|
||||||
|
is_top_adj_src, is_top_adj_dst = np.nonzero(adj[0])
|
||||||
|
is_left_adj_src, is_left_adj_dst = np.nonzero(adj[1])
|
||||||
|
is_top_right_adj_src, is_top_right_adj_dst = np.nonzero(adj[2])
|
||||||
|
is_top_left_adj_src, is_top_left_adj_dst = np.nonzero(adj[3])
|
||||||
|
is_down_adj_src, is_down_adj_dst = np.nonzero(adj[4])
|
||||||
|
is_right_adj_src, is_right_adj_dst = np.nonzero(adj[5])
|
||||||
|
is_down_left_adj_src, is_down_left_adj_dst = np.nonzero(adj[6])
|
||||||
|
is_down_right_adj_src, is_down_right_adj_dst = np.nonzero(adj[7])
|
||||||
|
is_left_src, is_left_dst = np.nonzero(adj[8])
|
||||||
|
is_right_src, is_right_dst = np.nonzero(adj[9])
|
||||||
|
is_front_src, is_front_dst = np.nonzero(adj[10])
|
||||||
|
is_back_src, is_back_dst = np.nonzero(adj[11])
|
||||||
|
is_aligned_src, is_aligned_dst = np.nonzero(adj[12])
|
||||||
|
is_close_src, is_close_dst = np.nonzero(adj[13])
|
||||||
|
g = dgl.heterograph({
|
||||||
|
('obj', 'is_top_adj', 'obj'): (torch.tensor(is_top_adj_src), torch.tensor(is_top_adj_dst)),
|
||||||
|
('obj', 'is_left_adj', 'obj'): (torch.tensor(is_left_adj_src), torch.tensor(is_left_adj_dst)),
|
||||||
|
('obj', 'is_top_right_adj', 'obj'): (torch.tensor(is_top_right_adj_src), torch.tensor(is_top_right_adj_dst)),
|
||||||
|
('obj', 'is_top_left_adj', 'obj'): (torch.tensor(is_top_left_adj_src), torch.tensor(is_top_left_adj_dst)),
|
||||||
|
('obj', 'is_down_adj', 'obj'): (torch.tensor(is_down_adj_src), torch.tensor(is_down_adj_dst)),
|
||||||
|
('obj', 'is_right_adj', 'obj'): (torch.tensor(is_right_adj_src), torch.tensor(is_right_adj_dst)),
|
||||||
|
('obj', 'is_down_left_adj', 'obj'): (torch.tensor(is_down_left_adj_src), torch.tensor(is_down_left_adj_dst)),
|
||||||
|
('obj', 'is_down_right_adj', 'obj'): (torch.tensor(is_down_right_adj_src), torch.tensor(is_down_right_adj_dst)),
|
||||||
|
('obj', 'is_left', 'obj'): (torch.tensor(is_left_src), torch.tensor(is_left_dst)),
|
||||||
|
('obj', 'is_right', 'obj'): (torch.tensor(is_right_src), torch.tensor(is_right_dst)),
|
||||||
|
('obj', 'is_front', 'obj'): (torch.tensor(is_front_src), torch.tensor(is_front_dst)),
|
||||||
|
('obj', 'is_back', 'obj'): (torch.tensor(is_back_src), torch.tensor(is_back_dst)),
|
||||||
|
('obj', 'is_aligned', 'obj'): (torch.tensor(is_aligned_src), torch.tensor(is_aligned_dst)),
|
||||||
|
('obj', 'is_close', 'obj'): (torch.tensor(is_close_src), torch.tensor(is_close_dst))
|
||||||
|
}, num_nodes_dict={'obj': len(grid_objs)})
|
||||||
|
g = self._add_node_features(grid_objs, g)
|
||||||
|
breakpoint()
|
||||||
|
return g
|
||||||
|
|
||||||
|
def _add_node_features(self, objs, graph):
|
||||||
|
for obj_idx, obj in enumerate(objs):
|
||||||
|
graph.nodes[obj_idx].data['type'] = torch.tensor(obj.type)
|
||||||
|
graph.nodes[obj_idx].data['pos'] = torch.tensor([[obj.x, obj.y]], dtype=torch.float32)
|
||||||
|
assert len(obj.attributes) == 2 and None not in obj.attributes[0] and None not in obj.attributes[1]
|
||||||
|
graph.nodes[obj_idx].data['color'] = torch.tensor([obj.attributes[0]])
|
||||||
|
graph.nodes[obj_idx].data['shape'] = torch.tensor([obj.attributes[1]])
|
||||||
|
return graph
|
||||||
|
|
||||||
|
def _get_spatial_rel(self, objs):
|
||||||
|
spatial_tensors = [np.zeros([len(objs), len(objs)]) for _ in range(len(self.rel_deter_func))]
|
||||||
|
for obj_idx1, obj1 in enumerate(objs):
|
||||||
|
for obj_idx2, obj2 in enumerate(objs):
|
||||||
|
direction_vec = np.array((0, -1)) # Up
|
||||||
|
for rel_idx, func in enumerate(self.rel_deter_func):
|
||||||
|
if func(obj1, obj2, direction_vec):
|
||||||
|
spatial_tensors[rel_idx][obj_idx1, obj_idx2] = 1.0
|
||||||
|
return spatial_tensors
|
||||||
|
|
||||||
|
def get_trial(self, trials, step=10):
|
||||||
|
# retrieve state embeddings and actions from cached file
|
||||||
|
states = []
|
||||||
|
actions = []
|
||||||
|
trial_len = []
|
||||||
|
lens = []
|
||||||
|
n_nodes = []
|
||||||
|
# 8 trials
|
||||||
|
for t in trials:
|
||||||
|
tl = [(t, n) for n in range(0, len(self.data_tuples[t]), step)]
|
||||||
|
if len(tl) > self.max_len: # 30
|
||||||
|
tl = tl[:self.max_len]
|
||||||
|
trial_len.append(tl)
|
||||||
|
for tl in trial_len:
|
||||||
|
states.append([])
|
||||||
|
actions.append([])
|
||||||
|
lens.append(len(tl))
|
||||||
|
for t, n in tl:
|
||||||
|
video = self.data_tuples[t][n][0]
|
||||||
|
states[-1].append(self._get_frame_graph(video, self.data_tuples[t][n][1]))
|
||||||
|
n_nodes.append(states[-1][-1].number_of_nodes())
|
||||||
|
# actions are pooled over frames
|
||||||
|
if len(self.data_tuples[t]) > n + self.action_range:
|
||||||
|
actions_xy = [d[2] for d in self.data_tuples[t][n:n + self.action_range]]
|
||||||
|
else:
|
||||||
|
actions_xy = [d[2] for d in self.data_tuples[t][n:]]
|
||||||
|
actions_xy = np.array(actions_xy)
|
||||||
|
actions_xy = np.mean(actions_xy, axis=0)
|
||||||
|
actions[-1].append(actions_xy)
|
||||||
|
states[-1] = dgl.batch(states[-1])
|
||||||
|
actions[-1] = torch.tensor(np.array(actions[-1]))
|
||||||
|
trial_actions_padded = torch.zeros(self.max_len, actions[-1].size(1))
|
||||||
|
trial_actions_padded[:actions[-1].size(0), :] = actions[-1]
|
||||||
|
actions[-1] = trial_actions_padded
|
||||||
|
return states, actions, lens, n_nodes
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
ep_trials = [idx * self.num_trials + t for t in range(self.num_trials)] # [idx, ..., idx+8]
|
||||||
|
states, actions, lens, n_nodes = self.get_trial(ep_trials, step=self.action_range)
|
||||||
|
return states, actions, lens, n_nodes
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.tot_trials // self.num_trials
|
||||||
|
|
||||||
|
|
||||||
|
class TestTransitionDatasetSequence(torch.utils.data.Dataset):
|
||||||
|
"""
|
||||||
|
Test dataset class for the behavior cloning rnn model. This dataset is used to test the model on the eval data.
|
||||||
|
This class is used to compare plausible and implausible episodes.
|
||||||
|
Args:
|
||||||
|
path: path to the dataset
|
||||||
|
types: list of video types to include
|
||||||
|
size: size of the frames to be returned
|
||||||
|
mode: test
|
||||||
|
num_context: number of context state-action pairs
|
||||||
|
num_test: number of test state-action pairs
|
||||||
|
num_trials: number of trials in an episode
|
||||||
|
action_range: number of frames to skip; actions are combined over these number of frames (displcement) of the agent
|
||||||
|
process_data: whether to the videos or not (skip if already processed)
|
||||||
|
__getitem__:
|
||||||
|
returns: (expected_dem_frames, expected_dem_actions, expected_dem_lens expected_query_frames, expected_target_actions,
|
||||||
|
unexpected_dem_frames, unexpected_dem_actions, unexpected_dem_lens, unexpected_query_frames, unexpected_target_actions)
|
||||||
|
dem_frames: (num_context, max_len, 3, size, size)
|
||||||
|
dem_actions: (num_context, max_len, 2)
|
||||||
|
dem_lens: (num_context)
|
||||||
|
query_frames: (num_test, 3, size, size)
|
||||||
|
target_actions: (num_test, 2)
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
path,
|
||||||
|
task_type=None,
|
||||||
|
mode="test",
|
||||||
|
num_test=1,
|
||||||
|
num_trials=9,
|
||||||
|
action_range=10,
|
||||||
|
process_data=0,
|
||||||
|
max_len=30
|
||||||
|
):
|
||||||
|
self.path = path
|
||||||
|
self.task_type = task_type
|
||||||
|
self.mode = mode
|
||||||
|
self.num_trials = num_trials
|
||||||
|
self.num_test = num_test
|
||||||
|
self.action_range = action_range
|
||||||
|
self.max_len = max_len
|
||||||
|
self.ep_combs = self.num_trials * (self.num_trials - 2) # 9p2 - 9
|
||||||
|
self.eps = [[x, y] for x in range(self.num_trials) for y in range(self.num_trials) if x != y]
|
||||||
|
self.path_list_exp = []
|
||||||
|
self.json_list_exp = []
|
||||||
|
self.path_list_un = []
|
||||||
|
self.json_list_un = []
|
||||||
|
|
||||||
|
print(f'reading files of type {task_type} in {mode}')
|
||||||
|
|
||||||
|
paths_expected = sorted([os.path.join(self.path, task_type, x) for x in os.listdir(os.path.join(self.path, task_type)) if
|
||||||
|
x.endswith('e.mp4')])
|
||||||
|
jsons_expected = sorted([os.path.join(self.path, task_type, x) for x in os.listdir(os.path.join(self.path, task_type)) if
|
||||||
|
x.endswith('e.json') and 'index' not in x])
|
||||||
|
paths_unexpected = sorted([os.path.join(self.path, task_type, x) for x in os.listdir(os.path.join(self.path, task_type)) if
|
||||||
|
x.endswith('u.mp4')])
|
||||||
|
jsons_unexpected = sorted([os.path.join(self.path, task_type, x) for x in os.listdir(os.path.join(self.path, task_type)) if
|
||||||
|
x.endswith('u.json') and 'index' not in x])
|
||||||
|
self.path_list_exp += paths_expected
|
||||||
|
self.json_list_exp += jsons_expected
|
||||||
|
self.path_list_un += paths_unexpected
|
||||||
|
self.json_list_un += jsons_unexpected
|
||||||
|
self.data_expected = []
|
||||||
|
self.data_unexpected = []
|
||||||
|
|
||||||
|
if process_data:
|
||||||
|
# index data. This is done to speed up video retrieval.
|
||||||
|
# frame index, action tuples are stored
|
||||||
|
self.data_expected = index_data(self.json_list_exp, self.path_list_exp)
|
||||||
|
index_dict = {'data_tuples': self.data_expected}
|
||||||
|
with open(os.path.join(self.path, f'jindex_bib_test_{task_type}e.json'), 'w') as fp:
|
||||||
|
json.dump(index_dict, fp)
|
||||||
|
|
||||||
|
self.data_unexpected = index_data(self.json_list_un, self.path_list_un)
|
||||||
|
index_dict = {'data_tuples': self.data_unexpected}
|
||||||
|
with open(os.path.join(self.path, f'jindex_bib_test_{task_type}u.json'), 'w') as fp:
|
||||||
|
json.dump(index_dict, fp)
|
||||||
|
else:
|
||||||
|
with open(os.path.join(self.path, f'jindex_bib_{mode}_{task_type}e.json'), 'r') as fp:
|
||||||
|
index_dict = json.load(fp)
|
||||||
|
self.data_expected = index_dict['data_tuples']
|
||||||
|
with open(os.path.join(self.path, f'jindex_bib_{mode}_{task_type}u.json'), 'r') as fp:
|
||||||
|
index_dict = json.load(fp)
|
||||||
|
self.data_unexpected = index_dict['data_tuples']
|
||||||
|
|
||||||
|
self.rel_deter_func = [
|
||||||
|
is_top_adj, is_left_adj, is_top_right_adj, is_top_left_adj,
|
||||||
|
is_down_adj, is_right_adj, is_down_left_adj, is_down_right_adj,
|
||||||
|
is_left, is_right, is_front, is_back, is_aligned, is_close
|
||||||
|
]
|
||||||
|
|
||||||
|
print('Done.')
|
||||||
|
|
||||||
|
def _get_frame_graph(self, jsonfile, frame_idx):
|
||||||
|
# load json
|
||||||
|
with open(jsonfile, 'rb') as f:
|
||||||
|
frame_data = json.load(f)
|
||||||
|
flat_list = [x for xs in frame_data for x in xs]
|
||||||
|
# extract entities
|
||||||
|
grid_objs = parse_objects(flat_list[frame_idx])
|
||||||
|
# --- build the graph
|
||||||
|
adj = self._get_spatial_rel(grid_objs)
|
||||||
|
# define edges
|
||||||
|
is_top_adj_src, is_top_adj_dst = np.nonzero(adj[0])
|
||||||
|
is_left_adj_src, is_left_adj_dst = np.nonzero(adj[1])
|
||||||
|
is_top_right_adj_src, is_top_right_adj_dst = np.nonzero(adj[2])
|
||||||
|
is_top_left_adj_src, is_top_left_adj_dst = np.nonzero(adj[3])
|
||||||
|
is_down_adj_src, is_down_adj_dst = np.nonzero(adj[4])
|
||||||
|
is_right_adj_src, is_right_adj_dst = np.nonzero(adj[5])
|
||||||
|
is_down_left_adj_src, is_down_left_adj_dst = np.nonzero(adj[6])
|
||||||
|
is_down_right_adj_src, is_down_right_adj_dst = np.nonzero(adj[7])
|
||||||
|
is_left_src, is_left_dst = np.nonzero(adj[8])
|
||||||
|
is_right_src, is_right_dst = np.nonzero(adj[9])
|
||||||
|
is_front_src, is_front_dst = np.nonzero(adj[10])
|
||||||
|
is_back_src, is_back_dst = np.nonzero(adj[11])
|
||||||
|
is_aligned_src, is_aligned_dst = np.nonzero(adj[12])
|
||||||
|
is_close_src, is_close_dst = np.nonzero(adj[13])
|
||||||
|
g = dgl.heterograph({
|
||||||
|
('obj', 'is_top_adj', 'obj'): (torch.tensor(is_top_adj_src), torch.tensor(is_top_adj_dst)),
|
||||||
|
('obj', 'is_left_adj', 'obj'): (torch.tensor(is_left_adj_src), torch.tensor(is_left_adj_dst)),
|
||||||
|
('obj', 'is_top_right_adj', 'obj'): (torch.tensor(is_top_right_adj_src), torch.tensor(is_top_right_adj_dst)),
|
||||||
|
('obj', 'is_top_left_adj', 'obj'): (torch.tensor(is_top_left_adj_src), torch.tensor(is_top_left_adj_dst)),
|
||||||
|
('obj', 'is_down_adj', 'obj'): (torch.tensor(is_down_adj_src), torch.tensor(is_down_adj_dst)),
|
||||||
|
('obj', 'is_right_adj', 'obj'): (torch.tensor(is_right_adj_src), torch.tensor(is_right_adj_dst)),
|
||||||
|
('obj', 'is_down_left_adj', 'obj'): (torch.tensor(is_down_left_adj_src), torch.tensor(is_down_left_adj_dst)),
|
||||||
|
('obj', 'is_down_right_adj', 'obj'): (torch.tensor(is_down_right_adj_src), torch.tensor(is_down_right_adj_dst)),
|
||||||
|
('obj', 'is_left', 'obj'): (torch.tensor(is_left_src), torch.tensor(is_left_dst)),
|
||||||
|
('obj', 'is_right', 'obj'): (torch.tensor(is_right_src), torch.tensor(is_right_dst)),
|
||||||
|
('obj', 'is_front', 'obj'): (torch.tensor(is_front_src), torch.tensor(is_front_dst)),
|
||||||
|
('obj', 'is_back', 'obj'): (torch.tensor(is_back_src), torch.tensor(is_back_dst)),
|
||||||
|
('obj', 'is_aligned', 'obj'): (torch.tensor(is_aligned_src), torch.tensor(is_aligned_dst)),
|
||||||
|
('obj', 'is_close', 'obj'): (torch.tensor(is_close_src), torch.tensor(is_close_dst))
|
||||||
|
}, num_nodes_dict={'obj': len(grid_objs)})
|
||||||
|
g = self._add_node_features(grid_objs, g)
|
||||||
|
return g
|
||||||
|
|
||||||
|
def _add_node_features(self, objs, graph):
|
||||||
|
for obj_idx, obj in enumerate(objs):
|
||||||
|
graph.nodes[obj_idx].data['type'] = torch.tensor(obj.type)
|
||||||
|
graph.nodes[obj_idx].data['pos'] = torch.tensor([[obj.x, obj.y]], dtype=torch.float32)
|
||||||
|
assert len(obj.attributes) == 2 and None not in obj.attributes[0] and None not in obj.attributes[1]
|
||||||
|
graph.nodes[obj_idx].data['color'] = torch.tensor([obj.attributes[0]])
|
||||||
|
graph.nodes[obj_idx].data['shape'] = torch.tensor([obj.attributes[1]])
|
||||||
|
return graph
|
||||||
|
|
||||||
|
def _get_spatial_rel(self, objs):
|
||||||
|
spatial_tensors = [np.zeros([len(objs), len(objs)]) for _ in range(len(self.rel_deter_func))]
|
||||||
|
for obj_idx1, obj1 in enumerate(objs):
|
||||||
|
for obj_idx2, obj2 in enumerate(objs):
|
||||||
|
direction_vec = np.array((0, -1)) # Up why??????????????
|
||||||
|
for rel_idx, func in enumerate(self.rel_deter_func):
|
||||||
|
if func(obj1, obj2, direction_vec):
|
||||||
|
spatial_tensors[rel_idx][obj_idx1, obj_idx2] = 1.0
|
||||||
|
return spatial_tensors
|
||||||
|
|
||||||
|
def get_trial(self, trials, data, step=10):
|
||||||
|
# retrieve state embeddings and actions from cached file
|
||||||
|
states = []
|
||||||
|
actions = []
|
||||||
|
trial_len = []
|
||||||
|
lens = []
|
||||||
|
n_nodes = []
|
||||||
|
for t in trials:
|
||||||
|
tl = [(t, n) for n in range(0, len(data[t]), step)]
|
||||||
|
if len(tl) > self.max_len:
|
||||||
|
tl = tl[:self.max_len]
|
||||||
|
trial_len.append(tl)
|
||||||
|
for tl in trial_len:
|
||||||
|
states.append([])
|
||||||
|
actions.append([])
|
||||||
|
lens.append(len(tl))
|
||||||
|
for t, n in tl:
|
||||||
|
video = data[t][n][0]
|
||||||
|
states[-1].append(self._get_frame_graph(video, data[t][n][1]))
|
||||||
|
n_nodes.append(states[-1][-1].number_of_nodes())
|
||||||
|
if len(data[t]) > n + self.action_range:
|
||||||
|
actions_xy = [d[2] for d in data[t][n:n + self.action_range]]
|
||||||
|
else:
|
||||||
|
actions_xy = [d[2] for d in data[t][n:]]
|
||||||
|
actions_xy = np.array(actions_xy)
|
||||||
|
actions_xy = np.mean(actions_xy, axis=0)
|
||||||
|
actions[-1].append(actions_xy)
|
||||||
|
states[-1] = dgl.batch(states[-1])
|
||||||
|
actions[-1] = torch.tensor(np.array(actions[-1]))
|
||||||
|
trial_actions_padded = torch.zeros(self.max_len, actions[-1].size(1))
|
||||||
|
trial_actions_padded[:actions[-1].size(0), :] = actions[-1]
|
||||||
|
actions[-1] = trial_actions_padded
|
||||||
|
return states, actions, lens, n_nodes
|
||||||
|
|
||||||
|
def get_test(self, trial, data, step=10):
|
||||||
|
# retrieve state embeddings and actions from cached file
|
||||||
|
states = []
|
||||||
|
actions = []
|
||||||
|
trial_len = []
|
||||||
|
trial_len += [(trial, n) for n in range(0, len(data[trial]), step)]
|
||||||
|
for t, n in trial_len:
|
||||||
|
video = data[t][n][0]
|
||||||
|
state = self._get_frame_graph(video, data[t][n][1])
|
||||||
|
if len(data[t]) > n + self.action_range:
|
||||||
|
actions_xy = [d[2] for d in data[t][n:n + self.action_range]]
|
||||||
|
else:
|
||||||
|
actions_xy = [d[2] for d in data[t][n:]]
|
||||||
|
actions_xy = np.array(actions_xy)
|
||||||
|
actions_xy = np.mean(actions_xy, axis=0)
|
||||||
|
actions.append(actions_xy)
|
||||||
|
states.append(state)
|
||||||
|
#states = torch.stack(states)
|
||||||
|
states = dgl.batch(states)
|
||||||
|
actions = torch.tensor(np.array(actions))
|
||||||
|
return states, actions
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
ep_trials = [idx * self.num_trials + t for t in range(self.num_trials)]
|
||||||
|
dem_expected_states, dem_expected_actions, dem_expected_lens, dem_expected_nodes = self.get_trial(
|
||||||
|
ep_trials[:-1], self.data_expected, step=self.action_range
|
||||||
|
)
|
||||||
|
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, dem_unexpected_nodes = self.get_trial(
|
||||||
|
ep_trials[:-1], self.data_unexpected, step=self.action_range
|
||||||
|
)
|
||||||
|
query_expected_frames, target_expected_actions = self.get_test(
|
||||||
|
ep_trials[-1], self.data_expected, step=self.action_range
|
||||||
|
)
|
||||||
|
query_unexpected_frames, target_unexpected_actions = self.get_test(
|
||||||
|
ep_trials[-1], self.data_unexpected, step=self.action_range
|
||||||
|
)
|
||||||
|
return dem_expected_states, dem_expected_actions, dem_expected_lens, dem_expected_nodes, \
|
||||||
|
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, dem_unexpected_nodes, \
|
||||||
|
query_expected_frames, target_expected_actions, \
|
||||||
|
query_unexpected_frames, target_unexpected_actions
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.path_list_exp)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
types = ['preference', 'multi_agent', 'inaccessible_goal',
|
||||||
|
'efficiency_irrational', 'efficiency_time','efficiency_path',
|
||||||
|
'instrumental_no_barrier', 'instrumental_blocking_barrier', 'instrumental_inconsequential_barrier']
|
||||||
|
for t in types:
|
||||||
|
ttd = TestTransitionDatasetSequence(path='/datasets/external/bib_evaluation_1_1/', task_type=t, process_data=0, mode='test')
|
||||||
|
for i in range(ttd.__len__()):
|
||||||
|
print(i, end='\r')
|
||||||
|
dem_expected_states, dem_expected_actions, dem_expected_lens, dem_expected_nodes, \
|
||||||
|
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, dem_unexpected_nodes, \
|
||||||
|
query_expected_frames, target_expected_actions, \
|
||||||
|
query_unexpected_frames, target_unexpected_actions = ttd.__getitem__(i)
|
||||||
|
for j in range(8):
|
||||||
|
if not torch.tensor([1., 0., 0., 0., 0., 0., 0., 0., 0.]) in dem_expected_states[j].ndata['type']:
|
||||||
|
print(i)
|
||||||
|
if not torch.tensor([1., 0., 0., 0., 0., 0., 0., 0., 0.]) in dem_unexpected_states[j].ndata['type']:
|
||||||
|
print(i)
|
||||||
|
if not torch.tensor([1., 0., 0., 0., 0., 0., 0., 0., 0.]) in query_expected_frames.ndata['type']:
|
||||||
|
print(i)
|
||||||
|
if not torch.tensor([1., 0., 0., 0., 0., 0., 0., 0., 0.]) in query_unexpected_frames.ndata['type']:
|
||||||
|
print(i)
|
174
utils/grid_object.py
Normal file
174
utils/grid_object.py
Normal file
|
@ -0,0 +1,174 @@
|
||||||
|
import json
|
||||||
|
import pdb
|
||||||
|
import numpy as np
|
||||||
|
from sklearn.preprocessing import OneHotEncoder
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
|
||||||
|
SHAPES = {
|
||||||
|
# walls
|
||||||
|
'square': 0,
|
||||||
|
# objects
|
||||||
|
'heart': 1, 'clove_tree': 2, 'moon': 3, 'wine': 4, 'double_dia': 5,
|
||||||
|
'flag': 6, 'capsule': 7, 'vase': 8, 'curved_triangle': 9, 'spoon': 10,
|
||||||
|
'medal': 11, # inaccessible goal
|
||||||
|
# home
|
||||||
|
'home': 12,
|
||||||
|
# agent
|
||||||
|
'pentagon': 13, 'clove': 14, 'kite': 15,
|
||||||
|
# key
|
||||||
|
'triangle0': 16, 'triangle180': 16, 'triangle90': 16, 'triangle270': 16,
|
||||||
|
# lock
|
||||||
|
'triangle_slot0': 17, 'triangle_slot180': 17, 'triangle_slot90': 17, 'triangle_slot270': 17
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTITIES = {
|
||||||
|
'agent': 0 , 'walls': 1, 'fuse_walls': 2, 'key': 3,
|
||||||
|
'blocking': 4, 'home': 5, 'objects': 6, 'pin': 7, 'lock': 8
|
||||||
|
}
|
||||||
|
|
||||||
|
# =============================== GridPbject class ===============================
|
||||||
|
|
||||||
|
class GridObject():
|
||||||
|
"object is specified by its location"
|
||||||
|
def __init__(self, x, y, object_type, attributes=[]):
|
||||||
|
self.x = x
|
||||||
|
self.y = y
|
||||||
|
self.type = object_type
|
||||||
|
self.attributes = attributes
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pos(self):
|
||||||
|
return np.array([self.x, self.y])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
return {'type': str(self.type),
|
||||||
|
'x': str(self.x),
|
||||||
|
'y': str(self.y),
|
||||||
|
'color': self.attributes[0],
|
||||||
|
'shape': self.attributes[1]}
|
||||||
|
|
||||||
|
# =============================== Helper functions ===============================
|
||||||
|
|
||||||
|
def type2index(key):
|
||||||
|
for name, idx in ENTITIES.items():
|
||||||
|
if name == key:
|
||||||
|
return idx
|
||||||
|
|
||||||
|
def find_shape(shape_string, print_shape=False):
|
||||||
|
try:
|
||||||
|
shape = shape_string.split('/')[-1].split('.')[0]
|
||||||
|
except:
|
||||||
|
shape = shape_string.split('.')[0]
|
||||||
|
if print_shape: print(shape)
|
||||||
|
for name, idx in SHAPES.items():
|
||||||
|
if name == shape:
|
||||||
|
return idx
|
||||||
|
|
||||||
|
def parse_objects(frame):
|
||||||
|
"""
|
||||||
|
x and y are computed differently from walls and objects
|
||||||
|
for walls x, y = obj[0][0] + obj[1][0]/2, obj[0][1] + obj[1][1]/2
|
||||||
|
for objects x, y = obj[0][0] + obj[1], obj[0][1] + obj[1]
|
||||||
|
:param obj:
|
||||||
|
:return: GridObject
|
||||||
|
"""
|
||||||
|
shape_onehot_encoder = OneHotEncoder(sparse=False)
|
||||||
|
shape_onehot_encoder.fit([[i] for i in range(len(SHAPES)-6)])
|
||||||
|
type_onehot_encoder = OneHotEncoder(sparse=False)
|
||||||
|
type_onehot_encoder.fit([[i] for i in range(len(ENTITIES))])
|
||||||
|
# remove duplicate walls
|
||||||
|
frame['walls'].sort()
|
||||||
|
frame['walls'] = list(k for k, _ in itertools.groupby(frame['walls']))
|
||||||
|
# remove boundary walls
|
||||||
|
frame['walls'] = [w for w in frame['walls'] if (w[0][0] != 0 and w[0][0] != 180 and w[0][1] != 0 and w[0][1] != 180)]
|
||||||
|
# remove duplicate fuse_walls
|
||||||
|
frame['fuse_walls'].sort()
|
||||||
|
frame['fuse_walls'] = list(k for k, _ in itertools.groupby(frame['fuse_walls']))
|
||||||
|
grid_objs = []
|
||||||
|
assert 'agent' in frame.keys()
|
||||||
|
for key in frame.keys():
|
||||||
|
#print(key)
|
||||||
|
if key == 'size':
|
||||||
|
continue
|
||||||
|
obj = frame[key]
|
||||||
|
if obj == []:
|
||||||
|
#print(key, 'skipped')
|
||||||
|
continue
|
||||||
|
obj_type = type2index(key)
|
||||||
|
obj_type = type_onehot_encoder.transform([[obj_type]])
|
||||||
|
if key == 'walls':
|
||||||
|
for wall in obj:
|
||||||
|
x, y = wall[0][0] + wall[1][0]/2, wall[0][1] + wall[1][1]/2
|
||||||
|
#x, y = (wall[0][0] + wall[1][0]/2)/200, (wall[0][1] + wall[1][1]/2)/200 if u use this you need to change relations.py!!!
|
||||||
|
color = [0, 0, 0] if key == 'walls' else [80, 146, 56]
|
||||||
|
#color = [c / 255 for c in color]
|
||||||
|
shape = 0
|
||||||
|
assert shape in SHAPES.values(), 'Shape not found'
|
||||||
|
shape = shape_onehot_encoder.transform([[shape]])[0]
|
||||||
|
grid_obj = GridObject(x=x, y=y, object_type=obj_type, attributes=[color, shape])
|
||||||
|
grid_objs.append(grid_obj)
|
||||||
|
elif key == 'fuse_walls':
|
||||||
|
# resample green barriers
|
||||||
|
obj = [obj[i] for i in range(len(obj)) if (obj[i][0][0] % 20 == 0 and obj[i][0][1] % 20 == 0)]
|
||||||
|
for wall in obj:
|
||||||
|
x, y = wall[0][0] + wall[1][0]/2, wall[0][1] + wall[1][1]/2
|
||||||
|
#x, y = (wall[0][0] + wall[1][0]/2)/200, (wall[0][1] + wall[1][1]/2)/200 if u use this you need to change relations.py!!!
|
||||||
|
color = [80, 146, 56]
|
||||||
|
#color = [c / 255 for c in color]
|
||||||
|
shape = 0
|
||||||
|
assert shape in SHAPES.values(), 'Shape not found'
|
||||||
|
shape = shape_onehot_encoder.transform([[shape]])[0]
|
||||||
|
grid_obj = GridObject(x=x, y=y, object_type=obj_type, attributes=[color, shape])
|
||||||
|
grid_objs.append(grid_obj)
|
||||||
|
elif key == 'objects':
|
||||||
|
for ob in obj:
|
||||||
|
x, y = ob[0][0] + ob[1], ob[0][1] + ob[1]
|
||||||
|
color = ob[-1]
|
||||||
|
#color = [c / 255 for c in color]
|
||||||
|
shape = find_shape(ob[2], print_shape=False)
|
||||||
|
assert shape in SHAPES.values(), 'Shape not found'
|
||||||
|
shape = shape_onehot_encoder.transform([[shape]])[0]
|
||||||
|
grid_obj = GridObject(x=x, y=y, object_type=obj_type, attributes=[color, shape])
|
||||||
|
grid_objs.append(grid_obj)
|
||||||
|
elif key == 'key':
|
||||||
|
obj = obj[0]
|
||||||
|
x, y = obj[0][0] + obj[1], obj[0][1] + obj[1]
|
||||||
|
color = obj[-1]
|
||||||
|
#color = [c / 255 for c in color]
|
||||||
|
shape = find_shape(obj[2], print_shape=False)
|
||||||
|
assert shape in SHAPES.values(), 'Shape not found'
|
||||||
|
shape = shape_onehot_encoder.transform([[shape]])[0]
|
||||||
|
grid_obj = GridObject(x=x, y=y, object_type=obj_type, attributes=[color, shape])
|
||||||
|
grid_objs.append(grid_obj)
|
||||||
|
elif key == 'lock':
|
||||||
|
obj = obj[0]
|
||||||
|
x, y = obj[0][0] + obj[1], obj[0][1] + obj[1]
|
||||||
|
color = obj[-1]
|
||||||
|
#color = [c / 255 for c in color]
|
||||||
|
shape = find_shape(obj[2], print_shape=False)
|
||||||
|
assert shape in SHAPES.values(), 'Shape not found'
|
||||||
|
shape = shape_onehot_encoder.transform([[shape]])[0]
|
||||||
|
grid_obj = GridObject(x=x, y=y, object_type=obj_type, attributes=[color, shape])
|
||||||
|
grid_objs.append(grid_obj)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
x, y = obj[0][0] + obj[1], obj[0][1] + obj[1]
|
||||||
|
color = obj[-1]
|
||||||
|
#color = [c / 255 for c in color]
|
||||||
|
shape = find_shape(obj[2], print_shape=False)
|
||||||
|
assert shape in SHAPES.values(), 'Shape not found'
|
||||||
|
shape = shape_onehot_encoder.transform([[shape]])[0]
|
||||||
|
except:
|
||||||
|
# [[[x, y], extension, shape, color]] in some cases in instrumental_no_barrier (bib_evaluation_1_1)
|
||||||
|
x, y = obj[0][0][0] + obj[0][1], obj[0][0][1] + obj[0][1]
|
||||||
|
color = obj[0][-1]
|
||||||
|
#color = [c / 255 for c in color]
|
||||||
|
assert len(color) == 3
|
||||||
|
shape = find_shape(obj[0][2], print_shape=False)
|
||||||
|
assert shape in SHAPES.values(), 'Shape not found'
|
||||||
|
shape = shape_onehot_encoder.transform([[shape]])[0]
|
||||||
|
grid_obj = GridObject(x=x, y=y, object_type=obj_type, attributes=[color, shape])
|
||||||
|
grid_objs.append(grid_obj)
|
||||||
|
return grid_objs
|
124
utils/index_data.py
Normal file
124
utils/index_data.py
Normal file
|
@ -0,0 +1,124 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import torch.utils.data
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def index_data(json_list, path_list):
|
||||||
|
print(f'processing files {len(json_list)}')
|
||||||
|
data_tuples = []
|
||||||
|
for j, v in tqdm(zip(json_list, path_list)):
|
||||||
|
with open(j, 'r') as f:
|
||||||
|
state = json.load(f)
|
||||||
|
ep_lens = [len(x) for x in state]
|
||||||
|
past_len = 0
|
||||||
|
for e, l in enumerate(ep_lens):
|
||||||
|
data_tuples.append([])
|
||||||
|
# skip first 30 frames and last 83 frames
|
||||||
|
for f in range(30, l - 83):
|
||||||
|
# find action taken;
|
||||||
|
f0x, f0y = state[e][f]['agent'][0]
|
||||||
|
f1x, f1y = state[e][f + 1]['agent'][0]
|
||||||
|
dx = (f1x - f0x) / 2.
|
||||||
|
dy = (f1y - f0y) / 2.
|
||||||
|
action = [dx, dy]
|
||||||
|
#data_tuples[-1].append((v, past_len + f, action))
|
||||||
|
data_tuples[-1].append((j, past_len + f, action))
|
||||||
|
# data_tuples = (json file, frame number, action)
|
||||||
|
assert len(data_tuples[-1]) > 0
|
||||||
|
past_len += l
|
||||||
|
return data_tuples
|
||||||
|
|
||||||
|
class TransitionDataset(torch.utils.data.Dataset):
|
||||||
|
"""
|
||||||
|
Training dataset class for the behavior cloning mlp model.
|
||||||
|
Args:
|
||||||
|
path: path to the dataset
|
||||||
|
types: list of video types to include
|
||||||
|
size: size of the frames to be returned
|
||||||
|
mode: train, val
|
||||||
|
num_context: number of context state-action pairs
|
||||||
|
num_test: number of test state-action pairs
|
||||||
|
num_trials: number of trials in an episode
|
||||||
|
action_range: number of frames to skip; actions are combined over these number of frames (displcement) of the agent
|
||||||
|
process_data: whether to the videos or not (skip if already processed)
|
||||||
|
__getitem__:
|
||||||
|
returns: (dem_frames, dem_actions, query_frames, target_actions)
|
||||||
|
dem_frames: (num_context, 3, size, size)
|
||||||
|
dem_actions: (num_context, 2)
|
||||||
|
query_frames: (num_test, 3, size, size)
|
||||||
|
target_actions: (num_test, 2)
|
||||||
|
"""
|
||||||
|
def __init__(self, path, types=None, size=None, mode="train", num_context=30, num_test=1, num_trials=9,
|
||||||
|
action_range=10, process_data=0):
|
||||||
|
|
||||||
|
self.path = path
|
||||||
|
self.types = types
|
||||||
|
self.size = size
|
||||||
|
self.mode = mode
|
||||||
|
self.num_trials = num_trials
|
||||||
|
self.num_context = num_context
|
||||||
|
self.num_test = num_test
|
||||||
|
self.action_range = action_range
|
||||||
|
self.ep_combs = self.num_trials * (self.num_trials - 2) # 9p2 - 9
|
||||||
|
self.eps = [[x, y] for x in range(self.num_trials) for y in range(self.num_trials) if x != y]
|
||||||
|
types_str = '_'.join(self.types)
|
||||||
|
|
||||||
|
self.path_list = []
|
||||||
|
self.json_list = []
|
||||||
|
# get video paths and json file paths
|
||||||
|
for t in types:
|
||||||
|
print(f'reading files of type {t} in {mode}')
|
||||||
|
paths = [os.path.join(self.path, t, x) for x in os.listdir(os.path.join(self.path, t)) if
|
||||||
|
x.endswith(f'.mp4')]
|
||||||
|
jsons = [os.path.join(self.path, t, x) for x in os.listdir(os.path.join(self.path, t)) if
|
||||||
|
x.endswith(f'.json') and 'index' not in x]
|
||||||
|
|
||||||
|
paths = sorted(paths)
|
||||||
|
jsons = sorted(jsons)
|
||||||
|
|
||||||
|
if mode == 'train':
|
||||||
|
self.path_list += paths[:int(0.8 * len(jsons))]
|
||||||
|
self.json_list += jsons[:int(0.8 * len(jsons))]
|
||||||
|
elif mode == 'val':
|
||||||
|
self.path_list += paths[int(0.8 * len(jsons)):]
|
||||||
|
self.json_list += jsons[int(0.8 * len(jsons)):]
|
||||||
|
else:
|
||||||
|
self.path_list += paths
|
||||||
|
self.json_list += jsons
|
||||||
|
|
||||||
|
self.data_tuples = []
|
||||||
|
if process_data:
|
||||||
|
# index the videos in the dataset directory. This is done to speed up the retrieval of videos.
|
||||||
|
# frame index, action tuples are stored
|
||||||
|
self.data_tuples = index_data(self.json_list, self.path_list)
|
||||||
|
# tuples of frame index and action (displacement of agent)
|
||||||
|
index_dict = {'data_tuples': self.data_tuples}
|
||||||
|
with open(os.path.join(self.path, f'jindex_bib_{mode}_{types_str}.json'), 'w') as fp:
|
||||||
|
json.dump(index_dict, fp)
|
||||||
|
else:
|
||||||
|
# read pre-indexed data
|
||||||
|
with open(os.path.join(self.path, f'jindex_bib_{mode}_{types_str}.json'), 'r') as fp:
|
||||||
|
index_dict = json.load(fp)
|
||||||
|
self.data_tuples = index_dict['data_tuples']
|
||||||
|
|
||||||
|
self.tot_trials = len(self.path_list) * 9
|
||||||
|
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
print('Empty')
|
||||||
|
return
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.tot_trials // self.num_trials
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
dataset = TransitionDataset(path='/datasets/external/bib_train/',
|
||||||
|
types=['multi_agent', 'instrumental_action'], #['instrumental_action', 'multi_agent', 'preference', 'single_object'],
|
||||||
|
size=(84, 84),
|
||||||
|
mode="train", num_context=30,
|
||||||
|
num_test=1, num_trials=9,
|
||||||
|
action_range=10, process_data=1)
|
||||||
|
print(len(dataset))
|
116
utils/relations.py
Normal file
116
utils/relations.py
Normal file
|
@ -0,0 +1,116 @@
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
# =============================== relationships to build the graph ===============================
|
||||||
|
|
||||||
|
def rotate_vec2d(vec, degrees):
|
||||||
|
"""
|
||||||
|
rotate a vector anti-clockwise
|
||||||
|
:param vec:
|
||||||
|
:param degrees:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
theta = np.radians(degrees)
|
||||||
|
c, s = np.cos(theta), np.sin(theta)
|
||||||
|
R = np.array(((c, -s), (s, c)))
|
||||||
|
return R@vec
|
||||||
|
|
||||||
|
# ---------- Remote Directional Relations --------------------------------------------------------
|
||||||
|
|
||||||
|
def is_front(obj1, obj2, direction_vec)->bool:
|
||||||
|
diff = obj2.pos - obj1.pos
|
||||||
|
return diff@direction_vec > 0.1
|
||||||
|
|
||||||
|
def is_back(obj1, obj2, direction_vec)->bool:
|
||||||
|
diff = obj2.pos - obj1.pos
|
||||||
|
return diff@direction_vec < -0.1
|
||||||
|
|
||||||
|
def is_left(obj1, obj2, direction_vec)->bool:
|
||||||
|
left_vec = rotate_vec2d(direction_vec, -90)
|
||||||
|
diff = obj2.pos - obj1.pos
|
||||||
|
return diff@left_vec > 0.1
|
||||||
|
|
||||||
|
def is_right(obj1, obj2, direction_vec)->bool:
|
||||||
|
left_vec = rotate_vec2d(direction_vec, 90)
|
||||||
|
diff = obj2.pos - obj1.pos
|
||||||
|
return diff@left_vec > 0.1
|
||||||
|
|
||||||
|
# ---------- Alignment and Adjacency Relations ---------------------------------------------------
|
||||||
|
|
||||||
|
def is_close(obj1, obj2, direction_vec=None)->bool:
|
||||||
|
# indicate whether two objects are adjacent to each other,
|
||||||
|
# which, unlike local directional relations, carry no directional information
|
||||||
|
distance = np.abs(obj1.pos - obj2.pos)
|
||||||
|
return np.sum(distance)==20
|
||||||
|
|
||||||
|
def is_aligned(obj1, obj2, direction_vec=None)->bool:
|
||||||
|
# indicate if two entities are on the same horizontal or vertical line
|
||||||
|
diff = obj2.pos - obj1.pos
|
||||||
|
return np.any(diff==0)
|
||||||
|
|
||||||
|
# ---------- Local Directional Relations ---------------------------------------------------------
|
||||||
|
|
||||||
|
def is_top_adj(obj1, obj2, direction_vec=None)->bool:
|
||||||
|
return obj1.x==obj2.x and obj1.y==obj2.y+20
|
||||||
|
|
||||||
|
def is_left_adj(obj1, obj2, direction_vec=None)->bool:
|
||||||
|
return obj1.y==obj2.y and obj1.x==obj2.x-20
|
||||||
|
|
||||||
|
def is_top_left_adj(obj1, obj2, direction_vec=None)->bool:
|
||||||
|
return obj1.y==obj2.y+20 and obj1.x==obj2.x-20
|
||||||
|
|
||||||
|
def is_top_right_adj(obj1, obj2, direction_vec=None)->bool:
|
||||||
|
return obj1.y==obj2.y+20 and obj1.x==obj2.x+20
|
||||||
|
|
||||||
|
def is_down_adj(obj1, obj2, direction_vec=None)->bool:
|
||||||
|
return is_top_adj(obj2, obj1)
|
||||||
|
|
||||||
|
def is_right_adj(obj1, obj2, direction_vec=None)->bool:
|
||||||
|
return is_left_adj(obj2, obj1)
|
||||||
|
|
||||||
|
def is_down_right_adj(obj1, obj2, direction_vec=None)->bool:
|
||||||
|
return is_top_left_adj(obj2, obj1)
|
||||||
|
|
||||||
|
def is_down_left_adj(obj1, obj2, direction_vec=None)->bool:
|
||||||
|
return is_top_right_adj(obj2, obj1)
|
||||||
|
|
||||||
|
# ---------- More Remote Directional Relations (not used) ----------------------------------------
|
||||||
|
|
||||||
|
def top_left(obj1, obj2, direction_vec)->bool:
|
||||||
|
return (obj1.x-obj2.x) <= (obj1.y-obj2.y)
|
||||||
|
|
||||||
|
def top_right(obj1, obj2, direction_vec)->bool:
|
||||||
|
return -(obj1.x-obj2.x) <= (obj1.y-obj2.y)
|
||||||
|
|
||||||
|
def down_left(obj1, obj2, direction_vec)->bool:
|
||||||
|
return top_right(obj2, obj1, direction_vec)
|
||||||
|
|
||||||
|
def down_right(obj1, obj2, direction_vec)->bool:
|
||||||
|
return top_left(obj2, obj1, direction_vec)
|
||||||
|
|
||||||
|
def fan_top(obj1, obj2, direction_vec)->bool:
|
||||||
|
top_left = (obj1.x-obj2.x) <= (obj1.y-obj2.y)
|
||||||
|
top_right = -(obj1.x-obj2.x) <= (obj1.y-obj2.y)
|
||||||
|
return top_left and top_right
|
||||||
|
|
||||||
|
def fan_down(obj1, obj2, direction_vec)->bool:
|
||||||
|
return fan_top(obj2, obj1, direction_vec)
|
||||||
|
|
||||||
|
def fan_right(obj1, obj2, direction_vec)->bool:
|
||||||
|
down_left = (obj1.x-obj2.x) >= (obj1.y-obj2.y)
|
||||||
|
top_right = -(obj1.x-obj2.x) <= (obj1.y-obj2.y)
|
||||||
|
return down_left and top_right
|
||||||
|
|
||||||
|
def fan_left(obj1, obj2, direction_vec)->bool:
|
||||||
|
return fan_right(obj2, obj1, direction_vec)
|
||||||
|
|
||||||
|
# ---------- Ad-hoc Relations --------------------------------------------------------------------
|
||||||
|
|
||||||
|
def needs(obj1, obj2, direction_vec=None)->bool:
|
||||||
|
return np.argmax(obj1.type) == 0 and np.argmax(obj2.type) == 3
|
||||||
|
|
||||||
|
def opens(obj1, obj2, direction_vec=None)->bool:
|
||||||
|
return np.argmax(obj1.type) == 3 and np.argmax(obj2.type) == 8
|
||||||
|
|
||||||
|
def collects(obj1, obj2, direction_vec=None)->bool:
|
||||||
|
return np.argmax(obj1.type) == 0 and np.argmax(obj2.type) == 6
|
1
utils/run_build_graphs.sh
Normal file
1
utils/run_build_graphs.sh
Normal file
|
@ -0,0 +1 @@
|
||||||
|
python build_graphs.py --mode train --cpus 30 && python build_graphs.py --mode val --cpus 30
|
Loading…
Reference in a new issue