82 lines
No EOL
3.4 KiB
Python
82 lines
No EOL
3.4 KiB
Python
import torch
|
|
import os
|
|
import seaborn as sns
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
from sklearn.decomposition import PCA
|
|
|
|
|
|
FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-22_12-00-38_train_None" # no_tom seed 1
|
|
|
|
#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-17_23-39-41_train_None" # impl mult seed 1
|
|
#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-18_12-47-07_train_None" # impl sum seed 1
|
|
#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-18_15-58-44_train_None" # impl attn seed 1
|
|
#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-18_12-58-04_train_None" # impl concat seed 1
|
|
|
|
#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-17_23-40-01_train_None" # cm mult seed 1
|
|
#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-18_12-45-55_train_None" # cm sum seed 1
|
|
#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-18_12-50-42_train_None" # cm attn seed 1
|
|
#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-18_12-57-15_train_None" # cm concat seed 1
|
|
|
|
#FOLDER_PATH = "/scratch/bortoletto/dev/boss/predictions/2023-05-17_23-37-50_train_None" # db seed 1
|
|
|
|
print(FOLDER_PATH)
|
|
|
|
MTOM_COLORS = {
|
|
"MN1": (110/255, 117/255, 161/255),
|
|
"MN2": (179/255, 106/255, 98/255),
|
|
"Base": (193/255, 198/255, 208/255),
|
|
"CG": (170/255, 129/255, 42/255),
|
|
"IC": (97/255, 112/255, 83/255),
|
|
"DB": (144/255, 63/255, 110/255)
|
|
}
|
|
|
|
sns.set_theme(style='white')
|
|
|
|
for i in range(60):
|
|
|
|
print(f'Computing analysis for test video {i}...', end='\r')
|
|
|
|
emb_file = os.path.join(FOLDER_PATH, f'{i}.pt')
|
|
data = torch.load(emb_file)
|
|
if len(data) == 14: # implicit
|
|
model = 'impl'
|
|
out_left, cell_left, out_right, cell_right, feats = data[0], data[1], data[2], data[3], data[4:]
|
|
out_left = out_left.squeeze(0)
|
|
cell_left = cell_left.squeeze(0)
|
|
out_right = out_right.squeeze(0)
|
|
cell_right = cell_right.squeeze(0)
|
|
elif len(data) == 13: # common mind
|
|
model = 'cm'
|
|
out_left, out_right, common_mind, feats = data[0], data[1], data[2], data[3:]
|
|
out_left = out_left.squeeze(0)
|
|
out_right = out_right.squeeze(0)
|
|
common_mind = common_mind.squeeze(0)
|
|
elif len(data) == 12: # speaker-listener
|
|
model = 'sl'
|
|
out_left, out_right, feats = data[0], data[1], data[2:]
|
|
out_left = out_left.squeeze(0)
|
|
out_right = out_right.squeeze(0)
|
|
else: raise ValueError("Data should have 14 (impl), 13 (cm) or 12 (sl) elements!")
|
|
|
|
# ====== PCA for left and right embeddings ====== #
|
|
|
|
out_left_and_right = np.concatenate((out_left, out_right), axis=0)
|
|
|
|
pca = PCA(n_components=2)
|
|
pca_result = pca.fit_transform(out_left_and_right)
|
|
|
|
# Separate the PCA results for each tensor
|
|
pca_result_left = pca_result[:out_left.shape[0]]
|
|
pca_result_right = pca_result[out_right.shape[0]:]
|
|
|
|
plt.figure(figsize=(7,6))
|
|
plt.scatter(pca_result_left[:, 0], pca_result_left[:, 1], label='$h_1$', color=MTOM_COLORS['MN1'], s=100)
|
|
plt.scatter(pca_result_right[:, 0], pca_result_right[:, 1], label='$h_2$', color=MTOM_COLORS['MN2'], s=100)
|
|
plt.xlabel('Principal Component 1', fontsize=30)
|
|
plt.ylabel('Principal Component 2', fontsize=30)
|
|
plt.grid(False)
|
|
plt.legend(fontsize=30)
|
|
plt.tight_layout()
|
|
plt.savefig(f'{FOLDER_PATH}/{i}_pca.pdf')
|
|
plt.close() |