update readme
This commit is contained in:
parent
35ee4b75e8
commit
249a01f342
18 changed files with 4936 additions and 0 deletions
BIN
checkpoints/haheae/last.ckpt
Normal file
BIN
checkpoints/haheae/last.ckpt
Normal file
Binary file not shown.
179
choices.py
Normal file
179
choices.py
Normal file
|
@ -0,0 +1,179 @@
|
|||
from enum import Enum
|
||||
from torch import nn
|
||||
|
||||
|
||||
class TrainMode(Enum):
|
||||
# manipulate mode = training the classifier
|
||||
manipulate = 'manipulate'
|
||||
# default training mode!
|
||||
diffusion = 'diffusion'
|
||||
# default latent training mode!
|
||||
# fitting a diffusion model to a given latent
|
||||
latent_diffusion = 'latentdiffusion'
|
||||
|
||||
def is_manipulate(self):
|
||||
return self in [
|
||||
TrainMode.manipulate,
|
||||
]
|
||||
|
||||
def is_diffusion(self):
|
||||
return self in [
|
||||
TrainMode.diffusion,
|
||||
TrainMode.latent_diffusion,
|
||||
]
|
||||
|
||||
def is_autoenc(self):
|
||||
# the network possibly does autoencoding
|
||||
return self in [
|
||||
TrainMode.diffusion,
|
||||
]
|
||||
|
||||
def is_latent_diffusion(self):
|
||||
return self in [
|
||||
TrainMode.latent_diffusion,
|
||||
]
|
||||
|
||||
def use_latent_net(self):
|
||||
return self.is_latent_diffusion()
|
||||
|
||||
def require_dataset_infer(self):
|
||||
"""
|
||||
whether training in this mode requires the latent variables to be available?
|
||||
"""
|
||||
# this will precalculate all the latents before hand
|
||||
# and the dataset will be all the predicted latents
|
||||
return self in [
|
||||
TrainMode.latent_diffusion,
|
||||
TrainMode.manipulate,
|
||||
]
|
||||
|
||||
|
||||
class ManipulateMode(Enum):
|
||||
"""
|
||||
how to train the classifier to manipulate
|
||||
"""
|
||||
# train on whole celeba attr dataset
|
||||
celebahq_all = 'celebahq_all'
|
||||
# celeba with D2C's crop
|
||||
d2c_fewshot = 'd2cfewshot'
|
||||
d2c_fewshot_allneg = 'd2cfewshotallneg'
|
||||
|
||||
def is_celeba_attr(self):
|
||||
return self in [
|
||||
ManipulateMode.d2c_fewshot,
|
||||
ManipulateMode.d2c_fewshot_allneg,
|
||||
ManipulateMode.celebahq_all,
|
||||
]
|
||||
|
||||
def is_single_class(self):
|
||||
return self in [
|
||||
ManipulateMode.d2c_fewshot,
|
||||
ManipulateMode.d2c_fewshot_allneg,
|
||||
]
|
||||
|
||||
def is_fewshot(self):
|
||||
return self in [
|
||||
ManipulateMode.d2c_fewshot,
|
||||
ManipulateMode.d2c_fewshot_allneg,
|
||||
]
|
||||
|
||||
def is_fewshot_allneg(self):
|
||||
return self in [
|
||||
ManipulateMode.d2c_fewshot_allneg,
|
||||
]
|
||||
|
||||
|
||||
class ModelType(Enum):
|
||||
"""
|
||||
Kinds of the backbone models
|
||||
"""
|
||||
|
||||
# unconditional ddpm
|
||||
ddpm = 'ddpm'
|
||||
# autoencoding ddpm cannot do unconditional generation
|
||||
autoencoder = 'autoencoder'
|
||||
|
||||
def has_autoenc(self):
|
||||
return self in [
|
||||
ModelType.autoencoder,
|
||||
]
|
||||
|
||||
def can_sample(self):
|
||||
return self in [ModelType.ddpm]
|
||||
|
||||
|
||||
class ModelName(Enum):
|
||||
"""
|
||||
List of all supported model classes
|
||||
"""
|
||||
|
||||
beatgans_ddpm = 'beatgans_ddpm'
|
||||
beatgans_autoenc = 'beatgans_autoenc'
|
||||
|
||||
|
||||
class ModelMeanType(Enum):
|
||||
"""
|
||||
Which type of output the model predicts.
|
||||
"""
|
||||
|
||||
eps = 'eps' # the model predicts epsilon
|
||||
|
||||
|
||||
class ModelVarType(Enum):
|
||||
"""
|
||||
What is used as the model's output variance.
|
||||
|
||||
The LEARNED_RANGE option has been added to allow the model to predict
|
||||
values between FIXED_SMALL and FIXED_LARGE, making its job easier.
|
||||
"""
|
||||
|
||||
# posterior beta_t
|
||||
fixed_small = 'fixed_small'
|
||||
# beta_t
|
||||
fixed_large = 'fixed_large'
|
||||
|
||||
|
||||
class LossType(Enum):
|
||||
mse = 'mse' # use raw MSE loss (and KL when learning variances)
|
||||
l1 = 'l1'
|
||||
|
||||
|
||||
class GenerativeType(Enum):
|
||||
"""
|
||||
How's a sample generated
|
||||
"""
|
||||
|
||||
ddpm = 'ddpm'
|
||||
ddim = 'ddim'
|
||||
|
||||
|
||||
class OptimizerType(Enum):
|
||||
adam = 'adam'
|
||||
adamw = 'adamw'
|
||||
|
||||
|
||||
class Activation(Enum):
|
||||
none = 'none'
|
||||
relu = 'relu'
|
||||
lrelu = 'lrelu'
|
||||
silu = 'silu'
|
||||
tanh = 'tanh'
|
||||
|
||||
def get_act(self):
|
||||
if self == Activation.none:
|
||||
return nn.Identity()
|
||||
elif self == Activation.relu:
|
||||
return nn.ReLU()
|
||||
elif self == Activation.lrelu:
|
||||
return nn.LeakyReLU(negative_slope=0.2)
|
||||
elif self == Activation.silu:
|
||||
return nn.SiLU()
|
||||
elif self == Activation.tanh:
|
||||
return nn.Tanh()
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class ManipulateLossType(Enum):
|
||||
bce = 'bce'
|
||||
mse = 'mse'
|
153
config.py
Normal file
153
config.py
Normal file
|
@ -0,0 +1,153 @@
|
|||
from model.blocks import *
|
||||
from diffusion.resample import UniformSampler
|
||||
from dataclasses import dataclass
|
||||
from diffusion.diffusion import space_timesteps
|
||||
from typing import Tuple
|
||||
from config_base import BaseConfig
|
||||
from diffusion import *
|
||||
from diffusion.base import GenerativeType, LossType, ModelMeanType, ModelVarType, get_named_beta_schedule
|
||||
from model import *
|
||||
from choices import *
|
||||
from preprocess import *
|
||||
import os
|
||||
|
||||
@dataclass
|
||||
class TrainConfig(BaseConfig):
|
||||
name: str = ''
|
||||
base_dir: str = './checkpoints/'
|
||||
logdir: str = f'{base_dir}{name}'
|
||||
data_name: str = ''
|
||||
data_val_name: str = ''
|
||||
seq_len: int = 40 # for reconstruction
|
||||
seq_len_future: int = 3 # for prediction
|
||||
in_channels = 9
|
||||
fp16: bool = True
|
||||
lr: float = 1e-4
|
||||
ema_decay: float = 0.9999
|
||||
seed: int = 0 # random seed
|
||||
batch_size: int = 64
|
||||
accum_batches: int = 1
|
||||
batch_size_eval: int = 1024
|
||||
total_epochs: int = 1_000
|
||||
save_every_epochs: int = 10
|
||||
eval_every_epochs: int = 10
|
||||
train_mode: TrainMode = TrainMode.diffusion
|
||||
T: int = 1000
|
||||
T_eval: int = 100
|
||||
diffusion_type: str = 'beatgans'
|
||||
semantic_encoder_type: str = 'gcn'
|
||||
net_beatgans_embed_channels: int = 128
|
||||
beatgans_gen_type: GenerativeType = GenerativeType.ddim
|
||||
beatgans_loss_type: LossType = LossType.mse
|
||||
hand_mse_factor = 1.0
|
||||
head_mse_factor = 1.0
|
||||
beatgans_model_mean_type: ModelMeanType = ModelMeanType.eps
|
||||
beatgans_model_var_type: ModelVarType = ModelVarType.fixed_large
|
||||
beatgans_rescale_timesteps: bool = False
|
||||
beta_scheduler: str = 'linear'
|
||||
net_ch: int = 64
|
||||
net_ch_mult: Tuple[int, ...]= (1, 2, 4)
|
||||
net_enc_channel_mult: Tuple[int] = (1, 2, 4)
|
||||
grad_clip: float = 1
|
||||
optimizer: OptimizerType = OptimizerType.adam
|
||||
weight_decay: float = 0
|
||||
warmup: int = 0
|
||||
model_conf: ModelConfig = None
|
||||
model_name: ModelName = ModelName.beatgans_autoenc
|
||||
model_type: ModelType = None
|
||||
|
||||
@property
|
||||
def batch_size_effective(self):
|
||||
return self.batch_size*self.accum_batches
|
||||
|
||||
def _make_diffusion_conf(self, T=None):
|
||||
if self.diffusion_type == 'beatgans':
|
||||
# can use T < self.T for evaluation
|
||||
# follows the guided-diffusion repo conventions
|
||||
# t's are evenly spaced
|
||||
if self.beatgans_gen_type == GenerativeType.ddpm:
|
||||
section_counts = [T]
|
||||
elif self.beatgans_gen_type == GenerativeType.ddim:
|
||||
section_counts = f'ddim{T}'
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
return SpacedDiffusionBeatGansConfig(
|
||||
gen_type=self.beatgans_gen_type,
|
||||
model_type=self.model_type,
|
||||
betas=get_named_beta_schedule(self.beta_scheduler, T),
|
||||
model_mean_type=self.beatgans_model_mean_type,
|
||||
model_var_type=self.beatgans_model_var_type,
|
||||
loss_type=self.beatgans_loss_type,
|
||||
rescale_timesteps=self.beatgans_rescale_timesteps,
|
||||
use_timesteps=space_timesteps(num_timesteps=T, section_counts=section_counts),
|
||||
fp16=self.fp16,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def model_out_channels(self):
|
||||
return self.in_channels
|
||||
|
||||
@property
|
||||
def model_input_channels(self):
|
||||
return self.in_channels
|
||||
|
||||
def make_T_sampler(self):
|
||||
return UniformSampler(self.T)
|
||||
|
||||
def make_diffusion_conf(self):
|
||||
return self._make_diffusion_conf(self.T)
|
||||
|
||||
def make_eval_diffusion_conf(self):
|
||||
return self._make_diffusion_conf(T=self.T_eval)
|
||||
|
||||
def make_model_conf(self):
|
||||
cls = BeatGANsAutoencConfig
|
||||
if self.model_name == ModelName.beatgans_autoenc:
|
||||
self.model_type = ModelType.autoencoder
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
self.model_conf = cls(
|
||||
semantic_encoder_type = self.semantic_encoder_type,
|
||||
channel_mult=self.net_ch_mult,
|
||||
seq_len = self.seq_len,
|
||||
seq_len_future = self.seq_len_future,
|
||||
embed_channels=self.net_beatgans_embed_channels,
|
||||
enc_out_channels=self.net_beatgans_embed_channels,
|
||||
enc_channel_mult=self.net_enc_channel_mult,
|
||||
in_channels=self.model_input_channels,
|
||||
model_channels=self.net_ch,
|
||||
out_channels=self.model_out_channels,
|
||||
)
|
||||
|
||||
return self.model_conf
|
||||
|
||||
def egobody_autoenc(mode, encoder_type='gcn', hand_mse_factor=1.0, head_mse_factor=1.0, data_sample_rate=1, epoch=130,in_channels=9, seq_len=40):
|
||||
conf = TrainConfig()
|
||||
conf.seq_len = seq_len
|
||||
conf.seq_len_future = 3
|
||||
conf.in_channels = in_channels
|
||||
conf.net_beatgans_embed_channels = 128
|
||||
conf.net_ch = 64
|
||||
conf.net_ch_mult = (1, 1, 1)
|
||||
conf.semantic_encoder_type = encoder_type
|
||||
conf.hand_mse_factor = hand_mse_factor
|
||||
conf.head_mse_factor = head_mse_factor
|
||||
conf.net_enc_channel_mult = conf.net_ch_mult
|
||||
conf.total_epochs = epoch
|
||||
conf.save_every_epochs = 10
|
||||
conf.eval_every_epochs = 10
|
||||
conf.batch_size = 64
|
||||
conf.batch_size_eval = 1024*4
|
||||
|
||||
conf.data_dir = "/scratch/hu/pose_forecast/egobody_pose2gaze/"
|
||||
conf.data_sample_rate = data_sample_rate
|
||||
conf.name = 'egobody_autoenc'
|
||||
conf.data_name = 'egobody'
|
||||
|
||||
conf.mode = mode
|
||||
conf.make_model_conf()
|
||||
return conf
|
72
config_base.py
Normal file
72
config_base.py
Normal file
|
@ -0,0 +1,72 @@
|
|||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseConfig:
|
||||
def clone(self):
|
||||
return deepcopy(self)
|
||||
|
||||
def inherit(self, another):
|
||||
"""inherit common keys from a given config"""
|
||||
common_keys = set(self.__dict__.keys()) & set(another.__dict__.keys())
|
||||
for k in common_keys:
|
||||
setattr(self, k, getattr(another, k))
|
||||
|
||||
def propagate(self):
|
||||
"""push down the configuration to all members"""
|
||||
for k, v in self.__dict__.items():
|
||||
if isinstance(v, BaseConfig):
|
||||
v.inherit(self)
|
||||
v.propagate()
|
||||
|
||||
def save(self, save_path):
|
||||
"""save config to json file"""
|
||||
dirname = os.path.dirname(save_path)
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
conf = self.as_dict_jsonable()
|
||||
with open(save_path, 'w') as f:
|
||||
json.dump(conf, f)
|
||||
|
||||
def load(self, load_path):
|
||||
"""load json config"""
|
||||
with open(load_path) as f:
|
||||
conf = json.load(f)
|
||||
self.from_dict(conf)
|
||||
|
||||
def from_dict(self, dict, strict=False):
|
||||
for k, v in dict.items():
|
||||
if not hasattr(self, k):
|
||||
if strict:
|
||||
raise ValueError(f"loading extra '{k}'")
|
||||
else:
|
||||
print(f"loading extra '{k}'")
|
||||
continue
|
||||
if isinstance(self.__dict__[k], BaseConfig):
|
||||
self.__dict__[k].from_dict(v)
|
||||
else:
|
||||
self.__dict__[k] = v
|
||||
|
||||
def as_dict_jsonable(self):
|
||||
conf = {}
|
||||
for k, v in self.__dict__.items():
|
||||
if isinstance(v, BaseConfig):
|
||||
conf[k] = v.as_dict_jsonable()
|
||||
else:
|
||||
if jsonable(v):
|
||||
conf[k] = v
|
||||
else:
|
||||
# ignore not jsonable
|
||||
pass
|
||||
return conf
|
||||
|
||||
|
||||
def jsonable(x):
|
||||
try:
|
||||
json.dumps(x)
|
||||
return True
|
||||
except TypeError:
|
||||
return False
|
6
diffusion/__init__.py
Normal file
6
diffusion/__init__.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
from typing import Union
|
||||
|
||||
from .diffusion import SpacedDiffusionBeatGans, SpacedDiffusionBeatGansConfig
|
||||
|
||||
Sampler = Union[SpacedDiffusionBeatGans]
|
||||
SamplerConfig = Union[SpacedDiffusionBeatGansConfig]
|
1148
diffusion/base.py
Normal file
1148
diffusion/base.py
Normal file
File diff suppressed because it is too large
Load diff
182
diffusion/diffusion.py
Normal file
182
diffusion/diffusion.py
Normal file
|
@ -0,0 +1,182 @@
|
|||
from .base import *
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
def space_timesteps(num_timesteps, section_counts):
|
||||
"""
|
||||
Create a list of timesteps to use from an original diffusion process,
|
||||
given the number of timesteps we want to take from equally-sized portions
|
||||
of the original process.
|
||||
|
||||
For example, if there's 300 timesteps and the section counts are [10,15,20]
|
||||
then the first 100 timesteps are strided to be 10 timesteps, the second 100
|
||||
are strided to be 15 timesteps, and the final 100 are strided to be 20.
|
||||
|
||||
If the stride is a string starting with "ddim", then the fixed striding
|
||||
from the DDIM paper is used, and only one section is allowed.
|
||||
|
||||
:param num_timesteps: the number of diffusion steps in the original
|
||||
process to divide up.
|
||||
:param section_counts: either a list of numbers, or a string containing
|
||||
comma-separated numbers, indicating the step count
|
||||
per section. As a special case, use "ddimN" where N
|
||||
is a number of steps to use the striding from the
|
||||
DDIM paper.
|
||||
:return: a set of diffusion steps from the original process to use.
|
||||
"""
|
||||
if isinstance(section_counts, str):
|
||||
if section_counts.startswith("ddim"):
|
||||
desired_count = int(section_counts[len("ddim"):])
|
||||
for i in range(1, num_timesteps):
|
||||
if len(range(0, num_timesteps, i)) == desired_count:
|
||||
return set(range(0, num_timesteps, i))
|
||||
raise ValueError(
|
||||
f"cannot create exactly {num_timesteps} steps with an integer stride"
|
||||
)
|
||||
section_counts = [int(x) for x in section_counts.split(",")]
|
||||
size_per = num_timesteps // len(section_counts)
|
||||
extra = num_timesteps % len(section_counts)
|
||||
start_idx = 0
|
||||
all_steps = []
|
||||
for i, section_count in enumerate(section_counts):
|
||||
size = size_per + (1 if i < extra else 0)
|
||||
if size < section_count:
|
||||
raise ValueError(
|
||||
f"cannot divide section of {size} steps into {section_count}")
|
||||
if section_count <= 1:
|
||||
frac_stride = 1
|
||||
else:
|
||||
frac_stride = (size - 1) / (section_count - 1)
|
||||
cur_idx = 0.0
|
||||
taken_steps = []
|
||||
for _ in range(section_count):
|
||||
taken_steps.append(start_idx + round(cur_idx))
|
||||
cur_idx += frac_stride
|
||||
all_steps += taken_steps
|
||||
start_idx += size
|
||||
return set(all_steps)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpacedDiffusionBeatGansConfig(GaussianDiffusionBeatGansConfig):
|
||||
use_timesteps: Tuple[int] = None
|
||||
|
||||
def make_sampler(self):
|
||||
return SpacedDiffusionBeatGans(self)
|
||||
|
||||
|
||||
class SpacedDiffusionBeatGans(GaussianDiffusionBeatGans):
|
||||
"""
|
||||
A diffusion process which can skip steps in a base diffusion process.
|
||||
|
||||
:param use_timesteps: a collection (sequence or set) of timesteps from the
|
||||
original diffusion process to retain.
|
||||
:param kwargs: the kwargs to create the base diffusion process.
|
||||
"""
|
||||
def __init__(self, conf: SpacedDiffusionBeatGansConfig):
|
||||
self.conf = conf
|
||||
self.use_timesteps = set(conf.use_timesteps)
|
||||
# how the new t's mapped to the old t's
|
||||
self.timestep_map = []
|
||||
self.original_num_steps = len(conf.betas)
|
||||
|
||||
base_diffusion = GaussianDiffusionBeatGans(conf) # pylint: disable=missing-kwoa
|
||||
last_alpha_cumprod = 1.0
|
||||
new_betas = []
|
||||
for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
|
||||
if i in self.use_timesteps:
|
||||
# getting the new betas of the new timesteps
|
||||
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
|
||||
last_alpha_cumprod = alpha_cumprod
|
||||
self.timestep_map.append(i)
|
||||
conf.betas = np.array(new_betas)
|
||||
super().__init__(conf)
|
||||
|
||||
def p_mean_variance(self, model: Model, *args, **kwargs): # pylint: disable=signature-differs
|
||||
return super().p_mean_variance(self._wrap_model(model), *args,
|
||||
**kwargs)
|
||||
|
||||
def training_losses(self, model: Model, *args, **kwargs): # pylint: disable=signature-differs
|
||||
return super().training_losses(self._wrap_model(model), *args,
|
||||
**kwargs)
|
||||
|
||||
def condition_mean(self, cond_fn, *args, **kwargs):
|
||||
return super().condition_mean(self._wrap_model(cond_fn), *args,
|
||||
**kwargs)
|
||||
|
||||
def condition_score(self, cond_fn, *args, **kwargs):
|
||||
return super().condition_score(self._wrap_model(cond_fn), *args,
|
||||
**kwargs)
|
||||
|
||||
def _wrap_model(self, model: Model):
|
||||
if isinstance(model, _WrappedModel):
|
||||
return model
|
||||
return _WrappedModel(model, self.timestep_map, self.rescale_timesteps,
|
||||
self.original_num_steps)
|
||||
|
||||
def _scale_timesteps(self, t):
|
||||
# Scaling is done by the wrapped model.
|
||||
return t
|
||||
|
||||
|
||||
class _WrappedModel:
|
||||
"""
|
||||
converting the supplied t's to the old t's scales.
|
||||
"""
|
||||
def __init__(self, model, timestep_map, rescale_timesteps,
|
||||
original_num_steps):
|
||||
self.model = model
|
||||
self.timestep_map = timestep_map
|
||||
self.rescale_timesteps = rescale_timesteps
|
||||
self.original_num_steps = original_num_steps
|
||||
|
||||
def forward(self, x, t, t_cond=None, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
t: t's with differrent ranges (can be << T due to smaller eval T) need to be converted to the original t's
|
||||
t_cond: the same as t but can be of different values
|
||||
"""
|
||||
map_tensor = th.tensor(self.timestep_map,
|
||||
device=t.device,
|
||||
dtype=t.dtype)
|
||||
|
||||
def do(t):
|
||||
new_ts = map_tensor[t]
|
||||
if self.rescale_timesteps:
|
||||
new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
|
||||
return new_ts
|
||||
|
||||
if t_cond is not None:
|
||||
# support t_cond
|
||||
t_cond = do(t_cond)
|
||||
## run_ffhq256.py None
|
||||
return self.model(x=x, t=do(t), t_cond=t_cond, **kwargs) ## self.model.__class__==model.unet_autoenc.BeatGANsAutoencModel
|
||||
|
||||
def __getattr__(self, name):
|
||||
# allow for calling the model's methods
|
||||
if hasattr(self.model, name):
|
||||
func = getattr(self.model, name)
|
||||
return func
|
||||
raise AttributeError(name)
|
||||
|
||||
# def __call__(self, x, t, t_cond=None, **kwargs):
|
||||
# """
|
||||
# Args:
|
||||
# t: t's with differrent ranges (can be << T due to smaller eval T) need to be converted to the original t's
|
||||
# t_cond: the same as t but can be of different values
|
||||
# """
|
||||
# map_tensor = th.tensor(self.timestep_map,
|
||||
# device=t.device,
|
||||
# dtype=t.dtype)
|
||||
|
||||
# def do(t):
|
||||
# new_ts = map_tensor[t]
|
||||
# if self.rescale_timesteps:
|
||||
# new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
|
||||
# return new_ts
|
||||
|
||||
# if t_cond is not None:
|
||||
# # support t_cond
|
||||
# t_cond = do(t_cond)
|
||||
# ## run_ffhq256.py None
|
||||
# return self.model(x=x, t=do(t), t_cond=t_cond, **kwargs)
|
63
diffusion/resample.py
Normal file
63
diffusion/resample.py
Normal file
|
@ -0,0 +1,63 @@
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def create_named_schedule_sampler(name, diffusion):
|
||||
"""
|
||||
Create a ScheduleSampler from a library of pre-defined samplers.
|
||||
|
||||
:param name: the name of the sampler.
|
||||
:param diffusion: the diffusion object to sample for.
|
||||
"""
|
||||
if name == "uniform":
|
||||
return UniformSampler(diffusion)
|
||||
else:
|
||||
raise NotImplementedError(f"unknown schedule sampler: {name}")
|
||||
|
||||
|
||||
class ScheduleSampler(ABC):
|
||||
"""
|
||||
A distribution over timesteps in the diffusion process, intended to reduce
|
||||
variance of the objective.
|
||||
|
||||
By default, samplers perform unbiased importance sampling, in which the
|
||||
objective's mean is unchanged.
|
||||
However, subclasses may override sample() to change how the resampled
|
||||
terms are reweighted, allowing for actual changes in the objective.
|
||||
"""
|
||||
@abstractmethod
|
||||
def weights(self):
|
||||
"""
|
||||
Get a numpy array of weights, one per diffusion step.
|
||||
|
||||
The weights needn't be normalized, but must be positive.
|
||||
"""
|
||||
|
||||
def sample(self, batch_size, device):
|
||||
"""
|
||||
Importance-sample timesteps for a batch.
|
||||
|
||||
:param batch_size: the number of timesteps.
|
||||
:param device: the torch device to save to.
|
||||
:return: a tuple (timesteps, weights):
|
||||
- timesteps: a tensor of timestep indices.
|
||||
- weights: a tensor of weights to scale the resulting losses.
|
||||
"""
|
||||
w = self.weights()
|
||||
p = w / np.sum(w)
|
||||
indices_np = np.random.choice(len(p), size=(batch_size, ), p=p)
|
||||
indices = th.from_numpy(indices_np).long().to(device)
|
||||
weights_np = 1 / (len(p) * p[indices_np])
|
||||
weights = th.from_numpy(weights_np).float().to(device)
|
||||
return indices, weights
|
||||
|
||||
|
||||
class UniformSampler(ScheduleSampler):
|
||||
def __init__(self, num_timesteps): ## all steps are 1
|
||||
self._weights = np.ones([num_timesteps])
|
||||
|
||||
def weights(self):
|
||||
return self._weights
|
101
environment/haheae.yml
Normal file
101
environment/haheae.yml
Normal file
|
@ -0,0 +1,101 @@
|
|||
name: haheae
|
||||
channels:
|
||||
- defaults
|
||||
dependencies:
|
||||
- _libgcc_mutex=0.1=main
|
||||
- _openmp_mutex=5.1=1_gnu
|
||||
- ca-certificates=2023.12.12=h06a4308_0
|
||||
- ld_impl_linux-64=2.38=h1181459_1
|
||||
- libffi=3.4.4=h6a678d5_0
|
||||
- libgcc-ng=11.2.0=h1234567_1
|
||||
- libgomp=11.2.0=h1234567_1
|
||||
- libstdcxx-ng=11.2.0=h1234567_1
|
||||
- ncurses=6.4=h6a678d5_0
|
||||
- openssl=3.0.12=h7f8727e_0
|
||||
- pip=23.3.1=py38h06a4308_0
|
||||
- python=3.8.18=h955ad1f_0
|
||||
- readline=8.2=h5eee18b_0
|
||||
- setuptools=68.2.2=py38h06a4308_0
|
||||
- sqlite=3.41.2=h5eee18b_0
|
||||
- tk=8.6.12=h1ccaba5_0
|
||||
- wheel=0.41.2=py38h06a4308_0
|
||||
- xz=5.4.5=h5eee18b_0
|
||||
- zlib=1.2.13=h5eee18b_0
|
||||
- pip:
|
||||
- absl-py==2.0.0
|
||||
- aiohttp==3.9.1
|
||||
- aiosignal==1.3.1
|
||||
- appdirs==1.4.4
|
||||
- async-timeout==4.0.3
|
||||
- attrs==23.2.0
|
||||
- cachetools==5.3.2
|
||||
- certifi==2023.11.17
|
||||
- charset-normalizer==3.3.2
|
||||
- click==8.1.7
|
||||
- contourpy==1.1.1
|
||||
- cycler==0.12.1
|
||||
- cython==0.29.37
|
||||
- docker-pycreds==0.4.0
|
||||
- fonttools==4.47.2
|
||||
- frozenlist==1.4.1
|
||||
- fsspec==2023.12.2
|
||||
- ftfy==6.1.3
|
||||
- future==0.18.3
|
||||
- gitdb==4.0.11
|
||||
- gitpython==3.1.41
|
||||
- google-auth==2.26.2
|
||||
- google-auth-oauthlib==1.0.0
|
||||
- grpcio==1.60.0
|
||||
- hdbscan==0.8.33
|
||||
- idna==3.6
|
||||
- importlib-metadata==7.0.1
|
||||
- importlib-resources==6.1.1
|
||||
- joblib==1.3.2
|
||||
- kiwisolver==1.4.5
|
||||
- lmdb==1.2.1
|
||||
- lpips==0.1.4
|
||||
- markdown==3.5.2
|
||||
- markupsafe==2.1.3
|
||||
- matplotlib==3.5.3
|
||||
- multidict==6.0.4
|
||||
- numpy==1.24.4
|
||||
- oauthlib==3.2.2
|
||||
- packaging==23.2
|
||||
- pandas==1.5.3
|
||||
- pillow==10.2.0
|
||||
- protobuf==4.25.2
|
||||
- psutil==5.9.8
|
||||
- pyasn1==0.5.1
|
||||
- pyasn1-modules==0.3.0
|
||||
- pydeprecate==0.3.1
|
||||
- pyparsing==3.1.1
|
||||
- python-dateutil==2.8.2
|
||||
- pytorch-fid==0.2.0
|
||||
- pytorch-lightning==1.4.5
|
||||
- pytz==2023.3.post1
|
||||
- pyyaml==6.0.1
|
||||
- regex==2023.12.25
|
||||
- requests==2.31.0
|
||||
- requests-oauthlib==1.3.1
|
||||
- rsa==4.9
|
||||
- scikit-learn==1.3.2
|
||||
- scipy==1.5.4
|
||||
- sentry-sdk==1.39.2
|
||||
- setproctitle==1.3.3
|
||||
- six==1.16.0
|
||||
- smmap==5.0.1
|
||||
- tensorboard==2.14.0
|
||||
- tensorboard-data-server==0.7.2
|
||||
- threadpoolctl==3.2.0
|
||||
- torch==1.8.1
|
||||
- torchmetrics==0.5.0
|
||||
- torchvision==0.9.1
|
||||
- tqdm==4.66.1
|
||||
- typing-extensions==4.9.0
|
||||
- tzdata==2023.4
|
||||
- urllib3==2.1.0
|
||||
- wandb==0.16.2
|
||||
- wcwidth==0.2.13
|
||||
- werkzeug==3.0.1
|
||||
- yarl==1.9.4
|
||||
- zipp==3.17.0
|
565
main.py
Normal file
565
main.py
Normal file
|
@ -0,0 +1,565 @@
|
|||
import warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
import os
|
||||
os.nice(5)
|
||||
import copy, wandb
|
||||
from tqdm import tqdm, trange
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
import random
|
||||
import math
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from pytorch_lightning import loggers as pl_loggers
|
||||
from pytorch_lightning.callbacks import *
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.cuda import amp
|
||||
from torch.optim.optimizer import Optimizer
|
||||
from config import *
|
||||
import time
|
||||
import datetime
|
||||
|
||||
class LitModel(pl.LightningModule):
|
||||
def __init__(self, conf: TrainConfig):
|
||||
super().__init__()
|
||||
|
||||
## wandb
|
||||
self.save_hyperparameters({k:v for (k,v) in vars(conf).items() if not callable(v)})
|
||||
|
||||
if conf.seed is not None:
|
||||
pl.seed_everything(conf.seed)
|
||||
|
||||
self.save_hyperparameters(conf.as_dict_jsonable())
|
||||
self.conf = conf
|
||||
self.model = conf.make_model_conf().make_model()
|
||||
|
||||
self.ema_model = copy.deepcopy(self.model)
|
||||
self.ema_model.requires_grad_(False)
|
||||
self.ema_model.eval()
|
||||
|
||||
model_size = 0
|
||||
for param in self.model.parameters():
|
||||
model_size += param.data.nelement()
|
||||
print('Model params: %.3f M' % (model_size / 1024 / 1024))
|
||||
|
||||
self.sampler = conf.make_diffusion_conf().make_sampler()
|
||||
self.eval_sampler = conf.make_eval_diffusion_conf().make_sampler()
|
||||
self.T_sampler = conf.make_T_sampler()
|
||||
self.save_every_epochs = conf.save_every_epochs
|
||||
self.eval_every_epochs = conf.eval_every_epochs
|
||||
|
||||
def setup(self, stage=None) -> None:
|
||||
"""
|
||||
make datasets & seeding each worker separately
|
||||
"""
|
||||
##############################################
|
||||
# NEED TO SET THE SEED SEPARATELY HERE
|
||||
if self.conf.seed is not None:
|
||||
seed = self.conf.seed
|
||||
np.random.seed(seed)
|
||||
random.seed(seed) # Python random module.
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.backends.cudnn.deterministic = True
|
||||
print('seed:', seed)
|
||||
##############################################
|
||||
|
||||
## Load dataset
|
||||
if self.conf.mode == 'train':
|
||||
self.train_data = load_egobody(self.conf.data_dir, self.conf.seq_len+self.conf.seq_len_future, self.conf.data_sample_rate, train=1)
|
||||
self.val_data = load_egobody(self.conf.data_dir, self.conf.seq_len+self.conf.seq_len_future, self.conf.data_sample_rate, train=0)
|
||||
if self.conf.in_channels == 6: # hand only
|
||||
self.train_data = self.train_data[:, :6, :]
|
||||
self.val_data = self.val_data[:, :6, :]
|
||||
if self.conf.in_channels == 3: # head only
|
||||
self.train_data = self.train_data[:, 6:, :]
|
||||
self.val_data = self.val_data[:, 6:, :]
|
||||
|
||||
def encode(self, x):
|
||||
assert self.conf.model_type.has_autoenc()
|
||||
cond, pred_hand, pred_head = self.ema_model.encoder.forward(x)
|
||||
return cond, pred_hand, pred_head
|
||||
|
||||
def encode_stochastic(self, x, cond, T=None):
|
||||
if T is None:
|
||||
sampler = self.eval_sampler
|
||||
else:
|
||||
sampler = self.conf._make_diffusion_conf(T).make_sampler() # get noise at step T
|
||||
|
||||
## x_0 -> x-T using reverse of inference
|
||||
out = sampler.ddim_reverse_sample_loop(self.ema_model, x, model_kwargs={'cond': cond})
|
||||
''' 'sample': x_T
|
||||
'sample_t': x_t, t in (1, ..., T)
|
||||
'xstart_t': predicted x_0 at each timestep. "xstart here is a bit different from sampling from T = T-1 to T = 0"
|
||||
'T': (1, ..., T)
|
||||
'''
|
||||
return out['sample']
|
||||
|
||||
def train_dataloader(self):
|
||||
return torch.utils.data.DataLoader(self.train_data, batch_size=self.conf.batch_size, shuffle=True)
|
||||
|
||||
def val_dataloader(self):
|
||||
return torch.utils.data.DataLoader(self.val_data, batch_size=self.conf.batch_size_eval, shuffle=False)
|
||||
|
||||
def is_last_accum(self, batch_idx):
|
||||
"""
|
||||
is it the last gradient accumulation loop?
|
||||
used with gradient_accum > 1 and to see if the optimizer will perform "step" in this iteration or not
|
||||
"""
|
||||
return (batch_idx + 1) % self.conf.accum_batches == 0
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
"""
|
||||
given an input, calculate the loss function
|
||||
no optimization at this stage.
|
||||
"""
|
||||
with amp.autocast(False):
|
||||
x_start = batch[:, :, :self.conf.seq_len]
|
||||
x_future = batch[:, :, self.conf.seq_len:]
|
||||
|
||||
if self.conf.train_mode == TrainMode.diffusion:
|
||||
"""
|
||||
main training mode!!!
|
||||
"""
|
||||
t, weight = self.T_sampler.sample(len(x_start), x_start.device)
|
||||
''' self.T_sampler: diffusion.resample.UniformSampler (weights for all timesteps are 1)
|
||||
- t: a tensor of timestep indices.
|
||||
- weight: a tensor of weights to scale the resulting losses.
|
||||
## num_timesteps is self.conf.T == 1000
|
||||
'''
|
||||
losses = self.sampler.training_losses(model=self.model,
|
||||
x_start=x_start,
|
||||
t=t,
|
||||
x_future=x_future,
|
||||
hand_mse_factor = self.conf.hand_mse_factor,
|
||||
head_mse_factor = self.conf.head_mse_factor,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
loss = losses['loss'].mean() ## average loss across mini-batches
|
||||
#noise_mse = losses['mse'].mean()
|
||||
#hand_mse = losses['hand_mse'].mean()
|
||||
#head_mse = losses['head_mse'].mean()
|
||||
|
||||
## Log loss and metric (wandb)
|
||||
self.log("train_loss", loss, on_epoch=True, prog_bar=True)
|
||||
#self.log("train_noise_mse", noise_mse, on_epoch=True, prog_bar=True)
|
||||
#self.log("train_hand_mse", hand_mse, on_epoch=True, prog_bar=True)
|
||||
#self.log("train_head_mse", head_mse, on_epoch=True, prog_bar=True)
|
||||
return {'loss': loss}
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
if self.conf.in_channels == 9: # both hand and head
|
||||
if((self.current_epoch+1)% self.eval_every_epochs == 0):
|
||||
batch_future = batch[:, :, self.conf.seq_len:]
|
||||
gt_hand_future = batch_future[:, :6, :]
|
||||
gt_head_future = batch_future[:, 6:, :]
|
||||
batch = batch[:, :, :self.conf.seq_len]
|
||||
cond, pred_hand_future, pred_head_future = self.encode(batch)
|
||||
xT = self.encode_stochastic(batch, cond)
|
||||
pred_xstart = self.generate(xT, cond)
|
||||
|
||||
# hand reconstruction error
|
||||
gt_hand = batch[:, :6, :]
|
||||
pred_hand = pred_xstart[:, :6, :]
|
||||
bs, channels, seq_len = gt_hand.shape
|
||||
gt_hand = gt_hand.reshape(bs, 2, 3, seq_len)
|
||||
pred_hand = pred_hand.reshape(bs, 2, 3, seq_len)
|
||||
hand_traj = torch.mean(torch.norm(gt_hand - pred_hand, dim=2))
|
||||
|
||||
# hand prediction error
|
||||
bs, channels, seq_len = gt_hand_future.shape
|
||||
gt_hand_future = gt_hand_future.reshape(bs, 2, 3, seq_len)
|
||||
pred_hand_future = pred_hand_future.reshape(bs, 2, 3, seq_len)
|
||||
baseline_hand_future = gt_hand[:, :, :, -1:].expand(-1, -1, -1, self.conf.seq_len_future).clone()
|
||||
hand_traj_future = torch.mean(torch.norm(gt_hand_future - pred_hand_future, dim=2))
|
||||
hand_traj_future_baseline = torch.mean(torch.norm(gt_hand_future - baseline_hand_future, dim=2))
|
||||
|
||||
# head reconstruction error
|
||||
gt_head = batch[:, 6:, :]
|
||||
gt_head = F.normalize(gt_head, dim=1) # normalize head orientation to unit vectors
|
||||
pred_head = pred_xstart[:, 6:, :]
|
||||
pred_head = F.normalize(pred_head, dim=1)
|
||||
head_ang = torch.mean(acos_safe(torch.sum(gt_head*pred_head, 1)))/torch.tensor(math.pi) * 180.0
|
||||
|
||||
# head prediction error
|
||||
gt_head_future = F.normalize(gt_head_future, dim=1)
|
||||
pred_head_future = F.normalize(pred_head_future, dim=1)
|
||||
baseline_head_future = gt_head[:, :, -1:].expand(-1, -1, self.conf.seq_len_future).clone()
|
||||
head_ang_future = torch.mean(acos_safe(torch.sum(gt_head_future*pred_head_future, 1)))/torch.tensor(math.pi) * 180.0
|
||||
head_ang_future_baseline = torch.mean(acos_safe(torch.sum(gt_head_future*baseline_head_future, 1)))/torch.tensor(math.pi) * 180.0
|
||||
|
||||
self.log("val_hand_traj", hand_traj, on_epoch=True, prog_bar=True)
|
||||
self.log("val_head_ang", head_ang, on_epoch=True, prog_bar=True)
|
||||
self.log("val_hand_traj_future", hand_traj_future, on_epoch=True, prog_bar=True)
|
||||
self.log("val_head_ang_future", head_ang_future, on_epoch=True, prog_bar=True)
|
||||
self.log("val_hand_traj_future_baseline", hand_traj_future_baseline, on_epoch=True, prog_bar=True)
|
||||
self.log("val_head_ang_future_baseline", head_ang_future_baseline, on_epoch=True, prog_bar=True)
|
||||
|
||||
if self.conf.in_channels == 6: # hand only
|
||||
if((self.current_epoch+1)% self.eval_every_epochs == 0):
|
||||
batch_future = batch[:, :, self.conf.seq_len:]
|
||||
gt_hand_future = batch_future[:, :, :]
|
||||
batch = batch[:, :, :self.conf.seq_len]
|
||||
cond, pred_hand_future, pred_head_future = self.encode(batch)
|
||||
xT = self.encode_stochastic(batch, cond)
|
||||
pred_xstart = self.generate(xT, cond)
|
||||
|
||||
# hand reconstruction error
|
||||
gt_hand = batch[:, :, :]
|
||||
pred_hand = pred_xstart[:, :, :]
|
||||
bs, channels, seq_len = gt_hand.shape
|
||||
gt_hand = gt_hand.reshape(bs, 2, 3, seq_len)
|
||||
pred_hand = pred_hand.reshape(bs, 2, 3, seq_len)
|
||||
hand_traj = torch.mean(torch.norm(gt_hand - pred_hand, dim=2))
|
||||
|
||||
# hand prediction error
|
||||
bs, channels, seq_len = gt_hand_future.shape
|
||||
gt_hand_future = gt_hand_future.reshape(bs, 2, 3, seq_len)
|
||||
pred_hand_future = pred_hand_future.reshape(bs, 2, 3, seq_len)
|
||||
baseline_hand_future = gt_hand[:, :, :, -1:].expand(-1, -1, -1, self.conf.seq_len_future).clone()
|
||||
hand_traj_future = torch.mean(torch.norm(gt_hand_future - pred_hand_future, dim=2))
|
||||
hand_traj_future_baseline = torch.mean(torch.norm(gt_hand_future - baseline_hand_future, dim=2))
|
||||
|
||||
self.log("val_hand_traj", hand_traj, on_epoch=True, prog_bar=True)
|
||||
self.log("val_hand_traj_future", hand_traj_future, on_epoch=True, prog_bar=True)
|
||||
self.log("val_hand_traj_future_baseline", hand_traj_future_baseline, on_epoch=True, prog_bar=True)
|
||||
|
||||
if self.conf.in_channels == 3: # head only
|
||||
if((self.current_epoch+1)% self.eval_every_epochs == 0):
|
||||
batch_future = batch[:, :, self.conf.seq_len:]
|
||||
gt_head_future = batch_future[:, :, :]
|
||||
batch = batch[:, :, :self.conf.seq_len]
|
||||
cond, pred_hand_future, pred_head_future = self.encode(batch)
|
||||
xT = self.encode_stochastic(batch, cond)
|
||||
pred_xstart = self.generate(xT, cond)
|
||||
|
||||
# head reconstruction error
|
||||
gt_head = batch[:, :, :]
|
||||
gt_head = F.normalize(gt_head, dim=1) # normalize head orientation to unit vectors
|
||||
pred_head = pred_xstart[:, :, :]
|
||||
pred_head = F.normalize(pred_head, dim=1)
|
||||
head_ang = torch.mean(acos_safe(torch.sum(gt_head*pred_head, 1)))/torch.tensor(math.pi) * 180.0
|
||||
|
||||
# head prediction error
|
||||
gt_head_future = F.normalize(gt_head_future, dim=1)
|
||||
pred_head_future = F.normalize(pred_head_future, dim=1)
|
||||
baseline_head_future = gt_head[:, :, -1:].expand(-1, -1, self.conf.seq_len_future).clone()
|
||||
head_ang_future = torch.mean(acos_safe(torch.sum(gt_head_future*pred_head_future, 1)))/torch.tensor(math.pi) * 180.0
|
||||
head_ang_future_baseline = torch.mean(acos_safe(torch.sum(gt_head_future*baseline_head_future, 1)))/torch.tensor(math.pi) * 180.0
|
||||
|
||||
self.log("val_head_ang", head_ang, on_epoch=True, prog_bar=True)
|
||||
self.log("val_head_ang_future", head_ang_future, on_epoch=True, prog_bar=True)
|
||||
self.log("val_head_ang_future_baseline", head_ang_future_baseline, on_epoch=True, prog_bar=True)
|
||||
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
batch_future = batch[:, :, self.conf.seq_len:]
|
||||
gt_hand_future = batch_future[:, :6, :]
|
||||
gt_head_future = batch_future[:, 6:, :]
|
||||
batch = batch[:, :, :self.conf.seq_len]
|
||||
cond, pred_hand_future, pred_head_future = self.encode(batch)
|
||||
xT = self.encode_stochastic(batch, cond)
|
||||
pred_xstart = self.generate(xT, cond)
|
||||
|
||||
# hand reconstruction error
|
||||
gt_hand = batch[:, :6, :]
|
||||
pred_hand = pred_xstart[:, :6, :]
|
||||
bs, channels, seq_len = gt_hand.shape
|
||||
gt_hand = gt_hand.reshape(bs, 2, 3, seq_len)
|
||||
pred_hand = pred_hand.reshape(bs, 2, 3, seq_len)
|
||||
hand_traj = torch.mean(torch.norm(gt_hand - pred_hand, dim=2))
|
||||
|
||||
# hand prediction error
|
||||
bs, channels, seq_len = gt_hand_future.shape
|
||||
gt_hand_future = gt_hand_future.reshape(bs, 2, 3, seq_len)
|
||||
pred_hand_future = pred_hand_future.reshape(bs, 2, 3, seq_len)
|
||||
baseline_hand_future = gt_hand[:, :, :, -1:].expand(-1, -1, -1, self.conf.seq_len_future).clone()
|
||||
hand_traj_future = torch.mean(torch.norm(gt_hand_future - pred_hand_future, dim=2))
|
||||
hand_traj_future_baseline = torch.mean(torch.norm(gt_hand_future - baseline_hand_future, dim=2))
|
||||
|
||||
# head reconstruction error
|
||||
gt_head = batch[:, 6:, :]
|
||||
gt_head = F.normalize(gt_head, dim=1) # normalize head orientation to unit vectors
|
||||
pred_head = pred_xstart[:, 6:, :]
|
||||
pred_head = F.normalize(pred_head, dim=1)
|
||||
head_ang = torch.mean(acos_safe(torch.sum(gt_head*pred_head, 1)))/torch.tensor(math.pi) * 180.0
|
||||
|
||||
# head prediction error
|
||||
gt_head_future = F.normalize(gt_head_future, dim=1)
|
||||
pred_head_future = F.normalize(pred_head_future, dim=1)
|
||||
baseline_head_future = gt_head[:, :, -1:].expand(-1, -1, self.conf.seq_len_future).clone()
|
||||
head_ang_future = torch.mean(acos_safe(torch.sum(gt_head_future*pred_head_future, 1)))/torch.tensor(math.pi) * 180.0
|
||||
head_ang_future_baseline = torch.mean(acos_safe(torch.sum(gt_head_future*baseline_head_future, 1)))/torch.tensor(math.pi) * 180.0
|
||||
|
||||
self.log("test_hand_traj", hand_traj, on_epoch=True, prog_bar=True)
|
||||
self.log("test_head_ang", head_ang, on_epoch=True, prog_bar=True)
|
||||
self.log("test_hand_traj_future", hand_traj_future, on_epoch=True, prog_bar=True)
|
||||
self.log("test_head_ang_future", head_ang_future, on_epoch=True, prog_bar=True)
|
||||
self.log("test_hand_traj_future_baseline", hand_traj_future_baseline, on_epoch=True, prog_bar=True)
|
||||
self.log("test_head_ang_future_baseline", head_ang_future_baseline, on_epoch=True, prog_bar=True)
|
||||
|
||||
|
||||
def generate(self, noise, cond=None, ema=True, T=None):
|
||||
if T is None:
|
||||
sampler = self.eval_sampler
|
||||
else:
|
||||
sampler = self.conf._make_diffusion_conf(T).make_sampler()
|
||||
|
||||
if ema:
|
||||
model = self.ema_model
|
||||
else:
|
||||
model = self.model
|
||||
|
||||
gen = sampler.sample(model=model, noise=noise, model_kwargs={'cond': cond})
|
||||
return gen
|
||||
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx: int,
|
||||
dataloader_idx: int) -> None:
|
||||
"""
|
||||
after each training step ...
|
||||
"""
|
||||
if self.is_last_accum(batch_idx):
|
||||
# only apply ema on the last gradient accumulation step,
|
||||
# if it is the iteration that has optimizer.step()
|
||||
|
||||
ema(self.model, self.ema_model, self.conf.ema_decay)
|
||||
|
||||
if (batch_idx==len(self.train_dataloader())-1) and ((self.current_epoch+1) % self.save_every_epochs == 0):
|
||||
save_path = os.path.join(self.conf.logdir, 'epoch_%d.ckpt' % (self.current_epoch+1))
|
||||
torch.save({
|
||||
'state_dict': self.state_dict(),
|
||||
'global_step': self.global_step,
|
||||
'loss': outputs['loss'],
|
||||
}, save_path)
|
||||
|
||||
def on_before_optimizer_step(self, optimizer: Optimizer,
|
||||
optimizer_idx: int) -> None:
|
||||
# fix the fp16 + clip grad norm problem with pytorch lightinng
|
||||
# this is the currently correct way to do it
|
||||
if self.conf.grad_clip > 0:
|
||||
# from trainer.params_grads import grads_norm, iter_opt_params
|
||||
params = [
|
||||
p for group in optimizer.param_groups for p in group['params']
|
||||
]
|
||||
# print('before:', grads_norm(iter_opt_params(optimizer)))
|
||||
torch.nn.utils.clip_grad_norm_(params, max_norm=self.conf.grad_clip)
|
||||
# print('after:', grads_norm(iter_opt_params(optimizer)))
|
||||
|
||||
def configure_optimizers(self):
|
||||
out = {}
|
||||
if self.conf.optimizer == OptimizerType.adam:
|
||||
optim = torch.optim.Adam(self.model.parameters(),
|
||||
lr=self.conf.lr,
|
||||
weight_decay=self.conf.weight_decay)
|
||||
elif self.conf.optimizer == OptimizerType.adamw:
|
||||
optim = torch.optim.AdamW(self.model.parameters(),
|
||||
lr=self.conf.lr,
|
||||
weight_decay=self.conf.weight_decay)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
out['optimizer'] = optim
|
||||
if self.conf.warmup > 0:
|
||||
sched = torch.optim.lr_scheduler.LambdaLR(optim,
|
||||
lr_lambda=WarmupLR(
|
||||
self.conf.warmup))
|
||||
out['lr_scheduler'] = {
|
||||
'scheduler': sched,
|
||||
'interval': 'step',
|
||||
}
|
||||
return out
|
||||
|
||||
|
||||
def ema(source, target, decay):
|
||||
source_dict = source.state_dict()
|
||||
target_dict = target.state_dict()
|
||||
for key in source_dict.keys():
|
||||
target_dict[key].data.copy_(target_dict[key].data * decay +
|
||||
source_dict[key].data * (1 - decay))
|
||||
|
||||
|
||||
class WarmupLR:
|
||||
def __init__(self, warmup) -> None:
|
||||
self.warmup = warmup
|
||||
|
||||
def __call__(self, step):
|
||||
return min(step, self.warmup) / self.warmup
|
||||
|
||||
|
||||
def train(conf: TrainConfig, model: LitModel, gpus):
|
||||
checkpoint = ModelCheckpoint(dirpath=conf.logdir,
|
||||
filename='last',
|
||||
save_last=True,
|
||||
save_top_k=1,
|
||||
every_n_epochs=conf.save_every_epochs,
|
||||
)
|
||||
checkpoint_path = f'{conf.logdir}last.ckpt'
|
||||
if os.path.exists(checkpoint_path):
|
||||
resume = checkpoint_path
|
||||
if conf.mode == 'train':
|
||||
print('ckpt path:', checkpoint_path)
|
||||
else:
|
||||
print('checkpoint not found!')
|
||||
resume = None
|
||||
|
||||
wandb_logger = pl_loggers.WandbLogger(project='haheae',
|
||||
name='%s_%s'%(model.conf.data_name, conf.logdir.split('/')[-2]),
|
||||
log_model=True,
|
||||
save_dir=conf.logdir,
|
||||
dir = conf.logdir,
|
||||
config=vars(model.conf),
|
||||
save_code=True,
|
||||
settings=wandb.Settings(code_dir="."))
|
||||
|
||||
trainer = pl.Trainer(
|
||||
max_epochs=conf.total_epochs,
|
||||
resume_from_checkpoint=resume,
|
||||
gpus=gpus,
|
||||
precision=16 if conf.fp16 else 32,
|
||||
callbacks=[
|
||||
checkpoint,
|
||||
LearningRateMonitor(),
|
||||
],
|
||||
logger= wandb_logger,
|
||||
accumulate_grad_batches=conf.accum_batches,
|
||||
progress_bar_refresh_rate=4,
|
||||
)
|
||||
|
||||
if conf.mode == 'train':
|
||||
trainer.fit(model)
|
||||
elif conf.mode == 'eval':
|
||||
checkpoint_path = f'{conf.logdir}last.ckpt'
|
||||
# load the latest checkpoint
|
||||
print('loading from:', checkpoint_path)
|
||||
state = torch.load(checkpoint_path)
|
||||
model.load_state_dict(state['state_dict'])
|
||||
|
||||
test_datasets = ['egobody', 'adt', 'gimo']
|
||||
for dataset_name in test_datasets:
|
||||
if dataset_name == 'egobody':
|
||||
data_dir = "/scratch/hu/pose_forecast/egobody_pose2gaze/"
|
||||
test_data = load_egobody(data_dir, conf.seq_len+conf.seq_len_future, 1, train=0) # use the test set
|
||||
elif dataset_name == 'adt':
|
||||
data_dir = "/scratch/hu/pose_forecast/adt_pose2gaze/"
|
||||
test_data = load_adt(data_dir, conf.seq_len+conf.seq_len_future, 1, train=2) # use the train+test set
|
||||
elif dataset_name == 'gimo':
|
||||
data_dir = "/scratch/hu/pose_forecast/gimo_pose2gaze/"
|
||||
test_data = load_gimo(data_dir, conf.seq_len+conf.seq_len_future, 1, train=2) # use the train+test set
|
||||
|
||||
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=conf.batch_size_eval, shuffle=False)
|
||||
results = trainer.test(model, dataloaders=test_dataloader, verbose=False)
|
||||
print("\n\nTest on {}, dataset size: {}".format(dataset_name, test_data.shape))
|
||||
print("test_hand_traj: {:.3f} cm".format(results[0]['test_hand_traj']*100))
|
||||
print("test_head_ang: {:.3f} deg".format(results[0]['test_head_ang']))
|
||||
print("test_hand_traj_future: {:.3f} cm".format(results[0]['test_hand_traj_future']*100))
|
||||
print("test_head_ang_future: {:.3f} deg".format(results[0]['test_head_ang_future']))
|
||||
print("test_hand_traj_future_baseline: {:.3f} cm".format(results[0]['test_hand_traj_future_baseline']*100))
|
||||
print("test_head_ang_future_baseline: {:.3f} deg\n\n".format(results[0]['test_head_ang_future_baseline']))
|
||||
|
||||
wandb.finish()
|
||||
|
||||
|
||||
def acos_safe(x, eps=1e-6):
|
||||
slope = np.arccos(1-eps) / eps
|
||||
buf = torch.empty_like(x)
|
||||
good = abs(x) <= 1-eps
|
||||
bad = ~good
|
||||
sign = torch.sign(x[bad])
|
||||
buf[good] = torch.acos(x[good])
|
||||
buf[bad] = torch.acos(sign * (1 - eps)) - slope*sign*(abs(x[bad]) - 1 + eps)
|
||||
return buf
|
||||
|
||||
|
||||
def get_representation(model, dataset, conf, device='cuda'):
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=conf.batch_size_eval, shuffle=False)
|
||||
with torch.no_grad():
|
||||
conds = [] # semantic representation
|
||||
xTs = [] # stochastic representation
|
||||
for batch in tqdm(dataloader, total=len(dataloader), desc='infer'):
|
||||
batch = batch.to(device)
|
||||
cond, _, _ = model.encode(batch)
|
||||
xT = model.encode_stochastic(batch, cond)
|
||||
cond_cpu = cond.cpu().data.numpy()
|
||||
xT_cpu = xT.cpu().data.numpy()
|
||||
if len(conds) == 0:
|
||||
conds = cond_cpu
|
||||
xTs = xT_cpu
|
||||
else:
|
||||
conds = np.concatenate((conds, cond_cpu), axis=0)
|
||||
xTs = np.concatenate((xTs, xT_cpu), axis=0)
|
||||
return conds, xTs
|
||||
|
||||
|
||||
def generate_from_representation(model, conds, xTs, device='cuda'):
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
conds = torch.from_numpy(conds).to(device)
|
||||
xTs = torch.from_numpy(xTs).to(device)
|
||||
rec = model.generate(xTs, conds)
|
||||
rec = rec.cpu().data.numpy()
|
||||
return rec
|
||||
|
||||
|
||||
def evaluate_reconstruction(gt, rec):
|
||||
# hand reconstruction error (cm)
|
||||
gt_hand = gt[:, :6, :]
|
||||
rec_hand = rec[:, :6, :]
|
||||
bs, channels, seq_len = gt_hand.shape
|
||||
gt_hand = gt_hand.reshape(bs, 2, 3, seq_len)
|
||||
rec_hand = rec_hand.reshape(bs, 2, 3, seq_len)
|
||||
hand_traj_errors = np.mean(np.mean(np.linalg.norm(gt_hand - rec_hand, axis=2), axis=1), axis=1)*100
|
||||
|
||||
# head reconstruction error (deg)
|
||||
gt_head = gt[:, 6:, :]
|
||||
gt_head_norm = np.linalg.norm(gt_head, axis=1, keepdims=True)
|
||||
gt_head = gt_head/gt_head_norm
|
||||
rec_head = rec[:, 6:, :]
|
||||
rec_head_norm = np.linalg.norm(rec_head, axis=1, keepdims=True)
|
||||
rec_head = rec_head/rec_head_norm
|
||||
dot_sum = np.clip(np.sum(gt_head*rec_head, axis=1), -1, 1)
|
||||
head_ang_errors = np.mean(np.arccos(dot_sum), axis=1)/np.pi * 180.0
|
||||
return hand_traj_errors, head_ang_errors
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--gpus', default=7, type=int)
|
||||
parser.add_argument('--mode', default='eval', type=str)
|
||||
parser.add_argument('--encoder_type', default='gcn', type=str)
|
||||
parser.add_argument('--model_name', default='haheae', type=str)
|
||||
parser.add_argument('--hand_mse_factor', default=1.0, type=float)
|
||||
parser.add_argument('--head_mse_factor', default=1.0, type=float)
|
||||
parser.add_argument('--data_sample_rate', default=1, type=int)
|
||||
parser.add_argument('--epoch', default=130, type=int)
|
||||
parser.add_argument('--in_channels', default=9, type=int)
|
||||
args = parser.parse_args()
|
||||
|
||||
conf = egobody_autoenc(args.mode, args.encoder_type, args.hand_mse_factor, args.head_mse_factor, args.data_sample_rate, args.epoch, args.in_channels)
|
||||
model = LitModel(conf)
|
||||
conf.logdir = f'{conf.logdir}{args.model_name}/'
|
||||
print('log dir: {}'.format(conf.logdir))
|
||||
MakeDir(conf.logdir)
|
||||
|
||||
if conf.mode == 'train' or conf.mode == 'eval': # train or evaluate the model
|
||||
os.environ['WANDB_CACHE_DIR'] = conf.logdir
|
||||
os.environ['WANDB_DATA_DIR'] = conf.logdir
|
||||
# set wandb to not upload checkpoints, but all the others
|
||||
os.environ['WANDB_IGNORE_GLOBS'] = '*.ckpt'
|
||||
local_time = time.asctime(time.localtime(time.time()))
|
||||
print('\n{} starts at {}'.format(conf.mode, local_time))
|
||||
start_time = datetime.datetime.now()
|
||||
train(conf, model, gpus=[args.gpus])
|
||||
end_time = datetime.datetime.now()
|
||||
total_time = (end_time - start_time).seconds/60
|
||||
print('\nTotal time: {:.3f} min'.format(total_time))
|
||||
local_time = time.asctime(time.localtime(time.time()))
|
||||
print('\n{} ends at {}'.format(conf.mode, local_time))
|
6
model/__init__.py
Normal file
6
model/__init__.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
from typing import Union
|
||||
from .unet import BeatGANsUNetModel, BeatGANsUNetConfig, GCNUNetModel, GCNUNetConfig
|
||||
from .unet_autoenc import BeatGANsAutoencConfig, BeatGANsAutoencModel, GCNAutoencConfig, GCNAutoencModel
|
||||
|
||||
Model = Union[BeatGANsUNetModel, BeatGANsAutoencModel, GCNUNetModel, GCNAutoencModel]
|
||||
ModelConfig = Union[BeatGANsUNetConfig, BeatGANsAutoencConfig, GCNUNetConfig,GCNAutoencConfig]
|
579
model/blocks.py
Normal file
579
model/blocks.py
Normal file
|
@ -0,0 +1,579 @@
|
|||
import math, pdb
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from numbers import Number
|
||||
|
||||
import torch as th
|
||||
import torch.nn.functional as F
|
||||
from choices import *
|
||||
from config_base import BaseConfig
|
||||
from torch import nn
|
||||
import numpy as np
|
||||
from .nn import (avg_pool_nd, conv_nd, linear, normalization,
|
||||
timestep_embedding, zero_module)
|
||||
|
||||
|
||||
class TimestepBlock(nn.Module):
|
||||
"""
|
||||
Any module where forward() takes timestep embeddings as a second argument.
|
||||
"""
|
||||
@abstractmethod
|
||||
def forward(self, x, emb=None, cond=None, lateral=None):
|
||||
"""
|
||||
Apply the module to `x` given `emb` timestep embeddings.
|
||||
"""
|
||||
|
||||
|
||||
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
||||
"""
|
||||
A sequential module that passes timestep embeddings to the children that
|
||||
support it as an extra input.
|
||||
"""
|
||||
def forward(self, x, emb=None, cond=None, lateral=None):
|
||||
for layer in self:
|
||||
if isinstance(layer, TimestepBlock):
|
||||
'''if layer(x, emb=emb, cond=cond, lateral=lateral).shape[-1]==10:
|
||||
pdb.set_trace()'''
|
||||
x = layer(x, emb=emb, cond=cond, lateral=lateral)
|
||||
else:
|
||||
'''if layer(x).shape[-1]==10:
|
||||
pdb.set_trace()'''
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResBlockConfig(BaseConfig):
|
||||
channels: int
|
||||
emb_channels: int
|
||||
dropout: float
|
||||
out_channels: int = None
|
||||
# condition the resblock with time (and encoder's output)
|
||||
use_condition: bool = True
|
||||
# whether to use 3x3 conv for skip path when the channels aren't matched
|
||||
use_conv: bool = False
|
||||
# dimension of conv (always 1 = 1d)
|
||||
dims: int = 1
|
||||
up: bool = False
|
||||
down: bool = False
|
||||
# whether to condition with both time & encoder's output
|
||||
two_cond: bool = False
|
||||
# number of encoders' output channels
|
||||
cond_emb_channels: int = None
|
||||
# suggest: False
|
||||
has_lateral: bool = False
|
||||
lateral_channels: int = None
|
||||
# whether to init the convolution with zero weights
|
||||
# this is default from BeatGANs and seems to help learning
|
||||
use_zero_module: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
self.out_channels = self.out_channels or self.channels
|
||||
self.cond_emb_channels = self.cond_emb_channels or self.emb_channels
|
||||
|
||||
def make_model(self):
|
||||
return ResBlock(self)
|
||||
|
||||
|
||||
class ResBlock(TimestepBlock):
|
||||
"""
|
||||
A residual block that can optionally change the number of channels.
|
||||
|
||||
total layers:
|
||||
in_layers
|
||||
- norm
|
||||
- act
|
||||
- conv
|
||||
out_layers
|
||||
- norm
|
||||
- (modulation)
|
||||
- act
|
||||
- conv
|
||||
"""
|
||||
def __init__(self, conf: ResBlockConfig):
|
||||
super().__init__()
|
||||
self.conf = conf
|
||||
|
||||
#############################
|
||||
# IN LAYERS
|
||||
#############################
|
||||
assert conf.lateral_channels is None
|
||||
layers = [
|
||||
normalization(conf.channels),
|
||||
nn.SiLU(),
|
||||
conv_nd(conf.dims, conf.channels, conf.out_channels, 3, padding=1) ## 3 is kernel size
|
||||
]
|
||||
self.in_layers = nn.Sequential(*layers)
|
||||
|
||||
self.updown = conf.up or conf.down
|
||||
|
||||
if conf.up:
|
||||
self.h_upd = Upsample(conf.channels, False, conf.dims)
|
||||
self.x_upd = Upsample(conf.channels, False, conf.dims)
|
||||
elif conf.down:
|
||||
self.h_upd = Downsample(conf.channels, False, conf.dims)
|
||||
self.x_upd = Downsample(conf.channels, False, conf.dims)
|
||||
else:
|
||||
self.h_upd = self.x_upd = nn.Identity()
|
||||
|
||||
#############################
|
||||
# OUT LAYERS CONDITIONS
|
||||
#############################
|
||||
if conf.use_condition:
|
||||
# condition layers for the out_layers
|
||||
self.emb_layers = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
linear(conf.emb_channels, 2 * conf.out_channels),
|
||||
)
|
||||
|
||||
if conf.two_cond:
|
||||
self.cond_emb_layers = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
linear(conf.cond_emb_channels, conf.out_channels),
|
||||
)
|
||||
#############################
|
||||
# OUT LAYERS (ignored when there is no condition)
|
||||
#############################
|
||||
# original version
|
||||
conv = conv_nd(conf.dims,
|
||||
conf.out_channels,
|
||||
conf.out_channels,
|
||||
3,
|
||||
padding=1)
|
||||
if conf.use_zero_module:
|
||||
# zere out the weights
|
||||
# it seems to help training
|
||||
conv = zero_module(conv)
|
||||
|
||||
# construct the layers
|
||||
# - norm
|
||||
# - (modulation)
|
||||
# - act
|
||||
# - dropout
|
||||
# - conv
|
||||
layers = []
|
||||
layers += [
|
||||
normalization(conf.out_channels),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(p=conf.dropout),
|
||||
conv,
|
||||
]
|
||||
self.out_layers = nn.Sequential(*layers)
|
||||
|
||||
#############################
|
||||
# SKIP LAYERS
|
||||
#############################
|
||||
if conf.out_channels == conf.channels:
|
||||
# cannot be used with gatedconv, also gatedconv is alsways used as the first block
|
||||
self.skip_connection = nn.Identity()
|
||||
else:
|
||||
if conf.use_conv:
|
||||
kernel_size = 3
|
||||
padding = 1
|
||||
else:
|
||||
kernel_size = 1
|
||||
padding = 0
|
||||
|
||||
self.skip_connection = conv_nd(conf.dims,
|
||||
conf.channels,
|
||||
conf.out_channels,
|
||||
kernel_size,
|
||||
padding=padding)
|
||||
|
||||
def forward(self, x, emb=None, cond=None, lateral=None):
|
||||
"""
|
||||
Apply the block to a Tensor, conditioned on a timestep embedding.
|
||||
|
||||
Args:
|
||||
x: input
|
||||
lateral: lateral connection from the encoder
|
||||
"""
|
||||
return self._forward(x, emb, cond, lateral)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x,
|
||||
emb=None,
|
||||
cond=None,
|
||||
lateral=None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
lateral: required if "has_lateral" and non-gated, with gated, it can be supplied optionally
|
||||
"""
|
||||
if self.conf.has_lateral:
|
||||
# lateral may be supplied even if it doesn't require
|
||||
# the model will take the lateral only if "has_lateral"
|
||||
assert lateral is not None
|
||||
# x = F.interpolate(x, size=(lateral.size(2)), mode='linear' )
|
||||
x = th.cat([x, lateral], dim=1)
|
||||
|
||||
if self.updown:
|
||||
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
||||
h = in_rest(x)
|
||||
h = self.h_upd(h)
|
||||
x = self.x_upd(x)
|
||||
h = in_conv(h)
|
||||
else:
|
||||
h = self.in_layers(x)
|
||||
|
||||
if self.conf.use_condition:
|
||||
# it's possible that the network may not receieve the time emb
|
||||
# this happens with autoenc and setting the time_at
|
||||
if emb is not None:
|
||||
emb_out = self.emb_layers(emb).type(h.dtype)
|
||||
else:
|
||||
emb_out = None
|
||||
|
||||
if self.conf.two_cond:
|
||||
# it's possible that the network is two_cond
|
||||
# but it doesn't get the second condition
|
||||
# in which case, we ignore the second condition
|
||||
# and treat as if the network has one condition
|
||||
if cond is None:
|
||||
cond_out = None
|
||||
else:
|
||||
if not isinstance(cond, th.Tensor):
|
||||
assert isinstance(cond, dict)
|
||||
cond = cond['cond']
|
||||
cond_out = self.cond_emb_layers(cond).type(h.dtype)
|
||||
if cond_out is not None:
|
||||
while len(cond_out.shape) < len(h.shape):
|
||||
cond_out = cond_out[..., None]
|
||||
else:
|
||||
cond_out = None
|
||||
|
||||
# this is the new refactored code
|
||||
h = apply_conditions(
|
||||
h=h,
|
||||
emb=emb_out,
|
||||
cond=cond_out,
|
||||
layers=self.out_layers,
|
||||
scale_bias=1,
|
||||
in_channels=self.conf.out_channels,
|
||||
up_down_layer=None,
|
||||
)
|
||||
|
||||
return self.skip_connection(x) + h
|
||||
|
||||
|
||||
def apply_conditions(
|
||||
h,
|
||||
emb=None,
|
||||
cond=None,
|
||||
layers: nn.Sequential = None,
|
||||
scale_bias: float = 1,
|
||||
in_channels: int = 512,
|
||||
up_down_layer: nn.Module = None,
|
||||
):
|
||||
"""
|
||||
apply conditions on the feature maps
|
||||
|
||||
Args:
|
||||
emb: time conditional (ready to scale + shift)
|
||||
cond: encoder's conditional (ready to scale + shift)
|
||||
"""
|
||||
two_cond = emb is not None and cond is not None
|
||||
|
||||
if emb is not None:
|
||||
# adjusting shapes
|
||||
while len(emb.shape) < len(h.shape):
|
||||
emb = emb[..., None]
|
||||
|
||||
if two_cond:
|
||||
# adjusting shapes
|
||||
while len(cond.shape) < len(h.shape):
|
||||
cond = cond[..., None]
|
||||
# time first
|
||||
scale_shifts = [emb, cond]
|
||||
else:
|
||||
# "cond" is not used with single cond mode
|
||||
scale_shifts = [emb]
|
||||
|
||||
# support scale, shift or shift only
|
||||
for i, each in enumerate(scale_shifts):
|
||||
if each is None:
|
||||
# special case: the condition is not provided
|
||||
a = None
|
||||
b = None
|
||||
else:
|
||||
if each.shape[1] == in_channels * 2:
|
||||
a, b = th.chunk(each, 2, dim=1)
|
||||
else:
|
||||
a = each
|
||||
b = None
|
||||
scale_shifts[i] = (a, b)
|
||||
|
||||
# condition scale bias could be a list
|
||||
if isinstance(scale_bias, Number):
|
||||
biases = [scale_bias] * len(scale_shifts)
|
||||
else:
|
||||
# a list
|
||||
biases = scale_bias
|
||||
|
||||
# default, the scale & shift are applied after the group norm but BEFORE SiLU
|
||||
pre_layers, post_layers = layers[0], layers[1:]
|
||||
|
||||
# spilt the post layer to be able to scale up or down before conv
|
||||
# post layers will contain only the conv
|
||||
mid_layers, post_layers = post_layers[:-2], post_layers[-2:]
|
||||
|
||||
h = pre_layers(h)
|
||||
# scale and shift for each condition
|
||||
for i, (scale, shift) in enumerate(scale_shifts):
|
||||
# if scale is None, it indicates that the condition is not provided
|
||||
if scale is not None:
|
||||
h = h * (biases[i] + scale)
|
||||
if shift is not None:
|
||||
h = h + shift
|
||||
h = mid_layers(h)
|
||||
|
||||
# upscale or downscale if any just before the last conv
|
||||
if up_down_layer is not None:
|
||||
h = up_down_layer(h)
|
||||
h = post_layers(h)
|
||||
return h
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
"""
|
||||
An upsampling layer with an optional convolution.
|
||||
|
||||
:param channels: channels in the inputs and outputs.
|
||||
:param use_conv: a bool determining if a convolution is applied.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
upsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
if use_conv:
|
||||
self.conv = conv_nd(dims,
|
||||
self.channels,
|
||||
self.out_channels,
|
||||
3,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
if self.dims == 3:
|
||||
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
|
||||
mode="nearest")
|
||||
else:
|
||||
# if x.shape[2] == 4:
|
||||
# feature = 9
|
||||
# x = F.interpolate(x, size=(feature), mode="nearest")
|
||||
# if x.shape[2] == 8:
|
||||
# feature = 9
|
||||
# x = F.interpolate(x, size=(feature), mode="nearest")
|
||||
# else:
|
||||
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
||||
if self.use_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
"""
|
||||
A downsampling layer with an optional convolution.
|
||||
|
||||
:param channels: channels in the inputs and outputs.
|
||||
:param use_conv: a bool determining if a convolution is applied.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
downsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
self.stride = 2 if dims != 3 else (1, 2, 2)
|
||||
if use_conv:
|
||||
self.op = conv_nd(dims,
|
||||
self.channels,
|
||||
self.out_channels,
|
||||
3,
|
||||
stride=self.stride,
|
||||
padding=1)
|
||||
else:
|
||||
assert self.channels == self.out_channels
|
||||
self.op = avg_pool_nd(dims, kernel_size=self.stride, stride=self.stride)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
# if x.shape[2] % 2 != 0:
|
||||
# op = avg_pool_nd(self.dims, kernel_size=3, stride=2)
|
||||
# return op(x)
|
||||
# if x.shape[2] % 2 != 0:
|
||||
# op = avg_pool_nd(self.dims, kernel_size=2, stride=1)
|
||||
# return op(x)
|
||||
# else:
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
"""
|
||||
An attention block that allows spatial positions to attend to each other.
|
||||
|
||||
Originally ported from here, but adapted to the N-d case.
|
||||
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
num_heads=1,
|
||||
num_head_channels=-1,
|
||||
use_new_attention_order=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
if num_head_channels == -1:
|
||||
self.num_heads = num_heads
|
||||
else:
|
||||
assert (
|
||||
channels % num_head_channels == 0
|
||||
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
||||
self.num_heads = channels // num_head_channels
|
||||
self.norm = normalization(channels)
|
||||
self.qkv = conv_nd(1, channels, channels * 3, 1)
|
||||
if use_new_attention_order:
|
||||
# split qkv before split heads
|
||||
self.attention = QKVAttention(self.num_heads)
|
||||
else:
|
||||
# split heads before split qkv
|
||||
self.attention = QKVAttentionLegacy(self.num_heads)
|
||||
|
||||
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
|
||||
|
||||
def forward(self, x):
|
||||
return self._forward(x)
|
||||
|
||||
def _forward(self, x):
|
||||
b, c, *spatial = x.shape
|
||||
x = x.reshape(b, c, -1)
|
||||
qkv = self.qkv(self.norm(x))
|
||||
h = self.attention(qkv)
|
||||
h = self.proj_out(h)
|
||||
return (x + h).reshape(b, c, *spatial)
|
||||
|
||||
|
||||
def count_flops_attn(model, _x, y):
|
||||
"""
|
||||
A counter for the `thop` package to count the operations in an
|
||||
attention operation.
|
||||
Meant to be used like:
|
||||
macs, params = thop.profile(
|
||||
model,
|
||||
inputs=(inputs, timestamps),
|
||||
custom_ops={QKVAttention: QKVAttention.count_flops},
|
||||
)
|
||||
"""
|
||||
b, c, *spatial = y[0].shape
|
||||
num_spatial = int(np.prod(spatial))
|
||||
# We perform two matmuls with the same number of ops.
|
||||
# The first computes the weight matrix, the second computes
|
||||
# the combination of the value vectors.
|
||||
matmul_ops = 2 * b * (num_spatial**2) * c
|
||||
model.total_ops += th.DoubleTensor([matmul_ops])
|
||||
|
||||
|
||||
class QKVAttentionLegacy(nn.Module):
|
||||
"""
|
||||
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
|
||||
"""
|
||||
def __init__(self, n_heads):
|
||||
super().__init__()
|
||||
self.n_heads = n_heads
|
||||
|
||||
def forward(self, qkv):
|
||||
"""
|
||||
Apply QKV attention.
|
||||
|
||||
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
|
||||
:return: an [N x (H * C) x T] tensor after attention.
|
||||
"""
|
||||
bs, width, length = qkv.shape
|
||||
assert width % (3 * self.n_heads) == 0
|
||||
ch = width // (3 * self.n_heads)
|
||||
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||
weight = th.einsum(
|
||||
"bct,bcs->bts", q * scale,
|
||||
k * scale) # More stable with f16 than dividing afterwards
|
||||
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
a = th.einsum("bts,bcs->bct", weight, v)
|
||||
return a.reshape(bs, -1, length)
|
||||
|
||||
@staticmethod
|
||||
def count_flops(model, _x, y):
|
||||
return count_flops_attn(model, _x, y)
|
||||
|
||||
|
||||
class QKVAttention(nn.Module):
|
||||
"""
|
||||
A module which performs QKV attention and splits in a different order.
|
||||
"""
|
||||
def __init__(self, n_heads):
|
||||
super().__init__()
|
||||
self.n_heads = n_heads
|
||||
|
||||
def forward(self, qkv):
|
||||
"""
|
||||
Apply QKV attention.
|
||||
|
||||
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
|
||||
:return: an [N x (H * C) x T] tensor after attention.
|
||||
"""
|
||||
bs, width, length = qkv.shape
|
||||
pdb.set_trace()
|
||||
assert width % (3 * self.n_heads) == 0
|
||||
ch = width // (3 * self.n_heads)
|
||||
q, k, v = qkv.chunk(3, dim=1)
|
||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||
weight = th.einsum(
|
||||
"bct,bcs->bts",
|
||||
(q * scale).view(bs * self.n_heads, ch, length),
|
||||
(k * scale).view(bs * self.n_heads, ch, length),
|
||||
) # More stable with f16 than dividing afterwards
|
||||
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
a = th.einsum("bts,bcs->bct", weight,
|
||||
v.reshape(bs * self.n_heads, ch, length))
|
||||
return a.reshape(bs, -1, length)
|
||||
|
||||
@staticmethod
|
||||
def count_flops(model, _x, y):
|
||||
return count_flops_attn(model, _x, y)
|
||||
|
||||
|
||||
class AttentionPool2d(nn.Module):
|
||||
"""
|
||||
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
spacial_dim: int,
|
||||
embed_dim: int,
|
||||
num_heads_channels: int,
|
||||
output_dim: int = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.positional_embedding = nn.Parameter(
|
||||
th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5)
|
||||
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
|
||||
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
|
||||
self.num_heads = embed_dim // num_heads_channels
|
||||
self.attention = QKVAttention(self.num_heads)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, *_spatial = x.shape
|
||||
x = x.reshape(b, c, -1) # NC(HW)
|
||||
x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
|
||||
x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
|
||||
x = self.qkv_proj(x)
|
||||
x = self.attention(x)
|
||||
x = self.c_proj(x)
|
||||
return x[:, :, 0]
|
173
model/graph_convolution_network.py
Normal file
173
model/graph_convolution_network.py
Normal file
|
@ -0,0 +1,173 @@
|
|||
import torch.nn as nn
|
||||
import torch
|
||||
from dataclasses import dataclass
|
||||
from torch.nn.parameter import Parameter
|
||||
from numbers import Number
|
||||
import torch.nn.functional as F
|
||||
from .blocks import *
|
||||
import math
|
||||
|
||||
|
||||
class graph_convolution(nn.Module):
|
||||
def __init__(self, in_features, out_features, node_n = 3, seq_len = 80, bias=True):
|
||||
super(graph_convolution, self).__init__()
|
||||
|
||||
self.temporal_graph_weights = Parameter(torch.FloatTensor(seq_len, seq_len))
|
||||
self.feature_weights = Parameter(torch.FloatTensor(in_features, out_features))
|
||||
self.spatial_graph_weights = Parameter(torch.FloatTensor(node_n, node_n))
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(torch.FloatTensor(seq_len))
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
stdv = 1. / math.sqrt(self.spatial_graph_weights.size(1))
|
||||
self.feature_weights.data.uniform_(-stdv, stdv)
|
||||
self.temporal_graph_weights.data.uniform_(-stdv, stdv)
|
||||
self.spatial_graph_weights.data.uniform_(-stdv, stdv)
|
||||
if self.bias is not None:
|
||||
self.bias.data.uniform_(-stdv, stdv)
|
||||
|
||||
def forward(self, x):
|
||||
y = torch.matmul(x, self.temporal_graph_weights)
|
||||
y = torch.matmul(y.permute(0, 3, 2, 1), self.feature_weights)
|
||||
y = torch.matmul(self.spatial_graph_weights, y).permute(0, 3, 2, 1).contiguous()
|
||||
|
||||
if self.bias is not None:
|
||||
return (y + self.bias)
|
||||
else:
|
||||
return y
|
||||
|
||||
|
||||
@dataclass
|
||||
class residual_graph_convolution_config():
|
||||
in_features: int
|
||||
seq_len: int
|
||||
emb_channels: int
|
||||
dropout: float
|
||||
out_features: int = None
|
||||
node_n: int = 3
|
||||
# condition the block with time (and encoder's output)
|
||||
use_condition: bool = True
|
||||
# whether to condition with both time & encoder's output
|
||||
two_cond: bool = False
|
||||
# number of encoders' output channels
|
||||
cond_emb_channels: int = None
|
||||
has_lateral: bool = False
|
||||
graph_convolution_bias: bool = True
|
||||
scale_bias: float = 1
|
||||
|
||||
def __post_init__(self):
|
||||
self.out_features = self.out_features or self.in_features
|
||||
self.cond_emb_channels = self.cond_emb_channels or self.emb_channels
|
||||
|
||||
def make_model(self):
|
||||
return residual_graph_convolution(self)
|
||||
|
||||
|
||||
class residual_graph_convolution(TimestepBlock):
|
||||
def __init__(self, conf: residual_graph_convolution_config):
|
||||
super(residual_graph_convolution, self).__init__()
|
||||
self.conf = conf
|
||||
|
||||
self.gcn = graph_convolution(conf.in_features, conf.out_features, node_n=conf.node_n, seq_len=conf.seq_len, bias=conf.graph_convolution_bias)
|
||||
self.ln = nn.LayerNorm([conf.out_features, conf.node_n, conf.seq_len])
|
||||
self.act_f = nn.Tanh()
|
||||
self.dropout = nn.Dropout(conf.dropout)
|
||||
|
||||
if conf.use_condition:
|
||||
# condition layers for the out_layers
|
||||
self.emb_layers = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(conf.emb_channels, conf.out_features),
|
||||
)
|
||||
|
||||
if conf.two_cond:
|
||||
self.cond_emb_layers = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(conf.cond_emb_channels, conf.out_features),
|
||||
)
|
||||
|
||||
if conf.in_features == conf.out_features:
|
||||
self.skip_connection = nn.Identity()
|
||||
else:
|
||||
self.skip_connection = nn.Sequential(
|
||||
graph_convolution(conf.in_features, conf.out_features, node_n=conf.node_n, seq_len=conf.seq_len, bias=conf.graph_convolution_bias),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
def forward(self, x, emb=None, cond=None, lateral=None):
|
||||
if self.conf.has_lateral:
|
||||
# lateral may be supplied even if it doesn't require
|
||||
# the model will take the lateral only if "has_lateral"
|
||||
assert lateral is not None
|
||||
x = torch.cat((x, lateral), dim =1)
|
||||
|
||||
y = self.gcn(x)
|
||||
y = self.ln(y)
|
||||
|
||||
if self.conf.use_condition:
|
||||
if emb is not None:
|
||||
emb = self.emb_layers(emb).type(x.dtype)
|
||||
# adjusting shapes
|
||||
while len(emb.shape) < len(y.shape):
|
||||
emb = emb[..., None]
|
||||
|
||||
if self.conf.two_cond or True:
|
||||
if cond is not None:
|
||||
if not isinstance(cond, torch.Tensor):
|
||||
assert isinstance(cond, dict)
|
||||
cond = cond['cond']
|
||||
cond = self.cond_emb_layers(cond).type(x.dtype)
|
||||
while len(cond.shape) < len(y.shape):
|
||||
cond = cond[..., None]
|
||||
scales = [emb, cond]
|
||||
else:
|
||||
scales = [emb]
|
||||
|
||||
# condition scale bias could be a list
|
||||
if isinstance(self.conf.scale_bias, Number):
|
||||
biases = [self.conf.scale_bias] * len(scales)
|
||||
else:
|
||||
# a list
|
||||
biases = self.conf.scale_bias
|
||||
|
||||
# scale for each condition
|
||||
for i, scale in enumerate(scales):
|
||||
# if scale is None, it indicates that the condition is not provided
|
||||
if scale is not None:
|
||||
y = y*(biases[i] + scale)
|
||||
|
||||
y = self.act_f(y)
|
||||
y = self.dropout(y)
|
||||
return self.skip_connection(x) + y
|
||||
|
||||
|
||||
class graph_downsample(nn.Module):
|
||||
"""
|
||||
A downsampling layer
|
||||
"""
|
||||
def __init__(self, kernel_size = 2):
|
||||
super().__init__()
|
||||
self.downsample = nn.AvgPool1d(kernel_size = kernel_size)
|
||||
|
||||
def forward(self, x):
|
||||
bs, features, node_n, seq_len = x.shape
|
||||
x = x.reshape(bs, features*node_n, seq_len)
|
||||
x = self.downsample(x)
|
||||
x = x.reshape(bs, features, node_n, -1)
|
||||
return x
|
||||
|
||||
|
||||
class graph_upsample(nn.Module):
|
||||
"""
|
||||
An upsampling layer
|
||||
"""
|
||||
def __init__(self, scale_factor=2):
|
||||
super().__init__()
|
||||
self.scale_factor = scale_factor
|
||||
|
||||
def forward(self, x):
|
||||
x = F.interpolate(x, (x.shape[2], x.shape[3]*self.scale_factor), mode="nearest")
|
||||
return x
|
141
model/nn.py
Normal file
141
model/nn.py
Normal file
|
@ -0,0 +1,141 @@
|
|||
"""
|
||||
Various utilities for neural networks.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
import math, pdb
|
||||
from typing import Optional
|
||||
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
||||
class SiLU(nn.Module):
|
||||
# @th.jit.script
|
||||
def forward(self, x):
|
||||
return x * th.sigmoid(x)
|
||||
|
||||
|
||||
class GroupNorm32(nn.GroupNorm):
|
||||
def forward(self, x):
|
||||
y = super().forward(x.float()).type(x.dtype)
|
||||
return y
|
||||
|
||||
|
||||
def conv_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D convolution module.
|
||||
"""
|
||||
assert dims==1
|
||||
if dims == 1:
|
||||
return nn.Conv1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.Conv2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.Conv3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
def linear(*args, **kwargs):
|
||||
"""
|
||||
Create a linear module.
|
||||
"""
|
||||
return nn.Linear(*args, **kwargs)
|
||||
|
||||
|
||||
def avg_pool_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D average pooling module.
|
||||
"""
|
||||
assert dims==1
|
||||
if dims == 1:
|
||||
return nn.AvgPool1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.AvgPool2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.AvgPool3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
def update_ema(target_params, source_params, rate=0.99):
|
||||
"""
|
||||
Update target parameters to be closer to those of source parameters using
|
||||
an exponential moving average.
|
||||
|
||||
:param target_params: the target parameter sequence.
|
||||
:param source_params: the source parameter sequence.
|
||||
:param rate: the EMA rate (closer to 1 means slower).
|
||||
"""
|
||||
for targ, src in zip(target_params, source_params):
|
||||
targ.detach().mul_(rate).add_(src, alpha=1 - rate)
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
def scale_module(module, scale):
|
||||
"""
|
||||
Scale the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().mul_(scale)
|
||||
return module
|
||||
|
||||
|
||||
def mean_flat(tensor):
|
||||
"""
|
||||
Take the mean over all non-batch dimensions.
|
||||
"""
|
||||
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
||||
|
||||
|
||||
def normalization(channels):
|
||||
"""
|
||||
Make a standard normalization layer.
|
||||
|
||||
:param channels: number of input channels.
|
||||
:return: an nn.Module for normalization.
|
||||
"""
|
||||
# return GroupNorm32(channels, channels)
|
||||
return GroupNorm32(min(4, channels), channels)
|
||||
|
||||
|
||||
def timestep_embedding(timesteps, dim, max_period=10000):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
|
||||
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an [N x dim] Tensor of positional embeddings.
|
||||
"""
|
||||
half = dim // 2
|
||||
freqs = th.exp(-math.log(max_period) *
|
||||
th.arange(start=0, end=half, dtype=th.float32) /
|
||||
half).to(device=timesteps.device)
|
||||
args = timesteps[:, None].float() * freqs[None]
|
||||
embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = th.cat(
|
||||
[embedding, th.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
|
||||
def torch_checkpoint(func, args, flag, preserve_rng_state=False):
|
||||
# torch's gradient checkpoint works with automatic mixed precision, given torch >= 1.8
|
||||
if flag:
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
func, *args, preserve_rng_state=preserve_rng_state)
|
||||
else:
|
||||
return func(*args)
|
954
model/unet.py
Normal file
954
model/unet.py
Normal file
|
@ -0,0 +1,954 @@
|
|||
import math
|
||||
from dataclasses import dataclass
|
||||
from numbers import Number
|
||||
from typing import NamedTuple, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from choices import *
|
||||
from config_base import BaseConfig
|
||||
from .blocks import *
|
||||
from .graph_convolution_network import *
|
||||
from .nn import (conv_nd, linear, normalization, timestep_embedding, zero_module)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BeatGANsUNetConfig(BaseConfig):
|
||||
seq_len: int = 80
|
||||
in_channels: int = 9
|
||||
# base channels, will be multiplied
|
||||
model_channels: int = 64
|
||||
# output of the unet
|
||||
out_channels: int = 9
|
||||
# how many repeating resblocks per resolution
|
||||
# the decoding side would have "one more" resblock
|
||||
# default: 2
|
||||
num_res_blocks: int = 2
|
||||
# number of time embed channels and style channels
|
||||
embed_channels: int = 256
|
||||
# at what resolutions you want to do self-attention of the feature maps
|
||||
# attentions generally improve performance
|
||||
attention_resolutions: Tuple[int] = (0, )
|
||||
# dropout applies to the resblocks (on feature maps)
|
||||
dropout: float = 0.1
|
||||
channel_mult: Tuple[int] = (1, 2, 4)
|
||||
conv_resample: bool = True
|
||||
# 1 = 1d conv
|
||||
dims: int = 1
|
||||
# number of attention heads
|
||||
num_heads: int = 1
|
||||
# or specify the number of channels per attention head
|
||||
num_head_channels: int = -1
|
||||
# use resblock for upscale/downscale blocks (expensive)
|
||||
# default: True (BeatGANs)
|
||||
resblock_updown: bool = True
|
||||
use_new_attention_order: bool = False
|
||||
resnet_two_cond: bool = True
|
||||
resnet_cond_channels: int = None
|
||||
# init the decoding conv layers with zero weights, this speeds up training
|
||||
# default: True (BeatGANs)
|
||||
resnet_use_zero_module: bool = True
|
||||
|
||||
def make_model(self):
|
||||
return BeatGANsUNetModel(self)
|
||||
|
||||
|
||||
class BeatGANsUNetModel(nn.Module):
|
||||
def __init__(self, conf: BeatGANsUNetConfig):
|
||||
super().__init__()
|
||||
self.conf = conf
|
||||
|
||||
self.dtype = th.float32
|
||||
|
||||
self.time_emb_channels = conf.model_channels
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(self.time_emb_channels, conf.embed_channels),
|
||||
nn.SiLU(),
|
||||
linear(conf.embed_channels, conf.embed_channels),
|
||||
)
|
||||
|
||||
ch = input_ch = int(conf.channel_mult[0] * conf.model_channels)
|
||||
self.input_blocks = nn.ModuleList([
|
||||
TimestepEmbedSequential(
|
||||
conv_nd(conf.dims, conf.in_channels, ch, 3, padding=1)),
|
||||
])
|
||||
|
||||
kwargs = dict(
|
||||
use_condition=True,
|
||||
two_cond=conf.resnet_two_cond,
|
||||
use_zero_module=conf.resnet_use_zero_module,
|
||||
# style channels for the resnet block
|
||||
cond_emb_channels=conf.resnet_cond_channels,
|
||||
)
|
||||
|
||||
self._feature_size = ch
|
||||
|
||||
# input_block_chans = [ch]
|
||||
input_block_chans = [[] for _ in range(len(conf.channel_mult))]
|
||||
input_block_chans[0].append(ch)
|
||||
|
||||
# number of blocks at each resolution
|
||||
self.input_num_blocks = [0 for _ in range(len(conf.channel_mult))]
|
||||
self.input_num_blocks[0] = 1
|
||||
self.output_num_blocks = [0 for _ in range(len(conf.channel_mult))]
|
||||
|
||||
ds = 1
|
||||
resolution = conf.seq_len
|
||||
for level, mult in enumerate(conf.channel_mult):
|
||||
for _ in range(conf.num_res_blocks):
|
||||
layers = [
|
||||
ResBlockConfig(
|
||||
ch,
|
||||
conf.embed_channels,
|
||||
conf.dropout,
|
||||
out_channels=int(mult * conf.model_channels),
|
||||
dims=conf.dims,
|
||||
**kwargs,
|
||||
).make_model()
|
||||
]
|
||||
ch = int(mult * conf.model_channels)
|
||||
if resolution in conf.attention_resolutions:
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
num_heads=conf.num_heads,
|
||||
num_head_channels=conf.num_head_channels,
|
||||
use_new_attention_order=conf.
|
||||
use_new_attention_order,
|
||||
))
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
# input_block_chans.append(ch)
|
||||
input_block_chans[level].append(ch)
|
||||
self.input_num_blocks[level] += 1
|
||||
# print(input_block_chans)
|
||||
if level != len(conf.channel_mult) - 1:
|
||||
resolution //= 2
|
||||
out_ch = ch
|
||||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
ResBlockConfig(
|
||||
ch,
|
||||
conf.embed_channels,
|
||||
conf.dropout,
|
||||
out_channels=out_ch,
|
||||
dims=conf.dims,
|
||||
down=True,
|
||||
**kwargs,
|
||||
).make_model() if conf.
|
||||
resblock_updown else Downsample(ch,
|
||||
conf.conv_resample,
|
||||
dims=conf.dims,
|
||||
out_channels=out_ch)))
|
||||
ch = out_ch
|
||||
# input_block_chans.append(ch)
|
||||
input_block_chans[level + 1].append(ch)
|
||||
self.input_num_blocks[level + 1] += 1
|
||||
ds *= 2
|
||||
self._feature_size += ch
|
||||
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
ResBlockConfig(
|
||||
ch,
|
||||
conf.embed_channels,
|
||||
conf.dropout,
|
||||
dims=conf.dims,
|
||||
**kwargs,
|
||||
).make_model(),
|
||||
#AttentionBlock(
|
||||
# ch,
|
||||
# num_heads=conf.num_heads,
|
||||
# num_head_channels=conf.num_head_channels,
|
||||
# use_new_attention_order=conf.use_new_attention_order,
|
||||
#),
|
||||
ResBlockConfig(
|
||||
ch,
|
||||
conf.embed_channels,
|
||||
conf.dropout,
|
||||
dims=conf.dims,
|
||||
**kwargs,
|
||||
).make_model(),
|
||||
)
|
||||
self._feature_size += ch
|
||||
|
||||
self.output_blocks = nn.ModuleList([])
|
||||
for level, mult in list(enumerate(conf.channel_mult))[::-1]:
|
||||
for i in range(conf.num_res_blocks + 1):
|
||||
# print(input_block_chans)
|
||||
# ich = input_block_chans.pop()
|
||||
try:
|
||||
ich = input_block_chans[level].pop()
|
||||
except IndexError:
|
||||
# this happens only when num_res_block > num_enc_res_block
|
||||
# we will not have enough lateral (skip) connecions for all decoder blocks
|
||||
ich = 0
|
||||
# print('pop:', ich)
|
||||
layers = [
|
||||
ResBlockConfig(
|
||||
# only direct channels when gated
|
||||
channels=ch + ich,
|
||||
emb_channels=conf.embed_channels,
|
||||
dropout=conf.dropout,
|
||||
out_channels=int(conf.model_channels * mult),
|
||||
dims=conf.dims,
|
||||
# lateral channels are described here when gated
|
||||
has_lateral=True if ich > 0 else False,
|
||||
lateral_channels=None,
|
||||
**kwargs,
|
||||
).make_model()
|
||||
]
|
||||
ch = int(conf.model_channels * mult)
|
||||
if resolution in conf.attention_resolutions:
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
num_heads=conf.num_heads,
|
||||
num_head_channels=conf.num_head_channels,
|
||||
use_new_attention_order=conf.
|
||||
use_new_attention_order,
|
||||
))
|
||||
if level and i == conf.num_res_blocks:
|
||||
resolution *= 2
|
||||
out_ch = ch
|
||||
layers.append(
|
||||
ResBlockConfig(
|
||||
ch,
|
||||
conf.embed_channels,
|
||||
conf.dropout,
|
||||
out_channels=out_ch,
|
||||
dims=conf.dims,
|
||||
up=True,
|
||||
**kwargs,
|
||||
).make_model() if (
|
||||
conf.resblock_updown
|
||||
) else Upsample(ch,
|
||||
conf.conv_resample,
|
||||
dims=conf.dims,
|
||||
out_channels=out_ch))
|
||||
ds //= 2
|
||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self.output_num_blocks[level] += 1
|
||||
self._feature_size += ch
|
||||
|
||||
# print(input_block_chans)
|
||||
# print('inputs:', self.input_num_blocks)
|
||||
# print('outputs:', self.output_num_blocks)
|
||||
|
||||
if conf.resnet_use_zero_module:
|
||||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
zero_module(
|
||||
conv_nd(conf.dims,
|
||||
input_ch,
|
||||
conf.out_channels,
|
||||
3, ## kernel size
|
||||
padding=1)),
|
||||
)
|
||||
else:
|
||||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
conv_nd(conf.dims, input_ch, conf.out_channels, 3, padding=1),
|
||||
)
|
||||
|
||||
def forward(self, x, t, **kwargs):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
|
||||
:param x: an [N x C x ...] Tensor of inputs.
|
||||
:param timesteps: a 1-D batch of timesteps.
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
|
||||
# hs = []
|
||||
hs = [[] for _ in range(len(self.conf.channel_mult))]
|
||||
emb = self.time_embed(timestep_embedding(t, self.time_emb_channels))
|
||||
|
||||
# new code supports input_num_blocks != output_num_blocks
|
||||
h = x.type(self.dtype)
|
||||
k = 0
|
||||
for i in range(len(self.input_num_blocks)):
|
||||
for j in range(self.input_num_blocks[i]):
|
||||
h = self.input_blocks[k](h, emb=emb)
|
||||
# print(i, j, h.shape)
|
||||
hs[i].append(h) ## Get output from each layer
|
||||
k += 1
|
||||
assert k == len(self.input_blocks)
|
||||
|
||||
# middle blocks
|
||||
h = self.middle_block(h, emb=emb)
|
||||
|
||||
# output blocks
|
||||
k = 0
|
||||
for i in range(len(self.output_num_blocks)):
|
||||
for j in range(self.output_num_blocks[i]):
|
||||
# take the lateral connection from the same layer (in reserve)
|
||||
# until there is no more, use None
|
||||
try:
|
||||
lateral = hs[-i - 1].pop()
|
||||
# print(i, j, lateral.shape)
|
||||
except IndexError:
|
||||
lateral = None
|
||||
# print(i, j, lateral)
|
||||
h = self.output_blocks[k](h, emb=emb, lateral=lateral)
|
||||
k += 1
|
||||
|
||||
h = h.type(x.dtype)
|
||||
pred = self.out(h)
|
||||
return Return(pred=pred)
|
||||
|
||||
|
||||
class Return(NamedTuple):
|
||||
pred: th.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class BeatGANsEncoderConfig(BaseConfig):
|
||||
in_channels: int
|
||||
seq_len: int = 80
|
||||
num_res_blocks: int = 2
|
||||
attention_resolutions: Tuple[int] = (0, )
|
||||
model_channels: int = 32
|
||||
out_channels: int = 256
|
||||
dropout: float = 0.1
|
||||
channel_mult: Tuple[int] = (1, 2, 4)
|
||||
use_time_condition: bool = False
|
||||
conv_resample: bool = True
|
||||
dims: int = 1
|
||||
num_heads: int = 1
|
||||
num_head_channels: int = -1
|
||||
resblock_updown: bool = True
|
||||
use_new_attention_order: bool = False
|
||||
pool: str = 'adaptivenonzero'
|
||||
|
||||
def make_model(self):
|
||||
return BeatGANsEncoderModel(self)
|
||||
|
||||
|
||||
class BeatGANsEncoderModel(nn.Module):
|
||||
"""
|
||||
The half UNet model with attention and timestep embedding.
|
||||
|
||||
For usage, see UNet.
|
||||
"""
|
||||
def __init__(self, conf: BeatGANsEncoderConfig):
|
||||
super().__init__()
|
||||
self.conf = conf
|
||||
self.dtype = th.float32
|
||||
|
||||
if conf.use_time_condition:
|
||||
time_embed_dim = conf.model_channels
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(conf.model_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
else:
|
||||
time_embed_dim = None
|
||||
|
||||
ch = int(conf.channel_mult[0] * conf.model_channels)
|
||||
self.input_blocks = nn.ModuleList([
|
||||
TimestepEmbedSequential(
|
||||
conv_nd(conf.dims, conf.in_channels, ch, 3, padding=1),
|
||||
)
|
||||
])
|
||||
self._feature_size = ch
|
||||
input_block_chans = [ch]
|
||||
ds = 1
|
||||
resolution = conf.seq_len
|
||||
for level, mult in enumerate(conf.channel_mult):
|
||||
for _ in range(conf.num_res_blocks):
|
||||
layers = [
|
||||
ResBlockConfig(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
conf.dropout,
|
||||
out_channels=int(mult * conf.model_channels),
|
||||
dims=conf.dims,
|
||||
use_condition=conf.use_time_condition,
|
||||
).make_model()
|
||||
]
|
||||
ch = int(mult * conf.model_channels)
|
||||
if resolution in conf.attention_resolutions:
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
num_heads=conf.num_heads,
|
||||
num_head_channels=conf.num_head_channels,
|
||||
use_new_attention_order=conf.
|
||||
use_new_attention_order,
|
||||
))
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
input_block_chans.append(ch)
|
||||
if level != len(conf.channel_mult) - 1:
|
||||
resolution //= 2
|
||||
out_ch = ch
|
||||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
ResBlockConfig(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
conf.dropout,
|
||||
out_channels=out_ch,
|
||||
dims=conf.dims,
|
||||
use_condition=conf.use_time_condition,
|
||||
down=True,
|
||||
).make_model() if (
|
||||
conf.resblock_updown
|
||||
) else Downsample(ch,
|
||||
conf.conv_resample,
|
||||
dims=conf.dims,
|
||||
out_channels=out_ch)))
|
||||
ch = out_ch
|
||||
input_block_chans.append(ch)
|
||||
ds *= 2
|
||||
self._feature_size += ch
|
||||
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
ResBlockConfig(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
conf.dropout,
|
||||
dims=conf.dims,
|
||||
use_condition=conf.use_time_condition,
|
||||
).make_model(),
|
||||
AttentionBlock(
|
||||
ch,
|
||||
num_heads=conf.num_heads,
|
||||
num_head_channels=conf.num_head_channels,
|
||||
use_new_attention_order=conf.use_new_attention_order,
|
||||
),
|
||||
ResBlockConfig(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
conf.dropout,
|
||||
dims=conf.dims,
|
||||
use_condition=conf.use_time_condition,
|
||||
).make_model(),
|
||||
)
|
||||
self._feature_size += ch
|
||||
if conf.pool == "adaptivenonzero":
|
||||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
## nn.AdaptiveAvgPool2d((1, 1)),
|
||||
nn.AdaptiveAvgPool1d(1),
|
||||
conv_nd(conf.dims, ch, conf.out_channels, 1),
|
||||
nn.Flatten(),
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unexpected {conf.pool} pooling")
|
||||
|
||||
def forward(self, x, t=None):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
|
||||
:param x: an [N x C x ...] Tensor of inputs.
|
||||
:param timesteps: a 1-D batch of timesteps.
|
||||
:return: an [N x K] Tensor of outputs.
|
||||
"""
|
||||
if self.conf.use_time_condition:
|
||||
emb = self.time_embed(timestep_embedding(t, self.model_channels))
|
||||
else: ## autoencoding.py
|
||||
emb = None
|
||||
|
||||
results = []
|
||||
h = x.type(self.dtype)
|
||||
for module in self.input_blocks: ## flow input x over all the input blocks
|
||||
h = module(h, emb=emb)
|
||||
if self.conf.pool.startswith("spatial"):
|
||||
results.append(h.type(x.dtype).mean(dim=(2, 3)))
|
||||
h = self.middle_block(h, emb=emb) ## TimestepEmbedSequential(...)
|
||||
if self.conf.pool.startswith("spatial"):
|
||||
results.append(h.type(x.dtype).mean(dim=(2, 3)))
|
||||
h = th.cat(results, axis=-1)
|
||||
else: ## autoencoder.py
|
||||
h = h.type(x.dtype)
|
||||
|
||||
h = h.float()
|
||||
h = self.out(h)
|
||||
|
||||
return h
|
||||
|
||||
|
||||
@dataclass
|
||||
class GCNUNetConfig(BaseConfig):
|
||||
in_channels: int = 9
|
||||
node_n: int = 3
|
||||
seq_len: int = 80
|
||||
# base channels, will be multiplied
|
||||
model_channels: int = 32
|
||||
# output of the unet
|
||||
out_channels: int = 9
|
||||
# how many repeating resblocks per resolution
|
||||
num_res_blocks: int = 8
|
||||
# number of time embed channels and style channels
|
||||
embed_channels: int = 256
|
||||
# dropout applies to the resblocks
|
||||
dropout: float = 0.1
|
||||
channel_mult: Tuple[int] = (1, 2, 4)
|
||||
resnet_two_cond: bool = True
|
||||
|
||||
def make_model(self):
|
||||
return GCNUNetModel(self)
|
||||
|
||||
|
||||
class GCNUNetModel(nn.Module):
|
||||
def __init__(self, conf: GCNUNetConfig):
|
||||
super().__init__()
|
||||
self.conf = conf
|
||||
self.dtype = th.float32
|
||||
assert conf.in_channels%conf.node_n == 0
|
||||
self.in_features = conf.in_channels//conf.node_n
|
||||
|
||||
self.time_emb_channels = conf.model_channels*4
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(self.time_emb_channels, conf.embed_channels),
|
||||
nn.SiLU(),
|
||||
linear(conf.embed_channels, conf.embed_channels),
|
||||
)
|
||||
|
||||
ch = int(conf.channel_mult[0] * conf.model_channels)
|
||||
self.input_blocks = nn.ModuleList([
|
||||
TimestepEmbedSequential(
|
||||
graph_convolution(in_features=self.in_features, out_features=ch, node_n=conf.node_n, seq_len=conf.seq_len)),
|
||||
])
|
||||
|
||||
kwargs = dict(
|
||||
use_condition=True,
|
||||
two_cond=conf.resnet_two_cond,
|
||||
)
|
||||
|
||||
input_block_chans = [[] for _ in range(len(conf.channel_mult))]
|
||||
input_block_chans[0].append(ch)
|
||||
|
||||
# number of blocks at each resolution
|
||||
self.input_num_blocks = [0 for _ in range(len(conf.channel_mult))]
|
||||
self.input_num_blocks[0] = 1
|
||||
self.output_num_blocks = [0 for _ in range(len(conf.channel_mult))]
|
||||
|
||||
ds = 1
|
||||
resolution = conf.seq_len
|
||||
for level, mult in enumerate(conf.channel_mult):
|
||||
for _ in range(conf.num_res_blocks):
|
||||
layers = [
|
||||
residual_graph_convolution_config(
|
||||
in_features=ch,
|
||||
seq_len=resolution,
|
||||
emb_channels = conf.embed_channels,
|
||||
dropout=conf.dropout,
|
||||
out_features=int(mult * conf.model_channels),
|
||||
node_n=conf.node_n,
|
||||
**kwargs,
|
||||
).make_model()
|
||||
]
|
||||
ch = int(mult * conf.model_channels)
|
||||
self.input_blocks.append(*layers)
|
||||
input_block_chans[level].append(ch)
|
||||
self.input_num_blocks[level] += 1
|
||||
if level != len(conf.channel_mult) - 1:
|
||||
resolution //= 2
|
||||
out_ch = ch
|
||||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
graph_downsample()))
|
||||
ch = out_ch
|
||||
input_block_chans[level + 1].append(ch)
|
||||
self.input_num_blocks[level + 1] += 1
|
||||
ds *= 2
|
||||
|
||||
self.output_blocks = nn.ModuleList([])
|
||||
for level, mult in list(enumerate(conf.channel_mult))[::-1]:
|
||||
for i in range(conf.num_res_blocks + 1):
|
||||
try:
|
||||
ich = input_block_chans[level].pop()
|
||||
except IndexError:
|
||||
# this happens only when num_res_block > num_enc_res_block
|
||||
# we will not have enough lateral (skip) connecions for all decoder blocks
|
||||
ich = 0
|
||||
layers = [
|
||||
residual_graph_convolution_config(
|
||||
in_features=ch + ich,
|
||||
seq_len=resolution,
|
||||
emb_channels = conf.embed_channels,
|
||||
dropout=conf.dropout,
|
||||
out_features=int(mult * conf.model_channels),
|
||||
node_n=conf.node_n,
|
||||
has_lateral=True if ich > 0 else False,
|
||||
**kwargs,
|
||||
).make_model()
|
||||
]
|
||||
ch = int(mult*conf.model_channels)
|
||||
if level and i == conf.num_res_blocks:
|
||||
resolution *= 2
|
||||
out_ch = ch
|
||||
layers.append(graph_upsample())
|
||||
ds //= 2
|
||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self.output_num_blocks[level] += 1
|
||||
|
||||
self.out = nn.Sequential(
|
||||
graph_convolution(in_features=ch, out_features=self.in_features, node_n=conf.node_n, seq_len=conf.seq_len),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
def forward(self, x, t, **kwargs):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
|
||||
:param x: an [N x C x ...] Tensor of inputs.
|
||||
:param timesteps: a 1-D batch of timesteps.
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
bs, channels, seq_len = x.shape
|
||||
x = x.reshape(bs, self.conf.node_n, self.in_features, seq_len).permute(0, 2, 1, 3)
|
||||
|
||||
hs = [[] for _ in range(len(self.conf.channel_mult))]
|
||||
emb = self.time_embed(timestep_embedding(t, self.time_emb_channels))
|
||||
|
||||
# new code supports input_num_blocks != output_num_blocks
|
||||
h = x.type(self.dtype)
|
||||
k = 0
|
||||
for i in range(len(self.input_num_blocks)):
|
||||
for j in range(self.input_num_blocks[i]):
|
||||
h = self.input_blocks[k](h, emb=emb)
|
||||
# print(i, j, h.shape)
|
||||
hs[i].append(h) ## Get output from each layer
|
||||
k += 1
|
||||
assert k == len(self.input_blocks)
|
||||
|
||||
# output blocks
|
||||
k = 0
|
||||
for i in range(len(self.output_num_blocks)):
|
||||
for j in range(self.output_num_blocks[i]):
|
||||
# take the lateral connection from the same layer (in reserve)
|
||||
# until there is no more, use None
|
||||
try:
|
||||
lateral = hs[-i - 1].pop()
|
||||
# print(i, j, lateral.shape)
|
||||
except IndexError:
|
||||
lateral = None
|
||||
# print(i, j, lateral)
|
||||
h = self.output_blocks[k](h, emb=emb, lateral=lateral)
|
||||
k += 1
|
||||
|
||||
h = h.type(x.dtype)
|
||||
pred = self.out(h)
|
||||
pred = pred.permute(0, 2, 1, 3).reshape(bs, -1, seq_len)
|
||||
|
||||
return Return(pred=pred)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GCNEncoderConfig(BaseConfig):
|
||||
in_channels: int
|
||||
in_features = 3 # features for one node
|
||||
seq_len: int = 40
|
||||
seq_len_future: int = 3
|
||||
num_res_blocks: int = 2
|
||||
model_channels: int = 32
|
||||
out_channels: int = 32
|
||||
dropout: float = 0.1
|
||||
channel_mult: Tuple[int] = (1, 2, 4)
|
||||
use_time_condition: bool = False
|
||||
|
||||
def make_model(self):
|
||||
return GCNEncoderModel(self)
|
||||
|
||||
|
||||
class GCNEncoderModel(nn.Module):
|
||||
def __init__(self, conf: GCNEncoderConfig):
|
||||
super().__init__()
|
||||
self.conf = conf
|
||||
self.dtype = th.float32
|
||||
assert conf.in_channels%conf.in_features == 0
|
||||
self.in_features = conf.in_features
|
||||
self.node_n = conf.in_channels//conf.in_features
|
||||
|
||||
ch = int(conf.channel_mult[0] * conf.model_channels)
|
||||
self.input_blocks = nn.ModuleList([
|
||||
graph_convolution(in_features=self.in_features, out_features=ch, node_n=self.node_n, seq_len=conf.seq_len),
|
||||
])
|
||||
input_block_chans = [ch]
|
||||
ds = 1
|
||||
resolution = conf.seq_len
|
||||
for level, mult in enumerate(conf.channel_mult):
|
||||
for _ in range(conf.num_res_blocks):
|
||||
layers = [
|
||||
residual_graph_convolution_config(
|
||||
in_features=ch,
|
||||
seq_len=resolution,
|
||||
emb_channels = None,
|
||||
dropout=conf.dropout,
|
||||
out_features=int(mult * conf.model_channels),
|
||||
node_n=self.node_n,
|
||||
use_condition=conf.use_time_condition,
|
||||
).make_model()
|
||||
]
|
||||
ch = int(mult * conf.model_channels)
|
||||
self.input_blocks.append(*layers)
|
||||
input_block_chans.append(ch)
|
||||
if level != len(conf.channel_mult) - 1:
|
||||
resolution //= 2
|
||||
out_ch = ch
|
||||
self.input_blocks.append(
|
||||
graph_downsample())
|
||||
ch = out_ch
|
||||
input_block_chans.append(ch)
|
||||
ds *= 2
|
||||
|
||||
self.hand_prediction = nn.Sequential(
|
||||
conv_nd(1, ch*2, ch*2, 3, padding=1),
|
||||
nn.LayerNorm([ch*2, conf.seq_len_future]),
|
||||
nn.Tanh(),
|
||||
conv_nd(1, ch*2, self.in_features*2, 1),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
self.head_prediction = nn.Sequential(
|
||||
conv_nd(1, ch, ch, 3, padding=1),
|
||||
nn.LayerNorm([ch, conf.seq_len_future]),
|
||||
nn.Tanh(),
|
||||
conv_nd(1, ch, self.in_features, 1),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
self.out = nn.Sequential(
|
||||
nn.AdaptiveAvgPool1d(1),
|
||||
conv_nd(1, ch*self.node_n, conf.out_channels, 1),
|
||||
nn.Flatten(),
|
||||
)
|
||||
|
||||
|
||||
def forward(self, x, t=None):
|
||||
bs, channels, seq_len = x.shape
|
||||
|
||||
if self.node_n == 3: # both hand and head
|
||||
hand_last = x[:, :6, -1:].expand(-1, -1, self.conf.seq_len_future).clone() #last observed hand position
|
||||
head_last = x[:, 6:, -1:].expand(-1, -1, self.conf.seq_len_future).clone()# last observed head orientation
|
||||
|
||||
if self.node_n == 2: # hand only
|
||||
hand_last = x[:, :, -1:].expand(-1, -1, self.conf.seq_len_future).clone() #last observed hand position
|
||||
|
||||
if self.node_n == 1: # head only
|
||||
head_last = x[:, :, -1:].expand(-1, -1, self.conf.seq_len_future).clone()# last observed head orientation
|
||||
|
||||
|
||||
x = x.reshape(bs, self.node_n, self.in_features, seq_len).permute(0, 2, 1, 3)
|
||||
|
||||
h = x.type(self.dtype)
|
||||
for module in self.input_blocks:
|
||||
h = module(h)
|
||||
|
||||
h = h.type(x.dtype)
|
||||
h = h.float()
|
||||
bs, features, node_n, seq_len = h.shape
|
||||
|
||||
if self.node_n == 3: # both hand and head
|
||||
hand_features = h[:, :, :2, -self.conf.seq_len_future:].reshape(bs, features*2, -1)
|
||||
head_features = h[:, :, 2:, -self.conf.seq_len_future:].reshape(bs, features, -1)
|
||||
|
||||
pred_hand = self.hand_prediction(hand_features) + hand_last
|
||||
pred_head = self.head_prediction(head_features) + head_last
|
||||
pred_head = F.normalize(pred_head, dim=1)# normalize head orientation to unit vectors
|
||||
|
||||
if self.node_n == 2: # hand only
|
||||
hand_features = h[:, :, :, -self.conf.seq_len_future:].reshape(bs, features*2, -1)
|
||||
pred_hand = self.hand_prediction(hand_features) + hand_last
|
||||
pred_head = None
|
||||
|
||||
if self.node_n == 1: # head only
|
||||
head_features = h[:, :, :, -self.conf.seq_len_future:].reshape(bs, features, -1)
|
||||
pred_head = self.head_prediction(head_features) + head_last
|
||||
pred_head = F.normalize(pred_head, dim=1)# normalize head orientation to unit vectors
|
||||
pred_hand = None
|
||||
|
||||
h = h.reshape(bs, features*node_n, seq_len)
|
||||
h = self.out(h)
|
||||
|
||||
return h, pred_hand, pred_head
|
||||
|
||||
|
||||
@dataclass
|
||||
class CNNEncoderConfig(BaseConfig):
|
||||
in_channels: int
|
||||
seq_len: int = 40
|
||||
seq_len_future: int = 3
|
||||
out_channels: int = 128
|
||||
|
||||
def make_model(self):
|
||||
return CNNEncoderModel(self)
|
||||
|
||||
|
||||
class CNNEncoderModel(nn.Module):
|
||||
def __init__(self, conf: CNNEncoderConfig):
|
||||
super().__init__()
|
||||
self.conf = conf
|
||||
self.dtype = th.float32
|
||||
input_dim = conf.in_channels
|
||||
length = conf.seq_len
|
||||
out_channels = conf.out_channels
|
||||
|
||||
self.encoder = nn.Sequential(
|
||||
nn.Conv1d(input_dim, 32, kernel_size=3, padding=1),
|
||||
nn.LayerNorm([32, length]),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv1d(32, 32, kernel_size=3, padding=1),
|
||||
nn.LayerNorm([32, length]),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv1d(32, 32, kernel_size=3, padding=1),
|
||||
nn.LayerNorm([32, length]),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
|
||||
self.out = nn.Linear(32 * length, out_channels)
|
||||
|
||||
def forward(self, x, t=None):
|
||||
bs, channels, seq_len = x.shape
|
||||
hand_last = x[:, :6, -1:].expand(-1, -1, self.conf.seq_len_future).clone() #last observed hand position
|
||||
head_last = x[:, 6:, -1:].expand(-1, -1, self.conf.seq_len_future).clone()# last observed head orientation
|
||||
|
||||
h = x.type(self.dtype)
|
||||
h = self.encoder(h)
|
||||
h = h.view(h.shape[0], -1)
|
||||
|
||||
h = h.type(x.dtype)
|
||||
h = h.float()
|
||||
|
||||
h = self.out(h)
|
||||
return h, hand_last, head_last
|
||||
|
||||
|
||||
@dataclass
|
||||
class GRUEncoderConfig(BaseConfig):
|
||||
in_channels: int
|
||||
seq_len: int = 40
|
||||
seq_len_future: int = 3
|
||||
out_channels: int = 128
|
||||
|
||||
def make_model(self):
|
||||
return GRUEncoderModel(self)
|
||||
|
||||
|
||||
class GRUEncoderModel(nn.Module):
|
||||
def __init__(self, conf: GRUEncoderConfig):
|
||||
super().__init__()
|
||||
self.conf = conf
|
||||
self.dtype = th.float32
|
||||
input_dim = conf.in_channels
|
||||
length = conf.seq_len
|
||||
feature_channels = 32
|
||||
out_channels = conf.out_channels
|
||||
|
||||
self.encoder = nn.GRU(input_dim, feature_channels, 1, batch_first=True)
|
||||
|
||||
self.out = nn.Linear(feature_channels * length, out_channels)
|
||||
|
||||
def forward(self, x, t=None):
|
||||
bs, channels, seq_len = x.shape
|
||||
hand_last = x[:, :6, -1:].expand(-1, -1, self.conf.seq_len_future).clone() #last observed hand position
|
||||
head_last = x[:, 6:, -1:].expand(-1, -1, self.conf.seq_len_future).clone()# last observed head orientation
|
||||
|
||||
h = x.type(self.dtype)
|
||||
h, _ = self.encoder(h.permute(0, 2, 1))
|
||||
h = h.reshape(h.shape[0], -1)
|
||||
|
||||
h = h.type(x.dtype)
|
||||
h = h.float()
|
||||
|
||||
h = self.out(h)
|
||||
return h, hand_last, head_last
|
||||
|
||||
|
||||
@dataclass
|
||||
class LSTMEncoderConfig(BaseConfig):
|
||||
in_channels: int
|
||||
seq_len: int = 40
|
||||
seq_len_future: int = 3
|
||||
out_channels: int = 128
|
||||
|
||||
def make_model(self):
|
||||
return LSTMEncoderModel(self)
|
||||
|
||||
|
||||
class LSTMEncoderModel(nn.Module):
|
||||
def __init__(self, conf: LSTMEncoderConfig):
|
||||
super().__init__()
|
||||
self.conf = conf
|
||||
self.dtype = th.float32
|
||||
input_dim = conf.in_channels
|
||||
length = conf.seq_len
|
||||
feature_channels = 32
|
||||
out_channels = conf.out_channels
|
||||
|
||||
self.encoder = nn.LSTM(input_dim, feature_channels, 1, batch_first=True)
|
||||
|
||||
self.out = nn.Linear(feature_channels * length, out_channels)
|
||||
|
||||
def forward(self, x, t=None):
|
||||
bs, channels, seq_len = x.shape
|
||||
hand_last = x[:, :6, -1:].expand(-1, -1, self.conf.seq_len_future).clone() #last observed hand position
|
||||
head_last = x[:, 6:, -1:].expand(-1, -1, self.conf.seq_len_future).clone()# last observed head orientation
|
||||
|
||||
h = x.type(self.dtype)
|
||||
h, _ = self.encoder(h.permute(0, 2, 1))
|
||||
h = h.reshape(h.shape[0], -1)
|
||||
|
||||
h = h.type(x.dtype)
|
||||
h = h.float()
|
||||
|
||||
h = self.out(h)
|
||||
return h, hand_last, head_last
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLPEncoderConfig(BaseConfig):
|
||||
in_channels: int
|
||||
seq_len: int = 40
|
||||
seq_len_future: int = 3
|
||||
out_channels: int = 128
|
||||
|
||||
def make_model(self):
|
||||
return MLPEncoderModel(self)
|
||||
|
||||
|
||||
class MLPEncoderModel(nn.Module):
|
||||
def __init__(self, conf: MLPEncoderConfig):
|
||||
super().__init__()
|
||||
self.conf = conf
|
||||
self.dtype = th.float32
|
||||
input_dim = conf.in_channels
|
||||
length = conf.seq_len
|
||||
out_channels = conf.out_channels
|
||||
|
||||
linear_size = 128
|
||||
self.encoder = nn.Sequential(
|
||||
nn.Linear(length*input_dim, linear_size),
|
||||
nn.LayerNorm([linear_size]),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(linear_size, linear_size),
|
||||
nn.LayerNorm([linear_size]),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
|
||||
self.out = nn.Linear(linear_size, out_channels)
|
||||
|
||||
def forward(self, x, t=None):
|
||||
bs, channels, seq_len = x.shape
|
||||
hand_last = x[:, :6, -1:].expand(-1, -1, self.conf.seq_len_future).clone() #last observed hand position
|
||||
head_last = x[:, 6:, -1:].expand(-1, -1, self.conf.seq_len_future).clone()# last observed head orientation
|
||||
|
||||
h = x.type(self.dtype)
|
||||
|
||||
h = h.view(h.shape[0], -1)
|
||||
h = self.encoder(h)
|
||||
|
||||
h = h.type(x.dtype)
|
||||
h = h.float()
|
||||
|
||||
h = self.out(h)
|
||||
return h, hand_last, head_last
|
418
model/unet_autoenc.py
Normal file
418
model/unet_autoenc.py
Normal file
|
@ -0,0 +1,418 @@
|
|||
from enum import Enum
|
||||
|
||||
import torch, pdb
|
||||
import os
|
||||
from torch import Tensor
|
||||
from torch.nn.functional import silu
|
||||
from .unet import *
|
||||
from choices import *
|
||||
|
||||
|
||||
@dataclass
|
||||
class BeatGANsAutoencConfig(BeatGANsUNetConfig):
|
||||
seq_len_future: int = 3
|
||||
enc_out_channels: int = 128
|
||||
semantic_encoder_type: str = 'gcn'
|
||||
enc_channel_mult: Tuple[int] = None
|
||||
def make_model(self):
|
||||
return BeatGANsAutoencModel(self)
|
||||
|
||||
class BeatGANsAutoencModel(BeatGANsUNetModel):
|
||||
def __init__(self, conf: BeatGANsAutoencConfig):
|
||||
super().__init__(conf)
|
||||
self.conf = conf
|
||||
|
||||
# having only time, cond
|
||||
self.time_embed = TimeStyleSeperateEmbed(
|
||||
time_channels=conf.model_channels,
|
||||
time_out_channels=conf.embed_channels,
|
||||
)
|
||||
|
||||
if conf.semantic_encoder_type == 'gcn':
|
||||
self.encoder = GCNEncoderConfig(
|
||||
seq_len=conf.seq_len,
|
||||
seq_len_future=conf.seq_len_future,
|
||||
in_channels=conf.in_channels,
|
||||
model_channels=16,
|
||||
out_channels=conf.enc_out_channels,
|
||||
channel_mult=conf.enc_channel_mult or conf.channel_mult,
|
||||
).make_model()
|
||||
elif conf.semantic_encoder_type == '1dcnn':
|
||||
self.encoder = CNNEncoderConfig(
|
||||
seq_len=conf.seq_len,
|
||||
seq_len_future=conf.seq_len_future,
|
||||
in_channels=conf.in_channels,
|
||||
out_channels=conf.enc_out_channels,
|
||||
).make_model()
|
||||
elif conf.semantic_encoder_type == 'gru':
|
||||
# ensure deterministic behavior of RNNs
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:2"
|
||||
self.encoder = GRUEncoderConfig(
|
||||
seq_len=conf.seq_len,
|
||||
seq_len_future=conf.seq_len_future,
|
||||
in_channels=conf.in_channels,
|
||||
out_channels=conf.enc_out_channels,
|
||||
).make_model()
|
||||
elif conf.semantic_encoder_type == 'lstm':
|
||||
# ensure deterministic behavior of RNNs
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:2"
|
||||
self.encoder = LSTMEncoderConfig(
|
||||
seq_len=conf.seq_len,
|
||||
seq_len_future=conf.seq_len_future,
|
||||
in_channels=conf.in_channels,
|
||||
out_channels=conf.enc_out_channels,
|
||||
).make_model()
|
||||
elif conf.semantic_encoder_type == 'mlp':
|
||||
self.encoder = MLPEncoderConfig(
|
||||
seq_len=conf.seq_len,
|
||||
seq_len_future=conf.seq_len_future,
|
||||
in_channels=conf.in_channels,
|
||||
out_channels=conf.enc_out_channels,
|
||||
).make_model()
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
|
||||
"""
|
||||
Reparameterization trick to sample from N(mu, var) from
|
||||
N(0,1).
|
||||
:param mu: (Tensor) Mean of the latent Gaussian [B x D]
|
||||
:param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
|
||||
:return: (Tensor) [B x D]
|
||||
"""
|
||||
assert self.conf.is_stochastic
|
||||
std = torch.exp(0.5 * logvar)
|
||||
eps = torch.randn_like(std)
|
||||
return eps * std + mu
|
||||
|
||||
def sample_z(self, n: int, device):
|
||||
assert self.conf.is_stochastic
|
||||
return torch.randn(n, self.conf.enc_out_channels, device=device)
|
||||
|
||||
def noise_to_cond(self, noise: Tensor):
|
||||
raise NotImplementedError()
|
||||
assert self.conf.noise_net_conf is not None
|
||||
return self.noise_net.forward(noise)
|
||||
|
||||
def encode(self, x):
|
||||
cond, pred_hand, pred_head = self.encoder.forward(x)
|
||||
return cond, pred_hand, pred_head
|
||||
|
||||
@property
|
||||
def stylespace_sizes(self):
|
||||
modules = list(self.input_blocks.modules()) + list(
|
||||
self.middle_block.modules()) + list(self.output_blocks.modules())
|
||||
sizes = []
|
||||
for module in modules:
|
||||
if isinstance(module, ResBlock):
|
||||
linear = module.cond_emb_layers[-1]
|
||||
sizes.append(linear.weight.shape[0])
|
||||
return sizes
|
||||
|
||||
def encode_stylespace(self, x, return_vector: bool = True):
|
||||
"""
|
||||
encode to style space
|
||||
"""
|
||||
modules = list(self.input_blocks.modules()) + list(
|
||||
self.middle_block.modules()) + list(self.output_blocks.modules())
|
||||
# (n, c)
|
||||
cond = self.encoder.forward(x)
|
||||
S = []
|
||||
for module in modules:
|
||||
if isinstance(module, ResBlock):
|
||||
# (n, c')
|
||||
s = module.cond_emb_layers.forward(cond)
|
||||
S.append(s)
|
||||
|
||||
if return_vector:
|
||||
# (n, sum_c)
|
||||
return torch.cat(S, dim=1)
|
||||
else:
|
||||
return S
|
||||
|
||||
def forward(self,
|
||||
x,
|
||||
t,
|
||||
x_start=None,
|
||||
cond=None,
|
||||
style=None,
|
||||
noise=None,
|
||||
t_cond=None,
|
||||
**kwargs):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
|
||||
Args:
|
||||
x_start: the original image to encode
|
||||
cond: output of the encoder
|
||||
noise: random noise (to predict the cond)
|
||||
"""
|
||||
if t_cond is None:
|
||||
t_cond = t ## randomly sampled timestep with the size of [batch_size]
|
||||
|
||||
if noise is not None:
|
||||
# if the noise is given, we predict the cond from noise
|
||||
cond = self.noise_to_cond(noise)
|
||||
|
||||
cond_given = True
|
||||
if cond is None:
|
||||
cond_given = False
|
||||
if x is not None:
|
||||
assert len(x) == len(x_start), f'{len(x)} != {len(x_start)}'
|
||||
|
||||
cond, pred_hand, pred_head = self.encode(x_start)
|
||||
|
||||
if t is not None: ## t==t_cond
|
||||
_t_emb = timestep_embedding(t, self.conf.model_channels)
|
||||
#print("t: {}, _t_emb:{}".format(t, _t_emb))
|
||||
_t_cond_emb = timestep_embedding(t_cond, self.conf.model_channels)
|
||||
#print("t_cond: {}, _t_cond_emb:{}".format(t, _t_cond_emb))
|
||||
else:
|
||||
# this happens when training only autoenc
|
||||
_t_emb = None
|
||||
_t_cond_emb = None
|
||||
|
||||
if self.conf.resnet_two_cond:
|
||||
res = self.time_embed.forward( ## self.time_embed is an MLP
|
||||
time_emb=_t_emb,
|
||||
cond=cond,
|
||||
time_cond_emb=_t_cond_emb,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
if self.conf.resnet_two_cond:
|
||||
# two cond: first = time emb, second = cond_emb
|
||||
emb = res.time_emb
|
||||
cond_emb = res.emb
|
||||
else:
|
||||
# one cond = combined of both time and cond
|
||||
emb = res.emb
|
||||
cond_emb = None
|
||||
|
||||
# override the style if given
|
||||
style = style or res.style ## style==None, res.style: cond, torch.Size([64, 512])
|
||||
|
||||
|
||||
# where in the model to supply time conditions
|
||||
enc_time_emb = emb ## time embeddings
|
||||
mid_time_emb = emb
|
||||
dec_time_emb = emb
|
||||
# where in the model to supply style conditions
|
||||
enc_cond_emb = cond_emb ## z_sem embeddings
|
||||
mid_cond_emb = cond_emb
|
||||
dec_cond_emb = cond_emb
|
||||
|
||||
# hs = []
|
||||
hs = [[] for _ in range(len(self.conf.channel_mult))]
|
||||
|
||||
if x is not None:
|
||||
h = x.type(self.dtype)
|
||||
# input blocks
|
||||
k = 0
|
||||
for i in range(len(self.input_num_blocks)):
|
||||
for j in range(self.input_num_blocks[i]):
|
||||
h = self.input_blocks[k](h,
|
||||
emb=enc_time_emb,
|
||||
cond=enc_cond_emb)
|
||||
# print(i, j, h.shape)
|
||||
'''if h.shape[-1]%2==1:
|
||||
pdb.set_trace()'''
|
||||
hs[i].append(h)
|
||||
k += 1
|
||||
assert k == len(self.input_blocks)
|
||||
|
||||
# middle blocks
|
||||
h = self.middle_block(h, emb=mid_time_emb, cond=mid_cond_emb)
|
||||
else:
|
||||
# no lateral connections
|
||||
# happens when training only the autonecoder
|
||||
h = None
|
||||
hs = [[] for _ in range(len(self.conf.channel_mult))]
|
||||
pdb.set_trace()
|
||||
|
||||
# output blocks
|
||||
k = 0
|
||||
for i in range(len(self.output_num_blocks)):
|
||||
for j in range(self.output_num_blocks[i]):
|
||||
# take the lateral connection from the same layer (in reserve)
|
||||
# until there is no more, use None
|
||||
try:
|
||||
lateral = hs[-i - 1].pop() ## in the reverse order (symmetric)
|
||||
except IndexError:
|
||||
lateral = None
|
||||
'''print(i, j, lateral.shape, h.shape)
|
||||
if lateral.shape[-1]!=h.shape[-1]:
|
||||
pdb.set_trace()'''
|
||||
# print("h is", h.size())
|
||||
# print("lateral is", lateral.size())
|
||||
h = self.output_blocks[k](h,
|
||||
emb=dec_time_emb,
|
||||
cond=dec_cond_emb,
|
||||
lateral=lateral)
|
||||
k += 1
|
||||
|
||||
pred = self.out(h)
|
||||
# print("h:", h.shape)
|
||||
# print("pred:", pred.shape)
|
||||
|
||||
if cond_given == True:
|
||||
return AutoencReturn(pred=pred, cond=cond)
|
||||
else:
|
||||
return AutoencReturn(pred=pred, cond=cond, pred_hand=pred_hand, pred_head=pred_head)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GCNAutoencConfig(GCNUNetConfig):
|
||||
# number of style channels
|
||||
enc_out_channels: int = 256
|
||||
enc_channel_mult: Tuple[int] = None
|
||||
def make_model(self):
|
||||
return GCNAutoencModel(self)
|
||||
|
||||
|
||||
class GCNAutoencModel(GCNUNetModel):
|
||||
def __init__(self, conf: GCNAutoencConfig):
|
||||
super().__init__(conf)
|
||||
self.conf = conf
|
||||
|
||||
# having only time, cond
|
||||
self.time_emb_channels = conf.model_channels
|
||||
self.time_embed = TimeStyleSeperateEmbed(
|
||||
time_channels=self.time_emb_channels,
|
||||
time_out_channels=conf.embed_channels,
|
||||
)
|
||||
|
||||
self.encoder = GCNEncoderConfig(
|
||||
seq_len=conf.seq_len,
|
||||
in_channels=conf.in_channels,
|
||||
model_channels=32,
|
||||
out_channels=conf.enc_out_channels,
|
||||
channel_mult=conf.enc_channel_mult or conf.channel_mult,
|
||||
).make_model()
|
||||
|
||||
def encode(self, x):
|
||||
cond = self.encoder.forward(x)
|
||||
return {'cond': cond}
|
||||
|
||||
def forward(self,
|
||||
x,
|
||||
t,
|
||||
x_start=None,
|
||||
cond=None,
|
||||
**kwargs):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
|
||||
Args:
|
||||
x_start: the original input to encode
|
||||
cond: output of the encoder
|
||||
"""
|
||||
|
||||
if cond is None:
|
||||
if x is not None:
|
||||
assert len(x) == len(x_start), f'{len(x)} != {len(x_start)}'
|
||||
|
||||
tmp = self.encode(x_start)
|
||||
cond = tmp['cond']
|
||||
|
||||
if t is not None:
|
||||
_t_emb = timestep_embedding(t, self.time_emb_channels)
|
||||
else:
|
||||
# this happens when training only autoenc
|
||||
_t_emb = None
|
||||
|
||||
if self.conf.resnet_two_cond:
|
||||
res = self.time_embed.forward( ## self.time_embed is an MLP
|
||||
time_emb=_t_emb,
|
||||
cond=cond,
|
||||
)
|
||||
# two cond: first = time emb, second = cond_emb
|
||||
emb = res.time_emb
|
||||
cond_emb = res.emb
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
# where in the model to supply time conditions
|
||||
enc_time_emb = emb ## time embeddings
|
||||
mid_time_emb = emb
|
||||
dec_time_emb = emb
|
||||
enc_cond_emb = cond_emb ## z_sem embeddings
|
||||
mid_cond_emb = cond_emb
|
||||
dec_cond_emb = cond_emb
|
||||
|
||||
|
||||
bs, channels, seq_len = x.shape
|
||||
x = x.reshape(bs, self.conf.node_n, self.in_features, seq_len).permute(0, 2, 1, 3)
|
||||
hs = [[] for _ in range(len(self.conf.channel_mult))]
|
||||
h = x.type(self.dtype)
|
||||
|
||||
# input blocks
|
||||
k = 0
|
||||
for i in range(len(self.input_num_blocks)):
|
||||
for j in range(self.input_num_blocks[i]):
|
||||
h = self.input_blocks[k](h,
|
||||
emb=enc_time_emb,
|
||||
cond=enc_cond_emb)
|
||||
hs[i].append(h)
|
||||
k += 1
|
||||
assert k == len(self.input_blocks)
|
||||
|
||||
# output blocks
|
||||
k = 0
|
||||
for i in range(len(self.output_num_blocks)):
|
||||
for j in range(self.output_num_blocks[i]):
|
||||
# take the lateral connection from the same layer (in reserve)
|
||||
# until there is no more, use None
|
||||
try:
|
||||
lateral = hs[-i - 1].pop() ## in the reverse order (symmetric)
|
||||
except IndexError:
|
||||
lateral = None
|
||||
h = self.output_blocks[k](h,
|
||||
emb=dec_time_emb,
|
||||
cond=dec_cond_emb,
|
||||
lateral=lateral)
|
||||
k += 1
|
||||
|
||||
|
||||
pred = self.out(h)
|
||||
pred = pred.permute(0, 2, 1, 3).reshape(bs, -1, seq_len)
|
||||
|
||||
return AutoencReturn(pred=pred, cond=cond)
|
||||
|
||||
|
||||
class AutoencReturn(NamedTuple):
|
||||
pred: Tensor
|
||||
cond: Tensor = None
|
||||
pred_hand: Tensor = None
|
||||
pred_head: Tensor = None
|
||||
|
||||
|
||||
class EmbedReturn(NamedTuple):
|
||||
# style and time
|
||||
emb: Tensor = None
|
||||
# time only
|
||||
time_emb: Tensor = None
|
||||
# style only (but could depend on time)
|
||||
style: Tensor = None
|
||||
|
||||
|
||||
class TimeStyleSeperateEmbed(nn.Module):
|
||||
# embed only style
|
||||
def __init__(self, time_channels, time_out_channels):
|
||||
super().__init__()
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(time_channels, time_out_channels),
|
||||
nn.SiLU(),
|
||||
linear(time_out_channels, time_out_channels),
|
||||
)
|
||||
self.style = nn.Identity()
|
||||
|
||||
def forward(self, time_emb=None, cond=None, **kwargs):
|
||||
if time_emb is None:
|
||||
# happens with autoenc training mode
|
||||
time_emb = None
|
||||
else:
|
||||
time_emb = self.time_embed(time_emb)
|
||||
style = self.style(cond) ## style==cond
|
||||
return EmbedReturn(emb=style, time_emb=time_emb, style=style)
|
193
preprocess.py
Normal file
193
preprocess.py
Normal file
|
@ -0,0 +1,193 @@
|
|||
import os, random, math, copy
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import pickle as pkl
|
||||
import logging, sys
|
||||
from torch.utils.data import DataLoader,Dataset
|
||||
import multiprocessing as mp
|
||||
import json
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def MakeDir(dirpath):
|
||||
if not os.path.exists(dirpath):
|
||||
os.makedirs(dirpath)
|
||||
|
||||
|
||||
def load_egobody(data_dir, seq_len, sample_rate=1, train=1):
|
||||
data_dir_train = data_dir + 'train/'
|
||||
data_dir_test = data_dir + 'test/'
|
||||
if train == 0:
|
||||
data_dirs = [data_dir_test] # test
|
||||
elif train == 1:
|
||||
data_dirs = [data_dir_train] # train
|
||||
elif train == 2:
|
||||
data_dirs = [data_dir_train, data_dir_test] # train + test
|
||||
|
||||
hand_head = []
|
||||
for data_dir in data_dirs:
|
||||
file_paths = sorted(os.listdir(data_dir))
|
||||
pose_xyz_file_paths = []
|
||||
head_file_paths = []
|
||||
for path in file_paths:
|
||||
path_split = path.split('_')
|
||||
data_type = path_split[-1][:-4]
|
||||
if(data_type == 'xyz'):
|
||||
pose_xyz_file_paths.append(path)
|
||||
if(data_type == 'head'):
|
||||
head_file_paths.append(path)
|
||||
|
||||
file_num = len(pose_xyz_file_paths)
|
||||
for i in range(file_num):
|
||||
pose_data = np.load(data_dir + pose_xyz_file_paths[i])
|
||||
head_data = np.load(data_dir + head_file_paths[i])
|
||||
num_frames = pose_data.shape[0]
|
||||
if num_frames < seq_len:
|
||||
continue
|
||||
|
||||
head_pos = pose_data[:, 15*3:16*3]
|
||||
left_hand_pos = pose_data[:, 20*3:21*3]
|
||||
right_hand_pos = pose_data[:, 21*3:22*3]
|
||||
head_ori = head_data
|
||||
left_hand_pos -= head_pos # convert hand positions to head coordinate system
|
||||
right_hand_pos -= head_pos
|
||||
hand_head_data = left_hand_pos
|
||||
hand_head_data = np.concatenate((hand_head_data, right_hand_pos), axis=1)
|
||||
hand_head_data = np.concatenate((hand_head_data, head_ori), axis=1)
|
||||
|
||||
fs = np.arange(0, num_frames - seq_len + 1)
|
||||
fs_sel = fs
|
||||
for i in np.arange(seq_len - 1):
|
||||
fs_sel = np.vstack((fs_sel, fs + i + 1))
|
||||
fs_sel = fs_sel.transpose()
|
||||
seq_sel = hand_head_data[fs_sel, :]
|
||||
seq_sel = seq_sel[0::sample_rate, :, :]
|
||||
if len(hand_head) == 0:
|
||||
hand_head = seq_sel
|
||||
else:
|
||||
hand_head = np.concatenate((hand_head, seq_sel), axis=0)
|
||||
|
||||
hand_head = np.transpose(hand_head, (0, 2, 1))
|
||||
return hand_head
|
||||
|
||||
|
||||
def load_adt(data_dir, seq_len, sample_rate=1, train=1):
|
||||
data_dir_train = data_dir + 'train/'
|
||||
data_dir_test = data_dir + 'test/'
|
||||
if train == 0:
|
||||
data_dirs = [data_dir_test] # test
|
||||
elif train == 1:
|
||||
data_dirs = [data_dir_train] # train
|
||||
elif train == 2:
|
||||
data_dirs = [data_dir_train, data_dir_test] # train + test
|
||||
|
||||
hand_head = []
|
||||
for data_dir in data_dirs:
|
||||
file_paths = sorted(os.listdir(data_dir))
|
||||
pose_xyz_file_paths = []
|
||||
head_file_paths = []
|
||||
for path in file_paths:
|
||||
path_split = path.split('_')
|
||||
data_type = path_split[-1][:-4]
|
||||
if(data_type == 'xyz'):
|
||||
pose_xyz_file_paths.append(path)
|
||||
if(data_type == 'head'):
|
||||
head_file_paths.append(path)
|
||||
|
||||
file_num = len(pose_xyz_file_paths)
|
||||
for i in range(file_num):
|
||||
pose_data = np.load(data_dir + pose_xyz_file_paths[i])
|
||||
head_data = np.load(data_dir + head_file_paths[i])
|
||||
num_frames = pose_data.shape[0]
|
||||
if num_frames < seq_len:
|
||||
continue
|
||||
|
||||
head_pos = pose_data[:, 4*3:5*3]
|
||||
left_hand_pos = pose_data[:, 8*3:9*3]
|
||||
right_hand_pos = pose_data[:, 12*3:13*3]
|
||||
head_ori = head_data
|
||||
left_hand_pos -= head_pos # convert hand positions to head coordinate system
|
||||
right_hand_pos -= head_pos
|
||||
hand_head_data = left_hand_pos
|
||||
hand_head_data = np.concatenate((hand_head_data, right_hand_pos), axis=1)
|
||||
hand_head_data = np.concatenate((hand_head_data, head_ori), axis=1)
|
||||
|
||||
fs = np.arange(0, num_frames - seq_len + 1)
|
||||
fs_sel = fs
|
||||
for i in np.arange(seq_len - 1):
|
||||
fs_sel = np.vstack((fs_sel, fs + i + 1))
|
||||
fs_sel = fs_sel.transpose()
|
||||
seq_sel = hand_head_data[fs_sel, :]
|
||||
seq_sel = seq_sel[0::sample_rate, :, :]
|
||||
if len(hand_head) == 0:
|
||||
hand_head = seq_sel
|
||||
else:
|
||||
hand_head = np.concatenate((hand_head, seq_sel), axis=0)
|
||||
|
||||
hand_head = np.transpose(hand_head, (0, 2, 1))
|
||||
return hand_head
|
||||
|
||||
|
||||
def load_gimo(data_dir, seq_len, sample_rate=1, train=1):
|
||||
data_dir_train = data_dir + 'train/'
|
||||
data_dir_test = data_dir + 'test/'
|
||||
if train == 0:
|
||||
data_dirs = [data_dir_test] # test
|
||||
elif train == 1:
|
||||
data_dirs = [data_dir_train] # train
|
||||
elif train == 2:
|
||||
data_dirs = [data_dir_train, data_dir_test] # train + test
|
||||
|
||||
hand_head = []
|
||||
for data_dir in data_dirs:
|
||||
file_paths = sorted(os.listdir(data_dir))
|
||||
pose_xyz_file_paths = []
|
||||
head_file_paths = []
|
||||
for path in file_paths:
|
||||
path_split = path.split('_')
|
||||
data_type = path_split[-1][:-4]
|
||||
if(data_type == 'xyz'):
|
||||
pose_xyz_file_paths.append(path)
|
||||
if(data_type == 'head'):
|
||||
head_file_paths.append(path)
|
||||
|
||||
file_num = len(pose_xyz_file_paths)
|
||||
for i in range(file_num):
|
||||
pose_data = np.load(data_dir + pose_xyz_file_paths[i])
|
||||
head_data = np.load(data_dir + head_file_paths[i])
|
||||
num_frames = pose_data.shape[0]
|
||||
if num_frames < seq_len:
|
||||
continue
|
||||
|
||||
head_pos = pose_data[:, 15*3:16*3]
|
||||
left_hand_pos = pose_data[:, 20*3:21*3]
|
||||
right_hand_pos = pose_data[:, 21*3:22*3]
|
||||
head_ori = head_data
|
||||
left_hand_pos -= head_pos # convert hand positions to head coordinate system
|
||||
right_hand_pos -= head_pos
|
||||
hand_head_data = left_hand_pos
|
||||
hand_head_data = np.concatenate((hand_head_data, right_hand_pos), axis=1)
|
||||
hand_head_data = np.concatenate((hand_head_data, head_ori), axis=1)
|
||||
|
||||
fs = np.arange(0, num_frames - seq_len + 1)
|
||||
fs_sel = fs
|
||||
for i in np.arange(seq_len - 1):
|
||||
fs_sel = np.vstack((fs_sel, fs + i + 1))
|
||||
fs_sel = fs_sel.transpose()
|
||||
seq_sel = hand_head_data[fs_sel, :]
|
||||
seq_sel = seq_sel[0::sample_rate, :, :]
|
||||
if len(hand_head) == 0:
|
||||
hand_head = seq_sel
|
||||
else:
|
||||
hand_head = np.concatenate((hand_head, seq_sel), axis=0)
|
||||
|
||||
hand_head = np.transpose(hand_head, (0, 2, 1))
|
||||
return hand_head
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
data_dir = "/scratch/hu/pose_forecast/egobody_pose2gaze/"
|
||||
seq_len = 40
|
||||
|
||||
test_data = load_egobody(data_dir, seq_len, sample_rate=10, train=0)
|
||||
print("\ndataset size: {}".format(test_data.shape))
|
3
train.sh
Normal file
3
train.sh
Normal file
|
@ -0,0 +1,3 @@
|
|||
#python main.py --gpus 7 --mode 'train' --model_name 'haheae';
|
||||
|
||||
python main.py --gpus 7 --mode 'eval' --model_name 'haheae';
|
Loading…
Add table
Add a link
Reference in a new issue