IRENE/utils/dataset.py

488 lines
24 KiB
Python
Raw Normal View History

2024-02-01 15:40:47 +01:00
import dgl
import torch
import torch.utils.data
import os
import pickle as pkl
import json
import numpy as np
from tqdm import tqdm
import sys
sys.path.append('/projects/bortoletto/irene/')
from utils.grid_object import *
from utils.relations import *
# ========================== Helper functions ==========================
def index_data(json_list, path_list):
print(f'processing files {len(json_list)}')
data_tuples = []
for j, v in tqdm(zip(json_list, path_list)):
with open(j, 'r') as f:
state = json.load(f)
ep_lens = [len(x) for x in state]
past_len = 0
for e, l in enumerate(ep_lens):
data_tuples.append([])
# skip first 30 frames and last 83 frames
for f in range(30, l - 83):
# find action taken;
f0x, f0y = state[e][f]['agent'][0]
f1x, f1y = state[e][f + 1]['agent'][0]
dx = (f1x - f0x) / 2.
dy = (f1y - f0y) / 2.
action = [dx, dy]
#data_tuples[-1].append((v, past_len + f, action))
data_tuples[-1].append((j, past_len + f, action))
# data_tuples = [[json file, frame number, action] for each episode in each video]
assert len(data_tuples[-1]) > 0
past_len += l
return data_tuples
# ========================== Dataset class ==========================
class TransitionDataset(torch.utils.data.Dataset):
"""
Training dataset class for the behavior cloning mlp model.
Args:
path: path to the dataset
types: list of video types to include
mode: train, val
num_test: number of test state-action pairs
num_trials: number of trials in an episode
action_range: number of frames to skip; actions are combined over these number of frames (displcement) of the agent
process_data: whether to the videos or not (skip if already processed)
max_len: max number of context state-action pairs
__getitem__:
returns: (states, actions, lens, n_nodes)
dem_frames: batched DGLGraph.heterograph
dem_actions: (max_len, 2)
query_frames: DGLGraph.heterograph
target_actions: (num_test, 2)
"""
def __init__(
self,
path,
types=None,
mode="train",
num_test=1,
num_trials=9,
action_range=10,
process_data=0,
max_len=30
):
self.path = path
self.types = types
self.mode = mode
self.num_trials = num_trials
self.num_test = num_test
self.action_range = action_range
self.max_len = max_len
self.ep_combs = self.num_trials * (self.num_trials - 2) # 9p2 - 9
self.eps = [[x, y] for x in range(self.num_trials) for y in range(self.num_trials) if x != y]
types_str = '_'.join(self.types)
self.rel_deter_func = [
is_top_adj, is_left_adj, is_top_right_adj, is_top_left_adj,
is_down_adj, is_right_adj, is_down_left_adj, is_down_right_adj,
is_left, is_right, is_front, is_back, is_aligned, is_close
]
self.path_list = []
self.json_list = []
# get video paths and json file paths
for t in types:
print(f'reading files of type {t} in {mode}')
paths = [os.path.join(self.path, t, x) for x in os.listdir(os.path.join(self.path, t)) if
x.endswith(f'.mp4')]
jsons = [os.path.join(self.path, t, x) for x in os.listdir(os.path.join(self.path, t)) if
x.endswith(f'.json') and 'index' not in x]
paths = sorted(paths)
jsons = sorted(jsons)
if mode == 'train':
self.path_list += paths[:int(0.8 * len(jsons))]
self.json_list += jsons[:int(0.8 * len(jsons))]
elif mode == 'val':
self.path_list += paths[int(0.8 * len(jsons)):]
self.json_list += jsons[int(0.8 * len(jsons)):]
else:
self.path_list += paths
self.json_list += jsons
self.data_tuples = []
if process_data:
# index the videos in the dataset directory. This is done to speed up the retrieval of videos.
# frame index, action tuples are stored
self.data_tuples = index_data(self.json_list, self.path_list)
# tuples of frame index and action (displacement of agent)
index_dict = {'data_tuples': self.data_tuples}
with open(os.path.join(self.path, f'jindex_bib_{mode}_{types_str}.json'), 'w') as fp:
json.dump(index_dict, fp)
else:
# read pre-indexed data
with open(os.path.join(self.path, f'jindex_bib_{mode}_{types_str}.json'), 'r') as fp:
index_dict = json.load(fp)
self.data_tuples = index_dict['data_tuples']
self.tot_trials = len(self.path_list) * 9
def _get_frame_graph(self, jsonfile, frame_idx):
# load json
with open(jsonfile, 'rb') as f:
frame_data = json.load(f)
flat_list = [x for xs in frame_data for x in xs]
# extract entities
grid_objs = parse_objects(flat_list[frame_idx])
# --- build the graph
adj = self._get_spatial_rel(grid_objs)
# define edges
is_top_adj_src, is_top_adj_dst = np.nonzero(adj[0])
is_left_adj_src, is_left_adj_dst = np.nonzero(adj[1])
is_top_right_adj_src, is_top_right_adj_dst = np.nonzero(adj[2])
is_top_left_adj_src, is_top_left_adj_dst = np.nonzero(adj[3])
is_down_adj_src, is_down_adj_dst = np.nonzero(adj[4])
is_right_adj_src, is_right_adj_dst = np.nonzero(adj[5])
is_down_left_adj_src, is_down_left_adj_dst = np.nonzero(adj[6])
is_down_right_adj_src, is_down_right_adj_dst = np.nonzero(adj[7])
is_left_src, is_left_dst = np.nonzero(adj[8])
is_right_src, is_right_dst = np.nonzero(adj[9])
is_front_src, is_front_dst = np.nonzero(adj[10])
is_back_src, is_back_dst = np.nonzero(adj[11])
is_aligned_src, is_aligned_dst = np.nonzero(adj[12])
is_close_src, is_close_dst = np.nonzero(adj[13])
g = dgl.heterograph({
('obj', 'is_top_adj', 'obj'): (torch.tensor(is_top_adj_src), torch.tensor(is_top_adj_dst)),
('obj', 'is_left_adj', 'obj'): (torch.tensor(is_left_adj_src), torch.tensor(is_left_adj_dst)),
('obj', 'is_top_right_adj', 'obj'): (torch.tensor(is_top_right_adj_src), torch.tensor(is_top_right_adj_dst)),
('obj', 'is_top_left_adj', 'obj'): (torch.tensor(is_top_left_adj_src), torch.tensor(is_top_left_adj_dst)),
('obj', 'is_down_adj', 'obj'): (torch.tensor(is_down_adj_src), torch.tensor(is_down_adj_dst)),
('obj', 'is_right_adj', 'obj'): (torch.tensor(is_right_adj_src), torch.tensor(is_right_adj_dst)),
('obj', 'is_down_left_adj', 'obj'): (torch.tensor(is_down_left_adj_src), torch.tensor(is_down_left_adj_dst)),
('obj', 'is_down_right_adj', 'obj'): (torch.tensor(is_down_right_adj_src), torch.tensor(is_down_right_adj_dst)),
('obj', 'is_left', 'obj'): (torch.tensor(is_left_src), torch.tensor(is_left_dst)),
('obj', 'is_right', 'obj'): (torch.tensor(is_right_src), torch.tensor(is_right_dst)),
('obj', 'is_front', 'obj'): (torch.tensor(is_front_src), torch.tensor(is_front_dst)),
('obj', 'is_back', 'obj'): (torch.tensor(is_back_src), torch.tensor(is_back_dst)),
('obj', 'is_aligned', 'obj'): (torch.tensor(is_aligned_src), torch.tensor(is_aligned_dst)),
('obj', 'is_close', 'obj'): (torch.tensor(is_close_src), torch.tensor(is_close_dst))
}, num_nodes_dict={'obj': len(grid_objs)})
g = self._add_node_features(grid_objs, g)
breakpoint()
return g
def _add_node_features(self, objs, graph):
for obj_idx, obj in enumerate(objs):
graph.nodes[obj_idx].data['type'] = torch.tensor(obj.type)
graph.nodes[obj_idx].data['pos'] = torch.tensor([[obj.x, obj.y]], dtype=torch.float32)
assert len(obj.attributes) == 2 and None not in obj.attributes[0] and None not in obj.attributes[1]
graph.nodes[obj_idx].data['color'] = torch.tensor([obj.attributes[0]])
graph.nodes[obj_idx].data['shape'] = torch.tensor([obj.attributes[1]])
return graph
def _get_spatial_rel(self, objs):
spatial_tensors = [np.zeros([len(objs), len(objs)]) for _ in range(len(self.rel_deter_func))]
for obj_idx1, obj1 in enumerate(objs):
for obj_idx2, obj2 in enumerate(objs):
direction_vec = np.array((0, -1)) # Up
for rel_idx, func in enumerate(self.rel_deter_func):
if func(obj1, obj2, direction_vec):
spatial_tensors[rel_idx][obj_idx1, obj_idx2] = 1.0
return spatial_tensors
def get_trial(self, trials, step=10):
# retrieve state embeddings and actions from cached file
states = []
actions = []
trial_len = []
lens = []
n_nodes = []
# 8 trials
for t in trials:
tl = [(t, n) for n in range(0, len(self.data_tuples[t]), step)]
if len(tl) > self.max_len: # 30
tl = tl[:self.max_len]
trial_len.append(tl)
for tl in trial_len:
states.append([])
actions.append([])
lens.append(len(tl))
for t, n in tl:
video = self.data_tuples[t][n][0]
states[-1].append(self._get_frame_graph(video, self.data_tuples[t][n][1]))
n_nodes.append(states[-1][-1].number_of_nodes())
# actions are pooled over frames
if len(self.data_tuples[t]) > n + self.action_range:
actions_xy = [d[2] for d in self.data_tuples[t][n:n + self.action_range]]
else:
actions_xy = [d[2] for d in self.data_tuples[t][n:]]
actions_xy = np.array(actions_xy)
actions_xy = np.mean(actions_xy, axis=0)
actions[-1].append(actions_xy)
states[-1] = dgl.batch(states[-1])
actions[-1] = torch.tensor(np.array(actions[-1]))
trial_actions_padded = torch.zeros(self.max_len, actions[-1].size(1))
trial_actions_padded[:actions[-1].size(0), :] = actions[-1]
actions[-1] = trial_actions_padded
return states, actions, lens, n_nodes
def __getitem__(self, idx):
ep_trials = [idx * self.num_trials + t for t in range(self.num_trials)] # [idx, ..., idx+8]
states, actions, lens, n_nodes = self.get_trial(ep_trials, step=self.action_range)
return states, actions, lens, n_nodes
def __len__(self):
return self.tot_trials // self.num_trials
class TestTransitionDatasetSequence(torch.utils.data.Dataset):
"""
Test dataset class for the behavior cloning rnn model. This dataset is used to test the model on the eval data.
This class is used to compare plausible and implausible episodes.
Args:
path: path to the dataset
types: list of video types to include
size: size of the frames to be returned
mode: test
num_context: number of context state-action pairs
num_test: number of test state-action pairs
num_trials: number of trials in an episode
action_range: number of frames to skip; actions are combined over these number of frames (displcement) of the agent
process_data: whether to the videos or not (skip if already processed)
__getitem__:
returns: (expected_dem_frames, expected_dem_actions, expected_dem_lens expected_query_frames, expected_target_actions,
unexpected_dem_frames, unexpected_dem_actions, unexpected_dem_lens, unexpected_query_frames, unexpected_target_actions)
dem_frames: (num_context, max_len, 3, size, size)
dem_actions: (num_context, max_len, 2)
dem_lens: (num_context)
query_frames: (num_test, 3, size, size)
target_actions: (num_test, 2)
"""
def __init__(
self,
path,
task_type=None,
mode="test",
num_test=1,
num_trials=9,
action_range=10,
process_data=0,
max_len=30
):
self.path = path
self.task_type = task_type
self.mode = mode
self.num_trials = num_trials
self.num_test = num_test
self.action_range = action_range
self.max_len = max_len
self.ep_combs = self.num_trials * (self.num_trials - 2) # 9p2 - 9
self.eps = [[x, y] for x in range(self.num_trials) for y in range(self.num_trials) if x != y]
self.path_list_exp = []
self.json_list_exp = []
self.path_list_un = []
self.json_list_un = []
print(f'reading files of type {task_type} in {mode}')
paths_expected = sorted([os.path.join(self.path, task_type, x) for x in os.listdir(os.path.join(self.path, task_type)) if
x.endswith('e.mp4')])
jsons_expected = sorted([os.path.join(self.path, task_type, x) for x in os.listdir(os.path.join(self.path, task_type)) if
x.endswith('e.json') and 'index' not in x])
paths_unexpected = sorted([os.path.join(self.path, task_type, x) for x in os.listdir(os.path.join(self.path, task_type)) if
x.endswith('u.mp4')])
jsons_unexpected = sorted([os.path.join(self.path, task_type, x) for x in os.listdir(os.path.join(self.path, task_type)) if
x.endswith('u.json') and 'index' not in x])
self.path_list_exp += paths_expected
self.json_list_exp += jsons_expected
self.path_list_un += paths_unexpected
self.json_list_un += jsons_unexpected
self.data_expected = []
self.data_unexpected = []
if process_data:
# index data. This is done to speed up video retrieval.
# frame index, action tuples are stored
self.data_expected = index_data(self.json_list_exp, self.path_list_exp)
index_dict = {'data_tuples': self.data_expected}
with open(os.path.join(self.path, f'jindex_bib_test_{task_type}e.json'), 'w') as fp:
json.dump(index_dict, fp)
self.data_unexpected = index_data(self.json_list_un, self.path_list_un)
index_dict = {'data_tuples': self.data_unexpected}
with open(os.path.join(self.path, f'jindex_bib_test_{task_type}u.json'), 'w') as fp:
json.dump(index_dict, fp)
else:
with open(os.path.join(self.path, f'jindex_bib_{mode}_{task_type}e.json'), 'r') as fp:
index_dict = json.load(fp)
self.data_expected = index_dict['data_tuples']
with open(os.path.join(self.path, f'jindex_bib_{mode}_{task_type}u.json'), 'r') as fp:
index_dict = json.load(fp)
self.data_unexpected = index_dict['data_tuples']
self.rel_deter_func = [
is_top_adj, is_left_adj, is_top_right_adj, is_top_left_adj,
is_down_adj, is_right_adj, is_down_left_adj, is_down_right_adj,
is_left, is_right, is_front, is_back, is_aligned, is_close
]
print('Done.')
def _get_frame_graph(self, jsonfile, frame_idx):
# load json
with open(jsonfile, 'rb') as f:
frame_data = json.load(f)
flat_list = [x for xs in frame_data for x in xs]
# extract entities
grid_objs = parse_objects(flat_list[frame_idx])
# --- build the graph
adj = self._get_spatial_rel(grid_objs)
# define edges
is_top_adj_src, is_top_adj_dst = np.nonzero(adj[0])
is_left_adj_src, is_left_adj_dst = np.nonzero(adj[1])
is_top_right_adj_src, is_top_right_adj_dst = np.nonzero(adj[2])
is_top_left_adj_src, is_top_left_adj_dst = np.nonzero(adj[3])
is_down_adj_src, is_down_adj_dst = np.nonzero(adj[4])
is_right_adj_src, is_right_adj_dst = np.nonzero(adj[5])
is_down_left_adj_src, is_down_left_adj_dst = np.nonzero(adj[6])
is_down_right_adj_src, is_down_right_adj_dst = np.nonzero(adj[7])
is_left_src, is_left_dst = np.nonzero(adj[8])
is_right_src, is_right_dst = np.nonzero(adj[9])
is_front_src, is_front_dst = np.nonzero(adj[10])
is_back_src, is_back_dst = np.nonzero(adj[11])
is_aligned_src, is_aligned_dst = np.nonzero(adj[12])
is_close_src, is_close_dst = np.nonzero(adj[13])
g = dgl.heterograph({
('obj', 'is_top_adj', 'obj'): (torch.tensor(is_top_adj_src), torch.tensor(is_top_adj_dst)),
('obj', 'is_left_adj', 'obj'): (torch.tensor(is_left_adj_src), torch.tensor(is_left_adj_dst)),
('obj', 'is_top_right_adj', 'obj'): (torch.tensor(is_top_right_adj_src), torch.tensor(is_top_right_adj_dst)),
('obj', 'is_top_left_adj', 'obj'): (torch.tensor(is_top_left_adj_src), torch.tensor(is_top_left_adj_dst)),
('obj', 'is_down_adj', 'obj'): (torch.tensor(is_down_adj_src), torch.tensor(is_down_adj_dst)),
('obj', 'is_right_adj', 'obj'): (torch.tensor(is_right_adj_src), torch.tensor(is_right_adj_dst)),
('obj', 'is_down_left_adj', 'obj'): (torch.tensor(is_down_left_adj_src), torch.tensor(is_down_left_adj_dst)),
('obj', 'is_down_right_adj', 'obj'): (torch.tensor(is_down_right_adj_src), torch.tensor(is_down_right_adj_dst)),
('obj', 'is_left', 'obj'): (torch.tensor(is_left_src), torch.tensor(is_left_dst)),
('obj', 'is_right', 'obj'): (torch.tensor(is_right_src), torch.tensor(is_right_dst)),
('obj', 'is_front', 'obj'): (torch.tensor(is_front_src), torch.tensor(is_front_dst)),
('obj', 'is_back', 'obj'): (torch.tensor(is_back_src), torch.tensor(is_back_dst)),
('obj', 'is_aligned', 'obj'): (torch.tensor(is_aligned_src), torch.tensor(is_aligned_dst)),
('obj', 'is_close', 'obj'): (torch.tensor(is_close_src), torch.tensor(is_close_dst))
}, num_nodes_dict={'obj': len(grid_objs)})
g = self._add_node_features(grid_objs, g)
return g
def _add_node_features(self, objs, graph):
for obj_idx, obj in enumerate(objs):
graph.nodes[obj_idx].data['type'] = torch.tensor(obj.type)
graph.nodes[obj_idx].data['pos'] = torch.tensor([[obj.x, obj.y]], dtype=torch.float32)
assert len(obj.attributes) == 2 and None not in obj.attributes[0] and None not in obj.attributes[1]
graph.nodes[obj_idx].data['color'] = torch.tensor([obj.attributes[0]])
graph.nodes[obj_idx].data['shape'] = torch.tensor([obj.attributes[1]])
return graph
def _get_spatial_rel(self, objs):
spatial_tensors = [np.zeros([len(objs), len(objs)]) for _ in range(len(self.rel_deter_func))]
for obj_idx1, obj1 in enumerate(objs):
for obj_idx2, obj2 in enumerate(objs):
direction_vec = np.array((0, -1)) # Up why??????????????
for rel_idx, func in enumerate(self.rel_deter_func):
if func(obj1, obj2, direction_vec):
spatial_tensors[rel_idx][obj_idx1, obj_idx2] = 1.0
return spatial_tensors
def get_trial(self, trials, data, step=10):
# retrieve state embeddings and actions from cached file
states = []
actions = []
trial_len = []
lens = []
n_nodes = []
for t in trials:
tl = [(t, n) for n in range(0, len(data[t]), step)]
if len(tl) > self.max_len:
tl = tl[:self.max_len]
trial_len.append(tl)
for tl in trial_len:
states.append([])
actions.append([])
lens.append(len(tl))
for t, n in tl:
video = data[t][n][0]
states[-1].append(self._get_frame_graph(video, data[t][n][1]))
n_nodes.append(states[-1][-1].number_of_nodes())
if len(data[t]) > n + self.action_range:
actions_xy = [d[2] for d in data[t][n:n + self.action_range]]
else:
actions_xy = [d[2] for d in data[t][n:]]
actions_xy = np.array(actions_xy)
actions_xy = np.mean(actions_xy, axis=0)
actions[-1].append(actions_xy)
states[-1] = dgl.batch(states[-1])
actions[-1] = torch.tensor(np.array(actions[-1]))
trial_actions_padded = torch.zeros(self.max_len, actions[-1].size(1))
trial_actions_padded[:actions[-1].size(0), :] = actions[-1]
actions[-1] = trial_actions_padded
return states, actions, lens, n_nodes
def get_test(self, trial, data, step=10):
# retrieve state embeddings and actions from cached file
states = []
actions = []
trial_len = []
trial_len += [(trial, n) for n in range(0, len(data[trial]), step)]
for t, n in trial_len:
video = data[t][n][0]
state = self._get_frame_graph(video, data[t][n][1])
if len(data[t]) > n + self.action_range:
actions_xy = [d[2] for d in data[t][n:n + self.action_range]]
else:
actions_xy = [d[2] for d in data[t][n:]]
actions_xy = np.array(actions_xy)
actions_xy = np.mean(actions_xy, axis=0)
actions.append(actions_xy)
states.append(state)
#states = torch.stack(states)
states = dgl.batch(states)
actions = torch.tensor(np.array(actions))
return states, actions
def __getitem__(self, idx):
ep_trials = [idx * self.num_trials + t for t in range(self.num_trials)]
dem_expected_states, dem_expected_actions, dem_expected_lens, dem_expected_nodes = self.get_trial(
ep_trials[:-1], self.data_expected, step=self.action_range
)
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, dem_unexpected_nodes = self.get_trial(
ep_trials[:-1], self.data_unexpected, step=self.action_range
)
query_expected_frames, target_expected_actions = self.get_test(
ep_trials[-1], self.data_expected, step=self.action_range
)
query_unexpected_frames, target_unexpected_actions = self.get_test(
ep_trials[-1], self.data_unexpected, step=self.action_range
)
return dem_expected_states, dem_expected_actions, dem_expected_lens, dem_expected_nodes, \
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, dem_unexpected_nodes, \
query_expected_frames, target_expected_actions, \
query_unexpected_frames, target_unexpected_actions
def __len__(self):
return len(self.path_list_exp)
if __name__ == '__main__':
types = ['preference', 'multi_agent', 'inaccessible_goal',
'efficiency_irrational', 'efficiency_time','efficiency_path',
'instrumental_no_barrier', 'instrumental_blocking_barrier', 'instrumental_inconsequential_barrier']
for t in types:
ttd = TestTransitionDatasetSequence(path='/datasets/external/bib_evaluation_1_1/', task_type=t, process_data=0, mode='test')
for i in range(ttd.__len__()):
print(i, end='\r')
dem_expected_states, dem_expected_actions, dem_expected_lens, dem_expected_nodes, \
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, dem_unexpected_nodes, \
query_expected_frames, target_expected_actions, \
query_unexpected_frames, target_unexpected_actions = ttd.__getitem__(i)
for j in range(8):
if not torch.tensor([1., 0., 0., 0., 0., 0., 0., 0., 0.]) in dem_expected_states[j].ndata['type']:
print(i)
if not torch.tensor([1., 0., 0., 0., 0., 0., 0., 0., 0.]) in dem_unexpected_states[j].ndata['type']:
print(i)
if not torch.tensor([1., 0., 0., 0., 0., 0., 0., 0., 0.]) in query_expected_frames.ndata['type']:
print(i)
if not torch.tensor([1., 0., 0., 0., 0., 0., 0., 0., 0.]) in query_unexpected_frames.ndata['type']:
print(i)