Update models/utils.py

This commit is contained in:
Adnen Abdessaied 2024-09-20 09:32:14 +02:00
parent 985a3f8468
commit 0c864f23b1

View file

@ -140,11 +140,7 @@ def get_knn_graph(features, num_nn, device):
def track_features_vis(features, att, top_k, device, node_idx=None):
"""Computes an adjacency matrix based on the nearset neighbor similiarity for
the i3d, audio, and sam input modalities. The tracked constituents of each modality
are randomly chosen (A_tilde in the paper).
"""
features = features.clone().detach()
top_k = min(features.size(1), top_k)
if att is None:
node_idx = torch.randint(low=0, high=features.size(1), size=(features.size(0), top_k))
@ -159,10 +155,7 @@ def track_features_vis(features, att, top_k, device, node_idx=None):
def track_features_text(features, att, top_k, device, node_idx=None):
"""Computes an adjacency matrix based on the nearset neighbor similiarity for
the history and question inputs. The tracked constituents of each modality
are randomly chosen (A_tilde in the paper).
"""
hidden_dim = features[0].size(-1)
min_len = min([feat.size(1) for feat in features])
top_k = min(min_len, top_k)