127 lines
2.5 KiB
Python
127 lines
2.5 KiB
Python
|
from enum import Enum
|
||
|
from torch import nn
|
||
|
|
||
|
|
||
|
class TrainMode(Enum):
|
||
|
manipulate = 'manipulate'
|
||
|
diffusion = 'diffusion'
|
||
|
latent_diffusion = 'latentdiffusion'
|
||
|
|
||
|
def is_manipulate(self):
|
||
|
return self in [
|
||
|
TrainMode.manipulate,
|
||
|
]
|
||
|
|
||
|
def is_diffusion(self):
|
||
|
return self in [
|
||
|
TrainMode.diffusion,
|
||
|
TrainMode.latent_diffusion,
|
||
|
]
|
||
|
|
||
|
def is_autoenc(self):
|
||
|
return self in [
|
||
|
TrainMode.diffusion,
|
||
|
]
|
||
|
|
||
|
def is_latent_diffusion(self):
|
||
|
return self in [
|
||
|
TrainMode.latent_diffusion,
|
||
|
]
|
||
|
|
||
|
def use_latent_net(self):
|
||
|
return self.is_latent_diffusion()
|
||
|
|
||
|
def require_dataset_infer(self):
|
||
|
"""
|
||
|
whether training in this mode requires the latent variables to be available?
|
||
|
"""
|
||
|
return self in [
|
||
|
TrainMode.latent_diffusion,
|
||
|
TrainMode.manipulate,
|
||
|
]
|
||
|
|
||
|
class ModelType(Enum):
|
||
|
"""
|
||
|
Kinds of the backbone models
|
||
|
"""
|
||
|
|
||
|
ddpm = 'ddpm'
|
||
|
autoencoder = 'autoencoder'
|
||
|
|
||
|
def has_autoenc(self):
|
||
|
return self in [
|
||
|
ModelType.autoencoder,
|
||
|
]
|
||
|
|
||
|
def can_sample(self):
|
||
|
return self in [ModelType.ddpm]
|
||
|
|
||
|
|
||
|
class ModelName(Enum):
|
||
|
"""
|
||
|
List of all supported model classes
|
||
|
"""
|
||
|
|
||
|
beatgans_ddpm = 'beatgans_ddpm'
|
||
|
beatgans_autoenc = 'beatgans_autoenc'
|
||
|
|
||
|
|
||
|
class ModelMeanType(Enum):
|
||
|
"""
|
||
|
Which type of output the model predicts.
|
||
|
"""
|
||
|
|
||
|
eps = 'eps'
|
||
|
|
||
|
|
||
|
class ModelVarType(Enum):
|
||
|
"""
|
||
|
What is used as the model's output variance.
|
||
|
|
||
|
The LEARNED_RANGE option has been added to allow the model to predict
|
||
|
values between FIXED_SMALL and FIXED_LARGE, making its job easier.
|
||
|
"""
|
||
|
|
||
|
fixed_small = 'fixed_small'
|
||
|
fixed_large = 'fixed_large'
|
||
|
|
||
|
|
||
|
class LossType(Enum):
|
||
|
mse = 'mse'
|
||
|
l1 = 'l1'
|
||
|
|
||
|
|
||
|
class GenerativeType(Enum):
|
||
|
"""
|
||
|
How's a sample generated
|
||
|
"""
|
||
|
|
||
|
ddpm = 'ddpm'
|
||
|
ddim = 'ddim'
|
||
|
|
||
|
|
||
|
class OptimizerType(Enum):
|
||
|
adam = 'adam'
|
||
|
adamw = 'adamw'
|
||
|
|
||
|
|
||
|
class Activation(Enum):
|
||
|
none = 'none'
|
||
|
relu = 'relu'
|
||
|
lrelu = 'lrelu'
|
||
|
silu = 'silu'
|
||
|
tanh = 'tanh'
|
||
|
|
||
|
def get_act(self):
|
||
|
if self == Activation.none:
|
||
|
return nn.Identity()
|
||
|
elif self == Activation.relu:
|
||
|
return nn.ReLU()
|
||
|
elif self == Activation.lrelu:
|
||
|
return nn.LeakyReLU(negative_slope=0.2)
|
||
|
elif self == Activation.silu:
|
||
|
return nn.SiLU()
|
||
|
elif self == Activation.tanh:
|
||
|
return nn.Tanh()
|
||
|
else:
|
||
|
raise NotImplementedError()
|