From 0c864f23b1696b6098d3c448a7c368d373c542b0 Mon Sep 17 00:00:00 2001 From: abdessaied Date: Fri, 20 Sep 2024 09:32:14 +0200 Subject: [PATCH] Update models/utils.py --- models/utils.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/models/utils.py b/models/utils.py index b97e2e1..8a5ff26 100644 --- a/models/utils.py +++ b/models/utils.py @@ -140,11 +140,7 @@ def get_knn_graph(features, num_nn, device): def track_features_vis(features, att, top_k, device, node_idx=None): - """Computes an adjacency matrix based on the nearset neighbor similiarity for - the i3d, audio, and sam input modalities. The tracked constituents of each modality - are randomly chosen (A_tilde in the paper). - """ - features = features.clone().detach() + top_k = min(features.size(1), top_k) if att is None: node_idx = torch.randint(low=0, high=features.size(1), size=(features.size(0), top_k)) @@ -159,10 +155,7 @@ def track_features_vis(features, att, top_k, device, node_idx=None): def track_features_text(features, att, top_k, device, node_idx=None): - """Computes an adjacency matrix based on the nearset neighbor similiarity for - the history and question inputs. The tracked constituents of each modality - are randomly chosen (A_tilde in the paper). - """ + hidden_dim = features[0].size(-1) min_len = min([feat.size(1) for feat in features]) top_k = min(min_len, top_k)