init
This commit is contained in:
commit
b102c2a534
20 changed files with 4305 additions and 0 deletions
6
diffusion/__init__.py
Normal file
6
diffusion/__init__.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
from typing import Union
|
||||
|
||||
from .diffusion import SpacedDiffusionBeatGans, SpacedDiffusionBeatGansConfig
|
||||
|
||||
Sampler = Union[SpacedDiffusionBeatGans]
|
||||
SamplerConfig = Union[SpacedDiffusionBeatGansConfig]
|
1086
diffusion/base.py
Executable file
1086
diffusion/base.py
Executable file
File diff suppressed because it is too large
Load diff
154
diffusion/diffusion.py
Executable file
154
diffusion/diffusion.py
Executable file
|
@ -0,0 +1,154 @@
|
|||
from .base import *
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
def space_timesteps(num_timesteps, section_counts):
|
||||
"""
|
||||
Create a list of timesteps to use from an original diffusion process,
|
||||
given the number of timesteps we want to take from equally-sized portions
|
||||
of the original process.
|
||||
|
||||
For example, if there's 300 timesteps and the section counts are [10,15,20]
|
||||
then the first 100 timesteps are strided to be 10 timesteps, the second 100
|
||||
are strided to be 15 timesteps, and the final 100 are strided to be 20.
|
||||
|
||||
If the stride is a string starting with "ddim", then the fixed striding
|
||||
from the DDIM paper is used, and only one section is allowed.
|
||||
|
||||
:param num_timesteps: the number of diffusion steps in the original
|
||||
process to divide up.
|
||||
:param section_counts: either a list of numbers, or a string containing
|
||||
comma-separated numbers, indicating the step count
|
||||
per section. As a special case, use "ddimN" where N
|
||||
is a number of steps to use the striding from the
|
||||
DDIM paper.
|
||||
:return: a set of diffusion steps from the original process to use.
|
||||
"""
|
||||
if isinstance(section_counts, str):
|
||||
if section_counts.startswith("ddim"):
|
||||
desired_count = int(section_counts[len("ddim"):])
|
||||
for i in range(1, num_timesteps):
|
||||
if len(range(0, num_timesteps, i)) == desired_count:
|
||||
return set(range(0, num_timesteps, i))
|
||||
raise ValueError(
|
||||
f"cannot create exactly {num_timesteps} steps with an integer stride"
|
||||
)
|
||||
section_counts = [int(x) for x in section_counts.split(",")]
|
||||
size_per = num_timesteps // len(section_counts)
|
||||
extra = num_timesteps % len(section_counts)
|
||||
start_idx = 0
|
||||
all_steps = []
|
||||
for i, section_count in enumerate(section_counts):
|
||||
size = size_per + (1 if i < extra else 0)
|
||||
if size < section_count:
|
||||
raise ValueError(
|
||||
f"cannot divide section of {size} steps into {section_count}")
|
||||
if section_count <= 1:
|
||||
frac_stride = 1
|
||||
else:
|
||||
frac_stride = (size - 1) / (section_count - 1)
|
||||
cur_idx = 0.0
|
||||
taken_steps = []
|
||||
for _ in range(section_count):
|
||||
taken_steps.append(start_idx + round(cur_idx))
|
||||
cur_idx += frac_stride
|
||||
all_steps += taken_steps
|
||||
start_idx += size
|
||||
return set(all_steps)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpacedDiffusionBeatGansConfig(GaussianDiffusionBeatGansConfig):
|
||||
use_timesteps: Tuple[int] = None
|
||||
|
||||
def make_sampler(self):
|
||||
return SpacedDiffusionBeatGans(self)
|
||||
|
||||
|
||||
class SpacedDiffusionBeatGans(GaussianDiffusionBeatGans):
|
||||
"""
|
||||
A diffusion process which can skip steps in a base diffusion process.
|
||||
|
||||
:param use_timesteps: a collection (sequence or set) of timesteps from the
|
||||
original diffusion process to retain.
|
||||
:param kwargs: the kwargs to create the base diffusion process.
|
||||
"""
|
||||
def __init__(self, conf: SpacedDiffusionBeatGansConfig):
|
||||
self.conf = conf
|
||||
self.use_timesteps = set(conf.use_timesteps)
|
||||
self.timestep_map = []
|
||||
self.original_num_steps = len(conf.betas)
|
||||
|
||||
base_diffusion = GaussianDiffusionBeatGans(conf)
|
||||
last_alpha_cumprod = 1.0
|
||||
new_betas = []
|
||||
for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
|
||||
if i in self.use_timesteps:
|
||||
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
|
||||
last_alpha_cumprod = alpha_cumprod
|
||||
self.timestep_map.append(i)
|
||||
conf.betas = np.array(new_betas)
|
||||
super().__init__(conf)
|
||||
|
||||
def p_mean_variance(self, model: Model, *args, **kwargs):
|
||||
return super().p_mean_variance(self._wrap_model(model), *args,
|
||||
**kwargs)
|
||||
|
||||
def training_losses(self, model: Model, *args, **kwargs):
|
||||
return super().training_losses(self._wrap_model(model), *args,
|
||||
**kwargs)
|
||||
|
||||
def condition_mean(self, cond_fn, *args, **kwargs):
|
||||
return super().condition_mean(self._wrap_model(cond_fn), *args,
|
||||
**kwargs)
|
||||
|
||||
def condition_score(self, cond_fn, *args, **kwargs):
|
||||
return super().condition_score(self._wrap_model(cond_fn), *args,
|
||||
**kwargs)
|
||||
|
||||
def _wrap_model(self, model: Model):
|
||||
if isinstance(model, _WrappedModel):
|
||||
return model
|
||||
return _WrappedModel(model, self.timestep_map, self.rescale_timesteps,
|
||||
self.original_num_steps)
|
||||
|
||||
def _scale_timesteps(self, t):
|
||||
return t
|
||||
|
||||
|
||||
class _WrappedModel:
|
||||
"""
|
||||
converting the supplied t's to the old t's scales.
|
||||
"""
|
||||
def __init__(self, model, timestep_map, rescale_timesteps,
|
||||
original_num_steps):
|
||||
self.model = model
|
||||
self.timestep_map = timestep_map
|
||||
self.rescale_timesteps = rescale_timesteps
|
||||
self.original_num_steps = original_num_steps
|
||||
|
||||
def forward(self, x, t, t_cond=None, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
t: t's with differrent ranges (can be << T due to smaller eval T) need to be converted to the original t's
|
||||
t_cond: the same as t but can be of different values
|
||||
"""
|
||||
map_tensor = th.tensor(self.timestep_map,
|
||||
device=t.device,
|
||||
dtype=t.dtype)
|
||||
|
||||
def do(t):
|
||||
new_ts = map_tensor[t]
|
||||
if self.rescale_timesteps:
|
||||
new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
|
||||
return new_ts
|
||||
|
||||
if t_cond is not None:
|
||||
t_cond = do(t_cond)
|
||||
return self.model(x=x, t=do(t), t_cond=t_cond, **kwargs)
|
||||
|
||||
def __getattr__(self, name):
|
||||
if hasattr(self.model, name):
|
||||
func = getattr(self.model, name)
|
||||
return func
|
||||
raise AttributeError(name)
|
62
diffusion/resample.py
Normal file
62
diffusion/resample.py
Normal file
|
@ -0,0 +1,62 @@
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
|
||||
|
||||
def create_named_schedule_sampler(name, diffusion):
|
||||
"""
|
||||
Create a ScheduleSampler from a library of pre-defined samplers.
|
||||
|
||||
:param name: the name of the sampler.
|
||||
:param diffusion: the diffusion object to sample for.
|
||||
"""
|
||||
if name == "uniform":
|
||||
return UniformSampler(diffusion)
|
||||
else:
|
||||
raise NotImplementedError(f"unknown schedule sampler: {name}")
|
||||
|
||||
|
||||
class ScheduleSampler(ABC):
|
||||
"""
|
||||
A distribution over timesteps in the diffusion process, intended to reduce
|
||||
variance of the objective.
|
||||
|
||||
By default, samplers perform unbiased importance sampling, in which the
|
||||
objective's mean is unchanged.
|
||||
However, subclasses may override sample() to change how the resampled
|
||||
terms are reweighted, allowing for actual changes in the objective.
|
||||
"""
|
||||
@abstractmethod
|
||||
def weights(self):
|
||||
"""
|
||||
Get a numpy array of weights, one per diffusion step.
|
||||
|
||||
The weights needn't be normalized, but must be positive.
|
||||
"""
|
||||
|
||||
def sample(self, batch_size, device):
|
||||
"""
|
||||
Importance-sample timesteps for a batch.
|
||||
|
||||
:param batch_size: the number of timesteps.
|
||||
:param device: the torch device to save to.
|
||||
:return: a tuple (timesteps, weights):
|
||||
- timesteps: a tensor of timestep indices.
|
||||
- weights: a tensor of weights to scale the resulting losses.
|
||||
"""
|
||||
w = self.weights()
|
||||
p = w / np.sum(w)
|
||||
indices_np = np.random.choice(len(p), size=(batch_size, ), p=p)
|
||||
indices = th.from_numpy(indices_np).long().to(device)
|
||||
weights_np = 1 / (len(p) * p[indices_np])
|
||||
weights = th.from_numpy(weights_np).float().to(device)
|
||||
return indices, weights
|
||||
|
||||
|
||||
class UniformSampler(ScheduleSampler):
|
||||
def __init__(self, num_timesteps):
|
||||
self._weights = np.ones([num_timesteps])
|
||||
|
||||
def weights(self):
|
||||
return self._weights
|
Loading…
Add table
Add a link
Reference in a new issue