import torch from torch import Tensor, nn from torch.nn.functional import silu from .latentnet import * from .unet import * from choices import * @dataclass class BeatGANsAutoencConfig(BeatGANsUNetConfig): enc_out_channels: int = 512 enc_attn_resolutions: Tuple[int] = None enc_pool: str = 'depthconv' enc_num_res_block: int = 2 enc_channel_mult: Tuple[int] = None enc_grad_checkpoint: bool = False latent_net_conf: MLPSkipNetConfig = None def make_model(self): return BeatGANsAutoencModel(self) class BeatGANsAutoencModel(BeatGANsUNetModel): def __init__(self, conf: BeatGANsAutoencConfig): super().__init__(conf) self.conf = conf self.time_embed = TimeStyleSeperateEmbed( time_channels=conf.model_channels, time_out_channels=conf.embed_channels, ) self.encoder = BeatGANsEncoderConfig( image_size=conf.image_size, in_channels=conf.in_channels, model_channels=conf.model_channels, out_hid_channels=conf.enc_out_channels, out_channels=conf.enc_out_channels, num_res_blocks=conf.enc_num_res_block, attention_resolutions=(conf.enc_attn_resolutions or conf.attention_resolutions), dropout=conf.dropout, channel_mult=conf.enc_channel_mult or conf.channel_mult, use_time_condition=False, conv_resample=conf.conv_resample, dims=conf.dims, use_checkpoint=conf.use_checkpoint or conf.enc_grad_checkpoint, num_heads=conf.num_heads, num_head_channels=conf.num_head_channels, resblock_updown=conf.resblock_updown, use_new_attention_order=conf.use_new_attention_order, pool=conf.enc_pool, ).make_model() self.user_classifier = UserClassifier(conf.enc_out_channels//2, conf.num_users) self.non_user_classifier = UserClassifierGradientReverse(conf.enc_out_channels//2, conf.num_users) if conf.latent_net_conf is not None: self.latent_net = conf.latent_net_conf.make_model() def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: """ Reparameterization trick to sample from N(mu, var) from N(0,1). :param mu: (Tensor) Mean of the latent Gaussian [B x D] :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D] :return: (Tensor) [B x D] """ assert self.conf.is_stochastic std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return eps * std + mu def sample_z(self, n: int, device): assert self.conf.is_stochastic return torch.randn(n, self.conf.enc_out_channels, device=device) def noise_to_cond(self, noise: Tensor): raise NotImplementedError() assert self.conf.noise_net_conf is not None return self.noise_net.forward(noise) def encode(self, x): cond = self.encoder.forward(x) return {'cond': cond} @property def stylespace_sizes(self): modules = list(self.input_blocks.modules()) + list( self.middle_block.modules()) + list(self.output_blocks.modules()) sizes = [] for module in modules: if isinstance(module, ResBlock): linear = module.cond_emb_layers[-1] sizes.append(linear.weight.shape[0]) return sizes def encode_stylespace(self, x, return_vector: bool = True): """ encode to style space """ modules = list(self.input_blocks.modules()) + list( self.middle_block.modules()) + list(self.output_blocks.modules()) cond = self.encoder.forward(x) S = [] for module in modules: if isinstance(module, ResBlock): s = module.cond_emb_layers.forward(cond) S.append(s) if return_vector: return torch.cat(S, dim=1) else: return S def forward(self, x, t, y=None, x_start=None, cond=None, style=None, noise=None, t_cond=None, **kwargs): """ Apply the model to an input batch. Args: x_start: the original image to encode cond: output of the encoder noise: random noise (to predict the cond) """ if t_cond is None: t_cond = t if noise is not None: cond = self.noise_to_cond(noise) if cond is None: if x is not None: assert len(x) == len(x_start), f'{len(x)} != {len(x_start)}' tmp = self.encode(x_start) cond = tmp['cond'] if t is not None: _t_emb = timestep_embedding(t, self.conf.model_channels) _t_cond_emb = timestep_embedding(t_cond, self.conf.model_channels) else: _t_emb = None _t_cond_emb = None if self.conf.resnet_two_cond: res = self.time_embed.forward( time_emb=_t_emb, cond=cond, time_cond_emb=_t_cond_emb ) else: raise NotImplementedError() if self.conf.resnet_two_cond: emb = res.time_emb cond_emb = res.emb else: emb = res.emb cond_emb = None style = style or res.style assert (y is not None) == ( self.conf.num_classes is not None ), "must specify y if and only if the model is class-conditional" if self.conf.num_classes is not None: raise NotImplementedError() enc_time_emb = emb mid_time_emb = emb dec_time_emb = emb enc_cond_emb = cond_emb mid_cond_emb = cond_emb dec_cond_emb = cond_emb if self.conf.num_users is not None: user_pred = self.user_classifier(cond_emb[:, :self.conf.enc_out_channels // 2]) non_user_pred = self.non_user_classifier(cond_emb[:, self.conf.enc_out_channels // 2:]) hs = [[] for _ in range(len(self.conf.channel_mult))] if x is not None: h = x.type(self.dtype) # input blocks k = 0 for i in range(len(self.input_num_blocks)): for j in range(self.input_num_blocks[i]): h = self.input_blocks[k](h, emb=enc_time_emb, cond=enc_cond_emb) '''print(i, j, h.shape) if h.shape[-1]%2==1: pdb.set_trace()''' hs[i].append(h) k += 1 assert k == len(self.input_blocks) # middle blocks h = self.middle_block(h, emb=mid_time_emb, cond=mid_cond_emb) else: h = None hs = [[] for _ in range(len(self.conf.channel_mult))] # output blocks k = 0 for i in range(len(self.output_num_blocks)): for j in range(self.output_num_blocks[i]): try: lateral = hs[-i - 1].pop() except IndexError: lateral = None h = self.output_blocks[k](h, emb=dec_time_emb, cond=dec_cond_emb, lateral=lateral) k += 1 pred = self.out(h) return AutoencReturn(pred=pred, cond=cond, user_pred=user_pred, non_user_pred=non_user_pred) class UserClassifier(nn.Module): def __init__(self, in_channels, num_classes): super().__init__() self.fc = nn.Sequential( nn.Linear(in_channels, 256), nn.ReLU(), nn.Linear(256, num_classes), nn.Softmax(dim=1) ) def forward(self, x): return self.fc(x) class GradReverse(torch.autograd.Function): """ Implement the gradient reversal layer for the convenience of domain adaptation neural network. The forward part is the identity function while the backward part is the negative function. """ @staticmethod def forward(ctx, x): return x.view_as(x) @staticmethod def backward(ctx, grad_output): return grad_output.neg() class GradientReversalLayer(nn.Module): def __init__(self): super(GradientReversalLayer, self).__init__() def forward(self, inputs): return GradReverse.apply(inputs) class UserClassifierGradientReverse(nn.Module): def __init__(self, in_channels, num_classes): super().__init__() self.grl = GradientReversalLayer() self.fc = UserClassifier(in_channels, num_classes) def forward(self, x): x = self.grl(x) return self.fc(x) class AutoencReturn(NamedTuple): pred: Tensor cond: Tensor = None user_pred: Tensor = None non_user_pred: Tensor = None class EmbedReturn(NamedTuple): emb: Tensor = None time_emb: Tensor = None style: Tensor = None class TimeStyleSeperateEmbed(nn.Module): def __init__(self, time_channels, time_out_channels): super().__init__() self.time_embed = nn.Sequential( linear(time_channels, time_out_channels), nn.SiLU(), linear(time_out_channels, time_out_channels), ) self.cond_combine = nn.Sequential( nn.Linear(time_out_channels * 2, time_out_channels), nn.SiLU() ) self.style = nn.Identity() def forward(self, time_emb=None, cond=None, **kwargs): if time_emb is None: time_emb = None else: time_emb = self.time_embed(time_emb) style = self.style(cond) return EmbedReturn(emb=style, time_emb=time_emb, style=style)