diff --git a/models/utils.py b/models/utils.py index b97e2e1..8a5ff26 100644 --- a/models/utils.py +++ b/models/utils.py @@ -140,11 +140,7 @@ def get_knn_graph(features, num_nn, device): def track_features_vis(features, att, top_k, device, node_idx=None): - """Computes an adjacency matrix based on the nearset neighbor similiarity for - the i3d, audio, and sam input modalities. The tracked constituents of each modality - are randomly chosen (A_tilde in the paper). - """ - features = features.clone().detach() + top_k = min(features.size(1), top_k) if att is None: node_idx = torch.randint(low=0, high=features.size(1), size=(features.size(0), top_k)) @@ -159,10 +155,7 @@ def track_features_vis(features, att, top_k, device, node_idx=None): def track_features_text(features, att, top_k, device, node_idx=None): - """Computes an adjacency matrix based on the nearset neighbor similiarity for - the history and question inputs. The tracked constituents of each modality - are randomly chosen (A_tilde in the paper). - """ + hidden_dim = features[0].size(-1) min_len = min([feat.size(1) for feat in features]) top_k = min(min_len, top_k)