import math import torch import torch.nn as nn import torch.nn.functional as F from einops.layers.torch import Rearrange from torch.optim.lr_scheduler import LambdaLR import os import numpy as np import logging from torch.utils.tensorboard import SummaryWriter # -----------------------------------------------------------------------------# # ---------------------------------- modules ----------------------------------# # -----------------------------------------------------------------------------# def zero_module(module): """ Zero out the parameters of a module and return it. """ for p in module.parameters(): p.detach().zero_() return module class SinusoidalPosEmb(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, x): device = x.device half_dim = self.dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device=device) * -emb) emb = x[:, None] * emb[None, :] emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb class Downsample1d(nn.Module): def __init__(self, dim): super().__init__() #self.conv = nn.Conv1d(dim, dim, 2, 1, 0) self.conv = nn.Conv1d(dim, dim, 1, 1, 0) def forward(self, x): return self.conv(x) class Upsample1d(nn.Module): def __init__(self, dim): super().__init__() #self.conv = nn.ConvTranspose1d(dim, dim, 2, 1, 0) self.conv = nn.ConvTranspose1d(dim, dim, 1, 1, 0) def forward(self, x): return self.conv(x) class Conv1dBlock(nn.Module): """ Conv1d --> GroupNorm --> Mish """ def __init__(self, inp_channels, out_channels, kernel_size, n_groups=32, drop_out=0.0, if_zero=False): super().__init__() if drop_out > 0.0: self.block = nn.Sequential( zero_module( nn.Conv1d(inp_channels, out_channels, kernel_size, padding=1), ), Rearrange('batch channels horizon -> batch channels 1 horizon'), nn.GroupNorm(n_groups, out_channels), Rearrange('batch channels 1 horizon -> batch channels horizon'), nn.Mish(), nn.Dropout(p=drop_out), ) elif if_zero: self.block = nn.Sequential( zero_module( nn.Conv1d(inp_channels, out_channels, kernel_size, padding=1), ), Rearrange('batch channels horizon -> batch channels 1 horizon'), nn.GroupNorm(n_groups, out_channels), Rearrange('batch channels 1 horizon -> batch channels horizon'), nn.Mish(), ) else: self.block = nn.Sequential( nn.Conv1d(inp_channels, out_channels, kernel_size, padding=1), Rearrange('batch channels horizon -> batch channels 1 horizon'), nn.GroupNorm(n_groups, out_channels), Rearrange('batch channels 1 horizon -> batch channels horizon'), nn.Mish(), ) def forward(self, x): return self.block(x) # -----------------------------------------------------------------------------# # ---------------------------------- sampling ---------------------------------# # -----------------------------------------------------------------------------# def extract(a, t, x_shape): b, *_ = t.shape out = a.gather(-1, t) return out.reshape(b, *((1,) * (len(x_shape) - 1))) def cosine_beta_schedule(timesteps, s=0.008, dtype=torch.float32): """ cosine schedule as proposed in https://openreview.net/forum?id=-NEXDKk8gZ """ steps = timesteps + 1 x = np.linspace(0, steps, steps) alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2 alphas_cumprod = alphas_cumprod / alphas_cumprod[0] betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) betas_clipped = np.clip(betas, a_min=0, a_max=0.999) return torch.tensor(betas_clipped, dtype=dtype) def condition_projection(x, conditions, action_dim, class_dim): for t, val in conditions.items(): if t != 'task': x[:, t, class_dim + action_dim:] = val.clone() x[:, 1:-1, class_dim + action_dim:] = 0. x[:, :, :class_dim] = conditions['task'] return x def condition_projection_noise(x, conditions, action_dim, class_dim): '''for t, val in conditions.items(): if t != 'task': x[:, t, class_dim + action_dim:] = val.clone()''' x[:, 1:-1, class_dim + action_dim:] = 0. x[:, :, :class_dim] = conditions['task'] return x '''def condition_projection_dit(x, conditions, action_dim, class_dim): for t, val in conditions.items(): x[:, t, action_dim:] = val.clone() x[:, 1:-1, action_dim:] = 0. return x''' # for img tensors as img, 32*64 def condition_projection_dit(x, conditions, action_dim, class_dim): for t, val in conditions.items(): #print(t, x.shape, val.shape) x[:, t, :, :48] = val.clone() x[:, 1:-1, :, :48] = 0. return x # -----------------------------------------------------------------------------# # ---------------------------------- Loss -------------------------------------# # -----------------------------------------------------------------------------# def normal_kl(mean1, logvar1, mean2, logvar2): """ Compute the KL divergence between two gaussians. Shapes are automatically broadcasted, so batches can be compared to scalars, among other use cases. """ tensor = None for obj in (mean1, logvar1, mean2, logvar2): if isinstance(obj, torch.Tensor): tensor = obj break assert tensor is not None, "at least one argument must be a Tensor" # Force variances to be Tensors. Broadcasting helps convert scalars to # Tensors, but it does not work for th.exp(). logvar1, logvar2 = [ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2) ] return 0.5 * ( -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) ) class Weighted_MSE(nn.Module): def __init__(self, weights, action_dim, class_dim): super().__init__() # self.register_buffer('weights', weights) self.action_dim = action_dim self.class_dim = class_dim def forward(self, pred, targ): """ :param pred: [B, T, task_dim+action_dim+observation_dim] :param targ: [B, T, task_dim+action_dim+observation_dim] :return: """ loss_action = F.mse_loss(pred, targ, reduction='none') loss_action[:, 0, self.class_dim:self.class_dim + self.action_dim] *= 10. loss_action[:, -1, self.class_dim:self.class_dim + self.action_dim] *= 10. loss_action = loss_action.sum() return loss_action class Weighted_MSE_dit(nn.Module): def __init__(self, weights, action_dim, class_dim): super().__init__() # self.register_buffer('weights', weights) self.action_dim = action_dim self.class_dim = class_dim self.weight = torch.full((32, 16), 10).cuda() def forward(self, pred, targ): """ :param pred: [B, T, task_dim+action_dim+observation_dim] :param targ: [B, T, task_dim+action_dim+observation_dim] :return: """ loss_action = F.mse_loss(pred, targ, reduction='none') #print('loss_action', loss_action.shape) '''print('loss_action', loss_action.shape) loss_action[:, 0, 32, 48:] *= 10. loss_action[:, -1, 32, 48:] *= 10. loss_action = loss_action.sum()''' loss_action[:, 0, :, 48:] *= self.weight loss_action[:, -1, :, 48:] *= self.weight loss_action = loss_action.sum() return loss_action Losses = { 'Weighted_MSE': Weighted_MSE, #'Weighted_MSE': Weighted_MSE_dit, 'normal_kl': normal_kl, } # -----------------------------------------------------------------------------# # -------------------------------- lr_schedule --------------------------------# # -----------------------------------------------------------------------------# def get_lr_schedule_with_warmup(args, optimizer, num_training_steps, last_epoch=-1): if args.dataset == 'crosstask': num_warmup_steps = num_training_steps * 20 / 120 decay_steps = num_training_steps * 30 / 120 def lr_lambda(current_step): if current_step <= num_warmup_steps: return max(0., float(current_step) / float(max(1, num_warmup_steps))) else: return max(0.5 ** ((current_step - num_warmup_steps) // decay_steps), 0.) return LambdaLR(optimizer, lr_lambda, last_epoch) if args.dataset == 'coin': num_warmup_steps = num_training_steps * 20 / 800 # total 160,000 steps, 200*800, up to 4500 steps, increase linearly decay_steps = num_training_steps * 50 / 800 # total 160,000 steps, 200*800, decay at 6000 steps def lr_lambda(current_step): if current_step <= num_warmup_steps: return max(0., float(current_step) / float(max(1, num_warmup_steps))) else: return max(0.5 ** ((current_step - num_warmup_steps) // decay_steps), 0.) return LambdaLR(optimizer, lr_lambda, last_epoch) if args.dataset == 'NIV': num_warmup_steps = num_training_steps * 90 / 130 # total 6500 steps, 50*130, up to 4500 steps, increase linearly 90 / 130 decay_steps = num_training_steps * 120 / 130 # total 6500 steps, 50*130, decay at 6000 steps 120 / 130 def lr_lambda(current_step): if current_step <= num_warmup_steps: return max(0., float(current_step) / float(max(1, num_warmup_steps))) else: return max(0.5 ** ((current_step - num_warmup_steps) // decay_steps), 0.) return LambdaLR(optimizer, lr_lambda, last_epoch) # -----------------------------------------------------------------------------# # ---------------------------------- logging ----------------------------------# # -----------------------------------------------------------------------------# # Taken from PyTorch's examples.imagenet.main class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count class Logger: def __init__(self, log_dir, n_logged_samples=10, summary_writer=SummaryWriter, if_exist=False): self._log_dir = log_dir print('logging outputs to ', log_dir) self._n_logged_samples = n_logged_samples self._summ_writer = summary_writer(log_dir, flush_secs=120, max_queue=10) if not if_exist: log = logging.getLogger(log_dir) if not log.handlers: log.setLevel(logging.DEBUG) if not os.path.exists(log_dir): os.mkdir(log_dir) fh = logging.FileHandler(os.path.join(log_dir, 'log.txt')) fh.setLevel(logging.INFO) formatter = logging.Formatter(fmt='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S') fh.setFormatter(formatter) log.addHandler(fh) self.log = log def log_scalar(self, scalar, name, step_): self._summ_writer.add_scalar('{}'.format(name), scalar, step_) def log_scalars(self, scalar_dict, group_name, step, phase): """Will log all scalars in the same plot.""" self._summ_writer.add_scalars('{}_{}'.format(group_name, phase), scalar_dict, step) def flush(self): self._summ_writer.flush() def log_info(self, info): self.log.info("{}".format(info))