IRENE/utils/build_graphs.py

116 lines
4.8 KiB
Python
Raw Normal View History

2024-02-01 15:40:47 +01:00
import sys
sys.path.append('/projects/bortoletto/icml2023_matteo/utils')
from dataset import TransitionDataset, TestTransitionDatasetSequence
import multiprocessing as mp
import argparse
import pickle as pkl
import os
# Instantiate the parser
parser = argparse.ArgumentParser()
parser.add_argument('--cpus', type=int,
help='Number of processes')
parser.add_argument('--mode', type=str,
help='Train (train) or validation (val)')
args = parser.parse_args()
NUM_PROCESSES = args.cpus
MODE = args.mode
def generate_files(idx):
print('Generating idx', idx)
if os.path.exists(PATH+str(idx)+'.pkl'):
print('Index', idx, 'skipped.')
return
if MODE == 'train' or MODE == 'val':
states, actions, lens, n_nodes = dataset.__getitem__(idx)
with open(PATH+str(idx)+'.pkl', 'wb') as f:
pkl.dump([states, actions, lens, n_nodes], f)
elif MODE == 'test':
dem_expected_states, dem_expected_actions, dem_expected_lens, dem_expected_nodes, \
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, dem_unexpected_nodes, \
query_expected_frames, target_expected_actions, \
query_unexpected_frames, target_unexpected_actions = dataset.__getitem__(idx)
with open(PATH+str(idx)+'.pkl', 'wb') as f:
pkl.dump([
dem_expected_states, dem_expected_actions, dem_expected_lens, dem_expected_nodes, \
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, dem_unexpected_nodes, \
query_expected_frames, target_expected_actions, \
query_unexpected_frames, target_unexpected_actions], f
)
else:
raise ValueError('MODE can be only train, val or test.')
print(PATH+str(idx)+'.pkl saved.')
if __name__ == "__main__":
if MODE == 'train':
print('TRAIN MODE')
PATH = '/datasets/external/bib_train/graphs/all_tasks/train_dgl_hetero_nobound_4feats/'
if not os.path.exists(PATH):
os.makedirs(PATH)
print(PATH, 'directory created.')
dataset = TransitionDataset(
path='/datasets/external/bib_train/',
types=['instrumental_action', 'multi_agent', 'preference', 'single_object'],
mode="train",
max_len=30,
num_test=1,
num_trials=9,
action_range=10,
process_data=0
)
pool = mp.Pool(processes=NUM_PROCESSES)
print('Starting graph generation with', NUM_PROCESSES, 'processes...')
pool.map(generate_files, [i for i in range(dataset.__len__())])
pool.close()
elif MODE == 'val':
print('VALIDATION MODE')
types = ['multi_agent', 'instrumental_action', 'preference', 'single_object']
for t in range(len(types)):
PATH = '/datasets/external/bib_train/graphs/all_tasks/val_dgl_hetero_nobound_4feats/'+types[t]+'/'
if not os.path.exists(PATH):
os.makedirs(PATH)
print(PATH, 'directory created.')
dataset = TransitionDataset(
path='/datasets/external/bib_train/',
types=[types[t]],
mode="val",
max_len=30,
num_test=1,
num_trials=9,
action_range=10,
process_data=0
)
pool = mp.Pool(processes=NUM_PROCESSES)
print('Starting', types[t], 'graph generation with', NUM_PROCESSES, 'processes...')
pool.map(generate_files, [i for i in range(dataset.__len__())])
pool.close()
elif MODE == 'test':
print('TEST MODE')
types = [
'preference', 'multi_agent', 'inaccessible_goal',
'efficiency_irrational', 'efficiency_time','efficiency_path',
'instrumental_no_barrier', 'instrumental_blocking_barrier', 'instrumental_inconsequential_barrier'
]
for t in range(len(types)):
PATH = '/datasets/external/bib_evaluation_1_1/graphs/all_tasks_dgl_hetero_nobound_4feats/'+types[t]+'/'
if not os.path.exists(PATH):
os.makedirs(PATH)
print(PATH, 'directory created.')
dataset = TestTransitionDatasetSequence(
path='/datasets/external/bib_evaluation_1_1/',
task_type=types[t],
mode="test",
num_test=1,
num_trials=9,
action_range=10,
process_data=0,
max_len=30
)
pool = mp.Pool(processes=NUM_PROCESSES)
print('Starting', types[t], 'graph generation with', NUM_PROCESSES, 'processes...')
pool.map(generate_files, [i for i in range(dataset.__len__())])
pool.close()
else:
raise ValueError