first commit

This commit is contained in:
Lei Shi 2024-12-02 15:42:58 +01:00
commit 8f8cf48929
2819 changed files with 33143 additions and 0 deletions

212
model/diffusion_act.py Normal file
View file

@ -0,0 +1,212 @@
import random
import numpy as np
import torch
from torch import nn
from .helpers import (
cosine_beta_schedule,
extract,
condition_projection,
Losses,
)
class GaussianDiffusion(nn.Module):
def __init__(self, model, horizon, observation_dim, action_dim, class_dim, n_timesteps=200,
loss_type='Weighted_MSE', clip_denoised=False, ddim_discr_method='uniform',
):
super().__init__()
self.horizon = horizon
self.observation_dim = observation_dim
self.action_dim = action_dim
self.class_dim = class_dim
self.model = model
betas = cosine_beta_schedule(n_timesteps)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]])
self.n_timesteps = n_timesteps
self.clip_denoised = clip_denoised
self.eta = 0.0
self.random_ratio = 1.0
# ---------------------------ddim--------------------------------
ddim_timesteps = 10
if ddim_discr_method == 'uniform':
c = n_timesteps // ddim_timesteps
ddim_timestep_seq = np.asarray(list(range(0, n_timesteps, c)))
elif ddim_discr_method == 'quad':
ddim_timestep_seq = (
(np.linspace(0, np.sqrt(n_timesteps), ddim_timesteps)) ** 2
).astype(int)
else:
assert RuntimeError()
self.ddim_timesteps = ddim_timesteps
self.ddim_timestep_seq = ddim_timestep_seq
# ----------------------------------------------------------------
self.register_buffer('betas', betas)
self.register_buffer('alphas_cumprod', alphas_cumprod)
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
self.register_buffer('posterior_variance', posterior_variance)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.register_buffer('posterior_log_variance_clipped',
torch.log(torch.clamp(posterior_variance, min=1e-20)))
self.register_buffer('posterior_mean_coef1',
betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
self.register_buffer('posterior_mean_coef2',
(1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))
self.loss_type = loss_type
self.loss_fn = Losses[loss_type](None, self.action_dim, self.class_dim)
# ------------------------------------------ sampling ------------------------------------------#
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, x, cond, task_label, t):
#x_recon = self.model(x, t)
x_recon = self.model(x, t, task_label)
if self.clip_denoised:
x_recon.clamp(-1., 1.)
else:
assert RuntimeError()
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
x_start=x_recon, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
return \
(extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) \
/ extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
@torch.no_grad()
def p_sample_ddim(self, x, cond, avg_mask, task_label, t, t_prev, if_prev=False, if_avg_mask=False):
b, *_, device = *x.shape, x.device
#x_recon = self.model(x, t) # without class condition
x_recon = self.model(x, t, task_label)
if self.clip_denoised:
x_recon.clamp(-1., 1.)
else:
assert RuntimeError()
eps = self._predict_eps_from_xstart(x, t, x_recon)
alpha_bar = extract(self.alphas_cumprod, t, x.shape)
if if_prev:
alpha_bar_prev = extract(self.alphas_cumprod_prev, t_prev, x.shape)
else:
alpha_bar_prev = extract(self.alphas_cumprod, t_prev, x.shape)
sigma = (
self.eta
* torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
* torch.sqrt(1 - alpha_bar / alpha_bar_prev)
)
noise = torch.randn_like(x) * self.random_ratio
mean_pred = (
x_recon * torch.sqrt(alpha_bar_prev)
+ torch.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
)
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return mean_pred + nonzero_mask * sigma * noise
@torch.no_grad()
def p_sample(self, x, cond, task_label, t):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(x=x, cond=cond, task_label=task_label, t=t)
noise = torch.randn_like(x) * self.random_ratio
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def p_sample_loop(self, cond, avg_mask, task_label, if_jump, if_avg_mask):
device = self.betas.device
batch_size = len(cond[0])
horizon = self.horizon
shape = (batch_size, horizon, self.class_dim + self.action_dim + self.observation_dim)
x = torch.randn(shape, device=device) * self.random_ratio # xt for Noise and diffusion
x = condition_projection(x, cond, self.action_dim, self.class_dim)
if not if_jump:
for i in reversed(range(0, self.n_timesteps)):
timesteps = torch.full((batch_size,), i, device=device, dtype=torch.long)
x = self.p_sample(x, cond, task_label, timesteps)
x = condition_projection(x, cond, self.action_dim, self.class_dim)
else:
for i in reversed(range(0, self.ddim_timesteps)):
timesteps = torch.full((batch_size,), self.ddim_timestep_seq[i], device=device, dtype=torch.long)
if i == 0:
timesteps_prev = torch.full((batch_size,), 0, device=device, dtype=torch.long)
x = self.p_sample_ddim(x, cond, avg_mask, task_label, timesteps, timesteps_prev, True, if_avg_mask)
else:
timesteps_prev = torch.full((batch_size,), self.ddim_timestep_seq[i-1], device=device, dtype=torch.long)
x = self.p_sample_ddim(x, cond, avg_mask, task_label, timesteps, timesteps_prev, if_avg_mask)
x = condition_projection(x, cond, self.action_dim, self.class_dim)
return x
# ------------------------------------------ training ------------------------------------------#
def q_sample(self, x_start, t, noise=None):
if noise is None:
noise = torch.randn_like(x_start) * self.random_ratio
sample = (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
return sample
def p_losses(self, x_start, cond, t, act_emb_noise, task_label):
noise = act_emb_noise * self.random_ratio
#print('act noise shape', noise.shape)
#print('x_start shape', x_start.shape)
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) # for diffusion, should be removed for Noise and Deterministic
x_noisy = condition_projection(x_noisy, cond, self.action_dim, self.class_dim)
#x_recon = self.model(x_noisy, t) # without class condition
x_recon = self.model(x_noisy, t, task_label) # with class condition
#print('x_recon',x_recon.shape)
x_recon = condition_projection(x_recon, cond, self.action_dim, self.class_dim)
loss = self.loss_fn(x_recon, x_start)
return loss
def loss(self, x, cond, act_emb_noise, task_label):
batch_size = len(x) # for diffusion
t = torch.randint(0, self.n_timesteps, (batch_size,), device=x.device).long() # for diffusion
# t = None # for Noise and Deterministic
return self.p_losses(x, cond, t, act_emb_noise, task_label)
def forward(self, cond, avg_mask, task_label, if_jump=False, if_avg_mask=False):
return self.p_sample_loop(cond, avg_mask, task_label, if_jump, if_avg_mask)

222
model/diffusion_act_dist.py Normal file
View file

@ -0,0 +1,222 @@
import random
import numpy as np
import torch
from torch import nn
from .helpers import (
cosine_beta_schedule,
extract,
condition_projection,
Losses,
)
class GaussianDiffusion(nn.Module):
def __init__(self, model, horizon, observation_dim, action_dim, class_dim, act_mean, act_std, n_timesteps=200,
loss_type='Weighted_MSE', clip_denoised=False, ddim_discr_method='uniform',
):
super().__init__()
self.horizon = horizon
self.observation_dim = observation_dim
self.action_dim = action_dim
self.class_dim = class_dim
self.model = model
betas = cosine_beta_schedule(n_timesteps)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]])
self.n_timesteps = n_timesteps
self.clip_denoised = clip_denoised
self.eta = 0.0
self.random_ratio = 1.0
self.act_mean = act_mean
self.act_std = act_std
# ---------------------------ddim--------------------------------
ddim_timesteps = 10
if ddim_discr_method == 'uniform':
c = n_timesteps // ddim_timesteps
ddim_timestep_seq = np.asarray(list(range(0, n_timesteps, c)))
elif ddim_discr_method == 'quad':
ddim_timestep_seq = (
(np.linspace(0, np.sqrt(n_timesteps), ddim_timesteps)) ** 2
).astype(int)
else:
assert RuntimeError()
self.ddim_timesteps = ddim_timesteps
self.ddim_timestep_seq = ddim_timestep_seq
# ----------------------------------------------------------------
self.register_buffer('betas', betas)
self.register_buffer('alphas_cumprod', alphas_cumprod)
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
self.register_buffer('posterior_variance', posterior_variance)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.register_buffer('posterior_log_variance_clipped',
torch.log(torch.clamp(posterior_variance, min=1e-20)))
self.register_buffer('posterior_mean_coef1',
betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
self.register_buffer('posterior_mean_coef2',
(1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))
self.loss_type = loss_type
self.loss_fn = Losses[loss_type](None, self.action_dim, self.class_dim)
# ------------------------------------------ sampling ------------------------------------------#
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, x, cond, task_label, t):
#x_recon = self.model(x, t)
x_recon = self.model(x, t, task_label)
if self.clip_denoised:
x_recon.clamp(-1., 1.)
else:
assert RuntimeError()
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
x_start=x_recon, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
return \
(extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) \
/ extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
@torch.no_grad()
def p_sample_ddim(self, x, cond, avg_mask, task_label, t, t_prev, noise, if_prev=False, if_avg_mask=False):
b, *_, device = *x.shape, x.device
#x_recon = self.model(x, t) # without class condition
x_recon = self.model(x, t, task_label)
if self.clip_denoised:
x_recon.clamp(-1., 1.)
else:
assert RuntimeError()
eps = self._predict_eps_from_xstart(x, t, x_recon)
alpha_bar = extract(self.alphas_cumprod, t, x.shape)
if if_prev:
alpha_bar_prev = extract(self.alphas_cumprod_prev, t_prev, x.shape)
else:
alpha_bar_prev = extract(self.alphas_cumprod, t_prev, x.shape)
sigma = (
self.eta
* torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
* torch.sqrt(1 - alpha_bar / alpha_bar_prev)
)
mean_pred = (
x_recon * torch.sqrt(alpha_bar_prev)
+ torch.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
)
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return mean_pred + nonzero_mask * sigma * noise
@torch.no_grad()
def p_sample(self, x, cond, task_label, t, noise):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(x=x, cond=cond, task_label=task_label, t=t)
#noise = torch.randn_like(x) * self.random_ratio
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def p_sample_loop(self, cond, avg_mask, task_label, if_jump, if_avg_mask):
device = self.betas.device
batch_size = len(cond[0])
horizon = self.horizon
shape = (batch_size, horizon, self.class_dim + self.action_dim + self.observation_dim)
mean = torch.zeros(shape, device=device)
std = torch.zeros(shape, device=device)
for i in range(shape[1]):
std[:,i,:] = std[:,i,:] + self.act_std[i]
mean[:,i,:] = mean[:,i,:] + self.act_mean[i]
x = torch.normal(mean, std)
noise = x
x = condition_projection(x, cond, self.action_dim, self.class_dim)
if not if_jump:
for i in reversed(range(0, self.n_timesteps)):
timesteps = torch.full((batch_size,), i, device=device, dtype=torch.long)
x = self.p_sample(x, cond, task_label, timesteps)
x = condition_projection(x, cond, self.action_dim, self.class_dim)
else:
for i in reversed(range(0, self.ddim_timesteps)):
timesteps = torch.full((batch_size,), self.ddim_timestep_seq[i], device=device, dtype=torch.long)
if i == 0:
timesteps_prev = torch.full((batch_size,), 0, device=device, dtype=torch.long)
x = self.p_sample_ddim(x, cond, avg_mask, task_label, timesteps, timesteps_prev, noise, True, if_avg_mask)
else:
timesteps_prev = torch.full((batch_size,), self.ddim_timestep_seq[i-1], device=device, dtype=torch.long)
x = self.p_sample_ddim(x, cond, avg_mask, task_label, timesteps, timesteps_prev, noise, if_avg_mask)
x = condition_projection(x, cond, self.action_dim, self.class_dim)
return x
# ------------------------------------------ training ------------------------------------------#
def q_sample(self, x_start, t, noise=None):
if noise is None:
noise = torch.randn_like(x_start) * self.random_ratio
sample = (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
return sample
def p_losses(self, x_start, cond, t, act_emb_noise, task_label):
noise = act_emb_noise * self.random_ratio
#print('act noise shape', noise.shape)
#print('x_start shape', x_start.shape)
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) # for diffusion, should be removed for Noise and Deterministic
x_noisy = condition_projection(x_noisy, cond, self.action_dim, self.class_dim)
#x_recon = self.model(x_noisy, t) # without class condition
x_recon = self.model(x_noisy, t, task_label) # with class condition
#print('x_recon',x_recon.shape)
x_recon = condition_projection(x_recon, cond, self.action_dim, self.class_dim)
loss = self.loss_fn(x_recon, x_start)
return loss
def loss(self, x, cond, act_emb_noise, task_label):
batch_size = len(x)
t = torch.randint(0, self.n_timesteps, (batch_size,), device=x.device).long() # for diffusion
return self.p_losses(x, cond, t, act_emb_noise, task_label)
def forward(self, cond, avg_mask, task_label, if_jump=False, if_avg_mask=False):
return self.p_sample_loop(cond, avg_mask, task_label, if_jump, if_avg_mask)

208
model/diffusion_no_mask.py Normal file
View file

@ -0,0 +1,208 @@
import random
import numpy as np
import torch
from torch import nn
from .helpers import (
cosine_beta_schedule,
extract,
condition_projection,
Losses,
)
class GaussianDiffusion(nn.Module):
def __init__(self, model, horizon, observation_dim, action_dim, class_dim, n_timesteps=200,
loss_type='Weighted_MSE', clip_denoised=False, ddim_discr_method='uniform',
):
super().__init__()
self.horizon = horizon
self.observation_dim = observation_dim
self.action_dim = action_dim
self.class_dim = class_dim
self.model = model
betas = cosine_beta_schedule(n_timesteps)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]])
self.n_timesteps = n_timesteps
self.clip_denoised = clip_denoised
self.eta = 0.0
self.random_ratio = 1.0
# ---------------------------ddim--------------------------------
ddim_timesteps = 10
if ddim_discr_method == 'uniform':
c = n_timesteps // ddim_timesteps
ddim_timestep_seq = np.asarray(list(range(0, n_timesteps, c)))
elif ddim_discr_method == 'quad':
ddim_timestep_seq = (
(np.linspace(0, np.sqrt(n_timesteps), ddim_timesteps)) ** 2
).astype(int)
else:
assert RuntimeError()
self.ddim_timesteps = ddim_timesteps
self.ddim_timestep_seq = ddim_timestep_seq
# ----------------------------------------------------------------
self.register_buffer('betas', betas)
self.register_buffer('alphas_cumprod', alphas_cumprod)
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
self.register_buffer('posterior_variance', posterior_variance)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.register_buffer('posterior_log_variance_clipped',
torch.log(torch.clamp(posterior_variance, min=1e-20)))
self.register_buffer('posterior_mean_coef1',
betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
self.register_buffer('posterior_mean_coef2',
(1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))
self.loss_type = loss_type
self.loss_fn = Losses[loss_type](None, self.action_dim, self.class_dim)
# ------------------------------------------ sampling ------------------------------------------#
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, x, cond, t):
x_recon = self.model(x, t,0)
if self.clip_denoised:
x_recon.clamp(-1., 1.)
else:
assert RuntimeError()
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
x_start=x_recon, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
return \
(extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) \
/ extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
@torch.no_grad()
def p_sample_ddim(self, x, cond, t, t_prev, if_prev=False):
b, *_, device = *x.shape, x.device
x_recon = self.model(x, t,0)
if self.clip_denoised:
x_recon.clamp(-1., 1.)
else:
assert RuntimeError()
eps = self._predict_eps_from_xstart(x, t, x_recon)
alpha_bar = extract(self.alphas_cumprod, t, x.shape)
if if_prev:
alpha_bar_prev = extract(self.alphas_cumprod_prev, t_prev, x.shape)
else:
alpha_bar_prev = extract(self.alphas_cumprod, t_prev, x.shape)
sigma = (
self.eta
* torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
* torch.sqrt(1 - alpha_bar / alpha_bar_prev)
)
noise = torch.randn_like(x) * self.random_ratio
mean_pred = (
x_recon * torch.sqrt(alpha_bar_prev)
+ torch.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
)
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return mean_pred + nonzero_mask * sigma * noise
@torch.no_grad()
def p_sample(self, x, cond, t):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(x=x, cond=cond, t=t)
noise = torch.randn_like(x) * self.random_ratio
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def p_sample_loop(self, cond, if_jump):
device = self.betas.device
batch_size = len(cond[0])
horizon = self.horizon
shape = (batch_size, horizon, self.class_dim + self.action_dim + self.observation_dim)
x = torch.randn(shape, device=device) * self.random_ratio # xt for Noise and diffusion
# x = torch.zeros(shape, device=device) # for Deterministic
x = condition_projection(x, cond, self.action_dim, self.class_dim)
if not if_jump:
for i in reversed(range(0, self.n_timesteps)):
timesteps = torch.full((batch_size,), i, device=device, dtype=torch.long)
x = self.p_sample(x, cond, timesteps)
x = condition_projection(x, cond, self.action_dim, self.class_dim)
else:
for i in reversed(range(0, self.ddim_timesteps)):
timesteps = torch.full((batch_size,), self.ddim_timestep_seq[i], device=device, dtype=torch.long)
if i == 0:
timesteps_prev = torch.full((batch_size,), 0, device=device, dtype=torch.long)
x = self.p_sample_ddim(x, cond, timesteps, timesteps_prev, True)
else:
timesteps_prev = torch.full((batch_size,), self.ddim_timestep_seq[i-1], device=device, dtype=torch.long)
x = self.p_sample_ddim(x, cond, timesteps, timesteps_prev)
x = condition_projection(x, cond, self.action_dim, self.class_dim)
return x
# ------------------------------------------ training ------------------------------------------#
def q_sample(self, x_start, t, noise=None):
if noise is None:
noise = torch.randn_like(x_start) * self.random_ratio
sample = (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
return sample
def p_losses(self, x_start, cond, t):
noise = torch.randn_like(x_start) * self.random_ratio # for Noise and diffusion
# noise = torch.zeros_like(x_start) # for Deterministic
# x_noisy = noise # for Noise and Deterministic
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) # for diffusion, should be removed for Noise and Deterministic
x_noisy = condition_projection(x_noisy, cond, self.action_dim, self.class_dim)
x_recon = self.model(x_noisy, t, 0)
x_recon = condition_projection(x_recon, cond, self.action_dim, self.class_dim)
loss = self.loss_fn(x_recon, x_start)
return loss
def loss(self, x, cond):
batch_size = len(x) # for diffusion
t = torch.randint(0, self.n_timesteps, (batch_size,), device=x.device).long() # for diffusion
# t = None # for Noise and Deterministic
return self.p_losses(x, cond, t)
def forward(self, cond, if_jump=False):
return self.p_sample_loop(cond, if_jump)

512
model/dit.py Normal file
View file

@ -0,0 +1,512 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# GLIDE: https://github.com/openai/glide-text2im
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
# --------------------------------------------------------
import torch
import torch.nn as nn
import numpy as np
import math
from timm.models.vision_transformer import Attention, Mlp # PatchEmbed
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from .helpers import SinusoidalPosEmb
class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding
"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=768,
norm_layer=None,
flatten=True,
bias=True,
):
super().__init__()
img_size = (img_size, 1)
patch_size = (patch_size, 1)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
B, C, H, W = x.shape
#_assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
#_assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = self.norm(x)
return x
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
#################################################################################
# Embedding Layers for Timesteps and Class Labels #
#################################################################################
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class LabelEmbedder(nn.Module):
"""
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
"""
def __init__(self, num_classes, hidden_size, dropout_prob):
super().__init__()
use_cfg_embedding = dropout_prob > 0
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
self.num_classes = num_classes
self.dropout_prob = dropout_prob
def token_drop(self, labels, force_drop_ids=None):
"""
Drops labels to enable classifier-free guidance.
"""
if force_drop_ids is None:
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
else:
drop_ids = force_drop_ids == 1
labels = torch.where(drop_ids, self.num_classes, labels)
return labels
def forward(self, labels, train, force_drop_ids=None):
use_dropout = self.dropout_prob > 0
if (train and use_dropout) or (force_drop_ids is not None):
labels = self.token_drop(labels, force_drop_ids)
embeddings = self.embedding_table(labels)
return embeddings
#################################################################################
# Core DiT Model #
#################################################################################
class DiTBlock(nn.Module):
"""
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
class FinalLayer(nn.Module):
"""
The final layer of DiT.
"""
def __init__(self, hidden_size, patch_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, patch_size * 1 * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class DiT(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(
self,
input_size=32,
patch_size=2,
in_channels=4,
hidden_size=384,
depth=12,
num_heads=6,
mlp_ratio=4.0,
class_dropout_prob=0.1,
num_classes=1000,
learn_sigma=False,
):
super().__init__()
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels * 2 if learn_sigma else in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
self.t_embedder = TimestepEmbedder(hidden_size)
#self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
num_patches = self.x_embedder.num_patches
# Will use fixed sin-cos embedding:
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
self.blocks = nn.ModuleList([
DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
])
self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
self.initialize_weights()
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize (and freeze) pos_embed by sin-cos embedding:
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches )) #** 0.5
#print('pos_embed', pos_embed.shape, self.x_embedder.num_patches)
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.x_embedder.proj.bias, 0)
# Initialize label embedding table:
#nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# Zero-out adaLN modulation layers in DiT blocks:
for block in self.blocks:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
def unpatchify(self, x):
"""
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
"""
c = self.out_channels
p = self.x_embedder.patch_size[0]
'''h = w = int(x.shape[1] ** 0.5)
assert h * w == x.shape[1]'''
h = x.shape[1]
w = 1
#print(x.shape)
x = x.reshape(shape=(x.shape[0], h, w, p, 1, c))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * 1))
return imgs
def forward(self, x, t):
"""
Forward pass of DiT.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N,) tensor of class labels
"""
#print('x', x.shape, 'x_embedder', self.x_embedder(x).shape, 'pos_embed', self.pos_embed.shape)
x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
t = self.t_embedder(t) # (N, D)
#y = self.y_embedder(y, self.training) # (N, D)
#c = t + y # (N, D)
c = t
for block in self.blocks:
x = block(x, c) # (N, T, D)
x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify(x) # (N, out_channels, H, W)
return x
def forward_with_cfg(self, x, t, y, cfg_scale):
"""
Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
"""
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
half = x[: len(x) // 2]
combined = torch.cat([half, half], dim=0)
model_out = self.forward(combined, t, y)
# For exact reproducibility reasons, we apply classifier-free guidance on only
# three channels by default. The standard approach to cfg applies it to all channels.
# This can be done by uncommenting the following line and commenting-out the line following that.
# eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
eps, rest = model_out[:, :3], model_out[:, 3:]
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
eps = torch.cat([half_eps, half_eps], dim=0)
return torch.cat([eps, rest], dim=1)
#################################################################################
# Sine/Cosine Positional Embedding Functions #
#################################################################################
# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(1, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, 1])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.
omega = 1. / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
def get_emb(sin_inp):
"""
Gets a base embedding for one dimension with sin and cos intertwined
"""
emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1)
return torch.flatten(emb, -2, -1)
class PositionalEncoding1D(nn.Module):
def __init__(self, channels, dtype_override=None):
"""
:param channels: The last dimension of the tensor you want to apply pos emb to.
:param dtype_override: If set, overrides the dtype of the output embedding.
"""
super(PositionalEncoding1D, self).__init__()
self.org_channels = channels
channels = int(np.ceil(channels / 2) * 2)
inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels))
self.register_buffer("inv_freq", inv_freq)
self.register_buffer("cached_penc", None, persistent=False)
self.channels = channels
self.dtype_override = dtype_override
def forward(self, tensor):
"""
:param tensor: A 3d tensor of size (batch_size, ch, x)
:return: Positional Encoding Matrix of size (batch_size, ch, x)
"""
if len(tensor.shape) != 3:
raise RuntimeError("The input tensor has to be 3d!")
if self.cached_penc is not None and self.cached_penc.shape == tensor.shape:
return self.cached_penc
self.cached_penc = None
batch_size, orig_ch, x = tensor.shape
pos_x = torch.arange(x, device=tensor.device, dtype=self.inv_freq.dtype)
sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq)
emb_x = get_emb(sin_inp_x)
#print('emb_x', emb_x.shape)
emb = torch.zeros(
(self.channels, x),
device=tensor.device,
dtype=(
self.dtype_override if self.dtype_override is not None else tensor.dtype
),
)
emb[:self.channels, :] = emb_x.permute(1,0)
self.cached_penc = emb[None, :orig_ch, :].repeat(batch_size, 1, 1)
return self.cached_penc
class TransformerModel(nn.Module):
def __init__(self, ntoken, ninp, nhead, nhid, nlayers=6, dropout=0.0):
super(TransformerModel, self).__init__()
self.model_type = 'Transformer'
self.pos_encoder = PositionalEncoding1D(ninp)
encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout, batch_first=True)
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
#self.encoder = nn.Embedding(ntoken, ninp)
self.ninp = ninp
#self.decoder = nn.Linear(ninp, ntoken)
self.time_mlp = nn.Sequential( # should be removed for Noise and Deterministic Baselines
SinusoidalPosEmb(ninp),
nn.Linear(ninp-1, ninp * 4),
nn.Mish(),
nn.Linear(ninp * 4, ninp),
)
#self.init_weights()
def generate_square_subsequent_mask(self, sz):
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask
'''def init_weights(self):
initrange = 0.1
self.encoder.weight.data.uniform_(-initrange, initrange)
self.decoder.bias.data.zero_()
self.decoder.weight.data.uniform_(-initrange, initrange)'''
def forward(self, src, t):
#print('self.ninp', self.ninp)
#src = self.encoder(src) * math.sqrt(self.ninp)
#print('src', src.shape)
#t = self.time_mlp(t).unsqueeze(1)
emb = self.pos_encoder(src)
#time = torch.cat((t,t,t), dim=1)
#print('time', time.shape)
output = self.transformer_encoder(src+emb)
#print('shape after transformer', output.shape)
#output = self.decoder(output)
return output
#################################################################################
# DiT Configs #
#################################################################################
def DiT_XL_2(**kwargs):
return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
def DiT_XL_4(**kwargs):
return DiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
def DiT_XL_8(**kwargs):
return DiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs)
def DiT_L_2(**kwargs):
return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
def DiT_L_4(**kwargs):
return DiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs)
def DiT_L_8(**kwargs):
return DiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs)
def DiT_B_2(**kwargs):
return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
def DiT_B_4(**kwargs):
return DiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)
def DiT_B_8(**kwargs):
return DiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
def DiT_S_2(**kwargs):
return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
def DiT_S_4(**kwargs):
return DiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
def DiT_S_8(**kwargs):
return DiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
DiT_models = {
'DiT-XL/2': DiT_XL_2, 'DiT-XL/4': DiT_XL_4, 'DiT-XL/8': DiT_XL_8,
'DiT-L/2': DiT_L_2, 'DiT-L/4': DiT_L_4, 'DiT-L/8': DiT_L_8,
'DiT-B/2': DiT_B_2, 'DiT-B/4': DiT_B_4, 'DiT-B/8': DiT_B_8,
'DiT-S/2': DiT_S_2, 'DiT-S/4': DiT_S_4, 'DiT-S/8': DiT_S_8,
}

346
model/helpers.py Normal file
View file

@ -0,0 +1,346 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops.layers.torch import Rearrange
from torch.optim.lr_scheduler import LambdaLR
import os
import numpy as np
import logging
from torch.utils.tensorboard import SummaryWriter
# -----------------------------------------------------------------------------#
# ---------------------------------- modules ----------------------------------#
# -----------------------------------------------------------------------------#
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class Downsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
#self.conv = nn.Conv1d(dim, dim, 2, 1, 0)
self.conv = nn.Conv1d(dim, dim, 1, 1, 0)
def forward(self, x):
return self.conv(x)
class Upsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
#self.conv = nn.ConvTranspose1d(dim, dim, 2, 1, 0)
self.conv = nn.ConvTranspose1d(dim, dim, 1, 1, 0)
def forward(self, x):
return self.conv(x)
class Conv1dBlock(nn.Module):
"""
Conv1d --> GroupNorm --> Mish
"""
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=32, drop_out=0.0, if_zero=False):
super().__init__()
if drop_out > 0.0:
self.block = nn.Sequential(
zero_module(
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=1),
),
Rearrange('batch channels horizon -> batch channels 1 horizon'),
nn.GroupNorm(n_groups, out_channels),
Rearrange('batch channels 1 horizon -> batch channels horizon'),
nn.Mish(),
nn.Dropout(p=drop_out),
)
elif if_zero:
self.block = nn.Sequential(
zero_module(
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=1),
),
Rearrange('batch channels horizon -> batch channels 1 horizon'),
nn.GroupNorm(n_groups, out_channels),
Rearrange('batch channels 1 horizon -> batch channels horizon'),
nn.Mish(),
)
else:
self.block = nn.Sequential(
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=1),
Rearrange('batch channels horizon -> batch channels 1 horizon'),
nn.GroupNorm(n_groups, out_channels),
Rearrange('batch channels 1 horizon -> batch channels horizon'),
nn.Mish(),
)
def forward(self, x):
return self.block(x)
# -----------------------------------------------------------------------------#
# ---------------------------------- sampling ---------------------------------#
# -----------------------------------------------------------------------------#
def extract(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def cosine_beta_schedule(timesteps, s=0.008, dtype=torch.float32):
"""
cosine schedule
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps = timesteps + 1
x = np.linspace(0, steps, steps)
alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
betas_clipped = np.clip(betas, a_min=0, a_max=0.999)
return torch.tensor(betas_clipped, dtype=dtype)
def condition_projection(x, conditions, action_dim, class_dim):
for t, val in conditions.items():
if t != 'task':
x[:, t, class_dim + action_dim:] = val.clone()
x[:, 1:-1, class_dim + action_dim:] = 0.
x[:, :, :class_dim] = conditions['task']
return x
def condition_projection_noise(x, conditions, action_dim, class_dim):
'''for t, val in conditions.items():
if t != 'task':
x[:, t, class_dim + action_dim:] = val.clone()'''
x[:, 1:-1, class_dim + action_dim:] = 0.
x[:, :, :class_dim] = conditions['task']
return x
'''def condition_projection_dit(x, conditions, action_dim, class_dim):
for t, val in conditions.items():
x[:, t, action_dim:] = val.clone()
x[:, 1:-1, action_dim:] = 0.
return x'''
# for img tensors as img, 32*64
def condition_projection_dit(x, conditions, action_dim, class_dim):
for t, val in conditions.items():
#print(t, x.shape, val.shape)
x[:, t, :, :48] = val.clone()
x[:, 1:-1, :, :48] = 0.
return x
# -----------------------------------------------------------------------------#
# ---------------------------------- Loss -------------------------------------#
# -----------------------------------------------------------------------------#
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor = None
for obj in (mean1, logvar1, mean2, logvar2):
if isinstance(obj, torch.Tensor):
tensor = obj
break
assert tensor is not None, "at least one argument must be a Tensor"
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for th.exp().
logvar1, logvar2 = [
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
for x in (logvar1, logvar2)
]
return 0.5 * (
-1.0
+ logvar2
- logvar1
+ torch.exp(logvar1 - logvar2)
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
)
class Weighted_MSE(nn.Module):
def __init__(self, weights, action_dim, class_dim):
super().__init__()
# self.register_buffer('weights', weights)
self.action_dim = action_dim
self.class_dim = class_dim
def forward(self, pred, targ):
"""
:param pred: [B, T, task_dim+action_dim+observation_dim]
:param targ: [B, T, task_dim+action_dim+observation_dim]
:return:
"""
loss_action = F.mse_loss(pred, targ, reduction='none')
loss_action[:, 0, self.class_dim:self.class_dim + self.action_dim] *= 10.
loss_action[:, -1, self.class_dim:self.class_dim + self.action_dim] *= 10.
loss_action = loss_action.sum()
return loss_action
class Weighted_MSE_dit(nn.Module):
def __init__(self, weights, action_dim, class_dim):
super().__init__()
# self.register_buffer('weights', weights)
self.action_dim = action_dim
self.class_dim = class_dim
self.weight = torch.full((32, 16), 10).cuda()
def forward(self, pred, targ):
"""
:param pred: [B, T, task_dim+action_dim+observation_dim]
:param targ: [B, T, task_dim+action_dim+observation_dim]
:return:
"""
loss_action = F.mse_loss(pred, targ, reduction='none')
#print('loss_action', loss_action.shape)
'''print('loss_action', loss_action.shape)
loss_action[:, 0, 32, 48:] *= 10.
loss_action[:, -1, 32, 48:] *= 10.
loss_action = loss_action.sum()'''
loss_action[:, 0, :, 48:] *= self.weight
loss_action[:, -1, :, 48:] *= self.weight
loss_action = loss_action.sum()
return loss_action
Losses = {
'Weighted_MSE': Weighted_MSE,
#'Weighted_MSE': Weighted_MSE_dit,
'normal_kl': normal_kl,
}
# -----------------------------------------------------------------------------#
# -------------------------------- lr_schedule --------------------------------#
# -----------------------------------------------------------------------------#
def get_lr_schedule_with_warmup(args, optimizer, num_training_steps, last_epoch=-1):
if args.dataset == 'crosstask':
num_warmup_steps = num_training_steps * 20 / 120
decay_steps = num_training_steps * 30 / 120
def lr_lambda(current_step):
if current_step <= num_warmup_steps:
return max(0., float(current_step) / float(max(1, num_warmup_steps)))
else:
return max(0.5 ** ((current_step - num_warmup_steps) // decay_steps), 0.)
return LambdaLR(optimizer, lr_lambda, last_epoch)
if args.dataset == 'coin':
num_warmup_steps = num_training_steps * 20 / 800 # total 160,000 steps, 200*800, up to 4500 steps, increase linearly
decay_steps = num_training_steps * 50 / 800 # total 160,000 steps, 200*800, decay at 6000 steps
def lr_lambda(current_step):
if current_step <= num_warmup_steps:
return max(0., float(current_step) / float(max(1, num_warmup_steps)))
else:
return max(0.5 ** ((current_step - num_warmup_steps) // decay_steps), 0.)
return LambdaLR(optimizer, lr_lambda, last_epoch)
if args.dataset == 'NIV':
num_warmup_steps = num_training_steps * 90 / 130 # total 6500 steps, 50*130, up to 4500 steps, increase linearly 90 / 130
decay_steps = num_training_steps * 120 / 130 # total 6500 steps, 50*130, decay at 6000 steps 120 / 130
def lr_lambda(current_step):
if current_step <= num_warmup_steps:
return max(0., float(current_step) / float(max(1, num_warmup_steps)))
else:
return max(0.5 ** ((current_step - num_warmup_steps) // decay_steps), 0.)
return LambdaLR(optimizer, lr_lambda, last_epoch)
# -----------------------------------------------------------------------------#
# ---------------------------------- logging ----------------------------------#
# -----------------------------------------------------------------------------#
# Taken from PyTorch's examples.imagenet.main
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
class Logger:
def __init__(self, log_dir, n_logged_samples=10, summary_writer=SummaryWriter, if_exist=False):
self._log_dir = log_dir
print('logging outputs to ', log_dir)
self._n_logged_samples = n_logged_samples
self._summ_writer = summary_writer(log_dir, flush_secs=120, max_queue=10)
if not if_exist:
log = logging.getLogger(log_dir)
if not log.handlers:
log.setLevel(logging.DEBUG)
if not os.path.exists(log_dir):
os.mkdir(log_dir)
fh = logging.FileHandler(os.path.join(log_dir, 'log.txt'))
fh.setLevel(logging.INFO)
formatter = logging.Formatter(fmt='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S')
fh.setFormatter(formatter)
log.addHandler(fh)
self.log = log
def log_scalar(self, scalar, name, step_):
self._summ_writer.add_scalar('{}'.format(name), scalar, step_)
def log_scalars(self, scalar_dict, group_name, step, phase):
"""Will log all scalars in the same plot."""
self._summ_writer.add_scalars('{}_{}'.format(group_name, phase), scalar_dict, step)
def flush(self):
self._summ_writer.flush()
def log_info(self, info):
self.log.info("{}".format(info))

324
model/temporal_act.py Normal file
View file

@ -0,0 +1,324 @@
import torch
import torch.nn as nn
import einops
from einops.layers.torch import Rearrange
import math
from .helpers import (
SinusoidalPosEmb,
Downsample1d,
Upsample1d,
Conv1dBlock,
)
class GroupNorm32(nn.GroupNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def normalization(channels):
"""
Make a standard normalization layer.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
return GroupNorm32(32, channels)
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def __init__(self, channels, num_heads=1, use_checkpoint=False):
super().__init__()
self.channels = channels
self.num_heads = num_heads
self.use_checkpoint = use_checkpoint
self.norm = normalization(channels)
#self.qkv = conv_nd(1, channels, channels * 3, 1)
self.qkv = Conv1dBlock( channels, channels * 3, 2)
self.attention = QKVAttention()
#self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
self.proj_out = zero_module(Conv1dBlock(channels, channels, 4))
def forward(self, x):
#print(x.shape)
b, c, *spatial = x.shape
x = x.reshape(b, c, -1)
qkv = self.qkv(self.norm(x))
qkv = qkv.reshape(b * self.num_heads, -1, qkv.shape[2])
h = self.attention(qkv)
#print(h.shape, qkv.shape)
h = h.reshape(b, -1, h.shape[-1])
h = self.proj_out(h)
#print(x.shape, h.shape)
return (x + h).reshape(b, c, *spatial)
class QKVAttention(nn.Module):
"""
A module which performs QKV attention.
"""
def forward(self, qkv):
"""
Apply QKV attention.
:param qkv: an [N x (C * 3) x T] tensor of Qs, Ks, and Vs.
:return: an [N x C x T] tensor after attention.
"""
ch = qkv.shape[1] // 3
q, k, v = torch.split(qkv, ch, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = torch.einsum(
"bct,bcs->bts", q * scale, k * scale
) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
return torch.einsum("bts,bcs->bct", weight, v)
class ResidualTemporalBlock(nn.Module):
def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=3):
super().__init__()
self.blocks = nn.ModuleList([
Conv1dBlock(inp_channels, out_channels, kernel_size),
Conv1dBlock(out_channels, out_channels, kernel_size, if_zero=True)
])
self.time_mlp = nn.Sequential( # should be removed for Noise and Deterministic Baselines
nn.Mish(),
nn.Linear(embed_dim, out_channels),
Rearrange('batch t -> batch t 1'),
)
self.residual_conv = nn.Conv1d(inp_channels, out_channels, 1) \
if inp_channels != out_channels else nn.Identity()
self.dropout = nn.Dropout(0.5)
def forward(self, x, t):
out = self.blocks[0](x) + self.time_mlp(t) # for diffusion
# out = self.blocks[0](x) # for Noise and Deterministic Baselines
out = self.blocks[1](out)
return out + self.residual_conv(self.dropout(x))
class TemporalUnet(nn.Module):
def __init__(
self,
transition_dim,
num_class,
dim=32,
dim_mults=(1, 2, 4, 8),
):
super().__init__()
dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
time_dim = dim
self.time_mlp = nn.Sequential( # should be removed for Noise and Deterministic Baselines
SinusoidalPosEmb(dim),
nn.Linear(dim, dim * 4),
nn.Mish(),
nn.Linear(dim * 4, dim),
)
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
self.label_embed = nn.Embedding(num_class, time_dim)
# print(in_out)
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1)
self.downs.append(nn.ModuleList([
ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim),
AttentionBlock(dim_out, use_checkpoint=False, num_heads=4),
ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim),
Downsample1d(dim_out) if not is_last else nn.Identity()
]))
'''self.downs.append(nn.ModuleList([
ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim*2),
ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim*2),
Downsample1d(dim_out) if not is_last else nn.Identity()
]))'''
mid_dim = dims[-1]
self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim)
self.attention = AttentionBlock(mid_dim, use_checkpoint=False, num_heads=16)
self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim)
'''self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim*2)
self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim*2)'''
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolutions - 1)
self.ups.append(nn.ModuleList([
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim),
AttentionBlock(dim_in, use_checkpoint=False, num_heads=4),
ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim),
Upsample1d(dim_in) if not is_last else nn.Identity()
]))
'''self.ups.append(nn.ModuleList([
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim*2),
ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim*2),
Upsample1d(dim_in) if not is_last else nn.Identity()
]))'''
self.final_conv = nn.Sequential(
Conv1dBlock(dim, dim, kernel_size=3, if_zero=True),
nn.Conv1d(dim, transition_dim, 1),
)
def forward(self, x, time, class_label):
x = einops.rearrange(x, 'b h t -> b t h')
# t = None # for Noise and Deterministic Baselines
t = self.time_mlp(time) # for diffusion
#print(x.shape, time.shape, t.shape, class_label.shape)
#y_emb = self.label_embed(class_label)
#print(t.shape, y_emb.shape)
#t = t + y_emb
#t = torch.cat((t, y_emb), 1)
h = []
for resnet, attn, resnet2, downsample in self.downs:
x = resnet(x, t)
x = attn(x)
x = resnet2(x, t)
h.append(x)
x = downsample(x)
x = self.mid_block1(x, t)
x = self.attention(x)
x = self.mid_block2(x, t)
for resnet,attn, resnet2, upsample in self.ups:
x = torch.cat((x, h.pop()), dim=1)
x = resnet(x, t)
x = attn(x)
x = resnet2(x, t)
x = upsample(x)
x = self.final_conv(x)
x = einops.rearrange(x, 'b t h -> b h t')
return x
class TemporalUnetNoAttn(nn.Module):
def __init__(
self,
transition_dim,
num_class,
dim=32,
dim_mults=(1, 2, 4, 8),
):
super().__init__()
dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
time_dim = dim
self.time_mlp = nn.Sequential( # should be removed for Noise and Deterministic Baselines
SinusoidalPosEmb(dim),
nn.Linear(dim, dim * 4),
nn.Mish(),
nn.Linear(dim * 4, dim),
)
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
self.label_embed = nn.Embedding(num_class, time_dim)
# print(in_out)
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1)
self.downs.append(nn.ModuleList([
ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim),
#AttentionBlock(dim_out, use_checkpoint=False, num_heads=4),
ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim),
Downsample1d(dim_out) if not is_last else nn.Identity()
]))
'''self.downs.append(nn.ModuleList([
ResidualTemporalBlock(dim_in, dim_out, embed_dim=time_dim*2),
ResidualTemporalBlock(dim_out, dim_out, embed_dim=time_dim*2),
Downsample1d(dim_out) if not is_last else nn.Identity()
]))'''
mid_dim = dims[-1]
self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim)
#self.attention = AttentionBlock(mid_dim, use_checkpoint=False, num_heads=16)
self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim)
'''self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim*2)
self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim*2)'''
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolutions - 1)
self.ups.append(nn.ModuleList([
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim),
#AttentionBlock(dim_in, use_checkpoint=False, num_heads=4),
ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim),
Upsample1d(dim_in) if not is_last else nn.Identity()
]))
'''self.ups.append(nn.ModuleList([
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim*2),
ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim*2),
Upsample1d(dim_in) if not is_last else nn.Identity()
]))'''
self.final_conv = nn.Sequential(
Conv1dBlock(dim, dim, kernel_size=3, if_zero=True),
nn.Conv1d(dim, transition_dim, 1),
)
def forward(self, x, time, class_label):
x = einops.rearrange(x, 'b h t -> b t h')
# t = None # for Noise and Deterministic Baselines
t = self.time_mlp(time) # for diffusion
#print(x.shape, time.shape, t.shape, class_label.shape)
#y_emb = self.label_embed(class_label)
#print(t.shape, y_emb.shape)
#t = t + y_emb
#t = torch.cat((t, y_emb), 1)
h = []
for resnet, resnet2, downsample in self.downs:
x = resnet(x, t)
x = resnet2(x, t)
h.append(x)
x = downsample(x)
x = self.mid_block1(x, t)
x = self.mid_block2(x, t)
for resnet, resnet2, upsample in self.ups:
x = torch.cat((x, h.pop()), dim=1)
x = resnet(x, t)
x = resnet2(x, t)
x = upsample(x)
x = self.final_conv(x)
x = einops.rearrange(x, 'b t h -> b h t')
return x