VDGR/models/vilbert_dialog.py

2022 lines
86 KiB
Python
Raw Normal View History

2023-10-25 15:38:09 +02:00
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BERT model."""
import copy
import json
import logging
import math
import os
import shutil
import tarfile
import tempfile
import sys
from io import open
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F
from torch.nn.utils.weight_norm import weight_norm
from pytorch_transformers.modeling_bert import BertEmbeddings
from utils.data_utils import sequence_mask, to_data_list
import torch_geometric.nn as pyg_nn
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from pytorch_pretrained_bert.file_utils import cached_path
import pdb
logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = {
"bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz",
"bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz",
"bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz",
"bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz",
"bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz",
"bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz",
"bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz",
}
def load_tf_weights_in_bert(model, tf_checkpoint_path):
""" Load tf checkpoints in a pytorch model
"""
try:
import re
import numpy as np
import tensorflow as tf
except ImportError:
print(
"Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
raise
tf_path = os.path.abspath(tf_checkpoint_path)
print("Converting TensorFlow checkpoint from {}".format(tf_path))
# Load weights from TF model
init_vars = tf.train.list_variables(tf_path)
names = []
arrays = []
for name, shape in init_vars:
print("Loading TF weight {} with shape {}".format(name, shape))
array = tf.train.load_variable(tf_path, name)
names.append(name)
arrays.append(array)
for name, array in zip(names, arrays):
name = name.split("/")
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if any(n in ["adam_v", "adam_m"] for n in name):
print("Skipping {}".format("/".join(name)))
continue
pointer = model
for m_name in name:
if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
l = re.split(r"_(\d+)", m_name)
else:
l = [m_name]
if l[0] == "kernel" or l[0] == "gamma":
pointer = getattr(pointer, "weight")
elif l[0] == "output_bias" or l[0] == "beta":
pointer = getattr(pointer, "bias")
elif l[0] == "output_weights":
pointer = getattr(pointer, "weight")
else:
pointer = getattr(pointer, l[0])
if len(l) >= 2:
num = int(l[1])
pointer = pointer[num]
if m_name[-11:] == "_embeddings":
pointer = getattr(pointer, "weight")
elif m_name == "kernel":
array = np.transpose(array)
try:
assert pointer.shape == array.shape
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
print("Initialize PyTorch weight {}".format(name))
pointer.data = torch.from_numpy(array)
return model
class GeLU(nn.Module):
"""Implementation of the gelu activation function.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
Also see https://arxiv.org/abs/1606.08415
"""
def __init__(self):
super(GeLU, self).__init__()
def forward(self, x):
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def gelu(x):
"""Implementation of the gelu activation function.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
Also see https://arxiv.org/abs/1606.08415
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def swish(x):
return x * torch.sigmoid(x)
ACT2FN = {"GeLU": GeLU(), "gelu": gelu,
"relu": torch.nn.functional.relu, "swish": swish}
class BertConfig(object):
"""Configuration class to store the configuration of a `BertModel`.
"""
def __init__(
self,
vocab_size_or_config_json_file,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
v_feature_size=2048,
v_target_size=1601,
v_hidden_size=768,
v_num_hidden_layers=3,
v_num_attention_heads=12,
v_intermediate_size=3072,
bi_hidden_size=1024,
bi_num_attention_heads=16,
v_attention_probs_dropout_prob=0.1,
v_hidden_act="gelu",
v_hidden_dropout_prob=0.1,
v_initializer_range=0.2,
v_biattention_id=[0, 1],
t_biattention_id=[10, 11],
predict_feature=False,
fast_mode=False,
fixed_v_layer=0,
fixed_t_layer=0,
in_batch_pairs=False,
fusion_method="mul",
intra_gate=False,
with_coattention=True
):
"""Constructs BertConfig.
Args:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
hidden_size: Size of the encoder layers and the pooler layer.
num_hidden_layers: Number of hidden layers in the Transformer encoder.
num_attention_heads: Number of attention heads for each attention layer in
the Transformer encoder.
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
layer in the Transformer encoder.
hidden_act: The non-linear activation function (function or string) in the
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
hidden_dropout_prob: The dropout probabilitiy for all fully connected
layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob: The dropout ratio for the attention
probabilities.
max_position_embeddings: The maximum sequence length that this model might
ever be used with. Typically set this to something large just in case
(e.g., 512 or 1024 or 2048).
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
`BertModel`.
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
"""
assert len(v_biattention_id) == len(t_biattention_id)
assert max(v_biattention_id) < v_num_hidden_layers
assert max(t_biattention_id) < num_hidden_layers
if isinstance(vocab_size_or_config_json_file, str) or (
sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)
):
with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
self.__dict__[key] = value
elif isinstance(vocab_size_or_config_json_file, int):
self.vocab_size = vocab_size_or_config_json_file
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.v_feature_size = v_feature_size
self.v_hidden_size = v_hidden_size
self.v_num_hidden_layers = v_num_hidden_layers
self.v_num_attention_heads = v_num_attention_heads
self.v_intermediate_size = v_intermediate_size
self.v_attention_probs_dropout_prob = v_attention_probs_dropout_prob
self.v_hidden_act = v_hidden_act
self.v_hidden_dropout_prob = v_hidden_dropout_prob
self.v_initializer_range = v_initializer_range
self.v_biattention_id = v_biattention_id
self.t_biattention_id = t_biattention_id
self.v_target_size = v_target_size
self.bi_hidden_size = bi_hidden_size
self.bi_num_attention_heads = bi_num_attention_heads
self.predict_feature = predict_feature
self.fast_mode = fast_mode
self.fixed_v_layer = fixed_v_layer
self.fixed_t_layer = fixed_t_layer
self.in_batch_pairs = in_batch_pairs
self.fusion_method = fusion_method
self.intra_gate = intra_gate
self.with_coattention=with_coattention
else:
raise ValueError(
"First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)"
)
@classmethod
def from_dict(cls, json_object):
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
config = BertConfig(vocab_size_or_config_json_file=-1)
for key, value in json_object.items():
config.__dict__[key] = value
return config
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `BertConfig` from a json file of parameters."""
with open(json_file, "r", encoding="utf-8") as reader:
text = reader.read()
return cls.from_dict(json.loads(text))
def __repr__(self):
return str(self.to_json_string())
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
try:
# from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
import torch.nn.LayerNorm as BertLayerNorm
except ImportError:
# logger.info(
# "Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex ."
# )
pass
class BertLayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-12):
"""Construct a layernorm module in the TF style (epsilon inside the square root).
"""
super(BertLayerNorm, self).__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps
def forward(self, x):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.weight * x + self.bias
class BertEmbeddingsDialog(nn.Module):
def __init__(self, config, device):
super(BertEmbeddingsDialog, self).__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
max_seq_len = 256
d_model = config.hidden_size
pe = torch.zeros(max_seq_len, d_model)
for pos in range(max_seq_len):
for i in range(0, d_model, 2):
pe[pos, i] = \
math.sin(pos / (10000 ** ((2 * i)/d_model)))
pe[pos, i + 1] = \
math.cos(pos / (10000 ** ((2 * (i + 1))/d_model)))
self.pe = pe.to(device)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
# add support for additional segment embeddings. Supporting 10 additional embedding as of now
self.token_type_embeddings_extension = nn.Embedding(10,config.hidden_size)
# adding specialized embeddings for sep tokens
self.sep_embeddings = nn.Embedding(50,config.hidden_size)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.config = config
def forward(self, input_ids, sep_indices=None, sep_len=None, token_type_ids=None):
seq_length = input_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_ids_extension = token_type_ids - self.config.type_vocab_size
token_type_ids_extension_mask = (token_type_ids_extension >= 0).float()
token_type_ids_extension = (token_type_ids_extension.float() * token_type_ids_extension_mask).long()
token_type_ids_mask = (token_type_ids < self.config.type_vocab_size).float()
assert torch.sum(token_type_ids_extension_mask + token_type_ids_mask) == \
torch.numel(token_type_ids) == torch.numel(token_type_ids_mask)
token_type_ids = (token_type_ids.float() * token_type_ids_mask).long()
token_type_embeddings = self.token_type_embeddings(token_type_ids)
token_type_embeddings_extension = self.token_type_embeddings_extension(token_type_ids_extension)
token_type_embeddings = (token_type_embeddings * token_type_ids_mask.unsqueeze(-1)) + \
(token_type_embeddings_extension * token_type_ids_extension_mask.unsqueeze(-1))
embeddings = words_embeddings + position_embeddings + token_type_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class BertSelfAttention(nn.Module):
def __init__(self, config):
super(BertSelfAttention, self).__init__()
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (
self.num_attention_heads,
self.attention_head_size,
)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states, attention_mask):
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer, attention_probs
class BertSelfOutput(nn.Module):
def __init__(self, config):
super(BertSelfOutput, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertAttention(nn.Module):
def __init__(self, config):
super(BertAttention, self).__init__()
self.self = BertSelfAttention(config)
self.output = BertSelfOutput(config)
def forward(self, input_tensor, attention_mask):
self_output, attention_probs = self.self(input_tensor, attention_mask)
attention_output = self.output(self_output, input_tensor)
return attention_output, attention_probs
class BertIntermediate(nn.Module):
def __init__(self, config):
super(BertIntermediate, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str) or (
sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)
):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class BertOutput(nn.Module):
def __init__(self, config):
super(BertOutput, self).__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertLayer(nn.Module):
def __init__(self, config):
super(BertLayer, self).__init__()
self.attention = BertAttention(config)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
def forward(self, hidden_states, attention_mask):
attention_output, attention_probs = self.attention(hidden_states, attention_mask)
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output, attention_probs
class TextGraphLayer(nn.Module):
def __init__(self, config):
super(TextGraphLayer, self).__init__()
self.config = config
self.gnn_act = ACT2FN[config.gnn_act]
self.num_q_gnn_layers = config.num_q_gnn_layers
self.num_h_gnn_layers = config.num_h_gnn_layers
self.q_gnn_layers = []
self.q_gnn_norm_layers = []
for _ in range(self.num_q_gnn_layers):
# Graph layers
self.q_gnn_layers.append(
pyg_nn.GATv2Conv(
config.hidden_size, config.hidden_size//config.num_gnn_attention_heads,
config.num_gnn_attention_heads,
dropout=config.gnn_dropout_prob,
edge_dim=config.q_gnn_edge_dim,
concat=True
)
)
# After each graph layer, a normalization layer is added
self.q_gnn_norm_layers.append(pyg_nn.PairNorm())
self.q_gnn_layers = nn.ModuleList(self.q_gnn_layers)
self.q_gnn_norm_layers = nn.ModuleList(self.q_gnn_norm_layers)
self.h_gnn_layers = []
self.h_gnn_norm_layers = []
for _ in range(self.num_h_gnn_layers):
self.h_gnn_layers.append(
pyg_nn.GATv2Conv(
config.hidden_size, config.hidden_size//config.num_gnn_attention_heads,
config.num_gnn_attention_heads,
dropout=config.gnn_dropout_prob,
concat=True
)
)
# After each graph layer, a normalization layer is added
self.h_gnn_norm_layers.append(pyg_nn.PairNorm())
self.h_gnn_layers = nn.ModuleList(self.h_gnn_layers)
self.h_gnn_norm_layers = nn.ModuleList(self.h_gnn_norm_layers)
self.h_gnn_dense_hub = nn.Linear(config.v_hidden_size, config.hidden_size)
self.h_gnn_layer_norm_hub = BertLayerNorm(config.hidden_size, eps=1e-12)
self.h_gnn_dropout_hub = nn.Dropout(config.gnn_dropout_prob)
q_dense_pooling = nn.Sequential(
nn.Linear(config.hidden_size, 1),
ACT2FN['GeLU'],
nn.Dropout(config.gnn_dropout_prob)
)
self.q_gnn_pooling = pyg_nn.GlobalAttention(q_dense_pooling)
h_dense_pooling = nn.Sequential(
nn.Linear(config.hidden_size, 1),
ACT2FN['GeLU'],
nn.Dropout(config.gnn_dropout_prob)
)
self.h_gnn_pooling = pyg_nn.GlobalAttention(h_dense_pooling)
def forward(
self, hidden_states, q_edge_indices, q_edge_attributes,
q_limits, h_edge_indices, h_sep_indices, v_hub,
len_q_gr=None, len_h_gr=None, len_h_sep=None):
device = hidden_states.device
batch_size, _, hidden_size = hidden_states.size()
if isinstance(q_edge_indices, list):
assert len(q_edge_indices) == len(q_edge_attributes) == q_limits.size(0) \
== len(h_edge_indices) == len(h_sep_indices) == batch_size
else:
assert q_edge_indices.size(0) == q_edge_attributes.size(0) == q_limits.size(0) \
== h_edge_indices.size(0) == h_sep_indices.size(0) == batch_size
if len_q_gr is not None:
q_edge_indices = [t.squeeze(0)[:, :l].long() for t, l in zip(torch.split(q_edge_indices, 1, dim=0), len_q_gr)]
q_edge_attributes = [t.squeeze(0)[:l, :] for t, l in zip(torch.split(q_edge_attributes, 1, dim=0), len_q_gr)]
h_edge_indices = [t.squeeze(0)[:, :l].long() for t, l in zip(torch.split(h_edge_indices, 1, dim=0), len_h_gr)]
h_sep_indices = [t.squeeze(0)[:l].long() for t, l in zip(torch.split(h_sep_indices, 1, dim=0), len_h_sep)]
else:
q_edge_indices = [t.squeeze(0) for t in torch.split(q_edge_indices, 1, dim=0)]
q_edge_attributes = [t.squeeze(0) for t in torch.split(q_edge_attributes, 1, dim=0)]
h_edge_indices = [t.squeeze(0) for t in torch.split(h_edge_indices, 1, dim=0)]
h_sep_indices = [t.squeeze(0).long() for t in torch.split(h_sep_indices, 1, dim=0)]
gnn_hidden_states = hidden_states.clone().detach()
# Extract the history and question node features (without the hub node)
h_node_feats = []
q_node_feats = []
q_limits = q_limits.tolist()
q_tok_indices_extended = []
h_sep_indices_extended = []
for i, (h_sep_idx, q_limit) in enumerate(zip(h_sep_indices, q_limits)):
batch_data = gnn_hidden_states[i, :, :].clone().detach()
h_sep_idx = h_sep_idx.unsqueeze(-1).repeat(1, hidden_size)
h_sep_indices_extended.append(h_sep_idx)
h_node_feats.append(torch.gather(batch_data, 0, h_sep_idx))
q_tok_idx = torch.arange(q_limit[0], q_limit[1]).unsqueeze(-1).repeat(1, hidden_size).to(device)
q_tok_indices_extended.append(q_tok_idx)
q_node_feats.append(torch.gather(batch_data, 0, q_tok_idx))
# if self.use_hub_nodes:
# Map v_hub to the correct vector space
v_hub = self.h_gnn_dense_hub(v_hub)
v_hub = self.h_gnn_layer_norm_hub(v_hub)
v_hub = self.h_gnn_dropout_hub(v_hub)
# Add the hub node to the history nodes
v_hub = torch.split(v_hub, 1, dim=0)
h_node_feats = [torch.cat((h, x), dim=0) for h, x in zip(h_node_feats, v_hub)]
# Create the history graph data and pass them through the GNNs
pg_hist_data = [Data(x=x, edge_index=idx) for x, idx in zip(h_node_feats, h_edge_indices)]
pg_hist_loader = DataLoader(pg_hist_data, batch_size=batch_size, shuffle=False)
for data in pg_hist_loader:
x_h, edge_index_h, h_gnn_batch_idx = data.x, data.edge_index, data.batch
for i in range(self.num_h_gnn_layers):
# Normalization
x_h = self.h_gnn_norm_layers[i](x_h, h_gnn_batch_idx)
# Graph propagation
x_h = self.h_gnn_layers[i](x_h, edge_index_h, edge_attr=None)
# Activation
x_h = self.gnn_act(x_h) + x_h
x_h = self.gnn_act(x_h)
h_hub = self.h_gnn_pooling(x_h, h_gnn_batch_idx)
# Add the hub nodes
h_hub_split = torch.split(h_hub, 1, dim=0)
q_node_feats = [torch.cat((q, x), dim=0) for q, x in zip(q_node_feats, h_hub_split)]
# Create the question graph data and pass them through the GNNs
pg_ques_data = [Data(x=x, edge_index=idx, edge_attr=attr) for x, idx, attr in zip(q_node_feats, q_edge_indices, q_edge_attributes)]
pg_ques_loader = DataLoader(pg_ques_data, batch_size=batch_size, shuffle=False)
for data in pg_ques_loader:
x_q, edge_index_q, edge_attr_q, q_gnn_batch_idx = data.x, data.edge_index, data.edge_attr, data.batch
for i in range(self.num_q_gnn_layers):
# Normalization
x_q = self.q_gnn_norm_layers[i](x_q, q_gnn_batch_idx)
# GNN propagation
x_q = self.q_gnn_layers[i](x_q, edge_index_q, edge_attr=edge_attr_q)
# Activation
x_q = self.gnn_act(x_q) + x_q
x_q = self.gnn_act(x_q)
q_hub = self.q_gnn_pooling(x_q, q_gnn_batch_idx)
# Reshape the node features
h_node_feats = to_data_list(x_h, h_gnn_batch_idx)
q_node_feats = to_data_list(x_q, q_gnn_batch_idx)
# Update the text tokens with the graph feats
zipped_data = zip(h_node_feats, h_sep_indices_extended, q_node_feats, q_tok_indices_extended)
for i, (h_node_feat, h_sep_idx, q_node_feat, q_tok_idx) in enumerate(zipped_data):
gnn_hidden_states[i].scatter(0, h_sep_idx, h_node_feat[:-1])
gnn_hidden_states[i].scatter(0, q_tok_idx, q_node_feat[:-1])
final_hidden_states = 0.5 * (hidden_states + gnn_hidden_states)
return final_hidden_states, h_hub, q_hub
class BertImageSelfAttention(nn.Module):
def __init__(self, config):
super(BertImageSelfAttention, self).__init__()
if config.v_hidden_size % config.v_num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.v_hidden_size, config.v_num_attention_heads)
)
self.num_attention_heads = config.v_num_attention_heads
self.attention_head_size = int(
config.v_hidden_size / config.v_num_attention_heads
)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.v_hidden_size, self.all_head_size)
self.key = nn.Linear(config.v_hidden_size, self.all_head_size)
self.value = nn.Linear(config.v_hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.v_attention_probs_dropout_prob)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (
self.num_attention_heads,
self.attention_head_size,
)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states, attention_mask):
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer, attention_probs
class BertImageSelfOutput(nn.Module):
def __init__(self, config):
super(BertImageSelfOutput, self).__init__()
self.dense = nn.Linear(config.v_hidden_size, config.v_hidden_size)
self.LayerNorm = BertLayerNorm(config.v_hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.v_hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertImageAttention(nn.Module):
def __init__(self, config):
super(BertImageAttention, self).__init__()
self.self = BertImageSelfAttention(config)
self.output = BertImageSelfOutput(config)
def forward(self, input_tensor, attention_mask):
self_output, attention_probs = self.self(input_tensor, attention_mask)
attention_output = self.output(self_output, input_tensor)
return attention_output, attention_probs
class BertImageIntermediate(nn.Module):
def __init__(self, config):
super(BertImageIntermediate, self).__init__()
self.dense = nn.Linear(config.v_hidden_size, config.v_intermediate_size)
if isinstance(config.v_hidden_act, str) or (
sys.version_info[0] == 2 and isinstance(config.v_hidden_act, unicode)
):
self.intermediate_act_fn = ACT2FN[config.v_hidden_act]
else:
self.intermediate_act_fn = config.v_hidden_act
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class BertImageOutput(nn.Module):
def __init__(self, config):
super(BertImageOutput, self).__init__()
self.dense = nn.Linear(config.v_intermediate_size, config.v_hidden_size)
self.LayerNorm = BertLayerNorm(config.v_hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.v_hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertImageLayer(nn.Module):
def __init__(self, config):
super(BertImageLayer, self).__init__()
self.attention = BertImageAttention(config)
self.intermediate = BertImageIntermediate(config)
self.output = BertImageOutput(config)
def forward(self, hidden_states, attention_mask):
attention_output, attention_probs = self.attention(hidden_states, attention_mask)
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output, attention_probs
class ImageGraphLayer(nn.Module):
def __init__(self, config):
super(ImageGraphLayer, self).__init__()
self.config = config
self.gnn_act = ACT2FN[config.gnn_act]
self.num_gnn_layers = config.num_v_gnn_layers
self.gnn_layers = []
self.gnn_norm_layers = []
for _ in range(self.num_gnn_layers):
self.gnn_layers.append(
pyg_nn.GATv2Conv(
config.v_hidden_size, config.v_hidden_size//config.num_gnn_attention_heads,
config.num_gnn_attention_heads,
dropout=config.gnn_dropout_prob,
edge_dim=config.v_gnn_edge_dim,
concat=True
)
)
# After each graph layer, a normalization layer is added
self.gnn_norm_layers.append(pyg_nn.PairNorm())
self.gnn_layers = nn.ModuleList(self.gnn_layers)
self.gnn_norm_layers = nn.ModuleList(self.gnn_norm_layers)
self.gnn_dense_hub = nn.Linear(config.hidden_size, config.v_hidden_size)
self.gnn_layer_norm_hub = BertLayerNorm(config.v_hidden_size, eps=1e-12)
self.gnn_dropout_hub = nn.Dropout(config.gnn_dropout_prob)
dense_pooling = nn.Sequential(
nn.Linear(config.v_hidden_size, 1),
ACT2FN['GeLU'],
nn.Dropout(config.gnn_dropout_prob)
)
self.gnn_pooling = pyg_nn.GlobalAttention(dense_pooling)
def forward(
self, hidden_states, edge_indices, edge_attributes, hub_states,
len_img_gr=None):
# assert hub_states is not None
gnn_hidden_states = hidden_states.clone().detach()
batch_size, num_img_reg, v_hidden_size = hidden_states.size()
node_feats = hidden_states.clone().detach()
# Remave the [IMG] feats
node_feats = node_feats[:, 1:]
node_feats = torch.split(node_feats, 1, dim=0)
if len_img_gr is not None:
edge_indices = [t.squeeze(0)[:, :l].long() for t, l in zip(torch.split(edge_indices, 1, dim=0), len_img_gr)]
edge_attributes = [t.squeeze(0)[:l, :] for t, l in zip(torch.split(edge_attributes, 1, dim=0), len_img_gr)]
# Concat the hub states
hub_states = self.gnn_dense_hub(hub_states)
hub_states = self.gnn_dropout_hub(hub_states)
hub_states = self.gnn_layer_norm_hub(hub_states)
hub_states = torch.split(hub_states, 1, dim=0)
node_feats = [torch.cat((x.squeeze(0), h), dim=0)
for x, h in zip(node_feats, hub_states)]
pg_data = [Data(x, idx, attr) for x, idx, attr in zip(
node_feats, edge_indices, edge_attributes)]
pg_dataloader = DataLoader(
pg_data, batch_size=batch_size, shuffle=False)
# Gnn forward pass
for data in pg_dataloader:
x, edge_index, edge_attr, gnn_batch_idx = data.x, data.edge_index, data.edge_attr, data.batch
for i in range(self.num_gnn_layers):
# Normalization
x = self.gnn_norm_layers[i](x, gnn_batch_idx)
# GNN propagation
x = self.gnn_layers[i](x, edge_index, edge_attr=edge_attr)
# Activation
x = self.gnn_act(x) + x
x = self.gnn_act(x)
# Reshape the output of the GNN to batch_size x num_img_reg x hidden_dim
v_hub = self.gnn_pooling(x, gnn_batch_idx)
x = x.view(batch_size, num_img_reg, v_hidden_size)
gnn_hidden_states[:, 1:, :] = x[:, :-1, :]
final_hidden_states = 0.5 * (hidden_states + gnn_hidden_states)
return final_hidden_states, v_hub
class BertBiAttention(nn.Module):
def __init__(self, config):
super(BertBiAttention, self).__init__()
if config.bi_hidden_size % config.bi_num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.bi_hidden_size, config.bi_num_attention_heads)
)
self.num_attention_heads = config.bi_num_attention_heads
self.attention_head_size = int(
config.bi_hidden_size / config.bi_num_attention_heads
)
self.all_head_size = self.num_attention_heads * self.attention_head_size
# self.scale = nn.Linear(1, self.num_attention_heads, bias=False)
# self.scale_act_fn = ACT2FN['relu']
self.query1 = nn.Linear(config.v_hidden_size, self.all_head_size)
self.key1 = nn.Linear(config.v_hidden_size, self.all_head_size)
self.value1 = nn.Linear(config.v_hidden_size, self.all_head_size)
# self.logit1 = nn.Linear(config.hidden_size, self.num_attention_heads)
self.dropout1 = nn.Dropout(config.v_attention_probs_dropout_prob)
self.query2 = nn.Linear(config.hidden_size, self.all_head_size)
self.key2 = nn.Linear(config.hidden_size, self.all_head_size)
self.value2 = nn.Linear(config.hidden_size, self.all_head_size)
# self.logit2 = nn.Linear(config.hidden_size, self.num_attention_heads)
self.dropout2 = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (
self.num_attention_heads,
self.attention_head_size,
)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, input_tensor1, attention_mask1, input_tensor2, attention_mask2, co_attention_mask=None, use_co_attention_mask=False):
# for vision input.
mixed_query_layer1 = self.query1(input_tensor1)
mixed_key_layer1 = self.key1(input_tensor1)
mixed_value_layer1 = self.value1(input_tensor1)
# mixed_logit_layer1 = self.logit1(input_tensor1)
query_layer1 = self.transpose_for_scores(mixed_query_layer1)
key_layer1 = self.transpose_for_scores(mixed_key_layer1)
value_layer1 = self.transpose_for_scores(mixed_value_layer1)
# logit_layer1 = self.transpose_for_logits(mixed_logit_layer1)
# for text input:
mixed_query_layer2 = self.query2(input_tensor2)
mixed_key_layer2 = self.key2(input_tensor2)
mixed_value_layer2 = self.value2(input_tensor2)
# mixed_logit_layer2 = self.logit2(input_tensor2)
query_layer2 = self.transpose_for_scores(mixed_query_layer2)
key_layer2 = self.transpose_for_scores(mixed_key_layer2)
value_layer2 = self.transpose_for_scores(mixed_value_layer2)
# logit_layer2 = self.transpose_for_logits(mixed_logit_layer2)
# Take the dot product between "query2" and "key1" to get the raw attention scores for value 1.
attention_scores1 = torch.matmul(query_layer2, key_layer1.transpose(-1, -2))
attention_scores1 = attention_scores1 / math.sqrt(self.attention_head_size)
attention_scores1 = attention_scores1 + attention_mask1
if use_co_attention_mask:
attention_scores1 = attention_scores1 + co_attention_mask.permute(0,1,3,2)
# Normalize the attention scores to probabilities.
attention_probs1 = nn.Softmax(dim=-1)(attention_scores1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs1 = self.dropout1(attention_probs1)
context_layer1 = torch.matmul(attention_probs1, value_layer1)
context_layer1 = context_layer1.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape1 = context_layer1.size()[:-2] + (self.all_head_size,)
context_layer1 = context_layer1.view(*new_context_layer_shape1)
# Take the dot product between "query1" and "key2" to get the raw attention scores for value 2.
attention_scores2 = torch.matmul(query_layer1, key_layer2.transpose(-1, -2))
attention_scores2 = attention_scores2 / math.sqrt(self.attention_head_size)
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
# we can comment this line for single flow.
attention_scores2 = attention_scores2 + attention_mask2
if use_co_attention_mask:
attention_scores2 = attention_scores2 + co_attention_mask
# Normalize the attention scores to probabilities.
attention_probs2 = nn.Softmax(dim=-1)(attention_scores2)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs2 = self.dropout2(attention_probs2)
context_layer2 = torch.matmul(attention_probs2, value_layer2)
context_layer2 = context_layer2.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape2 = context_layer2.size()[:-2] + (self.all_head_size,)
context_layer2 = context_layer2.view(*new_context_layer_shape2)
return context_layer1, context_layer2, (attention_probs1, attention_probs2)
class BertBiOutput(nn.Module):
def __init__(self, config):
super(BertBiOutput, self).__init__()
self.dense1 = nn.Linear(config.bi_hidden_size, config.v_hidden_size)
self.LayerNorm1 = BertLayerNorm(config.v_hidden_size, eps=1e-12)
self.dropout1 = nn.Dropout(config.v_hidden_dropout_prob)
self.q_dense1 = nn.Linear(config.bi_hidden_size, config.v_hidden_size)
self.q_dropout1 = nn.Dropout(config.v_hidden_dropout_prob)
self.dense2 = nn.Linear(config.bi_hidden_size, config.hidden_size)
self.LayerNorm2 = BertLayerNorm(config.hidden_size, eps=1e-12)
self.dropout2 = nn.Dropout(config.hidden_dropout_prob)
self.q_dense2 = nn.Linear(config.bi_hidden_size, config.hidden_size)
self.q_dropout2 = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states1, input_tensor1, hidden_states2, input_tensor2):
context_state1 = self.dense1(hidden_states1)
context_state1 = self.dropout1(context_state1)
context_state2 = self.dense2(hidden_states2)
context_state2 = self.dropout2(context_state2)
hidden_states1 = self.LayerNorm1(context_state1 + input_tensor1)
hidden_states2 = self.LayerNorm2(context_state2 + input_tensor2)
return hidden_states1, hidden_states2
class BertConnectionLayer(nn.Module):
def __init__(self, config):
super(BertConnectionLayer, self).__init__()
self.biattention = BertBiAttention(config)
self.biOutput = BertBiOutput(config)
self.v_intermediate = BertImageIntermediate(config)
self.v_output = BertImageOutput(config)
self.t_intermediate = BertIntermediate(config)
self.t_output = BertOutput(config)
def forward(self, input_tensor1, attention_mask1, input_tensor2, attention_mask2, co_attention_mask=None, use_co_attention_mask=False):
bi_output1, bi_output2, co_attention_probs = self.biattention(
input_tensor1, attention_mask1, input_tensor2, attention_mask2, co_attention_mask, use_co_attention_mask
)
attention_output1, attention_output2 = self.biOutput(bi_output2, input_tensor1, bi_output1, input_tensor2)
intermediate_output1 = self.v_intermediate(attention_output1)
layer_output1 = self.v_output(intermediate_output1, attention_output1)
intermediate_output2 = self.t_intermediate(attention_output2)
layer_output2 = self.t_output(intermediate_output2, attention_output2)
return layer_output1, layer_output2, co_attention_probs
class BertEncoder(nn.Module):
def __init__(self, config):
super(BertEncoder, self).__init__()
# in the bert encoder, we need to extract three things here.
# text bert layer: BertLayer
# vision bert layer: BertImageLayer
# Bi-Attention: Given the output of two bertlayer, perform bi-directional
# attention and add on two layers.
self.FAST_MODE = config.fast_mode
self.with_coattention = config.with_coattention
self.v_biattention_id = config.v_biattention_id
self.t_biattention_id = config.t_biattention_id
self.in_batch_pairs = config.in_batch_pairs
self.fixed_t_layer = config.fixed_t_layer
self.fixed_v_layer = config.fixed_v_layer
self.t_gnn_ids = config.t_gnn_ids
self.v_gnn_ids = config.v_gnn_ids
v_layer = BertImageLayer(config)
connect_layer = BertConnectionLayer(config)
self.layer = []
for _ in range(config.num_hidden_layers):
self.layer.append(BertLayer(config))
self.layer = nn.ModuleList(self.layer)
txt_graph_layer = TextGraphLayer(config)
self.t_gnns = nn.ModuleList([txt_graph_layer for _ in range(len(self.t_gnn_ids))])