import copy, wandb import os import random import numpy as np import pandas as pd import pytorch_lightning as pl import torch from pytorch_lightning import loggers as pl_loggers from pytorch_lightning.callbacks import * from torch.cuda import amp from torch.optim.optimizer import Optimizer from torch.utils.data.dataset import TensorDataset from config import * from dataset import * from dist_utils import * def MakeDir(dirName): if not os.path.exists(dirName): os.makedirs(dirName) class LitModel(pl.LightningModule): def __init__(self, conf: TrainConfig, betas): super().__init__() self.save_hyperparameters({k:v for (k,v) in vars(conf).items() if not callable(v)}) self.save_hyperparameters(conf.as_dict_jsonable()) assert conf.train_mode != TrainMode.manipulate if conf.seed is not None: pl.seed_everything(conf.seed) conf.betas = betas self.conf = conf self.model = conf.make_model_conf().make_model() self.ema_model = copy.deepcopy(self.model) self.ema_model.requires_grad_(False) self.ema_model.eval() model_size = 0 for param in self.model.parameters(): model_size += param.data.nelement() print('Model params: %.2f M' % (model_size / 1024 / 1024)) self.sampler = conf.make_diffusion_conf().make_sampler() self.eval_sampler = conf.make_eval_diffusion_conf().make_sampler() self.T_sampler = conf.make_T_sampler() if conf.train_mode.use_latent_net(): self.latent_sampler = conf.make_latent_diffusion_conf( ).make_sampler() self.eval_latent_sampler = conf.make_latent_eval_diffusion_conf( ).make_sampler() else: self.latent_sampler = None self.eval_latent_sampler = None if conf.pretrain is not None: print(f'loading pretrain ... {conf.pretrain.name}') state = torch.load(conf.pretrain.path, map_location='cpu') print('step:', state['global_step']) self.load_state_dict(state['state_dict'], strict=False) if conf.latent_infer_path is not None: print('loading latent stats ...') state = torch.load(conf.latent_infer_path) self.conds = state['conds'] else: self.conds_mean = None self.conds_std = None def normalize(self, cond): cond = (cond - self.conds_mean.to(self.device)) / self.conds_std.to( self.device) return cond def denormalize(self, cond): cond = (cond * self.conds_std.to(self.device)) + self.conds_mean.to( self.device) return cond def render(self, noise, cond=None, T=None): if T is None: sampler = self.eval_sampler else: sampler = self.conf._make_diffusion_conf(T).make_sampler() if cond is not None: pred_img = render_condition(self.conf, self.ema_model, noise, sampler=sampler, cond=cond) else: pred_img = render_uncondition(self.conf, self.ema_model, noise, sampler=sampler, latent_sampler=None) return pred_img def encode(self, x): assert self.conf.model_type.has_autoenc() cond = self.ema_model.encoder.forward(x) return cond def encode_stochastic(self, x, cond, T=None): if T is None: sampler = self.eval_sampler else: sampler = self.conf._make_diffusion_conf(T).make_sampler() out = sampler.ddim_reverse_sample_loop(self.ema_model, x, model_kwargs={'cond': cond}) return out['sample'], out['xstart_t'] def forward(self, noise=None, x_start=None, ema_model: bool = False): with amp.autocast(False): if ema_model: model = self.ema_model else: model = self.model gen = self.eval_sampler.sample(model=model, noise=noise, x_start=x_start) return gen def setup(self, stage=None) -> None: """ make datasets & seeding each worker separately """ if self.conf.seed is not None: seed = self.conf.seed * get_world_size() + self.global_rank np.random.seed(seed) random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True taskdata, tasklabels = loadDataset(self.conf) assert self.conf.num_users==tasklabels['user'].nunique() tasklabels = pd.DataFrame(tasklabels['user'].astype('category').cat.codes.values)[0].values.astype(int) assert self.conf.num_users==len(np.unique(tasklabels)) self.train_data = self.conf.make_dataset(taskdata, tasklabels) self.val_data = self.train_data def train_dataloader(self): """ return the dataloader, if diffusion mode => return image dataset if latent mode => return the inferred latent dataset """ if self.conf.train_mode.require_dataset_infer(): if self.conds is None: self.conds = self.infer_whole_dataset() self.conds_mean.data = self.conds.float().mean(dim=0, keepdim=True) self.conds_std.data = self.conds.float().std(dim=0, keepdim=True) print('mean:', self.conds_mean.mean(), 'std:', self.conds_std.mean()) conf = self.conf.clone() conf.batch_size = self.batch_size data = TensorDataset(self.conds) return conf.make_loader(data, shuffle=True) else: return torch.utils.data.DataLoader(self.train_data, batch_size=self.conf.batch_size, shuffle=True) @property def batch_size(self): """ local batch size for each worker """ ws = get_world_size() assert self.conf.batch_size % ws == 0 return self.conf.batch_size // ws @property def num_samples(self): """ (global) batch size * iterations """ return self.global_step * self.conf.batch_size_effective def is_last_accum(self, batch_idx): """ is it the last gradient accumulation loop? used with gradient_accum > 1 and to see if the optimizer will perform "step" in this iteration or not """ return (batch_idx + 1) % self.conf.accum_batches == 0 def training_step(self, batch, batch_idx): """ given an input, calculate the loss function no optimization at this stage. """ with amp.autocast(False): if self.conf.train_mode.require_dataset_infer(): cond = batch[0] if self.conf.latent_znormalize: cond = (cond - self.conds_mean.to( self.device)) / self.conds_std.to(self.device) else: imgs = batch[0] x_start = imgs if self.conf.train_mode == TrainMode.diffusion: t, weight = self.T_sampler.sample(len(x_start), x_start.device) losses = self.sampler.training_losses(model=self.model, x_start=x_start, t=t, user_label=batch[1], lossbetas=self.conf.betas) elif self.conf.train_mode.is_latent_diffusion(): t, weight = self.T_sampler.sample(len(cond), cond.device) latent_losses = self.latent_sampler.training_losses( model=self.model.latent_net, x_start=cond, t=t) losses = { 'latent': latent_losses['loss'], 'loss': latent_losses['loss'] } else: raise NotImplementedError() loss = losses['loss'].mean() self.log("train_loss", loss) return {'loss': loss} def on_train_batch_end(self, outputs, batch, batch_idx: int, dataloader_idx: int) -> None: """ after each training step """ if self.is_last_accum(batch_idx): if (batch_idx==len(self.train_dataloader())-1) and ((self.current_epoch+1) % 10 == 0): save_path = os.path.join(self.conf.logdir, 'epoch%d.ckpt' % self.current_epoch) torch.save({ 'state_dict': self.state_dict(), 'global_step': self.global_step, 'loss': outputs['loss'], }, save_path) if self.conf.train_mode == TrainMode.latent_diffusion: ema(self.model.latent_net, self.ema_model.latent_net, self.conf.ema_decay) else: ema(self.model, self.ema_model, self.conf.ema_decay) def on_before_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None: if self.conf.grad_clip > 0: params = [ p for group in optimizer.param_groups for p in group['params'] ] torch.nn.utils.clip_grad_norm_(params, max_norm=self.conf.grad_clip) def configure_optimizers(self): out = {} if self.conf.optimizer == OptimizerType.adam: optim = torch.optim.Adam(self.model.parameters(), lr=self.conf.lr, weight_decay=self.conf.weight_decay) elif self.conf.optimizer == OptimizerType.adamw: optim = torch.optim.AdamW(self.model.parameters(), lr=self.conf.lr, weight_decay=self.conf.weight_decay) else: raise NotImplementedError() out['optimizer'] = optim if self.conf.warmup > 0: sched = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=WarmupLR( self.conf.warmup)) out['lr_scheduler'] = { 'scheduler': sched, 'interval': 'step', } return out def split_tensor(self, x): """ extract the tensor for a corresponding "worker" in the batch dimension Args: x: (n, c) Returns: x: (n_local, c) """ n = len(x) rank = self.global_rank world_size = get_world_size() per_rank = n // world_size return x[rank * per_rank:(rank + 1) * per_rank] def ema(source, target, decay): source_dict = source.state_dict() target_dict = target.state_dict() for key in source_dict.keys(): target_dict[key].data.copy_(target_dict[key].data * decay + source_dict[key].data * (1 - decay)) class WarmupLR: def __init__(self, warmup) -> None: self.warmup = warmup def __call__(self, step): return min(step, self.warmup) / self.warmup def is_time(num_samples, every, step_size): closest = (num_samples // every) * every return num_samples - closest < step_size def train(conf: TrainConfig, model: LitModel, gpus, nodes=1): checkpoint = ModelCheckpoint(dirpath=conf.logdir, save_last=True, save_top_k=1, every_n_train_steps=conf.save_every_samples // conf.batch_size_effective) checkpoint_path = f'{conf.logdir}last.ckpt' print('ckpt path:', checkpoint_path) if os.path.exists(checkpoint_path): resume = checkpoint_path print('resume!') else: if conf.continue_from is not None: resume = conf.continue_from.path else: resume = None plugins = [] if len(gpus) == 1 and nodes == 1: accelerator = None else: accelerator = 'ddp' from pytorch_lightning.plugins import DDPPlugin plugins.append(DDPPlugin(find_unused_parameters=False)) wandb_logger = pl_loggers.WandbLogger(project='dismouse', name='%s_%s'%(model.conf.pretrainDataset, conf.logdir.split('/')[-2]), log_model=True, save_dir=conf.logdir, dir = conf.logdir, config=vars(model.conf), save_code=True, settings=wandb.Settings(code_dir=".")) trainer = pl.Trainer( max_steps=conf.total_samples // conf.batch_size_effective, resume_from_checkpoint=resume, gpus=gpus, num_nodes=nodes, accelerator=accelerator, precision=16 if conf.fp16 else 32, callbacks=[ checkpoint, LearningRateMonitor(), ], replace_sampler_ddp=True, logger= wandb_logger, accumulate_grad_batches=conf.accum_batches, plugins=plugins, ) trainer.fit(model) wandb.finish()