DisMouse/model/unet.py

505 lines
18 KiB
Python
Raw Normal View History

2024-10-08 14:18:47 +02:00
import math
from dataclasses import dataclass
from numbers import Number
from typing import NamedTuple, Tuple, Union
import numpy as np
import torch as th
from torch import nn
import torch.nn.functional as F
from choices import *
from config_base import BaseConfig
from .blocks import *
from .nn import (conv_nd, linear, normalization, timestep_embedding,
torch_checkpoint, zero_module)
@dataclass
class BeatGANsUNetConfig(BaseConfig):
image_size: int = 64
in_channels: int = 2
model_channels: int = 64
out_channels: int = 2
num_res_blocks: int = 2
num_input_res_blocks: int = None
embed_channels: int = 512
attention_resolutions: Tuple[int] = (16, )
time_embed_channels: int = None
dropout: float = 0.1
channel_mult: Tuple[int] = (1, 2, 4, 8)
input_channel_mult: Tuple[int] = None
conv_resample: bool = True
dims: int = 2
num_classes: int = None
use_checkpoint: bool = False
num_heads: int = 1
num_head_channels: int = -1
num_heads_upsample: int = -1
resblock_updown: bool = True
use_new_attention_order: bool = False
resnet_two_cond: bool = False
resnet_cond_channels: int = None
resnet_use_zero_module: bool = True
attn_checkpoint: bool = False
num_users: int = None
def make_model(self):
return BeatGANsUNetModel(self)
class BeatGANsUNetModel(nn.Module):
def __init__(self, conf: BeatGANsUNetConfig):
super().__init__()
self.conf = conf
if conf.num_heads_upsample == -1:
self.num_heads_upsample = conf.num_heads
self.dtype = th.float32
self.time_emb_channels = conf.time_embed_channels or conf.model_channels
self.time_embed = nn.Sequential(
linear(self.time_emb_channels, conf.embed_channels),
nn.SiLU(),
linear(conf.embed_channels, conf.embed_channels),
)
if conf.num_classes is not None:
self.label_emb = nn.Embedding(conf.num_classes,
conf.embed_channels)
ch = input_ch = int(conf.channel_mult[0] * conf.model_channels)
self.input_blocks = nn.ModuleList([
TimestepEmbedSequential(
conv_nd(conf.dims, conf.in_channels, ch, 3, padding=1))
])
kwargs = dict(
use_condition=True,
two_cond=conf.resnet_two_cond,
use_zero_module=conf.resnet_use_zero_module,
cond_emb_channels=conf.resnet_cond_channels,
)
self._feature_size = ch
# input_block_chans = [ch]
input_block_chans = [[] for _ in range(len(conf.channel_mult))]
input_block_chans[0].append(ch)
# number of blocks at each resolution
self.input_num_blocks = [0 for _ in range(len(conf.channel_mult))]
self.input_num_blocks[0] = 1
self.output_num_blocks = [0 for _ in range(len(conf.channel_mult))]
ds = 1
resolution = conf.image_size
for level, mult in enumerate(conf.input_channel_mult
or conf.channel_mult):
for _ in range(conf.num_input_res_blocks or conf.num_res_blocks):
layers = [
ResBlockConfig(
ch,
conf.embed_channels,
conf.dropout,
out_channels=int(mult * conf.model_channels),
dims=conf.dims,
use_checkpoint=conf.use_checkpoint,
**kwargs,
).make_model()
]
ch = int(mult * conf.model_channels)
if resolution in conf.attention_resolutions:
layers.append(
AttentionBlock(
ch,
use_checkpoint=conf.use_checkpoint
or conf.attn_checkpoint,
num_heads=conf.num_heads,
num_head_channels=conf.num_head_channels,
use_new_attention_order=conf.
use_new_attention_order,
))
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
input_block_chans[level].append(ch)
self.input_num_blocks[level] += 1
if level != len(conf.channel_mult) - 1:
resolution //= 2
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlockConfig(
ch,
conf.embed_channels,
conf.dropout,
out_channels=out_ch,
dims=conf.dims,
use_checkpoint=conf.use_checkpoint,
down=True,
**kwargs,
).make_model() if conf.
resblock_updown else Downsample(ch,
conf.conv_resample,
dims=conf.dims,
out_channels=out_ch)))
ch = out_ch
input_block_chans[level + 1].append(ch)
self.input_num_blocks[level + 1] += 1
ds *= 2
self._feature_size += ch
self.middle_block = TimestepEmbedSequential(
ResBlockConfig(
ch,
conf.embed_channels,
conf.dropout,
dims=conf.dims,
use_checkpoint=conf.use_checkpoint,
**kwargs,
).make_model(),
AttentionBlock(
ch,
use_checkpoint=conf.use_checkpoint or conf.attn_checkpoint,
num_heads=conf.num_heads,
num_head_channels=conf.num_head_channels,
use_new_attention_order=conf.use_new_attention_order,
),
ResBlockConfig(
ch,
conf.embed_channels,
conf.dropout,
dims=conf.dims,
use_checkpoint=conf.use_checkpoint,
**kwargs,
).make_model(),
)
self._feature_size += ch
self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(conf.channel_mult))[::-1]:
for i in range(conf.num_res_blocks + 1):
try:
ich = input_block_chans[level].pop()
except IndexError:
ich = 0
layers = [
ResBlockConfig(
channels=ch + ich,
emb_channels=conf.embed_channels,
dropout=conf.dropout,
out_channels=int(conf.model_channels * mult),
dims=conf.dims,
use_checkpoint=conf.use_checkpoint,
has_lateral=True if ich > 0 else False,
lateral_channels=None,
**kwargs,
).make_model()
]
ch = int(conf.model_channels * mult)
if resolution in conf.attention_resolutions:
layers.append(
AttentionBlock(
ch,
use_checkpoint=conf.use_checkpoint
or conf.attn_checkpoint,
num_heads=self.num_heads_upsample,
num_head_channels=conf.num_head_channels,
use_new_attention_order=conf.
use_new_attention_order,
))
if level and i == conf.num_res_blocks:
resolution *= 2
out_ch = ch
layers.append(
ResBlockConfig(
ch,
conf.embed_channels,
conf.dropout,
out_channels=out_ch,
dims=conf.dims,
use_checkpoint=conf.use_checkpoint,
up=True,
**kwargs,
).make_model() if (
conf.resblock_updown
) else Upsample(ch,
conf.conv_resample,
dims=conf.dims,
out_channels=out_ch))
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
self.output_num_blocks[level] += 1
self._feature_size += ch
if conf.resnet_use_zero_module:
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
zero_module(
conv_nd(conf.dims,
input_ch,
conf.out_channels,
3,
padding=1)),
)
else:
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
conv_nd(conf.dims, input_ch, conf.out_channels, 3, padding=1),
)
def forward(self, x, t, y=None, **kwargs):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs.
"""
assert (y is not None) == (
self.conf.num_classes is not None
), "must specify y if and only if the model is class-conditional"
# hs = []
hs = [[] for _ in range(len(self.conf.channel_mult))]
emb = self.time_embed(timestep_embedding(t, self.time_emb_channels))
if self.conf.num_classes is not None:
raise NotImplementedError()
h = x.type(self.dtype)
k = 0
for i in range(len(self.input_num_blocks)):
for j in range(self.input_num_blocks[i]):
h = self.input_blocks[k](h, emb=emb)
hs[i].append(h)
k += 1
assert k == len(self.input_blocks)
# middle blocks
h = self.middle_block(h, emb=emb)
# output blocks
k = 0
for i in range(len(self.output_num_blocks)):
for j in range(self.output_num_blocks[i]):
try:
lateral = hs[-i - 1].pop()
except IndexError:
lateral = None
h = self.output_blocks[k](h, emb=emb, lateral=lateral)
k += 1
h = h.type(x.dtype)
pred = self.out(h)
return Return(pred=pred)
class Return(NamedTuple):
pred: th.Tensor
@dataclass
class BeatGANsEncoderConfig(BaseConfig):
image_size: int
in_channels: int
model_channels: int
out_hid_channels: int
out_channels: int
num_res_blocks: int
attention_resolutions: Tuple[int]
dropout: float = 0
channel_mult: Tuple[int] = (1, 2, 4, 8)
use_time_condition: bool = True
conv_resample: bool = True
dims: int = 2
use_checkpoint: bool = False
num_heads: int = 1
num_head_channels: int = -1
resblock_updown: bool = False
use_new_attention_order: bool = False
pool: str = 'adaptivenonzero'
def make_model(self):
return BeatGANsEncoderModel(self)
class BeatGANsEncoderModel(nn.Module):
"""
The half UNet model with attention and timestep embedding.
For usage, see UNet.
"""
def __init__(self, conf: BeatGANsEncoderConfig):
super().__init__()
self.conf = conf
self.dtype = th.float32
if conf.use_time_condition:
time_embed_dim = conf.model_channels * 4
self.time_embed = nn.Sequential(
linear(conf.model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
else:
time_embed_dim = None
ch = int(conf.channel_mult[0] * conf.model_channels)
self.input_blocks = nn.ModuleList([
TimestepEmbedSequential(
conv_nd(conf.dims, conf.in_channels, ch, 3, padding=1))
])
self._feature_size = ch
input_block_chans = [ch]
ds = 1
resolution = conf.image_size
for level, mult in enumerate(conf.channel_mult):
for _ in range(conf.num_res_blocks):
layers = [
ResBlockConfig(
ch,
time_embed_dim,
conf.dropout,
out_channels=int(mult * conf.model_channels),
dims=conf.dims,
use_condition=conf.use_time_condition,
use_checkpoint=conf.use_checkpoint,
).make_model()
]
ch = int(mult * conf.model_channels)
if resolution in conf.attention_resolutions:
layers.append(
AttentionBlock(
ch,
use_checkpoint=conf.use_checkpoint,
num_heads=conf.num_heads,
num_head_channels=conf.num_head_channels,
use_new_attention_order=conf.
use_new_attention_order,
))
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(conf.channel_mult) - 1:
resolution //= 2
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlockConfig(
ch,
time_embed_dim,
conf.dropout,
out_channels=out_ch,
dims=conf.dims,
use_condition=conf.use_time_condition,
use_checkpoint=conf.use_checkpoint,
down=True,
).make_model() if (
conf.resblock_updown
) else Downsample(ch,
conf.conv_resample,
dims=conf.dims,
out_channels=out_ch)))
ch = out_ch
input_block_chans.append(ch)
ds *= 2
self._feature_size += ch
self.middle_block = TimestepEmbedSequential(
ResBlockConfig(
ch,
time_embed_dim,
conf.dropout,
dims=conf.dims,
use_condition=conf.use_time_condition,
use_checkpoint=conf.use_checkpoint,
).make_model(),
AttentionBlock(
ch,
use_checkpoint=conf.use_checkpoint,
num_heads=conf.num_heads,
num_head_channels=conf.num_head_channels,
use_new_attention_order=conf.use_new_attention_order,
),
ResBlockConfig(
ch,
time_embed_dim,
conf.dropout,
dims=conf.dims,
use_condition=conf.use_time_condition,
use_checkpoint=conf.use_checkpoint,
).make_model(),
)
self._feature_size += ch
if conf.pool == "adaptivenonzero":
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
nn.AdaptiveAvgPool1d(1),
conv_nd(conf.dims, ch, conf.out_channels, 1),
nn.Flatten(),
)
else:
raise NotImplementedError(f"Unexpected {conf.pool} pooling")
def forward(self, x, t=None, return_2d_feature=False):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:return: an [N x K] Tensor of outputs.
"""
if self.conf.use_time_condition:
emb = self.time_embed(timestep_embedding(t, self.model_channels))
else:
emb = None
results = []
h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb=emb)
if self.conf.pool.startswith("spatial"):
results.append(h.type(x.dtype).mean(dim=(2, 3)))
h = self.middle_block(h, emb=emb)
if self.conf.pool.startswith("spatial"):
results.append(h.type(x.dtype).mean(dim=(2, 3)))
h = th.cat(results, axis=-1)
else:
h = h.type(x.dtype)
h_2d = h
h = self.out(h)
if return_2d_feature:
return h, h_2d
else:
return h
def forward_flatten(self, x):
"""
transform the last 2d feature into a flatten vector
"""
h = self.out(x)
return h
class SuperResModel(BeatGANsUNetModel):
"""
A UNetModel that performs super-resolution.
Expects an extra kwarg `low_res` to condition on a low-resolution image.
"""
def __init__(self, image_size, in_channels, *args, **kwargs):
super().__init__(image_size, in_channels * 2, *args, **kwargs)
def forward(self, x, timesteps, low_res=None, **kwargs):
_, _, new_height, new_width = x.shape
upsampled = F.interpolate(low_res, (new_height, new_width),
mode="bilinear")
x = th.cat([x, upsampled], dim=1)
return super().forward(x, timesteps, **kwargs)