init
This commit is contained in:
commit
b102c2a534
20 changed files with 4305 additions and 0 deletions
127
choices.py
Executable file
127
choices.py
Executable file
|
@ -0,0 +1,127 @@
|
|||
from enum import Enum
|
||||
from torch import nn
|
||||
|
||||
|
||||
class TrainMode(Enum):
|
||||
manipulate = 'manipulate'
|
||||
diffusion = 'diffusion'
|
||||
latent_diffusion = 'latentdiffusion'
|
||||
|
||||
def is_manipulate(self):
|
||||
return self in [
|
||||
TrainMode.manipulate,
|
||||
]
|
||||
|
||||
def is_diffusion(self):
|
||||
return self in [
|
||||
TrainMode.diffusion,
|
||||
TrainMode.latent_diffusion,
|
||||
]
|
||||
|
||||
def is_autoenc(self):
|
||||
return self in [
|
||||
TrainMode.diffusion,
|
||||
]
|
||||
|
||||
def is_latent_diffusion(self):
|
||||
return self in [
|
||||
TrainMode.latent_diffusion,
|
||||
]
|
||||
|
||||
def use_latent_net(self):
|
||||
return self.is_latent_diffusion()
|
||||
|
||||
def require_dataset_infer(self):
|
||||
"""
|
||||
whether training in this mode requires the latent variables to be available?
|
||||
"""
|
||||
return self in [
|
||||
TrainMode.latent_diffusion,
|
||||
TrainMode.manipulate,
|
||||
]
|
||||
|
||||
class ModelType(Enum):
|
||||
"""
|
||||
Kinds of the backbone models
|
||||
"""
|
||||
|
||||
ddpm = 'ddpm'
|
||||
autoencoder = 'autoencoder'
|
||||
|
||||
def has_autoenc(self):
|
||||
return self in [
|
||||
ModelType.autoencoder,
|
||||
]
|
||||
|
||||
def can_sample(self):
|
||||
return self in [ModelType.ddpm]
|
||||
|
||||
|
||||
class ModelName(Enum):
|
||||
"""
|
||||
List of all supported model classes
|
||||
"""
|
||||
|
||||
beatgans_ddpm = 'beatgans_ddpm'
|
||||
beatgans_autoenc = 'beatgans_autoenc'
|
||||
|
||||
|
||||
class ModelMeanType(Enum):
|
||||
"""
|
||||
Which type of output the model predicts.
|
||||
"""
|
||||
|
||||
eps = 'eps'
|
||||
|
||||
|
||||
class ModelVarType(Enum):
|
||||
"""
|
||||
What is used as the model's output variance.
|
||||
|
||||
The LEARNED_RANGE option has been added to allow the model to predict
|
||||
values between FIXED_SMALL and FIXED_LARGE, making its job easier.
|
||||
"""
|
||||
|
||||
fixed_small = 'fixed_small'
|
||||
fixed_large = 'fixed_large'
|
||||
|
||||
|
||||
class LossType(Enum):
|
||||
mse = 'mse'
|
||||
l1 = 'l1'
|
||||
|
||||
|
||||
class GenerativeType(Enum):
|
||||
"""
|
||||
How's a sample generated
|
||||
"""
|
||||
|
||||
ddpm = 'ddpm'
|
||||
ddim = 'ddim'
|
||||
|
||||
|
||||
class OptimizerType(Enum):
|
||||
adam = 'adam'
|
||||
adamw = 'adamw'
|
||||
|
||||
|
||||
class Activation(Enum):
|
||||
none = 'none'
|
||||
relu = 'relu'
|
||||
lrelu = 'lrelu'
|
||||
silu = 'silu'
|
||||
tanh = 'tanh'
|
||||
|
||||
def get_act(self):
|
||||
if self == Activation.none:
|
||||
return nn.Identity()
|
||||
elif self == Activation.relu:
|
||||
return nn.ReLU()
|
||||
elif self == Activation.lrelu:
|
||||
return nn.LeakyReLU(negative_slope=0.2)
|
||||
elif self == Activation.silu:
|
||||
return nn.SiLU()
|
||||
elif self == Activation.tanh:
|
||||
return nn.Tanh()
|
||||
else:
|
||||
raise NotImplementedError()
|
303
config.py
Normal file
303
config.py
Normal file
|
@ -0,0 +1,303 @@
|
|||
from model.unet import ScaleAt
|
||||
from model.latentnet import *
|
||||
from diffusion.resample import UniformSampler
|
||||
from diffusion.diffusion import space_timesteps
|
||||
from typing import Tuple
|
||||
|
||||
from config_base import BaseConfig
|
||||
from dataset import *
|
||||
from diffusion import *
|
||||
from diffusion.base import GenerativeType, LossType, ModelMeanType, ModelVarType, get_named_beta_schedule
|
||||
from model import *
|
||||
from choices import *
|
||||
|
||||
@dataclass
|
||||
class PretrainConfig(BaseConfig):
|
||||
name: str
|
||||
path: str
|
||||
|
||||
@dataclass
|
||||
class TrainConfig(BaseConfig):
|
||||
seed: int = 0
|
||||
train_mode: TrainMode = TrainMode.diffusion
|
||||
train_cond0_prob: float = 0
|
||||
train_pred_xstart_detach: bool = True
|
||||
train_interpolate_prob: float = 0
|
||||
train_interpolate_img: bool = False
|
||||
accum_batches: int = 1
|
||||
autoenc_mid_attn: bool = True
|
||||
batch_size: int = 512
|
||||
batch_size_eval: int = 4
|
||||
beatgans_gen_type: GenerativeType = GenerativeType.ddim
|
||||
beatgans_loss_type: LossType = LossType.mse
|
||||
beatgans_model_mean_type: ModelMeanType = ModelMeanType.eps
|
||||
beatgans_model_var_type: ModelVarType = ModelVarType.fixed_large
|
||||
beatgans_rescale_timesteps: bool = False
|
||||
latent_infer_path: str = None
|
||||
latent_znormalize: bool = False
|
||||
latent_gen_type: GenerativeType = GenerativeType.ddim
|
||||
latent_loss_type: LossType = LossType.mse
|
||||
latent_model_mean_type: ModelMeanType = ModelMeanType.eps
|
||||
latent_model_var_type: ModelVarType = ModelVarType.fixed_large
|
||||
latent_rescale_timesteps: bool = False
|
||||
latent_T_eval: int = 1_000
|
||||
latent_clip_sample: bool = False
|
||||
latent_beta_scheduler: str = 'linear'
|
||||
beta_scheduler: str = 'linear'
|
||||
data_name: str = ''
|
||||
data_val_name: str = None
|
||||
diffusion_type: str = None
|
||||
dropout: float = 0.1
|
||||
ema_decay: float = 0.9999
|
||||
eval_num_images: int = 5_000
|
||||
eval_every_samples: int = 200_000
|
||||
eval_ema_every_samples: int = 200_000
|
||||
fid_use_torch: bool = True
|
||||
fp16: bool = False
|
||||
grad_clip: float = 1
|
||||
img_size: int = 64
|
||||
lr: float = 0.0001
|
||||
optimizer: OptimizerType = OptimizerType.adam
|
||||
weight_decay: float = 0
|
||||
model_conf: ModelConfig = None
|
||||
model_name: ModelName = None
|
||||
model_type: ModelType = None
|
||||
net_attn: Tuple[int] = None
|
||||
net_beatgans_attn_head: int = 1
|
||||
net_beatgans_embed_channels: int = 128
|
||||
net_resblock_updown: bool = True
|
||||
net_enc_use_time: bool = False
|
||||
net_enc_pool: str = 'adaptivenonzero'
|
||||
net_beatgans_gradient_checkpoint: bool = False
|
||||
net_beatgans_resnet_two_cond: bool = False
|
||||
net_beatgans_resnet_use_zero_module: bool = True
|
||||
net_beatgans_resnet_scale_at: ScaleAt = ScaleAt.after_norm
|
||||
net_beatgans_resnet_cond_channels: int = None
|
||||
net_ch_mult: Tuple[int] = None
|
||||
net_ch: int = 64
|
||||
net_enc_attn: Tuple[int] = None
|
||||
net_enc_k: int = None
|
||||
net_enc_num_res_blocks: int = 2
|
||||
net_enc_channel_mult: Tuple[int] = None
|
||||
net_enc_grad_checkpoint: bool = False
|
||||
net_autoenc_stochastic: bool = False
|
||||
net_latent_activation: Activation = Activation.silu
|
||||
net_latent_channel_mult: Tuple[int] = (1, 2, 4)
|
||||
net_latent_condition_bias: float = 0
|
||||
net_latent_dropout: float = 0
|
||||
net_latent_layers: int = None
|
||||
net_latent_net_last_act: Activation = Activation.none
|
||||
net_latent_net_type: LatentNetType = LatentNetType.none
|
||||
net_latent_num_hid_channels: int = 1024
|
||||
net_latent_num_time_layers: int = 2
|
||||
net_latent_skip_layers: Tuple[int] = None
|
||||
net_latent_time_emb_channels: int = 64
|
||||
net_latent_use_norm: bool = False
|
||||
net_latent_time_last_act: bool = False
|
||||
net_num_res_blocks: int = 2
|
||||
net_num_input_res_blocks: int = None
|
||||
net_enc_num_cls: int = None
|
||||
num_workers: int = 4
|
||||
parallel: bool = False
|
||||
postfix: str = ''
|
||||
sample_size: int = 64
|
||||
sample_every_samples: int = 20_000
|
||||
save_every_samples: int = 100_000
|
||||
style_ch: int = 128
|
||||
T_eval: int = 1_000
|
||||
T_sampler: str = 'uniform'
|
||||
T: int = 1_000
|
||||
total_samples: int = 10_000_000
|
||||
warmup: int = 0
|
||||
pretrain: PretrainConfig = None
|
||||
continue_from: PretrainConfig = None
|
||||
eval_programs: Tuple[str] = None
|
||||
eval_path: str = None
|
||||
base_dir: str = 'checkpoints'
|
||||
name: str = ''
|
||||
logdir: str = f'{base_dir}{name}'
|
||||
num_users: int = 0
|
||||
|
||||
def __post_init__(self):
|
||||
self.batch_size_eval = self.batch_size_eval or self.batch_size
|
||||
self.data_val_name = self.data_val_name or self.data_name
|
||||
|
||||
def scale_up_gpus(self, num_gpus, num_nodes=1):
|
||||
self.eval_ema_every_samples *= num_gpus * num_nodes
|
||||
self.eval_every_samples *= num_gpus * num_nodes
|
||||
self.sample_every_samples *= num_gpus * num_nodes
|
||||
self.batch_size *= num_gpus * num_nodes
|
||||
self.batch_size_eval *= num_gpus * num_nodes
|
||||
return self
|
||||
|
||||
@property
|
||||
def batch_size_effective(self):
|
||||
return self.batch_size * self.accum_batches
|
||||
|
||||
def _make_diffusion_conf(self, T=None):
|
||||
if self.diffusion_type == 'beatgans':
|
||||
if self.beatgans_gen_type == GenerativeType.ddpm:
|
||||
section_counts = [T]
|
||||
elif self.beatgans_gen_type == GenerativeType.ddim:
|
||||
section_counts = f'ddim{T}'
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
return SpacedDiffusionBeatGansConfig(
|
||||
gen_type=self.beatgans_gen_type,
|
||||
model_type=self.model_type,
|
||||
betas=get_named_beta_schedule(self.beta_scheduler, self.T),
|
||||
model_mean_type=self.beatgans_model_mean_type,
|
||||
model_var_type=self.beatgans_model_var_type,
|
||||
loss_type=self.beatgans_loss_type,
|
||||
rescale_timesteps=self.beatgans_rescale_timesteps,
|
||||
use_timesteps=space_timesteps(num_timesteps=self.T,
|
||||
section_counts=section_counts),
|
||||
fp16=self.fp16,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _make_latent_diffusion_conf(self, T=None):
|
||||
if self.latent_gen_type == GenerativeType.ddpm:
|
||||
section_counts = [T]
|
||||
elif self.latent_gen_type == GenerativeType.ddim:
|
||||
section_counts = f'ddim{T}'
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
return SpacedDiffusionBeatGansConfig(
|
||||
train_pred_xstart_detach=self.train_pred_xstart_detach,
|
||||
gen_type=self.latent_gen_type,
|
||||
model_type=ModelType.ddpm,
|
||||
betas=get_named_beta_schedule(self.latent_beta_scheduler, self.T),
|
||||
model_mean_type=self.latent_model_mean_type,
|
||||
model_var_type=self.latent_model_var_type,
|
||||
loss_type=self.latent_loss_type,
|
||||
rescale_timesteps=self.latent_rescale_timesteps,
|
||||
use_timesteps=space_timesteps(num_timesteps=self.T,
|
||||
section_counts=section_counts),
|
||||
fp16=self.fp16,
|
||||
)
|
||||
|
||||
@property
|
||||
def model_out_channels(self):
|
||||
return 2
|
||||
@property
|
||||
def model_input_channels(self):
|
||||
return 2
|
||||
|
||||
def make_T_sampler(self):
|
||||
if self.T_sampler == 'uniform':
|
||||
return UniformSampler(self.T)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def make_diffusion_conf(self):
|
||||
return self._make_diffusion_conf(self.T)
|
||||
|
||||
def make_eval_diffusion_conf(self):
|
||||
return self._make_diffusion_conf(T=self.T_eval)
|
||||
|
||||
def make_latent_diffusion_conf(self):
|
||||
return self._make_latent_diffusion_conf(T=self.T)
|
||||
|
||||
def make_latent_eval_diffusion_conf(self):
|
||||
return self._make_latent_diffusion_conf(T=self.latent_T_eval)
|
||||
|
||||
def make_dataset(self, taskdata, tasklabels):
|
||||
return SimpleSet(taskdata, tasklabels, intflag=True)
|
||||
|
||||
def make_model_conf(self):
|
||||
if self.model_name == ModelName.beatgans_ddpm:
|
||||
self.model_type = ModelType.ddpm
|
||||
self.model_conf = BeatGANsUNetConfig(
|
||||
attention_resolutions=self.net_attn,
|
||||
channel_mult=self.net_ch_mult,
|
||||
conv_resample=True,
|
||||
dims=2,
|
||||
dropout=self.dropout,
|
||||
embed_channels=self.net_beatgans_embed_channels,
|
||||
image_size=self.img_size,
|
||||
in_channels=self.model_input_channels,
|
||||
model_channels=self.net_ch,
|
||||
num_classes=None,
|
||||
num_head_channels=-1,
|
||||
num_heads_upsample=-1,
|
||||
num_heads=self.net_beatgans_attn_head,
|
||||
num_res_blocks=self.net_num_res_blocks,
|
||||
num_input_res_blocks=self.net_num_input_res_blocks,
|
||||
out_channels=self.model_out_channels,
|
||||
resblock_updown=self.net_resblock_updown,
|
||||
use_checkpoint=self.net_beatgans_gradient_checkpoint,
|
||||
use_new_attention_order=False,
|
||||
resnet_two_cond=self.net_beatgans_resnet_two_cond,
|
||||
resnet_use_zero_module=self.
|
||||
net_beatgans_resnet_use_zero_module,
|
||||
)
|
||||
elif self.model_name in [
|
||||
ModelName.beatgans_autoenc,
|
||||
]:
|
||||
cls = BeatGANsAutoencConfig
|
||||
if self.model_name == ModelName.beatgans_autoenc:
|
||||
self.model_type = ModelType.autoencoder
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
if self.net_latent_net_type == LatentNetType.none:
|
||||
latent_net_conf = None
|
||||
elif self.net_latent_net_type == LatentNetType.skip:
|
||||
latent_net_conf = MLPSkipNetConfig(
|
||||
num_channels=self.style_ch,
|
||||
skip_layers=self.net_latent_skip_layers,
|
||||
num_hid_channels=self.net_latent_num_hid_channels,
|
||||
num_layers=self.net_latent_layers,
|
||||
num_time_emb_channels=self.net_latent_time_emb_channels,
|
||||
activation=self.net_latent_activation,
|
||||
use_norm=self.net_latent_use_norm,
|
||||
condition_bias=self.net_latent_condition_bias,
|
||||
dropout=self.net_latent_dropout,
|
||||
last_act=self.net_latent_net_last_act,
|
||||
num_time_layers=self.net_latent_num_time_layers,
|
||||
time_last_act=self.net_latent_time_last_act,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
self.model_conf = cls(
|
||||
attention_resolutions=self.net_attn,
|
||||
channel_mult=self.net_ch_mult,
|
||||
conv_resample=True,
|
||||
dims=1,
|
||||
dropout=self.dropout,
|
||||
embed_channels=self.net_beatgans_embed_channels,
|
||||
enc_out_channels=self.style_ch,
|
||||
enc_pool=self.net_enc_pool,
|
||||
enc_num_res_block=self.net_enc_num_res_blocks,
|
||||
enc_channel_mult=self.net_enc_channel_mult,
|
||||
enc_grad_checkpoint=self.net_enc_grad_checkpoint,
|
||||
enc_attn_resolutions=self.net_enc_attn,
|
||||
image_size=self.img_size,
|
||||
in_channels=self.model_input_channels,
|
||||
model_channels=self.net_ch,
|
||||
num_classes=None,
|
||||
num_head_channels=-1,
|
||||
num_heads_upsample=-1,
|
||||
num_heads=self.net_beatgans_attn_head,
|
||||
num_res_blocks=self.net_num_res_blocks,
|
||||
num_input_res_blocks=self.net_num_input_res_blocks,
|
||||
out_channels=self.model_out_channels,
|
||||
resblock_updown=self.net_resblock_updown,
|
||||
use_checkpoint=self.net_beatgans_gradient_checkpoint,
|
||||
use_new_attention_order=False,
|
||||
resnet_two_cond=self.net_beatgans_resnet_two_cond,
|
||||
resnet_use_zero_module=self.
|
||||
net_beatgans_resnet_use_zero_module,
|
||||
latent_net_conf=latent_net_conf,
|
||||
resnet_cond_channels=self.net_beatgans_resnet_cond_channels,
|
||||
num_users = self.num_users,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(self.model_name)
|
||||
|
||||
return self.model_conf
|
71
config_base.py
Executable file
71
config_base.py
Executable file
|
@ -0,0 +1,71 @@
|
|||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseConfig:
|
||||
def clone(self):
|
||||
return deepcopy(self)
|
||||
|
||||
def inherit(self, another):
|
||||
"""inherit common keys from a given config"""
|
||||
common_keys = set(self.__dict__.keys()) & set(another.__dict__.keys())
|
||||
for k in common_keys:
|
||||
setattr(self, k, getattr(another, k))
|
||||
|
||||
def propagate(self):
|
||||
"""push down the configuration to all members"""
|
||||
for k, v in self.__dict__.items():
|
||||
if isinstance(v, BaseConfig):
|
||||
v.inherit(self)
|
||||
v.propagate()
|
||||
|
||||
def save(self, save_path):
|
||||
"""save config to json file"""
|
||||
dirname = os.path.dirname(save_path)
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
conf = self.as_dict_jsonable()
|
||||
with open(save_path, 'w') as f:
|
||||
json.dump(conf, f)
|
||||
|
||||
def load(self, load_path):
|
||||
"""load json config"""
|
||||
with open(load_path) as f:
|
||||
conf = json.load(f)
|
||||
self.from_dict(conf)
|
||||
|
||||
def from_dict(self, dict, strict=False):
|
||||
for k, v in dict.items():
|
||||
if not hasattr(self, k):
|
||||
if strict:
|
||||
raise ValueError(f"loading extra '{k}'")
|
||||
else:
|
||||
print(f"loading extra '{k}'")
|
||||
continue
|
||||
if isinstance(self.__dict__[k], BaseConfig):
|
||||
self.__dict__[k].from_dict(v)
|
||||
else:
|
||||
self.__dict__[k] = v
|
||||
|
||||
def as_dict_jsonable(self):
|
||||
conf = {}
|
||||
for k, v in self.__dict__.items():
|
||||
if isinstance(v, BaseConfig):
|
||||
conf[k] = v.as_dict_jsonable()
|
||||
else:
|
||||
if jsonable(v):
|
||||
conf[k] = v
|
||||
else:
|
||||
pass
|
||||
return conf
|
||||
|
||||
|
||||
def jsonable(x):
|
||||
try:
|
||||
json.dumps(x)
|
||||
return True
|
||||
except TypeError:
|
||||
return False
|
21
dataset.py
Executable file
21
dataset.py
Executable file
|
@ -0,0 +1,21 @@
|
|||
from torch.utils.data import Dataset
|
||||
import torch
|
||||
import pandas as pd
|
||||
|
||||
def loadDataset(conf):
|
||||
eval('taskdata, tasklabels = load%s(conf)'%(conf.pretrainDataset)) # plug in the function to load your own dataset
|
||||
tasklabels = pd.DataFrame(tasklabels, columns=['user'])
|
||||
print('taskdata.shape:', taskdata.shape) # (N, 2, window_length*sample_freq)
|
||||
return taskdata, tasklabels
|
||||
|
||||
class SimpleSet(Dataset):
|
||||
def __init__(self, data, labels, intflag=True):
|
||||
self.data = torch.tensor(data, dtype=torch.float)
|
||||
if intflag:
|
||||
self.label = torch.tensor(labels, dtype=torch.long)
|
||||
else:
|
||||
self.label = torch.tensor(labels, dtype=torch.float)
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
def __getitem__(self, index):
|
||||
return self.data[index], self.label[index]
|
6
diffusion/__init__.py
Normal file
6
diffusion/__init__.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
from typing import Union
|
||||
|
||||
from .diffusion import SpacedDiffusionBeatGans, SpacedDiffusionBeatGansConfig
|
||||
|
||||
Sampler = Union[SpacedDiffusionBeatGans]
|
||||
SamplerConfig = Union[SpacedDiffusionBeatGansConfig]
|
1086
diffusion/base.py
Executable file
1086
diffusion/base.py
Executable file
File diff suppressed because it is too large
Load diff
154
diffusion/diffusion.py
Executable file
154
diffusion/diffusion.py
Executable file
|
@ -0,0 +1,154 @@
|
|||
from .base import *
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
def space_timesteps(num_timesteps, section_counts):
|
||||
"""
|
||||
Create a list of timesteps to use from an original diffusion process,
|
||||
given the number of timesteps we want to take from equally-sized portions
|
||||
of the original process.
|
||||
|
||||
For example, if there's 300 timesteps and the section counts are [10,15,20]
|
||||
then the first 100 timesteps are strided to be 10 timesteps, the second 100
|
||||
are strided to be 15 timesteps, and the final 100 are strided to be 20.
|
||||
|
||||
If the stride is a string starting with "ddim", then the fixed striding
|
||||
from the DDIM paper is used, and only one section is allowed.
|
||||
|
||||
:param num_timesteps: the number of diffusion steps in the original
|
||||
process to divide up.
|
||||
:param section_counts: either a list of numbers, or a string containing
|
||||
comma-separated numbers, indicating the step count
|
||||
per section. As a special case, use "ddimN" where N
|
||||
is a number of steps to use the striding from the
|
||||
DDIM paper.
|
||||
:return: a set of diffusion steps from the original process to use.
|
||||
"""
|
||||
if isinstance(section_counts, str):
|
||||
if section_counts.startswith("ddim"):
|
||||
desired_count = int(section_counts[len("ddim"):])
|
||||
for i in range(1, num_timesteps):
|
||||
if len(range(0, num_timesteps, i)) == desired_count:
|
||||
return set(range(0, num_timesteps, i))
|
||||
raise ValueError(
|
||||
f"cannot create exactly {num_timesteps} steps with an integer stride"
|
||||
)
|
||||
section_counts = [int(x) for x in section_counts.split(",")]
|
||||
size_per = num_timesteps // len(section_counts)
|
||||
extra = num_timesteps % len(section_counts)
|
||||
start_idx = 0
|
||||
all_steps = []
|
||||
for i, section_count in enumerate(section_counts):
|
||||
size = size_per + (1 if i < extra else 0)
|
||||
if size < section_count:
|
||||
raise ValueError(
|
||||
f"cannot divide section of {size} steps into {section_count}")
|
||||
if section_count <= 1:
|
||||
frac_stride = 1
|
||||
else:
|
||||
frac_stride = (size - 1) / (section_count - 1)
|
||||
cur_idx = 0.0
|
||||
taken_steps = []
|
||||
for _ in range(section_count):
|
||||
taken_steps.append(start_idx + round(cur_idx))
|
||||
cur_idx += frac_stride
|
||||
all_steps += taken_steps
|
||||
start_idx += size
|
||||
return set(all_steps)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpacedDiffusionBeatGansConfig(GaussianDiffusionBeatGansConfig):
|
||||
use_timesteps: Tuple[int] = None
|
||||
|
||||
def make_sampler(self):
|
||||
return SpacedDiffusionBeatGans(self)
|
||||
|
||||
|
||||
class SpacedDiffusionBeatGans(GaussianDiffusionBeatGans):
|
||||
"""
|
||||
A diffusion process which can skip steps in a base diffusion process.
|
||||
|
||||
:param use_timesteps: a collection (sequence or set) of timesteps from the
|
||||
original diffusion process to retain.
|
||||
:param kwargs: the kwargs to create the base diffusion process.
|
||||
"""
|
||||
def __init__(self, conf: SpacedDiffusionBeatGansConfig):
|
||||
self.conf = conf
|
||||
self.use_timesteps = set(conf.use_timesteps)
|
||||
self.timestep_map = []
|
||||
self.original_num_steps = len(conf.betas)
|
||||
|
||||
base_diffusion = GaussianDiffusionBeatGans(conf)
|
||||
last_alpha_cumprod = 1.0
|
||||
new_betas = []
|
||||
for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
|
||||
if i in self.use_timesteps:
|
||||
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
|
||||
last_alpha_cumprod = alpha_cumprod
|
||||
self.timestep_map.append(i)
|
||||
conf.betas = np.array(new_betas)
|
||||
super().__init__(conf)
|
||||
|
||||
def p_mean_variance(self, model: Model, *args, **kwargs):
|
||||
return super().p_mean_variance(self._wrap_model(model), *args,
|
||||
**kwargs)
|
||||
|
||||
def training_losses(self, model: Model, *args, **kwargs):
|
||||
return super().training_losses(self._wrap_model(model), *args,
|
||||
**kwargs)
|
||||
|
||||
def condition_mean(self, cond_fn, *args, **kwargs):
|
||||
return super().condition_mean(self._wrap_model(cond_fn), *args,
|
||||
**kwargs)
|
||||
|
||||
def condition_score(self, cond_fn, *args, **kwargs):
|
||||
return super().condition_score(self._wrap_model(cond_fn), *args,
|
||||
**kwargs)
|
||||
|
||||
def _wrap_model(self, model: Model):
|
||||
if isinstance(model, _WrappedModel):
|
||||
return model
|
||||
return _WrappedModel(model, self.timestep_map, self.rescale_timesteps,
|
||||
self.original_num_steps)
|
||||
|
||||
def _scale_timesteps(self, t):
|
||||
return t
|
||||
|
||||
|
||||
class _WrappedModel:
|
||||
"""
|
||||
converting the supplied t's to the old t's scales.
|
||||
"""
|
||||
def __init__(self, model, timestep_map, rescale_timesteps,
|
||||
original_num_steps):
|
||||
self.model = model
|
||||
self.timestep_map = timestep_map
|
||||
self.rescale_timesteps = rescale_timesteps
|
||||
self.original_num_steps = original_num_steps
|
||||
|
||||
def forward(self, x, t, t_cond=None, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
t: t's with differrent ranges (can be << T due to smaller eval T) need to be converted to the original t's
|
||||
t_cond: the same as t but can be of different values
|
||||
"""
|
||||
map_tensor = th.tensor(self.timestep_map,
|
||||
device=t.device,
|
||||
dtype=t.dtype)
|
||||
|
||||
def do(t):
|
||||
new_ts = map_tensor[t]
|
||||
if self.rescale_timesteps:
|
||||
new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
|
||||
return new_ts
|
||||
|
||||
if t_cond is not None:
|
||||
t_cond = do(t_cond)
|
||||
return self.model(x=x, t=do(t), t_cond=t_cond, **kwargs)
|
||||
|
||||
def __getattr__(self, name):
|
||||
if hasattr(self.model, name):
|
||||
func = getattr(self.model, name)
|
||||
return func
|
||||
raise AttributeError(name)
|
62
diffusion/resample.py
Normal file
62
diffusion/resample.py
Normal file
|
@ -0,0 +1,62 @@
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
|
||||
|
||||
def create_named_schedule_sampler(name, diffusion):
|
||||
"""
|
||||
Create a ScheduleSampler from a library of pre-defined samplers.
|
||||
|
||||
:param name: the name of the sampler.
|
||||
:param diffusion: the diffusion object to sample for.
|
||||
"""
|
||||
if name == "uniform":
|
||||
return UniformSampler(diffusion)
|
||||
else:
|
||||
raise NotImplementedError(f"unknown schedule sampler: {name}")
|
||||
|
||||
|
||||
class ScheduleSampler(ABC):
|
||||
"""
|
||||
A distribution over timesteps in the diffusion process, intended to reduce
|
||||
variance of the objective.
|
||||
|
||||
By default, samplers perform unbiased importance sampling, in which the
|
||||
objective's mean is unchanged.
|
||||
However, subclasses may override sample() to change how the resampled
|
||||
terms are reweighted, allowing for actual changes in the objective.
|
||||
"""
|
||||
@abstractmethod
|
||||
def weights(self):
|
||||
"""
|
||||
Get a numpy array of weights, one per diffusion step.
|
||||
|
||||
The weights needn't be normalized, but must be positive.
|
||||
"""
|
||||
|
||||
def sample(self, batch_size, device):
|
||||
"""
|
||||
Importance-sample timesteps for a batch.
|
||||
|
||||
:param batch_size: the number of timesteps.
|
||||
:param device: the torch device to save to.
|
||||
:return: a tuple (timesteps, weights):
|
||||
- timesteps: a tensor of timestep indices.
|
||||
- weights: a tensor of weights to scale the resulting losses.
|
||||
"""
|
||||
w = self.weights()
|
||||
p = w / np.sum(w)
|
||||
indices_np = np.random.choice(len(p), size=(batch_size, ), p=p)
|
||||
indices = th.from_numpy(indices_np).long().to(device)
|
||||
weights_np = 1 / (len(p) * p[indices_np])
|
||||
weights = th.from_numpy(weights_np).float().to(device)
|
||||
return indices, weights
|
||||
|
||||
|
||||
class UniformSampler(ScheduleSampler):
|
||||
def __init__(self, num_timesteps):
|
||||
self._weights = np.ones([num_timesteps])
|
||||
|
||||
def weights(self):
|
||||
return self._weights
|
42
dist_utils.py
Executable file
42
dist_utils.py
Executable file
|
@ -0,0 +1,42 @@
|
|||
from typing import List
|
||||
from torch import distributed
|
||||
|
||||
|
||||
def barrier():
|
||||
if distributed.is_initialized():
|
||||
distributed.barrier()
|
||||
else:
|
||||
pass
|
||||
|
||||
|
||||
def broadcast(data, src):
|
||||
if distributed.is_initialized():
|
||||
distributed.broadcast(data, src)
|
||||
else:
|
||||
pass
|
||||
|
||||
|
||||
def all_gather(data: List, src):
|
||||
if distributed.is_initialized():
|
||||
distributed.all_gather(data, src)
|
||||
else:
|
||||
data[0] = src
|
||||
|
||||
|
||||
def get_rank():
|
||||
if distributed.is_initialized():
|
||||
return distributed.get_rank()
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
def get_world_size():
|
||||
if distributed.is_initialized():
|
||||
return distributed.get_world_size()
|
||||
else:
|
||||
return 1
|
||||
|
||||
|
||||
def chunk_size(size, rank, world_size):
|
||||
extra = rank < size % world_size
|
||||
return size // world_size + extra
|
160
environment.yml
Normal file
160
environment.yml
Normal file
|
@ -0,0 +1,160 @@
|
|||
name: dismouse
|
||||
channels:
|
||||
- defaults
|
||||
- conda-forge
|
||||
dependencies:
|
||||
- _libgcc_mutex=0.1=main
|
||||
- _openmp_mutex=5.1=1_gnu
|
||||
- ca-certificates=2023.12.12=h06a4308_0
|
||||
- ld_impl_linux-64=2.38=h1181459_1
|
||||
- libffi=3.4.4=h6a678d5_0
|
||||
- libgcc-ng=11.2.0=h1234567_1
|
||||
- libgomp=11.2.0=h1234567_1
|
||||
- libstdcxx-ng=11.2.0=h1234567_1
|
||||
- ncurses=6.4=h6a678d5_0
|
||||
- openssl=3.0.12=h7f8727e_0
|
||||
- pip=23.3.1=py38h06a4308_0
|
||||
- python=3.8.18=h955ad1f_0
|
||||
- readline=8.2=h5eee18b_0
|
||||
- setuptools=68.2.2=py38h06a4308_0
|
||||
- sqlite=3.41.2=h5eee18b_0
|
||||
- tk=8.6.12=h1ccaba5_0
|
||||
- wheel=0.41.2=py38h06a4308_0
|
||||
- xz=5.4.5=h5eee18b_0
|
||||
- zlib=1.2.13=h5eee18b_0
|
||||
- pip:
|
||||
- absl-py==2.0.0
|
||||
- aiohttp==3.9.1
|
||||
- aiosignal==1.3.1
|
||||
- appdirs==1.4.4
|
||||
- async-timeout==4.0.3
|
||||
- attrs==23.2.0
|
||||
- bleach==6.1.0
|
||||
- bokeh==3.1.1
|
||||
- cachetools==5.3.2
|
||||
- certifi==2023.11.17
|
||||
- charset-normalizer==3.3.2
|
||||
- click==8.1.7
|
||||
- cloudpickle==3.0.0
|
||||
- colorcet==3.1.0
|
||||
- contourpy==1.1.1
|
||||
- cycler==0.12.1
|
||||
- cython==0.29.37
|
||||
- dask==2023.5.0
|
||||
- datashader==0.15.2
|
||||
- datashape==0.5.2
|
||||
- docker-pycreds==0.4.0
|
||||
- filelock==3.13.1
|
||||
- fonttools==4.47.2
|
||||
- frozenlist==1.4.1
|
||||
- fsspec==2023.12.2
|
||||
- ftfy==6.1.3
|
||||
- future==0.18.3
|
||||
- gitdb==4.0.11
|
||||
- gitpython==3.1.41
|
||||
- google-auth==2.26.2
|
||||
- google-auth-oauthlib==1.0.0
|
||||
- grpcio==1.60.0
|
||||
- hdbscan==0.8.33
|
||||
- hmmlearn==0.3.2
|
||||
- holoviews==1.17.1
|
||||
- idna==3.6
|
||||
- imageio==2.34.0
|
||||
- importlib-metadata==7.0.1
|
||||
- importlib-resources==6.1.1
|
||||
- jinja2==3.1.3
|
||||
- joblib==1.3.2
|
||||
- kiwisolver==1.4.5
|
||||
- kornia==0.7.1
|
||||
- lazy-loader==0.3
|
||||
- lightning-utilities==0.10.1
|
||||
- linkify-it-py==2.0.3
|
||||
- llvmlite==0.41.1
|
||||
- lmdb==1.2.1
|
||||
- locket==1.0.0
|
||||
- lpips==0.1.4
|
||||
- markdown==3.5.2
|
||||
- markdown-it-py==3.0.0
|
||||
- markupsafe==2.1.3
|
||||
- matplotlib==3.7.4
|
||||
- mdit-py-plugins==0.4.0
|
||||
- mdurl==0.1.2
|
||||
- mpmath==1.3.0
|
||||
- multidict==6.0.4
|
||||
- multipledispatch==1.0.0
|
||||
- networkx==3.1
|
||||
- numba==0.58.1
|
||||
- numpy==1.24.4
|
||||
- nvidia-cublas-cu12==12.1.3.1
|
||||
- nvidia-cuda-cupti-cu12==12.1.105
|
||||
- nvidia-cuda-nvrtc-cu12==12.1.105
|
||||
- nvidia-cuda-runtime-cu12==12.1.105
|
||||
- nvidia-cudnn-cu12==8.9.2.26
|
||||
- nvidia-cufft-cu12==11.0.2.54
|
||||
- nvidia-curand-cu12==10.3.2.106
|
||||
- nvidia-cusolver-cu12==11.4.5.107
|
||||
- nvidia-cusparse-cu12==12.1.0.106
|
||||
- nvidia-nccl-cu12==2.19.3
|
||||
- nvidia-nvjitlink-cu12==12.3.101
|
||||
- nvidia-nvtx-cu12==12.1.105
|
||||
- oauthlib==3.2.2
|
||||
- packaging==23.2
|
||||
- pandas==1.5.3
|
||||
- panel==1.2.3
|
||||
- param==2.0.2
|
||||
- partd==1.4.1
|
||||
- pillow==10.2.0
|
||||
- protobuf==4.25.2
|
||||
- psutil==5.9.8
|
||||
- pyasn1==0.5.1
|
||||
- pyasn1-modules==0.3.0
|
||||
- pyct==0.5.0
|
||||
- pydeprecate==0.3.1
|
||||
- pynndescent==0.5.11
|
||||
- pyparsing==3.1.1
|
||||
- python-crfsuite==0.9.10
|
||||
- python-dateutil==2.8.2
|
||||
- pytorch-fid==0.2.0
|
||||
- pytorch-lightning==1.4.5
|
||||
- pytz==2023.3.post1
|
||||
- pyviz-comms==3.0.1
|
||||
- pywavelets==1.4.1
|
||||
- pyyaml==6.0.1
|
||||
- regex==2023.12.25
|
||||
- requests==2.31.0
|
||||
- requests-oauthlib==1.3.1
|
||||
- rsa==4.9
|
||||
- scikit-image==0.21.0
|
||||
- scikit-learn==1.3.2
|
||||
- scipy==1.10.1
|
||||
- sentry-sdk==1.39.2
|
||||
- setproctitle==1.3.3
|
||||
- six==1.16.0
|
||||
- sklearn-crfsuite==0.3.6
|
||||
- smmap==5.0.1
|
||||
- sympy==1.12
|
||||
- tabulate==0.9.0
|
||||
- tensorboard==2.14.0
|
||||
- tensorboard-data-server==0.7.2
|
||||
- threadpoolctl==3.2.0
|
||||
- tifffile==2023.7.10
|
||||
- toolz==0.12.1
|
||||
- torch==1.8.1
|
||||
- torchmetrics==0.5.0
|
||||
- torchvision==0.9.1
|
||||
- tornado==6.4
|
||||
- tqdm==4.66.1
|
||||
- triton==2.2.0
|
||||
- typing-extensions==4.9.0
|
||||
- tzdata==2023.4
|
||||
- uc-micro-py==1.0.3
|
||||
- umap-learn==0.5.5
|
||||
- urllib3==2.1.0
|
||||
- wandb==0.16.2
|
||||
- wcwidth==0.2.13
|
||||
- webencodings==0.5.1
|
||||
- werkzeug==3.0.1
|
||||
- xarray==2023.1.0
|
||||
- xyzservices==2023.10.1
|
||||
- yarl==1.9.4
|
||||
- zipp==3.17.0
|
378
experiment.py
Executable file
378
experiment.py
Executable file
|
@ -0,0 +1,378 @@
|
|||
import copy, wandb
|
||||
import os
|
||||
import random
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from pytorch_lightning import loggers as pl_loggers
|
||||
from pytorch_lightning.callbacks import *
|
||||
from torch.cuda import amp
|
||||
from torch.optim.optimizer import Optimizer
|
||||
from torch.utils.data.dataset import TensorDataset
|
||||
|
||||
from config import *
|
||||
from dataset import *
|
||||
from dist_utils import *
|
||||
|
||||
def MakeDir(dirName):
|
||||
if not os.path.exists(dirName):
|
||||
os.makedirs(dirName)
|
||||
|
||||
class LitModel(pl.LightningModule):
|
||||
def __init__(self, conf: TrainConfig, betas):
|
||||
super().__init__()
|
||||
|
||||
self.save_hyperparameters({k:v for (k,v) in vars(conf).items() if not callable(v)})
|
||||
self.save_hyperparameters(conf.as_dict_jsonable())
|
||||
|
||||
assert conf.train_mode != TrainMode.manipulate
|
||||
if conf.seed is not None:
|
||||
pl.seed_everything(conf.seed)
|
||||
|
||||
conf.betas = betas
|
||||
self.conf = conf
|
||||
self.model = conf.make_model_conf().make_model()
|
||||
|
||||
self.ema_model = copy.deepcopy(self.model)
|
||||
self.ema_model.requires_grad_(False)
|
||||
self.ema_model.eval()
|
||||
|
||||
model_size = 0
|
||||
for param in self.model.parameters():
|
||||
model_size += param.data.nelement()
|
||||
print('Model params: %.2f M' % (model_size / 1024 / 1024))
|
||||
|
||||
self.sampler = conf.make_diffusion_conf().make_sampler()
|
||||
self.eval_sampler = conf.make_eval_diffusion_conf().make_sampler()
|
||||
|
||||
self.T_sampler = conf.make_T_sampler()
|
||||
|
||||
if conf.train_mode.use_latent_net():
|
||||
self.latent_sampler = conf.make_latent_diffusion_conf(
|
||||
).make_sampler()
|
||||
self.eval_latent_sampler = conf.make_latent_eval_diffusion_conf(
|
||||
).make_sampler()
|
||||
else:
|
||||
self.latent_sampler = None
|
||||
self.eval_latent_sampler = None
|
||||
|
||||
if conf.pretrain is not None:
|
||||
print(f'loading pretrain ... {conf.pretrain.name}')
|
||||
state = torch.load(conf.pretrain.path, map_location='cpu')
|
||||
print('step:', state['global_step'])
|
||||
self.load_state_dict(state['state_dict'], strict=False)
|
||||
|
||||
if conf.latent_infer_path is not None:
|
||||
print('loading latent stats ...')
|
||||
state = torch.load(conf.latent_infer_path)
|
||||
self.conds = state['conds']
|
||||
else:
|
||||
self.conds_mean = None
|
||||
self.conds_std = None
|
||||
|
||||
def normalize(self, cond):
|
||||
cond = (cond - self.conds_mean.to(self.device)) / self.conds_std.to(
|
||||
self.device)
|
||||
return cond
|
||||
|
||||
def denormalize(self, cond):
|
||||
cond = (cond * self.conds_std.to(self.device)) + self.conds_mean.to(
|
||||
self.device)
|
||||
return cond
|
||||
|
||||
def render(self, noise, cond=None, T=None):
|
||||
if T is None:
|
||||
sampler = self.eval_sampler
|
||||
else:
|
||||
sampler = self.conf._make_diffusion_conf(T).make_sampler()
|
||||
|
||||
if cond is not None:
|
||||
pred_img = render_condition(self.conf,
|
||||
self.ema_model,
|
||||
noise,
|
||||
sampler=sampler,
|
||||
cond=cond)
|
||||
else:
|
||||
pred_img = render_uncondition(self.conf,
|
||||
self.ema_model,
|
||||
noise,
|
||||
sampler=sampler,
|
||||
latent_sampler=None)
|
||||
return pred_img
|
||||
|
||||
def encode(self, x):
|
||||
assert self.conf.model_type.has_autoenc()
|
||||
cond = self.ema_model.encoder.forward(x)
|
||||
return cond
|
||||
|
||||
def encode_stochastic(self, x, cond, T=None):
|
||||
if T is None:
|
||||
sampler = self.eval_sampler
|
||||
else:
|
||||
sampler = self.conf._make_diffusion_conf(T).make_sampler()
|
||||
|
||||
out = sampler.ddim_reverse_sample_loop(self.ema_model,
|
||||
x,
|
||||
model_kwargs={'cond': cond})
|
||||
return out['sample'], out['xstart_t']
|
||||
|
||||
def forward(self, noise=None, x_start=None, ema_model: bool = False):
|
||||
with amp.autocast(False):
|
||||
if ema_model:
|
||||
model = self.ema_model
|
||||
else:
|
||||
model = self.model
|
||||
gen = self.eval_sampler.sample(model=model,
|
||||
noise=noise,
|
||||
x_start=x_start)
|
||||
return gen
|
||||
|
||||
def setup(self, stage=None) -> None:
|
||||
"""
|
||||
make datasets & seeding each worker separately
|
||||
"""
|
||||
if self.conf.seed is not None:
|
||||
seed = self.conf.seed * get_world_size() + self.global_rank
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
taskdata, tasklabels = loadDataset(self.conf)
|
||||
assert self.conf.num_users==tasklabels['user'].nunique()
|
||||
|
||||
tasklabels = pd.DataFrame(tasklabels['user'].astype('category').cat.codes.values)[0].values.astype(int)
|
||||
assert self.conf.num_users==len(np.unique(tasklabels))
|
||||
|
||||
self.train_data = self.conf.make_dataset(taskdata, tasklabels)
|
||||
self.val_data = self.train_data
|
||||
|
||||
def train_dataloader(self):
|
||||
"""
|
||||
return the dataloader, if diffusion mode => return image dataset
|
||||
if latent mode => return the inferred latent dataset
|
||||
"""
|
||||
if self.conf.train_mode.require_dataset_infer():
|
||||
if self.conds is None:
|
||||
self.conds = self.infer_whole_dataset()
|
||||
self.conds_mean.data = self.conds.float().mean(dim=0,
|
||||
keepdim=True)
|
||||
self.conds_std.data = self.conds.float().std(dim=0,
|
||||
keepdim=True)
|
||||
print('mean:', self.conds_mean.mean(), 'std:',
|
||||
self.conds_std.mean())
|
||||
|
||||
conf = self.conf.clone()
|
||||
conf.batch_size = self.batch_size
|
||||
data = TensorDataset(self.conds)
|
||||
return conf.make_loader(data, shuffle=True)
|
||||
else:
|
||||
return torch.utils.data.DataLoader(self.train_data, batch_size=self.conf.batch_size, shuffle=True)
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
"""
|
||||
local batch size for each worker
|
||||
"""
|
||||
ws = get_world_size()
|
||||
assert self.conf.batch_size % ws == 0
|
||||
return self.conf.batch_size // ws
|
||||
|
||||
@property
|
||||
def num_samples(self):
|
||||
"""
|
||||
(global) batch size * iterations
|
||||
"""
|
||||
return self.global_step * self.conf.batch_size_effective
|
||||
|
||||
def is_last_accum(self, batch_idx):
|
||||
"""
|
||||
is it the last gradient accumulation loop?
|
||||
used with gradient_accum > 1 and to see if the optimizer will perform "step" in this iteration or not
|
||||
"""
|
||||
return (batch_idx + 1) % self.conf.accum_batches == 0
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
"""
|
||||
given an input, calculate the loss function
|
||||
no optimization at this stage.
|
||||
"""
|
||||
with amp.autocast(False):
|
||||
if self.conf.train_mode.require_dataset_infer():
|
||||
cond = batch[0]
|
||||
if self.conf.latent_znormalize:
|
||||
cond = (cond - self.conds_mean.to(
|
||||
self.device)) / self.conds_std.to(self.device)
|
||||
else:
|
||||
imgs = batch[0]
|
||||
x_start = imgs
|
||||
|
||||
if self.conf.train_mode == TrainMode.diffusion:
|
||||
t, weight = self.T_sampler.sample(len(x_start), x_start.device)
|
||||
|
||||
losses = self.sampler.training_losses(model=self.model,
|
||||
x_start=x_start,
|
||||
t=t,
|
||||
user_label=batch[1],
|
||||
lossbetas=self.conf.betas)
|
||||
elif self.conf.train_mode.is_latent_diffusion():
|
||||
t, weight = self.T_sampler.sample(len(cond), cond.device)
|
||||
latent_losses = self.latent_sampler.training_losses(
|
||||
model=self.model.latent_net, x_start=cond, t=t)
|
||||
losses = {
|
||||
'latent': latent_losses['loss'],
|
||||
'loss': latent_losses['loss']
|
||||
}
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
loss = losses['loss'].mean()
|
||||
self.log("train_loss", loss)
|
||||
|
||||
return {'loss': loss}
|
||||
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx: int,
|
||||
dataloader_idx: int) -> None:
|
||||
"""
|
||||
after each training step
|
||||
"""
|
||||
if self.is_last_accum(batch_idx):
|
||||
if (batch_idx==len(self.train_dataloader())-1) and ((self.current_epoch+1) % 10 == 0):
|
||||
save_path = os.path.join(self.conf.logdir, 'epoch%d.ckpt' % self.current_epoch)
|
||||
torch.save({
|
||||
'state_dict': self.state_dict(),
|
||||
'global_step': self.global_step,
|
||||
'loss': outputs['loss'],
|
||||
}, save_path)
|
||||
|
||||
if self.conf.train_mode == TrainMode.latent_diffusion:
|
||||
ema(self.model.latent_net, self.ema_model.latent_net,
|
||||
self.conf.ema_decay)
|
||||
else:
|
||||
ema(self.model, self.ema_model, self.conf.ema_decay)
|
||||
|
||||
def on_before_optimizer_step(self, optimizer: Optimizer,
|
||||
optimizer_idx: int) -> None:
|
||||
if self.conf.grad_clip > 0:
|
||||
params = [
|
||||
p for group in optimizer.param_groups for p in group['params']
|
||||
]
|
||||
torch.nn.utils.clip_grad_norm_(params,
|
||||
max_norm=self.conf.grad_clip)
|
||||
|
||||
def configure_optimizers(self):
|
||||
out = {}
|
||||
if self.conf.optimizer == OptimizerType.adam:
|
||||
optim = torch.optim.Adam(self.model.parameters(),
|
||||
lr=self.conf.lr,
|
||||
weight_decay=self.conf.weight_decay)
|
||||
elif self.conf.optimizer == OptimizerType.adamw:
|
||||
optim = torch.optim.AdamW(self.model.parameters(),
|
||||
lr=self.conf.lr,
|
||||
weight_decay=self.conf.weight_decay)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
out['optimizer'] = optim
|
||||
if self.conf.warmup > 0:
|
||||
sched = torch.optim.lr_scheduler.LambdaLR(optim,
|
||||
lr_lambda=WarmupLR(
|
||||
self.conf.warmup))
|
||||
out['lr_scheduler'] = {
|
||||
'scheduler': sched,
|
||||
'interval': 'step',
|
||||
}
|
||||
return out
|
||||
|
||||
def split_tensor(self, x):
|
||||
"""
|
||||
extract the tensor for a corresponding "worker" in the batch dimension
|
||||
|
||||
Args:
|
||||
x: (n, c)
|
||||
|
||||
Returns: x: (n_local, c)
|
||||
"""
|
||||
n = len(x)
|
||||
rank = self.global_rank
|
||||
world_size = get_world_size()
|
||||
per_rank = n // world_size
|
||||
return x[rank * per_rank:(rank + 1) * per_rank]
|
||||
|
||||
def ema(source, target, decay):
|
||||
source_dict = source.state_dict()
|
||||
target_dict = target.state_dict()
|
||||
for key in source_dict.keys():
|
||||
target_dict[key].data.copy_(target_dict[key].data * decay +
|
||||
source_dict[key].data * (1 - decay))
|
||||
|
||||
class WarmupLR:
|
||||
def __init__(self, warmup) -> None:
|
||||
self.warmup = warmup
|
||||
|
||||
def __call__(self, step):
|
||||
return min(step, self.warmup) / self.warmup
|
||||
|
||||
|
||||
def is_time(num_samples, every, step_size):
|
||||
closest = (num_samples // every) * every
|
||||
return num_samples - closest < step_size
|
||||
|
||||
|
||||
def train(conf: TrainConfig, model: LitModel, gpus, nodes=1):
|
||||
checkpoint = ModelCheckpoint(dirpath=conf.logdir,
|
||||
save_last=True,
|
||||
save_top_k=1,
|
||||
every_n_train_steps=conf.save_every_samples //
|
||||
conf.batch_size_effective)
|
||||
checkpoint_path = f'{conf.logdir}last.ckpt'
|
||||
print('ckpt path:', checkpoint_path)
|
||||
if os.path.exists(checkpoint_path):
|
||||
resume = checkpoint_path
|
||||
print('resume!')
|
||||
else:
|
||||
if conf.continue_from is not None:
|
||||
resume = conf.continue_from.path
|
||||
else:
|
||||
resume = None
|
||||
|
||||
plugins = []
|
||||
if len(gpus) == 1 and nodes == 1:
|
||||
accelerator = None
|
||||
else:
|
||||
accelerator = 'ddp'
|
||||
from pytorch_lightning.plugins import DDPPlugin
|
||||
|
||||
plugins.append(DDPPlugin(find_unused_parameters=False))
|
||||
|
||||
wandb_logger = pl_loggers.WandbLogger(project='dismouse',
|
||||
name='%s_%s'%(model.conf.pretrainDataset, conf.logdir.split('/')[-2]),
|
||||
log_model=True,
|
||||
save_dir=conf.logdir,
|
||||
dir = conf.logdir,
|
||||
config=vars(model.conf),
|
||||
save_code=True,
|
||||
settings=wandb.Settings(code_dir="."))
|
||||
|
||||
trainer = pl.Trainer(
|
||||
max_steps=conf.total_samples // conf.batch_size_effective,
|
||||
resume_from_checkpoint=resume,
|
||||
gpus=gpus,
|
||||
num_nodes=nodes,
|
||||
accelerator=accelerator,
|
||||
precision=16 if conf.fp16 else 32,
|
||||
callbacks=[
|
||||
checkpoint,
|
||||
LearningRateMonitor(),
|
||||
],
|
||||
replace_sampler_ddp=True,
|
||||
logger= wandb_logger,
|
||||
accumulate_grad_batches=conf.accum_batches,
|
||||
plugins=plugins,
|
||||
)
|
||||
|
||||
trainer.fit(model)
|
||||
wandb.finish()
|
29
main.py
Normal file
29
main.py
Normal file
|
@ -0,0 +1,29 @@
|
|||
import os
|
||||
from templates import *
|
||||
|
||||
if __name__ == '__main__':
|
||||
gpus = [0]
|
||||
|
||||
conf = mouse_autoenc('trainDiff')
|
||||
|
||||
betas = {
|
||||
'recon':1,
|
||||
'noise':1,
|
||||
'user':0.01,
|
||||
'nonuser':0.01,
|
||||
'mi': 0.01
|
||||
}
|
||||
betastr = ''
|
||||
for k,v in betas.items():
|
||||
betastr += f'{k}{v}_'
|
||||
betastr = betastr[:-1]
|
||||
|
||||
diffmodel = LitModel(conf, betas)
|
||||
|
||||
conf.logdir= f'{conf.logdir}/mouse_autoenc/{conf.pretrainDataset}/{betastr}/embDim{conf.net_beatgans_embed_channels}/win{conf.AEwin}/slid{conf.slid}/GRL/'
|
||||
MakeDir(conf.logdir)
|
||||
os.environ['WANDB_CACHE_DIR'] = conf.logdir
|
||||
os.environ['WANDB_DATA_DIR'] = conf.logdir
|
||||
os.environ['WANDB_IGNORE_GLOBS'] = '*.ckpt'
|
||||
|
||||
train(conf, diffmodel, gpus=gpus)
|
176
model/MI.py
Normal file
176
model/MI.py
Normal file
|
@ -0,0 +1,176 @@
|
|||
'''
|
||||
Differentiable approximation to the mutual information (MI) metric.
|
||||
Implementation in PyTorch
|
||||
'''
|
||||
|
||||
# Imports #
|
||||
# ----------------------------------------------------------------------
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib import cm
|
||||
import os
|
||||
|
||||
# Note: This code snippet was taken from the discussion found at:
|
||||
# https://discuss.pytorch.org/t/differentiable-torch-histc/25865/2
|
||||
# By Tony-Y
|
||||
class SoftHistogram1D(nn.Module):
|
||||
'''
|
||||
Differentiable 1D histogram calculation (supported via pytorch's autograd)
|
||||
inupt:
|
||||
x - N x D array, where N is the batch size and D is the length of each data series
|
||||
bins - Number of bins for the histogram
|
||||
min - Scalar min value to be included in the histogram
|
||||
max - Scalar max value to be included in the histogram
|
||||
sigma - Scalar smoothing factor fir the bin approximation via sigmoid functions.
|
||||
Larger values correspond to sharper edges, and thus yield a more accurate approximation
|
||||
output:
|
||||
N x bins array, where each row is a histogram
|
||||
'''
|
||||
|
||||
def __init__(self, bins=50, min=0, max=1, sigma=10):
|
||||
super(SoftHistogram1D, self).__init__()
|
||||
self.bins = bins
|
||||
self.min = min
|
||||
self.max = max
|
||||
self.sigma = sigma
|
||||
self.delta = float(max - min) / float(bins)
|
||||
self.centers = float(min) + self.delta * (torch.arange(bins).float() + 0.5) # Bin centers
|
||||
self.centers = nn.Parameter(self.centers, requires_grad=False) # Wrap for allow for cuda support
|
||||
|
||||
def forward(self, x):
|
||||
# Replicate x and for each row remove center
|
||||
x = torch.unsqueeze(x, 1) - torch.unsqueeze(self.centers, 1)
|
||||
|
||||
# Bin approximation using a sigmoid function
|
||||
x = torch.sigmoid(self.sigma * (x + self.delta / 2)) - torch.sigmoid(self.sigma * (x - self.delta / 2))
|
||||
|
||||
# Sum along the non-batch dimensions
|
||||
x = x.sum(dim=-1)
|
||||
# x = x / x.sum(dim=-1).unsqueeze(1) # normalization
|
||||
return x
|
||||
|
||||
|
||||
# Note: This is an extension to the 2D case of the previous code snippet
|
||||
class SoftHistogram2D(nn.Module):
|
||||
'''
|
||||
Differentiable 1D histogram calculation (supported via pytorch's autograd)
|
||||
inupt:
|
||||
x, y - N x D array, where N is the batch size and D is the length of each data series
|
||||
(i.e. vectorized image or vectorized 3D volume)
|
||||
bins - Number of bins for the histogram
|
||||
min - Scalar min value to be included in the histogram
|
||||
max - Scalar max value to be included in the histogram
|
||||
sigma - Scalar smoothing factor fir the bin approximation via sigmoid functions.
|
||||
Larger values correspond to sharper edges, and thus yield a more accurate approximation
|
||||
output:
|
||||
N x bins array, where each row is a histogram
|
||||
'''
|
||||
|
||||
def __init__(self, bins=50, min=0, max=1, sigma=10):
|
||||
super(SoftHistogram2D, self).__init__()
|
||||
self.bins = bins
|
||||
self.min = min
|
||||
self.max = max
|
||||
self.sigma = sigma
|
||||
self.delta = float(max - min) / float(bins)
|
||||
self.centers = float(min) + self.delta * (torch.arange(bins).float() + 0.5) # Bin centers
|
||||
self.centers = nn.Parameter(self.centers, requires_grad=False) # Wrap for allow for cuda support
|
||||
|
||||
def forward(self, x, y):
|
||||
assert x.size() == y.size(), "(SoftHistogram2D) x and y sizes do not match"
|
||||
|
||||
# Replicate x and for each row remove center
|
||||
x = torch.unsqueeze(x, 1) - torch.unsqueeze(self.centers, 1)
|
||||
y = torch.unsqueeze(y, 1) - torch.unsqueeze(self.centers, 1)
|
||||
|
||||
# Bin approximation using a sigmoid function (can be sigma_x and sigma_y respectively - same for delta)
|
||||
x = torch.sigmoid(self.sigma * (x + self.delta / 2)) - torch.sigmoid(self.sigma * (x - self.delta / 2))
|
||||
y = torch.sigmoid(self.sigma * (y + self.delta / 2)) - torch.sigmoid(self.sigma * (y - self.delta / 2))
|
||||
|
||||
# Batched matrix multiplication - this way we sum jointly
|
||||
z = torch.matmul(x, y.permute((0, 2, 1)))
|
||||
return z
|
||||
|
||||
|
||||
class MI_pytorch(nn.Module):
|
||||
'''
|
||||
This class is a pytorch implementation of the mutual information (MI) calculation between two images.
|
||||
This is an approximation, as the images' histograms rely on differentiable approximations of rectangular windows.
|
||||
|
||||
I(X, Y) = H(X) + H(Y) - H(X, Y) = \sum(\sum(p(X, Y) * log(p(Y, Y)/(p(X) * p(Y)))))
|
||||
|
||||
where H(X) = -\sum(p(x) * log(p(x))) is the entropy
|
||||
'''
|
||||
|
||||
def __init__(self, bins=50, min=0, max=1, sigma=10, reduction='sum'):
|
||||
super(MI_pytorch, self).__init__()
|
||||
self.bins = bins
|
||||
self.min = min
|
||||
self.max = max
|
||||
self.sigma = sigma
|
||||
self.reduction = reduction
|
||||
|
||||
# 2D joint histogram
|
||||
self.hist2d = SoftHistogram2D(bins, min, max, sigma)
|
||||
|
||||
# Epsilon - to avoid log(0)
|
||||
self.eps = torch.tensor(0.00000001, dtype=torch.float32, requires_grad=False)
|
||||
|
||||
def forward(self, im1, im2):
|
||||
'''
|
||||
Forward implementation of a differentiable MI estimator for batched images
|
||||
:param im1: N x ... tensor, where N is the batch size
|
||||
... dimensions can take any form, i.e. 2D images or 3D volumes.
|
||||
:param im2: N x ... tensor, where N is the batch size
|
||||
:return: N x 1 vector - the approximate MI values between the batched im1 and im2
|
||||
'''
|
||||
|
||||
# Check for valid inputs
|
||||
assert im1.size() == im2.size(), "(MI_pytorch) Inputs should have the same dimensions."
|
||||
|
||||
batch_size = im1.size()[0]
|
||||
|
||||
# Flatten tensors
|
||||
im1_flat = im1.view(im1.size()[0], -1)
|
||||
im2_flat = im2.view(im2.size()[0], -1)
|
||||
|
||||
# Calculate joint histogram
|
||||
hgram = self.hist2d(im1_flat, im2_flat)
|
||||
|
||||
# Convert to a joint distribution
|
||||
# Pxy = torch.distributions.Categorical(probs=hgram).probs
|
||||
Pxy = torch.div(hgram, torch.sum(hgram.view(hgram.size()[0], -1)))
|
||||
|
||||
# Calculate the marginal distributions
|
||||
Py = torch.sum(Pxy, dim=1).unsqueeze(1)
|
||||
Px = torch.sum(Pxy, dim=2).unsqueeze(1)
|
||||
|
||||
# Use the KL divergence distance to calculate the MI
|
||||
Px_Py = torch.matmul(Px.permute((0, 2, 1)), Py)
|
||||
|
||||
# Reshape to batch_size X all_the_rest
|
||||
Pxy = Pxy.reshape(batch_size, -1)
|
||||
Px_Py = Px_Py.reshape(batch_size, -1)
|
||||
|
||||
# Calculate mutual information - this is an approximation due to the histogram calculation and eps,
|
||||
# but it can handle batches
|
||||
if batch_size == 1:
|
||||
# No need for eps approximation in the case of a single batch
|
||||
nzs = Pxy > 0 # Calculate based on the non-zero values only
|
||||
mut_info = torch.matmul(Pxy[nzs], torch.log(Pxy[nzs]) - torch.log(Px_Py[nzs])) # MI calculation
|
||||
else:
|
||||
# For arbitrary batch size > 1
|
||||
mut_info = torch.sum(Pxy * (torch.log(Pxy + self.eps) - torch.log(Px_Py + self.eps)), dim=1)
|
||||
|
||||
# Reduction
|
||||
if self.reduction == 'sum':
|
||||
mut_info = torch.sum(mut_info)
|
||||
elif self.reduction == 'batchmean':
|
||||
mut_info = torch.sum(mut_info)
|
||||
mut_info = mut_info / float(batch_size)
|
||||
elif self.reduction=='individual':
|
||||
pass
|
||||
|
||||
return mut_info
|
6
model/__init__.py
Normal file
6
model/__init__.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
from typing import Union
|
||||
from .unet import BeatGANsUNetModel, BeatGANsUNetConfig
|
||||
from .unet_autoenc import BeatGANsAutoencConfig, BeatGANsAutoencModel
|
||||
|
||||
Model = Union[BeatGANsUNetModel, BeatGANsAutoencModel]
|
||||
ModelConfig = Union[BeatGANsUNetConfig, BeatGANsAutoencConfig]
|
495
model/blocks.py
Normal file
495
model/blocks.py
Normal file
|
@ -0,0 +1,495 @@
|
|||
import math
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from numbers import Number
|
||||
|
||||
import torch as th
|
||||
import torch.nn.functional as F
|
||||
from choices import *
|
||||
from config_base import BaseConfig
|
||||
from torch import nn
|
||||
|
||||
from .nn import (avg_pool_nd, conv_nd, linear, normalization,
|
||||
timestep_embedding, torch_checkpoint, zero_module)
|
||||
|
||||
|
||||
class ScaleAt(Enum):
|
||||
after_norm = 'afternorm'
|
||||
|
||||
|
||||
class TimestepBlock(nn.Module):
|
||||
"""
|
||||
Any module where forward() takes timestep embeddings as a second argument.
|
||||
"""
|
||||
@abstractmethod
|
||||
def forward(self, x, emb=None, cond=None, lateral=None):
|
||||
"""
|
||||
Apply the module to `x` given `emb` timestep embeddings.
|
||||
"""
|
||||
|
||||
|
||||
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
||||
"""
|
||||
A sequential module that passes timestep embeddings to the children that
|
||||
support it as an extra input.
|
||||
"""
|
||||
def forward(self, x, emb=None, cond=None, lateral=None):
|
||||
for layer in self:
|
||||
if isinstance(layer, TimestepBlock):
|
||||
x = layer(x, emb=emb, cond=cond, lateral=lateral)
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResBlockConfig(BaseConfig):
|
||||
channels: int
|
||||
emb_channels: int
|
||||
dropout: float
|
||||
out_channels: int = None
|
||||
use_condition: bool = True
|
||||
use_conv: bool = False
|
||||
dims: int = 2
|
||||
use_checkpoint: bool = False
|
||||
up: bool = False
|
||||
down: bool = False
|
||||
two_cond: bool = False
|
||||
cond_emb_channels: int = None
|
||||
has_lateral: bool = False
|
||||
lateral_channels: int = None
|
||||
use_zero_module: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
self.out_channels = self.out_channels or self.channels
|
||||
self.cond_emb_channels = self.cond_emb_channels or self.emb_channels
|
||||
|
||||
def make_model(self):
|
||||
return ResBlock(self)
|
||||
|
||||
|
||||
class ResBlock(TimestepBlock):
|
||||
"""
|
||||
A residual block that can optionally change the number of channels.
|
||||
"""
|
||||
def __init__(self, conf: ResBlockConfig):
|
||||
super().__init__()
|
||||
self.conf = conf
|
||||
|
||||
assert conf.lateral_channels is None
|
||||
layers = [
|
||||
normalization(conf.channels),
|
||||
nn.SiLU(),
|
||||
conv_nd(conf.dims, conf.channels, conf.out_channels, 3, padding=1)
|
||||
]
|
||||
self.in_layers = nn.Sequential(*layers)
|
||||
|
||||
self.updown = conf.up or conf.down
|
||||
|
||||
if conf.up:
|
||||
self.h_upd = Upsample(conf.channels, False, conf.dims)
|
||||
self.x_upd = Upsample(conf.channels, False, conf.dims)
|
||||
elif conf.down:
|
||||
self.h_upd = Downsample(conf.channels, False, conf.dims)
|
||||
self.x_upd = Downsample(conf.channels, False, conf.dims)
|
||||
else:
|
||||
self.h_upd = self.x_upd = nn.Identity()
|
||||
|
||||
if conf.use_condition:
|
||||
self.emb_layers = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
linear(conf.emb_channels, 2 * conf.out_channels),
|
||||
)
|
||||
|
||||
if conf.two_cond:
|
||||
self.cond_emb_layers = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
linear(conf.cond_emb_channels, conf.out_channels),
|
||||
)
|
||||
conv = conv_nd(conf.dims,
|
||||
conf.out_channels,
|
||||
conf.out_channels,
|
||||
3,
|
||||
padding=1)
|
||||
if conf.use_zero_module:
|
||||
conv = zero_module(conv)
|
||||
|
||||
layers = []
|
||||
layers += [
|
||||
normalization(conf.out_channels),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(p=conf.dropout),
|
||||
conv,
|
||||
]
|
||||
self.out_layers = nn.Sequential(*layers)
|
||||
|
||||
if conf.out_channels == conf.channels:
|
||||
self.skip_connection = nn.Identity()
|
||||
else:
|
||||
if conf.use_conv:
|
||||
kernel_size = 3
|
||||
padding = 1
|
||||
else:
|
||||
kernel_size = 1
|
||||
padding = 0
|
||||
|
||||
self.skip_connection = conv_nd(conf.dims,
|
||||
conf.channels,
|
||||
conf.out_channels,
|
||||
kernel_size,
|
||||
padding=padding)
|
||||
|
||||
def forward(self, x, emb=None, cond=None, lateral=None):
|
||||
"""
|
||||
Apply the block to a Tensor, conditioned on a timestep embedding.
|
||||
|
||||
Args:
|
||||
x: input
|
||||
lateral: lateral connection from the encoder
|
||||
"""
|
||||
return torch_checkpoint(self._forward, (x, emb, cond, lateral),
|
||||
self.conf.use_checkpoint)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x,
|
||||
emb=None,
|
||||
cond=None,
|
||||
lateral=None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
lateral: required if "has_lateral" and non-gated, with gated, it can be supplied optionally
|
||||
"""
|
||||
if self.conf.has_lateral:
|
||||
assert lateral is not None
|
||||
x = th.cat([x, lateral], dim=1)
|
||||
|
||||
if self.updown:
|
||||
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
||||
h = in_rest(x)
|
||||
h = self.h_upd(h)
|
||||
x = self.x_upd(x)
|
||||
h = in_conv(h)
|
||||
else:
|
||||
h = self.in_layers(x)
|
||||
|
||||
if self.conf.use_condition:
|
||||
if emb is not None:
|
||||
emb_out = self.emb_layers(emb).type(h.dtype)
|
||||
else:
|
||||
emb_out = None
|
||||
|
||||
if self.conf.two_cond:
|
||||
if cond is None:
|
||||
cond_out = None
|
||||
else:
|
||||
if not isinstance(cond, th.Tensor):
|
||||
assert isinstance(cond, dict)
|
||||
cond = cond['cond']
|
||||
cond_out = self.cond_emb_layers(cond).type(h.dtype)
|
||||
|
||||
if cond_out is not None:
|
||||
while len(cond_out.shape) < len(h.shape):
|
||||
cond_out = cond_out[..., None]
|
||||
else:
|
||||
cond_out = None
|
||||
|
||||
h = apply_conditions(
|
||||
h=h,
|
||||
emb=emb_out,
|
||||
cond=cond_out,
|
||||
layers=self.out_layers,
|
||||
scale_bias=1,
|
||||
in_channels=self.conf.out_channels,
|
||||
up_down_layer=None,
|
||||
)
|
||||
return self.skip_connection(x) + h
|
||||
|
||||
|
||||
def apply_conditions(
|
||||
h,
|
||||
emb=None,
|
||||
cond=None,
|
||||
layers: nn.Sequential = None,
|
||||
scale_bias: float = 1,
|
||||
in_channels: int = 512,
|
||||
up_down_layer: nn.Module = None,
|
||||
):
|
||||
"""
|
||||
apply conditions on the feature maps
|
||||
|
||||
Args:
|
||||
emb: time conditional (ready to scale + shift)
|
||||
cond: encoder's conditional (read to scale + shift)
|
||||
"""
|
||||
two_cond = emb is not None and cond is not None
|
||||
|
||||
if emb is not None:
|
||||
while len(emb.shape) < len(h.shape):
|
||||
emb = emb[..., None]
|
||||
|
||||
if two_cond:
|
||||
while len(cond.shape) < len(h.shape):
|
||||
cond = cond[..., None]
|
||||
scale_shifts = [emb, cond]
|
||||
else:
|
||||
scale_shifts = [emb]
|
||||
|
||||
for i, each in enumerate(scale_shifts):
|
||||
if each is None:
|
||||
a = None
|
||||
b = None
|
||||
else:
|
||||
if each.shape[1] == in_channels * 2:
|
||||
a, b = th.chunk(each, 2, dim=1)
|
||||
else:
|
||||
a = each
|
||||
b = None
|
||||
scale_shifts[i] = (a, b)
|
||||
|
||||
if isinstance(scale_bias, Number):
|
||||
biases = [scale_bias] * len(scale_shifts)
|
||||
else:
|
||||
biases = scale_bias
|
||||
|
||||
pre_layers, post_layers = layers[0], layers[1:]
|
||||
|
||||
mid_layers, post_layers = post_layers[:-2], post_layers[-2:]
|
||||
|
||||
h = pre_layers(h)
|
||||
for i, (scale, shift) in enumerate(scale_shifts):
|
||||
if scale is not None:
|
||||
h = h * (biases[i] + scale)
|
||||
if shift is not None:
|
||||
h = h + shift
|
||||
h = mid_layers(h)
|
||||
|
||||
if up_down_layer is not None:
|
||||
h = up_down_layer(h)
|
||||
h = post_layers(h)
|
||||
return h
|
||||
|
||||
class Upsample(nn.Module):
|
||||
"""
|
||||
An upsampling layer with an optional convolution.
|
||||
|
||||
:param channels: channels in the inputs and outputs.
|
||||
:param use_conv: a bool determining if a convolution is applied.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
upsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
if use_conv:
|
||||
self.conv = conv_nd(dims,
|
||||
self.channels,
|
||||
self.out_channels,
|
||||
3,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
if self.dims == 3:
|
||||
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
|
||||
mode="nearest")
|
||||
else:
|
||||
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
||||
if self.use_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
"""
|
||||
A downsampling layer with an optional convolution.
|
||||
|
||||
:param channels: channels in the inputs and outputs.
|
||||
:param use_conv: a bool determining if a convolution is applied.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
downsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
stride = 2 if dims != 3 else (1, 2, 2)
|
||||
if use_conv:
|
||||
self.op = conv_nd(dims,
|
||||
self.channels,
|
||||
self.out_channels,
|
||||
3,
|
||||
stride=stride,
|
||||
padding=1)
|
||||
else:
|
||||
assert self.channels == self.out_channels
|
||||
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
"""
|
||||
An attention block that allows spatial positions to attend to each other.
|
||||
|
||||
Originally ported from here, but adapted to the N-d case.
|
||||
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
num_heads=1,
|
||||
num_head_channels=-1,
|
||||
use_checkpoint=False,
|
||||
use_new_attention_order=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
if num_head_channels == -1:
|
||||
self.num_heads = num_heads
|
||||
else:
|
||||
assert (
|
||||
channels % num_head_channels == 0
|
||||
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
||||
self.num_heads = channels // num_head_channels
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.norm = normalization(channels)
|
||||
self.qkv = conv_nd(1, channels, channels * 3, 1)
|
||||
if use_new_attention_order:
|
||||
self.attention = QKVAttention(self.num_heads)
|
||||
else:
|
||||
self.attention = QKVAttentionLegacy(self.num_heads)
|
||||
|
||||
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
|
||||
|
||||
def forward(self, x):
|
||||
return torch_checkpoint(self._forward, (x, ), self.use_checkpoint)
|
||||
|
||||
def _forward(self, x):
|
||||
b, c, *spatial = x.shape
|
||||
x = x.reshape(b, c, -1)
|
||||
qkv = self.qkv(self.norm(x))
|
||||
h = self.attention(qkv)
|
||||
h = self.proj_out(h)
|
||||
return (x + h).reshape(b, c, *spatial)
|
||||
|
||||
|
||||
def count_flops_attn(model, _x, y):
|
||||
"""
|
||||
A counter for the `thop` package to count the operations in an
|
||||
attention operation.
|
||||
Meant to be used like:
|
||||
macs, params = thop.profile(
|
||||
model,
|
||||
inputs=(inputs, timestamps),
|
||||
custom_ops={QKVAttention: QKVAttention.count_flops},
|
||||
)
|
||||
"""
|
||||
b, c, *spatial = y[0].shape
|
||||
num_spatial = int(np.prod(spatial))
|
||||
matmul_ops = 2 * b * (num_spatial**2) * c
|
||||
model.total_ops += th.DoubleTensor([matmul_ops])
|
||||
|
||||
class QKVAttentionLegacy(nn.Module):
|
||||
"""
|
||||
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
|
||||
"""
|
||||
def __init__(self, n_heads):
|
||||
super().__init__()
|
||||
self.n_heads = n_heads
|
||||
|
||||
def forward(self, qkv):
|
||||
"""
|
||||
Apply QKV attention.
|
||||
|
||||
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
|
||||
:return: an [N x (H * C) x T] tensor after attention.
|
||||
"""
|
||||
bs, width, length = qkv.shape
|
||||
assert width % (3 * self.n_heads) == 0
|
||||
ch = width // (3 * self.n_heads)
|
||||
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch,
|
||||
dim=1)
|
||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||
weight = th.einsum(
|
||||
"bct,bcs->bts", q * scale,
|
||||
k * scale)
|
||||
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
a = th.einsum("bts,bcs->bct", weight, v)
|
||||
return a.reshape(bs, -1, length)
|
||||
|
||||
@staticmethod
|
||||
def count_flops(model, _x, y):
|
||||
return count_flops_attn(model, _x, y)
|
||||
|
||||
class QKVAttention(nn.Module):
|
||||
"""
|
||||
A module which performs QKV attention and splits in a different order.
|
||||
"""
|
||||
def __init__(self, n_heads):
|
||||
super().__init__()
|
||||
self.n_heads = n_heads
|
||||
|
||||
def forward(self, qkv):
|
||||
"""
|
||||
Apply QKV attention.
|
||||
|
||||
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
|
||||
:return: an [N x (H * C) x T] tensor after attention.
|
||||
"""
|
||||
bs, width, length = qkv.shape
|
||||
assert width % (3 * self.n_heads) == 0
|
||||
ch = width // (3 * self.n_heads)
|
||||
q, k, v = qkv.chunk(3, dim=1)
|
||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||
weight = th.einsum(
|
||||
"bct,bcs->bts",
|
||||
(q * scale).view(bs * self.n_heads, ch, length),
|
||||
(k * scale).view(bs * self.n_heads, ch, length),
|
||||
)
|
||||
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
a = th.einsum("bts,bcs->bct", weight,
|
||||
v.reshape(bs * self.n_heads, ch, length))
|
||||
return a.reshape(bs, -1, length)
|
||||
|
||||
@staticmethod
|
||||
def count_flops(model, _x, y):
|
||||
return count_flops_attn(model, _x, y)
|
||||
|
||||
|
||||
class AttentionPool2d(nn.Module):
|
||||
"""
|
||||
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
spacial_dim: int,
|
||||
embed_dim: int,
|
||||
num_heads_channels: int,
|
||||
output_dim: int = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.positional_embedding = nn.Parameter(
|
||||
th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5)
|
||||
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
|
||||
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
|
||||
self.num_heads = embed_dim // num_heads_channels
|
||||
self.attention = QKVAttention(self.num_heads)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, *_spatial = x.shape
|
||||
x = x.reshape(b, c, -1)
|
||||
x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1)
|
||||
x = x + self.positional_embedding[None, :, :].to(x.dtype)
|
||||
x = self.qkv_proj(x)
|
||||
x = self.attention(x)
|
||||
x = self.c_proj(x)
|
||||
return x[:, :, 0]
|
184
model/latentnet.py
Normal file
184
model/latentnet.py
Normal file
|
@ -0,0 +1,184 @@
|
|||
import math
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import NamedTuple, Tuple
|
||||
|
||||
import torch
|
||||
from choices import *
|
||||
from config_base import BaseConfig
|
||||
from torch import nn
|
||||
from torch.nn import init
|
||||
|
||||
from .blocks import *
|
||||
from .nn import timestep_embedding
|
||||
from .unet import *
|
||||
|
||||
|
||||
class LatentNetType(Enum):
|
||||
none = 'none'
|
||||
skip = 'skip'
|
||||
|
||||
class LatentNetReturn(NamedTuple):
|
||||
pred: torch.Tensor = None
|
||||
|
||||
@dataclass
|
||||
class MLPSkipNetConfig(BaseConfig):
|
||||
"""
|
||||
default MLP for the latent DPM in the paper!
|
||||
"""
|
||||
num_channels: int
|
||||
skip_layers: Tuple[int]
|
||||
num_hid_channels: int
|
||||
num_layers: int
|
||||
num_time_emb_channels: int = 64
|
||||
activation: Activation = Activation.silu
|
||||
use_norm: bool = True
|
||||
condition_bias: float = 1
|
||||
dropout: float = 0
|
||||
last_act: Activation = Activation.none
|
||||
num_time_layers: int = 2
|
||||
time_last_act: bool = False
|
||||
|
||||
def make_model(self):
|
||||
return MLPSkipNet(self)
|
||||
|
||||
|
||||
class MLPSkipNet(nn.Module):
|
||||
"""
|
||||
concat x to hidden layers
|
||||
|
||||
default MLP for the latent DPM in the paper!
|
||||
"""
|
||||
def __init__(self, conf: MLPSkipNetConfig):
|
||||
super().__init__()
|
||||
self.conf = conf
|
||||
|
||||
layers = []
|
||||
for i in range(conf.num_time_layers):
|
||||
if i == 0:
|
||||
a = conf.num_time_emb_channels
|
||||
b = conf.num_channels
|
||||
else:
|
||||
a = conf.num_channels
|
||||
b = conf.num_channels
|
||||
layers.append(nn.Linear(a, b))
|
||||
if i < conf.num_time_layers - 1 or conf.time_last_act:
|
||||
layers.append(conf.activation.get_act())
|
||||
self.time_embed = nn.Sequential(*layers)
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
for i in range(conf.num_layers):
|
||||
if i == 0:
|
||||
act = conf.activation
|
||||
norm = conf.use_norm
|
||||
cond = True
|
||||
a, b = conf.num_channels, conf.num_hid_channels
|
||||
dropout = conf.dropout
|
||||
elif i == conf.num_layers - 1:
|
||||
act = Activation.none
|
||||
norm = False
|
||||
cond = False
|
||||
a, b = conf.num_hid_channels, conf.num_channels
|
||||
dropout = 0
|
||||
else:
|
||||
act = conf.activation
|
||||
norm = conf.use_norm
|
||||
cond = True
|
||||
a, b = conf.num_hid_channels, conf.num_hid_channels
|
||||
dropout = conf.dropout
|
||||
|
||||
if i in conf.skip_layers:
|
||||
a += conf.num_channels
|
||||
|
||||
self.layers.append(
|
||||
MLPLNAct(
|
||||
a,
|
||||
b,
|
||||
norm=norm,
|
||||
activation=act,
|
||||
cond_channels=conf.num_channels,
|
||||
use_cond=cond,
|
||||
condition_bias=conf.condition_bias,
|
||||
dropout=dropout,
|
||||
))
|
||||
self.last_act = conf.last_act.get_act()
|
||||
|
||||
def forward(self, x, t, **kwargs):
|
||||
t = timestep_embedding(t, self.conf.num_time_emb_channels)
|
||||
cond = self.time_embed(t)
|
||||
h = x
|
||||
for i in range(len(self.layers)):
|
||||
if i in self.conf.skip_layers:
|
||||
h = torch.cat([h, x], dim=1)
|
||||
h = self.layers[i].forward(x=h, cond=cond)
|
||||
h = self.last_act(h)
|
||||
return LatentNetReturn(h)
|
||||
|
||||
|
||||
class MLPLNAct(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
norm: bool,
|
||||
use_cond: bool,
|
||||
activation: Activation,
|
||||
cond_channels: int,
|
||||
condition_bias: float = 0,
|
||||
dropout: float = 0,
|
||||
):
|
||||
super().__init__()
|
||||
self.activation = activation
|
||||
self.condition_bias = condition_bias
|
||||
self.use_cond = use_cond
|
||||
|
||||
self.linear = nn.Linear(in_channels, out_channels)
|
||||
self.act = activation.get_act()
|
||||
if self.use_cond:
|
||||
self.linear_emb = nn.Linear(cond_channels, out_channels)
|
||||
self.cond_layers = nn.Sequential(self.act, self.linear_emb)
|
||||
if norm:
|
||||
self.norm = nn.LayerNorm(out_channels)
|
||||
else:
|
||||
self.norm = nn.Identity()
|
||||
|
||||
if dropout > 0:
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
else:
|
||||
self.dropout = nn.Identity()
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
for module in self.modules():
|
||||
if isinstance(module, nn.Linear):
|
||||
if self.activation == Activation.relu:
|
||||
init.kaiming_normal_(module.weight,
|
||||
a=0,
|
||||
nonlinearity='relu')
|
||||
elif self.activation == Activation.lrelu:
|
||||
init.kaiming_normal_(module.weight,
|
||||
a=0.2,
|
||||
nonlinearity='leaky_relu')
|
||||
elif self.activation == Activation.silu:
|
||||
init.kaiming_normal_(module.weight,
|
||||
a=0,
|
||||
nonlinearity='relu')
|
||||
else:
|
||||
pass
|
||||
|
||||
def forward(self, x, cond=None):
|
||||
x = self.linear(x)
|
||||
if self.use_cond:
|
||||
cond = self.cond_layers(cond)
|
||||
cond = (cond, None)
|
||||
|
||||
x = x * (self.condition_bias + cond[0])
|
||||
if cond[1] is not None:
|
||||
x = x + cond[1]
|
||||
x = self.norm(x)
|
||||
else:
|
||||
x = self.norm(x)
|
||||
x = self.act(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
135
model/nn.py
Executable file
135
model/nn.py
Executable file
|
@ -0,0 +1,135 @@
|
|||
"""
|
||||
Various utilities for neural networks.
|
||||
"""
|
||||
import math
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
||||
class SiLU(nn.Module):
|
||||
# @th.jit.script
|
||||
def forward(self, x):
|
||||
return x * th.sigmoid(x)
|
||||
|
||||
|
||||
class GroupNorm32(nn.GroupNorm):
|
||||
def forward(self, x):
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
|
||||
|
||||
def conv_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D convolution module.
|
||||
"""
|
||||
assert dims==1
|
||||
if dims == 1:
|
||||
return nn.Conv1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.Conv2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.Conv3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
def linear(*args, **kwargs):
|
||||
"""
|
||||
Create a linear module.
|
||||
"""
|
||||
return nn.Linear(*args, **kwargs)
|
||||
|
||||
|
||||
def avg_pool_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D average pooling module.
|
||||
"""
|
||||
assert dims==1
|
||||
if dims == 1:
|
||||
return nn.AvgPool1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.AvgPool2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.AvgPool3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
def update_ema(target_params, source_params, rate=0.99):
|
||||
"""
|
||||
Update target parameters to be closer to those of source parameters using
|
||||
an exponential moving average.
|
||||
|
||||
:param target_params: the target parameter sequence.
|
||||
:param source_params: the source parameter sequence.
|
||||
:param rate: the EMA rate (closer to 1 means slower).
|
||||
"""
|
||||
for targ, src in zip(target_params, source_params):
|
||||
targ.detach().mul_(rate).add_(src, alpha=1 - rate)
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
def scale_module(module, scale):
|
||||
"""
|
||||
Scale the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().mul_(scale)
|
||||
return module
|
||||
|
||||
|
||||
def mean_flat(tensor):
|
||||
"""
|
||||
Take the mean over all non-batch dimensions.
|
||||
"""
|
||||
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
||||
|
||||
|
||||
def normalization(channels):
|
||||
"""
|
||||
Make a standard normalization layer.
|
||||
|
||||
:param channels: number of input channels.
|
||||
:return: an nn.Module for normalization.
|
||||
"""
|
||||
return GroupNorm32(min(32, channels), channels)
|
||||
|
||||
|
||||
def timestep_embedding(timesteps, dim, max_period=10000):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
|
||||
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an [N x dim] Tensor of positional embeddings.
|
||||
"""
|
||||
half = dim // 2
|
||||
freqs = th.exp(-math.log(max_period) *
|
||||
th.arange(start=0, end=half, dtype=th.float32) /
|
||||
half).to(device=timesteps.device)
|
||||
args = timesteps[:, None].float() * freqs[None]
|
||||
embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = th.cat(
|
||||
[embedding, th.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
|
||||
def torch_checkpoint(func, args, flag, preserve_rng_state=False):
|
||||
# torch's gradient checkpoint works with automatic mixed precision, given torch >= 1.8
|
||||
if flag:
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
func, *args, preserve_rng_state=preserve_rng_state)
|
||||
else:
|
||||
return func(*args)
|
505
model/unet.py
Normal file
505
model/unet.py
Normal file
|
@ -0,0 +1,505 @@
|
|||
import math
|
||||
from dataclasses import dataclass
|
||||
from numbers import Number
|
||||
from typing import NamedTuple, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from choices import *
|
||||
from config_base import BaseConfig
|
||||
from .blocks import *
|
||||
|
||||
from .nn import (conv_nd, linear, normalization, timestep_embedding,
|
||||
torch_checkpoint, zero_module)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BeatGANsUNetConfig(BaseConfig):
|
||||
image_size: int = 64
|
||||
in_channels: int = 2
|
||||
model_channels: int = 64
|
||||
out_channels: int = 2
|
||||
num_res_blocks: int = 2
|
||||
num_input_res_blocks: int = None
|
||||
embed_channels: int = 512
|
||||
attention_resolutions: Tuple[int] = (16, )
|
||||
time_embed_channels: int = None
|
||||
dropout: float = 0.1
|
||||
channel_mult: Tuple[int] = (1, 2, 4, 8)
|
||||
input_channel_mult: Tuple[int] = None
|
||||
conv_resample: bool = True
|
||||
dims: int = 2
|
||||
num_classes: int = None
|
||||
use_checkpoint: bool = False
|
||||
num_heads: int = 1
|
||||
num_head_channels: int = -1
|
||||
num_heads_upsample: int = -1
|
||||
resblock_updown: bool = True
|
||||
use_new_attention_order: bool = False
|
||||
resnet_two_cond: bool = False
|
||||
resnet_cond_channels: int = None
|
||||
resnet_use_zero_module: bool = True
|
||||
attn_checkpoint: bool = False
|
||||
|
||||
num_users: int = None
|
||||
|
||||
def make_model(self):
|
||||
return BeatGANsUNetModel(self)
|
||||
|
||||
|
||||
class BeatGANsUNetModel(nn.Module):
|
||||
def __init__(self, conf: BeatGANsUNetConfig):
|
||||
super().__init__()
|
||||
self.conf = conf
|
||||
|
||||
if conf.num_heads_upsample == -1:
|
||||
self.num_heads_upsample = conf.num_heads
|
||||
|
||||
self.dtype = th.float32
|
||||
|
||||
self.time_emb_channels = conf.time_embed_channels or conf.model_channels
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(self.time_emb_channels, conf.embed_channels),
|
||||
nn.SiLU(),
|
||||
linear(conf.embed_channels, conf.embed_channels),
|
||||
)
|
||||
|
||||
if conf.num_classes is not None:
|
||||
self.label_emb = nn.Embedding(conf.num_classes,
|
||||
conf.embed_channels)
|
||||
|
||||
ch = input_ch = int(conf.channel_mult[0] * conf.model_channels)
|
||||
self.input_blocks = nn.ModuleList([
|
||||
TimestepEmbedSequential(
|
||||
conv_nd(conf.dims, conf.in_channels, ch, 3, padding=1))
|
||||
])
|
||||
|
||||
kwargs = dict(
|
||||
use_condition=True,
|
||||
two_cond=conf.resnet_two_cond,
|
||||
use_zero_module=conf.resnet_use_zero_module,
|
||||
cond_emb_channels=conf.resnet_cond_channels,
|
||||
)
|
||||
|
||||
self._feature_size = ch
|
||||
|
||||
# input_block_chans = [ch]
|
||||
input_block_chans = [[] for _ in range(len(conf.channel_mult))]
|
||||
input_block_chans[0].append(ch)
|
||||
|
||||
# number of blocks at each resolution
|
||||
self.input_num_blocks = [0 for _ in range(len(conf.channel_mult))]
|
||||
self.input_num_blocks[0] = 1
|
||||
self.output_num_blocks = [0 for _ in range(len(conf.channel_mult))]
|
||||
|
||||
ds = 1
|
||||
resolution = conf.image_size
|
||||
for level, mult in enumerate(conf.input_channel_mult
|
||||
or conf.channel_mult):
|
||||
for _ in range(conf.num_input_res_blocks or conf.num_res_blocks):
|
||||
layers = [
|
||||
ResBlockConfig(
|
||||
ch,
|
||||
conf.embed_channels,
|
||||
conf.dropout,
|
||||
out_channels=int(mult * conf.model_channels),
|
||||
dims=conf.dims,
|
||||
use_checkpoint=conf.use_checkpoint,
|
||||
**kwargs,
|
||||
).make_model()
|
||||
]
|
||||
ch = int(mult * conf.model_channels)
|
||||
if resolution in conf.attention_resolutions:
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=conf.use_checkpoint
|
||||
or conf.attn_checkpoint,
|
||||
num_heads=conf.num_heads,
|
||||
num_head_channels=conf.num_head_channels,
|
||||
use_new_attention_order=conf.
|
||||
use_new_attention_order,
|
||||
))
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
input_block_chans[level].append(ch)
|
||||
self.input_num_blocks[level] += 1
|
||||
if level != len(conf.channel_mult) - 1:
|
||||
resolution //= 2
|
||||
out_ch = ch
|
||||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
ResBlockConfig(
|
||||
ch,
|
||||
conf.embed_channels,
|
||||
conf.dropout,
|
||||
out_channels=out_ch,
|
||||
dims=conf.dims,
|
||||
use_checkpoint=conf.use_checkpoint,
|
||||
down=True,
|
||||
**kwargs,
|
||||
).make_model() if conf.
|
||||
resblock_updown else Downsample(ch,
|
||||
conf.conv_resample,
|
||||
dims=conf.dims,
|
||||
out_channels=out_ch)))
|
||||
ch = out_ch
|
||||
input_block_chans[level + 1].append(ch)
|
||||
self.input_num_blocks[level + 1] += 1
|
||||
ds *= 2
|
||||
self._feature_size += ch
|
||||
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
ResBlockConfig(
|
||||
ch,
|
||||
conf.embed_channels,
|
||||
conf.dropout,
|
||||
dims=conf.dims,
|
||||
use_checkpoint=conf.use_checkpoint,
|
||||
**kwargs,
|
||||
).make_model(),
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=conf.use_checkpoint or conf.attn_checkpoint,
|
||||
num_heads=conf.num_heads,
|
||||
num_head_channels=conf.num_head_channels,
|
||||
use_new_attention_order=conf.use_new_attention_order,
|
||||
),
|
||||
ResBlockConfig(
|
||||
ch,
|
||||
conf.embed_channels,
|
||||
conf.dropout,
|
||||
dims=conf.dims,
|
||||
use_checkpoint=conf.use_checkpoint,
|
||||
**kwargs,
|
||||
).make_model(),
|
||||
)
|
||||
self._feature_size += ch
|
||||
|
||||
self.output_blocks = nn.ModuleList([])
|
||||
for level, mult in list(enumerate(conf.channel_mult))[::-1]:
|
||||
for i in range(conf.num_res_blocks + 1):
|
||||
try:
|
||||
ich = input_block_chans[level].pop()
|
||||
except IndexError:
|
||||
ich = 0
|
||||
layers = [
|
||||
ResBlockConfig(
|
||||
channels=ch + ich,
|
||||
emb_channels=conf.embed_channels,
|
||||
dropout=conf.dropout,
|
||||
out_channels=int(conf.model_channels * mult),
|
||||
dims=conf.dims,
|
||||
use_checkpoint=conf.use_checkpoint,
|
||||
has_lateral=True if ich > 0 else False,
|
||||
lateral_channels=None,
|
||||
**kwargs,
|
||||
).make_model()
|
||||
]
|
||||
ch = int(conf.model_channels * mult)
|
||||
if resolution in conf.attention_resolutions:
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=conf.use_checkpoint
|
||||
or conf.attn_checkpoint,
|
||||
num_heads=self.num_heads_upsample,
|
||||
num_head_channels=conf.num_head_channels,
|
||||
use_new_attention_order=conf.
|
||||
use_new_attention_order,
|
||||
))
|
||||
if level and i == conf.num_res_blocks:
|
||||
resolution *= 2
|
||||
out_ch = ch
|
||||
layers.append(
|
||||
ResBlockConfig(
|
||||
ch,
|
||||
conf.embed_channels,
|
||||
conf.dropout,
|
||||
out_channels=out_ch,
|
||||
dims=conf.dims,
|
||||
use_checkpoint=conf.use_checkpoint,
|
||||
up=True,
|
||||
**kwargs,
|
||||
).make_model() if (
|
||||
conf.resblock_updown
|
||||
) else Upsample(ch,
|
||||
conf.conv_resample,
|
||||
dims=conf.dims,
|
||||
out_channels=out_ch))
|
||||
ds //= 2
|
||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self.output_num_blocks[level] += 1
|
||||
self._feature_size += ch
|
||||
|
||||
if conf.resnet_use_zero_module:
|
||||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
zero_module(
|
||||
conv_nd(conf.dims,
|
||||
input_ch,
|
||||
conf.out_channels,
|
||||
3,
|
||||
padding=1)),
|
||||
)
|
||||
else:
|
||||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
conv_nd(conf.dims, input_ch, conf.out_channels, 3, padding=1),
|
||||
)
|
||||
|
||||
def forward(self, x, t, y=None, **kwargs):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
|
||||
:param x: an [N x C x ...] Tensor of inputs.
|
||||
:param timesteps: a 1-D batch of timesteps.
|
||||
:param y: an [N] Tensor of labels, if class-conditional.
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
assert (y is not None) == (
|
||||
self.conf.num_classes is not None
|
||||
), "must specify y if and only if the model is class-conditional"
|
||||
|
||||
# hs = []
|
||||
hs = [[] for _ in range(len(self.conf.channel_mult))]
|
||||
emb = self.time_embed(timestep_embedding(t, self.time_emb_channels))
|
||||
|
||||
if self.conf.num_classes is not None:
|
||||
raise NotImplementedError()
|
||||
|
||||
h = x.type(self.dtype)
|
||||
k = 0
|
||||
for i in range(len(self.input_num_blocks)):
|
||||
for j in range(self.input_num_blocks[i]):
|
||||
h = self.input_blocks[k](h, emb=emb)
|
||||
hs[i].append(h)
|
||||
k += 1
|
||||
assert k == len(self.input_blocks)
|
||||
|
||||
# middle blocks
|
||||
h = self.middle_block(h, emb=emb)
|
||||
|
||||
# output blocks
|
||||
k = 0
|
||||
for i in range(len(self.output_num_blocks)):
|
||||
for j in range(self.output_num_blocks[i]):
|
||||
try:
|
||||
lateral = hs[-i - 1].pop()
|
||||
except IndexError:
|
||||
lateral = None
|
||||
h = self.output_blocks[k](h, emb=emb, lateral=lateral)
|
||||
k += 1
|
||||
|
||||
h = h.type(x.dtype)
|
||||
pred = self.out(h)
|
||||
return Return(pred=pred)
|
||||
|
||||
class Return(NamedTuple):
|
||||
pred: th.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class BeatGANsEncoderConfig(BaseConfig):
|
||||
image_size: int
|
||||
in_channels: int
|
||||
model_channels: int
|
||||
out_hid_channels: int
|
||||
out_channels: int
|
||||
num_res_blocks: int
|
||||
attention_resolutions: Tuple[int]
|
||||
dropout: float = 0
|
||||
channel_mult: Tuple[int] = (1, 2, 4, 8)
|
||||
use_time_condition: bool = True
|
||||
conv_resample: bool = True
|
||||
dims: int = 2
|
||||
use_checkpoint: bool = False
|
||||
num_heads: int = 1
|
||||
num_head_channels: int = -1
|
||||
resblock_updown: bool = False
|
||||
use_new_attention_order: bool = False
|
||||
pool: str = 'adaptivenonzero'
|
||||
|
||||
def make_model(self):
|
||||
return BeatGANsEncoderModel(self)
|
||||
|
||||
class BeatGANsEncoderModel(nn.Module):
|
||||
"""
|
||||
The half UNet model with attention and timestep embedding.
|
||||
|
||||
For usage, see UNet.
|
||||
"""
|
||||
def __init__(self, conf: BeatGANsEncoderConfig):
|
||||
super().__init__()
|
||||
self.conf = conf
|
||||
self.dtype = th.float32
|
||||
|
||||
if conf.use_time_condition:
|
||||
time_embed_dim = conf.model_channels * 4
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(conf.model_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
else:
|
||||
time_embed_dim = None
|
||||
|
||||
ch = int(conf.channel_mult[0] * conf.model_channels)
|
||||
self.input_blocks = nn.ModuleList([
|
||||
TimestepEmbedSequential(
|
||||
conv_nd(conf.dims, conf.in_channels, ch, 3, padding=1))
|
||||
])
|
||||
self._feature_size = ch
|
||||
input_block_chans = [ch]
|
||||
ds = 1
|
||||
resolution = conf.image_size
|
||||
for level, mult in enumerate(conf.channel_mult):
|
||||
for _ in range(conf.num_res_blocks):
|
||||
layers = [
|
||||
ResBlockConfig(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
conf.dropout,
|
||||
out_channels=int(mult * conf.model_channels),
|
||||
dims=conf.dims,
|
||||
use_condition=conf.use_time_condition,
|
||||
use_checkpoint=conf.use_checkpoint,
|
||||
).make_model()
|
||||
]
|
||||
ch = int(mult * conf.model_channels)
|
||||
if resolution in conf.attention_resolutions:
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=conf.use_checkpoint,
|
||||
num_heads=conf.num_heads,
|
||||
num_head_channels=conf.num_head_channels,
|
||||
use_new_attention_order=conf.
|
||||
use_new_attention_order,
|
||||
))
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
input_block_chans.append(ch)
|
||||
if level != len(conf.channel_mult) - 1:
|
||||
resolution //= 2
|
||||
out_ch = ch
|
||||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
ResBlockConfig(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
conf.dropout,
|
||||
out_channels=out_ch,
|
||||
dims=conf.dims,
|
||||
use_condition=conf.use_time_condition,
|
||||
use_checkpoint=conf.use_checkpoint,
|
||||
down=True,
|
||||
).make_model() if (
|
||||
conf.resblock_updown
|
||||
) else Downsample(ch,
|
||||
conf.conv_resample,
|
||||
dims=conf.dims,
|
||||
out_channels=out_ch)))
|
||||
ch = out_ch
|
||||
input_block_chans.append(ch)
|
||||
ds *= 2
|
||||
self._feature_size += ch
|
||||
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
ResBlockConfig(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
conf.dropout,
|
||||
dims=conf.dims,
|
||||
use_condition=conf.use_time_condition,
|
||||
use_checkpoint=conf.use_checkpoint,
|
||||
).make_model(),
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=conf.use_checkpoint,
|
||||
num_heads=conf.num_heads,
|
||||
num_head_channels=conf.num_head_channels,
|
||||
use_new_attention_order=conf.use_new_attention_order,
|
||||
),
|
||||
ResBlockConfig(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
conf.dropout,
|
||||
dims=conf.dims,
|
||||
use_condition=conf.use_time_condition,
|
||||
use_checkpoint=conf.use_checkpoint,
|
||||
).make_model(),
|
||||
)
|
||||
self._feature_size += ch
|
||||
if conf.pool == "adaptivenonzero":
|
||||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
nn.AdaptiveAvgPool1d(1),
|
||||
conv_nd(conf.dims, ch, conf.out_channels, 1),
|
||||
nn.Flatten(),
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unexpected {conf.pool} pooling")
|
||||
|
||||
def forward(self, x, t=None, return_2d_feature=False):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
|
||||
:param x: an [N x C x ...] Tensor of inputs.
|
||||
:param timesteps: a 1-D batch of timesteps.
|
||||
:return: an [N x K] Tensor of outputs.
|
||||
"""
|
||||
if self.conf.use_time_condition:
|
||||
emb = self.time_embed(timestep_embedding(t, self.model_channels))
|
||||
else:
|
||||
emb = None
|
||||
|
||||
results = []
|
||||
h = x.type(self.dtype)
|
||||
for module in self.input_blocks:
|
||||
h = module(h, emb=emb)
|
||||
if self.conf.pool.startswith("spatial"):
|
||||
results.append(h.type(x.dtype).mean(dim=(2, 3)))
|
||||
h = self.middle_block(h, emb=emb)
|
||||
if self.conf.pool.startswith("spatial"):
|
||||
results.append(h.type(x.dtype).mean(dim=(2, 3)))
|
||||
h = th.cat(results, axis=-1)
|
||||
else:
|
||||
h = h.type(x.dtype)
|
||||
|
||||
h_2d = h
|
||||
h = self.out(h)
|
||||
|
||||
if return_2d_feature:
|
||||
return h, h_2d
|
||||
else:
|
||||
return h
|
||||
|
||||
def forward_flatten(self, x):
|
||||
"""
|
||||
transform the last 2d feature into a flatten vector
|
||||
"""
|
||||
h = self.out(x)
|
||||
return h
|
||||
|
||||
|
||||
class SuperResModel(BeatGANsUNetModel):
|
||||
"""
|
||||
A UNetModel that performs super-resolution.
|
||||
|
||||
Expects an extra kwarg `low_res` to condition on a low-resolution image.
|
||||
"""
|
||||
def __init__(self, image_size, in_channels, *args, **kwargs):
|
||||
super().__init__(image_size, in_channels * 2, *args, **kwargs)
|
||||
|
||||
def forward(self, x, timesteps, low_res=None, **kwargs):
|
||||
_, _, new_height, new_width = x.shape
|
||||
upsampled = F.interpolate(low_res, (new_height, new_width),
|
||||
mode="bilinear")
|
||||
x = th.cat([x, upsampled], dim=1)
|
||||
return super().forward(x, timesteps, **kwargs)
|
310
model/unet_autoenc.py
Normal file
310
model/unet_autoenc.py
Normal file
|
@ -0,0 +1,310 @@
|
|||
import torch
|
||||
from torch import Tensor, nn
|
||||
from torch.nn.functional import silu
|
||||
from .latentnet import *
|
||||
from .unet import *
|
||||
from choices import *
|
||||
|
||||
|
||||
@dataclass
|
||||
class BeatGANsAutoencConfig(BeatGANsUNetConfig):
|
||||
enc_out_channels: int = 512
|
||||
enc_attn_resolutions: Tuple[int] = None
|
||||
enc_pool: str = 'depthconv'
|
||||
enc_num_res_block: int = 2
|
||||
enc_channel_mult: Tuple[int] = None
|
||||
enc_grad_checkpoint: bool = False
|
||||
latent_net_conf: MLPSkipNetConfig = None
|
||||
|
||||
def make_model(self):
|
||||
return BeatGANsAutoencModel(self)
|
||||
|
||||
|
||||
class BeatGANsAutoencModel(BeatGANsUNetModel):
|
||||
def __init__(self, conf: BeatGANsAutoencConfig):
|
||||
super().__init__(conf)
|
||||
self.conf = conf
|
||||
|
||||
self.time_embed = TimeStyleSeperateEmbed(
|
||||
time_channels=conf.model_channels,
|
||||
time_out_channels=conf.embed_channels,
|
||||
)
|
||||
|
||||
self.encoder = BeatGANsEncoderConfig(
|
||||
image_size=conf.image_size,
|
||||
in_channels=conf.in_channels,
|
||||
model_channels=conf.model_channels,
|
||||
out_hid_channels=conf.enc_out_channels,
|
||||
out_channels=conf.enc_out_channels,
|
||||
num_res_blocks=conf.enc_num_res_block,
|
||||
attention_resolutions=(conf.enc_attn_resolutions
|
||||
or conf.attention_resolutions),
|
||||
dropout=conf.dropout,
|
||||
channel_mult=conf.enc_channel_mult or conf.channel_mult,
|
||||
use_time_condition=False,
|
||||
conv_resample=conf.conv_resample,
|
||||
dims=conf.dims,
|
||||
use_checkpoint=conf.use_checkpoint or conf.enc_grad_checkpoint,
|
||||
num_heads=conf.num_heads,
|
||||
num_head_channels=conf.num_head_channels,
|
||||
resblock_updown=conf.resblock_updown,
|
||||
use_new_attention_order=conf.use_new_attention_order,
|
||||
pool=conf.enc_pool,
|
||||
).make_model()
|
||||
|
||||
self.user_classifier = UserClassifier(conf.enc_out_channels//2, conf.num_users)
|
||||
self.non_user_classifier = UserClassifierGradientReverse(conf.enc_out_channels//2, conf.num_users)
|
||||
|
||||
if conf.latent_net_conf is not None:
|
||||
self.latent_net = conf.latent_net_conf.make_model()
|
||||
|
||||
def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
|
||||
"""
|
||||
Reparameterization trick to sample from N(mu, var) from
|
||||
N(0,1).
|
||||
:param mu: (Tensor) Mean of the latent Gaussian [B x D]
|
||||
:param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
|
||||
:return: (Tensor) [B x D]
|
||||
"""
|
||||
assert self.conf.is_stochastic
|
||||
std = torch.exp(0.5 * logvar)
|
||||
eps = torch.randn_like(std)
|
||||
return eps * std + mu
|
||||
|
||||
def sample_z(self, n: int, device):
|
||||
assert self.conf.is_stochastic
|
||||
return torch.randn(n, self.conf.enc_out_channels, device=device)
|
||||
|
||||
def noise_to_cond(self, noise: Tensor):
|
||||
raise NotImplementedError()
|
||||
assert self.conf.noise_net_conf is not None
|
||||
return self.noise_net.forward(noise)
|
||||
|
||||
def encode(self, x):
|
||||
cond = self.encoder.forward(x)
|
||||
return {'cond': cond}
|
||||
|
||||
@property
|
||||
def stylespace_sizes(self):
|
||||
modules = list(self.input_blocks.modules()) + list(
|
||||
self.middle_block.modules()) + list(self.output_blocks.modules())
|
||||
sizes = []
|
||||
for module in modules:
|
||||
if isinstance(module, ResBlock):
|
||||
linear = module.cond_emb_layers[-1]
|
||||
sizes.append(linear.weight.shape[0])
|
||||
return sizes
|
||||
|
||||
def encode_stylespace(self, x, return_vector: bool = True):
|
||||
"""
|
||||
encode to style space
|
||||
"""
|
||||
modules = list(self.input_blocks.modules()) + list(
|
||||
self.middle_block.modules()) + list(self.output_blocks.modules())
|
||||
cond = self.encoder.forward(x)
|
||||
S = []
|
||||
for module in modules:
|
||||
if isinstance(module, ResBlock):
|
||||
s = module.cond_emb_layers.forward(cond)
|
||||
S.append(s)
|
||||
|
||||
if return_vector:
|
||||
return torch.cat(S, dim=1)
|
||||
else:
|
||||
return S
|
||||
|
||||
def forward(self,
|
||||
x,
|
||||
t,
|
||||
y=None,
|
||||
x_start=None,
|
||||
cond=None,
|
||||
style=None,
|
||||
noise=None,
|
||||
t_cond=None,
|
||||
**kwargs):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
|
||||
Args:
|
||||
x_start: the original image to encode
|
||||
cond: output of the encoder
|
||||
noise: random noise (to predict the cond)
|
||||
"""
|
||||
if t_cond is None:
|
||||
t_cond = t
|
||||
|
||||
if noise is not None:
|
||||
cond = self.noise_to_cond(noise)
|
||||
|
||||
if cond is None:
|
||||
if x is not None:
|
||||
assert len(x) == len(x_start), f'{len(x)} != {len(x_start)}'
|
||||
tmp = self.encode(x_start)
|
||||
cond = tmp['cond']
|
||||
|
||||
if t is not None:
|
||||
_t_emb = timestep_embedding(t, self.conf.model_channels)
|
||||
_t_cond_emb = timestep_embedding(t_cond, self.conf.model_channels)
|
||||
else:
|
||||
_t_emb = None
|
||||
_t_cond_emb = None
|
||||
|
||||
if self.conf.resnet_two_cond:
|
||||
res = self.time_embed.forward(
|
||||
time_emb=_t_emb,
|
||||
cond=cond,
|
||||
time_cond_emb=_t_cond_emb
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
if self.conf.resnet_two_cond:
|
||||
emb = res.time_emb
|
||||
cond_emb = res.emb
|
||||
else:
|
||||
emb = res.emb
|
||||
cond_emb = None
|
||||
|
||||
style = style or res.style
|
||||
|
||||
assert (y is not None) == (
|
||||
self.conf.num_classes is not None
|
||||
), "must specify y if and only if the model is class-conditional"
|
||||
|
||||
|
||||
if self.conf.num_classes is not None:
|
||||
raise NotImplementedError()
|
||||
|
||||
enc_time_emb = emb
|
||||
mid_time_emb = emb
|
||||
dec_time_emb = emb
|
||||
enc_cond_emb = cond_emb
|
||||
mid_cond_emb = cond_emb
|
||||
dec_cond_emb = cond_emb
|
||||
|
||||
if self.conf.num_users is not None:
|
||||
user_pred = self.user_classifier(cond_emb[:, :self.conf.enc_out_channels // 2])
|
||||
non_user_pred = self.non_user_classifier(cond_emb[:, self.conf.enc_out_channels // 2:])
|
||||
|
||||
hs = [[] for _ in range(len(self.conf.channel_mult))]
|
||||
|
||||
if x is not None:
|
||||
h = x.type(self.dtype)
|
||||
|
||||
# input blocks
|
||||
k = 0
|
||||
for i in range(len(self.input_num_blocks)):
|
||||
for j in range(self.input_num_blocks[i]):
|
||||
h = self.input_blocks[k](h,
|
||||
emb=enc_time_emb,
|
||||
cond=enc_cond_emb)
|
||||
'''print(i, j, h.shape)
|
||||
if h.shape[-1]%2==1:
|
||||
pdb.set_trace()'''
|
||||
hs[i].append(h)
|
||||
k += 1
|
||||
assert k == len(self.input_blocks)
|
||||
|
||||
# middle blocks
|
||||
h = self.middle_block(h, emb=mid_time_emb, cond=mid_cond_emb)
|
||||
else:
|
||||
h = None
|
||||
hs = [[] for _ in range(len(self.conf.channel_mult))]
|
||||
|
||||
# output blocks
|
||||
k = 0
|
||||
for i in range(len(self.output_num_blocks)):
|
||||
for j in range(self.output_num_blocks[i]):
|
||||
try:
|
||||
lateral = hs[-i - 1].pop()
|
||||
except IndexError:
|
||||
lateral = None
|
||||
|
||||
h = self.output_blocks[k](h,
|
||||
emb=dec_time_emb,
|
||||
cond=dec_cond_emb,
|
||||
lateral=lateral)
|
||||
k += 1
|
||||
|
||||
pred = self.out(h)
|
||||
|
||||
return AutoencReturn(pred=pred, cond=cond, user_pred=user_pred, non_user_pred=non_user_pred)
|
||||
|
||||
class UserClassifier(nn.Module):
|
||||
def __init__(self, in_channels, num_classes):
|
||||
super().__init__()
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(in_channels, 256),
|
||||
nn.ReLU(),
|
||||
nn.Linear(256, num_classes),
|
||||
nn.Softmax(dim=1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.fc(x)
|
||||
|
||||
class GradReverse(torch.autograd.Function):
|
||||
"""
|
||||
Implement the gradient reversal layer for the convenience of domain adaptation neural network.
|
||||
The forward part is the identity function while the backward part is the negative function.
|
||||
"""
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
return x.view_as(x)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output.neg()
|
||||
|
||||
class GradientReversalLayer(nn.Module):
|
||||
def __init__(self):
|
||||
super(GradientReversalLayer, self).__init__()
|
||||
|
||||
def forward(self, inputs):
|
||||
return GradReverse.apply(inputs)
|
||||
|
||||
class UserClassifierGradientReverse(nn.Module):
|
||||
def __init__(self, in_channels, num_classes):
|
||||
super().__init__()
|
||||
self.grl = GradientReversalLayer()
|
||||
self.fc = UserClassifier(in_channels, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.grl(x)
|
||||
return self.fc(x)
|
||||
|
||||
class AutoencReturn(NamedTuple):
|
||||
pred: Tensor
|
||||
cond: Tensor = None
|
||||
user_pred: Tensor = None
|
||||
non_user_pred: Tensor = None
|
||||
|
||||
|
||||
class EmbedReturn(NamedTuple):
|
||||
emb: Tensor = None
|
||||
time_emb: Tensor = None
|
||||
style: Tensor = None
|
||||
|
||||
class TimeStyleSeperateEmbed(nn.Module):
|
||||
def __init__(self, time_channels, time_out_channels):
|
||||
super().__init__()
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(time_channels, time_out_channels),
|
||||
nn.SiLU(),
|
||||
linear(time_out_channels, time_out_channels),
|
||||
)
|
||||
self.cond_combine = nn.Sequential(
|
||||
nn.Linear(time_out_channels * 2, time_out_channels),
|
||||
nn.SiLU()
|
||||
)
|
||||
self.style = nn.Identity()
|
||||
|
||||
def forward(self, time_emb=None, cond=None, **kwargs):
|
||||
if time_emb is None:
|
||||
time_emb = None
|
||||
else:
|
||||
time_emb = self.time_embed(time_emb)
|
||||
|
||||
style = self.style(cond)
|
||||
return EmbedReturn(emb=style, time_emb=time_emb, style=style)
|
55
templates.py
Normal file
55
templates.py
Normal file
|
@ -0,0 +1,55 @@
|
|||
from experiment import *
|
||||
|
||||
def autoenc_base():
|
||||
conf = TrainConfig()
|
||||
conf.batch_size = 32
|
||||
conf.beatgans_gen_type = GenerativeType.ddim
|
||||
conf.beta_scheduler = 'linear'
|
||||
conf.data_name = 'ffhq'
|
||||
conf.diffusion_type = 'beatgans'
|
||||
conf.eval_ema_every_samples = 200_000
|
||||
conf.eval_every_samples = 200_000
|
||||
conf.fp16 = True
|
||||
conf.lr = 1e-4
|
||||
conf.model_name = ModelName.beatgans_autoenc
|
||||
conf.net_attn = (16, )
|
||||
conf.net_beatgans_attn_head = 1
|
||||
conf.net_beatgans_embed_channels = 128
|
||||
conf.net_beatgans_resnet_two_cond = True
|
||||
conf.net_ch_mult = (1, 2, 4, 8)
|
||||
conf.net_ch = 64
|
||||
conf.net_enc_channel_mult = (1, 2, 4, 8, 8)
|
||||
conf.net_enc_pool = 'adaptivenonzero'
|
||||
conf.sample_size = 32
|
||||
conf.T_eval = 20
|
||||
conf.T = 1000
|
||||
conf.make_model_conf()
|
||||
return conf
|
||||
|
||||
def mouse_autoenc(mode):
|
||||
num_users = {'Clarkson': 75}
|
||||
|
||||
conf = autoenc_base()
|
||||
conf.scale_up_gpus(1)
|
||||
conf.img_size = 256
|
||||
conf.net_ch = 128
|
||||
conf.net_ch_mult = (1, 1, 2, 2, 4, 4)
|
||||
conf.net_enc_channel_mult = (1, 1, 2, 2, 4, 4, 4)
|
||||
conf.eval_every_samples = 10_000_000
|
||||
conf.eval_ema_every_samples = 10_000_000
|
||||
conf.total_samples = 200_000_000
|
||||
conf.batch_size = 512
|
||||
conf.name = 'mouse_autoenc'
|
||||
|
||||
conf.pretrainDataset = 'Clarkson'
|
||||
conf.data_name = 'Clarkson'
|
||||
|
||||
conf.path = f'../mousedata/{conf.pretrainDataset}/'
|
||||
conf.AEwin = 8
|
||||
conf.slid = 1
|
||||
conf.timeWinFreq = 20
|
||||
conf.num_users = num_users[conf.pretrainDataset]
|
||||
|
||||
conf.mode = mode
|
||||
conf.make_model_conf()
|
||||
return conf
|
Loading…
Reference in a new issue