DisMouse/choices.py

127 lines
2.5 KiB
Python
Raw Normal View History

2024-10-08 14:18:47 +02:00
from enum import Enum
from torch import nn
class TrainMode(Enum):
manipulate = 'manipulate'
diffusion = 'diffusion'
latent_diffusion = 'latentdiffusion'
def is_manipulate(self):
return self in [
TrainMode.manipulate,
]
def is_diffusion(self):
return self in [
TrainMode.diffusion,
TrainMode.latent_diffusion,
]
def is_autoenc(self):
return self in [
TrainMode.diffusion,
]
def is_latent_diffusion(self):
return self in [
TrainMode.latent_diffusion,
]
def use_latent_net(self):
return self.is_latent_diffusion()
def require_dataset_infer(self):
"""
whether training in this mode requires the latent variables to be available?
"""
return self in [
TrainMode.latent_diffusion,
TrainMode.manipulate,
]
class ModelType(Enum):
"""
Kinds of the backbone models
"""
ddpm = 'ddpm'
autoencoder = 'autoencoder'
def has_autoenc(self):
return self in [
ModelType.autoencoder,
]
def can_sample(self):
return self in [ModelType.ddpm]
class ModelName(Enum):
"""
List of all supported model classes
"""
beatgans_ddpm = 'beatgans_ddpm'
beatgans_autoenc = 'beatgans_autoenc'
class ModelMeanType(Enum):
"""
Which type of output the model predicts.
"""
eps = 'eps'
class ModelVarType(Enum):
"""
What is used as the model's output variance.
The LEARNED_RANGE option has been added to allow the model to predict
values between FIXED_SMALL and FIXED_LARGE, making its job easier.
"""
fixed_small = 'fixed_small'
fixed_large = 'fixed_large'
class LossType(Enum):
mse = 'mse'
l1 = 'l1'
class GenerativeType(Enum):
"""
How's a sample generated
"""
ddpm = 'ddpm'
ddim = 'ddim'
class OptimizerType(Enum):
adam = 'adam'
adamw = 'adamw'
class Activation(Enum):
none = 'none'
relu = 'relu'
lrelu = 'lrelu'
silu = 'silu'
tanh = 'tanh'
def get_act(self):
if self == Activation.none:
return nn.Identity()
elif self == Activation.relu:
return nn.ReLU()
elif self == Activation.lrelu:
return nn.LeakyReLU(negative_slope=0.2)
elif self == Activation.silu:
return nn.SiLU()
elif self == Activation.tanh:
return nn.Tanh()
else:
raise NotImplementedError()