Update models/avsd_bart.py

This commit is contained in:
Adnen Abdessaied 2024-10-17 14:09:53 +02:00
parent 0c864f23b1
commit 03ec3544b9

View file

@ -645,11 +645,11 @@ class BartEncoder(BartPretrainedModel):
history_X, history_node_idx = track_features_text(history_hidden, history_att, self.config.top_k, device) history_X, history_node_idx = track_features_text(history_hidden, history_att, self.config.top_k, device)
question_X, question_node_idx = track_features_text(question_hidden, question_att, self.config.top_k, device) question_X, question_node_idx = track_features_text(question_hidden, question_att, self.config.top_k, device)
# NOTE: The indices need to be adjusted to match the global input # NOTE: The indices need to be adjusted (not inplace) to match the global input
i3d_rgb_node_idx += 1 i3d_rgb_node_idx = i3d_rgb_node_idx + 1
i3d_flow_node_idx += i3d_flow_interval[0] + 1 i3d_flow_node_idx = i3d_flow_node_idx + i3d_flow_interval[0] + 1
sam_node_idx += sam_interval[0] + 1 sam_node_idx = sam_node_idx + sam_interval[0] + 1
audio_node_idx += audio_interval[0] + 1 audio_node_idx = audio_node_idx + audio_interval[0] + 1
history_node_idx = [x + history_intervals[0][0] + 1 for x in history_node_idx] history_node_idx = [x + history_intervals[0][0] + 1 for x in history_node_idx]
question_node_idx = [x + qi[0] + 1 for x, qi in zip(question_node_idx, question_intervals)] question_node_idx = [x + qi[0] + 1 for x, qi in zip(question_node_idx, question_intervals)]