diff --git a/models/nextqa_bart.py b/models/nextqa_bart.py index 4729be0..59ab6f1 100644 --- a/models/nextqa_bart.py +++ b/models/nextqa_bart.py @@ -635,8 +635,8 @@ class BartEncoder(BartPretrainedModel): question_X, question_node_idx = track_features_text(question_hidden, question_att, self.config.top_k, device) # NOTE: The indices need to be adjusted to match the global input - i3d_rgb_node_idx += 1 - i3d_flow_node_idx += i3d_flow_interval[0] + 1 + i3d_rgb_node_idx = i3d_rgb_node_idx + 1 + i3d_flow_node_idx = i3d_flow_node_idx + i3d_flow_interval[0] + 1 question_node_idx = [x + qi[0] + 1 for x, qi in zip(question_node_idx, question_intervals)]