from __future__ import annotations from typing import Optional, Union import torch import pickle import torch import time import glob import random import os import numpy as np import pandas as pd import cv2 from itertools import product import csv import torchvision.transforms as T from utils.helpers import tracker_skeID, CLIPS_IDS_88, ALL_IDS, UNIQUE_OBJ_IDS def collate_fn(batch): # Unpack the batch into individual elements img_3rd_pov, img_tracker, img_battery, pose1, pose2, bboxes, tracker_id, gaze, labels, exp_id, timestep = zip(*batch) # Determine the maximum number of objects in any batch max_n_obj = max(bbox.shape[1] for bbox in bboxes) # Pad the bounding box tensors bboxes_pad = [] for bbox in bboxes: pad_size = max_n_obj - bbox.shape[1] pad = torch.zeros((bbox.shape[0], pad_size, bbox.shape[2]), dtype=torch.float32) padded_bbox = torch.cat((bbox, pad), dim=1) bboxes_pad.append(padded_bbox) # Stack the padded tensors into a batch tensor bboxes_batch = torch.stack(bboxes_pad, dim=0) img_3rd_pov = torch.stack(img_3rd_pov, dim=0) img_tracker = torch.stack(img_tracker, dim=0) img_battery = torch.stack(img_battery, dim=0) pose1 = torch.stack(pose1, dim=0) pose2 = torch.stack(pose2, dim=0) gaze = torch.stack(gaze, dim=0) labels = torch.tensor(labels, dtype=torch.long) # Return the batched tensors return img_3rd_pov, img_tracker, img_battery, pose1, pose2, bboxes_batch, tracker_id, gaze, labels, exp_id, timestep def collate_fn_test(batch): # Unpack the batch into individual elements img_3rd_pov, img_tracker, img_battery, pose1, pose2, bboxes, tracker_id, gaze, labels, exp_id, timestep, false_beliefs = zip(*batch) # Determine the maximum number of objects in any batch max_n_obj = max(bbox.shape[1] for bbox in bboxes) # Pad the bounding box tensors bboxes_pad = [] for bbox in bboxes: pad_size = max_n_obj - bbox.shape[1] pad = torch.zeros((bbox.shape[0], pad_size, bbox.shape[2]), dtype=torch.float32) padded_bbox = torch.cat((bbox, pad), dim=1) bboxes_pad.append(padded_bbox) # Stack the padded tensors into a batch tensor bboxes_batch = torch.stack(bboxes_pad, dim=0) img_3rd_pov = torch.stack(img_3rd_pov, dim=0) img_tracker = torch.stack(img_tracker, dim=0) img_battery = torch.stack(img_battery, dim=0) pose1 = torch.stack(pose1, dim=0) pose2 = torch.stack(pose2, dim=0) gaze = torch.stack(gaze, dim=0) labels = torch.tensor(labels, dtype=torch.long) # Return the batched tensors return img_3rd_pov, img_tracker, img_battery, pose1, pose2, bboxes_batch, tracker_id, gaze, labels, exp_id, timestep, false_beliefs class TBDDataset(torch.utils.data.Dataset): def __init__( self, path: str = "/scratch/bortoletto/data/tbd", mode: str = "train", tbd_data_path: str = "/scratch/bortoletto/data/tbd/mind_lstm_training_cnn_att/", list_of_ids_to_consider: list = ALL_IDS, use_preprocessed_img: bool = True, resize_img: Optional[Union[tuple, int]] = (128,128), ): """TBD Dataset based on the 88 clip version of the TBD data. Expects the following folder structure: - path - tracker_gt_smooth <- These are eye tracking from POV, 2D coordinates - images/*/ <- These are the images by experiment id - battery <- These are the images, 1st Person - tracker <- These are the images, 1st other Person w/ eye fixation - kinect <- These are the images, 3rd Person - skeleton <- Pose estimation, 3D coordinates - annotation <- These are the labels, i.e. [0,3] (see below) Labels are strcutured as follows: { "O1": [ <- Object with id O1 { "m1": { "fluent": 3, <- # 0: enter 1: disappear 2: update 3: unchange "loc": null }, "m2": { "fluent": 3, "loc": null }, "m12": { "fluent": 3, "loc": null }, "m21": { "fluent": 3, "loc": null }, "mc": { "fluent": 3, "loc": null }, "mg": { "fluent": 3, "loc": [ 22, 9 ] } }, ... ], ... } This corresponds to a strict subset of the raw dataset collected by the TBD people in their paper "Learning Traidic Belief Dynamics in Nonverbal Communication from Videos" (CVPR2021, Oral). We keep small amounts of data in memory (everything <100MB). Otherwise we read from disk on the fly. This dataset applies normalization. Args: path (str, optional): Where the folders lie. Defaults to "/scratch/ruhdorfer/triadic_beleif_data_v2". list_of_ids_to_consider (list, optional): List of ids to consider. Defaults to ALL_IDS. Otherwise specify a list, e.g. ["test_94342_23", "test_boelter_21", ...]. resize_img (Optional[Union[tuple, int]], optional): Resize image to this size if required. Defaults to None. """ print(f"Loading TBD Dataset in mode {mode}...") self.mode = mode start = time.time() self.skeleton_3D_path = f"{path}/skeleton" self.tracker_2D_path = f"{path}/tracker_gt_smooth" self.bbox_csv_path = f"{path}/annotations_with_bbox.csv" if use_preprocessed_img: self.img_path = f"{path}/images_norm" else: self.img_path = f"{path}/images" self.obj_ids_path = f"{path}/mind_lstm_training_cnn_att_shu.pkl" self.label_map = list(product([0, 1, 2, 3], repeat=5)) clips = os.listdir(tbd_data_path) data = [] labels = [] for clip in clips: with open(tbd_data_path + clip, 'rb') as f: vec_input, label_ = pickle.load(f, encoding='latin1') data = data + vec_input labels = labels + label_ c = list(zip(data, labels)) random.shuffle(c) train_ratio = int(len(c) * 0.6) validate_ratio = int(len(c) * 0.2) data, label = zip(*c) train_x, train_y = data[:train_ratio], label[:train_ratio] validate_x, validate_y = data[train_ratio:train_ratio + validate_ratio], label[train_ratio:train_ratio + validate_ratio] test_x, test_y = data[train_ratio + validate_ratio:], label[train_ratio + validate_ratio:] self.mind_count = np.zeros(1024) # used for CE weights if mode == "train": self.data, self.labels = train_x, train_y elif mode == "val": self.data, self.labels = validate_x, validate_y elif mode == "test": self.data, self.labels = test_x, test_y self.false_beliefs_path = f"{path}/store_mind_set" # keep small amouts of data in memory self.skeleton_3D = self.load_skeleton_3D(self.skeleton_3D_path, list_of_ids_to_consider) self.tracker_2D = self.load_tracker_2D(self.tracker_2D_path, list_of_ids_to_consider) self.bbox_df = pd.read_csv(self.bbox_csv_path, header=0) self.obj_ids = self.load_obj_ids(self.obj_ids_path) if not use_preprocessed_img: normalisation_steps = [ T.ToTensor(), T.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ] if resize_img is not None: normalisation_steps.insert(1, T.Resize(resize_img)) self.preprocess_img = T.Compose(normalisation_steps) else: self.preprocess_img = None self.use_preprocessed_img = use_preprocessed_img print(f"Done loading in {time.time() - start}s.") def __len__(self): return len(self.data) def __getitem__( self, idx: int ) -> tuple[ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, str, torch.Tensor, dict, str, int, str ]: """Given an index, return the corresponding experiment_id and timestep in the experiment. Then picky the appropriate data and labels from these. Args: idx (int): _description_ Returns: tuple: torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, str, torch.Tensor, dict Returns the following: - img_kinect: torch.Tensor of shape (T, C, H, W) (Default is [T, 3, 720, 1280]) - img_tracker: torch.Tensor of shape (T, H, W, C) - img_battery: torch.Tensor of shape (T, H, W, C) - skeleton_3D: torch.Tensor of shape (T, 26, 3) (skele 1) - skeleton_3D: torch.Tensor of shape (T, 26, 3) (skele 2) - bbox: torch.Tensor of shape (T, num_obj, 5) - tracker to skeleton ID: str (either skeleton 1 or 2) - tracker_2D: torch.Tensor of shape (T, 2) - labels: dict (see below) """ labels = self.label_map[self.labels[idx]] experiment_id = self.data[idx][1][0].split('/')[6] img_data_path = f"{self.img_path}/{experiment_id}" frame_ids = [int(os.path.basename(self.data[idx][1][i]).split('_')[0]) for i in range(len(self.data[idx][1]))] if self.use_preprocessed_img: kinect = sorted(list(glob.glob(f"{img_data_path}/kinect/*.pt"))) tracker = sorted(list(glob.glob(f"{img_data_path}/tracker/*.pt"))) battery = sorted(list(glob.glob(f"{img_data_path}/battery/*.pt"))) kinect_img_paths = [kinect[id] for id in frame_ids] tracker_img_paths = [tracker[id] for id in frame_ids] battery_img_paths = [battery[id] for id in frame_ids] else: kinect = sorted(list(glob.glob(f"{img_data_path}/kinect/*.jpg"))) tracker = sorted(list(glob.glob(f"{img_data_path}/tracker/*.jpg"))) battery = sorted(list(glob.glob(f"{img_data_path}/battery/*.jpg"))) kinect_img_paths = [kinect[id] for id in frame_ids] tracker_img_paths = [tracker[id] for id in frame_ids] battery_img_paths = [battery[id] for id in frame_ids] # load images kinect_imgs = [ torch.tensor(torch.load(img_path)) if self.use_preprocessed_img else self.preprocess_img(cv2.imread(img_path)) for img_path in kinect_img_paths ] kinect_imgs = torch.stack(kinect_imgs, axis=0) tracker_imgs = [ torch.tensor(torch.load(img_path)) if self.use_preprocessed_img else self.preprocess_img(cv2.imread(img_path)) for img_path in tracker_img_paths ] tracker_imgs = torch.stack(tracker_imgs, axis=0) battery_imgs = [ torch.tensor(torch.load(img_path)) if self.use_preprocessed_img else self.preprocess_img(cv2.imread(img_path)) for img_path in battery_img_paths ] battery_imgs = torch.stack(battery_imgs, axis=0) # load object id to check for false beliefs - only for testing if self.mode == "test": #or self.mode == "train": if f"{experiment_id}.txt" in os.listdir(self.false_beliefs_path): obj_id = self.obj_ids[experiment_id][frame_ids[-1]] obj_id = next(x for x in obj_id if x is not None) false_belief = next((line.strip().split(',')[2] for line in open(f"{self.false_beliefs_path}/{experiment_id}.txt") if line.startswith(str(frame_ids[-1]) + ',' + obj_id + ',')), "no") #if experiment_id in ['test_boelter4_0', 'test_boelter4_7', 'test_boelter4_6', 'test_boelter4_8', 'test_boelter2_3', # 'test_94342_20', 'test_94342_18', 'test_94342_11', 'test_94342_17', 'test_boelter3_8', 'test_94342_2', # 'test_boelter2_17', 'test_boelter3_7', 'test_94342_4', 'test_boelter3_9', 'test_boelter_10', # 'test_boelter2_6', 'test_boelter4_10', 'test_boelter4_2', 'test_boelter4_5', 'test_94342_24', # 'test_94342_15', 'test_boelter3_5', 'test_94342_8', 'test2', 'test_boelter3_12']: # print('here!') # with open(os.path.join(f'results/hgm_test_fb.csv'), mode='a') as file: # writer = csv.writer(file) # writer.writerow([experiment_id, obj_id, str(frame_ids[-1]), false_belief, labels[0], labels[1], labels[2], labels[3], labels[4]]) else: false_belief = "no" #with open(os.path.join(f'results/test_fb.csv'), mode='a') as file: # writer = csv.writer(file) # writer.writerow([experiment_id, str(frame_ids[-1]), false_belief, labels[0], labels[1], labels[2], labels[3], labels[4]]) df = self.bbox_df[ (self.bbox_df.experiment_name == experiment_id) #& (self.bbox_df.name == obj_id) # NOTE: load the bounding boxes for all objects & (self.bbox_df.name != 'P1') & (self.bbox_df.name != 'P2') & (self.bbox_df.frame.isin(frame_ids)) ] bboxes = [] for f in frame_ids: bbox = torch.tensor(df.loc[df['frame'] == f, ["x_min", "y_min", "x_max", "y_max"]].to_numpy(), dtype=torch.float32) bbox[:, 0] = bbox[:, 0] / 1280.0 bbox[:, 1] = bbox[:, 1] / 720.0 bbox[:, 2] = bbox[:, 2] / 1280.0 bbox[:, 3] = bbox[:, 3] / 720.0 bboxes.append(bbox) bboxes = torch.stack(bboxes) # NOTE: this will need a collate function bc not every video has the same number of objects skele1 = self.skeleton_3D[experiment_id]["skele1"][frame_ids] skele2 = self.skeleton_3D[experiment_id]["skele2"][frame_ids] gaze = self.tracker_2D[experiment_id][frame_ids] if self.mode == "test": return ( kinect_imgs, tracker_imgs, battery_imgs, skele1, skele2, bboxes, tracker_skeID[experiment_id], # <- This is the tracker skeleton ID gaze, labels, # <- per object "m1", "m2", "m12", "m21", "mc" experiment_id, frame_ids, #self.onehot(int(obj_id[1:])) # <- This is the object ID as a one-hot encoding false_belief ) else: return ( kinect_imgs, tracker_imgs, battery_imgs, skele1, skele2, bboxes, tracker_skeID[experiment_id], # <- This is the tracker skeleton ID gaze, labels, # <- per object "m1", "m2", "m12", "m21", "mc" experiment_id, frame_ids #self.onehot(int(obj_id[1:])) # <- This is the object ID as a one-hot encoding ) def onehot(self, x, n=len(UNIQUE_OBJ_IDS)): retval = torch.zeros(n) if x > 0: retval[x-1] = 1 return retval def load_obj_ids(self, path: str): with open(path, "rb") as f: ids = pickle.load(f) return ids def extract_labels(self): """TODO: Converts index label to [m1, m2, m12, m21, mc] format. """ return def _flatten_mind_obj_timestep(self, mind_obj_dict: dict) -> list: """Flattens the mind object dict to a list. I.e. takes { "m1": { "fluent": 3, <- # 0: enter 1: disappear 2: update 3: unchange "loc": null }, "m2": { "fluent": 3, "loc": null }, "m12": { "fluent": 3, "loc": null }, "m21": { "fluent": 3, "loc": null }, "mc": { "fluent": 3, "loc": null }, "mg": { "fluent": 3, "loc": [ 22, 9 ] } } and returns [3, 3, 3, 3, 3, 3] Args: mind_obj_dict (dict): Mind object dict as described in __init__.doctstring. Returns: list: List of mind object labels. """ return np.array([mind_obj["fluent"] for key, mind_obj in mind_obj_dict.items() if key != "mg"]) def load_skeleton_3D(self, path: str, list_of_ids_to_consider: list): """Load skeleton 3D data from disk. - path - * <- list of ids - skele1.p <- 3D coord per id and timestep - skele2.p <- Args: path (str): Where the skeleton 3D data lie. list_of_ids_to_consider (list): List of ids to consider. Defaults to None which means all ids. Otherwise specify a list, e.g. ["test_94342_23", "test_boelter_21", ...]. Returns: dict: skeleton 3D data as described above in __init__.doctstring. """ skeleton_3D = {} for experiment_id in list_of_ids_to_consider: skeleton_3D[experiment_id] = {} with open(f"{path}/{experiment_id}/skele1.p", "rb") as f: skeleton_3D[experiment_id]["skele1"] = torch.tensor(np.array(pickle.load(f, encoding="latin1")), dtype=torch.float32) with open(f"{path}/{experiment_id}/skele2.p", "rb") as f: skeleton_3D[experiment_id]["skele2"] = torch.tensor(np.array(pickle.load(f, encoding="latin1")), dtype=torch.float32) return skeleton_3D def load_tracker_2D(self, path: str, list_of_ids_to_consider: list): """Load tracker 2D data from disk. - path - *.p <- 2D coord per id and timestep Args: path (str): Where the tracker 2D data lie. list_of_ids_to_consider (list): List of ids to consider. Defaults to None which means all ids. Otherwise specify a list, e.g. ["test_94342_23", "test_boelter_21", ...]. Returns: dict: tracker 2D data. """ tracker_2D = {} for experiment_id in list_of_ids_to_consider: with open(f"{path}/{experiment_id}.p", "rb") as f: tracker_2D[experiment_id] = torch.tensor(np.array(pickle.load(f, encoding="latin1")), dtype=torch.float32) return tracker_2D def load_bbox(self, path: str, list_of_ids_to_consider: list): """Load bbox data from disk. - bbox_tensors.pickle <- bbox per experiment id one tensor Args: path (str): Where the bbox data lie. list_of_ids_to_consider (list): List of ids to consider. Returns: dict: bbox data. """ with open(path, "rb") as f: pickle_data = pickle.load(f) for key in CLIPS_IDS_88: if key not in list_of_ids_to_consider: pickle_data.pop(key, None) return pickle_data if __name__ == '__main__': # os.environ['PYTHONHASHSEED'] = str(42) # torch.manual_seed(42) # np.random.seed(42) # random.seed(42) data = TBDDataset(use_preprocessed_img=True, mode="test") from tqdm import tqdm for i in tqdm(range(data.__len__())): data[i] breakpoint() from torch.utils.data import DataLoader # Just for guessing time data_0=data[0] data_last=data[len(data)-1] idx = np.random.randint(1, len(data)-1) # Something in between. start = time.time() ( kinect_imgs, # <- len x 720 x 1280 x 3 originally, likely smaller now tracker_imgs, battery_imgs, skele1, skele2, bbox, tracker_skeID_sample, # <- This is the tracker skeleton ID tracker2d, label, experiment_id, # From here for debugging timestep, #obj_id, # <- This is the object ID as a one-hot false_belief ) = data[idx] end = time.time() print(f"Time for one sample: {end-start}") print('kinect:', kinect_imgs.shape) print('tracker:', tracker_imgs.shape) print('battery:', battery_imgs.shape) print('skele1:', skele1.shape) print('skele2:', skele2.shape) print('gaze:', tracker2d.shape) print('bbox:', bbox.shape) print('label:', label) #breakpoint() dl = DataLoader( data, batch_size=4, shuffle=False, collate_fn=collate_fn ) from tqdm import tqdm for j, batch in tqdm(enumerate(dl)): #print(j, end='\r') img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze, labels, exp_id, timestep = batch #breakpoint() #print(img_3rd_pov.shape) #print(img_tracker.shape) #print(img_battery.shape) #print(pose1.shape, pose2.shape) #print(bbox.shape) #print(gaze.shape) breakpoint()