update readme
This commit is contained in:
parent
35ee4b75e8
commit
249a01f342
18 changed files with 4936 additions and 0 deletions
173
model/graph_convolution_network.py
Normal file
173
model/graph_convolution_network.py
Normal file
|
@ -0,0 +1,173 @@
|
|||
import torch.nn as nn
|
||||
import torch
|
||||
from dataclasses import dataclass
|
||||
from torch.nn.parameter import Parameter
|
||||
from numbers import Number
|
||||
import torch.nn.functional as F
|
||||
from .blocks import *
|
||||
import math
|
||||
|
||||
|
||||
class graph_convolution(nn.Module):
|
||||
def __init__(self, in_features, out_features, node_n = 3, seq_len = 80, bias=True):
|
||||
super(graph_convolution, self).__init__()
|
||||
|
||||
self.temporal_graph_weights = Parameter(torch.FloatTensor(seq_len, seq_len))
|
||||
self.feature_weights = Parameter(torch.FloatTensor(in_features, out_features))
|
||||
self.spatial_graph_weights = Parameter(torch.FloatTensor(node_n, node_n))
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(torch.FloatTensor(seq_len))
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
stdv = 1. / math.sqrt(self.spatial_graph_weights.size(1))
|
||||
self.feature_weights.data.uniform_(-stdv, stdv)
|
||||
self.temporal_graph_weights.data.uniform_(-stdv, stdv)
|
||||
self.spatial_graph_weights.data.uniform_(-stdv, stdv)
|
||||
if self.bias is not None:
|
||||
self.bias.data.uniform_(-stdv, stdv)
|
||||
|
||||
def forward(self, x):
|
||||
y = torch.matmul(x, self.temporal_graph_weights)
|
||||
y = torch.matmul(y.permute(0, 3, 2, 1), self.feature_weights)
|
||||
y = torch.matmul(self.spatial_graph_weights, y).permute(0, 3, 2, 1).contiguous()
|
||||
|
||||
if self.bias is not None:
|
||||
return (y + self.bias)
|
||||
else:
|
||||
return y
|
||||
|
||||
|
||||
@dataclass
|
||||
class residual_graph_convolution_config():
|
||||
in_features: int
|
||||
seq_len: int
|
||||
emb_channels: int
|
||||
dropout: float
|
||||
out_features: int = None
|
||||
node_n: int = 3
|
||||
# condition the block with time (and encoder's output)
|
||||
use_condition: bool = True
|
||||
# whether to condition with both time & encoder's output
|
||||
two_cond: bool = False
|
||||
# number of encoders' output channels
|
||||
cond_emb_channels: int = None
|
||||
has_lateral: bool = False
|
||||
graph_convolution_bias: bool = True
|
||||
scale_bias: float = 1
|
||||
|
||||
def __post_init__(self):
|
||||
self.out_features = self.out_features or self.in_features
|
||||
self.cond_emb_channels = self.cond_emb_channels or self.emb_channels
|
||||
|
||||
def make_model(self):
|
||||
return residual_graph_convolution(self)
|
||||
|
||||
|
||||
class residual_graph_convolution(TimestepBlock):
|
||||
def __init__(self, conf: residual_graph_convolution_config):
|
||||
super(residual_graph_convolution, self).__init__()
|
||||
self.conf = conf
|
||||
|
||||
self.gcn = graph_convolution(conf.in_features, conf.out_features, node_n=conf.node_n, seq_len=conf.seq_len, bias=conf.graph_convolution_bias)
|
||||
self.ln = nn.LayerNorm([conf.out_features, conf.node_n, conf.seq_len])
|
||||
self.act_f = nn.Tanh()
|
||||
self.dropout = nn.Dropout(conf.dropout)
|
||||
|
||||
if conf.use_condition:
|
||||
# condition layers for the out_layers
|
||||
self.emb_layers = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(conf.emb_channels, conf.out_features),
|
||||
)
|
||||
|
||||
if conf.two_cond:
|
||||
self.cond_emb_layers = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(conf.cond_emb_channels, conf.out_features),
|
||||
)
|
||||
|
||||
if conf.in_features == conf.out_features:
|
||||
self.skip_connection = nn.Identity()
|
||||
else:
|
||||
self.skip_connection = nn.Sequential(
|
||||
graph_convolution(conf.in_features, conf.out_features, node_n=conf.node_n, seq_len=conf.seq_len, bias=conf.graph_convolution_bias),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
def forward(self, x, emb=None, cond=None, lateral=None):
|
||||
if self.conf.has_lateral:
|
||||
# lateral may be supplied even if it doesn't require
|
||||
# the model will take the lateral only if "has_lateral"
|
||||
assert lateral is not None
|
||||
x = torch.cat((x, lateral), dim =1)
|
||||
|
||||
y = self.gcn(x)
|
||||
y = self.ln(y)
|
||||
|
||||
if self.conf.use_condition:
|
||||
if emb is not None:
|
||||
emb = self.emb_layers(emb).type(x.dtype)
|
||||
# adjusting shapes
|
||||
while len(emb.shape) < len(y.shape):
|
||||
emb = emb[..., None]
|
||||
|
||||
if self.conf.two_cond or True:
|
||||
if cond is not None:
|
||||
if not isinstance(cond, torch.Tensor):
|
||||
assert isinstance(cond, dict)
|
||||
cond = cond['cond']
|
||||
cond = self.cond_emb_layers(cond).type(x.dtype)
|
||||
while len(cond.shape) < len(y.shape):
|
||||
cond = cond[..., None]
|
||||
scales = [emb, cond]
|
||||
else:
|
||||
scales = [emb]
|
||||
|
||||
# condition scale bias could be a list
|
||||
if isinstance(self.conf.scale_bias, Number):
|
||||
biases = [self.conf.scale_bias] * len(scales)
|
||||
else:
|
||||
# a list
|
||||
biases = self.conf.scale_bias
|
||||
|
||||
# scale for each condition
|
||||
for i, scale in enumerate(scales):
|
||||
# if scale is None, it indicates that the condition is not provided
|
||||
if scale is not None:
|
||||
y = y*(biases[i] + scale)
|
||||
|
||||
y = self.act_f(y)
|
||||
y = self.dropout(y)
|
||||
return self.skip_connection(x) + y
|
||||
|
||||
|
||||
class graph_downsample(nn.Module):
|
||||
"""
|
||||
A downsampling layer
|
||||
"""
|
||||
def __init__(self, kernel_size = 2):
|
||||
super().__init__()
|
||||
self.downsample = nn.AvgPool1d(kernel_size = kernel_size)
|
||||
|
||||
def forward(self, x):
|
||||
bs, features, node_n, seq_len = x.shape
|
||||
x = x.reshape(bs, features*node_n, seq_len)
|
||||
x = self.downsample(x)
|
||||
x = x.reshape(bs, features, node_n, -1)
|
||||
return x
|
||||
|
||||
|
||||
class graph_upsample(nn.Module):
|
||||
"""
|
||||
An upsampling layer
|
||||
"""
|
||||
def __init__(self, scale_factor=2):
|
||||
super().__init__()
|
||||
self.scale_factor = scale_factor
|
||||
|
||||
def forward(self, x):
|
||||
x = F.interpolate(x, (x.shape[2], x.shape[3]*self.scale_factor), mode="nearest")
|
||||
return x
|
Loading…
Add table
Add a link
Reference in a new issue