up
This commit is contained in:
parent
a333481e05
commit
de0bea7508
18 changed files with 3150 additions and 2 deletions
0
utils/__init__.py
Normal file
0
utils/__init__.py
Normal file
115
utils/build_graphs.py
Normal file
115
utils/build_graphs.py
Normal file
|
@ -0,0 +1,115 @@
|
|||
import sys
|
||||
sys.path.append('/projects/bortoletto/icml2023_matteo/utils')
|
||||
from dataset import TransitionDataset, TestTransitionDatasetSequence
|
||||
import multiprocessing as mp
|
||||
import argparse
|
||||
import pickle as pkl
|
||||
import os
|
||||
|
||||
# Instantiate the parser
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--cpus', type=int,
|
||||
help='Number of processes')
|
||||
parser.add_argument('--mode', type=str,
|
||||
help='Train (train) or validation (val)')
|
||||
args = parser.parse_args()
|
||||
|
||||
NUM_PROCESSES = args.cpus
|
||||
MODE = args.mode
|
||||
|
||||
def generate_files(idx):
|
||||
print('Generating idx', idx)
|
||||
if os.path.exists(PATH+str(idx)+'.pkl'):
|
||||
print('Index', idx, 'skipped.')
|
||||
return
|
||||
if MODE == 'train' or MODE == 'val':
|
||||
states, actions, lens, n_nodes = dataset.__getitem__(idx)
|
||||
with open(PATH+str(idx)+'.pkl', 'wb') as f:
|
||||
pkl.dump([states, actions, lens, n_nodes], f)
|
||||
elif MODE == 'test':
|
||||
dem_expected_states, dem_expected_actions, dem_expected_lens, dem_expected_nodes, \
|
||||
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, dem_unexpected_nodes, \
|
||||
query_expected_frames, target_expected_actions, \
|
||||
query_unexpected_frames, target_unexpected_actions = dataset.__getitem__(idx)
|
||||
with open(PATH+str(idx)+'.pkl', 'wb') as f:
|
||||
pkl.dump([
|
||||
dem_expected_states, dem_expected_actions, dem_expected_lens, dem_expected_nodes, \
|
||||
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, dem_unexpected_nodes, \
|
||||
query_expected_frames, target_expected_actions, \
|
||||
query_unexpected_frames, target_unexpected_actions], f
|
||||
)
|
||||
else:
|
||||
raise ValueError('MODE can be only train, val or test.')
|
||||
print(PATH+str(idx)+'.pkl saved.')
|
||||
|
||||
if __name__ == "__main__":
|
||||
if MODE == 'train':
|
||||
print('TRAIN MODE')
|
||||
PATH = '/datasets/external/bib_train/graphs/all_tasks/train_dgl_hetero_nobound_4feats/'
|
||||
if not os.path.exists(PATH):
|
||||
os.makedirs(PATH)
|
||||
print(PATH, 'directory created.')
|
||||
dataset = TransitionDataset(
|
||||
path='/datasets/external/bib_train/',
|
||||
types=['instrumental_action', 'multi_agent', 'preference', 'single_object'],
|
||||
mode="train",
|
||||
max_len=30,
|
||||
num_test=1,
|
||||
num_trials=9,
|
||||
action_range=10,
|
||||
process_data=0
|
||||
)
|
||||
pool = mp.Pool(processes=NUM_PROCESSES)
|
||||
print('Starting graph generation with', NUM_PROCESSES, 'processes...')
|
||||
pool.map(generate_files, [i for i in range(dataset.__len__())])
|
||||
pool.close()
|
||||
elif MODE == 'val':
|
||||
print('VALIDATION MODE')
|
||||
types = ['multi_agent', 'instrumental_action', 'preference', 'single_object']
|
||||
for t in range(len(types)):
|
||||
PATH = '/datasets/external/bib_train/graphs/all_tasks/val_dgl_hetero_nobound_4feats/'+types[t]+'/'
|
||||
if not os.path.exists(PATH):
|
||||
os.makedirs(PATH)
|
||||
print(PATH, 'directory created.')
|
||||
dataset = TransitionDataset(
|
||||
path='/datasets/external/bib_train/',
|
||||
types=[types[t]],
|
||||
mode="val",
|
||||
max_len=30,
|
||||
num_test=1,
|
||||
num_trials=9,
|
||||
action_range=10,
|
||||
process_data=0
|
||||
)
|
||||
pool = mp.Pool(processes=NUM_PROCESSES)
|
||||
print('Starting', types[t], 'graph generation with', NUM_PROCESSES, 'processes...')
|
||||
pool.map(generate_files, [i for i in range(dataset.__len__())])
|
||||
pool.close()
|
||||
elif MODE == 'test':
|
||||
print('TEST MODE')
|
||||
types = [
|
||||
'preference', 'multi_agent', 'inaccessible_goal',
|
||||
'efficiency_irrational', 'efficiency_time','efficiency_path',
|
||||
'instrumental_no_barrier', 'instrumental_blocking_barrier', 'instrumental_inconsequential_barrier'
|
||||
]
|
||||
for t in range(len(types)):
|
||||
PATH = '/datasets/external/bib_evaluation_1_1/graphs/all_tasks_dgl_hetero_nobound_4feats/'+types[t]+'/'
|
||||
if not os.path.exists(PATH):
|
||||
os.makedirs(PATH)
|
||||
print(PATH, 'directory created.')
|
||||
dataset = TestTransitionDatasetSequence(
|
||||
path='/datasets/external/bib_evaluation_1_1/',
|
||||
task_type=types[t],
|
||||
mode="test",
|
||||
num_test=1,
|
||||
num_trials=9,
|
||||
action_range=10,
|
||||
process_data=0,
|
||||
max_len=30
|
||||
)
|
||||
pool = mp.Pool(processes=NUM_PROCESSES)
|
||||
print('Starting', types[t], 'graph generation with', NUM_PROCESSES, 'processes...')
|
||||
pool.map(generate_files, [i for i in range(dataset.__len__())])
|
||||
pool.close()
|
||||
else:
|
||||
raise ValueError
|
487
utils/dataset.py
Normal file
487
utils/dataset.py
Normal file
|
@ -0,0 +1,487 @@
|
|||
import dgl
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import os
|
||||
import pickle as pkl
|
||||
import json
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
import sys
|
||||
sys.path.append('/projects/bortoletto/irene/')
|
||||
from utils.grid_object import *
|
||||
from utils.relations import *
|
||||
|
||||
# ========================== Helper functions ==========================
|
||||
|
||||
def index_data(json_list, path_list):
|
||||
print(f'processing files {len(json_list)}')
|
||||
data_tuples = []
|
||||
for j, v in tqdm(zip(json_list, path_list)):
|
||||
with open(j, 'r') as f:
|
||||
state = json.load(f)
|
||||
ep_lens = [len(x) for x in state]
|
||||
past_len = 0
|
||||
for e, l in enumerate(ep_lens):
|
||||
data_tuples.append([])
|
||||
# skip first 30 frames and last 83 frames
|
||||
for f in range(30, l - 83):
|
||||
# find action taken;
|
||||
f0x, f0y = state[e][f]['agent'][0]
|
||||
f1x, f1y = state[e][f + 1]['agent'][0]
|
||||
dx = (f1x - f0x) / 2.
|
||||
dy = (f1y - f0y) / 2.
|
||||
action = [dx, dy]
|
||||
#data_tuples[-1].append((v, past_len + f, action))
|
||||
data_tuples[-1].append((j, past_len + f, action))
|
||||
# data_tuples = [[json file, frame number, action] for each episode in each video]
|
||||
assert len(data_tuples[-1]) > 0
|
||||
past_len += l
|
||||
return data_tuples
|
||||
|
||||
# ========================== Dataset class ==========================
|
||||
|
||||
class TransitionDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
Training dataset class for the behavior cloning mlp model.
|
||||
Args:
|
||||
path: path to the dataset
|
||||
types: list of video types to include
|
||||
mode: train, val
|
||||
num_test: number of test state-action pairs
|
||||
num_trials: number of trials in an episode
|
||||
action_range: number of frames to skip; actions are combined over these number of frames (displcement) of the agent
|
||||
process_data: whether to the videos or not (skip if already processed)
|
||||
max_len: max number of context state-action pairs
|
||||
__getitem__:
|
||||
returns: (states, actions, lens, n_nodes)
|
||||
dem_frames: batched DGLGraph.heterograph
|
||||
dem_actions: (max_len, 2)
|
||||
query_frames: DGLGraph.heterograph
|
||||
target_actions: (num_test, 2)
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
path,
|
||||
types=None,
|
||||
mode="train",
|
||||
num_test=1,
|
||||
num_trials=9,
|
||||
action_range=10,
|
||||
process_data=0,
|
||||
max_len=30
|
||||
):
|
||||
self.path = path
|
||||
self.types = types
|
||||
self.mode = mode
|
||||
self.num_trials = num_trials
|
||||
self.num_test = num_test
|
||||
self.action_range = action_range
|
||||
self.max_len = max_len
|
||||
self.ep_combs = self.num_trials * (self.num_trials - 2) # 9p2 - 9
|
||||
self.eps = [[x, y] for x in range(self.num_trials) for y in range(self.num_trials) if x != y]
|
||||
types_str = '_'.join(self.types)
|
||||
self.rel_deter_func = [
|
||||
is_top_adj, is_left_adj, is_top_right_adj, is_top_left_adj,
|
||||
is_down_adj, is_right_adj, is_down_left_adj, is_down_right_adj,
|
||||
is_left, is_right, is_front, is_back, is_aligned, is_close
|
||||
]
|
||||
self.path_list = []
|
||||
self.json_list = []
|
||||
# get video paths and json file paths
|
||||
for t in types:
|
||||
print(f'reading files of type {t} in {mode}')
|
||||
paths = [os.path.join(self.path, t, x) for x in os.listdir(os.path.join(self.path, t)) if
|
||||
x.endswith(f'.mp4')]
|
||||
jsons = [os.path.join(self.path, t, x) for x in os.listdir(os.path.join(self.path, t)) if
|
||||
x.endswith(f'.json') and 'index' not in x]
|
||||
paths = sorted(paths)
|
||||
jsons = sorted(jsons)
|
||||
if mode == 'train':
|
||||
self.path_list += paths[:int(0.8 * len(jsons))]
|
||||
self.json_list += jsons[:int(0.8 * len(jsons))]
|
||||
elif mode == 'val':
|
||||
self.path_list += paths[int(0.8 * len(jsons)):]
|
||||
self.json_list += jsons[int(0.8 * len(jsons)):]
|
||||
else:
|
||||
self.path_list += paths
|
||||
self.json_list += jsons
|
||||
self.data_tuples = []
|
||||
if process_data:
|
||||
# index the videos in the dataset directory. This is done to speed up the retrieval of videos.
|
||||
# frame index, action tuples are stored
|
||||
self.data_tuples = index_data(self.json_list, self.path_list)
|
||||
# tuples of frame index and action (displacement of agent)
|
||||
index_dict = {'data_tuples': self.data_tuples}
|
||||
with open(os.path.join(self.path, f'jindex_bib_{mode}_{types_str}.json'), 'w') as fp:
|
||||
json.dump(index_dict, fp)
|
||||
else:
|
||||
# read pre-indexed data
|
||||
with open(os.path.join(self.path, f'jindex_bib_{mode}_{types_str}.json'), 'r') as fp:
|
||||
index_dict = json.load(fp)
|
||||
self.data_tuples = index_dict['data_tuples']
|
||||
self.tot_trials = len(self.path_list) * 9
|
||||
|
||||
def _get_frame_graph(self, jsonfile, frame_idx):
|
||||
# load json
|
||||
with open(jsonfile, 'rb') as f:
|
||||
frame_data = json.load(f)
|
||||
flat_list = [x for xs in frame_data for x in xs]
|
||||
# extract entities
|
||||
grid_objs = parse_objects(flat_list[frame_idx])
|
||||
# --- build the graph
|
||||
adj = self._get_spatial_rel(grid_objs)
|
||||
# define edges
|
||||
is_top_adj_src, is_top_adj_dst = np.nonzero(adj[0])
|
||||
is_left_adj_src, is_left_adj_dst = np.nonzero(adj[1])
|
||||
is_top_right_adj_src, is_top_right_adj_dst = np.nonzero(adj[2])
|
||||
is_top_left_adj_src, is_top_left_adj_dst = np.nonzero(adj[3])
|
||||
is_down_adj_src, is_down_adj_dst = np.nonzero(adj[4])
|
||||
is_right_adj_src, is_right_adj_dst = np.nonzero(adj[5])
|
||||
is_down_left_adj_src, is_down_left_adj_dst = np.nonzero(adj[6])
|
||||
is_down_right_adj_src, is_down_right_adj_dst = np.nonzero(adj[7])
|
||||
is_left_src, is_left_dst = np.nonzero(adj[8])
|
||||
is_right_src, is_right_dst = np.nonzero(adj[9])
|
||||
is_front_src, is_front_dst = np.nonzero(adj[10])
|
||||
is_back_src, is_back_dst = np.nonzero(adj[11])
|
||||
is_aligned_src, is_aligned_dst = np.nonzero(adj[12])
|
||||
is_close_src, is_close_dst = np.nonzero(adj[13])
|
||||
g = dgl.heterograph({
|
||||
('obj', 'is_top_adj', 'obj'): (torch.tensor(is_top_adj_src), torch.tensor(is_top_adj_dst)),
|
||||
('obj', 'is_left_adj', 'obj'): (torch.tensor(is_left_adj_src), torch.tensor(is_left_adj_dst)),
|
||||
('obj', 'is_top_right_adj', 'obj'): (torch.tensor(is_top_right_adj_src), torch.tensor(is_top_right_adj_dst)),
|
||||
('obj', 'is_top_left_adj', 'obj'): (torch.tensor(is_top_left_adj_src), torch.tensor(is_top_left_adj_dst)),
|
||||
('obj', 'is_down_adj', 'obj'): (torch.tensor(is_down_adj_src), torch.tensor(is_down_adj_dst)),
|
||||
('obj', 'is_right_adj', 'obj'): (torch.tensor(is_right_adj_src), torch.tensor(is_right_adj_dst)),
|
||||
('obj', 'is_down_left_adj', 'obj'): (torch.tensor(is_down_left_adj_src), torch.tensor(is_down_left_adj_dst)),
|
||||
('obj', 'is_down_right_adj', 'obj'): (torch.tensor(is_down_right_adj_src), torch.tensor(is_down_right_adj_dst)),
|
||||
('obj', 'is_left', 'obj'): (torch.tensor(is_left_src), torch.tensor(is_left_dst)),
|
||||
('obj', 'is_right', 'obj'): (torch.tensor(is_right_src), torch.tensor(is_right_dst)),
|
||||
('obj', 'is_front', 'obj'): (torch.tensor(is_front_src), torch.tensor(is_front_dst)),
|
||||
('obj', 'is_back', 'obj'): (torch.tensor(is_back_src), torch.tensor(is_back_dst)),
|
||||
('obj', 'is_aligned', 'obj'): (torch.tensor(is_aligned_src), torch.tensor(is_aligned_dst)),
|
||||
('obj', 'is_close', 'obj'): (torch.tensor(is_close_src), torch.tensor(is_close_dst))
|
||||
}, num_nodes_dict={'obj': len(grid_objs)})
|
||||
g = self._add_node_features(grid_objs, g)
|
||||
breakpoint()
|
||||
return g
|
||||
|
||||
def _add_node_features(self, objs, graph):
|
||||
for obj_idx, obj in enumerate(objs):
|
||||
graph.nodes[obj_idx].data['type'] = torch.tensor(obj.type)
|
||||
graph.nodes[obj_idx].data['pos'] = torch.tensor([[obj.x, obj.y]], dtype=torch.float32)
|
||||
assert len(obj.attributes) == 2 and None not in obj.attributes[0] and None not in obj.attributes[1]
|
||||
graph.nodes[obj_idx].data['color'] = torch.tensor([obj.attributes[0]])
|
||||
graph.nodes[obj_idx].data['shape'] = torch.tensor([obj.attributes[1]])
|
||||
return graph
|
||||
|
||||
def _get_spatial_rel(self, objs):
|
||||
spatial_tensors = [np.zeros([len(objs), len(objs)]) for _ in range(len(self.rel_deter_func))]
|
||||
for obj_idx1, obj1 in enumerate(objs):
|
||||
for obj_idx2, obj2 in enumerate(objs):
|
||||
direction_vec = np.array((0, -1)) # Up
|
||||
for rel_idx, func in enumerate(self.rel_deter_func):
|
||||
if func(obj1, obj2, direction_vec):
|
||||
spatial_tensors[rel_idx][obj_idx1, obj_idx2] = 1.0
|
||||
return spatial_tensors
|
||||
|
||||
def get_trial(self, trials, step=10):
|
||||
# retrieve state embeddings and actions from cached file
|
||||
states = []
|
||||
actions = []
|
||||
trial_len = []
|
||||
lens = []
|
||||
n_nodes = []
|
||||
# 8 trials
|
||||
for t in trials:
|
||||
tl = [(t, n) for n in range(0, len(self.data_tuples[t]), step)]
|
||||
if len(tl) > self.max_len: # 30
|
||||
tl = tl[:self.max_len]
|
||||
trial_len.append(tl)
|
||||
for tl in trial_len:
|
||||
states.append([])
|
||||
actions.append([])
|
||||
lens.append(len(tl))
|
||||
for t, n in tl:
|
||||
video = self.data_tuples[t][n][0]
|
||||
states[-1].append(self._get_frame_graph(video, self.data_tuples[t][n][1]))
|
||||
n_nodes.append(states[-1][-1].number_of_nodes())
|
||||
# actions are pooled over frames
|
||||
if len(self.data_tuples[t]) > n + self.action_range:
|
||||
actions_xy = [d[2] for d in self.data_tuples[t][n:n + self.action_range]]
|
||||
else:
|
||||
actions_xy = [d[2] for d in self.data_tuples[t][n:]]
|
||||
actions_xy = np.array(actions_xy)
|
||||
actions_xy = np.mean(actions_xy, axis=0)
|
||||
actions[-1].append(actions_xy)
|
||||
states[-1] = dgl.batch(states[-1])
|
||||
actions[-1] = torch.tensor(np.array(actions[-1]))
|
||||
trial_actions_padded = torch.zeros(self.max_len, actions[-1].size(1))
|
||||
trial_actions_padded[:actions[-1].size(0), :] = actions[-1]
|
||||
actions[-1] = trial_actions_padded
|
||||
return states, actions, lens, n_nodes
|
||||
|
||||
def __getitem__(self, idx):
|
||||
ep_trials = [idx * self.num_trials + t for t in range(self.num_trials)] # [idx, ..., idx+8]
|
||||
states, actions, lens, n_nodes = self.get_trial(ep_trials, step=self.action_range)
|
||||
return states, actions, lens, n_nodes
|
||||
|
||||
def __len__(self):
|
||||
return self.tot_trials // self.num_trials
|
||||
|
||||
|
||||
class TestTransitionDatasetSequence(torch.utils.data.Dataset):
|
||||
"""
|
||||
Test dataset class for the behavior cloning rnn model. This dataset is used to test the model on the eval data.
|
||||
This class is used to compare plausible and implausible episodes.
|
||||
Args:
|
||||
path: path to the dataset
|
||||
types: list of video types to include
|
||||
size: size of the frames to be returned
|
||||
mode: test
|
||||
num_context: number of context state-action pairs
|
||||
num_test: number of test state-action pairs
|
||||
num_trials: number of trials in an episode
|
||||
action_range: number of frames to skip; actions are combined over these number of frames (displcement) of the agent
|
||||
process_data: whether to the videos or not (skip if already processed)
|
||||
__getitem__:
|
||||
returns: (expected_dem_frames, expected_dem_actions, expected_dem_lens expected_query_frames, expected_target_actions,
|
||||
unexpected_dem_frames, unexpected_dem_actions, unexpected_dem_lens, unexpected_query_frames, unexpected_target_actions)
|
||||
dem_frames: (num_context, max_len, 3, size, size)
|
||||
dem_actions: (num_context, max_len, 2)
|
||||
dem_lens: (num_context)
|
||||
query_frames: (num_test, 3, size, size)
|
||||
target_actions: (num_test, 2)
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
path,
|
||||
task_type=None,
|
||||
mode="test",
|
||||
num_test=1,
|
||||
num_trials=9,
|
||||
action_range=10,
|
||||
process_data=0,
|
||||
max_len=30
|
||||
):
|
||||
self.path = path
|
||||
self.task_type = task_type
|
||||
self.mode = mode
|
||||
self.num_trials = num_trials
|
||||
self.num_test = num_test
|
||||
self.action_range = action_range
|
||||
self.max_len = max_len
|
||||
self.ep_combs = self.num_trials * (self.num_trials - 2) # 9p2 - 9
|
||||
self.eps = [[x, y] for x in range(self.num_trials) for y in range(self.num_trials) if x != y]
|
||||
self.path_list_exp = []
|
||||
self.json_list_exp = []
|
||||
self.path_list_un = []
|
||||
self.json_list_un = []
|
||||
|
||||
print(f'reading files of type {task_type} in {mode}')
|
||||
|
||||
paths_expected = sorted([os.path.join(self.path, task_type, x) for x in os.listdir(os.path.join(self.path, task_type)) if
|
||||
x.endswith('e.mp4')])
|
||||
jsons_expected = sorted([os.path.join(self.path, task_type, x) for x in os.listdir(os.path.join(self.path, task_type)) if
|
||||
x.endswith('e.json') and 'index' not in x])
|
||||
paths_unexpected = sorted([os.path.join(self.path, task_type, x) for x in os.listdir(os.path.join(self.path, task_type)) if
|
||||
x.endswith('u.mp4')])
|
||||
jsons_unexpected = sorted([os.path.join(self.path, task_type, x) for x in os.listdir(os.path.join(self.path, task_type)) if
|
||||
x.endswith('u.json') and 'index' not in x])
|
||||
self.path_list_exp += paths_expected
|
||||
self.json_list_exp += jsons_expected
|
||||
self.path_list_un += paths_unexpected
|
||||
self.json_list_un += jsons_unexpected
|
||||
self.data_expected = []
|
||||
self.data_unexpected = []
|
||||
|
||||
if process_data:
|
||||
# index data. This is done to speed up video retrieval.
|
||||
# frame index, action tuples are stored
|
||||
self.data_expected = index_data(self.json_list_exp, self.path_list_exp)
|
||||
index_dict = {'data_tuples': self.data_expected}
|
||||
with open(os.path.join(self.path, f'jindex_bib_test_{task_type}e.json'), 'w') as fp:
|
||||
json.dump(index_dict, fp)
|
||||
|
||||
self.data_unexpected = index_data(self.json_list_un, self.path_list_un)
|
||||
index_dict = {'data_tuples': self.data_unexpected}
|
||||
with open(os.path.join(self.path, f'jindex_bib_test_{task_type}u.json'), 'w') as fp:
|
||||
json.dump(index_dict, fp)
|
||||
else:
|
||||
with open(os.path.join(self.path, f'jindex_bib_{mode}_{task_type}e.json'), 'r') as fp:
|
||||
index_dict = json.load(fp)
|
||||
self.data_expected = index_dict['data_tuples']
|
||||
with open(os.path.join(self.path, f'jindex_bib_{mode}_{task_type}u.json'), 'r') as fp:
|
||||
index_dict = json.load(fp)
|
||||
self.data_unexpected = index_dict['data_tuples']
|
||||
|
||||
self.rel_deter_func = [
|
||||
is_top_adj, is_left_adj, is_top_right_adj, is_top_left_adj,
|
||||
is_down_adj, is_right_adj, is_down_left_adj, is_down_right_adj,
|
||||
is_left, is_right, is_front, is_back, is_aligned, is_close
|
||||
]
|
||||
|
||||
print('Done.')
|
||||
|
||||
def _get_frame_graph(self, jsonfile, frame_idx):
|
||||
# load json
|
||||
with open(jsonfile, 'rb') as f:
|
||||
frame_data = json.load(f)
|
||||
flat_list = [x for xs in frame_data for x in xs]
|
||||
# extract entities
|
||||
grid_objs = parse_objects(flat_list[frame_idx])
|
||||
# --- build the graph
|
||||
adj = self._get_spatial_rel(grid_objs)
|
||||
# define edges
|
||||
is_top_adj_src, is_top_adj_dst = np.nonzero(adj[0])
|
||||
is_left_adj_src, is_left_adj_dst = np.nonzero(adj[1])
|
||||
is_top_right_adj_src, is_top_right_adj_dst = np.nonzero(adj[2])
|
||||
is_top_left_adj_src, is_top_left_adj_dst = np.nonzero(adj[3])
|
||||
is_down_adj_src, is_down_adj_dst = np.nonzero(adj[4])
|
||||
is_right_adj_src, is_right_adj_dst = np.nonzero(adj[5])
|
||||
is_down_left_adj_src, is_down_left_adj_dst = np.nonzero(adj[6])
|
||||
is_down_right_adj_src, is_down_right_adj_dst = np.nonzero(adj[7])
|
||||
is_left_src, is_left_dst = np.nonzero(adj[8])
|
||||
is_right_src, is_right_dst = np.nonzero(adj[9])
|
||||
is_front_src, is_front_dst = np.nonzero(adj[10])
|
||||
is_back_src, is_back_dst = np.nonzero(adj[11])
|
||||
is_aligned_src, is_aligned_dst = np.nonzero(adj[12])
|
||||
is_close_src, is_close_dst = np.nonzero(adj[13])
|
||||
g = dgl.heterograph({
|
||||
('obj', 'is_top_adj', 'obj'): (torch.tensor(is_top_adj_src), torch.tensor(is_top_adj_dst)),
|
||||
('obj', 'is_left_adj', 'obj'): (torch.tensor(is_left_adj_src), torch.tensor(is_left_adj_dst)),
|
||||
('obj', 'is_top_right_adj', 'obj'): (torch.tensor(is_top_right_adj_src), torch.tensor(is_top_right_adj_dst)),
|
||||
('obj', 'is_top_left_adj', 'obj'): (torch.tensor(is_top_left_adj_src), torch.tensor(is_top_left_adj_dst)),
|
||||
('obj', 'is_down_adj', 'obj'): (torch.tensor(is_down_adj_src), torch.tensor(is_down_adj_dst)),
|
||||
('obj', 'is_right_adj', 'obj'): (torch.tensor(is_right_adj_src), torch.tensor(is_right_adj_dst)),
|
||||
('obj', 'is_down_left_adj', 'obj'): (torch.tensor(is_down_left_adj_src), torch.tensor(is_down_left_adj_dst)),
|
||||
('obj', 'is_down_right_adj', 'obj'): (torch.tensor(is_down_right_adj_src), torch.tensor(is_down_right_adj_dst)),
|
||||
('obj', 'is_left', 'obj'): (torch.tensor(is_left_src), torch.tensor(is_left_dst)),
|
||||
('obj', 'is_right', 'obj'): (torch.tensor(is_right_src), torch.tensor(is_right_dst)),
|
||||
('obj', 'is_front', 'obj'): (torch.tensor(is_front_src), torch.tensor(is_front_dst)),
|
||||
('obj', 'is_back', 'obj'): (torch.tensor(is_back_src), torch.tensor(is_back_dst)),
|
||||
('obj', 'is_aligned', 'obj'): (torch.tensor(is_aligned_src), torch.tensor(is_aligned_dst)),
|
||||
('obj', 'is_close', 'obj'): (torch.tensor(is_close_src), torch.tensor(is_close_dst))
|
||||
}, num_nodes_dict={'obj': len(grid_objs)})
|
||||
g = self._add_node_features(grid_objs, g)
|
||||
return g
|
||||
|
||||
def _add_node_features(self, objs, graph):
|
||||
for obj_idx, obj in enumerate(objs):
|
||||
graph.nodes[obj_idx].data['type'] = torch.tensor(obj.type)
|
||||
graph.nodes[obj_idx].data['pos'] = torch.tensor([[obj.x, obj.y]], dtype=torch.float32)
|
||||
assert len(obj.attributes) == 2 and None not in obj.attributes[0] and None not in obj.attributes[1]
|
||||
graph.nodes[obj_idx].data['color'] = torch.tensor([obj.attributes[0]])
|
||||
graph.nodes[obj_idx].data['shape'] = torch.tensor([obj.attributes[1]])
|
||||
return graph
|
||||
|
||||
def _get_spatial_rel(self, objs):
|
||||
spatial_tensors = [np.zeros([len(objs), len(objs)]) for _ in range(len(self.rel_deter_func))]
|
||||
for obj_idx1, obj1 in enumerate(objs):
|
||||
for obj_idx2, obj2 in enumerate(objs):
|
||||
direction_vec = np.array((0, -1)) # Up why??????????????
|
||||
for rel_idx, func in enumerate(self.rel_deter_func):
|
||||
if func(obj1, obj2, direction_vec):
|
||||
spatial_tensors[rel_idx][obj_idx1, obj_idx2] = 1.0
|
||||
return spatial_tensors
|
||||
|
||||
def get_trial(self, trials, data, step=10):
|
||||
# retrieve state embeddings and actions from cached file
|
||||
states = []
|
||||
actions = []
|
||||
trial_len = []
|
||||
lens = []
|
||||
n_nodes = []
|
||||
for t in trials:
|
||||
tl = [(t, n) for n in range(0, len(data[t]), step)]
|
||||
if len(tl) > self.max_len:
|
||||
tl = tl[:self.max_len]
|
||||
trial_len.append(tl)
|
||||
for tl in trial_len:
|
||||
states.append([])
|
||||
actions.append([])
|
||||
lens.append(len(tl))
|
||||
for t, n in tl:
|
||||
video = data[t][n][0]
|
||||
states[-1].append(self._get_frame_graph(video, data[t][n][1]))
|
||||
n_nodes.append(states[-1][-1].number_of_nodes())
|
||||
if len(data[t]) > n + self.action_range:
|
||||
actions_xy = [d[2] for d in data[t][n:n + self.action_range]]
|
||||
else:
|
||||
actions_xy = [d[2] for d in data[t][n:]]
|
||||
actions_xy = np.array(actions_xy)
|
||||
actions_xy = np.mean(actions_xy, axis=0)
|
||||
actions[-1].append(actions_xy)
|
||||
states[-1] = dgl.batch(states[-1])
|
||||
actions[-1] = torch.tensor(np.array(actions[-1]))
|
||||
trial_actions_padded = torch.zeros(self.max_len, actions[-1].size(1))
|
||||
trial_actions_padded[:actions[-1].size(0), :] = actions[-1]
|
||||
actions[-1] = trial_actions_padded
|
||||
return states, actions, lens, n_nodes
|
||||
|
||||
def get_test(self, trial, data, step=10):
|
||||
# retrieve state embeddings and actions from cached file
|
||||
states = []
|
||||
actions = []
|
||||
trial_len = []
|
||||
trial_len += [(trial, n) for n in range(0, len(data[trial]), step)]
|
||||
for t, n in trial_len:
|
||||
video = data[t][n][0]
|
||||
state = self._get_frame_graph(video, data[t][n][1])
|
||||
if len(data[t]) > n + self.action_range:
|
||||
actions_xy = [d[2] for d in data[t][n:n + self.action_range]]
|
||||
else:
|
||||
actions_xy = [d[2] for d in data[t][n:]]
|
||||
actions_xy = np.array(actions_xy)
|
||||
actions_xy = np.mean(actions_xy, axis=0)
|
||||
actions.append(actions_xy)
|
||||
states.append(state)
|
||||
#states = torch.stack(states)
|
||||
states = dgl.batch(states)
|
||||
actions = torch.tensor(np.array(actions))
|
||||
return states, actions
|
||||
|
||||
def __getitem__(self, idx):
|
||||
ep_trials = [idx * self.num_trials + t for t in range(self.num_trials)]
|
||||
dem_expected_states, dem_expected_actions, dem_expected_lens, dem_expected_nodes = self.get_trial(
|
||||
ep_trials[:-1], self.data_expected, step=self.action_range
|
||||
)
|
||||
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, dem_unexpected_nodes = self.get_trial(
|
||||
ep_trials[:-1], self.data_unexpected, step=self.action_range
|
||||
)
|
||||
query_expected_frames, target_expected_actions = self.get_test(
|
||||
ep_trials[-1], self.data_expected, step=self.action_range
|
||||
)
|
||||
query_unexpected_frames, target_unexpected_actions = self.get_test(
|
||||
ep_trials[-1], self.data_unexpected, step=self.action_range
|
||||
)
|
||||
return dem_expected_states, dem_expected_actions, dem_expected_lens, dem_expected_nodes, \
|
||||
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, dem_unexpected_nodes, \
|
||||
query_expected_frames, target_expected_actions, \
|
||||
query_unexpected_frames, target_unexpected_actions
|
||||
|
||||
def __len__(self):
|
||||
return len(self.path_list_exp)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
types = ['preference', 'multi_agent', 'inaccessible_goal',
|
||||
'efficiency_irrational', 'efficiency_time','efficiency_path',
|
||||
'instrumental_no_barrier', 'instrumental_blocking_barrier', 'instrumental_inconsequential_barrier']
|
||||
for t in types:
|
||||
ttd = TestTransitionDatasetSequence(path='/datasets/external/bib_evaluation_1_1/', task_type=t, process_data=0, mode='test')
|
||||
for i in range(ttd.__len__()):
|
||||
print(i, end='\r')
|
||||
dem_expected_states, dem_expected_actions, dem_expected_lens, dem_expected_nodes, \
|
||||
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, dem_unexpected_nodes, \
|
||||
query_expected_frames, target_expected_actions, \
|
||||
query_unexpected_frames, target_unexpected_actions = ttd.__getitem__(i)
|
||||
for j in range(8):
|
||||
if not torch.tensor([1., 0., 0., 0., 0., 0., 0., 0., 0.]) in dem_expected_states[j].ndata['type']:
|
||||
print(i)
|
||||
if not torch.tensor([1., 0., 0., 0., 0., 0., 0., 0., 0.]) in dem_unexpected_states[j].ndata['type']:
|
||||
print(i)
|
||||
if not torch.tensor([1., 0., 0., 0., 0., 0., 0., 0., 0.]) in query_expected_frames.ndata['type']:
|
||||
print(i)
|
||||
if not torch.tensor([1., 0., 0., 0., 0., 0., 0., 0., 0.]) in query_unexpected_frames.ndata['type']:
|
||||
print(i)
|
174
utils/grid_object.py
Normal file
174
utils/grid_object.py
Normal file
|
@ -0,0 +1,174 @@
|
|||
import json
|
||||
import pdb
|
||||
import numpy as np
|
||||
from sklearn.preprocessing import OneHotEncoder
|
||||
import itertools
|
||||
|
||||
|
||||
SHAPES = {
|
||||
# walls
|
||||
'square': 0,
|
||||
# objects
|
||||
'heart': 1, 'clove_tree': 2, 'moon': 3, 'wine': 4, 'double_dia': 5,
|
||||
'flag': 6, 'capsule': 7, 'vase': 8, 'curved_triangle': 9, 'spoon': 10,
|
||||
'medal': 11, # inaccessible goal
|
||||
# home
|
||||
'home': 12,
|
||||
# agent
|
||||
'pentagon': 13, 'clove': 14, 'kite': 15,
|
||||
# key
|
||||
'triangle0': 16, 'triangle180': 16, 'triangle90': 16, 'triangle270': 16,
|
||||
# lock
|
||||
'triangle_slot0': 17, 'triangle_slot180': 17, 'triangle_slot90': 17, 'triangle_slot270': 17
|
||||
}
|
||||
|
||||
ENTITIES = {
|
||||
'agent': 0 , 'walls': 1, 'fuse_walls': 2, 'key': 3,
|
||||
'blocking': 4, 'home': 5, 'objects': 6, 'pin': 7, 'lock': 8
|
||||
}
|
||||
|
||||
# =============================== GridPbject class ===============================
|
||||
|
||||
class GridObject():
|
||||
"object is specified by its location"
|
||||
def __init__(self, x, y, object_type, attributes=[]):
|
||||
self.x = x
|
||||
self.y = y
|
||||
self.type = object_type
|
||||
self.attributes = attributes
|
||||
|
||||
@property
|
||||
def pos(self):
|
||||
return np.array([self.x, self.y])
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return {'type': str(self.type),
|
||||
'x': str(self.x),
|
||||
'y': str(self.y),
|
||||
'color': self.attributes[0],
|
||||
'shape': self.attributes[1]}
|
||||
|
||||
# =============================== Helper functions ===============================
|
||||
|
||||
def type2index(key):
|
||||
for name, idx in ENTITIES.items():
|
||||
if name == key:
|
||||
return idx
|
||||
|
||||
def find_shape(shape_string, print_shape=False):
|
||||
try:
|
||||
shape = shape_string.split('/')[-1].split('.')[0]
|
||||
except:
|
||||
shape = shape_string.split('.')[0]
|
||||
if print_shape: print(shape)
|
||||
for name, idx in SHAPES.items():
|
||||
if name == shape:
|
||||
return idx
|
||||
|
||||
def parse_objects(frame):
|
||||
"""
|
||||
x and y are computed differently from walls and objects
|
||||
for walls x, y = obj[0][0] + obj[1][0]/2, obj[0][1] + obj[1][1]/2
|
||||
for objects x, y = obj[0][0] + obj[1], obj[0][1] + obj[1]
|
||||
:param obj:
|
||||
:return: GridObject
|
||||
"""
|
||||
shape_onehot_encoder = OneHotEncoder(sparse=False)
|
||||
shape_onehot_encoder.fit([[i] for i in range(len(SHAPES)-6)])
|
||||
type_onehot_encoder = OneHotEncoder(sparse=False)
|
||||
type_onehot_encoder.fit([[i] for i in range(len(ENTITIES))])
|
||||
# remove duplicate walls
|
||||
frame['walls'].sort()
|
||||
frame['walls'] = list(k for k, _ in itertools.groupby(frame['walls']))
|
||||
# remove boundary walls
|
||||
frame['walls'] = [w for w in frame['walls'] if (w[0][0] != 0 and w[0][0] != 180 and w[0][1] != 0 and w[0][1] != 180)]
|
||||
# remove duplicate fuse_walls
|
||||
frame['fuse_walls'].sort()
|
||||
frame['fuse_walls'] = list(k for k, _ in itertools.groupby(frame['fuse_walls']))
|
||||
grid_objs = []
|
||||
assert 'agent' in frame.keys()
|
||||
for key in frame.keys():
|
||||
#print(key)
|
||||
if key == 'size':
|
||||
continue
|
||||
obj = frame[key]
|
||||
if obj == []:
|
||||
#print(key, 'skipped')
|
||||
continue
|
||||
obj_type = type2index(key)
|
||||
obj_type = type_onehot_encoder.transform([[obj_type]])
|
||||
if key == 'walls':
|
||||
for wall in obj:
|
||||
x, y = wall[0][0] + wall[1][0]/2, wall[0][1] + wall[1][1]/2
|
||||
#x, y = (wall[0][0] + wall[1][0]/2)/200, (wall[0][1] + wall[1][1]/2)/200 if u use this you need to change relations.py!!!
|
||||
color = [0, 0, 0] if key == 'walls' else [80, 146, 56]
|
||||
#color = [c / 255 for c in color]
|
||||
shape = 0
|
||||
assert shape in SHAPES.values(), 'Shape not found'
|
||||
shape = shape_onehot_encoder.transform([[shape]])[0]
|
||||
grid_obj = GridObject(x=x, y=y, object_type=obj_type, attributes=[color, shape])
|
||||
grid_objs.append(grid_obj)
|
||||
elif key == 'fuse_walls':
|
||||
# resample green barriers
|
||||
obj = [obj[i] for i in range(len(obj)) if (obj[i][0][0] % 20 == 0 and obj[i][0][1] % 20 == 0)]
|
||||
for wall in obj:
|
||||
x, y = wall[0][0] + wall[1][0]/2, wall[0][1] + wall[1][1]/2
|
||||
#x, y = (wall[0][0] + wall[1][0]/2)/200, (wall[0][1] + wall[1][1]/2)/200 if u use this you need to change relations.py!!!
|
||||
color = [80, 146, 56]
|
||||
#color = [c / 255 for c in color]
|
||||
shape = 0
|
||||
assert shape in SHAPES.values(), 'Shape not found'
|
||||
shape = shape_onehot_encoder.transform([[shape]])[0]
|
||||
grid_obj = GridObject(x=x, y=y, object_type=obj_type, attributes=[color, shape])
|
||||
grid_objs.append(grid_obj)
|
||||
elif key == 'objects':
|
||||
for ob in obj:
|
||||
x, y = ob[0][0] + ob[1], ob[0][1] + ob[1]
|
||||
color = ob[-1]
|
||||
#color = [c / 255 for c in color]
|
||||
shape = find_shape(ob[2], print_shape=False)
|
||||
assert shape in SHAPES.values(), 'Shape not found'
|
||||
shape = shape_onehot_encoder.transform([[shape]])[0]
|
||||
grid_obj = GridObject(x=x, y=y, object_type=obj_type, attributes=[color, shape])
|
||||
grid_objs.append(grid_obj)
|
||||
elif key == 'key':
|
||||
obj = obj[0]
|
||||
x, y = obj[0][0] + obj[1], obj[0][1] + obj[1]
|
||||
color = obj[-1]
|
||||
#color = [c / 255 for c in color]
|
||||
shape = find_shape(obj[2], print_shape=False)
|
||||
assert shape in SHAPES.values(), 'Shape not found'
|
||||
shape = shape_onehot_encoder.transform([[shape]])[0]
|
||||
grid_obj = GridObject(x=x, y=y, object_type=obj_type, attributes=[color, shape])
|
||||
grid_objs.append(grid_obj)
|
||||
elif key == 'lock':
|
||||
obj = obj[0]
|
||||
x, y = obj[0][0] + obj[1], obj[0][1] + obj[1]
|
||||
color = obj[-1]
|
||||
#color = [c / 255 for c in color]
|
||||
shape = find_shape(obj[2], print_shape=False)
|
||||
assert shape in SHAPES.values(), 'Shape not found'
|
||||
shape = shape_onehot_encoder.transform([[shape]])[0]
|
||||
grid_obj = GridObject(x=x, y=y, object_type=obj_type, attributes=[color, shape])
|
||||
grid_objs.append(grid_obj)
|
||||
else:
|
||||
try:
|
||||
x, y = obj[0][0] + obj[1], obj[0][1] + obj[1]
|
||||
color = obj[-1]
|
||||
#color = [c / 255 for c in color]
|
||||
shape = find_shape(obj[2], print_shape=False)
|
||||
assert shape in SHAPES.values(), 'Shape not found'
|
||||
shape = shape_onehot_encoder.transform([[shape]])[0]
|
||||
except:
|
||||
# [[[x, y], extension, shape, color]] in some cases in instrumental_no_barrier (bib_evaluation_1_1)
|
||||
x, y = obj[0][0][0] + obj[0][1], obj[0][0][1] + obj[0][1]
|
||||
color = obj[0][-1]
|
||||
#color = [c / 255 for c in color]
|
||||
assert len(color) == 3
|
||||
shape = find_shape(obj[0][2], print_shape=False)
|
||||
assert shape in SHAPES.values(), 'Shape not found'
|
||||
shape = shape_onehot_encoder.transform([[shape]])[0]
|
||||
grid_obj = GridObject(x=x, y=y, object_type=obj_type, attributes=[color, shape])
|
||||
grid_objs.append(grid_obj)
|
||||
return grid_objs
|
124
utils/index_data.py
Normal file
124
utils/index_data.py
Normal file
|
@ -0,0 +1,124 @@
|
|||
import json
|
||||
import os
|
||||
import torch
|
||||
import torch.utils.data
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def index_data(json_list, path_list):
|
||||
print(f'processing files {len(json_list)}')
|
||||
data_tuples = []
|
||||
for j, v in tqdm(zip(json_list, path_list)):
|
||||
with open(j, 'r') as f:
|
||||
state = json.load(f)
|
||||
ep_lens = [len(x) for x in state]
|
||||
past_len = 0
|
||||
for e, l in enumerate(ep_lens):
|
||||
data_tuples.append([])
|
||||
# skip first 30 frames and last 83 frames
|
||||
for f in range(30, l - 83):
|
||||
# find action taken;
|
||||
f0x, f0y = state[e][f]['agent'][0]
|
||||
f1x, f1y = state[e][f + 1]['agent'][0]
|
||||
dx = (f1x - f0x) / 2.
|
||||
dy = (f1y - f0y) / 2.
|
||||
action = [dx, dy]
|
||||
#data_tuples[-1].append((v, past_len + f, action))
|
||||
data_tuples[-1].append((j, past_len + f, action))
|
||||
# data_tuples = (json file, frame number, action)
|
||||
assert len(data_tuples[-1]) > 0
|
||||
past_len += l
|
||||
return data_tuples
|
||||
|
||||
class TransitionDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
Training dataset class for the behavior cloning mlp model.
|
||||
Args:
|
||||
path: path to the dataset
|
||||
types: list of video types to include
|
||||
size: size of the frames to be returned
|
||||
mode: train, val
|
||||
num_context: number of context state-action pairs
|
||||
num_test: number of test state-action pairs
|
||||
num_trials: number of trials in an episode
|
||||
action_range: number of frames to skip; actions are combined over these number of frames (displcement) of the agent
|
||||
process_data: whether to the videos or not (skip if already processed)
|
||||
__getitem__:
|
||||
returns: (dem_frames, dem_actions, query_frames, target_actions)
|
||||
dem_frames: (num_context, 3, size, size)
|
||||
dem_actions: (num_context, 2)
|
||||
query_frames: (num_test, 3, size, size)
|
||||
target_actions: (num_test, 2)
|
||||
"""
|
||||
def __init__(self, path, types=None, size=None, mode="train", num_context=30, num_test=1, num_trials=9,
|
||||
action_range=10, process_data=0):
|
||||
|
||||
self.path = path
|
||||
self.types = types
|
||||
self.size = size
|
||||
self.mode = mode
|
||||
self.num_trials = num_trials
|
||||
self.num_context = num_context
|
||||
self.num_test = num_test
|
||||
self.action_range = action_range
|
||||
self.ep_combs = self.num_trials * (self.num_trials - 2) # 9p2 - 9
|
||||
self.eps = [[x, y] for x in range(self.num_trials) for y in range(self.num_trials) if x != y]
|
||||
types_str = '_'.join(self.types)
|
||||
|
||||
self.path_list = []
|
||||
self.json_list = []
|
||||
# get video paths and json file paths
|
||||
for t in types:
|
||||
print(f'reading files of type {t} in {mode}')
|
||||
paths = [os.path.join(self.path, t, x) for x in os.listdir(os.path.join(self.path, t)) if
|
||||
x.endswith(f'.mp4')]
|
||||
jsons = [os.path.join(self.path, t, x) for x in os.listdir(os.path.join(self.path, t)) if
|
||||
x.endswith(f'.json') and 'index' not in x]
|
||||
|
||||
paths = sorted(paths)
|
||||
jsons = sorted(jsons)
|
||||
|
||||
if mode == 'train':
|
||||
self.path_list += paths[:int(0.8 * len(jsons))]
|
||||
self.json_list += jsons[:int(0.8 * len(jsons))]
|
||||
elif mode == 'val':
|
||||
self.path_list += paths[int(0.8 * len(jsons)):]
|
||||
self.json_list += jsons[int(0.8 * len(jsons)):]
|
||||
else:
|
||||
self.path_list += paths
|
||||
self.json_list += jsons
|
||||
|
||||
self.data_tuples = []
|
||||
if process_data:
|
||||
# index the videos in the dataset directory. This is done to speed up the retrieval of videos.
|
||||
# frame index, action tuples are stored
|
||||
self.data_tuples = index_data(self.json_list, self.path_list)
|
||||
# tuples of frame index and action (displacement of agent)
|
||||
index_dict = {'data_tuples': self.data_tuples}
|
||||
with open(os.path.join(self.path, f'jindex_bib_{mode}_{types_str}.json'), 'w') as fp:
|
||||
json.dump(index_dict, fp)
|
||||
else:
|
||||
# read pre-indexed data
|
||||
with open(os.path.join(self.path, f'jindex_bib_{mode}_{types_str}.json'), 'r') as fp:
|
||||
index_dict = json.load(fp)
|
||||
self.data_tuples = index_dict['data_tuples']
|
||||
|
||||
self.tot_trials = len(self.path_list) * 9
|
||||
|
||||
|
||||
def __getitem__(self, idx):
|
||||
print('Empty')
|
||||
return
|
||||
|
||||
def __len__(self):
|
||||
return self.tot_trials // self.num_trials
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dataset = TransitionDataset(path='/datasets/external/bib_train/',
|
||||
types=['multi_agent', 'instrumental_action'], #['instrumental_action', 'multi_agent', 'preference', 'single_object'],
|
||||
size=(84, 84),
|
||||
mode="train", num_context=30,
|
||||
num_test=1, num_trials=9,
|
||||
action_range=10, process_data=1)
|
||||
print(len(dataset))
|
116
utils/relations.py
Normal file
116
utils/relations.py
Normal file
|
@ -0,0 +1,116 @@
|
|||
import numpy as np
|
||||
|
||||
|
||||
# =============================== relationships to build the graph ===============================
|
||||
|
||||
def rotate_vec2d(vec, degrees):
|
||||
"""
|
||||
rotate a vector anti-clockwise
|
||||
:param vec:
|
||||
:param degrees:
|
||||
:return:
|
||||
"""
|
||||
theta = np.radians(degrees)
|
||||
c, s = np.cos(theta), np.sin(theta)
|
||||
R = np.array(((c, -s), (s, c)))
|
||||
return R@vec
|
||||
|
||||
# ---------- Remote Directional Relations --------------------------------------------------------
|
||||
|
||||
def is_front(obj1, obj2, direction_vec)->bool:
|
||||
diff = obj2.pos - obj1.pos
|
||||
return diff@direction_vec > 0.1
|
||||
|
||||
def is_back(obj1, obj2, direction_vec)->bool:
|
||||
diff = obj2.pos - obj1.pos
|
||||
return diff@direction_vec < -0.1
|
||||
|
||||
def is_left(obj1, obj2, direction_vec)->bool:
|
||||
left_vec = rotate_vec2d(direction_vec, -90)
|
||||
diff = obj2.pos - obj1.pos
|
||||
return diff@left_vec > 0.1
|
||||
|
||||
def is_right(obj1, obj2, direction_vec)->bool:
|
||||
left_vec = rotate_vec2d(direction_vec, 90)
|
||||
diff = obj2.pos - obj1.pos
|
||||
return diff@left_vec > 0.1
|
||||
|
||||
# ---------- Alignment and Adjacency Relations ---------------------------------------------------
|
||||
|
||||
def is_close(obj1, obj2, direction_vec=None)->bool:
|
||||
# indicate whether two objects are adjacent to each other,
|
||||
# which, unlike local directional relations, carry no directional information
|
||||
distance = np.abs(obj1.pos - obj2.pos)
|
||||
return np.sum(distance)==20
|
||||
|
||||
def is_aligned(obj1, obj2, direction_vec=None)->bool:
|
||||
# indicate if two entities are on the same horizontal or vertical line
|
||||
diff = obj2.pos - obj1.pos
|
||||
return np.any(diff==0)
|
||||
|
||||
# ---------- Local Directional Relations ---------------------------------------------------------
|
||||
|
||||
def is_top_adj(obj1, obj2, direction_vec=None)->bool:
|
||||
return obj1.x==obj2.x and obj1.y==obj2.y+20
|
||||
|
||||
def is_left_adj(obj1, obj2, direction_vec=None)->bool:
|
||||
return obj1.y==obj2.y and obj1.x==obj2.x-20
|
||||
|
||||
def is_top_left_adj(obj1, obj2, direction_vec=None)->bool:
|
||||
return obj1.y==obj2.y+20 and obj1.x==obj2.x-20
|
||||
|
||||
def is_top_right_adj(obj1, obj2, direction_vec=None)->bool:
|
||||
return obj1.y==obj2.y+20 and obj1.x==obj2.x+20
|
||||
|
||||
def is_down_adj(obj1, obj2, direction_vec=None)->bool:
|
||||
return is_top_adj(obj2, obj1)
|
||||
|
||||
def is_right_adj(obj1, obj2, direction_vec=None)->bool:
|
||||
return is_left_adj(obj2, obj1)
|
||||
|
||||
def is_down_right_adj(obj1, obj2, direction_vec=None)->bool:
|
||||
return is_top_left_adj(obj2, obj1)
|
||||
|
||||
def is_down_left_adj(obj1, obj2, direction_vec=None)->bool:
|
||||
return is_top_right_adj(obj2, obj1)
|
||||
|
||||
# ---------- More Remote Directional Relations (not used) ----------------------------------------
|
||||
|
||||
def top_left(obj1, obj2, direction_vec)->bool:
|
||||
return (obj1.x-obj2.x) <= (obj1.y-obj2.y)
|
||||
|
||||
def top_right(obj1, obj2, direction_vec)->bool:
|
||||
return -(obj1.x-obj2.x) <= (obj1.y-obj2.y)
|
||||
|
||||
def down_left(obj1, obj2, direction_vec)->bool:
|
||||
return top_right(obj2, obj1, direction_vec)
|
||||
|
||||
def down_right(obj1, obj2, direction_vec)->bool:
|
||||
return top_left(obj2, obj1, direction_vec)
|
||||
|
||||
def fan_top(obj1, obj2, direction_vec)->bool:
|
||||
top_left = (obj1.x-obj2.x) <= (obj1.y-obj2.y)
|
||||
top_right = -(obj1.x-obj2.x) <= (obj1.y-obj2.y)
|
||||
return top_left and top_right
|
||||
|
||||
def fan_down(obj1, obj2, direction_vec)->bool:
|
||||
return fan_top(obj2, obj1, direction_vec)
|
||||
|
||||
def fan_right(obj1, obj2, direction_vec)->bool:
|
||||
down_left = (obj1.x-obj2.x) >= (obj1.y-obj2.y)
|
||||
top_right = -(obj1.x-obj2.x) <= (obj1.y-obj2.y)
|
||||
return down_left and top_right
|
||||
|
||||
def fan_left(obj1, obj2, direction_vec)->bool:
|
||||
return fan_right(obj2, obj1, direction_vec)
|
||||
|
||||
# ---------- Ad-hoc Relations --------------------------------------------------------------------
|
||||
|
||||
def needs(obj1, obj2, direction_vec=None)->bool:
|
||||
return np.argmax(obj1.type) == 0 and np.argmax(obj2.type) == 3
|
||||
|
||||
def opens(obj1, obj2, direction_vec=None)->bool:
|
||||
return np.argmax(obj1.type) == 3 and np.argmax(obj2.type) == 8
|
||||
|
||||
def collects(obj1, obj2, direction_vec=None)->bool:
|
||||
return np.argmax(obj1.type) == 0 and np.argmax(obj2.type) == 6
|
1
utils/run_build_graphs.sh
Normal file
1
utils/run_build_graphs.sh
Normal file
|
@ -0,0 +1 @@
|
|||
python build_graphs.py --mode train --cpus 30 && python build_graphs.py --mode val --cpus 30
|
Loading…
Add table
Add a link
Reference in a new issue