from torch.utils.data import Dataset import os from torchvision import transforms from natsort import natsorted import numpy as np import pickle import torch import cv2 import ast from einops import rearrange from utils import spherical2cartesial class Data(Dataset): def __init__( self, frame_path, label_path, pose_path=None, gaze_path=None, bbox_path=None, ocr_graph_path=None, presaved=128, sizes = (128,128), spatial_patch_size: int = 16, temporal_patch_size: int = 4, img_channels: int = 3, patch_data: bool = False, flatten_dim: int = 1 ): self.data = {'frame_paths': [], 'labels': [], 'poses': [], 'gazes': [], 'bboxes': [], 'ocr_graph': []} self.size_1, self.size_2 = sizes self.patch_data = patch_data if self.patch_data: self.spatial_patch_size = spatial_patch_size self.temporal_patch_size = temporal_patch_size self.num_patches = (self.size_1 // self.spatial_patch_size) ** 2 self.spatial_patch_dim = img_channels * spatial_patch_size ** 2 assert self.size_1 % spatial_patch_size == 0 and self.size_2 % spatial_patch_size == 0, 'Image dimensions must be divisible by the patch size.' if presaved == 224: self.presaved = 'presaved' elif presaved == 128: self.presaved = 'presaved128' else: raise ValueError frame_dirs = os.listdir(frame_path) episode_ids = natsorted(frame_dirs) seq_lens = [] for frame_dir in natsorted(frame_dirs): frame_paths = [os.path.join(frame_path, frame_dir, i) for i in natsorted(os.listdir(os.path.join(frame_path, frame_dir)))] seq_lens.append(len(frame_paths)) self.data['frame_paths'].append(frame_paths) print('Labels...') with open(label_path, 'rb') as fp: labels = pickle.load(fp) left_labels, right_labels = labels for i in episode_ids: episode_id = int(i) left_labels_i = left_labels[episode_id] right_labels_i = right_labels[episode_id] self.data['labels'].append([left_labels_i, right_labels_i]) fp.close() print('Pose...') self.pose_path = pose_path if pose_path is not None: for i in episode_ids: pose_i_dir = os.path.join(pose_path, i) poses = [] for j in natsorted(os.listdir(pose_i_dir)): fs = cv2.FileStorage(os.path.join(pose_i_dir, j), cv2.FILE_STORAGE_READ) if torch.tensor(fs.getNode("pose_0").mat()).shape != (2,25,3): poses.append(fs.getNode("pose_0").mat()[:2,:,:]) else: poses.append(fs.getNode("pose_0").mat()) poses = np.array(poses) poses[:, :, :, 0] = poses[:, :, :, 0] / 1920 poses[:, :, :, 1] = poses[:, :, :, 1] / 1088 self.data['poses'].append(torch.flatten(torch.tensor(poses), flatten_dim)) #self.data['poses'].append(torch.tensor(np.array(poses))) print('Gaze...') """ Gaze information is represented by a 3D gaze vector based on the observing camera’s Cartesian eye coordinate system """ self.gaze_path = gaze_path if gaze_path is not None: for i in episode_ids: gaze_i_txt = os.path.join(gaze_path, '{}.txt'.format(i)) with open(gaze_i_txt, 'r') as fp: gaze_i = fp.readlines() fp.close() gaze_i = ast.literal_eval(gaze_i[0]) gaze_i_tensor = torch.zeros((len(gaze_i), 2, 3)) for j in range(len(gaze_i)): if len(gaze_i[j]) >= 2: gaze_i_tensor[j,:] = torch.tensor(gaze_i[j][:2]) elif len(gaze_i[j]) == 1: gaze_i_tensor[j,0] = torch.tensor(gaze_i[j][0]) else: continue self.data['gazes'].append(torch.flatten(gaze_i_tensor, flatten_dim)) print('Bbox...') self.bbox_path = bbox_path if bbox_path is not None: self.objects = [ 'none', 'apple', 'orange', 'lemon', 'potato', 'wine', 'wineopener', 'knife', 'mug', 'peeler', 'bowl', 'chocolate', 'sugar', 'magazine', 'cracker', 'chips', 'scissors', 'cap', 'marker', 'sardinecan', 'tomatocan', 'plant', 'walnut', 'nail', 'waterspray', 'hammer', 'canopener' ] """ # NOTE: old bbox for i in episode_ids: bbox_i_dir = os.path.join(bbox_path, i) with open(bbox_i_dir, 'rb') as fp: bboxes_i = pickle.load(fp) len_i = len(bboxes_i) fp.close() bboxes_i_tensor = torch.zeros((len_i, len(self.objects), 4)) for j in range(len(bboxes_i)): items_i_j, bboxes_i_j = bboxes_i[j] for k in range(len(items_i_j)): bboxes_i_tensor[j, self.objects.index(items_i_j[k])] = torch.tensor([ bboxes_i_j[k][0] / 1920, # * self.size_1, bboxes_i_j[k][1] / 1088, # * self.size_2, bboxes_i_j[k][2] / 1920, # * self.size_1, bboxes_i_j[k][3] / 1088, # * self.size_2 ]) # [x_min, y_min, x_max, y_max] self.data['bboxes'].append(torch.flatten(bboxes_i_tensor, 1)) """ # NOTE: new bbox for i in episode_ids: bbox_dir = os.path.join(bbox_path, i) bbox_tensor = torch.zeros((len(os.listdir(bbox_dir)), len(self.objects), 4)) # TODO: we might want to cut it to 10 objects for index, bbox in enumerate(sorted(os.listdir(bbox_dir), key=len)): with open(os.path.join(bbox_dir, bbox), 'r') as fp: bbox_content = fp.readlines() fp.close() for bbox_content_line in bbox_content: bbox_content_values = bbox_content_line.split() class_index, x_center, y_center, x_width, y_height = map(float, bbox_content_values) bbox_tensor[index][int(class_index)] = torch.FloatTensor([x_center, y_center, x_width, y_height]) self.data['bboxes'].append(torch.flatten(bbox_tensor, 1)) print('OCR...\n') self.ocr_graph_path = ocr_graph_path if ocr_graph_path is not None: ocr_graph = [ [15, [10, 4], [17, 2]], [13, [16, 7], [18, 4]], [11, [16, 4], [7, 10]], [14, [10, 11], [7, 1]], [12, [10, 9], [16, 3]], [1, [7, 2], [9, 9], [10, 2]], [5, [8, 8], [6, 8]], [4, [9, 8], [7, 6]], [3, [10, 1], [8, 3], [7, 4], [9, 2], [6, 1]], [2, [10, 1], [7, 7], [9, 3]], [19, [10, 2], [26, 6]], [20, [10, 7], [26, 5]], [22, [25, 4], [10, 8]], [23, [25, 15]], [21, [16, 5], [24, 8]] ] ocr_tensor = torch.zeros((27, 27)) for ocr in ocr_graph: obj = ocr[0] contexts = ocr[1:] total_context_count = sum([i[1] for i in contexts]) for context in contexts: ocr_tensor[obj, context[0]] = context[1] / total_context_count ocr_tensor = torch.flatten(ocr_tensor) for i in episode_ids: self.data['ocr_graph'].append(ocr_tensor) self.frame_path = frame_path self.transform = transforms.Compose([ #transforms.ToPILImage(), transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def __len__(self): return len(self.data['frame_paths']) def __getitem__(self, idx): frames_path = self.data['frame_paths'][idx][0].split('/') frames_path.insert(5, self.presaved) frames_path = ''.join(f'{w}/' for w in frames_path[:-1])[:-1]+'.pkl' with open(frames_path, 'rb') as f: images = pickle.load(f) images = torch.stack(images) if self.patch_data: images = rearrange( images, '(t p3) c (h p1) (w p2) -> t p3 (h w) (p1 p2 c)', p1=self.spatial_patch_size, p2=self.spatial_patch_size, p3=self.temporal_patch_size ) if self.pose_path is not None: pose = self.data['poses'][idx] if self.patch_data: pose = rearrange( pose, '(t p1) d -> t p1 d', p1=self.temporal_patch_size ) else: pose = None if self.gaze_path is not None: gaze = self.data['gazes'][idx] if self.patch_data: gaze = rearrange( gaze, '(t p1) d -> t p1 d', p1=self.temporal_patch_size ) else: gaze = None if self.bbox_path is not None: bbox = self.data['bboxes'][idx] if self.patch_data: bbox = rearrange( bbox, '(t p1) d -> t p1 d', p1=self.temporal_patch_size ) else: bbox = None if self.ocr_graph_path is not None: ocr_graphs = self.data['ocr_graph'][idx] else: ocr_graphs = None return images, torch.permute(torch.tensor(self.data['labels'][idx]), (1, 0)), pose, gaze, bbox, ocr_graphs class DataTest(Dataset): def __init__( self, frame_path, label_path, pose_path=None, gaze_path=None, bbox_path=None, ocr_graph_path=None, presaved=128, frame_ids=None, median=None, sizes = (128,128), spatial_patch_size: int = 16, temporal_patch_size: int = 4, img_channels: int = 3, patch_data: bool = False, flatten_dim: int = 1 ): self.data = {'frame_paths': [], 'labels': [], 'poses': [], 'gazes': [], 'bboxes': [], 'ocr_graph': []} self.size_1, self.size_2 = sizes self.patch_data = patch_data if self.patch_data: self.spatial_patch_size = spatial_patch_size self.temporal_patch_size = temporal_patch_size self.num_patches = (self.size_1 // self.spatial_patch_size) ** 2 self.spatial_patch_dim = img_channels * spatial_patch_size ** 2 assert self.size_1 % spatial_patch_size == 0 and self.size_2 % spatial_patch_size == 0, 'Image dimensions must be divisible by the patch size.' if presaved == 224: self.presaved = 'presaved' elif presaved == 128: self.presaved = 'presaved128' else: raise ValueError if frame_ids is not None: test_ids = [] frame_dirs = os.listdir(frame_path) episode_ids = natsorted(frame_dirs) for frame_dir in episode_ids: if int(frame_dir) in frame_ids: test_ids.append(str(frame_dir)) frame_paths = [os.path.join(frame_path, frame_dir, i) for i in natsorted(os.listdir(os.path.join(frame_path, frame_dir)))] self.data['frame_paths'].append(frame_paths) elif median is not None: test_ids = [] frame_dirs = os.listdir(frame_path) episode_ids = natsorted(frame_dirs) for frame_dir in episode_ids: frame_paths = [os.path.join(frame_path, frame_dir, i) for i in natsorted(os.listdir(os.path.join(frame_path, frame_dir)))] seq_len = len(frame_paths) if (median[1] and seq_len >= median[0]) or (not median[1] and seq_len < median[0]): self.data['frame_paths'].append(frame_paths) test_ids.append(int(frame_dir)) else: frame_dirs = os.listdir(frame_path) episode_ids = natsorted(frame_dirs) test_ids = episode_ids.copy() for frame_dir in episode_ids: frame_paths = [os.path.join(frame_path, frame_dir, i) for i in natsorted(os.listdir(os.path.join(frame_path, frame_dir)))] self.data['frame_paths'].append(frame_paths) print('Labels...') with open(label_path, 'rb') as fp: labels = pickle.load(fp) left_labels, right_labels = labels for i in test_ids: episode_id = int(i) left_labels_i = left_labels[episode_id] right_labels_i = right_labels[episode_id] self.data['labels'].append([left_labels_i, right_labels_i]) fp.close() print('Pose...') self.pose_path = pose_path if pose_path is not None: for i in test_ids: pose_i_dir = os.path.join(pose_path, i) poses = [] for j in natsorted(os.listdir(pose_i_dir)): fs = cv2.FileStorage(os.path.join(pose_i_dir, j), cv2.FILE_STORAGE_READ) if torch.tensor(fs.getNode("pose_0").mat()).shape != (2,25,3): poses.append(fs.getNode("pose_0").mat()[:2,:,:]) else: poses.append(fs.getNode("pose_0").mat()) poses = np.array(poses) poses[:, :, :, 0] = poses[:, :, :, 0] / 1920 poses[:, :, :, 1] = poses[:, :, :, 1] / 1088 self.data['poses'].append(torch.flatten(torch.tensor(poses), flatten_dim)) print('Gaze...') self.gaze_path = gaze_path if gaze_path is not None: for i in test_ids: gaze_i_txt = os.path.join(gaze_path, '{}.txt'.format(i)) with open(gaze_i_txt, 'r') as fp: gaze_i = fp.readlines() fp.close() gaze_i = ast.literal_eval(gaze_i[0]) gaze_i_tensor = torch.zeros((len(gaze_i), 2, 3)) for j in range(len(gaze_i)): if len(gaze_i[j]) >= 2: gaze_i_tensor[j,:] = torch.tensor(gaze_i[j][:2]) elif len(gaze_i[j]) == 1: gaze_i_tensor[j,0] = torch.tensor(gaze_i[j][0]) else: continue self.data['gazes'].append(torch.flatten(gaze_i_tensor, flatten_dim)) print('Bbox...') self.bbox_path = bbox_path if bbox_path is not None: self.objects = [ 'none', 'apple', 'orange', 'lemon', 'potato', 'wine', 'wineopener', 'knife', 'mug', 'peeler', 'bowl', 'chocolate', 'sugar', 'magazine', 'cracker', 'chips', 'scissors', 'cap', 'marker', 'sardinecan', 'tomatocan', 'plant', 'walnut', 'nail', 'waterspray', 'hammer', 'canopener' ] """ NOTE: old bbox for i in test_ids: bbox_i_dir = os.path.join(bbox_path, i) with open(bbox_i_dir, 'rb') as fp: bboxes_i = pickle.load(fp) len_i = len(bboxes_i) fp.close() bboxes_i_tensor = torch.zeros((len_i, len(self.objects), 4)) for j in range(len(bboxes_i)): items_i_j, bboxes_i_j = bboxes_i[j] for k in range(len(items_i_j)): bboxes_i_tensor[j, self.objects.index(items_i_j[k])] = torch.tensor([ bboxes_i_j[k][0] / 1920, # * self.size_1, bboxes_i_j[k][1] / 1088, # * self.size_2, bboxes_i_j[k][2] / 1920, # * self.size_1, bboxes_i_j[k][3] / 1088, # * self.size_2 ]) # [x_min, y_min, x_max, y_max] self.data['bboxes'].append(torch.flatten(bboxes_i_tensor, 1)) """ for i in test_ids: bbox_dir = os.path.join(bbox_path, i) bbox_tensor = torch.zeros((len(os.listdir(bbox_dir)), len(self.objects), 4)) # TODO: we might want to cut it to 10 objects for index, bbox in enumerate(sorted(os.listdir(bbox_dir), key=len)): with open(os.path.join(bbox_dir, bbox), 'r') as fp: bbox_content = fp.readlines() fp.close() for bbox_content_line in bbox_content: bbox_content_values = bbox_content_line.split() class_index, x_center, y_center, x_width, y_height = map(float, bbox_content_values) bbox_tensor[index][int(class_index)] = torch.FloatTensor([x_center, y_center, x_width, y_height]) self.data['bboxes'].append(torch.flatten(bbox_tensor, 1)) print('OCR...\n') self.ocr_graph_path = ocr_graph_path if ocr_graph_path is not None: ocr_graph = [ [15, [10, 4], [17, 2]], [13, [16, 7], [18, 4]], [11, [16, 4], [7, 10]], [14, [10, 11], [7, 1]], [12, [10, 9], [16, 3]], [1, [7, 2], [9, 9], [10, 2]], [5, [8, 8], [6, 8]], [4, [9, 8], [7, 6]], [3, [10, 1], [8, 3], [7, 4], [9, 2], [6, 1]], [2, [10, 1], [7, 7], [9, 3]], [19, [10, 2], [26, 6]], [20, [10, 7], [26, 5]], [22, [25, 4], [10, 8]], [23, [25, 15]], [21, [16, 5], [24, 8]] ] ocr_tensor = torch.zeros((27, 27)) for ocr in ocr_graph: obj = ocr[0] contexts = ocr[1:] total_context_count = sum([i[1] for i in contexts]) for context in contexts: ocr_tensor[obj, context[0]] = context[1] / total_context_count ocr_tensor = torch.flatten(ocr_tensor) for i in test_ids: self.data['ocr_graph'].append(ocr_tensor) self.frame_path = frame_path self.transform = transforms.Compose([ #transforms.ToPILImage(), transforms.Resize((self.size_1, self.size_2)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def __len__(self): return len(self.data['frame_paths']) def __getitem__(self, idx): frames_path = self.data['frame_paths'][idx][0].split('/') frames_path.insert(5, self.presaved) frames_path = ''.join(f'{w}/' for w in frames_path[:-1])[:-1]+'.pkl' with open(frames_path, 'rb') as f: images = pickle.load(f) images = torch.stack(images) if self.patch_data: images = rearrange( images, '(t p3) c (h p1) (w p2) -> t p3 (h w) (p1 p2 c)', p1=self.spatial_patch_size, p2=self.spatial_patch_size, p3=self.temporal_patch_size ) if self.pose_path is not None: pose = self.data['poses'][idx] if self.patch_data: pose = rearrange( pose, '(t p1) d -> t p1 d', p1=self.temporal_patch_size ) else: pose = None if self.gaze_path is not None: gaze = self.data['gazes'][idx] if self.patch_data: gaze = rearrange( gaze, '(t p1) d -> t p1 d', p1=self.temporal_patch_size ) else: gaze = None if self.bbox_path is not None: bbox = self.data['bboxes'][idx] if self.patch_data: bbox = rearrange( bbox, '(t p1) d -> t p1 d', p1=self.temporal_patch_size ) else: bbox = None if self.ocr_graph_path is not None: ocr_graphs = self.data['ocr_graph'][idx] else: ocr_graphs = None return images, torch.permute(torch.tensor(self.data['labels'][idx]), (1, 0)), pose, gaze, bbox, ocr_graphs