mtomnet/tbd/tbd_dataloader.py
2025-01-10 15:39:20 +01:00

568 lines
22 KiB
Python
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
from typing import Optional, Union
import torch
import pickle
import torch
import time
import glob
import random
import os
import numpy as np
import pandas as pd
import cv2
from itertools import product
import csv
import torchvision.transforms as T
from utils.helpers import tracker_skeID, CLIPS_IDS_88, ALL_IDS, UNIQUE_OBJ_IDS
def collate_fn(batch):
# Unpack the batch into individual elements
img_3rd_pov, img_tracker, img_battery, pose1, pose2, bboxes, tracker_id, gaze, labels, exp_id, timestep = zip(*batch)
# Determine the maximum number of objects in any batch
max_n_obj = max(bbox.shape[1] for bbox in bboxes)
# Pad the bounding box tensors
bboxes_pad = []
for bbox in bboxes:
pad_size = max_n_obj - bbox.shape[1]
pad = torch.zeros((bbox.shape[0], pad_size, bbox.shape[2]), dtype=torch.float32)
padded_bbox = torch.cat((bbox, pad), dim=1)
bboxes_pad.append(padded_bbox)
# Stack the padded tensors into a batch tensor
bboxes_batch = torch.stack(bboxes_pad, dim=0)
img_3rd_pov = torch.stack(img_3rd_pov, dim=0)
img_tracker = torch.stack(img_tracker, dim=0)
img_battery = torch.stack(img_battery, dim=0)
pose1 = torch.stack(pose1, dim=0)
pose2 = torch.stack(pose2, dim=0)
gaze = torch.stack(gaze, dim=0)
labels = torch.tensor(labels, dtype=torch.long)
# Return the batched tensors
return img_3rd_pov, img_tracker, img_battery, pose1, pose2, bboxes_batch, tracker_id, gaze, labels, exp_id, timestep
def collate_fn_test(batch):
# Unpack the batch into individual elements
img_3rd_pov, img_tracker, img_battery, pose1, pose2, bboxes, tracker_id, gaze, labels, exp_id, timestep, false_beliefs = zip(*batch)
# Determine the maximum number of objects in any batch
max_n_obj = max(bbox.shape[1] for bbox in bboxes)
# Pad the bounding box tensors
bboxes_pad = []
for bbox in bboxes:
pad_size = max_n_obj - bbox.shape[1]
pad = torch.zeros((bbox.shape[0], pad_size, bbox.shape[2]), dtype=torch.float32)
padded_bbox = torch.cat((bbox, pad), dim=1)
bboxes_pad.append(padded_bbox)
# Stack the padded tensors into a batch tensor
bboxes_batch = torch.stack(bboxes_pad, dim=0)
img_3rd_pov = torch.stack(img_3rd_pov, dim=0)
img_tracker = torch.stack(img_tracker, dim=0)
img_battery = torch.stack(img_battery, dim=0)
pose1 = torch.stack(pose1, dim=0)
pose2 = torch.stack(pose2, dim=0)
gaze = torch.stack(gaze, dim=0)
labels = torch.tensor(labels, dtype=torch.long)
# Return the batched tensors
return img_3rd_pov, img_tracker, img_battery, pose1, pose2, bboxes_batch, tracker_id, gaze, labels, exp_id, timestep, false_beliefs
class TBDDataset(torch.utils.data.Dataset):
def __init__(
self,
path: str = "/scratch/bortoletto/data/tbd",
mode: str = "train",
tbd_data_path: str = "/scratch/bortoletto/data/tbd/mind_lstm_training_cnn_att/",
list_of_ids_to_consider: list = ALL_IDS,
use_preprocessed_img: bool = True,
resize_img: Optional[Union[tuple, int]] = (128,128),
):
"""TBD Dataset based on the 88 clip version of the TBD data.
Expects the following folder structure:
- path
- tracker_gt_smooth <- These are eye tracking from POV, 2D coordinates
- images/*/ <- These are the images by experiment id
- battery <- These are the images, 1st Person
- tracker <- These are the images, 1st other Person w/ eye fixation
- kinect <- These are the images, 3rd Person
- skeleton <- Pose estimation, 3D coordinates
- annotation <- These are the labels, i.e. [0,3] (see below)
Labels are strcutured as follows:
{
"O1": [ <- Object with id O1
{
"m1": {
"fluent": 3, <- # 0: enter 1: disappear 2: update 3: unchange
"loc": null
},
"m2": {
"fluent": 3,
"loc": null
},
"m12": {
"fluent": 3,
"loc": null
},
"m21": {
"fluent": 3,
"loc": null
},
"mc": {
"fluent": 3,
"loc": null
},
"mg": {
"fluent": 3,
"loc": [
22,
9
]
}
}, ...
], ...
}
This corresponds to a strict subset of the raw dataset collected
by the TBD people in their paper "Learning Traidic Belief Dynamics
in Nonverbal Communication from Videos" (CVPR2021, Oral).
We keep small amounts of data in memory (everything <100MB).
Otherwise we read from disk on the fly. This dataset applies normalization.
Args:
path (str, optional): Where the folders lie.
Defaults to "/scratch/ruhdorfer/triadic_beleif_data_v2".
list_of_ids_to_consider (list, optional): List of ids to consider.
Defaults to ALL_IDS. Otherwise specify a list,
e.g. ["test_94342_23", "test_boelter_21", ...].
resize_img (Optional[Union[tuple, int]], optional): Resize image to
this size if required. Defaults to None.
"""
print(f"Loading TBD Dataset in mode {mode}...")
self.mode = mode
start = time.time()
self.skeleton_3D_path = f"{path}/skeleton"
self.tracker_2D_path = f"{path}/tracker_gt_smooth"
self.bbox_csv_path = f"{path}/annotations_with_bbox.csv"
if use_preprocessed_img:
self.img_path = f"{path}/images_norm"
else:
self.img_path = f"{path}/images"
self.obj_ids_path = f"{path}/mind_lstm_training_cnn_att_shu.pkl"
self.label_map = list(product([0, 1, 2, 3], repeat=5))
clips = os.listdir(tbd_data_path)
data = []
labels = []
for clip in clips:
with open(tbd_data_path + clip, 'rb') as f:
vec_input, label_ = pickle.load(f, encoding='latin1')
data = data + vec_input
labels = labels + label_
c = list(zip(data, labels))
random.shuffle(c)
train_ratio = int(len(c) * 0.6)
validate_ratio = int(len(c) * 0.2)
data, label = zip(*c)
train_x, train_y = data[:train_ratio], label[:train_ratio]
validate_x, validate_y = data[train_ratio:train_ratio + validate_ratio], label[train_ratio:train_ratio + validate_ratio]
test_x, test_y = data[train_ratio + validate_ratio:], label[train_ratio + validate_ratio:]
self.mind_count = np.zeros(1024) # used for CE weights
if mode == "train":
self.data, self.labels = train_x, train_y
elif mode == "val":
self.data, self.labels = validate_x, validate_y
elif mode == "test":
self.data, self.labels = test_x, test_y
self.false_beliefs_path = f"{path}/store_mind_set"
# keep small amouts of data in memory
self.skeleton_3D = self.load_skeleton_3D(self.skeleton_3D_path, list_of_ids_to_consider)
self.tracker_2D = self.load_tracker_2D(self.tracker_2D_path, list_of_ids_to_consider)
self.bbox_df = pd.read_csv(self.bbox_csv_path, header=0)
self.obj_ids = self.load_obj_ids(self.obj_ids_path)
if not use_preprocessed_img:
normalisation_steps = [
T.ToTensor(),
T.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
]
if resize_img is not None:
normalisation_steps.insert(1, T.Resize(resize_img))
self.preprocess_img = T.Compose(normalisation_steps)
else:
self.preprocess_img = None
self.use_preprocessed_img = use_preprocessed_img
print(f"Done loading in {time.time() - start}s.")
def __len__(self):
return len(self.data)
def __getitem__(
self, idx: int
) -> tuple[
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, str, torch.Tensor, dict, str, int, str
]:
"""Given an index, return the corresponding experiment_id and timestep in the experiment.
Then picky the appropriate data and labels from these.
Args:
idx (int): _description_
Returns:
tuple: torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, str, torch.Tensor, dict
Returns the following:
- img_kinect: torch.Tensor of shape (T, C, H, W) (Default is [T, 3, 720, 1280])
- img_tracker: torch.Tensor of shape (T, H, W, C)
- img_battery: torch.Tensor of shape (T, H, W, C)
- skeleton_3D: torch.Tensor of shape (T, 26, 3) (skele 1)
- skeleton_3D: torch.Tensor of shape (T, 26, 3) (skele 2)
- bbox: torch.Tensor of shape (T, num_obj, 5)
- tracker to skeleton ID: str (either skeleton 1 or 2)
- tracker_2D: torch.Tensor of shape (T, 2)
- labels: dict (see below)
"""
labels = self.label_map[self.labels[idx]]
experiment_id = self.data[idx][1][0].split('/')[6]
img_data_path = f"{self.img_path}/{experiment_id}"
frame_ids = [int(os.path.basename(self.data[idx][1][i]).split('_')[0]) for i in range(len(self.data[idx][1]))]
if self.use_preprocessed_img:
kinect = sorted(list(glob.glob(f"{img_data_path}/kinect/*.pt")))
tracker = sorted(list(glob.glob(f"{img_data_path}/tracker/*.pt")))
battery = sorted(list(glob.glob(f"{img_data_path}/battery/*.pt")))
kinect_img_paths = [kinect[id] for id in frame_ids]
tracker_img_paths = [tracker[id] for id in frame_ids]
battery_img_paths = [battery[id] for id in frame_ids]
else:
kinect = sorted(list(glob.glob(f"{img_data_path}/kinect/*.jpg")))
tracker = sorted(list(glob.glob(f"{img_data_path}/tracker/*.jpg")))
battery = sorted(list(glob.glob(f"{img_data_path}/battery/*.jpg")))
kinect_img_paths = [kinect[id] for id in frame_ids]
tracker_img_paths = [tracker[id] for id in frame_ids]
battery_img_paths = [battery[id] for id in frame_ids]
# load images
kinect_imgs = [
torch.tensor(torch.load(img_path)) if self.use_preprocessed_img else self.preprocess_img(cv2.imread(img_path))
for img_path in kinect_img_paths
]
kinect_imgs = torch.stack(kinect_imgs, axis=0)
tracker_imgs = [
torch.tensor(torch.load(img_path)) if self.use_preprocessed_img else self.preprocess_img(cv2.imread(img_path))
for img_path in tracker_img_paths
]
tracker_imgs = torch.stack(tracker_imgs, axis=0)
battery_imgs = [
torch.tensor(torch.load(img_path)) if self.use_preprocessed_img else self.preprocess_img(cv2.imread(img_path))
for img_path in battery_img_paths
]
battery_imgs = torch.stack(battery_imgs, axis=0)
# load object id to check for false beliefs - only for testing
if self.mode == "test": #or self.mode == "train":
if f"{experiment_id}.txt" in os.listdir(self.false_beliefs_path):
obj_id = self.obj_ids[experiment_id][frame_ids[-1]]
obj_id = next(x for x in obj_id if x is not None)
false_belief = next((line.strip().split(',')[2] for line in open(f"{self.false_beliefs_path}/{experiment_id}.txt") if line.startswith(str(frame_ids[-1]) + ',' + obj_id + ',')), "no")
#if experiment_id in ['test_boelter4_0', 'test_boelter4_7', 'test_boelter4_6', 'test_boelter4_8', 'test_boelter2_3',
# 'test_94342_20', 'test_94342_18', 'test_94342_11', 'test_94342_17', 'test_boelter3_8', 'test_94342_2',
# 'test_boelter2_17', 'test_boelter3_7', 'test_94342_4', 'test_boelter3_9', 'test_boelter_10',
# 'test_boelter2_6', 'test_boelter4_10', 'test_boelter4_2', 'test_boelter4_5', 'test_94342_24',
# 'test_94342_15', 'test_boelter3_5', 'test_94342_8', 'test2', 'test_boelter3_12']:
# print('here!')
# with open(os.path.join(f'results/hgm_test_fb.csv'), mode='a') as file:
# writer = csv.writer(file)
# writer.writerow([experiment_id, obj_id, str(frame_ids[-1]), false_belief, labels[0], labels[1], labels[2], labels[3], labels[4]])
else:
false_belief = "no"
#with open(os.path.join(f'results/test_fb.csv'), mode='a') as file:
# writer = csv.writer(file)
# writer.writerow([experiment_id, str(frame_ids[-1]), false_belief, labels[0], labels[1], labels[2], labels[3], labels[4]])
df = self.bbox_df[
(self.bbox_df.experiment_name == experiment_id)
#& (self.bbox_df.name == obj_id) # NOTE: load the bounding boxes for all objects
& (self.bbox_df.name != 'P1')
& (self.bbox_df.name != 'P2')
& (self.bbox_df.frame.isin(frame_ids))
]
bboxes = []
for f in frame_ids:
bbox = torch.tensor(df.loc[df['frame'] == f, ["x_min", "y_min", "x_max", "y_max"]].to_numpy(), dtype=torch.float32)
bbox[:, 0] = bbox[:, 0] / 1280.0
bbox[:, 1] = bbox[:, 1] / 720.0
bbox[:, 2] = bbox[:, 2] / 1280.0
bbox[:, 3] = bbox[:, 3] / 720.0
bboxes.append(bbox)
bboxes = torch.stack(bboxes) # NOTE: this will need a collate function bc not every video has the same number of objects
skele1 = self.skeleton_3D[experiment_id]["skele1"][frame_ids]
skele2 = self.skeleton_3D[experiment_id]["skele2"][frame_ids]
gaze = self.tracker_2D[experiment_id][frame_ids]
if self.mode == "test":
return (
kinect_imgs,
tracker_imgs,
battery_imgs,
skele1,
skele2,
bboxes,
tracker_skeID[experiment_id], # <- This is the tracker skeleton ID
gaze,
labels, # <- per object "m1", "m2", "m12", "m21", "mc"
experiment_id,
frame_ids,
#self.onehot(int(obj_id[1:])) # <- This is the object ID as a one-hot encoding
false_belief
)
else:
return (
kinect_imgs,
tracker_imgs,
battery_imgs,
skele1,
skele2,
bboxes,
tracker_skeID[experiment_id], # <- This is the tracker skeleton ID
gaze,
labels, # <- per object "m1", "m2", "m12", "m21", "mc"
experiment_id,
frame_ids
#self.onehot(int(obj_id[1:])) # <- This is the object ID as a one-hot encoding
)
def onehot(self, x, n=len(UNIQUE_OBJ_IDS)):
retval = torch.zeros(n)
if x > 0:
retval[x-1] = 1
return retval
def load_obj_ids(self, path: str):
with open(path, "rb") as f:
ids = pickle.load(f)
return ids
def extract_labels(self):
"""TODO: Converts index label to [m1, m2, m12, m21, mc] format.
"""
return
def _flatten_mind_obj_timestep(self, mind_obj_dict: dict) -> list:
"""Flattens the mind object dict to a list. I.e. takes
{
"m1": {
"fluent": 3, <- # 0: enter 1: disappear 2: update 3: unchange
"loc": null
},
"m2": {
"fluent": 3,
"loc": null
},
"m12": {
"fluent": 3,
"loc": null
},
"m21": {
"fluent": 3,
"loc": null
},
"mc": {
"fluent": 3,
"loc": null
},
"mg": {
"fluent": 3,
"loc": [
22,
9
]
}
}
and returns [3, 3, 3, 3, 3, 3]
Args:
mind_obj_dict (dict): Mind object dict as described in __init__.doctstring.
Returns:
list: List of mind object labels.
"""
return np.array([mind_obj["fluent"] for key, mind_obj in mind_obj_dict.items() if key != "mg"])
def load_skeleton_3D(self, path: str, list_of_ids_to_consider: list):
"""Load skeleton 3D data from disk.
- path
- * <- list of ids
- skele1.p <- 3D coord per id and timestep
- skele2.p <-
Args:
path (str): Where the skeleton 3D data lie.
list_of_ids_to_consider (list): List of ids to consider.
Defaults to None which means all ids. Otherwise specify a list,
e.g. ["test_94342_23", "test_boelter_21", ...].
Returns:
dict: skeleton 3D data as described above in __init__.doctstring.
"""
skeleton_3D = {}
for experiment_id in list_of_ids_to_consider:
skeleton_3D[experiment_id] = {}
with open(f"{path}/{experiment_id}/skele1.p", "rb") as f:
skeleton_3D[experiment_id]["skele1"] = torch.tensor(np.array(pickle.load(f, encoding="latin1")), dtype=torch.float32)
with open(f"{path}/{experiment_id}/skele2.p", "rb") as f:
skeleton_3D[experiment_id]["skele2"] = torch.tensor(np.array(pickle.load(f, encoding="latin1")), dtype=torch.float32)
return skeleton_3D
def load_tracker_2D(self, path: str, list_of_ids_to_consider: list):
"""Load tracker 2D data from disk.
- path
- *.p <- 2D coord per id and timestep
Args:
path (str): Where the tracker 2D data lie.
list_of_ids_to_consider (list): List of ids to consider.
Defaults to None which means all ids. Otherwise specify a list,
e.g. ["test_94342_23", "test_boelter_21", ...].
Returns:
dict: tracker 2D data.
"""
tracker_2D = {}
for experiment_id in list_of_ids_to_consider:
with open(f"{path}/{experiment_id}.p", "rb") as f:
tracker_2D[experiment_id] = torch.tensor(np.array(pickle.load(f, encoding="latin1")), dtype=torch.float32)
return tracker_2D
def load_bbox(self, path: str, list_of_ids_to_consider: list):
"""Load bbox data from disk.
- bbox_tensors.pickle <- bbox per experiment id one tensor
Args:
path (str): Where the bbox data lie.
list_of_ids_to_consider (list): List of ids to consider.
Returns:
dict: bbox data.
"""
with open(path, "rb") as f:
pickle_data = pickle.load(f)
for key in CLIPS_IDS_88:
if key not in list_of_ids_to_consider:
pickle_data.pop(key, None)
return pickle_data
if __name__ == '__main__':
# os.environ['PYTHONHASHSEED'] = str(42)
# torch.manual_seed(42)
# np.random.seed(42)
# random.seed(42)
data = TBDDataset(use_preprocessed_img=True, mode="test")
from tqdm import tqdm
for i in tqdm(range(data.__len__())):
data[i]
breakpoint()
from torch.utils.data import DataLoader
# Just for guessing time
data_0=data[0]
data_last=data[len(data)-1]
idx = np.random.randint(1, len(data)-1) # Something in between.
start = time.time()
(
kinect_imgs, # <- len x 720 x 1280 x 3 originally, likely smaller now
tracker_imgs,
battery_imgs,
skele1,
skele2,
bbox,
tracker_skeID_sample, # <- This is the tracker skeleton ID
tracker2d,
label,
experiment_id, # From here for debugging
timestep,
#obj_id, # <- This is the object ID as a one-hot
false_belief
) = data[idx]
end = time.time()
print(f"Time for one sample: {end-start}")
print('kinect:', kinect_imgs.shape)
print('tracker:', tracker_imgs.shape)
print('battery:', battery_imgs.shape)
print('skele1:', skele1.shape)
print('skele2:', skele2.shape)
print('gaze:', tracker2d.shape)
print('bbox:', bbox.shape)
print('label:', label)
#breakpoint()
dl = DataLoader(
data,
batch_size=4,
shuffle=False,
collate_fn=collate_fn
)
from tqdm import tqdm
for j, batch in tqdm(enumerate(dl)):
#print(j, end='\r')
img_3rd_pov, img_tracker, img_battery, pose1, pose2, bbox, tracker_id, gaze, labels, exp_id, timestep = batch
#breakpoint()
#print(img_3rd_pov.shape)
#print(img_tracker.shape)
#print(img_battery.shape)
#print(pose1.shape, pose2.shape)
#print(bbox.shape)
#print(gaze.shape)
breakpoint()