mtomnet/tbd/utils/similarity.py

75 lines
2.5 KiB
Python
Raw Permalink Normal View History

2025-01-10 15:39:20 +01:00
import os
import torch
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import seaborn as sns
FOLDER_PATH = 'PATH_TO_FOLDER'
print(FOLDER_PATH)
MTOM_COLORS = {
"MN1": (110/255, 117/255, 161/255),
"MN2": (179/255, 106/255, 98/255),
"Base": (193/255, 198/255, 208/255),
"CG": (170/255, 129/255, 42/255),
"IC": (97/255, 112/255, 83/255),
"DB": (144/255, 63/255, 110/255)
}
COLORS = sns.color_palette()
sns.set_theme(style='white')
out_left_main_mods_full_test = []
out_right_main_mods_full_test = []
cell_left_main_mods_full_test = []
cell_right_main_mods_full_test = []
cm_left_main_mods_full_test = []
cm_right_main_mods_full_test = []
for i in range(len([filename for filename in os.listdir(FOLDER_PATH) if filename.endswith('.pt')])):
print(f'Computing analysis for test video {i}...', end='\r')
emb_file = os.path.join(FOLDER_PATH, f'{i}.pt')
data = torch.load(emb_file)
if len(data) == 13: # implicit
model = 'impl'
out_left, cell_left, out_right, cell_right, feats = data[0], data[1], data[2], data[3], data[4:]
elif len(data) == 12: # common mind
model = 'cm'
out_left, out_right, common_mind, feats = data[0], data[1], data[2], data[3:]
elif len(data) == 11: # speaker-listener
model = 'sl'
out_left, out_right, feats = data[0], data[1], data[2:]
else: raise ValueError("Data should have 13 (impl), others are not implemented")
# ====== PCA for left and right embeddings ====== #
out_left_pca = out_left[0].reshape(-1, 64)
out_right_pca = out_right[0].reshape(-1, 64)
out_left_and_right = np.concatenate((out_left_pca, out_right_pca), axis=0)
pca = PCA(n_components=2)
pca_result = pca.fit_transform(out_left_and_right)
# Separate the PCA results for each tensor
pca_result_left = pca_result[:out_left_pca.shape[0]]
pca_result_right = pca_result[out_right_pca.shape[0]:]
plt.figure(figsize=(6.8,6))
plt.scatter(pca_result_left[:, 0], pca_result_left[:, 1], label='$h_1$', color=MTOM_COLORS['MN1'], s=100)
plt.scatter(pca_result_right[:, 0], pca_result_right[:, 1], label='$h_2$', color=MTOM_COLORS['MN2'], s=100)
plt.xlabel('Principal Component 1', fontsize=32)
plt.ylabel('Principal Component 2', fontsize=32)
plt.xticks(fontsize=24)
plt.xticks([-0.4, -0.2, 0.0, 0.2, 0.4])
plt.yticks(fontsize=24)
plt.legend(fontsize=32)
plt.tight_layout()
plt.savefig(f'{FOLDER_PATH}/{i}_pca.pdf')
plt.close()