commit b102c2a534aaace8336b0a23f2c572d3d3e9ec64 Author: Guanhua Zhang Date: Tue Oct 8 14:18:47 2024 +0200 init diff --git a/choices.py b/choices.py new file mode 100755 index 0000000..50dd4b1 --- /dev/null +++ b/choices.py @@ -0,0 +1,127 @@ +from enum import Enum +from torch import nn + + +class TrainMode(Enum): + manipulate = 'manipulate' + diffusion = 'diffusion' + latent_diffusion = 'latentdiffusion' + + def is_manipulate(self): + return self in [ + TrainMode.manipulate, + ] + + def is_diffusion(self): + return self in [ + TrainMode.diffusion, + TrainMode.latent_diffusion, + ] + + def is_autoenc(self): + return self in [ + TrainMode.diffusion, + ] + + def is_latent_diffusion(self): + return self in [ + TrainMode.latent_diffusion, + ] + + def use_latent_net(self): + return self.is_latent_diffusion() + + def require_dataset_infer(self): + """ + whether training in this mode requires the latent variables to be available? + """ + return self in [ + TrainMode.latent_diffusion, + TrainMode.manipulate, + ] + +class ModelType(Enum): + """ + Kinds of the backbone models + """ + + ddpm = 'ddpm' + autoencoder = 'autoencoder' + + def has_autoenc(self): + return self in [ + ModelType.autoencoder, + ] + + def can_sample(self): + return self in [ModelType.ddpm] + + +class ModelName(Enum): + """ + List of all supported model classes + """ + + beatgans_ddpm = 'beatgans_ddpm' + beatgans_autoenc = 'beatgans_autoenc' + + +class ModelMeanType(Enum): + """ + Which type of output the model predicts. + """ + + eps = 'eps' + + +class ModelVarType(Enum): + """ + What is used as the model's output variance. + + The LEARNED_RANGE option has been added to allow the model to predict + values between FIXED_SMALL and FIXED_LARGE, making its job easier. + """ + + fixed_small = 'fixed_small' + fixed_large = 'fixed_large' + + +class LossType(Enum): + mse = 'mse' + l1 = 'l1' + + +class GenerativeType(Enum): + """ + How's a sample generated + """ + + ddpm = 'ddpm' + ddim = 'ddim' + + +class OptimizerType(Enum): + adam = 'adam' + adamw = 'adamw' + + +class Activation(Enum): + none = 'none' + relu = 'relu' + lrelu = 'lrelu' + silu = 'silu' + tanh = 'tanh' + + def get_act(self): + if self == Activation.none: + return nn.Identity() + elif self == Activation.relu: + return nn.ReLU() + elif self == Activation.lrelu: + return nn.LeakyReLU(negative_slope=0.2) + elif self == Activation.silu: + return nn.SiLU() + elif self == Activation.tanh: + return nn.Tanh() + else: + raise NotImplementedError() \ No newline at end of file diff --git a/config.py b/config.py new file mode 100644 index 0000000..a27c743 --- /dev/null +++ b/config.py @@ -0,0 +1,303 @@ +from model.unet import ScaleAt +from model.latentnet import * +from diffusion.resample import UniformSampler +from diffusion.diffusion import space_timesteps +from typing import Tuple + +from config_base import BaseConfig +from dataset import * +from diffusion import * +from diffusion.base import GenerativeType, LossType, ModelMeanType, ModelVarType, get_named_beta_schedule +from model import * +from choices import * + +@dataclass +class PretrainConfig(BaseConfig): + name: str + path: str + +@dataclass +class TrainConfig(BaseConfig): + seed: int = 0 + train_mode: TrainMode = TrainMode.diffusion + train_cond0_prob: float = 0 + train_pred_xstart_detach: bool = True + train_interpolate_prob: float = 0 + train_interpolate_img: bool = False + accum_batches: int = 1 + autoenc_mid_attn: bool = True + batch_size: int = 512 + batch_size_eval: int = 4 + beatgans_gen_type: GenerativeType = GenerativeType.ddim + beatgans_loss_type: LossType = LossType.mse + beatgans_model_mean_type: ModelMeanType = ModelMeanType.eps + beatgans_model_var_type: ModelVarType = ModelVarType.fixed_large + beatgans_rescale_timesteps: bool = False + latent_infer_path: str = None + latent_znormalize: bool = False + latent_gen_type: GenerativeType = GenerativeType.ddim + latent_loss_type: LossType = LossType.mse + latent_model_mean_type: ModelMeanType = ModelMeanType.eps + latent_model_var_type: ModelVarType = ModelVarType.fixed_large + latent_rescale_timesteps: bool = False + latent_T_eval: int = 1_000 + latent_clip_sample: bool = False + latent_beta_scheduler: str = 'linear' + beta_scheduler: str = 'linear' + data_name: str = '' + data_val_name: str = None + diffusion_type: str = None + dropout: float = 0.1 + ema_decay: float = 0.9999 + eval_num_images: int = 5_000 + eval_every_samples: int = 200_000 + eval_ema_every_samples: int = 200_000 + fid_use_torch: bool = True + fp16: bool = False + grad_clip: float = 1 + img_size: int = 64 + lr: float = 0.0001 + optimizer: OptimizerType = OptimizerType.adam + weight_decay: float = 0 + model_conf: ModelConfig = None + model_name: ModelName = None + model_type: ModelType = None + net_attn: Tuple[int] = None + net_beatgans_attn_head: int = 1 + net_beatgans_embed_channels: int = 128 + net_resblock_updown: bool = True + net_enc_use_time: bool = False + net_enc_pool: str = 'adaptivenonzero' + net_beatgans_gradient_checkpoint: bool = False + net_beatgans_resnet_two_cond: bool = False + net_beatgans_resnet_use_zero_module: bool = True + net_beatgans_resnet_scale_at: ScaleAt = ScaleAt.after_norm + net_beatgans_resnet_cond_channels: int = None + net_ch_mult: Tuple[int] = None + net_ch: int = 64 + net_enc_attn: Tuple[int] = None + net_enc_k: int = None + net_enc_num_res_blocks: int = 2 + net_enc_channel_mult: Tuple[int] = None + net_enc_grad_checkpoint: bool = False + net_autoenc_stochastic: bool = False + net_latent_activation: Activation = Activation.silu + net_latent_channel_mult: Tuple[int] = (1, 2, 4) + net_latent_condition_bias: float = 0 + net_latent_dropout: float = 0 + net_latent_layers: int = None + net_latent_net_last_act: Activation = Activation.none + net_latent_net_type: LatentNetType = LatentNetType.none + net_latent_num_hid_channels: int = 1024 + net_latent_num_time_layers: int = 2 + net_latent_skip_layers: Tuple[int] = None + net_latent_time_emb_channels: int = 64 + net_latent_use_norm: bool = False + net_latent_time_last_act: bool = False + net_num_res_blocks: int = 2 + net_num_input_res_blocks: int = None + net_enc_num_cls: int = None + num_workers: int = 4 + parallel: bool = False + postfix: str = '' + sample_size: int = 64 + sample_every_samples: int = 20_000 + save_every_samples: int = 100_000 + style_ch: int = 128 + T_eval: int = 1_000 + T_sampler: str = 'uniform' + T: int = 1_000 + total_samples: int = 10_000_000 + warmup: int = 0 + pretrain: PretrainConfig = None + continue_from: PretrainConfig = None + eval_programs: Tuple[str] = None + eval_path: str = None + base_dir: str = 'checkpoints' + name: str = '' + logdir: str = f'{base_dir}{name}' + num_users: int = 0 + + def __post_init__(self): + self.batch_size_eval = self.batch_size_eval or self.batch_size + self.data_val_name = self.data_val_name or self.data_name + + def scale_up_gpus(self, num_gpus, num_nodes=1): + self.eval_ema_every_samples *= num_gpus * num_nodes + self.eval_every_samples *= num_gpus * num_nodes + self.sample_every_samples *= num_gpus * num_nodes + self.batch_size *= num_gpus * num_nodes + self.batch_size_eval *= num_gpus * num_nodes + return self + + @property + def batch_size_effective(self): + return self.batch_size * self.accum_batches + + def _make_diffusion_conf(self, T=None): + if self.diffusion_type == 'beatgans': + if self.beatgans_gen_type == GenerativeType.ddpm: + section_counts = [T] + elif self.beatgans_gen_type == GenerativeType.ddim: + section_counts = f'ddim{T}' + else: + raise NotImplementedError() + + return SpacedDiffusionBeatGansConfig( + gen_type=self.beatgans_gen_type, + model_type=self.model_type, + betas=get_named_beta_schedule(self.beta_scheduler, self.T), + model_mean_type=self.beatgans_model_mean_type, + model_var_type=self.beatgans_model_var_type, + loss_type=self.beatgans_loss_type, + rescale_timesteps=self.beatgans_rescale_timesteps, + use_timesteps=space_timesteps(num_timesteps=self.T, + section_counts=section_counts), + fp16=self.fp16, + ) + else: + raise NotImplementedError() + + def _make_latent_diffusion_conf(self, T=None): + if self.latent_gen_type == GenerativeType.ddpm: + section_counts = [T] + elif self.latent_gen_type == GenerativeType.ddim: + section_counts = f'ddim{T}' + else: + raise NotImplementedError() + + return SpacedDiffusionBeatGansConfig( + train_pred_xstart_detach=self.train_pred_xstart_detach, + gen_type=self.latent_gen_type, + model_type=ModelType.ddpm, + betas=get_named_beta_schedule(self.latent_beta_scheduler, self.T), + model_mean_type=self.latent_model_mean_type, + model_var_type=self.latent_model_var_type, + loss_type=self.latent_loss_type, + rescale_timesteps=self.latent_rescale_timesteps, + use_timesteps=space_timesteps(num_timesteps=self.T, + section_counts=section_counts), + fp16=self.fp16, + ) + + @property + def model_out_channels(self): + return 2 + @property + def model_input_channels(self): + return 2 + + def make_T_sampler(self): + if self.T_sampler == 'uniform': + return UniformSampler(self.T) + else: + raise NotImplementedError() + + def make_diffusion_conf(self): + return self._make_diffusion_conf(self.T) + + def make_eval_diffusion_conf(self): + return self._make_diffusion_conf(T=self.T_eval) + + def make_latent_diffusion_conf(self): + return self._make_latent_diffusion_conf(T=self.T) + + def make_latent_eval_diffusion_conf(self): + return self._make_latent_diffusion_conf(T=self.latent_T_eval) + + def make_dataset(self, taskdata, tasklabels): + return SimpleSet(taskdata, tasklabels, intflag=True) + + def make_model_conf(self): + if self.model_name == ModelName.beatgans_ddpm: + self.model_type = ModelType.ddpm + self.model_conf = BeatGANsUNetConfig( + attention_resolutions=self.net_attn, + channel_mult=self.net_ch_mult, + conv_resample=True, + dims=2, + dropout=self.dropout, + embed_channels=self.net_beatgans_embed_channels, + image_size=self.img_size, + in_channels=self.model_input_channels, + model_channels=self.net_ch, + num_classes=None, + num_head_channels=-1, + num_heads_upsample=-1, + num_heads=self.net_beatgans_attn_head, + num_res_blocks=self.net_num_res_blocks, + num_input_res_blocks=self.net_num_input_res_blocks, + out_channels=self.model_out_channels, + resblock_updown=self.net_resblock_updown, + use_checkpoint=self.net_beatgans_gradient_checkpoint, + use_new_attention_order=False, + resnet_two_cond=self.net_beatgans_resnet_two_cond, + resnet_use_zero_module=self. + net_beatgans_resnet_use_zero_module, + ) + elif self.model_name in [ + ModelName.beatgans_autoenc, + ]: + cls = BeatGANsAutoencConfig + if self.model_name == ModelName.beatgans_autoenc: + self.model_type = ModelType.autoencoder + else: + raise NotImplementedError() + + if self.net_latent_net_type == LatentNetType.none: + latent_net_conf = None + elif self.net_latent_net_type == LatentNetType.skip: + latent_net_conf = MLPSkipNetConfig( + num_channels=self.style_ch, + skip_layers=self.net_latent_skip_layers, + num_hid_channels=self.net_latent_num_hid_channels, + num_layers=self.net_latent_layers, + num_time_emb_channels=self.net_latent_time_emb_channels, + activation=self.net_latent_activation, + use_norm=self.net_latent_use_norm, + condition_bias=self.net_latent_condition_bias, + dropout=self.net_latent_dropout, + last_act=self.net_latent_net_last_act, + num_time_layers=self.net_latent_num_time_layers, + time_last_act=self.net_latent_time_last_act, + ) + else: + raise NotImplementedError() + + self.model_conf = cls( + attention_resolutions=self.net_attn, + channel_mult=self.net_ch_mult, + conv_resample=True, + dims=1, + dropout=self.dropout, + embed_channels=self.net_beatgans_embed_channels, + enc_out_channels=self.style_ch, + enc_pool=self.net_enc_pool, + enc_num_res_block=self.net_enc_num_res_blocks, + enc_channel_mult=self.net_enc_channel_mult, + enc_grad_checkpoint=self.net_enc_grad_checkpoint, + enc_attn_resolutions=self.net_enc_attn, + image_size=self.img_size, + in_channels=self.model_input_channels, + model_channels=self.net_ch, + num_classes=None, + num_head_channels=-1, + num_heads_upsample=-1, + num_heads=self.net_beatgans_attn_head, + num_res_blocks=self.net_num_res_blocks, + num_input_res_blocks=self.net_num_input_res_blocks, + out_channels=self.model_out_channels, + resblock_updown=self.net_resblock_updown, + use_checkpoint=self.net_beatgans_gradient_checkpoint, + use_new_attention_order=False, + resnet_two_cond=self.net_beatgans_resnet_two_cond, + resnet_use_zero_module=self. + net_beatgans_resnet_use_zero_module, + latent_net_conf=latent_net_conf, + resnet_cond_channels=self.net_beatgans_resnet_cond_channels, + num_users = self.num_users, + ) + else: + raise NotImplementedError(self.model_name) + + return self.model_conf \ No newline at end of file diff --git a/config_base.py b/config_base.py new file mode 100755 index 0000000..a071af8 --- /dev/null +++ b/config_base.py @@ -0,0 +1,71 @@ +import json +import os +from copy import deepcopy +from dataclasses import dataclass + + +@dataclass +class BaseConfig: + def clone(self): + return deepcopy(self) + + def inherit(self, another): + """inherit common keys from a given config""" + common_keys = set(self.__dict__.keys()) & set(another.__dict__.keys()) + for k in common_keys: + setattr(self, k, getattr(another, k)) + + def propagate(self): + """push down the configuration to all members""" + for k, v in self.__dict__.items(): + if isinstance(v, BaseConfig): + v.inherit(self) + v.propagate() + + def save(self, save_path): + """save config to json file""" + dirname = os.path.dirname(save_path) + if not os.path.exists(dirname): + os.makedirs(dirname) + conf = self.as_dict_jsonable() + with open(save_path, 'w') as f: + json.dump(conf, f) + + def load(self, load_path): + """load json config""" + with open(load_path) as f: + conf = json.load(f) + self.from_dict(conf) + + def from_dict(self, dict, strict=False): + for k, v in dict.items(): + if not hasattr(self, k): + if strict: + raise ValueError(f"loading extra '{k}'") + else: + print(f"loading extra '{k}'") + continue + if isinstance(self.__dict__[k], BaseConfig): + self.__dict__[k].from_dict(v) + else: + self.__dict__[k] = v + + def as_dict_jsonable(self): + conf = {} + for k, v in self.__dict__.items(): + if isinstance(v, BaseConfig): + conf[k] = v.as_dict_jsonable() + else: + if jsonable(v): + conf[k] = v + else: + pass + return conf + + +def jsonable(x): + try: + json.dumps(x) + return True + except TypeError: + return False diff --git a/dataset.py b/dataset.py new file mode 100755 index 0000000..328927d --- /dev/null +++ b/dataset.py @@ -0,0 +1,21 @@ +from torch.utils.data import Dataset +import torch +import pandas as pd + +def loadDataset(conf): + eval('taskdata, tasklabels = load%s(conf)'%(conf.pretrainDataset)) # plug in the function to load your own dataset + tasklabels = pd.DataFrame(tasklabels, columns=['user']) + print('taskdata.shape:', taskdata.shape) # (N, 2, window_length*sample_freq) + return taskdata, tasklabels + +class SimpleSet(Dataset): + def __init__(self, data, labels, intflag=True): + self.data = torch.tensor(data, dtype=torch.float) + if intflag: + self.label = torch.tensor(labels, dtype=torch.long) + else: + self.label = torch.tensor(labels, dtype=torch.float) + def __len__(self): + return len(self.data) + def __getitem__(self, index): + return self.data[index], self.label[index] \ No newline at end of file diff --git a/diffusion/__init__.py b/diffusion/__init__.py new file mode 100644 index 0000000..4e0838c --- /dev/null +++ b/diffusion/__init__.py @@ -0,0 +1,6 @@ +from typing import Union + +from .diffusion import SpacedDiffusionBeatGans, SpacedDiffusionBeatGansConfig + +Sampler = Union[SpacedDiffusionBeatGans] +SamplerConfig = Union[SpacedDiffusionBeatGansConfig] diff --git a/diffusion/base.py b/diffusion/base.py new file mode 100755 index 0000000..6275450 --- /dev/null +++ b/diffusion/base.py @@ -0,0 +1,1086 @@ +""" +This code started out as a PyTorch port of Ho et al's diffusion models: +https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py + +Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules. +""" + +from config_base import BaseConfig +import math + +import numpy as np +import torch as th +from model import * +from model.nn import mean_flat +from typing import NamedTuple, Tuple +from choices import * +from torch.cuda.amp import autocast +import torch.nn.functional as F + +from dataclasses import dataclass + +from model.MI import * + +@dataclass +class GaussianDiffusionBeatGansConfig(BaseConfig): + gen_type: GenerativeType + betas: Tuple[float] + model_type: ModelType + model_mean_type: ModelMeanType + model_var_type: ModelVarType + loss_type: LossType + rescale_timesteps: bool + fp16: bool + train_pred_xstart_detach: bool = True + + def make_sampler(self): + return GaussianDiffusionBeatGans(self) + + +class GaussianDiffusionBeatGans: + """ + Utilities for training and sampling diffusion models. + + Ported directly from here, and then adapted over time to further experimentation. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 + + :param betas: a 1-D numpy array of betas for each diffusion timestep, + starting at T and going to 1. + :param model_mean_type: a ModelMeanType determining what the model outputs. + :param model_var_type: a ModelVarType determining how variance is output. + :param loss_type: a LossType determining the loss function to use. + :param rescale_timesteps: if True, pass floating point timesteps into the + model so that they are always scaled like in the + original paper (0 to 1000). + """ + def __init__(self, conf: GaussianDiffusionBeatGansConfig): + self.conf = conf + self.model_mean_type = conf.model_mean_type + self.model_var_type = conf.model_var_type + self.loss_type = conf.loss_type + self.rescale_timesteps = conf.rescale_timesteps + + betas = np.array(conf.betas, dtype=np.float64) + self.betas = betas + assert len(betas.shape) == 1, "betas must be 1-D" + assert (betas > 0).all() and (betas <= 1).all() + + self.num_timesteps = int(betas.shape[0]) + + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) + self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) + assert self.alphas_cumprod_prev.shape == (self.num_timesteps, ) + + self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - + 1) + + self.posterior_variance = (betas * (1.0 - self.alphas_cumprod_prev) / + (1.0 - self.alphas_cumprod)) + self.posterior_log_variance_clipped = np.log( + np.append(self.posterior_variance[1], self.posterior_variance[1:])) + self.posterior_mean_coef1 = (betas * + np.sqrt(self.alphas_cumprod_prev) / + (1.0 - self.alphas_cumprod)) + self.posterior_mean_coef2 = ((1.0 - self.alphas_cumprod_prev) * + np.sqrt(alphas) / + (1.0 - self.alphas_cumprod)) + + def training_losses(self, + model: Model, + x_start: th.Tensor, + t: th.Tensor, + model_kwargs=None, + noise: th.Tensor = None, + user_label=None, + lossbetas=None): + """ + Compute training losses for a single timestep. + + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param t: a batch of timestep indices. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param noise: if specified, the specific Gaussian noise to try to remove. + :return: a dict with the key "loss" containing a tensor of shape [N]. + Some mean or variance settings may also have other keys. + """ + if model_kwargs is None: + model_kwargs = {} + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start, t, noise=noise) + + terms = {'x_t': x_t} + + if self.loss_type in [ + LossType.mse, + LossType.l1, + ]: + with autocast(self.conf.fp16): + model_forward = model.forward(x=x_t, + t=self._scale_timesteps(t), + x_start=x_start, + **model_kwargs) + model_output = model_forward.pred + + _model_output = model_output + if self.conf.train_pred_xstart_detach: + _model_output = _model_output.detach() + p_mean_var = self.p_mean_variance( + model=DummyModel(pred=model_output), + x=x_t, + t=t, + clip_denoised=False) + terms['pred_xstart'] = p_mean_var['pred_xstart'] + target_types = { + ModelMeanType.eps: noise, + } + target = target_types[self.model_mean_type] + assert model_output.shape == target.shape == x_start.shape + + if self.loss_type == LossType.mse: + if self.model_mean_type == ModelMeanType.eps: + assert (x_start >= 0).all() and (x_start <= 1).all() + assert (terms['pred_xstart'] >= 0).all() and (terms['pred_xstart'] <= 1).all() + assert terms['pred_xstart'].requires_grad + + terms["mse"] = th.zeros((x_start.shape[0]), device=x_start.device) + if lossbetas['recon']!=0: + input_mse = mean_flat((x_start - terms['pred_xstart'])**2) + assert input_mse.requires_grad and input_mse.grad_fn is not None + terms["mse"]+=input_mse*lossbetas['recon'] + if lossbetas['noise']!=0: + noise_mse = mean_flat((model_output - target)**2) + assert noise_mse.requires_grad and noise_mse.grad_fn is not None + terms["mse"]+=noise_mse*lossbetas['noise'] + if lossbetas['user']!=0: + user_cross = F.cross_entropy(model_forward.user_pred, user_label, reduction='none') + assert user_cross.requires_grad and user_cross.grad_fn is not None + terms["mse"]+=user_cross*lossbetas['user'] + if lossbetas['nonuser']!=0: + non_user_cross = F.cross_entropy(model_forward.non_user_pred, user_label, reduction='none') + assert non_user_cross.requires_grad and non_user_cross.grad_fn is not None + terms["mse"]+=non_user_cross*lossbetas['nonuser'] + if 'mi' in lossbetas.keys() and lossbetas['mi']!=0: + user_emb = model_forward.cond[:, :model_forward.cond.shape[1]//2] + non_user_emb = model_forward.cond[:, model_forward.cond.shape[1]//2:] + minval = th.min(model_forward.cond) + maxval = th.max(model_forward.cond) + mutual_info = MI_pytorch(bins=20, min=minval, max=maxval, sigma=100, reduction='individual').to(user_emb.device) + mi = mutual_info(user_emb, non_user_emb) + assert mi.requires_grad and mi.grad_fn is not None + terms["mse"]+=mi*lossbetas['mi'] + else: + raise NotImplementedError() + elif self.loss_type == LossType.l1: + terms["mse"] = mean_flat((target - model_output).abs()) + else: + raise NotImplementedError() + if "vb" in terms: + terms["loss"] = terms["mse"] + terms["vb"] + else: + terms["loss"] = terms["mse"] + else: + raise NotImplementedError(self.loss_type) + + return terms + + def sample(self, + model: Model, + shape=None, + noise=None, + cond=None, + x_start=None, + clip_denoised=True, + model_kwargs=None, + progress=False): + """ + Args: + x_start: given for the autoencoder + """ + if model_kwargs is None: + model_kwargs = {} + if self.conf.model_type.has_autoenc(): + model_kwargs['x_start'] = x_start + model_kwargs['cond'] = cond + + if self.conf.gen_type == GenerativeType.ddpm: + return self.p_sample_loop(model, + shape=shape, + noise=noise, + clip_denoised=clip_denoised, + model_kwargs=model_kwargs, + progress=progress) + elif self.conf.gen_type == GenerativeType.ddim: + return self.ddim_sample_loop(model, + shape=shape, + noise=noise, + clip_denoised=clip_denoised, + model_kwargs=model_kwargs, + progress=progress) + else: + raise NotImplementedError() + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * + x_start) + variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, + x_start.shape) + log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, + t, x_start.shape) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data for a given number of diffusion steps. + + In other words, sample from q(x_t | x_0). + + :param x_start: the initial data batch. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :param noise: if specified, the split-out normal noise. + :return: A noisy version of x_start. + """ + if noise is None: + noise = th.randn_like(x_start) + assert noise.shape == x_start.shape + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * + x_start + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, + t, x_start.shape) * noise) + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior: + + q(x_{t-1} | x_t, x_0) + + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * + x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * + x_t) + posterior_variance = _extract_into_tensor(self.posterior_variance, t, + x_t.shape) + posterior_log_variance_clipped = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape) + assert (posterior_mean.shape[0] == posterior_variance.shape[0] == + posterior_log_variance_clipped.shape[0] == x_start.shape[0]) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, + model: Model, + x, + t, + clip_denoised=True, + denoised_fn=None, + model_kwargs=None): + """ + Apply the model to get p(x_{t-1} | x_t), as well as a prediction of + the initial x, x_0. + + :param model: the model, which takes a signal and a batch of timesteps + as input. + :param x: the [N x C x ...] tensor at time t. + :param t: a 1-D Tensor of timesteps. + :param clip_denoised: if True, clip the denoised signal into [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. Applies before + clip_denoised. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict with the following keys: + - 'mean': the model mean output. + - 'variance': the model variance output. + - 'log_variance': the log of 'variance'. + - 'pred_xstart': the prediction for x_0. + """ + if model_kwargs is None: + model_kwargs = {} + + B, C = x.shape[:2] + assert t.shape == (B, ) + with autocast(self.conf.fp16): + model_forward = model.forward(x=x, + t=self._scale_timesteps(t), + **model_kwargs) + model_output = model_forward.pred + + if self.model_var_type in [ + ModelVarType.fixed_large, ModelVarType.fixed_small + ]: + model_variance, model_log_variance = { + ModelVarType.fixed_large: ( + np.append(self.posterior_variance[1], self.betas[1:]), + np.log( + np.append(self.posterior_variance[1], self.betas[1:])), + ), + ModelVarType.fixed_small: ( + self.posterior_variance, + self.posterior_log_variance_clipped, + ), + }[self.model_var_type] + model_variance = _extract_into_tensor(model_variance, t, x.shape) + model_log_variance = _extract_into_tensor(model_log_variance, t, + x.shape) + + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + return x.clamp(-1, 1) + return F.sigmoid(x) + + if self.model_mean_type in [ + ModelMeanType.eps, + ]: + if self.model_mean_type == ModelMeanType.eps: + pred_xstart = process_xstart( + self._predict_xstart_from_eps(x_t=x, t=t, + eps=model_output)) + else: + raise NotImplementedError() + model_mean, _, _ = self.q_posterior_mean_variance( + x_start=pred_xstart, x_t=x, t=t) + else: + raise NotImplementedError(self.model_mean_type) + + assert (model_mean.shape == model_log_variance.shape == + pred_xstart.shape == x.shape) + + return { + "mean": model_mean, + "variance": model_variance, + "log_variance": model_log_variance, + "pred_xstart": pred_xstart, + 'model_forward': model_forward, + } + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return (_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, + x_t.shape) * x_t - + _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, + x_t.shape) * eps) + + def _predict_xstart_from_xprev(self, x_t, t, xprev): + assert x_t.shape == xprev.shape + return ( + _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) + * xprev - _extract_into_tensor( + self.posterior_mean_coef2 / self.posterior_mean_coef1, t, + x_t.shape) * x_t) + + def _predict_xstart_from_scaled_xstart(self, t, scaled_xstart): + return scaled_xstart * _extract_into_tensor( + self.sqrt_recip_alphas_cumprod, t, scaled_xstart.shape) + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return (_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, + x_t.shape) * x_t - + pred_xstart) / _extract_into_tensor( + self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def _predict_eps_from_scaled_xstart(self, x_t, t, scaled_xstart): + """ + Args: + scaled_xstart: is supposed to be sqrt(alphacum) * x_0 + """ + return (x_t - scaled_xstart) / _extract_into_tensor( + self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) + + def _scale_timesteps(self, t): + if self.rescale_timesteps: + return t.float() * (1000.0 / self.num_timesteps) + return t + + def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute the mean for the previous step, given a function cond_fn that + computes the gradient of a conditional log probability with respect to + x. In particular, cond_fn computes grad(log(p(y|x))), and we want to + condition on y. + + This uses the conditioning strategy from Sohl-Dickstein et al. (2015). + """ + gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs) + new_mean = (p_mean_var["mean"].float() + + p_mean_var["variance"] * gradient.float()) + return new_mean + + def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute what the p_mean_variance output would have been, should the + model's score function be conditioned by cond_fn. + + See condition_mean() for details on cond_fn. + + Unlike condition_mean(), this instead uses the conditioning strategy + from Song et al (2020). + """ + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + + eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) + eps = eps - (1 - alpha_bar).sqrt() * cond_fn( + x, self._scale_timesteps(t), **model_kwargs) + + out = p_mean_var.copy() + out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) + out["mean"], _, _ = self.q_posterior_mean_variance( + x_start=out["pred_xstart"], x_t=x, t=t) + return out + + def p_sample( + self, + model: Model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + ): + """ + Sample x_{t-1} from the model at the given timestep. + + :param model: the model to sample from. + :param x: the current tensor at x_{t-1}. + :param t: the value of t, starting at 0 for the first diffusion step. + :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - 'sample': a random sample from the model. + - 'pred_xstart': a prediction of x_0. + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + noise = th.randn_like(x) + nonzero_mask = ((t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) + if cond_fn is not None: + out["mean"] = self.condition_mean(cond_fn, + out, + x, + t, + model_kwargs=model_kwargs) + sample = out["mean"] + nonzero_mask * th.exp( + 0.5 * out["log_variance"]) * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def p_sample_loop( + self, + model: Model, + shape=None, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model. + + :param model: the model module. + :param shape: the shape of the samples, (N, C, H, W). + :param noise: if specified, the noise from the encoder to sample. + Should be of the same shape as `shape`. + :param clip_denoised: if True, clip x_start predictions to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param device: if specified, the device to create the samples on. + If not specified, use a model parameter's device. + :param progress: if True, show a tqdm progress bar. + :return: a non-differentiable batch of samples. + """ + final = None + for sample in self.p_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + ): + final = sample + return final["sample"] + + def p_sample_loop_progressive( + self, + model: Model, + shape=None, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model and yield intermediate samples from + each timestep of diffusion. + + Arguments are the same as p_sample_loop(). + Returns a generator over dicts, where each dict is the return value of + p_sample(). + """ + if device is None: + device = next(model.parameters()).device + if noise is not None: + img = noise + else: + assert isinstance(shape, (tuple, list)) + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * len(img), device=device) + with th.no_grad(): + out = self.p_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + ) + yield out + img = out["sample"] + + def ddim_sample( + self, + model: Model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t-1} from the model using DDIM. + + Same usage as p_sample(). + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, + out, + x, + t, + model_kwargs=model_kwargs) + + eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, + x.shape) + sigma = (eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * + th.sqrt(1 - alpha_bar / alpha_bar_prev)) + noise = th.randn_like(x) + mean_pred = (out["pred_xstart"] * th.sqrt(alpha_bar_prev) + + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps) + nonzero_mask = ((t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) + sample = mean_pred + nonzero_mask * sigma * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def ddim_reverse_sample( + self, + model: Model, + x, + t, + clip_denoised=True, + denoised_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t+1} from the model using DDIM reverse ODE. + """ + assert eta == 0.0, "Reverse ODE only for deterministic path" + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + eps = (_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) + * x - out["pred_xstart"]) / _extract_into_tensor( + self.sqrt_recipm1_alphas_cumprod, t, x.shape) + alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) + + mean_pred = (out["pred_xstart"] * th.sqrt(alpha_bar_next) + + th.sqrt(1 - alpha_bar_next) * eps) + return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} + + def ddim_reverse_sample_loop( + self, + model: Model, + x, + clip_denoised=True, + denoised_fn=None, + model_kwargs=None, + eta=0.0, + device=None, + ): + if device is None: + device = next(model.parameters()).device + sample_t = [] + xstart_t = [] + T = [] + indices = list(range(self.num_timesteps)) + sample = x + for i in indices: + t = th.tensor([i] * len(sample), device=device) + with th.no_grad(): + out = self.ddim_reverse_sample(model, + sample, + t=t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + eta=eta) + sample = out['sample'] + sample_t.append(sample) + xstart_t.append(out['pred_xstart']) + T.append(t) + + return { + 'sample': sample, + 'sample_t': sample_t, + 'xstart_t': xstart_t, + 'T': T, + } + + def ddim_sample_loop( + self, + model: Model, + shape=None, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Generate samples from the model using DDIM. + + Same usage as p_sample_loop(). + """ + final = None + for sample in self.ddim_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + eta=eta, + ): + final = sample + return final["sample"] + + def ddim_sample_loop_progressive( + self, + model: Model, + shape=None, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Use DDIM to sample from the model and yield intermediate samples from + each timestep of DDIM. + + Same usage as p_sample_loop_progressive(). + """ + if device is None: + device = next(model.parameters()).device + if noise is not None: + img = noise + else: + assert isinstance(shape, (tuple, list)) + img = th.randn(*shape, device=device) + + indices = list(range(self.num_timesteps))[::-1] + + if progress: + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + _kwargs = model_kwargs + + t = th.tensor([i] * len(img), device=device) + with th.no_grad(): + out = self.ddim_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=_kwargs, + eta=eta, + ) + out['t'] = t + yield out + img = out["sample"] + + def _vb_terms_bpd(self, + model: Model, + x_start, + x_t, + t, + clip_denoised=True, + model_kwargs=None): + """ + Get a term for the variational lower-bound. + + The resulting units are bits (rather than nats, as one might expect). + This allows for comparison to other papers. + + :return: a dict with the following keys: + - 'output': a shape [N] tensor of NLLs or KLs. + - 'pred_xstart': the x_0 predictions. + """ + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t) + out = self.p_mean_variance(model, + x_t, + t, + clip_denoised=clip_denoised, + model_kwargs=model_kwargs) + kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"], + out["log_variance"]) + kl = mean_flat(kl) / np.log(2.0) + + decoder_nll = -discretized_gaussian_log_likelihood( + x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]) + assert decoder_nll.shape == x_start.shape + decoder_nll = mean_flat(decoder_nll) / np.log(2.0) + + output = th.where((t == 0), decoder_nll, kl) + return { + "output": output, + "pred_xstart": out["pred_xstart"], + 'model_forward': out['model_forward'], + } + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + + This term can't be optimized, as it only depends on the encoder. + + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = th.tensor([self.num_timesteps - 1] * batch_size, + device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl(mean1=qt_mean, + logvar1=qt_log_variance, + mean2=0.0, + logvar2=0.0) + return mean_flat(kl_prior) / np.log(2.0) + + def calc_bpd_loop(self, + model: Model, + x_start, + clip_denoised=True, + model_kwargs=None): + """ + Compute the entire variational lower-bound, measured in bits-per-dim, + as well as other related quantities. + + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param clip_denoised: if True, clip denoised samples. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + + :return: a dict containing the following keys: + - total_bpd: the total variational lower-bound, per batch element. + - prior_bpd: the prior term in the lower-bound. + - vb: an [N x T] tensor of terms in the lower-bound. + - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. + - mse: an [N x T] tensor of epsilon MSEs for each timestep. + """ + device = x_start.device + batch_size = x_start.shape[0] + + vb = [] + xstart_mse = [] + mse = [] + for t in list(range(self.num_timesteps))[::-1]: + t_batch = th.tensor([t] * batch_size, device=device) + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) + with th.no_grad(): + out = self._vb_terms_bpd( + model, + x_start=x_start, + x_t=x_t, + t=t_batch, + clip_denoised=clip_denoised, + model_kwargs=model_kwargs, + ) + vb.append(out["output"]) + xstart_mse.append(mean_flat((out["pred_xstart"] - x_start)**2)) + eps = self._predict_eps_from_xstart(x_t, t_batch, + out["pred_xstart"]) + mse.append(mean_flat((eps - noise)**2)) + + vb = th.stack(vb, dim=1) + xstart_mse = th.stack(xstart_mse, dim=1) + mse = th.stack(mse, dim=1) + + prior_bpd = self._prior_bpd(x_start) + total_bpd = vb.sum(dim=1) + prior_bpd + return { + "total_bpd": total_bpd, + "prior_bpd": prior_bpd, + "vb": vb, + "xstart_mse": xstart_mse, + "mse": mse, + } + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res.expand(broadcast_shape) + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): + """ + Get a pre-defined beta schedule for the given name. + + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name == "linear": + scale = 1000 / num_diffusion_timesteps + beta_start = scale * 0.0001 + beta_end = scale * 0.02 + return np.linspace(beta_start, + beta_end, + num_diffusion_timesteps, + dtype=np.float64) + elif schedule_name == "cosine": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2)**2, + ) + elif schedule_name == "const0.01": + scale = 1000 / num_diffusion_timesteps + return np.array([scale * 0.01] * num_diffusion_timesteps, + dtype=np.float64) + elif schedule_name == "const0.015": + scale = 1000 / num_diffusion_timesteps + return np.array([scale * 0.015] * num_diffusion_timesteps, + dtype=np.float64) + elif schedule_name == "const0.008": + scale = 1000 / num_diffusion_timesteps + return np.array([scale * 0.008] * num_diffusion_timesteps, + dtype=np.float64) + elif schedule_name == "const0.0065": + scale = 1000 / num_diffusion_timesteps + return np.array([scale * 0.0065] * num_diffusion_timesteps, + dtype=np.float64) + elif schedule_name == "const0.0055": + scale = 1000 / num_diffusion_timesteps + return np.array([scale * 0.0055] * num_diffusion_timesteps, + dtype=np.float64) + elif schedule_name == "const0.0045": + scale = 1000 / num_diffusion_timesteps + return np.array([scale * 0.0045] * num_diffusion_timesteps, + dtype=np.float64) + elif schedule_name == "const0.0035": + scale = 1000 / num_diffusion_timesteps + return np.array([scale * 0.0035] * num_diffusion_timesteps, + dtype=np.float64) + elif schedule_name == "const0.0025": + scale = 1000 / num_diffusion_timesteps + return np.array([scale * 0.0025] * num_diffusion_timesteps, + dtype=np.float64) + elif schedule_name == "const0.0015": + scale = 1000 / num_diffusion_timesteps + return np.array([scale * 0.0015] * num_diffusion_timesteps, + dtype=np.float64) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + Compute the KL divergence between two gaussians. + + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, th.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + logvar1, logvar2 = [ + x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * (-1.0 + logvar2 - logvar1 + th.exp(logvar1 - logvar2) + + ((mean1 - mean2)**2) * th.exp(-logvar2)) + + +def approx_standard_normal_cdf(x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. + """ + return 0.5 * ( + 1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) + + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. + + :param x: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + assert x.shape == means.shape == log_scales.shape + centered_x = x - means + inv_stdv = th.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 1.0 / 255.0) + cdf_plus = approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - 1.0 / 255.0) + cdf_min = approx_standard_normal_cdf(min_in) + log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = th.where( + x < -0.999, + log_cdf_plus, + th.where(x > 0.999, log_one_minus_cdf_min, + th.log(cdf_delta.clamp(min=1e-12))), + ) + assert log_probs.shape == x.shape + return log_probs + +class DummyModel(th.nn.Module): + def __init__(self, pred): + super().__init__() + self.pred = pred + + def forward(self, *args, **kwargs): + return DummyReturn(pred=self.pred) + +class DummyReturn(NamedTuple): + pred: th.Tensor \ No newline at end of file diff --git a/diffusion/diffusion.py b/diffusion/diffusion.py new file mode 100755 index 0000000..422bae4 --- /dev/null +++ b/diffusion/diffusion.py @@ -0,0 +1,154 @@ +from .base import * +from dataclasses import dataclass + + +def space_timesteps(num_timesteps, section_counts): + """ + Create a list of timesteps to use from an original diffusion process, + given the number of timesteps we want to take from equally-sized portions + of the original process. + + For example, if there's 300 timesteps and the section counts are [10,15,20] + then the first 100 timesteps are strided to be 10 timesteps, the second 100 + are strided to be 15 timesteps, and the final 100 are strided to be 20. + + If the stride is a string starting with "ddim", then the fixed striding + from the DDIM paper is used, and only one section is allowed. + + :param num_timesteps: the number of diffusion steps in the original + process to divide up. + :param section_counts: either a list of numbers, or a string containing + comma-separated numbers, indicating the step count + per section. As a special case, use "ddimN" where N + is a number of steps to use the striding from the + DDIM paper. + :return: a set of diffusion steps from the original process to use. + """ + if isinstance(section_counts, str): + if section_counts.startswith("ddim"): + desired_count = int(section_counts[len("ddim"):]) + for i in range(1, num_timesteps): + if len(range(0, num_timesteps, i)) == desired_count: + return set(range(0, num_timesteps, i)) + raise ValueError( + f"cannot create exactly {num_timesteps} steps with an integer stride" + ) + section_counts = [int(x) for x in section_counts.split(",")] + size_per = num_timesteps // len(section_counts) + extra = num_timesteps % len(section_counts) + start_idx = 0 + all_steps = [] + for i, section_count in enumerate(section_counts): + size = size_per + (1 if i < extra else 0) + if size < section_count: + raise ValueError( + f"cannot divide section of {size} steps into {section_count}") + if section_count <= 1: + frac_stride = 1 + else: + frac_stride = (size - 1) / (section_count - 1) + cur_idx = 0.0 + taken_steps = [] + for _ in range(section_count): + taken_steps.append(start_idx + round(cur_idx)) + cur_idx += frac_stride + all_steps += taken_steps + start_idx += size + return set(all_steps) + + +@dataclass +class SpacedDiffusionBeatGansConfig(GaussianDiffusionBeatGansConfig): + use_timesteps: Tuple[int] = None + + def make_sampler(self): + return SpacedDiffusionBeatGans(self) + + +class SpacedDiffusionBeatGans(GaussianDiffusionBeatGans): + """ + A diffusion process which can skip steps in a base diffusion process. + + :param use_timesteps: a collection (sequence or set) of timesteps from the + original diffusion process to retain. + :param kwargs: the kwargs to create the base diffusion process. + """ + def __init__(self, conf: SpacedDiffusionBeatGansConfig): + self.conf = conf + self.use_timesteps = set(conf.use_timesteps) + self.timestep_map = [] + self.original_num_steps = len(conf.betas) + + base_diffusion = GaussianDiffusionBeatGans(conf) + last_alpha_cumprod = 1.0 + new_betas = [] + for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): + if i in self.use_timesteps: + new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + self.timestep_map.append(i) + conf.betas = np.array(new_betas) + super().__init__(conf) + + def p_mean_variance(self, model: Model, *args, **kwargs): + return super().p_mean_variance(self._wrap_model(model), *args, + **kwargs) + + def training_losses(self, model: Model, *args, **kwargs): + return super().training_losses(self._wrap_model(model), *args, + **kwargs) + + def condition_mean(self, cond_fn, *args, **kwargs): + return super().condition_mean(self._wrap_model(cond_fn), *args, + **kwargs) + + def condition_score(self, cond_fn, *args, **kwargs): + return super().condition_score(self._wrap_model(cond_fn), *args, + **kwargs) + + def _wrap_model(self, model: Model): + if isinstance(model, _WrappedModel): + return model + return _WrappedModel(model, self.timestep_map, self.rescale_timesteps, + self.original_num_steps) + + def _scale_timesteps(self, t): + return t + + +class _WrappedModel: + """ + converting the supplied t's to the old t's scales. + """ + def __init__(self, model, timestep_map, rescale_timesteps, + original_num_steps): + self.model = model + self.timestep_map = timestep_map + self.rescale_timesteps = rescale_timesteps + self.original_num_steps = original_num_steps + + def forward(self, x, t, t_cond=None, **kwargs): + """ + Args: + t: t's with differrent ranges (can be << T due to smaller eval T) need to be converted to the original t's + t_cond: the same as t but can be of different values + """ + map_tensor = th.tensor(self.timestep_map, + device=t.device, + dtype=t.dtype) + + def do(t): + new_ts = map_tensor[t] + if self.rescale_timesteps: + new_ts = new_ts.float() * (1000.0 / self.original_num_steps) + return new_ts + + if t_cond is not None: + t_cond = do(t_cond) + return self.model(x=x, t=do(t), t_cond=t_cond, **kwargs) + + def __getattr__(self, name): + if hasattr(self.model, name): + func = getattr(self.model, name) + return func + raise AttributeError(name) diff --git a/diffusion/resample.py b/diffusion/resample.py new file mode 100644 index 0000000..658ce97 --- /dev/null +++ b/diffusion/resample.py @@ -0,0 +1,62 @@ +from abc import ABC, abstractmethod + +import numpy as np +import torch as th + + +def create_named_schedule_sampler(name, diffusion): + """ + Create a ScheduleSampler from a library of pre-defined samplers. + + :param name: the name of the sampler. + :param diffusion: the diffusion object to sample for. + """ + if name == "uniform": + return UniformSampler(diffusion) + else: + raise NotImplementedError(f"unknown schedule sampler: {name}") + + +class ScheduleSampler(ABC): + """ + A distribution over timesteps in the diffusion process, intended to reduce + variance of the objective. + + By default, samplers perform unbiased importance sampling, in which the + objective's mean is unchanged. + However, subclasses may override sample() to change how the resampled + terms are reweighted, allowing for actual changes in the objective. + """ + @abstractmethod + def weights(self): + """ + Get a numpy array of weights, one per diffusion step. + + The weights needn't be normalized, but must be positive. + """ + + def sample(self, batch_size, device): + """ + Importance-sample timesteps for a batch. + + :param batch_size: the number of timesteps. + :param device: the torch device to save to. + :return: a tuple (timesteps, weights): + - timesteps: a tensor of timestep indices. + - weights: a tensor of weights to scale the resulting losses. + """ + w = self.weights() + p = w / np.sum(w) + indices_np = np.random.choice(len(p), size=(batch_size, ), p=p) + indices = th.from_numpy(indices_np).long().to(device) + weights_np = 1 / (len(p) * p[indices_np]) + weights = th.from_numpy(weights_np).float().to(device) + return indices, weights + + +class UniformSampler(ScheduleSampler): + def __init__(self, num_timesteps): + self._weights = np.ones([num_timesteps]) + + def weights(self): + return self._weights diff --git a/dist_utils.py b/dist_utils.py new file mode 100755 index 0000000..88bb446 --- /dev/null +++ b/dist_utils.py @@ -0,0 +1,42 @@ +from typing import List +from torch import distributed + + +def barrier(): + if distributed.is_initialized(): + distributed.barrier() + else: + pass + + +def broadcast(data, src): + if distributed.is_initialized(): + distributed.broadcast(data, src) + else: + pass + + +def all_gather(data: List, src): + if distributed.is_initialized(): + distributed.all_gather(data, src) + else: + data[0] = src + + +def get_rank(): + if distributed.is_initialized(): + return distributed.get_rank() + else: + return 0 + + +def get_world_size(): + if distributed.is_initialized(): + return distributed.get_world_size() + else: + return 1 + + +def chunk_size(size, rank, world_size): + extra = rank < size % world_size + return size // world_size + extra \ No newline at end of file diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..e1ccc71 --- /dev/null +++ b/environment.yml @@ -0,0 +1,160 @@ +name: dismouse +channels: + - defaults + - conda-forge +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=5.1=1_gnu + - ca-certificates=2023.12.12=h06a4308_0 + - ld_impl_linux-64=2.38=h1181459_1 + - libffi=3.4.4=h6a678d5_0 + - libgcc-ng=11.2.0=h1234567_1 + - libgomp=11.2.0=h1234567_1 + - libstdcxx-ng=11.2.0=h1234567_1 + - ncurses=6.4=h6a678d5_0 + - openssl=3.0.12=h7f8727e_0 + - pip=23.3.1=py38h06a4308_0 + - python=3.8.18=h955ad1f_0 + - readline=8.2=h5eee18b_0 + - setuptools=68.2.2=py38h06a4308_0 + - sqlite=3.41.2=h5eee18b_0 + - tk=8.6.12=h1ccaba5_0 + - wheel=0.41.2=py38h06a4308_0 + - xz=5.4.5=h5eee18b_0 + - zlib=1.2.13=h5eee18b_0 + - pip: + - absl-py==2.0.0 + - aiohttp==3.9.1 + - aiosignal==1.3.1 + - appdirs==1.4.4 + - async-timeout==4.0.3 + - attrs==23.2.0 + - bleach==6.1.0 + - bokeh==3.1.1 + - cachetools==5.3.2 + - certifi==2023.11.17 + - charset-normalizer==3.3.2 + - click==8.1.7 + - cloudpickle==3.0.0 + - colorcet==3.1.0 + - contourpy==1.1.1 + - cycler==0.12.1 + - cython==0.29.37 + - dask==2023.5.0 + - datashader==0.15.2 + - datashape==0.5.2 + - docker-pycreds==0.4.0 + - filelock==3.13.1 + - fonttools==4.47.2 + - frozenlist==1.4.1 + - fsspec==2023.12.2 + - ftfy==6.1.3 + - future==0.18.3 + - gitdb==4.0.11 + - gitpython==3.1.41 + - google-auth==2.26.2 + - google-auth-oauthlib==1.0.0 + - grpcio==1.60.0 + - hdbscan==0.8.33 + - hmmlearn==0.3.2 + - holoviews==1.17.1 + - idna==3.6 + - imageio==2.34.0 + - importlib-metadata==7.0.1 + - importlib-resources==6.1.1 + - jinja2==3.1.3 + - joblib==1.3.2 + - kiwisolver==1.4.5 + - kornia==0.7.1 + - lazy-loader==0.3 + - lightning-utilities==0.10.1 + - linkify-it-py==2.0.3 + - llvmlite==0.41.1 + - lmdb==1.2.1 + - locket==1.0.0 + - lpips==0.1.4 + - markdown==3.5.2 + - markdown-it-py==3.0.0 + - markupsafe==2.1.3 + - matplotlib==3.7.4 + - mdit-py-plugins==0.4.0 + - mdurl==0.1.2 + - mpmath==1.3.0 + - multidict==6.0.4 + - multipledispatch==1.0.0 + - networkx==3.1 + - numba==0.58.1 + - numpy==1.24.4 + - nvidia-cublas-cu12==12.1.3.1 + - nvidia-cuda-cupti-cu12==12.1.105 + - nvidia-cuda-nvrtc-cu12==12.1.105 + - nvidia-cuda-runtime-cu12==12.1.105 + - nvidia-cudnn-cu12==8.9.2.26 + - nvidia-cufft-cu12==11.0.2.54 + - nvidia-curand-cu12==10.3.2.106 + - nvidia-cusolver-cu12==11.4.5.107 + - nvidia-cusparse-cu12==12.1.0.106 + - nvidia-nccl-cu12==2.19.3 + - nvidia-nvjitlink-cu12==12.3.101 + - nvidia-nvtx-cu12==12.1.105 + - oauthlib==3.2.2 + - packaging==23.2 + - pandas==1.5.3 + - panel==1.2.3 + - param==2.0.2 + - partd==1.4.1 + - pillow==10.2.0 + - protobuf==4.25.2 + - psutil==5.9.8 + - pyasn1==0.5.1 + - pyasn1-modules==0.3.0 + - pyct==0.5.0 + - pydeprecate==0.3.1 + - pynndescent==0.5.11 + - pyparsing==3.1.1 + - python-crfsuite==0.9.10 + - python-dateutil==2.8.2 + - pytorch-fid==0.2.0 + - pytorch-lightning==1.4.5 + - pytz==2023.3.post1 + - pyviz-comms==3.0.1 + - pywavelets==1.4.1 + - pyyaml==6.0.1 + - regex==2023.12.25 + - requests==2.31.0 + - requests-oauthlib==1.3.1 + - rsa==4.9 + - scikit-image==0.21.0 + - scikit-learn==1.3.2 + - scipy==1.10.1 + - sentry-sdk==1.39.2 + - setproctitle==1.3.3 + - six==1.16.0 + - sklearn-crfsuite==0.3.6 + - smmap==5.0.1 + - sympy==1.12 + - tabulate==0.9.0 + - tensorboard==2.14.0 + - tensorboard-data-server==0.7.2 + - threadpoolctl==3.2.0 + - tifffile==2023.7.10 + - toolz==0.12.1 + - torch==1.8.1 + - torchmetrics==0.5.0 + - torchvision==0.9.1 + - tornado==6.4 + - tqdm==4.66.1 + - triton==2.2.0 + - typing-extensions==4.9.0 + - tzdata==2023.4 + - uc-micro-py==1.0.3 + - umap-learn==0.5.5 + - urllib3==2.1.0 + - wandb==0.16.2 + - wcwidth==0.2.13 + - webencodings==0.5.1 + - werkzeug==3.0.1 + - xarray==2023.1.0 + - xyzservices==2023.10.1 + - yarl==1.9.4 + - zipp==3.17.0 \ No newline at end of file diff --git a/experiment.py b/experiment.py new file mode 100755 index 0000000..c9d818e --- /dev/null +++ b/experiment.py @@ -0,0 +1,378 @@ +import copy, wandb +import os +import random +import numpy as np +import pandas as pd +import pytorch_lightning as pl +import torch +from pytorch_lightning import loggers as pl_loggers +from pytorch_lightning.callbacks import * +from torch.cuda import amp +from torch.optim.optimizer import Optimizer +from torch.utils.data.dataset import TensorDataset + +from config import * +from dataset import * +from dist_utils import * + +def MakeDir(dirName): + if not os.path.exists(dirName): + os.makedirs(dirName) + +class LitModel(pl.LightningModule): + def __init__(self, conf: TrainConfig, betas): + super().__init__() + + self.save_hyperparameters({k:v for (k,v) in vars(conf).items() if not callable(v)}) + self.save_hyperparameters(conf.as_dict_jsonable()) + + assert conf.train_mode != TrainMode.manipulate + if conf.seed is not None: + pl.seed_everything(conf.seed) + + conf.betas = betas + self.conf = conf + self.model = conf.make_model_conf().make_model() + + self.ema_model = copy.deepcopy(self.model) + self.ema_model.requires_grad_(False) + self.ema_model.eval() + + model_size = 0 + for param in self.model.parameters(): + model_size += param.data.nelement() + print('Model params: %.2f M' % (model_size / 1024 / 1024)) + + self.sampler = conf.make_diffusion_conf().make_sampler() + self.eval_sampler = conf.make_eval_diffusion_conf().make_sampler() + + self.T_sampler = conf.make_T_sampler() + + if conf.train_mode.use_latent_net(): + self.latent_sampler = conf.make_latent_diffusion_conf( + ).make_sampler() + self.eval_latent_sampler = conf.make_latent_eval_diffusion_conf( + ).make_sampler() + else: + self.latent_sampler = None + self.eval_latent_sampler = None + + if conf.pretrain is not None: + print(f'loading pretrain ... {conf.pretrain.name}') + state = torch.load(conf.pretrain.path, map_location='cpu') + print('step:', state['global_step']) + self.load_state_dict(state['state_dict'], strict=False) + + if conf.latent_infer_path is not None: + print('loading latent stats ...') + state = torch.load(conf.latent_infer_path) + self.conds = state['conds'] + else: + self.conds_mean = None + self.conds_std = None + + def normalize(self, cond): + cond = (cond - self.conds_mean.to(self.device)) / self.conds_std.to( + self.device) + return cond + + def denormalize(self, cond): + cond = (cond * self.conds_std.to(self.device)) + self.conds_mean.to( + self.device) + return cond + + def render(self, noise, cond=None, T=None): + if T is None: + sampler = self.eval_sampler + else: + sampler = self.conf._make_diffusion_conf(T).make_sampler() + + if cond is not None: + pred_img = render_condition(self.conf, + self.ema_model, + noise, + sampler=sampler, + cond=cond) + else: + pred_img = render_uncondition(self.conf, + self.ema_model, + noise, + sampler=sampler, + latent_sampler=None) + return pred_img + + def encode(self, x): + assert self.conf.model_type.has_autoenc() + cond = self.ema_model.encoder.forward(x) + return cond + + def encode_stochastic(self, x, cond, T=None): + if T is None: + sampler = self.eval_sampler + else: + sampler = self.conf._make_diffusion_conf(T).make_sampler() + + out = sampler.ddim_reverse_sample_loop(self.ema_model, + x, + model_kwargs={'cond': cond}) + return out['sample'], out['xstart_t'] + + def forward(self, noise=None, x_start=None, ema_model: bool = False): + with amp.autocast(False): + if ema_model: + model = self.ema_model + else: + model = self.model + gen = self.eval_sampler.sample(model=model, + noise=noise, + x_start=x_start) + return gen + + def setup(self, stage=None) -> None: + """ + make datasets & seeding each worker separately + """ + if self.conf.seed is not None: + seed = self.conf.seed * get_world_size() + self.global_rank + np.random.seed(seed) + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + + taskdata, tasklabels = loadDataset(self.conf) + assert self.conf.num_users==tasklabels['user'].nunique() + + tasklabels = pd.DataFrame(tasklabels['user'].astype('category').cat.codes.values)[0].values.astype(int) + assert self.conf.num_users==len(np.unique(tasklabels)) + + self.train_data = self.conf.make_dataset(taskdata, tasklabels) + self.val_data = self.train_data + + def train_dataloader(self): + """ + return the dataloader, if diffusion mode => return image dataset + if latent mode => return the inferred latent dataset + """ + if self.conf.train_mode.require_dataset_infer(): + if self.conds is None: + self.conds = self.infer_whole_dataset() + self.conds_mean.data = self.conds.float().mean(dim=0, + keepdim=True) + self.conds_std.data = self.conds.float().std(dim=0, + keepdim=True) + print('mean:', self.conds_mean.mean(), 'std:', + self.conds_std.mean()) + + conf = self.conf.clone() + conf.batch_size = self.batch_size + data = TensorDataset(self.conds) + return conf.make_loader(data, shuffle=True) + else: + return torch.utils.data.DataLoader(self.train_data, batch_size=self.conf.batch_size, shuffle=True) + + @property + def batch_size(self): + """ + local batch size for each worker + """ + ws = get_world_size() + assert self.conf.batch_size % ws == 0 + return self.conf.batch_size // ws + + @property + def num_samples(self): + """ + (global) batch size * iterations + """ + return self.global_step * self.conf.batch_size_effective + + def is_last_accum(self, batch_idx): + """ + is it the last gradient accumulation loop? + used with gradient_accum > 1 and to see if the optimizer will perform "step" in this iteration or not + """ + return (batch_idx + 1) % self.conf.accum_batches == 0 + + def training_step(self, batch, batch_idx): + """ + given an input, calculate the loss function + no optimization at this stage. + """ + with amp.autocast(False): + if self.conf.train_mode.require_dataset_infer(): + cond = batch[0] + if self.conf.latent_znormalize: + cond = (cond - self.conds_mean.to( + self.device)) / self.conds_std.to(self.device) + else: + imgs = batch[0] + x_start = imgs + + if self.conf.train_mode == TrainMode.diffusion: + t, weight = self.T_sampler.sample(len(x_start), x_start.device) + + losses = self.sampler.training_losses(model=self.model, + x_start=x_start, + t=t, + user_label=batch[1], + lossbetas=self.conf.betas) + elif self.conf.train_mode.is_latent_diffusion(): + t, weight = self.T_sampler.sample(len(cond), cond.device) + latent_losses = self.latent_sampler.training_losses( + model=self.model.latent_net, x_start=cond, t=t) + losses = { + 'latent': latent_losses['loss'], + 'loss': latent_losses['loss'] + } + else: + raise NotImplementedError() + + loss = losses['loss'].mean() + self.log("train_loss", loss) + + return {'loss': loss} + + def on_train_batch_end(self, outputs, batch, batch_idx: int, + dataloader_idx: int) -> None: + """ + after each training step + """ + if self.is_last_accum(batch_idx): + if (batch_idx==len(self.train_dataloader())-1) and ((self.current_epoch+1) % 10 == 0): + save_path = os.path.join(self.conf.logdir, 'epoch%d.ckpt' % self.current_epoch) + torch.save({ + 'state_dict': self.state_dict(), + 'global_step': self.global_step, + 'loss': outputs['loss'], + }, save_path) + + if self.conf.train_mode == TrainMode.latent_diffusion: + ema(self.model.latent_net, self.ema_model.latent_net, + self.conf.ema_decay) + else: + ema(self.model, self.ema_model, self.conf.ema_decay) + + def on_before_optimizer_step(self, optimizer: Optimizer, + optimizer_idx: int) -> None: + if self.conf.grad_clip > 0: + params = [ + p for group in optimizer.param_groups for p in group['params'] + ] + torch.nn.utils.clip_grad_norm_(params, + max_norm=self.conf.grad_clip) + + def configure_optimizers(self): + out = {} + if self.conf.optimizer == OptimizerType.adam: + optim = torch.optim.Adam(self.model.parameters(), + lr=self.conf.lr, + weight_decay=self.conf.weight_decay) + elif self.conf.optimizer == OptimizerType.adamw: + optim = torch.optim.AdamW(self.model.parameters(), + lr=self.conf.lr, + weight_decay=self.conf.weight_decay) + else: + raise NotImplementedError() + out['optimizer'] = optim + if self.conf.warmup > 0: + sched = torch.optim.lr_scheduler.LambdaLR(optim, + lr_lambda=WarmupLR( + self.conf.warmup)) + out['lr_scheduler'] = { + 'scheduler': sched, + 'interval': 'step', + } + return out + + def split_tensor(self, x): + """ + extract the tensor for a corresponding "worker" in the batch dimension + + Args: + x: (n, c) + + Returns: x: (n_local, c) + """ + n = len(x) + rank = self.global_rank + world_size = get_world_size() + per_rank = n // world_size + return x[rank * per_rank:(rank + 1) * per_rank] + +def ema(source, target, decay): + source_dict = source.state_dict() + target_dict = target.state_dict() + for key in source_dict.keys(): + target_dict[key].data.copy_(target_dict[key].data * decay + + source_dict[key].data * (1 - decay)) + +class WarmupLR: + def __init__(self, warmup) -> None: + self.warmup = warmup + + def __call__(self, step): + return min(step, self.warmup) / self.warmup + + +def is_time(num_samples, every, step_size): + closest = (num_samples // every) * every + return num_samples - closest < step_size + + +def train(conf: TrainConfig, model: LitModel, gpus, nodes=1): + checkpoint = ModelCheckpoint(dirpath=conf.logdir, + save_last=True, + save_top_k=1, + every_n_train_steps=conf.save_every_samples // + conf.batch_size_effective) + checkpoint_path = f'{conf.logdir}last.ckpt' + print('ckpt path:', checkpoint_path) + if os.path.exists(checkpoint_path): + resume = checkpoint_path + print('resume!') + else: + if conf.continue_from is not None: + resume = conf.continue_from.path + else: + resume = None + + plugins = [] + if len(gpus) == 1 and nodes == 1: + accelerator = None + else: + accelerator = 'ddp' + from pytorch_lightning.plugins import DDPPlugin + + plugins.append(DDPPlugin(find_unused_parameters=False)) + + wandb_logger = pl_loggers.WandbLogger(project='dismouse', + name='%s_%s'%(model.conf.pretrainDataset, conf.logdir.split('/')[-2]), + log_model=True, + save_dir=conf.logdir, + dir = conf.logdir, + config=vars(model.conf), + save_code=True, + settings=wandb.Settings(code_dir=".")) + + trainer = pl.Trainer( + max_steps=conf.total_samples // conf.batch_size_effective, + resume_from_checkpoint=resume, + gpus=gpus, + num_nodes=nodes, + accelerator=accelerator, + precision=16 if conf.fp16 else 32, + callbacks=[ + checkpoint, + LearningRateMonitor(), + ], + replace_sampler_ddp=True, + logger= wandb_logger, + accumulate_grad_batches=conf.accum_batches, + plugins=plugins, + ) + + trainer.fit(model) + wandb.finish() \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..cdc6d47 --- /dev/null +++ b/main.py @@ -0,0 +1,29 @@ +import os +from templates import * + +if __name__ == '__main__': + gpus = [0] + + conf = mouse_autoenc('trainDiff') + + betas = { + 'recon':1, + 'noise':1, + 'user':0.01, + 'nonuser':0.01, + 'mi': 0.01 + } + betastr = '' + for k,v in betas.items(): + betastr += f'{k}{v}_' + betastr = betastr[:-1] + + diffmodel = LitModel(conf, betas) + + conf.logdir= f'{conf.logdir}/mouse_autoenc/{conf.pretrainDataset}/{betastr}/embDim{conf.net_beatgans_embed_channels}/win{conf.AEwin}/slid{conf.slid}/GRL/' + MakeDir(conf.logdir) + os.environ['WANDB_CACHE_DIR'] = conf.logdir + os.environ['WANDB_DATA_DIR'] = conf.logdir + os.environ['WANDB_IGNORE_GLOBS'] = '*.ckpt' + + train(conf, diffmodel, gpus=gpus) \ No newline at end of file diff --git a/model/MI.py b/model/MI.py new file mode 100644 index 0000000..a113cbf --- /dev/null +++ b/model/MI.py @@ -0,0 +1,176 @@ +''' +Differentiable approximation to the mutual information (MI) metric. +Implementation in PyTorch +''' + +# Imports # +# ---------------------------------------------------------------------- +import torch +import torch.nn as nn +import numpy as np +import matplotlib.pyplot as plt +from matplotlib import cm +import os + +# Note: This code snippet was taken from the discussion found at: +# https://discuss.pytorch.org/t/differentiable-torch-histc/25865/2 +# By Tony-Y +class SoftHistogram1D(nn.Module): + ''' + Differentiable 1D histogram calculation (supported via pytorch's autograd) + inupt: + x - N x D array, where N is the batch size and D is the length of each data series + bins - Number of bins for the histogram + min - Scalar min value to be included in the histogram + max - Scalar max value to be included in the histogram + sigma - Scalar smoothing factor fir the bin approximation via sigmoid functions. + Larger values correspond to sharper edges, and thus yield a more accurate approximation + output: + N x bins array, where each row is a histogram + ''' + + def __init__(self, bins=50, min=0, max=1, sigma=10): + super(SoftHistogram1D, self).__init__() + self.bins = bins + self.min = min + self.max = max + self.sigma = sigma + self.delta = float(max - min) / float(bins) + self.centers = float(min) + self.delta * (torch.arange(bins).float() + 0.5) # Bin centers + self.centers = nn.Parameter(self.centers, requires_grad=False) # Wrap for allow for cuda support + + def forward(self, x): + # Replicate x and for each row remove center + x = torch.unsqueeze(x, 1) - torch.unsqueeze(self.centers, 1) + + # Bin approximation using a sigmoid function + x = torch.sigmoid(self.sigma * (x + self.delta / 2)) - torch.sigmoid(self.sigma * (x - self.delta / 2)) + + # Sum along the non-batch dimensions + x = x.sum(dim=-1) + # x = x / x.sum(dim=-1).unsqueeze(1) # normalization + return x + + +# Note: This is an extension to the 2D case of the previous code snippet +class SoftHistogram2D(nn.Module): + ''' + Differentiable 1D histogram calculation (supported via pytorch's autograd) + inupt: + x, y - N x D array, where N is the batch size and D is the length of each data series + (i.e. vectorized image or vectorized 3D volume) + bins - Number of bins for the histogram + min - Scalar min value to be included in the histogram + max - Scalar max value to be included in the histogram + sigma - Scalar smoothing factor fir the bin approximation via sigmoid functions. + Larger values correspond to sharper edges, and thus yield a more accurate approximation + output: + N x bins array, where each row is a histogram + ''' + + def __init__(self, bins=50, min=0, max=1, sigma=10): + super(SoftHistogram2D, self).__init__() + self.bins = bins + self.min = min + self.max = max + self.sigma = sigma + self.delta = float(max - min) / float(bins) + self.centers = float(min) + self.delta * (torch.arange(bins).float() + 0.5) # Bin centers + self.centers = nn.Parameter(self.centers, requires_grad=False) # Wrap for allow for cuda support + + def forward(self, x, y): + assert x.size() == y.size(), "(SoftHistogram2D) x and y sizes do not match" + + # Replicate x and for each row remove center + x = torch.unsqueeze(x, 1) - torch.unsqueeze(self.centers, 1) + y = torch.unsqueeze(y, 1) - torch.unsqueeze(self.centers, 1) + + # Bin approximation using a sigmoid function (can be sigma_x and sigma_y respectively - same for delta) + x = torch.sigmoid(self.sigma * (x + self.delta / 2)) - torch.sigmoid(self.sigma * (x - self.delta / 2)) + y = torch.sigmoid(self.sigma * (y + self.delta / 2)) - torch.sigmoid(self.sigma * (y - self.delta / 2)) + + # Batched matrix multiplication - this way we sum jointly + z = torch.matmul(x, y.permute((0, 2, 1))) + return z + + +class MI_pytorch(nn.Module): + ''' + This class is a pytorch implementation of the mutual information (MI) calculation between two images. + This is an approximation, as the images' histograms rely on differentiable approximations of rectangular windows. + + I(X, Y) = H(X) + H(Y) - H(X, Y) = \sum(\sum(p(X, Y) * log(p(Y, Y)/(p(X) * p(Y))))) + + where H(X) = -\sum(p(x) * log(p(x))) is the entropy + ''' + + def __init__(self, bins=50, min=0, max=1, sigma=10, reduction='sum'): + super(MI_pytorch, self).__init__() + self.bins = bins + self.min = min + self.max = max + self.sigma = sigma + self.reduction = reduction + + # 2D joint histogram + self.hist2d = SoftHistogram2D(bins, min, max, sigma) + + # Epsilon - to avoid log(0) + self.eps = torch.tensor(0.00000001, dtype=torch.float32, requires_grad=False) + + def forward(self, im1, im2): + ''' + Forward implementation of a differentiable MI estimator for batched images + :param im1: N x ... tensor, where N is the batch size + ... dimensions can take any form, i.e. 2D images or 3D volumes. + :param im2: N x ... tensor, where N is the batch size + :return: N x 1 vector - the approximate MI values between the batched im1 and im2 + ''' + + # Check for valid inputs + assert im1.size() == im2.size(), "(MI_pytorch) Inputs should have the same dimensions." + + batch_size = im1.size()[0] + + # Flatten tensors + im1_flat = im1.view(im1.size()[0], -1) + im2_flat = im2.view(im2.size()[0], -1) + + # Calculate joint histogram + hgram = self.hist2d(im1_flat, im2_flat) + + # Convert to a joint distribution + # Pxy = torch.distributions.Categorical(probs=hgram).probs + Pxy = torch.div(hgram, torch.sum(hgram.view(hgram.size()[0], -1))) + + # Calculate the marginal distributions + Py = torch.sum(Pxy, dim=1).unsqueeze(1) + Px = torch.sum(Pxy, dim=2).unsqueeze(1) + + # Use the KL divergence distance to calculate the MI + Px_Py = torch.matmul(Px.permute((0, 2, 1)), Py) + + # Reshape to batch_size X all_the_rest + Pxy = Pxy.reshape(batch_size, -1) + Px_Py = Px_Py.reshape(batch_size, -1) + + # Calculate mutual information - this is an approximation due to the histogram calculation and eps, + # but it can handle batches + if batch_size == 1: + # No need for eps approximation in the case of a single batch + nzs = Pxy > 0 # Calculate based on the non-zero values only + mut_info = torch.matmul(Pxy[nzs], torch.log(Pxy[nzs]) - torch.log(Px_Py[nzs])) # MI calculation + else: + # For arbitrary batch size > 1 + mut_info = torch.sum(Pxy * (torch.log(Pxy + self.eps) - torch.log(Px_Py + self.eps)), dim=1) + + # Reduction + if self.reduction == 'sum': + mut_info = torch.sum(mut_info) + elif self.reduction == 'batchmean': + mut_info = torch.sum(mut_info) + mut_info = mut_info / float(batch_size) + elif self.reduction=='individual': + pass + + return mut_info \ No newline at end of file diff --git a/model/__init__.py b/model/__init__.py new file mode 100644 index 0000000..6a501aa --- /dev/null +++ b/model/__init__.py @@ -0,0 +1,6 @@ +from typing import Union +from .unet import BeatGANsUNetModel, BeatGANsUNetConfig +from .unet_autoenc import BeatGANsAutoencConfig, BeatGANsAutoencModel + +Model = Union[BeatGANsUNetModel, BeatGANsAutoencModel] +ModelConfig = Union[BeatGANsUNetConfig, BeatGANsAutoencConfig] diff --git a/model/blocks.py b/model/blocks.py new file mode 100644 index 0000000..1460cdb --- /dev/null +++ b/model/blocks.py @@ -0,0 +1,495 @@ +import math +from abc import abstractmethod +from dataclasses import dataclass +from numbers import Number + +import torch as th +import torch.nn.functional as F +from choices import * +from config_base import BaseConfig +from torch import nn + +from .nn import (avg_pool_nd, conv_nd, linear, normalization, + timestep_embedding, torch_checkpoint, zero_module) + + +class ScaleAt(Enum): + after_norm = 'afternorm' + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + @abstractmethod + def forward(self, x, emb=None, cond=None, lateral=None): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + def forward(self, x, emb=None, cond=None, lateral=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb=emb, cond=cond, lateral=lateral) + else: + x = layer(x) + return x + + +@dataclass +class ResBlockConfig(BaseConfig): + channels: int + emb_channels: int + dropout: float + out_channels: int = None + use_condition: bool = True + use_conv: bool = False + dims: int = 2 + use_checkpoint: bool = False + up: bool = False + down: bool = False + two_cond: bool = False + cond_emb_channels: int = None + has_lateral: bool = False + lateral_channels: int = None + use_zero_module: bool = True + + def __post_init__(self): + self.out_channels = self.out_channels or self.channels + self.cond_emb_channels = self.cond_emb_channels or self.emb_channels + + def make_model(self): + return ResBlock(self) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + """ + def __init__(self, conf: ResBlockConfig): + super().__init__() + self.conf = conf + + assert conf.lateral_channels is None + layers = [ + normalization(conf.channels), + nn.SiLU(), + conv_nd(conf.dims, conf.channels, conf.out_channels, 3, padding=1) + ] + self.in_layers = nn.Sequential(*layers) + + self.updown = conf.up or conf.down + + if conf.up: + self.h_upd = Upsample(conf.channels, False, conf.dims) + self.x_upd = Upsample(conf.channels, False, conf.dims) + elif conf.down: + self.h_upd = Downsample(conf.channels, False, conf.dims) + self.x_upd = Downsample(conf.channels, False, conf.dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + if conf.use_condition: + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear(conf.emb_channels, 2 * conf.out_channels), + ) + + if conf.two_cond: + self.cond_emb_layers = nn.Sequential( + nn.SiLU(), + linear(conf.cond_emb_channels, conf.out_channels), + ) + conv = conv_nd(conf.dims, + conf.out_channels, + conf.out_channels, + 3, + padding=1) + if conf.use_zero_module: + conv = zero_module(conv) + + layers = [] + layers += [ + normalization(conf.out_channels), + nn.SiLU(), + nn.Dropout(p=conf.dropout), + conv, + ] + self.out_layers = nn.Sequential(*layers) + + if conf.out_channels == conf.channels: + self.skip_connection = nn.Identity() + else: + if conf.use_conv: + kernel_size = 3 + padding = 1 + else: + kernel_size = 1 + padding = 0 + + self.skip_connection = conv_nd(conf.dims, + conf.channels, + conf.out_channels, + kernel_size, + padding=padding) + + def forward(self, x, emb=None, cond=None, lateral=None): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + + Args: + x: input + lateral: lateral connection from the encoder + """ + return torch_checkpoint(self._forward, (x, emb, cond, lateral), + self.conf.use_checkpoint) + + def _forward( + self, + x, + emb=None, + cond=None, + lateral=None, + ): + """ + Args: + lateral: required if "has_lateral" and non-gated, with gated, it can be supplied optionally + """ + if self.conf.has_lateral: + assert lateral is not None + x = th.cat([x, lateral], dim=1) + + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + + if self.conf.use_condition: + if emb is not None: + emb_out = self.emb_layers(emb).type(h.dtype) + else: + emb_out = None + + if self.conf.two_cond: + if cond is None: + cond_out = None + else: + if not isinstance(cond, th.Tensor): + assert isinstance(cond, dict) + cond = cond['cond'] + cond_out = self.cond_emb_layers(cond).type(h.dtype) + + if cond_out is not None: + while len(cond_out.shape) < len(h.shape): + cond_out = cond_out[..., None] + else: + cond_out = None + + h = apply_conditions( + h=h, + emb=emb_out, + cond=cond_out, + layers=self.out_layers, + scale_bias=1, + in_channels=self.conf.out_channels, + up_down_layer=None, + ) + return self.skip_connection(x) + h + + +def apply_conditions( + h, + emb=None, + cond=None, + layers: nn.Sequential = None, + scale_bias: float = 1, + in_channels: int = 512, + up_down_layer: nn.Module = None, +): + """ + apply conditions on the feature maps + + Args: + emb: time conditional (ready to scale + shift) + cond: encoder's conditional (read to scale + shift) + """ + two_cond = emb is not None and cond is not None + + if emb is not None: + while len(emb.shape) < len(h.shape): + emb = emb[..., None] + + if two_cond: + while len(cond.shape) < len(h.shape): + cond = cond[..., None] + scale_shifts = [emb, cond] + else: + scale_shifts = [emb] + + for i, each in enumerate(scale_shifts): + if each is None: + a = None + b = None + else: + if each.shape[1] == in_channels * 2: + a, b = th.chunk(each, 2, dim=1) + else: + a = each + b = None + scale_shifts[i] = (a, b) + + if isinstance(scale_bias, Number): + biases = [scale_bias] * len(scale_shifts) + else: + biases = scale_bias + + pre_layers, post_layers = layers[0], layers[1:] + + mid_layers, post_layers = post_layers[:-2], post_layers[-2:] + + h = pre_layers(h) + for i, (scale, shift) in enumerate(scale_shifts): + if scale is not None: + h = h * (biases[i] + scale) + if shift is not None: + h = h + shift + h = mid_layers(h) + + if up_down_layer is not None: + h = up_down_layer(h) + h = post_layers(h) + return h + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + def __init__(self, channels, use_conv, dims=2, out_channels=None): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, + self.channels, + self.out_channels, + 3, + padding=1) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), + mode="nearest") + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + def __init__(self, channels, use_conv, dims=2, out_channels=None): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd(dims, + self.channels, + self.out_channels, + 3, + stride=stride, + padding=1) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + self.attention = QKVAttention(self.num_heads) + else: + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return torch_checkpoint(self._forward, (x, ), self.use_checkpoint) + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + matmul_ops = 2 * b * (num_spatial**2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, + dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, + k * scale) + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, + v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter( + th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] diff --git a/model/latentnet.py b/model/latentnet.py new file mode 100644 index 0000000..2c1bc86 --- /dev/null +++ b/model/latentnet.py @@ -0,0 +1,184 @@ +import math +from dataclasses import dataclass +from enum import Enum +from typing import NamedTuple, Tuple + +import torch +from choices import * +from config_base import BaseConfig +from torch import nn +from torch.nn import init + +from .blocks import * +from .nn import timestep_embedding +from .unet import * + + +class LatentNetType(Enum): + none = 'none' + skip = 'skip' + +class LatentNetReturn(NamedTuple): + pred: torch.Tensor = None + +@dataclass +class MLPSkipNetConfig(BaseConfig): + """ + default MLP for the latent DPM in the paper! + """ + num_channels: int + skip_layers: Tuple[int] + num_hid_channels: int + num_layers: int + num_time_emb_channels: int = 64 + activation: Activation = Activation.silu + use_norm: bool = True + condition_bias: float = 1 + dropout: float = 0 + last_act: Activation = Activation.none + num_time_layers: int = 2 + time_last_act: bool = False + + def make_model(self): + return MLPSkipNet(self) + + +class MLPSkipNet(nn.Module): + """ + concat x to hidden layers + + default MLP for the latent DPM in the paper! + """ + def __init__(self, conf: MLPSkipNetConfig): + super().__init__() + self.conf = conf + + layers = [] + for i in range(conf.num_time_layers): + if i == 0: + a = conf.num_time_emb_channels + b = conf.num_channels + else: + a = conf.num_channels + b = conf.num_channels + layers.append(nn.Linear(a, b)) + if i < conf.num_time_layers - 1 or conf.time_last_act: + layers.append(conf.activation.get_act()) + self.time_embed = nn.Sequential(*layers) + + self.layers = nn.ModuleList([]) + for i in range(conf.num_layers): + if i == 0: + act = conf.activation + norm = conf.use_norm + cond = True + a, b = conf.num_channels, conf.num_hid_channels + dropout = conf.dropout + elif i == conf.num_layers - 1: + act = Activation.none + norm = False + cond = False + a, b = conf.num_hid_channels, conf.num_channels + dropout = 0 + else: + act = conf.activation + norm = conf.use_norm + cond = True + a, b = conf.num_hid_channels, conf.num_hid_channels + dropout = conf.dropout + + if i in conf.skip_layers: + a += conf.num_channels + + self.layers.append( + MLPLNAct( + a, + b, + norm=norm, + activation=act, + cond_channels=conf.num_channels, + use_cond=cond, + condition_bias=conf.condition_bias, + dropout=dropout, + )) + self.last_act = conf.last_act.get_act() + + def forward(self, x, t, **kwargs): + t = timestep_embedding(t, self.conf.num_time_emb_channels) + cond = self.time_embed(t) + h = x + for i in range(len(self.layers)): + if i in self.conf.skip_layers: + h = torch.cat([h, x], dim=1) + h = self.layers[i].forward(x=h, cond=cond) + h = self.last_act(h) + return LatentNetReturn(h) + + +class MLPLNAct(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + norm: bool, + use_cond: bool, + activation: Activation, + cond_channels: int, + condition_bias: float = 0, + dropout: float = 0, + ): + super().__init__() + self.activation = activation + self.condition_bias = condition_bias + self.use_cond = use_cond + + self.linear = nn.Linear(in_channels, out_channels) + self.act = activation.get_act() + if self.use_cond: + self.linear_emb = nn.Linear(cond_channels, out_channels) + self.cond_layers = nn.Sequential(self.act, self.linear_emb) + if norm: + self.norm = nn.LayerNorm(out_channels) + else: + self.norm = nn.Identity() + + if dropout > 0: + self.dropout = nn.Dropout(p=dropout) + else: + self.dropout = nn.Identity() + + self.init_weights() + + def init_weights(self): + for module in self.modules(): + if isinstance(module, nn.Linear): + if self.activation == Activation.relu: + init.kaiming_normal_(module.weight, + a=0, + nonlinearity='relu') + elif self.activation == Activation.lrelu: + init.kaiming_normal_(module.weight, + a=0.2, + nonlinearity='leaky_relu') + elif self.activation == Activation.silu: + init.kaiming_normal_(module.weight, + a=0, + nonlinearity='relu') + else: + pass + + def forward(self, x, cond=None): + x = self.linear(x) + if self.use_cond: + cond = self.cond_layers(cond) + cond = (cond, None) + + x = x * (self.condition_bias + cond[0]) + if cond[1] is not None: + x = x + cond[1] + x = self.norm(x) + else: + x = self.norm(x) + x = self.act(x) + x = self.dropout(x) + return x \ No newline at end of file diff --git a/model/nn.py b/model/nn.py new file mode 100755 index 0000000..4e3d8bb --- /dev/null +++ b/model/nn.py @@ -0,0 +1,135 @@ +""" +Various utilities for neural networks. +""" +import math +import torch as th +import torch.nn as nn +import torch.utils.checkpoint + +import torch.nn.functional as F + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + # @th.jit.script + def forward(self, x): + return x * th.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + assert dims==1 + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + assert dims==1 + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def update_ema(target_params, source_params, rate=0.99): + """ + Update target parameters to be closer to those of source parameters using + an exponential moving average. + + :param target_params: the target parameter sequence. + :param source_params: the source parameter sequence. + :param rate: the EMA rate (closer to 1 means slower). + """ + for targ, src in zip(target_params, source_params): + targ.detach().mul_(rate).add_(src, alpha=1 - rate) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(min(32, channels), channels) + + +def timestep_embedding(timesteps, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = th.exp(-math.log(max_period) * + th.arange(start=0, end=half, dtype=th.float32) / + half).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) + if dim % 2: + embedding = th.cat( + [embedding, th.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +def torch_checkpoint(func, args, flag, preserve_rng_state=False): + # torch's gradient checkpoint works with automatic mixed precision, given torch >= 1.8 + if flag: + return torch.utils.checkpoint.checkpoint( + func, *args, preserve_rng_state=preserve_rng_state) + else: + return func(*args) diff --git a/model/unet.py b/model/unet.py new file mode 100644 index 0000000..9e78615 --- /dev/null +++ b/model/unet.py @@ -0,0 +1,505 @@ +import math +from dataclasses import dataclass +from numbers import Number +from typing import NamedTuple, Tuple, Union + +import numpy as np +import torch as th +from torch import nn +import torch.nn.functional as F +from choices import * +from config_base import BaseConfig +from .blocks import * + +from .nn import (conv_nd, linear, normalization, timestep_embedding, + torch_checkpoint, zero_module) + + +@dataclass +class BeatGANsUNetConfig(BaseConfig): + image_size: int = 64 + in_channels: int = 2 + model_channels: int = 64 + out_channels: int = 2 + num_res_blocks: int = 2 + num_input_res_blocks: int = None + embed_channels: int = 512 + attention_resolutions: Tuple[int] = (16, ) + time_embed_channels: int = None + dropout: float = 0.1 + channel_mult: Tuple[int] = (1, 2, 4, 8) + input_channel_mult: Tuple[int] = None + conv_resample: bool = True + dims: int = 2 + num_classes: int = None + use_checkpoint: bool = False + num_heads: int = 1 + num_head_channels: int = -1 + num_heads_upsample: int = -1 + resblock_updown: bool = True + use_new_attention_order: bool = False + resnet_two_cond: bool = False + resnet_cond_channels: int = None + resnet_use_zero_module: bool = True + attn_checkpoint: bool = False + + num_users: int = None + + def make_model(self): + return BeatGANsUNetModel(self) + + +class BeatGANsUNetModel(nn.Module): + def __init__(self, conf: BeatGANsUNetConfig): + super().__init__() + self.conf = conf + + if conf.num_heads_upsample == -1: + self.num_heads_upsample = conf.num_heads + + self.dtype = th.float32 + + self.time_emb_channels = conf.time_embed_channels or conf.model_channels + self.time_embed = nn.Sequential( + linear(self.time_emb_channels, conf.embed_channels), + nn.SiLU(), + linear(conf.embed_channels, conf.embed_channels), + ) + + if conf.num_classes is not None: + self.label_emb = nn.Embedding(conf.num_classes, + conf.embed_channels) + + ch = input_ch = int(conf.channel_mult[0] * conf.model_channels) + self.input_blocks = nn.ModuleList([ + TimestepEmbedSequential( + conv_nd(conf.dims, conf.in_channels, ch, 3, padding=1)) + ]) + + kwargs = dict( + use_condition=True, + two_cond=conf.resnet_two_cond, + use_zero_module=conf.resnet_use_zero_module, + cond_emb_channels=conf.resnet_cond_channels, + ) + + self._feature_size = ch + + # input_block_chans = [ch] + input_block_chans = [[] for _ in range(len(conf.channel_mult))] + input_block_chans[0].append(ch) + + # number of blocks at each resolution + self.input_num_blocks = [0 for _ in range(len(conf.channel_mult))] + self.input_num_blocks[0] = 1 + self.output_num_blocks = [0 for _ in range(len(conf.channel_mult))] + + ds = 1 + resolution = conf.image_size + for level, mult in enumerate(conf.input_channel_mult + or conf.channel_mult): + for _ in range(conf.num_input_res_blocks or conf.num_res_blocks): + layers = [ + ResBlockConfig( + ch, + conf.embed_channels, + conf.dropout, + out_channels=int(mult * conf.model_channels), + dims=conf.dims, + use_checkpoint=conf.use_checkpoint, + **kwargs, + ).make_model() + ] + ch = int(mult * conf.model_channels) + if resolution in conf.attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=conf.use_checkpoint + or conf.attn_checkpoint, + num_heads=conf.num_heads, + num_head_channels=conf.num_head_channels, + use_new_attention_order=conf. + use_new_attention_order, + )) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans[level].append(ch) + self.input_num_blocks[level] += 1 + if level != len(conf.channel_mult) - 1: + resolution //= 2 + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlockConfig( + ch, + conf.embed_channels, + conf.dropout, + out_channels=out_ch, + dims=conf.dims, + use_checkpoint=conf.use_checkpoint, + down=True, + **kwargs, + ).make_model() if conf. + resblock_updown else Downsample(ch, + conf.conv_resample, + dims=conf.dims, + out_channels=out_ch))) + ch = out_ch + input_block_chans[level + 1].append(ch) + self.input_num_blocks[level + 1] += 1 + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlockConfig( + ch, + conf.embed_channels, + conf.dropout, + dims=conf.dims, + use_checkpoint=conf.use_checkpoint, + **kwargs, + ).make_model(), + AttentionBlock( + ch, + use_checkpoint=conf.use_checkpoint or conf.attn_checkpoint, + num_heads=conf.num_heads, + num_head_channels=conf.num_head_channels, + use_new_attention_order=conf.use_new_attention_order, + ), + ResBlockConfig( + ch, + conf.embed_channels, + conf.dropout, + dims=conf.dims, + use_checkpoint=conf.use_checkpoint, + **kwargs, + ).make_model(), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(conf.channel_mult))[::-1]: + for i in range(conf.num_res_blocks + 1): + try: + ich = input_block_chans[level].pop() + except IndexError: + ich = 0 + layers = [ + ResBlockConfig( + channels=ch + ich, + emb_channels=conf.embed_channels, + dropout=conf.dropout, + out_channels=int(conf.model_channels * mult), + dims=conf.dims, + use_checkpoint=conf.use_checkpoint, + has_lateral=True if ich > 0 else False, + lateral_channels=None, + **kwargs, + ).make_model() + ] + ch = int(conf.model_channels * mult) + if resolution in conf.attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=conf.use_checkpoint + or conf.attn_checkpoint, + num_heads=self.num_heads_upsample, + num_head_channels=conf.num_head_channels, + use_new_attention_order=conf. + use_new_attention_order, + )) + if level and i == conf.num_res_blocks: + resolution *= 2 + out_ch = ch + layers.append( + ResBlockConfig( + ch, + conf.embed_channels, + conf.dropout, + out_channels=out_ch, + dims=conf.dims, + use_checkpoint=conf.use_checkpoint, + up=True, + **kwargs, + ).make_model() if ( + conf.resblock_updown + ) else Upsample(ch, + conf.conv_resample, + dims=conf.dims, + out_channels=out_ch)) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self.output_num_blocks[level] += 1 + self._feature_size += ch + + if conf.resnet_use_zero_module: + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module( + conv_nd(conf.dims, + input_ch, + conf.out_channels, + 3, + padding=1)), + ) + else: + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + conv_nd(conf.dims, input_ch, conf.out_channels, 3, padding=1), + ) + + def forward(self, x, t, y=None, **kwargs): + """ + Apply the model to an input batch. + + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.conf.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + + # hs = [] + hs = [[] for _ in range(len(self.conf.channel_mult))] + emb = self.time_embed(timestep_embedding(t, self.time_emb_channels)) + + if self.conf.num_classes is not None: + raise NotImplementedError() + + h = x.type(self.dtype) + k = 0 + for i in range(len(self.input_num_blocks)): + for j in range(self.input_num_blocks[i]): + h = self.input_blocks[k](h, emb=emb) + hs[i].append(h) + k += 1 + assert k == len(self.input_blocks) + + # middle blocks + h = self.middle_block(h, emb=emb) + + # output blocks + k = 0 + for i in range(len(self.output_num_blocks)): + for j in range(self.output_num_blocks[i]): + try: + lateral = hs[-i - 1].pop() + except IndexError: + lateral = None + h = self.output_blocks[k](h, emb=emb, lateral=lateral) + k += 1 + + h = h.type(x.dtype) + pred = self.out(h) + return Return(pred=pred) + +class Return(NamedTuple): + pred: th.Tensor + + +@dataclass +class BeatGANsEncoderConfig(BaseConfig): + image_size: int + in_channels: int + model_channels: int + out_hid_channels: int + out_channels: int + num_res_blocks: int + attention_resolutions: Tuple[int] + dropout: float = 0 + channel_mult: Tuple[int] = (1, 2, 4, 8) + use_time_condition: bool = True + conv_resample: bool = True + dims: int = 2 + use_checkpoint: bool = False + num_heads: int = 1 + num_head_channels: int = -1 + resblock_updown: bool = False + use_new_attention_order: bool = False + pool: str = 'adaptivenonzero' + + def make_model(self): + return BeatGANsEncoderModel(self) + +class BeatGANsEncoderModel(nn.Module): + """ + The half UNet model with attention and timestep embedding. + + For usage, see UNet. + """ + def __init__(self, conf: BeatGANsEncoderConfig): + super().__init__() + self.conf = conf + self.dtype = th.float32 + + if conf.use_time_condition: + time_embed_dim = conf.model_channels * 4 + self.time_embed = nn.Sequential( + linear(conf.model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + else: + time_embed_dim = None + + ch = int(conf.channel_mult[0] * conf.model_channels) + self.input_blocks = nn.ModuleList([ + TimestepEmbedSequential( + conv_nd(conf.dims, conf.in_channels, ch, 3, padding=1)) + ]) + self._feature_size = ch + input_block_chans = [ch] + ds = 1 + resolution = conf.image_size + for level, mult in enumerate(conf.channel_mult): + for _ in range(conf.num_res_blocks): + layers = [ + ResBlockConfig( + ch, + time_embed_dim, + conf.dropout, + out_channels=int(mult * conf.model_channels), + dims=conf.dims, + use_condition=conf.use_time_condition, + use_checkpoint=conf.use_checkpoint, + ).make_model() + ] + ch = int(mult * conf.model_channels) + if resolution in conf.attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=conf.use_checkpoint, + num_heads=conf.num_heads, + num_head_channels=conf.num_head_channels, + use_new_attention_order=conf. + use_new_attention_order, + )) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(conf.channel_mult) - 1: + resolution //= 2 + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlockConfig( + ch, + time_embed_dim, + conf.dropout, + out_channels=out_ch, + dims=conf.dims, + use_condition=conf.use_time_condition, + use_checkpoint=conf.use_checkpoint, + down=True, + ).make_model() if ( + conf.resblock_updown + ) else Downsample(ch, + conf.conv_resample, + dims=conf.dims, + out_channels=out_ch))) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlockConfig( + ch, + time_embed_dim, + conf.dropout, + dims=conf.dims, + use_condition=conf.use_time_condition, + use_checkpoint=conf.use_checkpoint, + ).make_model(), + AttentionBlock( + ch, + use_checkpoint=conf.use_checkpoint, + num_heads=conf.num_heads, + num_head_channels=conf.num_head_channels, + use_new_attention_order=conf.use_new_attention_order, + ), + ResBlockConfig( + ch, + time_embed_dim, + conf.dropout, + dims=conf.dims, + use_condition=conf.use_time_condition, + use_checkpoint=conf.use_checkpoint, + ).make_model(), + ) + self._feature_size += ch + if conf.pool == "adaptivenonzero": + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + nn.AdaptiveAvgPool1d(1), + conv_nd(conf.dims, ch, conf.out_channels, 1), + nn.Flatten(), + ) + else: + raise NotImplementedError(f"Unexpected {conf.pool} pooling") + + def forward(self, x, t=None, return_2d_feature=False): + """ + Apply the model to an input batch. + + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :return: an [N x K] Tensor of outputs. + """ + if self.conf.use_time_condition: + emb = self.time_embed(timestep_embedding(t, self.model_channels)) + else: + emb = None + + results = [] + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb=emb) + if self.conf.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = self.middle_block(h, emb=emb) + if self.conf.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = th.cat(results, axis=-1) + else: + h = h.type(x.dtype) + + h_2d = h + h = self.out(h) + + if return_2d_feature: + return h, h_2d + else: + return h + + def forward_flatten(self, x): + """ + transform the last 2d feature into a flatten vector + """ + h = self.out(x) + return h + + +class SuperResModel(BeatGANsUNetModel): + """ + A UNetModel that performs super-resolution. + + Expects an extra kwarg `low_res` to condition on a low-resolution image. + """ + def __init__(self, image_size, in_channels, *args, **kwargs): + super().__init__(image_size, in_channels * 2, *args, **kwargs) + + def forward(self, x, timesteps, low_res=None, **kwargs): + _, _, new_height, new_width = x.shape + upsampled = F.interpolate(low_res, (new_height, new_width), + mode="bilinear") + x = th.cat([x, upsampled], dim=1) + return super().forward(x, timesteps, **kwargs) \ No newline at end of file diff --git a/model/unet_autoenc.py b/model/unet_autoenc.py new file mode 100644 index 0000000..60b2761 --- /dev/null +++ b/model/unet_autoenc.py @@ -0,0 +1,310 @@ +import torch +from torch import Tensor, nn +from torch.nn.functional import silu +from .latentnet import * +from .unet import * +from choices import * + + +@dataclass +class BeatGANsAutoencConfig(BeatGANsUNetConfig): + enc_out_channels: int = 512 + enc_attn_resolutions: Tuple[int] = None + enc_pool: str = 'depthconv' + enc_num_res_block: int = 2 + enc_channel_mult: Tuple[int] = None + enc_grad_checkpoint: bool = False + latent_net_conf: MLPSkipNetConfig = None + + def make_model(self): + return BeatGANsAutoencModel(self) + + +class BeatGANsAutoencModel(BeatGANsUNetModel): + def __init__(self, conf: BeatGANsAutoencConfig): + super().__init__(conf) + self.conf = conf + + self.time_embed = TimeStyleSeperateEmbed( + time_channels=conf.model_channels, + time_out_channels=conf.embed_channels, + ) + + self.encoder = BeatGANsEncoderConfig( + image_size=conf.image_size, + in_channels=conf.in_channels, + model_channels=conf.model_channels, + out_hid_channels=conf.enc_out_channels, + out_channels=conf.enc_out_channels, + num_res_blocks=conf.enc_num_res_block, + attention_resolutions=(conf.enc_attn_resolutions + or conf.attention_resolutions), + dropout=conf.dropout, + channel_mult=conf.enc_channel_mult or conf.channel_mult, + use_time_condition=False, + conv_resample=conf.conv_resample, + dims=conf.dims, + use_checkpoint=conf.use_checkpoint or conf.enc_grad_checkpoint, + num_heads=conf.num_heads, + num_head_channels=conf.num_head_channels, + resblock_updown=conf.resblock_updown, + use_new_attention_order=conf.use_new_attention_order, + pool=conf.enc_pool, + ).make_model() + + self.user_classifier = UserClassifier(conf.enc_out_channels//2, conf.num_users) + self.non_user_classifier = UserClassifierGradientReverse(conf.enc_out_channels//2, conf.num_users) + + if conf.latent_net_conf is not None: + self.latent_net = conf.latent_net_conf.make_model() + + def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: + """ + Reparameterization trick to sample from N(mu, var) from + N(0,1). + :param mu: (Tensor) Mean of the latent Gaussian [B x D] + :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D] + :return: (Tensor) [B x D] + """ + assert self.conf.is_stochastic + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + return eps * std + mu + + def sample_z(self, n: int, device): + assert self.conf.is_stochastic + return torch.randn(n, self.conf.enc_out_channels, device=device) + + def noise_to_cond(self, noise: Tensor): + raise NotImplementedError() + assert self.conf.noise_net_conf is not None + return self.noise_net.forward(noise) + + def encode(self, x): + cond = self.encoder.forward(x) + return {'cond': cond} + + @property + def stylespace_sizes(self): + modules = list(self.input_blocks.modules()) + list( + self.middle_block.modules()) + list(self.output_blocks.modules()) + sizes = [] + for module in modules: + if isinstance(module, ResBlock): + linear = module.cond_emb_layers[-1] + sizes.append(linear.weight.shape[0]) + return sizes + + def encode_stylespace(self, x, return_vector: bool = True): + """ + encode to style space + """ + modules = list(self.input_blocks.modules()) + list( + self.middle_block.modules()) + list(self.output_blocks.modules()) + cond = self.encoder.forward(x) + S = [] + for module in modules: + if isinstance(module, ResBlock): + s = module.cond_emb_layers.forward(cond) + S.append(s) + + if return_vector: + return torch.cat(S, dim=1) + else: + return S + + def forward(self, + x, + t, + y=None, + x_start=None, + cond=None, + style=None, + noise=None, + t_cond=None, + **kwargs): + """ + Apply the model to an input batch. + + Args: + x_start: the original image to encode + cond: output of the encoder + noise: random noise (to predict the cond) + """ + if t_cond is None: + t_cond = t + + if noise is not None: + cond = self.noise_to_cond(noise) + + if cond is None: + if x is not None: + assert len(x) == len(x_start), f'{len(x)} != {len(x_start)}' + tmp = self.encode(x_start) + cond = tmp['cond'] + + if t is not None: + _t_emb = timestep_embedding(t, self.conf.model_channels) + _t_cond_emb = timestep_embedding(t_cond, self.conf.model_channels) + else: + _t_emb = None + _t_cond_emb = None + + if self.conf.resnet_two_cond: + res = self.time_embed.forward( + time_emb=_t_emb, + cond=cond, + time_cond_emb=_t_cond_emb + ) + else: + raise NotImplementedError() + + if self.conf.resnet_two_cond: + emb = res.time_emb + cond_emb = res.emb + else: + emb = res.emb + cond_emb = None + + style = style or res.style + + assert (y is not None) == ( + self.conf.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + + + if self.conf.num_classes is not None: + raise NotImplementedError() + + enc_time_emb = emb + mid_time_emb = emb + dec_time_emb = emb + enc_cond_emb = cond_emb + mid_cond_emb = cond_emb + dec_cond_emb = cond_emb + + if self.conf.num_users is not None: + user_pred = self.user_classifier(cond_emb[:, :self.conf.enc_out_channels // 2]) + non_user_pred = self.non_user_classifier(cond_emb[:, self.conf.enc_out_channels // 2:]) + + hs = [[] for _ in range(len(self.conf.channel_mult))] + + if x is not None: + h = x.type(self.dtype) + + # input blocks + k = 0 + for i in range(len(self.input_num_blocks)): + for j in range(self.input_num_blocks[i]): + h = self.input_blocks[k](h, + emb=enc_time_emb, + cond=enc_cond_emb) + '''print(i, j, h.shape) + if h.shape[-1]%2==1: + pdb.set_trace()''' + hs[i].append(h) + k += 1 + assert k == len(self.input_blocks) + + # middle blocks + h = self.middle_block(h, emb=mid_time_emb, cond=mid_cond_emb) + else: + h = None + hs = [[] for _ in range(len(self.conf.channel_mult))] + + # output blocks + k = 0 + for i in range(len(self.output_num_blocks)): + for j in range(self.output_num_blocks[i]): + try: + lateral = hs[-i - 1].pop() + except IndexError: + lateral = None + + h = self.output_blocks[k](h, + emb=dec_time_emb, + cond=dec_cond_emb, + lateral=lateral) + k += 1 + + pred = self.out(h) + + return AutoencReturn(pred=pred, cond=cond, user_pred=user_pred, non_user_pred=non_user_pred) + +class UserClassifier(nn.Module): + def __init__(self, in_channels, num_classes): + super().__init__() + self.fc = nn.Sequential( + nn.Linear(in_channels, 256), + nn.ReLU(), + nn.Linear(256, num_classes), + nn.Softmax(dim=1) + ) + + def forward(self, x): + return self.fc(x) + +class GradReverse(torch.autograd.Function): + """ + Implement the gradient reversal layer for the convenience of domain adaptation neural network. + The forward part is the identity function while the backward part is the negative function. + """ + @staticmethod + def forward(ctx, x): + return x.view_as(x) + + @staticmethod + def backward(ctx, grad_output): + return grad_output.neg() + +class GradientReversalLayer(nn.Module): + def __init__(self): + super(GradientReversalLayer, self).__init__() + + def forward(self, inputs): + return GradReverse.apply(inputs) + +class UserClassifierGradientReverse(nn.Module): + def __init__(self, in_channels, num_classes): + super().__init__() + self.grl = GradientReversalLayer() + self.fc = UserClassifier(in_channels, num_classes) + + def forward(self, x): + x = self.grl(x) + return self.fc(x) + +class AutoencReturn(NamedTuple): + pred: Tensor + cond: Tensor = None + user_pred: Tensor = None + non_user_pred: Tensor = None + + +class EmbedReturn(NamedTuple): + emb: Tensor = None + time_emb: Tensor = None + style: Tensor = None + +class TimeStyleSeperateEmbed(nn.Module): + def __init__(self, time_channels, time_out_channels): + super().__init__() + self.time_embed = nn.Sequential( + linear(time_channels, time_out_channels), + nn.SiLU(), + linear(time_out_channels, time_out_channels), + ) + self.cond_combine = nn.Sequential( + nn.Linear(time_out_channels * 2, time_out_channels), + nn.SiLU() + ) + self.style = nn.Identity() + + def forward(self, time_emb=None, cond=None, **kwargs): + if time_emb is None: + time_emb = None + else: + time_emb = self.time_embed(time_emb) + + style = self.style(cond) + return EmbedReturn(emb=style, time_emb=time_emb, style=style) \ No newline at end of file diff --git a/templates.py b/templates.py new file mode 100644 index 0000000..e4153bf --- /dev/null +++ b/templates.py @@ -0,0 +1,55 @@ +from experiment import * + +def autoenc_base(): + conf = TrainConfig() + conf.batch_size = 32 + conf.beatgans_gen_type = GenerativeType.ddim + conf.beta_scheduler = 'linear' + conf.data_name = 'ffhq' + conf.diffusion_type = 'beatgans' + conf.eval_ema_every_samples = 200_000 + conf.eval_every_samples = 200_000 + conf.fp16 = True + conf.lr = 1e-4 + conf.model_name = ModelName.beatgans_autoenc + conf.net_attn = (16, ) + conf.net_beatgans_attn_head = 1 + conf.net_beatgans_embed_channels = 128 + conf.net_beatgans_resnet_two_cond = True + conf.net_ch_mult = (1, 2, 4, 8) + conf.net_ch = 64 + conf.net_enc_channel_mult = (1, 2, 4, 8, 8) + conf.net_enc_pool = 'adaptivenonzero' + conf.sample_size = 32 + conf.T_eval = 20 + conf.T = 1000 + conf.make_model_conf() + return conf + +def mouse_autoenc(mode): + num_users = {'Clarkson': 75} + + conf = autoenc_base() + conf.scale_up_gpus(1) + conf.img_size = 256 + conf.net_ch = 128 + conf.net_ch_mult = (1, 1, 2, 2, 4, 4) + conf.net_enc_channel_mult = (1, 1, 2, 2, 4, 4, 4) + conf.eval_every_samples = 10_000_000 + conf.eval_ema_every_samples = 10_000_000 + conf.total_samples = 200_000_000 + conf.batch_size = 512 + conf.name = 'mouse_autoenc' + + conf.pretrainDataset = 'Clarkson' + conf.data_name = 'Clarkson' + + conf.path = f'../mousedata/{conf.pretrainDataset}/' + conf.AEwin = 8 + conf.slid = 1 + conf.timeWinFreq = 20 + conf.num_users = num_users[conf.pretrainDataset] + + conf.mode = mode + conf.make_model_conf() + return conf \ No newline at end of file