up
This commit is contained in:
parent
a333481e05
commit
de0bea7508
18 changed files with 3150 additions and 2 deletions
310
tom/dataset.py
Normal file
310
tom/dataset.py
Normal file
|
@ -0,0 +1,310 @@
|
|||
import pickle as pkl
|
||||
import os
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import torch.nn.functional as F
|
||||
import dgl
|
||||
import random
|
||||
from dgl.data import DGLDataset
|
||||
|
||||
|
||||
def collate_function_seq(batch):
|
||||
#dem_frames = torch.stack([item[0] for item in batch])
|
||||
dem_frames = dgl.batch([item[0] for item in batch])
|
||||
dem_actions = torch.stack([item[1] for item in batch])
|
||||
dem_lens = [item[2] for item in batch]
|
||||
#query_frames = torch.stack([item[3] for item in batch])
|
||||
query_frames = dgl.batch([item[3] for item in batch])
|
||||
target_actions = torch.stack([item[4] for item in batch])
|
||||
return [dem_frames, dem_actions, dem_lens, query_frames, target_actions]
|
||||
|
||||
def collate_function_seq_test(batch):
|
||||
dem_expected_states = dgl.batch([item[0] for item in batch][0])
|
||||
dem_expected_actions = torch.stack([item[1] for item in batch][0]).unsqueeze(dim=0)
|
||||
dem_expected_lens = [item[2] for item in batch]
|
||||
#print(dem_expected_actions.size())
|
||||
dem_unexpected_states = dgl.batch([item[3] for item in batch][0])
|
||||
dem_unexpected_actions = torch.stack([item[4] for item in batch][0]).unsqueeze(dim=0)
|
||||
dem_unexpected_lens = [item[5] for item in batch]
|
||||
query_expected_frames = dgl.batch([item[6] for item in batch])
|
||||
target_expected_actions = torch.stack([item[7] for item in batch])
|
||||
#print(target_expected_actions.size())
|
||||
query_unexpected_frames = dgl.batch([item[8] for item in batch])
|
||||
target_unexpected_actions = torch.stack([item[9] for item in batch])
|
||||
return [
|
||||
dem_expected_states, dem_expected_actions, dem_expected_lens, \
|
||||
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, \
|
||||
query_expected_frames, target_expected_actions, \
|
||||
query_unexpected_frames, target_unexpected_actions
|
||||
]
|
||||
|
||||
def collate_function_mental(batch):
|
||||
dem_frames = dgl.batch([item[0] for item in batch])
|
||||
dem_actions = torch.stack([item[1] for item in batch])
|
||||
dem_lens = [item[2] for item in batch]
|
||||
past_test_frames = dgl.batch([item[3] for item in batch])
|
||||
past_test_actions = torch.stack([item[4] for item in batch])
|
||||
past_test_len = [item[5] for item in batch]
|
||||
query_frames = dgl.batch([item[6] for item in batch])
|
||||
target_actions = torch.stack([item[7] for item in batch])
|
||||
return [dem_frames, dem_actions, dem_lens, past_test_frames, past_test_actions, past_test_len, query_frames, target_actions]
|
||||
|
||||
|
||||
class ToMnetDGLDataset(DGLDataset):
|
||||
"""
|
||||
Training dataset class.
|
||||
"""
|
||||
def __init__(self, path, types=None, mode="train"):
|
||||
self.path = path
|
||||
self.types = types
|
||||
self.mode = mode
|
||||
print('Mode:', self.mode)
|
||||
|
||||
if self.mode == 'train':
|
||||
if len(self.types) == 4:
|
||||
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/'
|
||||
#self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_global/'
|
||||
#self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_local/'
|
||||
elif len(self.types) == 3:
|
||||
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper() + '/'
|
||||
print(self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper())
|
||||
elif len(self.types) == 2:
|
||||
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + '/'
|
||||
print(self.types[0][0].upper() + self.types[1][0].upper())
|
||||
elif len(self.types) == 1:
|
||||
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + '/'
|
||||
else: raise ValueError('Number of types different from 1 or 4.')
|
||||
elif self.mode == 'val':
|
||||
assert len(self.types) == 1
|
||||
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/' + self.types[0] + '/'
|
||||
#self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_global/' + self.types[0] + '/'
|
||||
#self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_local/' + self.types[0] + '/'
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
def get_test(self, states, actions):
|
||||
# now states is a batched graph -> unbatch it, take the len, pick one sub-graph
|
||||
# randomly and select the corresponding action
|
||||
frame_graphs = dgl.unbatch(states)
|
||||
trial_len = len(frame_graphs)
|
||||
query_idx = random.randint(0, trial_len - 1)
|
||||
query_graph = frame_graphs[query_idx]
|
||||
target_action = actions[query_idx]
|
||||
return query_graph, target_action
|
||||
|
||||
def __getitem__(self, idx):
|
||||
with open(self.path+str(idx)+'.pkl', 'rb') as f:
|
||||
states, actions, lens, _ = pkl.load(f)
|
||||
# shuffle
|
||||
ziplist = list(zip(states, actions, lens))
|
||||
random.shuffle(ziplist)
|
||||
states, actions, lens = zip(*ziplist)
|
||||
# convert tuples to lists
|
||||
states, actions, lens = [*states], [*actions], [*lens]
|
||||
# pick last element in the list as test and pick random frame
|
||||
test_s, test_a = self.get_test(states[-1], actions[-1])
|
||||
dem_s = states[:-1]
|
||||
dem_a = actions[:-1]
|
||||
dem_lens = lens[:-1]
|
||||
dem_s = dgl.batch(dem_s)
|
||||
dem_a = torch.stack(dem_a)
|
||||
return dem_s, dem_a, dem_lens, test_s, test_a
|
||||
|
||||
def __len__(self):
|
||||
return len(os.listdir(self.path))
|
||||
|
||||
|
||||
class TestToMnetDGLDataset(DGLDataset):
|
||||
"""
|
||||
Testing dataset class.
|
||||
"""
|
||||
def __init__(self, path, task_type=None, mode="test"):
|
||||
self.path = path
|
||||
self.type = task_type
|
||||
self.mode = mode
|
||||
print('Mode:', self.mode)
|
||||
|
||||
if self.mode == 'test':
|
||||
self.path = self.path + '_dgl_hetero_nobound_4feats/' + self.type + '/'
|
||||
#self.path = self.path + '_dgl_hetero_nobound_4feats_global/' + self.type + '/'
|
||||
#self.path = self.path + '_dgl_hetero_nobound_4feats_local/' + self.type + '/'
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
def __getitem__(self, idx):
|
||||
with open(self.path+str(idx)+'.pkl', 'rb') as f:
|
||||
dem_expected_states, dem_expected_actions, dem_expected_lens, _, \
|
||||
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, _, \
|
||||
query_expected_frames, target_expected_actions, \
|
||||
query_unexpected_frames, target_unexpected_actions = pkl.load(f)
|
||||
assert len(dem_expected_states) == 8
|
||||
assert len(dem_expected_actions) == 8
|
||||
assert len(dem_expected_lens) == 8
|
||||
assert len(dem_unexpected_states) == 8
|
||||
assert len(dem_unexpected_actions) == 8
|
||||
assert len(dem_unexpected_lens) == 8
|
||||
assert len(dgl.unbatch(query_expected_frames)) == target_expected_actions.size()[0]
|
||||
assert len(dgl.unbatch(query_unexpected_frames)) == target_unexpected_actions.size()[0]
|
||||
# ignore n_nodes
|
||||
return dem_expected_states, dem_expected_actions, dem_expected_lens, \
|
||||
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, \
|
||||
query_expected_frames, target_expected_actions, \
|
||||
query_unexpected_frames, target_unexpected_actions
|
||||
|
||||
def __len__(self):
|
||||
return len(os.listdir(self.path))
|
||||
|
||||
|
||||
class ToMnetDGLDatasetUndersample(DGLDataset):
|
||||
"""
|
||||
Training dataset class for the behavior cloning mlp model.
|
||||
"""
|
||||
def __init__(self, path, types=None, mode="train"):
|
||||
self.path = path
|
||||
self.types = types
|
||||
self.mode = mode
|
||||
print('Mode:', self.mode)
|
||||
|
||||
if self.mode == 'train':
|
||||
if len(self.types) == 4:
|
||||
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/'
|
||||
elif len(self.types) == 3:
|
||||
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper() + '/'
|
||||
print(self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper())
|
||||
elif len(self.types) == 2:
|
||||
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + '/'
|
||||
print(self.types[0][0].upper() + self.types[1][0].upper())
|
||||
elif len(self.types) == 1:
|
||||
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + '/'
|
||||
else: raise ValueError('Number of types different from 1 or 4.')
|
||||
elif self.mode == 'val':
|
||||
assert len(self.types) == 1
|
||||
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/' + self.types[0] + '/'
|
||||
else:
|
||||
raise ValueError
|
||||
print('Undersampled dataset!')
|
||||
|
||||
def get_test(self, states, actions):
|
||||
# now states is a batched graph -> unbatch it, take the len, pick one sub-graph
|
||||
# randomly and select the corresponding action
|
||||
frame_graphs = dgl.unbatch(states)
|
||||
trial_len = len(frame_graphs)
|
||||
query_idx = random.randint(0, trial_len - 1)
|
||||
query_graph = frame_graphs[query_idx]
|
||||
target_action = actions[query_idx]
|
||||
return query_graph, target_action
|
||||
|
||||
def __getitem__(self, idx):
|
||||
idx = idx + 3175
|
||||
with open(self.path+str(idx)+'.pkl', 'rb') as f:
|
||||
states, actions, lens, _ = pkl.load(f)
|
||||
# shuffle
|
||||
ziplist = list(zip(states, actions, lens))
|
||||
random.shuffle(ziplist)
|
||||
states, actions, lens = zip(*ziplist)
|
||||
# convert tuples to lists
|
||||
states, actions, lens = [*states], [*actions], [*lens]
|
||||
# pick last element in the list as test and pick random frame
|
||||
test_s, test_a = self.get_test(states[-1], actions[-1])
|
||||
dem_s = states[:-1]
|
||||
dem_a = actions[:-1]
|
||||
dem_lens = lens[:-1]
|
||||
dem_s = dgl.batch(dem_s)
|
||||
dem_a = torch.stack(dem_a)
|
||||
return dem_s, dem_a, dem_lens, test_s, test_a
|
||||
|
||||
def __len__(self):
|
||||
return len(os.listdir(self.path)) - 3175
|
||||
|
||||
|
||||
class ToMnetDGLDatasetMental(DGLDataset):
|
||||
"""
|
||||
Training dataset class.
|
||||
"""
|
||||
def __init__(self, path, types=None, mode="train"):
|
||||
self.path = path
|
||||
self.types = types
|
||||
self.mode = mode
|
||||
print('Mode:', self.mode)
|
||||
|
||||
if self.mode == 'train':
|
||||
if len(self.types) == 4:
|
||||
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/'
|
||||
elif len(self.types) == 3:
|
||||
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper() + '/'
|
||||
print(self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper())
|
||||
elif len(self.types) == 2:
|
||||
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + '/'
|
||||
print(self.types[0][0].upper() + self.types[1][0].upper())
|
||||
elif len(self.types) == 1:
|
||||
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + '/'
|
||||
else: raise ValueError('Number of types different from 1 or 4.')
|
||||
elif self.mode == 'val':
|
||||
assert len(self.types) == 1
|
||||
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/' + self.types[0] + '/'
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
def get_test(self, states, actions):
|
||||
"""
|
||||
return: past_test_graphs, past_test_actions, test_graph, test_action
|
||||
"""
|
||||
frame_graphs = dgl.unbatch(states)
|
||||
trial_len = len(frame_graphs)
|
||||
query_idx = random.randint(0, trial_len - 1)
|
||||
test_graph = frame_graphs[query_idx]
|
||||
test_action = actions[query_idx]
|
||||
if query_idx > 0:
|
||||
#past_test_actions = F.pad(past_test_actions, (0,0,0,41-query_idx), 'constant', 0)
|
||||
if query_idx == 1:
|
||||
past_test_graphs = frame_graphs[0]
|
||||
past_test_actions = actions[:query_idx]
|
||||
past_test_actions = F.pad(past_test_actions, (0,0,0,41-query_idx), 'constant', 0)
|
||||
return past_test_graphs, past_test_actions, query_idx, test_graph, test_action
|
||||
else:
|
||||
past_test_graphs = frame_graphs[:query_idx]
|
||||
past_test_actions = actions[:query_idx]
|
||||
past_test_graphs = dgl.batch(past_test_graphs)
|
||||
past_test_actions = F.pad(past_test_actions, (0,0,0,41-query_idx), 'constant', 0)
|
||||
return past_test_graphs, past_test_actions, query_idx, test_graph, test_action
|
||||
else:
|
||||
test_action_ = F.pad(test_action.unsqueeze(0), (0,0,0,41-1), 'constant', 0)
|
||||
# NOTE: since there are no past frames, return the test frame and action and query_idx=1
|
||||
# not sure what would be a better alternative
|
||||
return test_graph, test_action_, 1, test_graph, test_action
|
||||
|
||||
def __getitem__(self, idx):
|
||||
with open(self.path+str(idx)+'.pkl', 'rb') as f:
|
||||
states, actions, lens, _ = pkl.load(f)
|
||||
ziplist = list(zip(states, actions, lens))
|
||||
random.shuffle(ziplist)
|
||||
states, actions, lens = zip(*ziplist)
|
||||
states, actions, lens = [*states], [*actions], [*lens]
|
||||
past_test_s, past_test_a, past_test_len, test_s, test_a = self.get_test(states[-1], actions[-1])
|
||||
dem_s = states[:-1]
|
||||
dem_a = actions[:-1]
|
||||
dem_lens = lens[:-1]
|
||||
dem_s = dgl.batch(dem_s)
|
||||
dem_a = torch.stack(dem_a)
|
||||
return dem_s, dem_a, dem_lens, past_test_s, past_test_a, past_test_len, test_s, test_a
|
||||
|
||||
def __len__(self):
|
||||
return len(os.listdir(self.path))
|
||||
|
||||
# --------------------------------------------------------------------------------------------------------
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
types = [
|
||||
'preference', 'multi_agent', 'inaccessible_goal',
|
||||
'efficiency_irrational', 'efficiency_time','efficiency_path',
|
||||
'instrumental_no_barrier', 'instrumental_blocking_barrier', 'instrumental_inconsequential_barrier'
|
||||
]
|
||||
|
||||
mental_dataset = ToMnetDGLDatasetMental(
|
||||
path='/datasets/external/bib_train/graphs/all_tasks/',
|
||||
types=['instrumental_action'],
|
||||
mode='train'
|
||||
)
|
||||
dem_frames, dem_actions, dem_lens, past_test_frames, past_test_actions, len, test_frame, test_action = mental_dataset.__getitem__(99)
|
||||
breakpoint()
|
Loading…
Add table
Add a link
Reference in a new issue