import os import csv import matplotlib.pyplot as plt import seaborn as sns from sklearn.metrics import f1_score import numpy as np from tqdm import tqdm ALPHA = 0.7 BAR_WIDTH = 0.27 sns.set_theme(style='whitegrid') #sns.set_palette('mako') MTOM_COLORS = { "MN1": (110/255, 117/255, 161/255), "MN2": (179/255, 106/255, 98/255), "Base": (193/255, 198/255, 208/255), "CG": (170/255, 129/255, 42/255), "IC": (97/255, 112/255, 83/255), "DB": (144/255, 63/255, 110/255) } model_to_subdir = { "IC$\parallel$": ["2023-07-16_10-34-32_train", "2023-07-18_13-49-57_train", "2023-07-19_12-17-46_train"], "IC$\oplus$": ["2023-07-16_10-35-02_train", "2023-07-18_13-50-32_train", "2023-07-19_12-18-18_train"], "IC$\otimes$": ["2023-07-16_10-35-41_train", "2023-07-18_13-52-26_train", "2023-07-19_12-18-49_train"], "IC$\odot$": ["2023-07-16_10-36-04_train", "2023-07-18_13-53-03_train", "2023-07-19_12-19-50_train"], "CG$\parallel$": ["2023-07-15_14-12-36_train", "2023-07-17_11-54-28_train", "2023-07-19_00-30-05_train"], "CG$\oplus$": ["2023-07-15_14-14-08_train", "2023-07-17_11-56-05_train", "2023-07-19_00-30-47_train"], "CG$\otimes$": ["2023-07-15_14-14-53_train", "2023-07-17_11-56-39_train", "2023-07-19_00-31-36_train"], "CG$\odot$": ["2023-07-15_14-10-05_train", "2023-07-17_11-57-30_train", "2023-07-19_00-32-10_train"], "DB": ["2023-08-08_12-56-02_train", "2023-08-08_19-07-43_train", "2023-08-08_19-08-47_train"], "Base": ["2023-08-08_12-53-38_train", "2023-08-08_19-10-02_train", "2023-08-08_19-10-51_train"] } def read_data_from_csv(subdirectory_path): print(subdirectory_path) data = [] csv_files = [file for file in os.listdir(subdirectory_path) if file.endswith('.csv')] for csv_file in csv_files: file_path = os.path.join(subdirectory_path, csv_file) with open(file_path, 'r') as file: reader = csv.reader(file) header_skipped = False for row in reader: if not header_skipped: header_skipped = True continue frame, m1_pred, m2_pred, m12_pred, m21_pred, mc_pred, m1_label, m2_label, m12_label, m21_label, mc_label, false_belief = row data.append({ 'frame': int(frame), 'm1_pred': int(m1_pred), 'm2_pred': int(m2_pred), 'm12_pred': int(m12_pred), 'm21_pred': int(m21_pred), 'mc_pred': int(mc_pred), 'm1_label': int(m1_label), 'm2_label': int(m2_label), 'm12_label': int(m12_label), 'm21_label': int(m21_label), 'mc_label': int(mc_label), 'false_belief': false_belief, }) return data def compute_correct_false_belief(data, mind="all", folder=None): total_false_belief = 0 correct_false_belief = 0 for item in data: if 'false' in item['false_belief']: false_belief_type = item['false_belief'].split('_')[0] if mind == "all" or false_belief_type in mind: total_false_belief += 1 if item[f"{false_belief_type}_pred"] == item[f"{false_belief_type}_label"]: if folder is not None: with open(f"predictions/{folder}/fb_{'_'.join(mind)}.txt" if isinstance(mind, list) else f"predictions/{folder}/fb_{mind}.txt", "a") as f: f.write(f"{str(1)}\n") correct_false_belief += 1 else: if folder is not None: with open(f"predictions/{folder}/fb_{'_'.join(mind)}.txt" if isinstance(mind, list) else f"predictions/{folder}/fb_{mind}.txt", "a") as f: f.write(f"{str(0)}\n") if total_false_belief == 0: accuracy = 0.0 else: accuracy = correct_false_belief / total_false_belief return accuracy def compute_macro_f1_score(data, mind="all"): y_true = [] y_pred = [] for item in data: if 'false' in item['false_belief']: false_belief_type = item['false_belief'].split('_')[0] if mind == "all" or false_belief_type in mind: y_true.append(int(item[f"{false_belief_type}_label"])) y_pred.append(int(item[f"{false_belief_type}_pred"])) if not y_true or not y_pred: macro_f1 = 0.0 else: macro_f1 = f1_score(y_true, y_pred, average='macro') return macro_f1 def delete_files_in_subfolders(folder_path, file_names_to_delete): """ Delete specified files in all subfolders of a given folder. Parameters: folder_path: The path to the folder containing subfolders. file_names_to_delete: A list of file names to be deleted. Returns: None """ for root, _, _ in os.walk(folder_path): for file_name in file_names_to_delete: file_path = os.path.join(root, file_name) if os.path.exists(file_path): os.remove(file_path) print(f"Deleted: {file_path}") if __name__ == "__main__": folder_path = "predictions" files_to_delete = ["fb_m1_m2_m12_m21.txt", "fb_m1_m2.txt", "fb_m12_m21.txt"] delete_files_in_subfolders(folder_path, files_to_delete) metric = "Accuracy" if metric == "Macro F1": score_function = compute_macro_f1_score elif metric == "Accuracy": score_function = compute_correct_false_belief else: raise ValueError models = [ 'Base', 'DB', 'CG$\parallel$', 'CG$\oplus$', 'CG$\otimes$', 'CG$\odot$', 'IC$\parallel$', 'IC$\oplus$', 'IC$\otimes$', 'IC$\odot$' ] parent_dir = 'predictions' minds = categories = ['m1', 'm2', 'm12', 'm21'] score_m1_m2 = [] score_m12_m21 = [] score_all = [] std_m1_m2 = [] std_m12_m21 = [] std_all = [] for model in models: model_scores_m1_m2 = [] model_scores_m12_m21 = [] model_scores_all = [] for s in range(3): subdir_path = os.path.join(parent_dir, model_to_subdir[model][s]) data = read_data_from_csv(subdir_path) model_scores_m1_m2.append(score_function(data, ['m1', 'm2'], model_to_subdir[model][s])) model_scores_m12_m21.append(score_function(data, ['m12', 'm21'], model_to_subdir[model][s])) model_scores_all.append(score_function(data, ['m1', 'm2', 'm12', 'm21'], model_to_subdir[model][s])) score_m1_m2.append(np.mean(model_scores_m1_m2)) std_m1_m2.append(np.std(model_scores_m1_m2)) score_m12_m21.append(np.mean(model_scores_m12_m21)) std_m12_m21.append(np.std(model_scores_m12_m21)) score_all.append(np.mean(model_scores_all)) std_all.append(np.std(model_scores_all)) # Create a dataframe to use with sns.catplot data = { 'Model': [m for m in models], 'FO_FB_mean': score_m1_m2, 'FO_FB_std': std_m1_m2, 'SO_FB_mean': score_m12_m21, 'SO_FB_std': std_m12_m21, 'Both_mean': score_all, 'Both_std': std_all } models = data['Model'] fo_fb_mean = data['FO_FB_mean'] fo_fb_std = data['FO_FB_std'] so_fb_mean = data['SO_FB_mean'] so_fb_std = data['SO_FB_std'] both_mean = data['Both_mean'] both_std = data['Both_std'] bar_width = BAR_WIDTH x = np.arange(len(models)) plt.figure(figsize=(13, 3.5)) fo_fb_bars = plt.bar(x - bar_width, fo_fb_mean, width=bar_width, yerr=fo_fb_std, capsize=4, label='First-order false belief', alpha=ALPHA) so_fb_bars = plt.bar(x, so_fb_mean, width=bar_width, yerr=so_fb_std, capsize=4, label='Second-order false belief', alpha=ALPHA) both_bars = plt.bar(x + bar_width, both_mean, width=bar_width, yerr=both_std, capsize=4, label='Both', alpha=ALPHA) def add_labels(bars, std_values): cnt = 0 for bar, std in zip(bars, std_values): height = bar.get_height() offset = std + 0.01 if cnt == 0 or cnt == 1 or cnt == 9: plt.text(bar.get_x() + bar.get_width() / 2., height + offset, f'{height:.2f}*', ha='center', va='bottom', fontsize=10) else: plt.text(bar.get_x() + bar.get_width() / 2., height + offset, f'{height:.2f}', ha='center', va='bottom', fontsize=10) cnt = cnt + 1 add_labels(fo_fb_bars, fo_fb_std) add_labels(so_fb_bars, so_fb_std) add_labels(both_bars, both_std) plt.gca().spines['top'].set_visible(False) plt.gca().spines['right'].set_visible(False) plt.xlabel('MToMnet', fontsize=14) plt.ylabel('Macro F1 Score' if metric == "Macro F1" else 'Accuracy', fontsize=14) plt.xticks(x, models, rotation=0, fontsize=14) plt.yticks(fontsize=14) plt.legend(fontsize=14, loc='upper center', bbox_to_anchor=(0.5, 1.3), ncol=3) plt.tight_layout() plt.savefig('results/false_belief_first_vs_second.pdf')