Make code public
This commit is contained in:
commit
8e03ef1c38
49 changed files with 545354 additions and 0 deletions
0
models/__init__.py
Normal file
0
models/__init__.py
Normal file
1438
models/avsd_bart.py
Normal file
1438
models/avsd_bart.py
Normal file
File diff suppressed because it is too large
Load diff
801
models/gnns.py
Normal file
801
models/gnns.py
Normal file
|
@ -0,0 +1,801 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch_geometric.nn.dense import DenseGATConv, DenseGCNConv, DenseSAGEConv
|
||||
from torch.nn.parameter import Parameter
|
||||
from typing import Optional, Tuple
|
||||
from .utils import get_knn_graph
|
||||
import torch_sparse
|
||||
|
||||
|
||||
class BartAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
dropout: float = 0.0,
|
||||
is_decoder: bool = False,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.dropout = dropout
|
||||
self.head_dim = embed_dim // num_heads
|
||||
|
||||
if (self.head_dim * num_heads) != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
||||
f" and `num_heads`: {num_heads})."
|
||||
)
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.is_decoder = is_decoder
|
||||
|
||||
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
key_value_states: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
layer_head_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
# for the decoder
|
||||
is_cross_attention = key_value_states is not None
|
||||
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
# get query proj
|
||||
query_states = self.q_proj(hidden_states) * self.scaling
|
||||
# get key, value proj
|
||||
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
||||
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
||||
# the provided `key_value_states` to support prefix tuning
|
||||
if (
|
||||
is_cross_attention
|
||||
and past_key_value is not None
|
||||
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
||||
):
|
||||
# reuse k,v, cross_attentions
|
||||
key_states = past_key_value[0]
|
||||
value_states = past_key_value[1]
|
||||
elif is_cross_attention:
|
||||
# cross_attentions
|
||||
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
||||
elif past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
else:
|
||||
# self_attention
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_states, value_states)
|
||||
|
||||
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
||||
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
||||
key_states = key_states.reshape(*proj_shape)
|
||||
value_states = value_states.reshape(*proj_shape)
|
||||
|
||||
src_len = key_states.size(1)
|
||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||
|
||||
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
if layer_head_mask is not None:
|
||||
if layer_head_mask.size() != (self.num_heads,):
|
||||
raise ValueError(
|
||||
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
|
||||
f" {layer_head_mask.size()}"
|
||||
)
|
||||
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
if output_attentions:
|
||||
# this operation is a bit awkward, but it's required to
|
||||
# make sure that attn_weights keeps its gradient.
|
||||
# In order to do so, attn_weights have to be reshaped
|
||||
# twice and have to be reused in the following
|
||||
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
else:
|
||||
attn_weights_reshaped = None
|
||||
|
||||
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||
|
||||
attn_output = torch.bmm(attn_probs, value_states)
|
||||
|
||||
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
|
||||
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
||||
# partitioned across GPUs when using tensor-parallelism.
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights_reshaped, past_key_value
|
||||
|
||||
|
||||
class MLPModule(nn.Module):
|
||||
def __init__(self, d_in, d_hidden, d_out, num_layers=3, dropout=0.3, use_non_linear=False, use_batch_norm=False):
|
||||
super(MLPModule, self).__init__()
|
||||
|
||||
self.use_batch_norm = use_batch_norm
|
||||
self.dropout = dropout
|
||||
|
||||
self.fcs = nn.ModuleList()
|
||||
self.batch_norms = nn.ModuleList()
|
||||
|
||||
if num_layers == 1:
|
||||
self.fcs.append(nn.Linear(d_in, d_out))
|
||||
else:
|
||||
self.fcs.append(nn.Linear(d_in, d_hidden))
|
||||
self.batch_norms.append(nn.BatchNorm1d(d_hidden))
|
||||
for _ in range(num_layers - 2):
|
||||
self.fcs.append(nn.Linear(d_hidden, d_hidden))
|
||||
self.batch_norms.append(nn.BatchNorm1d(d_hidden))
|
||||
|
||||
self.fcs.append(nn.Linear(d_hidden, d_out))
|
||||
|
||||
self.act_fn = nn.GELU()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.use_non_linear=use_non_linear
|
||||
|
||||
|
||||
def reset_parameters(self):
|
||||
for fc in self.fcs:
|
||||
fc.reset_parameters()
|
||||
for bn in self.batch_norms:
|
||||
bn.reset_parameters()
|
||||
|
||||
def forward(self, X):
|
||||
for fc, bn in zip(self.fcs[:-1], self.batch_norms):
|
||||
X = fc(X)
|
||||
X = self.act_fn(X)
|
||||
if self.use_batch_norm:
|
||||
if X.dim() > 2:
|
||||
X = X.transpose(1, 2)
|
||||
X = bn(X)
|
||||
if X.dim() > 2:
|
||||
X = X.transpose(1, 2)
|
||||
X = self.dropout(X)
|
||||
X = self.fcs[-1](X)
|
||||
return X
|
||||
|
||||
|
||||
class GATModule(nn.Module):
|
||||
def __init__(self, d_in, d_hidden, d_out, num_layers=3, dropout=0.3, concat=True, heads=2, use_non_linear=False, use_batch_norm=False):
|
||||
super(GATModule, self).__init__()
|
||||
self.gnns = nn.ModuleList()
|
||||
if concat:
|
||||
d_hidden = d_hidden // heads
|
||||
d_out = d_out // heads
|
||||
|
||||
self.gnns.append(DenseGATConv(d_in, d_hidden, heads=heads, concat=concat, dropout=dropout))
|
||||
|
||||
self.batch_norms = nn.ModuleList()
|
||||
self.batch_norms.append(nn.BatchNorm1d(d_hidden * heads if concat else d_hidden))
|
||||
|
||||
for _ in range(num_layers - 2):
|
||||
self.gnns.append(DenseGATConv(
|
||||
d_hidden * heads if concat else d_hidden, d_hidden,
|
||||
heads=heads,
|
||||
concat=concat,
|
||||
dropout=dropout)
|
||||
)
|
||||
self.batch_norms.append(nn.BatchNorm1d(d_hidden * heads if concat else d_hidden))
|
||||
|
||||
self.gnns.append(DenseGATConv(
|
||||
d_hidden * heads if concat else d_hidden, d_out,
|
||||
heads=heads,
|
||||
concat=concat,
|
||||
dropout=dropout)
|
||||
)
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.non_linear = nn.GELU()
|
||||
self.use_batch_norm = use_batch_norm
|
||||
self.use_non_linear = use_non_linear
|
||||
|
||||
def reset_parameters(self):
|
||||
for gnn in self.gnns:
|
||||
gnn.reset_parameters()
|
||||
for batch_norm in self.batch_norms:
|
||||
batch_norm.reset_parameters()
|
||||
|
||||
def forward(self, X, A):
|
||||
Z = self.dropout(X)
|
||||
for i in range(len(self.gnns) - 1):
|
||||
Z = self.gnns[i](Z, A)
|
||||
if self.use_batch_norm:
|
||||
Z = Z.transpose(1, 2)
|
||||
Z = self.batch_norms[i](Z)
|
||||
Z = Z.transpose(1, 2)
|
||||
if self.use_non_linear:
|
||||
Z = self.non_linear(Z)
|
||||
Z = self.dropout(Z)
|
||||
Z = self.gnns[-1](Z, A)
|
||||
return Z
|
||||
|
||||
|
||||
class GCNModule(nn.Module):
|
||||
def __init__(self, d_in, d_hidden, d_out, num_layers=3, dropout=0.3, use_non_linear=False, use_batch_norm=False):
|
||||
super(GCNModule, self).__init__()
|
||||
self.gnns = nn.ModuleList()
|
||||
|
||||
self.gnns.append(DenseGCNConv(d_in, d_hidden))
|
||||
|
||||
self.batch_norms = nn.ModuleList()
|
||||
self.batch_norms.append(nn.BatchNorm1d(d_hidden))
|
||||
|
||||
for _ in range(num_layers - 2):
|
||||
self.gnns.append(DenseGCNConv(
|
||||
d_hidden, d_hidden)
|
||||
)
|
||||
self.batch_norms.append(nn.BatchNorm1d(d_hidden))
|
||||
|
||||
self.gnns.append(DenseGCNConv(
|
||||
d_hidden, d_out)
|
||||
)
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.non_linear = nn.GELU()
|
||||
self.use_batch_norm = use_batch_norm
|
||||
self.use_non_linear = use_non_linear
|
||||
|
||||
def reset_parameters(self):
|
||||
for gnn in self.gnns:
|
||||
gnn.reset_parameters()
|
||||
for batch_norm in self.batch_norms:
|
||||
batch_norm.reset_parameters()
|
||||
|
||||
def forward(self, X, A):
|
||||
Z = self.dropout(X)
|
||||
for i in range(len(self.gnns) - 1):
|
||||
Z = self.gnns[i](Z, A)
|
||||
if self.use_batch_norm:
|
||||
Z = Z.transpose(1, 2)
|
||||
Z = self.batch_norms[i](Z)
|
||||
Z = Z.transpose(1, 2)
|
||||
if self.use_non_linear:
|
||||
Z = self.non_linear(Z)
|
||||
Z = self.dropout(Z)
|
||||
Z = self.gnns[-1](Z, A)
|
||||
return Z
|
||||
|
||||
|
||||
class SAGEModule(nn.Module):
|
||||
def __init__(self, d_in, d_hidden, d_out, num_layers=3, dropout=0.3, use_non_linear=False, use_batch_norm=False):
|
||||
super(SAGEModule, self).__init__()
|
||||
self.gnns = nn.ModuleList()
|
||||
|
||||
self.gnns.append(DenseSAGEConv(d_in, d_hidden))
|
||||
|
||||
self.batch_norms = nn.ModuleList()
|
||||
self.batch_norms.append(nn.BatchNorm1d(d_hidden))
|
||||
|
||||
for _ in range(num_layers - 2):
|
||||
self.gnns.append(DenseSAGEConv(
|
||||
d_hidden, d_hidden)
|
||||
)
|
||||
self.batch_norms.append(nn.BatchNorm1d(d_hidden))
|
||||
|
||||
self.gnns.append(DenseSAGEConv(
|
||||
d_hidden, d_out)
|
||||
)
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.non_linear = nn.GELU()
|
||||
self.use_batch_norm = use_batch_norm
|
||||
self.use_non_linear = use_non_linear
|
||||
|
||||
def reset_parameters(self):
|
||||
for gnn in self.gnns:
|
||||
gnn.reset_parameters()
|
||||
for batch_norm in self.batch_norms:
|
||||
batch_norm.reset_parameters()
|
||||
|
||||
def forward(self, X, A):
|
||||
Z = self.dropout(X)
|
||||
for i in range(len(self.gnns) - 1):
|
||||
Z = self.gnns[i](Z, A)
|
||||
if self.use_batch_norm:
|
||||
Z = Z.transpose(1, 2)
|
||||
Z = self.batch_norms[i](Z)
|
||||
Z = Z.transpose(1, 2)
|
||||
if self.use_non_linear:
|
||||
Z = self.non_linear(Z)
|
||||
Z = self.dropout(Z)
|
||||
Z = self.gnns[-1](Z, A)
|
||||
return Z
|
||||
|
||||
|
||||
class GlobalGraphLearner(nn.Module):
|
||||
def __init__(self, d_in, num_heads, random=False):
|
||||
super(GlobalGraphLearner, self).__init__()
|
||||
self.random = random
|
||||
if not self.random:
|
||||
w = torch.Tensor(num_heads, d_in)
|
||||
self.w = Parameter(nn.init.xavier_uniform_(w), requires_grad=True)
|
||||
|
||||
def reset_parameters(self):
|
||||
if not self.random:
|
||||
self.w = Parameter(nn.init.xavier_uniform_(self.w))
|
||||
|
||||
def forward(self, Z):
|
||||
if self.random:
|
||||
att_global = torch.randn((Z.size(0), Z.size(1), Z.size(1))).to(Z.device)
|
||||
else:
|
||||
w_expanded = self.w.unsqueeze(1).unsqueeze(1)
|
||||
Z = Z.unsqueeze(0) * w_expanded
|
||||
Z = F.normalize(Z, p=2, dim=-1)
|
||||
att_global = torch.matmul(Z, Z.transpose(-1, -2)).mean(0)
|
||||
mask_global = (att_global > 0).detach().float()
|
||||
att_global = att_global * mask_global
|
||||
|
||||
return att_global
|
||||
|
||||
|
||||
class DenseAPPNP(nn.Module):
|
||||
def __init__(self, K, alpha):
|
||||
super().__init__()
|
||||
self.K = K
|
||||
self.alpha = alpha
|
||||
|
||||
def forward(self, x, adj_t):
|
||||
h = x
|
||||
for _ in range(self.K):
|
||||
if adj_t.is_sparse:
|
||||
x = torch_sparse.spmm(adj_t, x)
|
||||
else:
|
||||
x = torch.matmul(adj_t, x)
|
||||
x = x * (1 - self.alpha)
|
||||
x += self.alpha * h
|
||||
x /= self.K
|
||||
return x
|
||||
|
||||
|
||||
class Dense_APPNP_Net(nn.Module):
|
||||
def __init__(self, d_in, d_hidden, d_out, dropout=.5, K=10, alpha=.1):
|
||||
super(Dense_APPNP_Net, self).__init__()
|
||||
self.lin1 = nn.Linear(d_in, d_hidden)
|
||||
self.lin2 = nn.Linear(d_hidden, d_out)
|
||||
self.prop1 = DenseAPPNP(K, alpha)
|
||||
self.dropout = dropout
|
||||
|
||||
def reset_parameters(self):
|
||||
self.lin1.reset_parameters()
|
||||
self.lin2.reset_parameters()
|
||||
|
||||
def forward(self, x, adj_t):
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = F.relu(self.lin1(x))
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = self.lin2(x)
|
||||
x = self.prop1(x, adj_t)
|
||||
return x
|
||||
|
||||
|
||||
class MMGraphLearner(nn.Module):
|
||||
def __init__(self, d_in, num_heads, random=False):
|
||||
super(MMGraphLearner, self).__init__()
|
||||
self.random = random
|
||||
if not self.random:
|
||||
w = torch.Tensor(num_heads, d_in)
|
||||
self.w = Parameter(nn.init.xavier_uniform_(w), requires_grad=True)
|
||||
|
||||
self.fc = nn.Linear(d_in, d_in)
|
||||
|
||||
def reset_parameters(self):
|
||||
if not self.random:
|
||||
self.fc.reset_parameters()
|
||||
self.w = Parameter(nn.init.xavier_uniform_(self.w), requires_grad=True)
|
||||
|
||||
def forward(self, features):
|
||||
if self.random:
|
||||
att = torch.randn((features.size(0), features.size(1), features.size(1))).to(features.device)
|
||||
else:
|
||||
features = self.fc(features)
|
||||
w_expanded = self.w.unsqueeze(1).unsqueeze(1)
|
||||
features = features.unsqueeze(0) * w_expanded
|
||||
features = F.normalize(features, p=2, dim=-1)
|
||||
att = torch.matmul(features, features.transpose(-1, -2)).mean(0)
|
||||
mask = (att > 0).detach().float()
|
||||
att = att * mask
|
||||
return att
|
||||
|
||||
|
||||
class QNetLocal(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(QNetLocal, self).__init__()
|
||||
self.config=config
|
||||
|
||||
self.mm_gnn_modules = nn.ModuleList()
|
||||
self.mm_graph_learners_1 = nn.ModuleList()
|
||||
self.mm_graph_learners_2 = nn.ModuleList()
|
||||
|
||||
for _ in range(self.config.num_modalities):
|
||||
if self.config.gnn_type == 'gat':
|
||||
self.mm_gnn_modules.append(GATModule(
|
||||
self.config.d_model, self.config.d_model, self.config.d_model,
|
||||
num_layers=self.config.num_local_gnn_layers,
|
||||
heads=self.config.num_local_gnn_heads,
|
||||
dropout=self.config.local_gnn_dropout,
|
||||
concat=self.config.local_gnn_concat,
|
||||
use_batch_norm=self.config.use_local_gnn_bn,
|
||||
use_non_linear=self.config.use_non_linear
|
||||
)
|
||||
)
|
||||
elif self.config.gnn_type == 'appnp':
|
||||
self.mm_gnn_modules.append(Dense_APPNP_Net(
|
||||
self.config.d_model, self.config.d_model, self.config.d_model, dropout=self.config.local_gnn_dropout,
|
||||
K=self.config.gnn_K, alpha=self.config.gnn_alpha
|
||||
)
|
||||
)
|
||||
elif self.config.gnn_type == 'gcn':
|
||||
self.mm_gnn_modules.append(GCNModule(
|
||||
self.config.d_model, self.config.d_model, self.config.d_model,
|
||||
num_layers=self.config.num_local_gnn_layers,
|
||||
dropout=self.config.local_gnn_dropout,
|
||||
use_batch_norm=self.config.use_local_gnn_bn,
|
||||
use_non_linear=self.config.use_non_linear
|
||||
)
|
||||
)
|
||||
elif self.config.gnn_type == 'sage':
|
||||
self.mm_gnn_modules.append(SAGEModule(
|
||||
self.config.d_model, self.config.d_model, self.config.d_model,
|
||||
num_layers=self.config.num_local_gnn_layers,
|
||||
dropout=self.config.local_gnn_dropout,
|
||||
use_batch_norm=self.config.use_local_gnn_bn,
|
||||
use_non_linear=self.config.use_non_linear
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError
|
||||
self.mm_graph_learners_1.append(MMGraphLearner(self.config.d_model, self.config.num_local_gr_learner_heads, random=self.config.use_random_graphs))
|
||||
self.mm_graph_learners_2.append(MMGraphLearner(self.config.d_model * 2, self.config.num_local_gr_learner_heads, random=self.config.use_random_graphs))
|
||||
|
||||
def reset_parameters(self):
|
||||
for i in range(self.config.num_modalities):
|
||||
self.mm_gnn_modules[i].reset_parameters()
|
||||
self.mm_graph_learners_1[i].reset_parameters()
|
||||
self.mm_graph_learners_2[i].reset_parameters()
|
||||
|
||||
def forward(self, features, A_tildes=None):
|
||||
mm_Xs = features# []
|
||||
device = features[0].device
|
||||
|
||||
if A_tildes is None:
|
||||
A_tildes = []
|
||||
for mm_X in mm_Xs:
|
||||
A_tildes.append(get_knn_graph(mm_X, self.config.num_nn, device))
|
||||
|
||||
################# Multi-modal graph learner (upper branch) #################
|
||||
A_primes = []
|
||||
for i, mm_X in enumerate(mm_Xs): # iterate over the modalities
|
||||
A_primes.append(self.mm_graph_learners_1[i](mm_X))
|
||||
|
||||
# Linear combination of A_primes with A_tildes
|
||||
A_primes = [(1 - self.config.init_adj_ratio) * A_prime + self.config.init_adj_ratio * A_tilde for A_prime, A_tilde in zip(A_primes, A_tildes)]
|
||||
|
||||
################# Multi-modal gnn (upper branch) #################
|
||||
Z_primes = []
|
||||
for i, (mm_X, A_prime) in enumerate(zip(mm_Xs, A_primes)):
|
||||
Z_primes.append(self.mm_gnn_modules[i](mm_X, A_prime))
|
||||
|
||||
################# Multi-modal gnn (lower branch) #################
|
||||
Z_double_primes = []
|
||||
for i, (mm_X, A_tilde) in enumerate(zip(mm_Xs, A_tildes)):
|
||||
Z_double_primes.append(self.mm_gnn_modules[i](mm_X, A_tilde))
|
||||
|
||||
Z_concats = [torch.cat([Z_1, Z_2], dim=-1) for Z_1, Z_2 in zip(Z_primes, Z_double_primes)]
|
||||
|
||||
################# Multi-modal graph learner (lower branch) #################
|
||||
A_double_primes = []
|
||||
for i, Z_concat in enumerate(Z_concats):
|
||||
A_double_primes.append(self.mm_graph_learners_2[i](Z_concat))
|
||||
|
||||
A_double_primes = [(1 - self.config.init_adj_ratio) * A_double_prime + self.config.init_adj_ratio * A_tilde for A_double_prime, A_tilde in zip(A_double_primes, A_tildes)]
|
||||
|
||||
As = [(1 - self.config.adj_ratio) * A_prime + self.config.adj_ratio * A_double_prime for A_prime, A_double_prime in zip(A_primes, A_double_primes)]
|
||||
|
||||
################## Average across all multimodal inputs ##################
|
||||
|
||||
Zs = [0.5 * Z1 + 0.5 * Z2 for Z1, Z2 in zip(Z_primes, Z_double_primes)]
|
||||
return As, Zs
|
||||
|
||||
|
||||
class PNetLocal(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(PNetLocal, self).__init__()
|
||||
self.config = config
|
||||
self.mm_gnn_modules = nn.ModuleList()
|
||||
self.mm_mlp_modules = nn.ModuleList()
|
||||
|
||||
self.mm_graph_learners_1 = nn.ModuleList()
|
||||
self.mm_graph_learners_2 = nn.ModuleList()
|
||||
|
||||
for _ in range(self.config.num_modalities):
|
||||
if self.config.gnn_type == 'gat':
|
||||
self.mm_gnn_modules.append(GATModule(
|
||||
self.config.d_model, self.config.d_model, self.config.d_model,
|
||||
num_layers=self.config.num_local_gnn_layers,
|
||||
heads=self.config.num_local_gnn_heads,
|
||||
dropout=self.config.local_gnn_dropout,
|
||||
concat=self.config.local_gnn_concat,
|
||||
use_batch_norm=self.config.use_local_gnn_bn,
|
||||
use_non_linear=self.config.use_non_linear
|
||||
)
|
||||
)
|
||||
elif self.config.gnn_type == 'appnp':
|
||||
self.mm_gnn_modules.append(Dense_APPNP_Net(
|
||||
self.config.d_model, self.config.d_model, self.config.d_model, dropout=self.config.local_gnn_dropout,
|
||||
K=self.config.gnn_K, alpha=self.config.gnn_alpha
|
||||
)
|
||||
)
|
||||
elif self.config.gnn_type == 'gcn':
|
||||
self.mm_gnn_modules.append(GCNModule(
|
||||
self.config.d_model, self.config.d_model, self.config.d_model,
|
||||
num_layers=self.config.num_local_gnn_layers,
|
||||
dropout=self.config.local_gnn_dropout,
|
||||
use_batch_norm=self.config.use_local_gnn_bn,
|
||||
use_non_linear=self.config.use_non_linear
|
||||
)
|
||||
)
|
||||
elif self.config.gnn_type == 'sage':
|
||||
self.mm_gnn_modules.append(SAGEModule(
|
||||
self.config.d_model, self.config.d_model, self.config.d_model,
|
||||
num_layers=self.config.num_local_gnn_layers,
|
||||
dropout=self.config.local_gnn_dropout,
|
||||
use_batch_norm=self.config.use_local_gnn_bn,
|
||||
use_non_linear=self.config.use_non_linear
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError
|
||||
self.mm_mlp_modules.append(MLPModule(
|
||||
self.config.d_model, self.config.d_model, self.config.d_model,
|
||||
num_layers=self.config.num_local_fc_layers,
|
||||
dropout=self.config.local_fc_dropout,
|
||||
use_batch_norm=self.config.use_local_fc_bn,
|
||||
use_non_linear=self.config.use_non_linear
|
||||
))
|
||||
|
||||
self.mm_graph_learners_1.append(MMGraphLearner(self.config.d_model, self.config.num_local_gr_learner_heads, random=self.config.use_random_graphs))
|
||||
self.mm_graph_learners_2.append(MMGraphLearner(self.config.d_model * 2, self.config.num_local_gr_learner_heads, random=self.config.use_random_graphs))
|
||||
|
||||
def reset_parameters(self):
|
||||
for i in range(self.config.num_modalities):
|
||||
self.mm_gnn_modules[i].reset_parameters()
|
||||
self.mm_mlp_modules[i].reset_parameters()
|
||||
self.mm_graph_learners_1[i].reset_parameters()
|
||||
self.mm_graph_learners_2[i].reset_parameters()
|
||||
|
||||
def forward(self, features):
|
||||
mm_Xs = features
|
||||
|
||||
################# Multi-modal graph learner (upper branch) #################
|
||||
A_primes = []
|
||||
for i, mm_X in enumerate(mm_Xs): # iterate over the modalities
|
||||
A_primes.append(self.mm_graph_learners_1[i](mm_X))
|
||||
|
||||
################# Multi-modal gnn (upper branch) #################
|
||||
Z_primes = []
|
||||
for i, (mm_X, A_prime) in enumerate(zip(mm_Xs, A_primes)):
|
||||
Z_primes.append(self.mm_gnn_modules[i](mm_X, A_prime))
|
||||
|
||||
################# Multi-modal gnn (lower branch) #################
|
||||
Z_double_primes = []
|
||||
for i, mm_X, in enumerate(mm_Xs):
|
||||
Z_double_primes.append(self.mm_mlp_modules[i](mm_X))
|
||||
|
||||
Z_concats = [torch.cat([Z_1, Z_2], dim=-1) for Z_1, Z_2 in zip(Z_primes, Z_double_primes)]
|
||||
|
||||
################# Multi-modal graph learner (lower branch) #################
|
||||
A_double_primes = []
|
||||
for i, Z_concat in enumerate(Z_concats):
|
||||
A_double_primes.append(self.mm_graph_learners_2[i](Z_concat))
|
||||
|
||||
As = [(1 - self.config.adj_ratio) * A_prime + self.config.adj_ratio * A_double_prime for A_prime, A_double_prime in zip(A_primes, A_double_primes)]
|
||||
|
||||
################## Average across all multimodal inputs ##################
|
||||
|
||||
Zs = [0.5 * Z1 + 0.5 * Z2 for Z1, Z2 in zip(Z_primes, Z_double_primes)]
|
||||
|
||||
return As, Zs
|
||||
|
||||
|
||||
class QNetGlobal(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(QNetGlobal, self).__init__()
|
||||
self.config = config
|
||||
if self.config.gnn_type == 'gat':
|
||||
self.gnn = GATModule(
|
||||
self.config.d_model, self.config.d_model, self.config.d_model,
|
||||
num_layers=self.config.num_global_gnn_layers,
|
||||
heads=self.config.num_global_gnn_heads,
|
||||
dropout=self.config.global_gnn_dropout,
|
||||
concat=self.config.global_gnn_concat,
|
||||
use_batch_norm=self.config.use_global_gnn_bn,
|
||||
use_non_linear=self.config.use_non_linear
|
||||
)
|
||||
elif self.config.gnn_type == 'appnp':
|
||||
self.gnn = Dense_APPNP_Net(
|
||||
self.config.d_model, self.config.d_model, self.config.d_model, dropout=self.config.local_gnn_dropout,
|
||||
K=self.config.gnn_K, alpha=self.config.gnn_alpha
|
||||
)
|
||||
elif self.config.gnn_type == 'gcn':
|
||||
self.gnn = GCNModule(
|
||||
self.config.d_model, self.config.d_model, self.config.d_model,
|
||||
num_layers=self.config.num_global_gnn_layers,
|
||||
dropout=self.config.global_gnn_dropout,
|
||||
use_batch_norm=self.config.use_global_gnn_bn,
|
||||
use_non_linear=self.config.use_non_linear
|
||||
)
|
||||
|
||||
elif self.config.gnn_type == 'sage':
|
||||
self.gnn = SAGEModule(
|
||||
self.config.d_model, self.config.d_model, self.config.d_model,
|
||||
num_layers=self.config.num_global_gnn_layers,
|
||||
dropout=self.config.global_gnn_dropout,
|
||||
use_batch_norm=self.config.use_global_gnn_bn,
|
||||
use_non_linear=self.config.use_non_linear
|
||||
)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
self.graph_learner_1 = GlobalGraphLearner(self.config.d_model, self.config.num_global_gr_learner_heads, self.config.use_random_graphs)
|
||||
self.graph_learner_2 = GlobalGraphLearner(self.config.d_model * 2, self.config.num_global_gr_learner_heads, self.config.use_random_graphs)
|
||||
|
||||
def reset_parameters(self):
|
||||
self.gnn.reset_parameters()
|
||||
self.graph_learner_1.reset_parameters()
|
||||
self.graph_learner_2.reset_parameters()
|
||||
|
||||
def forward(self, Z, A):
|
||||
|
||||
################# Graph learner (upper branch) #################
|
||||
A_prime = self.graph_learner_1(Z)
|
||||
A_prime = (1-self.config.init_adj_ratio) * A_prime + self.config.init_adj_ratio * A
|
||||
|
||||
################# Gnn (upper branch) #################
|
||||
Z_prime = self.gnn(Z, A_prime)
|
||||
|
||||
################# Gnn (lower branch) #################
|
||||
Z_double_prime = self.gnn(Z, A)
|
||||
Z_concat = torch.cat([Z_prime, Z_double_prime], dim=-1)
|
||||
|
||||
################# Graph learner (lower branch) #################
|
||||
A_double_prime = self.graph_learner_2(Z_concat)
|
||||
A_double_prime = (1-self.config.init_adj_ratio) * A_double_prime + self.config.init_adj_ratio * A
|
||||
|
||||
################## Average across branches ##################
|
||||
A_global = (1 - self.config.adj_ratio) * A_prime + self.config.adj_ratio * A_double_prime
|
||||
Z_global = 0.5 * Z_prime + 0.5 * Z_double_prime
|
||||
return A_global, Z_global
|
||||
|
||||
|
||||
class PNetGlobal(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(PNetGlobal, self).__init__()
|
||||
self.config = config
|
||||
|
||||
if self.config.gnn_type == 'gat':
|
||||
self.gnn = GATModule(
|
||||
self.config.d_model, self.config.d_model, self.config.d_model,
|
||||
num_layers=self.config.num_global_gnn_layers,
|
||||
heads=self.config.num_global_gnn_heads,
|
||||
dropout=self.config.global_gnn_dropout,
|
||||
concat=self.config.global_gnn_concat,
|
||||
use_batch_norm=self.config.use_global_gnn_bn,
|
||||
use_non_linear=self.config.use_non_linear
|
||||
)
|
||||
elif self.config.gnn_type == 'appnp':
|
||||
self.gnn = Dense_APPNP_Net(
|
||||
self.config.d_model, self.config.d_model, self.config.d_model, dropout=self.config.local_gnn_dropout,
|
||||
K=self.config.gnn_K, alpha=self.config.gnn_alpha
|
||||
)
|
||||
elif self.config.gnn_type == 'gcn':
|
||||
self.gnn = GCNModule(
|
||||
self.config.d_model, self.config.d_model, self.config.d_model,
|
||||
num_layers=self.config.num_global_gnn_layers,
|
||||
dropout=self.config.global_gnn_dropout,
|
||||
use_batch_norm=self.config.use_global_gnn_bn,
|
||||
use_non_linear=self.config.use_non_linear
|
||||
)
|
||||
elif self.config.gnn_type == 'sage':
|
||||
self.gnn = SAGEModule(
|
||||
self.config.d_model, self.config.d_model, self.config.d_model,
|
||||
num_layers=self.config.num_global_gnn_layers,
|
||||
dropout=self.config.global_gnn_dropout,
|
||||
use_batch_norm=self.config.use_global_gnn_bn,
|
||||
use_non_linear=self.config.use_non_linear
|
||||
)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
self.mlp = MLPModule(
|
||||
self.config.d_model, self.config.d_model, self.config.d_model,
|
||||
num_layers=self.config.num_global_fc_layers,
|
||||
dropout=self.config.global_fc_dropout,
|
||||
use_batch_norm=self.config.use_global_fc_bn,
|
||||
use_non_linear=self.config.use_non_linear
|
||||
|
||||
)
|
||||
|
||||
self.graph_learner_1 = GlobalGraphLearner(self.config.d_model, self.config.num_global_gr_learner_heads, random=self.config.use_random_graphs)
|
||||
self.graph_learner_2 = GlobalGraphLearner(self.config.d_model * 2, self.config.num_global_gr_learner_heads, random=self.config.use_random_graphs)
|
||||
|
||||
def reset_parameters(self):
|
||||
self.gnn.reset_parameters()
|
||||
self.mlp.reset_parameters()
|
||||
self.graph_learner_1.reset_parameters()
|
||||
self.graph_learner_2.reset_parameters()
|
||||
|
||||
def forward(self, Z, A):
|
||||
|
||||
################# Graph learner (upper branch) #################
|
||||
A_prime = self.graph_learner_1(Z)
|
||||
|
||||
################# Gnn (upper branch) #################
|
||||
Z_prime = self.gnn(Z, A_prime)
|
||||
|
||||
################# mlp (lower branch) #################
|
||||
Z_double_prime = self.mlp(Z)
|
||||
Z_concat = torch.cat([Z_prime, Z_double_prime], dim=-1)
|
||||
|
||||
################# Graph learner (lower branch) #################
|
||||
A_double_prime = self.graph_learner_2(Z_concat)
|
||||
# A_double_prime = (1-self.config.init_adj_ratio) * A_double_prime + self.config.init_adj_ratio * A
|
||||
|
||||
################## Average across braches ##################
|
||||
A_global = (1 - self.config.adj_ratio) * A_prime + self.config.adj_ratio * A_double_prime
|
||||
Z_global = 0.5 * Z_prime + 0.5 * Z_double_prime
|
||||
return A_global, Z_global
|
||||
|
1397
models/nextqa_bart.py
Normal file
1397
models/nextqa_bart.py
Normal file
File diff suppressed because it is too large
Load diff
249
models/utils.py
Normal file
249
models/utils.py
Normal file
|
@ -0,0 +1,249 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers.utils import ModelOutput
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
class ELBO(nn.Module):
|
||||
def __init__(self):
|
||||
super(ELBO, self).__init__()
|
||||
|
||||
def forward(self, QA, PA):
|
||||
QA_flattened = QA.view(-1).unsqueeze(-1)
|
||||
PA_flattened = PA.view(-1).unsqueeze(-1)
|
||||
|
||||
QA_flattened = torch.cat([torch.zeros_like(QA_flattened), QA_flattened], dim=-1)
|
||||
PA_flattened = torch.cat([torch.zeros_like(PA_flattened), PA_flattened], dim=-1)
|
||||
|
||||
log_QA = F.log_softmax(QA_flattened, dim=1)
|
||||
log_PA = F.log_softmax(PA_flattened, dim=1)
|
||||
|
||||
QA_dist = torch.exp(log_QA)
|
||||
|
||||
loss_QA = torch.mean(log_QA * QA_dist)
|
||||
loss_PA = torch.mean(log_PA * QA_dist)
|
||||
|
||||
loss = loss_QA - loss_PA
|
||||
|
||||
return loss
|
||||
|
||||
def seperate_nextqa_input_modalities(
|
||||
features, i3d_rgb_interval, i3d_flow_interval, question_intervals,
|
||||
vis_state_vector_idx, question_state_vector_idx,
|
||||
attention_values=None):
|
||||
""" We separate the multimodal input hidden states. The state token embeddings are left out (+1 while indexing)
|
||||
|
||||
Args:
|
||||
features (_type_): _description_
|
||||
i3d_rgb_interval (_type_): _description_
|
||||
i3d_flow_interval (_type_): _description_
|
||||
sam_interval (_type_): _description_
|
||||
audio_interval (_type_): _description_
|
||||
history_intervals (_type_): _description_
|
||||
question_intervals (_type_): _description_
|
||||
|
||||
Returns:
|
||||
_type_: _description_
|
||||
"""
|
||||
features_copy = features.clone() # .detach()
|
||||
i3d_rgb_hidden = features_copy[:, i3d_rgb_interval[0]+1:i3d_rgb_interval[1], :]
|
||||
i3d_flow_hidden = features_copy[:, i3d_flow_interval[0]+1:i3d_flow_interval[1], :]
|
||||
|
||||
question_hidden = []
|
||||
features_split = torch.split(features_copy, 1, dim=0)
|
||||
for ques_inter, feat in zip(question_intervals, features_split):
|
||||
ques_idx = torch.arange(ques_inter[0]+1, ques_inter[1]).unsqueeze(0).unsqueeze(-1).repeat(1, 1, feat.size(-1)).to(feat.device)
|
||||
question_hidden.append(torch.gather(feat, 1, ques_idx))
|
||||
|
||||
if attention_values is None:
|
||||
i3d_rgb_att = None
|
||||
i3d_flow_att = None
|
||||
question_att = None
|
||||
else:
|
||||
attention_values = attention_values.mean(1)
|
||||
i3d_rgb_att = attention_values[:, vis_state_vector_idx[0], vis_state_vector_idx[0]+1:vis_state_vector_idx[1]]
|
||||
i3d_flow_att = attention_values[:, vis_state_vector_idx[1], vis_state_vector_idx[1]+1:question_state_vector_idx[0]]
|
||||
question_att = [attention_values[i, question_state_vector_idx[i], question_intervals[i][0] + 1: question_intervals[i][1]] for i in range(len(question_state_vector_idx))]
|
||||
|
||||
features_list = [i3d_rgb_hidden, i3d_flow_hidden, question_hidden]
|
||||
att = [i3d_rgb_att, i3d_flow_att, question_att]
|
||||
|
||||
return features_list, att
|
||||
|
||||
|
||||
def seperate_input_modalities(
|
||||
features, i3d_rgb_interval, i3d_flow_interval, sam_interval, audio_interval, history_intervals, question_intervals,
|
||||
vis_state_vector_idx, history_state_vector_idx, question_state_vector_idx,
|
||||
attention_values=None):
|
||||
""" We separate the multimodal input hidden states. The state token embeddings are left out (+1 while indexing)
|
||||
|
||||
Args:
|
||||
features (_type_): _description_
|
||||
i3d_rgb_interval (_type_): _description_
|
||||
i3d_flow_interval (_type_): _description_
|
||||
sam_interval (_type_): _description_
|
||||
audio_interval (_type_): _description_
|
||||
history_intervals (_type_): _description_
|
||||
question_intervals (_type_): _description_
|
||||
|
||||
Returns:
|
||||
_type_: _description_
|
||||
"""
|
||||
features_copy = features.clone() # .detach()
|
||||
i3d_rgb_hidden = features_copy[:, i3d_rgb_interval[0]+1:i3d_rgb_interval[1], :]
|
||||
i3d_flow_hidden = features_copy[:, i3d_flow_interval[0]+1:i3d_flow_interval[1], :]
|
||||
sam_hidden = features_copy[:, sam_interval[0]+1:sam_interval[1], :]
|
||||
audio_hidden = features_copy[:, audio_interval[0]+1:audio_interval[1], :]
|
||||
|
||||
history_hidden = []
|
||||
question_hidden = []
|
||||
features_split = torch.split(features_copy, 1, dim=0)
|
||||
for hist_inter, ques_inter, feat in zip(history_intervals, question_intervals, features_split):
|
||||
hist_idx = torch.arange(hist_inter[0]+1, hist_inter[1]).unsqueeze(0).unsqueeze(-1).repeat(1, 1, feat.size(-1)).to(feat.device)
|
||||
history_hidden.append(torch.gather(feat, 1, hist_idx))
|
||||
|
||||
ques_idx = torch.arange(ques_inter[0]+1, ques_inter[1]).unsqueeze(0).unsqueeze(-1).repeat(1, 1, feat.size(-1)).to(feat.device)
|
||||
question_hidden.append(torch.gather(feat, 1, ques_idx))
|
||||
|
||||
if attention_values is None:
|
||||
i3d_rgb_att = None
|
||||
i3d_flow_att = None
|
||||
sam_att = None
|
||||
audio_att = None
|
||||
history_att = None
|
||||
question_att = None
|
||||
else:
|
||||
attention_values = attention_values.mean(1)
|
||||
i3d_rgb_att = attention_values[:, vis_state_vector_idx[0], vis_state_vector_idx[0]+1:vis_state_vector_idx[1]]
|
||||
i3d_flow_att = attention_values[:, vis_state_vector_idx[1], vis_state_vector_idx[1]+1:vis_state_vector_idx[2]]
|
||||
sam_att = attention_values[:, vis_state_vector_idx[2], vis_state_vector_idx[2]+1:vis_state_vector_idx[3]]
|
||||
audio_att = attention_values[:, vis_state_vector_idx[3], vis_state_vector_idx[3]+1:history_state_vector_idx[0] - 1]
|
||||
history_att = [attention_values[i, history_state_vector_idx[i], history_intervals[i][0] + 1 : history_intervals[i][1]] for i in range(len(history_state_vector_idx))]
|
||||
question_att = [attention_values[i, question_state_vector_idx[i], question_intervals[i][0] + 1: question_intervals[i][1]] for i in range(len(question_state_vector_idx))]
|
||||
|
||||
features_list = [i3d_rgb_hidden, i3d_flow_hidden, sam_hidden, audio_hidden, history_hidden, question_hidden]
|
||||
att = [i3d_rgb_att, i3d_flow_att, sam_att, audio_att, history_att, question_att]
|
||||
|
||||
return features_list, att
|
||||
|
||||
|
||||
def get_knn_graph(features, num_nn, device):
|
||||
features = features.permute((1, 2, 0))
|
||||
cosine_sim_pairwise = F.cosine_similarity(features, features.unsqueeze(1), dim=-2)
|
||||
cosine_sim_pairwise = cosine_sim_pairwise.permute((2, 0, 1))
|
||||
num_nn = min(num_nn, cosine_sim_pairwise.size(-1))
|
||||
adj_mat = torch.zeros_like(cosine_sim_pairwise).to(device)
|
||||
_, to_keep = torch.topk(cosine_sim_pairwise, num_nn, dim=-1, sorted=False)
|
||||
adj_mat = adj_mat.scatter(-1, to_keep, torch.ones_like(adj_mat).to(device))
|
||||
return adj_mat
|
||||
|
||||
|
||||
def track_features_vis(features, att, top_k, device, node_idx=None):
|
||||
"""Computes an adjacency matrix based on the nearset neighbor similiarity for
|
||||
the i3d, audio, and sam input modalities. The tracked constituents of each modality
|
||||
are randomly chosen (A_tilde in the paper).
|
||||
"""
|
||||
features = features.clone().detach()
|
||||
top_k = min(features.size(1), top_k)
|
||||
if att is None:
|
||||
node_idx = torch.randint(low=0, high=features.size(1), size=(features.size(0), top_k))
|
||||
else:
|
||||
_, node_idx = torch.topk(att, top_k, dim=-1, sorted=False)
|
||||
|
||||
node_idx = node_idx.unsqueeze(-1).repeat(1, 1, features.size(-1)).to(device)
|
||||
|
||||
selected_features = torch.gather(features, 1, node_idx)
|
||||
|
||||
return selected_features, node_idx
|
||||
|
||||
|
||||
def track_features_text(features, att, top_k, device, node_idx=None):
|
||||
"""Computes an adjacency matrix based on the nearset neighbor similiarity for
|
||||
the history and question inputs. The tracked constituents of each modality
|
||||
are randomly chosen (A_tilde in the paper).
|
||||
"""
|
||||
hidden_dim = features[0].size(-1)
|
||||
min_len = min([feat.size(1) for feat in features])
|
||||
top_k = min(min_len, top_k)
|
||||
if att is None:
|
||||
node_idx = [torch.randint(low=0, high=feat.size(1), size=(feat.size(0), top_k)) for feat in features]
|
||||
else:
|
||||
node_idx = [torch.topk(a, top_k, dim=-1, sorted=False)[-1] for a in att]
|
||||
|
||||
node_idx = [idx.unsqueeze(-1).repeat(1, 1, hidden_dim).to(device) for idx in node_idx]
|
||||
|
||||
selected_features = [torch.gather(feat, 1, idx) for feat, idx in zip(features, node_idx)]
|
||||
selected_features = torch.cat(selected_features, dim=0)
|
||||
|
||||
return selected_features, node_idx
|
||||
|
||||
|
||||
def diag_tensor(tensors):
|
||||
device = tensors[0].device
|
||||
n = sum([t.size(-1) for t in tensors])
|
||||
bsz = tensors[0].size(0)
|
||||
diag_tensor = torch.zeros((bsz, n, n)).float().to(device)
|
||||
delimiter = 0
|
||||
delimiters = [0]
|
||||
for t in tensors:
|
||||
diag_tensor[:, delimiter:delimiter+t.size(-1), delimiter:delimiter+t.size(-1)] = t
|
||||
delimiter += t.size(-1)
|
||||
delimiters.append(delimiter)
|
||||
|
||||
return diag_tensor, delimiters
|
||||
|
||||
|
||||
def embed_graphs(features, delimiters):
|
||||
state_vectors = []
|
||||
for i in range(len(delimiters) - 1):
|
||||
state_vectors.append(features[:, delimiters[i]:delimiters[i+1], :].mean(dim=1))
|
||||
return state_vectors
|
||||
|
||||
|
||||
class AVSDEncoderOutput(ModelOutput):
|
||||
last_hidden_state: torch.FloatTensor = None
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
QAs_local = None
|
||||
PAs_local = None
|
||||
QA_global = None
|
||||
PA_global = None
|
||||
state_vectors = None
|
||||
|
||||
|
||||
class AVSDSeq2SeqModelOutput(ModelOutput):
|
||||
|
||||
last_hidden_state: torch.FloatTensor = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
|
||||
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
QAs_local = None
|
||||
PAs_local = None
|
||||
QA_global = None
|
||||
PA_global = None
|
||||
state_vectors = None
|
||||
|
||||
|
||||
class AVSDSeq2SeqLMOutput(ModelOutput):
|
||||
|
||||
gen_loss: Optional[torch.FloatTensor] = None
|
||||
elbo_loss_global: Optional[torch.FloatTensor] = None
|
||||
elbo_loss_local: Optional[torch.FloatTensor] = None
|
||||
logits: torch.FloatTensor = None
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
|
||||
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
encoder_QAs_local = None
|
||||
encoder_PAs_local = None
|
||||
encoder_QA_global = None
|
||||
encoder_PA_global = None
|
||||
encoder_state_vectors = None
|
Loading…
Add table
Add a link
Reference in a new issue