IRENE/tom/dataset.py

310 lines
14 KiB
Python
Raw Normal View History

2024-02-01 15:40:47 +01:00
import pickle as pkl
import os
import torch
import torch.utils.data
import torch.nn.functional as F
import dgl
import random
from dgl.data import DGLDataset
def collate_function_seq(batch):
#dem_frames = torch.stack([item[0] for item in batch])
dem_frames = dgl.batch([item[0] for item in batch])
dem_actions = torch.stack([item[1] for item in batch])
dem_lens = [item[2] for item in batch]
#query_frames = torch.stack([item[3] for item in batch])
query_frames = dgl.batch([item[3] for item in batch])
target_actions = torch.stack([item[4] for item in batch])
return [dem_frames, dem_actions, dem_lens, query_frames, target_actions]
def collate_function_seq_test(batch):
dem_expected_states = dgl.batch([item[0] for item in batch][0])
dem_expected_actions = torch.stack([item[1] for item in batch][0]).unsqueeze(dim=0)
dem_expected_lens = [item[2] for item in batch]
#print(dem_expected_actions.size())
dem_unexpected_states = dgl.batch([item[3] for item in batch][0])
dem_unexpected_actions = torch.stack([item[4] for item in batch][0]).unsqueeze(dim=0)
dem_unexpected_lens = [item[5] for item in batch]
query_expected_frames = dgl.batch([item[6] for item in batch])
target_expected_actions = torch.stack([item[7] for item in batch])
#print(target_expected_actions.size())
query_unexpected_frames = dgl.batch([item[8] for item in batch])
target_unexpected_actions = torch.stack([item[9] for item in batch])
return [
dem_expected_states, dem_expected_actions, dem_expected_lens, \
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, \
query_expected_frames, target_expected_actions, \
query_unexpected_frames, target_unexpected_actions
]
def collate_function_mental(batch):
dem_frames = dgl.batch([item[0] for item in batch])
dem_actions = torch.stack([item[1] for item in batch])
dem_lens = [item[2] for item in batch]
past_test_frames = dgl.batch([item[3] for item in batch])
past_test_actions = torch.stack([item[4] for item in batch])
past_test_len = [item[5] for item in batch]
query_frames = dgl.batch([item[6] for item in batch])
target_actions = torch.stack([item[7] for item in batch])
return [dem_frames, dem_actions, dem_lens, past_test_frames, past_test_actions, past_test_len, query_frames, target_actions]
class ToMnetDGLDataset(DGLDataset):
"""
Training dataset class.
"""
def __init__(self, path, types=None, mode="train"):
self.path = path
self.types = types
self.mode = mode
print('Mode:', self.mode)
if self.mode == 'train':
if len(self.types) == 4:
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/'
#self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_global/'
#self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_local/'
elif len(self.types) == 3:
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper() + '/'
print(self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper())
elif len(self.types) == 2:
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + '/'
print(self.types[0][0].upper() + self.types[1][0].upper())
elif len(self.types) == 1:
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + '/'
else: raise ValueError('Number of types different from 1 or 4.')
elif self.mode == 'val':
assert len(self.types) == 1
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/' + self.types[0] + '/'
#self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_global/' + self.types[0] + '/'
#self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_local/' + self.types[0] + '/'
else:
raise ValueError
def get_test(self, states, actions):
# now states is a batched graph -> unbatch it, take the len, pick one sub-graph
# randomly and select the corresponding action
frame_graphs = dgl.unbatch(states)
trial_len = len(frame_graphs)
query_idx = random.randint(0, trial_len - 1)
query_graph = frame_graphs[query_idx]
target_action = actions[query_idx]
return query_graph, target_action
def __getitem__(self, idx):
with open(self.path+str(idx)+'.pkl', 'rb') as f:
states, actions, lens, _ = pkl.load(f)
# shuffle
ziplist = list(zip(states, actions, lens))
random.shuffle(ziplist)
states, actions, lens = zip(*ziplist)
# convert tuples to lists
states, actions, lens = [*states], [*actions], [*lens]
# pick last element in the list as test and pick random frame
test_s, test_a = self.get_test(states[-1], actions[-1])
dem_s = states[:-1]
dem_a = actions[:-1]
dem_lens = lens[:-1]
dem_s = dgl.batch(dem_s)
dem_a = torch.stack(dem_a)
return dem_s, dem_a, dem_lens, test_s, test_a
def __len__(self):
return len(os.listdir(self.path))
class TestToMnetDGLDataset(DGLDataset):
"""
Testing dataset class.
"""
def __init__(self, path, task_type=None, mode="test"):
self.path = path
self.type = task_type
self.mode = mode
print('Mode:', self.mode)
if self.mode == 'test':
self.path = self.path + '_dgl_hetero_nobound_4feats/' + self.type + '/'
#self.path = self.path + '_dgl_hetero_nobound_4feats_global/' + self.type + '/'
#self.path = self.path + '_dgl_hetero_nobound_4feats_local/' + self.type + '/'
else:
raise ValueError
def __getitem__(self, idx):
with open(self.path+str(idx)+'.pkl', 'rb') as f:
dem_expected_states, dem_expected_actions, dem_expected_lens, _, \
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, _, \
query_expected_frames, target_expected_actions, \
query_unexpected_frames, target_unexpected_actions = pkl.load(f)
assert len(dem_expected_states) == 8
assert len(dem_expected_actions) == 8
assert len(dem_expected_lens) == 8
assert len(dem_unexpected_states) == 8
assert len(dem_unexpected_actions) == 8
assert len(dem_unexpected_lens) == 8
assert len(dgl.unbatch(query_expected_frames)) == target_expected_actions.size()[0]
assert len(dgl.unbatch(query_unexpected_frames)) == target_unexpected_actions.size()[0]
# ignore n_nodes
return dem_expected_states, dem_expected_actions, dem_expected_lens, \
dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, \
query_expected_frames, target_expected_actions, \
query_unexpected_frames, target_unexpected_actions
def __len__(self):
return len(os.listdir(self.path))
class ToMnetDGLDatasetUndersample(DGLDataset):
"""
Training dataset class for the behavior cloning mlp model.
"""
def __init__(self, path, types=None, mode="train"):
self.path = path
self.types = types
self.mode = mode
print('Mode:', self.mode)
if self.mode == 'train':
if len(self.types) == 4:
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/'
elif len(self.types) == 3:
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper() + '/'
print(self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper())
elif len(self.types) == 2:
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + '/'
print(self.types[0][0].upper() + self.types[1][0].upper())
elif len(self.types) == 1:
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + '/'
else: raise ValueError('Number of types different from 1 or 4.')
elif self.mode == 'val':
assert len(self.types) == 1
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/' + self.types[0] + '/'
else:
raise ValueError
print('Undersampled dataset!')
def get_test(self, states, actions):
# now states is a batched graph -> unbatch it, take the len, pick one sub-graph
# randomly and select the corresponding action
frame_graphs = dgl.unbatch(states)
trial_len = len(frame_graphs)
query_idx = random.randint(0, trial_len - 1)
query_graph = frame_graphs[query_idx]
target_action = actions[query_idx]
return query_graph, target_action
def __getitem__(self, idx):
idx = idx + 3175
with open(self.path+str(idx)+'.pkl', 'rb') as f:
states, actions, lens, _ = pkl.load(f)
# shuffle
ziplist = list(zip(states, actions, lens))
random.shuffle(ziplist)
states, actions, lens = zip(*ziplist)
# convert tuples to lists
states, actions, lens = [*states], [*actions], [*lens]
# pick last element in the list as test and pick random frame
test_s, test_a = self.get_test(states[-1], actions[-1])
dem_s = states[:-1]
dem_a = actions[:-1]
dem_lens = lens[:-1]
dem_s = dgl.batch(dem_s)
dem_a = torch.stack(dem_a)
return dem_s, dem_a, dem_lens, test_s, test_a
def __len__(self):
return len(os.listdir(self.path)) - 3175
class ToMnetDGLDatasetMental(DGLDataset):
"""
Training dataset class.
"""
def __init__(self, path, types=None, mode="train"):
self.path = path
self.types = types
self.mode = mode
print('Mode:', self.mode)
if self.mode == 'train':
if len(self.types) == 4:
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/'
elif len(self.types) == 3:
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper() + '/'
print(self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper())
elif len(self.types) == 2:
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + '/'
print(self.types[0][0].upper() + self.types[1][0].upper())
elif len(self.types) == 1:
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + '/'
else: raise ValueError('Number of types different from 1 or 4.')
elif self.mode == 'val':
assert len(self.types) == 1
self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/' + self.types[0] + '/'
else:
raise ValueError
def get_test(self, states, actions):
"""
return: past_test_graphs, past_test_actions, test_graph, test_action
"""
frame_graphs = dgl.unbatch(states)
trial_len = len(frame_graphs)
query_idx = random.randint(0, trial_len - 1)
test_graph = frame_graphs[query_idx]
test_action = actions[query_idx]
if query_idx > 0:
#past_test_actions = F.pad(past_test_actions, (0,0,0,41-query_idx), 'constant', 0)
if query_idx == 1:
past_test_graphs = frame_graphs[0]
past_test_actions = actions[:query_idx]
past_test_actions = F.pad(past_test_actions, (0,0,0,41-query_idx), 'constant', 0)
return past_test_graphs, past_test_actions, query_idx, test_graph, test_action
else:
past_test_graphs = frame_graphs[:query_idx]
past_test_actions = actions[:query_idx]
past_test_graphs = dgl.batch(past_test_graphs)
past_test_actions = F.pad(past_test_actions, (0,0,0,41-query_idx), 'constant', 0)
return past_test_graphs, past_test_actions, query_idx, test_graph, test_action
else:
test_action_ = F.pad(test_action.unsqueeze(0), (0,0,0,41-1), 'constant', 0)
# NOTE: since there are no past frames, return the test frame and action and query_idx=1
# not sure what would be a better alternative
return test_graph, test_action_, 1, test_graph, test_action
def __getitem__(self, idx):
with open(self.path+str(idx)+'.pkl', 'rb') as f:
states, actions, lens, _ = pkl.load(f)
ziplist = list(zip(states, actions, lens))
random.shuffle(ziplist)
states, actions, lens = zip(*ziplist)
states, actions, lens = [*states], [*actions], [*lens]
past_test_s, past_test_a, past_test_len, test_s, test_a = self.get_test(states[-1], actions[-1])
dem_s = states[:-1]
dem_a = actions[:-1]
dem_lens = lens[:-1]
dem_s = dgl.batch(dem_s)
dem_a = torch.stack(dem_a)
return dem_s, dem_a, dem_lens, past_test_s, past_test_a, past_test_len, test_s, test_a
def __len__(self):
return len(os.listdir(self.path))
# --------------------------------------------------------------------------------------------------------
if __name__ == "__main__":
types = [
'preference', 'multi_agent', 'inaccessible_goal',
'efficiency_irrational', 'efficiency_time','efficiency_path',
'instrumental_no_barrier', 'instrumental_blocking_barrier', 'instrumental_inconsequential_barrier'
]
mental_dataset = ToMnetDGLDatasetMental(
path='/datasets/external/bib_train/graphs/all_tasks/',
types=['instrumental_action'],
mode='train'
)
dem_frames, dem_actions, dem_lens, past_test_frames, past_test_actions, len, test_frame, test_action = mental_dataset.__getitem__(99)
breakpoint()