ActionDiffusion_WACV2025/dataloader/data_load_mlp.py
2024-12-02 15:42:58 +01:00

456 lines
19 KiB
Python

import os
import numpy as np
import torch
from torch.utils.data import Dataset
import json
import math
from collections import namedtuple
Batch = namedtuple('Batch', 'Observations Actions Class')
def get_vids_from_json(path):
task_vids = {}
with open(path, 'r') as f:
json_data = json.load(f)
for i in json_data:
task = i['task']
vid = i['vid']
if task not in task_vids:
task_vids[task] = []
task_vids[task].append(vid)
return task_vids
def get_vids(path):
task_vids = {}
with open(path, 'r') as f:
for line in f:
task, vid, url = line.strip().split(',')
if task not in task_vids:
task_vids[task] = []
task_vids[task].append(vid)
return task_vids
def read_task_info(path):
titles = {}
urls = {}
n_steps = {}
steps = {}
with open(path, 'r') as f:
idx = f.readline()
while idx != '':
idx = idx.strip()
titles[idx] = f.readline().strip()
urls[idx] = f.readline().strip()
n_steps[idx] = int(f.readline().strip())
steps[idx] = f.readline().strip().split(',')
next(f)
idx = f.readline()
return {'title': titles, 'url': urls, 'n_steps': n_steps, 'steps': steps}
class PlanningDataset(Dataset):
def __init__(self,
root,
args=None,
is_val=False,
model=None,
crosstask_use_feature_how=True,
):
self.is_val = is_val
self.data_root = root
self.args = args
self.max_traj_len = args.horizon
self.vid_names = None
self.frame_cnts = None
self.images = None
self.last_vid = ''
self.crosstask_use_feature_how = crosstask_use_feature_how
if args.dataset == 'crosstask':
"""
.
└── crosstask
├── crosstask_features
└── crosstask_release
├── tasks_primary.txt
├── videos.csv or json
└── videos_val.csv or json
"""
val_csv_path = os.path.join(
root, 'dataset', 'crosstask', 'crosstask_release', 'test_list.json') # 'videos_val.csv')
video_csv_path = os.path.join(
root, 'dataset', 'crosstask', 'crosstask_release', 'train_list.json') # 'videos.csv')
if crosstask_use_feature_how:
self.features_path = os.path.join(root, 'dataset', 'crosstask', 'processed_data')
else:
self.features_path = os.path.join(root, 'dataset', 'crosstask', 'crosstask_features')
self.constraints_path = os.path.join(
root, 'dataset', 'crosstask', 'crosstask_release', 'annotations')
self.action_one_hot = np.load(
os.path.join(root, 'dataset', 'crosstask', 'crosstask_release', 'actions_one_hot.npy'),
allow_pickle=True).item()
self.task_class = {
'23521': 0,
'59684': 1,
'71781': 2,
'113766': 3,
'105222': 4,
'94276': 5,
'53193': 6,
'105253': 7,
'44047': 8,
'76400': 9,
'16815': 10,
'95603': 11,
'109972': 12,
'44789': 13,
'40567': 14,
'77721': 15,
'87706': 16,
'91515': 17
}
# cross_task_data_name = "/data1/wanghanlin/diffusion_planning/jsons_crosstask105/sliding_window_cross_task_data_{}_{}_new_task_id_73.json".format(
# is_val, self.max_traj_len)
if is_val:
cross_task_data_name = args.json_path_val
else:
cross_task_data_name = args.json_path_train
if os.path.exists(cross_task_data_name):
with open(cross_task_data_name, 'r') as f:
self.json_data = json.load(f)
print('Loaded {}'.format(cross_task_data_name))
else:
file_type = val_csv_path.split('.')[-1]
if file_type == 'json':
all_task_vids = get_vids_from_json(video_csv_path)
val_vids = get_vids_from_json(val_csv_path)
else:
all_task_vids = get_vids(video_csv_path)
val_vids = get_vids(val_csv_path)
if is_val:
task_vids = val_vids
else:
task_vids = {task: [vid for vid in vids if task not in val_vids or vid not in val_vids[task]] for
task, vids in
all_task_vids.items()}
primary_info = read_task_info(os.path.join(
root, 'dataset', 'crosstask', 'crosstask_release', 'tasks_primary.txt'))
self.n_steps = primary_info['n_steps']
all_tasks = set(self.n_steps.keys())
task_vids = {task: vids for task,
vids in task_vids.items() if task in all_tasks}
all_vids = []
for task, vids in task_vids.items():
all_vids.extend([(task, vid) for vid in vids])
json_data = []
for idx in range(len(all_vids)):
task, vid = all_vids[idx]
if self.crosstask_use_feature_how:
video_path = os.path.join(
self.features_path, str(task) + '_' + str(vid) + '.npy')
else:
video_path = os.path.join(
self.features_path, str(vid) + '.npy')
legal_range = self.process_single(task, vid)
if not legal_range:
continue
temp_len = len(legal_range)
temp = []
while temp_len < self.max_traj_len:
temp.append(legal_range[0])
temp_len += 1
temp.extend(legal_range)
legal_range = temp
for i in range(len(legal_range) - self.max_traj_len + 1):
legal_range_current = legal_range[i:i + self.max_traj_len]
json_data.append({'id': {'vid': vid, 'task': task, 'feature': video_path,
'legal_range': legal_range_current, 'task_id': self.task_class[task]},
'instruction_len': self.n_steps[task]})
self.json_data = json_data
with open(cross_task_data_name, 'w') as f:
json.dump(json_data, f)
elif args.dataset == 'coin':
coin_path = os.path.join(root, 'dataset/coin', 'full_npy/')
val_csv_path = os.path.join(
root, 'dataset/coin', 'coin_test_30.json')
video_csv_path = os.path.join(
root, 'dataset/coin', 'coin_train_70.json')
# coin_data_name = "/data1/wanghanlin/diffusion_planning/jsons_coin/sliding_window_cross_task_data_{}_{}_new_task_id_73.json".format(
# is_val, self.max_traj_len)
if is_val:
coin_data_name = args.json_path_val
else:
coin_data_name = args.json_path_train
if os.path.exists(coin_data_name):
with open(coin_data_name, 'r') as f:
self.json_data = json.load(f)
print('Loaded {}'.format(coin_data_name))
else:
json_data = []
num = 0
if is_val:
with open(val_csv_path, 'r') as f:
coin_data = json.load(f)
else:
with open(video_csv_path, 'r') as f:
coin_data = json.load(f)
for i in coin_data:
for (k, v) in i.items():
file_name = v['class'] + '_' + str(v['recipe_type']) + '_' + k + '.npy'
file_path = coin_path + file_name
images_ = np.load(file_path, allow_pickle=True)
images = images_['frames_features']
legal_range = []
last_action = v['annotation'][-1]['segment'][1]
last_action = math.ceil(last_action)
if last_action > len(images):
print(k, last_action, len(images))
num += 1
continue
for annotation in v['annotation']:
action_label = int(annotation['id']) - 1
start_idx, end_idx = annotation['segment']
start_idx = math.floor(start_idx)
end_idx = math.ceil(end_idx)
if end_idx < images.shape[0]:
legal_range.append((start_idx, end_idx, action_label))
else:
legal_range.append((start_idx, images.shape[0] - 1, action_label))
temp_len = len(legal_range)
temp = []
while temp_len < self.max_traj_len:
temp.append(legal_range[0])
temp_len += 1
temp.extend(legal_range)
legal_range = temp
for i in range(len(legal_range) - self.max_traj_len + 1):
legal_range_current = legal_range[i:i + self.max_traj_len]
json_data.append({'id': {'vid': k, 'feature': file_path,
'legal_range': legal_range_current, 'task_id': v['recipe_type']},
'instruction_len': 0})
print(num)
self.json_data = json_data
with open(coin_data_name, 'w') as f:
json.dump(json_data, f)
elif args.dataset == 'NIV':
val_csv_path = os.path.join(
root, 'dataset/NIV', 'test30_new.json')
video_csv_path = os.path.join(
root, 'dataset/NIV', 'train70_new.json')
# niv_data_name = "/data1/wanghanlin/diffusion_planning/jsons_niv/sliding_window_cross_task_data_{}_{}_new_task_id_73.json".format(
# is_val, self.max_traj_len)
if is_val:
niv_data_name = args.json_path_val
else:
niv_data_name = args.json_path_train
if os.path.exists(niv_data_name):
with open(niv_data_name, 'r') as f:
self.json_data = json.load(f)
print('Loaded {}'.format(niv_data_name))
else:
json_data = []
if is_val:
with open(val_csv_path, 'r') as f:
niv_data = json.load(f)
else:
with open(video_csv_path, 'r') as f:
niv_data = json.load(f)
for d in niv_data:
legal_range = []
path = os.path.join( root, 'dataset/NIV', 'processed_data' , d['feature'])
info = np.load(path, allow_pickle=True)
num_steps = int(info['num_steps'])
assert num_steps == len(info['steps_ids'])
assert info['num_steps'] == len(info['steps_starts'])
assert info['num_steps'] == len(info['steps_ends'])
starts = info['steps_starts']
ends = info['steps_ends']
action_labels = info['steps_ids']
images = info['frames_features']
for i in range(num_steps):
action_label = int(action_labels[i])
start_idx = math.floor(float(starts[i]))
end_idx = math.ceil(float(ends[i]))
if end_idx < images.shape[0]:
legal_range.append((start_idx, end_idx, action_label))
else:
legal_range.append((start_idx, images.shape[0] - 1, action_label))
temp_len = len(legal_range)
temp = []
while temp_len < self.max_traj_len:
temp.append(legal_range[0])
temp_len += 1
temp.extend(legal_range)
legal_range = temp
for i in range(len(legal_range) - self.max_traj_len + 1):
legal_range_current = legal_range[i:i + self.max_traj_len]
json_data.append({'id': {'feature': path,
'legal_range': legal_range_current, 'task_id': d['task_id']},
'instruction_len': 0})
self.json_data = json_data
with open(niv_data_name, 'w') as f:
json.dump(json_data, f)
print(len(json_data))
else:
raise NotImplementedError(
'Dataset {} is not implemented'.format(args.dataset))
self.model = model
self.prepare_data()
self.M = 3
def process_single(self, task, vid):
if self.crosstask_use_feature_how:
if not os.path.exists(os.path.join(self.features_path, str(task) + '_' + str(vid) + '.npy')):
return False
images_ = np.load(os.path.join(self.features_path, str(task) + '_' + str(vid) + '.npy'), allow_pickle=True)
images = images_['frames_features']
else:
if not os.path.exists(os.path.join(self.features_path, vid + '.npy')):
return False
images = np.load(os.path.join(self.features_path, vid + '.npy'))
cnst_path = os.path.join(
self.constraints_path, task + '_' + vid + '.csv')
legal_range = self.read_assignment(task, cnst_path)
legal_range_ret = []
for (start_idx, end_idx, action_label) in legal_range:
if not start_idx < images.shape[0]:
print(task, vid, end_idx, images.shape[0])
return False
if end_idx < images.shape[0]:
legal_range_ret.append((start_idx, end_idx, action_label))
else:
legal_range_ret.append((start_idx, images.shape[0] - 1, action_label))
return legal_range_ret
def read_assignment(self, task_id, path):
legal_range = []
with open(path, 'r') as f:
for line in f:
step, start, end = line.strip().split(',')
start = int(math.floor(float(start)))
end = int(math.ceil(float(end)))
action_label_ind = self.action_one_hot[task_id + '_' + step]
legal_range.append((start, end, action_label_ind))
return legal_range
def prepare_data(self):
vid_names = []
frame_cnts = []
for listdata in self.json_data:
vid_names.append(listdata['id'])
frame_cnts.append(listdata['instruction_len'])
self.vid_names = vid_names
self.frame_cnts = frame_cnts
def curate_dataset(self, images, legal_range, M=2):
images_list = []
labels_onehot_list = []
idx_list = []
for start_idx, end_idx, action_label in legal_range:
idx = start_idx
idx_list.append(idx)
image_start_idx = max(0, idx)
if image_start_idx + M <= len(images):
image_start = images[image_start_idx: image_start_idx + M]
else:
image_start = images[len(images) - M: len(images)]
image_start_cat = image_start[0]
for w in range(len(image_start) - 1):
image_start_cat = np.concatenate((image_start_cat, image_start[w + 1]), axis=0)
images_list.append(image_start_cat)
labels_onehot_list.append(action_label)
end_idx = max(2, end_idx)
image_end = images[end_idx - 2:end_idx + M - 2]
image_end_cat = image_end[0]
for w in range(len(image_end) - 1):
image_end_cat = np.concatenate((image_end_cat, image_end[w + 1]), axis=0)
images_list.append(image_end_cat)
return images_list, labels_onehot_list, idx_list
def sample_single(self, index):
folder_id = self.vid_names[index]
if self.is_val:
event_class = folder_id['task_id']
else:
task_class = folder_id['task_id']
if self.args.dataset == 'crosstask':
if folder_id['vid'] != self.last_vid:
if self.crosstask_use_feature_how:
images_ = np.load(folder_id['feature'], allow_pickle=True)
self.images = images_['frames_features']
self.last_vid = folder_id['vid']
else:
self.images = np.load(os.path.join(self.features_path, folder_id['vid'] + '.npy'))
else:
images_ = np.load(folder_id['feature'], allow_pickle=True)
self.images = images_['frames_features']
images, labels_matrix, idx_list = self.curate_dataset(
self.images, folder_id['legal_range'], M=self.M)
frames = torch.tensor(np.array(images))
labels_tensor = torch.tensor(labels_matrix, dtype=torch.long)
if self.is_val:
event_class = torch.tensor(event_class, dtype=torch.long)
return frames, labels_tensor, event_class
else:
task_class = torch.tensor(task_class, dtype=torch.long)
return frames, labels_tensor, task_class
def __getitem__(self, index):
if self.is_val:
frames, labels, event_class = self.sample_single(index)
else:
frames, labels, task = self.sample_single(index)
if self.is_val:
batch = Batch(frames, labels, event_class)
else:
batch = Batch(frames, labels, task)
return batch
def __len__(self):
return len(self.json_data)