up
This commit is contained in:
parent
d4aaf7f4ad
commit
25b8b3f343
55 changed files with 7592 additions and 4 deletions
224
tbd/utils/fb_scores_err.py
Normal file
224
tbd/utils/fb_scores_err.py
Normal file
|
@ -0,0 +1,224 @@
|
|||
import os
|
||||
import csv
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
from sklearn.metrics import f1_score
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
ALPHA = 0.7
|
||||
BAR_WIDTH = 0.27
|
||||
sns.set_theme(style='whitegrid')
|
||||
#sns.set_palette('mako')
|
||||
|
||||
MTOM_COLORS = {
|
||||
"MN1": (110/255, 117/255, 161/255),
|
||||
"MN2": (179/255, 106/255, 98/255),
|
||||
"Base": (193/255, 198/255, 208/255),
|
||||
"CG": (170/255, 129/255, 42/255),
|
||||
"IC": (97/255, 112/255, 83/255),
|
||||
"DB": (144/255, 63/255, 110/255)
|
||||
}
|
||||
|
||||
model_to_subdir = {
|
||||
"IC$\parallel$": ["2023-07-16_10-34-32_train", "2023-07-18_13-49-57_train", "2023-07-19_12-17-46_train"],
|
||||
"IC$\oplus$": ["2023-07-16_10-35-02_train", "2023-07-18_13-50-32_train", "2023-07-19_12-18-18_train"],
|
||||
"IC$\otimes$": ["2023-07-16_10-35-41_train", "2023-07-18_13-52-26_train", "2023-07-19_12-18-49_train"],
|
||||
"IC$\odot$": ["2023-07-16_10-36-04_train", "2023-07-18_13-53-03_train", "2023-07-19_12-19-50_train"],
|
||||
"CG$\parallel$": ["2023-07-15_14-12-36_train", "2023-07-17_11-54-28_train", "2023-07-19_00-30-05_train"],
|
||||
"CG$\oplus$": ["2023-07-15_14-14-08_train", "2023-07-17_11-56-05_train", "2023-07-19_00-30-47_train"],
|
||||
"CG$\otimes$": ["2023-07-15_14-14-53_train", "2023-07-17_11-56-39_train", "2023-07-19_00-31-36_train"],
|
||||
"CG$\odot$": ["2023-07-15_14-10-05_train", "2023-07-17_11-57-30_train", "2023-07-19_00-32-10_train"],
|
||||
"DB": ["2023-08-08_12-56-02_train", "2023-08-08_19-07-43_train", "2023-08-08_19-08-47_train"],
|
||||
"Base": ["2023-08-08_12-53-38_train", "2023-08-08_19-10-02_train", "2023-08-08_19-10-51_train"]
|
||||
}
|
||||
|
||||
def read_data_from_csv(subdirectory_path):
|
||||
print(subdirectory_path)
|
||||
data = []
|
||||
csv_files = [file for file in os.listdir(subdirectory_path) if file.endswith('.csv')]
|
||||
for csv_file in csv_files:
|
||||
file_path = os.path.join(subdirectory_path, csv_file)
|
||||
with open(file_path, 'r') as file:
|
||||
reader = csv.reader(file)
|
||||
header_skipped = False
|
||||
for row in reader:
|
||||
if not header_skipped:
|
||||
header_skipped = True
|
||||
continue
|
||||
frame, m1_pred, m2_pred, m12_pred, m21_pred, mc_pred, m1_label, m2_label, m12_label, m21_label, mc_label, false_belief = row
|
||||
data.append({
|
||||
'frame': int(frame),
|
||||
'm1_pred': int(m1_pred),
|
||||
'm2_pred': int(m2_pred),
|
||||
'm12_pred': int(m12_pred),
|
||||
'm21_pred': int(m21_pred),
|
||||
'mc_pred': int(mc_pred),
|
||||
'm1_label': int(m1_label),
|
||||
'm2_label': int(m2_label),
|
||||
'm12_label': int(m12_label),
|
||||
'm21_label': int(m21_label),
|
||||
'mc_label': int(mc_label),
|
||||
'false_belief': false_belief,
|
||||
})
|
||||
return data
|
||||
|
||||
def compute_correct_false_belief(data, mind="all", folder=None):
|
||||
total_false_belief = 0
|
||||
correct_false_belief = 0
|
||||
for item in data:
|
||||
if 'false' in item['false_belief']:
|
||||
false_belief_type = item['false_belief'].split('_')[0]
|
||||
if mind == "all" or false_belief_type in mind:
|
||||
total_false_belief += 1
|
||||
if item[f"{false_belief_type}_pred"] == item[f"{false_belief_type}_label"]:
|
||||
if folder is not None:
|
||||
with open(f"predictions/{folder}/fb_{'_'.join(mind)}.txt" if isinstance(mind, list) else f"predictions/{folder}/fb_{mind}.txt", "a") as f:
|
||||
f.write(f"{str(1)}\n")
|
||||
correct_false_belief += 1
|
||||
else:
|
||||
if folder is not None:
|
||||
with open(f"predictions/{folder}/fb_{'_'.join(mind)}.txt" if isinstance(mind, list) else f"predictions/{folder}/fb_{mind}.txt", "a") as f:
|
||||
f.write(f"{str(0)}\n")
|
||||
if total_false_belief == 0:
|
||||
accuracy = 0.0
|
||||
else:
|
||||
accuracy = correct_false_belief / total_false_belief
|
||||
return accuracy
|
||||
|
||||
def compute_macro_f1_score(data, mind="all"):
|
||||
y_true = []
|
||||
y_pred = []
|
||||
|
||||
for item in data:
|
||||
if 'false' in item['false_belief']:
|
||||
false_belief_type = item['false_belief'].split('_')[0]
|
||||
if mind == "all" or false_belief_type in mind:
|
||||
y_true.append(int(item[f"{false_belief_type}_label"]))
|
||||
y_pred.append(int(item[f"{false_belief_type}_pred"]))
|
||||
|
||||
if not y_true or not y_pred:
|
||||
macro_f1 = 0.0
|
||||
else:
|
||||
macro_f1 = f1_score(y_true, y_pred, average='macro')
|
||||
|
||||
return macro_f1
|
||||
|
||||
def delete_files_in_subfolders(folder_path, file_names_to_delete):
|
||||
"""
|
||||
Delete specified files in all subfolders of a given folder.
|
||||
|
||||
Parameters:
|
||||
folder_path: The path to the folder containing subfolders.
|
||||
file_names_to_delete: A list of file names to be deleted.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for root, _, _ in os.walk(folder_path):
|
||||
for file_name in file_names_to_delete:
|
||||
file_path = os.path.join(root, file_name)
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
print(f"Deleted: {file_path}")
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
folder_path = "predictions"
|
||||
files_to_delete = ["fb_m1_m2_m12_m21.txt", "fb_m1_m2.txt", "fb_m12_m21.txt"]
|
||||
delete_files_in_subfolders(folder_path, files_to_delete)
|
||||
|
||||
metric = "Accuracy"
|
||||
if metric == "Macro F1":
|
||||
score_function = compute_macro_f1_score
|
||||
elif metric == "Accuracy":
|
||||
score_function = compute_correct_false_belief
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
models = [
|
||||
'Base', 'DB',
|
||||
'CG$\parallel$', 'CG$\oplus$', 'CG$\otimes$', 'CG$\odot$',
|
||||
'IC$\parallel$', 'IC$\oplus$', 'IC$\otimes$', 'IC$\odot$'
|
||||
]
|
||||
|
||||
parent_dir = 'predictions'
|
||||
minds = categories = ['m1', 'm2', 'm12', 'm21']
|
||||
score_m1_m2 = []
|
||||
score_m12_m21 = []
|
||||
score_all = []
|
||||
std_m1_m2 = []
|
||||
std_m12_m21 = []
|
||||
std_all = []
|
||||
|
||||
for model in models:
|
||||
model_scores_m1_m2 = []
|
||||
model_scores_m12_m21 = []
|
||||
model_scores_all = []
|
||||
for s in range(3):
|
||||
subdir_path = os.path.join(parent_dir, model_to_subdir[model][s])
|
||||
data = read_data_from_csv(subdir_path)
|
||||
model_scores_m1_m2.append(score_function(data, ['m1', 'm2'], model_to_subdir[model][s]))
|
||||
model_scores_m12_m21.append(score_function(data, ['m12', 'm21'], model_to_subdir[model][s]))
|
||||
model_scores_all.append(score_function(data, ['m1', 'm2', 'm12', 'm21'], model_to_subdir[model][s]))
|
||||
score_m1_m2.append(np.mean(model_scores_m1_m2))
|
||||
std_m1_m2.append(np.std(model_scores_m1_m2))
|
||||
score_m12_m21.append(np.mean(model_scores_m12_m21))
|
||||
std_m12_m21.append(np.std(model_scores_m12_m21))
|
||||
score_all.append(np.mean(model_scores_all))
|
||||
std_all.append(np.std(model_scores_all))
|
||||
|
||||
# Create a dataframe to use with sns.catplot
|
||||
data = {
|
||||
'Model': [m for m in models],
|
||||
'FO_FB_mean': score_m1_m2,
|
||||
'FO_FB_std': std_m1_m2,
|
||||
'SO_FB_mean': score_m12_m21,
|
||||
'SO_FB_std': std_m12_m21,
|
||||
'Both_mean': score_all,
|
||||
'Both_std': std_all
|
||||
}
|
||||
|
||||
models = data['Model']
|
||||
fo_fb_mean = data['FO_FB_mean']
|
||||
fo_fb_std = data['FO_FB_std']
|
||||
so_fb_mean = data['SO_FB_mean']
|
||||
so_fb_std = data['SO_FB_std']
|
||||
both_mean = data['Both_mean']
|
||||
both_std = data['Both_std']
|
||||
|
||||
bar_width = BAR_WIDTH
|
||||
x = np.arange(len(models))
|
||||
|
||||
plt.figure(figsize=(13, 3.5))
|
||||
fo_fb_bars = plt.bar(x - bar_width, fo_fb_mean, width=bar_width, yerr=fo_fb_std, capsize=4, label='First-order false belief', alpha=ALPHA)
|
||||
so_fb_bars = plt.bar(x, so_fb_mean, width=bar_width, yerr=so_fb_std, capsize=4, label='Second-order false belief', alpha=ALPHA)
|
||||
both_bars = plt.bar(x + bar_width, both_mean, width=bar_width, yerr=both_std, capsize=4, label='Both', alpha=ALPHA)
|
||||
|
||||
def add_labels(bars, std_values):
|
||||
cnt = 0
|
||||
for bar, std in zip(bars, std_values):
|
||||
height = bar.get_height()
|
||||
offset = std + 0.01
|
||||
if cnt == 0 or cnt == 1 or cnt == 9:
|
||||
plt.text(bar.get_x() + bar.get_width() / 2., height + offset, f'{height:.2f}*', ha='center', va='bottom', fontsize=10)
|
||||
else:
|
||||
plt.text(bar.get_x() + bar.get_width() / 2., height + offset, f'{height:.2f}', ha='center', va='bottom', fontsize=10)
|
||||
cnt = cnt + 1
|
||||
|
||||
add_labels(fo_fb_bars, fo_fb_std)
|
||||
add_labels(so_fb_bars, so_fb_std)
|
||||
add_labels(both_bars, both_std)
|
||||
|
||||
plt.gca().spines['top'].set_visible(False)
|
||||
plt.gca().spines['right'].set_visible(False)
|
||||
plt.xlabel('MToMnet', fontsize=14)
|
||||
plt.ylabel('Macro F1 Score' if metric == "Macro F1" else 'Accuracy', fontsize=14)
|
||||
plt.xticks(x, models, rotation=0, fontsize=14)
|
||||
plt.yticks(fontsize=14)
|
||||
plt.legend(fontsize=14, loc='upper center', bbox_to_anchor=(0.5, 1.3), ncol=3)
|
||||
plt.tight_layout()
|
||||
plt.savefig('results/false_belief_first_vs_second.pdf')
|
210
tbd/utils/helpers.py
Normal file
210
tbd/utils/helpers.py
Normal file
File diff suppressed because one or more lines are too long
37
tbd/utils/preprocess_img.py
Normal file
37
tbd/utils/preprocess_img.py
Normal file
|
@ -0,0 +1,37 @@
|
|||
import glob
|
||||
|
||||
import cv2
|
||||
|
||||
import torchvision.transforms as T
|
||||
import torch
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
PATH_IN = "/scratch/bortoletto/data/tbd/images"
|
||||
PATH_OUT = "/scratch/bortoletto/data/tbd/images_norm"
|
||||
|
||||
normalisation_steps = [
|
||||
T.ToTensor(),
|
||||
T.Resize((128,128)),
|
||||
T.Normalize(
|
||||
mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225]
|
||||
)
|
||||
]
|
||||
|
||||
preprocess_img = T.Compose(normalisation_steps)
|
||||
|
||||
def main():
|
||||
print(f"{PATH_IN}/*/*/*.jpg")
|
||||
all_img = glob.glob(f"{PATH_IN}/*/*/*.jpg")
|
||||
print(len(all_img))
|
||||
for img_path in tqdm(all_img):
|
||||
new_img = preprocess_img(cv2.imread(img_path)).numpy()
|
||||
img_path_split = img_path.split("/")
|
||||
os.makedirs(f"{PATH_OUT}/{img_path_split[-3]}/{img_path_split[-2]}", exist_ok=True)
|
||||
out_img = f"{PATH_OUT}/{img_path_split[-3]}/{img_path_split[-2]}/{img_path_split[-1][:-4]}.pt"
|
||||
torch.save(new_img, out_img)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
106
tbd/utils/reformat_labels_ours.py
Normal file
106
tbd/utils/reformat_labels_ours.py
Normal file
|
@ -0,0 +1,106 @@
|
|||
import pandas as pd
|
||||
import os
|
||||
import glob
|
||||
import pickle
|
||||
|
||||
DATASET_LOCATION = "YOUR_PATH_HERE"
|
||||
|
||||
def reframe_annotation():
|
||||
annotation_path = f'{DATASET_LOCATION}/retrieve_annotation/all/'
|
||||
save_path = f'{DATASET_LOCATION}/reformat_annotation/'
|
||||
if not os.path.exists(save_path):
|
||||
os.makedirs(save_path)
|
||||
tasks = glob.glob(annotation_path + '*.txt')
|
||||
id_map = pd.read_csv('id_map.csv')
|
||||
for task in tasks:
|
||||
if not task.split('/')[-1].split('_')[2] == '1.txt':
|
||||
continue
|
||||
with open(task, 'r') as f:
|
||||
lines = f.readlines()
|
||||
task_id = int(task.split('/')[-1].split('_')[1]) + 1
|
||||
clip = id_map.loc[id_map['ID'] == task_id].folder
|
||||
print(task_id, len(clip))
|
||||
if len(clip) == 0:
|
||||
continue
|
||||
with open(save_path + clip.item() + '.txt', 'w') as f:
|
||||
for line in lines:
|
||||
words = line.split()
|
||||
f.write(words[0] + ',' + words[1] + ',' + words[2] + ',' + words[3] + ',' + words[4] + ',' + words[5] +
|
||||
',' + words[6] + ',' + words[7] + ',' + words[8] + ',' + words[9] + ',' + ' '.join(words[10:]) + '\n')
|
||||
f.close()
|
||||
|
||||
def get_grid_location(obj_frame):
|
||||
x_min = obj_frame['x_min']#.item()
|
||||
y_min = obj_frame['y_min']#.item()
|
||||
x_max = obj_frame['x_max']#.item()
|
||||
y_max = obj_frame['y_max']#.item()
|
||||
gridLW = 1280 / 25.
|
||||
gridLH = 720 / 15.
|
||||
center_x, center_y = (x_min + x_max)/2, (y_min + y_max)/2
|
||||
X, Y = int(center_x / gridLW), int(center_y / gridLH)
|
||||
return X, Y
|
||||
|
||||
def regenerate_annotation():
|
||||
annotation_path = f'{DATASET_LOCATION}/reformat_annotation/'
|
||||
save_path=f'{DATASET_LOCATION}/regenerate_annotation/'
|
||||
if not os.path.exists(save_path):
|
||||
os.makedirs(save_path)
|
||||
tasks = glob.glob(annotation_path + '*.txt')
|
||||
for task in tasks:
|
||||
print(task)
|
||||
annt = pd.read_csv(task, sep=",", header=None)
|
||||
annt.columns = ["obj_id", "x_min", "y_min", "x_max", "y_max", "frame", "lost", "occluded", "generated", "name", "label"]
|
||||
obj_records = {}
|
||||
for index, obj_frame in annt.iterrows():
|
||||
if obj_frame['name'].startswith('P'):
|
||||
continue
|
||||
else:
|
||||
assert obj_frame['name'].startswith('O')
|
||||
obj_name = obj_frame['name']
|
||||
# 0: enter 1: disappear 2: update 3: unchange
|
||||
frame_id = obj_frame['frame']
|
||||
curr_loc = get_grid_location(obj_frame)
|
||||
mind_dict = {'m1': {'fluent': 3, 'loc': None}, 'm2': {'fluent': 3, 'loc': None},
|
||||
'm12': {'fluent': 3, 'loc': None},
|
||||
'm21': {'fluent': 3, 'loc': None}, 'mc': {'fluent': 3, 'loc': None},
|
||||
'mg': {'fluent': 3, 'loc': curr_loc}}
|
||||
mind_dict['mg']['loc'] = curr_loc
|
||||
if not type(obj_frame['label']) == float:
|
||||
mind_labels = obj_frame['label'].split()
|
||||
for mind_label in mind_labels:
|
||||
if mind_label == 'in_m1' or mind_label == 'in_m2' or mind_label == 'in_m12' \
|
||||
or mind_label == 'in_m21' or mind_label == 'in_mc' or mind_label == '"in_m1"' or mind_label == '"in_m2"'\
|
||||
or mind_label == '"in_m12"' or mind_label == '"in_m21"' or mind_label == '"in_mc"':
|
||||
mind_name = mind_label.split('_')[1].split('"')[0]
|
||||
mind_dict[mind_name]['loc'] = curr_loc
|
||||
else:
|
||||
mind_name = mind_label.split('_')[0].split('"')
|
||||
if len(mind_name) > 1:
|
||||
mind_name = mind_name[1]
|
||||
else:
|
||||
mind_name = mind_name[0]
|
||||
last_loc = obj_records[obj_name][frame_id - 1][mind_name]['loc']
|
||||
mind_dict[mind_name]['loc'] = last_loc
|
||||
|
||||
for mind_name in mind_dict.keys():
|
||||
if frame_id > 0:
|
||||
curr_loc = mind_dict[mind_name]['loc']
|
||||
last_loc = obj_records[obj_name][frame_id - 1][mind_name]['loc']
|
||||
if last_loc is None and curr_loc is not None:
|
||||
mind_dict[mind_name]['fluent'] = 0
|
||||
elif last_loc is not None and curr_loc is None:
|
||||
mind_dict[mind_name]['fluent'] = 1
|
||||
elif not last_loc == curr_loc:
|
||||
mind_dict[mind_name]['fluent'] = 2
|
||||
if obj_name not in obj_records:
|
||||
obj_records[obj_name] = [mind_dict]
|
||||
else:
|
||||
obj_records[obj_name].append(mind_dict)
|
||||
|
||||
with open(save_path + task.split('/')[-1].split('.')[0] + '.p', 'wb') as f:
|
||||
pickle.dump(obj_records, f)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
reframe_annotation()
|
||||
regenerate_annotation()
|
75
tbd/utils/similarity.py
Normal file
75
tbd/utils/similarity.py
Normal file
|
@ -0,0 +1,75 @@
|
|||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
import matplotlib.pyplot as plt
|
||||
from sklearn.decomposition import PCA
|
||||
import seaborn as sns
|
||||
|
||||
|
||||
FOLDER_PATH = 'PATH_TO_FOLDER'
|
||||
|
||||
print(FOLDER_PATH)
|
||||
|
||||
MTOM_COLORS = {
|
||||
"MN1": (110/255, 117/255, 161/255),
|
||||
"MN2": (179/255, 106/255, 98/255),
|
||||
"Base": (193/255, 198/255, 208/255),
|
||||
"CG": (170/255, 129/255, 42/255),
|
||||
"IC": (97/255, 112/255, 83/255),
|
||||
"DB": (144/255, 63/255, 110/255)
|
||||
}
|
||||
|
||||
COLORS = sns.color_palette()
|
||||
|
||||
sns.set_theme(style='white')
|
||||
|
||||
out_left_main_mods_full_test = []
|
||||
out_right_main_mods_full_test = []
|
||||
cell_left_main_mods_full_test = []
|
||||
cell_right_main_mods_full_test = []
|
||||
cm_left_main_mods_full_test = []
|
||||
cm_right_main_mods_full_test = []
|
||||
|
||||
for i in range(len([filename for filename in os.listdir(FOLDER_PATH) if filename.endswith('.pt')])):
|
||||
|
||||
print(f'Computing analysis for test video {i}...', end='\r')
|
||||
|
||||
emb_file = os.path.join(FOLDER_PATH, f'{i}.pt')
|
||||
data = torch.load(emb_file)
|
||||
if len(data) == 13: # implicit
|
||||
model = 'impl'
|
||||
out_left, cell_left, out_right, cell_right, feats = data[0], data[1], data[2], data[3], data[4:]
|
||||
elif len(data) == 12: # common mind
|
||||
model = 'cm'
|
||||
out_left, out_right, common_mind, feats = data[0], data[1], data[2], data[3:]
|
||||
elif len(data) == 11: # speaker-listener
|
||||
model = 'sl'
|
||||
out_left, out_right, feats = data[0], data[1], data[2:]
|
||||
else: raise ValueError("Data should have 13 (impl), others are not implemented")
|
||||
|
||||
# ====== PCA for left and right embeddings ====== #
|
||||
|
||||
out_left_pca = out_left[0].reshape(-1, 64)
|
||||
out_right_pca = out_right[0].reshape(-1, 64)
|
||||
out_left_and_right = np.concatenate((out_left_pca, out_right_pca), axis=0)
|
||||
|
||||
pca = PCA(n_components=2)
|
||||
pca_result = pca.fit_transform(out_left_and_right)
|
||||
|
||||
# Separate the PCA results for each tensor
|
||||
pca_result_left = pca_result[:out_left_pca.shape[0]]
|
||||
pca_result_right = pca_result[out_right_pca.shape[0]:]
|
||||
|
||||
plt.figure(figsize=(6.8,6))
|
||||
plt.scatter(pca_result_left[:, 0], pca_result_left[:, 1], label='$h_1$', color=MTOM_COLORS['MN1'], s=100)
|
||||
plt.scatter(pca_result_right[:, 0], pca_result_right[:, 1], label='$h_2$', color=MTOM_COLORS['MN2'], s=100)
|
||||
plt.xlabel('Principal Component 1', fontsize=32)
|
||||
plt.ylabel('Principal Component 2', fontsize=32)
|
||||
plt.xticks(fontsize=24)
|
||||
plt.xticks([-0.4, -0.2, 0.0, 0.2, 0.4])
|
||||
plt.yticks(fontsize=24)
|
||||
plt.legend(fontsize=32)
|
||||
plt.tight_layout()
|
||||
plt.savefig(f'{FOLDER_PATH}/{i}_pca.pdf')
|
||||
plt.close()
|
96
tbd/utils/store_mind_set.py
Normal file
96
tbd/utils/store_mind_set.py
Normal file
|
@ -0,0 +1,96 @@
|
|||
import os
|
||||
import pandas as pd
|
||||
import pickle
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def check_append(obj_name, m1, mind_name, obj_frame, flags, label):
|
||||
if label:
|
||||
if not obj_name in m1:
|
||||
m1[obj_name] = []
|
||||
m1[obj_name].append(
|
||||
[obj_frame.frame, [obj_frame.x_min, obj_frame.y_min, obj_frame.x_max, obj_frame.y_max], 0])
|
||||
flags[mind_name] = 1
|
||||
elif not flags[mind_name]:
|
||||
m1[obj_name].append(
|
||||
[obj_frame.frame, [obj_frame.x_min, obj_frame.y_min, obj_frame.x_max, obj_frame.y_max], 0])
|
||||
flags[mind_name] = 1
|
||||
else: # false belief
|
||||
if obj_name in m1:
|
||||
if flags[mind_name]:
|
||||
m1[obj_name].append(
|
||||
[obj_frame.frame, [obj_frame.x_min, obj_frame.y_min, obj_frame.x_max, obj_frame.y_max], 1])
|
||||
flags[mind_name] = 0
|
||||
return flags, m1
|
||||
|
||||
|
||||
def store_mind_set(clip, annotation_path, save_path):
|
||||
if not os.path.exists(save_path):
|
||||
os.makedirs(save_path)
|
||||
annt = pd.read_csv(annotation_path + clip, sep=",", header=None)
|
||||
annt.columns = ["obj_id", "x_min", "y_min", "x_max", "y_max", "frame", "lost", "occluded", "generated", "name",
|
||||
"label"]
|
||||
obj_names = annt.name.unique()
|
||||
m1, m2, m12, m21, mc = {}, {}, {}, {}, {}
|
||||
flags = {'m1':0, 'm2':0, 'm12':0, 'm21':0, 'mc':0}
|
||||
for obj_name in obj_names:
|
||||
if obj_name == 'P1' or obj_name == 'P2':
|
||||
continue
|
||||
obj_frames = annt.loc[annt.name == obj_name]
|
||||
for index, obj_frame in obj_frames.iterrows():
|
||||
if type(obj_frame.label) == float:
|
||||
continue
|
||||
labels = obj_frame.label.split()
|
||||
for label in labels:
|
||||
if label == 'in_m1' or label == '"in_m1"':
|
||||
flags, m1 = check_append(obj_name, m1, 'm1', obj_frame, flags, 1)
|
||||
elif label == 'in_m2' or label == '"in_m2"':
|
||||
flags, m2 = check_append(obj_name, m2, 'm2', obj_frame, flags, 1)
|
||||
elif label == 'in_m12'or label == '"in_m12"':
|
||||
flags, m12 = check_append(obj_name, m12, 'm12', obj_frame, flags, 1)
|
||||
elif label == 'in_m21' or label == '"in_m21"':
|
||||
flags, m21 = check_append(obj_name, m21, 'm21', obj_frame, flags, 1)
|
||||
elif label == 'in_mc'or label == '"in_mc"':
|
||||
flags, mc = check_append(obj_name, mc, 'mc', obj_frame, flags, 1)
|
||||
elif label == 'm1_false' or label == '"m1_false"':
|
||||
flags, m1 = check_append(obj_name, m1, 'm1', obj_frame, flags, 0)
|
||||
flags, m12 = check_append(obj_name, m12, 'm12', obj_frame, flags, 0)
|
||||
flags, m21 = check_append(obj_name, m21, 'm21', obj_frame, flags, 0)
|
||||
false_belief = 'm1_false'
|
||||
with open(save_path + clip.split('.')[0] + '.txt', "a") as file:
|
||||
file.write(f"{obj_frame['frame']},{obj_name},{false_belief}\n")
|
||||
elif label == 'm2_false' or label == '"m2_false"':
|
||||
flags, m2 = check_append(obj_name, m2, 'm2', obj_frame, flags, 0)
|
||||
flags, m12 = check_append(obj_name, m12, 'm12', obj_frame, flags, 0)
|
||||
flags, m21 = check_append(obj_name, m21, 'm21', obj_frame, flags, 0)
|
||||
false_belief = 'm2_false'
|
||||
with open(save_path + clip.split('.')[0] + '.txt', "a") as file:
|
||||
file.write(f"{obj_frame['frame']},{obj_name},{false_belief}\n")
|
||||
elif label == 'm12_false' or label == '"m12_false"':
|
||||
flags, m12 = check_append(obj_name, m12, 'm12', obj_frame, flags, 0)
|
||||
flags, mc = check_append(obj_name, mc, 'mc', obj_frame, flags, 0)
|
||||
false_belief = 'm12_false'
|
||||
with open(save_path + clip.split('.')[0] + '.txt', "a") as file:
|
||||
file.write(f"{obj_frame['frame']},{obj_name},{false_belief}\n")
|
||||
elif label == 'm21_false' or label == '"m21_false"':
|
||||
flags, m21 = check_append(obj_name, m2, 'm21', obj_frame, flags, 0)
|
||||
flags, mc = check_append(obj_name, mc, 'mc', obj_frame, flags, 0)
|
||||
false_belief = 'm21_false'
|
||||
with open(save_path + clip.split('.')[0] + '.txt', "a") as file:
|
||||
file.write(f"{obj_frame['frame']},{obj_name},{false_belief}\n")
|
||||
# print('m1', m1)
|
||||
# print('m2', m2)
|
||||
# print('m12', m12)
|
||||
# print('m21', m21)
|
||||
# print('mc', mc)
|
||||
#with open(save_path + clip.split('.')[0] + '.p', 'wb') as f:
|
||||
# pickle.dump([m1, m2, m12, m21, mc], f)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
annotation_path = '/scratch/bortoletto/data/tbd/reformat_annotation/'
|
||||
save_path = '/scratch/bortoletto/data/tbd/store_mind_set/'
|
||||
|
||||
for clip in tqdm(os.listdir(annotation_path), desc="Processing videos", unit="item"):
|
||||
store_mind_set(clip, annotation_path, save_path)
|
95
tbd/utils/visualize_bbox.py
Normal file
95
tbd/utils/visualize_bbox.py
Normal file
|
@ -0,0 +1,95 @@
|
|||
import time
|
||||
from tbd_dataloader import TBDv2Dataset
|
||||
|
||||
import numpy as np
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as patches
|
||||
|
||||
def point2screen(points):
|
||||
K = [607.13232421875, 0.0, 638.6468505859375, 0.0, 607.1067504882812, 367.1607360839844, 0.0, 0.0, 1.0]
|
||||
K = np.reshape(np.array(K), [3, 3])
|
||||
rot_points = np.array(points) + np.array([0, 0.2, 0])
|
||||
rot_points = rot_points
|
||||
points_camera = rot_points.reshape(3, 1)
|
||||
|
||||
project_matrix = np.array(K).reshape(3, 3)
|
||||
points_prj = project_matrix.dot(points_camera)
|
||||
points_prj = points_prj.transpose()
|
||||
if not points_prj[:, 2][0] == 0.0:
|
||||
points_prj[:, 0] = points_prj[:, 0] / points_prj[:, 2]
|
||||
points_prj[:, 1] = points_prj[:, 1] / points_prj[:, 2]
|
||||
points_screen = points_prj[:, :2]
|
||||
assert points_screen.shape == (1, 2)
|
||||
points_screen = points_screen.reshape(-1)
|
||||
return points_screen
|
||||
|
||||
if __name__ == '__main__':
|
||||
data = TBDv2Dataset(number_frames_to_sample=1, resize_img=None)
|
||||
index = np.random.randint(0, len(data))
|
||||
start = time.time()
|
||||
(
|
||||
kinect_imgs, # <- len x 720 x 1280 x 3
|
||||
tracker_imgs,
|
||||
battery_imgs,
|
||||
skele1,
|
||||
skele2,
|
||||
bbox,
|
||||
tracker_skeID_sample, # <- This is the tracker skeleton ID
|
||||
tracker2d,
|
||||
label,
|
||||
experiment_id, # From here for debugging
|
||||
timestep,
|
||||
obj_id, # <- This is the object ID as a string
|
||||
) = data[index]
|
||||
end = time.time()
|
||||
print(f"Time for one sample: {end-start}")
|
||||
|
||||
img = kinect_imgs[-1]
|
||||
bbox = bbox[-1]
|
||||
print(label.shape)
|
||||
|
||||
print(skele1.shape)
|
||||
print(skele2.shape)
|
||||
|
||||
skele1 = skele1[-1, :,:]
|
||||
skele2 = skele2[-1, :,:]
|
||||
|
||||
print(skele1.shape)
|
||||
|
||||
|
||||
|
||||
# reshape img from c, h, w to h, w, c
|
||||
img = img.permute(1, 2, 0)
|
||||
|
||||
fig, ax = plt.subplots(1)
|
||||
ax.imshow(img)
|
||||
print(bbox[0], bbox[1], bbox[2], bbox[3]) # t(top left x, top left y, width, height)
|
||||
top_left_x, top_left_y, width, height = bbox[0], bbox[1], bbox[2], bbox[3]
|
||||
x_min, y_min, x_max, y_max = bbox[0], bbox[1], bbox[2], bbox[3]
|
||||
|
||||
|
||||
|
||||
|
||||
for i in range(26):
|
||||
print(skele1[i,0], skele1[i,1])
|
||||
print(skele1[i,:].shape)
|
||||
print(point2screen(skele1[i,:]))
|
||||
x, y = point2screen(skele1[i,:])[0], point2screen(skele1[i,:])[1]
|
||||
ax.text(x, y, f"{i}", fontsize=5, color='w')
|
||||
|
||||
wedge = patches.Wedge((x,y), 10, 0, 360, width=10, color='b')
|
||||
ax.add_patch(wedge)
|
||||
|
||||
for i in range(26):
|
||||
x, y = point2screen(skele2[i,:])[0], point2screen(skele2[i,:])[1]
|
||||
ax.text(x, y, f"{i}", fontsize=5, color='w')
|
||||
wedge = patches.Wedge((point2screen(skele2[i,:])[0], point2screen(skele2[i,:])[1]), 10, 0, 360, width=10, color='r')
|
||||
ax.add_patch(wedge)
|
||||
|
||||
# Create a Rectangle patch
|
||||
# rect = patches.Rectangle((top_left_x, top_left_y-height), width, height, linewidth=1, edgecolor='r', facecolor='none')
|
||||
# ax.add_patch(rect)
|
||||
# rect = patches.Rectangle((x_min, y_max), x_max-x_min, y_max-y_min, linewidth=1, edgecolor='g', facecolor='none')
|
||||
# ax.add_patch(rect)
|
||||
fig.savefig(f"bbox_{obj_id}_{index}_{experiment_id}.png")
|
Loading…
Add table
Add a link
Reference in a new issue