first commit
This commit is contained in:
commit
8f8cf48929
2819 changed files with 33143 additions and 0 deletions
167
dataloader/data_load.py
Normal file
167
dataloader/data_load.py
Normal file
|
@ -0,0 +1,167 @@
|
|||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
import json
|
||||
from collections import namedtuple
|
||||
|
||||
|
||||
Batch = namedtuple('Batch', 'Observations Actions Class')
|
||||
|
||||
|
||||
class PlanningDataset(Dataset):
|
||||
"""
|
||||
load video and action features from dataset
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
root,
|
||||
args=None,
|
||||
is_val=False,
|
||||
model=None,
|
||||
):
|
||||
self.is_val = is_val
|
||||
self.data_root = root
|
||||
self.args = args
|
||||
self.max_traj_len = args.horizon
|
||||
self.vid_names = None
|
||||
self.frame_cnts = None
|
||||
self.images = None
|
||||
self.last_vid = ''
|
||||
|
||||
if args.dataset == 'crosstask':
|
||||
if is_val:
|
||||
cross_task_data_name = args.json_path_val
|
||||
# "/data1/wanghanlin/diffusion_planning/jsons_crosstask105/sliding_window_cross_task_data_{}_{}_new_task_id_73_with_event_class.json".format(is_val, self.max_traj_len)
|
||||
else:
|
||||
cross_task_data_name = args.json_path_train
|
||||
# "/data1/wanghanlin/diffusion_planning/jsons_crosstask105/sliding_window_cross_task_data_{}_{}_new_task_id_73.json".format(
|
||||
# is_val, self.max_traj_len)
|
||||
|
||||
if os.path.exists(cross_task_data_name):
|
||||
with open(cross_task_data_name, 'r') as f:
|
||||
self.json_data = json.load(f)
|
||||
print('Loaded {}'.format(cross_task_data_name))
|
||||
else:
|
||||
assert 0
|
||||
elif args.dataset == 'coin':
|
||||
if is_val:
|
||||
coin_data_name = args.json_path_val
|
||||
# "/data1/wanghanlin/diffusion_planning/jsons_coin/sliding_window_cross_task_data_{}_{}_new_task_id_73_with_event_class.json".format(
|
||||
# is_val, self.max_traj_len)
|
||||
else:
|
||||
coin_data_name = args.json_path_train
|
||||
# "/data1/wanghanlin/diffusion_planning/jsons_coin/sliding_window_cross_task_data_{}_{}_new_task_id_73.json".format(
|
||||
# is_val, self.max_traj_len)
|
||||
if os.path.exists(coin_data_name):
|
||||
with open(coin_data_name, 'r') as f:
|
||||
self.json_data = json.load(f)
|
||||
print('Loaded {}'.format(coin_data_name))
|
||||
else:
|
||||
assert 0
|
||||
elif args.dataset == 'NIV':
|
||||
if is_val:
|
||||
niv_data_name = args.json_path_val
|
||||
# "/data1/wanghanlin/diffusion_planning/jsons_niv/sliding_window_cross_task_data_{}_{}_new_task_id_73_with_event_class.json".format(
|
||||
# is_val, self.max_traj_len)
|
||||
else:
|
||||
niv_data_name = args.json_path_train
|
||||
# "/data1/wanghanlin/diffusion_planning/jsons_niv/sliding_window_cross_task_data_{}_{}_new_task_id_73.json".format(
|
||||
# is_val, self.max_traj_len)
|
||||
if os.path.exists(niv_data_name):
|
||||
with open(niv_data_name, 'r') as f:
|
||||
self.json_data = json.load(f)
|
||||
print('Loaded {}'.format(niv_data_name))
|
||||
else:
|
||||
assert 0
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'Dataset {} is not implemented'.format(args.dataset))
|
||||
|
||||
self.model = model
|
||||
self.prepare_data()
|
||||
#self.M = args.horizon
|
||||
self.M = 3
|
||||
|
||||
def prepare_data(self):
|
||||
vid_names = []
|
||||
frame_cnts = []
|
||||
for listdata in self.json_data:
|
||||
vid_names.append(listdata['id'])
|
||||
frame_cnts.append(listdata['instruction_len'])
|
||||
self.vid_names = vid_names
|
||||
self.frame_cnts = frame_cnts
|
||||
|
||||
def curate_dataset(self, images, legal_range, M=2):
|
||||
images_list = []
|
||||
labels_onehot_list = []
|
||||
idx_list = []
|
||||
for start_idx, end_idx, action_label in legal_range:
|
||||
idx = start_idx
|
||||
idx_list.append(idx)
|
||||
image_start_idx = max(0, idx)
|
||||
if image_start_idx + M <= len(images):
|
||||
image_start = images[image_start_idx: image_start_idx + M]
|
||||
else:
|
||||
image_start = images[len(images) - M: len(images)]
|
||||
image_start_cat = image_start[0]
|
||||
for w in range(len(image_start) - 1):
|
||||
image_start_cat = np.concatenate((image_start_cat, image_start[w + 1]), axis=0)
|
||||
images_list.append(image_start_cat)
|
||||
labels_onehot_list.append(action_label)
|
||||
|
||||
end_idx = max(2, end_idx)
|
||||
image_end = images[end_idx - 2:end_idx + M - 2]
|
||||
image_end_cat = image_end[0]
|
||||
for w in range(len(image_end) - 1):
|
||||
image_end_cat = np.concatenate((image_end_cat, image_end[w + 1]), axis=0)
|
||||
images_list.append(image_end_cat)
|
||||
'''end_idx = max(M-1, end_idx)
|
||||
image_end = images[end_idx - (M-1):end_idx + M - (M-1)]
|
||||
image_end_cat = image_end[0]
|
||||
for w in range(len(image_end) - 1):
|
||||
image_end_cat = np.concatenate((image_end_cat, image_end[w + 1]), axis=0)
|
||||
images_list.append(image_end_cat)'''
|
||||
return images_list, labels_onehot_list, idx_list
|
||||
|
||||
def sample_single(self, index):
|
||||
folder_id = self.vid_names[index]
|
||||
|
||||
if self.is_val:
|
||||
event_class = folder_id['event_class'] # was event_class
|
||||
else:
|
||||
task_class = folder_id['task_id']
|
||||
|
||||
if self.args.dataset == 'crosstask':
|
||||
if folder_id['vid'] != self.last_vid:
|
||||
images_ = np.load(folder_id['feature'], allow_pickle=True)
|
||||
self.images = images_['frames_features']
|
||||
self.last_vid = folder_id['vid']
|
||||
else:
|
||||
images_ = np.load(folder_id['feature'], allow_pickle=True)
|
||||
self.images = images_['frames_features']
|
||||
images, labels_matrix, idx_list = self.curate_dataset(
|
||||
self.images, folder_id['legal_range'], M=self.M)
|
||||
frames = torch.tensor(np.array(images))
|
||||
labels_tensor = torch.tensor(labels_matrix, dtype=torch.long)
|
||||
|
||||
if self.is_val:
|
||||
event_class = torch.tensor(event_class, dtype=torch.long)
|
||||
return frames, labels_tensor, event_class
|
||||
else:
|
||||
task_class = torch.tensor(task_class, dtype=torch.long)
|
||||
return frames, labels_tensor, task_class
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.is_val:
|
||||
frames, labels, event_class = self.sample_single(index)
|
||||
else:
|
||||
frames, labels, task = self.sample_single(index)
|
||||
if self.is_val:
|
||||
batch = Batch(frames, labels, event_class)
|
||||
else:
|
||||
batch = Batch(frames, labels, task)
|
||||
return batch
|
||||
|
||||
def __len__(self):
|
||||
return len(self.json_data)
|
456
dataloader/data_load_mlp.py
Normal file
456
dataloader/data_load_mlp.py
Normal file
|
@ -0,0 +1,456 @@
|
|||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
import json
|
||||
import math
|
||||
from collections import namedtuple
|
||||
|
||||
Batch = namedtuple('Batch', 'Observations Actions Class')
|
||||
|
||||
|
||||
def get_vids_from_json(path):
|
||||
task_vids = {}
|
||||
with open(path, 'r') as f:
|
||||
json_data = json.load(f)
|
||||
|
||||
for i in json_data:
|
||||
task = i['task']
|
||||
vid = i['vid']
|
||||
if task not in task_vids:
|
||||
task_vids[task] = []
|
||||
task_vids[task].append(vid)
|
||||
return task_vids
|
||||
|
||||
|
||||
def get_vids(path):
|
||||
task_vids = {}
|
||||
with open(path, 'r') as f:
|
||||
for line in f:
|
||||
task, vid, url = line.strip().split(',')
|
||||
if task not in task_vids:
|
||||
task_vids[task] = []
|
||||
task_vids[task].append(vid)
|
||||
return task_vids
|
||||
|
||||
|
||||
def read_task_info(path):
|
||||
titles = {}
|
||||
urls = {}
|
||||
n_steps = {}
|
||||
steps = {}
|
||||
with open(path, 'r') as f:
|
||||
idx = f.readline()
|
||||
while idx != '':
|
||||
idx = idx.strip()
|
||||
titles[idx] = f.readline().strip()
|
||||
urls[idx] = f.readline().strip()
|
||||
n_steps[idx] = int(f.readline().strip())
|
||||
steps[idx] = f.readline().strip().split(',')
|
||||
next(f)
|
||||
idx = f.readline()
|
||||
return {'title': titles, 'url': urls, 'n_steps': n_steps, 'steps': steps}
|
||||
|
||||
|
||||
class PlanningDataset(Dataset):
|
||||
def __init__(self,
|
||||
root,
|
||||
args=None,
|
||||
is_val=False,
|
||||
model=None,
|
||||
crosstask_use_feature_how=True,
|
||||
):
|
||||
self.is_val = is_val
|
||||
self.data_root = root
|
||||
self.args = args
|
||||
self.max_traj_len = args.horizon
|
||||
self.vid_names = None
|
||||
self.frame_cnts = None
|
||||
self.images = None
|
||||
self.last_vid = ''
|
||||
self.crosstask_use_feature_how = crosstask_use_feature_how
|
||||
|
||||
if args.dataset == 'crosstask':
|
||||
"""
|
||||
.
|
||||
└── crosstask
|
||||
├── crosstask_features
|
||||
└── crosstask_release
|
||||
├── tasks_primary.txt
|
||||
├── videos.csv or json
|
||||
└── videos_val.csv or json
|
||||
"""
|
||||
val_csv_path = os.path.join(
|
||||
root, 'dataset', 'crosstask', 'crosstask_release', 'test_list.json') # 'videos_val.csv')
|
||||
video_csv_path = os.path.join(
|
||||
root, 'dataset', 'crosstask', 'crosstask_release', 'train_list.json') # 'videos.csv')
|
||||
|
||||
if crosstask_use_feature_how:
|
||||
self.features_path = os.path.join(root, 'dataset', 'crosstask', 'processed_data')
|
||||
else:
|
||||
self.features_path = os.path.join(root, 'dataset', 'crosstask', 'crosstask_features')
|
||||
|
||||
self.constraints_path = os.path.join(
|
||||
root, 'dataset', 'crosstask', 'crosstask_release', 'annotations')
|
||||
|
||||
self.action_one_hot = np.load(
|
||||
os.path.join(root, 'dataset', 'crosstask', 'crosstask_release', 'actions_one_hot.npy'),
|
||||
allow_pickle=True).item()
|
||||
|
||||
self.task_class = {
|
||||
'23521': 0,
|
||||
'59684': 1,
|
||||
'71781': 2,
|
||||
'113766': 3,
|
||||
'105222': 4,
|
||||
'94276': 5,
|
||||
'53193': 6,
|
||||
'105253': 7,
|
||||
'44047': 8,
|
||||
'76400': 9,
|
||||
'16815': 10,
|
||||
'95603': 11,
|
||||
'109972': 12,
|
||||
'44789': 13,
|
||||
'40567': 14,
|
||||
'77721': 15,
|
||||
'87706': 16,
|
||||
'91515': 17
|
||||
}
|
||||
|
||||
# cross_task_data_name = "/data1/wanghanlin/diffusion_planning/jsons_crosstask105/sliding_window_cross_task_data_{}_{}_new_task_id_73.json".format(
|
||||
# is_val, self.max_traj_len)
|
||||
|
||||
if is_val:
|
||||
cross_task_data_name = args.json_path_val
|
||||
else:
|
||||
cross_task_data_name = args.json_path_train
|
||||
|
||||
if os.path.exists(cross_task_data_name):
|
||||
with open(cross_task_data_name, 'r') as f:
|
||||
self.json_data = json.load(f)
|
||||
print('Loaded {}'.format(cross_task_data_name))
|
||||
else:
|
||||
file_type = val_csv_path.split('.')[-1]
|
||||
if file_type == 'json':
|
||||
all_task_vids = get_vids_from_json(video_csv_path)
|
||||
val_vids = get_vids_from_json(val_csv_path)
|
||||
else:
|
||||
all_task_vids = get_vids(video_csv_path)
|
||||
val_vids = get_vids(val_csv_path)
|
||||
|
||||
if is_val:
|
||||
task_vids = val_vids
|
||||
else:
|
||||
task_vids = {task: [vid for vid in vids if task not in val_vids or vid not in val_vids[task]] for
|
||||
task, vids in
|
||||
all_task_vids.items()}
|
||||
|
||||
primary_info = read_task_info(os.path.join(
|
||||
root, 'dataset', 'crosstask', 'crosstask_release', 'tasks_primary.txt'))
|
||||
|
||||
self.n_steps = primary_info['n_steps']
|
||||
all_tasks = set(self.n_steps.keys())
|
||||
|
||||
task_vids = {task: vids for task,
|
||||
vids in task_vids.items() if task in all_tasks}
|
||||
|
||||
all_vids = []
|
||||
for task, vids in task_vids.items():
|
||||
all_vids.extend([(task, vid) for vid in vids])
|
||||
json_data = []
|
||||
for idx in range(len(all_vids)):
|
||||
task, vid = all_vids[idx]
|
||||
if self.crosstask_use_feature_how:
|
||||
video_path = os.path.join(
|
||||
self.features_path, str(task) + '_' + str(vid) + '.npy')
|
||||
else:
|
||||
video_path = os.path.join(
|
||||
self.features_path, str(vid) + '.npy')
|
||||
legal_range = self.process_single(task, vid)
|
||||
if not legal_range:
|
||||
continue
|
||||
|
||||
temp_len = len(legal_range)
|
||||
temp = []
|
||||
while temp_len < self.max_traj_len:
|
||||
temp.append(legal_range[0])
|
||||
temp_len += 1
|
||||
|
||||
temp.extend(legal_range)
|
||||
legal_range = temp
|
||||
|
||||
for i in range(len(legal_range) - self.max_traj_len + 1):
|
||||
legal_range_current = legal_range[i:i + self.max_traj_len]
|
||||
json_data.append({'id': {'vid': vid, 'task': task, 'feature': video_path,
|
||||
'legal_range': legal_range_current, 'task_id': self.task_class[task]},
|
||||
'instruction_len': self.n_steps[task]})
|
||||
|
||||
self.json_data = json_data
|
||||
with open(cross_task_data_name, 'w') as f:
|
||||
json.dump(json_data, f)
|
||||
|
||||
elif args.dataset == 'coin':
|
||||
coin_path = os.path.join(root, 'dataset/coin', 'full_npy/')
|
||||
val_csv_path = os.path.join(
|
||||
root, 'dataset/coin', 'coin_test_30.json')
|
||||
video_csv_path = os.path.join(
|
||||
root, 'dataset/coin', 'coin_train_70.json')
|
||||
|
||||
# coin_data_name = "/data1/wanghanlin/diffusion_planning/jsons_coin/sliding_window_cross_task_data_{}_{}_new_task_id_73.json".format(
|
||||
# is_val, self.max_traj_len)
|
||||
if is_val:
|
||||
coin_data_name = args.json_path_val
|
||||
else:
|
||||
coin_data_name = args.json_path_train
|
||||
|
||||
if os.path.exists(coin_data_name):
|
||||
with open(coin_data_name, 'r') as f:
|
||||
self.json_data = json.load(f)
|
||||
print('Loaded {}'.format(coin_data_name))
|
||||
else:
|
||||
json_data = []
|
||||
num = 0
|
||||
if is_val:
|
||||
with open(val_csv_path, 'r') as f:
|
||||
coin_data = json.load(f)
|
||||
else:
|
||||
with open(video_csv_path, 'r') as f:
|
||||
coin_data = json.load(f)
|
||||
for i in coin_data:
|
||||
for (k, v) in i.items():
|
||||
file_name = v['class'] + '_' + str(v['recipe_type']) + '_' + k + '.npy'
|
||||
file_path = coin_path + file_name
|
||||
images_ = np.load(file_path, allow_pickle=True)
|
||||
images = images_['frames_features']
|
||||
legal_range = []
|
||||
|
||||
last_action = v['annotation'][-1]['segment'][1]
|
||||
last_action = math.ceil(last_action)
|
||||
if last_action > len(images):
|
||||
print(k, last_action, len(images))
|
||||
num += 1
|
||||
continue
|
||||
|
||||
for annotation in v['annotation']:
|
||||
action_label = int(annotation['id']) - 1
|
||||
start_idx, end_idx = annotation['segment']
|
||||
start_idx = math.floor(start_idx)
|
||||
end_idx = math.ceil(end_idx)
|
||||
|
||||
if end_idx < images.shape[0]:
|
||||
legal_range.append((start_idx, end_idx, action_label))
|
||||
else:
|
||||
legal_range.append((start_idx, images.shape[0] - 1, action_label))
|
||||
|
||||
temp_len = len(legal_range)
|
||||
temp = []
|
||||
while temp_len < self.max_traj_len:
|
||||
temp.append(legal_range[0])
|
||||
temp_len += 1
|
||||
|
||||
temp.extend(legal_range)
|
||||
legal_range = temp
|
||||
|
||||
for i in range(len(legal_range) - self.max_traj_len + 1):
|
||||
legal_range_current = legal_range[i:i + self.max_traj_len]
|
||||
json_data.append({'id': {'vid': k, 'feature': file_path,
|
||||
'legal_range': legal_range_current, 'task_id': v['recipe_type']},
|
||||
'instruction_len': 0})
|
||||
print(num)
|
||||
self.json_data = json_data
|
||||
with open(coin_data_name, 'w') as f:
|
||||
json.dump(json_data, f)
|
||||
|
||||
elif args.dataset == 'NIV':
|
||||
val_csv_path = os.path.join(
|
||||
root, 'dataset/NIV', 'test30_new.json')
|
||||
video_csv_path = os.path.join(
|
||||
root, 'dataset/NIV', 'train70_new.json')
|
||||
|
||||
# niv_data_name = "/data1/wanghanlin/diffusion_planning/jsons_niv/sliding_window_cross_task_data_{}_{}_new_task_id_73.json".format(
|
||||
# is_val, self.max_traj_len)
|
||||
if is_val:
|
||||
niv_data_name = args.json_path_val
|
||||
else:
|
||||
niv_data_name = args.json_path_train
|
||||
|
||||
|
||||
if os.path.exists(niv_data_name):
|
||||
with open(niv_data_name, 'r') as f:
|
||||
self.json_data = json.load(f)
|
||||
print('Loaded {}'.format(niv_data_name))
|
||||
else:
|
||||
json_data = []
|
||||
if is_val:
|
||||
with open(val_csv_path, 'r') as f:
|
||||
niv_data = json.load(f)
|
||||
else:
|
||||
with open(video_csv_path, 'r') as f:
|
||||
niv_data = json.load(f)
|
||||
for d in niv_data:
|
||||
legal_range = []
|
||||
path = os.path.join( root, 'dataset/NIV', 'processed_data' , d['feature'])
|
||||
info = np.load(path, allow_pickle=True)
|
||||
num_steps = int(info['num_steps'])
|
||||
assert num_steps == len(info['steps_ids'])
|
||||
assert info['num_steps'] == len(info['steps_starts'])
|
||||
assert info['num_steps'] == len(info['steps_ends'])
|
||||
starts = info['steps_starts']
|
||||
ends = info['steps_ends']
|
||||
action_labels = info['steps_ids']
|
||||
images = info['frames_features']
|
||||
|
||||
for i in range(num_steps):
|
||||
action_label = int(action_labels[i])
|
||||
start_idx = math.floor(float(starts[i]))
|
||||
end_idx = math.ceil(float(ends[i]))
|
||||
|
||||
if end_idx < images.shape[0]:
|
||||
legal_range.append((start_idx, end_idx, action_label))
|
||||
else:
|
||||
legal_range.append((start_idx, images.shape[0] - 1, action_label))
|
||||
|
||||
temp_len = len(legal_range)
|
||||
temp = []
|
||||
while temp_len < self.max_traj_len:
|
||||
temp.append(legal_range[0])
|
||||
temp_len += 1
|
||||
|
||||
temp.extend(legal_range)
|
||||
legal_range = temp
|
||||
|
||||
for i in range(len(legal_range) - self.max_traj_len + 1):
|
||||
legal_range_current = legal_range[i:i + self.max_traj_len]
|
||||
json_data.append({'id': {'feature': path,
|
||||
'legal_range': legal_range_current, 'task_id': d['task_id']},
|
||||
'instruction_len': 0})
|
||||
self.json_data = json_data
|
||||
with open(niv_data_name, 'w') as f:
|
||||
json.dump(json_data, f)
|
||||
print(len(json_data))
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'Dataset {} is not implemented'.format(args.dataset))
|
||||
|
||||
self.model = model
|
||||
self.prepare_data()
|
||||
self.M = 3
|
||||
|
||||
def process_single(self, task, vid):
|
||||
if self.crosstask_use_feature_how:
|
||||
if not os.path.exists(os.path.join(self.features_path, str(task) + '_' + str(vid) + '.npy')):
|
||||
return False
|
||||
images_ = np.load(os.path.join(self.features_path, str(task) + '_' + str(vid) + '.npy'), allow_pickle=True)
|
||||
images = images_['frames_features']
|
||||
else:
|
||||
if not os.path.exists(os.path.join(self.features_path, vid + '.npy')):
|
||||
return False
|
||||
images = np.load(os.path.join(self.features_path, vid + '.npy'))
|
||||
|
||||
cnst_path = os.path.join(
|
||||
self.constraints_path, task + '_' + vid + '.csv')
|
||||
legal_range = self.read_assignment(task, cnst_path)
|
||||
legal_range_ret = []
|
||||
for (start_idx, end_idx, action_label) in legal_range:
|
||||
if not start_idx < images.shape[0]:
|
||||
print(task, vid, end_idx, images.shape[0])
|
||||
return False
|
||||
if end_idx < images.shape[0]:
|
||||
legal_range_ret.append((start_idx, end_idx, action_label))
|
||||
else:
|
||||
legal_range_ret.append((start_idx, images.shape[0] - 1, action_label))
|
||||
|
||||
return legal_range_ret
|
||||
|
||||
def read_assignment(self, task_id, path):
|
||||
legal_range = []
|
||||
with open(path, 'r') as f:
|
||||
for line in f:
|
||||
step, start, end = line.strip().split(',')
|
||||
start = int(math.floor(float(start)))
|
||||
end = int(math.ceil(float(end)))
|
||||
action_label_ind = self.action_one_hot[task_id + '_' + step]
|
||||
legal_range.append((start, end, action_label_ind))
|
||||
|
||||
return legal_range
|
||||
|
||||
def prepare_data(self):
|
||||
vid_names = []
|
||||
frame_cnts = []
|
||||
for listdata in self.json_data:
|
||||
vid_names.append(listdata['id'])
|
||||
frame_cnts.append(listdata['instruction_len'])
|
||||
self.vid_names = vid_names
|
||||
self.frame_cnts = frame_cnts
|
||||
|
||||
def curate_dataset(self, images, legal_range, M=2):
|
||||
images_list = []
|
||||
labels_onehot_list = []
|
||||
idx_list = []
|
||||
for start_idx, end_idx, action_label in legal_range:
|
||||
idx = start_idx
|
||||
idx_list.append(idx)
|
||||
image_start_idx = max(0, idx)
|
||||
if image_start_idx + M <= len(images):
|
||||
image_start = images[image_start_idx: image_start_idx + M]
|
||||
else:
|
||||
image_start = images[len(images) - M: len(images)]
|
||||
image_start_cat = image_start[0]
|
||||
for w in range(len(image_start) - 1):
|
||||
image_start_cat = np.concatenate((image_start_cat, image_start[w + 1]), axis=0)
|
||||
images_list.append(image_start_cat)
|
||||
labels_onehot_list.append(action_label)
|
||||
|
||||
end_idx = max(2, end_idx)
|
||||
image_end = images[end_idx - 2:end_idx + M - 2]
|
||||
image_end_cat = image_end[0]
|
||||
for w in range(len(image_end) - 1):
|
||||
image_end_cat = np.concatenate((image_end_cat, image_end[w + 1]), axis=0)
|
||||
images_list.append(image_end_cat)
|
||||
return images_list, labels_onehot_list, idx_list
|
||||
|
||||
def sample_single(self, index):
|
||||
folder_id = self.vid_names[index]
|
||||
if self.is_val:
|
||||
event_class = folder_id['task_id']
|
||||
else:
|
||||
task_class = folder_id['task_id']
|
||||
|
||||
if self.args.dataset == 'crosstask':
|
||||
if folder_id['vid'] != self.last_vid:
|
||||
if self.crosstask_use_feature_how:
|
||||
images_ = np.load(folder_id['feature'], allow_pickle=True)
|
||||
self.images = images_['frames_features']
|
||||
self.last_vid = folder_id['vid']
|
||||
else:
|
||||
self.images = np.load(os.path.join(self.features_path, folder_id['vid'] + '.npy'))
|
||||
else:
|
||||
images_ = np.load(folder_id['feature'], allow_pickle=True)
|
||||
self.images = images_['frames_features']
|
||||
|
||||
images, labels_matrix, idx_list = self.curate_dataset(
|
||||
self.images, folder_id['legal_range'], M=self.M)
|
||||
frames = torch.tensor(np.array(images))
|
||||
labels_tensor = torch.tensor(labels_matrix, dtype=torch.long)
|
||||
|
||||
if self.is_val:
|
||||
event_class = torch.tensor(event_class, dtype=torch.long)
|
||||
return frames, labels_tensor, event_class
|
||||
else:
|
||||
task_class = torch.tensor(task_class, dtype=torch.long)
|
||||
return frames, labels_tensor, task_class
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.is_val:
|
||||
frames, labels, event_class = self.sample_single(index)
|
||||
else:
|
||||
frames, labels, task = self.sample_single(index)
|
||||
if self.is_val:
|
||||
batch = Batch(frames, labels, event_class)
|
||||
else:
|
||||
batch = Batch(frames, labels, task)
|
||||
return batch
|
||||
|
||||
def __len__(self):
|
||||
return len(self.json_data)
|
Loading…
Add table
Add a link
Reference in a new issue