import json import os import torch import torch.utils.data from tqdm import tqdm def index_data(json_list, path_list): print(f'processing files {len(json_list)}') data_tuples = [] for j, v in tqdm(zip(json_list, path_list)): with open(j, 'r') as f: state = json.load(f) ep_lens = [len(x) for x in state] past_len = 0 for e, l in enumerate(ep_lens): data_tuples.append([]) # skip first 30 frames and last 83 frames for f in range(30, l - 83): # find action taken; f0x, f0y = state[e][f]['agent'][0] f1x, f1y = state[e][f + 1]['agent'][0] dx = (f1x - f0x) / 2. dy = (f1y - f0y) / 2. action = [dx, dy] #data_tuples[-1].append((v, past_len + f, action)) data_tuples[-1].append((j, past_len + f, action)) # data_tuples = (json file, frame number, action) assert len(data_tuples[-1]) > 0 past_len += l return data_tuples class TransitionDataset(torch.utils.data.Dataset): """ Training dataset class for the behavior cloning mlp model. Args: path: path to the dataset types: list of video types to include size: size of the frames to be returned mode: train, val num_context: number of context state-action pairs num_test: number of test state-action pairs num_trials: number of trials in an episode action_range: number of frames to skip; actions are combined over these number of frames (displcement) of the agent process_data: whether to the videos or not (skip if already processed) __getitem__: returns: (dem_frames, dem_actions, query_frames, target_actions) dem_frames: (num_context, 3, size, size) dem_actions: (num_context, 2) query_frames: (num_test, 3, size, size) target_actions: (num_test, 2) """ def __init__(self, path, types=None, size=None, mode="train", num_context=30, num_test=1, num_trials=9, action_range=10, process_data=0): self.path = path self.types = types self.size = size self.mode = mode self.num_trials = num_trials self.num_context = num_context self.num_test = num_test self.action_range = action_range self.ep_combs = self.num_trials * (self.num_trials - 2) # 9p2 - 9 self.eps = [[x, y] for x in range(self.num_trials) for y in range(self.num_trials) if x != y] types_str = '_'.join(self.types) self.path_list = [] self.json_list = [] # get video paths and json file paths for t in types: print(f'reading files of type {t} in {mode}') paths = [os.path.join(self.path, t, x) for x in os.listdir(os.path.join(self.path, t)) if x.endswith(f'.mp4')] jsons = [os.path.join(self.path, t, x) for x in os.listdir(os.path.join(self.path, t)) if x.endswith(f'.json') and 'index' not in x] paths = sorted(paths) jsons = sorted(jsons) if mode == 'train': self.path_list += paths[:int(0.8 * len(jsons))] self.json_list += jsons[:int(0.8 * len(jsons))] elif mode == 'val': self.path_list += paths[int(0.8 * len(jsons)):] self.json_list += jsons[int(0.8 * len(jsons)):] else: self.path_list += paths self.json_list += jsons self.data_tuples = [] if process_data: # index the videos in the dataset directory. This is done to speed up the retrieval of videos. # frame index, action tuples are stored self.data_tuples = index_data(self.json_list, self.path_list) # tuples of frame index and action (displacement of agent) index_dict = {'data_tuples': self.data_tuples} with open(os.path.join(self.path, f'jindex_bib_{mode}_{types_str}.json'), 'w') as fp: json.dump(index_dict, fp) else: # read pre-indexed data with open(os.path.join(self.path, f'jindex_bib_{mode}_{types_str}.json'), 'r') as fp: index_dict = json.load(fp) self.data_tuples = index_dict['data_tuples'] self.tot_trials = len(self.path_list) * 9 def __getitem__(self, idx): print('Empty') return def __len__(self): return self.tot_trials // self.num_trials if __name__ == "__main__": dataset = TransitionDataset(path='/datasets/external/bib_train/', types=['multi_agent', 'instrumental_action'], #['instrumental_action', 'multi_agent', 'preference', 'single_object'], size=(84, 84), mode="train", num_context=30, num_test=1, num_trials=9, action_range=10, process_data=1) print(len(dataset))