769 lines
31 KiB
Python
769 lines
31 KiB
Python
|
import os
|
||
|
import torch
|
||
|
import numpy as np
|
||
|
from torch.nn.utils.rnn import pad_sequence
|
||
|
import torch.nn.functional as F
|
||
|
import matplotlib.pyplot as plt
|
||
|
from sklearn.preprocessing import StandardScaler
|
||
|
from sklearn.decomposition import PCA
|
||
|
from sklearn.manifold import TSNE
|
||
|
#from umap import UMAP
|
||
|
import json
|
||
|
import seaborn as sns
|
||
|
import pandas as pd
|
||
|
from natsort import natsorted
|
||
|
import argparse
|
||
|
import seaborn as sns
|
||
|
|
||
|
|
||
|
|
||
|
COLORS_DM = {
|
||
|
"red": (242/256, 165/256, 179/256),
|
||
|
"blue": (195/256, 219/256, 252/256),
|
||
|
"green": (156/256, 228/256, 213/256),
|
||
|
"yellow": (250/256, 236/256, 144/256),
|
||
|
"violet": (207/256, 187/256, 244/256),
|
||
|
"orange": (244/256, 188/256, 154/256)
|
||
|
}
|
||
|
|
||
|
MTOM_COLORS = {
|
||
|
"MN1": (110/255, 117/255, 161/255),
|
||
|
"MN2": (179/255, 106/255, 98/255),
|
||
|
"Base": (193/255, 198/255, 208/255),
|
||
|
"CG": (170/255, 129/255, 42/255),
|
||
|
"IC": (97/255, 112/255, 83/255),
|
||
|
"DB": (144/255, 63/255, 110/255)
|
||
|
}
|
||
|
|
||
|
COLORS = sns.color_palette()
|
||
|
|
||
|
OBJECTS = [
|
||
|
'none', 'apple', 'orange', 'lemon', 'potato', 'wine', 'wineopener',
|
||
|
'knife', 'mug', 'peeler', 'bowl', 'chocolate', 'sugar', 'magazine',
|
||
|
'cracker', 'chips', 'scissors', 'cap', 'marker', 'sardinecan', 'tomatocan',
|
||
|
'plant', 'walnut', 'nail', 'waterspray', 'hammer', 'canopener'
|
||
|
]
|
||
|
|
||
|
tried_once = [2, 4, 5, 6, 7, 8, 9, 10, 12, 15, 18, 19, 20, 21, 22, 23, 24,
|
||
|
25, 26, 27, 28, 29, 30, 32, 38, 40, 41, 42, 44, 47, 50, 51, 52, 53, 55,
|
||
|
56, 57, 61, 63, 65, 67, 68, 70, 71, 72, 73, 74, 76, 80, 81, 83, 85, 87,
|
||
|
88, 89, 90, 92, 93, 96, 97, 99, 101, 102, 105, 106, 108, 110, 111, 112,
|
||
|
113, 114, 116, 118, 121, 123, 125, 131, 132, 134, 135, 140, 142, 143,
|
||
|
145, 146, 148, 149, 151, 152, 154, 155, 156, 157, 160, 161, 162, 165,
|
||
|
169, 170, 171, 173, 175, 176, 178, 179, 180, 181, 182, 183, 184, 185,
|
||
|
186, 187, 190, 191, 194, 196, 203, 204, 206, 207, 208, 209, 210, 211,
|
||
|
213, 214, 216, 218, 219, 220, 222, 225, 227, 228, 229, 232, 233, 235,
|
||
|
236, 237, 238, 239, 242, 243, 246, 247, 249, 251, 252, 254, 255, 256,
|
||
|
257, 260, 261, 262, 263, 265, 266, 268, 270, 272, 275, 277, 278, 279,
|
||
|
281, 282, 287, 290, 296, 298]
|
||
|
|
||
|
tried_twice = [1, 3, 11, 13, 14, 16, 17, 31, 33, 35, 36, 37, 39,
|
||
|
43, 45, 49, 54, 58, 59, 60, 62, 64, 66, 69, 75, 77, 79, 82, 84, 86,
|
||
|
91, 94, 95, 98, 100, 103, 104, 107, 109, 115, 117, 119, 120, 122,
|
||
|
124, 126, 127, 128, 129, 130, 133, 136, 137, 138, 139, 141, 144,
|
||
|
147, 150, 153, 158, 159, 164, 166, 167, 168, 172, 174, 177, 188,
|
||
|
189, 192, 193, 195, 197, 198, 199, 200, 201, 202, 205, 212, 215,
|
||
|
217, 221, 223, 224, 226, 230, 231, 234, 240, 241, 244, 245, 248,
|
||
|
250, 253, 258, 259, 264, 267, 269, 271, 273, 274, 276, 280, 283,
|
||
|
284, 285, 286, 288, 291, 292, 293, 294, 295, 297, 299]
|
||
|
|
||
|
tried_thrice = [34, 46, 48, 78, 163, 289]
|
||
|
|
||
|
friends = [i for i in range(150, 300)]
|
||
|
strangers = [i for i in range(0, 150)]
|
||
|
|
||
|
|
||
|
def get_input_dim(modalities):
|
||
|
dimensions = {
|
||
|
'rgb': 512,
|
||
|
'pose': 150,
|
||
|
'gaze': 6,
|
||
|
'ocr': 64,
|
||
|
'bbox': 108
|
||
|
}
|
||
|
sum_of_dimensions = sum(dimensions[modality] for modality in modalities)
|
||
|
return sum_of_dimensions
|
||
|
|
||
|
|
||
|
def count_parameters(model):
|
||
|
#return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||
|
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
|
||
|
return sum([np.prod(p.size()) for p in model_parameters])
|
||
|
|
||
|
|
||
|
def onehot(label, n_classes):
|
||
|
if len(label.size()) == 3:
|
||
|
batch_size, seq_len, _ = label.size()
|
||
|
label_1_2d = label[:,:,0].contiguous().view(batch_size*seq_len, 1)
|
||
|
label_2_2d = label[:,:,1].contiguous().view(batch_size*seq_len, 1)
|
||
|
#onehot_1_2d = torch.zeros(batch_size*seq_len, n_classes).scatter_(1, label_1_2d, 1)
|
||
|
#onehot_2_2d = torch.zeros(batch_size*seq_len, n_classes).scatter_(1, label_2_2d, 1)
|
||
|
#onehot_2d = torch.cat((onehot_1_2d, onehot_2_2d), dim=1)
|
||
|
#onehot_3d = onehot_2d.view(batch_size, seq_len, 2*n_classes)
|
||
|
onehot_1_2d = torch.zeros(batch_size*seq_len, n_classes).scatter_(1, label_1_2d, 1).view(-1, seq_len, n_classes)
|
||
|
onehot_2_2d = torch.zeros(batch_size*seq_len, n_classes).scatter_(1, label_2_2d, 1).view(-1, seq_len, n_classes)
|
||
|
onehot_3d = torch.cat((onehot_1_2d, onehot_2_2d), dim=2)
|
||
|
return onehot_3d
|
||
|
else:
|
||
|
return torch.zeros(label.size(0), n_classes).scatter_(1, label.view(-1, 1), 1)
|
||
|
|
||
|
|
||
|
def mixup(data, alpha, n_classes):
|
||
|
lam = torch.FloatTensor([np.random.beta(alpha, alpha)])
|
||
|
frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths = data
|
||
|
indices = torch.randperm(labels.size(0))
|
||
|
# labels
|
||
|
labels2 = labels[indices]
|
||
|
labels = onehot(labels, n_classes)
|
||
|
labels2 = onehot(labels2, n_classes)
|
||
|
labels = labels * lam + labels2 * (1 - lam)
|
||
|
# frames
|
||
|
frames2 = frames[indices]
|
||
|
frames = frames * lam + frames2 * (1 - lam)
|
||
|
# poses
|
||
|
poses2 = poses[indices]
|
||
|
poses = poses * lam + poses2 * (1 - lam)
|
||
|
# gazes
|
||
|
gazes2 = gazes[indices]
|
||
|
gazes = gazes * lam + gazes2 * (1 - lam)
|
||
|
# bboxes
|
||
|
bboxes2 = bboxes[indices]
|
||
|
bboxes = bboxes * lam + bboxes2 * (1 - lam)
|
||
|
return frames, labels, poses, gazes, bboxes, ocr_graphs, sequence_lengths
|
||
|
|
||
|
|
||
|
def get_classification_accuracy(pred_left_labels, pred_right_labels, labels, sequence_lengths):
|
||
|
max_len = max(sequence_lengths)
|
||
|
pred_left_labels = torch.reshape(pred_left_labels, (-1, max_len, 27))
|
||
|
pred_right_labels = torch.reshape(pred_right_labels, (-1, max_len, 27))
|
||
|
labels = torch.reshape(labels, (-1, max_len, 2))
|
||
|
left_correct = torch.argmax(pred_left_labels, 2) == labels[:,:,0]
|
||
|
right_correct = torch.argmax(pred_right_labels, 2) == labels[:,:,1]
|
||
|
num_pred = sum(sequence_lengths) * 2
|
||
|
num_correct = 0
|
||
|
for i in range(len(sequence_lengths)):
|
||
|
size = sequence_lengths[i]
|
||
|
num_correct += (torch.sum(left_correct[i][:size]) + torch.sum(right_correct[i][:size])).item()
|
||
|
acc = num_correct / num_pred
|
||
|
return acc, num_correct, num_pred
|
||
|
|
||
|
|
||
|
def get_classification_accuracy_mixup(pred_left_labels, pred_right_labels, labels, sequence_lengths):
|
||
|
max_len = max(sequence_lengths)
|
||
|
pred_left_labels = torch.reshape(pred_left_labels, (-1, max_len, 27))
|
||
|
pred_right_labels = torch.reshape(pred_right_labels, (-1, max_len, 27))
|
||
|
labels = torch.reshape(labels, (-1, max_len, 54))
|
||
|
left_correct = torch.argmax(pred_left_labels, 2) == torch.argmax(labels[:,:,:27], 2)
|
||
|
right_correct = torch.argmax(pred_right_labels, 2) == torch.argmax(labels[:,:,27:], 2)
|
||
|
num_pred = sum(sequence_lengths) * 2
|
||
|
num_correct = 0
|
||
|
for i in range(len(sequence_lengths)):
|
||
|
size = sequence_lengths[i]
|
||
|
num_correct += (torch.sum(left_correct[i][:size]) + torch.sum(right_correct[i][:size])).item()
|
||
|
acc = num_correct / num_pred
|
||
|
return acc, num_correct, num_pred
|
||
|
|
||
|
|
||
|
def pad_collate(batch):
|
||
|
(aa, bb, cc, dd, ee, ff) = zip(*batch)
|
||
|
seq_lens = [len(a) for a in aa]
|
||
|
aa_pad = pad_sequence(aa, batch_first=True, padding_value=0)
|
||
|
bb_pad = pad_sequence(bb, batch_first=True, padding_value=-1)
|
||
|
if cc[0] is not None:
|
||
|
cc_pad = pad_sequence(cc, batch_first=True, padding_value=0)
|
||
|
else:
|
||
|
cc_pad = None
|
||
|
if dd[0] is not None:
|
||
|
dd_pad = pad_sequence(dd, batch_first=True, padding_value=0)
|
||
|
else:
|
||
|
dd_pad = None
|
||
|
if ee[0] is not None:
|
||
|
ee_pad = pad_sequence(ee, batch_first=True, padding_value=0)
|
||
|
else:
|
||
|
ee_pad = None
|
||
|
if ff[0] is not None:
|
||
|
ff_pad = pad_sequence(ff, batch_first=True, padding_value=0)
|
||
|
else:
|
||
|
ff_pad = None
|
||
|
return aa_pad, bb_pad, cc_pad, dd_pad, ee_pad, ff_pad, seq_lens
|
||
|
|
||
|
|
||
|
def ocr_loss(pL, lL, pR, lR, ocr, loss, eta=10.0, mode='abs'):
|
||
|
"""
|
||
|
Custom loss based on the negative OCRL matrix.
|
||
|
|
||
|
Input:
|
||
|
pL: tensor of shape [batch_size, num_classes] representing left belief predictions
|
||
|
lL: tensor of shape [batch_size] representing the left belief labels
|
||
|
pR: tensor of shape [batch_size, num_classes] representing right belief predictions
|
||
|
lR: tensor of shape [batch_size] representing the right belief labels
|
||
|
ocr: negative OCR matrix (i.e. 1-OCR)
|
||
|
loss: loss function
|
||
|
eta: hyperparameter for the interaction term
|
||
|
|
||
|
Output:
|
||
|
Final loss resulting from weighting left and right loss with the respective
|
||
|
OCR coefficients and summing them.
|
||
|
"""
|
||
|
bs = pL.shape[0]
|
||
|
if len([*lR[0].size()]) > 0: # for mixup
|
||
|
p = torch.tensor([ocr[torch.argmax(pL[i]), torch.argmax(pR[i])] for i in range(bs)], device=pL.device)
|
||
|
g = torch.tensor([ocr[torch.argmax(lL[i]), torch.argmax(lR[i])] for i in range(bs)], device=pL.device)
|
||
|
else:
|
||
|
p = torch.tensor([ocr[torch.argmax(pL[i]), torch.argmax(pR[i])] for i in range(bs)], device=pL.device)
|
||
|
g = torch.tensor([ocr[lL[i], lR[i]] for i in range(bs)], device=pL.device)
|
||
|
left_loss = torch.mean(loss(pL, lL))
|
||
|
right_loss = torch.mean(loss(pR, lR))
|
||
|
if mode == 'abs':
|
||
|
interaction_loss = torch.mean(torch.abs(g - p))
|
||
|
elif mode == 'mse':
|
||
|
interaction_loss = torch.mean(torch.pow(g - p, 2))
|
||
|
else: raise NotImplementedError
|
||
|
eta = (left_loss + right_loss) / interaction_loss
|
||
|
interaction_loss = interaction_loss * eta
|
||
|
print(f"Left: {left_loss} --- Right: {right_loss} --- Interaction: {interaction_loss}")
|
||
|
return left_loss + right_loss + interaction_loss
|
||
|
|
||
|
|
||
|
def spherical2cartesial(x):
|
||
|
"""
|
||
|
From https://colab.research.google.com/drive/1SJbzd-gFTbiYjfZynIfrG044fWi6svbV?usp=sharing#scrollTo=78QhNw4MYSYp
|
||
|
"""
|
||
|
output = torch.zeros(x.size(0),3)
|
||
|
output[:,2] = -torch.cos(x[:,1])*torch.cos(x[:,0])
|
||
|
output[:,0] = torch.cos(x[:,1])*torch.sin(x[:,0])
|
||
|
output[:,1] = torch.sin(x[:,1])
|
||
|
return output
|
||
|
|
||
|
|
||
|
def presave():
|
||
|
from PIL import Image
|
||
|
from torchvision import transforms
|
||
|
import os
|
||
|
from natsort import natsorted
|
||
|
import pickle as pkl
|
||
|
|
||
|
preprocess = transforms.Compose([
|
||
|
transforms.Resize((128, 128)),
|
||
|
#transforms.Resize(256),
|
||
|
#transforms.CenterCrop(224),
|
||
|
transforms.ToTensor(),
|
||
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||
|
])
|
||
|
|
||
|
frame_path = '/scratch/bortoletto/data/boss/test/frame'
|
||
|
frame_dirs = os.listdir(frame_path)
|
||
|
frame_paths = []
|
||
|
for frame_dir in natsorted(frame_dirs):
|
||
|
paths = [os.path.join(frame_path, frame_dir, i) for i in natsorted(os.listdir(os.path.join(frame_path, frame_dir)))]
|
||
|
frame_paths.append(paths)
|
||
|
|
||
|
save_folder = '/scratch/bortoletto/data/boss/presaved128/'+frame_path.split('/')[-2]+'/'+frame_path.split('/')[-1]
|
||
|
if not os.path.exists(save_folder):
|
||
|
os.makedirs(save_folder)
|
||
|
print(save_folder, 'created')
|
||
|
|
||
|
for video in natsorted(frame_paths):
|
||
|
print(video[0].split('/')[-2], end='\r')
|
||
|
images = [preprocess(Image.open(i)) for i in video]
|
||
|
strings = video[0].split('/')
|
||
|
with open(save_folder+'/'+strings[7]+'.pkl', 'wb') as f:
|
||
|
pkl.dump(images, f)
|
||
|
|
||
|
|
||
|
def compute_cosine_similarity(tensor1, tensor2):
|
||
|
return F.cosine_similarity(tensor1, tensor2, dim=-1)
|
||
|
|
||
|
|
||
|
def find_most_similar_embedding(rgb, pose, gaze, ocr, bbox, repr):
|
||
|
gaze_similarity = compute_cosine_similarity(gaze, repr)
|
||
|
pose_similarity = compute_cosine_similarity(pose, repr)
|
||
|
ocr_similarity = compute_cosine_similarity(ocr, repr)
|
||
|
bbox_similarity = compute_cosine_similarity(bbox, repr)
|
||
|
rgb_similarity = compute_cosine_similarity(rgb, repr)
|
||
|
similarities = torch.stack([gaze_similarity, pose_similarity, ocr_similarity, bbox_similarity, rgb_similarity])
|
||
|
max_index = torch.argmax(similarities, dim=0)
|
||
|
main_modality = []
|
||
|
main_modality_name = []
|
||
|
for idx in max_index:
|
||
|
if idx == 0:
|
||
|
main_modality.append(gaze)
|
||
|
main_modality_name.append('gaze')
|
||
|
elif idx == 1:
|
||
|
main_modality.append(pose)
|
||
|
main_modality_name.append('pose')
|
||
|
elif idx == 2:
|
||
|
main_modality.append(ocr)
|
||
|
main_modality_name.append('ocr')
|
||
|
elif idx == 3:
|
||
|
main_modality.append(bbox)
|
||
|
main_modality_name.append('bbox')
|
||
|
else:
|
||
|
main_modality.append(rgb)
|
||
|
main_modality_name.append('rgb')
|
||
|
return main_modality, main_modality_name
|
||
|
|
||
|
|
||
|
def plot_similarity_histogram(values, filename, labels):
|
||
|
def count_elements(list_of_strings, target_list):
|
||
|
counts = []
|
||
|
for element in list_of_strings:
|
||
|
count = target_list.count(element)
|
||
|
counts.append(count)
|
||
|
return counts
|
||
|
fig, ax = plt.subplots(figsize=(12,4))
|
||
|
colors = ["red", "blue", "green", "violet"]
|
||
|
# Remove spines
|
||
|
ax.spines['top'].set_visible(False)
|
||
|
ax.spines['right'].set_visible(False)
|
||
|
colors = [MTOM_COLORS['MN1'], MTOM_COLORS['MN2'], MTOM_COLORS['CG'], MTOM_COLORS['CG']]
|
||
|
alphas = [0.6, 0.6, 0.6, 0.6]
|
||
|
edgecolors = ['black', 'black', MTOM_COLORS['MN1'], MTOM_COLORS['MN2']]
|
||
|
linewidths = [1.0, 1.0, 5.0, 5.0]
|
||
|
if isinstance(values[0], list):
|
||
|
num_lists = len(values)
|
||
|
bar_width = 0.8 / num_lists
|
||
|
for i, val in enumerate(values):
|
||
|
unique_strings = ['rgb', 'pose', 'gaze', 'bbox', 'ocr']
|
||
|
counts = count_elements(unique_strings, val)
|
||
|
x = np.arange(len(unique_strings))
|
||
|
x_shifted = x + (i - (len(labels) - 1) / 2) * bar_width #x - (bar_width * num_lists / 2) + (bar_width * i)
|
||
|
ax.bar(x_shifted, counts, width=bar_width, label=f'{labels[i]}', color=colors[i], edgecolor=edgecolors[i], linewidth=linewidths[i], alpha=alphas[i])
|
||
|
ax.set_xlabel('Modality', fontsize=18)
|
||
|
ax.set_ylabel('Counts', fontsize=18)
|
||
|
ax.set_xticks(np.arange(len(unique_strings)))
|
||
|
ax.set_xticklabels(unique_strings, fontsize=18)
|
||
|
ax.set_yticklabels(ax.get_yticklabels(), fontsize=16)
|
||
|
ax.legend(fontsize=18)
|
||
|
ax.grid(axis='y')
|
||
|
plt.savefig(filename, bbox_inches='tight')
|
||
|
else:
|
||
|
unique_strings, counts = np.unique(values, return_counts=True)
|
||
|
ax.bar(unique_strings, counts)
|
||
|
ax.set_xlabel('Modality')
|
||
|
ax.set_ylabel('Counts')
|
||
|
plt.savefig(filename, bbox_inches='tight')
|
||
|
|
||
|
|
||
|
def plot_scores_histogram(data, filename, size=(8,6), rotation=0, colors=None):
|
||
|
means = []
|
||
|
stds = []
|
||
|
for key, values in data.items():
|
||
|
mean = np.mean(values)
|
||
|
std = np.std(values)
|
||
|
means.append(mean)
|
||
|
stds.append(std)
|
||
|
fig, ax = plt.subplots(figsize=size)
|
||
|
x = np.arange(len(data))
|
||
|
width = 0.6
|
||
|
rects1 = ax.bar(
|
||
|
x, means, width, label='Mean', yerr=stds,
|
||
|
capsize=5,
|
||
|
color='teal' if colors is None else colors,
|
||
|
edgecolor='black',
|
||
|
linewidth=1.5,
|
||
|
alpha=0.6
|
||
|
)
|
||
|
ax.set_ylabel('Accuracy', fontsize=18)
|
||
|
#ax.set_title('Mean and Standard Deviation of Results', fontsize=14)
|
||
|
if filename == 'results/all':
|
||
|
xticklabels = [
|
||
|
'Base',
|
||
|
'DB',
|
||
|
'CG$\otimes$', 'CG$\oplus$', 'CG$\odot$', 'CG$\parallel$',
|
||
|
'IC$\otimes$', 'IC$\oplus$', 'IC$\odot$', 'IC$\parallel$',
|
||
|
'CNN', 'CNN+GRU', 'CNN+LSTM', 'CNN+Conv1D'
|
||
|
]
|
||
|
else:
|
||
|
xticklabels = list(data.keys())
|
||
|
ax.set_xticks(np.arange(len(xticklabels)))
|
||
|
ax.set_xticklabels(xticklabels, rotation=rotation, fontsize=16) #data.keys(), rotation=rotation, fontsize=16)
|
||
|
ax.set_yticklabels(ax.get_yticklabels(), fontsize=16)
|
||
|
#ax.grid(axis='y')
|
||
|
#ax.set_axisbelow(True)
|
||
|
#ax.legend(loc='upper right')
|
||
|
# Add value labels above each bar, avoiding overlapping with error bars
|
||
|
for rect, std in zip(rects1, stds):
|
||
|
height = rect.get_height()
|
||
|
if height + std < ax.get_ylim()[1]:
|
||
|
ax.annotate(f'{height:.3f}', xy=(rect.get_x() + rect.get_width() / 2, height + std),
|
||
|
xytext=(0, 5), textcoords="offset points", ha='center', va='bottom', fontsize=14)
|
||
|
else:
|
||
|
ax.annotate(f'{height:.3f}', xy=(rect.get_x() + rect.get_width() / 2, height - std),
|
||
|
xytext=(0, -12), textcoords="offset points", ha='center', va='top', fontsize=14)
|
||
|
# Remove spines
|
||
|
ax.spines['top'].set_visible(False)
|
||
|
ax.spines['right'].set_visible(False)
|
||
|
ax.grid(axis='x')
|
||
|
plt.tight_layout()
|
||
|
plt.savefig(f'{filename}.pdf', bbox_inches='tight')
|
||
|
|
||
|
|
||
|
def plot_confusion_matrices(left_cm, right_cm, labels, title, annot=True):
|
||
|
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
|
||
|
left_xticklabels = [OBJECTS[i] for i in labels[0]]
|
||
|
left_yticklabels = [OBJECTS[i] for i in labels[1]]
|
||
|
right_xticklabels = [OBJECTS[i] for i in labels[2]]
|
||
|
right_yticklabels = [OBJECTS[i] for i in labels[3]]
|
||
|
sns.heatmap(
|
||
|
left_cm,
|
||
|
annot=annot,
|
||
|
fmt='.0f',
|
||
|
cmap='Blues',
|
||
|
cbar=False,
|
||
|
xticklabels=left_xticklabels,
|
||
|
yticklabels=left_yticklabels,
|
||
|
ax=ax1
|
||
|
)
|
||
|
ax1.set_xlabel('Predicted')
|
||
|
ax1.set_ylabel('True')
|
||
|
ax1.set_title('Left Confusion Matrix')
|
||
|
sns.heatmap(
|
||
|
right_cm,
|
||
|
annot=annot,
|
||
|
fmt='.0f',
|
||
|
cmap='Blues',
|
||
|
cbar=False, #True if annot is False else False,
|
||
|
xticklabels=right_xticklabels,
|
||
|
yticklabels=right_yticklabels,
|
||
|
ax=ax2
|
||
|
)
|
||
|
ax2.set_xlabel('Predicted')
|
||
|
ax2.set_ylabel('True')
|
||
|
ax2.set_title('Right Confusion Matrix')
|
||
|
#plt.suptitle(title)
|
||
|
plt.tight_layout()
|
||
|
plt.savefig(title + '.pdf')
|
||
|
plt.close()
|
||
|
|
||
|
def plot_confusion_matrix(confusion_matrix, labels, title, annot=True):
|
||
|
plt.figure(figsize=(6, 6))
|
||
|
xticklabels = [OBJECTS[i] for i in labels[0]]
|
||
|
yticklabels = [OBJECTS[i] for i in labels[1]]
|
||
|
sns.heatmap(
|
||
|
confusion_matrix,
|
||
|
annot=annot,
|
||
|
fmt='.0f',
|
||
|
cmap='Blues',
|
||
|
cbar=False,
|
||
|
xticklabels=xticklabels,
|
||
|
yticklabels=yticklabels
|
||
|
)
|
||
|
plt.xlabel('Predicted')
|
||
|
plt.ylabel('True')
|
||
|
#plt.title(title)
|
||
|
plt.tight_layout()
|
||
|
plt.savefig(title + '.pdf')
|
||
|
plt.close()
|
||
|
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
|
||
|
parser = argparse.ArgumentParser()
|
||
|
|
||
|
parser.add_argument('--task', type=str, choices=['confusion', 'similarity', 'scores', 'friends_vs_strangers'])
|
||
|
parser.add_argument('--folder_path', type=str)
|
||
|
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
if args.task == 'similarity':
|
||
|
sns.set_theme(style='white')
|
||
|
else:
|
||
|
sns.set_theme(style='whitegrid')
|
||
|
|
||
|
# -*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* #
|
||
|
|
||
|
if args.task == 'similarity':
|
||
|
|
||
|
out_left_main_mods_full_test = []
|
||
|
out_right_main_mods_full_test = []
|
||
|
cell_left_main_mods_full_test = []
|
||
|
cell_right_main_mods_full_test = []
|
||
|
cm_left_main_mods_full_test = []
|
||
|
cm_right_main_mods_full_test = []
|
||
|
|
||
|
for i in range(60):
|
||
|
|
||
|
print(f'Computing analysis for test video {i}...', end='\r')
|
||
|
|
||
|
emb_file = os.path.join(args.folder_path, f'{i}.pt')
|
||
|
data = torch.load(emb_file)
|
||
|
if len(data) == 14: # implicit
|
||
|
model = 'impl'
|
||
|
out_left, cell_left, out_right, cell_right, feats = data[0], data[1], data[2], data[3], data[4:]
|
||
|
out_left = out_left.squeeze(0)
|
||
|
cell_left = cell_left.squeeze(0)
|
||
|
out_right = out_right.squeeze(0)
|
||
|
cell_right = cell_right.squeeze(0)
|
||
|
elif len(data) == 13: # common mind
|
||
|
model = 'cm'
|
||
|
out_left, out_right, common_mind, feats = data[0], data[1], data[2], data[3:]
|
||
|
out_left = out_left.squeeze(0)
|
||
|
out_right = out_right.squeeze(0)
|
||
|
common_mind = common_mind.squeeze(0)
|
||
|
elif len(data) == 12: # speaker-listener
|
||
|
model = 'sl'
|
||
|
out_left, out_right, feats = data[0], data[1], data[2:]
|
||
|
out_left = out_left.squeeze(0)
|
||
|
out_right = out_right.squeeze(0)
|
||
|
else: raise ValueError("Data should have 14 (impl), 13 (cm) or 12 (sl) elements!")
|
||
|
|
||
|
# ====== PCA for left and right embeddings ====== #
|
||
|
|
||
|
out_left_and_right = np.concatenate((out_left, out_right), axis=0)
|
||
|
|
||
|
pca = PCA(n_components=2)
|
||
|
pca_result = pca.fit_transform(out_left_and_right)
|
||
|
|
||
|
# Separate the PCA results for each tensor
|
||
|
pca_result_left = pca_result[:out_left.shape[0]]
|
||
|
pca_result_right = pca_result[out_right.shape[0]:]
|
||
|
|
||
|
plt.figure(figsize=(7,6))
|
||
|
plt.scatter(pca_result_left[:, 0], pca_result_left[:, 1], label='MindNet$_1$', color=MTOM_COLORS['MN1'], s=100)
|
||
|
plt.scatter(pca_result_right[:, 0], pca_result_right[:, 1], label='MindNet$_2$', color=MTOM_COLORS['MN2'], s=100)
|
||
|
plt.xlabel('Principal Component 1', fontsize=30)
|
||
|
plt.ylabel('Principal Component 2', fontsize=30)
|
||
|
plt.grid(False)
|
||
|
plt.legend(fontsize=30)
|
||
|
plt.tight_layout()
|
||
|
plt.savefig(f'{args.folder_path}/{i}_pca.pdf')
|
||
|
plt.close()
|
||
|
|
||
|
# ====== Feature similarity ====== #
|
||
|
|
||
|
if len(feats) == 10:
|
||
|
left_rgb, left_ocr, left_pose, left_gaze, left_bbox, right_rgb, right_ocr, right_pose, right_gaze, right_bbox = feats
|
||
|
left_rgb = left_rgb.squeeze(0)
|
||
|
left_ocr = left_ocr.squeeze(0)
|
||
|
left_pose = left_pose.squeeze(0)
|
||
|
left_gaze = left_gaze.squeeze(0)
|
||
|
left_bbox = left_bbox.squeeze(0)
|
||
|
right_rgb = right_rgb.squeeze(0)
|
||
|
right_ocr = right_ocr.squeeze(0)
|
||
|
right_pose = right_pose.squeeze(0)
|
||
|
right_gaze = right_gaze.squeeze(0)
|
||
|
right_bbox = right_bbox.squeeze(0)
|
||
|
else: raise NotImplementedError("Ablated versions are not supported yet.")
|
||
|
|
||
|
# out: [1, seq_len, dim] --- squeeze ---> [seq_len, dim]
|
||
|
# cell: [1, 1, dim] --------- squeeze ---> [1, dim]
|
||
|
# cm: [1, seq_len, dim] --- squeeze ---> [seq_len, dim]
|
||
|
# feat: [1, seq_len, dim] --- squeeze ---> [seq_len, dim]
|
||
|
|
||
|
_, out_left_main_mods = find_most_similar_embedding(left_rgb, left_pose, left_gaze, left_ocr, left_bbox, out_left)
|
||
|
_, out_right_main_mods = find_most_similar_embedding(right_rgb, right_pose, right_gaze, right_ocr, right_bbox, out_right)
|
||
|
out_left_main_mods_full_test += out_left_main_mods
|
||
|
out_right_main_mods_full_test += out_right_main_mods
|
||
|
if model == 'impl':
|
||
|
_, cell_left_main_mods = find_most_similar_embedding(left_rgb, left_pose, left_gaze, left_ocr, left_bbox, cell_left)
|
||
|
_, cell_right_main_mods = find_most_similar_embedding(right_rgb, right_pose, right_gaze, right_ocr, right_bbox, cell_right)
|
||
|
cell_left_main_mods_full_test += cell_left_main_mods
|
||
|
cell_right_main_mods_full_test += cell_right_main_mods
|
||
|
if model == 'cm':
|
||
|
_, cm_left_main_mods = find_most_similar_embedding(left_rgb, left_pose, left_gaze, left_ocr, left_bbox, common_mind)
|
||
|
_, cm_right_main_mods = find_most_similar_embedding(right_rgb, right_pose, right_gaze, right_ocr, right_bbox, common_mind)
|
||
|
cm_left_main_mods_full_test += cm_left_main_mods
|
||
|
cm_right_main_mods_full_test += cm_right_main_mods
|
||
|
|
||
|
if model == 'impl':
|
||
|
plot_similarity_histogram(
|
||
|
[out_left_main_mods_full_test,
|
||
|
out_right_main_mods_full_test,
|
||
|
cell_left_main_mods_full_test,
|
||
|
cell_right_main_mods_full_test],
|
||
|
f'{args.folder_path}/boss_similartiy_impl_hist_all.pdf',
|
||
|
[r'$h_1$', r'$h_2$', r'$c_1$', r'$c_2$']
|
||
|
)
|
||
|
elif model == 'cm':
|
||
|
plot_similarity_histogram(
|
||
|
[out_left_main_mods_full_test,
|
||
|
out_right_main_mods_full_test,
|
||
|
cm_left_main_mods_full_test,
|
||
|
cm_right_main_mods_full_test],
|
||
|
f'{args.folder_path}/boss_similarity_cm_hist_all.pdf',
|
||
|
[r'$h_1$', r'$h_2$', r'$cg$ w/ 1', r'$cg$ w/ 2']
|
||
|
)
|
||
|
elif model == 'sl':
|
||
|
plot_similarity_histogram(
|
||
|
[out_left_main_mods_full_test,
|
||
|
out_right_main_mods_full_test],
|
||
|
f'{args.folder_path}/boss_similarity_sl_hist_all.py',
|
||
|
[r'$h_1$', r'$h_2$']
|
||
|
)
|
||
|
|
||
|
# -*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* #
|
||
|
|
||
|
elif args.task == 'scores':
|
||
|
|
||
|
|
||
|
all = json.load(open('results/all.json'))
|
||
|
abl_tom_cm_concat = json.load(open('results/abl_cm_concat.json'))
|
||
|
|
||
|
plot_scores_histogram(
|
||
|
all,
|
||
|
filename='results/all',
|
||
|
size=(10,4),
|
||
|
rotation=45,
|
||
|
colors=[COLORS[7]] + [MTOM_COLORS['DB']] + [MTOM_COLORS['CG']]*4 + [MTOM_COLORS['IC']]*4 + [COLORS[4]]*4
|
||
|
)
|
||
|
|
||
|
plot_scores_histogram(
|
||
|
abl_tom_cm_concat,
|
||
|
size=(10,4),
|
||
|
filename='results/abl_cm_concat',
|
||
|
colors=[MTOM_COLORS['CG']]*8
|
||
|
)
|
||
|
|
||
|
# -*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* #
|
||
|
|
||
|
elif args.task == 'confusion':
|
||
|
|
||
|
# List to store confusion matrices
|
||
|
all_df = []
|
||
|
left_matrices = []
|
||
|
right_matrices = []
|
||
|
|
||
|
# Iterate over the CSV files
|
||
|
for filename in natsorted(os.listdir(args.folder_path)):
|
||
|
if filename.endswith('.csv'):
|
||
|
file_path = os.path.join(args.folder_path, filename)
|
||
|
|
||
|
print(f'Processing {file_path}...')
|
||
|
|
||
|
# Read the CSV file
|
||
|
df = pd.read_csv(file_path)
|
||
|
|
||
|
# Extract the left and right labels and predictions
|
||
|
left_labels = df['left_label']
|
||
|
right_labels = df['right_label']
|
||
|
left_preds = df['left_pred']
|
||
|
right_preds = df['right_pred']
|
||
|
|
||
|
# Calculate the confusion matrices for left and right
|
||
|
left_cm = pd.crosstab(left_labels, left_preds)
|
||
|
right_cm = pd.crosstab(right_labels, right_preds)
|
||
|
|
||
|
# Append the confusion matrices to the list
|
||
|
left_matrices.append(left_cm)
|
||
|
right_matrices.append(right_cm)
|
||
|
all_df.append(df)
|
||
|
|
||
|
# Plot and save the confusion matrices for left and right
|
||
|
for i, cm in enumerate(zip(left_matrices, right_matrices)):
|
||
|
print(f'Computing confusion matrices for video {i}...', end='\r')
|
||
|
plot_confusion_matrices(
|
||
|
cm[0],
|
||
|
cm[1],
|
||
|
labels=[cm[0].columns, cm[0].index, cm[1].columns, cm[1].index],
|
||
|
title=f'{args.folder_path}/{i}_cm'
|
||
|
)
|
||
|
|
||
|
merged_df = pd.concat(all_df).reset_index(drop=True)
|
||
|
|
||
|
merged_left_cm = pd.crosstab(merged_df['left_label'], merged_df['left_pred'])
|
||
|
merged_right_cm = pd.crosstab(merged_df['right_label'], merged_df['right_pred'])
|
||
|
plot_confusion_matrices(
|
||
|
merged_left_cm,
|
||
|
merged_right_cm,
|
||
|
labels=[merged_left_cm.columns, merged_left_cm.index, merged_right_cm.columns, merged_right_cm.index],
|
||
|
title=f'{args.folder_path}/all_lr',
|
||
|
annot=False
|
||
|
)
|
||
|
|
||
|
merged_preds = pd.concat([merged_df['left_pred'], merged_df['right_pred']]).reset_index(drop=True)
|
||
|
merged_labels = pd.concat([merged_df['left_label'], merged_df['right_label']]).reset_index(drop=True)
|
||
|
merged_all_cm = pd.crosstab(merged_labels, merged_preds)
|
||
|
plot_confusion_matrix(
|
||
|
merged_all_cm,
|
||
|
labels=[merged_all_cm.columns, merged_all_cm.index],
|
||
|
title=f'{args.folder_path}/all_merged',
|
||
|
annot=False
|
||
|
)
|
||
|
|
||
|
# -*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* #
|
||
|
|
||
|
elif args.task == 'friends_vs_strangers':
|
||
|
|
||
|
# friends ids: 30-59
|
||
|
# stranger ids: 0-29
|
||
|
out_left_list = []
|
||
|
out_right_list = []
|
||
|
cell_left_list = []
|
||
|
cell_right_list = []
|
||
|
common_mind_list = []
|
||
|
|
||
|
for i in range(60):
|
||
|
|
||
|
print(f'Computing analysis for test video {i}...', end='\r')
|
||
|
|
||
|
emb_file = os.path.join(args.folder_path, f'{i}.pt')
|
||
|
data = torch.load(emb_file)
|
||
|
if len(data) == 14: # implicit
|
||
|
model = 'impl'
|
||
|
out_left, cell_left, out_right, cell_right, feats = data[0], data[1], data[2], data[3], data[4:]
|
||
|
out_left = out_left.squeeze(0)
|
||
|
cell_left = cell_left.squeeze(0)
|
||
|
out_right = out_right.squeeze(0)
|
||
|
cell_right = cell_right.squeeze(0)
|
||
|
out_left_list.append(out_left)
|
||
|
out_right_list.append(out_right)
|
||
|
cell_left_list.append(cell_left)
|
||
|
cell_right_list.append(cell_right)
|
||
|
elif len(data) == 13: # common mind
|
||
|
model = 'cm'
|
||
|
out_left, out_right, common_mind, feats = data[0], data[1], data[2], data[3:]
|
||
|
out_left = out_left.squeeze(0)
|
||
|
out_right = out_right.squeeze(0)
|
||
|
common_mind = common_mind.squeeze(0)
|
||
|
out_left_list.append(out_left)
|
||
|
out_right_list.append(out_right)
|
||
|
common_mind_list.append(common_mind)
|
||
|
elif len(data) == 12: # speaker-listener
|
||
|
model = 'sl'
|
||
|
out_left, out_right, feats = data[0], data[1], data[2:]
|
||
|
out_left = out_left.squeeze(0)
|
||
|
out_right = out_right.squeeze(0)
|
||
|
out_left_list.append(out_left)
|
||
|
out_right_list.append(out_right)
|
||
|
else: raise ValueError("Data should have 14 (impl), 13 (cm) or 12 (sl) elements!")
|
||
|
|
||
|
# ====== PCA for left and right embeddings ====== #
|
||
|
|
||
|
print('\rComputing PCA...')
|
||
|
|
||
|
strangers_nframes = sum([out_left_list[i].shape[0] for i in range(30)])
|
||
|
|
||
|
left = torch.cat(out_left_list, 0)
|
||
|
right = torch.cat(out_right_list, 0)
|
||
|
|
||
|
out_left_and_right = torch.cat([left, right], axis=0).numpy()
|
||
|
#out_left_and_right = StandardScaler().fit_transform(out_left_and_right)
|
||
|
|
||
|
pca = PCA(n_components=2)
|
||
|
pca_result = pca.fit_transform(out_left_and_right)
|
||
|
#pca_result = TSNE(n_components=2, learning_rate='auto', init='random', perplexity=3).fit_transform(out_left_and_right)
|
||
|
#pca_result = UMAP().fit_transform(out_left_and_right)
|
||
|
|
||
|
# Separate the PCA results for each tensor
|
||
|
#pca_result_left = pca_result[:left.shape[0]]
|
||
|
#pca_result_right = pca_result[right.shape[0]:]
|
||
|
pca_result_strangers = pca_result[:strangers_nframes]
|
||
|
pca_result_friends = pca_result[strangers_nframes:]
|
||
|
|
||
|
#plt.scatter(pca_result_left[:, 0], pca_result_left[:, 1], label='Left')
|
||
|
#plt.scatter(pca_result_right[:, 0], pca_result_right[:, 1], label='Right')
|
||
|
plt.scatter(pca_result_friends[:, 0], pca_result_friends[:, 1], label='Friends')
|
||
|
plt.scatter(pca_result_strangers[:, 0], pca_result_strangers[:, 1], label='Strangers')
|
||
|
plt.xlabel('Principal Component 1')
|
||
|
plt.ylabel('Principal Component 2')
|
||
|
plt.legend()
|
||
|
plt.savefig(f'{args.folder_path}/friends_vs_strangers_pca.pdf')
|
||
|
plt.close()
|
||
|
|
||
|
# -*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-* #
|
||
|
|
||
|
else:
|
||
|
|
||
|
raise NameError
|