import pickle as pkl import os import torch import torch.utils.data import torch.nn.functional as F import dgl import random from dgl.data import DGLDataset def collate_function_seq(batch): #dem_frames = torch.stack([item[0] for item in batch]) dem_frames = dgl.batch([item[0] for item in batch]) dem_actions = torch.stack([item[1] for item in batch]) dem_lens = [item[2] for item in batch] #query_frames = torch.stack([item[3] for item in batch]) query_frames = dgl.batch([item[3] for item in batch]) target_actions = torch.stack([item[4] for item in batch]) return [dem_frames, dem_actions, dem_lens, query_frames, target_actions] def collate_function_seq_test(batch): dem_expected_states = dgl.batch([item[0] for item in batch][0]) dem_expected_actions = torch.stack([item[1] for item in batch][0]).unsqueeze(dim=0) dem_expected_lens = [item[2] for item in batch] #print(dem_expected_actions.size()) dem_unexpected_states = dgl.batch([item[3] for item in batch][0]) dem_unexpected_actions = torch.stack([item[4] for item in batch][0]).unsqueeze(dim=0) dem_unexpected_lens = [item[5] for item in batch] query_expected_frames = dgl.batch([item[6] for item in batch]) target_expected_actions = torch.stack([item[7] for item in batch]) #print(target_expected_actions.size()) query_unexpected_frames = dgl.batch([item[8] for item in batch]) target_unexpected_actions = torch.stack([item[9] for item in batch]) return [ dem_expected_states, dem_expected_actions, dem_expected_lens, \ dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, \ query_expected_frames, target_expected_actions, \ query_unexpected_frames, target_unexpected_actions ] def collate_function_mental(batch): dem_frames = dgl.batch([item[0] for item in batch]) dem_actions = torch.stack([item[1] for item in batch]) dem_lens = [item[2] for item in batch] past_test_frames = dgl.batch([item[3] for item in batch]) past_test_actions = torch.stack([item[4] for item in batch]) past_test_len = [item[5] for item in batch] query_frames = dgl.batch([item[6] for item in batch]) target_actions = torch.stack([item[7] for item in batch]) return [dem_frames, dem_actions, dem_lens, past_test_frames, past_test_actions, past_test_len, query_frames, target_actions] class ToMnetDGLDataset(DGLDataset): """ Training dataset class. """ def __init__(self, path, types=None, mode="train"): self.path = path self.types = types self.mode = mode print('Mode:', self.mode) if self.mode == 'train': if len(self.types) == 4: self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/' #self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_global/' #self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_local/' elif len(self.types) == 3: self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper() + '/' print(self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper()) elif len(self.types) == 2: self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + '/' print(self.types[0][0].upper() + self.types[1][0].upper()) elif len(self.types) == 1: self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + '/' else: raise ValueError('Number of types different from 1 or 4.') elif self.mode == 'val': assert len(self.types) == 1 self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/' + self.types[0] + '/' #self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_global/' + self.types[0] + '/' #self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_local/' + self.types[0] + '/' else: raise ValueError def get_test(self, states, actions): # now states is a batched graph -> unbatch it, take the len, pick one sub-graph # randomly and select the corresponding action frame_graphs = dgl.unbatch(states) trial_len = len(frame_graphs) query_idx = random.randint(0, trial_len - 1) query_graph = frame_graphs[query_idx] target_action = actions[query_idx] return query_graph, target_action def __getitem__(self, idx): with open(self.path+str(idx)+'.pkl', 'rb') as f: states, actions, lens, _ = pkl.load(f) # shuffle ziplist = list(zip(states, actions, lens)) random.shuffle(ziplist) states, actions, lens = zip(*ziplist) # convert tuples to lists states, actions, lens = [*states], [*actions], [*lens] # pick last element in the list as test and pick random frame test_s, test_a = self.get_test(states[-1], actions[-1]) dem_s = states[:-1] dem_a = actions[:-1] dem_lens = lens[:-1] dem_s = dgl.batch(dem_s) dem_a = torch.stack(dem_a) return dem_s, dem_a, dem_lens, test_s, test_a def __len__(self): return len(os.listdir(self.path)) class TestToMnetDGLDataset(DGLDataset): """ Testing dataset class. """ def __init__(self, path, task_type=None, mode="test"): self.path = path self.type = task_type self.mode = mode print('Mode:', self.mode) if self.mode == 'test': self.path = self.path + '_dgl_hetero_nobound_4feats/' + self.type + '/' #self.path = self.path + '_dgl_hetero_nobound_4feats_global/' + self.type + '/' #self.path = self.path + '_dgl_hetero_nobound_4feats_local/' + self.type + '/' else: raise ValueError def __getitem__(self, idx): with open(self.path+str(idx)+'.pkl', 'rb') as f: dem_expected_states, dem_expected_actions, dem_expected_lens, _, \ dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, _, \ query_expected_frames, target_expected_actions, \ query_unexpected_frames, target_unexpected_actions = pkl.load(f) assert len(dem_expected_states) == 8 assert len(dem_expected_actions) == 8 assert len(dem_expected_lens) == 8 assert len(dem_unexpected_states) == 8 assert len(dem_unexpected_actions) == 8 assert len(dem_unexpected_lens) == 8 assert len(dgl.unbatch(query_expected_frames)) == target_expected_actions.size()[0] assert len(dgl.unbatch(query_unexpected_frames)) == target_unexpected_actions.size()[0] # ignore n_nodes return dem_expected_states, dem_expected_actions, dem_expected_lens, \ dem_unexpected_states, dem_unexpected_actions, dem_unexpected_lens, \ query_expected_frames, target_expected_actions, \ query_unexpected_frames, target_unexpected_actions def __len__(self): return len(os.listdir(self.path)) class ToMnetDGLDatasetUndersample(DGLDataset): """ Training dataset class for the behavior cloning mlp model. """ def __init__(self, path, types=None, mode="train"): self.path = path self.types = types self.mode = mode print('Mode:', self.mode) if self.mode == 'train': if len(self.types) == 4: self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/' elif len(self.types) == 3: self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper() + '/' print(self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper()) elif len(self.types) == 2: self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + '/' print(self.types[0][0].upper() + self.types[1][0].upper()) elif len(self.types) == 1: self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + '/' else: raise ValueError('Number of types different from 1 or 4.') elif self.mode == 'val': assert len(self.types) == 1 self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/' + self.types[0] + '/' else: raise ValueError print('Undersampled dataset!') def get_test(self, states, actions): # now states is a batched graph -> unbatch it, take the len, pick one sub-graph # randomly and select the corresponding action frame_graphs = dgl.unbatch(states) trial_len = len(frame_graphs) query_idx = random.randint(0, trial_len - 1) query_graph = frame_graphs[query_idx] target_action = actions[query_idx] return query_graph, target_action def __getitem__(self, idx): idx = idx + 3175 with open(self.path+str(idx)+'.pkl', 'rb') as f: states, actions, lens, _ = pkl.load(f) # shuffle ziplist = list(zip(states, actions, lens)) random.shuffle(ziplist) states, actions, lens = zip(*ziplist) # convert tuples to lists states, actions, lens = [*states], [*actions], [*lens] # pick last element in the list as test and pick random frame test_s, test_a = self.get_test(states[-1], actions[-1]) dem_s = states[:-1] dem_a = actions[:-1] dem_lens = lens[:-1] dem_s = dgl.batch(dem_s) dem_a = torch.stack(dem_a) return dem_s, dem_a, dem_lens, test_s, test_a def __len__(self): return len(os.listdir(self.path)) - 3175 class ToMnetDGLDatasetMental(DGLDataset): """ Training dataset class. """ def __init__(self, path, types=None, mode="train"): self.path = path self.types = types self.mode = mode print('Mode:', self.mode) if self.mode == 'train': if len(self.types) == 4: self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/' elif len(self.types) == 3: self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper() + '/' print(self.types[0][0].upper() + self.types[1][0].upper() + self.types[2][0].upper()) elif len(self.types) == 2: self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + self.types[1][0].upper() + '/' print(self.types[0][0].upper() + self.types[1][0].upper()) elif len(self.types) == 1: self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats_' + self.types[0][0].upper() + '/' else: raise ValueError('Number of types different from 1 or 4.') elif self.mode == 'val': assert len(self.types) == 1 self.path = self.path + self.mode + '_dgl_hetero_nobound_4feats/' + self.types[0] + '/' else: raise ValueError def get_test(self, states, actions): """ return: past_test_graphs, past_test_actions, test_graph, test_action """ frame_graphs = dgl.unbatch(states) trial_len = len(frame_graphs) query_idx = random.randint(0, trial_len - 1) test_graph = frame_graphs[query_idx] test_action = actions[query_idx] if query_idx > 0: #past_test_actions = F.pad(past_test_actions, (0,0,0,41-query_idx), 'constant', 0) if query_idx == 1: past_test_graphs = frame_graphs[0] past_test_actions = actions[:query_idx] past_test_actions = F.pad(past_test_actions, (0,0,0,41-query_idx), 'constant', 0) return past_test_graphs, past_test_actions, query_idx, test_graph, test_action else: past_test_graphs = frame_graphs[:query_idx] past_test_actions = actions[:query_idx] past_test_graphs = dgl.batch(past_test_graphs) past_test_actions = F.pad(past_test_actions, (0,0,0,41-query_idx), 'constant', 0) return past_test_graphs, past_test_actions, query_idx, test_graph, test_action else: test_action_ = F.pad(test_action.unsqueeze(0), (0,0,0,41-1), 'constant', 0) # NOTE: since there are no past frames, return the test frame and action and query_idx=1 # not sure what would be a better alternative return test_graph, test_action_, 1, test_graph, test_action def __getitem__(self, idx): with open(self.path+str(idx)+'.pkl', 'rb') as f: states, actions, lens, _ = pkl.load(f) ziplist = list(zip(states, actions, lens)) random.shuffle(ziplist) states, actions, lens = zip(*ziplist) states, actions, lens = [*states], [*actions], [*lens] past_test_s, past_test_a, past_test_len, test_s, test_a = self.get_test(states[-1], actions[-1]) dem_s = states[:-1] dem_a = actions[:-1] dem_lens = lens[:-1] dem_s = dgl.batch(dem_s) dem_a = torch.stack(dem_a) return dem_s, dem_a, dem_lens, past_test_s, past_test_a, past_test_len, test_s, test_a def __len__(self): return len(os.listdir(self.path)) # -------------------------------------------------------------------------------------------------------- if __name__ == "__main__": types = [ 'preference', 'multi_agent', 'inaccessible_goal', 'efficiency_irrational', 'efficiency_time','efficiency_path', 'instrumental_no_barrier', 'instrumental_blocking_barrier', 'instrumental_inconsequential_barrier' ] mental_dataset = ToMnetDGLDatasetMental( path='/datasets/external/bib_train/graphs/all_tasks/', types=['instrumental_action'], mode='train' ) dem_frames, dem_actions, dem_lens, past_test_frames, past_test_actions, len, test_frame, test_action = mental_dataset.__getitem__(99) breakpoint()