173 lines
No EOL
6.4 KiB
Python
173 lines
No EOL
6.4 KiB
Python
import torch.nn as nn
|
|
import torch
|
|
from dataclasses import dataclass
|
|
from torch.nn.parameter import Parameter
|
|
from numbers import Number
|
|
import torch.nn.functional as F
|
|
from .blocks import *
|
|
import math
|
|
|
|
|
|
class graph_convolution(nn.Module):
|
|
def __init__(self, in_features, out_features, node_n = 3, seq_len = 80, bias=True):
|
|
super(graph_convolution, self).__init__()
|
|
|
|
self.temporal_graph_weights = Parameter(torch.FloatTensor(seq_len, seq_len))
|
|
self.feature_weights = Parameter(torch.FloatTensor(in_features, out_features))
|
|
self.spatial_graph_weights = Parameter(torch.FloatTensor(node_n, node_n))
|
|
|
|
if bias:
|
|
self.bias = Parameter(torch.FloatTensor(seq_len))
|
|
|
|
self.reset_parameters()
|
|
|
|
def reset_parameters(self):
|
|
stdv = 1. / math.sqrt(self.spatial_graph_weights.size(1))
|
|
self.feature_weights.data.uniform_(-stdv, stdv)
|
|
self.temporal_graph_weights.data.uniform_(-stdv, stdv)
|
|
self.spatial_graph_weights.data.uniform_(-stdv, stdv)
|
|
if self.bias is not None:
|
|
self.bias.data.uniform_(-stdv, stdv)
|
|
|
|
def forward(self, x):
|
|
y = torch.matmul(x, self.temporal_graph_weights)
|
|
y = torch.matmul(y.permute(0, 3, 2, 1), self.feature_weights)
|
|
y = torch.matmul(self.spatial_graph_weights, y).permute(0, 3, 2, 1).contiguous()
|
|
|
|
if self.bias is not None:
|
|
return (y + self.bias)
|
|
else:
|
|
return y
|
|
|
|
|
|
@dataclass
|
|
class residual_graph_convolution_config():
|
|
in_features: int
|
|
seq_len: int
|
|
emb_channels: int
|
|
dropout: float
|
|
out_features: int = None
|
|
node_n: int = 3
|
|
# condition the block with time (and encoder's output)
|
|
use_condition: bool = True
|
|
# whether to condition with both time & encoder's output
|
|
two_cond: bool = False
|
|
# number of encoders' output channels
|
|
cond_emb_channels: int = None
|
|
has_lateral: bool = False
|
|
graph_convolution_bias: bool = True
|
|
scale_bias: float = 1
|
|
|
|
def __post_init__(self):
|
|
self.out_features = self.out_features or self.in_features
|
|
self.cond_emb_channels = self.cond_emb_channels or self.emb_channels
|
|
|
|
def make_model(self):
|
|
return residual_graph_convolution(self)
|
|
|
|
|
|
class residual_graph_convolution(TimestepBlock):
|
|
def __init__(self, conf: residual_graph_convolution_config):
|
|
super(residual_graph_convolution, self).__init__()
|
|
self.conf = conf
|
|
|
|
self.gcn = graph_convolution(conf.in_features, conf.out_features, node_n=conf.node_n, seq_len=conf.seq_len, bias=conf.graph_convolution_bias)
|
|
self.ln = nn.LayerNorm([conf.out_features, conf.node_n, conf.seq_len])
|
|
self.act_f = nn.Tanh()
|
|
self.dropout = nn.Dropout(conf.dropout)
|
|
|
|
if conf.use_condition:
|
|
# condition layers for the out_layers
|
|
self.emb_layers = nn.Sequential(
|
|
nn.SiLU(),
|
|
nn.Linear(conf.emb_channels, conf.out_features),
|
|
)
|
|
|
|
if conf.two_cond:
|
|
self.cond_emb_layers = nn.Sequential(
|
|
nn.SiLU(),
|
|
nn.Linear(conf.cond_emb_channels, conf.out_features),
|
|
)
|
|
|
|
if conf.in_features == conf.out_features:
|
|
self.skip_connection = nn.Identity()
|
|
else:
|
|
self.skip_connection = nn.Sequential(
|
|
graph_convolution(conf.in_features, conf.out_features, node_n=conf.node_n, seq_len=conf.seq_len, bias=conf.graph_convolution_bias),
|
|
nn.Tanh(),
|
|
)
|
|
|
|
def forward(self, x, emb=None, cond=None, lateral=None):
|
|
if self.conf.has_lateral:
|
|
# lateral may be supplied even if it doesn't require
|
|
# the model will take the lateral only if "has_lateral"
|
|
assert lateral is not None
|
|
x = torch.cat((x, lateral), dim =1)
|
|
|
|
y = self.gcn(x)
|
|
y = self.ln(y)
|
|
|
|
if self.conf.use_condition:
|
|
if emb is not None:
|
|
emb = self.emb_layers(emb).type(x.dtype)
|
|
# adjusting shapes
|
|
while len(emb.shape) < len(y.shape):
|
|
emb = emb[..., None]
|
|
|
|
if self.conf.two_cond or True:
|
|
if cond is not None:
|
|
if not isinstance(cond, torch.Tensor):
|
|
assert isinstance(cond, dict)
|
|
cond = cond['cond']
|
|
cond = self.cond_emb_layers(cond).type(x.dtype)
|
|
while len(cond.shape) < len(y.shape):
|
|
cond = cond[..., None]
|
|
scales = [emb, cond]
|
|
else:
|
|
scales = [emb]
|
|
|
|
# condition scale bias could be a list
|
|
if isinstance(self.conf.scale_bias, Number):
|
|
biases = [self.conf.scale_bias] * len(scales)
|
|
else:
|
|
# a list
|
|
biases = self.conf.scale_bias
|
|
|
|
# scale for each condition
|
|
for i, scale in enumerate(scales):
|
|
# if scale is None, it indicates that the condition is not provided
|
|
if scale is not None:
|
|
y = y*(biases[i] + scale)
|
|
|
|
y = self.act_f(y)
|
|
y = self.dropout(y)
|
|
return self.skip_connection(x) + y
|
|
|
|
|
|
class graph_downsample(nn.Module):
|
|
"""
|
|
A downsampling layer
|
|
"""
|
|
def __init__(self, kernel_size = 2):
|
|
super().__init__()
|
|
self.downsample = nn.AvgPool1d(kernel_size = kernel_size)
|
|
|
|
def forward(self, x):
|
|
bs, features, node_n, seq_len = x.shape
|
|
x = x.reshape(bs, features*node_n, seq_len)
|
|
x = self.downsample(x)
|
|
x = x.reshape(bs, features, node_n, -1)
|
|
return x
|
|
|
|
|
|
class graph_upsample(nn.Module):
|
|
"""
|
|
An upsampling layer
|
|
"""
|
|
def __init__(self, scale_factor=2):
|
|
super().__init__()
|
|
self.scale_factor = scale_factor
|
|
|
|
def forward(self, x):
|
|
x = F.interpolate(x, (x.shape[2], x.shape[3]*self.scale_factor), mode="nearest")
|
|
return x |