first commit
This commit is contained in:
commit
8f8cf48929
2819 changed files with 33143 additions and 0 deletions
212
model/diffusion_act.py
Normal file
212
model/diffusion_act.py
Normal file
|
@ -0,0 +1,212 @@
|
|||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .helpers import (
|
||||
cosine_beta_schedule,
|
||||
extract,
|
||||
condition_projection,
|
||||
Losses,
|
||||
)
|
||||
|
||||
class GaussianDiffusion(nn.Module):
|
||||
def __init__(self, model, horizon, observation_dim, action_dim, class_dim, n_timesteps=200,
|
||||
loss_type='Weighted_MSE', clip_denoised=False, ddim_discr_method='uniform',
|
||||
):
|
||||
super().__init__()
|
||||
self.horizon = horizon
|
||||
self.observation_dim = observation_dim
|
||||
self.action_dim = action_dim
|
||||
self.class_dim = class_dim
|
||||
self.model = model
|
||||
|
||||
betas = cosine_beta_schedule(n_timesteps)
|
||||
alphas = 1. - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]])
|
||||
|
||||
self.n_timesteps = n_timesteps
|
||||
self.clip_denoised = clip_denoised
|
||||
self.eta = 0.0
|
||||
self.random_ratio = 1.0
|
||||
|
||||
# ---------------------------ddim--------------------------------
|
||||
ddim_timesteps = 10
|
||||
|
||||
if ddim_discr_method == 'uniform':
|
||||
c = n_timesteps // ddim_timesteps
|
||||
ddim_timestep_seq = np.asarray(list(range(0, n_timesteps, c)))
|
||||
elif ddim_discr_method == 'quad':
|
||||
ddim_timestep_seq = (
|
||||
(np.linspace(0, np.sqrt(n_timesteps), ddim_timesteps)) ** 2
|
||||
).astype(int)
|
||||
else:
|
||||
assert RuntimeError()
|
||||
|
||||
self.ddim_timesteps = ddim_timesteps
|
||||
self.ddim_timestep_seq = ddim_timestep_seq
|
||||
# ----------------------------------------------------------------
|
||||
|
||||
self.register_buffer('betas', betas)
|
||||
self.register_buffer('alphas_cumprod', alphas_cumprod)
|
||||
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
|
||||
|
||||
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
||||
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
|
||||
self.register_buffer('posterior_variance', posterior_variance)
|
||||
|
||||
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
||||
self.register_buffer('posterior_log_variance_clipped',
|
||||
torch.log(torch.clamp(posterior_variance, min=1e-20)))
|
||||
self.register_buffer('posterior_mean_coef1',
|
||||
betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
||||
self.register_buffer('posterior_mean_coef2',
|
||||
(1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))
|
||||
self.loss_type = loss_type
|
||||
self.loss_fn = Losses[loss_type](None, self.action_dim, self.class_dim)
|
||||
|
||||
# ------------------------------------------ sampling ------------------------------------------#
|
||||
|
||||
def q_posterior(self, x_start, x_t, t):
|
||||
posterior_mean = (
|
||||
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
||||
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
||||
)
|
||||
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
|
||||
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
||||
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
||||
|
||||
def p_mean_variance(self, x, cond, task_label, t):
|
||||
#x_recon = self.model(x, t)
|
||||
x_recon = self.model(x, t, task_label)
|
||||
|
||||
if self.clip_denoised:
|
||||
x_recon.clamp(-1., 1.)
|
||||
else:
|
||||
assert RuntimeError()
|
||||
|
||||
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
|
||||
x_start=x_recon, x_t=x, t=t)
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
|
||||
@torch.no_grad()
|
||||
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
|
||||
return \
|
||||
(extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) \
|
||||
/ extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_ddim(self, x, cond, avg_mask, task_label, t, t_prev, if_prev=False, if_avg_mask=False):
|
||||
b, *_, device = *x.shape, x.device
|
||||
#x_recon = self.model(x, t) # without class condition
|
||||
x_recon = self.model(x, t, task_label)
|
||||
|
||||
if self.clip_denoised:
|
||||
x_recon.clamp(-1., 1.)
|
||||
else:
|
||||
assert RuntimeError()
|
||||
|
||||
eps = self._predict_eps_from_xstart(x, t, x_recon)
|
||||
alpha_bar = extract(self.alphas_cumprod, t, x.shape)
|
||||
if if_prev:
|
||||
alpha_bar_prev = extract(self.alphas_cumprod_prev, t_prev, x.shape)
|
||||
else:
|
||||
alpha_bar_prev = extract(self.alphas_cumprod, t_prev, x.shape)
|
||||
sigma = (
|
||||
self.eta
|
||||
* torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
|
||||
* torch.sqrt(1 - alpha_bar / alpha_bar_prev)
|
||||
)
|
||||
|
||||
noise = torch.randn_like(x) * self.random_ratio
|
||||
|
||||
mean_pred = (
|
||||
x_recon * torch.sqrt(alpha_bar_prev)
|
||||
+ torch.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
|
||||
)
|
||||
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
return mean_pred + nonzero_mask * sigma * noise
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample(self, x, cond, task_label, t):
|
||||
b, *_, device = *x.shape, x.device
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(x=x, cond=cond, task_label=task_label, t=t)
|
||||
noise = torch.randn_like(x) * self.random_ratio
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop(self, cond, avg_mask, task_label, if_jump, if_avg_mask):
|
||||
device = self.betas.device
|
||||
batch_size = len(cond[0])
|
||||
horizon = self.horizon
|
||||
shape = (batch_size, horizon, self.class_dim + self.action_dim + self.observation_dim)
|
||||
|
||||
x = torch.randn(shape, device=device) * self.random_ratio # xt for Noise and diffusion
|
||||
x = condition_projection(x, cond, self.action_dim, self.class_dim)
|
||||
|
||||
if not if_jump:
|
||||
for i in reversed(range(0, self.n_timesteps)):
|
||||
timesteps = torch.full((batch_size,), i, device=device, dtype=torch.long)
|
||||
x = self.p_sample(x, cond, task_label, timesteps)
|
||||
x = condition_projection(x, cond, self.action_dim, self.class_dim)
|
||||
|
||||
else:
|
||||
for i in reversed(range(0, self.ddim_timesteps)):
|
||||
timesteps = torch.full((batch_size,), self.ddim_timestep_seq[i], device=device, dtype=torch.long)
|
||||
if i == 0:
|
||||
timesteps_prev = torch.full((batch_size,), 0, device=device, dtype=torch.long)
|
||||
x = self.p_sample_ddim(x, cond, avg_mask, task_label, timesteps, timesteps_prev, True, if_avg_mask)
|
||||
else:
|
||||
timesteps_prev = torch.full((batch_size,), self.ddim_timestep_seq[i-1], device=device, dtype=torch.long)
|
||||
x = self.p_sample_ddim(x, cond, avg_mask, task_label, timesteps, timesteps_prev, if_avg_mask)
|
||||
x = condition_projection(x, cond, self.action_dim, self.class_dim)
|
||||
|
||||
return x
|
||||
|
||||
# ------------------------------------------ training ------------------------------------------#
|
||||
|
||||
def q_sample(self, x_start, t, noise=None):
|
||||
if noise is None:
|
||||
noise = torch.randn_like(x_start) * self.random_ratio
|
||||
|
||||
sample = (
|
||||
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
||||
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
||||
)
|
||||
|
||||
return sample
|
||||
|
||||
def p_losses(self, x_start, cond, t, act_emb_noise, task_label):
|
||||
noise = act_emb_noise * self.random_ratio
|
||||
#print('act noise shape', noise.shape)
|
||||
#print('x_start shape', x_start.shape)
|
||||
|
||||
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) # for diffusion, should be removed for Noise and Deterministic
|
||||
x_noisy = condition_projection(x_noisy, cond, self.action_dim, self.class_dim)
|
||||
|
||||
#x_recon = self.model(x_noisy, t) # without class condition
|
||||
x_recon = self.model(x_noisy, t, task_label) # with class condition
|
||||
#print('x_recon',x_recon.shape)
|
||||
x_recon = condition_projection(x_recon, cond, self.action_dim, self.class_dim)
|
||||
|
||||
loss = self.loss_fn(x_recon, x_start)
|
||||
return loss
|
||||
|
||||
def loss(self, x, cond, act_emb_noise, task_label):
|
||||
batch_size = len(x) # for diffusion
|
||||
t = torch.randint(0, self.n_timesteps, (batch_size,), device=x.device).long() # for diffusion
|
||||
# t = None # for Noise and Deterministic
|
||||
return self.p_losses(x, cond, t, act_emb_noise, task_label)
|
||||
|
||||
def forward(self, cond, avg_mask, task_label, if_jump=False, if_avg_mask=False):
|
||||
return self.p_sample_loop(cond, avg_mask, task_label, if_jump, if_avg_mask)
|
222
model/diffusion_act_dist.py
Normal file
222
model/diffusion_act_dist.py
Normal file
|
@ -0,0 +1,222 @@
|
|||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .helpers import (
|
||||
cosine_beta_schedule,
|
||||
extract,
|
||||
condition_projection,
|
||||
Losses,
|
||||
)
|
||||
|
||||
class GaussianDiffusion(nn.Module):
|
||||
def __init__(self, model, horizon, observation_dim, action_dim, class_dim, act_mean, act_std, n_timesteps=200,
|
||||
loss_type='Weighted_MSE', clip_denoised=False, ddim_discr_method='uniform',
|
||||
):
|
||||
super().__init__()
|
||||
self.horizon = horizon
|
||||
self.observation_dim = observation_dim
|
||||
self.action_dim = action_dim
|
||||
self.class_dim = class_dim
|
||||
self.model = model
|
||||
|
||||
betas = cosine_beta_schedule(n_timesteps)
|
||||
alphas = 1. - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]])
|
||||
|
||||
self.n_timesteps = n_timesteps
|
||||
self.clip_denoised = clip_denoised
|
||||
self.eta = 0.0
|
||||
self.random_ratio = 1.0
|
||||
self.act_mean = act_mean
|
||||
self.act_std = act_std
|
||||
|
||||
# ---------------------------ddim--------------------------------
|
||||
ddim_timesteps = 10
|
||||
|
||||
if ddim_discr_method == 'uniform':
|
||||
c = n_timesteps // ddim_timesteps
|
||||
ddim_timestep_seq = np.asarray(list(range(0, n_timesteps, c)))
|
||||
elif ddim_discr_method == 'quad':
|
||||
ddim_timestep_seq = (
|
||||
(np.linspace(0, np.sqrt(n_timesteps), ddim_timesteps)) ** 2
|
||||
).astype(int)
|
||||
else:
|
||||
assert RuntimeError()
|
||||
|
||||
self.ddim_timesteps = ddim_timesteps
|
||||
self.ddim_timestep_seq = ddim_timestep_seq
|
||||
# ----------------------------------------------------------------
|
||||
|
||||
self.register_buffer('betas', betas)
|
||||
self.register_buffer('alphas_cumprod', alphas_cumprod)
|
||||
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
|
||||
|
||||
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
||||
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
|
||||
self.register_buffer('posterior_variance', posterior_variance)
|
||||
|
||||
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
||||
self.register_buffer('posterior_log_variance_clipped',
|
||||
torch.log(torch.clamp(posterior_variance, min=1e-20)))
|
||||
self.register_buffer('posterior_mean_coef1',
|
||||
betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
||||
self.register_buffer('posterior_mean_coef2',
|
||||
(1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))
|
||||
self.loss_type = loss_type
|
||||
self.loss_fn = Losses[loss_type](None, self.action_dim, self.class_dim)
|
||||
|
||||
# ------------------------------------------ sampling ------------------------------------------#
|
||||
|
||||
def q_posterior(self, x_start, x_t, t):
|
||||
posterior_mean = (
|
||||
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
||||
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
||||
)
|
||||
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
|
||||
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
||||
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
||||
|
||||
def p_mean_variance(self, x, cond, task_label, t):
|
||||
#x_recon = self.model(x, t)
|
||||
x_recon = self.model(x, t, task_label)
|
||||
|
||||
if self.clip_denoised:
|
||||
x_recon.clamp(-1., 1.)
|
||||
else:
|
||||
assert RuntimeError()
|
||||
|
||||
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
|
||||
x_start=x_recon, x_t=x, t=t)
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
|
||||
@torch.no_grad()
|
||||
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
|
||||
return \
|
||||
(extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) \
|
||||
/ extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_ddim(self, x, cond, avg_mask, task_label, t, t_prev, noise, if_prev=False, if_avg_mask=False):
|
||||
b, *_, device = *x.shape, x.device
|
||||
#x_recon = self.model(x, t) # without class condition
|
||||
x_recon = self.model(x, t, task_label)
|
||||
|
||||
if self.clip_denoised:
|
||||
x_recon.clamp(-1., 1.)
|
||||
else:
|
||||
assert RuntimeError()
|
||||
|
||||
eps = self._predict_eps_from_xstart(x, t, x_recon)
|
||||
alpha_bar = extract(self.alphas_cumprod, t, x.shape)
|
||||
if if_prev:
|
||||
alpha_bar_prev = extract(self.alphas_cumprod_prev, t_prev, x.shape)
|
||||
else:
|
||||
alpha_bar_prev = extract(self.alphas_cumprod, t_prev, x.shape)
|
||||
sigma = (
|
||||
self.eta
|
||||
* torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
|
||||
* torch.sqrt(1 - alpha_bar / alpha_bar_prev)
|
||||
)
|
||||
|
||||
mean_pred = (
|
||||
x_recon * torch.sqrt(alpha_bar_prev)
|
||||
+ torch.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
|
||||
)
|
||||
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
return mean_pred + nonzero_mask * sigma * noise
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample(self, x, cond, task_label, t, noise):
|
||||
b, *_, device = *x.shape, x.device
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(x=x, cond=cond, task_label=task_label, t=t)
|
||||
#noise = torch.randn_like(x) * self.random_ratio
|
||||
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop(self, cond, avg_mask, task_label, if_jump, if_avg_mask):
|
||||
device = self.betas.device
|
||||
batch_size = len(cond[0])
|
||||
horizon = self.horizon
|
||||
shape = (batch_size, horizon, self.class_dim + self.action_dim + self.observation_dim)
|
||||
|
||||
|
||||
mean = torch.zeros(shape, device=device)
|
||||
std = torch.zeros(shape, device=device)
|
||||
|
||||
for i in range(shape[1]):
|
||||
std[:,i,:] = std[:,i,:] + self.act_std[i]
|
||||
mean[:,i,:] = mean[:,i,:] + self.act_mean[i]
|
||||
|
||||
x = torch.normal(mean, std)
|
||||
noise = x
|
||||
x = condition_projection(x, cond, self.action_dim, self.class_dim)
|
||||
|
||||
|
||||
if not if_jump:
|
||||
for i in reversed(range(0, self.n_timesteps)):
|
||||
timesteps = torch.full((batch_size,), i, device=device, dtype=torch.long)
|
||||
x = self.p_sample(x, cond, task_label, timesteps)
|
||||
x = condition_projection(x, cond, self.action_dim, self.class_dim)
|
||||
|
||||
else:
|
||||
for i in reversed(range(0, self.ddim_timesteps)):
|
||||
timesteps = torch.full((batch_size,), self.ddim_timestep_seq[i], device=device, dtype=torch.long)
|
||||
if i == 0:
|
||||
timesteps_prev = torch.full((batch_size,), 0, device=device, dtype=torch.long)
|
||||
x = self.p_sample_ddim(x, cond, avg_mask, task_label, timesteps, timesteps_prev, noise, True, if_avg_mask)
|
||||
else:
|
||||
timesteps_prev = torch.full((batch_size,), self.ddim_timestep_seq[i-1], device=device, dtype=torch.long)
|
||||
x = self.p_sample_ddim(x, cond, avg_mask, task_label, timesteps, timesteps_prev, noise, if_avg_mask)
|
||||
x = condition_projection(x, cond, self.action_dim, self.class_dim)
|
||||
|
||||
return x
|
||||
|
||||
# ------------------------------------------ training ------------------------------------------#
|
||||
|
||||
def q_sample(self, x_start, t, noise=None):
|
||||
if noise is None:
|
||||
noise = torch.randn_like(x_start) * self.random_ratio
|
||||
|
||||
sample = (
|
||||
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
||||
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
||||
)
|
||||
|
||||
return sample
|
||||
|
||||
def p_losses(self, x_start, cond, t, act_emb_noise, task_label):
|
||||
noise = act_emb_noise * self.random_ratio
|
||||
#print('act noise shape', noise.shape)
|
||||
#print('x_start shape', x_start.shape)
|
||||
|
||||
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) # for diffusion, should be removed for Noise and Deterministic
|
||||
x_noisy = condition_projection(x_noisy, cond, self.action_dim, self.class_dim)
|
||||
|
||||
#x_recon = self.model(x_noisy, t) # without class condition
|
||||
x_recon = self.model(x_noisy, t, task_label) # with class condition
|
||||
#print('x_recon',x_recon.shape)
|
||||
x_recon = condition_projection(x_recon, cond, self.action_dim, self.class_dim)
|
||||
|
||||
loss = self.loss_fn(x_recon, x_start)
|
||||
return loss
|
||||
|
||||
def loss(self, x, cond, act_emb_noise, task_label):
|
||||
batch_size = len(x)
|
||||
t = torch.randint(0, self.n_timesteps, (batch_size,), device=x.device).long() # for diffusion
|
||||
return self.p_losses(x, cond, t, act_emb_noise, task_label)
|
||||
|
||||
def forward(self, cond, avg_mask, task_label, if_jump=False, if_avg_mask=False):
|
||||
return self.p_sample_loop(cond, avg_mask, task_label, if_jump, if_avg_mask)
|
208
model/diffusion_no_mask.py
Normal file
208
model/diffusion_no_mask.py
Normal file
|
@ -0,0 +1,208 @@
|
|||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .helpers import (
|
||||
cosine_beta_schedule,
|
||||
extract,
|
||||
condition_projection,
|
||||
Losses,
|
||||
)
|
||||
|
||||
class GaussianDiffusion(nn.Module):
|
||||
def __init__(self, model, horizon, observation_dim, action_dim, class_dim, n_timesteps=200,
|
||||
loss_type='Weighted_MSE', clip_denoised=False, ddim_discr_method='uniform',
|
||||
):
|
||||
super().__init__()
|
||||
self.horizon = horizon
|
||||
self.observation_dim = observation_dim
|
||||
self.action_dim = action_dim
|
||||
self.class_dim = class_dim
|
||||
self.model = model
|
||||
|
||||
betas = cosine_beta_schedule(n_timesteps)
|
||||
alphas = 1. - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]])
|
||||
|
||||
self.n_timesteps = n_timesteps
|
||||
self.clip_denoised = clip_denoised
|
||||
self.eta = 0.0
|
||||
self.random_ratio = 1.0
|
||||
|
||||
# ---------------------------ddim--------------------------------
|
||||
ddim_timesteps = 10
|
||||
|
||||
if ddim_discr_method == 'uniform':
|
||||
c = n_timesteps // ddim_timesteps
|
||||
ddim_timestep_seq = np.asarray(list(range(0, n_timesteps, c)))
|
||||
elif ddim_discr_method == 'quad':
|
||||
ddim_timestep_seq = (
|
||||
(np.linspace(0, np.sqrt(n_timesteps), ddim_timesteps)) ** 2
|
||||
).astype(int)
|
||||
else:
|
||||
assert RuntimeError()
|
||||
|
||||
self.ddim_timesteps = ddim_timesteps
|
||||
self.ddim_timestep_seq = ddim_timestep_seq
|
||||
# ----------------------------------------------------------------
|
||||
|
||||
self.register_buffer('betas', betas)
|
||||
self.register_buffer('alphas_cumprod', alphas_cumprod)
|
||||
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
|
||||
|
||||
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
||||
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
|
||||
self.register_buffer('posterior_variance', posterior_variance)
|
||||
|
||||
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
||||
self.register_buffer('posterior_log_variance_clipped',
|
||||
torch.log(torch.clamp(posterior_variance, min=1e-20)))
|
||||
self.register_buffer('posterior_mean_coef1',
|
||||
betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
||||
self.register_buffer('posterior_mean_coef2',
|
||||
(1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))
|
||||
self.loss_type = loss_type
|
||||
self.loss_fn = Losses[loss_type](None, self.action_dim, self.class_dim)
|
||||
|
||||
# ------------------------------------------ sampling ------------------------------------------#
|
||||
|
||||
def q_posterior(self, x_start, x_t, t):
|
||||
posterior_mean = (
|
||||
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
||||
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
||||
)
|
||||
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
|
||||
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
||||
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
||||
|
||||
def p_mean_variance(self, x, cond, t):
|
||||
x_recon = self.model(x, t,0)
|
||||
|
||||
if self.clip_denoised:
|
||||
x_recon.clamp(-1., 1.)
|
||||
else:
|
||||
assert RuntimeError()
|
||||
|
||||
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
|
||||
x_start=x_recon, x_t=x, t=t)
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
|
||||
@torch.no_grad()
|
||||
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
|
||||
return \
|
||||
(extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) \
|
||||
/ extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_ddim(self, x, cond, t, t_prev, if_prev=False):
|
||||
b, *_, device = *x.shape, x.device
|
||||
x_recon = self.model(x, t,0)
|
||||
|
||||
if self.clip_denoised:
|
||||
x_recon.clamp(-1., 1.)
|
||||
else:
|
||||
assert RuntimeError()
|
||||
|
||||
eps = self._predict_eps_from_xstart(x, t, x_recon)
|
||||
alpha_bar = extract(self.alphas_cumprod, t, x.shape)
|
||||
if if_prev:
|
||||
alpha_bar_prev = extract(self.alphas_cumprod_prev, t_prev, x.shape)
|
||||
else:
|
||||
alpha_bar_prev = extract(self.alphas_cumprod, t_prev, x.shape)
|
||||
sigma = (
|
||||
self.eta
|
||||
* torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
|
||||
* torch.sqrt(1 - alpha_bar / alpha_bar_prev)
|
||||
)
|
||||
|
||||
noise = torch.randn_like(x) * self.random_ratio
|
||||
mean_pred = (
|
||||
x_recon * torch.sqrt(alpha_bar_prev)
|
||||
+ torch.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
|
||||
)
|
||||
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
return mean_pred + nonzero_mask * sigma * noise
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample(self, x, cond, t):
|
||||
b, *_, device = *x.shape, x.device
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(x=x, cond=cond, t=t)
|
||||
noise = torch.randn_like(x) * self.random_ratio
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop(self, cond, if_jump):
|
||||
device = self.betas.device
|
||||
batch_size = len(cond[0])
|
||||
horizon = self.horizon
|
||||
shape = (batch_size, horizon, self.class_dim + self.action_dim + self.observation_dim)
|
||||
|
||||
x = torch.randn(shape, device=device) * self.random_ratio # xt for Noise and diffusion
|
||||
# x = torch.zeros(shape, device=device) # for Deterministic
|
||||
x = condition_projection(x, cond, self.action_dim, self.class_dim)
|
||||
|
||||
if not if_jump:
|
||||
for i in reversed(range(0, self.n_timesteps)):
|
||||
timesteps = torch.full((batch_size,), i, device=device, dtype=torch.long)
|
||||
x = self.p_sample(x, cond, timesteps)
|
||||
x = condition_projection(x, cond, self.action_dim, self.class_dim)
|
||||
|
||||
else:
|
||||
for i in reversed(range(0, self.ddim_timesteps)):
|
||||
timesteps = torch.full((batch_size,), self.ddim_timestep_seq[i], device=device, dtype=torch.long)
|
||||
if i == 0:
|
||||
timesteps_prev = torch.full((batch_size,), 0, device=device, dtype=torch.long)
|
||||
x = self.p_sample_ddim(x, cond, timesteps, timesteps_prev, True)
|
||||
else:
|
||||
timesteps_prev = torch.full((batch_size,), self.ddim_timestep_seq[i-1], device=device, dtype=torch.long)
|
||||
x = self.p_sample_ddim(x, cond, timesteps, timesteps_prev)
|
||||
x = condition_projection(x, cond, self.action_dim, self.class_dim)
|
||||
|
||||
return x
|
||||
|
||||
# ------------------------------------------ training ------------------------------------------#
|
||||
|
||||
def q_sample(self, x_start, t, noise=None):
|
||||
if noise is None:
|
||||
noise = torch.randn_like(x_start) * self.random_ratio
|
||||
|
||||
sample = (
|
||||
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
||||
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
||||
)
|
||||
|
||||
return sample
|
||||
|
||||
def p_losses(self, x_start, cond, t):
|
||||
noise = torch.randn_like(x_start) * self.random_ratio # for Noise and diffusion
|
||||
# noise = torch.zeros_like(x_start) # for Deterministic
|
||||
# x_noisy = noise # for Noise and Deterministic
|
||||
|
||||
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) # for diffusion, should be removed for Noise and Deterministic
|
||||
x_noisy = condition_projection(x_noisy, cond, self.action_dim, self.class_dim)
|
||||
|
||||
x_recon = self.model(x_noisy, t, 0)
|
||||
x_recon = condition_projection(x_recon, cond, self.action_dim, self.class_dim)
|
||||
|
||||
loss = self.loss_fn(x_recon, x_start)
|
||||
return loss
|
||||
|
||||
def loss(self, x, cond):
|
||||
batch_size = len(x) # for diffusion
|
||||
t = torch.randint(0, self.n_timesteps, (batch_size,), device=x.device).long() # for diffusion
|
||||
# t = None # for Noise and Deterministic
|
||||
return self.p_losses(x, cond, t)
|
||||
|
||||
def forward(self, cond, if_jump=False):
|
||||
return self.p_sample_loop(cond, if_jump)
|
512
model/dit.py
Normal file
512
model/dit.py
Normal file
|
@ -0,0 +1,512 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
# --------------------------------------------------------
|
||||
# References:
|
||||
# GLIDE: https://github.com/openai/glide-text2im
|
||||
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
|
||||
# --------------------------------------------------------
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import math
|
||||
from timm.models.vision_transformer import Attention, Mlp # PatchEmbed
|
||||
from torch.nn import TransformerEncoder, TransformerEncoderLayer
|
||||
from .helpers import SinusoidalPosEmb
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
""" 2D Image to Patch Embedding
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
embed_dim=768,
|
||||
norm_layer=None,
|
||||
flatten=True,
|
||||
bias=True,
|
||||
):
|
||||
super().__init__()
|
||||
img_size = (img_size, 1)
|
||||
patch_size = (patch_size, 1)
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
self.flatten = flatten
|
||||
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
#_assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
|
||||
#_assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
|
||||
x = self.proj(x)
|
||||
if self.flatten:
|
||||
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
|
||||
#################################################################################
|
||||
# Embedding Layers for Timesteps and Class Labels #
|
||||
#################################################################################
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds scalar timesteps into vector representations.
|
||||
"""
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an (N, D) Tensor of positional embeddings.
|
||||
"""
|
||||
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
||||
).to(device=t.device)
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
def forward(self, t):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
class LabelEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
||||
"""
|
||||
def __init__(self, num_classes, hidden_size, dropout_prob):
|
||||
super().__init__()
|
||||
use_cfg_embedding = dropout_prob > 0
|
||||
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
|
||||
self.num_classes = num_classes
|
||||
self.dropout_prob = dropout_prob
|
||||
|
||||
def token_drop(self, labels, force_drop_ids=None):
|
||||
"""
|
||||
Drops labels to enable classifier-free guidance.
|
||||
"""
|
||||
if force_drop_ids is None:
|
||||
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
|
||||
else:
|
||||
drop_ids = force_drop_ids == 1
|
||||
labels = torch.where(drop_ids, self.num_classes, labels)
|
||||
return labels
|
||||
|
||||
def forward(self, labels, train, force_drop_ids=None):
|
||||
use_dropout = self.dropout_prob > 0
|
||||
if (train and use_dropout) or (force_drop_ids is not None):
|
||||
labels = self.token_drop(labels, force_drop_ids)
|
||||
embeddings = self.embedding_table(labels)
|
||||
return embeddings
|
||||
|
||||
|
||||
#################################################################################
|
||||
# Core DiT Model #
|
||||
#################################################################################
|
||||
|
||||
class DiTBlock(nn.Module):
|
||||
"""
|
||||
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
|
||||
"""
|
||||
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
|
||||
super().__init__()
|
||||
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
|
||||
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
|
||||
x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
|
||||
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
||||
return x
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
"""
|
||||
The final layer of DiT.
|
||||
"""
|
||||
def __init__(self, hidden_size, patch_size, out_channels):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, patch_size * 1 * out_channels, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class DiT(nn.Module):
|
||||
"""
|
||||
Diffusion model with a Transformer backbone.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
input_size=32,
|
||||
patch_size=2,
|
||||
in_channels=4,
|
||||
hidden_size=384,
|
||||
depth=12,
|
||||
num_heads=6,
|
||||
mlp_ratio=4.0,
|
||||
class_dropout_prob=0.1,
|
||||
num_classes=1000,
|
||||
learn_sigma=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.learn_sigma = learn_sigma
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels * 2 if learn_sigma else in_channels
|
||||
self.patch_size = patch_size
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
|
||||
self.t_embedder = TimestepEmbedder(hidden_size)
|
||||
#self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
|
||||
num_patches = self.x_embedder.num_patches
|
||||
# Will use fixed sin-cos embedding:
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
|
||||
|
||||
self.blocks = nn.ModuleList([
|
||||
DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
|
||||
])
|
||||
self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
|
||||
self.initialize_weights()
|
||||
|
||||
def initialize_weights(self):
|
||||
# Initialize transformer layers:
|
||||
def _basic_init(module):
|
||||
if isinstance(module, nn.Linear):
|
||||
torch.nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0)
|
||||
self.apply(_basic_init)
|
||||
|
||||
# Initialize (and freeze) pos_embed by sin-cos embedding:
|
||||
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches )) #** 0.5
|
||||
#print('pos_embed', pos_embed.shape, self.x_embedder.num_patches)
|
||||
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
||||
|
||||
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
||||
w = self.x_embedder.proj.weight.data
|
||||
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
||||
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
||||
|
||||
# Initialize label embedding table:
|
||||
#nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
|
||||
|
||||
# Initialize timestep embedding MLP:
|
||||
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
||||
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
||||
|
||||
# Zero-out adaLN modulation layers in DiT blocks:
|
||||
for block in self.blocks:
|
||||
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
||||
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
||||
|
||||
# Zero-out output layers:
|
||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
||||
nn.init.constant_(self.final_layer.linear.weight, 0)
|
||||
nn.init.constant_(self.final_layer.linear.bias, 0)
|
||||
|
||||
def unpatchify(self, x):
|
||||
"""
|
||||
x: (N, T, patch_size**2 * C)
|
||||
imgs: (N, H, W, C)
|
||||
"""
|
||||
c = self.out_channels
|
||||
p = self.x_embedder.patch_size[0]
|
||||
'''h = w = int(x.shape[1] ** 0.5)
|
||||
assert h * w == x.shape[1]'''
|
||||
h = x.shape[1]
|
||||
w = 1
|
||||
|
||||
#print(x.shape)
|
||||
|
||||
x = x.reshape(shape=(x.shape[0], h, w, p, 1, c))
|
||||
x = torch.einsum('nhwpqc->nchpwq', x)
|
||||
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * 1))
|
||||
return imgs
|
||||
|
||||
def forward(self, x, t):
|
||||
"""
|
||||
Forward pass of DiT.
|
||||
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||
t: (N,) tensor of diffusion timesteps
|
||||
y: (N,) tensor of class labels
|
||||
"""
|
||||
#print('x', x.shape, 'x_embedder', self.x_embedder(x).shape, 'pos_embed', self.pos_embed.shape)
|
||||
x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
|
||||
t = self.t_embedder(t) # (N, D)
|
||||
#y = self.y_embedder(y, self.training) # (N, D)
|
||||
#c = t + y # (N, D)
|
||||
c = t
|
||||
for block in self.blocks:
|
||||
x = block(x, c) # (N, T, D)
|
||||
x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
|
||||
x = self.unpatchify(x) # (N, out_channels, H, W)
|
||||
return x
|
||||
|
||||
def forward_with_cfg(self, x, t, y, cfg_scale):
|
||||
"""
|
||||
Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
|
||||
"""
|
||||
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
|
||||
half = x[: len(x) // 2]
|
||||
combined = torch.cat([half, half], dim=0)
|
||||
model_out = self.forward(combined, t, y)
|
||||
# For exact reproducibility reasons, we apply classifier-free guidance on only
|
||||
# three channels by default. The standard approach to cfg applies it to all channels.
|
||||
# This can be done by uncommenting the following line and commenting-out the line following that.
|
||||
# eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
|
||||
eps, rest = model_out[:, :3], model_out[:, 3:]
|
||||
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
||||
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
|
||||
eps = torch.cat([half_eps, half_eps], dim=0)
|
||||
return torch.cat([eps, rest], dim=1)
|
||||
|
||||
|
||||
#################################################################################
|
||||
# Sine/Cosine Positional Embedding Functions #
|
||||
#################################################################################
|
||||
# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
||||
|
||||
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
|
||||
"""
|
||||
grid_size: int of the grid height and width
|
||||
return:
|
||||
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
||||
"""
|
||||
grid_h = np.arange(grid_size, dtype=np.float32)
|
||||
grid_w = np.arange(1, dtype=np.float32)
|
||||
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||
grid = np.stack(grid, axis=0)
|
||||
|
||||
grid = grid.reshape([2, 1, grid_size, 1])
|
||||
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
||||
if cls_token and extra_tokens > 0:
|
||||
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
||||
return pos_embed
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
||||
assert embed_dim % 2 == 0
|
||||
|
||||
# use half of dimensions to encode grid_h
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
||||
|
||||
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
||||
return emb
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||
"""
|
||||
embed_dim: output dimension for each position
|
||||
pos: a list of positions to be encoded: size (M,)
|
||||
out: (M, D)
|
||||
"""
|
||||
assert embed_dim % 2 == 0
|
||||
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
||||
omega /= embed_dim / 2.
|
||||
omega = 1. / 10000**omega # (D/2,)
|
||||
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
||||
|
||||
emb_sin = np.sin(out) # (M, D/2)
|
||||
emb_cos = np.cos(out) # (M, D/2)
|
||||
|
||||
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
||||
return emb
|
||||
|
||||
def get_emb(sin_inp):
|
||||
"""
|
||||
Gets a base embedding for one dimension with sin and cos intertwined
|
||||
"""
|
||||
emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1)
|
||||
return torch.flatten(emb, -2, -1)
|
||||
|
||||
|
||||
class PositionalEncoding1D(nn.Module):
|
||||
def __init__(self, channels, dtype_override=None):
|
||||
"""
|
||||
:param channels: The last dimension of the tensor you want to apply pos emb to.
|
||||
:param dtype_override: If set, overrides the dtype of the output embedding.
|
||||
"""
|
||||
super(PositionalEncoding1D, self).__init__()
|
||||
self.org_channels = channels
|
||||
channels = int(np.ceil(channels / 2) * 2)
|
||||
inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels))
|
||||
self.register_buffer("inv_freq", inv_freq)
|
||||
self.register_buffer("cached_penc", None, persistent=False)
|
||||
self.channels = channels
|
||||
self.dtype_override = dtype_override
|
||||
|
||||
def forward(self, tensor):
|
||||
"""
|
||||
:param tensor: A 3d tensor of size (batch_size, ch, x)
|
||||
:return: Positional Encoding Matrix of size (batch_size, ch, x)
|
||||
"""
|
||||
if len(tensor.shape) != 3:
|
||||
raise RuntimeError("The input tensor has to be 3d!")
|
||||
|
||||
if self.cached_penc is not None and self.cached_penc.shape == tensor.shape:
|
||||
return self.cached_penc
|
||||
|
||||
self.cached_penc = None
|
||||
batch_size, orig_ch, x = tensor.shape
|
||||
pos_x = torch.arange(x, device=tensor.device, dtype=self.inv_freq.dtype)
|
||||
sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq)
|
||||
emb_x = get_emb(sin_inp_x)
|
||||
#print('emb_x', emb_x.shape)
|
||||
emb = torch.zeros(
|
||||
(self.channels, x),
|
||||
device=tensor.device,
|
||||
dtype=(
|
||||
self.dtype_override if self.dtype_override is not None else tensor.dtype
|
||||
),
|
||||
)
|
||||
emb[:self.channels, :] = emb_x.permute(1,0)
|
||||
|
||||
self.cached_penc = emb[None, :orig_ch, :].repeat(batch_size, 1, 1)
|
||||
return self.cached_penc
|
||||
|
||||
class TransformerModel(nn.Module):
|
||||
|
||||
def __init__(self, ntoken, ninp, nhead, nhid, nlayers=6, dropout=0.0):
|
||||
super(TransformerModel, self).__init__()
|
||||
|
||||
self.model_type = 'Transformer'
|
||||
self.pos_encoder = PositionalEncoding1D(ninp)
|
||||
encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout, batch_first=True)
|
||||
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
|
||||
#self.encoder = nn.Embedding(ntoken, ninp)
|
||||
self.ninp = ninp
|
||||
#self.decoder = nn.Linear(ninp, ntoken)
|
||||
self.time_mlp = nn.Sequential( # should be removed for Noise and Deterministic Baselines
|
||||
SinusoidalPosEmb(ninp),
|
||||
nn.Linear(ninp-1, ninp * 4),
|
||||
nn.Mish(),
|
||||
nn.Linear(ninp * 4, ninp),
|
||||
)
|
||||
|
||||
#self.init_weights()
|
||||
|
||||
def generate_square_subsequent_mask(self, sz):
|
||||
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
||||
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
||||
return mask
|
||||
|
||||
'''def init_weights(self):
|
||||
initrange = 0.1
|
||||
self.encoder.weight.data.uniform_(-initrange, initrange)
|
||||
self.decoder.bias.data.zero_()
|
||||
self.decoder.weight.data.uniform_(-initrange, initrange)'''
|
||||
|
||||
def forward(self, src, t):
|
||||
#print('self.ninp', self.ninp)
|
||||
#src = self.encoder(src) * math.sqrt(self.ninp)
|
||||
#print('src', src.shape)
|
||||
|
||||
#t = self.time_mlp(t).unsqueeze(1)
|
||||
|
||||
emb = self.pos_encoder(src)
|
||||
#time = torch.cat((t,t,t), dim=1)
|
||||
#print('time', time.shape)
|
||||
output = self.transformer_encoder(src+emb)
|
||||
#print('shape after transformer', output.shape)
|
||||
#output = self.decoder(output)
|
||||
return output
|
||||
|
||||
|
||||
#################################################################################
|
||||
# DiT Configs #
|
||||
#################################################################################
|
||||
|
||||
def DiT_XL_2(**kwargs):
|
||||
return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
|
||||
|
||||
def DiT_XL_4(**kwargs):
|
||||
return DiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
|
||||
|
||||
def DiT_XL_8(**kwargs):
|
||||
return DiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs)
|
||||
|
||||
def DiT_L_2(**kwargs):
|
||||
return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
|
||||
|
||||
def DiT_L_4(**kwargs):
|
||||
return DiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs)
|
||||
|
||||
def DiT_L_8(**kwargs):
|
||||
return DiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs)
|
||||
|
||||
def DiT_B_2(**kwargs):
|
||||
return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
|
||||
|
||||
def DiT_B_4(**kwargs):
|
||||
return DiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)
|
||||
|
||||
def DiT_B_8(**kwargs):
|
||||
return DiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
|
||||
|
||||
def DiT_S_2(**kwargs):
|
||||
return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
|
||||
|
||||
def DiT_S_4(**kwargs):
|
||||
return DiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
|
||||
|
||||
def DiT_S_8(**kwargs):
|
||||
return DiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
|
||||
|
||||
|
||||
DiT_models = {
|
||||
'DiT-XL/2': DiT_XL_2, 'DiT-XL/4': DiT_XL_4, 'DiT-XL/8': DiT_XL_8,
|
||||
'DiT-L/2': DiT_L_2, 'DiT-L/4': DiT_L_4, 'DiT-L/8': DiT_L_8,
|
||||
'DiT-B/2': DiT_B_2, 'DiT-B/4': DiT_B_4, 'DiT-B/8': DiT_B_8,
|
||||
'DiT-S/2': DiT_S_2, 'DiT-S/4': DiT_S_4, 'DiT-S/8': DiT_S_8,
|
||||
}
|
346
model/helpers.py
Normal file
346
model/helpers.py
Normal file
|
@ -0,0 +1,346 @@
|
|||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops.layers.torch import Rearrange
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
import os
|
||||
import numpy as np
|
||||
import logging
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------#
|
||||
# ---------------------------------- modules ----------------------------------#
|
||||
# -----------------------------------------------------------------------------#
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
class SinusoidalPosEmb(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x):
|
||||
device = x.device
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
||||
emb = x[:, None] * emb[None, :]
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
return emb
|
||||
|
||||
|
||||
class Downsample1d(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
#self.conv = nn.Conv1d(dim, dim, 2, 1, 0)
|
||||
self.conv = nn.Conv1d(dim, dim, 1, 1, 0)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Upsample1d(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
#self.conv = nn.ConvTranspose1d(dim, dim, 2, 1, 0)
|
||||
self.conv = nn.ConvTranspose1d(dim, dim, 1, 1, 0)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Conv1dBlock(nn.Module):
|
||||
"""
|
||||
Conv1d --> GroupNorm --> Mish
|
||||
"""
|
||||
|
||||
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=32, drop_out=0.0, if_zero=False):
|
||||
super().__init__()
|
||||
if drop_out > 0.0:
|
||||
self.block = nn.Sequential(
|
||||
zero_module(
|
||||
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=1),
|
||||
),
|
||||
Rearrange('batch channels horizon -> batch channels 1 horizon'),
|
||||
nn.GroupNorm(n_groups, out_channels),
|
||||
Rearrange('batch channels 1 horizon -> batch channels horizon'),
|
||||
nn.Mish(),
|
||||
nn.Dropout(p=drop_out),
|
||||
)
|
||||
elif if_zero:
|
||||
self.block = nn.Sequential(
|
||||
zero_module(
|
||||
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=1),
|
||||
),
|
||||
Rearrange('batch channels horizon -> batch channels 1 horizon'),
|
||||
nn.GroupNorm(n_groups, out_channels),
|
||||
Rearrange('batch channels 1 horizon -> batch channels horizon'),
|
||||
nn.Mish(),
|
||||
|
||||
)
|
||||
else:
|
||||
self.block = nn.Sequential(
|
||||
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=1),
|
||||
Rearrange('batch channels horizon -> batch channels 1 horizon'),
|
||||
nn.GroupNorm(n_groups, out_channels),
|
||||
Rearrange('batch channels 1 horizon -> batch channels horizon'),
|
||||
nn.Mish(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------#
|
||||
# ---------------------------------- sampling ---------------------------------#
|
||||
# -----------------------------------------------------------------------------#
|
||||
|
||||
def extract(a, t, x_shape):
|
||||
b, *_ = t.shape
|
||||
out = a.gather(-1, t)
|
||||
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
||||
|
||||
|
||||
def cosine_beta_schedule(timesteps, s=0.008, dtype=torch.float32):
|
||||
"""
|
||||
cosine schedule
|
||||
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
|
||||
"""
|
||||
steps = timesteps + 1
|
||||
x = np.linspace(0, steps, steps)
|
||||
alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
|
||||
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
||||
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
||||
betas_clipped = np.clip(betas, a_min=0, a_max=0.999)
|
||||
return torch.tensor(betas_clipped, dtype=dtype)
|
||||
|
||||
|
||||
def condition_projection(x, conditions, action_dim, class_dim):
|
||||
for t, val in conditions.items():
|
||||
if t != 'task':
|
||||
x[:, t, class_dim + action_dim:] = val.clone()
|
||||
|
||||
x[:, 1:-1, class_dim + action_dim:] = 0.
|
||||
x[:, :, :class_dim] = conditions['task']
|
||||
|
||||
return x
|
||||
|
||||
def condition_projection_noise(x, conditions, action_dim, class_dim):
|
||||
'''for t, val in conditions.items():
|
||||
if t != 'task':
|
||||
x[:, t, class_dim + action_dim:] = val.clone()'''
|
||||
|
||||
x[:, 1:-1, class_dim + action_dim:] = 0.
|
||||
x[:, :, :class_dim] = conditions['task']
|
||||
return x
|
||||
|
||||
'''def condition_projection_dit(x, conditions, action_dim, class_dim):
|
||||
for t, val in conditions.items():
|
||||
x[:, t, action_dim:] = val.clone()
|
||||
|
||||
x[:, 1:-1, action_dim:] = 0.
|
||||
return x'''
|
||||
|
||||
# for img tensors as img, 32*64
|
||||
def condition_projection_dit(x, conditions, action_dim, class_dim):
|
||||
for t, val in conditions.items():
|
||||
#print(t, x.shape, val.shape)
|
||||
x[:, t, :, :48] = val.clone()
|
||||
|
||||
x[:, 1:-1, :, :48] = 0.
|
||||
return x
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------#
|
||||
# ---------------------------------- Loss -------------------------------------#
|
||||
# -----------------------------------------------------------------------------#
|
||||
|
||||
def normal_kl(mean1, logvar1, mean2, logvar2):
|
||||
"""
|
||||
Compute the KL divergence between two gaussians.
|
||||
Shapes are automatically broadcasted, so batches can be compared to
|
||||
scalars, among other use cases.
|
||||
"""
|
||||
tensor = None
|
||||
for obj in (mean1, logvar1, mean2, logvar2):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
tensor = obj
|
||||
break
|
||||
assert tensor is not None, "at least one argument must be a Tensor"
|
||||
|
||||
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
||||
# Tensors, but it does not work for th.exp().
|
||||
logvar1, logvar2 = [
|
||||
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
|
||||
for x in (logvar1, logvar2)
|
||||
]
|
||||
|
||||
return 0.5 * (
|
||||
-1.0
|
||||
+ logvar2
|
||||
- logvar1
|
||||
+ torch.exp(logvar1 - logvar2)
|
||||
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
|
||||
)
|
||||
|
||||
class Weighted_MSE(nn.Module):
|
||||
|
||||
def __init__(self, weights, action_dim, class_dim):
|
||||
super().__init__()
|
||||
# self.register_buffer('weights', weights)
|
||||
self.action_dim = action_dim
|
||||
self.class_dim = class_dim
|
||||
|
||||
def forward(self, pred, targ):
|
||||
"""
|
||||
:param pred: [B, T, task_dim+action_dim+observation_dim]
|
||||
:param targ: [B, T, task_dim+action_dim+observation_dim]
|
||||
:return:
|
||||
"""
|
||||
|
||||
loss_action = F.mse_loss(pred, targ, reduction='none')
|
||||
loss_action[:, 0, self.class_dim:self.class_dim + self.action_dim] *= 10.
|
||||
loss_action[:, -1, self.class_dim:self.class_dim + self.action_dim] *= 10.
|
||||
loss_action = loss_action.sum()
|
||||
return loss_action
|
||||
|
||||
class Weighted_MSE_dit(nn.Module):
|
||||
|
||||
def __init__(self, weights, action_dim, class_dim):
|
||||
super().__init__()
|
||||
# self.register_buffer('weights', weights)
|
||||
self.action_dim = action_dim
|
||||
self.class_dim = class_dim
|
||||
self.weight = torch.full((32, 16), 10).cuda()
|
||||
|
||||
def forward(self, pred, targ):
|
||||
"""
|
||||
:param pred: [B, T, task_dim+action_dim+observation_dim]
|
||||
:param targ: [B, T, task_dim+action_dim+observation_dim]
|
||||
:return:
|
||||
"""
|
||||
|
||||
loss_action = F.mse_loss(pred, targ, reduction='none')
|
||||
#print('loss_action', loss_action.shape)
|
||||
'''print('loss_action', loss_action.shape)
|
||||
loss_action[:, 0, 32, 48:] *= 10.
|
||||
loss_action[:, -1, 32, 48:] *= 10.
|
||||
loss_action = loss_action.sum()'''
|
||||
loss_action[:, 0, :, 48:] *= self.weight
|
||||
loss_action[:, -1, :, 48:] *= self.weight
|
||||
loss_action = loss_action.sum()
|
||||
return loss_action
|
||||
|
||||
|
||||
Losses = {
|
||||
'Weighted_MSE': Weighted_MSE,
|
||||
#'Weighted_MSE': Weighted_MSE_dit,
|
||||
'normal_kl': normal_kl,
|
||||
}
|
||||
|
||||
# -----------------------------------------------------------------------------#
|
||||
# -------------------------------- lr_schedule --------------------------------#
|
||||
# -----------------------------------------------------------------------------#
|
||||
|
||||
def get_lr_schedule_with_warmup(args, optimizer, num_training_steps, last_epoch=-1):
|
||||
if args.dataset == 'crosstask':
|
||||
num_warmup_steps = num_training_steps * 20 / 120
|
||||
decay_steps = num_training_steps * 30 / 120
|
||||
|
||||
def lr_lambda(current_step):
|
||||
if current_step <= num_warmup_steps:
|
||||
return max(0., float(current_step) / float(max(1, num_warmup_steps)))
|
||||
else:
|
||||
return max(0.5 ** ((current_step - num_warmup_steps) // decay_steps), 0.)
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
||||
|
||||
if args.dataset == 'coin':
|
||||
num_warmup_steps = num_training_steps * 20 / 800 # total 160,000 steps, 200*800, up to 4500 steps, increase linearly
|
||||
decay_steps = num_training_steps * 50 / 800 # total 160,000 steps, 200*800, decay at 6000 steps
|
||||
|
||||
def lr_lambda(current_step):
|
||||
if current_step <= num_warmup_steps:
|
||||
return max(0., float(current_step) / float(max(1, num_warmup_steps)))
|
||||
else:
|
||||
return max(0.5 ** ((current_step - num_warmup_steps) // decay_steps), 0.)
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
||||
|
||||
if args.dataset == 'NIV':
|
||||
num_warmup_steps = num_training_steps * 90 / 130 # total 6500 steps, 50*130, up to 4500 steps, increase linearly 90 / 130
|
||||
decay_steps = num_training_steps * 120 / 130 # total 6500 steps, 50*130, decay at 6000 steps 120 / 130
|
||||
|
||||
def lr_lambda(current_step):
|
||||
if current_step <= num_warmup_steps:
|
||||
return max(0., float(current_step) / float(max(1, num_warmup_steps)))
|
||||
else:
|
||||
return max(0.5 ** ((current_step - num_warmup_steps) // decay_steps), 0.)
|
||||
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
||||
|
||||
# -----------------------------------------------------------------------------#
|
||||
# ---------------------------------- logging ----------------------------------#
|
||||
# -----------------------------------------------------------------------------#
|
||||
|
||||
# Taken from PyTorch's examples.imagenet.main
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
class Logger:
|
||||
def __init__(self, log_dir, n_logged_samples=10, summary_writer=SummaryWriter, if_exist=False):
|
||||
self._log_dir = log_dir
|
||||
print('logging outputs to ', log_dir)
|
||||
self._n_logged_samples = n_logged_samples
|
||||
self._summ_writer = summary_writer(log_dir, flush_secs=120, max_queue=10)
|
||||
if not if_exist:
|
||||
log = logging.getLogger(log_dir)
|
||||
if not log.handlers:
|
||||
log.setLevel(logging.DEBUG)
|
||||
if not os.path.exists(log_dir):
|
||||
os.mkdir(log_dir)
|
||||
fh = logging.FileHandler(os.path.join(log_dir, 'log.txt'))
|
||||
fh.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter(fmt='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S')
|
||||
fh.setFormatter(formatter)
|
||||
log.addHandler(fh)
|
||||
self.log = log
|
||||
|
||||
def log_scalar(self, scalar, name, step_):
|
||||
self._summ_writer.add_scalar('{}'.format(name), scalar, step_)
|
||||
|
||||
def log_scalars(self, scalar_dict, group_name, step, phase):
|
||||
"""Will log all scalars in the same plot."""
|
||||
self._summ_writer.add_scalars('{}_{}'.format(group_name, phase), scalar_dict, step)
|
||||
|
||||
def flush(self):
|
||||
self._summ_writer.flush()
|
||||
|
||||
def log_info(self, info):
|
||||
self.log.info("{}".format(info))
|
324
model/temporal_act.py
Normal file
324
model/temporal_act.py
Normal file
|
@ -0,0 +1,324 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import einops
|
||||
from einops.layers.torch import Rearrange
|
||||
import math
|
||||
|
||||
from .helpers import (
|
||||
SinusoidalPosEmb,
|
||||
Downsample1d,
|
||||
Upsample1d,
|
||||
Conv1dBlock,
|
||||
)
|
||||
|
||||
class GroupNorm32(nn.GroupNorm):
|
||||
def forward(self, x):
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
def normalization(channels):
|
||||
"""
|
||||
Make a standard normalization layer.
|
||||
|
||||
:param channels: number of input channels.
|
||||
:return: an nn.Module for normalization.
|
||||
"""
|
||||
return GroupNorm32(32, channels)
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
"""
|
||||
An attention block that allows spatial positions to attend to each other.
|
||||
|
||||
Originally ported from here, but adapted to the N-d case.
|
||||
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, num_heads=1, use_checkpoint=False):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.num_heads = num_heads
|
||||
self.use_checkpoint = use_checkpoint
|
||||
|
||||
self.norm = normalization(channels)
|
||||
#self.qkv = conv_nd(1, channels, channels * 3, 1)
|
||||
self.qkv = Conv1dBlock( channels, channels * 3, 2)
|
||||
|
||||
self.attention = QKVAttention()
|
||||
#self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
|
||||
self.proj_out = zero_module(Conv1dBlock(channels, channels, 4))
|
||||
|
||||
def forward(self, x):
|
||||
#print(x.shape)
|
||||
b, c, *spatial = x.shape
|
||||
x = x.reshape(b, c, -1)
|
||||
qkv = self.qkv(self.norm(x))
|
||||
qkv = qkv.reshape(b * self.num_heads, -1, qkv.shape[2])
|
||||
h = self.attention(qkv)
|
||||
#print(h.shape, qkv.shape)
|
||||
h = h.reshape(b, -1, h.shape[-1])
|
||||
h = self.proj_out(h)
|
||||
#print(x.shape, h.shape)
|
||||
return (x + h).reshape(b, c, *spatial)
|
||||
|
||||
|
||||
class QKVAttention(nn.Module):
|
||||
"""
|
||||
A module which performs QKV attention.
|
||||
"""
|
||||
|
||||
def forward(self, qkv):
|
||||
"""
|
||||
Apply QKV attention.
|
||||
|
||||
:param qkv: an [N x (C * 3) x T] tensor of Qs, Ks, and Vs.
|
||||
:return: an [N x C x T] tensor after attention.
|
||||
"""
|
||||
ch = qkv.shape[1] // 3
|
||||
q, k, v = torch.split(qkv, ch, dim=1)
|
||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||
weight = torch.einsum(
|
||||
"bct,bcs->bts", q * scale, k * scale
|
||||
) # More stable with f16 than dividing afterwards
|
||||
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
return torch.einsum("bts,bcs->bct", weight, v)
|
||||
|
||||
|
||||
class ResidualTemporalBlock(nn.Module):
|
||||
|
||||
def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=3):
|
||||
super().__init__()
|
||||
|
||||
self.blocks = nn.ModuleList([
|
||||
Conv1dBlock(inp_channels, out_channels, kernel_size),
|
||||
Conv1dBlock(out_channels, out_channels, kernel_size, if_zero=True)
|
||||
])
|
||||
self.time_mlp = nn.Sequential( # should be removed for Noise and Deterministic Baselines
|
||||
nn.Mish(),
|
||||
nn.Linear(embed_dim, out_channels),
|
||||
Rearrange('batch t -> batch t 1'),
|
||||
)
|
||||
self.residual_conv = nn.Conv1d(inp_channels, out_channels, 1) \
|
||||
if inp_channels != out_channels else nn.Identity()
|
||||
|
||||
self.dropout = nn.Dropout(0.5)
|
||||
|
||||
def forward(self, x, t):
|
||||
out = self.blocks[0](x) + self.time_mlp(t) # for diffusion
|
||||
# out = self.blocks[0](x) # for Noise and Deterministic Baselines
|
||||
out = self.blocks[1](out)
|
||||
return out + self.residual_conv(self.dropout(x))
|
||||
|
||||
|
||||
class TemporalUnet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
transition_dim,
|
||||
num_class,
|
||||
dim=32,
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
|
||||
in_out = list(zip(dims[:-1], dims[1:]))
|
||||
|
||||
time_dim = dim
|
||||
self.time_mlp = nn.Sequential( # should be removed for Noise and Deterministic Baselines
|
||||
SinusoidalPosEmb(dim),
|
||||
nn.Linear(dim, dim * 4),
|
||||
nn.Mish(),
|
||||
nn.Linear(dim * 4, dim),
|
||||
)
|
||||
|
||||
self.downs = nn.ModuleList([])
|
||||
self.ups = nn.ModuleList([])
|
||||
num_resolutions = len(in_out)
|
||||
|
||||
self.label_embed = nn.Embedding(num_class, time_dim)
|
||||
|
||||
# print(in_out)
|
||||
for ind, (dim_in, dim_out) in enumerate(in_out):
|
||||
is_last = ind >= (num_resolutions - 1)
|
||||
|
||||
self.downs.append(nn.ModuleList([
|
||||
ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim),
|
||||
AttentionBlock(dim_out, use_checkpoint=False, num_heads=4),
|
||||
ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim),
|
||||
Downsample1d(dim_out) if not is_last else nn.Identity()
|
||||
]))
|
||||
|
||||
'''self.downs.append(nn.ModuleList([
|
||||
ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim*2),
|
||||
ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim*2),
|
||||
Downsample1d(dim_out) if not is_last else nn.Identity()
|
||||
]))'''
|
||||
|
||||
mid_dim = dims[-1]
|
||||
self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim)
|
||||
self.attention = AttentionBlock(mid_dim, use_checkpoint=False, num_heads=16)
|
||||
self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim)
|
||||
'''self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim*2)
|
||||
self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim*2)'''
|
||||
|
||||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||
is_last = ind >= (num_resolutions - 1)
|
||||
|
||||
self.ups.append(nn.ModuleList([
|
||||
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim),
|
||||
AttentionBlock(dim_in, use_checkpoint=False, num_heads=4),
|
||||
ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim),
|
||||
Upsample1d(dim_in) if not is_last else nn.Identity()
|
||||
]))
|
||||
'''self.ups.append(nn.ModuleList([
|
||||
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim*2),
|
||||
ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim*2),
|
||||
Upsample1d(dim_in) if not is_last else nn.Identity()
|
||||
]))'''
|
||||
|
||||
self.final_conv = nn.Sequential(
|
||||
Conv1dBlock(dim, dim, kernel_size=3, if_zero=True),
|
||||
nn.Conv1d(dim, transition_dim, 1),
|
||||
)
|
||||
|
||||
def forward(self, x, time, class_label):
|
||||
x = einops.rearrange(x, 'b h t -> b t h')
|
||||
|
||||
# t = None # for Noise and Deterministic Baselines
|
||||
t = self.time_mlp(time) # for diffusion
|
||||
#print(x.shape, time.shape, t.shape, class_label.shape)
|
||||
#y_emb = self.label_embed(class_label)
|
||||
#print(t.shape, y_emb.shape)
|
||||
#t = t + y_emb
|
||||
#t = torch.cat((t, y_emb), 1)
|
||||
h = []
|
||||
|
||||
for resnet, attn, resnet2, downsample in self.downs:
|
||||
x = resnet(x, t)
|
||||
x = attn(x)
|
||||
x = resnet2(x, t)
|
||||
h.append(x)
|
||||
x = downsample(x)
|
||||
|
||||
x = self.mid_block1(x, t)
|
||||
x = self.attention(x)
|
||||
x = self.mid_block2(x, t)
|
||||
|
||||
for resnet,attn, resnet2, upsample in self.ups:
|
||||
x = torch.cat((x, h.pop()), dim=1)
|
||||
x = resnet(x, t)
|
||||
x = attn(x)
|
||||
x = resnet2(x, t)
|
||||
x = upsample(x)
|
||||
|
||||
x = self.final_conv(x)
|
||||
x = einops.rearrange(x, 'b t h -> b h t')
|
||||
return x
|
||||
|
||||
class TemporalUnetNoAttn(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
transition_dim,
|
||||
num_class,
|
||||
dim=32,
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
|
||||
in_out = list(zip(dims[:-1], dims[1:]))
|
||||
|
||||
time_dim = dim
|
||||
self.time_mlp = nn.Sequential( # should be removed for Noise and Deterministic Baselines
|
||||
SinusoidalPosEmb(dim),
|
||||
nn.Linear(dim, dim * 4),
|
||||
nn.Mish(),
|
||||
nn.Linear(dim * 4, dim),
|
||||
)
|
||||
|
||||
self.downs = nn.ModuleList([])
|
||||
self.ups = nn.ModuleList([])
|
||||
num_resolutions = len(in_out)
|
||||
|
||||
self.label_embed = nn.Embedding(num_class, time_dim)
|
||||
|
||||
# print(in_out)
|
||||
for ind, (dim_in, dim_out) in enumerate(in_out):
|
||||
is_last = ind >= (num_resolutions - 1)
|
||||
|
||||
self.downs.append(nn.ModuleList([
|
||||
ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim),
|
||||
#AttentionBlock(dim_out, use_checkpoint=False, num_heads=4),
|
||||
ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim),
|
||||
Downsample1d(dim_out) if not is_last else nn.Identity()
|
||||
]))
|
||||
|
||||
'''self.downs.append(nn.ModuleList([
|
||||
ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim*2),
|
||||
ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim*2),
|
||||
Downsample1d(dim_out) if not is_last else nn.Identity()
|
||||
]))'''
|
||||
|
||||
mid_dim = dims[-1]
|
||||
self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim)
|
||||
#self.attention = AttentionBlock(mid_dim, use_checkpoint=False, num_heads=16)
|
||||
self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim)
|
||||
'''self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim*2)
|
||||
self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim*2)'''
|
||||
|
||||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||
is_last = ind >= (num_resolutions - 1)
|
||||
|
||||
self.ups.append(nn.ModuleList([
|
||||
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim),
|
||||
#AttentionBlock(dim_in, use_checkpoint=False, num_heads=4),
|
||||
ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim),
|
||||
Upsample1d(dim_in) if not is_last else nn.Identity()
|
||||
]))
|
||||
'''self.ups.append(nn.ModuleList([
|
||||
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim*2),
|
||||
ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim*2),
|
||||
Upsample1d(dim_in) if not is_last else nn.Identity()
|
||||
]))'''
|
||||
|
||||
self.final_conv = nn.Sequential(
|
||||
Conv1dBlock(dim, dim, kernel_size=3, if_zero=True),
|
||||
nn.Conv1d(dim, transition_dim, 1),
|
||||
)
|
||||
|
||||
def forward(self, x, time, class_label):
|
||||
x = einops.rearrange(x, 'b h t -> b t h')
|
||||
|
||||
# t = None # for Noise and Deterministic Baselines
|
||||
t = self.time_mlp(time) # for diffusion
|
||||
#print(x.shape, time.shape, t.shape, class_label.shape)
|
||||
#y_emb = self.label_embed(class_label)
|
||||
#print(t.shape, y_emb.shape)
|
||||
#t = t + y_emb
|
||||
#t = torch.cat((t, y_emb), 1)
|
||||
h = []
|
||||
|
||||
for resnet, resnet2, downsample in self.downs:
|
||||
x = resnet(x, t)
|
||||
x = resnet2(x, t)
|
||||
h.append(x)
|
||||
x = downsample(x)
|
||||
|
||||
x = self.mid_block1(x, t)
|
||||
x = self.mid_block2(x, t)
|
||||
|
||||
for resnet, resnet2, upsample in self.ups:
|
||||
x = torch.cat((x, h.pop()), dim=1)
|
||||
x = resnet(x, t)
|
||||
x = resnet2(x, t)
|
||||
x = upsample(x)
|
||||
|
||||
x = self.final_conv(x)
|
||||
x = einops.rearrange(x, 'b t h -> b h t')
|
||||
return x
|
Loading…
Add table
Add a link
Reference in a new issue