This commit is contained in:
Guanhua Zhang 2024-10-08 14:18:47 +02:00
commit b102c2a534
20 changed files with 4305 additions and 0 deletions

6
diffusion/__init__.py Normal file
View file

@ -0,0 +1,6 @@
from typing import Union
from .diffusion import SpacedDiffusionBeatGans, SpacedDiffusionBeatGansConfig
Sampler = Union[SpacedDiffusionBeatGans]
SamplerConfig = Union[SpacedDiffusionBeatGansConfig]

1086
diffusion/base.py Executable file

File diff suppressed because it is too large Load diff

154
diffusion/diffusion.py Executable file
View file

@ -0,0 +1,154 @@
from .base import *
from dataclasses import dataclass
def space_timesteps(num_timesteps, section_counts):
"""
Create a list of timesteps to use from an original diffusion process,
given the number of timesteps we want to take from equally-sized portions
of the original process.
For example, if there's 300 timesteps and the section counts are [10,15,20]
then the first 100 timesteps are strided to be 10 timesteps, the second 100
are strided to be 15 timesteps, and the final 100 are strided to be 20.
If the stride is a string starting with "ddim", then the fixed striding
from the DDIM paper is used, and only one section is allowed.
:param num_timesteps: the number of diffusion steps in the original
process to divide up.
:param section_counts: either a list of numbers, or a string containing
comma-separated numbers, indicating the step count
per section. As a special case, use "ddimN" where N
is a number of steps to use the striding from the
DDIM paper.
:return: a set of diffusion steps from the original process to use.
"""
if isinstance(section_counts, str):
if section_counts.startswith("ddim"):
desired_count = int(section_counts[len("ddim"):])
for i in range(1, num_timesteps):
if len(range(0, num_timesteps, i)) == desired_count:
return set(range(0, num_timesteps, i))
raise ValueError(
f"cannot create exactly {num_timesteps} steps with an integer stride"
)
section_counts = [int(x) for x in section_counts.split(",")]
size_per = num_timesteps // len(section_counts)
extra = num_timesteps % len(section_counts)
start_idx = 0
all_steps = []
for i, section_count in enumerate(section_counts):
size = size_per + (1 if i < extra else 0)
if size < section_count:
raise ValueError(
f"cannot divide section of {size} steps into {section_count}")
if section_count <= 1:
frac_stride = 1
else:
frac_stride = (size - 1) / (section_count - 1)
cur_idx = 0.0
taken_steps = []
for _ in range(section_count):
taken_steps.append(start_idx + round(cur_idx))
cur_idx += frac_stride
all_steps += taken_steps
start_idx += size
return set(all_steps)
@dataclass
class SpacedDiffusionBeatGansConfig(GaussianDiffusionBeatGansConfig):
use_timesteps: Tuple[int] = None
def make_sampler(self):
return SpacedDiffusionBeatGans(self)
class SpacedDiffusionBeatGans(GaussianDiffusionBeatGans):
"""
A diffusion process which can skip steps in a base diffusion process.
:param use_timesteps: a collection (sequence or set) of timesteps from the
original diffusion process to retain.
:param kwargs: the kwargs to create the base diffusion process.
"""
def __init__(self, conf: SpacedDiffusionBeatGansConfig):
self.conf = conf
self.use_timesteps = set(conf.use_timesteps)
self.timestep_map = []
self.original_num_steps = len(conf.betas)
base_diffusion = GaussianDiffusionBeatGans(conf)
last_alpha_cumprod = 1.0
new_betas = []
for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
if i in self.use_timesteps:
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
last_alpha_cumprod = alpha_cumprod
self.timestep_map.append(i)
conf.betas = np.array(new_betas)
super().__init__(conf)
def p_mean_variance(self, model: Model, *args, **kwargs):
return super().p_mean_variance(self._wrap_model(model), *args,
**kwargs)
def training_losses(self, model: Model, *args, **kwargs):
return super().training_losses(self._wrap_model(model), *args,
**kwargs)
def condition_mean(self, cond_fn, *args, **kwargs):
return super().condition_mean(self._wrap_model(cond_fn), *args,
**kwargs)
def condition_score(self, cond_fn, *args, **kwargs):
return super().condition_score(self._wrap_model(cond_fn), *args,
**kwargs)
def _wrap_model(self, model: Model):
if isinstance(model, _WrappedModel):
return model
return _WrappedModel(model, self.timestep_map, self.rescale_timesteps,
self.original_num_steps)
def _scale_timesteps(self, t):
return t
class _WrappedModel:
"""
converting the supplied t's to the old t's scales.
"""
def __init__(self, model, timestep_map, rescale_timesteps,
original_num_steps):
self.model = model
self.timestep_map = timestep_map
self.rescale_timesteps = rescale_timesteps
self.original_num_steps = original_num_steps
def forward(self, x, t, t_cond=None, **kwargs):
"""
Args:
t: t's with differrent ranges (can be << T due to smaller eval T) need to be converted to the original t's
t_cond: the same as t but can be of different values
"""
map_tensor = th.tensor(self.timestep_map,
device=t.device,
dtype=t.dtype)
def do(t):
new_ts = map_tensor[t]
if self.rescale_timesteps:
new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
return new_ts
if t_cond is not None:
t_cond = do(t_cond)
return self.model(x=x, t=do(t), t_cond=t_cond, **kwargs)
def __getattr__(self, name):
if hasattr(self.model, name):
func = getattr(self.model, name)
return func
raise AttributeError(name)

62
diffusion/resample.py Normal file
View file

@ -0,0 +1,62 @@
from abc import ABC, abstractmethod
import numpy as np
import torch as th
def create_named_schedule_sampler(name, diffusion):
"""
Create a ScheduleSampler from a library of pre-defined samplers.
:param name: the name of the sampler.
:param diffusion: the diffusion object to sample for.
"""
if name == "uniform":
return UniformSampler(diffusion)
else:
raise NotImplementedError(f"unknown schedule sampler: {name}")
class ScheduleSampler(ABC):
"""
A distribution over timesteps in the diffusion process, intended to reduce
variance of the objective.
By default, samplers perform unbiased importance sampling, in which the
objective's mean is unchanged.
However, subclasses may override sample() to change how the resampled
terms are reweighted, allowing for actual changes in the objective.
"""
@abstractmethod
def weights(self):
"""
Get a numpy array of weights, one per diffusion step.
The weights needn't be normalized, but must be positive.
"""
def sample(self, batch_size, device):
"""
Importance-sample timesteps for a batch.
:param batch_size: the number of timesteps.
:param device: the torch device to save to.
:return: a tuple (timesteps, weights):
- timesteps: a tensor of timestep indices.
- weights: a tensor of weights to scale the resulting losses.
"""
w = self.weights()
p = w / np.sum(w)
indices_np = np.random.choice(len(p), size=(batch_size, ), p=p)
indices = th.from_numpy(indices_np).long().to(device)
weights_np = 1 / (len(p) * p[indices_np])
weights = th.from_numpy(weights_np).float().to(device)
return indices, weights
class UniformSampler(ScheduleSampler):
def __init__(self, num_timesteps):
self._weights = np.ones([num_timesteps])
def weights(self):
return self._weights