from model.unet import ScaleAt from model.latentnet import * from diffusion.resample import UniformSampler from diffusion.diffusion import space_timesteps from typing import Tuple from config_base import BaseConfig from dataset import * from diffusion import * from diffusion.base import GenerativeType, LossType, ModelMeanType, ModelVarType, get_named_beta_schedule from model import * from choices import * @dataclass class PretrainConfig(BaseConfig): name: str path: str @dataclass class TrainConfig(BaseConfig): seed: int = 0 train_mode: TrainMode = TrainMode.diffusion train_cond0_prob: float = 0 train_pred_xstart_detach: bool = True train_interpolate_prob: float = 0 train_interpolate_img: bool = False accum_batches: int = 1 autoenc_mid_attn: bool = True batch_size: int = 512 batch_size_eval: int = 4 beatgans_gen_type: GenerativeType = GenerativeType.ddim beatgans_loss_type: LossType = LossType.mse beatgans_model_mean_type: ModelMeanType = ModelMeanType.eps beatgans_model_var_type: ModelVarType = ModelVarType.fixed_large beatgans_rescale_timesteps: bool = False latent_infer_path: str = None latent_znormalize: bool = False latent_gen_type: GenerativeType = GenerativeType.ddim latent_loss_type: LossType = LossType.mse latent_model_mean_type: ModelMeanType = ModelMeanType.eps latent_model_var_type: ModelVarType = ModelVarType.fixed_large latent_rescale_timesteps: bool = False latent_T_eval: int = 1_000 latent_clip_sample: bool = False latent_beta_scheduler: str = 'linear' beta_scheduler: str = 'linear' data_name: str = '' data_val_name: str = None diffusion_type: str = None dropout: float = 0.1 ema_decay: float = 0.9999 eval_num_images: int = 5_000 eval_every_samples: int = 200_000 eval_ema_every_samples: int = 200_000 fid_use_torch: bool = True fp16: bool = False grad_clip: float = 1 img_size: int = 64 lr: float = 0.0001 optimizer: OptimizerType = OptimizerType.adam weight_decay: float = 0 model_conf: ModelConfig = None model_name: ModelName = None model_type: ModelType = None net_attn: Tuple[int] = None net_beatgans_attn_head: int = 1 net_beatgans_embed_channels: int = 128 net_resblock_updown: bool = True net_enc_use_time: bool = False net_enc_pool: str = 'adaptivenonzero' net_beatgans_gradient_checkpoint: bool = False net_beatgans_resnet_two_cond: bool = False net_beatgans_resnet_use_zero_module: bool = True net_beatgans_resnet_scale_at: ScaleAt = ScaleAt.after_norm net_beatgans_resnet_cond_channels: int = None net_ch_mult: Tuple[int] = None net_ch: int = 64 net_enc_attn: Tuple[int] = None net_enc_k: int = None net_enc_num_res_blocks: int = 2 net_enc_channel_mult: Tuple[int] = None net_enc_grad_checkpoint: bool = False net_autoenc_stochastic: bool = False net_latent_activation: Activation = Activation.silu net_latent_channel_mult: Tuple[int] = (1, 2, 4) net_latent_condition_bias: float = 0 net_latent_dropout: float = 0 net_latent_layers: int = None net_latent_net_last_act: Activation = Activation.none net_latent_net_type: LatentNetType = LatentNetType.none net_latent_num_hid_channels: int = 1024 net_latent_num_time_layers: int = 2 net_latent_skip_layers: Tuple[int] = None net_latent_time_emb_channels: int = 64 net_latent_use_norm: bool = False net_latent_time_last_act: bool = False net_num_res_blocks: int = 2 net_num_input_res_blocks: int = None net_enc_num_cls: int = None num_workers: int = 4 parallel: bool = False postfix: str = '' sample_size: int = 64 sample_every_samples: int = 20_000 save_every_samples: int = 100_000 style_ch: int = 128 T_eval: int = 1_000 T_sampler: str = 'uniform' T: int = 1_000 total_samples: int = 10_000_000 warmup: int = 0 pretrain: PretrainConfig = None continue_from: PretrainConfig = None eval_programs: Tuple[str] = None eval_path: str = None base_dir: str = 'checkpoints' name: str = '' logdir: str = f'{base_dir}{name}' num_users: int = 0 def __post_init__(self): self.batch_size_eval = self.batch_size_eval or self.batch_size self.data_val_name = self.data_val_name or self.data_name def scale_up_gpus(self, num_gpus, num_nodes=1): self.eval_ema_every_samples *= num_gpus * num_nodes self.eval_every_samples *= num_gpus * num_nodes self.sample_every_samples *= num_gpus * num_nodes self.batch_size *= num_gpus * num_nodes self.batch_size_eval *= num_gpus * num_nodes return self @property def batch_size_effective(self): return self.batch_size * self.accum_batches def _make_diffusion_conf(self, T=None): if self.diffusion_type == 'beatgans': if self.beatgans_gen_type == GenerativeType.ddpm: section_counts = [T] elif self.beatgans_gen_type == GenerativeType.ddim: section_counts = f'ddim{T}' else: raise NotImplementedError() return SpacedDiffusionBeatGansConfig( gen_type=self.beatgans_gen_type, model_type=self.model_type, betas=get_named_beta_schedule(self.beta_scheduler, self.T), model_mean_type=self.beatgans_model_mean_type, model_var_type=self.beatgans_model_var_type, loss_type=self.beatgans_loss_type, rescale_timesteps=self.beatgans_rescale_timesteps, use_timesteps=space_timesteps(num_timesteps=self.T, section_counts=section_counts), fp16=self.fp16, ) else: raise NotImplementedError() def _make_latent_diffusion_conf(self, T=None): if self.latent_gen_type == GenerativeType.ddpm: section_counts = [T] elif self.latent_gen_type == GenerativeType.ddim: section_counts = f'ddim{T}' else: raise NotImplementedError() return SpacedDiffusionBeatGansConfig( train_pred_xstart_detach=self.train_pred_xstart_detach, gen_type=self.latent_gen_type, model_type=ModelType.ddpm, betas=get_named_beta_schedule(self.latent_beta_scheduler, self.T), model_mean_type=self.latent_model_mean_type, model_var_type=self.latent_model_var_type, loss_type=self.latent_loss_type, rescale_timesteps=self.latent_rescale_timesteps, use_timesteps=space_timesteps(num_timesteps=self.T, section_counts=section_counts), fp16=self.fp16, ) @property def model_out_channels(self): return 2 @property def model_input_channels(self): return 2 def make_T_sampler(self): if self.T_sampler == 'uniform': return UniformSampler(self.T) else: raise NotImplementedError() def make_diffusion_conf(self): return self._make_diffusion_conf(self.T) def make_eval_diffusion_conf(self): return self._make_diffusion_conf(T=self.T_eval) def make_latent_diffusion_conf(self): return self._make_latent_diffusion_conf(T=self.T) def make_latent_eval_diffusion_conf(self): return self._make_latent_diffusion_conf(T=self.latent_T_eval) def make_dataset(self, taskdata, tasklabels): return SimpleSet(taskdata, tasklabels, intflag=True) def make_model_conf(self): if self.model_name == ModelName.beatgans_ddpm: self.model_type = ModelType.ddpm self.model_conf = BeatGANsUNetConfig( attention_resolutions=self.net_attn, channel_mult=self.net_ch_mult, conv_resample=True, dims=2, dropout=self.dropout, embed_channels=self.net_beatgans_embed_channels, image_size=self.img_size, in_channels=self.model_input_channels, model_channels=self.net_ch, num_classes=None, num_head_channels=-1, num_heads_upsample=-1, num_heads=self.net_beatgans_attn_head, num_res_blocks=self.net_num_res_blocks, num_input_res_blocks=self.net_num_input_res_blocks, out_channels=self.model_out_channels, resblock_updown=self.net_resblock_updown, use_checkpoint=self.net_beatgans_gradient_checkpoint, use_new_attention_order=False, resnet_two_cond=self.net_beatgans_resnet_two_cond, resnet_use_zero_module=self. net_beatgans_resnet_use_zero_module, ) elif self.model_name in [ ModelName.beatgans_autoenc, ]: cls = BeatGANsAutoencConfig if self.model_name == ModelName.beatgans_autoenc: self.model_type = ModelType.autoencoder else: raise NotImplementedError() if self.net_latent_net_type == LatentNetType.none: latent_net_conf = None elif self.net_latent_net_type == LatentNetType.skip: latent_net_conf = MLPSkipNetConfig( num_channels=self.style_ch, skip_layers=self.net_latent_skip_layers, num_hid_channels=self.net_latent_num_hid_channels, num_layers=self.net_latent_layers, num_time_emb_channels=self.net_latent_time_emb_channels, activation=self.net_latent_activation, use_norm=self.net_latent_use_norm, condition_bias=self.net_latent_condition_bias, dropout=self.net_latent_dropout, last_act=self.net_latent_net_last_act, num_time_layers=self.net_latent_num_time_layers, time_last_act=self.net_latent_time_last_act, ) else: raise NotImplementedError() self.model_conf = cls( attention_resolutions=self.net_attn, channel_mult=self.net_ch_mult, conv_resample=True, dims=1, dropout=self.dropout, embed_channels=self.net_beatgans_embed_channels, enc_out_channels=self.style_ch, enc_pool=self.net_enc_pool, enc_num_res_block=self.net_enc_num_res_blocks, enc_channel_mult=self.net_enc_channel_mult, enc_grad_checkpoint=self.net_enc_grad_checkpoint, enc_attn_resolutions=self.net_enc_attn, image_size=self.img_size, in_channels=self.model_input_channels, model_channels=self.net_ch, num_classes=None, num_head_channels=-1, num_heads_upsample=-1, num_heads=self.net_beatgans_attn_head, num_res_blocks=self.net_num_res_blocks, num_input_res_blocks=self.net_num_input_res_blocks, out_channels=self.model_out_channels, resblock_updown=self.net_resblock_updown, use_checkpoint=self.net_beatgans_gradient_checkpoint, use_new_attention_order=False, resnet_two_cond=self.net_beatgans_resnet_two_cond, resnet_use_zero_module=self. net_beatgans_resnet_use_zero_module, latent_net_conf=latent_net_conf, resnet_cond_channels=self.net_beatgans_resnet_cond_channels, num_users = self.num_users, ) else: raise NotImplementedError(self.model_name) return self.model_conf