# coding=utf-8 # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch BERT model.""" import copy import json import logging import math import os import shutil import tarfile import tempfile import sys from io import open import torch from torch import nn from torch.nn import CrossEntropyLoss import torch.nn.functional as F from torch.nn.utils.weight_norm import weight_norm from pytorch_transformers.modeling_bert import BertEmbeddings from utils.data_utils import sequence_mask, to_data_list import torch_geometric.nn as pyg_nn from torch_geometric.data import Data from torch_geometric.loader import DataLoader from pytorch_pretrained_bert.file_utils import cached_path import pdb logger = logging.getLogger(__name__) PRETRAINED_MODEL_ARCHIVE_MAP = { "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz", "bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz", "bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz", "bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz", "bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz", "bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", "bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", } def load_tf_weights_in_bert(model, tf_checkpoint_path): """ Load tf checkpoints in a pytorch model """ try: import re import numpy as np import tensorflow as tf except ImportError: print( "Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see " "https://www.tensorflow.org/install/ for installation instructions." ) raise tf_path = os.path.abspath(tf_checkpoint_path) print("Converting TensorFlow checkpoint from {}".format(tf_path)) # Load weights from TF model init_vars = tf.train.list_variables(tf_path) names = [] arrays = [] for name, shape in init_vars: print("Loading TF weight {} with shape {}".format(name, shape)) array = tf.train.load_variable(tf_path, name) names.append(name) arrays.append(array) for name, array in zip(names, arrays): name = name.split("/") # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v # which are not required for using pretrained model if any(n in ["adam_v", "adam_m"] for n in name): print("Skipping {}".format("/".join(name))) continue pointer = model for m_name in name: if re.fullmatch(r"[A-Za-z]+_\d+", m_name): l = re.split(r"_(\d+)", m_name) else: l = [m_name] if l[0] == "kernel" or l[0] == "gamma": pointer = getattr(pointer, "weight") elif l[0] == "output_bias" or l[0] == "beta": pointer = getattr(pointer, "bias") elif l[0] == "output_weights": pointer = getattr(pointer, "weight") else: pointer = getattr(pointer, l[0]) if len(l) >= 2: num = int(l[1]) pointer = pointer[num] if m_name[-11:] == "_embeddings": pointer = getattr(pointer, "weight") elif m_name == "kernel": array = np.transpose(array) try: assert pointer.shape == array.shape except AssertionError as e: e.args += (pointer.shape, array.shape) raise print("Initialize PyTorch weight {}".format(name)) pointer.data = torch.from_numpy(array) return model class GeLU(nn.Module): """Implementation of the gelu activation function. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) Also see https://arxiv.org/abs/1606.08415 """ def __init__(self): super(GeLU, self).__init__() def forward(self, x): return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) def gelu(x): """Implementation of the gelu activation function. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) Also see https://arxiv.org/abs/1606.08415 """ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) def swish(x): return x * torch.sigmoid(x) ACT2FN = {"GeLU": GeLU(), "gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} class BertConfig(object): """Configuration class to store the configuration of a `BertModel`. """ def __init__( self, vocab_size_or_config_json_file, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, v_feature_size=2048, v_target_size=1601, v_hidden_size=768, v_num_hidden_layers=3, v_num_attention_heads=12, v_intermediate_size=3072, bi_hidden_size=1024, bi_num_attention_heads=16, v_attention_probs_dropout_prob=0.1, v_hidden_act="gelu", v_hidden_dropout_prob=0.1, v_initializer_range=0.2, v_biattention_id=[0, 1], t_biattention_id=[10, 11], predict_feature=False, fast_mode=False, fixed_v_layer=0, fixed_t_layer=0, in_batch_pairs=False, fusion_method="mul", intra_gate=False, with_coattention=True ): """Constructs BertConfig. Args: vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`. hidden_size: Size of the encoder layers and the pooler layer. num_hidden_layers: Number of hidden layers in the Transformer encoder. num_attention_heads: Number of attention heads for each attention layer in the Transformer encoder. intermediate_size: The size of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. hidden_act: The non-linear activation function (function or string) in the encoder and pooler. If string, "gelu", "relu" and "swish" are supported. hidden_dropout_prob: The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler. attention_probs_dropout_prob: The dropout ratio for the attention probabilities. max_position_embeddings: The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 512 or 1024 or 2048). type_vocab_size: The vocabulary size of the `token_type_ids` passed into `BertModel`. initializer_range: The sttdev of the truncated_normal_initializer for initializing all weight matrices. """ assert len(v_biattention_id) == len(t_biattention_id) assert max(v_biattention_id) < v_num_hidden_layers assert max(t_biattention_id) < num_hidden_layers if isinstance(vocab_size_or_config_json_file, str) or ( sys.version_info[0] == 2 and isinstance(vocab_size_or_config_json_file, unicode) ): with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader: json_config = json.loads(reader.read()) for key, value in json_config.items(): self.__dict__[key] = value elif isinstance(vocab_size_or_config_json_file, int): self.vocab_size = vocab_size_or_config_json_file self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.hidden_act = hidden_act self.intermediate_size = intermediate_size self.hidden_dropout_prob = hidden_dropout_prob self.attention_probs_dropout_prob = attention_probs_dropout_prob self.max_position_embeddings = max_position_embeddings self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range self.v_feature_size = v_feature_size self.v_hidden_size = v_hidden_size self.v_num_hidden_layers = v_num_hidden_layers self.v_num_attention_heads = v_num_attention_heads self.v_intermediate_size = v_intermediate_size self.v_attention_probs_dropout_prob = v_attention_probs_dropout_prob self.v_hidden_act = v_hidden_act self.v_hidden_dropout_prob = v_hidden_dropout_prob self.v_initializer_range = v_initializer_range self.v_biattention_id = v_biattention_id self.t_biattention_id = t_biattention_id self.v_target_size = v_target_size self.bi_hidden_size = bi_hidden_size self.bi_num_attention_heads = bi_num_attention_heads self.predict_feature = predict_feature self.fast_mode = fast_mode self.fixed_v_layer = fixed_v_layer self.fixed_t_layer = fixed_t_layer self.in_batch_pairs = in_batch_pairs self.fusion_method = fusion_method self.intra_gate = intra_gate self.with_coattention=with_coattention else: raise ValueError( "First argument must be either a vocabulary size (int)" "or the path to a pretrained model config file (str)" ) @classmethod def from_dict(cls, json_object): """Constructs a `BertConfig` from a Python dictionary of parameters.""" config = BertConfig(vocab_size_or_config_json_file=-1) for key, value in json_object.items(): config.__dict__[key] = value return config @classmethod def from_json_file(cls, json_file): """Constructs a `BertConfig` from a json file of parameters.""" with open(json_file, "r", encoding="utf-8") as reader: text = reader.read() return cls.from_dict(json.loads(text)) def __repr__(self): return str(self.to_json_string()) def to_dict(self): """Serializes this instance to a Python dictionary.""" output = copy.deepcopy(self.__dict__) return output def to_json_string(self): """Serializes this instance to a JSON string.""" return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" try: # from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm import torch.nn.LayerNorm as BertLayerNorm except ImportError: # logger.info( # "Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex ." # ) pass class BertLayerNorm(nn.Module): def __init__(self, hidden_size, eps=1e-12): """Construct a layernorm module in the TF style (epsilon inside the square root). """ super(BertLayerNorm, self).__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.bias = nn.Parameter(torch.zeros(hidden_size)) self.variance_epsilon = eps def forward(self, x): u = x.mean(-1, keepdim=True) s = (x - u).pow(2).mean(-1, keepdim=True) x = (x - u) / torch.sqrt(s + self.variance_epsilon) return self.weight * x + self.bias class BertEmbeddingsDialog(nn.Module): def __init__(self, config, device): super(BertEmbeddingsDialog, self).__init__() self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) max_seq_len = 256 d_model = config.hidden_size pe = torch.zeros(max_seq_len, d_model) for pos in range(max_seq_len): for i in range(0, d_model, 2): pe[pos, i] = \ math.sin(pos / (10000 ** ((2 * i)/d_model))) pe[pos, i + 1] = \ math.cos(pos / (10000 ** ((2 * (i + 1))/d_model))) self.pe = pe.to(device) self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) # add support for additional segment embeddings. Supporting 10 additional embedding as of now self.token_type_embeddings_extension = nn.Embedding(10,config.hidden_size) # adding specialized embeddings for sep tokens self.sep_embeddings = nn.Embedding(50,config.hidden_size) # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load # any TensorFlow checkpoint file self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.config = config def forward(self, input_ids, sep_indices=None, sep_len=None, token_type_ids=None): seq_length = input_ids.size(1) position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) position_ids = position_ids.unsqueeze(0).expand_as(input_ids) if token_type_ids is None: token_type_ids = torch.zeros_like(input_ids) words_embeddings = self.word_embeddings(input_ids) position_embeddings = self.position_embeddings(position_ids) token_type_ids_extension = token_type_ids - self.config.type_vocab_size token_type_ids_extension_mask = (token_type_ids_extension >= 0).float() token_type_ids_extension = (token_type_ids_extension.float() * token_type_ids_extension_mask).long() token_type_ids_mask = (token_type_ids < self.config.type_vocab_size).float() assert torch.sum(token_type_ids_extension_mask + token_type_ids_mask) == \ torch.numel(token_type_ids) == torch.numel(token_type_ids_mask) token_type_ids = (token_type_ids.float() * token_type_ids_mask).long() token_type_embeddings = self.token_type_embeddings(token_type_ids) token_type_embeddings_extension = self.token_type_embeddings_extension(token_type_ids_extension) token_type_embeddings = (token_type_embeddings * token_type_ids_mask.unsqueeze(-1)) + \ (token_type_embeddings_extension * token_type_ids_extension_mask.unsqueeze(-1)) embeddings = words_embeddings + position_embeddings + token_type_embeddings embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings class BertSelfAttention(nn.Module): def __init__(self, config): super(BertSelfAttention, self).__init__() if config.hidden_size % config.num_attention_heads != 0: raise ValueError( "The hidden size (%d) is not a multiple of the number of attention " "heads (%d)" % (config.hidden_size, config.num_attention_heads) ) self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size self.query = nn.Linear(config.hidden_size, self.all_head_size) self.key = nn.Linear(config.hidden_size, self.all_head_size) self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + ( self.num_attention_heads, self.attention_head_size, ) x = x.view(*new_x_shape) return x.permute(0, 2, 1, 3) def forward(self, hidden_states, attention_mask): mixed_query_layer = self.query(hidden_states) mixed_key_layer = self.key(hidden_states) mixed_value_layer = self.value(hidden_states) query_layer = self.transpose_for_scores(mixed_query_layer) key_layer = self.transpose_for_scores(mixed_key_layer) value_layer = self.transpose_for_scores(mixed_value_layer) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = attention_scores / math.sqrt(self.attention_head_size) # Apply the attention mask is (precomputed for all layers in BertModel forward() function) attention_scores = attention_scores + attention_mask # Normalize the attention scores to probabilities. attention_probs = nn.Softmax(dim=-1)(attention_scores) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = self.dropout(attention_probs) context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(*new_context_layer_shape) return context_layer, attention_probs class BertSelfOutput(nn.Module): def __init__(self, config): super(BertSelfOutput, self).__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states class BertAttention(nn.Module): def __init__(self, config): super(BertAttention, self).__init__() self.self = BertSelfAttention(config) self.output = BertSelfOutput(config) def forward(self, input_tensor, attention_mask): self_output, attention_probs = self.self(input_tensor, attention_mask) attention_output = self.output(self_output, input_tensor) return attention_output, attention_probs class BertIntermediate(nn.Module): def __init__(self, config): super(BertIntermediate, self).__init__() self.dense = nn.Linear(config.hidden_size, config.intermediate_size) if isinstance(config.hidden_act, str) or ( sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode) ): self.intermediate_act_fn = ACT2FN[config.hidden_act] else: self.intermediate_act_fn = config.hidden_act def forward(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) return hidden_states class BertOutput(nn.Module): def __init__(self, config): super(BertOutput, self).__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states class BertLayer(nn.Module): def __init__(self, config): super(BertLayer, self).__init__() self.attention = BertAttention(config) self.intermediate = BertIntermediate(config) self.output = BertOutput(config) def forward(self, hidden_states, attention_mask): attention_output, attention_probs = self.attention(hidden_states, attention_mask) intermediate_output = self.intermediate(attention_output) layer_output = self.output(intermediate_output, attention_output) return layer_output, attention_probs class TextGraphLayer(nn.Module): def __init__(self, config): super(TextGraphLayer, self).__init__() self.config = config self.gnn_act = ACT2FN[config.gnn_act] self.num_q_gnn_layers = config.num_q_gnn_layers self.num_h_gnn_layers = config.num_h_gnn_layers self.q_gnn_layers = [] self.q_gnn_norm_layers = [] for _ in range(self.num_q_gnn_layers): # Graph layers self.q_gnn_layers.append( pyg_nn.GATv2Conv( config.hidden_size, config.hidden_size//config.num_gnn_attention_heads, config.num_gnn_attention_heads, dropout=config.gnn_dropout_prob, edge_dim=config.q_gnn_edge_dim, concat=True ) ) # After each graph layer, a normalization layer is added self.q_gnn_norm_layers.append(pyg_nn.PairNorm()) self.q_gnn_layers = nn.ModuleList(self.q_gnn_layers) self.q_gnn_norm_layers = nn.ModuleList(self.q_gnn_norm_layers) self.h_gnn_layers = [] self.h_gnn_norm_layers = [] for _ in range(self.num_h_gnn_layers): self.h_gnn_layers.append( pyg_nn.GATv2Conv( config.hidden_size, config.hidden_size//config.num_gnn_attention_heads, config.num_gnn_attention_heads, dropout=config.gnn_dropout_prob, concat=True ) ) # After each graph layer, a normalization layer is added self.h_gnn_norm_layers.append(pyg_nn.PairNorm()) self.h_gnn_layers = nn.ModuleList(self.h_gnn_layers) self.h_gnn_norm_layers = nn.ModuleList(self.h_gnn_norm_layers) self.h_gnn_dense_hub = nn.Linear(config.v_hidden_size, config.hidden_size) self.h_gnn_layer_norm_hub = BertLayerNorm(config.hidden_size, eps=1e-12) self.h_gnn_dropout_hub = nn.Dropout(config.gnn_dropout_prob) q_dense_pooling = nn.Sequential( nn.Linear(config.hidden_size, 1), ACT2FN['GeLU'], nn.Dropout(config.gnn_dropout_prob) ) self.q_gnn_pooling = pyg_nn.GlobalAttention(q_dense_pooling) h_dense_pooling = nn.Sequential( nn.Linear(config.hidden_size, 1), ACT2FN['GeLU'], nn.Dropout(config.gnn_dropout_prob) ) self.h_gnn_pooling = pyg_nn.GlobalAttention(h_dense_pooling) def forward( self, hidden_states, q_edge_indices, q_edge_attributes, q_limits, h_edge_indices, h_sep_indices, v_hub, len_q_gr=None, len_h_gr=None, len_h_sep=None): device = hidden_states.device batch_size, _, hidden_size = hidden_states.size() if isinstance(q_edge_indices, list): assert len(q_edge_indices) == len(q_edge_attributes) == q_limits.size(0) \ == len(h_edge_indices) == len(h_sep_indices) == batch_size else: assert q_edge_indices.size(0) == q_edge_attributes.size(0) == q_limits.size(0) \ == h_edge_indices.size(0) == h_sep_indices.size(0) == batch_size if len_q_gr is not None: q_edge_indices = [t.squeeze(0)[:, :l].long() for t, l in zip(torch.split(q_edge_indices, 1, dim=0), len_q_gr)] q_edge_attributes = [t.squeeze(0)[:l, :] for t, l in zip(torch.split(q_edge_attributes, 1, dim=0), len_q_gr)] h_edge_indices = [t.squeeze(0)[:, :l].long() for t, l in zip(torch.split(h_edge_indices, 1, dim=0), len_h_gr)] h_sep_indices = [t.squeeze(0)[:l].long() for t, l in zip(torch.split(h_sep_indices, 1, dim=0), len_h_sep)] else: q_edge_indices = [t.squeeze(0) for t in torch.split(q_edge_indices, 1, dim=0)] q_edge_attributes = [t.squeeze(0) for t in torch.split(q_edge_attributes, 1, dim=0)] h_edge_indices = [t.squeeze(0) for t in torch.split(h_edge_indices, 1, dim=0)] h_sep_indices = [t.squeeze(0).long() for t in torch.split(h_sep_indices, 1, dim=0)] gnn_hidden_states = hidden_states.clone().detach() # Extract the history and question node features (without the hub node) h_node_feats = [] q_node_feats = [] q_limits = q_limits.tolist() q_tok_indices_extended = [] h_sep_indices_extended = [] for i, (h_sep_idx, q_limit) in enumerate(zip(h_sep_indices, q_limits)): batch_data = gnn_hidden_states[i, :, :].clone().detach() h_sep_idx = h_sep_idx.unsqueeze(-1).repeat(1, hidden_size) h_sep_indices_extended.append(h_sep_idx) h_node_feats.append(torch.gather(batch_data, 0, h_sep_idx)) q_tok_idx = torch.arange(q_limit[0], q_limit[1]).unsqueeze(-1).repeat(1, hidden_size).to(device) q_tok_indices_extended.append(q_tok_idx) q_node_feats.append(torch.gather(batch_data, 0, q_tok_idx)) # if self.use_hub_nodes: # Map v_hub to the correct vector space v_hub = self.h_gnn_dense_hub(v_hub) v_hub = self.h_gnn_layer_norm_hub(v_hub) v_hub = self.h_gnn_dropout_hub(v_hub) # Add the hub node to the history nodes v_hub = torch.split(v_hub, 1, dim=0) h_node_feats = [torch.cat((h, x), dim=0) for h, x in zip(h_node_feats, v_hub)] # Create the history graph data and pass them through the GNNs pg_hist_data = [Data(x=x, edge_index=idx) for x, idx in zip(h_node_feats, h_edge_indices)] pg_hist_loader = DataLoader(pg_hist_data, batch_size=batch_size, shuffle=False) for data in pg_hist_loader: x_h, edge_index_h, h_gnn_batch_idx = data.x, data.edge_index, data.batch for i in range(self.num_h_gnn_layers): # Normalization x_h = self.h_gnn_norm_layers[i](x_h, h_gnn_batch_idx) # Graph propagation x_h = self.h_gnn_layers[i](x_h, edge_index_h, edge_attr=None) # Activation x_h = self.gnn_act(x_h) + x_h x_h = self.gnn_act(x_h) h_hub = self.h_gnn_pooling(x_h, h_gnn_batch_idx) # Add the hub nodes h_hub_split = torch.split(h_hub, 1, dim=0) q_node_feats = [torch.cat((q, x), dim=0) for q, x in zip(q_node_feats, h_hub_split)] # Create the question graph data and pass them through the GNNs pg_ques_data = [Data(x=x, edge_index=idx, edge_attr=attr) for x, idx, attr in zip(q_node_feats, q_edge_indices, q_edge_attributes)] pg_ques_loader = DataLoader(pg_ques_data, batch_size=batch_size, shuffle=False) for data in pg_ques_loader: x_q, edge_index_q, edge_attr_q, q_gnn_batch_idx = data.x, data.edge_index, data.edge_attr, data.batch for i in range(self.num_q_gnn_layers): # Normalization x_q = self.q_gnn_norm_layers[i](x_q, q_gnn_batch_idx) # GNN propagation x_q = self.q_gnn_layers[i](x_q, edge_index_q, edge_attr=edge_attr_q) # Activation x_q = self.gnn_act(x_q) + x_q x_q = self.gnn_act(x_q) q_hub = self.q_gnn_pooling(x_q, q_gnn_batch_idx) # Reshape the node features h_node_feats = to_data_list(x_h, h_gnn_batch_idx) q_node_feats = to_data_list(x_q, q_gnn_batch_idx) # Update the text tokens with the graph feats zipped_data = zip(h_node_feats, h_sep_indices_extended, q_node_feats, q_tok_indices_extended) for i, (h_node_feat, h_sep_idx, q_node_feat, q_tok_idx) in enumerate(zipped_data): gnn_hidden_states[i].scatter(0, h_sep_idx, h_node_feat[:-1]) gnn_hidden_states[i].scatter(0, q_tok_idx, q_node_feat[:-1]) final_hidden_states = 0.5 * (hidden_states + gnn_hidden_states) return final_hidden_states, h_hub, q_hub class BertImageSelfAttention(nn.Module): def __init__(self, config): super(BertImageSelfAttention, self).__init__() if config.v_hidden_size % config.v_num_attention_heads != 0: raise ValueError( "The hidden size (%d) is not a multiple of the number of attention " "heads (%d)" % (config.v_hidden_size, config.v_num_attention_heads) ) self.num_attention_heads = config.v_num_attention_heads self.attention_head_size = int( config.v_hidden_size / config.v_num_attention_heads ) self.all_head_size = self.num_attention_heads * self.attention_head_size self.query = nn.Linear(config.v_hidden_size, self.all_head_size) self.key = nn.Linear(config.v_hidden_size, self.all_head_size) self.value = nn.Linear(config.v_hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.v_attention_probs_dropout_prob) def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + ( self.num_attention_heads, self.attention_head_size, ) x = x.view(*new_x_shape) return x.permute(0, 2, 1, 3) def forward(self, hidden_states, attention_mask): mixed_query_layer = self.query(hidden_states) mixed_key_layer = self.key(hidden_states) mixed_value_layer = self.value(hidden_states) query_layer = self.transpose_for_scores(mixed_query_layer) key_layer = self.transpose_for_scores(mixed_key_layer) value_layer = self.transpose_for_scores(mixed_value_layer) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = attention_scores / math.sqrt(self.attention_head_size) # Apply the attention mask is (precomputed for all layers in BertModel forward() function) attention_scores = attention_scores + attention_mask # Normalize the attention scores to probabilities. attention_probs = nn.Softmax(dim=-1)(attention_scores) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = self.dropout(attention_probs) context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(*new_context_layer_shape) return context_layer, attention_probs class BertImageSelfOutput(nn.Module): def __init__(self, config): super(BertImageSelfOutput, self).__init__() self.dense = nn.Linear(config.v_hidden_size, config.v_hidden_size) self.LayerNorm = BertLayerNorm(config.v_hidden_size, eps=1e-12) self.dropout = nn.Dropout(config.v_hidden_dropout_prob) def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states class BertImageAttention(nn.Module): def __init__(self, config): super(BertImageAttention, self).__init__() self.self = BertImageSelfAttention(config) self.output = BertImageSelfOutput(config) def forward(self, input_tensor, attention_mask): self_output, attention_probs = self.self(input_tensor, attention_mask) attention_output = self.output(self_output, input_tensor) return attention_output, attention_probs class BertImageIntermediate(nn.Module): def __init__(self, config): super(BertImageIntermediate, self).__init__() self.dense = nn.Linear(config.v_hidden_size, config.v_intermediate_size) if isinstance(config.v_hidden_act, str) or ( sys.version_info[0] == 2 and isinstance(config.v_hidden_act, unicode) ): self.intermediate_act_fn = ACT2FN[config.v_hidden_act] else: self.intermediate_act_fn = config.v_hidden_act def forward(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) return hidden_states class BertImageOutput(nn.Module): def __init__(self, config): super(BertImageOutput, self).__init__() self.dense = nn.Linear(config.v_intermediate_size, config.v_hidden_size) self.LayerNorm = BertLayerNorm(config.v_hidden_size, eps=1e-12) self.dropout = nn.Dropout(config.v_hidden_dropout_prob) def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states class BertImageLayer(nn.Module): def __init__(self, config): super(BertImageLayer, self).__init__() self.attention = BertImageAttention(config) self.intermediate = BertImageIntermediate(config) self.output = BertImageOutput(config) def forward(self, hidden_states, attention_mask): attention_output, attention_probs = self.attention(hidden_states, attention_mask) intermediate_output = self.intermediate(attention_output) layer_output = self.output(intermediate_output, attention_output) return layer_output, attention_probs class ImageGraphLayer(nn.Module): def __init__(self, config): super(ImageGraphLayer, self).__init__() self.config = config self.gnn_act = ACT2FN[config.gnn_act] self.num_gnn_layers = config.num_v_gnn_layers self.gnn_layers = [] self.gnn_norm_layers = [] for _ in range(self.num_gnn_layers): self.gnn_layers.append( pyg_nn.GATv2Conv( config.v_hidden_size, config.v_hidden_size//config.num_gnn_attention_heads, config.num_gnn_attention_heads, dropout=config.gnn_dropout_prob, edge_dim=config.v_gnn_edge_dim, concat=True ) ) # After each graph layer, a normalization layer is added self.gnn_norm_layers.append(pyg_nn.PairNorm()) self.gnn_layers = nn.ModuleList(self.gnn_layers) self.gnn_norm_layers = nn.ModuleList(self.gnn_norm_layers) self.gnn_dense_hub = nn.Linear(config.hidden_size, config.v_hidden_size) self.gnn_layer_norm_hub = BertLayerNorm(config.v_hidden_size, eps=1e-12) self.gnn_dropout_hub = nn.Dropout(config.gnn_dropout_prob) dense_pooling = nn.Sequential( nn.Linear(config.v_hidden_size, 1), ACT2FN['GeLU'], nn.Dropout(config.gnn_dropout_prob) ) self.gnn_pooling = pyg_nn.GlobalAttention(dense_pooling) def forward( self, hidden_states, edge_indices, edge_attributes, hub_states, len_img_gr=None): # assert hub_states is not None gnn_hidden_states = hidden_states.clone().detach() batch_size, num_img_reg, v_hidden_size = hidden_states.size() node_feats = hidden_states.clone().detach() # Remave the [IMG] feats node_feats = node_feats[:, 1:] node_feats = torch.split(node_feats, 1, dim=0) if len_img_gr is not None: edge_indices = [t.squeeze(0)[:, :l].long() for t, l in zip(torch.split(edge_indices, 1, dim=0), len_img_gr)] edge_attributes = [t.squeeze(0)[:l, :] for t, l in zip(torch.split(edge_attributes, 1, dim=0), len_img_gr)] # Concat the hub states hub_states = self.gnn_dense_hub(hub_states) hub_states = self.gnn_dropout_hub(hub_states) hub_states = self.gnn_layer_norm_hub(hub_states) hub_states = torch.split(hub_states, 1, dim=0) node_feats = [torch.cat((x.squeeze(0), h), dim=0) for x, h in zip(node_feats, hub_states)] pg_data = [Data(x, idx, attr) for x, idx, attr in zip( node_feats, edge_indices, edge_attributes)] pg_dataloader = DataLoader( pg_data, batch_size=batch_size, shuffle=False) # Gnn forward pass for data in pg_dataloader: x, edge_index, edge_attr, gnn_batch_idx = data.x, data.edge_index, data.edge_attr, data.batch for i in range(self.num_gnn_layers): # Normalization x = self.gnn_norm_layers[i](x, gnn_batch_idx) # GNN propagation x = self.gnn_layers[i](x, edge_index, edge_attr=edge_attr) # Activation x = self.gnn_act(x) + x x = self.gnn_act(x) # Reshape the output of the GNN to batch_size x num_img_reg x hidden_dim v_hub = self.gnn_pooling(x, gnn_batch_idx) x = x.view(batch_size, num_img_reg, v_hidden_size) gnn_hidden_states[:, 1:, :] = x[:, :-1, :] final_hidden_states = 0.5 * (hidden_states + gnn_hidden_states) return final_hidden_states, v_hub class BertBiAttention(nn.Module): def __init__(self, config): super(BertBiAttention, self).__init__() if config.bi_hidden_size % config.bi_num_attention_heads != 0: raise ValueError( "The hidden size (%d) is not a multiple of the number of attention " "heads (%d)" % (config.bi_hidden_size, config.bi_num_attention_heads) ) self.num_attention_heads = config.bi_num_attention_heads self.attention_head_size = int( config.bi_hidden_size / config.bi_num_attention_heads ) self.all_head_size = self.num_attention_heads * self.attention_head_size # self.scale = nn.Linear(1, self.num_attention_heads, bias=False) # self.scale_act_fn = ACT2FN['relu'] self.query1 = nn.Linear(config.v_hidden_size, self.all_head_size) self.key1 = nn.Linear(config.v_hidden_size, self.all_head_size) self.value1 = nn.Linear(config.v_hidden_size, self.all_head_size) # self.logit1 = nn.Linear(config.hidden_size, self.num_attention_heads) self.dropout1 = nn.Dropout(config.v_attention_probs_dropout_prob) self.query2 = nn.Linear(config.hidden_size, self.all_head_size) self.key2 = nn.Linear(config.hidden_size, self.all_head_size) self.value2 = nn.Linear(config.hidden_size, self.all_head_size) # self.logit2 = nn.Linear(config.hidden_size, self.num_attention_heads) self.dropout2 = nn.Dropout(config.attention_probs_dropout_prob) def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + ( self.num_attention_heads, self.attention_head_size, ) x = x.view(*new_x_shape) return x.permute(0, 2, 1, 3) def forward(self, input_tensor1, attention_mask1, input_tensor2, attention_mask2, co_attention_mask=None, use_co_attention_mask=False): # for vision input. mixed_query_layer1 = self.query1(input_tensor1) mixed_key_layer1 = self.key1(input_tensor1) mixed_value_layer1 = self.value1(input_tensor1) # mixed_logit_layer1 = self.logit1(input_tensor1) query_layer1 = self.transpose_for_scores(mixed_query_layer1) key_layer1 = self.transpose_for_scores(mixed_key_layer1) value_layer1 = self.transpose_for_scores(mixed_value_layer1) # logit_layer1 = self.transpose_for_logits(mixed_logit_layer1) # for text input: mixed_query_layer2 = self.query2(input_tensor2) mixed_key_layer2 = self.key2(input_tensor2) mixed_value_layer2 = self.value2(input_tensor2) # mixed_logit_layer2 = self.logit2(input_tensor2) query_layer2 = self.transpose_for_scores(mixed_query_layer2) key_layer2 = self.transpose_for_scores(mixed_key_layer2) value_layer2 = self.transpose_for_scores(mixed_value_layer2) # logit_layer2 = self.transpose_for_logits(mixed_logit_layer2) # Take the dot product between "query2" and "key1" to get the raw attention scores for value 1. attention_scores1 = torch.matmul(query_layer2, key_layer1.transpose(-1, -2)) attention_scores1 = attention_scores1 / math.sqrt(self.attention_head_size) attention_scores1 = attention_scores1 + attention_mask1 if use_co_attention_mask: attention_scores1 = attention_scores1 + co_attention_mask.permute(0,1,3,2) # Normalize the attention scores to probabilities. attention_probs1 = nn.Softmax(dim=-1)(attention_scores1) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs1 = self.dropout1(attention_probs1) context_layer1 = torch.matmul(attention_probs1, value_layer1) context_layer1 = context_layer1.permute(0, 2, 1, 3).contiguous() new_context_layer_shape1 = context_layer1.size()[:-2] + (self.all_head_size,) context_layer1 = context_layer1.view(*new_context_layer_shape1) # Take the dot product between "query1" and "key2" to get the raw attention scores for value 2. attention_scores2 = torch.matmul(query_layer1, key_layer2.transpose(-1, -2)) attention_scores2 = attention_scores2 / math.sqrt(self.attention_head_size) # Apply the attention mask is (precomputed for all layers in BertModel forward() function) # we can comment this line for single flow. attention_scores2 = attention_scores2 + attention_mask2 if use_co_attention_mask: attention_scores2 = attention_scores2 + co_attention_mask # Normalize the attention scores to probabilities. attention_probs2 = nn.Softmax(dim=-1)(attention_scores2) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs2 = self.dropout2(attention_probs2) context_layer2 = torch.matmul(attention_probs2, value_layer2) context_layer2 = context_layer2.permute(0, 2, 1, 3).contiguous() new_context_layer_shape2 = context_layer2.size()[:-2] + (self.all_head_size,) context_layer2 = context_layer2.view(*new_context_layer_shape2) return context_layer1, context_layer2, (attention_probs1, attention_probs2) class BertBiOutput(nn.Module): def __init__(self, config): super(BertBiOutput, self).__init__() self.dense1 = nn.Linear(config.bi_hidden_size, config.v_hidden_size) self.LayerNorm1 = BertLayerNorm(config.v_hidden_size, eps=1e-12) self.dropout1 = nn.Dropout(config.v_hidden_dropout_prob) self.q_dense1 = nn.Linear(config.bi_hidden_size, config.v_hidden_size) self.q_dropout1 = nn.Dropout(config.v_hidden_dropout_prob) self.dense2 = nn.Linear(config.bi_hidden_size, config.hidden_size) self.LayerNorm2 = BertLayerNorm(config.hidden_size, eps=1e-12) self.dropout2 = nn.Dropout(config.hidden_dropout_prob) self.q_dense2 = nn.Linear(config.bi_hidden_size, config.hidden_size) self.q_dropout2 = nn.Dropout(config.hidden_dropout_prob) def forward(self, hidden_states1, input_tensor1, hidden_states2, input_tensor2): context_state1 = self.dense1(hidden_states1) context_state1 = self.dropout1(context_state1) context_state2 = self.dense2(hidden_states2) context_state2 = self.dropout2(context_state2) hidden_states1 = self.LayerNorm1(context_state1 + input_tensor1) hidden_states2 = self.LayerNorm2(context_state2 + input_tensor2) return hidden_states1, hidden_states2 class BertConnectionLayer(nn.Module): def __init__(self, config): super(BertConnectionLayer, self).__init__() self.biattention = BertBiAttention(config) self.biOutput = BertBiOutput(config) self.v_intermediate = BertImageIntermediate(config) self.v_output = BertImageOutput(config) self.t_intermediate = BertIntermediate(config) self.t_output = BertOutput(config) def forward(self, input_tensor1, attention_mask1, input_tensor2, attention_mask2, co_attention_mask=None, use_co_attention_mask=False): bi_output1, bi_output2, co_attention_probs = self.biattention( input_tensor1, attention_mask1, input_tensor2, attention_mask2, co_attention_mask, use_co_attention_mask ) attention_output1, attention_output2 = self.biOutput(bi_output2, input_tensor1, bi_output1, input_tensor2) intermediate_output1 = self.v_intermediate(attention_output1) layer_output1 = self.v_output(intermediate_output1, attention_output1) intermediate_output2 = self.t_intermediate(attention_output2) layer_output2 = self.t_output(intermediate_output2, attention_output2) return layer_output1, layer_output2, co_attention_probs class BertEncoder(nn.Module): def __init__(self, config): super(BertEncoder, self).__init__() # in the bert encoder, we need to extract three things here. # text bert layer: BertLayer # vision bert layer: BertImageLayer # Bi-Attention: Given the output of two bertlayer, perform bi-directional # attention and add on two layers. self.FAST_MODE = config.fast_mode self.with_coattention = config.with_coattention self.v_biattention_id = config.v_biattention_id self.t_biattention_id = config.t_biattention_id self.in_batch_pairs = config.in_batch_pairs self.fixed_t_layer = config.fixed_t_layer self.fixed_v_layer = config.fixed_v_layer self.t_gnn_ids = config.t_gnn_ids self.v_gnn_ids = config.v_gnn_ids v_layer = BertImageLayer(config) connect_layer = BertConnectionLayer(config) self.layer = [] for _ in range(config.num_hidden_layers): self.layer.append(BertLayer(config)) self.layer = nn.ModuleList(self.layer) txt_graph_layer = TextGraphLayer(config) self.t_gnns = nn.ModuleList([txt_graph_layer for _ in range(len(self.t_gnn_ids))]) self.v_layer = nn.ModuleList( [copy.deepcopy(v_layer) for _ in range(config.v_num_hidden_layers)] ) img_graph_layer = ImageGraphLayer(config) self.v_gnns = nn.ModuleList([img_graph_layer for _ in range(len(self.v_gnn_ids))]) self.c_layer = nn.ModuleList( [copy.deepcopy(connect_layer) for _ in range(len(config.v_biattention_id))] ) def forward( self, txt_embedding, image_embedding, txt_attention_mask, image_attention_mask, image_edge_indices, image_edge_attributes, question_edge_indices, question_edge_attributes, question_limits, history_edge_indices, history_sep_indices, co_attention_mask=None, output_all_encoded_layers=True, output_all_attention_masks=False, len_img_gr=None, len_q_gr=None, len_h_gr=None, len_h_sep=None ): v_start = 0 t_start = 0 count = 0 all_encoder_layers_t = [] all_encoder_layers_v = [] all_attention_mask_t = [] all_attnetion_mask_v = [] all_attention_mask_c = [] batch_size, num_words, t_hidden_size = txt_embedding.size() _, num_regions, v_hidden_size = image_embedding.size() # self.pool_feats(txt_embedding) use_co_attention_mask = False # Init the v_hub with the [IMG]-token embedding v_hub = image_embedding[:, 0, :].clone().detach() q_hub = None for v_layer_id, t_layer_id in zip(self.v_biattention_id, self.t_biattention_id): v_end = v_layer_id t_end = t_layer_id assert self.fixed_t_layer <= t_end assert self.fixed_v_layer <= v_end for idx in range(v_start, self.fixed_v_layer): with torch.no_grad(): image_embedding, image_attention_probs = self.v_layer[idx]( image_embedding, image_attention_mask) v_start = self.fixed_v_layer if output_all_attention_masks: all_attnetion_mask_v.append(image_attention_probs) for idx in range(v_start, v_end): # Perfrom graph message passing and aggr. if applicable if idx in self.v_gnn_ids: assert q_hub is not None v_gnn_layer_idx = self.v_gnn_ids.index(idx) image_embedding, v_hub = self.v_gnns[v_gnn_layer_idx]( image_embedding, image_edge_indices, image_edge_attributes, q_hub, len_img_gr=len_img_gr, ) # Perform standard bert self-attention image_embedding, image_attention_probs = self.v_layer[idx]( image_embedding, image_attention_mask) if output_all_attention_masks: all_attnetion_mask_v.append(image_attention_probs) for idx in range(t_start, self.fixed_t_layer): with torch.no_grad(): txt_embedding, txt_attention_probs = self.layer[idx](txt_embedding, txt_attention_mask) t_start = self.fixed_t_layer if output_all_attention_masks: all_attention_mask_t.append(txt_attention_probs) for idx in range(t_start, t_end): # Perfrom graph message passing and aggr. if applicable if idx in self.t_gnn_ids: t_gnn_layer_idx = self.t_gnn_ids.index(idx) txt_embedding, h_hub, q_hub = self.t_gnns[t_gnn_layer_idx]( txt_embedding, question_edge_indices, question_edge_attributes, question_limits, history_edge_indices, history_sep_indices, v_hub, len_q_gr=len_q_gr, len_h_gr=len_h_gr, len_h_sep=len_h_sep ) # Perform standard bert self-attention txt_embedding, txt_attention_probs = self.layer[idx](txt_embedding, txt_attention_mask) if output_all_attention_masks: all_attention_mask_t.append(txt_attention_probs) if count == 0 and self.in_batch_pairs: # new batch size is the batch_size ^2 image_embedding = image_embedding.unsqueeze(0).expand(batch_size, batch_size, num_regions, v_hidden_size).contiguous( ).view(batch_size*batch_size, num_regions, v_hidden_size) image_attention_mask = image_attention_mask.unsqueeze(0).expand( batch_size, batch_size, 1, 1, num_regions).contiguous().view(batch_size*batch_size, 1, 1, num_regions) txt_embedding = txt_embedding.unsqueeze(1).expand(batch_size, batch_size, num_words, t_hidden_size).contiguous( ).view(batch_size*batch_size, num_words, t_hidden_size) txt_attention_mask = txt_attention_mask.unsqueeze(1).expand( batch_size, batch_size, 1, 1, num_words).contiguous().view(batch_size*batch_size, 1, 1, num_words) co_attention_mask = co_attention_mask.unsqueeze(1).expand( batch_size, batch_size, 1, num_regions, num_words).contiguous().view(batch_size*batch_size, 1, num_regions, num_words) if self.with_coattention: # do the bi attention. image_embedding, txt_embedding, co_attention_probs = self.c_layer[count]( image_embedding, image_attention_mask, txt_embedding, txt_attention_mask, co_attention_mask, use_co_attention_mask) # use_co_attention_mask = False if output_all_attention_masks: all_attention_mask_c.append(co_attention_probs) v_start = v_end t_start = t_end count += 1 if output_all_encoded_layers: all_encoder_layers_t.append(txt_embedding) all_encoder_layers_v.append(image_embedding) for idx in range(v_start, len(self.v_layer)): # Perfrom graph message passing and aggr. if applicable if idx in self.v_gnn_ids: v_gnn_layer_idx = self.v_gnn_ids.index(idx) image_embedding, v_hub = self.v_gnns[v_gnn_layer_idx]( image_embedding, image_edge_indices, image_edge_attributes, q_hub, len_img_gr=len_img_gr ) # Perform standard bert self-attention image_embedding, image_attention_probs = self.v_layer[idx]( image_embedding, image_attention_mask) if output_all_attention_masks: all_attnetion_mask_v.append(image_attention_probs) for idx in range(t_start, len(self.layer)): # Perfrom graph message passing and aggr. if applicable if idx in self.t_gnn_ids: t_gnn_layer_idx = self.t_gnn_ids.index(idx) txt_embedding, h_hub, q_hub = self.t_gnns[t_gnn_layer_idx]( txt_embedding, question_edge_indices, question_edge_attributes, question_limits, history_edge_indices, history_sep_indices, v_hub, len_q_gr=len_q_gr, len_h_gr=len_h_gr, len_h_sep=len_h_sep ) # Perform standard bert self-attention txt_embedding, txt_attention_probs = self.layer[idx](txt_embedding, txt_attention_mask) if output_all_attention_masks: all_attention_mask_t.append(txt_attention_probs) # add the end part to finish. if not output_all_encoded_layers: all_encoder_layers_t.append(txt_embedding) all_encoder_layers_v.append(image_embedding) return all_encoder_layers_t, all_encoder_layers_v, (all_attention_mask_t, all_attnetion_mask_v, all_attention_mask_c), (h_hub, q_hub, v_hub) class BertTextPooler(nn.Module): def __init__(self, config): super(BertTextPooler, self).__init__() self.dense = nn.Linear(config.hidden_size, config.bi_hidden_size) self.activation = nn.ReLU() def forward(self, hidden_states): # We "pool" the model by simply taking the hidden state corresponding # to the first token. first_token_tensor = hidden_states[:, 0] pooled_output = self.dense(first_token_tensor) pooled_output = self.activation(pooled_output) return pooled_output class BertImagePooler(nn.Module): def __init__(self, config): super(BertImagePooler, self).__init__() self.dense = nn.Linear(config.v_hidden_size, config.bi_hidden_size) self.activation = nn.ReLU() def forward(self, hidden_states): # We "pool" the model by simply taking the hidden state corresponding # to the first token. first_token_tensor = hidden_states[:, 0] pooled_output = self.dense(first_token_tensor) pooled_output = self.activation(pooled_output) return pooled_output class BertPredictionHeadTransform(nn.Module): def __init__(self, config): super(BertPredictionHeadTransform, self).__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) if isinstance(config.hidden_act, str) or ( sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode) ): self.transform_act_fn = ACT2FN[config.hidden_act] else: self.transform_act_fn = config.hidden_act self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) def forward(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = self.transform_act_fn(hidden_states) hidden_states = self.LayerNorm(hidden_states) return hidden_states class BertImgPredictionHeadTransform(nn.Module): def __init__(self, config): super(BertImgPredictionHeadTransform, self).__init__() self.dense = nn.Linear(config.v_hidden_size, config.v_hidden_size) if isinstance(config.hidden_act, str) or ( sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode) ): self.transform_act_fn = ACT2FN[config.hidden_act] else: self.transform_act_fn = config.v_hidden_act self.LayerNorm = BertLayerNorm(config.v_hidden_size, eps=1e-12) def forward(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = self.transform_act_fn(hidden_states) hidden_states = self.LayerNorm(hidden_states) return hidden_states class BertLMPredictionHead(nn.Module): def __init__(self, config, bert_model_embedding_weights): super(BertLMPredictionHead, self).__init__() self.transform = BertPredictionHeadTransform(config) # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. self.decoder = nn.Linear( bert_model_embedding_weights.size(1), bert_model_embedding_weights.size(0), bias=False, ) self.decoder.weight = bert_model_embedding_weights self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0))) def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) + self.bias return hidden_states class BertOnlyMLMHead(nn.Module): def __init__(self, config, bert_model_embedding_weights): super(BertOnlyMLMHead, self).__init__() self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights) def forward(self, sequence_output): prediction_scores = self.predictions(sequence_output) return prediction_scores class BertOnlyNSPHead(nn.Module): def __init__(self, config): super(BertOnlyNSPHead, self).__init__() self.seq_relationship = nn.Linear(config.hidden_size, 2) def forward(self, pooled_output): seq_relationship_score = self.seq_relationship(pooled_output) return seq_relationship_score class BertPreTrainingHeads(nn.Module): def __init__(self, config, bert_model_embedding_weights): super(BertPreTrainingHeads, self).__init__() self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights) self.bi_seq_relationship = nn.Linear(config.bi_hidden_size, 2) self.imagePredictions = BertImagePredictionHead(config) self.fusion_method = config.fusion_method self.dropout = nn.Dropout(0.1) def forward( self, sequence_output_t, sequence_output_v, pooled_output_t, pooled_output_v ): if self.fusion_method == 'sum': pooled_output = self.dropout(pooled_output_t + pooled_output_v) elif self.fusion_method == 'mul': pooled_output = self.dropout(pooled_output_t * pooled_output_v) else: assert False prediction_scores_t = self.predictions(sequence_output_t) seq_relationship_score = self.bi_seq_relationship(pooled_output) prediction_scores_v = self.imagePredictions(sequence_output_v) return prediction_scores_t, prediction_scores_v, seq_relationship_score class BertImagePredictionHead(nn.Module): def __init__(self, config): super(BertImagePredictionHead, self).__init__() self.transform = BertImgPredictionHeadTransform(config) # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. self.decoder = nn.Linear(config.v_hidden_size, config.v_target_size) def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) return hidden_states class BertPreTrainedModel(nn.Module): """ An abstract class to handle weights initialization and a simple interface for dowloading and loading pretrained models. """ def __init__(self, config, device='cuda:0', default_gpu=True, *inputs, **kwargs): super(BertPreTrainedModel, self).__init__() if not isinstance(config, BertConfig): raise ValueError( "Parameter config in `{}(config)` should be an instance of class `BertConfig`. " "To create a model from a Google pretrained model use " "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( self.__class__.__name__, self.__class__.__name__ ) ) self.config = config def init_bert_weights(self, module): """ Initialize the weights. """ if isinstance(module, (nn.Linear, nn.Embedding)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) elif isinstance(module, BertLayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() @classmethod def from_pretrained( cls, pretrained_model_name_or_path, config, device, use_apex=False, default_gpu=True, state_dict=None, cache_dir=None, from_tf=False, *inputs, **kwargs ): """ Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict. Download and cache the pre-trained model file if needed. Params: pretrained_model_name_or_path: either: - a str with the name of a pre-trained model to load selected in the list of: . `bert-base-uncased` . `bert-large-uncased` . `bert-base-cased` . `bert-large-cased` . `bert-base-multilingual-uncased` . `bert-base-multilingual-cased` . `bert-base-chinese` - a path or url to a pretrained model archive containing: . `bert_config.json` a configuration file for the model . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance - a path or url to a pretrained model archive containing: . `bert_config.json` a configuration file for the model . `model.chkpt` a TensorFlow checkpoint from_tf: should we load the weights from a locally saved TensorFlow checkpoint cache_dir: an optional path to a folder in which the pre-trained models will be cached. state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models *inputs, **kwargs: additional input for the specific Bert class (ex: num_labels for BertForSequenceClassification) """ CONFIG_NAME = "bert_config.json" WEIGHTS_NAME = "pytorch_model.bin" TF_WEIGHTS_NAME = "model.ckpt" if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path] else: archive_file = pretrained_model_name_or_path # redirect to the cache, if necessary try: resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) except EnvironmentError: logger.error( "Model name '{}' was not found in model name list ({}). " "We assumed '{}' was a path or url but couldn't find any file " "associated to this path or url.".format( pretrained_model_name_or_path, ", ".join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), archive_file, ) ) return None if default_gpu: if resolved_archive_file == archive_file: logger.info("loading archive file {}".format(archive_file)) else: logger.info( "loading archive file {} from cache at {}".format( archive_file, resolved_archive_file ) ) tempdir = None if os.path.isdir(resolved_archive_file) or from_tf: serialization_dir = resolved_archive_file elif resolved_archive_file[-3:] == 'bin': serialization_dir = '/'.join(resolved_archive_file.split('/')[:-1]) WEIGHTS_NAME = resolved_archive_file.split('/')[-1] else: # Extract archive to temp dir tempdir = tempfile.mkdtemp() logger.info( "extracting archive file {} to temp dir {}".format( resolved_archive_file, tempdir ) ) with tarfile.open(resolved_archive_file, "r:gz") as archive: archive.extractall(tempdir) serialization_dir = tempdir # Load config # config_file = os.path.join(serialization_dir, CONFIG_NAME) # config = BertConfig.from_json_file(config_file) if default_gpu: # cancel output # logger.info("Model config {}".format(config)) pass # Instantiate model. model = cls(config, device, use_apex, *inputs, **kwargs) if state_dict is None and not from_tf: weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) map_location = {'cuda:0': device} state_dict = torch.load( weights_path, map_location=map_location ) if 'state_dict' in dir(state_dict): state_dict = state_dict.state_dict() if tempdir: # Clean up temp dir shutil.rmtree(tempdir) if from_tf: # Directly load from a TensorFlow checkpoint weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME) return load_tf_weights_in_bert(model, weights_path) # Load from a PyTorch state_dict old_keys = [] new_keys = [] for key in state_dict.keys(): new_key = None if "gamma" in key: new_key = key.replace("gamma", "weight") if "beta" in key: new_key = key.replace("beta", "bias") if new_key: old_keys.append(key) new_keys.append(new_key) for old_key, new_key in zip(old_keys, new_keys): state_dict[new_key] = state_dict.pop(old_key) missing_keys = [] unexpected_keys = [] error_msgs = [] # copy state_dict so _load_from_state_dict can modify it metadata = getattr(state_dict, "_metadata", None) state_dict = state_dict.copy() if metadata is not None: state_dict._metadata = metadata def load(module, prefix=""): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) module._load_from_state_dict( state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs, ) for name, child in module._modules.items(): if child is not None: load(child, prefix + name + ".") start_prefix = "" if not hasattr(model, "bert") and any( s.startswith("bert.") for s in state_dict.keys() ): start_prefix = "bert." load(model, prefix=start_prefix) if len(missing_keys) > 0 and default_gpu: logger.info( "Weights of {} not initialized from pretrained model: {}".format( model.__class__.__name__, missing_keys ) ) if len(unexpected_keys) > 0 and default_gpu: logger.info( "Weights from pretrained model not used in {}: {}".format( model.__class__.__name__, unexpected_keys ) ) if len(error_msgs) > 0 and default_gpu: raise RuntimeError( "Error(s) in loading state_dict for {}:\n\t{}".format( model.__class__.__name__, "\n\t".join(error_msgs) ) ) return model class BertModel(BertPreTrainedModel): """BERT model ("Bidirectional Embedding Representations from a Transformer"). Params: config: a BertConfig class instance with the configuration to build a new model Inputs: `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts `extract_features.py`, `run_classifier.py` and `run_squad.py`) `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to a `sentence B` token (see BERT paper for more details). `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max input sequence length in the current batch. It's the mask that we typically use for attention when a batch has varying length sentences. `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`. Outputs: Tuple of (encoded_layers, pooled_output) `encoded_layers`: controled by `output_all_encoded_layers` argument: - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size], - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding to the last attention block of shape [batch_size, sequence_length, hidden_size], `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a classifier pretrained on top of the hidden state associated to the first character of the input (`CLS`) to train on the Next-Sentence task (see BERT's paper). Example usage: ```python # Already been converted into WordPiece token ids input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) model = modeling.BertModel(config=config) all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) ``` """ def __init__(self, config, device, use_apex=False): super(BertModel, self).__init__(config, device) # initilize word embedding self.embeddings = BertEmbeddingsDialog(config, device) # initlize the vision embedding self.v_embeddings = BertImageEmbeddings(config) self.encoder = BertEncoder(config) self.t_pooler = BertTextPooler(config) self.v_pooler = BertImagePooler(config) self.use_apex = use_apex self.apply(self.init_bert_weights) def forward( self, input_txt, input_imgs, image_loc, image_edge_indices, image_edge_attributes, question_edge_indices, question_edge_attributes, question_limits, history_edge_indices, history_sep_indices, sep_indices=None, sep_len=None, token_type_ids=None, attention_mask=None, image_attention_mask=None, co_attention_mask=None, output_all_encoded_layers=False, output_all_attention_masks=False, ): if attention_mask is None: attention_mask = torch.ones_like(input_txt) if token_type_ids is None: token_type_ids = torch.zeros_like(input_txt) if image_attention_mask is None: image_attention_mask = torch.ones( input_imgs.size(0), input_imgs.size(1) ).type_as(input_txt) # We create a 3D attention mask from a 2D tensor mask. # Sizes are [batch_size, 1, 1, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # this attention mask is more simple than the triangular masking of causal attention # used in OpenAI GPT, we just need to prepare the broadcast dimension here. extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) extended_image_attention_mask = image_attention_mask.unsqueeze(1).unsqueeze(2) # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for # positions we want to attend and -10000.0 for masked positions. # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. if self.use_apex: dtype = dtype=next(self.parameters()).dtype else: dtype = torch.float32 extended_attention_mask = extended_attention_mask.to(dtype) # fp16 compatibility extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 extended_image_attention_mask = extended_image_attention_mask.to(dtype) # fp16 compatibility extended_image_attention_mask = (1.0 - extended_image_attention_mask) * -10000.0 if co_attention_mask is None: co_attention_mask = torch.zeros(input_txt.size(0), input_imgs.size(1), input_txt.size(1)).type_as(extended_image_attention_mask) extended_co_attention_mask = co_attention_mask.unsqueeze(1) # extended_co_attention_mask = co_attention_mask.unsqueeze(-1) extended_co_attention_mask = extended_co_attention_mask * 5.0 extended_co_attention_mask = extended_co_attention_mask.to(dtype) # fp16 compatibility embedding_output = self.embeddings(input_txt, token_type_ids=token_type_ids, sep_indices=sep_indices, sep_len=sep_len) v_embedding_output = self.v_embeddings(input_imgs, image_loc) encoded_layers_t, encoded_layers_v, all_attention_mask, hub_feats = self.encoder( embedding_output, v_embedding_output, extended_attention_mask, extended_image_attention_mask, image_edge_indices, image_edge_attributes, question_edge_indices, question_edge_attributes, question_limits, history_edge_indices, history_sep_indices, co_attention_mask=extended_co_attention_mask, output_all_encoded_layers=output_all_encoded_layers, output_all_attention_masks=output_all_attention_masks, ) sequence_output_t = encoded_layers_t[-1] sequence_output_v = encoded_layers_v[-1] pooled_output_t = self.t_pooler(sequence_output_t) pooled_output_v = self.v_pooler(sequence_output_v) if not output_all_encoded_layers: encoded_layers_t = encoded_layers_t[-1] encoded_layers_v = encoded_layers_v[-1] return encoded_layers_t, encoded_layers_v, pooled_output_t, pooled_output_v, all_attention_mask, hub_feats class BertImageEmbeddings(nn.Module): """Construct the embeddings from image, spatial location (omit now) and token_type embeddings. """ def __init__(self, config): super(BertImageEmbeddings, self).__init__() self.image_embeddings = nn.Linear(config.v_feature_size, config.v_hidden_size) self.image_location_embeddings = nn.Linear(5, config.v_hidden_size) self.LayerNorm = BertLayerNorm(config.v_hidden_size, eps=1e-12) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, input_ids, input_loc): img_embeddings = self.image_embeddings(input_ids) loc_embeddings = self.image_location_embeddings(input_loc) embeddings = self.LayerNorm(img_embeddings+loc_embeddings) embeddings = self.dropout(embeddings) return embeddings class BertForMultiModalPreTraining(BertPreTrainedModel): """BERT model with multi modal pre-training heads. """ def __init__(self, config, device, use_apex=False): super(BertForMultiModalPreTraining, self).__init__(config, device) self.bert = BertModel(config, device, use_apex) self.cls = BertPreTrainingHeads( config, self.bert.embeddings.word_embeddings.weight ) self.apply(self.init_bert_weights) self.predict_feature = config.predict_feature self.loss_fct = CrossEntropyLoss(ignore_index=-1) print("model's option for predict_feature is ", config.predict_feature) if self.predict_feature: self.vis_criterion = nn.MSELoss(reduction="none") else: self.vis_criterion = nn.KLDivLoss(reduction="none") def forward( self, input_ids, image_feat, image_loc, image_edge_indices, image_edge_attributes, question_edge_indices, question_edge_attributes, question_limits, history_edge_indices, history_sep_indices, sep_indices=None, sep_len=None, token_type_ids=None, attention_mask=None, image_attention_mask=None, masked_lm_labels=None, image_label=None, image_target = None, next_sentence_label=None, output_all_attention_masks=False ): # in this model, we first embed the images. sequence_output_t, sequence_output_v, pooled_output_t, pooled_output_v, all_attention_mask, hub_nodes = self.bert( input_ids, image_feat, image_loc, image_edge_indices, image_edge_attributes, question_edge_indices, question_edge_attributes, question_limits, history_edge_indices, history_sep_indices, sep_indices=sep_indices, sep_len=sep_len, token_type_ids=token_type_ids, attention_mask=attention_mask, image_attention_mask=image_attention_mask, output_all_encoded_layers=False, output_all_attention_masks=output_all_attention_masks ) prediction_scores_t, prediction_scores_v, seq_relationship_score = self.cls( sequence_output_t, sequence_output_v, pooled_output_t, pooled_output_v ) if masked_lm_labels is not None and next_sentence_label is not None and image_target is not None: # prediction_scores_v = prediction_scores_v[:, 1:] if self.predict_feature: img_loss = self.vis_criterion(prediction_scores_v, image_target) masked_img_loss = torch.sum( img_loss * (image_label == 1).unsqueeze(2).float() ) / max(torch.sum((image_label == 1).unsqueeze(2).expand_as(img_loss)),1) else: img_loss = self.vis_criterion( F.log_softmax(prediction_scores_v, dim=2), image_target ) masked_img_loss = torch.sum( img_loss * (image_label == 1).unsqueeze(2).float() ) / max(torch.sum((image_label == 1)), 0) # masked_img_loss = torch.sum(img_loss) / (img_loss.shape[0] * img_loss.shape[1]) masked_lm_loss = self.loss_fct( prediction_scores_t.view(-1, self.config.vocab_size), masked_lm_labels.view(-1), ) next_sentence_loss = self.loss_fct( seq_relationship_score.view(-1, 2), next_sentence_label.view(-1) ) # total_loss = masked_lm_loss + next_sentence_loss + masked_img_loss return masked_lm_loss.unsqueeze(0), masked_img_loss.unsqueeze(0), next_sentence_loss.unsqueeze(0), sequence_output_t, prediction_scores_t, seq_relationship_score, hub_nodes else: return prediction_scores_t, prediction_scores_v, seq_relationship_score, sequence_output_t, all_attention_mask, hub_nodes def get_text_embedding( self, input_ids, image_feat, image_loc, sep_indices=None, sep_len=None, token_type_ids=None, attention_mask=None, image_attention_mask=None, output_all_attention_masks=False ): # in this model, we first embed the images. sequence_output_t, sequence_output_v, pooled_output_t, pooled_output_v, all_attention_mask = self.bert( input_ids, image_feat, image_loc, sep_indices=sep_indices, sep_len=sep_len, token_type_ids=token_type_ids, attention_mask=attention_mask, image_attention_mask=image_attention_mask, output_all_encoded_layers=False, output_all_attention_masks=output_all_attention_masks ) return sequence_output_t # [batch_size, num_words, 768] class VILBertForVLTasks(BertPreTrainedModel): def __init__(self, config, device, num_labels, use_apex=False, dropout_prob=0.1, default_gpu=True): super(VILBertForVLTasks, self).__init__(config) self.num_labels = num_labels self.bert = BertModel(config, device, use_apex) self.use_apex = use_apex self.dropout = nn.Dropout(dropout_prob) self.cls = BertPreTrainingHeads( config, self.bert.embeddings.word_embeddings.weight ) self.vil_prediction = SimpleClassifier(config.bi_hidden_size, config.bi_hidden_size*2, num_labels, 0.5) # self.vil_prediction = nn.Linear(config.bi_hidden_size, num_labels) self.vil_logit = nn.Linear(config.bi_hidden_size, 1) self.vision_logit = nn.Linear(config.v_hidden_size, 1) self.linguisic_logit = nn.Linear(config.hidden_size, 1) self.fusion_method = config.fusion_method self.apply(self.init_bert_weights) def forward( self, input_txt, input_imgs, image_loc, token_type_ids=None, attention_mask=None, image_attention_mask=None, co_attention_mask=None, output_all_encoded_layers=False, ): sequence_output_t, sequence_output_v, pooled_output_t, pooled_output_v, _ = self.bert( input_txt, input_imgs, image_loc, token_type_ids, attention_mask, image_attention_mask, co_attention_mask, output_all_encoded_layers=False, ) vil_prediction = 0 vil_logit = 0 vil_binary_prediction = 0 vision_prediction = 0 vision_logit = 0 linguisic_prediction = 0 linguisic_logit = 0 linguisic_prediction, vision_prediction, vil_binary_prediction = self.cls( sequence_output_t, sequence_output_v, pooled_output_t, pooled_output_v ) if self.fusion_method == 'sum': pooled_output = self.dropout(pooled_output_t + pooled_output_v) elif self.fusion_method == 'mul': pooled_output = self.dropout(pooled_output_t * pooled_output_v) else: assert False vil_prediction = self.vil_prediction(pooled_output) vil_logit = self.vil_logit(pooled_output) if self.use_apex: dtype = next(self.parameters()).dtype else: dtype = torch.float32 vision_logit = self.vision_logit(self.dropout(sequence_output_v)) + ((1.0 - image_attention_mask)* -10000.0).unsqueeze(2).to(dtype) linguisic_logit = self.linguisic_logit(self.dropout(sequence_output_t)) return vil_prediction, vil_logit, vil_binary_prediction, vision_prediction, vision_logit, linguisic_prediction, linguisic_logit class SimpleClassifier(nn.Module): def __init__(self, in_dim, hid_dim, out_dim, dropout): super(SimpleClassifier, self).__init__() layers = [ weight_norm(nn.Linear(in_dim, hid_dim), dim=None), nn.ReLU(), nn.Dropout(dropout, inplace=True), weight_norm(nn.Linear(hid_dim, out_dim), dim=None) ] self.main = nn.Sequential(*layers) def forward(self, x): logits = self.main(x)