82 lines
2.8 KiB
Python
82 lines
2.8 KiB
Python
|
import json
|
||
|
import numpy as np
|
||
|
import matplotlib.pyplot as plt
|
||
|
import seaborn as sns
|
||
|
|
||
|
sns.set_theme(style='whitegrid')
|
||
|
|
||
|
COLORS = sns.color_palette()
|
||
|
|
||
|
MTOM_COLORS = {
|
||
|
"MN1": (110/255, 117/255, 161/255),
|
||
|
"MN2": (179/255, 106/255, 98/255),
|
||
|
"Base": (193/255, 198/255, 208/255),
|
||
|
"CG": (170/255, 129/255, 42/255),
|
||
|
"IC": (97/255, 112/255, 83/255),
|
||
|
"DB": (144/255, 63/255, 110/255)
|
||
|
}
|
||
|
|
||
|
abl_tom_cm_concat = json.load(open('results/abl_cm_concat.json'))
|
||
|
|
||
|
abl_old_bbox_mean = [0.539406718, 0.5348262324, 0.529845863]
|
||
|
abl_old_bbox_std = [0.03639921819, 0.01519544901, 0.01718265794]
|
||
|
|
||
|
filename = 'results/abl_cm_concat_old_vs_new_bbox'
|
||
|
|
||
|
#def plot_scores_histogram(data, filename, size=(8,6), rotation=0, colors=None):
|
||
|
|
||
|
means = []
|
||
|
stds = []
|
||
|
for key, values in abl_tom_cm_concat.items():
|
||
|
mean = np.mean(values)
|
||
|
std = np.std(values)
|
||
|
means.append(mean)
|
||
|
stds.append(std)
|
||
|
fig, ax = plt.subplots(figsize=(10,4))
|
||
|
x = np.arange(len(abl_tom_cm_concat))
|
||
|
width = 0.6
|
||
|
rects1 = ax.bar(
|
||
|
x, means, width, label='New bbox', yerr=stds,
|
||
|
capsize=5,
|
||
|
color=[MTOM_COLORS['CG']]*8,
|
||
|
edgecolor='black',
|
||
|
linewidth=1.5,
|
||
|
alpha=0.6
|
||
|
)
|
||
|
rects2 = ax.bar(
|
||
|
[0, 2, 5], abl_old_bbox_mean, width, label='Original bbox', yerr=abl_old_bbox_std,
|
||
|
capsize=5,
|
||
|
color=[COLORS[9]]*8,
|
||
|
edgecolor='black',
|
||
|
linewidth=1.5,
|
||
|
alpha=0.6
|
||
|
)
|
||
|
ax.set_ylabel('Accuracy', fontsize=18)
|
||
|
xticklabels = list(abl_tom_cm_concat.keys())
|
||
|
ax.set_xticks(np.arange(len(xticklabels)))
|
||
|
ax.set_xticklabels(xticklabels, rotation=0, fontsize=16)
|
||
|
ax.set_yticklabels(ax.get_yticklabels(), fontsize=16)
|
||
|
# Add value labels above each bar, avoiding overlapping with error bars
|
||
|
for rect, std in zip(rects1, stds):
|
||
|
height = rect.get_height()
|
||
|
if height + std < ax.get_ylim()[1]:
|
||
|
ax.annotate(f'{height:.3f}', xy=(rect.get_x() + rect.get_width() / 2, height + std),
|
||
|
xytext=(0, 5), textcoords="offset points", ha='center', va='bottom', fontsize=14)
|
||
|
else:
|
||
|
ax.annotate(f'{height:.3f}', xy=(rect.get_x() + rect.get_width() / 2, height - std),
|
||
|
xytext=(0, -12), textcoords="offset points", ha='center', va='top', fontsize=14)
|
||
|
for rect, std in zip(rects2, abl_old_bbox_std):
|
||
|
height = rect.get_height()
|
||
|
if height + std < ax.get_ylim()[1]:
|
||
|
ax.annotate(f'{height:.3f}', xy=(rect.get_x() + rect.get_width() / 2, height + std),
|
||
|
xytext=(0, 5), textcoords="offset points", ha='center', va='bottom', fontsize=14)
|
||
|
else:
|
||
|
ax.annotate(f'{height:.3f}', xy=(rect.get_x() + rect.get_width() / 2, height - std),
|
||
|
xytext=(0, -12), textcoords="offset points", ha='center', va='top', fontsize=14)
|
||
|
# Remove spines
|
||
|
ax.spines['top'].set_visible(False)
|
||
|
ax.spines['right'].set_visible(False)
|
||
|
ax.legend(fontsize=14)
|
||
|
ax.grid(axis='x')
|
||
|
plt.tight_layout()
|
||
|
plt.savefig(f'{filename}.pdf', bbox_inches='tight')
|