update readme

This commit is contained in:
Zhiming Hu 2025-06-03 21:11:04 +02:00
parent 35ee4b75e8
commit 249a01f342
18 changed files with 4936 additions and 0 deletions

Binary file not shown.

179
choices.py Normal file
View file

@ -0,0 +1,179 @@
from enum import Enum
from torch import nn
class TrainMode(Enum):
# manipulate mode = training the classifier
manipulate = 'manipulate'
# default training mode!
diffusion = 'diffusion'
# default latent training mode!
# fitting a diffusion model to a given latent
latent_diffusion = 'latentdiffusion'
def is_manipulate(self):
return self in [
TrainMode.manipulate,
]
def is_diffusion(self):
return self in [
TrainMode.diffusion,
TrainMode.latent_diffusion,
]
def is_autoenc(self):
# the network possibly does autoencoding
return self in [
TrainMode.diffusion,
]
def is_latent_diffusion(self):
return self in [
TrainMode.latent_diffusion,
]
def use_latent_net(self):
return self.is_latent_diffusion()
def require_dataset_infer(self):
"""
whether training in this mode requires the latent variables to be available?
"""
# this will precalculate all the latents before hand
# and the dataset will be all the predicted latents
return self in [
TrainMode.latent_diffusion,
TrainMode.manipulate,
]
class ManipulateMode(Enum):
"""
how to train the classifier to manipulate
"""
# train on whole celeba attr dataset
celebahq_all = 'celebahq_all'
# celeba with D2C's crop
d2c_fewshot = 'd2cfewshot'
d2c_fewshot_allneg = 'd2cfewshotallneg'
def is_celeba_attr(self):
return self in [
ManipulateMode.d2c_fewshot,
ManipulateMode.d2c_fewshot_allneg,
ManipulateMode.celebahq_all,
]
def is_single_class(self):
return self in [
ManipulateMode.d2c_fewshot,
ManipulateMode.d2c_fewshot_allneg,
]
def is_fewshot(self):
return self in [
ManipulateMode.d2c_fewshot,
ManipulateMode.d2c_fewshot_allneg,
]
def is_fewshot_allneg(self):
return self in [
ManipulateMode.d2c_fewshot_allneg,
]
class ModelType(Enum):
"""
Kinds of the backbone models
"""
# unconditional ddpm
ddpm = 'ddpm'
# autoencoding ddpm cannot do unconditional generation
autoencoder = 'autoencoder'
def has_autoenc(self):
return self in [
ModelType.autoencoder,
]
def can_sample(self):
return self in [ModelType.ddpm]
class ModelName(Enum):
"""
List of all supported model classes
"""
beatgans_ddpm = 'beatgans_ddpm'
beatgans_autoenc = 'beatgans_autoenc'
class ModelMeanType(Enum):
"""
Which type of output the model predicts.
"""
eps = 'eps' # the model predicts epsilon
class ModelVarType(Enum):
"""
What is used as the model's output variance.
The LEARNED_RANGE option has been added to allow the model to predict
values between FIXED_SMALL and FIXED_LARGE, making its job easier.
"""
# posterior beta_t
fixed_small = 'fixed_small'
# beta_t
fixed_large = 'fixed_large'
class LossType(Enum):
mse = 'mse' # use raw MSE loss (and KL when learning variances)
l1 = 'l1'
class GenerativeType(Enum):
"""
How's a sample generated
"""
ddpm = 'ddpm'
ddim = 'ddim'
class OptimizerType(Enum):
adam = 'adam'
adamw = 'adamw'
class Activation(Enum):
none = 'none'
relu = 'relu'
lrelu = 'lrelu'
silu = 'silu'
tanh = 'tanh'
def get_act(self):
if self == Activation.none:
return nn.Identity()
elif self == Activation.relu:
return nn.ReLU()
elif self == Activation.lrelu:
return nn.LeakyReLU(negative_slope=0.2)
elif self == Activation.silu:
return nn.SiLU()
elif self == Activation.tanh:
return nn.Tanh()
else:
raise NotImplementedError()
class ManipulateLossType(Enum):
bce = 'bce'
mse = 'mse'

153
config.py Normal file
View file

@ -0,0 +1,153 @@
from model.blocks import *
from diffusion.resample import UniformSampler
from dataclasses import dataclass
from diffusion.diffusion import space_timesteps
from typing import Tuple
from config_base import BaseConfig
from diffusion import *
from diffusion.base import GenerativeType, LossType, ModelMeanType, ModelVarType, get_named_beta_schedule
from model import *
from choices import *
from preprocess import *
import os
@dataclass
class TrainConfig(BaseConfig):
name: str = ''
base_dir: str = './checkpoints/'
logdir: str = f'{base_dir}{name}'
data_name: str = ''
data_val_name: str = ''
seq_len: int = 40 # for reconstruction
seq_len_future: int = 3 # for prediction
in_channels = 9
fp16: bool = True
lr: float = 1e-4
ema_decay: float = 0.9999
seed: int = 0 # random seed
batch_size: int = 64
accum_batches: int = 1
batch_size_eval: int = 1024
total_epochs: int = 1_000
save_every_epochs: int = 10
eval_every_epochs: int = 10
train_mode: TrainMode = TrainMode.diffusion
T: int = 1000
T_eval: int = 100
diffusion_type: str = 'beatgans'
semantic_encoder_type: str = 'gcn'
net_beatgans_embed_channels: int = 128
beatgans_gen_type: GenerativeType = GenerativeType.ddim
beatgans_loss_type: LossType = LossType.mse
hand_mse_factor = 1.0
head_mse_factor = 1.0
beatgans_model_mean_type: ModelMeanType = ModelMeanType.eps
beatgans_model_var_type: ModelVarType = ModelVarType.fixed_large
beatgans_rescale_timesteps: bool = False
beta_scheduler: str = 'linear'
net_ch: int = 64
net_ch_mult: Tuple[int, ...]= (1, 2, 4)
net_enc_channel_mult: Tuple[int] = (1, 2, 4)
grad_clip: float = 1
optimizer: OptimizerType = OptimizerType.adam
weight_decay: float = 0
warmup: int = 0
model_conf: ModelConfig = None
model_name: ModelName = ModelName.beatgans_autoenc
model_type: ModelType = None
@property
def batch_size_effective(self):
return self.batch_size*self.accum_batches
def _make_diffusion_conf(self, T=None):
if self.diffusion_type == 'beatgans':
# can use T < self.T for evaluation
# follows the guided-diffusion repo conventions
# t's are evenly spaced
if self.beatgans_gen_type == GenerativeType.ddpm:
section_counts = [T]
elif self.beatgans_gen_type == GenerativeType.ddim:
section_counts = f'ddim{T}'
else:
raise NotImplementedError()
return SpacedDiffusionBeatGansConfig(
gen_type=self.beatgans_gen_type,
model_type=self.model_type,
betas=get_named_beta_schedule(self.beta_scheduler, T),
model_mean_type=self.beatgans_model_mean_type,
model_var_type=self.beatgans_model_var_type,
loss_type=self.beatgans_loss_type,
rescale_timesteps=self.beatgans_rescale_timesteps,
use_timesteps=space_timesteps(num_timesteps=T, section_counts=section_counts),
fp16=self.fp16,
)
else:
raise NotImplementedError()
@property
def model_out_channels(self):
return self.in_channels
@property
def model_input_channels(self):
return self.in_channels
def make_T_sampler(self):
return UniformSampler(self.T)
def make_diffusion_conf(self):
return self._make_diffusion_conf(self.T)
def make_eval_diffusion_conf(self):
return self._make_diffusion_conf(T=self.T_eval)
def make_model_conf(self):
cls = BeatGANsAutoencConfig
if self.model_name == ModelName.beatgans_autoenc:
self.model_type = ModelType.autoencoder
else:
raise NotImplementedError()
self.model_conf = cls(
semantic_encoder_type = self.semantic_encoder_type,
channel_mult=self.net_ch_mult,
seq_len = self.seq_len,
seq_len_future = self.seq_len_future,
embed_channels=self.net_beatgans_embed_channels,
enc_out_channels=self.net_beatgans_embed_channels,
enc_channel_mult=self.net_enc_channel_mult,
in_channels=self.model_input_channels,
model_channels=self.net_ch,
out_channels=self.model_out_channels,
)
return self.model_conf
def egobody_autoenc(mode, encoder_type='gcn', hand_mse_factor=1.0, head_mse_factor=1.0, data_sample_rate=1, epoch=130,in_channels=9, seq_len=40):
conf = TrainConfig()
conf.seq_len = seq_len
conf.seq_len_future = 3
conf.in_channels = in_channels
conf.net_beatgans_embed_channels = 128
conf.net_ch = 64
conf.net_ch_mult = (1, 1, 1)
conf.semantic_encoder_type = encoder_type
conf.hand_mse_factor = hand_mse_factor
conf.head_mse_factor = head_mse_factor
conf.net_enc_channel_mult = conf.net_ch_mult
conf.total_epochs = epoch
conf.save_every_epochs = 10
conf.eval_every_epochs = 10
conf.batch_size = 64
conf.batch_size_eval = 1024*4
conf.data_dir = "/scratch/hu/pose_forecast/egobody_pose2gaze/"
conf.data_sample_rate = data_sample_rate
conf.name = 'egobody_autoenc'
conf.data_name = 'egobody'
conf.mode = mode
conf.make_model_conf()
return conf

72
config_base.py Normal file
View file

@ -0,0 +1,72 @@
import json
import os
from copy import deepcopy
from dataclasses import dataclass
@dataclass
class BaseConfig:
def clone(self):
return deepcopy(self)
def inherit(self, another):
"""inherit common keys from a given config"""
common_keys = set(self.__dict__.keys()) & set(another.__dict__.keys())
for k in common_keys:
setattr(self, k, getattr(another, k))
def propagate(self):
"""push down the configuration to all members"""
for k, v in self.__dict__.items():
if isinstance(v, BaseConfig):
v.inherit(self)
v.propagate()
def save(self, save_path):
"""save config to json file"""
dirname = os.path.dirname(save_path)
if not os.path.exists(dirname):
os.makedirs(dirname)
conf = self.as_dict_jsonable()
with open(save_path, 'w') as f:
json.dump(conf, f)
def load(self, load_path):
"""load json config"""
with open(load_path) as f:
conf = json.load(f)
self.from_dict(conf)
def from_dict(self, dict, strict=False):
for k, v in dict.items():
if not hasattr(self, k):
if strict:
raise ValueError(f"loading extra '{k}'")
else:
print(f"loading extra '{k}'")
continue
if isinstance(self.__dict__[k], BaseConfig):
self.__dict__[k].from_dict(v)
else:
self.__dict__[k] = v
def as_dict_jsonable(self):
conf = {}
for k, v in self.__dict__.items():
if isinstance(v, BaseConfig):
conf[k] = v.as_dict_jsonable()
else:
if jsonable(v):
conf[k] = v
else:
# ignore not jsonable
pass
return conf
def jsonable(x):
try:
json.dumps(x)
return True
except TypeError:
return False

6
diffusion/__init__.py Normal file
View file

@ -0,0 +1,6 @@
from typing import Union
from .diffusion import SpacedDiffusionBeatGans, SpacedDiffusionBeatGansConfig
Sampler = Union[SpacedDiffusionBeatGans]
SamplerConfig = Union[SpacedDiffusionBeatGansConfig]

1148
diffusion/base.py Normal file

File diff suppressed because it is too large Load diff

182
diffusion/diffusion.py Normal file
View file

@ -0,0 +1,182 @@
from .base import *
from dataclasses import dataclass
def space_timesteps(num_timesteps, section_counts):
"""
Create a list of timesteps to use from an original diffusion process,
given the number of timesteps we want to take from equally-sized portions
of the original process.
For example, if there's 300 timesteps and the section counts are [10,15,20]
then the first 100 timesteps are strided to be 10 timesteps, the second 100
are strided to be 15 timesteps, and the final 100 are strided to be 20.
If the stride is a string starting with "ddim", then the fixed striding
from the DDIM paper is used, and only one section is allowed.
:param num_timesteps: the number of diffusion steps in the original
process to divide up.
:param section_counts: either a list of numbers, or a string containing
comma-separated numbers, indicating the step count
per section. As a special case, use "ddimN" where N
is a number of steps to use the striding from the
DDIM paper.
:return: a set of diffusion steps from the original process to use.
"""
if isinstance(section_counts, str):
if section_counts.startswith("ddim"):
desired_count = int(section_counts[len("ddim"):])
for i in range(1, num_timesteps):
if len(range(0, num_timesteps, i)) == desired_count:
return set(range(0, num_timesteps, i))
raise ValueError(
f"cannot create exactly {num_timesteps} steps with an integer stride"
)
section_counts = [int(x) for x in section_counts.split(",")]
size_per = num_timesteps // len(section_counts)
extra = num_timesteps % len(section_counts)
start_idx = 0
all_steps = []
for i, section_count in enumerate(section_counts):
size = size_per + (1 if i < extra else 0)
if size < section_count:
raise ValueError(
f"cannot divide section of {size} steps into {section_count}")
if section_count <= 1:
frac_stride = 1
else:
frac_stride = (size - 1) / (section_count - 1)
cur_idx = 0.0
taken_steps = []
for _ in range(section_count):
taken_steps.append(start_idx + round(cur_idx))
cur_idx += frac_stride
all_steps += taken_steps
start_idx += size
return set(all_steps)
@dataclass
class SpacedDiffusionBeatGansConfig(GaussianDiffusionBeatGansConfig):
use_timesteps: Tuple[int] = None
def make_sampler(self):
return SpacedDiffusionBeatGans(self)
class SpacedDiffusionBeatGans(GaussianDiffusionBeatGans):
"""
A diffusion process which can skip steps in a base diffusion process.
:param use_timesteps: a collection (sequence or set) of timesteps from the
original diffusion process to retain.
:param kwargs: the kwargs to create the base diffusion process.
"""
def __init__(self, conf: SpacedDiffusionBeatGansConfig):
self.conf = conf
self.use_timesteps = set(conf.use_timesteps)
# how the new t's mapped to the old t's
self.timestep_map = []
self.original_num_steps = len(conf.betas)
base_diffusion = GaussianDiffusionBeatGans(conf) # pylint: disable=missing-kwoa
last_alpha_cumprod = 1.0
new_betas = []
for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
if i in self.use_timesteps:
# getting the new betas of the new timesteps
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
last_alpha_cumprod = alpha_cumprod
self.timestep_map.append(i)
conf.betas = np.array(new_betas)
super().__init__(conf)
def p_mean_variance(self, model: Model, *args, **kwargs): # pylint: disable=signature-differs
return super().p_mean_variance(self._wrap_model(model), *args,
**kwargs)
def training_losses(self, model: Model, *args, **kwargs): # pylint: disable=signature-differs
return super().training_losses(self._wrap_model(model), *args,
**kwargs)
def condition_mean(self, cond_fn, *args, **kwargs):
return super().condition_mean(self._wrap_model(cond_fn), *args,
**kwargs)
def condition_score(self, cond_fn, *args, **kwargs):
return super().condition_score(self._wrap_model(cond_fn), *args,
**kwargs)
def _wrap_model(self, model: Model):
if isinstance(model, _WrappedModel):
return model
return _WrappedModel(model, self.timestep_map, self.rescale_timesteps,
self.original_num_steps)
def _scale_timesteps(self, t):
# Scaling is done by the wrapped model.
return t
class _WrappedModel:
"""
converting the supplied t's to the old t's scales.
"""
def __init__(self, model, timestep_map, rescale_timesteps,
original_num_steps):
self.model = model
self.timestep_map = timestep_map
self.rescale_timesteps = rescale_timesteps
self.original_num_steps = original_num_steps
def forward(self, x, t, t_cond=None, **kwargs):
"""
Args:
t: t's with differrent ranges (can be << T due to smaller eval T) need to be converted to the original t's
t_cond: the same as t but can be of different values
"""
map_tensor = th.tensor(self.timestep_map,
device=t.device,
dtype=t.dtype)
def do(t):
new_ts = map_tensor[t]
if self.rescale_timesteps:
new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
return new_ts
if t_cond is not None:
# support t_cond
t_cond = do(t_cond)
## run_ffhq256.py None
return self.model(x=x, t=do(t), t_cond=t_cond, **kwargs) ## self.model.__class__==model.unet_autoenc.BeatGANsAutoencModel
def __getattr__(self, name):
# allow for calling the model's methods
if hasattr(self.model, name):
func = getattr(self.model, name)
return func
raise AttributeError(name)
# def __call__(self, x, t, t_cond=None, **kwargs):
# """
# Args:
# t: t's with differrent ranges (can be << T due to smaller eval T) need to be converted to the original t's
# t_cond: the same as t but can be of different values
# """
# map_tensor = th.tensor(self.timestep_map,
# device=t.device,
# dtype=t.dtype)
# def do(t):
# new_ts = map_tensor[t]
# if self.rescale_timesteps:
# new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
# return new_ts
# if t_cond is not None:
# # support t_cond
# t_cond = do(t_cond)
# ## run_ffhq256.py None
# return self.model(x=x, t=do(t), t_cond=t_cond, **kwargs)

63
diffusion/resample.py Normal file
View file

@ -0,0 +1,63 @@
from abc import ABC, abstractmethod
import numpy as np
import torch as th
import torch.distributed as dist
def create_named_schedule_sampler(name, diffusion):
"""
Create a ScheduleSampler from a library of pre-defined samplers.
:param name: the name of the sampler.
:param diffusion: the diffusion object to sample for.
"""
if name == "uniform":
return UniformSampler(diffusion)
else:
raise NotImplementedError(f"unknown schedule sampler: {name}")
class ScheduleSampler(ABC):
"""
A distribution over timesteps in the diffusion process, intended to reduce
variance of the objective.
By default, samplers perform unbiased importance sampling, in which the
objective's mean is unchanged.
However, subclasses may override sample() to change how the resampled
terms are reweighted, allowing for actual changes in the objective.
"""
@abstractmethod
def weights(self):
"""
Get a numpy array of weights, one per diffusion step.
The weights needn't be normalized, but must be positive.
"""
def sample(self, batch_size, device):
"""
Importance-sample timesteps for a batch.
:param batch_size: the number of timesteps.
:param device: the torch device to save to.
:return: a tuple (timesteps, weights):
- timesteps: a tensor of timestep indices.
- weights: a tensor of weights to scale the resulting losses.
"""
w = self.weights()
p = w / np.sum(w)
indices_np = np.random.choice(len(p), size=(batch_size, ), p=p)
indices = th.from_numpy(indices_np).long().to(device)
weights_np = 1 / (len(p) * p[indices_np])
weights = th.from_numpy(weights_np).float().to(device)
return indices, weights
class UniformSampler(ScheduleSampler):
def __init__(self, num_timesteps): ## all steps are 1
self._weights = np.ones([num_timesteps])
def weights(self):
return self._weights

101
environment/haheae.yml Normal file
View file

@ -0,0 +1,101 @@
name: haheae
channels:
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- ca-certificates=2023.12.12=h06a4308_0
- ld_impl_linux-64=2.38=h1181459_1
- libffi=3.4.4=h6a678d5_0
- libgcc-ng=11.2.0=h1234567_1
- libgomp=11.2.0=h1234567_1
- libstdcxx-ng=11.2.0=h1234567_1
- ncurses=6.4=h6a678d5_0
- openssl=3.0.12=h7f8727e_0
- pip=23.3.1=py38h06a4308_0
- python=3.8.18=h955ad1f_0
- readline=8.2=h5eee18b_0
- setuptools=68.2.2=py38h06a4308_0
- sqlite=3.41.2=h5eee18b_0
- tk=8.6.12=h1ccaba5_0
- wheel=0.41.2=py38h06a4308_0
- xz=5.4.5=h5eee18b_0
- zlib=1.2.13=h5eee18b_0
- pip:
- absl-py==2.0.0
- aiohttp==3.9.1
- aiosignal==1.3.1
- appdirs==1.4.4
- async-timeout==4.0.3
- attrs==23.2.0
- cachetools==5.3.2
- certifi==2023.11.17
- charset-normalizer==3.3.2
- click==8.1.7
- contourpy==1.1.1
- cycler==0.12.1
- cython==0.29.37
- docker-pycreds==0.4.0
- fonttools==4.47.2
- frozenlist==1.4.1
- fsspec==2023.12.2
- ftfy==6.1.3
- future==0.18.3
- gitdb==4.0.11
- gitpython==3.1.41
- google-auth==2.26.2
- google-auth-oauthlib==1.0.0
- grpcio==1.60.0
- hdbscan==0.8.33
- idna==3.6
- importlib-metadata==7.0.1
- importlib-resources==6.1.1
- joblib==1.3.2
- kiwisolver==1.4.5
- lmdb==1.2.1
- lpips==0.1.4
- markdown==3.5.2
- markupsafe==2.1.3
- matplotlib==3.5.3
- multidict==6.0.4
- numpy==1.24.4
- oauthlib==3.2.2
- packaging==23.2
- pandas==1.5.3
- pillow==10.2.0
- protobuf==4.25.2
- psutil==5.9.8
- pyasn1==0.5.1
- pyasn1-modules==0.3.0
- pydeprecate==0.3.1
- pyparsing==3.1.1
- python-dateutil==2.8.2
- pytorch-fid==0.2.0
- pytorch-lightning==1.4.5
- pytz==2023.3.post1
- pyyaml==6.0.1
- regex==2023.12.25
- requests==2.31.0
- requests-oauthlib==1.3.1
- rsa==4.9
- scikit-learn==1.3.2
- scipy==1.5.4
- sentry-sdk==1.39.2
- setproctitle==1.3.3
- six==1.16.0
- smmap==5.0.1
- tensorboard==2.14.0
- tensorboard-data-server==0.7.2
- threadpoolctl==3.2.0
- torch==1.8.1
- torchmetrics==0.5.0
- torchvision==0.9.1
- tqdm==4.66.1
- typing-extensions==4.9.0
- tzdata==2023.4
- urllib3==2.1.0
- wandb==0.16.2
- wcwidth==0.2.13
- werkzeug==3.0.1
- yarl==1.9.4
- zipp==3.17.0

565
main.py Normal file
View file

@ -0,0 +1,565 @@
import warnings
warnings.filterwarnings("ignore")
import os
os.nice(5)
import copy, wandb
from tqdm import tqdm, trange
import argparse
import json
import re
import random
import math
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import *
from torch import nn
import torch.nn.functional as F
from torch.cuda import amp
from torch.optim.optimizer import Optimizer
from config import *
import time
import datetime
class LitModel(pl.LightningModule):
def __init__(self, conf: TrainConfig):
super().__init__()
## wandb
self.save_hyperparameters({k:v for (k,v) in vars(conf).items() if not callable(v)})
if conf.seed is not None:
pl.seed_everything(conf.seed)
self.save_hyperparameters(conf.as_dict_jsonable())
self.conf = conf
self.model = conf.make_model_conf().make_model()
self.ema_model = copy.deepcopy(self.model)
self.ema_model.requires_grad_(False)
self.ema_model.eval()
model_size = 0
for param in self.model.parameters():
model_size += param.data.nelement()
print('Model params: %.3f M' % (model_size / 1024 / 1024))
self.sampler = conf.make_diffusion_conf().make_sampler()
self.eval_sampler = conf.make_eval_diffusion_conf().make_sampler()
self.T_sampler = conf.make_T_sampler()
self.save_every_epochs = conf.save_every_epochs
self.eval_every_epochs = conf.eval_every_epochs
def setup(self, stage=None) -> None:
"""
make datasets & seeding each worker separately
"""
##############################################
# NEED TO SET THE SEED SEPARATELY HERE
if self.conf.seed is not None:
seed = self.conf.seed
np.random.seed(seed)
random.seed(seed) # Python random module.
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
print('seed:', seed)
##############################################
## Load dataset
if self.conf.mode == 'train':
self.train_data = load_egobody(self.conf.data_dir, self.conf.seq_len+self.conf.seq_len_future, self.conf.data_sample_rate, train=1)
self.val_data = load_egobody(self.conf.data_dir, self.conf.seq_len+self.conf.seq_len_future, self.conf.data_sample_rate, train=0)
if self.conf.in_channels == 6: # hand only
self.train_data = self.train_data[:, :6, :]
self.val_data = self.val_data[:, :6, :]
if self.conf.in_channels == 3: # head only
self.train_data = self.train_data[:, 6:, :]
self.val_data = self.val_data[:, 6:, :]
def encode(self, x):
assert self.conf.model_type.has_autoenc()
cond, pred_hand, pred_head = self.ema_model.encoder.forward(x)
return cond, pred_hand, pred_head
def encode_stochastic(self, x, cond, T=None):
if T is None:
sampler = self.eval_sampler
else:
sampler = self.conf._make_diffusion_conf(T).make_sampler() # get noise at step T
## x_0 -> x-T using reverse of inference
out = sampler.ddim_reverse_sample_loop(self.ema_model, x, model_kwargs={'cond': cond})
''' 'sample': x_T
'sample_t': x_t, t in (1, ..., T)
'xstart_t': predicted x_0 at each timestep. "xstart here is a bit different from sampling from T = T-1 to T = 0"
'T': (1, ..., T)
'''
return out['sample']
def train_dataloader(self):
return torch.utils.data.DataLoader(self.train_data, batch_size=self.conf.batch_size, shuffle=True)
def val_dataloader(self):
return torch.utils.data.DataLoader(self.val_data, batch_size=self.conf.batch_size_eval, shuffle=False)
def is_last_accum(self, batch_idx):
"""
is it the last gradient accumulation loop?
used with gradient_accum > 1 and to see if the optimizer will perform "step" in this iteration or not
"""
return (batch_idx + 1) % self.conf.accum_batches == 0
def training_step(self, batch, batch_idx):
"""
given an input, calculate the loss function
no optimization at this stage.
"""
with amp.autocast(False):
x_start = batch[:, :, :self.conf.seq_len]
x_future = batch[:, :, self.conf.seq_len:]
if self.conf.train_mode == TrainMode.diffusion:
"""
main training mode!!!
"""
t, weight = self.T_sampler.sample(len(x_start), x_start.device)
''' self.T_sampler: diffusion.resample.UniformSampler (weights for all timesteps are 1)
- t: a tensor of timestep indices.
- weight: a tensor of weights to scale the resulting losses.
## num_timesteps is self.conf.T == 1000
'''
losses = self.sampler.training_losses(model=self.model,
x_start=x_start,
t=t,
x_future=x_future,
hand_mse_factor = self.conf.hand_mse_factor,
head_mse_factor = self.conf.head_mse_factor,
)
else:
raise NotImplementedError()
loss = losses['loss'].mean() ## average loss across mini-batches
#noise_mse = losses['mse'].mean()
#hand_mse = losses['hand_mse'].mean()
#head_mse = losses['head_mse'].mean()
## Log loss and metric (wandb)
self.log("train_loss", loss, on_epoch=True, prog_bar=True)
#self.log("train_noise_mse", noise_mse, on_epoch=True, prog_bar=True)
#self.log("train_hand_mse", hand_mse, on_epoch=True, prog_bar=True)
#self.log("train_head_mse", head_mse, on_epoch=True, prog_bar=True)
return {'loss': loss}
def validation_step(self, batch, batch_idx):
if self.conf.in_channels == 9: # both hand and head
if((self.current_epoch+1)% self.eval_every_epochs == 0):
batch_future = batch[:, :, self.conf.seq_len:]
gt_hand_future = batch_future[:, :6, :]
gt_head_future = batch_future[:, 6:, :]
batch = batch[:, :, :self.conf.seq_len]
cond, pred_hand_future, pred_head_future = self.encode(batch)
xT = self.encode_stochastic(batch, cond)
pred_xstart = self.generate(xT, cond)
# hand reconstruction error
gt_hand = batch[:, :6, :]
pred_hand = pred_xstart[:, :6, :]
bs, channels, seq_len = gt_hand.shape
gt_hand = gt_hand.reshape(bs, 2, 3, seq_len)
pred_hand = pred_hand.reshape(bs, 2, 3, seq_len)
hand_traj = torch.mean(torch.norm(gt_hand - pred_hand, dim=2))
# hand prediction error
bs, channels, seq_len = gt_hand_future.shape
gt_hand_future = gt_hand_future.reshape(bs, 2, 3, seq_len)
pred_hand_future = pred_hand_future.reshape(bs, 2, 3, seq_len)
baseline_hand_future = gt_hand[:, :, :, -1:].expand(-1, -1, -1, self.conf.seq_len_future).clone()
hand_traj_future = torch.mean(torch.norm(gt_hand_future - pred_hand_future, dim=2))
hand_traj_future_baseline = torch.mean(torch.norm(gt_hand_future - baseline_hand_future, dim=2))
# head reconstruction error
gt_head = batch[:, 6:, :]
gt_head = F.normalize(gt_head, dim=1) # normalize head orientation to unit vectors
pred_head = pred_xstart[:, 6:, :]
pred_head = F.normalize(pred_head, dim=1)
head_ang = torch.mean(acos_safe(torch.sum(gt_head*pred_head, 1)))/torch.tensor(math.pi) * 180.0
# head prediction error
gt_head_future = F.normalize(gt_head_future, dim=1)
pred_head_future = F.normalize(pred_head_future, dim=1)
baseline_head_future = gt_head[:, :, -1:].expand(-1, -1, self.conf.seq_len_future).clone()
head_ang_future = torch.mean(acos_safe(torch.sum(gt_head_future*pred_head_future, 1)))/torch.tensor(math.pi) * 180.0
head_ang_future_baseline = torch.mean(acos_safe(torch.sum(gt_head_future*baseline_head_future, 1)))/torch.tensor(math.pi) * 180.0
self.log("val_hand_traj", hand_traj, on_epoch=True, prog_bar=True)
self.log("val_head_ang", head_ang, on_epoch=True, prog_bar=True)
self.log("val_hand_traj_future", hand_traj_future, on_epoch=True, prog_bar=True)
self.log("val_head_ang_future", head_ang_future, on_epoch=True, prog_bar=True)
self.log("val_hand_traj_future_baseline", hand_traj_future_baseline, on_epoch=True, prog_bar=True)
self.log("val_head_ang_future_baseline", head_ang_future_baseline, on_epoch=True, prog_bar=True)
if self.conf.in_channels == 6: # hand only
if((self.current_epoch+1)% self.eval_every_epochs == 0):
batch_future = batch[:, :, self.conf.seq_len:]
gt_hand_future = batch_future[:, :, :]
batch = batch[:, :, :self.conf.seq_len]
cond, pred_hand_future, pred_head_future = self.encode(batch)
xT = self.encode_stochastic(batch, cond)
pred_xstart = self.generate(xT, cond)
# hand reconstruction error
gt_hand = batch[:, :, :]
pred_hand = pred_xstart[:, :, :]
bs, channels, seq_len = gt_hand.shape
gt_hand = gt_hand.reshape(bs, 2, 3, seq_len)
pred_hand = pred_hand.reshape(bs, 2, 3, seq_len)
hand_traj = torch.mean(torch.norm(gt_hand - pred_hand, dim=2))
# hand prediction error
bs, channels, seq_len = gt_hand_future.shape
gt_hand_future = gt_hand_future.reshape(bs, 2, 3, seq_len)
pred_hand_future = pred_hand_future.reshape(bs, 2, 3, seq_len)
baseline_hand_future = gt_hand[:, :, :, -1:].expand(-1, -1, -1, self.conf.seq_len_future).clone()
hand_traj_future = torch.mean(torch.norm(gt_hand_future - pred_hand_future, dim=2))
hand_traj_future_baseline = torch.mean(torch.norm(gt_hand_future - baseline_hand_future, dim=2))
self.log("val_hand_traj", hand_traj, on_epoch=True, prog_bar=True)
self.log("val_hand_traj_future", hand_traj_future, on_epoch=True, prog_bar=True)
self.log("val_hand_traj_future_baseline", hand_traj_future_baseline, on_epoch=True, prog_bar=True)
if self.conf.in_channels == 3: # head only
if((self.current_epoch+1)% self.eval_every_epochs == 0):
batch_future = batch[:, :, self.conf.seq_len:]
gt_head_future = batch_future[:, :, :]
batch = batch[:, :, :self.conf.seq_len]
cond, pred_hand_future, pred_head_future = self.encode(batch)
xT = self.encode_stochastic(batch, cond)
pred_xstart = self.generate(xT, cond)
# head reconstruction error
gt_head = batch[:, :, :]
gt_head = F.normalize(gt_head, dim=1) # normalize head orientation to unit vectors
pred_head = pred_xstart[:, :, :]
pred_head = F.normalize(pred_head, dim=1)
head_ang = torch.mean(acos_safe(torch.sum(gt_head*pred_head, 1)))/torch.tensor(math.pi) * 180.0
# head prediction error
gt_head_future = F.normalize(gt_head_future, dim=1)
pred_head_future = F.normalize(pred_head_future, dim=1)
baseline_head_future = gt_head[:, :, -1:].expand(-1, -1, self.conf.seq_len_future).clone()
head_ang_future = torch.mean(acos_safe(torch.sum(gt_head_future*pred_head_future, 1)))/torch.tensor(math.pi) * 180.0
head_ang_future_baseline = torch.mean(acos_safe(torch.sum(gt_head_future*baseline_head_future, 1)))/torch.tensor(math.pi) * 180.0
self.log("val_head_ang", head_ang, on_epoch=True, prog_bar=True)
self.log("val_head_ang_future", head_ang_future, on_epoch=True, prog_bar=True)
self.log("val_head_ang_future_baseline", head_ang_future_baseline, on_epoch=True, prog_bar=True)
def test_step(self, batch, batch_idx):
batch_future = batch[:, :, self.conf.seq_len:]
gt_hand_future = batch_future[:, :6, :]
gt_head_future = batch_future[:, 6:, :]
batch = batch[:, :, :self.conf.seq_len]
cond, pred_hand_future, pred_head_future = self.encode(batch)
xT = self.encode_stochastic(batch, cond)
pred_xstart = self.generate(xT, cond)
# hand reconstruction error
gt_hand = batch[:, :6, :]
pred_hand = pred_xstart[:, :6, :]
bs, channels, seq_len = gt_hand.shape
gt_hand = gt_hand.reshape(bs, 2, 3, seq_len)
pred_hand = pred_hand.reshape(bs, 2, 3, seq_len)
hand_traj = torch.mean(torch.norm(gt_hand - pred_hand, dim=2))
# hand prediction error
bs, channels, seq_len = gt_hand_future.shape
gt_hand_future = gt_hand_future.reshape(bs, 2, 3, seq_len)
pred_hand_future = pred_hand_future.reshape(bs, 2, 3, seq_len)
baseline_hand_future = gt_hand[:, :, :, -1:].expand(-1, -1, -1, self.conf.seq_len_future).clone()
hand_traj_future = torch.mean(torch.norm(gt_hand_future - pred_hand_future, dim=2))
hand_traj_future_baseline = torch.mean(torch.norm(gt_hand_future - baseline_hand_future, dim=2))
# head reconstruction error
gt_head = batch[:, 6:, :]
gt_head = F.normalize(gt_head, dim=1) # normalize head orientation to unit vectors
pred_head = pred_xstart[:, 6:, :]
pred_head = F.normalize(pred_head, dim=1)
head_ang = torch.mean(acos_safe(torch.sum(gt_head*pred_head, 1)))/torch.tensor(math.pi) * 180.0
# head prediction error
gt_head_future = F.normalize(gt_head_future, dim=1)
pred_head_future = F.normalize(pred_head_future, dim=1)
baseline_head_future = gt_head[:, :, -1:].expand(-1, -1, self.conf.seq_len_future).clone()
head_ang_future = torch.mean(acos_safe(torch.sum(gt_head_future*pred_head_future, 1)))/torch.tensor(math.pi) * 180.0
head_ang_future_baseline = torch.mean(acos_safe(torch.sum(gt_head_future*baseline_head_future, 1)))/torch.tensor(math.pi) * 180.0
self.log("test_hand_traj", hand_traj, on_epoch=True, prog_bar=True)
self.log("test_head_ang", head_ang, on_epoch=True, prog_bar=True)
self.log("test_hand_traj_future", hand_traj_future, on_epoch=True, prog_bar=True)
self.log("test_head_ang_future", head_ang_future, on_epoch=True, prog_bar=True)
self.log("test_hand_traj_future_baseline", hand_traj_future_baseline, on_epoch=True, prog_bar=True)
self.log("test_head_ang_future_baseline", head_ang_future_baseline, on_epoch=True, prog_bar=True)
def generate(self, noise, cond=None, ema=True, T=None):
if T is None:
sampler = self.eval_sampler
else:
sampler = self.conf._make_diffusion_conf(T).make_sampler()
if ema:
model = self.ema_model
else:
model = self.model
gen = sampler.sample(model=model, noise=noise, model_kwargs={'cond': cond})
return gen
def on_train_batch_end(self, outputs, batch, batch_idx: int,
dataloader_idx: int) -> None:
"""
after each training step ...
"""
if self.is_last_accum(batch_idx):
# only apply ema on the last gradient accumulation step,
# if it is the iteration that has optimizer.step()
ema(self.model, self.ema_model, self.conf.ema_decay)
if (batch_idx==len(self.train_dataloader())-1) and ((self.current_epoch+1) % self.save_every_epochs == 0):
save_path = os.path.join(self.conf.logdir, 'epoch_%d.ckpt' % (self.current_epoch+1))
torch.save({
'state_dict': self.state_dict(),
'global_step': self.global_step,
'loss': outputs['loss'],
}, save_path)
def on_before_optimizer_step(self, optimizer: Optimizer,
optimizer_idx: int) -> None:
# fix the fp16 + clip grad norm problem with pytorch lightinng
# this is the currently correct way to do it
if self.conf.grad_clip > 0:
# from trainer.params_grads import grads_norm, iter_opt_params
params = [
p for group in optimizer.param_groups for p in group['params']
]
# print('before:', grads_norm(iter_opt_params(optimizer)))
torch.nn.utils.clip_grad_norm_(params, max_norm=self.conf.grad_clip)
# print('after:', grads_norm(iter_opt_params(optimizer)))
def configure_optimizers(self):
out = {}
if self.conf.optimizer == OptimizerType.adam:
optim = torch.optim.Adam(self.model.parameters(),
lr=self.conf.lr,
weight_decay=self.conf.weight_decay)
elif self.conf.optimizer == OptimizerType.adamw:
optim = torch.optim.AdamW(self.model.parameters(),
lr=self.conf.lr,
weight_decay=self.conf.weight_decay)
else:
raise NotImplementedError()
out['optimizer'] = optim
if self.conf.warmup > 0:
sched = torch.optim.lr_scheduler.LambdaLR(optim,
lr_lambda=WarmupLR(
self.conf.warmup))
out['lr_scheduler'] = {
'scheduler': sched,
'interval': 'step',
}
return out
def ema(source, target, decay):
source_dict = source.state_dict()
target_dict = target.state_dict()
for key in source_dict.keys():
target_dict[key].data.copy_(target_dict[key].data * decay +
source_dict[key].data * (1 - decay))
class WarmupLR:
def __init__(self, warmup) -> None:
self.warmup = warmup
def __call__(self, step):
return min(step, self.warmup) / self.warmup
def train(conf: TrainConfig, model: LitModel, gpus):
checkpoint = ModelCheckpoint(dirpath=conf.logdir,
filename='last',
save_last=True,
save_top_k=1,
every_n_epochs=conf.save_every_epochs,
)
checkpoint_path = f'{conf.logdir}last.ckpt'
if os.path.exists(checkpoint_path):
resume = checkpoint_path
if conf.mode == 'train':
print('ckpt path:', checkpoint_path)
else:
print('checkpoint not found!')
resume = None
wandb_logger = pl_loggers.WandbLogger(project='haheae',
name='%s_%s'%(model.conf.data_name, conf.logdir.split('/')[-2]),
log_model=True,
save_dir=conf.logdir,
dir = conf.logdir,
config=vars(model.conf),
save_code=True,
settings=wandb.Settings(code_dir="."))
trainer = pl.Trainer(
max_epochs=conf.total_epochs,
resume_from_checkpoint=resume,
gpus=gpus,
precision=16 if conf.fp16 else 32,
callbacks=[
checkpoint,
LearningRateMonitor(),
],
logger= wandb_logger,
accumulate_grad_batches=conf.accum_batches,
progress_bar_refresh_rate=4,
)
if conf.mode == 'train':
trainer.fit(model)
elif conf.mode == 'eval':
checkpoint_path = f'{conf.logdir}last.ckpt'
# load the latest checkpoint
print('loading from:', checkpoint_path)
state = torch.load(checkpoint_path)
model.load_state_dict(state['state_dict'])
test_datasets = ['egobody', 'adt', 'gimo']
for dataset_name in test_datasets:
if dataset_name == 'egobody':
data_dir = "/scratch/hu/pose_forecast/egobody_pose2gaze/"
test_data = load_egobody(data_dir, conf.seq_len+conf.seq_len_future, 1, train=0) # use the test set
elif dataset_name == 'adt':
data_dir = "/scratch/hu/pose_forecast/adt_pose2gaze/"
test_data = load_adt(data_dir, conf.seq_len+conf.seq_len_future, 1, train=2) # use the train+test set
elif dataset_name == 'gimo':
data_dir = "/scratch/hu/pose_forecast/gimo_pose2gaze/"
test_data = load_gimo(data_dir, conf.seq_len+conf.seq_len_future, 1, train=2) # use the train+test set
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=conf.batch_size_eval, shuffle=False)
results = trainer.test(model, dataloaders=test_dataloader, verbose=False)
print("\n\nTest on {}, dataset size: {}".format(dataset_name, test_data.shape))
print("test_hand_traj: {:.3f} cm".format(results[0]['test_hand_traj']*100))
print("test_head_ang: {:.3f} deg".format(results[0]['test_head_ang']))
print("test_hand_traj_future: {:.3f} cm".format(results[0]['test_hand_traj_future']*100))
print("test_head_ang_future: {:.3f} deg".format(results[0]['test_head_ang_future']))
print("test_hand_traj_future_baseline: {:.3f} cm".format(results[0]['test_hand_traj_future_baseline']*100))
print("test_head_ang_future_baseline: {:.3f} deg\n\n".format(results[0]['test_head_ang_future_baseline']))
wandb.finish()
def acos_safe(x, eps=1e-6):
slope = np.arccos(1-eps) / eps
buf = torch.empty_like(x)
good = abs(x) <= 1-eps
bad = ~good
sign = torch.sign(x[bad])
buf[good] = torch.acos(x[good])
buf[bad] = torch.acos(sign * (1 - eps)) - slope*sign*(abs(x[bad]) - 1 + eps)
return buf
def get_representation(model, dataset, conf, device='cuda'):
model = model.to(device)
model.eval()
dataloader = torch.utils.data.DataLoader(dataset, batch_size=conf.batch_size_eval, shuffle=False)
with torch.no_grad():
conds = [] # semantic representation
xTs = [] # stochastic representation
for batch in tqdm(dataloader, total=len(dataloader), desc='infer'):
batch = batch.to(device)
cond, _, _ = model.encode(batch)
xT = model.encode_stochastic(batch, cond)
cond_cpu = cond.cpu().data.numpy()
xT_cpu = xT.cpu().data.numpy()
if len(conds) == 0:
conds = cond_cpu
xTs = xT_cpu
else:
conds = np.concatenate((conds, cond_cpu), axis=0)
xTs = np.concatenate((xTs, xT_cpu), axis=0)
return conds, xTs
def generate_from_representation(model, conds, xTs, device='cuda'):
model = model.to(device)
model.eval()
conds = torch.from_numpy(conds).to(device)
xTs = torch.from_numpy(xTs).to(device)
rec = model.generate(xTs, conds)
rec = rec.cpu().data.numpy()
return rec
def evaluate_reconstruction(gt, rec):
# hand reconstruction error (cm)
gt_hand = gt[:, :6, :]
rec_hand = rec[:, :6, :]
bs, channels, seq_len = gt_hand.shape
gt_hand = gt_hand.reshape(bs, 2, 3, seq_len)
rec_hand = rec_hand.reshape(bs, 2, 3, seq_len)
hand_traj_errors = np.mean(np.mean(np.linalg.norm(gt_hand - rec_hand, axis=2), axis=1), axis=1)*100
# head reconstruction error (deg)
gt_head = gt[:, 6:, :]
gt_head_norm = np.linalg.norm(gt_head, axis=1, keepdims=True)
gt_head = gt_head/gt_head_norm
rec_head = rec[:, 6:, :]
rec_head_norm = np.linalg.norm(rec_head, axis=1, keepdims=True)
rec_head = rec_head/rec_head_norm
dot_sum = np.clip(np.sum(gt_head*rec_head, axis=1), -1, 1)
head_ang_errors = np.mean(np.arccos(dot_sum), axis=1)/np.pi * 180.0
return hand_traj_errors, head_ang_errors
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpus', default=7, type=int)
parser.add_argument('--mode', default='eval', type=str)
parser.add_argument('--encoder_type', default='gcn', type=str)
parser.add_argument('--model_name', default='haheae', type=str)
parser.add_argument('--hand_mse_factor', default=1.0, type=float)
parser.add_argument('--head_mse_factor', default=1.0, type=float)
parser.add_argument('--data_sample_rate', default=1, type=int)
parser.add_argument('--epoch', default=130, type=int)
parser.add_argument('--in_channels', default=9, type=int)
args = parser.parse_args()
conf = egobody_autoenc(args.mode, args.encoder_type, args.hand_mse_factor, args.head_mse_factor, args.data_sample_rate, args.epoch, args.in_channels)
model = LitModel(conf)
conf.logdir = f'{conf.logdir}{args.model_name}/'
print('log dir: {}'.format(conf.logdir))
MakeDir(conf.logdir)
if conf.mode == 'train' or conf.mode == 'eval': # train or evaluate the model
os.environ['WANDB_CACHE_DIR'] = conf.logdir
os.environ['WANDB_DATA_DIR'] = conf.logdir
# set wandb to not upload checkpoints, but all the others
os.environ['WANDB_IGNORE_GLOBS'] = '*.ckpt'
local_time = time.asctime(time.localtime(time.time()))
print('\n{} starts at {}'.format(conf.mode, local_time))
start_time = datetime.datetime.now()
train(conf, model, gpus=[args.gpus])
end_time = datetime.datetime.now()
total_time = (end_time - start_time).seconds/60
print('\nTotal time: {:.3f} min'.format(total_time))
local_time = time.asctime(time.localtime(time.time()))
print('\n{} ends at {}'.format(conf.mode, local_time))

6
model/__init__.py Normal file
View file

@ -0,0 +1,6 @@
from typing import Union
from .unet import BeatGANsUNetModel, BeatGANsUNetConfig, GCNUNetModel, GCNUNetConfig
from .unet_autoenc import BeatGANsAutoencConfig, BeatGANsAutoencModel, GCNAutoencConfig, GCNAutoencModel
Model = Union[BeatGANsUNetModel, BeatGANsAutoencModel, GCNUNetModel, GCNAutoencModel]
ModelConfig = Union[BeatGANsUNetConfig, BeatGANsAutoencConfig, GCNUNetConfig,GCNAutoencConfig]

579
model/blocks.py Normal file
View file

@ -0,0 +1,579 @@
import math, pdb
from abc import abstractmethod
from dataclasses import dataclass
from numbers import Number
import torch as th
import torch.nn.functional as F
from choices import *
from config_base import BaseConfig
from torch import nn
import numpy as np
from .nn import (avg_pool_nd, conv_nd, linear, normalization,
timestep_embedding, zero_module)
class TimestepBlock(nn.Module):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@abstractmethod
def forward(self, x, emb=None, cond=None, lateral=None):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
"""
A sequential module that passes timestep embeddings to the children that
support it as an extra input.
"""
def forward(self, x, emb=None, cond=None, lateral=None):
for layer in self:
if isinstance(layer, TimestepBlock):
'''if layer(x, emb=emb, cond=cond, lateral=lateral).shape[-1]==10:
pdb.set_trace()'''
x = layer(x, emb=emb, cond=cond, lateral=lateral)
else:
'''if layer(x).shape[-1]==10:
pdb.set_trace()'''
x = layer(x)
return x
@dataclass
class ResBlockConfig(BaseConfig):
channels: int
emb_channels: int
dropout: float
out_channels: int = None
# condition the resblock with time (and encoder's output)
use_condition: bool = True
# whether to use 3x3 conv for skip path when the channels aren't matched
use_conv: bool = False
# dimension of conv (always 1 = 1d)
dims: int = 1
up: bool = False
down: bool = False
# whether to condition with both time & encoder's output
two_cond: bool = False
# number of encoders' output channels
cond_emb_channels: int = None
# suggest: False
has_lateral: bool = False
lateral_channels: int = None
# whether to init the convolution with zero weights
# this is default from BeatGANs and seems to help learning
use_zero_module: bool = True
def __post_init__(self):
self.out_channels = self.out_channels or self.channels
self.cond_emb_channels = self.cond_emb_channels or self.emb_channels
def make_model(self):
return ResBlock(self)
class ResBlock(TimestepBlock):
"""
A residual block that can optionally change the number of channels.
total layers:
in_layers
- norm
- act
- conv
out_layers
- norm
- (modulation)
- act
- conv
"""
def __init__(self, conf: ResBlockConfig):
super().__init__()
self.conf = conf
#############################
# IN LAYERS
#############################
assert conf.lateral_channels is None
layers = [
normalization(conf.channels),
nn.SiLU(),
conv_nd(conf.dims, conf.channels, conf.out_channels, 3, padding=1) ## 3 is kernel size
]
self.in_layers = nn.Sequential(*layers)
self.updown = conf.up or conf.down
if conf.up:
self.h_upd = Upsample(conf.channels, False, conf.dims)
self.x_upd = Upsample(conf.channels, False, conf.dims)
elif conf.down:
self.h_upd = Downsample(conf.channels, False, conf.dims)
self.x_upd = Downsample(conf.channels, False, conf.dims)
else:
self.h_upd = self.x_upd = nn.Identity()
#############################
# OUT LAYERS CONDITIONS
#############################
if conf.use_condition:
# condition layers for the out_layers
self.emb_layers = nn.Sequential(
nn.SiLU(),
linear(conf.emb_channels, 2 * conf.out_channels),
)
if conf.two_cond:
self.cond_emb_layers = nn.Sequential(
nn.SiLU(),
linear(conf.cond_emb_channels, conf.out_channels),
)
#############################
# OUT LAYERS (ignored when there is no condition)
#############################
# original version
conv = conv_nd(conf.dims,
conf.out_channels,
conf.out_channels,
3,
padding=1)
if conf.use_zero_module:
# zere out the weights
# it seems to help training
conv = zero_module(conv)
# construct the layers
# - norm
# - (modulation)
# - act
# - dropout
# - conv
layers = []
layers += [
normalization(conf.out_channels),
nn.SiLU(),
nn.Dropout(p=conf.dropout),
conv,
]
self.out_layers = nn.Sequential(*layers)
#############################
# SKIP LAYERS
#############################
if conf.out_channels == conf.channels:
# cannot be used with gatedconv, also gatedconv is alsways used as the first block
self.skip_connection = nn.Identity()
else:
if conf.use_conv:
kernel_size = 3
padding = 1
else:
kernel_size = 1
padding = 0
self.skip_connection = conv_nd(conf.dims,
conf.channels,
conf.out_channels,
kernel_size,
padding=padding)
def forward(self, x, emb=None, cond=None, lateral=None):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
Args:
x: input
lateral: lateral connection from the encoder
"""
return self._forward(x, emb, cond, lateral)
def _forward(
self,
x,
emb=None,
cond=None,
lateral=None,
):
"""
Args:
lateral: required if "has_lateral" and non-gated, with gated, it can be supplied optionally
"""
if self.conf.has_lateral:
# lateral may be supplied even if it doesn't require
# the model will take the lateral only if "has_lateral"
assert lateral is not None
# x = F.interpolate(x, size=(lateral.size(2)), mode='linear' )
x = th.cat([x, lateral], dim=1)
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
h = in_rest(x)
h = self.h_upd(h)
x = self.x_upd(x)
h = in_conv(h)
else:
h = self.in_layers(x)
if self.conf.use_condition:
# it's possible that the network may not receieve the time emb
# this happens with autoenc and setting the time_at
if emb is not None:
emb_out = self.emb_layers(emb).type(h.dtype)
else:
emb_out = None
if self.conf.two_cond:
# it's possible that the network is two_cond
# but it doesn't get the second condition
# in which case, we ignore the second condition
# and treat as if the network has one condition
if cond is None:
cond_out = None
else:
if not isinstance(cond, th.Tensor):
assert isinstance(cond, dict)
cond = cond['cond']
cond_out = self.cond_emb_layers(cond).type(h.dtype)
if cond_out is not None:
while len(cond_out.shape) < len(h.shape):
cond_out = cond_out[..., None]
else:
cond_out = None
# this is the new refactored code
h = apply_conditions(
h=h,
emb=emb_out,
cond=cond_out,
layers=self.out_layers,
scale_bias=1,
in_channels=self.conf.out_channels,
up_down_layer=None,
)
return self.skip_connection(x) + h
def apply_conditions(
h,
emb=None,
cond=None,
layers: nn.Sequential = None,
scale_bias: float = 1,
in_channels: int = 512,
up_down_layer: nn.Module = None,
):
"""
apply conditions on the feature maps
Args:
emb: time conditional (ready to scale + shift)
cond: encoder's conditional (ready to scale + shift)
"""
two_cond = emb is not None and cond is not None
if emb is not None:
# adjusting shapes
while len(emb.shape) < len(h.shape):
emb = emb[..., None]
if two_cond:
# adjusting shapes
while len(cond.shape) < len(h.shape):
cond = cond[..., None]
# time first
scale_shifts = [emb, cond]
else:
# "cond" is not used with single cond mode
scale_shifts = [emb]
# support scale, shift or shift only
for i, each in enumerate(scale_shifts):
if each is None:
# special case: the condition is not provided
a = None
b = None
else:
if each.shape[1] == in_channels * 2:
a, b = th.chunk(each, 2, dim=1)
else:
a = each
b = None
scale_shifts[i] = (a, b)
# condition scale bias could be a list
if isinstance(scale_bias, Number):
biases = [scale_bias] * len(scale_shifts)
else:
# a list
biases = scale_bias
# default, the scale & shift are applied after the group norm but BEFORE SiLU
pre_layers, post_layers = layers[0], layers[1:]
# spilt the post layer to be able to scale up or down before conv
# post layers will contain only the conv
mid_layers, post_layers = post_layers[:-2], post_layers[-2:]
h = pre_layers(h)
# scale and shift for each condition
for i, (scale, shift) in enumerate(scale_shifts):
# if scale is None, it indicates that the condition is not provided
if scale is not None:
h = h * (biases[i] + scale)
if shift is not None:
h = h + shift
h = mid_layers(h)
# upscale or downscale if any just before the last conv
if up_down_layer is not None:
h = up_down_layer(h)
h = post_layers(h)
return h
class Upsample(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
if use_conv:
self.conv = conv_nd(dims,
self.channels,
self.out_channels,
3,
padding=1)
def forward(self, x):
assert x.shape[1] == self.channels
if self.dims == 3:
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
mode="nearest")
else:
# if x.shape[2] == 4:
# feature = 9
# x = F.interpolate(x, size=(feature), mode="nearest")
# if x.shape[2] == 8:
# feature = 9
# x = F.interpolate(x, size=(feature), mode="nearest")
# else:
x = F.interpolate(x, scale_factor=2, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
self.stride = 2 if dims != 3 else (1, 2, 2)
if use_conv:
self.op = conv_nd(dims,
self.channels,
self.out_channels,
3,
stride=self.stride,
padding=1)
else:
assert self.channels == self.out_channels
self.op = avg_pool_nd(dims, kernel_size=self.stride, stride=self.stride)
def forward(self, x):
assert x.shape[1] == self.channels
# if x.shape[2] % 2 != 0:
# op = avg_pool_nd(self.dims, kernel_size=3, stride=2)
# return op(x)
# if x.shape[2] % 2 != 0:
# op = avg_pool_nd(self.dims, kernel_size=2, stride=1)
# return op(x)
# else:
return self.op(x)
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def __init__(
self,
channels,
num_heads=1,
num_head_channels=-1,
use_new_attention_order=False,
):
super().__init__()
self.channels = channels
if num_head_channels == -1:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.norm = normalization(channels)
self.qkv = conv_nd(1, channels, channels * 3, 1)
if use_new_attention_order:
# split qkv before split heads
self.attention = QKVAttention(self.num_heads)
else:
# split heads before split qkv
self.attention = QKVAttentionLegacy(self.num_heads)
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
def forward(self, x):
return self._forward(x)
def _forward(self, x):
b, c, *spatial = x.shape
x = x.reshape(b, c, -1)
qkv = self.qkv(self.norm(x))
h = self.attention(qkv)
h = self.proj_out(h)
return (x + h).reshape(b, c, *spatial)
def count_flops_attn(model, _x, y):
"""
A counter for the `thop` package to count the operations in an
attention operation.
Meant to be used like:
macs, params = thop.profile(
model,
inputs=(inputs, timestamps),
custom_ops={QKVAttention: QKVAttention.count_flops},
)
"""
b, c, *spatial = y[0].shape
num_spatial = int(np.prod(spatial))
# We perform two matmuls with the same number of ops.
# The first computes the weight matrix, the second computes
# the combination of the value vectors.
matmul_ops = 2 * b * (num_spatial**2) * c
model.total_ops += th.DoubleTensor([matmul_ops])
class QKVAttentionLegacy(nn.Module):
"""
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
"""
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv):
"""
Apply QKV attention.
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum(
"bct,bcs->bts", q * scale,
k * scale) # More stable with f16 than dividing afterwards
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
a = th.einsum("bts,bcs->bct", weight, v)
return a.reshape(bs, -1, length)
@staticmethod
def count_flops(model, _x, y):
return count_flops_attn(model, _x, y)
class QKVAttention(nn.Module):
"""
A module which performs QKV attention and splits in a different order.
"""
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv):
"""
Apply QKV attention.
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs, width, length = qkv.shape
pdb.set_trace()
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.chunk(3, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum(
"bct,bcs->bts",
(q * scale).view(bs * self.n_heads, ch, length),
(k * scale).view(bs * self.n_heads, ch, length),
) # More stable with f16 than dividing afterwards
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
a = th.einsum("bts,bcs->bct", weight,
v.reshape(bs * self.n_heads, ch, length))
return a.reshape(bs, -1, length)
@staticmethod
def count_flops(model, _x, y):
return count_flops_attn(model, _x, y)
class AttentionPool2d(nn.Module):
"""
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
"""
def __init__(
self,
spacial_dim: int,
embed_dim: int,
num_heads_channels: int,
output_dim: int = None,
):
super().__init__()
self.positional_embedding = nn.Parameter(
th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5)
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
self.num_heads = embed_dim // num_heads_channels
self.attention = QKVAttention(self.num_heads)
def forward(self, x):
b, c, *_spatial = x.shape
x = x.reshape(b, c, -1) # NC(HW)
x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
x = self.qkv_proj(x)
x = self.attention(x)
x = self.c_proj(x)
return x[:, :, 0]

View file

@ -0,0 +1,173 @@
import torch.nn as nn
import torch
from dataclasses import dataclass
from torch.nn.parameter import Parameter
from numbers import Number
import torch.nn.functional as F
from .blocks import *
import math
class graph_convolution(nn.Module):
def __init__(self, in_features, out_features, node_n = 3, seq_len = 80, bias=True):
super(graph_convolution, self).__init__()
self.temporal_graph_weights = Parameter(torch.FloatTensor(seq_len, seq_len))
self.feature_weights = Parameter(torch.FloatTensor(in_features, out_features))
self.spatial_graph_weights = Parameter(torch.FloatTensor(node_n, node_n))
if bias:
self.bias = Parameter(torch.FloatTensor(seq_len))
self.reset_parameters()
def reset_parameters(self):
stdv = 1. / math.sqrt(self.spatial_graph_weights.size(1))
self.feature_weights.data.uniform_(-stdv, stdv)
self.temporal_graph_weights.data.uniform_(-stdv, stdv)
self.spatial_graph_weights.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)
def forward(self, x):
y = torch.matmul(x, self.temporal_graph_weights)
y = torch.matmul(y.permute(0, 3, 2, 1), self.feature_weights)
y = torch.matmul(self.spatial_graph_weights, y).permute(0, 3, 2, 1).contiguous()
if self.bias is not None:
return (y + self.bias)
else:
return y
@dataclass
class residual_graph_convolution_config():
in_features: int
seq_len: int
emb_channels: int
dropout: float
out_features: int = None
node_n: int = 3
# condition the block with time (and encoder's output)
use_condition: bool = True
# whether to condition with both time & encoder's output
two_cond: bool = False
# number of encoders' output channels
cond_emb_channels: int = None
has_lateral: bool = False
graph_convolution_bias: bool = True
scale_bias: float = 1
def __post_init__(self):
self.out_features = self.out_features or self.in_features
self.cond_emb_channels = self.cond_emb_channels or self.emb_channels
def make_model(self):
return residual_graph_convolution(self)
class residual_graph_convolution(TimestepBlock):
def __init__(self, conf: residual_graph_convolution_config):
super(residual_graph_convolution, self).__init__()
self.conf = conf
self.gcn = graph_convolution(conf.in_features, conf.out_features, node_n=conf.node_n, seq_len=conf.seq_len, bias=conf.graph_convolution_bias)
self.ln = nn.LayerNorm([conf.out_features, conf.node_n, conf.seq_len])
self.act_f = nn.Tanh()
self.dropout = nn.Dropout(conf.dropout)
if conf.use_condition:
# condition layers for the out_layers
self.emb_layers = nn.Sequential(
nn.SiLU(),
nn.Linear(conf.emb_channels, conf.out_features),
)
if conf.two_cond:
self.cond_emb_layers = nn.Sequential(
nn.SiLU(),
nn.Linear(conf.cond_emb_channels, conf.out_features),
)
if conf.in_features == conf.out_features:
self.skip_connection = nn.Identity()
else:
self.skip_connection = nn.Sequential(
graph_convolution(conf.in_features, conf.out_features, node_n=conf.node_n, seq_len=conf.seq_len, bias=conf.graph_convolution_bias),
nn.Tanh(),
)
def forward(self, x, emb=None, cond=None, lateral=None):
if self.conf.has_lateral:
# lateral may be supplied even if it doesn't require
# the model will take the lateral only if "has_lateral"
assert lateral is not None
x = torch.cat((x, lateral), dim =1)
y = self.gcn(x)
y = self.ln(y)
if self.conf.use_condition:
if emb is not None:
emb = self.emb_layers(emb).type(x.dtype)
# adjusting shapes
while len(emb.shape) < len(y.shape):
emb = emb[..., None]
if self.conf.two_cond or True:
if cond is not None:
if not isinstance(cond, torch.Tensor):
assert isinstance(cond, dict)
cond = cond['cond']
cond = self.cond_emb_layers(cond).type(x.dtype)
while len(cond.shape) < len(y.shape):
cond = cond[..., None]
scales = [emb, cond]
else:
scales = [emb]
# condition scale bias could be a list
if isinstance(self.conf.scale_bias, Number):
biases = [self.conf.scale_bias] * len(scales)
else:
# a list
biases = self.conf.scale_bias
# scale for each condition
for i, scale in enumerate(scales):
# if scale is None, it indicates that the condition is not provided
if scale is not None:
y = y*(biases[i] + scale)
y = self.act_f(y)
y = self.dropout(y)
return self.skip_connection(x) + y
class graph_downsample(nn.Module):
"""
A downsampling layer
"""
def __init__(self, kernel_size = 2):
super().__init__()
self.downsample = nn.AvgPool1d(kernel_size = kernel_size)
def forward(self, x):
bs, features, node_n, seq_len = x.shape
x = x.reshape(bs, features*node_n, seq_len)
x = self.downsample(x)
x = x.reshape(bs, features, node_n, -1)
return x
class graph_upsample(nn.Module):
"""
An upsampling layer
"""
def __init__(self, scale_factor=2):
super().__init__()
self.scale_factor = scale_factor
def forward(self, x):
x = F.interpolate(x, (x.shape[2], x.shape[3]*self.scale_factor), mode="nearest")
return x

141
model/nn.py Normal file
View file

@ -0,0 +1,141 @@
"""
Various utilities for neural networks.
"""
from enum import Enum
import math, pdb
from typing import Optional
import torch as th
import torch.nn as nn
import torch.utils.checkpoint
import torch.nn.functional as F
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
class SiLU(nn.Module):
# @th.jit.script
def forward(self, x):
return x * th.sigmoid(x)
class GroupNorm32(nn.GroupNorm):
def forward(self, x):
y = super().forward(x.float()).type(x.dtype)
return y
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
assert dims==1
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def linear(*args, **kwargs):
"""
Create a linear module.
"""
return nn.Linear(*args, **kwargs)
def avg_pool_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
assert dims==1
if dims == 1:
return nn.AvgPool1d(*args, **kwargs)
elif dims == 2:
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def update_ema(target_params, source_params, rate=0.99):
"""
Update target parameters to be closer to those of source parameters using
an exponential moving average.
:param target_params: the target parameter sequence.
:param source_params: the source parameter sequence.
:param rate: the EMA rate (closer to 1 means slower).
"""
for targ, src in zip(target_params, source_params):
targ.detach().mul_(rate).add_(src, alpha=1 - rate)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def scale_module(module, scale):
"""
Scale the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().mul_(scale)
return module
def mean_flat(tensor):
"""
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))
def normalization(channels):
"""
Make a standard normalization layer.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
# return GroupNorm32(channels, channels)
return GroupNorm32(min(4, channels), channels)
def timestep_embedding(timesteps, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
half = dim // 2
freqs = th.exp(-math.log(max_period) *
th.arange(start=0, end=half, dtype=th.float32) /
half).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
if dim % 2:
embedding = th.cat(
[embedding, th.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def torch_checkpoint(func, args, flag, preserve_rng_state=False):
# torch's gradient checkpoint works with automatic mixed precision, given torch >= 1.8
if flag:
return torch.utils.checkpoint.checkpoint(
func, *args, preserve_rng_state=preserve_rng_state)
else:
return func(*args)

954
model/unet.py Normal file
View file

@ -0,0 +1,954 @@
import math
from dataclasses import dataclass
from numbers import Number
from typing import NamedTuple, Tuple, Union
import numpy as np
import torch as th
from torch import nn
import torch.nn.functional as F
from choices import *
from config_base import BaseConfig
from .blocks import *
from .graph_convolution_network import *
from .nn import (conv_nd, linear, normalization, timestep_embedding, zero_module)
@dataclass
class BeatGANsUNetConfig(BaseConfig):
seq_len: int = 80
in_channels: int = 9
# base channels, will be multiplied
model_channels: int = 64
# output of the unet
out_channels: int = 9
# how many repeating resblocks per resolution
# the decoding side would have "one more" resblock
# default: 2
num_res_blocks: int = 2
# number of time embed channels and style channels
embed_channels: int = 256
# at what resolutions you want to do self-attention of the feature maps
# attentions generally improve performance
attention_resolutions: Tuple[int] = (0, )
# dropout applies to the resblocks (on feature maps)
dropout: float = 0.1
channel_mult: Tuple[int] = (1, 2, 4)
conv_resample: bool = True
# 1 = 1d conv
dims: int = 1
# number of attention heads
num_heads: int = 1
# or specify the number of channels per attention head
num_head_channels: int = -1
# use resblock for upscale/downscale blocks (expensive)
# default: True (BeatGANs)
resblock_updown: bool = True
use_new_attention_order: bool = False
resnet_two_cond: bool = True
resnet_cond_channels: int = None
# init the decoding conv layers with zero weights, this speeds up training
# default: True (BeatGANs)
resnet_use_zero_module: bool = True
def make_model(self):
return BeatGANsUNetModel(self)
class BeatGANsUNetModel(nn.Module):
def __init__(self, conf: BeatGANsUNetConfig):
super().__init__()
self.conf = conf
self.dtype = th.float32
self.time_emb_channels = conf.model_channels
self.time_embed = nn.Sequential(
linear(self.time_emb_channels, conf.embed_channels),
nn.SiLU(),
linear(conf.embed_channels, conf.embed_channels),
)
ch = input_ch = int(conf.channel_mult[0] * conf.model_channels)
self.input_blocks = nn.ModuleList([
TimestepEmbedSequential(
conv_nd(conf.dims, conf.in_channels, ch, 3, padding=1)),
])
kwargs = dict(
use_condition=True,
two_cond=conf.resnet_two_cond,
use_zero_module=conf.resnet_use_zero_module,
# style channels for the resnet block
cond_emb_channels=conf.resnet_cond_channels,
)
self._feature_size = ch
# input_block_chans = [ch]
input_block_chans = [[] for _ in range(len(conf.channel_mult))]
input_block_chans[0].append(ch)
# number of blocks at each resolution
self.input_num_blocks = [0 for _ in range(len(conf.channel_mult))]
self.input_num_blocks[0] = 1
self.output_num_blocks = [0 for _ in range(len(conf.channel_mult))]
ds = 1
resolution = conf.seq_len
for level, mult in enumerate(conf.channel_mult):
for _ in range(conf.num_res_blocks):
layers = [
ResBlockConfig(
ch,
conf.embed_channels,
conf.dropout,
out_channels=int(mult * conf.model_channels),
dims=conf.dims,
**kwargs,
).make_model()
]
ch = int(mult * conf.model_channels)
if resolution in conf.attention_resolutions:
layers.append(
AttentionBlock(
ch,
num_heads=conf.num_heads,
num_head_channels=conf.num_head_channels,
use_new_attention_order=conf.
use_new_attention_order,
))
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
# input_block_chans.append(ch)
input_block_chans[level].append(ch)
self.input_num_blocks[level] += 1
# print(input_block_chans)
if level != len(conf.channel_mult) - 1:
resolution //= 2
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlockConfig(
ch,
conf.embed_channels,
conf.dropout,
out_channels=out_ch,
dims=conf.dims,
down=True,
**kwargs,
).make_model() if conf.
resblock_updown else Downsample(ch,
conf.conv_resample,
dims=conf.dims,
out_channels=out_ch)))
ch = out_ch
# input_block_chans.append(ch)
input_block_chans[level + 1].append(ch)
self.input_num_blocks[level + 1] += 1
ds *= 2
self._feature_size += ch
self.middle_block = TimestepEmbedSequential(
ResBlockConfig(
ch,
conf.embed_channels,
conf.dropout,
dims=conf.dims,
**kwargs,
).make_model(),
#AttentionBlock(
# ch,
# num_heads=conf.num_heads,
# num_head_channels=conf.num_head_channels,
# use_new_attention_order=conf.use_new_attention_order,
#),
ResBlockConfig(
ch,
conf.embed_channels,
conf.dropout,
dims=conf.dims,
**kwargs,
).make_model(),
)
self._feature_size += ch
self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(conf.channel_mult))[::-1]:
for i in range(conf.num_res_blocks + 1):
# print(input_block_chans)
# ich = input_block_chans.pop()
try:
ich = input_block_chans[level].pop()
except IndexError:
# this happens only when num_res_block > num_enc_res_block
# we will not have enough lateral (skip) connecions for all decoder blocks
ich = 0
# print('pop:', ich)
layers = [
ResBlockConfig(
# only direct channels when gated
channels=ch + ich,
emb_channels=conf.embed_channels,
dropout=conf.dropout,
out_channels=int(conf.model_channels * mult),
dims=conf.dims,
# lateral channels are described here when gated
has_lateral=True if ich > 0 else False,
lateral_channels=None,
**kwargs,
).make_model()
]
ch = int(conf.model_channels * mult)
if resolution in conf.attention_resolutions:
layers.append(
AttentionBlock(
ch,
num_heads=conf.num_heads,
num_head_channels=conf.num_head_channels,
use_new_attention_order=conf.
use_new_attention_order,
))
if level and i == conf.num_res_blocks:
resolution *= 2
out_ch = ch
layers.append(
ResBlockConfig(
ch,
conf.embed_channels,
conf.dropout,
out_channels=out_ch,
dims=conf.dims,
up=True,
**kwargs,
).make_model() if (
conf.resblock_updown
) else Upsample(ch,
conf.conv_resample,
dims=conf.dims,
out_channels=out_ch))
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
self.output_num_blocks[level] += 1
self._feature_size += ch
# print(input_block_chans)
# print('inputs:', self.input_num_blocks)
# print('outputs:', self.output_num_blocks)
if conf.resnet_use_zero_module:
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
zero_module(
conv_nd(conf.dims,
input_ch,
conf.out_channels,
3, ## kernel size
padding=1)),
)
else:
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
conv_nd(conf.dims, input_ch, conf.out_channels, 3, padding=1),
)
def forward(self, x, t, **kwargs):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:return: an [N x C x ...] Tensor of outputs.
"""
# hs = []
hs = [[] for _ in range(len(self.conf.channel_mult))]
emb = self.time_embed(timestep_embedding(t, self.time_emb_channels))
# new code supports input_num_blocks != output_num_blocks
h = x.type(self.dtype)
k = 0
for i in range(len(self.input_num_blocks)):
for j in range(self.input_num_blocks[i]):
h = self.input_blocks[k](h, emb=emb)
# print(i, j, h.shape)
hs[i].append(h) ## Get output from each layer
k += 1
assert k == len(self.input_blocks)
# middle blocks
h = self.middle_block(h, emb=emb)
# output blocks
k = 0
for i in range(len(self.output_num_blocks)):
for j in range(self.output_num_blocks[i]):
# take the lateral connection from the same layer (in reserve)
# until there is no more, use None
try:
lateral = hs[-i - 1].pop()
# print(i, j, lateral.shape)
except IndexError:
lateral = None
# print(i, j, lateral)
h = self.output_blocks[k](h, emb=emb, lateral=lateral)
k += 1
h = h.type(x.dtype)
pred = self.out(h)
return Return(pred=pred)
class Return(NamedTuple):
pred: th.Tensor
@dataclass
class BeatGANsEncoderConfig(BaseConfig):
in_channels: int
seq_len: int = 80
num_res_blocks: int = 2
attention_resolutions: Tuple[int] = (0, )
model_channels: int = 32
out_channels: int = 256
dropout: float = 0.1
channel_mult: Tuple[int] = (1, 2, 4)
use_time_condition: bool = False
conv_resample: bool = True
dims: int = 1
num_heads: int = 1
num_head_channels: int = -1
resblock_updown: bool = True
use_new_attention_order: bool = False
pool: str = 'adaptivenonzero'
def make_model(self):
return BeatGANsEncoderModel(self)
class BeatGANsEncoderModel(nn.Module):
"""
The half UNet model with attention and timestep embedding.
For usage, see UNet.
"""
def __init__(self, conf: BeatGANsEncoderConfig):
super().__init__()
self.conf = conf
self.dtype = th.float32
if conf.use_time_condition:
time_embed_dim = conf.model_channels
self.time_embed = nn.Sequential(
linear(conf.model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
else:
time_embed_dim = None
ch = int(conf.channel_mult[0] * conf.model_channels)
self.input_blocks = nn.ModuleList([
TimestepEmbedSequential(
conv_nd(conf.dims, conf.in_channels, ch, 3, padding=1),
)
])
self._feature_size = ch
input_block_chans = [ch]
ds = 1
resolution = conf.seq_len
for level, mult in enumerate(conf.channel_mult):
for _ in range(conf.num_res_blocks):
layers = [
ResBlockConfig(
ch,
time_embed_dim,
conf.dropout,
out_channels=int(mult * conf.model_channels),
dims=conf.dims,
use_condition=conf.use_time_condition,
).make_model()
]
ch = int(mult * conf.model_channels)
if resolution in conf.attention_resolutions:
layers.append(
AttentionBlock(
ch,
num_heads=conf.num_heads,
num_head_channels=conf.num_head_channels,
use_new_attention_order=conf.
use_new_attention_order,
))
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(conf.channel_mult) - 1:
resolution //= 2
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlockConfig(
ch,
time_embed_dim,
conf.dropout,
out_channels=out_ch,
dims=conf.dims,
use_condition=conf.use_time_condition,
down=True,
).make_model() if (
conf.resblock_updown
) else Downsample(ch,
conf.conv_resample,
dims=conf.dims,
out_channels=out_ch)))
ch = out_ch
input_block_chans.append(ch)
ds *= 2
self._feature_size += ch
self.middle_block = TimestepEmbedSequential(
ResBlockConfig(
ch,
time_embed_dim,
conf.dropout,
dims=conf.dims,
use_condition=conf.use_time_condition,
).make_model(),
AttentionBlock(
ch,
num_heads=conf.num_heads,
num_head_channels=conf.num_head_channels,
use_new_attention_order=conf.use_new_attention_order,
),
ResBlockConfig(
ch,
time_embed_dim,
conf.dropout,
dims=conf.dims,
use_condition=conf.use_time_condition,
).make_model(),
)
self._feature_size += ch
if conf.pool == "adaptivenonzero":
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
## nn.AdaptiveAvgPool2d((1, 1)),
nn.AdaptiveAvgPool1d(1),
conv_nd(conf.dims, ch, conf.out_channels, 1),
nn.Flatten(),
)
else:
raise NotImplementedError(f"Unexpected {conf.pool} pooling")
def forward(self, x, t=None):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:return: an [N x K] Tensor of outputs.
"""
if self.conf.use_time_condition:
emb = self.time_embed(timestep_embedding(t, self.model_channels))
else: ## autoencoding.py
emb = None
results = []
h = x.type(self.dtype)
for module in self.input_blocks: ## flow input x over all the input blocks
h = module(h, emb=emb)
if self.conf.pool.startswith("spatial"):
results.append(h.type(x.dtype).mean(dim=(2, 3)))
h = self.middle_block(h, emb=emb) ## TimestepEmbedSequential(...)
if self.conf.pool.startswith("spatial"):
results.append(h.type(x.dtype).mean(dim=(2, 3)))
h = th.cat(results, axis=-1)
else: ## autoencoder.py
h = h.type(x.dtype)
h = h.float()
h = self.out(h)
return h
@dataclass
class GCNUNetConfig(BaseConfig):
in_channels: int = 9
node_n: int = 3
seq_len: int = 80
# base channels, will be multiplied
model_channels: int = 32
# output of the unet
out_channels: int = 9
# how many repeating resblocks per resolution
num_res_blocks: int = 8
# number of time embed channels and style channels
embed_channels: int = 256
# dropout applies to the resblocks
dropout: float = 0.1
channel_mult: Tuple[int] = (1, 2, 4)
resnet_two_cond: bool = True
def make_model(self):
return GCNUNetModel(self)
class GCNUNetModel(nn.Module):
def __init__(self, conf: GCNUNetConfig):
super().__init__()
self.conf = conf
self.dtype = th.float32
assert conf.in_channels%conf.node_n == 0
self.in_features = conf.in_channels//conf.node_n
self.time_emb_channels = conf.model_channels*4
self.time_embed = nn.Sequential(
linear(self.time_emb_channels, conf.embed_channels),
nn.SiLU(),
linear(conf.embed_channels, conf.embed_channels),
)
ch = int(conf.channel_mult[0] * conf.model_channels)
self.input_blocks = nn.ModuleList([
TimestepEmbedSequential(
graph_convolution(in_features=self.in_features, out_features=ch, node_n=conf.node_n, seq_len=conf.seq_len)),
])
kwargs = dict(
use_condition=True,
two_cond=conf.resnet_two_cond,
)
input_block_chans = [[] for _ in range(len(conf.channel_mult))]
input_block_chans[0].append(ch)
# number of blocks at each resolution
self.input_num_blocks = [0 for _ in range(len(conf.channel_mult))]
self.input_num_blocks[0] = 1
self.output_num_blocks = [0 for _ in range(len(conf.channel_mult))]
ds = 1
resolution = conf.seq_len
for level, mult in enumerate(conf.channel_mult):
for _ in range(conf.num_res_blocks):
layers = [
residual_graph_convolution_config(
in_features=ch,
seq_len=resolution,
emb_channels = conf.embed_channels,
dropout=conf.dropout,
out_features=int(mult * conf.model_channels),
node_n=conf.node_n,
**kwargs,
).make_model()
]
ch = int(mult * conf.model_channels)
self.input_blocks.append(*layers)
input_block_chans[level].append(ch)
self.input_num_blocks[level] += 1
if level != len(conf.channel_mult) - 1:
resolution //= 2
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
graph_downsample()))
ch = out_ch
input_block_chans[level + 1].append(ch)
self.input_num_blocks[level + 1] += 1
ds *= 2
self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(conf.channel_mult))[::-1]:
for i in range(conf.num_res_blocks + 1):
try:
ich = input_block_chans[level].pop()
except IndexError:
# this happens only when num_res_block > num_enc_res_block
# we will not have enough lateral (skip) connecions for all decoder blocks
ich = 0
layers = [
residual_graph_convolution_config(
in_features=ch + ich,
seq_len=resolution,
emb_channels = conf.embed_channels,
dropout=conf.dropout,
out_features=int(mult * conf.model_channels),
node_n=conf.node_n,
has_lateral=True if ich > 0 else False,
**kwargs,
).make_model()
]
ch = int(mult*conf.model_channels)
if level and i == conf.num_res_blocks:
resolution *= 2
out_ch = ch
layers.append(graph_upsample())
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
self.output_num_blocks[level] += 1
self.out = nn.Sequential(
graph_convolution(in_features=ch, out_features=self.in_features, node_n=conf.node_n, seq_len=conf.seq_len),
nn.Tanh(),
)
def forward(self, x, t, **kwargs):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:return: an [N x C x ...] Tensor of outputs.
"""
bs, channels, seq_len = x.shape
x = x.reshape(bs, self.conf.node_n, self.in_features, seq_len).permute(0, 2, 1, 3)
hs = [[] for _ in range(len(self.conf.channel_mult))]
emb = self.time_embed(timestep_embedding(t, self.time_emb_channels))
# new code supports input_num_blocks != output_num_blocks
h = x.type(self.dtype)
k = 0
for i in range(len(self.input_num_blocks)):
for j in range(self.input_num_blocks[i]):
h = self.input_blocks[k](h, emb=emb)
# print(i, j, h.shape)
hs[i].append(h) ## Get output from each layer
k += 1
assert k == len(self.input_blocks)
# output blocks
k = 0
for i in range(len(self.output_num_blocks)):
for j in range(self.output_num_blocks[i]):
# take the lateral connection from the same layer (in reserve)
# until there is no more, use None
try:
lateral = hs[-i - 1].pop()
# print(i, j, lateral.shape)
except IndexError:
lateral = None
# print(i, j, lateral)
h = self.output_blocks[k](h, emb=emb, lateral=lateral)
k += 1
h = h.type(x.dtype)
pred = self.out(h)
pred = pred.permute(0, 2, 1, 3).reshape(bs, -1, seq_len)
return Return(pred=pred)
@dataclass
class GCNEncoderConfig(BaseConfig):
in_channels: int
in_features = 3 # features for one node
seq_len: int = 40
seq_len_future: int = 3
num_res_blocks: int = 2
model_channels: int = 32
out_channels: int = 32
dropout: float = 0.1
channel_mult: Tuple[int] = (1, 2, 4)
use_time_condition: bool = False
def make_model(self):
return GCNEncoderModel(self)
class GCNEncoderModel(nn.Module):
def __init__(self, conf: GCNEncoderConfig):
super().__init__()
self.conf = conf
self.dtype = th.float32
assert conf.in_channels%conf.in_features == 0
self.in_features = conf.in_features
self.node_n = conf.in_channels//conf.in_features
ch = int(conf.channel_mult[0] * conf.model_channels)
self.input_blocks = nn.ModuleList([
graph_convolution(in_features=self.in_features, out_features=ch, node_n=self.node_n, seq_len=conf.seq_len),
])
input_block_chans = [ch]
ds = 1
resolution = conf.seq_len
for level, mult in enumerate(conf.channel_mult):
for _ in range(conf.num_res_blocks):
layers = [
residual_graph_convolution_config(
in_features=ch,
seq_len=resolution,
emb_channels = None,
dropout=conf.dropout,
out_features=int(mult * conf.model_channels),
node_n=self.node_n,
use_condition=conf.use_time_condition,
).make_model()
]
ch = int(mult * conf.model_channels)
self.input_blocks.append(*layers)
input_block_chans.append(ch)
if level != len(conf.channel_mult) - 1:
resolution //= 2
out_ch = ch
self.input_blocks.append(
graph_downsample())
ch = out_ch
input_block_chans.append(ch)
ds *= 2
self.hand_prediction = nn.Sequential(
conv_nd(1, ch*2, ch*2, 3, padding=1),
nn.LayerNorm([ch*2, conf.seq_len_future]),
nn.Tanh(),
conv_nd(1, ch*2, self.in_features*2, 1),
nn.Tanh(),
)
self.head_prediction = nn.Sequential(
conv_nd(1, ch, ch, 3, padding=1),
nn.LayerNorm([ch, conf.seq_len_future]),
nn.Tanh(),
conv_nd(1, ch, self.in_features, 1),
nn.Tanh(),
)
self.out = nn.Sequential(
nn.AdaptiveAvgPool1d(1),
conv_nd(1, ch*self.node_n, conf.out_channels, 1),
nn.Flatten(),
)
def forward(self, x, t=None):
bs, channels, seq_len = x.shape
if self.node_n == 3: # both hand and head
hand_last = x[:, :6, -1:].expand(-1, -1, self.conf.seq_len_future).clone() #last observed hand position
head_last = x[:, 6:, -1:].expand(-1, -1, self.conf.seq_len_future).clone()# last observed head orientation
if self.node_n == 2: # hand only
hand_last = x[:, :, -1:].expand(-1, -1, self.conf.seq_len_future).clone() #last observed hand position
if self.node_n == 1: # head only
head_last = x[:, :, -1:].expand(-1, -1, self.conf.seq_len_future).clone()# last observed head orientation
x = x.reshape(bs, self.node_n, self.in_features, seq_len).permute(0, 2, 1, 3)
h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h)
h = h.type(x.dtype)
h = h.float()
bs, features, node_n, seq_len = h.shape
if self.node_n == 3: # both hand and head
hand_features = h[:, :, :2, -self.conf.seq_len_future:].reshape(bs, features*2, -1)
head_features = h[:, :, 2:, -self.conf.seq_len_future:].reshape(bs, features, -1)
pred_hand = self.hand_prediction(hand_features) + hand_last
pred_head = self.head_prediction(head_features) + head_last
pred_head = F.normalize(pred_head, dim=1)# normalize head orientation to unit vectors
if self.node_n == 2: # hand only
hand_features = h[:, :, :, -self.conf.seq_len_future:].reshape(bs, features*2, -1)
pred_hand = self.hand_prediction(hand_features) + hand_last
pred_head = None
if self.node_n == 1: # head only
head_features = h[:, :, :, -self.conf.seq_len_future:].reshape(bs, features, -1)
pred_head = self.head_prediction(head_features) + head_last
pred_head = F.normalize(pred_head, dim=1)# normalize head orientation to unit vectors
pred_hand = None
h = h.reshape(bs, features*node_n, seq_len)
h = self.out(h)
return h, pred_hand, pred_head
@dataclass
class CNNEncoderConfig(BaseConfig):
in_channels: int
seq_len: int = 40
seq_len_future: int = 3
out_channels: int = 128
def make_model(self):
return CNNEncoderModel(self)
class CNNEncoderModel(nn.Module):
def __init__(self, conf: CNNEncoderConfig):
super().__init__()
self.conf = conf
self.dtype = th.float32
input_dim = conf.in_channels
length = conf.seq_len
out_channels = conf.out_channels
self.encoder = nn.Sequential(
nn.Conv1d(input_dim, 32, kernel_size=3, padding=1),
nn.LayerNorm([32, length]),
nn.ReLU(inplace=True),
nn.Conv1d(32, 32, kernel_size=3, padding=1),
nn.LayerNorm([32, length]),
nn.ReLU(inplace=True),
nn.Conv1d(32, 32, kernel_size=3, padding=1),
nn.LayerNorm([32, length]),
nn.ReLU(inplace=True)
)
self.out = nn.Linear(32 * length, out_channels)
def forward(self, x, t=None):
bs, channels, seq_len = x.shape
hand_last = x[:, :6, -1:].expand(-1, -1, self.conf.seq_len_future).clone() #last observed hand position
head_last = x[:, 6:, -1:].expand(-1, -1, self.conf.seq_len_future).clone()# last observed head orientation
h = x.type(self.dtype)
h = self.encoder(h)
h = h.view(h.shape[0], -1)
h = h.type(x.dtype)
h = h.float()
h = self.out(h)
return h, hand_last, head_last
@dataclass
class GRUEncoderConfig(BaseConfig):
in_channels: int
seq_len: int = 40
seq_len_future: int = 3
out_channels: int = 128
def make_model(self):
return GRUEncoderModel(self)
class GRUEncoderModel(nn.Module):
def __init__(self, conf: GRUEncoderConfig):
super().__init__()
self.conf = conf
self.dtype = th.float32
input_dim = conf.in_channels
length = conf.seq_len
feature_channels = 32
out_channels = conf.out_channels
self.encoder = nn.GRU(input_dim, feature_channels, 1, batch_first=True)
self.out = nn.Linear(feature_channels * length, out_channels)
def forward(self, x, t=None):
bs, channels, seq_len = x.shape
hand_last = x[:, :6, -1:].expand(-1, -1, self.conf.seq_len_future).clone() #last observed hand position
head_last = x[:, 6:, -1:].expand(-1, -1, self.conf.seq_len_future).clone()# last observed head orientation
h = x.type(self.dtype)
h, _ = self.encoder(h.permute(0, 2, 1))
h = h.reshape(h.shape[0], -1)
h = h.type(x.dtype)
h = h.float()
h = self.out(h)
return h, hand_last, head_last
@dataclass
class LSTMEncoderConfig(BaseConfig):
in_channels: int
seq_len: int = 40
seq_len_future: int = 3
out_channels: int = 128
def make_model(self):
return LSTMEncoderModel(self)
class LSTMEncoderModel(nn.Module):
def __init__(self, conf: LSTMEncoderConfig):
super().__init__()
self.conf = conf
self.dtype = th.float32
input_dim = conf.in_channels
length = conf.seq_len
feature_channels = 32
out_channels = conf.out_channels
self.encoder = nn.LSTM(input_dim, feature_channels, 1, batch_first=True)
self.out = nn.Linear(feature_channels * length, out_channels)
def forward(self, x, t=None):
bs, channels, seq_len = x.shape
hand_last = x[:, :6, -1:].expand(-1, -1, self.conf.seq_len_future).clone() #last observed hand position
head_last = x[:, 6:, -1:].expand(-1, -1, self.conf.seq_len_future).clone()# last observed head orientation
h = x.type(self.dtype)
h, _ = self.encoder(h.permute(0, 2, 1))
h = h.reshape(h.shape[0], -1)
h = h.type(x.dtype)
h = h.float()
h = self.out(h)
return h, hand_last, head_last
@dataclass
class MLPEncoderConfig(BaseConfig):
in_channels: int
seq_len: int = 40
seq_len_future: int = 3
out_channels: int = 128
def make_model(self):
return MLPEncoderModel(self)
class MLPEncoderModel(nn.Module):
def __init__(self, conf: MLPEncoderConfig):
super().__init__()
self.conf = conf
self.dtype = th.float32
input_dim = conf.in_channels
length = conf.seq_len
out_channels = conf.out_channels
linear_size = 128
self.encoder = nn.Sequential(
nn.Linear(length*input_dim, linear_size),
nn.LayerNorm([linear_size]),
nn.ReLU(inplace=True),
nn.Linear(linear_size, linear_size),
nn.LayerNorm([linear_size]),
nn.ReLU(inplace=True),
)
self.out = nn.Linear(linear_size, out_channels)
def forward(self, x, t=None):
bs, channels, seq_len = x.shape
hand_last = x[:, :6, -1:].expand(-1, -1, self.conf.seq_len_future).clone() #last observed hand position
head_last = x[:, 6:, -1:].expand(-1, -1, self.conf.seq_len_future).clone()# last observed head orientation
h = x.type(self.dtype)
h = h.view(h.shape[0], -1)
h = self.encoder(h)
h = h.type(x.dtype)
h = h.float()
h = self.out(h)
return h, hand_last, head_last

418
model/unet_autoenc.py Normal file
View file

@ -0,0 +1,418 @@
from enum import Enum
import torch, pdb
import os
from torch import Tensor
from torch.nn.functional import silu
from .unet import *
from choices import *
@dataclass
class BeatGANsAutoencConfig(BeatGANsUNetConfig):
seq_len_future: int = 3
enc_out_channels: int = 128
semantic_encoder_type: str = 'gcn'
enc_channel_mult: Tuple[int] = None
def make_model(self):
return BeatGANsAutoencModel(self)
class BeatGANsAutoencModel(BeatGANsUNetModel):
def __init__(self, conf: BeatGANsAutoencConfig):
super().__init__(conf)
self.conf = conf
# having only time, cond
self.time_embed = TimeStyleSeperateEmbed(
time_channels=conf.model_channels,
time_out_channels=conf.embed_channels,
)
if conf.semantic_encoder_type == 'gcn':
self.encoder = GCNEncoderConfig(
seq_len=conf.seq_len,
seq_len_future=conf.seq_len_future,
in_channels=conf.in_channels,
model_channels=16,
out_channels=conf.enc_out_channels,
channel_mult=conf.enc_channel_mult or conf.channel_mult,
).make_model()
elif conf.semantic_encoder_type == '1dcnn':
self.encoder = CNNEncoderConfig(
seq_len=conf.seq_len,
seq_len_future=conf.seq_len_future,
in_channels=conf.in_channels,
out_channels=conf.enc_out_channels,
).make_model()
elif conf.semantic_encoder_type == 'gru':
# ensure deterministic behavior of RNNs
os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:2"
self.encoder = GRUEncoderConfig(
seq_len=conf.seq_len,
seq_len_future=conf.seq_len_future,
in_channels=conf.in_channels,
out_channels=conf.enc_out_channels,
).make_model()
elif conf.semantic_encoder_type == 'lstm':
# ensure deterministic behavior of RNNs
os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:2"
self.encoder = LSTMEncoderConfig(
seq_len=conf.seq_len,
seq_len_future=conf.seq_len_future,
in_channels=conf.in_channels,
out_channels=conf.enc_out_channels,
).make_model()
elif conf.semantic_encoder_type == 'mlp':
self.encoder = MLPEncoderConfig(
seq_len=conf.seq_len,
seq_len_future=conf.seq_len_future,
in_channels=conf.in_channels,
out_channels=conf.enc_out_channels,
).make_model()
else:
raise NotImplementedError()
def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
"""
Reparameterization trick to sample from N(mu, var) from
N(0,1).
:param mu: (Tensor) Mean of the latent Gaussian [B x D]
:param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
:return: (Tensor) [B x D]
"""
assert self.conf.is_stochastic
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps * std + mu
def sample_z(self, n: int, device):
assert self.conf.is_stochastic
return torch.randn(n, self.conf.enc_out_channels, device=device)
def noise_to_cond(self, noise: Tensor):
raise NotImplementedError()
assert self.conf.noise_net_conf is not None
return self.noise_net.forward(noise)
def encode(self, x):
cond, pred_hand, pred_head = self.encoder.forward(x)
return cond, pred_hand, pred_head
@property
def stylespace_sizes(self):
modules = list(self.input_blocks.modules()) + list(
self.middle_block.modules()) + list(self.output_blocks.modules())
sizes = []
for module in modules:
if isinstance(module, ResBlock):
linear = module.cond_emb_layers[-1]
sizes.append(linear.weight.shape[0])
return sizes
def encode_stylespace(self, x, return_vector: bool = True):
"""
encode to style space
"""
modules = list(self.input_blocks.modules()) + list(
self.middle_block.modules()) + list(self.output_blocks.modules())
# (n, c)
cond = self.encoder.forward(x)
S = []
for module in modules:
if isinstance(module, ResBlock):
# (n, c')
s = module.cond_emb_layers.forward(cond)
S.append(s)
if return_vector:
# (n, sum_c)
return torch.cat(S, dim=1)
else:
return S
def forward(self,
x,
t,
x_start=None,
cond=None,
style=None,
noise=None,
t_cond=None,
**kwargs):
"""
Apply the model to an input batch.
Args:
x_start: the original image to encode
cond: output of the encoder
noise: random noise (to predict the cond)
"""
if t_cond is None:
t_cond = t ## randomly sampled timestep with the size of [batch_size]
if noise is not None:
# if the noise is given, we predict the cond from noise
cond = self.noise_to_cond(noise)
cond_given = True
if cond is None:
cond_given = False
if x is not None:
assert len(x) == len(x_start), f'{len(x)} != {len(x_start)}'
cond, pred_hand, pred_head = self.encode(x_start)
if t is not None: ## t==t_cond
_t_emb = timestep_embedding(t, self.conf.model_channels)
#print("t: {}, _t_emb:{}".format(t, _t_emb))
_t_cond_emb = timestep_embedding(t_cond, self.conf.model_channels)
#print("t_cond: {}, _t_cond_emb:{}".format(t, _t_cond_emb))
else:
# this happens when training only autoenc
_t_emb = None
_t_cond_emb = None
if self.conf.resnet_two_cond:
res = self.time_embed.forward( ## self.time_embed is an MLP
time_emb=_t_emb,
cond=cond,
time_cond_emb=_t_cond_emb,
)
else:
raise NotImplementedError()
if self.conf.resnet_two_cond:
# two cond: first = time emb, second = cond_emb
emb = res.time_emb
cond_emb = res.emb
else:
# one cond = combined of both time and cond
emb = res.emb
cond_emb = None
# override the style if given
style = style or res.style ## style==None, res.style: cond, torch.Size([64, 512])
# where in the model to supply time conditions
enc_time_emb = emb ## time embeddings
mid_time_emb = emb
dec_time_emb = emb
# where in the model to supply style conditions
enc_cond_emb = cond_emb ## z_sem embeddings
mid_cond_emb = cond_emb
dec_cond_emb = cond_emb
# hs = []
hs = [[] for _ in range(len(self.conf.channel_mult))]
if x is not None:
h = x.type(self.dtype)
# input blocks
k = 0
for i in range(len(self.input_num_blocks)):
for j in range(self.input_num_blocks[i]):
h = self.input_blocks[k](h,
emb=enc_time_emb,
cond=enc_cond_emb)
# print(i, j, h.shape)
'''if h.shape[-1]%2==1:
pdb.set_trace()'''
hs[i].append(h)
k += 1
assert k == len(self.input_blocks)
# middle blocks
h = self.middle_block(h, emb=mid_time_emb, cond=mid_cond_emb)
else:
# no lateral connections
# happens when training only the autonecoder
h = None
hs = [[] for _ in range(len(self.conf.channel_mult))]
pdb.set_trace()
# output blocks
k = 0
for i in range(len(self.output_num_blocks)):
for j in range(self.output_num_blocks[i]):
# take the lateral connection from the same layer (in reserve)
# until there is no more, use None
try:
lateral = hs[-i - 1].pop() ## in the reverse order (symmetric)
except IndexError:
lateral = None
'''print(i, j, lateral.shape, h.shape)
if lateral.shape[-1]!=h.shape[-1]:
pdb.set_trace()'''
# print("h is", h.size())
# print("lateral is", lateral.size())
h = self.output_blocks[k](h,
emb=dec_time_emb,
cond=dec_cond_emb,
lateral=lateral)
k += 1
pred = self.out(h)
# print("h:", h.shape)
# print("pred:", pred.shape)
if cond_given == True:
return AutoencReturn(pred=pred, cond=cond)
else:
return AutoencReturn(pred=pred, cond=cond, pred_hand=pred_hand, pred_head=pred_head)
@dataclass
class GCNAutoencConfig(GCNUNetConfig):
# number of style channels
enc_out_channels: int = 256
enc_channel_mult: Tuple[int] = None
def make_model(self):
return GCNAutoencModel(self)
class GCNAutoencModel(GCNUNetModel):
def __init__(self, conf: GCNAutoencConfig):
super().__init__(conf)
self.conf = conf
# having only time, cond
self.time_emb_channels = conf.model_channels
self.time_embed = TimeStyleSeperateEmbed(
time_channels=self.time_emb_channels,
time_out_channels=conf.embed_channels,
)
self.encoder = GCNEncoderConfig(
seq_len=conf.seq_len,
in_channels=conf.in_channels,
model_channels=32,
out_channels=conf.enc_out_channels,
channel_mult=conf.enc_channel_mult or conf.channel_mult,
).make_model()
def encode(self, x):
cond = self.encoder.forward(x)
return {'cond': cond}
def forward(self,
x,
t,
x_start=None,
cond=None,
**kwargs):
"""
Apply the model to an input batch.
Args:
x_start: the original input to encode
cond: output of the encoder
"""
if cond is None:
if x is not None:
assert len(x) == len(x_start), f'{len(x)} != {len(x_start)}'
tmp = self.encode(x_start)
cond = tmp['cond']
if t is not None:
_t_emb = timestep_embedding(t, self.time_emb_channels)
else:
# this happens when training only autoenc
_t_emb = None
if self.conf.resnet_two_cond:
res = self.time_embed.forward( ## self.time_embed is an MLP
time_emb=_t_emb,
cond=cond,
)
# two cond: first = time emb, second = cond_emb
emb = res.time_emb
cond_emb = res.emb
else:
raise NotImplementedError()
# where in the model to supply time conditions
enc_time_emb = emb ## time embeddings
mid_time_emb = emb
dec_time_emb = emb
enc_cond_emb = cond_emb ## z_sem embeddings
mid_cond_emb = cond_emb
dec_cond_emb = cond_emb
bs, channels, seq_len = x.shape
x = x.reshape(bs, self.conf.node_n, self.in_features, seq_len).permute(0, 2, 1, 3)
hs = [[] for _ in range(len(self.conf.channel_mult))]
h = x.type(self.dtype)
# input blocks
k = 0
for i in range(len(self.input_num_blocks)):
for j in range(self.input_num_blocks[i]):
h = self.input_blocks[k](h,
emb=enc_time_emb,
cond=enc_cond_emb)
hs[i].append(h)
k += 1
assert k == len(self.input_blocks)
# output blocks
k = 0
for i in range(len(self.output_num_blocks)):
for j in range(self.output_num_blocks[i]):
# take the lateral connection from the same layer (in reserve)
# until there is no more, use None
try:
lateral = hs[-i - 1].pop() ## in the reverse order (symmetric)
except IndexError:
lateral = None
h = self.output_blocks[k](h,
emb=dec_time_emb,
cond=dec_cond_emb,
lateral=lateral)
k += 1
pred = self.out(h)
pred = pred.permute(0, 2, 1, 3).reshape(bs, -1, seq_len)
return AutoencReturn(pred=pred, cond=cond)
class AutoencReturn(NamedTuple):
pred: Tensor
cond: Tensor = None
pred_hand: Tensor = None
pred_head: Tensor = None
class EmbedReturn(NamedTuple):
# style and time
emb: Tensor = None
# time only
time_emb: Tensor = None
# style only (but could depend on time)
style: Tensor = None
class TimeStyleSeperateEmbed(nn.Module):
# embed only style
def __init__(self, time_channels, time_out_channels):
super().__init__()
self.time_embed = nn.Sequential(
linear(time_channels, time_out_channels),
nn.SiLU(),
linear(time_out_channels, time_out_channels),
)
self.style = nn.Identity()
def forward(self, time_emb=None, cond=None, **kwargs):
if time_emb is None:
# happens with autoenc training mode
time_emb = None
else:
time_emb = self.time_embed(time_emb)
style = self.style(cond) ## style==cond
return EmbedReturn(emb=style, time_emb=time_emb, style=style)

193
preprocess.py Normal file
View file

@ -0,0 +1,193 @@
import os, random, math, copy
import pandas as pd
import numpy as np
import pickle as pkl
import logging, sys
from torch.utils.data import DataLoader,Dataset
import multiprocessing as mp
import json
import matplotlib.pyplot as plt
def MakeDir(dirpath):
if not os.path.exists(dirpath):
os.makedirs(dirpath)
def load_egobody(data_dir, seq_len, sample_rate=1, train=1):
data_dir_train = data_dir + 'train/'
data_dir_test = data_dir + 'test/'
if train == 0:
data_dirs = [data_dir_test] # test
elif train == 1:
data_dirs = [data_dir_train] # train
elif train == 2:
data_dirs = [data_dir_train, data_dir_test] # train + test
hand_head = []
for data_dir in data_dirs:
file_paths = sorted(os.listdir(data_dir))
pose_xyz_file_paths = []
head_file_paths = []
for path in file_paths:
path_split = path.split('_')
data_type = path_split[-1][:-4]
if(data_type == 'xyz'):
pose_xyz_file_paths.append(path)
if(data_type == 'head'):
head_file_paths.append(path)
file_num = len(pose_xyz_file_paths)
for i in range(file_num):
pose_data = np.load(data_dir + pose_xyz_file_paths[i])
head_data = np.load(data_dir + head_file_paths[i])
num_frames = pose_data.shape[0]
if num_frames < seq_len:
continue
head_pos = pose_data[:, 15*3:16*3]
left_hand_pos = pose_data[:, 20*3:21*3]
right_hand_pos = pose_data[:, 21*3:22*3]
head_ori = head_data
left_hand_pos -= head_pos # convert hand positions to head coordinate system
right_hand_pos -= head_pos
hand_head_data = left_hand_pos
hand_head_data = np.concatenate((hand_head_data, right_hand_pos), axis=1)
hand_head_data = np.concatenate((hand_head_data, head_ori), axis=1)
fs = np.arange(0, num_frames - seq_len + 1)
fs_sel = fs
for i in np.arange(seq_len - 1):
fs_sel = np.vstack((fs_sel, fs + i + 1))
fs_sel = fs_sel.transpose()
seq_sel = hand_head_data[fs_sel, :]
seq_sel = seq_sel[0::sample_rate, :, :]
if len(hand_head) == 0:
hand_head = seq_sel
else:
hand_head = np.concatenate((hand_head, seq_sel), axis=0)
hand_head = np.transpose(hand_head, (0, 2, 1))
return hand_head
def load_adt(data_dir, seq_len, sample_rate=1, train=1):
data_dir_train = data_dir + 'train/'
data_dir_test = data_dir + 'test/'
if train == 0:
data_dirs = [data_dir_test] # test
elif train == 1:
data_dirs = [data_dir_train] # train
elif train == 2:
data_dirs = [data_dir_train, data_dir_test] # train + test
hand_head = []
for data_dir in data_dirs:
file_paths = sorted(os.listdir(data_dir))
pose_xyz_file_paths = []
head_file_paths = []
for path in file_paths:
path_split = path.split('_')
data_type = path_split[-1][:-4]
if(data_type == 'xyz'):
pose_xyz_file_paths.append(path)
if(data_type == 'head'):
head_file_paths.append(path)
file_num = len(pose_xyz_file_paths)
for i in range(file_num):
pose_data = np.load(data_dir + pose_xyz_file_paths[i])
head_data = np.load(data_dir + head_file_paths[i])
num_frames = pose_data.shape[0]
if num_frames < seq_len:
continue
head_pos = pose_data[:, 4*3:5*3]
left_hand_pos = pose_data[:, 8*3:9*3]
right_hand_pos = pose_data[:, 12*3:13*3]
head_ori = head_data
left_hand_pos -= head_pos # convert hand positions to head coordinate system
right_hand_pos -= head_pos
hand_head_data = left_hand_pos
hand_head_data = np.concatenate((hand_head_data, right_hand_pos), axis=1)
hand_head_data = np.concatenate((hand_head_data, head_ori), axis=1)
fs = np.arange(0, num_frames - seq_len + 1)
fs_sel = fs
for i in np.arange(seq_len - 1):
fs_sel = np.vstack((fs_sel, fs + i + 1))
fs_sel = fs_sel.transpose()
seq_sel = hand_head_data[fs_sel, :]
seq_sel = seq_sel[0::sample_rate, :, :]
if len(hand_head) == 0:
hand_head = seq_sel
else:
hand_head = np.concatenate((hand_head, seq_sel), axis=0)
hand_head = np.transpose(hand_head, (0, 2, 1))
return hand_head
def load_gimo(data_dir, seq_len, sample_rate=1, train=1):
data_dir_train = data_dir + 'train/'
data_dir_test = data_dir + 'test/'
if train == 0:
data_dirs = [data_dir_test] # test
elif train == 1:
data_dirs = [data_dir_train] # train
elif train == 2:
data_dirs = [data_dir_train, data_dir_test] # train + test
hand_head = []
for data_dir in data_dirs:
file_paths = sorted(os.listdir(data_dir))
pose_xyz_file_paths = []
head_file_paths = []
for path in file_paths:
path_split = path.split('_')
data_type = path_split[-1][:-4]
if(data_type == 'xyz'):
pose_xyz_file_paths.append(path)
if(data_type == 'head'):
head_file_paths.append(path)
file_num = len(pose_xyz_file_paths)
for i in range(file_num):
pose_data = np.load(data_dir + pose_xyz_file_paths[i])
head_data = np.load(data_dir + head_file_paths[i])
num_frames = pose_data.shape[0]
if num_frames < seq_len:
continue
head_pos = pose_data[:, 15*3:16*3]
left_hand_pos = pose_data[:, 20*3:21*3]
right_hand_pos = pose_data[:, 21*3:22*3]
head_ori = head_data
left_hand_pos -= head_pos # convert hand positions to head coordinate system
right_hand_pos -= head_pos
hand_head_data = left_hand_pos
hand_head_data = np.concatenate((hand_head_data, right_hand_pos), axis=1)
hand_head_data = np.concatenate((hand_head_data, head_ori), axis=1)
fs = np.arange(0, num_frames - seq_len + 1)
fs_sel = fs
for i in np.arange(seq_len - 1):
fs_sel = np.vstack((fs_sel, fs + i + 1))
fs_sel = fs_sel.transpose()
seq_sel = hand_head_data[fs_sel, :]
seq_sel = seq_sel[0::sample_rate, :, :]
if len(hand_head) == 0:
hand_head = seq_sel
else:
hand_head = np.concatenate((hand_head, seq_sel), axis=0)
hand_head = np.transpose(hand_head, (0, 2, 1))
return hand_head
if __name__ == "__main__":
data_dir = "/scratch/hu/pose_forecast/egobody_pose2gaze/"
seq_len = 40
test_data = load_egobody(data_dir, seq_len, sample_rate=10, train=0)
print("\ndataset size: {}".format(test_data.shape))

3
train.sh Normal file
View file

@ -0,0 +1,3 @@
#python main.py --gpus 7 --mode 'train' --model_name 'haheae';
python main.py --gpus 7 --mode 'eval' --model_name 'haheae';