This commit is contained in:
Guanhua Zhang 2024-10-08 14:18:47 +02:00
commit b102c2a534
20 changed files with 4305 additions and 0 deletions

127
choices.py Executable file
View file

@ -0,0 +1,127 @@
from enum import Enum
from torch import nn
class TrainMode(Enum):
manipulate = 'manipulate'
diffusion = 'diffusion'
latent_diffusion = 'latentdiffusion'
def is_manipulate(self):
return self in [
TrainMode.manipulate,
]
def is_diffusion(self):
return self in [
TrainMode.diffusion,
TrainMode.latent_diffusion,
]
def is_autoenc(self):
return self in [
TrainMode.diffusion,
]
def is_latent_diffusion(self):
return self in [
TrainMode.latent_diffusion,
]
def use_latent_net(self):
return self.is_latent_diffusion()
def require_dataset_infer(self):
"""
whether training in this mode requires the latent variables to be available?
"""
return self in [
TrainMode.latent_diffusion,
TrainMode.manipulate,
]
class ModelType(Enum):
"""
Kinds of the backbone models
"""
ddpm = 'ddpm'
autoencoder = 'autoencoder'
def has_autoenc(self):
return self in [
ModelType.autoencoder,
]
def can_sample(self):
return self in [ModelType.ddpm]
class ModelName(Enum):
"""
List of all supported model classes
"""
beatgans_ddpm = 'beatgans_ddpm'
beatgans_autoenc = 'beatgans_autoenc'
class ModelMeanType(Enum):
"""
Which type of output the model predicts.
"""
eps = 'eps'
class ModelVarType(Enum):
"""
What is used as the model's output variance.
The LEARNED_RANGE option has been added to allow the model to predict
values between FIXED_SMALL and FIXED_LARGE, making its job easier.
"""
fixed_small = 'fixed_small'
fixed_large = 'fixed_large'
class LossType(Enum):
mse = 'mse'
l1 = 'l1'
class GenerativeType(Enum):
"""
How's a sample generated
"""
ddpm = 'ddpm'
ddim = 'ddim'
class OptimizerType(Enum):
adam = 'adam'
adamw = 'adamw'
class Activation(Enum):
none = 'none'
relu = 'relu'
lrelu = 'lrelu'
silu = 'silu'
tanh = 'tanh'
def get_act(self):
if self == Activation.none:
return nn.Identity()
elif self == Activation.relu:
return nn.ReLU()
elif self == Activation.lrelu:
return nn.LeakyReLU(negative_slope=0.2)
elif self == Activation.silu:
return nn.SiLU()
elif self == Activation.tanh:
return nn.Tanh()
else:
raise NotImplementedError()

303
config.py Normal file
View file

@ -0,0 +1,303 @@
from model.unet import ScaleAt
from model.latentnet import *
from diffusion.resample import UniformSampler
from diffusion.diffusion import space_timesteps
from typing import Tuple
from config_base import BaseConfig
from dataset import *
from diffusion import *
from diffusion.base import GenerativeType, LossType, ModelMeanType, ModelVarType, get_named_beta_schedule
from model import *
from choices import *
@dataclass
class PretrainConfig(BaseConfig):
name: str
path: str
@dataclass
class TrainConfig(BaseConfig):
seed: int = 0
train_mode: TrainMode = TrainMode.diffusion
train_cond0_prob: float = 0
train_pred_xstart_detach: bool = True
train_interpolate_prob: float = 0
train_interpolate_img: bool = False
accum_batches: int = 1
autoenc_mid_attn: bool = True
batch_size: int = 512
batch_size_eval: int = 4
beatgans_gen_type: GenerativeType = GenerativeType.ddim
beatgans_loss_type: LossType = LossType.mse
beatgans_model_mean_type: ModelMeanType = ModelMeanType.eps
beatgans_model_var_type: ModelVarType = ModelVarType.fixed_large
beatgans_rescale_timesteps: bool = False
latent_infer_path: str = None
latent_znormalize: bool = False
latent_gen_type: GenerativeType = GenerativeType.ddim
latent_loss_type: LossType = LossType.mse
latent_model_mean_type: ModelMeanType = ModelMeanType.eps
latent_model_var_type: ModelVarType = ModelVarType.fixed_large
latent_rescale_timesteps: bool = False
latent_T_eval: int = 1_000
latent_clip_sample: bool = False
latent_beta_scheduler: str = 'linear'
beta_scheduler: str = 'linear'
data_name: str = ''
data_val_name: str = None
diffusion_type: str = None
dropout: float = 0.1
ema_decay: float = 0.9999
eval_num_images: int = 5_000
eval_every_samples: int = 200_000
eval_ema_every_samples: int = 200_000
fid_use_torch: bool = True
fp16: bool = False
grad_clip: float = 1
img_size: int = 64
lr: float = 0.0001
optimizer: OptimizerType = OptimizerType.adam
weight_decay: float = 0
model_conf: ModelConfig = None
model_name: ModelName = None
model_type: ModelType = None
net_attn: Tuple[int] = None
net_beatgans_attn_head: int = 1
net_beatgans_embed_channels: int = 128
net_resblock_updown: bool = True
net_enc_use_time: bool = False
net_enc_pool: str = 'adaptivenonzero'
net_beatgans_gradient_checkpoint: bool = False
net_beatgans_resnet_two_cond: bool = False
net_beatgans_resnet_use_zero_module: bool = True
net_beatgans_resnet_scale_at: ScaleAt = ScaleAt.after_norm
net_beatgans_resnet_cond_channels: int = None
net_ch_mult: Tuple[int] = None
net_ch: int = 64
net_enc_attn: Tuple[int] = None
net_enc_k: int = None
net_enc_num_res_blocks: int = 2
net_enc_channel_mult: Tuple[int] = None
net_enc_grad_checkpoint: bool = False
net_autoenc_stochastic: bool = False
net_latent_activation: Activation = Activation.silu
net_latent_channel_mult: Tuple[int] = (1, 2, 4)
net_latent_condition_bias: float = 0
net_latent_dropout: float = 0
net_latent_layers: int = None
net_latent_net_last_act: Activation = Activation.none
net_latent_net_type: LatentNetType = LatentNetType.none
net_latent_num_hid_channels: int = 1024
net_latent_num_time_layers: int = 2
net_latent_skip_layers: Tuple[int] = None
net_latent_time_emb_channels: int = 64
net_latent_use_norm: bool = False
net_latent_time_last_act: bool = False
net_num_res_blocks: int = 2
net_num_input_res_blocks: int = None
net_enc_num_cls: int = None
num_workers: int = 4
parallel: bool = False
postfix: str = ''
sample_size: int = 64
sample_every_samples: int = 20_000
save_every_samples: int = 100_000
style_ch: int = 128
T_eval: int = 1_000
T_sampler: str = 'uniform'
T: int = 1_000
total_samples: int = 10_000_000
warmup: int = 0
pretrain: PretrainConfig = None
continue_from: PretrainConfig = None
eval_programs: Tuple[str] = None
eval_path: str = None
base_dir: str = 'checkpoints'
name: str = ''
logdir: str = f'{base_dir}{name}'
num_users: int = 0
def __post_init__(self):
self.batch_size_eval = self.batch_size_eval or self.batch_size
self.data_val_name = self.data_val_name or self.data_name
def scale_up_gpus(self, num_gpus, num_nodes=1):
self.eval_ema_every_samples *= num_gpus * num_nodes
self.eval_every_samples *= num_gpus * num_nodes
self.sample_every_samples *= num_gpus * num_nodes
self.batch_size *= num_gpus * num_nodes
self.batch_size_eval *= num_gpus * num_nodes
return self
@property
def batch_size_effective(self):
return self.batch_size * self.accum_batches
def _make_diffusion_conf(self, T=None):
if self.diffusion_type == 'beatgans':
if self.beatgans_gen_type == GenerativeType.ddpm:
section_counts = [T]
elif self.beatgans_gen_type == GenerativeType.ddim:
section_counts = f'ddim{T}'
else:
raise NotImplementedError()
return SpacedDiffusionBeatGansConfig(
gen_type=self.beatgans_gen_type,
model_type=self.model_type,
betas=get_named_beta_schedule(self.beta_scheduler, self.T),
model_mean_type=self.beatgans_model_mean_type,
model_var_type=self.beatgans_model_var_type,
loss_type=self.beatgans_loss_type,
rescale_timesteps=self.beatgans_rescale_timesteps,
use_timesteps=space_timesteps(num_timesteps=self.T,
section_counts=section_counts),
fp16=self.fp16,
)
else:
raise NotImplementedError()
def _make_latent_diffusion_conf(self, T=None):
if self.latent_gen_type == GenerativeType.ddpm:
section_counts = [T]
elif self.latent_gen_type == GenerativeType.ddim:
section_counts = f'ddim{T}'
else:
raise NotImplementedError()
return SpacedDiffusionBeatGansConfig(
train_pred_xstart_detach=self.train_pred_xstart_detach,
gen_type=self.latent_gen_type,
model_type=ModelType.ddpm,
betas=get_named_beta_schedule(self.latent_beta_scheduler, self.T),
model_mean_type=self.latent_model_mean_type,
model_var_type=self.latent_model_var_type,
loss_type=self.latent_loss_type,
rescale_timesteps=self.latent_rescale_timesteps,
use_timesteps=space_timesteps(num_timesteps=self.T,
section_counts=section_counts),
fp16=self.fp16,
)
@property
def model_out_channels(self):
return 2
@property
def model_input_channels(self):
return 2
def make_T_sampler(self):
if self.T_sampler == 'uniform':
return UniformSampler(self.T)
else:
raise NotImplementedError()
def make_diffusion_conf(self):
return self._make_diffusion_conf(self.T)
def make_eval_diffusion_conf(self):
return self._make_diffusion_conf(T=self.T_eval)
def make_latent_diffusion_conf(self):
return self._make_latent_diffusion_conf(T=self.T)
def make_latent_eval_diffusion_conf(self):
return self._make_latent_diffusion_conf(T=self.latent_T_eval)
def make_dataset(self, taskdata, tasklabels):
return SimpleSet(taskdata, tasklabels, intflag=True)
def make_model_conf(self):
if self.model_name == ModelName.beatgans_ddpm:
self.model_type = ModelType.ddpm
self.model_conf = BeatGANsUNetConfig(
attention_resolutions=self.net_attn,
channel_mult=self.net_ch_mult,
conv_resample=True,
dims=2,
dropout=self.dropout,
embed_channels=self.net_beatgans_embed_channels,
image_size=self.img_size,
in_channels=self.model_input_channels,
model_channels=self.net_ch,
num_classes=None,
num_head_channels=-1,
num_heads_upsample=-1,
num_heads=self.net_beatgans_attn_head,
num_res_blocks=self.net_num_res_blocks,
num_input_res_blocks=self.net_num_input_res_blocks,
out_channels=self.model_out_channels,
resblock_updown=self.net_resblock_updown,
use_checkpoint=self.net_beatgans_gradient_checkpoint,
use_new_attention_order=False,
resnet_two_cond=self.net_beatgans_resnet_two_cond,
resnet_use_zero_module=self.
net_beatgans_resnet_use_zero_module,
)
elif self.model_name in [
ModelName.beatgans_autoenc,
]:
cls = BeatGANsAutoencConfig
if self.model_name == ModelName.beatgans_autoenc:
self.model_type = ModelType.autoencoder
else:
raise NotImplementedError()
if self.net_latent_net_type == LatentNetType.none:
latent_net_conf = None
elif self.net_latent_net_type == LatentNetType.skip:
latent_net_conf = MLPSkipNetConfig(
num_channels=self.style_ch,
skip_layers=self.net_latent_skip_layers,
num_hid_channels=self.net_latent_num_hid_channels,
num_layers=self.net_latent_layers,
num_time_emb_channels=self.net_latent_time_emb_channels,
activation=self.net_latent_activation,
use_norm=self.net_latent_use_norm,
condition_bias=self.net_latent_condition_bias,
dropout=self.net_latent_dropout,
last_act=self.net_latent_net_last_act,
num_time_layers=self.net_latent_num_time_layers,
time_last_act=self.net_latent_time_last_act,
)
else:
raise NotImplementedError()
self.model_conf = cls(
attention_resolutions=self.net_attn,
channel_mult=self.net_ch_mult,
conv_resample=True,
dims=1,
dropout=self.dropout,
embed_channels=self.net_beatgans_embed_channels,
enc_out_channels=self.style_ch,
enc_pool=self.net_enc_pool,
enc_num_res_block=self.net_enc_num_res_blocks,
enc_channel_mult=self.net_enc_channel_mult,
enc_grad_checkpoint=self.net_enc_grad_checkpoint,
enc_attn_resolutions=self.net_enc_attn,
image_size=self.img_size,
in_channels=self.model_input_channels,
model_channels=self.net_ch,
num_classes=None,
num_head_channels=-1,
num_heads_upsample=-1,
num_heads=self.net_beatgans_attn_head,
num_res_blocks=self.net_num_res_blocks,
num_input_res_blocks=self.net_num_input_res_blocks,
out_channels=self.model_out_channels,
resblock_updown=self.net_resblock_updown,
use_checkpoint=self.net_beatgans_gradient_checkpoint,
use_new_attention_order=False,
resnet_two_cond=self.net_beatgans_resnet_two_cond,
resnet_use_zero_module=self.
net_beatgans_resnet_use_zero_module,
latent_net_conf=latent_net_conf,
resnet_cond_channels=self.net_beatgans_resnet_cond_channels,
num_users = self.num_users,
)
else:
raise NotImplementedError(self.model_name)
return self.model_conf

71
config_base.py Executable file
View file

@ -0,0 +1,71 @@
import json
import os
from copy import deepcopy
from dataclasses import dataclass
@dataclass
class BaseConfig:
def clone(self):
return deepcopy(self)
def inherit(self, another):
"""inherit common keys from a given config"""
common_keys = set(self.__dict__.keys()) & set(another.__dict__.keys())
for k in common_keys:
setattr(self, k, getattr(another, k))
def propagate(self):
"""push down the configuration to all members"""
for k, v in self.__dict__.items():
if isinstance(v, BaseConfig):
v.inherit(self)
v.propagate()
def save(self, save_path):
"""save config to json file"""
dirname = os.path.dirname(save_path)
if not os.path.exists(dirname):
os.makedirs(dirname)
conf = self.as_dict_jsonable()
with open(save_path, 'w') as f:
json.dump(conf, f)
def load(self, load_path):
"""load json config"""
with open(load_path) as f:
conf = json.load(f)
self.from_dict(conf)
def from_dict(self, dict, strict=False):
for k, v in dict.items():
if not hasattr(self, k):
if strict:
raise ValueError(f"loading extra '{k}'")
else:
print(f"loading extra '{k}'")
continue
if isinstance(self.__dict__[k], BaseConfig):
self.__dict__[k].from_dict(v)
else:
self.__dict__[k] = v
def as_dict_jsonable(self):
conf = {}
for k, v in self.__dict__.items():
if isinstance(v, BaseConfig):
conf[k] = v.as_dict_jsonable()
else:
if jsonable(v):
conf[k] = v
else:
pass
return conf
def jsonable(x):
try:
json.dumps(x)
return True
except TypeError:
return False

21
dataset.py Executable file
View file

@ -0,0 +1,21 @@
from torch.utils.data import Dataset
import torch
import pandas as pd
def loadDataset(conf):
eval('taskdata, tasklabels = load%s(conf)'%(conf.pretrainDataset)) # plug in the function to load your own dataset
tasklabels = pd.DataFrame(tasklabels, columns=['user'])
print('taskdata.shape:', taskdata.shape) # (N, 2, window_length*sample_freq)
return taskdata, tasklabels
class SimpleSet(Dataset):
def __init__(self, data, labels, intflag=True):
self.data = torch.tensor(data, dtype=torch.float)
if intflag:
self.label = torch.tensor(labels, dtype=torch.long)
else:
self.label = torch.tensor(labels, dtype=torch.float)
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index], self.label[index]

6
diffusion/__init__.py Normal file
View file

@ -0,0 +1,6 @@
from typing import Union
from .diffusion import SpacedDiffusionBeatGans, SpacedDiffusionBeatGansConfig
Sampler = Union[SpacedDiffusionBeatGans]
SamplerConfig = Union[SpacedDiffusionBeatGansConfig]

1086
diffusion/base.py Executable file

File diff suppressed because it is too large Load diff

154
diffusion/diffusion.py Executable file
View file

@ -0,0 +1,154 @@
from .base import *
from dataclasses import dataclass
def space_timesteps(num_timesteps, section_counts):
"""
Create a list of timesteps to use from an original diffusion process,
given the number of timesteps we want to take from equally-sized portions
of the original process.
For example, if there's 300 timesteps and the section counts are [10,15,20]
then the first 100 timesteps are strided to be 10 timesteps, the second 100
are strided to be 15 timesteps, and the final 100 are strided to be 20.
If the stride is a string starting with "ddim", then the fixed striding
from the DDIM paper is used, and only one section is allowed.
:param num_timesteps: the number of diffusion steps in the original
process to divide up.
:param section_counts: either a list of numbers, or a string containing
comma-separated numbers, indicating the step count
per section. As a special case, use "ddimN" where N
is a number of steps to use the striding from the
DDIM paper.
:return: a set of diffusion steps from the original process to use.
"""
if isinstance(section_counts, str):
if section_counts.startswith("ddim"):
desired_count = int(section_counts[len("ddim"):])
for i in range(1, num_timesteps):
if len(range(0, num_timesteps, i)) == desired_count:
return set(range(0, num_timesteps, i))
raise ValueError(
f"cannot create exactly {num_timesteps} steps with an integer stride"
)
section_counts = [int(x) for x in section_counts.split(",")]
size_per = num_timesteps // len(section_counts)
extra = num_timesteps % len(section_counts)
start_idx = 0
all_steps = []
for i, section_count in enumerate(section_counts):
size = size_per + (1 if i < extra else 0)
if size < section_count:
raise ValueError(
f"cannot divide section of {size} steps into {section_count}")
if section_count <= 1:
frac_stride = 1
else:
frac_stride = (size - 1) / (section_count - 1)
cur_idx = 0.0
taken_steps = []
for _ in range(section_count):
taken_steps.append(start_idx + round(cur_idx))
cur_idx += frac_stride
all_steps += taken_steps
start_idx += size
return set(all_steps)
@dataclass
class SpacedDiffusionBeatGansConfig(GaussianDiffusionBeatGansConfig):
use_timesteps: Tuple[int] = None
def make_sampler(self):
return SpacedDiffusionBeatGans(self)
class SpacedDiffusionBeatGans(GaussianDiffusionBeatGans):
"""
A diffusion process which can skip steps in a base diffusion process.
:param use_timesteps: a collection (sequence or set) of timesteps from the
original diffusion process to retain.
:param kwargs: the kwargs to create the base diffusion process.
"""
def __init__(self, conf: SpacedDiffusionBeatGansConfig):
self.conf = conf
self.use_timesteps = set(conf.use_timesteps)
self.timestep_map = []
self.original_num_steps = len(conf.betas)
base_diffusion = GaussianDiffusionBeatGans(conf)
last_alpha_cumprod = 1.0
new_betas = []
for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
if i in self.use_timesteps:
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
last_alpha_cumprod = alpha_cumprod
self.timestep_map.append(i)
conf.betas = np.array(new_betas)
super().__init__(conf)
def p_mean_variance(self, model: Model, *args, **kwargs):
return super().p_mean_variance(self._wrap_model(model), *args,
**kwargs)
def training_losses(self, model: Model, *args, **kwargs):
return super().training_losses(self._wrap_model(model), *args,
**kwargs)
def condition_mean(self, cond_fn, *args, **kwargs):
return super().condition_mean(self._wrap_model(cond_fn), *args,
**kwargs)
def condition_score(self, cond_fn, *args, **kwargs):
return super().condition_score(self._wrap_model(cond_fn), *args,
**kwargs)
def _wrap_model(self, model: Model):
if isinstance(model, _WrappedModel):
return model
return _WrappedModel(model, self.timestep_map, self.rescale_timesteps,
self.original_num_steps)
def _scale_timesteps(self, t):
return t
class _WrappedModel:
"""
converting the supplied t's to the old t's scales.
"""
def __init__(self, model, timestep_map, rescale_timesteps,
original_num_steps):
self.model = model
self.timestep_map = timestep_map
self.rescale_timesteps = rescale_timesteps
self.original_num_steps = original_num_steps
def forward(self, x, t, t_cond=None, **kwargs):
"""
Args:
t: t's with differrent ranges (can be << T due to smaller eval T) need to be converted to the original t's
t_cond: the same as t but can be of different values
"""
map_tensor = th.tensor(self.timestep_map,
device=t.device,
dtype=t.dtype)
def do(t):
new_ts = map_tensor[t]
if self.rescale_timesteps:
new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
return new_ts
if t_cond is not None:
t_cond = do(t_cond)
return self.model(x=x, t=do(t), t_cond=t_cond, **kwargs)
def __getattr__(self, name):
if hasattr(self.model, name):
func = getattr(self.model, name)
return func
raise AttributeError(name)

62
diffusion/resample.py Normal file
View file

@ -0,0 +1,62 @@
from abc import ABC, abstractmethod
import numpy as np
import torch as th
def create_named_schedule_sampler(name, diffusion):
"""
Create a ScheduleSampler from a library of pre-defined samplers.
:param name: the name of the sampler.
:param diffusion: the diffusion object to sample for.
"""
if name == "uniform":
return UniformSampler(diffusion)
else:
raise NotImplementedError(f"unknown schedule sampler: {name}")
class ScheduleSampler(ABC):
"""
A distribution over timesteps in the diffusion process, intended to reduce
variance of the objective.
By default, samplers perform unbiased importance sampling, in which the
objective's mean is unchanged.
However, subclasses may override sample() to change how the resampled
terms are reweighted, allowing for actual changes in the objective.
"""
@abstractmethod
def weights(self):
"""
Get a numpy array of weights, one per diffusion step.
The weights needn't be normalized, but must be positive.
"""
def sample(self, batch_size, device):
"""
Importance-sample timesteps for a batch.
:param batch_size: the number of timesteps.
:param device: the torch device to save to.
:return: a tuple (timesteps, weights):
- timesteps: a tensor of timestep indices.
- weights: a tensor of weights to scale the resulting losses.
"""
w = self.weights()
p = w / np.sum(w)
indices_np = np.random.choice(len(p), size=(batch_size, ), p=p)
indices = th.from_numpy(indices_np).long().to(device)
weights_np = 1 / (len(p) * p[indices_np])
weights = th.from_numpy(weights_np).float().to(device)
return indices, weights
class UniformSampler(ScheduleSampler):
def __init__(self, num_timesteps):
self._weights = np.ones([num_timesteps])
def weights(self):
return self._weights

42
dist_utils.py Executable file
View file

@ -0,0 +1,42 @@
from typing import List
from torch import distributed
def barrier():
if distributed.is_initialized():
distributed.barrier()
else:
pass
def broadcast(data, src):
if distributed.is_initialized():
distributed.broadcast(data, src)
else:
pass
def all_gather(data: List, src):
if distributed.is_initialized():
distributed.all_gather(data, src)
else:
data[0] = src
def get_rank():
if distributed.is_initialized():
return distributed.get_rank()
else:
return 0
def get_world_size():
if distributed.is_initialized():
return distributed.get_world_size()
else:
return 1
def chunk_size(size, rank, world_size):
extra = rank < size % world_size
return size // world_size + extra

160
environment.yml Normal file
View file

@ -0,0 +1,160 @@
name: dismouse
channels:
- defaults
- conda-forge
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- ca-certificates=2023.12.12=h06a4308_0
- ld_impl_linux-64=2.38=h1181459_1
- libffi=3.4.4=h6a678d5_0
- libgcc-ng=11.2.0=h1234567_1
- libgomp=11.2.0=h1234567_1
- libstdcxx-ng=11.2.0=h1234567_1
- ncurses=6.4=h6a678d5_0
- openssl=3.0.12=h7f8727e_0
- pip=23.3.1=py38h06a4308_0
- python=3.8.18=h955ad1f_0
- readline=8.2=h5eee18b_0
- setuptools=68.2.2=py38h06a4308_0
- sqlite=3.41.2=h5eee18b_0
- tk=8.6.12=h1ccaba5_0
- wheel=0.41.2=py38h06a4308_0
- xz=5.4.5=h5eee18b_0
- zlib=1.2.13=h5eee18b_0
- pip:
- absl-py==2.0.0
- aiohttp==3.9.1
- aiosignal==1.3.1
- appdirs==1.4.4
- async-timeout==4.0.3
- attrs==23.2.0
- bleach==6.1.0
- bokeh==3.1.1
- cachetools==5.3.2
- certifi==2023.11.17
- charset-normalizer==3.3.2
- click==8.1.7
- cloudpickle==3.0.0
- colorcet==3.1.0
- contourpy==1.1.1
- cycler==0.12.1
- cython==0.29.37
- dask==2023.5.0
- datashader==0.15.2
- datashape==0.5.2
- docker-pycreds==0.4.0
- filelock==3.13.1
- fonttools==4.47.2
- frozenlist==1.4.1
- fsspec==2023.12.2
- ftfy==6.1.3
- future==0.18.3
- gitdb==4.0.11
- gitpython==3.1.41
- google-auth==2.26.2
- google-auth-oauthlib==1.0.0
- grpcio==1.60.0
- hdbscan==0.8.33
- hmmlearn==0.3.2
- holoviews==1.17.1
- idna==3.6
- imageio==2.34.0
- importlib-metadata==7.0.1
- importlib-resources==6.1.1
- jinja2==3.1.3
- joblib==1.3.2
- kiwisolver==1.4.5
- kornia==0.7.1
- lazy-loader==0.3
- lightning-utilities==0.10.1
- linkify-it-py==2.0.3
- llvmlite==0.41.1
- lmdb==1.2.1
- locket==1.0.0
- lpips==0.1.4
- markdown==3.5.2
- markdown-it-py==3.0.0
- markupsafe==2.1.3
- matplotlib==3.7.4
- mdit-py-plugins==0.4.0
- mdurl==0.1.2
- mpmath==1.3.0
- multidict==6.0.4
- multipledispatch==1.0.0
- networkx==3.1
- numba==0.58.1
- numpy==1.24.4
- nvidia-cublas-cu12==12.1.3.1
- nvidia-cuda-cupti-cu12==12.1.105
- nvidia-cuda-nvrtc-cu12==12.1.105
- nvidia-cuda-runtime-cu12==12.1.105
- nvidia-cudnn-cu12==8.9.2.26
- nvidia-cufft-cu12==11.0.2.54
- nvidia-curand-cu12==10.3.2.106
- nvidia-cusolver-cu12==11.4.5.107
- nvidia-cusparse-cu12==12.1.0.106
- nvidia-nccl-cu12==2.19.3
- nvidia-nvjitlink-cu12==12.3.101
- nvidia-nvtx-cu12==12.1.105
- oauthlib==3.2.2
- packaging==23.2
- pandas==1.5.3
- panel==1.2.3
- param==2.0.2
- partd==1.4.1
- pillow==10.2.0
- protobuf==4.25.2
- psutil==5.9.8
- pyasn1==0.5.1
- pyasn1-modules==0.3.0
- pyct==0.5.0
- pydeprecate==0.3.1
- pynndescent==0.5.11
- pyparsing==3.1.1
- python-crfsuite==0.9.10
- python-dateutil==2.8.2
- pytorch-fid==0.2.0
- pytorch-lightning==1.4.5
- pytz==2023.3.post1
- pyviz-comms==3.0.1
- pywavelets==1.4.1
- pyyaml==6.0.1
- regex==2023.12.25
- requests==2.31.0
- requests-oauthlib==1.3.1
- rsa==4.9
- scikit-image==0.21.0
- scikit-learn==1.3.2
- scipy==1.10.1
- sentry-sdk==1.39.2
- setproctitle==1.3.3
- six==1.16.0
- sklearn-crfsuite==0.3.6
- smmap==5.0.1
- sympy==1.12
- tabulate==0.9.0
- tensorboard==2.14.0
- tensorboard-data-server==0.7.2
- threadpoolctl==3.2.0
- tifffile==2023.7.10
- toolz==0.12.1
- torch==1.8.1
- torchmetrics==0.5.0
- torchvision==0.9.1
- tornado==6.4
- tqdm==4.66.1
- triton==2.2.0
- typing-extensions==4.9.0
- tzdata==2023.4
- uc-micro-py==1.0.3
- umap-learn==0.5.5
- urllib3==2.1.0
- wandb==0.16.2
- wcwidth==0.2.13
- webencodings==0.5.1
- werkzeug==3.0.1
- xarray==2023.1.0
- xyzservices==2023.10.1
- yarl==1.9.4
- zipp==3.17.0

378
experiment.py Executable file
View file

@ -0,0 +1,378 @@
import copy, wandb
import os
import random
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import *
from torch.cuda import amp
from torch.optim.optimizer import Optimizer
from torch.utils.data.dataset import TensorDataset
from config import *
from dataset import *
from dist_utils import *
def MakeDir(dirName):
if not os.path.exists(dirName):
os.makedirs(dirName)
class LitModel(pl.LightningModule):
def __init__(self, conf: TrainConfig, betas):
super().__init__()
self.save_hyperparameters({k:v for (k,v) in vars(conf).items() if not callable(v)})
self.save_hyperparameters(conf.as_dict_jsonable())
assert conf.train_mode != TrainMode.manipulate
if conf.seed is not None:
pl.seed_everything(conf.seed)
conf.betas = betas
self.conf = conf
self.model = conf.make_model_conf().make_model()
self.ema_model = copy.deepcopy(self.model)
self.ema_model.requires_grad_(False)
self.ema_model.eval()
model_size = 0
for param in self.model.parameters():
model_size += param.data.nelement()
print('Model params: %.2f M' % (model_size / 1024 / 1024))
self.sampler = conf.make_diffusion_conf().make_sampler()
self.eval_sampler = conf.make_eval_diffusion_conf().make_sampler()
self.T_sampler = conf.make_T_sampler()
if conf.train_mode.use_latent_net():
self.latent_sampler = conf.make_latent_diffusion_conf(
).make_sampler()
self.eval_latent_sampler = conf.make_latent_eval_diffusion_conf(
).make_sampler()
else:
self.latent_sampler = None
self.eval_latent_sampler = None
if conf.pretrain is not None:
print(f'loading pretrain ... {conf.pretrain.name}')
state = torch.load(conf.pretrain.path, map_location='cpu')
print('step:', state['global_step'])
self.load_state_dict(state['state_dict'], strict=False)
if conf.latent_infer_path is not None:
print('loading latent stats ...')
state = torch.load(conf.latent_infer_path)
self.conds = state['conds']
else:
self.conds_mean = None
self.conds_std = None
def normalize(self, cond):
cond = (cond - self.conds_mean.to(self.device)) / self.conds_std.to(
self.device)
return cond
def denormalize(self, cond):
cond = (cond * self.conds_std.to(self.device)) + self.conds_mean.to(
self.device)
return cond
def render(self, noise, cond=None, T=None):
if T is None:
sampler = self.eval_sampler
else:
sampler = self.conf._make_diffusion_conf(T).make_sampler()
if cond is not None:
pred_img = render_condition(self.conf,
self.ema_model,
noise,
sampler=sampler,
cond=cond)
else:
pred_img = render_uncondition(self.conf,
self.ema_model,
noise,
sampler=sampler,
latent_sampler=None)
return pred_img
def encode(self, x):
assert self.conf.model_type.has_autoenc()
cond = self.ema_model.encoder.forward(x)
return cond
def encode_stochastic(self, x, cond, T=None):
if T is None:
sampler = self.eval_sampler
else:
sampler = self.conf._make_diffusion_conf(T).make_sampler()
out = sampler.ddim_reverse_sample_loop(self.ema_model,
x,
model_kwargs={'cond': cond})
return out['sample'], out['xstart_t']
def forward(self, noise=None, x_start=None, ema_model: bool = False):
with amp.autocast(False):
if ema_model:
model = self.ema_model
else:
model = self.model
gen = self.eval_sampler.sample(model=model,
noise=noise,
x_start=x_start)
return gen
def setup(self, stage=None) -> None:
"""
make datasets & seeding each worker separately
"""
if self.conf.seed is not None:
seed = self.conf.seed * get_world_size() + self.global_rank
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
taskdata, tasklabels = loadDataset(self.conf)
assert self.conf.num_users==tasklabels['user'].nunique()
tasklabels = pd.DataFrame(tasklabels['user'].astype('category').cat.codes.values)[0].values.astype(int)
assert self.conf.num_users==len(np.unique(tasklabels))
self.train_data = self.conf.make_dataset(taskdata, tasklabels)
self.val_data = self.train_data
def train_dataloader(self):
"""
return the dataloader, if diffusion mode => return image dataset
if latent mode => return the inferred latent dataset
"""
if self.conf.train_mode.require_dataset_infer():
if self.conds is None:
self.conds = self.infer_whole_dataset()
self.conds_mean.data = self.conds.float().mean(dim=0,
keepdim=True)
self.conds_std.data = self.conds.float().std(dim=0,
keepdim=True)
print('mean:', self.conds_mean.mean(), 'std:',
self.conds_std.mean())
conf = self.conf.clone()
conf.batch_size = self.batch_size
data = TensorDataset(self.conds)
return conf.make_loader(data, shuffle=True)
else:
return torch.utils.data.DataLoader(self.train_data, batch_size=self.conf.batch_size, shuffle=True)
@property
def batch_size(self):
"""
local batch size for each worker
"""
ws = get_world_size()
assert self.conf.batch_size % ws == 0
return self.conf.batch_size // ws
@property
def num_samples(self):
"""
(global) batch size * iterations
"""
return self.global_step * self.conf.batch_size_effective
def is_last_accum(self, batch_idx):
"""
is it the last gradient accumulation loop?
used with gradient_accum > 1 and to see if the optimizer will perform "step" in this iteration or not
"""
return (batch_idx + 1) % self.conf.accum_batches == 0
def training_step(self, batch, batch_idx):
"""
given an input, calculate the loss function
no optimization at this stage.
"""
with amp.autocast(False):
if self.conf.train_mode.require_dataset_infer():
cond = batch[0]
if self.conf.latent_znormalize:
cond = (cond - self.conds_mean.to(
self.device)) / self.conds_std.to(self.device)
else:
imgs = batch[0]
x_start = imgs
if self.conf.train_mode == TrainMode.diffusion:
t, weight = self.T_sampler.sample(len(x_start), x_start.device)
losses = self.sampler.training_losses(model=self.model,
x_start=x_start,
t=t,
user_label=batch[1],
lossbetas=self.conf.betas)
elif self.conf.train_mode.is_latent_diffusion():
t, weight = self.T_sampler.sample(len(cond), cond.device)
latent_losses = self.latent_sampler.training_losses(
model=self.model.latent_net, x_start=cond, t=t)
losses = {
'latent': latent_losses['loss'],
'loss': latent_losses['loss']
}
else:
raise NotImplementedError()
loss = losses['loss'].mean()
self.log("train_loss", loss)
return {'loss': loss}
def on_train_batch_end(self, outputs, batch, batch_idx: int,
dataloader_idx: int) -> None:
"""
after each training step
"""
if self.is_last_accum(batch_idx):
if (batch_idx==len(self.train_dataloader())-1) and ((self.current_epoch+1) % 10 == 0):
save_path = os.path.join(self.conf.logdir, 'epoch%d.ckpt' % self.current_epoch)
torch.save({
'state_dict': self.state_dict(),
'global_step': self.global_step,
'loss': outputs['loss'],
}, save_path)
if self.conf.train_mode == TrainMode.latent_diffusion:
ema(self.model.latent_net, self.ema_model.latent_net,
self.conf.ema_decay)
else:
ema(self.model, self.ema_model, self.conf.ema_decay)
def on_before_optimizer_step(self, optimizer: Optimizer,
optimizer_idx: int) -> None:
if self.conf.grad_clip > 0:
params = [
p for group in optimizer.param_groups for p in group['params']
]
torch.nn.utils.clip_grad_norm_(params,
max_norm=self.conf.grad_clip)
def configure_optimizers(self):
out = {}
if self.conf.optimizer == OptimizerType.adam:
optim = torch.optim.Adam(self.model.parameters(),
lr=self.conf.lr,
weight_decay=self.conf.weight_decay)
elif self.conf.optimizer == OptimizerType.adamw:
optim = torch.optim.AdamW(self.model.parameters(),
lr=self.conf.lr,
weight_decay=self.conf.weight_decay)
else:
raise NotImplementedError()
out['optimizer'] = optim
if self.conf.warmup > 0:
sched = torch.optim.lr_scheduler.LambdaLR(optim,
lr_lambda=WarmupLR(
self.conf.warmup))
out['lr_scheduler'] = {
'scheduler': sched,
'interval': 'step',
}
return out
def split_tensor(self, x):
"""
extract the tensor for a corresponding "worker" in the batch dimension
Args:
x: (n, c)
Returns: x: (n_local, c)
"""
n = len(x)
rank = self.global_rank
world_size = get_world_size()
per_rank = n // world_size
return x[rank * per_rank:(rank + 1) * per_rank]
def ema(source, target, decay):
source_dict = source.state_dict()
target_dict = target.state_dict()
for key in source_dict.keys():
target_dict[key].data.copy_(target_dict[key].data * decay +
source_dict[key].data * (1 - decay))
class WarmupLR:
def __init__(self, warmup) -> None:
self.warmup = warmup
def __call__(self, step):
return min(step, self.warmup) / self.warmup
def is_time(num_samples, every, step_size):
closest = (num_samples // every) * every
return num_samples - closest < step_size
def train(conf: TrainConfig, model: LitModel, gpus, nodes=1):
checkpoint = ModelCheckpoint(dirpath=conf.logdir,
save_last=True,
save_top_k=1,
every_n_train_steps=conf.save_every_samples //
conf.batch_size_effective)
checkpoint_path = f'{conf.logdir}last.ckpt'
print('ckpt path:', checkpoint_path)
if os.path.exists(checkpoint_path):
resume = checkpoint_path
print('resume!')
else:
if conf.continue_from is not None:
resume = conf.continue_from.path
else:
resume = None
plugins = []
if len(gpus) == 1 and nodes == 1:
accelerator = None
else:
accelerator = 'ddp'
from pytorch_lightning.plugins import DDPPlugin
plugins.append(DDPPlugin(find_unused_parameters=False))
wandb_logger = pl_loggers.WandbLogger(project='dismouse',
name='%s_%s'%(model.conf.pretrainDataset, conf.logdir.split('/')[-2]),
log_model=True,
save_dir=conf.logdir,
dir = conf.logdir,
config=vars(model.conf),
save_code=True,
settings=wandb.Settings(code_dir="."))
trainer = pl.Trainer(
max_steps=conf.total_samples // conf.batch_size_effective,
resume_from_checkpoint=resume,
gpus=gpus,
num_nodes=nodes,
accelerator=accelerator,
precision=16 if conf.fp16 else 32,
callbacks=[
checkpoint,
LearningRateMonitor(),
],
replace_sampler_ddp=True,
logger= wandb_logger,
accumulate_grad_batches=conf.accum_batches,
plugins=plugins,
)
trainer.fit(model)
wandb.finish()

29
main.py Normal file
View file

@ -0,0 +1,29 @@
import os
from templates import *
if __name__ == '__main__':
gpus = [0]
conf = mouse_autoenc('trainDiff')
betas = {
'recon':1,
'noise':1,
'user':0.01,
'nonuser':0.01,
'mi': 0.01
}
betastr = ''
for k,v in betas.items():
betastr += f'{k}{v}_'
betastr = betastr[:-1]
diffmodel = LitModel(conf, betas)
conf.logdir= f'{conf.logdir}/mouse_autoenc/{conf.pretrainDataset}/{betastr}/embDim{conf.net_beatgans_embed_channels}/win{conf.AEwin}/slid{conf.slid}/GRL/'
MakeDir(conf.logdir)
os.environ['WANDB_CACHE_DIR'] = conf.logdir
os.environ['WANDB_DATA_DIR'] = conf.logdir
os.environ['WANDB_IGNORE_GLOBS'] = '*.ckpt'
train(conf, diffmodel, gpus=gpus)

176
model/MI.py Normal file
View file

@ -0,0 +1,176 @@
'''
Differentiable approximation to the mutual information (MI) metric.
Implementation in PyTorch
'''
# Imports #
# ----------------------------------------------------------------------
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
import os
# Note: This code snippet was taken from the discussion found at:
# https://discuss.pytorch.org/t/differentiable-torch-histc/25865/2
# By Tony-Y
class SoftHistogram1D(nn.Module):
'''
Differentiable 1D histogram calculation (supported via pytorch's autograd)
inupt:
x - N x D array, where N is the batch size and D is the length of each data series
bins - Number of bins for the histogram
min - Scalar min value to be included in the histogram
max - Scalar max value to be included in the histogram
sigma - Scalar smoothing factor fir the bin approximation via sigmoid functions.
Larger values correspond to sharper edges, and thus yield a more accurate approximation
output:
N x bins array, where each row is a histogram
'''
def __init__(self, bins=50, min=0, max=1, sigma=10):
super(SoftHistogram1D, self).__init__()
self.bins = bins
self.min = min
self.max = max
self.sigma = sigma
self.delta = float(max - min) / float(bins)
self.centers = float(min) + self.delta * (torch.arange(bins).float() + 0.5) # Bin centers
self.centers = nn.Parameter(self.centers, requires_grad=False) # Wrap for allow for cuda support
def forward(self, x):
# Replicate x and for each row remove center
x = torch.unsqueeze(x, 1) - torch.unsqueeze(self.centers, 1)
# Bin approximation using a sigmoid function
x = torch.sigmoid(self.sigma * (x + self.delta / 2)) - torch.sigmoid(self.sigma * (x - self.delta / 2))
# Sum along the non-batch dimensions
x = x.sum(dim=-1)
# x = x / x.sum(dim=-1).unsqueeze(1) # normalization
return x
# Note: This is an extension to the 2D case of the previous code snippet
class SoftHistogram2D(nn.Module):
'''
Differentiable 1D histogram calculation (supported via pytorch's autograd)
inupt:
x, y - N x D array, where N is the batch size and D is the length of each data series
(i.e. vectorized image or vectorized 3D volume)
bins - Number of bins for the histogram
min - Scalar min value to be included in the histogram
max - Scalar max value to be included in the histogram
sigma - Scalar smoothing factor fir the bin approximation via sigmoid functions.
Larger values correspond to sharper edges, and thus yield a more accurate approximation
output:
N x bins array, where each row is a histogram
'''
def __init__(self, bins=50, min=0, max=1, sigma=10):
super(SoftHistogram2D, self).__init__()
self.bins = bins
self.min = min
self.max = max
self.sigma = sigma
self.delta = float(max - min) / float(bins)
self.centers = float(min) + self.delta * (torch.arange(bins).float() + 0.5) # Bin centers
self.centers = nn.Parameter(self.centers, requires_grad=False) # Wrap for allow for cuda support
def forward(self, x, y):
assert x.size() == y.size(), "(SoftHistogram2D) x and y sizes do not match"
# Replicate x and for each row remove center
x = torch.unsqueeze(x, 1) - torch.unsqueeze(self.centers, 1)
y = torch.unsqueeze(y, 1) - torch.unsqueeze(self.centers, 1)
# Bin approximation using a sigmoid function (can be sigma_x and sigma_y respectively - same for delta)
x = torch.sigmoid(self.sigma * (x + self.delta / 2)) - torch.sigmoid(self.sigma * (x - self.delta / 2))
y = torch.sigmoid(self.sigma * (y + self.delta / 2)) - torch.sigmoid(self.sigma * (y - self.delta / 2))
# Batched matrix multiplication - this way we sum jointly
z = torch.matmul(x, y.permute((0, 2, 1)))
return z
class MI_pytorch(nn.Module):
'''
This class is a pytorch implementation of the mutual information (MI) calculation between two images.
This is an approximation, as the images' histograms rely on differentiable approximations of rectangular windows.
I(X, Y) = H(X) + H(Y) - H(X, Y) = \sum(\sum(p(X, Y) * log(p(Y, Y)/(p(X) * p(Y)))))
where H(X) = -\sum(p(x) * log(p(x))) is the entropy
'''
def __init__(self, bins=50, min=0, max=1, sigma=10, reduction='sum'):
super(MI_pytorch, self).__init__()
self.bins = bins
self.min = min
self.max = max
self.sigma = sigma
self.reduction = reduction
# 2D joint histogram
self.hist2d = SoftHistogram2D(bins, min, max, sigma)
# Epsilon - to avoid log(0)
self.eps = torch.tensor(0.00000001, dtype=torch.float32, requires_grad=False)
def forward(self, im1, im2):
'''
Forward implementation of a differentiable MI estimator for batched images
:param im1: N x ... tensor, where N is the batch size
... dimensions can take any form, i.e. 2D images or 3D volumes.
:param im2: N x ... tensor, where N is the batch size
:return: N x 1 vector - the approximate MI values between the batched im1 and im2
'''
# Check for valid inputs
assert im1.size() == im2.size(), "(MI_pytorch) Inputs should have the same dimensions."
batch_size = im1.size()[0]
# Flatten tensors
im1_flat = im1.view(im1.size()[0], -1)
im2_flat = im2.view(im2.size()[0], -1)
# Calculate joint histogram
hgram = self.hist2d(im1_flat, im2_flat)
# Convert to a joint distribution
# Pxy = torch.distributions.Categorical(probs=hgram).probs
Pxy = torch.div(hgram, torch.sum(hgram.view(hgram.size()[0], -1)))
# Calculate the marginal distributions
Py = torch.sum(Pxy, dim=1).unsqueeze(1)
Px = torch.sum(Pxy, dim=2).unsqueeze(1)
# Use the KL divergence distance to calculate the MI
Px_Py = torch.matmul(Px.permute((0, 2, 1)), Py)
# Reshape to batch_size X all_the_rest
Pxy = Pxy.reshape(batch_size, -1)
Px_Py = Px_Py.reshape(batch_size, -1)
# Calculate mutual information - this is an approximation due to the histogram calculation and eps,
# but it can handle batches
if batch_size == 1:
# No need for eps approximation in the case of a single batch
nzs = Pxy > 0 # Calculate based on the non-zero values only
mut_info = torch.matmul(Pxy[nzs], torch.log(Pxy[nzs]) - torch.log(Px_Py[nzs])) # MI calculation
else:
# For arbitrary batch size > 1
mut_info = torch.sum(Pxy * (torch.log(Pxy + self.eps) - torch.log(Px_Py + self.eps)), dim=1)
# Reduction
if self.reduction == 'sum':
mut_info = torch.sum(mut_info)
elif self.reduction == 'batchmean':
mut_info = torch.sum(mut_info)
mut_info = mut_info / float(batch_size)
elif self.reduction=='individual':
pass
return mut_info

6
model/__init__.py Normal file
View file

@ -0,0 +1,6 @@
from typing import Union
from .unet import BeatGANsUNetModel, BeatGANsUNetConfig
from .unet_autoenc import BeatGANsAutoencConfig, BeatGANsAutoencModel
Model = Union[BeatGANsUNetModel, BeatGANsAutoencModel]
ModelConfig = Union[BeatGANsUNetConfig, BeatGANsAutoencConfig]

495
model/blocks.py Normal file
View file

@ -0,0 +1,495 @@
import math
from abc import abstractmethod
from dataclasses import dataclass
from numbers import Number
import torch as th
import torch.nn.functional as F
from choices import *
from config_base import BaseConfig
from torch import nn
from .nn import (avg_pool_nd, conv_nd, linear, normalization,
timestep_embedding, torch_checkpoint, zero_module)
class ScaleAt(Enum):
after_norm = 'afternorm'
class TimestepBlock(nn.Module):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@abstractmethod
def forward(self, x, emb=None, cond=None, lateral=None):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
"""
A sequential module that passes timestep embeddings to the children that
support it as an extra input.
"""
def forward(self, x, emb=None, cond=None, lateral=None):
for layer in self:
if isinstance(layer, TimestepBlock):
x = layer(x, emb=emb, cond=cond, lateral=lateral)
else:
x = layer(x)
return x
@dataclass
class ResBlockConfig(BaseConfig):
channels: int
emb_channels: int
dropout: float
out_channels: int = None
use_condition: bool = True
use_conv: bool = False
dims: int = 2
use_checkpoint: bool = False
up: bool = False
down: bool = False
two_cond: bool = False
cond_emb_channels: int = None
has_lateral: bool = False
lateral_channels: int = None
use_zero_module: bool = True
def __post_init__(self):
self.out_channels = self.out_channels or self.channels
self.cond_emb_channels = self.cond_emb_channels or self.emb_channels
def make_model(self):
return ResBlock(self)
class ResBlock(TimestepBlock):
"""
A residual block that can optionally change the number of channels.
"""
def __init__(self, conf: ResBlockConfig):
super().__init__()
self.conf = conf
assert conf.lateral_channels is None
layers = [
normalization(conf.channels),
nn.SiLU(),
conv_nd(conf.dims, conf.channels, conf.out_channels, 3, padding=1)
]
self.in_layers = nn.Sequential(*layers)
self.updown = conf.up or conf.down
if conf.up:
self.h_upd = Upsample(conf.channels, False, conf.dims)
self.x_upd = Upsample(conf.channels, False, conf.dims)
elif conf.down:
self.h_upd = Downsample(conf.channels, False, conf.dims)
self.x_upd = Downsample(conf.channels, False, conf.dims)
else:
self.h_upd = self.x_upd = nn.Identity()
if conf.use_condition:
self.emb_layers = nn.Sequential(
nn.SiLU(),
linear(conf.emb_channels, 2 * conf.out_channels),
)
if conf.two_cond:
self.cond_emb_layers = nn.Sequential(
nn.SiLU(),
linear(conf.cond_emb_channels, conf.out_channels),
)
conv = conv_nd(conf.dims,
conf.out_channels,
conf.out_channels,
3,
padding=1)
if conf.use_zero_module:
conv = zero_module(conv)
layers = []
layers += [
normalization(conf.out_channels),
nn.SiLU(),
nn.Dropout(p=conf.dropout),
conv,
]
self.out_layers = nn.Sequential(*layers)
if conf.out_channels == conf.channels:
self.skip_connection = nn.Identity()
else:
if conf.use_conv:
kernel_size = 3
padding = 1
else:
kernel_size = 1
padding = 0
self.skip_connection = conv_nd(conf.dims,
conf.channels,
conf.out_channels,
kernel_size,
padding=padding)
def forward(self, x, emb=None, cond=None, lateral=None):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
Args:
x: input
lateral: lateral connection from the encoder
"""
return torch_checkpoint(self._forward, (x, emb, cond, lateral),
self.conf.use_checkpoint)
def _forward(
self,
x,
emb=None,
cond=None,
lateral=None,
):
"""
Args:
lateral: required if "has_lateral" and non-gated, with gated, it can be supplied optionally
"""
if self.conf.has_lateral:
assert lateral is not None
x = th.cat([x, lateral], dim=1)
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
h = in_rest(x)
h = self.h_upd(h)
x = self.x_upd(x)
h = in_conv(h)
else:
h = self.in_layers(x)
if self.conf.use_condition:
if emb is not None:
emb_out = self.emb_layers(emb).type(h.dtype)
else:
emb_out = None
if self.conf.two_cond:
if cond is None:
cond_out = None
else:
if not isinstance(cond, th.Tensor):
assert isinstance(cond, dict)
cond = cond['cond']
cond_out = self.cond_emb_layers(cond).type(h.dtype)
if cond_out is not None:
while len(cond_out.shape) < len(h.shape):
cond_out = cond_out[..., None]
else:
cond_out = None
h = apply_conditions(
h=h,
emb=emb_out,
cond=cond_out,
layers=self.out_layers,
scale_bias=1,
in_channels=self.conf.out_channels,
up_down_layer=None,
)
return self.skip_connection(x) + h
def apply_conditions(
h,
emb=None,
cond=None,
layers: nn.Sequential = None,
scale_bias: float = 1,
in_channels: int = 512,
up_down_layer: nn.Module = None,
):
"""
apply conditions on the feature maps
Args:
emb: time conditional (ready to scale + shift)
cond: encoder's conditional (read to scale + shift)
"""
two_cond = emb is not None and cond is not None
if emb is not None:
while len(emb.shape) < len(h.shape):
emb = emb[..., None]
if two_cond:
while len(cond.shape) < len(h.shape):
cond = cond[..., None]
scale_shifts = [emb, cond]
else:
scale_shifts = [emb]
for i, each in enumerate(scale_shifts):
if each is None:
a = None
b = None
else:
if each.shape[1] == in_channels * 2:
a, b = th.chunk(each, 2, dim=1)
else:
a = each
b = None
scale_shifts[i] = (a, b)
if isinstance(scale_bias, Number):
biases = [scale_bias] * len(scale_shifts)
else:
biases = scale_bias
pre_layers, post_layers = layers[0], layers[1:]
mid_layers, post_layers = post_layers[:-2], post_layers[-2:]
h = pre_layers(h)
for i, (scale, shift) in enumerate(scale_shifts):
if scale is not None:
h = h * (biases[i] + scale)
if shift is not None:
h = h + shift
h = mid_layers(h)
if up_down_layer is not None:
h = up_down_layer(h)
h = post_layers(h)
return h
class Upsample(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
if use_conv:
self.conv = conv_nd(dims,
self.channels,
self.out_channels,
3,
padding=1)
def forward(self, x):
assert x.shape[1] == self.channels
if self.dims == 3:
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
mode="nearest")
else:
x = F.interpolate(x, scale_factor=2, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
stride = 2 if dims != 3 else (1, 2, 2)
if use_conv:
self.op = conv_nd(dims,
self.channels,
self.out_channels,
3,
stride=stride,
padding=1)
else:
assert self.channels == self.out_channels
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
def forward(self, x):
assert x.shape[1] == self.channels
return self.op(x)
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def __init__(
self,
channels,
num_heads=1,
num_head_channels=-1,
use_checkpoint=False,
use_new_attention_order=False,
):
super().__init__()
self.channels = channels
if num_head_channels == -1:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.use_checkpoint = use_checkpoint
self.norm = normalization(channels)
self.qkv = conv_nd(1, channels, channels * 3, 1)
if use_new_attention_order:
self.attention = QKVAttention(self.num_heads)
else:
self.attention = QKVAttentionLegacy(self.num_heads)
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
def forward(self, x):
return torch_checkpoint(self._forward, (x, ), self.use_checkpoint)
def _forward(self, x):
b, c, *spatial = x.shape
x = x.reshape(b, c, -1)
qkv = self.qkv(self.norm(x))
h = self.attention(qkv)
h = self.proj_out(h)
return (x + h).reshape(b, c, *spatial)
def count_flops_attn(model, _x, y):
"""
A counter for the `thop` package to count the operations in an
attention operation.
Meant to be used like:
macs, params = thop.profile(
model,
inputs=(inputs, timestamps),
custom_ops={QKVAttention: QKVAttention.count_flops},
)
"""
b, c, *spatial = y[0].shape
num_spatial = int(np.prod(spatial))
matmul_ops = 2 * b * (num_spatial**2) * c
model.total_ops += th.DoubleTensor([matmul_ops])
class QKVAttentionLegacy(nn.Module):
"""
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
"""
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv):
"""
Apply QKV attention.
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch,
dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum(
"bct,bcs->bts", q * scale,
k * scale)
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
a = th.einsum("bts,bcs->bct", weight, v)
return a.reshape(bs, -1, length)
@staticmethod
def count_flops(model, _x, y):
return count_flops_attn(model, _x, y)
class QKVAttention(nn.Module):
"""
A module which performs QKV attention and splits in a different order.
"""
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv):
"""
Apply QKV attention.
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.chunk(3, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum(
"bct,bcs->bts",
(q * scale).view(bs * self.n_heads, ch, length),
(k * scale).view(bs * self.n_heads, ch, length),
)
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
a = th.einsum("bts,bcs->bct", weight,
v.reshape(bs * self.n_heads, ch, length))
return a.reshape(bs, -1, length)
@staticmethod
def count_flops(model, _x, y):
return count_flops_attn(model, _x, y)
class AttentionPool2d(nn.Module):
"""
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
"""
def __init__(
self,
spacial_dim: int,
embed_dim: int,
num_heads_channels: int,
output_dim: int = None,
):
super().__init__()
self.positional_embedding = nn.Parameter(
th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5)
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
self.num_heads = embed_dim // num_heads_channels
self.attention = QKVAttention(self.num_heads)
def forward(self, x):
b, c, *_spatial = x.shape
x = x.reshape(b, c, -1)
x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1)
x = x + self.positional_embedding[None, :, :].to(x.dtype)
x = self.qkv_proj(x)
x = self.attention(x)
x = self.c_proj(x)
return x[:, :, 0]

184
model/latentnet.py Normal file
View file

@ -0,0 +1,184 @@
import math
from dataclasses import dataclass
from enum import Enum
from typing import NamedTuple, Tuple
import torch
from choices import *
from config_base import BaseConfig
from torch import nn
from torch.nn import init
from .blocks import *
from .nn import timestep_embedding
from .unet import *
class LatentNetType(Enum):
none = 'none'
skip = 'skip'
class LatentNetReturn(NamedTuple):
pred: torch.Tensor = None
@dataclass
class MLPSkipNetConfig(BaseConfig):
"""
default MLP for the latent DPM in the paper!
"""
num_channels: int
skip_layers: Tuple[int]
num_hid_channels: int
num_layers: int
num_time_emb_channels: int = 64
activation: Activation = Activation.silu
use_norm: bool = True
condition_bias: float = 1
dropout: float = 0
last_act: Activation = Activation.none
num_time_layers: int = 2
time_last_act: bool = False
def make_model(self):
return MLPSkipNet(self)
class MLPSkipNet(nn.Module):
"""
concat x to hidden layers
default MLP for the latent DPM in the paper!
"""
def __init__(self, conf: MLPSkipNetConfig):
super().__init__()
self.conf = conf
layers = []
for i in range(conf.num_time_layers):
if i == 0:
a = conf.num_time_emb_channels
b = conf.num_channels
else:
a = conf.num_channels
b = conf.num_channels
layers.append(nn.Linear(a, b))
if i < conf.num_time_layers - 1 or conf.time_last_act:
layers.append(conf.activation.get_act())
self.time_embed = nn.Sequential(*layers)
self.layers = nn.ModuleList([])
for i in range(conf.num_layers):
if i == 0:
act = conf.activation
norm = conf.use_norm
cond = True
a, b = conf.num_channels, conf.num_hid_channels
dropout = conf.dropout
elif i == conf.num_layers - 1:
act = Activation.none
norm = False
cond = False
a, b = conf.num_hid_channels, conf.num_channels
dropout = 0
else:
act = conf.activation
norm = conf.use_norm
cond = True
a, b = conf.num_hid_channels, conf.num_hid_channels
dropout = conf.dropout
if i in conf.skip_layers:
a += conf.num_channels
self.layers.append(
MLPLNAct(
a,
b,
norm=norm,
activation=act,
cond_channels=conf.num_channels,
use_cond=cond,
condition_bias=conf.condition_bias,
dropout=dropout,
))
self.last_act = conf.last_act.get_act()
def forward(self, x, t, **kwargs):
t = timestep_embedding(t, self.conf.num_time_emb_channels)
cond = self.time_embed(t)
h = x
for i in range(len(self.layers)):
if i in self.conf.skip_layers:
h = torch.cat([h, x], dim=1)
h = self.layers[i].forward(x=h, cond=cond)
h = self.last_act(h)
return LatentNetReturn(h)
class MLPLNAct(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
norm: bool,
use_cond: bool,
activation: Activation,
cond_channels: int,
condition_bias: float = 0,
dropout: float = 0,
):
super().__init__()
self.activation = activation
self.condition_bias = condition_bias
self.use_cond = use_cond
self.linear = nn.Linear(in_channels, out_channels)
self.act = activation.get_act()
if self.use_cond:
self.linear_emb = nn.Linear(cond_channels, out_channels)
self.cond_layers = nn.Sequential(self.act, self.linear_emb)
if norm:
self.norm = nn.LayerNorm(out_channels)
else:
self.norm = nn.Identity()
if dropout > 0:
self.dropout = nn.Dropout(p=dropout)
else:
self.dropout = nn.Identity()
self.init_weights()
def init_weights(self):
for module in self.modules():
if isinstance(module, nn.Linear):
if self.activation == Activation.relu:
init.kaiming_normal_(module.weight,
a=0,
nonlinearity='relu')
elif self.activation == Activation.lrelu:
init.kaiming_normal_(module.weight,
a=0.2,
nonlinearity='leaky_relu')
elif self.activation == Activation.silu:
init.kaiming_normal_(module.weight,
a=0,
nonlinearity='relu')
else:
pass
def forward(self, x, cond=None):
x = self.linear(x)
if self.use_cond:
cond = self.cond_layers(cond)
cond = (cond, None)
x = x * (self.condition_bias + cond[0])
if cond[1] is not None:
x = x + cond[1]
x = self.norm(x)
else:
x = self.norm(x)
x = self.act(x)
x = self.dropout(x)
return x

135
model/nn.py Executable file
View file

@ -0,0 +1,135 @@
"""
Various utilities for neural networks.
"""
import math
import torch as th
import torch.nn as nn
import torch.utils.checkpoint
import torch.nn.functional as F
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
class SiLU(nn.Module):
# @th.jit.script
def forward(self, x):
return x * th.sigmoid(x)
class GroupNorm32(nn.GroupNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
assert dims==1
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def linear(*args, **kwargs):
"""
Create a linear module.
"""
return nn.Linear(*args, **kwargs)
def avg_pool_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
assert dims==1
if dims == 1:
return nn.AvgPool1d(*args, **kwargs)
elif dims == 2:
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def update_ema(target_params, source_params, rate=0.99):
"""
Update target parameters to be closer to those of source parameters using
an exponential moving average.
:param target_params: the target parameter sequence.
:param source_params: the source parameter sequence.
:param rate: the EMA rate (closer to 1 means slower).
"""
for targ, src in zip(target_params, source_params):
targ.detach().mul_(rate).add_(src, alpha=1 - rate)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def scale_module(module, scale):
"""
Scale the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().mul_(scale)
return module
def mean_flat(tensor):
"""
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))
def normalization(channels):
"""
Make a standard normalization layer.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
return GroupNorm32(min(32, channels), channels)
def timestep_embedding(timesteps, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
half = dim // 2
freqs = th.exp(-math.log(max_period) *
th.arange(start=0, end=half, dtype=th.float32) /
half).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
if dim % 2:
embedding = th.cat(
[embedding, th.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def torch_checkpoint(func, args, flag, preserve_rng_state=False):
# torch's gradient checkpoint works with automatic mixed precision, given torch >= 1.8
if flag:
return torch.utils.checkpoint.checkpoint(
func, *args, preserve_rng_state=preserve_rng_state)
else:
return func(*args)

505
model/unet.py Normal file
View file

@ -0,0 +1,505 @@
import math
from dataclasses import dataclass
from numbers import Number
from typing import NamedTuple, Tuple, Union
import numpy as np
import torch as th
from torch import nn
import torch.nn.functional as F
from choices import *
from config_base import BaseConfig
from .blocks import *
from .nn import (conv_nd, linear, normalization, timestep_embedding,
torch_checkpoint, zero_module)
@dataclass
class BeatGANsUNetConfig(BaseConfig):
image_size: int = 64
in_channels: int = 2
model_channels: int = 64
out_channels: int = 2
num_res_blocks: int = 2
num_input_res_blocks: int = None
embed_channels: int = 512
attention_resolutions: Tuple[int] = (16, )
time_embed_channels: int = None
dropout: float = 0.1
channel_mult: Tuple[int] = (1, 2, 4, 8)
input_channel_mult: Tuple[int] = None
conv_resample: bool = True
dims: int = 2
num_classes: int = None
use_checkpoint: bool = False
num_heads: int = 1
num_head_channels: int = -1
num_heads_upsample: int = -1
resblock_updown: bool = True
use_new_attention_order: bool = False
resnet_two_cond: bool = False
resnet_cond_channels: int = None
resnet_use_zero_module: bool = True
attn_checkpoint: bool = False
num_users: int = None
def make_model(self):
return BeatGANsUNetModel(self)
class BeatGANsUNetModel(nn.Module):
def __init__(self, conf: BeatGANsUNetConfig):
super().__init__()
self.conf = conf
if conf.num_heads_upsample == -1:
self.num_heads_upsample = conf.num_heads
self.dtype = th.float32
self.time_emb_channels = conf.time_embed_channels or conf.model_channels
self.time_embed = nn.Sequential(
linear(self.time_emb_channels, conf.embed_channels),
nn.SiLU(),
linear(conf.embed_channels, conf.embed_channels),
)
if conf.num_classes is not None:
self.label_emb = nn.Embedding(conf.num_classes,
conf.embed_channels)
ch = input_ch = int(conf.channel_mult[0] * conf.model_channels)
self.input_blocks = nn.ModuleList([
TimestepEmbedSequential(
conv_nd(conf.dims, conf.in_channels, ch, 3, padding=1))
])
kwargs = dict(
use_condition=True,
two_cond=conf.resnet_two_cond,
use_zero_module=conf.resnet_use_zero_module,
cond_emb_channels=conf.resnet_cond_channels,
)
self._feature_size = ch
# input_block_chans = [ch]
input_block_chans = [[] for _ in range(len(conf.channel_mult))]
input_block_chans[0].append(ch)
# number of blocks at each resolution
self.input_num_blocks = [0 for _ in range(len(conf.channel_mult))]
self.input_num_blocks[0] = 1
self.output_num_blocks = [0 for _ in range(len(conf.channel_mult))]
ds = 1
resolution = conf.image_size
for level, mult in enumerate(conf.input_channel_mult
or conf.channel_mult):
for _ in range(conf.num_input_res_blocks or conf.num_res_blocks):
layers = [
ResBlockConfig(
ch,
conf.embed_channels,
conf.dropout,
out_channels=int(mult * conf.model_channels),
dims=conf.dims,
use_checkpoint=conf.use_checkpoint,
**kwargs,
).make_model()
]
ch = int(mult * conf.model_channels)
if resolution in conf.attention_resolutions:
layers.append(
AttentionBlock(
ch,
use_checkpoint=conf.use_checkpoint
or conf.attn_checkpoint,
num_heads=conf.num_heads,
num_head_channels=conf.num_head_channels,
use_new_attention_order=conf.
use_new_attention_order,
))
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
input_block_chans[level].append(ch)
self.input_num_blocks[level] += 1
if level != len(conf.channel_mult) - 1:
resolution //= 2
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlockConfig(
ch,
conf.embed_channels,
conf.dropout,
out_channels=out_ch,
dims=conf.dims,
use_checkpoint=conf.use_checkpoint,
down=True,
**kwargs,
).make_model() if conf.
resblock_updown else Downsample(ch,
conf.conv_resample,
dims=conf.dims,
out_channels=out_ch)))
ch = out_ch
input_block_chans[level + 1].append(ch)
self.input_num_blocks[level + 1] += 1
ds *= 2
self._feature_size += ch
self.middle_block = TimestepEmbedSequential(
ResBlockConfig(
ch,
conf.embed_channels,
conf.dropout,
dims=conf.dims,
use_checkpoint=conf.use_checkpoint,
**kwargs,
).make_model(),
AttentionBlock(
ch,
use_checkpoint=conf.use_checkpoint or conf.attn_checkpoint,
num_heads=conf.num_heads,
num_head_channels=conf.num_head_channels,
use_new_attention_order=conf.use_new_attention_order,
),
ResBlockConfig(
ch,
conf.embed_channels,
conf.dropout,
dims=conf.dims,
use_checkpoint=conf.use_checkpoint,
**kwargs,
).make_model(),
)
self._feature_size += ch
self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(conf.channel_mult))[::-1]:
for i in range(conf.num_res_blocks + 1):
try:
ich = input_block_chans[level].pop()
except IndexError:
ich = 0
layers = [
ResBlockConfig(
channels=ch + ich,
emb_channels=conf.embed_channels,
dropout=conf.dropout,
out_channels=int(conf.model_channels * mult),
dims=conf.dims,
use_checkpoint=conf.use_checkpoint,
has_lateral=True if ich > 0 else False,
lateral_channels=None,
**kwargs,
).make_model()
]
ch = int(conf.model_channels * mult)
if resolution in conf.attention_resolutions:
layers.append(
AttentionBlock(
ch,
use_checkpoint=conf.use_checkpoint
or conf.attn_checkpoint,
num_heads=self.num_heads_upsample,
num_head_channels=conf.num_head_channels,
use_new_attention_order=conf.
use_new_attention_order,
))
if level and i == conf.num_res_blocks:
resolution *= 2
out_ch = ch
layers.append(
ResBlockConfig(
ch,
conf.embed_channels,
conf.dropout,
out_channels=out_ch,
dims=conf.dims,
use_checkpoint=conf.use_checkpoint,
up=True,
**kwargs,
).make_model() if (
conf.resblock_updown
) else Upsample(ch,
conf.conv_resample,
dims=conf.dims,
out_channels=out_ch))
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
self.output_num_blocks[level] += 1
self._feature_size += ch
if conf.resnet_use_zero_module:
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
zero_module(
conv_nd(conf.dims,
input_ch,
conf.out_channels,
3,
padding=1)),
)
else:
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
conv_nd(conf.dims, input_ch, conf.out_channels, 3, padding=1),
)
def forward(self, x, t, y=None, **kwargs):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs.
"""
assert (y is not None) == (
self.conf.num_classes is not None
), "must specify y if and only if the model is class-conditional"
# hs = []
hs = [[] for _ in range(len(self.conf.channel_mult))]
emb = self.time_embed(timestep_embedding(t, self.time_emb_channels))
if self.conf.num_classes is not None:
raise NotImplementedError()
h = x.type(self.dtype)
k = 0
for i in range(len(self.input_num_blocks)):
for j in range(self.input_num_blocks[i]):
h = self.input_blocks[k](h, emb=emb)
hs[i].append(h)
k += 1
assert k == len(self.input_blocks)
# middle blocks
h = self.middle_block(h, emb=emb)
# output blocks
k = 0
for i in range(len(self.output_num_blocks)):
for j in range(self.output_num_blocks[i]):
try:
lateral = hs[-i - 1].pop()
except IndexError:
lateral = None
h = self.output_blocks[k](h, emb=emb, lateral=lateral)
k += 1
h = h.type(x.dtype)
pred = self.out(h)
return Return(pred=pred)
class Return(NamedTuple):
pred: th.Tensor
@dataclass
class BeatGANsEncoderConfig(BaseConfig):
image_size: int
in_channels: int
model_channels: int
out_hid_channels: int
out_channels: int
num_res_blocks: int
attention_resolutions: Tuple[int]
dropout: float = 0
channel_mult: Tuple[int] = (1, 2, 4, 8)
use_time_condition: bool = True
conv_resample: bool = True
dims: int = 2
use_checkpoint: bool = False
num_heads: int = 1
num_head_channels: int = -1
resblock_updown: bool = False
use_new_attention_order: bool = False
pool: str = 'adaptivenonzero'
def make_model(self):
return BeatGANsEncoderModel(self)
class BeatGANsEncoderModel(nn.Module):
"""
The half UNet model with attention and timestep embedding.
For usage, see UNet.
"""
def __init__(self, conf: BeatGANsEncoderConfig):
super().__init__()
self.conf = conf
self.dtype = th.float32
if conf.use_time_condition:
time_embed_dim = conf.model_channels * 4
self.time_embed = nn.Sequential(
linear(conf.model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
else:
time_embed_dim = None
ch = int(conf.channel_mult[0] * conf.model_channels)
self.input_blocks = nn.ModuleList([
TimestepEmbedSequential(
conv_nd(conf.dims, conf.in_channels, ch, 3, padding=1))
])
self._feature_size = ch
input_block_chans = [ch]
ds = 1
resolution = conf.image_size
for level, mult in enumerate(conf.channel_mult):
for _ in range(conf.num_res_blocks):
layers = [
ResBlockConfig(
ch,
time_embed_dim,
conf.dropout,
out_channels=int(mult * conf.model_channels),
dims=conf.dims,
use_condition=conf.use_time_condition,
use_checkpoint=conf.use_checkpoint,
).make_model()
]
ch = int(mult * conf.model_channels)
if resolution in conf.attention_resolutions:
layers.append(
AttentionBlock(
ch,
use_checkpoint=conf.use_checkpoint,
num_heads=conf.num_heads,
num_head_channels=conf.num_head_channels,
use_new_attention_order=conf.
use_new_attention_order,
))
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(conf.channel_mult) - 1:
resolution //= 2
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlockConfig(
ch,
time_embed_dim,
conf.dropout,
out_channels=out_ch,
dims=conf.dims,
use_condition=conf.use_time_condition,
use_checkpoint=conf.use_checkpoint,
down=True,
).make_model() if (
conf.resblock_updown
) else Downsample(ch,
conf.conv_resample,
dims=conf.dims,
out_channels=out_ch)))
ch = out_ch
input_block_chans.append(ch)
ds *= 2
self._feature_size += ch
self.middle_block = TimestepEmbedSequential(
ResBlockConfig(
ch,
time_embed_dim,
conf.dropout,
dims=conf.dims,
use_condition=conf.use_time_condition,
use_checkpoint=conf.use_checkpoint,
).make_model(),
AttentionBlock(
ch,
use_checkpoint=conf.use_checkpoint,
num_heads=conf.num_heads,
num_head_channels=conf.num_head_channels,
use_new_attention_order=conf.use_new_attention_order,
),
ResBlockConfig(
ch,
time_embed_dim,
conf.dropout,
dims=conf.dims,
use_condition=conf.use_time_condition,
use_checkpoint=conf.use_checkpoint,
).make_model(),
)
self._feature_size += ch
if conf.pool == "adaptivenonzero":
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
nn.AdaptiveAvgPool1d(1),
conv_nd(conf.dims, ch, conf.out_channels, 1),
nn.Flatten(),
)
else:
raise NotImplementedError(f"Unexpected {conf.pool} pooling")
def forward(self, x, t=None, return_2d_feature=False):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:return: an [N x K] Tensor of outputs.
"""
if self.conf.use_time_condition:
emb = self.time_embed(timestep_embedding(t, self.model_channels))
else:
emb = None
results = []
h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb=emb)
if self.conf.pool.startswith("spatial"):
results.append(h.type(x.dtype).mean(dim=(2, 3)))
h = self.middle_block(h, emb=emb)
if self.conf.pool.startswith("spatial"):
results.append(h.type(x.dtype).mean(dim=(2, 3)))
h = th.cat(results, axis=-1)
else:
h = h.type(x.dtype)
h_2d = h
h = self.out(h)
if return_2d_feature:
return h, h_2d
else:
return h
def forward_flatten(self, x):
"""
transform the last 2d feature into a flatten vector
"""
h = self.out(x)
return h
class SuperResModel(BeatGANsUNetModel):
"""
A UNetModel that performs super-resolution.
Expects an extra kwarg `low_res` to condition on a low-resolution image.
"""
def __init__(self, image_size, in_channels, *args, **kwargs):
super().__init__(image_size, in_channels * 2, *args, **kwargs)
def forward(self, x, timesteps, low_res=None, **kwargs):
_, _, new_height, new_width = x.shape
upsampled = F.interpolate(low_res, (new_height, new_width),
mode="bilinear")
x = th.cat([x, upsampled], dim=1)
return super().forward(x, timesteps, **kwargs)

310
model/unet_autoenc.py Normal file
View file

@ -0,0 +1,310 @@
import torch
from torch import Tensor, nn
from torch.nn.functional import silu
from .latentnet import *
from .unet import *
from choices import *
@dataclass
class BeatGANsAutoencConfig(BeatGANsUNetConfig):
enc_out_channels: int = 512
enc_attn_resolutions: Tuple[int] = None
enc_pool: str = 'depthconv'
enc_num_res_block: int = 2
enc_channel_mult: Tuple[int] = None
enc_grad_checkpoint: bool = False
latent_net_conf: MLPSkipNetConfig = None
def make_model(self):
return BeatGANsAutoencModel(self)
class BeatGANsAutoencModel(BeatGANsUNetModel):
def __init__(self, conf: BeatGANsAutoencConfig):
super().__init__(conf)
self.conf = conf
self.time_embed = TimeStyleSeperateEmbed(
time_channels=conf.model_channels,
time_out_channels=conf.embed_channels,
)
self.encoder = BeatGANsEncoderConfig(
image_size=conf.image_size,
in_channels=conf.in_channels,
model_channels=conf.model_channels,
out_hid_channels=conf.enc_out_channels,
out_channels=conf.enc_out_channels,
num_res_blocks=conf.enc_num_res_block,
attention_resolutions=(conf.enc_attn_resolutions
or conf.attention_resolutions),
dropout=conf.dropout,
channel_mult=conf.enc_channel_mult or conf.channel_mult,
use_time_condition=False,
conv_resample=conf.conv_resample,
dims=conf.dims,
use_checkpoint=conf.use_checkpoint or conf.enc_grad_checkpoint,
num_heads=conf.num_heads,
num_head_channels=conf.num_head_channels,
resblock_updown=conf.resblock_updown,
use_new_attention_order=conf.use_new_attention_order,
pool=conf.enc_pool,
).make_model()
self.user_classifier = UserClassifier(conf.enc_out_channels//2, conf.num_users)
self.non_user_classifier = UserClassifierGradientReverse(conf.enc_out_channels//2, conf.num_users)
if conf.latent_net_conf is not None:
self.latent_net = conf.latent_net_conf.make_model()
def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
"""
Reparameterization trick to sample from N(mu, var) from
N(0,1).
:param mu: (Tensor) Mean of the latent Gaussian [B x D]
:param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
:return: (Tensor) [B x D]
"""
assert self.conf.is_stochastic
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps * std + mu
def sample_z(self, n: int, device):
assert self.conf.is_stochastic
return torch.randn(n, self.conf.enc_out_channels, device=device)
def noise_to_cond(self, noise: Tensor):
raise NotImplementedError()
assert self.conf.noise_net_conf is not None
return self.noise_net.forward(noise)
def encode(self, x):
cond = self.encoder.forward(x)
return {'cond': cond}
@property
def stylespace_sizes(self):
modules = list(self.input_blocks.modules()) + list(
self.middle_block.modules()) + list(self.output_blocks.modules())
sizes = []
for module in modules:
if isinstance(module, ResBlock):
linear = module.cond_emb_layers[-1]
sizes.append(linear.weight.shape[0])
return sizes
def encode_stylespace(self, x, return_vector: bool = True):
"""
encode to style space
"""
modules = list(self.input_blocks.modules()) + list(
self.middle_block.modules()) + list(self.output_blocks.modules())
cond = self.encoder.forward(x)
S = []
for module in modules:
if isinstance(module, ResBlock):
s = module.cond_emb_layers.forward(cond)
S.append(s)
if return_vector:
return torch.cat(S, dim=1)
else:
return S
def forward(self,
x,
t,
y=None,
x_start=None,
cond=None,
style=None,
noise=None,
t_cond=None,
**kwargs):
"""
Apply the model to an input batch.
Args:
x_start: the original image to encode
cond: output of the encoder
noise: random noise (to predict the cond)
"""
if t_cond is None:
t_cond = t
if noise is not None:
cond = self.noise_to_cond(noise)
if cond is None:
if x is not None:
assert len(x) == len(x_start), f'{len(x)} != {len(x_start)}'
tmp = self.encode(x_start)
cond = tmp['cond']
if t is not None:
_t_emb = timestep_embedding(t, self.conf.model_channels)
_t_cond_emb = timestep_embedding(t_cond, self.conf.model_channels)
else:
_t_emb = None
_t_cond_emb = None
if self.conf.resnet_two_cond:
res = self.time_embed.forward(
time_emb=_t_emb,
cond=cond,
time_cond_emb=_t_cond_emb
)
else:
raise NotImplementedError()
if self.conf.resnet_two_cond:
emb = res.time_emb
cond_emb = res.emb
else:
emb = res.emb
cond_emb = None
style = style or res.style
assert (y is not None) == (
self.conf.num_classes is not None
), "must specify y if and only if the model is class-conditional"
if self.conf.num_classes is not None:
raise NotImplementedError()
enc_time_emb = emb
mid_time_emb = emb
dec_time_emb = emb
enc_cond_emb = cond_emb
mid_cond_emb = cond_emb
dec_cond_emb = cond_emb
if self.conf.num_users is not None:
user_pred = self.user_classifier(cond_emb[:, :self.conf.enc_out_channels // 2])
non_user_pred = self.non_user_classifier(cond_emb[:, self.conf.enc_out_channels // 2:])
hs = [[] for _ in range(len(self.conf.channel_mult))]
if x is not None:
h = x.type(self.dtype)
# input blocks
k = 0
for i in range(len(self.input_num_blocks)):
for j in range(self.input_num_blocks[i]):
h = self.input_blocks[k](h,
emb=enc_time_emb,
cond=enc_cond_emb)
'''print(i, j, h.shape)
if h.shape[-1]%2==1:
pdb.set_trace()'''
hs[i].append(h)
k += 1
assert k == len(self.input_blocks)
# middle blocks
h = self.middle_block(h, emb=mid_time_emb, cond=mid_cond_emb)
else:
h = None
hs = [[] for _ in range(len(self.conf.channel_mult))]
# output blocks
k = 0
for i in range(len(self.output_num_blocks)):
for j in range(self.output_num_blocks[i]):
try:
lateral = hs[-i - 1].pop()
except IndexError:
lateral = None
h = self.output_blocks[k](h,
emb=dec_time_emb,
cond=dec_cond_emb,
lateral=lateral)
k += 1
pred = self.out(h)
return AutoencReturn(pred=pred, cond=cond, user_pred=user_pred, non_user_pred=non_user_pred)
class UserClassifier(nn.Module):
def __init__(self, in_channels, num_classes):
super().__init__()
self.fc = nn.Sequential(
nn.Linear(in_channels, 256),
nn.ReLU(),
nn.Linear(256, num_classes),
nn.Softmax(dim=1)
)
def forward(self, x):
return self.fc(x)
class GradReverse(torch.autograd.Function):
"""
Implement the gradient reversal layer for the convenience of domain adaptation neural network.
The forward part is the identity function while the backward part is the negative function.
"""
@staticmethod
def forward(ctx, x):
return x.view_as(x)
@staticmethod
def backward(ctx, grad_output):
return grad_output.neg()
class GradientReversalLayer(nn.Module):
def __init__(self):
super(GradientReversalLayer, self).__init__()
def forward(self, inputs):
return GradReverse.apply(inputs)
class UserClassifierGradientReverse(nn.Module):
def __init__(self, in_channels, num_classes):
super().__init__()
self.grl = GradientReversalLayer()
self.fc = UserClassifier(in_channels, num_classes)
def forward(self, x):
x = self.grl(x)
return self.fc(x)
class AutoencReturn(NamedTuple):
pred: Tensor
cond: Tensor = None
user_pred: Tensor = None
non_user_pred: Tensor = None
class EmbedReturn(NamedTuple):
emb: Tensor = None
time_emb: Tensor = None
style: Tensor = None
class TimeStyleSeperateEmbed(nn.Module):
def __init__(self, time_channels, time_out_channels):
super().__init__()
self.time_embed = nn.Sequential(
linear(time_channels, time_out_channels),
nn.SiLU(),
linear(time_out_channels, time_out_channels),
)
self.cond_combine = nn.Sequential(
nn.Linear(time_out_channels * 2, time_out_channels),
nn.SiLU()
)
self.style = nn.Identity()
def forward(self, time_emb=None, cond=None, **kwargs):
if time_emb is None:
time_emb = None
else:
time_emb = self.time_embed(time_emb)
style = self.style(cond)
return EmbedReturn(emb=style, time_emb=time_emb, style=style)

55
templates.py Normal file
View file

@ -0,0 +1,55 @@
from experiment import *
def autoenc_base():
conf = TrainConfig()
conf.batch_size = 32
conf.beatgans_gen_type = GenerativeType.ddim
conf.beta_scheduler = 'linear'
conf.data_name = 'ffhq'
conf.diffusion_type = 'beatgans'
conf.eval_ema_every_samples = 200_000
conf.eval_every_samples = 200_000
conf.fp16 = True
conf.lr = 1e-4
conf.model_name = ModelName.beatgans_autoenc
conf.net_attn = (16, )
conf.net_beatgans_attn_head = 1
conf.net_beatgans_embed_channels = 128
conf.net_beatgans_resnet_two_cond = True
conf.net_ch_mult = (1, 2, 4, 8)
conf.net_ch = 64
conf.net_enc_channel_mult = (1, 2, 4, 8, 8)
conf.net_enc_pool = 'adaptivenonzero'
conf.sample_size = 32
conf.T_eval = 20
conf.T = 1000
conf.make_model_conf()
return conf
def mouse_autoenc(mode):
num_users = {'Clarkson': 75}
conf = autoenc_base()
conf.scale_up_gpus(1)
conf.img_size = 256
conf.net_ch = 128
conf.net_ch_mult = (1, 1, 2, 2, 4, 4)
conf.net_enc_channel_mult = (1, 1, 2, 2, 4, 4, 4)
conf.eval_every_samples = 10_000_000
conf.eval_ema_every_samples = 10_000_000
conf.total_samples = 200_000_000
conf.batch_size = 512
conf.name = 'mouse_autoenc'
conf.pretrainDataset = 'Clarkson'
conf.data_name = 'Clarkson'
conf.path = f'../mousedata/{conf.pretrainDataset}/'
conf.AEwin = 8
conf.slid = 1
conf.timeWinFreq = 20
conf.num_users = num_users[conf.pretrainDataset]
conf.mode = mode
conf.make_model_conf()
return conf