DisMouse/config.py

303 lines
12 KiB
Python
Raw Permalink Normal View History

2024-10-08 14:18:47 +02:00
from model.unet import ScaleAt
from model.latentnet import *
from diffusion.resample import UniformSampler
from diffusion.diffusion import space_timesteps
from typing import Tuple
from config_base import BaseConfig
from dataset import *
from diffusion import *
from diffusion.base import GenerativeType, LossType, ModelMeanType, ModelVarType, get_named_beta_schedule
from model import *
from choices import *
@dataclass
class PretrainConfig(BaseConfig):
name: str
path: str
@dataclass
class TrainConfig(BaseConfig):
seed: int = 0
train_mode: TrainMode = TrainMode.diffusion
train_cond0_prob: float = 0
train_pred_xstart_detach: bool = True
train_interpolate_prob: float = 0
train_interpolate_img: bool = False
accum_batches: int = 1
autoenc_mid_attn: bool = True
batch_size: int = 512
batch_size_eval: int = 4
beatgans_gen_type: GenerativeType = GenerativeType.ddim
beatgans_loss_type: LossType = LossType.mse
beatgans_model_mean_type: ModelMeanType = ModelMeanType.eps
beatgans_model_var_type: ModelVarType = ModelVarType.fixed_large
beatgans_rescale_timesteps: bool = False
latent_infer_path: str = None
latent_znormalize: bool = False
latent_gen_type: GenerativeType = GenerativeType.ddim
latent_loss_type: LossType = LossType.mse
latent_model_mean_type: ModelMeanType = ModelMeanType.eps
latent_model_var_type: ModelVarType = ModelVarType.fixed_large
latent_rescale_timesteps: bool = False
latent_T_eval: int = 1_000
latent_clip_sample: bool = False
latent_beta_scheduler: str = 'linear'
beta_scheduler: str = 'linear'
data_name: str = ''
data_val_name: str = None
diffusion_type: str = None
dropout: float = 0.1
ema_decay: float = 0.9999
eval_num_images: int = 5_000
eval_every_samples: int = 200_000
eval_ema_every_samples: int = 200_000
fid_use_torch: bool = True
fp16: bool = False
grad_clip: float = 1
img_size: int = 64
lr: float = 0.0001
optimizer: OptimizerType = OptimizerType.adam
weight_decay: float = 0
model_conf: ModelConfig = None
model_name: ModelName = None
model_type: ModelType = None
net_attn: Tuple[int] = None
net_beatgans_attn_head: int = 1
net_beatgans_embed_channels: int = 128
net_resblock_updown: bool = True
net_enc_use_time: bool = False
net_enc_pool: str = 'adaptivenonzero'
net_beatgans_gradient_checkpoint: bool = False
net_beatgans_resnet_two_cond: bool = False
net_beatgans_resnet_use_zero_module: bool = True
net_beatgans_resnet_scale_at: ScaleAt = ScaleAt.after_norm
net_beatgans_resnet_cond_channels: int = None
net_ch_mult: Tuple[int] = None
net_ch: int = 64
net_enc_attn: Tuple[int] = None
net_enc_k: int = None
net_enc_num_res_blocks: int = 2
net_enc_channel_mult: Tuple[int] = None
net_enc_grad_checkpoint: bool = False
net_autoenc_stochastic: bool = False
net_latent_activation: Activation = Activation.silu
net_latent_channel_mult: Tuple[int] = (1, 2, 4)
net_latent_condition_bias: float = 0
net_latent_dropout: float = 0
net_latent_layers: int = None
net_latent_net_last_act: Activation = Activation.none
net_latent_net_type: LatentNetType = LatentNetType.none
net_latent_num_hid_channels: int = 1024
net_latent_num_time_layers: int = 2
net_latent_skip_layers: Tuple[int] = None
net_latent_time_emb_channels: int = 64
net_latent_use_norm: bool = False
net_latent_time_last_act: bool = False
net_num_res_blocks: int = 2
net_num_input_res_blocks: int = None
net_enc_num_cls: int = None
num_workers: int = 4
parallel: bool = False
postfix: str = ''
sample_size: int = 64
sample_every_samples: int = 20_000
save_every_samples: int = 100_000
style_ch: int = 128
T_eval: int = 1_000
T_sampler: str = 'uniform'
T: int = 1_000
total_samples: int = 10_000_000
warmup: int = 0
pretrain: PretrainConfig = None
continue_from: PretrainConfig = None
eval_programs: Tuple[str] = None
eval_path: str = None
base_dir: str = 'checkpoints'
name: str = ''
logdir: str = f'{base_dir}{name}'
num_users: int = 0
def __post_init__(self):
self.batch_size_eval = self.batch_size_eval or self.batch_size
self.data_val_name = self.data_val_name or self.data_name
def scale_up_gpus(self, num_gpus, num_nodes=1):
self.eval_ema_every_samples *= num_gpus * num_nodes
self.eval_every_samples *= num_gpus * num_nodes
self.sample_every_samples *= num_gpus * num_nodes
self.batch_size *= num_gpus * num_nodes
self.batch_size_eval *= num_gpus * num_nodes
return self
@property
def batch_size_effective(self):
return self.batch_size * self.accum_batches
def _make_diffusion_conf(self, T=None):
if self.diffusion_type == 'beatgans':
if self.beatgans_gen_type == GenerativeType.ddpm:
section_counts = [T]
elif self.beatgans_gen_type == GenerativeType.ddim:
section_counts = f'ddim{T}'
else:
raise NotImplementedError()
return SpacedDiffusionBeatGansConfig(
gen_type=self.beatgans_gen_type,
model_type=self.model_type,
betas=get_named_beta_schedule(self.beta_scheduler, self.T),
model_mean_type=self.beatgans_model_mean_type,
model_var_type=self.beatgans_model_var_type,
loss_type=self.beatgans_loss_type,
rescale_timesteps=self.beatgans_rescale_timesteps,
use_timesteps=space_timesteps(num_timesteps=self.T,
section_counts=section_counts),
fp16=self.fp16,
)
else:
raise NotImplementedError()
def _make_latent_diffusion_conf(self, T=None):
if self.latent_gen_type == GenerativeType.ddpm:
section_counts = [T]
elif self.latent_gen_type == GenerativeType.ddim:
section_counts = f'ddim{T}'
else:
raise NotImplementedError()
return SpacedDiffusionBeatGansConfig(
train_pred_xstart_detach=self.train_pred_xstart_detach,
gen_type=self.latent_gen_type,
model_type=ModelType.ddpm,
betas=get_named_beta_schedule(self.latent_beta_scheduler, self.T),
model_mean_type=self.latent_model_mean_type,
model_var_type=self.latent_model_var_type,
loss_type=self.latent_loss_type,
rescale_timesteps=self.latent_rescale_timesteps,
use_timesteps=space_timesteps(num_timesteps=self.T,
section_counts=section_counts),
fp16=self.fp16,
)
@property
def model_out_channels(self):
return 2
@property
def model_input_channels(self):
return 2
def make_T_sampler(self):
if self.T_sampler == 'uniform':
return UniformSampler(self.T)
else:
raise NotImplementedError()
def make_diffusion_conf(self):
return self._make_diffusion_conf(self.T)
def make_eval_diffusion_conf(self):
return self._make_diffusion_conf(T=self.T_eval)
def make_latent_diffusion_conf(self):
return self._make_latent_diffusion_conf(T=self.T)
def make_latent_eval_diffusion_conf(self):
return self._make_latent_diffusion_conf(T=self.latent_T_eval)
def make_dataset(self, taskdata, tasklabels):
return SimpleSet(taskdata, tasklabels, intflag=True)
def make_model_conf(self):
if self.model_name == ModelName.beatgans_ddpm:
self.model_type = ModelType.ddpm
self.model_conf = BeatGANsUNetConfig(
attention_resolutions=self.net_attn,
channel_mult=self.net_ch_mult,
conv_resample=True,
dims=2,
dropout=self.dropout,
embed_channels=self.net_beatgans_embed_channels,
image_size=self.img_size,
in_channels=self.model_input_channels,
model_channels=self.net_ch,
num_classes=None,
num_head_channels=-1,
num_heads_upsample=-1,
num_heads=self.net_beatgans_attn_head,
num_res_blocks=self.net_num_res_blocks,
num_input_res_blocks=self.net_num_input_res_blocks,
out_channels=self.model_out_channels,
resblock_updown=self.net_resblock_updown,
use_checkpoint=self.net_beatgans_gradient_checkpoint,
use_new_attention_order=False,
resnet_two_cond=self.net_beatgans_resnet_two_cond,
resnet_use_zero_module=self.
net_beatgans_resnet_use_zero_module,
)
elif self.model_name in [
ModelName.beatgans_autoenc,
]:
cls = BeatGANsAutoencConfig
if self.model_name == ModelName.beatgans_autoenc:
self.model_type = ModelType.autoencoder
else:
raise NotImplementedError()
if self.net_latent_net_type == LatentNetType.none:
latent_net_conf = None
elif self.net_latent_net_type == LatentNetType.skip:
latent_net_conf = MLPSkipNetConfig(
num_channels=self.style_ch,
skip_layers=self.net_latent_skip_layers,
num_hid_channels=self.net_latent_num_hid_channels,
num_layers=self.net_latent_layers,
num_time_emb_channels=self.net_latent_time_emb_channels,
activation=self.net_latent_activation,
use_norm=self.net_latent_use_norm,
condition_bias=self.net_latent_condition_bias,
dropout=self.net_latent_dropout,
last_act=self.net_latent_net_last_act,
num_time_layers=self.net_latent_num_time_layers,
time_last_act=self.net_latent_time_last_act,
)
else:
raise NotImplementedError()
self.model_conf = cls(
attention_resolutions=self.net_attn,
channel_mult=self.net_ch_mult,
conv_resample=True,
dims=1,
dropout=self.dropout,
embed_channels=self.net_beatgans_embed_channels,
enc_out_channels=self.style_ch,
enc_pool=self.net_enc_pool,
enc_num_res_block=self.net_enc_num_res_blocks,
enc_channel_mult=self.net_enc_channel_mult,
enc_grad_checkpoint=self.net_enc_grad_checkpoint,
enc_attn_resolutions=self.net_enc_attn,
image_size=self.img_size,
in_channels=self.model_input_channels,
model_channels=self.net_ch,
num_classes=None,
num_head_channels=-1,
num_heads_upsample=-1,
num_heads=self.net_beatgans_attn_head,
num_res_blocks=self.net_num_res_blocks,
num_input_res_blocks=self.net_num_input_res_blocks,
out_channels=self.model_out_channels,
resblock_updown=self.net_resblock_updown,
use_checkpoint=self.net_beatgans_gradient_checkpoint,
use_new_attention_order=False,
resnet_two_cond=self.net_beatgans_resnet_two_cond,
resnet_use_zero_module=self.
net_beatgans_resnet_use_zero_module,
latent_net_conf=latent_net_conf,
resnet_cond_channels=self.net_beatgans_resnet_cond_channels,
num_users = self.num_users,
)
else:
raise NotImplementedError(self.model_name)
return self.model_conf