501 lines
21 KiB
Python
501 lines
21 KiB
Python
|
from torch.utils.data import Dataset
|
|||
|
import os
|
|||
|
from torchvision import transforms
|
|||
|
from natsort import natsorted
|
|||
|
import numpy as np
|
|||
|
import pickle
|
|||
|
import torch
|
|||
|
import cv2
|
|||
|
import ast
|
|||
|
from einops import rearrange
|
|||
|
from utils import spherical2cartesial
|
|||
|
|
|||
|
|
|||
|
class Data(Dataset):
|
|||
|
def __init__(
|
|||
|
self,
|
|||
|
frame_path,
|
|||
|
label_path,
|
|||
|
pose_path=None,
|
|||
|
gaze_path=None,
|
|||
|
bbox_path=None,
|
|||
|
ocr_graph_path=None,
|
|||
|
presaved=128,
|
|||
|
sizes = (128,128),
|
|||
|
spatial_patch_size: int = 16,
|
|||
|
temporal_patch_size: int = 4,
|
|||
|
img_channels: int = 3,
|
|||
|
patch_data: bool = False,
|
|||
|
flatten_dim: int = 1
|
|||
|
):
|
|||
|
self.data = {'frame_paths': [], 'labels': [], 'poses': [], 'gazes': [], 'bboxes': [], 'ocr_graph': []}
|
|||
|
self.size_1, self.size_2 = sizes
|
|||
|
self.patch_data = patch_data
|
|||
|
if self.patch_data:
|
|||
|
self.spatial_patch_size = spatial_patch_size
|
|||
|
self.temporal_patch_size = temporal_patch_size
|
|||
|
self.num_patches = (self.size_1 // self.spatial_patch_size) ** 2
|
|||
|
self.spatial_patch_dim = img_channels * spatial_patch_size ** 2
|
|||
|
|
|||
|
assert self.size_1 % spatial_patch_size == 0 and self.size_2 % spatial_patch_size == 0, 'Image dimensions must be divisible by the patch size.'
|
|||
|
|
|||
|
if presaved == 224:
|
|||
|
self.presaved = 'presaved'
|
|||
|
elif presaved == 128:
|
|||
|
self.presaved = 'presaved128'
|
|||
|
else: raise ValueError
|
|||
|
|
|||
|
frame_dirs = os.listdir(frame_path)
|
|||
|
episode_ids = natsorted(frame_dirs)
|
|||
|
seq_lens = []
|
|||
|
for frame_dir in natsorted(frame_dirs):
|
|||
|
frame_paths = [os.path.join(frame_path, frame_dir, i) for i in natsorted(os.listdir(os.path.join(frame_path, frame_dir)))]
|
|||
|
seq_lens.append(len(frame_paths))
|
|||
|
self.data['frame_paths'].append(frame_paths)
|
|||
|
|
|||
|
print('Labels...')
|
|||
|
with open(label_path, 'rb') as fp:
|
|||
|
labels = pickle.load(fp)
|
|||
|
left_labels, right_labels = labels
|
|||
|
for i in episode_ids:
|
|||
|
episode_id = int(i)
|
|||
|
left_labels_i = left_labels[episode_id]
|
|||
|
right_labels_i = right_labels[episode_id]
|
|||
|
self.data['labels'].append([left_labels_i, right_labels_i])
|
|||
|
fp.close()
|
|||
|
|
|||
|
print('Pose...')
|
|||
|
self.pose_path = pose_path
|
|||
|
if pose_path is not None:
|
|||
|
for i in episode_ids:
|
|||
|
pose_i_dir = os.path.join(pose_path, i)
|
|||
|
poses = []
|
|||
|
for j in natsorted(os.listdir(pose_i_dir)):
|
|||
|
fs = cv2.FileStorage(os.path.join(pose_i_dir, j), cv2.FILE_STORAGE_READ)
|
|||
|
if torch.tensor(fs.getNode("pose_0").mat()).shape != (2,25,3):
|
|||
|
poses.append(fs.getNode("pose_0").mat()[:2,:,:])
|
|||
|
else:
|
|||
|
poses.append(fs.getNode("pose_0").mat())
|
|||
|
poses = np.array(poses)
|
|||
|
poses[:, :, :, 0] = poses[:, :, :, 0] / 1920
|
|||
|
poses[:, :, :, 1] = poses[:, :, :, 1] / 1088
|
|||
|
self.data['poses'].append(torch.flatten(torch.tensor(poses), flatten_dim))
|
|||
|
#self.data['poses'].append(torch.tensor(np.array(poses)))
|
|||
|
|
|||
|
print('Gaze...')
|
|||
|
"""
|
|||
|
Gaze information is represented by a 3D gaze vector based on the observing camera’s Cartesian eye coordinate system
|
|||
|
"""
|
|||
|
self.gaze_path = gaze_path
|
|||
|
if gaze_path is not None:
|
|||
|
for i in episode_ids:
|
|||
|
gaze_i_txt = os.path.join(gaze_path, '{}.txt'.format(i))
|
|||
|
with open(gaze_i_txt, 'r') as fp:
|
|||
|
gaze_i = fp.readlines()
|
|||
|
fp.close()
|
|||
|
gaze_i = ast.literal_eval(gaze_i[0])
|
|||
|
gaze_i_tensor = torch.zeros((len(gaze_i), 2, 3))
|
|||
|
for j in range(len(gaze_i)):
|
|||
|
if len(gaze_i[j]) >= 2:
|
|||
|
gaze_i_tensor[j,:] = torch.tensor(gaze_i[j][:2])
|
|||
|
elif len(gaze_i[j]) == 1:
|
|||
|
gaze_i_tensor[j,0] = torch.tensor(gaze_i[j][0])
|
|||
|
else:
|
|||
|
continue
|
|||
|
self.data['gazes'].append(torch.flatten(gaze_i_tensor, flatten_dim))
|
|||
|
|
|||
|
print('Bbox...')
|
|||
|
self.bbox_path = bbox_path
|
|||
|
if bbox_path is not None:
|
|||
|
self.objects = [
|
|||
|
'none', 'apple', 'orange', 'lemon', 'potato', 'wine', 'wineopener',
|
|||
|
'knife', 'mug', 'peeler', 'bowl', 'chocolate', 'sugar', 'magazine',
|
|||
|
'cracker', 'chips', 'scissors', 'cap', 'marker', 'sardinecan', 'tomatocan',
|
|||
|
'plant', 'walnut', 'nail', 'waterspray', 'hammer', 'canopener'
|
|||
|
]
|
|||
|
"""
|
|||
|
# NOTE: old bbox
|
|||
|
for i in episode_ids:
|
|||
|
bbox_i_dir = os.path.join(bbox_path, i)
|
|||
|
with open(bbox_i_dir, 'rb') as fp:
|
|||
|
bboxes_i = pickle.load(fp)
|
|||
|
len_i = len(bboxes_i)
|
|||
|
fp.close()
|
|||
|
bboxes_i_tensor = torch.zeros((len_i, len(self.objects), 4))
|
|||
|
for j in range(len(bboxes_i)):
|
|||
|
items_i_j, bboxes_i_j = bboxes_i[j]
|
|||
|
for k in range(len(items_i_j)):
|
|||
|
bboxes_i_tensor[j, self.objects.index(items_i_j[k])] = torch.tensor([
|
|||
|
bboxes_i_j[k][0] / 1920, # * self.size_1,
|
|||
|
bboxes_i_j[k][1] / 1088, # * self.size_2,
|
|||
|
bboxes_i_j[k][2] / 1920, # * self.size_1,
|
|||
|
bboxes_i_j[k][3] / 1088, # * self.size_2
|
|||
|
]) # [x_min, y_min, x_max, y_max]
|
|||
|
self.data['bboxes'].append(torch.flatten(bboxes_i_tensor, 1))
|
|||
|
"""
|
|||
|
# NOTE: new bbox
|
|||
|
for i in episode_ids:
|
|||
|
bbox_dir = os.path.join(bbox_path, i)
|
|||
|
bbox_tensor = torch.zeros((len(os.listdir(bbox_dir)), len(self.objects), 4)) # TODO: we might want to cut it to 10 objects
|
|||
|
for index, bbox in enumerate(sorted(os.listdir(bbox_dir), key=len)):
|
|||
|
with open(os.path.join(bbox_dir, bbox), 'r') as fp:
|
|||
|
bbox_content = fp.readlines()
|
|||
|
fp.close()
|
|||
|
for bbox_content_line in bbox_content:
|
|||
|
bbox_content_values = bbox_content_line.split()
|
|||
|
class_index, x_center, y_center, x_width, y_height = map(float, bbox_content_values)
|
|||
|
bbox_tensor[index][int(class_index)] = torch.FloatTensor([x_center, y_center, x_width, y_height])
|
|||
|
self.data['bboxes'].append(torch.flatten(bbox_tensor, 1))
|
|||
|
|
|||
|
print('OCR...\n')
|
|||
|
self.ocr_graph_path = ocr_graph_path
|
|||
|
if ocr_graph_path is not None:
|
|||
|
ocr_graph = [
|
|||
|
[15, [10, 4], [17, 2]],
|
|||
|
[13, [16, 7], [18, 4]],
|
|||
|
[11, [16, 4], [7, 10]],
|
|||
|
[14, [10, 11], [7, 1]],
|
|||
|
[12, [10, 9], [16, 3]],
|
|||
|
[1, [7, 2], [9, 9], [10, 2]],
|
|||
|
[5, [8, 8], [6, 8]],
|
|||
|
[4, [9, 8], [7, 6]],
|
|||
|
[3, [10, 1], [8, 3], [7, 4], [9, 2], [6, 1]],
|
|||
|
[2, [10, 1], [7, 7], [9, 3]],
|
|||
|
[19, [10, 2], [26, 6]],
|
|||
|
[20, [10, 7], [26, 5]],
|
|||
|
[22, [25, 4], [10, 8]],
|
|||
|
[23, [25, 15]],
|
|||
|
[21, [16, 5], [24, 8]]
|
|||
|
]
|
|||
|
ocr_tensor = torch.zeros((27, 27))
|
|||
|
for ocr in ocr_graph:
|
|||
|
obj = ocr[0]
|
|||
|
contexts = ocr[1:]
|
|||
|
total_context_count = sum([i[1] for i in contexts])
|
|||
|
for context in contexts:
|
|||
|
ocr_tensor[obj, context[0]] = context[1] / total_context_count
|
|||
|
ocr_tensor = torch.flatten(ocr_tensor)
|
|||
|
for i in episode_ids:
|
|||
|
self.data['ocr_graph'].append(ocr_tensor)
|
|||
|
|
|||
|
self.frame_path = frame_path
|
|||
|
self.transform = transforms.Compose([
|
|||
|
#transforms.ToPILImage(),
|
|||
|
transforms.Resize(256),
|
|||
|
transforms.CenterCrop(224),
|
|||
|
transforms.ToTensor(),
|
|||
|
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
|||
|
std=[0.229, 0.224, 0.225])
|
|||
|
])
|
|||
|
|
|||
|
def __len__(self):
|
|||
|
return len(self.data['frame_paths'])
|
|||
|
|
|||
|
def __getitem__(self, idx):
|
|||
|
frames_path = self.data['frame_paths'][idx][0].split('/')
|
|||
|
frames_path.insert(5, self.presaved)
|
|||
|
frames_path = ''.join(f'{w}/' for w in frames_path[:-1])[:-1]+'.pkl'
|
|||
|
with open(frames_path, 'rb') as f:
|
|||
|
images = pickle.load(f)
|
|||
|
images = torch.stack(images)
|
|||
|
if self.patch_data:
|
|||
|
images = rearrange(
|
|||
|
images,
|
|||
|
'(t p3) c (h p1) (w p2) -> t p3 (h w) (p1 p2 c)',
|
|||
|
p1=self.spatial_patch_size,
|
|||
|
p2=self.spatial_patch_size,
|
|||
|
p3=self.temporal_patch_size
|
|||
|
)
|
|||
|
|
|||
|
if self.pose_path is not None:
|
|||
|
pose = self.data['poses'][idx]
|
|||
|
if self.patch_data:
|
|||
|
pose = rearrange(
|
|||
|
pose,
|
|||
|
'(t p1) d -> t p1 d',
|
|||
|
p1=self.temporal_patch_size
|
|||
|
)
|
|||
|
else:
|
|||
|
pose = None
|
|||
|
|
|||
|
if self.gaze_path is not None:
|
|||
|
gaze = self.data['gazes'][idx]
|
|||
|
if self.patch_data:
|
|||
|
gaze = rearrange(
|
|||
|
gaze,
|
|||
|
'(t p1) d -> t p1 d',
|
|||
|
p1=self.temporal_patch_size
|
|||
|
)
|
|||
|
else:
|
|||
|
gaze = None
|
|||
|
|
|||
|
if self.bbox_path is not None:
|
|||
|
bbox = self.data['bboxes'][idx]
|
|||
|
if self.patch_data:
|
|||
|
bbox = rearrange(
|
|||
|
bbox,
|
|||
|
'(t p1) d -> t p1 d',
|
|||
|
p1=self.temporal_patch_size
|
|||
|
)
|
|||
|
else:
|
|||
|
bbox = None
|
|||
|
|
|||
|
if self.ocr_graph_path is not None:
|
|||
|
ocr_graphs = self.data['ocr_graph'][idx]
|
|||
|
else:
|
|||
|
ocr_graphs = None
|
|||
|
|
|||
|
return images, torch.permute(torch.tensor(self.data['labels'][idx]), (1, 0)), pose, gaze, bbox, ocr_graphs
|
|||
|
|
|||
|
|
|||
|
class DataTest(Dataset):
|
|||
|
def __init__(
|
|||
|
self,
|
|||
|
frame_path,
|
|||
|
label_path,
|
|||
|
pose_path=None,
|
|||
|
gaze_path=None,
|
|||
|
bbox_path=None,
|
|||
|
ocr_graph_path=None,
|
|||
|
presaved=128,
|
|||
|
frame_ids=None,
|
|||
|
median=None,
|
|||
|
sizes = (128,128),
|
|||
|
spatial_patch_size: int = 16,
|
|||
|
temporal_patch_size: int = 4,
|
|||
|
img_channels: int = 3,
|
|||
|
patch_data: bool = False,
|
|||
|
flatten_dim: int = 1
|
|||
|
):
|
|||
|
self.data = {'frame_paths': [], 'labels': [], 'poses': [], 'gazes': [], 'bboxes': [], 'ocr_graph': []}
|
|||
|
self.size_1, self.size_2 = sizes
|
|||
|
self.patch_data = patch_data
|
|||
|
if self.patch_data:
|
|||
|
self.spatial_patch_size = spatial_patch_size
|
|||
|
self.temporal_patch_size = temporal_patch_size
|
|||
|
self.num_patches = (self.size_1 // self.spatial_patch_size) ** 2
|
|||
|
self.spatial_patch_dim = img_channels * spatial_patch_size ** 2
|
|||
|
|
|||
|
assert self.size_1 % spatial_patch_size == 0 and self.size_2 % spatial_patch_size == 0, 'Image dimensions must be divisible by the patch size.'
|
|||
|
|
|||
|
if presaved == 224:
|
|||
|
self.presaved = 'presaved'
|
|||
|
elif presaved == 128:
|
|||
|
self.presaved = 'presaved128'
|
|||
|
else: raise ValueError
|
|||
|
|
|||
|
if frame_ids is not None:
|
|||
|
test_ids = []
|
|||
|
frame_dirs = os.listdir(frame_path)
|
|||
|
episode_ids = natsorted(frame_dirs)
|
|||
|
for frame_dir in episode_ids:
|
|||
|
if int(frame_dir) in frame_ids:
|
|||
|
test_ids.append(str(frame_dir))
|
|||
|
frame_paths = [os.path.join(frame_path, frame_dir, i) for i in natsorted(os.listdir(os.path.join(frame_path, frame_dir)))]
|
|||
|
self.data['frame_paths'].append(frame_paths)
|
|||
|
elif median is not None:
|
|||
|
test_ids = []
|
|||
|
frame_dirs = os.listdir(frame_path)
|
|||
|
episode_ids = natsorted(frame_dirs)
|
|||
|
for frame_dir in episode_ids:
|
|||
|
frame_paths = [os.path.join(frame_path, frame_dir, i) for i in natsorted(os.listdir(os.path.join(frame_path, frame_dir)))]
|
|||
|
seq_len = len(frame_paths)
|
|||
|
if (median[1] and seq_len >= median[0]) or (not median[1] and seq_len < median[0]):
|
|||
|
self.data['frame_paths'].append(frame_paths)
|
|||
|
test_ids.append(int(frame_dir))
|
|||
|
else:
|
|||
|
frame_dirs = os.listdir(frame_path)
|
|||
|
episode_ids = natsorted(frame_dirs)
|
|||
|
test_ids = episode_ids.copy()
|
|||
|
for frame_dir in episode_ids:
|
|||
|
frame_paths = [os.path.join(frame_path, frame_dir, i) for i in natsorted(os.listdir(os.path.join(frame_path, frame_dir)))]
|
|||
|
self.data['frame_paths'].append(frame_paths)
|
|||
|
|
|||
|
print('Labels...')
|
|||
|
with open(label_path, 'rb') as fp:
|
|||
|
labels = pickle.load(fp)
|
|||
|
left_labels, right_labels = labels
|
|||
|
for i in test_ids:
|
|||
|
episode_id = int(i)
|
|||
|
left_labels_i = left_labels[episode_id]
|
|||
|
right_labels_i = right_labels[episode_id]
|
|||
|
self.data['labels'].append([left_labels_i, right_labels_i])
|
|||
|
fp.close()
|
|||
|
|
|||
|
print('Pose...')
|
|||
|
self.pose_path = pose_path
|
|||
|
if pose_path is not None:
|
|||
|
for i in test_ids:
|
|||
|
pose_i_dir = os.path.join(pose_path, i)
|
|||
|
poses = []
|
|||
|
for j in natsorted(os.listdir(pose_i_dir)):
|
|||
|
fs = cv2.FileStorage(os.path.join(pose_i_dir, j), cv2.FILE_STORAGE_READ)
|
|||
|
if torch.tensor(fs.getNode("pose_0").mat()).shape != (2,25,3):
|
|||
|
poses.append(fs.getNode("pose_0").mat()[:2,:,:])
|
|||
|
else:
|
|||
|
poses.append(fs.getNode("pose_0").mat())
|
|||
|
poses = np.array(poses)
|
|||
|
poses[:, :, :, 0] = poses[:, :, :, 0] / 1920
|
|||
|
poses[:, :, :, 1] = poses[:, :, :, 1] / 1088
|
|||
|
self.data['poses'].append(torch.flatten(torch.tensor(poses), flatten_dim))
|
|||
|
|
|||
|
print('Gaze...')
|
|||
|
self.gaze_path = gaze_path
|
|||
|
if gaze_path is not None:
|
|||
|
for i in test_ids:
|
|||
|
gaze_i_txt = os.path.join(gaze_path, '{}.txt'.format(i))
|
|||
|
with open(gaze_i_txt, 'r') as fp:
|
|||
|
gaze_i = fp.readlines()
|
|||
|
fp.close()
|
|||
|
gaze_i = ast.literal_eval(gaze_i[0])
|
|||
|
gaze_i_tensor = torch.zeros((len(gaze_i), 2, 3))
|
|||
|
for j in range(len(gaze_i)):
|
|||
|
if len(gaze_i[j]) >= 2:
|
|||
|
gaze_i_tensor[j,:] = torch.tensor(gaze_i[j][:2])
|
|||
|
elif len(gaze_i[j]) == 1:
|
|||
|
gaze_i_tensor[j,0] = torch.tensor(gaze_i[j][0])
|
|||
|
else:
|
|||
|
continue
|
|||
|
self.data['gazes'].append(torch.flatten(gaze_i_tensor, flatten_dim))
|
|||
|
|
|||
|
print('Bbox...')
|
|||
|
self.bbox_path = bbox_path
|
|||
|
if bbox_path is not None:
|
|||
|
self.objects = [
|
|||
|
'none', 'apple', 'orange', 'lemon', 'potato', 'wine', 'wineopener',
|
|||
|
'knife', 'mug', 'peeler', 'bowl', 'chocolate', 'sugar', 'magazine',
|
|||
|
'cracker', 'chips', 'scissors', 'cap', 'marker', 'sardinecan', 'tomatocan',
|
|||
|
'plant', 'walnut', 'nail', 'waterspray', 'hammer', 'canopener'
|
|||
|
]
|
|||
|
""" NOTE: old bbox
|
|||
|
for i in test_ids:
|
|||
|
bbox_i_dir = os.path.join(bbox_path, i)
|
|||
|
with open(bbox_i_dir, 'rb') as fp:
|
|||
|
bboxes_i = pickle.load(fp)
|
|||
|
len_i = len(bboxes_i)
|
|||
|
fp.close()
|
|||
|
bboxes_i_tensor = torch.zeros((len_i, len(self.objects), 4))
|
|||
|
for j in range(len(bboxes_i)):
|
|||
|
items_i_j, bboxes_i_j = bboxes_i[j]
|
|||
|
for k in range(len(items_i_j)):
|
|||
|
bboxes_i_tensor[j, self.objects.index(items_i_j[k])] = torch.tensor([
|
|||
|
bboxes_i_j[k][0] / 1920, # * self.size_1,
|
|||
|
bboxes_i_j[k][1] / 1088, # * self.size_2,
|
|||
|
bboxes_i_j[k][2] / 1920, # * self.size_1,
|
|||
|
bboxes_i_j[k][3] / 1088, # * self.size_2
|
|||
|
]) # [x_min, y_min, x_max, y_max]
|
|||
|
self.data['bboxes'].append(torch.flatten(bboxes_i_tensor, 1))
|
|||
|
"""
|
|||
|
for i in test_ids:
|
|||
|
bbox_dir = os.path.join(bbox_path, i)
|
|||
|
bbox_tensor = torch.zeros((len(os.listdir(bbox_dir)), len(self.objects), 4)) # TODO: we might want to cut it to 10 objects
|
|||
|
for index, bbox in enumerate(sorted(os.listdir(bbox_dir), key=len)):
|
|||
|
with open(os.path.join(bbox_dir, bbox), 'r') as fp:
|
|||
|
bbox_content = fp.readlines()
|
|||
|
fp.close()
|
|||
|
for bbox_content_line in bbox_content:
|
|||
|
bbox_content_values = bbox_content_line.split()
|
|||
|
class_index, x_center, y_center, x_width, y_height = map(float, bbox_content_values)
|
|||
|
bbox_tensor[index][int(class_index)] = torch.FloatTensor([x_center, y_center, x_width, y_height])
|
|||
|
self.data['bboxes'].append(torch.flatten(bbox_tensor, 1))
|
|||
|
|
|||
|
print('OCR...\n')
|
|||
|
self.ocr_graph_path = ocr_graph_path
|
|||
|
if ocr_graph_path is not None:
|
|||
|
ocr_graph = [
|
|||
|
[15, [10, 4], [17, 2]],
|
|||
|
[13, [16, 7], [18, 4]],
|
|||
|
[11, [16, 4], [7, 10]],
|
|||
|
[14, [10, 11], [7, 1]],
|
|||
|
[12, [10, 9], [16, 3]],
|
|||
|
[1, [7, 2], [9, 9], [10, 2]],
|
|||
|
[5, [8, 8], [6, 8]],
|
|||
|
[4, [9, 8], [7, 6]],
|
|||
|
[3, [10, 1], [8, 3], [7, 4], [9, 2], [6, 1]],
|
|||
|
[2, [10, 1], [7, 7], [9, 3]],
|
|||
|
[19, [10, 2], [26, 6]],
|
|||
|
[20, [10, 7], [26, 5]],
|
|||
|
[22, [25, 4], [10, 8]],
|
|||
|
[23, [25, 15]],
|
|||
|
[21, [16, 5], [24, 8]]
|
|||
|
]
|
|||
|
ocr_tensor = torch.zeros((27, 27))
|
|||
|
for ocr in ocr_graph:
|
|||
|
obj = ocr[0]
|
|||
|
contexts = ocr[1:]
|
|||
|
total_context_count = sum([i[1] for i in contexts])
|
|||
|
for context in contexts:
|
|||
|
ocr_tensor[obj, context[0]] = context[1] / total_context_count
|
|||
|
ocr_tensor = torch.flatten(ocr_tensor)
|
|||
|
for i in test_ids:
|
|||
|
self.data['ocr_graph'].append(ocr_tensor)
|
|||
|
|
|||
|
self.frame_path = frame_path
|
|||
|
self.transform = transforms.Compose([
|
|||
|
#transforms.ToPILImage(),
|
|||
|
transforms.Resize((self.size_1, self.size_2)),
|
|||
|
transforms.ToTensor(),
|
|||
|
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
|||
|
std=[0.229, 0.224, 0.225])
|
|||
|
])
|
|||
|
|
|||
|
def __len__(self):
|
|||
|
return len(self.data['frame_paths'])
|
|||
|
|
|||
|
def __getitem__(self, idx):
|
|||
|
frames_path = self.data['frame_paths'][idx][0].split('/')
|
|||
|
frames_path.insert(5, self.presaved)
|
|||
|
frames_path = ''.join(f'{w}/' for w in frames_path[:-1])[:-1]+'.pkl'
|
|||
|
with open(frames_path, 'rb') as f:
|
|||
|
images = pickle.load(f)
|
|||
|
images = torch.stack(images)
|
|||
|
if self.patch_data:
|
|||
|
images = rearrange(
|
|||
|
images,
|
|||
|
'(t p3) c (h p1) (w p2) -> t p3 (h w) (p1 p2 c)',
|
|||
|
p1=self.spatial_patch_size,
|
|||
|
p2=self.spatial_patch_size,
|
|||
|
p3=self.temporal_patch_size
|
|||
|
)
|
|||
|
|
|||
|
if self.pose_path is not None:
|
|||
|
pose = self.data['poses'][idx]
|
|||
|
if self.patch_data:
|
|||
|
pose = rearrange(
|
|||
|
pose,
|
|||
|
'(t p1) d -> t p1 d',
|
|||
|
p1=self.temporal_patch_size
|
|||
|
)
|
|||
|
else:
|
|||
|
pose = None
|
|||
|
|
|||
|
if self.gaze_path is not None:
|
|||
|
gaze = self.data['gazes'][idx]
|
|||
|
if self.patch_data:
|
|||
|
gaze = rearrange(
|
|||
|
gaze,
|
|||
|
'(t p1) d -> t p1 d',
|
|||
|
p1=self.temporal_patch_size
|
|||
|
)
|
|||
|
else:
|
|||
|
gaze = None
|
|||
|
|
|||
|
if self.bbox_path is not None:
|
|||
|
bbox = self.data['bboxes'][idx]
|
|||
|
if self.patch_data:
|
|||
|
bbox = rearrange(
|
|||
|
bbox,
|
|||
|
'(t p1) d -> t p1 d',
|
|||
|
p1=self.temporal_patch_size
|
|||
|
)
|
|||
|
else:
|
|||
|
bbox = None
|
|||
|
|
|||
|
if self.ocr_graph_path is not None:
|
|||
|
ocr_graphs = self.data['ocr_graph'][idx]
|
|||
|
else:
|
|||
|
ocr_graphs = None
|
|||
|
|
|||
|
return images, torch.permute(torch.tensor(self.data['labels'][idx]), (1, 0)), pose, gaze, bbox, ocr_graphs
|
|||
|
|