This commit is contained in:
Guanhua Zhang 2024-10-08 14:18:47 +02:00
commit b102c2a534
20 changed files with 4305 additions and 0 deletions

176
model/MI.py Normal file
View file

@ -0,0 +1,176 @@
'''
Differentiable approximation to the mutual information (MI) metric.
Implementation in PyTorch
'''
# Imports #
# ----------------------------------------------------------------------
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
import os
# Note: This code snippet was taken from the discussion found at:
# https://discuss.pytorch.org/t/differentiable-torch-histc/25865/2
# By Tony-Y
class SoftHistogram1D(nn.Module):
'''
Differentiable 1D histogram calculation (supported via pytorch's autograd)
inupt:
x - N x D array, where N is the batch size and D is the length of each data series
bins - Number of bins for the histogram
min - Scalar min value to be included in the histogram
max - Scalar max value to be included in the histogram
sigma - Scalar smoothing factor fir the bin approximation via sigmoid functions.
Larger values correspond to sharper edges, and thus yield a more accurate approximation
output:
N x bins array, where each row is a histogram
'''
def __init__(self, bins=50, min=0, max=1, sigma=10):
super(SoftHistogram1D, self).__init__()
self.bins = bins
self.min = min
self.max = max
self.sigma = sigma
self.delta = float(max - min) / float(bins)
self.centers = float(min) + self.delta * (torch.arange(bins).float() + 0.5) # Bin centers
self.centers = nn.Parameter(self.centers, requires_grad=False) # Wrap for allow for cuda support
def forward(self, x):
# Replicate x and for each row remove center
x = torch.unsqueeze(x, 1) - torch.unsqueeze(self.centers, 1)
# Bin approximation using a sigmoid function
x = torch.sigmoid(self.sigma * (x + self.delta / 2)) - torch.sigmoid(self.sigma * (x - self.delta / 2))
# Sum along the non-batch dimensions
x = x.sum(dim=-1)
# x = x / x.sum(dim=-1).unsqueeze(1) # normalization
return x
# Note: This is an extension to the 2D case of the previous code snippet
class SoftHistogram2D(nn.Module):
'''
Differentiable 1D histogram calculation (supported via pytorch's autograd)
inupt:
x, y - N x D array, where N is the batch size and D is the length of each data series
(i.e. vectorized image or vectorized 3D volume)
bins - Number of bins for the histogram
min - Scalar min value to be included in the histogram
max - Scalar max value to be included in the histogram
sigma - Scalar smoothing factor fir the bin approximation via sigmoid functions.
Larger values correspond to sharper edges, and thus yield a more accurate approximation
output:
N x bins array, where each row is a histogram
'''
def __init__(self, bins=50, min=0, max=1, sigma=10):
super(SoftHistogram2D, self).__init__()
self.bins = bins
self.min = min
self.max = max
self.sigma = sigma
self.delta = float(max - min) / float(bins)
self.centers = float(min) + self.delta * (torch.arange(bins).float() + 0.5) # Bin centers
self.centers = nn.Parameter(self.centers, requires_grad=False) # Wrap for allow for cuda support
def forward(self, x, y):
assert x.size() == y.size(), "(SoftHistogram2D) x and y sizes do not match"
# Replicate x and for each row remove center
x = torch.unsqueeze(x, 1) - torch.unsqueeze(self.centers, 1)
y = torch.unsqueeze(y, 1) - torch.unsqueeze(self.centers, 1)
# Bin approximation using a sigmoid function (can be sigma_x and sigma_y respectively - same for delta)
x = torch.sigmoid(self.sigma * (x + self.delta / 2)) - torch.sigmoid(self.sigma * (x - self.delta / 2))
y = torch.sigmoid(self.sigma * (y + self.delta / 2)) - torch.sigmoid(self.sigma * (y - self.delta / 2))
# Batched matrix multiplication - this way we sum jointly
z = torch.matmul(x, y.permute((0, 2, 1)))
return z
class MI_pytorch(nn.Module):
'''
This class is a pytorch implementation of the mutual information (MI) calculation between two images.
This is an approximation, as the images' histograms rely on differentiable approximations of rectangular windows.
I(X, Y) = H(X) + H(Y) - H(X, Y) = \sum(\sum(p(X, Y) * log(p(Y, Y)/(p(X) * p(Y)))))
where H(X) = -\sum(p(x) * log(p(x))) is the entropy
'''
def __init__(self, bins=50, min=0, max=1, sigma=10, reduction='sum'):
super(MI_pytorch, self).__init__()
self.bins = bins
self.min = min
self.max = max
self.sigma = sigma
self.reduction = reduction
# 2D joint histogram
self.hist2d = SoftHistogram2D(bins, min, max, sigma)
# Epsilon - to avoid log(0)
self.eps = torch.tensor(0.00000001, dtype=torch.float32, requires_grad=False)
def forward(self, im1, im2):
'''
Forward implementation of a differentiable MI estimator for batched images
:param im1: N x ... tensor, where N is the batch size
... dimensions can take any form, i.e. 2D images or 3D volumes.
:param im2: N x ... tensor, where N is the batch size
:return: N x 1 vector - the approximate MI values between the batched im1 and im2
'''
# Check for valid inputs
assert im1.size() == im2.size(), "(MI_pytorch) Inputs should have the same dimensions."
batch_size = im1.size()[0]
# Flatten tensors
im1_flat = im1.view(im1.size()[0], -1)
im2_flat = im2.view(im2.size()[0], -1)
# Calculate joint histogram
hgram = self.hist2d(im1_flat, im2_flat)
# Convert to a joint distribution
# Pxy = torch.distributions.Categorical(probs=hgram).probs
Pxy = torch.div(hgram, torch.sum(hgram.view(hgram.size()[0], -1)))
# Calculate the marginal distributions
Py = torch.sum(Pxy, dim=1).unsqueeze(1)
Px = torch.sum(Pxy, dim=2).unsqueeze(1)
# Use the KL divergence distance to calculate the MI
Px_Py = torch.matmul(Px.permute((0, 2, 1)), Py)
# Reshape to batch_size X all_the_rest
Pxy = Pxy.reshape(batch_size, -1)
Px_Py = Px_Py.reshape(batch_size, -1)
# Calculate mutual information - this is an approximation due to the histogram calculation and eps,
# but it can handle batches
if batch_size == 1:
# No need for eps approximation in the case of a single batch
nzs = Pxy > 0 # Calculate based on the non-zero values only
mut_info = torch.matmul(Pxy[nzs], torch.log(Pxy[nzs]) - torch.log(Px_Py[nzs])) # MI calculation
else:
# For arbitrary batch size > 1
mut_info = torch.sum(Pxy * (torch.log(Pxy + self.eps) - torch.log(Px_Py + self.eps)), dim=1)
# Reduction
if self.reduction == 'sum':
mut_info = torch.sum(mut_info)
elif self.reduction == 'batchmean':
mut_info = torch.sum(mut_info)
mut_info = mut_info / float(batch_size)
elif self.reduction=='individual':
pass
return mut_info

6
model/__init__.py Normal file
View file

@ -0,0 +1,6 @@
from typing import Union
from .unet import BeatGANsUNetModel, BeatGANsUNetConfig
from .unet_autoenc import BeatGANsAutoencConfig, BeatGANsAutoencModel
Model = Union[BeatGANsUNetModel, BeatGANsAutoencModel]
ModelConfig = Union[BeatGANsUNetConfig, BeatGANsAutoencConfig]

495
model/blocks.py Normal file
View file

@ -0,0 +1,495 @@
import math
from abc import abstractmethod
from dataclasses import dataclass
from numbers import Number
import torch as th
import torch.nn.functional as F
from choices import *
from config_base import BaseConfig
from torch import nn
from .nn import (avg_pool_nd, conv_nd, linear, normalization,
timestep_embedding, torch_checkpoint, zero_module)
class ScaleAt(Enum):
after_norm = 'afternorm'
class TimestepBlock(nn.Module):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@abstractmethod
def forward(self, x, emb=None, cond=None, lateral=None):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
"""
A sequential module that passes timestep embeddings to the children that
support it as an extra input.
"""
def forward(self, x, emb=None, cond=None, lateral=None):
for layer in self:
if isinstance(layer, TimestepBlock):
x = layer(x, emb=emb, cond=cond, lateral=lateral)
else:
x = layer(x)
return x
@dataclass
class ResBlockConfig(BaseConfig):
channels: int
emb_channels: int
dropout: float
out_channels: int = None
use_condition: bool = True
use_conv: bool = False
dims: int = 2
use_checkpoint: bool = False
up: bool = False
down: bool = False
two_cond: bool = False
cond_emb_channels: int = None
has_lateral: bool = False
lateral_channels: int = None
use_zero_module: bool = True
def __post_init__(self):
self.out_channels = self.out_channels or self.channels
self.cond_emb_channels = self.cond_emb_channels or self.emb_channels
def make_model(self):
return ResBlock(self)
class ResBlock(TimestepBlock):
"""
A residual block that can optionally change the number of channels.
"""
def __init__(self, conf: ResBlockConfig):
super().__init__()
self.conf = conf
assert conf.lateral_channels is None
layers = [
normalization(conf.channels),
nn.SiLU(),
conv_nd(conf.dims, conf.channels, conf.out_channels, 3, padding=1)
]
self.in_layers = nn.Sequential(*layers)
self.updown = conf.up or conf.down
if conf.up:
self.h_upd = Upsample(conf.channels, False, conf.dims)
self.x_upd = Upsample(conf.channels, False, conf.dims)
elif conf.down:
self.h_upd = Downsample(conf.channels, False, conf.dims)
self.x_upd = Downsample(conf.channels, False, conf.dims)
else:
self.h_upd = self.x_upd = nn.Identity()
if conf.use_condition:
self.emb_layers = nn.Sequential(
nn.SiLU(),
linear(conf.emb_channels, 2 * conf.out_channels),
)
if conf.two_cond:
self.cond_emb_layers = nn.Sequential(
nn.SiLU(),
linear(conf.cond_emb_channels, conf.out_channels),
)
conv = conv_nd(conf.dims,
conf.out_channels,
conf.out_channels,
3,
padding=1)
if conf.use_zero_module:
conv = zero_module(conv)
layers = []
layers += [
normalization(conf.out_channels),
nn.SiLU(),
nn.Dropout(p=conf.dropout),
conv,
]
self.out_layers = nn.Sequential(*layers)
if conf.out_channels == conf.channels:
self.skip_connection = nn.Identity()
else:
if conf.use_conv:
kernel_size = 3
padding = 1
else:
kernel_size = 1
padding = 0
self.skip_connection = conv_nd(conf.dims,
conf.channels,
conf.out_channels,
kernel_size,
padding=padding)
def forward(self, x, emb=None, cond=None, lateral=None):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
Args:
x: input
lateral: lateral connection from the encoder
"""
return torch_checkpoint(self._forward, (x, emb, cond, lateral),
self.conf.use_checkpoint)
def _forward(
self,
x,
emb=None,
cond=None,
lateral=None,
):
"""
Args:
lateral: required if "has_lateral" and non-gated, with gated, it can be supplied optionally
"""
if self.conf.has_lateral:
assert lateral is not None
x = th.cat([x, lateral], dim=1)
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
h = in_rest(x)
h = self.h_upd(h)
x = self.x_upd(x)
h = in_conv(h)
else:
h = self.in_layers(x)
if self.conf.use_condition:
if emb is not None:
emb_out = self.emb_layers(emb).type(h.dtype)
else:
emb_out = None
if self.conf.two_cond:
if cond is None:
cond_out = None
else:
if not isinstance(cond, th.Tensor):
assert isinstance(cond, dict)
cond = cond['cond']
cond_out = self.cond_emb_layers(cond).type(h.dtype)
if cond_out is not None:
while len(cond_out.shape) < len(h.shape):
cond_out = cond_out[..., None]
else:
cond_out = None
h = apply_conditions(
h=h,
emb=emb_out,
cond=cond_out,
layers=self.out_layers,
scale_bias=1,
in_channels=self.conf.out_channels,
up_down_layer=None,
)
return self.skip_connection(x) + h
def apply_conditions(
h,
emb=None,
cond=None,
layers: nn.Sequential = None,
scale_bias: float = 1,
in_channels: int = 512,
up_down_layer: nn.Module = None,
):
"""
apply conditions on the feature maps
Args:
emb: time conditional (ready to scale + shift)
cond: encoder's conditional (read to scale + shift)
"""
two_cond = emb is not None and cond is not None
if emb is not None:
while len(emb.shape) < len(h.shape):
emb = emb[..., None]
if two_cond:
while len(cond.shape) < len(h.shape):
cond = cond[..., None]
scale_shifts = [emb, cond]
else:
scale_shifts = [emb]
for i, each in enumerate(scale_shifts):
if each is None:
a = None
b = None
else:
if each.shape[1] == in_channels * 2:
a, b = th.chunk(each, 2, dim=1)
else:
a = each
b = None
scale_shifts[i] = (a, b)
if isinstance(scale_bias, Number):
biases = [scale_bias] * len(scale_shifts)
else:
biases = scale_bias
pre_layers, post_layers = layers[0], layers[1:]
mid_layers, post_layers = post_layers[:-2], post_layers[-2:]
h = pre_layers(h)
for i, (scale, shift) in enumerate(scale_shifts):
if scale is not None:
h = h * (biases[i] + scale)
if shift is not None:
h = h + shift
h = mid_layers(h)
if up_down_layer is not None:
h = up_down_layer(h)
h = post_layers(h)
return h
class Upsample(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
if use_conv:
self.conv = conv_nd(dims,
self.channels,
self.out_channels,
3,
padding=1)
def forward(self, x):
assert x.shape[1] == self.channels
if self.dims == 3:
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
mode="nearest")
else:
x = F.interpolate(x, scale_factor=2, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
stride = 2 if dims != 3 else (1, 2, 2)
if use_conv:
self.op = conv_nd(dims,
self.channels,
self.out_channels,
3,
stride=stride,
padding=1)
else:
assert self.channels == self.out_channels
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
def forward(self, x):
assert x.shape[1] == self.channels
return self.op(x)
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def __init__(
self,
channels,
num_heads=1,
num_head_channels=-1,
use_checkpoint=False,
use_new_attention_order=False,
):
super().__init__()
self.channels = channels
if num_head_channels == -1:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.use_checkpoint = use_checkpoint
self.norm = normalization(channels)
self.qkv = conv_nd(1, channels, channels * 3, 1)
if use_new_attention_order:
self.attention = QKVAttention(self.num_heads)
else:
self.attention = QKVAttentionLegacy(self.num_heads)
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
def forward(self, x):
return torch_checkpoint(self._forward, (x, ), self.use_checkpoint)
def _forward(self, x):
b, c, *spatial = x.shape
x = x.reshape(b, c, -1)
qkv = self.qkv(self.norm(x))
h = self.attention(qkv)
h = self.proj_out(h)
return (x + h).reshape(b, c, *spatial)
def count_flops_attn(model, _x, y):
"""
A counter for the `thop` package to count the operations in an
attention operation.
Meant to be used like:
macs, params = thop.profile(
model,
inputs=(inputs, timestamps),
custom_ops={QKVAttention: QKVAttention.count_flops},
)
"""
b, c, *spatial = y[0].shape
num_spatial = int(np.prod(spatial))
matmul_ops = 2 * b * (num_spatial**2) * c
model.total_ops += th.DoubleTensor([matmul_ops])
class QKVAttentionLegacy(nn.Module):
"""
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
"""
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv):
"""
Apply QKV attention.
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch,
dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum(
"bct,bcs->bts", q * scale,
k * scale)
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
a = th.einsum("bts,bcs->bct", weight, v)
return a.reshape(bs, -1, length)
@staticmethod
def count_flops(model, _x, y):
return count_flops_attn(model, _x, y)
class QKVAttention(nn.Module):
"""
A module which performs QKV attention and splits in a different order.
"""
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv):
"""
Apply QKV attention.
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.chunk(3, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum(
"bct,bcs->bts",
(q * scale).view(bs * self.n_heads, ch, length),
(k * scale).view(bs * self.n_heads, ch, length),
)
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
a = th.einsum("bts,bcs->bct", weight,
v.reshape(bs * self.n_heads, ch, length))
return a.reshape(bs, -1, length)
@staticmethod
def count_flops(model, _x, y):
return count_flops_attn(model, _x, y)
class AttentionPool2d(nn.Module):
"""
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
"""
def __init__(
self,
spacial_dim: int,
embed_dim: int,
num_heads_channels: int,
output_dim: int = None,
):
super().__init__()
self.positional_embedding = nn.Parameter(
th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5)
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
self.num_heads = embed_dim // num_heads_channels
self.attention = QKVAttention(self.num_heads)
def forward(self, x):
b, c, *_spatial = x.shape
x = x.reshape(b, c, -1)
x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1)
x = x + self.positional_embedding[None, :, :].to(x.dtype)
x = self.qkv_proj(x)
x = self.attention(x)
x = self.c_proj(x)
return x[:, :, 0]

184
model/latentnet.py Normal file
View file

@ -0,0 +1,184 @@
import math
from dataclasses import dataclass
from enum import Enum
from typing import NamedTuple, Tuple
import torch
from choices import *
from config_base import BaseConfig
from torch import nn
from torch.nn import init
from .blocks import *
from .nn import timestep_embedding
from .unet import *
class LatentNetType(Enum):
none = 'none'
skip = 'skip'
class LatentNetReturn(NamedTuple):
pred: torch.Tensor = None
@dataclass
class MLPSkipNetConfig(BaseConfig):
"""
default MLP for the latent DPM in the paper!
"""
num_channels: int
skip_layers: Tuple[int]
num_hid_channels: int
num_layers: int
num_time_emb_channels: int = 64
activation: Activation = Activation.silu
use_norm: bool = True
condition_bias: float = 1
dropout: float = 0
last_act: Activation = Activation.none
num_time_layers: int = 2
time_last_act: bool = False
def make_model(self):
return MLPSkipNet(self)
class MLPSkipNet(nn.Module):
"""
concat x to hidden layers
default MLP for the latent DPM in the paper!
"""
def __init__(self, conf: MLPSkipNetConfig):
super().__init__()
self.conf = conf
layers = []
for i in range(conf.num_time_layers):
if i == 0:
a = conf.num_time_emb_channels
b = conf.num_channels
else:
a = conf.num_channels
b = conf.num_channels
layers.append(nn.Linear(a, b))
if i < conf.num_time_layers - 1 or conf.time_last_act:
layers.append(conf.activation.get_act())
self.time_embed = nn.Sequential(*layers)
self.layers = nn.ModuleList([])
for i in range(conf.num_layers):
if i == 0:
act = conf.activation
norm = conf.use_norm
cond = True
a, b = conf.num_channels, conf.num_hid_channels
dropout = conf.dropout
elif i == conf.num_layers - 1:
act = Activation.none
norm = False
cond = False
a, b = conf.num_hid_channels, conf.num_channels
dropout = 0
else:
act = conf.activation
norm = conf.use_norm
cond = True
a, b = conf.num_hid_channels, conf.num_hid_channels
dropout = conf.dropout
if i in conf.skip_layers:
a += conf.num_channels
self.layers.append(
MLPLNAct(
a,
b,
norm=norm,
activation=act,
cond_channels=conf.num_channels,
use_cond=cond,
condition_bias=conf.condition_bias,
dropout=dropout,
))
self.last_act = conf.last_act.get_act()
def forward(self, x, t, **kwargs):
t = timestep_embedding(t, self.conf.num_time_emb_channels)
cond = self.time_embed(t)
h = x
for i in range(len(self.layers)):
if i in self.conf.skip_layers:
h = torch.cat([h, x], dim=1)
h = self.layers[i].forward(x=h, cond=cond)
h = self.last_act(h)
return LatentNetReturn(h)
class MLPLNAct(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
norm: bool,
use_cond: bool,
activation: Activation,
cond_channels: int,
condition_bias: float = 0,
dropout: float = 0,
):
super().__init__()
self.activation = activation
self.condition_bias = condition_bias
self.use_cond = use_cond
self.linear = nn.Linear(in_channels, out_channels)
self.act = activation.get_act()
if self.use_cond:
self.linear_emb = nn.Linear(cond_channels, out_channels)
self.cond_layers = nn.Sequential(self.act, self.linear_emb)
if norm:
self.norm = nn.LayerNorm(out_channels)
else:
self.norm = nn.Identity()
if dropout > 0:
self.dropout = nn.Dropout(p=dropout)
else:
self.dropout = nn.Identity()
self.init_weights()
def init_weights(self):
for module in self.modules():
if isinstance(module, nn.Linear):
if self.activation == Activation.relu:
init.kaiming_normal_(module.weight,
a=0,
nonlinearity='relu')
elif self.activation == Activation.lrelu:
init.kaiming_normal_(module.weight,
a=0.2,
nonlinearity='leaky_relu')
elif self.activation == Activation.silu:
init.kaiming_normal_(module.weight,
a=0,
nonlinearity='relu')
else:
pass
def forward(self, x, cond=None):
x = self.linear(x)
if self.use_cond:
cond = self.cond_layers(cond)
cond = (cond, None)
x = x * (self.condition_bias + cond[0])
if cond[1] is not None:
x = x + cond[1]
x = self.norm(x)
else:
x = self.norm(x)
x = self.act(x)
x = self.dropout(x)
return x

135
model/nn.py Executable file
View file

@ -0,0 +1,135 @@
"""
Various utilities for neural networks.
"""
import math
import torch as th
import torch.nn as nn
import torch.utils.checkpoint
import torch.nn.functional as F
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
class SiLU(nn.Module):
# @th.jit.script
def forward(self, x):
return x * th.sigmoid(x)
class GroupNorm32(nn.GroupNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
assert dims==1
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def linear(*args, **kwargs):
"""
Create a linear module.
"""
return nn.Linear(*args, **kwargs)
def avg_pool_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
assert dims==1
if dims == 1:
return nn.AvgPool1d(*args, **kwargs)
elif dims == 2:
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def update_ema(target_params, source_params, rate=0.99):
"""
Update target parameters to be closer to those of source parameters using
an exponential moving average.
:param target_params: the target parameter sequence.
:param source_params: the source parameter sequence.
:param rate: the EMA rate (closer to 1 means slower).
"""
for targ, src in zip(target_params, source_params):
targ.detach().mul_(rate).add_(src, alpha=1 - rate)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def scale_module(module, scale):
"""
Scale the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().mul_(scale)
return module
def mean_flat(tensor):
"""
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))
def normalization(channels):
"""
Make a standard normalization layer.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
return GroupNorm32(min(32, channels), channels)
def timestep_embedding(timesteps, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
half = dim // 2
freqs = th.exp(-math.log(max_period) *
th.arange(start=0, end=half, dtype=th.float32) /
half).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
if dim % 2:
embedding = th.cat(
[embedding, th.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def torch_checkpoint(func, args, flag, preserve_rng_state=False):
# torch's gradient checkpoint works with automatic mixed precision, given torch >= 1.8
if flag:
return torch.utils.checkpoint.checkpoint(
func, *args, preserve_rng_state=preserve_rng_state)
else:
return func(*args)

505
model/unet.py Normal file
View file

@ -0,0 +1,505 @@
import math
from dataclasses import dataclass
from numbers import Number
from typing import NamedTuple, Tuple, Union
import numpy as np
import torch as th
from torch import nn
import torch.nn.functional as F
from choices import *
from config_base import BaseConfig
from .blocks import *
from .nn import (conv_nd, linear, normalization, timestep_embedding,
torch_checkpoint, zero_module)
@dataclass
class BeatGANsUNetConfig(BaseConfig):
image_size: int = 64
in_channels: int = 2
model_channels: int = 64
out_channels: int = 2
num_res_blocks: int = 2
num_input_res_blocks: int = None
embed_channels: int = 512
attention_resolutions: Tuple[int] = (16, )
time_embed_channels: int = None
dropout: float = 0.1
channel_mult: Tuple[int] = (1, 2, 4, 8)
input_channel_mult: Tuple[int] = None
conv_resample: bool = True
dims: int = 2
num_classes: int = None
use_checkpoint: bool = False
num_heads: int = 1
num_head_channels: int = -1
num_heads_upsample: int = -1
resblock_updown: bool = True
use_new_attention_order: bool = False
resnet_two_cond: bool = False
resnet_cond_channels: int = None
resnet_use_zero_module: bool = True
attn_checkpoint: bool = False
num_users: int = None
def make_model(self):
return BeatGANsUNetModel(self)
class BeatGANsUNetModel(nn.Module):
def __init__(self, conf: BeatGANsUNetConfig):
super().__init__()
self.conf = conf
if conf.num_heads_upsample == -1:
self.num_heads_upsample = conf.num_heads
self.dtype = th.float32
self.time_emb_channels = conf.time_embed_channels or conf.model_channels
self.time_embed = nn.Sequential(
linear(self.time_emb_channels, conf.embed_channels),
nn.SiLU(),
linear(conf.embed_channels, conf.embed_channels),
)
if conf.num_classes is not None:
self.label_emb = nn.Embedding(conf.num_classes,
conf.embed_channels)
ch = input_ch = int(conf.channel_mult[0] * conf.model_channels)
self.input_blocks = nn.ModuleList([
TimestepEmbedSequential(
conv_nd(conf.dims, conf.in_channels, ch, 3, padding=1))
])
kwargs = dict(
use_condition=True,
two_cond=conf.resnet_two_cond,
use_zero_module=conf.resnet_use_zero_module,
cond_emb_channels=conf.resnet_cond_channels,
)
self._feature_size = ch
# input_block_chans = [ch]
input_block_chans = [[] for _ in range(len(conf.channel_mult))]
input_block_chans[0].append(ch)
# number of blocks at each resolution
self.input_num_blocks = [0 for _ in range(len(conf.channel_mult))]
self.input_num_blocks[0] = 1
self.output_num_blocks = [0 for _ in range(len(conf.channel_mult))]
ds = 1
resolution = conf.image_size
for level, mult in enumerate(conf.input_channel_mult
or conf.channel_mult):
for _ in range(conf.num_input_res_blocks or conf.num_res_blocks):
layers = [
ResBlockConfig(
ch,
conf.embed_channels,
conf.dropout,
out_channels=int(mult * conf.model_channels),
dims=conf.dims,
use_checkpoint=conf.use_checkpoint,
**kwargs,
).make_model()
]
ch = int(mult * conf.model_channels)
if resolution in conf.attention_resolutions:
layers.append(
AttentionBlock(
ch,
use_checkpoint=conf.use_checkpoint
or conf.attn_checkpoint,
num_heads=conf.num_heads,
num_head_channels=conf.num_head_channels,
use_new_attention_order=conf.
use_new_attention_order,
))
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
input_block_chans[level].append(ch)
self.input_num_blocks[level] += 1
if level != len(conf.channel_mult) - 1:
resolution //= 2
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlockConfig(
ch,
conf.embed_channels,
conf.dropout,
out_channels=out_ch,
dims=conf.dims,
use_checkpoint=conf.use_checkpoint,
down=True,
**kwargs,
).make_model() if conf.
resblock_updown else Downsample(ch,
conf.conv_resample,
dims=conf.dims,
out_channels=out_ch)))
ch = out_ch
input_block_chans[level + 1].append(ch)
self.input_num_blocks[level + 1] += 1
ds *= 2
self._feature_size += ch
self.middle_block = TimestepEmbedSequential(
ResBlockConfig(
ch,
conf.embed_channels,
conf.dropout,
dims=conf.dims,
use_checkpoint=conf.use_checkpoint,
**kwargs,
).make_model(),
AttentionBlock(
ch,
use_checkpoint=conf.use_checkpoint or conf.attn_checkpoint,
num_heads=conf.num_heads,
num_head_channels=conf.num_head_channels,
use_new_attention_order=conf.use_new_attention_order,
),
ResBlockConfig(
ch,
conf.embed_channels,
conf.dropout,
dims=conf.dims,
use_checkpoint=conf.use_checkpoint,
**kwargs,
).make_model(),
)
self._feature_size += ch
self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(conf.channel_mult))[::-1]:
for i in range(conf.num_res_blocks + 1):
try:
ich = input_block_chans[level].pop()
except IndexError:
ich = 0
layers = [
ResBlockConfig(
channels=ch + ich,
emb_channels=conf.embed_channels,
dropout=conf.dropout,
out_channels=int(conf.model_channels * mult),
dims=conf.dims,
use_checkpoint=conf.use_checkpoint,
has_lateral=True if ich > 0 else False,
lateral_channels=None,
**kwargs,
).make_model()
]
ch = int(conf.model_channels * mult)
if resolution in conf.attention_resolutions:
layers.append(
AttentionBlock(
ch,
use_checkpoint=conf.use_checkpoint
or conf.attn_checkpoint,
num_heads=self.num_heads_upsample,
num_head_channels=conf.num_head_channels,
use_new_attention_order=conf.
use_new_attention_order,
))
if level and i == conf.num_res_blocks:
resolution *= 2
out_ch = ch
layers.append(
ResBlockConfig(
ch,
conf.embed_channels,
conf.dropout,
out_channels=out_ch,
dims=conf.dims,
use_checkpoint=conf.use_checkpoint,
up=True,
**kwargs,
).make_model() if (
conf.resblock_updown
) else Upsample(ch,
conf.conv_resample,
dims=conf.dims,
out_channels=out_ch))
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
self.output_num_blocks[level] += 1
self._feature_size += ch
if conf.resnet_use_zero_module:
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
zero_module(
conv_nd(conf.dims,
input_ch,
conf.out_channels,
3,
padding=1)),
)
else:
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
conv_nd(conf.dims, input_ch, conf.out_channels, 3, padding=1),
)
def forward(self, x, t, y=None, **kwargs):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs.
"""
assert (y is not None) == (
self.conf.num_classes is not None
), "must specify y if and only if the model is class-conditional"
# hs = []
hs = [[] for _ in range(len(self.conf.channel_mult))]
emb = self.time_embed(timestep_embedding(t, self.time_emb_channels))
if self.conf.num_classes is not None:
raise NotImplementedError()
h = x.type(self.dtype)
k = 0
for i in range(len(self.input_num_blocks)):
for j in range(self.input_num_blocks[i]):
h = self.input_blocks[k](h, emb=emb)
hs[i].append(h)
k += 1
assert k == len(self.input_blocks)
# middle blocks
h = self.middle_block(h, emb=emb)
# output blocks
k = 0
for i in range(len(self.output_num_blocks)):
for j in range(self.output_num_blocks[i]):
try:
lateral = hs[-i - 1].pop()
except IndexError:
lateral = None
h = self.output_blocks[k](h, emb=emb, lateral=lateral)
k += 1
h = h.type(x.dtype)
pred = self.out(h)
return Return(pred=pred)
class Return(NamedTuple):
pred: th.Tensor
@dataclass
class BeatGANsEncoderConfig(BaseConfig):
image_size: int
in_channels: int
model_channels: int
out_hid_channels: int
out_channels: int
num_res_blocks: int
attention_resolutions: Tuple[int]
dropout: float = 0
channel_mult: Tuple[int] = (1, 2, 4, 8)
use_time_condition: bool = True
conv_resample: bool = True
dims: int = 2
use_checkpoint: bool = False
num_heads: int = 1
num_head_channels: int = -1
resblock_updown: bool = False
use_new_attention_order: bool = False
pool: str = 'adaptivenonzero'
def make_model(self):
return BeatGANsEncoderModel(self)
class BeatGANsEncoderModel(nn.Module):
"""
The half UNet model with attention and timestep embedding.
For usage, see UNet.
"""
def __init__(self, conf: BeatGANsEncoderConfig):
super().__init__()
self.conf = conf
self.dtype = th.float32
if conf.use_time_condition:
time_embed_dim = conf.model_channels * 4
self.time_embed = nn.Sequential(
linear(conf.model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
else:
time_embed_dim = None
ch = int(conf.channel_mult[0] * conf.model_channels)
self.input_blocks = nn.ModuleList([
TimestepEmbedSequential(
conv_nd(conf.dims, conf.in_channels, ch, 3, padding=1))
])
self._feature_size = ch
input_block_chans = [ch]
ds = 1
resolution = conf.image_size
for level, mult in enumerate(conf.channel_mult):
for _ in range(conf.num_res_blocks):
layers = [
ResBlockConfig(
ch,
time_embed_dim,
conf.dropout,
out_channels=int(mult * conf.model_channels),
dims=conf.dims,
use_condition=conf.use_time_condition,
use_checkpoint=conf.use_checkpoint,
).make_model()
]
ch = int(mult * conf.model_channels)
if resolution in conf.attention_resolutions:
layers.append(
AttentionBlock(
ch,
use_checkpoint=conf.use_checkpoint,
num_heads=conf.num_heads,
num_head_channels=conf.num_head_channels,
use_new_attention_order=conf.
use_new_attention_order,
))
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(conf.channel_mult) - 1:
resolution //= 2
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlockConfig(
ch,
time_embed_dim,
conf.dropout,
out_channels=out_ch,
dims=conf.dims,
use_condition=conf.use_time_condition,
use_checkpoint=conf.use_checkpoint,
down=True,
).make_model() if (
conf.resblock_updown
) else Downsample(ch,
conf.conv_resample,
dims=conf.dims,
out_channels=out_ch)))
ch = out_ch
input_block_chans.append(ch)
ds *= 2
self._feature_size += ch
self.middle_block = TimestepEmbedSequential(
ResBlockConfig(
ch,
time_embed_dim,
conf.dropout,
dims=conf.dims,
use_condition=conf.use_time_condition,
use_checkpoint=conf.use_checkpoint,
).make_model(),
AttentionBlock(
ch,
use_checkpoint=conf.use_checkpoint,
num_heads=conf.num_heads,
num_head_channels=conf.num_head_channels,
use_new_attention_order=conf.use_new_attention_order,
),
ResBlockConfig(
ch,
time_embed_dim,
conf.dropout,
dims=conf.dims,
use_condition=conf.use_time_condition,
use_checkpoint=conf.use_checkpoint,
).make_model(),
)
self._feature_size += ch
if conf.pool == "adaptivenonzero":
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
nn.AdaptiveAvgPool1d(1),
conv_nd(conf.dims, ch, conf.out_channels, 1),
nn.Flatten(),
)
else:
raise NotImplementedError(f"Unexpected {conf.pool} pooling")
def forward(self, x, t=None, return_2d_feature=False):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:return: an [N x K] Tensor of outputs.
"""
if self.conf.use_time_condition:
emb = self.time_embed(timestep_embedding(t, self.model_channels))
else:
emb = None
results = []
h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb=emb)
if self.conf.pool.startswith("spatial"):
results.append(h.type(x.dtype).mean(dim=(2, 3)))
h = self.middle_block(h, emb=emb)
if self.conf.pool.startswith("spatial"):
results.append(h.type(x.dtype).mean(dim=(2, 3)))
h = th.cat(results, axis=-1)
else:
h = h.type(x.dtype)
h_2d = h
h = self.out(h)
if return_2d_feature:
return h, h_2d
else:
return h
def forward_flatten(self, x):
"""
transform the last 2d feature into a flatten vector
"""
h = self.out(x)
return h
class SuperResModel(BeatGANsUNetModel):
"""
A UNetModel that performs super-resolution.
Expects an extra kwarg `low_res` to condition on a low-resolution image.
"""
def __init__(self, image_size, in_channels, *args, **kwargs):
super().__init__(image_size, in_channels * 2, *args, **kwargs)
def forward(self, x, timesteps, low_res=None, **kwargs):
_, _, new_height, new_width = x.shape
upsampled = F.interpolate(low_res, (new_height, new_width),
mode="bilinear")
x = th.cat([x, upsampled], dim=1)
return super().forward(x, timesteps, **kwargs)

310
model/unet_autoenc.py Normal file
View file

@ -0,0 +1,310 @@
import torch
from torch import Tensor, nn
from torch.nn.functional import silu
from .latentnet import *
from .unet import *
from choices import *
@dataclass
class BeatGANsAutoencConfig(BeatGANsUNetConfig):
enc_out_channels: int = 512
enc_attn_resolutions: Tuple[int] = None
enc_pool: str = 'depthconv'
enc_num_res_block: int = 2
enc_channel_mult: Tuple[int] = None
enc_grad_checkpoint: bool = False
latent_net_conf: MLPSkipNetConfig = None
def make_model(self):
return BeatGANsAutoencModel(self)
class BeatGANsAutoencModel(BeatGANsUNetModel):
def __init__(self, conf: BeatGANsAutoencConfig):
super().__init__(conf)
self.conf = conf
self.time_embed = TimeStyleSeperateEmbed(
time_channels=conf.model_channels,
time_out_channels=conf.embed_channels,
)
self.encoder = BeatGANsEncoderConfig(
image_size=conf.image_size,
in_channels=conf.in_channels,
model_channels=conf.model_channels,
out_hid_channels=conf.enc_out_channels,
out_channels=conf.enc_out_channels,
num_res_blocks=conf.enc_num_res_block,
attention_resolutions=(conf.enc_attn_resolutions
or conf.attention_resolutions),
dropout=conf.dropout,
channel_mult=conf.enc_channel_mult or conf.channel_mult,
use_time_condition=False,
conv_resample=conf.conv_resample,
dims=conf.dims,
use_checkpoint=conf.use_checkpoint or conf.enc_grad_checkpoint,
num_heads=conf.num_heads,
num_head_channels=conf.num_head_channels,
resblock_updown=conf.resblock_updown,
use_new_attention_order=conf.use_new_attention_order,
pool=conf.enc_pool,
).make_model()
self.user_classifier = UserClassifier(conf.enc_out_channels//2, conf.num_users)
self.non_user_classifier = UserClassifierGradientReverse(conf.enc_out_channels//2, conf.num_users)
if conf.latent_net_conf is not None:
self.latent_net = conf.latent_net_conf.make_model()
def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
"""
Reparameterization trick to sample from N(mu, var) from
N(0,1).
:param mu: (Tensor) Mean of the latent Gaussian [B x D]
:param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
:return: (Tensor) [B x D]
"""
assert self.conf.is_stochastic
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps * std + mu
def sample_z(self, n: int, device):
assert self.conf.is_stochastic
return torch.randn(n, self.conf.enc_out_channels, device=device)
def noise_to_cond(self, noise: Tensor):
raise NotImplementedError()
assert self.conf.noise_net_conf is not None
return self.noise_net.forward(noise)
def encode(self, x):
cond = self.encoder.forward(x)
return {'cond': cond}
@property
def stylespace_sizes(self):
modules = list(self.input_blocks.modules()) + list(
self.middle_block.modules()) + list(self.output_blocks.modules())
sizes = []
for module in modules:
if isinstance(module, ResBlock):
linear = module.cond_emb_layers[-1]
sizes.append(linear.weight.shape[0])
return sizes
def encode_stylespace(self, x, return_vector: bool = True):
"""
encode to style space
"""
modules = list(self.input_blocks.modules()) + list(
self.middle_block.modules()) + list(self.output_blocks.modules())
cond = self.encoder.forward(x)
S = []
for module in modules:
if isinstance(module, ResBlock):
s = module.cond_emb_layers.forward(cond)
S.append(s)
if return_vector:
return torch.cat(S, dim=1)
else:
return S
def forward(self,
x,
t,
y=None,
x_start=None,
cond=None,
style=None,
noise=None,
t_cond=None,
**kwargs):
"""
Apply the model to an input batch.
Args:
x_start: the original image to encode
cond: output of the encoder
noise: random noise (to predict the cond)
"""
if t_cond is None:
t_cond = t
if noise is not None:
cond = self.noise_to_cond(noise)
if cond is None:
if x is not None:
assert len(x) == len(x_start), f'{len(x)} != {len(x_start)}'
tmp = self.encode(x_start)
cond = tmp['cond']
if t is not None:
_t_emb = timestep_embedding(t, self.conf.model_channels)
_t_cond_emb = timestep_embedding(t_cond, self.conf.model_channels)
else:
_t_emb = None
_t_cond_emb = None
if self.conf.resnet_two_cond:
res = self.time_embed.forward(
time_emb=_t_emb,
cond=cond,
time_cond_emb=_t_cond_emb
)
else:
raise NotImplementedError()
if self.conf.resnet_two_cond:
emb = res.time_emb
cond_emb = res.emb
else:
emb = res.emb
cond_emb = None
style = style or res.style
assert (y is not None) == (
self.conf.num_classes is not None
), "must specify y if and only if the model is class-conditional"
if self.conf.num_classes is not None:
raise NotImplementedError()
enc_time_emb = emb
mid_time_emb = emb
dec_time_emb = emb
enc_cond_emb = cond_emb
mid_cond_emb = cond_emb
dec_cond_emb = cond_emb
if self.conf.num_users is not None:
user_pred = self.user_classifier(cond_emb[:, :self.conf.enc_out_channels // 2])
non_user_pred = self.non_user_classifier(cond_emb[:, self.conf.enc_out_channels // 2:])
hs = [[] for _ in range(len(self.conf.channel_mult))]
if x is not None:
h = x.type(self.dtype)
# input blocks
k = 0
for i in range(len(self.input_num_blocks)):
for j in range(self.input_num_blocks[i]):
h = self.input_blocks[k](h,
emb=enc_time_emb,
cond=enc_cond_emb)
'''print(i, j, h.shape)
if h.shape[-1]%2==1:
pdb.set_trace()'''
hs[i].append(h)
k += 1
assert k == len(self.input_blocks)
# middle blocks
h = self.middle_block(h, emb=mid_time_emb, cond=mid_cond_emb)
else:
h = None
hs = [[] for _ in range(len(self.conf.channel_mult))]
# output blocks
k = 0
for i in range(len(self.output_num_blocks)):
for j in range(self.output_num_blocks[i]):
try:
lateral = hs[-i - 1].pop()
except IndexError:
lateral = None
h = self.output_blocks[k](h,
emb=dec_time_emb,
cond=dec_cond_emb,
lateral=lateral)
k += 1
pred = self.out(h)
return AutoencReturn(pred=pred, cond=cond, user_pred=user_pred, non_user_pred=non_user_pred)
class UserClassifier(nn.Module):
def __init__(self, in_channels, num_classes):
super().__init__()
self.fc = nn.Sequential(
nn.Linear(in_channels, 256),
nn.ReLU(),
nn.Linear(256, num_classes),
nn.Softmax(dim=1)
)
def forward(self, x):
return self.fc(x)
class GradReverse(torch.autograd.Function):
"""
Implement the gradient reversal layer for the convenience of domain adaptation neural network.
The forward part is the identity function while the backward part is the negative function.
"""
@staticmethod
def forward(ctx, x):
return x.view_as(x)
@staticmethod
def backward(ctx, grad_output):
return grad_output.neg()
class GradientReversalLayer(nn.Module):
def __init__(self):
super(GradientReversalLayer, self).__init__()
def forward(self, inputs):
return GradReverse.apply(inputs)
class UserClassifierGradientReverse(nn.Module):
def __init__(self, in_channels, num_classes):
super().__init__()
self.grl = GradientReversalLayer()
self.fc = UserClassifier(in_channels, num_classes)
def forward(self, x):
x = self.grl(x)
return self.fc(x)
class AutoencReturn(NamedTuple):
pred: Tensor
cond: Tensor = None
user_pred: Tensor = None
non_user_pred: Tensor = None
class EmbedReturn(NamedTuple):
emb: Tensor = None
time_emb: Tensor = None
style: Tensor = None
class TimeStyleSeperateEmbed(nn.Module):
def __init__(self, time_channels, time_out_channels):
super().__init__()
self.time_embed = nn.Sequential(
linear(time_channels, time_out_channels),
nn.SiLU(),
linear(time_out_channels, time_out_channels),
)
self.cond_combine = nn.Sequential(
nn.Linear(time_out_channels * 2, time_out_channels),
nn.SiLU()
)
self.style = nn.Identity()
def forward(self, time_emb=None, cond=None, **kwargs):
if time_emb is None:
time_emb = None
else:
time_emb = self.time_embed(time_emb)
style = self.style(cond)
return EmbedReturn(emb=style, time_emb=time_emb, style=style)