HaHeAE/config.py
2025-06-03 21:11:04 +02:00

153 lines
No EOL
5.4 KiB
Python

from model.blocks import *
from diffusion.resample import UniformSampler
from dataclasses import dataclass
from diffusion.diffusion import space_timesteps
from typing import Tuple
from config_base import BaseConfig
from diffusion import *
from diffusion.base import GenerativeType, LossType, ModelMeanType, ModelVarType, get_named_beta_schedule
from model import *
from choices import *
from preprocess import *
import os
@dataclass
class TrainConfig(BaseConfig):
name: str = ''
base_dir: str = './checkpoints/'
logdir: str = f'{base_dir}{name}'
data_name: str = ''
data_val_name: str = ''
seq_len: int = 40 # for reconstruction
seq_len_future: int = 3 # for prediction
in_channels = 9
fp16: bool = True
lr: float = 1e-4
ema_decay: float = 0.9999
seed: int = 0 # random seed
batch_size: int = 64
accum_batches: int = 1
batch_size_eval: int = 1024
total_epochs: int = 1_000
save_every_epochs: int = 10
eval_every_epochs: int = 10
train_mode: TrainMode = TrainMode.diffusion
T: int = 1000
T_eval: int = 100
diffusion_type: str = 'beatgans'
semantic_encoder_type: str = 'gcn'
net_beatgans_embed_channels: int = 128
beatgans_gen_type: GenerativeType = GenerativeType.ddim
beatgans_loss_type: LossType = LossType.mse
hand_mse_factor = 1.0
head_mse_factor = 1.0
beatgans_model_mean_type: ModelMeanType = ModelMeanType.eps
beatgans_model_var_type: ModelVarType = ModelVarType.fixed_large
beatgans_rescale_timesteps: bool = False
beta_scheduler: str = 'linear'
net_ch: int = 64
net_ch_mult: Tuple[int, ...]= (1, 2, 4)
net_enc_channel_mult: Tuple[int] = (1, 2, 4)
grad_clip: float = 1
optimizer: OptimizerType = OptimizerType.adam
weight_decay: float = 0
warmup: int = 0
model_conf: ModelConfig = None
model_name: ModelName = ModelName.beatgans_autoenc
model_type: ModelType = None
@property
def batch_size_effective(self):
return self.batch_size*self.accum_batches
def _make_diffusion_conf(self, T=None):
if self.diffusion_type == 'beatgans':
# can use T < self.T for evaluation
# follows the guided-diffusion repo conventions
# t's are evenly spaced
if self.beatgans_gen_type == GenerativeType.ddpm:
section_counts = [T]
elif self.beatgans_gen_type == GenerativeType.ddim:
section_counts = f'ddim{T}'
else:
raise NotImplementedError()
return SpacedDiffusionBeatGansConfig(
gen_type=self.beatgans_gen_type,
model_type=self.model_type,
betas=get_named_beta_schedule(self.beta_scheduler, T),
model_mean_type=self.beatgans_model_mean_type,
model_var_type=self.beatgans_model_var_type,
loss_type=self.beatgans_loss_type,
rescale_timesteps=self.beatgans_rescale_timesteps,
use_timesteps=space_timesteps(num_timesteps=T, section_counts=section_counts),
fp16=self.fp16,
)
else:
raise NotImplementedError()
@property
def model_out_channels(self):
return self.in_channels
@property
def model_input_channels(self):
return self.in_channels
def make_T_sampler(self):
return UniformSampler(self.T)
def make_diffusion_conf(self):
return self._make_diffusion_conf(self.T)
def make_eval_diffusion_conf(self):
return self._make_diffusion_conf(T=self.T_eval)
def make_model_conf(self):
cls = BeatGANsAutoencConfig
if self.model_name == ModelName.beatgans_autoenc:
self.model_type = ModelType.autoencoder
else:
raise NotImplementedError()
self.model_conf = cls(
semantic_encoder_type = self.semantic_encoder_type,
channel_mult=self.net_ch_mult,
seq_len = self.seq_len,
seq_len_future = self.seq_len_future,
embed_channels=self.net_beatgans_embed_channels,
enc_out_channels=self.net_beatgans_embed_channels,
enc_channel_mult=self.net_enc_channel_mult,
in_channels=self.model_input_channels,
model_channels=self.net_ch,
out_channels=self.model_out_channels,
)
return self.model_conf
def egobody_autoenc(mode, encoder_type='gcn', hand_mse_factor=1.0, head_mse_factor=1.0, data_sample_rate=1, epoch=130,in_channels=9, seq_len=40):
conf = TrainConfig()
conf.seq_len = seq_len
conf.seq_len_future = 3
conf.in_channels = in_channels
conf.net_beatgans_embed_channels = 128
conf.net_ch = 64
conf.net_ch_mult = (1, 1, 1)
conf.semantic_encoder_type = encoder_type
conf.hand_mse_factor = hand_mse_factor
conf.head_mse_factor = head_mse_factor
conf.net_enc_channel_mult = conf.net_ch_mult
conf.total_epochs = epoch
conf.save_every_epochs = 10
conf.eval_every_epochs = 10
conf.batch_size = 64
conf.batch_size_eval = 1024*4
conf.data_dir = "/scratch/hu/pose_forecast/egobody_pose2gaze/"
conf.data_sample_rate = data_sample_rate
conf.name = 'egobody_autoenc'
conf.data_name = 'egobody'
conf.mode = mode
conf.make_model_conf()
return conf