diff --git a/checkpoints/haheae/last.ckpt b/checkpoints/haheae/last.ckpt new file mode 100644 index 0000000..d0e2c51 Binary files /dev/null and b/checkpoints/haheae/last.ckpt differ diff --git a/choices.py b/choices.py new file mode 100644 index 0000000..1cd1616 --- /dev/null +++ b/choices.py @@ -0,0 +1,179 @@ +from enum import Enum +from torch import nn + + +class TrainMode(Enum): + # manipulate mode = training the classifier + manipulate = 'manipulate' + # default training mode! + diffusion = 'diffusion' + # default latent training mode! + # fitting a diffusion model to a given latent + latent_diffusion = 'latentdiffusion' + + def is_manipulate(self): + return self in [ + TrainMode.manipulate, + ] + + def is_diffusion(self): + return self in [ + TrainMode.diffusion, + TrainMode.latent_diffusion, + ] + + def is_autoenc(self): + # the network possibly does autoencoding + return self in [ + TrainMode.diffusion, + ] + + def is_latent_diffusion(self): + return self in [ + TrainMode.latent_diffusion, + ] + + def use_latent_net(self): + return self.is_latent_diffusion() + + def require_dataset_infer(self): + """ + whether training in this mode requires the latent variables to be available? + """ + # this will precalculate all the latents before hand + # and the dataset will be all the predicted latents + return self in [ + TrainMode.latent_diffusion, + TrainMode.manipulate, + ] + + +class ManipulateMode(Enum): + """ + how to train the classifier to manipulate + """ + # train on whole celeba attr dataset + celebahq_all = 'celebahq_all' + # celeba with D2C's crop + d2c_fewshot = 'd2cfewshot' + d2c_fewshot_allneg = 'd2cfewshotallneg' + + def is_celeba_attr(self): + return self in [ + ManipulateMode.d2c_fewshot, + ManipulateMode.d2c_fewshot_allneg, + ManipulateMode.celebahq_all, + ] + + def is_single_class(self): + return self in [ + ManipulateMode.d2c_fewshot, + ManipulateMode.d2c_fewshot_allneg, + ] + + def is_fewshot(self): + return self in [ + ManipulateMode.d2c_fewshot, + ManipulateMode.d2c_fewshot_allneg, + ] + + def is_fewshot_allneg(self): + return self in [ + ManipulateMode.d2c_fewshot_allneg, + ] + + +class ModelType(Enum): + """ + Kinds of the backbone models + """ + + # unconditional ddpm + ddpm = 'ddpm' + # autoencoding ddpm cannot do unconditional generation + autoencoder = 'autoencoder' + + def has_autoenc(self): + return self in [ + ModelType.autoencoder, + ] + + def can_sample(self): + return self in [ModelType.ddpm] + + +class ModelName(Enum): + """ + List of all supported model classes + """ + + beatgans_ddpm = 'beatgans_ddpm' + beatgans_autoenc = 'beatgans_autoenc' + + +class ModelMeanType(Enum): + """ + Which type of output the model predicts. + """ + + eps = 'eps' # the model predicts epsilon + + +class ModelVarType(Enum): + """ + What is used as the model's output variance. + + The LEARNED_RANGE option has been added to allow the model to predict + values between FIXED_SMALL and FIXED_LARGE, making its job easier. + """ + + # posterior beta_t + fixed_small = 'fixed_small' + # beta_t + fixed_large = 'fixed_large' + + +class LossType(Enum): + mse = 'mse' # use raw MSE loss (and KL when learning variances) + l1 = 'l1' + + +class GenerativeType(Enum): + """ + How's a sample generated + """ + + ddpm = 'ddpm' + ddim = 'ddim' + + +class OptimizerType(Enum): + adam = 'adam' + adamw = 'adamw' + + +class Activation(Enum): + none = 'none' + relu = 'relu' + lrelu = 'lrelu' + silu = 'silu' + tanh = 'tanh' + + def get_act(self): + if self == Activation.none: + return nn.Identity() + elif self == Activation.relu: + return nn.ReLU() + elif self == Activation.lrelu: + return nn.LeakyReLU(negative_slope=0.2) + elif self == Activation.silu: + return nn.SiLU() + elif self == Activation.tanh: + return nn.Tanh() + else: + raise NotImplementedError() + + +class ManipulateLossType(Enum): + bce = 'bce' + mse = 'mse' \ No newline at end of file diff --git a/config.py b/config.py new file mode 100644 index 0000000..4f32ec7 --- /dev/null +++ b/config.py @@ -0,0 +1,153 @@ +from model.blocks import * +from diffusion.resample import UniformSampler +from dataclasses import dataclass +from diffusion.diffusion import space_timesteps +from typing import Tuple +from config_base import BaseConfig +from diffusion import * +from diffusion.base import GenerativeType, LossType, ModelMeanType, ModelVarType, get_named_beta_schedule +from model import * +from choices import * +from preprocess import * +import os + +@dataclass +class TrainConfig(BaseConfig): + name: str = '' + base_dir: str = './checkpoints/' + logdir: str = f'{base_dir}{name}' + data_name: str = '' + data_val_name: str = '' + seq_len: int = 40 # for reconstruction + seq_len_future: int = 3 # for prediction + in_channels = 9 + fp16: bool = True + lr: float = 1e-4 + ema_decay: float = 0.9999 + seed: int = 0 # random seed + batch_size: int = 64 + accum_batches: int = 1 + batch_size_eval: int = 1024 + total_epochs: int = 1_000 + save_every_epochs: int = 10 + eval_every_epochs: int = 10 + train_mode: TrainMode = TrainMode.diffusion + T: int = 1000 + T_eval: int = 100 + diffusion_type: str = 'beatgans' + semantic_encoder_type: str = 'gcn' + net_beatgans_embed_channels: int = 128 + beatgans_gen_type: GenerativeType = GenerativeType.ddim + beatgans_loss_type: LossType = LossType.mse + hand_mse_factor = 1.0 + head_mse_factor = 1.0 + beatgans_model_mean_type: ModelMeanType = ModelMeanType.eps + beatgans_model_var_type: ModelVarType = ModelVarType.fixed_large + beatgans_rescale_timesteps: bool = False + beta_scheduler: str = 'linear' + net_ch: int = 64 + net_ch_mult: Tuple[int, ...]= (1, 2, 4) + net_enc_channel_mult: Tuple[int] = (1, 2, 4) + grad_clip: float = 1 + optimizer: OptimizerType = OptimizerType.adam + weight_decay: float = 0 + warmup: int = 0 + model_conf: ModelConfig = None + model_name: ModelName = ModelName.beatgans_autoenc + model_type: ModelType = None + + @property + def batch_size_effective(self): + return self.batch_size*self.accum_batches + + def _make_diffusion_conf(self, T=None): + if self.diffusion_type == 'beatgans': + # can use T < self.T for evaluation + # follows the guided-diffusion repo conventions + # t's are evenly spaced + if self.beatgans_gen_type == GenerativeType.ddpm: + section_counts = [T] + elif self.beatgans_gen_type == GenerativeType.ddim: + section_counts = f'ddim{T}' + else: + raise NotImplementedError() + + return SpacedDiffusionBeatGansConfig( + gen_type=self.beatgans_gen_type, + model_type=self.model_type, + betas=get_named_beta_schedule(self.beta_scheduler, T), + model_mean_type=self.beatgans_model_mean_type, + model_var_type=self.beatgans_model_var_type, + loss_type=self.beatgans_loss_type, + rescale_timesteps=self.beatgans_rescale_timesteps, + use_timesteps=space_timesteps(num_timesteps=T, section_counts=section_counts), + fp16=self.fp16, + ) + else: + raise NotImplementedError() + + @property + def model_out_channels(self): + return self.in_channels + + @property + def model_input_channels(self): + return self.in_channels + + def make_T_sampler(self): + return UniformSampler(self.T) + + def make_diffusion_conf(self): + return self._make_diffusion_conf(self.T) + + def make_eval_diffusion_conf(self): + return self._make_diffusion_conf(T=self.T_eval) + + def make_model_conf(self): + cls = BeatGANsAutoencConfig + if self.model_name == ModelName.beatgans_autoenc: + self.model_type = ModelType.autoencoder + else: + raise NotImplementedError() + + self.model_conf = cls( + semantic_encoder_type = self.semantic_encoder_type, + channel_mult=self.net_ch_mult, + seq_len = self.seq_len, + seq_len_future = self.seq_len_future, + embed_channels=self.net_beatgans_embed_channels, + enc_out_channels=self.net_beatgans_embed_channels, + enc_channel_mult=self.net_enc_channel_mult, + in_channels=self.model_input_channels, + model_channels=self.net_ch, + out_channels=self.model_out_channels, + ) + + return self.model_conf + +def egobody_autoenc(mode, encoder_type='gcn', hand_mse_factor=1.0, head_mse_factor=1.0, data_sample_rate=1, epoch=130,in_channels=9, seq_len=40): + conf = TrainConfig() + conf.seq_len = seq_len + conf.seq_len_future = 3 + conf.in_channels = in_channels + conf.net_beatgans_embed_channels = 128 + conf.net_ch = 64 + conf.net_ch_mult = (1, 1, 1) + conf.semantic_encoder_type = encoder_type + conf.hand_mse_factor = hand_mse_factor + conf.head_mse_factor = head_mse_factor + conf.net_enc_channel_mult = conf.net_ch_mult + conf.total_epochs = epoch + conf.save_every_epochs = 10 + conf.eval_every_epochs = 10 + conf.batch_size = 64 + conf.batch_size_eval = 1024*4 + + conf.data_dir = "/scratch/hu/pose_forecast/egobody_pose2gaze/" + conf.data_sample_rate = data_sample_rate + conf.name = 'egobody_autoenc' + conf.data_name = 'egobody' + + conf.mode = mode + conf.make_model_conf() + return conf \ No newline at end of file diff --git a/config_base.py b/config_base.py new file mode 100644 index 0000000..385f9ee --- /dev/null +++ b/config_base.py @@ -0,0 +1,72 @@ +import json +import os +from copy import deepcopy +from dataclasses import dataclass + + +@dataclass +class BaseConfig: + def clone(self): + return deepcopy(self) + + def inherit(self, another): + """inherit common keys from a given config""" + common_keys = set(self.__dict__.keys()) & set(another.__dict__.keys()) + for k in common_keys: + setattr(self, k, getattr(another, k)) + + def propagate(self): + """push down the configuration to all members""" + for k, v in self.__dict__.items(): + if isinstance(v, BaseConfig): + v.inherit(self) + v.propagate() + + def save(self, save_path): + """save config to json file""" + dirname = os.path.dirname(save_path) + if not os.path.exists(dirname): + os.makedirs(dirname) + conf = self.as_dict_jsonable() + with open(save_path, 'w') as f: + json.dump(conf, f) + + def load(self, load_path): + """load json config""" + with open(load_path) as f: + conf = json.load(f) + self.from_dict(conf) + + def from_dict(self, dict, strict=False): + for k, v in dict.items(): + if not hasattr(self, k): + if strict: + raise ValueError(f"loading extra '{k}'") + else: + print(f"loading extra '{k}'") + continue + if isinstance(self.__dict__[k], BaseConfig): + self.__dict__[k].from_dict(v) + else: + self.__dict__[k] = v + + def as_dict_jsonable(self): + conf = {} + for k, v in self.__dict__.items(): + if isinstance(v, BaseConfig): + conf[k] = v.as_dict_jsonable() + else: + if jsonable(v): + conf[k] = v + else: + # ignore not jsonable + pass + return conf + + +def jsonable(x): + try: + json.dumps(x) + return True + except TypeError: + return False diff --git a/diffusion/__init__.py b/diffusion/__init__.py new file mode 100644 index 0000000..4e0838c --- /dev/null +++ b/diffusion/__init__.py @@ -0,0 +1,6 @@ +from typing import Union + +from .diffusion import SpacedDiffusionBeatGans, SpacedDiffusionBeatGansConfig + +Sampler = Union[SpacedDiffusionBeatGans] +SamplerConfig = Union[SpacedDiffusionBeatGansConfig] diff --git a/diffusion/base.py b/diffusion/base.py new file mode 100644 index 0000000..1ef28af --- /dev/null +++ b/diffusion/base.py @@ -0,0 +1,1148 @@ +""" +This code started out as a PyTorch port of Ho et al's diffusion models: +https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py + +Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules. +""" + +from model.unet_autoenc import AutoencReturn +from config_base import BaseConfig +import enum +import math, pdb + +import numpy as np +import torch as th +from model import * +from model.nn import mean_flat +from typing import NamedTuple, Tuple +from choices import * +from torch.cuda.amp import autocast +import torch.nn.functional as F + +from dataclasses import dataclass + + +@dataclass +class GaussianDiffusionBeatGansConfig(BaseConfig): + gen_type: GenerativeType + betas: Tuple[float] + model_type: ModelType + model_mean_type: ModelMeanType + model_var_type: ModelVarType + loss_type: LossType + rescale_timesteps: bool + fp16: bool + train_pred_xstart_detach: bool = True + + def make_sampler(self): + return GaussianDiffusionBeatGans(self) + + +class GaussianDiffusionBeatGans: + """ + Utilities for training and sampling diffusion models. + + Ported directly from here, and then adapted over time to further experimentation. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 + + :param betas: a 1-D numpy array of betas for each diffusion timestep, + starting at T and going to 1. + :param model_mean_type: a ModelMeanType determining what the model outputs. + :param model_var_type: a ModelVarType determining how variance is output. + :param loss_type: a LossType determining the loss function to use. + :param rescale_timesteps: if True, pass floating point timesteps into the + model so that they are always scaled like in the + original paper (0 to 1000). + """ + def __init__(self, conf: GaussianDiffusionBeatGansConfig): + self.conf = conf + self.model_mean_type = conf.model_mean_type + self.model_var_type = conf.model_var_type + self.loss_type = conf.loss_type + self.rescale_timesteps = conf.rescale_timesteps + + # Use float64 for accuracy. + betas = np.array(conf.betas, dtype=np.float64) + self.betas = betas + assert len(betas.shape) == 1, "betas must be 1-D" + assert (betas > 0).all() and (betas <= 1).all() + + self.num_timesteps = int(betas.shape[0]) + + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) + self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) + assert self.alphas_cumprod_prev.shape == (self.num_timesteps, ) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - + 1) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + self.posterior_variance = (betas * (1.0 - self.alphas_cumprod_prev) / + (1.0 - self.alphas_cumprod)) + # log calculation clipped because the posterior variance is 0 at the + # beginning of the diffusion chain. + self.posterior_log_variance_clipped = np.log( + np.append(self.posterior_variance[1], self.posterior_variance[1:])) + self.posterior_mean_coef1 = (betas * + np.sqrt(self.alphas_cumprod_prev) / + (1.0 - self.alphas_cumprod)) + self.posterior_mean_coef2 = ((1.0 - self.alphas_cumprod_prev) * + np.sqrt(alphas) / + (1.0 - self.alphas_cumprod)) + + def training_losses(self, + model: Model, + x_start: th.Tensor, + t: th.Tensor, + x_future: th.Tensor, + hand_mse_factor = 1.0, + head_mse_factor = 1.0, + model_kwargs=None, + noise: th.Tensor=None, + ): + """ + Compute training losses for a single timestep. + + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param t: a batch of timestep indices. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param noise: if specified, the specific Gaussian noise to try to remove. + :return: a dict with the key "loss" containing a tensor of shape [N]. + Some mean or variance settings may also have other keys. + """ + if model_kwargs is None: + model_kwargs = {} + if noise is None: ## run_ffhq256.py + noise = th.randn_like(x_start) ## randomly sample from a Gaussian distribution + x_t = self.q_sample(x_start, t, noise=noise) ## get x_t from x_0 + + terms = {'x_t': x_t} + + if self.loss_type in [ + LossType.mse, + LossType.l1, + ]: + with autocast(self.conf.fp16): + # x_t is static wrt. to the diffusion process + model_forward = model.forward(x=x_t,##.detach(), + t=self._scale_timesteps(t), + x_start=x_start,##.detach(), + **model_kwargs) + model_output = model_forward.pred ## predicted noise + pred_hand = model_forward.pred_hand ## predicted hand trajectory + pred_head = model_forward.pred_head ## predicted head orientation + in_channels = x_future.shape[1] + if in_channels == 9: # both hand and head + gt_hand = x_future[:, :6, :] + gt_head = x_future[:, 6:, :] + if in_channels == 6: # hand only + gt_hand = x_future[:, :, :] + if in_channels == 3: # head only + gt_head = x_future[:, :, :] + + _model_output = model_output + if self.conf.train_pred_xstart_detach: + _model_output = _model_output.detach() + # get the pred xstart + p_mean_var = self.p_mean_variance( ## Apply the model to get p(x_{t-1} | x_t), and a prediction of x_0 + model=DummyModel(pred=model_output), ##_model_output), ## return the same value + # gradient goes through x_t + x=x_t, + t=t, + clip_denoised=False) + terms['pred_xstart'] = p_mean_var['pred_xstart'] + # model_output = model(x_t, self._scale_timesteps(t), **model_kwargs) + target_types = { + ModelMeanType.eps: noise, + } + target = target_types[self.model_mean_type] ## Gaussian noise + assert model_output.shape == target.shape == x_start.shape + + ## Calculate loss + if self.loss_type == LossType.mse: + if self.model_mean_type == ModelMeanType.eps: + noise_mse = mean_flat((target - model_output)**2) + terms["mse"] = noise_mse + if in_channels == 9: # both hand and head + hand_mse = mean_flat((gt_hand - pred_hand)**2) + head_mse = mean_flat((gt_head - pred_head)**2) + terms["hand_mse"] = hand_mse + terms["head_mse"] = head_mse + if in_channels == 6: # hand only + hand_mse = mean_flat((gt_hand - pred_hand)**2) + terms["hand_mse"] = hand_mse + terms["head_mse"] = 0 + if in_channels == 3: # head only + head_mse = mean_flat((gt_head - pred_head)**2) + terms["head_mse"] = head_mse + terms["hand_mse"] = 0 + else: + raise NotImplementedError() + else: + raise NotImplementedError() + terms["loss"] = terms["mse"] + terms["hand_mse"]*hand_mse_factor + terms["head_mse"]*head_mse_factor + else: + raise NotImplementedError(self.loss_type) + + return terms + + def sample(self, + model: Model, + shape=None, + noise=None, + cond=None, + x_start=None, + clip_denoised=True, + model_kwargs=None, + progress=False): + """ + Args: + x_start: given for the autoencoder + """ + ## autoencoding.py: model_kwargs not None, ['cond'] torch.Size([1, 512]) + if model_kwargs is None: + model_kwargs = {} + if self.conf.model_type.has_autoenc(): + model_kwargs['x_start'] = x_start + model_kwargs['cond'] = cond + + if self.conf.gen_type == GenerativeType.ddpm: + return self.p_sample_loop(model, + shape=shape, + noise=noise, + clip_denoised=clip_denoised, + model_kwargs=model_kwargs, + progress=progress) + elif self.conf.gen_type == GenerativeType.ddim: ## autoencoding.py + return self.ddim_sample_loop(model, + shape=shape, + noise=noise, + clip_denoised=clip_denoised, + model_kwargs=model_kwargs, + progress=progress) + else: + raise NotImplementedError() + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * + x_start) + variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, + x_start.shape) + log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, + t, x_start.shape) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data for a given number of diffusion steps. + + In other words, sample from q(x_t | x_0). + + :param x_start: the initial data batch. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :param noise: if specified, the split-out normal noise. + :return: A noisy version of x_start. + """ + if noise is None: + noise = th.randn_like(x_start) + assert noise.shape == x_start.shape + x1 = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) + # x1 = x1.to(device="cpu") + # x1 = x1.numpy() + x2 = _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, + t, x_start.shape) + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * + x_start + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, + t, x_start.shape) * noise) + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior: + + q(x_{t-1} | x_t, x_0) + + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * + x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * + x_t) + posterior_variance = _extract_into_tensor(self.posterior_variance, t, + x_t.shape) + posterior_log_variance_clipped = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape) + assert (posterior_mean.shape[0] == posterior_variance.shape[0] == + posterior_log_variance_clipped.shape[0] == x_start.shape[0]) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, + model: Model, + x, + t, + clip_denoised=True, + denoised_fn=None, + model_kwargs=None): + """ + Apply the model to get p(x_{t-1} | x_t), as well as a prediction of + the initial x, x_0. + + :param model: the model, which takes a signal and a batch of timesteps + as input. + :param x: the [N x C x ...] tensor at time t. + :param t: a 1-D Tensor of timesteps. + :param clip_denoised: if True, clip the denoised signal into [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. Applies before + clip_denoised. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict with the following keys: + - 'mean': the model mean output. + - 'variance': the model variance output. + - 'log_variance': the log of 'variance'. + - 'pred_xstart': the prediction for x_0. + """ + if model_kwargs is None: + model_kwargs = {} + + B, C = x.shape[:2] + assert t.shape == (B, ) + with autocast(self.conf.fp16): + model_forward = model.forward(x=x, + t=self._scale_timesteps(t), + **model_kwargs) + model_output = model_forward.pred + + if self.model_var_type in [ + ModelVarType.fixed_large, ModelVarType.fixed_small + ]: + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so + # to get a better decoder log likelihood. + ModelVarType.fixed_large: ( + np.append(self.posterior_variance[1], self.betas[1:]), + np.log( + np.append(self.posterior_variance[1], self.betas[1:])), + ), + ModelVarType.fixed_small: ( + self.posterior_variance, + self.posterior_log_variance_clipped, + ), + }[self.model_var_type] + model_variance = _extract_into_tensor(model_variance, t, x.shape) + model_log_variance = _extract_into_tensor(model_log_variance, t, + x.shape) + + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + return x.clamp(-1, 1) + return F.tanh(x) + + if self.model_mean_type in [ + ModelMeanType.eps, + ]: + if self.model_mean_type == ModelMeanType.eps: + pred_xstart = process_xstart( + self._predict_xstart_from_eps(x_t=x, t=t, + eps=model_output)) + else: + raise NotImplementedError() + model_mean, _, _ = self.q_posterior_mean_variance( + x_start=pred_xstart, x_t=x, t=t) + else: + raise NotImplementedError(self.model_mean_type) + + assert (model_mean.shape == model_log_variance.shape == + pred_xstart.shape == x.shape) + + return { + "mean": model_mean, + "variance": model_variance, + "log_variance": model_log_variance, + "pred_xstart": pred_xstart, + 'model_forward': model_forward, + } + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return (_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, + x_t.shape) * x_t - + _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, + x_t.shape) * eps) + + def _predict_xstart_from_xprev(self, x_t, t, xprev): + assert x_t.shape == xprev.shape + return ( # (xprev - coef2*x_t) / coef1 + _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) + * xprev - _extract_into_tensor( + self.posterior_mean_coef2 / self.posterior_mean_coef1, t, + x_t.shape) * x_t) + + def _predict_xstart_from_scaled_xstart(self, t, scaled_xstart): + return scaled_xstart * _extract_into_tensor( + self.sqrt_recip_alphas_cumprod, t, scaled_xstart.shape) + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return (_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, + x_t.shape) * x_t - + pred_xstart) / _extract_into_tensor( + self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def _predict_eps_from_scaled_xstart(self, x_t, t, scaled_xstart): + """ + Args: + scaled_xstart: is supposed to be sqrt(alphacum) * x_0 + """ + # 1 / sqrt(1-alphabar) * (x_t - scaled xstart) + return (x_t - scaled_xstart) / _extract_into_tensor( + self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) + + def _scale_timesteps(self, t): + if self.rescale_timesteps: + # scale t to be maxed out at 1000 steps + return t.float() * (1000.0 / self.num_timesteps) + return t + + def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute the mean for the previous step, given a function cond_fn that + computes the gradient of a conditional log probability with respect to + x. In particular, cond_fn computes grad(log(p(y|x))), and we want to + condition on y. + + This uses the conditioning strategy from Sohl-Dickstein et al. (2015). + """ + gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs) + new_mean = (p_mean_var["mean"].float() + + p_mean_var["variance"] * gradient.float()) + return new_mean + + def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute what the p_mean_variance output would have been, should the + model's score function be conditioned by cond_fn. + + See condition_mean() for details on cond_fn. + + Unlike condition_mean(), this instead uses the conditioning strategy + from Song et al (2020). + """ + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + + eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) + eps = eps - (1 - alpha_bar).sqrt() * cond_fn( + x, self._scale_timesteps(t), **model_kwargs) + + out = p_mean_var.copy() + out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) + out["mean"], _, _ = self.q_posterior_mean_variance( + x_start=out["pred_xstart"], x_t=x, t=t) + return out + + def p_sample( + self, + model: Model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + ): + """ + Sample x_{t-1} from the model at the given timestep. + + :param model: the model to sample from. + :param x: the current tensor at x_{t-1}. + :param t: the value of t, starting at 0 for the first diffusion step. + :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - 'sample': a random sample from the model. + - 'pred_xstart': a prediction of x_0. + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + noise = th.randn_like(x) + nonzero_mask = ((t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + if cond_fn is not None: + out["mean"] = self.condition_mean(cond_fn, + out, + x, + t, + model_kwargs=model_kwargs) + sample = out["mean"] + nonzero_mask * th.exp( + 0.5 * out["log_variance"]) * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def p_sample_loop( + self, + model: Model, + shape=None, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model. + + :param model: the model module. + :param shape: the shape of the samples, (N, C, H, W). + :param noise: if specified, the noise from the encoder to sample. + Should be of the same shape as `shape`. + :param clip_denoised: if True, clip x_start predictions to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param device: if specified, the device to create the samples on. + If not specified, use a model parameter's device. + :param progress: if True, show a tqdm progress bar. + :return: a non-differentiable batch of samples. + """ + final = None + for sample in self.p_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + ): + final = sample + return final["sample"] + + def p_sample_loop_progressive( + self, + model: Model, + shape=None, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model and yield intermediate samples from + each timestep of diffusion. + + Arguments are the same as p_sample_loop(). + Returns a generator over dicts, where each dict is the return value of + p_sample(). + """ + if device is None: + device = next(model.parameters()).device + if noise is not None: + img = noise + else: + assert isinstance(shape, (tuple, list)) + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + # t = th.tensor([i] * shape[0], device=device) + t = th.tensor([i] * len(img), device=device) + with th.no_grad(): + out = self.p_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + ) + yield out + img = out["sample"] + + def ddim_sample( + self, + model: Model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t-1} from the model using DDIM. + + Same usage as p_sample(). + """ + out = self.p_mean_variance( ## Apply the model to get p(x_{t-1} | x_t), as well as a prediction of the initial x, x_0 + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, + out, + x, + t, + model_kwargs=model_kwargs) + + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, + x.shape) + sigma = (eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * + th.sqrt(1 - alpha_bar / alpha_bar_prev)) + # Equation 12. + noise = th.randn_like(x) + mean_pred = (out["pred_xstart"] * th.sqrt(alpha_bar_prev) + + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps) + nonzero_mask = ((t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + sample = mean_pred + nonzero_mask * sigma * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def ddim_reverse_sample( ## In Stochastic + self, + model: Model, + x, + t, + clip_denoised=True, + denoised_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t+1} from the model using DDIM reverse ODE. + """ + assert eta == 0.0, "Reverse ODE only for deterministic path" + out = self.p_mean_variance( + model, ## model.unet_autoenc.BeatGANsAutoencModel + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + ''' out (dict), all torch.Size([1, 3, 256, 256]): + - 'mean': the model mean output. + - 'variance': the model variance output. + - 'log_variance': the log of 'variance'. + - 'pred_xstart': the prediction for x_0. + ''' + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = (_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) + * x - out["pred_xstart"]) / _extract_into_tensor( + self.sqrt_recipm1_alphas_cumprod, t, x.shape) + alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) ## alphas_cumprod_next = alpha_{t+1} + + # Equation 12. reversed (DDIM paper) (th.sqrt == torch.sqrt) + ## I.e., Equation 8 in DiffAE paper + mean_pred = (out["pred_xstart"] * th.sqrt(alpha_bar_next) + + th.sqrt(1 - alpha_bar_next) * eps) + ## mean_pred = x_{t+1} + ## alpha_bar_next = \sqrt(\alpha_{t+1}) + ## out["pred_xstart"] = f_\theta(x_t,t,z_{sem}) + ## eps = \epsilon_\theta(x_t,t,z_{sem}) + return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} + + def ddim_reverse_sample_loop( + self, + model: Model, + x, + clip_denoised=True, + denoised_fn=None, + model_kwargs=None, + eta=0.0, + device=None, + ): + if device is None: + device = next(model.parameters()).device + sample_t = [] + xstart_t = [] + T = [] + indices = list(range(self.num_timesteps)) ## [0,250) list + sample = x + for i in indices: ## For each timestep + t = th.tensor([i] * len(sample), device=device) + with th.no_grad(): + out = self.ddim_reverse_sample(model, + sample, + t=t, + clip_denoised=clip_denoised, ## True + denoised_fn=denoised_fn, ## None + model_kwargs=model_kwargs, ## dict_keys(['cond']), ['cond'].shape ([1, 512]) + eta=eta) ## 0 + ## out.keys() == dict_keys(['sample', 'pred_xstart']). + sample = out['sample'] ## x_{t-1}, torch.Size([1, 3, 256, 256]) + # [1, ..., T] + sample_t.append(sample) + # [0, ...., T-1] + xstart_t.append(out['pred_xstart']) ## predicted x_0, torch.Size([1, 3, 256, 256]) + # [0, ..., T-1] ready to use + T.append(t) + + return { + # xT (stochastic) + 'sample': sample, + # (1, ..., T) + 'sample_t': sample_t, + # xstart here is a bit different from sampling from T = T-1 to T = 0 + # may not be exact + 'xstart_t': xstart_t, + 'T': T, + } + + def ddim_sample_loop( + self, + model: Model, + shape=None, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Generate samples from the model using DDIM. + + Same usage as p_sample_loop(). + """ + final = None + for sample in self.ddim_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + eta=eta, + ): + final = sample + return final["sample"] + + def ddim_sample_loop_progressive( + self, + model: Model, + shape=None, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Use DDIM to sample from the model and yield intermediate samples from + each timestep of DDIM. + + Same usage as p_sample_loop_progressive(). + """ + if device is None: + device = next(model.parameters()).device + if noise is not None: ## xT + img = noise + else: + assert isinstance(shape, (tuple, list)) + img = th.randn(*shape, device=device) + + indices = list(range(self.num_timesteps))[::-1] ## list starting from self.num_timesteps-1 to 0 + + if progress: ## autoencoding.py doesn't go here + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + + if isinstance(model_kwargs, list): + # index dependent model kwargs + # (T-1, ..., 0) + _kwargs = model_kwargs[i] + else: ## autoencoding.py, dict + _kwargs = model_kwargs + + t = th.tensor([i] * len(img), device=device) ## get each timestep (T-1 --> 0) + with th.no_grad(): + out = self.ddim_sample( + model, ## model.unet_autoenc.BeatGANsAutoencModel + img, + t, + clip_denoised=clip_denoised, ## True + denoised_fn=denoised_fn, ## None + cond_fn=cond_fn, ## None + model_kwargs=_kwargs, + eta=eta, ## 0.0 + ) + ## out: {"sample": sample, "pred_xstart": out["pred_xstart"]} + out['t'] = t + yield out + img = out["sample"] + + def _vb_terms_bpd(self, + model: Model, + x_start, + x_t, + t, + clip_denoised=True, + model_kwargs=None): + """ + Get a term for the variational lower-bound. + + The resulting units are bits (rather than nats, as one might expect). + This allows for comparison to other papers. + + :return: a dict with the following keys: + - 'output': a shape [N] tensor of NLLs or KLs. + - 'pred_xstart': the x_0 predictions. + """ + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t) + out = self.p_mean_variance(model, + x_t, + t, + clip_denoised=clip_denoised, + model_kwargs=model_kwargs) + kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"], + out["log_variance"]) + kl = mean_flat(kl) / np.log(2.0) + + decoder_nll = -discretized_gaussian_log_likelihood( + x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]) + assert decoder_nll.shape == x_start.shape + decoder_nll = mean_flat(decoder_nll) / np.log(2.0) + + # At the first timestep return the decoder NLL, + # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) + output = th.where((t == 0), decoder_nll, kl) + return { + "output": output, + "pred_xstart": out["pred_xstart"], + 'model_forward': out['model_forward'], + } + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + + This term can't be optimized, as it only depends on the encoder. + + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = th.tensor([self.num_timesteps - 1] * batch_size, + device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl(mean1=qt_mean, + logvar1=qt_log_variance, + mean2=0.0, + logvar2=0.0) + return mean_flat(kl_prior) / np.log(2.0) + + def calc_bpd_loop(self, + model: Model, + x_start, + clip_denoised=True, + model_kwargs=None): + """ + Compute the entire variational lower-bound, measured in bits-per-dim, + as well as other related quantities. + + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param clip_denoised: if True, clip denoised samples. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + + :return: a dict containing the following keys: + - total_bpd: the total variational lower-bound, per batch element. + - prior_bpd: the prior term in the lower-bound. + - vb: an [N x T] tensor of terms in the lower-bound. + - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. + - mse: an [N x T] tensor of epsilon MSEs for each timestep. + """ + device = x_start.device + batch_size = x_start.shape[0] + + vb = [] + xstart_mse = [] + mse = [] + for t in list(range(self.num_timesteps))[::-1]: + t_batch = th.tensor([t] * batch_size, device=device) + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) + # Calculate VLB term at the current timestep + with th.no_grad(): + out = self._vb_terms_bpd( + model, + x_start=x_start, + x_t=x_t, + t=t_batch, + clip_denoised=clip_denoised, + model_kwargs=model_kwargs, + ) + vb.append(out["output"]) + xstart_mse.append(mean_flat((out["pred_xstart"] - x_start)**2)) + eps = self._predict_eps_from_xstart(x_t, t_batch, + out["pred_xstart"]) + mse.append(mean_flat((eps - noise)**2)) + + vb = th.stack(vb, dim=1) + xstart_mse = th.stack(xstart_mse, dim=1) + mse = th.stack(mse, dim=1) + + prior_bpd = self._prior_bpd(x_start) + total_bpd = vb.sum(dim=1) + prior_bpd + return { + "total_bpd": total_bpd, + "prior_bpd": prior_bpd, + "vb": vb, + "xstart_mse": xstart_mse, + "mse": mse, + } + +def atanh(x): + return 0.5 * th.log((1 + x) / (1 - x)) + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + ## E.g., broadcast_shape==torch.Size([1, 3, 256, 256]) + res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() ## get arr[timesteps] + while len(res.shape) < len(broadcast_shape): ## torch.Size([1, 1, 1, 1]) + res = res[..., None] + return res.expand(broadcast_shape) ## torch.Size([1, 3, 256, 256]), all elements are arr[timesteps] + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): + """ + Get a pre-defined beta schedule for the given name. + + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name == "linear": + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = 1000 / num_diffusion_timesteps + beta_start = scale * 0.0001 + beta_end = scale * 0.02 + return np.linspace(beta_start, + beta_end, + num_diffusion_timesteps, + dtype=np.float64) + elif schedule_name == "cosine": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2)**2, + ) + elif schedule_name == "const0.01": + scale = 1000 / num_diffusion_timesteps + return np.array([scale * 0.01] * num_diffusion_timesteps, + dtype=np.float64) + elif schedule_name == "const0.015": + scale = 1000 / num_diffusion_timesteps + return np.array([scale * 0.015] * num_diffusion_timesteps, + dtype=np.float64) + elif schedule_name == "const0.008": + scale = 1000 / num_diffusion_timesteps + return np.array([scale * 0.008] * num_diffusion_timesteps, + dtype=np.float64) + elif schedule_name == "const0.0065": + scale = 1000 / num_diffusion_timesteps + return np.array([scale * 0.0065] * num_diffusion_timesteps, + dtype=np.float64) + elif schedule_name == "const0.0055": + scale = 1000 / num_diffusion_timesteps + return np.array([scale * 0.0055] * num_diffusion_timesteps, + dtype=np.float64) + elif schedule_name == "const0.0045": + scale = 1000 / num_diffusion_timesteps + return np.array([scale * 0.0045] * num_diffusion_timesteps, + dtype=np.float64) + elif schedule_name == "const0.0035": + scale = 1000 / num_diffusion_timesteps + return np.array([scale * 0.0035] * num_diffusion_timesteps, + dtype=np.float64) + elif schedule_name == "const0.0025": + scale = 1000 / num_diffusion_timesteps + return np.array([scale * 0.0025] * num_diffusion_timesteps, + dtype=np.float64) + elif schedule_name == "const0.0015": + scale = 1000 / num_diffusion_timesteps + return np.array([scale * 0.0015] * num_diffusion_timesteps, + dtype=np.float64) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + Compute the KL divergence between two gaussians. + + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, th.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for th.exp(). + logvar1, logvar2 = [ + x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * (-1.0 + logvar2 - logvar1 + th.exp(logvar1 - logvar2) + + ((mean1 - mean2)**2) * th.exp(-logvar2)) + + +def approx_standard_normal_cdf(x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. + """ + return 0.5 * ( + 1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) + + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. + + :param x: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + assert x.shape == means.shape == log_scales.shape + centered_x = x - means + inv_stdv = th.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 1.0 / 255.0) + cdf_plus = approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - 1.0 / 255.0) + cdf_min = approx_standard_normal_cdf(min_in) + log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = th.where( + x < -0.999, + log_cdf_plus, + th.where(x > 0.999, log_one_minus_cdf_min, + th.log(cdf_delta.clamp(min=1e-12))), + ) + assert log_probs.shape == x.shape + return log_probs + + +class DummyModel(th.nn.Module): + def __init__(self, pred): + super().__init__() + self.pred = pred + + def forward(self, *args, **kwargs): + return DummyReturn(pred=self.pred) + + +class DummyReturn(NamedTuple): + pred: th.Tensor \ No newline at end of file diff --git a/diffusion/diffusion.py b/diffusion/diffusion.py new file mode 100644 index 0000000..f8bf0d9 --- /dev/null +++ b/diffusion/diffusion.py @@ -0,0 +1,182 @@ +from .base import * +from dataclasses import dataclass + + +def space_timesteps(num_timesteps, section_counts): + """ + Create a list of timesteps to use from an original diffusion process, + given the number of timesteps we want to take from equally-sized portions + of the original process. + + For example, if there's 300 timesteps and the section counts are [10,15,20] + then the first 100 timesteps are strided to be 10 timesteps, the second 100 + are strided to be 15 timesteps, and the final 100 are strided to be 20. + + If the stride is a string starting with "ddim", then the fixed striding + from the DDIM paper is used, and only one section is allowed. + + :param num_timesteps: the number of diffusion steps in the original + process to divide up. + :param section_counts: either a list of numbers, or a string containing + comma-separated numbers, indicating the step count + per section. As a special case, use "ddimN" where N + is a number of steps to use the striding from the + DDIM paper. + :return: a set of diffusion steps from the original process to use. + """ + if isinstance(section_counts, str): + if section_counts.startswith("ddim"): + desired_count = int(section_counts[len("ddim"):]) + for i in range(1, num_timesteps): + if len(range(0, num_timesteps, i)) == desired_count: + return set(range(0, num_timesteps, i)) + raise ValueError( + f"cannot create exactly {num_timesteps} steps with an integer stride" + ) + section_counts = [int(x) for x in section_counts.split(",")] + size_per = num_timesteps // len(section_counts) + extra = num_timesteps % len(section_counts) + start_idx = 0 + all_steps = [] + for i, section_count in enumerate(section_counts): + size = size_per + (1 if i < extra else 0) + if size < section_count: + raise ValueError( + f"cannot divide section of {size} steps into {section_count}") + if section_count <= 1: + frac_stride = 1 + else: + frac_stride = (size - 1) / (section_count - 1) + cur_idx = 0.0 + taken_steps = [] + for _ in range(section_count): + taken_steps.append(start_idx + round(cur_idx)) + cur_idx += frac_stride + all_steps += taken_steps + start_idx += size + return set(all_steps) + + +@dataclass +class SpacedDiffusionBeatGansConfig(GaussianDiffusionBeatGansConfig): + use_timesteps: Tuple[int] = None + + def make_sampler(self): + return SpacedDiffusionBeatGans(self) + + +class SpacedDiffusionBeatGans(GaussianDiffusionBeatGans): + """ + A diffusion process which can skip steps in a base diffusion process. + + :param use_timesteps: a collection (sequence or set) of timesteps from the + original diffusion process to retain. + :param kwargs: the kwargs to create the base diffusion process. + """ + def __init__(self, conf: SpacedDiffusionBeatGansConfig): + self.conf = conf + self.use_timesteps = set(conf.use_timesteps) + # how the new t's mapped to the old t's + self.timestep_map = [] + self.original_num_steps = len(conf.betas) + + base_diffusion = GaussianDiffusionBeatGans(conf) # pylint: disable=missing-kwoa + last_alpha_cumprod = 1.0 + new_betas = [] + for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): + if i in self.use_timesteps: + # getting the new betas of the new timesteps + new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + self.timestep_map.append(i) + conf.betas = np.array(new_betas) + super().__init__(conf) + + def p_mean_variance(self, model: Model, *args, **kwargs): # pylint: disable=signature-differs + return super().p_mean_variance(self._wrap_model(model), *args, + **kwargs) + + def training_losses(self, model: Model, *args, **kwargs): # pylint: disable=signature-differs + return super().training_losses(self._wrap_model(model), *args, + **kwargs) + + def condition_mean(self, cond_fn, *args, **kwargs): + return super().condition_mean(self._wrap_model(cond_fn), *args, + **kwargs) + + def condition_score(self, cond_fn, *args, **kwargs): + return super().condition_score(self._wrap_model(cond_fn), *args, + **kwargs) + + def _wrap_model(self, model: Model): + if isinstance(model, _WrappedModel): + return model + return _WrappedModel(model, self.timestep_map, self.rescale_timesteps, + self.original_num_steps) + + def _scale_timesteps(self, t): + # Scaling is done by the wrapped model. + return t + + +class _WrappedModel: + """ + converting the supplied t's to the old t's scales. + """ + def __init__(self, model, timestep_map, rescale_timesteps, + original_num_steps): + self.model = model + self.timestep_map = timestep_map + self.rescale_timesteps = rescale_timesteps + self.original_num_steps = original_num_steps + + def forward(self, x, t, t_cond=None, **kwargs): + """ + Args: + t: t's with differrent ranges (can be << T due to smaller eval T) need to be converted to the original t's + t_cond: the same as t but can be of different values + """ + map_tensor = th.tensor(self.timestep_map, + device=t.device, + dtype=t.dtype) + + def do(t): + new_ts = map_tensor[t] + if self.rescale_timesteps: + new_ts = new_ts.float() * (1000.0 / self.original_num_steps) + return new_ts + + if t_cond is not None: + # support t_cond + t_cond = do(t_cond) + ## run_ffhq256.py None + return self.model(x=x, t=do(t), t_cond=t_cond, **kwargs) ## self.model.__class__==model.unet_autoenc.BeatGANsAutoencModel + + def __getattr__(self, name): + # allow for calling the model's methods + if hasattr(self.model, name): + func = getattr(self.model, name) + return func + raise AttributeError(name) + + # def __call__(self, x, t, t_cond=None, **kwargs): + # """ + # Args: + # t: t's with differrent ranges (can be << T due to smaller eval T) need to be converted to the original t's + # t_cond: the same as t but can be of different values + # """ + # map_tensor = th.tensor(self.timestep_map, + # device=t.device, + # dtype=t.dtype) + + # def do(t): + # new_ts = map_tensor[t] + # if self.rescale_timesteps: + # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) + # return new_ts + + # if t_cond is not None: + # # support t_cond + # t_cond = do(t_cond) + # ## run_ffhq256.py None + # return self.model(x=x, t=do(t), t_cond=t_cond, **kwargs) diff --git a/diffusion/resample.py b/diffusion/resample.py new file mode 100644 index 0000000..1d5e581 --- /dev/null +++ b/diffusion/resample.py @@ -0,0 +1,63 @@ +from abc import ABC, abstractmethod + +import numpy as np +import torch as th +import torch.distributed as dist + + +def create_named_schedule_sampler(name, diffusion): + """ + Create a ScheduleSampler from a library of pre-defined samplers. + + :param name: the name of the sampler. + :param diffusion: the diffusion object to sample for. + """ + if name == "uniform": + return UniformSampler(diffusion) + else: + raise NotImplementedError(f"unknown schedule sampler: {name}") + + +class ScheduleSampler(ABC): + """ + A distribution over timesteps in the diffusion process, intended to reduce + variance of the objective. + + By default, samplers perform unbiased importance sampling, in which the + objective's mean is unchanged. + However, subclasses may override sample() to change how the resampled + terms are reweighted, allowing for actual changes in the objective. + """ + @abstractmethod + def weights(self): + """ + Get a numpy array of weights, one per diffusion step. + + The weights needn't be normalized, but must be positive. + """ + + def sample(self, batch_size, device): + """ + Importance-sample timesteps for a batch. + + :param batch_size: the number of timesteps. + :param device: the torch device to save to. + :return: a tuple (timesteps, weights): + - timesteps: a tensor of timestep indices. + - weights: a tensor of weights to scale the resulting losses. + """ + w = self.weights() + p = w / np.sum(w) + indices_np = np.random.choice(len(p), size=(batch_size, ), p=p) + indices = th.from_numpy(indices_np).long().to(device) + weights_np = 1 / (len(p) * p[indices_np]) + weights = th.from_numpy(weights_np).float().to(device) + return indices, weights + + +class UniformSampler(ScheduleSampler): + def __init__(self, num_timesteps): ## all steps are 1 + self._weights = np.ones([num_timesteps]) + + def weights(self): + return self._weights \ No newline at end of file diff --git a/environment/haheae.yml b/environment/haheae.yml new file mode 100644 index 0000000..87fdcfa --- /dev/null +++ b/environment/haheae.yml @@ -0,0 +1,101 @@ +name: haheae +channels: + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=5.1=1_gnu + - ca-certificates=2023.12.12=h06a4308_0 + - ld_impl_linux-64=2.38=h1181459_1 + - libffi=3.4.4=h6a678d5_0 + - libgcc-ng=11.2.0=h1234567_1 + - libgomp=11.2.0=h1234567_1 + - libstdcxx-ng=11.2.0=h1234567_1 + - ncurses=6.4=h6a678d5_0 + - openssl=3.0.12=h7f8727e_0 + - pip=23.3.1=py38h06a4308_0 + - python=3.8.18=h955ad1f_0 + - readline=8.2=h5eee18b_0 + - setuptools=68.2.2=py38h06a4308_0 + - sqlite=3.41.2=h5eee18b_0 + - tk=8.6.12=h1ccaba5_0 + - wheel=0.41.2=py38h06a4308_0 + - xz=5.4.5=h5eee18b_0 + - zlib=1.2.13=h5eee18b_0 + - pip: + - absl-py==2.0.0 + - aiohttp==3.9.1 + - aiosignal==1.3.1 + - appdirs==1.4.4 + - async-timeout==4.0.3 + - attrs==23.2.0 + - cachetools==5.3.2 + - certifi==2023.11.17 + - charset-normalizer==3.3.2 + - click==8.1.7 + - contourpy==1.1.1 + - cycler==0.12.1 + - cython==0.29.37 + - docker-pycreds==0.4.0 + - fonttools==4.47.2 + - frozenlist==1.4.1 + - fsspec==2023.12.2 + - ftfy==6.1.3 + - future==0.18.3 + - gitdb==4.0.11 + - gitpython==3.1.41 + - google-auth==2.26.2 + - google-auth-oauthlib==1.0.0 + - grpcio==1.60.0 + - hdbscan==0.8.33 + - idna==3.6 + - importlib-metadata==7.0.1 + - importlib-resources==6.1.1 + - joblib==1.3.2 + - kiwisolver==1.4.5 + - lmdb==1.2.1 + - lpips==0.1.4 + - markdown==3.5.2 + - markupsafe==2.1.3 + - matplotlib==3.5.3 + - multidict==6.0.4 + - numpy==1.24.4 + - oauthlib==3.2.2 + - packaging==23.2 + - pandas==1.5.3 + - pillow==10.2.0 + - protobuf==4.25.2 + - psutil==5.9.8 + - pyasn1==0.5.1 + - pyasn1-modules==0.3.0 + - pydeprecate==0.3.1 + - pyparsing==3.1.1 + - python-dateutil==2.8.2 + - pytorch-fid==0.2.0 + - pytorch-lightning==1.4.5 + - pytz==2023.3.post1 + - pyyaml==6.0.1 + - regex==2023.12.25 + - requests==2.31.0 + - requests-oauthlib==1.3.1 + - rsa==4.9 + - scikit-learn==1.3.2 + - scipy==1.5.4 + - sentry-sdk==1.39.2 + - setproctitle==1.3.3 + - six==1.16.0 + - smmap==5.0.1 + - tensorboard==2.14.0 + - tensorboard-data-server==0.7.2 + - threadpoolctl==3.2.0 + - torch==1.8.1 + - torchmetrics==0.5.0 + - torchvision==0.9.1 + - tqdm==4.66.1 + - typing-extensions==4.9.0 + - tzdata==2023.4 + - urllib3==2.1.0 + - wandb==0.16.2 + - wcwidth==0.2.13 + - werkzeug==3.0.1 + - yarl==1.9.4 + - zipp==3.17.0 diff --git a/main.py b/main.py new file mode 100644 index 0000000..6a2d2df --- /dev/null +++ b/main.py @@ -0,0 +1,565 @@ +import warnings +warnings.filterwarnings("ignore") +import os +os.nice(5) +import copy, wandb +from tqdm import tqdm, trange +import argparse +import json +import re +import random +import math +import numpy as np +import pandas as pd +import pytorch_lightning as pl +import torch +from pytorch_lightning import loggers as pl_loggers +from pytorch_lightning.callbacks import * +from torch import nn +import torch.nn.functional as F +from torch.cuda import amp +from torch.optim.optimizer import Optimizer +from config import * +import time +import datetime + +class LitModel(pl.LightningModule): + def __init__(self, conf: TrainConfig): + super().__init__() + + ## wandb + self.save_hyperparameters({k:v for (k,v) in vars(conf).items() if not callable(v)}) + + if conf.seed is not None: + pl.seed_everything(conf.seed) + + self.save_hyperparameters(conf.as_dict_jsonable()) + self.conf = conf + self.model = conf.make_model_conf().make_model() + + self.ema_model = copy.deepcopy(self.model) + self.ema_model.requires_grad_(False) + self.ema_model.eval() + + model_size = 0 + for param in self.model.parameters(): + model_size += param.data.nelement() + print('Model params: %.3f M' % (model_size / 1024 / 1024)) + + self.sampler = conf.make_diffusion_conf().make_sampler() + self.eval_sampler = conf.make_eval_diffusion_conf().make_sampler() + self.T_sampler = conf.make_T_sampler() + self.save_every_epochs = conf.save_every_epochs + self.eval_every_epochs = conf.eval_every_epochs + + def setup(self, stage=None) -> None: + """ + make datasets & seeding each worker separately + """ + ############################################## + # NEED TO SET THE SEED SEPARATELY HERE + if self.conf.seed is not None: + seed = self.conf.seed + np.random.seed(seed) + random.seed(seed) # Python random module. + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + print('seed:', seed) + ############################################## + + ## Load dataset + if self.conf.mode == 'train': + self.train_data = load_egobody(self.conf.data_dir, self.conf.seq_len+self.conf.seq_len_future, self.conf.data_sample_rate, train=1) + self.val_data = load_egobody(self.conf.data_dir, self.conf.seq_len+self.conf.seq_len_future, self.conf.data_sample_rate, train=0) + if self.conf.in_channels == 6: # hand only + self.train_data = self.train_data[:, :6, :] + self.val_data = self.val_data[:, :6, :] + if self.conf.in_channels == 3: # head only + self.train_data = self.train_data[:, 6:, :] + self.val_data = self.val_data[:, 6:, :] + + def encode(self, x): + assert self.conf.model_type.has_autoenc() + cond, pred_hand, pred_head = self.ema_model.encoder.forward(x) + return cond, pred_hand, pred_head + + def encode_stochastic(self, x, cond, T=None): + if T is None: + sampler = self.eval_sampler + else: + sampler = self.conf._make_diffusion_conf(T).make_sampler() # get noise at step T + + ## x_0 -> x-T using reverse of inference + out = sampler.ddim_reverse_sample_loop(self.ema_model, x, model_kwargs={'cond': cond}) + ''' 'sample': x_T + 'sample_t': x_t, t in (1, ..., T) + 'xstart_t': predicted x_0 at each timestep. "xstart here is a bit different from sampling from T = T-1 to T = 0" + 'T': (1, ..., T) + ''' + return out['sample'] + + def train_dataloader(self): + return torch.utils.data.DataLoader(self.train_data, batch_size=self.conf.batch_size, shuffle=True) + + def val_dataloader(self): + return torch.utils.data.DataLoader(self.val_data, batch_size=self.conf.batch_size_eval, shuffle=False) + + def is_last_accum(self, batch_idx): + """ + is it the last gradient accumulation loop? + used with gradient_accum > 1 and to see if the optimizer will perform "step" in this iteration or not + """ + return (batch_idx + 1) % self.conf.accum_batches == 0 + + def training_step(self, batch, batch_idx): + """ + given an input, calculate the loss function + no optimization at this stage. + """ + with amp.autocast(False): + x_start = batch[:, :, :self.conf.seq_len] + x_future = batch[:, :, self.conf.seq_len:] + + if self.conf.train_mode == TrainMode.diffusion: + """ + main training mode!!! + """ + t, weight = self.T_sampler.sample(len(x_start), x_start.device) + ''' self.T_sampler: diffusion.resample.UniformSampler (weights for all timesteps are 1) + - t: a tensor of timestep indices. + - weight: a tensor of weights to scale the resulting losses. + ## num_timesteps is self.conf.T == 1000 + ''' + losses = self.sampler.training_losses(model=self.model, + x_start=x_start, + t=t, + x_future=x_future, + hand_mse_factor = self.conf.hand_mse_factor, + head_mse_factor = self.conf.head_mse_factor, + ) + else: + raise NotImplementedError() + + loss = losses['loss'].mean() ## average loss across mini-batches + #noise_mse = losses['mse'].mean() + #hand_mse = losses['hand_mse'].mean() + #head_mse = losses['head_mse'].mean() + + ## Log loss and metric (wandb) + self.log("train_loss", loss, on_epoch=True, prog_bar=True) + #self.log("train_noise_mse", noise_mse, on_epoch=True, prog_bar=True) + #self.log("train_hand_mse", hand_mse, on_epoch=True, prog_bar=True) + #self.log("train_head_mse", head_mse, on_epoch=True, prog_bar=True) + return {'loss': loss} + + def validation_step(self, batch, batch_idx): + if self.conf.in_channels == 9: # both hand and head + if((self.current_epoch+1)% self.eval_every_epochs == 0): + batch_future = batch[:, :, self.conf.seq_len:] + gt_hand_future = batch_future[:, :6, :] + gt_head_future = batch_future[:, 6:, :] + batch = batch[:, :, :self.conf.seq_len] + cond, pred_hand_future, pred_head_future = self.encode(batch) + xT = self.encode_stochastic(batch, cond) + pred_xstart = self.generate(xT, cond) + + # hand reconstruction error + gt_hand = batch[:, :6, :] + pred_hand = pred_xstart[:, :6, :] + bs, channels, seq_len = gt_hand.shape + gt_hand = gt_hand.reshape(bs, 2, 3, seq_len) + pred_hand = pred_hand.reshape(bs, 2, 3, seq_len) + hand_traj = torch.mean(torch.norm(gt_hand - pred_hand, dim=2)) + + # hand prediction error + bs, channels, seq_len = gt_hand_future.shape + gt_hand_future = gt_hand_future.reshape(bs, 2, 3, seq_len) + pred_hand_future = pred_hand_future.reshape(bs, 2, 3, seq_len) + baseline_hand_future = gt_hand[:, :, :, -1:].expand(-1, -1, -1, self.conf.seq_len_future).clone() + hand_traj_future = torch.mean(torch.norm(gt_hand_future - pred_hand_future, dim=2)) + hand_traj_future_baseline = torch.mean(torch.norm(gt_hand_future - baseline_hand_future, dim=2)) + + # head reconstruction error + gt_head = batch[:, 6:, :] + gt_head = F.normalize(gt_head, dim=1) # normalize head orientation to unit vectors + pred_head = pred_xstart[:, 6:, :] + pred_head = F.normalize(pred_head, dim=1) + head_ang = torch.mean(acos_safe(torch.sum(gt_head*pred_head, 1)))/torch.tensor(math.pi) * 180.0 + + # head prediction error + gt_head_future = F.normalize(gt_head_future, dim=1) + pred_head_future = F.normalize(pred_head_future, dim=1) + baseline_head_future = gt_head[:, :, -1:].expand(-1, -1, self.conf.seq_len_future).clone() + head_ang_future = torch.mean(acos_safe(torch.sum(gt_head_future*pred_head_future, 1)))/torch.tensor(math.pi) * 180.0 + head_ang_future_baseline = torch.mean(acos_safe(torch.sum(gt_head_future*baseline_head_future, 1)))/torch.tensor(math.pi) * 180.0 + + self.log("val_hand_traj", hand_traj, on_epoch=True, prog_bar=True) + self.log("val_head_ang", head_ang, on_epoch=True, prog_bar=True) + self.log("val_hand_traj_future", hand_traj_future, on_epoch=True, prog_bar=True) + self.log("val_head_ang_future", head_ang_future, on_epoch=True, prog_bar=True) + self.log("val_hand_traj_future_baseline", hand_traj_future_baseline, on_epoch=True, prog_bar=True) + self.log("val_head_ang_future_baseline", head_ang_future_baseline, on_epoch=True, prog_bar=True) + + if self.conf.in_channels == 6: # hand only + if((self.current_epoch+1)% self.eval_every_epochs == 0): + batch_future = batch[:, :, self.conf.seq_len:] + gt_hand_future = batch_future[:, :, :] + batch = batch[:, :, :self.conf.seq_len] + cond, pred_hand_future, pred_head_future = self.encode(batch) + xT = self.encode_stochastic(batch, cond) + pred_xstart = self.generate(xT, cond) + + # hand reconstruction error + gt_hand = batch[:, :, :] + pred_hand = pred_xstart[:, :, :] + bs, channels, seq_len = gt_hand.shape + gt_hand = gt_hand.reshape(bs, 2, 3, seq_len) + pred_hand = pred_hand.reshape(bs, 2, 3, seq_len) + hand_traj = torch.mean(torch.norm(gt_hand - pred_hand, dim=2)) + + # hand prediction error + bs, channels, seq_len = gt_hand_future.shape + gt_hand_future = gt_hand_future.reshape(bs, 2, 3, seq_len) + pred_hand_future = pred_hand_future.reshape(bs, 2, 3, seq_len) + baseline_hand_future = gt_hand[:, :, :, -1:].expand(-1, -1, -1, self.conf.seq_len_future).clone() + hand_traj_future = torch.mean(torch.norm(gt_hand_future - pred_hand_future, dim=2)) + hand_traj_future_baseline = torch.mean(torch.norm(gt_hand_future - baseline_hand_future, dim=2)) + + self.log("val_hand_traj", hand_traj, on_epoch=True, prog_bar=True) + self.log("val_hand_traj_future", hand_traj_future, on_epoch=True, prog_bar=True) + self.log("val_hand_traj_future_baseline", hand_traj_future_baseline, on_epoch=True, prog_bar=True) + + if self.conf.in_channels == 3: # head only + if((self.current_epoch+1)% self.eval_every_epochs == 0): + batch_future = batch[:, :, self.conf.seq_len:] + gt_head_future = batch_future[:, :, :] + batch = batch[:, :, :self.conf.seq_len] + cond, pred_hand_future, pred_head_future = self.encode(batch) + xT = self.encode_stochastic(batch, cond) + pred_xstart = self.generate(xT, cond) + + # head reconstruction error + gt_head = batch[:, :, :] + gt_head = F.normalize(gt_head, dim=1) # normalize head orientation to unit vectors + pred_head = pred_xstart[:, :, :] + pred_head = F.normalize(pred_head, dim=1) + head_ang = torch.mean(acos_safe(torch.sum(gt_head*pred_head, 1)))/torch.tensor(math.pi) * 180.0 + + # head prediction error + gt_head_future = F.normalize(gt_head_future, dim=1) + pred_head_future = F.normalize(pred_head_future, dim=1) + baseline_head_future = gt_head[:, :, -1:].expand(-1, -1, self.conf.seq_len_future).clone() + head_ang_future = torch.mean(acos_safe(torch.sum(gt_head_future*pred_head_future, 1)))/torch.tensor(math.pi) * 180.0 + head_ang_future_baseline = torch.mean(acos_safe(torch.sum(gt_head_future*baseline_head_future, 1)))/torch.tensor(math.pi) * 180.0 + + self.log("val_head_ang", head_ang, on_epoch=True, prog_bar=True) + self.log("val_head_ang_future", head_ang_future, on_epoch=True, prog_bar=True) + self.log("val_head_ang_future_baseline", head_ang_future_baseline, on_epoch=True, prog_bar=True) + + + def test_step(self, batch, batch_idx): + batch_future = batch[:, :, self.conf.seq_len:] + gt_hand_future = batch_future[:, :6, :] + gt_head_future = batch_future[:, 6:, :] + batch = batch[:, :, :self.conf.seq_len] + cond, pred_hand_future, pred_head_future = self.encode(batch) + xT = self.encode_stochastic(batch, cond) + pred_xstart = self.generate(xT, cond) + + # hand reconstruction error + gt_hand = batch[:, :6, :] + pred_hand = pred_xstart[:, :6, :] + bs, channels, seq_len = gt_hand.shape + gt_hand = gt_hand.reshape(bs, 2, 3, seq_len) + pred_hand = pred_hand.reshape(bs, 2, 3, seq_len) + hand_traj = torch.mean(torch.norm(gt_hand - pred_hand, dim=2)) + + # hand prediction error + bs, channels, seq_len = gt_hand_future.shape + gt_hand_future = gt_hand_future.reshape(bs, 2, 3, seq_len) + pred_hand_future = pred_hand_future.reshape(bs, 2, 3, seq_len) + baseline_hand_future = gt_hand[:, :, :, -1:].expand(-1, -1, -1, self.conf.seq_len_future).clone() + hand_traj_future = torch.mean(torch.norm(gt_hand_future - pred_hand_future, dim=2)) + hand_traj_future_baseline = torch.mean(torch.norm(gt_hand_future - baseline_hand_future, dim=2)) + + # head reconstruction error + gt_head = batch[:, 6:, :] + gt_head = F.normalize(gt_head, dim=1) # normalize head orientation to unit vectors + pred_head = pred_xstart[:, 6:, :] + pred_head = F.normalize(pred_head, dim=1) + head_ang = torch.mean(acos_safe(torch.sum(gt_head*pred_head, 1)))/torch.tensor(math.pi) * 180.0 + + # head prediction error + gt_head_future = F.normalize(gt_head_future, dim=1) + pred_head_future = F.normalize(pred_head_future, dim=1) + baseline_head_future = gt_head[:, :, -1:].expand(-1, -1, self.conf.seq_len_future).clone() + head_ang_future = torch.mean(acos_safe(torch.sum(gt_head_future*pred_head_future, 1)))/torch.tensor(math.pi) * 180.0 + head_ang_future_baseline = torch.mean(acos_safe(torch.sum(gt_head_future*baseline_head_future, 1)))/torch.tensor(math.pi) * 180.0 + + self.log("test_hand_traj", hand_traj, on_epoch=True, prog_bar=True) + self.log("test_head_ang", head_ang, on_epoch=True, prog_bar=True) + self.log("test_hand_traj_future", hand_traj_future, on_epoch=True, prog_bar=True) + self.log("test_head_ang_future", head_ang_future, on_epoch=True, prog_bar=True) + self.log("test_hand_traj_future_baseline", hand_traj_future_baseline, on_epoch=True, prog_bar=True) + self.log("test_head_ang_future_baseline", head_ang_future_baseline, on_epoch=True, prog_bar=True) + + + def generate(self, noise, cond=None, ema=True, T=None): + if T is None: + sampler = self.eval_sampler + else: + sampler = self.conf._make_diffusion_conf(T).make_sampler() + + if ema: + model = self.ema_model + else: + model = self.model + + gen = sampler.sample(model=model, noise=noise, model_kwargs={'cond': cond}) + return gen + + def on_train_batch_end(self, outputs, batch, batch_idx: int, + dataloader_idx: int) -> None: + """ + after each training step ... + """ + if self.is_last_accum(batch_idx): + # only apply ema on the last gradient accumulation step, + # if it is the iteration that has optimizer.step() + + ema(self.model, self.ema_model, self.conf.ema_decay) + + if (batch_idx==len(self.train_dataloader())-1) and ((self.current_epoch+1) % self.save_every_epochs == 0): + save_path = os.path.join(self.conf.logdir, 'epoch_%d.ckpt' % (self.current_epoch+1)) + torch.save({ + 'state_dict': self.state_dict(), + 'global_step': self.global_step, + 'loss': outputs['loss'], + }, save_path) + + def on_before_optimizer_step(self, optimizer: Optimizer, + optimizer_idx: int) -> None: + # fix the fp16 + clip grad norm problem with pytorch lightinng + # this is the currently correct way to do it + if self.conf.grad_clip > 0: + # from trainer.params_grads import grads_norm, iter_opt_params + params = [ + p for group in optimizer.param_groups for p in group['params'] + ] + # print('before:', grads_norm(iter_opt_params(optimizer))) + torch.nn.utils.clip_grad_norm_(params, max_norm=self.conf.grad_clip) + # print('after:', grads_norm(iter_opt_params(optimizer))) + + def configure_optimizers(self): + out = {} + if self.conf.optimizer == OptimizerType.adam: + optim = torch.optim.Adam(self.model.parameters(), + lr=self.conf.lr, + weight_decay=self.conf.weight_decay) + elif self.conf.optimizer == OptimizerType.adamw: + optim = torch.optim.AdamW(self.model.parameters(), + lr=self.conf.lr, + weight_decay=self.conf.weight_decay) + else: + raise NotImplementedError() + out['optimizer'] = optim + if self.conf.warmup > 0: + sched = torch.optim.lr_scheduler.LambdaLR(optim, + lr_lambda=WarmupLR( + self.conf.warmup)) + out['lr_scheduler'] = { + 'scheduler': sched, + 'interval': 'step', + } + return out + + +def ema(source, target, decay): + source_dict = source.state_dict() + target_dict = target.state_dict() + for key in source_dict.keys(): + target_dict[key].data.copy_(target_dict[key].data * decay + + source_dict[key].data * (1 - decay)) + + +class WarmupLR: + def __init__(self, warmup) -> None: + self.warmup = warmup + + def __call__(self, step): + return min(step, self.warmup) / self.warmup + + +def train(conf: TrainConfig, model: LitModel, gpus): + checkpoint = ModelCheckpoint(dirpath=conf.logdir, + filename='last', + save_last=True, + save_top_k=1, + every_n_epochs=conf.save_every_epochs, + ) + checkpoint_path = f'{conf.logdir}last.ckpt' + if os.path.exists(checkpoint_path): + resume = checkpoint_path + if conf.mode == 'train': + print('ckpt path:', checkpoint_path) + else: + print('checkpoint not found!') + resume = None + + wandb_logger = pl_loggers.WandbLogger(project='haheae', + name='%s_%s'%(model.conf.data_name, conf.logdir.split('/')[-2]), + log_model=True, + save_dir=conf.logdir, + dir = conf.logdir, + config=vars(model.conf), + save_code=True, + settings=wandb.Settings(code_dir=".")) + + trainer = pl.Trainer( + max_epochs=conf.total_epochs, + resume_from_checkpoint=resume, + gpus=gpus, + precision=16 if conf.fp16 else 32, + callbacks=[ + checkpoint, + LearningRateMonitor(), + ], + logger= wandb_logger, + accumulate_grad_batches=conf.accum_batches, + progress_bar_refresh_rate=4, + ) + + if conf.mode == 'train': + trainer.fit(model) + elif conf.mode == 'eval': + checkpoint_path = f'{conf.logdir}last.ckpt' + # load the latest checkpoint + print('loading from:', checkpoint_path) + state = torch.load(checkpoint_path) + model.load_state_dict(state['state_dict']) + + test_datasets = ['egobody', 'adt', 'gimo'] + for dataset_name in test_datasets: + if dataset_name == 'egobody': + data_dir = "/scratch/hu/pose_forecast/egobody_pose2gaze/" + test_data = load_egobody(data_dir, conf.seq_len+conf.seq_len_future, 1, train=0) # use the test set + elif dataset_name == 'adt': + data_dir = "/scratch/hu/pose_forecast/adt_pose2gaze/" + test_data = load_adt(data_dir, conf.seq_len+conf.seq_len_future, 1, train=2) # use the train+test set + elif dataset_name == 'gimo': + data_dir = "/scratch/hu/pose_forecast/gimo_pose2gaze/" + test_data = load_gimo(data_dir, conf.seq_len+conf.seq_len_future, 1, train=2) # use the train+test set + + test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=conf.batch_size_eval, shuffle=False) + results = trainer.test(model, dataloaders=test_dataloader, verbose=False) + print("\n\nTest on {}, dataset size: {}".format(dataset_name, test_data.shape)) + print("test_hand_traj: {:.3f} cm".format(results[0]['test_hand_traj']*100)) + print("test_head_ang: {:.3f} deg".format(results[0]['test_head_ang'])) + print("test_hand_traj_future: {:.3f} cm".format(results[0]['test_hand_traj_future']*100)) + print("test_head_ang_future: {:.3f} deg".format(results[0]['test_head_ang_future'])) + print("test_hand_traj_future_baseline: {:.3f} cm".format(results[0]['test_hand_traj_future_baseline']*100)) + print("test_head_ang_future_baseline: {:.3f} deg\n\n".format(results[0]['test_head_ang_future_baseline'])) + + wandb.finish() + + +def acos_safe(x, eps=1e-6): + slope = np.arccos(1-eps) / eps + buf = torch.empty_like(x) + good = abs(x) <= 1-eps + bad = ~good + sign = torch.sign(x[bad]) + buf[good] = torch.acos(x[good]) + buf[bad] = torch.acos(sign * (1 - eps)) - slope*sign*(abs(x[bad]) - 1 + eps) + return buf + + +def get_representation(model, dataset, conf, device='cuda'): + model = model.to(device) + model.eval() + dataloader = torch.utils.data.DataLoader(dataset, batch_size=conf.batch_size_eval, shuffle=False) + with torch.no_grad(): + conds = [] # semantic representation + xTs = [] # stochastic representation + for batch in tqdm(dataloader, total=len(dataloader), desc='infer'): + batch = batch.to(device) + cond, _, _ = model.encode(batch) + xT = model.encode_stochastic(batch, cond) + cond_cpu = cond.cpu().data.numpy() + xT_cpu = xT.cpu().data.numpy() + if len(conds) == 0: + conds = cond_cpu + xTs = xT_cpu + else: + conds = np.concatenate((conds, cond_cpu), axis=0) + xTs = np.concatenate((xTs, xT_cpu), axis=0) + return conds, xTs + + +def generate_from_representation(model, conds, xTs, device='cuda'): + model = model.to(device) + model.eval() + conds = torch.from_numpy(conds).to(device) + xTs = torch.from_numpy(xTs).to(device) + rec = model.generate(xTs, conds) + rec = rec.cpu().data.numpy() + return rec + + +def evaluate_reconstruction(gt, rec): + # hand reconstruction error (cm) + gt_hand = gt[:, :6, :] + rec_hand = rec[:, :6, :] + bs, channels, seq_len = gt_hand.shape + gt_hand = gt_hand.reshape(bs, 2, 3, seq_len) + rec_hand = rec_hand.reshape(bs, 2, 3, seq_len) + hand_traj_errors = np.mean(np.mean(np.linalg.norm(gt_hand - rec_hand, axis=2), axis=1), axis=1)*100 + + # head reconstruction error (deg) + gt_head = gt[:, 6:, :] + gt_head_norm = np.linalg.norm(gt_head, axis=1, keepdims=True) + gt_head = gt_head/gt_head_norm + rec_head = rec[:, 6:, :] + rec_head_norm = np.linalg.norm(rec_head, axis=1, keepdims=True) + rec_head = rec_head/rec_head_norm + dot_sum = np.clip(np.sum(gt_head*rec_head, axis=1), -1, 1) + head_ang_errors = np.mean(np.arccos(dot_sum), axis=1)/np.pi * 180.0 + return hand_traj_errors, head_ang_errors + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--gpus', default=7, type=int) + parser.add_argument('--mode', default='eval', type=str) + parser.add_argument('--encoder_type', default='gcn', type=str) + parser.add_argument('--model_name', default='haheae', type=str) + parser.add_argument('--hand_mse_factor', default=1.0, type=float) + parser.add_argument('--head_mse_factor', default=1.0, type=float) + parser.add_argument('--data_sample_rate', default=1, type=int) + parser.add_argument('--epoch', default=130, type=int) + parser.add_argument('--in_channels', default=9, type=int) + args = parser.parse_args() + + conf = egobody_autoenc(args.mode, args.encoder_type, args.hand_mse_factor, args.head_mse_factor, args.data_sample_rate, args.epoch, args.in_channels) + model = LitModel(conf) + conf.logdir = f'{conf.logdir}{args.model_name}/' + print('log dir: {}'.format(conf.logdir)) + MakeDir(conf.logdir) + + if conf.mode == 'train' or conf.mode == 'eval': # train or evaluate the model + os.environ['WANDB_CACHE_DIR'] = conf.logdir + os.environ['WANDB_DATA_DIR'] = conf.logdir + # set wandb to not upload checkpoints, but all the others + os.environ['WANDB_IGNORE_GLOBS'] = '*.ckpt' + local_time = time.asctime(time.localtime(time.time())) + print('\n{} starts at {}'.format(conf.mode, local_time)) + start_time = datetime.datetime.now() + train(conf, model, gpus=[args.gpus]) + end_time = datetime.datetime.now() + total_time = (end_time - start_time).seconds/60 + print('\nTotal time: {:.3f} min'.format(total_time)) + local_time = time.asctime(time.localtime(time.time())) + print('\n{} ends at {}'.format(conf.mode, local_time)) \ No newline at end of file diff --git a/model/__init__.py b/model/__init__.py new file mode 100644 index 0000000..12a56ea --- /dev/null +++ b/model/__init__.py @@ -0,0 +1,6 @@ +from typing import Union +from .unet import BeatGANsUNetModel, BeatGANsUNetConfig, GCNUNetModel, GCNUNetConfig +from .unet_autoenc import BeatGANsAutoencConfig, BeatGANsAutoencModel, GCNAutoencConfig, GCNAutoencModel + +Model = Union[BeatGANsUNetModel, BeatGANsAutoencModel, GCNUNetModel, GCNAutoencModel] +ModelConfig = Union[BeatGANsUNetConfig, BeatGANsAutoencConfig, GCNUNetConfig,GCNAutoencConfig] \ No newline at end of file diff --git a/model/blocks.py b/model/blocks.py new file mode 100644 index 0000000..cf883c1 --- /dev/null +++ b/model/blocks.py @@ -0,0 +1,579 @@ +import math, pdb +from abc import abstractmethod +from dataclasses import dataclass +from numbers import Number + +import torch as th +import torch.nn.functional as F +from choices import * +from config_base import BaseConfig +from torch import nn +import numpy as np +from .nn import (avg_pool_nd, conv_nd, linear, normalization, + timestep_embedding, zero_module) + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + @abstractmethod + def forward(self, x, emb=None, cond=None, lateral=None): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + def forward(self, x, emb=None, cond=None, lateral=None): + for layer in self: + if isinstance(layer, TimestepBlock): + '''if layer(x, emb=emb, cond=cond, lateral=lateral).shape[-1]==10: + pdb.set_trace()''' + x = layer(x, emb=emb, cond=cond, lateral=lateral) + else: + '''if layer(x).shape[-1]==10: + pdb.set_trace()''' + x = layer(x) + return x + + +@dataclass +class ResBlockConfig(BaseConfig): + channels: int + emb_channels: int + dropout: float + out_channels: int = None + # condition the resblock with time (and encoder's output) + use_condition: bool = True + # whether to use 3x3 conv for skip path when the channels aren't matched + use_conv: bool = False + # dimension of conv (always 1 = 1d) + dims: int = 1 + up: bool = False + down: bool = False + # whether to condition with both time & encoder's output + two_cond: bool = False + # number of encoders' output channels + cond_emb_channels: int = None + # suggest: False + has_lateral: bool = False + lateral_channels: int = None + # whether to init the convolution with zero weights + # this is default from BeatGANs and seems to help learning + use_zero_module: bool = True + + def __post_init__(self): + self.out_channels = self.out_channels or self.channels + self.cond_emb_channels = self.cond_emb_channels or self.emb_channels + + def make_model(self): + return ResBlock(self) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + + total layers: + in_layers + - norm + - act + - conv + out_layers + - norm + - (modulation) + - act + - conv + """ + def __init__(self, conf: ResBlockConfig): + super().__init__() + self.conf = conf + + ############################# + # IN LAYERS + ############################# + assert conf.lateral_channels is None + layers = [ + normalization(conf.channels), + nn.SiLU(), + conv_nd(conf.dims, conf.channels, conf.out_channels, 3, padding=1) ## 3 is kernel size + ] + self.in_layers = nn.Sequential(*layers) + + self.updown = conf.up or conf.down + + if conf.up: + self.h_upd = Upsample(conf.channels, False, conf.dims) + self.x_upd = Upsample(conf.channels, False, conf.dims) + elif conf.down: + self.h_upd = Downsample(conf.channels, False, conf.dims) + self.x_upd = Downsample(conf.channels, False, conf.dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + ############################# + # OUT LAYERS CONDITIONS + ############################# + if conf.use_condition: + # condition layers for the out_layers + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear(conf.emb_channels, 2 * conf.out_channels), + ) + + if conf.two_cond: + self.cond_emb_layers = nn.Sequential( + nn.SiLU(), + linear(conf.cond_emb_channels, conf.out_channels), + ) + ############################# + # OUT LAYERS (ignored when there is no condition) + ############################# + # original version + conv = conv_nd(conf.dims, + conf.out_channels, + conf.out_channels, + 3, + padding=1) + if conf.use_zero_module: + # zere out the weights + # it seems to help training + conv = zero_module(conv) + + # construct the layers + # - norm + # - (modulation) + # - act + # - dropout + # - conv + layers = [] + layers += [ + normalization(conf.out_channels), + nn.SiLU(), + nn.Dropout(p=conf.dropout), + conv, + ] + self.out_layers = nn.Sequential(*layers) + + ############################# + # SKIP LAYERS + ############################# + if conf.out_channels == conf.channels: + # cannot be used with gatedconv, also gatedconv is alsways used as the first block + self.skip_connection = nn.Identity() + else: + if conf.use_conv: + kernel_size = 3 + padding = 1 + else: + kernel_size = 1 + padding = 0 + + self.skip_connection = conv_nd(conf.dims, + conf.channels, + conf.out_channels, + kernel_size, + padding=padding) + + def forward(self, x, emb=None, cond=None, lateral=None): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + + Args: + x: input + lateral: lateral connection from the encoder + """ + return self._forward(x, emb, cond, lateral) + + def _forward( + self, + x, + emb=None, + cond=None, + lateral=None, + ): + """ + Args: + lateral: required if "has_lateral" and non-gated, with gated, it can be supplied optionally + """ + if self.conf.has_lateral: + # lateral may be supplied even if it doesn't require + # the model will take the lateral only if "has_lateral" + assert lateral is not None + # x = F.interpolate(x, size=(lateral.size(2)), mode='linear' ) + x = th.cat([x, lateral], dim=1) + + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + + if self.conf.use_condition: + # it's possible that the network may not receieve the time emb + # this happens with autoenc and setting the time_at + if emb is not None: + emb_out = self.emb_layers(emb).type(h.dtype) + else: + emb_out = None + + if self.conf.two_cond: + # it's possible that the network is two_cond + # but it doesn't get the second condition + # in which case, we ignore the second condition + # and treat as if the network has one condition + if cond is None: + cond_out = None + else: + if not isinstance(cond, th.Tensor): + assert isinstance(cond, dict) + cond = cond['cond'] + cond_out = self.cond_emb_layers(cond).type(h.dtype) + if cond_out is not None: + while len(cond_out.shape) < len(h.shape): + cond_out = cond_out[..., None] + else: + cond_out = None + + # this is the new refactored code + h = apply_conditions( + h=h, + emb=emb_out, + cond=cond_out, + layers=self.out_layers, + scale_bias=1, + in_channels=self.conf.out_channels, + up_down_layer=None, + ) + + return self.skip_connection(x) + h + + +def apply_conditions( + h, + emb=None, + cond=None, + layers: nn.Sequential = None, + scale_bias: float = 1, + in_channels: int = 512, + up_down_layer: nn.Module = None, +): + """ + apply conditions on the feature maps + + Args: + emb: time conditional (ready to scale + shift) + cond: encoder's conditional (ready to scale + shift) + """ + two_cond = emb is not None and cond is not None + + if emb is not None: + # adjusting shapes + while len(emb.shape) < len(h.shape): + emb = emb[..., None] + + if two_cond: + # adjusting shapes + while len(cond.shape) < len(h.shape): + cond = cond[..., None] + # time first + scale_shifts = [emb, cond] + else: + # "cond" is not used with single cond mode + scale_shifts = [emb] + + # support scale, shift or shift only + for i, each in enumerate(scale_shifts): + if each is None: + # special case: the condition is not provided + a = None + b = None + else: + if each.shape[1] == in_channels * 2: + a, b = th.chunk(each, 2, dim=1) + else: + a = each + b = None + scale_shifts[i] = (a, b) + + # condition scale bias could be a list + if isinstance(scale_bias, Number): + biases = [scale_bias] * len(scale_shifts) + else: + # a list + biases = scale_bias + + # default, the scale & shift are applied after the group norm but BEFORE SiLU + pre_layers, post_layers = layers[0], layers[1:] + + # spilt the post layer to be able to scale up or down before conv + # post layers will contain only the conv + mid_layers, post_layers = post_layers[:-2], post_layers[-2:] + + h = pre_layers(h) + # scale and shift for each condition + for i, (scale, shift) in enumerate(scale_shifts): + # if scale is None, it indicates that the condition is not provided + if scale is not None: + h = h * (biases[i] + scale) + if shift is not None: + h = h + shift + h = mid_layers(h) + + # upscale or downscale if any just before the last conv + if up_down_layer is not None: + h = up_down_layer(h) + h = post_layers(h) + return h + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + def __init__(self, channels, use_conv, dims=2, out_channels=None): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, + self.channels, + self.out_channels, + 3, + padding=1) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), + mode="nearest") + else: + # if x.shape[2] == 4: + # feature = 9 + # x = F.interpolate(x, size=(feature), mode="nearest") + # if x.shape[2] == 8: + # feature = 9 + # x = F.interpolate(x, size=(feature), mode="nearest") + # else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + def __init__(self, channels, use_conv, dims=2, out_channels=None): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + self.stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd(dims, + self.channels, + self.out_channels, + 3, + stride=self.stride, + padding=1) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=self.stride, stride=self.stride) + + def forward(self, x): + assert x.shape[1] == self.channels + # if x.shape[2] % 2 != 0: + # op = avg_pool_nd(self.dims, kernel_size=3, stride=2) + # return op(x) + # if x.shape[2] % 2 != 0: + # op = avg_pool_nd(self.dims, kernel_size=2, stride=1) + # return op(x) + # else: + return self.op(x) + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return self._forward(x) + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial**2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, + k * scale) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + pdb.set_trace() + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, + v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter( + th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] diff --git a/model/graph_convolution_network.py b/model/graph_convolution_network.py new file mode 100644 index 0000000..a8d946b --- /dev/null +++ b/model/graph_convolution_network.py @@ -0,0 +1,173 @@ +import torch.nn as nn +import torch +from dataclasses import dataclass +from torch.nn.parameter import Parameter +from numbers import Number +import torch.nn.functional as F +from .blocks import * +import math + + +class graph_convolution(nn.Module): + def __init__(self, in_features, out_features, node_n = 3, seq_len = 80, bias=True): + super(graph_convolution, self).__init__() + + self.temporal_graph_weights = Parameter(torch.FloatTensor(seq_len, seq_len)) + self.feature_weights = Parameter(torch.FloatTensor(in_features, out_features)) + self.spatial_graph_weights = Parameter(torch.FloatTensor(node_n, node_n)) + + if bias: + self.bias = Parameter(torch.FloatTensor(seq_len)) + + self.reset_parameters() + + def reset_parameters(self): + stdv = 1. / math.sqrt(self.spatial_graph_weights.size(1)) + self.feature_weights.data.uniform_(-stdv, stdv) + self.temporal_graph_weights.data.uniform_(-stdv, stdv) + self.spatial_graph_weights.data.uniform_(-stdv, stdv) + if self.bias is not None: + self.bias.data.uniform_(-stdv, stdv) + + def forward(self, x): + y = torch.matmul(x, self.temporal_graph_weights) + y = torch.matmul(y.permute(0, 3, 2, 1), self.feature_weights) + y = torch.matmul(self.spatial_graph_weights, y).permute(0, 3, 2, 1).contiguous() + + if self.bias is not None: + return (y + self.bias) + else: + return y + + +@dataclass +class residual_graph_convolution_config(): + in_features: int + seq_len: int + emb_channels: int + dropout: float + out_features: int = None + node_n: int = 3 + # condition the block with time (and encoder's output) + use_condition: bool = True + # whether to condition with both time & encoder's output + two_cond: bool = False + # number of encoders' output channels + cond_emb_channels: int = None + has_lateral: bool = False + graph_convolution_bias: bool = True + scale_bias: float = 1 + + def __post_init__(self): + self.out_features = self.out_features or self.in_features + self.cond_emb_channels = self.cond_emb_channels or self.emb_channels + + def make_model(self): + return residual_graph_convolution(self) + + +class residual_graph_convolution(TimestepBlock): + def __init__(self, conf: residual_graph_convolution_config): + super(residual_graph_convolution, self).__init__() + self.conf = conf + + self.gcn = graph_convolution(conf.in_features, conf.out_features, node_n=conf.node_n, seq_len=conf.seq_len, bias=conf.graph_convolution_bias) + self.ln = nn.LayerNorm([conf.out_features, conf.node_n, conf.seq_len]) + self.act_f = nn.Tanh() + self.dropout = nn.Dropout(conf.dropout) + + if conf.use_condition: + # condition layers for the out_layers + self.emb_layers = nn.Sequential( + nn.SiLU(), + nn.Linear(conf.emb_channels, conf.out_features), + ) + + if conf.two_cond: + self.cond_emb_layers = nn.Sequential( + nn.SiLU(), + nn.Linear(conf.cond_emb_channels, conf.out_features), + ) + + if conf.in_features == conf.out_features: + self.skip_connection = nn.Identity() + else: + self.skip_connection = nn.Sequential( + graph_convolution(conf.in_features, conf.out_features, node_n=conf.node_n, seq_len=conf.seq_len, bias=conf.graph_convolution_bias), + nn.Tanh(), + ) + + def forward(self, x, emb=None, cond=None, lateral=None): + if self.conf.has_lateral: + # lateral may be supplied even if it doesn't require + # the model will take the lateral only if "has_lateral" + assert lateral is not None + x = torch.cat((x, lateral), dim =1) + + y = self.gcn(x) + y = self.ln(y) + + if self.conf.use_condition: + if emb is not None: + emb = self.emb_layers(emb).type(x.dtype) + # adjusting shapes + while len(emb.shape) < len(y.shape): + emb = emb[..., None] + + if self.conf.two_cond or True: + if cond is not None: + if not isinstance(cond, torch.Tensor): + assert isinstance(cond, dict) + cond = cond['cond'] + cond = self.cond_emb_layers(cond).type(x.dtype) + while len(cond.shape) < len(y.shape): + cond = cond[..., None] + scales = [emb, cond] + else: + scales = [emb] + + # condition scale bias could be a list + if isinstance(self.conf.scale_bias, Number): + biases = [self.conf.scale_bias] * len(scales) + else: + # a list + biases = self.conf.scale_bias + + # scale for each condition + for i, scale in enumerate(scales): + # if scale is None, it indicates that the condition is not provided + if scale is not None: + y = y*(biases[i] + scale) + + y = self.act_f(y) + y = self.dropout(y) + return self.skip_connection(x) + y + + +class graph_downsample(nn.Module): + """ + A downsampling layer + """ + def __init__(self, kernel_size = 2): + super().__init__() + self.downsample = nn.AvgPool1d(kernel_size = kernel_size) + + def forward(self, x): + bs, features, node_n, seq_len = x.shape + x = x.reshape(bs, features*node_n, seq_len) + x = self.downsample(x) + x = x.reshape(bs, features, node_n, -1) + return x + + +class graph_upsample(nn.Module): + """ + An upsampling layer + """ + def __init__(self, scale_factor=2): + super().__init__() + self.scale_factor = scale_factor + + def forward(self, x): + x = F.interpolate(x, (x.shape[2], x.shape[3]*self.scale_factor), mode="nearest") + return x \ No newline at end of file diff --git a/model/nn.py b/model/nn.py new file mode 100644 index 0000000..f9470a8 --- /dev/null +++ b/model/nn.py @@ -0,0 +1,141 @@ +""" +Various utilities for neural networks. +""" + +from enum import Enum +import math, pdb +from typing import Optional + +import torch as th +import torch.nn as nn +import torch.utils.checkpoint + +import torch.nn.functional as F + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + # @th.jit.script + def forward(self, x): + return x * th.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + y = super().forward(x.float()).type(x.dtype) + return y + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + assert dims==1 + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + assert dims==1 + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def update_ema(target_params, source_params, rate=0.99): + """ + Update target parameters to be closer to those of source parameters using + an exponential moving average. + + :param target_params: the target parameter sequence. + :param source_params: the source parameter sequence. + :param rate: the EMA rate (closer to 1 means slower). + """ + for targ, src in zip(target_params, source_params): + targ.detach().mul_(rate).add_(src, alpha=1 - rate) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + # return GroupNorm32(channels, channels) + return GroupNorm32(min(4, channels), channels) + + +def timestep_embedding(timesteps, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = th.exp(-math.log(max_period) * + th.arange(start=0, end=half, dtype=th.float32) / + half).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) + if dim % 2: + embedding = th.cat( + [embedding, th.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +def torch_checkpoint(func, args, flag, preserve_rng_state=False): + # torch's gradient checkpoint works with automatic mixed precision, given torch >= 1.8 + if flag: + return torch.utils.checkpoint.checkpoint( + func, *args, preserve_rng_state=preserve_rng_state) + else: + return func(*args) diff --git a/model/unet.py b/model/unet.py new file mode 100644 index 0000000..ec27f33 --- /dev/null +++ b/model/unet.py @@ -0,0 +1,954 @@ +import math +from dataclasses import dataclass +from numbers import Number +from typing import NamedTuple, Tuple, Union + +import numpy as np +import torch as th +from torch import nn +import torch.nn.functional as F +from choices import * +from config_base import BaseConfig +from .blocks import * +from .graph_convolution_network import * +from .nn import (conv_nd, linear, normalization, timestep_embedding, zero_module) + + +@dataclass +class BeatGANsUNetConfig(BaseConfig): + seq_len: int = 80 + in_channels: int = 9 + # base channels, will be multiplied + model_channels: int = 64 + # output of the unet + out_channels: int = 9 + # how many repeating resblocks per resolution + # the decoding side would have "one more" resblock + # default: 2 + num_res_blocks: int = 2 + # number of time embed channels and style channels + embed_channels: int = 256 + # at what resolutions you want to do self-attention of the feature maps + # attentions generally improve performance + attention_resolutions: Tuple[int] = (0, ) + # dropout applies to the resblocks (on feature maps) + dropout: float = 0.1 + channel_mult: Tuple[int] = (1, 2, 4) + conv_resample: bool = True + # 1 = 1d conv + dims: int = 1 + # number of attention heads + num_heads: int = 1 + # or specify the number of channels per attention head + num_head_channels: int = -1 + # use resblock for upscale/downscale blocks (expensive) + # default: True (BeatGANs) + resblock_updown: bool = True + use_new_attention_order: bool = False + resnet_two_cond: bool = True + resnet_cond_channels: int = None + # init the decoding conv layers with zero weights, this speeds up training + # default: True (BeatGANs) + resnet_use_zero_module: bool = True + + def make_model(self): + return BeatGANsUNetModel(self) + + +class BeatGANsUNetModel(nn.Module): + def __init__(self, conf: BeatGANsUNetConfig): + super().__init__() + self.conf = conf + + self.dtype = th.float32 + + self.time_emb_channels = conf.model_channels + self.time_embed = nn.Sequential( + linear(self.time_emb_channels, conf.embed_channels), + nn.SiLU(), + linear(conf.embed_channels, conf.embed_channels), + ) + + ch = input_ch = int(conf.channel_mult[0] * conf.model_channels) + self.input_blocks = nn.ModuleList([ + TimestepEmbedSequential( + conv_nd(conf.dims, conf.in_channels, ch, 3, padding=1)), + ]) + + kwargs = dict( + use_condition=True, + two_cond=conf.resnet_two_cond, + use_zero_module=conf.resnet_use_zero_module, + # style channels for the resnet block + cond_emb_channels=conf.resnet_cond_channels, + ) + + self._feature_size = ch + + # input_block_chans = [ch] + input_block_chans = [[] for _ in range(len(conf.channel_mult))] + input_block_chans[0].append(ch) + + # number of blocks at each resolution + self.input_num_blocks = [0 for _ in range(len(conf.channel_mult))] + self.input_num_blocks[0] = 1 + self.output_num_blocks = [0 for _ in range(len(conf.channel_mult))] + + ds = 1 + resolution = conf.seq_len + for level, mult in enumerate(conf.channel_mult): + for _ in range(conf.num_res_blocks): + layers = [ + ResBlockConfig( + ch, + conf.embed_channels, + conf.dropout, + out_channels=int(mult * conf.model_channels), + dims=conf.dims, + **kwargs, + ).make_model() + ] + ch = int(mult * conf.model_channels) + if resolution in conf.attention_resolutions: + layers.append( + AttentionBlock( + ch, + num_heads=conf.num_heads, + num_head_channels=conf.num_head_channels, + use_new_attention_order=conf. + use_new_attention_order, + )) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + # input_block_chans.append(ch) + input_block_chans[level].append(ch) + self.input_num_blocks[level] += 1 + # print(input_block_chans) + if level != len(conf.channel_mult) - 1: + resolution //= 2 + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlockConfig( + ch, + conf.embed_channels, + conf.dropout, + out_channels=out_ch, + dims=conf.dims, + down=True, + **kwargs, + ).make_model() if conf. + resblock_updown else Downsample(ch, + conf.conv_resample, + dims=conf.dims, + out_channels=out_ch))) + ch = out_ch + # input_block_chans.append(ch) + input_block_chans[level + 1].append(ch) + self.input_num_blocks[level + 1] += 1 + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlockConfig( + ch, + conf.embed_channels, + conf.dropout, + dims=conf.dims, + **kwargs, + ).make_model(), + #AttentionBlock( + # ch, + # num_heads=conf.num_heads, + # num_head_channels=conf.num_head_channels, + # use_new_attention_order=conf.use_new_attention_order, + #), + ResBlockConfig( + ch, + conf.embed_channels, + conf.dropout, + dims=conf.dims, + **kwargs, + ).make_model(), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(conf.channel_mult))[::-1]: + for i in range(conf.num_res_blocks + 1): + # print(input_block_chans) + # ich = input_block_chans.pop() + try: + ich = input_block_chans[level].pop() + except IndexError: + # this happens only when num_res_block > num_enc_res_block + # we will not have enough lateral (skip) connecions for all decoder blocks + ich = 0 + # print('pop:', ich) + layers = [ + ResBlockConfig( + # only direct channels when gated + channels=ch + ich, + emb_channels=conf.embed_channels, + dropout=conf.dropout, + out_channels=int(conf.model_channels * mult), + dims=conf.dims, + # lateral channels are described here when gated + has_lateral=True if ich > 0 else False, + lateral_channels=None, + **kwargs, + ).make_model() + ] + ch = int(conf.model_channels * mult) + if resolution in conf.attention_resolutions: + layers.append( + AttentionBlock( + ch, + num_heads=conf.num_heads, + num_head_channels=conf.num_head_channels, + use_new_attention_order=conf. + use_new_attention_order, + )) + if level and i == conf.num_res_blocks: + resolution *= 2 + out_ch = ch + layers.append( + ResBlockConfig( + ch, + conf.embed_channels, + conf.dropout, + out_channels=out_ch, + dims=conf.dims, + up=True, + **kwargs, + ).make_model() if ( + conf.resblock_updown + ) else Upsample(ch, + conf.conv_resample, + dims=conf.dims, + out_channels=out_ch)) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self.output_num_blocks[level] += 1 + self._feature_size += ch + + # print(input_block_chans) + # print('inputs:', self.input_num_blocks) + # print('outputs:', self.output_num_blocks) + + if conf.resnet_use_zero_module: + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module( + conv_nd(conf.dims, + input_ch, + conf.out_channels, + 3, ## kernel size + padding=1)), + ) + else: + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + conv_nd(conf.dims, input_ch, conf.out_channels, 3, padding=1), + ) + + def forward(self, x, t, **kwargs): + """ + Apply the model to an input batch. + + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :return: an [N x C x ...] Tensor of outputs. + """ + + # hs = [] + hs = [[] for _ in range(len(self.conf.channel_mult))] + emb = self.time_embed(timestep_embedding(t, self.time_emb_channels)) + + # new code supports input_num_blocks != output_num_blocks + h = x.type(self.dtype) + k = 0 + for i in range(len(self.input_num_blocks)): + for j in range(self.input_num_blocks[i]): + h = self.input_blocks[k](h, emb=emb) + # print(i, j, h.shape) + hs[i].append(h) ## Get output from each layer + k += 1 + assert k == len(self.input_blocks) + + # middle blocks + h = self.middle_block(h, emb=emb) + + # output blocks + k = 0 + for i in range(len(self.output_num_blocks)): + for j in range(self.output_num_blocks[i]): + # take the lateral connection from the same layer (in reserve) + # until there is no more, use None + try: + lateral = hs[-i - 1].pop() + # print(i, j, lateral.shape) + except IndexError: + lateral = None + # print(i, j, lateral) + h = self.output_blocks[k](h, emb=emb, lateral=lateral) + k += 1 + + h = h.type(x.dtype) + pred = self.out(h) + return Return(pred=pred) + + +class Return(NamedTuple): + pred: th.Tensor + + +@dataclass +class BeatGANsEncoderConfig(BaseConfig): + in_channels: int + seq_len: int = 80 + num_res_blocks: int = 2 + attention_resolutions: Tuple[int] = (0, ) + model_channels: int = 32 + out_channels: int = 256 + dropout: float = 0.1 + channel_mult: Tuple[int] = (1, 2, 4) + use_time_condition: bool = False + conv_resample: bool = True + dims: int = 1 + num_heads: int = 1 + num_head_channels: int = -1 + resblock_updown: bool = True + use_new_attention_order: bool = False + pool: str = 'adaptivenonzero' + + def make_model(self): + return BeatGANsEncoderModel(self) + + +class BeatGANsEncoderModel(nn.Module): + """ + The half UNet model with attention and timestep embedding. + + For usage, see UNet. + """ + def __init__(self, conf: BeatGANsEncoderConfig): + super().__init__() + self.conf = conf + self.dtype = th.float32 + + if conf.use_time_condition: + time_embed_dim = conf.model_channels + self.time_embed = nn.Sequential( + linear(conf.model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + else: + time_embed_dim = None + + ch = int(conf.channel_mult[0] * conf.model_channels) + self.input_blocks = nn.ModuleList([ + TimestepEmbedSequential( + conv_nd(conf.dims, conf.in_channels, ch, 3, padding=1), + ) + ]) + self._feature_size = ch + input_block_chans = [ch] + ds = 1 + resolution = conf.seq_len + for level, mult in enumerate(conf.channel_mult): + for _ in range(conf.num_res_blocks): + layers = [ + ResBlockConfig( + ch, + time_embed_dim, + conf.dropout, + out_channels=int(mult * conf.model_channels), + dims=conf.dims, + use_condition=conf.use_time_condition, + ).make_model() + ] + ch = int(mult * conf.model_channels) + if resolution in conf.attention_resolutions: + layers.append( + AttentionBlock( + ch, + num_heads=conf.num_heads, + num_head_channels=conf.num_head_channels, + use_new_attention_order=conf. + use_new_attention_order, + )) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(conf.channel_mult) - 1: + resolution //= 2 + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlockConfig( + ch, + time_embed_dim, + conf.dropout, + out_channels=out_ch, + dims=conf.dims, + use_condition=conf.use_time_condition, + down=True, + ).make_model() if ( + conf.resblock_updown + ) else Downsample(ch, + conf.conv_resample, + dims=conf.dims, + out_channels=out_ch))) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlockConfig( + ch, + time_embed_dim, + conf.dropout, + dims=conf.dims, + use_condition=conf.use_time_condition, + ).make_model(), + AttentionBlock( + ch, + num_heads=conf.num_heads, + num_head_channels=conf.num_head_channels, + use_new_attention_order=conf.use_new_attention_order, + ), + ResBlockConfig( + ch, + time_embed_dim, + conf.dropout, + dims=conf.dims, + use_condition=conf.use_time_condition, + ).make_model(), + ) + self._feature_size += ch + if conf.pool == "adaptivenonzero": + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + ## nn.AdaptiveAvgPool2d((1, 1)), + nn.AdaptiveAvgPool1d(1), + conv_nd(conf.dims, ch, conf.out_channels, 1), + nn.Flatten(), + ) + else: + raise NotImplementedError(f"Unexpected {conf.pool} pooling") + + def forward(self, x, t=None): + """ + Apply the model to an input batch. + + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :return: an [N x K] Tensor of outputs. + """ + if self.conf.use_time_condition: + emb = self.time_embed(timestep_embedding(t, self.model_channels)) + else: ## autoencoding.py + emb = None + + results = [] + h = x.type(self.dtype) + for module in self.input_blocks: ## flow input x over all the input blocks + h = module(h, emb=emb) + if self.conf.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = self.middle_block(h, emb=emb) ## TimestepEmbedSequential(...) + if self.conf.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = th.cat(results, axis=-1) + else: ## autoencoder.py + h = h.type(x.dtype) + + h = h.float() + h = self.out(h) + + return h + + +@dataclass +class GCNUNetConfig(BaseConfig): + in_channels: int = 9 + node_n: int = 3 + seq_len: int = 80 + # base channels, will be multiplied + model_channels: int = 32 + # output of the unet + out_channels: int = 9 + # how many repeating resblocks per resolution + num_res_blocks: int = 8 + # number of time embed channels and style channels + embed_channels: int = 256 + # dropout applies to the resblocks + dropout: float = 0.1 + channel_mult: Tuple[int] = (1, 2, 4) + resnet_two_cond: bool = True + + def make_model(self): + return GCNUNetModel(self) + + +class GCNUNetModel(nn.Module): + def __init__(self, conf: GCNUNetConfig): + super().__init__() + self.conf = conf + self.dtype = th.float32 + assert conf.in_channels%conf.node_n == 0 + self.in_features = conf.in_channels//conf.node_n + + self.time_emb_channels = conf.model_channels*4 + self.time_embed = nn.Sequential( + linear(self.time_emb_channels, conf.embed_channels), + nn.SiLU(), + linear(conf.embed_channels, conf.embed_channels), + ) + + ch = int(conf.channel_mult[0] * conf.model_channels) + self.input_blocks = nn.ModuleList([ + TimestepEmbedSequential( + graph_convolution(in_features=self.in_features, out_features=ch, node_n=conf.node_n, seq_len=conf.seq_len)), + ]) + + kwargs = dict( + use_condition=True, + two_cond=conf.resnet_two_cond, + ) + + input_block_chans = [[] for _ in range(len(conf.channel_mult))] + input_block_chans[0].append(ch) + + # number of blocks at each resolution + self.input_num_blocks = [0 for _ in range(len(conf.channel_mult))] + self.input_num_blocks[0] = 1 + self.output_num_blocks = [0 for _ in range(len(conf.channel_mult))] + + ds = 1 + resolution = conf.seq_len + for level, mult in enumerate(conf.channel_mult): + for _ in range(conf.num_res_blocks): + layers = [ + residual_graph_convolution_config( + in_features=ch, + seq_len=resolution, + emb_channels = conf.embed_channels, + dropout=conf.dropout, + out_features=int(mult * conf.model_channels), + node_n=conf.node_n, + **kwargs, + ).make_model() + ] + ch = int(mult * conf.model_channels) + self.input_blocks.append(*layers) + input_block_chans[level].append(ch) + self.input_num_blocks[level] += 1 + if level != len(conf.channel_mult) - 1: + resolution //= 2 + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + graph_downsample())) + ch = out_ch + input_block_chans[level + 1].append(ch) + self.input_num_blocks[level + 1] += 1 + ds *= 2 + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(conf.channel_mult))[::-1]: + for i in range(conf.num_res_blocks + 1): + try: + ich = input_block_chans[level].pop() + except IndexError: + # this happens only when num_res_block > num_enc_res_block + # we will not have enough lateral (skip) connecions for all decoder blocks + ich = 0 + layers = [ + residual_graph_convolution_config( + in_features=ch + ich, + seq_len=resolution, + emb_channels = conf.embed_channels, + dropout=conf.dropout, + out_features=int(mult * conf.model_channels), + node_n=conf.node_n, + has_lateral=True if ich > 0 else False, + **kwargs, + ).make_model() + ] + ch = int(mult*conf.model_channels) + if level and i == conf.num_res_blocks: + resolution *= 2 + out_ch = ch + layers.append(graph_upsample()) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self.output_num_blocks[level] += 1 + + self.out = nn.Sequential( + graph_convolution(in_features=ch, out_features=self.in_features, node_n=conf.node_n, seq_len=conf.seq_len), + nn.Tanh(), + ) + + def forward(self, x, t, **kwargs): + """ + Apply the model to an input batch. + + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :return: an [N x C x ...] Tensor of outputs. + """ + bs, channels, seq_len = x.shape + x = x.reshape(bs, self.conf.node_n, self.in_features, seq_len).permute(0, 2, 1, 3) + + hs = [[] for _ in range(len(self.conf.channel_mult))] + emb = self.time_embed(timestep_embedding(t, self.time_emb_channels)) + + # new code supports input_num_blocks != output_num_blocks + h = x.type(self.dtype) + k = 0 + for i in range(len(self.input_num_blocks)): + for j in range(self.input_num_blocks[i]): + h = self.input_blocks[k](h, emb=emb) + # print(i, j, h.shape) + hs[i].append(h) ## Get output from each layer + k += 1 + assert k == len(self.input_blocks) + + # output blocks + k = 0 + for i in range(len(self.output_num_blocks)): + for j in range(self.output_num_blocks[i]): + # take the lateral connection from the same layer (in reserve) + # until there is no more, use None + try: + lateral = hs[-i - 1].pop() + # print(i, j, lateral.shape) + except IndexError: + lateral = None + # print(i, j, lateral) + h = self.output_blocks[k](h, emb=emb, lateral=lateral) + k += 1 + + h = h.type(x.dtype) + pred = self.out(h) + pred = pred.permute(0, 2, 1, 3).reshape(bs, -1, seq_len) + + return Return(pred=pred) + + +@dataclass +class GCNEncoderConfig(BaseConfig): + in_channels: int + in_features = 3 # features for one node + seq_len: int = 40 + seq_len_future: int = 3 + num_res_blocks: int = 2 + model_channels: int = 32 + out_channels: int = 32 + dropout: float = 0.1 + channel_mult: Tuple[int] = (1, 2, 4) + use_time_condition: bool = False + + def make_model(self): + return GCNEncoderModel(self) + + +class GCNEncoderModel(nn.Module): + def __init__(self, conf: GCNEncoderConfig): + super().__init__() + self.conf = conf + self.dtype = th.float32 + assert conf.in_channels%conf.in_features == 0 + self.in_features = conf.in_features + self.node_n = conf.in_channels//conf.in_features + + ch = int(conf.channel_mult[0] * conf.model_channels) + self.input_blocks = nn.ModuleList([ + graph_convolution(in_features=self.in_features, out_features=ch, node_n=self.node_n, seq_len=conf.seq_len), + ]) + input_block_chans = [ch] + ds = 1 + resolution = conf.seq_len + for level, mult in enumerate(conf.channel_mult): + for _ in range(conf.num_res_blocks): + layers = [ + residual_graph_convolution_config( + in_features=ch, + seq_len=resolution, + emb_channels = None, + dropout=conf.dropout, + out_features=int(mult * conf.model_channels), + node_n=self.node_n, + use_condition=conf.use_time_condition, + ).make_model() + ] + ch = int(mult * conf.model_channels) + self.input_blocks.append(*layers) + input_block_chans.append(ch) + if level != len(conf.channel_mult) - 1: + resolution //= 2 + out_ch = ch + self.input_blocks.append( + graph_downsample()) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + + self.hand_prediction = nn.Sequential( + conv_nd(1, ch*2, ch*2, 3, padding=1), + nn.LayerNorm([ch*2, conf.seq_len_future]), + nn.Tanh(), + conv_nd(1, ch*2, self.in_features*2, 1), + nn.Tanh(), + ) + + self.head_prediction = nn.Sequential( + conv_nd(1, ch, ch, 3, padding=1), + nn.LayerNorm([ch, conf.seq_len_future]), + nn.Tanh(), + conv_nd(1, ch, self.in_features, 1), + nn.Tanh(), + ) + + self.out = nn.Sequential( + nn.AdaptiveAvgPool1d(1), + conv_nd(1, ch*self.node_n, conf.out_channels, 1), + nn.Flatten(), + ) + + + def forward(self, x, t=None): + bs, channels, seq_len = x.shape + + if self.node_n == 3: # both hand and head + hand_last = x[:, :6, -1:].expand(-1, -1, self.conf.seq_len_future).clone() #last observed hand position + head_last = x[:, 6:, -1:].expand(-1, -1, self.conf.seq_len_future).clone()# last observed head orientation + + if self.node_n == 2: # hand only + hand_last = x[:, :, -1:].expand(-1, -1, self.conf.seq_len_future).clone() #last observed hand position + + if self.node_n == 1: # head only + head_last = x[:, :, -1:].expand(-1, -1, self.conf.seq_len_future).clone()# last observed head orientation + + + x = x.reshape(bs, self.node_n, self.in_features, seq_len).permute(0, 2, 1, 3) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h) + + h = h.type(x.dtype) + h = h.float() + bs, features, node_n, seq_len = h.shape + + if self.node_n == 3: # both hand and head + hand_features = h[:, :, :2, -self.conf.seq_len_future:].reshape(bs, features*2, -1) + head_features = h[:, :, 2:, -self.conf.seq_len_future:].reshape(bs, features, -1) + + pred_hand = self.hand_prediction(hand_features) + hand_last + pred_head = self.head_prediction(head_features) + head_last + pred_head = F.normalize(pred_head, dim=1)# normalize head orientation to unit vectors + + if self.node_n == 2: # hand only + hand_features = h[:, :, :, -self.conf.seq_len_future:].reshape(bs, features*2, -1) + pred_hand = self.hand_prediction(hand_features) + hand_last + pred_head = None + + if self.node_n == 1: # head only + head_features = h[:, :, :, -self.conf.seq_len_future:].reshape(bs, features, -1) + pred_head = self.head_prediction(head_features) + head_last + pred_head = F.normalize(pred_head, dim=1)# normalize head orientation to unit vectors + pred_hand = None + + h = h.reshape(bs, features*node_n, seq_len) + h = self.out(h) + + return h, pred_hand, pred_head + + +@dataclass +class CNNEncoderConfig(BaseConfig): + in_channels: int + seq_len: int = 40 + seq_len_future: int = 3 + out_channels: int = 128 + + def make_model(self): + return CNNEncoderModel(self) + + +class CNNEncoderModel(nn.Module): + def __init__(self, conf: CNNEncoderConfig): + super().__init__() + self.conf = conf + self.dtype = th.float32 + input_dim = conf.in_channels + length = conf.seq_len + out_channels = conf.out_channels + + self.encoder = nn.Sequential( + nn.Conv1d(input_dim, 32, kernel_size=3, padding=1), + nn.LayerNorm([32, length]), + nn.ReLU(inplace=True), + nn.Conv1d(32, 32, kernel_size=3, padding=1), + nn.LayerNorm([32, length]), + nn.ReLU(inplace=True), + nn.Conv1d(32, 32, kernel_size=3, padding=1), + nn.LayerNorm([32, length]), + nn.ReLU(inplace=True) + ) + + self.out = nn.Linear(32 * length, out_channels) + + def forward(self, x, t=None): + bs, channels, seq_len = x.shape + hand_last = x[:, :6, -1:].expand(-1, -1, self.conf.seq_len_future).clone() #last observed hand position + head_last = x[:, 6:, -1:].expand(-1, -1, self.conf.seq_len_future).clone()# last observed head orientation + + h = x.type(self.dtype) + h = self.encoder(h) + h = h.view(h.shape[0], -1) + + h = h.type(x.dtype) + h = h.float() + + h = self.out(h) + return h, hand_last, head_last + + +@dataclass +class GRUEncoderConfig(BaseConfig): + in_channels: int + seq_len: int = 40 + seq_len_future: int = 3 + out_channels: int = 128 + + def make_model(self): + return GRUEncoderModel(self) + + +class GRUEncoderModel(nn.Module): + def __init__(self, conf: GRUEncoderConfig): + super().__init__() + self.conf = conf + self.dtype = th.float32 + input_dim = conf.in_channels + length = conf.seq_len + feature_channels = 32 + out_channels = conf.out_channels + + self.encoder = nn.GRU(input_dim, feature_channels, 1, batch_first=True) + + self.out = nn.Linear(feature_channels * length, out_channels) + + def forward(self, x, t=None): + bs, channels, seq_len = x.shape + hand_last = x[:, :6, -1:].expand(-1, -1, self.conf.seq_len_future).clone() #last observed hand position + head_last = x[:, 6:, -1:].expand(-1, -1, self.conf.seq_len_future).clone()# last observed head orientation + + h = x.type(self.dtype) + h, _ = self.encoder(h.permute(0, 2, 1)) + h = h.reshape(h.shape[0], -1) + + h = h.type(x.dtype) + h = h.float() + + h = self.out(h) + return h, hand_last, head_last + + +@dataclass +class LSTMEncoderConfig(BaseConfig): + in_channels: int + seq_len: int = 40 + seq_len_future: int = 3 + out_channels: int = 128 + + def make_model(self): + return LSTMEncoderModel(self) + + +class LSTMEncoderModel(nn.Module): + def __init__(self, conf: LSTMEncoderConfig): + super().__init__() + self.conf = conf + self.dtype = th.float32 + input_dim = conf.in_channels + length = conf.seq_len + feature_channels = 32 + out_channels = conf.out_channels + + self.encoder = nn.LSTM(input_dim, feature_channels, 1, batch_first=True) + + self.out = nn.Linear(feature_channels * length, out_channels) + + def forward(self, x, t=None): + bs, channels, seq_len = x.shape + hand_last = x[:, :6, -1:].expand(-1, -1, self.conf.seq_len_future).clone() #last observed hand position + head_last = x[:, 6:, -1:].expand(-1, -1, self.conf.seq_len_future).clone()# last observed head orientation + + h = x.type(self.dtype) + h, _ = self.encoder(h.permute(0, 2, 1)) + h = h.reshape(h.shape[0], -1) + + h = h.type(x.dtype) + h = h.float() + + h = self.out(h) + return h, hand_last, head_last + + +@dataclass +class MLPEncoderConfig(BaseConfig): + in_channels: int + seq_len: int = 40 + seq_len_future: int = 3 + out_channels: int = 128 + + def make_model(self): + return MLPEncoderModel(self) + + +class MLPEncoderModel(nn.Module): + def __init__(self, conf: MLPEncoderConfig): + super().__init__() + self.conf = conf + self.dtype = th.float32 + input_dim = conf.in_channels + length = conf.seq_len + out_channels = conf.out_channels + + linear_size = 128 + self.encoder = nn.Sequential( + nn.Linear(length*input_dim, linear_size), + nn.LayerNorm([linear_size]), + nn.ReLU(inplace=True), + nn.Linear(linear_size, linear_size), + nn.LayerNorm([linear_size]), + nn.ReLU(inplace=True), + ) + + self.out = nn.Linear(linear_size, out_channels) + + def forward(self, x, t=None): + bs, channels, seq_len = x.shape + hand_last = x[:, :6, -1:].expand(-1, -1, self.conf.seq_len_future).clone() #last observed hand position + head_last = x[:, 6:, -1:].expand(-1, -1, self.conf.seq_len_future).clone()# last observed head orientation + + h = x.type(self.dtype) + + h = h.view(h.shape[0], -1) + h = self.encoder(h) + + h = h.type(x.dtype) + h = h.float() + + h = self.out(h) + return h, hand_last, head_last \ No newline at end of file diff --git a/model/unet_autoenc.py b/model/unet_autoenc.py new file mode 100644 index 0000000..4f31d05 --- /dev/null +++ b/model/unet_autoenc.py @@ -0,0 +1,418 @@ +from enum import Enum + +import torch, pdb +import os +from torch import Tensor +from torch.nn.functional import silu +from .unet import * +from choices import * + + +@dataclass +class BeatGANsAutoencConfig(BeatGANsUNetConfig): + seq_len_future: int = 3 + enc_out_channels: int = 128 + semantic_encoder_type: str = 'gcn' + enc_channel_mult: Tuple[int] = None + def make_model(self): + return BeatGANsAutoencModel(self) + +class BeatGANsAutoencModel(BeatGANsUNetModel): + def __init__(self, conf: BeatGANsAutoencConfig): + super().__init__(conf) + self.conf = conf + + # having only time, cond + self.time_embed = TimeStyleSeperateEmbed( + time_channels=conf.model_channels, + time_out_channels=conf.embed_channels, + ) + + if conf.semantic_encoder_type == 'gcn': + self.encoder = GCNEncoderConfig( + seq_len=conf.seq_len, + seq_len_future=conf.seq_len_future, + in_channels=conf.in_channels, + model_channels=16, + out_channels=conf.enc_out_channels, + channel_mult=conf.enc_channel_mult or conf.channel_mult, + ).make_model() + elif conf.semantic_encoder_type == '1dcnn': + self.encoder = CNNEncoderConfig( + seq_len=conf.seq_len, + seq_len_future=conf.seq_len_future, + in_channels=conf.in_channels, + out_channels=conf.enc_out_channels, + ).make_model() + elif conf.semantic_encoder_type == 'gru': + # ensure deterministic behavior of RNNs + os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:2" + self.encoder = GRUEncoderConfig( + seq_len=conf.seq_len, + seq_len_future=conf.seq_len_future, + in_channels=conf.in_channels, + out_channels=conf.enc_out_channels, + ).make_model() + elif conf.semantic_encoder_type == 'lstm': + # ensure deterministic behavior of RNNs + os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:2" + self.encoder = LSTMEncoderConfig( + seq_len=conf.seq_len, + seq_len_future=conf.seq_len_future, + in_channels=conf.in_channels, + out_channels=conf.enc_out_channels, + ).make_model() + elif conf.semantic_encoder_type == 'mlp': + self.encoder = MLPEncoderConfig( + seq_len=conf.seq_len, + seq_len_future=conf.seq_len_future, + in_channels=conf.in_channels, + out_channels=conf.enc_out_channels, + ).make_model() + else: + raise NotImplementedError() + + def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: + """ + Reparameterization trick to sample from N(mu, var) from + N(0,1). + :param mu: (Tensor) Mean of the latent Gaussian [B x D] + :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D] + :return: (Tensor) [B x D] + """ + assert self.conf.is_stochastic + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + return eps * std + mu + + def sample_z(self, n: int, device): + assert self.conf.is_stochastic + return torch.randn(n, self.conf.enc_out_channels, device=device) + + def noise_to_cond(self, noise: Tensor): + raise NotImplementedError() + assert self.conf.noise_net_conf is not None + return self.noise_net.forward(noise) + + def encode(self, x): + cond, pred_hand, pred_head = self.encoder.forward(x) + return cond, pred_hand, pred_head + + @property + def stylespace_sizes(self): + modules = list(self.input_blocks.modules()) + list( + self.middle_block.modules()) + list(self.output_blocks.modules()) + sizes = [] + for module in modules: + if isinstance(module, ResBlock): + linear = module.cond_emb_layers[-1] + sizes.append(linear.weight.shape[0]) + return sizes + + def encode_stylespace(self, x, return_vector: bool = True): + """ + encode to style space + """ + modules = list(self.input_blocks.modules()) + list( + self.middle_block.modules()) + list(self.output_blocks.modules()) + # (n, c) + cond = self.encoder.forward(x) + S = [] + for module in modules: + if isinstance(module, ResBlock): + # (n, c') + s = module.cond_emb_layers.forward(cond) + S.append(s) + + if return_vector: + # (n, sum_c) + return torch.cat(S, dim=1) + else: + return S + + def forward(self, + x, + t, + x_start=None, + cond=None, + style=None, + noise=None, + t_cond=None, + **kwargs): + """ + Apply the model to an input batch. + + Args: + x_start: the original image to encode + cond: output of the encoder + noise: random noise (to predict the cond) + """ + if t_cond is None: + t_cond = t ## randomly sampled timestep with the size of [batch_size] + + if noise is not None: + # if the noise is given, we predict the cond from noise + cond = self.noise_to_cond(noise) + + cond_given = True + if cond is None: + cond_given = False + if x is not None: + assert len(x) == len(x_start), f'{len(x)} != {len(x_start)}' + + cond, pred_hand, pred_head = self.encode(x_start) + + if t is not None: ## t==t_cond + _t_emb = timestep_embedding(t, self.conf.model_channels) + #print("t: {}, _t_emb:{}".format(t, _t_emb)) + _t_cond_emb = timestep_embedding(t_cond, self.conf.model_channels) + #print("t_cond: {}, _t_cond_emb:{}".format(t, _t_cond_emb)) + else: + # this happens when training only autoenc + _t_emb = None + _t_cond_emb = None + + if self.conf.resnet_two_cond: + res = self.time_embed.forward( ## self.time_embed is an MLP + time_emb=_t_emb, + cond=cond, + time_cond_emb=_t_cond_emb, + ) + else: + raise NotImplementedError() + + if self.conf.resnet_two_cond: + # two cond: first = time emb, second = cond_emb + emb = res.time_emb + cond_emb = res.emb + else: + # one cond = combined of both time and cond + emb = res.emb + cond_emb = None + + # override the style if given + style = style or res.style ## style==None, res.style: cond, torch.Size([64, 512]) + + + # where in the model to supply time conditions + enc_time_emb = emb ## time embeddings + mid_time_emb = emb + dec_time_emb = emb + # where in the model to supply style conditions + enc_cond_emb = cond_emb ## z_sem embeddings + mid_cond_emb = cond_emb + dec_cond_emb = cond_emb + + # hs = [] + hs = [[] for _ in range(len(self.conf.channel_mult))] + + if x is not None: + h = x.type(self.dtype) + # input blocks + k = 0 + for i in range(len(self.input_num_blocks)): + for j in range(self.input_num_blocks[i]): + h = self.input_blocks[k](h, + emb=enc_time_emb, + cond=enc_cond_emb) + # print(i, j, h.shape) + '''if h.shape[-1]%2==1: + pdb.set_trace()''' + hs[i].append(h) + k += 1 + assert k == len(self.input_blocks) + + # middle blocks + h = self.middle_block(h, emb=mid_time_emb, cond=mid_cond_emb) + else: + # no lateral connections + # happens when training only the autonecoder + h = None + hs = [[] for _ in range(len(self.conf.channel_mult))] + pdb.set_trace() + + # output blocks + k = 0 + for i in range(len(self.output_num_blocks)): + for j in range(self.output_num_blocks[i]): + # take the lateral connection from the same layer (in reserve) + # until there is no more, use None + try: + lateral = hs[-i - 1].pop() ## in the reverse order (symmetric) + except IndexError: + lateral = None + '''print(i, j, lateral.shape, h.shape) + if lateral.shape[-1]!=h.shape[-1]: + pdb.set_trace()''' + # print("h is", h.size()) + # print("lateral is", lateral.size()) + h = self.output_blocks[k](h, + emb=dec_time_emb, + cond=dec_cond_emb, + lateral=lateral) + k += 1 + + pred = self.out(h) + # print("h:", h.shape) + # print("pred:", pred.shape) + + if cond_given == True: + return AutoencReturn(pred=pred, cond=cond) + else: + return AutoencReturn(pred=pred, cond=cond, pred_hand=pred_hand, pred_head=pred_head) + + +@dataclass +class GCNAutoencConfig(GCNUNetConfig): + # number of style channels + enc_out_channels: int = 256 + enc_channel_mult: Tuple[int] = None + def make_model(self): + return GCNAutoencModel(self) + + +class GCNAutoencModel(GCNUNetModel): + def __init__(self, conf: GCNAutoencConfig): + super().__init__(conf) + self.conf = conf + + # having only time, cond + self.time_emb_channels = conf.model_channels + self.time_embed = TimeStyleSeperateEmbed( + time_channels=self.time_emb_channels, + time_out_channels=conf.embed_channels, + ) + + self.encoder = GCNEncoderConfig( + seq_len=conf.seq_len, + in_channels=conf.in_channels, + model_channels=32, + out_channels=conf.enc_out_channels, + channel_mult=conf.enc_channel_mult or conf.channel_mult, + ).make_model() + + def encode(self, x): + cond = self.encoder.forward(x) + return {'cond': cond} + + def forward(self, + x, + t, + x_start=None, + cond=None, + **kwargs): + """ + Apply the model to an input batch. + + Args: + x_start: the original input to encode + cond: output of the encoder + """ + + if cond is None: + if x is not None: + assert len(x) == len(x_start), f'{len(x)} != {len(x_start)}' + + tmp = self.encode(x_start) + cond = tmp['cond'] + + if t is not None: + _t_emb = timestep_embedding(t, self.time_emb_channels) + else: + # this happens when training only autoenc + _t_emb = None + + if self.conf.resnet_two_cond: + res = self.time_embed.forward( ## self.time_embed is an MLP + time_emb=_t_emb, + cond=cond, + ) + # two cond: first = time emb, second = cond_emb + emb = res.time_emb + cond_emb = res.emb + else: + raise NotImplementedError() + + # where in the model to supply time conditions + enc_time_emb = emb ## time embeddings + mid_time_emb = emb + dec_time_emb = emb + enc_cond_emb = cond_emb ## z_sem embeddings + mid_cond_emb = cond_emb + dec_cond_emb = cond_emb + + + bs, channels, seq_len = x.shape + x = x.reshape(bs, self.conf.node_n, self.in_features, seq_len).permute(0, 2, 1, 3) + hs = [[] for _ in range(len(self.conf.channel_mult))] + h = x.type(self.dtype) + + # input blocks + k = 0 + for i in range(len(self.input_num_blocks)): + for j in range(self.input_num_blocks[i]): + h = self.input_blocks[k](h, + emb=enc_time_emb, + cond=enc_cond_emb) + hs[i].append(h) + k += 1 + assert k == len(self.input_blocks) + + # output blocks + k = 0 + for i in range(len(self.output_num_blocks)): + for j in range(self.output_num_blocks[i]): + # take the lateral connection from the same layer (in reserve) + # until there is no more, use None + try: + lateral = hs[-i - 1].pop() ## in the reverse order (symmetric) + except IndexError: + lateral = None + h = self.output_blocks[k](h, + emb=dec_time_emb, + cond=dec_cond_emb, + lateral=lateral) + k += 1 + + + pred = self.out(h) + pred = pred.permute(0, 2, 1, 3).reshape(bs, -1, seq_len) + + return AutoencReturn(pred=pred, cond=cond) + + +class AutoencReturn(NamedTuple): + pred: Tensor + cond: Tensor = None + pred_hand: Tensor = None + pred_head: Tensor = None + + +class EmbedReturn(NamedTuple): + # style and time + emb: Tensor = None + # time only + time_emb: Tensor = None + # style only (but could depend on time) + style: Tensor = None + + +class TimeStyleSeperateEmbed(nn.Module): + # embed only style + def __init__(self, time_channels, time_out_channels): + super().__init__() + self.time_embed = nn.Sequential( + linear(time_channels, time_out_channels), + nn.SiLU(), + linear(time_out_channels, time_out_channels), + ) + self.style = nn.Identity() + + def forward(self, time_emb=None, cond=None, **kwargs): + if time_emb is None: + # happens with autoenc training mode + time_emb = None + else: + time_emb = self.time_embed(time_emb) + style = self.style(cond) ## style==cond + return EmbedReturn(emb=style, time_emb=time_emb, style=style) \ No newline at end of file diff --git a/preprocess.py b/preprocess.py new file mode 100644 index 0000000..417d61a --- /dev/null +++ b/preprocess.py @@ -0,0 +1,193 @@ +import os, random, math, copy +import pandas as pd +import numpy as np +import pickle as pkl +import logging, sys +from torch.utils.data import DataLoader,Dataset +import multiprocessing as mp +import json +import matplotlib.pyplot as plt + + +def MakeDir(dirpath): + if not os.path.exists(dirpath): + os.makedirs(dirpath) + + +def load_egobody(data_dir, seq_len, sample_rate=1, train=1): + data_dir_train = data_dir + 'train/' + data_dir_test = data_dir + 'test/' + if train == 0: + data_dirs = [data_dir_test] # test + elif train == 1: + data_dirs = [data_dir_train] # train + elif train == 2: + data_dirs = [data_dir_train, data_dir_test] # train + test + + hand_head = [] + for data_dir in data_dirs: + file_paths = sorted(os.listdir(data_dir)) + pose_xyz_file_paths = [] + head_file_paths = [] + for path in file_paths: + path_split = path.split('_') + data_type = path_split[-1][:-4] + if(data_type == 'xyz'): + pose_xyz_file_paths.append(path) + if(data_type == 'head'): + head_file_paths.append(path) + + file_num = len(pose_xyz_file_paths) + for i in range(file_num): + pose_data = np.load(data_dir + pose_xyz_file_paths[i]) + head_data = np.load(data_dir + head_file_paths[i]) + num_frames = pose_data.shape[0] + if num_frames < seq_len: + continue + + head_pos = pose_data[:, 15*3:16*3] + left_hand_pos = pose_data[:, 20*3:21*3] + right_hand_pos = pose_data[:, 21*3:22*3] + head_ori = head_data + left_hand_pos -= head_pos # convert hand positions to head coordinate system + right_hand_pos -= head_pos + hand_head_data = left_hand_pos + hand_head_data = np.concatenate((hand_head_data, right_hand_pos), axis=1) + hand_head_data = np.concatenate((hand_head_data, head_ori), axis=1) + + fs = np.arange(0, num_frames - seq_len + 1) + fs_sel = fs + for i in np.arange(seq_len - 1): + fs_sel = np.vstack((fs_sel, fs + i + 1)) + fs_sel = fs_sel.transpose() + seq_sel = hand_head_data[fs_sel, :] + seq_sel = seq_sel[0::sample_rate, :, :] + if len(hand_head) == 0: + hand_head = seq_sel + else: + hand_head = np.concatenate((hand_head, seq_sel), axis=0) + + hand_head = np.transpose(hand_head, (0, 2, 1)) + return hand_head + + +def load_adt(data_dir, seq_len, sample_rate=1, train=1): + data_dir_train = data_dir + 'train/' + data_dir_test = data_dir + 'test/' + if train == 0: + data_dirs = [data_dir_test] # test + elif train == 1: + data_dirs = [data_dir_train] # train + elif train == 2: + data_dirs = [data_dir_train, data_dir_test] # train + test + + hand_head = [] + for data_dir in data_dirs: + file_paths = sorted(os.listdir(data_dir)) + pose_xyz_file_paths = [] + head_file_paths = [] + for path in file_paths: + path_split = path.split('_') + data_type = path_split[-1][:-4] + if(data_type == 'xyz'): + pose_xyz_file_paths.append(path) + if(data_type == 'head'): + head_file_paths.append(path) + + file_num = len(pose_xyz_file_paths) + for i in range(file_num): + pose_data = np.load(data_dir + pose_xyz_file_paths[i]) + head_data = np.load(data_dir + head_file_paths[i]) + num_frames = pose_data.shape[0] + if num_frames < seq_len: + continue + + head_pos = pose_data[:, 4*3:5*3] + left_hand_pos = pose_data[:, 8*3:9*3] + right_hand_pos = pose_data[:, 12*3:13*3] + head_ori = head_data + left_hand_pos -= head_pos # convert hand positions to head coordinate system + right_hand_pos -= head_pos + hand_head_data = left_hand_pos + hand_head_data = np.concatenate((hand_head_data, right_hand_pos), axis=1) + hand_head_data = np.concatenate((hand_head_data, head_ori), axis=1) + + fs = np.arange(0, num_frames - seq_len + 1) + fs_sel = fs + for i in np.arange(seq_len - 1): + fs_sel = np.vstack((fs_sel, fs + i + 1)) + fs_sel = fs_sel.transpose() + seq_sel = hand_head_data[fs_sel, :] + seq_sel = seq_sel[0::sample_rate, :, :] + if len(hand_head) == 0: + hand_head = seq_sel + else: + hand_head = np.concatenate((hand_head, seq_sel), axis=0) + + hand_head = np.transpose(hand_head, (0, 2, 1)) + return hand_head + + +def load_gimo(data_dir, seq_len, sample_rate=1, train=1): + data_dir_train = data_dir + 'train/' + data_dir_test = data_dir + 'test/' + if train == 0: + data_dirs = [data_dir_test] # test + elif train == 1: + data_dirs = [data_dir_train] # train + elif train == 2: + data_dirs = [data_dir_train, data_dir_test] # train + test + + hand_head = [] + for data_dir in data_dirs: + file_paths = sorted(os.listdir(data_dir)) + pose_xyz_file_paths = [] + head_file_paths = [] + for path in file_paths: + path_split = path.split('_') + data_type = path_split[-1][:-4] + if(data_type == 'xyz'): + pose_xyz_file_paths.append(path) + if(data_type == 'head'): + head_file_paths.append(path) + + file_num = len(pose_xyz_file_paths) + for i in range(file_num): + pose_data = np.load(data_dir + pose_xyz_file_paths[i]) + head_data = np.load(data_dir + head_file_paths[i]) + num_frames = pose_data.shape[0] + if num_frames < seq_len: + continue + + head_pos = pose_data[:, 15*3:16*3] + left_hand_pos = pose_data[:, 20*3:21*3] + right_hand_pos = pose_data[:, 21*3:22*3] + head_ori = head_data + left_hand_pos -= head_pos # convert hand positions to head coordinate system + right_hand_pos -= head_pos + hand_head_data = left_hand_pos + hand_head_data = np.concatenate((hand_head_data, right_hand_pos), axis=1) + hand_head_data = np.concatenate((hand_head_data, head_ori), axis=1) + + fs = np.arange(0, num_frames - seq_len + 1) + fs_sel = fs + for i in np.arange(seq_len - 1): + fs_sel = np.vstack((fs_sel, fs + i + 1)) + fs_sel = fs_sel.transpose() + seq_sel = hand_head_data[fs_sel, :] + seq_sel = seq_sel[0::sample_rate, :, :] + if len(hand_head) == 0: + hand_head = seq_sel + else: + hand_head = np.concatenate((hand_head, seq_sel), axis=0) + + hand_head = np.transpose(hand_head, (0, 2, 1)) + return hand_head + + +if __name__ == "__main__": + data_dir = "/scratch/hu/pose_forecast/egobody_pose2gaze/" + seq_len = 40 + + test_data = load_egobody(data_dir, seq_len, sample_rate=10, train=0) + print("\ndataset size: {}".format(test_data.shape)) \ No newline at end of file diff --git a/train.sh b/train.sh new file mode 100644 index 0000000..30cd6e3 --- /dev/null +++ b/train.sh @@ -0,0 +1,3 @@ +#python main.py --gpus 7 --mode 'train' --model_name 'haheae'; + +python main.py --gpus 7 --mode 'eval' --model_name 'haheae'; \ No newline at end of file