125 lines
5 KiB
Python
125 lines
5 KiB
Python
|
import json
|
||
|
import os
|
||
|
import torch
|
||
|
import torch.utils.data
|
||
|
from tqdm import tqdm
|
||
|
|
||
|
|
||
|
def index_data(json_list, path_list):
|
||
|
print(f'processing files {len(json_list)}')
|
||
|
data_tuples = []
|
||
|
for j, v in tqdm(zip(json_list, path_list)):
|
||
|
with open(j, 'r') as f:
|
||
|
state = json.load(f)
|
||
|
ep_lens = [len(x) for x in state]
|
||
|
past_len = 0
|
||
|
for e, l in enumerate(ep_lens):
|
||
|
data_tuples.append([])
|
||
|
# skip first 30 frames and last 83 frames
|
||
|
for f in range(30, l - 83):
|
||
|
# find action taken;
|
||
|
f0x, f0y = state[e][f]['agent'][0]
|
||
|
f1x, f1y = state[e][f + 1]['agent'][0]
|
||
|
dx = (f1x - f0x) / 2.
|
||
|
dy = (f1y - f0y) / 2.
|
||
|
action = [dx, dy]
|
||
|
#data_tuples[-1].append((v, past_len + f, action))
|
||
|
data_tuples[-1].append((j, past_len + f, action))
|
||
|
# data_tuples = (json file, frame number, action)
|
||
|
assert len(data_tuples[-1]) > 0
|
||
|
past_len += l
|
||
|
return data_tuples
|
||
|
|
||
|
class TransitionDataset(torch.utils.data.Dataset):
|
||
|
"""
|
||
|
Training dataset class for the behavior cloning mlp model.
|
||
|
Args:
|
||
|
path: path to the dataset
|
||
|
types: list of video types to include
|
||
|
size: size of the frames to be returned
|
||
|
mode: train, val
|
||
|
num_context: number of context state-action pairs
|
||
|
num_test: number of test state-action pairs
|
||
|
num_trials: number of trials in an episode
|
||
|
action_range: number of frames to skip; actions are combined over these number of frames (displcement) of the agent
|
||
|
process_data: whether to the videos or not (skip if already processed)
|
||
|
__getitem__:
|
||
|
returns: (dem_frames, dem_actions, query_frames, target_actions)
|
||
|
dem_frames: (num_context, 3, size, size)
|
||
|
dem_actions: (num_context, 2)
|
||
|
query_frames: (num_test, 3, size, size)
|
||
|
target_actions: (num_test, 2)
|
||
|
"""
|
||
|
def __init__(self, path, types=None, size=None, mode="train", num_context=30, num_test=1, num_trials=9,
|
||
|
action_range=10, process_data=0):
|
||
|
|
||
|
self.path = path
|
||
|
self.types = types
|
||
|
self.size = size
|
||
|
self.mode = mode
|
||
|
self.num_trials = num_trials
|
||
|
self.num_context = num_context
|
||
|
self.num_test = num_test
|
||
|
self.action_range = action_range
|
||
|
self.ep_combs = self.num_trials * (self.num_trials - 2) # 9p2 - 9
|
||
|
self.eps = [[x, y] for x in range(self.num_trials) for y in range(self.num_trials) if x != y]
|
||
|
types_str = '_'.join(self.types)
|
||
|
|
||
|
self.path_list = []
|
||
|
self.json_list = []
|
||
|
# get video paths and json file paths
|
||
|
for t in types:
|
||
|
print(f'reading files of type {t} in {mode}')
|
||
|
paths = [os.path.join(self.path, t, x) for x in os.listdir(os.path.join(self.path, t)) if
|
||
|
x.endswith(f'.mp4')]
|
||
|
jsons = [os.path.join(self.path, t, x) for x in os.listdir(os.path.join(self.path, t)) if
|
||
|
x.endswith(f'.json') and 'index' not in x]
|
||
|
|
||
|
paths = sorted(paths)
|
||
|
jsons = sorted(jsons)
|
||
|
|
||
|
if mode == 'train':
|
||
|
self.path_list += paths[:int(0.8 * len(jsons))]
|
||
|
self.json_list += jsons[:int(0.8 * len(jsons))]
|
||
|
elif mode == 'val':
|
||
|
self.path_list += paths[int(0.8 * len(jsons)):]
|
||
|
self.json_list += jsons[int(0.8 * len(jsons)):]
|
||
|
else:
|
||
|
self.path_list += paths
|
||
|
self.json_list += jsons
|
||
|
|
||
|
self.data_tuples = []
|
||
|
if process_data:
|
||
|
# index the videos in the dataset directory. This is done to speed up the retrieval of videos.
|
||
|
# frame index, action tuples are stored
|
||
|
self.data_tuples = index_data(self.json_list, self.path_list)
|
||
|
# tuples of frame index and action (displacement of agent)
|
||
|
index_dict = {'data_tuples': self.data_tuples}
|
||
|
with open(os.path.join(self.path, f'jindex_bib_{mode}_{types_str}.json'), 'w') as fp:
|
||
|
json.dump(index_dict, fp)
|
||
|
else:
|
||
|
# read pre-indexed data
|
||
|
with open(os.path.join(self.path, f'jindex_bib_{mode}_{types_str}.json'), 'r') as fp:
|
||
|
index_dict = json.load(fp)
|
||
|
self.data_tuples = index_dict['data_tuples']
|
||
|
|
||
|
self.tot_trials = len(self.path_list) * 9
|
||
|
|
||
|
|
||
|
def __getitem__(self, idx):
|
||
|
print('Empty')
|
||
|
return
|
||
|
|
||
|
def __len__(self):
|
||
|
return self.tot_trials // self.num_trials
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
dataset = TransitionDataset(path='/datasets/external/bib_train/',
|
||
|
types=['multi_agent', 'instrumental_action'], #['instrumental_action', 'multi_agent', 'preference', 'single_object'],
|
||
|
size=(84, 84),
|
||
|
mode="train", num_context=30,
|
||
|
num_test=1, num_trials=9,
|
||
|
action_range=10, process_data=1)
|
||
|
print(len(dataset))
|