ActionDiffusion_WACV2025/model/helpers.py

347 lines
12 KiB
Python
Raw Permalink Normal View History

2024-12-02 15:42:58 +01:00
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops.layers.torch import Rearrange
from torch.optim.lr_scheduler import LambdaLR
import os
import numpy as np
import logging
from torch.utils.tensorboard import SummaryWriter
# -----------------------------------------------------------------------------#
# ---------------------------------- modules ----------------------------------#
# -----------------------------------------------------------------------------#
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class Downsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
#self.conv = nn.Conv1d(dim, dim, 2, 1, 0)
self.conv = nn.Conv1d(dim, dim, 1, 1, 0)
def forward(self, x):
return self.conv(x)
class Upsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
#self.conv = nn.ConvTranspose1d(dim, dim, 2, 1, 0)
self.conv = nn.ConvTranspose1d(dim, dim, 1, 1, 0)
def forward(self, x):
return self.conv(x)
class Conv1dBlock(nn.Module):
"""
Conv1d --> GroupNorm --> Mish
"""
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=32, drop_out=0.0, if_zero=False):
super().__init__()
if drop_out > 0.0:
self.block = nn.Sequential(
zero_module(
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=1),
),
Rearrange('batch channels horizon -> batch channels 1 horizon'),
nn.GroupNorm(n_groups, out_channels),
Rearrange('batch channels 1 horizon -> batch channels horizon'),
nn.Mish(),
nn.Dropout(p=drop_out),
)
elif if_zero:
self.block = nn.Sequential(
zero_module(
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=1),
),
Rearrange('batch channels horizon -> batch channels 1 horizon'),
nn.GroupNorm(n_groups, out_channels),
Rearrange('batch channels 1 horizon -> batch channels horizon'),
nn.Mish(),
)
else:
self.block = nn.Sequential(
nn.Conv1d(inp_channels, out_channels, kernel_size, padding=1),
Rearrange('batch channels horizon -> batch channels 1 horizon'),
nn.GroupNorm(n_groups, out_channels),
Rearrange('batch channels 1 horizon -> batch channels horizon'),
nn.Mish(),
)
def forward(self, x):
return self.block(x)
# -----------------------------------------------------------------------------#
# ---------------------------------- sampling ---------------------------------#
# -----------------------------------------------------------------------------#
def extract(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def cosine_beta_schedule(timesteps, s=0.008, dtype=torch.float32):
"""
cosine schedule
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps = timesteps + 1
x = np.linspace(0, steps, steps)
alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
betas_clipped = np.clip(betas, a_min=0, a_max=0.999)
return torch.tensor(betas_clipped, dtype=dtype)
def condition_projection(x, conditions, action_dim, class_dim):
for t, val in conditions.items():
if t != 'task':
x[:, t, class_dim + action_dim:] = val.clone()
x[:, 1:-1, class_dim + action_dim:] = 0.
x[:, :, :class_dim] = conditions['task']
return x
def condition_projection_noise(x, conditions, action_dim, class_dim):
'''for t, val in conditions.items():
if t != 'task':
x[:, t, class_dim + action_dim:] = val.clone()'''
x[:, 1:-1, class_dim + action_dim:] = 0.
x[:, :, :class_dim] = conditions['task']
return x
'''def condition_projection_dit(x, conditions, action_dim, class_dim):
for t, val in conditions.items():
x[:, t, action_dim:] = val.clone()
x[:, 1:-1, action_dim:] = 0.
return x'''
# for img tensors as img, 32*64
def condition_projection_dit(x, conditions, action_dim, class_dim):
for t, val in conditions.items():
#print(t, x.shape, val.shape)
x[:, t, :, :48] = val.clone()
x[:, 1:-1, :, :48] = 0.
return x
# -----------------------------------------------------------------------------#
# ---------------------------------- Loss -------------------------------------#
# -----------------------------------------------------------------------------#
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor = None
for obj in (mean1, logvar1, mean2, logvar2):
if isinstance(obj, torch.Tensor):
tensor = obj
break
assert tensor is not None, "at least one argument must be a Tensor"
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for th.exp().
logvar1, logvar2 = [
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
for x in (logvar1, logvar2)
]
return 0.5 * (
-1.0
+ logvar2
- logvar1
+ torch.exp(logvar1 - logvar2)
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
)
class Weighted_MSE(nn.Module):
def __init__(self, weights, action_dim, class_dim):
super().__init__()
# self.register_buffer('weights', weights)
self.action_dim = action_dim
self.class_dim = class_dim
def forward(self, pred, targ):
"""
:param pred: [B, T, task_dim+action_dim+observation_dim]
:param targ: [B, T, task_dim+action_dim+observation_dim]
:return:
"""
loss_action = F.mse_loss(pred, targ, reduction='none')
loss_action[:, 0, self.class_dim:self.class_dim + self.action_dim] *= 10.
loss_action[:, -1, self.class_dim:self.class_dim + self.action_dim] *= 10.
loss_action = loss_action.sum()
return loss_action
class Weighted_MSE_dit(nn.Module):
def __init__(self, weights, action_dim, class_dim):
super().__init__()
# self.register_buffer('weights', weights)
self.action_dim = action_dim
self.class_dim = class_dim
self.weight = torch.full((32, 16), 10).cuda()
def forward(self, pred, targ):
"""
:param pred: [B, T, task_dim+action_dim+observation_dim]
:param targ: [B, T, task_dim+action_dim+observation_dim]
:return:
"""
loss_action = F.mse_loss(pred, targ, reduction='none')
#print('loss_action', loss_action.shape)
'''print('loss_action', loss_action.shape)
loss_action[:, 0, 32, 48:] *= 10.
loss_action[:, -1, 32, 48:] *= 10.
loss_action = loss_action.sum()'''
loss_action[:, 0, :, 48:] *= self.weight
loss_action[:, -1, :, 48:] *= self.weight
loss_action = loss_action.sum()
return loss_action
Losses = {
'Weighted_MSE': Weighted_MSE,
#'Weighted_MSE': Weighted_MSE_dit,
'normal_kl': normal_kl,
}
# -----------------------------------------------------------------------------#
# -------------------------------- lr_schedule --------------------------------#
# -----------------------------------------------------------------------------#
def get_lr_schedule_with_warmup(args, optimizer, num_training_steps, last_epoch=-1):
if args.dataset == 'crosstask':
num_warmup_steps = num_training_steps * 20 / 120
decay_steps = num_training_steps * 30 / 120
def lr_lambda(current_step):
if current_step <= num_warmup_steps:
return max(0., float(current_step) / float(max(1, num_warmup_steps)))
else:
return max(0.5 ** ((current_step - num_warmup_steps) // decay_steps), 0.)
return LambdaLR(optimizer, lr_lambda, last_epoch)
if args.dataset == 'coin':
num_warmup_steps = num_training_steps * 20 / 800 # total 160,000 steps, 200*800, up to 4500 steps, increase linearly
decay_steps = num_training_steps * 50 / 800 # total 160,000 steps, 200*800, decay at 6000 steps
def lr_lambda(current_step):
if current_step <= num_warmup_steps:
return max(0., float(current_step) / float(max(1, num_warmup_steps)))
else:
return max(0.5 ** ((current_step - num_warmup_steps) // decay_steps), 0.)
return LambdaLR(optimizer, lr_lambda, last_epoch)
if args.dataset == 'NIV':
num_warmup_steps = num_training_steps * 90 / 130 # total 6500 steps, 50*130, up to 4500 steps, increase linearly 90 / 130
decay_steps = num_training_steps * 120 / 130 # total 6500 steps, 50*130, decay at 6000 steps 120 / 130
def lr_lambda(current_step):
if current_step <= num_warmup_steps:
return max(0., float(current_step) / float(max(1, num_warmup_steps)))
else:
return max(0.5 ** ((current_step - num_warmup_steps) // decay_steps), 0.)
return LambdaLR(optimizer, lr_lambda, last_epoch)
# -----------------------------------------------------------------------------#
# ---------------------------------- logging ----------------------------------#
# -----------------------------------------------------------------------------#
# Taken from PyTorch's examples.imagenet.main
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
class Logger:
def __init__(self, log_dir, n_logged_samples=10, summary_writer=SummaryWriter, if_exist=False):
self._log_dir = log_dir
print('logging outputs to ', log_dir)
self._n_logged_samples = n_logged_samples
self._summ_writer = summary_writer(log_dir, flush_secs=120, max_queue=10)
if not if_exist:
log = logging.getLogger(log_dir)
if not log.handlers:
log.setLevel(logging.DEBUG)
if not os.path.exists(log_dir):
os.mkdir(log_dir)
fh = logging.FileHandler(os.path.join(log_dir, 'log.txt'))
fh.setLevel(logging.INFO)
formatter = logging.Formatter(fmt='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S')
fh.setFormatter(formatter)
log.addHandler(fh)
self.log = log
def log_scalar(self, scalar, name, step_):
self._summ_writer.add_scalar('{}'.format(name), scalar, step_)
def log_scalars(self, scalar_dict, group_name, step, phase):
"""Will log all scalars in the same plot."""
self._summ_writer.add_scalars('{}_{}'.format(group_name, phase), scalar_dict, step)
def flush(self):
self._summ_writer.flush()
def log_info(self, info):
self.log.info("{}".format(info))