DisMouse/experiment.py

378 lines
14 KiB
Python
Raw Normal View History

2024-10-08 14:18:47 +02:00
import copy, wandb
import os
import random
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import *
from torch.cuda import amp
from torch.optim.optimizer import Optimizer
from torch.utils.data.dataset import TensorDataset
from config import *
from dataset import *
from dist_utils import *
def MakeDir(dirName):
if not os.path.exists(dirName):
os.makedirs(dirName)
class LitModel(pl.LightningModule):
def __init__(self, conf: TrainConfig, betas):
super().__init__()
self.save_hyperparameters({k:v for (k,v) in vars(conf).items() if not callable(v)})
self.save_hyperparameters(conf.as_dict_jsonable())
assert conf.train_mode != TrainMode.manipulate
if conf.seed is not None:
pl.seed_everything(conf.seed)
conf.betas = betas
self.conf = conf
self.model = conf.make_model_conf().make_model()
self.ema_model = copy.deepcopy(self.model)
self.ema_model.requires_grad_(False)
self.ema_model.eval()
model_size = 0
for param in self.model.parameters():
model_size += param.data.nelement()
print('Model params: %.2f M' % (model_size / 1024 / 1024))
self.sampler = conf.make_diffusion_conf().make_sampler()
self.eval_sampler = conf.make_eval_diffusion_conf().make_sampler()
self.T_sampler = conf.make_T_sampler()
if conf.train_mode.use_latent_net():
self.latent_sampler = conf.make_latent_diffusion_conf(
).make_sampler()
self.eval_latent_sampler = conf.make_latent_eval_diffusion_conf(
).make_sampler()
else:
self.latent_sampler = None
self.eval_latent_sampler = None
if conf.pretrain is not None:
print(f'loading pretrain ... {conf.pretrain.name}')
state = torch.load(conf.pretrain.path, map_location='cpu')
print('step:', state['global_step'])
self.load_state_dict(state['state_dict'], strict=False)
if conf.latent_infer_path is not None:
print('loading latent stats ...')
state = torch.load(conf.latent_infer_path)
self.conds = state['conds']
else:
self.conds_mean = None
self.conds_std = None
def normalize(self, cond):
cond = (cond - self.conds_mean.to(self.device)) / self.conds_std.to(
self.device)
return cond
def denormalize(self, cond):
cond = (cond * self.conds_std.to(self.device)) + self.conds_mean.to(
self.device)
return cond
def render(self, noise, cond=None, T=None):
if T is None:
sampler = self.eval_sampler
else:
sampler = self.conf._make_diffusion_conf(T).make_sampler()
if cond is not None:
pred_img = render_condition(self.conf,
self.ema_model,
noise,
sampler=sampler,
cond=cond)
else:
pred_img = render_uncondition(self.conf,
self.ema_model,
noise,
sampler=sampler,
latent_sampler=None)
return pred_img
def encode(self, x):
assert self.conf.model_type.has_autoenc()
cond = self.ema_model.encoder.forward(x)
return cond
def encode_stochastic(self, x, cond, T=None):
if T is None:
sampler = self.eval_sampler
else:
sampler = self.conf._make_diffusion_conf(T).make_sampler()
out = sampler.ddim_reverse_sample_loop(self.ema_model,
x,
model_kwargs={'cond': cond})
return out['sample'], out['xstart_t']
def forward(self, noise=None, x_start=None, ema_model: bool = False):
with amp.autocast(False):
if ema_model:
model = self.ema_model
else:
model = self.model
gen = self.eval_sampler.sample(model=model,
noise=noise,
x_start=x_start)
return gen
def setup(self, stage=None) -> None:
"""
make datasets & seeding each worker separately
"""
if self.conf.seed is not None:
seed = self.conf.seed * get_world_size() + self.global_rank
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
taskdata, tasklabels = loadDataset(self.conf)
assert self.conf.num_users==tasklabels['user'].nunique()
tasklabels = pd.DataFrame(tasklabels['user'].astype('category').cat.codes.values)[0].values.astype(int)
assert self.conf.num_users==len(np.unique(tasklabels))
self.train_data = self.conf.make_dataset(taskdata, tasklabels)
self.val_data = self.train_data
def train_dataloader(self):
"""
return the dataloader, if diffusion mode => return image dataset
if latent mode => return the inferred latent dataset
"""
if self.conf.train_mode.require_dataset_infer():
if self.conds is None:
self.conds = self.infer_whole_dataset()
self.conds_mean.data = self.conds.float().mean(dim=0,
keepdim=True)
self.conds_std.data = self.conds.float().std(dim=0,
keepdim=True)
print('mean:', self.conds_mean.mean(), 'std:',
self.conds_std.mean())
conf = self.conf.clone()
conf.batch_size = self.batch_size
data = TensorDataset(self.conds)
return conf.make_loader(data, shuffle=True)
else:
return torch.utils.data.DataLoader(self.train_data, batch_size=self.conf.batch_size, shuffle=True)
@property
def batch_size(self):
"""
local batch size for each worker
"""
ws = get_world_size()
assert self.conf.batch_size % ws == 0
return self.conf.batch_size // ws
@property
def num_samples(self):
"""
(global) batch size * iterations
"""
return self.global_step * self.conf.batch_size_effective
def is_last_accum(self, batch_idx):
"""
is it the last gradient accumulation loop?
used with gradient_accum > 1 and to see if the optimizer will perform "step" in this iteration or not
"""
return (batch_idx + 1) % self.conf.accum_batches == 0
def training_step(self, batch, batch_idx):
"""
given an input, calculate the loss function
no optimization at this stage.
"""
with amp.autocast(False):
if self.conf.train_mode.require_dataset_infer():
cond = batch[0]
if self.conf.latent_znormalize:
cond = (cond - self.conds_mean.to(
self.device)) / self.conds_std.to(self.device)
else:
imgs = batch[0]
x_start = imgs
if self.conf.train_mode == TrainMode.diffusion:
t, weight = self.T_sampler.sample(len(x_start), x_start.device)
losses = self.sampler.training_losses(model=self.model,
x_start=x_start,
t=t,
user_label=batch[1],
lossbetas=self.conf.betas)
elif self.conf.train_mode.is_latent_diffusion():
t, weight = self.T_sampler.sample(len(cond), cond.device)
latent_losses = self.latent_sampler.training_losses(
model=self.model.latent_net, x_start=cond, t=t)
losses = {
'latent': latent_losses['loss'],
'loss': latent_losses['loss']
}
else:
raise NotImplementedError()
loss = losses['loss'].mean()
self.log("train_loss", loss)
return {'loss': loss}
def on_train_batch_end(self, outputs, batch, batch_idx: int,
dataloader_idx: int) -> None:
"""
after each training step
"""
if self.is_last_accum(batch_idx):
if (batch_idx==len(self.train_dataloader())-1) and ((self.current_epoch+1) % 10 == 0):
save_path = os.path.join(self.conf.logdir, 'epoch%d.ckpt' % self.current_epoch)
torch.save({
'state_dict': self.state_dict(),
'global_step': self.global_step,
'loss': outputs['loss'],
}, save_path)
if self.conf.train_mode == TrainMode.latent_diffusion:
ema(self.model.latent_net, self.ema_model.latent_net,
self.conf.ema_decay)
else:
ema(self.model, self.ema_model, self.conf.ema_decay)
def on_before_optimizer_step(self, optimizer: Optimizer,
optimizer_idx: int) -> None:
if self.conf.grad_clip > 0:
params = [
p for group in optimizer.param_groups for p in group['params']
]
torch.nn.utils.clip_grad_norm_(params,
max_norm=self.conf.grad_clip)
def configure_optimizers(self):
out = {}
if self.conf.optimizer == OptimizerType.adam:
optim = torch.optim.Adam(self.model.parameters(),
lr=self.conf.lr,
weight_decay=self.conf.weight_decay)
elif self.conf.optimizer == OptimizerType.adamw:
optim = torch.optim.AdamW(self.model.parameters(),
lr=self.conf.lr,
weight_decay=self.conf.weight_decay)
else:
raise NotImplementedError()
out['optimizer'] = optim
if self.conf.warmup > 0:
sched = torch.optim.lr_scheduler.LambdaLR(optim,
lr_lambda=WarmupLR(
self.conf.warmup))
out['lr_scheduler'] = {
'scheduler': sched,
'interval': 'step',
}
return out
def split_tensor(self, x):
"""
extract the tensor for a corresponding "worker" in the batch dimension
Args:
x: (n, c)
Returns: x: (n_local, c)
"""
n = len(x)
rank = self.global_rank
world_size = get_world_size()
per_rank = n // world_size
return x[rank * per_rank:(rank + 1) * per_rank]
def ema(source, target, decay):
source_dict = source.state_dict()
target_dict = target.state_dict()
for key in source_dict.keys():
target_dict[key].data.copy_(target_dict[key].data * decay +
source_dict[key].data * (1 - decay))
class WarmupLR:
def __init__(self, warmup) -> None:
self.warmup = warmup
def __call__(self, step):
return min(step, self.warmup) / self.warmup
def is_time(num_samples, every, step_size):
closest = (num_samples // every) * every
return num_samples - closest < step_size
def train(conf: TrainConfig, model: LitModel, gpus, nodes=1):
checkpoint = ModelCheckpoint(dirpath=conf.logdir,
save_last=True,
save_top_k=1,
every_n_train_steps=conf.save_every_samples //
conf.batch_size_effective)
checkpoint_path = f'{conf.logdir}last.ckpt'
print('ckpt path:', checkpoint_path)
if os.path.exists(checkpoint_path):
resume = checkpoint_path
print('resume!')
else:
if conf.continue_from is not None:
resume = conf.continue_from.path
else:
resume = None
plugins = []
if len(gpus) == 1 and nodes == 1:
accelerator = None
else:
accelerator = 'ddp'
from pytorch_lightning.plugins import DDPPlugin
plugins.append(DDPPlugin(find_unused_parameters=False))
wandb_logger = pl_loggers.WandbLogger(project='dismouse',
name='%s_%s'%(model.conf.pretrainDataset, conf.logdir.split('/')[-2]),
log_model=True,
save_dir=conf.logdir,
dir = conf.logdir,
config=vars(model.conf),
save_code=True,
settings=wandb.Settings(code_dir="."))
trainer = pl.Trainer(
max_steps=conf.total_samples // conf.batch_size_effective,
resume_from_checkpoint=resume,
gpus=gpus,
num_nodes=nodes,
accelerator=accelerator,
precision=16 if conf.fp16 else 32,
callbacks=[
checkpoint,
LearningRateMonitor(),
],
replace_sampler_ddp=True,
logger= wandb_logger,
accumulate_grad_batches=conf.accum_batches,
plugins=plugins,
)
trainer.fit(model)
wandb.finish()