MST-MIXER/models/gnns.py
2024-07-08 11:41:28 +02:00

801 lines
33 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn.dense import DenseGATConv, DenseGCNConv, DenseSAGEConv
from torch.nn.parameter import Parameter
from typing import Optional, Tuple
from .utils import get_knn_graph
import torch_sparse
class BartAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.reshape(*proj_shape)
value_states = value_states.reshape(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if output_attentions:
# this operation is a bit awkward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to be reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
else:
attn_weights_reshaped = None
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_probs, value_states)
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned across GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights_reshaped, past_key_value
class MLPModule(nn.Module):
def __init__(self, d_in, d_hidden, d_out, num_layers=3, dropout=0.3, use_non_linear=False, use_batch_norm=False):
super(MLPModule, self).__init__()
self.use_batch_norm = use_batch_norm
self.dropout = dropout
self.fcs = nn.ModuleList()
self.batch_norms = nn.ModuleList()
if num_layers == 1:
self.fcs.append(nn.Linear(d_in, d_out))
else:
self.fcs.append(nn.Linear(d_in, d_hidden))
self.batch_norms.append(nn.BatchNorm1d(d_hidden))
for _ in range(num_layers - 2):
self.fcs.append(nn.Linear(d_hidden, d_hidden))
self.batch_norms.append(nn.BatchNorm1d(d_hidden))
self.fcs.append(nn.Linear(d_hidden, d_out))
self.act_fn = nn.GELU()
self.dropout = nn.Dropout(dropout)
self.use_non_linear=use_non_linear
def reset_parameters(self):
for fc in self.fcs:
fc.reset_parameters()
for bn in self.batch_norms:
bn.reset_parameters()
def forward(self, X):
for fc, bn in zip(self.fcs[:-1], self.batch_norms):
X = fc(X)
X = self.act_fn(X)
if self.use_batch_norm:
if X.dim() > 2:
X = X.transpose(1, 2)
X = bn(X)
if X.dim() > 2:
X = X.transpose(1, 2)
X = self.dropout(X)
X = self.fcs[-1](X)
return X
class GATModule(nn.Module):
def __init__(self, d_in, d_hidden, d_out, num_layers=3, dropout=0.3, concat=True, heads=2, use_non_linear=False, use_batch_norm=False):
super(GATModule, self).__init__()
self.gnns = nn.ModuleList()
if concat:
d_hidden = d_hidden // heads
d_out = d_out // heads
self.gnns.append(DenseGATConv(d_in, d_hidden, heads=heads, concat=concat, dropout=dropout))
self.batch_norms = nn.ModuleList()
self.batch_norms.append(nn.BatchNorm1d(d_hidden * heads if concat else d_hidden))
for _ in range(num_layers - 2):
self.gnns.append(DenseGATConv(
d_hidden * heads if concat else d_hidden, d_hidden,
heads=heads,
concat=concat,
dropout=dropout)
)
self.batch_norms.append(nn.BatchNorm1d(d_hidden * heads if concat else d_hidden))
self.gnns.append(DenseGATConv(
d_hidden * heads if concat else d_hidden, d_out,
heads=heads,
concat=concat,
dropout=dropout)
)
self.dropout = nn.Dropout(dropout)
self.non_linear = nn.GELU()
self.use_batch_norm = use_batch_norm
self.use_non_linear = use_non_linear
def reset_parameters(self):
for gnn in self.gnns:
gnn.reset_parameters()
for batch_norm in self.batch_norms:
batch_norm.reset_parameters()
def forward(self, X, A):
Z = self.dropout(X)
for i in range(len(self.gnns) - 1):
Z = self.gnns[i](Z, A)
if self.use_batch_norm:
Z = Z.transpose(1, 2)
Z = self.batch_norms[i](Z)
Z = Z.transpose(1, 2)
if self.use_non_linear:
Z = self.non_linear(Z)
Z = self.dropout(Z)
Z = self.gnns[-1](Z, A)
return Z
class GCNModule(nn.Module):
def __init__(self, d_in, d_hidden, d_out, num_layers=3, dropout=0.3, use_non_linear=False, use_batch_norm=False):
super(GCNModule, self).__init__()
self.gnns = nn.ModuleList()
self.gnns.append(DenseGCNConv(d_in, d_hidden))
self.batch_norms = nn.ModuleList()
self.batch_norms.append(nn.BatchNorm1d(d_hidden))
for _ in range(num_layers - 2):
self.gnns.append(DenseGCNConv(
d_hidden, d_hidden)
)
self.batch_norms.append(nn.BatchNorm1d(d_hidden))
self.gnns.append(DenseGCNConv(
d_hidden, d_out)
)
self.dropout = nn.Dropout(dropout)
self.non_linear = nn.GELU()
self.use_batch_norm = use_batch_norm
self.use_non_linear = use_non_linear
def reset_parameters(self):
for gnn in self.gnns:
gnn.reset_parameters()
for batch_norm in self.batch_norms:
batch_norm.reset_parameters()
def forward(self, X, A):
Z = self.dropout(X)
for i in range(len(self.gnns) - 1):
Z = self.gnns[i](Z, A)
if self.use_batch_norm:
Z = Z.transpose(1, 2)
Z = self.batch_norms[i](Z)
Z = Z.transpose(1, 2)
if self.use_non_linear:
Z = self.non_linear(Z)
Z = self.dropout(Z)
Z = self.gnns[-1](Z, A)
return Z
class SAGEModule(nn.Module):
def __init__(self, d_in, d_hidden, d_out, num_layers=3, dropout=0.3, use_non_linear=False, use_batch_norm=False):
super(SAGEModule, self).__init__()
self.gnns = nn.ModuleList()
self.gnns.append(DenseSAGEConv(d_in, d_hidden))
self.batch_norms = nn.ModuleList()
self.batch_norms.append(nn.BatchNorm1d(d_hidden))
for _ in range(num_layers - 2):
self.gnns.append(DenseSAGEConv(
d_hidden, d_hidden)
)
self.batch_norms.append(nn.BatchNorm1d(d_hidden))
self.gnns.append(DenseSAGEConv(
d_hidden, d_out)
)
self.dropout = nn.Dropout(dropout)
self.non_linear = nn.GELU()
self.use_batch_norm = use_batch_norm
self.use_non_linear = use_non_linear
def reset_parameters(self):
for gnn in self.gnns:
gnn.reset_parameters()
for batch_norm in self.batch_norms:
batch_norm.reset_parameters()
def forward(self, X, A):
Z = self.dropout(X)
for i in range(len(self.gnns) - 1):
Z = self.gnns[i](Z, A)
if self.use_batch_norm:
Z = Z.transpose(1, 2)
Z = self.batch_norms[i](Z)
Z = Z.transpose(1, 2)
if self.use_non_linear:
Z = self.non_linear(Z)
Z = self.dropout(Z)
Z = self.gnns[-1](Z, A)
return Z
class GlobalGraphLearner(nn.Module):
def __init__(self, d_in, num_heads, random=False):
super(GlobalGraphLearner, self).__init__()
self.random = random
if not self.random:
w = torch.Tensor(num_heads, d_in)
self.w = Parameter(nn.init.xavier_uniform_(w), requires_grad=True)
def reset_parameters(self):
if not self.random:
self.w = Parameter(nn.init.xavier_uniform_(self.w))
def forward(self, Z):
if self.random:
att_global = torch.randn((Z.size(0), Z.size(1), Z.size(1))).to(Z.device)
else:
w_expanded = self.w.unsqueeze(1).unsqueeze(1)
Z = Z.unsqueeze(0) * w_expanded
Z = F.normalize(Z, p=2, dim=-1)
att_global = torch.matmul(Z, Z.transpose(-1, -2)).mean(0)
mask_global = (att_global > 0).detach().float()
att_global = att_global * mask_global
return att_global
class DenseAPPNP(nn.Module):
def __init__(self, K, alpha):
super().__init__()
self.K = K
self.alpha = alpha
def forward(self, x, adj_t):
h = x
for _ in range(self.K):
if adj_t.is_sparse:
x = torch_sparse.spmm(adj_t, x)
else:
x = torch.matmul(adj_t, x)
x = x * (1 - self.alpha)
x += self.alpha * h
x /= self.K
return x
class Dense_APPNP_Net(nn.Module):
def __init__(self, d_in, d_hidden, d_out, dropout=.5, K=10, alpha=.1):
super(Dense_APPNP_Net, self).__init__()
self.lin1 = nn.Linear(d_in, d_hidden)
self.lin2 = nn.Linear(d_hidden, d_out)
self.prop1 = DenseAPPNP(K, alpha)
self.dropout = dropout
def reset_parameters(self):
self.lin1.reset_parameters()
self.lin2.reset_parameters()
def forward(self, x, adj_t):
x = F.dropout(x, p=self.dropout, training=self.training)
x = F.relu(self.lin1(x))
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.lin2(x)
x = self.prop1(x, adj_t)
return x
class MMGraphLearner(nn.Module):
def __init__(self, d_in, num_heads, random=False):
super(MMGraphLearner, self).__init__()
self.random = random
if not self.random:
w = torch.Tensor(num_heads, d_in)
self.w = Parameter(nn.init.xavier_uniform_(w), requires_grad=True)
self.fc = nn.Linear(d_in, d_in)
def reset_parameters(self):
if not self.random:
self.fc.reset_parameters()
self.w = Parameter(nn.init.xavier_uniform_(self.w), requires_grad=True)
def forward(self, features):
if self.random:
att = torch.randn((features.size(0), features.size(1), features.size(1))).to(features.device)
else:
features = self.fc(features)
w_expanded = self.w.unsqueeze(1).unsqueeze(1)
features = features.unsqueeze(0) * w_expanded
features = F.normalize(features, p=2, dim=-1)
att = torch.matmul(features, features.transpose(-1, -2)).mean(0)
mask = (att > 0).detach().float()
att = att * mask
return att
class QNetLocal(nn.Module):
def __init__(self, config):
super(QNetLocal, self).__init__()
self.config=config
self.mm_gnn_modules = nn.ModuleList()
self.mm_graph_learners_1 = nn.ModuleList()
self.mm_graph_learners_2 = nn.ModuleList()
for _ in range(self.config.num_modalities):
if self.config.gnn_type == 'gat':
self.mm_gnn_modules.append(GATModule(
self.config.d_model, self.config.d_model, self.config.d_model,
num_layers=self.config.num_local_gnn_layers,
heads=self.config.num_local_gnn_heads,
dropout=self.config.local_gnn_dropout,
concat=self.config.local_gnn_concat,
use_batch_norm=self.config.use_local_gnn_bn,
use_non_linear=self.config.use_non_linear
)
)
elif self.config.gnn_type == 'appnp':
self.mm_gnn_modules.append(Dense_APPNP_Net(
self.config.d_model, self.config.d_model, self.config.d_model, dropout=self.config.local_gnn_dropout,
K=self.config.gnn_K, alpha=self.config.gnn_alpha
)
)
elif self.config.gnn_type == 'gcn':
self.mm_gnn_modules.append(GCNModule(
self.config.d_model, self.config.d_model, self.config.d_model,
num_layers=self.config.num_local_gnn_layers,
dropout=self.config.local_gnn_dropout,
use_batch_norm=self.config.use_local_gnn_bn,
use_non_linear=self.config.use_non_linear
)
)
elif self.config.gnn_type == 'sage':
self.mm_gnn_modules.append(SAGEModule(
self.config.d_model, self.config.d_model, self.config.d_model,
num_layers=self.config.num_local_gnn_layers,
dropout=self.config.local_gnn_dropout,
use_batch_norm=self.config.use_local_gnn_bn,
use_non_linear=self.config.use_non_linear
)
)
else:
raise ValueError
self.mm_graph_learners_1.append(MMGraphLearner(self.config.d_model, self.config.num_local_gr_learner_heads, random=self.config.use_random_graphs))
self.mm_graph_learners_2.append(MMGraphLearner(self.config.d_model * 2, self.config.num_local_gr_learner_heads, random=self.config.use_random_graphs))
def reset_parameters(self):
for i in range(self.config.num_modalities):
self.mm_gnn_modules[i].reset_parameters()
self.mm_graph_learners_1[i].reset_parameters()
self.mm_graph_learners_2[i].reset_parameters()
def forward(self, features, A_tildes=None):
mm_Xs = features# []
device = features[0].device
if A_tildes is None:
A_tildes = []
for mm_X in mm_Xs:
A_tildes.append(get_knn_graph(mm_X, self.config.num_nn, device))
################# Multi-modal graph learner (upper branch) #################
A_primes = []
for i, mm_X in enumerate(mm_Xs): # iterate over the modalities
A_primes.append(self.mm_graph_learners_1[i](mm_X))
# Linear combination of A_primes with A_tildes
A_primes = [(1 - self.config.init_adj_ratio) * A_prime + self.config.init_adj_ratio * A_tilde for A_prime, A_tilde in zip(A_primes, A_tildes)]
################# Multi-modal gnn (upper branch) #################
Z_primes = []
for i, (mm_X, A_prime) in enumerate(zip(mm_Xs, A_primes)):
Z_primes.append(self.mm_gnn_modules[i](mm_X, A_prime))
################# Multi-modal gnn (lower branch) #################
Z_double_primes = []
for i, (mm_X, A_tilde) in enumerate(zip(mm_Xs, A_tildes)):
Z_double_primes.append(self.mm_gnn_modules[i](mm_X, A_tilde))
Z_concats = [torch.cat([Z_1, Z_2], dim=-1) for Z_1, Z_2 in zip(Z_primes, Z_double_primes)]
################# Multi-modal graph learner (lower branch) #################
A_double_primes = []
for i, Z_concat in enumerate(Z_concats):
A_double_primes.append(self.mm_graph_learners_2[i](Z_concat))
A_double_primes = [(1 - self.config.init_adj_ratio) * A_double_prime + self.config.init_adj_ratio * A_tilde for A_double_prime, A_tilde in zip(A_double_primes, A_tildes)]
As = [(1 - self.config.adj_ratio) * A_prime + self.config.adj_ratio * A_double_prime for A_prime, A_double_prime in zip(A_primes, A_double_primes)]
################## Average across all multimodal inputs ##################
Zs = [0.5 * Z1 + 0.5 * Z2 for Z1, Z2 in zip(Z_primes, Z_double_primes)]
return As, Zs
class PNetLocal(nn.Module):
def __init__(self, config):
super(PNetLocal, self).__init__()
self.config = config
self.mm_gnn_modules = nn.ModuleList()
self.mm_mlp_modules = nn.ModuleList()
self.mm_graph_learners_1 = nn.ModuleList()
self.mm_graph_learners_2 = nn.ModuleList()
for _ in range(self.config.num_modalities):
if self.config.gnn_type == 'gat':
self.mm_gnn_modules.append(GATModule(
self.config.d_model, self.config.d_model, self.config.d_model,
num_layers=self.config.num_local_gnn_layers,
heads=self.config.num_local_gnn_heads,
dropout=self.config.local_gnn_dropout,
concat=self.config.local_gnn_concat,
use_batch_norm=self.config.use_local_gnn_bn,
use_non_linear=self.config.use_non_linear
)
)
elif self.config.gnn_type == 'appnp':
self.mm_gnn_modules.append(Dense_APPNP_Net(
self.config.d_model, self.config.d_model, self.config.d_model, dropout=self.config.local_gnn_dropout,
K=self.config.gnn_K, alpha=self.config.gnn_alpha
)
)
elif self.config.gnn_type == 'gcn':
self.mm_gnn_modules.append(GCNModule(
self.config.d_model, self.config.d_model, self.config.d_model,
num_layers=self.config.num_local_gnn_layers,
dropout=self.config.local_gnn_dropout,
use_batch_norm=self.config.use_local_gnn_bn,
use_non_linear=self.config.use_non_linear
)
)
elif self.config.gnn_type == 'sage':
self.mm_gnn_modules.append(SAGEModule(
self.config.d_model, self.config.d_model, self.config.d_model,
num_layers=self.config.num_local_gnn_layers,
dropout=self.config.local_gnn_dropout,
use_batch_norm=self.config.use_local_gnn_bn,
use_non_linear=self.config.use_non_linear
)
)
else:
raise ValueError
self.mm_mlp_modules.append(MLPModule(
self.config.d_model, self.config.d_model, self.config.d_model,
num_layers=self.config.num_local_fc_layers,
dropout=self.config.local_fc_dropout,
use_batch_norm=self.config.use_local_fc_bn,
use_non_linear=self.config.use_non_linear
))
self.mm_graph_learners_1.append(MMGraphLearner(self.config.d_model, self.config.num_local_gr_learner_heads, random=self.config.use_random_graphs))
self.mm_graph_learners_2.append(MMGraphLearner(self.config.d_model * 2, self.config.num_local_gr_learner_heads, random=self.config.use_random_graphs))
def reset_parameters(self):
for i in range(self.config.num_modalities):
self.mm_gnn_modules[i].reset_parameters()
self.mm_mlp_modules[i].reset_parameters()
self.mm_graph_learners_1[i].reset_parameters()
self.mm_graph_learners_2[i].reset_parameters()
def forward(self, features):
mm_Xs = features
################# Multi-modal graph learner (upper branch) #################
A_primes = []
for i, mm_X in enumerate(mm_Xs): # iterate over the modalities
A_primes.append(self.mm_graph_learners_1[i](mm_X))
################# Multi-modal gnn (upper branch) #################
Z_primes = []
for i, (mm_X, A_prime) in enumerate(zip(mm_Xs, A_primes)):
Z_primes.append(self.mm_gnn_modules[i](mm_X, A_prime))
################# Multi-modal gnn (lower branch) #################
Z_double_primes = []
for i, mm_X, in enumerate(mm_Xs):
Z_double_primes.append(self.mm_mlp_modules[i](mm_X))
Z_concats = [torch.cat([Z_1, Z_2], dim=-1) for Z_1, Z_2 in zip(Z_primes, Z_double_primes)]
################# Multi-modal graph learner (lower branch) #################
A_double_primes = []
for i, Z_concat in enumerate(Z_concats):
A_double_primes.append(self.mm_graph_learners_2[i](Z_concat))
As = [(1 - self.config.adj_ratio) * A_prime + self.config.adj_ratio * A_double_prime for A_prime, A_double_prime in zip(A_primes, A_double_primes)]
################## Average across all multimodal inputs ##################
Zs = [0.5 * Z1 + 0.5 * Z2 for Z1, Z2 in zip(Z_primes, Z_double_primes)]
return As, Zs
class QNetGlobal(nn.Module):
def __init__(self, config):
super(QNetGlobal, self).__init__()
self.config = config
if self.config.gnn_type == 'gat':
self.gnn = GATModule(
self.config.d_model, self.config.d_model, self.config.d_model,
num_layers=self.config.num_global_gnn_layers,
heads=self.config.num_global_gnn_heads,
dropout=self.config.global_gnn_dropout,
concat=self.config.global_gnn_concat,
use_batch_norm=self.config.use_global_gnn_bn,
use_non_linear=self.config.use_non_linear
)
elif self.config.gnn_type == 'appnp':
self.gnn = Dense_APPNP_Net(
self.config.d_model, self.config.d_model, self.config.d_model, dropout=self.config.local_gnn_dropout,
K=self.config.gnn_K, alpha=self.config.gnn_alpha
)
elif self.config.gnn_type == 'gcn':
self.gnn = GCNModule(
self.config.d_model, self.config.d_model, self.config.d_model,
num_layers=self.config.num_global_gnn_layers,
dropout=self.config.global_gnn_dropout,
use_batch_norm=self.config.use_global_gnn_bn,
use_non_linear=self.config.use_non_linear
)
elif self.config.gnn_type == 'sage':
self.gnn = SAGEModule(
self.config.d_model, self.config.d_model, self.config.d_model,
num_layers=self.config.num_global_gnn_layers,
dropout=self.config.global_gnn_dropout,
use_batch_norm=self.config.use_global_gnn_bn,
use_non_linear=self.config.use_non_linear
)
else:
raise ValueError
self.graph_learner_1 = GlobalGraphLearner(self.config.d_model, self.config.num_global_gr_learner_heads, self.config.use_random_graphs)
self.graph_learner_2 = GlobalGraphLearner(self.config.d_model * 2, self.config.num_global_gr_learner_heads, self.config.use_random_graphs)
def reset_parameters(self):
self.gnn.reset_parameters()
self.graph_learner_1.reset_parameters()
self.graph_learner_2.reset_parameters()
def forward(self, Z, A):
################# Graph learner (upper branch) #################
A_prime = self.graph_learner_1(Z)
A_prime = (1-self.config.init_adj_ratio) * A_prime + self.config.init_adj_ratio * A
################# Gnn (upper branch) #################
Z_prime = self.gnn(Z, A_prime)
################# Gnn (lower branch) #################
Z_double_prime = self.gnn(Z, A)
Z_concat = torch.cat([Z_prime, Z_double_prime], dim=-1)
################# Graph learner (lower branch) #################
A_double_prime = self.graph_learner_2(Z_concat)
A_double_prime = (1-self.config.init_adj_ratio) * A_double_prime + self.config.init_adj_ratio * A
################## Average across branches ##################
A_global = (1 - self.config.adj_ratio) * A_prime + self.config.adj_ratio * A_double_prime
Z_global = 0.5 * Z_prime + 0.5 * Z_double_prime
return A_global, Z_global
class PNetGlobal(nn.Module):
def __init__(self, config):
super(PNetGlobal, self).__init__()
self.config = config
if self.config.gnn_type == 'gat':
self.gnn = GATModule(
self.config.d_model, self.config.d_model, self.config.d_model,
num_layers=self.config.num_global_gnn_layers,
heads=self.config.num_global_gnn_heads,
dropout=self.config.global_gnn_dropout,
concat=self.config.global_gnn_concat,
use_batch_norm=self.config.use_global_gnn_bn,
use_non_linear=self.config.use_non_linear
)
elif self.config.gnn_type == 'appnp':
self.gnn = Dense_APPNP_Net(
self.config.d_model, self.config.d_model, self.config.d_model, dropout=self.config.local_gnn_dropout,
K=self.config.gnn_K, alpha=self.config.gnn_alpha
)
elif self.config.gnn_type == 'gcn':
self.gnn = GCNModule(
self.config.d_model, self.config.d_model, self.config.d_model,
num_layers=self.config.num_global_gnn_layers,
dropout=self.config.global_gnn_dropout,
use_batch_norm=self.config.use_global_gnn_bn,
use_non_linear=self.config.use_non_linear
)
elif self.config.gnn_type == 'sage':
self.gnn = SAGEModule(
self.config.d_model, self.config.d_model, self.config.d_model,
num_layers=self.config.num_global_gnn_layers,
dropout=self.config.global_gnn_dropout,
use_batch_norm=self.config.use_global_gnn_bn,
use_non_linear=self.config.use_non_linear
)
else:
raise ValueError
self.mlp = MLPModule(
self.config.d_model, self.config.d_model, self.config.d_model,
num_layers=self.config.num_global_fc_layers,
dropout=self.config.global_fc_dropout,
use_batch_norm=self.config.use_global_fc_bn,
use_non_linear=self.config.use_non_linear
)
self.graph_learner_1 = GlobalGraphLearner(self.config.d_model, self.config.num_global_gr_learner_heads, random=self.config.use_random_graphs)
self.graph_learner_2 = GlobalGraphLearner(self.config.d_model * 2, self.config.num_global_gr_learner_heads, random=self.config.use_random_graphs)
def reset_parameters(self):
self.gnn.reset_parameters()
self.mlp.reset_parameters()
self.graph_learner_1.reset_parameters()
self.graph_learner_2.reset_parameters()
def forward(self, Z, A):
################# Graph learner (upper branch) #################
A_prime = self.graph_learner_1(Z)
################# Gnn (upper branch) #################
Z_prime = self.gnn(Z, A_prime)
################# mlp (lower branch) #################
Z_double_prime = self.mlp(Z)
Z_concat = torch.cat([Z_prime, Z_double_prime], dim=-1)
################# Graph learner (lower branch) #################
A_double_prime = self.graph_learner_2(Z_concat)
# A_double_prime = (1-self.config.init_adj_ratio) * A_double_prime + self.config.init_adj_ratio * A
################## Average across braches ##################
A_global = (1 - self.config.adj_ratio) * A_prime + self.config.adj_ratio * A_double_prime
Z_global = 0.5 * Z_prime + 0.5 * Z_double_prime
return A_global, Z_global