2022 lines
86 KiB
Python
2022 lines
86 KiB
Python
|
# coding=utf-8
|
||
|
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
||
|
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
"""PyTorch BERT model."""
|
||
|
|
||
|
import copy
|
||
|
import json
|
||
|
import logging
|
||
|
import math
|
||
|
import os
|
||
|
import shutil
|
||
|
import tarfile
|
||
|
import tempfile
|
||
|
import sys
|
||
|
from io import open
|
||
|
|
||
|
import torch
|
||
|
from torch import nn
|
||
|
from torch.nn import CrossEntropyLoss
|
||
|
import torch.nn.functional as F
|
||
|
from torch.nn.utils.weight_norm import weight_norm
|
||
|
from pytorch_transformers.modeling_bert import BertEmbeddings
|
||
|
from utils.data_utils import sequence_mask, to_data_list
|
||
|
import torch_geometric.nn as pyg_nn
|
||
|
from torch_geometric.data import Data
|
||
|
from torch_geometric.loader import DataLoader
|
||
|
from pytorch_pretrained_bert.file_utils import cached_path
|
||
|
import pdb
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||
|
"bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz",
|
||
|
"bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz",
|
||
|
"bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz",
|
||
|
"bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz",
|
||
|
"bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz",
|
||
|
"bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz",
|
||
|
"bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz",
|
||
|
}
|
||
|
|
||
|
def load_tf_weights_in_bert(model, tf_checkpoint_path):
|
||
|
""" Load tf checkpoints in a pytorch model
|
||
|
"""
|
||
|
try:
|
||
|
import re
|
||
|
import numpy as np
|
||
|
import tensorflow as tf
|
||
|
except ImportError:
|
||
|
print(
|
||
|
"Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
|
||
|
"https://www.tensorflow.org/install/ for installation instructions."
|
||
|
)
|
||
|
raise
|
||
|
tf_path = os.path.abspath(tf_checkpoint_path)
|
||
|
print("Converting TensorFlow checkpoint from {}".format(tf_path))
|
||
|
# Load weights from TF model
|
||
|
init_vars = tf.train.list_variables(tf_path)
|
||
|
names = []
|
||
|
arrays = []
|
||
|
for name, shape in init_vars:
|
||
|
print("Loading TF weight {} with shape {}".format(name, shape))
|
||
|
array = tf.train.load_variable(tf_path, name)
|
||
|
names.append(name)
|
||
|
arrays.append(array)
|
||
|
|
||
|
for name, array in zip(names, arrays):
|
||
|
name = name.split("/")
|
||
|
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
||
|
# which are not required for using pretrained model
|
||
|
if any(n in ["adam_v", "adam_m"] for n in name):
|
||
|
print("Skipping {}".format("/".join(name)))
|
||
|
continue
|
||
|
pointer = model
|
||
|
for m_name in name:
|
||
|
if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
|
||
|
l = re.split(r"_(\d+)", m_name)
|
||
|
else:
|
||
|
l = [m_name]
|
||
|
if l[0] == "kernel" or l[0] == "gamma":
|
||
|
pointer = getattr(pointer, "weight")
|
||
|
elif l[0] == "output_bias" or l[0] == "beta":
|
||
|
pointer = getattr(pointer, "bias")
|
||
|
elif l[0] == "output_weights":
|
||
|
pointer = getattr(pointer, "weight")
|
||
|
else:
|
||
|
pointer = getattr(pointer, l[0])
|
||
|
if len(l) >= 2:
|
||
|
num = int(l[1])
|
||
|
pointer = pointer[num]
|
||
|
if m_name[-11:] == "_embeddings":
|
||
|
pointer = getattr(pointer, "weight")
|
||
|
elif m_name == "kernel":
|
||
|
array = np.transpose(array)
|
||
|
try:
|
||
|
assert pointer.shape == array.shape
|
||
|
except AssertionError as e:
|
||
|
e.args += (pointer.shape, array.shape)
|
||
|
raise
|
||
|
print("Initialize PyTorch weight {}".format(name))
|
||
|
pointer.data = torch.from_numpy(array)
|
||
|
return model
|
||
|
|
||
|
class GeLU(nn.Module):
|
||
|
"""Implementation of the gelu activation function.
|
||
|
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
|
||
|
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||
|
Also see https://arxiv.org/abs/1606.08415
|
||
|
"""
|
||
|
|
||
|
def __init__(self):
|
||
|
super(GeLU, self).__init__()
|
||
|
|
||
|
def forward(self, x):
|
||
|
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
||
|
|
||
|
|
||
|
def gelu(x):
|
||
|
"""Implementation of the gelu activation function.
|
||
|
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
|
||
|
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||
|
Also see https://arxiv.org/abs/1606.08415
|
||
|
"""
|
||
|
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
||
|
|
||
|
|
||
|
def swish(x):
|
||
|
return x * torch.sigmoid(x)
|
||
|
|
||
|
|
||
|
ACT2FN = {"GeLU": GeLU(), "gelu": gelu,
|
||
|
"relu": torch.nn.functional.relu, "swish": swish}
|
||
|
|
||
|
class BertConfig(object):
|
||
|
"""Configuration class to store the configuration of a `BertModel`.
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
vocab_size_or_config_json_file,
|
||
|
hidden_size=768,
|
||
|
num_hidden_layers=12,
|
||
|
num_attention_heads=12,
|
||
|
intermediate_size=3072,
|
||
|
hidden_act="gelu",
|
||
|
hidden_dropout_prob=0.1,
|
||
|
attention_probs_dropout_prob=0.1,
|
||
|
max_position_embeddings=512,
|
||
|
type_vocab_size=2,
|
||
|
initializer_range=0.02,
|
||
|
v_feature_size=2048,
|
||
|
v_target_size=1601,
|
||
|
v_hidden_size=768,
|
||
|
v_num_hidden_layers=3,
|
||
|
v_num_attention_heads=12,
|
||
|
v_intermediate_size=3072,
|
||
|
bi_hidden_size=1024,
|
||
|
bi_num_attention_heads=16,
|
||
|
v_attention_probs_dropout_prob=0.1,
|
||
|
v_hidden_act="gelu",
|
||
|
v_hidden_dropout_prob=0.1,
|
||
|
v_initializer_range=0.2,
|
||
|
v_biattention_id=[0, 1],
|
||
|
t_biattention_id=[10, 11],
|
||
|
predict_feature=False,
|
||
|
fast_mode=False,
|
||
|
fixed_v_layer=0,
|
||
|
fixed_t_layer=0,
|
||
|
in_batch_pairs=False,
|
||
|
fusion_method="mul",
|
||
|
intra_gate=False,
|
||
|
with_coattention=True
|
||
|
):
|
||
|
|
||
|
"""Constructs BertConfig.
|
||
|
|
||
|
Args:
|
||
|
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
|
||
|
hidden_size: Size of the encoder layers and the pooler layer.
|
||
|
num_hidden_layers: Number of hidden layers in the Transformer encoder.
|
||
|
num_attention_heads: Number of attention heads for each attention layer in
|
||
|
the Transformer encoder.
|
||
|
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
|
||
|
layer in the Transformer encoder.
|
||
|
hidden_act: The non-linear activation function (function or string) in the
|
||
|
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
|
||
|
hidden_dropout_prob: The dropout probabilitiy for all fully connected
|
||
|
layers in the embeddings, encoder, and pooler.
|
||
|
attention_probs_dropout_prob: The dropout ratio for the attention
|
||
|
probabilities.
|
||
|
max_position_embeddings: The maximum sequence length that this model might
|
||
|
ever be used with. Typically set this to something large just in case
|
||
|
(e.g., 512 or 1024 or 2048).
|
||
|
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
|
||
|
`BertModel`.
|
||
|
initializer_range: The sttdev of the truncated_normal_initializer for
|
||
|
initializing all weight matrices.
|
||
|
"""
|
||
|
assert len(v_biattention_id) == len(t_biattention_id)
|
||
|
assert max(v_biattention_id) < v_num_hidden_layers
|
||
|
assert max(t_biattention_id) < num_hidden_layers
|
||
|
|
||
|
if isinstance(vocab_size_or_config_json_file, str) or (
|
||
|
sys.version_info[0] == 2
|
||
|
and isinstance(vocab_size_or_config_json_file, unicode)
|
||
|
):
|
||
|
with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader:
|
||
|
json_config = json.loads(reader.read())
|
||
|
for key, value in json_config.items():
|
||
|
self.__dict__[key] = value
|
||
|
elif isinstance(vocab_size_or_config_json_file, int):
|
||
|
self.vocab_size = vocab_size_or_config_json_file
|
||
|
self.hidden_size = hidden_size
|
||
|
self.num_hidden_layers = num_hidden_layers
|
||
|
self.num_attention_heads = num_attention_heads
|
||
|
self.hidden_act = hidden_act
|
||
|
self.intermediate_size = intermediate_size
|
||
|
self.hidden_dropout_prob = hidden_dropout_prob
|
||
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||
|
self.max_position_embeddings = max_position_embeddings
|
||
|
self.type_vocab_size = type_vocab_size
|
||
|
self.initializer_range = initializer_range
|
||
|
self.v_feature_size = v_feature_size
|
||
|
self.v_hidden_size = v_hidden_size
|
||
|
self.v_num_hidden_layers = v_num_hidden_layers
|
||
|
self.v_num_attention_heads = v_num_attention_heads
|
||
|
self.v_intermediate_size = v_intermediate_size
|
||
|
self.v_attention_probs_dropout_prob = v_attention_probs_dropout_prob
|
||
|
self.v_hidden_act = v_hidden_act
|
||
|
self.v_hidden_dropout_prob = v_hidden_dropout_prob
|
||
|
self.v_initializer_range = v_initializer_range
|
||
|
self.v_biattention_id = v_biattention_id
|
||
|
self.t_biattention_id = t_biattention_id
|
||
|
self.v_target_size = v_target_size
|
||
|
self.bi_hidden_size = bi_hidden_size
|
||
|
self.bi_num_attention_heads = bi_num_attention_heads
|
||
|
self.predict_feature = predict_feature
|
||
|
self.fast_mode = fast_mode
|
||
|
self.fixed_v_layer = fixed_v_layer
|
||
|
self.fixed_t_layer = fixed_t_layer
|
||
|
|
||
|
self.in_batch_pairs = in_batch_pairs
|
||
|
self.fusion_method = fusion_method
|
||
|
self.intra_gate = intra_gate
|
||
|
self.with_coattention=with_coattention
|
||
|
else:
|
||
|
raise ValueError(
|
||
|
"First argument must be either a vocabulary size (int)"
|
||
|
"or the path to a pretrained model config file (str)"
|
||
|
)
|
||
|
|
||
|
@classmethod
|
||
|
def from_dict(cls, json_object):
|
||
|
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
|
||
|
config = BertConfig(vocab_size_or_config_json_file=-1)
|
||
|
for key, value in json_object.items():
|
||
|
config.__dict__[key] = value
|
||
|
return config
|
||
|
|
||
|
@classmethod
|
||
|
def from_json_file(cls, json_file):
|
||
|
"""Constructs a `BertConfig` from a json file of parameters."""
|
||
|
with open(json_file, "r", encoding="utf-8") as reader:
|
||
|
text = reader.read()
|
||
|
return cls.from_dict(json.loads(text))
|
||
|
|
||
|
def __repr__(self):
|
||
|
return str(self.to_json_string())
|
||
|
|
||
|
def to_dict(self):
|
||
|
"""Serializes this instance to a Python dictionary."""
|
||
|
output = copy.deepcopy(self.__dict__)
|
||
|
return output
|
||
|
|
||
|
def to_json_string(self):
|
||
|
"""Serializes this instance to a JSON string."""
|
||
|
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
||
|
|
||
|
try:
|
||
|
# from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
|
||
|
import torch.nn.LayerNorm as BertLayerNorm
|
||
|
except ImportError:
|
||
|
# logger.info(
|
||
|
# "Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex ."
|
||
|
# )
|
||
|
pass
|
||
|
|
||
|
class BertLayerNorm(nn.Module):
|
||
|
def __init__(self, hidden_size, eps=1e-12):
|
||
|
"""Construct a layernorm module in the TF style (epsilon inside the square root).
|
||
|
"""
|
||
|
super(BertLayerNorm, self).__init__()
|
||
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||
|
self.bias = nn.Parameter(torch.zeros(hidden_size))
|
||
|
self.variance_epsilon = eps
|
||
|
|
||
|
def forward(self, x):
|
||
|
u = x.mean(-1, keepdim=True)
|
||
|
s = (x - u).pow(2).mean(-1, keepdim=True)
|
||
|
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
|
||
|
return self.weight * x + self.bias
|
||
|
|
||
|
class BertEmbeddingsDialog(nn.Module):
|
||
|
def __init__(self, config, device):
|
||
|
super(BertEmbeddingsDialog, self).__init__()
|
||
|
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
|
||
|
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
||
|
max_seq_len = 256
|
||
|
d_model = config.hidden_size
|
||
|
pe = torch.zeros(max_seq_len, d_model)
|
||
|
for pos in range(max_seq_len):
|
||
|
for i in range(0, d_model, 2):
|
||
|
pe[pos, i] = \
|
||
|
math.sin(pos / (10000 ** ((2 * i)/d_model)))
|
||
|
pe[pos, i + 1] = \
|
||
|
math.cos(pos / (10000 ** ((2 * (i + 1))/d_model)))
|
||
|
self.pe = pe.to(device)
|
||
|
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
||
|
# add support for additional segment embeddings. Supporting 10 additional embedding as of now
|
||
|
self.token_type_embeddings_extension = nn.Embedding(10,config.hidden_size)
|
||
|
# adding specialized embeddings for sep tokens
|
||
|
self.sep_embeddings = nn.Embedding(50,config.hidden_size)
|
||
|
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
||
|
# any TensorFlow checkpoint file
|
||
|
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
|
||
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||
|
self.config = config
|
||
|
|
||
|
def forward(self, input_ids, sep_indices=None, sep_len=None, token_type_ids=None):
|
||
|
seq_length = input_ids.size(1)
|
||
|
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
|
||
|
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
||
|
if token_type_ids is None:
|
||
|
token_type_ids = torch.zeros_like(input_ids)
|
||
|
|
||
|
words_embeddings = self.word_embeddings(input_ids)
|
||
|
position_embeddings = self.position_embeddings(position_ids)
|
||
|
|
||
|
token_type_ids_extension = token_type_ids - self.config.type_vocab_size
|
||
|
token_type_ids_extension_mask = (token_type_ids_extension >= 0).float()
|
||
|
token_type_ids_extension = (token_type_ids_extension.float() * token_type_ids_extension_mask).long()
|
||
|
|
||
|
token_type_ids_mask = (token_type_ids < self.config.type_vocab_size).float()
|
||
|
assert torch.sum(token_type_ids_extension_mask + token_type_ids_mask) == \
|
||
|
torch.numel(token_type_ids) == torch.numel(token_type_ids_mask)
|
||
|
token_type_ids = (token_type_ids.float() * token_type_ids_mask).long()
|
||
|
|
||
|
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||
|
token_type_embeddings_extension = self.token_type_embeddings_extension(token_type_ids_extension)
|
||
|
|
||
|
token_type_embeddings = (token_type_embeddings * token_type_ids_mask.unsqueeze(-1)) + \
|
||
|
(token_type_embeddings_extension * token_type_ids_extension_mask.unsqueeze(-1))
|
||
|
|
||
|
embeddings = words_embeddings + position_embeddings + token_type_embeddings
|
||
|
|
||
|
embeddings = self.LayerNorm(embeddings)
|
||
|
embeddings = self.dropout(embeddings)
|
||
|
return embeddings
|
||
|
|
||
|
class BertSelfAttention(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
super(BertSelfAttention, self).__init__()
|
||
|
if config.hidden_size % config.num_attention_heads != 0:
|
||
|
raise ValueError(
|
||
|
"The hidden size (%d) is not a multiple of the number of attention "
|
||
|
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
|
||
|
)
|
||
|
self.num_attention_heads = config.num_attention_heads
|
||
|
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||
|
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||
|
|
||
|
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
||
|
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
||
|
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||
|
|
||
|
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||
|
|
||
|
def transpose_for_scores(self, x):
|
||
|
new_x_shape = x.size()[:-1] + (
|
||
|
self.num_attention_heads,
|
||
|
self.attention_head_size,
|
||
|
)
|
||
|
x = x.view(*new_x_shape)
|
||
|
return x.permute(0, 2, 1, 3)
|
||
|
|
||
|
def forward(self, hidden_states, attention_mask):
|
||
|
mixed_query_layer = self.query(hidden_states)
|
||
|
mixed_key_layer = self.key(hidden_states)
|
||
|
mixed_value_layer = self.value(hidden_states)
|
||
|
|
||
|
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||
|
key_layer = self.transpose_for_scores(mixed_key_layer)
|
||
|
value_layer = self.transpose_for_scores(mixed_value_layer)
|
||
|
|
||
|
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||
|
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||
|
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||
|
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
||
|
attention_scores = attention_scores + attention_mask
|
||
|
|
||
|
# Normalize the attention scores to probabilities.
|
||
|
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
||
|
|
||
|
# This is actually dropping out entire tokens to attend to, which might
|
||
|
# seem a bit unusual, but is taken from the original Transformer paper.
|
||
|
attention_probs = self.dropout(attention_probs)
|
||
|
|
||
|
context_layer = torch.matmul(attention_probs, value_layer)
|
||
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||
|
context_layer = context_layer.view(*new_context_layer_shape)
|
||
|
|
||
|
return context_layer, attention_probs
|
||
|
|
||
|
class BertSelfOutput(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
super(BertSelfOutput, self).__init__()
|
||
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||
|
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
|
||
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||
|
|
||
|
def forward(self, hidden_states, input_tensor):
|
||
|
hidden_states = self.dense(hidden_states)
|
||
|
hidden_states = self.dropout(hidden_states)
|
||
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||
|
return hidden_states
|
||
|
|
||
|
class BertAttention(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
super(BertAttention, self).__init__()
|
||
|
self.self = BertSelfAttention(config)
|
||
|
self.output = BertSelfOutput(config)
|
||
|
|
||
|
def forward(self, input_tensor, attention_mask):
|
||
|
self_output, attention_probs = self.self(input_tensor, attention_mask)
|
||
|
attention_output = self.output(self_output, input_tensor)
|
||
|
return attention_output, attention_probs
|
||
|
|
||
|
|
||
|
class BertIntermediate(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
super(BertIntermediate, self).__init__()
|
||
|
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
||
|
if isinstance(config.hidden_act, str) or (
|
||
|
sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)
|
||
|
):
|
||
|
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
||
|
else:
|
||
|
self.intermediate_act_fn = config.hidden_act
|
||
|
|
||
|
def forward(self, hidden_states):
|
||
|
hidden_states = self.dense(hidden_states)
|
||
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
||
|
return hidden_states
|
||
|
|
||
|
|
||
|
class BertOutput(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
super(BertOutput, self).__init__()
|
||
|
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
||
|
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
|
||
|
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||
|
|
||
|
def forward(self, hidden_states, input_tensor):
|
||
|
hidden_states = self.dense(hidden_states)
|
||
|
hidden_states = self.dropout(hidden_states)
|
||
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||
|
return hidden_states
|
||
|
|
||
|
|
||
|
class BertLayer(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
super(BertLayer, self).__init__()
|
||
|
self.attention = BertAttention(config)
|
||
|
self.intermediate = BertIntermediate(config)
|
||
|
self.output = BertOutput(config)
|
||
|
|
||
|
def forward(self, hidden_states, attention_mask):
|
||
|
attention_output, attention_probs = self.attention(hidden_states, attention_mask)
|
||
|
intermediate_output = self.intermediate(attention_output)
|
||
|
layer_output = self.output(intermediate_output, attention_output)
|
||
|
return layer_output, attention_probs
|
||
|
|
||
|
|
||
|
class TextGraphLayer(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
super(TextGraphLayer, self).__init__()
|
||
|
self.config = config
|
||
|
self.gnn_act = ACT2FN[config.gnn_act]
|
||
|
|
||
|
self.num_q_gnn_layers = config.num_q_gnn_layers
|
||
|
self.num_h_gnn_layers = config.num_h_gnn_layers
|
||
|
|
||
|
self.q_gnn_layers = []
|
||
|
self.q_gnn_norm_layers = []
|
||
|
|
||
|
for _ in range(self.num_q_gnn_layers):
|
||
|
# Graph layers
|
||
|
self.q_gnn_layers.append(
|
||
|
pyg_nn.GATv2Conv(
|
||
|
config.hidden_size, config.hidden_size//config.num_gnn_attention_heads,
|
||
|
config.num_gnn_attention_heads,
|
||
|
dropout=config.gnn_dropout_prob,
|
||
|
edge_dim=config.q_gnn_edge_dim,
|
||
|
concat=True
|
||
|
)
|
||
|
)
|
||
|
# After each graph layer, a normalization layer is added
|
||
|
self.q_gnn_norm_layers.append(pyg_nn.PairNorm())
|
||
|
|
||
|
self.q_gnn_layers = nn.ModuleList(self.q_gnn_layers)
|
||
|
self.q_gnn_norm_layers = nn.ModuleList(self.q_gnn_norm_layers)
|
||
|
|
||
|
self.h_gnn_layers = []
|
||
|
self.h_gnn_norm_layers = []
|
||
|
|
||
|
for _ in range(self.num_h_gnn_layers):
|
||
|
self.h_gnn_layers.append(
|
||
|
pyg_nn.GATv2Conv(
|
||
|
config.hidden_size, config.hidden_size//config.num_gnn_attention_heads,
|
||
|
config.num_gnn_attention_heads,
|
||
|
dropout=config.gnn_dropout_prob,
|
||
|
concat=True
|
||
|
)
|
||
|
)
|
||
|
# After each graph layer, a normalization layer is added
|
||
|
self.h_gnn_norm_layers.append(pyg_nn.PairNorm())
|
||
|
|
||
|
self.h_gnn_layers = nn.ModuleList(self.h_gnn_layers)
|
||
|
self.h_gnn_norm_layers = nn.ModuleList(self.h_gnn_norm_layers)
|
||
|
|
||
|
self.h_gnn_dense_hub = nn.Linear(config.v_hidden_size, config.hidden_size)
|
||
|
self.h_gnn_layer_norm_hub = BertLayerNorm(config.hidden_size, eps=1e-12)
|
||
|
self.h_gnn_dropout_hub = nn.Dropout(config.gnn_dropout_prob)
|
||
|
|
||
|
q_dense_pooling = nn.Sequential(
|
||
|
nn.Linear(config.hidden_size, 1),
|
||
|
ACT2FN['GeLU'],
|
||
|
nn.Dropout(config.gnn_dropout_prob)
|
||
|
)
|
||
|
self.q_gnn_pooling = pyg_nn.GlobalAttention(q_dense_pooling)
|
||
|
h_dense_pooling = nn.Sequential(
|
||
|
nn.Linear(config.hidden_size, 1),
|
||
|
ACT2FN['GeLU'],
|
||
|
nn.Dropout(config.gnn_dropout_prob)
|
||
|
)
|
||
|
self.h_gnn_pooling = pyg_nn.GlobalAttention(h_dense_pooling)
|
||
|
|
||
|
|
||
|
def forward(
|
||
|
self, hidden_states, q_edge_indices, q_edge_attributes,
|
||
|
q_limits, h_edge_indices, h_sep_indices, v_hub,
|
||
|
len_q_gr=None, len_h_gr=None, len_h_sep=None):
|
||
|
device = hidden_states.device
|
||
|
batch_size, _, hidden_size = hidden_states.size()
|
||
|
if isinstance(q_edge_indices, list):
|
||
|
assert len(q_edge_indices) == len(q_edge_attributes) == q_limits.size(0) \
|
||
|
== len(h_edge_indices) == len(h_sep_indices) == batch_size
|
||
|
else:
|
||
|
assert q_edge_indices.size(0) == q_edge_attributes.size(0) == q_limits.size(0) \
|
||
|
== h_edge_indices.size(0) == h_sep_indices.size(0) == batch_size
|
||
|
if len_q_gr is not None:
|
||
|
q_edge_indices = [t.squeeze(0)[:, :l].long() for t, l in zip(torch.split(q_edge_indices, 1, dim=0), len_q_gr)]
|
||
|
q_edge_attributes = [t.squeeze(0)[:l, :] for t, l in zip(torch.split(q_edge_attributes, 1, dim=0), len_q_gr)]
|
||
|
h_edge_indices = [t.squeeze(0)[:, :l].long() for t, l in zip(torch.split(h_edge_indices, 1, dim=0), len_h_gr)]
|
||
|
h_sep_indices = [t.squeeze(0)[:l].long() for t, l in zip(torch.split(h_sep_indices, 1, dim=0), len_h_sep)]
|
||
|
else:
|
||
|
q_edge_indices = [t.squeeze(0) for t in torch.split(q_edge_indices, 1, dim=0)]
|
||
|
q_edge_attributes = [t.squeeze(0) for t in torch.split(q_edge_attributes, 1, dim=0)]
|
||
|
h_edge_indices = [t.squeeze(0) for t in torch.split(h_edge_indices, 1, dim=0)]
|
||
|
h_sep_indices = [t.squeeze(0).long() for t in torch.split(h_sep_indices, 1, dim=0)]
|
||
|
|
||
|
gnn_hidden_states = hidden_states.clone().detach()
|
||
|
# Extract the history and question node features (without the hub node)
|
||
|
h_node_feats = []
|
||
|
q_node_feats = []
|
||
|
q_limits = q_limits.tolist()
|
||
|
q_tok_indices_extended = []
|
||
|
h_sep_indices_extended = []
|
||
|
for i, (h_sep_idx, q_limit) in enumerate(zip(h_sep_indices, q_limits)):
|
||
|
batch_data = gnn_hidden_states[i, :, :].clone().detach()
|
||
|
h_sep_idx = h_sep_idx.unsqueeze(-1).repeat(1, hidden_size)
|
||
|
h_sep_indices_extended.append(h_sep_idx)
|
||
|
h_node_feats.append(torch.gather(batch_data, 0, h_sep_idx))
|
||
|
q_tok_idx = torch.arange(q_limit[0], q_limit[1]).unsqueeze(-1).repeat(1, hidden_size).to(device)
|
||
|
q_tok_indices_extended.append(q_tok_idx)
|
||
|
q_node_feats.append(torch.gather(batch_data, 0, q_tok_idx))
|
||
|
|
||
|
# if self.use_hub_nodes:
|
||
|
# Map v_hub to the correct vector space
|
||
|
v_hub = self.h_gnn_dense_hub(v_hub)
|
||
|
v_hub = self.h_gnn_layer_norm_hub(v_hub)
|
||
|
v_hub = self.h_gnn_dropout_hub(v_hub)
|
||
|
# Add the hub node to the history nodes
|
||
|
v_hub = torch.split(v_hub, 1, dim=0)
|
||
|
h_node_feats = [torch.cat((h, x), dim=0) for h, x in zip(h_node_feats, v_hub)]
|
||
|
|
||
|
# Create the history graph data and pass them through the GNNs
|
||
|
pg_hist_data = [Data(x=x, edge_index=idx) for x, idx in zip(h_node_feats, h_edge_indices)]
|
||
|
pg_hist_loader = DataLoader(pg_hist_data, batch_size=batch_size, shuffle=False)
|
||
|
for data in pg_hist_loader:
|
||
|
x_h, edge_index_h, h_gnn_batch_idx = data.x, data.edge_index, data.batch
|
||
|
for i in range(self.num_h_gnn_layers):
|
||
|
# Normalization
|
||
|
x_h = self.h_gnn_norm_layers[i](x_h, h_gnn_batch_idx)
|
||
|
# Graph propagation
|
||
|
x_h = self.h_gnn_layers[i](x_h, edge_index_h, edge_attr=None)
|
||
|
# Activation
|
||
|
x_h = self.gnn_act(x_h) + x_h
|
||
|
x_h = self.gnn_act(x_h)
|
||
|
|
||
|
|
||
|
h_hub = self.h_gnn_pooling(x_h, h_gnn_batch_idx)
|
||
|
|
||
|
# Add the hub nodes
|
||
|
h_hub_split = torch.split(h_hub, 1, dim=0)
|
||
|
q_node_feats = [torch.cat((q, x), dim=0) for q, x in zip(q_node_feats, h_hub_split)]
|
||
|
|
||
|
|
||
|
# Create the question graph data and pass them through the GNNs
|
||
|
pg_ques_data = [Data(x=x, edge_index=idx, edge_attr=attr) for x, idx, attr in zip(q_node_feats, q_edge_indices, q_edge_attributes)]
|
||
|
pg_ques_loader = DataLoader(pg_ques_data, batch_size=batch_size, shuffle=False)
|
||
|
for data in pg_ques_loader:
|
||
|
x_q, edge_index_q, edge_attr_q, q_gnn_batch_idx = data.x, data.edge_index, data.edge_attr, data.batch
|
||
|
for i in range(self.num_q_gnn_layers):
|
||
|
# Normalization
|
||
|
x_q = self.q_gnn_norm_layers[i](x_q, q_gnn_batch_idx)
|
||
|
# GNN propagation
|
||
|
x_q = self.q_gnn_layers[i](x_q, edge_index_q, edge_attr=edge_attr_q)
|
||
|
# Activation
|
||
|
x_q = self.gnn_act(x_q) + x_q
|
||
|
x_q = self.gnn_act(x_q)
|
||
|
|
||
|
|
||
|
q_hub = self.q_gnn_pooling(x_q, q_gnn_batch_idx)
|
||
|
# Reshape the node features
|
||
|
h_node_feats = to_data_list(x_h, h_gnn_batch_idx)
|
||
|
q_node_feats = to_data_list(x_q, q_gnn_batch_idx)
|
||
|
|
||
|
# Update the text tokens with the graph feats
|
||
|
zipped_data = zip(h_node_feats, h_sep_indices_extended, q_node_feats, q_tok_indices_extended)
|
||
|
for i, (h_node_feat, h_sep_idx, q_node_feat, q_tok_idx) in enumerate(zipped_data):
|
||
|
gnn_hidden_states[i].scatter(0, h_sep_idx, h_node_feat[:-1])
|
||
|
gnn_hidden_states[i].scatter(0, q_tok_idx, q_node_feat[:-1])
|
||
|
|
||
|
final_hidden_states = 0.5 * (hidden_states + gnn_hidden_states)
|
||
|
return final_hidden_states, h_hub, q_hub
|
||
|
|
||
|
|
||
|
class BertImageSelfAttention(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
super(BertImageSelfAttention, self).__init__()
|
||
|
if config.v_hidden_size % config.v_num_attention_heads != 0:
|
||
|
raise ValueError(
|
||
|
"The hidden size (%d) is not a multiple of the number of attention "
|
||
|
"heads (%d)" % (config.v_hidden_size, config.v_num_attention_heads)
|
||
|
)
|
||
|
self.num_attention_heads = config.v_num_attention_heads
|
||
|
self.attention_head_size = int(
|
||
|
config.v_hidden_size / config.v_num_attention_heads
|
||
|
)
|
||
|
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||
|
|
||
|
self.query = nn.Linear(config.v_hidden_size, self.all_head_size)
|
||
|
self.key = nn.Linear(config.v_hidden_size, self.all_head_size)
|
||
|
self.value = nn.Linear(config.v_hidden_size, self.all_head_size)
|
||
|
|
||
|
self.dropout = nn.Dropout(config.v_attention_probs_dropout_prob)
|
||
|
|
||
|
def transpose_for_scores(self, x):
|
||
|
new_x_shape = x.size()[:-1] + (
|
||
|
self.num_attention_heads,
|
||
|
self.attention_head_size,
|
||
|
)
|
||
|
x = x.view(*new_x_shape)
|
||
|
return x.permute(0, 2, 1, 3)
|
||
|
|
||
|
def forward(self, hidden_states, attention_mask):
|
||
|
mixed_query_layer = self.query(hidden_states)
|
||
|
mixed_key_layer = self.key(hidden_states)
|
||
|
mixed_value_layer = self.value(hidden_states)
|
||
|
|
||
|
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||
|
key_layer = self.transpose_for_scores(mixed_key_layer)
|
||
|
value_layer = self.transpose_for_scores(mixed_value_layer)
|
||
|
|
||
|
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||
|
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
||
|
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||
|
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
||
|
attention_scores = attention_scores + attention_mask
|
||
|
|
||
|
# Normalize the attention scores to probabilities.
|
||
|
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
||
|
|
||
|
# This is actually dropping out entire tokens to attend to, which might
|
||
|
# seem a bit unusual, but is taken from the original Transformer paper.
|
||
|
attention_probs = self.dropout(attention_probs)
|
||
|
|
||
|
context_layer = torch.matmul(attention_probs, value_layer)
|
||
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||
|
context_layer = context_layer.view(*new_context_layer_shape)
|
||
|
|
||
|
return context_layer, attention_probs
|
||
|
|
||
|
class BertImageSelfOutput(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
super(BertImageSelfOutput, self).__init__()
|
||
|
self.dense = nn.Linear(config.v_hidden_size, config.v_hidden_size)
|
||
|
self.LayerNorm = BertLayerNorm(config.v_hidden_size, eps=1e-12)
|
||
|
self.dropout = nn.Dropout(config.v_hidden_dropout_prob)
|
||
|
|
||
|
def forward(self, hidden_states, input_tensor):
|
||
|
hidden_states = self.dense(hidden_states)
|
||
|
hidden_states = self.dropout(hidden_states)
|
||
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||
|
return hidden_states
|
||
|
|
||
|
class BertImageAttention(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
super(BertImageAttention, self).__init__()
|
||
|
self.self = BertImageSelfAttention(config)
|
||
|
self.output = BertImageSelfOutput(config)
|
||
|
|
||
|
def forward(self, input_tensor, attention_mask):
|
||
|
self_output, attention_probs = self.self(input_tensor, attention_mask)
|
||
|
attention_output = self.output(self_output, input_tensor)
|
||
|
return attention_output, attention_probs
|
||
|
|
||
|
|
||
|
class BertImageIntermediate(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
super(BertImageIntermediate, self).__init__()
|
||
|
self.dense = nn.Linear(config.v_hidden_size, config.v_intermediate_size)
|
||
|
if isinstance(config.v_hidden_act, str) or (
|
||
|
sys.version_info[0] == 2 and isinstance(config.v_hidden_act, unicode)
|
||
|
):
|
||
|
self.intermediate_act_fn = ACT2FN[config.v_hidden_act]
|
||
|
else:
|
||
|
self.intermediate_act_fn = config.v_hidden_act
|
||
|
|
||
|
def forward(self, hidden_states):
|
||
|
hidden_states = self.dense(hidden_states)
|
||
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
||
|
return hidden_states
|
||
|
|
||
|
|
||
|
class BertImageOutput(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
super(BertImageOutput, self).__init__()
|
||
|
self.dense = nn.Linear(config.v_intermediate_size, config.v_hidden_size)
|
||
|
self.LayerNorm = BertLayerNorm(config.v_hidden_size, eps=1e-12)
|
||
|
self.dropout = nn.Dropout(config.v_hidden_dropout_prob)
|
||
|
|
||
|
def forward(self, hidden_states, input_tensor):
|
||
|
hidden_states = self.dense(hidden_states)
|
||
|
hidden_states = self.dropout(hidden_states)
|
||
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||
|
return hidden_states
|
||
|
|
||
|
|
||
|
class BertImageLayer(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
super(BertImageLayer, self).__init__()
|
||
|
self.attention = BertImageAttention(config)
|
||
|
self.intermediate = BertImageIntermediate(config)
|
||
|
self.output = BertImageOutput(config)
|
||
|
|
||
|
def forward(self, hidden_states, attention_mask):
|
||
|
attention_output, attention_probs = self.attention(hidden_states, attention_mask)
|
||
|
intermediate_output = self.intermediate(attention_output)
|
||
|
layer_output = self.output(intermediate_output, attention_output)
|
||
|
return layer_output, attention_probs
|
||
|
|
||
|
class ImageGraphLayer(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
super(ImageGraphLayer, self).__init__()
|
||
|
self.config = config
|
||
|
self.gnn_act = ACT2FN[config.gnn_act]
|
||
|
|
||
|
self.num_gnn_layers = config.num_v_gnn_layers
|
||
|
self.gnn_layers = []
|
||
|
self.gnn_norm_layers = []
|
||
|
|
||
|
for _ in range(self.num_gnn_layers):
|
||
|
self.gnn_layers.append(
|
||
|
pyg_nn.GATv2Conv(
|
||
|
config.v_hidden_size, config.v_hidden_size//config.num_gnn_attention_heads,
|
||
|
config.num_gnn_attention_heads,
|
||
|
dropout=config.gnn_dropout_prob,
|
||
|
edge_dim=config.v_gnn_edge_dim,
|
||
|
concat=True
|
||
|
)
|
||
|
)
|
||
|
# After each graph layer, a normalization layer is added
|
||
|
self.gnn_norm_layers.append(pyg_nn.PairNorm())
|
||
|
|
||
|
self.gnn_layers = nn.ModuleList(self.gnn_layers)
|
||
|
self.gnn_norm_layers = nn.ModuleList(self.gnn_norm_layers)
|
||
|
|
||
|
self.gnn_dense_hub = nn.Linear(config.hidden_size, config.v_hidden_size)
|
||
|
self.gnn_layer_norm_hub = BertLayerNorm(config.v_hidden_size, eps=1e-12)
|
||
|
self.gnn_dropout_hub = nn.Dropout(config.gnn_dropout_prob)
|
||
|
|
||
|
dense_pooling = nn.Sequential(
|
||
|
nn.Linear(config.v_hidden_size, 1),
|
||
|
ACT2FN['GeLU'],
|
||
|
nn.Dropout(config.gnn_dropout_prob)
|
||
|
)
|
||
|
self.gnn_pooling = pyg_nn.GlobalAttention(dense_pooling)
|
||
|
|
||
|
def forward(
|
||
|
self, hidden_states, edge_indices, edge_attributes, hub_states,
|
||
|
len_img_gr=None):
|
||
|
# assert hub_states is not None
|
||
|
gnn_hidden_states = hidden_states.clone().detach()
|
||
|
batch_size, num_img_reg, v_hidden_size = hidden_states.size()
|
||
|
node_feats = hidden_states.clone().detach()
|
||
|
# Remave the [IMG] feats
|
||
|
node_feats = node_feats[:, 1:]
|
||
|
node_feats = torch.split(node_feats, 1, dim=0)
|
||
|
|
||
|
if len_img_gr is not None:
|
||
|
edge_indices = [t.squeeze(0)[:, :l].long() for t, l in zip(torch.split(edge_indices, 1, dim=0), len_img_gr)]
|
||
|
edge_attributes = [t.squeeze(0)[:l, :] for t, l in zip(torch.split(edge_attributes, 1, dim=0), len_img_gr)]
|
||
|
|
||
|
# Concat the hub states
|
||
|
hub_states = self.gnn_dense_hub(hub_states)
|
||
|
hub_states = self.gnn_dropout_hub(hub_states)
|
||
|
hub_states = self.gnn_layer_norm_hub(hub_states)
|
||
|
|
||
|
hub_states = torch.split(hub_states, 1, dim=0)
|
||
|
node_feats = [torch.cat((x.squeeze(0), h), dim=0)
|
||
|
for x, h in zip(node_feats, hub_states)]
|
||
|
|
||
|
pg_data = [Data(x, idx, attr) for x, idx, attr in zip(
|
||
|
node_feats, edge_indices, edge_attributes)]
|
||
|
pg_dataloader = DataLoader(
|
||
|
pg_data, batch_size=batch_size, shuffle=False)
|
||
|
# Gnn forward pass
|
||
|
for data in pg_dataloader:
|
||
|
x, edge_index, edge_attr, gnn_batch_idx = data.x, data.edge_index, data.edge_attr, data.batch
|
||
|
for i in range(self.num_gnn_layers):
|
||
|
# Normalization
|
||
|
x = self.gnn_norm_layers[i](x, gnn_batch_idx)
|
||
|
# GNN propagation
|
||
|
x = self.gnn_layers[i](x, edge_index, edge_attr=edge_attr)
|
||
|
# Activation
|
||
|
x = self.gnn_act(x) + x
|
||
|
x = self.gnn_act(x)
|
||
|
|
||
|
# Reshape the output of the GNN to batch_size x num_img_reg x hidden_dim
|
||
|
v_hub = self.gnn_pooling(x, gnn_batch_idx)
|
||
|
|
||
|
x = x.view(batch_size, num_img_reg, v_hidden_size)
|
||
|
gnn_hidden_states[:, 1:, :] = x[:, :-1, :]
|
||
|
|
||
|
final_hidden_states = 0.5 * (hidden_states + gnn_hidden_states)
|
||
|
|
||
|
return final_hidden_states, v_hub
|
||
|
|
||
|
|
||
|
class BertBiAttention(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
super(BertBiAttention, self).__init__()
|
||
|
if config.bi_hidden_size % config.bi_num_attention_heads != 0:
|
||
|
raise ValueError(
|
||
|
"The hidden size (%d) is not a multiple of the number of attention "
|
||
|
"heads (%d)" % (config.bi_hidden_size, config.bi_num_attention_heads)
|
||
|
)
|
||
|
|
||
|
self.num_attention_heads = config.bi_num_attention_heads
|
||
|
self.attention_head_size = int(
|
||
|
config.bi_hidden_size / config.bi_num_attention_heads
|
||
|
)
|
||
|
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||
|
|
||
|
# self.scale = nn.Linear(1, self.num_attention_heads, bias=False)
|
||
|
# self.scale_act_fn = ACT2FN['relu']
|
||
|
|
||
|
self.query1 = nn.Linear(config.v_hidden_size, self.all_head_size)
|
||
|
self.key1 = nn.Linear(config.v_hidden_size, self.all_head_size)
|
||
|
self.value1 = nn.Linear(config.v_hidden_size, self.all_head_size)
|
||
|
# self.logit1 = nn.Linear(config.hidden_size, self.num_attention_heads)
|
||
|
|
||
|
self.dropout1 = nn.Dropout(config.v_attention_probs_dropout_prob)
|
||
|
|
||
|
self.query2 = nn.Linear(config.hidden_size, self.all_head_size)
|
||
|
self.key2 = nn.Linear(config.hidden_size, self.all_head_size)
|
||
|
self.value2 = nn.Linear(config.hidden_size, self.all_head_size)
|
||
|
# self.logit2 = nn.Linear(config.hidden_size, self.num_attention_heads)
|
||
|
|
||
|
self.dropout2 = nn.Dropout(config.attention_probs_dropout_prob)
|
||
|
|
||
|
def transpose_for_scores(self, x):
|
||
|
new_x_shape = x.size()[:-1] + (
|
||
|
self.num_attention_heads,
|
||
|
self.attention_head_size,
|
||
|
)
|
||
|
x = x.view(*new_x_shape)
|
||
|
return x.permute(0, 2, 1, 3)
|
||
|
|
||
|
def forward(self, input_tensor1, attention_mask1, input_tensor2, attention_mask2, co_attention_mask=None, use_co_attention_mask=False):
|
||
|
|
||
|
# for vision input.
|
||
|
mixed_query_layer1 = self.query1(input_tensor1)
|
||
|
mixed_key_layer1 = self.key1(input_tensor1)
|
||
|
mixed_value_layer1 = self.value1(input_tensor1)
|
||
|
# mixed_logit_layer1 = self.logit1(input_tensor1)
|
||
|
|
||
|
query_layer1 = self.transpose_for_scores(mixed_query_layer1)
|
||
|
key_layer1 = self.transpose_for_scores(mixed_key_layer1)
|
||
|
value_layer1 = self.transpose_for_scores(mixed_value_layer1)
|
||
|
# logit_layer1 = self.transpose_for_logits(mixed_logit_layer1)
|
||
|
|
||
|
# for text input:
|
||
|
mixed_query_layer2 = self.query2(input_tensor2)
|
||
|
mixed_key_layer2 = self.key2(input_tensor2)
|
||
|
mixed_value_layer2 = self.value2(input_tensor2)
|
||
|
# mixed_logit_layer2 = self.logit2(input_tensor2)
|
||
|
|
||
|
query_layer2 = self.transpose_for_scores(mixed_query_layer2)
|
||
|
key_layer2 = self.transpose_for_scores(mixed_key_layer2)
|
||
|
value_layer2 = self.transpose_for_scores(mixed_value_layer2)
|
||
|
# logit_layer2 = self.transpose_for_logits(mixed_logit_layer2)
|
||
|
|
||
|
# Take the dot product between "query2" and "key1" to get the raw attention scores for value 1.
|
||
|
attention_scores1 = torch.matmul(query_layer2, key_layer1.transpose(-1, -2))
|
||
|
attention_scores1 = attention_scores1 / math.sqrt(self.attention_head_size)
|
||
|
attention_scores1 = attention_scores1 + attention_mask1
|
||
|
|
||
|
if use_co_attention_mask:
|
||
|
attention_scores1 = attention_scores1 + co_attention_mask.permute(0,1,3,2)
|
||
|
|
||
|
# Normalize the attention scores to probabilities.
|
||
|
attention_probs1 = nn.Softmax(dim=-1)(attention_scores1)
|
||
|
|
||
|
# This is actually dropping out entire tokens to attend to, which might
|
||
|
# seem a bit unusual, but is taken from the original Transformer paper.
|
||
|
attention_probs1 = self.dropout1(attention_probs1)
|
||
|
|
||
|
context_layer1 = torch.matmul(attention_probs1, value_layer1)
|
||
|
context_layer1 = context_layer1.permute(0, 2, 1, 3).contiguous()
|
||
|
new_context_layer_shape1 = context_layer1.size()[:-2] + (self.all_head_size,)
|
||
|
context_layer1 = context_layer1.view(*new_context_layer_shape1)
|
||
|
|
||
|
# Take the dot product between "query1" and "key2" to get the raw attention scores for value 2.
|
||
|
attention_scores2 = torch.matmul(query_layer1, key_layer2.transpose(-1, -2))
|
||
|
attention_scores2 = attention_scores2 / math.sqrt(self.attention_head_size)
|
||
|
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
||
|
|
||
|
# we can comment this line for single flow.
|
||
|
attention_scores2 = attention_scores2 + attention_mask2
|
||
|
if use_co_attention_mask:
|
||
|
attention_scores2 = attention_scores2 + co_attention_mask
|
||
|
|
||
|
# Normalize the attention scores to probabilities.
|
||
|
attention_probs2 = nn.Softmax(dim=-1)(attention_scores2)
|
||
|
|
||
|
# This is actually dropping out entire tokens to attend to, which might
|
||
|
# seem a bit unusual, but is taken from the original Transformer paper.
|
||
|
attention_probs2 = self.dropout2(attention_probs2)
|
||
|
|
||
|
context_layer2 = torch.matmul(attention_probs2, value_layer2)
|
||
|
context_layer2 = context_layer2.permute(0, 2, 1, 3).contiguous()
|
||
|
new_context_layer_shape2 = context_layer2.size()[:-2] + (self.all_head_size,)
|
||
|
context_layer2 = context_layer2.view(*new_context_layer_shape2)
|
||
|
|
||
|
return context_layer1, context_layer2, (attention_probs1, attention_probs2)
|
||
|
|
||
|
class BertBiOutput(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
super(BertBiOutput, self).__init__()
|
||
|
|
||
|
self.dense1 = nn.Linear(config.bi_hidden_size, config.v_hidden_size)
|
||
|
self.LayerNorm1 = BertLayerNorm(config.v_hidden_size, eps=1e-12)
|
||
|
self.dropout1 = nn.Dropout(config.v_hidden_dropout_prob)
|
||
|
|
||
|
self.q_dense1 = nn.Linear(config.bi_hidden_size, config.v_hidden_size)
|
||
|
self.q_dropout1 = nn.Dropout(config.v_hidden_dropout_prob)
|
||
|
|
||
|
self.dense2 = nn.Linear(config.bi_hidden_size, config.hidden_size)
|
||
|
self.LayerNorm2 = BertLayerNorm(config.hidden_size, eps=1e-12)
|
||
|
self.dropout2 = nn.Dropout(config.hidden_dropout_prob)
|
||
|
|
||
|
self.q_dense2 = nn.Linear(config.bi_hidden_size, config.hidden_size)
|
||
|
self.q_dropout2 = nn.Dropout(config.hidden_dropout_prob)
|
||
|
|
||
|
def forward(self, hidden_states1, input_tensor1, hidden_states2, input_tensor2):
|
||
|
|
||
|
|
||
|
context_state1 = self.dense1(hidden_states1)
|
||
|
context_state1 = self.dropout1(context_state1)
|
||
|
|
||
|
context_state2 = self.dense2(hidden_states2)
|
||
|
context_state2 = self.dropout2(context_state2)
|
||
|
|
||
|
hidden_states1 = self.LayerNorm1(context_state1 + input_tensor1)
|
||
|
hidden_states2 = self.LayerNorm2(context_state2 + input_tensor2)
|
||
|
|
||
|
return hidden_states1, hidden_states2
|
||
|
|
||
|
class BertConnectionLayer(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
super(BertConnectionLayer, self).__init__()
|
||
|
self.biattention = BertBiAttention(config)
|
||
|
|
||
|
self.biOutput = BertBiOutput(config)
|
||
|
|
||
|
self.v_intermediate = BertImageIntermediate(config)
|
||
|
self.v_output = BertImageOutput(config)
|
||
|
|
||
|
self.t_intermediate = BertIntermediate(config)
|
||
|
self.t_output = BertOutput(config)
|
||
|
|
||
|
def forward(self, input_tensor1, attention_mask1, input_tensor2, attention_mask2, co_attention_mask=None, use_co_attention_mask=False):
|
||
|
|
||
|
bi_output1, bi_output2, co_attention_probs = self.biattention(
|
||
|
input_tensor1, attention_mask1, input_tensor2, attention_mask2, co_attention_mask, use_co_attention_mask
|
||
|
)
|
||
|
|
||
|
attention_output1, attention_output2 = self.biOutput(bi_output2, input_tensor1, bi_output1, input_tensor2)
|
||
|
|
||
|
intermediate_output1 = self.v_intermediate(attention_output1)
|
||
|
layer_output1 = self.v_output(intermediate_output1, attention_output1)
|
||
|
|
||
|
intermediate_output2 = self.t_intermediate(attention_output2)
|
||
|
layer_output2 = self.t_output(intermediate_output2, attention_output2)
|
||
|
|
||
|
return layer_output1, layer_output2, co_attention_probs
|
||
|
|
||
|
class BertEncoder(nn.Module):
|
||
|
def __init__(self, config):
|
||
|
super(BertEncoder, self).__init__()
|
||
|
|
||
|
# in the bert encoder, we need to extract three things here.
|
||
|
# text bert layer: BertLayer
|
||
|
# vision bert layer: BertImageLayer
|
||
|
# Bi-Attention: Given the output of two bertlayer, perform bi-directional
|
||
|
# attention and add on two layers.
|
||
|
|
||
|
self.FAST_MODE = config.fast_mode
|
||
|
self.with_coattention = config.with_coattention
|
||
|
self.v_biattention_id = config.v_biattention_id
|
||
|
self.t_biattention_id = config.t_biattention_id
|
||
|
self.in_batch_pairs = config.in_batch_pairs
|
||
|
self.fixed_t_layer = config.fixed_t_layer
|
||
|
self.fixed_v_layer = config.fixed_v_layer
|
||
|
self.t_gnn_ids = config.t_gnn_ids
|
||
|
self.v_gnn_ids = config.v_gnn_ids
|
||
|
|
||
|
v_layer = BertImageLayer(config)
|
||
|
connect_layer = BertConnectionLayer(config)
|
||
|
|
||
|
self.layer = []
|
||
|
for _ in range(config.num_hidden_layers):
|
||
|
self.layer.append(BertLayer(config))
|
||
|
|
||
|
self.layer = nn.ModuleList(self.layer)
|
||
|
|
||
|
txt_graph_layer = TextGraphLayer(config)
|
||
|
self.t_gnns = nn.ModuleList([txt_graph_layer for _ in range(len(self.t_gnn_ids))])
|
||
|
|
||
|