update readme
This commit is contained in:
parent
35ee4b75e8
commit
249a01f342
18 changed files with 4936 additions and 0 deletions
6
model/__init__.py
Normal file
6
model/__init__.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
from typing import Union
|
||||
from .unet import BeatGANsUNetModel, BeatGANsUNetConfig, GCNUNetModel, GCNUNetConfig
|
||||
from .unet_autoenc import BeatGANsAutoencConfig, BeatGANsAutoencModel, GCNAutoencConfig, GCNAutoencModel
|
||||
|
||||
Model = Union[BeatGANsUNetModel, BeatGANsAutoencModel, GCNUNetModel, GCNAutoencModel]
|
||||
ModelConfig = Union[BeatGANsUNetConfig, BeatGANsAutoencConfig, GCNUNetConfig,GCNAutoencConfig]
|
579
model/blocks.py
Normal file
579
model/blocks.py
Normal file
|
@ -0,0 +1,579 @@
|
|||
import math, pdb
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from numbers import Number
|
||||
|
||||
import torch as th
|
||||
import torch.nn.functional as F
|
||||
from choices import *
|
||||
from config_base import BaseConfig
|
||||
from torch import nn
|
||||
import numpy as np
|
||||
from .nn import (avg_pool_nd, conv_nd, linear, normalization,
|
||||
timestep_embedding, zero_module)
|
||||
|
||||
|
||||
class TimestepBlock(nn.Module):
|
||||
"""
|
||||
Any module where forward() takes timestep embeddings as a second argument.
|
||||
"""
|
||||
@abstractmethod
|
||||
def forward(self, x, emb=None, cond=None, lateral=None):
|
||||
"""
|
||||
Apply the module to `x` given `emb` timestep embeddings.
|
||||
"""
|
||||
|
||||
|
||||
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
||||
"""
|
||||
A sequential module that passes timestep embeddings to the children that
|
||||
support it as an extra input.
|
||||
"""
|
||||
def forward(self, x, emb=None, cond=None, lateral=None):
|
||||
for layer in self:
|
||||
if isinstance(layer, TimestepBlock):
|
||||
'''if layer(x, emb=emb, cond=cond, lateral=lateral).shape[-1]==10:
|
||||
pdb.set_trace()'''
|
||||
x = layer(x, emb=emb, cond=cond, lateral=lateral)
|
||||
else:
|
||||
'''if layer(x).shape[-1]==10:
|
||||
pdb.set_trace()'''
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResBlockConfig(BaseConfig):
|
||||
channels: int
|
||||
emb_channels: int
|
||||
dropout: float
|
||||
out_channels: int = None
|
||||
# condition the resblock with time (and encoder's output)
|
||||
use_condition: bool = True
|
||||
# whether to use 3x3 conv for skip path when the channels aren't matched
|
||||
use_conv: bool = False
|
||||
# dimension of conv (always 1 = 1d)
|
||||
dims: int = 1
|
||||
up: bool = False
|
||||
down: bool = False
|
||||
# whether to condition with both time & encoder's output
|
||||
two_cond: bool = False
|
||||
# number of encoders' output channels
|
||||
cond_emb_channels: int = None
|
||||
# suggest: False
|
||||
has_lateral: bool = False
|
||||
lateral_channels: int = None
|
||||
# whether to init the convolution with zero weights
|
||||
# this is default from BeatGANs and seems to help learning
|
||||
use_zero_module: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
self.out_channels = self.out_channels or self.channels
|
||||
self.cond_emb_channels = self.cond_emb_channels or self.emb_channels
|
||||
|
||||
def make_model(self):
|
||||
return ResBlock(self)
|
||||
|
||||
|
||||
class ResBlock(TimestepBlock):
|
||||
"""
|
||||
A residual block that can optionally change the number of channels.
|
||||
|
||||
total layers:
|
||||
in_layers
|
||||
- norm
|
||||
- act
|
||||
- conv
|
||||
out_layers
|
||||
- norm
|
||||
- (modulation)
|
||||
- act
|
||||
- conv
|
||||
"""
|
||||
def __init__(self, conf: ResBlockConfig):
|
||||
super().__init__()
|
||||
self.conf = conf
|
||||
|
||||
#############################
|
||||
# IN LAYERS
|
||||
#############################
|
||||
assert conf.lateral_channels is None
|
||||
layers = [
|
||||
normalization(conf.channels),
|
||||
nn.SiLU(),
|
||||
conv_nd(conf.dims, conf.channels, conf.out_channels, 3, padding=1) ## 3 is kernel size
|
||||
]
|
||||
self.in_layers = nn.Sequential(*layers)
|
||||
|
||||
self.updown = conf.up or conf.down
|
||||
|
||||
if conf.up:
|
||||
self.h_upd = Upsample(conf.channels, False, conf.dims)
|
||||
self.x_upd = Upsample(conf.channels, False, conf.dims)
|
||||
elif conf.down:
|
||||
self.h_upd = Downsample(conf.channels, False, conf.dims)
|
||||
self.x_upd = Downsample(conf.channels, False, conf.dims)
|
||||
else:
|
||||
self.h_upd = self.x_upd = nn.Identity()
|
||||
|
||||
#############################
|
||||
# OUT LAYERS CONDITIONS
|
||||
#############################
|
||||
if conf.use_condition:
|
||||
# condition layers for the out_layers
|
||||
self.emb_layers = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
linear(conf.emb_channels, 2 * conf.out_channels),
|
||||
)
|
||||
|
||||
if conf.two_cond:
|
||||
self.cond_emb_layers = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
linear(conf.cond_emb_channels, conf.out_channels),
|
||||
)
|
||||
#############################
|
||||
# OUT LAYERS (ignored when there is no condition)
|
||||
#############################
|
||||
# original version
|
||||
conv = conv_nd(conf.dims,
|
||||
conf.out_channels,
|
||||
conf.out_channels,
|
||||
3,
|
||||
padding=1)
|
||||
if conf.use_zero_module:
|
||||
# zere out the weights
|
||||
# it seems to help training
|
||||
conv = zero_module(conv)
|
||||
|
||||
# construct the layers
|
||||
# - norm
|
||||
# - (modulation)
|
||||
# - act
|
||||
# - dropout
|
||||
# - conv
|
||||
layers = []
|
||||
layers += [
|
||||
normalization(conf.out_channels),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(p=conf.dropout),
|
||||
conv,
|
||||
]
|
||||
self.out_layers = nn.Sequential(*layers)
|
||||
|
||||
#############################
|
||||
# SKIP LAYERS
|
||||
#############################
|
||||
if conf.out_channels == conf.channels:
|
||||
# cannot be used with gatedconv, also gatedconv is alsways used as the first block
|
||||
self.skip_connection = nn.Identity()
|
||||
else:
|
||||
if conf.use_conv:
|
||||
kernel_size = 3
|
||||
padding = 1
|
||||
else:
|
||||
kernel_size = 1
|
||||
padding = 0
|
||||
|
||||
self.skip_connection = conv_nd(conf.dims,
|
||||
conf.channels,
|
||||
conf.out_channels,
|
||||
kernel_size,
|
||||
padding=padding)
|
||||
|
||||
def forward(self, x, emb=None, cond=None, lateral=None):
|
||||
"""
|
||||
Apply the block to a Tensor, conditioned on a timestep embedding.
|
||||
|
||||
Args:
|
||||
x: input
|
||||
lateral: lateral connection from the encoder
|
||||
"""
|
||||
return self._forward(x, emb, cond, lateral)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x,
|
||||
emb=None,
|
||||
cond=None,
|
||||
lateral=None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
lateral: required if "has_lateral" and non-gated, with gated, it can be supplied optionally
|
||||
"""
|
||||
if self.conf.has_lateral:
|
||||
# lateral may be supplied even if it doesn't require
|
||||
# the model will take the lateral only if "has_lateral"
|
||||
assert lateral is not None
|
||||
# x = F.interpolate(x, size=(lateral.size(2)), mode='linear' )
|
||||
x = th.cat([x, lateral], dim=1)
|
||||
|
||||
if self.updown:
|
||||
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
||||
h = in_rest(x)
|
||||
h = self.h_upd(h)
|
||||
x = self.x_upd(x)
|
||||
h = in_conv(h)
|
||||
else:
|
||||
h = self.in_layers(x)
|
||||
|
||||
if self.conf.use_condition:
|
||||
# it's possible that the network may not receieve the time emb
|
||||
# this happens with autoenc and setting the time_at
|
||||
if emb is not None:
|
||||
emb_out = self.emb_layers(emb).type(h.dtype)
|
||||
else:
|
||||
emb_out = None
|
||||
|
||||
if self.conf.two_cond:
|
||||
# it's possible that the network is two_cond
|
||||
# but it doesn't get the second condition
|
||||
# in which case, we ignore the second condition
|
||||
# and treat as if the network has one condition
|
||||
if cond is None:
|
||||
cond_out = None
|
||||
else:
|
||||
if not isinstance(cond, th.Tensor):
|
||||
assert isinstance(cond, dict)
|
||||
cond = cond['cond']
|
||||
cond_out = self.cond_emb_layers(cond).type(h.dtype)
|
||||
if cond_out is not None:
|
||||
while len(cond_out.shape) < len(h.shape):
|
||||
cond_out = cond_out[..., None]
|
||||
else:
|
||||
cond_out = None
|
||||
|
||||
# this is the new refactored code
|
||||
h = apply_conditions(
|
||||
h=h,
|
||||
emb=emb_out,
|
||||
cond=cond_out,
|
||||
layers=self.out_layers,
|
||||
scale_bias=1,
|
||||
in_channels=self.conf.out_channels,
|
||||
up_down_layer=None,
|
||||
)
|
||||
|
||||
return self.skip_connection(x) + h
|
||||
|
||||
|
||||
def apply_conditions(
|
||||
h,
|
||||
emb=None,
|
||||
cond=None,
|
||||
layers: nn.Sequential = None,
|
||||
scale_bias: float = 1,
|
||||
in_channels: int = 512,
|
||||
up_down_layer: nn.Module = None,
|
||||
):
|
||||
"""
|
||||
apply conditions on the feature maps
|
||||
|
||||
Args:
|
||||
emb: time conditional (ready to scale + shift)
|
||||
cond: encoder's conditional (ready to scale + shift)
|
||||
"""
|
||||
two_cond = emb is not None and cond is not None
|
||||
|
||||
if emb is not None:
|
||||
# adjusting shapes
|
||||
while len(emb.shape) < len(h.shape):
|
||||
emb = emb[..., None]
|
||||
|
||||
if two_cond:
|
||||
# adjusting shapes
|
||||
while len(cond.shape) < len(h.shape):
|
||||
cond = cond[..., None]
|
||||
# time first
|
||||
scale_shifts = [emb, cond]
|
||||
else:
|
||||
# "cond" is not used with single cond mode
|
||||
scale_shifts = [emb]
|
||||
|
||||
# support scale, shift or shift only
|
||||
for i, each in enumerate(scale_shifts):
|
||||
if each is None:
|
||||
# special case: the condition is not provided
|
||||
a = None
|
||||
b = None
|
||||
else:
|
||||
if each.shape[1] == in_channels * 2:
|
||||
a, b = th.chunk(each, 2, dim=1)
|
||||
else:
|
||||
a = each
|
||||
b = None
|
||||
scale_shifts[i] = (a, b)
|
||||
|
||||
# condition scale bias could be a list
|
||||
if isinstance(scale_bias, Number):
|
||||
biases = [scale_bias] * len(scale_shifts)
|
||||
else:
|
||||
# a list
|
||||
biases = scale_bias
|
||||
|
||||
# default, the scale & shift are applied after the group norm but BEFORE SiLU
|
||||
pre_layers, post_layers = layers[0], layers[1:]
|
||||
|
||||
# spilt the post layer to be able to scale up or down before conv
|
||||
# post layers will contain only the conv
|
||||
mid_layers, post_layers = post_layers[:-2], post_layers[-2:]
|
||||
|
||||
h = pre_layers(h)
|
||||
# scale and shift for each condition
|
||||
for i, (scale, shift) in enumerate(scale_shifts):
|
||||
# if scale is None, it indicates that the condition is not provided
|
||||
if scale is not None:
|
||||
h = h * (biases[i] + scale)
|
||||
if shift is not None:
|
||||
h = h + shift
|
||||
h = mid_layers(h)
|
||||
|
||||
# upscale or downscale if any just before the last conv
|
||||
if up_down_layer is not None:
|
||||
h = up_down_layer(h)
|
||||
h = post_layers(h)
|
||||
return h
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
"""
|
||||
An upsampling layer with an optional convolution.
|
||||
|
||||
:param channels: channels in the inputs and outputs.
|
||||
:param use_conv: a bool determining if a convolution is applied.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
upsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
if use_conv:
|
||||
self.conv = conv_nd(dims,
|
||||
self.channels,
|
||||
self.out_channels,
|
||||
3,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
if self.dims == 3:
|
||||
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
|
||||
mode="nearest")
|
||||
else:
|
||||
# if x.shape[2] == 4:
|
||||
# feature = 9
|
||||
# x = F.interpolate(x, size=(feature), mode="nearest")
|
||||
# if x.shape[2] == 8:
|
||||
# feature = 9
|
||||
# x = F.interpolate(x, size=(feature), mode="nearest")
|
||||
# else:
|
||||
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
||||
if self.use_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
"""
|
||||
A downsampling layer with an optional convolution.
|
||||
|
||||
:param channels: channels in the inputs and outputs.
|
||||
:param use_conv: a bool determining if a convolution is applied.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
downsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
self.stride = 2 if dims != 3 else (1, 2, 2)
|
||||
if use_conv:
|
||||
self.op = conv_nd(dims,
|
||||
self.channels,
|
||||
self.out_channels,
|
||||
3,
|
||||
stride=self.stride,
|
||||
padding=1)
|
||||
else:
|
||||
assert self.channels == self.out_channels
|
||||
self.op = avg_pool_nd(dims, kernel_size=self.stride, stride=self.stride)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
# if x.shape[2] % 2 != 0:
|
||||
# op = avg_pool_nd(self.dims, kernel_size=3, stride=2)
|
||||
# return op(x)
|
||||
# if x.shape[2] % 2 != 0:
|
||||
# op = avg_pool_nd(self.dims, kernel_size=2, stride=1)
|
||||
# return op(x)
|
||||
# else:
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
"""
|
||||
An attention block that allows spatial positions to attend to each other.
|
||||
|
||||
Originally ported from here, but adapted to the N-d case.
|
||||
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
num_heads=1,
|
||||
num_head_channels=-1,
|
||||
use_new_attention_order=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
if num_head_channels == -1:
|
||||
self.num_heads = num_heads
|
||||
else:
|
||||
assert (
|
||||
channels % num_head_channels == 0
|
||||
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
||||
self.num_heads = channels // num_head_channels
|
||||
self.norm = normalization(channels)
|
||||
self.qkv = conv_nd(1, channels, channels * 3, 1)
|
||||
if use_new_attention_order:
|
||||
# split qkv before split heads
|
||||
self.attention = QKVAttention(self.num_heads)
|
||||
else:
|
||||
# split heads before split qkv
|
||||
self.attention = QKVAttentionLegacy(self.num_heads)
|
||||
|
||||
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
|
||||
|
||||
def forward(self, x):
|
||||
return self._forward(x)
|
||||
|
||||
def _forward(self, x):
|
||||
b, c, *spatial = x.shape
|
||||
x = x.reshape(b, c, -1)
|
||||
qkv = self.qkv(self.norm(x))
|
||||
h = self.attention(qkv)
|
||||
h = self.proj_out(h)
|
||||
return (x + h).reshape(b, c, *spatial)
|
||||
|
||||
|
||||
def count_flops_attn(model, _x, y):
|
||||
"""
|
||||
A counter for the `thop` package to count the operations in an
|
||||
attention operation.
|
||||
Meant to be used like:
|
||||
macs, params = thop.profile(
|
||||
model,
|
||||
inputs=(inputs, timestamps),
|
||||
custom_ops={QKVAttention: QKVAttention.count_flops},
|
||||
)
|
||||
"""
|
||||
b, c, *spatial = y[0].shape
|
||||
num_spatial = int(np.prod(spatial))
|
||||
# We perform two matmuls with the same number of ops.
|
||||
# The first computes the weight matrix, the second computes
|
||||
# the combination of the value vectors.
|
||||
matmul_ops = 2 * b * (num_spatial**2) * c
|
||||
model.total_ops += th.DoubleTensor([matmul_ops])
|
||||
|
||||
|
||||
class QKVAttentionLegacy(nn.Module):
|
||||
"""
|
||||
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
|
||||
"""
|
||||
def __init__(self, n_heads):
|
||||
super().__init__()
|
||||
self.n_heads = n_heads
|
||||
|
||||
def forward(self, qkv):
|
||||
"""
|
||||
Apply QKV attention.
|
||||
|
||||
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
|
||||
:return: an [N x (H * C) x T] tensor after attention.
|
||||
"""
|
||||
bs, width, length = qkv.shape
|
||||
assert width % (3 * self.n_heads) == 0
|
||||
ch = width // (3 * self.n_heads)
|
||||
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||
weight = th.einsum(
|
||||
"bct,bcs->bts", q * scale,
|
||||
k * scale) # More stable with f16 than dividing afterwards
|
||||
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
a = th.einsum("bts,bcs->bct", weight, v)
|
||||
return a.reshape(bs, -1, length)
|
||||
|
||||
@staticmethod
|
||||
def count_flops(model, _x, y):
|
||||
return count_flops_attn(model, _x, y)
|
||||
|
||||
|
||||
class QKVAttention(nn.Module):
|
||||
"""
|
||||
A module which performs QKV attention and splits in a different order.
|
||||
"""
|
||||
def __init__(self, n_heads):
|
||||
super().__init__()
|
||||
self.n_heads = n_heads
|
||||
|
||||
def forward(self, qkv):
|
||||
"""
|
||||
Apply QKV attention.
|
||||
|
||||
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
|
||||
:return: an [N x (H * C) x T] tensor after attention.
|
||||
"""
|
||||
bs, width, length = qkv.shape
|
||||
pdb.set_trace()
|
||||
assert width % (3 * self.n_heads) == 0
|
||||
ch = width // (3 * self.n_heads)
|
||||
q, k, v = qkv.chunk(3, dim=1)
|
||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||
weight = th.einsum(
|
||||
"bct,bcs->bts",
|
||||
(q * scale).view(bs * self.n_heads, ch, length),
|
||||
(k * scale).view(bs * self.n_heads, ch, length),
|
||||
) # More stable with f16 than dividing afterwards
|
||||
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
a = th.einsum("bts,bcs->bct", weight,
|
||||
v.reshape(bs * self.n_heads, ch, length))
|
||||
return a.reshape(bs, -1, length)
|
||||
|
||||
@staticmethod
|
||||
def count_flops(model, _x, y):
|
||||
return count_flops_attn(model, _x, y)
|
||||
|
||||
|
||||
class AttentionPool2d(nn.Module):
|
||||
"""
|
||||
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
spacial_dim: int,
|
||||
embed_dim: int,
|
||||
num_heads_channels: int,
|
||||
output_dim: int = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.positional_embedding = nn.Parameter(
|
||||
th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5)
|
||||
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
|
||||
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
|
||||
self.num_heads = embed_dim // num_heads_channels
|
||||
self.attention = QKVAttention(self.num_heads)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, *_spatial = x.shape
|
||||
x = x.reshape(b, c, -1) # NC(HW)
|
||||
x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
|
||||
x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
|
||||
x = self.qkv_proj(x)
|
||||
x = self.attention(x)
|
||||
x = self.c_proj(x)
|
||||
return x[:, :, 0]
|
173
model/graph_convolution_network.py
Normal file
173
model/graph_convolution_network.py
Normal file
|
@ -0,0 +1,173 @@
|
|||
import torch.nn as nn
|
||||
import torch
|
||||
from dataclasses import dataclass
|
||||
from torch.nn.parameter import Parameter
|
||||
from numbers import Number
|
||||
import torch.nn.functional as F
|
||||
from .blocks import *
|
||||
import math
|
||||
|
||||
|
||||
class graph_convolution(nn.Module):
|
||||
def __init__(self, in_features, out_features, node_n = 3, seq_len = 80, bias=True):
|
||||
super(graph_convolution, self).__init__()
|
||||
|
||||
self.temporal_graph_weights = Parameter(torch.FloatTensor(seq_len, seq_len))
|
||||
self.feature_weights = Parameter(torch.FloatTensor(in_features, out_features))
|
||||
self.spatial_graph_weights = Parameter(torch.FloatTensor(node_n, node_n))
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(torch.FloatTensor(seq_len))
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
stdv = 1. / math.sqrt(self.spatial_graph_weights.size(1))
|
||||
self.feature_weights.data.uniform_(-stdv, stdv)
|
||||
self.temporal_graph_weights.data.uniform_(-stdv, stdv)
|
||||
self.spatial_graph_weights.data.uniform_(-stdv, stdv)
|
||||
if self.bias is not None:
|
||||
self.bias.data.uniform_(-stdv, stdv)
|
||||
|
||||
def forward(self, x):
|
||||
y = torch.matmul(x, self.temporal_graph_weights)
|
||||
y = torch.matmul(y.permute(0, 3, 2, 1), self.feature_weights)
|
||||
y = torch.matmul(self.spatial_graph_weights, y).permute(0, 3, 2, 1).contiguous()
|
||||
|
||||
if self.bias is not None:
|
||||
return (y + self.bias)
|
||||
else:
|
||||
return y
|
||||
|
||||
|
||||
@dataclass
|
||||
class residual_graph_convolution_config():
|
||||
in_features: int
|
||||
seq_len: int
|
||||
emb_channels: int
|
||||
dropout: float
|
||||
out_features: int = None
|
||||
node_n: int = 3
|
||||
# condition the block with time (and encoder's output)
|
||||
use_condition: bool = True
|
||||
# whether to condition with both time & encoder's output
|
||||
two_cond: bool = False
|
||||
# number of encoders' output channels
|
||||
cond_emb_channels: int = None
|
||||
has_lateral: bool = False
|
||||
graph_convolution_bias: bool = True
|
||||
scale_bias: float = 1
|
||||
|
||||
def __post_init__(self):
|
||||
self.out_features = self.out_features or self.in_features
|
||||
self.cond_emb_channels = self.cond_emb_channels or self.emb_channels
|
||||
|
||||
def make_model(self):
|
||||
return residual_graph_convolution(self)
|
||||
|
||||
|
||||
class residual_graph_convolution(TimestepBlock):
|
||||
def __init__(self, conf: residual_graph_convolution_config):
|
||||
super(residual_graph_convolution, self).__init__()
|
||||
self.conf = conf
|
||||
|
||||
self.gcn = graph_convolution(conf.in_features, conf.out_features, node_n=conf.node_n, seq_len=conf.seq_len, bias=conf.graph_convolution_bias)
|
||||
self.ln = nn.LayerNorm([conf.out_features, conf.node_n, conf.seq_len])
|
||||
self.act_f = nn.Tanh()
|
||||
self.dropout = nn.Dropout(conf.dropout)
|
||||
|
||||
if conf.use_condition:
|
||||
# condition layers for the out_layers
|
||||
self.emb_layers = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(conf.emb_channels, conf.out_features),
|
||||
)
|
||||
|
||||
if conf.two_cond:
|
||||
self.cond_emb_layers = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(conf.cond_emb_channels, conf.out_features),
|
||||
)
|
||||
|
||||
if conf.in_features == conf.out_features:
|
||||
self.skip_connection = nn.Identity()
|
||||
else:
|
||||
self.skip_connection = nn.Sequential(
|
||||
graph_convolution(conf.in_features, conf.out_features, node_n=conf.node_n, seq_len=conf.seq_len, bias=conf.graph_convolution_bias),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
def forward(self, x, emb=None, cond=None, lateral=None):
|
||||
if self.conf.has_lateral:
|
||||
# lateral may be supplied even if it doesn't require
|
||||
# the model will take the lateral only if "has_lateral"
|
||||
assert lateral is not None
|
||||
x = torch.cat((x, lateral), dim =1)
|
||||
|
||||
y = self.gcn(x)
|
||||
y = self.ln(y)
|
||||
|
||||
if self.conf.use_condition:
|
||||
if emb is not None:
|
||||
emb = self.emb_layers(emb).type(x.dtype)
|
||||
# adjusting shapes
|
||||
while len(emb.shape) < len(y.shape):
|
||||
emb = emb[..., None]
|
||||
|
||||
if self.conf.two_cond or True:
|
||||
if cond is not None:
|
||||
if not isinstance(cond, torch.Tensor):
|
||||
assert isinstance(cond, dict)
|
||||
cond = cond['cond']
|
||||
cond = self.cond_emb_layers(cond).type(x.dtype)
|
||||
while len(cond.shape) < len(y.shape):
|
||||
cond = cond[..., None]
|
||||
scales = [emb, cond]
|
||||
else:
|
||||
scales = [emb]
|
||||
|
||||
# condition scale bias could be a list
|
||||
if isinstance(self.conf.scale_bias, Number):
|
||||
biases = [self.conf.scale_bias] * len(scales)
|
||||
else:
|
||||
# a list
|
||||
biases = self.conf.scale_bias
|
||||
|
||||
# scale for each condition
|
||||
for i, scale in enumerate(scales):
|
||||
# if scale is None, it indicates that the condition is not provided
|
||||
if scale is not None:
|
||||
y = y*(biases[i] + scale)
|
||||
|
||||
y = self.act_f(y)
|
||||
y = self.dropout(y)
|
||||
return self.skip_connection(x) + y
|
||||
|
||||
|
||||
class graph_downsample(nn.Module):
|
||||
"""
|
||||
A downsampling layer
|
||||
"""
|
||||
def __init__(self, kernel_size = 2):
|
||||
super().__init__()
|
||||
self.downsample = nn.AvgPool1d(kernel_size = kernel_size)
|
||||
|
||||
def forward(self, x):
|
||||
bs, features, node_n, seq_len = x.shape
|
||||
x = x.reshape(bs, features*node_n, seq_len)
|
||||
x = self.downsample(x)
|
||||
x = x.reshape(bs, features, node_n, -1)
|
||||
return x
|
||||
|
||||
|
||||
class graph_upsample(nn.Module):
|
||||
"""
|
||||
An upsampling layer
|
||||
"""
|
||||
def __init__(self, scale_factor=2):
|
||||
super().__init__()
|
||||
self.scale_factor = scale_factor
|
||||
|
||||
def forward(self, x):
|
||||
x = F.interpolate(x, (x.shape[2], x.shape[3]*self.scale_factor), mode="nearest")
|
||||
return x
|
141
model/nn.py
Normal file
141
model/nn.py
Normal file
|
@ -0,0 +1,141 @@
|
|||
"""
|
||||
Various utilities for neural networks.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
import math, pdb
|
||||
from typing import Optional
|
||||
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
||||
class SiLU(nn.Module):
|
||||
# @th.jit.script
|
||||
def forward(self, x):
|
||||
return x * th.sigmoid(x)
|
||||
|
||||
|
||||
class GroupNorm32(nn.GroupNorm):
|
||||
def forward(self, x):
|
||||
y = super().forward(x.float()).type(x.dtype)
|
||||
return y
|
||||
|
||||
|
||||
def conv_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D convolution module.
|
||||
"""
|
||||
assert dims==1
|
||||
if dims == 1:
|
||||
return nn.Conv1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.Conv2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.Conv3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
def linear(*args, **kwargs):
|
||||
"""
|
||||
Create a linear module.
|
||||
"""
|
||||
return nn.Linear(*args, **kwargs)
|
||||
|
||||
|
||||
def avg_pool_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D average pooling module.
|
||||
"""
|
||||
assert dims==1
|
||||
if dims == 1:
|
||||
return nn.AvgPool1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.AvgPool2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.AvgPool3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
def update_ema(target_params, source_params, rate=0.99):
|
||||
"""
|
||||
Update target parameters to be closer to those of source parameters using
|
||||
an exponential moving average.
|
||||
|
||||
:param target_params: the target parameter sequence.
|
||||
:param source_params: the source parameter sequence.
|
||||
:param rate: the EMA rate (closer to 1 means slower).
|
||||
"""
|
||||
for targ, src in zip(target_params, source_params):
|
||||
targ.detach().mul_(rate).add_(src, alpha=1 - rate)
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
def scale_module(module, scale):
|
||||
"""
|
||||
Scale the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().mul_(scale)
|
||||
return module
|
||||
|
||||
|
||||
def mean_flat(tensor):
|
||||
"""
|
||||
Take the mean over all non-batch dimensions.
|
||||
"""
|
||||
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
||||
|
||||
|
||||
def normalization(channels):
|
||||
"""
|
||||
Make a standard normalization layer.
|
||||
|
||||
:param channels: number of input channels.
|
||||
:return: an nn.Module for normalization.
|
||||
"""
|
||||
# return GroupNorm32(channels, channels)
|
||||
return GroupNorm32(min(4, channels), channels)
|
||||
|
||||
|
||||
def timestep_embedding(timesteps, dim, max_period=10000):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
|
||||
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an [N x dim] Tensor of positional embeddings.
|
||||
"""
|
||||
half = dim // 2
|
||||
freqs = th.exp(-math.log(max_period) *
|
||||
th.arange(start=0, end=half, dtype=th.float32) /
|
||||
half).to(device=timesteps.device)
|
||||
args = timesteps[:, None].float() * freqs[None]
|
||||
embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = th.cat(
|
||||
[embedding, th.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
|
||||
def torch_checkpoint(func, args, flag, preserve_rng_state=False):
|
||||
# torch's gradient checkpoint works with automatic mixed precision, given torch >= 1.8
|
||||
if flag:
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
func, *args, preserve_rng_state=preserve_rng_state)
|
||||
else:
|
||||
return func(*args)
|
954
model/unet.py
Normal file
954
model/unet.py
Normal file
|
@ -0,0 +1,954 @@
|
|||
import math
|
||||
from dataclasses import dataclass
|
||||
from numbers import Number
|
||||
from typing import NamedTuple, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from choices import *
|
||||
from config_base import BaseConfig
|
||||
from .blocks import *
|
||||
from .graph_convolution_network import *
|
||||
from .nn import (conv_nd, linear, normalization, timestep_embedding, zero_module)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BeatGANsUNetConfig(BaseConfig):
|
||||
seq_len: int = 80
|
||||
in_channels: int = 9
|
||||
# base channels, will be multiplied
|
||||
model_channels: int = 64
|
||||
# output of the unet
|
||||
out_channels: int = 9
|
||||
# how many repeating resblocks per resolution
|
||||
# the decoding side would have "one more" resblock
|
||||
# default: 2
|
||||
num_res_blocks: int = 2
|
||||
# number of time embed channels and style channels
|
||||
embed_channels: int = 256
|
||||
# at what resolutions you want to do self-attention of the feature maps
|
||||
# attentions generally improve performance
|
||||
attention_resolutions: Tuple[int] = (0, )
|
||||
# dropout applies to the resblocks (on feature maps)
|
||||
dropout: float = 0.1
|
||||
channel_mult: Tuple[int] = (1, 2, 4)
|
||||
conv_resample: bool = True
|
||||
# 1 = 1d conv
|
||||
dims: int = 1
|
||||
# number of attention heads
|
||||
num_heads: int = 1
|
||||
# or specify the number of channels per attention head
|
||||
num_head_channels: int = -1
|
||||
# use resblock for upscale/downscale blocks (expensive)
|
||||
# default: True (BeatGANs)
|
||||
resblock_updown: bool = True
|
||||
use_new_attention_order: bool = False
|
||||
resnet_two_cond: bool = True
|
||||
resnet_cond_channels: int = None
|
||||
# init the decoding conv layers with zero weights, this speeds up training
|
||||
# default: True (BeatGANs)
|
||||
resnet_use_zero_module: bool = True
|
||||
|
||||
def make_model(self):
|
||||
return BeatGANsUNetModel(self)
|
||||
|
||||
|
||||
class BeatGANsUNetModel(nn.Module):
|
||||
def __init__(self, conf: BeatGANsUNetConfig):
|
||||
super().__init__()
|
||||
self.conf = conf
|
||||
|
||||
self.dtype = th.float32
|
||||
|
||||
self.time_emb_channels = conf.model_channels
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(self.time_emb_channels, conf.embed_channels),
|
||||
nn.SiLU(),
|
||||
linear(conf.embed_channels, conf.embed_channels),
|
||||
)
|
||||
|
||||
ch = input_ch = int(conf.channel_mult[0] * conf.model_channels)
|
||||
self.input_blocks = nn.ModuleList([
|
||||
TimestepEmbedSequential(
|
||||
conv_nd(conf.dims, conf.in_channels, ch, 3, padding=1)),
|
||||
])
|
||||
|
||||
kwargs = dict(
|
||||
use_condition=True,
|
||||
two_cond=conf.resnet_two_cond,
|
||||
use_zero_module=conf.resnet_use_zero_module,
|
||||
# style channels for the resnet block
|
||||
cond_emb_channels=conf.resnet_cond_channels,
|
||||
)
|
||||
|
||||
self._feature_size = ch
|
||||
|
||||
# input_block_chans = [ch]
|
||||
input_block_chans = [[] for _ in range(len(conf.channel_mult))]
|
||||
input_block_chans[0].append(ch)
|
||||
|
||||
# number of blocks at each resolution
|
||||
self.input_num_blocks = [0 for _ in range(len(conf.channel_mult))]
|
||||
self.input_num_blocks[0] = 1
|
||||
self.output_num_blocks = [0 for _ in range(len(conf.channel_mult))]
|
||||
|
||||
ds = 1
|
||||
resolution = conf.seq_len
|
||||
for level, mult in enumerate(conf.channel_mult):
|
||||
for _ in range(conf.num_res_blocks):
|
||||
layers = [
|
||||
ResBlockConfig(
|
||||
ch,
|
||||
conf.embed_channels,
|
||||
conf.dropout,
|
||||
out_channels=int(mult * conf.model_channels),
|
||||
dims=conf.dims,
|
||||
**kwargs,
|
||||
).make_model()
|
||||
]
|
||||
ch = int(mult * conf.model_channels)
|
||||
if resolution in conf.attention_resolutions:
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
num_heads=conf.num_heads,
|
||||
num_head_channels=conf.num_head_channels,
|
||||
use_new_attention_order=conf.
|
||||
use_new_attention_order,
|
||||
))
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
# input_block_chans.append(ch)
|
||||
input_block_chans[level].append(ch)
|
||||
self.input_num_blocks[level] += 1
|
||||
# print(input_block_chans)
|
||||
if level != len(conf.channel_mult) - 1:
|
||||
resolution //= 2
|
||||
out_ch = ch
|
||||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
ResBlockConfig(
|
||||
ch,
|
||||
conf.embed_channels,
|
||||
conf.dropout,
|
||||
out_channels=out_ch,
|
||||
dims=conf.dims,
|
||||
down=True,
|
||||
**kwargs,
|
||||
).make_model() if conf.
|
||||
resblock_updown else Downsample(ch,
|
||||
conf.conv_resample,
|
||||
dims=conf.dims,
|
||||
out_channels=out_ch)))
|
||||
ch = out_ch
|
||||
# input_block_chans.append(ch)
|
||||
input_block_chans[level + 1].append(ch)
|
||||
self.input_num_blocks[level + 1] += 1
|
||||
ds *= 2
|
||||
self._feature_size += ch
|
||||
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
ResBlockConfig(
|
||||
ch,
|
||||
conf.embed_channels,
|
||||
conf.dropout,
|
||||
dims=conf.dims,
|
||||
**kwargs,
|
||||
).make_model(),
|
||||
#AttentionBlock(
|
||||
# ch,
|
||||
# num_heads=conf.num_heads,
|
||||
# num_head_channels=conf.num_head_channels,
|
||||
# use_new_attention_order=conf.use_new_attention_order,
|
||||
#),
|
||||
ResBlockConfig(
|
||||
ch,
|
||||
conf.embed_channels,
|
||||
conf.dropout,
|
||||
dims=conf.dims,
|
||||
**kwargs,
|
||||
).make_model(),
|
||||
)
|
||||
self._feature_size += ch
|
||||
|
||||
self.output_blocks = nn.ModuleList([])
|
||||
for level, mult in list(enumerate(conf.channel_mult))[::-1]:
|
||||
for i in range(conf.num_res_blocks + 1):
|
||||
# print(input_block_chans)
|
||||
# ich = input_block_chans.pop()
|
||||
try:
|
||||
ich = input_block_chans[level].pop()
|
||||
except IndexError:
|
||||
# this happens only when num_res_block > num_enc_res_block
|
||||
# we will not have enough lateral (skip) connecions for all decoder blocks
|
||||
ich = 0
|
||||
# print('pop:', ich)
|
||||
layers = [
|
||||
ResBlockConfig(
|
||||
# only direct channels when gated
|
||||
channels=ch + ich,
|
||||
emb_channels=conf.embed_channels,
|
||||
dropout=conf.dropout,
|
||||
out_channels=int(conf.model_channels * mult),
|
||||
dims=conf.dims,
|
||||
# lateral channels are described here when gated
|
||||
has_lateral=True if ich > 0 else False,
|
||||
lateral_channels=None,
|
||||
**kwargs,
|
||||
).make_model()
|
||||
]
|
||||
ch = int(conf.model_channels * mult)
|
||||
if resolution in conf.attention_resolutions:
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
num_heads=conf.num_heads,
|
||||
num_head_channels=conf.num_head_channels,
|
||||
use_new_attention_order=conf.
|
||||
use_new_attention_order,
|
||||
))
|
||||
if level and i == conf.num_res_blocks:
|
||||
resolution *= 2
|
||||
out_ch = ch
|
||||
layers.append(
|
||||
ResBlockConfig(
|
||||
ch,
|
||||
conf.embed_channels,
|
||||
conf.dropout,
|
||||
out_channels=out_ch,
|
||||
dims=conf.dims,
|
||||
up=True,
|
||||
**kwargs,
|
||||
).make_model() if (
|
||||
conf.resblock_updown
|
||||
) else Upsample(ch,
|
||||
conf.conv_resample,
|
||||
dims=conf.dims,
|
||||
out_channels=out_ch))
|
||||
ds //= 2
|
||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self.output_num_blocks[level] += 1
|
||||
self._feature_size += ch
|
||||
|
||||
# print(input_block_chans)
|
||||
# print('inputs:', self.input_num_blocks)
|
||||
# print('outputs:', self.output_num_blocks)
|
||||
|
||||
if conf.resnet_use_zero_module:
|
||||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
zero_module(
|
||||
conv_nd(conf.dims,
|
||||
input_ch,
|
||||
conf.out_channels,
|
||||
3, ## kernel size
|
||||
padding=1)),
|
||||
)
|
||||
else:
|
||||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
conv_nd(conf.dims, input_ch, conf.out_channels, 3, padding=1),
|
||||
)
|
||||
|
||||
def forward(self, x, t, **kwargs):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
|
||||
:param x: an [N x C x ...] Tensor of inputs.
|
||||
:param timesteps: a 1-D batch of timesteps.
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
|
||||
# hs = []
|
||||
hs = [[] for _ in range(len(self.conf.channel_mult))]
|
||||
emb = self.time_embed(timestep_embedding(t, self.time_emb_channels))
|
||||
|
||||
# new code supports input_num_blocks != output_num_blocks
|
||||
h = x.type(self.dtype)
|
||||
k = 0
|
||||
for i in range(len(self.input_num_blocks)):
|
||||
for j in range(self.input_num_blocks[i]):
|
||||
h = self.input_blocks[k](h, emb=emb)
|
||||
# print(i, j, h.shape)
|
||||
hs[i].append(h) ## Get output from each layer
|
||||
k += 1
|
||||
assert k == len(self.input_blocks)
|
||||
|
||||
# middle blocks
|
||||
h = self.middle_block(h, emb=emb)
|
||||
|
||||
# output blocks
|
||||
k = 0
|
||||
for i in range(len(self.output_num_blocks)):
|
||||
for j in range(self.output_num_blocks[i]):
|
||||
# take the lateral connection from the same layer (in reserve)
|
||||
# until there is no more, use None
|
||||
try:
|
||||
lateral = hs[-i - 1].pop()
|
||||
# print(i, j, lateral.shape)
|
||||
except IndexError:
|
||||
lateral = None
|
||||
# print(i, j, lateral)
|
||||
h = self.output_blocks[k](h, emb=emb, lateral=lateral)
|
||||
k += 1
|
||||
|
||||
h = h.type(x.dtype)
|
||||
pred = self.out(h)
|
||||
return Return(pred=pred)
|
||||
|
||||
|
||||
class Return(NamedTuple):
|
||||
pred: th.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class BeatGANsEncoderConfig(BaseConfig):
|
||||
in_channels: int
|
||||
seq_len: int = 80
|
||||
num_res_blocks: int = 2
|
||||
attention_resolutions: Tuple[int] = (0, )
|
||||
model_channels: int = 32
|
||||
out_channels: int = 256
|
||||
dropout: float = 0.1
|
||||
channel_mult: Tuple[int] = (1, 2, 4)
|
||||
use_time_condition: bool = False
|
||||
conv_resample: bool = True
|
||||
dims: int = 1
|
||||
num_heads: int = 1
|
||||
num_head_channels: int = -1
|
||||
resblock_updown: bool = True
|
||||
use_new_attention_order: bool = False
|
||||
pool: str = 'adaptivenonzero'
|
||||
|
||||
def make_model(self):
|
||||
return BeatGANsEncoderModel(self)
|
||||
|
||||
|
||||
class BeatGANsEncoderModel(nn.Module):
|
||||
"""
|
||||
The half UNet model with attention and timestep embedding.
|
||||
|
||||
For usage, see UNet.
|
||||
"""
|
||||
def __init__(self, conf: BeatGANsEncoderConfig):
|
||||
super().__init__()
|
||||
self.conf = conf
|
||||
self.dtype = th.float32
|
||||
|
||||
if conf.use_time_condition:
|
||||
time_embed_dim = conf.model_channels
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(conf.model_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
else:
|
||||
time_embed_dim = None
|
||||
|
||||
ch = int(conf.channel_mult[0] * conf.model_channels)
|
||||
self.input_blocks = nn.ModuleList([
|
||||
TimestepEmbedSequential(
|
||||
conv_nd(conf.dims, conf.in_channels, ch, 3, padding=1),
|
||||
)
|
||||
])
|
||||
self._feature_size = ch
|
||||
input_block_chans = [ch]
|
||||
ds = 1
|
||||
resolution = conf.seq_len
|
||||
for level, mult in enumerate(conf.channel_mult):
|
||||
for _ in range(conf.num_res_blocks):
|
||||
layers = [
|
||||
ResBlockConfig(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
conf.dropout,
|
||||
out_channels=int(mult * conf.model_channels),
|
||||
dims=conf.dims,
|
||||
use_condition=conf.use_time_condition,
|
||||
).make_model()
|
||||
]
|
||||
ch = int(mult * conf.model_channels)
|
||||
if resolution in conf.attention_resolutions:
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
num_heads=conf.num_heads,
|
||||
num_head_channels=conf.num_head_channels,
|
||||
use_new_attention_order=conf.
|
||||
use_new_attention_order,
|
||||
))
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
input_block_chans.append(ch)
|
||||
if level != len(conf.channel_mult) - 1:
|
||||
resolution //= 2
|
||||
out_ch = ch
|
||||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
ResBlockConfig(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
conf.dropout,
|
||||
out_channels=out_ch,
|
||||
dims=conf.dims,
|
||||
use_condition=conf.use_time_condition,
|
||||
down=True,
|
||||
).make_model() if (
|
||||
conf.resblock_updown
|
||||
) else Downsample(ch,
|
||||
conf.conv_resample,
|
||||
dims=conf.dims,
|
||||
out_channels=out_ch)))
|
||||
ch = out_ch
|
||||
input_block_chans.append(ch)
|
||||
ds *= 2
|
||||
self._feature_size += ch
|
||||
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
ResBlockConfig(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
conf.dropout,
|
||||
dims=conf.dims,
|
||||
use_condition=conf.use_time_condition,
|
||||
).make_model(),
|
||||
AttentionBlock(
|
||||
ch,
|
||||
num_heads=conf.num_heads,
|
||||
num_head_channels=conf.num_head_channels,
|
||||
use_new_attention_order=conf.use_new_attention_order,
|
||||
),
|
||||
ResBlockConfig(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
conf.dropout,
|
||||
dims=conf.dims,
|
||||
use_condition=conf.use_time_condition,
|
||||
).make_model(),
|
||||
)
|
||||
self._feature_size += ch
|
||||
if conf.pool == "adaptivenonzero":
|
||||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
## nn.AdaptiveAvgPool2d((1, 1)),
|
||||
nn.AdaptiveAvgPool1d(1),
|
||||
conv_nd(conf.dims, ch, conf.out_channels, 1),
|
||||
nn.Flatten(),
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unexpected {conf.pool} pooling")
|
||||
|
||||
def forward(self, x, t=None):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
|
||||
:param x: an [N x C x ...] Tensor of inputs.
|
||||
:param timesteps: a 1-D batch of timesteps.
|
||||
:return: an [N x K] Tensor of outputs.
|
||||
"""
|
||||
if self.conf.use_time_condition:
|
||||
emb = self.time_embed(timestep_embedding(t, self.model_channels))
|
||||
else: ## autoencoding.py
|
||||
emb = None
|
||||
|
||||
results = []
|
||||
h = x.type(self.dtype)
|
||||
for module in self.input_blocks: ## flow input x over all the input blocks
|
||||
h = module(h, emb=emb)
|
||||
if self.conf.pool.startswith("spatial"):
|
||||
results.append(h.type(x.dtype).mean(dim=(2, 3)))
|
||||
h = self.middle_block(h, emb=emb) ## TimestepEmbedSequential(...)
|
||||
if self.conf.pool.startswith("spatial"):
|
||||
results.append(h.type(x.dtype).mean(dim=(2, 3)))
|
||||
h = th.cat(results, axis=-1)
|
||||
else: ## autoencoder.py
|
||||
h = h.type(x.dtype)
|
||||
|
||||
h = h.float()
|
||||
h = self.out(h)
|
||||
|
||||
return h
|
||||
|
||||
|
||||
@dataclass
|
||||
class GCNUNetConfig(BaseConfig):
|
||||
in_channels: int = 9
|
||||
node_n: int = 3
|
||||
seq_len: int = 80
|
||||
# base channels, will be multiplied
|
||||
model_channels: int = 32
|
||||
# output of the unet
|
||||
out_channels: int = 9
|
||||
# how many repeating resblocks per resolution
|
||||
num_res_blocks: int = 8
|
||||
# number of time embed channels and style channels
|
||||
embed_channels: int = 256
|
||||
# dropout applies to the resblocks
|
||||
dropout: float = 0.1
|
||||
channel_mult: Tuple[int] = (1, 2, 4)
|
||||
resnet_two_cond: bool = True
|
||||
|
||||
def make_model(self):
|
||||
return GCNUNetModel(self)
|
||||
|
||||
|
||||
class GCNUNetModel(nn.Module):
|
||||
def __init__(self, conf: GCNUNetConfig):
|
||||
super().__init__()
|
||||
self.conf = conf
|
||||
self.dtype = th.float32
|
||||
assert conf.in_channels%conf.node_n == 0
|
||||
self.in_features = conf.in_channels//conf.node_n
|
||||
|
||||
self.time_emb_channels = conf.model_channels*4
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(self.time_emb_channels, conf.embed_channels),
|
||||
nn.SiLU(),
|
||||
linear(conf.embed_channels, conf.embed_channels),
|
||||
)
|
||||
|
||||
ch = int(conf.channel_mult[0] * conf.model_channels)
|
||||
self.input_blocks = nn.ModuleList([
|
||||
TimestepEmbedSequential(
|
||||
graph_convolution(in_features=self.in_features, out_features=ch, node_n=conf.node_n, seq_len=conf.seq_len)),
|
||||
])
|
||||
|
||||
kwargs = dict(
|
||||
use_condition=True,
|
||||
two_cond=conf.resnet_two_cond,
|
||||
)
|
||||
|
||||
input_block_chans = [[] for _ in range(len(conf.channel_mult))]
|
||||
input_block_chans[0].append(ch)
|
||||
|
||||
# number of blocks at each resolution
|
||||
self.input_num_blocks = [0 for _ in range(len(conf.channel_mult))]
|
||||
self.input_num_blocks[0] = 1
|
||||
self.output_num_blocks = [0 for _ in range(len(conf.channel_mult))]
|
||||
|
||||
ds = 1
|
||||
resolution = conf.seq_len
|
||||
for level, mult in enumerate(conf.channel_mult):
|
||||
for _ in range(conf.num_res_blocks):
|
||||
layers = [
|
||||
residual_graph_convolution_config(
|
||||
in_features=ch,
|
||||
seq_len=resolution,
|
||||
emb_channels = conf.embed_channels,
|
||||
dropout=conf.dropout,
|
||||
out_features=int(mult * conf.model_channels),
|
||||
node_n=conf.node_n,
|
||||
**kwargs,
|
||||
).make_model()
|
||||
]
|
||||
ch = int(mult * conf.model_channels)
|
||||
self.input_blocks.append(*layers)
|
||||
input_block_chans[level].append(ch)
|
||||
self.input_num_blocks[level] += 1
|
||||
if level != len(conf.channel_mult) - 1:
|
||||
resolution //= 2
|
||||
out_ch = ch
|
||||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
graph_downsample()))
|
||||
ch = out_ch
|
||||
input_block_chans[level + 1].append(ch)
|
||||
self.input_num_blocks[level + 1] += 1
|
||||
ds *= 2
|
||||
|
||||
self.output_blocks = nn.ModuleList([])
|
||||
for level, mult in list(enumerate(conf.channel_mult))[::-1]:
|
||||
for i in range(conf.num_res_blocks + 1):
|
||||
try:
|
||||
ich = input_block_chans[level].pop()
|
||||
except IndexError:
|
||||
# this happens only when num_res_block > num_enc_res_block
|
||||
# we will not have enough lateral (skip) connecions for all decoder blocks
|
||||
ich = 0
|
||||
layers = [
|
||||
residual_graph_convolution_config(
|
||||
in_features=ch + ich,
|
||||
seq_len=resolution,
|
||||
emb_channels = conf.embed_channels,
|
||||
dropout=conf.dropout,
|
||||
out_features=int(mult * conf.model_channels),
|
||||
node_n=conf.node_n,
|
||||
has_lateral=True if ich > 0 else False,
|
||||
**kwargs,
|
||||
).make_model()
|
||||
]
|
||||
ch = int(mult*conf.model_channels)
|
||||
if level and i == conf.num_res_blocks:
|
||||
resolution *= 2
|
||||
out_ch = ch
|
||||
layers.append(graph_upsample())
|
||||
ds //= 2
|
||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self.output_num_blocks[level] += 1
|
||||
|
||||
self.out = nn.Sequential(
|
||||
graph_convolution(in_features=ch, out_features=self.in_features, node_n=conf.node_n, seq_len=conf.seq_len),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
def forward(self, x, t, **kwargs):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
|
||||
:param x: an [N x C x ...] Tensor of inputs.
|
||||
:param timesteps: a 1-D batch of timesteps.
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
bs, channels, seq_len = x.shape
|
||||
x = x.reshape(bs, self.conf.node_n, self.in_features, seq_len).permute(0, 2, 1, 3)
|
||||
|
||||
hs = [[] for _ in range(len(self.conf.channel_mult))]
|
||||
emb = self.time_embed(timestep_embedding(t, self.time_emb_channels))
|
||||
|
||||
# new code supports input_num_blocks != output_num_blocks
|
||||
h = x.type(self.dtype)
|
||||
k = 0
|
||||
for i in range(len(self.input_num_blocks)):
|
||||
for j in range(self.input_num_blocks[i]):
|
||||
h = self.input_blocks[k](h, emb=emb)
|
||||
# print(i, j, h.shape)
|
||||
hs[i].append(h) ## Get output from each layer
|
||||
k += 1
|
||||
assert k == len(self.input_blocks)
|
||||
|
||||
# output blocks
|
||||
k = 0
|
||||
for i in range(len(self.output_num_blocks)):
|
||||
for j in range(self.output_num_blocks[i]):
|
||||
# take the lateral connection from the same layer (in reserve)
|
||||
# until there is no more, use None
|
||||
try:
|
||||
lateral = hs[-i - 1].pop()
|
||||
# print(i, j, lateral.shape)
|
||||
except IndexError:
|
||||
lateral = None
|
||||
# print(i, j, lateral)
|
||||
h = self.output_blocks[k](h, emb=emb, lateral=lateral)
|
||||
k += 1
|
||||
|
||||
h = h.type(x.dtype)
|
||||
pred = self.out(h)
|
||||
pred = pred.permute(0, 2, 1, 3).reshape(bs, -1, seq_len)
|
||||
|
||||
return Return(pred=pred)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GCNEncoderConfig(BaseConfig):
|
||||
in_channels: int
|
||||
in_features = 3 # features for one node
|
||||
seq_len: int = 40
|
||||
seq_len_future: int = 3
|
||||
num_res_blocks: int = 2
|
||||
model_channels: int = 32
|
||||
out_channels: int = 32
|
||||
dropout: float = 0.1
|
||||
channel_mult: Tuple[int] = (1, 2, 4)
|
||||
use_time_condition: bool = False
|
||||
|
||||
def make_model(self):
|
||||
return GCNEncoderModel(self)
|
||||
|
||||
|
||||
class GCNEncoderModel(nn.Module):
|
||||
def __init__(self, conf: GCNEncoderConfig):
|
||||
super().__init__()
|
||||
self.conf = conf
|
||||
self.dtype = th.float32
|
||||
assert conf.in_channels%conf.in_features == 0
|
||||
self.in_features = conf.in_features
|
||||
self.node_n = conf.in_channels//conf.in_features
|
||||
|
||||
ch = int(conf.channel_mult[0] * conf.model_channels)
|
||||
self.input_blocks = nn.ModuleList([
|
||||
graph_convolution(in_features=self.in_features, out_features=ch, node_n=self.node_n, seq_len=conf.seq_len),
|
||||
])
|
||||
input_block_chans = [ch]
|
||||
ds = 1
|
||||
resolution = conf.seq_len
|
||||
for level, mult in enumerate(conf.channel_mult):
|
||||
for _ in range(conf.num_res_blocks):
|
||||
layers = [
|
||||
residual_graph_convolution_config(
|
||||
in_features=ch,
|
||||
seq_len=resolution,
|
||||
emb_channels = None,
|
||||
dropout=conf.dropout,
|
||||
out_features=int(mult * conf.model_channels),
|
||||
node_n=self.node_n,
|
||||
use_condition=conf.use_time_condition,
|
||||
).make_model()
|
||||
]
|
||||
ch = int(mult * conf.model_channels)
|
||||
self.input_blocks.append(*layers)
|
||||
input_block_chans.append(ch)
|
||||
if level != len(conf.channel_mult) - 1:
|
||||
resolution //= 2
|
||||
out_ch = ch
|
||||
self.input_blocks.append(
|
||||
graph_downsample())
|
||||
ch = out_ch
|
||||
input_block_chans.append(ch)
|
||||
ds *= 2
|
||||
|
||||
self.hand_prediction = nn.Sequential(
|
||||
conv_nd(1, ch*2, ch*2, 3, padding=1),
|
||||
nn.LayerNorm([ch*2, conf.seq_len_future]),
|
||||
nn.Tanh(),
|
||||
conv_nd(1, ch*2, self.in_features*2, 1),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
self.head_prediction = nn.Sequential(
|
||||
conv_nd(1, ch, ch, 3, padding=1),
|
||||
nn.LayerNorm([ch, conf.seq_len_future]),
|
||||
nn.Tanh(),
|
||||
conv_nd(1, ch, self.in_features, 1),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
self.out = nn.Sequential(
|
||||
nn.AdaptiveAvgPool1d(1),
|
||||
conv_nd(1, ch*self.node_n, conf.out_channels, 1),
|
||||
nn.Flatten(),
|
||||
)
|
||||
|
||||
|
||||
def forward(self, x, t=None):
|
||||
bs, channels, seq_len = x.shape
|
||||
|
||||
if self.node_n == 3: # both hand and head
|
||||
hand_last = x[:, :6, -1:].expand(-1, -1, self.conf.seq_len_future).clone() #last observed hand position
|
||||
head_last = x[:, 6:, -1:].expand(-1, -1, self.conf.seq_len_future).clone()# last observed head orientation
|
||||
|
||||
if self.node_n == 2: # hand only
|
||||
hand_last = x[:, :, -1:].expand(-1, -1, self.conf.seq_len_future).clone() #last observed hand position
|
||||
|
||||
if self.node_n == 1: # head only
|
||||
head_last = x[:, :, -1:].expand(-1, -1, self.conf.seq_len_future).clone()# last observed head orientation
|
||||
|
||||
|
||||
x = x.reshape(bs, self.node_n, self.in_features, seq_len).permute(0, 2, 1, 3)
|
||||
|
||||
h = x.type(self.dtype)
|
||||
for module in self.input_blocks:
|
||||
h = module(h)
|
||||
|
||||
h = h.type(x.dtype)
|
||||
h = h.float()
|
||||
bs, features, node_n, seq_len = h.shape
|
||||
|
||||
if self.node_n == 3: # both hand and head
|
||||
hand_features = h[:, :, :2, -self.conf.seq_len_future:].reshape(bs, features*2, -1)
|
||||
head_features = h[:, :, 2:, -self.conf.seq_len_future:].reshape(bs, features, -1)
|
||||
|
||||
pred_hand = self.hand_prediction(hand_features) + hand_last
|
||||
pred_head = self.head_prediction(head_features) + head_last
|
||||
pred_head = F.normalize(pred_head, dim=1)# normalize head orientation to unit vectors
|
||||
|
||||
if self.node_n == 2: # hand only
|
||||
hand_features = h[:, :, :, -self.conf.seq_len_future:].reshape(bs, features*2, -1)
|
||||
pred_hand = self.hand_prediction(hand_features) + hand_last
|
||||
pred_head = None
|
||||
|
||||
if self.node_n == 1: # head only
|
||||
head_features = h[:, :, :, -self.conf.seq_len_future:].reshape(bs, features, -1)
|
||||
pred_head = self.head_prediction(head_features) + head_last
|
||||
pred_head = F.normalize(pred_head, dim=1)# normalize head orientation to unit vectors
|
||||
pred_hand = None
|
||||
|
||||
h = h.reshape(bs, features*node_n, seq_len)
|
||||
h = self.out(h)
|
||||
|
||||
return h, pred_hand, pred_head
|
||||
|
||||
|
||||
@dataclass
|
||||
class CNNEncoderConfig(BaseConfig):
|
||||
in_channels: int
|
||||
seq_len: int = 40
|
||||
seq_len_future: int = 3
|
||||
out_channels: int = 128
|
||||
|
||||
def make_model(self):
|
||||
return CNNEncoderModel(self)
|
||||
|
||||
|
||||
class CNNEncoderModel(nn.Module):
|
||||
def __init__(self, conf: CNNEncoderConfig):
|
||||
super().__init__()
|
||||
self.conf = conf
|
||||
self.dtype = th.float32
|
||||
input_dim = conf.in_channels
|
||||
length = conf.seq_len
|
||||
out_channels = conf.out_channels
|
||||
|
||||
self.encoder = nn.Sequential(
|
||||
nn.Conv1d(input_dim, 32, kernel_size=3, padding=1),
|
||||
nn.LayerNorm([32, length]),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv1d(32, 32, kernel_size=3, padding=1),
|
||||
nn.LayerNorm([32, length]),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv1d(32, 32, kernel_size=3, padding=1),
|
||||
nn.LayerNorm([32, length]),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
|
||||
self.out = nn.Linear(32 * length, out_channels)
|
||||
|
||||
def forward(self, x, t=None):
|
||||
bs, channels, seq_len = x.shape
|
||||
hand_last = x[:, :6, -1:].expand(-1, -1, self.conf.seq_len_future).clone() #last observed hand position
|
||||
head_last = x[:, 6:, -1:].expand(-1, -1, self.conf.seq_len_future).clone()# last observed head orientation
|
||||
|
||||
h = x.type(self.dtype)
|
||||
h = self.encoder(h)
|
||||
h = h.view(h.shape[0], -1)
|
||||
|
||||
h = h.type(x.dtype)
|
||||
h = h.float()
|
||||
|
||||
h = self.out(h)
|
||||
return h, hand_last, head_last
|
||||
|
||||
|
||||
@dataclass
|
||||
class GRUEncoderConfig(BaseConfig):
|
||||
in_channels: int
|
||||
seq_len: int = 40
|
||||
seq_len_future: int = 3
|
||||
out_channels: int = 128
|
||||
|
||||
def make_model(self):
|
||||
return GRUEncoderModel(self)
|
||||
|
||||
|
||||
class GRUEncoderModel(nn.Module):
|
||||
def __init__(self, conf: GRUEncoderConfig):
|
||||
super().__init__()
|
||||
self.conf = conf
|
||||
self.dtype = th.float32
|
||||
input_dim = conf.in_channels
|
||||
length = conf.seq_len
|
||||
feature_channels = 32
|
||||
out_channels = conf.out_channels
|
||||
|
||||
self.encoder = nn.GRU(input_dim, feature_channels, 1, batch_first=True)
|
||||
|
||||
self.out = nn.Linear(feature_channels * length, out_channels)
|
||||
|
||||
def forward(self, x, t=None):
|
||||
bs, channels, seq_len = x.shape
|
||||
hand_last = x[:, :6, -1:].expand(-1, -1, self.conf.seq_len_future).clone() #last observed hand position
|
||||
head_last = x[:, 6:, -1:].expand(-1, -1, self.conf.seq_len_future).clone()# last observed head orientation
|
||||
|
||||
h = x.type(self.dtype)
|
||||
h, _ = self.encoder(h.permute(0, 2, 1))
|
||||
h = h.reshape(h.shape[0], -1)
|
||||
|
||||
h = h.type(x.dtype)
|
||||
h = h.float()
|
||||
|
||||
h = self.out(h)
|
||||
return h, hand_last, head_last
|
||||
|
||||
|
||||
@dataclass
|
||||
class LSTMEncoderConfig(BaseConfig):
|
||||
in_channels: int
|
||||
seq_len: int = 40
|
||||
seq_len_future: int = 3
|
||||
out_channels: int = 128
|
||||
|
||||
def make_model(self):
|
||||
return LSTMEncoderModel(self)
|
||||
|
||||
|
||||
class LSTMEncoderModel(nn.Module):
|
||||
def __init__(self, conf: LSTMEncoderConfig):
|
||||
super().__init__()
|
||||
self.conf = conf
|
||||
self.dtype = th.float32
|
||||
input_dim = conf.in_channels
|
||||
length = conf.seq_len
|
||||
feature_channels = 32
|
||||
out_channels = conf.out_channels
|
||||
|
||||
self.encoder = nn.LSTM(input_dim, feature_channels, 1, batch_first=True)
|
||||
|
||||
self.out = nn.Linear(feature_channels * length, out_channels)
|
||||
|
||||
def forward(self, x, t=None):
|
||||
bs, channels, seq_len = x.shape
|
||||
hand_last = x[:, :6, -1:].expand(-1, -1, self.conf.seq_len_future).clone() #last observed hand position
|
||||
head_last = x[:, 6:, -1:].expand(-1, -1, self.conf.seq_len_future).clone()# last observed head orientation
|
||||
|
||||
h = x.type(self.dtype)
|
||||
h, _ = self.encoder(h.permute(0, 2, 1))
|
||||
h = h.reshape(h.shape[0], -1)
|
||||
|
||||
h = h.type(x.dtype)
|
||||
h = h.float()
|
||||
|
||||
h = self.out(h)
|
||||
return h, hand_last, head_last
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLPEncoderConfig(BaseConfig):
|
||||
in_channels: int
|
||||
seq_len: int = 40
|
||||
seq_len_future: int = 3
|
||||
out_channels: int = 128
|
||||
|
||||
def make_model(self):
|
||||
return MLPEncoderModel(self)
|
||||
|
||||
|
||||
class MLPEncoderModel(nn.Module):
|
||||
def __init__(self, conf: MLPEncoderConfig):
|
||||
super().__init__()
|
||||
self.conf = conf
|
||||
self.dtype = th.float32
|
||||
input_dim = conf.in_channels
|
||||
length = conf.seq_len
|
||||
out_channels = conf.out_channels
|
||||
|
||||
linear_size = 128
|
||||
self.encoder = nn.Sequential(
|
||||
nn.Linear(length*input_dim, linear_size),
|
||||
nn.LayerNorm([linear_size]),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(linear_size, linear_size),
|
||||
nn.LayerNorm([linear_size]),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
|
||||
self.out = nn.Linear(linear_size, out_channels)
|
||||
|
||||
def forward(self, x, t=None):
|
||||
bs, channels, seq_len = x.shape
|
||||
hand_last = x[:, :6, -1:].expand(-1, -1, self.conf.seq_len_future).clone() #last observed hand position
|
||||
head_last = x[:, 6:, -1:].expand(-1, -1, self.conf.seq_len_future).clone()# last observed head orientation
|
||||
|
||||
h = x.type(self.dtype)
|
||||
|
||||
h = h.view(h.shape[0], -1)
|
||||
h = self.encoder(h)
|
||||
|
||||
h = h.type(x.dtype)
|
||||
h = h.float()
|
||||
|
||||
h = self.out(h)
|
||||
return h, hand_last, head_last
|
418
model/unet_autoenc.py
Normal file
418
model/unet_autoenc.py
Normal file
|
@ -0,0 +1,418 @@
|
|||
from enum import Enum
|
||||
|
||||
import torch, pdb
|
||||
import os
|
||||
from torch import Tensor
|
||||
from torch.nn.functional import silu
|
||||
from .unet import *
|
||||
from choices import *
|
||||
|
||||
|
||||
@dataclass
|
||||
class BeatGANsAutoencConfig(BeatGANsUNetConfig):
|
||||
seq_len_future: int = 3
|
||||
enc_out_channels: int = 128
|
||||
semantic_encoder_type: str = 'gcn'
|
||||
enc_channel_mult: Tuple[int] = None
|
||||
def make_model(self):
|
||||
return BeatGANsAutoencModel(self)
|
||||
|
||||
class BeatGANsAutoencModel(BeatGANsUNetModel):
|
||||
def __init__(self, conf: BeatGANsAutoencConfig):
|
||||
super().__init__(conf)
|
||||
self.conf = conf
|
||||
|
||||
# having only time, cond
|
||||
self.time_embed = TimeStyleSeperateEmbed(
|
||||
time_channels=conf.model_channels,
|
||||
time_out_channels=conf.embed_channels,
|
||||
)
|
||||
|
||||
if conf.semantic_encoder_type == 'gcn':
|
||||
self.encoder = GCNEncoderConfig(
|
||||
seq_len=conf.seq_len,
|
||||
seq_len_future=conf.seq_len_future,
|
||||
in_channels=conf.in_channels,
|
||||
model_channels=16,
|
||||
out_channels=conf.enc_out_channels,
|
||||
channel_mult=conf.enc_channel_mult or conf.channel_mult,
|
||||
).make_model()
|
||||
elif conf.semantic_encoder_type == '1dcnn':
|
||||
self.encoder = CNNEncoderConfig(
|
||||
seq_len=conf.seq_len,
|
||||
seq_len_future=conf.seq_len_future,
|
||||
in_channels=conf.in_channels,
|
||||
out_channels=conf.enc_out_channels,
|
||||
).make_model()
|
||||
elif conf.semantic_encoder_type == 'gru':
|
||||
# ensure deterministic behavior of RNNs
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:2"
|
||||
self.encoder = GRUEncoderConfig(
|
||||
seq_len=conf.seq_len,
|
||||
seq_len_future=conf.seq_len_future,
|
||||
in_channels=conf.in_channels,
|
||||
out_channels=conf.enc_out_channels,
|
||||
).make_model()
|
||||
elif conf.semantic_encoder_type == 'lstm':
|
||||
# ensure deterministic behavior of RNNs
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:2"
|
||||
self.encoder = LSTMEncoderConfig(
|
||||
seq_len=conf.seq_len,
|
||||
seq_len_future=conf.seq_len_future,
|
||||
in_channels=conf.in_channels,
|
||||
out_channels=conf.enc_out_channels,
|
||||
).make_model()
|
||||
elif conf.semantic_encoder_type == 'mlp':
|
||||
self.encoder = MLPEncoderConfig(
|
||||
seq_len=conf.seq_len,
|
||||
seq_len_future=conf.seq_len_future,
|
||||
in_channels=conf.in_channels,
|
||||
out_channels=conf.enc_out_channels,
|
||||
).make_model()
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
|
||||
"""
|
||||
Reparameterization trick to sample from N(mu, var) from
|
||||
N(0,1).
|
||||
:param mu: (Tensor) Mean of the latent Gaussian [B x D]
|
||||
:param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
|
||||
:return: (Tensor) [B x D]
|
||||
"""
|
||||
assert self.conf.is_stochastic
|
||||
std = torch.exp(0.5 * logvar)
|
||||
eps = torch.randn_like(std)
|
||||
return eps * std + mu
|
||||
|
||||
def sample_z(self, n: int, device):
|
||||
assert self.conf.is_stochastic
|
||||
return torch.randn(n, self.conf.enc_out_channels, device=device)
|
||||
|
||||
def noise_to_cond(self, noise: Tensor):
|
||||
raise NotImplementedError()
|
||||
assert self.conf.noise_net_conf is not None
|
||||
return self.noise_net.forward(noise)
|
||||
|
||||
def encode(self, x):
|
||||
cond, pred_hand, pred_head = self.encoder.forward(x)
|
||||
return cond, pred_hand, pred_head
|
||||
|
||||
@property
|
||||
def stylespace_sizes(self):
|
||||
modules = list(self.input_blocks.modules()) + list(
|
||||
self.middle_block.modules()) + list(self.output_blocks.modules())
|
||||
sizes = []
|
||||
for module in modules:
|
||||
if isinstance(module, ResBlock):
|
||||
linear = module.cond_emb_layers[-1]
|
||||
sizes.append(linear.weight.shape[0])
|
||||
return sizes
|
||||
|
||||
def encode_stylespace(self, x, return_vector: bool = True):
|
||||
"""
|
||||
encode to style space
|
||||
"""
|
||||
modules = list(self.input_blocks.modules()) + list(
|
||||
self.middle_block.modules()) + list(self.output_blocks.modules())
|
||||
# (n, c)
|
||||
cond = self.encoder.forward(x)
|
||||
S = []
|
||||
for module in modules:
|
||||
if isinstance(module, ResBlock):
|
||||
# (n, c')
|
||||
s = module.cond_emb_layers.forward(cond)
|
||||
S.append(s)
|
||||
|
||||
if return_vector:
|
||||
# (n, sum_c)
|
||||
return torch.cat(S, dim=1)
|
||||
else:
|
||||
return S
|
||||
|
||||
def forward(self,
|
||||
x,
|
||||
t,
|
||||
x_start=None,
|
||||
cond=None,
|
||||
style=None,
|
||||
noise=None,
|
||||
t_cond=None,
|
||||
**kwargs):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
|
||||
Args:
|
||||
x_start: the original image to encode
|
||||
cond: output of the encoder
|
||||
noise: random noise (to predict the cond)
|
||||
"""
|
||||
if t_cond is None:
|
||||
t_cond = t ## randomly sampled timestep with the size of [batch_size]
|
||||
|
||||
if noise is not None:
|
||||
# if the noise is given, we predict the cond from noise
|
||||
cond = self.noise_to_cond(noise)
|
||||
|
||||
cond_given = True
|
||||
if cond is None:
|
||||
cond_given = False
|
||||
if x is not None:
|
||||
assert len(x) == len(x_start), f'{len(x)} != {len(x_start)}'
|
||||
|
||||
cond, pred_hand, pred_head = self.encode(x_start)
|
||||
|
||||
if t is not None: ## t==t_cond
|
||||
_t_emb = timestep_embedding(t, self.conf.model_channels)
|
||||
#print("t: {}, _t_emb:{}".format(t, _t_emb))
|
||||
_t_cond_emb = timestep_embedding(t_cond, self.conf.model_channels)
|
||||
#print("t_cond: {}, _t_cond_emb:{}".format(t, _t_cond_emb))
|
||||
else:
|
||||
# this happens when training only autoenc
|
||||
_t_emb = None
|
||||
_t_cond_emb = None
|
||||
|
||||
if self.conf.resnet_two_cond:
|
||||
res = self.time_embed.forward( ## self.time_embed is an MLP
|
||||
time_emb=_t_emb,
|
||||
cond=cond,
|
||||
time_cond_emb=_t_cond_emb,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
if self.conf.resnet_two_cond:
|
||||
# two cond: first = time emb, second = cond_emb
|
||||
emb = res.time_emb
|
||||
cond_emb = res.emb
|
||||
else:
|
||||
# one cond = combined of both time and cond
|
||||
emb = res.emb
|
||||
cond_emb = None
|
||||
|
||||
# override the style if given
|
||||
style = style or res.style ## style==None, res.style: cond, torch.Size([64, 512])
|
||||
|
||||
|
||||
# where in the model to supply time conditions
|
||||
enc_time_emb = emb ## time embeddings
|
||||
mid_time_emb = emb
|
||||
dec_time_emb = emb
|
||||
# where in the model to supply style conditions
|
||||
enc_cond_emb = cond_emb ## z_sem embeddings
|
||||
mid_cond_emb = cond_emb
|
||||
dec_cond_emb = cond_emb
|
||||
|
||||
# hs = []
|
||||
hs = [[] for _ in range(len(self.conf.channel_mult))]
|
||||
|
||||
if x is not None:
|
||||
h = x.type(self.dtype)
|
||||
# input blocks
|
||||
k = 0
|
||||
for i in range(len(self.input_num_blocks)):
|
||||
for j in range(self.input_num_blocks[i]):
|
||||
h = self.input_blocks[k](h,
|
||||
emb=enc_time_emb,
|
||||
cond=enc_cond_emb)
|
||||
# print(i, j, h.shape)
|
||||
'''if h.shape[-1]%2==1:
|
||||
pdb.set_trace()'''
|
||||
hs[i].append(h)
|
||||
k += 1
|
||||
assert k == len(self.input_blocks)
|
||||
|
||||
# middle blocks
|
||||
h = self.middle_block(h, emb=mid_time_emb, cond=mid_cond_emb)
|
||||
else:
|
||||
# no lateral connections
|
||||
# happens when training only the autonecoder
|
||||
h = None
|
||||
hs = [[] for _ in range(len(self.conf.channel_mult))]
|
||||
pdb.set_trace()
|
||||
|
||||
# output blocks
|
||||
k = 0
|
||||
for i in range(len(self.output_num_blocks)):
|
||||
for j in range(self.output_num_blocks[i]):
|
||||
# take the lateral connection from the same layer (in reserve)
|
||||
# until there is no more, use None
|
||||
try:
|
||||
lateral = hs[-i - 1].pop() ## in the reverse order (symmetric)
|
||||
except IndexError:
|
||||
lateral = None
|
||||
'''print(i, j, lateral.shape, h.shape)
|
||||
if lateral.shape[-1]!=h.shape[-1]:
|
||||
pdb.set_trace()'''
|
||||
# print("h is", h.size())
|
||||
# print("lateral is", lateral.size())
|
||||
h = self.output_blocks[k](h,
|
||||
emb=dec_time_emb,
|
||||
cond=dec_cond_emb,
|
||||
lateral=lateral)
|
||||
k += 1
|
||||
|
||||
pred = self.out(h)
|
||||
# print("h:", h.shape)
|
||||
# print("pred:", pred.shape)
|
||||
|
||||
if cond_given == True:
|
||||
return AutoencReturn(pred=pred, cond=cond)
|
||||
else:
|
||||
return AutoencReturn(pred=pred, cond=cond, pred_hand=pred_hand, pred_head=pred_head)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GCNAutoencConfig(GCNUNetConfig):
|
||||
# number of style channels
|
||||
enc_out_channels: int = 256
|
||||
enc_channel_mult: Tuple[int] = None
|
||||
def make_model(self):
|
||||
return GCNAutoencModel(self)
|
||||
|
||||
|
||||
class GCNAutoencModel(GCNUNetModel):
|
||||
def __init__(self, conf: GCNAutoencConfig):
|
||||
super().__init__(conf)
|
||||
self.conf = conf
|
||||
|
||||
# having only time, cond
|
||||
self.time_emb_channels = conf.model_channels
|
||||
self.time_embed = TimeStyleSeperateEmbed(
|
||||
time_channels=self.time_emb_channels,
|
||||
time_out_channels=conf.embed_channels,
|
||||
)
|
||||
|
||||
self.encoder = GCNEncoderConfig(
|
||||
seq_len=conf.seq_len,
|
||||
in_channels=conf.in_channels,
|
||||
model_channels=32,
|
||||
out_channels=conf.enc_out_channels,
|
||||
channel_mult=conf.enc_channel_mult or conf.channel_mult,
|
||||
).make_model()
|
||||
|
||||
def encode(self, x):
|
||||
cond = self.encoder.forward(x)
|
||||
return {'cond': cond}
|
||||
|
||||
def forward(self,
|
||||
x,
|
||||
t,
|
||||
x_start=None,
|
||||
cond=None,
|
||||
**kwargs):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
|
||||
Args:
|
||||
x_start: the original input to encode
|
||||
cond: output of the encoder
|
||||
"""
|
||||
|
||||
if cond is None:
|
||||
if x is not None:
|
||||
assert len(x) == len(x_start), f'{len(x)} != {len(x_start)}'
|
||||
|
||||
tmp = self.encode(x_start)
|
||||
cond = tmp['cond']
|
||||
|
||||
if t is not None:
|
||||
_t_emb = timestep_embedding(t, self.time_emb_channels)
|
||||
else:
|
||||
# this happens when training only autoenc
|
||||
_t_emb = None
|
||||
|
||||
if self.conf.resnet_two_cond:
|
||||
res = self.time_embed.forward( ## self.time_embed is an MLP
|
||||
time_emb=_t_emb,
|
||||
cond=cond,
|
||||
)
|
||||
# two cond: first = time emb, second = cond_emb
|
||||
emb = res.time_emb
|
||||
cond_emb = res.emb
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
# where in the model to supply time conditions
|
||||
enc_time_emb = emb ## time embeddings
|
||||
mid_time_emb = emb
|
||||
dec_time_emb = emb
|
||||
enc_cond_emb = cond_emb ## z_sem embeddings
|
||||
mid_cond_emb = cond_emb
|
||||
dec_cond_emb = cond_emb
|
||||
|
||||
|
||||
bs, channels, seq_len = x.shape
|
||||
x = x.reshape(bs, self.conf.node_n, self.in_features, seq_len).permute(0, 2, 1, 3)
|
||||
hs = [[] for _ in range(len(self.conf.channel_mult))]
|
||||
h = x.type(self.dtype)
|
||||
|
||||
# input blocks
|
||||
k = 0
|
||||
for i in range(len(self.input_num_blocks)):
|
||||
for j in range(self.input_num_blocks[i]):
|
||||
h = self.input_blocks[k](h,
|
||||
emb=enc_time_emb,
|
||||
cond=enc_cond_emb)
|
||||
hs[i].append(h)
|
||||
k += 1
|
||||
assert k == len(self.input_blocks)
|
||||
|
||||
# output blocks
|
||||
k = 0
|
||||
for i in range(len(self.output_num_blocks)):
|
||||
for j in range(self.output_num_blocks[i]):
|
||||
# take the lateral connection from the same layer (in reserve)
|
||||
# until there is no more, use None
|
||||
try:
|
||||
lateral = hs[-i - 1].pop() ## in the reverse order (symmetric)
|
||||
except IndexError:
|
||||
lateral = None
|
||||
h = self.output_blocks[k](h,
|
||||
emb=dec_time_emb,
|
||||
cond=dec_cond_emb,
|
||||
lateral=lateral)
|
||||
k += 1
|
||||
|
||||
|
||||
pred = self.out(h)
|
||||
pred = pred.permute(0, 2, 1, 3).reshape(bs, -1, seq_len)
|
||||
|
||||
return AutoencReturn(pred=pred, cond=cond)
|
||||
|
||||
|
||||
class AutoencReturn(NamedTuple):
|
||||
pred: Tensor
|
||||
cond: Tensor = None
|
||||
pred_hand: Tensor = None
|
||||
pred_head: Tensor = None
|
||||
|
||||
|
||||
class EmbedReturn(NamedTuple):
|
||||
# style and time
|
||||
emb: Tensor = None
|
||||
# time only
|
||||
time_emb: Tensor = None
|
||||
# style only (but could depend on time)
|
||||
style: Tensor = None
|
||||
|
||||
|
||||
class TimeStyleSeperateEmbed(nn.Module):
|
||||
# embed only style
|
||||
def __init__(self, time_channels, time_out_channels):
|
||||
super().__init__()
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(time_channels, time_out_channels),
|
||||
nn.SiLU(),
|
||||
linear(time_out_channels, time_out_channels),
|
||||
)
|
||||
self.style = nn.Identity()
|
||||
|
||||
def forward(self, time_emb=None, cond=None, **kwargs):
|
||||
if time_emb is None:
|
||||
# happens with autoenc training mode
|
||||
time_emb = None
|
||||
else:
|
||||
time_emb = self.time_embed(time_emb)
|
||||
style = self.style(cond) ## style==cond
|
||||
return EmbedReturn(emb=style, time_emb=time_emb, style=style)
|
Loading…
Add table
Add a link
Reference in a new issue