303 lines
12 KiB
Python
303 lines
12 KiB
Python
|
from model.unet import ScaleAt
|
||
|
from model.latentnet import *
|
||
|
from diffusion.resample import UniformSampler
|
||
|
from diffusion.diffusion import space_timesteps
|
||
|
from typing import Tuple
|
||
|
|
||
|
from config_base import BaseConfig
|
||
|
from dataset import *
|
||
|
from diffusion import *
|
||
|
from diffusion.base import GenerativeType, LossType, ModelMeanType, ModelVarType, get_named_beta_schedule
|
||
|
from model import *
|
||
|
from choices import *
|
||
|
|
||
|
@dataclass
|
||
|
class PretrainConfig(BaseConfig):
|
||
|
name: str
|
||
|
path: str
|
||
|
|
||
|
@dataclass
|
||
|
class TrainConfig(BaseConfig):
|
||
|
seed: int = 0
|
||
|
train_mode: TrainMode = TrainMode.diffusion
|
||
|
train_cond0_prob: float = 0
|
||
|
train_pred_xstart_detach: bool = True
|
||
|
train_interpolate_prob: float = 0
|
||
|
train_interpolate_img: bool = False
|
||
|
accum_batches: int = 1
|
||
|
autoenc_mid_attn: bool = True
|
||
|
batch_size: int = 512
|
||
|
batch_size_eval: int = 4
|
||
|
beatgans_gen_type: GenerativeType = GenerativeType.ddim
|
||
|
beatgans_loss_type: LossType = LossType.mse
|
||
|
beatgans_model_mean_type: ModelMeanType = ModelMeanType.eps
|
||
|
beatgans_model_var_type: ModelVarType = ModelVarType.fixed_large
|
||
|
beatgans_rescale_timesteps: bool = False
|
||
|
latent_infer_path: str = None
|
||
|
latent_znormalize: bool = False
|
||
|
latent_gen_type: GenerativeType = GenerativeType.ddim
|
||
|
latent_loss_type: LossType = LossType.mse
|
||
|
latent_model_mean_type: ModelMeanType = ModelMeanType.eps
|
||
|
latent_model_var_type: ModelVarType = ModelVarType.fixed_large
|
||
|
latent_rescale_timesteps: bool = False
|
||
|
latent_T_eval: int = 1_000
|
||
|
latent_clip_sample: bool = False
|
||
|
latent_beta_scheduler: str = 'linear'
|
||
|
beta_scheduler: str = 'linear'
|
||
|
data_name: str = ''
|
||
|
data_val_name: str = None
|
||
|
diffusion_type: str = None
|
||
|
dropout: float = 0.1
|
||
|
ema_decay: float = 0.9999
|
||
|
eval_num_images: int = 5_000
|
||
|
eval_every_samples: int = 200_000
|
||
|
eval_ema_every_samples: int = 200_000
|
||
|
fid_use_torch: bool = True
|
||
|
fp16: bool = False
|
||
|
grad_clip: float = 1
|
||
|
img_size: int = 64
|
||
|
lr: float = 0.0001
|
||
|
optimizer: OptimizerType = OptimizerType.adam
|
||
|
weight_decay: float = 0
|
||
|
model_conf: ModelConfig = None
|
||
|
model_name: ModelName = None
|
||
|
model_type: ModelType = None
|
||
|
net_attn: Tuple[int] = None
|
||
|
net_beatgans_attn_head: int = 1
|
||
|
net_beatgans_embed_channels: int = 128
|
||
|
net_resblock_updown: bool = True
|
||
|
net_enc_use_time: bool = False
|
||
|
net_enc_pool: str = 'adaptivenonzero'
|
||
|
net_beatgans_gradient_checkpoint: bool = False
|
||
|
net_beatgans_resnet_two_cond: bool = False
|
||
|
net_beatgans_resnet_use_zero_module: bool = True
|
||
|
net_beatgans_resnet_scale_at: ScaleAt = ScaleAt.after_norm
|
||
|
net_beatgans_resnet_cond_channels: int = None
|
||
|
net_ch_mult: Tuple[int] = None
|
||
|
net_ch: int = 64
|
||
|
net_enc_attn: Tuple[int] = None
|
||
|
net_enc_k: int = None
|
||
|
net_enc_num_res_blocks: int = 2
|
||
|
net_enc_channel_mult: Tuple[int] = None
|
||
|
net_enc_grad_checkpoint: bool = False
|
||
|
net_autoenc_stochastic: bool = False
|
||
|
net_latent_activation: Activation = Activation.silu
|
||
|
net_latent_channel_mult: Tuple[int] = (1, 2, 4)
|
||
|
net_latent_condition_bias: float = 0
|
||
|
net_latent_dropout: float = 0
|
||
|
net_latent_layers: int = None
|
||
|
net_latent_net_last_act: Activation = Activation.none
|
||
|
net_latent_net_type: LatentNetType = LatentNetType.none
|
||
|
net_latent_num_hid_channels: int = 1024
|
||
|
net_latent_num_time_layers: int = 2
|
||
|
net_latent_skip_layers: Tuple[int] = None
|
||
|
net_latent_time_emb_channels: int = 64
|
||
|
net_latent_use_norm: bool = False
|
||
|
net_latent_time_last_act: bool = False
|
||
|
net_num_res_blocks: int = 2
|
||
|
net_num_input_res_blocks: int = None
|
||
|
net_enc_num_cls: int = None
|
||
|
num_workers: int = 4
|
||
|
parallel: bool = False
|
||
|
postfix: str = ''
|
||
|
sample_size: int = 64
|
||
|
sample_every_samples: int = 20_000
|
||
|
save_every_samples: int = 100_000
|
||
|
style_ch: int = 128
|
||
|
T_eval: int = 1_000
|
||
|
T_sampler: str = 'uniform'
|
||
|
T: int = 1_000
|
||
|
total_samples: int = 10_000_000
|
||
|
warmup: int = 0
|
||
|
pretrain: PretrainConfig = None
|
||
|
continue_from: PretrainConfig = None
|
||
|
eval_programs: Tuple[str] = None
|
||
|
eval_path: str = None
|
||
|
base_dir: str = 'checkpoints'
|
||
|
name: str = ''
|
||
|
logdir: str = f'{base_dir}{name}'
|
||
|
num_users: int = 0
|
||
|
|
||
|
def __post_init__(self):
|
||
|
self.batch_size_eval = self.batch_size_eval or self.batch_size
|
||
|
self.data_val_name = self.data_val_name or self.data_name
|
||
|
|
||
|
def scale_up_gpus(self, num_gpus, num_nodes=1):
|
||
|
self.eval_ema_every_samples *= num_gpus * num_nodes
|
||
|
self.eval_every_samples *= num_gpus * num_nodes
|
||
|
self.sample_every_samples *= num_gpus * num_nodes
|
||
|
self.batch_size *= num_gpus * num_nodes
|
||
|
self.batch_size_eval *= num_gpus * num_nodes
|
||
|
return self
|
||
|
|
||
|
@property
|
||
|
def batch_size_effective(self):
|
||
|
return self.batch_size * self.accum_batches
|
||
|
|
||
|
def _make_diffusion_conf(self, T=None):
|
||
|
if self.diffusion_type == 'beatgans':
|
||
|
if self.beatgans_gen_type == GenerativeType.ddpm:
|
||
|
section_counts = [T]
|
||
|
elif self.beatgans_gen_type == GenerativeType.ddim:
|
||
|
section_counts = f'ddim{T}'
|
||
|
else:
|
||
|
raise NotImplementedError()
|
||
|
|
||
|
return SpacedDiffusionBeatGansConfig(
|
||
|
gen_type=self.beatgans_gen_type,
|
||
|
model_type=self.model_type,
|
||
|
betas=get_named_beta_schedule(self.beta_scheduler, self.T),
|
||
|
model_mean_type=self.beatgans_model_mean_type,
|
||
|
model_var_type=self.beatgans_model_var_type,
|
||
|
loss_type=self.beatgans_loss_type,
|
||
|
rescale_timesteps=self.beatgans_rescale_timesteps,
|
||
|
use_timesteps=space_timesteps(num_timesteps=self.T,
|
||
|
section_counts=section_counts),
|
||
|
fp16=self.fp16,
|
||
|
)
|
||
|
else:
|
||
|
raise NotImplementedError()
|
||
|
|
||
|
def _make_latent_diffusion_conf(self, T=None):
|
||
|
if self.latent_gen_type == GenerativeType.ddpm:
|
||
|
section_counts = [T]
|
||
|
elif self.latent_gen_type == GenerativeType.ddim:
|
||
|
section_counts = f'ddim{T}'
|
||
|
else:
|
||
|
raise NotImplementedError()
|
||
|
|
||
|
return SpacedDiffusionBeatGansConfig(
|
||
|
train_pred_xstart_detach=self.train_pred_xstart_detach,
|
||
|
gen_type=self.latent_gen_type,
|
||
|
model_type=ModelType.ddpm,
|
||
|
betas=get_named_beta_schedule(self.latent_beta_scheduler, self.T),
|
||
|
model_mean_type=self.latent_model_mean_type,
|
||
|
model_var_type=self.latent_model_var_type,
|
||
|
loss_type=self.latent_loss_type,
|
||
|
rescale_timesteps=self.latent_rescale_timesteps,
|
||
|
use_timesteps=space_timesteps(num_timesteps=self.T,
|
||
|
section_counts=section_counts),
|
||
|
fp16=self.fp16,
|
||
|
)
|
||
|
|
||
|
@property
|
||
|
def model_out_channels(self):
|
||
|
return 2
|
||
|
@property
|
||
|
def model_input_channels(self):
|
||
|
return 2
|
||
|
|
||
|
def make_T_sampler(self):
|
||
|
if self.T_sampler == 'uniform':
|
||
|
return UniformSampler(self.T)
|
||
|
else:
|
||
|
raise NotImplementedError()
|
||
|
|
||
|
def make_diffusion_conf(self):
|
||
|
return self._make_diffusion_conf(self.T)
|
||
|
|
||
|
def make_eval_diffusion_conf(self):
|
||
|
return self._make_diffusion_conf(T=self.T_eval)
|
||
|
|
||
|
def make_latent_diffusion_conf(self):
|
||
|
return self._make_latent_diffusion_conf(T=self.T)
|
||
|
|
||
|
def make_latent_eval_diffusion_conf(self):
|
||
|
return self._make_latent_diffusion_conf(T=self.latent_T_eval)
|
||
|
|
||
|
def make_dataset(self, taskdata, tasklabels):
|
||
|
return SimpleSet(taskdata, tasklabels, intflag=True)
|
||
|
|
||
|
def make_model_conf(self):
|
||
|
if self.model_name == ModelName.beatgans_ddpm:
|
||
|
self.model_type = ModelType.ddpm
|
||
|
self.model_conf = BeatGANsUNetConfig(
|
||
|
attention_resolutions=self.net_attn,
|
||
|
channel_mult=self.net_ch_mult,
|
||
|
conv_resample=True,
|
||
|
dims=2,
|
||
|
dropout=self.dropout,
|
||
|
embed_channels=self.net_beatgans_embed_channels,
|
||
|
image_size=self.img_size,
|
||
|
in_channels=self.model_input_channels,
|
||
|
model_channels=self.net_ch,
|
||
|
num_classes=None,
|
||
|
num_head_channels=-1,
|
||
|
num_heads_upsample=-1,
|
||
|
num_heads=self.net_beatgans_attn_head,
|
||
|
num_res_blocks=self.net_num_res_blocks,
|
||
|
num_input_res_blocks=self.net_num_input_res_blocks,
|
||
|
out_channels=self.model_out_channels,
|
||
|
resblock_updown=self.net_resblock_updown,
|
||
|
use_checkpoint=self.net_beatgans_gradient_checkpoint,
|
||
|
use_new_attention_order=False,
|
||
|
resnet_two_cond=self.net_beatgans_resnet_two_cond,
|
||
|
resnet_use_zero_module=self.
|
||
|
net_beatgans_resnet_use_zero_module,
|
||
|
)
|
||
|
elif self.model_name in [
|
||
|
ModelName.beatgans_autoenc,
|
||
|
]:
|
||
|
cls = BeatGANsAutoencConfig
|
||
|
if self.model_name == ModelName.beatgans_autoenc:
|
||
|
self.model_type = ModelType.autoencoder
|
||
|
else:
|
||
|
raise NotImplementedError()
|
||
|
|
||
|
if self.net_latent_net_type == LatentNetType.none:
|
||
|
latent_net_conf = None
|
||
|
elif self.net_latent_net_type == LatentNetType.skip:
|
||
|
latent_net_conf = MLPSkipNetConfig(
|
||
|
num_channels=self.style_ch,
|
||
|
skip_layers=self.net_latent_skip_layers,
|
||
|
num_hid_channels=self.net_latent_num_hid_channels,
|
||
|
num_layers=self.net_latent_layers,
|
||
|
num_time_emb_channels=self.net_latent_time_emb_channels,
|
||
|
activation=self.net_latent_activation,
|
||
|
use_norm=self.net_latent_use_norm,
|
||
|
condition_bias=self.net_latent_condition_bias,
|
||
|
dropout=self.net_latent_dropout,
|
||
|
last_act=self.net_latent_net_last_act,
|
||
|
num_time_layers=self.net_latent_num_time_layers,
|
||
|
time_last_act=self.net_latent_time_last_act,
|
||
|
)
|
||
|
else:
|
||
|
raise NotImplementedError()
|
||
|
|
||
|
self.model_conf = cls(
|
||
|
attention_resolutions=self.net_attn,
|
||
|
channel_mult=self.net_ch_mult,
|
||
|
conv_resample=True,
|
||
|
dims=1,
|
||
|
dropout=self.dropout,
|
||
|
embed_channels=self.net_beatgans_embed_channels,
|
||
|
enc_out_channels=self.style_ch,
|
||
|
enc_pool=self.net_enc_pool,
|
||
|
enc_num_res_block=self.net_enc_num_res_blocks,
|
||
|
enc_channel_mult=self.net_enc_channel_mult,
|
||
|
enc_grad_checkpoint=self.net_enc_grad_checkpoint,
|
||
|
enc_attn_resolutions=self.net_enc_attn,
|
||
|
image_size=self.img_size,
|
||
|
in_channels=self.model_input_channels,
|
||
|
model_channels=self.net_ch,
|
||
|
num_classes=None,
|
||
|
num_head_channels=-1,
|
||
|
num_heads_upsample=-1,
|
||
|
num_heads=self.net_beatgans_attn_head,
|
||
|
num_res_blocks=self.net_num_res_blocks,
|
||
|
num_input_res_blocks=self.net_num_input_res_blocks,
|
||
|
out_channels=self.model_out_channels,
|
||
|
resblock_updown=self.net_resblock_updown,
|
||
|
use_checkpoint=self.net_beatgans_gradient_checkpoint,
|
||
|
use_new_attention_order=False,
|
||
|
resnet_two_cond=self.net_beatgans_resnet_two_cond,
|
||
|
resnet_use_zero_module=self.
|
||
|
net_beatgans_resnet_use_zero_module,
|
||
|
latent_net_conf=latent_net_conf,
|
||
|
resnet_cond_channels=self.net_beatgans_resnet_cond_channels,
|
||
|
num_users = self.num_users,
|
||
|
)
|
||
|
else:
|
||
|
raise NotImplementedError(self.model_name)
|
||
|
|
||
|
return self.model_conf
|