import os import torch import numpy as np from torch.nn.utils.rnn import pad_sequence import torch.nn.functional as F import matplotlib.pyplot as plt from sklearn.preprocessing import StandardScaler from sklearn.decomposition import PCA from sklearn.manifold import TSNE #from umap import UMAP import json import seaborn as sns import pandas as pd from natsort import natsorted import argparse import seaborn as sns COLORS_DM = { "red": (242/256, 165/256, 179/256), "blue": (195/256, 219/256, 252/256), "green": (156/256, 228/256, 213/256), "yellow": (250/256, 236/256, 144/256), "violet": (207/256, 187/256, 244/256), "orange": (244/256, 188/256, 154/256) } MTOM_COLORS = { "MN1": (110/255, 117/255, 161/255), "MN2": (179/255, 106/255, 98/255), "Base": (193/255, 198/255, 208/255), "CG": (170/255, 129/255, 42/255), "IC": (97/255, 112/255, 83/255), "DB": (144/255, 63/255, 110/255) } COLORS = sns.color_palette() OBJECTS = [ 'none', 'apple', 'orange', 'lemon', 'potato', 'wine', 'wineopener', 'knife', 'mug', 'peeler', 'bowl', 'chocolate', 'sugar', 'magazine', 'cracker', 'chips', 'scissors', 'cap', 'marker', 'sardinecan', 'tomatocan', 'plant', 'walnut', 'nail', 'waterspray', 'hammer', 'canopener' ] tried_once = [2, 4, 5, 6, 7, 8, 9, 10, 12, 15, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 32, 38, 40, 41, 42, 44, 47, 50, 51, 52, 53, 55, 56, 57, 61, 63, 65, 67, 68, 70, 71, 72, 73, 74, 76, 80, 81, 83, 85, 87, 88, 89, 90, 92, 93, 96, 97, 99, 101, 102, 105, 106, 108, 110, 111, 112, 113, 114, 116, 118, 121, 123, 125, 131, 132, 134, 135, 140, 142, 143, 145, 146, 148, 149, 151, 152, 154, 155, 156, 157, 160, 161, 162, 165, 169, 170, 171, 173, 175, 176, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 190, 191, 194, 196, 203, 204, 206, 207, 208, 209, 210, 211, 213, 214, 216, 218, 219, 220, 222, 225, 227, 228, 229, 232, 233, 235, 236, 237, 238, 239, 242, 243, 246, 247, 249, 251, 252, 254, 255, 256, 257, 260, 261, 262, 263, 265, 266, 268, 270, 272, 275, 277, 278, 279, 281, 282, 287, 290, 296, 298] tried_twice = [1, 3, 11, 13, 14, 16, 17, 31, 33, 35, 36, 37, 39, 43, 45, 49, 54, 58, 59, 60, 62, 64, 66, 69, 75, 77, 79, 82, 84, 86, 91, 94, 95, 98, 100, 103, 104, 107, 109, 115, 117, 119, 120, 122, 124, 126, 127, 128, 129, 130, 133, 136, 137, 138, 139, 141, 144, 147, 150, 153, 158, 159, 164, 166, 167, 168, 172, 174, 177, 188, 189, 192, 193, 195, 197, 198, 199, 200, 201, 202, 205, 212, 215, 217, 221, 223, 224, 226, 230, 231, 234, 240, 241, 244, 245, 248, 250, 253, 258, 259, 264, 267, 269, 271, 273, 274, 276, 280, 283, 284, 285, 286, 288, 291, 292, 293, 294, 295, 297, 299] tried_thrice = [34, 46, 48, 78, 163, 289] friends = [i for i in range(150, 300)] strangers = [i for i in range(0, 150)] def get_input_dim(modalities): dimensions = { 'rgb': 512, 'pose': 150, 'gaze': 6, 'ocr': 64, 'bbox': 108 } sum_of_dimensions = sum(dimensions[modality] for modality in modalities) return sum_of_dimensions def count_parameters(model): #return sum(p.numel() for p in model.parameters() if p.requires_grad) model_parameters = filter(lambda p: p.requires_grad, model.parameters()) return sum([np.prod(p.size()) for p in model_parameters]) def onehot(label, n_classes): if len(label.size()) == 3: batch_size, seq_len, _ = label.size() label_1_2d = label[:,:,0].contiguous().view(batch_size*seq_len, 1) label_2_2d = label[:,:,1].contiguous().view(batch_size*seq_len, 1) #onehot_1_2d = torch.zeros(batch_size*seq_len, n_classes).scatter_(1, label_1_2d, 1) #onehot_2_2d = torch.zeros(batch_size*seq_len, n_classes).scatter_(1, label_2_2d, 1) #onehot_2d = torch.cat((onehot_1_2d, onehot_2_2d), dim=1) #onehot_3d = onehot_2d.view(batch_size, seq_len, 2*n_classes) onehot_1_2d = torch.zeros(batch_size*seq_len, n_classes).scatter_(1, label_1_2d, 1).view(-1, seq_len, n_classes) onehot_2_2d = torch.zeros(batch_size*seq_len, n_classes).scatter_(1, label_2_2d, 1).view(-1, seq_len, n_classes) onehot_3d = torch.cat((onehot_1_2d, onehot_2_2d), dim=2) return onehot_3d else: return torch.zeros(label.size(0), n_classes).scatter_(1, label.view(-1, 1), 1) def mixup(data, alpha, n_classes): lam = torch.FloatTensor([np.random.beta(alpha, alpha)]) frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths = data indices = torch.randperm(labels.size(0)) # labels labels2 = labels[indices] labels = onehot(labels, n_classes) labels2 = onehot(labels2, n_classes) labels = labels * lam + labels2 * (1 - lam) # frames frames2 = frames[indices] frames = frames * lam + frames2 * (1 - lam) # poses poses2 = poses[indices] poses = poses * lam + poses2 * (1 - lam) # gazes gazes2 = gazes[indices] gazes = gazes * lam + gazes2 * (1 - lam) # bboxes bboxes2 = bboxes[indices] bboxes = bboxes * lam + bboxes2 * (1 - lam) return frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths def get_classification_accuracy(pred_left_labels, pred_right_labels, labels, sequence_lengths): max_len = max(sequence_lengths) pred_left_labels = torch.reshape(pred_left_labels, (-1, max_len, 27)) pred_right_labels = torch.reshape(pred_right_labels, (-1, max_len, 27)) labels = torch.reshape(labels, (-1, max_len, 2)) left_correct = torch.argmax(pred_left_labels, 2) == labels[:,:,0] right_correct = torch.argmax(pred_right_labels, 2) == labels[:,:,1] num_pred = sum(sequence_lengths) * 2 num_correct = 0 for i in range(len(sequence_lengths)): size = sequence_lengths[i] num_correct += (torch.sum(left_correct[i][:size]) + torch.sum(right_correct[i][:size])).item() acc = num_correct / num_pred return acc, num_correct, num_pred def get_classification_accuracy_mixup(pred_left_labels, pred_right_labels, labels, sequence_lengths): max_len = max(sequence_lengths) pred_left_labels = torch.reshape(pred_left_labels, (-1, max_len, 27)) pred_right_labels = torch.reshape(pred_right_labels, (-1, max_len, 27)) labels = torch.reshape(labels, (-1, max_len, 54)) left_correct = torch.argmax(pred_left_labels, 2) == torch.argmax(labels[:,:,:27], 2) right_correct = torch.argmax(pred_right_labels, 2) == torch.argmax(labels[:,:,27:], 2) num_pred = sum(sequence_lengths) * 2 num_correct = 0 for i in range(len(sequence_lengths)): size = sequence_lengths[i] num_correct += (torch.sum(left_correct[i][:size]) + torch.sum(right_correct[i][:size])).item() acc = num_correct / num_pred return acc, num_correct, num_pred def pad_collate(batch): (aa, bb, cc, dd, ee, ff) = zip(*batch) seq_lens = [len(a) for a in aa] aa_pad = pad_sequence(aa, batch_first=True, padding_value=0) bb_pad = pad_sequence(bb, batch_first=True, padding_value=-1) if cc[0] is not None: cc_pad = pad_sequence(cc, batch_first=True, padding_value=0) else: cc_pad = None if dd[0] is not None: dd_pad = pad_sequence(dd, batch_first=True, padding_value=0) else: dd_pad = None if ee[0] is not None: ee_pad = pad_sequence(ee, batch_first=True, padding_value=0) else: ee_pad = None if ff[0] is not None: ff_pad = pad_sequence(ff, batch_first=True, padding_value=0) else: ff_pad = None return aa_pad, bb_pad, cc_pad, dd_pad, ee_pad, ff_pad, seq_lens def ocr_loss(pL, lL, pR, lR, ocr, loss, eta=10.0, mode='abs'): """ Custom loss based on the negative OCRL matrix. Input: pL: tensor of shape [batch_size, num_classes] representing left belief predictions lL: tensor of shape [batch_size] representing the left belief labels pR: tensor of shape [batch_size, num_classes] representing right belief predictions lR: tensor of shape [batch_size] representing the right belief labels ocr: negative OCR matrix (i.e. 1-OCR) loss: loss function eta: hyperparameter for the interaction term Output: Final loss resulting from weighting left and right loss with the respective OCR coefficients and summing them. """ bs = pL.shape[0] if len([*lR[0].size()]) > 0: # for mixup p = torch.tensor([ocr[torch.argmax(pL[i]), torch.argmax(pR[i])] for i in range(bs)], device=pL.device) g = torch.tensor([ocr[torch.argmax(lL[i]), torch.argmax(lR[i])] for i in range(bs)], device=pL.device) else: p = torch.tensor([ocr[torch.argmax(pL[i]), torch.argmax(pR[i])] for i in range(bs)], device=pL.device) g = torch.tensor([ocr[lL[i], lR[i]] for i in range(bs)], device=pL.device) left_loss = torch.mean(loss(pL, lL)) right_loss = torch.mean(loss(pR, lR)) if mode == 'abs': interaction_loss = torch.mean(torch.abs(g - p)) elif mode == 'mse': interaction_loss = torch.mean(torch.pow(g - p, 2)) else: raise NotImplementedError eta = (left_loss + right_loss) / interaction_loss interaction_loss = interaction_loss * eta print(f"Left: {left_loss} --- Right: {right_loss} --- Interaction: {interaction_loss}") return left_loss + right_loss + interaction_loss def spherical2cartesial(x): """ From https://colab.research.google.com/drive/1SJbzd-gFTbiYjfZynIfrG044fWi6svbV?usp=sharing#scrollTo=78QhNw4MYSYp """ output = torch.zeros(x.size(0),3) output[:,2] = -torch.cos(x[:,1])*torch.cos(x[:,0]) output[:,0] = torch.cos(x[:,1])*torch.sin(x[:,0]) output[:,1] = torch.sin(x[:,1]) return output def presave(): from PIL import Image from torchvision import transforms import os from natsort import natsorted import pickle as pkl preprocess = transforms.Compose([ transforms.Resize((128, 128)), #transforms.Resize(256), #transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) frame_path = '/scratch/bortoletto/data/boss/test/frame' frame_dirs = os.listdir(frame_path) frame_paths = [] for frame_dir in natsorted(frame_dirs): paths = [os.path.join(frame_path, frame_dir, i) for i in natsorted(os.listdir(os.path.join(frame_path, frame_dir)))] frame_paths.append(paths) save_folder = '/scratch/bortoletto/data/boss/presaved128/'+frame_path.split('/')[-2]+'/'+frame_path.split('/')[-1] if not os.path.exists(save_folder): os.makedirs(save_folder) print(save_folder, 'created') for video in natsorted(frame_paths): print(video[0].split('/')[-2], end='\r') images = [preprocess(Image.open(i)) for i in video] strings = video[0].split('/') with open(save_folder+'/'+strings[7]+'.pkl', 'wb') as f: pkl.dump(images, f) def compute_cosine_similarity(tensor1, tensor2): return F.cosine_similarity(tensor1, tensor2, dim=-1) def find_most_similar_embedding(rgb, pose, gaze, ocr, bbox, repr): gaze_similarity = compute_cosine_similarity(gaze, repr) pose_similarity = compute_cosine_similarity(pose, repr) ocr_similarity = compute_cosine_similarity(ocr, repr) bbox_similarity = compute_cosine_similarity(bbox, repr) rgb_similarity = compute_cosine_similarity(rgb, repr) similarities = torch.stack([gaze_similarity, pose_similarity, ocr_similarity, bbox_similarity, rgb_similarity]) max_index = torch.argmax(similarities, dim=0) main_modality = [] main_modality_name = [] for idx in max_index: if idx == 0: main_modality.append(gaze) main_modality_name.append('gaze') elif idx == 1: main_modality.append(pose) main_modality_name.append('pose') elif idx == 2: main_modality.append(ocr) main_modality_name.append('ocr') elif idx == 3: main_modality.append(bbox) main_modality_name.append('bbox') else: main_modality.append(rgb) main_modality_name.append('rgb') return main_modality, main_modality_name def plot_similarity_histogram(values, filename, labels): def count_elements(list_of_strings, target_list): counts = [] for element in list_of_strings: count = target_list.count(element) counts.append(count) return counts fig, ax = plt.subplots(figsize=(12,4)) colors = ["red", "blue", "green", "violet"] # Remove spines ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) colors = [MTOM_COLORS['MN1'], MTOM_COLORS['MN2'], MTOM_COLORS['CG'], MTOM_COLORS['CG']] alphas = [0.6, 0.6, 0.6, 0.6] edgecolors = ['black', 'black', MTOM_COLORS['MN1'], MTOM_COLORS['MN2']] linewidths = [1.0, 1.0, 5.0, 5.0] if isinstance(values[0], list): num_lists = len(values) bar_width = 0.8 / num_lists for i, val in enumerate(values): unique_strings = ['rgb', 'pose', 'gaze', 'bbox', 'ocr'] counts = count_elements(unique_strings, val) x = np.arange(len(unique_strings)) x_shifted = x + (i - (len(labels) - 1) / 2) * bar_width #x - (bar_width * num_lists / 2) + (bar_width * i) ax.bar(x_shifted, counts, width=bar_width, label=f'{labels[i]}', color=colors[i], edgecolor=edgecolors[i], linewidth=linewidths[i], alpha=alphas[i]) ax.set_xlabel('Modality', fontsize=18) ax.set_ylabel('Counts', fontsize=18) ax.set_xticks(np.arange(len(unique_strings))) ax.set_xticklabels(unique_strings, fontsize=18) ax.set_yticklabels(ax.get_yticklabels(), fontsize=16) ax.legend(fontsize=18) ax.grid(axis='y') plt.savefig(filename, bbox_inches='tight') else: unique_strings, counts = np.unique(values, return_counts=True) ax.bar(unique_strings, counts) ax.set_xlabel('Modality') ax.set_ylabel('Counts') plt.savefig(filename, bbox_inches='tight') def plot_scores_histogram(data, filename, size=(8,6), rotation=0, colors=None): means = [] stds = [] for key, values in data.items(): mean = np.mean(values) std = np.std(values) means.append(mean) stds.append(std) fig, ax = plt.subplots(figsize=size) x = np.arange(len(data)) width = 0.6 rects1 = ax.bar( x, means, width, label='Mean', yerr=stds, capsize=5, color='teal' if colors is None else colors, edgecolor='black', linewidth=1.5, alpha=0.6 ) ax.set_ylabel('Accuracy', fontsize=18) #ax.set_title('Mean and Standard Deviation of Results', fontsize=14) if filename == 'results/all': xticklabels = [ 'Base', 'DB', 'CG$\otimes$', 'CG$\oplus$', 'CG$\odot$', 'CG$\parallel$', 'IC$\otimes$', 'IC$\oplus$', 'IC$\odot$', 'IC$\parallel$', 'CNN', 'CNN+GRU', 'CNN+LSTM', 'CNN+Conv1D' ] else: xticklabels = list(data.keys()) ax.set_xticks(np.arange(len(xticklabels))) ax.set_xticklabels(xticklabels, rotation=rotation, fontsize=16) #data.keys(), rotation=rotation, fontsize=16) ax.set_yticklabels(ax.get_yticklabels(), fontsize=16) #ax.grid(axis='y') #ax.set_axisbelow(True) #ax.legend(loc='upper right') # Add value labels above each bar, avoiding overlapping with error bars for rect, std in zip(rects1, stds): height = rect.get_height() if height + std < ax.get_ylim()[1]: ax.annotate(f'{height:.3f}', xy=(rect.get_x() + rect.get_width() / 2, height + std), xytext=(0, 5), textcoords="offset points", ha='center', va='bottom', fontsize=14) else: ax.annotate(f'{height:.3f}', xy=(rect.get_x() + rect.get_width() / 2, height - std), xytext=(0, -12), textcoords="offset points", ha='center', va='top', fontsize=14) # Remove spines ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.grid(axis='x') plt.tight_layout() plt.savefig(f'{filename}.pdf', bbox_inches='tight') def plot_confusion_matrices(left_cm, right_cm, labels, title, annot=True): fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6)) left_xticklabels = [OBJECTS[i] for i in labels[0]] left_yticklabels = [OBJECTS[i] for i in labels[1]] right_xticklabels = [OBJECTS[i] for i in labels[2]] right_yticklabels = [OBJECTS[i] for i in labels[3]] sns.heatmap( left_cm, annot=annot, fmt='.0f', cmap='Blues', cbar=False, xticklabels=left_xticklabels, yticklabels=left_yticklabels, ax=ax1 ) ax1.set_xlabel('Predicted') ax1.set_ylabel('True') ax1.set_title('Left Confusion Matrix') sns.heatmap( right_cm, annot=annot, fmt='.0f', cmap='Blues', cbar=False, #True if annot is False else False, xticklabels=right_xticklabels, yticklabels=right_yticklabels, ax=ax2 ) ax2.set_xlabel('Predicted') ax2.set_ylabel('True') ax2.set_title('Right Confusion Matrix') #plt.suptitle(title) plt.tight_layout() plt.savefig(title + '.pdf') plt.close() def plot_confusion_matrix(confusion_matrix, labels, title, annot=True): plt.figure(figsize=(6, 6)) xticklabels = [OBJECTS[i] for i in labels[0]] yticklabels = [OBJECTS[i] for i in labels[1]] sns.heatmap( confusion_matrix, annot=annot, fmt='.0f', cmap='Blues', cbar=False, xticklabels=xticklabels, yticklabels=yticklabels ) plt.xlabel('Predicted') plt.ylabel('True') #plt.title(title) plt.tight_layout() plt.savefig(title + '.pdf') plt.close() if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--task', type=str, choices=['confusion', 'similarity', 'scores', 'friends_vs_strangers']) parser.add_argument('--folder_path', type=str) args = parser.parse_args() if args.task == 'similarity': sns.set_theme(style='white') else: sns.set_theme(style='whitegrid') # -*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* # if args.task == 'similarity': out_left_main_mods_full_test = [] out_right_main_mods_full_test = [] cell_left_main_mods_full_test = [] cell_right_main_mods_full_test = [] cm_left_main_mods_full_test = [] cm_right_main_mods_full_test = [] for i in range(60): print(f'Computing analysis for test video {i}...', end='\r') emb_file = os.path.join(args.folder_path, f'{i}.pt') data = torch.load(emb_file) if len(data) == 14: # implicit model = 'impl' out_left, cell_left, out_right, cell_right, feats = data[0], data[1], data[2], data[3], data[4:] out_left = out_left.squeeze(0) cell_left = cell_left.squeeze(0) out_right = out_right.squeeze(0) cell_right = cell_right.squeeze(0) elif len(data) == 13: # common mind model = 'cm' out_left, out_right, common_mind, feats = data[0], data[1], data[2], data[3:] out_left = out_left.squeeze(0) out_right = out_right.squeeze(0) common_mind = common_mind.squeeze(0) elif len(data) == 12: # speaker-listener model = 'sl' out_left, out_right, feats = data[0], data[1], data[2:] out_left = out_left.squeeze(0) out_right = out_right.squeeze(0) else: raise ValueError("Data should have 14 (impl), 13 (cm) or 12 (sl) elements!") # ====== PCA for left and right embeddings ====== # out_left_and_right = np.concatenate((out_left, out_right), axis=0) pca = PCA(n_components=2) pca_result = pca.fit_transform(out_left_and_right) # Separate the PCA results for each tensor pca_result_left = pca_result[:out_left.shape[0]] pca_result_right = pca_result[out_right.shape[0]:] plt.figure(figsize=(7,6)) plt.scatter(pca_result_left[:, 0], pca_result_left[:, 1], label='MindNet$_1$', color=MTOM_COLORS['MN1'], s=100) plt.scatter(pca_result_right[:, 0], pca_result_right[:, 1], label='MindNet$_2$', color=MTOM_COLORS['MN2'], s=100) plt.xlabel('Principal Component 1', fontsize=30) plt.ylabel('Principal Component 2', fontsize=30) plt.grid(False) plt.legend(fontsize=30) plt.tight_layout() plt.savefig(f'{args.folder_path}/{i}_pca.pdf') plt.close() # ====== Feature similarity ====== # if len(feats) == 10: left_rgb, left_ocr, left_pose, left_gaze, left_bbox, right_rgb, right_ocr, right_pose, right_gaze, right_bbox = feats left_rgb = left_rgb.squeeze(0) left_ocr = left_ocr.squeeze(0) left_pose = left_pose.squeeze(0) left_gaze = left_gaze.squeeze(0) left_bbox = left_bbox.squeeze(0) right_rgb = right_rgb.squeeze(0) right_ocr = right_ocr.squeeze(0) right_pose = right_pose.squeeze(0) right_gaze = right_gaze.squeeze(0) right_bbox = right_bbox.squeeze(0) else: raise NotImplementedError("Ablated versions are not supported yet.") # out: [1, seq_len, dim] --- squeeze ---> [seq_len, dim] # cell: [1, 1, dim] --------- squeeze ---> [1, dim] # cm: [1, seq_len, dim] --- squeeze ---> [seq_len, dim] # feat: [1, seq_len, dim] --- squeeze ---> [seq_len, dim] _, out_left_main_mods = find_most_similar_embedding(left_rgb, left_pose, left_gaze, left_ocr, left_bbox, out_left) _, out_right_main_mods = find_most_similar_embedding(right_rgb, right_pose, right_gaze, right_ocr, right_bbox, out_right) out_left_main_mods_full_test += out_left_main_mods out_right_main_mods_full_test += out_right_main_mods if model == 'impl': _, cell_left_main_mods = find_most_similar_embedding(left_rgb, left_pose, left_gaze, left_ocr, left_bbox, cell_left) _, cell_right_main_mods = find_most_similar_embedding(right_rgb, right_pose, right_gaze, right_ocr, right_bbox, cell_right) cell_left_main_mods_full_test += cell_left_main_mods cell_right_main_mods_full_test += cell_right_main_mods if model == 'cm': _, cm_left_main_mods = find_most_similar_embedding(left_rgb, left_pose, left_gaze, left_ocr, left_bbox, common_mind) _, cm_right_main_mods = find_most_similar_embedding(right_rgb, right_pose, right_gaze, right_ocr, right_bbox, common_mind) cm_left_main_mods_full_test += cm_left_main_mods cm_right_main_mods_full_test += cm_right_main_mods if model == 'impl': plot_similarity_histogram( [out_left_main_mods_full_test, out_right_main_mods_full_test, cell_left_main_mods_full_test, cell_right_main_mods_full_test], f'{args.folder_path}/boss_similartiy_impl_hist_all.pdf', [r'$h_1$', r'$h_2$', r'$c_1$', r'$c_2$'] ) elif model == 'cm': plot_similarity_histogram( [out_left_main_mods_full_test, out_right_main_mods_full_test, cm_left_main_mods_full_test, cm_right_main_mods_full_test], f'{args.folder_path}/boss_similarity_cm_hist_all.pdf', [r'$h_1$', r'$h_2$', r'$cg$ w/ 1', r'$cg$ w/ 2'] ) elif model == 'sl': plot_similarity_histogram( [out_left_main_mods_full_test, out_right_main_mods_full_test], f'{args.folder_path}/boss_similarity_sl_hist_all.py', [r'$h_1$', r'$h_2$'] ) # -*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* # elif args.task == 'scores': all = json.load(open('results/all.json')) abl_tom_cm_concat = json.load(open('results/abl_cm_concat.json')) plot_scores_histogram( all, filename='results/all', size=(10,4), rotation=45, colors=[COLORS[7]] + [MTOM_COLORS['DB']] + [MTOM_COLORS['CG']]*4 + [MTOM_COLORS['IC']]*4 + [COLORS[4]]*4 ) plot_scores_histogram( abl_tom_cm_concat, size=(10,4), filename='results/abl_cm_concat', colors=[MTOM_COLORS['CG']]*8 ) # -*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* # elif args.task == 'confusion': # List to store confusion matrices all_df = [] left_matrices = [] right_matrices = [] # Iterate over the CSV files for filename in natsorted(os.listdir(args.folder_path)): if filename.endswith('.csv'): file_path = os.path.join(args.folder_path, filename) print(f'Processing {file_path}...') # Read the CSV file df = pd.read_csv(file_path) # Extract the left and right labels and predictions left_labels = df['left_label'] right_labels = df['right_label'] left_preds = df['left_pred'] right_preds = df['right_pred'] # Calculate the confusion matrices for left and right left_cm = pd.crosstab(left_labels, left_preds) right_cm = pd.crosstab(right_labels, right_preds) # Append the confusion matrices to the list left_matrices.append(left_cm) right_matrices.append(right_cm) all_df.append(df) # Plot and save the confusion matrices for left and right for i, cm in enumerate(zip(left_matrices, right_matrices)): print(f'Computing confusion matrices for video {i}...', end='\r') plot_confusion_matrices( cm[0], cm[1], labels=[cm[0].columns, cm[0].index, cm[1].columns, cm[1].index], title=f'{args.folder_path}/{i}_cm' ) merged_df = pd.concat(all_df).reset_index(drop=True) merged_left_cm = pd.crosstab(merged_df['left_label'], merged_df['left_pred']) merged_right_cm = pd.crosstab(merged_df['right_label'], merged_df['right_pred']) plot_confusion_matrices( merged_left_cm, merged_right_cm, labels=[merged_left_cm.columns, merged_left_cm.index, merged_right_cm.columns, merged_right_cm.index], title=f'{args.folder_path}/all_lr', annot=False ) merged_preds = pd.concat([merged_df['left_pred'], merged_df['right_pred']]).reset_index(drop=True) merged_labels = pd.concat([merged_df['left_label'], merged_df['right_label']]).reset_index(drop=True) merged_all_cm = pd.crosstab(merged_labels, merged_preds) plot_confusion_matrix( merged_all_cm, labels=[merged_all_cm.columns, merged_all_cm.index], title=f'{args.folder_path}/all_merged', annot=False ) # -*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* # elif args.task == 'friends_vs_strangers': # friends ids: 30-59 # stranger ids: 0-29 out_left_list = [] out_right_list = [] cell_left_list = [] cell_right_list = [] common_mind_list = [] for i in range(60): print(f'Computing analysis for test video {i}...', end='\r') emb_file = os.path.join(args.folder_path, f'{i}.pt') data = torch.load(emb_file) if len(data) == 14: # implicit model = 'impl' out_left, cell_left, out_right, cell_right, feats = data[0], data[1], data[2], data[3], data[4:] out_left = out_left.squeeze(0) cell_left = cell_left.squeeze(0) out_right = out_right.squeeze(0) cell_right = cell_right.squeeze(0) out_left_list.append(out_left) out_right_list.append(out_right) cell_left_list.append(cell_left) cell_right_list.append(cell_right) elif len(data) == 13: # common mind model = 'cm' out_left, out_right, common_mind, feats = data[0], data[1], data[2], data[3:] out_left = out_left.squeeze(0) out_right = out_right.squeeze(0) common_mind = common_mind.squeeze(0) out_left_list.append(out_left) out_right_list.append(out_right) common_mind_list.append(common_mind) elif len(data) == 12: # speaker-listener model = 'sl' out_left, out_right, feats = data[0], data[1], data[2:] out_left = out_left.squeeze(0) out_right = out_right.squeeze(0) out_left_list.append(out_left) out_right_list.append(out_right) else: raise ValueError("Data should have 14 (impl), 13 (cm) or 12 (sl) elements!") # ====== PCA for left and right embeddings ====== # print('\rComputing PCA...') strangers_nframes = sum([out_left_list[i].shape[0] for i in range(30)]) left = torch.cat(out_left_list, 0) right = torch.cat(out_right_list, 0) out_left_and_right = torch.cat([left, right], axis=0).numpy() #out_left_and_right = StandardScaler().fit_transform(out_left_and_right) pca = PCA(n_components=2) pca_result = pca.fit_transform(out_left_and_right) #pca_result = TSNE(n_components=2, learning_rate='auto', init='random', perplexity=3).fit_transform(out_left_and_right) #pca_result = UMAP().fit_transform(out_left_and_right) # Separate the PCA results for each tensor #pca_result_left = pca_result[:left.shape[0]] #pca_result_right = pca_result[right.shape[0]:] pca_result_strangers = pca_result[:strangers_nframes] pca_result_friends = pca_result[strangers_nframes:] #plt.scatter(pca_result_left[:, 0], pca_result_left[:, 1], label='Left') #plt.scatter(pca_result_right[:, 0], pca_result_right[:, 1], label='Right') plt.scatter(pca_result_friends[:, 0], pca_result_friends[:, 1], label='Friends') plt.scatter(pca_result_strangers[:, 0], pca_result_strangers[:, 1], label='Strangers') plt.xlabel('Principal Component 1') plt.ylabel('Principal Component 2') plt.legend() plt.savefig(f'{args.folder_path}/friends_vs_strangers_pca.pdf') plt.close() # -*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* # else: raise NameError