mtomnet/boss/plots/old_vs_new_bbox.py

82 lines
2.8 KiB
Python
Raw Permalink Normal View History

2025-01-10 15:39:20 +01:00
import json
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme(style='whitegrid')
COLORS = sns.color_palette()
MTOM_COLORS = {
"MN1": (110/255, 117/255, 161/255),
"MN2": (179/255, 106/255, 98/255),
"Base": (193/255, 198/255, 208/255),
"CG": (170/255, 129/255, 42/255),
"IC": (97/255, 112/255, 83/255),
"DB": (144/255, 63/255, 110/255)
}
abl_tom_cm_concat = json.load(open('results/abl_cm_concat.json'))
abl_old_bbox_mean = [0.539406718, 0.5348262324, 0.529845863]
abl_old_bbox_std = [0.03639921819, 0.01519544901, 0.01718265794]
filename = 'results/abl_cm_concat_old_vs_new_bbox'
#def plot_scores_histogram(data, filename, size=(8,6), rotation=0, colors=None):
means = []
stds = []
for key, values in abl_tom_cm_concat.items():
mean = np.mean(values)
std = np.std(values)
means.append(mean)
stds.append(std)
fig, ax = plt.subplots(figsize=(10,4))
x = np.arange(len(abl_tom_cm_concat))
width = 0.6
rects1 = ax.bar(
x, means, width, label='New bbox', yerr=stds,
capsize=5,
color=[MTOM_COLORS['CG']]*8,
edgecolor='black',
linewidth=1.5,
alpha=0.6
)
rects2 = ax.bar(
[0, 2, 5], abl_old_bbox_mean, width, label='Original bbox', yerr=abl_old_bbox_std,
capsize=5,
color=[COLORS[9]]*8,
edgecolor='black',
linewidth=1.5,
alpha=0.6
)
ax.set_ylabel('Accuracy', fontsize=18)
xticklabels = list(abl_tom_cm_concat.keys())
ax.set_xticks(np.arange(len(xticklabels)))
ax.set_xticklabels(xticklabels, rotation=0, fontsize=16)
ax.set_yticklabels(ax.get_yticklabels(), fontsize=16)
# Add value labels above each bar, avoiding overlapping with error bars
for rect, std in zip(rects1, stds):
height = rect.get_height()
if height + std < ax.get_ylim()[1]:
ax.annotate(f'{height:.3f}', xy=(rect.get_x() + rect.get_width() / 2, height + std),
xytext=(0, 5), textcoords="offset points", ha='center', va='bottom', fontsize=14)
else:
ax.annotate(f'{height:.3f}', xy=(rect.get_x() + rect.get_width() / 2, height - std),
xytext=(0, -12), textcoords="offset points", ha='center', va='top', fontsize=14)
for rect, std in zip(rects2, abl_old_bbox_std):
height = rect.get_height()
if height + std < ax.get_ylim()[1]:
ax.annotate(f'{height:.3f}', xy=(rect.get_x() + rect.get_width() / 2, height + std),
xytext=(0, 5), textcoords="offset points", ha='center', va='bottom', fontsize=14)
else:
ax.annotate(f'{height:.3f}', xy=(rect.get_x() + rect.get_width() / 2, height - std),
xytext=(0, -12), textcoords="offset points", ha='center', va='top', fontsize=14)
# Remove spines
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.legend(fontsize=14)
ax.grid(axis='x')
plt.tight_layout()
plt.savefig(f'{filename}.pdf', bbox_inches='tight')