310 lines
No EOL
9.9 KiB
Python
310 lines
No EOL
9.9 KiB
Python
import torch
|
|
from torch import Tensor, nn
|
|
from torch.nn.functional import silu
|
|
from .latentnet import *
|
|
from .unet import *
|
|
from choices import *
|
|
|
|
|
|
@dataclass
|
|
class BeatGANsAutoencConfig(BeatGANsUNetConfig):
|
|
enc_out_channels: int = 512
|
|
enc_attn_resolutions: Tuple[int] = None
|
|
enc_pool: str = 'depthconv'
|
|
enc_num_res_block: int = 2
|
|
enc_channel_mult: Tuple[int] = None
|
|
enc_grad_checkpoint: bool = False
|
|
latent_net_conf: MLPSkipNetConfig = None
|
|
|
|
def make_model(self):
|
|
return BeatGANsAutoencModel(self)
|
|
|
|
|
|
class BeatGANsAutoencModel(BeatGANsUNetModel):
|
|
def __init__(self, conf: BeatGANsAutoencConfig):
|
|
super().__init__(conf)
|
|
self.conf = conf
|
|
|
|
self.time_embed = TimeStyleSeperateEmbed(
|
|
time_channels=conf.model_channels,
|
|
time_out_channels=conf.embed_channels,
|
|
)
|
|
|
|
self.encoder = BeatGANsEncoderConfig(
|
|
image_size=conf.image_size,
|
|
in_channels=conf.in_channels,
|
|
model_channels=conf.model_channels,
|
|
out_hid_channels=conf.enc_out_channels,
|
|
out_channels=conf.enc_out_channels,
|
|
num_res_blocks=conf.enc_num_res_block,
|
|
attention_resolutions=(conf.enc_attn_resolutions
|
|
or conf.attention_resolutions),
|
|
dropout=conf.dropout,
|
|
channel_mult=conf.enc_channel_mult or conf.channel_mult,
|
|
use_time_condition=False,
|
|
conv_resample=conf.conv_resample,
|
|
dims=conf.dims,
|
|
use_checkpoint=conf.use_checkpoint or conf.enc_grad_checkpoint,
|
|
num_heads=conf.num_heads,
|
|
num_head_channels=conf.num_head_channels,
|
|
resblock_updown=conf.resblock_updown,
|
|
use_new_attention_order=conf.use_new_attention_order,
|
|
pool=conf.enc_pool,
|
|
).make_model()
|
|
|
|
self.user_classifier = UserClassifier(conf.enc_out_channels//2, conf.num_users)
|
|
self.non_user_classifier = UserClassifierGradientReverse(conf.enc_out_channels//2, conf.num_users)
|
|
|
|
if conf.latent_net_conf is not None:
|
|
self.latent_net = conf.latent_net_conf.make_model()
|
|
|
|
def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
|
|
"""
|
|
Reparameterization trick to sample from N(mu, var) from
|
|
N(0,1).
|
|
:param mu: (Tensor) Mean of the latent Gaussian [B x D]
|
|
:param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
|
|
:return: (Tensor) [B x D]
|
|
"""
|
|
assert self.conf.is_stochastic
|
|
std = torch.exp(0.5 * logvar)
|
|
eps = torch.randn_like(std)
|
|
return eps * std + mu
|
|
|
|
def sample_z(self, n: int, device):
|
|
assert self.conf.is_stochastic
|
|
return torch.randn(n, self.conf.enc_out_channels, device=device)
|
|
|
|
def noise_to_cond(self, noise: Tensor):
|
|
raise NotImplementedError()
|
|
assert self.conf.noise_net_conf is not None
|
|
return self.noise_net.forward(noise)
|
|
|
|
def encode(self, x):
|
|
cond = self.encoder.forward(x)
|
|
return {'cond': cond}
|
|
|
|
@property
|
|
def stylespace_sizes(self):
|
|
modules = list(self.input_blocks.modules()) + list(
|
|
self.middle_block.modules()) + list(self.output_blocks.modules())
|
|
sizes = []
|
|
for module in modules:
|
|
if isinstance(module, ResBlock):
|
|
linear = module.cond_emb_layers[-1]
|
|
sizes.append(linear.weight.shape[0])
|
|
return sizes
|
|
|
|
def encode_stylespace(self, x, return_vector: bool = True):
|
|
"""
|
|
encode to style space
|
|
"""
|
|
modules = list(self.input_blocks.modules()) + list(
|
|
self.middle_block.modules()) + list(self.output_blocks.modules())
|
|
cond = self.encoder.forward(x)
|
|
S = []
|
|
for module in modules:
|
|
if isinstance(module, ResBlock):
|
|
s = module.cond_emb_layers.forward(cond)
|
|
S.append(s)
|
|
|
|
if return_vector:
|
|
return torch.cat(S, dim=1)
|
|
else:
|
|
return S
|
|
|
|
def forward(self,
|
|
x,
|
|
t,
|
|
y=None,
|
|
x_start=None,
|
|
cond=None,
|
|
style=None,
|
|
noise=None,
|
|
t_cond=None,
|
|
**kwargs):
|
|
"""
|
|
Apply the model to an input batch.
|
|
|
|
Args:
|
|
x_start: the original image to encode
|
|
cond: output of the encoder
|
|
noise: random noise (to predict the cond)
|
|
"""
|
|
if t_cond is None:
|
|
t_cond = t
|
|
|
|
if noise is not None:
|
|
cond = self.noise_to_cond(noise)
|
|
|
|
if cond is None:
|
|
if x is not None:
|
|
assert len(x) == len(x_start), f'{len(x)} != {len(x_start)}'
|
|
tmp = self.encode(x_start)
|
|
cond = tmp['cond']
|
|
|
|
if t is not None:
|
|
_t_emb = timestep_embedding(t, self.conf.model_channels)
|
|
_t_cond_emb = timestep_embedding(t_cond, self.conf.model_channels)
|
|
else:
|
|
_t_emb = None
|
|
_t_cond_emb = None
|
|
|
|
if self.conf.resnet_two_cond:
|
|
res = self.time_embed.forward(
|
|
time_emb=_t_emb,
|
|
cond=cond,
|
|
time_cond_emb=_t_cond_emb
|
|
)
|
|
else:
|
|
raise NotImplementedError()
|
|
|
|
if self.conf.resnet_two_cond:
|
|
emb = res.time_emb
|
|
cond_emb = res.emb
|
|
else:
|
|
emb = res.emb
|
|
cond_emb = None
|
|
|
|
style = style or res.style
|
|
|
|
assert (y is not None) == (
|
|
self.conf.num_classes is not None
|
|
), "must specify y if and only if the model is class-conditional"
|
|
|
|
|
|
if self.conf.num_classes is not None:
|
|
raise NotImplementedError()
|
|
|
|
enc_time_emb = emb
|
|
mid_time_emb = emb
|
|
dec_time_emb = emb
|
|
enc_cond_emb = cond_emb
|
|
mid_cond_emb = cond_emb
|
|
dec_cond_emb = cond_emb
|
|
|
|
if self.conf.num_users is not None:
|
|
user_pred = self.user_classifier(cond_emb[:, :self.conf.enc_out_channels // 2])
|
|
non_user_pred = self.non_user_classifier(cond_emb[:, self.conf.enc_out_channels // 2:])
|
|
|
|
hs = [[] for _ in range(len(self.conf.channel_mult))]
|
|
|
|
if x is not None:
|
|
h = x.type(self.dtype)
|
|
|
|
# input blocks
|
|
k = 0
|
|
for i in range(len(self.input_num_blocks)):
|
|
for j in range(self.input_num_blocks[i]):
|
|
h = self.input_blocks[k](h,
|
|
emb=enc_time_emb,
|
|
cond=enc_cond_emb)
|
|
'''print(i, j, h.shape)
|
|
if h.shape[-1]%2==1:
|
|
pdb.set_trace()'''
|
|
hs[i].append(h)
|
|
k += 1
|
|
assert k == len(self.input_blocks)
|
|
|
|
# middle blocks
|
|
h = self.middle_block(h, emb=mid_time_emb, cond=mid_cond_emb)
|
|
else:
|
|
h = None
|
|
hs = [[] for _ in range(len(self.conf.channel_mult))]
|
|
|
|
# output blocks
|
|
k = 0
|
|
for i in range(len(self.output_num_blocks)):
|
|
for j in range(self.output_num_blocks[i]):
|
|
try:
|
|
lateral = hs[-i - 1].pop()
|
|
except IndexError:
|
|
lateral = None
|
|
|
|
h = self.output_blocks[k](h,
|
|
emb=dec_time_emb,
|
|
cond=dec_cond_emb,
|
|
lateral=lateral)
|
|
k += 1
|
|
|
|
pred = self.out(h)
|
|
|
|
return AutoencReturn(pred=pred, cond=cond, user_pred=user_pred, non_user_pred=non_user_pred)
|
|
|
|
class UserClassifier(nn.Module):
|
|
def __init__(self, in_channels, num_classes):
|
|
super().__init__()
|
|
self.fc = nn.Sequential(
|
|
nn.Linear(in_channels, 256),
|
|
nn.ReLU(),
|
|
nn.Linear(256, num_classes),
|
|
nn.Softmax(dim=1)
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.fc(x)
|
|
|
|
class GradReverse(torch.autograd.Function):
|
|
"""
|
|
Implement the gradient reversal layer for the convenience of domain adaptation neural network.
|
|
The forward part is the identity function while the backward part is the negative function.
|
|
"""
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
return x.view_as(x)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
return grad_output.neg()
|
|
|
|
class GradientReversalLayer(nn.Module):
|
|
def __init__(self):
|
|
super(GradientReversalLayer, self).__init__()
|
|
|
|
def forward(self, inputs):
|
|
return GradReverse.apply(inputs)
|
|
|
|
class UserClassifierGradientReverse(nn.Module):
|
|
def __init__(self, in_channels, num_classes):
|
|
super().__init__()
|
|
self.grl = GradientReversalLayer()
|
|
self.fc = UserClassifier(in_channels, num_classes)
|
|
|
|
def forward(self, x):
|
|
x = self.grl(x)
|
|
return self.fc(x)
|
|
|
|
class AutoencReturn(NamedTuple):
|
|
pred: Tensor
|
|
cond: Tensor = None
|
|
user_pred: Tensor = None
|
|
non_user_pred: Tensor = None
|
|
|
|
|
|
class EmbedReturn(NamedTuple):
|
|
emb: Tensor = None
|
|
time_emb: Tensor = None
|
|
style: Tensor = None
|
|
|
|
class TimeStyleSeperateEmbed(nn.Module):
|
|
def __init__(self, time_channels, time_out_channels):
|
|
super().__init__()
|
|
self.time_embed = nn.Sequential(
|
|
linear(time_channels, time_out_channels),
|
|
nn.SiLU(),
|
|
linear(time_out_channels, time_out_channels),
|
|
)
|
|
self.cond_combine = nn.Sequential(
|
|
nn.Linear(time_out_channels * 2, time_out_channels),
|
|
nn.SiLU()
|
|
)
|
|
self.style = nn.Identity()
|
|
|
|
def forward(self, time_emb=None, cond=None, **kwargs):
|
|
if time_emb is None:
|
|
time_emb = None
|
|
else:
|
|
time_emb = self.time_embed(time_emb)
|
|
|
|
style = self.style(cond)
|
|
return EmbedReturn(emb=style, time_emb=time_emb, style=style) |