Update models/utils.py
This commit is contained in:
parent
985a3f8468
commit
0c864f23b1
1 changed files with 2 additions and 9 deletions
|
@ -140,11 +140,7 @@ def get_knn_graph(features, num_nn, device):
|
||||||
|
|
||||||
|
|
||||||
def track_features_vis(features, att, top_k, device, node_idx=None):
|
def track_features_vis(features, att, top_k, device, node_idx=None):
|
||||||
"""Computes an adjacency matrix based on the nearset neighbor similiarity for
|
|
||||||
the i3d, audio, and sam input modalities. The tracked constituents of each modality
|
|
||||||
are randomly chosen (A_tilde in the paper).
|
|
||||||
"""
|
|
||||||
features = features.clone().detach()
|
|
||||||
top_k = min(features.size(1), top_k)
|
top_k = min(features.size(1), top_k)
|
||||||
if att is None:
|
if att is None:
|
||||||
node_idx = torch.randint(low=0, high=features.size(1), size=(features.size(0), top_k))
|
node_idx = torch.randint(low=0, high=features.size(1), size=(features.size(0), top_k))
|
||||||
|
@ -159,10 +155,7 @@ def track_features_vis(features, att, top_k, device, node_idx=None):
|
||||||
|
|
||||||
|
|
||||||
def track_features_text(features, att, top_k, device, node_idx=None):
|
def track_features_text(features, att, top_k, device, node_idx=None):
|
||||||
"""Computes an adjacency matrix based on the nearset neighbor similiarity for
|
|
||||||
the history and question inputs. The tracked constituents of each modality
|
|
||||||
are randomly chosen (A_tilde in the paper).
|
|
||||||
"""
|
|
||||||
hidden_dim = features[0].size(-1)
|
hidden_dim = features[0].size(-1)
|
||||||
min_len = min([feat.size(1) for feat in features])
|
min_len = min([feat.size(1) for feat in features])
|
||||||
top_k = min(min_len, top_k)
|
top_k = min(min_len, top_k)
|
||||||
|
|
Loading…
Reference in a new issue