update readme
This commit is contained in:
parent
35ee4b75e8
commit
249a01f342
18 changed files with 4936 additions and 0 deletions
BIN
checkpoints/haheae/last.ckpt
Normal file
BIN
checkpoints/haheae/last.ckpt
Normal file
Binary file not shown.
179
choices.py
Normal file
179
choices.py
Normal file
|
@ -0,0 +1,179 @@
|
||||||
|
from enum import Enum
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class TrainMode(Enum):
|
||||||
|
# manipulate mode = training the classifier
|
||||||
|
manipulate = 'manipulate'
|
||||||
|
# default training mode!
|
||||||
|
diffusion = 'diffusion'
|
||||||
|
# default latent training mode!
|
||||||
|
# fitting a diffusion model to a given latent
|
||||||
|
latent_diffusion = 'latentdiffusion'
|
||||||
|
|
||||||
|
def is_manipulate(self):
|
||||||
|
return self in [
|
||||||
|
TrainMode.manipulate,
|
||||||
|
]
|
||||||
|
|
||||||
|
def is_diffusion(self):
|
||||||
|
return self in [
|
||||||
|
TrainMode.diffusion,
|
||||||
|
TrainMode.latent_diffusion,
|
||||||
|
]
|
||||||
|
|
||||||
|
def is_autoenc(self):
|
||||||
|
# the network possibly does autoencoding
|
||||||
|
return self in [
|
||||||
|
TrainMode.diffusion,
|
||||||
|
]
|
||||||
|
|
||||||
|
def is_latent_diffusion(self):
|
||||||
|
return self in [
|
||||||
|
TrainMode.latent_diffusion,
|
||||||
|
]
|
||||||
|
|
||||||
|
def use_latent_net(self):
|
||||||
|
return self.is_latent_diffusion()
|
||||||
|
|
||||||
|
def require_dataset_infer(self):
|
||||||
|
"""
|
||||||
|
whether training in this mode requires the latent variables to be available?
|
||||||
|
"""
|
||||||
|
# this will precalculate all the latents before hand
|
||||||
|
# and the dataset will be all the predicted latents
|
||||||
|
return self in [
|
||||||
|
TrainMode.latent_diffusion,
|
||||||
|
TrainMode.manipulate,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class ManipulateMode(Enum):
|
||||||
|
"""
|
||||||
|
how to train the classifier to manipulate
|
||||||
|
"""
|
||||||
|
# train on whole celeba attr dataset
|
||||||
|
celebahq_all = 'celebahq_all'
|
||||||
|
# celeba with D2C's crop
|
||||||
|
d2c_fewshot = 'd2cfewshot'
|
||||||
|
d2c_fewshot_allneg = 'd2cfewshotallneg'
|
||||||
|
|
||||||
|
def is_celeba_attr(self):
|
||||||
|
return self in [
|
||||||
|
ManipulateMode.d2c_fewshot,
|
||||||
|
ManipulateMode.d2c_fewshot_allneg,
|
||||||
|
ManipulateMode.celebahq_all,
|
||||||
|
]
|
||||||
|
|
||||||
|
def is_single_class(self):
|
||||||
|
return self in [
|
||||||
|
ManipulateMode.d2c_fewshot,
|
||||||
|
ManipulateMode.d2c_fewshot_allneg,
|
||||||
|
]
|
||||||
|
|
||||||
|
def is_fewshot(self):
|
||||||
|
return self in [
|
||||||
|
ManipulateMode.d2c_fewshot,
|
||||||
|
ManipulateMode.d2c_fewshot_allneg,
|
||||||
|
]
|
||||||
|
|
||||||
|
def is_fewshot_allneg(self):
|
||||||
|
return self in [
|
||||||
|
ManipulateMode.d2c_fewshot_allneg,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class ModelType(Enum):
|
||||||
|
"""
|
||||||
|
Kinds of the backbone models
|
||||||
|
"""
|
||||||
|
|
||||||
|
# unconditional ddpm
|
||||||
|
ddpm = 'ddpm'
|
||||||
|
# autoencoding ddpm cannot do unconditional generation
|
||||||
|
autoencoder = 'autoencoder'
|
||||||
|
|
||||||
|
def has_autoenc(self):
|
||||||
|
return self in [
|
||||||
|
ModelType.autoencoder,
|
||||||
|
]
|
||||||
|
|
||||||
|
def can_sample(self):
|
||||||
|
return self in [ModelType.ddpm]
|
||||||
|
|
||||||
|
|
||||||
|
class ModelName(Enum):
|
||||||
|
"""
|
||||||
|
List of all supported model classes
|
||||||
|
"""
|
||||||
|
|
||||||
|
beatgans_ddpm = 'beatgans_ddpm'
|
||||||
|
beatgans_autoenc = 'beatgans_autoenc'
|
||||||
|
|
||||||
|
|
||||||
|
class ModelMeanType(Enum):
|
||||||
|
"""
|
||||||
|
Which type of output the model predicts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
eps = 'eps' # the model predicts epsilon
|
||||||
|
|
||||||
|
|
||||||
|
class ModelVarType(Enum):
|
||||||
|
"""
|
||||||
|
What is used as the model's output variance.
|
||||||
|
|
||||||
|
The LEARNED_RANGE option has been added to allow the model to predict
|
||||||
|
values between FIXED_SMALL and FIXED_LARGE, making its job easier.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# posterior beta_t
|
||||||
|
fixed_small = 'fixed_small'
|
||||||
|
# beta_t
|
||||||
|
fixed_large = 'fixed_large'
|
||||||
|
|
||||||
|
|
||||||
|
class LossType(Enum):
|
||||||
|
mse = 'mse' # use raw MSE loss (and KL when learning variances)
|
||||||
|
l1 = 'l1'
|
||||||
|
|
||||||
|
|
||||||
|
class GenerativeType(Enum):
|
||||||
|
"""
|
||||||
|
How's a sample generated
|
||||||
|
"""
|
||||||
|
|
||||||
|
ddpm = 'ddpm'
|
||||||
|
ddim = 'ddim'
|
||||||
|
|
||||||
|
|
||||||
|
class OptimizerType(Enum):
|
||||||
|
adam = 'adam'
|
||||||
|
adamw = 'adamw'
|
||||||
|
|
||||||
|
|
||||||
|
class Activation(Enum):
|
||||||
|
none = 'none'
|
||||||
|
relu = 'relu'
|
||||||
|
lrelu = 'lrelu'
|
||||||
|
silu = 'silu'
|
||||||
|
tanh = 'tanh'
|
||||||
|
|
||||||
|
def get_act(self):
|
||||||
|
if self == Activation.none:
|
||||||
|
return nn.Identity()
|
||||||
|
elif self == Activation.relu:
|
||||||
|
return nn.ReLU()
|
||||||
|
elif self == Activation.lrelu:
|
||||||
|
return nn.LeakyReLU(negative_slope=0.2)
|
||||||
|
elif self == Activation.silu:
|
||||||
|
return nn.SiLU()
|
||||||
|
elif self == Activation.tanh:
|
||||||
|
return nn.Tanh()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class ManipulateLossType(Enum):
|
||||||
|
bce = 'bce'
|
||||||
|
mse = 'mse'
|
153
config.py
Normal file
153
config.py
Normal file
|
@ -0,0 +1,153 @@
|
||||||
|
from model.blocks import *
|
||||||
|
from diffusion.resample import UniformSampler
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from diffusion.diffusion import space_timesteps
|
||||||
|
from typing import Tuple
|
||||||
|
from config_base import BaseConfig
|
||||||
|
from diffusion import *
|
||||||
|
from diffusion.base import GenerativeType, LossType, ModelMeanType, ModelVarType, get_named_beta_schedule
|
||||||
|
from model import *
|
||||||
|
from choices import *
|
||||||
|
from preprocess import *
|
||||||
|
import os
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainConfig(BaseConfig):
|
||||||
|
name: str = ''
|
||||||
|
base_dir: str = './checkpoints/'
|
||||||
|
logdir: str = f'{base_dir}{name}'
|
||||||
|
data_name: str = ''
|
||||||
|
data_val_name: str = ''
|
||||||
|
seq_len: int = 40 # for reconstruction
|
||||||
|
seq_len_future: int = 3 # for prediction
|
||||||
|
in_channels = 9
|
||||||
|
fp16: bool = True
|
||||||
|
lr: float = 1e-4
|
||||||
|
ema_decay: float = 0.9999
|
||||||
|
seed: int = 0 # random seed
|
||||||
|
batch_size: int = 64
|
||||||
|
accum_batches: int = 1
|
||||||
|
batch_size_eval: int = 1024
|
||||||
|
total_epochs: int = 1_000
|
||||||
|
save_every_epochs: int = 10
|
||||||
|
eval_every_epochs: int = 10
|
||||||
|
train_mode: TrainMode = TrainMode.diffusion
|
||||||
|
T: int = 1000
|
||||||
|
T_eval: int = 100
|
||||||
|
diffusion_type: str = 'beatgans'
|
||||||
|
semantic_encoder_type: str = 'gcn'
|
||||||
|
net_beatgans_embed_channels: int = 128
|
||||||
|
beatgans_gen_type: GenerativeType = GenerativeType.ddim
|
||||||
|
beatgans_loss_type: LossType = LossType.mse
|
||||||
|
hand_mse_factor = 1.0
|
||||||
|
head_mse_factor = 1.0
|
||||||
|
beatgans_model_mean_type: ModelMeanType = ModelMeanType.eps
|
||||||
|
beatgans_model_var_type: ModelVarType = ModelVarType.fixed_large
|
||||||
|
beatgans_rescale_timesteps: bool = False
|
||||||
|
beta_scheduler: str = 'linear'
|
||||||
|
net_ch: int = 64
|
||||||
|
net_ch_mult: Tuple[int, ...]= (1, 2, 4)
|
||||||
|
net_enc_channel_mult: Tuple[int] = (1, 2, 4)
|
||||||
|
grad_clip: float = 1
|
||||||
|
optimizer: OptimizerType = OptimizerType.adam
|
||||||
|
weight_decay: float = 0
|
||||||
|
warmup: int = 0
|
||||||
|
model_conf: ModelConfig = None
|
||||||
|
model_name: ModelName = ModelName.beatgans_autoenc
|
||||||
|
model_type: ModelType = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def batch_size_effective(self):
|
||||||
|
return self.batch_size*self.accum_batches
|
||||||
|
|
||||||
|
def _make_diffusion_conf(self, T=None):
|
||||||
|
if self.diffusion_type == 'beatgans':
|
||||||
|
# can use T < self.T for evaluation
|
||||||
|
# follows the guided-diffusion repo conventions
|
||||||
|
# t's are evenly spaced
|
||||||
|
if self.beatgans_gen_type == GenerativeType.ddpm:
|
||||||
|
section_counts = [T]
|
||||||
|
elif self.beatgans_gen_type == GenerativeType.ddim:
|
||||||
|
section_counts = f'ddim{T}'
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
return SpacedDiffusionBeatGansConfig(
|
||||||
|
gen_type=self.beatgans_gen_type,
|
||||||
|
model_type=self.model_type,
|
||||||
|
betas=get_named_beta_schedule(self.beta_scheduler, T),
|
||||||
|
model_mean_type=self.beatgans_model_mean_type,
|
||||||
|
model_var_type=self.beatgans_model_var_type,
|
||||||
|
loss_type=self.beatgans_loss_type,
|
||||||
|
rescale_timesteps=self.beatgans_rescale_timesteps,
|
||||||
|
use_timesteps=space_timesteps(num_timesteps=T, section_counts=section_counts),
|
||||||
|
fp16=self.fp16,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_out_channels(self):
|
||||||
|
return self.in_channels
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_input_channels(self):
|
||||||
|
return self.in_channels
|
||||||
|
|
||||||
|
def make_T_sampler(self):
|
||||||
|
return UniformSampler(self.T)
|
||||||
|
|
||||||
|
def make_diffusion_conf(self):
|
||||||
|
return self._make_diffusion_conf(self.T)
|
||||||
|
|
||||||
|
def make_eval_diffusion_conf(self):
|
||||||
|
return self._make_diffusion_conf(T=self.T_eval)
|
||||||
|
|
||||||
|
def make_model_conf(self):
|
||||||
|
cls = BeatGANsAutoencConfig
|
||||||
|
if self.model_name == ModelName.beatgans_autoenc:
|
||||||
|
self.model_type = ModelType.autoencoder
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
self.model_conf = cls(
|
||||||
|
semantic_encoder_type = self.semantic_encoder_type,
|
||||||
|
channel_mult=self.net_ch_mult,
|
||||||
|
seq_len = self.seq_len,
|
||||||
|
seq_len_future = self.seq_len_future,
|
||||||
|
embed_channels=self.net_beatgans_embed_channels,
|
||||||
|
enc_out_channels=self.net_beatgans_embed_channels,
|
||||||
|
enc_channel_mult=self.net_enc_channel_mult,
|
||||||
|
in_channels=self.model_input_channels,
|
||||||
|
model_channels=self.net_ch,
|
||||||
|
out_channels=self.model_out_channels,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.model_conf
|
||||||
|
|
||||||
|
def egobody_autoenc(mode, encoder_type='gcn', hand_mse_factor=1.0, head_mse_factor=1.0, data_sample_rate=1, epoch=130,in_channels=9, seq_len=40):
|
||||||
|
conf = TrainConfig()
|
||||||
|
conf.seq_len = seq_len
|
||||||
|
conf.seq_len_future = 3
|
||||||
|
conf.in_channels = in_channels
|
||||||
|
conf.net_beatgans_embed_channels = 128
|
||||||
|
conf.net_ch = 64
|
||||||
|
conf.net_ch_mult = (1, 1, 1)
|
||||||
|
conf.semantic_encoder_type = encoder_type
|
||||||
|
conf.hand_mse_factor = hand_mse_factor
|
||||||
|
conf.head_mse_factor = head_mse_factor
|
||||||
|
conf.net_enc_channel_mult = conf.net_ch_mult
|
||||||
|
conf.total_epochs = epoch
|
||||||
|
conf.save_every_epochs = 10
|
||||||
|
conf.eval_every_epochs = 10
|
||||||
|
conf.batch_size = 64
|
||||||
|
conf.batch_size_eval = 1024*4
|
||||||
|
|
||||||
|
conf.data_dir = "/scratch/hu/pose_forecast/egobody_pose2gaze/"
|
||||||
|
conf.data_sample_rate = data_sample_rate
|
||||||
|
conf.name = 'egobody_autoenc'
|
||||||
|
conf.data_name = 'egobody'
|
||||||
|
|
||||||
|
conf.mode = mode
|
||||||
|
conf.make_model_conf()
|
||||||
|
return conf
|
72
config_base.py
Normal file
72
config_base.py
Normal file
|
@ -0,0 +1,72 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from copy import deepcopy
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BaseConfig:
|
||||||
|
def clone(self):
|
||||||
|
return deepcopy(self)
|
||||||
|
|
||||||
|
def inherit(self, another):
|
||||||
|
"""inherit common keys from a given config"""
|
||||||
|
common_keys = set(self.__dict__.keys()) & set(another.__dict__.keys())
|
||||||
|
for k in common_keys:
|
||||||
|
setattr(self, k, getattr(another, k))
|
||||||
|
|
||||||
|
def propagate(self):
|
||||||
|
"""push down the configuration to all members"""
|
||||||
|
for k, v in self.__dict__.items():
|
||||||
|
if isinstance(v, BaseConfig):
|
||||||
|
v.inherit(self)
|
||||||
|
v.propagate()
|
||||||
|
|
||||||
|
def save(self, save_path):
|
||||||
|
"""save config to json file"""
|
||||||
|
dirname = os.path.dirname(save_path)
|
||||||
|
if not os.path.exists(dirname):
|
||||||
|
os.makedirs(dirname)
|
||||||
|
conf = self.as_dict_jsonable()
|
||||||
|
with open(save_path, 'w') as f:
|
||||||
|
json.dump(conf, f)
|
||||||
|
|
||||||
|
def load(self, load_path):
|
||||||
|
"""load json config"""
|
||||||
|
with open(load_path) as f:
|
||||||
|
conf = json.load(f)
|
||||||
|
self.from_dict(conf)
|
||||||
|
|
||||||
|
def from_dict(self, dict, strict=False):
|
||||||
|
for k, v in dict.items():
|
||||||
|
if not hasattr(self, k):
|
||||||
|
if strict:
|
||||||
|
raise ValueError(f"loading extra '{k}'")
|
||||||
|
else:
|
||||||
|
print(f"loading extra '{k}'")
|
||||||
|
continue
|
||||||
|
if isinstance(self.__dict__[k], BaseConfig):
|
||||||
|
self.__dict__[k].from_dict(v)
|
||||||
|
else:
|
||||||
|
self.__dict__[k] = v
|
||||||
|
|
||||||
|
def as_dict_jsonable(self):
|
||||||
|
conf = {}
|
||||||
|
for k, v in self.__dict__.items():
|
||||||
|
if isinstance(v, BaseConfig):
|
||||||
|
conf[k] = v.as_dict_jsonable()
|
||||||
|
else:
|
||||||
|
if jsonable(v):
|
||||||
|
conf[k] = v
|
||||||
|
else:
|
||||||
|
# ignore not jsonable
|
||||||
|
pass
|
||||||
|
return conf
|
||||||
|
|
||||||
|
|
||||||
|
def jsonable(x):
|
||||||
|
try:
|
||||||
|
json.dumps(x)
|
||||||
|
return True
|
||||||
|
except TypeError:
|
||||||
|
return False
|
6
diffusion/__init__.py
Normal file
6
diffusion/__init__.py
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from .diffusion import SpacedDiffusionBeatGans, SpacedDiffusionBeatGansConfig
|
||||||
|
|
||||||
|
Sampler = Union[SpacedDiffusionBeatGans]
|
||||||
|
SamplerConfig = Union[SpacedDiffusionBeatGansConfig]
|
1148
diffusion/base.py
Normal file
1148
diffusion/base.py
Normal file
File diff suppressed because it is too large
Load diff
182
diffusion/diffusion.py
Normal file
182
diffusion/diffusion.py
Normal file
|
@ -0,0 +1,182 @@
|
||||||
|
from .base import *
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
def space_timesteps(num_timesteps, section_counts):
|
||||||
|
"""
|
||||||
|
Create a list of timesteps to use from an original diffusion process,
|
||||||
|
given the number of timesteps we want to take from equally-sized portions
|
||||||
|
of the original process.
|
||||||
|
|
||||||
|
For example, if there's 300 timesteps and the section counts are [10,15,20]
|
||||||
|
then the first 100 timesteps are strided to be 10 timesteps, the second 100
|
||||||
|
are strided to be 15 timesteps, and the final 100 are strided to be 20.
|
||||||
|
|
||||||
|
If the stride is a string starting with "ddim", then the fixed striding
|
||||||
|
from the DDIM paper is used, and only one section is allowed.
|
||||||
|
|
||||||
|
:param num_timesteps: the number of diffusion steps in the original
|
||||||
|
process to divide up.
|
||||||
|
:param section_counts: either a list of numbers, or a string containing
|
||||||
|
comma-separated numbers, indicating the step count
|
||||||
|
per section. As a special case, use "ddimN" where N
|
||||||
|
is a number of steps to use the striding from the
|
||||||
|
DDIM paper.
|
||||||
|
:return: a set of diffusion steps from the original process to use.
|
||||||
|
"""
|
||||||
|
if isinstance(section_counts, str):
|
||||||
|
if section_counts.startswith("ddim"):
|
||||||
|
desired_count = int(section_counts[len("ddim"):])
|
||||||
|
for i in range(1, num_timesteps):
|
||||||
|
if len(range(0, num_timesteps, i)) == desired_count:
|
||||||
|
return set(range(0, num_timesteps, i))
|
||||||
|
raise ValueError(
|
||||||
|
f"cannot create exactly {num_timesteps} steps with an integer stride"
|
||||||
|
)
|
||||||
|
section_counts = [int(x) for x in section_counts.split(",")]
|
||||||
|
size_per = num_timesteps // len(section_counts)
|
||||||
|
extra = num_timesteps % len(section_counts)
|
||||||
|
start_idx = 0
|
||||||
|
all_steps = []
|
||||||
|
for i, section_count in enumerate(section_counts):
|
||||||
|
size = size_per + (1 if i < extra else 0)
|
||||||
|
if size < section_count:
|
||||||
|
raise ValueError(
|
||||||
|
f"cannot divide section of {size} steps into {section_count}")
|
||||||
|
if section_count <= 1:
|
||||||
|
frac_stride = 1
|
||||||
|
else:
|
||||||
|
frac_stride = (size - 1) / (section_count - 1)
|
||||||
|
cur_idx = 0.0
|
||||||
|
taken_steps = []
|
||||||
|
for _ in range(section_count):
|
||||||
|
taken_steps.append(start_idx + round(cur_idx))
|
||||||
|
cur_idx += frac_stride
|
||||||
|
all_steps += taken_steps
|
||||||
|
start_idx += size
|
||||||
|
return set(all_steps)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SpacedDiffusionBeatGansConfig(GaussianDiffusionBeatGansConfig):
|
||||||
|
use_timesteps: Tuple[int] = None
|
||||||
|
|
||||||
|
def make_sampler(self):
|
||||||
|
return SpacedDiffusionBeatGans(self)
|
||||||
|
|
||||||
|
|
||||||
|
class SpacedDiffusionBeatGans(GaussianDiffusionBeatGans):
|
||||||
|
"""
|
||||||
|
A diffusion process which can skip steps in a base diffusion process.
|
||||||
|
|
||||||
|
:param use_timesteps: a collection (sequence or set) of timesteps from the
|
||||||
|
original diffusion process to retain.
|
||||||
|
:param kwargs: the kwargs to create the base diffusion process.
|
||||||
|
"""
|
||||||
|
def __init__(self, conf: SpacedDiffusionBeatGansConfig):
|
||||||
|
self.conf = conf
|
||||||
|
self.use_timesteps = set(conf.use_timesteps)
|
||||||
|
# how the new t's mapped to the old t's
|
||||||
|
self.timestep_map = []
|
||||||
|
self.original_num_steps = len(conf.betas)
|
||||||
|
|
||||||
|
base_diffusion = GaussianDiffusionBeatGans(conf) # pylint: disable=missing-kwoa
|
||||||
|
last_alpha_cumprod = 1.0
|
||||||
|
new_betas = []
|
||||||
|
for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
|
||||||
|
if i in self.use_timesteps:
|
||||||
|
# getting the new betas of the new timesteps
|
||||||
|
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
|
||||||
|
last_alpha_cumprod = alpha_cumprod
|
||||||
|
self.timestep_map.append(i)
|
||||||
|
conf.betas = np.array(new_betas)
|
||||||
|
super().__init__(conf)
|
||||||
|
|
||||||
|
def p_mean_variance(self, model: Model, *args, **kwargs): # pylint: disable=signature-differs
|
||||||
|
return super().p_mean_variance(self._wrap_model(model), *args,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
def training_losses(self, model: Model, *args, **kwargs): # pylint: disable=signature-differs
|
||||||
|
return super().training_losses(self._wrap_model(model), *args,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
def condition_mean(self, cond_fn, *args, **kwargs):
|
||||||
|
return super().condition_mean(self._wrap_model(cond_fn), *args,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
def condition_score(self, cond_fn, *args, **kwargs):
|
||||||
|
return super().condition_score(self._wrap_model(cond_fn), *args,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
def _wrap_model(self, model: Model):
|
||||||
|
if isinstance(model, _WrappedModel):
|
||||||
|
return model
|
||||||
|
return _WrappedModel(model, self.timestep_map, self.rescale_timesteps,
|
||||||
|
self.original_num_steps)
|
||||||
|
|
||||||
|
def _scale_timesteps(self, t):
|
||||||
|
# Scaling is done by the wrapped model.
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
class _WrappedModel:
|
||||||
|
"""
|
||||||
|
converting the supplied t's to the old t's scales.
|
||||||
|
"""
|
||||||
|
def __init__(self, model, timestep_map, rescale_timesteps,
|
||||||
|
original_num_steps):
|
||||||
|
self.model = model
|
||||||
|
self.timestep_map = timestep_map
|
||||||
|
self.rescale_timesteps = rescale_timesteps
|
||||||
|
self.original_num_steps = original_num_steps
|
||||||
|
|
||||||
|
def forward(self, x, t, t_cond=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
t: t's with differrent ranges (can be << T due to smaller eval T) need to be converted to the original t's
|
||||||
|
t_cond: the same as t but can be of different values
|
||||||
|
"""
|
||||||
|
map_tensor = th.tensor(self.timestep_map,
|
||||||
|
device=t.device,
|
||||||
|
dtype=t.dtype)
|
||||||
|
|
||||||
|
def do(t):
|
||||||
|
new_ts = map_tensor[t]
|
||||||
|
if self.rescale_timesteps:
|
||||||
|
new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
|
||||||
|
return new_ts
|
||||||
|
|
||||||
|
if t_cond is not None:
|
||||||
|
# support t_cond
|
||||||
|
t_cond = do(t_cond)
|
||||||
|
## run_ffhq256.py None
|
||||||
|
return self.model(x=x, t=do(t), t_cond=t_cond, **kwargs) ## self.model.__class__==model.unet_autoenc.BeatGANsAutoencModel
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
# allow for calling the model's methods
|
||||||
|
if hasattr(self.model, name):
|
||||||
|
func = getattr(self.model, name)
|
||||||
|
return func
|
||||||
|
raise AttributeError(name)
|
||||||
|
|
||||||
|
# def __call__(self, x, t, t_cond=None, **kwargs):
|
||||||
|
# """
|
||||||
|
# Args:
|
||||||
|
# t: t's with differrent ranges (can be << T due to smaller eval T) need to be converted to the original t's
|
||||||
|
# t_cond: the same as t but can be of different values
|
||||||
|
# """
|
||||||
|
# map_tensor = th.tensor(self.timestep_map,
|
||||||
|
# device=t.device,
|
||||||
|
# dtype=t.dtype)
|
||||||
|
|
||||||
|
# def do(t):
|
||||||
|
# new_ts = map_tensor[t]
|
||||||
|
# if self.rescale_timesteps:
|
||||||
|
# new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
|
||||||
|
# return new_ts
|
||||||
|
|
||||||
|
# if t_cond is not None:
|
||||||
|
# # support t_cond
|
||||||
|
# t_cond = do(t_cond)
|
||||||
|
# ## run_ffhq256.py None
|
||||||
|
# return self.model(x=x, t=do(t), t_cond=t_cond, **kwargs)
|
63
diffusion/resample.py
Normal file
63
diffusion/resample.py
Normal file
|
@ -0,0 +1,63 @@
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch as th
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
|
||||||
|
def create_named_schedule_sampler(name, diffusion):
|
||||||
|
"""
|
||||||
|
Create a ScheduleSampler from a library of pre-defined samplers.
|
||||||
|
|
||||||
|
:param name: the name of the sampler.
|
||||||
|
:param diffusion: the diffusion object to sample for.
|
||||||
|
"""
|
||||||
|
if name == "uniform":
|
||||||
|
return UniformSampler(diffusion)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"unknown schedule sampler: {name}")
|
||||||
|
|
||||||
|
|
||||||
|
class ScheduleSampler(ABC):
|
||||||
|
"""
|
||||||
|
A distribution over timesteps in the diffusion process, intended to reduce
|
||||||
|
variance of the objective.
|
||||||
|
|
||||||
|
By default, samplers perform unbiased importance sampling, in which the
|
||||||
|
objective's mean is unchanged.
|
||||||
|
However, subclasses may override sample() to change how the resampled
|
||||||
|
terms are reweighted, allowing for actual changes in the objective.
|
||||||
|
"""
|
||||||
|
@abstractmethod
|
||||||
|
def weights(self):
|
||||||
|
"""
|
||||||
|
Get a numpy array of weights, one per diffusion step.
|
||||||
|
|
||||||
|
The weights needn't be normalized, but must be positive.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def sample(self, batch_size, device):
|
||||||
|
"""
|
||||||
|
Importance-sample timesteps for a batch.
|
||||||
|
|
||||||
|
:param batch_size: the number of timesteps.
|
||||||
|
:param device: the torch device to save to.
|
||||||
|
:return: a tuple (timesteps, weights):
|
||||||
|
- timesteps: a tensor of timestep indices.
|
||||||
|
- weights: a tensor of weights to scale the resulting losses.
|
||||||
|
"""
|
||||||
|
w = self.weights()
|
||||||
|
p = w / np.sum(w)
|
||||||
|
indices_np = np.random.choice(len(p), size=(batch_size, ), p=p)
|
||||||
|
indices = th.from_numpy(indices_np).long().to(device)
|
||||||
|
weights_np = 1 / (len(p) * p[indices_np])
|
||||||
|
weights = th.from_numpy(weights_np).float().to(device)
|
||||||
|
return indices, weights
|
||||||
|
|
||||||
|
|
||||||
|
class UniformSampler(ScheduleSampler):
|
||||||
|
def __init__(self, num_timesteps): ## all steps are 1
|
||||||
|
self._weights = np.ones([num_timesteps])
|
||||||
|
|
||||||
|
def weights(self):
|
||||||
|
return self._weights
|
101
environment/haheae.yml
Normal file
101
environment/haheae.yml
Normal file
|
@ -0,0 +1,101 @@
|
||||||
|
name: haheae
|
||||||
|
channels:
|
||||||
|
- defaults
|
||||||
|
dependencies:
|
||||||
|
- _libgcc_mutex=0.1=main
|
||||||
|
- _openmp_mutex=5.1=1_gnu
|
||||||
|
- ca-certificates=2023.12.12=h06a4308_0
|
||||||
|
- ld_impl_linux-64=2.38=h1181459_1
|
||||||
|
- libffi=3.4.4=h6a678d5_0
|
||||||
|
- libgcc-ng=11.2.0=h1234567_1
|
||||||
|
- libgomp=11.2.0=h1234567_1
|
||||||
|
- libstdcxx-ng=11.2.0=h1234567_1
|
||||||
|
- ncurses=6.4=h6a678d5_0
|
||||||
|
- openssl=3.0.12=h7f8727e_0
|
||||||
|
- pip=23.3.1=py38h06a4308_0
|
||||||
|
- python=3.8.18=h955ad1f_0
|
||||||
|
- readline=8.2=h5eee18b_0
|
||||||
|
- setuptools=68.2.2=py38h06a4308_0
|
||||||
|
- sqlite=3.41.2=h5eee18b_0
|
||||||
|
- tk=8.6.12=h1ccaba5_0
|
||||||
|
- wheel=0.41.2=py38h06a4308_0
|
||||||
|
- xz=5.4.5=h5eee18b_0
|
||||||
|
- zlib=1.2.13=h5eee18b_0
|
||||||
|
- pip:
|
||||||
|
- absl-py==2.0.0
|
||||||
|
- aiohttp==3.9.1
|
||||||
|
- aiosignal==1.3.1
|
||||||
|
- appdirs==1.4.4
|
||||||
|
- async-timeout==4.0.3
|
||||||
|
- attrs==23.2.0
|
||||||
|
- cachetools==5.3.2
|
||||||
|
- certifi==2023.11.17
|
||||||
|
- charset-normalizer==3.3.2
|
||||||
|
- click==8.1.7
|
||||||
|
- contourpy==1.1.1
|
||||||
|
- cycler==0.12.1
|
||||||
|
- cython==0.29.37
|
||||||
|
- docker-pycreds==0.4.0
|
||||||
|
- fonttools==4.47.2
|
||||||
|
- frozenlist==1.4.1
|
||||||
|
- fsspec==2023.12.2
|
||||||
|
- ftfy==6.1.3
|
||||||
|
- future==0.18.3
|
||||||
|
- gitdb==4.0.11
|
||||||
|
- gitpython==3.1.41
|
||||||
|
- google-auth==2.26.2
|
||||||
|
- google-auth-oauthlib==1.0.0
|
||||||
|
- grpcio==1.60.0
|
||||||
|
- hdbscan==0.8.33
|
||||||
|
- idna==3.6
|
||||||
|
- importlib-metadata==7.0.1
|
||||||
|
- importlib-resources==6.1.1
|
||||||
|
- joblib==1.3.2
|
||||||
|
- kiwisolver==1.4.5
|
||||||
|
- lmdb==1.2.1
|
||||||
|
- lpips==0.1.4
|
||||||
|
- markdown==3.5.2
|
||||||
|
- markupsafe==2.1.3
|
||||||
|
- matplotlib==3.5.3
|
||||||
|
- multidict==6.0.4
|
||||||
|
- numpy==1.24.4
|
||||||
|
- oauthlib==3.2.2
|
||||||
|
- packaging==23.2
|
||||||
|
- pandas==1.5.3
|
||||||
|
- pillow==10.2.0
|
||||||
|
- protobuf==4.25.2
|
||||||
|
- psutil==5.9.8
|
||||||
|
- pyasn1==0.5.1
|
||||||
|
- pyasn1-modules==0.3.0
|
||||||
|
- pydeprecate==0.3.1
|
||||||
|
- pyparsing==3.1.1
|
||||||
|
- python-dateutil==2.8.2
|
||||||
|
- pytorch-fid==0.2.0
|
||||||
|
- pytorch-lightning==1.4.5
|
||||||
|
- pytz==2023.3.post1
|
||||||
|
- pyyaml==6.0.1
|
||||||
|
- regex==2023.12.25
|
||||||
|
- requests==2.31.0
|
||||||
|
- requests-oauthlib==1.3.1
|
||||||
|
- rsa==4.9
|
||||||
|
- scikit-learn==1.3.2
|
||||||
|
- scipy==1.5.4
|
||||||
|
- sentry-sdk==1.39.2
|
||||||
|
- setproctitle==1.3.3
|
||||||
|
- six==1.16.0
|
||||||
|
- smmap==5.0.1
|
||||||
|
- tensorboard==2.14.0
|
||||||
|
- tensorboard-data-server==0.7.2
|
||||||
|
- threadpoolctl==3.2.0
|
||||||
|
- torch==1.8.1
|
||||||
|
- torchmetrics==0.5.0
|
||||||
|
- torchvision==0.9.1
|
||||||
|
- tqdm==4.66.1
|
||||||
|
- typing-extensions==4.9.0
|
||||||
|
- tzdata==2023.4
|
||||||
|
- urllib3==2.1.0
|
||||||
|
- wandb==0.16.2
|
||||||
|
- wcwidth==0.2.13
|
||||||
|
- werkzeug==3.0.1
|
||||||
|
- yarl==1.9.4
|
||||||
|
- zipp==3.17.0
|
565
main.py
Normal file
565
main.py
Normal file
|
@ -0,0 +1,565 @@
|
||||||
|
import warnings
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
import os
|
||||||
|
os.nice(5)
|
||||||
|
import copy, wandb
|
||||||
|
from tqdm import tqdm, trange
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import random
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import torch
|
||||||
|
from pytorch_lightning import loggers as pl_loggers
|
||||||
|
from pytorch_lightning.callbacks import *
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.cuda import amp
|
||||||
|
from torch.optim.optimizer import Optimizer
|
||||||
|
from config import *
|
||||||
|
import time
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
class LitModel(pl.LightningModule):
|
||||||
|
def __init__(self, conf: TrainConfig):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
## wandb
|
||||||
|
self.save_hyperparameters({k:v for (k,v) in vars(conf).items() if not callable(v)})
|
||||||
|
|
||||||
|
if conf.seed is not None:
|
||||||
|
pl.seed_everything(conf.seed)
|
||||||
|
|
||||||
|
self.save_hyperparameters(conf.as_dict_jsonable())
|
||||||
|
self.conf = conf
|
||||||
|
self.model = conf.make_model_conf().make_model()
|
||||||
|
|
||||||
|
self.ema_model = copy.deepcopy(self.model)
|
||||||
|
self.ema_model.requires_grad_(False)
|
||||||
|
self.ema_model.eval()
|
||||||
|
|
||||||
|
model_size = 0
|
||||||
|
for param in self.model.parameters():
|
||||||
|
model_size += param.data.nelement()
|
||||||
|
print('Model params: %.3f M' % (model_size / 1024 / 1024))
|
||||||
|
|
||||||
|
self.sampler = conf.make_diffusion_conf().make_sampler()
|
||||||
|
self.eval_sampler = conf.make_eval_diffusion_conf().make_sampler()
|
||||||
|
self.T_sampler = conf.make_T_sampler()
|
||||||
|
self.save_every_epochs = conf.save_every_epochs
|
||||||
|
self.eval_every_epochs = conf.eval_every_epochs
|
||||||
|
|
||||||
|
def setup(self, stage=None) -> None:
|
||||||
|
"""
|
||||||
|
make datasets & seeding each worker separately
|
||||||
|
"""
|
||||||
|
##############################################
|
||||||
|
# NEED TO SET THE SEED SEPARATELY HERE
|
||||||
|
if self.conf.seed is not None:
|
||||||
|
seed = self.conf.seed
|
||||||
|
np.random.seed(seed)
|
||||||
|
random.seed(seed) # Python random module.
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
print('seed:', seed)
|
||||||
|
##############################################
|
||||||
|
|
||||||
|
## Load dataset
|
||||||
|
if self.conf.mode == 'train':
|
||||||
|
self.train_data = load_egobody(self.conf.data_dir, self.conf.seq_len+self.conf.seq_len_future, self.conf.data_sample_rate, train=1)
|
||||||
|
self.val_data = load_egobody(self.conf.data_dir, self.conf.seq_len+self.conf.seq_len_future, self.conf.data_sample_rate, train=0)
|
||||||
|
if self.conf.in_channels == 6: # hand only
|
||||||
|
self.train_data = self.train_data[:, :6, :]
|
||||||
|
self.val_data = self.val_data[:, :6, :]
|
||||||
|
if self.conf.in_channels == 3: # head only
|
||||||
|
self.train_data = self.train_data[:, 6:, :]
|
||||||
|
self.val_data = self.val_data[:, 6:, :]
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
assert self.conf.model_type.has_autoenc()
|
||||||
|
cond, pred_hand, pred_head = self.ema_model.encoder.forward(x)
|
||||||
|
return cond, pred_hand, pred_head
|
||||||
|
|
||||||
|
def encode_stochastic(self, x, cond, T=None):
|
||||||
|
if T is None:
|
||||||
|
sampler = self.eval_sampler
|
||||||
|
else:
|
||||||
|
sampler = self.conf._make_diffusion_conf(T).make_sampler() # get noise at step T
|
||||||
|
|
||||||
|
## x_0 -> x-T using reverse of inference
|
||||||
|
out = sampler.ddim_reverse_sample_loop(self.ema_model, x, model_kwargs={'cond': cond})
|
||||||
|
''' 'sample': x_T
|
||||||
|
'sample_t': x_t, t in (1, ..., T)
|
||||||
|
'xstart_t': predicted x_0 at each timestep. "xstart here is a bit different from sampling from T = T-1 to T = 0"
|
||||||
|
'T': (1, ..., T)
|
||||||
|
'''
|
||||||
|
return out['sample']
|
||||||
|
|
||||||
|
def train_dataloader(self):
|
||||||
|
return torch.utils.data.DataLoader(self.train_data, batch_size=self.conf.batch_size, shuffle=True)
|
||||||
|
|
||||||
|
def val_dataloader(self):
|
||||||
|
return torch.utils.data.DataLoader(self.val_data, batch_size=self.conf.batch_size_eval, shuffle=False)
|
||||||
|
|
||||||
|
def is_last_accum(self, batch_idx):
|
||||||
|
"""
|
||||||
|
is it the last gradient accumulation loop?
|
||||||
|
used with gradient_accum > 1 and to see if the optimizer will perform "step" in this iteration or not
|
||||||
|
"""
|
||||||
|
return (batch_idx + 1) % self.conf.accum_batches == 0
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx):
|
||||||
|
"""
|
||||||
|
given an input, calculate the loss function
|
||||||
|
no optimization at this stage.
|
||||||
|
"""
|
||||||
|
with amp.autocast(False):
|
||||||
|
x_start = batch[:, :, :self.conf.seq_len]
|
||||||
|
x_future = batch[:, :, self.conf.seq_len:]
|
||||||
|
|
||||||
|
if self.conf.train_mode == TrainMode.diffusion:
|
||||||
|
"""
|
||||||
|
main training mode!!!
|
||||||
|
"""
|
||||||
|
t, weight = self.T_sampler.sample(len(x_start), x_start.device)
|
||||||
|
''' self.T_sampler: diffusion.resample.UniformSampler (weights for all timesteps are 1)
|
||||||
|
- t: a tensor of timestep indices.
|
||||||
|
- weight: a tensor of weights to scale the resulting losses.
|
||||||
|
## num_timesteps is self.conf.T == 1000
|
||||||
|
'''
|
||||||
|
losses = self.sampler.training_losses(model=self.model,
|
||||||
|
x_start=x_start,
|
||||||
|
t=t,
|
||||||
|
x_future=x_future,
|
||||||
|
hand_mse_factor = self.conf.hand_mse_factor,
|
||||||
|
head_mse_factor = self.conf.head_mse_factor,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
loss = losses['loss'].mean() ## average loss across mini-batches
|
||||||
|
#noise_mse = losses['mse'].mean()
|
||||||
|
#hand_mse = losses['hand_mse'].mean()
|
||||||
|
#head_mse = losses['head_mse'].mean()
|
||||||
|
|
||||||
|
## Log loss and metric (wandb)
|
||||||
|
self.log("train_loss", loss, on_epoch=True, prog_bar=True)
|
||||||
|
#self.log("train_noise_mse", noise_mse, on_epoch=True, prog_bar=True)
|
||||||
|
#self.log("train_hand_mse", hand_mse, on_epoch=True, prog_bar=True)
|
||||||
|
#self.log("train_head_mse", head_mse, on_epoch=True, prog_bar=True)
|
||||||
|
return {'loss': loss}
|
||||||
|
|
||||||
|
def validation_step(self, batch, batch_idx):
|
||||||
|
if self.conf.in_channels == 9: # both hand and head
|
||||||
|
if((self.current_epoch+1)% self.eval_every_epochs == 0):
|
||||||
|
batch_future = batch[:, :, self.conf.seq_len:]
|
||||||
|
gt_hand_future = batch_future[:, :6, :]
|
||||||
|
gt_head_future = batch_future[:, 6:, :]
|
||||||
|
batch = batch[:, :, :self.conf.seq_len]
|
||||||
|
cond, pred_hand_future, pred_head_future = self.encode(batch)
|
||||||
|
xT = self.encode_stochastic(batch, cond)
|
||||||
|
pred_xstart = self.generate(xT, cond)
|
||||||
|
|
||||||
|
# hand reconstruction error
|
||||||
|
gt_hand = batch[:, :6, :]
|
||||||
|
pred_hand = pred_xstart[:, :6, :]
|
||||||
|
bs, channels, seq_len = gt_hand.shape
|
||||||
|
gt_hand = gt_hand.reshape(bs, 2, 3, seq_len)
|
||||||
|
pred_hand = pred_hand.reshape(bs, 2, 3, seq_len)
|
||||||
|
hand_traj = torch.mean(torch.norm(gt_hand - pred_hand, dim=2))
|
||||||
|
|
||||||
|
# hand prediction error
|
||||||
|
bs, channels, seq_len = gt_hand_future.shape
|
||||||
|
gt_hand_future = gt_hand_future.reshape(bs, 2, 3, seq_len)
|
||||||
|
pred_hand_future = pred_hand_future.reshape(bs, 2, 3, seq_len)
|
||||||
|
baseline_hand_future = gt_hand[:, :, :, -1:].expand(-1, -1, -1, self.conf.seq_len_future).clone()
|
||||||
|
hand_traj_future = torch.mean(torch.norm(gt_hand_future - pred_hand_future, dim=2))
|
||||||
|
hand_traj_future_baseline = torch.mean(torch.norm(gt_hand_future - baseline_hand_future, dim=2))
|
||||||
|
|
||||||
|
# head reconstruction error
|
||||||
|
gt_head = batch[:, 6:, :]
|
||||||
|
gt_head = F.normalize(gt_head, dim=1) # normalize head orientation to unit vectors
|
||||||
|
pred_head = pred_xstart[:, 6:, :]
|
||||||
|
pred_head = F.normalize(pred_head, dim=1)
|
||||||
|
head_ang = torch.mean(acos_safe(torch.sum(gt_head*pred_head, 1)))/torch.tensor(math.pi) * 180.0
|
||||||
|
|
||||||
|
# head prediction error
|
||||||
|
gt_head_future = F.normalize(gt_head_future, dim=1)
|
||||||
|
pred_head_future = F.normalize(pred_head_future, dim=1)
|
||||||
|
baseline_head_future = gt_head[:, :, -1:].expand(-1, -1, self.conf.seq_len_future).clone()
|
||||||
|
head_ang_future = torch.mean(acos_safe(torch.sum(gt_head_future*pred_head_future, 1)))/torch.tensor(math.pi) * 180.0
|
||||||
|
head_ang_future_baseline = torch.mean(acos_safe(torch.sum(gt_head_future*baseline_head_future, 1)))/torch.tensor(math.pi) * 180.0
|
||||||
|
|
||||||
|
self.log("val_hand_traj", hand_traj, on_epoch=True, prog_bar=True)
|
||||||
|
self.log("val_head_ang", head_ang, on_epoch=True, prog_bar=True)
|
||||||
|
self.log("val_hand_traj_future", hand_traj_future, on_epoch=True, prog_bar=True)
|
||||||
|
self.log("val_head_ang_future", head_ang_future, on_epoch=True, prog_bar=True)
|
||||||
|
self.log("val_hand_traj_future_baseline", hand_traj_future_baseline, on_epoch=True, prog_bar=True)
|
||||||
|
self.log("val_head_ang_future_baseline", head_ang_future_baseline, on_epoch=True, prog_bar=True)
|
||||||
|
|
||||||
|
if self.conf.in_channels == 6: # hand only
|
||||||
|
if((self.current_epoch+1)% self.eval_every_epochs == 0):
|
||||||
|
batch_future = batch[:, :, self.conf.seq_len:]
|
||||||
|
gt_hand_future = batch_future[:, :, :]
|
||||||
|
batch = batch[:, :, :self.conf.seq_len]
|
||||||
|
cond, pred_hand_future, pred_head_future = self.encode(batch)
|
||||||
|
xT = self.encode_stochastic(batch, cond)
|
||||||
|
pred_xstart = self.generate(xT, cond)
|
||||||
|
|
||||||
|
# hand reconstruction error
|
||||||
|
gt_hand = batch[:, :, :]
|
||||||
|
pred_hand = pred_xstart[:, :, :]
|
||||||
|
bs, channels, seq_len = gt_hand.shape
|
||||||
|
gt_hand = gt_hand.reshape(bs, 2, 3, seq_len)
|
||||||
|
pred_hand = pred_hand.reshape(bs, 2, 3, seq_len)
|
||||||
|
hand_traj = torch.mean(torch.norm(gt_hand - pred_hand, dim=2))
|
||||||
|
|
||||||
|
# hand prediction error
|
||||||
|
bs, channels, seq_len = gt_hand_future.shape
|
||||||
|
gt_hand_future = gt_hand_future.reshape(bs, 2, 3, seq_len)
|
||||||
|
pred_hand_future = pred_hand_future.reshape(bs, 2, 3, seq_len)
|
||||||
|
baseline_hand_future = gt_hand[:, :, :, -1:].expand(-1, -1, -1, self.conf.seq_len_future).clone()
|
||||||
|
hand_traj_future = torch.mean(torch.norm(gt_hand_future - pred_hand_future, dim=2))
|
||||||
|
hand_traj_future_baseline = torch.mean(torch.norm(gt_hand_future - baseline_hand_future, dim=2))
|
||||||
|
|
||||||
|
self.log("val_hand_traj", hand_traj, on_epoch=True, prog_bar=True)
|
||||||
|
self.log("val_hand_traj_future", hand_traj_future, on_epoch=True, prog_bar=True)
|
||||||
|
self.log("val_hand_traj_future_baseline", hand_traj_future_baseline, on_epoch=True, prog_bar=True)
|
||||||
|
|
||||||
|
if self.conf.in_channels == 3: # head only
|
||||||
|
if((self.current_epoch+1)% self.eval_every_epochs == 0):
|
||||||
|
batch_future = batch[:, :, self.conf.seq_len:]
|
||||||
|
gt_head_future = batch_future[:, :, :]
|
||||||
|
batch = batch[:, :, :self.conf.seq_len]
|
||||||
|
cond, pred_hand_future, pred_head_future = self.encode(batch)
|
||||||
|
xT = self.encode_stochastic(batch, cond)
|
||||||
|
pred_xstart = self.generate(xT, cond)
|
||||||
|
|
||||||
|
# head reconstruction error
|
||||||
|
gt_head = batch[:, :, :]
|
||||||
|
gt_head = F.normalize(gt_head, dim=1) # normalize head orientation to unit vectors
|
||||||
|
pred_head = pred_xstart[:, :, :]
|
||||||
|
pred_head = F.normalize(pred_head, dim=1)
|
||||||
|
head_ang = torch.mean(acos_safe(torch.sum(gt_head*pred_head, 1)))/torch.tensor(math.pi) * 180.0
|
||||||
|
|
||||||
|
# head prediction error
|
||||||
|
gt_head_future = F.normalize(gt_head_future, dim=1)
|
||||||
|
pred_head_future = F.normalize(pred_head_future, dim=1)
|
||||||
|
baseline_head_future = gt_head[:, :, -1:].expand(-1, -1, self.conf.seq_len_future).clone()
|
||||||
|
head_ang_future = torch.mean(acos_safe(torch.sum(gt_head_future*pred_head_future, 1)))/torch.tensor(math.pi) * 180.0
|
||||||
|
head_ang_future_baseline = torch.mean(acos_safe(torch.sum(gt_head_future*baseline_head_future, 1)))/torch.tensor(math.pi) * 180.0
|
||||||
|
|
||||||
|
self.log("val_head_ang", head_ang, on_epoch=True, prog_bar=True)
|
||||||
|
self.log("val_head_ang_future", head_ang_future, on_epoch=True, prog_bar=True)
|
||||||
|
self.log("val_head_ang_future_baseline", head_ang_future_baseline, on_epoch=True, prog_bar=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_step(self, batch, batch_idx):
|
||||||
|
batch_future = batch[:, :, self.conf.seq_len:]
|
||||||
|
gt_hand_future = batch_future[:, :6, :]
|
||||||
|
gt_head_future = batch_future[:, 6:, :]
|
||||||
|
batch = batch[:, :, :self.conf.seq_len]
|
||||||
|
cond, pred_hand_future, pred_head_future = self.encode(batch)
|
||||||
|
xT = self.encode_stochastic(batch, cond)
|
||||||
|
pred_xstart = self.generate(xT, cond)
|
||||||
|
|
||||||
|
# hand reconstruction error
|
||||||
|
gt_hand = batch[:, :6, :]
|
||||||
|
pred_hand = pred_xstart[:, :6, :]
|
||||||
|
bs, channels, seq_len = gt_hand.shape
|
||||||
|
gt_hand = gt_hand.reshape(bs, 2, 3, seq_len)
|
||||||
|
pred_hand = pred_hand.reshape(bs, 2, 3, seq_len)
|
||||||
|
hand_traj = torch.mean(torch.norm(gt_hand - pred_hand, dim=2))
|
||||||
|
|
||||||
|
# hand prediction error
|
||||||
|
bs, channels, seq_len = gt_hand_future.shape
|
||||||
|
gt_hand_future = gt_hand_future.reshape(bs, 2, 3, seq_len)
|
||||||
|
pred_hand_future = pred_hand_future.reshape(bs, 2, 3, seq_len)
|
||||||
|
baseline_hand_future = gt_hand[:, :, :, -1:].expand(-1, -1, -1, self.conf.seq_len_future).clone()
|
||||||
|
hand_traj_future = torch.mean(torch.norm(gt_hand_future - pred_hand_future, dim=2))
|
||||||
|
hand_traj_future_baseline = torch.mean(torch.norm(gt_hand_future - baseline_hand_future, dim=2))
|
||||||
|
|
||||||
|
# head reconstruction error
|
||||||
|
gt_head = batch[:, 6:, :]
|
||||||
|
gt_head = F.normalize(gt_head, dim=1) # normalize head orientation to unit vectors
|
||||||
|
pred_head = pred_xstart[:, 6:, :]
|
||||||
|
pred_head = F.normalize(pred_head, dim=1)
|
||||||
|
head_ang = torch.mean(acos_safe(torch.sum(gt_head*pred_head, 1)))/torch.tensor(math.pi) * 180.0
|
||||||
|
|
||||||
|
# head prediction error
|
||||||
|
gt_head_future = F.normalize(gt_head_future, dim=1)
|
||||||
|
pred_head_future = F.normalize(pred_head_future, dim=1)
|
||||||
|
baseline_head_future = gt_head[:, :, -1:].expand(-1, -1, self.conf.seq_len_future).clone()
|
||||||
|
head_ang_future = torch.mean(acos_safe(torch.sum(gt_head_future*pred_head_future, 1)))/torch.tensor(math.pi) * 180.0
|
||||||
|
head_ang_future_baseline = torch.mean(acos_safe(torch.sum(gt_head_future*baseline_head_future, 1)))/torch.tensor(math.pi) * 180.0
|
||||||
|
|
||||||
|
self.log("test_hand_traj", hand_traj, on_epoch=True, prog_bar=True)
|
||||||
|
self.log("test_head_ang", head_ang, on_epoch=True, prog_bar=True)
|
||||||
|
self.log("test_hand_traj_future", hand_traj_future, on_epoch=True, prog_bar=True)
|
||||||
|
self.log("test_head_ang_future", head_ang_future, on_epoch=True, prog_bar=True)
|
||||||
|
self.log("test_hand_traj_future_baseline", hand_traj_future_baseline, on_epoch=True, prog_bar=True)
|
||||||
|
self.log("test_head_ang_future_baseline", head_ang_future_baseline, on_epoch=True, prog_bar=True)
|
||||||
|
|
||||||
|
|
||||||
|
def generate(self, noise, cond=None, ema=True, T=None):
|
||||||
|
if T is None:
|
||||||
|
sampler = self.eval_sampler
|
||||||
|
else:
|
||||||
|
sampler = self.conf._make_diffusion_conf(T).make_sampler()
|
||||||
|
|
||||||
|
if ema:
|
||||||
|
model = self.ema_model
|
||||||
|
else:
|
||||||
|
model = self.model
|
||||||
|
|
||||||
|
gen = sampler.sample(model=model, noise=noise, model_kwargs={'cond': cond})
|
||||||
|
return gen
|
||||||
|
|
||||||
|
def on_train_batch_end(self, outputs, batch, batch_idx: int,
|
||||||
|
dataloader_idx: int) -> None:
|
||||||
|
"""
|
||||||
|
after each training step ...
|
||||||
|
"""
|
||||||
|
if self.is_last_accum(batch_idx):
|
||||||
|
# only apply ema on the last gradient accumulation step,
|
||||||
|
# if it is the iteration that has optimizer.step()
|
||||||
|
|
||||||
|
ema(self.model, self.ema_model, self.conf.ema_decay)
|
||||||
|
|
||||||
|
if (batch_idx==len(self.train_dataloader())-1) and ((self.current_epoch+1) % self.save_every_epochs == 0):
|
||||||
|
save_path = os.path.join(self.conf.logdir, 'epoch_%d.ckpt' % (self.current_epoch+1))
|
||||||
|
torch.save({
|
||||||
|
'state_dict': self.state_dict(),
|
||||||
|
'global_step': self.global_step,
|
||||||
|
'loss': outputs['loss'],
|
||||||
|
}, save_path)
|
||||||
|
|
||||||
|
def on_before_optimizer_step(self, optimizer: Optimizer,
|
||||||
|
optimizer_idx: int) -> None:
|
||||||
|
# fix the fp16 + clip grad norm problem with pytorch lightinng
|
||||||
|
# this is the currently correct way to do it
|
||||||
|
if self.conf.grad_clip > 0:
|
||||||
|
# from trainer.params_grads import grads_norm, iter_opt_params
|
||||||
|
params = [
|
||||||
|
p for group in optimizer.param_groups for p in group['params']
|
||||||
|
]
|
||||||
|
# print('before:', grads_norm(iter_opt_params(optimizer)))
|
||||||
|
torch.nn.utils.clip_grad_norm_(params, max_norm=self.conf.grad_clip)
|
||||||
|
# print('after:', grads_norm(iter_opt_params(optimizer)))
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
out = {}
|
||||||
|
if self.conf.optimizer == OptimizerType.adam:
|
||||||
|
optim = torch.optim.Adam(self.model.parameters(),
|
||||||
|
lr=self.conf.lr,
|
||||||
|
weight_decay=self.conf.weight_decay)
|
||||||
|
elif self.conf.optimizer == OptimizerType.adamw:
|
||||||
|
optim = torch.optim.AdamW(self.model.parameters(),
|
||||||
|
lr=self.conf.lr,
|
||||||
|
weight_decay=self.conf.weight_decay)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
out['optimizer'] = optim
|
||||||
|
if self.conf.warmup > 0:
|
||||||
|
sched = torch.optim.lr_scheduler.LambdaLR(optim,
|
||||||
|
lr_lambda=WarmupLR(
|
||||||
|
self.conf.warmup))
|
||||||
|
out['lr_scheduler'] = {
|
||||||
|
'scheduler': sched,
|
||||||
|
'interval': 'step',
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def ema(source, target, decay):
|
||||||
|
source_dict = source.state_dict()
|
||||||
|
target_dict = target.state_dict()
|
||||||
|
for key in source_dict.keys():
|
||||||
|
target_dict[key].data.copy_(target_dict[key].data * decay +
|
||||||
|
source_dict[key].data * (1 - decay))
|
||||||
|
|
||||||
|
|
||||||
|
class WarmupLR:
|
||||||
|
def __init__(self, warmup) -> None:
|
||||||
|
self.warmup = warmup
|
||||||
|
|
||||||
|
def __call__(self, step):
|
||||||
|
return min(step, self.warmup) / self.warmup
|
||||||
|
|
||||||
|
|
||||||
|
def train(conf: TrainConfig, model: LitModel, gpus):
|
||||||
|
checkpoint = ModelCheckpoint(dirpath=conf.logdir,
|
||||||
|
filename='last',
|
||||||
|
save_last=True,
|
||||||
|
save_top_k=1,
|
||||||
|
every_n_epochs=conf.save_every_epochs,
|
||||||
|
)
|
||||||
|
checkpoint_path = f'{conf.logdir}last.ckpt'
|
||||||
|
if os.path.exists(checkpoint_path):
|
||||||
|
resume = checkpoint_path
|
||||||
|
if conf.mode == 'train':
|
||||||
|
print('ckpt path:', checkpoint_path)
|
||||||
|
else:
|
||||||
|
print('checkpoint not found!')
|
||||||
|
resume = None
|
||||||
|
|
||||||
|
wandb_logger = pl_loggers.WandbLogger(project='haheae',
|
||||||
|
name='%s_%s'%(model.conf.data_name, conf.logdir.split('/')[-2]),
|
||||||
|
log_model=True,
|
||||||
|
save_dir=conf.logdir,
|
||||||
|
dir = conf.logdir,
|
||||||
|
config=vars(model.conf),
|
||||||
|
save_code=True,
|
||||||
|
settings=wandb.Settings(code_dir="."))
|
||||||
|
|
||||||
|
trainer = pl.Trainer(
|
||||||
|
max_epochs=conf.total_epochs,
|
||||||
|
resume_from_checkpoint=resume,
|
||||||
|
gpus=gpus,
|
||||||
|
precision=16 if conf.fp16 else 32,
|
||||||
|
callbacks=[
|
||||||
|
checkpoint,
|
||||||
|
LearningRateMonitor(),
|
||||||
|
],
|
||||||
|
logger= wandb_logger,
|
||||||
|
accumulate_grad_batches=conf.accum_batches,
|
||||||
|
progress_bar_refresh_rate=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
if conf.mode == 'train':
|
||||||
|
trainer.fit(model)
|
||||||
|
elif conf.mode == 'eval':
|
||||||
|
checkpoint_path = f'{conf.logdir}last.ckpt'
|
||||||
|
# load the latest checkpoint
|
||||||
|
print('loading from:', checkpoint_path)
|
||||||
|
state = torch.load(checkpoint_path)
|
||||||
|
model.load_state_dict(state['state_dict'])
|
||||||
|
|
||||||
|
test_datasets = ['egobody', 'adt', 'gimo']
|
||||||
|
for dataset_name in test_datasets:
|
||||||
|
if dataset_name == 'egobody':
|
||||||
|
data_dir = "/scratch/hu/pose_forecast/egobody_pose2gaze/"
|
||||||
|
test_data = load_egobody(data_dir, conf.seq_len+conf.seq_len_future, 1, train=0) # use the test set
|
||||||
|
elif dataset_name == 'adt':
|
||||||
|
data_dir = "/scratch/hu/pose_forecast/adt_pose2gaze/"
|
||||||
|
test_data = load_adt(data_dir, conf.seq_len+conf.seq_len_future, 1, train=2) # use the train+test set
|
||||||
|
elif dataset_name == 'gimo':
|
||||||
|
data_dir = "/scratch/hu/pose_forecast/gimo_pose2gaze/"
|
||||||
|
test_data = load_gimo(data_dir, conf.seq_len+conf.seq_len_future, 1, train=2) # use the train+test set
|
||||||
|
|
||||||
|
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=conf.batch_size_eval, shuffle=False)
|
||||||
|
results = trainer.test(model, dataloaders=test_dataloader, verbose=False)
|
||||||
|
print("\n\nTest on {}, dataset size: {}".format(dataset_name, test_data.shape))
|
||||||
|
print("test_hand_traj: {:.3f} cm".format(results[0]['test_hand_traj']*100))
|
||||||
|
print("test_head_ang: {:.3f} deg".format(results[0]['test_head_ang']))
|
||||||
|
print("test_hand_traj_future: {:.3f} cm".format(results[0]['test_hand_traj_future']*100))
|
||||||
|
print("test_head_ang_future: {:.3f} deg".format(results[0]['test_head_ang_future']))
|
||||||
|
print("test_hand_traj_future_baseline: {:.3f} cm".format(results[0]['test_hand_traj_future_baseline']*100))
|
||||||
|
print("test_head_ang_future_baseline: {:.3f} deg\n\n".format(results[0]['test_head_ang_future_baseline']))
|
||||||
|
|
||||||
|
wandb.finish()
|
||||||
|
|
||||||
|
|
||||||
|
def acos_safe(x, eps=1e-6):
|
||||||
|
slope = np.arccos(1-eps) / eps
|
||||||
|
buf = torch.empty_like(x)
|
||||||
|
good = abs(x) <= 1-eps
|
||||||
|
bad = ~good
|
||||||
|
sign = torch.sign(x[bad])
|
||||||
|
buf[good] = torch.acos(x[good])
|
||||||
|
buf[bad] = torch.acos(sign * (1 - eps)) - slope*sign*(abs(x[bad]) - 1 + eps)
|
||||||
|
return buf
|
||||||
|
|
||||||
|
|
||||||
|
def get_representation(model, dataset, conf, device='cuda'):
|
||||||
|
model = model.to(device)
|
||||||
|
model.eval()
|
||||||
|
dataloader = torch.utils.data.DataLoader(dataset, batch_size=conf.batch_size_eval, shuffle=False)
|
||||||
|
with torch.no_grad():
|
||||||
|
conds = [] # semantic representation
|
||||||
|
xTs = [] # stochastic representation
|
||||||
|
for batch in tqdm(dataloader, total=len(dataloader), desc='infer'):
|
||||||
|
batch = batch.to(device)
|
||||||
|
cond, _, _ = model.encode(batch)
|
||||||
|
xT = model.encode_stochastic(batch, cond)
|
||||||
|
cond_cpu = cond.cpu().data.numpy()
|
||||||
|
xT_cpu = xT.cpu().data.numpy()
|
||||||
|
if len(conds) == 0:
|
||||||
|
conds = cond_cpu
|
||||||
|
xTs = xT_cpu
|
||||||
|
else:
|
||||||
|
conds = np.concatenate((conds, cond_cpu), axis=0)
|
||||||
|
xTs = np.concatenate((xTs, xT_cpu), axis=0)
|
||||||
|
return conds, xTs
|
||||||
|
|
||||||
|
|
||||||
|
def generate_from_representation(model, conds, xTs, device='cuda'):
|
||||||
|
model = model.to(device)
|
||||||
|
model.eval()
|
||||||
|
conds = torch.from_numpy(conds).to(device)
|
||||||
|
xTs = torch.from_numpy(xTs).to(device)
|
||||||
|
rec = model.generate(xTs, conds)
|
||||||
|
rec = rec.cpu().data.numpy()
|
||||||
|
return rec
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_reconstruction(gt, rec):
|
||||||
|
# hand reconstruction error (cm)
|
||||||
|
gt_hand = gt[:, :6, :]
|
||||||
|
rec_hand = rec[:, :6, :]
|
||||||
|
bs, channels, seq_len = gt_hand.shape
|
||||||
|
gt_hand = gt_hand.reshape(bs, 2, 3, seq_len)
|
||||||
|
rec_hand = rec_hand.reshape(bs, 2, 3, seq_len)
|
||||||
|
hand_traj_errors = np.mean(np.mean(np.linalg.norm(gt_hand - rec_hand, axis=2), axis=1), axis=1)*100
|
||||||
|
|
||||||
|
# head reconstruction error (deg)
|
||||||
|
gt_head = gt[:, 6:, :]
|
||||||
|
gt_head_norm = np.linalg.norm(gt_head, axis=1, keepdims=True)
|
||||||
|
gt_head = gt_head/gt_head_norm
|
||||||
|
rec_head = rec[:, 6:, :]
|
||||||
|
rec_head_norm = np.linalg.norm(rec_head, axis=1, keepdims=True)
|
||||||
|
rec_head = rec_head/rec_head_norm
|
||||||
|
dot_sum = np.clip(np.sum(gt_head*rec_head, axis=1), -1, 1)
|
||||||
|
head_ang_errors = np.mean(np.arccos(dot_sum), axis=1)/np.pi * 180.0
|
||||||
|
return hand_traj_errors, head_ang_errors
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--gpus', default=7, type=int)
|
||||||
|
parser.add_argument('--mode', default='eval', type=str)
|
||||||
|
parser.add_argument('--encoder_type', default='gcn', type=str)
|
||||||
|
parser.add_argument('--model_name', default='haheae', type=str)
|
||||||
|
parser.add_argument('--hand_mse_factor', default=1.0, type=float)
|
||||||
|
parser.add_argument('--head_mse_factor', default=1.0, type=float)
|
||||||
|
parser.add_argument('--data_sample_rate', default=1, type=int)
|
||||||
|
parser.add_argument('--epoch', default=130, type=int)
|
||||||
|
parser.add_argument('--in_channels', default=9, type=int)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
conf = egobody_autoenc(args.mode, args.encoder_type, args.hand_mse_factor, args.head_mse_factor, args.data_sample_rate, args.epoch, args.in_channels)
|
||||||
|
model = LitModel(conf)
|
||||||
|
conf.logdir = f'{conf.logdir}{args.model_name}/'
|
||||||
|
print('log dir: {}'.format(conf.logdir))
|
||||||
|
MakeDir(conf.logdir)
|
||||||
|
|
||||||
|
if conf.mode == 'train' or conf.mode == 'eval': # train or evaluate the model
|
||||||
|
os.environ['WANDB_CACHE_DIR'] = conf.logdir
|
||||||
|
os.environ['WANDB_DATA_DIR'] = conf.logdir
|
||||||
|
# set wandb to not upload checkpoints, but all the others
|
||||||
|
os.environ['WANDB_IGNORE_GLOBS'] = '*.ckpt'
|
||||||
|
local_time = time.asctime(time.localtime(time.time()))
|
||||||
|
print('\n{} starts at {}'.format(conf.mode, local_time))
|
||||||
|
start_time = datetime.datetime.now()
|
||||||
|
train(conf, model, gpus=[args.gpus])
|
||||||
|
end_time = datetime.datetime.now()
|
||||||
|
total_time = (end_time - start_time).seconds/60
|
||||||
|
print('\nTotal time: {:.3f} min'.format(total_time))
|
||||||
|
local_time = time.asctime(time.localtime(time.time()))
|
||||||
|
print('\n{} ends at {}'.format(conf.mode, local_time))
|
6
model/__init__.py
Normal file
6
model/__init__.py
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
from typing import Union
|
||||||
|
from .unet import BeatGANsUNetModel, BeatGANsUNetConfig, GCNUNetModel, GCNUNetConfig
|
||||||
|
from .unet_autoenc import BeatGANsAutoencConfig, BeatGANsAutoencModel, GCNAutoencConfig, GCNAutoencModel
|
||||||
|
|
||||||
|
Model = Union[BeatGANsUNetModel, BeatGANsAutoencModel, GCNUNetModel, GCNAutoencModel]
|
||||||
|
ModelConfig = Union[BeatGANsUNetConfig, BeatGANsAutoencConfig, GCNUNetConfig,GCNAutoencConfig]
|
579
model/blocks.py
Normal file
579
model/blocks.py
Normal file
|
@ -0,0 +1,579 @@
|
||||||
|
import math, pdb
|
||||||
|
from abc import abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from numbers import Number
|
||||||
|
|
||||||
|
import torch as th
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from choices import *
|
||||||
|
from config_base import BaseConfig
|
||||||
|
from torch import nn
|
||||||
|
import numpy as np
|
||||||
|
from .nn import (avg_pool_nd, conv_nd, linear, normalization,
|
||||||
|
timestep_embedding, zero_module)
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
Any module where forward() takes timestep embeddings as a second argument.
|
||||||
|
"""
|
||||||
|
@abstractmethod
|
||||||
|
def forward(self, x, emb=None, cond=None, lateral=None):
|
||||||
|
"""
|
||||||
|
Apply the module to `x` given `emb` timestep embeddings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
||||||
|
"""
|
||||||
|
A sequential module that passes timestep embeddings to the children that
|
||||||
|
support it as an extra input.
|
||||||
|
"""
|
||||||
|
def forward(self, x, emb=None, cond=None, lateral=None):
|
||||||
|
for layer in self:
|
||||||
|
if isinstance(layer, TimestepBlock):
|
||||||
|
'''if layer(x, emb=emb, cond=cond, lateral=lateral).shape[-1]==10:
|
||||||
|
pdb.set_trace()'''
|
||||||
|
x = layer(x, emb=emb, cond=cond, lateral=lateral)
|
||||||
|
else:
|
||||||
|
'''if layer(x).shape[-1]==10:
|
||||||
|
pdb.set_trace()'''
|
||||||
|
x = layer(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ResBlockConfig(BaseConfig):
|
||||||
|
channels: int
|
||||||
|
emb_channels: int
|
||||||
|
dropout: float
|
||||||
|
out_channels: int = None
|
||||||
|
# condition the resblock with time (and encoder's output)
|
||||||
|
use_condition: bool = True
|
||||||
|
# whether to use 3x3 conv for skip path when the channels aren't matched
|
||||||
|
use_conv: bool = False
|
||||||
|
# dimension of conv (always 1 = 1d)
|
||||||
|
dims: int = 1
|
||||||
|
up: bool = False
|
||||||
|
down: bool = False
|
||||||
|
# whether to condition with both time & encoder's output
|
||||||
|
two_cond: bool = False
|
||||||
|
# number of encoders' output channels
|
||||||
|
cond_emb_channels: int = None
|
||||||
|
# suggest: False
|
||||||
|
has_lateral: bool = False
|
||||||
|
lateral_channels: int = None
|
||||||
|
# whether to init the convolution with zero weights
|
||||||
|
# this is default from BeatGANs and seems to help learning
|
||||||
|
use_zero_module: bool = True
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
self.out_channels = self.out_channels or self.channels
|
||||||
|
self.cond_emb_channels = self.cond_emb_channels or self.emb_channels
|
||||||
|
|
||||||
|
def make_model(self):
|
||||||
|
return ResBlock(self)
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock(TimestepBlock):
|
||||||
|
"""
|
||||||
|
A residual block that can optionally change the number of channels.
|
||||||
|
|
||||||
|
total layers:
|
||||||
|
in_layers
|
||||||
|
- norm
|
||||||
|
- act
|
||||||
|
- conv
|
||||||
|
out_layers
|
||||||
|
- norm
|
||||||
|
- (modulation)
|
||||||
|
- act
|
||||||
|
- conv
|
||||||
|
"""
|
||||||
|
def __init__(self, conf: ResBlockConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.conf = conf
|
||||||
|
|
||||||
|
#############################
|
||||||
|
# IN LAYERS
|
||||||
|
#############################
|
||||||
|
assert conf.lateral_channels is None
|
||||||
|
layers = [
|
||||||
|
normalization(conf.channels),
|
||||||
|
nn.SiLU(),
|
||||||
|
conv_nd(conf.dims, conf.channels, conf.out_channels, 3, padding=1) ## 3 is kernel size
|
||||||
|
]
|
||||||
|
self.in_layers = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
self.updown = conf.up or conf.down
|
||||||
|
|
||||||
|
if conf.up:
|
||||||
|
self.h_upd = Upsample(conf.channels, False, conf.dims)
|
||||||
|
self.x_upd = Upsample(conf.channels, False, conf.dims)
|
||||||
|
elif conf.down:
|
||||||
|
self.h_upd = Downsample(conf.channels, False, conf.dims)
|
||||||
|
self.x_upd = Downsample(conf.channels, False, conf.dims)
|
||||||
|
else:
|
||||||
|
self.h_upd = self.x_upd = nn.Identity()
|
||||||
|
|
||||||
|
#############################
|
||||||
|
# OUT LAYERS CONDITIONS
|
||||||
|
#############################
|
||||||
|
if conf.use_condition:
|
||||||
|
# condition layers for the out_layers
|
||||||
|
self.emb_layers = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
linear(conf.emb_channels, 2 * conf.out_channels),
|
||||||
|
)
|
||||||
|
|
||||||
|
if conf.two_cond:
|
||||||
|
self.cond_emb_layers = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
linear(conf.cond_emb_channels, conf.out_channels),
|
||||||
|
)
|
||||||
|
#############################
|
||||||
|
# OUT LAYERS (ignored when there is no condition)
|
||||||
|
#############################
|
||||||
|
# original version
|
||||||
|
conv = conv_nd(conf.dims,
|
||||||
|
conf.out_channels,
|
||||||
|
conf.out_channels,
|
||||||
|
3,
|
||||||
|
padding=1)
|
||||||
|
if conf.use_zero_module:
|
||||||
|
# zere out the weights
|
||||||
|
# it seems to help training
|
||||||
|
conv = zero_module(conv)
|
||||||
|
|
||||||
|
# construct the layers
|
||||||
|
# - norm
|
||||||
|
# - (modulation)
|
||||||
|
# - act
|
||||||
|
# - dropout
|
||||||
|
# - conv
|
||||||
|
layers = []
|
||||||
|
layers += [
|
||||||
|
normalization(conf.out_channels),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Dropout(p=conf.dropout),
|
||||||
|
conv,
|
||||||
|
]
|
||||||
|
self.out_layers = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
#############################
|
||||||
|
# SKIP LAYERS
|
||||||
|
#############################
|
||||||
|
if conf.out_channels == conf.channels:
|
||||||
|
# cannot be used with gatedconv, also gatedconv is alsways used as the first block
|
||||||
|
self.skip_connection = nn.Identity()
|
||||||
|
else:
|
||||||
|
if conf.use_conv:
|
||||||
|
kernel_size = 3
|
||||||
|
padding = 1
|
||||||
|
else:
|
||||||
|
kernel_size = 1
|
||||||
|
padding = 0
|
||||||
|
|
||||||
|
self.skip_connection = conv_nd(conf.dims,
|
||||||
|
conf.channels,
|
||||||
|
conf.out_channels,
|
||||||
|
kernel_size,
|
||||||
|
padding=padding)
|
||||||
|
|
||||||
|
def forward(self, x, emb=None, cond=None, lateral=None):
|
||||||
|
"""
|
||||||
|
Apply the block to a Tensor, conditioned on a timestep embedding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: input
|
||||||
|
lateral: lateral connection from the encoder
|
||||||
|
"""
|
||||||
|
return self._forward(x, emb, cond, lateral)
|
||||||
|
|
||||||
|
def _forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
emb=None,
|
||||||
|
cond=None,
|
||||||
|
lateral=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
lateral: required if "has_lateral" and non-gated, with gated, it can be supplied optionally
|
||||||
|
"""
|
||||||
|
if self.conf.has_lateral:
|
||||||
|
# lateral may be supplied even if it doesn't require
|
||||||
|
# the model will take the lateral only if "has_lateral"
|
||||||
|
assert lateral is not None
|
||||||
|
# x = F.interpolate(x, size=(lateral.size(2)), mode='linear' )
|
||||||
|
x = th.cat([x, lateral], dim=1)
|
||||||
|
|
||||||
|
if self.updown:
|
||||||
|
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
||||||
|
h = in_rest(x)
|
||||||
|
h = self.h_upd(h)
|
||||||
|
x = self.x_upd(x)
|
||||||
|
h = in_conv(h)
|
||||||
|
else:
|
||||||
|
h = self.in_layers(x)
|
||||||
|
|
||||||
|
if self.conf.use_condition:
|
||||||
|
# it's possible that the network may not receieve the time emb
|
||||||
|
# this happens with autoenc and setting the time_at
|
||||||
|
if emb is not None:
|
||||||
|
emb_out = self.emb_layers(emb).type(h.dtype)
|
||||||
|
else:
|
||||||
|
emb_out = None
|
||||||
|
|
||||||
|
if self.conf.two_cond:
|
||||||
|
# it's possible that the network is two_cond
|
||||||
|
# but it doesn't get the second condition
|
||||||
|
# in which case, we ignore the second condition
|
||||||
|
# and treat as if the network has one condition
|
||||||
|
if cond is None:
|
||||||
|
cond_out = None
|
||||||
|
else:
|
||||||
|
if not isinstance(cond, th.Tensor):
|
||||||
|
assert isinstance(cond, dict)
|
||||||
|
cond = cond['cond']
|
||||||
|
cond_out = self.cond_emb_layers(cond).type(h.dtype)
|
||||||
|
if cond_out is not None:
|
||||||
|
while len(cond_out.shape) < len(h.shape):
|
||||||
|
cond_out = cond_out[..., None]
|
||||||
|
else:
|
||||||
|
cond_out = None
|
||||||
|
|
||||||
|
# this is the new refactored code
|
||||||
|
h = apply_conditions(
|
||||||
|
h=h,
|
||||||
|
emb=emb_out,
|
||||||
|
cond=cond_out,
|
||||||
|
layers=self.out_layers,
|
||||||
|
scale_bias=1,
|
||||||
|
in_channels=self.conf.out_channels,
|
||||||
|
up_down_layer=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.skip_connection(x) + h
|
||||||
|
|
||||||
|
|
||||||
|
def apply_conditions(
|
||||||
|
h,
|
||||||
|
emb=None,
|
||||||
|
cond=None,
|
||||||
|
layers: nn.Sequential = None,
|
||||||
|
scale_bias: float = 1,
|
||||||
|
in_channels: int = 512,
|
||||||
|
up_down_layer: nn.Module = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
apply conditions on the feature maps
|
||||||
|
|
||||||
|
Args:
|
||||||
|
emb: time conditional (ready to scale + shift)
|
||||||
|
cond: encoder's conditional (ready to scale + shift)
|
||||||
|
"""
|
||||||
|
two_cond = emb is not None and cond is not None
|
||||||
|
|
||||||
|
if emb is not None:
|
||||||
|
# adjusting shapes
|
||||||
|
while len(emb.shape) < len(h.shape):
|
||||||
|
emb = emb[..., None]
|
||||||
|
|
||||||
|
if two_cond:
|
||||||
|
# adjusting shapes
|
||||||
|
while len(cond.shape) < len(h.shape):
|
||||||
|
cond = cond[..., None]
|
||||||
|
# time first
|
||||||
|
scale_shifts = [emb, cond]
|
||||||
|
else:
|
||||||
|
# "cond" is not used with single cond mode
|
||||||
|
scale_shifts = [emb]
|
||||||
|
|
||||||
|
# support scale, shift or shift only
|
||||||
|
for i, each in enumerate(scale_shifts):
|
||||||
|
if each is None:
|
||||||
|
# special case: the condition is not provided
|
||||||
|
a = None
|
||||||
|
b = None
|
||||||
|
else:
|
||||||
|
if each.shape[1] == in_channels * 2:
|
||||||
|
a, b = th.chunk(each, 2, dim=1)
|
||||||
|
else:
|
||||||
|
a = each
|
||||||
|
b = None
|
||||||
|
scale_shifts[i] = (a, b)
|
||||||
|
|
||||||
|
# condition scale bias could be a list
|
||||||
|
if isinstance(scale_bias, Number):
|
||||||
|
biases = [scale_bias] * len(scale_shifts)
|
||||||
|
else:
|
||||||
|
# a list
|
||||||
|
biases = scale_bias
|
||||||
|
|
||||||
|
# default, the scale & shift are applied after the group norm but BEFORE SiLU
|
||||||
|
pre_layers, post_layers = layers[0], layers[1:]
|
||||||
|
|
||||||
|
# spilt the post layer to be able to scale up or down before conv
|
||||||
|
# post layers will contain only the conv
|
||||||
|
mid_layers, post_layers = post_layers[:-2], post_layers[-2:]
|
||||||
|
|
||||||
|
h = pre_layers(h)
|
||||||
|
# scale and shift for each condition
|
||||||
|
for i, (scale, shift) in enumerate(scale_shifts):
|
||||||
|
# if scale is None, it indicates that the condition is not provided
|
||||||
|
if scale is not None:
|
||||||
|
h = h * (biases[i] + scale)
|
||||||
|
if shift is not None:
|
||||||
|
h = h + shift
|
||||||
|
h = mid_layers(h)
|
||||||
|
|
||||||
|
# upscale or downscale if any just before the last conv
|
||||||
|
if up_down_layer is not None:
|
||||||
|
h = up_down_layer(h)
|
||||||
|
h = post_layers(h)
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
class Upsample(nn.Module):
|
||||||
|
"""
|
||||||
|
An upsampling layer with an optional convolution.
|
||||||
|
|
||||||
|
:param channels: channels in the inputs and outputs.
|
||||||
|
:param use_conv: a bool determining if a convolution is applied.
|
||||||
|
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||||
|
upsampling occurs in the inner-two dimensions.
|
||||||
|
"""
|
||||||
|
def __init__(self, channels, use_conv, dims=2, out_channels=None):
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
self.out_channels = out_channels or channels
|
||||||
|
self.use_conv = use_conv
|
||||||
|
self.dims = dims
|
||||||
|
if use_conv:
|
||||||
|
self.conv = conv_nd(dims,
|
||||||
|
self.channels,
|
||||||
|
self.out_channels,
|
||||||
|
3,
|
||||||
|
padding=1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
assert x.shape[1] == self.channels
|
||||||
|
if self.dims == 3:
|
||||||
|
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
|
||||||
|
mode="nearest")
|
||||||
|
else:
|
||||||
|
# if x.shape[2] == 4:
|
||||||
|
# feature = 9
|
||||||
|
# x = F.interpolate(x, size=(feature), mode="nearest")
|
||||||
|
# if x.shape[2] == 8:
|
||||||
|
# feature = 9
|
||||||
|
# x = F.interpolate(x, size=(feature), mode="nearest")
|
||||||
|
# else:
|
||||||
|
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
||||||
|
if self.use_conv:
|
||||||
|
x = self.conv(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Downsample(nn.Module):
|
||||||
|
"""
|
||||||
|
A downsampling layer with an optional convolution.
|
||||||
|
|
||||||
|
:param channels: channels in the inputs and outputs.
|
||||||
|
:param use_conv: a bool determining if a convolution is applied.
|
||||||
|
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||||
|
downsampling occurs in the inner-two dimensions.
|
||||||
|
"""
|
||||||
|
def __init__(self, channels, use_conv, dims=2, out_channels=None):
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
self.out_channels = out_channels or channels
|
||||||
|
self.use_conv = use_conv
|
||||||
|
self.dims = dims
|
||||||
|
self.stride = 2 if dims != 3 else (1, 2, 2)
|
||||||
|
if use_conv:
|
||||||
|
self.op = conv_nd(dims,
|
||||||
|
self.channels,
|
||||||
|
self.out_channels,
|
||||||
|
3,
|
||||||
|
stride=self.stride,
|
||||||
|
padding=1)
|
||||||
|
else:
|
||||||
|
assert self.channels == self.out_channels
|
||||||
|
self.op = avg_pool_nd(dims, kernel_size=self.stride, stride=self.stride)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
assert x.shape[1] == self.channels
|
||||||
|
# if x.shape[2] % 2 != 0:
|
||||||
|
# op = avg_pool_nd(self.dims, kernel_size=3, stride=2)
|
||||||
|
# return op(x)
|
||||||
|
# if x.shape[2] % 2 != 0:
|
||||||
|
# op = avg_pool_nd(self.dims, kernel_size=2, stride=1)
|
||||||
|
# return op(x)
|
||||||
|
# else:
|
||||||
|
return self.op(x)
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
An attention block that allows spatial positions to attend to each other.
|
||||||
|
|
||||||
|
Originally ported from here, but adapted to the N-d case.
|
||||||
|
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels,
|
||||||
|
num_heads=1,
|
||||||
|
num_head_channels=-1,
|
||||||
|
use_new_attention_order=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
if num_head_channels == -1:
|
||||||
|
self.num_heads = num_heads
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
channels % num_head_channels == 0
|
||||||
|
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
||||||
|
self.num_heads = channels // num_head_channels
|
||||||
|
self.norm = normalization(channels)
|
||||||
|
self.qkv = conv_nd(1, channels, channels * 3, 1)
|
||||||
|
if use_new_attention_order:
|
||||||
|
# split qkv before split heads
|
||||||
|
self.attention = QKVAttention(self.num_heads)
|
||||||
|
else:
|
||||||
|
# split heads before split qkv
|
||||||
|
self.attention = QKVAttentionLegacy(self.num_heads)
|
||||||
|
|
||||||
|
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self._forward(x)
|
||||||
|
|
||||||
|
def _forward(self, x):
|
||||||
|
b, c, *spatial = x.shape
|
||||||
|
x = x.reshape(b, c, -1)
|
||||||
|
qkv = self.qkv(self.norm(x))
|
||||||
|
h = self.attention(qkv)
|
||||||
|
h = self.proj_out(h)
|
||||||
|
return (x + h).reshape(b, c, *spatial)
|
||||||
|
|
||||||
|
|
||||||
|
def count_flops_attn(model, _x, y):
|
||||||
|
"""
|
||||||
|
A counter for the `thop` package to count the operations in an
|
||||||
|
attention operation.
|
||||||
|
Meant to be used like:
|
||||||
|
macs, params = thop.profile(
|
||||||
|
model,
|
||||||
|
inputs=(inputs, timestamps),
|
||||||
|
custom_ops={QKVAttention: QKVAttention.count_flops},
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
b, c, *spatial = y[0].shape
|
||||||
|
num_spatial = int(np.prod(spatial))
|
||||||
|
# We perform two matmuls with the same number of ops.
|
||||||
|
# The first computes the weight matrix, the second computes
|
||||||
|
# the combination of the value vectors.
|
||||||
|
matmul_ops = 2 * b * (num_spatial**2) * c
|
||||||
|
model.total_ops += th.DoubleTensor([matmul_ops])
|
||||||
|
|
||||||
|
|
||||||
|
class QKVAttentionLegacy(nn.Module):
|
||||||
|
"""
|
||||||
|
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
|
||||||
|
"""
|
||||||
|
def __init__(self, n_heads):
|
||||||
|
super().__init__()
|
||||||
|
self.n_heads = n_heads
|
||||||
|
|
||||||
|
def forward(self, qkv):
|
||||||
|
"""
|
||||||
|
Apply QKV attention.
|
||||||
|
|
||||||
|
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
|
||||||
|
:return: an [N x (H * C) x T] tensor after attention.
|
||||||
|
"""
|
||||||
|
bs, width, length = qkv.shape
|
||||||
|
assert width % (3 * self.n_heads) == 0
|
||||||
|
ch = width // (3 * self.n_heads)
|
||||||
|
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
||||||
|
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||||
|
weight = th.einsum(
|
||||||
|
"bct,bcs->bts", q * scale,
|
||||||
|
k * scale) # More stable with f16 than dividing afterwards
|
||||||
|
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||||
|
a = th.einsum("bts,bcs->bct", weight, v)
|
||||||
|
return a.reshape(bs, -1, length)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def count_flops(model, _x, y):
|
||||||
|
return count_flops_attn(model, _x, y)
|
||||||
|
|
||||||
|
|
||||||
|
class QKVAttention(nn.Module):
|
||||||
|
"""
|
||||||
|
A module which performs QKV attention and splits in a different order.
|
||||||
|
"""
|
||||||
|
def __init__(self, n_heads):
|
||||||
|
super().__init__()
|
||||||
|
self.n_heads = n_heads
|
||||||
|
|
||||||
|
def forward(self, qkv):
|
||||||
|
"""
|
||||||
|
Apply QKV attention.
|
||||||
|
|
||||||
|
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
|
||||||
|
:return: an [N x (H * C) x T] tensor after attention.
|
||||||
|
"""
|
||||||
|
bs, width, length = qkv.shape
|
||||||
|
pdb.set_trace()
|
||||||
|
assert width % (3 * self.n_heads) == 0
|
||||||
|
ch = width // (3 * self.n_heads)
|
||||||
|
q, k, v = qkv.chunk(3, dim=1)
|
||||||
|
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||||
|
weight = th.einsum(
|
||||||
|
"bct,bcs->bts",
|
||||||
|
(q * scale).view(bs * self.n_heads, ch, length),
|
||||||
|
(k * scale).view(bs * self.n_heads, ch, length),
|
||||||
|
) # More stable with f16 than dividing afterwards
|
||||||
|
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||||
|
a = th.einsum("bts,bcs->bct", weight,
|
||||||
|
v.reshape(bs * self.n_heads, ch, length))
|
||||||
|
return a.reshape(bs, -1, length)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def count_flops(model, _x, y):
|
||||||
|
return count_flops_attn(model, _x, y)
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionPool2d(nn.Module):
|
||||||
|
"""
|
||||||
|
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
spacial_dim: int,
|
||||||
|
embed_dim: int,
|
||||||
|
num_heads_channels: int,
|
||||||
|
output_dim: int = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.positional_embedding = nn.Parameter(
|
||||||
|
th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5)
|
||||||
|
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
|
||||||
|
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
|
||||||
|
self.num_heads = embed_dim // num_heads_channels
|
||||||
|
self.attention = QKVAttention(self.num_heads)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
b, c, *_spatial = x.shape
|
||||||
|
x = x.reshape(b, c, -1) # NC(HW)
|
||||||
|
x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
|
||||||
|
x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
|
||||||
|
x = self.qkv_proj(x)
|
||||||
|
x = self.attention(x)
|
||||||
|
x = self.c_proj(x)
|
||||||
|
return x[:, :, 0]
|
173
model/graph_convolution_network.py
Normal file
173
model/graph_convolution_network.py
Normal file
|
@ -0,0 +1,173 @@
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
from numbers import Number
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from .blocks import *
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
class graph_convolution(nn.Module):
|
||||||
|
def __init__(self, in_features, out_features, node_n = 3, seq_len = 80, bias=True):
|
||||||
|
super(graph_convolution, self).__init__()
|
||||||
|
|
||||||
|
self.temporal_graph_weights = Parameter(torch.FloatTensor(seq_len, seq_len))
|
||||||
|
self.feature_weights = Parameter(torch.FloatTensor(in_features, out_features))
|
||||||
|
self.spatial_graph_weights = Parameter(torch.FloatTensor(node_n, node_n))
|
||||||
|
|
||||||
|
if bias:
|
||||||
|
self.bias = Parameter(torch.FloatTensor(seq_len))
|
||||||
|
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
stdv = 1. / math.sqrt(self.spatial_graph_weights.size(1))
|
||||||
|
self.feature_weights.data.uniform_(-stdv, stdv)
|
||||||
|
self.temporal_graph_weights.data.uniform_(-stdv, stdv)
|
||||||
|
self.spatial_graph_weights.data.uniform_(-stdv, stdv)
|
||||||
|
if self.bias is not None:
|
||||||
|
self.bias.data.uniform_(-stdv, stdv)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
y = torch.matmul(x, self.temporal_graph_weights)
|
||||||
|
y = torch.matmul(y.permute(0, 3, 2, 1), self.feature_weights)
|
||||||
|
y = torch.matmul(self.spatial_graph_weights, y).permute(0, 3, 2, 1).contiguous()
|
||||||
|
|
||||||
|
if self.bias is not None:
|
||||||
|
return (y + self.bias)
|
||||||
|
else:
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class residual_graph_convolution_config():
|
||||||
|
in_features: int
|
||||||
|
seq_len: int
|
||||||
|
emb_channels: int
|
||||||
|
dropout: float
|
||||||
|
out_features: int = None
|
||||||
|
node_n: int = 3
|
||||||
|
# condition the block with time (and encoder's output)
|
||||||
|
use_condition: bool = True
|
||||||
|
# whether to condition with both time & encoder's output
|
||||||
|
two_cond: bool = False
|
||||||
|
# number of encoders' output channels
|
||||||
|
cond_emb_channels: int = None
|
||||||
|
has_lateral: bool = False
|
||||||
|
graph_convolution_bias: bool = True
|
||||||
|
scale_bias: float = 1
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
self.out_features = self.out_features or self.in_features
|
||||||
|
self.cond_emb_channels = self.cond_emb_channels or self.emb_channels
|
||||||
|
|
||||||
|
def make_model(self):
|
||||||
|
return residual_graph_convolution(self)
|
||||||
|
|
||||||
|
|
||||||
|
class residual_graph_convolution(TimestepBlock):
|
||||||
|
def __init__(self, conf: residual_graph_convolution_config):
|
||||||
|
super(residual_graph_convolution, self).__init__()
|
||||||
|
self.conf = conf
|
||||||
|
|
||||||
|
self.gcn = graph_convolution(conf.in_features, conf.out_features, node_n=conf.node_n, seq_len=conf.seq_len, bias=conf.graph_convolution_bias)
|
||||||
|
self.ln = nn.LayerNorm([conf.out_features, conf.node_n, conf.seq_len])
|
||||||
|
self.act_f = nn.Tanh()
|
||||||
|
self.dropout = nn.Dropout(conf.dropout)
|
||||||
|
|
||||||
|
if conf.use_condition:
|
||||||
|
# condition layers for the out_layers
|
||||||
|
self.emb_layers = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(conf.emb_channels, conf.out_features),
|
||||||
|
)
|
||||||
|
|
||||||
|
if conf.two_cond:
|
||||||
|
self.cond_emb_layers = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(conf.cond_emb_channels, conf.out_features),
|
||||||
|
)
|
||||||
|
|
||||||
|
if conf.in_features == conf.out_features:
|
||||||
|
self.skip_connection = nn.Identity()
|
||||||
|
else:
|
||||||
|
self.skip_connection = nn.Sequential(
|
||||||
|
graph_convolution(conf.in_features, conf.out_features, node_n=conf.node_n, seq_len=conf.seq_len, bias=conf.graph_convolution_bias),
|
||||||
|
nn.Tanh(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, emb=None, cond=None, lateral=None):
|
||||||
|
if self.conf.has_lateral:
|
||||||
|
# lateral may be supplied even if it doesn't require
|
||||||
|
# the model will take the lateral only if "has_lateral"
|
||||||
|
assert lateral is not None
|
||||||
|
x = torch.cat((x, lateral), dim =1)
|
||||||
|
|
||||||
|
y = self.gcn(x)
|
||||||
|
y = self.ln(y)
|
||||||
|
|
||||||
|
if self.conf.use_condition:
|
||||||
|
if emb is not None:
|
||||||
|
emb = self.emb_layers(emb).type(x.dtype)
|
||||||
|
# adjusting shapes
|
||||||
|
while len(emb.shape) < len(y.shape):
|
||||||
|
emb = emb[..., None]
|
||||||
|
|
||||||
|
if self.conf.two_cond or True:
|
||||||
|
if cond is not None:
|
||||||
|
if not isinstance(cond, torch.Tensor):
|
||||||
|
assert isinstance(cond, dict)
|
||||||
|
cond = cond['cond']
|
||||||
|
cond = self.cond_emb_layers(cond).type(x.dtype)
|
||||||
|
while len(cond.shape) < len(y.shape):
|
||||||
|
cond = cond[..., None]
|
||||||
|
scales = [emb, cond]
|
||||||
|
else:
|
||||||
|
scales = [emb]
|
||||||
|
|
||||||
|
# condition scale bias could be a list
|
||||||
|
if isinstance(self.conf.scale_bias, Number):
|
||||||
|
biases = [self.conf.scale_bias] * len(scales)
|
||||||
|
else:
|
||||||
|
# a list
|
||||||
|
biases = self.conf.scale_bias
|
||||||
|
|
||||||
|
# scale for each condition
|
||||||
|
for i, scale in enumerate(scales):
|
||||||
|
# if scale is None, it indicates that the condition is not provided
|
||||||
|
if scale is not None:
|
||||||
|
y = y*(biases[i] + scale)
|
||||||
|
|
||||||
|
y = self.act_f(y)
|
||||||
|
y = self.dropout(y)
|
||||||
|
return self.skip_connection(x) + y
|
||||||
|
|
||||||
|
|
||||||
|
class graph_downsample(nn.Module):
|
||||||
|
"""
|
||||||
|
A downsampling layer
|
||||||
|
"""
|
||||||
|
def __init__(self, kernel_size = 2):
|
||||||
|
super().__init__()
|
||||||
|
self.downsample = nn.AvgPool1d(kernel_size = kernel_size)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
bs, features, node_n, seq_len = x.shape
|
||||||
|
x = x.reshape(bs, features*node_n, seq_len)
|
||||||
|
x = self.downsample(x)
|
||||||
|
x = x.reshape(bs, features, node_n, -1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class graph_upsample(nn.Module):
|
||||||
|
"""
|
||||||
|
An upsampling layer
|
||||||
|
"""
|
||||||
|
def __init__(self, scale_factor=2):
|
||||||
|
super().__init__()
|
||||||
|
self.scale_factor = scale_factor
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = F.interpolate(x, (x.shape[2], x.shape[3]*self.scale_factor), mode="nearest")
|
||||||
|
return x
|
141
model/nn.py
Normal file
141
model/nn.py
Normal file
|
@ -0,0 +1,141 @@
|
||||||
|
"""
|
||||||
|
Various utilities for neural networks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
import math, pdb
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch as th
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
||||||
|
class SiLU(nn.Module):
|
||||||
|
# @th.jit.script
|
||||||
|
def forward(self, x):
|
||||||
|
return x * th.sigmoid(x)
|
||||||
|
|
||||||
|
|
||||||
|
class GroupNorm32(nn.GroupNorm):
|
||||||
|
def forward(self, x):
|
||||||
|
y = super().forward(x.float()).type(x.dtype)
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
def conv_nd(dims, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Create a 1D, 2D, or 3D convolution module.
|
||||||
|
"""
|
||||||
|
assert dims==1
|
||||||
|
if dims == 1:
|
||||||
|
return nn.Conv1d(*args, **kwargs)
|
||||||
|
elif dims == 2:
|
||||||
|
return nn.Conv2d(*args, **kwargs)
|
||||||
|
elif dims == 3:
|
||||||
|
return nn.Conv3d(*args, **kwargs)
|
||||||
|
raise ValueError(f"unsupported dimensions: {dims}")
|
||||||
|
|
||||||
|
|
||||||
|
def linear(*args, **kwargs):
|
||||||
|
"""
|
||||||
|
Create a linear module.
|
||||||
|
"""
|
||||||
|
return nn.Linear(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def avg_pool_nd(dims, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Create a 1D, 2D, or 3D average pooling module.
|
||||||
|
"""
|
||||||
|
assert dims==1
|
||||||
|
if dims == 1:
|
||||||
|
return nn.AvgPool1d(*args, **kwargs)
|
||||||
|
elif dims == 2:
|
||||||
|
return nn.AvgPool2d(*args, **kwargs)
|
||||||
|
elif dims == 3:
|
||||||
|
return nn.AvgPool3d(*args, **kwargs)
|
||||||
|
raise ValueError(f"unsupported dimensions: {dims}")
|
||||||
|
|
||||||
|
|
||||||
|
def update_ema(target_params, source_params, rate=0.99):
|
||||||
|
"""
|
||||||
|
Update target parameters to be closer to those of source parameters using
|
||||||
|
an exponential moving average.
|
||||||
|
|
||||||
|
:param target_params: the target parameter sequence.
|
||||||
|
:param source_params: the source parameter sequence.
|
||||||
|
:param rate: the EMA rate (closer to 1 means slower).
|
||||||
|
"""
|
||||||
|
for targ, src in zip(target_params, source_params):
|
||||||
|
targ.detach().mul_(rate).add_(src, alpha=1 - rate)
|
||||||
|
|
||||||
|
|
||||||
|
def zero_module(module):
|
||||||
|
"""
|
||||||
|
Zero out the parameters of a module and return it.
|
||||||
|
"""
|
||||||
|
for p in module.parameters():
|
||||||
|
p.detach().zero_()
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
def scale_module(module, scale):
|
||||||
|
"""
|
||||||
|
Scale the parameters of a module and return it.
|
||||||
|
"""
|
||||||
|
for p in module.parameters():
|
||||||
|
p.detach().mul_(scale)
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
def mean_flat(tensor):
|
||||||
|
"""
|
||||||
|
Take the mean over all non-batch dimensions.
|
||||||
|
"""
|
||||||
|
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
||||||
|
|
||||||
|
|
||||||
|
def normalization(channels):
|
||||||
|
"""
|
||||||
|
Make a standard normalization layer.
|
||||||
|
|
||||||
|
:param channels: number of input channels.
|
||||||
|
:return: an nn.Module for normalization.
|
||||||
|
"""
|
||||||
|
# return GroupNorm32(channels, channels)
|
||||||
|
return GroupNorm32(min(4, channels), channels)
|
||||||
|
|
||||||
|
|
||||||
|
def timestep_embedding(timesteps, dim, max_period=10000):
|
||||||
|
"""
|
||||||
|
Create sinusoidal timestep embeddings.
|
||||||
|
|
||||||
|
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||||
|
These may be fractional.
|
||||||
|
:param dim: the dimension of the output.
|
||||||
|
:param max_period: controls the minimum frequency of the embeddings.
|
||||||
|
:return: an [N x dim] Tensor of positional embeddings.
|
||||||
|
"""
|
||||||
|
half = dim // 2
|
||||||
|
freqs = th.exp(-math.log(max_period) *
|
||||||
|
th.arange(start=0, end=half, dtype=th.float32) /
|
||||||
|
half).to(device=timesteps.device)
|
||||||
|
args = timesteps[:, None].float() * freqs[None]
|
||||||
|
embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
|
||||||
|
if dim % 2:
|
||||||
|
embedding = th.cat(
|
||||||
|
[embedding, th.zeros_like(embedding[:, :1])], dim=-1)
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
|
||||||
|
def torch_checkpoint(func, args, flag, preserve_rng_state=False):
|
||||||
|
# torch's gradient checkpoint works with automatic mixed precision, given torch >= 1.8
|
||||||
|
if flag:
|
||||||
|
return torch.utils.checkpoint.checkpoint(
|
||||||
|
func, *args, preserve_rng_state=preserve_rng_state)
|
||||||
|
else:
|
||||||
|
return func(*args)
|
954
model/unet.py
Normal file
954
model/unet.py
Normal file
|
@ -0,0 +1,954 @@
|
||||||
|
import math
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from numbers import Number
|
||||||
|
from typing import NamedTuple, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch as th
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from choices import *
|
||||||
|
from config_base import BaseConfig
|
||||||
|
from .blocks import *
|
||||||
|
from .graph_convolution_network import *
|
||||||
|
from .nn import (conv_nd, linear, normalization, timestep_embedding, zero_module)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BeatGANsUNetConfig(BaseConfig):
|
||||||
|
seq_len: int = 80
|
||||||
|
in_channels: int = 9
|
||||||
|
# base channels, will be multiplied
|
||||||
|
model_channels: int = 64
|
||||||
|
# output of the unet
|
||||||
|
out_channels: int = 9
|
||||||
|
# how many repeating resblocks per resolution
|
||||||
|
# the decoding side would have "one more" resblock
|
||||||
|
# default: 2
|
||||||
|
num_res_blocks: int = 2
|
||||||
|
# number of time embed channels and style channels
|
||||||
|
embed_channels: int = 256
|
||||||
|
# at what resolutions you want to do self-attention of the feature maps
|
||||||
|
# attentions generally improve performance
|
||||||
|
attention_resolutions: Tuple[int] = (0, )
|
||||||
|
# dropout applies to the resblocks (on feature maps)
|
||||||
|
dropout: float = 0.1
|
||||||
|
channel_mult: Tuple[int] = (1, 2, 4)
|
||||||
|
conv_resample: bool = True
|
||||||
|
# 1 = 1d conv
|
||||||
|
dims: int = 1
|
||||||
|
# number of attention heads
|
||||||
|
num_heads: int = 1
|
||||||
|
# or specify the number of channels per attention head
|
||||||
|
num_head_channels: int = -1
|
||||||
|
# use resblock for upscale/downscale blocks (expensive)
|
||||||
|
# default: True (BeatGANs)
|
||||||
|
resblock_updown: bool = True
|
||||||
|
use_new_attention_order: bool = False
|
||||||
|
resnet_two_cond: bool = True
|
||||||
|
resnet_cond_channels: int = None
|
||||||
|
# init the decoding conv layers with zero weights, this speeds up training
|
||||||
|
# default: True (BeatGANs)
|
||||||
|
resnet_use_zero_module: bool = True
|
||||||
|
|
||||||
|
def make_model(self):
|
||||||
|
return BeatGANsUNetModel(self)
|
||||||
|
|
||||||
|
|
||||||
|
class BeatGANsUNetModel(nn.Module):
|
||||||
|
def __init__(self, conf: BeatGANsUNetConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.conf = conf
|
||||||
|
|
||||||
|
self.dtype = th.float32
|
||||||
|
|
||||||
|
self.time_emb_channels = conf.model_channels
|
||||||
|
self.time_embed = nn.Sequential(
|
||||||
|
linear(self.time_emb_channels, conf.embed_channels),
|
||||||
|
nn.SiLU(),
|
||||||
|
linear(conf.embed_channels, conf.embed_channels),
|
||||||
|
)
|
||||||
|
|
||||||
|
ch = input_ch = int(conf.channel_mult[0] * conf.model_channels)
|
||||||
|
self.input_blocks = nn.ModuleList([
|
||||||
|
TimestepEmbedSequential(
|
||||||
|
conv_nd(conf.dims, conf.in_channels, ch, 3, padding=1)),
|
||||||
|
])
|
||||||
|
|
||||||
|
kwargs = dict(
|
||||||
|
use_condition=True,
|
||||||
|
two_cond=conf.resnet_two_cond,
|
||||||
|
use_zero_module=conf.resnet_use_zero_module,
|
||||||
|
# style channels for the resnet block
|
||||||
|
cond_emb_channels=conf.resnet_cond_channels,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._feature_size = ch
|
||||||
|
|
||||||
|
# input_block_chans = [ch]
|
||||||
|
input_block_chans = [[] for _ in range(len(conf.channel_mult))]
|
||||||
|
input_block_chans[0].append(ch)
|
||||||
|
|
||||||
|
# number of blocks at each resolution
|
||||||
|
self.input_num_blocks = [0 for _ in range(len(conf.channel_mult))]
|
||||||
|
self.input_num_blocks[0] = 1
|
||||||
|
self.output_num_blocks = [0 for _ in range(len(conf.channel_mult))]
|
||||||
|
|
||||||
|
ds = 1
|
||||||
|
resolution = conf.seq_len
|
||||||
|
for level, mult in enumerate(conf.channel_mult):
|
||||||
|
for _ in range(conf.num_res_blocks):
|
||||||
|
layers = [
|
||||||
|
ResBlockConfig(
|
||||||
|
ch,
|
||||||
|
conf.embed_channels,
|
||||||
|
conf.dropout,
|
||||||
|
out_channels=int(mult * conf.model_channels),
|
||||||
|
dims=conf.dims,
|
||||||
|
**kwargs,
|
||||||
|
).make_model()
|
||||||
|
]
|
||||||
|
ch = int(mult * conf.model_channels)
|
||||||
|
if resolution in conf.attention_resolutions:
|
||||||
|
layers.append(
|
||||||
|
AttentionBlock(
|
||||||
|
ch,
|
||||||
|
num_heads=conf.num_heads,
|
||||||
|
num_head_channels=conf.num_head_channels,
|
||||||
|
use_new_attention_order=conf.
|
||||||
|
use_new_attention_order,
|
||||||
|
))
|
||||||
|
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||||
|
self._feature_size += ch
|
||||||
|
# input_block_chans.append(ch)
|
||||||
|
input_block_chans[level].append(ch)
|
||||||
|
self.input_num_blocks[level] += 1
|
||||||
|
# print(input_block_chans)
|
||||||
|
if level != len(conf.channel_mult) - 1:
|
||||||
|
resolution //= 2
|
||||||
|
out_ch = ch
|
||||||
|
self.input_blocks.append(
|
||||||
|
TimestepEmbedSequential(
|
||||||
|
ResBlockConfig(
|
||||||
|
ch,
|
||||||
|
conf.embed_channels,
|
||||||
|
conf.dropout,
|
||||||
|
out_channels=out_ch,
|
||||||
|
dims=conf.dims,
|
||||||
|
down=True,
|
||||||
|
**kwargs,
|
||||||
|
).make_model() if conf.
|
||||||
|
resblock_updown else Downsample(ch,
|
||||||
|
conf.conv_resample,
|
||||||
|
dims=conf.dims,
|
||||||
|
out_channels=out_ch)))
|
||||||
|
ch = out_ch
|
||||||
|
# input_block_chans.append(ch)
|
||||||
|
input_block_chans[level + 1].append(ch)
|
||||||
|
self.input_num_blocks[level + 1] += 1
|
||||||
|
ds *= 2
|
||||||
|
self._feature_size += ch
|
||||||
|
|
||||||
|
self.middle_block = TimestepEmbedSequential(
|
||||||
|
ResBlockConfig(
|
||||||
|
ch,
|
||||||
|
conf.embed_channels,
|
||||||
|
conf.dropout,
|
||||||
|
dims=conf.dims,
|
||||||
|
**kwargs,
|
||||||
|
).make_model(),
|
||||||
|
#AttentionBlock(
|
||||||
|
# ch,
|
||||||
|
# num_heads=conf.num_heads,
|
||||||
|
# num_head_channels=conf.num_head_channels,
|
||||||
|
# use_new_attention_order=conf.use_new_attention_order,
|
||||||
|
#),
|
||||||
|
ResBlockConfig(
|
||||||
|
ch,
|
||||||
|
conf.embed_channels,
|
||||||
|
conf.dropout,
|
||||||
|
dims=conf.dims,
|
||||||
|
**kwargs,
|
||||||
|
).make_model(),
|
||||||
|
)
|
||||||
|
self._feature_size += ch
|
||||||
|
|
||||||
|
self.output_blocks = nn.ModuleList([])
|
||||||
|
for level, mult in list(enumerate(conf.channel_mult))[::-1]:
|
||||||
|
for i in range(conf.num_res_blocks + 1):
|
||||||
|
# print(input_block_chans)
|
||||||
|
# ich = input_block_chans.pop()
|
||||||
|
try:
|
||||||
|
ich = input_block_chans[level].pop()
|
||||||
|
except IndexError:
|
||||||
|
# this happens only when num_res_block > num_enc_res_block
|
||||||
|
# we will not have enough lateral (skip) connecions for all decoder blocks
|
||||||
|
ich = 0
|
||||||
|
# print('pop:', ich)
|
||||||
|
layers = [
|
||||||
|
ResBlockConfig(
|
||||||
|
# only direct channels when gated
|
||||||
|
channels=ch + ich,
|
||||||
|
emb_channels=conf.embed_channels,
|
||||||
|
dropout=conf.dropout,
|
||||||
|
out_channels=int(conf.model_channels * mult),
|
||||||
|
dims=conf.dims,
|
||||||
|
# lateral channels are described here when gated
|
||||||
|
has_lateral=True if ich > 0 else False,
|
||||||
|
lateral_channels=None,
|
||||||
|
**kwargs,
|
||||||
|
).make_model()
|
||||||
|
]
|
||||||
|
ch = int(conf.model_channels * mult)
|
||||||
|
if resolution in conf.attention_resolutions:
|
||||||
|
layers.append(
|
||||||
|
AttentionBlock(
|
||||||
|
ch,
|
||||||
|
num_heads=conf.num_heads,
|
||||||
|
num_head_channels=conf.num_head_channels,
|
||||||
|
use_new_attention_order=conf.
|
||||||
|
use_new_attention_order,
|
||||||
|
))
|
||||||
|
if level and i == conf.num_res_blocks:
|
||||||
|
resolution *= 2
|
||||||
|
out_ch = ch
|
||||||
|
layers.append(
|
||||||
|
ResBlockConfig(
|
||||||
|
ch,
|
||||||
|
conf.embed_channels,
|
||||||
|
conf.dropout,
|
||||||
|
out_channels=out_ch,
|
||||||
|
dims=conf.dims,
|
||||||
|
up=True,
|
||||||
|
**kwargs,
|
||||||
|
).make_model() if (
|
||||||
|
conf.resblock_updown
|
||||||
|
) else Upsample(ch,
|
||||||
|
conf.conv_resample,
|
||||||
|
dims=conf.dims,
|
||||||
|
out_channels=out_ch))
|
||||||
|
ds //= 2
|
||||||
|
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||||
|
self.output_num_blocks[level] += 1
|
||||||
|
self._feature_size += ch
|
||||||
|
|
||||||
|
# print(input_block_chans)
|
||||||
|
# print('inputs:', self.input_num_blocks)
|
||||||
|
# print('outputs:', self.output_num_blocks)
|
||||||
|
|
||||||
|
if conf.resnet_use_zero_module:
|
||||||
|
self.out = nn.Sequential(
|
||||||
|
normalization(ch),
|
||||||
|
nn.SiLU(),
|
||||||
|
zero_module(
|
||||||
|
conv_nd(conf.dims,
|
||||||
|
input_ch,
|
||||||
|
conf.out_channels,
|
||||||
|
3, ## kernel size
|
||||||
|
padding=1)),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.out = nn.Sequential(
|
||||||
|
normalization(ch),
|
||||||
|
nn.SiLU(),
|
||||||
|
conv_nd(conf.dims, input_ch, conf.out_channels, 3, padding=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, t, **kwargs):
|
||||||
|
"""
|
||||||
|
Apply the model to an input batch.
|
||||||
|
|
||||||
|
:param x: an [N x C x ...] Tensor of inputs.
|
||||||
|
:param timesteps: a 1-D batch of timesteps.
|
||||||
|
:return: an [N x C x ...] Tensor of outputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# hs = []
|
||||||
|
hs = [[] for _ in range(len(self.conf.channel_mult))]
|
||||||
|
emb = self.time_embed(timestep_embedding(t, self.time_emb_channels))
|
||||||
|
|
||||||
|
# new code supports input_num_blocks != output_num_blocks
|
||||||
|
h = x.type(self.dtype)
|
||||||
|
k = 0
|
||||||
|
for i in range(len(self.input_num_blocks)):
|
||||||
|
for j in range(self.input_num_blocks[i]):
|
||||||
|
h = self.input_blocks[k](h, emb=emb)
|
||||||
|
# print(i, j, h.shape)
|
||||||
|
hs[i].append(h) ## Get output from each layer
|
||||||
|
k += 1
|
||||||
|
assert k == len(self.input_blocks)
|
||||||
|
|
||||||
|
# middle blocks
|
||||||
|
h = self.middle_block(h, emb=emb)
|
||||||
|
|
||||||
|
# output blocks
|
||||||
|
k = 0
|
||||||
|
for i in range(len(self.output_num_blocks)):
|
||||||
|
for j in range(self.output_num_blocks[i]):
|
||||||
|
# take the lateral connection from the same layer (in reserve)
|
||||||
|
# until there is no more, use None
|
||||||
|
try:
|
||||||
|
lateral = hs[-i - 1].pop()
|
||||||
|
# print(i, j, lateral.shape)
|
||||||
|
except IndexError:
|
||||||
|
lateral = None
|
||||||
|
# print(i, j, lateral)
|
||||||
|
h = self.output_blocks[k](h, emb=emb, lateral=lateral)
|
||||||
|
k += 1
|
||||||
|
|
||||||
|
h = h.type(x.dtype)
|
||||||
|
pred = self.out(h)
|
||||||
|
return Return(pred=pred)
|
||||||
|
|
||||||
|
|
||||||
|
class Return(NamedTuple):
|
||||||
|
pred: th.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BeatGANsEncoderConfig(BaseConfig):
|
||||||
|
in_channels: int
|
||||||
|
seq_len: int = 80
|
||||||
|
num_res_blocks: int = 2
|
||||||
|
attention_resolutions: Tuple[int] = (0, )
|
||||||
|
model_channels: int = 32
|
||||||
|
out_channels: int = 256
|
||||||
|
dropout: float = 0.1
|
||||||
|
channel_mult: Tuple[int] = (1, 2, 4)
|
||||||
|
use_time_condition: bool = False
|
||||||
|
conv_resample: bool = True
|
||||||
|
dims: int = 1
|
||||||
|
num_heads: int = 1
|
||||||
|
num_head_channels: int = -1
|
||||||
|
resblock_updown: bool = True
|
||||||
|
use_new_attention_order: bool = False
|
||||||
|
pool: str = 'adaptivenonzero'
|
||||||
|
|
||||||
|
def make_model(self):
|
||||||
|
return BeatGANsEncoderModel(self)
|
||||||
|
|
||||||
|
|
||||||
|
class BeatGANsEncoderModel(nn.Module):
|
||||||
|
"""
|
||||||
|
The half UNet model with attention and timestep embedding.
|
||||||
|
|
||||||
|
For usage, see UNet.
|
||||||
|
"""
|
||||||
|
def __init__(self, conf: BeatGANsEncoderConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.conf = conf
|
||||||
|
self.dtype = th.float32
|
||||||
|
|
||||||
|
if conf.use_time_condition:
|
||||||
|
time_embed_dim = conf.model_channels
|
||||||
|
self.time_embed = nn.Sequential(
|
||||||
|
linear(conf.model_channels, time_embed_dim),
|
||||||
|
nn.SiLU(),
|
||||||
|
linear(time_embed_dim, time_embed_dim),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
time_embed_dim = None
|
||||||
|
|
||||||
|
ch = int(conf.channel_mult[0] * conf.model_channels)
|
||||||
|
self.input_blocks = nn.ModuleList([
|
||||||
|
TimestepEmbedSequential(
|
||||||
|
conv_nd(conf.dims, conf.in_channels, ch, 3, padding=1),
|
||||||
|
)
|
||||||
|
])
|
||||||
|
self._feature_size = ch
|
||||||
|
input_block_chans = [ch]
|
||||||
|
ds = 1
|
||||||
|
resolution = conf.seq_len
|
||||||
|
for level, mult in enumerate(conf.channel_mult):
|
||||||
|
for _ in range(conf.num_res_blocks):
|
||||||
|
layers = [
|
||||||
|
ResBlockConfig(
|
||||||
|
ch,
|
||||||
|
time_embed_dim,
|
||||||
|
conf.dropout,
|
||||||
|
out_channels=int(mult * conf.model_channels),
|
||||||
|
dims=conf.dims,
|
||||||
|
use_condition=conf.use_time_condition,
|
||||||
|
).make_model()
|
||||||
|
]
|
||||||
|
ch = int(mult * conf.model_channels)
|
||||||
|
if resolution in conf.attention_resolutions:
|
||||||
|
layers.append(
|
||||||
|
AttentionBlock(
|
||||||
|
ch,
|
||||||
|
num_heads=conf.num_heads,
|
||||||
|
num_head_channels=conf.num_head_channels,
|
||||||
|
use_new_attention_order=conf.
|
||||||
|
use_new_attention_order,
|
||||||
|
))
|
||||||
|
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||||
|
self._feature_size += ch
|
||||||
|
input_block_chans.append(ch)
|
||||||
|
if level != len(conf.channel_mult) - 1:
|
||||||
|
resolution //= 2
|
||||||
|
out_ch = ch
|
||||||
|
self.input_blocks.append(
|
||||||
|
TimestepEmbedSequential(
|
||||||
|
ResBlockConfig(
|
||||||
|
ch,
|
||||||
|
time_embed_dim,
|
||||||
|
conf.dropout,
|
||||||
|
out_channels=out_ch,
|
||||||
|
dims=conf.dims,
|
||||||
|
use_condition=conf.use_time_condition,
|
||||||
|
down=True,
|
||||||
|
).make_model() if (
|
||||||
|
conf.resblock_updown
|
||||||
|
) else Downsample(ch,
|
||||||
|
conf.conv_resample,
|
||||||
|
dims=conf.dims,
|
||||||
|
out_channels=out_ch)))
|
||||||
|
ch = out_ch
|
||||||
|
input_block_chans.append(ch)
|
||||||
|
ds *= 2
|
||||||
|
self._feature_size += ch
|
||||||
|
|
||||||
|
self.middle_block = TimestepEmbedSequential(
|
||||||
|
ResBlockConfig(
|
||||||
|
ch,
|
||||||
|
time_embed_dim,
|
||||||
|
conf.dropout,
|
||||||
|
dims=conf.dims,
|
||||||
|
use_condition=conf.use_time_condition,
|
||||||
|
).make_model(),
|
||||||
|
AttentionBlock(
|
||||||
|
ch,
|
||||||
|
num_heads=conf.num_heads,
|
||||||
|
num_head_channels=conf.num_head_channels,
|
||||||
|
use_new_attention_order=conf.use_new_attention_order,
|
||||||
|
),
|
||||||
|
ResBlockConfig(
|
||||||
|
ch,
|
||||||
|
time_embed_dim,
|
||||||
|
conf.dropout,
|
||||||
|
dims=conf.dims,
|
||||||
|
use_condition=conf.use_time_condition,
|
||||||
|
).make_model(),
|
||||||
|
)
|
||||||
|
self._feature_size += ch
|
||||||
|
if conf.pool == "adaptivenonzero":
|
||||||
|
self.out = nn.Sequential(
|
||||||
|
normalization(ch),
|
||||||
|
nn.SiLU(),
|
||||||
|
## nn.AdaptiveAvgPool2d((1, 1)),
|
||||||
|
nn.AdaptiveAvgPool1d(1),
|
||||||
|
conv_nd(conf.dims, ch, conf.out_channels, 1),
|
||||||
|
nn.Flatten(),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unexpected {conf.pool} pooling")
|
||||||
|
|
||||||
|
def forward(self, x, t=None):
|
||||||
|
"""
|
||||||
|
Apply the model to an input batch.
|
||||||
|
|
||||||
|
:param x: an [N x C x ...] Tensor of inputs.
|
||||||
|
:param timesteps: a 1-D batch of timesteps.
|
||||||
|
:return: an [N x K] Tensor of outputs.
|
||||||
|
"""
|
||||||
|
if self.conf.use_time_condition:
|
||||||
|
emb = self.time_embed(timestep_embedding(t, self.model_channels))
|
||||||
|
else: ## autoencoding.py
|
||||||
|
emb = None
|
||||||
|
|
||||||
|
results = []
|
||||||
|
h = x.type(self.dtype)
|
||||||
|
for module in self.input_blocks: ## flow input x over all the input blocks
|
||||||
|
h = module(h, emb=emb)
|
||||||
|
if self.conf.pool.startswith("spatial"):
|
||||||
|
results.append(h.type(x.dtype).mean(dim=(2, 3)))
|
||||||
|
h = self.middle_block(h, emb=emb) ## TimestepEmbedSequential(...)
|
||||||
|
if self.conf.pool.startswith("spatial"):
|
||||||
|
results.append(h.type(x.dtype).mean(dim=(2, 3)))
|
||||||
|
h = th.cat(results, axis=-1)
|
||||||
|
else: ## autoencoder.py
|
||||||
|
h = h.type(x.dtype)
|
||||||
|
|
||||||
|
h = h.float()
|
||||||
|
h = self.out(h)
|
||||||
|
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GCNUNetConfig(BaseConfig):
|
||||||
|
in_channels: int = 9
|
||||||
|
node_n: int = 3
|
||||||
|
seq_len: int = 80
|
||||||
|
# base channels, will be multiplied
|
||||||
|
model_channels: int = 32
|
||||||
|
# output of the unet
|
||||||
|
out_channels: int = 9
|
||||||
|
# how many repeating resblocks per resolution
|
||||||
|
num_res_blocks: int = 8
|
||||||
|
# number of time embed channels and style channels
|
||||||
|
embed_channels: int = 256
|
||||||
|
# dropout applies to the resblocks
|
||||||
|
dropout: float = 0.1
|
||||||
|
channel_mult: Tuple[int] = (1, 2, 4)
|
||||||
|
resnet_two_cond: bool = True
|
||||||
|
|
||||||
|
def make_model(self):
|
||||||
|
return GCNUNetModel(self)
|
||||||
|
|
||||||
|
|
||||||
|
class GCNUNetModel(nn.Module):
|
||||||
|
def __init__(self, conf: GCNUNetConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.conf = conf
|
||||||
|
self.dtype = th.float32
|
||||||
|
assert conf.in_channels%conf.node_n == 0
|
||||||
|
self.in_features = conf.in_channels//conf.node_n
|
||||||
|
|
||||||
|
self.time_emb_channels = conf.model_channels*4
|
||||||
|
self.time_embed = nn.Sequential(
|
||||||
|
linear(self.time_emb_channels, conf.embed_channels),
|
||||||
|
nn.SiLU(),
|
||||||
|
linear(conf.embed_channels, conf.embed_channels),
|
||||||
|
)
|
||||||
|
|
||||||
|
ch = int(conf.channel_mult[0] * conf.model_channels)
|
||||||
|
self.input_blocks = nn.ModuleList([
|
||||||
|
TimestepEmbedSequential(
|
||||||
|
graph_convolution(in_features=self.in_features, out_features=ch, node_n=conf.node_n, seq_len=conf.seq_len)),
|
||||||
|
])
|
||||||
|
|
||||||
|
kwargs = dict(
|
||||||
|
use_condition=True,
|
||||||
|
two_cond=conf.resnet_two_cond,
|
||||||
|
)
|
||||||
|
|
||||||
|
input_block_chans = [[] for _ in range(len(conf.channel_mult))]
|
||||||
|
input_block_chans[0].append(ch)
|
||||||
|
|
||||||
|
# number of blocks at each resolution
|
||||||
|
self.input_num_blocks = [0 for _ in range(len(conf.channel_mult))]
|
||||||
|
self.input_num_blocks[0] = 1
|
||||||
|
self.output_num_blocks = [0 for _ in range(len(conf.channel_mult))]
|
||||||
|
|
||||||
|
ds = 1
|
||||||
|
resolution = conf.seq_len
|
||||||
|
for level, mult in enumerate(conf.channel_mult):
|
||||||
|
for _ in range(conf.num_res_blocks):
|
||||||
|
layers = [
|
||||||
|
residual_graph_convolution_config(
|
||||||
|
in_features=ch,
|
||||||
|
seq_len=resolution,
|
||||||
|
emb_channels = conf.embed_channels,
|
||||||
|
dropout=conf.dropout,
|
||||||
|
out_features=int(mult * conf.model_channels),
|
||||||
|
node_n=conf.node_n,
|
||||||
|
**kwargs,
|
||||||
|
).make_model()
|
||||||
|
]
|
||||||
|
ch = int(mult * conf.model_channels)
|
||||||
|
self.input_blocks.append(*layers)
|
||||||
|
input_block_chans[level].append(ch)
|
||||||
|
self.input_num_blocks[level] += 1
|
||||||
|
if level != len(conf.channel_mult) - 1:
|
||||||
|
resolution //= 2
|
||||||
|
out_ch = ch
|
||||||
|
self.input_blocks.append(
|
||||||
|
TimestepEmbedSequential(
|
||||||
|
graph_downsample()))
|
||||||
|
ch = out_ch
|
||||||
|
input_block_chans[level + 1].append(ch)
|
||||||
|
self.input_num_blocks[level + 1] += 1
|
||||||
|
ds *= 2
|
||||||
|
|
||||||
|
self.output_blocks = nn.ModuleList([])
|
||||||
|
for level, mult in list(enumerate(conf.channel_mult))[::-1]:
|
||||||
|
for i in range(conf.num_res_blocks + 1):
|
||||||
|
try:
|
||||||
|
ich = input_block_chans[level].pop()
|
||||||
|
except IndexError:
|
||||||
|
# this happens only when num_res_block > num_enc_res_block
|
||||||
|
# we will not have enough lateral (skip) connecions for all decoder blocks
|
||||||
|
ich = 0
|
||||||
|
layers = [
|
||||||
|
residual_graph_convolution_config(
|
||||||
|
in_features=ch + ich,
|
||||||
|
seq_len=resolution,
|
||||||
|
emb_channels = conf.embed_channels,
|
||||||
|
dropout=conf.dropout,
|
||||||
|
out_features=int(mult * conf.model_channels),
|
||||||
|
node_n=conf.node_n,
|
||||||
|
has_lateral=True if ich > 0 else False,
|
||||||
|
**kwargs,
|
||||||
|
).make_model()
|
||||||
|
]
|
||||||
|
ch = int(mult*conf.model_channels)
|
||||||
|
if level and i == conf.num_res_blocks:
|
||||||
|
resolution *= 2
|
||||||
|
out_ch = ch
|
||||||
|
layers.append(graph_upsample())
|
||||||
|
ds //= 2
|
||||||
|
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||||
|
self.output_num_blocks[level] += 1
|
||||||
|
|
||||||
|
self.out = nn.Sequential(
|
||||||
|
graph_convolution(in_features=ch, out_features=self.in_features, node_n=conf.node_n, seq_len=conf.seq_len),
|
||||||
|
nn.Tanh(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, t, **kwargs):
|
||||||
|
"""
|
||||||
|
Apply the model to an input batch.
|
||||||
|
|
||||||
|
:param x: an [N x C x ...] Tensor of inputs.
|
||||||
|
:param timesteps: a 1-D batch of timesteps.
|
||||||
|
:return: an [N x C x ...] Tensor of outputs.
|
||||||
|
"""
|
||||||
|
bs, channels, seq_len = x.shape
|
||||||
|
x = x.reshape(bs, self.conf.node_n, self.in_features, seq_len).permute(0, 2, 1, 3)
|
||||||
|
|
||||||
|
hs = [[] for _ in range(len(self.conf.channel_mult))]
|
||||||
|
emb = self.time_embed(timestep_embedding(t, self.time_emb_channels))
|
||||||
|
|
||||||
|
# new code supports input_num_blocks != output_num_blocks
|
||||||
|
h = x.type(self.dtype)
|
||||||
|
k = 0
|
||||||
|
for i in range(len(self.input_num_blocks)):
|
||||||
|
for j in range(self.input_num_blocks[i]):
|
||||||
|
h = self.input_blocks[k](h, emb=emb)
|
||||||
|
# print(i, j, h.shape)
|
||||||
|
hs[i].append(h) ## Get output from each layer
|
||||||
|
k += 1
|
||||||
|
assert k == len(self.input_blocks)
|
||||||
|
|
||||||
|
# output blocks
|
||||||
|
k = 0
|
||||||
|
for i in range(len(self.output_num_blocks)):
|
||||||
|
for j in range(self.output_num_blocks[i]):
|
||||||
|
# take the lateral connection from the same layer (in reserve)
|
||||||
|
# until there is no more, use None
|
||||||
|
try:
|
||||||
|
lateral = hs[-i - 1].pop()
|
||||||
|
# print(i, j, lateral.shape)
|
||||||
|
except IndexError:
|
||||||
|
lateral = None
|
||||||
|
# print(i, j, lateral)
|
||||||
|
h = self.output_blocks[k](h, emb=emb, lateral=lateral)
|
||||||
|
k += 1
|
||||||
|
|
||||||
|
h = h.type(x.dtype)
|
||||||
|
pred = self.out(h)
|
||||||
|
pred = pred.permute(0, 2, 1, 3).reshape(bs, -1, seq_len)
|
||||||
|
|
||||||
|
return Return(pred=pred)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GCNEncoderConfig(BaseConfig):
|
||||||
|
in_channels: int
|
||||||
|
in_features = 3 # features for one node
|
||||||
|
seq_len: int = 40
|
||||||
|
seq_len_future: int = 3
|
||||||
|
num_res_blocks: int = 2
|
||||||
|
model_channels: int = 32
|
||||||
|
out_channels: int = 32
|
||||||
|
dropout: float = 0.1
|
||||||
|
channel_mult: Tuple[int] = (1, 2, 4)
|
||||||
|
use_time_condition: bool = False
|
||||||
|
|
||||||
|
def make_model(self):
|
||||||
|
return GCNEncoderModel(self)
|
||||||
|
|
||||||
|
|
||||||
|
class GCNEncoderModel(nn.Module):
|
||||||
|
def __init__(self, conf: GCNEncoderConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.conf = conf
|
||||||
|
self.dtype = th.float32
|
||||||
|
assert conf.in_channels%conf.in_features == 0
|
||||||
|
self.in_features = conf.in_features
|
||||||
|
self.node_n = conf.in_channels//conf.in_features
|
||||||
|
|
||||||
|
ch = int(conf.channel_mult[0] * conf.model_channels)
|
||||||
|
self.input_blocks = nn.ModuleList([
|
||||||
|
graph_convolution(in_features=self.in_features, out_features=ch, node_n=self.node_n, seq_len=conf.seq_len),
|
||||||
|
])
|
||||||
|
input_block_chans = [ch]
|
||||||
|
ds = 1
|
||||||
|
resolution = conf.seq_len
|
||||||
|
for level, mult in enumerate(conf.channel_mult):
|
||||||
|
for _ in range(conf.num_res_blocks):
|
||||||
|
layers = [
|
||||||
|
residual_graph_convolution_config(
|
||||||
|
in_features=ch,
|
||||||
|
seq_len=resolution,
|
||||||
|
emb_channels = None,
|
||||||
|
dropout=conf.dropout,
|
||||||
|
out_features=int(mult * conf.model_channels),
|
||||||
|
node_n=self.node_n,
|
||||||
|
use_condition=conf.use_time_condition,
|
||||||
|
).make_model()
|
||||||
|
]
|
||||||
|
ch = int(mult * conf.model_channels)
|
||||||
|
self.input_blocks.append(*layers)
|
||||||
|
input_block_chans.append(ch)
|
||||||
|
if level != len(conf.channel_mult) - 1:
|
||||||
|
resolution //= 2
|
||||||
|
out_ch = ch
|
||||||
|
self.input_blocks.append(
|
||||||
|
graph_downsample())
|
||||||
|
ch = out_ch
|
||||||
|
input_block_chans.append(ch)
|
||||||
|
ds *= 2
|
||||||
|
|
||||||
|
self.hand_prediction = nn.Sequential(
|
||||||
|
conv_nd(1, ch*2, ch*2, 3, padding=1),
|
||||||
|
nn.LayerNorm([ch*2, conf.seq_len_future]),
|
||||||
|
nn.Tanh(),
|
||||||
|
conv_nd(1, ch*2, self.in_features*2, 1),
|
||||||
|
nn.Tanh(),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.head_prediction = nn.Sequential(
|
||||||
|
conv_nd(1, ch, ch, 3, padding=1),
|
||||||
|
nn.LayerNorm([ch, conf.seq_len_future]),
|
||||||
|
nn.Tanh(),
|
||||||
|
conv_nd(1, ch, self.in_features, 1),
|
||||||
|
nn.Tanh(),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.out = nn.Sequential(
|
||||||
|
nn.AdaptiveAvgPool1d(1),
|
||||||
|
conv_nd(1, ch*self.node_n, conf.out_channels, 1),
|
||||||
|
nn.Flatten(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x, t=None):
|
||||||
|
bs, channels, seq_len = x.shape
|
||||||
|
|
||||||
|
if self.node_n == 3: # both hand and head
|
||||||
|
hand_last = x[:, :6, -1:].expand(-1, -1, self.conf.seq_len_future).clone() #last observed hand position
|
||||||
|
head_last = x[:, 6:, -1:].expand(-1, -1, self.conf.seq_len_future).clone()# last observed head orientation
|
||||||
|
|
||||||
|
if self.node_n == 2: # hand only
|
||||||
|
hand_last = x[:, :, -1:].expand(-1, -1, self.conf.seq_len_future).clone() #last observed hand position
|
||||||
|
|
||||||
|
if self.node_n == 1: # head only
|
||||||
|
head_last = x[:, :, -1:].expand(-1, -1, self.conf.seq_len_future).clone()# last observed head orientation
|
||||||
|
|
||||||
|
|
||||||
|
x = x.reshape(bs, self.node_n, self.in_features, seq_len).permute(0, 2, 1, 3)
|
||||||
|
|
||||||
|
h = x.type(self.dtype)
|
||||||
|
for module in self.input_blocks:
|
||||||
|
h = module(h)
|
||||||
|
|
||||||
|
h = h.type(x.dtype)
|
||||||
|
h = h.float()
|
||||||
|
bs, features, node_n, seq_len = h.shape
|
||||||
|
|
||||||
|
if self.node_n == 3: # both hand and head
|
||||||
|
hand_features = h[:, :, :2, -self.conf.seq_len_future:].reshape(bs, features*2, -1)
|
||||||
|
head_features = h[:, :, 2:, -self.conf.seq_len_future:].reshape(bs, features, -1)
|
||||||
|
|
||||||
|
pred_hand = self.hand_prediction(hand_features) + hand_last
|
||||||
|
pred_head = self.head_prediction(head_features) + head_last
|
||||||
|
pred_head = F.normalize(pred_head, dim=1)# normalize head orientation to unit vectors
|
||||||
|
|
||||||
|
if self.node_n == 2: # hand only
|
||||||
|
hand_features = h[:, :, :, -self.conf.seq_len_future:].reshape(bs, features*2, -1)
|
||||||
|
pred_hand = self.hand_prediction(hand_features) + hand_last
|
||||||
|
pred_head = None
|
||||||
|
|
||||||
|
if self.node_n == 1: # head only
|
||||||
|
head_features = h[:, :, :, -self.conf.seq_len_future:].reshape(bs, features, -1)
|
||||||
|
pred_head = self.head_prediction(head_features) + head_last
|
||||||
|
pred_head = F.normalize(pred_head, dim=1)# normalize head orientation to unit vectors
|
||||||
|
pred_hand = None
|
||||||
|
|
||||||
|
h = h.reshape(bs, features*node_n, seq_len)
|
||||||
|
h = self.out(h)
|
||||||
|
|
||||||
|
return h, pred_hand, pred_head
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CNNEncoderConfig(BaseConfig):
|
||||||
|
in_channels: int
|
||||||
|
seq_len: int = 40
|
||||||
|
seq_len_future: int = 3
|
||||||
|
out_channels: int = 128
|
||||||
|
|
||||||
|
def make_model(self):
|
||||||
|
return CNNEncoderModel(self)
|
||||||
|
|
||||||
|
|
||||||
|
class CNNEncoderModel(nn.Module):
|
||||||
|
def __init__(self, conf: CNNEncoderConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.conf = conf
|
||||||
|
self.dtype = th.float32
|
||||||
|
input_dim = conf.in_channels
|
||||||
|
length = conf.seq_len
|
||||||
|
out_channels = conf.out_channels
|
||||||
|
|
||||||
|
self.encoder = nn.Sequential(
|
||||||
|
nn.Conv1d(input_dim, 32, kernel_size=3, padding=1),
|
||||||
|
nn.LayerNorm([32, length]),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Conv1d(32, 32, kernel_size=3, padding=1),
|
||||||
|
nn.LayerNorm([32, length]),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Conv1d(32, 32, kernel_size=3, padding=1),
|
||||||
|
nn.LayerNorm([32, length]),
|
||||||
|
nn.ReLU(inplace=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.out = nn.Linear(32 * length, out_channels)
|
||||||
|
|
||||||
|
def forward(self, x, t=None):
|
||||||
|
bs, channels, seq_len = x.shape
|
||||||
|
hand_last = x[:, :6, -1:].expand(-1, -1, self.conf.seq_len_future).clone() #last observed hand position
|
||||||
|
head_last = x[:, 6:, -1:].expand(-1, -1, self.conf.seq_len_future).clone()# last observed head orientation
|
||||||
|
|
||||||
|
h = x.type(self.dtype)
|
||||||
|
h = self.encoder(h)
|
||||||
|
h = h.view(h.shape[0], -1)
|
||||||
|
|
||||||
|
h = h.type(x.dtype)
|
||||||
|
h = h.float()
|
||||||
|
|
||||||
|
h = self.out(h)
|
||||||
|
return h, hand_last, head_last
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GRUEncoderConfig(BaseConfig):
|
||||||
|
in_channels: int
|
||||||
|
seq_len: int = 40
|
||||||
|
seq_len_future: int = 3
|
||||||
|
out_channels: int = 128
|
||||||
|
|
||||||
|
def make_model(self):
|
||||||
|
return GRUEncoderModel(self)
|
||||||
|
|
||||||
|
|
||||||
|
class GRUEncoderModel(nn.Module):
|
||||||
|
def __init__(self, conf: GRUEncoderConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.conf = conf
|
||||||
|
self.dtype = th.float32
|
||||||
|
input_dim = conf.in_channels
|
||||||
|
length = conf.seq_len
|
||||||
|
feature_channels = 32
|
||||||
|
out_channels = conf.out_channels
|
||||||
|
|
||||||
|
self.encoder = nn.GRU(input_dim, feature_channels, 1, batch_first=True)
|
||||||
|
|
||||||
|
self.out = nn.Linear(feature_channels * length, out_channels)
|
||||||
|
|
||||||
|
def forward(self, x, t=None):
|
||||||
|
bs, channels, seq_len = x.shape
|
||||||
|
hand_last = x[:, :6, -1:].expand(-1, -1, self.conf.seq_len_future).clone() #last observed hand position
|
||||||
|
head_last = x[:, 6:, -1:].expand(-1, -1, self.conf.seq_len_future).clone()# last observed head orientation
|
||||||
|
|
||||||
|
h = x.type(self.dtype)
|
||||||
|
h, _ = self.encoder(h.permute(0, 2, 1))
|
||||||
|
h = h.reshape(h.shape[0], -1)
|
||||||
|
|
||||||
|
h = h.type(x.dtype)
|
||||||
|
h = h.float()
|
||||||
|
|
||||||
|
h = self.out(h)
|
||||||
|
return h, hand_last, head_last
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LSTMEncoderConfig(BaseConfig):
|
||||||
|
in_channels: int
|
||||||
|
seq_len: int = 40
|
||||||
|
seq_len_future: int = 3
|
||||||
|
out_channels: int = 128
|
||||||
|
|
||||||
|
def make_model(self):
|
||||||
|
return LSTMEncoderModel(self)
|
||||||
|
|
||||||
|
|
||||||
|
class LSTMEncoderModel(nn.Module):
|
||||||
|
def __init__(self, conf: LSTMEncoderConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.conf = conf
|
||||||
|
self.dtype = th.float32
|
||||||
|
input_dim = conf.in_channels
|
||||||
|
length = conf.seq_len
|
||||||
|
feature_channels = 32
|
||||||
|
out_channels = conf.out_channels
|
||||||
|
|
||||||
|
self.encoder = nn.LSTM(input_dim, feature_channels, 1, batch_first=True)
|
||||||
|
|
||||||
|
self.out = nn.Linear(feature_channels * length, out_channels)
|
||||||
|
|
||||||
|
def forward(self, x, t=None):
|
||||||
|
bs, channels, seq_len = x.shape
|
||||||
|
hand_last = x[:, :6, -1:].expand(-1, -1, self.conf.seq_len_future).clone() #last observed hand position
|
||||||
|
head_last = x[:, 6:, -1:].expand(-1, -1, self.conf.seq_len_future).clone()# last observed head orientation
|
||||||
|
|
||||||
|
h = x.type(self.dtype)
|
||||||
|
h, _ = self.encoder(h.permute(0, 2, 1))
|
||||||
|
h = h.reshape(h.shape[0], -1)
|
||||||
|
|
||||||
|
h = h.type(x.dtype)
|
||||||
|
h = h.float()
|
||||||
|
|
||||||
|
h = self.out(h)
|
||||||
|
return h, hand_last, head_last
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MLPEncoderConfig(BaseConfig):
|
||||||
|
in_channels: int
|
||||||
|
seq_len: int = 40
|
||||||
|
seq_len_future: int = 3
|
||||||
|
out_channels: int = 128
|
||||||
|
|
||||||
|
def make_model(self):
|
||||||
|
return MLPEncoderModel(self)
|
||||||
|
|
||||||
|
|
||||||
|
class MLPEncoderModel(nn.Module):
|
||||||
|
def __init__(self, conf: MLPEncoderConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.conf = conf
|
||||||
|
self.dtype = th.float32
|
||||||
|
input_dim = conf.in_channels
|
||||||
|
length = conf.seq_len
|
||||||
|
out_channels = conf.out_channels
|
||||||
|
|
||||||
|
linear_size = 128
|
||||||
|
self.encoder = nn.Sequential(
|
||||||
|
nn.Linear(length*input_dim, linear_size),
|
||||||
|
nn.LayerNorm([linear_size]),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Linear(linear_size, linear_size),
|
||||||
|
nn.LayerNorm([linear_size]),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.out = nn.Linear(linear_size, out_channels)
|
||||||
|
|
||||||
|
def forward(self, x, t=None):
|
||||||
|
bs, channels, seq_len = x.shape
|
||||||
|
hand_last = x[:, :6, -1:].expand(-1, -1, self.conf.seq_len_future).clone() #last observed hand position
|
||||||
|
head_last = x[:, 6:, -1:].expand(-1, -1, self.conf.seq_len_future).clone()# last observed head orientation
|
||||||
|
|
||||||
|
h = x.type(self.dtype)
|
||||||
|
|
||||||
|
h = h.view(h.shape[0], -1)
|
||||||
|
h = self.encoder(h)
|
||||||
|
|
||||||
|
h = h.type(x.dtype)
|
||||||
|
h = h.float()
|
||||||
|
|
||||||
|
h = self.out(h)
|
||||||
|
return h, hand_last, head_last
|
418
model/unet_autoenc.py
Normal file
418
model/unet_autoenc.py
Normal file
|
@ -0,0 +1,418 @@
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
import torch, pdb
|
||||||
|
import os
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.nn.functional import silu
|
||||||
|
from .unet import *
|
||||||
|
from choices import *
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BeatGANsAutoencConfig(BeatGANsUNetConfig):
|
||||||
|
seq_len_future: int = 3
|
||||||
|
enc_out_channels: int = 128
|
||||||
|
semantic_encoder_type: str = 'gcn'
|
||||||
|
enc_channel_mult: Tuple[int] = None
|
||||||
|
def make_model(self):
|
||||||
|
return BeatGANsAutoencModel(self)
|
||||||
|
|
||||||
|
class BeatGANsAutoencModel(BeatGANsUNetModel):
|
||||||
|
def __init__(self, conf: BeatGANsAutoencConfig):
|
||||||
|
super().__init__(conf)
|
||||||
|
self.conf = conf
|
||||||
|
|
||||||
|
# having only time, cond
|
||||||
|
self.time_embed = TimeStyleSeperateEmbed(
|
||||||
|
time_channels=conf.model_channels,
|
||||||
|
time_out_channels=conf.embed_channels,
|
||||||
|
)
|
||||||
|
|
||||||
|
if conf.semantic_encoder_type == 'gcn':
|
||||||
|
self.encoder = GCNEncoderConfig(
|
||||||
|
seq_len=conf.seq_len,
|
||||||
|
seq_len_future=conf.seq_len_future,
|
||||||
|
in_channels=conf.in_channels,
|
||||||
|
model_channels=16,
|
||||||
|
out_channels=conf.enc_out_channels,
|
||||||
|
channel_mult=conf.enc_channel_mult or conf.channel_mult,
|
||||||
|
).make_model()
|
||||||
|
elif conf.semantic_encoder_type == '1dcnn':
|
||||||
|
self.encoder = CNNEncoderConfig(
|
||||||
|
seq_len=conf.seq_len,
|
||||||
|
seq_len_future=conf.seq_len_future,
|
||||||
|
in_channels=conf.in_channels,
|
||||||
|
out_channels=conf.enc_out_channels,
|
||||||
|
).make_model()
|
||||||
|
elif conf.semantic_encoder_type == 'gru':
|
||||||
|
# ensure deterministic behavior of RNNs
|
||||||
|
os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:2"
|
||||||
|
self.encoder = GRUEncoderConfig(
|
||||||
|
seq_len=conf.seq_len,
|
||||||
|
seq_len_future=conf.seq_len_future,
|
||||||
|
in_channels=conf.in_channels,
|
||||||
|
out_channels=conf.enc_out_channels,
|
||||||
|
).make_model()
|
||||||
|
elif conf.semantic_encoder_type == 'lstm':
|
||||||
|
# ensure deterministic behavior of RNNs
|
||||||
|
os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:2"
|
||||||
|
self.encoder = LSTMEncoderConfig(
|
||||||
|
seq_len=conf.seq_len,
|
||||||
|
seq_len_future=conf.seq_len_future,
|
||||||
|
in_channels=conf.in_channels,
|
||||||
|
out_channels=conf.enc_out_channels,
|
||||||
|
).make_model()
|
||||||
|
elif conf.semantic_encoder_type == 'mlp':
|
||||||
|
self.encoder = MLPEncoderConfig(
|
||||||
|
seq_len=conf.seq_len,
|
||||||
|
seq_len_future=conf.seq_len_future,
|
||||||
|
in_channels=conf.in_channels,
|
||||||
|
out_channels=conf.enc_out_channels,
|
||||||
|
).make_model()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
|
||||||
|
"""
|
||||||
|
Reparameterization trick to sample from N(mu, var) from
|
||||||
|
N(0,1).
|
||||||
|
:param mu: (Tensor) Mean of the latent Gaussian [B x D]
|
||||||
|
:param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
|
||||||
|
:return: (Tensor) [B x D]
|
||||||
|
"""
|
||||||
|
assert self.conf.is_stochastic
|
||||||
|
std = torch.exp(0.5 * logvar)
|
||||||
|
eps = torch.randn_like(std)
|
||||||
|
return eps * std + mu
|
||||||
|
|
||||||
|
def sample_z(self, n: int, device):
|
||||||
|
assert self.conf.is_stochastic
|
||||||
|
return torch.randn(n, self.conf.enc_out_channels, device=device)
|
||||||
|
|
||||||
|
def noise_to_cond(self, noise: Tensor):
|
||||||
|
raise NotImplementedError()
|
||||||
|
assert self.conf.noise_net_conf is not None
|
||||||
|
return self.noise_net.forward(noise)
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
cond, pred_hand, pred_head = self.encoder.forward(x)
|
||||||
|
return cond, pred_hand, pred_head
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stylespace_sizes(self):
|
||||||
|
modules = list(self.input_blocks.modules()) + list(
|
||||||
|
self.middle_block.modules()) + list(self.output_blocks.modules())
|
||||||
|
sizes = []
|
||||||
|
for module in modules:
|
||||||
|
if isinstance(module, ResBlock):
|
||||||
|
linear = module.cond_emb_layers[-1]
|
||||||
|
sizes.append(linear.weight.shape[0])
|
||||||
|
return sizes
|
||||||
|
|
||||||
|
def encode_stylespace(self, x, return_vector: bool = True):
|
||||||
|
"""
|
||||||
|
encode to style space
|
||||||
|
"""
|
||||||
|
modules = list(self.input_blocks.modules()) + list(
|
||||||
|
self.middle_block.modules()) + list(self.output_blocks.modules())
|
||||||
|
# (n, c)
|
||||||
|
cond = self.encoder.forward(x)
|
||||||
|
S = []
|
||||||
|
for module in modules:
|
||||||
|
if isinstance(module, ResBlock):
|
||||||
|
# (n, c')
|
||||||
|
s = module.cond_emb_layers.forward(cond)
|
||||||
|
S.append(s)
|
||||||
|
|
||||||
|
if return_vector:
|
||||||
|
# (n, sum_c)
|
||||||
|
return torch.cat(S, dim=1)
|
||||||
|
else:
|
||||||
|
return S
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
x,
|
||||||
|
t,
|
||||||
|
x_start=None,
|
||||||
|
cond=None,
|
||||||
|
style=None,
|
||||||
|
noise=None,
|
||||||
|
t_cond=None,
|
||||||
|
**kwargs):
|
||||||
|
"""
|
||||||
|
Apply the model to an input batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x_start: the original image to encode
|
||||||
|
cond: output of the encoder
|
||||||
|
noise: random noise (to predict the cond)
|
||||||
|
"""
|
||||||
|
if t_cond is None:
|
||||||
|
t_cond = t ## randomly sampled timestep with the size of [batch_size]
|
||||||
|
|
||||||
|
if noise is not None:
|
||||||
|
# if the noise is given, we predict the cond from noise
|
||||||
|
cond = self.noise_to_cond(noise)
|
||||||
|
|
||||||
|
cond_given = True
|
||||||
|
if cond is None:
|
||||||
|
cond_given = False
|
||||||
|
if x is not None:
|
||||||
|
assert len(x) == len(x_start), f'{len(x)} != {len(x_start)}'
|
||||||
|
|
||||||
|
cond, pred_hand, pred_head = self.encode(x_start)
|
||||||
|
|
||||||
|
if t is not None: ## t==t_cond
|
||||||
|
_t_emb = timestep_embedding(t, self.conf.model_channels)
|
||||||
|
#print("t: {}, _t_emb:{}".format(t, _t_emb))
|
||||||
|
_t_cond_emb = timestep_embedding(t_cond, self.conf.model_channels)
|
||||||
|
#print("t_cond: {}, _t_cond_emb:{}".format(t, _t_cond_emb))
|
||||||
|
else:
|
||||||
|
# this happens when training only autoenc
|
||||||
|
_t_emb = None
|
||||||
|
_t_cond_emb = None
|
||||||
|
|
||||||
|
if self.conf.resnet_two_cond:
|
||||||
|
res = self.time_embed.forward( ## self.time_embed is an MLP
|
||||||
|
time_emb=_t_emb,
|
||||||
|
cond=cond,
|
||||||
|
time_cond_emb=_t_cond_emb,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
if self.conf.resnet_two_cond:
|
||||||
|
# two cond: first = time emb, second = cond_emb
|
||||||
|
emb = res.time_emb
|
||||||
|
cond_emb = res.emb
|
||||||
|
else:
|
||||||
|
# one cond = combined of both time and cond
|
||||||
|
emb = res.emb
|
||||||
|
cond_emb = None
|
||||||
|
|
||||||
|
# override the style if given
|
||||||
|
style = style or res.style ## style==None, res.style: cond, torch.Size([64, 512])
|
||||||
|
|
||||||
|
|
||||||
|
# where in the model to supply time conditions
|
||||||
|
enc_time_emb = emb ## time embeddings
|
||||||
|
mid_time_emb = emb
|
||||||
|
dec_time_emb = emb
|
||||||
|
# where in the model to supply style conditions
|
||||||
|
enc_cond_emb = cond_emb ## z_sem embeddings
|
||||||
|
mid_cond_emb = cond_emb
|
||||||
|
dec_cond_emb = cond_emb
|
||||||
|
|
||||||
|
# hs = []
|
||||||
|
hs = [[] for _ in range(len(self.conf.channel_mult))]
|
||||||
|
|
||||||
|
if x is not None:
|
||||||
|
h = x.type(self.dtype)
|
||||||
|
# input blocks
|
||||||
|
k = 0
|
||||||
|
for i in range(len(self.input_num_blocks)):
|
||||||
|
for j in range(self.input_num_blocks[i]):
|
||||||
|
h = self.input_blocks[k](h,
|
||||||
|
emb=enc_time_emb,
|
||||||
|
cond=enc_cond_emb)
|
||||||
|
# print(i, j, h.shape)
|
||||||
|
'''if h.shape[-1]%2==1:
|
||||||
|
pdb.set_trace()'''
|
||||||
|
hs[i].append(h)
|
||||||
|
k += 1
|
||||||
|
assert k == len(self.input_blocks)
|
||||||
|
|
||||||
|
# middle blocks
|
||||||
|
h = self.middle_block(h, emb=mid_time_emb, cond=mid_cond_emb)
|
||||||
|
else:
|
||||||
|
# no lateral connections
|
||||||
|
# happens when training only the autonecoder
|
||||||
|
h = None
|
||||||
|
hs = [[] for _ in range(len(self.conf.channel_mult))]
|
||||||
|
pdb.set_trace()
|
||||||
|
|
||||||
|
# output blocks
|
||||||
|
k = 0
|
||||||
|
for i in range(len(self.output_num_blocks)):
|
||||||
|
for j in range(self.output_num_blocks[i]):
|
||||||
|
# take the lateral connection from the same layer (in reserve)
|
||||||
|
# until there is no more, use None
|
||||||
|
try:
|
||||||
|
lateral = hs[-i - 1].pop() ## in the reverse order (symmetric)
|
||||||
|
except IndexError:
|
||||||
|
lateral = None
|
||||||
|
'''print(i, j, lateral.shape, h.shape)
|
||||||
|
if lateral.shape[-1]!=h.shape[-1]:
|
||||||
|
pdb.set_trace()'''
|
||||||
|
# print("h is", h.size())
|
||||||
|
# print("lateral is", lateral.size())
|
||||||
|
h = self.output_blocks[k](h,
|
||||||
|
emb=dec_time_emb,
|
||||||
|
cond=dec_cond_emb,
|
||||||
|
lateral=lateral)
|
||||||
|
k += 1
|
||||||
|
|
||||||
|
pred = self.out(h)
|
||||||
|
# print("h:", h.shape)
|
||||||
|
# print("pred:", pred.shape)
|
||||||
|
|
||||||
|
if cond_given == True:
|
||||||
|
return AutoencReturn(pred=pred, cond=cond)
|
||||||
|
else:
|
||||||
|
return AutoencReturn(pred=pred, cond=cond, pred_hand=pred_hand, pred_head=pred_head)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GCNAutoencConfig(GCNUNetConfig):
|
||||||
|
# number of style channels
|
||||||
|
enc_out_channels: int = 256
|
||||||
|
enc_channel_mult: Tuple[int] = None
|
||||||
|
def make_model(self):
|
||||||
|
return GCNAutoencModel(self)
|
||||||
|
|
||||||
|
|
||||||
|
class GCNAutoencModel(GCNUNetModel):
|
||||||
|
def __init__(self, conf: GCNAutoencConfig):
|
||||||
|
super().__init__(conf)
|
||||||
|
self.conf = conf
|
||||||
|
|
||||||
|
# having only time, cond
|
||||||
|
self.time_emb_channels = conf.model_channels
|
||||||
|
self.time_embed = TimeStyleSeperateEmbed(
|
||||||
|
time_channels=self.time_emb_channels,
|
||||||
|
time_out_channels=conf.embed_channels,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.encoder = GCNEncoderConfig(
|
||||||
|
seq_len=conf.seq_len,
|
||||||
|
in_channels=conf.in_channels,
|
||||||
|
model_channels=32,
|
||||||
|
out_channels=conf.enc_out_channels,
|
||||||
|
channel_mult=conf.enc_channel_mult or conf.channel_mult,
|
||||||
|
).make_model()
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
cond = self.encoder.forward(x)
|
||||||
|
return {'cond': cond}
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
x,
|
||||||
|
t,
|
||||||
|
x_start=None,
|
||||||
|
cond=None,
|
||||||
|
**kwargs):
|
||||||
|
"""
|
||||||
|
Apply the model to an input batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x_start: the original input to encode
|
||||||
|
cond: output of the encoder
|
||||||
|
"""
|
||||||
|
|
||||||
|
if cond is None:
|
||||||
|
if x is not None:
|
||||||
|
assert len(x) == len(x_start), f'{len(x)} != {len(x_start)}'
|
||||||
|
|
||||||
|
tmp = self.encode(x_start)
|
||||||
|
cond = tmp['cond']
|
||||||
|
|
||||||
|
if t is not None:
|
||||||
|
_t_emb = timestep_embedding(t, self.time_emb_channels)
|
||||||
|
else:
|
||||||
|
# this happens when training only autoenc
|
||||||
|
_t_emb = None
|
||||||
|
|
||||||
|
if self.conf.resnet_two_cond:
|
||||||
|
res = self.time_embed.forward( ## self.time_embed is an MLP
|
||||||
|
time_emb=_t_emb,
|
||||||
|
cond=cond,
|
||||||
|
)
|
||||||
|
# two cond: first = time emb, second = cond_emb
|
||||||
|
emb = res.time_emb
|
||||||
|
cond_emb = res.emb
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
# where in the model to supply time conditions
|
||||||
|
enc_time_emb = emb ## time embeddings
|
||||||
|
mid_time_emb = emb
|
||||||
|
dec_time_emb = emb
|
||||||
|
enc_cond_emb = cond_emb ## z_sem embeddings
|
||||||
|
mid_cond_emb = cond_emb
|
||||||
|
dec_cond_emb = cond_emb
|
||||||
|
|
||||||
|
|
||||||
|
bs, channels, seq_len = x.shape
|
||||||
|
x = x.reshape(bs, self.conf.node_n, self.in_features, seq_len).permute(0, 2, 1, 3)
|
||||||
|
hs = [[] for _ in range(len(self.conf.channel_mult))]
|
||||||
|
h = x.type(self.dtype)
|
||||||
|
|
||||||
|
# input blocks
|
||||||
|
k = 0
|
||||||
|
for i in range(len(self.input_num_blocks)):
|
||||||
|
for j in range(self.input_num_blocks[i]):
|
||||||
|
h = self.input_blocks[k](h,
|
||||||
|
emb=enc_time_emb,
|
||||||
|
cond=enc_cond_emb)
|
||||||
|
hs[i].append(h)
|
||||||
|
k += 1
|
||||||
|
assert k == len(self.input_blocks)
|
||||||
|
|
||||||
|
# output blocks
|
||||||
|
k = 0
|
||||||
|
for i in range(len(self.output_num_blocks)):
|
||||||
|
for j in range(self.output_num_blocks[i]):
|
||||||
|
# take the lateral connection from the same layer (in reserve)
|
||||||
|
# until there is no more, use None
|
||||||
|
try:
|
||||||
|
lateral = hs[-i - 1].pop() ## in the reverse order (symmetric)
|
||||||
|
except IndexError:
|
||||||
|
lateral = None
|
||||||
|
h = self.output_blocks[k](h,
|
||||||
|
emb=dec_time_emb,
|
||||||
|
cond=dec_cond_emb,
|
||||||
|
lateral=lateral)
|
||||||
|
k += 1
|
||||||
|
|
||||||
|
|
||||||
|
pred = self.out(h)
|
||||||
|
pred = pred.permute(0, 2, 1, 3).reshape(bs, -1, seq_len)
|
||||||
|
|
||||||
|
return AutoencReturn(pred=pred, cond=cond)
|
||||||
|
|
||||||
|
|
||||||
|
class AutoencReturn(NamedTuple):
|
||||||
|
pred: Tensor
|
||||||
|
cond: Tensor = None
|
||||||
|
pred_hand: Tensor = None
|
||||||
|
pred_head: Tensor = None
|
||||||
|
|
||||||
|
|
||||||
|
class EmbedReturn(NamedTuple):
|
||||||
|
# style and time
|
||||||
|
emb: Tensor = None
|
||||||
|
# time only
|
||||||
|
time_emb: Tensor = None
|
||||||
|
# style only (but could depend on time)
|
||||||
|
style: Tensor = None
|
||||||
|
|
||||||
|
|
||||||
|
class TimeStyleSeperateEmbed(nn.Module):
|
||||||
|
# embed only style
|
||||||
|
def __init__(self, time_channels, time_out_channels):
|
||||||
|
super().__init__()
|
||||||
|
self.time_embed = nn.Sequential(
|
||||||
|
linear(time_channels, time_out_channels),
|
||||||
|
nn.SiLU(),
|
||||||
|
linear(time_out_channels, time_out_channels),
|
||||||
|
)
|
||||||
|
self.style = nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, time_emb=None, cond=None, **kwargs):
|
||||||
|
if time_emb is None:
|
||||||
|
# happens with autoenc training mode
|
||||||
|
time_emb = None
|
||||||
|
else:
|
||||||
|
time_emb = self.time_embed(time_emb)
|
||||||
|
style = self.style(cond) ## style==cond
|
||||||
|
return EmbedReturn(emb=style, time_emb=time_emb, style=style)
|
193
preprocess.py
Normal file
193
preprocess.py
Normal file
|
@ -0,0 +1,193 @@
|
||||||
|
import os, random, math, copy
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import pickle as pkl
|
||||||
|
import logging, sys
|
||||||
|
from torch.utils.data import DataLoader,Dataset
|
||||||
|
import multiprocessing as mp
|
||||||
|
import json
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
def MakeDir(dirpath):
|
||||||
|
if not os.path.exists(dirpath):
|
||||||
|
os.makedirs(dirpath)
|
||||||
|
|
||||||
|
|
||||||
|
def load_egobody(data_dir, seq_len, sample_rate=1, train=1):
|
||||||
|
data_dir_train = data_dir + 'train/'
|
||||||
|
data_dir_test = data_dir + 'test/'
|
||||||
|
if train == 0:
|
||||||
|
data_dirs = [data_dir_test] # test
|
||||||
|
elif train == 1:
|
||||||
|
data_dirs = [data_dir_train] # train
|
||||||
|
elif train == 2:
|
||||||
|
data_dirs = [data_dir_train, data_dir_test] # train + test
|
||||||
|
|
||||||
|
hand_head = []
|
||||||
|
for data_dir in data_dirs:
|
||||||
|
file_paths = sorted(os.listdir(data_dir))
|
||||||
|
pose_xyz_file_paths = []
|
||||||
|
head_file_paths = []
|
||||||
|
for path in file_paths:
|
||||||
|
path_split = path.split('_')
|
||||||
|
data_type = path_split[-1][:-4]
|
||||||
|
if(data_type == 'xyz'):
|
||||||
|
pose_xyz_file_paths.append(path)
|
||||||
|
if(data_type == 'head'):
|
||||||
|
head_file_paths.append(path)
|
||||||
|
|
||||||
|
file_num = len(pose_xyz_file_paths)
|
||||||
|
for i in range(file_num):
|
||||||
|
pose_data = np.load(data_dir + pose_xyz_file_paths[i])
|
||||||
|
head_data = np.load(data_dir + head_file_paths[i])
|
||||||
|
num_frames = pose_data.shape[0]
|
||||||
|
if num_frames < seq_len:
|
||||||
|
continue
|
||||||
|
|
||||||
|
head_pos = pose_data[:, 15*3:16*3]
|
||||||
|
left_hand_pos = pose_data[:, 20*3:21*3]
|
||||||
|
right_hand_pos = pose_data[:, 21*3:22*3]
|
||||||
|
head_ori = head_data
|
||||||
|
left_hand_pos -= head_pos # convert hand positions to head coordinate system
|
||||||
|
right_hand_pos -= head_pos
|
||||||
|
hand_head_data = left_hand_pos
|
||||||
|
hand_head_data = np.concatenate((hand_head_data, right_hand_pos), axis=1)
|
||||||
|
hand_head_data = np.concatenate((hand_head_data, head_ori), axis=1)
|
||||||
|
|
||||||
|
fs = np.arange(0, num_frames - seq_len + 1)
|
||||||
|
fs_sel = fs
|
||||||
|
for i in np.arange(seq_len - 1):
|
||||||
|
fs_sel = np.vstack((fs_sel, fs + i + 1))
|
||||||
|
fs_sel = fs_sel.transpose()
|
||||||
|
seq_sel = hand_head_data[fs_sel, :]
|
||||||
|
seq_sel = seq_sel[0::sample_rate, :, :]
|
||||||
|
if len(hand_head) == 0:
|
||||||
|
hand_head = seq_sel
|
||||||
|
else:
|
||||||
|
hand_head = np.concatenate((hand_head, seq_sel), axis=0)
|
||||||
|
|
||||||
|
hand_head = np.transpose(hand_head, (0, 2, 1))
|
||||||
|
return hand_head
|
||||||
|
|
||||||
|
|
||||||
|
def load_adt(data_dir, seq_len, sample_rate=1, train=1):
|
||||||
|
data_dir_train = data_dir + 'train/'
|
||||||
|
data_dir_test = data_dir + 'test/'
|
||||||
|
if train == 0:
|
||||||
|
data_dirs = [data_dir_test] # test
|
||||||
|
elif train == 1:
|
||||||
|
data_dirs = [data_dir_train] # train
|
||||||
|
elif train == 2:
|
||||||
|
data_dirs = [data_dir_train, data_dir_test] # train + test
|
||||||
|
|
||||||
|
hand_head = []
|
||||||
|
for data_dir in data_dirs:
|
||||||
|
file_paths = sorted(os.listdir(data_dir))
|
||||||
|
pose_xyz_file_paths = []
|
||||||
|
head_file_paths = []
|
||||||
|
for path in file_paths:
|
||||||
|
path_split = path.split('_')
|
||||||
|
data_type = path_split[-1][:-4]
|
||||||
|
if(data_type == 'xyz'):
|
||||||
|
pose_xyz_file_paths.append(path)
|
||||||
|
if(data_type == 'head'):
|
||||||
|
head_file_paths.append(path)
|
||||||
|
|
||||||
|
file_num = len(pose_xyz_file_paths)
|
||||||
|
for i in range(file_num):
|
||||||
|
pose_data = np.load(data_dir + pose_xyz_file_paths[i])
|
||||||
|
head_data = np.load(data_dir + head_file_paths[i])
|
||||||
|
num_frames = pose_data.shape[0]
|
||||||
|
if num_frames < seq_len:
|
||||||
|
continue
|
||||||
|
|
||||||
|
head_pos = pose_data[:, 4*3:5*3]
|
||||||
|
left_hand_pos = pose_data[:, 8*3:9*3]
|
||||||
|
right_hand_pos = pose_data[:, 12*3:13*3]
|
||||||
|
head_ori = head_data
|
||||||
|
left_hand_pos -= head_pos # convert hand positions to head coordinate system
|
||||||
|
right_hand_pos -= head_pos
|
||||||
|
hand_head_data = left_hand_pos
|
||||||
|
hand_head_data = np.concatenate((hand_head_data, right_hand_pos), axis=1)
|
||||||
|
hand_head_data = np.concatenate((hand_head_data, head_ori), axis=1)
|
||||||
|
|
||||||
|
fs = np.arange(0, num_frames - seq_len + 1)
|
||||||
|
fs_sel = fs
|
||||||
|
for i in np.arange(seq_len - 1):
|
||||||
|
fs_sel = np.vstack((fs_sel, fs + i + 1))
|
||||||
|
fs_sel = fs_sel.transpose()
|
||||||
|
seq_sel = hand_head_data[fs_sel, :]
|
||||||
|
seq_sel = seq_sel[0::sample_rate, :, :]
|
||||||
|
if len(hand_head) == 0:
|
||||||
|
hand_head = seq_sel
|
||||||
|
else:
|
||||||
|
hand_head = np.concatenate((hand_head, seq_sel), axis=0)
|
||||||
|
|
||||||
|
hand_head = np.transpose(hand_head, (0, 2, 1))
|
||||||
|
return hand_head
|
||||||
|
|
||||||
|
|
||||||
|
def load_gimo(data_dir, seq_len, sample_rate=1, train=1):
|
||||||
|
data_dir_train = data_dir + 'train/'
|
||||||
|
data_dir_test = data_dir + 'test/'
|
||||||
|
if train == 0:
|
||||||
|
data_dirs = [data_dir_test] # test
|
||||||
|
elif train == 1:
|
||||||
|
data_dirs = [data_dir_train] # train
|
||||||
|
elif train == 2:
|
||||||
|
data_dirs = [data_dir_train, data_dir_test] # train + test
|
||||||
|
|
||||||
|
hand_head = []
|
||||||
|
for data_dir in data_dirs:
|
||||||
|
file_paths = sorted(os.listdir(data_dir))
|
||||||
|
pose_xyz_file_paths = []
|
||||||
|
head_file_paths = []
|
||||||
|
for path in file_paths:
|
||||||
|
path_split = path.split('_')
|
||||||
|
data_type = path_split[-1][:-4]
|
||||||
|
if(data_type == 'xyz'):
|
||||||
|
pose_xyz_file_paths.append(path)
|
||||||
|
if(data_type == 'head'):
|
||||||
|
head_file_paths.append(path)
|
||||||
|
|
||||||
|
file_num = len(pose_xyz_file_paths)
|
||||||
|
for i in range(file_num):
|
||||||
|
pose_data = np.load(data_dir + pose_xyz_file_paths[i])
|
||||||
|
head_data = np.load(data_dir + head_file_paths[i])
|
||||||
|
num_frames = pose_data.shape[0]
|
||||||
|
if num_frames < seq_len:
|
||||||
|
continue
|
||||||
|
|
||||||
|
head_pos = pose_data[:, 15*3:16*3]
|
||||||
|
left_hand_pos = pose_data[:, 20*3:21*3]
|
||||||
|
right_hand_pos = pose_data[:, 21*3:22*3]
|
||||||
|
head_ori = head_data
|
||||||
|
left_hand_pos -= head_pos # convert hand positions to head coordinate system
|
||||||
|
right_hand_pos -= head_pos
|
||||||
|
hand_head_data = left_hand_pos
|
||||||
|
hand_head_data = np.concatenate((hand_head_data, right_hand_pos), axis=1)
|
||||||
|
hand_head_data = np.concatenate((hand_head_data, head_ori), axis=1)
|
||||||
|
|
||||||
|
fs = np.arange(0, num_frames - seq_len + 1)
|
||||||
|
fs_sel = fs
|
||||||
|
for i in np.arange(seq_len - 1):
|
||||||
|
fs_sel = np.vstack((fs_sel, fs + i + 1))
|
||||||
|
fs_sel = fs_sel.transpose()
|
||||||
|
seq_sel = hand_head_data[fs_sel, :]
|
||||||
|
seq_sel = seq_sel[0::sample_rate, :, :]
|
||||||
|
if len(hand_head) == 0:
|
||||||
|
hand_head = seq_sel
|
||||||
|
else:
|
||||||
|
hand_head = np.concatenate((hand_head, seq_sel), axis=0)
|
||||||
|
|
||||||
|
hand_head = np.transpose(hand_head, (0, 2, 1))
|
||||||
|
return hand_head
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
data_dir = "/scratch/hu/pose_forecast/egobody_pose2gaze/"
|
||||||
|
seq_len = 40
|
||||||
|
|
||||||
|
test_data = load_egobody(data_dir, seq_len, sample_rate=10, train=0)
|
||||||
|
print("\ndataset size: {}".format(test_data.shape))
|
3
train.sh
Normal file
3
train.sh
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
#python main.py --gpus 7 --mode 'train' --model_name 'haheae';
|
||||||
|
|
||||||
|
python main.py --gpus 7 --mode 'eval' --model_name 'haheae';
|
Loading…
Add table
Add a link
Reference in a new issue