import sys sys.path.append('/projects/bortoletto/icml2023_matteo/utils') from dataset import TransitionDataset, TestTransitionDatasetSequence import multiprocessing as mp import argparse import pickle as pkl import os # Instantiate the parser parser = argparse.ArgumentParser() parser.add_argument('--cpus', type=int, help='Number of processes') parser.add_argument('--mode', type=str, help='Train (train) or validation (val)') args = parser.parse_args() NUM_PROCESSES = args.cpus MODE = args.mode def generate_files(idx): print('Generating idx', idx) if os.path.exists(PATH+str(idx)+'.pkl'): print('Index', idx, 'skipped.') return if MODE == 'train' or MODE == 'val': states, actions, lens, n_nodes = dataset.__getitem__(idx) with open(PATH+str(idx)+'.pkl', 'wb') as f: pkl.dump([states, actions, lens, n_nodes], f) elif MODE == 'test': dem_expected_states, dem_expected_actions, dem_expected_lens, dem_expected_nodes, \ dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, dem_unexpected_nodes, \ query_expected_frames, target_expected_actions, \ query_unexpected_frames, target_unexpected_actions = dataset.__getitem__(idx) with open(PATH+str(idx)+'.pkl', 'wb') as f: pkl.dump([ dem_expected_states, dem_expected_actions, dem_expected_lens, dem_expected_nodes, \ dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, dem_unexpected_nodes, \ query_expected_frames, target_expected_actions, \ query_unexpected_frames, target_unexpected_actions], f ) else: raise ValueError('MODE can be only train, val or test.') print(PATH+str(idx)+'.pkl saved.') if __name__ == "__main__": if MODE == 'train': print('TRAIN MODE') PATH = '/datasets/external/bib_train/graphs/all_tasks/train_dgl_hetero_nobound_4feats/' if not os.path.exists(PATH): os.makedirs(PATH) print(PATH, 'directory created.') dataset = TransitionDataset( path='/datasets/external/bib_train/', types=['instrumental_action', 'multi_agent', 'preference', 'single_object'], mode="train", max_len=30, num_test=1, num_trials=9, action_range=10, process_data=0 ) pool = mp.Pool(processes=NUM_PROCESSES) print('Starting graph generation with', NUM_PROCESSES, 'processes...') pool.map(generate_files, [i for i in range(dataset.__len__())]) pool.close() elif MODE == 'val': print('VALIDATION MODE') types = ['multi_agent', 'instrumental_action', 'preference', 'single_object'] for t in range(len(types)): PATH = '/datasets/external/bib_train/graphs/all_tasks/val_dgl_hetero_nobound_4feats/'+types[t]+'/' if not os.path.exists(PATH): os.makedirs(PATH) print(PATH, 'directory created.') dataset = TransitionDataset( path='/datasets/external/bib_train/', types=[types[t]], mode="val", max_len=30, num_test=1, num_trials=9, action_range=10, process_data=0 ) pool = mp.Pool(processes=NUM_PROCESSES) print('Starting', types[t], 'graph generation with', NUM_PROCESSES, 'processes...') pool.map(generate_files, [i for i in range(dataset.__len__())]) pool.close() elif MODE == 'test': print('TEST MODE') types = [ 'preference', 'multi_agent', 'inaccessible_goal', 'efficiency_irrational', 'efficiency_time','efficiency_path', 'instrumental_no_barrier', 'instrumental_blocking_barrier', 'instrumental_inconsequential_barrier' ] for t in range(len(types)): PATH = '/datasets/external/bib_evaluation_1_1/graphs/all_tasks_dgl_hetero_nobound_4feats/'+types[t]+'/' if not os.path.exists(PATH): os.makedirs(PATH) print(PATH, 'directory created.') dataset = TestTransitionDatasetSequence( path='/datasets/external/bib_evaluation_1_1/', task_type=types[t], mode="test", num_test=1, num_trials=9, action_range=10, process_data=0, max_len=30 ) pool = mp.Pool(processes=NUM_PROCESSES) print('Starting', types[t], 'graph generation with', NUM_PROCESSES, 'processes...') pool.map(generate_files, [i for i in range(dataset.__len__())]) pool.close() else: raise ValueError