mtomnet/boss/utils.py

769 lines
31 KiB
Python
Raw Normal View History

2025-01-10 15:39:20 +01:00
import os
import torch
import numpy as np
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
#from umap import UMAP
import json
import seaborn as sns
import pandas as pd
from natsort import natsorted
import argparse
import seaborn as sns
COLORS_DM = {
"red": (242/256, 165/256, 179/256),
"blue": (195/256, 219/256, 252/256),
"green": (156/256, 228/256, 213/256),
"yellow": (250/256, 236/256, 144/256),
"violet": (207/256, 187/256, 244/256),
"orange": (244/256, 188/256, 154/256)
}
MTOM_COLORS = {
"MN1": (110/255, 117/255, 161/255),
"MN2": (179/255, 106/255, 98/255),
"Base": (193/255, 198/255, 208/255),
"CG": (170/255, 129/255, 42/255),
"IC": (97/255, 112/255, 83/255),
"DB": (144/255, 63/255, 110/255)
}
COLORS = sns.color_palette()
OBJECTS = [
'none', 'apple', 'orange', 'lemon', 'potato', 'wine', 'wineopener',
'knife', 'mug', 'peeler', 'bowl', 'chocolate', 'sugar', 'magazine',
'cracker', 'chips', 'scissors', 'cap', 'marker', 'sardinecan', 'tomatocan',
'plant', 'walnut', 'nail', 'waterspray', 'hammer', 'canopener'
]
tried_once = [2, 4, 5, 6, 7, 8, 9, 10, 12, 15, 18, 19, 20, 21, 22, 23, 24,
25, 26, 27, 28, 29, 30, 32, 38, 40, 41, 42, 44, 47, 50, 51, 52, 53, 55,
56, 57, 61, 63, 65, 67, 68, 70, 71, 72, 73, 74, 76, 80, 81, 83, 85, 87,
88, 89, 90, 92, 93, 96, 97, 99, 101, 102, 105, 106, 108, 110, 111, 112,
113, 114, 116, 118, 121, 123, 125, 131, 132, 134, 135, 140, 142, 143,
145, 146, 148, 149, 151, 152, 154, 155, 156, 157, 160, 161, 162, 165,
169, 170, 171, 173, 175, 176, 178, 179, 180, 181, 182, 183, 184, 185,
186, 187, 190, 191, 194, 196, 203, 204, 206, 207, 208, 209, 210, 211,
213, 214, 216, 218, 219, 220, 222, 225, 227, 228, 229, 232, 233, 235,
236, 237, 238, 239, 242, 243, 246, 247, 249, 251, 252, 254, 255, 256,
257, 260, 261, 262, 263, 265, 266, 268, 270, 272, 275, 277, 278, 279,
281, 282, 287, 290, 296, 298]
tried_twice = [1, 3, 11, 13, 14, 16, 17, 31, 33, 35, 36, 37, 39,
43, 45, 49, 54, 58, 59, 60, 62, 64, 66, 69, 75, 77, 79, 82, 84, 86,
91, 94, 95, 98, 100, 103, 104, 107, 109, 115, 117, 119, 120, 122,
124, 126, 127, 128, 129, 130, 133, 136, 137, 138, 139, 141, 144,
147, 150, 153, 158, 159, 164, 166, 167, 168, 172, 174, 177, 188,
189, 192, 193, 195, 197, 198, 199, 200, 201, 202, 205, 212, 215,
217, 221, 223, 224, 226, 230, 231, 234, 240, 241, 244, 245, 248,
250, 253, 258, 259, 264, 267, 269, 271, 273, 274, 276, 280, 283,
284, 285, 286, 288, 291, 292, 293, 294, 295, 297, 299]
tried_thrice = [34, 46, 48, 78, 163, 289]
friends = [i for i in range(150, 300)]
strangers = [i for i in range(0, 150)]
def get_input_dim(modalities):
dimensions = {
'rgb': 512,
'pose': 150,
'gaze': 6,
'ocr': 64,
'bbox': 108
}
sum_of_dimensions = sum(dimensions[modality] for modality in modalities)
return sum_of_dimensions
def count_parameters(model):
#return sum(p.numel() for p in model.parameters() if p.requires_grad)
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
return sum([np.prod(p.size()) for p in model_parameters])
def onehot(label, n_classes):
if len(label.size()) == 3:
batch_size, seq_len, _ = label.size()
label_1_2d = label[:,:,0].contiguous().view(batch_size*seq_len, 1)
label_2_2d = label[:,:,1].contiguous().view(batch_size*seq_len, 1)
#onehot_1_2d = torch.zeros(batch_size*seq_len, n_classes).scatter_(1, label_1_2d, 1)
#onehot_2_2d = torch.zeros(batch_size*seq_len, n_classes).scatter_(1, label_2_2d, 1)
#onehot_2d = torch.cat((onehot_1_2d, onehot_2_2d), dim=1)
#onehot_3d = onehot_2d.view(batch_size, seq_len, 2*n_classes)
onehot_1_2d = torch.zeros(batch_size*seq_len, n_classes).scatter_(1, label_1_2d, 1).view(-1, seq_len, n_classes)
onehot_2_2d = torch.zeros(batch_size*seq_len, n_classes).scatter_(1, label_2_2d, 1).view(-1, seq_len, n_classes)
onehot_3d = torch.cat((onehot_1_2d, onehot_2_2d), dim=2)
return onehot_3d
else:
return torch.zeros(label.size(0), n_classes).scatter_(1, label.view(-1, 1), 1)
def mixup(data, alpha, n_classes):
lam = torch.FloatTensor([np.random.beta(alpha, alpha)])
frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths = data
indices = torch.randperm(labels.size(0))
# labels
labels2 = labels[indices]
labels = onehot(labels, n_classes)
labels2 = onehot(labels2, n_classes)
labels = labels * lam + labels2 * (1 - lam)
# frames
frames2 = frames[indices]
frames = frames * lam + frames2 * (1 - lam)
# poses
poses2 = poses[indices]
poses = poses * lam + poses2 * (1 - lam)
# gazes
gazes2 = gazes[indices]
gazes = gazes * lam + gazes2 * (1 - lam)
# bboxes
bboxes2 = bboxes[indices]
bboxes = bboxes * lam + bboxes2 * (1 - lam)
return frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths
def get_classification_accuracy(pred_left_labels, pred_right_labels, labels, sequence_lengths):
max_len = max(sequence_lengths)
pred_left_labels = torch.reshape(pred_left_labels, (-1, max_len, 27))
pred_right_labels = torch.reshape(pred_right_labels, (-1, max_len, 27))
labels = torch.reshape(labels, (-1, max_len, 2))
left_correct = torch.argmax(pred_left_labels, 2) == labels[:,:,0]
right_correct = torch.argmax(pred_right_labels, 2) == labels[:,:,1]
num_pred = sum(sequence_lengths) * 2
num_correct = 0
for i in range(len(sequence_lengths)):
size = sequence_lengths[i]
num_correct += (torch.sum(left_correct[i][:size]) + torch.sum(right_correct[i][:size])).item()
acc = num_correct / num_pred
return acc, num_correct, num_pred
def get_classification_accuracy_mixup(pred_left_labels, pred_right_labels, labels, sequence_lengths):
max_len = max(sequence_lengths)
pred_left_labels = torch.reshape(pred_left_labels, (-1, max_len, 27))
pred_right_labels = torch.reshape(pred_right_labels, (-1, max_len, 27))
labels = torch.reshape(labels, (-1, max_len, 54))
left_correct = torch.argmax(pred_left_labels, 2) == torch.argmax(labels[:,:,:27], 2)
right_correct = torch.argmax(pred_right_labels, 2) == torch.argmax(labels[:,:,27:], 2)
num_pred = sum(sequence_lengths) * 2
num_correct = 0
for i in range(len(sequence_lengths)):
size = sequence_lengths[i]
num_correct += (torch.sum(left_correct[i][:size]) + torch.sum(right_correct[i][:size])).item()
acc = num_correct / num_pred
return acc, num_correct, num_pred
def pad_collate(batch):
(aa, bb, cc, dd, ee, ff) = zip(*batch)
seq_lens = [len(a) for a in aa]
aa_pad = pad_sequence(aa, batch_first=True, padding_value=0)
bb_pad = pad_sequence(bb, batch_first=True, padding_value=-1)
if cc[0] is not None:
cc_pad = pad_sequence(cc, batch_first=True, padding_value=0)
else:
cc_pad = None
if dd[0] is not None:
dd_pad = pad_sequence(dd, batch_first=True, padding_value=0)
else:
dd_pad = None
if ee[0] is not None:
ee_pad = pad_sequence(ee, batch_first=True, padding_value=0)
else:
ee_pad = None
if ff[0] is not None:
ff_pad = pad_sequence(ff, batch_first=True, padding_value=0)
else:
ff_pad = None
return aa_pad, bb_pad, cc_pad, dd_pad, ee_pad, ff_pad, seq_lens
def ocr_loss(pL, lL, pR, lR, ocr, loss, eta=10.0, mode='abs'):
"""
Custom loss based on the negative OCRL matrix.
Input:
pL: tensor of shape [batch_size, num_classes] representing left belief predictions
lL: tensor of shape [batch_size] representing the left belief labels
pR: tensor of shape [batch_size, num_classes] representing right belief predictions
lR: tensor of shape [batch_size] representing the right belief labels
ocr: negative OCR matrix (i.e. 1-OCR)
loss: loss function
eta: hyperparameter for the interaction term
Output:
Final loss resulting from weighting left and right loss with the respective
OCR coefficients and summing them.
"""
bs = pL.shape[0]
if len([*lR[0].size()]) > 0: # for mixup
p = torch.tensor([ocr[torch.argmax(pL[i]), torch.argmax(pR[i])] for i in range(bs)], device=pL.device)
g = torch.tensor([ocr[torch.argmax(lL[i]), torch.argmax(lR[i])] for i in range(bs)], device=pL.device)
else:
p = torch.tensor([ocr[torch.argmax(pL[i]), torch.argmax(pR[i])] for i in range(bs)], device=pL.device)
g = torch.tensor([ocr[lL[i], lR[i]] for i in range(bs)], device=pL.device)
left_loss = torch.mean(loss(pL, lL))
right_loss = torch.mean(loss(pR, lR))
if mode == 'abs':
interaction_loss = torch.mean(torch.abs(g - p))
elif mode == 'mse':
interaction_loss = torch.mean(torch.pow(g - p, 2))
else: raise NotImplementedError
eta = (left_loss + right_loss) / interaction_loss
interaction_loss = interaction_loss * eta
print(f"Left: {left_loss} --- Right: {right_loss} --- Interaction: {interaction_loss}")
return left_loss + right_loss + interaction_loss
def spherical2cartesial(x):
"""
From https://colab.research.google.com/drive/1SJbzd-gFTbiYjfZynIfrG044fWi6svbV?usp=sharing#scrollTo=78QhNw4MYSYp
"""
output = torch.zeros(x.size(0),3)
output[:,2] = -torch.cos(x[:,1])*torch.cos(x[:,0])
output[:,0] = torch.cos(x[:,1])*torch.sin(x[:,0])
output[:,1] = torch.sin(x[:,1])
return output
def presave():
from PIL import Image
from torchvision import transforms
import os
from natsort import natsorted
import pickle as pkl
preprocess = transforms.Compose([
transforms.Resize((128, 128)),
#transforms.Resize(256),
#transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
frame_path = '/scratch/bortoletto/data/boss/test/frame'
frame_dirs = os.listdir(frame_path)
frame_paths = []
for frame_dir in natsorted(frame_dirs):
paths = [os.path.join(frame_path, frame_dir, i) for i in natsorted(os.listdir(os.path.join(frame_path, frame_dir)))]
frame_paths.append(paths)
save_folder = '/scratch/bortoletto/data/boss/presaved128/'+frame_path.split('/')[-2]+'/'+frame_path.split('/')[-1]
if not os.path.exists(save_folder):
os.makedirs(save_folder)
print(save_folder, 'created')
for video in natsorted(frame_paths):
print(video[0].split('/')[-2], end='\r')
images = [preprocess(Image.open(i)) for i in video]
strings = video[0].split('/')
with open(save_folder+'/'+strings[7]+'.pkl', 'wb') as f:
pkl.dump(images, f)
def compute_cosine_similarity(tensor1, tensor2):
return F.cosine_similarity(tensor1, tensor2, dim=-1)
def find_most_similar_embedding(rgb, pose, gaze, ocr, bbox, repr):
gaze_similarity = compute_cosine_similarity(gaze, repr)
pose_similarity = compute_cosine_similarity(pose, repr)
ocr_similarity = compute_cosine_similarity(ocr, repr)
bbox_similarity = compute_cosine_similarity(bbox, repr)
rgb_similarity = compute_cosine_similarity(rgb, repr)
similarities = torch.stack([gaze_similarity, pose_similarity, ocr_similarity, bbox_similarity, rgb_similarity])
max_index = torch.argmax(similarities, dim=0)
main_modality = []
main_modality_name = []
for idx in max_index:
if idx == 0:
main_modality.append(gaze)
main_modality_name.append('gaze')
elif idx == 1:
main_modality.append(pose)
main_modality_name.append('pose')
elif idx == 2:
main_modality.append(ocr)
main_modality_name.append('ocr')
elif idx == 3:
main_modality.append(bbox)
main_modality_name.append('bbox')
else:
main_modality.append(rgb)
main_modality_name.append('rgb')
return main_modality, main_modality_name
def plot_similarity_histogram(values, filename, labels):
def count_elements(list_of_strings, target_list):
counts = []
for element in list_of_strings:
count = target_list.count(element)
counts.append(count)
return counts
fig, ax = plt.subplots(figsize=(12,4))
colors = ["red", "blue", "green", "violet"]
# Remove spines
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
colors = [MTOM_COLORS['MN1'], MTOM_COLORS['MN2'], MTOM_COLORS['CG'], MTOM_COLORS['CG']]
alphas = [0.6, 0.6, 0.6, 0.6]
edgecolors = ['black', 'black', MTOM_COLORS['MN1'], MTOM_COLORS['MN2']]
linewidths = [1.0, 1.0, 5.0, 5.0]
if isinstance(values[0], list):
num_lists = len(values)
bar_width = 0.8 / num_lists
for i, val in enumerate(values):
unique_strings = ['rgb', 'pose', 'gaze', 'bbox', 'ocr']
counts = count_elements(unique_strings, val)
x = np.arange(len(unique_strings))
x_shifted = x + (i - (len(labels) - 1) / 2) * bar_width #x - (bar_width * num_lists / 2) + (bar_width * i)
ax.bar(x_shifted, counts, width=bar_width, label=f'{labels[i]}', color=colors[i], edgecolor=edgecolors[i], linewidth=linewidths[i], alpha=alphas[i])
ax.set_xlabel('Modality', fontsize=18)
ax.set_ylabel('Counts', fontsize=18)
ax.set_xticks(np.arange(len(unique_strings)))
ax.set_xticklabels(unique_strings, fontsize=18)
ax.set_yticklabels(ax.get_yticklabels(), fontsize=16)
ax.legend(fontsize=18)
ax.grid(axis='y')
plt.savefig(filename, bbox_inches='tight')
else:
unique_strings, counts = np.unique(values, return_counts=True)
ax.bar(unique_strings, counts)
ax.set_xlabel('Modality')
ax.set_ylabel('Counts')
plt.savefig(filename, bbox_inches='tight')
def plot_scores_histogram(data, filename, size=(8,6), rotation=0, colors=None):
means = []
stds = []
for key, values in data.items():
mean = np.mean(values)
std = np.std(values)
means.append(mean)
stds.append(std)
fig, ax = plt.subplots(figsize=size)
x = np.arange(len(data))
width = 0.6
rects1 = ax.bar(
x, means, width, label='Mean', yerr=stds,
capsize=5,
color='teal' if colors is None else colors,
edgecolor='black',
linewidth=1.5,
alpha=0.6
)
ax.set_ylabel('Accuracy', fontsize=18)
#ax.set_title('Mean and Standard Deviation of Results', fontsize=14)
if filename == 'results/all':
xticklabels = [
'Base',
'DB',
'CG$\otimes$', 'CG$\oplus$', 'CG$\odot$', 'CG$\parallel$',
'IC$\otimes$', 'IC$\oplus$', 'IC$\odot$', 'IC$\parallel$',
'CNN', 'CNN+GRU', 'CNN+LSTM', 'CNN+Conv1D'
]
else:
xticklabels = list(data.keys())
ax.set_xticks(np.arange(len(xticklabels)))
ax.set_xticklabels(xticklabels, rotation=rotation, fontsize=16) #data.keys(), rotation=rotation, fontsize=16)
ax.set_yticklabels(ax.get_yticklabels(), fontsize=16)
#ax.grid(axis='y')
#ax.set_axisbelow(True)
#ax.legend(loc='upper right')
# Add value labels above each bar, avoiding overlapping with error bars
for rect, std in zip(rects1, stds):
height = rect.get_height()
if height + std < ax.get_ylim()[1]:
ax.annotate(f'{height:.3f}', xy=(rect.get_x() + rect.get_width() / 2, height + std),
xytext=(0, 5), textcoords="offset points", ha='center', va='bottom', fontsize=14)
else:
ax.annotate(f'{height:.3f}', xy=(rect.get_x() + rect.get_width() / 2, height - std),
xytext=(0, -12), textcoords="offset points", ha='center', va='top', fontsize=14)
# Remove spines
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.grid(axis='x')
plt.tight_layout()
plt.savefig(f'{filename}.pdf', bbox_inches='tight')
def plot_confusion_matrices(left_cm, right_cm, labels, title, annot=True):
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
left_xticklabels = [OBJECTS[i] for i in labels[0]]
left_yticklabels = [OBJECTS[i] for i in labels[1]]
right_xticklabels = [OBJECTS[i] for i in labels[2]]
right_yticklabels = [OBJECTS[i] for i in labels[3]]
sns.heatmap(
left_cm,
annot=annot,
fmt='.0f',
cmap='Blues',
cbar=False,
xticklabels=left_xticklabels,
yticklabels=left_yticklabels,
ax=ax1
)
ax1.set_xlabel('Predicted')
ax1.set_ylabel('True')
ax1.set_title('Left Confusion Matrix')
sns.heatmap(
right_cm,
annot=annot,
fmt='.0f',
cmap='Blues',
cbar=False, #True if annot is False else False,
xticklabels=right_xticklabels,
yticklabels=right_yticklabels,
ax=ax2
)
ax2.set_xlabel('Predicted')
ax2.set_ylabel('True')
ax2.set_title('Right Confusion Matrix')
#plt.suptitle(title)
plt.tight_layout()
plt.savefig(title + '.pdf')
plt.close()
def plot_confusion_matrix(confusion_matrix, labels, title, annot=True):
plt.figure(figsize=(6, 6))
xticklabels = [OBJECTS[i] for i in labels[0]]
yticklabels = [OBJECTS[i] for i in labels[1]]
sns.heatmap(
confusion_matrix,
annot=annot,
fmt='.0f',
cmap='Blues',
cbar=False,
xticklabels=xticklabels,
yticklabels=yticklabels
)
plt.xlabel('Predicted')
plt.ylabel('True')
#plt.title(title)
plt.tight_layout()
plt.savefig(title + '.pdf')
plt.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, choices=['confusion', 'similarity', 'scores', 'friends_vs_strangers'])
parser.add_argument('--folder_path', type=str)
args = parser.parse_args()
if args.task == 'similarity':
sns.set_theme(style='white')
else:
sns.set_theme(style='whitegrid')
# -*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* #
if args.task == 'similarity':
out_left_main_mods_full_test = []
out_right_main_mods_full_test = []
cell_left_main_mods_full_test = []
cell_right_main_mods_full_test = []
cm_left_main_mods_full_test = []
cm_right_main_mods_full_test = []
for i in range(60):
print(f'Computing analysis for test video {i}...', end='\r')
emb_file = os.path.join(args.folder_path, f'{i}.pt')
data = torch.load(emb_file)
if len(data) == 14: # implicit
model = 'impl'
out_left, cell_left, out_right, cell_right, feats = data[0], data[1], data[2], data[3], data[4:]
out_left = out_left.squeeze(0)
cell_left = cell_left.squeeze(0)
out_right = out_right.squeeze(0)
cell_right = cell_right.squeeze(0)
elif len(data) == 13: # common mind
model = 'cm'
out_left, out_right, common_mind, feats = data[0], data[1], data[2], data[3:]
out_left = out_left.squeeze(0)
out_right = out_right.squeeze(0)
common_mind = common_mind.squeeze(0)
elif len(data) == 12: # speaker-listener
model = 'sl'
out_left, out_right, feats = data[0], data[1], data[2:]
out_left = out_left.squeeze(0)
out_right = out_right.squeeze(0)
else: raise ValueError("Data should have 14 (impl), 13 (cm) or 12 (sl) elements!")
# ====== PCA for left and right embeddings ====== #
out_left_and_right = np.concatenate((out_left, out_right), axis=0)
pca = PCA(n_components=2)
pca_result = pca.fit_transform(out_left_and_right)
# Separate the PCA results for each tensor
pca_result_left = pca_result[:out_left.shape[0]]
pca_result_right = pca_result[out_right.shape[0]:]
plt.figure(figsize=(7,6))
plt.scatter(pca_result_left[:, 0], pca_result_left[:, 1], label='MindNet$_1$', color=MTOM_COLORS['MN1'], s=100)
plt.scatter(pca_result_right[:, 0], pca_result_right[:, 1], label='MindNet$_2$', color=MTOM_COLORS['MN2'], s=100)
plt.xlabel('Principal Component 1', fontsize=30)
plt.ylabel('Principal Component 2', fontsize=30)
plt.grid(False)
plt.legend(fontsize=30)
plt.tight_layout()
plt.savefig(f'{args.folder_path}/{i}_pca.pdf')
plt.close()
# ====== Feature similarity ====== #
if len(feats) == 10:
left_rgb, left_ocr, left_pose, left_gaze, left_bbox, right_rgb, right_ocr, right_pose, right_gaze, right_bbox = feats
left_rgb = left_rgb.squeeze(0)
left_ocr = left_ocr.squeeze(0)
left_pose = left_pose.squeeze(0)
left_gaze = left_gaze.squeeze(0)
left_bbox = left_bbox.squeeze(0)
right_rgb = right_rgb.squeeze(0)
right_ocr = right_ocr.squeeze(0)
right_pose = right_pose.squeeze(0)
right_gaze = right_gaze.squeeze(0)
right_bbox = right_bbox.squeeze(0)
else: raise NotImplementedError("Ablated versions are not supported yet.")
# out: [1, seq_len, dim] --- squeeze ---> [seq_len, dim]
# cell: [1, 1, dim] --------- squeeze ---> [1, dim]
# cm: [1, seq_len, dim] --- squeeze ---> [seq_len, dim]
# feat: [1, seq_len, dim] --- squeeze ---> [seq_len, dim]
_, out_left_main_mods = find_most_similar_embedding(left_rgb, left_pose, left_gaze, left_ocr, left_bbox, out_left)
_, out_right_main_mods = find_most_similar_embedding(right_rgb, right_pose, right_gaze, right_ocr, right_bbox, out_right)
out_left_main_mods_full_test += out_left_main_mods
out_right_main_mods_full_test += out_right_main_mods
if model == 'impl':
_, cell_left_main_mods = find_most_similar_embedding(left_rgb, left_pose, left_gaze, left_ocr, left_bbox, cell_left)
_, cell_right_main_mods = find_most_similar_embedding(right_rgb, right_pose, right_gaze, right_ocr, right_bbox, cell_right)
cell_left_main_mods_full_test += cell_left_main_mods
cell_right_main_mods_full_test += cell_right_main_mods
if model == 'cm':
_, cm_left_main_mods = find_most_similar_embedding(left_rgb, left_pose, left_gaze, left_ocr, left_bbox, common_mind)
_, cm_right_main_mods = find_most_similar_embedding(right_rgb, right_pose, right_gaze, right_ocr, right_bbox, common_mind)
cm_left_main_mods_full_test += cm_left_main_mods
cm_right_main_mods_full_test += cm_right_main_mods
if model == 'impl':
plot_similarity_histogram(
[out_left_main_mods_full_test,
out_right_main_mods_full_test,
cell_left_main_mods_full_test,
cell_right_main_mods_full_test],
f'{args.folder_path}/boss_similartiy_impl_hist_all.pdf',
[r'$h_1$', r'$h_2$', r'$c_1$', r'$c_2$']
)
elif model == 'cm':
plot_similarity_histogram(
[out_left_main_mods_full_test,
out_right_main_mods_full_test,
cm_left_main_mods_full_test,
cm_right_main_mods_full_test],
f'{args.folder_path}/boss_similarity_cm_hist_all.pdf',
[r'$h_1$', r'$h_2$', r'$cg$ w/ 1', r'$cg$ w/ 2']
)
elif model == 'sl':
plot_similarity_histogram(
[out_left_main_mods_full_test,
out_right_main_mods_full_test],
f'{args.folder_path}/boss_similarity_sl_hist_all.py',
[r'$h_1$', r'$h_2$']
)
# -*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* #
elif args.task == 'scores':
all = json.load(open('results/all.json'))
abl_tom_cm_concat = json.load(open('results/abl_cm_concat.json'))
plot_scores_histogram(
all,
filename='results/all',
size=(10,4),
rotation=45,
colors=[COLORS[7]] + [MTOM_COLORS['DB']] + [MTOM_COLORS['CG']]*4 + [MTOM_COLORS['IC']]*4 + [COLORS[4]]*4
)
plot_scores_histogram(
abl_tom_cm_concat,
size=(10,4),
filename='results/abl_cm_concat',
colors=[MTOM_COLORS['CG']]*8
)
# -*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* #
elif args.task == 'confusion':
# List to store confusion matrices
all_df = []
left_matrices = []
right_matrices = []
# Iterate over the CSV files
for filename in natsorted(os.listdir(args.folder_path)):
if filename.endswith('.csv'):
file_path = os.path.join(args.folder_path, filename)
print(f'Processing {file_path}...')
# Read the CSV file
df = pd.read_csv(file_path)
# Extract the left and right labels and predictions
left_labels = df['left_label']
right_labels = df['right_label']
left_preds = df['left_pred']
right_preds = df['right_pred']
# Calculate the confusion matrices for left and right
left_cm = pd.crosstab(left_labels, left_preds)
right_cm = pd.crosstab(right_labels, right_preds)
# Append the confusion matrices to the list
left_matrices.append(left_cm)
right_matrices.append(right_cm)
all_df.append(df)
# Plot and save the confusion matrices for left and right
for i, cm in enumerate(zip(left_matrices, right_matrices)):
print(f'Computing confusion matrices for video {i}...', end='\r')
plot_confusion_matrices(
cm[0],
cm[1],
labels=[cm[0].columns, cm[0].index, cm[1].columns, cm[1].index],
title=f'{args.folder_path}/{i}_cm'
)
merged_df = pd.concat(all_df).reset_index(drop=True)
merged_left_cm = pd.crosstab(merged_df['left_label'], merged_df['left_pred'])
merged_right_cm = pd.crosstab(merged_df['right_label'], merged_df['right_pred'])
plot_confusion_matrices(
merged_left_cm,
merged_right_cm,
labels=[merged_left_cm.columns, merged_left_cm.index, merged_right_cm.columns, merged_right_cm.index],
title=f'{args.folder_path}/all_lr',
annot=False
)
merged_preds = pd.concat([merged_df['left_pred'], merged_df['right_pred']]).reset_index(drop=True)
merged_labels = pd.concat([merged_df['left_label'], merged_df['right_label']]).reset_index(drop=True)
merged_all_cm = pd.crosstab(merged_labels, merged_preds)
plot_confusion_matrix(
merged_all_cm,
labels=[merged_all_cm.columns, merged_all_cm.index],
title=f'{args.folder_path}/all_merged',
annot=False
)
# -*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* #
elif args.task == 'friends_vs_strangers':
# friends ids: 30-59
# stranger ids: 0-29
out_left_list = []
out_right_list = []
cell_left_list = []
cell_right_list = []
common_mind_list = []
for i in range(60):
print(f'Computing analysis for test video {i}...', end='\r')
emb_file = os.path.join(args.folder_path, f'{i}.pt')
data = torch.load(emb_file)
if len(data) == 14: # implicit
model = 'impl'
out_left, cell_left, out_right, cell_right, feats = data[0], data[1], data[2], data[3], data[4:]
out_left = out_left.squeeze(0)
cell_left = cell_left.squeeze(0)
out_right = out_right.squeeze(0)
cell_right = cell_right.squeeze(0)
out_left_list.append(out_left)
out_right_list.append(out_right)
cell_left_list.append(cell_left)
cell_right_list.append(cell_right)
elif len(data) == 13: # common mind
model = 'cm'
out_left, out_right, common_mind, feats = data[0], data[1], data[2], data[3:]
out_left = out_left.squeeze(0)
out_right = out_right.squeeze(0)
common_mind = common_mind.squeeze(0)
out_left_list.append(out_left)
out_right_list.append(out_right)
common_mind_list.append(common_mind)
elif len(data) == 12: # speaker-listener
model = 'sl'
out_left, out_right, feats = data[0], data[1], data[2:]
out_left = out_left.squeeze(0)
out_right = out_right.squeeze(0)
out_left_list.append(out_left)
out_right_list.append(out_right)
else: raise ValueError("Data should have 14 (impl), 13 (cm) or 12 (sl) elements!")
# ====== PCA for left and right embeddings ====== #
print('\rComputing PCA...')
strangers_nframes = sum([out_left_list[i].shape[0] for i in range(30)])
left = torch.cat(out_left_list, 0)
right = torch.cat(out_right_list, 0)
out_left_and_right = torch.cat([left, right], axis=0).numpy()
#out_left_and_right = StandardScaler().fit_transform(out_left_and_right)
pca = PCA(n_components=2)
pca_result = pca.fit_transform(out_left_and_right)
#pca_result = TSNE(n_components=2, learning_rate='auto', init='random', perplexity=3).fit_transform(out_left_and_right)
#pca_result = UMAP().fit_transform(out_left_and_right)
# Separate the PCA results for each tensor
#pca_result_left = pca_result[:left.shape[0]]
#pca_result_right = pca_result[right.shape[0]:]
pca_result_strangers = pca_result[:strangers_nframes]
pca_result_friends = pca_result[strangers_nframes:]
#plt.scatter(pca_result_left[:, 0], pca_result_left[:, 1], label='Left')
#plt.scatter(pca_result_right[:, 0], pca_result_right[:, 1], label='Right')
plt.scatter(pca_result_friends[:, 0], pca_result_friends[:, 1], label='Friends')
plt.scatter(pca_result_strangers[:, 0], pca_result_strangers[:, 1], label='Strangers')
plt.xlabel('Principal Component 1')
plt.ylabel('Principal Component 2')
plt.legend()
plt.savefig(f'{args.folder_path}/friends_vs_strangers_pca.pdf')
plt.close()
# -*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* #
else:
raise NameError