diff --git a/models/avsd_bart.py b/models/avsd_bart.py index 3bd388b..39f5eca 100644 --- a/models/avsd_bart.py +++ b/models/avsd_bart.py @@ -645,11 +645,11 @@ class BartEncoder(BartPretrainedModel): history_X, history_node_idx = track_features_text(history_hidden, history_att, self.config.top_k, device) question_X, question_node_idx = track_features_text(question_hidden, question_att, self.config.top_k, device) - # NOTE: The indices need to be adjusted to match the global input - i3d_rgb_node_idx += 1 - i3d_flow_node_idx += i3d_flow_interval[0] + 1 - sam_node_idx += sam_interval[0] + 1 - audio_node_idx += audio_interval[0] + 1 + # NOTE: The indices need to be adjusted (not inplace) to match the global input + i3d_rgb_node_idx = i3d_rgb_node_idx + 1 + i3d_flow_node_idx = i3d_flow_node_idx + i3d_flow_interval[0] + 1 + sam_node_idx = sam_node_idx + sam_interval[0] + 1 + audio_node_idx = audio_node_idx + audio_interval[0] + 1 history_node_idx = [x + history_intervals[0][0] + 1 for x in history_node_idx] question_node_idx = [x + qi[0] + 1 for x, qi in zip(question_node_idx, question_intervals)]