up
This commit is contained in:
parent
d4aaf7f4ad
commit
25b8b3f343
55 changed files with 7592 additions and 4 deletions
82
boss/plots/old_vs_new_bbox.py
Normal file
82
boss/plots/old_vs_new_bbox.py
Normal file
|
@ -0,0 +1,82 @@
|
|||
import json
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
|
||||
sns.set_theme(style='whitegrid')
|
||||
|
||||
COLORS = sns.color_palette()
|
||||
|
||||
MTOM_COLORS = {
|
||||
"MN1": (110/255, 117/255, 161/255),
|
||||
"MN2": (179/255, 106/255, 98/255),
|
||||
"Base": (193/255, 198/255, 208/255),
|
||||
"CG": (170/255, 129/255, 42/255),
|
||||
"IC": (97/255, 112/255, 83/255),
|
||||
"DB": (144/255, 63/255, 110/255)
|
||||
}
|
||||
|
||||
abl_tom_cm_concat = json.load(open('results/abl_cm_concat.json'))
|
||||
|
||||
abl_old_bbox_mean = [0.539406718, 0.5348262324, 0.529845863]
|
||||
abl_old_bbox_std = [0.03639921819, 0.01519544901, 0.01718265794]
|
||||
|
||||
filename = 'results/abl_cm_concat_old_vs_new_bbox'
|
||||
|
||||
#def plot_scores_histogram(data, filename, size=(8,6), rotation=0, colors=None):
|
||||
|
||||
means = []
|
||||
stds = []
|
||||
for key, values in abl_tom_cm_concat.items():
|
||||
mean = np.mean(values)
|
||||
std = np.std(values)
|
||||
means.append(mean)
|
||||
stds.append(std)
|
||||
fig, ax = plt.subplots(figsize=(10,4))
|
||||
x = np.arange(len(abl_tom_cm_concat))
|
||||
width = 0.6
|
||||
rects1 = ax.bar(
|
||||
x, means, width, label='New bbox', yerr=stds,
|
||||
capsize=5,
|
||||
color=[MTOM_COLORS['CG']]*8,
|
||||
edgecolor='black',
|
||||
linewidth=1.5,
|
||||
alpha=0.6
|
||||
)
|
||||
rects2 = ax.bar(
|
||||
[0, 2, 5], abl_old_bbox_mean, width, label='Original bbox', yerr=abl_old_bbox_std,
|
||||
capsize=5,
|
||||
color=[COLORS[9]]*8,
|
||||
edgecolor='black',
|
||||
linewidth=1.5,
|
||||
alpha=0.6
|
||||
)
|
||||
ax.set_ylabel('Accuracy', fontsize=18)
|
||||
xticklabels = list(abl_tom_cm_concat.keys())
|
||||
ax.set_xticks(np.arange(len(xticklabels)))
|
||||
ax.set_xticklabels(xticklabels, rotation=0, fontsize=16)
|
||||
ax.set_yticklabels(ax.get_yticklabels(), fontsize=16)
|
||||
# Add value labels above each bar, avoiding overlapping with error bars
|
||||
for rect, std in zip(rects1, stds):
|
||||
height = rect.get_height()
|
||||
if height + std < ax.get_ylim()[1]:
|
||||
ax.annotate(f'{height:.3f}', xy=(rect.get_x() + rect.get_width() / 2, height + std),
|
||||
xytext=(0, 5), textcoords="offset points", ha='center', va='bottom', fontsize=14)
|
||||
else:
|
||||
ax.annotate(f'{height:.3f}', xy=(rect.get_x() + rect.get_width() / 2, height - std),
|
||||
xytext=(0, -12), textcoords="offset points", ha='center', va='top', fontsize=14)
|
||||
for rect, std in zip(rects2, abl_old_bbox_std):
|
||||
height = rect.get_height()
|
||||
if height + std < ax.get_ylim()[1]:
|
||||
ax.annotate(f'{height:.3f}', xy=(rect.get_x() + rect.get_width() / 2, height + std),
|
||||
xytext=(0, 5), textcoords="offset points", ha='center', va='bottom', fontsize=14)
|
||||
else:
|
||||
ax.annotate(f'{height:.3f}', xy=(rect.get_x() + rect.get_width() / 2, height - std),
|
||||
xytext=(0, -12), textcoords="offset points", ha='center', va='top', fontsize=14)
|
||||
# Remove spines
|
||||
ax.spines['top'].set_visible(False)
|
||||
ax.spines['right'].set_visible(False)
|
||||
ax.legend(fontsize=14)
|
||||
ax.grid(axis='x')
|
||||
plt.tight_layout()
|
||||
plt.savefig(f'{filename}.pdf', bbox_inches='tight')
|
82
boss/plots/pca.py
Normal file
82
boss/plots/pca.py
Normal file
|
@ -0,0 +1,82 @@
|
|||
import torch
|
||||
import os
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from sklearn.decomposition import PCA
|
||||
|
||||
|
||||
FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-22_12-00-38_train_None" # no_tom seed 1
|
||||
|
||||
#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-17_23-39-41_train_None" # impl mult seed 1
|
||||
#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-18_12-47-07_train_None" # impl sum seed 1
|
||||
#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-18_15-58-44_train_None" # impl attn seed 1
|
||||
#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-18_12-58-04_train_None" # impl concat seed 1
|
||||
|
||||
#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-17_23-40-01_train_None" # cm mult seed 1
|
||||
#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-18_12-45-55_train_None" # cm sum seed 1
|
||||
#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-18_12-50-42_train_None" # cm attn seed 1
|
||||
#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-18_12-57-15_train_None" # cm concat seed 1
|
||||
|
||||
#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-17_23-37-50_train_None" # db seed 1
|
||||
|
||||
print(FOLDER_PATH)
|
||||
|
||||
MTOM_COLORS = {
|
||||
"MN1": (110/255, 117/255, 161/255),
|
||||
"MN2": (179/255, 106/255, 98/255),
|
||||
"Base": (193/255, 198/255, 208/255),
|
||||
"CG": (170/255, 129/255, 42/255),
|
||||
"IC": (97/255, 112/255, 83/255),
|
||||
"DB": (144/255, 63/255, 110/255)
|
||||
}
|
||||
|
||||
sns.set_theme(style='white')
|
||||
|
||||
for i in range(60):
|
||||
|
||||
print(f'Computing analysis for test video {i}...', end='\r')
|
||||
|
||||
emb_file = os.path.join(FOLDER_PATH, f'{i}.pt')
|
||||
data = torch.load(emb_file)
|
||||
if len(data) == 14: # implicit
|
||||
model = 'impl'
|
||||
out_left, cell_left, out_right, cell_right, feats = data[0], data[1], data[2], data[3], data[4:]
|
||||
out_left = out_left.squeeze(0)
|
||||
cell_left = cell_left.squeeze(0)
|
||||
out_right = out_right.squeeze(0)
|
||||
cell_right = cell_right.squeeze(0)
|
||||
elif len(data) == 13: # common mind
|
||||
model = 'cm'
|
||||
out_left, out_right, common_mind, feats = data[0], data[1], data[2], data[3:]
|
||||
out_left = out_left.squeeze(0)
|
||||
out_right = out_right.squeeze(0)
|
||||
common_mind = common_mind.squeeze(0)
|
||||
elif len(data) == 12: # speaker-listener
|
||||
model = 'sl'
|
||||
out_left, out_right, feats = data[0], data[1], data[2:]
|
||||
out_left = out_left.squeeze(0)
|
||||
out_right = out_right.squeeze(0)
|
||||
else: raise ValueError("Data should have 14 (impl), 13 (cm) or 12 (sl) elements!")
|
||||
|
||||
# ====== PCA for left and right embeddings ====== #
|
||||
|
||||
out_left_and_right = np.concatenate((out_left, out_right), axis=0)
|
||||
|
||||
pca = PCA(n_components=2)
|
||||
pca_result = pca.fit_transform(out_left_and_right)
|
||||
|
||||
# Separate the PCA results for each tensor
|
||||
pca_result_left = pca_result[:out_left.shape[0]]
|
||||
pca_result_right = pca_result[out_right.shape[0]:]
|
||||
|
||||
plt.figure(figsize=(7,6))
|
||||
plt.scatter(pca_result_left[:, 0], pca_result_left[:, 1], label='$h_1$', color=MTOM_COLORS['MN1'], s=100)
|
||||
plt.scatter(pca_result_right[:, 0], pca_result_right[:, 1], label='$h_2$', color=MTOM_COLORS['MN2'], s=100)
|
||||
plt.xlabel('Principal Component 1', fontsize=30)
|
||||
plt.ylabel('Principal Component 2', fontsize=30)
|
||||
plt.grid(False)
|
||||
plt.legend(fontsize=30)
|
||||
plt.tight_layout()
|
||||
plt.savefig(f'{FOLDER_PATH}/{i}_pca.pdf')
|
||||
plt.close()
|
Loading…
Add table
Add a link
Reference in a new issue