Update models/avsd_bart.py
This commit is contained in:
parent
0c864f23b1
commit
03ec3544b9
1 changed files with 5 additions and 5 deletions
|
@ -645,11 +645,11 @@ class BartEncoder(BartPretrainedModel):
|
|||
history_X, history_node_idx = track_features_text(history_hidden, history_att, self.config.top_k, device)
|
||||
question_X, question_node_idx = track_features_text(question_hidden, question_att, self.config.top_k, device)
|
||||
|
||||
# NOTE: The indices need to be adjusted to match the global input
|
||||
i3d_rgb_node_idx += 1
|
||||
i3d_flow_node_idx += i3d_flow_interval[0] + 1
|
||||
sam_node_idx += sam_interval[0] + 1
|
||||
audio_node_idx += audio_interval[0] + 1
|
||||
# NOTE: The indices need to be adjusted (not inplace) to match the global input
|
||||
i3d_rgb_node_idx = i3d_rgb_node_idx + 1
|
||||
i3d_flow_node_idx = i3d_flow_node_idx + i3d_flow_interval[0] + 1
|
||||
sam_node_idx = sam_node_idx + sam_interval[0] + 1
|
||||
audio_node_idx = audio_node_idx + audio_interval[0] + 1
|
||||
history_node_idx = [x + history_intervals[0][0] + 1 for x in history_node_idx]
|
||||
question_node_idx = [x + qi[0] + 1 for x, qi in zip(question_node_idx, question_intervals)]
|
||||
|
||||
|
|
Loading…
Reference in a new issue