mtomnet/tbd/utils/fb_scores_err.py

225 lines
8.9 KiB
Python
Raw Permalink Normal View History

2025-01-10 15:39:20 +01:00
import os
import csv
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import f1_score
import numpy as np
from tqdm import tqdm
ALPHA = 0.7
BAR_WIDTH = 0.27
sns.set_theme(style='whitegrid')
#sns.set_palette('mako')
MTOM_COLORS = {
"MN1": (110/255, 117/255, 161/255),
"MN2": (179/255, 106/255, 98/255),
"Base": (193/255, 198/255, 208/255),
"CG": (170/255, 129/255, 42/255),
"IC": (97/255, 112/255, 83/255),
"DB": (144/255, 63/255, 110/255)
}
model_to_subdir = {
"IC$\parallel$": ["2023-07-16_10-34-32_train", "2023-07-18_13-49-57_train", "2023-07-19_12-17-46_train"],
"IC$\oplus$": ["2023-07-16_10-35-02_train", "2023-07-18_13-50-32_train", "2023-07-19_12-18-18_train"],
"IC$\otimes$": ["2023-07-16_10-35-41_train", "2023-07-18_13-52-26_train", "2023-07-19_12-18-49_train"],
"IC$\odot$": ["2023-07-16_10-36-04_train", "2023-07-18_13-53-03_train", "2023-07-19_12-19-50_train"],
"CG$\parallel$": ["2023-07-15_14-12-36_train", "2023-07-17_11-54-28_train", "2023-07-19_00-30-05_train"],
"CG$\oplus$": ["2023-07-15_14-14-08_train", "2023-07-17_11-56-05_train", "2023-07-19_00-30-47_train"],
"CG$\otimes$": ["2023-07-15_14-14-53_train", "2023-07-17_11-56-39_train", "2023-07-19_00-31-36_train"],
"CG$\odot$": ["2023-07-15_14-10-05_train", "2023-07-17_11-57-30_train", "2023-07-19_00-32-10_train"],
"DB": ["2023-08-08_12-56-02_train", "2023-08-08_19-07-43_train", "2023-08-08_19-08-47_train"],
"Base": ["2023-08-08_12-53-38_train", "2023-08-08_19-10-02_train", "2023-08-08_19-10-51_train"]
}
def read_data_from_csv(subdirectory_path):
print(subdirectory_path)
data = []
csv_files = [file for file in os.listdir(subdirectory_path) if file.endswith('.csv')]
for csv_file in csv_files:
file_path = os.path.join(subdirectory_path, csv_file)
with open(file_path, 'r') as file:
reader = csv.reader(file)
header_skipped = False
for row in reader:
if not header_skipped:
header_skipped = True
continue
frame, m1_pred, m2_pred, m12_pred, m21_pred, mc_pred, m1_label, m2_label, m12_label, m21_label, mc_label, false_belief = row
data.append({
'frame': int(frame),
'm1_pred': int(m1_pred),
'm2_pred': int(m2_pred),
'm12_pred': int(m12_pred),
'm21_pred': int(m21_pred),
'mc_pred': int(mc_pred),
'm1_label': int(m1_label),
'm2_label': int(m2_label),
'm12_label': int(m12_label),
'm21_label': int(m21_label),
'mc_label': int(mc_label),
'false_belief': false_belief,
})
return data
def compute_correct_false_belief(data, mind="all", folder=None):
total_false_belief = 0
correct_false_belief = 0
for item in data:
if 'false' in item['false_belief']:
false_belief_type = item['false_belief'].split('_')[0]
if mind == "all" or false_belief_type in mind:
total_false_belief += 1
if item[f"{false_belief_type}_pred"] == item[f"{false_belief_type}_label"]:
if folder is not None:
with open(f"predictions/{folder}/fb_{'_'.join(mind)}.txt" if isinstance(mind, list) else f"predictions/{folder}/fb_{mind}.txt", "a") as f:
f.write(f"{str(1)}\n")
correct_false_belief += 1
else:
if folder is not None:
with open(f"predictions/{folder}/fb_{'_'.join(mind)}.txt" if isinstance(mind, list) else f"predictions/{folder}/fb_{mind}.txt", "a") as f:
f.write(f"{str(0)}\n")
if total_false_belief == 0:
accuracy = 0.0
else:
accuracy = correct_false_belief / total_false_belief
return accuracy
def compute_macro_f1_score(data, mind="all"):
y_true = []
y_pred = []
for item in data:
if 'false' in item['false_belief']:
false_belief_type = item['false_belief'].split('_')[0]
if mind == "all" or false_belief_type in mind:
y_true.append(int(item[f"{false_belief_type}_label"]))
y_pred.append(int(item[f"{false_belief_type}_pred"]))
if not y_true or not y_pred:
macro_f1 = 0.0
else:
macro_f1 = f1_score(y_true, y_pred, average='macro')
return macro_f1
def delete_files_in_subfolders(folder_path, file_names_to_delete):
"""
Delete specified files in all subfolders of a given folder.
Parameters:
folder_path: The path to the folder containing subfolders.
file_names_to_delete: A list of file names to be deleted.
Returns:
None
"""
for root, _, _ in os.walk(folder_path):
for file_name in file_names_to_delete:
file_path = os.path.join(root, file_name)
if os.path.exists(file_path):
os.remove(file_path)
print(f"Deleted: {file_path}")
if __name__ == "__main__":
folder_path = "predictions"
files_to_delete = ["fb_m1_m2_m12_m21.txt", "fb_m1_m2.txt", "fb_m12_m21.txt"]
delete_files_in_subfolders(folder_path, files_to_delete)
metric = "Accuracy"
if metric == "Macro F1":
score_function = compute_macro_f1_score
elif metric == "Accuracy":
score_function = compute_correct_false_belief
else:
raise ValueError
models = [
'Base', 'DB',
'CG$\parallel$', 'CG$\oplus$', 'CG$\otimes$', 'CG$\odot$',
'IC$\parallel$', 'IC$\oplus$', 'IC$\otimes$', 'IC$\odot$'
]
parent_dir = 'predictions'
minds = categories = ['m1', 'm2', 'm12', 'm21']
score_m1_m2 = []
score_m12_m21 = []
score_all = []
std_m1_m2 = []
std_m12_m21 = []
std_all = []
for model in models:
model_scores_m1_m2 = []
model_scores_m12_m21 = []
model_scores_all = []
for s in range(3):
subdir_path = os.path.join(parent_dir, model_to_subdir[model][s])
data = read_data_from_csv(subdir_path)
model_scores_m1_m2.append(score_function(data, ['m1', 'm2'], model_to_subdir[model][s]))
model_scores_m12_m21.append(score_function(data, ['m12', 'm21'], model_to_subdir[model][s]))
model_scores_all.append(score_function(data, ['m1', 'm2', 'm12', 'm21'], model_to_subdir[model][s]))
score_m1_m2.append(np.mean(model_scores_m1_m2))
std_m1_m2.append(np.std(model_scores_m1_m2))
score_m12_m21.append(np.mean(model_scores_m12_m21))
std_m12_m21.append(np.std(model_scores_m12_m21))
score_all.append(np.mean(model_scores_all))
std_all.append(np.std(model_scores_all))
# Create a dataframe to use with sns.catplot
data = {
'Model': [m for m in models],
'FO_FB_mean': score_m1_m2,
'FO_FB_std': std_m1_m2,
'SO_FB_mean': score_m12_m21,
'SO_FB_std': std_m12_m21,
'Both_mean': score_all,
'Both_std': std_all
}
models = data['Model']
fo_fb_mean = data['FO_FB_mean']
fo_fb_std = data['FO_FB_std']
so_fb_mean = data['SO_FB_mean']
so_fb_std = data['SO_FB_std']
both_mean = data['Both_mean']
both_std = data['Both_std']
bar_width = BAR_WIDTH
x = np.arange(len(models))
plt.figure(figsize=(13, 3.5))
fo_fb_bars = plt.bar(x - bar_width, fo_fb_mean, width=bar_width, yerr=fo_fb_std, capsize=4, label='First-order false belief', alpha=ALPHA)
so_fb_bars = plt.bar(x, so_fb_mean, width=bar_width, yerr=so_fb_std, capsize=4, label='Second-order false belief', alpha=ALPHA)
both_bars = plt.bar(x + bar_width, both_mean, width=bar_width, yerr=both_std, capsize=4, label='Both', alpha=ALPHA)
def add_labels(bars, std_values):
cnt = 0
for bar, std in zip(bars, std_values):
height = bar.get_height()
offset = std + 0.01
if cnt == 0 or cnt == 1 or cnt == 9:
plt.text(bar.get_x() + bar.get_width() / 2., height + offset, f'{height:.2f}*', ha='center', va='bottom', fontsize=10)
else:
plt.text(bar.get_x() + bar.get_width() / 2., height + offset, f'{height:.2f}', ha='center', va='bottom', fontsize=10)
cnt = cnt + 1
add_labels(fo_fb_bars, fo_fb_std)
add_labels(so_fb_bars, so_fb_std)
add_labels(both_bars, both_std)
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)
plt.xlabel('MToMnet', fontsize=14)
plt.ylabel('Macro F1 Score' if metric == "Macro F1" else 'Accuracy', fontsize=14)
plt.xticks(x, models, rotation=0, fontsize=14)
plt.yticks(fontsize=14)
plt.legend(fontsize=14, loc='upper center', bbox_to_anchor=(0.5, 1.3), ncol=3)
plt.tight_layout()
plt.savefig('results/false_belief_first_vs_second.pdf')