HaHeAE/model/graph_convolution_network.py
2025-06-03 21:11:04 +02:00

173 lines
No EOL
6.4 KiB
Python

import torch.nn as nn
import torch
from dataclasses import dataclass
from torch.nn.parameter import Parameter
from numbers import Number
import torch.nn.functional as F
from .blocks import *
import math
class graph_convolution(nn.Module):
def __init__(self, in_features, out_features, node_n = 3, seq_len = 80, bias=True):
super(graph_convolution, self).__init__()
self.temporal_graph_weights = Parameter(torch.FloatTensor(seq_len, seq_len))
self.feature_weights = Parameter(torch.FloatTensor(in_features, out_features))
self.spatial_graph_weights = Parameter(torch.FloatTensor(node_n, node_n))
if bias:
self.bias = Parameter(torch.FloatTensor(seq_len))
self.reset_parameters()
def reset_parameters(self):
stdv = 1. / math.sqrt(self.spatial_graph_weights.size(1))
self.feature_weights.data.uniform_(-stdv, stdv)
self.temporal_graph_weights.data.uniform_(-stdv, stdv)
self.spatial_graph_weights.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)
def forward(self, x):
y = torch.matmul(x, self.temporal_graph_weights)
y = torch.matmul(y.permute(0, 3, 2, 1), self.feature_weights)
y = torch.matmul(self.spatial_graph_weights, y).permute(0, 3, 2, 1).contiguous()
if self.bias is not None:
return (y + self.bias)
else:
return y
@dataclass
class residual_graph_convolution_config():
in_features: int
seq_len: int
emb_channels: int
dropout: float
out_features: int = None
node_n: int = 3
# condition the block with time (and encoder's output)
use_condition: bool = True
# whether to condition with both time & encoder's output
two_cond: bool = False
# number of encoders' output channels
cond_emb_channels: int = None
has_lateral: bool = False
graph_convolution_bias: bool = True
scale_bias: float = 1
def __post_init__(self):
self.out_features = self.out_features or self.in_features
self.cond_emb_channels = self.cond_emb_channels or self.emb_channels
def make_model(self):
return residual_graph_convolution(self)
class residual_graph_convolution(TimestepBlock):
def __init__(self, conf: residual_graph_convolution_config):
super(residual_graph_convolution, self).__init__()
self.conf = conf
self.gcn = graph_convolution(conf.in_features, conf.out_features, node_n=conf.node_n, seq_len=conf.seq_len, bias=conf.graph_convolution_bias)
self.ln = nn.LayerNorm([conf.out_features, conf.node_n, conf.seq_len])
self.act_f = nn.Tanh()
self.dropout = nn.Dropout(conf.dropout)
if conf.use_condition:
# condition layers for the out_layers
self.emb_layers = nn.Sequential(
nn.SiLU(),
nn.Linear(conf.emb_channels, conf.out_features),
)
if conf.two_cond:
self.cond_emb_layers = nn.Sequential(
nn.SiLU(),
nn.Linear(conf.cond_emb_channels, conf.out_features),
)
if conf.in_features == conf.out_features:
self.skip_connection = nn.Identity()
else:
self.skip_connection = nn.Sequential(
graph_convolution(conf.in_features, conf.out_features, node_n=conf.node_n, seq_len=conf.seq_len, bias=conf.graph_convolution_bias),
nn.Tanh(),
)
def forward(self, x, emb=None, cond=None, lateral=None):
if self.conf.has_lateral:
# lateral may be supplied even if it doesn't require
# the model will take the lateral only if "has_lateral"
assert lateral is not None
x = torch.cat((x, lateral), dim =1)
y = self.gcn(x)
y = self.ln(y)
if self.conf.use_condition:
if emb is not None:
emb = self.emb_layers(emb).type(x.dtype)
# adjusting shapes
while len(emb.shape) < len(y.shape):
emb = emb[..., None]
if self.conf.two_cond or True:
if cond is not None:
if not isinstance(cond, torch.Tensor):
assert isinstance(cond, dict)
cond = cond['cond']
cond = self.cond_emb_layers(cond).type(x.dtype)
while len(cond.shape) < len(y.shape):
cond = cond[..., None]
scales = [emb, cond]
else:
scales = [emb]
# condition scale bias could be a list
if isinstance(self.conf.scale_bias, Number):
biases = [self.conf.scale_bias] * len(scales)
else:
# a list
biases = self.conf.scale_bias
# scale for each condition
for i, scale in enumerate(scales):
# if scale is None, it indicates that the condition is not provided
if scale is not None:
y = y*(biases[i] + scale)
y = self.act_f(y)
y = self.dropout(y)
return self.skip_connection(x) + y
class graph_downsample(nn.Module):
"""
A downsampling layer
"""
def __init__(self, kernel_size = 2):
super().__init__()
self.downsample = nn.AvgPool1d(kernel_size = kernel_size)
def forward(self, x):
bs, features, node_n, seq_len = x.shape
x = x.reshape(bs, features*node_n, seq_len)
x = self.downsample(x)
x = x.reshape(bs, features, node_n, -1)
return x
class graph_upsample(nn.Module):
"""
An upsampling layer
"""
def __init__(self, scale_factor=2):
super().__init__()
self.scale_factor = scale_factor
def forward(self, x):
x = F.interpolate(x, (x.shape[2], x.shape[3]*self.scale_factor), mode="nearest")
return x