import math from dataclasses import dataclass from numbers import Number from typing import NamedTuple, Tuple, Union import numpy as np import torch as th from torch import nn import torch.nn.functional as F from choices import * from config_base import BaseConfig from .blocks import * from .nn import (conv_nd, linear, normalization, timestep_embedding, torch_checkpoint, zero_module) @dataclass class BeatGANsUNetConfig(BaseConfig): image_size: int = 64 in_channels: int = 2 model_channels: int = 64 out_channels: int = 2 num_res_blocks: int = 2 num_input_res_blocks: int = None embed_channels: int = 512 attention_resolutions: Tuple[int] = (16, ) time_embed_channels: int = None dropout: float = 0.1 channel_mult: Tuple[int] = (1, 2, 4, 8) input_channel_mult: Tuple[int] = None conv_resample: bool = True dims: int = 2 num_classes: int = None use_checkpoint: bool = False num_heads: int = 1 num_head_channels: int = -1 num_heads_upsample: int = -1 resblock_updown: bool = True use_new_attention_order: bool = False resnet_two_cond: bool = False resnet_cond_channels: int = None resnet_use_zero_module: bool = True attn_checkpoint: bool = False num_users: int = None def make_model(self): return BeatGANsUNetModel(self) class BeatGANsUNetModel(nn.Module): def __init__(self, conf: BeatGANsUNetConfig): super().__init__() self.conf = conf if conf.num_heads_upsample == -1: self.num_heads_upsample = conf.num_heads self.dtype = th.float32 self.time_emb_channels = conf.time_embed_channels or conf.model_channels self.time_embed = nn.Sequential( linear(self.time_emb_channels, conf.embed_channels), nn.SiLU(), linear(conf.embed_channels, conf.embed_channels), ) if conf.num_classes is not None: self.label_emb = nn.Embedding(conf.num_classes, conf.embed_channels) ch = input_ch = int(conf.channel_mult[0] * conf.model_channels) self.input_blocks = nn.ModuleList([ TimestepEmbedSequential( conv_nd(conf.dims, conf.in_channels, ch, 3, padding=1)) ]) kwargs = dict( use_condition=True, two_cond=conf.resnet_two_cond, use_zero_module=conf.resnet_use_zero_module, cond_emb_channels=conf.resnet_cond_channels, ) self._feature_size = ch # input_block_chans = [ch] input_block_chans = [[] for _ in range(len(conf.channel_mult))] input_block_chans[0].append(ch) # number of blocks at each resolution self.input_num_blocks = [0 for _ in range(len(conf.channel_mult))] self.input_num_blocks[0] = 1 self.output_num_blocks = [0 for _ in range(len(conf.channel_mult))] ds = 1 resolution = conf.image_size for level, mult in enumerate(conf.input_channel_mult or conf.channel_mult): for _ in range(conf.num_input_res_blocks or conf.num_res_blocks): layers = [ ResBlockConfig( ch, conf.embed_channels, conf.dropout, out_channels=int(mult * conf.model_channels), dims=conf.dims, use_checkpoint=conf.use_checkpoint, **kwargs, ).make_model() ] ch = int(mult * conf.model_channels) if resolution in conf.attention_resolutions: layers.append( AttentionBlock( ch, use_checkpoint=conf.use_checkpoint or conf.attn_checkpoint, num_heads=conf.num_heads, num_head_channels=conf.num_head_channels, use_new_attention_order=conf. use_new_attention_order, )) self.input_blocks.append(TimestepEmbedSequential(*layers)) self._feature_size += ch input_block_chans[level].append(ch) self.input_num_blocks[level] += 1 if level != len(conf.channel_mult) - 1: resolution //= 2 out_ch = ch self.input_blocks.append( TimestepEmbedSequential( ResBlockConfig( ch, conf.embed_channels, conf.dropout, out_channels=out_ch, dims=conf.dims, use_checkpoint=conf.use_checkpoint, down=True, **kwargs, ).make_model() if conf. resblock_updown else Downsample(ch, conf.conv_resample, dims=conf.dims, out_channels=out_ch))) ch = out_ch input_block_chans[level + 1].append(ch) self.input_num_blocks[level + 1] += 1 ds *= 2 self._feature_size += ch self.middle_block = TimestepEmbedSequential( ResBlockConfig( ch, conf.embed_channels, conf.dropout, dims=conf.dims, use_checkpoint=conf.use_checkpoint, **kwargs, ).make_model(), AttentionBlock( ch, use_checkpoint=conf.use_checkpoint or conf.attn_checkpoint, num_heads=conf.num_heads, num_head_channels=conf.num_head_channels, use_new_attention_order=conf.use_new_attention_order, ), ResBlockConfig( ch, conf.embed_channels, conf.dropout, dims=conf.dims, use_checkpoint=conf.use_checkpoint, **kwargs, ).make_model(), ) self._feature_size += ch self.output_blocks = nn.ModuleList([]) for level, mult in list(enumerate(conf.channel_mult))[::-1]: for i in range(conf.num_res_blocks + 1): try: ich = input_block_chans[level].pop() except IndexError: ich = 0 layers = [ ResBlockConfig( channels=ch + ich, emb_channels=conf.embed_channels, dropout=conf.dropout, out_channels=int(conf.model_channels * mult), dims=conf.dims, use_checkpoint=conf.use_checkpoint, has_lateral=True if ich > 0 else False, lateral_channels=None, **kwargs, ).make_model() ] ch = int(conf.model_channels * mult) if resolution in conf.attention_resolutions: layers.append( AttentionBlock( ch, use_checkpoint=conf.use_checkpoint or conf.attn_checkpoint, num_heads=self.num_heads_upsample, num_head_channels=conf.num_head_channels, use_new_attention_order=conf. use_new_attention_order, )) if level and i == conf.num_res_blocks: resolution *= 2 out_ch = ch layers.append( ResBlockConfig( ch, conf.embed_channels, conf.dropout, out_channels=out_ch, dims=conf.dims, use_checkpoint=conf.use_checkpoint, up=True, **kwargs, ).make_model() if ( conf.resblock_updown ) else Upsample(ch, conf.conv_resample, dims=conf.dims, out_channels=out_ch)) ds //= 2 self.output_blocks.append(TimestepEmbedSequential(*layers)) self.output_num_blocks[level] += 1 self._feature_size += ch if conf.resnet_use_zero_module: self.out = nn.Sequential( normalization(ch), nn.SiLU(), zero_module( conv_nd(conf.dims, input_ch, conf.out_channels, 3, padding=1)), ) else: self.out = nn.Sequential( normalization(ch), nn.SiLU(), conv_nd(conf.dims, input_ch, conf.out_channels, 3, padding=1), ) def forward(self, x, t, y=None, **kwargs): """ Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. :param timesteps: a 1-D batch of timesteps. :param y: an [N] Tensor of labels, if class-conditional. :return: an [N x C x ...] Tensor of outputs. """ assert (y is not None) == ( self.conf.num_classes is not None ), "must specify y if and only if the model is class-conditional" # hs = [] hs = [[] for _ in range(len(self.conf.channel_mult))] emb = self.time_embed(timestep_embedding(t, self.time_emb_channels)) if self.conf.num_classes is not None: raise NotImplementedError() h = x.type(self.dtype) k = 0 for i in range(len(self.input_num_blocks)): for j in range(self.input_num_blocks[i]): h = self.input_blocks[k](h, emb=emb) hs[i].append(h) k += 1 assert k == len(self.input_blocks) # middle blocks h = self.middle_block(h, emb=emb) # output blocks k = 0 for i in range(len(self.output_num_blocks)): for j in range(self.output_num_blocks[i]): try: lateral = hs[-i - 1].pop() except IndexError: lateral = None h = self.output_blocks[k](h, emb=emb, lateral=lateral) k += 1 h = h.type(x.dtype) pred = self.out(h) return Return(pred=pred) class Return(NamedTuple): pred: th.Tensor @dataclass class BeatGANsEncoderConfig(BaseConfig): image_size: int in_channels: int model_channels: int out_hid_channels: int out_channels: int num_res_blocks: int attention_resolutions: Tuple[int] dropout: float = 0 channel_mult: Tuple[int] = (1, 2, 4, 8) use_time_condition: bool = True conv_resample: bool = True dims: int = 2 use_checkpoint: bool = False num_heads: int = 1 num_head_channels: int = -1 resblock_updown: bool = False use_new_attention_order: bool = False pool: str = 'adaptivenonzero' def make_model(self): return BeatGANsEncoderModel(self) class BeatGANsEncoderModel(nn.Module): """ The half UNet model with attention and timestep embedding. For usage, see UNet. """ def __init__(self, conf: BeatGANsEncoderConfig): super().__init__() self.conf = conf self.dtype = th.float32 if conf.use_time_condition: time_embed_dim = conf.model_channels * 4 self.time_embed = nn.Sequential( linear(conf.model_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, time_embed_dim), ) else: time_embed_dim = None ch = int(conf.channel_mult[0] * conf.model_channels) self.input_blocks = nn.ModuleList([ TimestepEmbedSequential( conv_nd(conf.dims, conf.in_channels, ch, 3, padding=1)) ]) self._feature_size = ch input_block_chans = [ch] ds = 1 resolution = conf.image_size for level, mult in enumerate(conf.channel_mult): for _ in range(conf.num_res_blocks): layers = [ ResBlockConfig( ch, time_embed_dim, conf.dropout, out_channels=int(mult * conf.model_channels), dims=conf.dims, use_condition=conf.use_time_condition, use_checkpoint=conf.use_checkpoint, ).make_model() ] ch = int(mult * conf.model_channels) if resolution in conf.attention_resolutions: layers.append( AttentionBlock( ch, use_checkpoint=conf.use_checkpoint, num_heads=conf.num_heads, num_head_channels=conf.num_head_channels, use_new_attention_order=conf. use_new_attention_order, )) self.input_blocks.append(TimestepEmbedSequential(*layers)) self._feature_size += ch input_block_chans.append(ch) if level != len(conf.channel_mult) - 1: resolution //= 2 out_ch = ch self.input_blocks.append( TimestepEmbedSequential( ResBlockConfig( ch, time_embed_dim, conf.dropout, out_channels=out_ch, dims=conf.dims, use_condition=conf.use_time_condition, use_checkpoint=conf.use_checkpoint, down=True, ).make_model() if ( conf.resblock_updown ) else Downsample(ch, conf.conv_resample, dims=conf.dims, out_channels=out_ch))) ch = out_ch input_block_chans.append(ch) ds *= 2 self._feature_size += ch self.middle_block = TimestepEmbedSequential( ResBlockConfig( ch, time_embed_dim, conf.dropout, dims=conf.dims, use_condition=conf.use_time_condition, use_checkpoint=conf.use_checkpoint, ).make_model(), AttentionBlock( ch, use_checkpoint=conf.use_checkpoint, num_heads=conf.num_heads, num_head_channels=conf.num_head_channels, use_new_attention_order=conf.use_new_attention_order, ), ResBlockConfig( ch, time_embed_dim, conf.dropout, dims=conf.dims, use_condition=conf.use_time_condition, use_checkpoint=conf.use_checkpoint, ).make_model(), ) self._feature_size += ch if conf.pool == "adaptivenonzero": self.out = nn.Sequential( normalization(ch), nn.SiLU(), nn.AdaptiveAvgPool1d(1), conv_nd(conf.dims, ch, conf.out_channels, 1), nn.Flatten(), ) else: raise NotImplementedError(f"Unexpected {conf.pool} pooling") def forward(self, x, t=None, return_2d_feature=False): """ Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. :param timesteps: a 1-D batch of timesteps. :return: an [N x K] Tensor of outputs. """ if self.conf.use_time_condition: emb = self.time_embed(timestep_embedding(t, self.model_channels)) else: emb = None results = [] h = x.type(self.dtype) for module in self.input_blocks: h = module(h, emb=emb) if self.conf.pool.startswith("spatial"): results.append(h.type(x.dtype).mean(dim=(2, 3))) h = self.middle_block(h, emb=emb) if self.conf.pool.startswith("spatial"): results.append(h.type(x.dtype).mean(dim=(2, 3))) h = th.cat(results, axis=-1) else: h = h.type(x.dtype) h_2d = h h = self.out(h) if return_2d_feature: return h, h_2d else: return h def forward_flatten(self, x): """ transform the last 2d feature into a flatten vector """ h = self.out(x) return h class SuperResModel(BeatGANsUNetModel): """ A UNetModel that performs super-resolution. Expects an extra kwarg `low_res` to condition on a low-resolution image. """ def __init__(self, image_size, in_channels, *args, **kwargs): super().__init__(image_size, in_channels * 2, *args, **kwargs) def forward(self, x, timesteps, low_res=None, **kwargs): _, _, new_height, new_width = x.shape upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear") x = th.cat([x, upsampled], dim=1) return super().forward(x, timesteps, **kwargs)