mtomnet/boss/dataloader.py
2025-01-10 15:39:20 +01:00

500 lines
21 KiB
Python
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from torch.utils.data import Dataset
import os
from torchvision import transforms
from natsort import natsorted
import numpy as np
import pickle
import torch
import cv2
import ast
from einops import rearrange
from utils import spherical2cartesial
class Data(Dataset):
def __init__(
self,
frame_path,
label_path,
pose_path=None,
gaze_path=None,
bbox_path=None,
ocr_graph_path=None,
presaved=128,
sizes = (128,128),
spatial_patch_size: int = 16,
temporal_patch_size: int = 4,
img_channels: int = 3,
patch_data: bool = False,
flatten_dim: int = 1
):
self.data = {'frame_paths': [], 'labels': [], 'poses': [], 'gazes': [], 'bboxes': [], 'ocr_graph': []}
self.size_1, self.size_2 = sizes
self.patch_data = patch_data
if self.patch_data:
self.spatial_patch_size = spatial_patch_size
self.temporal_patch_size = temporal_patch_size
self.num_patches = (self.size_1 // self.spatial_patch_size) ** 2
self.spatial_patch_dim = img_channels * spatial_patch_size ** 2
assert self.size_1 % spatial_patch_size == 0 and self.size_2 % spatial_patch_size == 0, 'Image dimensions must be divisible by the patch size.'
if presaved == 224:
self.presaved = 'presaved'
elif presaved == 128:
self.presaved = 'presaved128'
else: raise ValueError
frame_dirs = os.listdir(frame_path)
episode_ids = natsorted(frame_dirs)
seq_lens = []
for frame_dir in natsorted(frame_dirs):
frame_paths = [os.path.join(frame_path, frame_dir, i) for i in natsorted(os.listdir(os.path.join(frame_path, frame_dir)))]
seq_lens.append(len(frame_paths))
self.data['frame_paths'].append(frame_paths)
print('Labels...')
with open(label_path, 'rb') as fp:
labels = pickle.load(fp)
left_labels, right_labels = labels
for i in episode_ids:
episode_id = int(i)
left_labels_i = left_labels[episode_id]
right_labels_i = right_labels[episode_id]
self.data['labels'].append([left_labels_i, right_labels_i])
fp.close()
print('Pose...')
self.pose_path = pose_path
if pose_path is not None:
for i in episode_ids:
pose_i_dir = os.path.join(pose_path, i)
poses = []
for j in natsorted(os.listdir(pose_i_dir)):
fs = cv2.FileStorage(os.path.join(pose_i_dir, j), cv2.FILE_STORAGE_READ)
if torch.tensor(fs.getNode("pose_0").mat()).shape != (2,25,3):
poses.append(fs.getNode("pose_0").mat()[:2,:,:])
else:
poses.append(fs.getNode("pose_0").mat())
poses = np.array(poses)
poses[:, :, :, 0] = poses[:, :, :, 0] / 1920
poses[:, :, :, 1] = poses[:, :, :, 1] / 1088
self.data['poses'].append(torch.flatten(torch.tensor(poses), flatten_dim))
#self.data['poses'].append(torch.tensor(np.array(poses)))
print('Gaze...')
"""
Gaze information is represented by a 3D gaze vector based on the observing cameras Cartesian eye coordinate system
"""
self.gaze_path = gaze_path
if gaze_path is not None:
for i in episode_ids:
gaze_i_txt = os.path.join(gaze_path, '{}.txt'.format(i))
with open(gaze_i_txt, 'r') as fp:
gaze_i = fp.readlines()
fp.close()
gaze_i = ast.literal_eval(gaze_i[0])
gaze_i_tensor = torch.zeros((len(gaze_i), 2, 3))
for j in range(len(gaze_i)):
if len(gaze_i[j]) >= 2:
gaze_i_tensor[j,:] = torch.tensor(gaze_i[j][:2])
elif len(gaze_i[j]) == 1:
gaze_i_tensor[j,0] = torch.tensor(gaze_i[j][0])
else:
continue
self.data['gazes'].append(torch.flatten(gaze_i_tensor, flatten_dim))
print('Bbox...')
self.bbox_path = bbox_path
if bbox_path is not None:
self.objects = [
'none', 'apple', 'orange', 'lemon', 'potato', 'wine', 'wineopener',
'knife', 'mug', 'peeler', 'bowl', 'chocolate', 'sugar', 'magazine',
'cracker', 'chips', 'scissors', 'cap', 'marker', 'sardinecan', 'tomatocan',
'plant', 'walnut', 'nail', 'waterspray', 'hammer', 'canopener'
]
"""
# NOTE: old bbox
for i in episode_ids:
bbox_i_dir = os.path.join(bbox_path, i)
with open(bbox_i_dir, 'rb') as fp:
bboxes_i = pickle.load(fp)
len_i = len(bboxes_i)
fp.close()
bboxes_i_tensor = torch.zeros((len_i, len(self.objects), 4))
for j in range(len(bboxes_i)):
items_i_j, bboxes_i_j = bboxes_i[j]
for k in range(len(items_i_j)):
bboxes_i_tensor[j, self.objects.index(items_i_j[k])] = torch.tensor([
bboxes_i_j[k][0] / 1920, # * self.size_1,
bboxes_i_j[k][1] / 1088, # * self.size_2,
bboxes_i_j[k][2] / 1920, # * self.size_1,
bboxes_i_j[k][3] / 1088, # * self.size_2
]) # [x_min, y_min, x_max, y_max]
self.data['bboxes'].append(torch.flatten(bboxes_i_tensor, 1))
"""
# NOTE: new bbox
for i in episode_ids:
bbox_dir = os.path.join(bbox_path, i)
bbox_tensor = torch.zeros((len(os.listdir(bbox_dir)), len(self.objects), 4)) # TODO: we might want to cut it to 10 objects
for index, bbox in enumerate(sorted(os.listdir(bbox_dir), key=len)):
with open(os.path.join(bbox_dir, bbox), 'r') as fp:
bbox_content = fp.readlines()
fp.close()
for bbox_content_line in bbox_content:
bbox_content_values = bbox_content_line.split()
class_index, x_center, y_center, x_width, y_height = map(float, bbox_content_values)
bbox_tensor[index][int(class_index)] = torch.FloatTensor([x_center, y_center, x_width, y_height])
self.data['bboxes'].append(torch.flatten(bbox_tensor, 1))
print('OCR...\n')
self.ocr_graph_path = ocr_graph_path
if ocr_graph_path is not None:
ocr_graph = [
[15, [10, 4], [17, 2]],
[13, [16, 7], [18, 4]],
[11, [16, 4], [7, 10]],
[14, [10, 11], [7, 1]],
[12, [10, 9], [16, 3]],
[1, [7, 2], [9, 9], [10, 2]],
[5, [8, 8], [6, 8]],
[4, [9, 8], [7, 6]],
[3, [10, 1], [8, 3], [7, 4], [9, 2], [6, 1]],
[2, [10, 1], [7, 7], [9, 3]],
[19, [10, 2], [26, 6]],
[20, [10, 7], [26, 5]],
[22, [25, 4], [10, 8]],
[23, [25, 15]],
[21, [16, 5], [24, 8]]
]
ocr_tensor = torch.zeros((27, 27))
for ocr in ocr_graph:
obj = ocr[0]
contexts = ocr[1:]
total_context_count = sum([i[1] for i in contexts])
for context in contexts:
ocr_tensor[obj, context[0]] = context[1] / total_context_count
ocr_tensor = torch.flatten(ocr_tensor)
for i in episode_ids:
self.data['ocr_graph'].append(ocr_tensor)
self.frame_path = frame_path
self.transform = transforms.Compose([
#transforms.ToPILImage(),
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def __len__(self):
return len(self.data['frame_paths'])
def __getitem__(self, idx):
frames_path = self.data['frame_paths'][idx][0].split('/')
frames_path.insert(5, self.presaved)
frames_path = ''.join(f'{w}/' for w in frames_path[:-1])[:-1]+'.pkl'
with open(frames_path, 'rb') as f:
images = pickle.load(f)
images = torch.stack(images)
if self.patch_data:
images = rearrange(
images,
'(t p3) c (h p1) (w p2) -> t p3 (h w) (p1 p2 c)',
p1=self.spatial_patch_size,
p2=self.spatial_patch_size,
p3=self.temporal_patch_size
)
if self.pose_path is not None:
pose = self.data['poses'][idx]
if self.patch_data:
pose = rearrange(
pose,
'(t p1) d -> t p1 d',
p1=self.temporal_patch_size
)
else:
pose = None
if self.gaze_path is not None:
gaze = self.data['gazes'][idx]
if self.patch_data:
gaze = rearrange(
gaze,
'(t p1) d -> t p1 d',
p1=self.temporal_patch_size
)
else:
gaze = None
if self.bbox_path is not None:
bbox = self.data['bboxes'][idx]
if self.patch_data:
bbox = rearrange(
bbox,
'(t p1) d -> t p1 d',
p1=self.temporal_patch_size
)
else:
bbox = None
if self.ocr_graph_path is not None:
ocr_graphs = self.data['ocr_graph'][idx]
else:
ocr_graphs = None
return images, torch.permute(torch.tensor(self.data['labels'][idx]), (1, 0)), pose, gaze, bbox, ocr_graphs
class DataTest(Dataset):
def __init__(
self,
frame_path,
label_path,
pose_path=None,
gaze_path=None,
bbox_path=None,
ocr_graph_path=None,
presaved=128,
frame_ids=None,
median=None,
sizes = (128,128),
spatial_patch_size: int = 16,
temporal_patch_size: int = 4,
img_channels: int = 3,
patch_data: bool = False,
flatten_dim: int = 1
):
self.data = {'frame_paths': [], 'labels': [], 'poses': [], 'gazes': [], 'bboxes': [], 'ocr_graph': []}
self.size_1, self.size_2 = sizes
self.patch_data = patch_data
if self.patch_data:
self.spatial_patch_size = spatial_patch_size
self.temporal_patch_size = temporal_patch_size
self.num_patches = (self.size_1 // self.spatial_patch_size) ** 2
self.spatial_patch_dim = img_channels * spatial_patch_size ** 2
assert self.size_1 % spatial_patch_size == 0 and self.size_2 % spatial_patch_size == 0, 'Image dimensions must be divisible by the patch size.'
if presaved == 224:
self.presaved = 'presaved'
elif presaved == 128:
self.presaved = 'presaved128'
else: raise ValueError
if frame_ids is not None:
test_ids = []
frame_dirs = os.listdir(frame_path)
episode_ids = natsorted(frame_dirs)
for frame_dir in episode_ids:
if int(frame_dir) in frame_ids:
test_ids.append(str(frame_dir))
frame_paths = [os.path.join(frame_path, frame_dir, i) for i in natsorted(os.listdir(os.path.join(frame_path, frame_dir)))]
self.data['frame_paths'].append(frame_paths)
elif median is not None:
test_ids = []
frame_dirs = os.listdir(frame_path)
episode_ids = natsorted(frame_dirs)
for frame_dir in episode_ids:
frame_paths = [os.path.join(frame_path, frame_dir, i) for i in natsorted(os.listdir(os.path.join(frame_path, frame_dir)))]
seq_len = len(frame_paths)
if (median[1] and seq_len >= median[0]) or (not median[1] and seq_len < median[0]):
self.data['frame_paths'].append(frame_paths)
test_ids.append(int(frame_dir))
else:
frame_dirs = os.listdir(frame_path)
episode_ids = natsorted(frame_dirs)
test_ids = episode_ids.copy()
for frame_dir in episode_ids:
frame_paths = [os.path.join(frame_path, frame_dir, i) for i in natsorted(os.listdir(os.path.join(frame_path, frame_dir)))]
self.data['frame_paths'].append(frame_paths)
print('Labels...')
with open(label_path, 'rb') as fp:
labels = pickle.load(fp)
left_labels, right_labels = labels
for i in test_ids:
episode_id = int(i)
left_labels_i = left_labels[episode_id]
right_labels_i = right_labels[episode_id]
self.data['labels'].append([left_labels_i, right_labels_i])
fp.close()
print('Pose...')
self.pose_path = pose_path
if pose_path is not None:
for i in test_ids:
pose_i_dir = os.path.join(pose_path, i)
poses = []
for j in natsorted(os.listdir(pose_i_dir)):
fs = cv2.FileStorage(os.path.join(pose_i_dir, j), cv2.FILE_STORAGE_READ)
if torch.tensor(fs.getNode("pose_0").mat()).shape != (2,25,3):
poses.append(fs.getNode("pose_0").mat()[:2,:,:])
else:
poses.append(fs.getNode("pose_0").mat())
poses = np.array(poses)
poses[:, :, :, 0] = poses[:, :, :, 0] / 1920
poses[:, :, :, 1] = poses[:, :, :, 1] / 1088
self.data['poses'].append(torch.flatten(torch.tensor(poses), flatten_dim))
print('Gaze...')
self.gaze_path = gaze_path
if gaze_path is not None:
for i in test_ids:
gaze_i_txt = os.path.join(gaze_path, '{}.txt'.format(i))
with open(gaze_i_txt, 'r') as fp:
gaze_i = fp.readlines()
fp.close()
gaze_i = ast.literal_eval(gaze_i[0])
gaze_i_tensor = torch.zeros((len(gaze_i), 2, 3))
for j in range(len(gaze_i)):
if len(gaze_i[j]) >= 2:
gaze_i_tensor[j,:] = torch.tensor(gaze_i[j][:2])
elif len(gaze_i[j]) == 1:
gaze_i_tensor[j,0] = torch.tensor(gaze_i[j][0])
else:
continue
self.data['gazes'].append(torch.flatten(gaze_i_tensor, flatten_dim))
print('Bbox...')
self.bbox_path = bbox_path
if bbox_path is not None:
self.objects = [
'none', 'apple', 'orange', 'lemon', 'potato', 'wine', 'wineopener',
'knife', 'mug', 'peeler', 'bowl', 'chocolate', 'sugar', 'magazine',
'cracker', 'chips', 'scissors', 'cap', 'marker', 'sardinecan', 'tomatocan',
'plant', 'walnut', 'nail', 'waterspray', 'hammer', 'canopener'
]
""" NOTE: old bbox
for i in test_ids:
bbox_i_dir = os.path.join(bbox_path, i)
with open(bbox_i_dir, 'rb') as fp:
bboxes_i = pickle.load(fp)
len_i = len(bboxes_i)
fp.close()
bboxes_i_tensor = torch.zeros((len_i, len(self.objects), 4))
for j in range(len(bboxes_i)):
items_i_j, bboxes_i_j = bboxes_i[j]
for k in range(len(items_i_j)):
bboxes_i_tensor[j, self.objects.index(items_i_j[k])] = torch.tensor([
bboxes_i_j[k][0] / 1920, # * self.size_1,
bboxes_i_j[k][1] / 1088, # * self.size_2,
bboxes_i_j[k][2] / 1920, # * self.size_1,
bboxes_i_j[k][3] / 1088, # * self.size_2
]) # [x_min, y_min, x_max, y_max]
self.data['bboxes'].append(torch.flatten(bboxes_i_tensor, 1))
"""
for i in test_ids:
bbox_dir = os.path.join(bbox_path, i)
bbox_tensor = torch.zeros((len(os.listdir(bbox_dir)), len(self.objects), 4)) # TODO: we might want to cut it to 10 objects
for index, bbox in enumerate(sorted(os.listdir(bbox_dir), key=len)):
with open(os.path.join(bbox_dir, bbox), 'r') as fp:
bbox_content = fp.readlines()
fp.close()
for bbox_content_line in bbox_content:
bbox_content_values = bbox_content_line.split()
class_index, x_center, y_center, x_width, y_height = map(float, bbox_content_values)
bbox_tensor[index][int(class_index)] = torch.FloatTensor([x_center, y_center, x_width, y_height])
self.data['bboxes'].append(torch.flatten(bbox_tensor, 1))
print('OCR...\n')
self.ocr_graph_path = ocr_graph_path
if ocr_graph_path is not None:
ocr_graph = [
[15, [10, 4], [17, 2]],
[13, [16, 7], [18, 4]],
[11, [16, 4], [7, 10]],
[14, [10, 11], [7, 1]],
[12, [10, 9], [16, 3]],
[1, [7, 2], [9, 9], [10, 2]],
[5, [8, 8], [6, 8]],
[4, [9, 8], [7, 6]],
[3, [10, 1], [8, 3], [7, 4], [9, 2], [6, 1]],
[2, [10, 1], [7, 7], [9, 3]],
[19, [10, 2], [26, 6]],
[20, [10, 7], [26, 5]],
[22, [25, 4], [10, 8]],
[23, [25, 15]],
[21, [16, 5], [24, 8]]
]
ocr_tensor = torch.zeros((27, 27))
for ocr in ocr_graph:
obj = ocr[0]
contexts = ocr[1:]
total_context_count = sum([i[1] for i in contexts])
for context in contexts:
ocr_tensor[obj, context[0]] = context[1] / total_context_count
ocr_tensor = torch.flatten(ocr_tensor)
for i in test_ids:
self.data['ocr_graph'].append(ocr_tensor)
self.frame_path = frame_path
self.transform = transforms.Compose([
#transforms.ToPILImage(),
transforms.Resize((self.size_1, self.size_2)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def __len__(self):
return len(self.data['frame_paths'])
def __getitem__(self, idx):
frames_path = self.data['frame_paths'][idx][0].split('/')
frames_path.insert(5, self.presaved)
frames_path = ''.join(f'{w}/' for w in frames_path[:-1])[:-1]+'.pkl'
with open(frames_path, 'rb') as f:
images = pickle.load(f)
images = torch.stack(images)
if self.patch_data:
images = rearrange(
images,
'(t p3) c (h p1) (w p2) -> t p3 (h w) (p1 p2 c)',
p1=self.spatial_patch_size,
p2=self.spatial_patch_size,
p3=self.temporal_patch_size
)
if self.pose_path is not None:
pose = self.data['poses'][idx]
if self.patch_data:
pose = rearrange(
pose,
'(t p1) d -> t p1 d',
p1=self.temporal_patch_size
)
else:
pose = None
if self.gaze_path is not None:
gaze = self.data['gazes'][idx]
if self.patch_data:
gaze = rearrange(
gaze,
'(t p1) d -> t p1 d',
p1=self.temporal_patch_size
)
else:
gaze = None
if self.bbox_path is not None:
bbox = self.data['bboxes'][idx]
if self.patch_data:
bbox = rearrange(
bbox,
'(t p1) d -> t p1 d',
p1=self.temporal_patch_size
)
else:
bbox = None
if self.ocr_graph_path is not None:
ocr_graphs = self.data['ocr_graph'][idx]
else:
ocr_graphs = None
return images, torch.permute(torch.tensor(self.data['labels'][idx]), (1, 0)), pose, gaze, bbox, ocr_graphs