update readme
This commit is contained in:
parent
35ee4b75e8
commit
249a01f342
18 changed files with 4936 additions and 0 deletions
63
diffusion/resample.py
Normal file
63
diffusion/resample.py
Normal file
|
@ -0,0 +1,63 @@
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def create_named_schedule_sampler(name, diffusion):
|
||||
"""
|
||||
Create a ScheduleSampler from a library of pre-defined samplers.
|
||||
|
||||
:param name: the name of the sampler.
|
||||
:param diffusion: the diffusion object to sample for.
|
||||
"""
|
||||
if name == "uniform":
|
||||
return UniformSampler(diffusion)
|
||||
else:
|
||||
raise NotImplementedError(f"unknown schedule sampler: {name}")
|
||||
|
||||
|
||||
class ScheduleSampler(ABC):
|
||||
"""
|
||||
A distribution over timesteps in the diffusion process, intended to reduce
|
||||
variance of the objective.
|
||||
|
||||
By default, samplers perform unbiased importance sampling, in which the
|
||||
objective's mean is unchanged.
|
||||
However, subclasses may override sample() to change how the resampled
|
||||
terms are reweighted, allowing for actual changes in the objective.
|
||||
"""
|
||||
@abstractmethod
|
||||
def weights(self):
|
||||
"""
|
||||
Get a numpy array of weights, one per diffusion step.
|
||||
|
||||
The weights needn't be normalized, but must be positive.
|
||||
"""
|
||||
|
||||
def sample(self, batch_size, device):
|
||||
"""
|
||||
Importance-sample timesteps for a batch.
|
||||
|
||||
:param batch_size: the number of timesteps.
|
||||
:param device: the torch device to save to.
|
||||
:return: a tuple (timesteps, weights):
|
||||
- timesteps: a tensor of timestep indices.
|
||||
- weights: a tensor of weights to scale the resulting losses.
|
||||
"""
|
||||
w = self.weights()
|
||||
p = w / np.sum(w)
|
||||
indices_np = np.random.choice(len(p), size=(batch_size, ), p=p)
|
||||
indices = th.from_numpy(indices_np).long().to(device)
|
||||
weights_np = 1 / (len(p) * p[indices_np])
|
||||
weights = th.from_numpy(weights_np).float().to(device)
|
||||
return indices, weights
|
||||
|
||||
|
||||
class UniformSampler(ScheduleSampler):
|
||||
def __init__(self, num_timesteps): ## all steps are 1
|
||||
self._weights = np.ones([num_timesteps])
|
||||
|
||||
def weights(self):
|
||||
return self._weights
|
Loading…
Add table
Add a link
Reference in a new issue