init
This commit is contained in:
commit
b102c2a534
20 changed files with 4305 additions and 0 deletions
176
model/MI.py
Normal file
176
model/MI.py
Normal file
|
@ -0,0 +1,176 @@
|
|||
'''
|
||||
Differentiable approximation to the mutual information (MI) metric.
|
||||
Implementation in PyTorch
|
||||
'''
|
||||
|
||||
# Imports #
|
||||
# ----------------------------------------------------------------------
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib import cm
|
||||
import os
|
||||
|
||||
# Note: This code snippet was taken from the discussion found at:
|
||||
# https://discuss.pytorch.org/t/differentiable-torch-histc/25865/2
|
||||
# By Tony-Y
|
||||
class SoftHistogram1D(nn.Module):
|
||||
'''
|
||||
Differentiable 1D histogram calculation (supported via pytorch's autograd)
|
||||
inupt:
|
||||
x - N x D array, where N is the batch size and D is the length of each data series
|
||||
bins - Number of bins for the histogram
|
||||
min - Scalar min value to be included in the histogram
|
||||
max - Scalar max value to be included in the histogram
|
||||
sigma - Scalar smoothing factor fir the bin approximation via sigmoid functions.
|
||||
Larger values correspond to sharper edges, and thus yield a more accurate approximation
|
||||
output:
|
||||
N x bins array, where each row is a histogram
|
||||
'''
|
||||
|
||||
def __init__(self, bins=50, min=0, max=1, sigma=10):
|
||||
super(SoftHistogram1D, self).__init__()
|
||||
self.bins = bins
|
||||
self.min = min
|
||||
self.max = max
|
||||
self.sigma = sigma
|
||||
self.delta = float(max - min) / float(bins)
|
||||
self.centers = float(min) + self.delta * (torch.arange(bins).float() + 0.5) # Bin centers
|
||||
self.centers = nn.Parameter(self.centers, requires_grad=False) # Wrap for allow for cuda support
|
||||
|
||||
def forward(self, x):
|
||||
# Replicate x and for each row remove center
|
||||
x = torch.unsqueeze(x, 1) - torch.unsqueeze(self.centers, 1)
|
||||
|
||||
# Bin approximation using a sigmoid function
|
||||
x = torch.sigmoid(self.sigma * (x + self.delta / 2)) - torch.sigmoid(self.sigma * (x - self.delta / 2))
|
||||
|
||||
# Sum along the non-batch dimensions
|
||||
x = x.sum(dim=-1)
|
||||
# x = x / x.sum(dim=-1).unsqueeze(1) # normalization
|
||||
return x
|
||||
|
||||
|
||||
# Note: This is an extension to the 2D case of the previous code snippet
|
||||
class SoftHistogram2D(nn.Module):
|
||||
'''
|
||||
Differentiable 1D histogram calculation (supported via pytorch's autograd)
|
||||
inupt:
|
||||
x, y - N x D array, where N is the batch size and D is the length of each data series
|
||||
(i.e. vectorized image or vectorized 3D volume)
|
||||
bins - Number of bins for the histogram
|
||||
min - Scalar min value to be included in the histogram
|
||||
max - Scalar max value to be included in the histogram
|
||||
sigma - Scalar smoothing factor fir the bin approximation via sigmoid functions.
|
||||
Larger values correspond to sharper edges, and thus yield a more accurate approximation
|
||||
output:
|
||||
N x bins array, where each row is a histogram
|
||||
'''
|
||||
|
||||
def __init__(self, bins=50, min=0, max=1, sigma=10):
|
||||
super(SoftHistogram2D, self).__init__()
|
||||
self.bins = bins
|
||||
self.min = min
|
||||
self.max = max
|
||||
self.sigma = sigma
|
||||
self.delta = float(max - min) / float(bins)
|
||||
self.centers = float(min) + self.delta * (torch.arange(bins).float() + 0.5) # Bin centers
|
||||
self.centers = nn.Parameter(self.centers, requires_grad=False) # Wrap for allow for cuda support
|
||||
|
||||
def forward(self, x, y):
|
||||
assert x.size() == y.size(), "(SoftHistogram2D) x and y sizes do not match"
|
||||
|
||||
# Replicate x and for each row remove center
|
||||
x = torch.unsqueeze(x, 1) - torch.unsqueeze(self.centers, 1)
|
||||
y = torch.unsqueeze(y, 1) - torch.unsqueeze(self.centers, 1)
|
||||
|
||||
# Bin approximation using a sigmoid function (can be sigma_x and sigma_y respectively - same for delta)
|
||||
x = torch.sigmoid(self.sigma * (x + self.delta / 2)) - torch.sigmoid(self.sigma * (x - self.delta / 2))
|
||||
y = torch.sigmoid(self.sigma * (y + self.delta / 2)) - torch.sigmoid(self.sigma * (y - self.delta / 2))
|
||||
|
||||
# Batched matrix multiplication - this way we sum jointly
|
||||
z = torch.matmul(x, y.permute((0, 2, 1)))
|
||||
return z
|
||||
|
||||
|
||||
class MI_pytorch(nn.Module):
|
||||
'''
|
||||
This class is a pytorch implementation of the mutual information (MI) calculation between two images.
|
||||
This is an approximation, as the images' histograms rely on differentiable approximations of rectangular windows.
|
||||
|
||||
I(X, Y) = H(X) + H(Y) - H(X, Y) = \sum(\sum(p(X, Y) * log(p(Y, Y)/(p(X) * p(Y)))))
|
||||
|
||||
where H(X) = -\sum(p(x) * log(p(x))) is the entropy
|
||||
'''
|
||||
|
||||
def __init__(self, bins=50, min=0, max=1, sigma=10, reduction='sum'):
|
||||
super(MI_pytorch, self).__init__()
|
||||
self.bins = bins
|
||||
self.min = min
|
||||
self.max = max
|
||||
self.sigma = sigma
|
||||
self.reduction = reduction
|
||||
|
||||
# 2D joint histogram
|
||||
self.hist2d = SoftHistogram2D(bins, min, max, sigma)
|
||||
|
||||
# Epsilon - to avoid log(0)
|
||||
self.eps = torch.tensor(0.00000001, dtype=torch.float32, requires_grad=False)
|
||||
|
||||
def forward(self, im1, im2):
|
||||
'''
|
||||
Forward implementation of a differentiable MI estimator for batched images
|
||||
:param im1: N x ... tensor, where N is the batch size
|
||||
... dimensions can take any form, i.e. 2D images or 3D volumes.
|
||||
:param im2: N x ... tensor, where N is the batch size
|
||||
:return: N x 1 vector - the approximate MI values between the batched im1 and im2
|
||||
'''
|
||||
|
||||
# Check for valid inputs
|
||||
assert im1.size() == im2.size(), "(MI_pytorch) Inputs should have the same dimensions."
|
||||
|
||||
batch_size = im1.size()[0]
|
||||
|
||||
# Flatten tensors
|
||||
im1_flat = im1.view(im1.size()[0], -1)
|
||||
im2_flat = im2.view(im2.size()[0], -1)
|
||||
|
||||
# Calculate joint histogram
|
||||
hgram = self.hist2d(im1_flat, im2_flat)
|
||||
|
||||
# Convert to a joint distribution
|
||||
# Pxy = torch.distributions.Categorical(probs=hgram).probs
|
||||
Pxy = torch.div(hgram, torch.sum(hgram.view(hgram.size()[0], -1)))
|
||||
|
||||
# Calculate the marginal distributions
|
||||
Py = torch.sum(Pxy, dim=1).unsqueeze(1)
|
||||
Px = torch.sum(Pxy, dim=2).unsqueeze(1)
|
||||
|
||||
# Use the KL divergence distance to calculate the MI
|
||||
Px_Py = torch.matmul(Px.permute((0, 2, 1)), Py)
|
||||
|
||||
# Reshape to batch_size X all_the_rest
|
||||
Pxy = Pxy.reshape(batch_size, -1)
|
||||
Px_Py = Px_Py.reshape(batch_size, -1)
|
||||
|
||||
# Calculate mutual information - this is an approximation due to the histogram calculation and eps,
|
||||
# but it can handle batches
|
||||
if batch_size == 1:
|
||||
# No need for eps approximation in the case of a single batch
|
||||
nzs = Pxy > 0 # Calculate based on the non-zero values only
|
||||
mut_info = torch.matmul(Pxy[nzs], torch.log(Pxy[nzs]) - torch.log(Px_Py[nzs])) # MI calculation
|
||||
else:
|
||||
# For arbitrary batch size > 1
|
||||
mut_info = torch.sum(Pxy * (torch.log(Pxy + self.eps) - torch.log(Px_Py + self.eps)), dim=1)
|
||||
|
||||
# Reduction
|
||||
if self.reduction == 'sum':
|
||||
mut_info = torch.sum(mut_info)
|
||||
elif self.reduction == 'batchmean':
|
||||
mut_info = torch.sum(mut_info)
|
||||
mut_info = mut_info / float(batch_size)
|
||||
elif self.reduction=='individual':
|
||||
pass
|
||||
|
||||
return mut_info
|
6
model/__init__.py
Normal file
6
model/__init__.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
from typing import Union
|
||||
from .unet import BeatGANsUNetModel, BeatGANsUNetConfig
|
||||
from .unet_autoenc import BeatGANsAutoencConfig, BeatGANsAutoencModel
|
||||
|
||||
Model = Union[BeatGANsUNetModel, BeatGANsAutoencModel]
|
||||
ModelConfig = Union[BeatGANsUNetConfig, BeatGANsAutoencConfig]
|
495
model/blocks.py
Normal file
495
model/blocks.py
Normal file
|
@ -0,0 +1,495 @@
|
|||
import math
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from numbers import Number
|
||||
|
||||
import torch as th
|
||||
import torch.nn.functional as F
|
||||
from choices import *
|
||||
from config_base import BaseConfig
|
||||
from torch import nn
|
||||
|
||||
from .nn import (avg_pool_nd, conv_nd, linear, normalization,
|
||||
timestep_embedding, torch_checkpoint, zero_module)
|
||||
|
||||
|
||||
class ScaleAt(Enum):
|
||||
after_norm = 'afternorm'
|
||||
|
||||
|
||||
class TimestepBlock(nn.Module):
|
||||
"""
|
||||
Any module where forward() takes timestep embeddings as a second argument.
|
||||
"""
|
||||
@abstractmethod
|
||||
def forward(self, x, emb=None, cond=None, lateral=None):
|
||||
"""
|
||||
Apply the module to `x` given `emb` timestep embeddings.
|
||||
"""
|
||||
|
||||
|
||||
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
||||
"""
|
||||
A sequential module that passes timestep embeddings to the children that
|
||||
support it as an extra input.
|
||||
"""
|
||||
def forward(self, x, emb=None, cond=None, lateral=None):
|
||||
for layer in self:
|
||||
if isinstance(layer, TimestepBlock):
|
||||
x = layer(x, emb=emb, cond=cond, lateral=lateral)
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResBlockConfig(BaseConfig):
|
||||
channels: int
|
||||
emb_channels: int
|
||||
dropout: float
|
||||
out_channels: int = None
|
||||
use_condition: bool = True
|
||||
use_conv: bool = False
|
||||
dims: int = 2
|
||||
use_checkpoint: bool = False
|
||||
up: bool = False
|
||||
down: bool = False
|
||||
two_cond: bool = False
|
||||
cond_emb_channels: int = None
|
||||
has_lateral: bool = False
|
||||
lateral_channels: int = None
|
||||
use_zero_module: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
self.out_channels = self.out_channels or self.channels
|
||||
self.cond_emb_channels = self.cond_emb_channels or self.emb_channels
|
||||
|
||||
def make_model(self):
|
||||
return ResBlock(self)
|
||||
|
||||
|
||||
class ResBlock(TimestepBlock):
|
||||
"""
|
||||
A residual block that can optionally change the number of channels.
|
||||
"""
|
||||
def __init__(self, conf: ResBlockConfig):
|
||||
super().__init__()
|
||||
self.conf = conf
|
||||
|
||||
assert conf.lateral_channels is None
|
||||
layers = [
|
||||
normalization(conf.channels),
|
||||
nn.SiLU(),
|
||||
conv_nd(conf.dims, conf.channels, conf.out_channels, 3, padding=1)
|
||||
]
|
||||
self.in_layers = nn.Sequential(*layers)
|
||||
|
||||
self.updown = conf.up or conf.down
|
||||
|
||||
if conf.up:
|
||||
self.h_upd = Upsample(conf.channels, False, conf.dims)
|
||||
self.x_upd = Upsample(conf.channels, False, conf.dims)
|
||||
elif conf.down:
|
||||
self.h_upd = Downsample(conf.channels, False, conf.dims)
|
||||
self.x_upd = Downsample(conf.channels, False, conf.dims)
|
||||
else:
|
||||
self.h_upd = self.x_upd = nn.Identity()
|
||||
|
||||
if conf.use_condition:
|
||||
self.emb_layers = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
linear(conf.emb_channels, 2 * conf.out_channels),
|
||||
)
|
||||
|
||||
if conf.two_cond:
|
||||
self.cond_emb_layers = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
linear(conf.cond_emb_channels, conf.out_channels),
|
||||
)
|
||||
conv = conv_nd(conf.dims,
|
||||
conf.out_channels,
|
||||
conf.out_channels,
|
||||
3,
|
||||
padding=1)
|
||||
if conf.use_zero_module:
|
||||
conv = zero_module(conv)
|
||||
|
||||
layers = []
|
||||
layers += [
|
||||
normalization(conf.out_channels),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(p=conf.dropout),
|
||||
conv,
|
||||
]
|
||||
self.out_layers = nn.Sequential(*layers)
|
||||
|
||||
if conf.out_channels == conf.channels:
|
||||
self.skip_connection = nn.Identity()
|
||||
else:
|
||||
if conf.use_conv:
|
||||
kernel_size = 3
|
||||
padding = 1
|
||||
else:
|
||||
kernel_size = 1
|
||||
padding = 0
|
||||
|
||||
self.skip_connection = conv_nd(conf.dims,
|
||||
conf.channels,
|
||||
conf.out_channels,
|
||||
kernel_size,
|
||||
padding=padding)
|
||||
|
||||
def forward(self, x, emb=None, cond=None, lateral=None):
|
||||
"""
|
||||
Apply the block to a Tensor, conditioned on a timestep embedding.
|
||||
|
||||
Args:
|
||||
x: input
|
||||
lateral: lateral connection from the encoder
|
||||
"""
|
||||
return torch_checkpoint(self._forward, (x, emb, cond, lateral),
|
||||
self.conf.use_checkpoint)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x,
|
||||
emb=None,
|
||||
cond=None,
|
||||
lateral=None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
lateral: required if "has_lateral" and non-gated, with gated, it can be supplied optionally
|
||||
"""
|
||||
if self.conf.has_lateral:
|
||||
assert lateral is not None
|
||||
x = th.cat([x, lateral], dim=1)
|
||||
|
||||
if self.updown:
|
||||
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
||||
h = in_rest(x)
|
||||
h = self.h_upd(h)
|
||||
x = self.x_upd(x)
|
||||
h = in_conv(h)
|
||||
else:
|
||||
h = self.in_layers(x)
|
||||
|
||||
if self.conf.use_condition:
|
||||
if emb is not None:
|
||||
emb_out = self.emb_layers(emb).type(h.dtype)
|
||||
else:
|
||||
emb_out = None
|
||||
|
||||
if self.conf.two_cond:
|
||||
if cond is None:
|
||||
cond_out = None
|
||||
else:
|
||||
if not isinstance(cond, th.Tensor):
|
||||
assert isinstance(cond, dict)
|
||||
cond = cond['cond']
|
||||
cond_out = self.cond_emb_layers(cond).type(h.dtype)
|
||||
|
||||
if cond_out is not None:
|
||||
while len(cond_out.shape) < len(h.shape):
|
||||
cond_out = cond_out[..., None]
|
||||
else:
|
||||
cond_out = None
|
||||
|
||||
h = apply_conditions(
|
||||
h=h,
|
||||
emb=emb_out,
|
||||
cond=cond_out,
|
||||
layers=self.out_layers,
|
||||
scale_bias=1,
|
||||
in_channels=self.conf.out_channels,
|
||||
up_down_layer=None,
|
||||
)
|
||||
return self.skip_connection(x) + h
|
||||
|
||||
|
||||
def apply_conditions(
|
||||
h,
|
||||
emb=None,
|
||||
cond=None,
|
||||
layers: nn.Sequential = None,
|
||||
scale_bias: float = 1,
|
||||
in_channels: int = 512,
|
||||
up_down_layer: nn.Module = None,
|
||||
):
|
||||
"""
|
||||
apply conditions on the feature maps
|
||||
|
||||
Args:
|
||||
emb: time conditional (ready to scale + shift)
|
||||
cond: encoder's conditional (read to scale + shift)
|
||||
"""
|
||||
two_cond = emb is not None and cond is not None
|
||||
|
||||
if emb is not None:
|
||||
while len(emb.shape) < len(h.shape):
|
||||
emb = emb[..., None]
|
||||
|
||||
if two_cond:
|
||||
while len(cond.shape) < len(h.shape):
|
||||
cond = cond[..., None]
|
||||
scale_shifts = [emb, cond]
|
||||
else:
|
||||
scale_shifts = [emb]
|
||||
|
||||
for i, each in enumerate(scale_shifts):
|
||||
if each is None:
|
||||
a = None
|
||||
b = None
|
||||
else:
|
||||
if each.shape[1] == in_channels * 2:
|
||||
a, b = th.chunk(each, 2, dim=1)
|
||||
else:
|
||||
a = each
|
||||
b = None
|
||||
scale_shifts[i] = (a, b)
|
||||
|
||||
if isinstance(scale_bias, Number):
|
||||
biases = [scale_bias] * len(scale_shifts)
|
||||
else:
|
||||
biases = scale_bias
|
||||
|
||||
pre_layers, post_layers = layers[0], layers[1:]
|
||||
|
||||
mid_layers, post_layers = post_layers[:-2], post_layers[-2:]
|
||||
|
||||
h = pre_layers(h)
|
||||
for i, (scale, shift) in enumerate(scale_shifts):
|
||||
if scale is not None:
|
||||
h = h * (biases[i] + scale)
|
||||
if shift is not None:
|
||||
h = h + shift
|
||||
h = mid_layers(h)
|
||||
|
||||
if up_down_layer is not None:
|
||||
h = up_down_layer(h)
|
||||
h = post_layers(h)
|
||||
return h
|
||||
|
||||
class Upsample(nn.Module):
|
||||
"""
|
||||
An upsampling layer with an optional convolution.
|
||||
|
||||
:param channels: channels in the inputs and outputs.
|
||||
:param use_conv: a bool determining if a convolution is applied.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
upsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
if use_conv:
|
||||
self.conv = conv_nd(dims,
|
||||
self.channels,
|
||||
self.out_channels,
|
||||
3,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
if self.dims == 3:
|
||||
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
|
||||
mode="nearest")
|
||||
else:
|
||||
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
||||
if self.use_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
"""
|
||||
A downsampling layer with an optional convolution.
|
||||
|
||||
:param channels: channels in the inputs and outputs.
|
||||
:param use_conv: a bool determining if a convolution is applied.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
downsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
stride = 2 if dims != 3 else (1, 2, 2)
|
||||
if use_conv:
|
||||
self.op = conv_nd(dims,
|
||||
self.channels,
|
||||
self.out_channels,
|
||||
3,
|
||||
stride=stride,
|
||||
padding=1)
|
||||
else:
|
||||
assert self.channels == self.out_channels
|
||||
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
"""
|
||||
An attention block that allows spatial positions to attend to each other.
|
||||
|
||||
Originally ported from here, but adapted to the N-d case.
|
||||
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
num_heads=1,
|
||||
num_head_channels=-1,
|
||||
use_checkpoint=False,
|
||||
use_new_attention_order=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
if num_head_channels == -1:
|
||||
self.num_heads = num_heads
|
||||
else:
|
||||
assert (
|
||||
channels % num_head_channels == 0
|
||||
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
||||
self.num_heads = channels // num_head_channels
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.norm = normalization(channels)
|
||||
self.qkv = conv_nd(1, channels, channels * 3, 1)
|
||||
if use_new_attention_order:
|
||||
self.attention = QKVAttention(self.num_heads)
|
||||
else:
|
||||
self.attention = QKVAttentionLegacy(self.num_heads)
|
||||
|
||||
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
|
||||
|
||||
def forward(self, x):
|
||||
return torch_checkpoint(self._forward, (x, ), self.use_checkpoint)
|
||||
|
||||
def _forward(self, x):
|
||||
b, c, *spatial = x.shape
|
||||
x = x.reshape(b, c, -1)
|
||||
qkv = self.qkv(self.norm(x))
|
||||
h = self.attention(qkv)
|
||||
h = self.proj_out(h)
|
||||
return (x + h).reshape(b, c, *spatial)
|
||||
|
||||
|
||||
def count_flops_attn(model, _x, y):
|
||||
"""
|
||||
A counter for the `thop` package to count the operations in an
|
||||
attention operation.
|
||||
Meant to be used like:
|
||||
macs, params = thop.profile(
|
||||
model,
|
||||
inputs=(inputs, timestamps),
|
||||
custom_ops={QKVAttention: QKVAttention.count_flops},
|
||||
)
|
||||
"""
|
||||
b, c, *spatial = y[0].shape
|
||||
num_spatial = int(np.prod(spatial))
|
||||
matmul_ops = 2 * b * (num_spatial**2) * c
|
||||
model.total_ops += th.DoubleTensor([matmul_ops])
|
||||
|
||||
class QKVAttentionLegacy(nn.Module):
|
||||
"""
|
||||
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
|
||||
"""
|
||||
def __init__(self, n_heads):
|
||||
super().__init__()
|
||||
self.n_heads = n_heads
|
||||
|
||||
def forward(self, qkv):
|
||||
"""
|
||||
Apply QKV attention.
|
||||
|
||||
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
|
||||
:return: an [N x (H * C) x T] tensor after attention.
|
||||
"""
|
||||
bs, width, length = qkv.shape
|
||||
assert width % (3 * self.n_heads) == 0
|
||||
ch = width // (3 * self.n_heads)
|
||||
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch,
|
||||
dim=1)
|
||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||
weight = th.einsum(
|
||||
"bct,bcs->bts", q * scale,
|
||||
k * scale)
|
||||
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
a = th.einsum("bts,bcs->bct", weight, v)
|
||||
return a.reshape(bs, -1, length)
|
||||
|
||||
@staticmethod
|
||||
def count_flops(model, _x, y):
|
||||
return count_flops_attn(model, _x, y)
|
||||
|
||||
class QKVAttention(nn.Module):
|
||||
"""
|
||||
A module which performs QKV attention and splits in a different order.
|
||||
"""
|
||||
def __init__(self, n_heads):
|
||||
super().__init__()
|
||||
self.n_heads = n_heads
|
||||
|
||||
def forward(self, qkv):
|
||||
"""
|
||||
Apply QKV attention.
|
||||
|
||||
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
|
||||
:return: an [N x (H * C) x T] tensor after attention.
|
||||
"""
|
||||
bs, width, length = qkv.shape
|
||||
assert width % (3 * self.n_heads) == 0
|
||||
ch = width // (3 * self.n_heads)
|
||||
q, k, v = qkv.chunk(3, dim=1)
|
||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||
weight = th.einsum(
|
||||
"bct,bcs->bts",
|
||||
(q * scale).view(bs * self.n_heads, ch, length),
|
||||
(k * scale).view(bs * self.n_heads, ch, length),
|
||||
)
|
||||
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
a = th.einsum("bts,bcs->bct", weight,
|
||||
v.reshape(bs * self.n_heads, ch, length))
|
||||
return a.reshape(bs, -1, length)
|
||||
|
||||
@staticmethod
|
||||
def count_flops(model, _x, y):
|
||||
return count_flops_attn(model, _x, y)
|
||||
|
||||
|
||||
class AttentionPool2d(nn.Module):
|
||||
"""
|
||||
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
spacial_dim: int,
|
||||
embed_dim: int,
|
||||
num_heads_channels: int,
|
||||
output_dim: int = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.positional_embedding = nn.Parameter(
|
||||
th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5)
|
||||
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
|
||||
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
|
||||
self.num_heads = embed_dim // num_heads_channels
|
||||
self.attention = QKVAttention(self.num_heads)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, *_spatial = x.shape
|
||||
x = x.reshape(b, c, -1)
|
||||
x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1)
|
||||
x = x + self.positional_embedding[None, :, :].to(x.dtype)
|
||||
x = self.qkv_proj(x)
|
||||
x = self.attention(x)
|
||||
x = self.c_proj(x)
|
||||
return x[:, :, 0]
|
184
model/latentnet.py
Normal file
184
model/latentnet.py
Normal file
|
@ -0,0 +1,184 @@
|
|||
import math
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import NamedTuple, Tuple
|
||||
|
||||
import torch
|
||||
from choices import *
|
||||
from config_base import BaseConfig
|
||||
from torch import nn
|
||||
from torch.nn import init
|
||||
|
||||
from .blocks import *
|
||||
from .nn import timestep_embedding
|
||||
from .unet import *
|
||||
|
||||
|
||||
class LatentNetType(Enum):
|
||||
none = 'none'
|
||||
skip = 'skip'
|
||||
|
||||
class LatentNetReturn(NamedTuple):
|
||||
pred: torch.Tensor = None
|
||||
|
||||
@dataclass
|
||||
class MLPSkipNetConfig(BaseConfig):
|
||||
"""
|
||||
default MLP for the latent DPM in the paper!
|
||||
"""
|
||||
num_channels: int
|
||||
skip_layers: Tuple[int]
|
||||
num_hid_channels: int
|
||||
num_layers: int
|
||||
num_time_emb_channels: int = 64
|
||||
activation: Activation = Activation.silu
|
||||
use_norm: bool = True
|
||||
condition_bias: float = 1
|
||||
dropout: float = 0
|
||||
last_act: Activation = Activation.none
|
||||
num_time_layers: int = 2
|
||||
time_last_act: bool = False
|
||||
|
||||
def make_model(self):
|
||||
return MLPSkipNet(self)
|
||||
|
||||
|
||||
class MLPSkipNet(nn.Module):
|
||||
"""
|
||||
concat x to hidden layers
|
||||
|
||||
default MLP for the latent DPM in the paper!
|
||||
"""
|
||||
def __init__(self, conf: MLPSkipNetConfig):
|
||||
super().__init__()
|
||||
self.conf = conf
|
||||
|
||||
layers = []
|
||||
for i in range(conf.num_time_layers):
|
||||
if i == 0:
|
||||
a = conf.num_time_emb_channels
|
||||
b = conf.num_channels
|
||||
else:
|
||||
a = conf.num_channels
|
||||
b = conf.num_channels
|
||||
layers.append(nn.Linear(a, b))
|
||||
if i < conf.num_time_layers - 1 or conf.time_last_act:
|
||||
layers.append(conf.activation.get_act())
|
||||
self.time_embed = nn.Sequential(*layers)
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
for i in range(conf.num_layers):
|
||||
if i == 0:
|
||||
act = conf.activation
|
||||
norm = conf.use_norm
|
||||
cond = True
|
||||
a, b = conf.num_channels, conf.num_hid_channels
|
||||
dropout = conf.dropout
|
||||
elif i == conf.num_layers - 1:
|
||||
act = Activation.none
|
||||
norm = False
|
||||
cond = False
|
||||
a, b = conf.num_hid_channels, conf.num_channels
|
||||
dropout = 0
|
||||
else:
|
||||
act = conf.activation
|
||||
norm = conf.use_norm
|
||||
cond = True
|
||||
a, b = conf.num_hid_channels, conf.num_hid_channels
|
||||
dropout = conf.dropout
|
||||
|
||||
if i in conf.skip_layers:
|
||||
a += conf.num_channels
|
||||
|
||||
self.layers.append(
|
||||
MLPLNAct(
|
||||
a,
|
||||
b,
|
||||
norm=norm,
|
||||
activation=act,
|
||||
cond_channels=conf.num_channels,
|
||||
use_cond=cond,
|
||||
condition_bias=conf.condition_bias,
|
||||
dropout=dropout,
|
||||
))
|
||||
self.last_act = conf.last_act.get_act()
|
||||
|
||||
def forward(self, x, t, **kwargs):
|
||||
t = timestep_embedding(t, self.conf.num_time_emb_channels)
|
||||
cond = self.time_embed(t)
|
||||
h = x
|
||||
for i in range(len(self.layers)):
|
||||
if i in self.conf.skip_layers:
|
||||
h = torch.cat([h, x], dim=1)
|
||||
h = self.layers[i].forward(x=h, cond=cond)
|
||||
h = self.last_act(h)
|
||||
return LatentNetReturn(h)
|
||||
|
||||
|
||||
class MLPLNAct(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
norm: bool,
|
||||
use_cond: bool,
|
||||
activation: Activation,
|
||||
cond_channels: int,
|
||||
condition_bias: float = 0,
|
||||
dropout: float = 0,
|
||||
):
|
||||
super().__init__()
|
||||
self.activation = activation
|
||||
self.condition_bias = condition_bias
|
||||
self.use_cond = use_cond
|
||||
|
||||
self.linear = nn.Linear(in_channels, out_channels)
|
||||
self.act = activation.get_act()
|
||||
if self.use_cond:
|
||||
self.linear_emb = nn.Linear(cond_channels, out_channels)
|
||||
self.cond_layers = nn.Sequential(self.act, self.linear_emb)
|
||||
if norm:
|
||||
self.norm = nn.LayerNorm(out_channels)
|
||||
else:
|
||||
self.norm = nn.Identity()
|
||||
|
||||
if dropout > 0:
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
else:
|
||||
self.dropout = nn.Identity()
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
for module in self.modules():
|
||||
if isinstance(module, nn.Linear):
|
||||
if self.activation == Activation.relu:
|
||||
init.kaiming_normal_(module.weight,
|
||||
a=0,
|
||||
nonlinearity='relu')
|
||||
elif self.activation == Activation.lrelu:
|
||||
init.kaiming_normal_(module.weight,
|
||||
a=0.2,
|
||||
nonlinearity='leaky_relu')
|
||||
elif self.activation == Activation.silu:
|
||||
init.kaiming_normal_(module.weight,
|
||||
a=0,
|
||||
nonlinearity='relu')
|
||||
else:
|
||||
pass
|
||||
|
||||
def forward(self, x, cond=None):
|
||||
x = self.linear(x)
|
||||
if self.use_cond:
|
||||
cond = self.cond_layers(cond)
|
||||
cond = (cond, None)
|
||||
|
||||
x = x * (self.condition_bias + cond[0])
|
||||
if cond[1] is not None:
|
||||
x = x + cond[1]
|
||||
x = self.norm(x)
|
||||
else:
|
||||
x = self.norm(x)
|
||||
x = self.act(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
135
model/nn.py
Executable file
135
model/nn.py
Executable file
|
@ -0,0 +1,135 @@
|
|||
"""
|
||||
Various utilities for neural networks.
|
||||
"""
|
||||
import math
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
||||
class SiLU(nn.Module):
|
||||
# @th.jit.script
|
||||
def forward(self, x):
|
||||
return x * th.sigmoid(x)
|
||||
|
||||
|
||||
class GroupNorm32(nn.GroupNorm):
|
||||
def forward(self, x):
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
|
||||
|
||||
def conv_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D convolution module.
|
||||
"""
|
||||
assert dims==1
|
||||
if dims == 1:
|
||||
return nn.Conv1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.Conv2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.Conv3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
def linear(*args, **kwargs):
|
||||
"""
|
||||
Create a linear module.
|
||||
"""
|
||||
return nn.Linear(*args, **kwargs)
|
||||
|
||||
|
||||
def avg_pool_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D average pooling module.
|
||||
"""
|
||||
assert dims==1
|
||||
if dims == 1:
|
||||
return nn.AvgPool1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.AvgPool2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.AvgPool3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
def update_ema(target_params, source_params, rate=0.99):
|
||||
"""
|
||||
Update target parameters to be closer to those of source parameters using
|
||||
an exponential moving average.
|
||||
|
||||
:param target_params: the target parameter sequence.
|
||||
:param source_params: the source parameter sequence.
|
||||
:param rate: the EMA rate (closer to 1 means slower).
|
||||
"""
|
||||
for targ, src in zip(target_params, source_params):
|
||||
targ.detach().mul_(rate).add_(src, alpha=1 - rate)
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
def scale_module(module, scale):
|
||||
"""
|
||||
Scale the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().mul_(scale)
|
||||
return module
|
||||
|
||||
|
||||
def mean_flat(tensor):
|
||||
"""
|
||||
Take the mean over all non-batch dimensions.
|
||||
"""
|
||||
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
||||
|
||||
|
||||
def normalization(channels):
|
||||
"""
|
||||
Make a standard normalization layer.
|
||||
|
||||
:param channels: number of input channels.
|
||||
:return: an nn.Module for normalization.
|
||||
"""
|
||||
return GroupNorm32(min(32, channels), channels)
|
||||
|
||||
|
||||
def timestep_embedding(timesteps, dim, max_period=10000):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
|
||||
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an [N x dim] Tensor of positional embeddings.
|
||||
"""
|
||||
half = dim // 2
|
||||
freqs = th.exp(-math.log(max_period) *
|
||||
th.arange(start=0, end=half, dtype=th.float32) /
|
||||
half).to(device=timesteps.device)
|
||||
args = timesteps[:, None].float() * freqs[None]
|
||||
embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = th.cat(
|
||||
[embedding, th.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
|
||||
def torch_checkpoint(func, args, flag, preserve_rng_state=False):
|
||||
# torch's gradient checkpoint works with automatic mixed precision, given torch >= 1.8
|
||||
if flag:
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
func, *args, preserve_rng_state=preserve_rng_state)
|
||||
else:
|
||||
return func(*args)
|
505
model/unet.py
Normal file
505
model/unet.py
Normal file
|
@ -0,0 +1,505 @@
|
|||
import math
|
||||
from dataclasses import dataclass
|
||||
from numbers import Number
|
||||
from typing import NamedTuple, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from choices import *
|
||||
from config_base import BaseConfig
|
||||
from .blocks import *
|
||||
|
||||
from .nn import (conv_nd, linear, normalization, timestep_embedding,
|
||||
torch_checkpoint, zero_module)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BeatGANsUNetConfig(BaseConfig):
|
||||
image_size: int = 64
|
||||
in_channels: int = 2
|
||||
model_channels: int = 64
|
||||
out_channels: int = 2
|
||||
num_res_blocks: int = 2
|
||||
num_input_res_blocks: int = None
|
||||
embed_channels: int = 512
|
||||
attention_resolutions: Tuple[int] = (16, )
|
||||
time_embed_channels: int = None
|
||||
dropout: float = 0.1
|
||||
channel_mult: Tuple[int] = (1, 2, 4, 8)
|
||||
input_channel_mult: Tuple[int] = None
|
||||
conv_resample: bool = True
|
||||
dims: int = 2
|
||||
num_classes: int = None
|
||||
use_checkpoint: bool = False
|
||||
num_heads: int = 1
|
||||
num_head_channels: int = -1
|
||||
num_heads_upsample: int = -1
|
||||
resblock_updown: bool = True
|
||||
use_new_attention_order: bool = False
|
||||
resnet_two_cond: bool = False
|
||||
resnet_cond_channels: int = None
|
||||
resnet_use_zero_module: bool = True
|
||||
attn_checkpoint: bool = False
|
||||
|
||||
num_users: int = None
|
||||
|
||||
def make_model(self):
|
||||
return BeatGANsUNetModel(self)
|
||||
|
||||
|
||||
class BeatGANsUNetModel(nn.Module):
|
||||
def __init__(self, conf: BeatGANsUNetConfig):
|
||||
super().__init__()
|
||||
self.conf = conf
|
||||
|
||||
if conf.num_heads_upsample == -1:
|
||||
self.num_heads_upsample = conf.num_heads
|
||||
|
||||
self.dtype = th.float32
|
||||
|
||||
self.time_emb_channels = conf.time_embed_channels or conf.model_channels
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(self.time_emb_channels, conf.embed_channels),
|
||||
nn.SiLU(),
|
||||
linear(conf.embed_channels, conf.embed_channels),
|
||||
)
|
||||
|
||||
if conf.num_classes is not None:
|
||||
self.label_emb = nn.Embedding(conf.num_classes,
|
||||
conf.embed_channels)
|
||||
|
||||
ch = input_ch = int(conf.channel_mult[0] * conf.model_channels)
|
||||
self.input_blocks = nn.ModuleList([
|
||||
TimestepEmbedSequential(
|
||||
conv_nd(conf.dims, conf.in_channels, ch, 3, padding=1))
|
||||
])
|
||||
|
||||
kwargs = dict(
|
||||
use_condition=True,
|
||||
two_cond=conf.resnet_two_cond,
|
||||
use_zero_module=conf.resnet_use_zero_module,
|
||||
cond_emb_channels=conf.resnet_cond_channels,
|
||||
)
|
||||
|
||||
self._feature_size = ch
|
||||
|
||||
# input_block_chans = [ch]
|
||||
input_block_chans = [[] for _ in range(len(conf.channel_mult))]
|
||||
input_block_chans[0].append(ch)
|
||||
|
||||
# number of blocks at each resolution
|
||||
self.input_num_blocks = [0 for _ in range(len(conf.channel_mult))]
|
||||
self.input_num_blocks[0] = 1
|
||||
self.output_num_blocks = [0 for _ in range(len(conf.channel_mult))]
|
||||
|
||||
ds = 1
|
||||
resolution = conf.image_size
|
||||
for level, mult in enumerate(conf.input_channel_mult
|
||||
or conf.channel_mult):
|
||||
for _ in range(conf.num_input_res_blocks or conf.num_res_blocks):
|
||||
layers = [
|
||||
ResBlockConfig(
|
||||
ch,
|
||||
conf.embed_channels,
|
||||
conf.dropout,
|
||||
out_channels=int(mult * conf.model_channels),
|
||||
dims=conf.dims,
|
||||
use_checkpoint=conf.use_checkpoint,
|
||||
**kwargs,
|
||||
).make_model()
|
||||
]
|
||||
ch = int(mult * conf.model_channels)
|
||||
if resolution in conf.attention_resolutions:
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=conf.use_checkpoint
|
||||
or conf.attn_checkpoint,
|
||||
num_heads=conf.num_heads,
|
||||
num_head_channels=conf.num_head_channels,
|
||||
use_new_attention_order=conf.
|
||||
use_new_attention_order,
|
||||
))
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
input_block_chans[level].append(ch)
|
||||
self.input_num_blocks[level] += 1
|
||||
if level != len(conf.channel_mult) - 1:
|
||||
resolution //= 2
|
||||
out_ch = ch
|
||||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
ResBlockConfig(
|
||||
ch,
|
||||
conf.embed_channels,
|
||||
conf.dropout,
|
||||
out_channels=out_ch,
|
||||
dims=conf.dims,
|
||||
use_checkpoint=conf.use_checkpoint,
|
||||
down=True,
|
||||
**kwargs,
|
||||
).make_model() if conf.
|
||||
resblock_updown else Downsample(ch,
|
||||
conf.conv_resample,
|
||||
dims=conf.dims,
|
||||
out_channels=out_ch)))
|
||||
ch = out_ch
|
||||
input_block_chans[level + 1].append(ch)
|
||||
self.input_num_blocks[level + 1] += 1
|
||||
ds *= 2
|
||||
self._feature_size += ch
|
||||
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
ResBlockConfig(
|
||||
ch,
|
||||
conf.embed_channels,
|
||||
conf.dropout,
|
||||
dims=conf.dims,
|
||||
use_checkpoint=conf.use_checkpoint,
|
||||
**kwargs,
|
||||
).make_model(),
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=conf.use_checkpoint or conf.attn_checkpoint,
|
||||
num_heads=conf.num_heads,
|
||||
num_head_channels=conf.num_head_channels,
|
||||
use_new_attention_order=conf.use_new_attention_order,
|
||||
),
|
||||
ResBlockConfig(
|
||||
ch,
|
||||
conf.embed_channels,
|
||||
conf.dropout,
|
||||
dims=conf.dims,
|
||||
use_checkpoint=conf.use_checkpoint,
|
||||
**kwargs,
|
||||
).make_model(),
|
||||
)
|
||||
self._feature_size += ch
|
||||
|
||||
self.output_blocks = nn.ModuleList([])
|
||||
for level, mult in list(enumerate(conf.channel_mult))[::-1]:
|
||||
for i in range(conf.num_res_blocks + 1):
|
||||
try:
|
||||
ich = input_block_chans[level].pop()
|
||||
except IndexError:
|
||||
ich = 0
|
||||
layers = [
|
||||
ResBlockConfig(
|
||||
channels=ch + ich,
|
||||
emb_channels=conf.embed_channels,
|
||||
dropout=conf.dropout,
|
||||
out_channels=int(conf.model_channels * mult),
|
||||
dims=conf.dims,
|
||||
use_checkpoint=conf.use_checkpoint,
|
||||
has_lateral=True if ich > 0 else False,
|
||||
lateral_channels=None,
|
||||
**kwargs,
|
||||
).make_model()
|
||||
]
|
||||
ch = int(conf.model_channels * mult)
|
||||
if resolution in conf.attention_resolutions:
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=conf.use_checkpoint
|
||||
or conf.attn_checkpoint,
|
||||
num_heads=self.num_heads_upsample,
|
||||
num_head_channels=conf.num_head_channels,
|
||||
use_new_attention_order=conf.
|
||||
use_new_attention_order,
|
||||
))
|
||||
if level and i == conf.num_res_blocks:
|
||||
resolution *= 2
|
||||
out_ch = ch
|
||||
layers.append(
|
||||
ResBlockConfig(
|
||||
ch,
|
||||
conf.embed_channels,
|
||||
conf.dropout,
|
||||
out_channels=out_ch,
|
||||
dims=conf.dims,
|
||||
use_checkpoint=conf.use_checkpoint,
|
||||
up=True,
|
||||
**kwargs,
|
||||
).make_model() if (
|
||||
conf.resblock_updown
|
||||
) else Upsample(ch,
|
||||
conf.conv_resample,
|
||||
dims=conf.dims,
|
||||
out_channels=out_ch))
|
||||
ds //= 2
|
||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self.output_num_blocks[level] += 1
|
||||
self._feature_size += ch
|
||||
|
||||
if conf.resnet_use_zero_module:
|
||||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
zero_module(
|
||||
conv_nd(conf.dims,
|
||||
input_ch,
|
||||
conf.out_channels,
|
||||
3,
|
||||
padding=1)),
|
||||
)
|
||||
else:
|
||||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
conv_nd(conf.dims, input_ch, conf.out_channels, 3, padding=1),
|
||||
)
|
||||
|
||||
def forward(self, x, t, y=None, **kwargs):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
|
||||
:param x: an [N x C x ...] Tensor of inputs.
|
||||
:param timesteps: a 1-D batch of timesteps.
|
||||
:param y: an [N] Tensor of labels, if class-conditional.
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
assert (y is not None) == (
|
||||
self.conf.num_classes is not None
|
||||
), "must specify y if and only if the model is class-conditional"
|
||||
|
||||
# hs = []
|
||||
hs = [[] for _ in range(len(self.conf.channel_mult))]
|
||||
emb = self.time_embed(timestep_embedding(t, self.time_emb_channels))
|
||||
|
||||
if self.conf.num_classes is not None:
|
||||
raise NotImplementedError()
|
||||
|
||||
h = x.type(self.dtype)
|
||||
k = 0
|
||||
for i in range(len(self.input_num_blocks)):
|
||||
for j in range(self.input_num_blocks[i]):
|
||||
h = self.input_blocks[k](h, emb=emb)
|
||||
hs[i].append(h)
|
||||
k += 1
|
||||
assert k == len(self.input_blocks)
|
||||
|
||||
# middle blocks
|
||||
h = self.middle_block(h, emb=emb)
|
||||
|
||||
# output blocks
|
||||
k = 0
|
||||
for i in range(len(self.output_num_blocks)):
|
||||
for j in range(self.output_num_blocks[i]):
|
||||
try:
|
||||
lateral = hs[-i - 1].pop()
|
||||
except IndexError:
|
||||
lateral = None
|
||||
h = self.output_blocks[k](h, emb=emb, lateral=lateral)
|
||||
k += 1
|
||||
|
||||
h = h.type(x.dtype)
|
||||
pred = self.out(h)
|
||||
return Return(pred=pred)
|
||||
|
||||
class Return(NamedTuple):
|
||||
pred: th.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class BeatGANsEncoderConfig(BaseConfig):
|
||||
image_size: int
|
||||
in_channels: int
|
||||
model_channels: int
|
||||
out_hid_channels: int
|
||||
out_channels: int
|
||||
num_res_blocks: int
|
||||
attention_resolutions: Tuple[int]
|
||||
dropout: float = 0
|
||||
channel_mult: Tuple[int] = (1, 2, 4, 8)
|
||||
use_time_condition: bool = True
|
||||
conv_resample: bool = True
|
||||
dims: int = 2
|
||||
use_checkpoint: bool = False
|
||||
num_heads: int = 1
|
||||
num_head_channels: int = -1
|
||||
resblock_updown: bool = False
|
||||
use_new_attention_order: bool = False
|
||||
pool: str = 'adaptivenonzero'
|
||||
|
||||
def make_model(self):
|
||||
return BeatGANsEncoderModel(self)
|
||||
|
||||
class BeatGANsEncoderModel(nn.Module):
|
||||
"""
|
||||
The half UNet model with attention and timestep embedding.
|
||||
|
||||
For usage, see UNet.
|
||||
"""
|
||||
def __init__(self, conf: BeatGANsEncoderConfig):
|
||||
super().__init__()
|
||||
self.conf = conf
|
||||
self.dtype = th.float32
|
||||
|
||||
if conf.use_time_condition:
|
||||
time_embed_dim = conf.model_channels * 4
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(conf.model_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
else:
|
||||
time_embed_dim = None
|
||||
|
||||
ch = int(conf.channel_mult[0] * conf.model_channels)
|
||||
self.input_blocks = nn.ModuleList([
|
||||
TimestepEmbedSequential(
|
||||
conv_nd(conf.dims, conf.in_channels, ch, 3, padding=1))
|
||||
])
|
||||
self._feature_size = ch
|
||||
input_block_chans = [ch]
|
||||
ds = 1
|
||||
resolution = conf.image_size
|
||||
for level, mult in enumerate(conf.channel_mult):
|
||||
for _ in range(conf.num_res_blocks):
|
||||
layers = [
|
||||
ResBlockConfig(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
conf.dropout,
|
||||
out_channels=int(mult * conf.model_channels),
|
||||
dims=conf.dims,
|
||||
use_condition=conf.use_time_condition,
|
||||
use_checkpoint=conf.use_checkpoint,
|
||||
).make_model()
|
||||
]
|
||||
ch = int(mult * conf.model_channels)
|
||||
if resolution in conf.attention_resolutions:
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=conf.use_checkpoint,
|
||||
num_heads=conf.num_heads,
|
||||
num_head_channels=conf.num_head_channels,
|
||||
use_new_attention_order=conf.
|
||||
use_new_attention_order,
|
||||
))
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
input_block_chans.append(ch)
|
||||
if level != len(conf.channel_mult) - 1:
|
||||
resolution //= 2
|
||||
out_ch = ch
|
||||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
ResBlockConfig(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
conf.dropout,
|
||||
out_channels=out_ch,
|
||||
dims=conf.dims,
|
||||
use_condition=conf.use_time_condition,
|
||||
use_checkpoint=conf.use_checkpoint,
|
||||
down=True,
|
||||
).make_model() if (
|
||||
conf.resblock_updown
|
||||
) else Downsample(ch,
|
||||
conf.conv_resample,
|
||||
dims=conf.dims,
|
||||
out_channels=out_ch)))
|
||||
ch = out_ch
|
||||
input_block_chans.append(ch)
|
||||
ds *= 2
|
||||
self._feature_size += ch
|
||||
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
ResBlockConfig(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
conf.dropout,
|
||||
dims=conf.dims,
|
||||
use_condition=conf.use_time_condition,
|
||||
use_checkpoint=conf.use_checkpoint,
|
||||
).make_model(),
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=conf.use_checkpoint,
|
||||
num_heads=conf.num_heads,
|
||||
num_head_channels=conf.num_head_channels,
|
||||
use_new_attention_order=conf.use_new_attention_order,
|
||||
),
|
||||
ResBlockConfig(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
conf.dropout,
|
||||
dims=conf.dims,
|
||||
use_condition=conf.use_time_condition,
|
||||
use_checkpoint=conf.use_checkpoint,
|
||||
).make_model(),
|
||||
)
|
||||
self._feature_size += ch
|
||||
if conf.pool == "adaptivenonzero":
|
||||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
nn.AdaptiveAvgPool1d(1),
|
||||
conv_nd(conf.dims, ch, conf.out_channels, 1),
|
||||
nn.Flatten(),
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unexpected {conf.pool} pooling")
|
||||
|
||||
def forward(self, x, t=None, return_2d_feature=False):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
|
||||
:param x: an [N x C x ...] Tensor of inputs.
|
||||
:param timesteps: a 1-D batch of timesteps.
|
||||
:return: an [N x K] Tensor of outputs.
|
||||
"""
|
||||
if self.conf.use_time_condition:
|
||||
emb = self.time_embed(timestep_embedding(t, self.model_channels))
|
||||
else:
|
||||
emb = None
|
||||
|
||||
results = []
|
||||
h = x.type(self.dtype)
|
||||
for module in self.input_blocks:
|
||||
h = module(h, emb=emb)
|
||||
if self.conf.pool.startswith("spatial"):
|
||||
results.append(h.type(x.dtype).mean(dim=(2, 3)))
|
||||
h = self.middle_block(h, emb=emb)
|
||||
if self.conf.pool.startswith("spatial"):
|
||||
results.append(h.type(x.dtype).mean(dim=(2, 3)))
|
||||
h = th.cat(results, axis=-1)
|
||||
else:
|
||||
h = h.type(x.dtype)
|
||||
|
||||
h_2d = h
|
||||
h = self.out(h)
|
||||
|
||||
if return_2d_feature:
|
||||
return h, h_2d
|
||||
else:
|
||||
return h
|
||||
|
||||
def forward_flatten(self, x):
|
||||
"""
|
||||
transform the last 2d feature into a flatten vector
|
||||
"""
|
||||
h = self.out(x)
|
||||
return h
|
||||
|
||||
|
||||
class SuperResModel(BeatGANsUNetModel):
|
||||
"""
|
||||
A UNetModel that performs super-resolution.
|
||||
|
||||
Expects an extra kwarg `low_res` to condition on a low-resolution image.
|
||||
"""
|
||||
def __init__(self, image_size, in_channels, *args, **kwargs):
|
||||
super().__init__(image_size, in_channels * 2, *args, **kwargs)
|
||||
|
||||
def forward(self, x, timesteps, low_res=None, **kwargs):
|
||||
_, _, new_height, new_width = x.shape
|
||||
upsampled = F.interpolate(low_res, (new_height, new_width),
|
||||
mode="bilinear")
|
||||
x = th.cat([x, upsampled], dim=1)
|
||||
return super().forward(x, timesteps, **kwargs)
|
310
model/unet_autoenc.py
Normal file
310
model/unet_autoenc.py
Normal file
|
@ -0,0 +1,310 @@
|
|||
import torch
|
||||
from torch import Tensor, nn
|
||||
from torch.nn.functional import silu
|
||||
from .latentnet import *
|
||||
from .unet import *
|
||||
from choices import *
|
||||
|
||||
|
||||
@dataclass
|
||||
class BeatGANsAutoencConfig(BeatGANsUNetConfig):
|
||||
enc_out_channels: int = 512
|
||||
enc_attn_resolutions: Tuple[int] = None
|
||||
enc_pool: str = 'depthconv'
|
||||
enc_num_res_block: int = 2
|
||||
enc_channel_mult: Tuple[int] = None
|
||||
enc_grad_checkpoint: bool = False
|
||||
latent_net_conf: MLPSkipNetConfig = None
|
||||
|
||||
def make_model(self):
|
||||
return BeatGANsAutoencModel(self)
|
||||
|
||||
|
||||
class BeatGANsAutoencModel(BeatGANsUNetModel):
|
||||
def __init__(self, conf: BeatGANsAutoencConfig):
|
||||
super().__init__(conf)
|
||||
self.conf = conf
|
||||
|
||||
self.time_embed = TimeStyleSeperateEmbed(
|
||||
time_channels=conf.model_channels,
|
||||
time_out_channels=conf.embed_channels,
|
||||
)
|
||||
|
||||
self.encoder = BeatGANsEncoderConfig(
|
||||
image_size=conf.image_size,
|
||||
in_channels=conf.in_channels,
|
||||
model_channels=conf.model_channels,
|
||||
out_hid_channels=conf.enc_out_channels,
|
||||
out_channels=conf.enc_out_channels,
|
||||
num_res_blocks=conf.enc_num_res_block,
|
||||
attention_resolutions=(conf.enc_attn_resolutions
|
||||
or conf.attention_resolutions),
|
||||
dropout=conf.dropout,
|
||||
channel_mult=conf.enc_channel_mult or conf.channel_mult,
|
||||
use_time_condition=False,
|
||||
conv_resample=conf.conv_resample,
|
||||
dims=conf.dims,
|
||||
use_checkpoint=conf.use_checkpoint or conf.enc_grad_checkpoint,
|
||||
num_heads=conf.num_heads,
|
||||
num_head_channels=conf.num_head_channels,
|
||||
resblock_updown=conf.resblock_updown,
|
||||
use_new_attention_order=conf.use_new_attention_order,
|
||||
pool=conf.enc_pool,
|
||||
).make_model()
|
||||
|
||||
self.user_classifier = UserClassifier(conf.enc_out_channels//2, conf.num_users)
|
||||
self.non_user_classifier = UserClassifierGradientReverse(conf.enc_out_channels//2, conf.num_users)
|
||||
|
||||
if conf.latent_net_conf is not None:
|
||||
self.latent_net = conf.latent_net_conf.make_model()
|
||||
|
||||
def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
|
||||
"""
|
||||
Reparameterization trick to sample from N(mu, var) from
|
||||
N(0,1).
|
||||
:param mu: (Tensor) Mean of the latent Gaussian [B x D]
|
||||
:param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
|
||||
:return: (Tensor) [B x D]
|
||||
"""
|
||||
assert self.conf.is_stochastic
|
||||
std = torch.exp(0.5 * logvar)
|
||||
eps = torch.randn_like(std)
|
||||
return eps * std + mu
|
||||
|
||||
def sample_z(self, n: int, device):
|
||||
assert self.conf.is_stochastic
|
||||
return torch.randn(n, self.conf.enc_out_channels, device=device)
|
||||
|
||||
def noise_to_cond(self, noise: Tensor):
|
||||
raise NotImplementedError()
|
||||
assert self.conf.noise_net_conf is not None
|
||||
return self.noise_net.forward(noise)
|
||||
|
||||
def encode(self, x):
|
||||
cond = self.encoder.forward(x)
|
||||
return {'cond': cond}
|
||||
|
||||
@property
|
||||
def stylespace_sizes(self):
|
||||
modules = list(self.input_blocks.modules()) + list(
|
||||
self.middle_block.modules()) + list(self.output_blocks.modules())
|
||||
sizes = []
|
||||
for module in modules:
|
||||
if isinstance(module, ResBlock):
|
||||
linear = module.cond_emb_layers[-1]
|
||||
sizes.append(linear.weight.shape[0])
|
||||
return sizes
|
||||
|
||||
def encode_stylespace(self, x, return_vector: bool = True):
|
||||
"""
|
||||
encode to style space
|
||||
"""
|
||||
modules = list(self.input_blocks.modules()) + list(
|
||||
self.middle_block.modules()) + list(self.output_blocks.modules())
|
||||
cond = self.encoder.forward(x)
|
||||
S = []
|
||||
for module in modules:
|
||||
if isinstance(module, ResBlock):
|
||||
s = module.cond_emb_layers.forward(cond)
|
||||
S.append(s)
|
||||
|
||||
if return_vector:
|
||||
return torch.cat(S, dim=1)
|
||||
else:
|
||||
return S
|
||||
|
||||
def forward(self,
|
||||
x,
|
||||
t,
|
||||
y=None,
|
||||
x_start=None,
|
||||
cond=None,
|
||||
style=None,
|
||||
noise=None,
|
||||
t_cond=None,
|
||||
**kwargs):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
|
||||
Args:
|
||||
x_start: the original image to encode
|
||||
cond: output of the encoder
|
||||
noise: random noise (to predict the cond)
|
||||
"""
|
||||
if t_cond is None:
|
||||
t_cond = t
|
||||
|
||||
if noise is not None:
|
||||
cond = self.noise_to_cond(noise)
|
||||
|
||||
if cond is None:
|
||||
if x is not None:
|
||||
assert len(x) == len(x_start), f'{len(x)} != {len(x_start)}'
|
||||
tmp = self.encode(x_start)
|
||||
cond = tmp['cond']
|
||||
|
||||
if t is not None:
|
||||
_t_emb = timestep_embedding(t, self.conf.model_channels)
|
||||
_t_cond_emb = timestep_embedding(t_cond, self.conf.model_channels)
|
||||
else:
|
||||
_t_emb = None
|
||||
_t_cond_emb = None
|
||||
|
||||
if self.conf.resnet_two_cond:
|
||||
res = self.time_embed.forward(
|
||||
time_emb=_t_emb,
|
||||
cond=cond,
|
||||
time_cond_emb=_t_cond_emb
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
if self.conf.resnet_two_cond:
|
||||
emb = res.time_emb
|
||||
cond_emb = res.emb
|
||||
else:
|
||||
emb = res.emb
|
||||
cond_emb = None
|
||||
|
||||
style = style or res.style
|
||||
|
||||
assert (y is not None) == (
|
||||
self.conf.num_classes is not None
|
||||
), "must specify y if and only if the model is class-conditional"
|
||||
|
||||
|
||||
if self.conf.num_classes is not None:
|
||||
raise NotImplementedError()
|
||||
|
||||
enc_time_emb = emb
|
||||
mid_time_emb = emb
|
||||
dec_time_emb = emb
|
||||
enc_cond_emb = cond_emb
|
||||
mid_cond_emb = cond_emb
|
||||
dec_cond_emb = cond_emb
|
||||
|
||||
if self.conf.num_users is not None:
|
||||
user_pred = self.user_classifier(cond_emb[:, :self.conf.enc_out_channels // 2])
|
||||
non_user_pred = self.non_user_classifier(cond_emb[:, self.conf.enc_out_channels // 2:])
|
||||
|
||||
hs = [[] for _ in range(len(self.conf.channel_mult))]
|
||||
|
||||
if x is not None:
|
||||
h = x.type(self.dtype)
|
||||
|
||||
# input blocks
|
||||
k = 0
|
||||
for i in range(len(self.input_num_blocks)):
|
||||
for j in range(self.input_num_blocks[i]):
|
||||
h = self.input_blocks[k](h,
|
||||
emb=enc_time_emb,
|
||||
cond=enc_cond_emb)
|
||||
'''print(i, j, h.shape)
|
||||
if h.shape[-1]%2==1:
|
||||
pdb.set_trace()'''
|
||||
hs[i].append(h)
|
||||
k += 1
|
||||
assert k == len(self.input_blocks)
|
||||
|
||||
# middle blocks
|
||||
h = self.middle_block(h, emb=mid_time_emb, cond=mid_cond_emb)
|
||||
else:
|
||||
h = None
|
||||
hs = [[] for _ in range(len(self.conf.channel_mult))]
|
||||
|
||||
# output blocks
|
||||
k = 0
|
||||
for i in range(len(self.output_num_blocks)):
|
||||
for j in range(self.output_num_blocks[i]):
|
||||
try:
|
||||
lateral = hs[-i - 1].pop()
|
||||
except IndexError:
|
||||
lateral = None
|
||||
|
||||
h = self.output_blocks[k](h,
|
||||
emb=dec_time_emb,
|
||||
cond=dec_cond_emb,
|
||||
lateral=lateral)
|
||||
k += 1
|
||||
|
||||
pred = self.out(h)
|
||||
|
||||
return AutoencReturn(pred=pred, cond=cond, user_pred=user_pred, non_user_pred=non_user_pred)
|
||||
|
||||
class UserClassifier(nn.Module):
|
||||
def __init__(self, in_channels, num_classes):
|
||||
super().__init__()
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(in_channels, 256),
|
||||
nn.ReLU(),
|
||||
nn.Linear(256, num_classes),
|
||||
nn.Softmax(dim=1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.fc(x)
|
||||
|
||||
class GradReverse(torch.autograd.Function):
|
||||
"""
|
||||
Implement the gradient reversal layer for the convenience of domain adaptation neural network.
|
||||
The forward part is the identity function while the backward part is the negative function.
|
||||
"""
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
return x.view_as(x)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output.neg()
|
||||
|
||||
class GradientReversalLayer(nn.Module):
|
||||
def __init__(self):
|
||||
super(GradientReversalLayer, self).__init__()
|
||||
|
||||
def forward(self, inputs):
|
||||
return GradReverse.apply(inputs)
|
||||
|
||||
class UserClassifierGradientReverse(nn.Module):
|
||||
def __init__(self, in_channels, num_classes):
|
||||
super().__init__()
|
||||
self.grl = GradientReversalLayer()
|
||||
self.fc = UserClassifier(in_channels, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.grl(x)
|
||||
return self.fc(x)
|
||||
|
||||
class AutoencReturn(NamedTuple):
|
||||
pred: Tensor
|
||||
cond: Tensor = None
|
||||
user_pred: Tensor = None
|
||||
non_user_pred: Tensor = None
|
||||
|
||||
|
||||
class EmbedReturn(NamedTuple):
|
||||
emb: Tensor = None
|
||||
time_emb: Tensor = None
|
||||
style: Tensor = None
|
||||
|
||||
class TimeStyleSeperateEmbed(nn.Module):
|
||||
def __init__(self, time_channels, time_out_channels):
|
||||
super().__init__()
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(time_channels, time_out_channels),
|
||||
nn.SiLU(),
|
||||
linear(time_out_channels, time_out_channels),
|
||||
)
|
||||
self.cond_combine = nn.Sequential(
|
||||
nn.Linear(time_out_channels * 2, time_out_channels),
|
||||
nn.SiLU()
|
||||
)
|
||||
self.style = nn.Identity()
|
||||
|
||||
def forward(self, time_emb=None, cond=None, **kwargs):
|
||||
if time_emb is None:
|
||||
time_emb = None
|
||||
else:
|
||||
time_emb = self.time_embed(time_emb)
|
||||
|
||||
style = self.style(cond)
|
||||
return EmbedReturn(emb=style, time_emb=time_emb, style=style)
|
Loading…
Add table
Add a link
Reference in a new issue