IRENE/utils/index_data.py
2024-02-01 15:40:47 +01:00

124 lines
5 KiB
Python

import json
import os
import torch
import torch.utils.data
from tqdm import tqdm
def index_data(json_list, path_list):
print(f'processing files {len(json_list)}')
data_tuples = []
for j, v in tqdm(zip(json_list, path_list)):
with open(j, 'r') as f:
state = json.load(f)
ep_lens = [len(x) for x in state]
past_len = 0
for e, l in enumerate(ep_lens):
data_tuples.append([])
# skip first 30 frames and last 83 frames
for f in range(30, l - 83):
# find action taken;
f0x, f0y = state[e][f]['agent'][0]
f1x, f1y = state[e][f + 1]['agent'][0]
dx = (f1x - f0x) / 2.
dy = (f1y - f0y) / 2.
action = [dx, dy]
#data_tuples[-1].append((v, past_len + f, action))
data_tuples[-1].append((j, past_len + f, action))
# data_tuples = (json file, frame number, action)
assert len(data_tuples[-1]) > 0
past_len += l
return data_tuples
class TransitionDataset(torch.utils.data.Dataset):
"""
Training dataset class for the behavior cloning mlp model.
Args:
path: path to the dataset
types: list of video types to include
size: size of the frames to be returned
mode: train, val
num_context: number of context state-action pairs
num_test: number of test state-action pairs
num_trials: number of trials in an episode
action_range: number of frames to skip; actions are combined over these number of frames (displcement) of the agent
process_data: whether to the videos or not (skip if already processed)
__getitem__:
returns: (dem_frames, dem_actions, query_frames, target_actions)
dem_frames: (num_context, 3, size, size)
dem_actions: (num_context, 2)
query_frames: (num_test, 3, size, size)
target_actions: (num_test, 2)
"""
def __init__(self, path, types=None, size=None, mode="train", num_context=30, num_test=1, num_trials=9,
action_range=10, process_data=0):
self.path = path
self.types = types
self.size = size
self.mode = mode
self.num_trials = num_trials
self.num_context = num_context
self.num_test = num_test
self.action_range = action_range
self.ep_combs = self.num_trials * (self.num_trials - 2) # 9p2 - 9
self.eps = [[x, y] for x in range(self.num_trials) for y in range(self.num_trials) if x != y]
types_str = '_'.join(self.types)
self.path_list = []
self.json_list = []
# get video paths and json file paths
for t in types:
print(f'reading files of type {t} in {mode}')
paths = [os.path.join(self.path, t, x) for x in os.listdir(os.path.join(self.path, t)) if
x.endswith(f'.mp4')]
jsons = [os.path.join(self.path, t, x) for x in os.listdir(os.path.join(self.path, t)) if
x.endswith(f'.json') and 'index' not in x]
paths = sorted(paths)
jsons = sorted(jsons)
if mode == 'train':
self.path_list += paths[:int(0.8 * len(jsons))]
self.json_list += jsons[:int(0.8 * len(jsons))]
elif mode == 'val':
self.path_list += paths[int(0.8 * len(jsons)):]
self.json_list += jsons[int(0.8 * len(jsons)):]
else:
self.path_list += paths
self.json_list += jsons
self.data_tuples = []
if process_data:
# index the videos in the dataset directory. This is done to speed up the retrieval of videos.
# frame index, action tuples are stored
self.data_tuples = index_data(self.json_list, self.path_list)
# tuples of frame index and action (displacement of agent)
index_dict = {'data_tuples': self.data_tuples}
with open(os.path.join(self.path, f'jindex_bib_{mode}_{types_str}.json'), 'w') as fp:
json.dump(index_dict, fp)
else:
# read pre-indexed data
with open(os.path.join(self.path, f'jindex_bib_{mode}_{types_str}.json'), 'r') as fp:
index_dict = json.load(fp)
self.data_tuples = index_dict['data_tuples']
self.tot_trials = len(self.path_list) * 9
def __getitem__(self, idx):
print('Empty')
return
def __len__(self):
return self.tot_trials // self.num_trials
if __name__ == "__main__":
dataset = TransitionDataset(path='/datasets/external/bib_train/',
types=['multi_agent', 'instrumental_action'], #['instrumental_action', 'multi_agent', 'preference', 'single_object'],
size=(84, 84),
mode="train", num_context=30,
num_test=1, num_trials=9,
action_range=10, process_data=1)
print(len(dataset))