from enum import Enum from torch import nn class TrainMode(Enum): manipulate = 'manipulate' diffusion = 'diffusion' latent_diffusion = 'latentdiffusion' def is_manipulate(self): return self in [ TrainMode.manipulate, ] def is_diffusion(self): return self in [ TrainMode.diffusion, TrainMode.latent_diffusion, ] def is_autoenc(self): return self in [ TrainMode.diffusion, ] def is_latent_diffusion(self): return self in [ TrainMode.latent_diffusion, ] def use_latent_net(self): return self.is_latent_diffusion() def require_dataset_infer(self): """ whether training in this mode requires the latent variables to be available? """ return self in [ TrainMode.latent_diffusion, TrainMode.manipulate, ] class ModelType(Enum): """ Kinds of the backbone models """ ddpm = 'ddpm' autoencoder = 'autoencoder' def has_autoenc(self): return self in [ ModelType.autoencoder, ] def can_sample(self): return self in [ModelType.ddpm] class ModelName(Enum): """ List of all supported model classes """ beatgans_ddpm = 'beatgans_ddpm' beatgans_autoenc = 'beatgans_autoenc' class ModelMeanType(Enum): """ Which type of output the model predicts. """ eps = 'eps' class ModelVarType(Enum): """ What is used as the model's output variance. The LEARNED_RANGE option has been added to allow the model to predict values between FIXED_SMALL and FIXED_LARGE, making its job easier. """ fixed_small = 'fixed_small' fixed_large = 'fixed_large' class LossType(Enum): mse = 'mse' l1 = 'l1' class GenerativeType(Enum): """ How's a sample generated """ ddpm = 'ddpm' ddim = 'ddim' class OptimizerType(Enum): adam = 'adam' adamw = 'adamw' class Activation(Enum): none = 'none' relu = 'relu' lrelu = 'lrelu' silu = 'silu' tanh = 'tanh' def get_act(self): if self == Activation.none: return nn.Identity() elif self == Activation.relu: return nn.ReLU() elif self == Activation.lrelu: return nn.LeakyReLU(negative_slope=0.2) elif self == Activation.silu: return nn.SiLU() elif self == Activation.tanh: return nn.Tanh() else: raise NotImplementedError()