488 lines
24 KiB
Python
488 lines
24 KiB
Python
|
import dgl
|
||
|
import torch
|
||
|
import torch.utils.data
|
||
|
import os
|
||
|
import pickle as pkl
|
||
|
import json
|
||
|
import numpy as np
|
||
|
from tqdm import tqdm
|
||
|
|
||
|
import sys
|
||
|
sys.path.append('/projects/bortoletto/irene/')
|
||
|
from utils.grid_object import *
|
||
|
from utils.relations import *
|
||
|
|
||
|
# ========================== Helper functions ==========================
|
||
|
|
||
|
def index_data(json_list, path_list):
|
||
|
print(f'processing files {len(json_list)}')
|
||
|
data_tuples = []
|
||
|
for j, v in tqdm(zip(json_list, path_list)):
|
||
|
with open(j, 'r') as f:
|
||
|
state = json.load(f)
|
||
|
ep_lens = [len(x) for x in state]
|
||
|
past_len = 0
|
||
|
for e, l in enumerate(ep_lens):
|
||
|
data_tuples.append([])
|
||
|
# skip first 30 frames and last 83 frames
|
||
|
for f in range(30, l - 83):
|
||
|
# find action taken;
|
||
|
f0x, f0y = state[e][f]['agent'][0]
|
||
|
f1x, f1y = state[e][f + 1]['agent'][0]
|
||
|
dx = (f1x - f0x) / 2.
|
||
|
dy = (f1y - f0y) / 2.
|
||
|
action = [dx, dy]
|
||
|
#data_tuples[-1].append((v, past_len + f, action))
|
||
|
data_tuples[-1].append((j, past_len + f, action))
|
||
|
# data_tuples = [[json file, frame number, action] for each episode in each video]
|
||
|
assert len(data_tuples[-1]) > 0
|
||
|
past_len += l
|
||
|
return data_tuples
|
||
|
|
||
|
# ========================== Dataset class ==========================
|
||
|
|
||
|
class TransitionDataset(torch.utils.data.Dataset):
|
||
|
"""
|
||
|
Training dataset class for the behavior cloning mlp model.
|
||
|
Args:
|
||
|
path: path to the dataset
|
||
|
types: list of video types to include
|
||
|
mode: train, val
|
||
|
num_test: number of test state-action pairs
|
||
|
num_trials: number of trials in an episode
|
||
|
action_range: number of frames to skip; actions are combined over these number of frames (displcement) of the agent
|
||
|
process_data: whether to the videos or not (skip if already processed)
|
||
|
max_len: max number of context state-action pairs
|
||
|
__getitem__:
|
||
|
returns: (states, actions, lens, n_nodes)
|
||
|
dem_frames: batched DGLGraph.heterograph
|
||
|
dem_actions: (max_len, 2)
|
||
|
query_frames: DGLGraph.heterograph
|
||
|
target_actions: (num_test, 2)
|
||
|
"""
|
||
|
def __init__(
|
||
|
self,
|
||
|
path,
|
||
|
types=None,
|
||
|
mode="train",
|
||
|
num_test=1,
|
||
|
num_trials=9,
|
||
|
action_range=10,
|
||
|
process_data=0,
|
||
|
max_len=30
|
||
|
):
|
||
|
self.path = path
|
||
|
self.types = types
|
||
|
self.mode = mode
|
||
|
self.num_trials = num_trials
|
||
|
self.num_test = num_test
|
||
|
self.action_range = action_range
|
||
|
self.max_len = max_len
|
||
|
self.ep_combs = self.num_trials * (self.num_trials - 2) # 9p2 - 9
|
||
|
self.eps = [[x, y] for x in range(self.num_trials) for y in range(self.num_trials) if x != y]
|
||
|
types_str = '_'.join(self.types)
|
||
|
self.rel_deter_func = [
|
||
|
is_top_adj, is_left_adj, is_top_right_adj, is_top_left_adj,
|
||
|
is_down_adj, is_right_adj, is_down_left_adj, is_down_right_adj,
|
||
|
is_left, is_right, is_front, is_back, is_aligned, is_close
|
||
|
]
|
||
|
self.path_list = []
|
||
|
self.json_list = []
|
||
|
# get video paths and json file paths
|
||
|
for t in types:
|
||
|
print(f'reading files of type {t} in {mode}')
|
||
|
paths = [os.path.join(self.path, t, x) for x in os.listdir(os.path.join(self.path, t)) if
|
||
|
x.endswith(f'.mp4')]
|
||
|
jsons = [os.path.join(self.path, t, x) for x in os.listdir(os.path.join(self.path, t)) if
|
||
|
x.endswith(f'.json') and 'index' not in x]
|
||
|
paths = sorted(paths)
|
||
|
jsons = sorted(jsons)
|
||
|
if mode == 'train':
|
||
|
self.path_list += paths[:int(0.8 * len(jsons))]
|
||
|
self.json_list += jsons[:int(0.8 * len(jsons))]
|
||
|
elif mode == 'val':
|
||
|
self.path_list += paths[int(0.8 * len(jsons)):]
|
||
|
self.json_list += jsons[int(0.8 * len(jsons)):]
|
||
|
else:
|
||
|
self.path_list += paths
|
||
|
self.json_list += jsons
|
||
|
self.data_tuples = []
|
||
|
if process_data:
|
||
|
# index the videos in the dataset directory. This is done to speed up the retrieval of videos.
|
||
|
# frame index, action tuples are stored
|
||
|
self.data_tuples = index_data(self.json_list, self.path_list)
|
||
|
# tuples of frame index and action (displacement of agent)
|
||
|
index_dict = {'data_tuples': self.data_tuples}
|
||
|
with open(os.path.join(self.path, f'jindex_bib_{mode}_{types_str}.json'), 'w') as fp:
|
||
|
json.dump(index_dict, fp)
|
||
|
else:
|
||
|
# read pre-indexed data
|
||
|
with open(os.path.join(self.path, f'jindex_bib_{mode}_{types_str}.json'), 'r') as fp:
|
||
|
index_dict = json.load(fp)
|
||
|
self.data_tuples = index_dict['data_tuples']
|
||
|
self.tot_trials = len(self.path_list) * 9
|
||
|
|
||
|
def _get_frame_graph(self, jsonfile, frame_idx):
|
||
|
# load json
|
||
|
with open(jsonfile, 'rb') as f:
|
||
|
frame_data = json.load(f)
|
||
|
flat_list = [x for xs in frame_data for x in xs]
|
||
|
# extract entities
|
||
|
grid_objs = parse_objects(flat_list[frame_idx])
|
||
|
# --- build the graph
|
||
|
adj = self._get_spatial_rel(grid_objs)
|
||
|
# define edges
|
||
|
is_top_adj_src, is_top_adj_dst = np.nonzero(adj[0])
|
||
|
is_left_adj_src, is_left_adj_dst = np.nonzero(adj[1])
|
||
|
is_top_right_adj_src, is_top_right_adj_dst = np.nonzero(adj[2])
|
||
|
is_top_left_adj_src, is_top_left_adj_dst = np.nonzero(adj[3])
|
||
|
is_down_adj_src, is_down_adj_dst = np.nonzero(adj[4])
|
||
|
is_right_adj_src, is_right_adj_dst = np.nonzero(adj[5])
|
||
|
is_down_left_adj_src, is_down_left_adj_dst = np.nonzero(adj[6])
|
||
|
is_down_right_adj_src, is_down_right_adj_dst = np.nonzero(adj[7])
|
||
|
is_left_src, is_left_dst = np.nonzero(adj[8])
|
||
|
is_right_src, is_right_dst = np.nonzero(adj[9])
|
||
|
is_front_src, is_front_dst = np.nonzero(adj[10])
|
||
|
is_back_src, is_back_dst = np.nonzero(adj[11])
|
||
|
is_aligned_src, is_aligned_dst = np.nonzero(adj[12])
|
||
|
is_close_src, is_close_dst = np.nonzero(adj[13])
|
||
|
g = dgl.heterograph({
|
||
|
('obj', 'is_top_adj', 'obj'): (torch.tensor(is_top_adj_src), torch.tensor(is_top_adj_dst)),
|
||
|
('obj', 'is_left_adj', 'obj'): (torch.tensor(is_left_adj_src), torch.tensor(is_left_adj_dst)),
|
||
|
('obj', 'is_top_right_adj', 'obj'): (torch.tensor(is_top_right_adj_src), torch.tensor(is_top_right_adj_dst)),
|
||
|
('obj', 'is_top_left_adj', 'obj'): (torch.tensor(is_top_left_adj_src), torch.tensor(is_top_left_adj_dst)),
|
||
|
('obj', 'is_down_adj', 'obj'): (torch.tensor(is_down_adj_src), torch.tensor(is_down_adj_dst)),
|
||
|
('obj', 'is_right_adj', 'obj'): (torch.tensor(is_right_adj_src), torch.tensor(is_right_adj_dst)),
|
||
|
('obj', 'is_down_left_adj', 'obj'): (torch.tensor(is_down_left_adj_src), torch.tensor(is_down_left_adj_dst)),
|
||
|
('obj', 'is_down_right_adj', 'obj'): (torch.tensor(is_down_right_adj_src), torch.tensor(is_down_right_adj_dst)),
|
||
|
('obj', 'is_left', 'obj'): (torch.tensor(is_left_src), torch.tensor(is_left_dst)),
|
||
|
('obj', 'is_right', 'obj'): (torch.tensor(is_right_src), torch.tensor(is_right_dst)),
|
||
|
('obj', 'is_front', 'obj'): (torch.tensor(is_front_src), torch.tensor(is_front_dst)),
|
||
|
('obj', 'is_back', 'obj'): (torch.tensor(is_back_src), torch.tensor(is_back_dst)),
|
||
|
('obj', 'is_aligned', 'obj'): (torch.tensor(is_aligned_src), torch.tensor(is_aligned_dst)),
|
||
|
('obj', 'is_close', 'obj'): (torch.tensor(is_close_src), torch.tensor(is_close_dst))
|
||
|
}, num_nodes_dict={'obj': len(grid_objs)})
|
||
|
g = self._add_node_features(grid_objs, g)
|
||
|
breakpoint()
|
||
|
return g
|
||
|
|
||
|
def _add_node_features(self, objs, graph):
|
||
|
for obj_idx, obj in enumerate(objs):
|
||
|
graph.nodes[obj_idx].data['type'] = torch.tensor(obj.type)
|
||
|
graph.nodes[obj_idx].data['pos'] = torch.tensor([[obj.x, obj.y]], dtype=torch.float32)
|
||
|
assert len(obj.attributes) == 2 and None not in obj.attributes[0] and None not in obj.attributes[1]
|
||
|
graph.nodes[obj_idx].data['color'] = torch.tensor([obj.attributes[0]])
|
||
|
graph.nodes[obj_idx].data['shape'] = torch.tensor([obj.attributes[1]])
|
||
|
return graph
|
||
|
|
||
|
def _get_spatial_rel(self, objs):
|
||
|
spatial_tensors = [np.zeros([len(objs), len(objs)]) for _ in range(len(self.rel_deter_func))]
|
||
|
for obj_idx1, obj1 in enumerate(objs):
|
||
|
for obj_idx2, obj2 in enumerate(objs):
|
||
|
direction_vec = np.array((0, -1)) # Up
|
||
|
for rel_idx, func in enumerate(self.rel_deter_func):
|
||
|
if func(obj1, obj2, direction_vec):
|
||
|
spatial_tensors[rel_idx][obj_idx1, obj_idx2] = 1.0
|
||
|
return spatial_tensors
|
||
|
|
||
|
def get_trial(self, trials, step=10):
|
||
|
# retrieve state embeddings and actions from cached file
|
||
|
states = []
|
||
|
actions = []
|
||
|
trial_len = []
|
||
|
lens = []
|
||
|
n_nodes = []
|
||
|
# 8 trials
|
||
|
for t in trials:
|
||
|
tl = [(t, n) for n in range(0, len(self.data_tuples[t]), step)]
|
||
|
if len(tl) > self.max_len: # 30
|
||
|
tl = tl[:self.max_len]
|
||
|
trial_len.append(tl)
|
||
|
for tl in trial_len:
|
||
|
states.append([])
|
||
|
actions.append([])
|
||
|
lens.append(len(tl))
|
||
|
for t, n in tl:
|
||
|
video = self.data_tuples[t][n][0]
|
||
|
states[-1].append(self._get_frame_graph(video, self.data_tuples[t][n][1]))
|
||
|
n_nodes.append(states[-1][-1].number_of_nodes())
|
||
|
# actions are pooled over frames
|
||
|
if len(self.data_tuples[t]) > n + self.action_range:
|
||
|
actions_xy = [d[2] for d in self.data_tuples[t][n:n + self.action_range]]
|
||
|
else:
|
||
|
actions_xy = [d[2] for d in self.data_tuples[t][n:]]
|
||
|
actions_xy = np.array(actions_xy)
|
||
|
actions_xy = np.mean(actions_xy, axis=0)
|
||
|
actions[-1].append(actions_xy)
|
||
|
states[-1] = dgl.batch(states[-1])
|
||
|
actions[-1] = torch.tensor(np.array(actions[-1]))
|
||
|
trial_actions_padded = torch.zeros(self.max_len, actions[-1].size(1))
|
||
|
trial_actions_padded[:actions[-1].size(0), :] = actions[-1]
|
||
|
actions[-1] = trial_actions_padded
|
||
|
return states, actions, lens, n_nodes
|
||
|
|
||
|
def __getitem__(self, idx):
|
||
|
ep_trials = [idx * self.num_trials + t for t in range(self.num_trials)] # [idx, ..., idx+8]
|
||
|
states, actions, lens, n_nodes = self.get_trial(ep_trials, step=self.action_range)
|
||
|
return states, actions, lens, n_nodes
|
||
|
|
||
|
def __len__(self):
|
||
|
return self.tot_trials // self.num_trials
|
||
|
|
||
|
|
||
|
class TestTransitionDatasetSequence(torch.utils.data.Dataset):
|
||
|
"""
|
||
|
Test dataset class for the behavior cloning rnn model. This dataset is used to test the model on the eval data.
|
||
|
This class is used to compare plausible and implausible episodes.
|
||
|
Args:
|
||
|
path: path to the dataset
|
||
|
types: list of video types to include
|
||
|
size: size of the frames to be returned
|
||
|
mode: test
|
||
|
num_context: number of context state-action pairs
|
||
|
num_test: number of test state-action pairs
|
||
|
num_trials: number of trials in an episode
|
||
|
action_range: number of frames to skip; actions are combined over these number of frames (displcement) of the agent
|
||
|
process_data: whether to the videos or not (skip if already processed)
|
||
|
__getitem__:
|
||
|
returns: (expected_dem_frames, expected_dem_actions, expected_dem_lens expected_query_frames, expected_target_actions,
|
||
|
unexpected_dem_frames, unexpected_dem_actions, unexpected_dem_lens, unexpected_query_frames, unexpected_target_actions)
|
||
|
dem_frames: (num_context, max_len, 3, size, size)
|
||
|
dem_actions: (num_context, max_len, 2)
|
||
|
dem_lens: (num_context)
|
||
|
query_frames: (num_test, 3, size, size)
|
||
|
target_actions: (num_test, 2)
|
||
|
"""
|
||
|
def __init__(
|
||
|
self,
|
||
|
path,
|
||
|
task_type=None,
|
||
|
mode="test",
|
||
|
num_test=1,
|
||
|
num_trials=9,
|
||
|
action_range=10,
|
||
|
process_data=0,
|
||
|
max_len=30
|
||
|
):
|
||
|
self.path = path
|
||
|
self.task_type = task_type
|
||
|
self.mode = mode
|
||
|
self.num_trials = num_trials
|
||
|
self.num_test = num_test
|
||
|
self.action_range = action_range
|
||
|
self.max_len = max_len
|
||
|
self.ep_combs = self.num_trials * (self.num_trials - 2) # 9p2 - 9
|
||
|
self.eps = [[x, y] for x in range(self.num_trials) for y in range(self.num_trials) if x != y]
|
||
|
self.path_list_exp = []
|
||
|
self.json_list_exp = []
|
||
|
self.path_list_un = []
|
||
|
self.json_list_un = []
|
||
|
|
||
|
print(f'reading files of type {task_type} in {mode}')
|
||
|
|
||
|
paths_expected = sorted([os.path.join(self.path, task_type, x) for x in os.listdir(os.path.join(self.path, task_type)) if
|
||
|
x.endswith('e.mp4')])
|
||
|
jsons_expected = sorted([os.path.join(self.path, task_type, x) for x in os.listdir(os.path.join(self.path, task_type)) if
|
||
|
x.endswith('e.json') and 'index' not in x])
|
||
|
paths_unexpected = sorted([os.path.join(self.path, task_type, x) for x in os.listdir(os.path.join(self.path, task_type)) if
|
||
|
x.endswith('u.mp4')])
|
||
|
jsons_unexpected = sorted([os.path.join(self.path, task_type, x) for x in os.listdir(os.path.join(self.path, task_type)) if
|
||
|
x.endswith('u.json') and 'index' not in x])
|
||
|
self.path_list_exp += paths_expected
|
||
|
self.json_list_exp += jsons_expected
|
||
|
self.path_list_un += paths_unexpected
|
||
|
self.json_list_un += jsons_unexpected
|
||
|
self.data_expected = []
|
||
|
self.data_unexpected = []
|
||
|
|
||
|
if process_data:
|
||
|
# index data. This is done to speed up video retrieval.
|
||
|
# frame index, action tuples are stored
|
||
|
self.data_expected = index_data(self.json_list_exp, self.path_list_exp)
|
||
|
index_dict = {'data_tuples': self.data_expected}
|
||
|
with open(os.path.join(self.path, f'jindex_bib_test_{task_type}e.json'), 'w') as fp:
|
||
|
json.dump(index_dict, fp)
|
||
|
|
||
|
self.data_unexpected = index_data(self.json_list_un, self.path_list_un)
|
||
|
index_dict = {'data_tuples': self.data_unexpected}
|
||
|
with open(os.path.join(self.path, f'jindex_bib_test_{task_type}u.json'), 'w') as fp:
|
||
|
json.dump(index_dict, fp)
|
||
|
else:
|
||
|
with open(os.path.join(self.path, f'jindex_bib_{mode}_{task_type}e.json'), 'r') as fp:
|
||
|
index_dict = json.load(fp)
|
||
|
self.data_expected = index_dict['data_tuples']
|
||
|
with open(os.path.join(self.path, f'jindex_bib_{mode}_{task_type}u.json'), 'r') as fp:
|
||
|
index_dict = json.load(fp)
|
||
|
self.data_unexpected = index_dict['data_tuples']
|
||
|
|
||
|
self.rel_deter_func = [
|
||
|
is_top_adj, is_left_adj, is_top_right_adj, is_top_left_adj,
|
||
|
is_down_adj, is_right_adj, is_down_left_adj, is_down_right_adj,
|
||
|
is_left, is_right, is_front, is_back, is_aligned, is_close
|
||
|
]
|
||
|
|
||
|
print('Done.')
|
||
|
|
||
|
def _get_frame_graph(self, jsonfile, frame_idx):
|
||
|
# load json
|
||
|
with open(jsonfile, 'rb') as f:
|
||
|
frame_data = json.load(f)
|
||
|
flat_list = [x for xs in frame_data for x in xs]
|
||
|
# extract entities
|
||
|
grid_objs = parse_objects(flat_list[frame_idx])
|
||
|
# --- build the graph
|
||
|
adj = self._get_spatial_rel(grid_objs)
|
||
|
# define edges
|
||
|
is_top_adj_src, is_top_adj_dst = np.nonzero(adj[0])
|
||
|
is_left_adj_src, is_left_adj_dst = np.nonzero(adj[1])
|
||
|
is_top_right_adj_src, is_top_right_adj_dst = np.nonzero(adj[2])
|
||
|
is_top_left_adj_src, is_top_left_adj_dst = np.nonzero(adj[3])
|
||
|
is_down_adj_src, is_down_adj_dst = np.nonzero(adj[4])
|
||
|
is_right_adj_src, is_right_adj_dst = np.nonzero(adj[5])
|
||
|
is_down_left_adj_src, is_down_left_adj_dst = np.nonzero(adj[6])
|
||
|
is_down_right_adj_src, is_down_right_adj_dst = np.nonzero(adj[7])
|
||
|
is_left_src, is_left_dst = np.nonzero(adj[8])
|
||
|
is_right_src, is_right_dst = np.nonzero(adj[9])
|
||
|
is_front_src, is_front_dst = np.nonzero(adj[10])
|
||
|
is_back_src, is_back_dst = np.nonzero(adj[11])
|
||
|
is_aligned_src, is_aligned_dst = np.nonzero(adj[12])
|
||
|
is_close_src, is_close_dst = np.nonzero(adj[13])
|
||
|
g = dgl.heterograph({
|
||
|
('obj', 'is_top_adj', 'obj'): (torch.tensor(is_top_adj_src), torch.tensor(is_top_adj_dst)),
|
||
|
('obj', 'is_left_adj', 'obj'): (torch.tensor(is_left_adj_src), torch.tensor(is_left_adj_dst)),
|
||
|
('obj', 'is_top_right_adj', 'obj'): (torch.tensor(is_top_right_adj_src), torch.tensor(is_top_right_adj_dst)),
|
||
|
('obj', 'is_top_left_adj', 'obj'): (torch.tensor(is_top_left_adj_src), torch.tensor(is_top_left_adj_dst)),
|
||
|
('obj', 'is_down_adj', 'obj'): (torch.tensor(is_down_adj_src), torch.tensor(is_down_adj_dst)),
|
||
|
('obj', 'is_right_adj', 'obj'): (torch.tensor(is_right_adj_src), torch.tensor(is_right_adj_dst)),
|
||
|
('obj', 'is_down_left_adj', 'obj'): (torch.tensor(is_down_left_adj_src), torch.tensor(is_down_left_adj_dst)),
|
||
|
('obj', 'is_down_right_adj', 'obj'): (torch.tensor(is_down_right_adj_src), torch.tensor(is_down_right_adj_dst)),
|
||
|
('obj', 'is_left', 'obj'): (torch.tensor(is_left_src), torch.tensor(is_left_dst)),
|
||
|
('obj', 'is_right', 'obj'): (torch.tensor(is_right_src), torch.tensor(is_right_dst)),
|
||
|
('obj', 'is_front', 'obj'): (torch.tensor(is_front_src), torch.tensor(is_front_dst)),
|
||
|
('obj', 'is_back', 'obj'): (torch.tensor(is_back_src), torch.tensor(is_back_dst)),
|
||
|
('obj', 'is_aligned', 'obj'): (torch.tensor(is_aligned_src), torch.tensor(is_aligned_dst)),
|
||
|
('obj', 'is_close', 'obj'): (torch.tensor(is_close_src), torch.tensor(is_close_dst))
|
||
|
}, num_nodes_dict={'obj': len(grid_objs)})
|
||
|
g = self._add_node_features(grid_objs, g)
|
||
|
return g
|
||
|
|
||
|
def _add_node_features(self, objs, graph):
|
||
|
for obj_idx, obj in enumerate(objs):
|
||
|
graph.nodes[obj_idx].data['type'] = torch.tensor(obj.type)
|
||
|
graph.nodes[obj_idx].data['pos'] = torch.tensor([[obj.x, obj.y]], dtype=torch.float32)
|
||
|
assert len(obj.attributes) == 2 and None not in obj.attributes[0] and None not in obj.attributes[1]
|
||
|
graph.nodes[obj_idx].data['color'] = torch.tensor([obj.attributes[0]])
|
||
|
graph.nodes[obj_idx].data['shape'] = torch.tensor([obj.attributes[1]])
|
||
|
return graph
|
||
|
|
||
|
def _get_spatial_rel(self, objs):
|
||
|
spatial_tensors = [np.zeros([len(objs), len(objs)]) for _ in range(len(self.rel_deter_func))]
|
||
|
for obj_idx1, obj1 in enumerate(objs):
|
||
|
for obj_idx2, obj2 in enumerate(objs):
|
||
|
direction_vec = np.array((0, -1)) # Up why??????????????
|
||
|
for rel_idx, func in enumerate(self.rel_deter_func):
|
||
|
if func(obj1, obj2, direction_vec):
|
||
|
spatial_tensors[rel_idx][obj_idx1, obj_idx2] = 1.0
|
||
|
return spatial_tensors
|
||
|
|
||
|
def get_trial(self, trials, data, step=10):
|
||
|
# retrieve state embeddings and actions from cached file
|
||
|
states = []
|
||
|
actions = []
|
||
|
trial_len = []
|
||
|
lens = []
|
||
|
n_nodes = []
|
||
|
for t in trials:
|
||
|
tl = [(t, n) for n in range(0, len(data[t]), step)]
|
||
|
if len(tl) > self.max_len:
|
||
|
tl = tl[:self.max_len]
|
||
|
trial_len.append(tl)
|
||
|
for tl in trial_len:
|
||
|
states.append([])
|
||
|
actions.append([])
|
||
|
lens.append(len(tl))
|
||
|
for t, n in tl:
|
||
|
video = data[t][n][0]
|
||
|
states[-1].append(self._get_frame_graph(video, data[t][n][1]))
|
||
|
n_nodes.append(states[-1][-1].number_of_nodes())
|
||
|
if len(data[t]) > n + self.action_range:
|
||
|
actions_xy = [d[2] for d in data[t][n:n + self.action_range]]
|
||
|
else:
|
||
|
actions_xy = [d[2] for d in data[t][n:]]
|
||
|
actions_xy = np.array(actions_xy)
|
||
|
actions_xy = np.mean(actions_xy, axis=0)
|
||
|
actions[-1].append(actions_xy)
|
||
|
states[-1] = dgl.batch(states[-1])
|
||
|
actions[-1] = torch.tensor(np.array(actions[-1]))
|
||
|
trial_actions_padded = torch.zeros(self.max_len, actions[-1].size(1))
|
||
|
trial_actions_padded[:actions[-1].size(0), :] = actions[-1]
|
||
|
actions[-1] = trial_actions_padded
|
||
|
return states, actions, lens, n_nodes
|
||
|
|
||
|
def get_test(self, trial, data, step=10):
|
||
|
# retrieve state embeddings and actions from cached file
|
||
|
states = []
|
||
|
actions = []
|
||
|
trial_len = []
|
||
|
trial_len += [(trial, n) for n in range(0, len(data[trial]), step)]
|
||
|
for t, n in trial_len:
|
||
|
video = data[t][n][0]
|
||
|
state = self._get_frame_graph(video, data[t][n][1])
|
||
|
if len(data[t]) > n + self.action_range:
|
||
|
actions_xy = [d[2] for d in data[t][n:n + self.action_range]]
|
||
|
else:
|
||
|
actions_xy = [d[2] for d in data[t][n:]]
|
||
|
actions_xy = np.array(actions_xy)
|
||
|
actions_xy = np.mean(actions_xy, axis=0)
|
||
|
actions.append(actions_xy)
|
||
|
states.append(state)
|
||
|
#states = torch.stack(states)
|
||
|
states = dgl.batch(states)
|
||
|
actions = torch.tensor(np.array(actions))
|
||
|
return states, actions
|
||
|
|
||
|
def __getitem__(self, idx):
|
||
|
ep_trials = [idx * self.num_trials + t for t in range(self.num_trials)]
|
||
|
dem_expected_states, dem_expected_actions, dem_expected_lens, dem_expected_nodes = self.get_trial(
|
||
|
ep_trials[:-1], self.data_expected, step=self.action_range
|
||
|
)
|
||
|
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, dem_unexpected_nodes = self.get_trial(
|
||
|
ep_trials[:-1], self.data_unexpected, step=self.action_range
|
||
|
)
|
||
|
query_expected_frames, target_expected_actions = self.get_test(
|
||
|
ep_trials[-1], self.data_expected, step=self.action_range
|
||
|
)
|
||
|
query_unexpected_frames, target_unexpected_actions = self.get_test(
|
||
|
ep_trials[-1], self.data_unexpected, step=self.action_range
|
||
|
)
|
||
|
return dem_expected_states, dem_expected_actions, dem_expected_lens, dem_expected_nodes, \
|
||
|
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, dem_unexpected_nodes, \
|
||
|
query_expected_frames, target_expected_actions, \
|
||
|
query_unexpected_frames, target_unexpected_actions
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.path_list_exp)
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
types = ['preference', 'multi_agent', 'inaccessible_goal',
|
||
|
'efficiency_irrational', 'efficiency_time','efficiency_path',
|
||
|
'instrumental_no_barrier', 'instrumental_blocking_barrier', 'instrumental_inconsequential_barrier']
|
||
|
for t in types:
|
||
|
ttd = TestTransitionDatasetSequence(path='/datasets/external/bib_evaluation_1_1/', task_type=t, process_data=0, mode='test')
|
||
|
for i in range(ttd.__len__()):
|
||
|
print(i, end='\r')
|
||
|
dem_expected_states, dem_expected_actions, dem_expected_lens, dem_expected_nodes, \
|
||
|
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, dem_unexpected_nodes, \
|
||
|
query_expected_frames, target_expected_actions, \
|
||
|
query_unexpected_frames, target_unexpected_actions = ttd.__getitem__(i)
|
||
|
for j in range(8):
|
||
|
if not torch.tensor([1., 0., 0., 0., 0., 0., 0., 0., 0.]) in dem_expected_states[j].ndata['type']:
|
||
|
print(i)
|
||
|
if not torch.tensor([1., 0., 0., 0., 0., 0., 0., 0., 0.]) in dem_unexpected_states[j].ndata['type']:
|
||
|
print(i)
|
||
|
if not torch.tensor([1., 0., 0., 0., 0., 0., 0., 0., 0.]) in query_expected_frames.ndata['type']:
|
||
|
print(i)
|
||
|
if not torch.tensor([1., 0., 0., 0., 0., 0., 0., 0., 0.]) in query_unexpected_frames.ndata['type']:
|
||
|
print(i)
|