184 lines
5.4 KiB
Python
184 lines
5.4 KiB
Python
|
import math
|
||
|
from dataclasses import dataclass
|
||
|
from enum import Enum
|
||
|
from typing import NamedTuple, Tuple
|
||
|
|
||
|
import torch
|
||
|
from choices import *
|
||
|
from config_base import BaseConfig
|
||
|
from torch import nn
|
||
|
from torch.nn import init
|
||
|
|
||
|
from .blocks import *
|
||
|
from .nn import timestep_embedding
|
||
|
from .unet import *
|
||
|
|
||
|
|
||
|
class LatentNetType(Enum):
|
||
|
none = 'none'
|
||
|
skip = 'skip'
|
||
|
|
||
|
class LatentNetReturn(NamedTuple):
|
||
|
pred: torch.Tensor = None
|
||
|
|
||
|
@dataclass
|
||
|
class MLPSkipNetConfig(BaseConfig):
|
||
|
"""
|
||
|
default MLP for the latent DPM in the paper!
|
||
|
"""
|
||
|
num_channels: int
|
||
|
skip_layers: Tuple[int]
|
||
|
num_hid_channels: int
|
||
|
num_layers: int
|
||
|
num_time_emb_channels: int = 64
|
||
|
activation: Activation = Activation.silu
|
||
|
use_norm: bool = True
|
||
|
condition_bias: float = 1
|
||
|
dropout: float = 0
|
||
|
last_act: Activation = Activation.none
|
||
|
num_time_layers: int = 2
|
||
|
time_last_act: bool = False
|
||
|
|
||
|
def make_model(self):
|
||
|
return MLPSkipNet(self)
|
||
|
|
||
|
|
||
|
class MLPSkipNet(nn.Module):
|
||
|
"""
|
||
|
concat x to hidden layers
|
||
|
|
||
|
default MLP for the latent DPM in the paper!
|
||
|
"""
|
||
|
def __init__(self, conf: MLPSkipNetConfig):
|
||
|
super().__init__()
|
||
|
self.conf = conf
|
||
|
|
||
|
layers = []
|
||
|
for i in range(conf.num_time_layers):
|
||
|
if i == 0:
|
||
|
a = conf.num_time_emb_channels
|
||
|
b = conf.num_channels
|
||
|
else:
|
||
|
a = conf.num_channels
|
||
|
b = conf.num_channels
|
||
|
layers.append(nn.Linear(a, b))
|
||
|
if i < conf.num_time_layers - 1 or conf.time_last_act:
|
||
|
layers.append(conf.activation.get_act())
|
||
|
self.time_embed = nn.Sequential(*layers)
|
||
|
|
||
|
self.layers = nn.ModuleList([])
|
||
|
for i in range(conf.num_layers):
|
||
|
if i == 0:
|
||
|
act = conf.activation
|
||
|
norm = conf.use_norm
|
||
|
cond = True
|
||
|
a, b = conf.num_channels, conf.num_hid_channels
|
||
|
dropout = conf.dropout
|
||
|
elif i == conf.num_layers - 1:
|
||
|
act = Activation.none
|
||
|
norm = False
|
||
|
cond = False
|
||
|
a, b = conf.num_hid_channels, conf.num_channels
|
||
|
dropout = 0
|
||
|
else:
|
||
|
act = conf.activation
|
||
|
norm = conf.use_norm
|
||
|
cond = True
|
||
|
a, b = conf.num_hid_channels, conf.num_hid_channels
|
||
|
dropout = conf.dropout
|
||
|
|
||
|
if i in conf.skip_layers:
|
||
|
a += conf.num_channels
|
||
|
|
||
|
self.layers.append(
|
||
|
MLPLNAct(
|
||
|
a,
|
||
|
b,
|
||
|
norm=norm,
|
||
|
activation=act,
|
||
|
cond_channels=conf.num_channels,
|
||
|
use_cond=cond,
|
||
|
condition_bias=conf.condition_bias,
|
||
|
dropout=dropout,
|
||
|
))
|
||
|
self.last_act = conf.last_act.get_act()
|
||
|
|
||
|
def forward(self, x, t, **kwargs):
|
||
|
t = timestep_embedding(t, self.conf.num_time_emb_channels)
|
||
|
cond = self.time_embed(t)
|
||
|
h = x
|
||
|
for i in range(len(self.layers)):
|
||
|
if i in self.conf.skip_layers:
|
||
|
h = torch.cat([h, x], dim=1)
|
||
|
h = self.layers[i].forward(x=h, cond=cond)
|
||
|
h = self.last_act(h)
|
||
|
return LatentNetReturn(h)
|
||
|
|
||
|
|
||
|
class MLPLNAct(nn.Module):
|
||
|
def __init__(
|
||
|
self,
|
||
|
in_channels: int,
|
||
|
out_channels: int,
|
||
|
norm: bool,
|
||
|
use_cond: bool,
|
||
|
activation: Activation,
|
||
|
cond_channels: int,
|
||
|
condition_bias: float = 0,
|
||
|
dropout: float = 0,
|
||
|
):
|
||
|
super().__init__()
|
||
|
self.activation = activation
|
||
|
self.condition_bias = condition_bias
|
||
|
self.use_cond = use_cond
|
||
|
|
||
|
self.linear = nn.Linear(in_channels, out_channels)
|
||
|
self.act = activation.get_act()
|
||
|
if self.use_cond:
|
||
|
self.linear_emb = nn.Linear(cond_channels, out_channels)
|
||
|
self.cond_layers = nn.Sequential(self.act, self.linear_emb)
|
||
|
if norm:
|
||
|
self.norm = nn.LayerNorm(out_channels)
|
||
|
else:
|
||
|
self.norm = nn.Identity()
|
||
|
|
||
|
if dropout > 0:
|
||
|
self.dropout = nn.Dropout(p=dropout)
|
||
|
else:
|
||
|
self.dropout = nn.Identity()
|
||
|
|
||
|
self.init_weights()
|
||
|
|
||
|
def init_weights(self):
|
||
|
for module in self.modules():
|
||
|
if isinstance(module, nn.Linear):
|
||
|
if self.activation == Activation.relu:
|
||
|
init.kaiming_normal_(module.weight,
|
||
|
a=0,
|
||
|
nonlinearity='relu')
|
||
|
elif self.activation == Activation.lrelu:
|
||
|
init.kaiming_normal_(module.weight,
|
||
|
a=0.2,
|
||
|
nonlinearity='leaky_relu')
|
||
|
elif self.activation == Activation.silu:
|
||
|
init.kaiming_normal_(module.weight,
|
||
|
a=0,
|
||
|
nonlinearity='relu')
|
||
|
else:
|
||
|
pass
|
||
|
|
||
|
def forward(self, x, cond=None):
|
||
|
x = self.linear(x)
|
||
|
if self.use_cond:
|
||
|
cond = self.cond_layers(cond)
|
||
|
cond = (cond, None)
|
||
|
|
||
|
x = x * (self.condition_bias + cond[0])
|
||
|
if cond[1] is not None:
|
||
|
x = x + cond[1]
|
||
|
x = self.norm(x)
|
||
|
else:
|
||
|
x = self.norm(x)
|
||
|
x = self.act(x)
|
||
|
x = self.dropout(x)
|
||
|
return x
|