378 lines
14 KiB
Python
378 lines
14 KiB
Python
|
import copy, wandb
|
||
|
import os
|
||
|
import random
|
||
|
import numpy as np
|
||
|
import pandas as pd
|
||
|
import pytorch_lightning as pl
|
||
|
import torch
|
||
|
from pytorch_lightning import loggers as pl_loggers
|
||
|
from pytorch_lightning.callbacks import *
|
||
|
from torch.cuda import amp
|
||
|
from torch.optim.optimizer import Optimizer
|
||
|
from torch.utils.data.dataset import TensorDataset
|
||
|
|
||
|
from config import *
|
||
|
from dataset import *
|
||
|
from dist_utils import *
|
||
|
|
||
|
def MakeDir(dirName):
|
||
|
if not os.path.exists(dirName):
|
||
|
os.makedirs(dirName)
|
||
|
|
||
|
class LitModel(pl.LightningModule):
|
||
|
def __init__(self, conf: TrainConfig, betas):
|
||
|
super().__init__()
|
||
|
|
||
|
self.save_hyperparameters({k:v for (k,v) in vars(conf).items() if not callable(v)})
|
||
|
self.save_hyperparameters(conf.as_dict_jsonable())
|
||
|
|
||
|
assert conf.train_mode != TrainMode.manipulate
|
||
|
if conf.seed is not None:
|
||
|
pl.seed_everything(conf.seed)
|
||
|
|
||
|
conf.betas = betas
|
||
|
self.conf = conf
|
||
|
self.model = conf.make_model_conf().make_model()
|
||
|
|
||
|
self.ema_model = copy.deepcopy(self.model)
|
||
|
self.ema_model.requires_grad_(False)
|
||
|
self.ema_model.eval()
|
||
|
|
||
|
model_size = 0
|
||
|
for param in self.model.parameters():
|
||
|
model_size += param.data.nelement()
|
||
|
print('Model params: %.2f M' % (model_size / 1024 / 1024))
|
||
|
|
||
|
self.sampler = conf.make_diffusion_conf().make_sampler()
|
||
|
self.eval_sampler = conf.make_eval_diffusion_conf().make_sampler()
|
||
|
|
||
|
self.T_sampler = conf.make_T_sampler()
|
||
|
|
||
|
if conf.train_mode.use_latent_net():
|
||
|
self.latent_sampler = conf.make_latent_diffusion_conf(
|
||
|
).make_sampler()
|
||
|
self.eval_latent_sampler = conf.make_latent_eval_diffusion_conf(
|
||
|
).make_sampler()
|
||
|
else:
|
||
|
self.latent_sampler = None
|
||
|
self.eval_latent_sampler = None
|
||
|
|
||
|
if conf.pretrain is not None:
|
||
|
print(f'loading pretrain ... {conf.pretrain.name}')
|
||
|
state = torch.load(conf.pretrain.path, map_location='cpu')
|
||
|
print('step:', state['global_step'])
|
||
|
self.load_state_dict(state['state_dict'], strict=False)
|
||
|
|
||
|
if conf.latent_infer_path is not None:
|
||
|
print('loading latent stats ...')
|
||
|
state = torch.load(conf.latent_infer_path)
|
||
|
self.conds = state['conds']
|
||
|
else:
|
||
|
self.conds_mean = None
|
||
|
self.conds_std = None
|
||
|
|
||
|
def normalize(self, cond):
|
||
|
cond = (cond - self.conds_mean.to(self.device)) / self.conds_std.to(
|
||
|
self.device)
|
||
|
return cond
|
||
|
|
||
|
def denormalize(self, cond):
|
||
|
cond = (cond * self.conds_std.to(self.device)) + self.conds_mean.to(
|
||
|
self.device)
|
||
|
return cond
|
||
|
|
||
|
def render(self, noise, cond=None, T=None):
|
||
|
if T is None:
|
||
|
sampler = self.eval_sampler
|
||
|
else:
|
||
|
sampler = self.conf._make_diffusion_conf(T).make_sampler()
|
||
|
|
||
|
if cond is not None:
|
||
|
pred_img = render_condition(self.conf,
|
||
|
self.ema_model,
|
||
|
noise,
|
||
|
sampler=sampler,
|
||
|
cond=cond)
|
||
|
else:
|
||
|
pred_img = render_uncondition(self.conf,
|
||
|
self.ema_model,
|
||
|
noise,
|
||
|
sampler=sampler,
|
||
|
latent_sampler=None)
|
||
|
return pred_img
|
||
|
|
||
|
def encode(self, x):
|
||
|
assert self.conf.model_type.has_autoenc()
|
||
|
cond = self.ema_model.encoder.forward(x)
|
||
|
return cond
|
||
|
|
||
|
def encode_stochastic(self, x, cond, T=None):
|
||
|
if T is None:
|
||
|
sampler = self.eval_sampler
|
||
|
else:
|
||
|
sampler = self.conf._make_diffusion_conf(T).make_sampler()
|
||
|
|
||
|
out = sampler.ddim_reverse_sample_loop(self.ema_model,
|
||
|
x,
|
||
|
model_kwargs={'cond': cond})
|
||
|
return out['sample'], out['xstart_t']
|
||
|
|
||
|
def forward(self, noise=None, x_start=None, ema_model: bool = False):
|
||
|
with amp.autocast(False):
|
||
|
if ema_model:
|
||
|
model = self.ema_model
|
||
|
else:
|
||
|
model = self.model
|
||
|
gen = self.eval_sampler.sample(model=model,
|
||
|
noise=noise,
|
||
|
x_start=x_start)
|
||
|
return gen
|
||
|
|
||
|
def setup(self, stage=None) -> None:
|
||
|
"""
|
||
|
make datasets & seeding each worker separately
|
||
|
"""
|
||
|
if self.conf.seed is not None:
|
||
|
seed = self.conf.seed * get_world_size() + self.global_rank
|
||
|
np.random.seed(seed)
|
||
|
random.seed(seed)
|
||
|
torch.manual_seed(seed)
|
||
|
torch.cuda.manual_seed(seed)
|
||
|
torch.cuda.manual_seed_all(seed)
|
||
|
torch.backends.cudnn.benchmark = False
|
||
|
torch.backends.cudnn.deterministic = True
|
||
|
|
||
|
taskdata, tasklabels = loadDataset(self.conf)
|
||
|
assert self.conf.num_users==tasklabels['user'].nunique()
|
||
|
|
||
|
tasklabels = pd.DataFrame(tasklabels['user'].astype('category').cat.codes.values)[0].values.astype(int)
|
||
|
assert self.conf.num_users==len(np.unique(tasklabels))
|
||
|
|
||
|
self.train_data = self.conf.make_dataset(taskdata, tasklabels)
|
||
|
self.val_data = self.train_data
|
||
|
|
||
|
def train_dataloader(self):
|
||
|
"""
|
||
|
return the dataloader, if diffusion mode => return image dataset
|
||
|
if latent mode => return the inferred latent dataset
|
||
|
"""
|
||
|
if self.conf.train_mode.require_dataset_infer():
|
||
|
if self.conds is None:
|
||
|
self.conds = self.infer_whole_dataset()
|
||
|
self.conds_mean.data = self.conds.float().mean(dim=0,
|
||
|
keepdim=True)
|
||
|
self.conds_std.data = self.conds.float().std(dim=0,
|
||
|
keepdim=True)
|
||
|
print('mean:', self.conds_mean.mean(), 'std:',
|
||
|
self.conds_std.mean())
|
||
|
|
||
|
conf = self.conf.clone()
|
||
|
conf.batch_size = self.batch_size
|
||
|
data = TensorDataset(self.conds)
|
||
|
return conf.make_loader(data, shuffle=True)
|
||
|
else:
|
||
|
return torch.utils.data.DataLoader(self.train_data, batch_size=self.conf.batch_size, shuffle=True)
|
||
|
|
||
|
@property
|
||
|
def batch_size(self):
|
||
|
"""
|
||
|
local batch size for each worker
|
||
|
"""
|
||
|
ws = get_world_size()
|
||
|
assert self.conf.batch_size % ws == 0
|
||
|
return self.conf.batch_size // ws
|
||
|
|
||
|
@property
|
||
|
def num_samples(self):
|
||
|
"""
|
||
|
(global) batch size * iterations
|
||
|
"""
|
||
|
return self.global_step * self.conf.batch_size_effective
|
||
|
|
||
|
def is_last_accum(self, batch_idx):
|
||
|
"""
|
||
|
is it the last gradient accumulation loop?
|
||
|
used with gradient_accum > 1 and to see if the optimizer will perform "step" in this iteration or not
|
||
|
"""
|
||
|
return (batch_idx + 1) % self.conf.accum_batches == 0
|
||
|
|
||
|
def training_step(self, batch, batch_idx):
|
||
|
"""
|
||
|
given an input, calculate the loss function
|
||
|
no optimization at this stage.
|
||
|
"""
|
||
|
with amp.autocast(False):
|
||
|
if self.conf.train_mode.require_dataset_infer():
|
||
|
cond = batch[0]
|
||
|
if self.conf.latent_znormalize:
|
||
|
cond = (cond - self.conds_mean.to(
|
||
|
self.device)) / self.conds_std.to(self.device)
|
||
|
else:
|
||
|
imgs = batch[0]
|
||
|
x_start = imgs
|
||
|
|
||
|
if self.conf.train_mode == TrainMode.diffusion:
|
||
|
t, weight = self.T_sampler.sample(len(x_start), x_start.device)
|
||
|
|
||
|
losses = self.sampler.training_losses(model=self.model,
|
||
|
x_start=x_start,
|
||
|
t=t,
|
||
|
user_label=batch[1],
|
||
|
lossbetas=self.conf.betas)
|
||
|
elif self.conf.train_mode.is_latent_diffusion():
|
||
|
t, weight = self.T_sampler.sample(len(cond), cond.device)
|
||
|
latent_losses = self.latent_sampler.training_losses(
|
||
|
model=self.model.latent_net, x_start=cond, t=t)
|
||
|
losses = {
|
||
|
'latent': latent_losses['loss'],
|
||
|
'loss': latent_losses['loss']
|
||
|
}
|
||
|
else:
|
||
|
raise NotImplementedError()
|
||
|
|
||
|
loss = losses['loss'].mean()
|
||
|
self.log("train_loss", loss)
|
||
|
|
||
|
return {'loss': loss}
|
||
|
|
||
|
def on_train_batch_end(self, outputs, batch, batch_idx: int,
|
||
|
dataloader_idx: int) -> None:
|
||
|
"""
|
||
|
after each training step
|
||
|
"""
|
||
|
if self.is_last_accum(batch_idx):
|
||
|
if (batch_idx==len(self.train_dataloader())-1) and ((self.current_epoch+1) % 10 == 0):
|
||
|
save_path = os.path.join(self.conf.logdir, 'epoch%d.ckpt' % self.current_epoch)
|
||
|
torch.save({
|
||
|
'state_dict': self.state_dict(),
|
||
|
'global_step': self.global_step,
|
||
|
'loss': outputs['loss'],
|
||
|
}, save_path)
|
||
|
|
||
|
if self.conf.train_mode == TrainMode.latent_diffusion:
|
||
|
ema(self.model.latent_net, self.ema_model.latent_net,
|
||
|
self.conf.ema_decay)
|
||
|
else:
|
||
|
ema(self.model, self.ema_model, self.conf.ema_decay)
|
||
|
|
||
|
def on_before_optimizer_step(self, optimizer: Optimizer,
|
||
|
optimizer_idx: int) -> None:
|
||
|
if self.conf.grad_clip > 0:
|
||
|
params = [
|
||
|
p for group in optimizer.param_groups for p in group['params']
|
||
|
]
|
||
|
torch.nn.utils.clip_grad_norm_(params,
|
||
|
max_norm=self.conf.grad_clip)
|
||
|
|
||
|
def configure_optimizers(self):
|
||
|
out = {}
|
||
|
if self.conf.optimizer == OptimizerType.adam:
|
||
|
optim = torch.optim.Adam(self.model.parameters(),
|
||
|
lr=self.conf.lr,
|
||
|
weight_decay=self.conf.weight_decay)
|
||
|
elif self.conf.optimizer == OptimizerType.adamw:
|
||
|
optim = torch.optim.AdamW(self.model.parameters(),
|
||
|
lr=self.conf.lr,
|
||
|
weight_decay=self.conf.weight_decay)
|
||
|
else:
|
||
|
raise NotImplementedError()
|
||
|
out['optimizer'] = optim
|
||
|
if self.conf.warmup > 0:
|
||
|
sched = torch.optim.lr_scheduler.LambdaLR(optim,
|
||
|
lr_lambda=WarmupLR(
|
||
|
self.conf.warmup))
|
||
|
out['lr_scheduler'] = {
|
||
|
'scheduler': sched,
|
||
|
'interval': 'step',
|
||
|
}
|
||
|
return out
|
||
|
|
||
|
def split_tensor(self, x):
|
||
|
"""
|
||
|
extract the tensor for a corresponding "worker" in the batch dimension
|
||
|
|
||
|
Args:
|
||
|
x: (n, c)
|
||
|
|
||
|
Returns: x: (n_local, c)
|
||
|
"""
|
||
|
n = len(x)
|
||
|
rank = self.global_rank
|
||
|
world_size = get_world_size()
|
||
|
per_rank = n // world_size
|
||
|
return x[rank * per_rank:(rank + 1) * per_rank]
|
||
|
|
||
|
def ema(source, target, decay):
|
||
|
source_dict = source.state_dict()
|
||
|
target_dict = target.state_dict()
|
||
|
for key in source_dict.keys():
|
||
|
target_dict[key].data.copy_(target_dict[key].data * decay +
|
||
|
source_dict[key].data * (1 - decay))
|
||
|
|
||
|
class WarmupLR:
|
||
|
def __init__(self, warmup) -> None:
|
||
|
self.warmup = warmup
|
||
|
|
||
|
def __call__(self, step):
|
||
|
return min(step, self.warmup) / self.warmup
|
||
|
|
||
|
|
||
|
def is_time(num_samples, every, step_size):
|
||
|
closest = (num_samples // every) * every
|
||
|
return num_samples - closest < step_size
|
||
|
|
||
|
|
||
|
def train(conf: TrainConfig, model: LitModel, gpus, nodes=1):
|
||
|
checkpoint = ModelCheckpoint(dirpath=conf.logdir,
|
||
|
save_last=True,
|
||
|
save_top_k=1,
|
||
|
every_n_train_steps=conf.save_every_samples //
|
||
|
conf.batch_size_effective)
|
||
|
checkpoint_path = f'{conf.logdir}last.ckpt'
|
||
|
print('ckpt path:', checkpoint_path)
|
||
|
if os.path.exists(checkpoint_path):
|
||
|
resume = checkpoint_path
|
||
|
print('resume!')
|
||
|
else:
|
||
|
if conf.continue_from is not None:
|
||
|
resume = conf.continue_from.path
|
||
|
else:
|
||
|
resume = None
|
||
|
|
||
|
plugins = []
|
||
|
if len(gpus) == 1 and nodes == 1:
|
||
|
accelerator = None
|
||
|
else:
|
||
|
accelerator = 'ddp'
|
||
|
from pytorch_lightning.plugins import DDPPlugin
|
||
|
|
||
|
plugins.append(DDPPlugin(find_unused_parameters=False))
|
||
|
|
||
|
wandb_logger = pl_loggers.WandbLogger(project='dismouse',
|
||
|
name='%s_%s'%(model.conf.pretrainDataset, conf.logdir.split('/')[-2]),
|
||
|
log_model=True,
|
||
|
save_dir=conf.logdir,
|
||
|
dir = conf.logdir,
|
||
|
config=vars(model.conf),
|
||
|
save_code=True,
|
||
|
settings=wandb.Settings(code_dir="."))
|
||
|
|
||
|
trainer = pl.Trainer(
|
||
|
max_steps=conf.total_samples // conf.batch_size_effective,
|
||
|
resume_from_checkpoint=resume,
|
||
|
gpus=gpus,
|
||
|
num_nodes=nodes,
|
||
|
accelerator=accelerator,
|
||
|
precision=16 if conf.fp16 else 32,
|
||
|
callbacks=[
|
||
|
checkpoint,
|
||
|
LearningRateMonitor(),
|
||
|
],
|
||
|
replace_sampler_ddp=True,
|
||
|
logger= wandb_logger,
|
||
|
accumulate_grad_batches=conf.accum_batches,
|
||
|
plugins=plugins,
|
||
|
)
|
||
|
|
||
|
trainer.fit(model)
|
||
|
wandb.finish()
|