310 lines
14 KiB
Python
310 lines
14 KiB
Python
|
import pickle as pkl
|
||
|
import os
|
||
|
import torch
|
||
|
import torch.utils.data
|
||
|
import torch.nn.functional as F
|
||
|
import dgl
|
||
|
import random
|
||
|
from dgl.data import DGLDataset
|
||
|
|
||
|
|
||
|
def collate_function_seq(batch):
|
||
|
#dem_frames = torch.stack([item[0] for item in batch])
|
||
|
dem_frames = dgl.batch([item[0] for item in batch])
|
||
|
dem_actions = torch.stack([item[1] for item in batch])
|
||
|
dem_lens = [item[2] for item in batch]
|
||
|
#query_frames = torch.stack([item[3] for item in batch])
|
||
|
query_frames = dgl.batch([item[3] for item in batch])
|
||
|
target_actions = torch.stack([item[4] for item in batch])
|
||
|
return [dem_frames, dem_actions, dem_lens, query_frames, target_actions]
|
||
|
|
||
|
def collate_function_seq_test(batch):
|
||
|
dem_expected_states = dgl.batch([item[0] for item in batch][0])
|
||
|
dem_expected_actions = torch.stack([item[1] for item in batch][0]).unsqueeze(dim=0)
|
||
|
dem_expected_lens = [item[2] for item in batch]
|
||
|
#print(dem_expected_actions.size())
|
||
|
dem_unexpected_states = dgl.batch([item[3] for item in batch][0])
|
||
|
dem_unexpected_actions = torch.stack([item[4] for item in batch][0]).unsqueeze(dim=0)
|
||
|
dem_unexpected_lens = [item[5] for item in batch]
|
||
|
query_expected_frames = dgl.batch([item[6] for item in batch])
|
||
|
target_expected_actions = torch.stack([item[7] for item in batch])
|
||
|
#print(target_expected_actions.size())
|
||
|
query_unexpected_frames = dgl.batch([item[8] for item in batch])
|
||
|
target_unexpected_actions = torch.stack([item[9] for item in batch])
|
||
|
return [
|
||
|
dem_expected_states, dem_expected_actions, dem_expected_lens, \
|
||
|
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, \
|
||
|
query_expected_frames, target_expected_actions, \
|
||
|
query_unexpected_frames, target_unexpected_actions
|
||
|
]
|
||
|
|
||
|
def collate_function_mental(batch):
|
||
|
dem_frames = dgl.batch([item[0] for item in batch])
|
||
|
dem_actions = torch.stack([item[1] for item in batch])
|
||
|
dem_lens = [item[2] for item in batch]
|
||
|
past_test_frames = dgl.batch([item[3] for item in batch])
|
||
|
past_test_actions = torch.stack([item[4] for item in batch])
|
||
|
past_test_len = [item[5] for item in batch]
|
||
|
query_frames = dgl.batch([item[6] for item in batch])
|
||
|
target_actions = torch.stack([item[7] for item in batch])
|
||
|
return [dem_frames, dem_actions, dem_lens, past_test_frames, past_test_actions, past_test_len, query_frames, target_actions]
|
||
|
|
||
|
|
||
|
class ToMnetDGLDataset(DGLDataset):
|
||
|
"""
|
||
|
Training dataset class.
|
||
|
"""
|
||
|
def __init__(self, path, types=None, mode="train"):
|
||
|
self.path = path
|
||
|
self.types = types
|
||
|
self.mode = mode
|
||
|
print('Mode:', self.mode)
|
||
|
|
||
|
if self.mode == 'train':
|
||
|
if len(self.types) == 4:
|
||
|
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/'
|
||
|
#self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_global/'
|
||
|
#self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_local/'
|
||
|
elif len(self.types) == 3:
|
||
|
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper() + '/'
|
||
|
print(self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper())
|
||
|
elif len(self.types) == 2:
|
||
|
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + '/'
|
||
|
print(self.types[0][0].upper() + self.types[1][0].upper())
|
||
|
elif len(self.types) == 1:
|
||
|
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + '/'
|
||
|
else: raise ValueError('Number of types different from 1 or 4.')
|
||
|
elif self.mode == 'val':
|
||
|
assert len(self.types) == 1
|
||
|
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/' + self.types[0] + '/'
|
||
|
#self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_global/' + self.types[0] + '/'
|
||
|
#self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_local/' + self.types[0] + '/'
|
||
|
else:
|
||
|
raise ValueError
|
||
|
|
||
|
def get_test(self, states, actions):
|
||
|
# now states is a batched graph -> unbatch it, take the len, pick one sub-graph
|
||
|
# randomly and select the corresponding action
|
||
|
frame_graphs = dgl.unbatch(states)
|
||
|
trial_len = len(frame_graphs)
|
||
|
query_idx = random.randint(0, trial_len - 1)
|
||
|
query_graph = frame_graphs[query_idx]
|
||
|
target_action = actions[query_idx]
|
||
|
return query_graph, target_action
|
||
|
|
||
|
def __getitem__(self, idx):
|
||
|
with open(self.path+str(idx)+'.pkl', 'rb') as f:
|
||
|
states, actions, lens, _ = pkl.load(f)
|
||
|
# shuffle
|
||
|
ziplist = list(zip(states, actions, lens))
|
||
|
random.shuffle(ziplist)
|
||
|
states, actions, lens = zip(*ziplist)
|
||
|
# convert tuples to lists
|
||
|
states, actions, lens = [*states], [*actions], [*lens]
|
||
|
# pick last element in the list as test and pick random frame
|
||
|
test_s, test_a = self.get_test(states[-1], actions[-1])
|
||
|
dem_s = states[:-1]
|
||
|
dem_a = actions[:-1]
|
||
|
dem_lens = lens[:-1]
|
||
|
dem_s = dgl.batch(dem_s)
|
||
|
dem_a = torch.stack(dem_a)
|
||
|
return dem_s, dem_a, dem_lens, test_s, test_a
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(os.listdir(self.path))
|
||
|
|
||
|
|
||
|
class TestToMnetDGLDataset(DGLDataset):
|
||
|
"""
|
||
|
Testing dataset class.
|
||
|
"""
|
||
|
def __init__(self, path, task_type=None, mode="test"):
|
||
|
self.path = path
|
||
|
self.type = task_type
|
||
|
self.mode = mode
|
||
|
print('Mode:', self.mode)
|
||
|
|
||
|
if self.mode == 'test':
|
||
|
self.path = self.path + '_dgl_hetero_nobound_4feats/' + self.type + '/'
|
||
|
#self.path = self.path + '_dgl_hetero_nobound_4feats_global/' + self.type + '/'
|
||
|
#self.path = self.path + '_dgl_hetero_nobound_4feats_local/' + self.type + '/'
|
||
|
else:
|
||
|
raise ValueError
|
||
|
|
||
|
def __getitem__(self, idx):
|
||
|
with open(self.path+str(idx)+'.pkl', 'rb') as f:
|
||
|
dem_expected_states, dem_expected_actions, dem_expected_lens, _, \
|
||
|
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, _, \
|
||
|
query_expected_frames, target_expected_actions, \
|
||
|
query_unexpected_frames, target_unexpected_actions = pkl.load(f)
|
||
|
assert len(dem_expected_states) == 8
|
||
|
assert len(dem_expected_actions) == 8
|
||
|
assert len(dem_expected_lens) == 8
|
||
|
assert len(dem_unexpected_states) == 8
|
||
|
assert len(dem_unexpected_actions) == 8
|
||
|
assert len(dem_unexpected_lens) == 8
|
||
|
assert len(dgl.unbatch(query_expected_frames)) == target_expected_actions.size()[0]
|
||
|
assert len(dgl.unbatch(query_unexpected_frames)) == target_unexpected_actions.size()[0]
|
||
|
# ignore n_nodes
|
||
|
return dem_expected_states, dem_expected_actions, dem_expected_lens, \
|
||
|
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, \
|
||
|
query_expected_frames, target_expected_actions, \
|
||
|
query_unexpected_frames, target_unexpected_actions
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(os.listdir(self.path))
|
||
|
|
||
|
|
||
|
class ToMnetDGLDatasetUndersample(DGLDataset):
|
||
|
"""
|
||
|
Training dataset class for the behavior cloning mlp model.
|
||
|
"""
|
||
|
def __init__(self, path, types=None, mode="train"):
|
||
|
self.path = path
|
||
|
self.types = types
|
||
|
self.mode = mode
|
||
|
print('Mode:', self.mode)
|
||
|
|
||
|
if self.mode == 'train':
|
||
|
if len(self.types) == 4:
|
||
|
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/'
|
||
|
elif len(self.types) == 3:
|
||
|
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper() + '/'
|
||
|
print(self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper())
|
||
|
elif len(self.types) == 2:
|
||
|
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + '/'
|
||
|
print(self.types[0][0].upper() + self.types[1][0].upper())
|
||
|
elif len(self.types) == 1:
|
||
|
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + '/'
|
||
|
else: raise ValueError('Number of types different from 1 or 4.')
|
||
|
elif self.mode == 'val':
|
||
|
assert len(self.types) == 1
|
||
|
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/' + self.types[0] + '/'
|
||
|
else:
|
||
|
raise ValueError
|
||
|
print('Undersampled dataset!')
|
||
|
|
||
|
def get_test(self, states, actions):
|
||
|
# now states is a batched graph -> unbatch it, take the len, pick one sub-graph
|
||
|
# randomly and select the corresponding action
|
||
|
frame_graphs = dgl.unbatch(states)
|
||
|
trial_len = len(frame_graphs)
|
||
|
query_idx = random.randint(0, trial_len - 1)
|
||
|
query_graph = frame_graphs[query_idx]
|
||
|
target_action = actions[query_idx]
|
||
|
return query_graph, target_action
|
||
|
|
||
|
def __getitem__(self, idx):
|
||
|
idx = idx + 3175
|
||
|
with open(self.path+str(idx)+'.pkl', 'rb') as f:
|
||
|
states, actions, lens, _ = pkl.load(f)
|
||
|
# shuffle
|
||
|
ziplist = list(zip(states, actions, lens))
|
||
|
random.shuffle(ziplist)
|
||
|
states, actions, lens = zip(*ziplist)
|
||
|
# convert tuples to lists
|
||
|
states, actions, lens = [*states], [*actions], [*lens]
|
||
|
# pick last element in the list as test and pick random frame
|
||
|
test_s, test_a = self.get_test(states[-1], actions[-1])
|
||
|
dem_s = states[:-1]
|
||
|
dem_a = actions[:-1]
|
||
|
dem_lens = lens[:-1]
|
||
|
dem_s = dgl.batch(dem_s)
|
||
|
dem_a = torch.stack(dem_a)
|
||
|
return dem_s, dem_a, dem_lens, test_s, test_a
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(os.listdir(self.path)) - 3175
|
||
|
|
||
|
|
||
|
class ToMnetDGLDatasetMental(DGLDataset):
|
||
|
"""
|
||
|
Training dataset class.
|
||
|
"""
|
||
|
def __init__(self, path, types=None, mode="train"):
|
||
|
self.path = path
|
||
|
self.types = types
|
||
|
self.mode = mode
|
||
|
print('Mode:', self.mode)
|
||
|
|
||
|
if self.mode == 'train':
|
||
|
if len(self.types) == 4:
|
||
|
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/'
|
||
|
elif len(self.types) == 3:
|
||
|
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper() + '/'
|
||
|
print(self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper())
|
||
|
elif len(self.types) == 2:
|
||
|
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + '/'
|
||
|
print(self.types[0][0].upper() + self.types[1][0].upper())
|
||
|
elif len(self.types) == 1:
|
||
|
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + '/'
|
||
|
else: raise ValueError('Number of types different from 1 or 4.')
|
||
|
elif self.mode == 'val':
|
||
|
assert len(self.types) == 1
|
||
|
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/' + self.types[0] + '/'
|
||
|
else:
|
||
|
raise ValueError
|
||
|
|
||
|
def get_test(self, states, actions):
|
||
|
"""
|
||
|
return: past_test_graphs, past_test_actions, test_graph, test_action
|
||
|
"""
|
||
|
frame_graphs = dgl.unbatch(states)
|
||
|
trial_len = len(frame_graphs)
|
||
|
query_idx = random.randint(0, trial_len - 1)
|
||
|
test_graph = frame_graphs[query_idx]
|
||
|
test_action = actions[query_idx]
|
||
|
if query_idx > 0:
|
||
|
#past_test_actions = F.pad(past_test_actions, (0,0,0,41-query_idx), 'constant', 0)
|
||
|
if query_idx == 1:
|
||
|
past_test_graphs = frame_graphs[0]
|
||
|
past_test_actions = actions[:query_idx]
|
||
|
past_test_actions = F.pad(past_test_actions, (0,0,0,41-query_idx), 'constant', 0)
|
||
|
return past_test_graphs, past_test_actions, query_idx, test_graph, test_action
|
||
|
else:
|
||
|
past_test_graphs = frame_graphs[:query_idx]
|
||
|
past_test_actions = actions[:query_idx]
|
||
|
past_test_graphs = dgl.batch(past_test_graphs)
|
||
|
past_test_actions = F.pad(past_test_actions, (0,0,0,41-query_idx), 'constant', 0)
|
||
|
return past_test_graphs, past_test_actions, query_idx, test_graph, test_action
|
||
|
else:
|
||
|
test_action_ = F.pad(test_action.unsqueeze(0), (0,0,0,41-1), 'constant', 0)
|
||
|
# NOTE: since there are no past frames, return the test frame and action and query_idx=1
|
||
|
# not sure what would be a better alternative
|
||
|
return test_graph, test_action_, 1, test_graph, test_action
|
||
|
|
||
|
def __getitem__(self, idx):
|
||
|
with open(self.path+str(idx)+'.pkl', 'rb') as f:
|
||
|
states, actions, lens, _ = pkl.load(f)
|
||
|
ziplist = list(zip(states, actions, lens))
|
||
|
random.shuffle(ziplist)
|
||
|
states, actions, lens = zip(*ziplist)
|
||
|
states, actions, lens = [*states], [*actions], [*lens]
|
||
|
past_test_s, past_test_a, past_test_len, test_s, test_a = self.get_test(states[-1], actions[-1])
|
||
|
dem_s = states[:-1]
|
||
|
dem_a = actions[:-1]
|
||
|
dem_lens = lens[:-1]
|
||
|
dem_s = dgl.batch(dem_s)
|
||
|
dem_a = torch.stack(dem_a)
|
||
|
return dem_s, dem_a, dem_lens, past_test_s, past_test_a, past_test_len, test_s, test_a
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(os.listdir(self.path))
|
||
|
|
||
|
# --------------------------------------------------------------------------------------------------------
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
|
||
|
types = [
|
||
|
'preference', 'multi_agent', 'inaccessible_goal',
|
||
|
'efficiency_irrational', 'efficiency_time','efficiency_path',
|
||
|
'instrumental_no_barrier', 'instrumental_blocking_barrier', 'instrumental_inconsequential_barrier'
|
||
|
]
|
||
|
|
||
|
mental_dataset = ToMnetDGLDatasetMental(
|
||
|
path='/datasets/external/bib_train/graphs/all_tasks/',
|
||
|
types=['instrumental_action'],
|
||
|
mode='train'
|
||
|
)
|
||
|
dem_frames, dem_actions, dem_lens, past_test_frames, past_test_actions, len, test_frame, test_action = mental_dataset.__getitem__(99)
|
||
|
breakpoint()
|