VDGR/utils/optim_utils.py

390 lines
14 KiB
Python
Raw Normal View History

2023-10-25 15:38:09 +02:00
import logging
import math
import numpy as np
import random
import functools
import glog as log
import torch
from torch import nn, optim
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler, ConstantLR
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_
from pytorch_transformers.optimization import AdamW
class WarmupLinearScheduleNonZero(_LRScheduler):
""" Linear warmup and then linear decay.
Linearly increases learning rate from 0 to max_lr over `warmup_steps` training steps.
Linearly decreases learning rate linearly to min_lr over remaining `t_total - warmup_steps` steps.
"""
def __init__(self, optimizer, warmup_steps, t_total, min_lr=1e-5, last_epoch=-1):
self.warmup_steps = warmup_steps
self.t_total = t_total
self.min_lr = min_lr
super(WarmupLinearScheduleNonZero, self).__init__(optimizer, last_epoch=last_epoch)
def get_lr(self):
step = self.last_epoch
if step < self.warmup_steps:
lr_factor = float(step) / float(max(1, self.warmup_steps))
else:
lr_factor = max(0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps)))
return [base_lr * lr_factor if (base_lr * lr_factor) > self.min_lr else self.min_lr for base_lr in self.base_lrs]
def init_optim(model, config):
optimizer_grouped_parameters = []
gnn_params = []
encoder_params_with_decay = []
encoder_params_without_decay = []
exclude_from_weight_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
for module_name, module in model.named_children():
for param_name, param in module.named_parameters():
if param.requires_grad:
if "gnn" in param_name:
gnn_params.append(param)
elif module_name == 'encoder':
if any(ex in param_name for ex in exclude_from_weight_decay):
encoder_params_without_decay.append(param)
else:
encoder_params_with_decay.append(param)
optimizer_grouped_parameters = [
{
'params': gnn_params,
'weight_decay': config.gnn_weight_decay,
'lr': config['learning_rate_gnn'] if config.use_diff_lr_gnn else config['learning_rate_bert']
}
]
optimizer_grouped_parameters.extend(
[
{
'params': encoder_params_without_decay,
'weight_decay': 0,
'lr': config['learning_rate_bert']
},
{
'params': encoder_params_with_decay,
'weight_decay': 0.01,
'lr': config['learning_rate_bert']
}
]
)
optimizer = AdamW(optimizer_grouped_parameters, lr=config['learning_rate_gnn'])
scheduler = WarmupLinearScheduleNonZero(
optimizer,
warmup_steps=config['warmup_steps'],
t_total=config['train_steps'],
min_lr=config['min_lr']
)
return optimizer, scheduler
def build_torch_optimizer(model, config):
"""Builds the PyTorch optimizer.
We use the default parameters for Adam that are suggested by
the original paper https://arxiv.org/pdf/1412.6980.pdf
These values are also used by other established implementations,
e.g. https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer
https://keras.io/optimizers/
Recently there are slightly different values used in the paper
"Attention is all you need"
https://arxiv.org/pdf/1706.03762.pdf, particularly the value beta2=0.98
was used there however, beta2=0.999 is still arguably the more
established value, so we use that here as well
Args:
model: The model to optimize.
config: The dictionary of options.
Returns:
A ``torch.optim.Optimizer`` instance.
"""
params = [p for p in model.parameters() if p.requires_grad]
betas = [0.9, 0.999]
exclude_from_weight_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
params = {'bert': [], 'task': []}
for module_name, module in model.named_children():
if module_name == 'encoder':
param_type = 'bert'
else:
param_type = 'task'
for param_name, param in module.named_parameters():
if param.requires_grad:
if any(ex in param_name for ex in exclude_from_weight_decay):
params[param_type] += [
{
"params": [param],
"weight_decay": 0
}
]
else:
params[param_type] += [
{
"params": [param],
"weight_decay": 0.01
}
]
if config['task_optimizer'] == 'adamw':
log.info('Using AdamW as task optimizer')
task_optimizer = AdamWeightDecay(params['task'],
lr=config["learning_rate_task"],
betas=betas,
eps=1e-6)
elif config['task_optimizer'] == 'adam':
log.info('Using Adam as task optimizer')
task_optimizer = optim.Adam(params['task'],
lr=config["learning_rate_task"],
betas=betas,
eps=1e-6)
if len(params['bert']) > 0:
bert_optimizer = AdamWeightDecay(params['bert'],
lr=config["learning_rate_bert"],
betas=betas,
eps=1e-6)
optimizer = MultipleOptimizer([bert_optimizer, task_optimizer])
else:
optimizer = task_optimizer
return optimizer
def make_learning_rate_decay_fn(decay_method, train_steps, **kwargs):
"""Returns the learning decay function from options."""
if decay_method == "linear":
return functools.partial(
linear_decay,
global_steps=train_steps,
**kwargs)
elif decay_method == "exp":
return functools.partial(
exp_decay,
global_steps=train_steps,
**kwargs)
else:
raise ValueError(f'{decay_method} not found')
def linear_decay(step, global_steps, warmup_steps=100, initial_learning_rate=1, end_learning_rate=0, **kargs):
if step < warmup_steps:
return initial_learning_rate * step / warmup_steps
else:
return (initial_learning_rate - end_learning_rate) * \
(1 - (step - warmup_steps) / (global_steps - warmup_steps)) + \
end_learning_rate
def exp_decay(step, global_steps, decay_exp=1, warmup_steps=100, initial_learning_rate=1, end_learning_rate=0, **kargs):
if step < warmup_steps:
return initial_learning_rate * step / warmup_steps
else:
return (initial_learning_rate - end_learning_rate) * \
((1 - (step - warmup_steps) / (global_steps - warmup_steps)) ** decay_exp) + \
end_learning_rate
class MultipleOptimizer(object):
""" Implement multiple optimizers needed for sparse adam """
def __init__(self, op):
""" ? """
self.optimizers = op
@property
def param_groups(self):
param_groups = []
for optimizer in self.optimizers:
param_groups.extend(optimizer.param_groups)
return param_groups
def zero_grad(self):
""" ? """
for op in self.optimizers:
op.zero_grad()
def step(self):
""" ? """
for op in self.optimizers:
op.step()
@property
def state(self):
""" ? """
return {k: v for op in self.optimizers for k, v in op.state.items()}
def state_dict(self):
""" ? """
return [op.state_dict() for op in self.optimizers]
def load_state_dict(self, state_dicts):
""" ? """
assert len(state_dicts) == len(self.optimizers)
for i in range(len(state_dicts)):
self.optimizers[i].load_state_dict(state_dicts[i])
class OptimizerBase(object):
"""
Controller class for optimization. Mostly a thin
wrapper for `optim`, but also useful for implementing
rate scheduling beyond what is currently available.
Also implements necessary methods for training RNNs such
as grad manipulations.
"""
def __init__(self,
optimizer,
learning_rate,
learning_rate_decay_fn=None,
max_grad_norm=None):
"""Initializes the controller.
Args:
optimizer: A ``torch.optim.Optimizer`` instance.
learning_rate: The initial learning rate.
learning_rate_decay_fn: An optional callable taking the current step
as argument and return a learning rate scaling factor.
max_grad_norm: Clip gradients to this global norm.
"""
self._optimizer = optimizer
self._learning_rate = learning_rate
self._learning_rate_decay_fn = learning_rate_decay_fn
self._max_grad_norm = max_grad_norm or 0
self._training_step = 1
self._decay_step = 1
@classmethod
def from_opt(cls, model, config, checkpoint=None):
"""Builds the optimizer from options.
Args:
cls: The ``Optimizer`` class to instantiate.
model: The model to optimize.
opt: The dict of user options.
checkpoint: An optional checkpoint to load states from.
Returns:
An ``Optimizer`` instance.
"""
optim_opt = config
optim_state_dict = None
if config["loads_ckpt"] and checkpoint is not None:
optim = checkpoint['optim']
ckpt_opt = checkpoint['opt']
ckpt_state_dict = {}
if isinstance(optim, Optimizer): # Backward compatibility.
ckpt_state_dict['training_step'] = optim._step + 1
ckpt_state_dict['decay_step'] = optim._step + 1
ckpt_state_dict['optimizer'] = optim.optimizer.state_dict()
else:
ckpt_state_dict = optim
if config["reset_optim"] == 'none':
# Load everything from the checkpoint.
optim_opt = ckpt_opt
optim_state_dict = ckpt_state_dict
elif config["reset_optim"] == 'all':
# Build everything from scratch.
pass
elif config["reset_optim"] == 'states':
# Reset optimizer, keep options.
optim_opt = ckpt_opt
optim_state_dict = ckpt_state_dict
del optim_state_dict['optimizer']
elif config["reset_optim"] == 'keep_states':
# Reset options, keep optimizer.
optim_state_dict = ckpt_state_dict
learning_rates = [
optim_opt["learning_rate_bert"],
optim_opt["learning_rate_gnn"]
]
decay_fn = [
make_learning_rate_decay_fn(optim_opt['decay_method_bert'],
optim_opt['train_steps'],
warmup_steps=optim_opt['warmup_steps'],
decay_exp=optim_opt['decay_exp']),
make_learning_rate_decay_fn(optim_opt['decay_method_gnn'],
optim_opt['train_steps'],
warmup_steps=optim_opt['warmup_steps'],
decay_exp=optim_opt['decay_exp']),
]
optimizer = cls(
build_torch_optimizer(model, optim_opt),
learning_rates,
learning_rate_decay_fn=decay_fn,
max_grad_norm=optim_opt["max_grad_norm"])
if optim_state_dict:
optimizer.load_state_dict(optim_state_dict)
return optimizer
@property
def training_step(self):
"""The current training step."""
return self._training_step
def learning_rate(self):
"""Returns the current learning rate."""
if self._learning_rate_decay_fn is None:
return self._learning_rate
return [decay_fn(self._decay_step) * learning_rate \
for decay_fn, learning_rate in \
zip(self._learning_rate_decay_fn, self._learning_rate)]
def state_dict(self):
return {
'training_step': self._training_step,
'decay_step': self._decay_step,
'optimizer': self._optimizer.state_dict()
}
def load_state_dict(self, state_dict):
self._training_step = state_dict['training_step']
# State can be partially restored.
if 'decay_step' in state_dict:
self._decay_step = state_dict['decay_step']
if 'optimizer' in state_dict:
self._optimizer.load_state_dict(state_dict['optimizer'])
def zero_grad(self):
"""Zero the gradients of optimized parameters."""
self._optimizer.zero_grad()
def backward(self, loss):
"""Wrapper for backward pass. Some optimizer requires ownership of the
backward pass."""
loss.backward()
def step(self):
"""Update the model parameters based on current gradients.
Optionally, will employ gradient modification or update learning
rate.
"""
learning_rate = self.learning_rate()
if isinstance(self._optimizer, MultipleOptimizer):
optimizers = self._optimizer.optimizers
else:
optimizers = [self._optimizer]
for lr, op in zip(learning_rate, optimizers):
for group in op.param_groups:
group['lr'] = lr
if self._max_grad_norm > 0:
clip_grad_norm_(group['params'], self._max_grad_norm)
self._optimizer.step()
self._decay_step += 1
self._training_step += 1