ActionDiffusion_WACV2025/dataloader/data_load.py
2024-12-02 15:42:58 +01:00

167 lines
6.8 KiB
Python

import os
import numpy as np
import torch
from torch.utils.data import Dataset
import json
from collections import namedtuple
Batch = namedtuple('Batch', 'Observations Actions Class')
class PlanningDataset(Dataset):
"""
load video and action features from dataset
"""
def __init__(self,
root,
args=None,
is_val=False,
model=None,
):
self.is_val = is_val
self.data_root = root
self.args = args
self.max_traj_len = args.horizon
self.vid_names = None
self.frame_cnts = None
self.images = None
self.last_vid = ''
if args.dataset == 'crosstask':
if is_val:
cross_task_data_name = args.json_path_val
# "/data1/wanghanlin/diffusion_planning/jsons_crosstask105/sliding_window_cross_task_data_{}_{}_new_task_id_73_with_event_class.json".format(is_val, self.max_traj_len)
else:
cross_task_data_name = args.json_path_train
# "/data1/wanghanlin/diffusion_planning/jsons_crosstask105/sliding_window_cross_task_data_{}_{}_new_task_id_73.json".format(
# is_val, self.max_traj_len)
if os.path.exists(cross_task_data_name):
with open(cross_task_data_name, 'r') as f:
self.json_data = json.load(f)
print('Loaded {}'.format(cross_task_data_name))
else:
assert 0
elif args.dataset == 'coin':
if is_val:
coin_data_name = args.json_path_val
# "/data1/wanghanlin/diffusion_planning/jsons_coin/sliding_window_cross_task_data_{}_{}_new_task_id_73_with_event_class.json".format(
# is_val, self.max_traj_len)
else:
coin_data_name = args.json_path_train
# "/data1/wanghanlin/diffusion_planning/jsons_coin/sliding_window_cross_task_data_{}_{}_new_task_id_73.json".format(
# is_val, self.max_traj_len)
if os.path.exists(coin_data_name):
with open(coin_data_name, 'r') as f:
self.json_data = json.load(f)
print('Loaded {}'.format(coin_data_name))
else:
assert 0
elif args.dataset == 'NIV':
if is_val:
niv_data_name = args.json_path_val
# "/data1/wanghanlin/diffusion_planning/jsons_niv/sliding_window_cross_task_data_{}_{}_new_task_id_73_with_event_class.json".format(
# is_val, self.max_traj_len)
else:
niv_data_name = args.json_path_train
# "/data1/wanghanlin/diffusion_planning/jsons_niv/sliding_window_cross_task_data_{}_{}_new_task_id_73.json".format(
# is_val, self.max_traj_len)
if os.path.exists(niv_data_name):
with open(niv_data_name, 'r') as f:
self.json_data = json.load(f)
print('Loaded {}'.format(niv_data_name))
else:
assert 0
else:
raise NotImplementedError(
'Dataset {} is not implemented'.format(args.dataset))
self.model = model
self.prepare_data()
#self.M = args.horizon
self.M = 3
def prepare_data(self):
vid_names = []
frame_cnts = []
for listdata in self.json_data:
vid_names.append(listdata['id'])
frame_cnts.append(listdata['instruction_len'])
self.vid_names = vid_names
self.frame_cnts = frame_cnts
def curate_dataset(self, images, legal_range, M=2):
images_list = []
labels_onehot_list = []
idx_list = []
for start_idx, end_idx, action_label in legal_range:
idx = start_idx
idx_list.append(idx)
image_start_idx = max(0, idx)
if image_start_idx + M <= len(images):
image_start = images[image_start_idx: image_start_idx + M]
else:
image_start = images[len(images) - M: len(images)]
image_start_cat = image_start[0]
for w in range(len(image_start) - 1):
image_start_cat = np.concatenate((image_start_cat, image_start[w + 1]), axis=0)
images_list.append(image_start_cat)
labels_onehot_list.append(action_label)
end_idx = max(2, end_idx)
image_end = images[end_idx - 2:end_idx + M - 2]
image_end_cat = image_end[0]
for w in range(len(image_end) - 1):
image_end_cat = np.concatenate((image_end_cat, image_end[w + 1]), axis=0)
images_list.append(image_end_cat)
'''end_idx = max(M-1, end_idx)
image_end = images[end_idx - (M-1):end_idx + M - (M-1)]
image_end_cat = image_end[0]
for w in range(len(image_end) - 1):
image_end_cat = np.concatenate((image_end_cat, image_end[w + 1]), axis=0)
images_list.append(image_end_cat)'''
return images_list, labels_onehot_list, idx_list
def sample_single(self, index):
folder_id = self.vid_names[index]
if self.is_val:
event_class = folder_id['event_class'] # was event_class
else:
task_class = folder_id['task_id']
if self.args.dataset == 'crosstask':
if folder_id['vid'] != self.last_vid:
images_ = np.load(folder_id['feature'], allow_pickle=True)
self.images = images_['frames_features']
self.last_vid = folder_id['vid']
else:
images_ = np.load(folder_id['feature'], allow_pickle=True)
self.images = images_['frames_features']
images, labels_matrix, idx_list = self.curate_dataset(
self.images, folder_id['legal_range'], M=self.M)
frames = torch.tensor(np.array(images))
labels_tensor = torch.tensor(labels_matrix, dtype=torch.long)
if self.is_val:
event_class = torch.tensor(event_class, dtype=torch.long)
return frames, labels_tensor, event_class
else:
task_class = torch.tensor(task_class, dtype=torch.long)
return frames, labels_tensor, task_class
def __getitem__(self, index):
if self.is_val:
frames, labels, event_class = self.sample_single(index)
else:
frames, labels, task = self.sample_single(index)
if self.is_val:
batch = Batch(frames, labels, event_class)
else:
batch = Batch(frames, labels, task)
return batch
def __len__(self):
return len(self.json_data)