222 lines
9.6 KiB
Python
222 lines
9.6 KiB
Python
import random
|
|
import numpy as np
|
|
import torch
|
|
from torch import nn
|
|
|
|
from .helpers import (
|
|
cosine_beta_schedule,
|
|
extract,
|
|
condition_projection,
|
|
Losses,
|
|
)
|
|
|
|
class GaussianDiffusion(nn.Module):
|
|
def __init__(self, model, horizon, observation_dim, action_dim, class_dim, act_mean, act_std, n_timesteps=200,
|
|
loss_type='Weighted_MSE', clip_denoised=False, ddim_discr_method='uniform',
|
|
):
|
|
super().__init__()
|
|
self.horizon = horizon
|
|
self.observation_dim = observation_dim
|
|
self.action_dim = action_dim
|
|
self.class_dim = class_dim
|
|
self.model = model
|
|
|
|
betas = cosine_beta_schedule(n_timesteps)
|
|
alphas = 1. - betas
|
|
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
|
alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]])
|
|
|
|
self.n_timesteps = n_timesteps
|
|
self.clip_denoised = clip_denoised
|
|
self.eta = 0.0
|
|
self.random_ratio = 1.0
|
|
self.act_mean = act_mean
|
|
self.act_std = act_std
|
|
|
|
# ---------------------------ddim--------------------------------
|
|
ddim_timesteps = 10
|
|
|
|
if ddim_discr_method == 'uniform':
|
|
c = n_timesteps // ddim_timesteps
|
|
ddim_timestep_seq = np.asarray(list(range(0, n_timesteps, c)))
|
|
elif ddim_discr_method == 'quad':
|
|
ddim_timestep_seq = (
|
|
(np.linspace(0, np.sqrt(n_timesteps), ddim_timesteps)) ** 2
|
|
).astype(int)
|
|
else:
|
|
assert RuntimeError()
|
|
|
|
self.ddim_timesteps = ddim_timesteps
|
|
self.ddim_timestep_seq = ddim_timestep_seq
|
|
# ----------------------------------------------------------------
|
|
|
|
self.register_buffer('betas', betas)
|
|
self.register_buffer('alphas_cumprod', alphas_cumprod)
|
|
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
|
|
|
|
# calculations for diffusion q(x_t | x_{t-1}) and others
|
|
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
|
|
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
|
|
self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
|
|
self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
|
|
self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
|
|
|
|
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
|
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
|
|
self.register_buffer('posterior_variance', posterior_variance)
|
|
|
|
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
|
self.register_buffer('posterior_log_variance_clipped',
|
|
torch.log(torch.clamp(posterior_variance, min=1e-20)))
|
|
self.register_buffer('posterior_mean_coef1',
|
|
betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
|
self.register_buffer('posterior_mean_coef2',
|
|
(1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))
|
|
self.loss_type = loss_type
|
|
self.loss_fn = Losses[loss_type](None, self.action_dim, self.class_dim)
|
|
|
|
# ------------------------------------------ sampling ------------------------------------------#
|
|
|
|
def q_posterior(self, x_start, x_t, t):
|
|
posterior_mean = (
|
|
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
|
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
|
)
|
|
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
|
|
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
|
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
|
|
|
def p_mean_variance(self, x, cond, task_label, t):
|
|
#x_recon = self.model(x, t)
|
|
x_recon = self.model(x, t, task_label)
|
|
|
|
if self.clip_denoised:
|
|
x_recon.clamp(-1., 1.)
|
|
else:
|
|
assert RuntimeError()
|
|
|
|
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
|
|
x_start=x_recon, x_t=x, t=t)
|
|
return model_mean, posterior_variance, posterior_log_variance
|
|
|
|
@torch.no_grad()
|
|
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
|
|
return \
|
|
(extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) \
|
|
/ extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
|
|
|
@torch.no_grad()
|
|
def p_sample_ddim(self, x, cond, avg_mask, task_label, t, t_prev, noise, if_prev=False, if_avg_mask=False):
|
|
b, *_, device = *x.shape, x.device
|
|
#x_recon = self.model(x, t) # without class condition
|
|
x_recon = self.model(x, t, task_label)
|
|
|
|
if self.clip_denoised:
|
|
x_recon.clamp(-1., 1.)
|
|
else:
|
|
assert RuntimeError()
|
|
|
|
eps = self._predict_eps_from_xstart(x, t, x_recon)
|
|
alpha_bar = extract(self.alphas_cumprod, t, x.shape)
|
|
if if_prev:
|
|
alpha_bar_prev = extract(self.alphas_cumprod_prev, t_prev, x.shape)
|
|
else:
|
|
alpha_bar_prev = extract(self.alphas_cumprod, t_prev, x.shape)
|
|
sigma = (
|
|
self.eta
|
|
* torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
|
|
* torch.sqrt(1 - alpha_bar / alpha_bar_prev)
|
|
)
|
|
|
|
mean_pred = (
|
|
x_recon * torch.sqrt(alpha_bar_prev)
|
|
+ torch.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
|
|
)
|
|
|
|
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
|
return mean_pred + nonzero_mask * sigma * noise
|
|
|
|
@torch.no_grad()
|
|
def p_sample(self, x, cond, task_label, t, noise):
|
|
b, *_, device = *x.shape, x.device
|
|
model_mean, _, model_log_variance = self.p_mean_variance(x=x, cond=cond, task_label=task_label, t=t)
|
|
#noise = torch.randn_like(x) * self.random_ratio
|
|
|
|
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
|
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
|
|
|
@torch.no_grad()
|
|
def p_sample_loop(self, cond, avg_mask, task_label, if_jump, if_avg_mask):
|
|
device = self.betas.device
|
|
batch_size = len(cond[0])
|
|
horizon = self.horizon
|
|
shape = (batch_size, horizon, self.class_dim + self.action_dim + self.observation_dim)
|
|
|
|
|
|
mean = torch.zeros(shape, device=device)
|
|
std = torch.zeros(shape, device=device)
|
|
|
|
for i in range(shape[1]):
|
|
std[:,i,:] = std[:,i,:] + self.act_std[i]
|
|
mean[:,i,:] = mean[:,i,:] + self.act_mean[i]
|
|
|
|
x = torch.normal(mean, std)
|
|
noise = x
|
|
x = condition_projection(x, cond, self.action_dim, self.class_dim)
|
|
|
|
|
|
if not if_jump:
|
|
for i in reversed(range(0, self.n_timesteps)):
|
|
timesteps = torch.full((batch_size,), i, device=device, dtype=torch.long)
|
|
x = self.p_sample(x, cond, task_label, timesteps)
|
|
x = condition_projection(x, cond, self.action_dim, self.class_dim)
|
|
|
|
else:
|
|
for i in reversed(range(0, self.ddim_timesteps)):
|
|
timesteps = torch.full((batch_size,), self.ddim_timestep_seq[i], device=device, dtype=torch.long)
|
|
if i == 0:
|
|
timesteps_prev = torch.full((batch_size,), 0, device=device, dtype=torch.long)
|
|
x = self.p_sample_ddim(x, cond, avg_mask, task_label, timesteps, timesteps_prev, noise, True, if_avg_mask)
|
|
else:
|
|
timesteps_prev = torch.full((batch_size,), self.ddim_timestep_seq[i-1], device=device, dtype=torch.long)
|
|
x = self.p_sample_ddim(x, cond, avg_mask, task_label, timesteps, timesteps_prev, noise, if_avg_mask)
|
|
x = condition_projection(x, cond, self.action_dim, self.class_dim)
|
|
|
|
return x
|
|
|
|
# ------------------------------------------ training ------------------------------------------#
|
|
|
|
def q_sample(self, x_start, t, noise=None):
|
|
if noise is None:
|
|
noise = torch.randn_like(x_start) * self.random_ratio
|
|
|
|
sample = (
|
|
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
|
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
|
)
|
|
|
|
return sample
|
|
|
|
def p_losses(self, x_start, cond, t, act_emb_noise, task_label):
|
|
noise = act_emb_noise * self.random_ratio
|
|
#print('act noise shape', noise.shape)
|
|
#print('x_start shape', x_start.shape)
|
|
|
|
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) # for diffusion, should be removed for Noise and Deterministic
|
|
x_noisy = condition_projection(x_noisy, cond, self.action_dim, self.class_dim)
|
|
|
|
#x_recon = self.model(x_noisy, t) # without class condition
|
|
x_recon = self.model(x_noisy, t, task_label) # with class condition
|
|
#print('x_recon',x_recon.shape)
|
|
x_recon = condition_projection(x_recon, cond, self.action_dim, self.class_dim)
|
|
|
|
loss = self.loss_fn(x_recon, x_start)
|
|
return loss
|
|
|
|
def loss(self, x, cond, act_emb_noise, task_label):
|
|
batch_size = len(x)
|
|
t = torch.randint(0, self.n_timesteps, (batch_size,), device=x.device).long() # for diffusion
|
|
return self.p_losses(x, cond, t, act_emb_noise, task_label)
|
|
|
|
def forward(self, cond, avg_mask, task_label, if_jump=False, if_avg_mask=False):
|
|
return self.p_sample_loop(cond, avg_mask, task_label, if_jump, if_avg_mask)
|