115 lines
4.8 KiB
Python
115 lines
4.8 KiB
Python
import sys
|
|
sys.path.append('/projects/bortoletto/icml2023_matteo/utils')
|
|
from dataset import TransitionDataset, TestTransitionDatasetSequence
|
|
import multiprocessing as mp
|
|
import argparse
|
|
import pickle as pkl
|
|
import os
|
|
|
|
# Instantiate the parser
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--cpus', type=int,
|
|
help='Number of processes')
|
|
parser.add_argument('--mode', type=str,
|
|
help='Train (train) or validation (val)')
|
|
args = parser.parse_args()
|
|
|
|
NUM_PROCESSES = args.cpus
|
|
MODE = args.mode
|
|
|
|
def generate_files(idx):
|
|
print('Generating idx', idx)
|
|
if os.path.exists(PATH+str(idx)+'.pkl'):
|
|
print('Index', idx, 'skipped.')
|
|
return
|
|
if MODE == 'train' or MODE == 'val':
|
|
states, actions, lens, n_nodes = dataset.__getitem__(idx)
|
|
with open(PATH+str(idx)+'.pkl', 'wb') as f:
|
|
pkl.dump([states, actions, lens, n_nodes], f)
|
|
elif MODE == 'test':
|
|
dem_expected_states, dem_expected_actions, dem_expected_lens, dem_expected_nodes, \
|
|
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, dem_unexpected_nodes, \
|
|
query_expected_frames, target_expected_actions, \
|
|
query_unexpected_frames, target_unexpected_actions = dataset.__getitem__(idx)
|
|
with open(PATH+str(idx)+'.pkl', 'wb') as f:
|
|
pkl.dump([
|
|
dem_expected_states, dem_expected_actions, dem_expected_lens, dem_expected_nodes, \
|
|
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, dem_unexpected_nodes, \
|
|
query_expected_frames, target_expected_actions, \
|
|
query_unexpected_frames, target_unexpected_actions], f
|
|
)
|
|
else:
|
|
raise ValueError('MODE can be only train, val or test.')
|
|
print(PATH+str(idx)+'.pkl saved.')
|
|
|
|
if __name__ == "__main__":
|
|
if MODE == 'train':
|
|
print('TRAIN MODE')
|
|
PATH = '/datasets/external/bib_train/graphs/all_tasks/train_dgl_hetero_nobound_4feats/'
|
|
if not os.path.exists(PATH):
|
|
os.makedirs(PATH)
|
|
print(PATH, 'directory created.')
|
|
dataset = TransitionDataset(
|
|
path='/datasets/external/bib_train/',
|
|
types=['instrumental_action', 'multi_agent', 'preference', 'single_object'],
|
|
mode="train",
|
|
max_len=30,
|
|
num_test=1,
|
|
num_trials=9,
|
|
action_range=10,
|
|
process_data=0
|
|
)
|
|
pool = mp.Pool(processes=NUM_PROCESSES)
|
|
print('Starting graph generation with', NUM_PROCESSES, 'processes...')
|
|
pool.map(generate_files, [i for i in range(dataset.__len__())])
|
|
pool.close()
|
|
elif MODE == 'val':
|
|
print('VALIDATION MODE')
|
|
types = ['multi_agent', 'instrumental_action', 'preference', 'single_object']
|
|
for t in range(len(types)):
|
|
PATH = '/datasets/external/bib_train/graphs/all_tasks/val_dgl_hetero_nobound_4feats/'+types[t]+'/'
|
|
if not os.path.exists(PATH):
|
|
os.makedirs(PATH)
|
|
print(PATH, 'directory created.')
|
|
dataset = TransitionDataset(
|
|
path='/datasets/external/bib_train/',
|
|
types=[types[t]],
|
|
mode="val",
|
|
max_len=30,
|
|
num_test=1,
|
|
num_trials=9,
|
|
action_range=10,
|
|
process_data=0
|
|
)
|
|
pool = mp.Pool(processes=NUM_PROCESSES)
|
|
print('Starting', types[t], 'graph generation with', NUM_PROCESSES, 'processes...')
|
|
pool.map(generate_files, [i for i in range(dataset.__len__())])
|
|
pool.close()
|
|
elif MODE == 'test':
|
|
print('TEST MODE')
|
|
types = [
|
|
'preference', 'multi_agent', 'inaccessible_goal',
|
|
'efficiency_irrational', 'efficiency_time','efficiency_path',
|
|
'instrumental_no_barrier', 'instrumental_blocking_barrier', 'instrumental_inconsequential_barrier'
|
|
]
|
|
for t in range(len(types)):
|
|
PATH = '/datasets/external/bib_evaluation_1_1/graphs/all_tasks_dgl_hetero_nobound_4feats/'+types[t]+'/'
|
|
if not os.path.exists(PATH):
|
|
os.makedirs(PATH)
|
|
print(PATH, 'directory created.')
|
|
dataset = TestTransitionDatasetSequence(
|
|
path='/datasets/external/bib_evaluation_1_1/',
|
|
task_type=types[t],
|
|
mode="test",
|
|
num_test=1,
|
|
num_trials=9,
|
|
action_range=10,
|
|
process_data=0,
|
|
max_len=30
|
|
)
|
|
pool = mp.Pool(processes=NUM_PROCESSES)
|
|
print('Starting', types[t], 'graph generation with', NUM_PROCESSES, 'processes...')
|
|
pool.map(generate_files, [i for i in range(dataset.__len__())])
|
|
pool.close()
|
|
else:
|
|
raise ValueError
|