import os import numpy as np import torch from torch.utils.data import Dataset import json import math from collections import namedtuple Batch = namedtuple('Batch', 'Observations Actions Class') def get_vids_from_json(path): task_vids = {} with open(path, 'r') as f: json_data = json.load(f) for i in json_data: task = i['task'] vid = i['vid'] if task not in task_vids: task_vids[task] = [] task_vids[task].append(vid) return task_vids def get_vids(path): task_vids = {} with open(path, 'r') as f: for line in f: task, vid, url = line.strip().split(',') if task not in task_vids: task_vids[task] = [] task_vids[task].append(vid) return task_vids def read_task_info(path): titles = {} urls = {} n_steps = {} steps = {} with open(path, 'r') as f: idx = f.readline() while idx != '': idx = idx.strip() titles[idx] = f.readline().strip() urls[idx] = f.readline().strip() n_steps[idx] = int(f.readline().strip()) steps[idx] = f.readline().strip().split(',') next(f) idx = f.readline() return {'title': titles, 'url': urls, 'n_steps': n_steps, 'steps': steps} class PlanningDataset(Dataset): def __init__(self, root, args=None, is_val=False, model=None, crosstask_use_feature_how=True, ): self.is_val = is_val self.data_root = root self.args = args self.max_traj_len = args.horizon self.vid_names = None self.frame_cnts = None self.images = None self.last_vid = '' self.crosstask_use_feature_how = crosstask_use_feature_how if args.dataset == 'crosstask': """ . └── crosstask ├── crosstask_features └── crosstask_release ├── tasks_primary.txt ├── videos.csv or json └── videos_val.csv or json """ val_csv_path = os.path.join( root, 'dataset', 'crosstask', 'crosstask_release', 'test_list.json') # 'videos_val.csv') video_csv_path = os.path.join( root, 'dataset', 'crosstask', 'crosstask_release', 'train_list.json') # 'videos.csv') if crosstask_use_feature_how: self.features_path = os.path.join(root, 'dataset', 'crosstask', 'processed_data') else: self.features_path = os.path.join(root, 'dataset', 'crosstask', 'crosstask_features') self.constraints_path = os.path.join( root, 'dataset', 'crosstask', 'crosstask_release', 'annotations') self.action_one_hot = np.load( os.path.join(root, 'dataset', 'crosstask', 'crosstask_release', 'actions_one_hot.npy'), allow_pickle=True).item() self.task_class = { '23521': 0, '59684': 1, '71781': 2, '113766': 3, '105222': 4, '94276': 5, '53193': 6, '105253': 7, '44047': 8, '76400': 9, '16815': 10, '95603': 11, '109972': 12, '44789': 13, '40567': 14, '77721': 15, '87706': 16, '91515': 17 } # cross_task_data_name = "/data1/wanghanlin/diffusion_planning/jsons_crosstask105/sliding_window_cross_task_data_{}_{}_new_task_id_73.json".format( # is_val, self.max_traj_len) if is_val: cross_task_data_name = args.json_path_val else: cross_task_data_name = args.json_path_train if os.path.exists(cross_task_data_name): with open(cross_task_data_name, 'r') as f: self.json_data = json.load(f) print('Loaded {}'.format(cross_task_data_name)) else: file_type = val_csv_path.split('.')[-1] if file_type == 'json': all_task_vids = get_vids_from_json(video_csv_path) val_vids = get_vids_from_json(val_csv_path) else: all_task_vids = get_vids(video_csv_path) val_vids = get_vids(val_csv_path) if is_val: task_vids = val_vids else: task_vids = {task: [vid for vid in vids if task not in val_vids or vid not in val_vids[task]] for task, vids in all_task_vids.items()} primary_info = read_task_info(os.path.join( root, 'dataset', 'crosstask', 'crosstask_release', 'tasks_primary.txt')) self.n_steps = primary_info['n_steps'] all_tasks = set(self.n_steps.keys()) task_vids = {task: vids for task, vids in task_vids.items() if task in all_tasks} all_vids = [] for task, vids in task_vids.items(): all_vids.extend([(task, vid) for vid in vids]) json_data = [] for idx in range(len(all_vids)): task, vid = all_vids[idx] if self.crosstask_use_feature_how: video_path = os.path.join( self.features_path, str(task) + '_' + str(vid) + '.npy') else: video_path = os.path.join( self.features_path, str(vid) + '.npy') legal_range = self.process_single(task, vid) if not legal_range: continue temp_len = len(legal_range) temp = [] while temp_len < self.max_traj_len: temp.append(legal_range[0]) temp_len += 1 temp.extend(legal_range) legal_range = temp for i in range(len(legal_range) - self.max_traj_len + 1): legal_range_current = legal_range[i:i + self.max_traj_len] json_data.append({'id': {'vid': vid, 'task': task, 'feature': video_path, 'legal_range': legal_range_current, 'task_id': self.task_class[task]}, 'instruction_len': self.n_steps[task]}) self.json_data = json_data with open(cross_task_data_name, 'w') as f: json.dump(json_data, f) elif args.dataset == 'coin': coin_path = os.path.join(root, 'dataset/coin', 'full_npy/') val_csv_path = os.path.join( root, 'dataset/coin', 'coin_test_30.json') video_csv_path = os.path.join( root, 'dataset/coin', 'coin_train_70.json') # coin_data_name = "/data1/wanghanlin/diffusion_planning/jsons_coin/sliding_window_cross_task_data_{}_{}_new_task_id_73.json".format( # is_val, self.max_traj_len) if is_val: coin_data_name = args.json_path_val else: coin_data_name = args.json_path_train if os.path.exists(coin_data_name): with open(coin_data_name, 'r') as f: self.json_data = json.load(f) print('Loaded {}'.format(coin_data_name)) else: json_data = [] num = 0 if is_val: with open(val_csv_path, 'r') as f: coin_data = json.load(f) else: with open(video_csv_path, 'r') as f: coin_data = json.load(f) for i in coin_data: for (k, v) in i.items(): file_name = v['class'] + '_' + str(v['recipe_type']) + '_' + k + '.npy' file_path = coin_path + file_name images_ = np.load(file_path, allow_pickle=True) images = images_['frames_features'] legal_range = [] last_action = v['annotation'][-1]['segment'][1] last_action = math.ceil(last_action) if last_action > len(images): print(k, last_action, len(images)) num += 1 continue for annotation in v['annotation']: action_label = int(annotation['id']) - 1 start_idx, end_idx = annotation['segment'] start_idx = math.floor(start_idx) end_idx = math.ceil(end_idx) if end_idx < images.shape[0]: legal_range.append((start_idx, end_idx, action_label)) else: legal_range.append((start_idx, images.shape[0] - 1, action_label)) temp_len = len(legal_range) temp = [] while temp_len < self.max_traj_len: temp.append(legal_range[0]) temp_len += 1 temp.extend(legal_range) legal_range = temp for i in range(len(legal_range) - self.max_traj_len + 1): legal_range_current = legal_range[i:i + self.max_traj_len] json_data.append({'id': {'vid': k, 'feature': file_path, 'legal_range': legal_range_current, 'task_id': v['recipe_type']}, 'instruction_len': 0}) print(num) self.json_data = json_data with open(coin_data_name, 'w') as f: json.dump(json_data, f) elif args.dataset == 'NIV': val_csv_path = os.path.join( root, 'dataset/NIV', 'test30_new.json') video_csv_path = os.path.join( root, 'dataset/NIV', 'train70_new.json') # niv_data_name = "/data1/wanghanlin/diffusion_planning/jsons_niv/sliding_window_cross_task_data_{}_{}_new_task_id_73.json".format( # is_val, self.max_traj_len) if is_val: niv_data_name = args.json_path_val else: niv_data_name = args.json_path_train if os.path.exists(niv_data_name): with open(niv_data_name, 'r') as f: self.json_data = json.load(f) print('Loaded {}'.format(niv_data_name)) else: json_data = [] if is_val: with open(val_csv_path, 'r') as f: niv_data = json.load(f) else: with open(video_csv_path, 'r') as f: niv_data = json.load(f) for d in niv_data: legal_range = [] path = os.path.join( root, 'dataset/NIV', 'processed_data' , d['feature']) info = np.load(path, allow_pickle=True) num_steps = int(info['num_steps']) assert num_steps == len(info['steps_ids']) assert info['num_steps'] == len(info['steps_starts']) assert info['num_steps'] == len(info['steps_ends']) starts = info['steps_starts'] ends = info['steps_ends'] action_labels = info['steps_ids'] images = info['frames_features'] for i in range(num_steps): action_label = int(action_labels[i]) start_idx = math.floor(float(starts[i])) end_idx = math.ceil(float(ends[i])) if end_idx < images.shape[0]: legal_range.append((start_idx, end_idx, action_label)) else: legal_range.append((start_idx, images.shape[0] - 1, action_label)) temp_len = len(legal_range) temp = [] while temp_len < self.max_traj_len: temp.append(legal_range[0]) temp_len += 1 temp.extend(legal_range) legal_range = temp for i in range(len(legal_range) - self.max_traj_len + 1): legal_range_current = legal_range[i:i + self.max_traj_len] json_data.append({'id': {'feature': path, 'legal_range': legal_range_current, 'task_id': d['task_id']}, 'instruction_len': 0}) self.json_data = json_data with open(niv_data_name, 'w') as f: json.dump(json_data, f) print(len(json_data)) else: raise NotImplementedError( 'Dataset {} is not implemented'.format(args.dataset)) self.model = model self.prepare_data() self.M = 3 def process_single(self, task, vid): if self.crosstask_use_feature_how: if not os.path.exists(os.path.join(self.features_path, str(task) + '_' + str(vid) + '.npy')): return False images_ = np.load(os.path.join(self.features_path, str(task) + '_' + str(vid) + '.npy'), allow_pickle=True) images = images_['frames_features'] else: if not os.path.exists(os.path.join(self.features_path, vid + '.npy')): return False images = np.load(os.path.join(self.features_path, vid + '.npy')) cnst_path = os.path.join( self.constraints_path, task + '_' + vid + '.csv') legal_range = self.read_assignment(task, cnst_path) legal_range_ret = [] for (start_idx, end_idx, action_label) in legal_range: if not start_idx < images.shape[0]: print(task, vid, end_idx, images.shape[0]) return False if end_idx < images.shape[0]: legal_range_ret.append((start_idx, end_idx, action_label)) else: legal_range_ret.append((start_idx, images.shape[0] - 1, action_label)) return legal_range_ret def read_assignment(self, task_id, path): legal_range = [] with open(path, 'r') as f: for line in f: step, start, end = line.strip().split(',') start = int(math.floor(float(start))) end = int(math.ceil(float(end))) action_label_ind = self.action_one_hot[task_id + '_' + step] legal_range.append((start, end, action_label_ind)) return legal_range def prepare_data(self): vid_names = [] frame_cnts = [] for listdata in self.json_data: vid_names.append(listdata['id']) frame_cnts.append(listdata['instruction_len']) self.vid_names = vid_names self.frame_cnts = frame_cnts def curate_dataset(self, images, legal_range, M=2): images_list = [] labels_onehot_list = [] idx_list = [] for start_idx, end_idx, action_label in legal_range: idx = start_idx idx_list.append(idx) image_start_idx = max(0, idx) if image_start_idx + M <= len(images): image_start = images[image_start_idx: image_start_idx + M] else: image_start = images[len(images) - M: len(images)] image_start_cat = image_start[0] for w in range(len(image_start) - 1): image_start_cat = np.concatenate((image_start_cat, image_start[w + 1]), axis=0) images_list.append(image_start_cat) labels_onehot_list.append(action_label) end_idx = max(2, end_idx) image_end = images[end_idx - 2:end_idx + M - 2] image_end_cat = image_end[0] for w in range(len(image_end) - 1): image_end_cat = np.concatenate((image_end_cat, image_end[w + 1]), axis=0) images_list.append(image_end_cat) return images_list, labels_onehot_list, idx_list def sample_single(self, index): folder_id = self.vid_names[index] if self.is_val: event_class = folder_id['task_id'] else: task_class = folder_id['task_id'] if self.args.dataset == 'crosstask': if folder_id['vid'] != self.last_vid: if self.crosstask_use_feature_how: images_ = np.load(folder_id['feature'], allow_pickle=True) self.images = images_['frames_features'] self.last_vid = folder_id['vid'] else: self.images = np.load(os.path.join(self.features_path, folder_id['vid'] + '.npy')) else: images_ = np.load(folder_id['feature'], allow_pickle=True) self.images = images_['frames_features'] images, labels_matrix, idx_list = self.curate_dataset( self.images, folder_id['legal_range'], M=self.M) frames = torch.tensor(np.array(images)) labels_tensor = torch.tensor(labels_matrix, dtype=torch.long) if self.is_val: event_class = torch.tensor(event_class, dtype=torch.long) return frames, labels_tensor, event_class else: task_class = torch.tensor(task_class, dtype=torch.long) return frames, labels_tensor, task_class def __getitem__(self, index): if self.is_val: frames, labels, event_class = self.sample_single(index) else: frames, labels, task = self.sample_single(index) if self.is_val: batch = Batch(frames, labels, event_class) else: batch = Batch(frames, labels, task) return batch def __len__(self): return len(self.json_data)